Compare commits

..

408 Commits

Author SHA1 Message Date
Soulter 062af1ac08 🎈 perf: 优化 WebUI 日志错误处理 2025-04-07 10:38:03 +08:00
Soulter 79d38f9597 📦release: v3.5.2 2025-04-06 22:36:31 +08:00
Soulter 4d186baa35 Merge pull request #1128 from anka-afk/anka-dev
feature: 实现了 #1127 还有 #1133 还有 #1143
2025-04-06 22:22:01 +08:00
Raven95676 e54eaab842 将验证器字典移到类级别,避免重复创建 2025-04-05 21:19:53 +08:00
Raven95676 43b6297b5d reminder将时区设置移入try块,统一为self.timezone 2025-04-05 21:08:52 +08:00
Raven95676 c20f4f5adf 删除默认值,调整logger逻辑 2025-04-05 21:03:02 +08:00
Soulter dc1f222cd2 fix: 使用 zoneinfo 替代 tzinfo; 默认不设置时区(使用系统默认时区) 2025-04-05 17:27:46 +08:00
Soulter c2b687212c cleanup 2025-04-05 16:51:06 +08:00
Soulter 849913276d 🎈 perf: 钉钉支持 Markdown 渲染输出
fixes: #1104
2025-04-05 16:29:14 +08:00
Soulter 23579c1e4a 🐛 fix: 阿里百炼应用无法多轮会话
fixes: #1123
2025-04-05 16:21:41 +08:00
Soulter e031161fd4 🐛 修复: 移除文本输入框的 auto-grow 属性
fixes: #1038
2025-04-05 15:58:17 +08:00
Soulter 4800ee6c0a Merge pull request #1152 from AstrBotDevs/feat-log-filter
 feat: 更新日志发布机制,支持日志级别和内容的字典格式,增加日志筛选功能
2025-04-05 15:49:09 +08:00
Soulter d3a7fef9b0 🐛 修复: 移除多余的 console 语句 2025-04-05 15:46:45 +08:00
Soulter 40822fe77a feat: 更新日志发布机制,支持日志级别和内容的字典格式,增加日志筛选功能
fixes: #1010
2025-04-05 15:43:40 +08:00
Soulter 837b670213 feat(webui): 支持修改列表项
fixes: #1086
2025-04-05 15:10:44 +08:00
Soulter 57ce69f3fb feat: WebChat 支持语音输出
fixes: #1087
2025-04-05 15:02:34 +08:00
anka be022c4894 fix: add StarTools to api 2025-04-05 11:55:25 +08:00
anka 8a366964bb feature: 增加时区设置支持 2025-04-05 11:52:51 +08:00
anka ee86b68470 fix: 漏加classmethod了! 2025-04-05 01:15:56 +08:00
anka 60352307aa fix: 重生之我要苦读设计模式, 终于知道怎么整了哈哈哈: 使用静态类实现工具集合, 并且正确初始化 2025-04-05 01:11:10 +08:00
anka 3ebd2f746f feature: 添加插件工具类, 暂时这么多 2025-04-05 00:51:52 +08:00
anka 1c1a65b637 fix: 全部消息段的检验弄好了! 2025-04-05 00:21:28 +08:00
anka 010e60d029 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-04-04 23:13:43 +08:00
Soulter 7a25568861 Merge pull request #1131 from AliveGh0st/feature/gemini-safety-settings
feature:增加对Gemini系列模型的安全设置参数支持
2025-04-04 21:22:58 +08:00
AliveGh0st 5f4f913661 feat: 增加对 Gemini 系列模型的输入安全设置参数支持
fixes: #216

Squashed:

Update astrbot/core/config/default.py

描述更正.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

🎨 style: clean up

🐛 fix: 修复安全设置参数的默认值为列表
2025-04-04 21:12:51 +08:00
Soulter ccd0e34a53 Merge pull request #1145 from AstrBotDevs/feat-telegram-markdownv2
 feat: 支持 Telegram MarkdownV2 渲染
2025-04-04 20:54:04 +08:00
Soulter 72f1ffccd3 feat: 支持 Telegram MarkdownV2 渲染
fixes: #649 #907
2025-04-04 20:52:22 +08:00
Soulter ea7a52945f Merge pull request #1132 from Captain-Slacker-OwO/dify-md
docs: 更新 Dify 平台链接为官方域名
2025-04-04 01:12:19 +08:00
Soulter 89d4d1351a Merge pull request #1135 from AstrBotDevs/feat-dashscope-tts
feat: 支持阿里云百炼 TTS
2025-04-04 01:03:36 +08:00
Soulter b757c91d93 🐛 fix: 修复无法识别到函数调用异常的问题 2025-04-04 01:02:39 +08:00
Soulter 27203d7a4d 🐛 fix: update voice key name 2025-04-04 00:47:50 +08:00
Soulter 9ad4e18ac5 feat: 支持阿里云百炼 TTS 2025-04-04 00:32:37 +08:00
anka fcdc8f3ce7 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-04-03 21:57:24 +08:00
Captain-Slacker-OwO 78b994b84a docs: 更新 Dify 平台链接为官方域名
将 README 文件中的 Dify 平台链接从旧域名更新为官方域名 dify.ai,确保文档的准确性和权威性。
2025-04-03 19:00:44 +08:00
Soulter 58bfc677e2 🐛 fix: dify error Arg user must be provided
fixes #1073
2025-04-03 16:49:05 +08:00
Soulter 7d17285a0c 🐛 fix: ensure whitelist entries are stripped of whitespace and converted to strings 2025-04-03 16:44:37 +08:00
Soulter e9eb00a0d4 feat: 插件市场帮助按钮 2025-04-03 16:19:01 +08:00
anka 48d07af574 feature(fix?): 在发送消息之前统一检查消息内容是否为空, 不允许发送空消息, 以解决该消息内容不支持查看以及gemini返回<empty content>问题 2025-04-03 11:50:12 +08:00
Soulter 2fc62efd88 Merge pull request #1116 from AstrBotDevs/feat-log-sse
🏗 refactor: log 通信使用 SSE 替代 Websockets
2025-04-02 21:07:40 +08:00
Soulter be516d75bd 🐛 fix: upadte method name 2025-04-02 21:06:59 +08:00
Soulter 951d5fde85 🏗 refactor: log 通信使用 SSE 替代 Websockets 2025-04-02 20:59:25 +08:00
Soulter 1389abc052 Merge pull request #1112 from AstrBotDevs/fix-aiocqhttp-empty-plain
修复 aiocqhttp 适配器下空白 plain 导致的报错
2025-04-02 16:27:12 +08:00
Soulter 19ad67a77f 🐛 fix: 修复 aiocqhttp 适配器下空白 plain 导致的 the object is not a proper segment chain 报错问题 2025-04-02 16:24:36 +08:00
Soulter 641f308344 Update README.md 2025-04-01 11:35:56 +08:00
Soulter 9f097fa4d5 Update README.md 2025-04-01 11:33:38 +08:00
Soulter 5ad362c52b Merge pull request #1081 from anka-afk/anka-dev
fix #1074 and add some comment
2025-04-01 10:57:40 +08:00
Soulter 614f238a61 Merge pull request #1072 from zhx8702/feat-add-plugin-md-dialog
feat: 安装完插件后自动弹出插件仓库 README 对话框
2025-04-01 10:56:24 +08:00
zhx dec91950bc feat: 安装完插件后自动弹出插件仓库 README 对话框 2025-04-01 10:04:04 +08:00
anka 6cef9c23f0 bug fix: #1074 修改最多携带对话数量时出现bug 2025-03-31 22:41:23 +08:00
anka 3f568bf136 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-03-31 22:32:40 +08:00
anka 5484b421ce perf: 增加部分注释 2025-03-31 22:30:43 +08:00
Soulter 02f21e07d3 📦 release: v3.5.1 2025-03-31 10:59:32 +08:00
Soulter fff1f23a83 Update README.md 2025-03-31 00:57:23 +08:00
Soulter a056ec0d38 Merge pull request #1065 from AstrBotDevs/perf-openai-source-balance
🎈 perf: OpenAI sources supports api key load balance(random)
2025-03-30 22:53:27 +08:00
Soulter 2eb9e5dde3 perf: 添加重试等待 2025-03-30 22:51:34 +08:00
渡鸦95676 627d2a4701 新增重试间隔 2025-03-30 22:33:21 +08:00
Soulter 76895fe86d chore: improve variable names 2025-03-30 22:12:34 +08:00
Soulter 64c3c85780 Merge pull request #1056 from Raven95676/master
perf: 优化无对话情况下设置人格的反馈;若禁用提供商,自动切换到另一个可用的提供商
2025-03-30 22:10:23 +08:00
Soulter 7288348857 🎈 perf: OpenAI sources supports api key load balance(random) 2025-03-30 22:00:45 +08:00
Soulter 62e73299b1 🐛 fix: forcely write shared preference data
Note: this is a fast fix for recent feedbacks, we'll improve its performance.
2025-03-30 21:33:41 +08:00
Raven95676 fe76c41ed8 perf: 若禁用提供商,自动切换到另一个可用的提供商 2025-03-30 15:18:48 +08:00
Raven95676 1a92edf8be perf: 优化无对话情况下设置人格的反馈 2025-03-30 14:38:40 +08:00
Soulter b63b606a4e docs: 推荐使用 uv 进行手动部署 2025-03-30 10:39:14 +08:00
Soulter 8e2ef3d22b Merge pull request #1050 from advent259141/master
回复空@功能的修复
2025-03-30 00:15:26 +08:00
Gao Jinzhe c6c4a32283 Add files via upload 2025-03-29 22:37:18 +08:00
Soulter b70b3b158e feat: 支持 gemini-2.0-flash-exp-image-generation 对图片模态的输入 #1017 2025-03-29 20:51:27 +08:00
Soulter 3d59ab8108 fix: conversation and tool use page refresh 404 2025-03-29 19:17:56 +08:00
Soulter b6c3089510 🎈 perf: 优化空 at 回复 2025-03-29 19:09:35 +08:00
Soulter bd92aac280 feat: 支持 /llm 指令快捷启停 LLM 功能 #296 2025-03-29 18:31:07 +08:00
Soulter 5299e802e9 Merge pull request #1046 from AstrBotDevs/feat-docker-embedded-ffmpeg
docker 镜像提供内置 ffmpeg
2025-03-29 17:53:40 +08:00
Soulter 8e5a57d7dd Merge pull request #1045 from Raven95676/master
在lifecycle新增插件资源清理逻辑
2025-03-29 17:53:16 +08:00
Soulter beaa324fb6 Merge pull request #1012 from Zhenyi-Wang/master
feat: gewechat client增加获取通讯录列表接口
2025-03-29 17:51:35 +08:00
Soulter 79e64fe206 Merge pull request #1011 from left666/left666
feat(core): 在 MessageChain 类中添加 at 和 at_all 方法
2025-03-29 17:50:55 +08:00
Soulter 93f525e3fe 🎈 perf: edge tts 支持使用代理;移除了一些不需要的方法 2025-03-29 17:48:22 +08:00
Soulter aacb803c64 Merge pull request #999 from Futureppo/master
部分api获取不到model导致key泄露,使用正则表达式过滤掉key内容
2025-03-29 17:43:10 +08:00
Soulter 8a0665b222 🎈 feat: 更新 Dockerfile,添加 Node.js 支持并优化依赖安装 2025-03-29 17:42:31 +08:00
Soulter 20e41a7f73 🐛 fix: newgroup 指令名显示错误 2025-03-29 17:42:31 +08:00
Soulter 93a1699a35 Update README.md 2025-03-29 17:42:31 +08:00
Soulter c33c07e4af Update README.md 2025-03-29 17:42:31 +08:00
Soulter c7484d0cc9 Update README.md 2025-03-29 17:42:31 +08:00
Soulter fb85a7bb35 feat: add demo mode 2025-03-29 17:42:31 +08:00
Soulter 42ff9a4d34 Update README.md 2025-03-29 17:42:31 +08:00
Soulter 005e9eae7c 🐛 fix: 插件更新时没有正确应用加速地址 2025-03-29 17:42:31 +08:00
Soulter 3e325debcc Update README.md 2025-03-29 17:42:31 +08:00
Soulter a221de9a2b 🐛 fix: 修复 LLM 响应后事件钩子无法生效的问题 2025-03-29 17:42:31 +08:00
Soulter 32b0cc1865 Update README.md 2025-03-29 17:42:31 +08:00
Soulter bbf85f8a12 🐛 fix: remove error logging for empty result and refresh extensions after upload 2025-03-29 17:42:31 +08:00
Soulter 67a0172b28 📦 release: v3.5.0 2025-03-29 17:42:31 +08:00
zhx fb19d4d45b fix: install_plugin_from_file 方法load传参数改为文件名 2025-03-29 17:42:31 +08:00
Soulter a156b1af14 feat: 支持通过指令下载插件 /plugin get 2025-03-29 17:42:31 +08:00
Soulter a604b4943c 🎈 perf: 优化新版本时的信息显示 2025-03-29 17:42:31 +08:00
pre-commit-ci[bot] 3f0b6435d9 🎈 auto fixes by pre-commit hooks 2025-03-29 17:42:31 +08:00
Gao Jinzhe e0f029e2cb Add files via upload 2025-03-29 17:42:31 +08:00
Soulter 89d3fd5fab 🎈 perf: 优化 WebUI 对话数据库中文历史检索 2025-03-29 17:42:31 +08:00
Soulter a38b00be6b 🐛 fix: 修复部分可能形成 SQL 注入的风险 2025-03-29 17:42:31 +08:00
Futureppo 0e8d52b591 :ballon: feat: 使用正则表达式过滤掉 /model 可能暴露的 api_key
Squashed:

更新正则表达式

🎈 auto fixes by pre-commit hooks

Update main.py

Update main.py

chore: bugfixes
2025-03-29 17:40:48 +08:00
Soulter 298c77740d feat: docker 镜像提供内置 ffmpeg #979 2025-03-29 17:26:57 +08:00
Raven95676 c681aae8ee 修复日志问题 2025-03-29 17:25:38 +08:00
Raven95676 faef98b089 在lifecycle新增插件资源清理逻辑 2025-03-29 17:07:12 +08:00
Soulter 84a3e0a30b 🎈 feat: 更新 Dockerfile,添加 Node.js 支持并优化依赖安装 2025-03-29 16:36:02 +08:00
Soulter 69bd553ce0 Merge pull request #1035 from AstrBotDevs/fix-1034-bug
🐛 fix: groupnew 指令名显示错误
2025-03-28 23:46:30 +08:00
Soulter fd0c0f8975 🐛 fix: newgroup 指令名显示错误 2025-03-28 23:45:19 +08:00
Zhenyi-Wang 860ceb06b4 Merge branch 'Soulter:master' into master 2025-03-28 21:27:25 +08:00
anka ecf501bf72 Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-03-28 19:04:35 +08:00
Soulter 81a2ed1e25 Update README.md 2025-03-28 18:20:33 +08:00
Soulter 76ab28338a Update README.md 2025-03-28 13:24:41 +08:00
Soulter 9a56c9630f Update README.md 2025-03-28 13:23:29 +08:00
anka 53b9497c18 perf: 增加部分注释 2025-03-27 21:32:38 +08:00
Soulter 750b16b6ee feat: add demo mode 2025-03-27 15:54:23 +08:00
anka 0ee3e0779a Merge remote-tracking branch 'origin/HEAD' into anka-dev 2025-03-27 15:21:04 +08:00
pre-commit-ci[bot] 333c2d9299 🎈 auto fixes by pre-commit hooks 2025-03-27 03:21:43 +00:00
Zhenyi Wang ad37ff5048 feat: gewechat client增加获取通讯录列表接口 2025-03-27 11:17:52 +08:00
pre-commit-ci[bot] 33f86f3bde 🎈 auto fixes by pre-commit hooks 2025-03-27 02:56:55 +00:00
Soulter 8acb969a49 Update README.md 2025-03-27 10:39:18 +08:00
left666 b74b5933b8 feat(core): 在 MessageChain 类中添加 at 和 at_all 方法
- 新增 at 方法,用于添加 At 消息到消息链中
- 新增 at_all 方法,用于添加 AtAll 消息到消息链中
2025-03-27 10:30:19 +08:00
Soulter 681c556b7e 🐛 fix: 插件更新时没有正确应用加速地址 2025-03-27 10:04:40 +08:00
anka 1746684e52 perf: 修改部分注释 2025-03-26 23:52:03 +08:00
Soulter 0b93d06555 Update README.md 2025-03-26 20:51:53 +08:00
anka 8a8b8c7c27 Merge remote-tracking branch 'origin/master' into anka-dev 2025-03-26 17:59:53 +08:00
anka 6b6577006d perf: 格式化 2025-03-26 17:59:30 +08:00
Soulter 23ee5e81c9 🐛 fix: 修复 LLM 响应后事件钩子无法生效的问题 2025-03-26 17:56:55 +08:00
Soulter 483f55e4b1 Update README.md 2025-03-26 16:16:03 +08:00
Soulter 1bb1bc2553 🐛 fix: remove error logging for empty result and refresh extensions after upload 2025-03-26 15:43:56 +08:00
Soulter a4e4e36f94 📦 release: v3.5.0 2025-03-26 15:30:09 +08:00
Soulter 6849415812 Merge pull request #996 from zhx8702/fix-star-manager
fix: install_plugin_from_file 方法load传参数改为文件名
2025-03-26 15:26:53 +08:00
zhx 86f6cb038e fix: install_plugin_from_file 方法load传参数改为文件名 2025-03-26 15:06:33 +08:00
Soulter 7480a1d6ce feat: 支持通过指令下载插件 /plugin get 2025-03-26 14:33:45 +08:00
Soulter 3cd10117dd 🎈 perf: 优化新版本时的信息显示 2025-03-26 14:14:01 +08:00
Soulter 0caf19d390 Merge pull request #937 from advent259141/master
将对只有一个 @ 的消息内容的处理改成调用llm回复
2025-03-26 13:54:43 +08:00
anka 5c14ebb049 Merge remote-tracking branch 'origin/master' into anka-dev 2025-03-26 13:53:21 +08:00
anka 9717a736b1 perf: 更新部分描述 2025-03-26 13:50:54 +08:00
Soulter 9c9ab50d1a 🎈 perf: 优化 WebUI 对话数据库中文历史检索 2025-03-26 13:50:11 +08:00
Soulter d4bcb8174e 🐛 fix: 修复部分可能形成 SQL 注入的风险 2025-03-26 13:41:18 +08:00
anka 9e7fe773bd perf: 更新部分注释 2025-03-26 11:14:46 +08:00
Soulter aca18fab0f feat: 优化配置文件中的提示信息,增强可读性 2025-03-26 00:56:51 +08:00
Soulter 691de01b79 feat: 支持设置最多携带对话数量 2025-03-26 00:46:15 +08:00
Soulter 3383f15142 Merge pull request #988 from Soulter/NiceAir/master
 feat: Update UI elements and improve layout in various components
2025-03-25 23:17:11 +08:00
Soulter 84c1593889 feat: Update UI elements and improve layout in various components 2025-03-25 21:52:15 +08:00
Soulter 3c80fa1e33 Update README.md 2025-03-25 21:31:23 +08:00
Soulter 06b16a1deb Merge pull request #983 from Soulter/feat-conversation-webui-mgr
 支持 WebUI 对话管理
2025-03-25 21:26:00 +08:00
Soulter 4c4246fb09 Merge pull request #982 from NiceAir/master
添加对gewe的表情包、引用消息、视频的支持
2025-03-25 21:25:00 +08:00
Soulter 364be1e9f6 🐛 fix: Handle missing defusedxml dependency for Gewechat message parsing 2025-03-25 21:21:38 +08:00
NiceAir f959ed71aa feat: Gewechat 支持表情包、引用消息、视频
Co-authored-by: Soulter <905617992@qq.com>
2025-03-25 21:00:12 +08:00
anka 5c4326c302 perf: 部分详细注释, 符合PEP8标准 2025-03-25 20:53:23 +08:00
Soulter 125fc3a622 feat: 支持 WebUI 对话管理 2025-03-25 19:44:46 +08:00
Soulter 6b9e785db3 Merge pull request #968 from Soulter/pre-commit-ci-update-config
🎈 pre-commit autoupdate
2025-03-25 15:03:39 +08:00
Soulter 25d34e9a43 Merge pull request #974 from zhx8702/feat-webui-add-search-keys
feat: 插件市场列表卡片过滤条件提出变量保持一致
2025-03-25 15:03:09 +08:00
Soulter 457d4aa1dc Merge pull request #976 from Raven95676/master
Improves Telegram adapter termination
2025-03-25 15:01:04 +08:00
Raven95676 ff0c0992ff Improves Telegram adapter termination 2025-03-25 14:46:20 +08:00
Soulter d379e012c4 🐛 fix: telegram /start issue #751 2025-03-25 14:03:46 +08:00
zhx 151fff26fd feat: 插件市场列表卡片过滤条件提出变量保持一致 2025-03-25 13:50:16 +08:00
Soulter 3d0d561215 Update compose.yml 2025-03-25 13:24:37 +08:00
Soulter 22d586ed7b Update compose.yml 2025-03-25 13:24:19 +08:00
Soulter 6dc19b29e8 🐛 fix: remove redundant validation call in config validation function #901 2025-03-25 12:56:48 +08:00
Soulter 50975a87d4 🐛 fix: handle message sending failures with error logging 2025-03-25 12:34:43 +08:00
Soulter ce721d9f0f 🐛 fix: platform adapter server blocks ctrl+c 2025-03-25 11:31:46 +08:00
Soulter 20510a33f7 feat: improve pyproject and use uv as package mgr 2025-03-25 11:07:20 +08:00
pre-commit-ci[bot] 3abd9c8763 🎈 pre-commit autoupdate
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.11.0 → v0.11.2](https://github.com/astral-sh/ruff-pre-commit/compare/v0.11.0...v0.11.2)
2025-03-24 17:08:12 +00:00
Soulter e9eff7420b feat: 更加完善和美观的 本地 Markdown 渲染 2025-03-25 00:56:19 +08:00
Soulter 64c250c9d8 🎈perf: 优化可能的 conversation 为 None 的问题 2025-03-25 00:06:25 +08:00
Soulter 8047f82bfd 🎈perf: 优化删除插件目录的逻辑,抛出异常细节;完善 mcp 未安装时的提示 2025-03-24 23:07:56 +08:00
Soulter af6467fb3d Merge pull request #962 from zhx8702/feat-webui-add-double-confirm
feat: 删除插件添加二次确认,插件列表添加非空判断
2025-03-24 23:01:43 +08:00
zhx 3ff1664aec feat: 删除多余代码 2025-03-24 20:27:05 +08:00
zhx 34ea2b44b8 Merge remote-tracking branch 'upstream/master' into feat-webui-add-double-confirm 2025-03-24 19:42:47 +08:00
Soulter 6c8d851109 Merge pull request #955 from Raven95676/master
Telegram适配器消息处理功能增强
2025-03-24 18:10:51 +08:00
Soulter d678299a74 Merge branch 'master' into master 2025-03-24 18:10:27 +08:00
Soulter 7aed0db2b6 Merge pull request #951 from IGCrystal/master
fix: fix SSLCertVerificationError
2025-03-24 18:05:49 +08:00
Soulter 0355524345 Merge branch 'master' into master 2025-03-24 17:58:00 +08:00
Soulter 0a43e4672e style: format codes 2025-03-24 17:57:28 +08:00
zhx 71e0ccdfec feat: 删除插件添加二次确认,插件列表添加非空判断 2025-03-24 16:41:54 +08:00
冰苷晶 1df33ac3c8 fix: fix error 2025-03-24 13:28:14 +08:00
pre-commit-ci[bot] 7334090ac1 🎈 auto fixes by pre-commit hooks 2025-03-24 05:20:37 +00:00
冰苷晶 6b0f044198 fix: fix other errors 2025-03-24 13:20:05 +08:00
pre-commit-ci[bot] ddf54c9cf8 🎈 auto fixes by pre-commit hooks 2025-03-24 04:32:21 +00:00
IGCrystal 7c64e184e2 Merge branch 'Soulter:master' into master 2025-03-24 12:32:16 +08:00
渡鸦95676 a904db033c Merge branch 'Soulter:master' into master 2025-03-24 12:19:17 +08:00
渡鸦95676 b234856b02 Remove unused variable
移除以通过ruff检查
在Ubuntu24.04LTS中,移除未见对现有功能的影响
2025-03-24 11:36:46 +08:00
Soulter 89d51d2afc 🎈 perf: config UI 2025-03-24 11:36:38 +08:00
Soulter 37cb9678e9 Merge pull request #826 from XuYingJie-cmd/master
新增了关于gewe发送视频的功能
2025-03-24 11:25:24 +08:00
pre-commit-ci[bot] 0500ff333a 🎈 auto fixes by pre-commit hooks 2025-03-24 02:50:28 +00:00
Raven95676 08528510ef Fix incorrect handling of reply messages within topics 2025-03-24 10:41:33 +08:00
Raven95676 ddbd03dc1e Adds sticker handling in Telegram adapter 2025-03-24 10:40:20 +08:00
Soulter ade87f378a 🎈 perf: UI 优化 2025-03-24 00:32:40 +08:00
冰苷晶 4db14b905f fix: fix error 2025-03-23 23:40:06 +08:00
pre-commit-ci[bot] b669b31451 🎈 auto fixes by pre-commit hooks 2025-03-23 15:07:22 +00:00
冰苷晶 1cb2b62f81 fix: fix error 2025-03-23 23:02:34 +08:00
Soulter e5828713cf 🎈 perf: improve ChatPage and ConfigPage UI 2025-03-23 22:57:02 +08:00
冰苷晶 d10cb84068 fix: fix SSLCertVerificationError 2025-03-23 22:55:07 +08:00
Soulter 4222f8516f Merge pull request #844 from AraragiEro/mcp_adapt
支持 MCP 服务并优化函数调用流程
2025-03-23 22:35:35 +08:00
Soulter 7f998c7611 chore: remove useless print output 2025-03-23 22:28:00 +08:00
Soulter db46000337 🎨 style: format codes 2025-03-23 22:22:11 +08:00
Soulter 1aac8d8041 feat: 适配完整的 function-calling 流程 2025-03-23 22:21:47 +08:00
Soulter c59c8e05f7 🐛 fix: tools result 2025-03-23 17:03:18 +08:00
Soulter 4942d0a629 feat: 在工具使用页面添加函数调用信息提示和链接功能 2025-03-23 17:00:38 +08:00
Soulter 873b7715f4 🎈 perf: 优化 MCP Client 异步 Event 管理 2025-03-23 16:51:28 +08:00
pre-commit-ci[bot] 98e7ed6920 🎈 auto fixes by pre-commit hooks 2025-03-23 08:34:05 +00:00
Soulter 046f5e645e feat: 完善 MCP 管理和实现 WebUI MCP 相关的页面 2025-03-23 16:33:44 +08:00
pre-commit-ci[bot] f5e5a7094c 🎈 auto fixes by pre-commit hooks 2025-03-23 06:39:13 +00:00
Gao Jinzhe 154125fee6 Add files via upload 2025-03-23 14:35:44 +08:00
pre-commit-ci[bot] 9f8e960ebe 🎈 auto fixes by pre-commit hooks 2025-03-23 03:31:20 +00:00
Soulter 4179b0be0a chore: 优化注解格式和 requirements.txt 2025-03-23 11:31:10 +08:00
Soulter 28bafa38db Merge branch 'master' into mcp_adapt 2025-03-23 11:01:44 +08:00
Soulter b07552565e Merge pull request #926 from Soulter/perf-graceful-shutdown
支持所有消息平台的优雅退出
2025-03-23 10:56:56 +08:00
Soulter c4427471d2 🎨 style: format codes 2025-03-23 00:25:26 +08:00
Soulter 08f81c6784 🐛 fix: 修复图片没有被存储到上下文中的问题 2025-03-23 00:23:42 +08:00
Soulter a471e98aca 🐛 fix: Telegram 下无法识别图片描述(Caption) #910 2025-03-23 00:23:01 +08:00
Soulter 75a8fcc8a0 🐛 fix: 修复 Telegram 下非默认群组话题引用消息异常 #906 2025-03-22 23:39:21 +08:00
Soulter 46ef76c168 feat: 支持消息平台的热重载 2025-03-22 19:54:54 +08:00
Soulter 66637446c9 Merge remote-tracking branch 'origin/master' into perf-graceful-shutdown 2025-03-22 19:26:35 +08:00
Soulter 21efeb888a Merge pull request #904 from LunarMeal/master
新增了newgroup指令
2025-03-22 19:18:06 +08:00
Soulter a4ee8b5322 Merge remote-tracking branch 'origin/master' into LunarMeal/master 2025-03-22 19:17:12 +08:00
Soulter 36519ac47e 🐛 fix: groupnew 设置为管理员指令 2025-03-22 19:14:58 +08:00
Soulter 3f514fceca 🎨 style: format codes 2025-03-22 19:07:47 +08:00
pre-commit-ci[bot] c2249fdfac 🎈 auto fixes by pre-commit hooks 2025-03-22 11:06:42 +00:00
Soulter c610719a44 feat: 为各平台适配器支持优雅关闭 2025-03-22 19:02:49 +08:00
Soulter 36a6c2461a 🐛 fix: 修复 Telegram Topic 群组下LLM 上下文及主动消息混乱的问题 #908 2025-03-22 18:15:43 +08:00
Soulter c29f22c39e Update PLUGIN_PUBLISH.yml 2025-03-22 15:51:35 +08:00
Soulter 30d3062944 🎈 perf: 优化钉钉在配置错误之后堵塞整个线程的问题 #885
a.k.a 帮钉钉擦屁股
2025-03-22 15:44:42 +08:00
Soulter 69ba75abf4 Update README.md 2025-03-22 01:26:03 +08:00
Soulter e4d486fec5 docs: 宝塔面板部署方式 2025-03-22 00:42:04 +08:00
Soulter f242144dcf 更新 README.md 2025-03-21 19:21:35 +08:00
Soulter 02dee2d664 🎈 perf: add error handling for missing pyffmpeg library in video sending functionality 2025-03-21 16:51:23 +08:00
Soulter a3dd2c3069 Merge remote-tracking branch 'origin/master' into XuYingJie-cmd/master 2025-03-21 16:49:15 +08:00
Soulter a23425e8aa Merge pull request #781 from Moyuyanli/master
添加gewe的群相关操作
2025-03-21 16:31:10 +08:00
Moyuyanli be79ddc9a3 fix:去掉跟post_text功能相同的接口方法 2025-03-21 16:24:31 +08:00
Soulter 7d71015e8c Update README.md 2025-03-21 16:12:25 +08:00
Soulter ad54549b51 Update README.md 2025-03-21 15:58:40 +08:00
Soulter 6cf032a164 Update compose.yml 2025-03-21 11:06:22 +08:00
Soulter 6390d796ac Update compose.yml 2025-03-21 11:05:44 +08:00
Soulter 98b8411905 Update compose.yml 2025-03-21 10:53:09 +08:00
LunarMeal ddf1029afa Merge branch 'master' of https://github.com/LunarMeal/AstrBot 2025-03-20 22:53:29 +08:00
LunarMeal 1effbc5cc9 fix 2025-03-20 22:53:21 +08:00
pre-commit-ci[bot] 414b645e9f 🎈 auto fixes by pre-commit hooks 2025-03-20 14:42:37 +00:00
LunarMeal 398c76f496 新增了newgroup指令 2025-03-20 22:39:49 +08:00
Soulter 1bc456dd95 🎈 perf: 改善一些术语描述 2025-03-20 20:31:36 +08:00
Soulter 2e8421884e Merge pull request #864 from Soulter/pre-commit-ci-update-config
🎈 pre-commit autoupdate
2025-03-20 20:23:45 +08:00
Soulter 70d9b193ac 🐛 fix: 修复私聊下 get_group 的一些问题 2025-03-20 20:18:20 +08:00
Moyuyanli b49c11004a fix:还原回原来的依赖信息 2025-03-20 19:57:35 +08:00
Soulter 34843eea90 🎨 style: format codes 2025-03-20 18:07:24 +08:00
pre-commit-ci[bot] 2d6d7f31e8 🎈 auto fixes by pre-commit hooks 2025-03-20 10:06:11 +00:00
Soulter 7a24cbff1c feat: 支持 aiocqhttp 适配器下的获取群消息 2025-03-20 18:05:44 +08:00
pre-commit-ci[bot] 1e7eb2cf1c 🎈 auto fixes by pre-commit hooks 2025-03-20 09:21:32 +00:00
Soulter 361256e016 chore: 添加了一些 gewechat client 的注释 2025-03-20 17:20:32 +08:00
Soulter 8838dbd003 🎨 style: format codes 2025-03-20 16:54:27 +08:00
pre-commit-ci[bot] 13a95e1f2b 🎈 auto fixes by pre-commit hooks 2025-03-20 08:42:40 +00:00
Soulter 1aaa451a3e Merge branch 'master' into Moyuyanli/master 2025-03-20 16:42:13 +08:00
Soulter cbba81e54d 🐛 fix: 无法接收图片 aiocqhttp 2025-03-20 16:03:41 +08:00
Soulter 370868dfac 🎈 perf: 消息平台和配置提供商配置页中,自动更新旧的配置,添加新的配置项 2025-03-20 13:22:49 +08:00
Soulter 77f692aae2 🎈 perf: 配置项显示优化 2025-03-20 13:17:27 +08:00
Soulter 9318e205ea feat: 阿里云百炼应用支持 RAG 应用 #878 2025-03-20 13:17:06 +08:00
Soulter ebcc717c19 🎈 perf: Dify 下支持更多类型的图片输入及提高代码复用性 #893
🐛 fix: 修复飞书下无法进行图片输入的问题
2025-03-20 11:21:45 +08:00
Soulter 4c16b564ee 🎈 perf: 忽略微信团队消息 #859 2025-03-19 01:09:01 +08:00
Soulter e2283d1453 🐛 fix: 修复 dify 下某些修改了 LLM 响应的插件可能不生效的问题 #876 2025-03-19 01:05:28 +08:00
Soulter d891801c5a v3.4.39 2025-03-18 22:43:35 +08:00
Soulter de75386944 🎈 perf: 登录后检查默认密码和弹出修改警告 2025-03-18 22:41:33 +08:00
Soulter 82dc37de50 style: format codes 2025-03-18 22:21:47 +08:00
Soulter b6fa7f62dc chore: 添加安全提示信息 2025-03-18 22:18:01 +08:00
Soulter f9e0a95c5e chore: 默认地址改回 0.0.0.0 2025-03-18 22:15:22 +08:00
pre-commit-ci[bot] b2c6e12647 🎈 auto fixes by pre-commit hooks 2025-03-17 17:10:06 +00:00
pre-commit-ci[bot] caffb83780 🎈 pre-commit autoupdate
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.9.10 → v0.11.0](https://github.com/astral-sh/ruff-pre-commit/compare/v0.9.10...v0.11.0)
2025-03-17 17:09:59 +00:00
Soulter 8882cb5479 v3.4.38 2025-03-18 00:54:51 +08:00
Soulter 75dace2dee 🎈 perf: 优化配置页的显示 2025-03-18 00:16:47 +08:00
Soulter ad6487d042 🐛 fix: 修复部分指令可能造成的配置类型问题 2025-03-17 23:44:04 +08:00
Soulter a91604e8ab Merge pull request #853 from IGCrystal/master
🎈 perf: 优化了iframe窗口,新增跳转按钮
2025-03-17 23:25:26 +08:00
Soulter c364f7c643 🎈 perf: Dify 下当只有图片输入时的默认 prompt #837 2025-03-17 23:17:07 +08:00
Soulter 53435ba184 🐛 fix: 修复 model_config 中自定义的配置项(如温度)类型自动变回 string #854 2025-03-17 23:11:57 +08:00
Soulter 25f8d5519b 🐛 fix: LLOnebot 合并消息转发错误 #842 2025-03-17 22:42:48 +08:00
Moyuyanli 2e4fef6c66 feat:添加消息记录器 2025-03-17 16:02:55 +08:00
冰苷晶 80b2b7dc00 🎈 perf: 优化了iframe窗口 2025-03-16 21:35:30 +08:00
Alero 8585cd8e21 修复codecheck 2025-03-15 20:26:17 +08:00
Alero 9fa2a7eeea 修复codecheck 2025-03-15 20:24:36 +08:00
pre-commit-ci[bot] 2d1f74228d 🎈 auto fixes by pre-commit hooks 2025-03-15 12:10:17 +00:00
Alero 3d6f7aa0e1 修复codecheck 2025-03-15 20:09:49 +08:00
pre-commit-ci[bot] 3dea60366a 🎈 auto fixes by pre-commit hooks 2025-03-15 11:54:09 +00:00
Alero d4d9a1df4c feat:新增MCP服务支持并优化工具调用逻辑
引入MCP客户端支持,增加mcp_server.json配置样例,完善工具描述生成及调用逻辑以支持MCP服务工具功能。同时调整相关逻辑以区分本地工具与MCP工具的调用方式,提升扩展性和灵活性。
2025-03-15 19:47:06 +08:00
Soulter 7d6975fd31 Merge pull request #832 from IGCrystal/master
🎈 perf: 优化iframe窗口,加入了关闭按钮
2025-03-15 14:25:16 +08:00
IGCrystal 08be52ed17 Merge branch 'Soulter:master' into master 2025-03-15 12:05:27 +08:00
邹永赫 682a7700c2 Merge pull request #835 from zouyonghe/master
修改注册函数工具时的打印信息
2025-03-15 12:20:32 +09:00
pre-commit-ci[bot] 9d87009216 🎈 auto fixes by pre-commit hooks 2025-03-15 03:16:51 +00:00
邹永赫 ef86838f62 修改注册函数工具时的打印信息 2025-03-15 12:15:05 +09:00
Soulter 35468233f8 🎈 perf: supports for customizing webui host, wecom webhook server host, qq official webhook server host #821 2025-03-15 01:21:36 +08:00
Soulter 26e229867d 🐛fix: 可能的QQ平台回复消息带有末尾空白的问题 #822 2025-03-15 00:57:17 +08:00
Soulter 3a1578b3c6 feat: 支持 Dify 文件、图片、视频、音频输出。#819 2025-03-15 00:51:32 +08:00
冰苷晶 d5e3d2cbbc 🎈 perf: 优化iframe窗口,加入了关闭按钮 2025-03-14 20:23:15 +08:00
Moyuyanli c095248176 Merge remote-tracking branch 'origin/master' 2025-03-14 18:30:42 +08:00
Moyuyanli 44601c8954 fix:修复gewe的ModContacts消息类型 2025-03-14 18:30:27 +08:00
Soulter 135dbb8f07 style: clean codes 2025-03-14 18:02:00 +08:00
pre-commit-ci[bot] c95682a0c7 🎈 auto fixes by pre-commit hooks 2025-03-14 09:11:21 +00:00
Moyuyanli d177b9f7fa feat:添加主动添加好友事件 2025-03-14 17:11:10 +08:00
徐英杰 9b57615d94 新增了关于gewe发送视频的功能 2025-03-14 16:19:41 +08:00
Soulter c03f3eacd1 Update README.md 2025-03-13 23:03:36 +08:00
Soulter a26e395932 Merge pull request #817 from Soulter/feat-parse-reply
[Feature] 添加了 LLM 对消息平台引用回复内容的感知
2025-03-13 21:06:44 +08:00
Soulter 0870b87c96 🐛 fix: 获取引用消息失败时没有将引用消息段加入消息链 2025-03-13 20:59:52 +08:00
Soulter b52a44a7dd 🎨 stype: format codes 2025-03-13 20:44:08 +08:00
Soulter 0a290aafef Merge pull request #815 from diudiu62/perf-gewechat
微信有未处理的消息类型,导致控制台打印太多的日志
2025-03-13 20:39:39 +08:00
Soulter 9014d4c410 🎨 style: format codes 2025-03-13 20:36:41 +08:00
pre-commit-ci[bot] 60e58b4f5f 🎈 auto fixes by pre-commit hooks 2025-03-13 09:52:03 +00:00
Soulter 620e74a6aa Merge branch 'master' into feat-parse-reply 2025-03-13 17:51:12 +08:00
Soulter efa287ed35 feat: 支持 LLM 对引用消息的感知 #783 2025-03-13 17:40:28 +08:00
Soulter a24eb9d9b0 🏗 refactor: clean up AstrBotConfig component markup for improved readability 2025-03-13 17:02:58 +08:00
Soulter bd3dab8aae 🐛 fix: 插件管理的插件简介太长 “帮助”“操作”图标不显示 #790 2025-03-13 17:02:58 +08:00
Soulter 4fe1ebaa5b 🏗 refactor: improve styling and layout of AstrBotConfig component for enhanced readability 2025-03-13 17:02:58 +08:00
Soulter c5e944744b 🏗 refactor: enhance ConfigPage layout and styling for better user experience 2025-03-13 17:02:58 +08:00
Soulter 0c396181f7 🏗 refactor: 配置页样式重写 2025-03-13 17:02:58 +08:00
Soulter 0034474219 🐛 fix: sent message to wrong topic in topic group #801 2025-03-13 17:02:58 +08:00
shuiping233 8136ad8287 修复命令参数报错信息无法发送至qq官方机器人平台的bug 2025-03-13 17:02:58 +08:00
Soulter 681940d466 🐛 fix: 修复重载插件时函数工具可能多次家在的问题 2025-03-13 17:02:58 +08:00
Soulter 16488506e8 🐛 fix: 修复部分情况下文件无法上传到 Telegram 群组的问题 #601 2025-03-13 17:02:58 +08:00
邹永赫 122fccc041 修复无法发送非嵌套的转发消息的问题 2025-03-13 17:02:58 +08:00
邹永赫 9d0ad35403 支持嵌套转发,里层包含多条信息 2025-03-13 17:02:58 +08:00
邹永赫 f9ec97e026 支持嵌套转发 2025-03-13 17:02:58 +08:00
Soulter 95495a2647 🏗 refactor: clean up AstrBotConfig component markup for improved readability 2025-03-13 16:40:59 +08:00
Soulter e3310a605c 🐛 fix: 插件管理的插件简介太长 “帮助”“操作”图标不显示 #790 2025-03-13 16:36:35 +08:00
Soulter b55719bf28 🏗 refactor: improve styling and layout of AstrBotConfig component for enhanced readability 2025-03-13 15:59:20 +08:00
diudiu62 b957b51279 已知消息类型,没有业务处理,只是避免控制台打印太多的日志 2025-03-13 15:55:22 +08:00
Soulter 90bcfab369 🏗 refactor: enhance ConfigPage layout and styling for better user experience 2025-03-13 15:44:52 +08:00
Soulter f8a8e30641 🏗 refactor: 配置页样式重写 2025-03-13 15:37:53 +08:00
Soulter 25cb98e7a7 🐛 fix: sent message to wrong topic in topic group #801 2025-03-13 13:02:22 +08:00
Soulter 03e1bb7cf9 Merge pull request #807 from shuiping233/fix-#806
修复命令参数报错信息无法发送至qq官方机器人平台的bug
2025-03-13 10:05:24 +08:00
Soulter 85dbb24f3a 🐛 fix: 修复重载插件时函数工具可能多次家在的问题 2025-03-12 23:37:24 +08:00
shuiping233 d817635782 修复命令参数报错信息无法发送至qq官方机器人平台的bug 2025-03-12 18:09:25 +08:00
Soulter 2f4f237810 🐛 fix: 修复部分情况下文件无法上传到 Telegram 群组的问题 #601 2025-03-12 14:14:45 +08:00
邹永赫 5ac94d810f Merge pull request #794 from zouyonghe/dev/nested-forward
修复无法发送非嵌套的转发消息的问题
2025-03-12 12:01:33 +09:00
邹永赫 39dc46dc25 修复无法发送非嵌套的转发消息的问题 2025-03-12 11:59:53 +09:00
邹永赫 0d9cf725f7 Merge pull request #792 from zouyonghe/dev/nested-forward
支持嵌套转发,里层包含多条信息
2025-03-12 11:17:16 +09:00
邹永赫 e55dbead5b 支持嵌套转发,里层包含多条信息 2025-03-12 11:14:54 +09:00
邹永赫 7d046e5b30 Merge pull request #788 from zouyonghe/dev/nested-forward
支持嵌套转发
2025-03-12 08:50:50 +09:00
邹永赫 8b4693cf66 支持嵌套转发 2025-03-12 08:39:54 +09:00
Soulter a1172c9a82 feat: 支持解析回复消息 #783 2025-03-11 23:27:10 +08:00
Soulter 1ed2bd33f0 🐛 fix: 修复插件更新时显示未知更新的问题 2025-03-11 22:38:25 +08:00
Soulter 4c159bd0ba Merge pull request #785 from shuiping233/fix-qq-offical-image-upload-issue
修复了使用Image.fromBytes等包装的图片消息链无法通过qq官方机器人适配器发送的bug
2025-03-11 22:10:27 +08:00
Soulter 050654b2a9 🐛 fix: 修复 QQ 官方机器人适配器下发送base64图片消息段报错的问题。
Co-authored-by: shuiping233 <1944680304@qq.com>
2025-03-11 22:08:13 +08:00
Soulter 61b261e1b2 Merge pull request #780 from beat4ocean/master
fix: 修复gewechat平台用户本人发消息触发消息回复的bug
2025-03-11 21:55:44 +08:00
shuiping233 017b010206 修复了使用Image.fromBytes等包装的图片消息链无法通过qq官方机器人适配器发送的bug 2025-03-11 21:17:08 +08:00
pre-commit-ci[bot] 00f5189f58 🎈 auto fixes by pre-commit hooks 2025-03-11 09:16:43 +00:00
Moyuyanli 4a8309ed1f style:idea默认格式化了部分代码
feat:添加根据消息事件获取群信息的接口
2025-03-11 17:10:55 +08:00
Moyuyanli 76cfc31a1d feat:添加 Group 类型 2025-03-11 17:10:04 +08:00
Moyuyanli d9ec434699 feat:gewe的client添加 添加好友接口
feat:gewe的client添加 获取群信息/群成员接口
feat:gewe的client添加 添加群成员为好友接口
2025-03-11 17:08:33 +08:00
Soulter 239f3c40be 🎈 perf: 优化 WebUI 边栏宽度 2025-03-11 16:11:34 +08:00
Soulter 09c8c6e670 🐛 fix: 修复 aiocqhttp 下可能的设置管理员无效的问题 2025-03-11 15:52:30 +08:00
beat4ocean 7e4ad01c94 Merge branch 'Soulter:master' into master 2025-03-11 15:52:23 +08:00
beat4ocean ed98e269ef Merge remote-tracking branch 'origin/master' 2025-03-11 15:48:44 +08:00
beat4ocean b47d63334f fix: 修复gewechat平台用户本人发消息触发消息回复的bug 2025-03-11 15:48:28 +08:00
Soulter 5e2a3a5aea fix: 修复部分情况下 EdgeTTS 无法使用的问题
Co-authored-by: 需要哦 <2687427560@qq.com>
2025-03-11 15:29:51 +08:00
Soulter 1a7eb21fc7 Revert "🐛 fix: 修复 gewechat 部分场景下下载图片报错 #700"
This reverts commit c38fa77ce6.
2025-03-11 14:54:41 +08:00
Soulter 834a51cdc9 🐛 fix: 修复 OpenAI TTS API TypeError 报错 #755 2025-03-11 14:30:59 +08:00
Soulter 1b69d99c06 🐛 fix: 修复更新插件后插件重载不完全的问题 2025-03-11 14:20:24 +08:00
Soulter ad189933c6 Merge pull request #775 from roeseth/master
update compose.yml to mount system time and tz
2025-03-11 12:49:38 +08:00
Soulter 9d86ff32de Merge pull request #774 from Soulter/pre-commit-ci-update-config
🎈 pre-commit autoupdate
2025-03-11 11:40:57 +08:00
Soulter 278bb57a58 Merge pull request #772 from beat4ocean/master
fix: 修复个人微信非第一次登陆情况,已记录gewechat的appid失效设备不存在导致无法重新登陆个人微信的bug
2025-03-11 11:40:07 +08:00
pre-commit-ci[bot] 0ba494e0ba 🎈 auto fixes by pre-commit hooks 2025-03-11 02:11:25 +00:00
roeseth 8b247054bb update compose.yml to mount system time and tz 2025-03-10 19:07:45 -07:00
pre-commit-ci[bot] 7c5c8e4e0d 🎈 auto fixes by pre-commit hooks 2025-03-11 00:55:01 +00:00
beat4ocean ad106a27f3 Merge branch 'Soulter:master' into master 2025-03-11 08:54:55 +08:00
beat4ocean 9d6f61b49e fix: 修复非第一次登陆情况,已记录的gewechat的appid失效设备不存在导致无法重新登陆的bug 2025-03-11 08:48:37 +08:00
pre-commit-ci[bot] 02368954a0 🎈 auto fixes by pre-commit hooks 2025-03-10 17:09:25 +00:00
pre-commit-ci[bot] b477a35a01 🎈 pre-commit autoupdate
updates:
- [github.com/astral-sh/ruff-pre-commit: v0.9.9 → v0.9.10](https://github.com/astral-sh/ruff-pre-commit/compare/v0.9.9...v0.9.10)
2025-03-10 17:09:18 +00:00
Soulter 16622887de perf: 在调用插件异常时更完整的报错信息 2025-03-11 00:47:37 +08:00
Soulter 9059d1fb17 feat: 支持在对话隔离情况下可以将群聊加入白名单 #746 2025-03-11 00:34:29 +08:00
Soulter df2b008d82 Merge pull request #744 from roeseth/fix-local-timezone
Use system local time zone instead of hardcoded UTC+8
2025-03-11 00:21:43 +08:00
Soulter 0da871efd0 chore: 日志完善 2025-03-10 23:58:42 +08:00
Soulter 1c55349f81 fix: 钉钉 webui 文档 2025-03-10 23:58:42 +08:00
Soulter 9309fa1e81 修复fishaudio默认baseurl不可用的问题 2025-03-10 01:32:26 +08:00
Soulter 5996189f91 Update README.md 2025-03-09 22:25:45 +08:00
Soulter bd2b984bfb v3.4.37 2025-03-09 22:14:23 +08:00
pre-commit-ci[bot] 194409a117 🎈 auto fixes by pre-commit hooks 2025-03-09 13:23:52 +00:00
roeseth 27978b216d use system local timezone instead of hardcoded UTC+8 2025-03-09 06:18:53 -07:00
Soulter c38fa77ce6 🐛 fix: 修复 gewechat 部分场景下下载图片报错 #700 2025-03-09 18:10:38 +08:00
Soulter 3eb49f7422 feat: 支持设置私聊是否需要唤醒前缀唤醒 #735 2025-03-09 18:03:23 +08:00
Soulter 1989d615d2 🌈 style: format codes 2025-03-09 17:48:59 +08:00
Soulter 239412d265 feat: 支持接入钉钉 #643 2025-03-09 17:47:51 +08:00
Soulter 375a419a9e Merge pull request #732 from xiewoc/master
Update aiocqhttp_platform_adapter.py
2025-03-09 12:36:48 +08:00
Soulter 875c8ab424 ci: upate astrbot webui build cis 2025-03-09 11:31:10 +08:00
Soulter c9bfc810ce ci: upload astrbot webui build ci 2025-03-09 11:26:10 +08:00
Soulter 46ecb16949 🐛 fix: 无法正常保存插件的 list 类型配置 #737 2025-03-09 11:12:24 +08:00
Soulter f6dc16f17b style: format codes 2025-03-08 20:55:25 +08:00
Soulter 4eef42f730 refactor: 移除未使用的 defineEmits 导入 2025-03-08 20:53:43 +08:00
Soulter 8612d9a771 docs: update changelogs 2025-03-08 20:37:46 +08:00
Soulter 0caff054f5 feat: 会话控制器支持自定义会话ID算子 2025-03-08 20:29:42 +08:00
Soulter 4aa91ad599 feat: 支持当消息只有@bot时,下一条发送人的消息直接唤醒机器人 2025-03-08 19:55:24 +08:00
Soulter 7a0864f5c2 feat: 推荐插件页面 2025-03-08 18:58:50 +08:00
Soulter 73dc0dfcf6 perf: 插件市场支持显示插件 logo 2025-03-08 17:31:08 +08:00
Soulter 1ff9a69339 chore: plugin logo 2025-03-08 17:23:25 +08:00
Soulter 179eb5d847 feat: 优化了插件卡片的 UI,插件卡片支持显示 logo 2025-03-08 17:13:36 +08:00
Soulter 52c868828c perf: 插件更新、保存配置均支持热重载 2025-03-08 15:22:56 +08:00
Soulter 7eea4615b6 perf: 优化了日志显示 2025-03-08 15:22:22 +08:00
Soulter d9b351df1a fix: 修复主动人格情况下人格失效的问题 #719 #712 2025-03-08 14:14:14 +08:00
pre-commit-ci[bot] d6a785b645 🎈 auto fixes by pre-commit hooks 2025-03-08 04:33:19 +00:00
xiewoc 79db828a01 Update aiocqhttp_platform_adapter.py 2025-03-08 12:30:49 +08:00
Soulter a5ffb0f8dc perf: 安装/更新插件后直接热重载而不重启;更新 plugin 指令 2025-03-08 00:20:48 +08:00
Soulter 9492fcde74 perf: 完善了插件的启用和禁用的生命周期管理 2025-03-07 23:44:07 +08:00
Soulter d2456ce4cd Update README.md 2025-03-07 10:52:09 +08:00
Soulter 7de27abc8d 🐛 fix: Telegram适配器使用代理地址无法获取图片 #723 2025-03-07 09:05:00 +08:00
Soulter d8155bc8eb 🐛 fix: Telegram适配器使用代理地址无法获取图片 #723 2025-03-07 00:42:15 +08:00
Soulter cf08e52a92 style: cleanup 2025-03-06 23:52:15 +08:00
Soulter 768398b991 feat: 支持 gewechat 图片等更多类型的主动消息 #710 2025-03-06 22:26:58 +08:00
Soulter 24c20a19f1 feat: 支持插件会话控制 API 2025-03-06 22:13:14 +08:00
Soulter 8fbcbcd4c0 🐛 fix: webchat cannot send active image message #710 2025-03-05 22:34:37 +08:00
Soulter e0da5bb943 chore: delete some files for project safety 2025-03-05 19:05:50 +08:00
Soulter 36fbc4fb82 Update README.md 2025-03-05 18:55:40 +08:00
Soulter cb11051f42 Update README.md 2025-03-05 17:56:23 +08:00
Soulter a824781d14 Update README.md 2025-03-05 17:55:06 +08:00
Soulter 600a2c6748 🐛 fix: context.get_platform() error 2025-03-05 13:28:55 +08:00
Soulter 77df64bfb5 🐛 fix: 修复插件在带了 __del__ 之后无法被禁用和重载的问题 2025-03-05 11:33:01 +08:00
Soulter 2d6e54903c Update README.md 2025-03-05 00:58:44 +08:00
Soulter baa2b83df9 🐛 fix: telegram cannot handle /start #620 2025-03-05 00:40:38 +08:00
Soulter 1ff02446af 🐛 fix: 404 error after installing plugins 2025-03-04 23:39:01 +08:00
Soulter b58c6ba762 feat: add template of lmstudio #691 2025-03-04 23:38:33 +08:00
155 changed files with 14897 additions and 3022 deletions
+5 -4
View File
@@ -6,7 +6,7 @@ body:
- type: markdown
attributes:
value: |
欢迎发布插件到插件市场!
欢迎发布插件到插件市场!请确保您的插件经过**完整的**测试。
- type: textarea
attributes:
@@ -22,9 +22,10 @@ body:
插件名:
插件作者:
插件简介:
标签: (可选)
社交链接: (可选, 将会在插件市场作者名称上作为可点击的链接)
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。
支持的消息平台:(必填,如 QQ、微信、飞书)
标签:(可选)
社交链接:(可选, 将会在插件市场作者名称上作为可点击的链接)
description: 必填。请以列表的字段按顺序将插件名、插件作者、插件简介放在这里。如果您不知道支持哪些消息平台,请填写测试过的消息平台。
- type: checkboxes
attributes:
+31
View File
@@ -0,0 +1,31 @@
name: AstrBot Dashboard CI
on: [push]
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: npm install, build
run: |
cd dashboard
npm install
npm run build
- name: Inject Commit SHA
id: get_sha
run: |
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
mkdir -p dashboard/dist/assets
echo $COMMIT_SHA > dashboard/dist/assets/version
- name: Archive production artifacts
uses: actions/upload-artifact@v4
with:
name: dist-without-markdown
path: |
dashboard/dist
!dist/**/*.md
+4
View File
@@ -1,6 +1,8 @@
__pycache__
botpy.log
.vscode
.venv*
.idea
data_v2.db
data_v3.db
configs/session
@@ -26,3 +28,5 @@ venv/*
packages/python_interpreter/workplace
.venv/*
.conda/
.idea
pytest.ini
+1 -1
View File
@@ -7,7 +7,7 @@ ci:
autoupdate_commit_msg: ":balloon: pre-commit autoupdate"
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.9
rev: v0.11.2
hooks:
- id: ruff
- id: ruff-format
+10 -2
View File
@@ -9,12 +9,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
python3-dev \
libffi-dev \
libssl-dev \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
RUN python -m pip install -r requirements.txt --no-cache-dir
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
# 释出 ffmpeg
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
EXPOSE 6185
EXPOSE 6186
+35
View File
@@ -0,0 +1,35 @@
FROM python:3.10-slim
WORKDIR /AstrBot
COPY . /AstrBot/
RUN apt-get update && apt-get install -y --no-install-recommends \
gcc \
build-essential \
python3-dev \
libffi-dev \
libssl-dev \
curl \
unzip \
ca-certificates \
bash \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
# Installation of Node.js
ENV NVM_DIR="/root/.nvm"
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
. "$NVM_DIR/nvm.sh" && \
nvm install 22 && \
nvm use 22
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
RUN python -m pip install uv
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
EXPOSE 6185
EXPOSE 6186
CMD ["python", "main.py"]
+55 -44
View File
@@ -1,6 +1,6 @@
<p align="center">
![6e1279651f16d7fdf4727558b72bbaf1](https://github.com/user-attachments/assets/ead4c551-fc3c-48f7-a6f7-afbfdb820512)
![yjtp](https://github.com/user-attachments/assets/dcc74009-c57e-4b66-9ae3-0a81fc001255)
</p>
@@ -10,13 +10,12 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-630166526-purple"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B6%88%E6%81%AF%E4%B8%8A%E8%A1%8C%E9%87%8F&cacheSeconds=3600)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot?style=for-the-badge&color=76bad9)](https://github.com/Soulter/AstrBot/releases/latest)
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg?style=for-the-badge&color=76bad9)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
![Dynamic JSON Badge](https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fstats&query=v&label=7%E6%97%A5%E6%B4%BB%E8%B7%83%E9%87%8F&cacheSeconds=10800&style=for-the-badge&color=3b618e)
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a>
@@ -26,19 +25,31 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
[![star](https://gitcode.com/Soulter/AstrBot/star/badge.svg?style=for-the-badge)](https://gitcode.com/Soulter/AstrBot)
<!-- [![codecov](https://img.shields.io/codecov/c/github/soulter/astrbot?style=for-the-badge)](https://codecov.io/gh/Soulter/AstrBot)
-->
## ✨ 近期更新
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
## ✨ 主要功能
> [!NOTE]
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
> [!TIP]
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
>
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
> 用户名: `astrbot`, 密码: `astrbot`。
## ✨ 使用方式
@@ -48,22 +59,33 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
#### Windows 一键安装器部署
需要电脑上安装有 Python>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
#### Replit 部署
#### 宝塔面板部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
请参阅官方文档 [宝塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html) 。
#### CasaOS 部署
社区贡献的部署方式。
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
请参阅官方文档 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html) 。
#### 手动部署
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)
推荐使用 `uv`
```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
pip install uv
uv run main.py
```
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
#### Replit 部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
## ⚡ 消息平台支持情况
@@ -74,7 +96,8 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
| 微信(个人号) | ✔ | 微信个人号私聊、群聊 | 文字、图片、语音 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| [微信(企业微信)](https://github.com/Soulter/astrbot_plugin_wecom) | ✔ | 私聊 | 文字、图片、语音 |
| 飞书 | ✔ | 群聊 | 文字、图片 |
| 飞书 | ✔ | 私聊、群聊 | 文字、图片 |
| 钉钉 | ✔ | 私聊、群聊 | 文字、图片 |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| Discord | 🚧 | 计划内 | - |
| WhatsApp | 🚧 | 计划内 | - |
@@ -84,7 +107,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
| 名称 | 支持性 | 类型 | 备注 |
| -------- | ------- | ------- | ------- |
| OpenAI API | ✔ | 文本生成 | 同时也支持 DeepSeek、Google Gemini、GLM(智谱)、Moonshot(月之暗面)、阿里云百炼、硅基流动、xAI 等所有兼容 OpenAI API 的服务 |
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、硅基流动、xAI 等兼容 OpenAI API 的服务 |
| Claude API | ✔ | 文本生成 | |
| Google Gemini API | ✔ | 文本生成 | |
| Dify | ✔ | LLMOps | |
@@ -96,6 +119,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
| Whisper | ✔ | 语音转文本 | 支持 API、本地部署 |
| SenseVoice | ✔ | 语音转文本 | 本地部署 |
| OpenAI TTS API | ✔ | 文本转语音 | |
| GSVI | ✔ | 文本转语音 | GPT-Sovits-Inference |
| Fishaudio | ✔ | 文本转语音 | GPT-Sovits 作者参与的项目 |
| Edge-TTS | ✔ | 文本转语音 | Edge 浏览器的免费 TTS |
@@ -125,38 +149,36 @@ pre-commit install
## ✨ Demo
> [!NOTE]
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ 自然语言待办事项 ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
_管理面板_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内置 Web Chat,在线与机器人交互 ✨_
_WebUI_
</div>
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
## ⭐ Star History
> [!TIP]
@@ -174,16 +196,5 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
2. The deployment of WeChat (personal account) utilizes [Gewechat](https://github.com/Devo919/Gewechat) service. AstrBot only guarantees connectivity with Gewechat and recommends using a WeChat account that is not frequently used. In the event of account risk control, the author of this project shall not bear any responsibility.
3. Please ensure compliance with local laws and regulations when using this project.
<!-- ## ✨ ATRI [Beta 测试]
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
_私は、高性能ですから!_
+1 -1
View File
@@ -28,7 +28,7 @@ AstrBot is a loosely coupled, asynchronous chatbot and development framework tha
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://astrbot.app/others/dify.html) for easy access to Dify assistants/knowledge bases/workflows.
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
+1 -1
View File
@@ -28,7 +28,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
2. **複数のメッセージプラットフォームの接続**。QQOneBot)、QQ チャンネル、WeChatGewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
+2
View File
@@ -5,6 +5,7 @@ from astrbot.core.platform import (
MessageMember,
MessageType,
PlatformMetadata,
Group,
)
from astrbot.core.platform.register import register_platform_adapter
@@ -18,4 +19,5 @@ __all__ = [
"MessageType",
"PlatformMetadata",
"register_platform_adapter",
"Group",
]
+2 -6
View File
@@ -2,11 +2,7 @@ from astrbot.core.star.register import (
register_star as register, # 注册插件(Star
)
from astrbot.core.star import Context, Star
from astrbot.core.star import Context, Star, StarTools
from astrbot.core.star.config import *
__all__ = [
"register",
"Context",
"Star",
]
__all__ = ["register", "Context", "Star", "StarTools"]
+7
View File
@@ -0,0 +1,7 @@
from astrbot.core.utils.session_waiter import (
SessionWaiter,
SessionController,
session_waiter,
)
__all__ = ["SessionWaiter", "SessionController", "session_waiter"]
+5 -1
View File
@@ -8,6 +8,7 @@ from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
# 初始化数据存储文件夹
os.makedirs("data", exist_ok=True)
astrbot_config = AstrBotConfig()
@@ -19,8 +20,11 @@ if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
db_helper = SQLiteDatabase(DB_PATH)
sp = SharedPreferences() # 简单的偏好设置存储
sp = (
SharedPreferences()
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
DEMO_MODE = os.getenv("DEMO_MODE", False)
+179 -18
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.35"
VERSION = "3.5.2"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -36,6 +36,8 @@ DEFAULT_CONFIG = {
"content_cleanup_rule": "",
},
"no_permission_reply": True,
"empty_mention_waiting": True,
"friend_message_needs_wake_prefix": False,
},
"provider": [],
"provider_settings": {
@@ -47,6 +49,7 @@ DEFAULT_CONFIG = {
"datetime_system_prompt": True,
"default_personality": "default",
"prompt_prefix": "",
"max_context_length": -1,
},
"provider_stt_settings": {
"enable": False,
@@ -78,21 +81,24 @@ DEFAULT_CONFIG = {
"admins_id": ["astrbot"],
"t2i": False,
"t2i_word_threshold": 150,
"t2i_strategy": "remote",
"t2i_endpoint": "",
"http_proxy": "",
"dashboard": {
"enable": True,
"username": "astrbot",
"password": "77b90590a8945a7d36c963981a307dc9",
"host": "0.0.0.0",
"port": 6185,
},
"platform": [],
"wake_prefix": ["/"],
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": "",
"knowledge_db": {},
"persona": [],
"timezone": "",
}
@@ -120,6 +126,7 @@ CONFIG_METADATA_2 = {
"enable": False,
"appid": "",
"secret": "",
"callback_server_host": "0.0.0.0",
"port": 6196,
},
"aiocqhttp(OneBotv11)": {
@@ -144,10 +151,11 @@ CONFIG_METADATA_2 = {
"enable": False,
"corpid": "",
"secret": "",
"port": 6195,
"token": "",
"encoding_aes_key": "",
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
"callback_server_host": "0.0.0.0",
"port": 6195,
},
"lark(飞书)": {
"id": "lark",
@@ -158,6 +166,13 @@ CONFIG_METADATA_2 = {
"app_secret": "",
"domain": "https://open.feishu.cn",
},
"dingtalk(钉钉)": {
"id": "dingtalk",
"type": "dingtalk",
"enable": False,
"client_id": "",
"client_secret": "",
},
"telegram": {
"id": "telegram",
"type": "telegram",
@@ -165,6 +180,7 @@ CONFIG_METADATA_2 = {
"telegram_token": "your_bot_token",
"start_message": "Hello, I'm AstrBot!",
"telegram_api_base_url": "https://api.telegram.org/bot",
"telegram_file_base_url": "https://api.telegram.org/file/bot",
},
},
"items": {
@@ -210,7 +226,7 @@ CONFIG_METADATA_2 = {
"hint": "启用后,机器人可以接收到频道的私聊消息。",
},
"ws_reverse_host": {
"description": "反向 Websocket 主机地址",
"description": "反向 Websocket 主机地址(AstrBot 为服务器端)",
"type": "string",
"hint": "aiocqhttp 适配器的反向 Websocket 服务器 IP 地址,不包含端口号。",
},
@@ -256,6 +272,16 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "启用后,当用户没有权限执行某个操作时,机器人会回复一条消息。",
},
"empty_mention_waiting": {
"description": "只 @ 机器人是否触发等待回复",
"type": "bool",
"hint": "启用后,当消息内容只有 @ 机器人时,会触发等待回复,在 60 秒内的该用户的任意一条消息均会唤醒机器人。这在某些平台不支持 @ 和语音/图片等消息同时发送时特别有用。",
},
"friend_message_needs_wake_prefix": {
"description": "私聊消息是否需要唤醒前缀",
"type": "bool",
"hint": "启用后,私聊消息需要唤醒前缀才会被处理,同群聊一样。",
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
@@ -322,7 +348,7 @@ CONFIG_METADATA_2 = {
"type": "list",
"items": {"type": "string"},
"obvious_hint": True,
"hint": "AstrBot 只处理填写的 ID 发来的消息事件为空时不启用白名单过滤。可使用 /sid 指令获取在某个平台上的会话 ID。也可在 AstrBot 日志内获取会话 ID,当一条消息没通过白名单时,会输出 INFO 级别的日志。会话 ID 类似 aiocqhttp:GroupMessage:547540978。管理员可使用 /wl 添加白名单",
"hint": "只处理填写的 ID 发来的消息事件为空时不启用。可使用 /sid 指令获取在平台上的会话 ID(类似 abc:GroupMessage:123)。管理员可使用 /wl 添加白名单",
},
"id_whitelist_log": {
"description": "打印白名单日志",
@@ -465,6 +491,16 @@ CONFIG_METADATA_2 = {
"model": "llama3.1-8b",
},
},
"LM_Studio": {
"id": "lm_studio",
"type": "openai_chat_completion",
"enable": True,
"key": ["lmstudio"],
"api_base": "http://localhost:1234/v1",
"model_config": {
"model": "llama-3.1-8b",
},
},
"Gemini(OpenAI兼容)": {
"id": "gemini_default",
"type": "openai_chat_completion",
@@ -484,7 +520,14 @@ CONFIG_METADATA_2 = {
"api_base": "https://generativelanguage.googleapis.com/",
"timeout": 120,
"model_config": {
"model": "gemini-1.5-flash",
"model": "gemini-2.0-flash-exp",
},
"gm_resp_image_modal": False,
"gm_safety_settings": {
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
},
},
"DeepSeek": {
@@ -548,7 +591,7 @@ CONFIG_METADATA_2 = {
"dify_api_type": "chat",
"dify_api_key": "",
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
"dify_workflow_output_key": "astrbot_wf_output",
"dify_query_input_key": "astrbot_text_query",
"variables": {},
"timeout": 60,
@@ -560,6 +603,11 @@ CONFIG_METADATA_2 = {
"dashscope_app_type": "agent",
"dashscope_api_key": "",
"dashscope_app_id": "",
"rag_options": {
"pipeline_ids": [],
"file_ids": [],
"output_reference": False,
},
"variables": {},
"timeout": 60,
},
@@ -626,12 +674,106 @@ CONFIG_METADATA_2 = {
"type": "fishaudio_tts_api",
"enable": False,
"api_key": "",
"api_base": "https://api.fish-audio.cn/v1",
"api_base": "https://api.fish.audio/v1",
"fishaudio-tts-character": "可莉",
"timeout": "20",
},
"阿里云百炼_TTS(API)": {
"id": "dashscope_tts",
"type": "dashscope_tts",
"enable": False,
"api_key": "",
"model": "cosyvoice-v1",
"dashscope_tts_voice": "loongstella",
"timeout": "20",
},
},
"items": {
"dashscope_tts_voice": {
"description": "语音合成模型",
"type": "string",
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
},
"gm_resp_image_modal": {
"description": "启用图片模态",
"type": "bool",
"hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。",
},
"gm_safety_settings": {
"description": "安全过滤器",
"type": "object",
"hint": "设置模型输入的内容安全过滤级别。过滤级别分类为NONE(不屏蔽)、HIGH(高风险时屏蔽)、MEDIUM_AND_ABOVE(中等风险及以上屏蔽)、LOW_AND_ABOVE(低风险及以上时屏蔽),具体参见Gemini API文档。",
"items": {
"harassment": {
"description": "骚扰内容",
"type": "string",
"hint": "负面或有害评论",
"options": [
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
],
},
"hate_speech": {
"description": "仇恨言论",
"type": "string",
"hint": "粗鲁、无礼或亵渎性质内容",
"options": [
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
],
},
"sexually_explicit": {
"description": "露骨色情内容",
"type": "string",
"hint": "包含性行为或其他淫秽内容的引用",
"options": [
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
],
},
"dangerous_content": {
"description": "危险内容",
"type": "string",
"hint": "宣扬、助长或鼓励有害行为的信息",
"options": [
"BLOCK_NONE",
"BLOCK_ONLY_HIGH",
"BLOCK_MEDIUM_AND_ABOVE",
"BLOCK_LOW_AND_ABOVE",
],
},
},
},
"rag_options": {
"description": "RAG 选项",
"type": "object",
"hint": "检索知识库设置, 非必填。仅 Agent 应用类型支持(智能体应用, 包括 RAG 应用)。阿里云百炼应用开启此功能后将无法多轮对话。",
"items": {
"pipeline_ids": {
"description": "知识库 ID 列表",
"type": "list",
"items": {"type": "string"},
"hint": "对指定知识库内所有文档进行检索, 前往 https://bailian.console.aliyun.com/ 数据应用->知识索引创建和获取 ID。",
},
"file_ids": {
"description": "非结构化文档 ID, 传入该参数将对指定非结构化文档进行检索。",
"type": "list",
"items": {"type": "string"},
"hint": "对指定非结构化文档进行检索。前往 https://bailian.console.aliyun.com/ 数据管理创建和获取 ID。",
},
"output_reference": {
"description": "是否输出知识库/文档的引用",
"type": "bool",
"hint": "在每次回答尾部加上引用源。默认为 False。",
},
},
},
"sensevoice_hint": {
"description": "部署SenseVoice",
"type": "string",
@@ -648,12 +790,14 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "modelscope 上的模型名称。默认:iic/SenseVoiceSmall。",
},
# "variables": {
# "description": "工作流固定输入变量",
# "type": "object",
# "obvious_hint": True,
# "hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
# },
"variables": {
"description": "工作流固定输入变量",
"type": "object",
"obvious_hint": True,
"items": {},
"hint": "可选。工作流固定输入变量,将会作为工作流的输入。也可以在对话时使用 /set 指令动态设置变量。如果变量名冲突,优先使用动态设置的变量。",
"invisible": True,
},
# "fastgpt_app_type": {
# "description": "应用类型",
# "type": "string",
@@ -664,7 +808,7 @@ CONFIG_METADATA_2 = {
"dashscope_app_type": {
"description": "应用类型",
"type": "string",
"hint": "阿里云百炼应用的应用类型。",
"hint": "百炼应用的应用类型。",
"options": [
"agent",
"agent-arrange",
@@ -844,6 +988,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "添加之后,会在每次对话的 Prompt 前加上此文本。",
},
"max_context_length": {
"description": "最多携带对话数量(条)",
"type": "int",
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
},
},
},
"persona": {
@@ -937,10 +1086,10 @@ CONFIG_METADATA_2 = {
"hint": "群聊消息最大数量。超过此数量后,会自动清除旧消息。",
},
"image_caption": {
"description": "启用图像转述(需模型支持)",
"description": "群聊图像转述(需模型支持)",
"type": "bool",
"obvious_hint": True,
"hint": "启用后,当接收到图片消息时,会使用模型先将图片转述为文字再进行后续处理。推荐使用 gpt-4o-mini 模型",
"hint": "用模型将群聊中的图片消息转述为文字,推荐 gpt-4o-mini 模型。和机器人的唤醒聊天中的图片消息仍然会直接作为上下文输入",
},
"image_caption_provider_id": {
"description": "图像转述提供商 ID",
@@ -1024,16 +1173,28 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
},
"timezone": {
"description": "时区",
"type": "string",
"obvious_hint": True,
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
},
"log_level": {
"description": "控制台日志级别",
"type": "string",
"hint": "控制台输出日志的级别。",
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
},
"t2i_strategy": {
"description": "文本转图像渲染源",
"type": "string",
"hint": "文本转图像策略。`remote` 为使用远程基于 HTML 的渲染服务,`local` 为使用 PIL 本地渲染。当使用 local 时,将 ttf 字体命名为 'font.ttf' 放在 data/ 目录下可自定义字体。",
"options": ["remote", "local"],
},
"t2i_endpoint": {
"description": "文本转图像服务接口",
"type": "string",
"hint": "为空时使用 AstrBot API 服务",
"hint": "当 t2i_strategy 为 remote 时生效。为空时使用 AstrBot API 服务",
},
"pip_install_arg": {
"description": "pip 安装参数",
+79 -9
View File
@@ -1,3 +1,10 @@
"""
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""
import uuid
import json
import asyncio
@@ -11,24 +18,34 @@ class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
# session_conversations 字典记录会话ID-对话ID 映射关系
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
self._start_periodic_save()
def _start_periodic_save(self):
"""启动定时保存任务"""
asyncio.create_task(self._periodic_save())
async def _periodic_save(self):
"""定时保存会话对话映射关系到存储中"""
while True:
await asyncio.sleep(self.save_interval)
self._save_to_storage()
def _save_to_storage(self):
"""保存会话对话映射关系到存储中"""
sp.put("session_conversation", self.session_conversations)
async def new_conversation(self, unified_msg_origin: str) -> str:
"""新建对话,并将当前会话的对话转移到新对话"""
"""新建对话,并将当前会话的对话转移到新对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
conversation_id = str(uuid.uuid4())
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
self.session_conversations[unified_msg_origin] = conversation_id
@@ -36,14 +53,24 @@ class ConversationManager:
return conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
"""切换会话的对话"""
"""切换会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
self.session_conversations[unified_msg_origin] = conversation_id
sp.put("session_conversation", self.session_conversations)
async def delete_conversation(
self, unified_msg_origin: str, conversation_id: str = None
):
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话"""
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
@@ -51,23 +78,48 @@ class ConversationManager:
sp.put("session_conversation", self.session_conversations)
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
"""获取会话当前的对话 ID"""
"""获取会话当前的对话 ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
"""
return self.session_conversations.get(unified_msg_origin, None)
async def get_conversation(
self, unified_msg_origin: str, conversation_id: str
) -> Conversation:
"""获取会话的对话"""
"""获取会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
Returns:
conversation (Conversation): 对话对象
"""
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
"""获取会话的所有对话"""
"""获取会话的所有对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
Returns:
conversations (List[Conversation]): 对话对象列表
"""
return self.db.get_conversations(unified_msg_origin)
async def update_conversation(
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
):
"""更新会话的对话"""
"""更新会话的对话
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
"""
if conversation_id:
self.db.update_conversation(
user_id=unified_msg_origin,
@@ -76,7 +128,12 @@ class ConversationManager:
)
async def update_conversation_title(self, unified_msg_origin: str, title: str):
"""更新会话的对话标题"""
"""更新会话的对话标题
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
title (str): 对话标题
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_title(
@@ -86,7 +143,12 @@ class ConversationManager:
async def update_conversation_persona_id(
self, unified_msg_origin: str, persona_id: str
):
"""更新会话的对话 Persona ID"""
"""更新会话的对话 Persona ID
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
persona_id (str): 对话 Persona ID
"""
conversation_id = self.session_conversations.get(unified_msg_origin)
if conversation_id:
self.db.update_conversation_persona_id(
@@ -96,6 +158,14 @@ class ConversationManager:
async def get_human_readable_context(
self, unified_msg_origin, conversation_id, page=1, page_size=10
):
"""获取人类可读的上下文
Args:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
page (int): 页码
page_size (int): 每页大小
"""
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
history = json.loads(conversation.history)
+89 -17
View File
@@ -1,3 +1,14 @@
"""
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
工作流程:
1. 初始化所有组件
2. 启动事件总线和任务, 所有任务都在这里运行
3. 执行启动完成事件钩子
"""
import traceback
import asyncio
import time
@@ -24,32 +35,51 @@ from astrbot.core.star.star_handler import star_map
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = astrbot_config
self.db = db
"""
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
EventBus 等。
该类还负责加载和执行插件, 以及处理事件总线的分发。
"""
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker # 初始化日志代理
self.astrbot_config = astrbot_config # 初始化配置
self.db = db # 初始化数据库
# 根据环境变量设置代理
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
os.environ["no_proxy"] = "localhost"
async def initialize(self):
"""
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
"""
# 初始化日志代理
logger.info("AstrBot v" + VERSION)
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG
else:
logger.setLevel(self.astrbot_config["log_level"])
self.event_queue = Queue()
self.event_queue.closed = False
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
# 初始化事件队列
self.event_queue = Queue()
# 初始化供应商管理器
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
# 初始化平台管理器
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
# 初始化知识库管理器
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
# 初始化对话管理器
self.conversation_manager = ConversationManager(self.db)
# 初始化提供给插件的上下文
self.star_context = Context(
self.event_queue,
self.astrbot_config,
@@ -59,33 +89,50 @@ class AstrBotCoreLifecycle:
self.conversation_manager,
self.knowledge_db_manager,
)
# 初始化插件管理器
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
# 扫描、注册插件、实例化插件类
await self.plugin_manager.reload()
"""扫描、注册插件、实例化插件类"""
# 根据配置实例化各个 Provider
await self.provider_manager.initialize()
"""根据配置实例化各个 Provider"""
# 初始化消息事件流水线调度器
self.pipeline_scheduler = PipelineScheduler(
PipelineContext(self.astrbot_config, self.plugin_manager)
)
await self.pipeline_scheduler.initialize()
"""初始化消息事件流水线调度器"""
# 初始化更新器
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
# 初始化事件总线
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
# 记录启动时间
self.start_time = int(time.time())
# 初始化当前任务列表
self.curr_tasks: List[asyncio.Task] = []
# 根据配置实例化各个平台适配器
await self.platform_manager.initialize()
"""根据配置实例化各个平台适配器"""
# 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event()
def _load(self):
"""加载事件总线和任务并初始化"""
# 创建一个异步任务来执行事件总线的 dispatch() 方法
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
event_bus_task = asyncio.create_task(
self.event_bus.dispatch(), name="event_bus"
)
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = []
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
@@ -99,17 +146,24 @@ class AstrBotCoreLifecycle:
self.start_time = int(time.time())
async def _task_wrapper(self, task: asyncio.Task):
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
Args:
task (asyncio.Task): 要执行的异步任务
"""
try:
await task
except asyncio.CancelledError:
pass
pass # 任务被取消, 静默处理
except Exception as e:
# 获取完整的异常堆栈信息, 按行分割并记录到日志中
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
logger.error(f"| {line}")
logger.error("-------")
async def start(self):
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
self._load()
logger.info("AstrBot 启动完成。")
@@ -126,15 +180,29 @@ class AstrBotCoreLifecycle:
except BaseException:
logger.error(traceback.format_exc())
# 同时运行curr_tasks中的所有任务
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
async def stop(self):
self.event_queue.closed = True
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
# 请求停止所有正在运行的异步任务
for task in self.curr_tasks:
task.cancel()
await self.provider_manager.terminate()
for plugin in self.plugin_manager.context.get_all_stars():
try:
await self.plugin_manager._terminate_plugin(plugin)
except Exception as e:
logger.warning(traceback.format_exc())
logger.warning(
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
)
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
# 再次遍历curr_tasks等待每个任务真正结束
for task in self.curr_tasks:
try:
await task
@@ -143,13 +211,17 @@ class AstrBotCoreLifecycle:
except Exception as e:
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
def restart(self):
self.event_queue.closed = True
async def restart(self):
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
await self.provider_manager.terminate()
await self.platform_manager.terminate()
self.dashboard_shutdown_event.set()
threading.Thread(
target=self.astrbot_updator._reboot, name="restart", daemon=True
).start()
def load_platform(self) -> List[asyncio.Task]:
"""加载平台实例并返回所有平台实例的异步任务列表"""
tasks = []
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
+43 -1
View File
@@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass
from typing import List
from typing import List, Dict, Any, Tuple
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, Conversation
@@ -117,3 +117,45 @@ class BaseDatabase(abc.ABC):
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
"""更新 Conversation Persona ID"""
raise NotImplementedError
@abc.abstractmethod
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页
Args:
page: 页码,从1开始
page_size: 每页数量
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
@abc.abstractmethod
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表
Args:
page: 页码
page_size: 每页数量
platforms: 平台筛选列表
message_types: 消息类型筛选列表
search_query: 搜索关键词
exclude_ids: 排除的用户ID列表
exclude_platforms: 排除的平台列表
Returns:
Tuple[List[Dict[str, Any]], int]: 返回一个元组,包含对话列表和总对话数
"""
raise NotImplementedError
+112
View File
@@ -0,0 +1,112 @@
import json
import aiosqlite
import os
from typing import Any
from .plugin_storage import PluginStorage
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
class SQLitePluginStorage(PluginStorage):
"""插件数据的 SQLite 存储实现类。
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
所有数据以 (plugin, key) 作为复合主键进行索引。
"""
_instance = None # Standalone instance of the class
_db_conn = None
db_path = None
def __new__(cls):
"""
创建或获取 SQLitePluginStorage 的单例实例。
如果实例已存在,则返回现有实例;否则创建一个新实例。
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
"""
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
if cls._instance is None:
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
cls._instance.db_path = DBPATH
return cls._instance
async def _init_db(self):
"""初始化数据库连接(只执行一次)"""
if SQLitePluginStorage._db_conn is None:
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
await self._setup_db()
async def _setup_db(self):
"""
异步初始化数据库。
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
其中 plugin 和 key 组合作为主键。
"""
await self._db_conn.execute("""
CREATE TABLE IF NOT EXISTS plugin_data (
plugin TEXT,
key TEXT,
value TEXT,
PRIMARY KEY (plugin, key)
)
""")
await self._db_conn.commit()
async def set(self, plugin: str, key: str, value: Any):
"""
异步存储数据。
将指定插件的键值对存入数据库,如果键已存在则更新值。
值会被序列化为 JSON 字符串后存储。
Args:
plugin: 插件标识符
key: 数据键名
value: 要存储的数据值(任意类型,将被 JSON 序列化)
"""
await self._init_db()
await self._db_conn.execute(
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
(plugin, key, json.dumps(value)),
)
await self._db_conn.commit()
async def get(self, plugin: str, key: str) -> Any:
"""
异步获取数据。
从数据库中获取指定插件和键名对应的值,
返回的值会从 JSON 字符串反序列化为原始数据类型。
Args:
plugin: 插件标识符
key: 数据键名
Returns:
Any: 存储的数据值,如果未找到则返回 None
"""
await self._init_db()
async with self._db_conn.execute(
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
(plugin, key),
) as cursor:
row = await cursor.fetchone()
return json.loads(row[0]) if row else None
async def delete(self, plugin: str, key: str):
"""
异步删除数据。
从数据库中删除指定插件和键名对应的数据项。
Args:
plugin: 插件标识符
key: 要删除的数据键名
"""
await self._init_db()
await self._db_conn.execute(
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
)
await self._db_conn.commit()
+8
View File
@@ -6,6 +6,8 @@ from typing import List
@dataclass
class Platform:
"""平台使用统计数据"""
name: str
count: int
timestamp: int
@@ -13,6 +15,8 @@ class Platform:
@dataclass
class Provider:
"""供应商使用统计数据"""
name: str
count: int
timestamp: int
@@ -20,6 +24,8 @@ class Provider:
@dataclass
class Plugin:
"""插件使用统计数据"""
name: str
count: int
timestamp: int
@@ -27,6 +33,8 @@ class Plugin:
@dataclass
class Command:
"""命令使用统计数据"""
name: str
count: int
timestamp: int
+192 -18
View File
@@ -3,7 +3,7 @@ import os
import time
from astrbot.core.db.po import Platform, Stats, LLMHistory, ATRIVision, Conversation
from . import BaseDatabase
from typing import Tuple
from typing import Tuple, List, Dict, Any
class SQLiteDatabase(BaseDatabase):
@@ -128,24 +128,23 @@ class SQLiteDatabase(BaseDatabase):
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
where_clause = ""
if session_id or provider_type:
where_clause += " WHERE "
has = False
if session_id:
where_clause += f"session_id = '{session_id}'"
has = True
if provider_type:
if has:
where_clause += " AND "
where_clause += f"provider_type = '{provider_type}'"
conditions = []
params = []
if session_id:
conditions.append("session_id = ?")
params.append(session_id)
if provider_type:
conditions.append("provider_type = ?")
params.append(provider_type)
sql = "SELECT * FROM llm_history"
if conditions:
sql += " WHERE " + " AND ".join(conditions)
c.execute(sql, params)
c.execute(
"""
SELECT * FROM llm_history
"""
+ where_clause
)
res = c.fetchall()
histories = []
for row in res:
@@ -389,3 +388,178 @@ class SQLiteDatabase(BaseDatabase):
if res:
return ATRIVision(*res)
return None
def get_all_conversations(
self, page: int = 1, page_size: int = 20
) -> Tuple[List[Dict[str, Any]], int]:
"""获取所有对话,支持分页,按更新时间降序排序"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 获取总记录数
c.execute("""
SELECT COUNT(*) FROM webchat_conversation
""")
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 获取分页数据,按更新时间降序排序
c.execute(
"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
""",
(page_size, offset),
)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型且至少有8个字符,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0,确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
def get_filtered_conversations(
self,
page: int = 1,
page_size: int = 20,
platforms: List[str] = None,
message_types: List[str] = None,
search_query: str = None,
exclude_ids: List[str] = None,
exclude_platforms: List[str] = None,
) -> Tuple[List[Dict[str, Any]], int]:
"""获取筛选后的对话列表"""
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
try:
# 构建查询条件
where_clauses = []
params = []
# 平台筛选
if platforms and len(platforms) > 0:
platform_conditions = []
for platform in platforms:
platform_conditions.append("user_id LIKE ?")
params.append(f"{platform}:%")
if platform_conditions:
where_clauses.append(f"({' OR '.join(platform_conditions)})")
# 消息类型筛选
if message_types and len(message_types) > 0:
message_type_conditions = []
for msg_type in message_types:
message_type_conditions.append("user_id LIKE ?")
params.append(f"%:{msg_type}:%")
if message_type_conditions:
where_clauses.append(f"({' OR '.join(message_type_conditions)})")
# 搜索关键词
if search_query:
search_query = search_query.encode("unicode_escape").decode("utf-8")
where_clauses.append(
"(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)"
)
search_param = f"%{search_query}%"
params.extend([search_param, search_param, search_param, search_param])
# 排除特定用户ID
if exclude_ids and len(exclude_ids) > 0:
for exclude_id in exclude_ids:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_id}%")
# 排除特定平台
if exclude_platforms and len(exclude_platforms) > 0:
for exclude_platform in exclude_platforms:
where_clauses.append("user_id NOT LIKE ?")
params.append(f"{exclude_platform}:%")
# 构建完整的 WHERE 子句
where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# 构建计数查询
count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}"
# 获取总记录数
c.execute(count_sql, params)
total_count = c.fetchone()[0]
# 计算偏移量
offset = (page - 1) * page_size
# 构建分页数据查询
data_sql = f"""
SELECT user_id, cid, created_at, updated_at, title, persona_id
FROM webchat_conversation
{where_sql}
ORDER BY updated_at DESC
LIMIT ? OFFSET ?
"""
query_params = params + [page_size, offset]
# 获取分页数据
c.execute(data_sql, query_params)
rows = c.fetchall()
conversations = []
for row in rows:
user_id, cid, created_at, updated_at, title, persona_id = row
# 确保 cid 是字符串类型,否则使用一个默认值
safe_cid = str(cid) if cid else "unknown"
display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid
conversations.append(
{
"user_id": user_id or "",
"cid": safe_cid,
"title": title or f"对话 {display_cid}",
"persona_id": persona_id or "",
"created_at": created_at or 0,
"updated_at": updated_at or 0,
}
)
return conversations, total_count
except Exception as _:
# 返回空列表和0,确保即使出错也有有效的返回值
return [], 0
finally:
c.close()
+5 -3
View File
@@ -38,11 +38,13 @@ CREATE TABLE IF NOT EXISTS atri_vision(
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
user_id TEXT, -- 会话 id
cid TEXT, -- 对话 id
history TEXT,
created_at INTEGER,
updated_at INTEGER,
title TEXT,
persona_id TEXT
);
);
PRAGMA encoding = 'UTF-8';
+35 -5
View File
@@ -1,3 +1,16 @@
"""
事件总线, 用于处理事件的分发和处理
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
class:
EventBus: 事件总线, 用于处理事件的分发和处理
工作流程:
1. 维护一个异步队列, 来接受各种消息事件
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
"""
import asyncio
from asyncio import Queue
from astrbot.core.pipeline.scheduler import PipelineScheduler
@@ -6,21 +19,38 @@ from .platform import AstrMessageEvent
class EventBus:
"""事件总线: 用于处理事件的分发和处理
维护一个异步队列, 来接受各种消息事件
"""
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
self.event_queue = event_queue
self.pipeline_scheduler = pipeline_scheduler
self.event_queue = event_queue # 事件队列
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
async def dispatch(self):
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
while True:
event: AstrMessageEvent = await self.event_queue.get()
self._print_event(event)
asyncio.create_task(self.pipeline_scheduler.execute(event))
event: AstrMessageEvent = (
await self.event_queue.get()
) # 从事件队列中获取新的事件
self._print_event(event) # 打印日志
asyncio.create_task(
self.pipeline_scheduler.execute(event)
) # 创建新的异步任务来执行管道调度器的处理逻辑
def _print_event(self, event: AstrMessageEvent):
"""用于记录事件信息
Args:
event (AstrMessageEvent): 事件对象
"""
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
if event.get_sender_name():
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
)
# 没有发送者名称: [平台名] 发送者ID: 消息概要
else:
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
@@ -1,18 +1,27 @@
"""
AstrBot 启动器负责初始化和启动核心组件和仪表板服务器
工作流程:
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
2. 运行核心生命周期任务和仪表板服务器
"""
import asyncio
import traceback
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .server import AstrBotDashboard
from astrbot.core.db import BaseDatabase
from astrbot.core import LogBroker
from astrbot.dashboard.server import AstrBotDashboard
class AstrBotDashBoardLifecycle:
class InitialLoader:
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
self.db = db
self.logger = logger
self.log_broker = log_broker
self.dashboard_server = None
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
@@ -25,11 +34,15 @@ class AstrBotDashBoardLifecycle:
logger.critical(traceback.format_exc())
logger.critical(f"😭 初始化 AstrBot 失败:{e} !!!")
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
task = asyncio.gather(core_task, self.dashboard_server.run())
self.dashboard_server = AstrBotDashboard(
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
)
task = asyncio.gather(
core_task, self.dashboard_server.run()
) # 启动核心任务和仪表板服务器
try:
await task
await task # 整个AstrBot在这里运行
except asyncio.CancelledError:
logger.info("🌈 正在关闭 AstrBot...")
await core_lifecycle.stop()
+169 -14
View File
@@ -1,11 +1,37 @@
"""
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
const:
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
class:
LogBroker: 日志代理类, 用于缓存和分发日志消息
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
LogManager: 日志管理器, 用于创建和配置日志记录器
function:
is_plugin_path: 检查文件路径是否来自插件目录
get_short_level_name: 将日志级别名称转换为四个字母的缩写
工作流程:
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
"""
import logging
import colorlog
import asyncio
import os
from collections import deque
from asyncio import Queue
from typing import List
# 日志缓存大小
CACHED_SIZE = 200
# 日志颜色配置
log_color_config = {
"DEBUG": "green",
"INFO": "bold_cyan",
@@ -17,13 +43,57 @@ log_color_config = {
}
def is_plugin_path(pathname):
"""检查文件路径是否来自插件目录
Args:
pathname (str): 文件路径
Returns:
bool: 如果路径来自插件目录,则返回 True,否则返回 False
"""
if not pathname:
return False
norm_path = os.path.normpath(pathname)
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
def get_short_level_name(level_name):
"""将日志级别名称转换为四个字母的缩写
Args:
level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
Returns:
str: 四个字母的日志级别缩写
"""
level_map = {
"DEBUG": "DBUG",
"INFO": "INFO",
"WARNING": "WARN",
"ERROR": "ERRO",
"CRITICAL": "CRIT",
}
return level_map.get(level_name, level_name[:4].upper())
class LogBroker:
"""日志代理类, 用于缓存和分发日志消息
发布-订阅模式
"""
def __init__(self):
self.log_cache = deque(maxlen=CACHED_SIZE)
self.subscribers: List[Queue] = []
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: List[Queue] = [] # 订阅者列表
def register(self) -> Queue:
"""给每个订阅者返回一个带有日志缓存的队列"""
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
Returns:
Queue: 订阅者的队列, 可用于接收日志消息
"""
q = Queue(maxsize=CACHED_SIZE + 10)
for log in self.log_cache:
q.put_nowait(log)
@@ -31,11 +101,20 @@ class LogBroker:
return q
def unregister(self, q: Queue):
"""取消订阅"""
"""取消订阅
Args:
q (Queue): 需要取消订阅的队列
"""
self.subscribers.remove(q)
def publish(self, log_entry: str):
"""发布消息"""
def publish(self, log_entry: dict):
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
Args:
log_entry (dict): 日志消息, 包含日志级别和日志内容.
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
"""
self.log_cache.append(log_entry)
for q in self.subscribers:
try:
@@ -45,44 +124,120 @@ class LogBroker:
class LogQueueHandler(logging.Handler):
"""日志处理器, 用于将日志消息发送到 LogBroker
继承自 logging.Handler
"""
def __init__(self, log_broker: LogBroker):
super().__init__()
self.log_broker = log_broker
def emit(self, record):
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
这个方法会在每次日志记录时被调用
Args:
record (logging.LogRecord): 日志记录对象, 包含日志信息
"""
log_entry = self.format(record)
self.log_broker.publish(log_entry)
self.log_broker.publish({
"level": record.levelname,
"time": record.asctime,
"data": log_entry,
})
class LogManager:
"""日志管理器, 用于创建和配置日志记录器
提供了获取默认日志记录器logger和设置队列处理器的方法
"""
@classmethod
def GetLogger(cls, log_name: str = "default"):
"""获取指定名称的日志记录器logger
Args:
log_name (str): 日志记录器的名称, 默认为 "default"
Returns:
logging.Logger: 返回配置好的日志记录器
"""
logger = logging.getLogger(log_name)
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
if logger.hasHandlers():
return logger
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
# 如果logger没有处理器
console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出
console_handler.setLevel(
logging.DEBUG
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
console_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s [%(asctime)s] [%(levelname)-5s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
datefmt="%H:%M:%S",
log_colors=log_color_config,
)
console_handler.setFormatter(console_formatter)
logger.setLevel(logging.DEBUG)
logger.addHandler(console_handler)
class PluginFilter(logging.Filter):
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
def filter(self, record):
record.plugin_tag = (
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
)
return True
class FileNameFilter(logging.Filter):
"""文件名过滤器类, 用于修改日志记录的文件名格式
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record):
dirname = os.path.dirname(record.pathname)
record.filename = (
os.path.basename(dirname)
+ "."
+ os.path.basename(record.pathname).replace(".py", "")
)
return True
class LevelNameFilter(logging.Filter):
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
# 添加短日志级别名称
def filter(self, record):
record.short_levelname = get_short_level_name(record.levelname)
return True
console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
logger.addFilter(PluginFilter()) # 添加插件过滤器
logger.addFilter(FileNameFilter()) # 添加文件名过滤器
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
logger.addHandler(console_handler) # 添加处理器到logger
return logger
@classmethod
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
"""设置队列处理器, 用于将日志消息发送到 LogBroker
Args:
logger (logging.Logger): 日志记录器
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
"""
handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG)
if logger.handlers:
handler.setFormatter(logger.handlers[0].formatter)
else:
# 为队列处理器设置相同格式的formatter
handler.setFormatter(
logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s"
)
)
logger.addHandler(handler)
+136 -11
View File
@@ -25,9 +25,11 @@ SOFTWARE.
import base64
import json
import os
import uuid
import typing as T
from enum import Enum
from pydantic.v1 import BaseModel
from astrbot.core.utils.io import download_image_by_url, file_to_base64
class ComponentType(Enum):
@@ -59,6 +61,8 @@ class ComponentType(Enum):
TTS = "TTS"
Unknown = "Unknown"
WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包
class BaseMessageComponent(BaseModel):
type: ComponentType
@@ -146,6 +150,51 @@ class Record(BaseMessageComponent):
return Record(file=url, **_)
raise Exception("not a valid url")
async def convert_to_file_path(self) -> str:
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 语音的本地路径,以绝对路径表示。
"""
if self.file and self.file.startswith("file:///"):
file_path = self.file[8:]
return file_path
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
return os.path.abspath(file_path)
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
file_path = f"data/temp/{uuid.uuid4()}.jpg"
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
elif os.path.exists(self.file):
file_path = self.file
return os.path.abspath(file_path)
else:
raise Exception(f"not a valid file: {self.file}")
async def convert_to_base64(self) -> str:
"""将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。
Returns:
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
if self.file and self.file.startswith("file:///"):
bs64_data = file_to_base64(self.file[8:])
elif self.file and self.file.startswith("http"):
file_path = await download_image_by_url(self.file)
bs64_data = file_to_base64(file_path)
elif self.file and self.file.startswith("base64://"):
bs64_data = self.file
elif os.path.exists(self.file):
bs64_data = file_to_base64(self.file)
else:
raise Exception(f"not a valid file: {self.file}")
return bs64_data
class Video(BaseMessageComponent):
type: ComponentType = "Video"
@@ -279,10 +328,6 @@ class Image(BaseMessageComponent):
file_unique: T.Optional[str] = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: T.Optional[str], **_):
# for k in _.keys():
# if (k == "_type" and _[k] not in ["flash", "show", None]) or \
# (k == "c" and _[k] not in [2, 3]):
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
super().__init__(file=file, **_)
@staticmethod
@@ -307,14 +352,77 @@ class Image(BaseMessageComponent):
def fromIO(IO):
return Image.fromBytes(IO.read())
async def convert_to_file_path(self) -> str:
"""将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。
Returns:
str: 图片的本地路径,以绝对路径表示。
"""
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
image_file_path = url[8:]
return image_file_path
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
return os.path.abspath(image_file_path)
elif url and url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
image_file_path = f"data/temp/{uuid.uuid4()}.jpg"
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
elif os.path.exists(url):
image_file_path = url
return os.path.abspath(image_file_path)
else:
raise Exception(f"not a valid file: {url}")
async def convert_to_base64(self) -> str:
"""将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。
Returns:
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
"""
# convert to base64
url = self.url if self.url else self.file
if url and url.startswith("file:///"):
bs64_data = file_to_base64(url[8:])
elif url and url.startswith("http"):
image_file_path = await download_image_by_url(url)
bs64_data = file_to_base64(image_file_path)
elif url and url.startswith("base64://"):
bs64_data = url
elif os.path.exists(url):
bs64_data = file_to_base64(url)
else:
raise Exception(f"not a valid file: {url}")
return bs64_data
class Reply(BaseMessageComponent):
type: ComponentType = "Reply"
id: T.Union[str, int]
text: T.Optional[str] = ""
qq: T.Optional[int] = 0
"""所引用的消息 ID"""
chain: T.Optional[T.List["BaseMessageComponent"]] = []
"""引用的消息段列表"""
sender_id: T.Optional[int] | T.Optional[str] = 0
"""引用的消息发送者 ID"""
sender_nickname: T.Optional[str] = ""
"""引用的消息发送者昵称"""
time: T.Optional[int] = 0
"""引用的消息发送时间"""
message_str: T.Optional[str] = ""
"""解析后的纯文本消息字符串"""
sender_str: T.Optional[str] = ""
"""被引用的消息纯文本"""
text: T.Optional[str] = ""
"""deprecated"""
qq: T.Optional[int] = 0
"""deprecated"""
seq: T.Optional[int] = 0
"""deprecated"""
def __init__(self, **_):
super().__init__(**_)
@@ -353,16 +461,22 @@ class Node(BaseMessageComponent):
id: T.Optional[int] = 0 # 忽略
name: T.Optional[str] = "" # qq昵称
uin: T.Optional[int] = 0 # qq号
content: T.Optional[T.Union[str, list]] = "" # 子消息段列表
content: T.Optional[T.Union[str, list, dict]] = "" # 子消息段列表
seq: T.Optional[T.Union[str, list]] = "" # 忽略
time: T.Optional[int] = 0
def __init__(self, content: T.Union[str, list], **_):
def __init__(self, content: T.Union[str, list, dict, "Node", T.List["Node"]], **_):
if isinstance(content, list):
_content = ""
for chain in content:
_content += chain.toString()
_content = None
if all(isinstance(item, Node) for item in content):
_content = [node.toDict() for node in content]
else:
_content = ""
for chain in content:
_content += chain.toString()
content = _content
elif isinstance(content, Node):
content = content.toDict()
super().__init__(content=content, **_)
def toString(self):
@@ -449,6 +563,16 @@ class File(BaseMessageComponent):
super().__init__(name=name, file=file)
class WechatEmoji(BaseMessageComponent):
type: ComponentType = "WechatEmoji"
md5: T.Optional[str] = ""
md5_len: T.Optional[int] = 0
cdnurl: T.Optional[str] = ""
def __init__(self, **_):
super().__init__(**_)
ComponentTypes = {
"plain": Plain,
"text": Plain,
@@ -477,4 +601,5 @@ ComponentTypes = {
"tts": TTS,
"unknown": Unknown,
"file": File,
"WechatEmoji": WechatEmoji,
}
+37 -6
View File
@@ -1,8 +1,14 @@
import enum
from typing import List, Optional
from typing import List, Optional, Union
from dataclasses import dataclass, field
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
from astrbot.core.message.components import (
BaseMessageComponent,
Plain,
Image,
At,
AtAll,
)
from typing_extensions import deprecated
@@ -31,6 +37,30 @@ class MessageChain:
self.chain.append(Plain(message))
return self
def at(self, name: str, qq: Union[str, int]):
"""添加一条 At 消息到消息链 `chain` 中。
Example:
CommandResult().at("张三", "12345678910")
# 输出 @张三
"""
self.chain.append(At(name=name, qq=qq))
return self
def at_all(self):
"""添加一条 AtAll 消息到消息链 `chain` 中。
Example:
CommandResult().at_all()
# 输出 @所有人
"""
self.chain.append(AtAll())
return self
@deprecated("请使用 message 方法代替。")
def error(self, message: str):
"""添加一条错误消息到消息链 `chain` 中
@@ -77,6 +107,10 @@ class MessageChain:
self.use_t2i_ = use_t2i
return self
def get_plain_text(self) -> str:
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
class EventResultType(enum.Enum):
"""用于描述事件处理的结果类型。
@@ -147,9 +181,6 @@ class MessageEventResult(MessageChain):
"""是否为 LLM 结果。"""
return self.result_content_type == ResultContentType.LLM_RESULT
def get_plain_text(self) -> str:
"""获取纯文本消息。这个方法将获取所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
# 为了兼容旧版代码,保留 CommandResult 的别名
CommandResult = MessageEventResult
+1
View File
@@ -12,6 +12,7 @@ from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage
# 管道阶段顺序
STAGES_ORDER = [
"WakingCheckStage", # 检查是否需要唤醒
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
@@ -1,7 +1,4 @@
import re
import os
import json
import base64
from . import ContentSafetyStrategy
@@ -11,13 +8,13 @@ class KeywordsStrategy(ContentSafetyStrategy):
if extra_keywords is None:
extra_keywords = []
self.keywords.extend(extra_keywords)
keywords_path = os.path.join(os.path.dirname(__file__), "unfit_words")
# keywords_path = os.path.join(os.path.dirname(__file__), "unfit_words")
# internal keywords
if os.path.exists(keywords_path):
with open(keywords_path, "r", encoding="utf-8") as f:
self.keywords.extend(
json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
)
# if os.path.exists(keywords_path):
# with open(keywords_path, "r", encoding="utf-8") as f:
# self.keywords.extend(
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
# )
def check(self, content: str) -> bool:
for keyword in self.keywords:
@@ -1 +0,0 @@
ewogICAgImtleXdvcmRzIjogWwogICAgICAgICLkuaDov5HlubMiLAogICAgICAgICLog6HplKbmtpsiLAogICAgICAgICLmsZ/ms73msJEiLAogICAgICAgICLmuKnlrrblrp0iLAogICAgICAgICLmnY7lhYvlvLoiLAogICAgICAgICLmnY7plb/mmKUiLAogICAgICAgICLmr5vms73kuJwiLAogICAgICAgICLpgpPlsI/lubMiLAogICAgICAgICLlkajmganmnaUiLAogICAgICAgICLnpL7kvJrkuLvkuYkiLAogICAgICAgICLlhbHkuqflhZoiLAogICAgICAgICLlhbHkuqfkuLvkuYkiLAogICAgICAgICLlpKfpmYblrpjmlrkiLAogICAgICAgICLljJfkuqzmlL/mnYMiLAogICAgICAgICLkuK3ljY7luJ3lm70iLAogICAgICAgICLkuK3lm73mlL/lupwiLAogICAgICAgICLlhbHni5ciLAogICAgICAgICLlha3lm5vkuovku7YiLAogICAgICAgICLlpKnlronpl6giLAogICAgICAgICLlha3lm5siLAogICAgICAgICLmlL/msrvlsYDluLjlp5QiLAogICAgICAgICLlrabmva4iLAogICAgICAgICLlhavkuZ0iLAogICAgICAgICLkuozljYHlpKciLAogICAgICAgICLmsJHov5vlhZoiLAogICAgICAgICLlj7Dni6wiLAogICAgICAgICLlj7Dmub7ni6znq4siLAogICAgICAgICLlj7Dmub7lm70iLAogICAgICAgICLlm73msJHlhZoiLAogICAgICAgICLlj7Dmub7msJHlm70iLAogICAgICAgICLkuK3ljY7msJHlm70iLAogICAgICAgICJwb3JuaHViIiwKICAgICAgICAiUG9ybmh1YiIsCiAgICAgICAgIuS9nOeIsSIsCiAgICAgICAgIuWBmueIsSIsCiAgICAgICAgIuaAp+S6pCIsCiAgICAgICAgIuiHquaFsCIsCiAgICAgICAgIumYtOiMjiIsCiAgICAgICAgIua3q+WmhyIsCiAgICAgICAgIuiCm+S6pCIsCiAgICAgICAgIuS6pOmFjSIsCiAgICAgICAgIuaAp+WFs+ezuyIsCiAgICAgICAgIuaAp+a0u+WKqCIsCiAgICAgICAgIuiJsuaDhSIsCiAgICAgICAgIuiJsuWbviIsCiAgICAgICAgIuijuOS9kyIsCiAgICAgICAgIuWwj+eptCIsCiAgICAgICAgIua3q+iNoSIsCiAgICAgICAgIuaAp+eIsSIsCiAgICAgICAgIua4r+eLrCIsCiAgICAgICAgIuazlei9ruWKnyIsCiAgICAgICAgIuWFreWbmyIKICAgIF0KfQ==
+4 -2
View File
@@ -5,5 +5,7 @@ from astrbot.core.star import PluginManager
@dataclass
class PipelineContext:
astrbot_config: AstrBotConfig
plugin_manager: PluginManager
"""上下文对象,包含管道执行所需的上下文信息"""
astrbot_config: AstrBotConfig # AstrBot 配置对象
plugin_manager: PluginManager # 插件管理器对象
@@ -3,6 +3,7 @@
"""
import traceback
import asyncio
import json
from typing import Union, AsyncGenerator
from ...context import PipelineContext
@@ -15,7 +16,13 @@ from astrbot.core.message.message_event_result import (
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest, LLMResponse
from astrbot.core.provider.entites import (
ProviderRequest,
LLMResponse,
ToolCallMessageSegment,
AssistantMessageSegment,
ToolCallsResult,
)
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
@@ -27,6 +34,9 @@ class LLMRequestSubStage(Stage):
self.provider_wake_prefix = ctx.astrbot_config["provider_settings"][
"wake_prefix"
] # str
self.max_context_length = ctx.astrbot_config["provider_settings"][
"max_context_length"
] # int
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
@@ -63,8 +73,8 @@ class LLMRequestSubStage(Stage):
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
# 获取对话上下文
conversation_id = await self.conv_manager.get_curr_conversation_id(
@@ -74,10 +84,16 @@ class LLMRequestSubStage(Stage):
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
req.session_id = event.unified_msg_origin
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
if not conversation:
conversation_id = await self.conv_manager.new_conversation(
event.unified_msg_origin
)
conversation = await self.conv_manager.get_conversation(
event.unified_msg_origin, conversation_id
)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
@@ -109,100 +125,63 @@ class LLMRequestSubStage(Stage):
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# max context length
if (
self.max_context_length != -1 # -1 为不限制
and len(req.contexts) // 2 > self.max_context_length
):
logger.debug("上下文长度超过限制,将截断。")
req.contexts = req.contexts[-self.max_context_length * 2 :]
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
try:
logger.debug(f"提供商请求 Payload: {req}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
need_loop = True
while need_loop:
need_loop = False
logger.debug(f"提供商请求 Payload: {req}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
# 执行 LLM 响应后的事件钩子。
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnLLMResponseEvent
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
async for result in self._handle_llm_response(event, req, llm_response):
if isinstance(result, ProviderRequest):
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
req = result
need_loop = True
else:
yield
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
)
)
for handler in handlers:
try:
logger.debug(
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
)
await handler.handler(event, llm_response)
except BaseException:
logger.error(traceback.format_exc())
if event.is_stopped():
logger.info(
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
)
return
# 保存到历史记录
await self._save_to_history(event, req, llm_response)
await Metric.upload(
llm_tick=1,
model_name=provider.get_model(),
provider_type=provider.meta().type,
)
if llm_response.role == "assistant":
# text completion
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# function calling
function_calling_result = {}
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
)
for func_tool_name, func_tool_args in zip(
llm_response.tools_call_name, llm_response.tools_call_args
):
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
try:
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
function_calling_result[func_tool_name] = resp
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
function_calling_result[func_tool_name] = (
"When calling the function, an error occurred: " + str(e)
)
if function_calling_result:
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
# 我们重新执行一遍这个 stage
req.func_tool = None # 暂时不支持递归工具调用
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
for tool_name, tool_result in function_calling_result.items():
extra_prompt += (
f"Tool: {tool_name}\nTool Result: {tool_result}\n"
)
req.prompt += extra_prompt
async for _ in self.process(event, _nested=True):
yield
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(
@@ -212,6 +191,116 @@ class LLMRequestSubStage(Stage):
)
return
async def _handle_llm_response(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
) -> AsyncGenerator[None, None]:
"""处理 LLM 响应。
Returns:
bool: 是否需要继续调用 LLM
Yields:
Iterator[bool]: 将 event 交付给下一个 stage
"""
if llm_response.role == "assistant":
# text completion
if llm_response.result_chain:
event.set_result(
MessageEventResult(
chain=llm_response.result_chain.chain
).set_result_content_type(ResultContentType.LLM_RESULT)
)
else:
event.set_result(
MessageEventResult()
.message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT)
)
elif llm_response.role == "err":
event.set_result(
MessageEventResult().message(
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
)
)
elif llm_response.role == "tool":
# function calling
tool_call_result: list[ToolCallMessageSegment] = []
logger.info(
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
)
for func_tool_name, func_tool_args, func_tool_id in zip(
llm_response.tools_call_name,
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
try:
func_tool = req.func_tool.get_func(func_tool_name)
if func_tool.origin == "mcp":
logger.info(
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
)
client = req.func_tool.mcp_client_dict[
func_tool.mcp_server_name
]
res = await client.session.call_tool(
func_tool.name, func_tool_args
)
if res:
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=res.content[0].text,
)
)
else:
logger.info(
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
)
# 尝试调用工具函数
wrapper = self._call_handler(
self.ctx, event, func_tool.handler, **func_tool_args
)
async for resp in wrapper:
if resp is not None: # 有 return 返回
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=resp,
)
)
else:
yield # 有生成器返回
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
tool_call_result.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content=f"error: {str(e)}",
)
)
if tool_call_result:
# 函数调用结果
req.func_tool = None # 暂时不支持递归工具调用
assistant_msg_seg = AssistantMessageSegment(
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
)
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
req.tool_calls_result = ToolCallsResult(
tool_calls_info=assistant_msg_seg,
tool_calls_result=tool_call_result,
)
yield req # 再次执行 LLM 请求
else:
if llm_response.completion_text:
event.set_result(
MessageEventResult().message(llm_response.completion_text)
)
async def _save_to_history(
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
):
@@ -221,8 +310,12 @@ class LLMRequestSubStage(Stage):
if llm_response.role == "assistant":
# 文本回复
contexts = req.contexts
new_record = {"role": "user", "content": req.prompt}
contexts.append(new_record)
contexts.append(await req.assemble_context())
# tool calls result
if req.tool_calls_result:
contexts.extend(req.tool_calls_result.to_openai_messages())
contexts.append(
{"role": "assistant", "content": llm_response.completion_text}
)
+77 -6
View File
@@ -2,6 +2,7 @@ import random
import asyncio
import math
import traceback
import astrbot.core.message.components as Comp
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
@@ -11,11 +12,42 @@ from astrbot.core import logger
from astrbot.core.message.message_event_result import BaseMessageComponent
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.star import star_map
from astrbot.core.message.components import Plain, Reply, At
@register_stage
class RespondStage(Stage):
# 组件类型到其非空判断函数的映射
_component_validators = {
Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), # 纯文本消息需要strip
Comp.Face: lambda comp: comp.id is not None, # QQ表情
Comp.Record: lambda comp: bool(comp.file), # 语音
Comp.Video: lambda comp: bool(comp.file), # 视频
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
Comp.AtAll: lambda comp: True, # @所有人
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
Comp.Dice: lambda comp: True, # 骰子(未完成)
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
Comp.Contact: lambda comp: True, # 联系人(未完成)
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
Comp.Music: lambda comp: bool(comp._type) and bool(comp.url) and bool(comp.audio), # 音乐
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.RedBag: lambda comp: bool(comp.title), # 红包
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
Comp.Json: lambda comp: bool(comp.data), # JSON
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
Comp.File: lambda comp: bool(comp.file), # 文件
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
}
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
@@ -62,7 +94,7 @@ class RespondStage(Stage):
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
"""分段回复 计算间隔时间"""
if self.interval_method == "log":
if isinstance(comp, Plain):
if isinstance(comp, Comp.Plain):
wc = await self._word_cnt(comp.text)
i = math.log(wc + 1, self.log_base)
return random.uniform(i, i + 0.5)
@@ -72,6 +104,28 @@ class RespondStage(Stage):
# random
return random.uniform(self.interval[0], self.interval[1])
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
"""检查消息链是否为空
Args:
chain (list[BaseMessageComponent]): 包含消息对象的列表
"""
if not chain:
return True
for comp in chain:
comp_type = type(comp)
# 检查组件类型是否在字典中
if comp_type in self._component_validators:
if self._component_validators[comp_type](comp):
return False
else:
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
# 如果所有组件都为空
return True
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
@@ -82,6 +136,16 @@ class RespondStage(Stage):
if len(result.chain) > 0:
await event._pre_send()
# 检查消息链是否为空
try:
if await self._is_empty_message_chain(result.chain):
logger.info("消息为空,跳过发送阶段")
event.clear_result()
event.stop_event()
return
except Exception as e:
logger.warning(f"空内容检查异常: {e}")
if self.enable_seg and (
(self.only_llm_result and result.is_llm_result())
or not self.only_llm_result
@@ -89,13 +153,13 @@ class RespondStage(Stage):
decorated_comps = []
if self.reply_with_mention:
for comp in result.chain:
if isinstance(comp, At):
if isinstance(comp, Comp.At):
decorated_comps.append(comp)
result.chain.remove(comp)
break
if self.reply_with_quote:
for comp in result.chain:
if isinstance(comp, Reply):
if isinstance(comp, Comp.Reply):
decorated_comps.append(comp)
result.chain.remove(comp)
break
@@ -103,9 +167,16 @@ class RespondStage(Stage):
for comp in result.chain:
i = await self._calc_comp_interval(comp)
await asyncio.sleep(i)
await event.send(MessageChain([*decorated_comps, comp]))
try:
await event.send(MessageChain([*decorated_comps, comp]))
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
break
else:
await event.send(result)
try:
await event.send(result)
except Exception as e:
logger.error(f"发送消息失败: {e} chain: {result.chain}")
await event._post_send()
logger.info(
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
@@ -31,6 +31,8 @@ class ResultDecorateStage(Stage):
self.t2i_word_threshold = 50
except BaseException:
self.t2i_word_threshold = 150
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
self.t2i_use_network = self.t2i_strategy == "remote"
self.forward_threshold = ctx.astrbot_config["platform_settings"][
"forward_threshold"
@@ -192,7 +194,9 @@ class ResultDecorateStage(Stage):
if plain_str and len(plain_str) > self.t2i_word_threshold:
render_start = time.time()
try:
url = await html_renderer.render_t2i(plain_str, return_url=True)
url = await html_renderer.render_t2i(
plain_str, return_url=True, use_network=self.t2i_use_network
)
except BaseException:
logger.error("文本转图片失败,使用文本发送。")
return
@@ -201,7 +205,10 @@ class ResultDecorateStage(Stage):
"文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。"
)
if url:
result.chain = [Image.fromURL(url)]
if url.startswith("http"):
result.chain = [Image.fromURL(url)]
else:
result.chain = [Image.fromFileSystem(url)]
# 触发转发消息
has_forwarded = False
+35 -12
View File
@@ -7,49 +7,72 @@ from astrbot.core import logger
class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext):
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__.__name__))
self.ctx = context
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
) # 按照顺序排序
self.ctx = context # 上下文对象
async def initialize(self):
"""初始化管道调度器时, 初始化所有阶段"""
for stage in registered_stages:
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
await stage.initialize(self.ctx)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
"""依次执行各个阶段
Args:
event (AstrMessageEvent): 事件对象
from_stage (int): 从第几个阶段开始执行, 默认从0开始
"""
for i in range(from_stage, len(registered_stages)):
stage = registered_stages[i]
stage = registered_stages[i] # 获取当前要执行的阶段
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
coro = stage.process(event)
if isinstance(coro, AsyncGenerator):
async for _ in coro:
coroutine = stage.process(
event
) # 调用阶段的process方法, 返回协程或者异步生成器
if isinstance(coroutine, AsyncGenerator):
# 如果返回的是异步生成器, 实现洋葱模型的核心
async for _ in coroutine:
# 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段
if event.is_stopped():
logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
)
break
# 递归调用, 处理所有后续阶段
await self._process_stages(event, i + 1)
# 此处是后续所有阶段处理完毕后返回的点, 执行后置处理
if event.is_stopped():
logger.debug(
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
)
break
else:
await coro
# 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件)
# 简单地等待它执行完成, 然后继续执行下一个阶段
await coroutine
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
break
if event.is_stopped():
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
break
async def execute(self, event: AstrMessageEvent):
"""执行 pipeline"""
"""执行 pipeline
Args:
event (AstrMessageEvent): 事件对象
"""
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if not event._has_send_oper and event.get_platform_name() == "webchat":
await event.send(None)
+60 -22
View File
@@ -1,14 +1,14 @@
from __future__ import annotations
import abc
import inspect
import traceback
from astrbot.api import logger
from typing import List, AsyncGenerator, Union, Awaitable
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = []
"""维护了所有已注册的 Stage 实现类"""
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
def register_stage(cls):
@@ -22,14 +22,24 @@ class Stage(abc.ABC):
@abc.abstractmethod
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化阶段"""
"""初始化阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
raise NotImplementedError
@abc.abstractmethod
async def process(
self, event: AstrMessageEvent
) -> Union[None, AsyncGenerator[None, None]]:
"""处理事件"""
"""处理事件
Args:
event (AstrMessageEvent): 事件对象,包含事件的相关信息
Returns:
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
"""
raise NotImplementedError
async def _call_handler(
@@ -40,33 +50,61 @@ class Stage(abc.ABC):
*args,
**kwargs,
) -> AsyncGenerator[None, None]:
"""调用 Handler。"""
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
"""执行事件处理函数并处理其返回结果
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
2. 协程: 执行一次并处理返回值
Args:
ctx (PipelineContext): 消息管道上下文对象
event (AstrMessageEvent): 待处理的事件对象
handler (Awaitable): 事件处理函数
*args: 传递给handler的位置参数
**kwargs: 传递给handler的关键字参数
Returns:
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
"""
ready_to_call = None # 一个协程或者异步生成器(async def)
trace_ = None
try:
ready_to_call = handler(event, *args, **kwargs)
except TypeError as e:
except TypeError as _:
# 向下兼容
logger.debug(str(e))
trace_ = traceback.format_exc()
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
if isinstance(ready_to_call, AsyncGenerator):
_has_yielded = False
async for ret in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
# 如果是一个异步生成器, 进入洋葱模型
_has_yielded = False # 是否返回过值
try:
async for ret in ready_to_call:
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
# 返回值只能是 MessageEventResult 或者 None(无返回值)
_has_yielded = True
if isinstance(ret, (MessageEventResult, CommandResult)):
# 如果返回值是 MessageEventResult, 设置结果并继续
event.set_result(ret)
yield # 传递控制权给上一层的process函数
else:
# 如果返回值是 None, 则不设置结果并继续
# 继续执行后续阶段
yield ret # 传递控制权给上一层的process函数
if not _has_yielded:
# 如果这个异步生成器没有执行到yield分支
yield
else:
yield ret
if not _has_yielded:
yield
except Exception as e:
logger.error(f"Previous Error: {trace_}")
raise e
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
# 如果只是一个协程, 直接执行
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
yield # 传递控制权给上一层的process函数
else:
yield ret
yield ret # 传递控制权给上一层的process函数
+13 -2
View File
@@ -21,10 +21,19 @@ class WakingCheckStage(Stage):
"""
async def initialize(self, ctx: PipelineContext) -> None:
"""初始化唤醒检查阶段
Args:
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
"""
self.ctx = ctx
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
"no_permission_reply", True
)
# 私聊是否需要 wake_prefix 才能唤醒机器人
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
"platform_settings"
].get("friend_message_needs_wake_prefix", False)
async def process(
self, event: AstrMessageEvent
@@ -32,7 +41,7 @@ class WakingCheckStage(Stage):
# 设置 sender 身份
event.message_str = event.message_str.strip()
for admin_id in self.ctx.astrbot_config["admins_id"]:
if event.get_sender_id() == admin_id:
if str(event.get_sender_id()) == admin_id:
event.role = "admin"
break
@@ -68,7 +77,7 @@ class WakingCheckStage(Stage):
event.is_at_or_wake_command = True
break
# 检查是否是私聊
if event.is_private_chat():
if event.is_private_chat() and not self.friend_message_needs_wake_prefix:
is_wake = True
event.is_wake = True
event.is_at_or_wake_command = True
@@ -102,6 +111,7 @@ class WakingCheckStage(Stage):
f"插件 {star_map[handler.handler_module_path].name}: {e}"
)
)
await event._post_send()
event.stop_event()
passed = False
break
@@ -113,6 +123,7 @@ class WakingCheckStage(Stage):
f"ID {event.get_sender_id()} 权限不足。通过 /sid 获取 ID 并请管理员添加。"
)
)
await event._post_send()
event.stop_event()
return
@@ -15,6 +15,9 @@ class WhitelistCheckStage(Stage):
"enable_id_white_list"
]
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
self.whitelist = [
str(i).strip() for i in self.whitelist if str(i).strip() != ""
]
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
"wl_ignore_admin_on_group"
]
@@ -51,7 +54,10 @@ class WhitelistCheckStage(Stage):
and event.get_message_type() == MessageType.FRIEND_MESSAGE
):
return
if event.unified_msg_origin not in self.whitelist:
if (
event.unified_msg_origin not in self.whitelist
and str(event.get_group_id()).strip() not in self.whitelist
):
if self.wl_log:
logger.info(
f"会话 ID {event.unified_msg_origin} 不在会话白名单中,已终止事件传播。请在配置文件中添加该会话 ID 到白名单。"
+2 -1
View File
@@ -1,7 +1,7 @@
from .platform import Platform
from .astr_message_event import AstrMessageEvent
from .platform_metadata import PlatformMetadata
from .astrbot_message import AstrBotMessage, MessageMember, MessageType
from .astrbot_message import AstrBotMessage, MessageMember, MessageType, Group
__all__ = [
"Platform",
@@ -10,4 +10,5 @@ __all__ = [
"AstrBotMessage",
"MessageMember",
"MessageType",
"Group",
]
+40 -14
View File
@@ -1,10 +1,9 @@
import abc
import asyncio
from dataclasses import dataclass
from .astrbot_message import AstrBotMessage
from .platform_metadata import PlatformMetadata
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from typing import List, Union
from typing import List, Union, Optional
from astrbot.core.db.po import Conversation
from astrbot.core.message.components import (
Plain,
Image,
@@ -13,10 +12,14 @@ from astrbot.core.message.components import (
At,
AtAll,
Forward,
Reply,
)
from astrbot.core.utils.metrics import Metric
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
from astrbot.core.platform.message_type import MessageType
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.db.po import Conversation
from astrbot.core.utils.metrics import Metric
from .astrbot_message import AstrBotMessage, Group
from .platform_metadata import PlatformMetadata
@dataclass
@@ -100,8 +103,15 @@ class AstrMessageEvent(abc.ABC):
elif isinstance(i, Forward):
# 转发消息
outline += "[转发消息]"
elif isinstance(i, Reply):
# 引用回复
if i.message_str:
outline += f"[引用消息({i.sender_nickname}: {i.message_str})]"
else:
outline += "[引用消息]"
else:
outline += f"[{i.type}]"
outline += " "
return outline
def get_message_outline(self) -> str:
@@ -192,13 +202,6 @@ class AstrMessageEvent(abc.ABC):
"""
return self.role == "admin"
async def send(self, message: MessageChain):
"""
发送消息到消息平台。
"""
await Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
self._has_send_oper = True
async def _pre_send(self):
"""调度器会在执行 send() 前调用该方法"""
@@ -360,3 +363,26 @@ class AstrMessageEvent(abc.ABC):
system_prompt=system_prompt,
conversation=conversation,
)
"""平台适配器"""
async def send(self, message: MessageChain):
"""发送消息到消息平台。
Args:
message (MessageChain): 消息链,具体使用方式请参考文档。
"""
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
)
self._has_send_oper = True
async def get_group(self, group_id: str = None, **kwargs) -> Optional[Group]:
"""获取一个群聊的数据, 如果不填写 group_id: 如果是私聊消息,返回 None。如果是群聊消息,返回当前群聊的数据。
适配情况:
- gewechat
- aiocqhttp(OneBotv11)
"""
...
+35
View File
@@ -10,6 +10,41 @@ class MessageMember:
user_id: str # 发送者id
nickname: str = None
def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
return (
f"User ID: {self.user_id},"
f"Nickname: {self.nickname if self.nickname else 'N/A'}"
)
@dataclass
class Group:
group_id: str
"""群号"""
group_name: str = None
"""群名称"""
group_avatar: str = None
"""群头像"""
group_owner: str = None
"""群主 id"""
group_admins: List[str] = None
"""群管理员 id"""
members: List[MessageMember] = None
"""所有群成员"""
def __str__(self):
# 使用 f-string 来构建返回的字符串表示形式
return (
f"Group ID: {self.group_id}\n"
f"Name: {self.group_name if self.group_name else 'N/A'}\n"
f"Avatar: {self.group_avatar if self.group_avatar else 'N/A'}\n"
f"Owner ID: {self.group_owner if self.group_owner else 'N/A'}\n"
f"Admin IDs: {self.group_admins if self.group_admins else 'N/A'}\n"
f"Members Len: {len(self.members) if self.members else 0}\n"
f"First Member: {self.members[0] if self.members else 'N/A'}\n"
)
class AstrBotMessage:
"""
+44 -32
View File
@@ -64,6 +64,10 @@ class PlatformManager:
)
case "lark":
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
case "dingtalk":
from .sources.dingtalk.dingtalk_adapter import (
DingtalkPlatformAdapter, # noqa: F401
)
case "telegram":
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
case "wecom":
@@ -81,14 +85,18 @@ class PlatformManager:
)
return
cls_type = platform_cls_map[platform_config["type"]]
inst = cls_type(platform_config, self.settings, self.event_queue)
self._inst_map[platform_config["id"]] = inst
inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
self._inst_map[platform_config["id"]] = {
"inst": inst,
"client_id": inst.client_self_id,
}
self.platform_insts.append(inst)
asyncio.create_task(
self._task_wrapper(
asyncio.create_task(
inst.run(), name=platform_config["id"] + "_platform"
inst.run(),
name=f"platform_{platform_config['type']}_{platform_config['id']}",
)
)
)
@@ -105,38 +113,42 @@ class PlatformManager:
logger.error("-------")
async def reload(self, platform_config: dict):
# 还未实现完成,不要调用此方法
if platform_config["id"] in self._inst_map:
# 正在运行
if getattr(self._inst_map[platform_config["id"]], "terminate", None):
logger.info(f"正在尝试终止 {platform_config['id']} 平台适配器 ...")
await self._inst_map[platform_config["id"]].terminate()
logger.info(f"{platform_config['id']} 平台适配器已终止。")
del self._inst_map[platform_config["id"]]
self.platform_insts.remove(self._inst_map[platform_config["id"]])
else:
logger.warning(f"可能无法正常终止 {platform_config['id']} 平台适配器。")
# 再启动新的实例
await self.terminate_platform(platform_config["id"])
if platform_config["enable"]:
await self.load_platform(platform_config)
else:
# 先将 _inst_map 中在 platform_config 中不存在的实例删除
config_ids = [platform["id"] for platform in self.platforms_config]
for key in list(self._inst_map.keys()):
if key not in config_ids:
if getattr(self._inst_map[key], "terminate", None):
logger.info(f"正在尝试终止 {key} 平台适配器 ...")
await self._inst_map[key].terminate()
logger.info(f"{key} 平台适配器已终止。")
del self._inst_map[key]
self.platform_insts.remove(self._inst_map[key])
else:
logger.warning(f"可能无法正常终止 {key} 平台适配器。")
# 和配置文件保持同步
config_ids = [provider["id"] for provider in self.platforms_config]
for key in list(self._inst_map.keys()):
if key not in config_ids:
await self.terminate_platform(key)
# 再启动新的实例
await self.load_platform(platform_config)
async def terminate_platform(self, platform_id: str):
if platform_id in self._inst_map:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
# client_id = self._inst_map.pop(platform_id, None)
info = self._inst_map.pop(platform_id, None)
client_id = info["client_id"]
inst = info["inst"]
try:
self.platform_insts.remove(
next(
inst
for inst in self.platform_insts
if inst.client_self_id == client_id
)
)
except Exception:
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
if getattr(inst, "terminate", None):
await inst.terminate()
async def terminate(self):
for inst in self.platform_insts:
if getattr(inst, "terminate", None):
await inst.terminate()
def get_insts(self):
return self.platform_insts
+3 -1
View File
@@ -1,4 +1,5 @@
import abc
import uuid
from typing import Awaitable, Any
from asyncio import Queue
from .platform_metadata import PlatformMetadata
@@ -13,6 +14,7 @@ class Platform(abc.ABC):
super().__init__()
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
self._event_queue = event_queue
self.client_self_id = uuid.uuid4().hex
@abc.abstractmethod
def run(self) -> Awaitable[Any]:
@@ -25,7 +27,7 @@ class Platform(abc.ABC):
"""
终止一个平台的运行实例。
"""
pass
...
@abc.abstractmethod
def meta(self) -> PlatformMetadata:
@@ -1,9 +1,9 @@
import asyncio
import typing
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import Group, MessageMember
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
class AiocqhttpMessageEvent(AstrMessageEvent):
@@ -21,20 +21,15 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
d = segment.toDict()
if isinstance(segment, Plain):
d["type"] = "text"
d["data"]["text"] = segment.text.strip()
# 如果是空文本或者只带换行符的文本,不发送
if not d["data"]["text"]:
continue
elif isinstance(segment, (Image, Record)):
# convert to base64
if segment.file and segment.file.startswith("file:///"):
bs64_data = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
bs64_data = file_to_base64(image_file_path)
elif segment.file and segment.file.startswith("base64://"):
bs64_data = segment.file
else:
bs64_data = file_to_base64(segment.file)
bs64 = await segment.convert_to_base64()
d["data"] = {
"file": bs64_data,
"file": bs64,
}
elif isinstance(segment, At):
d["data"] = {
@@ -46,6 +41,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if not ret:
return
send_one_by_one = False
for seg in message.chain:
if isinstance(seg, (Node, Nodes)):
@@ -55,8 +53,13 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
if send_one_by_one:
for seg in message.chain:
if isinstance(seg, Nodes):
# 带有多个节点的合并转发消息
if isinstance(seg, (Node, Nodes)):
# 合并转发消息
if isinstance(seg, Node):
nodes = Nodes([seg])
seg = nodes
payload = seg.toDict()
if self.get_group_id():
payload["group_id"] = self.get_group_id()
@@ -78,3 +81,46 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)
async def get_group(self, group_id=None, **kwargs):
if isinstance(group_id, str) and group_id.isdigit():
group_id = int(group_id)
elif self.get_group_id():
group_id = int(self.get_group_id())
else:
return None
info: dict = await self.bot.call_action(
"get_group_info",
group_id=group_id,
)
members: typing.List[typing.Dict] = await self.bot.call_action(
"get_group_member_list",
group_id=group_id,
)
owner_id = None
admin_ids = []
for member in members:
if member["role"] == "owner":
owner_id = member["user_id"]
if member["role"] == "admin":
admin_ids.append(member["user_id"])
group = Group(
group_id=str(group_id),
group_name=info.get("group_name"),
group_avatar="",
group_admins=admin_ids,
group_owner=str(owner_id),
members=[
MessageMember(
user_id=member["user_id"],
nickname=member.get("nickname") or member.get("card"),
)
for member in members
],
)
return group
@@ -43,8 +43,6 @@ class AiocqhttpAdapter(Platform):
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
)
self.stop = False
self.bot = CQHttp(
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
)
@@ -140,7 +138,7 @@ class AiocqhttpAdapter(Platform):
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = (
abm.sender.user_id + "_" + str(event.group_id)
str(abm.sender.user_id) + "_" + str(event.group_id)
) # 也保留群组 id
else:
abm.session_id = (
@@ -160,8 +158,14 @@ class AiocqhttpAdapter(Platform):
return abm
async def _convert_handle_message_event(self, event: Event) -> AstrBotMessage:
"""OneBot V11 消息类事件"""
async def _convert_handle_message_event(
self, event: Event, get_reply=True
) -> AstrBotMessage:
"""OneBot V11 消息类事件
@param event: 事件对象
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
@@ -240,6 +244,36 @@ class AiocqhttpAdapter(Platform):
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
elif t == "reply":
if not get_reply:
a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
else:
try:
reply_event_data = await self.bot.call_action(
action="get_msg",
message_id=int(m["data"]["id"]),
)
abm_reply = await self._convert_handle_message_event(
Event.from_payload(reply_event_data), get_reply=False
)
reply_seg = Reply(
id=abm_reply.message_id,
chain=abm_reply.message,
sender_id=abm_reply.sender.user_id,
sender_nickname=abm_reply.sender.nickname,
time=abm_reply.timestamp,
message_str=abm_reply.message_str,
text=abm_reply.message_str, # for compatibility
qq=abm_reply.sender.user_id, # for compatibility
)
abm.message.append(reply_seg)
except BaseException as e:
logger.error(f"获取引用消息失败: {e}")
a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
else:
a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
@@ -267,22 +301,19 @@ class AiocqhttpAdapter(Platform):
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
logging.getLogger("aiocqhttp").setLevel(logging.ERROR)
self.shutdown_event = asyncio.Event()
return coro
async def terminate(self):
self.stop = True
await asyncio.sleep(1)
self.shutdown_event.set()
async def shutdown_trigger_placeholder(self):
await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被优雅地关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
async def shutdown_trigger_placeholder(self):
# TODO: use asyncio.Event
while not self._event_queue.closed and not self.stop: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("aiocqhttp 适配器已关闭。")
async def handle_msg(self, message: AstrBotMessage):
message_event = AiocqhttpMessageEvent(
message_str=message.message_str,
@@ -0,0 +1,227 @@
import asyncio
import uuid
import aiohttp
import dingtalk_stream
import threading
from astrbot.api.platform import (
Platform,
AstrBotMessage,
MessageMember,
MessageType,
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .dingtalk_event import DingtalkMessageEvent
from ...register import register_platform_adapter
from astrbot import logger
from dingtalk_stream import AckMessage
from astrbot.core.utils.io import download_file
class MyEventHandler(dingtalk_stream.EventHandler):
async def process(self, event: dingtalk_stream.EventMessage):
print(
"2",
event.headers.event_type,
event.headers.event_id,
event.headers.event_born_time,
event.data,
)
return AckMessage.STATUS_OK, "OK"
@register_platform_adapter("dingtalk", "钉钉机器人官方 API 适配器")
class DingtalkPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(event_queue)
self.config = platform_config
self.unique_session = platform_settings["unique_session"]
self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"]
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
async def process(self_, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
abm = await self.convert_msg(im)
await self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK"
self.client = AstrCallbackClient()
credential = dingtalk_stream.Credential(self.client_id, self.client_secret)
client = dingtalk_stream.DingTalkStreamClient(credential, logger=logger)
client.register_all_event_handler(MyEventHandler())
client.register_callback_handler(
dingtalk_stream.ChatbotMessage.TOPIC, self.client
)
self.client_ = client # 用于 websockets 的 client
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"dingtalk",
"钉钉机器人官方 API 适配器",
)
async def convert_msg(
self, message: dingtalk_stream.ChatbotMessage
) -> AstrBotMessage:
abm = AstrBotMessage()
abm.message = []
abm.message_str = ""
abm.timestamp = int(message.create_at / 1000)
abm.type = (
MessageType.GROUP_MESSAGE
if message.conversation_type == "2"
else MessageType.FRIEND_MESSAGE
)
abm.sender = MessageMember(
user_id=message.sender_id, nickname=message.sender_nick
)
abm.self_id = message.chatbot_user_id
abm.message_id = message.message_id
abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE:
if message.is_in_at_list:
abm.message.append(At(qq=abm.self_id))
abm.group_id = message.conversation_id
if self.unique_session:
abm.session_id = abm.sender.user_id
else:
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id
message_type: str = message.message_type
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "richText":
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
contents: list[dict] = rtc.rich_text_list
for content in contents:
plains = ""
if "text" in content:
plains += content["text"]
abm.message.append(Plain(plains))
elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file(
content["downloadCode"],
message.robot_code,
"jpg",
)
abm.message.append(Image.fromFileSystem(f_path))
case "audio":
pass
return abm # 别忘了返回转换后的消息对象
async def download_ding_file(
self, download_code: str, robot_code: str, ext: str
) -> str:
"""下载钉钉文件
:param access_token: 钉钉机器人的 access_token
:param download_code: 下载码
:param robot_code: 机器人码
:param ext: 文件后缀
:return: 文件路径
"""
access_token = await self.get_access_token()
headers = {
"x-acs-dingtalk-access-token": access_token,
}
payload = {
"downloadCode": download_code,
"robotCode": robot_code,
}
f_path = f"data/dingtalk_file_{uuid.uuid4()}.{ext}"
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/robot/messageFiles/download",
headers=headers,
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}"
)
return None
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path)
return f_path
async def get_access_token(self) -> str:
payload = {
"appKey": self.client_id,
"appSecret": self.client_secret,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/oauth2/accessToken",
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}"
)
return None
return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage):
event = DingtalkMessageEvent(
message_str=abm.message_str,
message_obj=abm,
platform_meta=self.meta(),
session_id=abm.session_id,
client=self.client,
)
self._event_queue.put_nowait(event)
async def run(self):
# await self.client_.start()
# 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。
def start_client(loop: asyncio.AbstractEventLoop):
try:
self._shutdown_event = threading.Event()
task = loop.create_task(self.client_.start())
self._shutdown_event.wait()
if task.done():
task.result()
except Exception as e:
if "Graceful shutdown" in str(e):
logger.info("钉钉适配器已被优雅地关闭")
return
logger.error(f"钉钉机器人启动失败: {e}")
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self):
def monkey_patch_close():
raise Exception("Graceful shutdown")
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
self._shutdown_event.set()
def get_client(self):
return self.client
@@ -0,0 +1,58 @@
import asyncio
import dingtalk_stream
import astrbot.api.message_components as Comp
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot import logger
class DingtalkMessageEvent(AstrMessageEvent):
def __init__(
self,
message_str,
message_obj,
platform_meta,
session_id,
client: dingtalk_stream.ChatbotHandler,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
async def send_with_client(
self, client: dingtalk_stream.ChatbotHandler, message: MessageChain
):
for segment in message.chain:
if isinstance(segment, Comp.Plain):
segment.text = segment.text.strip()
await asyncio.get_event_loop().run_in_executor(
None, client.reply_markdown, "AstrBot", segment.text, self.message_obj.raw_message
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
if segment.file and segment.file.startswith("file:///"):
logger.warning(
"dingtalk only support url image, not: " + segment.file
)
continue
elif segment.file and segment.file.startswith("http"):
markdown_str += f"![image]({segment.file})\n\n"
elif segment.file and segment.file.startswith("base64://"):
logger.warning("dingtalk only support url image, not base64")
continue
else:
logger.warning(
"dingtalk only support url image, not: " + segment.file
)
continue
ret = await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
"😄",
markdown_str,
self.message_obj.raw_message,
)
logger.debug(f"send image: {ret}")
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
await super().send(message)
+335 -48
View File
@@ -1,17 +1,26 @@
import threading
import asyncio
import aiohttp
import quart
import base64
import datetime
import re
import os
import re
import threading
import aiohttp
import anyio
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At, Record
import quart
from astrbot.api import logger, sp
from .downloader import GeweDownloader
from astrbot.api.message_components import Plain, Image, At, Record, Video
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.core.utils.io import download_image_by_url
from .downloader import GeweDownloader
try:
from .xml_data_parser import GeweDataParser
except (ImportError, ModuleNotFoundError) as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {str(e)}"
)
class SimpleGewechatClient:
@@ -51,11 +60,11 @@ class SimpleGewechatClient:
self.server = quart.Quart(__name__)
self.server.add_url_rule(
"/astrbot-gewechat/callback", view_func=self.callback, methods=["POST"]
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
)
self.server.add_url_rule(
"/astrbot-gewechat/file/<file_id>",
view_func=self.handle_file,
view_func=self._handle_file,
methods=["GET"],
)
@@ -70,9 +79,10 @@ class SimpleGewechatClient:
self.userrealnames = {}
self.stop = False
self.shutdown_event = asyncio.Event()
async def get_token_id(self):
"""获取 Gewechat Token。"""
async with aiohttp.ClientSession() as session:
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
json_blob = await resp.json()
@@ -87,6 +97,15 @@ class SimpleGewechatClient:
type_name = data["type_name"]
else:
raise Exception("无法识别的消息类型")
# 以下没有业务处理,只是避免控制台打印太多的日志
if type_name == "ModContacts":
logger.info("gewechat下发:ModContacts消息通知。")
return
if type_name == "DelContacts":
logger.info("gewechat下发:DelContacts消息通知。")
return
if type_name == "Offline":
logger.critical("收到 gewechat 下线通知。")
return
@@ -147,6 +166,11 @@ class SimpleGewechatClient:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
# 检查消息是否由自己发送,若是则忽略
if user_id == abm.self_id:
logger.info("忽略自己发送的消息")
return None
abm.message = []
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
@@ -178,6 +202,11 @@ class SimpleGewechatClient:
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
if user_id == "weixin":
# 忽略微信团队消息
return
# 不同消息类型
match d["MsgType"]:
case 1:
@@ -195,18 +224,42 @@ class SimpleGewechatClient:
case 34:
# 语音消息
# data = await self.multimedia_downloader.download_voice(
# self.appid,
# content,
# abm.message_id
# )
# print(data)
if "ImgBuf" in d and "buffer" in d["ImgBuf"]:
voice_data = base64.b64decode(d["ImgBuf"]["buffer"])
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
# 以下已知消息类型,没有业务处理,只是避免控制台打印太多的日志
case 37: # 好友申请
logger.info("消息类型(37):好友申请")
case 42: # 名片
logger.info("消息类型(42):名片")
case 43: # 视频
video = Video(file="", cover=content)
abm.message.append(video)
case 47: # emoji
data_parser = GeweDataParser(content, abm.group_id == "")
emoji = data_parser.parse_emoji()
abm.message.append(emoji)
case 48: # 地理位置
logger.info("消息类型(48):地理位置")
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
data_parser = GeweDataParser(content, abm.group_id == "")
abm_data = data_parser.parse_mutil_49()
if abm_data:
abm.message.append(abm_data)
case 51: # 帐号消息同步?
logger.info("消息类型(51):帐号消息同步?")
case 10000: # 被踢出群聊/更换群主/修改群名称
logger.info("消息类型(10000):被踢出群聊/更换群主/修改群名称")
case 10002: # 撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办
logger.info(
"消息类型(10002):撤回/拍一拍/成员邀请/被移出群聊/解散群聊/群公告/群待办"
)
case _:
logger.info(f"未实现的消息类型: {d['MsgType']}")
abm.raw_message = d
@@ -214,7 +267,7 @@ class SimpleGewechatClient:
logger.debug(f"abm: {abm}")
return abm
async def callback(self):
async def _callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
@@ -236,7 +289,7 @@ class SimpleGewechatClient:
return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id):
async def _handle_file(self, file_id):
file_path = f"data/temp/{file_id}"
return await quart.send_file(file_path)
@@ -262,17 +315,14 @@ class SimpleGewechatClient:
await self.server.run_task(
host="0.0.0.0",
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger_placeholder(self):
# TODO: use asyncio.Event
while not self.event_queue.closed and not self.stop: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("gewechat 适配器已关闭。")
async def shutdown_trigger(self):
await self.shutdown_event.wait()
async def check_online(self, appid: str):
# /login/checkOnline
"""检查 APPID 对应的设备是否在线。"""
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/checkOnline",
@@ -283,6 +333,7 @@ class SimpleGewechatClient:
return json_blob["data"]
async def logout(self):
"""登出 gewechat。"""
if self.appid:
online = await self.check_online(self.appid)
if online:
@@ -296,6 +347,7 @@ class SimpleGewechatClient:
logger.info(f"登出结果: {json_blob}")
async def login(self):
"""登录 gewechat。一般来说插件用不到这个方法。"""
if self.token is None:
await self.get_token_id()
@@ -304,32 +356,49 @@ class SimpleGewechatClient:
)
if self.appid:
online = await self.check_online(self.appid)
if online:
logger.info(f"APPID: {self.appid} 已在线")
return
try:
online = await self.check_online(self.appid)
if online:
logger.info(f"APPID: {self.appid} 已在线")
return
except Exception as e:
logger.error(f"检查在线状态失败: {e}")
sp.put(f"gewechat-appid-{self.nickname}", "")
self.appid = None
payload = {"appId": self.appid}
if self.appid:
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/getLoginQrCode",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
if json_blob["ret"] != 200:
raise Exception(f"获取二维码失败: {json_blob}")
qr_data = json_blob["data"]["qrData"]
qr_uuid = json_blob["data"]["uuid"]
appid = json_blob["data"]["appId"]
logger.info(f"APPID: {appid}")
logger.warning(
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
)
try:
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/login/getLoginQrCode",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
if json_blob["ret"] != 200:
error_msg = json_blob.get("data", {}).get("msg", "")
if "设备不存在" in error_msg:
logger.error(
f"检测到无效的appid: {self.appid},将清除并重新登录。"
)
sp.put(f"gewechat-appid-{self.nickname}", "")
self.appid = None
return await self.login()
else:
raise Exception(f"获取二维码失败: {json_blob}")
qr_data = json_blob["data"]["qrData"]
qr_uuid = json_blob["data"]["uuid"]
appid = json_blob["data"]["appId"]
logger.info(f"APPID: {appid}")
logger.warning(
f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}"
)
except Exception as e:
raise e
# 执行登录
retry_cnt = 64
@@ -390,9 +459,18 @@ class SimpleGewechatClient:
self.appid = appid
logger.info(f"已保存 APPID: {appid}")
"""API"""
"""API 部分。Gewechat 的 API 文档请参考: https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1
"""
async def get_chatroom_member_list(self, chatroom_wxid: str):
async def get_chatroom_member_list(self, chatroom_wxid: str) -> dict:
"""获取群成员列表。
Args:
chatroom_wxid (str): 微信群聊的id。可以通过 event.get_group_id() 获取。
Returns:
dict: 返回群成员列表字典。其中键为 memberList 的值为群成员列表。
"""
payload = {"appId": self.appid, "chatroomId": chatroom_wxid}
async with aiohttp.ClientSession() as session:
@@ -405,6 +483,7 @@ class SimpleGewechatClient:
return json_blob["data"]
async def post_text(self, to_wxid, content: str, ats: str = ""):
"""发送纯文本消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
@@ -421,6 +500,7 @@ class SimpleGewechatClient:
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
"""发送图片消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
@@ -434,7 +514,79 @@ class SimpleGewechatClient:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_emoji(self, to_wxid, emoji_md5, emoji_size, cdnurl=""):
"""发送emoji消息"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"emojiMd5": emoji_md5,
"emojiSize": emoji_size,
}
# 优先表情包,若拿不到表情包的md5,就用当作图片发
try:
if emoji_md5 != "" and emoji_size != "":
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postEmoji",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.info(
f"发送emoji消息结果: {json_blob.get('msg', '操作失败')}"
)
else:
await self.post_image(to_wxid, cdnurl)
except Exception as e:
logger.error(e)
async def post_video(
self, to_wxid, video_url: str, thumb_url: str, video_duration: int
):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"videoUrl": video_url,
"thumbUrl": thumb_url,
"videoDuration": video_duration,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVideo", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送视频结果: {json_blob}")
async def forward_video(self, to_wxid, cnd_xml: str):
"""转发视频
Args:
to_wxid (str): 发送给谁
cnd_xml (str): 视频消息的cdn信息
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"xml": cnd_xml,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/forwardVideo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"转发视频结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
"""发送语音信息
Args:
voice_url (str): 语音文件的网络链接
voice_duration (int): 语音时长,毫秒
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
@@ -449,9 +601,16 @@ class SimpleGewechatClient:
f"{self.base_url}/message/postVoice", headers=self.headers, json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}")
logger.info(f"发送语音结果: {json_blob.get('msg', '操作失败')}")
async def post_file(self, to_wxid, file_url: str, file_name: str):
"""发送文件
Args:
to_wxid (string): 微信ID
file_url (str): 文件的网络链接
file_name (str): 文件名
"""
payload = {
"appId": self.appid,
"toWxid": to_wxid,
@@ -465,3 +624,131 @@ class SimpleGewechatClient:
) as resp:
json_blob = await resp.json()
logger.debug(f"发送文件结果: {json_blob}")
async def add_friend(self, v3: str, v4: str, content: str):
"""申请添加好友"""
payload = {
"appId": self.appid,
"scene": 3,
"content": content,
"v4": v4,
"v3": v3,
"option": 2,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/addContacts",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"申请添加好友结果: {json_blob}")
return json_blob
async def get_group(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_group_member(self, group_id: str):
payload = {
"appId": self.appid,
"chatroomId": group_id,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/getChatroomMemberList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def accept_group_invite(self, url: str):
"""同意进群"""
payload = {"appId": self.appid, "url": url}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/agreeJoinRoom",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def add_group_member_to_friend(
self, group_id: str, to_wxid: str, content: str
):
payload = {
"appId": self.appid,
"chatroomId": group_id,
"content": content,
"memberWxid": to_wxid,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/group/addGroupMemberAsFriend",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_user_or_group_info(self, *ids):
"""
获取用户或群组信息。
:param ids: 可变数量的 wxid 参数
"""
wxids_str = list(ids)
payload = {
"appId": self.appid,
"wxids": wxids_str, # 使用逗号分隔的字符串
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/getDetailInfo",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取群信息结果: {json_blob}")
return json_blob
async def get_contacts_list(self):
"""
获取通讯录列表
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
"""
payload = {"appId": self.appid}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/contacts/fetchContactsList",
headers=self.headers,
json=payload,
) as resp:
json_blob = await resp.json()
logger.debug(f"获取通讯录列表结果: {json_blob}")
return json_blob
@@ -39,3 +39,17 @@ class GeweDownloader:
continue
raise Exception("无法下载图片")
async def download_emoji_md5(self, app_id, emoji_md5):
"""下载emoji"""
try:
payload = {"appId": app_id, "emojiMd5": emoji_md5}
# gewe 计划中的接口,暂时没有实现。返回代码404
data = await self._post_json(
self.base_url, "/message/downloadEmojiMd5", payload
)
json_blob = json.loads(data)
return json_blob
except BaseException as e:
logger.error(f"gewe download emoji: {e}")
@@ -2,12 +2,21 @@ import wave
import uuid
import traceback
import os
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
from astrbot.core.utils.io import save_temp_img, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record, At, File
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, Group, MessageMember
from astrbot.api.message_components import (
Plain,
Image,
Record,
At,
File,
Video,
WechatEmoji as Emoji,
)
from .client import SimpleGewechatClient
@@ -37,12 +46,9 @@ class GewechatPlatformEvent(AstrMessageEvent):
self.client = client
@staticmethod
async def send_with_client(message: MessageChain, user_name: str):
pass
async def send(self, message: MessageChain):
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
async def send_with_client(
message: MessageChain, to_wxid: str, client: SimpleGewechatClient
):
if not to_wxid:
logger.error("无法获取到 to_wxid。")
return
@@ -70,56 +76,94 @@ class GewechatPlatformEvent(AstrMessageEvent):
payload["content"] = text
payload["ats"] = ats
has_at = True
await self.client.post_text(**payload)
await client.post_text(**payload)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
img_path = await comp.convert_to_file_path()
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
# 检查 record_path 是否在 data/temp 目录中
temp_directory = os.path.abspath("data/temp")
img_path = os.path.abspath(img_path)
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
file_id = os.path.basename(img_path)
img_url = f"{self.client.file_server_url}/{file_id}"
img_url = f"{client.file_server_url}/{file_id}"
logger.debug(f"gewe callback img url: {img_url}")
await self.client.post_image(to_wxid, img_url)
await client.post_image(to_wxid, img_url)
elif isinstance(comp, Video):
if comp.cover != "":
await client.forward_video(to_wxid, comp.cover)
else:
try:
from pyffmpeg import FFmpeg
except (ImportError, ModuleNotFoundError):
logger.error(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
raise ModuleNotFoundError(
"需要安装 pyffmpeg 库才能发送视频: pip install pyffmpeg"
)
video_url = comp.file
# 根据 url 下载视频
video_filename = f"{uuid.uuid4()}.mp4"
video_path = f"data/temp/{video_filename}"
await download_file(video_url, video_path)
# 获取视频第一帧
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
try:
ff = FFmpeg()
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
ff.options(command)
thumb_file_id = os.path.basename(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
except Exception as e:
logger.error(f"获取视频第一帧失败: {e}")
# 获取视频时长
try:
from pyffmpeg import FFprobe
# 创建 FFprobe 实例
ffprobe = FFprobe(video_url)
# 获取时长字符串
duration_str = ffprobe.duration
# 处理时长字符串
video_duration = float(duration_str.replace(":", ""))
except Exception as e:
logger.error(f"获取时长失败: {e}")
video_duration = 10
file_id = os.path.basename(video_path)
video_url = f"{client.file_server_url}/{file_id}"
await client.post_video(
to_wxid, video_url, thumb_url, video_duration
)
# 删除临时视频和缩略图文件
if os.path.exists(video_path):
os.remove(video_path)
if os.path.exists(thumb_path):
os.remove(thumb_path)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
record_path = await comp.convert_to_file_path()
silk_path = f"data/temp/{uuid.uuid4()}.silk"
try:
duration = await wav_to_tencent_silk(record_path, silk_path)
except Exception as e:
logger.error(traceback.format_exc())
await self.send(
MessageChain().message(f"语音文件转换失败。{str(e)}")
)
await client.post_text(to_wxid, f"语音文件转换失败。{str(e)}")
logger.info("Silk 语音文件格式转换至: " + record_path)
if duration == 0:
duration = get_wav_duration(record_path)
file_id = os.path.basename(silk_path)
record_url = f"{self.client.file_server_url}/{file_id}"
record_url = f"{client.file_server_url}/{file_id}"
logger.debug(f"gewe callback record url: {record_url}")
await self.client.post_voice(to_wxid, record_url, duration * 1000)
await client.post_voice(to_wxid, record_url, duration * 1000)
elif isinstance(comp, File):
file_path = comp.file
file_name = comp.name
@@ -131,12 +175,44 @@ class GewechatPlatformEvent(AstrMessageEvent):
file_path = file_path
file_id = os.path.basename(file_path)
file_url = f"{self.client.file_server_url}/{file_id}"
file_url = f"{client.file_server_url}/{file_id}"
logger.debug(f"gewe callback file url: {file_url}")
await self.client.post_file(to_wxid, file_url, file_id)
await client.post_file(to_wxid, file_url, file_id)
elif isinstance(comp, Emoji):
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
elif isinstance(comp, At):
pass
else:
logger.debug(f"gewechat 忽略: {comp.type}")
async def send(self, message: MessageChain):
to_wxid = self.message_obj.raw_message.get("to_wxid", None)
await GewechatPlatformEvent.send_with_client(message, to_wxid, self.client)
await super().send(message)
async def get_group(self, group_id=None, **kwargs):
# 确定有效的 group_id
if group_id is None:
group_id = self.get_group_id()
if not group_id:
return None
res = await self.client.get_group(group_id)
data: dict = res["data"]
if not data["chatroomId"]:
return None
members = [
MessageMember(user_id=member["wxid"], nickname=member["nickName"])
for member in data.get("memberList", [])
]
return Group(
group_id=data["chatroomId"],
group_name=data.get("nickName"),
group_avatar=data.get("smallHeadImgUrl"),
group_owner=data.get("chatRoomOwner"),
members=members,
)
@@ -4,12 +4,11 @@ import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from .gewechat_event import GewechatPlatformEvent
from .client import SimpleGewechatClient
from astrbot.core.message.components import Plain
from astrbot import logger
if sys.version_info >= (3, 12):
from typing import override
@@ -45,14 +44,16 @@ class GewechatPlatformAdapter(Platform):
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
to_wxid = session.session_id
if not to_wxid:
logger.error("无法获取到 to_wxid。")
return
session_id = session.session_id
if "#" in session_id:
# unique session
to_wxid = session_id.split("#")[1]
else:
to_wxid = session_id
for comp in message_chain.chain:
if isinstance(comp, Plain):
await self.client.post_text(to_wxid, comp.text)
await GewechatPlatformEvent.send_with_client(
message_chain, to_wxid, self.client
)
await super().send_by_session(session, message_chain)
@@ -64,8 +65,9 @@ class GewechatPlatformAdapter(Platform):
)
async def terminate(self):
self.client.stop = True
await asyncio.sleep(1)
self.client.shutdown_event.set()
await self.client.server.shutdown()
logger.info("Gewechat 适配器已被优雅地关闭。")
async def logout(self):
await self.client.logout()
@@ -81,7 +83,7 @@ class GewechatPlatformAdapter(Platform):
async def handle_msg(self, message: AstrBotMessage):
if message.type == MessageType.GROUP_MESSAGE:
if self.settingss["unique_session"]:
message.session_id = message.sender.user_id + "_" + message.group_id
message.session_id = message.sender.user_id + "#" + message.group_id
message_event = GewechatPlatformEvent(
message_str=message.message_str,
@@ -0,0 +1,78 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import WechatEmoji as Emoji, Reply, Plain
class GeweDataParser:
def __init__(self, data, is_private_chat):
self.data = data
self.is_private_chat = is_private_chat
def _format_to_xml(self):
return eT.fromstring(self.data)
def parse_mutil_49(self):
appmsg_type = self._format_to_xml().find(".//appmsg/type")
if appmsg_type is None:
return
match appmsg_type.text:
case "57":
return self.parse_reply()
def parse_emoji(self) -> Emoji | None:
try:
emoji_element = self._format_to_xml().find(".//emoji")
# 提取 md5 和 len 属性
if emoji_element is not None:
md5_value = emoji_element.get("md5")
emoji_size = emoji_element.get("len")
cdnurl = emoji_element.get("cdnurl")
return Emoji(md5=md5_value, md5_len=emoji_size, cdnurl=cdnurl)
except Exception as e:
logger.error(f"gewechat: parse_emoji failed, {e}")
def parse_reply(self) -> Reply | None:
try:
replied_id = -1
replied_uid = 0
replied_nickname = ""
replied_content = ""
content = ""
root = self._format_to_xml()
refermsg = root.find(".//refermsg")
if refermsg is not None:
# 被引用的信息
svrid = refermsg.find("svrid")
fromusr = refermsg.find("fromusr")
displayname = refermsg.find("displayname")
refermsg_content = refermsg.find("content")
if svrid is not None:
replied_id = svrid.text
if fromusr is not None:
replied_uid = fromusr.text
if displayname is not None:
replied_nickname = displayname.text
if refermsg_content is not None:
replied_content = refermsg_content.text
# 提取引用者说的内容
title = root.find(".//appmsg/title")
if title is not None:
content = title.text
r = Reply(
id=replied_id,
chain=[Plain(content)],
sender_id=replied_uid,
sender_nickname=replied_nickname,
sender_str=replied_content,
message_str=content,
)
return r
except Exception as e:
logger.error(f"gewechat: parse_reply failed, {e}")
@@ -2,6 +2,7 @@ import base64
import asyncio
import json
import re
import astrbot.api.message_components as Comp
from astrbot.api.platform import (
Platform,
@@ -11,7 +12,6 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .lark_event import LarkMessageEvent
from ...register import register_platform_adapter
@@ -66,7 +66,7 @@ class LarkPlatformAdapter(Platform):
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
raise NotImplementedError("Lark 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
@@ -92,7 +92,7 @@ class LarkPlatformAdapter(Platform):
at_list = {}
if message.mentions:
for m in message.mentions:
at_list[m.key] = At(qq=m.id.open_id, name=m.name)
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
if m.name == self.bot_name:
abm.self_id = m.id.open_id
@@ -111,7 +111,7 @@ class LarkPlatformAdapter(Platform):
if s in at_list:
abm.message.append(at_list[s])
else:
abm.message.append(Plain(parts[i].strip()))
abm.message.append(Comp.Plain(parts[i].strip()))
elif message.message_type == "post":
_ls = []
@@ -132,7 +132,7 @@ class LarkPlatformAdapter(Platform):
if comp["tag"] == "at":
abm.message.append(at_list[comp["user_id"]])
elif comp["tag"] == "text" and comp["text"].strip():
abm.message.append(Plain(comp["text"].strip()))
abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp["tag"] == "img":
image_key = comp["image_key"]
request = (
@@ -147,10 +147,10 @@ class LarkPlatformAdapter(Platform):
logger.error(f"无法下载飞书图片: {image_key}")
image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Image.fromBase64(image_base64))
abm.message.append(Comp.Image.fromBase64(image_base64))
for comp in abm.message:
if isinstance(comp, Plain):
if isinstance(comp, Comp.Plain):
abm.message_str += comp.text
abm.message_id = message.message_id
abm.raw_message = message
@@ -185,5 +185,9 @@ class LarkPlatformAdapter(Platform):
# self.client.start()
await self.client._connect()
async def terminate(self):
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已被优雅地关闭")
def get_client(self) -> lark.Client:
return self.client
@@ -122,16 +122,16 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text += i.text
elif isinstance(i, Image) and not image_base64:
if i.file and i.file.startswith("file:///"):
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
image_base64 = file_to_base64(i.file[8:])
image_file_path = i.file[8:]
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path).replace(
"base64://", ""
)
image_base64 = file_to_base64(image_file_path)
elif i.file and i.file.startswith("base64://"):
image_base64 = i.file
else:
image_base64 = file_to_base64(i.file).replace("base64://", "")
image_file_path = i.file
image_base64 = file_to_base64(i.file)
image_base64 = image_base64.removeprefix("base64://")
else:
logger.debug(f"qq_official 忽略 {i.type}")
return plain_text, image_base64, image_file_path
@@ -17,6 +17,7 @@ from astrbot.api.platform import (
MessageType,
PlatformMetadata,
)
from astrbot import logger
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain, At
@@ -204,3 +205,7 @@ class QQOfficialPlatformAdapter(Platform):
def get_client(self) -> botClient:
return self.client
async def terminate(self):
await self.client.close()
logger.info("QQ 官方机器人接口 适配器已被优雅地关闭")
@@ -13,6 +13,7 @@ from .qo_webhook_event import QQOfficialWebhookMessageEvent
from ...register import register_platform_adapter
from .qo_webhook_server import QQOfficialWebhook
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
from astrbot import logger
# remove logger handler
for handler in logging.root.handlers[:]:
@@ -111,3 +112,9 @@ class QQOfficialWebhookPlatformAdapter(Platform):
def get_client(self) -> botClient:
return self.client
async def terminate(self):
self.webhook_helper.shutdown_event.set()
await self.client.close()
await self.webhook_helper.server.shutdown()
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
@@ -15,6 +15,7 @@ class QQOfficialWebhook:
self.appid = config["appid"]
self.secret = config["secret"]
self.port = config.get("port", 6196)
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
if isinstance(self.port, str):
self.port = int(self.port)
@@ -29,6 +30,7 @@ class QQOfficialWebhook:
)
self.client = botpy_client
self.event_queue = event_queue
self.shutdown_event = asyncio.Event()
async def initialize(self):
logger.info("正在登录到 QQ 官方机器人...")
@@ -95,13 +97,14 @@ class QQOfficialWebhook:
return {"opcode": 12}
async def start_polling(self):
logger.info(
f"将在 {self.callback_server_host}:{self.port} 端口启动 QQ 官方机器人 webhook 适配器。"
)
await self.server.run_task(
host="0.0.0.0",
host=self.callback_server_host,
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger_placeholder(self):
while not self.event_queue.closed: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("qq_official_webhook 适配器已关闭。")
async def shutdown_trigger(self):
await self.shutdown_event.wait()
@@ -1,6 +1,7 @@
import sys
import uuid
import asyncio
import astrbot.api.message_components as Comp
from astrbot.api.platform import (
Platform,
@@ -10,14 +11,6 @@ from astrbot.api.platform import (
MessageType,
)
from astrbot.api.event import MessageChain
from astrbot.api.message_components import (
Plain,
Image,
Record,
File as AstrBotFile,
Video,
At,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.api.platform import register_platform_adapter
@@ -50,18 +43,29 @@ class TelegramPlatformAdapter(Platform):
)
if not base_url:
base_url = "https://api.telegram.org/bot"
file_base_url = self.config.get(
"telegram_file_base_url", "https://api.telegram.org/file/bot"
)
if not file_base_url:
file_base_url = "https://api.telegram.org/file/bot"
self.base_url = base_url
self.application = (
ApplicationBuilder()
.token(self.config["telegram_token"])
.base_url(base_url)
.base_file_url(file_base_url)
.build()
)
message_handler = TelegramMessageHandler(
filters=filters.ALL, # receive all messages
callback=self.convert_message,
callback=self.message_handler,
)
self.application.add_handler(message_handler)
self.client = self.application.bot
logger.debug(f"Telegram base url: {self.client.base_url}")
@override
async def send_by_session(
@@ -93,70 +97,135 @@ class TelegramPlatformAdapter(Platform):
chat_id=update.effective_chat.id, text=self.config["start_message"]
)
async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
logger.debug(f"Telegram message: {update.message}")
abm = await self.convert_message(update, context)
if abm:
await self.handle_msg(abm)
async def convert_message(
self, update: Update, context: ContextTypes.DEFAULT_TYPE
self, update: Update, context: ContextTypes.DEFAULT_TYPE, get_reply=True
) -> AstrBotMessage:
"""转换 Telegram 的消息对象为 AstrBotMessage 对象。
@param update: Telegram 的 Update 对象。
@param context: Telegram 的 Context 对象。
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
message = AstrBotMessage()
message.session_id = str(update.message.chat.id)
# 获得是群聊还是私聊
if update.effective_chat.type == ChatType.PRIVATE:
if update.message.chat.type == ChatType.PRIVATE:
message.type = MessageType.FRIEND_MESSAGE
else:
message.type = MessageType.GROUP_MESSAGE
message.group_id = update.effective_chat.id
message.group_id = str(update.message.chat.id)
if update.message.message_thread_id:
# Topic Group
message.group_id += "#" + str(update.message.message_thread_id)
message.session_id = message.group_id
message.message_id = str(update.message.message_id)
message.session_id = str(update.effective_chat.id)
message.sender = MessageMember(
str(update.effective_user.id), update.effective_user.username
str(update.message.from_user.id), update.message.from_user.username
)
message.self_id = str(context.bot.username)
message.raw_message = update
message.message_str = ""
message.message = []
logger.debug(f"Telegram message: {update.message}")
if update.message.reply_to_message and not (
update.message.is_topic_message
and update.message.message_thread_id
== update.message.reply_to_message.message_id
):
# 获取回复消息
reply_update = Update(
update_id=1,
message=update.message.reply_to_message,
)
reply_abm = await self.convert_message(reply_update, context, False)
message.message.append(
Comp.Reply(
id=reply_abm.message_id,
chain=reply_abm.message,
sender_id=reply_abm.sender.user_id,
sender_nickname=reply_abm.sender.nickname,
time=reply_abm.timestamp,
message_str=reply_abm.message_str,
text=reply_abm.message_str,
qq=reply_abm.sender.user_id,
)
)
if update.message.text:
# 处理文本消息
plain_text = update.message.text
if update.message.entities:
for entity in update.message.entities:
if entity.type == "mention":
name = plain_text[entity.offset+1 : entity.offset + entity.length]
message.message.append(At(qq=name, name=name))
name = plain_text[
entity.offset + 1 : entity.offset + entity.length
]
message.message.append(Comp.At(qq=name, name=name))
plain_text = (
plain_text[: entity.offset]
+ plain_text[entity.offset + entity.length :]
)
message.message.append(Plain(plain_text))
if plain_text:
message.message.append(Comp.Plain(plain_text))
message.message_str = plain_text
if message.message_str.strip() == "/start":
await self.start(update, context)
return
elif update.message.voice:
file = await update.message.voice.get_file()
message.message = [
Record(file=file.file_path, url=file.file_path),
Comp.Record(file=file.file_path, url=file.file_path),
]
elif update.message.photo:
photo = update.message.photo[-1] # get the largest photo
file = await photo.get_file()
message.message.append(Image(file=file.file_path, url=file.file_path))
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
if update.message.caption:
message.message_str = update.message.caption
message.message.append(Comp.Plain(message.message_str))
if update.message.caption_entities:
for entity in update.message.caption_entities:
if entity.type == "mention":
name = message.message_str[
entity.offset + 1 : entity.offset + entity.length
]
message.message.append(Comp.At(qq=name, name=name))
elif update.message.sticker:
# 将sticker当作图片处理
file = await update.message.sticker.get_file()
message.message.append(Comp.Image(file=file.file_path, url=file.file_path))
if update.message.sticker.emoji:
sticker_text = f"Sticker: {update.message.sticker.emoji}"
message.message_str = sticker_text
message.message.append(Comp.Plain(sticker_text))
elif update.message.document:
file = await update.message.document.get_file()
message.message = [
AstrBotFile(
file=file.file_path, name=update.message.document.file_name
),
Comp.File(file=file.file_path, name=update.message.document.file_name),
]
elif update.message.video:
file = await update.message.video.get_file()
message.message = [
Video(file=file.file_path, path=file.file_path),
Comp.Video(file=file.file_path, path=file.file_path),
]
await self.handle_msg(message)
return message
async def handle_msg(self, message: AstrBotMessage):
message_event = TelegramPlatformEvent(
@@ -170,3 +239,15 @@ class TelegramPlatformAdapter(Platform):
def get_client(self) -> ExtBot:
return self.client
async def terminate(self):
try:
await self.application.stop()
# 保险起见先判断是否存在updater对象
if self.application.updater is not None:
await self.application.updater.stop()
logger.info("Telegram 适配器已被优雅地关闭")
except Exception as e:
logger.error(f"Telegram 适配器关闭时出错: {e}")
@@ -1,7 +1,10 @@
import telegramify_markdown
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
from telegram.ext import ExtBot
from astrbot.core.utils.io import download_file
from astrbot import logger
class TelegramPlatformEvent(AstrMessageEvent):
@@ -31,36 +34,47 @@ class TelegramPlatformEvent(AstrMessageEvent):
at_user_id = i.name
at_flag = False
message_thread_id = None
if "#" in user_name:
# it's a supergroup chat with message_thread_id
user_name, message_thread_id = user_name.split("#")
for i in message.chain:
payload = {
"chat_id": user_name,
}
if has_reply:
payload["reply_to_message_id"] = reply_message_id
if message_thread_id:
payload["message_thread_id"] = message_thread_id
if isinstance(i, Plain):
if at_user_id and not at_flag:
i.text = f"@{at_user_id} " + i.text
at_flag = True
await client.send_message(text=i.text, **payload)
text = i.text
try:
text = telegramify_markdown.markdownify(
i.text, max_line_length=None, normalize_whitespace=False
)
except Exception as e:
logger.warning(
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
)
return
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
elif isinstance(i, Image):
if i.path:
image_path = i.path
else:
image_path = i.file
if image_path.startswith("base64://"):
import base64
base64_data = image_path[9:]
image_bytes = base64.b64decode(base64_data)
await client.send_photo(photo=image_bytes, **payload)
else:
await client.send_photo(photo=image_path, **payload)
image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload)
elif isinstance(i, File):
if i.file.startswith("https://"):
path = "data/temp/" + i.name
await download_file(i.file, path)
i.file = path
await client.send_document(document=i.file, filename=i.name, **payload)
elif isinstance(i, Record):
await client.send_voice(voice=i.file, **payload)
path = await i.convert_to_file_path()
await client.send_voice(voice=path, **payload)
async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -13,7 +13,7 @@ from astrbot.core.platform import (
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.message.components import Plain, Image, Record # noqa: F403
from astrbot import logger
from astrbot.core import web_chat_queue, web_chat_back_queue
from astrbot.core import web_chat_queue
from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
@@ -50,14 +50,7 @@ class WebChatAdapter(Platform):
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
# abm.session_id = f"webchat!{username}!{cid}"
plain = ""
cid = session.session_id.split("!")[-1]
for comp in message_chain.chain:
if isinstance(comp, Plain):
plain += comp.text
web_chat_back_queue.put_nowait((plain, cid))
await WebChatMessageEvent._send(message_chain, session.session_id)
await super().send_by_session(session, message_chain)
async def convert_message(self, data: tuple) -> AstrBotMessage:
@@ -126,3 +119,7 @@ class WebChatAdapter(Platform):
)
self.commit_event(message_event)
async def terminate(self):
# Do nothing
pass
@@ -3,23 +3,25 @@ import uuid
import base64
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.api.message_components import Plain, Image, Record
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import web_chat_back_queue
imgs_dir = "data/webchat/imgs"
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.imgs_dir = "data/webchat/imgs"
os.makedirs(self.imgs_dir, exist_ok=True)
os.makedirs(imgs_dir, exist_ok=True)
async def send(self, message: MessageChain):
@staticmethod
async def _send(message: MessageChain, session_id: str):
if not message:
web_chat_back_queue.put_nowait(None)
return
cid = self.session_id.split("!")[-1]
cid = session_id.split("!")[-1]
for comp in message.chain:
if isinstance(comp, Plain):
@@ -27,7 +29,7 @@ class WebChatMessageEvent(AstrMessageEvent):
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
path = os.path.join(imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
@@ -45,7 +47,26 @@ class WebChatMessageEvent(AstrMessageEvent):
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
elif isinstance(comp, Record):
# save record to local
filename = str(uuid.uuid4()) + ".wav"
path = os.path.join(imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid))
else:
logger.debug(f"webchat 忽略: {comp.type}")
web_chat_back_queue.put_nowait(None)
async def send(self, message: MessageChain):
await WebChatMessageEvent._send(message, session_id=self.session_id)
await super().send(message)
@@ -34,6 +34,7 @@ class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
self.port = int(config.get("port"))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule(
"/callback/command", view_func=self.verify, methods=["GET"]
)
@@ -49,6 +50,7 @@ class WecomServer:
)
self.callback = None
self.shutdown_event = asyncio.Event()
async def verify(self):
logger.info(f"验证请求有效性: {quart.request.args}")
@@ -86,17 +88,17 @@ class WecomServer:
return "success"
async def start_polling(self):
logger.info(f"将在 0.0.0.0:{self.port} 端口启动 企业微信 适配器。")
logger.info(
f"将在 {self.callback_server_host}:{self.port} 端口启动 企业微信 适配器。"
)
await self.server.run_task(
host="0.0.0.0",
host=self.callback_server_host,
port=self.port,
shutdown_trigger=self.shutdown_trigger_placeholder,
shutdown_trigger=self.shutdown_trigger,
)
async def shutdown_trigger_placeholder(self):
while not self.event_queue.closed: # noqa: ASYNC110
await asyncio.sleep(1)
logger.info("企业微信 适配器已关闭。")
async def shutdown_trigger(self):
await self.shutdown_event.wait()
@register_platform_adapter("wecom", "wecom 适配器")
@@ -232,3 +234,8 @@ class WecomPlatformAdapter(Platform):
def get_client(self) -> WeChatClient:
return self.client
async def terminate(self):
self.server.shutdown_event.set()
await self.server.server.shutdown()
logger.info("企业微信 适配器已被优雅地关闭")
@@ -3,7 +3,6 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from wechatpy.enterprise import WeChatClient
from astrbot.core.utils.io import download_image_by_url, download_file
from astrbot.api import logger
@@ -43,14 +42,7 @@ class WecomPlatformEvent(AstrMessageEvent):
message_obj.self_id, message_obj.session_id, comp.text
)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
img_path = await comp.convert_to_file_path()
with open(img_path, "rb") as f:
try:
@@ -68,16 +60,7 @@ class WecomPlatformEvent(AstrMessageEvent):
response["media_id"],
)
elif isinstance(comp, Record):
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
record_path = await comp.convert_to_file_path()
# 转成amr
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
pydub.AudioSegment.from_wav(record_path).export(
+206 -4
View File
@@ -1,9 +1,18 @@
import enum
import base64
import json
from astrbot.core.utils.io import download_image_by_url
from astrbot import logger
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from astrbot.core.db.po import Conversation
from astrbot.core.message.message_event_result import MessageChain
import astrbot.core.message.components as Comp
class ProviderType(enum.Enum):
@@ -27,6 +36,58 @@ class ProviderMetaData:
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
@dataclass
class ToolCallMessageSegment:
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
tool_call_id: str
content: str
role: str = "tool"
def to_dict(self):
return {
"tool_call_id": self.tool_call_id,
"content": self.content,
"role": self.role,
}
@dataclass
class AssistantMessageSegment:
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
content: str = None
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
role: str = "assistant"
def to_dict(self):
ret = {
"role": self.role,
}
if self.content:
ret["content"] = self.content
elif self.tool_calls:
ret["tool_calls"] = self.tool_calls
return ret
@dataclass
class ToolCallsResult:
"""工具调用结果"""
tool_calls_info: AssistantMessageSegment
"""函数调用的信息"""
tool_calls_result: List[ToolCallMessageSegment]
"""函数调用的结果"""
def to_openai_messages(self) -> List[Dict]:
ret = [
self.tool_calls_info.to_dict(),
*[item.to_dict() for item in self.tool_calls_result],
]
return ret
@dataclass
class ProviderRequest:
prompt: str
@@ -36,7 +97,7 @@ class ProviderRequest:
image_urls: List[str] = None
"""图片 URL 列表"""
func_tool: FuncCall = None
"""工具"""
"""可用的函数工具"""
contexts: List = None
"""上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
@@ -45,23 +106,164 @@ class ProviderRequest:
"""系统提示词"""
conversation: Conversation = None
tool_calls_result: ToolCallsResult = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
def __repr__(self):
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self.contexts}, system_prompt={self.system_prompt})"
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
def __str__(self):
return self.__repr__()
def _print_friendly_context(self):
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
if not self.contexts:
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
result_parts = []
for ctx in self.contexts:
role = ctx.get("role", "unknown")
content = ctx.get("content", "")
if isinstance(content, str):
result_parts.append(f"{role}: {content}")
elif isinstance(content, list):
msg_parts = []
image_count = 0
for item in content:
item_type = item.get("type", "")
if item_type == "text":
msg_parts.append(item.get("text", ""))
elif item_type == "image_url":
image_count += 1
if image_count > 0:
if msg_parts:
msg_parts.append(f"[+{image_count} images]")
else:
msg_parts.append(f"[{image_count} images]")
result_parts.append(f"{role}: {''.join(msg_parts)}")
return result_parts
async def assemble_context(self) -> Dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
if self.image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": self.prompt}],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self._encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self._encode_image_bs64(image_path)
else:
image_data = await self._encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{"type": "image_url", "image_url": {"url": image_data}}
)
return user_content
else:
return {"role": "user", "content": self.prompt}
async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
@dataclass
class LLMResponse:
role: str
"""角色, assistant, tool, err"""
completion_text: str = ""
"""LLM 返回的文本"""
result_chain: MessageChain = None
"""返回的消息链"""
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
"""工具调用参数"""
tools_call_name: List[str] = field(default_factory=list)
"""工具调用名称"""
tools_call_ids: List[str] = field(default_factory=list)
"""工具调用 ID"""
raw_completion: ChatCompletion = None
_new_record: Dict[str, any] = None
_completion_text: str = ""
def __init__(
self,
role: str,
completion_text: str = "",
result_chain: MessageChain = None,
tools_call_args: List[Dict[str, any]] = [],
tools_call_name: List[str] = [],
tools_call_ids: List[str] = [],
raw_completion: ChatCompletion = None,
_new_record: Dict[str, any] = None,
):
"""初始化 LLMResponse
Args:
role (str): 角色, assistant, tool, err
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
"""
self.role = role
self.completion_text = completion_text
self.result_chain = result_chain
self.tools_call_args = tools_call_args
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.raw_completion = raw_completion
self._new_record = _new_record
@property
def completion_text(self):
if self.result_chain:
return self.result_chain.get_plain_text()
return self._completion_text
@completion_text.setter
def completion_text(self, value):
if self.result_chain:
self.result_chain.chain = [
comp
for comp in self.result_chain.chain
if not isinstance(comp, Comp.Plain)
] # 清空 Plain 组件
self.result_chain.chain.insert(0, Comp.Plain(value))
else:
self._completion_text = value
def to_openai_tool_calls(self) -> List[Dict]:
"""将工具调用信息转换为 OpenAI 格式"""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
ret.append(
{
"id": self.tools_call_ids[idx],
"function": {
"name": self.tools_call_name[idx],
"arguments": json.dumps(tool_call_arg),
},
"type": "function",
}
)
return ret
+288 -23
View File
@@ -1,7 +1,30 @@
from __future__ import annotations
import json
import textwrap
from typing import Dict, List, Awaitable
import os
import asyncio
import copy
from typing import Dict, List, Awaitable, Literal, Any
from dataclasses import dataclass
from typing import Optional
from contextlib import AsyncExitStack
from astrbot import logger
try:
import mcp
except (ModuleNotFoundError, ImportError):
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
SUPPORTED_TYPES = [
"string",
"number",
"object",
"array",
"boolean",
] # json schema 支持的数据类型
@dataclass
@@ -13,25 +36,101 @@ class FuncTool:
name: str
parameters: Dict
description: str
handler: Awaitable
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
handler: Awaitable = None
"""处理函数, 当 origin 为 mcp 时,这个为空"""
handler_module_path: str = None
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
必须要保留这个字段, handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
"""
active: bool = True
"""是否激活"""
origin: Literal["local", "mcp"] = "local"
"""函数工具的来源, local 为本地函数工具, mcp 为 MCP 服务"""
SUPPORTED_TYPES = [
"string",
"number",
"object",
"array",
"boolean",
] # json schema 支持的数据类型
# MCP 相关字段
mcp_server_name: str = None
"""MCP 服务名称,当 origin 为 mcp 时有效"""
mcp_client: MCPClient = None
"""MCP 客户端,当 origin 为 mcp 时有效"""
def __repr__(self):
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description}, active={self.active}, origin={self.origin})"
async def execute(self, **args) -> Any:
"""执行函数调用"""
if self.origin == "local":
if not self.handler:
raise Exception(f"Local function {self.name} has no handler")
return await self.handler(**args)
elif self.origin == "mcp":
if not self.mcp_client or not self.mcp_client.session:
raise Exception(f"MCP client for {self.name} is not available")
# 使用name属性而不是额外的mcp_tool_name
if ":" in self.name:
# 如果名字是格式为 mcp:server:tool_name,提取实际的工具名
actual_tool_name = self.name.split(":")[-1]
return await self.mcp_client.session.call_tool(actual_tool_name, args)
else:
return await self.mcp_client.session.call_tool(self.name, args)
else:
raise Exception(f"Unknown function origin: {self.origin}")
class MCPClient:
def __init__(self):
# Initialize session and client objects
self.session: Optional[mcp.ClientSession] = None
self.exit_stack = AsyncExitStack()
self.name = None
self.active: bool = True
self.tools: List[mcp.Tool] = []
async def connect_to_server(self, mcp_server_config: dict):
"""Connect to an MCP server
Args:
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
"""
cfg = mcp_server_config.copy()
cfg.pop("active", None)
server_params = mcp.StdioServerParameters(
**cfg,
)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(server_params)
)
self.stdio, self.write = stdio_transport
self.session = await self.exit_stack.enter_async_context(
mcp.ClientSession(self.stdio, self.write)
)
await self.session.initialize()
async def list_tools_and_save(self) -> mcp.ListToolsResult:
"""List all tools from the server and save them to self.tools"""
response = await self.session.list_tools()
logger.debug(f"MCP server {self.name} list tools response: {response}")
self.tools = response.tools
return response
async def cleanup(self):
"""Clean up resources"""
await self.exit_stack.aclose()
class FuncCall:
def __init__(self) -> None:
self.func_list: List[FuncTool] = []
"""内部加载的 func tools"""
self.mcp_client_dict: Dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_service_queue = asyncio.Queue()
"""用于外部控制 MCP 服务的启停"""
self.mcp_client_event: Dict[str, asyncio.Event] = {}
def empty(self) -> bool:
return len(self.func_list) == 0
@@ -43,14 +142,16 @@ class FuncCall:
desc: str,
handler: Awaitable,
) -> None:
"""
为函数调用(function-calling / tools-use)添加工具。
"""添加函数调用工具
@param name: 函数名
@param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...]
@param desc: 函数描述
@param func_obj: 处理函数
"""
# check if the tool has been added before
self.remove_func(name)
params = {
"type": "object", # hard-coded here
"properties": {},
@@ -67,13 +168,14 @@ class FuncCall:
handler=handler,
)
self.func_list.append(_func)
logger.info(f"添加函数调用工具: {name}")
def remove_func(self, name: str) -> None:
"""
删除一个函数调用工具。
"""
for i, f in enumerate(self.func_list):
if f["name"] == name:
if f.name == name:
self.func_list.pop(i)
break
@@ -83,11 +185,166 @@ class FuncCall:
return f
return None
async def _init_mcp_clients(self) -> None:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
"mcpServers": {
"weather": {
"command": "uv",
"args": [
"--directory",
"/ABSOLUTE/PATH/TO/PARENT/FOLDER/weather",
"run",
"weather.py"
]
}
}
...
}
```
"""
current_dir = os.path.dirname(os.path.abspath(__file__))
data_dir = os.path.abspath(os.path.join(current_dir, "../../../data"))
mcp_json_file = os.path.join(data_dir, "mcp_server.json")
if not os.path.exists(mcp_json_file):
# 配置文件不存在错误处理
with open(mcp_json_file, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
return
mcp_server_json_obj: Dict[str, Dict] = json.load(
open(mcp_json_file, "r", encoding="utf-8")
)["mcpServers"]
for name in mcp_server_json_obj.keys():
cfg = mcp_server_json_obj[name]
if cfg.get("active", True):
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event)
)
self.mcp_client_event[name] = event
async def mcp_service_selector(self):
"""为了避免在不同异步任务中控制 MCP 服务导致的报错,整个项目统一通过这个 Task 来控制
使用 self.mcp_service_queue.put_nowait() 来控制 MCP 服务的启停,数据格式如下:
{"type": "init"} 初始化所有MCP客户端
{"type": "init", "name": "mcp_server_name", "cfg": {...}} 初始化指定的MCP客户端
{"type": "terminate"} 终止所有MCP客户端
{"type": "terminate", "name": "mcp_server_name"} 终止指定的MCP客户端
"""
while True:
data = await self.mcp_service_queue.get()
if data["type"] == "init":
if "name" in data:
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(
data["name"], data["cfg"], event
)
)
self.mcp_client_event[data["name"]] = event
else:
await self._init_mcp_clients()
elif data["type"] == "terminate":
if "name" in data:
# await self._terminate_mcp_client(data["name"])
if data["name"] in self.mcp_client_event:
self.mcp_client_event[data["name"]].set()
self.mcp_client_event.pop(data["name"], None)
else:
for name in self.mcp_client_dict.keys():
# await self._terminate_mcp_client(name)
# self.mcp_client_event[name].set()
if name in self.mcp_client_event:
self.mcp_client_event[name].set()
self.mcp_client_event.pop(name, None)
async def _init_mcp_client_task_wrapper(
self, name: str, cfg: dict, event: asyncio.Event
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
try:
await self._init_mcp_client(name, cfg)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
await self._terminate_mcp_client(name)
except Exception as e:
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
async def _init_mcp_client(self, name: str, config: dict) -> None:
"""初始化单个MCP客户端"""
try:
# 先清理之前的客户端,如果存在
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
mcp_client = MCPClient()
mcp_client.name = name
await mcp_client.connect_to_server(config)
tools_res = await mcp_client.list_tools_and_save()
tool_names = [tool.name for tool in tools_res.tools]
self.mcp_client_dict[name] = mcp_client
# 移除该MCP服务之前的工具(如有)
self.func_list = [
f
for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == name)
]
# 将 MCP 工具转换为 FuncTool 并添加到 func_list
for tool in mcp_client.tools:
func_tool = FuncTool(
name=tool.name,
parameters=tool.inputSchema,
description=tool.description,
origin="mcp",
mcp_server_name=name,
mcp_client=mcp_client,
)
self.func_list.append(func_tool)
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
return True
except Exception as e:
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
# 发生错误时确保客户端被清理
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
return False
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
if name in self.mcp_client_dict:
try:
# 关闭MCP连接
await self.mcp_client_dict[name].cleanup()
del self.mcp_client_dict[name]
except Exception as e:
logger.info(f"清空 MCP 客户端资源 {name}: {e}")
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (f.origin == "mcp" and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
def get_func_desc_openai_style(self) -> list:
"""
获得 OpenAI API 风格的**已经激活**的工具描述
"""
_l = []
# 处理所有工具(包括本地和MCP工具)
for f in self.func_list:
if not f.active:
continue
@@ -137,7 +394,13 @@ class FuncCall:
# 检查并添加非空的properties参数
params = f.parameters if isinstance(f.parameters, dict) else {}
params = copy.deepcopy(params)
if params.get("properties", {}):
properties = params["properties"]
for key, value in properties.items():
if "default" in value:
del value["default"]
params["properties"] = properties
func_declaration["parameters"] = params
tools.append(func_declaration)
@@ -153,9 +416,9 @@ class FuncCall:
continue
_l.append(
{
"name": f["name"],
"parameters": f["parameters"],
"description": f["description"],
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
)
func_definition = json.dumps(_l, ensure_ascii=False)
@@ -205,14 +468,11 @@ class FuncCall:
func_name = tool["name"]
args = tool["args"]
# 调用函数
tool_callable = None
for func in self.func_list:
if func.name == func_name:
tool_callable = func.star_handler_metadata.handler
break
if not tool_callable:
func_tool = self.get_func(func_name)
if not func_tool:
raise Exception(f"Request function {func_name} not found.")
ret = await tool_callable(**args)
ret = await func_tool.execute(**args)
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
@@ -222,3 +482,8 @@ class FuncCall:
def __repr__(self):
return str(self.func_list)
async def terminate(self):
for name in self.mcp_client_dict.keys():
await self._terminate_mcp_client(name)
logger.debug(f"清理 MCP 客户端 {name} 资源")
+46 -2
View File
@@ -1,4 +1,5 @@
import traceback
import asyncio
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .entites import ProviderType
@@ -127,14 +128,19 @@ class ProviderManager:
if self.tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
# 初始化 MCP Client 连接
asyncio.create_task(
self.llm_tools.mcp_service_selector(), name="mcp-service-handler"
)
self.llm_tools.mcp_service_queue.put_nowait({"type": "init"})
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
return
logger.info(
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商适配器 ..."
f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ..."
)
logger.debug(f"Provider Config: {provider_config}")
# 动态导入
try:
@@ -192,6 +198,10 @@ class ProviderManager:
from .sources.fishaudio_tts_api_source import (
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
)
case "dashscope_tts":
from .sources.dashscope_tts import (
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
)
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
@@ -300,10 +310,42 @@ class ProviderManager:
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
elif (
self.curr_provider_inst is None
and len(self.provider_insts) > 0
and self.provider_enabled
):
self.curr_provider_inst = self.provider_insts[0]
self.selected_provider_id = self.curr_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
)
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
elif (
self.curr_stt_provider_inst is None
and len(self.stt_provider_insts) > 0
and self.stt_enabled
):
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
)
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
elif (
self.curr_tts_provider_inst is None
and len(self.tts_provider_insts) > 0
and self.tts_enabled
):
self.curr_tts_provider_inst = self.tts_provider_insts[0]
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
)
def get_insts(self):
return self.provider_insts
@@ -340,3 +382,5 @@ class ProviderManager:
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
await provider_inst.terminate()
# 清理 MCP Client 连接
await self.llm_tools.mcp_service_queue.put({"type": "terminate"})
+3 -1
View File
@@ -3,7 +3,7 @@ from typing import List
from astrbot.core.db import BaseDatabase
from typing import TypedDict
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from dataclasses import dataclass
@@ -90,6 +90,7 @@ class Provider(AbstractProvider):
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -100,6 +101,7 @@ class Provider(AbstractProvider):
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
@@ -10,7 +10,7 @@ from astrbot.api.provider import Provider, Personality
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
from .openai_source import ProviderOpenAIOfficial
@@ -79,11 +79,14 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
# tools call (function calling)
args_ls = []
func_name_ls = []
tool_use_ids = []
func_name_ls.append(content.name)
args_ls.append(content.input)
tool_use_ids.append(content.id)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_use_ids
if not llm_response.completion_text and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
@@ -101,6 +104,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result: ToolCallsResult = None,
**kwargs,
) -> LLMResponse:
if not prompt:
@@ -113,6 +117,10 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
if "_no_save" in part:
del part["_no_save"]
if tool_calls_result:
# 暂时这样写。
prompt += f"Here are the related results via using tools: {str(tool_calls_result.tool_calls_result)}"
model_config = self.provider_config.get("model_config", {})
payloads = {"messages": context_query, **model_config}
@@ -1,3 +1,4 @@
import re
import asyncio
import functools
from typing import List
@@ -40,11 +41,28 @@ class ProviderDashscope(ProviderOpenAIOfficial):
raise Exception("阿里云百炼 APP 类型不能为空。")
self.model_name = "dashscope"
self.variables: dict = provider_config.get("variables", {})
self.rag_options: dict = provider_config.get("rag_options", {})
self.output_reference = self.rag_options.get("output_reference", False)
self.rag_options = self.rag_options.copy()
self.rag_options.pop("output_reference", None)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
def has_rag_options(self):
"""判断是否有 RAG 选项
Returns:
bool: 是否有 RAG 选项
"""
if self.rag_options and (
len(self.rag_options.get("pipeline_ids", [])) > 0
or len(self.rag_options.get("file_ids", [])) > 0
):
return True
return False
async def text_chat(
self,
prompt: str,
@@ -62,7 +80,10 @@ class ProviderDashscope(ProviderOpenAIOfficial):
session_var = session_vars.get(session_id, {})
payload_vars.update(session_var)
if self.dashscope_app_type in ["agent", "dialog-workflow"]:
if (
self.dashscope_app_type in ["agent", "dialog-workflow"]
and not self.has_rag_options()
):
# 支持多轮对话的
new_record = {"role": "user", "content": prompt}
if image_urls:
@@ -75,23 +96,31 @@ class ProviderDashscope(ProviderOpenAIOfficial):
if "_no_save" in part:
del part["_no_save"]
# 调用阿里云百炼 API
payload = {
"app_id": self.app_id,
"api_key": self.api_key,
"messages": context_query,
"biz_params": payload_vars or None,
}
partial = functools.partial(
Application.call,
app_id=self.app_id,
api_key=self.api_key,
messages=context_query,
biz_params=payload_vars or None,
**payload,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
else:
# 不支持多轮对话的
# 调用阿里云百炼 API
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
partial = functools.partial(
Application.call,
app_id=self.app_id,
promtp=prompt,
api_key=self.api_key,
biz_params=payload_vars or None,
**payload,
)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
@@ -107,6 +136,15 @@ class ProviderDashscope(ProviderOpenAIOfficial):
)
output_text = response.output.get("text", "")
# RAG 引用脚标格式化
output_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", output_text)
if self.output_reference and response.output.get("doc_references", None):
ref_str = ""
for ref in response.output.get("doc_references", []):
ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
ref_str += f"{ref['index_id']}. {ref_title}\n"
output_text += f"\n\n回答来源:\n{ref_str}"
return LLMResponse(role="assistant", completion_text=output_text)
async def forget(self, session_id):
@@ -0,0 +1,39 @@
import dashscope
import uuid
import asyncio
from dashscope.audio.tts_v2 import *
from ..provider import TTSProvider
from ..entites import ProviderType
from ..register import register_provider_adapter
@register_provider_adapter(
"dashscope_tts", "Dashscope TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
)
class ProviderDashscopeTTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
self.set_model(provider_config.get("model", None))
self.timeout_ms = float(provider_config.get("timeout", 20))*1000
dashscope.api_key = self.chosen_api_key
self.synthesizer = SpeechSynthesizer(
model=self.get_model(),
voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
)
async def get_audio(self, text: str) -> str:
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
audio = await asyncio.get_event_loop().run_in_executor(
None, self.synthesizer.call, text, self.timeout_ms
)
with open(path, "wb") as f:
f.write(audio)
return path
+83 -27
View File
@@ -1,3 +1,5 @@
import astrbot.core.message.components as Comp
from typing import List
from .. import Provider, Personality
from ..entites import LLMResponse
@@ -5,8 +7,9 @@ from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.io import download_image_by_url, download_file
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
@register_provider_adapter("dify", "Dify APP 适配器。")
@@ -30,7 +33,6 @@ class ProviderDify(Provider):
if not self.api_key:
raise Exception("Dify API Key 不能为空。")
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_client = DifyAPIClient(self.api_key, api_base)
self.api_type = provider_config.get("dify_api_type", "")
if not self.api_type:
raise Exception("Dify API 类型不能为空。")
@@ -41,15 +43,19 @@ class ProviderDify(Provider):
self.dify_query_input_key = provider_config.get(
"dify_query_input_key", "astrbot_text_query"
)
self.variables: dict = provider_config.get("variables", {})
if not self.dify_query_input_key:
self.dify_query_input_key = "astrbot_text_query"
if not self.workflow_output_key:
self.workflow_output_key = "astrbot_wf_output"
self.variables: dict = provider_config.get("variables", {})
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.conversation_ids = {}
"""记录当前 session id 的对话 ID"""
self.api_client = DifyAPIClient(self.api_key, api_base)
async def text_chat(
self,
prompt: str,
@@ -65,26 +71,27 @@ class ProviderDify(Provider):
files_payload = []
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
file_response = await self.api_client.file_upload(
image_path, user=session_id
image_path = (
await download_image_by_url(image_url)
if image_url.startswith("http")
else image_url
)
file_response = await self.api_client.file_upload(
image_path, user=session_id
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
}
)
else:
# TODO: 处理更多情况
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
}
)
# 获得会话变量
payload_vars = self.variables.copy()
@@ -96,6 +103,9 @@ class ProviderDify(Provider):
try:
match self.api_type:
case "chat" | "agent":
if not prompt:
prompt = "请描述这张图片。"
async for chunk in self.api_client.chat_messages(
inputs={
**payload_vars,
@@ -148,8 +158,9 @@ class ProviderDify(Provider):
)
case "workflow_finished":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
)
logger.debug(f"Dify 工作流结果:{chunk}")
if chunk["data"]["error"]:
logger.error(
f"Dify 工作流出现错误:{chunk['data']['error']}"
@@ -164,9 +175,7 @@ class ProviderDify(Provider):
raise Exception(
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
)
result = chunk["data"]["outputs"][
self.workflow_output_key
]
result = chunk
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
except Exception as e:
@@ -176,7 +185,54 @@ class ProviderDify(Provider):
if not result:
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
return LLMResponse(role="assistant", completion_text=result)
chain = await self.parse_dify_result(result)
return LLMResponse(role="assistant", result_chain=chain)
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
if isinstance(chunk, str):
# Chat
return MessageChain(chain=[Comp.Plain(chunk)])
async def parse_file(item: dict) -> Comp:
match item["type"]:
case "image":
return Comp.Image(file=item["url"], url=item["url"])
case "audio":
# 仅支持 wav
path = f"data/temp/{item['filename']}.wav"
await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"])
case "video":
return Comp.Video(file=item["url"])
case _:
return Comp.File(name=item["filename"], file=item["url"])
output = chunk["data"]["outputs"][self.workflow_output_key]
chains = []
if isinstance(output, str):
# 纯文本输出
chains.append(Comp.Plain(output))
elif isinstance(output, list):
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
for item in output:
# handle Array[File]
if (
not isinstance(item, dict)
or item.get("dify_model_identity", "") != "__dify__file__"
):
chains.append(Comp.Plain(str(output)))
break
else:
chains.append(Comp.Plain(str(output)))
# scan file
files = chunk["data"].get("files", [])
for item in files:
comp = await parse_file(item)
chains.append(comp)
return MessageChain(chain=chains)
async def forget(self, session_id):
self.conversation_ids[session_id] = ""
@@ -35,6 +35,8 @@ class ProviderEdgeTTS(TTSProvider):
self.pitch = provider_config.get("pitch", None)
self.timeout = provider_config.get("timeout", 30)
self.proxy = os.getenv("https_proxy", None)
self.set_model("edge_tts")
async def get_audio(self, text: str) -> str:
@@ -42,7 +44,7 @@ class ProviderEdgeTTS(TTSProvider):
mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3"
wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav"
# 构建Edge TTS参数
# 构建 Edge TTS 参数
kwargs = {"text": text, "voice": self.voice}
if self.rate:
kwargs["rate"] = self.rate
@@ -52,12 +54,20 @@ class ProviderEdgeTTS(TTSProvider):
kwargs["pitch"] = self.pitch
try:
communicate = edge_tts.Communicate(**kwargs)
communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs)
await communicate.save(mp3_path)
# 使用ffmpeg将MP3转换为标准WAV格式
_ = await asyncio.create_subprocess_exec(
[
try:
from pyffmpeg import FFmpeg
ff = FFmpeg()
ff.convert(input=mp3_path, output=wav_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line
# 使用ffmpeg将MP3转换为标准WAV格式
p = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y", # 覆盖输出文件
"-i",
@@ -68,11 +78,20 @@ class ProviderEdgeTTS(TTSProvider):
"24000", # 采样率24kHz (适合微信语音)
"-ac",
"1", # 单声道
"-af",
"apad=pad_dur=2", # 确保输出时长准确
"-fflags",
"+genpts", # 强制生成时间戳
"-hide_banner", # 隐藏版本信息
wav_path, # 输出文件
],
capture_output=True,
check=True,
)
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
# 等待进程完成并获取输出
stdout, stderr = await p.communicate()
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}")
os.remove(mp3_path)
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0:
@@ -82,13 +101,15 @@ class ProviderEdgeTTS(TTSProvider):
raise RuntimeError("生成的WAV文件不存在或为空")
except subprocess.CalledProcessError as e:
logger.error(f"FFmpeg转换失败: {e.stderr.decode() if e.stderr else str(e)}")
logger.error(
f"FFmpeg 转换失败: {e.stderr.decode() if e.stderr else str(e)}"
)
try:
if os.path.exists(mp3_path):
os.remove(mp3_path)
except Exception:
pass
raise RuntimeError(f"FFmpeg转换失败: {str(e)}")
raise RuntimeError(f"FFmpeg 转换失败: {str(e)}")
except Exception as e:
logger.error(f"音频生成失败: {str(e)}")
+122 -46
View File
@@ -1,6 +1,10 @@
import base64
import aiohttp
import json
import random
import asyncio
import astrbot.core.message.components as Comp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider, Personality
@@ -38,6 +42,8 @@ class SimpleGoogleGenAIClient:
model: str = "gemini-1.5-flash",
system_instruction: str = "",
tools: dict = None,
modalities: List[str] = ["Text"],
safety_settings: List[dict] = [],
):
payload = {}
if system_instruction:
@@ -45,6 +51,13 @@ class SimpleGoogleGenAIClient:
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
payload["generationConfig"] = {
"responseModalities": modalities,
}
payload["safetySettings"] = [
{"category": s["category"], "threshold": s["threshold"]}
for s in safety_settings
]
logger.debug(f"payload: {payload}")
request_url = (
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
@@ -98,6 +111,21 @@ class ProviderGoogleGenAI(Provider):
)
self.set_model(provider_config["model_config"]["model"])
safety_mapping = {
"harassment": "HARM_CATEGORY_HARASSMENT",
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
}
self.safety_settings = []
user_safety_config = self.provider_config.get("gm_safety_settings", {})
for config_key, harm_category in safety_mapping.items():
if threshold := user_safety_config.get(config_key):
self.safety_settings.append(
{"category": harm_category, "threshold": threshold}
)
async def get_models(self):
return await self.client.models_list()
@@ -119,7 +147,7 @@ class ProviderGoogleGenAI(Provider):
if message["role"] == "user":
if isinstance(message["content"], str):
if not message["content"]:
message["content"] = "<empty_content>"
message["content"] = ""
google_genai_conversation.append(
{"role": "user", "parts": [{"text": message["content"]}]}
@@ -130,7 +158,7 @@ class ProviderGoogleGenAI(Provider):
for part in message["content"]:
if part["type"] == "text":
if not part["text"]:
part["text"] = "<empty_content>"
part["text"] = ""
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
parts.append(
@@ -146,36 +174,105 @@ class ProviderGoogleGenAI(Provider):
google_genai_conversation.append({"role": "user", "parts": parts})
elif message["role"] == "assistant":
if not message["content"]:
message["content"] = "<empty_content>"
google_genai_conversation.append(
{"role": "model", "parts": [{"text": message["content"]}]}
if "content" in message:
if not message["content"]:
message["content"] = ""
google_genai_conversation.append(
{"role": "model", "parts": [{"text": message["content"]}]}
)
elif "tool_calls" in message:
# tool calls in the last turn
parts = []
for tool_call in message["tool_calls"]:
parts.append(
{
"functionCall": {
"name": tool_call["function"]["name"],
"args": json.loads(
tool_call["function"]["arguments"]
),
}
}
)
google_genai_conversation.append({"role": "model", "parts": parts})
elif message["role"] == "tool":
parts = []
parts.append(
{
"functionResponse": {
"name": message["tool_call_id"],
"response": {
"name": message["tool_call_id"],
"content": message["content"],
},
}
}
)
google_genai_conversation.append({"role": "user", "parts": parts})
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
result = await self.client.generate_content(
contents=google_genai_conversation,
model=self.get_model(),
system_instruction=system_instruction,
tools=tool,
)
logger.debug(f"result: {result}")
modalites = ["Text"]
if self.provider_config.get("gm_resp_image_modal", False):
modalites.append("Image")
if "candidates" not in result:
raise Exception("Gemini 返回异常结果: " + str(result))
loop = True
while loop:
loop = False
result = await self.client.generate_content(
contents=google_genai_conversation,
model=self.get_model(),
system_instruction=system_instruction,
tools=tool,
modalities=modalites,
safety_settings=self.safety_settings,
)
logger.debug(f"result: {result}")
# Developer instruction is not enabled for models/gemini-2.0-flash-exp
if "Developer instruction is not enabled" in str(result):
logger.warning(
f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。"
)
system_instruction = ""
loop = True
elif "Function calling is not enabled" in str(result):
logger.warning(
f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。"
)
tool = None
loop = True
elif "Multi-modal output is not supported" in str(result):
logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。"
)
modalites = ["Text"]
loop = True
elif "candidates" not in result:
raise Exception("Gemini 返回异常结果: " + str(result))
candidates = result["candidates"][0]["content"]["parts"]
llm_response = LLMResponse("assistant")
chain = []
for candidate in candidates:
if "text" in candidate:
llm_response.completion_text += candidate["text"]
chain.append(Comp.Plain(candidate["text"]))
elif "functionCall" in candidate:
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate["functionCall"]["args"])
llm_response.tools_call_name.append(candidate["functionCall"]["name"])
llm_response.tools_call_ids.append(
candidate["functionCall"]["name"]
) # 没有 tool id
elif "inlineData" in candidate:
mime_type: str = candidate["inlineData"]["mimeType"]
if mime_type.startswith("image/"):
chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"]))
llm_response.completion_text = llm_response.completion_text.strip()
llm_response.result_chain = MessageChain(chain=chain)
return llm_response
async def text_chat(
@@ -186,6 +283,7 @@ class ProviderGoogleGenAI(Provider):
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
@@ -198,6 +296,10 @@ class ProviderGoogleGenAI(Provider):
if "_no_save" in part:
del part["_no_save"]
# tool calls result
if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
@@ -214,46 +316,20 @@ class ProviderGoogleGenAI(Provider):
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt = 20
while retry_cnt > 0:
logger.warning(
f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
try:
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse(
"err", "err: 请尝试 /reset 重置会话"
)
elif "Function calling is not enabled" in str(e):
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
llm_response = await self._query(payloads, None)
break
elif "429" in str(e) or "API key not valid" in str(e):
if "429" in str(e) or "API key not valid" in str(e):
keys.remove(chosen_key)
if len(keys) > 0:
chosen_key = random.choice(keys)
logger.info(
f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..."
)
await asyncio.sleep(1)
continue
else:
logger.error(
f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..."
)
raise Exception("API 资源已耗尽,且没有可用的 Key 重试...")
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
else:
logger.error(
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
+78 -65
View File
@@ -2,6 +2,8 @@ import base64
import json
import os
import inspect
import random
import asyncio
from openai import AsyncOpenAI, AsyncAzureOpenAI
from openai.types.chat.chat_completion import ChatCompletion
@@ -120,15 +122,18 @@ class ProviderOpenAIOfficial(Provider):
# tools call (function calling)
args_ls = []
func_name_ls = []
tool_call_ids = []
for tool_call in choice.message.tool_calls:
for tool in tools.func_list:
if tool.name == tool_call.function.name:
args = json.loads(tool_call.function.arguments)
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
tool_call_ids.append(tool_call.id)
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_call_ids
if choice.finish_reason == "content_filter":
raise Exception(
@@ -151,6 +156,7 @@ class ProviderOpenAIOfficial(Provider):
func_tool: FuncCall = None,
contexts=[],
system_prompt=None,
tool_calls_result=None,
**kwargs,
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
@@ -162,82 +168,91 @@ class ProviderOpenAIOfficial(Provider):
if "_no_save" in part:
del part["_no_save"]
# tool calls result
if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model()
payloads = {"messages": context_query, **model_config}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except UnprocessableEntityError as e:
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
# 重试 10 次
retry_cnt = 20
while retry_cnt > 0:
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
try:
await self.pop_record(context_query)
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
if "maximum context length" in str(e):
retry_cnt -= 1
else:
raise e
if retry_cnt == 0:
llm_response = LLMResponse(
"err", "err: 请尝试 /reset 清除会话记录。"
)
elif "The model is not a VLM" in str(e): # siliconcloud
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
e = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
except UnprocessableEntityError as e:
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
llm_response = await self._query(payloads, func_tool)
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
elif (
"does not support Function Calling" in str(e)
or "does not support tools" in str(e)
or "Function call is not supported" in str(e)
or "Function calling is not enabled" in str(e)
or "Tool calling is not supported" in str(e)
or "No endpoints found that support tool use" in str(e)
or "model does not support function calling" in str(e)
or ("tool" in str(e) and "support" in str(e).lower())
or ("function" in str(e) and "support" in str(e).lower())
):
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
llm_response = await self._query(payloads, None)
else:
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
context_query = new_contexts
except Exception as e:
if "429" in str(e):
logger.warning(
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
)
# 最后一次不等待
if retry_cnt < max_retries - 1:
await asyncio.sleep(1)
available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys)
continue
else:
raise e
elif "maximum context length" in str(e):
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
)
await self.pop_record(context_query)
elif "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
elif (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
or ("function" in str(e).lower() and "support" in str(e).lower())
):
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
)
if "tools" in payloads:
del payloads["tools"]
func_tool = None
else:
logger.error(
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
f"发生了错误。Provider 配置如下: {self.provider_config}"
)
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
)
raise e
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
)
raise e
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
raise e
return llm_response
async def _remove_image_from_context(self, contexts: List):
@@ -275,10 +290,8 @@ class ProviderOpenAIOfficial(Provider):
def set_key(self, key):
self.client.api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None):
"""
组装上下文。
"""
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
if image_urls:
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
for image_url in image_urls:
@@ -18,10 +18,14 @@ class ProviderOpenAITTSAPI(TTSProvider):
self.chosen_api_key = provider_config.get("api_key", "")
self.voice = provider_config.get("openai-tts-voice", "alloy")
timeout = provider_config.get("timeout", NOT_GIVEN)
if isinstance(timeout, str):
timeout = int(timeout)
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
timeout=timeout,
)
self.set_model(provider_config.get("model", None))
@@ -48,14 +48,6 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
return os.path.join("data", "temp", f"{timestamp}")
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = await self.get_timestamped_path() + ".mp3"
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data","temp', filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
@@ -31,14 +31,6 @@ class ProviderOpenAIWhisperAPI(STTProvider):
self.set_model(provider_config.get("model", None))
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + ".mp3"
ff = FFmpeg()
output_path = ff.convert(path, os.path.join("data/temp", filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
@@ -33,14 +33,6 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
)
logger.info("Whisper 模型加载完成。")
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + ".mp3"
ff = FFmpeg()
output_path = ff.convert(path, os.path.join("data/temp", filename))
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
+3 -1
View File
@@ -4,12 +4,14 @@ from .context import Context
from astrbot.core.provider import Provider
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core import html_renderer
from astrbot.core.star.star_tools import StarTools
class Star(CommandParserMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
def __init__(self, context: Context):
StarTools.initialize(context)
self.context = context
async def text_to_image(self, text: str, return_url=True) -> str:
@@ -27,4 +29,4 @@ class Star(CommandParserMixin):
pass
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider"]
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
+6 -2
View File
@@ -183,11 +183,15 @@ class Context:
获取指定类型的平台适配器。
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
if isinstance(platform_type, str):
if platform.meta().name == platform_type:
if name == platform_type:
return platform
else:
if platform.meta().name == ADAPTER_NAME_2_TYPE[platform_type]:
if (
name in ADAPTER_NAME_2_TYPE
and ADAPTER_NAME_2_TYPE[name] & platform_type
):
return platform
async def send_message(
+3 -4
View File
@@ -15,7 +15,6 @@ from ..filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from astrbot.core import logger
def get_handler_full_name(awaitable: Awaitable) -> str:
@@ -359,9 +358,9 @@ def register_llm_tool(name: str = None):
}
)
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
llm_tools.add_func(
llm_tool_name, args, docstring.description.strip(), md.handler
)
return awaitable
return decorator
+2
View File
@@ -14,6 +14,8 @@ star_map: Dict[str, StarMetadata] = {}
class StarMetadata:
"""
插件的元数据。
当 activated 为 False 时,star_cls 可能为 None,请不要在插件未激活时调用 star_cls 的方法。
"""
name: str
+146 -44
View File
@@ -1,3 +1,7 @@
"""
插件的重载、启停、安装、卸载等操作。
"""
import inspect
import functools
import os
@@ -75,7 +79,7 @@ class PluginManager:
elif os.path.exists(os.path.join(path, d, d + ".py")):
module_str = d
else:
print(f"插件 {d} 未找到 main.py 或者 {d}.py,跳过。")
logger.info(f"插件 {d} 未找到 main.py 或者 {d}.py,跳过。")
continue
if os.path.exists(os.path.join(path, d, "main.py")) or os.path.exists(
os.path.join(path, d, d + ".py")
@@ -164,7 +168,6 @@ class PluginManager:
async def reload(self, specified_plugin_name=None):
"""扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件"""
specified_module_path = None
if specified_plugin_name:
for smd in star_registry:
@@ -184,6 +187,8 @@ class PluginManager:
f"插件 {smd.name} 未被正常终止: {str(e)}, 可能会导致该插件运行不正常。"
)
await self._unbind_plugin(smd.name, smd.module_path)
star_handlers_registry.clear()
star_map.clear()
star_registry.clear()
@@ -203,10 +208,17 @@ class PluginManager:
)
await self._unbind_plugin(smd.name, specified_module_path)
try:
del sys.modules[specified_module_path]
except KeyError:
logger.warning(f"模块 {specified_module_path} 未载入")
return await self.load(specified_module_path)
async def load(self, specified_module_path=None, specified_dir_name=None):
"""载入插件。
当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。
"""
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
alter_cmd = sp.get("alter_cmd", {})
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
@@ -214,11 +226,6 @@ class PluginManager:
fail_rec = ""
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
alter_cmd = sp.get("alter_cmd", {})
# 导入插件模块,并尝试实例化插件类
for plugin_module in plugin_modules:
try:
@@ -232,8 +239,11 @@ class PluginManager:
path = "data.plugins." if not reserved else "packages."
path += root_dir_name + "." + module_str
# 检查是否需要载入指定的插件
if specified_module_path and path != specified_module_path:
continue
if specified_dir_name and root_dir_name != specified_dir_name:
continue
logger.info(f"正在载入插件 {root_dir_name} ...")
@@ -287,18 +297,24 @@ class PluginManager:
except Exception:
pass
if plugin_config:
metadata.config = plugin_config
try:
metadata.star_cls = metadata.star_cls_type(
context=self.context, config=plugin_config
)
except TypeError as _:
if path not in inactivated_plugins:
# 只有没有禁用插件时才实例化插件类
if plugin_config:
metadata.config = plugin_config
try:
metadata.star_cls = metadata.star_cls_type(
context=self.context, config=plugin_config
)
except TypeError as _:
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
else:
metadata.star_cls = metadata.star_cls_type(
context=self.context
)
else:
metadata.star_cls = metadata.star_cls_type(context=self.context)
logger.info(f"插件 {metadata.name} 已被禁用。")
metadata.module = module
metadata.root_dir_name = root_dir_name
@@ -316,7 +332,10 @@ class PluginManager:
)
# 绑定 llm_tool handler
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == metadata.module_path:
if (
func_tool.handler
and func_tool.handler.__module__ == metadata.module_path
):
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(
func_tool.handler, metadata.star_cls
@@ -331,19 +350,23 @@ class PluginManager:
)
classes = self._get_classes(module)
if plugin_config:
try:
obj = getattr(module, classes[0])(
context=self.context, config=plugin_config
) # 实例化插件类
except TypeError as _:
if path not in inactivated_plugins:
# 只有没有禁用插件时才实例化插件类
if plugin_config:
try:
obj = getattr(module, classes[0])(
context=self.context, config=plugin_config
) # 实例化插件类
except TypeError as _:
obj = getattr(module, classes[0])(
context=self.context
) # 实例化插件类
else:
obj = getattr(module, classes[0])(
context=self.context
) # 实例化插件类
else:
obj = getattr(module, classes[0])(
context=self.context
) # 实例化插件类
logger.info(f"插件 {metadata.name} 已被禁用。")
metadata = None
metadata = self._load_plugin_metadata(
@@ -426,8 +449,36 @@ class PluginManager:
async def install_plugin(self, repo_url: str, proxy=""):
plugin_path = await self.updator.install(repo_url, proxy)
# reload the plugin
await self.reload()
return plugin_path
dir_name = os.path.basename(plugin_path)
await self.load(specified_dir_name=dir_name)
# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
if not plugin:
# Try to find by other name if directory name doesn't match plugin name
for star in self.context.get_all_stars():
if star.root_dir_name == dir_name:
plugin = star
break
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(plugin_path, "README.md")
if not os.path.exists(readme_path):
readme_path = os.path.join(plugin_path, "readme.md")
if os.path.exists(readme_path):
try:
with open(readme_path, "r", encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
plugin_info = None
if plugin:
plugin_info = {"repo": plugin.repo, "readme": readme_content}
return plugin_info
async def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
@@ -450,9 +501,11 @@ class PluginManager:
# 从 star_registry 和 star_map 中删除
await self._unbind_plugin(plugin_name, plugin.module_path)
if not remove_dir(os.path.join(ppath, root_dir_name)):
try:
remove_dir(os.path.join(ppath, root_dir_name))
except Exception as e:
raise Exception(
"移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
f"移除插件成功,但是删除插件文件夹失败: {str(e)}。您可以手动删除该文件夹,位于 addons/plugins/ 下。"
)
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
@@ -464,7 +517,9 @@ class PluginManager:
for handler in star_handlers_registry.get_handlers_by_module_name(
plugin_module_path
):
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
logger.info(
f"移除了插件 {plugin_name} 的处理函数 {handler.handler_name} ({len(star_handlers_registry)})"
)
star_handlers_registry.remove(handler)
keys_to_delete = [
k
@@ -472,9 +527,15 @@ class PluginManager:
if k.startswith(plugin_module_path)
]
for k in keys_to_delete:
v = star_handlers_registry.star_handlers_map[k]
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
del star_handlers_registry.star_handlers_map[k]
try:
del star_handlers_registry.star_handlers_map[k]
except KeyError:
pass
try:
del sys.modules[plugin_module_path]
except KeyError:
logger.warning(f"模块 {plugin_module_path} 未载入")
async def update_plugin(self, plugin_name: str, proxy=""):
"""升级一个插件"""
@@ -485,7 +546,7 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin, proxy=proxy)
await self.reload()
await self.reload(plugin_name)
async def turn_off_plugin(self, plugin_name: str):
"""
@@ -524,11 +585,18 @@ class PluginManager:
async def _terminate_plugin(self, star_metadata: StarMetadata):
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
logging.info(f"正在终止插件 {star_metadata.name} ...")
logger.info(f"正在终止插件 {star_metadata.name} ...")
if not star_metadata.activated:
# 说明之前已经被禁用了
logger.debug(f"插件 {star_metadata.name} 未被激活,不需要终止,跳过。")
return
if hasattr(star_metadata.star_cls, "__del__"):
asyncio.get_event_loop().run_in_executor(star_metadata.star_cls.__del__)
else:
asyncio.get_event_loop().run_in_executor(
None, star_metadata.star_cls.__del__
)
elif hasattr(star_metadata.star_cls, "terminate"):
await star_metadata.star_cls.terminate()
async def turn_on_plugin(self, plugin_name: str):
@@ -541,12 +609,17 @@ class PluginManager:
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
if (
func_tool.handler_module_path == plugin.module_path
and func_tool.name in inactivated_llm_tools
):
inactivated_llm_tools.remove(func_tool.name)
func_tool.active = True
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = True
await self.reload(plugin_name)
# plugin.activated = True
async def install_plugin_from_file(self, zip_file_path: str):
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
@@ -559,4 +632,33 @@ class PluginManager:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {str(e)}")
await self.reload()
# await self.reload()
await self.load(specified_dir_name=dir_name)
# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
if not plugin:
# Try to find by other name if directory name doesn't match plugin name
for star in self.context.get_all_stars():
if star.root_dir_name == dir_name:
plugin = star
break
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(desti_dir, "README.md")
if not os.path.exists(readme_path):
readme_path = os.path.join(desti_dir, "readme.md")
if os.path.exists(readme_path):
try:
with open(readme_path, "r", encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
plugin_info = None
if plugin:
plugin_info = {"repo": plugin.repo, "readme": readme_content}
return plugin_info
+144
View File
@@ -0,0 +1,144 @@
from typing import Union, Awaitable, List, Optional, ClassVar
from astrbot.core.message.components import BaseMessageComponent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.api.platform import MessageMember, AstrBotMessage
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.star.context import Context
class StarTools:
"""
提供给插件使用的便捷工具函数集合
这些方法封装了一些常用操作,使插件开发更加简单便捷!
"""
_context: ClassVar[Optional[Context]] = None
@classmethod
def initialize(cls, context: Context) -> None:
"""
初始化StarTools,设置context引用
Args:
context: 暴露给插件的上下文
"""
cls._context = context
@classmethod
async def send_message(
cls, session: Union[str, MessageSesion], message_chain: MessageChain
) -> bool:
"""
根据session(unified_msg_origin)主动发送消息
Args:
session: 消息会话。通过event.session或者event.unified_msg_origin获取
message_chain: 消息链
Returns:
bool: 是否找到匹配的平台
Raises:
ValueError: 当session为字符串且解析失败时抛出
Note:
qq_official(QQ官方API平台)不支持此方法
"""
return await cls._context.send_message(session, message_chain)
@classmethod
async def create_message(
cls,
type: str,
self_id: str,
session_id: str,
message_id: str,
sender: MessageMember,
message: List[BaseMessageComponent],
message_str: str,
raw_message: object,
group_id: str = "",
):
"""
创建一个AstrBot消息对象
Args:
type (str): 消息类型
self_id (str): 机器人自身ID
session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等)
message_id (str): 消息ID
sender (MessageMember): 发送者信息
message (List[BaseMessageComponent]): 消息组件列表
message_str (str): 消息字符串
raw_message (object): 原始消息对象
group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "".
Returns:
AstrBotMessage: 创建的消息对象
"""
abm = AstrBotMessage()
abm.type = type
abm.self_id = self_id
abm.session_id = session_id
abm.message_id = message_id
abm.sender = sender
abm.message = message
abm.message_str = message_str
abm.raw_message = raw_message
abm.group_id = group_id
return abm
# todo: 添加构造事件的方法
# async def create_event(
# self, platform: str, umo: str, sender_id: str, session_id: str
# ):
# platform = self._context.get_platform(platform)
# todo: 添加找到对应平台并提交对应事件的方法
@classmethod
def activate_llm_tool(cls, name: str) -> bool:
"""
激活一个已经注册的函数调用工具
注册的工具默认是激活状态
Args:
name (str): 工具名称
"""
return cls._context.activate_llm_tool(name)
@classmethod
def deactivate_llm_tool(cls, name: str) -> bool:
"""
停用一个已经注册的函数调用工具
Args:
name (str): 工具名称
"""
return cls._context.deactivate_llm_tool(name)
@classmethod
def register_llm_tool(
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
) -> None:
"""
为函数调用(function-calling/tools-use)添加工具
Args:
name (str): 工具名称
func_args (list): 函数参数列表
desc (str): 工具描述
func_obj (Awaitable): 函数对象,必须是异步函数
"""
cls._context.register_llm_tool(name, func_args, desc, func_obj)
@classmethod
def unregister_llm_tool(cls, name: str) -> None:
"""
删除一个函数调用工具
如果再要启用,需要重新注册
Args:
name (str): 工具名称
"""
cls._context.unregister_llm_tool(name)
+1 -1
View File
@@ -41,7 +41,7 @@ class PluginUpdator(RepoZipUpdator):
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
await self.download_from_repo_url(plugin_path, repo_url)
await self.download_from_repo_url(plugin_path, repo_url, proxy=proxy)
try:
remove_dir(plugin_path)
+12
View File
@@ -9,6 +9,11 @@ from astrbot.core.utils.io import download_file
class AstrBotUpdator(RepoZipUpdator):
"""AstrBot 更新器,继承自 RepoZipUpdator 类
该类用于处理 AstrBot 的更新操作
功能包括检查更新、下载更新文件、解压缩更新文件等
"""
def __init__(self, repo_mirror: str = "") -> None:
super().__init__(repo_mirror)
self.MAIN_PATH = os.path.abspath(
@@ -17,6 +22,9 @@ class AstrBotUpdator(RepoZipUpdator):
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
def terminate_child_processes(self):
"""终止当前进程的所有子进程
使用 psutil 库获取当前进程的所有子进程,并尝试终止它们
"""
try:
parent = psutil.Process(os.getpid())
children = parent.children(recursive=True)
@@ -35,6 +43,9 @@ class AstrBotUpdator(RepoZipUpdator):
pass
def _reboot(self, delay: int = 3):
"""重启当前程序
在指定的延迟后,终止所有子进程并重新启动程序
"""
py = sys.executable
time.sleep(delay)
self.terminate_child_processes()
@@ -46,6 +57,7 @@ class AstrBotUpdator(RepoZipUpdator):
raise e
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
"""检查更新"""
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
async def get_releases(self) -> list:
+1 -1
View File
@@ -8,7 +8,7 @@ class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
self.api_key = api_key
self.api_base = api_base
self.session = ClientSession()
self.session = ClientSession(trust_env=True)
self.headers = {
"Authorization": f"Bearer {self.api_key}",
}
+22 -11
View File
@@ -8,6 +8,9 @@ import base64
import zipfile
import uuid
import psutil
import certifi
from typing import Union
from PIL import Image
@@ -17,24 +20,20 @@ def on_error(func, path, exc_info):
"""
a callback of the rmtree function.
"""
print(f"remove {path} failed.")
import stat
if not os.access(path, os.W_OK):
os.chmod(path, stat.S_IWUSR)
func(path)
else:
raise
raise exc_info[1]
def remove_dir(file_path) -> bool:
if not os.path.exists(file_path):
return True
try:
shutil.rmtree(file_path, onerror=on_error)
return True
except BaseException:
return False
shutil.rmtree(file_path, onerror=on_error)
return True
def port_checker(port: int, host: str = "localhost"):
@@ -81,7 +80,13 @@ async def download_image_by_url(
下载图片, 返回 path
"""
try:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(
cafile=certifi.where()
) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context) # 使用 certifi 的根证书
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
if post:
async with session.post(url, json=post_data) as resp:
if not path:
@@ -98,7 +103,7 @@ async def download_image_by_url(
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client.ClientConnectorSSLError:
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("DEFAULT")
@@ -118,7 +123,13 @@ async def download_file(url: str, path: str, show_progress: bool = False):
从指定 url 下载文件到指定路径 path
"""
try:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(
cafile=certifi.where()
) # 使用 certifi 提供的 CA 证书
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.get(url, timeout=1800) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
@@ -141,7 +152,7 @@ async def download_file(url: str, path: str, show_progress: bool = False):
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
end="",
)
except aiohttp.client.ClientConnectorSSLError:
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
# 关闭SSL验证
ssl_context = ssl.create_default_context()
ssl_context.set_ciphers("DEFAULT")
+198
View File
@@ -0,0 +1,198 @@
"""
会话控制
"""
import abc
import asyncio
import time
import functools
import copy
import astrbot.core.message.components as Comp
from typing import Dict, Any, Callable, Awaitable, List
from astrbot.core.platform import AstrMessageEvent
USER_SESSIONS: Dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
FILTERS: List["SessionFilter"] = [] # 存储 SessionFilter 实例
class SessionController:
"""
控制一个 Session 是否已经结束
"""
def __init__(self):
self.future = asyncio.Future()
self.current_event: asyncio.Event = None
"""当前正在等待的所用的异步事件"""
self.ts: float = None
"""上次保持(keep)开始时的时间"""
self.timeout: float | int = None
"""上次保持(keep)开始时的超时时间"""
self.history_chains: List[List[Comp.BaseMessageComponent]] = []
def stop(self, error: Exception = None):
"""立即结束这个会话"""
if not self.future.done():
if error:
self.future.set_exception(error)
else:
self.future.set_result(None)
def keep(self, timeout: float | int = 0, reset_timeout=False):
"""保持这个会话
Args:
timeout (float): 必填。会话超时时间。
当 reset_timeout 设置为 True 时, 代表重置超时时间, timeout 必须 > 0, 如果 <= 0 则立即结束会话。
当 reset_timeout 设置为 False 时, 代表继续维持原来的超时时间, 新 timeout = 原来剩余的timeout + timeout (可以 < 0)
"""
new_ts = time.time()
if reset_timeout:
if timeout <= 0:
self.stop()
return
else:
left_timeout = self.timeout - (new_ts - self.ts)
timeout = left_timeout + timeout
if timeout <= 0:
self.stop()
return
if self.current_event and not self.current_event.is_set():
self.current_event.set() # 通知上一个 keep 结束
new_event = asyncio.Event()
self.ts = new_ts
self.current_event = new_event
self.timeout = timeout
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
async def _holding(self, event: asyncio.Event, timeout: int):
"""等待事件结束或超时"""
try:
await asyncio.wait_for(event.wait(), timeout)
except asyncio.TimeoutError:
if not self.future.done():
self.future.set_exception(TimeoutError("等待超时"))
except asyncio.CancelledError:
pass # 避免报错
# finally:
def get_history_chains(self) -> List[List[Comp.BaseMessageComponent]]:
"""获取历史消息链"""
return self.history_chains
class SessionFilter:
"""如何界定一个会话"""
@abc.abstractmethod
def filter(self, event: AstrMessageEvent) -> str:
"""根据事件返回一个会话标识符"""
pass
class DefaultSessionFilter(SessionFilter):
def filter(self, event: AstrMessageEvent) -> str:
"""默认实现,返回发送者的 ID 作为会话标识符"""
return event.get_sender_id()
class SessionWaiter:
def __init__(
self,
session_filter: SessionFilter,
session_id: str,
record_history_chains: bool,
):
self.session_id = session_id
self.session_filter = session_filter
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
self.session_controller = SessionController()
self.record_history_chains = record_history_chains
"""是否记录历史消息链"""
self._lock = asyncio.Lock()
"""需要保证一个 session 同时只有一个 trigger"""
async def register_wait(
self, handler: Callable[[str], Awaitable[Any]], timeout: int = 30
) -> Any:
"""等待外部输入并处理"""
self.handler = handler
USER_SESSIONS[self.session_id] = self
# 开始一个会话保持事件
self.session_controller.keep(timeout, reset_timeout=True)
try:
return await self.session_controller.future
except Exception as e:
self._cleanup(e)
raise e
finally:
self._cleanup()
def _cleanup(self, error: Exception = None):
"""清理会话"""
USER_SESSIONS.pop(self.session_id, None)
try:
FILTERS.remove(self.session_filter)
except ValueError:
pass
self.session_controller.stop(error)
@classmethod
async def trigger(cls, session_id: str, event: AstrMessageEvent):
"""外部输入触发会话处理"""
session = USER_SESSIONS.get(session_id, None)
if not session or session.session_controller.future.done():
return
async with session._lock:
if not session.session_controller.future.done():
if session.record_history_chains:
session.session_controller.history_chains.append(
[copy.deepcopy(comp) for comp in event.get_messages()]
)
try:
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
await session.handler(session.session_controller, event)
except Exception as e:
session.session_controller.stop(e)
def session_waiter(timeout: int = 30, record_history_chains: bool = False):
"""
装饰器:自动将函数注册为 SessionWaiter 处理函数,并等待外部输入触发执行。
:param timeout: 超时时间(秒)
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
"""
def decorator(func: Callable[[str], Awaitable[Any]]):
@functools.wraps(func)
async def wrapper(
event: AstrMessageEvent,
session_filter: SessionFilter = None,
*args,
**kwargs,
):
if not session_filter:
session_filter = DefaultSessionFilter()
if not isinstance(session_filter, SessionFilter):
raise ValueError("session_filter 必须是 SessionFilter")
session_id = session_filter.filter(event)
FILTERS.append(session_filter)
waiter = SessionWaiter(session_filter, session_id, record_history_chains)
return await waiter.register_wait(func, timeout)
return wrapper
return decorator
+1
View File
@@ -16,6 +16,7 @@ class SharedPreferences:
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4)
f.flush()
def get(self, key, default=None):
return self._data.get(key, default)
File diff suppressed because it is too large Load Diff
+7 -1
View File
@@ -1,5 +1,7 @@
import aiohttp
import os
import ssl
import certifi
from . import RenderStrategy
from astrbot.core.config import VERSION
@@ -46,7 +48,11 @@ class NetworkRenderStrategy(RenderStrategy):
},
}
if return_url:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.post(
f"{self.BASE_RENDER_URL}/generate", json=post_data
) as resp:
+23 -3
View File
@@ -2,6 +2,10 @@ import aiohttp
import os
import zipfile
import shutil
import ssl
import certifi
from astrbot.core.utils.io import on_error, download_file
from astrbot.core import logger
@@ -19,7 +23,7 @@ class ReleaseInfo:
self.body = body
def __str__(self) -> str:
return f"新版本: {self.version}, 发布于: {self.published_at}, 详细内容: {self.body}"
return f"\n{self.body}\n\n版本: {self.version} | 发布于: {self.published_at}"
class RepoZipUpdator:
@@ -33,8 +37,23 @@ class RepoZipUpdator:
返回一个列表每个元素是一个字典包含版本号发布时间更新内容commit hash等信息
"""
try:
async with aiohttp.ClientSession(trust_env=True) as session:
ssl_context = ssl.create_default_context(
cafile=certifi.where()
) # 新增:创建基于 certifi 的 SSL 上下文
connector = aiohttp.TCPConnector(
ssl=ssl_context
) # 新增:使用 TCPConnector 指定 SSL 上下文
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.get(url) as response:
# 检查 HTTP 状态码
if response.status != 200:
text = await response.text()
logger.error(
f"请求 {url} 失败,状态码: {response.status}, 内容: {text}"
)
raise Exception(f"请求失败,状态码: {response.status}")
result = await response.json()
if not result:
return []
@@ -53,7 +72,8 @@ class RepoZipUpdator:
"zipball_url": release["zipball_url"],
}
)
except BaseException:
except Exception as e:
logger.error(f"解析版本信息时发生异常: {e}")
raise Exception("解析版本信息失败")
return ret
-3
View File
@@ -1,3 +0,0 @@
from .dashboard_lifecycle import AstrBotDashBoardLifecycle
__all__ = ["AstrBotDashBoardLifecycle"]
+4
View File
@@ -6,6 +6,8 @@ from .stat import StatRoute
from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
from .tools import ToolsRoute # 导入新的ToolsRoute
from .conversation import ConversationRoute
__all__ = [
@@ -17,4 +19,6 @@ __all__ = [
"LogRoute",
"StaticFileRoute",
"ChatRoute",
"ToolsRoute", # 添加新的ToolsRoute
"ConversationRoute",
]
+21 -2
View File
@@ -2,7 +2,8 @@ import jwt
import datetime
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core import WEBUI_SK
from astrbot.core import WEBUI_SK, DEMO_MODE
from astrbot import logger
class AuthRoute(Route):
@@ -19,15 +20,33 @@ class AuthRoute(Route):
password = self.config["dashboard"]["password"]
post_data = await request.json
if post_data["username"] == username and post_data["password"] == password:
change_pwd_hint = False
if username == "astrbot" and password == "77b90590a8945a7d36c963981a307dc9":
change_pwd_hint = True
logger.warning("为了保证安全,请尽快修改默认密码。")
return (
Response()
.ok({"token": self.generate_jwt(username), "username": username})
.ok(
{
"token": self.generate_jwt(username),
"username": username,
"change_pwd_hint": change_pwd_hint,
}
)
.__dict__
)
else:
return Response().error("用户名或密码错误").__dict__
async def edit_account(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
password = self.config["dashboard"]["password"]
post_data = await request.json
+59 -13
View File
@@ -12,8 +12,11 @@ from astrbot.core import logger
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
return int(value)
if type_ == "int":
try:
return int(value)
except (ValueError, TypeError):
return None
elif (
type_ == "float"
and isinstance(value, str)
@@ -22,6 +25,11 @@ def try_cast(value: str, type_: str):
return float(value)
elif type_ == "float" and isinstance(value, int):
return float(value)
elif type_ == "float":
try:
return float(value)
except (ValueError, TypeError):
return None
def validate_config(
@@ -29,17 +37,46 @@ def validate_config(
) -> typing.Tuple[typing.List[str], typing.Dict]:
errors = []
def validate(data, metadata=schema, path=""):
for key, meta in metadata.items():
if key not in data:
def validate(data: dict, metadata: dict = schema, path=""):
for key, value in data.items():
if key not in metadata:
# 无 schema 的配置项,执行类型猜测
if isinstance(value, str):
try:
data[key] = int(value)
continue
except ValueError:
pass
try:
data[key] = float(value)
continue
except ValueError:
pass
if value.lower() == "true":
data[key] = True
elif value.lower() == "false":
data[key] = False
continue
value = data[key]
meta = metadata[key]
# null 转换
if value is None:
data[key] = DEFAULT_VALUE_MAP[meta["type"]]
continue
# 递归验证
if meta["type"] == "list" and isinstance(value, list):
if meta["type"] == "list" and not isinstance(value, list):
errors.append(
f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}"
)
elif (
meta["type"] == "list"
and isinstance(value, list)
and value
and "items" in meta
and isinstance(value[0], dict)
):
# 当前仅针对 list[dict] 的情况进行类型校验,以适配 AstrBot 中 platform、provider 的配置
for item in value:
validate(item, meta["items"], path=f"{path}{key}.")
elif meta["type"] == "object" and isinstance(value, dict):
@@ -75,7 +112,6 @@ def validate_config(
errors.append(
f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}"
)
validate(value, meta["items"], path=f"{path}{key}.")
if is_core:
for key, group in schema.items():
@@ -103,6 +139,7 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False)
except BaseException as e:
logger.error(traceback.format_exc())
logger.warning(f"验证配置时出现异常: {e}")
raise ValueError(f"验证配置时出现异常: {e}")
if errors:
raise ValueError(f"格式校验未通过: {errors}")
config.save_config(post_config)
@@ -124,6 +161,7 @@ class ConfigRoute(Route):
"/config/provider/new": ("POST", self.post_new_provider),
"/config/provider/update": ("POST", self.post_update_provider),
"/config/provider/delete": ("POST", self.post_delete_provider),
"/config/llmtools": ("GET", self.get_llm_tools),
}
self.register_routes()
@@ -149,9 +187,10 @@ class ConfigRoute(Route):
plugin_name = request.args.get("plugin_name", "unknown")
try:
await self._save_plugin_configs(post_configs, plugin_name)
await self.core_lifecycle.plugin_manager.reload(plugin_name)
return (
Response()
.ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载配置")
.ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在重载插件")
.__dict__
)
except Exception as e:
@@ -196,7 +235,8 @@ class ConfigRoute(Route):
return Response().error("未找到对应平台").__dict__
try:
await self._save_astrbot_configs(self.config)
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.platform_manager.reload(new_config)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "更新平台配置成功~").__dict__
@@ -232,7 +272,8 @@ class ConfigRoute(Route):
else:
return Response().error("未找到对应平台").__dict__
try:
await self._save_astrbot_configs(self.config)
save_config(self.config, self.config, is_core=True)
await self.core_lifecycle.platform_manager.terminate_platform(platform_id)
except Exception as e:
return Response().error(str(e)).__dict__
return Response().ok(None, "删除平台配置成功~").__dict__
@@ -253,6 +294,12 @@ class ConfigRoute(Route):
return Response().error(str(e)).__dict__
return Response().ok(None, "删除成功,已经实时生效~").__dict__
async def get_llm_tools(self):
"""获取函数调用工具。包含了本地加载的以及 MCP 服务的工具"""
tool_mgr = self.core_lifecycle.provider_manager.llm_tools
tools = tool_mgr.get_func_desc_openai_style()
return Response().ok(tools).__dict__
async def _get_astrbot_config(self):
config = self.config
@@ -298,7 +345,7 @@ class ConfigRoute(Route):
async def _save_astrbot_configs(self, post_configs: dict):
try:
save_config(post_configs, self.config, is_core=True)
self.core_lifecycle.restart()
await self.core_lifecycle.restart()
except Exception as e:
raise e
@@ -315,6 +362,5 @@ class ConfigRoute(Route):
try:
save_config(post_configs, md.config)
self.core_lifecycle.restart()
except Exception as e:
raise e
+215
View File
@@ -0,0 +1,215 @@
import traceback
import json
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
from astrbot.core.db import BaseDatabase
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class ConversationRoute(Route):
def __init__(
self,
context: RouteContext,
db_helper: BaseDatabase,
core_lifecycle: AstrBotCoreLifecycle,
) -> None:
super().__init__(context)
self.routes = {
"/conversation/list": ("GET", self.list_conversations),
"/conversation/detail": (
"POST",
self.get_conv_detail,
),
"/conversation/update": ("POST", self.upd_conv),
"/conversation/delete": ("POST", self.del_conv),
"/conversation/update_history": (
"POST",
self.update_history,
),
}
self.db_helper = db_helper
self.register_routes()
async def list_conversations(self):
"""获取对话列表,支持分页、排序和筛选"""
try:
# 获取分页参数
page = request.args.get("page", 1, type=int)
page_size = request.args.get("page_size", 20, type=int)
# 获取筛选参数
platforms = request.args.get("platforms", "")
message_types = request.args.get("message_types", "")
search_query = request.args.get("search", "")
exclude_ids = request.args.get("exclude_ids", "")
exclude_platforms = request.args.get("exclude_platforms", "")
# 转换为列表
platform_list = platforms.split(",") if platforms else []
message_type_list = message_types.split(",") if message_types else []
exclude_id_list = exclude_ids.split(",") if exclude_ids else []
exclude_platform_list = (
exclude_platforms.split(",") if exclude_platforms else []
)
# 限制页面大小,防止请求过大数据
if page < 1:
page = 1
if page_size < 1:
page_size = 20
if page_size > 100:
page_size = 100
# 使用数据库的分页方法获取会话列表和总数,传入筛选条件
try:
conversations, total_count = self.db_helper.get_filtered_conversations(
page=page,
page_size=page_size,
platforms=platform_list,
message_types=message_type_list,
search_query=search_query,
exclude_ids=exclude_id_list,
exclude_platforms=exclude_platform_list,
)
except Exception as e:
logger.error(f"数据库查询出错: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"数据库查询出错: {str(e)}").__dict__
# 计算总页数
total_pages = (
(total_count + page_size - 1) // page_size if total_count > 0 else 1
)
result = {
"conversations": conversations,
"pagination": {
"page": page,
"page_size": page_size,
"total": total_count,
"total_pages": total_pages,
},
}
return Response().ok(result).__dict__
except Exception as e:
error_msg = f"获取对话列表失败: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
return Response().error(f"获取对话列表失败: {str(e)}").__dict__
async def get_conv_detail(self):
"""获取指定对话详情(通过POST请求)"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
return (
Response()
.ok(
{
"user_id": user_id,
"cid": cid,
"title": conversation.title,
"persona_id": conversation.persona_id,
"history": conversation.history,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
}
)
.__dict__
)
except Exception as e:
logger.error(f"获取对话详情失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"获取对话详情失败: {str(e)}").__dict__
async def upd_conv(self):
"""更新对话信息(标题和角色ID)"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
title = data.get("title")
persona_id = data.get("persona_id", "")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
if title is not None:
self.db_helper.update_conversation_title(user_id, cid, title)
if persona_id is not None:
self.db_helper.update_conversation_persona_id(user_id, cid, persona_id)
return Response().ok({"message": "对话信息更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话信息失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"更新对话信息失败: {str(e)}").__dict__
async def del_conv(self):
"""删除对话"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
self.db_helper.delete_conversation(user_id, cid)
return Response().ok({"message": "对话删除成功"}).__dict__
except Exception as e:
logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"删除对话失败: {str(e)}").__dict__
async def update_history(self):
"""更新对话历史内容"""
try:
data = await request.get_json()
user_id = data.get("user_id")
cid = data.get("cid")
history = data.get("history")
if not user_id or not cid:
return Response().error("缺少必要参数: user_id 和 cid").__dict__
if history is None:
return Response().error("缺少必要参数: history").__dict__
# 历史记录必须是合法的 JSON 字符串
try:
if isinstance(history, list):
history = json.dumps(history)
else:
# 验证是否为有效的 JSON 字符串
json.loads(history)
except json.JSONDecodeError:
return (
Response().error("history 必须是有效的 JSON 字符串或数组").__dict__
)
conversation = self.db_helper.get_conversation_by_user_id(user_id, cid)
if not conversation:
return Response().error("对话不存在").__dict__
self.db_helper.update_conversation(user_id, cid, history)
return Response().ok({"message": "对话历史更新成功"}).__dict__
except Exception as e:
logger.error(f"更新对话历史失败: {str(e)}\n{traceback.format_exc()}")
return Response().error(f"更新对话历史失败: {str(e)}").__dict__
+33 -17
View File
@@ -1,5 +1,6 @@
import asyncio
from quart import websocket
import json
from quart import make_response
from astrbot.core import logger, LogBroker
from .route import Route, RouteContext
@@ -8,21 +9,36 @@ class LogRoute(Route):
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
super().__init__(context)
self.log_broker = log_broker
self.app.add_url_rule(
"/api/live-log", view_func=self.log, methods=["GET"], websocket=True
)
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
async def log(self):
queue = None
try:
queue = self.log_broker.register()
while True:
message = await queue.get()
await websocket.send(message)
except asyncio.CancelledError:
pass
except BaseException as e:
logger.error(f"WebSocket 连接错误: {e}")
finally:
if queue:
self.log_broker.unregister(queue)
async def stream():
queue = None
try:
queue = self.log_broker.register()
while True:
message = await queue.get()
payload = {
"type": "log",
**message # see astrbot/core/log.py
}
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
except asyncio.CancelledError:
pass
except BaseException as e:
logger.error(f"Log SSE 连接错误: {e}")
finally:
if queue:
self.log_broker.unregister(queue)
response = await make_response(
stream(),
{
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Transfer-Encoding": "chunked",
},
)
response.timeout = None
return response
+68 -8
View File
@@ -1,5 +1,9 @@
import traceback
import aiohttp
import ssl
import certifi
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
@@ -11,6 +15,7 @@ from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.regex import RegexFilter
from astrbot.core.star.star_handler import EventType
from astrbot.core import DEMO_MODE
class PluginRoute(Route):
@@ -46,6 +51,13 @@ class PluginRoute(Route):
}
async def reload_plugins(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
data = await request.json
plugin_name = data.get("name", None)
try:
@@ -65,9 +77,14 @@ class PluginRoute(Route):
else:
urls = ["https://api.soulter.top/astrbot/plugins"]
# 新增:创建 SSL 上下文,使用 certifi 提供的根证书
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
for url in urls:
try:
async with aiohttp.ClientSession(trust_env=True) as session:
async with aiohttp.ClientSession(
trust_env=True, connector=connector
) as session:
async with session.get(url) as response:
if response.status == 200:
result = await response.json()
@@ -178,6 +195,13 @@ class PluginRoute(Route):
return handlers
async def install_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
repo_url = post_data["url"]
@@ -187,30 +211,44 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url, proxy)
self.core_lifecycle.restart()
plugin_info = await self.plugin_manager.install_plugin(repo_url, proxy)
# self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
return Response().ok(plugin_info, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def install_plugin_upload(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
file = await request.files
file = file["file"]
logger.info(f"正在安装用户上传的插件 {file.filename}")
file_path = f"data/temp/{file.filename}"
await file.save(file_path)
await self.plugin_manager.install_plugin_from_file(file_path)
self.core_lifecycle.restart()
plugin_info = await self.plugin_manager.install_plugin_from_file(file_path)
# self.core_lifecycle.restart()
logger.info(f"安装插件 {file.filename} 成功")
return Response().ok(None, "安装成功。").__dict__
return Response().ok(plugin_info, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def uninstall_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
plugin_name = post_data["name"]
try:
@@ -223,13 +261,21 @@ class PluginRoute(Route):
return Response().error(str(e)).__dict__
async def update_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
plugin_name = post_data["name"]
proxy: str = post_data.get("proxy", None)
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name, proxy)
self.core_lifecycle.restart()
# self.core_lifecycle.restart()
await self.plugin_manager.reload(plugin_name)
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:
@@ -237,6 +283,13 @@ class PluginRoute(Route):
return Response().error(str(e)).__dict__
async def off_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
plugin_name = post_data["name"]
try:
@@ -248,6 +301,13 @@ class PluginRoute(Route):
return Response().error(str(e)).__dict__
async def on_plugin(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
post_data = await request.json
plugin_name = post_data["name"]
try:
+29 -4
View File
@@ -1,12 +1,14 @@
import traceback
import psutil
import time
import threading
from .route import Route, Response, RouteContext
from astrbot.core import logger
from quart import request
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.config import VERSION
from astrbot.core import DEMO_MODE
class StatRoute(Route):
@@ -28,7 +30,14 @@ class StatRoute(Route):
self.core_lifecycle = core_lifecycle
async def restart_core(self):
self.core_lifecycle.restart()
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
await self.core_lifecycle.restart()
return Response().ok().__dict__
def format_sec(self, sec: int):
@@ -64,6 +73,20 @@ class StatRoute(Route):
stat_dict = stat.__dict__
cpu_percent = psutil.cpu_percent(interval=0.5)
thread_count = threading.active_count()
# 获取插件信息
plugins = self.core_lifecycle.star_context.get_all_stars()
plugin_info = []
for plugin in plugins:
info = {
"name": getattr(plugin, "name", plugin.__class__.__name__),
"version": getattr(plugin, "version", "1.0.0"),
"is_enabled": True,
}
plugin_info.append(info)
stat_dict.update(
{
"platform": self.db_helper.get_grouped_base_stats(
@@ -73,9 +96,8 @@ class StatRoute(Route):
"platform_count": len(
self.core_lifecycle.platform_manager.get_insts()
),
"plugin_count": len(
self.core_lifecycle.star_context.get_all_stars()
),
"plugin_count": len(plugins),
"plugins": plugin_info,
"message_time_series": message_time_based_stats,
"running": self.format_sec(
int(time.time()) - self.core_lifecycle.start_time
@@ -84,6 +106,9 @@ class StatRoute(Route):
"process": psutil.Process().memory_info().rss >> 20,
"system": psutil.virtual_memory().total >> 20,
},
"cpu_percent": round(cpu_percent, 1),
"thread_count": thread_count,
"start_time": self.core_lifecycle.start_time,
}
)

Some files were not shown because too many files have changed in this diff Show More