Compare commits

...

276 Commits

Author SHA1 Message Date
Soulter 062af1ac08 🎈 perf: 优化 WebUI 日志错误处理 2025-04-07 10:38:03 +08:00
Soulter 79d38f9597 📦release: v3.5.2 2025-04-06 22:36:31 +08:00
Soulter 4d186baa35 Merge pull request #1128 from anka-afk/anka-dev
feature: 实现了 #1127 还有 #1133 还有 #1143
2025-04-06 22:22:01 +08:00
Raven95676 e54eaab842 将验证器字典移到类级别,避免重复创建 2025-04-05 21:19:53 +08:00
Raven95676 43b6297b5d reminder将时区设置移入try块,统一为self.timezone 2025-04-05 21:08:52 +08:00
Raven95676 c20f4f5adf 删除默认值,调整logger逻辑 2025-04-05 21:03:02 +08:00
Soulter dc1f222cd2 fix: 使用 zoneinfo 替代 tzinfo; 默认不设置时区(使用系统默认时区) 2025-04-05 17:27:46 +08:00
Soulter c2b687212c cleanup 2025-04-05 16:51:06 +08:00
Soulter 849913276d 🎈 perf: 钉钉支持 Markdown 渲染输出
fixes: #1104
2025-04-05 16:29:14 +08:00
Soulter 23579c1e4a 🐛 fix: 阿里百炼应用无法多轮会话
fixes: #1123
2025-04-05 16:21:41 +08:00
Soulter e031161fd4 🐛 修复: 移除文本输入框的 auto-grow 属性
fixes: #1038
2025-04-05 15:58:17 +08:00
Soulter 4800ee6c0a Merge pull request #1152 from AstrBotDevs/feat-log-filter
 feat: 更新日志发布机制,支持日志级别和内容的字典格式,增加日志筛选功能
2025-04-05 15:49:09 +08:00
Soulter d3a7fef9b0 🐛 修复: 移除多余的 console 语句 2025-04-05 15:46:45 +08:00
Soulter 40822fe77a feat: 更新日志发布机制,支持日志级别和内容的字典格式,增加日志筛选功能
fixes: #1010
2025-04-05 15:43:40 +08:00
Soulter 837b670213 feat(webui): 支持修改列表项
fixes: #1086
2025-04-05 15:10:44 +08:00
Soulter 57ce69f3fb feat: WebChat 支持语音输出
fixes: #1087
2025-04-05 15:02:34 +08:00
anka be022c4894 fix: add StarTools to api 2025-04-05 11:55:25 +08:00
anka 8a366964bb feature: 增加时区设置支持 2025-04-05 11:52:51 +08:00
anka ee86b68470 fix: 漏加classmethod了! 2025-04-05 01:15:56 +08:00
anka 60352307aa fix: 重生之我要苦读设计模式, 终于知道怎么整了哈哈哈: 使用静态类实现工具集合, 并且正确初始化 2025-04-05 01:11:10 +08:00
anka 3ebd2f746f feature: 添加插件工具类, 暂时这么多 2025-04-05 00:51:52 +08:00
anka 1c1a65b637 fix: 全部消息段的检验弄好了! 2025-04-05 00:21:28 +08:00
anka 010e60d029 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-04-04 23:13:43 +08:00
Soulter 7a25568861 Merge pull request #1131 from AliveGh0st/feature/gemini-safety-settings
feature:增加对Gemini系列模型的安全设置参数支持
2025-04-04 21:22:58 +08:00
AliveGh0st 5f4f913661 feat: 增加对 Gemini 系列模型的输入安全设置参数支持
fixes: #216

Squashed:

Update astrbot/core/config/default.py

描述更正.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

🎨 style: clean up

🐛 fix: 修复安全设置参数的默认值为列表
2025-04-04 21:12:51 +08:00
Soulter ccd0e34a53 Merge pull request #1145 from AstrBotDevs/feat-telegram-markdownv2
 feat: 支持 Telegram MarkdownV2 渲染
2025-04-04 20:54:04 +08:00
Soulter 72f1ffccd3 feat: 支持 Telegram MarkdownV2 渲染
fixes: #649 #907
2025-04-04 20:52:22 +08:00
Soulter ea7a52945f Merge pull request #1132 from Captain-Slacker-OwO/dify-md
docs: 更新 Dify 平台链接为官方域名
2025-04-04 01:12:19 +08:00
Soulter 89d4d1351a Merge pull request #1135 from AstrBotDevs/feat-dashscope-tts
feat: 支持阿里云百炼 TTS
2025-04-04 01:03:36 +08:00
Soulter b757c91d93 🐛 fix: 修复无法识别到函数调用异常的问题 2025-04-04 01:02:39 +08:00
Soulter 27203d7a4d 🐛 fix: update voice key name 2025-04-04 00:47:50 +08:00
Soulter 9ad4e18ac5 feat: 支持阿里云百炼 TTS 2025-04-04 00:32:37 +08:00
anka fcdc8f3ce7 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-04-03 21:57:24 +08:00
Captain-Slacker-OwO 78b994b84a docs: 更新 Dify 平台链接为官方域名
将 README 文件中的 Dify 平台链接从旧域名更新为官方域名 dify.ai,确保文档的准确性和权威性。
2025-04-03 19:00:44 +08:00
Soulter 58bfc677e2 🐛 fix: dify error Arg user must be provided
fixes #1073
2025-04-03 16:49:05 +08:00
Soulter 7d17285a0c 🐛 fix: ensure whitelist entries are stripped of whitespace and converted to strings 2025-04-03 16:44:37 +08:00
Soulter e9eb00a0d4 feat: 插件市场帮助按钮 2025-04-03 16:19:01 +08:00
anka 48d07af574 feature(fix?): 在发送消息之前统一检查消息内容是否为空, 不允许发送空消息, 以解决该消息内容不支持查看以及gemini返回<empty content>问题 2025-04-03 11:50:12 +08:00
Soulter 2fc62efd88 Merge pull request #1116 from AstrBotDevs/feat-log-sse
🏗 refactor: log 通信使用 SSE 替代 Websockets
2025-04-02 21:07:40 +08:00
Soulter be516d75bd 🐛 fix: upadte method name 2025-04-02 21:06:59 +08:00
Soulter 951d5fde85 🏗 refactor: log 通信使用 SSE 替代 Websockets 2025-04-02 20:59:25 +08:00
Soulter 1389abc052 Merge pull request #1112 from AstrBotDevs/fix-aiocqhttp-empty-plain
修复 aiocqhttp 适配器下空白 plain 导致的报错
2025-04-02 16:27:12 +08:00
Soulter 19ad67a77f 🐛 fix: 修复 aiocqhttp 适配器下空白 plain 导致的 the object is not a proper segment chain 报错问题 2025-04-02 16:24:36 +08:00
Soulter 641f308344 Update README.md 2025-04-01 11:35:56 +08:00
Soulter 9f097fa4d5 Update README.md 2025-04-01 11:33:38 +08:00
Soulter 5ad362c52b Merge pull request #1081 from anka-afk/anka-dev
fix #1074 and add some comment
2025-04-01 10:57:40 +08:00
Soulter 614f238a61 Merge pull request #1072 from zhx8702/feat-add-plugin-md-dialog
feat: 安装完插件后自动弹出插件仓库 README 对话框
2025-04-01 10:56:24 +08:00
zhx dec91950bc feat: 安装完插件后自动弹出插件仓库 README 对话框 2025-04-01 10:04:04 +08:00
anka 6cef9c23f0 bug fix: #1074 修改最多携带对话数量时出现bug 2025-03-31 22:41:23 +08:00
anka 3f568bf136 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-03-31 22:32:40 +08:00
anka 5484b421ce perf: 增加部分注释 2025-03-31 22:30:43 +08:00
Soulter 02f21e07d3 📦 release: v3.5.1 2025-03-31 10:59:32 +08:00
Soulter fff1f23a83 Update README.md 2025-03-31 00:57:23 +08:00
Soulter a056ec0d38 Merge pull request #1065 from AstrBotDevs/perf-openai-source-balance
🎈 perf: OpenAI sources supports api key load balance(random)
2025-03-30 22:53:27 +08:00
Soulter 2eb9e5dde3 perf: 添加重试等待 2025-03-30 22:51:34 +08:00
渡鸦95676 627d2a4701 新增重试间隔 2025-03-30 22:33:21 +08:00
Soulter 76895fe86d chore: improve variable names 2025-03-30 22:12:34 +08:00
Soulter 64c3c85780 Merge pull request #1056 from Raven95676/master
perf: 优化无对话情况下设置人格的反馈;若禁用提供商,自动切换到另一个可用的提供商
2025-03-30 22:10:23 +08:00
Soulter 7288348857 🎈 perf: OpenAI sources supports api key load balance(random) 2025-03-30 22:00:45 +08:00
Soulter 62e73299b1 🐛 fix: forcely write shared preference data
Note: this is a fast fix for recent feedbacks, we'll improve its performance.
2025-03-30 21:33:41 +08:00
Raven95676 fe76c41ed8 perf: 若禁用提供商,自动切换到另一个可用的提供商 2025-03-30 15:18:48 +08:00
Raven95676 1a92edf8be perf: 优化无对话情况下设置人格的反馈 2025-03-30 14:38:40 +08:00
Soulter b63b606a4e docs: 推荐使用 uv 进行手动部署 2025-03-30 10:39:14 +08:00
Soulter 8e2ef3d22b Merge pull request #1050 from advent259141/master
回复空@功能的修复
2025-03-30 00:15:26 +08:00
Gao Jinzhe c6c4a32283 Add files via upload 2025-03-29 22:37:18 +08:00
Soulter b70b3b158e feat: 支持 gemini-2.0-flash-exp-image-generation 对图片模态的输入 #1017 2025-03-29 20:51:27 +08:00
Soulter 3d59ab8108 fix: conversation and tool use page refresh 404 2025-03-29 19:17:56 +08:00
Soulter b6c3089510 🎈 perf: 优化空 at 回复 2025-03-29 19:09:35 +08:00
Soulter bd92aac280 feat: 支持 /llm 指令快捷启停 LLM 功能 #296 2025-03-29 18:31:07 +08:00
Soulter 5299e802e9 Merge pull request #1046 from AstrBotDevs/feat-docker-embedded-ffmpeg
docker 镜像提供内置 ffmpeg
2025-03-29 17:53:40 +08:00
Soulter 8e5a57d7dd Merge pull request #1045 from Raven95676/master
在lifecycle新增插件资源清理逻辑
2025-03-29 17:53:16 +08:00
Soulter beaa324fb6 Merge pull request #1012 from Zhenyi-Wang/master
feat: gewechat client增加获取通讯录列表接口
2025-03-29 17:51:35 +08:00
Soulter 79e64fe206 Merge pull request #1011 from left666/left666
feat(core): 在 MessageChain 类中添加 at 和 at_all 方法
2025-03-29 17:50:55 +08:00
Soulter 93f525e3fe 🎈 perf: edge tts 支持使用代理;移除了一些不需要的方法 2025-03-29 17:48:22 +08:00
Soulter aacb803c64 Merge pull request #999 from Futureppo/master
部分api获取不到model导致key泄露,使用正则表达式过滤掉key内容
2025-03-29 17:43:10 +08:00
Soulter 8a0665b222 🎈 feat: 更新 Dockerfile,添加 Node.js 支持并优化依赖安装 2025-03-29 17:42:31 +08:00
Soulter 20e41a7f73 🐛 fix: newgroup 指令名显示错误 2025-03-29 17:42:31 +08:00
Soulter 93a1699a35 Update README.md 2025-03-29 17:42:31 +08:00
Soulter c33c07e4af Update README.md 2025-03-29 17:42:31 +08:00
Soulter c7484d0cc9 Update README.md 2025-03-29 17:42:31 +08:00
Soulter fb85a7bb35 feat: add demo mode 2025-03-29 17:42:31 +08:00
Soulter 42ff9a4d34 Update README.md 2025-03-29 17:42:31 +08:00
Soulter 005e9eae7c 🐛 fix: 插件更新时没有正确应用加速地址 2025-03-29 17:42:31 +08:00
Soulter 3e325debcc Update README.md 2025-03-29 17:42:31 +08:00
Soulter a221de9a2b 🐛 fix: 修复 LLM 响应后事件钩子无法生效的问题 2025-03-29 17:42:31 +08:00
Soulter 32b0cc1865 Update README.md 2025-03-29 17:42:31 +08:00
Soulter bbf85f8a12 🐛 fix: remove error logging for empty result and refresh extensions after upload 2025-03-29 17:42:31 +08:00
Soulter 67a0172b28 📦 release: v3.5.0 2025-03-29 17:42:31 +08:00
zhx fb19d4d45b fix: install_plugin_from_file 方法load传参数改为文件名 2025-03-29 17:42:31 +08:00
Soulter a156b1af14 feat: 支持通过指令下载插件 /plugin get 2025-03-29 17:42:31 +08:00
Soulter a604b4943c 🎈 perf: 优化新版本时的信息显示 2025-03-29 17:42:31 +08:00
pre-commit-ci[bot] 3f0b6435d9 🎈 auto fixes by pre-commit hooks 2025-03-29 17:42:31 +08:00
Gao Jinzhe e0f029e2cb Add files via upload 2025-03-29 17:42:31 +08:00
Soulter 89d3fd5fab 🎈 perf: 优化 WebUI 对话数据库中文历史检索 2025-03-29 17:42:31 +08:00
Soulter a38b00be6b 🐛 fix: 修复部分可能形成 SQL 注入的风险 2025-03-29 17:42:31 +08:00
Futureppo 0e8d52b591 :ballon: feat: 使用正则表达式过滤掉 /model 可能暴露的 api_key
Squashed:

更新正则表达式

🎈 auto fixes by pre-commit hooks

Update main.py

Update main.py

chore: bugfixes
2025-03-29 17:40:48 +08:00
Soulter 298c77740d feat: docker 镜像提供内置 ffmpeg #979 2025-03-29 17:26:57 +08:00
Raven95676 c681aae8ee 修复日志问题 2025-03-29 17:25:38 +08:00
Raven95676 faef98b089 在lifecycle新增插件资源清理逻辑 2025-03-29 17:07:12 +08:00
Soulter 84a3e0a30b 🎈 feat: 更新 Dockerfile,添加 Node.js 支持并优化依赖安装 2025-03-29 16:36:02 +08:00
Soulter 69bd553ce0 Merge pull request #1035 from AstrBotDevs/fix-1034-bug
🐛 fix: groupnew 指令名显示错误
2025-03-28 23:46:30 +08:00
Soulter fd0c0f8975 🐛 fix: newgroup 指令名显示错误 2025-03-28 23:45:19 +08:00
Zhenyi-Wang 860ceb06b4 Merge branch 'Soulter:master' into master 2025-03-28 21:27:25 +08:00
anka ecf501bf72 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-03-28 19:04:35 +08:00
Soulter 81a2ed1e25 Update README.md 2025-03-28 18:20:33 +08:00
Soulter 76ab28338a Update README.md 2025-03-28 13:24:41 +08:00
Soulter 9a56c9630f Update README.md 2025-03-28 13:23:29 +08:00
anka 53b9497c18 perf: 增加部分注释 2025-03-27 21:32:38 +08:00
Soulter 750b16b6ee feat: add demo mode 2025-03-27 15:54:23 +08:00
anka 0ee3e0779a Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-03-27 15:21:04 +08:00
pre-commit-ci[bot] 333c2d9299 🎈 auto fixes by pre-commit hooks 2025-03-27 03:21:43 +00:00
Zhenyi Wang ad37ff5048 feat: gewechat client增加获取通讯录列表接口 2025-03-27 11:17:52 +08:00
pre-commit-ci[bot] 33f86f3bde 🎈 auto fixes by pre-commit hooks 2025-03-27 02:56:55 +00:00
Soulter 8acb969a49 Update README.md 2025-03-27 10:39:18 +08:00
left666 b74b5933b8 feat(core): 在 MessageChain 类中添加 at 和 at_all 方法
- 新增 at 方法,用于添加 At 消息到消息链中
- 新增 at_all 方法,用于添加 AtAll 消息到消息链中
2025-03-27 10:30:19 +08:00
Soulter 681c556b7e 🐛 fix: 插件更新时没有正确应用加速地址 2025-03-27 10:04:40 +08:00
anka 1746684e52 perf: 修改部分注释 2025-03-26 23:52:03 +08:00
Soulter 0b93d06555 Update README.md 2025-03-26 20:51:53 +08:00
anka 8a8b8c7c27 Merge remote-tracking branch 'origin/master' into anka-dev 2025-03-26 17:59:53 +08:00
anka 6b6577006d perf: 格式化 2025-03-26 17:59:30 +08:00
Soulter 23ee5e81c9 🐛 fix: 修复 LLM 响应后事件钩子无法生效的问题 2025-03-26 17:56:55 +08:00
Soulter 483f55e4b1 Update README.md 2025-03-26 16:16:03 +08:00
Soulter 1bb1bc2553 🐛 fix: remove error logging for empty result and refresh extensions after upload 2025-03-26 15:43:56 +08:00
Soulter a4e4e36f94 📦 release: v3.5.0 2025-03-26 15:30:09 +08:00
Soulter 6849415812 Merge pull request #996 from zhx8702/fix-star-manager
fix: install_plugin_from_file 方法load传参数改为文件名
2025-03-26 15:26:53 +08:00
zhx 86f6cb038e fix: install_plugin_from_file 方法load传参数改为文件名 2025-03-26 15:06:33 +08:00
Soulter 7480a1d6ce feat: 支持通过指令下载插件 /plugin get 2025-03-26 14:33:45 +08:00
Soulter 3cd10117dd 🎈 perf: 优化新版本时的信息显示 2025-03-26 14:14:01 +08:00
Soulter 0caf19d390 Merge pull request #937 from advent259141/master
将对只有一个 @ 的消息内容的处理改成调用llm回复
2025-03-26 13:54:43 +08:00
anka 5c14ebb049 Merge remote-tracking branch 'origin/master' into anka-dev 2025-03-26 13:53:21 +08:00
anka 9717a736b1 perf: 更新部分描述 2025-03-26 13:50:54 +08:00
Soulter 9c9ab50d1a 🎈 perf: 优化 WebUI 对话数据库中文历史检索 2025-03-26 13:50:11 +08:00
Soulter d4bcb8174e 🐛 fix: 修复部分可能形成 SQL 注入的风险 2025-03-26 13:41:18 +08:00
anka 9e7fe773bd perf: 更新部分注释 2025-03-26 11:14:46 +08:00
Soulter aca18fab0f feat: 优化配置文件中的提示信息,增强可读性 2025-03-26 00:56:51 +08:00
Soulter 691de01b79 feat: 支持设置最多携带对话数量 2025-03-26 00:46:15 +08:00
Soulter 3383f15142 Merge pull request #988 from Soulter/NiceAir/master
 feat: Update UI elements and improve layout in various components
2025-03-25 23:17:11 +08:00
Soulter 84c1593889 feat: Update UI elements and improve layout in various components 2025-03-25 21:52:15 +08:00
Soulter 3c80fa1e33 Update README.md 2025-03-25 21:31:23 +08:00
Soulter 06b16a1deb Merge pull request #983 from Soulter/feat-conversation-webui-mgr
 支持 WebUI 对话管理
2025-03-25 21:26:00 +08:00
Soulter 4c4246fb09 Merge pull request #982 from NiceAir/master
添加对gewe的表情包、引用消息、视频的支持
2025-03-25 21:25:00 +08:00
Soulter 364be1e9f6 🐛 fix: Handle missing defusedxml dependency for Gewechat message parsing 2025-03-25 21:21:38 +08:00
NiceAir f959ed71aa feat: Gewechat 支持表情包、引用消息、视频
Co-authored-by: Soulter <905617992@qq.com>
2025-03-25 21:00:12 +08:00
anka 5c4326c302 perf: 部分详细注释, 符合PEP8标准 2025-03-25 20:53:23 +08:00
Soulter 125fc3a622 feat: 支持 WebUI 对话管理 2025-03-25 19:44:46 +08:00
Soulter 6b9e785db3 Merge pull request #968 from Soulter/pre-commit-ci-update-config
🎈 pre-commit autoupdate
2025-03-25 15:03:39 +08:00
Soulter 25d34e9a43 Merge pull request #974 from zhx8702/feat-webui-add-search-keys
feat: 插件市场列表卡片过滤条件提出变量保持一致
2025-03-25 15:03:09 +08:00
Soulter 457d4aa1dc Merge pull request #976 from Raven95676/master
Improves Telegram adapter termination
2025-03-25 15:01:04 +08:00
Raven95676 ff0c0992ff Improves Telegram adapter termination 2025-03-25 14:46:20 +08:00
Soulter d379e012c4 🐛 fix: telegram /start issue #751 2025-03-25 14:03:46 +08:00
zhx 151fff26fd feat: 插件市场列表卡片过滤条件提出变量保持一致 2025-03-25 13:50:16 +08:00
Soulter 3d0d561215 Update compose.yml 2025-03-25 13:24:37 +08:00
Soulter 22d586ed7b Update compose.yml 2025-03-25 13:24:19 +08:00
Soulter 6dc19b29e8 🐛 fix: remove redundant validation call in config validation function #901 2025-03-25 12:56:48 +08:00
Soulter 50975a87d4 🐛 fix: handle message sending failures with error logging 2025-03-25 12:34:43 +08:00
Soulter ce721d9f0f 🐛 fix: platform adapter server blocks ctrl+c 2025-03-25 11:31:46 +08:00
Soulter 20510a33f7 feat: improve pyproject and use uv as package mgr 2025-03-25 11:07:20 +08:00
pre-commit-ci[bot] 3abd9c8763 🎈 pre-commit autoupdate
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.11.0 → v0.11.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.0...v0.11.2)
2025-03-24 17:08:12 +00:00
Soulter e9eff7420b feat: 更加完善和美观的 本地 Markdown 渲染 2025-03-25 00:56:19 +08:00
Soulter 64c250c9d8 🎈perf: 优化可能的 conversation 为 None 的问题 2025-03-25 00:06:25 +08:00
Soulter 8047f82bfd 🎈perf: 优化删除插件目录的逻辑,抛出异常细节;完善 mcp 未安装时的提示 2025-03-24 23:07:56 +08:00
Soulter af6467fb3d Merge pull request #962 from zhx8702/feat-webui-add-double-confirm
feat: 删除插件添加二次确认,插件列表添加非空判断
2025-03-24 23:01:43 +08:00
zhx 3ff1664aec feat: 删除多余代码 2025-03-24 20:27:05 +08:00
zhx 34ea2b44b8 Merge remote-tracking branch 'upstream/master' into feat-webui-add-double-confirm 2025-03-24 19:42:47 +08:00
Soulter 6c8d851109 Merge pull request #955 from Raven95676/master
Telegram适配器消息处理功能增强
2025-03-24 18:10:51 +08:00
Soulter d678299a74 Merge branch 'master' into master 2025-03-24 18:10:27 +08:00
Soulter 7aed0db2b6 Merge pull request #951 from IGCrystal/master
fix: fix SSLCertVerificationError
2025-03-24 18:05:49 +08:00
Soulter 0355524345 Merge branch 'master' into master 2025-03-24 17:58:00 +08:00
Soulter 0a43e4672e style: format codes 2025-03-24 17:57:28 +08:00
zhx 71e0ccdfec feat: 删除插件添加二次确认,插件列表添加非空判断 2025-03-24 16:41:54 +08:00
冰苷晶 1df33ac3c8 fix: fix error 2025-03-24 13:28:14 +08:00
pre-commit-ci[bot] 7334090ac1 🎈 auto fixes by pre-commit hooks 2025-03-24 05:20:37 +00:00
冰苷晶 6b0f044198 fix: fix other errors 2025-03-24 13:20:05 +08:00
pre-commit-ci[bot] ddf54c9cf8 🎈 auto fixes by pre-commit hooks 2025-03-24 04:32:21 +00:00
IGCrystal 7c64e184e2 Merge branch 'Soulter:master' into master 2025-03-24 12:32:16 +08:00
渡鸦95676 a904db033c Merge branch 'Soulter:master' into master 2025-03-24 12:19:17 +08:00
渡鸦95676 b234856b02 Remove unused variable
移除以通过ruff检查
在Ubuntu24.04LTS中,移除未见对现有功能的影响
2025-03-24 11:36:46 +08:00
Soulter 89d51d2afc 🎈 perf: config UI 2025-03-24 11:36:38 +08:00
Soulter 37cb9678e9 Merge pull request #826 from XuYingJie-cmd/master
新增了关于gewe发送视频的功能
2025-03-24 11:25:24 +08:00
pre-commit-ci[bot] 0500ff333a 🎈 auto fixes by pre-commit hooks 2025-03-24 02:50:28 +00:00
Raven95676 08528510ef Fix incorrect handling of reply messages within topics 2025-03-24 10:41:33 +08:00
Raven95676 ddbd03dc1e Adds sticker handling in Telegram adapter 2025-03-24 10:40:20 +08:00
Soulter ade87f378a 🎈 perf: UI 优化 2025-03-24 00:32:40 +08:00
冰苷晶 4db14b905f fix: fix error 2025-03-23 23:40:06 +08:00
pre-commit-ci[bot] b669b31451 🎈 auto fixes by pre-commit hooks 2025-03-23 15:07:22 +00:00
冰苷晶 1cb2b62f81 fix: fix error 2025-03-23 23:02:34 +08:00
Soulter e5828713cf 🎈 perf: improve ChatPage and ConfigPage UI 2025-03-23 22:57:02 +08:00
冰苷晶 d10cb84068 fix: fix SSLCertVerificationError 2025-03-23 22:55:07 +08:00
Soulter 4222f8516f Merge pull request #844 from AraragiEro/mcp_adapt
支持 MCP 服务并优化函数调用流程
2025-03-23 22:35:35 +08:00
Soulter 7f998c7611 chore: remove useless print output 2025-03-23 22:28:00 +08:00
Soulter db46000337 🎨 style: format codes 2025-03-23 22:22:11 +08:00
Soulter 1aac8d8041 feat: 适配完整的 function-calling 流程 2025-03-23 22:21:47 +08:00
Soulter c59c8e05f7 🐛 fix: tools result 2025-03-23 17:03:18 +08:00
Soulter 4942d0a629 feat: 在工具使用页面添加函数调用信息提示和链接功能 2025-03-23 17:00:38 +08:00
Soulter 873b7715f4 🎈 perf: 优化 MCP Client 异步 Event 管理 2025-03-23 16:51:28 +08:00
pre-commit-ci[bot] 98e7ed6920 🎈 auto fixes by pre-commit hooks 2025-03-23 08:34:05 +00:00
Soulter 046f5e645e feat: 完善 MCP 管理和实现 WebUI MCP 相关的页面 2025-03-23 16:33:44 +08:00
pre-commit-ci[bot] f5e5a7094c 🎈 auto fixes by pre-commit hooks 2025-03-23 06:39:13 +00:00
Gao Jinzhe 154125fee6 Add files via upload 2025-03-23 14:35:44 +08:00
pre-commit-ci[bot] 9f8e960ebe 🎈 auto fixes by pre-commit hooks 2025-03-23 03:31:20 +00:00
Soulter 4179b0be0a chore: 优化注解格式和 requirements.txt 2025-03-23 11:31:10 +08:00
Soulter 28bafa38db Merge branch 'master' into mcp_adapt 2025-03-23 11:01:44 +08:00
Soulter b07552565e Merge pull request #926 from Soulter/perf-graceful-shutdown
支持所有消息平台的优雅退出
2025-03-23 10:56:56 +08:00
Soulter c4427471d2 🎨 style: format codes 2025-03-23 00:25:26 +08:00
Soulter 08f81c6784 🐛 fix: 修复图片没有被存储到上下文中的问题 2025-03-23 00:23:42 +08:00
Soulter a471e98aca 🐛 fix: Telegram 下无法识别图片描述(Caption) #910 2025-03-23 00:23:01 +08:00
Soulter 75a8fcc8a0 🐛 fix: 修复 Telegram 下非默认群组话题引用消息异常 #906 2025-03-22 23:39:21 +08:00
Soulter 46ef76c168 feat: 支持消息平台的热重载 2025-03-22 19:54:54 +08:00
Soulter 66637446c9 Merge remote-tracking branch 'origin/master' into perf-graceful-shutdown 2025-03-22 19:26:35 +08:00
Soulter 21efeb888a Merge pull request #904 from LunarMeal/master
新增了newgroup指令
2025-03-22 19:18:06 +08:00
Soulter a4ee8b5322 Merge remote-tracking branch 'origin/master' into LunarMeal/master 2025-03-22 19:17:12 +08:00
Soulter 36519ac47e 🐛 fix: groupnew 设置为管理员指令 2025-03-22 19:14:58 +08:00
Soulter 3f514fceca 🎨 style: format codes 2025-03-22 19:07:47 +08:00
pre-commit-ci[bot] c2249fdfac 🎈 auto fixes by pre-commit hooks 2025-03-22 11:06:42 +00:00
Soulter c610719a44 feat: 为各平台适配器支持优雅关闭 2025-03-22 19:02:49 +08:00
Soulter 36a6c2461a 🐛 fix: 修复 Telegram Topic 群组下LLM 上下文及主动消息混乱的问题 #908 2025-03-22 18:15:43 +08:00
Soulter c29f22c39e Update PLUGIN_PUBLISH.yml 2025-03-22 15:51:35 +08:00
Soulter 30d3062944 🎈 perf: 优化钉钉在配置错误之后堵塞整个线程的问题 #885
a.k.a 帮钉钉擦屁股
2025-03-22 15:44:42 +08:00
Soulter 69ba75abf4 Update README.md 2025-03-22 01:26:03 +08:00
Soulter e4d486fec5 docs: 宝塔面板部署方式 2025-03-22 00:42:04 +08:00
Soulter f242144dcf 更新 README.md 2025-03-21 19:21:35 +08:00
Soulter 02dee2d664 🎈 perf: add error handling for missing pyffmpeg library in video sending functionality 2025-03-21 16:51:23 +08:00
Soulter a3dd2c3069 Merge remote-tracking branch 'origin/master' into XuYingJie-cmd/master 2025-03-21 16:49:15 +08:00
Soulter a23425e8aa Merge pull request #781 from Moyuyanli/master
添加gewe的群相关操作
2025-03-21 16:31:10 +08:00
Moyuyanli be79ddc9a3 fix:去掉跟post_text功能相同的接口方法 2025-03-21 16:24:31 +08:00
Soulter 7d71015e8c Update README.md 2025-03-21 16:12:25 +08:00
Soulter ad54549b51 Update README.md 2025-03-21 15:58:40 +08:00
Soulter 6cf032a164 Update compose.yml 2025-03-21 11:06:22 +08:00
Soulter 6390d796ac Update compose.yml 2025-03-21 11:05:44 +08:00
Soulter 98b8411905 Update compose.yml 2025-03-21 10:53:09 +08:00
LunarMeal ddf1029afa Merge branch 'master' of https://github.com/LunarMeal/AstrBot 2025-03-20 22:53:29 +08:00
LunarMeal 1effbc5cc9 fix 2025-03-20 22:53:21 +08:00
pre-commit-ci[bot] 414b645e9f 🎈 auto fixes by pre-commit hooks 2025-03-20 14:42:37 +00:00
LunarMeal 398c76f496 新增了newgroup指令 2025-03-20 22:39:49 +08:00
Soulter 1bc456dd95 🎈 perf: 改善一些术语描述 2025-03-20 20:31:36 +08:00
Soulter 2e8421884e Merge pull request #864 from Soulter/pre-commit-ci-update-config
🎈 pre-commit autoupdate
2025-03-20 20:23:45 +08:00
Soulter 70d9b193ac 🐛 fix: 修复私聊下 get_group 的一些问题 2025-03-20 20:18:20 +08:00
Moyuyanli b49c11004a fix:还原回原来的依赖信息 2025-03-20 19:57:35 +08:00
Soulter 34843eea90 🎨 style: format codes 2025-03-20 18:07:24 +08:00
pre-commit-ci[bot] 2d6d7f31e8 🎈 auto fixes by pre-commit hooks 2025-03-20 10:06:11 +00:00
Soulter 7a24cbff1c feat: 支持 aiocqhttp 适配器下的获取群消息 2025-03-20 18:05:44 +08:00
pre-commit-ci[bot] 1e7eb2cf1c 🎈 auto fixes by pre-commit hooks 2025-03-20 09:21:32 +00:00
Soulter 361256e016 chore: 添加了一些 gewechat client 的注释 2025-03-20 17:20:32 +08:00
Soulter 8838dbd003 🎨 style: format codes 2025-03-20 16:54:27 +08:00
pre-commit-ci[bot] 13a95e1f2b 🎈 auto fixes by pre-commit hooks 2025-03-20 08:42:40 +00:00
Soulter 1aaa451a3e Merge branch 'master' into Moyuyanli/master 2025-03-20 16:42:13 +08:00
Soulter cbba81e54d 🐛 fix: 无法接收图片 aiocqhttp 2025-03-20 16:03:41 +08:00
Soulter 370868dfac 🎈 perf: 消息平台和配置提供商配置页中,自动更新旧的配置,添加新的配置项 2025-03-20 13:22:49 +08:00
Soulter 77f692aae2 🎈 perf: 配置项显示优化 2025-03-20 13:17:27 +08:00
Soulter 9318e205ea feat: 阿里云百炼应用支持 RAG 应用 #878 2025-03-20 13:17:06 +08:00
Soulter ebcc717c19 🎈 perf: Dify 下支持更多类型的图片输入及提高代码复用性 #893
🐛 fix: 修复飞书下无法进行图片输入的问题
2025-03-20 11:21:45 +08:00
Soulter 4c16b564ee 🎈 perf: 忽略微信团队消息 #859 2025-03-19 01:09:01 +08:00
Soulter e2283d1453 🐛 fix: 修复 dify 下某些修改了 LLM 响应的插件可能不生效的问题 #876 2025-03-19 01:05:28 +08:00
Soulter d891801c5a v3.4.39 2025-03-18 22:43:35 +08:00
Soulter de75386944 🎈 perf: 登录后检查默认密码和弹出修改警告 2025-03-18 22:41:33 +08:00
Soulter 82dc37de50 style: format codes 2025-03-18 22:21:47 +08:00
Soulter b6fa7f62dc chore: 添加安全提示信息 2025-03-18 22:18:01 +08:00
Soulter f9e0a95c5e chore: 默认地址改回 0.0.0.0 2025-03-18 22:15:22 +08:00
pre-commit-ci[bot] b2c6e12647 🎈 auto fixes by pre-commit hooks 2025-03-17 17:10:06 +00:00
pre-commit-ci[bot] caffb83780 🎈 pre-commit autoupdate
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.9.10 → v0.11.0](https://github.com/astral-sh/ruff-pre-commit/compare/v0.9.10...v0.11.0)
2025-03-17 17:09:59 +00:00
Moyuyanli 2e4fef6c66 feat:添加消息记录器 2025-03-17 16:02:55 +08:00
Alero 8585cd8e21 修复codecheck 2025-03-15 20:26:17 +08:00
Alero 9fa2a7eeea 修复codecheck 2025-03-15 20:24:36 +08:00
pre-commit-ci[bot] 2d1f74228d 🎈 auto fixes by pre-commit hooks 2025-03-15 12:10:17 +00:00
Alero 3d6f7aa0e1 修复codecheck 2025-03-15 20:09:49 +08:00
pre-commit-ci[bot] 3dea60366a 🎈 auto fixes by pre-commit hooks 2025-03-15 11:54:09 +00:00
Alero d4d9a1df4c feat:新增MCP服务支持并优化工具调用逻辑
引入MCP客户端支持,增加mcp_server.json配置样例,完善工具描述生成及调用逻辑以支持MCP服务工具功能。同时调整相关逻辑以区分本地工具与MCP工具的调用方式,提升扩展性和灵活性。
2025-03-15 19:47:06 +08:00
Moyuyanli c095248176 Merge remote-tracking branch 'origin/master' 2025-03-14 18:30:42 +08:00
Moyuyanli 44601c8954 fix:修复gewe的ModContacts消息类型 2025-03-14 18:30:27 +08:00
pre-commit-ci[bot] c95682a0c7 🎈 auto fixes by pre-commit hooks 2025-03-14 09:11:21 +00:00
Moyuyanli d177b9f7fa feat:添加主动添加好友事件 2025-03-14 17:11:10 +08:00
徐英杰 9b57615d94 新增了关于gewe发送视频的功能 2025-03-14 16:19:41 +08:00
pre-commit-ci[bot] 00f5189f58 🎈 auto fixes by pre-commit hooks 2025-03-11 09:16:43 +00:00
Moyuyanli 4a8309ed1f style:idea默认格式化了部分代码
feat:添加根据消息事件获取群信息的接口
2025-03-11 17:10:55 +08:00
Moyuyanli 76cfc31a1d feat:添加 Group 类型 2025-03-11 17:10:04 +08:00
Moyuyanli d9ec434699 feat:gewe的client添加 添加好友接口
feat:gewe的client添加 获取群信息/群成员接口
feat:gewe的client添加 添加群成员为好友接口
2025-03-11 17:08:33 +08:00
138 changed files with 12692 additions and 2257 deletions
+5 -4
View File
@@ -6,7 +6,7 @@ body:
- type: markdown
attributes:
value: |
欢迎发布插件到插件市场!
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
- type: textarea
attributes:
@@ -22,9 +22,10 @@ body:
插件名:
插件作者:
插件简介:
标签: (可选)
社交链接: (可选, 将会在插件市场作者名称上作为可点击的链接)
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。
支持的消息平台:(必填,如 QQ、微信、飞书)
标签:(可选)
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
- type: checkboxes
attributes:
+3 -1
View File
@@ -1,6 +1,8 @@
__pycache__
botpy.log
.vscode
.venv*
.idea
data_v2.db
data_v3.db
configs/session
@@ -26,5 +28,5 @@ venv/*
packages/python_interpreter/workplace
.venv/*
.conda/
.idea/
.idea
pytest.ini
+1 -1
View File
@@ -7,7 +7,7 @@ ci:
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.10
rev: v0.11.2
hooks:
- id: ruff
- id: ruff-format
+10 -2
View File
@@ -9,12 +9,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
python3-dev \
libffi-dev \
libssl-dev \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN python -m pip install -r requirements.txt --no-cache-dir
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
# 释出 ffmpeg
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
EXPOSE 6185
EXPOSE 6186
+35
View File
@@ -0,0 +1,35 @@
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
curl \
unzip \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Installation of Node.js
ENV NVM_DIR="/root/.nvm"
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
. "$NVM_DIR/nvm.sh" && \
nvm install 22 && \
nvm use 22
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]
+50 -49
View File
@@ -1,6 +1,6 @@
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
![yjtp](https://github.com/user-attachments/assets/dcc74009-c57e-4b66-9ae3-0a81fc001255)
</p>
@@ -10,14 +10,12 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=60)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
[![star](https://gitcode.com/Soulter/AstrBot/star/badge.svg)](https://gitcode.com/Soulter/AstrBot)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=10800&style=for-the-badge&color=3b618e)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
@@ -27,19 +25,31 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
[![star](https://gitcode.com/Soulter/AstrBot/star/badge.svg?style=for-the-badge)](https://gitcode.com/Soulter/AstrBot)
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
## ✨ 近期更新
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
## ✨ 主要功能
> [!NOTE]
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
> [!TIP]
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
> 用户名: `astrbot`, 密码: `astrbot`。
## ✨ 使用方式
@@ -49,30 +59,33 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
#### Windows 一键安装器部署
需要电脑上安装有 Python>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### Replit 部署
#### 宝塔面板部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### CasaOS 部署
社区贡献的部署方式。
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。
#### 手动部署
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)
推荐使用 `uv`
## 🚀 路线图
```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
pip install uv
uv run main.py
```
### 垂类功能
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
1. 更好的上下文管理:限制 token 总数、对话上下文总结
3. AstrBot in Minecraft
#### Replit 部署
### 横功能
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ⚡ 消息平台支持情况
@@ -94,7 +107,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI API | ✔ | 文本生成 | 同时也支持 DeepSeek、Google Gemini、GLM(智谱)、Moonshot(月之暗面)、阿里云百炼、硅基流动、xAI 等所有兼容 OpenAI API 的服务 |
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、硅基流动、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
@@ -106,6 +119,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
@@ -135,38 +149,36 @@ pre-commit install
## ✨ Demo
> [!NOTE]
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ 自然语言待办事项 ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
_管理面板_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内置 Web Chat,在线与机器人交互 ✨_
_WebUI_
</div>
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
## ⭐ Star History
> [!TIP]
@@ -184,16 +196,5 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
_私は、高性能ですから!_
+1 -1
View File
@@ -28,7 +28,7 @@ AstrBot is a loosely coupled, asynchronous chatbot and development framework tha
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://astrbot.app/others/dify.html) for easy access to Dify assistants/knowledge bases/workflows.
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
+1 -1
View File
@@ -28,7 +28,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot)、QQ チャンネル、WeChatGewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
+2
View File
@@ -5,6 +5,7 @@ from astrbot.core.platform import (
MessageMember,
MessageType,
PlatformMetadata,
Group,
)
from astrbot.core.platform.register import register_platform_adapter
@@ -18,4 +19,5 @@ __all__ = [
"MessageType",
"PlatformMetadata",
"register_platform_adapter",
"Group",
]
+2 -6
View File
@@ -2,11 +2,7 @@ from astrbot.core.star.register import (
register_star as register, # 注册插件(Star
)
from astrbot.core.star import Context, Star
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import *
__all__ = [
"register",
"Context",
"Star",
]
__all__ = ["register", "Context", "Star", "StarTools"]
+5 -1
View File
@@ -8,6 +8,7 @@ from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
# 初始化数据存储文件夹
os.makedirs("data", exist_ok=True)
astrbot_config = AstrBotConfig()
@@ -19,8 +20,11 @@ if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
db_helper = SQLiteDatabase(DB_PATH)
sp = SharedPreferences() # 简单的偏好设置存储
sp = (
SharedPreferences()
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
DEMO_MODE = os.getenv("DEMO_MODE", False)
+145 -17
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.38"
VERSION = "3.5.2"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -49,6 +49,7 @@ DEFAULT_CONFIG = {
"datetime_system_prompt": True,
"default_personality": "default",
"prompt_prefix": "",
"max_context_length": -1,
},
"provider_stt_settings": {
"enable": False,
@@ -80,22 +81,24 @@ DEFAULT_CONFIG = {
"admins_id": ["astrbot"],
"t2i": False,
"t2i_word_threshold": 150,
"t2i_strategy": "remote",
"t2i_endpoint": "",
"http_proxy": "",
"dashboard": {
"enable": True,
"username": "astrbot",
"password": "77b90590a8945a7d36c963981a307dc9",
"host": "127.0.0.1",
"host": "0.0.0.0",
"port": 6185,
},
"platform": [],
"wake_prefix": ["/"],
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": "",
"knowledge_db": {},
"persona": [],
"timezone": "",
}
@@ -223,7 +226,7 @@ CONFIG_METADATA_2 = {
"hint": "启用后,机器人可以接收到频道的私聊消息。",
},
"ws_reverse_host": {
"description": "反向 Websocket 主机地址",
"description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
"type": "string",
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。",
},
@@ -345,7 +348,7 @@ CONFIG_METADATA_2 = {
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "只处理填写的 ID 发来的消息事件为空时不启用白名单过滤。可使用 /sid 指令获取在某个平台上的会话 ID。会话 ID 类似 aiocqhttp:GroupMessage:547540978。管理员可使用 /wl 添加白名单",
"hint": "只处理填写的 ID 发来的消息事件为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
},
"id_whitelist_log": {
"description": "打印白名单日志",
@@ -517,7 +520,14 @@ CONFIG_METADATA_2 = {
"api_base": "https://generativelanguage.googleapis.com/",
"timeout": 120,
"model_config": {
"model": "gemini-1.5-flash",
"model": "gemini-2.0-flash-exp",
},
"gm_resp_image_modal": False,
"gm_safety_settings": {
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
},
},
"DeepSeek": {
@@ -581,7 +591,7 @@ CONFIG_METADATA_2 = {
"dify_api_type": "chat",
"dify_api_key": "",
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
"dify_workflow_output_key": "astrbot_wf_output",
"dify_query_input_key": "astrbot_text_query",
"variables": {},
"timeout": 60,
@@ -593,6 +603,11 @@ CONFIG_METADATA_2 = {
"dashscope_app_type": "agent",
"dashscope_api_key": "",
"dashscope_app_id": "",
"rag_options": {
"pipeline_ids": [],
"file_ids": [],
"output_reference": False,
},
"variables": {},
"timeout": 60,
},
@@ -663,8 +678,102 @@ CONFIG_METADATA_2 = {
"fishaudio-tts-character": "可莉",
"timeout": "20",
},
"阿里云百炼_TTS(API)": {
"id": "dashscope_tts",
"type": "dashscope_tts",
"enable": False,
"api_key": "",
"model": "cosyvoice-v1",
"dashscope_tts_voice": "loongstella",
"timeout": "20",
},
},
"items": {
"dashscope_tts_voice": {
"description": "语音合成模型",
"type": "string",
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
},
"gm_resp_image_modal": {
"description": "启用图片模态",
"type": "bool",
"hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。",
},
"gm_safety_settings": {
"description": "安全过滤器",
"type": "object",
"hint": "设置模型输入的内容安全过滤级别。过滤级别分类为NONE(不屏蔽)、HIGH(高风险时屏蔽)、MEDIUM_AND_ABOVE(中等风险及以上屏蔽)、LOW_AND_ABOVE(低风险及以上时屏蔽),具体参见Gemini API文档。",
"items": {
"harassment": {
"description": "骚扰内容",
"type": "string",
"hint": "负面或有害评论",
"options": [
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
],
},
"hate_speech": {
"description": "仇恨言论",
"type": "string",
"hint": "粗鲁、无礼或亵渎性质内容",
"options": [
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
],
},
"sexually_explicit": {
"description": "露骨色情内容",
"type": "string",
"hint": "包含性行为或其他淫秽内容的引用",
"options": [
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
],
},
"dangerous_content": {
"description": "危险内容",
"type": "string",
"hint": "宣扬、助长或鼓励有害行为的信息",
"options": [
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
],
},
},
},
"rag_options": {
"description": "RAG 选项",
"type": "object",
"hint": "检索知识库设置, 非必填。仅 Agent 应用类型支持(智能体应用, 包括 RAG 应用)。阿里云百炼应用开启此功能后将无法多轮对话。",
"items": {
"pipeline_ids": {
"description": "知识库 ID 列表",
"type": "list",
"items": {"type": "string"},
"hint": "对指定知识库内所有文档进行检索, 前往 https://bailian.console.aliyun.com/ 数据应用->知识索引创建和获取 ID。",
},
"file_ids": {
"description": "非结构化文档 ID, 传入该参数将对指定非结构化文档进行检索。",
"type": "list",
"items": {"type": "string"},
"hint": "对指定非结构化文档进行检索。前往 https://bailian.console.aliyun.com/ 数据管理创建和获取 ID。",
},
"output_reference": {
"description": "是否输出知识库/文档的引用",
"type": "bool",
"hint": "在每次回答尾部加上引用源。默认为 False。",
},
},
},
"sensevoice_hint": {
"description": "部署SenseVoice",
"type": "string",
@@ -681,12 +790,14 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "modelscope 上的模型名称。默认:iic/SenseVoiceSmall。",
},
# "variables": {
# "description": "工作流固定输入变量",
# "type": "object",
# "obvious_hint": True,
# "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
# },
"variables": {
"description": "工作流固定输入变量",
"type": "object",
"obvious_hint": True,
"items": {},
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
"invisible": True,
},
# "fastgpt_app_type": {
# "description": "应用类型",
# "type": "string",
@@ -697,7 +808,7 @@ CONFIG_METADATA_2 = {
"dashscope_app_type": {
"description": "应用类型",
"type": "string",
"hint": "阿里云百炼应用的应用类型。",
"hint": "百炼应用的应用类型。",
"options": [
"agent",
"agent-arrange",
@@ -877,6 +988,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "添加之后,会在每次对话的 Prompt 前加上此文本。",
},
"max_context_length": {
"description": "最多携带对话数量(条)",
"type": "int",
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
},
},
},
"persona": {
@@ -970,10 +1086,10 @@ CONFIG_METADATA_2 = {
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
},
"image_caption": {
"description": "启用图像转述(需模型支持)",
"description": "群聊图像转述(需模型支持)",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型",
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入",
},
"image_caption_provider_id": {
"description": "图像转述提供商 ID",
@@ -1057,16 +1173,28 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
},
"timezone": {
"description": "时区",
"type": "string",
"obvious_hint": True,
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
},
"log_level": {
"description": "控制台日志级别",
"type": "string",
"hint": "控制台输出日志的级别。",
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
},
"t2i_strategy": {
"description": "文本转图像渲染源",
"type": "string",
"hint": "文本转图像策略。`remote` 为使用远程基于 HTML 的渲染服务,`local` 为使用 PIL 本地渲染。当使用 local 时,将 ttf 字体命名为 'font.ttf' 放在 data/ 目录下可自定义字体。",
"options": ["remote", "local"],
},
"t2i_endpoint": {
"description": "文本转图像服务接口",
"type": "string",
"hint": "为空时使用 AstrBot API 服务",
"hint": "当 t2i_strategy 为 remote 时生效。为空时使用 AstrBot API 服务",
},
"pip_install_arg": {
"description": "pip 安装参数",
+79 -9
View File
@@ -1,3 +1,10 @@
"""
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import uuid
import json
import asyncio
@@ -11,24 +18,34 @@ class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
"""启动定时保存任务"""
asyncio.create_task(self._periodic_save())
async def _periodic_save(self):
"""定时保存会话对话映射关系到存储中"""
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
"""保存会话对话映射关系到存储中"""
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
"""新建对话,并将当前会话的对话转移到新对话"""
"""新建对话,并将当前会话的对话转移到新对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
conversation_id = str(uuid.uuid4())
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
self.session_conversations[unified_msg_origin] = conversation_id
@@ -36,14 +53,24 @@ class ConversationManager:
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话"""
"""切换会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话"""
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
@@ -51,23 +78,48 @@ class ConversationManager:
sp.put("session_conversation", self.session_conversations)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
"""获取会话当前的对话 ID"""
"""获取会话当前的对话 ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(
self, unified_msg_origin: str, conversation_id: str
) -> Conversation:
"""获取会话的对话"""
"""获取会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Returns:
conversation (Conversation): 对话对象
"""
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
"""获取会话的所有对话"""
"""获取会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversations (List[Conversation]): 对话对象列表
"""
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
):
"""更新会话的对话"""
"""更新会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
@@ -76,7 +128,12 @@ class ConversationManager:
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
"""更新会话的对话标题"""
"""更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
@@ -86,7 +143,12 @@ class ConversationManager:
async def update_conversation_persona_id(
self, unified_msg_origin: str, persona_id: str
):
"""更新会话的对话 Persona ID"""
"""更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
@@ -96,6 +158,14 @@ class ConversationManager:
async def get_human_readable_context(
self, unified_msg_origin, conversation_id, page=1, page_size=10
):
"""获取人类可读的上下文
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
page (int): 页码
page_size (int): 每页大小
"""
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history)
+89 -17
View File
@@ -1,3 +1,14 @@
"""
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
工作流程:
1. 初始化所有组件
2. 启动事件总线和任务, 所有任务都在这里运行
3. 执行启动完成事件钩子
"""
import traceback
import asyncio
import time
@@ -24,32 +35,51 @@ from astrbot.core.star.star_handler import star_map
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = astrbot_config
self.db = db
"""
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
EventBus 等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
"""
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker # 初始化日志代理
self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库
# 根据环境变量设置代理
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
os.environ["no_proxy"] = "localhost"
async def initialize(self):
"""
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理
logger.info("AstrBot v" + VERSION)
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG
else:
logger.setLevel(self.astrbot_config["log_level"])
self.event_queue = Queue()
self.event_queue.closed = False
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
# 初始化事件队列
self.event_queue = Queue()
# 初始化供应商管理器
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
# 初始化知识库管理器
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
# 初始化对话管理器
self.conversation_manager = ConversationManager(self.db)
# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
self.astrbot_config,
@@ -59,33 +89,50 @@ class AstrBotCoreLifecycle:
self.conversation_manager,
self.knowledge_db_manager,
)
# 初始化插件管理器
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
# 扫描、注册插件、实例化插件类
await self.plugin_manager.reload()
"""扫描、注册插件、实例化插件类"""
# 根据配置实例化各个 Provider
await self.provider_manager.initialize()
"""根据配置实例化各个 Provider"""
# 初始化消息事件流水线调度器
self.pipeline_scheduler = PipelineScheduler(
PipelineContext(self.astrbot_config, self.plugin_manager)
)
await self.pipeline_scheduler.initialize()
"""初始化消息事件流水线调度器"""
# 初始化更新器
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
# 初始化事件总线
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
# 记录启动时间
self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks: List[asyncio.Task] = []
# 根据配置实例化各个平台适配器
await self.platform_manager.initialize()
"""根据配置实例化各个平台适配器"""
# 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event()
def _load(self):
"""加载事件总线和任务并初始化"""
# 创建一个异步任务来执行事件总线的 dispatch() 方法
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
event_bus_task = asyncio.create_task(
self.event_bus.dispatch(), name="event_bus"
)
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = []
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
@@ -99,17 +146,24 @@ class AstrBotCoreLifecycle:
self.start_time = int(time.time())
async def _task_wrapper(self, task: asyncio.Task):
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
Args:
task (asyncio.Task): 要执行的异步任务
"""
try:
await task
except asyncio.CancelledError:
pass
pass # 任务被取消, 静默处理
except Exception as e:
# 获取完整的异常堆栈信息, 按行分割并记录到日志中
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
async def start(self):
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
self._load()
logger.info("AstrBot 启动完成。")
@@ -126,15 +180,29 @@ class AstrBotCoreLifecycle:
except BaseException:
logger.error(traceback.format_exc())
# 同时运行curr_tasks中的所有任务
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self):
self.event_queue.closed = True
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
# 请求停止所有正在运行的异步任务
for task in self.curr_tasks:
task.cancel()
await self.provider_manager.terminate()
for plugin in self.plugin_manager.context.get_all_stars():
try:
await self.plugin_manager._terminate_plugin(plugin)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
)
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束
for task in self.curr_tasks:
try:
await task
@@ -143,13 +211,17 @@ class AstrBotCoreLifecycle:
except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
def restart(self):
self.event_queue.closed = True
async def restart(self):
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
threading.Thread(
target=self.astrbot_updator._reboot, name="restart", daemon=True
).start()
def load_platform(self) -> List[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = []
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
+43 -1
View File
@@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass
from typing import List
from typing import List, Dict, Any, Tuple
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@@ -117,3 +117,45 @@ class BaseDatabase(abc.ABC):
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
"""更新 Conversation Persona ID"""
raise NotImplementedError
@abc.abstractmethod
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页
Args:
page: 页码,从1开始
page_size: 每页数量
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
@abc.abstractmethod
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表
Args:
page: 页码
page_size: 每页数量
platforms: 平台筛选列表
message_types: 消息类型筛选列表
search_query: 搜索关键词
exclude_ids: 排除的用户ID列表
exclude_platforms: 排除的平台列表
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
+112
View File
@@ -0,0 +1,112 @@
import json
import aiosqlite
import os
from typing import Any
from .plugin_storage import PluginStorage
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
class SQLitePluginStorage(PluginStorage):
"""插件数据的 SQLite 存储实现类。
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
所有数据以 (plugin, key) 作为复合主键进行索引。
"""
_instance = None # Standalone instance of the class
_db_conn = None
db_path = None
def __new__(cls):
"""
创建或获取 SQLitePluginStorage 的单例实例。
如果实例已存在,则返回现有实例;否则创建一个新实例。
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
"""
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
if cls._instance is None:
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
cls._instance.db_path = DBPATH
return cls._instance
async def _init_db(self):
"""初始化数据库连接(只执行一次)"""
if SQLitePluginStorage._db_conn is None:
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
await self._setup_db()
async def _setup_db(self):
"""
异步初始化数据库。
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
其中 plugin 和 key 组合作为主键。
"""
await self._db_conn.execute("""
CREATE TABLE IF NOT EXISTS plugin_data (
plugin TEXT,
key TEXT,
value TEXT,
PRIMARY KEY (plugin, key)
)
""")
await self._db_conn.commit()
async def set(self, plugin: str, key: str, value: Any):
"""
异步存储数据。
将指定插件的键值对存入数据库,如果键已存在则更新值。
值会被序列化为 JSON 字符串后存储。
Args:
plugin: 插件标识符
key: 数据键名
value: 要存储的数据值(任意类型,将被 JSON 序列化)
"""
await self._init_db()
await self._db_conn.execute(
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
(plugin, key, json.dumps(value)),
)
await self._db_conn.commit()
async def get(self, plugin: str, key: str) -> Any:
"""
异步获取数据。
从数据库中获取指定插件和键名对应的值,
返回的值会从 JSON 字符串反序列化为原始数据类型。
Args:
plugin: 插件标识符
key: 数据键名
Returns:
Any: 存储的数据值,如果未找到则返回 None
"""
await self._init_db()
async with self._db_conn.execute(
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
(plugin, key),
) as cursor:
row = await cursor.fetchone()
return json.loads(row[0]) if row else None
async def delete(self, plugin: str, key: str):
"""
异步删除数据。
从数据库中删除指定插件和键名对应的数据项。
Args:
plugin: 插件标识符
key: 要删除的数据键名
"""
await self._init_db()
await self._db_conn.execute(
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
)
await self._db_conn.commit()
+8
View File
@@ -6,6 +6,8 @@ from typing import List
@dataclass
class Platform:
"""平台使用统计数据"""
name: str
count: int
timestamp: int
@@ -13,6 +15,8 @@ class Platform:
@dataclass
class Provider:
"""供应商使用统计数据"""
name: str
count: int
timestamp: int
@@ -20,6 +24,8 @@ class Provider:
@dataclass
class Plugin:
"""插件使用统计数据"""
name: str
count: int
timestamp: int
@@ -27,6 +33,8 @@ class Plugin:
@dataclass
class Command:
"""命令使用统计数据"""
name: str
count: int
timestamp: int
+192 -18
View File
@@ -3,7 +3,7 @@ import os
import time
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
from . import BaseDatabase
from typing import Tuple
from typing import Tuple, List, Dict, Any
class SQLiteDatabase(BaseDatabase):
@@ -128,24 +128,23 @@ class SQLiteDatabase(BaseDatabase):
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
where_clause = ""
if session_id or provider_type:
where_clause += " WHERE "
has = False
if session_id:
where_clause += f"session_id = '{session_id}'"
has = True
if provider_type:
if has:
where_clause += " AND "
where_clause += f"provider_type = '{provider_type}'"
conditions = []
params = []
if session_id:
conditions.append("session_id = ?")
params.append(session_id)
if provider_type:
conditions.append("provider_type = ?")
params.append(provider_type)
sql = "SELECT * FROM llm_history"
if conditions:
sql += " WHERE " + " AND ".join(conditions)
c.execute(sql, params)
c.execute(
"""
SELECT * FROM llm_history
"""
+ where_clause
)
res = c.fetchall()
histories = []
for row in res:
@@ -389,3 +388,178 @@ class SQLiteDatabase(BaseDatabase):
if res:
return ATRIVision(*res)
return None
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0,确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0,确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
+5 -3
View File
@@ -38,11 +38,13 @@ CREATE TABLE IF NOT EXISTS atri_vision(
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
);
PRAGMA encoding = 'UTF-8';
+35 -5
View File
@@ -1,3 +1,16 @@
"""
事件总线, 用于处理事件的分发和处理
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
class:
EventBus: 事件总线, 用于处理事件的分发和处理
工作流程:
1. 维护一个异步队列, 来接受各种消息事件
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
"""
import asyncio
from asyncio import Queue
from astrbot.core.pipeline.scheduler import PipelineScheduler
@@ -6,21 +19,38 @@ from .platform import AstrMessageEvent
class EventBus:
"""事件总线: 用于处理事件的分发和处理
维护一个异步队列, 来接受各种消息事件
"""
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
self.event_queue = event_queue
self.pipeline_scheduler = pipeline_scheduler
self.event_queue = event_queue # 事件队列
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
async def dispatch(self):
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
while True:
event: AstrMessageEvent = await self.event_queue.get()
self._print_event(event)
asyncio.create_task(self.pipeline_scheduler.execute(event))
event: AstrMessageEvent = (
await self.event_queue.get()
) # 从事件队列中获取新的事件
self._print_event(event) # 打印日志
asyncio.create_task(
self.pipeline_scheduler.execute(event)
) # 创建新的异步任务来执行管道调度器的处理逻辑
def _print_event(self, event: AstrMessageEvent):
"""用于记录事件信息
Args:
event (AstrMessageEvent): 事件对象
"""
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
if event.get_sender_name():
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
)
# 没有发送者名称: [平台名] 发送者ID: 消息概要
else:
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
@@ -1,18 +1,27 @@
"""
AstrBot 启动器负责初始化和启动核心组件和仪表板服务器
工作流程:
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
2. 运行核心生命周期任务和仪表板服务器
"""
import asyncio
import traceback
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .server import AstrBotDashboard
from astrbot.core.db import BaseDatabase
from astrbot.core import LogBroker
from astrbot.dashboard.server import AstrBotDashboard
class AstrBotDashBoardLifecycle:
class InitialLoader:
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
self.db = db
self.logger = logger
self.log_broker = log_broker
self.dashboard_server = None
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
@@ -25,11 +34,15 @@ class AstrBotDashBoardLifecycle:
logger.critical(traceback.format_exc())
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
task = asyncio.gather(core_task, self.dashboard_server.run())
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
)
task = asyncio.gather(
core_task, self.dashboard_server.run()
) # 启动核心任务和仪表板服务器
try:
await task
await task # 整个AstrBot在这里运行
except asyncio.CancelledError:
logger.info("🌈 正在关闭 AstrBot...")
await core_lifecycle.stop()
+117 -18
View File
@@ -1,3 +1,26 @@
"""
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
const:
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
class:
LogBroker: 日志代理类, 用于缓存和分发日志消息
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
LogManager: 日志管理器, 用于创建和配置日志记录器
function:
is_plugin_path: 检查文件路径是否来自插件目录
get_short_level_name: 将日志级别名称转换为四个字母的缩写
工作流程:
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
"""
import logging
import colorlog
import asyncio
@@ -6,7 +29,9 @@ from collections import deque
from asyncio import Queue
from typing import List
# 日志缓存大小
CACHED_SIZE = 200
# 日志颜色配置
log_color_config = {
"DEBUG": "green",
"INFO": "bold_cyan",
@@ -19,8 +44,13 @@ log_color_config = {
def is_plugin_path(pathname):
"""
检查文件路径是否来自插件目录
"""检查文件路径是否来自插件目录
Args:
pathname (str): 文件路径
Returns:
bool: 如果路径来自插件目录,则返回 True,否则返回 False
"""
if not pathname:
return False
@@ -30,8 +60,13 @@ def is_plugin_path(pathname):
def get_short_level_name(level_name):
"""
将日志级别名称转换为四个字母的缩写
"""将日志级别名称转换为四个字母的缩写
Args:
level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
Returns:
str: 四个字母的日志级别缩写
"""
level_map = {
"DEBUG": "DBUG",
@@ -44,12 +79,21 @@ def get_short_level_name(level_name):
class LogBroker:
"""日志代理类, 用于缓存和分发日志消息
发布-订阅模式
"""
def __init__(self):
self.log_cache = deque(maxlen=CACHED_SIZE)
self.subscribers: List[Queue] = []
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: List[Queue] = [] # 订阅者列表
def register(self) -> Queue:
"""给每个订阅者返回一个带有日志缓存的队列"""
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
Returns:
Queue: 订阅者的队列, 可用于接收日志消息
"""
q = Queue(maxsize=CACHED_SIZE + 10)
for log in self.log_cache:
q.put_nowait(log)
@@ -57,11 +101,20 @@ class LogBroker:
return q
def unregister(self, q: Queue):
"""取消订阅"""
"""取消订阅
Args:
q (Queue): 需要取消订阅的队列
"""
self.subscribers.remove(q)
def publish(self, log_entry: str):
"""发布消息"""
def publish(self, log_entry: dict):
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
Args:
log_entry (dict): 日志消息, 包含日志级别和日志内容.
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
"""
self.log_cache.append(log_entry)
for q in self.subscribers:
try:
@@ -71,24 +124,57 @@ class LogBroker:
class LogQueueHandler(logging.Handler):
"""日志处理器, 用于将日志消息发送到 LogBroker
继承自 logging.Handler
"""
def __init__(self, log_broker: LogBroker):
super().__init__()
self.log_broker = log_broker
def emit(self, record):
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
这个方法会在每次日志记录时被调用
Args:
record (logging.LogRecord): 日志记录对象, 包含日志信息
"""
log_entry = self.format(record)
self.log_broker.publish(log_entry)
self.log_broker.publish({
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
})
class LogManager:
"""日志管理器, 用于创建和配置日志记录器
提供了获取默认日志记录器logger和设置队列处理器的方法
"""
@classmethod
def GetLogger(cls, log_name: str = "default"):
"""获取指定名称的日志记录器logger
Args:
log_name (str): 日志记录器的名称, 默认为 "default"
Returns:
logging.Logger: 返回配置好的日志记录器
"""
logger = logging.getLogger(log_name)
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
if logger.hasHandlers():
return logger
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
# 如果logger没有处理器
console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出
console_handler.setLevel(
logging.DEBUG
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
console_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
datefmt="%H:%M:%S",
@@ -96,6 +182,8 @@ class LogManager:
)
class PluginFilter(logging.Filter):
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
def filter(self, record):
record.plugin_tag = (
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
@@ -103,6 +191,9 @@ class LogManager:
return True
class FileNameFilter(logging.Filter):
"""文件名过滤器类, 用于修改日志记录的文件名格式
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record):
dirname = os.path.dirname(record.pathname)
@@ -114,22 +205,30 @@ class LogManager:
return True
class LevelNameFilter(logging.Filter):
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
# 添加短日志级别名称
def filter(self, record):
record.short_levelname = get_short_level_name(record.levelname)
return True
console_handler.setFormatter(console_formatter)
logger.addFilter(PluginFilter())
logger.addFilter(FileNameFilter())
console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
logger.addFilter(PluginFilter()) # 添加插件过滤器
logger.addFilter(FileNameFilter()) # 添加文件名过滤器
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
logger.addHandler(console_handler) # 添加处理器到logger
return logger
@classmethod
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
"""设置队列处理器, 用于将日志消息发送到 LogBroker
Args:
logger (logging.Logger): 日志记录器
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
"""
handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG)
if logger.handlers:
+109 -4
View File
@@ -25,9 +25,11 @@ SOFTWARE.
import base64
import json
import os
import uuid
import typing as T
from enum import Enum
from pydantic.v1 import BaseModel
from astrbot.core.utils.io import download_image_by_url, file_to_base64
class ComponentType(Enum):
@@ -59,6 +61,8 @@ class ComponentType(Enum):
TTS = "TTS"
Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
class BaseMessageComponent(BaseModel):
type: ComponentType
@@ -146,6 +150,51 @@ class Record(BaseMessageComponent):
return Record(file=url, **_)
raise Exception("not a valid url")
async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 语音的本地路径,以绝对路径表示。
"""
if self.file and self.file.startswith("file:///"):
file_path = self.file[8:]
return file_path
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path)
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
file_path = f"data/temp/{uuid.uuid4()}.jpg"
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
elif os.path.exists(self.file):
file_path = self.file
return os.path.abspath(file_path)
else:
raise Exception(f"not a valid file: {self.file}")
async def convert_to_base64(self) -> str:
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
Returns:
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
if self.file and self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
bs64_data = file_to_base64(file_path)
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
else:
raise Exception(f"not a valid file: {self.file}")
return bs64_data
class Video(BaseMessageComponent):
type: ComponentType = "Video"
@@ -279,10 +328,6 @@ class Image(BaseMessageComponent):
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: T.Optional[str], **_):
# for k in _.keys():
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
# (k == "c" and _[k] not in [2, 3]):
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
@@ -307,6 +352,53 @@ class Image(BaseMessageComponent):
def fromIO(IO):
return Image.fromBytes(IO.read())
async def convert_to_file_path(self) -> str:
"""将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 图片的本地路径,以绝对路径表示。
"""
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
image_file_path = url[8:]
return image_file_path
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path)
elif url and url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
image_file_path = f"data/temp/{uuid.uuid4()}.jpg"
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
elif os.path.exists(url):
image_file_path = url
return os.path.abspath(image_file_path)
else:
raise Exception(f"not a valid file: {url}")
async def convert_to_base64(self) -> str:
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
Returns:
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
bs64_data = file_to_base64(url[8:])
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
bs64_data = file_to_base64(image_file_path)
elif url and url.startswith("base64://"):
bs64_data = url
elif os.path.exists(url):
bs64_data = file_to_base64(url)
else:
raise Exception(f"not a valid file: {url}")
return bs64_data
class Reply(BaseMessageComponent):
type: ComponentType = "Reply"
@@ -322,6 +414,8 @@ class Reply(BaseMessageComponent):
"""引用的消息发送时间"""
message_str: T.Optional[str] = ""
"""解析后的纯文本消息字符串"""
sender_str: T.Optional[str] = ""
"""被引用的消息纯文本"""
text: T.Optional[str] = ""
"""deprecated"""
@@ -469,6 +563,16 @@ class File(BaseMessageComponent):
super().__init__(name=name, file=file)
class WechatEmoji(BaseMessageComponent):
type: ComponentType = "WechatEmoji"
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
ComponentTypes = {
"plain": Plain,
"text": Plain,
@@ -497,4 +601,5 @@ ComponentTypes = {
"tts": TTS,
"unknown": Unknown,
"file": File,
"WechatEmoji": WechatEmoji,
}
+37 -6
View File
@@ -1,8 +1,14 @@
import enum
from typing import List, Optional
from typing import List, Optional, Union
from dataclasses import dataclass, field
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
from astrbot.core.message.components import (
BaseMessageComponent,
Plain,
Image,
At,
AtAll,
)
from typing_extensions import deprecated
@@ -31,6 +37,30 @@ class MessageChain:
self.chain.append(Plain(message))
return self
def at(self, name: str, qq: Union[str, int]):
"""添加一条 At 消息到消息链 `chain` 中。
Example:
CommandResult().at("张三", "12345678910")
# 输出 @张三
"""
self.chain.append(At(name=name, qq=qq))
return self
def at_all(self):
"""添加一条 AtAll 消息到消息链 `chain` 中。
Example:
CommandResult().at_all()
# 输出 @所有人
"""
self.chain.append(AtAll())
return self
@deprecated("请使用 message 方法代替。")
def error(self, message: str):
"""添加一条错误消息到消息链 `chain` 中
@@ -77,6 +107,10 @@ class MessageChain:
self.use_t2i_ = use_t2i
return self
def get_plain_text(self) -> str:
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
class EventResultType(enum.Enum):
"""用于描述事件处理的结果类型。
@@ -147,9 +181,6 @@ class MessageEventResult(MessageChain):
"""是否为 LLM 结果。"""
return self.result_content_type == ResultContentType.LLM_RESULT
def get_plain_text(self) -> str:
"""获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
# 为了兼容旧版代码,保留 CommandResult 的别名
CommandResult = MessageEventResult
+1
View File
@@ -12,6 +12,7 @@ from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
+4 -2
View File
@@ -5,5 +5,7 @@ from astrbot.core.star import PluginManager
@dataclass
class PipelineContext:
astrbot_config: AstrBotConfig
plugin_manager: PluginManager
"""上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
@@ -16,7 +16,13 @@ from astrbot.core.message.message_event_result import (
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.provider.entites import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult,
)
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
@@ -28,6 +34,9 @@ class LLMRequestSubStage(Stage):
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
"wake_prefix"
] # str
self.max_context_length = ctx.astrbot_config["provider_settings"][
"max_context_length"
] # int
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -64,8 +73,8 @@ class LLMRequestSubStage(Stage):
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(
@@ -75,10 +84,16 @@ class LLMRequestSubStage(Stage):
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
req.session_id = event.unified_msg_origin
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
if not conversation:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
@@ -110,33 +125,51 @@ class LLMRequestSubStage(Stage):
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# max context length
if (
self.max_context_length != -1 # -1 为不限制
and len(req.contexts) // 2 > self.max_context_length
):
logger.debug("上下文长度超过限制,将截断。")
req.contexts = req.contexts[-self.max_context_length * 2 :]
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
try:
logger.debug(f"提供商请求 Payload: {req}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
async for result in self._handle_llm_response(event, req, llm_response):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
asyncio.create_task(
Metric.upload(
@@ -146,72 +179,8 @@ class LLMRequestSubStage(Stage):
)
)
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.LLM_RESULT)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# function calling
function_calling_result = {}
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
)
for func_tool_name, func_tool_args in zip(
llm_response.tools_call_name, llm_response.tools_call_args
):
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
try:
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
function_calling_result[func_tool_name] = resp
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
function_calling_result[func_tool_name] = (
"When calling the function, an error occurred: " + str(e)
)
if function_calling_result:
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
# 我们重新执行一遍这个 stage
req.func_tool = None # 暂时不支持递归工具调用
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
for tool_name, tool_result in function_calling_result.items():
extra_prompt += (
f"Tool: {tool_name}\nTool Result: {tool_result}\n"
)
req.prompt += extra_prompt
async for _ in self.process(event, _nested=True):
yield
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
except BaseException as e:
logger.error(traceback.format_exc())
@@ -222,6 +191,116 @@ class LLMRequestSubStage(Stage):
)
return
async def _handle_llm_response(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
) -> AsyncGenerator[None, None]:
"""处理 LLM 响应。
Returns:
bool: 是否需要继续调用 LLM
Yields:
Iterator[bool]: 将 event 交付给下一个 stage
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.LLM_RESULT)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# function calling
tool_call_result: list[ToolCallMessageSegment] = []
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
)
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[
func_tool.mcp_server_name
]
res = await client.session.call_tool(
func_tool.name, func_tool_args
)
if res:
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
else:
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
)
if tool_call_result:
# 函数调用结果
req.func_tool = None # 暂时不支持递归工具调用
assistant_msg_seg = AssistantMessageSegment(
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
)
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
req.tool_calls_result = ToolCallsResult(
tool_calls_info=assistant_msg_seg,
tool_calls_result=tool_call_result,
)
yield req # 再次执行 LLM 请求
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
async def _save_to_history(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
):
@@ -231,8 +310,12 @@ class LLMRequestSubStage(Stage):
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts
new_record = {"role": "user", "content": req.prompt}
contexts.append(new_record)
contexts.append(await req.assemble_context())
# tool calls result
if req.tool_calls_result:
contexts.extend(req.tool_calls_result.to_openai_messages())
contexts.append(
{"role": "assistant", "content": llm_response.completion_text}
)
+77 -6
View File
@@ -2,6 +2,7 @@ import random
import asyncio
import math
import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
@@ -11,11 +12,42 @@ from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.components import Plain, Reply, At
@register_stage
class RespondStage(Stage):
# 组件类型到其非空判断函数的映射
_component_validators = {
Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), # 纯文本消息需要strip
Comp.Face: lambda comp: comp.id is not None, # QQ表情
Comp.Record: lambda comp: bool(comp.file), # 语音
Comp.Video: lambda comp: bool(comp.file), # 视频
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
Comp.AtAll: lambda comp: True, # @所有人
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
Comp.Dice: lambda comp: True, # 骰子(未完成)
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
Comp.Contact: lambda comp: True, # 联系人(未完成)
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
Comp.Music: lambda comp: bool(comp._type) and bool(comp.url) and bool(comp.audio), # 音乐
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.RedBag: lambda comp: bool(comp.title), # 红包
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
Comp.Json: lambda comp: bool(comp.data), # JSON
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
Comp.File: lambda comp: bool(comp.file), # 文件
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
}
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
@@ -62,7 +94,7 @@ class RespondStage(Stage):
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
"""分段回复 计算间隔时间"""
if self.interval_method == "log":
if isinstance(comp, Plain):
if isinstance(comp, Comp.Plain):
wc = await self._word_cnt(comp.text)
i = math.log(wc + 1, self.log_base)
return random.uniform(i, i + 0.5)
@@ -72,6 +104,28 @@ class RespondStage(Stage):
# random
return random.uniform(self.interval[0], self.interval[1])
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
"""检查消息链是否为空
Args:
chain (list[BaseMessageComponent]): 包含消息对象的列表
"""
if not chain:
return True
for comp in chain:
comp_type = type(comp)
# 检查组件类型是否在字典中
if comp_type in self._component_validators:
if self._component_validators[comp_type](comp):
return False
else:
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
# 如果所有组件都为空
return True
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
@@ -82,6 +136,16 @@ class RespondStage(Stage):
if len(result.chain) > 0:
await event._pre_send()
# 检查消息链是否为空
try:
if await self._is_empty_message_chain(result.chain):
logger.info("消息为空,跳过发送阶段")
event.clear_result()
event.stop_event()
return
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
if self.enable_seg and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
@@ -89,13 +153,13 @@ class RespondStage(Stage):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, At):
if isinstance(comp, Comp.At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Reply):
if isinstance(comp, Comp.Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
@@ -103,9 +167,16 @@ class RespondStage(Stage):
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
await event.send(MessageChain([*decorated_comps, comp]))
try:
await event.send(MessageChain([*decorated_comps, comp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
else:
await event.send(result)
try:
await event.send(result)
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
await event._post_send()
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
@@ -31,6 +31,8 @@ class ResultDecorateStage(Stage):
self.t2i_word_threshold = 50
except BaseException:
self.t2i_word_threshold = 150
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
self.t2i_use_network = self.t2i_strategy == "remote"
self.forward_threshold = ctx.astrbot_config["platform_settings"][
"forward_threshold"
@@ -192,7 +194,9 @@ class ResultDecorateStage(Stage):
if plain_str and len(plain_str) > self.t2i_word_threshold:
render_start = time.time()
try:
url = await html_renderer.render_t2i(plain_str, return_url=True)
url = await html_renderer.render_t2i(
plain_str, return_url=True, use_network=self.t2i_use_network
)
except BaseException:
logger.error("文本转图片失败,使用文本发送。")
return
@@ -201,7 +205,10 @@ class ResultDecorateStage(Stage):
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
)
if url:
result.chain = [Image.fromURL(url)]
if url.startswith("http"):
result.chain = [Image.fromURL(url)]
else:
result.chain = [Image.fromFileSystem(url)]
# 触发转发消息
has_forwarded = False
+35 -12
View File
@@ -7,49 +7,72 @@ from astrbot.core import logger
class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext):
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__.__name__))
self.ctx = context
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
) # 按照顺序排序
self.ctx = context # 上下文对象
async def initialize(self):
"""初始化管道调度器时, 初始化所有阶段"""
for stage in registered_stages:
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
await stage.initialize(self.ctx)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
"""依次执行各个阶段
Args:
event (AstrMessageEvent): 事件对象
from_stage (int): 从第几个阶段开始执行, 默认从0开始
"""
for i in range(from_stage, len(registered_stages)):
stage = registered_stages[i]
stage = registered_stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
coro = stage.process(event)
if isinstance(coro, AsyncGenerator):
async for _ in coro:
coroutine = stage.process(
event
) # 调用阶段的process方法, 返回协程或者异步生成器
if isinstance(coroutine, AsyncGenerator):
# 如果返回的是异步生成器, 实现洋葱模型的核心
async for _ in coroutine:
# 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段
if event.is_stopped():
logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
)
break
# 递归调用, 处理所有后续阶段
await self._process_stages(event, i + 1)
# 此处是后续所有阶段处理完毕后返回的点, 执行后置处理
if event.is_stopped():
logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
)
break
else:
await coro
# 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件)
# 简单地等待它执行完成, 然后继续执行下一个阶段
await coroutine
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
break
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
break
async def execute(self, event: AstrMessageEvent):
"""执行 pipeline"""
"""执行 pipeline
Args:
event (AstrMessageEvent): 事件对象
"""
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if not event._has_send_oper and event.get_platform_name() == "webchat":
await event.send(None)
+44 -14
View File
@@ -8,8 +8,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = []
"""维护了所有已注册的 Stage 实现类"""
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
def register_stage(cls):
@@ -23,14 +22,24 @@ class Stage(abc.ABC):
@abc.abstractmethod
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化阶段"""
"""初始化阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
raise NotImplementedError
@abc.abstractmethod
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理事件"""
"""处理事件
Args:
event (AstrMessageEvent): 事件对象,包含事件的相关信息
Returns:
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
"""
raise NotImplementedError
async def _call_handler(
@@ -41,9 +50,23 @@ class Stage(abc.ABC):
*args,
**kwargs,
) -> AsyncGenerator[None, None]:
"""调用 Handler。"""
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 待处理的事件对象
handler (Awaitable): 事件处理函数
*args: 传递给handler的位置参数
**kwargs: 传递给handler的关键字参数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器(async def)
trace_ = None
@@ -52,29 +75,36 @@ class Stage(abc.ABC):
except TypeError as _:
# 向下兼容
trace_ = traceback.format_exc()
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator):
_has_yielded = False
# 如果是一个异步生成器, 进入洋葱模型
_has_yielded = False # 是否返回过值
try:
async for ret in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None(无返回值)
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield
yield # 传递控制权给上一层的process函数
else:
yield ret
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret # 传递控制权给上一层的process函数
if not _has_yielded:
# 如果这个异步生成器没有执行到yield分支
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
yield # 传递控制权给上一层的process函数
else:
yield ret
yield ret # 传递控制权给上一层的process函数
@@ -21,6 +21,11 @@ class WakingCheckStage(Stage):
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化唤醒检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
"no_permission_reply", True
@@ -15,6 +15,9 @@ class WhitelistCheckStage(Stage):
"enable_id_white_list"
]
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
self.whitelist = [
str(i).strip() for i in self.whitelist if str(i).strip() != ""
]
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
"wl_ignore_admin_on_group"
]
@@ -53,7 +56,7 @@ class WhitelistCheckStage(Stage):
return
if (
event.unified_msg_origin not in self.whitelist
and event.get_group_id() not in self.whitelist
and str(event.get_group_id()).strip() not in self.whitelist
):
if self.wl_log:
logger.info(
+2 -1
View File
@@ -1,7 +1,7 @@
from .platform import Platform
from .astr_message_event import AstrMessageEvent
from .platform_metadata import PlatformMetadata
from .astrbot_message import AstrBotMessage, MessageMember, MessageType
from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group
__all__ = [
"Platform",
@@ -10,4 +10,5 @@ __all__ = [
"AstrBotMessage",
"MessageMember",
"MessageType",
"Group",
]
+31 -16
View File
@@ -1,11 +1,9 @@
import abc
import asyncio
from dataclasses import dataclass
from .astrbot_message import AstrBotMessage
from .platform_metadata import PlatformMetadata
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from typing import List, Union
from typing import List, Union, Optional
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
Plain,
Image,
@@ -16,9 +14,12 @@ from astrbot.core.message.components import (
Forward,
Reply,
)
from astrbot.core.utils.metrics import Metric
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.db.po import Conversation
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
@dataclass
@@ -201,15 +202,6 @@ class AstrMessageEvent(abc.ABC):
"""
return self.role == "admin"
async def send(self, message: MessageChain):
"""
发送消息到消息平台。
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
)
self._has_send_oper = True
async def _pre_send(self):
"""调度器会在执行 send() 前调用该方法"""
@@ -371,3 +363,26 @@ class AstrMessageEvent(abc.ABC):
system_prompt=system_prompt,
conversation=conversation,
)
"""平台适配器"""
async def send(self, message: MessageChain):
"""发送消息到消息平台。
Args:
message (MessageChain): 消息链,具体使用方式请参考文档。
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
)
self._has_send_oper = True
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
适配情况:
- gewechat
- aiocqhttp(OneBotv11)
"""
...
+35
View File
@@ -10,6 +10,41 @@ class MessageMember:
user_id: str # 发送者id
nickname: str = None
def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
return (
f"User ID: {self.user_id},"
f"Nickname: {self.nickname if self.nickname else 'N/A'}"
)
@dataclass
class Group:
group_id: str
"""群号"""
group_name: str = None
"""群名称"""
group_avatar: str = None
"""群头像"""
group_owner: str = None
"""群主 id"""
group_admins: List[str] = None
"""群管理员 id"""
members: List[MessageMember] = None
"""所有群成员"""
def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
return (
f"Group ID: {self.group_id}\n"
f"Name: {self.group_name if self.group_name else 'N/A'}\n"
f"Avatar: {self.group_avatar if self.group_avatar else 'N/A'}\n"
f"Owner ID: {self.group_owner if self.group_owner else 'N/A'}\n"
f"Admin IDs: {self.group_admins if self.group_admins else 'N/A'}\n"
f"Members Len: {len(self.members) if self.members else 0}\n"
f"First Member: {self.members[0] if self.members else 'N/A'}\n"
)
class AstrBotMessage:
"""
+40 -32
View File
@@ -85,14 +85,18 @@ class PlatformManager:
)
return
cls_type = platform_cls_map[platform_config["type"]]
inst = cls_type(platform_config, self.settings, self.event_queue)
self._inst_map[platform_config["id"]] = inst
inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
self._inst_map[platform_config["id"]] = {
"inst": inst,
"client_id": inst.client_self_id,
}
self.platform_insts.append(inst)
asyncio.create_task(
self._task_wrapper(
asyncio.create_task(
inst.run(), name=platform_config["id"] + "_platform"
inst.run(),
name=f"platform_{platform_config['type']}_{platform_config['id']}",
)
)
)
@@ -109,38 +113,42 @@ class PlatformManager:
logger.error("-------")
async def reload(self, platform_config: dict):
# 还未实现完成,不要调用此方法
if platform_config["id"] in self._inst_map:
# 正在运行
if getattr(self._inst_map[platform_config["id"]], "terminate", None):
logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...")
await self._inst_map[platform_config["id"]].terminate()
logger.info(f"{platform_config['id']} 平台适配器已终止。")
del self._inst_map[platform_config["id"]]
self.platform_insts.remove(self._inst_map[platform_config["id"]])
else:
logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。")
# 再启动新的实例
await self.terminate_platform(platform_config["id"])
if platform_config["enable"]:
await self.load_platform(platform_config)
else:
# 先将 _inst_map 中在 platform_config 中不存在的实例删除
config_ids = [platform["id"] for platform in self.platforms_config]
for key in list(self._inst_map.keys()):
if key not in config_ids:
if getattr(self._inst_map[key], "terminate", None):
logger.info(f"正在尝试终止 {key} 平台适配器 ...")
await self._inst_map[key].terminate()
logger.info(f"{key} 平台适配器已终止。")
del self._inst_map[key]
self.platform_insts.remove(self._inst_map[key])
else:
logger.warning(f"可能无法正常终止 {key} 平台适配器。")
# 和配置文件保持同步
config_ids = [provider["id"] for provider in self.platforms_config]
for key in list(self._inst_map.keys()):
if key not in config_ids:
await self.terminate_platform(key)
# 再启动新的实例
await self.load_platform(platform_config)
async def terminate_platform(self, platform_id: str):
if platform_id in self._inst_map:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
# client_id = self._inst_map.pop(platform_id, None)
info = self._inst_map.pop(platform_id, None)
client_id = info["client_id"]
inst = info["inst"]
try:
self.platform_insts.remove(
next(
inst
for inst in self.platform_insts
if inst.client_self_id == client_id
)
)
except Exception:
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
if getattr(inst, "terminate", None):
await inst.terminate()
async def terminate(self):
for inst in self.platform_insts:
if getattr(inst, "terminate", None):
await inst.terminate()
def get_insts(self):
return self.platform_insts
+3 -1
View File
@@ -1,4 +1,5 @@
import abc
import uuid
from typing import Awaitable, Any
from asyncio import Queue
from .platform_metadata import PlatformMetadata
@@ -13,6 +14,7 @@ class Platform(abc.ABC):
super().__init__()
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
self._event_queue = event_queue
self.client_self_id = uuid.uuid4().hex
@abc.abstractmethod
def run(self) -> Awaitable[Any]:
@@ -25,7 +27,7 @@ class Platform(abc.ABC):
"""
终止一个平台的运行实例。
"""
pass
...
@abc.abstractmethod
def meta(self) -> PlatformMetadata:
@@ -1,9 +1,9 @@
import asyncio
import typing
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import Group, MessageMember
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
class AiocqhttpMessageEvent(AstrMessageEvent):
@@ -22,20 +22,14 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
if isinstance(segment, Plain):
d["type"] = "text"
d["data"]["text"] = segment.text.strip()
# 如果是空文本或者只带换行符的文本,不发送
if not d["data"]["text"]:
continue
elif isinstance(segment, (Image, Record)):
# convert to base64
if segment.file and segment.file.startswith("file:///"):
bs64_data = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
bs64_data = file_to_base64(image_file_path)
elif segment.file and segment.file.startswith("base64://"):
bs64_data = segment.file
else:
bs64_data = file_to_base64(segment.file)
bs64 = await segment.convert_to_base64()
d["data"] = {
"file": bs64_data,
"file": bs64,
}
elif isinstance(segment, At):
d["data"] = {
@@ -47,6 +41,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if not ret:
return
send_one_by_one = False
for seg in message.chain:
if isinstance(seg, (Node, Nodes)):
@@ -84,3 +81,46 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)
async def get_group(self, group_id=None, **kwargs):
if isinstance(group_id, str) and group_id.isdigit():
group_id = int(group_id)
elif self.get_group_id():
group_id = int(self.get_group_id())
else:
return None
info: dict = await self.bot.call_action(
"get_group_info",
group_id=group_id,
)
members: typing.List[typing.Dict] = await self.bot.call_action(
"get_group_member_list",
group_id=group_id,
)
owner_id = None
admin_ids = []
for member in members:
if member["role"] == "owner":
owner_id = member["user_id"]
if member["role"] == "admin":
admin_ids.append(member["user_id"])
group = Group(
group_id=str(group_id),
group_name=info.get("group_name"),
group_avatar="",
group_admins=admin_ids,
group_owner=str(owner_id),
members=[
MessageMember(
user_id=member["user_id"],
nickname=member.get("nickname") or member.get("card"),
)
for member in members
],
)
return group
@@ -43,8 +43,6 @@ class AiocqhttpAdapter(Platform):
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
)
self.stop = False
self.bot = CQHttp(
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
)
@@ -303,22 +301,19 @@ class AiocqhttpAdapter(Platform):
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.getLogger("aiocqhttp").setLevel(logging.ERROR)
self.shutdown_event = asyncio.Event()
return coro
async def terminate(self):
self.stop = True
await asyncio.sleep(1)
self.shutdown_event.set()
async def shutdown_trigger_placeholder(self):
await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被优雅地关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
async def shutdown_trigger_placeholder(self):
# TODO: use asyncio.Event
while not self._event_queue.closed and not self.stop: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("aiocqhttp 适配器已关闭。")
async def handle_msg(self, message: AstrBotMessage):
message_event = AiocqhttpMessageEvent(
message_str=message.message_str,
@@ -2,6 +2,7 @@ import asyncio
import uuid
import aiohttp
import dingtalk_stream
import threading
from astrbot.api.platform import (
Platform,
@@ -196,7 +197,31 @@ class DingtalkPlatformAdapter(Platform):
self._event_queue.put_nowait(event)
async def run(self):
await self.client_.start()
# await self.client_.start()
# 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。
def start_client(loop: asyncio.AbstractEventLoop):
try:
self._shutdown_event = threading.Event()
task = loop.create_task(self.client_.start())
self._shutdown_event.wait()
if task.done():
task.result()
except Exception as e:
if "Graceful shutdown" in str(e):
logger.info("钉钉适配器已被优雅地关闭")
return
logger.error(f"钉钉机器人启动失败: {e}")
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self):
def monkey_patch_close():
raise Exception("Graceful shutdown")
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
self._shutdown_event.set()
def get_client(self):
return self.client
@@ -24,7 +24,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
if isinstance(segment, Comp.Plain):
segment.text = segment.text.strip()
await asyncio.get_event_loop().run_in_executor(
None, client.reply_text, segment.text, self.message_obj.raw_message
None, client.reply_markdown, "AstrBot", segment.text, self.message_obj.raw_message
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
+264 -33
View File
@@ -1,17 +1,26 @@
import threading
import asyncio
import aiohttp
import quart
import base64
import datetime
import re
import os
import re
import threading
import aiohttp
import anyio
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At, Record
import quart
from astrbot.api import logger, sp
from .downloader import GeweDownloader
from astrbot.api.message_components import Plain, Image, At, Record, Video
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.core.utils.io import download_image_by_url
from .downloader import GeweDownloader
try:
from .xml_data_parser import GeweDataParser
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
)
class SimpleGewechatClient:
@@ -51,11 +60,11 @@ class SimpleGewechatClient:
self.server = quart.Quart(__name__)
self.server.add_url_rule(
"/astrbot-gewechat/callback", view_func=self.callback, methods=["POST"]
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
)
self.server.add_url_rule(
"/astrbot-gewechat/file/<file_id>",
view_func=self.handle_file,
view_func=self._handle_file,
methods=["GET"],
)
@@ -70,9 +79,10 @@ class SimpleGewechatClient:
self.userrealnames = {}
self.stop = False
self.shutdown_event = asyncio.Event()
async def get_token_id(self):
"""获取 Gewechat Token。"""
async with aiohttp.ClientSession() as session:
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
json_blob = await resp.json()
@@ -192,6 +202,11 @@ class SimpleGewechatClient:
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
if user_id == "weixin":
# 忽略微信团队消息
return
# 不同消息类型
match d["MsgType"]:
case 1:
@@ -209,15 +224,10 @@ class SimpleGewechatClient:
case 34:
# 语音消息
# data = await self.multimedia_downloader.download_voice(
# self.appid,
# content,
# abm.message_id
# )
# print(data)
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
@@ -228,15 +238,19 @@ class SimpleGewechatClient:
case 42: # 名片
logger.info("消息类型(42):名片")
case 43: # 视频
logger.info("消息类型(43):视频")
video = Video(file="", cover=content)
abm.message.append(video)
case 47: # emoji
logger.info("消息类型(47)emoji")
data_parser = GeweDataParser(content, abm.group_id == "")
emoji = data_parser.parse_emoji()
abm.message.append(emoji)
case 48: # 地理位置
logger.info("消息类型(48):地理位置")
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
logger.info(
"消息类型(49):公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请"
)
data_parser = GeweDataParser(content, abm.group_id == "")
abm_data = data_parser.parse_mutil_49()
if abm_data:
abm.message.append(abm_data)
case 51: # 帐号消息同步?
logger.info("消息类型(51):帐号消息同步?")
case 10000: # 被踢出群聊/更换群主/修改群名称
@@ -253,7 +267,7 @@ class SimpleGewechatClient:
logger.debug(f"abm: {abm}")
return abm
async def callback(self):
async def _callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
@@ -275,7 +289,7 @@ class SimpleGewechatClient:
return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id):
async def _handle_file(self, file_id):
file_path = f"data/temp/{file_id}"
return await quart.send_file(file_path)
@@ -301,17 +315,14 @@ class SimpleGewechatClient:
await self.server.run_task(
host="0.0.0.0",
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger_placeholder(self):
# TODO: use asyncio.Event
while not self.event_queue.closed and not self.stop: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("gewechat 适配器已关闭。")
async def shutdown_trigger(self):
await self.shutdown_event.wait()
async def check_online(self, appid: str):
# /login/checkOnline
"""检查 APPID 对应的设备是否在线。"""
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkOnline",
@@ -322,6 +333,7 @@ class SimpleGewechatClient:
return json_blob["data"]
async def logout(self):
"""登出 gewechat。"""
if self.appid:
online = await self.check_online(self.appid)
if online:
@@ -335,6 +347,7 @@ class SimpleGewechatClient:
logger.info(f"登出结果: {json_blob}")
async def login(self):
"""登录 gewechat。一般来说插件用不到这个方法。"""
if self.token is None:
await self.get_token_id()
@@ -446,9 +459,18 @@ class SimpleGewechatClient:
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
"""API"""
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
"""
async def get_chatroom_member_list(self, chatroom_wxid: str):
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
"""获取群成员列表。
Args:
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
Returns:
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
"""
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
async with aiohttp.ClientSession() as session:
@@ -461,6 +483,7 @@ class SimpleGewechatClient:
return json_blob["data"]
async def post_text(self, to_wxid, content: str, ats: str = ""):
"""发送纯文本消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
@@ -477,6 +500,7 @@ class SimpleGewechatClient:
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
"""发送图片消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
@@ -490,7 +514,79 @@ class SimpleGewechatClient:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
"""发送emoji消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"emojiMd5": emoji_md5,
"emojiSize": emoji_size,
}
# 优先表情包,若拿不到表情包的md5,就用当作图片发
try:
if emoji_md5 != "" and emoji_size != "":
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postEmoji",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
)
else:
await self.post_image(to_wxid, cdnurl)
except Exception as e:
logger.error(e)
async def post_video(
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"videoUrl": video_url,
"thumbUrl": thumb_url,
"videoDuration": video_duration,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送视频结果: {json_blob}")
async def forward_video(self, to_wxid, cnd_xml: str):
"""转发视频
Args:
to_wxid (str): 发送给谁
cnd_xml (str): 视频消息的cdn信息
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"xml": cnd_xml,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/forwardVideo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"转发视频结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
"""发送语音信息
Args:
voice_url (str): 语音文件的网络链接
voice_duration (int): 语音时长,毫秒
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
@@ -505,9 +601,16 @@ class SimpleGewechatClient:
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}")
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
async def post_file(self, to_wxid, file_url: str, file_name: str):
"""发送文件
Args:
to_wxid (string): 微信ID
file_url (str): 文件的网络链接
file_name (str): 文件名
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
@@ -521,3 +624,131 @@ class SimpleGewechatClient:
) as resp:
json_blob = await resp.json()
logger.debug(f"发送文件结果: {json_blob}")
async def add_friend(self, v3: str, v4: str, content: str):
"""申请添加好友"""
payload = {
"appId": self.appid,
"scene": 3,
"content": content,
"v4": v4,
"v3": v3,
"option": 2,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/addContacts",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"申请添加好友结果: {json_blob}")
return json_blob
async def get_group(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_group_member(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def accept_group_invite(self, url: str):
"""同意进群"""
payload = {"appId": self.appid, "url": url}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/agreeJoinRoom",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def add_group_member_to_friend(
self, group_id: str, to_wxid: str, content: str
):
payload = {
"appId": self.appid,
"chatroomId": group_id,
"content": content,
"memberWxid": to_wxid,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/addGroupMemberAsFriend",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_user_or_group_info(self, *ids):
"""
获取用户或群组信息。
:param ids: 可变数量的 wxid 参数
"""
wxids_str = list(ids)
payload = {
"appId": self.appid,
"wxids": wxids_str, # 使用逗号分隔的字符串
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/getDetailInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_contacts_list(self):
"""
获取通讯录列表
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
"""
payload = {"appId": self.appid}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/fetchContactsList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取通讯录列表结果: {json_blob}")
return json_blob
@@ -39,3 +39,17 @@ class GeweDownloader:
continue
raise Exception("无法下载图片")
async def download_emoji_md5(self, app_id, emoji_md5):
"""下载emoji"""
try:
payload = {"appId": app_id, "emojiMd5": emoji_md5}
# gewe 计划中的接口,暂时没有实现。返回代码404
data = await self._post_json(
self.base_url, "/message/downloadEmojiMd5", payload
)
json_blob = json.loads(data)
return json_blob
except BaseException as e:
logger.error(f"gewe download emoji: {e}")
@@ -2,12 +2,21 @@ import wave
import uuid
import traceback
import os
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
from astrbot.core.utils.io import save_temp_img, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record, At, File
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
from astrbot.api.message_components import (
Plain,
Image,
Record,
At,
File,
Video,
WechatEmoji as Emoji,
)
from .client import SimpleGewechatClient
@@ -70,18 +79,10 @@ class GewechatPlatformEvent(AstrMessageEvent):
await client.post_text(**payload)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
img_path = await comp.convert_to_file_path()
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
# 检查 record_path 是否在 data/temp 目录中
temp_directory = os.path.abspath("data/temp")
img_path = os.path.abspath(img_path)
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
@@ -90,17 +91,65 @@ class GewechatPlatformEvent(AstrMessageEvent):
img_url = f"{client.file_server_url}/{file_id}"
logger.debug(f"gewe callback img url: {img_url}")
await client.post_image(to_wxid, img_url)
elif isinstance(comp, Video):
if comp.cover != "":
await client.forward_video(to_wxid, comp.cover)
else:
try:
from pyffmpeg import FFmpeg
except (ImportError, ModuleNotFoundError):
logger.error(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
raise ModuleNotFoundError(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
video_url = comp.file
# 根据 url 下载视频
video_filename = f"{uuid.uuid4()}.mp4"
video_path = f"data/temp/{video_filename}"
await download_file(video_url, video_path)
# 获取视频第一帧
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
try:
ff = FFmpeg()
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
ff.options(command)
thumb_file_id = os.path.basename(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
except Exception as e:
logger.error(f"获取视频第一帧失败: {e}")
# 获取视频时长
try:
from pyffmpeg import FFprobe
# 创建 FFprobe 实例
ffprobe = FFprobe(video_url)
# 获取时长字符串
duration_str = ffprobe.duration
# 处理时长字符串
video_duration = float(duration_str.replace(":", ""))
except Exception as e:
logger.error(f"获取时长失败: {e}")
video_duration = 10
file_id = os.path.basename(video_path)
video_url = f"{client.file_server_url}/{file_id}"
await client.post_video(
to_wxid, video_url, thumb_url, video_duration
)
# 删除临时视频和缩略图文件
if os.path.exists(video_path):
os.remove(video_path)
if os.path.exists(thumb_path):
os.remove(thumb_path)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
record_path = await comp.convert_to_file_path()
silk_path = f"data/temp/{uuid.uuid4()}.silk"
try:
@@ -129,6 +178,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
file_url = f"{client.file_server_url}/{file_id}"
logger.debug(f"gewe callback file url: {file_url}")
await client.post_file(to_wxid, file_url, file_id)
elif isinstance(comp, Emoji):
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
elif isinstance(comp, At):
pass
else:
@@ -138,3 +189,30 @@ class GewechatPlatformEvent(AstrMessageEvent):
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
await super().send(message)
async def get_group(self, group_id=None, **kwargs):
# 确定有效的 group_id
if group_id is None:
group_id = self.get_group_id()
if not group_id:
return None
res = await self.client.get_group(group_id)
data: dict = res["data"]
if not data["chatroomId"]:
return None
members = [
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
for member in data.get("memberList", [])
]
return Group(
group_id=data["chatroomId"],
group_name=data.get("nickName"),
group_avatar=data.get("smallHeadImgUrl"),
group_owner=data.get("chatRoomOwner"),
members=members,
)
@@ -8,6 +8,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .gewechat_event import GewechatPlatformEvent
from .client import SimpleGewechatClient
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
@@ -64,8 +65,9 @@ class GewechatPlatformAdapter(Platform):
)
async def terminate(self):
self.client.stop = True
await asyncio.sleep(1)
self.client.shutdown_event.set()
await self.client.server.shutdown()
logger.info("Gewechat 适配器已被优雅地关闭。")
async def logout(self):
await self.client.logout()
@@ -0,0 +1,78 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import WechatEmoji as Emoji, Reply, Plain
class GeweDataParser:
def __init__(self, data, is_private_chat):
self.data = data
self.is_private_chat = is_private_chat
def _format_to_xml(self):
return eT.fromstring(self.data)
def parse_mutil_49(self):
appmsg_type = self._format_to_xml().find(".//appmsg/type")
if appmsg_type is None:
return
match appmsg_type.text:
case "57":
return self.parse_reply()
def parse_emoji(self) -> Emoji | None:
try:
emoji_element = self._format_to_xml().find(".//emoji")
# 提取 md5 和 len 属性
if emoji_element is not None:
md5_value = emoji_element.get("md5")
emoji_size = emoji_element.get("len")
cdnurl = emoji_element.get("cdnurl")
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
except Exception as e:
logger.error(f"gewechat: parse_emoji failed, {e}")
def parse_reply(self) -> Reply | None:
try:
replied_id = -1
replied_uid = 0
replied_nickname = ""
replied_content = ""
content = ""
root = self._format_to_xml()
refermsg = root.find(".//refermsg")
if refermsg is not None:
# 被引用的信息
svrid = refermsg.find("svrid")
fromusr = refermsg.find("fromusr")
displayname = refermsg.find("displayname")
refermsg_content = refermsg.find("content")
if svrid is not None:
replied_id = svrid.text
if fromusr is not None:
replied_uid = fromusr.text
if displayname is not None:
replied_nickname = displayname.text
if refermsg_content is not None:
replied_content = refermsg_content.text
# 提取引用者说的内容
title = root.find(".//appmsg/title")
if title is not None:
content = title.text
r = Reply(
id=replied_id,
chain=[Plain(content)],
sender_id=replied_uid,
sender_nickname=replied_nickname,
sender_str=replied_content,
message_str=content,
)
return r
except Exception as e:
logger.error(f"gewechat: parse_reply failed, {e}")
@@ -2,6 +2,7 @@ import base64
import asyncio
import json
import re
import astrbot.api.message_components as Comp
from astrbot.api.platform import (
Platform,
@@ -11,7 +12,6 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .lark_event import LarkMessageEvent
from ...register import register_platform_adapter
@@ -92,7 +92,7 @@ class LarkPlatformAdapter(Platform):
at_list = {}
if message.mentions:
for m in message.mentions:
at_list[m.key] = At(qq=m.id.open_id, name=m.name)
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
if m.name == self.bot_name:
abm.self_id = m.id.open_id
@@ -111,7 +111,7 @@ class LarkPlatformAdapter(Platform):
if s in at_list:
abm.message.append(at_list[s])
else:
abm.message.append(Plain(parts[i].strip()))
abm.message.append(Comp.Plain(parts[i].strip()))
elif message.message_type == "post":
_ls = []
@@ -132,7 +132,7 @@ class LarkPlatformAdapter(Platform):
if comp["tag"] == "at":
abm.message.append(at_list[comp["user_id"]])
elif comp["tag"] == "text" and comp["text"].strip():
abm.message.append(Plain(comp["text"].strip()))
abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp["tag"] == "img":
image_key = comp["image_key"]
request = (
@@ -147,10 +147,10 @@ class LarkPlatformAdapter(Platform):
logger.error(f"无法下载飞书图片: {image_key}")
image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Image.fromBase64(image_base64))
abm.message.append(Comp.Image.fromBase64(image_base64))
for comp in abm.message:
if isinstance(comp, Plain):
if isinstance(comp, Comp.Plain):
abm.message_str += comp.text
abm.message_id = message.message_id
abm.raw_message = message
@@ -185,5 +185,9 @@ class LarkPlatformAdapter(Platform):
# self.client.start()
await self.client._connect()
async def terminate(self):
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已被优雅地关闭")
def get_client(self) -> lark.Client:
return self.client
@@ -17,6 +17,7 @@ from astrbot.api.platform import (
MessageType,
PlatformMetadata,
)
from astrbot import logger
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain, At
@@ -204,3 +205,7 @@ class QQOfficialPlatformAdapter(Platform):
def get_client(self) -> botClient:
return self.client
async def terminate(self):
await self.client.close()
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")
@@ -13,6 +13,7 @@ from .qo_webhook_event import QQOfficialWebhookMessageEvent
from ...register import register_platform_adapter
from .qo_webhook_server import QQOfficialWebhook
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
from astrbot import logger
# remove logger handler
for handler in logging.root.handlers[:]:
@@ -111,3 +112,9 @@ class QQOfficialWebhookPlatformAdapter(Platform):
def get_client(self) -> botClient:
return self.client
async def terminate(self):
self.webhook_helper.shutdown_event.set()
await self.client.close()
await self.webhook_helper.server.shutdown()
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
@@ -30,6 +30,7 @@ class QQOfficialWebhook:
)
self.client = botpy_client
self.event_queue = event_queue
self.shutdown_event = asyncio.Event()
async def initialize(self):
logger.info("正在登录到 QQ 官方机器人...")
@@ -102,10 +103,8 @@ class QQOfficialWebhook:
await self.server.run_task(
host=self.callback_server_host,
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger_placeholder(self):
while not self.event_queue.closed: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("qq_official_webhook 适配器已关闭。")
async def shutdown_trigger(self):
await self.shutdown_event.wait()
@@ -1,6 +1,7 @@
import sys
import uuid
import asyncio
import astrbot.api.message_components as Comp
from astrbot.api.platform import (
Platform,
@@ -10,15 +11,6 @@ from astrbot.api.platform import (
MessageType,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import (
Plain,
Image,
Record,
File as AstrBotFile,
Video,
At,
Reply,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.api.platform import register_platform_adapter
@@ -108,7 +100,8 @@ class TelegramPlatformAdapter(Platform):
async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
logger.debug(f"Telegram message: {update.message}")
abm = await self.convert_message(update, context)
await self.handle_msg(abm)
if abm:
await self.handle_msg(abm)
async def convert_message(
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
@@ -120,6 +113,7 @@ class TelegramPlatformAdapter(Platform):
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
message = AstrBotMessage()
message.session_id = str(update.message.chat.id)
# 获得是群聊还是私聊
if update.message.chat.type == ChatType.PRIVATE:
message.type = MessageType.FRIEND_MESSAGE
@@ -129,9 +123,9 @@ class TelegramPlatformAdapter(Platform):
if update.message.message_thread_id:
# Topic Group
message.group_id += "#" + str(update.message.message_thread_id)
message.session_id = message.group_id
message.message_id = str(update.message.message_id)
message.session_id = str(update.message.chat.id)
message.sender = MessageMember(
str(update.message.from_user.id), update.message.from_user.username
)
@@ -140,7 +134,11 @@ class TelegramPlatformAdapter(Platform):
message.message_str = ""
message.message = []
if update.message.reply_to_message:
if update.message.reply_to_message and not (
update.message.is_topic_message
and update.message.message_thread_id
== update.message.reply_to_message.message_id
):
# 获取回复消息
reply_update = Update(
update_id=1,
@@ -149,7 +147,7 @@ class TelegramPlatformAdapter(Platform):
reply_abm = await self.convert_message(reply_update, context, False)
message.message.append(
Reply(
Comp.Reply(
id=reply_abm.message_id,
chain=reply_abm.message,
sender_id=reply_abm.sender.user_id,
@@ -171,43 +169,60 @@ class TelegramPlatformAdapter(Platform):
name = plain_text[
entity.offset + 1 : entity.offset + entity.length
]
message.message.append(At(qq=name, name=name))
message.message.append(Comp.At(qq=name, name=name))
plain_text = (
plain_text[: entity.offset]
+ plain_text[entity.offset + entity.length :]
)
if plain_text:
message.message.append(Plain(plain_text))
message.message.append(Comp.Plain(plain_text))
message.message_str = plain_text
if message.message_str == "/start":
if message.message_str.strip() == "/start":
await self.start(update, context)
return
elif update.message.voice:
file = await update.message.voice.get_file()
message.message = [
Record(file=file.file_path, url=file.file_path),
Comp.Record(file=file.file_path, url=file.file_path),
]
elif update.message.photo:
photo = update.message.photo[-1] # get the largest photo
file = await photo.get_file()
message.message.append(Image(file=file.file_path, url=file.file_path))
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
if update.message.caption:
message.message_str = update.message.caption
message.message.append(Comp.Plain(message.message_str))
if update.message.caption_entities:
for entity in update.message.caption_entities:
if entity.type == "mention":
name = message.message_str[
entity.offset + 1 : entity.offset + entity.length
]
message.message.append(Comp.At(qq=name, name=name))
elif update.message.sticker:
# 将sticker当作图片处理
file = await update.message.sticker.get_file()
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
if update.message.sticker.emoji:
sticker_text = f"Sticker: {update.message.sticker.emoji}"
message.message_str = sticker_text
message.message.append(Comp.Plain(sticker_text))
elif update.message.document:
file = await update.message.document.get_file()
message.message = [
AstrBotFile(
file=file.file_path, name=update.message.document.file_name
),
Comp.File(file=file.file_path, name=update.message.document.file_name),
]
elif update.message.video:
file = await update.message.video.get_file()
message.message = [
Video(file=file.file_path, path=file.file_path),
Comp.Video(file=file.file_path, path=file.file_path),
]
return message
@@ -224,3 +239,15 @@ class TelegramPlatformAdapter(Platform):
def get_client(self) -> ExtBot:
return self.client
async def terminate(self):
try:
await self.application.stop()
# 保险起见先判断是否存在updater对象
if self.application.updater is not None:
await self.application.updater.stop()
logger.info("Telegram 适配器已被优雅地关闭")
except Exception as e:
logger.error(f"Telegram 适配器关闭时出错: {e}")
@@ -1,8 +1,10 @@
import telegramify_markdown
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
from telegram.ext import ExtBot
from astrbot.core.utils.io import download_file
from astrbot import logger
class TelegramPlatformEvent(AstrMessageEvent):
@@ -43,27 +45,26 @@ class TelegramPlatformEvent(AstrMessageEvent):
if has_reply:
payload["reply_to_message_id"] = reply_message_id
if message_thread_id:
payload["reply_to_message_id"] = message_thread_id
payload["message_thread_id"] = message_thread_id
if isinstance(i, Plain):
if at_user_id and not at_flag:
i.text = f"@{at_user_id} " + i.text
at_flag = True
await client.send_message(text=i.text, **payload)
text = i.text
try:
text = telegramify_markdown.markdownify(
i.text, max_line_length=None, normalize_whitespace=False
)
except Exception as e:
logger.warning(
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
)
return
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
elif isinstance(i, Image):
if i.path:
image_path = i.path
else:
image_path = i.file
if image_path.startswith("base64://"):
import base64
base64_data = image_path[9:]
image_bytes = base64.b64decode(base64_data)
await client.send_photo(photo=image_bytes, **payload)
else:
await client.send_photo(photo=image_path, **payload)
image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload)
elif isinstance(i, File):
if i.file.startswith("https://"):
path = "data/temp/" + i.name
@@ -72,7 +73,8 @@ class TelegramPlatformEvent(AstrMessageEvent):
await client.send_document(document=i.file, filename=i.name, **payload)
elif isinstance(i, Record):
await client.send_voice(voice=i.file, **payload)
path = await i.convert_to_file_path()
await client.send_voice(voice=path, **payload)
async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -119,3 +119,7 @@ class WebChatAdapter(Platform):
)
self.commit_event(message_event)
async def terminate(self):
# Do nothing
pass
@@ -3,7 +3,7 @@ import uuid
import base64
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, Record
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import web_chat_back_queue
@@ -47,6 +47,22 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
elif isinstance(comp, Record):
# save record to local
filename = str(uuid.uuid4()) + ".wav"
path = os.path.join(imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid))
else:
logger.debug(f"webchat 忽略: {comp.type}")
web_chat_back_queue.put_nowait(None)
@@ -50,6 +50,7 @@ class WecomServer:
)
self.callback = None
self.shutdown_event = asyncio.Event()
async def verify(self):
logger.info(f"验证请求有效性: {quart.request.args}")
@@ -93,13 +94,11 @@ class WecomServer:
await self.server.run_task(
host=self.callback_server_host,
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger_placeholder(self):
while not self.event_queue.closed: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("企业微信 适配器已关闭。")
async def shutdown_trigger(self):
await self.shutdown_event.wait()
@register_platform_adapter("wecom", "wecom 适配器")
@@ -235,3 +234,8 @@ class WecomPlatformAdapter(Platform):
def get_client(self) -> WeChatClient:
return self.client
async def terminate(self):
self.server.shutdown_event.set()
await self.server.server.shutdown()
logger.info("企业微信 适配器已被优雅地关闭")
@@ -3,7 +3,6 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from wechatpy.enterprise import WeChatClient
from astrbot.core.utils.io import download_image_by_url, download_file
from astrbot.api import logger
@@ -43,14 +42,7 @@ class WecomPlatformEvent(AstrMessageEvent):
message_obj.self_id, message_obj.session_id, comp.text
)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
img_path = await comp.convert_to_file_path()
with open(img_path, "rb") as f:
try:
@@ -68,16 +60,7 @@ class WecomPlatformEvent(AstrMessageEvent):
response["media_id"],
)
elif isinstance(comp, Record):
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
record_path = await comp.convert_to_file_path()
# 转成amr
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
pydub.AudioSegment.from_wav(record_path).export(
+203 -4
View File
@@ -1,10 +1,18 @@
import enum
import base64
import json
from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from astrbot.core.db.po import Conversation
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
class ProviderType(enum.Enum):
@@ -28,6 +36,58 @@ class ProviderMetaData:
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
class ToolCallMessageSegment:
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
tool_call_id: str
content: str
role: str = "tool"
def to_dict(self):
return {
"tool_call_id": self.tool_call_id,
"content": self.content,
"role": self.role,
}
@dataclass
class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
role: str = "assistant"
def to_dict(self):
ret = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
elif self.tool_calls:
ret["tool_calls"] = self.tool_calls
return ret
@dataclass
class ToolCallsResult:
"""工具调用结果"""
tool_calls_info: AssistantMessageSegment
"""函数调用的信息"""
tool_calls_result: List[ToolCallMessageSegment]
"""函数调用的结果"""
def to_openai_messages(self) -> List[Dict]:
ret = [
self.tool_calls_info.to_dict(),
*[item.to_dict() for item in self.tool_calls_result],
]
return ret
@dataclass
class ProviderRequest:
prompt: str
@@ -37,7 +97,7 @@ class ProviderRequest:
image_urls: List[str] = None
"""图片 URL 列表"""
func_tool: FuncCall = None
"""工具"""
"""可用的函数工具"""
contexts: List = None
"""上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
@@ -46,12 +106,85 @@ class ProviderRequest:
"""系统提示词"""
conversation: Conversation = None
tool_calls_result: ToolCallsResult = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt.strip()})"
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
def __str__(self):
return self.__repr__()
def _print_friendly_context(self):
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
if not self.contexts:
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
result_parts = []
for ctx in self.contexts:
role = ctx.get("role", "unknown")
content = ctx.get("content", "")
if isinstance(content, str):
result_parts.append(f"{role}: {content}")
elif isinstance(content, list):
msg_parts = []
image_count = 0
for item in content:
item_type = item.get("type", "")
if item_type == "text":
msg_parts.append(item.get("text", ""))
elif item_type == "image_url":
image_count += 1
if image_count > 0:
if msg_parts:
msg_parts.append(f"[+{image_count} images]")
else:
msg_parts.append(f"[{image_count} images]")
result_parts.append(f"{role}: {''.join(msg_parts)}")
return result_parts
async def assemble_context(self) -> Dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
if self.image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": self.prompt}],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self._encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self._encode_image_bs64(image_path)
else:
image_data = await self._encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}}
)
return user_content
else:
return {"role": "user", "content": self.prompt}
async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
@dataclass
class LLMResponse:
@@ -59,12 +192,78 @@ class LLMResponse:
"""角色, assistant, tool, err"""
result_chain: MessageChain = None
"""返回的消息链"""
completion_text: str = ""
"""LLM 返回的文本, 已经废弃但仍然兼容。使用 result_chain 替代"""
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
"""工具调用参数"""
tools_call_name: List[str] = field(default_factory=list)
"""工具调用名称"""
tools_call_ids: List[str] = field(default_factory=list)
"""工具调用 ID"""
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
_completion_text: str = ""
def __init__(
self,
role: str,
completion_text: str = "",
result_chain: MessageChain = None,
tools_call_args: List[Dict[str, any]] = [],
tools_call_name: List[str] = [],
tools_call_ids: List[str] = [],
raw_completion: ChatCompletion = None,
_new_record: Dict[str, any] = None,
):
"""初始化 LLMResponse
Args:
role (str): 角色, assistant, tool, err
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
"""
self.role = role
self.completion_text = completion_text
self.result_chain = result_chain
self.tools_call_args = tools_call_args
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.raw_completion = raw_completion
self._new_record = _new_record
@property
def completion_text(self):
if self.result_chain:
return self.result_chain.get_plain_text()
return self._completion_text
@completion_text.setter
def completion_text(self, value):
if self.result_chain:
self.result_chain.chain = [
comp
for comp in self.result_chain.chain
if not isinstance(comp, Comp.Plain)
] # 清空 Plain 组件
self.result_chain.chain.insert(0, Comp.Plain(value))
else:
self._completion_text = value
def to_openai_tool_calls(self) -> List[Dict]:
"""将工具调用信息转换为 OpenAI 格式"""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
ret.append(
{
"id": self.tools_call_ids[idx],
"function": {
"name": self.tools_call_name[idx],
"arguments": json.dumps(tool_call_arg),
},
"type": "function",
}
)
return ret
+279 -21
View File
@@ -1,9 +1,31 @@
from __future__ import annotations
import json
import textwrap
from typing import Dict, List, Awaitable
import os
import asyncio
import copy
from typing import Dict, List, Awaitable, Literal, Any
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
try:
import mcp
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
SUPPORTED_TYPES = [
"string",
"number",
"object",
"array",
"boolean",
] # json schema 支持的数据类型
@dataclass
class FuncTool:
@@ -14,28 +36,101 @@ class FuncTool:
name: str
parameters: Dict
description: str
handler: Awaitable
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
handler: Awaitable = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
# MCP 相关字段
mcp_server_name: str = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}), active={self.active})"
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
async def execute(self, **args) -> Any:
"""执行函数调用"""
if self.origin == "local":
if not self.handler:
raise Exception(f"Local function {self.name} has no handler")
return await self.handler(**args)
elif self.origin == "mcp":
if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name
if ":" in self.name:
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
actual_tool_name = self.name.split(":")[-1]
return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
return await self.mcp_client.session.call_tool(self.name, args)
else:
raise Exception(f"Unknown function origin: {self.origin}")
SUPPORTED_TYPES = [
"string",
"number",
"object",
"array",
"boolean",
] # json schema 支持的数据类型
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.active: bool = True
self.tools: List[mcp.Tool] = []
async def connect_to_server(self, mcp_server_config: dict):
"""Connect to an MCP server
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = mcp_server_config.copy()
cfg.pop("active", None)
server_params = mcp.StdioServerParameters(
**cfg,
)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(server_params)
)
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(self.stdio, self.write)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
logger.debug(f"MCP server {self.name} list tools response: {response}")
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
class FuncCall:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_service_queue = asyncio.Queue()
"""用于外部控制 MCP 服务的启停"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
def empty(self) -> bool:
return len(self.func_list) == 0
@@ -90,11 +185,166 @@ class FuncCall:
return f
return None
async def _init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
"mcpServers": {
"weather": {
"command": "uv",
"args": [
"--directory",
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
"run",
"weather.py"
]
}
}
...
}
```
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(current_dir, "../../../data"))
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
if not os.path.exists(mcp_json_file):
# 配置文件不存在错误处理
with open(mcp_json_file, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
return
mcp_server_json_obj: Dict[str, Dict] = json.load(
open(mcp_json_file, "r", encoding="utf-8")
)["mcpServers"]
for name in mcp_server_json_obj.keys():
cfg = mcp_server_json_obj[name]
if cfg.get("active", True):
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event)
)
self.mcp_client_event[name] = event
async def mcp_service_selector(self):
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
{"type": "init"} 初始化所有MCP客户端
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
{"type": "terminate"} 终止所有MCP客户端
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
"""
while True:
data = await self.mcp_service_queue.get()
if data["type"] == "init":
if "name" in data:
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(
data["name"], data["cfg"], event
)
)
self.mcp_client_event[data["name"]] = event
else:
await self._init_mcp_clients()
elif data["type"] == "terminate":
if "name" in data:
# await self._terminate_mcp_client(data["name"])
if data["name"] in self.mcp_client_event:
self.mcp_client_event[data["name"]].set()
self.mcp_client_event.pop(data["name"], None)
else:
for name in self.mcp_client_dict.keys():
# await self._terminate_mcp_client(name)
# self.mcp_client_event[name].set()
if name in self.mcp_client_event:
self.mcp_client_event[name].set()
self.mcp_client_event.pop(name, None)
async def _init_mcp_client_task_wrapper(
self, name: str, cfg: dict, event: asyncio.Event
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
await self._init_mcp_client(name, cfg)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
await self._terminate_mcp_client(name)
except Exception as e:
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端"""
try:
# 先清理之前的客户端,如果存在
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
mcp_client = MCPClient()
mcp_client.name = name
await mcp_client.connect_to_server(config)
tools_res = await mcp_client.list_tools_and_save()
tool_names = [tool.name for tool in tools_res.tools]
self.mcp_client_dict[name] = mcp_client
# 移除该MCP服务之前的工具(如有)
self.func_list = [
f
for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == name)
]
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
for tool in mcp_client.tools:
func_tool = FuncTool(
name=tool.name,
parameters=tool.inputSchema,
description=tool.description,
origin="mcp",
mcp_server_name=name,
mcp_client=mcp_client,
)
self.func_list.append(func_tool)
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
return True
except Exception as e:
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
# 发生错误时确保客户端被清理
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
return False
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
if name in self.mcp_client_dict:
try:
# 关闭MCP连接
await self.mcp_client_dict[name].cleanup()
del self.mcp_client_dict[name]
except Exception as e:
logger.info(f"清空 MCP 客户端资源 {name}: {e}")
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
def get_func_desc_openai_style(self) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""
_l = []
# 处理所有工具(包括本地和MCP工具)
for f in self.func_list:
if not f.active:
continue
@@ -144,7 +394,13 @@ class FuncCall:
# 检查并添加非空的properties参数
params = f.parameters if isinstance(f.parameters, dict) else {}
params = copy.deepcopy(params)
if params.get("properties", {}):
properties = params["properties"]
for key, value in properties.items():
if "default" in value:
del value["default"]
params["properties"] = properties
func_declaration["parameters"] = params
tools.append(func_declaration)
@@ -160,9 +416,9 @@ class FuncCall:
continue
_l.append(
{
"name": f["name"],
"parameters": f["parameters"],
"description": f["description"],
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
)
func_definition = json.dumps(_l, ensure_ascii=False)
@@ -212,14 +468,11 @@ class FuncCall:
func_name = tool["name"]
args = tool["args"]
# 调用函数
tool_callable = None
for func in self.func_list:
if func.name == func_name:
tool_callable = func.star_handler_metadata.handler
break
if not tool_callable:
func_tool = self.get_func(func_name)
if not func_tool:
raise Exception(f"Request function {func_name} not found.")
ret = await tool_callable(**args)
ret = await func_tool.execute(**args)
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
@@ -229,3 +482,8 @@ class FuncCall:
def __repr__(self):
return str(self.func_list)
async def terminate(self):
for name in self.mcp_client_dict.keys():
await self._terminate_mcp_client(name)
logger.debug(f"清理 MCP 客户端 {name} 资源")
+45
View File
@@ -1,4 +1,5 @@
import traceback
import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
@@ -127,6 +128,12 @@ class ProviderManager:
if self.tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
# 初始化 MCP Client 连接
asyncio.create_task(
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
)
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
return
@@ -191,6 +198,10 @@ class ProviderManager:
from .sources.fishaudio_tts_api_source import (
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
)
case "dashscope_tts":
from .sources.dashscope_tts import (
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
@@ -299,10 +310,42 @@ class ProviderManager:
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
elif (
self.curr_provider_inst is None
and len(self.provider_insts) > 0
and self.provider_enabled
):
self.curr_provider_inst = self.provider_insts[0]
self.selected_provider_id = self.curr_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
)
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
elif (
self.curr_stt_provider_inst is None
and len(self.stt_provider_insts) > 0
and self.stt_enabled
):
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
)
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
elif (
self.curr_tts_provider_inst is None
and len(self.tts_provider_insts) > 0
and self.tts_enabled
):
self.curr_tts_provider_inst = self.tts_provider_insts[0]
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
)
def get_insts(self):
return self.provider_insts
@@ -339,3 +382,5 @@ class ProviderManager:
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
await provider_inst.terminate()
# 清理 MCP Client 连接
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
+3 -1
View File
@@ -3,7 +3,7 @@ from typing import List
from astrbot.core.db import BaseDatabase
from typing import TypedDict
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from dataclasses import dataclass
@@ -90,6 +90,7 @@ class Provider(AbstractProvider):
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -100,6 +101,7 @@ class Provider(AbstractProvider):
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
@@ -10,7 +10,7 @@ from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from .openai_source import ProviderOpenAIOfficial
@@ -79,11 +79,14 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
# tools call (function calling)
args_ls = []
func_name_ls = []
tool_use_ids = []
func_name_ls.append(content.name)
args_ls.append(content.input)
tool_use_ids.append(content.id)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_use_ids
if not llm_response.completion_text and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
@@ -101,6 +104,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> LLMResponse:
if not prompt:
@@ -113,6 +117,10 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
if "_no_save" in part:
del part["_no_save"]
if tool_calls_result:
# 暂时这样写。
prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}"
model_config = self.provider_config.get("model_config", {})
payloads = {"messages": context_query, **model_config}
@@ -1,3 +1,4 @@
import re
import asyncio
import functools
from typing import List
@@ -40,11 +41,28 @@ class ProviderDashscope(ProviderOpenAIOfficial):
raise Exception("阿里云百炼 APP 类型不能为空。")
self.model_name = "dashscope"
self.variables: dict = provider_config.get("variables", {})
self.rag_options: dict = provider_config.get("rag_options", {})
self.output_reference = self.rag_options.get("output_reference", False)
self.rag_options = self.rag_options.copy()
self.rag_options.pop("output_reference", None)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
def has_rag_options(self):
"""判断是否有 RAG 选项
Returns:
bool: 是否有 RAG 选项
"""
if self.rag_options and (
len(self.rag_options.get("pipeline_ids", [])) > 0
or len(self.rag_options.get("file_ids", [])) > 0
):
return True
return False
async def text_chat(
self,
prompt: str,
@@ -62,7 +80,10 @@ class ProviderDashscope(ProviderOpenAIOfficial):
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
if self.dashscope_app_type in ["agent", "dialog-workflow"]:
if (
self.dashscope_app_type in ["agent", "dialog-workflow"]
and not self.has_rag_options()
):
# 支持多轮对话的
new_record = {"role": "user", "content": prompt}
if image_urls:
@@ -75,23 +96,31 @@ class ProviderDashscope(ProviderOpenAIOfficial):
if "_no_save" in part:
del part["_no_save"]
# 调用阿里云百炼 API
payload = {
"app_id": self.app_id,
"api_key": self.api_key,
"messages": context_query,
"biz_params": payload_vars or None,
}
partial = functools.partial(
Application.call,
app_id=self.app_id,
api_key=self.api_key,
messages=context_query,
biz_params=payload_vars or None,
**payload,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
else:
# 不支持多轮对话的
# 调用阿里云百炼 API
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
partial = functools.partial(
Application.call,
app_id=self.app_id,
promtp=prompt,
api_key=self.api_key,
biz_params=payload_vars or None,
**payload,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
@@ -107,6 +136,15 @@ class ProviderDashscope(ProviderOpenAIOfficial):
)
output_text = response.output.get("text", "")
# RAG 引用脚标格式化
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
if self.output_reference and response.output.get("doc_references", None):
ref_str = ""
for ref in response.output.get("doc_references", []):
ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
ref_str += f"{ref['index_id']}. {ref_title}\n"
output_text += f"\n\n回答来源:\n{ref_str}"
return LLMResponse(role="assistant", completion_text=output_text)
async def forget(self, session_id):
@@ -0,0 +1,39 @@
import dashscope
import uuid
import asyncio
from dashscope.audio.tts_v2 import *
from ..provider import TTSProvider
from ..entites import ProviderType
from ..register import register_provider_adapter
@register_provider_adapter(
"dashscope_tts", "Dashscope TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
)
class ProviderDashscopeTTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
self.set_model(provider_config.get("model", None))
self.timeout_ms = float(provider_config.get("timeout", 20))*1000
dashscope.api_key = self.chosen_api_key
self.synthesizer = SpeechSynthesizer(
model=self.get_model(),
voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
)
async def get_audio(self, text: str) -> str:
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
audio = await asyncio.get_event_loop().run_in_executor(
None, self.synthesizer.call, text, self.timeout_ms
)
with open(path, "wb") as f:
f.write(audio)
return path
+25 -21
View File
@@ -33,7 +33,6 @@ class ProviderDify(Provider):
if not self.api_key:
raise Exception("Dify API Key 不能为空。")
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_client = DifyAPIClient(self.api_key, api_base)
self.api_type = provider_config.get("dify_api_type", "")
if not self.api_type:
raise Exception("Dify API 类型不能为空。")
@@ -44,15 +43,19 @@ class ProviderDify(Provider):
self.dify_query_input_key = provider_config.get(
"dify_query_input_key", "astrbot_text_query"
)
self.variables: dict = provider_config.get("variables", {})
if not self.dify_query_input_key:
self.dify_query_input_key = "astrbot_text_query"
if not self.workflow_output_key:
self.workflow_output_key = "astrbot_wf_output"
self.variables: dict = provider_config.get("variables", {})
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.conversation_ids = {}
"""记录当前 session id 的对话 ID"""
self.api_client = DifyAPIClient(self.api_key, api_base)
async def text_chat(
self,
prompt: str,
@@ -68,26 +71,27 @@ class ProviderDify(Provider):
files_payload = []
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
file_response = await self.api_client.file_upload(
image_path, user=session_id
image_path = (
await download_image_by_url(image_url)
if image_url.startswith("http")
else image_url
)
file_response = await self.api_client.file_upload(
image_path, user=session_id
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
}
)
else:
# TODO: 处理更多情况
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
}
)
# 获得会话变量
payload_vars = self.variables.copy()
@@ -35,6 +35,8 @@ class ProviderEdgeTTS(TTSProvider):
self.pitch = provider_config.get("pitch", None)
self.timeout = provider_config.get("timeout", 30)
self.proxy = os.getenv("https_proxy", None)
self.set_model("edge_tts")
async def get_audio(self, text: str) -> str:
@@ -42,7 +44,7 @@ class ProviderEdgeTTS(TTSProvider):
mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3"
wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav"
# 构建Edge TTS参数
# 构建 Edge TTS 参数
kwargs = {"text": text, "voice": self.voice}
if self.rate:
kwargs["rate"] = self.rate
@@ -52,35 +54,45 @@ class ProviderEdgeTTS(TTSProvider):
kwargs["pitch"] = self.pitch
try:
communicate = edge_tts.Communicate(**kwargs)
communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs)
await communicate.save(mp3_path)
# 使用ffmpeg将MP3转换为标准WAV格式
_ = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y", # 覆盖输出文件
"-i",
mp3_path, # 输入文件
"-acodec",
"pcm_s16le", # 16位PCM编码
"-ar",
"24000", # 采样率24kHz (适合微信语音)
"-ac",
"1", # 单声道
"-af",
"apad=pad_dur=2", # 确保输出时长准确
"-fflags",
"+genpts", # 强制生成时间戳
"-hide_banner", # 隐藏版本信息
wav_path, # 输出文件
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# 等待进程完成并获取输出
stdout, stderr = await _.communicate()
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
logger.info(f"[EdgeTTS] 返回值(0代表成功): {_.returncode}")
try:
from pyffmpeg import FFmpeg
ff = FFmpeg()
ff.convert(input=mp3_path, output=wav_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line
# 使用ffmpeg将MP3转换为标准WAV格式
p = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y", # 覆盖输出文件
"-i",
mp3_path, # 输入文件
"-acodec",
"pcm_s16le", # 16位PCM编码
"-ar",
"24000", # 采样率24kHz (适合微信语音)
"-ac",
"1", # 单声道
"-af",
"apad=pad_dur=2", # 确保输出时长准确
"-fflags",
"+genpts", # 强制生成时间戳
"-hide_banner", # 隐藏版本信息
wav_path, # 输出文件
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# 等待进程完成并获取输出
stdout, stderr = await p.communicate()
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}")
os.remove(mp3_path)
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0:
return wav_path
+122 -46
View File
@@ -1,6 +1,10 @@
import base64
import aiohttp
import json
import random
import asyncio
import astrbot.core.message.components as Comp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
@@ -38,6 +42,8 @@ class SimpleGoogleGenAIClient:
model: str = "gemini-1.5-flash",
system_instruction: str = "",
tools: dict = None,
modalities: List[str] = ["Text"],
safety_settings: List[dict] = [],
):
payload = {}
if system_instruction:
@@ -45,6 +51,13 @@ class SimpleGoogleGenAIClient:
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
payload["generationConfig"] = {
"responseModalities": modalities,
}
payload["safetySettings"] = [
{"category": s["category"], "threshold": s["threshold"]}
for s in safety_settings
]
logger.debug(f"payload: {payload}")
request_url = (
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
@@ -98,6 +111,21 @@ class ProviderGoogleGenAI(Provider):
)
self.set_model(provider_config["model_config"]["model"])
safety_mapping = {
"harassment": "HARM_CATEGORY_HARASSMENT",
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
}
self.safety_settings = []
user_safety_config = self.provider_config.get("gm_safety_settings", {})
for config_key, harm_category in safety_mapping.items():
if threshold := user_safety_config.get(config_key):
self.safety_settings.append(
{"category": harm_category, "threshold": threshold}
)
async def get_models(self):
return await self.client.models_list()
@@ -119,7 +147,7 @@ class ProviderGoogleGenAI(Provider):
if message["role"] == "user":
if isinstance(message["content"], str):
if not message["content"]:
message["content"] = "<empty_content>"
message["content"] = ""
google_genai_conversation.append(
{"role": "user", "parts": [{"text": message["content"]}]}
@@ -130,7 +158,7 @@ class ProviderGoogleGenAI(Provider):
for part in message["content"]:
if part["type"] == "text":
if not part["text"]:
part["text"] = "<empty_content>"
part["text"] = ""
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
parts.append(
@@ -146,36 +174,105 @@ class ProviderGoogleGenAI(Provider):
google_genai_conversation.append({"role": "user", "parts": parts})
elif message["role"] == "assistant":
if not message["content"]:
message["content"] = "<empty_content>"
google_genai_conversation.append(
{"role": "model", "parts": [{"text": message["content"]}]}
if "content" in message:
if not message["content"]:
message["content"] = ""
google_genai_conversation.append(
{"role": "model", "parts": [{"text": message["content"]}]}
)
elif "tool_calls" in message:
# tool calls in the last turn
parts = []
for tool_call in message["tool_calls"]:
parts.append(
{
"functionCall": {
"name": tool_call["function"]["name"],
"args": json.loads(
tool_call["function"]["arguments"]
),
}
}
)
google_genai_conversation.append({"role": "model", "parts": parts})
elif message["role"] == "tool":
parts = []
parts.append(
{
"functionResponse": {
"name": message["tool_call_id"],
"response": {
"name": message["tool_call_id"],
"content": message["content"],
},
}
}
)
google_genai_conversation.append({"role": "user", "parts": parts})
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
result = await self.client.generate_content(
contents=google_genai_conversation,
model=self.get_model(),
system_instruction=system_instruction,
tools=tool,
)
logger.debug(f"result: {result}")
modalites = ["Text"]
if self.provider_config.get("gm_resp_image_modal", False):
modalites.append("Image")
if "candidates" not in result:
raise Exception("Gemini 返回异常结果: " + str(result))
loop = True
while loop:
loop = False
result = await self.client.generate_content(
contents=google_genai_conversation,
model=self.get_model(),
system_instruction=system_instruction,
tools=tool,
modalities=modalites,
safety_settings=self.safety_settings,
)
logger.debug(f"result: {result}")
# Developer instruction is not enabled for models/gemini-2.0-flash-exp
if "Developer instruction is not enabled" in str(result):
logger.warning(
f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。"
)
system_instruction = ""
loop = True
elif "Function calling is not enabled" in str(result):
logger.warning(
f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。"
)
tool = None
loop = True
elif "Multi-modal output is not supported" in str(result):
logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。"
)
modalites = ["Text"]
loop = True
elif "candidates" not in result:
raise Exception("Gemini 返回异常结果: " + str(result))
candidates = result["candidates"][0]["content"]["parts"]
llm_response = LLMResponse("assistant")
chain = []
for candidate in candidates:
if "text" in candidate:
llm_response.completion_text += candidate["text"]
chain.append(Comp.Plain(candidate["text"]))
elif "functionCall" in candidate:
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate["functionCall"]["args"])
llm_response.tools_call_name.append(candidate["functionCall"]["name"])
llm_response.tools_call_ids.append(
candidate["functionCall"]["name"]
) # 没有 tool id
elif "inlineData" in candidate:
mime_type: str = candidate["inlineData"]["mimeType"]
if mime_type.startswith("image/"):
chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"]))
llm_response.completion_text = llm_response.completion_text.strip()
llm_response.result_chain = MessageChain(chain=chain)
return llm_response
async def text_chat(
@@ -186,6 +283,7 @@ class ProviderGoogleGenAI(Provider):
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
@@ -198,6 +296,10 @@ class ProviderGoogleGenAI(Provider):
if "_no_save" in part:
del part["_no_save"]
# tool calls result
if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
@@ -214,46 +316,20 @@ class ProviderGoogleGenAI(Provider):
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 20
while retry_cnt > 0:
logger.warning(
f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
try:
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse(
"err", "err: 请尝试 /reset 重置会话"
)
elif "Function calling is not enabled" in str(e):
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
llm_response = await self._query(payloads, None)
break
elif "429" in str(e) or "API key not valid" in str(e):
if "429" in str(e) or "API key not valid" in str(e):
keys.remove(chosen_key)
if len(keys) > 0:
chosen_key = random.choice(keys)
logger.info(
f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..."
)
await asyncio.sleep(1)
continue
else:
logger.error(
f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..."
)
raise Exception("API 资源已耗尽,且没有可用的 Key 重试...")
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
else:
logger.error(
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
+78 -65
View File
@@ -2,6 +2,8 @@ import base64
import json
import os
import inspect
import random
import asyncio
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
@@ -120,15 +122,18 @@ class ProviderOpenAIOfficial(Provider):
# tools call (function calling)
args_ls = []
func_name_ls = []
tool_call_ids = []
for tool_call in choice.message.tool_calls:
for tool in tools.func_list:
if tool.name == tool_call.function.name:
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
tool_call_ids.append(tool_call.id)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_call_ids
if choice.finish_reason == "content_filter":
raise Exception(
@@ -151,6 +156,7 @@ class ProviderOpenAIOfficial(Provider):
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
@@ -162,82 +168,91 @@ class ProviderOpenAIOfficial(Provider):
if "_no_save" in part:
del part["_no_save"]
# tool calls result
if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
payloads = {"messages": context_query, **model_config}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except UnprocessableEntityError as e:
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
# 重试 10 次
retry_cnt = 20
while retry_cnt > 0:
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
try:
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse(
"err", "err: 请尝试 /reset 清除会话记录。"
)
elif "The model is not a VLM" in str(e): # siliconcloud
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
e = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
except UnprocessableEntityError as e:
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
llm_response = await self._query(payloads, func_tool)
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
elif (
"does not support Function Calling" in str(e)
or "does not support tools" in str(e)
or "Function call is not supported" in str(e)
or "Function calling is not enabled" in str(e)
or "Tool calling is not supported" in str(e)
or "No endpoints found that support tool use" in str(e)
or "model does not support function calling" in str(e)
or ("tool" in str(e) and "support" in str(e).lower())
or ("function" in str(e) and "support" in str(e).lower())
):
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
llm_response = await self._query(payloads, None)
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
context_query = new_contexts
except Exception as e:
if "429" in str(e):
logger.warning(
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
)
# 最后一次不等待
if retry_cnt < max_retries - 1:
await asyncio.sleep(1)
available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys)
continue
else:
raise e
elif "maximum context length" in str(e):
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
await self.pop_record(context_query)
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
elif (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
or ("function" in str(e).lower() and "support" in str(e).lower())
):
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
func_tool = None
else:
logger.error(
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
f"发生了错误。Provider 配置如下: {self.provider_config}"
)
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
)
raise e
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
)
raise e
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
raise e
return llm_response
async def _remove_image_from_context(self, contexts: List):
@@ -275,10 +290,8 @@ class ProviderOpenAIOfficial(Provider):
def set_key(self, key):
self.client.api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""
组装上下文
"""
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
if image_urls:
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
for image_url in image_urls:
@@ -48,14 +48,6 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return os.path.join("data", "temp", f"{timestamp}")
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = await self.get_timestamped_path() + ".mp3"
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data","temp', filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
@@ -31,14 +31,6 @@ class ProviderOpenAIWhisperAPI(STTProvider):
self.set_model(provider_config.get("model", None))
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + ".mp3"
ff = FFmpeg()
output_path = ff.convert(path, os.path.join("data/temp", filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
@@ -33,14 +33,6 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
)
logger.info("Whisper 模型加载完成。")
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + ".mp3"
ff = FFmpeg()
output_path = ff.convert(path, os.path.join("data/temp", filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
+3 -1
View File
@@ -4,12 +4,14 @@ from .context import Context
from astrbot.core.provider import Provider
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core import html_renderer
from astrbot.core.star.star_tools import StarTools
class Star(CommandParserMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
def __init__(self, context: Context):
StarTools.initialize(context)
self.context = context
async def text_to_image(self, text: str, return_url=True) -> str:
@@ -27,4 +29,4 @@ class Star(CommandParserMixin):
pass
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider"]
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
+67 -7
View File
@@ -332,7 +332,10 @@ class PluginManager:
)
# 绑定 llm_tool handler
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == metadata.module_path:
if (
func_tool.handler
and func_tool.handler.__module__ == metadata.module_path
):
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(
func_tool.handler, metadata.star_cls
@@ -448,7 +451,34 @@ class PluginManager:
# reload the plugin
dir_name = os.path.basename(plugin_path)
await self.load(specified_dir_name=dir_name)
return plugin_path
# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
if not plugin:
# Try to find by other name if directory name doesn't match plugin name
for star in self.context.get_all_stars():
if star.root_dir_name == dir_name:
plugin = star
break
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(plugin_path, "README.md")
if not os.path.exists(readme_path):
readme_path = os.path.join(plugin_path, "readme.md")
if os.path.exists(readme_path):
try:
with open(readme_path, "r", encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
plugin_info = None
if plugin:
plugin_info = {"repo": plugin.repo, "readme": readme_content}
return plugin_info
async def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
@@ -471,9 +501,11 @@ class PluginManager:
# 从 star_registry 和 star_map 中删除
await self._unbind_plugin(plugin_name, plugin.module_path)
if not remove_dir(os.path.join(ppath, root_dir_name)):
try:
remove_dir(os.path.join(ppath, root_dir_name))
except Exception as e:
raise Exception(
"移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
)
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
@@ -553,7 +585,7 @@ class PluginManager:
async def _terminate_plugin(self, star_metadata: StarMetadata):
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
logging.info(f"正在终止插件 {star_metadata.name} ...")
logger.info(f"正在终止插件 {star_metadata.name} ...")
if not star_metadata.activated:
# 说明之前已经被禁用了
@@ -564,7 +596,7 @@ class PluginManager:
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
)
else:
elif hasattr(star_metadata.star_cls, "terminate"):
await star_metadata.star_cls.terminate()
async def turn_on_plugin(self, plugin_name: str):
@@ -601,4 +633,32 @@ class PluginManager:
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
# await self.reload()
await self.load(desti_dir)
await self.load(specified_dir_name=dir_name)
# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
if not plugin:
# Try to find by other name if directory name doesn't match plugin name
for star in self.context.get_all_stars():
if star.root_dir_name == dir_name:
plugin = star
break
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(desti_dir, "README.md")
if not os.path.exists(readme_path):
readme_path = os.path.join(desti_dir, "readme.md")
if os.path.exists(readme_path):
try:
with open(readme_path, "r", encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
plugin_info = None
if plugin:
plugin_info = {"repo": plugin.repo, "readme": readme_content}
return plugin_info
+144
View File
@@ -0,0 +1,144 @@
from typing import Union, Awaitable, List, Optional, ClassVar
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.platform import MessageMember, AstrBotMessage
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.star.context import Context
class StarTools:
"""
提供给插件使用的便捷工具函数集合
这些方法封装了一些常用操作使插件开发更加简单便捷!
"""
_context: ClassVar[Optional[Context]] = None
@classmethod
def initialize(cls, context: Context) -> None:
"""
初始化StarTools设置context引用
Args:
context: 暴露给插件的上下文
"""
cls._context = context
@classmethod
async def send_message(
cls, session: Union[str, MessageSesion], message_chain: MessageChain
) -> bool:
"""
根据session(unified_msg_origin)主动发送消息
Args:
session: 消息会话通过event.session或者event.unified_msg_origin获取
message_chain: 消息链
Returns:
bool: 是否找到匹配的平台
Raises:
ValueError: 当session为字符串且解析失败时抛出
Note:
qq_official(QQ官方API平台)不支持此方法
"""
return await cls._context.send_message(session, message_chain)
@classmethod
async def create_message(
cls,
type: str,
self_id: str,
session_id: str,
message_id: str,
sender: MessageMember,
message: List[BaseMessageComponent],
message_str: str,
raw_message: object,
group_id: str = "",
):
"""
创建一个AstrBot消息对象
Args:
type (str): 消息类型
self_id (str): 机器人自身ID
session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等)
message_id (str): 消息ID
sender (MessageMember): 发送者信息
message (List[BaseMessageComponent]): 消息组件列表
message_str (str): 消息字符串
raw_message (object): 原始消息对象
group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "".
Returns:
AstrBotMessage: 创建的消息对象
"""
abm = AstrBotMessage()
abm.type = type
abm.self_id = self_id
abm.session_id = session_id
abm.message_id = message_id
abm.sender = sender
abm.message = message
abm.message_str = message_str
abm.raw_message = raw_message
abm.group_id = group_id
return abm
# todo: 添加构造事件的方法
# async def create_event(
# self, platform: str, umo: str, sender_id: str, session_id: str
# ):
# platform = self._context.get_platform(platform)
# todo: 添加找到对应平台并提交对应事件的方法
@classmethod
def activate_llm_tool(cls, name: str) -> bool:
"""
激活一个已经注册的函数调用工具
注册的工具默认是激活状态
Args:
name (str): 工具名称
"""
return cls._context.activate_llm_tool(name)
@classmethod
def deactivate_llm_tool(cls, name: str) -> bool:
"""
停用一个已经注册的函数调用工具
Args:
name (str): 工具名称
"""
return cls._context.deactivate_llm_tool(name)
@classmethod
def register_llm_tool(
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
) -> None:
"""
为函数调用function-calling/tools-use添加工具
Args:
name (str): 工具名称
func_args (list): 函数参数列表
desc (str): 工具描述
func_obj (Awaitable): 函数对象必须是异步函数
"""
cls._context.register_llm_tool(name, func_args, desc, func_obj)
@classmethod
def unregister_llm_tool(cls, name: str) -> None:
"""
删除一个函数调用工具
如果再要启用需要重新注册
Args:
name (str): 工具名称
"""
cls._context.unregister_llm_tool(name)
+1 -1
View File
@@ -41,7 +41,7 @@ class PluginUpdator(RepoZipUpdator):
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
await self.download_from_repo_url(plugin_path, repo_url)
await self.download_from_repo_url(plugin_path, repo_url, proxy=proxy)
try:
remove_dir(plugin_path)
+12
View File
@@ -9,6 +9,11 @@ from astrbot.core.utils.io import download_file
class AstrBotUpdator(RepoZipUpdator):
"""AstrBot 更新器,继承自 RepoZipUpdator 类
该类用于处理 AstrBot 的更新操作
功能包括检查更新下载更新文件解压缩更新文件等
"""
def __init__(self, repo_mirror: str = "") -> None:
super().__init__(repo_mirror)
self.MAIN_PATH = os.path.abspath(
@@ -17,6 +22,9 @@ class AstrBotUpdator(RepoZipUpdator):
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
def terminate_child_processes(self):
"""终止当前进程的所有子进程
使用 psutil 库获取当前进程的所有子进程并尝试终止它们
"""
try:
parent = psutil.Process(os.getpid())
children = parent.children(recursive=True)
@@ -35,6 +43,9 @@ class AstrBotUpdator(RepoZipUpdator):
pass
def _reboot(self, delay: int = 3):
"""重启当前程序
在指定的延迟后终止所有子进程并重新启动程序
"""
py = sys.executable
time.sleep(delay)
self.terminate_child_processes()
@@ -46,6 +57,7 @@ class AstrBotUpdator(RepoZipUpdator):
raise e
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
"""检查更新"""
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
async def get_releases(self) -> list:
+22 -11
View File
@@ -8,6 +8,9 @@ import base64
import zipfile
import uuid
import psutil
import certifi
from typing import Union
from PIL import Image
@@ -17,24 +20,20 @@ def on_error(func, path, exc_info):
"""
a callback of the rmtree function.
"""
print(f"remove {path} failed.")
import stat
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise
raise exc_info[1]
def remove_dir(file_path) -> bool:
if not os.path.exists(file_path):
return True
try:
shutil.rmtree(file_path, onerror=on_error)
return True
except BaseException:
return False
shutil.rmtree(file_path, onerror=on_error)
return True
def port_checker(port: int, host: str = "localhost"):
@@ -81,7 +80,13 @@ async def download_image_by_url(
下载图片, 返回 path
"""
try:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(
cafile=certifi.where()
) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
if post:
async with session.post(url, json=post_data) as resp:
if not path:
@@ -98,7 +103,7 @@ async def download_image_by_url(
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client.ClientConnectorSSLError:
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("DEFAULT")
@@ -118,7 +123,13 @@ async def download_file(url: str, path: str, show_progress: bool = False):
从指定 url 下载文件到指定路径 path
"""
try:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(
cafile=certifi.where()
) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.get(url, timeout=1800) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
@@ -141,7 +152,7 @@ async def download_file(url: str, path: str, show_progress: bool = False):
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
except aiohttp.client.ClientConnectorSSLError:
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("DEFAULT")
+1
View File
@@ -16,6 +16,7 @@ class SharedPreferences:
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4)
f.flush()
def get(self, key, default=None):
return self._data.get(key, default)
File diff suppressed because it is too large Load Diff
+7 -1
View File
@@ -1,5 +1,7 @@
import aiohttp
import os
import ssl
import certifi
from . import RenderStrategy
from astrbot.core.config import VERSION
@@ -46,7 +48,11 @@ class NetworkRenderStrategy(RenderStrategy):
},
}
if return_url:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.post(
f"{self.BASE_RENDER_URL}/generate", json=post_data
) as resp:
+23 -3
View File
@@ -2,6 +2,10 @@ import aiohttp
import os
import zipfile
import shutil
import ssl
import certifi
from astrbot.core.utils.io import on_error, download_file
from astrbot.core import logger
@@ -19,7 +23,7 @@ class ReleaseInfo:
self.body = body
def __str__(self) -> str:
return f"新版本: {self.version}, 发布于: {self.published_at}, 详细内容: {self.body}"
return f"\n{self.body}\n\n版本: {self.version} | 发布于: {self.published_at}"
class RepoZipUpdator:
@@ -33,8 +37,23 @@ class RepoZipUpdator:
返回一个列表每个元素是一个字典包含版本号发布时间更新内容commit hash等信息
"""
try:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(
cafile=certifi.where()
) # 新增:创建基于 certifi 的 SSL 上下文
connector = aiohttp.TCPConnector(
ssl=ssl_context
) # 新增:使用 TCPConnector 指定 SSL 上下文
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.get(url) as response:
# 检查 HTTP 状态码
if response.status != 200:
text = await response.text()
logger.error(
f"请求 {url} 失败,状态码: {response.status}, 内容: {text}"
)
raise Exception(f"请求失败,状态码: {response.status}")
result = await response.json()
if not result:
return []
@@ -53,7 +72,8 @@ class RepoZipUpdator:
"zipball_url": release["zipball_url"],
}
)
except BaseException:
except Exception as e:
logger.error(f"解析版本信息时发生异常: {e}")
raise Exception("解析版本信息失败")
return ret
-3
View File
@@ -1,3 +0,0 @@
from .dashboard_lifecycle import AstrBotDashBoardLifecycle
__all__ = ["AstrBotDashBoardLifecycle"]
+4
View File
@@ -6,6 +6,8 @@ from .stat import StatRoute
from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
from .tools import ToolsRoute # 导入新的ToolsRoute
from .conversation import ConversationRoute
__all__ = [
@@ -17,4 +19,6 @@ __all__ = [
"LogRoute",
"StaticFileRoute",
"ChatRoute",
"ToolsRoute", # 添加新的ToolsRoute
"ConversationRoute",
]
+21 -2
View File
@@ -2,7 +2,8 @@ import jwt
import datetime
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core import WEBUI_SK
from astrbot.core import WEBUI_SK, DEMO_MODE
from astrbot import logger
class AuthRoute(Route):
@@ -19,15 +20,33 @@ class AuthRoute(Route):
password = self.config["dashboard"]["password"]
post_data = await request.json
if post_data["username"] == username and post_data["password"] == password:
change_pwd_hint = False
if username == "astrbot" and password == "77b90590a8945a7d36c963981a307dc9":
change_pwd_hint = True
logger.warning("为了保证安全,请尽快修改默认密码。")
return (
Response()
.ok({"token": self.generate_jwt(username), "username": username})
.ok(
{
"token": self.generate_jwt(username),
"username": username,
"change_pwd_hint": change_pwd_hint,
}
)
.__dict__
)
else:
return Response().error("用户名或密码错误").__dict__
async def edit_account(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
password = self.config["dashboard"]["password"]
post_data = await request.json
+34 -11
View File
@@ -12,8 +12,11 @@ from astrbot.core import logger
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
return int(value)
if type_ == "int":
try:
return int(value)
except (ValueError, TypeError):
return None
elif (
type_ == "float"
and isinstance(value, str)
@@ -22,6 +25,11 @@ def try_cast(value: str, type_: str):
return float(value)
elif type_ == "float" and isinstance(value, int):
return float(value)
elif type_ == "float":
try:
return float(value)
except (ValueError, TypeError):
return None
def validate_config(
@@ -31,17 +39,24 @@ def validate_config(
def validate(data: dict, metadata: dict = schema, path=""):
for key, value in data.items():
print(key, value)
if key not in metadata:
# 无 schema 的配置项,执行类型猜测
if isinstance(value, str):
if value.isdigit():
try:
data[key] = int(value)
elif value.replace(".", "", 1).isdigit():
continue
except ValueError:
pass
try:
data[key] = float(value)
elif value == "true":
continue
except ValueError:
pass
if value.lower() == "true":
data[key] = True
elif value == "false":
elif value.lower() == "false":
data[key] = False
continue
meta = metadata[key]
@@ -97,7 +112,6 @@ def validate_config(
errors.append(
f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}"
)
validate(value, meta["items"], path=f"{path}{key}.")
if is_core:
for key, group in schema.items():
@@ -147,6 +161,7 @@ class ConfigRoute(Route):
"/config/provider/new": ("POST", self.post_new_provider),
"/config/provider/update": ("POST", self.post_update_provider),
"/config/provider/delete": ("POST", self.post_delete_provider),
"/config/llmtools": ("GET", self.get_llm_tools),
}
self.register_routes()
@@ -220,7 +235,8 @@ class ConfigRoute(Route):
return Response().error("未找到对应平台").__dict__
try:
await self._save_astrbot_configs(self.config)
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.platform_manager.reload(new_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "更新平台配置成功~").__dict__
@@ -256,7 +272,8 @@ class ConfigRoute(Route):
else:
return Response().error("未找到对应平台").__dict__
try:
await self._save_astrbot_configs(self.config)
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.platform_manager.terminate_platform(platform_id)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "删除平台配置成功~").__dict__
@@ -277,6 +294,12 @@ class ConfigRoute(Route):
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__
async def get_llm_tools(self):
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
tools = tool_mgr.get_func_desc_openai_style()
return Response().ok(tools).__dict__
async def _get_astrbot_config(self):
config = self.config
@@ -322,7 +345,7 @@ class ConfigRoute(Route):
async def _save_astrbot_configs(self, post_configs: dict):
try:
save_config(post_configs, self.config, is_core=True)
self.core_lifecycle.restart()
await self.core_lifecycle.restart()
except Exception as e:
raise e
+215
View File
@@ -0,0 +1,215 @@
import traceback
import json
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
from astrbot.core.db import BaseDatabase
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class ConversationRoute(Route):
def __init__(
self,
context: RouteContext,
db_helper: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.routes = {
"/conversation/list": ("GET", self.list_conversations),
"/conversation/detail": (
"POST",
self.get_conv_detail,
),
"/conversation/update": ("POST", self.upd_conv),
"/conversation/delete": ("POST", self.del_conv),
"/conversation/update_history": (
"POST",
self.update_history,
),
}
self.db_helper = db_helper
self.register_routes()
async def list_conversations(self):
"""获取对话列表,支持分页、排序和筛选"""
try:
# 获取分页参数
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
# 获取筛选参数
platforms = request.args.get("platforms", "")
message_types = request.args.get("message_types", "")
search_query = request.args.get("search", "")
exclude_ids = request.args.get("exclude_ids", "")
exclude_platforms = request.args.get("exclude_platforms", "")
# 转换为列表
platform_list = platforms.split(",") if platforms else []
message_type_list = message_types.split(",") if message_types else []
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
exclude_platform_list = (
exclude_platforms.split(",") if exclude_platforms else []
)
# 限制页面大小,防止请求过大数据
if page < 1:
page = 1
if page_size < 1:
page_size = 20
if page_size > 100:
page_size = 100
# 使用数据库的分页方法获取会话列表和总数,传入筛选条件
try:
conversations, total_count = self.db_helper.get_filtered_conversations(
page=page,
page_size=page_size,
platforms=platform_list,
message_types=message_type_list,
search_query=search_query,
exclude_ids=exclude_id_list,
exclude_platforms=exclude_platform_list,
)
except Exception as e:
logger.error(f"数据库查询出错: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"数据库查询出错: {str(e)}").__dict__
# 计算总页数
total_pages = (
(total_count + page_size - 1) // page_size if total_count > 0 else 1
)
result = {
"conversations": conversations,
"pagination": {
"page": page,
"page_size": page_size,
"total": total_count,
"total_pages": total_pages,
},
}
return Response().ok(result).__dict__
except Exception as e:
error_msg = f"获取对话列表失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"获取对话列表失败: {str(e)}").__dict__
async def get_conv_detail(self):
"""获取指定对话详情(通过POST请求)"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
return (
Response()
.ok(
{
"user_id": user_id,
"cid": cid,
"title": conversation.title,
"persona_id": conversation.persona_id,
"history": conversation.history,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
}
)
.__dict__
)
except Exception as e:
logger.error(f"获取对话详情失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"获取对话详情失败: {str(e)}").__dict__
async def upd_conv(self):
"""更新对话信息(标题和角色ID)"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
title = data.get("title")
persona_id = data.get("persona_id", "")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
if title is not None:
self.db_helper.update_conversation_title(user_id, cid, title)
if persona_id is not None:
self.db_helper.update_conversation_persona_id(user_id, cid, persona_id)
return Response().ok({"message": "对话信息更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话信息失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"更新对话信息失败: {str(e)}").__dict__
async def del_conv(self):
"""删除对话"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
self.db_helper.delete_conversation(user_id, cid)
return Response().ok({"message": "对话删除成功"}).__dict__
except Exception as e:
logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"删除对话失败: {str(e)}").__dict__
async def update_history(self):
"""更新对话历史内容"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
history = data.get("history")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
if history is None:
return Response().error("缺少必要参数: history").__dict__
# 历史记录必须是合法的 JSON 字符串
try:
if isinstance(history, list):
history = json.dumps(history)
else:
# 验证是否为有效的 JSON 字符串
json.loads(history)
except json.JSONDecodeError:
return (
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
)
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
self.db_helper.update_conversation(user_id, cid, history)
return Response().ok({"message": "对话历史更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话历史失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"更新对话历史失败: {str(e)}").__dict__
+33 -17
View File
@@ -1,5 +1,6 @@
import asyncio
from quart import websocket
import json
from quart import make_response
from astrbot.core import logger, LogBroker
from .route import Route, RouteContext
@@ -8,21 +9,36 @@ class LogRoute(Route):
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
super().__init__(context)
self.log_broker = log_broker
self.app.add_url_rule(
"/api/live-log", view_func=self.log, methods=["GET"], websocket=True
)
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
async def log(self):
queue = None
try:
queue = self.log_broker.register()
while True:
message = await queue.get()
await websocket.send(message)
except asyncio.CancelledError:
pass
except BaseException as e:
logger.error(f"WebSocket 连接错误: {e}")
finally:
if queue:
self.log_broker.unregister(queue)
async def stream():
queue = None
try:
queue = self.log_broker.register()
while True:
message = await queue.get()
payload = {
"type": "log",
**message # see astrbot/core/log.py
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
except asyncio.CancelledError:
pass
except BaseException as e:
logger.error(f"Log SSE 连接错误: {e}")
finally:
if queue:
self.log_broker.unregister(queue)
response = await make_response(
stream(),
{
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
},
)
response.timeout = None
return response
+64 -5
View File
@@ -1,5 +1,9 @@
import traceback
import aiohttp
import ssl
import certifi
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
@@ -11,6 +15,7 @@ from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.regex import RegexFilter
from astrbot.core.star.star_handler import EventType
from astrbot.core import DEMO_MODE
class PluginRoute(Route):
@@ -46,6 +51,13 @@ class PluginRoute(Route):
}
async def reload_plugins(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
data = await request.json
plugin_name = data.get("name", None)
try:
@@ -65,9 +77,14 @@ class PluginRoute(Route):
else:
urls = ["https://api.soulter.top/astrbot/plugins"]
# 新增:创建 SSL 上下文,使用 certifi 提供的根证书
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
for url in urls:
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.get(url) as response:
if response.status == 200:
result = await response.json()
@@ -178,6 +195,13 @@ class PluginRoute(Route):
return handlers
async def install_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
repo_url = post_data["url"]
@@ -187,30 +211,44 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url, proxy)
plugin_info = await self.plugin_manager.install_plugin(repo_url, proxy)
# self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
return Response().ok(plugin_info, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def install_plugin_upload(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
file = await request.files
file = file["file"]
logger.info(f"正在安装用户上传的插件 {file.filename}")
file_path = f"data/temp/{file.filename}"
await file.save(file_path)
await self.plugin_manager.install_plugin_from_file(file_path)
plugin_info = await self.plugin_manager.install_plugin_from_file(file_path)
# self.core_lifecycle.restart()
logger.info(f"安装插件 {file.filename} 成功")
return Response().ok(None, "安装成功。").__dict__
return Response().ok(plugin_info, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def uninstall_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
plugin_name = post_data["name"]
try:
@@ -223,6 +261,13 @@ class PluginRoute(Route):
return Response().error(str(e)).__dict__
async def update_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
plugin_name = post_data["name"]
proxy: str = post_data.get("proxy", None)
@@ -238,6 +283,13 @@ class PluginRoute(Route):
return Response().error(str(e)).__dict__
async def off_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
plugin_name = post_data["name"]
try:
@@ -249,6 +301,13 @@ class PluginRoute(Route):
return Response().error(str(e)).__dict__
async def on_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
plugin_name = post_data["name"]
try:
+29 -4
View File
@@ -1,12 +1,14 @@
import traceback
import psutil
import time
import threading
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.config import VERSION
from astrbot.core import DEMO_MODE
class StatRoute(Route):
@@ -28,7 +30,14 @@ class StatRoute(Route):
self.core_lifecycle = core_lifecycle
async def restart_core(self):
self.core_lifecycle.restart()
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
await self.core_lifecycle.restart()
return Response().ok().__dict__
def format_sec(self, sec: int):
@@ -64,6 +73,20 @@ class StatRoute(Route):
stat_dict = stat.__dict__
cpu_percent = psutil.cpu_percent(interval=0.5)
thread_count = threading.active_count()
# 获取插件信息
plugins = self.core_lifecycle.star_context.get_all_stars()
plugin_info = []
for plugin in plugins:
info = {
"name": getattr(plugin, "name", plugin.__class__.__name__),
"version": getattr(plugin, "version", "1.0.0"),
"is_enabled": True,
}
plugin_info.append(info)
stat_dict.update(
{
"platform": self.db_helper.get_grouped_base_stats(
@@ -73,9 +96,8 @@ class StatRoute(Route):
"platform_count": len(
self.core_lifecycle.platform_manager.get_insts()
),
"plugin_count": len(
self.core_lifecycle.star_context.get_all_stars()
),
"plugin_count": len(plugins),
"plugins": plugin_info,
"message_time_series": message_time_based_stats,
"running": self.format_sec(
int(time.time()) - self.core_lifecycle.start_time
@@ -84,6 +106,9 @@ class StatRoute(Route):
"process": psutil.Process().memory_info().rss >> 20,
"system": psutil.virtual_memory().total >> 20,
},
"cpu_percent": round(cpu_percent, 1),
"thread_count": thread_count,
"start_time": self.core_lifecycle.start_time,
}
)
+2
View File
@@ -20,6 +20,8 @@ class StaticFileRoute(Route):
"/providers",
"/about",
"/extension-marketplace",
"/conversation",
"/tool-use",
]
for i in index_:
self.app.add_url_rule(i, view_func=self.index)
+252
View File
@@ -0,0 +1,252 @@
import os
import json
import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core import logger
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
class ToolsRoute(Route):
def __init__(
self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle
) -> None:
super().__init__(context)
self.core_lifecycle = core_lifecycle
self.routes = {
"/tools/mcp/servers": ("GET", self.get_mcp_servers),
"/tools/mcp/add": ("POST", self.add_mcp_server),
"/tools/mcp/update": ("POST", self.update_mcp_server),
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
}
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
@property
def mcp_config_path(self):
current_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(current_dir, "../../../data"))
return os.path.join(data_dir, "mcp_server.json")
def load_mcp_config(self):
if not os.path.exists(self.mcp_config_path):
# 配置文件不存在,创建默认配置
os.makedirs(os.path.dirname(self.mcp_config_path), exist_ok=True)
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
return DEFAULT_MCP_CONFIG
try:
with open(self.mcp_config_path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.error(f"加载 MCP 配置失败: {e}")
return DEFAULT_MCP_CONFIG
def save_mcp_config(self, config):
try:
with open(self.mcp_config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
return True
except Exception as e:
logger.error(f"保存 MCP 配置失败: {e}")
return False
async def get_mcp_servers(self):
try:
config = self.load_mcp_config()
servers = []
# 获取所有服务器并添加它们的工具列表
for name, server_config in config["mcpServers"].items():
server_info = {
"name": name,
"active": server_config.get("active", True),
}
# 复制所有配置字段
for key, value in server_config.items():
if key != "active": # active 已经处理
server_info[key] = value
# 如果MCP客户端已初始化,从客户端获取工具名称
for (
name_key,
mcp_client,
) in self.tool_mgr.mcp_client_dict.items():
if name_key == name:
server_info["tools"] = [tool.name for tool in mcp_client.tools]
break
else:
server_info["tools"] = []
servers.append(server_info)
return Response().ok(servers).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取 MCP 服务器列表失败: {str(e)}").__dict__
async def add_mcp_server(self):
try:
server_data = await request.json
name = server_data.get("name", "")
# 检查必填字段
if not name:
return Response().error("服务器名称不能为空").__dict__
# 移除特殊字段并检查配置是否有效
has_valid_config = False
server_config = {"active": server_data.get("active", True)}
# 复制所有配置字段
for key, value in server_data.items():
if key not in ["name", "active", "tools"]: # 排除特殊字段
server_config[key] = value
has_valid_config = True
if not has_valid_config:
return Response().error("必须提供有效的服务器配置").__dict__
config = self.load_mcp_config()
if name in config["mcpServers"]:
return Response().error(f"服务器 {name} 已存在").__dict__
config["mcpServers"][name] = server_config
if self.save_mcp_config(config):
# 动态初始化新MCP客户端
self.tool_mgr.mcp_service_queue.put_nowait(
{
"type": "init",
"name": name,
"cfg": config["mcpServers"][name],
}
)
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
else:
return Response().error("保存配置失败").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"添加 MCP 服务器失败: {str(e)}").__dict__
async def update_mcp_server(self):
try:
server_data = await request.json
name = server_data.get("name", "")
if not name:
return Response().error("服务器名称不能为空").__dict__
config = self.load_mcp_config()
if name not in config["mcpServers"]:
return Response().error(f"服务器 {name} 不存在").__dict__
# 获取活动状态
active = server_data.get(
"active", config["mcpServers"][name].get("active", True)
)
# 创建新的配置对象
server_config = {"active": active}
# 仅更新活动状态的特殊处理
only_update_active = True
# 复制所有配置字段
for key, value in server_data.items():
if key not in ["name", "active", "tools"]: # 排除特殊字段
server_config[key] = value
only_update_active = False
# 如果只更新活动状态,保留原始配置
if only_update_active:
for key, value in config["mcpServers"][name].items():
if key != "active": # 除了active之外的所有字段都保留
server_config[key] = value
config["mcpServers"][name] = server_config
if self.save_mcp_config(config):
# 处理MCP客户端状态变化
if active:
# 如果要激活服务器或者配置已更改
if name in self.tool_mgr.mcp_client_dict or not only_update_active:
await self.tool_mgr.mcp_service_queue.put(
{
"type": "terminate",
"name": name,
}
)
await self.tool_mgr.mcp_service_queue.put(
{
"type": "init",
"name": name,
"cfg": config["mcpServers"][name],
}
)
else:
# 客户端不存在,初始化
self.tool_mgr.mcp_service_queue.put_nowait(
{
"type": "init",
"name": name,
"cfg": config["mcpServers"][name],
}
)
else:
# 如果要停用服务器
if name in self.tool_mgr.mcp_client_dict:
self.tool_mgr.mcp_service_queue.put_nowait(
{
"type": "terminate",
"name": name,
}
)
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
else:
return Response().error("保存配置失败").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"更新 MCP 服务器失败: {str(e)}").__dict__
async def delete_mcp_server(self):
try:
server_data = await request.json
name = server_data.get("name", "")
if not name:
return Response().error("服务器名称不能为空").__dict__
config = self.load_mcp_config()
if name not in config["mcpServers"]:
return Response().error(f"服务器 {name} 不存在").__dict__
# 删除服务器配置
del config["mcpServers"][name]
if self.save_mcp_config(config):
# 关闭并删除MCP客户端
if name in self.tool_mgr.mcp_client_dict:
self.tool_mgr.mcp_service_queue.put_nowait(
{
"type": "terminate",
"name": name,
}
)
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
else:
return Response().error("保存配置失败").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"删除 MCP 服务器失败: {str(e)}").__dict__
+9 -2
View File
@@ -6,6 +6,7 @@ from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger, pip_installer
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from astrbot.core.config.default import VERSION
from astrbot.core import DEMO_MODE
class UpdateRoute(Route):
@@ -95,8 +96,7 @@ class UpdateRoute(Route):
logger.error(f"更新依赖失败: {e}")
if reboot:
# threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
self.core_lifecycle.restart()
await self.core_lifecycle.restart()
return (
Response()
.ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。")
@@ -127,6 +127,13 @@ class UpdateRoute(Route):
return Response().error(e.__str__()).__dict__
async def install_pip_package(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
data = await request.json
package = data.get("package", "")
if not package:
+25 -11
View File
@@ -20,7 +20,12 @@ DATAPATH = os.path.abspath(
class AstrBotDashboard:
def __init__(self, core_lifecycle: AstrBotCoreLifecycle, db: BaseDatabase) -> None:
def __init__(
self,
core_lifecycle: AstrBotCoreLifecycle,
db: BaseDatabase,
shutdown_event: asyncio.Event,
) -> None:
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.data_path = os.path.abspath(os.path.join(DATAPATH, "dist"))
@@ -45,6 +50,10 @@ class AstrBotDashboard:
self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context)
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
self.tools_root = ToolsRoute(self.context, core_lifecycle)
self.conversation_route = ConversationRoute(self.context, db, core_lifecycle)
self.shutdown_event = shutdown_event
async def auth_middleware(self):
if not request.path.startswith("/api"):
@@ -73,11 +82,6 @@ class AstrBotDashboard:
r.status_code = 401
return r
async def shutdown_trigger_placeholder(self):
while not self.core_lifecycle.event_queue.closed: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("管理面板已关闭。")
def check_port_in_use(self, port: int) -> bool:
"""
跨平台检测端口是否被占用
@@ -122,7 +126,15 @@ class AstrBotDashboard:
def run(self):
ip_addr = []
port = self.core_lifecycle.astrbot_config["dashboard"].get("port", 6185)
host = self.core_lifecycle.astrbot_config["dashboard"].get("host", "127.0.0.1")
host = self.core_lifecycle.astrbot_config["dashboard"].get("host", "0.0.0.0")
logger.info(f"正在启动 WebUI, 监听地址: http://{host}:{port}")
if host == "0.0.0.0":
logger.info(
"提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host"
)
if host not in ["localhost", "127.0.0.1"]:
try:
ip_addr = get_local_ip_addresses()
@@ -144,7 +156,7 @@ class AstrBotDashboard:
raise Exception(f"端口 {port} 已被占用")
display = f"\n ✨✨✨\n AstrBot v{VERSION} 管理面板已启动,可访问\n\n"
display = f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n"
display += f" ➜ 本地: http://localhost:{port}\n"
for ip in ip_addr:
display += f" ➜ 网络: http://{ip}:{port}\n"
@@ -158,7 +170,9 @@ class AstrBotDashboard:
logger.info(display)
return self.app.run_task(
host=host,
port=port,
shutdown_trigger=self.shutdown_trigger_placeholder,
host=host, port=port, shutdown_trigger=self.shutdown_trigger
)
async def shutdown_trigger(self):
await self.shutdown_event.wait()
logger.info("AstrBot WebUI 已经被优雅地关闭")
+4
View File
@@ -0,0 +1,4 @@
# What's Changed
1. 默认账户密码登录成功后弹出修改警告
2. 将 WebUI 默认 host 改变回 v3.4.38 之前的版本以减少兼容性问题。
+59
View File
@@ -0,0 +1,59 @@
# What's Changed
> 📢 AstrBot 上架宝塔面板 Docker 应用商店了!
> 📢 在升级前,请完整阅读本次更新日志。
## ✨ 新增的功能
1. ‼️ 新增支持接入 MCP 服务器 @Soulter @AraragiEro
1. ‼️ 新增支持本地渲染 Markdown,并支持自定义字体,详见 -> [#957](https://github.com/Soulter/AstrBot/issues/957#issuecomment-2749981802)
2. 新增支持在 WebUI 管理所有与大模型的对话
3. 适配完整的 function-calling 流程。[#804](https://github.com/Soulter/AstrBot/issues/804) [#566](https://github.com/Soulter/AstrBot/issues/566)
4. 新增支持消息平台热重载,不再需要重启 AstrBot
5. 新增支持阿里云百炼应用的 RAG 应用 [#878](https://github.com/Soulter/AstrBot/issues/878)
6. 新增 `/plugin get` OP 指令下载插件。如 `/plugin get Raven95676/astrbot_plugin_wordle`
7. 新增 `/newgroup` OP 指令,支持私聊 bot 给指定群聊创建新的对话。by @LunarMeal
8. Gewechat 下支持 `添加好友`, `接收/发送视频`, `获取群信息`, `接收/发送表情包` by @Moyuyanli @Soulter @XuYingJie-cmd @NiceAir
9. Telegram 下支持接收和处理表情包(Sticker) @Raven95676
## 🎈 功能性优化
0. 更加美观的 WebUI 设计,降低疲劳程度。
1. 微信下,忽略 `微信团队` 的消息 [#859](https://github.com/Soulter/AstrBot/issues/859)
2. 完善 Dify 的图片输入功能 [#893](https://github.com/Soulter/AstrBot/issues/893)
3. 消息平台和配置提供商配置页中,自动更新旧的配置项
4. 优化钉钉在配置错误之后堵塞整个线程的问题 [#885](https://github.com/Soulter/AstrBot/issues/885)
5. WebUI 删除插件时提供二次确认避免误删 @zhx8702
6. WebUI 优化新版本时的信息显示
7. 发送消息失败时的报错回显优化
8. 改善所有消息平台的优雅退出逻辑
9. 空 @ 时调用 LLM 获得更加富有人格的回复 by @advent259141
## 🐛 修复的 Bug
1. 修复图片没有被存储到聊天上下文历史记录
2. 修复 Telegram 下无法识别图片描述(Caption) [#910](https://github.com/Soulter/AstrBot/issues/910)
3. 修复 Telegram Topic 群组下引用消息来源错误的问题 [#908](https://github.com/Soulter/AstrBot/issues/908)
4. 修复 Telegram 下 `/start` 指令的一些问题 [#751](https://github.com/Soulter/AstrBot/issues/751)
5. WebUI 插件市场卡片显示风格的过滤问题。[#927](https://github.com/Soulter/AstrBot/issues/927)
6. 统一 SSL 证书验证逻辑,修复 `SSLCertVerificationError` 的问题。by @IGCrystal [#950](https://github.com/Soulter/AstrBot/issues/950)
7. 修复可能形成 SQL 注入的风险
8. 修复本地上传插件时无法重载插件的问题 [#995](https://github.com/Soulter/AstrBot/issues/995) by @zhx8702
## 🧩 新增的插件
1. astrbot_plugin_majsoul-master - 雀魂多功能插件 - by @kterna
2. astrbot_plugin_server - 可视化服务器状态卡片,/status 或 /状态查询 查看 - by @yanfd @Meguminlove
3. astrbot_plugin_Getcwm - 刺猬猫小说数据获取与画图插件 - by @Li-shi-ling
4. astrbot_plugin_anti_withdrawal - 防撤回插件,目前只支持微信私聊群聊的文本消息,将撤回的消息记录并发送给设定的人 - by @NiceAir
5. astrbot_plugin_hello77 - 游戏梗自动回复插件 - by @ttq7
6. astrbot_plugin_push_lite - Webhook 轻量级推送插件 - @Raven95676
7. astrbot_plugin_pokecheck - 检测“戳”关键词的插件 - @huanyan434
8. astrbot_plugin_MultiAI_PollPad - 轮询调用配置的大语言模型输出多个结果。同时将 AI 结果拷贝至在线文本编辑器 - by @Ynkcc
9. astrbot_plugin_box - / - by @Zhalslar
10. astrbot_plugin_Translation - 通过调用百度翻译 API 实现翻译文本 - by @zengweis
11. astrbot_plugin_wordle_2 - Wordle 游戏插件 - by @Raven95676 @whzcc
12. astrbot_plugin_mai_sgin - 舞萌出勤与退勤签到插件 - by @Rinyin
13. astrbot_plugin_Lolicon - Lolicon API 随机动漫图片插件 - by @ttq7
14. astrbot_plugin_aiocensor - 综合内容安全+群管插件 - by @Raven95676
+30
View File
@@ -0,0 +1,30 @@
# What's Changed
> 📢 在升级前,请完整阅读本次更新日志。
## ✨ 新增的功能
1. 适配 `gemini-2.0-flash-exp-image-generation` 对图片模态的输入 [#1017](https://github.com/Soulter/AstrBot/issues/1017)
2. 在 MessageChain 类中添加 at 和 at_all 方法,用于快速添加 At 消息 @left666
3. Gewechat Client 增加获取通讯录列表接口
4. 支持 /llm 指令快捷启停 LLM 功能 [#296](https://github.com/Soulter/AstrBot/issues/296)
## 🎈 功能性优化
1. Edge TTS 支持使用代理
2. 在 Lifecycle 新增插件资源清理逻辑 @Raven95676
3. Docker 镜像提供内置 FFmpeg [#979](https://github.com/Soulter/AstrBot/issues/979)
4. 优化无对话情况下设置人格的反馈 @Raven95676
5. 若禁用提供商,自动切换到另一个可用的提供商 @Raven95676
6. openai_source 同步支持随机请求均衡,同时优化 LLM 请求逻辑的异常处理
7. 保存 shared_preferences 时强制刷新文件缓冲区
8. 优化空 At 回复 @advent259141
## 🐛 修复的 Bug
1. 插件更新时没有正确应用加速地址
2. newgroup 指令名显示错误
## 🧩 新增的插件
待补充
+31
View File
@@ -0,0 +1,31 @@
# What's Changed
> 📢 在升级前,请完整阅读本次更新日志。
## ✨ 新增的功能
1. 安装完插件后自动弹出插件仓库 README 对话框 @zhx8702
4. 支持阿里云百炼 TTS@Soulter
5. 支持 Telegram MarkdownV2 渲染 @Soulter
6. 支持 钉钉 Markdown 渲染 @Soulter
6. 增加对 Gemini 系列模型的输入安全设置参数支持 @AliveGh0st
7. 支持手动设置时区以应对容器、国外用户的时区问题 @anka-afk @Raven95676 @Soulter
8. 插件市场显示帮助按钮 @Soulter
## 🎈 功能性优化
1. WebUI 的日志通信使用 SSE 替代 Websockets @Soulter
2. 在发送消息之前统一检查消息内容是否为空, 不允许发送空消息, 以解决该消息内容不支持查看以及 Gemini 返回 `<empty content>` 问题 @anka-afk
3. 更新 Dify 平台链接为官方域名 by @Captain-Slacker-OwO
4. 人格 prompt 输入框支持调节高度 @Soulter
## 🐛 修复的 Bug
1. 将最多携带对话数量修改回 `-1` 时出现报错 #1074 @anka-afk
2. 修复无法识别到函数调用异常的问题 by @Soulter
3. 修复 aiocqhttp 适配器下空白 plain 导致的 `the object is not a proper segment chain` 报错问题 @Soulter
4. 修复阿里百炼应用无法多轮会话的问题 @Soulter
## 🧩 新增的插件
待补充
+11 -6
View File
@@ -1,16 +1,21 @@
version: '3.8'
# 当接入 QQ NapCat 时,请使用这个 compose 文件一键部署: https://github.com/NapNeko/NapCat-Docker/blob/main/compose/astrbot.yml
services:
astrbot:
image: soulter/astrbot:latest
container_name: astrbot
restart: always
ports: # mappings description: https://github.com/Soulter/AstrBot/issues/497
- "6185:6185"
- "6195:6195" # optional, wecom default port
- "6199:6199" # optional, aiocqhttp default port
- "6196:6196" # optional, qq official webhook default port
- "11451:11451" # optional, gewechat default port
- "6185:6185" # 必选,AstrBot WebUI 端口
- "6195:6195" # 可选, 企业微信 Webhook 端口
- "6199:6199" # 可选, QQ 个人号 WebSocket 端口
- "6196:6196" # 可选, QQ 官方接口 Webhook 端口
- "11451:11451" # 可选, 微信个人号 Webhook 端口
environment:
- TZ=Asia/Shanghai
volumes:
- ./data:/AstrBot/data
- /etc/timezone:/etc/timezone:ro
# - /etc/timezone:/etc/timezone:ro
- /etc/localtime:/etc/localtime:ro
+2 -1
View File
@@ -21,9 +21,10 @@
"axios-mock-adapter": "^1.22.0",
"chance": "1.1.11",
"date-fns": "2.30.0",
"highlight.js": "^11.11.1",
"js-md5": "^0.8.3",
"lodash": "4.17.21",
"marked": "^15.0.6",
"marked": "^15.0.7",
"pinia": "2.1.6",
"remixicon": "3.5.0",
"vee-validate": "4.11.3",
@@ -0,0 +1,44 @@
<template>
<v-dialog v-model="isOpen" max-width="400">
<v-card>
<v-card-title class="text-h6">{{ title }}</v-card-title>
<v-card-text>{{ message }}</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="gray" @click="handleCancel">取消</v-btn>
<v-btn color="red" @click="handleConfirm">确定</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</template>
<script setup>
import { ref } from "vue";
const isOpen = ref(false);
const title = ref("");
const message = ref("");
let resolvePromise = null; // Promise
const open = (options) => {
title.value = options.title || "确认操作";
message.value = options.message || "你确定要执行此操作吗?";
isOpen.value = true;
return new Promise((resolve) => {
resolvePromise = resolve; // Promise
});
};
const handleConfirm = () => {
isOpen.value = false;
if (resolvePromise) resolvePromise(true); // Promise
};
const handleCancel = () => {
isOpen.value = false;
if (resolvePromise) resolvePromise(false); // Promise
};
defineExpose({ open }); // `confirmPlugin.ts` 访 `open`
</script>

Some files were not shown because too many files have changed in this diff Show More