Compare commits

...

61 Commits

Author SHA1 Message Date
Soulter 34921e91f0 chore: bump version to 4.11.2 2026-01-08 15:32:37 +08:00
Soulter 6c15592cbb feat(metrics): add metrics tracking for plugin installation events 2026-01-08 15:29:32 +08:00
letr 8c7a4b87d0 fix(webui): maintain international consistency of the 'repo' button (#4358) 2026-01-08 13:34:13 +08:00
Gao Jinzhe 8ff12e3972 fix: on_waiting_llm_request hook did not check message validity (#4349)
* fix:修复waitingllmrequest没有检查消息有效性的问题

* 进行ruff修复
2026-01-07 12:51:00 +08:00
Gao Jinzhe eefa3f2f00 fix: conversation was still saved to the context after stop_event (#4345) 2026-01-07 12:48:13 +08:00
Oscar Shaw 479284a8dd feat(extension): plugin marketplace search supports matching display names. (#4332) 2026-01-06 12:54:11 +08:00
clown145 9322218880 feat: supports to display plugin CHANGELOG.md (#4337)
* feat: optimize plugin update changelog feature, refactor to reuse ReadmeDialog and support independent view entry

* fix: distinguish error state from empty state in ReadmeDialog
2026-01-06 12:53:14 +08:00
Soulter 399062f14d chore: remove wechatpadpro 2026-01-06 11:14:54 +08:00
Soulter de82df3c33 chore: bump version to 4.11.1 2026-01-05 20:22:18 +08:00
Soulter 9896aebfb5 feat: enhance provider source configuration with custom hints and tooltips 2026-01-05 20:20:09 +08:00
Soulter df7653eb99 fix: 部分情况下选择提供商出现”暂无可用提供商的问题“,即使实际上有 2026-01-05 20:01:54 +08:00
Soulter 8e7b44185d chore: bump version to 4.11.0 2026-01-05 18:05:12 +08:00
RC-CHN ef1c66a92e feat(webui): enable Range request support for backup downloads (#4329) 2026-01-05 17:27:03 +08:00
Soulter 241f1c26d3 feat: context compress (#4322)
* feat: context compressor

Co-authored-by: kawayiYokami <289104862@qq.com>

* Add comprehensive tests for ContextManager and ContextTruncator

- Implemented a full test suite for ContextManager covering initialization, message processing, token-based compression, and error handling.
- Added tests for ContextTruncator focusing on message fixing, truncation by turns, dropping oldest turns, and halving.
- Ensured that both test suites validate edge cases and maintain expected behavior with various message types, including system and tool messages.

* feat: add MockProvider for LLM compression tests

* chore: remove lock

* ruff fix

* fix

* perf

* feat: enhance context compression with token tracking and logging

* feat: update logging for context compression trigger

* feat: implement context compression logic with dynamic threshold and token tracking

* fix: reorder import statements for consistency

* feat: add token_usage tracking to conversations and update related processing logic

---------

Co-authored-by: kawayiYokami <289104862@qq.com>
2026-01-05 17:26:10 +08:00
Soulter 3615b7dde2 fix: token usage is always 0 in anthropic source (#4328) 2026-01-05 17:06:12 +08:00
RC-CHN 9bcf9bf2a0 fix(dashboard): complete i18n support for shared components (#4327)
* fix(dashboard): complete i18n support for shared components

- Replace hardcoded Chinese strings with i18n translations in:
  - PluginSetSelector.vue
  - ProviderSelector.vue
  - PersonaSelector.vue
  - KnowledgeBaseSelector.vue
  - T2ITemplateEditor.vue
  - AstrBotConfigV4.vue
  - ConfigItemRenderer.vue
  - ProxySelector.vue
  - ListConfigItem.vue

- Add missing translations to locale files:
  - core/shared.json: personaSelector, t2iTemplateEditor
  - core/common.json: autoDetect
  - features/settings.json: network.proxySelector

- Change prop defaults from hardcoded Chinese to empty strings,
  allowing components to use i18n fallback translations

* fix(i18n): 修正插件选择器标签的翻译格式,添加冒号

* fix(deployment): 添加持久化 machine-id PVC 和初始化容器,优化资源限制
2026-01-05 09:45:28 +08:00
Gao Jinzhe 7f5cc7cf1a feat: add on_waiting_llm_request event hook (#4319)
* 加入on_waiting_llm_request钩子

* ruff check
2026-01-04 16:11:12 +08:00
Oscar Shaw f26867c77d ci(stale): 增加 stale action 每次运行的操作限制 (#4256) 2026-01-04 11:20:03 +08:00
Soulter a14d588b44 docs: add Matrix adapter to community maintained section in multiple languages 2026-01-04 10:15:16 +08:00
Soulter e236402d92 chore: update platform adapter name for clarity 2026-01-04 10:12:25 +08:00
Soulter 454841de10 fix: database is locked error when invoking tts command (#4313)
* fix: database is locked error when invoking /tts command

fixes: #4311

* chore: rm pnpm lockfile

* perf: 减少操作数据库的次数
2026-01-03 19:12:39 +08:00
clown145 442b5403df feat(webui): supports force update plugins (#4293) 2026-01-03 15:30:50 +08:00
Soulter 9db7bf59b8 docs: add new community group contact 2026-01-03 00:48:55 +08:00
雪語 3622504021 fix: retry failed due to a mismatch in the msg.id data type of a WeChat Official Account (#4292)
问题描述:
- 控制台显示正常发送消息,但公众号未收到
- 处理时间 > 5秒的消息几乎总是失败(如 AI 图片生成)
- 短消息(<5秒)正常工作

根本原因:
msg.id 是整数类型,但字典 key 使用字符串类型,导致类型不匹配。
检查时整数无法匹配字符串 key,导致每次都创建新的 future,
微信重试时无法重用,最终导致响应失败。

修复内容:
将 msg.id 转换为字符串后再检查字典
  if str(msg.id) in self.wexin_event_workers:

影响范围:
- 修复了微信重试时无法正确重用 future 的问题
- AI 图片生成、长文本生成等耗时操作现在可以正常工作
- 仅影响微信公众号适配器,其他平台不受影响

Fixes #1679
2026-01-02 22:16:04 +08:00
Soulter fc42db40ce chore: bump version to 4.10.6 2026-01-02 12:14:59 +08:00
Soulter e413a002c1 perf: list view mode toggle with localStorage support in ExtensionPage (#4288)
closes: #4253
2026-01-02 11:59:41 +08:00
tjc66666666 6437d759a3 fix: reasoning content inject for openai api (#4284) 2026-01-02 01:09:28 +08:00
Soulter c758b2d888 feat: use shell globbing to match umop config router (#4270)
* feat: use shell globbing to match umop config router

* rf

* fix: use fnmatchcase for case-sensitive matching in UmopConfigRouter
2025-12-31 23:10:12 +08:00
Soulter 510290fe0e chore: bump version to 4.10.5 2025-12-31 17:58:28 +08:00
Soulter c61d62edb6 fix: handle null item-meta in ConfigItemRenderer (#4269)
fixes: #4268
2025-12-31 17:55:49 +08:00
Soulter 45bce6fe76 chore: bump version to 4.10.4 2025-12-31 12:50:37 +08:00
Soulter f156adddf8 feat: enhance configuration editor with template schema support and UI improvements (#4267)
- Added support for template schemas in the configuration editor, allowing users to define and manage additional parameters like temperature, top_p, and max_tokens.
- Improved UI components in ProviderModelsPanel and ObjectEditor for better user interaction, including new configuration buttons and enhanced input handling.
- Updated localization files to include new configuration options.
2025-12-31 12:19:29 +08:00
Soulter b5a4b80c36 perf: Add list item add button (#4259)
fixes: #4254
2025-12-30 15:27:17 +08:00
Soulter 792fb69d6d perf: allow zero chunk overlap in recursive chunker (#4258)
* Allow zero chunk overlap

* Validate recursive chunking bounds
2025-12-30 15:23:05 +08:00
Oscar Shaw 300a73ace0 fix(#4188): terminate the same plugin when install the plugin via file (#4250)
* fix(#4188): 从文件安装插件时先终止并解绑已存在的同名插件

* feat(star): 优化从文件安装插件的处理同名冲突逻辑,增加边缘检查
2025-12-30 13:43:44 +08:00
Oscar Shaw a5b9de3695 Update stale.yml 2025-12-30 11:10:21 +08:00
fluidcat 90142bcafe fix: ensure close aiodocker.Docker() (#4251)
* fix: ensure close aiodocker.Docker()

* fix: code formatted
2025-12-30 00:24:29 +08:00
Misaka Mikoto 79d0487c03 feat: add template_list config type to support multiple repeated core/plugin config sets (#4208)
* feat: 添加模板列表配置支持,包含验证和编辑功能

* refactor(dashboard): extract ConfigItemRenderer to eliminate code duplication

- Create ConfigItemRenderer.vue to centralize rendering logic for various config types (string, int, bool, selectors, etc.)
- Refactor TemplateListEditor.vue to use the new renderer for entry fields
- Refactor AstrBotConfig.vue and AstrBotConfigV4.vue to simplify metadata-driven rendering
- Resolve circular dependency by decoupling TemplateListEditor from the base renderer

* ruff format

* refactor: improve config validation and fix unidirection data flow

- Frontend: Fix one-way data flow in TemplateListEditor.vue by cloning entries before applying defaults and emitting updates instead of in-place modification.
- Frontend: Remove unused TemplateListEditor import in ConfigItemRenderer.vue.
- Backend: Refactor validate_config in config.py by extracting _expect_type and _validate_template_list helpers to reduce nesting and complexity.
2025-12-30 00:16:24 +08:00
akuuma 4f15102e79 perf(satori): increase websocket max message size to 10MB (#4238)
* perf(satori): increase websocket max message size to 10MB

Add max_size parameter to websocket connection to handle larger messages
and prevent connection drops when receiving large payloads from Satori platform.

* chore: ruff format

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-29 23:59:55 +08:00
Oscar Shaw ef1feb639c fix(utils): optimize pip install output decoding for cross-platform encoding compatibility (#4249) 2025-12-29 23:54:11 +08:00
Oscar Shaw 1039a4f864 chore: update stale issue workflow to target only 'bug' labeled issues and adjust inactivity handling (#4252) 2025-12-29 23:49:40 +08:00
Soulter 66e2f49c11 perf: support extended thinking for Anthropic, DeepSeek reasoning mode, and Gemini text part thought signatures to improve multi-turn reasoning performance. (#4240)
* perf: support extended thinking for Anthropic, DeepSeek reasoning mode, and Gemini text part thought signatures to improve multi-turn reasoning performance.

* chore: remove verbose

* perf

* refactor: remove special tools handling for deepseek-reasoner model in openai source

* fix: improve error handling and logging in InternalAgentSubStage processing

* refactor: remove unused reasoning content from Gemini source processing

* refactor: enhance modality determination logic in useProviderSources

Co-authored-by: kawayiYokami <289104862@qq.com>
2025-12-29 14:22:30 +08:00
fluidcat c5773fe63e feat: add JSON value for custom_extra_body (#4246)
* feat: add JSON value for custom_extra_body

* feat: add invalid format tip
2025-12-29 12:52:10 +08:00
NieiR 4e9ef48af2 fix: handle None values in _extract_usage to prevent TypeError (#4244)
* fix: handle None values in _extract_usage to prevent TypeError

Some LLM providers (especially API proxies) may return None for
prompt_tokens and completion_tokens in the usage response. This
causes a TypeError when attempting arithmetic operations.

Added null checks with fallback to 0 for both prompt_tokens and
completion_tokens before performing calculations.

* refactor: use explicit None check and reuse cached variable

- Use `is None` instead of `or 0` to avoid masking unexpected falsy values
- Reuse `cached` variable for `input_cached` to avoid redundant calculation

* ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-29 12:49:25 +08:00
RC-CHN 9eafd7b44a feat: add features for chunked upload and backup file management to the backup section (#4237)
* feat: 添加分片上传备份文件功能

* feat: 为上传备份文件添加异步并发以提升速度

* feat: 使用浏览器原生下载方式以显示进度条

* feat: 添加从已上传备份列表恢复的功能

* feat: 允许重命名备份文件

* feat: 在后端校验可用备份文件后在前端部分显示备份版本号,添加手动上传提示

* style: format code

* fix: 更新备份部分测试

* fix: 修复浏览器原生下载鉴权问题,通过url传参的方式完成认证

* feat(backup): 改进备份系统的分片上传和下载鉴权

- 修复浏览器原生下载鉴权问题,支持 URL 参数传递 token
- 修复上传会话过期判断,使用 last_activity 避免活跃上传被清理
- 延迟启动后台清理任务,避免 asyncio 事件循环问题
- 统一由后端计算 chunk_size 和 total_chunks,避免前后端不一致
- 更新 generate_unique_filename 文档注释与实际行为一致
- 更新测试用例以验证 origin 字段

修复问题:
- 浏览器下载时显示"需要授权"
- 大文件上传可能因会话过期失败
- __init__ 中 asyncio.create_task 可能失败

* style: format code
2025-12-29 12:30:59 +08:00
Soulter fc61f7ad32 fix: unique session config cannot be applied in non-default astrbot config (#4232)
* fix: unique session config cannot be applied in non-default astrbot config

fixes: #4195

* perf: sesison id
2025-12-28 15:01:43 +08:00
simplify123 f51810997a fix: Xinference STT failed: INVALID (#4231)
* Update xinference_stt_provider.py

* chore: ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-28 14:42:06 +08:00
Soulter fb4baf676f perf: add auto voice emotion for minimax tts (#4228)
* perf: add `auto` voice emotion for minimax tts

* ruff format
2025-12-28 00:34:44 +08:00
Oscar Shaw 71ad974c3c feat: two dashboard persistence optimizations (#4221)
* feat: persist console visibility state in local storage on PlatformPage

* feat: add persistence for sidebar opened items in local storage
2025-12-27 14:06:01 +08:00
Soulter f0fff68947 fix: at sender users not working in dingtalk (#4219)
fixes: #4218
2025-12-27 11:26:39 +08:00
Soulter 3e3599835e chore: bump version to 4.10.3 2025-12-26 22:39:59 +08:00
Soulter 5255388e2d refactor: move builtin stars to astrbot package (#4209)
* refactor: move builtin stars to astrbot package

fixes: #4202

* chore: ruff format

* chore: remove print
2025-12-26 22:31:22 +08:00
Yokami fbdd60b64c feat: add extra user content block support (#4189)
* feat: 多文本块功能

* FIX

* 传递链

* 重命名

* refactor: unify extra_user_content_parts type to ContentPart across providers and update related handling

* claude额外块支持图片模态

* 已经处理过了不用再处理

* feat: enhance image handling in extra content blocks for multiple providers

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-26 22:08:20 +08:00
Soulter bd1b0a2836 perf: drop unnecessary none-value fields in tool call loop (#4213) 2025-12-26 21:12:34 +08:00
Soulter 19541d9d07 fix: ensure max_tokens is set and validate tool_calls type in ProviderAnthropic (#4212) 2025-12-26 21:01:05 +08:00
大饼鸡蛋 2a5d574394 fix: failed to initialize FishAudio TTS instance (#4200)
fixes: #4172

* fix: 修复 FishAudio 源的配置加载问题并增强请求鲁棒性

- Fix `KeyError: 'model'``: 适配新版配置结构。
- Add `timeout` support: 防止长文本生成时超时。
- Improve response handling: 使用更标准的 Header 检查方式。

* feat: 使用更安全的类型转换并优化错误信息打印
2025-12-26 20:50:45 +08:00
Soulter f2924fbd1b chore: update readme 2025-12-26 18:04:56 +08:00
Gao Jinzhe 703e208947 fix: handle index out of range error when selecting provider (#4206) 2025-12-26 18:02:43 +08:00
NoctuUFO 9a5cc977c2 fix: fix log loss on SSE reconnect using Last-Event-ID (#4205)
* feat: implement last-event-id handing in log route

* perf: better log handling

* chore: ruff format

* perf: log

* Update ConsoleDisplayer.vue

* Update package.json

* Update ConsoleDisplayer.vue

* Update common.js

* chore: ruff format

* fix: ensure last_event_id is required for log replay

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-26 18:01:58 +08:00
RC-CHN aa38fe776a feat: supports data backup (#4105)
* feat: 添加数据迁移功能

* test: 添加迁移相关测试

* feat: 备份插件及相关持久化目录

* fix: 修复版本号比较逻辑,添加相关测试

* fix: 清洗文件名,添加相关测试

* fix: 修复安全文件名测试用例断言

* refactor: 优化代码,为备份模块提取公用常量

* feat: 修改备份版本校验逻辑,允许强制小版本间导入

* fix: 修复备份创建时间读取,修复备份相关i18n

* refactor(backup): 使用 astrbot_path 统一管理备份目录路径

* fix(backup): 清理备份模块中未使用的导入

* refactor(backup): 统一备份路径与参数并移除未用附件目录

- 通过 astrbot_path 动态获取备份/知识库/数据相关路径
- 移除 exporter/importer 未使用的 attachments_dir/data_root 传参
- 更新备份路由与测试用例的构造参数

* fix(dashboard): alias mermaid to dist entry for Vite prebundle

* fix(backup): 放行start-time接口到白名单以处理备份导入后jwt token变化导致无法自动刷新webui的问题

* chore(backup): 统一配置路径以使用动态数据目录

* refactor(backup): 使用 VersionComparator 替代重复的版本比较函数

* style(backup test): format code

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-26 15:47:50 +08:00
Soulter 701399c00c docs: update readme xmas 2025-12-24 21:58:04 +08:00
169 changed files with 11575 additions and 3417 deletions
+1 -2
View File
@@ -15,7 +15,6 @@ Always reference these instructions first and fallback to search or bash command
### Running the Application
- Run main application: `uv run main.py` -- starts in ~3 seconds
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
### Dashboard Build (Vue.js/Node.js)
- **Prerequisites**: Node.js 20+ and npm 10+ required
@@ -35,7 +34,7 @@ Always reference these instructions first and fallback to search or bash command
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
### Plugin Development
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
- Plugins load from `astrbot/builtin_stars/` (built-in) and `data/plugins/` (user-installed)
- Plugin system supports function tools and message handlers
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
+52 -15
View File
@@ -1,27 +1,64 @@
# This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time.
# 本工作流用于标记并关闭长期不活跃的 Issue。
# 目前仅针对带 `bug` 标签的 Issue 生效,不会处理 PR。
#
# You can adjust the behavior by modifying this file.
# For more information, see:
# https://github.com/actions/stale
name: Mark stale issues and pull requests
# 文档: https://github.com/actions/stale
name: Mark stale bug issues
on:
schedule:
- cron: '21 23 * * *'
# 每天 UTC 08:30 执行 (北京时间 16:30)
- cron: '30 8 * * *'
workflow_dispatch:
inputs:
dry-run:
description: '仅预览, 不实际执行 (Dry run mode)'
required: false
default: true
type: boolean
jobs:
stale:
runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps:
- uses: actions/stale@v10
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
stale-issue-message: 'Stale issue message'
stale-pr-message: 'Stale pull request message'
stale-issue-label: 'no-issue-activity'
stale-pr-label: 'no-pr-activity'
- uses: actions/stale@v10
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
operations-per-run: 200
# 只处理带 bug 标签的 Issue
any-of-labels: 'bug'
# 不处理 PR
days-before-pr-stale: -1
days-before-pr-close: -1
# 不活跃判定与关闭策略: 先标记 stale, 再延迟关闭
days-before-issue-stale: 60
days-before-issue-close: 30
stale-issue-label: 'stale'
stale-issue-message: |
This issue has been automatically marked as **stale** because it has not had any activity.
It will be closed in a certain period of time if no further activity occurs.
If this issue is still relevant, please leave a comment.
---
该 Issue 已较长时间无活动, 已被标记为 `stale`。
如无后续活动, 将在一段时间后自动关闭。
如仍需跟进, 请回复评论。
close-issue-message: |
This issue has been automatically closed due to inactivity.
If the problem still exists, feel free to reopen or create a new issue with updated information.
---
该 Issue 因长期无活动已自动关闭。
如问题仍存在, 欢迎补充复现信息并重新打开或新建 Issue。
remove-stale-when-updated: true
debug-only: ${{ github.event_name == 'workflow_dispatch' && inputs.dry-run }}
+2 -2
View File
@@ -24,9 +24,9 @@ configs/session
configs/config.yaml
cmd_config.json
# Plugins and packages
# Plugins
addons/plugins
packages/python_interpreter/workplace
astrbot/builtin_stars/python_interpreter/workplace
tests/astrbot_plugin_openai
# Dashboard
+2
View File
@@ -132,6 +132,7 @@ uv run main.py
**社区维护**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私信](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
@@ -208,6 +209,7 @@ pre-commit install
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 8 群:1030353265
- 开发者群:975206796
### Telegram 群组
+1
View File
@@ -134,6 +134,7 @@ Or refer to the official documentation: [Deploy AstrBot from Source](https://ast
**Community Maintained**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili Direct Messages](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources
**Maintenues par la communauté**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ uv run main.py
**コミュニティメンテナンス**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili ダイレクトメッセージ](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ uv run main.py
**Поддерживаемые сообществом**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+1
View File
@@ -134,6 +134,7 @@ uv run main.py
**社群維護**
- [Matrix](https://github.com/stevessr/astrbot_plugin_matrix_adapter)
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
+4
View File
@@ -21,6 +21,9 @@ from astrbot.core.star.register import (
from astrbot.core.star.register import register_on_llm_request as on_llm_request
from astrbot.core.star.register import register_on_llm_response as on_llm_response
from astrbot.core.star.register import register_on_platform_loaded as on_platform_loaded
from astrbot.core.star.register import (
register_on_waiting_llm_request as on_waiting_llm_request,
)
from astrbot.core.star.register import register_permission_type as permission_type
from astrbot.core.star.register import (
register_platform_adapter_type as platform_adapter_type,
@@ -46,6 +49,7 @@ __all__ = [
"on_llm_request",
"on_llm_response",
"on_platform_loaded",
"on_waiting_llm_request",
"permission_type",
"platform_adapter_type",
"regex",
@@ -100,16 +100,8 @@ class Main(star.Star):
logger.error(f"ltm: {e}")
@filter.on_llm_response()
async def inject_reasoning(self, event: AstrMessageEvent, resp: LLMResponse):
"""在 LLM 响应后基于配置注入思考过程文本 / 在 LLM 响应后记录对话"""
umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {})
show_reasoning = cfg.get("display_reasoning_text", False)
if show_reasoning and resp.reasoning_content:
resp.completion_text = (
f"🤔 思考: {resp.reasoning_content}\n\n{resp.completion_text}"
)
async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse):
"""在 LLM 响应后记录对话"""
if self.ltm and self.ltm_enabled(event):
try:
await self.ltm.after_req_llm(event, resp)
@@ -7,6 +7,7 @@ from astrbot.api import logger, sp, star
from astrbot.api.event import AstrMessageEvent
from astrbot.api.message_components import Image, Reply
from astrbot.api.provider import Provider, ProviderRequest
from astrbot.core.agent.message import TextPart
from astrbot.core.provider.func_tool_manager import ToolSet
@@ -85,7 +86,9 @@ class ProcessLLMRequest:
req.image_urls,
)
if caption:
req.prompt = f"(Image Caption: {caption})\n\n{req.prompt}"
req.extra_user_content_parts.append(
TextPart(text=f"<image_caption>{caption}</image_caption>")
)
req.image_urls = []
except Exception as e:
logger.error(f"处理图片描述失败: {e}")
@@ -129,13 +132,14 @@ class ProcessLLMRequest:
else:
req.prompt = prefix + req.prompt
# 收集系统提醒信息
system_parts = []
# user identifier
if cfg.get("identifier"):
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
req.prompt = (
f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n{req.prompt}"
)
system_parts.append(f"User ID: {user_id}, Nickname: {user_nickname}")
# group name identifier
if cfg.get("group_name_display") and event.message_obj.group_id:
@@ -146,7 +150,7 @@ class ProcessLLMRequest:
return
group_name = event.message_obj.group.group_name
if group_name:
req.system_prompt += f"\nGroup name: {group_name}\n"
system_parts.append(f"Group name: {group_name}")
# time info
if cfg.get("datetime_system_prompt"):
@@ -162,7 +166,7 @@ class ProcessLLMRequest:
current_time = (
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
)
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
system_parts.append(f"Current datetime: {current_time}")
img_cap_prov_id: str = cfg.get("default_image_caption_provider_id") or ""
if req.conversation:
@@ -225,10 +229,17 @@ class ProcessLLMRequest:
except BaseException as e:
logger.error(f"处理引用图片失败: {e}")
# 3. 将所有部分组合成文本并直接注入到当前消息
# 3. 将所有部分组合成文本并添加到 extra_user_content_parts
# 确保引用内容被正确的标签包裹
quoted_content = "\n".join(content_parts)
# 确保所有内容都在<Quoted Message>标签内
quoted_text = f"<Quoted Message>\n{quoted_content}\n</Quoted Message>"
req.prompt = f"{quoted_text}\n\n{req.prompt}"
req.extra_user_content_parts.append(TextPart(text=quoted_text))
# 统一包裹所有系统提醒
if system_parts:
system_content = (
"<system_reminder>" + "\n".join(system_parts) + "</system_reminder>"
)
req.extra_user_content_parts.append(TextPart(text=system_content))
@@ -184,7 +184,8 @@ class ProviderCommands:
event.set_result(MessageEventResult().message("请输入序号。"))
return
if idx2 > len(self.context.get_all_tts_providers()) or idx2 < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
event.set_result(MessageEventResult().message("无效的提供商序号。"))
return
provider = self.context.get_all_tts_providers()[idx2 - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
@@ -198,7 +199,8 @@ class ProviderCommands:
event.set_result(MessageEventResult().message("请输入序号。"))
return
if idx2 > len(self.context.get_all_stt_providers()) or idx2 < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
event.set_result(MessageEventResult().message("无效的提供商序号。"))
return
provider = self.context.get_all_stt_providers()[idx2 - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
@@ -209,8 +211,8 @@ class ProviderCommands:
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
elif isinstance(idx, int):
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
event.set_result(MessageEventResult().message("无效的提供商序号。"))
return
provider = self.context.get_all_providers()[idx - 1]
id_ = provider.meta().id
await self.context.provider_manager.set_provider(
@@ -14,13 +14,13 @@ class TTSCommand:
async def tts(self, event: AstrMessageEvent):
"""开关文本转语音(会话级别)"""
umo = event.unified_msg_origin
ses_tts = SessionServiceManager.is_tts_enabled_for_session(umo)
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
cfg = self.context.get_config(umo=umo)
tts_enable = cfg["provider_tts_settings"]["enable"]
# 切换状态
new_status = not ses_tts
SessionServiceManager.set_tts_status_for_session(umo, new_status)
await SessionServiceManager.set_tts_status_for_session(umo, new_status)
status_text = "已开启" if new_status else "已关闭"
@@ -157,9 +157,8 @@ class Main(star.Star):
async def is_docker_available(self) -> bool:
"""Check if docker is available"""
try:
docker = aiodocker.Docker()
await docker.version()
await docker.close()
async with aiodocker.Docker() as docker:
await docker.version()
return True
except BaseException as e:
logger.info(f"检查 Docker 可用性: {e}")
@@ -279,14 +278,14 @@ class Main(star.Star):
@pi.command("repull")
async def pi_repull(self, event: AstrMessageEvent):
"""重新拉取沙箱镜像"""
docker = aiodocker.Docker()
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
await docker.images.delete(image_name, force=True)
except aiodocker.exceptions.DockerError:
pass
await docker.images.pull(image_name)
async with aiodocker.Docker() as docker:
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
await docker.images.delete(image_name, force=True)
except aiodocker.exceptions.DockerError:
pass
await docker.images.pull(image_name)
yield event.plain_result("重新拉取沙箱镜像成功。")
@pi.command("file")
@@ -371,137 +370,137 @@ class Main(star.Star):
obs = ""
n = 5
for i in range(n):
if i > 0:
logger.info(f"Try {i + 1}/{n}")
async with aiodocker.Docker() as docker:
for i in range(n):
if i > 0:
logger.info(f"Try {i + 1}/{n}")
PROMPT_ = PROMPT.format(
prompt=plain_text,
extra_input=extra_inputs,
extra_prompt=obs,
)
provider = self.context.get_using_provider()
llm_response = await provider.text_chat(
prompt=PROMPT_,
session_id=f"{event.session_id}_{magic_code}_{i!s}",
)
logger.debug(
"code interpreter llm gened code:" + llm_response.completion_text,
)
# 整理代码并保存
code_clean = await self.tidy_code(llm_response.completion_text)
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
f.write(code_clean)
# 启动容器
docker = aiodocker.Docker()
# 检查有没有image
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
except aiodocker.exceptions.DockerError:
# 拉取镜像
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
await docker.images.pull(image_name)
yield event.plain_result(
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})",
)
self.docker_host_astrbot_abs_path = self.config.get(
"docker_host_astrbot_abs_path",
"",
)
if self.docker_host_astrbot_abs_path:
host_shared = os.path.join(
self.docker_host_astrbot_abs_path,
self.shared_path,
PROMPT_ = PROMPT.format(
prompt=plain_text,
extra_input=extra_inputs,
extra_prompt=obs,
)
host_output = os.path.join(
self.docker_host_astrbot_abs_path,
output_path,
)
host_workplace = os.path.join(
self.docker_host_astrbot_abs_path,
workplace_path,
provider = self.context.get_using_provider()
llm_response = await provider.text_chat(
prompt=PROMPT_,
session_id=f"{event.session_id}_{magic_code}_{i!s}",
)
else:
host_shared = os.path.abspath(self.shared_path)
host_output = os.path.abspath(output_path)
host_workplace = os.path.abspath(workplace_path)
logger.debug(
"code interpreter llm gened code:" + llm_response.completion_text,
)
logger.debug(
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}",
)
# 整理代码并保存
code_clean = await self.tidy_code(llm_response.completion_text)
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
f.write(code_clean)
container = await docker.containers.run(
{
"Image": image_name,
"Cmd": ["python", "exec.py"],
"Memory": 512 * 1024 * 1024,
"NanoCPUs": 1000000000,
"HostConfig": {
"Binds": [
f"{host_shared}:/astrbot_sandbox/shared:ro",
f"{host_output}:/astrbot_sandbox/output:rw",
f"{host_workplace}:/astrbot_sandbox:rw",
],
},
"Env": [f"MAGIC_CODE={magic_code}"],
"AutoRemove": True,
},
)
# 检查有没有image
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
except aiodocker.exceptions.DockerError:
# 拉取镜像
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
await docker.images.pull(image_name)
logger.debug(f"Container {container.id} created.")
logs = await self.run_container(container)
yield event.plain_result(
f"使用沙箱执行代码中,请稍等...(尝试次数: {i + 1}/{n})",
)
logger.debug(f"Container {container.id} finished.")
logger.debug(f"Container {container.id} logs: {logs}")
# 发送结果
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
ok = False
traceback = ""
for idx, log in enumerate(logs):
match = re.match(pattern, log)
if match:
ok = True
if match.group(1) == "TEXT":
yield event.plain_result(match.group(2))
elif match.group(1) == "IMAGE":
image_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending image: {image_path}")
yield event.image_result(image_path)
elif match.group(1) == "FILE":
file_path = os.path.join(workplace_path, match.group(2))
# logger.debug(f"Sending file: {file_path}")
# file_s3_url = await self.file_upload(file_path)
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
chain: list[BaseMessageComponent] = [
File(name=file_name, file=file_path)
]
yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log or "[Error]: " in log:
traceback = "\n".join(logs[idx:])
if not ok:
if traceback:
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
else:
logger.warning(
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}",
self.docker_host_astrbot_abs_path = self.config.get(
"docker_host_astrbot_abs_path",
"",
)
if self.docker_host_astrbot_abs_path:
host_shared = os.path.join(
self.docker_host_astrbot_abs_path,
self.shared_path,
)
break
else:
# 成功了
self.user_file_msg_buffer.pop(event.get_session_id())
return
host_output = os.path.join(
self.docker_host_astrbot_abs_path,
output_path,
)
host_workplace = os.path.join(
self.docker_host_astrbot_abs_path,
workplace_path,
)
else:
host_shared = os.path.abspath(self.shared_path)
host_output = os.path.abspath(output_path)
host_workplace = os.path.abspath(workplace_path)
logger.debug(
f"host_shared: {host_shared}, host_output: {host_output}, host_workplace: {host_workplace}",
)
container = await docker.containers.run(
{
"Image": image_name,
"Cmd": ["python", "exec.py"],
"Memory": 512 * 1024 * 1024,
"NanoCPUs": 1000000000,
"HostConfig": {
"Binds": [
f"{host_shared}:/astrbot_sandbox/shared:ro",
f"{host_output}:/astrbot_sandbox/output:rw",
f"{host_workplace}:/astrbot_sandbox:rw",
],
},
"Env": [f"MAGIC_CODE={magic_code}"],
"AutoRemove": True,
},
)
logger.debug(f"Container {container.id} created.")
logs = await self.run_container(container)
logger.debug(f"Container {container.id} finished.")
logger.debug(f"Container {container.id} logs: {logs}")
# 发送结果
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
ok = False
traceback = ""
for idx, log in enumerate(logs):
match = re.match(pattern, log)
if match:
ok = True
if match.group(1) == "TEXT":
yield event.plain_result(match.group(2))
elif match.group(1) == "IMAGE":
image_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending image: {image_path}")
yield event.image_result(image_path)
elif match.group(1) == "FILE":
file_path = os.path.join(workplace_path, match.group(2))
# logger.debug(f"Sending file: {file_path}")
# file_s3_url = await self.file_upload(file_path)
# logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
chain: list[BaseMessageComponent] = [
File(name=file_name, file=file_path)
]
yield event.set_result(MessageEventResult(chain=chain))
elif (
"Traceback (most recent call last)" in log or "[Error]: " in log
):
traceback = "\n".join(logs[idx:])
if not ok:
if traceback:
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occurred:\n\n{traceback}\n Need to improve/fix the code."
else:
logger.warning(
f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}",
)
break
else:
# 成功了
self.user_file_msg_buffer.pop(event.get_session_id())
return
yield event.plain_result(
"经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。",
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.10.2"
__version__ = "4.11.2"
+243
View File
@@ -0,0 +1,243 @@
from typing import TYPE_CHECKING, Protocol, runtime_checkable
from ..message import Message
if TYPE_CHECKING:
from astrbot import logger
else:
try:
from astrbot import logger
except ImportError:
import logging
logger = logging.getLogger("astrbot")
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
from ..context.truncator import ContextTruncator
@runtime_checkable
class ContextCompressor(Protocol):
"""
Protocol for context compressors.
Provides an interface for compressing message lists.
"""
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens for the model.
Returns:
True if compression is needed, False otherwise.
"""
...
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Compress the message list.
Args:
messages: The original message list.
Returns:
The compressed message list.
"""
...
class TruncateByTurnsCompressor:
"""Truncate by turns compressor implementation.
Truncates the message list by removing older turns.
"""
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
"""Initialize the truncate by turns compressor.
Args:
truncate_turns: The number of turns to remove when truncating (default: 1).
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.truncate_turns = truncate_turns
self.compression_threshold = compression_threshold
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
truncator = ContextTruncator()
truncated_messages = truncator.truncate_by_dropping_oldest_turns(
messages,
drop_turns=self.truncate_turns,
)
return truncated_messages
def split_history(
messages: list[Message], keep_recent: int
) -> tuple[list[Message], list[Message], list[Message]]:
"""Split the message list into system messages, messages to summarize, and recent messages.
Ensures that the split point is between complete user-assistant pairs to maintain conversation flow.
Args:
messages: The original message list.
keep_recent: The number of latest messages to keep.
Returns:
tuple: (system_messages, messages_to_summarize, recent_messages)
"""
# keep the system messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) <= keep_recent:
return system_messages, [], non_system_messages
# Find the split point, ensuring recent_messages starts with a user message
# This maintains complete conversation turns
split_index = len(non_system_messages) - keep_recent
# Search backward from split_index to find the first user message
# This ensures recent_messages starts with a user message (complete turn)
while split_index > 0 and non_system_messages[split_index].role != "user":
# TODO: +=1 or -=1 ? calculate by tokens
split_index -= 1
# If we couldn't find a user message, keep all messages as recent
if split_index == 0:
return system_messages, [], non_system_messages
messages_to_summarize = non_system_messages[:split_index]
recent_messages = non_system_messages[split_index:]
return system_messages, messages_to_summarize, recent_messages
class LLMSummaryCompressor:
"""LLM-based summary compressor.
Uses LLM to summarize the old conversation history, keeping the latest messages.
"""
def __init__(
self,
provider: "Provider",
keep_recent: int = 4,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
):
"""Initialize the LLM summary compressor.
Args:
provider: The LLM provider instance.
keep_recent: The number of latest messages to keep (default: 4).
instruction_text: Custom instruction for summary generation.
compression_threshold: The compression trigger threshold (default: 0.82).
"""
self.provider = provider
self.keep_recent = keep_recent
self.compression_threshold = compression_threshold
self.instruction_text = instruction_text or (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
)
def should_compress(
self, messages: list[Message], current_tokens: int, max_tokens: int
) -> bool:
"""Check if compression is needed.
Args:
messages: The message list to evaluate.
current_tokens: The current token count.
max_tokens: The maximum allowed tokens.
Returns:
True if compression is needed, False otherwise.
"""
if max_tokens <= 0 or current_tokens <= 0:
return False
usage_rate = current_tokens / max_tokens
return usage_rate > self.compression_threshold
async def __call__(self, messages: list[Message]) -> list[Message]:
"""Use LLM to generate a summary of the conversation history.
Process:
1. Divide messages: keep the system message and the latest N messages.
2. Send the old messages + the instruction message to the LLM.
3. Reconstruct the message list: [system message, summary message, latest messages].
"""
if len(messages) <= self.keep_recent + 1:
return messages
system_messages, messages_to_summarize, recent_messages = split_history(
messages, self.keep_recent
)
if not messages_to_summarize:
return messages
# build payload
instruction_message = Message(role="user", content=self.instruction_text)
llm_payload = messages_to_summarize + [instruction_message]
# generate summary
try:
response = await self.provider.text_chat(contexts=llm_payload)
summary_content = response.completion_text
except Exception as e:
logger.error(f"Failed to generate summary: {e}")
return messages
# build result
result = []
result.extend(system_messages)
result.append(
Message(
role="user",
content=f"Our previous history conversation summary: {summary_content}",
)
)
result.append(
Message(
role="assistant",
content="Acknowledged the summary of our previous conversation history.",
)
)
result.extend(recent_messages)
return result
+35
View File
@@ -0,0 +1,35 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
from .compressor import ContextCompressor
from .token_counter import TokenCounter
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
@dataclass
class ContextConfig:
"""Context configuration class."""
max_context_tokens: int = 0
"""Maximum number of context tokens. <= 0 means no limit."""
enforce_max_turns: int = -1 # -1 means no limit
"""Maximum number of conversation turns to keep. -1 means no limit. Executed before compression."""
truncate_turns: int = 1
"""Number of conversation turns to discard at once when truncation is triggered.
Two processes will use this value:
1. Enforce max turns truncation.
2. Truncation by turns compression strategy.
"""
llm_compress_instruction: str | None = None
"""Instruction prompt for LLM-based compression."""
llm_compress_keep_recent: int = 0
"""Number of recent messages to keep during LLM-based compression."""
llm_compress_provider: "Provider | None" = None
"""LLM provider used for compression tasks. If None, truncation strategy is used."""
custom_token_counter: TokenCounter | None = None
"""Custom token counting method. If None, the default method is used."""
custom_compressor: ContextCompressor | None = None
"""Custom context compression method. If None, the default method is used."""
+120
View File
@@ -0,0 +1,120 @@
from astrbot import logger
from ..message import Message
from .compressor import LLMSummaryCompressor, TruncateByTurnsCompressor
from .config import ContextConfig
from .token_counter import EstimateTokenCounter
from .truncator import ContextTruncator
class ContextManager:
"""Context compression manager."""
def __init__(
self,
config: ContextConfig,
):
"""Initialize the context manager.
There are two strategies to handle context limit reached:
1. Truncate by turns: remove older messages by turns.
2. LLM-based compression: use LLM to summarize old messages.
Args:
config: The context configuration.
"""
self.config = config
self.token_counter = config.custom_token_counter or EstimateTokenCounter()
self.truncator = ContextTruncator()
if config.custom_compressor:
self.compressor = config.custom_compressor
elif config.llm_compress_provider:
self.compressor = LLMSummaryCompressor(
provider=config.llm_compress_provider,
keep_recent=config.llm_compress_keep_recent,
instruction_text=config.llm_compress_instruction,
)
else:
self.compressor = TruncateByTurnsCompressor(
truncate_turns=config.truncate_turns
)
async def process(
self, messages: list[Message], trusted_token_usage: int = 0
) -> list[Message]:
"""Process the messages.
Args:
messages: The original message list.
Returns:
The processed message list.
"""
try:
result = messages
# 1. 基于轮次的截断 (Enforce max turns)
if self.config.enforce_max_turns != -1:
result = self.truncator.truncate_by_turns(
result,
keep_most_recent_turns=self.config.enforce_max_turns,
drop_turns=self.config.truncate_turns,
)
# 2. 基于 token 的压缩
if self.config.max_context_tokens > 0:
total_tokens = self.token_counter.count_tokens(
result, trusted_token_usage
)
if self.compressor.should_compress(
result, total_tokens, self.config.max_context_tokens
):
result = await self._run_compression(result, total_tokens)
return result
except Exception as e:
logger.error(f"Error during context processing: {e}", exc_info=True)
return messages
async def _run_compression(
self, messages: list[Message], prev_tokens: int
) -> list[Message]:
"""
Compress/truncate the messages.
Args:
messages: The original message list.
prev_tokens: The token count before compression.
Returns:
The compressed/truncated message list.
"""
logger.debug("Compress triggered, starting compression...")
messages = await self.compressor(messages)
# double check
tokens_after_summary = self.token_counter.count_tokens(messages)
# calculate compress rate
compress_rate = (tokens_after_summary / self.config.max_context_tokens) * 100
logger.info(
f"Compress completed."
f" {prev_tokens} -> {tokens_after_summary} tokens,"
f" compression rate: {compress_rate:.2f}%.",
)
# last check
if self.compressor.should_compress(
messages, tokens_after_summary, self.config.max_context_tokens
):
logger.info(
"Context still exceeds max tokens after compression, applying halving truncation..."
)
# still need compress, truncate by half
messages = self.truncator.truncate_by_halving(messages)
return messages
@@ -0,0 +1,64 @@
import json
from typing import Protocol, runtime_checkable
from ..message import Message, TextPart
@runtime_checkable
class TokenCounter(Protocol):
"""
Protocol for token counters.
Provides an interface for counting tokens in message lists.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
"""Count the total tokens in the message list.
Args:
messages: The message list.
trusted_token_usage: The total token usage that LLM API returned.
For some cases, this value is more accurate.
But some API does not return it, so the value defaults to 0.
Returns:
The total token count.
"""
...
class EstimateTokenCounter:
"""Estimate token counter implementation.
Provides a simple estimation of token count based on character types.
"""
def count_tokens(
self, messages: list[Message], trusted_token_usage: int = 0
) -> int:
if trusted_token_usage > 0:
return trusted_token_usage
total = 0
for msg in messages:
content = msg.content
if isinstance(content, str):
total += self._estimate_tokens(content)
elif isinstance(content, list):
# 处理多模态内容
for part in content:
if isinstance(part, TextPart):
total += self._estimate_tokens(part.text)
# 处理 Tool Calls
if msg.tool_calls:
for tc in msg.tool_calls:
tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump())
total += self._estimate_tokens(tc_str)
return total
def _estimate_tokens(self, text: str) -> int:
chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"])
other_count = len(text) - chinese_count
return int(chinese_count * 0.6 + other_count * 0.3)
+141
View File
@@ -0,0 +1,141 @@
from ..message import Message
class ContextTruncator:
"""Context truncator."""
def fix_messages(self, messages: list[Message]) -> list[Message]:
fixed_messages = []
for message in messages:
if message.role == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def truncate_by_turns(
self,
messages: list[Message],
keep_most_recent_turns: int,
drop_turns: int = 1,
) -> list[Message]:
"""截断上下文列表,确保不超过最大长度。
一个 turn 包含一个 user 消息和一个 assistant 消息。
这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。
Args:
messages: 上下文列表
keep_most_recent_turns: 保留最近的对话轮数
drop_turns: 一次性丢弃的对话轮数
Returns:
截断后的上下文列表
"""
if keep_most_recent_turns == -1:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= keep_most_recent_turns:
return messages
num_to_keep = keep_most_recent_turns - drop_turns + 1
if num_to_keep <= 0:
truncated_contexts = []
else:
truncated_contexts = non_system_messages[-num_to_keep * 2 :]
# 找到第一个 role 为 user 的索引,确保上下文格式正确
index = next(
(i for i, item in enumerate(truncated_contexts) if item.role == "user"),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
result = system_messages + truncated_contexts
return self.fix_messages(result)
def truncate_by_dropping_oldest_turns(
self,
messages: list[Message],
drop_turns: int = 1,
) -> list[Message]:
"""丢弃最旧的 N 个对话轮次。"""
if drop_turns <= 0:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
if len(non_system_messages) // 2 <= drop_turns:
truncated_non_system = []
else:
truncated_non_system = non_system_messages[drop_turns * 2 :]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
elif truncated_non_system:
truncated_non_system = []
result = system_messages + truncated_non_system
return self.fix_messages(result)
def truncate_by_halving(
self,
messages: list[Message],
) -> list[Message]:
"""对半砍策略,删除 50% 的消息"""
if len(messages) <= 2:
return messages
first_non_system = 0
for i, msg in enumerate(messages):
if msg.role != "system":
first_non_system = i
break
system_messages = messages[:first_non_system]
non_system_messages = messages[first_non_system:]
messages_to_delete = len(non_system_messages) // 2
if messages_to_delete == 0:
return messages
truncated_non_system = non_system_messages[messages_to_delete:]
index = next(
(i for i, item in enumerate(truncated_non_system) if item.role == "user"),
None,
)
if index is not None:
truncated_non_system = truncated_non_system[index:]
result = system_messages + truncated_non_system
return self.fix_messages(result)
+32 -1
View File
@@ -12,7 +12,7 @@ class ContentPart(BaseModel):
__content_part_registry: ClassVar[dict[str, type["ContentPart"]]] = {}
type: str
type: Literal["text", "think", "image_url", "audio_url"]
def __init_subclass__(cls, **kwargs: Any) -> None:
super().__init_subclass__(**kwargs)
@@ -63,6 +63,28 @@ class TextPart(ContentPart):
text: str
class ThinkPart(ContentPart):
"""
>>> ThinkPart(think="I think I need to think about this.").model_dump()
{'type': 'think', 'think': 'I think I need to think about this.', 'encrypted': None}
"""
type: str = "think"
think: str
encrypted: str | None = None
"""Encrypted thinking content, or signature."""
def merge_in_place(self, other: Any) -> bool:
if not isinstance(other, ThinkPart):
return False
if self.encrypted:
return False
self.think += other.think
if other.encrypted:
self.encrypted = other.encrypted
return True
class ImageURLPart(ContentPart):
"""
>>> ImageURLPart(image_url="http://example.com/image.jpg").model_dump()
@@ -169,6 +191,15 @@ class Message(BaseModel):
)
return self
@model_serializer(mode="wrap")
def serialize(self, handler):
data = handler(self)
if self.tool_calls is None:
data.pop("tool_calls", None)
if self.tool_call_id is None:
data.pop("tool_call_id", None)
return data
class AssistantMessageSegment(Message):
"""A message segment from the assistant."""
@@ -13,6 +13,7 @@ from mcp.types import (
)
from astrbot import logger
from astrbot.core.agent.message import TextPart, ThinkPart
from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import (
MessageChain,
@@ -24,6 +25,10 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.provider.provider import Provider
from ..context.compressor import ContextCompressor
from ..context.config import ContextConfig
from ..context.manager import ContextManager
from ..context.token_counter import TokenCounter
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats
@@ -46,10 +51,47 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
streaming: bool = False,
# enforce max turns, will discard older turns when exceeded BEFORE compression
# -1 means no limit
enforce_max_turns: int = -1,
# llm compressor
llm_compress_instruction: str | None = None,
llm_compress_keep_recent: int = 0,
llm_compress_provider: Provider | None = None,
# truncate by turns compressor
truncate_turns: int = 1,
# customize
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.streaming = streaming
self.enforce_max_turns = enforce_max_turns
self.llm_compress_instruction = llm_compress_instruction
self.llm_compress_keep_recent = llm_compress_keep_recent
self.llm_compress_provider = llm_compress_provider
self.truncate_turns = truncate_turns
self.custom_token_counter = custom_token_counter
self.custom_compressor = custom_compressor
# we will do compress when:
# 1. before requesting LLM
# TODO: 2. after LLM output a tool call
self.context_config = ContextConfig(
# <=0 will never do compress
max_context_tokens=provider.provider_config.get("max_context_tokens", 0),
# enforce max turns before compression
enforce_max_turns=self.enforce_max_turns,
truncate_turns=self.truncate_turns,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self.llm_compress_provider,
custom_token_counter=self.custom_token_counter,
custom_compressor=self.custom_compressor,
)
self.context_manager = ContextManager(self.context_config)
self.provider = provider
self.final_llm_resp = None
self._state = AgentState.IDLE
@@ -77,10 +119,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages,
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
}
if self.streaming:
@@ -108,6 +151,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self._transition_state(AgentState.RUNNING)
llm_resp_result = None
# do truncate and compress
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self.run_context.messages = await self.context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
async for llm_response in self._iter_llm_responses():
if llm_response.is_chunk:
# update ttft
@@ -168,13 +217,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
# record the final assistant message
self.run_context.messages.append(
Message(
role="assistant",
content=llm_resp.completion_text or "*No response*",
),
)
parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
encrypted=llm_resp.reasoning_signature,
)
)
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
self.run_context.messages.append(Message(role="assistant", content=parts))
# call the on_agent_done hook
try:
await self.agent_hooks.on_agent_done(self.run_context, llm_resp)
except Exception as e:
@@ -213,10 +269,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
data=AgentResponseData(chain=result),
)
# 将结果添加到上下文中
parts = []
if llm_resp.reasoning_content or llm_resp.reasoning_signature:
parts.append(
ThinkPart(
think=llm_resp.reasoning_content,
encrypted=llm_resp.reasoning_signature,
)
)
parts.append(TextPart(text=llm_resp.completion_text or "*No response*"))
tool_calls_result = ToolCallsResult(
tool_calls_info=AssistantMessageSegment(
tool_calls=llm_resp.to_openai_to_calls_model(),
content=llm_resp.completion_text,
content=parts,
),
tool_calls_result=tool_call_result_blocks,
)
+6
View File
@@ -13,6 +13,12 @@ from astrbot.core.star.star_handler import EventType
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
# 执行事件钩子
if llm_response and llm_response.reasoning_content:
# we will use this in result_decorate stage to inject reasoning content to chain
run_context.context.event.set_extra(
"_llm_reasoning_content", llm_response.reasoning_content
)
await call_event_hook(
run_context.context.event,
EventType.OnLLMResponseEvent,
+26
View File
@@ -0,0 +1,26 @@
"""AstrBot 备份与恢复模块
提供数据导出和导入功能,支持用户在服务器迁移时一键备份和恢复所有数据。
"""
# 从 constants 模块导入共享常量
from .constants import (
BACKUP_MANIFEST_VERSION,
KB_METADATA_MODELS,
MAIN_DB_MODELS,
get_backup_directories,
)
# 导入导出器和导入器
from .exporter import AstrBotExporter
from .importer import AstrBotImporter, ImportPreCheckResult
__all__ = [
"AstrBotExporter",
"AstrBotImporter",
"ImportPreCheckResult",
"MAIN_DB_MODELS",
"KB_METADATA_MODELS",
"get_backup_directories",
"BACKUP_MANIFEST_VERSION",
]
+77
View File
@@ -0,0 +1,77 @@
"""AstrBot 备份模块共享常量
此文件定义了导出器和导入器共享的常量,确保两端配置一致。
"""
from sqlmodel import SQLModel
from astrbot.core.db.po import (
Attachment,
CommandConfig,
CommandConflict,
ConversationV2,
Persona,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
Preference,
)
from astrbot.core.knowledge_base.models import (
KBDocument,
KBMedia,
KnowledgeBase,
)
from astrbot.core.utils.astrbot_path import (
get_astrbot_config_path,
get_astrbot_plugin_data_path,
get_astrbot_plugin_path,
get_astrbot_t2i_templates_path,
get_astrbot_temp_path,
get_astrbot_webchat_path,
)
# ============================================================
# 共享常量 - 确保导出和导入端配置一致
# ============================================================
# 主数据库模型类映射
MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
"platform_stats": PlatformStat,
"conversations": ConversationV2,
"personas": Persona,
"preferences": Preference,
"platform_message_history": PlatformMessageHistory,
"platform_sessions": PlatformSession,
"attachments": Attachment,
"command_configs": CommandConfig,
"command_conflicts": CommandConflict,
}
# 知识库元数据模型类映射
KB_METADATA_MODELS: dict[str, type[SQLModel]] = {
"knowledge_bases": KnowledgeBase,
"kb_documents": KBDocument,
"kb_media": KBMedia,
}
def get_backup_directories() -> dict[str, str]:
"""获取需要备份的目录列表
使用 astrbot_path 模块动态获取路径,支持通过环境变量 ASTRBOT_ROOT 自定义根目录。
Returns:
dict: 键为备份文件中的目录名称,值为目录的绝对路径
"""
return {
"plugins": get_astrbot_plugin_path(), # 插件本体
"plugin_data": get_astrbot_plugin_data_path(), # 插件数据
"config": get_astrbot_config_path(), # 配置目录
"t2i_templates": get_astrbot_t2i_templates_path(), # T2I 模板
"webchat": get_astrbot_webchat_path(), # WebChat 数据
"temp": get_astrbot_temp_path(), # 临时文件
}
# 备份清单版本号
BACKUP_MANIFEST_VERSION = "1.1"
+477
View File
@@ -0,0 +1,477 @@
"""AstrBot 数据导出器
负责将所有数据导出为 ZIP 备份文件。
导出格式为 JSON,这是数据库无关的方案,支持未来向 MySQL/PostgreSQL 迁移。
"""
import hashlib
import json
import os
import zipfile
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
from sqlalchemy import select
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import (
get_astrbot_backups_path,
get_astrbot_data_path,
)
# 从共享常量模块导入
from .constants import (
BACKUP_MANIFEST_VERSION,
KB_METADATA_MODELS,
MAIN_DB_MODELS,
get_backup_directories,
)
if TYPE_CHECKING:
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
class AstrBotExporter:
"""AstrBot 数据导出器
导出内容:
- 主数据库所有表(data/data_v4.db
- 知识库元数据(data/knowledge_base/kb.db
- 每个知识库的向量文档数据
- 配置文件(data/cmd_config.json
- 附件文件
- 知识库多媒体文件
- 插件目录(data/plugins
- 插件数据目录(data/plugin_data
- 配置目录(data/config
- T2I 模板目录(data/t2i_templates
- WebChat 数据目录(data/webchat
- 临时文件目录(data/temp
"""
def __init__(
self,
main_db: BaseDatabase,
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
):
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
self._checksums: dict[str, str] = {}
async def export_all(
self,
output_dir: str | None = None,
progress_callback: Any | None = None,
) -> str:
"""导出所有数据到 ZIP 文件
Args:
output_dir: 输出目录
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
Returns:
str: 生成的 ZIP 文件路径
"""
if output_dir is None:
output_dir = get_astrbot_backups_path()
# 确保输出目录存在
Path(output_dir).mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
zip_filename = f"astrbot_backup_{timestamp}.zip"
zip_path = os.path.join(output_dir, zip_filename)
logger.info(f"开始导出备份到 {zip_path}")
try:
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
# 1. 导出主数据库
if progress_callback:
await progress_callback("main_db", 0, 100, "正在导出主数据库...")
main_data = await self._export_main_database()
main_db_json = json.dumps(
main_data, ensure_ascii=False, indent=2, default=str
)
zf.writestr("databases/main_db.json", main_db_json)
self._add_checksum("databases/main_db.json", main_db_json)
if progress_callback:
await progress_callback("main_db", 100, 100, "主数据库导出完成")
# 2. 导出知识库数据
kb_meta_data: dict[str, Any] = {
"knowledge_bases": [],
"kb_documents": [],
"kb_media": [],
}
if self.kb_manager:
if progress_callback:
await progress_callback(
"kb_metadata", 0, 100, "正在导出知识库元数据..."
)
kb_meta_data = await self._export_kb_metadata()
kb_meta_json = json.dumps(
kb_meta_data, ensure_ascii=False, indent=2, default=str
)
zf.writestr("databases/kb_metadata.json", kb_meta_json)
self._add_checksum("databases/kb_metadata.json", kb_meta_json)
if progress_callback:
await progress_callback(
"kb_metadata", 100, 100, "知识库元数据导出完成"
)
# 导出每个知识库的文档数据
kb_insts = self.kb_manager.kb_insts
total_kbs = len(kb_insts)
for idx, (kb_id, kb_helper) in enumerate(kb_insts.items()):
if progress_callback:
await progress_callback(
"kb_documents",
idx,
total_kbs,
f"正在导出知识库 {kb_helper.kb.kb_name} 的文档数据...",
)
doc_data = await self._export_kb_documents(kb_helper)
doc_json = json.dumps(
doc_data, ensure_ascii=False, indent=2, default=str
)
doc_path = f"databases/kb_{kb_id}/documents.json"
zf.writestr(doc_path, doc_json)
self._add_checksum(doc_path, doc_json)
# 导出 FAISS 索引文件
await self._export_faiss_index(zf, kb_helper, kb_id)
# 导出知识库多媒体文件
await self._export_kb_media_files(zf, kb_helper, kb_id)
if progress_callback:
await progress_callback(
"kb_documents", total_kbs, total_kbs, "知识库文档导出完成"
)
# 3. 导出配置文件
if progress_callback:
await progress_callback("config", 0, 100, "正在导出配置文件...")
if os.path.exists(self.config_path):
with open(self.config_path, encoding="utf-8") as f:
config_content = f.read()
zf.writestr("config/cmd_config.json", config_content)
self._add_checksum("config/cmd_config.json", config_content)
if progress_callback:
await progress_callback("config", 100, 100, "配置文件导出完成")
# 4. 导出附件文件
if progress_callback:
await progress_callback("attachments", 0, 100, "正在导出附件...")
await self._export_attachments(zf, main_data.get("attachments", []))
if progress_callback:
await progress_callback("attachments", 100, 100, "附件导出完成")
# 5. 导出插件和其他目录
if progress_callback:
await progress_callback(
"directories", 0, 100, "正在导出插件和数据目录..."
)
dir_stats = await self._export_directories(zf)
if progress_callback:
await progress_callback("directories", 100, 100, "目录导出完成")
# 6. 生成 manifest
if progress_callback:
await progress_callback("manifest", 0, 100, "正在生成清单...")
manifest = self._generate_manifest(main_data, kb_meta_data, dir_stats)
manifest_json = json.dumps(manifest, ensure_ascii=False, indent=2)
zf.writestr("manifest.json", manifest_json)
if progress_callback:
await progress_callback("manifest", 100, 100, "清单生成完成")
logger.info(f"备份导出完成: {zip_path}")
return zip_path
except Exception as e:
logger.error(f"备份导出失败: {e}")
# 清理失败的文件
if os.path.exists(zip_path):
os.remove(zip_path)
raise
async def _export_main_database(self) -> dict[str, list[dict]]:
"""导出主数据库所有表"""
export_data: dict[str, list[dict]] = {}
async with self.main_db.get_db() as session:
for table_name, model_class in MAIN_DB_MODELS.items():
try:
result = await session.execute(select(model_class))
records = result.scalars().all()
export_data[table_name] = [
self._model_to_dict(record) for record in records
]
logger.debug(
f"导出表 {table_name}: {len(export_data[table_name])} 条记录"
)
except Exception as e:
logger.warning(f"导出表 {table_name} 失败: {e}")
export_data[table_name] = []
return export_data
async def _export_kb_metadata(self) -> dict[str, list[dict]]:
"""导出知识库元数据库"""
if not self.kb_manager:
return {"knowledge_bases": [], "kb_documents": [], "kb_media": []}
export_data: dict[str, list[dict]] = {}
async with self.kb_manager.kb_db.get_db() as session:
for table_name, model_class in KB_METADATA_MODELS.items():
try:
result = await session.execute(select(model_class))
records = result.scalars().all()
export_data[table_name] = [
self._model_to_dict(record) for record in records
]
logger.debug(
f"导出知识库表 {table_name}: {len(export_data[table_name])} 条记录"
)
except Exception as e:
logger.warning(f"导出知识库表 {table_name} 失败: {e}")
export_data[table_name] = []
return export_data
async def _export_kb_documents(self, kb_helper: Any) -> dict[str, Any]:
"""导出知识库的文档块数据"""
try:
from astrbot.core.db.vec_db.faiss_impl.vec_db import FaissVecDB
vec_db: FaissVecDB = kb_helper.vec_db
if not vec_db or not vec_db.document_storage:
return {"documents": []}
# 获取所有文档
docs = await vec_db.document_storage.get_documents(
metadata_filters={},
offset=0,
limit=None, # 获取全部
)
return {"documents": docs}
except Exception as e:
logger.warning(f"导出知识库文档失败: {e}")
return {"documents": []}
async def _export_faiss_index(
self,
zf: zipfile.ZipFile,
kb_helper: Any,
kb_id: str,
) -> None:
"""导出 FAISS 索引文件"""
try:
index_path = kb_helper.kb_dir / "index.faiss"
if index_path.exists():
archive_path = f"databases/kb_{kb_id}/index.faiss"
zf.write(str(index_path), archive_path)
logger.debug(f"导出 FAISS 索引: {archive_path}")
except Exception as e:
logger.warning(f"导出 FAISS 索引失败: {e}")
async def _export_kb_media_files(
self, zf: zipfile.ZipFile, kb_helper: Any, kb_id: str
) -> None:
"""导出知识库的多媒体文件"""
try:
media_dir = kb_helper.kb_medias_dir
if not media_dir.exists():
return
for root, _, files in os.walk(media_dir):
for file in files:
file_path = Path(root) / file
# 计算相对路径
rel_path = file_path.relative_to(kb_helper.kb_dir)
archive_path = f"files/kb_media/{kb_id}/{rel_path}"
zf.write(str(file_path), archive_path)
except Exception as e:
logger.warning(f"导出知识库媒体文件失败: {e}")
async def _export_directories(
self, zf: zipfile.ZipFile
) -> dict[str, dict[str, int]]:
"""导出插件和其他数据目录
Returns:
dict: 每个目录的统计信息 {dir_name: {"files": count, "size": bytes}}
"""
stats: dict[str, dict[str, int]] = {}
backup_directories = get_backup_directories()
for dir_name, dir_path in backup_directories.items():
full_path = Path(dir_path)
if not full_path.exists():
logger.debug(f"目录不存在,跳过: {full_path}")
continue
file_count = 0
total_size = 0
try:
for root, dirs, files in os.walk(full_path):
# 跳过 __pycache__ 目录
dirs[:] = [d for d in dirs if d != "__pycache__"]
for file in files:
# 跳过 .pyc 文件
if file.endswith(".pyc"):
continue
file_path = Path(root) / file
try:
# 计算相对路径
rel_path = file_path.relative_to(full_path)
archive_path = f"directories/{dir_name}/{rel_path}"
zf.write(str(file_path), archive_path)
file_count += 1
total_size += file_path.stat().st_size
except Exception as e:
logger.warning(f"导出文件 {file_path} 失败: {e}")
stats[dir_name] = {"files": file_count, "size": total_size}
logger.debug(
f"导出目录 {dir_name}: {file_count} 个文件, {total_size} 字节"
)
except Exception as e:
logger.warning(f"导出目录 {dir_path} 失败: {e}")
stats[dir_name] = {"files": 0, "size": 0}
return stats
async def _export_attachments(
self, zf: zipfile.ZipFile, attachments: list[dict]
) -> None:
"""导出附件文件"""
for attachment in attachments:
try:
file_path = attachment.get("path", "")
if file_path and os.path.exists(file_path):
# 使用 attachment_id 作为文件名
attachment_id = attachment.get("attachment_id", "")
ext = os.path.splitext(file_path)[1]
archive_path = f"files/attachments/{attachment_id}{ext}"
zf.write(file_path, archive_path)
except Exception as e:
logger.warning(f"导出附件失败: {e}")
def _model_to_dict(self, record: Any) -> dict:
"""将 SQLModel 实例转换为字典
这是数据库无关的序列化方式,支持未来迁移到其他数据库。
"""
# 使用 SQLModel 内置的 model_dump 方法(如果可用)
if hasattr(record, "model_dump"):
data = record.model_dump(mode="python")
# 处理 datetime 类型
for key, value in data.items():
if isinstance(value, datetime):
data[key] = value.isoformat()
return data
# 回退到手动提取
data = {}
# 使用 inspect 获取表信息
from sqlalchemy import inspect as sa_inspect
mapper = sa_inspect(record.__class__)
for column in mapper.columns:
value = getattr(record, column.name)
# 处理 datetime 类型 - 统一转为 ISO 格式字符串
if isinstance(value, datetime):
value = value.isoformat()
data[column.name] = value
return data
def _add_checksum(self, path: str, content: str | bytes) -> None:
"""计算并添加文件校验和"""
if isinstance(content, str):
content = content.encode("utf-8")
checksum = hashlib.sha256(content).hexdigest()
self._checksums[path] = f"sha256:{checksum}"
def _generate_manifest(
self,
main_data: dict[str, list[dict]],
kb_meta_data: dict[str, list[dict]],
dir_stats: dict[str, dict[str, int]] | None = None,
) -> dict:
"""生成备份清单"""
if dir_stats is None:
dir_stats = {}
# 收集知识库 ID
kb_document_tables = {}
if self.kb_manager:
for kb_id in self.kb_manager.kb_insts.keys():
kb_document_tables[kb_id] = "documents"
# 收集附件文件列表
attachment_files = []
for attachment in main_data.get("attachments", []):
attachment_id = attachment.get("attachment_id", "")
path = attachment.get("path", "")
if attachment_id and path:
ext = os.path.splitext(path)[1]
attachment_files.append(f"{attachment_id}{ext}")
# 收集知识库媒体文件
kb_media_files: dict[str, list[str]] = {}
if self.kb_manager:
for kb_id, kb_helper in self.kb_manager.kb_insts.items():
media_files: list[str] = []
media_dir = kb_helper.kb_medias_dir
if media_dir.exists():
for root, _, files in os.walk(media_dir):
for file in files:
media_files.append(file)
if media_files:
kb_media_files[kb_id] = media_files
manifest = {
"version": BACKUP_MANIFEST_VERSION,
"astrbot_version": VERSION,
"exported_at": datetime.now(timezone.utc).isoformat(),
"origin": "exported", # 标记备份来源:exported=本实例导出, uploaded=用户上传
"schema_version": {
"main_db": "v4",
"kb_db": "v1",
},
"tables": {
"main_db": list(main_data.keys()),
"kb_metadata": list(kb_meta_data.keys()),
"kb_documents": kb_document_tables,
},
"files": {
"attachments": attachment_files,
"kb_media": kb_media_files,
},
"directories": list(dir_stats.keys()),
"checksums": self._checksums,
"statistics": {
"main_db": {
table: len(records) for table, records in main_data.items()
},
"kb_metadata": {
table: len(records) for table, records in kb_meta_data.items()
},
"directories": dir_stats,
},
}
return manifest
+761
View File
@@ -0,0 +1,761 @@
"""AstrBot 数据导入器
负责从 ZIP 备份文件恢复所有数据。
导入时进行版本校验:
- 主版本(前两位)不同时直接拒绝导入
- 小版本(第三位)不同时提示警告,用户可选择强制导入
- 版本匹配时也需要用户确认
"""
import json
import os
import shutil
import zipfile
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Any
from sqlalchemy import delete
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import (
get_astrbot_data_path,
get_astrbot_knowledge_base_path,
)
from astrbot.core.utils.version_comparator import VersionComparator
# 从共享常量模块导入
from .constants import (
KB_METADATA_MODELS,
MAIN_DB_MODELS,
get_backup_directories,
)
if TYPE_CHECKING:
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
def _get_major_version(version_str: str) -> str:
"""提取版本的主版本部分(前两位)
Args:
version_str: 版本字符串,如 "4.9.1", "4.10.0-beta"
Returns:
主版本字符串,如 "4.9", "4.10"
"""
if not version_str:
return "0.0"
# 移除 v 前缀和预发布标签
version = version_str.lower().replace("v", "").split("-")[0].split("+")[0]
parts = [p for p in version.split(".") if p] # 过滤空字符串
if len(parts) >= 2:
return f"{parts[0]}.{parts[1]}"
elif len(parts) == 1 and parts[0]:
return f"{parts[0]}.0"
return "0.0"
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
KB_PATH = get_astrbot_knowledge_base_path()
@dataclass
class ImportPreCheckResult:
"""导入预检查结果
用于在实际导入前检查备份文件的版本兼容性,
并返回确认信息让用户决定是否继续导入。
"""
# 检查是否通过(文件有效且版本可导入)
valid: bool = False
# 是否可以导入(版本兼容)
can_import: bool = False
# 版本状态: match(完全匹配), minor_diff(小版本差异), major_diff(主版本不同,拒绝)
version_status: str = ""
# 备份文件中的 AstrBot 版本
backup_version: str = ""
# 当前运行的 AstrBot 版本
current_version: str = VERSION
# 备份创建时间
backup_time: str = ""
# 确认消息(显示给用户)
confirm_message: str = ""
# 警告消息列表
warnings: list[str] = field(default_factory=list)
# 错误消息(如果检查失败)
error: str = ""
# 备份包含的内容摘要
backup_summary: dict = field(default_factory=dict)
def to_dict(self) -> dict:
return {
"valid": self.valid,
"can_import": self.can_import,
"version_status": self.version_status,
"backup_version": self.backup_version,
"current_version": self.current_version,
"backup_time": self.backup_time,
"confirm_message": self.confirm_message,
"warnings": self.warnings,
"error": self.error,
"backup_summary": self.backup_summary,
}
class ImportResult:
"""导入结果"""
def __init__(self):
self.success = True
self.imported_tables: dict[str, int] = {}
self.imported_files: dict[str, int] = {}
self.imported_directories: dict[str, int] = {}
self.warnings: list[str] = []
self.errors: list[str] = []
def add_warning(self, msg: str) -> None:
self.warnings.append(msg)
logger.warning(msg)
def add_error(self, msg: str) -> None:
self.errors.append(msg)
self.success = False
logger.error(msg)
def to_dict(self) -> dict:
return {
"success": self.success,
"imported_tables": self.imported_tables,
"imported_files": self.imported_files,
"imported_directories": self.imported_directories,
"warnings": self.warnings,
"errors": self.errors,
}
class AstrBotImporter:
"""AstrBot 数据导入器
导入备份文件中的所有数据,包括:
- 主数据库所有表
- 知识库元数据和文档
- 配置文件
- 附件文件
- 知识库多媒体文件
- 插件目录(data/plugins
- 插件数据目录(data/plugin_data
- 配置目录(data/config
- T2I 模板目录(data/t2i_templates
- WebChat 数据目录(data/webchat
- 临时文件目录(data/temp
"""
def __init__(
self,
main_db: BaseDatabase,
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
kb_root_dir: str = KB_PATH,
):
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
self.kb_root_dir = kb_root_dir
def pre_check(self, zip_path: str) -> ImportPreCheckResult:
"""预检查备份文件
在实际导入前检查备份文件的有效性和版本兼容性。
返回检查结果供前端显示确认对话框。
Args:
zip_path: ZIP 备份文件路径
Returns:
ImportPreCheckResult: 预检查结果
"""
result = ImportPreCheckResult()
result.current_version = VERSION
if not os.path.exists(zip_path):
result.error = f"备份文件不存在: {zip_path}"
return result
try:
with zipfile.ZipFile(zip_path, "r") as zf:
# 读取 manifest
try:
manifest_data = zf.read("manifest.json")
manifest = json.loads(manifest_data)
except KeyError:
result.error = "备份文件缺少 manifest.json,不是有效的 AstrBot 备份"
return result
except json.JSONDecodeError as e:
result.error = f"manifest.json 格式错误: {e}"
return result
# 提取基本信息
result.backup_version = manifest.get("astrbot_version", "未知")
result.backup_time = manifest.get("exported_at", "未知")
result.valid = True
# 构建备份摘要
result.backup_summary = {
"tables": list(manifest.get("tables", {}).keys()),
"has_knowledge_bases": manifest.get("has_knowledge_bases", False),
"has_config": manifest.get("has_config", False),
"directories": manifest.get("directories", []),
}
# 检查版本兼容性
version_check = self._check_version_compatibility(result.backup_version)
result.version_status = version_check["status"]
result.can_import = version_check["can_import"]
# 版本信息由前端根据 version_status 和 i18n 生成显示
# 不再将版本消息添加到 warnings 列表中,避免中文硬编码
# warnings 列表保留用于其他非版本相关的警告
return result
except zipfile.BadZipFile:
result.error = "无效的 ZIP 文件"
return result
except Exception as e:
result.error = f"检查备份文件失败: {e}"
return result
def _check_version_compatibility(self, backup_version: str) -> dict:
"""检查版本兼容性
规则:
- 主版本(前两位,如 4.9)必须一致,否则拒绝
- 小版本(第三位,如 4.9.1 vs 4.9.2)不同时,警告但允许导入
Returns:
dict: {status, can_import, message}
"""
if not backup_version:
return {
"status": "major_diff",
"can_import": False,
"message": "备份文件缺少版本信息",
}
# 提取主版本(前两位)进行比较
backup_major = _get_major_version(backup_version)
current_major = _get_major_version(VERSION)
# 比较主版本
if VersionComparator.compare_version(backup_major, current_major) != 0:
return {
"status": "major_diff",
"can_import": False,
"message": (
f"主版本不兼容: 备份版本 {backup_version}, 当前版本 {VERSION}"
f"跨主版本导入可能导致数据损坏,请使用相同主版本的 AstrBot。"
),
}
# 比较完整版本
version_cmp = VersionComparator.compare_version(backup_version, VERSION)
if version_cmp != 0:
return {
"status": "minor_diff",
"can_import": True,
"message": (
f"小版本差异: 备份版本 {backup_version}, 当前版本 {VERSION}"
),
}
return {
"status": "match",
"can_import": True,
"message": "版本匹配",
}
async def import_all(
self,
zip_path: str,
mode: str = "replace", # "replace" 清空后导入
progress_callback: Any | None = None,
) -> ImportResult:
"""从 ZIP 文件导入所有数据
Args:
zip_path: ZIP 备份文件路径
mode: 导入模式,目前仅支持 "replace"(清空后导入)
progress_callback: 进度回调函数,接收参数 (stage, current, total, message)
Returns:
ImportResult: 导入结果
"""
result = ImportResult()
if not os.path.exists(zip_path):
result.add_error(f"备份文件不存在: {zip_path}")
return result
logger.info(f"开始从 {zip_path} 导入备份")
try:
with zipfile.ZipFile(zip_path, "r") as zf:
# 1. 读取并验证 manifest
if progress_callback:
await progress_callback("validate", 0, 100, "正在验证备份文件...")
try:
manifest_data = zf.read("manifest.json")
manifest = json.loads(manifest_data)
except KeyError:
result.add_error("备份文件缺少 manifest.json")
return result
except json.JSONDecodeError as e:
result.add_error(f"manifest.json 格式错误: {e}")
return result
# 版本校验
try:
self._validate_version(manifest)
except ValueError as e:
result.add_error(str(e))
return result
if progress_callback:
await progress_callback("validate", 100, 100, "验证完成")
# 2. 导入主数据库
if progress_callback:
await progress_callback("main_db", 0, 100, "正在导入主数据库...")
try:
main_data_content = zf.read("databases/main_db.json")
main_data = json.loads(main_data_content)
if mode == "replace":
await self._clear_main_db()
imported = await self._import_main_database(main_data)
result.imported_tables.update(imported)
except Exception as e:
result.add_error(f"导入主数据库失败: {e}")
return result
if progress_callback:
await progress_callback("main_db", 100, 100, "主数据库导入完成")
# 3. 导入知识库
if self.kb_manager and "databases/kb_metadata.json" in zf.namelist():
if progress_callback:
await progress_callback("kb", 0, 100, "正在导入知识库...")
try:
kb_meta_content = zf.read("databases/kb_metadata.json")
kb_meta_data = json.loads(kb_meta_content)
if mode == "replace":
await self._clear_kb_data()
await self._import_knowledge_bases(zf, kb_meta_data, result)
except Exception as e:
result.add_warning(f"导入知识库失败: {e}")
if progress_callback:
await progress_callback("kb", 100, 100, "知识库导入完成")
# 4. 导入配置文件
if progress_callback:
await progress_callback("config", 0, 100, "正在导入配置文件...")
if "config/cmd_config.json" in zf.namelist():
try:
config_content = zf.read("config/cmd_config.json")
# 备份现有配置
if os.path.exists(self.config_path):
backup_path = f"{self.config_path}.bak"
shutil.copy2(self.config_path, backup_path)
with open(self.config_path, "wb") as f:
f.write(config_content)
result.imported_files["config"] = 1
except Exception as e:
result.add_warning(f"导入配置文件失败: {e}")
if progress_callback:
await progress_callback("config", 100, 100, "配置文件导入完成")
# 5. 导入附件文件
if progress_callback:
await progress_callback("attachments", 0, 100, "正在导入附件...")
attachment_count = await self._import_attachments(
zf, main_data.get("attachments", [])
)
result.imported_files["attachments"] = attachment_count
if progress_callback:
await progress_callback("attachments", 100, 100, "附件导入完成")
# 6. 导入插件和其他目录
if progress_callback:
await progress_callback(
"directories", 0, 100, "正在导入插件和数据目录..."
)
dir_stats = await self._import_directories(zf, manifest, result)
result.imported_directories = dir_stats
if progress_callback:
await progress_callback("directories", 100, 100, "目录导入完成")
logger.info(f"备份导入完成: {result.to_dict()}")
return result
except zipfile.BadZipFile:
result.add_error("无效的 ZIP 文件")
return result
except Exception as e:
result.add_error(f"导入失败: {e}")
return result
def _validate_version(self, manifest: dict) -> None:
"""验证版本兼容性 - 仅允许相同主版本导入
注意:此方法仅在 import_all 中调用,用于双重校验。
前端应先调用 pre_check 获取详细的版本信息并让用户确认。
"""
backup_version = manifest.get("astrbot_version")
if not backup_version:
raise ValueError("备份文件缺少版本信息")
# 使用新的版本兼容性检查
version_check = self._check_version_compatibility(backup_version)
if version_check["status"] == "major_diff":
raise ValueError(version_check["message"])
# minor_diff 和 match 都允许导入
if version_check["status"] == "minor_diff":
logger.warning(f"版本差异警告: {version_check['message']}")
async def _clear_main_db(self) -> None:
"""清空主数据库所有表"""
async with self.main_db.get_db() as session:
async with session.begin():
for table_name, model_class in MAIN_DB_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空表 {table_name}")
except Exception as e:
logger.warning(f"清空表 {table_name} 失败: {e}")
async def _clear_kb_data(self) -> None:
"""清空知识库数据"""
if not self.kb_manager:
return
# 清空知识库元数据表
async with self.kb_manager.kb_db.get_db() as session:
async with session.begin():
for table_name, model_class in KB_METADATA_MODELS.items():
try:
await session.execute(delete(model_class))
logger.debug(f"已清空知识库表 {table_name}")
except Exception as e:
logger.warning(f"清空知识库表 {table_name} 失败: {e}")
# 删除知识库文件目录
for kb_id in list(self.kb_manager.kb_insts.keys()):
try:
kb_helper = self.kb_manager.kb_insts[kb_id]
await kb_helper.terminate()
if kb_helper.kb_dir.exists():
shutil.rmtree(kb_helper.kb_dir)
except Exception as e:
logger.warning(f"清理知识库 {kb_id} 失败: {e}")
self.kb_manager.kb_insts.clear()
async def _import_main_database(
self, data: dict[str, list[dict]]
) -> dict[str, int]:
"""导入主数据库数据"""
imported: dict[str, int] = {}
async with self.main_db.get_db() as session:
async with session.begin():
for table_name, rows in data.items():
model_class = MAIN_DB_MODELS.get(table_name)
if not model_class:
logger.warning(f"未知的表: {table_name}")
continue
count = 0
for row in rows:
try:
# 转换 datetime 字符串为 datetime 对象
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入记录到 {table_name} 失败: {e}")
imported[table_name] = count
logger.debug(f"导入表 {table_name}: {count} 条记录")
return imported
async def _import_knowledge_bases(
self,
zf: zipfile.ZipFile,
kb_meta_data: dict[str, list[dict]],
result: ImportResult,
) -> None:
"""导入知识库数据"""
if not self.kb_manager:
return
# 1. 导入知识库元数据
async with self.kb_manager.kb_db.get_db() as session:
async with session.begin():
for table_name, rows in kb_meta_data.items():
model_class = KB_METADATA_MODELS.get(table_name)
if not model_class:
continue
count = 0
for row in rows:
try:
row = self._convert_datetime_fields(row, model_class)
obj = model_class(**row)
session.add(obj)
count += 1
except Exception as e:
logger.warning(f"导入知识库记录到 {table_name} 失败: {e}")
result.imported_tables[f"kb_{table_name}"] = count
# 2. 导入每个知识库的文档和文件
for kb_data in kb_meta_data.get("knowledge_bases", []):
kb_id = kb_data.get("kb_id")
if not kb_id:
continue
# 创建知识库目录
kb_dir = Path(self.kb_root_dir) / kb_id
kb_dir.mkdir(parents=True, exist_ok=True)
# 导入文档数据
doc_path = f"databases/kb_{kb_id}/documents.json"
if doc_path in zf.namelist():
try:
doc_content = zf.read(doc_path)
doc_data = json.loads(doc_content)
# 导入到文档存储数据库
await self._import_kb_documents(kb_id, doc_data)
except Exception as e:
result.add_warning(f"导入知识库 {kb_id} 的文档失败: {e}")
# 导入 FAISS 索引
faiss_path = f"databases/kb_{kb_id}/index.faiss"
if faiss_path in zf.namelist():
try:
target_path = kb_dir / "index.faiss"
with zf.open(faiss_path) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入知识库 {kb_id} 的 FAISS 索引失败: {e}")
# 导入媒体文件
media_prefix = f"files/kb_media/{kb_id}/"
for name in zf.namelist():
if name.startswith(media_prefix):
try:
rel_path = name[len(media_prefix) :]
target_path = kb_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
except Exception as e:
result.add_warning(f"导入媒体文件 {name} 失败: {e}")
# 3. 重新加载知识库实例
await self.kb_manager.load_kbs()
async def _import_kb_documents(self, kb_id: str, doc_data: dict) -> None:
"""导入知识库文档到向量数据库"""
from astrbot.core.db.vec_db.faiss_impl.document_storage import DocumentStorage
kb_dir = Path(self.kb_root_dir) / kb_id
doc_db_path = kb_dir / "doc.db"
# 初始化文档存储
doc_storage = DocumentStorage(str(doc_db_path))
await doc_storage.initialize()
try:
documents = doc_data.get("documents", [])
for doc in documents:
try:
await doc_storage.insert_document(
doc_id=doc.get("doc_id", ""),
text=doc.get("text", ""),
metadata=json.loads(doc.get("metadata", "{}")),
)
except Exception as e:
logger.warning(f"导入文档块失败: {e}")
finally:
await doc_storage.close()
async def _import_attachments(
self,
zf: zipfile.ZipFile,
attachments: list[dict],
) -> int:
"""导入附件文件"""
count = 0
attachments_dir = Path(self.config_path).parent / "attachments"
attachments_dir.mkdir(parents=True, exist_ok=True)
attachment_prefix = "files/attachments/"
for name in zf.namelist():
if name.startswith(attachment_prefix) and name != attachment_prefix:
try:
# 从附件记录中找到原始路径
attachment_id = os.path.splitext(os.path.basename(name))[0]
original_path = None
for att in attachments:
if att.get("attachment_id") == attachment_id:
original_path = att.get("path")
break
if original_path:
target_path = Path(original_path)
else:
target_path = attachments_dir / os.path.basename(name)
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
count += 1
except Exception as e:
logger.warning(f"导入附件 {name} 失败: {e}")
return count
async def _import_directories(
self,
zf: zipfile.ZipFile,
manifest: dict,
result: ImportResult,
) -> dict[str, int]:
"""导入插件和其他数据目录
Args:
zf: ZIP 文件对象
manifest: 备份清单
result: 导入结果对象
Returns:
dict: 每个目录导入的文件数量
"""
dir_stats: dict[str, int] = {}
# 检查备份版本是否支持目录备份(需要版本 >= 1.1)
backup_version = manifest.get("version", "1.0")
if VersionComparator.compare_version(backup_version, "1.1") < 0:
logger.info("备份版本不支持目录备份,跳过目录导入")
return dir_stats
backed_up_dirs = manifest.get("directories", [])
backup_directories = get_backup_directories()
for dir_name in backed_up_dirs:
if dir_name not in backup_directories:
result.add_warning(f"未知的目录类型: {dir_name}")
continue
target_dir = Path(backup_directories[dir_name])
archive_prefix = f"directories/{dir_name}/"
file_count = 0
try:
# 获取该目录下的所有文件
dir_files = [
name
for name in zf.namelist()
if name.startswith(archive_prefix) and name != archive_prefix
]
if not dir_files:
continue
# 备份现有目录(如果存在)
if target_dir.exists():
backup_path = Path(f"{target_dir}.bak")
if backup_path.exists():
shutil.rmtree(backup_path)
shutil.move(str(target_dir), str(backup_path))
logger.debug(f"已备份现有目录 {target_dir}{backup_path}")
# 创建目标目录
target_dir.mkdir(parents=True, exist_ok=True)
# 解压文件
for name in dir_files:
try:
# 计算相对路径
rel_path = name[len(archive_prefix) :]
if not rel_path: # 跳过目录条目
continue
target_path = target_dir / rel_path
target_path.parent.mkdir(parents=True, exist_ok=True)
with zf.open(name) as src, open(target_path, "wb") as dst:
dst.write(src.read())
file_count += 1
except Exception as e:
result.add_warning(f"导入文件 {name} 失败: {e}")
dir_stats[dir_name] = file_count
logger.debug(f"导入目录 {dir_name}: {file_count} 个文件")
except Exception as e:
result.add_warning(f"导入目录 {dir_name} 失败: {e}")
dir_stats[dir_name] = 0
return dir_stats
def _convert_datetime_fields(self, row: dict, model_class: type) -> dict:
"""转换 datetime 字符串字段为 datetime 对象"""
result = row.copy()
# 获取模型的 datetime 字段
from sqlalchemy import inspect as sa_inspect
try:
mapper = sa_inspect(model_class)
for column in mapper.columns:
if column.name in result and result[column.name] is not None:
# 检查是否是 datetime 类型的列
from sqlalchemy import DateTime
if isinstance(column.type, DateTime):
value = result[column.name]
if isinstance(value, str):
# 解析 ISO 格式的日期时间字符串
result[column.name] = datetime.fromisoformat(value)
except Exception:
pass
return result
+2
View File
@@ -80,6 +80,8 @@ class AstrBotConfig(dict):
if v["type"] == "object":
conf[k] = {}
_parse_schema(v["items"], conf[k])
elif v["type"] == "template_list":
conf[k] = default
else:
conf[k] = default
+125 -33
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.10.2"
VERSION = "4.11.2"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -83,6 +83,16 @@ DEFAULT_CONFIG = {
"default_personality": "default",
"persona_pool": ["*"],
"prompt_prefix": "{{prompt}}",
"context_limit_reached_strategy": "truncate_by_turns", # or llm_compress
"llm_compress_instruction": (
"Based on our full conversation history, produce a concise summary of key takeaways and/or project progress.\n"
"1. Systematically cover all core topics discussed and the final conclusion/outcome for each; clearly highlight the latest primary focus.\n"
"2. If any tools were used, summarize tool usage (total call count) and extract the most valuable insights from tool outputs.\n"
"3. If there was an initial user goal, state it first and describe the current progress/status.\n"
"4. Write the summary in the user's language.\n"
),
"llm_compress_keep_recent": 4,
"llm_compress_provider_id": "",
"max_context_length": -1,
"dequeue_context_length": 1,
"streaming_response": False,
@@ -179,6 +189,7 @@ class ChatProviderTemplate(TypedDict):
model: str
modalities: list
custom_extra_body: dict[str, Any]
max_context_tokens: int
CHAT_PROVIDER_TEMPLATE = {
@@ -187,6 +198,7 @@ CHAT_PROVIDER_TEMPLATE = {
"model": "",
"modalities": [],
"custom_extra_body": {},
"max_context_tokens": 0,
}
"""
@@ -227,7 +239,7 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0",
"port": 6196,
},
"OneBot v11": {
"OneBot v11 (QQ 个人号等)": {
"id": "default",
"type": "aiocqhttp",
"enable": False,
@@ -235,16 +247,6 @@ CONFIG_METADATA_2 = {
"ws_reverse_port": 6199,
"ws_reverse_token": "",
},
"WeChatPadPro": {
"id": "wechatpadpro",
"type": "wechatpadpro",
"enable": False,
"admin_key": "stay33",
"host": "这里填写你的局域网IP或者公网服务器IP",
"port": 8059,
"wpp_active_message_poll": False,
"wpp_active_message_poll_interval": 3,
},
"微信公众平台": {
"id": "weixin_official_account",
"type": "weixin_official_account",
@@ -905,6 +907,7 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.anthropic.com/v1",
"timeout": 120,
"anth_thinking_config": {"budget": 0},
},
"Moonshot": {
"id": "moonshot",
@@ -920,7 +923,7 @@ CONFIG_METADATA_2 = {
"xAI": {
"id": "xai",
"provider": "xai",
"type": "openai_chat_completion",
"type": "xai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
@@ -1286,7 +1289,7 @@ CONFIG_METADATA_2 = {
"minimax-is-timber-weight": False,
"minimax-voice-id": "female-shaonv",
"minimax-timber-weight": '[\n {\n "voice_id": "Chinese (Mandarin)_Warm_Girl",\n "weight": 25\n },\n {\n "voice_id": "Chinese (Mandarin)_BashfulGirl",\n "weight": 50\n }\n]',
"minimax-voice-emotion": "neutral",
"minimax-voice-emotion": "auto",
"minimax-voice-latex": False,
"minimax-voice-english-normalization": False,
"timeout": 20,
@@ -1450,7 +1453,32 @@ CONFIG_METADATA_2 = {
"description": "自定义请求体参数",
"type": "dict",
"items": {},
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值",
"hint": "用于在请求时添加额外的参数,如 temperature、top_p、max_tokens 等",
"template_schema": {
"temperature": {
"name": "Temperature",
"description": "温度参数",
"hint": "控制输出的随机性,范围通常为 0-2。值越高越随机。",
"type": "float",
"default": 0.6,
"slider": {"min": 0, "max": 2, "step": 0.1},
},
"top_p": {
"name": "Top-p",
"description": "Top-p 采样",
"hint": "核采样参数,范围通常为 0-1。控制模型考虑的概率质量。",
"type": "float",
"default": 1.0,
"slider": {"min": 0, "max": 1, "step": 0.01},
},
"max_tokens": {
"name": "Max Tokens",
"description": "最大令牌数",
"hint": "生成的最大令牌数。",
"type": "int",
"default": 8192,
},
},
},
"provider": {
"type": "string",
@@ -1787,6 +1815,17 @@ CONFIG_METADATA_2 = {
},
},
},
"anth_thinking_config": {
"description": "Thinking Config",
"type": "object",
"items": {
"budget": {
"description": "Thinking Budget",
"type": "int",
"hint": "Anthropic thinking.budget_tokens param. Must >= 1024. See: https://platform.claude.com/docs/en/build-with-claude/extended-thinking",
},
},
},
"minimax-group-id": {
"type": "string",
"description": "用户组",
@@ -1858,15 +1897,18 @@ CONFIG_METADATA_2 = {
"minimax-voice-emotion": {
"type": "string",
"description": "情绪",
"hint": "控制合成语音的情绪",
"hint": "控制合成语音的情绪。当为 auto 时,将根据文本内容自动选择情绪。",
"options": [
"auto",
"happy",
"sad",
"angry",
"fearful",
"disgusted",
"surprised",
"neutral",
"calm",
"fluent",
"whisper",
],
},
"minimax-voice-latex": {
@@ -1993,6 +2035,11 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"max_context_tokens": {
"description": "模型上下文窗口大小",
"type": "int",
"hint": "模型最大上下文 Token 大小。如果为 0,则会自动从模型元数据填充(如有),也可手动修改。",
},
"dify_api_key": {
"description": "API Key",
"type": "string",
@@ -2500,6 +2547,66 @@ CONFIG_METADATA_3 = {
# "provider_settings.enable": True,
# },
# },
"truncate_and_compress": {
"description": "上下文管理策略",
"type": "object",
"items": {
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.context_limit_reached_strategy": {
"description": "超出模型上下文窗口时的处理方式",
"type": "string",
"options": ["truncate_by_turns", "llm_compress"],
"labels": ["按对话轮数截断", "由 LLM 压缩上下文"],
"condition": {
"provider_settings.agent_runner_type": "local",
},
"hint": "",
},
"provider_settings.llm_compress_instruction": {
"description": "上下文压缩提示词",
"type": "text",
"hint": "如果为空则使用默认提示词。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_keep_recent": {
"description": "压缩时保留最近对话轮数",
"type": "int",
"hint": "始终保留的最近 N 轮对话。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.llm_compress_provider_id": {
"description": "用于上下文压缩的模型提供商 ID",
"type": "string",
"_special": "select_provider",
"hint": "留空时将降级为“按对话轮数截断”的策略。",
"condition": {
"provider_settings.context_limit_reached_strategy": "llm_compress",
"provider_settings.agent_runner_type": "local",
},
},
},
},
"others": {
"description": "其他配置",
"type": "object",
@@ -2564,22 +2671,6 @@ CONFIG_METADATA_3 = {
"provider_settings.streaming_response": True,
},
},
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条,-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
@@ -3049,4 +3140,5 @@ DEFAULT_VALUE_MAP = {
"text": "",
"list": [],
"object": {},
"template_list": [],
}
+4
View File
@@ -69,6 +69,7 @@ class ConversationManager:
persona_id=conv_v2.persona_id,
created_at=created_at,
updated_at=updated_at,
token_usage=conv_v2.token_usage,
)
async def new_conversation(
@@ -256,6 +257,7 @@ class ConversationManager:
history: list[dict] | None = None,
title: str | None = None,
persona_id: str | None = None,
token_usage: int | None = None,
) -> None:
"""更新会话的对话.
@@ -263,6 +265,7 @@ class ConversationManager:
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
token_usage (int | None): token 使用量。None 表示不更新
"""
if not conversation_id:
@@ -274,6 +277,7 @@ class ConversationManager:
title=title,
persona_id=persona_id,
content=history,
token_usage=token_usage,
)
async def update_conversation_title(
+1
View File
@@ -90,6 +90,7 @@ class AstrBotCoreLifecycle:
# 初始化 UMOP 配置路由器
self.umop_config_router = UmopConfigRouter(sp=sp)
await self.umop_config_router.initialize()
# 初始化 AstrBot 配置管理器
self.astrbot_config_mgr = AstrBotConfigManager(
+1
View File
@@ -152,6 +152,7 @@ class BaseDatabase(abc.ABC):
title: str | None = None,
persona_id: str | None = None,
content: list[dict] | None = None,
token_usage: int | None = None,
) -> None:
"""Update a conversation's history."""
...
@@ -0,0 +1,61 @@
"""Migration script to add token_usage column to conversations table.
This migration adds the token_usage field to track token consumption for each conversation.
Changes:
- Adds token_usage column to conversations table (default: 0)
"""
from sqlalchemy import text
from astrbot.api import logger, sp
from astrbot.core.db import BaseDatabase
async def migrate_token_usage(db_helper: BaseDatabase):
"""Add token_usage column to conversations table.
This migration adds a new column to track token consumption in conversations.
"""
# 检查是否已经完成迁移
migration_done = await db_helper.get_preference(
"global", "global", "migration_done_token_usage_1"
)
if migration_done:
return
logger.info("开始执行数据库迁移(添加 conversations.token_usage 列)...")
# 这里只适配了 SQLite。因为截止至这一版本,AstrBot 仅支持 SQLite。
try:
async with db_helper.get_db() as session:
# 检查列是否已存在
result = await session.execute(text("PRAGMA table_info(conversations)"))
columns = result.fetchall()
column_names = [col[1] for col in columns]
if "token_usage" in column_names:
logger.info("token_usage 列已存在,跳过迁移")
await sp.put_async(
"global", "global", "migration_done_token_usage_1", True
)
return
# 添加 token_usage 列
await session.execute(
text(
"ALTER TABLE conversations ADD COLUMN token_usage INTEGER NOT NULL DEFAULT 0"
)
)
await session.commit()
logger.info("token_usage 列添加成功")
# 标记迁移完成
await sp.put_async("global", "global", "migration_done_token_usage_1", True)
logger.info("token_usage 迁移完成")
except Exception as e:
logger.error(f"迁移过程中发生错误: {e}", exc_info=True)
raise
+7
View File
@@ -54,6 +54,11 @@ class ConversationV2(SQLModel, table=True):
)
title: str | None = Field(default=None, max_length=255)
persona_id: str | None = Field(default=None)
token_usage: int = Field(default=0, nullable=False)
"""content is a list of OpenAI-formated messages in list[dict] format.
token_usage is the total token value of the messages.
when 0, will use estimated token counter.
"""
__table_args__ = (
UniqueConstraint(
@@ -313,6 +318,8 @@ class Conversation:
persona_id: str | None = ""
created_at: int = 0
updated_at: int = 0
token_usage: int = 0
"""对话的总 token 数量。AstrBot 会保留最近一次 LLM 请求返回的总 token 数,方便统计。token_usage 可能为 0,表示未知。"""
class Personality(TypedDict):
+5 -1
View File
@@ -241,7 +241,9 @@ class SQLiteDatabase(BaseDatabase):
session.add(new_conversation)
return new_conversation
async def update_conversation(self, cid, title=None, persona_id=None, content=None):
async def update_conversation(
self, cid, title=None, persona_id=None, content=None, token_usage=None
):
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
@@ -255,6 +257,8 @@ class SQLiteDatabase(BaseDatabase):
values["persona_id"] = persona_id
if content is not None:
values["content"] = content
if token_usage is not None:
values["token_usage"] = token_usage
if not values:
return None
query = query.values(**values)
@@ -149,8 +149,16 @@ class RecursiveCharacterChunker(BaseChunker):
分割后的文本块列表
"""
chunk_size = chunk_size or self.chunk_size
overlap = overlap or self.chunk_overlap
if chunk_size is None:
chunk_size = self.chunk_size
if overlap is None:
overlap = self.chunk_overlap
if chunk_size <= 0:
raise ValueError("chunk_size must be greater than 0")
if overlap < 0:
raise ValueError("chunk_overlap must be non-negative")
if overlap >= chunk_size:
raise ValueError("chunk_overlap must be less than chunk_size")
result = []
for i in range(0, len(text), chunk_size - overlap):
end = min(i + chunk_size, len(text))
+1 -1
View File
@@ -58,7 +58,7 @@ def is_plugin_path(pathname):
return False
norm_path = os.path.normpath(pathname)
return ("data/plugins" in norm_path) or ("packages/" in norm_path)
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
def get_short_level_name(level_name):
@@ -38,7 +38,7 @@ class AgentRequestSubStage(Stage):
)
return
if not SessionServiceManager.should_process_llm_request(event):
if not await SessionServiceManager.should_process_llm_request(event):
logger.debug(
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
)
@@ -1,11 +1,12 @@
"""本地 Agent 模式的 LLM 调用 Stage"""
import asyncio
import copy
import json
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.response import AgentStats
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
@@ -23,6 +24,7 @@ from astrbot.core.provider.entities import (
)
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
@@ -40,11 +42,6 @@ class InternalAgentSubStage(Stage):
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
@@ -64,6 +61,25 @@ class InternalAgentSubStage(Stage):
"moonshotai_api_key", ""
)
# 上下文管理相关
self.context_limit_reached_strategy: str = settings.get(
"context_limit_reached_strategy", "truncate_by_turns"
)
self.llm_compress_instruction: str = settings.get(
"llm_compress_instruction", ""
)
self.llm_compress_keep_recent: int = settings.get("llm_compress_keep_recent", 4)
self.llm_compress_provider_id: str = settings.get(
"llm_compress_provider_id", ""
)
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
self.max_context_length - 1,
)
if self.dequeue_context_length <= 0:
self.dequeue_context_length = 1
self.conv_manager = ctx.plugin_manager.context.conversation_manager
def _select_provider(self, event: AstrMessageEvent):
@@ -166,34 +182,6 @@ class InternalAgentSubStage(Stage):
},
)
def _truncate_contexts(
self,
contexts: list[dict],
) -> list[dict]:
"""截断上下文列表,确保不超过最大长度"""
if self.max_context_length == -1:
return contexts
if len(contexts) // 2 <= self.max_context_length:
return contexts
truncated_contexts = contexts[
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
]
# 找到第一个role 为 user 的索引,确保上下文格式正确
index = next(
(
i
for i, item in enumerate(truncated_contexts)
if item.get("role") == "user"
),
None,
)
if index is not None and index > 0:
truncated_contexts = truncated_contexts[index:]
return truncated_contexts
def _modalities_fix(
self,
provider: Provider,
@@ -294,6 +282,8 @@ class InternalAgentSubStage(Stage):
event: AstrMessageEvent,
req: ProviderRequest,
llm_response: LLMResponse | None,
all_messages: list[Message],
runner_stats: AgentStats | None,
):
if (
not req
@@ -307,222 +297,265 @@ class InternalAgentSubStage(Stage):
logger.debug("LLM 响应为空,不保存记录。")
return
if req.contexts is None:
req.contexts = []
# using agent context messages to save to history
message_to_save = []
for message in all_messages:
if message.role == "system":
# we do not save system messages to history
continue
if message.role in ["assistant", "user"] and getattr(
message, "_no_save", None
):
# we do not save user and assistant messages that are marked as _no_save
continue
message_to_save.append(message.model_dump())
# get token usage from agent runner stats
token_usage = None
if runner_stats:
token_usage = runner_stats.token_usage.total
# 历史上下文
messages = copy.deepcopy(req.contexts)
# 这一轮对话请求的用户输入
messages.append(await req.assemble_context())
# 这一轮对话的 LLM 响应
if req.tool_calls_result:
if not isinstance(req.tool_calls_result, list):
messages.extend(req.tool_calls_result.to_openai_messages())
elif isinstance(req.tool_calls_result, list):
for tcr in req.tool_calls_result:
messages.extend(tcr.to_openai_messages())
messages.append(
{
"role": "assistant",
"content": llm_response.completion_text or "*No response*",
}
)
messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
req.conversation.cid,
history=messages,
history=message_to_save,
token_usage=token_usage,
)
def _fix_messages(self, messages: list[dict]) -> list[dict]:
"""验证并且修复上下文"""
fixed_messages = []
for message in messages:
if message.get("role") == "tool":
# tool block 前面必须要有 user 和 assistant block
if len(fixed_messages) < 2:
# 这种情况可能是上下文被截断导致的
# 我们直接将之前的上下文都清空
fixed_messages = []
else:
fixed_messages.append(message)
else:
fixed_messages.append(message)
return fixed_messages
def _get_compress_provider(self) -> Provider | None:
if not self.llm_compress_provider_id:
return None
if self.context_limit_reached_strategy != "llm_compress":
return None
provider = self.ctx.plugin_manager.context.get_provider_by_id(
self.llm_compress_provider_id,
)
if provider is None:
logger.warning(
f"未找到指定的上下文压缩模型 {self.llm_compress_provider_id},将跳过压缩。",
)
return None
if not isinstance(provider, Provider):
logger.warning(
f"指定的上下文压缩模型 {self.llm_compress_provider_id} 不是对话模型,将跳过压缩。"
)
return None
return provider
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
) -> AsyncGenerator[None, None]:
req: ProviderRequest | None = None
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。")
return
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
logger.debug("ready to request llm provider")
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
logger.debug("acquired session lock for llm request")
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
try:
provider = self._select_provider(event)
if provider is None:
return
if not isinstance(provider, Provider):
logger.error(
f"选择的提供商类型无效({type(provider)}),跳过 LLM 请求处理。"
)
return
if req.conversation:
req.contexts = json.loads(req.conversation.history)
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
else:
req = ProviderRequest()
req.prompt = ""
req.image_urls = []
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if provider_wake_prefix and not event.message_str.startswith(
provider_wake_prefix
):
# 检查消息内容是否有效,避免空消息触发钩子
has_provider_request = event.get_extra("provider_request") is not None
has_valid_message = bool(event.message_str and event.message_str.strip())
if not has_provider_request and not has_valid_message:
logger.debug("skip llm request: empty message and no provider_request")
return
logger.debug("ready to request llm provider")
# 通知等待调用 LLM(在获取锁之前)
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
logger.debug("acquired session lock for llm request")
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), (
"provider_request 必须是 ProviderRequest 类型。"
)
if req.conversation:
req.contexts = json.loads(req.conversation.history)
else:
req = ProviderRequest()
req.prompt = ""
req.image_urls = []
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if provider_wake_prefix and not event.message_str.startswith(
provider_wake_prefix
):
return
req.prompt = event.message_str[len(provider_wake_prefix) :]
# func_tool selection 现在已经转移到 astrbot/builtin_stars/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
conversation = await self._get_session_conv(event)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
event.set_extra("provider_request", req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# apply file extract
if self.file_extract_enabled:
try:
await self._apply_file_extract(event, req)
except Exception as e:
logger.error(f"Error occurred while applying file extract: {e}")
if not req.prompt and not req.image_urls:
return
req.prompt = event.message_str[len(provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
req.image_urls.append(image_path)
# call event hook
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
conversation = await self._get_session_conv(event)
req.conversation = conversation
req.contexts = json.loads(conversation.history)
# apply knowledge base feature
await self._apply_kb(event, req)
event.set_extra("provider_request", req)
# truncate contexts to fit max length
# NOW moved to ContextManager inside ToolLoopAgentRunner
# if req.contexts:
# req.contexts = self._truncate_contexts(req.contexts)
# self._fix_messages(req.contexts)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
# apply file extract
if self.file_extract_enabled:
try:
await self._apply_file_extract(event, req)
except Exception as e:
logger.error(f"Error occurred while applying file extract: {e}")
# check provider modalities, if provider does not support image/tool_use, clear them in request.
self._modalities_fix(provider, req)
if not req.prompt and not req.image_urls:
return
# filter tools, only keep tools from this pipeline's selected plugins
self._plugin_tool_fix(event, req)
# call event hook
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
# apply knowledge base feature
await self._apply_kb(event, req)
# truncate contexts to fit max length
if req.contexts:
req.contexts = self._truncate_contexts(req.contexts)
self._fix_messages(req.contexts)
# session_id
if not req.session_id:
req.session_id = event.unified_msg_origin
# check provider modalities, if provider does not support image/tool_use, clear them in request.
self._modalities_fix(provider, req)
# filter tools, only keep tools from this pipeline's selected plugins
self._plugin_tool_fix(event, req)
stream_to_general = (
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
# 备份 req.contexts
backup_contexts = copy.deepcopy(req.contexts)
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
)
astr_agent_ctx = AstrAgentContext(
context=self.ctx.plugin_manager.context,
event=event,
)
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
)
if streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
show_reasoning=self.show_reasoning,
),
),
stream_to_general = (
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain()
.message(final_llm_resp.completion_text)
.chain
)
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
# run agent
agent_runner = AgentRunner()
logger.debug(
f"handle provider[id: {provider.provider_config['id']}] request: {req}",
)
astr_agent_ctx = AstrAgentContext(
context=self.ctx.plugin_manager.context,
event=event,
)
# inject model context length limit
if provider.provider_config.get("max_context_tokens", 0) <= 0:
model = provider.get_model()
if model_info := LLM_METADATAS.get(model):
provider.provider_config["max_context_tokens"] = model_info[
"limit"
]["context"]
await agent_runner.reset(
provider=provider,
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=self.tool_call_timeout,
),
tool_executor=FunctionToolExecutor(),
agent_hooks=MAIN_AGENT_HOOKS,
streaming=streaming_response,
llm_compress_instruction=self.llm_compress_instruction,
llm_compress_keep_recent=self.llm_compress_keep_recent,
llm_compress_provider=self._get_compress_provider(),
truncate_turns=self.dequeue_context_length,
enforce_max_turns=self.max_context_length,
)
if streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
show_reasoning=self.show_reasoning,
),
)
else:
async for _ in run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
stream_to_general,
show_reasoning=self.show_reasoning,
):
),
)
yield
if agent_runner.done():
if final_llm_resp := agent_runner.get_final_llm_resp():
if final_llm_resp.completion_text:
chain = (
MessageChain()
.message(final_llm_resp.completion_text)
.chain
)
elif final_llm_resp.result_chain:
chain = final_llm_resp.result_chain.chain
else:
chain = MessageChain().chain
event.set_result(
MessageEventResult(
chain=chain,
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
async for _ in run_agent(
agent_runner,
self.max_step,
self.show_tool_use,
stream_to_general,
show_reasoning=self.show_reasoning,
):
yield
# 恢复备份的 contexts
req.contexts = backup_contexts
# 检查事件是否被停止,如果被停止则不保存历史记录
if not event.is_stopped():
await self._save_to_history(
event,
req,
agent_runner.get_final_llm_resp(),
agent_runner.run_context.messages,
agent_runner.stats,
)
await self._save_to_history(event, req, agent_runner.get_final_llm_resp())
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
# 异步处理 WebChat 特殊情况
if event.get_platform_name() == "webchat":
asyncio.create_task(self._handle_webchat(event, req, provider))
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
),
)
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=agent_runner.provider.get_model(),
provider_type=agent_runner.provider.meta().type,
),
)
except Exception as e:
logger.error(f"Error occurred while processing agent: {e}")
await event.send(
MessageChain().message(
f"Error occurred while processing agent request: {e}"
)
)
+65 -57
View File
@@ -98,6 +98,9 @@ class ResultDecorateStage(Stage):
self.content_safe_check_stage = stage_cls()
await self.content_safe_check_stage.initialize(ctx)
provider_cfg = ctx.astrbot_config.get("provider_settings", {})
self.show_reasoning = provider_cfg.get("display_reasoning_text", False)
def _split_text_by_words(self, text: str) -> list[str]:
"""使用分段词列表分段文本"""
if not self.split_words_pattern:
@@ -254,70 +257,75 @@ class ResultDecorateStage(Stage):
event.unified_msg_origin,
)
if (
self.ctx.astrbot_config["provider_tts_settings"]["enable"]
should_tts = (
bool(self.ctx.astrbot_config["provider_tts_settings"]["enable"])
and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event)
):
should_tts = self.tts_trigger_probability >= 1.0 or (
self.tts_trigger_probability > 0.0
and random.random() <= self.tts_trigger_probability
and await SessionServiceManager.should_process_tts_request(event)
and random.random() <= self.tts_trigger_probability
and tts_provider
)
if should_tts and not tts_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
)
if not should_tts:
logger.debug("跳过 TTS:触发概率未命中。")
elif not tts_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
)
else:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
)
new_chain.append(comp)
continue
if (
not should_tts
and self.show_reasoning
and event.get_extra("_llm_reasoning_content")
):
# inject reasoning content to chain
reasoning_content = event.get_extra("_llm_reasoning_content")
result.chain.insert(0, Plain(f"🤔 思考: {reasoning_content}\n"))
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path,
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
),
if should_tts and tts_provider:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info(f"TTS 请求: {comp.text}")
audio_path = await tts_provider.get_audio(comp.text)
logger.info(f"TTS 结果: {audio_path}")
if not audio_path:
logger.error(
f"由于 TTS 音频文件未找到,消息段转语音失败: {comp.text}",
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
else:
continue
use_file_service = self.ctx.astrbot_config[
"provider_tts_settings"
]["use_file_service"]
callback_api_base = self.ctx.astrbot_config[
"callback_api_base"
]
dual_output = self.ctx.astrbot_config[
"provider_tts_settings"
]["dual_output"]
url = None
if use_file_service and callback_api_base:
token = await file_token_service.register_file(
audio_path,
)
url = f"{callback_api_base}/api/file/{token}"
logger.debug(f"已注册:{url}")
new_chain.append(
Record(
file=url or audio_path,
url=url or audio_path,
),
)
if dual_output:
new_chain.append(comp)
except Exception:
logger.error(traceback.format_exc())
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
result.chain = new_chain
else:
new_chain.append(comp)
result.chain = new_chain
# 文本转图片
elif (
@@ -21,7 +21,7 @@ class SessionStatusCheckStage(Stage):
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
# 检查会话是否整体启用
if not SessionServiceManager.is_session_enabled(event.unified_msg_origin):
if not await SessionServiceManager.is_session_enabled(event.unified_msg_origin):
logger.debug(f"会话 {event.unified_msg_origin} 已被关闭,已终止事件传播。")
# workaround for #2309
+31 -4
View File
@@ -1,9 +1,10 @@
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Callable
from astrbot import logger
from astrbot.core.message.components import At, AtAll, Reply
from astrbot.core.message.message_event_result import MessageChain, MessageEventResult
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.session_plugin_manager import SessionPluginManager
@@ -13,6 +14,22 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
from ..context import PipelineContext
from ..stage import Stage, register_stage
UNIQUE_SESSION_ID_BUILDERS: dict[str, Callable[[AstrMessageEvent], str | None]] = {
"aiocqhttp": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
"slack": lambda e: f"{e.get_sender_id()}_{e.get_group_id()}",
"dingtalk": lambda e: e.get_sender_id(),
"qq_official": lambda e: e.get_sender_id(),
"qq_official_webhook": lambda e: e.get_sender_id(),
"lark": lambda e: f"{e.get_sender_id()}%{e.get_group_id()}",
"misskey": lambda e: f"{e.get_session_id()}_{e.get_sender_id()}",
}
def build_unique_session_id(event: AstrMessageEvent) -> str | None:
platform = event.get_platform_name()
builder = UNIQUE_SESSION_ID_BUILDERS.get(platform)
return builder(event) if builder else None
@register_stage
class WakingCheckStage(Stage):
@@ -53,18 +70,27 @@ class WakingCheckStage(Stage):
self.disable_builtin_commands = self.ctx.astrbot_config.get(
"disable_builtin_commands", False
)
platform_settings = self.ctx.astrbot_config.get("platform_settings", {})
self.unique_session = platform_settings.get("unique_session", False)
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
# apply unique session
if self.unique_session and event.message_obj.type == MessageType.GROUP_MESSAGE:
sid = build_unique_session_id(event)
if sid:
event.session_id = sid
# ignore bot self message
if (
self.ignore_bot_self_message
and event.get_self_id() == event.get_sender_id()
):
# 忽略机器人自己发送的消息
event.stop_event()
return
# 设置 sender 身份
event.message_str = event.message_str.strip()
for admin_id in self.ctx.astrbot_config["admins_id"]:
@@ -136,7 +162,8 @@ class WakingCheckStage(Stage):
):
if (
self.disable_builtin_commands
and handler.handler_module_path == "packages.builtin_commands.main"
and handler.handler_module_path
== "astrbot.builtin_stars.builtin_commands.main"
):
logger.debug("skipping builtin command")
continue
@@ -199,7 +226,7 @@ class WakingCheckStage(Stage):
event._extras.pop("parsed_params", None)
# 根据会话配置过滤插件处理器
activated_handlers = SessionPluginManager.filter_handlers_by_session(
activated_handlers = await SessionPluginManager.filter_handlers_by_session(
event,
activated_handlers,
)
-4
View File
@@ -70,10 +70,6 @@ class PlatformManager:
from .sources.qqofficial_webhook.qo_webhook_adapter import (
QQOfficialWebhookPlatformAdapter, # noqa: F401
)
case "wechatpadpro":
from .sources.wechatpadpro.wechatpadpro_adapter import (
WeChatPadProAdapter, # noqa: F401
)
case "lark":
from .sources.lark.lark_adapter import (
LarkPlatformAdapter, # noqa: F401
@@ -41,7 +41,6 @@ class AiocqhttpAdapter(Platform):
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.host = platform_config["ws_reverse_host"]
self.port = platform_config["ws_reverse_port"]
@@ -136,14 +135,11 @@ class AiocqhttpAdapter(Platform):
abm.group_id = str(event.group_id)
else:
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.message_str = ""
abm.message = []
abm.timestamp = int(time.time())
@@ -164,16 +160,11 @@ class AiocqhttpAdapter(Platform):
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = (
str(abm.sender.user_id) + "_" + str(event.group_id)
) # 也保留群组 id
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.message_str = ""
abm.message = []
abm.raw_message = event
@@ -210,16 +201,11 @@ class AiocqhttpAdapter(Platform):
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = (
abm.sender.user_id + "_" + str(event.group_id)
) # 也保留群组 id
else:
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.session_id = (
str(event.group_id)
if abm.type == MessageType.GROUP_MESSAGE
else abm.sender.user_id
)
abm.message_id = str(event.message_id)
abm.message = []
@@ -50,8 +50,6 @@ class DingtalkPlatformAdapter(Platform):
) -> None:
super().__init__(platform_config, event_queue)
self.unique_session = platform_settings["unique_session"]
self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"]
@@ -129,10 +127,7 @@ class DingtalkPlatformAdapter(Platform):
if id := self._id_to_sid(user.dingtalk_id):
abm.message.append(At(qq=id))
abm.group_id = message.conversation_id
if self.unique_session:
abm.session_id = abm.sender.user_id
else:
abm.session_id = abm.group_id
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id
@@ -25,6 +25,20 @@ class DingtalkMessageEvent(AstrMessageEvent):
client: dingtalk_stream.ChatbotHandler,
message: MessageChain,
):
icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message)
ats = []
# fixes: #4218
# 钉钉 at 机器人需要使用 sender_staff_id 而不是 sender_id
for i in message.chain:
if isinstance(i, Comp.At):
print(i.qq, icm.sender_id, icm.sender_staff_id)
if str(i.qq) in str(icm.sender_id or ""):
# 适配器会将开头的 $:LWCP_v1:$ 去掉,因此我们用 in 判断
ats.append(f"@{icm.sender_staff_id}")
else:
ats.append(f"@{i.qq}")
at_str = " ".join(ats)
for segment in message.chain:
if isinstance(segment, Comp.Plain):
segment.text = segment.text.strip()
@@ -32,7 +46,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
None,
client.reply_markdown,
segment.text,
segment.text,
f"{at_str} {segment.text}".strip(),
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
)
elif isinstance(segment, Comp.Image):
@@ -44,8 +44,6 @@ class LarkPlatformAdapter(Platform):
) -> None:
super().__init__(platform_config, event_queue)
self.unique_session = platform_settings["unique_session"]
self.appid = platform_config["app_id"]
self.appsecret = platform_config["app_secret"]
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
@@ -317,14 +315,8 @@ class LarkPlatformAdapter(Platform):
user_id=event.event.sender.sender_id.open_id,
nickname=event.event.sender.sender_id.open_id[:8],
)
# 独立会话
if not self.unique_session:
if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id
elif abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.group_id
else:
abm.session_id = abm.sender.user_id
@@ -91,8 +91,6 @@ class MisskeyPlatformAdapter(Platform):
except Exception:
self.max_download_bytes = None
self.unique_session = platform_settings["unique_session"]
self.api: MisskeyAPI | None = None
self._running = False
self.client_self_id = ""
@@ -641,7 +639,6 @@ class MisskeyPlatformAdapter(Platform):
sender_info,
self.client_self_id,
is_chat=False,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache,
@@ -690,7 +687,6 @@ class MisskeyPlatformAdapter(Platform):
sender_info,
self.client_self_id,
is_chat=True,
unique_session=self.unique_session,
)
cache_user_info(
self._user_cache,
@@ -720,7 +716,6 @@ class MisskeyPlatformAdapter(Platform):
self.client_self_id,
is_chat=False,
room_id=room_id,
unique_session=self.unique_session,
)
cache_user_info(
@@ -338,7 +338,6 @@ def create_base_message(
client_self_id: str,
is_chat: bool = False,
room_id: str | None = None,
unique_session: bool = False,
) -> AstrBotMessage:
"""创建基础消息对象"""
message = AstrBotMessage()
@@ -353,8 +352,6 @@ def create_base_message(
if room_id:
session_prefix = "room"
session_id = f"{session_prefix}%{room_id}"
if unique_session:
session_id += f"_{sender_info['sender_id']}"
message.type = MessageType.GROUP_MESSAGE
message.group_id = room_id
elif is_chat:
@@ -44,11 +44,8 @@ class botClient(Client):
message,
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
abm.group_id = cast(str, message.group_openid)
abm.session_id = abm.group_id
self._commit(abm)
# 收到频道消息
@@ -57,9 +54,8 @@ class botClient(Client):
message,
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.channel_id
)
abm.group_id = message.channel_id
abm.session_id = abm.group_id
self._commit(abm)
# 收到私聊消息
@@ -104,7 +100,6 @@ class QQOfficialPlatformAdapter(Platform):
self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
self.unique_session: bool = platform_settings["unique_session"]
qq_group = platform_config["enable_group_c2c"]
guild_dm = platform_config["enable_guild_direct_message"]
@@ -35,11 +35,8 @@ class botClient(Client):
message,
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
abm.group_id = cast(str, message.group_openid)
abm.session_id = abm.group_id
self._commit(abm)
# 收到频道消息
@@ -48,9 +45,8 @@ class botClient(Client):
message,
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.channel_id
)
abm.group_id = message.channel_id
abm.session_id = abm.group_id
self._commit(abm)
# 收到私聊消息
@@ -95,7 +91,6 @@ class QQOfficialWebhookPlatformAdapter(Platform):
self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"]
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
intents = botpy.Intents(
@@ -142,7 +142,12 @@ class SatoriPlatformAdapter(Platform):
raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}")
try:
websocket = await connect(self.endpoint, additional_headers={})
websocket = await connect(
self.endpoint,
additional_headers={},
max_size=10 * 1024 * 1024, # 10MB
)
self.ws = websocket
await asyncio.sleep(0.1)
@@ -41,7 +41,6 @@ class SlackAdapter(Platform):
) -> None:
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)
self.bot_token = platform_config.get("bot_token")
self.app_token = platform_config.get("app_token")
@@ -147,12 +146,10 @@ class SlackAdapter(Platform):
abm.group_id = channel_id
# 设置会话ID
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = f"{user_id}_{channel_id}"
if abm.type == MessageType.GROUP_MESSAGE:
abm.session_id = abm.group_id
else:
abm.session_id = (
channel_id if abm.type == MessageType.GROUP_MESSAGE else user_id
)
abm.session_id = user_id
abm.message_id = event.get("client_msg_id", uuid.uuid4().hex)
abm.timestamp = int(float(event.get("ts", time.time())))
@@ -79,7 +79,6 @@ class WebChatAdapter(Platform):
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
os.makedirs(self.imgs_dir, exist_ok=True)
@@ -1,942 +0,0 @@
import asyncio
import base64
import json
import os
import time
import traceback
from typing import cast
import aiohttp
import anyio
import websockets
from astrbot import logger
from astrbot.api.message_components import At, Image, Plain, Record
from astrbot.api.platform import Platform, PlatformMetadata
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.astrbot_message import (
AstrBotMessage,
MessageMember,
MessageType,
)
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ...register import register_platform_adapter
from .wechatpadpro_message_event import WeChatPadProMessageEvent
try:
from .xml_data_parser import GeweDataParser
except ImportError as e:
logger.warning(
f"警告: 可能未安装 defusedxml 依赖库,将导致无法解析微信的 表情包、引用 类型的消息: {e!s}",
)
@register_platform_adapter(
"wechatpadpro", "WeChatPadPro 消息平台适配器", support_streaming_message=False
)
class WeChatPadProAdapter(Platform):
def __init__(
self,
platform_config: dict,
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(platform_config, event_queue)
self._shutdown_event = None
self.wxnewpass = None
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)
self.metadata = PlatformMetadata(
name="wechatpadpro",
description="WeChatPadPro 消息平台适配器",
id=self.config.get("id", "wechatpadpro"),
support_streaming_message=False,
)
# 保存配置信息
self.admin_key = self.config.get("admin_key")
self.host = self.config.get("host")
self.port = self.config.get("port")
self.active_mesasge_poll: bool = self.config.get(
"wpp_active_message_poll",
False,
)
self.active_message_poll_interval: int = self.config.get(
"wpp_active_message_poll_interval",
5,
)
self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码
self.wxid: str | None = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join(
get_astrbot_data_path(),
"wechatpadpro_credentials.json",
) # 持久化文件路径
self.ws_handle_task = None
# 添加图片消息缓存,用于引用消息处理
self.cached_images = {}
"""缓存图片消息。key是NewMsgId (对应引用消息的svrid)value是图片的base64数据"""
# 设置缓存大小限制,避免内存占用过大
self.max_image_cache = 50
# 添加文本消息缓存,用于引用消息处理
self.cached_texts = {}
"""缓存文本消息。key是NewMsgId (对应引用消息的svrid)value是消息文本内容"""
# 设置文本缓存大小限制
self.max_text_cache = 100
async def run(self) -> None:
"""启动平台适配器的运行实例。"""
logger.info("WeChatPadPro 适配器正在启动...")
if loaded_credentials := self.load_credentials():
self.auth_key = loaded_credentials.get("auth_key")
self.wxid = loaded_credentials.get("wxid")
isLoginIn = await self.check_online_status()
# 检查在线状态
if self.auth_key and isLoginIn:
logger.info("WeChatPadPro 设备已在线,凭据存在,跳过扫码登录。")
# 如果在线,连接 WebSocket 接收消息
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
else:
# 1. 生成授权码
if not self.auth_key:
logger.info("WeChatPadPro 无可用凭据,将生成新的授权码。")
await self.generate_auth_key()
# 2. 获取登录二维码
if not isLoginIn:
logger.info("WeChatPadPro 设备已离线,开始扫码登录。")
qr_code_url = await self.get_login_qr_code()
if qr_code_url:
logger.info(f"请扫描以下二维码登录: {qr_code_url}")
else:
logger.error("无法获取登录二维码。")
return
# 3. 检测扫码状态
login_successful = await self.check_login_status()
if login_successful:
logger.info("登录成功,WeChatPadPro适配器已连接。")
else:
logger.warning("登录失败或超时,WeChatPadPro 适配器将关闭。")
await self.terminate()
return
# 登录成功后,连接 WebSocket 接收消息
self.ws_handle_task = asyncio.create_task(self.connect_websocket())
self._shutdown_event = asyncio.Event()
await self._shutdown_event.wait()
logger.info("WeChatPadPro 适配器已停止。")
def load_credentials(self):
"""从文件中加载 auth_key 和 wxid。"""
if os.path.exists(self.credentials_file):
try:
with open(self.credentials_file) as f:
credentials = json.load(f)
logger.info("成功加载 WeChatPadPro 凭据。")
return credentials
except Exception as e:
logger.error(f"加载 WeChatPadPro 凭据失败: {e}")
return None
def save_credentials(self):
"""将 auth_key 和 wxid 保存到文件。"""
credentials = {
"auth_key": self.auth_key,
"wxid": self.wxid,
}
try:
# 确保数据目录存在
data_dir = os.path.dirname(self.credentials_file)
os.makedirs(data_dir, exist_ok=True)
with open(self.credentials_file, "w") as f:
json.dump(credentials, f)
except Exception as e:
logger.error(f"保存 WeChatPadPro 凭据失败: {e}")
async def check_online_status(self):
"""检查 WeChatPadPro 设备是否在线。"""
if not self.auth_key:
return False
url = f"{self.base_url}/login/GetLoginStatus"
params = {"key": self.auth_key}
async with aiohttp.ClientSession() as session:
try:
async with session.get(url, params=params) as response:
response_data = await response.json()
# 根据提供的在线接口返回示例,成功状态码是 200,loginState 为 1 表示在线
if response.status == 200 and response_data.get("Code") == 200:
login_state = response_data.get("Data", {}).get("loginState")
if login_state == 1:
logger.info("WeChatPadPro 设备当前在线。")
return True
# login_state == 3 为离线状态
if login_state == 3:
logger.info("WeChatPadPro 设备不在线。")
return False
logger.error(f"未知的在线状态: {response_data}")
return False
# Code == 300 为微信退出状态。
if response.status == 200 and response_data.get("Code") == 300:
logger.info("WeChatPadPro 设备已退出。")
return False
if response.status == 200 and response_data.get("Code") == -2:
# 该链接不存在
self.auth_key = None
return False
logger.error(
f"检查在线状态失败: {response.status}, {response_data}",
)
return False
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return False
except Exception as e:
logger.error(f"检查在线状态时发生错误: {e}")
logger.error(traceback.format_exc())
return False
def _extract_auth_key(self, data):
"""Helper method to extract auth_key from response data."""
if isinstance(data, dict):
auth_keys = data.get("authKeys") # 新接口
if isinstance(auth_keys, list) and auth_keys:
return auth_keys[0]
elif isinstance(data, list) and data: # 旧接口
return data[0]
return None
async def generate_auth_key(self):
"""生成授权码。"""
url = f"{self.base_url}/admin/GenAuthKey1"
params = {"key": self.admin_key}
payload = {"Count": 1, "Days": 365} # 生成一个有效期365天的授权码
self.auth_key = None # Reset auth_key before generating a new one
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(
f"生成授权码失败: {response.status}, {await response.text()}",
)
return
response_data = await response.json()
if response_data.get("Code") == 200:
if data := response_data.get("Data"):
self.auth_key = self._extract_auth_key(data)
if self.auth_key:
logger.info("成功获取授权码")
else:
logger.error(
f"生成授权码成功但未找到授权码: {response_data}",
)
else:
logger.error(f"生成授权码失败: {response_data}")
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
except Exception as e:
logger.error(f"生成授权码时发生错误: {e}")
async def get_login_qr_code(self):
"""获取登录二维码地址。"""
url = f"{self.base_url}/login/GetLoginQrCodeNew"
params = {"key": self.auth_key}
payload = {} # 根据文档,这个接口的 body 可以为空
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
response_data = await response.json()
if response.status == 200 and response_data.get("Code") == 200:
# 二维码地址在 Data.QrCodeUrl 字段中
if response_data.get("Data") and response_data["Data"].get(
"QrCodeUrl",
):
return response_data["Data"]["QrCodeUrl"]
logger.error(
f"获取登录二维码成功但未找到二维码地址: {response_data}",
)
return None
if "该 key 无效" in response_data.get("Text"):
logger.error(
"授权码无效,已经清除。请重新启动 AstrBot 或者本消息适配器。原因也可能是 WeChatPadPro 的 MySQL 服务没有启动成功,请检查 WeChatPadPro 服务的日志。",
)
self.auth_key = None
self.save_credentials()
return None
logger.error(
f"获取登录二维码失败: {response.status}, {response_data}",
)
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取登录二维码时发生错误: {e}")
return None
async def check_login_status(self):
"""循环检测扫码状态。
尝试 6 次后跳出循环添加倒计时
返回 True 如果登录成功否则返回 False
"""
url = f"{self.base_url}/login/CheckLoginStatus"
params = {"key": self.auth_key}
attempts = 0 # 初始化尝试次数
max_attempts = 36 # 最大尝试次数
countdown = 180 # 倒计时时长
logger.info(f"请在 {countdown} 秒内扫码登录。")
while attempts < max_attempts:
async with aiohttp.ClientSession() as session:
try:
async with session.get(url, params=params) as response:
response_data = await response.json()
# 成功判断条件和数据提取路径
if response.status == 200 and response_data.get("Code") == 200:
if (
response_data.get("Data")
and response_data["Data"].get("state") is not None
):
status = response_data["Data"]["state"]
logger.info(
f"{attempts + 1} 次尝试,当前登录状态: {status},还剩{countdown - attempts * 5}",
)
if status == 2: # 状态 2 表示登录成功
self.wxid = response_data["Data"].get("wxid")
self.wxnewpass = response_data["Data"].get(
"wxnewpass",
)
logger.info(
f"登录成功,wxid: {self.wxid}, wxnewpass: {self.wxnewpass}",
)
self.save_credentials() # 登录成功后保存凭据
return True
if status == -2: # 二维码过期
logger.error("二维码已过期,请重新获取。")
return False
else:
logger.error(
f"检测登录状态成功但未找到登录状态: {response_data}",
)
elif response_data.get("Code") == 300:
# "不存在状态"
pass
else:
logger.info(
f"检测登录状态失败: {response.status}, {response_data}",
)
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
await asyncio.sleep(5)
attempts += 1
continue
except Exception as e:
logger.error(f"检测登录状态时发生错误: {e}")
attempts += 1
continue
attempts += 1
await asyncio.sleep(5) # 每隔5秒检测一次
logger.warning("登录检测超过最大尝试次数,退出检测。")
return False
async def connect_websocket(self):
"""建立 WebSocket 连接并处理接收到的消息。"""
os.environ["no_proxy"] = f"localhost,127.0.0.1,{self.host}"
ws_url = f"ws://{self.host}:{self.port}/ws/GetSyncMsg?key={self.auth_key}"
logger.info(
f"正在连接 WebSocket: ws://{self.host}:{self.port}/ws/GetSyncMsg?key=***",
)
while True:
try:
async with websockets.connect(ws_url) as websocket:
logger.debug("WebSocket 连接成功。")
# 设置空闲超时重连
wait_time = (
self.active_message_poll_interval
if self.active_mesasge_poll
else 120
)
while True:
try:
message = await asyncio.wait_for(
websocket.recv(),
timeout=wait_time,
)
# logger.debug(message) # 不显示原始消息内容
asyncio.create_task(self.handle_websocket_message(message))
except asyncio.TimeoutError:
logger.debug(f"WebSocket 连接空闲超过 {wait_time} s")
break
except websockets.exceptions.ConnectionClosedOK:
logger.info("WebSocket 连接正常关闭。")
break
except Exception as e:
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
break
except Exception as e:
logger.error(
f"WebSocket 连接失败: {e}, 请检查WeChatPadPro服务状态,或尝试重启WeChatPadPro适配器。",
)
await asyncio.sleep(5)
async def handle_websocket_message(self, message: str | bytes):
"""处理从 WebSocket 接收到的消息。"""
logger.debug(f"收到 WebSocket 消息: {message}")
try:
message_data = json.loads(message)
if (
message_data.get("msg_id") is not None
and message_data.get("from_user_name") is not None
):
abm = await self.convert_message(message_data)
if abm:
# 创建 WeChatPadProMessageEvent 实例
message_event = WeChatPadProMessageEvent(
message_str=abm.message_str,
message_obj=abm,
platform_meta=self.meta(),
session_id=abm.session_id,
# 传递适配器实例,以便在事件中调用 send 方法
adapter=self,
)
# 提交事件到事件队列
self.commit_event(message_event)
else:
logger.warning(f"收到未知结构的 WebSocket 消息: {message_data}")
except json.JSONDecodeError:
logger.error(f"无法解析 WebSocket 消息为 JSON: {message}")
except Exception as e:
logger.error(f"处理 WebSocket 消息时发生错误: {e}")
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
if self.wxid is None:
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
return None
abm = AstrBotMessage()
abm.raw_message = raw_message
abm.message_id = str(raw_message.get("msg_id"))
abm.timestamp = cast(int, raw_message.get("create_time"))
abm.self_id = self.wxid
if int(time.time()) - abm.timestamp > 180:
logger.warning(
f"忽略 3 分钟前的旧消息:消息时间戳 {abm.timestamp} 超过当前时间 {int(time.time())}",
)
return None
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
content = raw_message.get("content", {}).get("str", "")
push_content = raw_message.get("push_content", "")
msg_type = cast(int, raw_message.get("msg_type"))
abm.message_str = ""
abm.message = []
# 如果是机器人自己发送的消息、回显消息或系统消息,忽略
if from_user_name == self.wxid:
logger.info("忽略来自自己的消息。")
return None
if from_user_name in ["weixin", "newsapp", "newsapp_wechat"]:
logger.info("忽略来自微信团队的消息。")
return None
# 先判断群聊/私聊并设置基本属性
if await self._process_chat_type(
abm,
raw_message,
from_user_name,
to_user_name,
content,
push_content,
):
# 再根据消息类型处理消息内容
await self._process_message_content(abm, raw_message, msg_type, content)
return abm
return None
async def _process_chat_type(
self,
abm: AstrBotMessage,
raw_message: dict,
from_user_name: str,
to_user_name: str,
content: str,
push_content: str,
):
"""判断消息是群聊还是私聊,并设置 AstrBotMessage 的基本属性。"""
if from_user_name == "weixin":
return False
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = from_user_name
parts = content.split(":\n", 1)
sender_wxid = parts[0] if len(parts) == 2 else ""
abm.sender = MessageMember(user_id=sender_wxid, nickname="")
# 获取群聊发送者的nickname
if sender_wxid:
accurate_nickname = await self._get_group_member_nickname(
abm.group_id,
sender_wxid,
)
if accurate_nickname:
abm.sender.nickname = accurate_nickname
# 对于群聊,session_id 可以是群聊 ID 或发送者 ID + 群聊 ID (如果 unique_session 为 True)
if self.unique_session:
abm.session_id = f"{from_user_name}#{abm.sender.user_id}"
else:
abm.session_id = from_user_name
msg_source = raw_message.get("msg_source", "")
if self.wxid in msg_source:
at_me = True
if "在群聊中@了你" in raw_message.get("push_content", ""):
at_me = True
if at_me:
abm.message.insert(0, At(qq=abm.self_id, name=""))
else:
abm.type = MessageType.FRIEND_MESSAGE
abm.group_id = ""
nick_name = ""
if push_content and " : " in push_content:
nick_name = push_content.split(" : ")[0]
abm.sender = MessageMember(user_id=from_user_name, nickname=nick_name)
abm.session_id = from_user_name
return True
async def _get_group_member_nickname(
self,
group_id: str,
member_wxid: str,
) -> str | None:
"""通过接口获取群成员的昵称。"""
url = f"{self.base_url}/group/GetChatroomMemberDetail"
params = {"key": self.auth_key}
payload = {
"ChatRoomName": group_id,
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
response_data = await response.json()
if response.status == 200 and response_data.get("Code") == 200:
# 从返回数据中查找对应成员的昵称
member_list = (
response_data.get("Data", {})
.get("member_data", {})
.get("chatroom_member_list", [])
)
for member in member_list:
if member.get("user_name") == member_wxid:
return member.get("nick_name")
logger.warning(
f"在群 {group_id} 中未找到成员 {member_wxid} 的昵称",
)
else:
logger.error(
f"获取群成员详情失败: {response.status}, {response_data}",
)
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取群成员详情时发生错误: {e}")
return None
async def _download_raw_image(
self,
from_user_name: str,
to_user_name: str,
msg_id: int,
) -> dict | None:
"""下载原始图片。"""
url = f"{self.base_url}/message/GetMsgBigImg"
params = {"key": self.auth_key}
payload = {
"CompressType": 0,
"FromUserName": from_user_name,
"MsgId": msg_id,
"Section": {"DataLen": 61440, "StartPos": 0},
"ToUserName": to_user_name,
"TotalLen": 0,
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status == 200:
return await response.json()
logger.error(f"下载图片失败: {response.status}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"下载图片时发生错误: {e}")
return None
async def download_voice(
self,
to_user_name: str,
new_msg_id: str,
bufid: str,
length: int,
):
"""下载原始音频。"""
url = f"{self.base_url}/message/GetMsgVoice"
params = {"key": self.auth_key}
payload = {
"Bufid": bufid,
"ToUserName": to_user_name,
"NewMsgId": new_msg_id,
"Length": length,
}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status == 200:
return await response.json()
logger.error(f"下载音频失败: {response.status}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"下载音频时发生错误: {e}")
return None
async def _process_message_content(
self,
abm: AstrBotMessage,
raw_message: dict,
msg_type: int,
content: str,
):
"""根据消息类型处理消息内容,填充 AstrBotMessage 的 message 列表。"""
if msg_type == 1: # 文本消息
abm.message_str = content
if abm.type == MessageType.GROUP_MESSAGE:
parts = content.split(":\n", 1)
if len(parts) == 2:
message_content = parts[1]
abm.message_str = message_content
# 检查是否@了机器人,参考 gewechat 的实现方式
# 微信大部分客户端在@用户昵称后面,紧接着是一个\u2005字符(四分之一空格)
at_me = False
# 检查 msg_source 中是否包含机器人的 wxid
# wechatpadpro 的格式: <atuserlist>wxid</atuserlist>
# gewechat 的格式: <atuserlist><![CDATA[wxid]]></atuserlist>
msg_source = raw_message.get("msg_source", "")
if (
f"<atuserlist>{abm.self_id}</atuserlist>" in msg_source
or f"<atuserlist>{abm.self_id}," in msg_source
or f",{abm.self_id}</atuserlist>" in msg_source
):
at_me = True
# 也检查 push_content 中是否有@提示
push_content = raw_message.get("push_content", "")
if "在群聊中@了你" in push_content:
at_me = True
if at_me:
# 被@了,在消息开头插入At组件(参考gewechat的做法)
bot_nickname = await self._get_group_member_nickname(
abm.group_id,
abm.self_id,
)
abm.message.insert(
0,
At(qq=abm.self_id, name=bot_nickname or abm.self_id),
)
# 只有当消息内容不仅仅是@时才添加Plain组件
if "\u2005" in message_content:
# 检查@之后是否还有其他内容
parts = message_content.split("\u2005")
if len(parts) > 1 and any(
part.strip() for part in parts[1:]
):
abm.message.append(Plain(message_content))
else:
# 检查是否只包含@机器人
is_pure_at = False
if (
bot_nickname
and message_content.strip() == f"@{bot_nickname}"
):
is_pure_at = True
if not is_pure_at:
abm.message.append(Plain(message_content))
else:
# 没有@机器人,作为普通文本处理
abm.message.append(Plain(message_content))
else:
abm.message.append(Plain(abm.message_str))
else: # 私聊消息
abm.message.append(Plain(abm.message_str))
# 缓存文本消息,以便引用消息可以查找
try:
# 获取msg_id作为缓存的key
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id:
# 限制缓存大小
if (
len(self.cached_texts) >= self.max_text_cache
and self.cached_texts
):
# 删除最早的一条缓存
oldest_key = next(iter(self.cached_texts))
self.cached_texts.pop(oldest_key)
logger.debug(f"缓存文本消息,new_msg_id={new_msg_id}")
self.cached_texts[str(new_msg_id)] = content
except Exception as e:
logger.error(f"缓存文本消息失败: {e}")
elif msg_type == 3:
# 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
msg_id = cast(int, raw_message.get("msg_id"))
image_resp = await self._download_raw_image(
from_user_name,
to_user_name,
msg_id,
)
if image_resp is None:
logger.error(f"下载图片失败: msg_id={msg_id}")
return
image_bs64_data = (
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
)
if image_bs64_data:
abm.message.append(Image.fromBase64(image_bs64_data))
# 缓存图片,以便引用消息可以查找
try:
# 获取msg_id作为缓存的key
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id:
# 限制缓存大小
if (
len(self.cached_images) >= self.max_image_cache
and self.cached_images
):
# 删除最早的一条缓存
oldest_key = next(iter(self.cached_images))
self.cached_images.pop(oldest_key)
logger.debug(f"缓存图片消息,new_msg_id={new_msg_id}")
self.cached_images[str(new_msg_id)] = image_bs64_data
except Exception as e:
logger.error(f"缓存图片消息失败: {e}")
elif msg_type == 47:
# 视频消息 (注意:表情消息也是 47,需要区分)
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
raw_message=raw_message,
)
emoji_message = data_parser.parse_emoji()
if emoji_message is not None:
abm.message.append(emoji_message)
elif msg_type == 50:
logger.warning("收到语音/视频消息,待实现。")
elif msg_type == 34:
# 语音消息
bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id is None:
logger.error("语音消息缺少 new_msg_id")
return
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
raw_message=raw_message,
)
voicemsg = data_parser._format_to_xml().find("voicemsg")
if voicemsg is None:
logger.error("无法从 XML 解析 voicemsg 节点")
return
bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice(
to_user_name=to_user_name,
new_msg_id=new_msg_id,
bufid=bufid,
length=length,
)
if voice_resp is None:
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
return
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(
temp_dir,
f"wechatpadpro_voice_{abm.message_id}.silk",
)
async with await anyio.open_file(file_path, "wb") as f:
await f.write(voice_bs64_data)
abm.message.append(Record(file=file_path, url=file_path))
elif msg_type == 49:
try:
parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
cached_texts=self.cached_texts,
cached_images=self.cached_images,
raw_message=raw_message,
downloader=self._download_raw_image,
)
components = await parser.parse_mutil_49()
if components:
abm.message.extend(components)
abm.message_str = "\n".join(
c.text for c in components if isinstance(c, Plain)
)
except Exception as e:
logger.warning(f"msg_type 49 处理失败: {e}")
abm.message.append(Plain("[XML 消息处理失败]"))
abm.message_str = "[XML 消息处理失败]"
else:
logger.warning(f"收到未处理的消息类型: {msg_type}")
async def terminate(self):
"""终止一个平台的运行实例。"""
logger.info("终止 WeChatPadPro 适配器。")
try:
if self.ws_handle_task:
self.ws_handle_task.cancel()
if self._shutdown_event is not None:
self._shutdown_event.set()
except Exception:
pass
def meta(self) -> PlatformMetadata:
"""得到一个平台的元数据。"""
return self.metadata
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
dummy_message_obj = AstrBotMessage()
dummy_message_obj.session_id = session.session_id
# 根据 session_id 判断消息类型
if "@chatroom" in session.session_id:
dummy_message_obj.type = MessageType.GROUP_MESSAGE
if "#" in session.session_id:
dummy_message_obj.group_id = session.session_id.split("#")[0]
else:
dummy_message_obj.group_id = session.session_id
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
else:
dummy_message_obj.type = MessageType.FRIEND_MESSAGE
dummy_message_obj.group_id = ""
dummy_message_obj.sender = MessageMember(user_id="", nickname="")
sending_event = WeChatPadProMessageEvent(
message_str="",
message_obj=dummy_message_obj,
platform_meta=self.meta(),
session_id=session.session_id,
adapter=self,
)
# 调用实例方法 send
await sending_event.send(message_chain)
async def get_contact_list(self):
"""获取联系人列表。"""
url = f"{self.base_url}/friend/GetContactList"
params = {"key": self.auth_key}
payload = {"CurrentChatRoomContactSeq": 0, "CurrentWxcontactSeq": 0}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"获取联系人列表失败: {response.status}")
return None
result = await response.json()
if result.get("Code") == 200 and result.get("Data"):
contact_list = (
result.get("Data", {})
.get("ContactList", {})
.get("contactUsernameList", [])
)
return contact_list
logger.error(f"获取联系人列表失败: {result}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取联系人列表时发生错误: {e}")
return None
async def get_contact_details_list(
self,
room_wx_id_list: list[str] | None = None,
user_names: list[str] | None = None,
) -> dict | None:
"""获取联系人详情列表。"""
if room_wx_id_list is None:
room_wx_id_list = []
if user_names is None:
user_names = []
url = f"{self.base_url}/friend/GetContactDetailsList"
params = {"key": self.auth_key}
payload = {"RoomWxIDList": room_wx_id_list, "UserNames": user_names}
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, params=params, json=payload) as response:
if response.status != 200:
logger.error(f"获取联系人详情列表失败: {response.status}")
return None
result = await response.json()
if result.get("Code") == 200 and result.get("Data"):
contact_list = result.get("Data", {}).get("contactList", {})
return contact_list
logger.error(f"获取联系人详情列表失败: {result}")
return None
except aiohttp.ClientConnectorError as e:
logger.error(f"连接到 WeChatPadPro 服务失败: {e}")
return None
except Exception as e:
logger.error(f"获取联系人详情列表时发生错误: {e}")
return None
@@ -1,178 +0,0 @@
import asyncio
import base64
import io
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
import aiohttp
from PIL import Image as PILImage # 使用别名避免冲突
from astrbot import logger
from astrbot.core.message.components import (
Image,
Plain,
Record,
WechatEmoji,
) # Import Image
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageType
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.utils.tencent_record_helper import audio_to_tencent_silk_base64
if TYPE_CHECKING:
from .wechatpadpro_adapter import WeChatPadProAdapter
class WeChatPadProMessageEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
adapter: "WeChatPadProAdapter", # 传递适配器实例
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.message_obj = message_obj # Save the full message object
self.adapter = adapter # Save the adapter instance
async def send(self, message: MessageChain):
async with aiohttp.ClientSession() as session:
for comp in message.chain:
await asyncio.sleep(1)
if isinstance(comp, Plain):
await self._send_text(session, comp.text)
elif isinstance(comp, Image):
await self._send_image(session, comp)
elif isinstance(comp, WechatEmoji):
await self._send_emoji(session, comp)
elif isinstance(comp, Record):
await self._send_voice(session, comp)
await super().send(message)
async def send_streaming(
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
):
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
async def _send_image(self, session: aiohttp.ClientSession, comp: Image):
b64 = await comp.convert_to_base64()
raw = self._validate_base64(b64)
b64c = self._compress_image(raw)
payload = {
"MsgItem": [
{"ImageContent": b64c, "MsgType": 3, "ToUserName": self.session_id},
],
}
url = f"{self.adapter.base_url}/message/SendImageNewMessage"
await self._post(session, url, payload)
async def _send_text(self, session: aiohttp.ClientSession, text: str):
if (
self.message_obj.type == MessageType.GROUP_MESSAGE # 确保是群聊消息
and self.adapter.settings.get(
"reply_with_mention",
False,
) # 检查适配器设置是否启用 reply_with_mention
and self.message_obj.sender # 确保有发送者信息
and (
self.message_obj.sender.user_id or self.message_obj.sender.nickname
) # 确保发送者有 ID 或昵称
):
# 优先使用 nickname,如果没有则使用 user_id
mention_text = (
self.message_obj.sender.nickname or self.message_obj.sender.user_id
)
message_text = f"@{mention_text} {text}"
# logger.info(f"已添加 @ 信息: {message_text}")
else:
message_text = text
if self.get_group_id() and "#" in self.session_id:
session_id = self.session_id.split("#")[0]
else:
session_id = self.session_id
payload = {
"MsgItem": [
{
"MsgType": 1,
"TextContent": message_text,
"ToUserName": session_id,
},
],
}
url = f"{self.adapter.base_url}/message/SendTextMessage"
await self._post(session, url, payload)
async def _send_emoji(self, session: aiohttp.ClientSession, comp: WechatEmoji):
payload = {
"EmojiList": [
{
"EmojiMd5": comp.md5,
"EmojiSize": comp.md5_len,
"ToUserName": self.session_id,
},
],
}
url = f"{self.adapter.base_url}/message/SendEmojiMessage"
await self._post(session, url, payload)
async def _send_voice(self, session: aiohttp.ClientSession, comp: Record):
record_path = await comp.convert_to_file_path()
# 默认已经存在 data/temp 中
b64, duration = await audio_to_tencent_silk_base64(record_path)
payload = {
"ToUserName": self.session_id,
"VoiceData": b64,
"VoiceFormat": 4,
"VoiceSecond": duration,
}
url = f"{self.adapter.base_url}/message/SendVoice"
await self._post(session, url, payload)
@staticmethod
def _validate_base64(b64: str) -> bytes:
return base64.b64decode(b64, validate=True)
@staticmethod
def _compress_image(data: bytes) -> str:
img = PILImage.open(io.BytesIO(data))
buf = io.BytesIO()
if img.format == "JPEG":
img.save(buf, "JPEG", quality=80)
else:
if img.mode in ("RGBA", "P"):
img = img.convert("RGB")
img.save(buf, "JPEG", quality=80)
# logger.info("图片处理完成!!!")
return base64.b64encode(buf.getvalue()).decode()
async def _post(self, session, url, payload):
params = {"key": self.adapter.auth_key}
try:
async with session.post(url, params=params, json=payload) as resp:
data = await resp.json()
if resp.status != 200 or data.get("Code") != 200:
logger.error(f"{url} failed: {resp.status} {data}")
except Exception as e:
logger.error(f"{url} error: {e}")
# TODO: 添加对其他消息组件类型的处理 (Record, Video, At等)
# elif isinstance(component, Record):
# pass
# elif isinstance(component, Video):
# pass
# elif isinstance(component, At):
# pass
# ...
@@ -1,159 +0,0 @@
from defusedxml import ElementTree as eT
from astrbot.api import logger
from astrbot.api.message_components import (
BaseMessageComponent,
Image,
Plain,
)
from astrbot.api.message_components import (
WechatEmoji as Emoji,
)
class GeweDataParser:
def __init__(
self,
content: str,
is_private_chat: bool = False,
cached_texts=None,
cached_images=None,
raw_message: dict | None = None,
downloader=None,
):
self._xml = None
self.content = content
self.is_private_chat = is_private_chat
self.cached_texts = cached_texts or {}
self.cached_images = cached_images or {}
self.downloader = downloader
raw_message = raw_message or {}
self.from_user_name = raw_message.get("from_user_name", {}).get("str", "")
self.to_user_name = raw_message.get("to_user_name", {}).get("str", "")
self.msg_id = raw_message.get("msg_id", "")
def _format_to_xml(self):
if self._xml:
return self._xml
try:
msg_str = self.content
if not self.is_private_chat:
parts = self.content.split(":\n", 1)
msg_str = parts[1] if len(parts) == 2 else self.content
self._xml = eT.fromstring(msg_str)
return self._xml
except Exception as e:
logger.error(f"[XML解析失败] {e}")
raise
async def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
"""处理 msg_type == 49 的多种 appmsg 类型(目前支持 type==57"""
try:
appmsg_type = self._format_to_xml().findtext(".//appmsg/type")
if appmsg_type == "57":
return await self.parse_reply()
except Exception as e:
logger.warning(f"[parse_mutil_49] 解析失败: {e}")
return None
async def parse_reply(self) -> list[BaseMessageComponent]:
"""处理 type == 57 的引用消息:支持文本(1)、图片(3)、嵌套49(49)"""
components = []
try:
appmsg = self._format_to_xml().find("appmsg")
if appmsg is None:
return [Plain("[引用消息解析失败]")]
refermsg = appmsg.find("refermsg")
if refermsg is None:
return [Plain("[引用消息解析失败]")]
quote_type = int(refermsg.findtext("type", "0"))
nickname = refermsg.findtext("displayname", "未知发送者")
quote_content = refermsg.findtext("content", "")
svrid = refermsg.findtext("svrid")
match quote_type:
case 1: # 文本引用
quoted_text = self.cached_texts.get(str(svrid), quote_content)
components.append(Plain(f"[引用] {nickname}: {quoted_text}"))
case 3: # 图片引用
quoted_image_b64 = self.cached_images.get(str(svrid))
if not quoted_image_b64:
try:
quote_xml = eT.fromstring(quote_content)
img = quote_xml.find("img")
cdn_url = (
img.get("cdnbigimgurl") or img.get("cdnmidimgurl")
if img is not None
else None
)
if cdn_url and self.downloader:
image_resp = await self.downloader(
self.from_user_name,
self.to_user_name,
self.msg_id,
)
quoted_image_b64 = (
image_resp.get("Data", {})
.get("Data", {})
.get("Buffer")
)
except Exception as e:
logger.warning(f"[引用图片解析失败] svrid={svrid} err={e}")
if quoted_image_b64:
components.extend(
[
Image.fromBase64(quoted_image_b64),
Plain(f"[引用] {nickname}: [引用的图片]"),
],
)
else:
components.append(
Plain(f"[引用] {nickname}: [引用的图片 - 未能获取]"),
)
case 49: # 嵌套引用
try:
nested_root = eT.fromstring(quote_content)
nested_title = nested_root.findtext(".//appmsg/title", "")
components.append(Plain(f"[引用] {nickname}: {nested_title}"))
except Exception as e:
logger.warning(f"[嵌套引用解析失败] err={e}")
components.append(Plain(f"[引用] {nickname}: [嵌套引用消息]"))
case _: # 其他未识别类型
logger.info(f"[未知引用类型] quote_type={quote_type}")
components.append(Plain(f"[引用] {nickname}: [不支持的引用类型]"))
# 主消息标题
title = appmsg.findtext("title", "")
if title:
components.append(Plain(title))
except Exception as e:
logger.error(f"[parse_reply] 总体解析失败: {e}")
return [Plain("[引用消息解析失败]")]
return components
def parse_emoji(self) -> Emoji | None:
"""处理 msg_type == 47 的表情消息(emoji"""
try:
emoji_element = self._format_to_xml().find(".//emoji")
if emoji_element is not None:
return Emoji(
md5=emoji_element.get("md5"),
md5_len=emoji_element.get("len"),
cdnurl=emoji_element.get("cdnurl"),
)
except Exception as e:
logger.error(f"[parse_emoji] 解析失败: {e}")
return None
@@ -191,7 +191,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if self.active_send_mode:
await self.convert_message(msg, None)
else:
if msg.id in self.wexin_event_workers:
if str(msg.id) in self.wexin_event_workers:
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
+46 -10
View File
@@ -14,6 +14,7 @@ import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core.agent.message import (
AssistantMessageSegment,
ContentPart,
ToolCall,
ToolCallMessageSegment,
)
@@ -92,6 +93,8 @@ class ProviderRequest:
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象"""
func_tool: ToolSet | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
@@ -166,13 +169,23 @@ class ProviderRequest:
async def assemble_context(self) -> dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if self.prompt and self.prompt.strip():
content_blocks.append({"type": "text", "text": self.prompt})
elif self.image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
# 2. 额外的内容块(系统提醒、指令等)
if self.extra_user_content_parts:
for part in self.extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if self.image_urls:
user_content = {
"role": "user",
"content": [
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
@@ -185,11 +198,21 @@ class ProviderRequest:
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
content_blocks.append(
{"type": "image_url", "image_url": {"url": image_data}},
)
return user_content
return {"role": "user", "content": self.prompt}
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
if (
len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
and not self.extra_user_content_parts
and not self.image_urls
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
@@ -249,6 +272,8 @@ class LLMResponse:
"""Tool call extra content. tool_call_id -> extra_content dict"""
reasoning_content: str = ""
"""The reasoning content extracted from the LLM, if any."""
reasoning_signature: str | None = None
"""The signature of the reasoning content, if any."""
raw_completion: (
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
@@ -269,12 +294,14 @@ class LLMResponse:
def __init__(
self,
role: str,
completion_text: str = "",
completion_text: str | None = None,
result_chain: MessageChain | None = None,
tools_call_args: list[dict[str, Any]] | None = None,
tools_call_name: list[str] | None = None,
tools_call_ids: list[str] | None = None,
tools_call_extra_content: dict[str, dict[str, Any]] | None = None,
reasoning_content: str | None = None,
reasoning_signature: str | None = None,
raw_completion: ChatCompletion
| GenerateContentResponse
| AnthropicMessage
@@ -294,6 +321,8 @@ class LLMResponse:
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
"""
if reasoning_content is None:
reasoning_content = ""
if tools_call_args is None:
tools_call_args = []
if tools_call_name is None:
@@ -310,9 +339,16 @@ class LLMResponse:
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.tools_call_extra_content = tools_call_extra_content
self.reasoning_content = reasoning_content
self.reasoning_signature = reasoning_signature
self.raw_completion = raw_completion
self.is_chunk = is_chunk
if id is not None:
self.id = id
if usage is not None:
self.usage = usage
@property
def completion_text(self):
if self.result_chain:
+27 -12
View File
@@ -119,19 +119,34 @@ class ProviderManager:
TTSProvider,
):
self.curr_tts_provider_inst = prov
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider_tts",
value=provider_id,
scope="global",
scope_id="global",
)
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov,
STTProvider,
):
self.curr_stt_provider_inst = prov
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider_stt",
value=provider_id,
scope="global",
scope_id="global",
)
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov,
Provider,
):
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
await sp.put_async(
key="curr_provider",
value=provider_id,
scope="global",
scope_id="global",
)
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
@@ -206,21 +221,21 @@ class ProviderManager:
logger.error(traceback.format_exc())
logger.error(e)
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
selected_provider_id = await sp.get_async(
key="curr_provider",
default=self.provider_settings.get("default_provider_id"),
scope="global",
scope_id="global",
)
selected_stt_provider_id = sp.get(
"curr_provider_stt",
self.provider_stt_settings.get("provider_id"),
selected_stt_provider_id = await sp.get_async(
key="curr_provider_stt",
default=self.provider_stt_settings.get("provider_id"),
scope="global",
scope_id="global",
)
selected_tts_provider_id = sp.get(
"curr_provider_tts",
self.provider_tts_settings.get("provider_id"),
selected_tts_provider_id = await sp.get_async(
key="curr_provider_tts",
default=self.provider_tts_settings.get("provider_id"),
scope="global",
scope_id="global",
)
+3 -1
View File
@@ -4,7 +4,7 @@ import os
from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message
from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.provider.entities import (
LLMResponse,
@@ -103,6 +103,7 @@ class Provider(AbstractProvider):
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -114,6 +115,7 @@ class Provider(AbstractProvider):
tools: tool set
contexts: 上下文 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling
extra_user_content_parts: 额外的内容块列表用于在用户消息后添加额外的文本块如系统提醒指令等
kwargs: 其他参数
Notes:
+140 -30
View File
@@ -11,6 +11,7 @@ from anthropic.types.usage import Usage
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
@@ -47,6 +48,8 @@ class ProviderAnthropic(Provider):
base_url=self.base_url,
)
self.thinking_config = provider_config.get("anth_thinking_config", {})
self.set_model(provider_config.get("model", "unknown"))
def _prepare_payload(self, messages: list[dict]):
@@ -63,12 +66,33 @@ class ProviderAnthropic(Provider):
new_messages = []
for message in messages:
if message["role"] == "system":
system_prompt = message["content"]
system_prompt = message["content"] or "<empty system prompt>"
elif message["role"] == "assistant":
blocks = []
if isinstance(message["content"], str):
reasoning_content = ""
thinking_signature = ""
if isinstance(message["content"], str) and message["content"].strip():
blocks.append({"type": "text", "text": message["content"]})
if "tool_calls" in message:
elif isinstance(message["content"], list):
for part in message["content"]:
if part.get("type") == "think":
# only pick the last think part for now
reasoning_content = part.get("think")
thinking_signature = part.get("encrypted")
else:
blocks.append(part)
if reasoning_content and thinking_signature:
blocks.insert(
0,
{
"type": "thinking",
"thinking": reasoning_content,
"signature": thinking_signature,
},
)
if "tool_calls" in message and isinstance(message["tool_calls"], list):
for tool_call in message["tool_calls"]:
blocks.append( # noqa: PERF401
{
@@ -99,7 +123,7 @@ class ProviderAnthropic(Provider):
{
"type": "tool_result",
"tool_use_id": message["tool_call_id"],
"content": message["content"],
"content": message["content"] or "<empty response>",
},
],
},
@@ -132,6 +156,14 @@ class ProviderAnthropic(Provider):
extra_body = self.provider_config.get("custom_extra_body", {})
if "max_tokens" not in payloads:
payloads["max_tokens"] = 1024
if self.thinking_config.get("budget"):
payloads["thinking"] = {
"budget_tokens": self.thinking_config.get("budget"),
"type": "enabled",
}
completion = await self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
)
@@ -149,6 +181,11 @@ class ProviderAnthropic(Provider):
completion_text = str(content_block.text).strip()
llm_response.completion_text = completion_text
if content_block.type == "thinking":
reasoning_content = str(content_block.thinking).strip()
llm_response.reasoning_content = reasoning_content
llm_response.reasoning_signature = content_block.signature
if content_block.type == "tool_use":
llm_response.tools_call_args.append(content_block.input)
llm_response.tools_call_name.append(content_block.name)
@@ -180,6 +217,16 @@ class ProviderAnthropic(Provider):
id = None
usage = TokenUsage()
extra_body = self.provider_config.get("custom_extra_body", {})
reasoning_content = ""
reasoning_signature = ""
if "max_tokens" not in payloads:
payloads["max_tokens"] = 1024
if self.thinking_config.get("budget"):
payloads["thinking"] = {
"budget_tokens": self.thinking_config.get("budget"),
"type": "enabled",
}
async with self.client.messages.stream(
**payloads, extra_body=extra_body
@@ -219,6 +266,21 @@ class ProviderAnthropic(Provider):
usage=usage,
id=id,
)
elif event.delta.type == "thinking_delta":
# 思考增量
reasoning = event.delta.thinking
if reasoning:
yield LLMResponse(
role="assistant",
reasoning_content=reasoning,
is_chunk=True,
usage=usage,
id=id,
reasoning_signature=reasoning_signature or None,
)
reasoning_content += reasoning
elif event.delta.type == "signature_delta":
reasoning_signature = event.delta.signature
elif event.delta.type == "input_json_delta":
# 工具调用参数增量
if event.index in tool_use_buffer:
@@ -275,6 +337,8 @@ class ProviderAnthropic(Provider):
is_chunk=False,
usage=usage,
id=id,
reasoning_content=reasoning_content,
reasoning_signature=reasoning_signature or None,
)
if final_tool_calls:
@@ -296,13 +360,16 @@ class ProviderAnthropic(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -342,21 +409,24 @@ class ProviderAnthropic(Provider):
async def text_chat_stream(
self,
prompt,
prompt=None,
session_id=None,
image_urls=...,
image_urls=None,
func_tool=None,
contexts=...,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
):
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -388,15 +458,15 @@ class ProviderAnthropic(Provider):
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
async def assemble_context(
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文,支持文本和图片"""
if not image_urls:
return {"role": "user", "content": text}
content = []
content.append({"type": "text", "text": text})
for image_url in image_urls:
async def resolve_image_url(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
@@ -408,28 +478,68 @@ class ProviderAnthropic(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
return None
# Get mime type for the image
mime_type, _ = guess_type(image_url)
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG if can't determine
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
},
return {
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
},
)
}
content = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for block in extra_user_content_parts:
if isinstance(block, TextPart):
content.append({"type": "text", "text": block.text})
elif isinstance(block, ImageURLPart):
image_dict = await resolve_image_url(block.image_url.url)
if image_dict:
content.append(image_dict)
else:
raise ValueError(f"不支持的额外内容块类型: {type(block)}")
# 3. 图片内容
if image_urls:
for image_url in image_urls:
image_dict = await resolve_image_url(image_url)
if image_dict:
content.append(image_dict)
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content) == 1
and content[0]["type"] == "text"
):
return {"role": "user", "content": content[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content}
async def encode_image_bs64(self, image_url: str) -> str:
@@ -56,10 +56,14 @@ class ProviderFishAudioTTSAPI(TTSProvider):
"api_base",
"https://api.fish-audio.cn/v1",
)
try:
self.timeout: int = int(provider_config.get("timeout", 20))
except ValueError:
self.timeout = 20
self.headers = {
"Authorization": f"Bearer {self.chosen_api_key}",
}
self.set_model(provider_config["model"])
self.set_model(provider_config.get("model", None))
async def _get_reference_id_by_character(self, character: str) -> str | None:
"""获取角色的reference_id
@@ -135,17 +139,21 @@ class ProviderFishAudioTTSAPI(TTSProvider):
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
self.headers["content-type"] = "application/msgpack"
request = await self._generate_request(text)
async with AsyncClient(base_url=self.api_base).stream(
async with AsyncClient(base_url=self.api_base, timeout=self.timeout).stream(
"POST",
"/tts",
headers=self.headers,
content=ormsgpack.packb(request, option=ormsgpack.OPT_SERIALIZE_PYDANTIC),
) as response:
if response.headers["content-type"] == "audio/wav":
if response.status_code == 200 and response.headers.get(
"content-type", ""
).startswith("audio/"):
with open(path, "wb") as f:
async for chunk in response.aiter_bytes():
f.write(chunk)
return path
body = await response.aread()
text = body.decode("utf-8", errors="replace")
raise Exception(f"Fish Audio API请求失败: {text}")
error_bytes = await response.aread()
error_text = error_bytes.decode("utf-8", errors="replace")[:1024]
raise Exception(
f"Fish Audio API请求失败: 状态码 {response.status_code}, 响应内容: {error_text}"
)
+112 -29
View File
@@ -13,6 +13,7 @@ from google.genai.errors import APIError
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
@@ -320,9 +321,37 @@ class ProviderGoogleGenAI(Provider):
append_or_extend(gemini_contents, parts, types.UserContent)
elif role == "assistant":
if content:
if isinstance(content, str):
parts = [types.Part.from_text(text=content)]
append_or_extend(gemini_contents, parts, types.ModelContent)
elif isinstance(content, list):
parts = []
thinking_signature = None
text = ""
for part in content:
# for most cases, assistant content only contains two parts: think and text
if part.get("type") == "think":
thinking_signature = part.get("encrypted") or None
else:
text += str(part.get("text"))
if thinking_signature and isinstance(thinking_signature, str):
try:
thinking_signature = base64.b64decode(thinking_signature)
except Exception as e:
logger.warning(
f"Failed to decode google gemini thinking signature: {e}",
exc_info=True,
)
thinking_signature = None
parts.append(
types.Part(
text=text,
thought_signature=thinking_signature,
)
)
append_or_extend(gemini_contents, parts, types.ModelContent)
elif not native_tool_enabled and "tool_calls" in message:
parts = []
for tool in message["tool_calls"]:
@@ -440,7 +469,8 @@ class ProviderGoogleGenAI(Provider):
for part in result_parts:
if part.text:
chain.append(Comp.Plain(part.text))
elif (
if (
part.function_call
and part.function_call.name is not None
and part.function_call.args is not None
@@ -457,13 +487,18 @@ class ProviderGoogleGenAI(Provider):
llm_response.tools_call_extra_content[tool_call_id] = {
"google": {"thought_signature": ts_bs64}
}
elif (
if (
part.inline_data
and part.inline_data.mime_type
and part.inline_data.mime_type.startswith("image/")
and part.inline_data.data
):
chain.append(Comp.Image.fromBytes(part.inline_data.data))
if ts := part.thought_signature:
# only keep the last thinking signature
llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8")
return MessageChain(chain=chain)
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
@@ -680,13 +715,16 @@ class ProviderGoogleGenAI(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -732,13 +770,16 @@ class ProviderGoogleGenAI(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -797,33 +838,75 @@ class ProviderGoogleGenAI(Provider):
self.chosen_api_key = key
self._init_client()
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
async def assemble_context(
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文。"""
if image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
async def resolve_image_part(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None
return {
"type": "image_url",
"image_url": {"url": image_data},
}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content_blocks.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content_blocks.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for part in extra_user_content_parts:
if isinstance(part, TextPart):
content_blocks.append({"type": "text", "text": part.text})
elif isinstance(part, ImageURLPart):
image_part = await resolve_image_part(part.image_url.url)
if image_part:
content_blocks.append(image_part)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{
"type": "image_url",
"image_url": {"url": image_data},
},
)
return user_content
return {"role": "user", "content": text}
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
# 3. 图片内容
if image_urls:
for image_url in image_urls:
image_part = await resolve_image_part(image_url)
if image_part:
content_blocks.append(image_part)
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
@@ -51,7 +51,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
"voice_id": ""
if self.is_timber_weight
else provider_config.get("minimax-voice-id", ""),
"emotion": provider_config.get("minimax-voice-emotion", "neutral"),
"emotion": provider_config.get("minimax-voice-emotion", "auto"),
"latex_read": provider_config.get("minimax-voice-latex", False),
"english_normalization": provider_config.get(
"minimax-voice-english-normalization",
@@ -59,6 +59,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
),
}
if self.voice_setting["emotion"] == "auto":
self.voice_setting.pop("emotion", None)
self.audio_setting: dict = {
"sample_rate": 32000,
"bitrate": 128000,
+94 -56
View File
@@ -17,7 +17,7 @@ from openai.types.completion_usage import CompletionUsage
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import Message
from astrbot.core.agent.message import ContentPart, ImageURLPart, Message, TextPart
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
@@ -74,28 +74,6 @@ class ProviderOpenAIOfficial(Provider):
self.reasoning_key = "reasoning_content"
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
- 仅在 provider_config.xai_native_search True 时生效
- 默认注入 {"mode": "auto"}
- 允许通过 kwargs 使用 xai_search_mode 覆盖on/auto/off
"""
if not bool(self.provider_config.get("xai_native_search", False)):
return
mode = kwargs.get("xai_search_mode", "auto")
mode = str(mode).lower()
if mode not in ("auto", "on", "off"):
mode = "auto"
# off 时不注入,保持与未开启一致
if mode == "off":
return
# OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
payloads["search_parameters"] = {"mode": mode}
async def get_models(self):
try:
models_str = []
@@ -134,10 +112,6 @@ class ProviderOpenAIOfficial(Provider):
model = payloads.get("model", "").lower()
# 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
if model == "deepseek-reasoner" and "tools" in payloads:
del payloads["tools"]
completion = await self.client.chat.completions.create(
**payloads,
stream=False,
@@ -251,10 +225,14 @@ class ProviderOpenAIOfficial(Provider):
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
ptd = usage.prompt_tokens_details
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
prompt_tokens = 0 if usage.prompt_tokens is None else usage.prompt_tokens
completion_tokens = (
0 if usage.completion_tokens is None else usage.completion_tokens
)
return TokenUsage(
input_other=usage.prompt_tokens - cached,
input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
output=usage.completion_tokens,
input_other=prompt_tokens - cached,
input_cached=cached,
output=completion_tokens,
)
async def _parse_openai_completion(
@@ -348,6 +326,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -355,7 +334,9 @@ class ProviderOpenAIOfficial(Provider):
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -378,11 +359,28 @@ class ProviderOpenAIOfficial(Provider):
payloads = {"messages": context_query, "model": model}
# xAI origin search tool inject
self._maybe_inject_xai_search(payloads, **kwargs)
self._finally_convert_payload(payloads)
return payloads, context_query
def _finally_convert_payload(self, payloads: dict):
"""Finally convert the payload. Such as think part conversion, tool inject."""
for message in payloads.get("messages", []):
if message.get("role") == "assistant" and isinstance(
message.get("content"), list
):
reasoning_content = ""
new_content = [] # not including think part
for part in message["content"]:
if part.get("type") == "think":
reasoning_content += str(part.get("think"))
else:
new_content.append(part)
message["content"] = new_content
# reasoning key is "reasoning_content"
if reasoning_content:
message["reasoning_content"] = reasoning_content
async def _handle_api_error(
self,
e: Exception,
@@ -476,6 +474,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
@@ -485,6 +484,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt,
tool_calls_result,
model=model,
extra_user_content_parts=extra_user_content_parts,
**kwargs,
)
@@ -624,33 +624,71 @@ class ProviderOpenAIOfficial(Provider):
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
if image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
async def resolve_image_part(image_url: str) -> dict | None:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
return None
return {
"type": "image_url",
"image_url": {"url": image_data},
}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content_blocks.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content_blocks.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for part in extra_user_content_parts:
if isinstance(part, TextPart):
content_blocks.append({"type": "text", "text": part.text})
elif isinstance(part, ImageURLPart):
image_part = await resolve_image_part(part.image_url.url)
if image_part:
content_blocks.append(image_part)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{
"type": "image_url",
"image_url": {"url": image_data},
},
)
return user_content
return {"role": "user", "content": text}
raise ValueError(f"不支持的额外内容块类型: {type(part)}")
# 3. 图片内容
if image_urls:
for image_url in image_urls:
image_part = await resolve_image_part(image_url)
if image_part:
content_blocks.append(image_part)
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
@@ -0,0 +1,29 @@
from ..register import register_provider_adapter
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter(
"xai_chat_completion", "xAI Chat Completion Provider Adapter"
)
class ProviderXAI(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
def _maybe_inject_xai_search(self, payloads: dict):
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
- 仅在 provider_config.xai_native_search True 时生效
- 默认注入 {"mode": "auto"}
"""
if not bool(self.provider_config.get("xai_native_search", False)):
return
# OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
payloads["search_parameters"] = {"mode": "auto"}
def _finally_convert_payload(self, payloads: dict):
self._maybe_inject_xai_search(payloads)
super()._finally_convert_payload(payloads)
@@ -8,7 +8,10 @@ from xinference_client.client.restful.async_restful_client import (
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
from astrbot.core.utils.tencent_record_helper import (
convert_to_pcm_wav,
tencent_silk_to_wav,
)
from ..entities import ProviderType
from ..provider import STTProvider
@@ -111,17 +114,22 @@ class ProviderXinferenceSTT(STTProvider):
return ""
# 2. Check for conversion
needs_conversion = False
if (
audio_url.endswith((".amr", ".silk"))
or is_tencent
or b"SILK" in audio_bytes[:8]
):
needs_conversion = True
conversion_type = None
if b"SILK" in audio_bytes[:8]:
conversion_type = "silk"
elif b"#!AMR" in audio_bytes[:6]:
conversion_type = "amr"
elif audio_url.endswith(".silk") or is_tencent:
conversion_type = "silk"
elif audio_url.endswith(".amr"):
conversion_type = "amr"
# 3. Perform conversion if needed
if needs_conversion:
logger.info("Audio requires conversion, using temporary files...")
if conversion_type:
logger.info(
f"Audio requires conversion ({conversion_type}), using temporary files..."
)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
@@ -132,8 +140,12 @@ class ProviderXinferenceSTT(STTProvider):
with open(input_path, "wb") as f:
f.write(audio_bytes)
logger.info("Converting silk/amr file to wav ...")
await tencent_silk_to_wav(input_path, output_path)
if conversion_type == "silk":
logger.info("Converting silk to wav ...")
await tencent_silk_to_wav(input_path, output_path)
elif conversion_type == "amr":
logger.info("Converting amr to wav ...")
await convert_to_pcm_wav(input_path, output_path)
with open(output_path, "rb") as f:
audio_bytes = f.read()
+15 -2
View File
@@ -149,9 +149,12 @@ class Context:
contexts: context messages for the LLM
max_steps: Maximum number of tool calls before stopping the loop
**kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include:
stream: bool - whether to stream the LLM response
agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution
agent_context: AstrAgentContext - context to use for the agent
other kwargs will be DIRECTLY passed to the runner.reset() method
Returns:
The final LLMResponse after tool calls are completed.
@@ -194,6 +197,15 @@ class Context:
)
agent_runner = ToolLoopAgentRunner()
tool_executor = FunctionToolExecutor()
streaming = kwargs.get("stream", False)
other_kwargs = {
k: v
for k, v in kwargs.items()
if k not in ["stream", "agent_hooks", "agent_context"]
}
await agent_runner.reset(
provider=prov,
request=request,
@@ -203,7 +215,8 @@ class Context:
),
tool_executor=tool_executor,
agent_hooks=agent_hooks,
streaming=kwargs.get("stream", False),
streaming=streaming,
**other_kwargs,
)
async for _ in agent_runner.step_until_done(max_steps):
pass
@@ -377,7 +390,7 @@ class Context:
if not module_path:
_parts = []
module_part = tool.__module__.split(".")
flags = ["packages", "plugins"]
flags = ["builtin_stars", "plugins"]
for i, part in enumerate(module_part):
_parts.append(part)
if part in flags and i + 1 < len(module_part):
@@ -12,7 +12,6 @@ class PlatformAdapterType(enum.Flag):
TELEGRAM = enum.auto()
WECOM = enum.auto()
LARK = enum.auto()
WECHATPADPRO = enum.auto()
DINGTALK = enum.auto()
DISCORD = enum.auto()
SLACK = enum.auto()
@@ -27,7 +26,6 @@ class PlatformAdapterType(enum.Flag):
| TELEGRAM
| WECOM
| LARK
| WECHATPADPRO
| DINGTALK
| DISCORD
| SLACK
@@ -49,7 +47,6 @@ ADAPTER_NAME_2_TYPE = {
"discord": PlatformAdapterType.DISCORD,
"slack": PlatformAdapterType.SLACK,
"kook": PlatformAdapterType.KOOK,
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
"vocechat": PlatformAdapterType.VOCECHAT,
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
"satori": PlatformAdapterType.SATORI,

Some files were not shown because too many files have changed in this diff Show More