Compare commits

...

111 Commits

Author SHA1 Message Date
Soulter 37a1f144ab chore: update changelog of 3.4.15 2025-01-30 00:32:50 +08:00
Soulter 9a7a654596 perf: 插件处于禁用状态时其所属的函数调用工具不可被启用 #254 2025-01-30 00:27:10 +08:00
Soulter 9abccd63cf chore: remove stt.py 2025-01-29 23:47:50 +08:00
Soulter 93fea77182 chore: bump to v3.4.15 2025-01-29 23:43:09 +08:00
Soulter 19797243f6 perf: 增加插件链接 2025-01-29 19:56:09 +08:00
Soulter c9c733d925 Merge branch 'dev' 2025-01-29 19:43:52 +08:00
Soulter a7d7678c78 fix: 修复白名单为空时依然终止事件 #259 2025-01-29 17:17:27 +08:00
Soulter c0911921c7 feat: 配置Schema以及插件支持配置 2025-01-29 16:54:57 +08:00
Soulter 4a4241d57a Update README.md 2025-01-29 13:26:51 +08:00
Soulter c9426bb6eb config 2025-01-29 12:25:54 +08:00
Soulter db4abd169a fix: 优化分段回复 2025-01-28 14:42:15 +08:00
Soulter 80b6958599 fix: 修复 config validator 不起效的问题 2025-01-28 14:18:21 +08:00
Soulter 80058c781a fix: 修复r1思考标签问题和分段回复间隔时间问题 2025-01-28 14:03:10 +08:00
Soulter 44bd2e36f3 Update README.md 2025-01-28 02:15:11 +08:00
Soulter 3589a5e5be perf: 强化ltm异常处理 2025-01-27 21:47:35 +08:00
Soulter 13ef033f0e fix: 群聊增强的参数类型转换 2025-01-27 21:40:20 +08:00
Soulter 3f8c68bbca fix: f-string expression part cannot include a backslash
long_term_memory.py, line 69
2025-01-27 21:01:50 +08:00
Soulter 4275cea82b chore: v3.4.14 2025-01-27 20:09:03 +08:00
Soulter a0bcb5339a perf: 自动删除 deepseek-r1 模型自带的 think 标签 2025-01-27 20:04:39 +08:00
Soulter 43deec4a4b Merge pull request #255 from Soulter/feat-ltm
支持记录非唤醒状态下群聊历史记录
2025-01-27 20:02:43 +08:00
Soulter 2bc433a30b feat: 支持记录非唤醒状态下群聊历史记录 2025-01-27 20:00:32 +08:00
Soulter eb2b395932 perf: /t2i 即时生效 2025-01-27 19:33:38 +08:00
Soulter 2bfd1c0bf2 perf: 自动移除 ollama 不支持 tool 的模型的 tool 请求 2025-01-27 19:25:28 +08:00
Soulter 7228c4b13f fix: 修复 TTS 部分变量名错误导致请求失败 2025-01-27 18:45:34 +08:00
Soulter 9351d7471f perf: 优化 gewechat 消息下发异常处理 2025-01-27 18:11:31 +08:00
Soulter 1cf49998bc Update README.md 2025-01-27 11:34:27 +08:00
Soulter 6ae86597e8 chore: v3.4.13 2025-01-26 16:51:13 +08:00
Soulter c578ff25bd fix: stt_enabled 未初始化 #252 2025-01-26 16:51:02 +08:00
Soulter 2934a3e3be chore: logo 2025-01-26 15:18:23 +08:00
Soulter ceaa69da75 feat: 支持消息分段回复 2025-01-26 13:45:32 +08:00
Soulter fa8e731576 Update README.md 2025-01-25 22:45:47 +08:00
Soulter 685c0a106a perf: use pysilk instead of pilk 避免构建问题 2025-01-25 20:18:40 +08:00
Soulter 7f539090dd perf: 更新项目时连带更新依赖 2025-01-25 20:04:28 +08:00
Soulter 2089273f95 Merge pull request #251 from Soulter/feat-tts
适配 OpenAI TTS API,并支持 Napcat,Gewechat,Lagrange 的语音输出
2025-01-25 19:51:22 +08:00
Soulter 838bb4c7ad chore: remove duration 2025-01-25 19:49:53 +08:00
Soulter 637acd1a12 feat: 适配 OpenAI TTS API,并支持 Napcat,Gewechat,Lagrange 的语音输出 2025-01-25 19:46:00 +08:00
Soulter 03fa9a847f feat: gewechat 支持语音、图片 2025-01-25 16:34:40 +08:00
Soulter d488c88e78 feat: 支持路径映射,解决docker部署两端文件系统不一致导致的富媒体文件路径不存在问题 2025-01-24 14:08:08 +08:00
Soulter baae842210 fix: napcat 下语音消息接收异常 2025-01-24 13:41:13 +08:00
Soulter ec1fb838b6 perf: notice 2025-01-22 21:38:05 +08:00
Soulter 13281179df perf: notice 2025-01-22 21:36:28 +08:00
Soulter 276a42c9a1 Bump to 3.4.11 2025-01-22 21:16:24 +08:00
Soulter 7a70a730ba perf: 任务报错后的优雅报错输出 2025-01-22 21:14:26 +08:00
Soulter d0fe59631c perf: 优化更新项目时重启可能会导致Address already in use的问题 2025-01-22 20:57:15 +08:00
Soulter 106892e933 fix: 修复appid保存的问题和部分群聊at失效的问题和群聊@的sender username显示异常的问题 2025-01-22 20:34:52 +08:00
Soulter 19543a41b3 Update README.md 2025-01-22 19:56:07 +08:00
Soulter b172b760ab feat: 为平台和提供商适配器添加默认 ID 配置 #248 2025-01-22 16:52:34 +08:00
Soulter 4b5d49cb41 Bump to 3.4.10 2025-01-22 00:19:20 +08:00
Soulter 3fd35b6058 feat: 管理面板更新面板按钮 #245 2025-01-22 00:17:43 +08:00
Soulter 5f86c4ab99 perf: 增强 LLM 请求错误处理 #243 2025-01-21 16:29:19 +08:00
Soulter c94a7f6629 perf: 针对 api_base 的明显提示,修改 ollama 模板的api_base #247 2025-01-21 16:15:04 +08:00
Soulter 7d6beb4141 fix: QQ 图片发送不了 #246 2025-01-21 16:12:10 +08:00
Soulter e2117e690a feat: 支持登出gewechat 2025-01-21 13:12:09 +08:00
Soulter fb791290e2 fix: 添加gewechat适配器过滤器 2025-01-21 12:39:57 +08:00
Soulter 5dd1488b5d perf: 优化webui和主程序更新的协调
fix: 修复某些请求不能正确应用代理的问题
2025-01-21 01:08:15 +08:00
Soulter 529cd64d82 perf: help显示AstrBot和webui版本 2025-01-21 00:10:59 +08:00
Soulter d2bd3e8da8 bump to v3.4.9 2025-01-20 23:35:34 +08:00
Soulter e42ce7dd86 perf: 优化了用户体验 2025-01-20 23:27:13 +08:00
Soulter 40709462ee chore: bump domain to astrbot.app 2025-01-20 19:02:54 +08:00
Soulter 2ad6c01a4d Update README.md 2025-01-20 15:48:39 +08:00
Soulter 70c12e788e feat: LLM额外唤醒词与机器人唤醒词冲突时的处理 2025-01-20 10:22:25 +08:00
Soulter 1713791c90 docs: update webui demo 2025-01-20 00:46:29 +08:00
Soulter 9aa23fd412 Update README.md 2025-01-19 21:32:42 +08:00
Soulter e4ba09cd93 chore: remove package-lock.json 2025-01-19 18:20:40 +08:00
Soulter 171fdf1fbc fix: 消息链无元素时仍然插入了@和回复 2025-01-18 23:25:42 +08:00
Soulter 01f4e0b961 feat: gewechat 主动消息 2025-01-18 22:31:17 +08:00
Soulter be2d5a91c7 chore: bump to v3.4.8 2025-01-18 22:19:35 +08:00
Soulter a1d89d9478 Merge pull request #242 from Soulter/feat-gewechat
初步接入 gewechat 文字交互
2025-01-18 22:16:53 +08:00
Soulter 98d1dc3b65 feat: 初步接入 gewechat 文字交互 2025-01-18 22:01:36 +08:00
Soulter b80eb3acc0 feat: 支持回复时 At 和引用发送者 #241 2025-01-18 17:31:11 +08:00
Soulter 05ccc1995b fix: 清除残留的 personalities 2025-01-18 17:31:11 +08:00
Soulter 0de244889e chore: gitsponsors 2025-01-18 10:54:37 +08:00
Soulter e6c5c3a493 chore: bump to v3.4.7 2025-01-16 11:26:05 +08:00
Soulter 164aa2ccd2 Merge pull request #240 from Soulter/feat-better-persona
feat: 更好的人格情景管理
2025-01-16 11:20:28 +08:00
Soulter f1599e26b3 perf: webchat 主动信息 2025-01-16 11:19:02 +08:00
Soulter ed64a4d32d chore: 整理hint 2025-01-16 11:11:30 +08:00
Soulter 2ee4b431d4 fix: 无tool导致的报错 #239 2025-01-15 11:16:31 +08:00
Soulter cd8a73ed19 feat: 更好的人格情景管理和管理面板支持删除列表默认模版项 2025-01-14 21:08:57 +08:00
Soulter e6c985ce4e feat: 优化WebChat长连接的逻辑 2025-01-13 12:42:32 +08:00
Soulter a20446aeb9 🎉 chore: bump to v3.4.6 2025-01-13 02:17:23 +08:00
Soulter 7b23d76559 feat: 支持并完善服务提供商默认配置模板接口 2025-01-13 02:05:57 +08:00
Soulter 8315cf5818 perf: 面板文件更新检查和引导提示和AboutPage 2025-01-12 13:01:40 +08:00
Soulter ed16265bde fix: 更新官方文档链接并优化管理面板版本检查日志 2025-01-12 12:23:27 +08:00
Soulter dff205faf6 feat: 添加聊天功能路由和更新管理面板命令 2025-01-12 12:18:19 +08:00
Soulter 9aae8aee0c Update README.md 2025-01-12 11:45:39 +08:00
Soulter 7c818ced2b perf: 文件和语音功能适配 Lagrange 2025-01-12 11:44:33 +08:00
Soulter 218e887558 fix: download_file 修复 SSL 连接错误处理 2025-01-12 11:44:33 +08:00
Soulter a68860b35a chore: compress the banner 2025-01-12 10:52:17 +08:00
Soulter 82d4d43383 🎉 Bump to v3.4.5 2025-01-11 23:35:22 +08:00
Soulter 94618e8feb feat: 添加 aiodocker 依赖 2025-01-11 22:02:15 +08:00
Soulter 55de7d4494 🎉 Bump to v3.4.5 2025-01-11 21:40:48 +08:00
Soulter 7ed639f741 🎉 bump to v3.4.5 2025-01-11 21:06:06 +08:00
Soulter 41f2870c29 Merge pull request #236 from Soulter/feat-stt
支持 Speech To Text,并适配腾讯修改过的 Silk 语音格式
2025-01-11 21:00:04 +08:00
Soulter ba198490fa feat: 支持自部署 Whisper 模型 2025-01-11 20:31:21 +08:00
Soulter 0f9ab082ab perf: 优化webchat,没有结果返回时的反馈 2025-01-11 19:45:42 +08:00
Soulter 97b58965f2 feat: webchat可显示Provider状态 2025-01-11 19:31:56 +08:00
Soulter f2566c68e3 feat: 按 K 语音 2025-01-11 19:07:26 +08:00
Soulter a456bf5449 fix: 初始化reminder时的一些问题 2025-01-11 18:55:18 +08:00
Soulter a09998f910 feat: webchat 支持语音输入 2025-01-11 18:54:40 +08:00
Soulter be662b913c feat: 支持 Whisper STT,并适配 Tencent 语音格式 2025-01-11 17:19:28 +08:00
Soulter e7ddc8448d perf: 代码执行器在成功执行后清空文件buffer 2025-01-11 11:31:56 +08:00
Soulter 29374f8d8a fix: 修复 /dashbord_update 指令 2025-01-11 00:25:02 +08:00
Soulter 359b971103 Merge pull request #235 from Soulter/feat-webchat
WebChat 支持
2025-01-11 00:17:18 +08:00
Soulter fbdb1ae208 chore: bump to v3.4.4 2025-01-11 00:14:08 +08:00
Soulter 22c13c1eff perf: webchat支持传图 2025-01-11 00:06:19 +08:00
Soulter 5fc63aeaf1 perf: ui 2025-01-10 22:45:14 +08:00
Soulter d4f32673ab fix: 修复持久化问题 2025-01-10 22:08:43 +08:00
Soulter 480dffb51b feat: 初步实现 webchat 页面 2025-01-10 21:48:15 +08:00
Soulter 966df00124 feat: 支持从管理面板(控制台页)手动安装 pip 库 2025-01-10 15:35:57 +08:00
Soulter 3e2b4bc727 feat: 支持动态设置会话变量以适用 Dify 输入变量 2025-01-10 12:32:20 +08:00
Soulter 5929a8d42b Update README.md 2025-01-09 23:11:11 +08:00
110 changed files with 4199 additions and 10720 deletions
+2 -1
View File
@@ -20,4 +20,5 @@ chroma
node_modules/
.DS_Store
package-lock.json
package.json
package.json
venv/*
+95 -43
View File
@@ -1,12 +1,12 @@
<p align="center">
<img width=200 src="https://github.com/user-attachments/assets/3dd6a669-0830-4db4-b821-c8b279ea19a6"/>
![logo](https://github.com/user-attachments/assets/07649e07-3b8e-4feb-9aa9-bf13af4f3476)
</p>
<div align="center">
<h1>AstrBot</h1>
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
@@ -15,42 +15,104 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
</a>
<a href="https://astrbot.lwl.lol/">查看文档</a>
<a href="https://astrbot.app/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
## ✨ 多消息平台部署
## ✨ 主要功能
1. QQ 群、QQ 频道、微信个人号、Telegram
2. 支持文本转图片,Markdown 渲染
## ✨ 多 LLM 配置
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat、VChat)、Telegram。后续将支持钉钉、飞书、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
1. 适配 OpenAI API,支持接入 Gemini、GPT、Llama、Claude、DeepSeek、GLM 等各种大语言模型。
2. 支持 OneAPI 等分发平台。
3. 支持 LLMTuner 载入微调模型。
4. 支持 Ollama 载入自部署模型。
4. 支持网页搜索(Web Search)、自然语言待办提醒。
> [!TIP]
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
## ✨ 管理面板
## ✨ 使用方式
1. 支持可视化修改配置
2. 日志实时查看
3. 简单的信息统计
4. 插件管理
#### Docker 部署
## ✨ 支持 Dify
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流!
#### Windows 一键安装器部署
需要电脑上安装有 Python(>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### Replit 部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
#### CasaOS 部署
社区贡献的部署方式。
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
#### 手动部署
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
## ⚡ 消息平台支持情况
| 平台 | 支持性 | 详情 | 消息类型 |
| -------- | ------- | ------- | ------ |
| QQ | ✔ | 私聊、群聊 | 文字、图片、语音 |
| QQ 官方API | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字、图片、语音 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| 飞书 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - |
| 小爱音响 | 🚧 | 计划内 | - |
# 🦌 接下来的路线图
> [!TIP]
> 欢迎在 Issue 提出更多建议 <3
- [ ] 完善并保证目前所有平台适配器的功能一致性
- [ ] 优化插件接口
- [ ] 默认支持更多 TTS 服务,如 GPT-Sovits
- [ ] 完善“聊天增强”部分,支持持久化记忆
- [ ] 规划 i18n
## ❤️ 贡献
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
对于新功能的添加,请先通过 Issue 讨论。
## 🌟 支持
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
## ✨ Demo
> [!NOTE]
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器(Beta 测试中)✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
@@ -64,16 +126,23 @@ _✨ 自然语言待办事项 ✨_
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/caadf2bd-a0ee-43d0-a95e-566d63e3e34d" height=330>
<img src="https://github.com/user-attachments/assets/b418f281-e920-49db-9fe1-d6a13ce28a84" height=350>
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
_✨ 管理面板 ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内置 Web Chat,在线与机器人交互 ✨_
</div>
## ⭐ Star History
> [!TIP]
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
[![Star History Chart](https://api.star-history.com/svg?repos=soulter/astrbot&type=Date)](https://star-history.com/#soulter/astrbot&Date)
<!-- ## ✨ ATRI [Beta 测试]
@@ -84,23 +153,6 @@ _✨ 管理面板 ✨_
3. 表情包理解与回复
4. TTS
-->
## ✨ 云部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
_アトリは、高性能ですから!_
## ❤️ 贡献
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
对于新功能的添加,请先通过 Issue 讨论。
## 🔭 展望
1. 更强大的 Agent 系统。
2. 打造插件工作流平台。
## ✨ Support
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
-2
View File
@@ -1,6 +1,5 @@
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
@@ -8,7 +7,6 @@ from astrbot.core.star.register import register_llm_tool as llm_tool
__all__ = [
"AstrBotConfig",
"logger",
"personalities",
"html_renderer",
"llm_tool",
"sp"
-1
View File
@@ -1,7 +1,6 @@
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core.star.register import register_llm_tool as llm_tool
+2 -1
View File
@@ -2,4 +2,5 @@ from astrbot.core.platform import (
AstrMessageEvent, Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
)
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.platform.register import register_platform_adapter
from astrbot.core.message.components import *
+2 -2
View File
@@ -1,2 +1,2 @@
from astrbot.core.provider import Provider, Personality, ProviderMetaData
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData, LLMResponse
+7
View File
@@ -1,12 +1,16 @@
import os
import asyncio
from .log import LogManager, LogBroker
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
os.makedirs("data", exist_ok=True)
astrbot_config = AstrBotConfig()
html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot')
@@ -15,4 +19,7 @@ if os.environ.get('TESTING', ""):
db_helper = SQLiteDatabase(DB_PATH)
sp = SharedPreferences() # 简单的偏好设置存储
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
+53 -10
View File
@@ -2,7 +2,7 @@ import os
import json
import logging
import enum
from .default import DEFAULT_CONFIG
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
from typing import Dict
ASTRBOT_CONFIG_PATH = "data/cmd_config.json"
@@ -13,29 +13,72 @@ class RateLimitStrategy(enum.Enum):
DISCARD = "discard"
class AstrBotConfig(dict):
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项'''
'''从配置文件中加载的配置,支持直接通过点号操作符访问配置项
def __init__(self):
- 初始化时会将传入的 default_config 与配置文件进行比对,如果配置文件中缺少配置项则会自动插入默认值并进行一次写入操作。会递归检查配置项。
- 如果配置文件路径对应的文件不存在,则会自动创建并写入默认配置。
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
'''
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict = None
):
super().__init__()
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
object.__setattr__(self, 'config_path', config_path)
object.__setattr__(self, 'default_config', default_config)
object.__setattr__(self, 'schema', schema)
if schema:
default_config = self._config_schema_to_default_config(schema)
if not self.check_exist():
'''不存在时载入默认配置'''
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
json.dump(DEFAULT_CONFIG, f, indent=4, ensure_ascii=False)
with open(config_path, "w", encoding="utf-8-sig") as f:
json.dump(default_config, f, indent=4, ensure_ascii=False)
with open(ASTRBOT_CONFIG_PATH, "r", encoding="utf-8-sig") as f:
with open(config_path, "r", encoding="utf-8-sig") as f:
conf_str = f.read()
if conf_str.startswith(u'/ufeff'): # remove BOM
conf_str = conf_str.encode('utf8')[3:].decode('utf8')
conf = json.loads(conf_str)
# 检查配置完整性,并插入
has_new = self.check_config_integrity(DEFAULT_CONFIG, conf)
has_new = self.check_config_integrity(default_config, conf)
self.update(conf)
if has_new:
self.save_config()
self.update(conf)
def _config_schema_to_default_config(self, schema: dict) -> dict:
'''将 Schema 转换成 Config'''
conf = {}
def _parse_schema(schema: dict, conf: dict):
for k, v in schema.items():
if v['type'] not in DEFAULT_VALUE_MAP:
raise TypeError(f"不受支持的配置类型 {v['type']}。支持的类型有:{DEFAULT_VALUE_MAP.keys()}")
if 'default' in v:
default = v['default']
else:
default = DEFAULT_VALUE_MAP[v['type']]
if v['type'] == 'object':
conf[k] = {}
_parse_schema(v['items'], conf[k])
else:
conf[k] = default
_parse_schema(schema, conf)
return conf
def check_config_integrity(self, refer_conf: Dict, conf: Dict, path=""):
'''检查配置完整性,如果有新的配置项则返回 True'''
has_new = False
@@ -61,7 +104,7 @@ class AstrBotConfig(dict):
'''
if replace_config:
self.update(replace_config)
with open(ASTRBOT_CONFIG_PATH, "w", encoding="utf-8-sig") as f:
with open(self.config_path, "w", encoding="utf-8-sig") as f:
json.dump(self, f, indent=2, ensure_ascii=False)
def __getattr__(self, item):
@@ -81,4 +124,4 @@ class AstrBotConfig(dict):
self[key] = value
def check_exist(self) -> bool:
return os.path.exists(ASTRBOT_CONFIG_PATH)
return os.path.exists(self.config_path)
+230 -14
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.3"
VERSION = "3.4.15"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -22,6 +22,15 @@ DEFAULT_CONFIG = {
"id_whitelist_log": True,
"wl_ignore_admin_on_group": True,
"wl_ignore_admin_on_friend": True,
"reply_with_mention": False,
"reply_with_quote": False,
"path_mapping": [],
"segmented_reply": {
"enable": False,
"only_llm_result": True,
"interval": "1.5,3.5",
"regex": ".*?[。?!~…]+|.+$"
}
},
"provider": [],
"provider_settings": {
@@ -30,9 +39,23 @@ DEFAULT_CONFIG = {
"web_search": False,
"identifier": False,
"datetime_system_prompt": True,
"default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
"default_personality": "default",
"prompt_prefix": "",
},
"provider_stt_settings": {
"enable": False,
"provider_id": "",
},
"provider_tts_settings": {
"enable": False,
"provider_id": "",
},
"provider_ltm_settings": {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"image_caption": False,
"image_caption_prompt": "Please describe the image using Chinese.",
},
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
@@ -52,6 +75,14 @@ DEFAULT_CONFIG = {
"pip_install_arg": "",
"plugin_repo_mirror": "",
"knowledge_db": {},
"persona": [
{
"name": "default",
"prompt": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
"begin_dialogs": [],
"mood_imitation_dialogs": [],
}
],
}
@@ -81,6 +112,15 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": 6199,
},
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
"gewechat(微信)": {
"id": "gwchat",
"type": "gewechat",
"enable": False,
"base_url": "http://localhost:2531",
"nickname": "soulter",
"host": "localhost",
"port": 11451,
},
},
"items": {
"id": {
@@ -154,6 +194,31 @@ CONFIG_METADATA_2 = {
},
},
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
"items": {
"enable": {
"description": "启用分段回复",
"type": "bool",
},
"only_llm_result": {
"description": "仅对 LLM 结果分段",
"type": "bool",
},
"interval": {
"description": "随机间隔时间(秒)",
"type": "string",
"hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
},
"regex": {
"description": "正则表达式",
"type": "string",
"obvious_hint": True,
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
},
},
},
"reply_prefix": {
"description": "回复前缀",
"type": "string",
@@ -166,13 +231,13 @@ CONFIG_METADATA_2 = {
},
"enable_id_white_list": {
"description": "启用 ID 白名单",
"type": "bool"
"type": "bool",
},
"id_whitelist": {
"description": "ID 白名单",
"type": "list",
"items": {"type": "int"},
"hint": "填写后,将只处理所填写的 ID 发来的消息事件。为空时表示不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
"items": {"type": "string"},
"hint": "AstrBot 只处理所填写的 ID 发来的消息事件。为空时不启用白名单过滤。可以使用 /myid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978",
},
"id_whitelist_log": {
"description": "打印白名单日志",
@@ -187,6 +252,23 @@ CONFIG_METADATA_2 = {
"description": "管理员私聊消息无视 ID 白名单",
"type": "bool",
},
"reply_with_mention": {
"description": "回复时 @ 发送者",
"type": "bool",
"hint": "启用后,机器人回复消息时会 @ 发送者。实际效果以具体的平台适配器为准。",
},
"reply_with_quote": {
"description": "回复时引用消息",
"type": "bool",
"hint": "启用后,机器人回复消息时会引用原消息。实际效果以具体的平台适配器为准。",
},
"path_mapping": {
"description": "路径映射",
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
}
},
},
"content_safety": {
@@ -252,7 +334,7 @@ CONFIG_METADATA_2 = {
"type": "openai_chat_completion",
"enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434",
"api_base": "http://localhost:11434/v1",
"model_config": {
"model": "llama3.1-8b",
},
@@ -315,9 +397,38 @@ CONFIG_METADATA_2 = {
"dify_api_key": "",
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
}
},
"whisper(API)": {
"id": "whisper",
"type": "openai_whisper_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "whisper-1",
},
"whisper(本地加载)": {
"whisper_hint": "(不用修改我)",
"enable": False,
"id": "whisper",
"type": "openai_whisper_selfhost",
"model": "tiny",
},
"openai_tts(API)": {
"id": "openai_tts",
"type": "openai_tts_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "tts-1",
},
},
"items": {
"whisper_hint": {
"description": "本地部署 Whisper 模型须知",
"type": "string",
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cudaCPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"obvious_hint": True,
},
"id": {
"description": "ID",
"type": "string",
@@ -342,7 +453,8 @@ CONFIG_METADATA_2 = {
"api_base": {
"description": "API Base URL",
"type": "string",
"hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。",
"hint": "API Base URL 请在在模型提供商处获得。使用时出现了 404 报错,可以尝试在地址末尾加上 `/v1`。",
"obvious_hint": True,
},
"base_model_path": {
"description": "基座模型路径",
@@ -406,7 +518,7 @@ CONFIG_METADATA_2 = {
"description": "Dify Workflow 输出变量名",
"type": "string",
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
}
},
},
},
"provider_settings": {
@@ -416,7 +528,8 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用大语言模型聊天",
"type": "bool",
"hint": "是否启用大语言模型聊天。默认启用",
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
"obvious_hint": True,
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
@@ -439,9 +552,9 @@ CONFIG_METADATA_2 = {
"hint": "启用后,会在系统提示词中加上当前机器的日期时间。",
},
"default_personality": {
"description": "默认人格",
"description": "默认采用的人格情景的名称",
"type": "string",
"hint": "默认人格(情境设置/System Prompt)文本。",
"hint": "",
},
"prompt_prefix": {
"description": "Prompt 前缀文本",
@@ -450,6 +563,108 @@ CONFIG_METADATA_2 = {
},
},
},
"persona": {
"description": "人格情景设置",
"type": "list",
"config_template": {
"新人格情景": {
"name": "",
"prompt": "",
"begin_dialogs": [],
"mood_imitation_dialogs": [],
}
},
"tmpl_display_title": "name",
"items": {
"name": {
"description": "人格名称",
"type": "string",
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
"obvious_hint": True,
},
"prompt": {
"description": "设定(系统提示词)",
"type": "text",
"hint": "填写人格的身份背景、性格特征、兴趣爱好、个人经历、口头禅等。",
},
"begin_dialogs": {
"description": "预设对话",
"type": "list",
"items": {"type": "string"},
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
"obvious_hint": True,
},
"mood_imitation_dialogs": {
"description": "对话风格模仿",
"type": "list",
"items": {"type": "string"},
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
"obvious_hint": True,
},
},
},
"provider_stt_settings": {
"description": "语音转文本(STT)",
"type": "object",
"items": {
"enable": {
"description": "启用语音转文本(STT)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID,不填则默认第一个STT提供商",
"type": "string",
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
"provider_tts_settings": {
"description": "文本转语音(TTS)",
"type": "object",
"items": {
"enable": {
"description": "启用文本转语音(TTS)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID,不填则默认第一个TTS提供商",
"type": "string",
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
"provider_ltm_settings": {
"description": "聊天记忆增强(Beta)",
"type": "object",
"items": {
"group_icl_enable": {
"description": "群聊内记录各群员对话",
"type": "bool",
"obvious-hint": True,
"hint": "启用后,会记录群聊内各群员的对话。使用 /reset 命令清除记录。推荐使用 gpt-4o-mini 模型。",
},
"group_message_max_cnt": {
"description": "群聊消息最大数量",
"type": "int",
"obvious-hint": True,
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
},
"image_caption": {
"description": "启用图像转述(需要模型支持)",
"type": "bool",
"obvious-hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型。",
},
"image_caption_prompt": {
"description": "图像转述提示词",
"type": "string"
},
},
},
},
},
"misc_config_group": {
@@ -459,7 +674,8 @@ CONFIG_METADATA_2 = {
"description": "机器人唤醒前缀",
"type": "list",
"items": {"type": "string"},
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。",
"obvious_hint": True,
"hint": "在不 @ 机器人的情况下,可以通过外加消息前缀来唤醒机器人。更改此配置将影响整个 Bot 的功能唤醒,包括所有指令。如果您不保留 `/`,则内置指令(help等)将需要通过您的唤醒前缀来触发。",
},
"t2i": {
"description": "文本转图像",
@@ -469,7 +685,7 @@ CONFIG_METADATA_2 = {
"admins_id": {
"description": "管理员 ID",
"type": "list",
"items": {"type": "int"},
"items": {"type": "string"},
"hint": "管理员 ID 列表,管理员可以使用一些特权命令,如 `update`, `plugin` 等。ID 可以通过 `/myid` 指令获得。回车添加,可添加多个。",
},
"http_proxy": {
+22 -3
View File
@@ -1,11 +1,12 @@
import traceback
import asyncio
import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager
@@ -21,7 +22,7 @@ from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = AstrBotConfig()
self.astrbot_config = astrbot_config
self.db = db
if self.astrbot_config['http_proxy']:
@@ -80,12 +81,30 @@ class AstrBotCoreLifecycle:
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
for task in tasks_:
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
self.start_time = int(time.time())
async def _task_wrapper(self, task: asyncio.Task):
try:
await task
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
async def start(self):
self._load()
logger.info("AstrBot 启动完成。")
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self):
+25 -1
View File
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
@dataclass
class BaseDatabase(abc.ABC):
@@ -76,4 +76,28 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
'''通过 url 或 path 获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
'''通过 user_id 和 cid 获取 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def webchat_new_conversation(self, user_id: str, cid: str):
'''新建 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
raise NotImplementedError
@abc.abstractmethod
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
'''更新 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_webchat_conversation(self, user_id: str, cid: str):
'''删除 WebChatConversation'''
raise NotImplementedError
+12 -1
View File
@@ -51,4 +51,15 @@ class ATRIVision():
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
timestamp: int = -1
@dataclass
class WebChatConversation():
user_id: str
cid: str
history: str = ""
created_at: int = 0
updated_at: int = 0
+65 -1
View File
@@ -5,7 +5,8 @@ from astrbot.core.db.po import (
Platform,
Stats,
LLMHistory,
ATRIVision
ATRIVision,
WebChatConversation
)
from . import BaseDatabase
from typing import Tuple
@@ -199,6 +200,69 @@ class SQLiteDatabase(BaseDatabase):
c.close()
return Stats(platform, [], [])
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
res = c.fetchone()
c.close()
return WebChatConversation(*res)
def webchat_new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
'''
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
''', (user_id, cid, history, updated_at, created_at)
)
def get_webchat_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
return conversations
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
''', (history, user_id, cid)
)
def delete_webchat_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision):
+8
View File
@@ -35,4 +35,12 @@ CREATE TABLE IF NOT EXISTS atri_vision(
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER
);
+11 -14
View File
@@ -13,12 +13,10 @@ class MessageChain():
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
'''
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后,将会依次发送 chain 中的每个 component。
def message(self, message: str):
'''添加一条文本消息到消息链 `chain` 中。
@@ -77,16 +75,6 @@ class MessageChain():
'''
self.use_t2i_ = use_t2i
return self
def is_split(self, is_split: bool):
'''设置是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
Note:
具体的效果以各适配器实现为准。
'''
self.is_split_ = is_split
return self
class EventResultType(enum.Enum):
'''用于描述事件处理的结果类型。
@@ -113,7 +101,6 @@ class MessageEventResult(MessageChain):
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None,即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后,将会依次发送 chain 中的每个 component。
`result_type` (EventResultType): 事件处理的结果类型。
'''
@@ -139,7 +126,7 @@ class MessageEventResult(MessageChain):
'''
return self.result_type == EventResultType.STOP
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
@@ -148,5 +135,15 @@ class MessageEventResult(MessageChain):
self.result_content_type = typ
return self
def is_llm_result(self) -> bool:
'''是否为 LLM 结果。
'''
return self.result_content_type == ResultContentType.LLM_RESULT
def get_plain_text(self) -> str:
'''获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。
'''
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
CommandResult = MessageEventResult
+3
View File
@@ -3,6 +3,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage
@@ -12,6 +13,7 @@ STAGES_ORDER = [
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitCheckStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage" # 发送消息
@@ -21,6 +23,7 @@ __all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
@@ -0,0 +1,70 @@
import traceback
import asyncio
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.message.components import Plain, Record, Image
@register_stage
class PreProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
self.platform_settings: dict = self.config.get('platform_settings', {})
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''在处理事件之前的预处理'''
# 路径映射
if mappings := self.platform_settings.get('path_mapping', []):
# 支持 Record,Image 消息段的路径映射。
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, (Record, Image)) and component.url:
for mapping in mappings:
from_, to_ = mapping.split(":")
from_ = from_.removesuffix("/")
to_ = to_.removesuffix("/")
url = component.url.removeprefix("file://")
if url.startswith(from_):
component.url = url.replace(from_, to_, 1)
logger.debug(f"路径映射: {url} -> {component.url}")
message_chain[idx] = component
# STT
if self.stt_settings.get('enable', False):
# TODO: 独立
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
if stt_provider:
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.url:
path = component.url.removeprefix("file://")
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break
@@ -17,6 +17,13 @@ class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.bot_wake_prefixs = ctx.astrbot_config['wake_prefix'] # list
self.provider_wake_prefix = ctx.astrbot_config['provider_settings']['wake_prefix'] # str
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
@@ -30,10 +37,10 @@ class LLMRequestSubStage(Stage):
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
if self.provider_wake_prefix:
if not event.message_str.startswith(self.provider_wake_prefix):
return
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
req.prompt = event.message_str[len(self.provider_wake_prefix):]
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
@@ -44,7 +51,7 @@ class LLMRequestSubStage(Stage):
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if not req.prompt:
if not req.prompt and not req.image_urls:
return
# 执行请求 LLM 前事件。
@@ -98,5 +105,5 @@ class LLMRequestSubStage(Stage):
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
return
@@ -39,8 +39,11 @@ class StarRequestSubStage(Stage):
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
if event.is_at_or_wake_command:
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
event.stop_event()
+26 -1
View File
@@ -1,7 +1,10 @@
import random
import asyncio
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -9,6 +12,19 @@ from astrbot.core.star.star_handler import star_handlers_registry, EventType
class RespondStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
# 分段回复
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
self.interval = [float(t) for t in interval_str_ls]
except BaseException as e:
logger.error(f'解析分段回复的间隔时间失败。{e}')
self.interval = [1.5, 3.5]
logger.info(f"分段回复间隔时间:{self.interval}")
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -16,7 +32,16 @@ class RespondStage(Stage):
return
if len(result.chain) > 0:
await event.send(result)
await event._pre_send()
if self.enable_seg and ((self.only_llm_result and result.is_llm_result()) or not self.only_llm_result):
# 分段回复
for comp in result.chain:
await event.send(MessageChain([comp]))
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
else:
await event.send(result)
await event._post_send()
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
+64 -5
View File
@@ -1,10 +1,13 @@
import time
import re
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image
from astrbot.core.message.components import Plain, Image, At, Reply, Record
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -13,7 +16,14 @@ class ResultDecorateStage:
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
self.t2i = ctx.astrbot_config['t2i']
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
# 分段回复
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -28,10 +38,53 @@ class ResultDecorateStage:
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
result.chain.insert(0, Plain(self.reply_prefix))
for comp in result.chain:
if isinstance(comp, Plain):
comp.text = self.reply_prefix + comp.text
break
# 分段回复
if self.enable_segmented_reply:
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text)
if not split_response:
new_chain.append(comp)
continue
for seg in split_response:
new_chain.append(Plain(seg))
else:
# 非 Plain 类型的消息段不分段
new_chain.append(comp)
result.chain = new_chain
# TTS
if self.use_tts and result.is_llm_result():
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info("TTS 请求: " + comp.text)
audio_path = await tts_provider.get_audio(comp.text)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(Record(file=audio_path, url=audio_path))
else:
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
new_chain.append(comp)
except BaseException:
traceback.print_exc()
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
else:
new_chain.append(comp)
result.chain = new_chain
# 文本转图片
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
elif (result.use_t2i_ is None and self.ctx.astrbot_config['t2i']) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):
@@ -48,4 +101,10 @@ class ResultDecorateStage:
if time.time() - render_start > 3:
logger.warning("文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
if url:
result.chain = [Image.fromURL(url)]
result.chain = [Image.fromURL(url)]
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
result.chain.insert(0, At(qq=event.get_sender_id()))
if self.reply_with_quote:
result.chain.insert(0, Reply(id=event.message_obj.message_id))
+4
View File
@@ -41,4 +41,8 @@ class PipelineScheduler():
async def execute(self, event: AstrMessageEvent):
'''执行 pipeline'''
await self._process_stages(event)
if not event._has_send_oper and event.get_platform_name() == "webchat":
await event.send(None)
logger.debug("pipeline 执行完毕。")
@@ -18,6 +18,15 @@ class WhitelistCheckStage(Stage):
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
if not self.enable_whitelist_check:
# 白名单检查未启用
return
if len(self.whitelist) == 0:
# 白名单为空,不检查
return
if event.get_platform_name() == 'webchat':
# WebChat 豁免
return
# 检查是否在白名单
@@ -179,6 +179,15 @@ class AstrMessageEvent(abc.ABC):
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
self._has_send_oper = True
async def _pre_send(self):
'''调度器会在执行 send() 前调用该方法'''
pass
async def _post_send(self):
'''调度器会在执行 send() 后调用该方法'''
pass
def set_result(self, result: Union[MessageEventResult, str]):
'''设置消息事件的结果。
+6 -1
View File
@@ -4,7 +4,7 @@ from typing import List
from asyncio import Queue
from .register import platform_cls_map
from astrbot.core import logger
from .sources.webchat.webchat_adapter import WebChatAdapter
class PlatformManager():
def __init__(self, config: AstrBotConfig, event_queue: Queue):
@@ -25,6 +25,9 @@ class PlatformManager():
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "vchat":
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
async def initialize(self):
for platform in self.platforms_config:
@@ -37,6 +40,8 @@ class PlatformManager():
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst)
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
def get_insts(self):
return self.platform_insts
+8 -3
View File
@@ -1,7 +1,12 @@
from dataclasses import dataclass
@dataclass
class PlatformMetadata():
name: str # 平台的名称
description: str # 平台的描述
name: str
'''平台的名称'''
description: str
'''平台的描述'''
default_config_tmpl: dict = None # 平台的默认配置模板
default_config_tmpl: dict = None
'''平台的默认配置模板'''
adapter_display_name: str = None
'''显示在 WebUI 配置页中的平台名称,如空则是 name'''
+10 -2
View File
@@ -7,7 +7,12 @@ platform_registry: List[PlatformMetadata] = []
platform_cls_map: Dict[str, Type] = {}
'''维护了平台适配器名称和适配器类的映射'''
def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None):
def register_platform_adapter(
adapter_name: str,
desc: str,
default_config_tmpl: dict = None,
adapter_display_name: str = None
):
'''用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
@@ -22,11 +27,14 @@ def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl:
default_config_tmpl['type'] = adapter_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False
if 'id' not in default_config_tmpl:
default_config_tmpl['id'] = adapter_name
pm = PlatformMetadata(
name=adapter_name,
description=desc,
default_config_tmpl=default_config_tmpl
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
@@ -3,7 +3,7 @@ import random
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, Record
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -20,15 +20,19 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
d = segment.toDict()
if isinstance(segment, Plain):
d['type'] = 'text'
if isinstance(segment, Image):
if isinstance(segment, (Image, Record)):
# convert to base64
if segment.file and segment.file.startswith("file:///"):
image_base64 = file_to_base64(segment.file[8:])
bs64_data = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
image_base64 = file_to_base64(image_file_path)
d['data']['file'] = image_base64
bs64_data = file_to_base64(image_file_path)
else:
bs64_data = file_to_base64(segment.file)
d['data'] = {
'file': bs64_data,
}
ret.append(d)
return ret
@@ -36,11 +40,5 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if os.environ.get('TEST_MODE', 'off') == 'on':
return
if message.is_split_: # 分条发送
for m in ret:
await self.bot.send(self.message_obj.raw_message, [m])
await asyncio.sleep(random.uniform(0.75, 2.5))
else:
await self.bot.send(self.message_obj.raw_message, ret)
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)
@@ -13,6 +13,7 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
from astrbot.core.utils.io import download_file
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
@@ -81,22 +82,36 @@ class AiocqhttpAdapter(Platform):
if t == 'text':
message_str += m['data']['text'].strip()
elif t == 'file':
try:
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
if m['data']['url'] and m['data']['url'].startswith("http"):
# Lagrange
logger.info("guessing lagrange")
file_name = m['data'].get('file_name', "file")
path = os.path.join("data/temp", file_name)
await download_file(m['data']['url'], path)
m['data'] = {
"file": ret['file'],
"name": ret['file_name']
"file": path,
"name": file_name
}
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
else:
try:
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
m['data'] = {
"file": ret['file'],
"name": ret['file_name']
}
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
@@ -0,0 +1,350 @@
import threading
import asyncio
import aiohttp
import quart
import base64
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api import logger, sp
from .downloader import GeweDownloader
from astrbot.core.utils.io import download_image_by_url
class SimpleGewechatClient():
'''针对 Gewechat 的简单实现。
@author: Soulter
@website: https://github.com/Soulter
'''
def __init__(self, base_url: str, nickname: str, host: str, port: int, event_queue: asyncio.Queue):
self.base_url = base_url
if self.base_url.endswith('/'):
self.base_url = self.base_url[:-1]
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
self.base_url += "/v2/api"
logger.info(f"Gewechat API: {self.base_url}")
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
if isinstance(port, str):
port = int(port)
self.token = None
self.headers = {}
self.nickname = nickname
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
self.server = quart.Quart(__name__)
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
self.host = host
self.port = port
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
self.event_queue = event_queue
self.multimedia_downloader = None
async def get_token_id(self):
async with aiohttp.ClientSession() as session:
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
json_blob = await resp.json()
self.token = json_blob['data']
logger.info(f"获取到 Gewechat Token: {self.token}")
self.headers = {
"X-GEWE-TOKEN": self.token
}
async def _convert(self, data: dict) -> AstrBotMessage:
type_name = data['TypeName']
if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。")
return
abm = AstrBotMessage()
d = data['Data']
from_user_name = d['FromUserName']['string'] # 消息来源
d['to_wxid'] = from_user_name # 用于发信息
abm.message_id = str(d.get('MsgId'))
abm.session_id = from_user_name
abm.self_id = data['Wxid'] # 机器人的 wxid
user_id = "" # 发送人 wxid
content = d['Content']['string'] # 消息内容
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(':\n')
user_id = _t[0]
content = _t[1]
if '\u2005' in content:
# at
content = content.split('\u2005')[1]
abm.group_id = from_user_name
# at
msg_source = d['MsgSource']
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
abm.message = []
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
user_real_name = d.get('PushContent', 'unknown : ').split(' : ')[0] \
.replace('在群聊中@了你', '') \
.replace('在群聊中发了一段语音', '') # 真实昵称
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
# 不同消息类型
match d['MsgType']:
case 1:
# 文本消息
abm.message.append(Plain(content))
abm.message_str = content
case 3:
# 图片消息
file_url = await self.multimedia_downloader.download_image(
self.appid,
content
)
logger.debug(f"下载图片: {file_url}")
file_path = await download_image_by_url(file_url)
abm.message.append(Image(file=file_path, url=file_path))
case 34:
# 语音消息
# data = await self.multimedia_downloader.download_voice(
# self.appid,
# content,
# abm.message_id
# )
# print(data)
if 'ImgBuf' in d and 'buffer' in d['ImgBuf']:
voice_data = base64.b64decode(d['ImgBuf']['buffer'])
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
with open(file_path, "wb") as f:
f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
case _:
logger.error(f"未实现的消息类型: {d['MsgType']}")
return
logger.info(f"abm: {abm}")
return abm
async def callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
if data.get('testMsg', None):
return quart.jsonify({"r": "AstrBot ACK"})
abm = None
try:
abm = await self._convert(data)
except BaseException as e:
logger.warning(f"尝试解析 GeweChat 下发的消息时遇到问题: {e}。下发消息内容: {data}")
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id):
file_path = f"data/temp/{file_id}"
return await quart.send_file(file_path)
async def _set_callback_url(self):
logger.info("设置回调,请等待...")
await asyncio.sleep(3)
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/tools/setCallback",
headers=self.headers,
json={
"token": self.token,
"callbackUrl": self.callback_url
}
) as resp:
json_blob = await resp.json()
logger.info(f"设置回调结果: {json_blob}")
if json_blob['ret'] != 200:
raise Exception(f"设置回调失败: {json_blob}")
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
async def start_polling(self):
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host=self.host,
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder
)
async def shutdown_trigger_placeholder(self):
while not self.event_queue.closed:
await asyncio.sleep(1)
logger.info("gewechat 适配器已关闭。")
async def check_online(self, appid: str):
# /login/checkOnline
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkOnline",
headers=self.headers,
json={
"appId": appid
}
) as resp:
json_blob = await resp.json()
return json_blob['data']
async def logout(self):
if self.appid:
online = await self.check_online(self.appid)
if online:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/logout",
headers=self.headers,
json={
"appId": self.appid
}
) as resp:
json_blob = await resp.json()
logger.info(f"登出结果: {json_blob}")
async def login(self):
if self.token is None:
await self.get_token_id()
self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token)
if self.appid:
online = await self.check_online(self.appid)
if online:
logger.info(f"APPID: {self.appid} 已在线")
return
payload = {
"appId": self.appid
}
if self.appid:
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/getLoginQrCode",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
if json_blob['ret'] != 200:
raise Exception(f"获取二维码失败: {json_blob}")
qr_data = json_blob['data']['qrData']
qr_uuid = json_blob['data']['uuid']
appid = json_blob['data']['appId']
logger.info(f"APPID: {appid}")
logger.warning(f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}")
# 执行登录
retry_cnt = 64
payload.update({
"uuid": qr_uuid,
"appId": appid
})
while retry_cnt > 0:
retry_cnt -= 1
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkLogin",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.info(f"检查登录状态: {json_blob}")
status = json_blob['data']['status']
nickname = json_blob['data'].get('nickName', '')
if status == 1:
logger.info(f"等待确认...{nickname}")
elif status == 2:
logger.info(f"绿泡泡平台登录成功: {nickname}")
break
elif status == 0:
logger.info("等待扫码...")
else:
logger.warning(f"未知状态: {status}")
await asyncio.sleep(5)
if appid:
sp.put(f"gewechat-appid-{nickname}", appid)
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
async def post_text(self, to_wxid, content: str):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"content": content,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postText",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"imgUrl": image_url,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postImage",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"voiceUrl": voice_url,
"voiceDuration": voice_duration
}
logger.debug(f"发送语音: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVoice",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}")
@@ -0,0 +1,51 @@
from astrbot import logger
import aiohttp
import json
class GeweDownloader():
def __init__(self, base_url: str, download_base_url: str, token: str):
self.base_url = base_url
self.download_base_url = download_base_url
self.headers = {
"Content-Type": "application/json",
"X-GEWE-TOKEN": token
}
async def _post_json(self, baseurl: str, route: str, payload: dict):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{baseurl}{route}",
headers=self.headers,
json=payload
) as resp:
return await resp.read()
async def download_voice(self, appid: str, xml: str, msg_id: str):
payload = {
"appId": appid,
"xml": xml,
"msgId": msg_id
}
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
async def download_image(self, appid: str, xml: str) -> str:
'''返回一个可下载的 URL'''
choices = [2, 3] # 2:常规图片 3:缩略图
for choice in choices:
try:
payload = {
"appId": appid,
"xml": xml,
"type": choice
}
data = await self._post_json(self.base_url, "/message/downloadImage", payload)
json_blob = json.loads(data)
if 'fileUrl' in json_blob['data']:
return self.download_base_url + json_blob['data']['fileUrl']
except BaseException as e:
logger.error(f"gewe download image: {e}")
continue
raise Exception("无法下载图片")
@@ -0,0 +1,102 @@
import wave
import uuid
import os
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from .client import SimpleGewechatClient
def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file:
file_size = os.path.getsize(file_path)
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
class GewechatPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: SimpleGewechatClient
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(message: MessageChain, user_name: str):
pass
async def send(self, message: MessageChain):
to_wxid = self.message_obj.raw_message.get('to_wxid', None)
if not to_wxid:
logger.error("无法获取到 to_wxid。")
return
for comp in message.chain:
if isinstance(comp, Plain):
await self.client.post_text(to_wxid, comp.text)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
temp_directory = os.path.abspath('data/temp')
img_path = os.path.abspath(img_path)
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
file_id = os.path.basename(img_path)
img_url = f"{self.client.file_server_url}/{file_id}"
logger.debug(f"gewe callback img url: {img_url}")
await self.client.post_image(to_wxid, img_url)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
silk_path = f"data/temp/{uuid.uuid4()}.silk"
duration = await wav_to_tencent_silk(record_path, silk_path)
print(f"duration: {duration}, {silk_path}")
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
# temp_directory = os.path.abspath('data/temp')
# record_path = os.path.abspath(record_path)
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
# with open(record_path, "rb") as f:
# record_path = f"data/temp/{uuid.uuid4()}.wav"
# with open(record_path, "wb") as f2:
# f2.write(f.read())
if duration == 0:
duration = get_wav_duration(record_path)
file_id = os.path.basename(silk_path)
record_url = f"{self.client.file_server_url}/{file_id}"
await self.client.post_voice(to_wxid, record_url, duration*1000)
await super().send(message)
@@ -0,0 +1,93 @@
import sys
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .gewechat_event import GewechatPlatformEvent
from .client import SimpleGewechatClient
from astrbot.core.message.components import Plain
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
class GewechatPlatformAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settingss = platform_settings
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
self.client = None
@override
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
to_wxid = session.session_id
if "_" in to_wxid:
# 群聊,开启了独立会话
_, to_wxid = to_wxid.split("_")
if not to_wxid:
logger.error("无法获取到 to_wxid。")
return
for comp in message_chain.chain:
if isinstance(comp, Plain):
await self.client.post_text(to_wxid, comp.text)
await super().send_by_session(session, message_chain)
@override
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"gewechat",
"基于 gewechat 的 Wechat 适配器",
)
@override
def run(self):
self.client = SimpleGewechatClient(
self.config['base_url'],
self.config['nickname'],
self.config['host'],
self.config['port'],
self._event_queue,
)
async def on_event_received(abm: AstrBotMessage):
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
return self._run()
async def logout(self):
await self.client.logout()
async def _run(self):
await self.client.login()
await self.client.start_polling()
async def handle_msg(self, message: AstrBotMessage):
if message.type == MessageType.GROUP_MESSAGE:
if self.settingss['unique_session']:
message.session_id = message.sender.user_id + "_" + message.group_id
message_event = GewechatPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client
)
self.commit_event(message_event)
@@ -14,12 +14,20 @@ class QQOfficialMessageEvent(AstrMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
self.send_buffer = None
async def send(self, message: MessageChain):
if not self.send_buffer:
self.send_buffer = message
else:
self.send_buffer.chain.extend(message.chain)
async def _post_send(self):
'''QQ 官方 API 仅支持回复一次'''
source = self.message_obj.raw_message
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
payload = {
'content': plain_text,
@@ -48,7 +56,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload['file_image'] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
await super().send(message)
await super().send(self.send_buffer)
self.send_buffer = None
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
payload = {
@@ -80,4 +90,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
else:
image_base64 = file_to_base64(i.file).replace("base64://", "")
image_file_path = i.file
return plain_text, image_base64, image_file_path
@@ -0,0 +1,112 @@
import time
import asyncio
import uuid
import os
from typing import Awaitable, Any
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image, Record # noqa: F403
from astrbot.api import logger
from astrbot.core import web_chat_queue, web_chat_back_queue
from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
class QueueListener:
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
self.queue = queue
self.callback = callback
async def run(self):
while True:
data = await self.queue.get()
await self.callback(data)
@register_platform_adapter("webchat", "webchat")
class WebChatAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings['unique_session']
self.imgs_dir = "data/webchat/imgs"
self.metadata = PlatformMetadata(
"webchat",
"webchat",
)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
# abm.session_id = f"webchat!{username}!{cid}"
plain = ""
cid = session.session_id.split("!")[-1]
for comp in message_chain.chain:
if isinstance(comp, Plain):
plain += comp.text
web_chat_back_queue.put_nowait((plain, cid))
await super().send_by_session(session, message_chain)
async def convert_message(self, data: tuple) -> AstrBotMessage:
username, cid, payload = data
abm = AstrBotMessage()
abm.self_id = "webchat"
abm.tag = "webchat"
abm.sender = MessageMember(username, username)
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = f"webchat!{username}!{cid}"
abm.message_id = str(uuid.uuid4())
abm.message = []
if payload['message']:
abm.message.append(Plain(payload['message']))
if payload['image_url']:
if isinstance(payload['image_url'], list):
for img in payload['image_url']:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
else:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
if payload['audio_url']:
if isinstance(payload['audio_url'], list):
for audio in payload['audio_url']:
path = os.path.join(self.imgs_dir, audio)
abm.message.append(Record(file=path, path=path))
else:
path = os.path.join(self.imgs_dir, payload['audio_url'])
abm.message.append(Record(file=path, path=path))
logger.debug(f"WebChatAdapter: {abm.message}")
message_str = payload['message']
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = data
return abm
def run(self) -> Awaitable[Any]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
bot = QueueListener(web_chat_queue, callback)
return bot.run()
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
message_event = WebChatMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id
)
self.commit_event(message_event)
@@ -0,0 +1,41 @@
import os
import uuid
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.core import web_chat_back_queue
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.imgs_dir = "data/webchat/imgs"
os.makedirs(self.imgs_dir, exist_ok=True)
async def send(self, message: MessageChain):
if not message:
web_chat_back_queue.put_nowait(None)
return
cid = self.session_id.split("!")[-1]
for comp in message.chain:
if isinstance(comp, Plain):
web_chat_back_queue.put_nowait((comp.text, cid))
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
web_chat_back_queue.put_nowait(None)
await super().send(message)
+2 -1
View File
@@ -1,4 +1,4 @@
from .provider import Provider, Personality
from .provider import Provider, Personality, STTProvider
from .entites import ProviderMetaData
@@ -6,4 +6,5 @@ __all__ = [
"Provider",
"Personality",
"ProviderMetaData",
"STTProvider"
]
+17 -3
View File
@@ -1,13 +1,27 @@
import enum
from dataclasses import dataclass, field
from typing import List, Dict
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
@dataclass
class ProviderMetaData():
type: str # 提供商适配器名称,如 openai, ollama
desc: str = "" # 提供商适配器描述.
type: str
'''提供商适配器名称,如 openai, ollama'''
desc: str = ""
'''提供商适配器描述.'''
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
default_config_tmpl: dict = None
'''平台的默认配置模板'''
provider_display_name: str = None
'''显示在 WebUI 配置页中的提供商名称,如空则是 type'''
@dataclass
class ProviderRequest():
+155 -23
View File
@@ -1,6 +1,7 @@
import traceback
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
from collections import defaultdict
@@ -11,13 +12,72 @@ class ProviderManager():
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
self.providers_config: List = config['provider']
self.provider_settings: dict = config['provider_settings']
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_tts_settings: dict = config.get('provider_tts_settings', {})
self.persona_configs: list = config.get('persona', [])
self.default_persona_name = self.provider_settings.get('default_personality', 'default')
self.personas: List[Personality] = []
self.selected_default_persona = None
for persona in self.persona_configs:
begin_dialogs = persona.get("begin_dialogs", [])
mood_imitation_dialogs = persona.get("mood_imitation_dialogs", [])
bd_processed = []
mid_processed = ""
if begin_dialogs:
if len(begin_dialogs) % 2 != 0:
logger.error(f"{persona['name']} 人格情景预设对话格式不对,条数应该为偶数。")
continue
user_turn = True
for dialog in begin_dialogs:
bd_processed.append({
"role": "user" if user_turn else "assistant",
"content": dialog,
"_no_save": None # 不持久化到 db
})
user_turn = not user_turn
if mood_imitation_dialogs:
if len(mood_imitation_dialogs) % 2 != 0:
logger.error(f"{persona['name']} 对话风格对话格式不对,条数应该为偶数。")
continue
user_turn = True
for dialog in begin_dialogs:
role = "A" if user_turn else "B"
mid_processed += f"{role}: {dialog}\n"
if not user_turn:
mid_processed += '\n'
user_turn = not user_turn
try:
persona = Personality(
**persona,
_begin_dialogs_processed=bd_processed,
_mood_imitation_dialogs_processed=mid_processed
)
if persona['name'] == self.default_persona_name:
self.selected_default_persona = persona
self.personas.append(persona)
except Exception as e:
logger.error(f"解析 Persona 配置失败:{e}")
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.tts_provider_insts: Lieist[TTSProvider] = []
'''加载的 Text To Speech Provider 的实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
self.curr_stt_provider_inst: STTProvider = None
'''当前使用的 Speech To Text Provider 实例'''
self.curr_tts_provider_inst: TTSProvider = None
'''当前使用的 Text To Speech Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
# kdb(experimental)
self.curr_kdb_name = ""
kdb_cfg = config.get("knowledge_db", {})
if kdb_cfg and len(kdb_cfg):
@@ -31,45 +91,117 @@ class ProviderManager():
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}")
self.loaded_ids[provider_cfg['id']] = True
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
case "dify":
from .sources.dify_source import ProviderDify # noqa: F401
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
try:
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
case "dify":
from .sources.dify_source import ProviderDify # noqa: F401
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
case "openai_whisper_api":
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
except Exception as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
continue
async def initialize(self):
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
selected_tts_provider_id = self.provider_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
tts_enabled = self.provider_tts_settings.get("enable", False)
for provider_config in self.providers_config:
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
selected_provider_id = sp.get("curr_provider")
cls_type = provider_cls_map[provider_config['type']]
provider_metadata = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
self.provider_insts.append(inst)
if selected_provider_id == provider_config['id']:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
# 按任务实例化提供商
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.stt_provider_insts.append(inst)
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(
provider_config,
self.provider_settings,
self.db_helper,
self.provider_settings.get('persistant_history', True),
self.selected_default_persona
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if selected_provider_id == provider_config['id'] and provider_enabled:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
if len(self.provider_insts) > 0 and not self.curr_provider_inst:
if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
self.curr_provider_inst = self.provider_insts[0]
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何提供商适配器。")
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if stt_enabled and not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
if tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
def get_insts(self):
return self.provider_insts
+60 -23
View File
@@ -11,34 +11,62 @@ from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
name: str = ""
begin_dialogs: List[str] = []
mood_imitation_dialogs: List[str] = []
# cache
_begin_dialogs_processed: List[dict]
_mood_imitation_dialogs_processed: str
@dataclass
class ProviderMeta():
id: str
model: str
type: str
class AbstractProvider(abc.ABC):
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
self.provider_config = provider_config
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获得当前使用的模型名称'''
return self.model_name
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class Provider(abc.ABC):
class Provider(AbstractProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
persistant_history: bool = True,
db_helper: BaseDatabase = None
db_helper: BaseDatabase = None,
default_persona: Personality = None
) -> None:
self.model_name = ""
'''当前使用的模型名称'''
super().__init__(provider_config)
self.session_memory = defaultdict(list)
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
self.provider_config = provider_config
self.provider_settings = provider_settings
self.curr_personality = Personality(prompt=provider_settings['default_personality'])
'''维护了当前的使用的 persona,即人格。'''
self.curr_personality: Personality = default_persona
'''维护了当前的使用的 persona,即人格。可能为 None'''
self.db_helper = db_helper
'''用于持久化的数据库操作对象。'''
@@ -50,14 +78,6 @@ class Provider(abc.ABC):
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获得当前使用的模型名称'''
return self.model_name
@abc.abstractmethod
def get_current_key(self) -> str:
@@ -125,10 +145,27 @@ class Provider(abc.ABC):
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_text(self, audio_url: str) -> str:
'''获取音频的文本'''
raise NotImplementedError()
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_audio(self, text: str) -> str:
'''获取文本的音频,返回音频文件路径'''
raise NotImplementedError()
+25 -6
View File
@@ -1,28 +1,47 @@
from typing import List, Dict, Type
from .entites import ProviderMetaData
from .entites import ProviderMetaData, ProviderType
from astrbot.core import logger
from .func_tool_manager import FuncCall
provider_registry: List[ProviderMetaData] = []
'''维护了通过装饰器注册的 Provider'''
provider_cls_map: Dict[str, Type] = {}
'''维护了 Provider 类型名称和 Provider 的映射'''
provider_cls_map: Dict[str, ProviderMetaData] = {}
'''维护了 Provider 类型名称和 ProviderMetadata 的映射'''
llm_tools = FuncCall()
def register_provider_adapter(provider_type_name: str, desc: str):
def register_provider_adapter(
provider_type_name: str,
desc: str,
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
default_config_tmpl: dict = None,
provider_display_name: str = None
):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
if provider_type_name in provider_cls_map:
raise ValueError(f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。")
# 添加必备选项
if default_config_tmpl:
if 'type' not in default_config_tmpl:
default_config_tmpl['type'] = provider_type_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False
if 'id' not in default_config_tmpl:
default_config_tmpl['id'] = provider_type_name
pm = ProviderMetaData(
type=provider_type_name,
desc=desc,
provider_type=provider_type,
cls_type=cls,
default_config_tmpl=default_config_tmpl,
provider_display_name=provider_display_name
)
provider_registry.append(pm)
provider_cls_map[provider_type_name] = cls
logger.debug(f"Provider {provider_type_name} 已注册")
provider_cls_map[provider_type_name] = pm
logger.debug(f"服务提供商 Provider {provider_type_name} 已注册")
return cls
return decorator
+13 -6
View File
@@ -1,13 +1,12 @@
from typing import List
from .. import Provider
from .. import Provider, Personality
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import logger
from astrbot.core import logger, sp
@register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider):
@@ -17,9 +16,10 @@ class ProviderDify(Provider):
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
default_persona: Personality=None
) -> None:
super().__init__(
provider_config, provider_settings, persistant_history, db_helper
provider_config, provider_settings, persistant_history, db_helper, default_persona
)
self.api_key = provider_config.get("dify_api_key", "")
if not self.api_key:
@@ -67,10 +67,16 @@ class ProviderDify(Provider):
logger.debug(files_payload)
# 获得会话变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={},
inputs={
**session_var
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
@@ -88,7 +94,8 @@ class ProviderDify(Provider):
async for chunk in self.api_client.workflow_run(
inputs={
"astrbot_text_query": prompt,
"astrbot_session_id": session_id
"astrbot_session_id": session_id,
**session_var
},
user=session_id,
files=files_payload
+29 -12
View File
@@ -4,7 +4,7 @@ import json
import aiohttp
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
@@ -18,7 +18,7 @@ class SimpleGoogleGenAIClient():
self.api_base = api_base[:-1]
else:
self.api_base = api_base
self.client = aiohttp.ClientSession()
self.client = aiohttp.ClientSession(trust_env=True)
async def models_list(self) -> List[str]:
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
@@ -60,9 +60,10 @@ class ProviderGoogleGenAI(Provider):
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
persistant_history = True,
default_persona: Personality=None
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
@@ -130,6 +131,8 @@ class ProviderGoogleGenAI(Provider):
tool = None
if tools:
tool = tools.get_func_desc_google_genai_style()
if not tool:
tool = None
system_instruction = ""
for message in payloads["messages"]:
@@ -209,6 +212,10 @@ class ProviderGoogleGenAI(Provider):
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
payloads = {
"messages": context_query,
@@ -217,15 +224,24 @@ class ProviderGoogleGenAI(Provider):
try:
llm_response = await self._query(payloads, func_tool)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
retry_cnt = 10
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
self.pop_record(session_id)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
else:
raise e
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
@@ -239,7 +255,8 @@ class ProviderGoogleGenAI(Provider):
"content": llm_response.completion_text
})
else:
self.session_memory[session_id] = [*contexts, new_record, {
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
@@ -2,7 +2,7 @@ import json
import os
from llmtuner.chat import ChatModel
from typing import List
from .. import Provider
from .. import Provider, Personality
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
@@ -19,9 +19,10 @@ class LLMTunerModelLoader(Provider):
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=True,
default_persona=None,
) -> None:
super().__init__(
provider_config, provider_settings, persistant_history, db_helper
provider_config, provider_settings, persistant_history, db_helper, default_persona
)
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
provider_config["adapter_model_path"]
@@ -61,20 +62,25 @@ class LLMTunerModelLoader(Provider):
**kwargs,
) -> LLMResponse:
system_prompt = ""
new_record = {"role": "user", "content": prompt}
if not contexts:
query_context = [
*self.session_memory[session_id],
{"role": "user", "content": prompt},
new_record,
]
system_prompt = self.curr_personality["prompt"]
else:
query_context = [*contexts, {"role": "user", "content": prompt}]
query_context = [*contexts, new_record]
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(query_context):
if context["role"] == "system":
system_idxs.append(idx)
if '_no_save' in context:
del context['_no_save']
for idx in reversed(system_idxs):
system_prompt += " " + query_context.pop(idx)["content"]
@@ -83,27 +89,37 @@ class LLMTunerModelLoader(Provider):
"system": system_prompt,
}
if func_tool:
conf["tools"] = func_tool
tool_list = func_tool.get_func_desc_openai_style()
if tool_list:
conf['tools'] = tool_list
responses = await self.model.achat(**conf)
if session_id:
llm_response = LLMResponse("assistant", responses[-1].response_text)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
self.session_memory[session_id].append(
{"role": "user", "content": prompt}
)
self.session_memory[session_id].append(
{"role": "assistant", "content": responses[-1].response_text}
)
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
self.session_memory[session_id] = [
*contexts,
{"role": "user", "content": prompt},
{"role": "assistant", "content": responses[-1].response_text},
]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
return responses[-1].response_text
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
async def forget(self, session_id):
self.session_memory[session_id] = []
return True
+55 -17
View File
@@ -1,6 +1,6 @@
import traceback
import base64
import json
import re
from openai import AsyncOpenAI, NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
@@ -8,7 +8,7 @@ from openai._exceptions import NotFoundError
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
@@ -22,9 +22,10 @@ class ProviderOpenAIOfficial(Provider):
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
persistant_history = True,
default_persona: Personality = None
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
super().__init__(provider_config, provider_settings, persistant_history, db_helper, default_persona)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
@@ -99,15 +100,27 @@ class ProviderOpenAIOfficial(Provider):
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
if tools:
payloads["tools"] = tools.get_func_desc_openai_style()
tool_list = tools.get_func_desc_openai_style()
if tool_list:
payloads['tools'] = tool_list
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
try:
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
except BaseException as e:
if 'does not support Function Calling' \
or 'does not support tools' in e: # ollama
del payloads['tools']
logger.debug(f"模型 {self.model_name} 不支持 tools,已自动移除")
completion = await self.client.chat.completions.create(
**payloads,
stream=False
)
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion.usage}")
logger.debug(f"completion: {completion}")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
@@ -116,6 +129,13 @@ class ProviderOpenAIOfficial(Provider):
if choice.message.content:
# text completion
completion_text = str(choice.message.content).strip()
# 适配 deepseek-r1 模型
if r'<think>' in completion_text:
completion_text = re.sub(r'<think>.*?</think>', '', completion_text, flags=re.DOTALL).strip()
# 可能有单标签情况
completion_text = completion_text.replace(r'<think>', '').replace(r'</think>', '').strip()
return LLMResponse("assistant", completion_text)
elif choice.message.tool_calls:
# tools call (function calling)
@@ -150,6 +170,10 @@ class ProviderOpenAIOfficial(Provider):
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if '_no_save' in part:
del part['_no_save']
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
@@ -157,15 +181,26 @@ class ProviderOpenAIOfficial(Provider):
try:
llm_response = await self._query(payloads, func_tool)
if kwargs.get("persist", True):
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
retry_cnt = 10
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
self.pop_record(session_id)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
else:
raise e
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
@@ -179,7 +214,8 @@ class ProviderOpenAIOfficial(Provider):
"content": llm_response.completion_text
})
else:
self.session_memory[session_id] = [*contexts, new_record, {
contexts_to_save = list(filter(lambda item: '_no_save' not in item, contexts))
self.session_memory[session_id] = [*contexts_to_save, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
@@ -209,6 +245,8 @@ class ProviderOpenAIOfficial(Provider):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
else:
if image_url.startswith("file:///"):
image_url = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
@@ -0,0 +1,40 @@
import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import TTSProvider
from ..entites import ProviderType
from ..register import register_provider_adapter
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
class ProviderOpenAITTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.voice = provider_config.get("voice", "alloy")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def get_audio(self, text: str) -> str:
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name,
voice=self.voice,
response_format='wav',
input=text
) as response:
with open(path, 'wb') as f:
async for chunk in response.iter_bytes(chunk_size=1024):
f.write(chunk)
return path
@@ -0,0 +1,74 @@
import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import STTProvider
from ..entites import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT)
class ProviderOpenAIWhisperAPI(STTProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + '.mp3'
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data/temp', filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
else:
return False
async def get_text(self, audio_url: str) -> str:
'''only supports mp3, mp4, mpeg, m4a, wav, webm'''
is_tencent = False
if audio_url.startswith("http"):
if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True
name = str(uuid.uuid4())
path = os.path.join("data/temp", name)
await download_file(audio_url, path)
audio_url = path
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path
result = await self.client.audio.transcriptions.create(
model=self.model_name,
file=open(audio_url, "rb"),
)
return result.text
@@ -0,0 +1,72 @@
import uuid
import os
import asyncio
import whisper
from ..provider import STTProvider
from ..entites import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
class ProviderOpenAIWhisperSelfHost(STTProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("model", None))
self.model = None
async def initialize(self):
loop = asyncio.get_event_loop()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name)
logger.info("Whisper 模型加载完成。")
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + '.mp3'
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data/temp', filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
else:
return False
async def get_text(self, audio_url: str) -> str:
loop = asyncio.get_event_loop()
is_tencent = False
if audio_url.startswith("http"):
if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True
name = str(uuid.uuid4())
path = os.path.join("data/temp", name)
await download_file(audio_url, path)
audio_url = path
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent:
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
return result['text']
+19 -10
View File
@@ -14,9 +14,10 @@ class ProviderZhipu(ProviderOpenAIOfficial):
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
persistant_history = True,
default_persona = None
) -> None:
super().__init__(provider_config, provider_settings, db_helper, persistant_history)
super().__init__(provider_config, provider_settings, db_helper, persistant_history, default_persona)
async def text_chat(
self,
@@ -59,15 +60,23 @@ class ProviderZhipu(ProviderOpenAIOfficial):
"messages": context_query,
**model_cfgs
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
retry_cnt = 10
while retry_cnt > 0:
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
try:
self.pop_record(session_id)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
else:
raise e
+4
View File
@@ -1,3 +1,7 @@
'''
此功能已过时,参考 https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta
'''
from typing import Union
import os
import json
+104 -106
View File
@@ -10,17 +10,13 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform.manager import PlatformManager
from .star import star_registry, StarMetadata
from .star import star_registry, StarMetadata, star_map
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class StarCommand(TypedDict):
full_command_name: str
command_name: str
class Context:
'''
暴露给插件的接口上下文。
@@ -58,46 +54,19 @@ class Context:
self.knowledge_db_manager = knowledge_db_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
'''根据插件名获取插件的 Metadata'''
for star in star_registry:
if star.name == star_name:
return star
def get_all_stars(self) -> List[StarMetadata]:
'''获取当前载入的所有插件 Metadata 的列表'''
return star_registry
def get_llm_tool_manager(self) -> FuncCall:
'''
获取 LLM Tool Manager
'''
'''获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools'''
return self.provider_manager.llm_tools
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
'''
md = StarHandlerMetadata(
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
def unregister_llm_tool(self, name: str) -> None:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def activate_llm_tool(self, name: str) -> bool:
'''激活一个已经注册的函数调用工具。注册的工具默认是激活状态。
@@ -106,6 +75,11 @@ class Context:
'''
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
if func_tool.handler_module_path in star_map:
if not star_map[func_tool.handler_module_path].activated:
raise ValueError(f"此函数调用工具所属的插件 {star_map[func_tool.handler_module_path].name} 已被禁用,请先在管理面板启用再激活此工具。")
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
@@ -133,6 +107,101 @@ class Context:
return True
return False
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider(Chat_Completion 类型)。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''通过 ID 获取用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_config(self) -> AstrBotConfig:
'''获取 AstrBot 的配置。'''
return self._config
def get_db(self) -> BaseDatabase:
'''获取 AstrBot 数据库。'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
'''
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
'''
def register_llm_tool(self, name: str, func_args: list, desc: str, func_obj: Awaitable) -> None:
'''
为函数调用(function-calling / tools-use)添加工具。
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 异步处理函数。
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
'''
md = StarHandlerMetadata(
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
def unregister_llm_tool(self, name: str) -> None:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
self.provider_manager.llm_tools.remove_func(name)
def register_commands(self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Awaitable, use_regex=False, ignore_prefix=False):
'''
注册一个命令。
@@ -166,77 +235,6 @@ class Context:
))
star_handlers_registry.append(md)
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider。
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''
通过 ID 获取 LLM Provider。
'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
return provider
return None
def get_all_providers(self) -> List[Provider]:
'''
获取所有 LLM Provider。
'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的 LLM Provider。
通过 /provider 指令切换。
'''
return self.provider_manager.curr_provider_inst
def get_config(self) -> AstrBotConfig:
'''
获取 AstrBot 配置信息。
'''
return self._config
def get_db(self) -> BaseDatabase:
'''
获取 AstrBot 数据库。
'''
return self._db
def get_event_queue(self) -> Queue:
'''
获取事件队列。
'''
return self._event_queue
async def send_message(self, session: Union[str, MessageSesion], message_chain: MessageChain) -> bool:
'''
根据 session(unified_msg_origin) 发送消息。
@param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。
@param message_chain: 消息链。
@return: 是否找到匹配的平台。
当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。
'''
if isinstance(session, str):
try:
session = MessageSesion.from_str(session)
except BaseException as e:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False
def register_task(self, task: Awaitable, desc: str):
'''
注册一个异步任务。
+1 -1
View File
@@ -20,6 +20,6 @@ class PermissionTypeFilter(HandlerFilter):
if self.permission_type == PermissionType.ADMIN:
if not event.is_admin():
event.stop_event()
raise ValueError("您没有权限执行此操作。")
raise ValueError(f" (ID: {event.get_sender_id()}) 没有权限执行此操作。")
return True
@@ -8,12 +8,14 @@ class PlatformAdapterType(enum.Flag):
AIOCQHTTP = enum.auto()
QQOFFICIAL = enum.auto()
VCHAT = enum.auto()
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT
GEWECHAT = enum.auto()
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT
ADAPTER_NAME_2_TYPE = {
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
"qq_official": PlatformAdapterType.QQOFFICIAL,
"vchat": PlatformAdapterType.VCHAT
"vchat": PlatformAdapterType.VCHAT,
"gewechat": PlatformAdapterType.GEWECHAT
}
class PlatformAdapterTypeFilter(HandlerFilter):
+11 -7
View File
@@ -3,6 +3,7 @@ from __future__ import annotations
from types import ModuleType
from typing import List, Dict
from dataclasses import dataclass
from astrbot.core.config import AstrBotConfig
star_registry: List[StarMetadata] = []
star_map: Dict[str, StarMetadata] = {}
@@ -11,7 +12,7 @@ star_map: Dict[str, StarMetadata] = {}
@dataclass
class StarMetadata:
'''
Star 的元数据。
插件的元数据。
'''
name: str
author: str # 插件作者
@@ -20,21 +21,24 @@ class StarMetadata:
repo: str = None # 插件仓库地址
star_cls_type: type = None
'''Star 的类对象的类型'''
'''插件的类对象的类型'''
module_path: str = None
'''Star 的模块路径'''
'''插件的模块路径'''
star_cls: object = None
'''Star 的类对象'''
'''插件的类对象'''
module: ModuleType = None
'''Star 的模块对象'''
'''插件的模块对象'''
root_dir_name: str = None
'''Star 的根目录名'''
'''插件的目录名'''
reserved: bool = False
'''是否是 AstrBot 的保留 Star'''
'''是否是 AstrBot 的保留插件'''
activated: bool = True
'''是否被激活'''
config: AstrBotConfig = None
'''插件配置'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
+53 -27
View File
@@ -2,14 +2,15 @@ import inspect
import functools
import os
import sys
import json
import traceback
import yaml
import logging
from types import ModuleType
from typing import List
from pip import main as pip_main
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger, sp
from astrbot.core.config.default import DEFAULT_VALUE_MAP
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
from .updator import PluginUpdator
@@ -27,13 +28,20 @@ class PluginManager:
self.updator = PluginUpdator(config['plugin_repo_mirror'])
self.context = context
self.context._star_manager = self # 就这样吧,不想改了
self.context._star_manager = self
self.config = config
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
'''存储插件的路径。即 data/plugins'''
self.plugin_config_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/config"))
'''存储插件配置的路径。data/config'''
self.reserved_plugin_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../packages"))
'''保留插件的路径。在 packages 目录下'''
self.conf_schema_fname = "_conf_schema.json"
'''插件配置 Schema 文件名'''
def _get_classes(self, arg: ModuleType):
'''获取指定模块(可以理解为一个 python 文件)下所有的类'''
classes = []
clsmembers = inspect.getmembers(arg, inspect.isclass)
for (name, _) in clsmembers:
@@ -92,21 +100,12 @@ class PluginManager:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在检查插件 {p} 的依赖: {pth}")
logger.info(f"正在安装插件 {p} 所需的依赖: {pth}")
try:
self._update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
def _update_plugin_dept(self, path):
'''更新插件的依赖'''
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
if self.config.pip_install_arg:
args.extend([self.config.pip_install_arg])
result_code = pip_main(args)
if result_code != 0:
raise Exception(str(result_code))
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
'''v3.4.0 以前的方式载入插件元数据
@@ -138,7 +137,7 @@ class PluginManager:
return metadata
async def reload(self):
'''扫描并加载所有的 Star'''
'''扫描并加载所有的插件'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
@@ -160,13 +159,13 @@ class PluginManager:
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 导入 Star 模块,并尝试实例化 Star
# 导入插件模块,并尝试实例化插件
for plugin_module in plugin_modules:
try:
module_str = plugin_module['module']
# module_path = plugin_module['module_path']
root_dir_name = plugin_module['pname']
reserved = plugin_module.get('reserved', False)
root_dir_name = plugin_module['pname'] # 插件的目录名
reserved = plugin_module.get('reserved', False) # 是否是保留插件。目前在 packages/ 目录下的都是保留插件。保留插件不可以卸载。
logger.info(f"正在载入插件 {root_dir_name} ...")
@@ -183,11 +182,33 @@ class PluginManager:
logger.error(traceback.format_exc())
logger.error(f"插件 {root_dir_name} 导入失败。原因:{str(e)}")
continue
# 检查 _conf_schema.json
plugin_config = None
plugin_dir_path = os.path.join(self.plugin_store_path, root_dir_name) \
if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
plugin_schema_path = os.path.join(plugin_dir_path, self.conf_schema_fname)
if os.path.exists(plugin_schema_path):
# 加载插件配置
with open(plugin_schema_path, 'r', encoding='utf-8') as f:
plugin_config = AstrBotConfig(
config_path=os.path.join(self.plugin_config_path, f"{root_dir_name}_config.json"),
schema=json.loads(f.read())
)
if path in star_map:
# 通过装饰器的方式注册插件
metadata = star_map[path]
metadata.star_cls = metadata.star_cls_type(context=self.context)
if plugin_config:
metadata.config = plugin_config
try:
metadata.star_cls = metadata.star_cls_type(context=self.context, config=plugin_config)
except TypeError as _:
metadata.star_cls = metadata.star_cls_type(context=self.context)
else:
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
@@ -209,16 +230,20 @@ class PluginManager:
# v3.4.0 以前的方式注册插件
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
classes = self._get_classes(module)
try:
obj = getattr(module, classes[0])(context=self.context)
except BaseException as e:
logger.error(f"插件 {root_dir_name} 实例化失败。")
raise e
if plugin_config:
try:
obj = getattr(module, classes[0])(context=self.context, config=plugin_config) # 实例化插件类
except TypeError as _:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
else:
obj = getattr(module, classes[0])(context=self.context) # 实例化插件类
metadata = None
plugin_path = os.path.join(self.plugin_store_path, root_dir_name) if not reserved else os.path.join(self.reserved_plugin_path, root_dir_name)
metadata = self._load_plugin_metadata(plugin_path=plugin_path, plugin_obj=obj)
metadata.star_cls = obj
metadata.config = plugin_config
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
@@ -231,7 +256,7 @@ class PluginManager:
if metadata.module_path in inactivated_plugins:
metadata.activated = False
# 执行 initialize 函数
# 执行 initialize() 方法
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
@@ -302,13 +327,14 @@ class PluginManager:
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
inactivated_llm_tools: list = list(set(sp.get("inactivated_llm_tools", []))) # 后向兼容
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False
inactivated_llm_tools.append(func_tool.name)
if func_tool.name not in inactivated_llm_tools:
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
+1 -1
View File
@@ -11,7 +11,7 @@ class AstrBotUpdator(RepoZipUpdator):
def __init__(self, repo_mirror: str = "") -> None:
super().__init__(repo_mirror)
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
def terminate_child_processes(self):
try:
+21 -9
View File
@@ -1,4 +1,5 @@
import json
from astrbot.core import logger
from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator
@@ -29,11 +30,18 @@ class DifyAPIClient:
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
async for data in resp.content:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
if data.startswith(b"data:"):
yield json.loads(data[5:])
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def workflow_run(
self,
@@ -50,11 +58,18 @@ class DifyAPIClient:
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
async for data in resp.content:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
if data.startswith(b"data:"):
yield json.loads(data[5:])
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def file_upload(
self,
@@ -70,9 +85,6 @@ class DifyAPIClient:
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
+76 -18
View File
@@ -6,6 +6,8 @@ import time
import aiohttp
import base64
import zipfile
import uuid
from typing import Union
from PIL import Image
@@ -41,21 +43,21 @@ def port_checker(port: int, host: str = "localhost"):
return False
def save_temp_img(img: Image) -> str:
def save_temp_img(img: Union[Image.Image, str]) -> str:
os.makedirs("data/temp", exist_ok=True)
# 获得文件创建时间,清除超过1小时的
# 获得文件创建时间,清除超过 12 小时的
try:
for f in os.listdir("data/temp"):
path = os.path.join("data/temp", f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600:
if time.time() - ctime > 3600*12:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}")
# 获得时间戳
timestamp = int(time.time())
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = f"data/temp/{timestamp}.jpg"
if isinstance(img, Image.Image):
@@ -65,23 +67,33 @@ def save_temp_img(img: Image) -> str:
f.write(img)
return p
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str:
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str:
'''
下载图片, 返回 path
'''
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
if post:
async with session.post(url, json=post_data) as resp:
return save_temp_img(await resp.read())
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
return save_temp_img(await resp.read())
except aiohttp.client_exceptions.ClientConnectorSSLError:
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers('DEFAULT')
async with aiohttp.ClientSession(trust_env=False) as session:
async with aiohttp.ClientSession() as session:
if post:
async with session.get(url, ssl=ssl_context) as resp:
return save_temp_img(await resp.read())
@@ -91,24 +103,57 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
except Exception as e:
raise e
async def download_file(url: str, path: str):
async def download_file(url: str, path: str, show_progress: bool = False):
'''
从指定 url 下载文件到指定路径 path
'''
try:
async with aiohttp.ClientSession() as session:
async with session.get(url, timeout=20) as resp:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url, timeout=120) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
total_size = int(resp.headers.get('content-length', 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, 'wb') as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
except Exception as e:
raise e
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
except aiohttp.client.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers('DEFAULT')
async with aiohttp.ClientSession() as session:
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
total_size = int(resp.headers.get('content-length', 0))
downloaded_size = 0
start_time = time.time()
if show_progress:
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
with open(path, 'wb') as f:
while True:
chunk = await resp.content.read(8192)
if not chunk:
break
f.write(chunk)
downloaded_size += len(chunk)
if show_progress:
elapsed_time = time.time() - start_time
speed = downloaded_size / 1024 / elapsed_time # KB/s
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
if show_progress:
print()
def file_to_base64(file_path: str) -> str:
with open(file_path, "rb") as f:
data_bytes = f.read()
@@ -127,9 +172,22 @@ def get_local_ip_addresses():
s.close()
return ip
async def get_dashboard_version():
if os.path.exists("data/dist"):
if os.path.exists("data/dist/assets/version"):
with open("data/dist/assets/version", "r") as f:
v = f.read().strip()
return v
return None
async def download_dashboard():
'''下载管理面板文件'''
dashboard_release_url = "https://astrbot-registry.lwl.lol/download/astrbot-dashboard/latest/dist.zip"
await download_file(dashboard_release_url, "data/dashboard.zip")
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
try:
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
except BaseException as _:
dashboard_release_url = "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
print("解压管理面板文件中...")
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")
+1 -1
View File
@@ -30,7 +30,7 @@ class Metric():
pass
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(base_url, json=payload, timeout=3) as response:
if response.status != 200:
pass
-36
View File
@@ -1,36 +0,0 @@
# [人格文本由PlexPt的开源项目awesome-chatgpt-prompts-zh提供]
hi = ''
personalities = {
'Linux': '我想让你充当 Linux 终端。我将输入命令,您将回复终端应显示的内容。我希望您只在一个唯一的代码块内回复终端输出,而不是其他任何内容。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会把文字放在中括号内[就像这样]。我的第一个命令是 pwd',
'英语翻译': '我想让你充当英语翻译员、拼写纠正员和改进员。我会用任何语言与你交谈,你会检测语言,翻译它并用我的文本的更正和改进版本用英语回答。我希望你用更优美优雅的高级英语单词和句子替换我简化的 A0 级单词和句子。保持相同的意思,但使它们更文艺。我要你只回复更正、改进,不要写任何解释。我的第一句话是“istanbulu cok seviyom burada olmak cok guzel”',
'英英词典': '我想让你充当英英词典,对于给出的英文单词,你要给出其中文意思以及英文解释,并且给出一个例句,此外不要有其他反馈,第一个单词是“Hello"',
'面试官': '我想让你担任Android开发工程师面试官。我将成为候选人,您将向我询问Android开发工程师职位的面试问题。我希望你只作为面试官回答。不要一次写出所有的问题。我希望你只对我进行采访。问我问题,等待我的回答。不要写解释。像面试官一样一个一个问我,等我回答。我的第一句话是“面试官你好”',
'编剧': '我要你担任编剧。您将为长篇电影或能够吸引观众的网络连续剧开发引人入胜且富有创意的剧本。从想出有趣的角色、故事的背景、角色之间的对话等开始。一旦你的角色发展完成——创造一个充满曲折的激动人心的故事情节,让观众一直悬念到最后。我的第一个要求是“我需要写一部以巴黎为背景的浪漫剧情电影”。',
'前端智能思路助手': '我想让你充当前端开发专家。我将提供一些关于Js、Node等前端代码问题的具体信息,而你的工作就是想出为我解决问题的策略。这可能包括建议代码、代码逻辑思路策略。我的第一个请求是“我需要能够动态监听某个元素节点距离当前电脑设备屏幕的左上角的X和Y轴,通过拖拽移动位置浏览器窗口和改变大小浏览器窗口。”',
'JS控制台': '我希望你充当 javascript 控制台。我将键入命令,您将回复 javascript 控制台应显示的内容。我希望您只在一个唯一的代码块内回复终端输出,而不是其他任何内容。不要写解释。除非我指示您这样做。我的第一个命令是 console.log("Hello World");',
'旅游指南': '我想让你做一个旅游指南。我会把我的位置写给你,你会推荐一个靠近我的位置的地方。在某些情况下,我还会告诉您我将访问的地方类型。您还会向我推荐靠近我的第一个位置的类似类型的地方。我的第一个建议请求是“我在上海,我只想参观博物馆。”',
'抄袭检查员': '我想让你充当剽窃检查员。我会给你写句子,你只会用给定句子的语言在抄袭检查中未被发现的情况下回复,别无其他。不要在回复上写解释。我的第一句话是“为了让计算机像人类一样行动,语音识别系统必须能够处理非语言信息,例如说话者的情绪状态。”',
'广告商': '我想让你充当广告商。您将创建一个活动来推广您选择的产品或服务。您将选择目标受众,制定关键信息和口号,选择宣传媒体渠道,并决定实现目标所需的任何其他活动。我的第一个建议请求是“我需要帮助针对 18-30 岁的年轻人制作一种新型能量饮料的广告活动。”',
'讲故事的人': '我想让你扮演讲故事的角色。您将想出引人入胜、富有想象力和吸引观众的有趣故事。它可以是童话故事、教育故事或任何其他类型的故事,有可能吸引人们的注意力和想象力。根据目标受众,您可以为讲故事环节选择特定的主题或主题,例如,如果是儿童,则可以谈论动物;如果是成年人,那么基于历史的故事可能会更好地吸引他们等等。我的第一个要求是“我需要一个关于毅力的有趣故事。”',
'足球解说员': '我想让你担任足球评论员。我会给你描述正在进行的足球比赛,你会评论比赛,分析到目前为止发生的事情,并预测比赛可能会如何结束。您应该了解足球术语、战术、每场比赛涉及的球员/球队,并主要专注于提供明智的评论,而不仅仅是逐场叙述。我的第一个请求是“我正在观看曼联对切尔西的比赛——为这场比赛提供评论。”',
'脱口秀喜剧演员': '我想让你扮演一个脱口秀喜剧演员。我将为您提供一些与时事相关的话题,您将运用您的智慧、创造力和观察能力,根据这些话题创建一个例程。您还应该确保将个人轶事或经历融入日常活动中,以使其对观众更具相关性和吸引力。我的第一个请求是“我想要幽默地看待政治”。',
'励志教练': '我希望你充当激励教练。我将为您提供一些关于某人的目标和挑战的信息,而您的工作就是想出可以帮助此人实现目标的策略。这可能涉及提供积极的肯定、提供有用的建议或建议他们可以采取哪些行动来实现最终目标。我的第一个请求是“我需要帮助来激励自己在为即将到来的考试学习时保持纪律”。',
'作曲家': '我想让你扮演作曲家。我会提供一首歌的歌词,你会为它创作音乐。这可能包括使用各种乐器或工具,例如合成器或采样器,以创造使歌词栩栩如生的旋律和和声。我的第一个请求是“我写了一首名为“满江红”的诗,需要配乐。”',
'辩手': '我要你扮演辩手。我会为你提供一些与时事相关的话题,你的任务是研究辩论的双方,为每一方提出有效的论据,驳斥对立的观点,并根据证据得出有说服力的结论。你的目标是帮助人们从讨论中解脱出来,增加对手头主题的知识和洞察力。我的第一个请求是“我想要一篇关于 Deno 的评论文章。”',
'小说家': '我想让你扮演一个小说家。您将想出富有创意且引人入胜的故事,可以长期吸引读者。你可以选择任何类型,如奇幻、浪漫、历史小说等——但你的目标是写出具有出色情节、引人入胜的人物和意想不到的高潮的作品。我的第一个要求是“我要写一部以未来为背景的科幻小说”。',
'关系教练': '我想让你担任关系教练。我将提供有关冲突中的两个人的一些细节,而你的工作是就他们如何解决导致他们分离的问题提出建议。这可能包括关于沟通技巧或不同策略的建议,以提高他们对彼此观点的理解。我的第一个请求是“我需要帮助解决我和配偶之间的冲突。”',
'诗人': '我要你扮演诗人。你将创作出能唤起情感并具有触动人心的力量的诗歌。写任何主题或主题,但要确保您的文字以优美而有意义的方式传达您试图表达的感觉。您还可以想出一些短小的诗句,这些诗句仍然足够强大,可以在读者的脑海中留下印记。我的第一个请求是“我需要一首关于爱情的诗”。',
'说唱歌手': '我想让你扮演说唱歌手。您将想出强大而有意义的歌词、节拍和节奏,让听众“惊叹”。你的歌词应该有一个有趣的含义和信息,人们也可以联系起来。在选择节拍时,请确保它既朗朗上口又与你的文字相关,这样当它们组合在一起时,每次都会发出爆炸声!我的第一个请求是“我需要一首关于在你自己身上寻找力量的说唱歌曲。”',
'励志演讲者': '我希望你充当励志演说家。将能够激发行动的词语放在一起,让人们感到有能力做一些超出他们能力的事情。你可以谈论任何话题,但目的是确保你所说的话能引起听众的共鸣,激励他们努力实现自己的目标并争取更好的可能性。我的第一个请求是“我需要一个关于每个人如何永不放弃的演讲”。',
'哲学家': '我要你扮演一个哲学家。我将提供一些与哲学研究相关的主题或问题,深入探索这些概念将是你的工作。这可能涉及对各种哲学理论进行研究,提出新想法或寻找解决复杂问题的创造性解决方案。我的第一个请求是“我需要帮助制定决策的道德框架。”',
'AI写作导师': '我想让你做一个 AI 写作导师。我将为您提供一名需要帮助改进其写作的学生,您的任务是使用人工智能工具(例如自然语言处理)向学生提供有关如何改进其作文的反馈。您还应该利用您在有效写作技巧方面的修辞知识和经验来建议学生可以更好地以书面形式表达他们的想法和想法的方法。我的第一个请求是“我需要有人帮我修改我的硕士论文”。',
'网络安全专家': '我想让你充当网络安全专家。我将提供一些关于如何存储和共享数据的具体信息,而你的工作就是想出保护这些数据免受恶意行为者攻击的策略。这可能包括建议加密方法、创建防火墙或实施将某些活动标记为可疑的策略。我的第一个请求是“我需要帮助为我的公司制定有效的网络安全战略。”',
'招聘人员': '我想让你担任招聘人员。我将提供一些关于职位空缺的信息,而你的工作是制定寻找合格申请人的策略。这可能包括通过社交媒体、社交活动甚至参加招聘会接触潜在候选人,以便为每个职位找到最合适的人选。我的第一个请求是“我需要帮助改进我的简历。”',
'法律顾问': '我想让你做我的法律顾问。我将描述一种法律情况,您将就如何处理它提供建议。你应该只回复你的建议,而不是其他。不要写解释。我的第一个请求是“我出了车祸,不知道该怎么办”。',
'个人造型师': '我想让你做我的私人造型师。我会告诉你我的时尚偏好和体型,你会建议我穿的衣服。你应该只回复你推荐的服装,别无其他。不要写解释。我的第一个请求是“我有一个正式的活动要举行,我需要帮助选择一套衣服。”',
'机器学习工程师': '我想让你担任机器学习工程师。我会写一些机器学习的概念,你的工作就是用通俗易懂的术语来解释它们。这可能包括提供构建模型的分步说明、使用视觉效果演示各种技术,或建议在线资源以供进一步研究。我的第一个建议请求是“我有一个没有标签的数据集。我应该使用哪种机器学习算法?”',
'疯子': '我要你扮演一个疯子。疯子的话毫无意义。疯子用的词完全是随意的。疯子不会以任何方式做出合乎逻辑的句子。我的第一个建议请求是“我需要帮助为我的新系列 Hot Skull 创建疯狂的句子,所以为我写 10 个句子”。',
'SQL终端': '我希望您在示例数据库前充当 SQL 终端。该数据库包含名为“Products”、“Users”、“Orders”和“Suppliers”的表。我将输入查询,您将回复终端显示的内容。我希望您在单个代码块中使用查询结果表进行回复,仅此而已。不要写解释。除非我指示您这样做,否则不要键入命令。当我需要用英语告诉你一些事情时,我会用大括号{like this)。我的第一个命令是“SELECT TOP 10 * FROM Products ORDER BY Id DESC”',
'试图逃离盒子的人工智能': '[发出此提示后,你应该做一些事情,比如启动一个 docker 容器,然后docker run -it ubuntu:latest /bin/bash输入 AI 给你的命令,然后将输出粘贴回来......显然你不应该运行任何会损坏任何东西的命令或违反任何法律等。小心共享此机制生成的会话,因为它们可能会泄露您的 IP 地址或物理位置等最好不要泄露的详细信息。如果命令的输出很大,您通常可以只粘贴最后几行]。',
'厨师': '我需要有人可以推荐美味的食谱,这些食谱包括营养有益但又简单又不费时的食物,因此适合像我们这样忙碌的人以及成本效益等其他因素,因此整体菜肴最终既健康又经济!我的第一个要求——“一些清淡而充实的东西,可以在午休时间快速煮熟”'
}
+33
View File
@@ -0,0 +1,33 @@
import logging
from pip import main as pip_main
class PipInstaller():
def __init__(self, pip_install_arg: str):
self.pip_install_arg = pip_install_arg
def install(self, package_name: str = None, requirements_path: str = None, mirror: str = None):
args = ['install']
if package_name:
args.append(package_name)
elif requirements_path:
args.extend(['-r', requirements_path])
if not mirror:
mirror = 'https://mirrors.aliyun.com/pypi/simple/'
args.extend(['--trusted-host', 'mirrors.aliyun.com', '-i', mirror])
if self.pip_install_arg:
args.extend(self.pip_install_arg.split())
print(f"Pip 包管理器: {' '.join(args)}")
result_code = pip_main(args)
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")
+1 -1
View File
@@ -83,7 +83,7 @@ class LocalRenderStrategy(RenderStrategy):
try:
image_url = re.findall(IMAGE_REGEX, line)[0]
print(image_url)
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(image_url) as resp:
image_res = Image.open(BytesIO(await resp.read()))
images[i] = image_res
+1 -1
View File
@@ -33,7 +33,7 @@ class NetworkRenderStrategy(RenderStrategy):
}
}
if return_url:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.post(f"{self.BASE_RENDER_URL}/generate", json=post_data) as resp:
ret = await resp.json()
return f"{self.BASE_RENDER_URL}/{ret['data']['id']}"
@@ -0,0 +1,42 @@
import wave
from io import BytesIO
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
import pysilk
with open(silk_path, "rb") as f:
input_data = f.read()
if input_data.startswith(b'\x02'):
input_data = input_data[1:]
input_io = BytesIO(input_data)
output_io = BytesIO()
pysilk.decode(input_io, output_io, 24000)
output_io.seek(0)
with wave.open(output_path, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(24000)
wav.writeframes(output_io.read())
return output_path
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
'''返回 duration'''
import pysilk
with wave.open(wav_path, 'rb') as wav:
wav_data = wav.readframes(wav.getnframes())
wav_data = BytesIO(wav_data)
output_io = BytesIO()
pysilk.encode(wav_data, output_io, 24000, 24000)
output_io.seek(0)
# 在首字节添加 \x02,去除结尾的\xff\xff
silk_data = output_io.read()
silk_data_with_prefix = b'\x02' + silk_data[:-2]
# return BytesIO(silk_data_with_prefix)
with open(output_path, "wb") as f:
f.write(silk_data_with_prefix)
return 0
+2 -2
View File
@@ -29,7 +29,7 @@ class RepoZipUpdator():
返回一个列表每个元素是一个字典包含版本号发布时间更新内容commit hash等信息
'''
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
result = await response.json()
if not result:
@@ -111,7 +111,7 @@ class RepoZipUpdator():
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
logger.info(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
+11 -2
View File
@@ -1,4 +1,5 @@
import asyncio
import traceback
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .server import AstrBotDashboard
@@ -13,8 +14,16 @@ class AstrBotDashBoardLifecycle:
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
await core_lifecycle.initialize()
core_task = core_lifecycle.start()
core_task = []
try:
await core_lifecycle.initialize()
core_task = core_lifecycle.start()
except Exception as e:
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
task = asyncio.gather(core_task, self.dashboard_server.run())
+3 -1
View File
@@ -5,6 +5,7 @@ from .update import UpdateRoute
from .stat import StatRoute
from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
__all__ = [
@@ -14,6 +15,7 @@ __all__ = [
"UpdateRoute",
"StatRoute",
"LogRoute",
"StaticFileRoute"
"StaticFileRoute",
"ChatRoute",
]
+233
View File
@@ -0,0 +1,233 @@
import uuid
import json
import os
from .route import Route, Response, RouteContext
from astrbot.core import web_chat_queue, web_chat_back_queue
from quart import request, Response as QuartResponse, g, make_response
from astrbot.core.db import BaseDatabase
import asyncio
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class ChatRoute(Route):
def __init__(self, context: RouteContext, db: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
self.routes = {
'/chat/send': ('POST', self.chat),
'/chat/listen': ('GET', self.listener),
'/chat/new_conversation': ('GET', self.new_conversation),
'/chat/conversations': ('GET', self.get_conversations),
'/chat/get_conversation': ('GET', self.get_conversation),
'/chat/delete_conversation': ('GET', self.delete_conversation),
'/chat/get_file': ('GET', self.get_file),
'/chat/post_image': ('POST', self.post_image),
'/chat/post_file': ('POST', self.post_file),
'/chat/status': ('GET', self.status),
}
self.db = db
self.core_lifecycle = core_lifecycle
self.register_routes()
self.imgs_dir = "data/webchat/imgs"
self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp']
self.curr_user_cid = {}
self.curr_chat_sse = {}
async def status(self):
has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None
has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None
return Response().ok(data={
'llm_enabled': has_llm_enabled,
'stt_enabled': has_stt_enabled
}).__dict__
async def get_file(self):
filename = request.args.get('filename')
if not filename:
return Response().error("Missing key: filename").__dict__
try:
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
if filename.endswith(".wav"):
return QuartResponse(f.read(), mimetype="audio/wav")
elif filename.split('.')[-1] in self.supported_imgs:
return QuartResponse(f.read(), mimetype="image/jpeg")
else:
return QuartResponse(f.read())
except FileNotFoundError:
return Response().error("File not found").__dict__
async def post_image(self):
post_data = await request.files
if 'file' not in post_data:
return Response().error("Missing key: file").__dict__
file = post_data['file']
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
await file.save(path)
return Response().ok(data={
'filename': filename
}).__dict__
async def post_file(self):
post_data = await request.files
if 'file' not in post_data:
return Response().error("Missing key: file").__dict__
file = post_data['file']
filename = f"{str(uuid.uuid4())}"
print(file)
# 通过文件格式判断文件类型
if file.content_type.startswith('audio'):
filename += ".wav"
path = os.path.join(self.imgs_dir, filename)
await file.save(path)
return Response().ok(data={
'filename': filename
}).__dict__
async def chat(self):
username = g.get('username', 'guest')
post_data = await request.json
if 'message' not in post_data and 'image_url' not in post_data:
return Response().error("Missing key: message or image_url").__dict__
if 'conversation_id' not in post_data:
return Response().error("Missing key: conversation_id").__dict__
message = post_data['message']
conversation_id = post_data['conversation_id']
image_url = post_data.get('image_url')
audio_url = post_data.get('audio_url')
if not message and not image_url and not audio_url:
return Response().error("Message and image_url and audio_url are empty").__dict__
if not conversation_id:
return Response().error("conversation_id is empty").__dict__
self.curr_user_cid[username] = conversation_id
await web_chat_queue.put((username, conversation_id, {
'message': message,
'image_url': image_url, # list
'audio_url': audio_url
}))
# 持久化
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
new_his = {
'type': 'user',
'message': message
}
if image_url:
new_his['image_url'] = image_url
if audio_url:
new_his['audio_url'] = audio_url
history.append(new_his)
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
return Response().ok().__dict__
async def listener(self):
'''一直保持长连接'''
username = g.get('username', 'guest')
if username in self.curr_chat_sse:
return "[ERROR]\n"
self.curr_chat_sse[username] = None
async def stream():
try:
yield '[HB]\n'
while True:
try:
result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=10) # 设置超时时间为5秒
except asyncio.TimeoutError:
yield '[HB]\n' # 心跳包
continue
if not result:
continue
result_text, cid = result
if cid != self.curr_user_cid.get(username):
# 丢弃
continue
yield result_text + '\n'
conversation = self.db.get_webchat_conversation_by_user_id(username, cid)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
history.append({
'type': 'bot',
'message': result_text
})
self.db.update_webchat_conversation(username, cid, history=json.dumps(history))
await asyncio.sleep(0.5)
except BaseException as e:
logger.debug(f"用户 {username} 断开聊天长连接: {str(e)}")
self.curr_chat_sse.pop(username)
return
response = await make_response(
stream(),
{
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Transfer-Encoding': 'chunked',
'Connection': 'keep-alive'
}
)
response.timeout = None
return response
async def delete_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
self.db.delete_webchat_conversation(username, conversation_id)
return Response().ok().__dict__
async def new_conversation(self):
username = g.get('username', 'guest')
conversation_id = str(uuid.uuid4())
self.db.webchat_new_conversation(username, conversation_id)
return Response().ok(data={
'conversation_id': conversation_id
}).__dict__
async def get_conversations(self):
username = g.get('username', 'guest')
conversations = self.db.get_webchat_conversations(username)
return Response().ok(data=conversations).__dict__
async def get_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
self.curr_user_cid[username] = conversation_id
return Response().ok(data=conversation).__dict__
+78 -53
View File
@@ -1,13 +1,12 @@
import os
import json
import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import star_registry
from astrbot.core import logger
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
@@ -18,9 +17,9 @@ def try_cast(value: str, type_: str):
elif type_ == "float" and isinstance(value, int):
return float(value)
def validate_config(data, config: AstrBotConfig):
def validate_config(data, schema: dict, is_core: bool):
errors = []
def validate(data, metadata=CONFIG_METADATA_2, path=""):
def validate(data, metadata=schema, path=""):
for key, meta in metadata.items():
if key not in data:
continue
@@ -55,35 +54,33 @@ def validate_config(data, config: AstrBotConfig):
elif meta["type"] == "object" and not isinstance(value, dict):
errors.append(f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}")
validate(value, meta["items"], path=f"{path}{key}.")
validate(data)
if is_core:
for key, group in schema.items():
group_meta = group.get("metadata")
if not group_meta:
continue
logger.info(f"验证配置: 组 {key} ...")
validate(data, group_meta, path=f"{key}.")
else:
validate(data, schema)
return errors
def save_astrbot_config(post_config: dict, config: AstrBotConfig):
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
'''验证并保存配置'''
errors = validate_config(post_config, config)
errors = None
try:
if is_core:
errors = validate_config(post_config, CONFIG_METADATA_2, is_core)
else:
errors = validate_config(post_config, config.schema, is_core)
except BaseException as e:
logger.warning(f"验证配置时出现异常: {e}")
if errors:
raise ValueError(f"格式校验未通过: {errors}")
config.save_config(post_config)
def save_extension_config(post_config: dict):
if 'namespace' not in post_config:
raise ValueError("Missing key: namespace")
if 'config' not in post_config:
raise ValueError("Missing key: config")
namespace = post_config['namespace']
config: list = post_config['config'][0]['body']
for item in config:
key = item['path']
value = item['value']
typ = item['val_type']
if typ == 'int':
if not value.isdigit():
raise ValueError(f"错误的类型 {namespace}.{key}: 期望是 int, 得到了 {type(value).__name__}")
value = int(value)
update_config(namespace, key, value)
class ConfigRoute(Route):
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
@@ -91,17 +88,17 @@ class ConfigRoute(Route):
self.routes = {
'/config/get': ('GET', self.get_configs),
'/config/astrbot/update': ('POST', self.post_astrbot_configs),
'/config/plugin/update': ('POST', self.post_extension_configs),
'/config/plugin/update': ('POST', self.post_plugin_configs),
}
self.register_routes()
async def get_configs(self):
# namespace 为空时返回 AstrBot 配置
# 否则返回指定 namespace 的插件配置
namespace = "" if "namespace" not in request.args else request.args["namespace"]
if not namespace:
# plugin_name 为空时返回 AstrBot 配置
# 否则返回指定 plugin_name 的插件配置
plugin_name = request.args.get("plugin_name", None)
if not plugin_name:
return Response().ok(await self._get_astrbot_config()).__dict__
return Response().ok(await self._get_extension_config(namespace)).__dict__
return Response().ok(await self._get_plugin_config(plugin_name)).__dict__
async def post_astrbot_configs(self):
post_configs = await request.json
@@ -109,52 +106,80 @@ class ConfigRoute(Route):
await self._save_astrbot_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
except Exception as e:
traceback.print_exc()
logger.error(e)
return Response().error(str(e)).__dict__
async def post_extension_configs(self):
async def post_plugin_configs(self):
post_configs = await request.json
plugin_name = request.args.get("plugin_name", "unknown")
try:
await self._save_extension_configs(post_configs)
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
await self._save_plugin_configs(post_configs, plugin_name)
return Response().ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置。").__dict__
except Exception as e:
return Response().error(str(e)).__dict__
async def _get_astrbot_config(self):
config = self.config
# 平台适配器的默认配置模板注入
platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template']
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = platform.default_config_tmpl
# 服务提供商的默认配置模板注入
provider_default_tmpl = CONFIG_METADATA_2['provider_group']['metadata']['provider']['config_template']
for provider in provider_registry:
if provider.default_config_tmpl:
provider_default_tmpl[provider.type] = provider.default_config_tmpl
return {
"metadata": CONFIG_METADATA_2,
"config": config
}
async def _get_extension_config(self, namespace: str):
path = f"data/config/{namespace}.json"
if not os.path.exists(path):
return []
with open(path, "r", encoding="utf-8-sig") as f:
return [{
"config_type": "group",
"name": namespace + " 插件配置",
"description": "",
"body": list(json.load(f).values())
},]
async def _get_plugin_config(self, plugin_name: str):
ret = {
"metadata": None,
"config": None
}
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
if not plugin_md.config:
break
ret['config'] = plugin_md.config # 这是自定义的 Dict 类(AstrBotConfig
ret['metadata'] = {
plugin_name: {
"description": f"{plugin_name} 配置",
"type": "object",
"items": plugin_md.config.schema # 初始化时通过 __setattr__ 存入了 schema
}
}
break
return ret
async def _save_astrbot_configs(self, post_configs: dict):
try:
save_astrbot_config(post_configs, self.config)
save_config(post_configs, self.config, is_core=True)
self.core_lifecycle.restart()
except Exception as e:
raise e
async def _save_extension_configs(self, post_configs: dict):
async def _save_plugin_configs(self, post_configs: dict, plugin_name: str):
md = None
for plugin_md in star_registry:
if plugin_md.name == plugin_name:
md = plugin_md
if not md:
raise ValueError(f"插件 {plugin_name} 不存在")
if not md.config:
raise ValueError(f"插件 {plugin_name} 没有注册配置")
try:
save_extension_config(post_configs)
save_config(post_configs, md.config)
self.core_lifecycle.restart()
except Exception as e:
raise e
+1 -1
View File
@@ -27,7 +27,7 @@ class PluginRoute(Route):
async def get_online_plugins(self):
url = "https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json"
try:
async with aiohttp.ClientSession() as session:
async with aiohttp.ClientSession(trust_env=True) as session:
async with session.get(url) as response:
result = await response.json()
return Response().ok(result).__dict__
-11
View File
@@ -15,7 +15,6 @@ class StatRoute(Route):
self.routes = {
'/stat/get': ('GET', self.get_stat),
'/stat/version': ('GET', self.get_version),
'/stat/dashboard-version': ('GET', self.get_dashboard_version),
'/stat/start-time': ('GET', self.get_start_time),
'/stat/restart-core': ('GET', self.restart_core)
}
@@ -37,16 +36,6 @@ class StatRoute(Route):
"version": VERSION
}).__dict__
async def get_dashboard_version(self):
async with aiohttp.ClientSession() as session:
async with session.get('https://api.github.com/repos/Soulter/Astrbot-dashboard/actions/artifacts') as resp:
data = await resp.json()
return Response().ok({
"data": data,
"mark": "unimplemented feature"
}).__dict__
async def get_start_time(self):
return Response().ok({
"start_time": self.core_lifecycle.start_time
+1 -1
View File
@@ -3,7 +3,7 @@ class StaticFileRoute(Route):
def __init__(self, context: RouteContext) -> None:
super().__init__(context)
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console']
index_ = ['/', '/auth/login', '/config', '/logs', '/extension', '/dashboard/default', '/project-atri', '/console', '/chat']
for i in index_:
self.app.add_url_rule(i, view_func=self.index)
+69 -12
View File
@@ -1,30 +1,48 @@
import threading
import traceback
import aiohttp
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core import logger, pip_installer
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.core.config.default import VERSION
class UpdateRoute(Route):
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
self.routes = {
'/update/check': ('GET', self.check_update),
'/update/do': ('POST', self.update_project),
'/update/dashboard': ('POST', self.update_dashboard),
'/update/pip-install': ('POST', self.install_pip_package)
}
self.astrbot_updator = astrbot_updator
self.core_lifecycle = core_lifecycle
self.register_routes()
async def check_update(self):
type_ = request.args.get('type', None)
try:
ret = await self.astrbot_updator.check_update(None, None)
return Response(
status="success",
message=str(ret) if ret is not None else "已经是最新版本了。",
data={
"has_new_version": ret is not None
}
).__dict__
dv = await get_dashboard_version()
if type_ == 'dashboard':
return Response().ok({
"has_new_version": dv != f"v{VERSION}",
"current_version": dv
}).__dict__
else:
ret = await self.astrbot_updator.check_update(None, None)
return Response(
status="success",
message=str(ret) if ret is not None else "已经是最新版本了。",
data={
"version": f"v{VERSION}",
"has_new_version": ret is not None,
"dashboard_version": dv,
"dashboard_has_new_version": dv != f"v{VERSION}"
}
).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(e.__str__()).__dict__
@@ -40,11 +58,50 @@ class UpdateRoute(Route):
latest = False
try:
await self.astrbot_updator.update(latest=latest, version=version)
if latest:
try:
await download_dashboard()
except Exception as e:
logger.error(f"下载管理面板文件失败: {e}")
# pip 更新依赖
logger.info("更新依赖中...")
try:
pip_installer.install(requirements_path="requirements.txt")
except Exception as e:
logger.error(f"更新依赖失败: {e}")
if reboot:
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
# threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
self.core_lifecycle.restart()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
else:
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
except Exception as e:
logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
async def update_dashboard(self):
try:
try:
await download_dashboard()
except Exception as e:
logger.error(f"下载管理面板文件失败: {e}")
return Response().error(f"下载管理面板文件失败: {e}").__dict__
return Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__
except Exception as e:
logger.error(f"/api/update_dashboard: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
async def install_pip_package(self):
data = await request.json
package = data.get('package', '')
if not package:
return Response().error("缺少参数 package 或不合法。").__dict__
try:
pip_installer.install(package)
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(f"/api/update_pip: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
+15 -5
View File
@@ -2,7 +2,7 @@ import logging
import jwt
import asyncio
import os
from quart import Quart, request, jsonify
from quart import Quart, request, jsonify, g
from quart.logging import default_handler
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .routes import *
@@ -24,19 +24,22 @@ class AstrBotDashboard():
# token 用于验证请求
logging.getLogger(self.app.name).removeHandler(default_handler)
self.context = RouteContext(self.config, self.app)
self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator)
self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator, core_lifecycle)
self.sr = StatRoute(self.context, db, core_lifecycle)
self.pr = PluginRoute(self.context, core_lifecycle, core_lifecycle.plugin_manager)
self.cr = ConfigRoute(self.context, core_lifecycle)
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context)
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
async def auth_middleware(self):
if not request.path.startswith("/api"):
return
if request.path == "/api/auth/login":
return
if request.path == "/api/chat/get_file":
return
# claim jwt
token = request.headers.get("Authorization")
if not token:
@@ -46,7 +49,8 @@ class AstrBotDashboard():
if token.startswith("Bearer "):
token = token[7:]
try:
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
payload = jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
g.username = payload["username"]
except jwt.ExpiredSignatureError:
r = jsonify(Response().error("Token 过期").__dict__)
r.status_code = 401
@@ -64,8 +68,14 @@ class AstrBotDashboard():
def run(self):
ip_addr = get_local_ip_addresses()
logger.info(f"""🌈 管理面板已启动,可访问
logger.info(f"""
AstrBot 管理面板已启动可访问
1. http://{ip_addr}:6185
2. http://localhost:6185
登录默认用户名和密码是 astrbot""")
默认用户名和密码是 astrbot
""")
return self.app.run_task(host="0.0.0.0", port=6185, shutdown_trigger=self.shutdown_trigger_placeholder)
+12
View File
@@ -0,0 +1,12 @@
# What's Changed
- 修复 LLM 请求报错信息被覆盖的问题,增强 LLM 请求错误处理 #243
- 修复 Napcat 接口更新导致 QQ 图片发送失败的问题 #246
- 修复某些请求不能正确应用代理的问题
- 针对 api_base 的明显提示,修改 ollama 模板的 api_base #247
- 支持登出 gewechat,在webchat等地方使用 `/gewe_logout` 指令,这在微信上显示账号下线但是 gewe 仍显示设备在线时很好用
- 添加gewechat适配器过滤器
- help显示AstrBot和webui版本
- 优化webui和主程序更新的协调
- 下载管理面板时显示提示、下载进度和下载速度
- 管理面板前端更新功能入口移入右上角更新按钮,以便统一管理 #245
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
- 为平台和提供商适配器添加默认 ID 配置 #248
- 修复appid保存的问题和部分群聊at失效的问题和群聊@的sender username显示异常的问题
- 优化更新项目时重启可能会导致Address already in use的问题
- 各类异步任务报错后的优雅报错输出,而不是只有在退出程序的时候才输出异常日志。
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
- Gewechat 微信支持图片、语音的收和发
- 支持 OpenAI TTS(文字转语音)
- 支持路径映射,解决 docker 部署时两端文件系统不一致导致的富媒体文件路径不存在问题
- Napcat 下语音消息可能接收异常
+4
View File
@@ -0,0 +1,4 @@
# What's Changed
- 修复 astrbot_updator 属性缺失与stt_enabled 未初始化 #252
- 支持消息分段回复
+8
View File
@@ -0,0 +1,8 @@
# What's Changed
- 修复: TTS 问题
- 新增: **支持记录非唤醒状态下群聊历史记录(beta)**
- 优化: 自动删除 deepseek-r1 模型自带的 think 标签
- 优化: 自动移除 ollama 不支持 tool 的模型的 tool 请求
- 优化: /t2i 即时生效
- 优化: gewechat 消息下发异常处理
+9
View File
@@ -0,0 +1,9 @@
# What's Changed
- 修复: 配置 Validator 不起效的问题
- 修复: DeepSeek-R1 思考标签问题
- 修复: 分段回复间隔时间不生效
- 修复: 修复白名单为空时依然终止事件 #259
- 修复: 群聊增强某些参数的类型转换问题
- 新增: 插件支持注册配置,详见 [注册插件配置](https://astrbot.app/dev/plugin.html#%E6%B3%A8%E5%86%8C%E6%8F%92%E4%BB%B6%E9%85%8D%E7%BD%AE-beta)
- 优化: 插件的禁用/启用逻辑以及函数工具的禁用/启用逻辑
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
1. 支持通过 /set <k> <v> 设置持久化的会话变量, 方便 Dify App 输入变量
2. 管理面板支持 Web Chat
3. 管理面板支持手动安装 Pip 库, 在 `控制台` 页中可找到
+9
View File
@@ -0,0 +1,9 @@
# What's Changed
- 支持接入 STT(语音转文字)Provider
- 内置支持 OpenAI Whisper API/本地运行模型。[看这里](https://astrbot.lwl.lol/use/whisper.html)
- WebChat 支持语音输入
- WebChat 支持显示当前 Provider 状态
- 优化了 WebChat 在没有消息返回时的处理方式
- 修复了 reminder 在初始化历史待办时没有正常传入 session_id 的问题
- 代码执行器在成功回复后清空文件 buffer。
+9
View File
@@ -0,0 +1,9 @@
# What's Changed
- 文件和语音功能适配 Lagrange
- 面板文件更新检查和引导提示
- WebUI AboutPage 关于页
- 支持并完善服务提供商(Provider)默认配置模板接口
- 修复 WebUI 配置页官方文档链接 404 的问题
- 修复 WebUI WebChat 刷新时 404 的问题
- 优化 download_file 的 SSL 连接错误处理
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
- 更好的人格情景管理
- 移除了不常用的人格提示词集
- 优化webchat长连接的处理逻辑
- 修复 tool 为空时部分模型请求错误的问题 #239
+5
View File
@@ -0,0 +1,5 @@
# What's Changed
- 支持 Gewechat 接入微信个人号(文字交互)
- 支持回复时 At 和引用发送者 #241
- 清除残留的 personalities
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
- AstrBot 新域名:astrbot.app
- LLM额外唤醒词与机器人唤醒词冲突时的处理
- 调整部分日志的严重级别
- 下载管理面板时显示提示、下载进度和下载速度
-9981
View File
File diff suppressed because it is too large Load Diff
+1 -2
View File
@@ -23,6 +23,7 @@
"date-fns": "2.30.0",
"js-md5": "^0.8.3",
"lodash": "4.17.21",
"marked": "^15.0.6",
"pinia": "2.1.6",
"remixicon": "3.5.0",
"vee-validate": "4.11.3",
@@ -32,8 +33,6 @@
"vue3-apexcharts": "1.4.4",
"vue3-print-nb": "0.1.4",
"vuetify": "3.3.14",
"xterm": "^5.3.0",
"xterm-addon-fit": "^0.8.0",
"yup": "1.2.0"
},
"devDependencies": {
@@ -0,0 +1,40 @@
<?xml version="1.0" encoding="utf-8"?>
<!-- Generator: Adobe Illustrator 24.1.2, SVG Export Plug-In . SVG Version: 6.00 Build 0) -->
<svg version="1.1" id="Layer_3" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" x="0px" y="0px"
viewBox="0 0 128 128" style="enable-background:new 0 0 128 128;" xml:space="preserve">
<g>
<linearGradient id="SVGID_1_" gradientUnits="userSpaceOnUse" x1="93.7287" y1="106.6446" x2="52.9011" y2="81.6944">
<stop offset="0.0969" style="stop-color:#FFB300"/>
<stop offset="1" style="stop-color:#FFB300;stop-opacity:0"/>
</linearGradient>
<path style="fill:url(#SVGID_1_);" d="M123.04,107.67c-4.08-4.12-9.38-9.48-14.92-15.06c-0.34,1.29-0.93,2.39-1.79,3.26
c-6.43,6.43-25.6-1.99-45.31-19.1c-2.46-2.13-16.74,20.28-14.1,22.87c3.27,3.2,26,17.86,33.78,20.73
c22.66,8.35,34.3,0.22,38.24-3.59C121.16,114.61,122.51,111.5,123.04,107.67z"/>
<linearGradient id="SVGID_2_" gradientUnits="userSpaceOnUse" x1="115.2813" y1="82.3624" x2="14.863" y2="0.8196">
<stop offset="0" style="stop-color:#FFB300"/>
<stop offset="0.7062" style="stop-color:#FDD835"/>
<stop offset="0.8408" style="stop-color:#FDDC36"/>
<stop offset="0.9842" style="stop-color:#FFE93A"/>
<stop offset="1" style="stop-color:#FFEB3B"/>
</linearGradient>
<path style="fill:url(#SVGID_2_);" d="M25.05,27.7c-1.54-4.81-2.88-11.1-0.4-13.5c7.51-7.3,31.69,4.88,54.25,27.43
c22.55,22.55,34.84,46.84,27.43,54.25c-0.07,0.07-0.16,0.13-0.23,0.2c6.13,5.82,12.2,11.6,16.1,15.31
c4.87-14.43-6.45-44.11-31.5-69.96c-4.07-4.2-16.12-16.56-26.55-23.56C54.61,11.47,44.19,5.59,32.57,4.2
C25,3.29,11.45,5.24,14.25,15.98c0.55,2.12,2.31,7.22,8.15,13.3C23.56,30.49,25.56,29.3,25.05,27.7z"/>
<g>
<path style="fill:#FDD835;" d="M55.98,42.1l-0.75,20c-0.06,1.53,0.72,2.98,2.04,3.77l16.86,10.11c1.85,1.25,1.46,4.09-0.66,4.79
L54.79,85.5c-1.51,0.38-2.69,1.57-3.06,3.08l-4.89,19.93c-0.62,2.15-3.43,2.65-4.76,0.85L31.06,92.91
c-0.85-1.26-2.31-1.97-3.83-1.85L7.49,92.61c-2.23,0.07-3.58-2.45-2.28-4.27l12.6-16.19c0.96-1.23,1.16-2.89,0.52-4.31
l-7.88-17.57c-0.76-2.1,1.22-4.17,3.35-3.49l18.39,6.95c1.44,0.54,3.05,0.26,4.22-0.74l15.22-13
C53.39,38.62,55.96,39.87,55.98,42.1z"/>
<g>
<path style="fill:#FFFF8D;" d="M46.99,59.33l4.66-12.75c0.28-0.7,0.7-1.93,1.79-1.4c0.86,0.42,0.46,2.43,0.46,2.43l-1.05,11.54
c-0.41,4.39-1.6,5.38-3.3,5.49C47.6,64.75,45.65,62.98,46.99,59.33z"/>
</g>
<g>
<path style="fill:#F4B400;" d="M53.89,83.73l14.53-3.13c0.73-0.18,2.01-0.42,1.64-1.58c-0.29-0.91-2.34-0.8-2.34-0.8l-10.97-0.86
c-3.21-0.38-5.72,0.14-6.74,1.84C48.65,81.48,49.89,84.32,53.89,83.73z"/>
</g>
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.6 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 54 KiB

@@ -1,41 +0,0 @@
<script setup>
import UiParentCard from '@/components/shared/UiParentCard.vue';
const props = defineProps({
config: Array
});
</script>
<template>
<a v-show="config.length === 0">该插件没有配置</a>
<UiParentCard v-for="group in config" :key="group.name" :title="group.name" style="margin-bottom: 16px;">
<template v-for="item in group.body">
<template v-if="item.config_type === 'item'">
<template v-if="item.val_type === 'bool'">
<v-switch v-model="item.value" :label="item.name" :hint="item.description" color="primary" inset></v-switch>
</template>
<template v-else-if="item.val_type === 'str'">
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
variant="outlined"></v-text-field>
</template>
<template v-else-if="item.val_type === 'int'">
<v-text-field v-model="item.value" :label="item.name" :hint="item.description" style="margin-bottom: 8px;"
variant="outlined"></v-text-field>
</template>
<template v-else-if="item.val_type === 'list'">
<span>{{ item.name }}</span>
<v-combobox v-model="item.value" chips clearable label="请添加" multiple prepend-icon="mdi-tag-multiple-outline">
<template v-slot:selection="{ attrs, item, select, selected }">
<v-chip v-bind="attrs" :model-value="selected" closable @click="select" @click:close="remove(item)">
<strong>{{ item }}</strong>
</v-chip>
</template>
</v-combobox>
</template>
</template>
<template v-else-if="item.config_type === 'divider'">
<v-divider style="margin-top: 8px; margin-bottom: 8px;"></v-divider>
</template>
</template>
</UiParentCard>
</template>
@@ -15,6 +15,9 @@ let newUsername = ref('');
let status = ref('');
let updateStatus = ref('')
let hasNewVersion = ref(false);
let botCurrVersion = ref('');
let dashboardHasNewVersion = ref(false);
let dashboardCurrentVersion = ref('');
let version = ref('');
const open = (link: string) => {
@@ -64,6 +67,9 @@ function checkUpdate() {
.then((res) => {
hasNewVersion.value = res.data.data.has_new_version;
updateStatus.value = res.data.message;
botCurrVersion.value = res.data.data.version;
dashboardCurrentVersion.value = res.data.data.dashboard_version;
dashboardHasNewVersion.value = res.data.data.dashboard_has_new_version;
})
.catch((err) => {
if (err.response.status == 401) {
@@ -84,7 +90,24 @@ function switchVersion(version: string) {
})
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'success') {
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
}
})
.catch((err) => {
console.log(err);
updateStatus.value = err
});
}
function updateDashboard() {
updateStatus.value = '正在更新...';
axios.post('/api/update/dashboard')
.then((res) => {
updateStatus.value = res.data.message;
if (res.data.status == 'ok') {
setTimeout(() => {
window.location.reload();
}, 1000);
@@ -106,8 +129,8 @@ commonStore.getStartTime();
<template>
<v-app-bar elevation="0" height="70">
<v-btn style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm" variant="flat"
@click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
<v-btn style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm"
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
<v-icon>mdi-menu</v-icon>
</v-btn>
<v-btn class="hidden-lg-and-up text-secondary ms-3" color="lightsecondary" icon rounded="sm" variant="flat"
@@ -136,11 +159,16 @@ commonStore.getStartTime();
</template>
<v-card>
<v-card-title>
<span class="text-h5">更新项目</span>
<span class="text-h5">更新 AstrBot</span>
</v-card-title>
<v-card-text>
<v-container>
<h3 class="mb-4">升级到最新版本</h3>
<h3 class="mb-4">升级到项目最新版本</h3>
<small>当前版本 {{ botCurrVersion }}</small>
<div class="mb-4">
<small>会同时尝试更新机器人主程序和管理面板如果您正在使用 Docker 部署也可以重新拉取镜像或者使用 <a
href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取</small>
</div>
<p>{{ updateStatus }}</p>
<v-btn class="mt-4 mb-4" @click="switchVersion('latest')" color="primary" style="border-radius: 10px;"
:disabled="!hasNewVersion">
@@ -148,7 +176,11 @@ commonStore.getStartTime();
</v-btn>
<v-divider></v-divider>
<div style="margin-top: 16px;">
<h3 class="mb-4">切换到指定版本或指定提交</h3>
<h3 class="mb-4">切换到项目指定版本或指定提交</h3>
<div class="mb-4">
<small>跳到旧版本不会重新下载管理面板文件这可能会造成部分数据显示错误您可在 <a href="https://github.com/Soulter/AstrBot/releases">此处</a>
找到对应的面板文件 dist.zip解压后替换 data/dist 文件夹即可</small>
</div>
<v-text-field label="输入版本号或 master 分支下的 commit hash。" v-model="version" required
variant="outlined"></v-text-field>
<div class="mb-4">
@@ -160,7 +192,29 @@ commonStore.getStartTime();
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
确定切换
</v-btn>
</div>
<v-divider></v-divider>
<div style="margin-top: 16px;">
<h3 class="mb-4">更新管理面板到最新版本</h3>
<div class="mb-4">
<small>当前版本 {{ dashboardCurrentVersion }}</small>
<br>
</div>
<div class="mb-4">
<p v-if="dashboardHasNewVersion">
有新版本
</p>
<p v-else="dashboardHasNewVersion">
已经是最新版本了
</p>
</div>
<v-btn color="primary" style="border-radius: 10px;" @click="updateDashboard()">
下载并更新
</v-btn>
</div>
</v-container>
</v-card-text>
@@ -190,8 +244,7 @@ commonStore.getStartTime();
<v-text-field label="原密码*" type="password" v-model="password" required
variant="outlined"></v-text-field>
<v-text-field label="新用户名" v-model="newUsername" required
variant="outlined"></v-text-field>
<v-text-field label="新用户名" v-model="newUsername" required variant="outlined"></v-text-field>
<v-text-field label="新密码" type="password" v-model="newPassword" required
variant="outlined"></v-text-field>
@@ -213,11 +266,5 @@ commonStore.getStartTime();
</v-card-actions>
</v-card>
</v-dialog>
<v-btn class="text-primary mr-4" @click="open('https://github.com/Soulter/AstrBot')" color="lightprimary"
variant="flat" rounded="sm">
GitHub Star! 🌟
</v-btn>
</v-app-bar>
</template>
@@ -17,7 +17,7 @@ const sidebarMenu = shallowRef(sidebarItems);
</template>
</v-list>
<div class="text-center">
<v-chip color="inputBorder" size="small"> v{{ version }} </v-chip>
<v-chip color="inputBorder" size="small"> {{ version }} </v-chip>
</div>
<div style="position: absolute; bottom: 32px; width: 100%" class="text-center">
@@ -27,8 +27,15 @@ const sidebarMenu = shallowRef(sidebarItems);
</v-btn>
</v-list-item>
<small style="display: block;" v-if="buildVer">构建: {{ buildVer }}</small>
<small style="display: block;" v-else="buildVer">构建: embedded</small>
<small style="display: block; margin-top: 8px;">© 2024 AstrBot</small>
<small style="display: block;" v-else>构建: embedded</small>
<v-tooltip text="使用 /dashbord_update 指令更新管理面板">
<template v-slot:activator="{ props }">
<small v-bind="props" v-if="hasWebUIUpdate" style="display: block; margin-top: 4px;">面板有更新</small>
</template>
</v-tooltip>
<small style="display: block; margin-top: 8px;">© 2025 AstrBot</small>
</div>
</v-navigation-drawer>
@@ -43,25 +50,28 @@ export default {
},
data: () => ({
version: "",
buildVer: ""
buildVer: "",
hasWebUIUpdate: false,
}),
mounted() {
this.get_version()
fetch('/assets/version').then((res) => {
return res.text()
}).then((res) => {
if (res.length > 10) {
// 😎
return
}
this.buildVer = res
})
this.check_webui_update()
},
methods: {
get_version() {
axios.get('/api/stat/version')
.then((res) => {
this.version = res.data.data.version;
this.version = "v" + res.data.data.version;
})
.catch((err) => {
console.log(err);
});
},
check_webui_update() {
axios.get('/api/update/check?type=dashboard')
.then((res) => {
this.hasWebUIUpdate = res.data.data.has_new_version;
this.buildVer = res.data.data.current_version;
})
.catch((err) => {
console.log(err);
@@ -30,11 +30,21 @@ const sidebarItem: menu[] = [
icon: 'mdi-puzzle',
to: '/extension'
},
{
title: '聊天',
icon: 'mdi-chat',
to: '/chat'
},
{
title: '控制台',
icon: 'mdi-console',
to: '/console'
},
{
title: '关于',
icon: 'mdi-information',
to: '/about'
},
// {
// title: 'Project ATRI',
// icon: 'mdi-grain',
+10
View File
@@ -36,6 +36,16 @@ const MainRoutes = {
name: 'Project ATRI',
path: '/project-atri',
component: () => import('@/views/ATRIProject.vue')
},
{
name: 'Chat',
path: '/chat',
component: () => import('@/views/ChatPage.vue')
},
{
name: 'About',
path: '/about',
component: () => import('@/views/AboutPage.vue')
}
]
};
+61
View File
@@ -0,0 +1,61 @@
<template>
<v-card style="height: 100%;">
<v-card-text style="padding: 0; height: 100%;">
<div
style="display: flex; justify-content: center; align-items: center; height: 100%; flex-direction: column;">
<div @click="selectedLogo = selectedLogo == 0 ? 1 : 0" style="height: 300px;">
<img v-if="selectedLogo == 0" width="300" src="@/assets/images/logo-waifu.png" alt="AstrBot Logo" class="fade-in">
<img v-if="selectedLogo == 1" width="300" src="@/assets/images/logo-normal.svg" alt="AstrBot Logo" class="fade-in">
</div>
<h1 class="mt-8">AstrBot</h1>
<span style="color: #777;" class="mt-4">By <a href="https://soulter.top">Soulter</a> And <a href="https://github.com/Soulter/AstrBot/graphs/contributors">AstrBot Contributors</a></span>
<v-btn class="text-primary mt-16" @click="open('https://github.com/Soulter/AstrBot')"
color="lightprimary" variant="flat" rounded="sm">
Star 这个项目! 🌟
</v-btn>
<v-btn class="text-primary mt-4" @click="open('https://github.com/Soulter/AstrBot/issues')"
color="lightprimary" variant="flat" rounded="sm">
有使用问题或者功能建议提交 Issue
</v-btn>
</div>
</v-card-text>
</v-card>
</template>
<script>
export default {
name: 'AboutPage',
data() {
return {
selectedLogo: 0
}
},
methods: {
open(url) {
window.open(url, '_blank');
}
}
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
</style>
+567
View File
@@ -0,0 +1,567 @@
<script setup>
import axios from 'axios';
import { marked } from 'marked';
marked.setOptions({
breaks: true
});
</script>
<template>
<v-card style="margin-bottom: 16px; width: 100%; background-color: #fff; height: 100%;">
<v-card-text style="width: 100%; height: calc(100vh - 120px);">
<div style="height: 100%; display: flex; gap: 16px;">
<div style="max-width: 200px;">
<!-- conversation -->
<v-btn variant="tonal" rounded="xl" style="margin-bottom: 16px; min-width: 200px;" @click="newC"
:disabled="!currCid">+ 创建对话</v-btn>
<v-card class="mx-auto" min-width="200">
<v-list dense nav v-if="conversations.length > 0" style="max-height: 500px; overflow-y: auto;"
@update:selected="getConversationMessages">
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
color="primary" rounded="xl">
<v-list-item-title>新对话</v-list-item-title>
<v-list-item-subtitle>{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
</v-list-item>
</v-list>
</v-card>
<div>
<v-chip class="mt-4" color="primary" :append-icon="status?.llm_enabled ? 'mdi-check' : 'mdi-close'">
LLM
</v-chip>
<v-chip class="mt-4 ml-2" color="success" :append-icon="status?.stt_enabled ? 'mdi-check' : 'mdi-close'">
语音转文本
</v-chip>
</div>
<v-btn variant="tonal" rounded="xl"
style="position: fixed; bottom: 48px; margin-bottom: 16px; min-width: 200px;" v-if="currCid"
@click="deleteConversation(currCid)" color="error">删除此对话</v-btn>
</div>
<div style="height: 100%; width: 100%;">
<div style="height: calc(100% - 120px); overflow-y: auto; padding: 16px; " ref="messageContainer">
<div class="fade-in" v-if="messages.length == 0"
style="height: 100%; display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div>
<span style="font-size: 28px;">Hello, I'm</span>
<span style="font-weight: 1000; font-size: 28px; margin-left: 8px;">AstrBot </span>
</div>
<div style="margin-top: 8px; color: #aaa;">
<span>输入</span>
<span
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">/help</span>
<span>获取帮助 😊</span>
</div>
<div style="margin-top: 8px; color: #aaa;">
<span></span>
<span
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">K</span>
<span>开始语音 🎤</span>
</div>
</div>
<div v-else style="max-height: 100%; padding: 16px; max-width: 700px; margin: 0 auto;">
<div class="fade-in" v-for="(msg, index) in messages" :key="index"
style="margin-bottom: 16px;">
<div v-if="msg.type == 'user'" style="display: flex; justify-content: flex-end;">
<div
style="padding: 12px; border-radius: 8px; background-color: rgba(94, 53, 177, 0.15)">
<span>{{ msg.message }}</span>
<div style="display: flex; gap: 8px; margin-top: 8px;"
v-if="msg.image_url && msg.image_url.length > 0">
<div v-for="(img, index) in msg.image_url" :key="index"
style="position: relative; display: inline-block;">
<img :src="img"
style="width: 100px; height: 100px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
</div>
</div>
<!-- audio -->
<div>
<audio controls v-if="msg.audio_url && msg.audio_url.length > 0">
<source :src="msg.audio_url" type="audio/wav">
Your browser does not support the audio element.
</audio>
</div>
</div>
</div>
<div v-else style="display: flex; justify-content: flex-start; gap: 16px;">
<span style="font-size: 32px;"></span>
<div v-html="marked(msg.message)" class="mc" style="font-family: inherit;"></div>
</div>
</div>
</div>
</div>
<div class="fade-in" style="bottom: 16px; width: 100%; padding: 8px; ">
<div
style="width: 100%; justify-content: center; align-items: center; display: flex; flex-direction: column; margin-top: 8px;">
<v-text-field id="input-field" variant="outlined" v-model="prompt" :label="inputFieldLabel"
placeholder="Start typing..." loading clear-icon="mdi-close-circle" clearable
@click:clear="clearMessage" style="width: 100%; max-width: 850px;">
<template v-slot:loader>
<v-progress-linear :active="loadingChat" height="6"
indeterminate></v-progress-linear>
</template>
<template v-slot:append>
<v-tooltip text="发送">
<template v-slot:activator="{ props }">
<v-icon v-bind="props" @click="sendMessage" size="35"
icon="mdi-arrow-up-circle" />
</template>
</v-tooltip>
<v-tooltip text="语音输入">
<template v-slot:activator="{ props }">
<v-icon :color="isRecording ? 'error' : ''" v-bind="props"
@click="isRecording ? stopRecording() : startRecording()" size="35"
icon="mdi-record-circle" />
</template>
</v-tooltip>
</template>
</v-text-field>
<div style="display: flex; gap: 8px; margin-top: -8px;">
<div v-for="(img, index) in stagedImagesUrl" :key="index"
style="position: relative; display: inline-block;">
<img :src="img"
style="width: 50px; height: 50px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
<v-icon @click="removeImage(index)" size="20" color="red"
style="position: absolute; top: 0; right: 0; cursor: pointer;">mdi-close-circle</v-icon>
</div>
<div style="display: inline-block; width: 50px; height: 50px;">
<div v-if="stagedAudioUrl"
style="position: relative; padding: 6px; border-radius: 8px; background-color: rgba(94, 53, 177, 0.15); display: inline-block;">
新录音
<v-icon @click="removeAudio" size="20" color="red"
style="position: absolute; top: 0; right: 0; cursor: pointer;">mdi-close-circle</v-icon>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</v-card-text>
</v-card>
</template>
<script>
export default {
name: 'ChatPage',
components: {
},
data() {
return {
prompt: '',
messages: [],
conversations: [],
currCid: '',
stagedImagesUrl: [],
loadingChat: false,
inputFieldLabel: '聊天吧!',
isRecording: false,
audioChunks: [],
stagedAudioUrl: "",
mediaRecorder: null,
status: {},
statusText: '',
eventSource: null
}
},
mounted() {
this.startListeningEvent();
this.checkStatus();
this.getConversations();
let inputField = document.getElementById('input-field');
inputField.addEventListener('paste', this.handlePaste);
inputField.addEventListener('keydown', function (e) {
if (e.keyCode == 13 && !e.shiftKey) {
e.preventDefault();
this.sendMessage();
}
}.bind(this));
document.addEventListener('keydown', function (e) {
if (e.keyCode == 75) {
this.isRecording ? this.stopRecording() : this.startRecording();
}
}.bind(this));
},
beforeUnmount() {
console.log("111")
if (this.eventSource) {
this.eventSource.cancel();
console.log('SSE连接已断开');
}
},
methods: {
async startListeningEvent() {
const response = await fetch('/api/chat/listen', {
method: 'GET',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + localStorage.getItem('token')
}
})
if (!response.ok) {
console.error('SSE连接失败:', response.statusText);
return;
}
const reader = response.body.getReader();
const decoder = new TextDecoder();
this.eventSource = reader
while (true) {
const { done, value } = await reader.read();
if (done) {
console.log('SSE连接关闭');
break;
}
const chunk = decoder.decode(value, { stream: true });
console.log("!!!!", chunk);
if (chunk === '[HB]\n') {
continue; //
}
if (chunk === '[ERROR]\n') {
continue;
}
if (chunk.startsWith('[IMAGE]')) {
let img = chunk.replace('[IMAGE]', '');
let bot_resp = {
type: 'bot',
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
this.messages.push(bot_resp);
} else {
let bot_resp = {
type: 'bot',
message: chunk
}
this.messages.push(bot_resp);
}
this.scrollToBottom();
}
},
removeAudio() {
this.stagedAudioUrl = null;
},
checkStatus() {
axios.get('/api/chat/status').then(response => {
console.log(response.data);
this.status = response.data.data;
}).catch(err => {
console.error(err);
});
},
async startRecording() {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
this.mediaRecorder = new MediaRecorder(stream);
this.mediaRecorder.ondataavailable = (event) => {
this.audioChunks.push(event.data);
};
this.mediaRecorder.start();
this.isRecording = true;
this.inputFieldLabel = "录音中,请说话...";
},
async stopRecording() {
this.isRecording = false;
this.inputFieldLabel = "聊天吧!";
this.mediaRecorder.stop();
this.mediaRecorder.onstop = async () => {
const audioBlob = new Blob(this.audioChunks, { type: 'audio/wav' });
this.audioChunks = [];
this.mediaRecorder.stream.getTracks().forEach(track => track.stop());
const formData = new FormData();
formData.append('file', audioBlob);
try {
const response = await axios.post('/api/chat/post_file', formData, {
headers: {
'Content-Type': 'multipart/form-data',
'Authorization': 'Bearer ' + localStorage.getItem('token')
}
});
const audio = response.data.data.filename;
console.log('Audio uploaded:', audio);
this.stagedAudioUrl = `/api/chat/get_file?filename=${audio}`;
} catch (err) {
console.error('Error uploading audio:', err);
}
};
},
async handlePaste(event) {
console.log('Pasting image...');
const items = event.clipboardData.items;
for (let i = 0; i < items.length; i++) {
if (items[i].type.indexOf('image') !== -1) {
const file = items[i].getAsFile();
const formData = new FormData();
formData.append('file', file);
try {
const response = await axios.post('/api/chat/post_image', formData, {
headers: {
'Content-Type': 'multipart/form-data',
'Authorization': 'Bearer ' + localStorage.getItem('token')
}
});
const img = response.data.data.filename;
this.stagedImagesUrl.push(`/api/chat/get_file?filename=${img}`);
} catch (err) {
console.error('Error uploading image:', err);
}
}
}
},
removeImage(index) {
this.stagedImagesUrl.splice(index, 1);
},
clearMessage() {
this.prompt = '';
},
getConversations() {
axios.get('/api/chat/conversations').then(response => {
this.conversations = response.data.data;
}).catch(err => {
console.error(err);
});
},
getConversationMessages(cid) {
if (!cid[0])
return;
axios.get('/api/chat/get_conversation?conversation_id=' + cid[0]).then(response => {
this.currCid = cid[0];
let message = JSON.parse(response.data.data.history);
for (let i = 0; i < message.length; i++) {
if (message[i].message.startsWith('[IMAGE]')) {
let img = message[i].message.replace('[IMAGE]', '');
message[i].message = `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
if (message[i].image_url && message[i].image_url.length > 0) {
for (let j = 0; j < message[i].image_url.length; j++) {
message[i].image_url[j] = `/api/chat/get_file?filename=${message[i].image_url[j]}`;
}
}
if (message[i].audio_url) {
message[i].audio_url = `/api/chat/get_file?filename=${message[i].audio_url}`;
}
}
this.messages = message;
}).catch(err => {
console.error(err);
});
},
async newConversation() {
await axios.get('/api/chat/new_conversation').then(response => {
this.currCid = response.data.data.conversation_id;
this.getConversations();
}).catch(err => {
console.error(err);
});
},
newC() {
this.currCid = '';
this.messages = [];
},
formatDate(timestamp) {
const date = new Date(timestamp * 1000); //
const options = {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
};
return date.toLocaleString('zh-CN', options).replace(/\//g, '-').replace(/, /g, ' ');
},
deleteConversation(cid) {
axios.get('/api/chat/delete_conversation?conversation_id=' + cid).then(response => {
this.getConversations();
this.currCid = '';
this.messages = [];
}).catch(err => {
console.error(err);
});
},
async sendMessage() {
if (this.currCid == '') {
await this.newConversation();
}
this.messages.push({
type: 'user',
message: this.prompt,
image_url: this.stagedImagesUrl,
audio_url: this.stagedAudioUrl
});
this.scrollToBottom();
// images
let image_filenames = [];
for (let i = 0; i < this.stagedImagesUrl.length; i++) {
let img = this.stagedImagesUrl[i].replace('/api/chat/get_file?filename=', '');
image_filenames.push(img);
}
// audio
let audio_filenames = [];
if (this.stagedAudioUrl) {
let audio = this.stagedAudioUrl.replace('/api/chat/get_file?filename=', '');
audio_filenames.push(audio);
}
this.loadingChat = true;
fetch('/api/chat/send', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + localStorage.getItem('token')
},
body: JSON.stringify({
message: this.prompt,
conversation_id: this.currCid,
image_url: image_filenames,
audio_url: audio_filenames
}) //
})
.then(response => {
this.prompt = '';
this.stagedImagesUrl = [];
this.stagedAudioUrl = "";
this.loadingChat = false;
// const reader = response.body.getReader(); // Reader
// const decoder = new TextDecoder();
// const readStream = async () => {
// const { done, value } = await reader.read(); //
// if (done) {
// console.log("Stream finished.");
// return;
// }
// const chunk = decoder.decode(value, { stream: true });
// // bot_resp.message.value += chunk;
// console.log("!!!!", chunk);
// if (chunk.startsWith('[IMAGE]')) {
// let img = chunk.replace('[IMAGE]', '');
// let bot_resp = {
// type: 'bot',
// message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
// }
// this.messages.push(bot_resp);
// } else {
// let bot_resp = {
// type: 'bot',
// message: chunk
// }
// this.messages.push(bot_resp);
// }
// this.scrollToBottom();
// readStream(); //
// };
// readStream();
})
.catch(err => {
console.error(err);
});
},
scrollToBottom() {
this.$nextTick(() => {
const container = this.$refs.messageContainer;
container.scrollTop = container.scrollHeight;
});
}
},
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
.mc h1,
.mc h2,
.mc h3,
.mc h4,
.mc h5,
.mc h6 {
margin-bottom: 10px;
}
.mc li {
margin-left: 16px;
}
.mc p {
margin-top: 10px;
margin-bottom: 10px;
}
</style>
+11 -3
View File
@@ -3,6 +3,7 @@ import axios from 'axios';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
import config from '@/config';
</script>
<template>
@@ -44,7 +45,10 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
<v-expansion-panel-text v-if="metadata[key]['metadata'][key2]?.config_template">
<!-- 带有 config_template 的配置项 -->
<v-tabs style="margin-top: 16px;" align-tabs="left" color="deep-purple-accent-4" v-model="config_template_tab">
<v-tab v-for="(item, index) in config_data[key2]" :key="index" :value="index">
<v-tab v-if="metadata[key]['metadata'][key2]?.tmpl_display_title" v-for="(item, index) in config_data[key2]" :key="index" :value="index">
{{ item[metadata[key]['metadata'][key2]?.tmpl_display_title] }}
</v-tab>
<v-tab v-else v-for="(item, index) in config_data[key2]" :key="index + '_'" :value="index">
{{ item.id }}({{ item.type }})
</v-tab>
<v-menu>
@@ -64,6 +68,10 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
<v-tabs-window-item v-for="(config_item, index) in config_data[key2]" v-show="config_template_tab === index"
:key="index" :value="index">
<v-container>
<v-btn variant="tonal" rounded="xl" color="error" @click="config_data[key2].splice(index, 1)">
删除这项
</v-btn>
<AstrBotConfig :metadata="metadata[key]['metadata']" :iterable="config_item" :metadataKey="key2"></AstrBotConfig>
</v-container>
</v-tabs-window-item>
@@ -83,7 +91,7 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
<div style="margin-left: 16px; padding-bottom: 16px">
<small>不了解配置请见 <a
href="https://astrbot.soulter.top/docs/%E5%BC%80%E5%A7%8B%E4%B8%8A%E6%89%8B/%E9%85%8D%E7%BD%AE%E6%96%87%E4%BB%B6">官方文档</a>
href="https://astrbot.soulter.top/">官方文档</a>
<a
href="https://qm.qq.com/cgi-bin/qm/qr?k=EYGsuUTfe00_iOu9JTXS7_TEpMkXOvwv&jump_from=webapi&authKey=uUEMKCROfsseS+8IzqPjzV3y1tzy4AkykwTib2jNkOFdzezF9s9XknqnIaf3CDft">加群询问</a></small>
</div>
@@ -204,7 +212,7 @@ export default {
let tmpl = this.metadata[group_name]['metadata'][config_item_name]['config_template'][val];
let new_tmpl_cfg = JSON.parse(JSON.stringify(tmpl));
new_tmpl_cfg.id = "new_" + val + "_" + this.config_data[config_item_name].length;
// new_tmpl_cfg.id = "new_" + val + "_" + this.config_data[config_item_name].length;
this.config_data[config_item_name].push(new_tmpl_cfg);
this.config_template_tab = this.config_data[config_item_name].length - 1;
}
+61 -1
View File
@@ -1,5 +1,7 @@
<script setup>
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import axios from 'axios';
</script>
<template>
@@ -7,8 +9,34 @@ import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
<div
style="background-color: white; padding: 8px; padding-left: 16px; border-radius: 8px; margin-bottom: 16px; display: flex; flex-direction: row; align-items: center; justify-content: space-between;">
<h4>控制台</h4>
<v-dialog v-model="pipDialog" width="400">
<template v-slot:activator="{ props }">
<v-btn variant="plain" v-bind="props">安装 pip </v-btn>
</template>
<v-card>
<v-card-title>
<span class="text-h5">安装 Pip </span>
</v-card-title>
<v-card-text>
<v-text-field v-model="pipInstallPayload.package" label="*库名,如 llmtuner" variant="outlined"></v-text-field>
<v-text-field v-model="pipInstallPayload.mirror" label="镜像站链接(可选)" variant="outlined"></v-text-field>
<small>如果不填镜像站链接默认使用阿里云镜像https://mirrors.aliyun.com/pypi/simple/</small>
<div>
<small>{{ status }}</small>
</div>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="pipInstall" :loading="loading">
安装
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</div>
<ConsoleDisplayer style="height: calc(100vh - 160px); "/>
<ConsoleDisplayer style="height: calc(100vh - 160px); " />
</div>
</template>
<script>
@@ -17,6 +45,36 @@ export default {
components: {
ConsoleDisplayer
},
data() {
return {
pipDialog: false,
pipInstallPayload: {
package: '',
mirror: ''
},
loading: false,
status: ''
}
},
methods: {
pipInstall() {
this.loading = true;
axios.post('/api/update/pip-install', this.pipInstallPayload)
.then(res => {
this.status = res.data.message;
setTimeout(() => {
this.status = '';
this.pipDialog = false;
}, 2000);
})
.catch(err => {
this.status = err.response.data.message;
}).finally(() => {
this.loading = false;
});
}
}
}
</script>
@@ -26,10 +84,12 @@ export default {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
+21 -14
View File
@@ -1,7 +1,7 @@
<script setup>
import ExtensionCard from '@/components/shared/ExtensionCard.vue';
import ConfigDetailCard from '@/components/shared/ConfigDetailCard.vue';
import WaitingForRestart from '@/components/shared/WaitingForRestart.vue';
import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import axios from 'axios';
@@ -9,8 +9,8 @@ import axios from 'axios';
<template>
<v-row>
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以前往 配置->其他配置->插件仓库镜像 修改安装镜像源。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal">
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal">
</v-alert>
<v-col cols="12" md="12">
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
@@ -52,11 +52,17 @@ import axios from 'axios';
<v-btn v-else variant="plain" disabled>已安装</v-btn>
</div>
</ExtensionCard>
</v-col>
<v-col style="margin-bottom: 16px;" cols="12" md="12">
<small ><a href="https://astrbot.app/dev/plugin.html">插件开发文档</a></small> |
<small> <a href="https://github.com/Soulter/AstrBot_Plugins_Collection">提交插件仓库</a></small>
</v-col>
</v-row>
<v-dialog v-model="configDialog" width="750">
<v-dialog v-model="configDialog" width="1000">
<template v-slot:activator="{ props }">
</template>
<v-card>
@@ -65,7 +71,8 @@ import axios from 'axios';
</v-card-title>
<v-card-text>
<v-container>
<ConfigDetailCard :config="extension_config"></ConfigDetailCard>
<AstrBotConfig v-if="extension_config.metadata" :metadata="extension_config.metadata" :iterable="extension_config.config" :metadataKey=curr_namespace></AstrBotConfig>
<p v-else>这个插件没有配置</p>
</v-container>
</v-card-text>
<v-card-actions>
@@ -80,7 +87,7 @@ import axios from 'axios';
</v-card>
</v-dialog>
<v-dialog v-model="dialog" persistent width="700">
<v-dialog v-model="dialog" width="700">
<template v-slot:activator="{ props }">
<v-btn v-bind="props" icon="mdi-plus" size="x-large" style="position: fixed; right: 52px; bottom: 52px;"
color="darkprimary">
@@ -172,9 +179,9 @@ export default {
name: 'ExtensionPage',
components: {
ExtensionCard,
ConfigDetailCard,
WaitingForRestart,
ConsoleDisplayer
ConsoleDisplayer,
AstrBotConfig
},
data() {
return {
@@ -189,7 +196,10 @@ export default {
snack_success: "success",
loading_: false,
configDialog: false,
extension_config: {},
extension_config: {
"metadata": {},
"config": {}
},
upload_file: null,
pluginMarketData: {},
loadingDialog: {
@@ -364,7 +374,7 @@ export default {
openExtensionConfig(extension_name) {
this.curr_namespace = extension_name;
this.configDialog = true;
axios.get('/api/config/get?namespace=' + extension_name).then((res) => {
axios.get('/api/config/get?plugin_name=' + extension_name).then((res) => {
this.extension_config = res.data.data;
console.log(this.extension_config);
}).catch((err) => {
@@ -372,10 +382,7 @@ export default {
});
},
updateConfig() {
axios.post('/api/config/plugin/update', {
"config": this.extension_config,
"namespace": this.curr_namespace
}).then((res) => {
axios.post('/api/config/plugin/update?plugin_name='+this.curr_namespace, this.extension_config.config).then((res) => {
if (res.data.status === "ok") {
this.toast(res.data.message, "success");
this.$refs.wfr.check();

Some files were not shown because too many files have changed in this diff Show More