Compare commits

...

135 Commits

Author SHA1 Message Date
Soulter a21bb5b234 chore: bump version to 4.20.0 2026-03-13 00:33:36 +08:00
Soulter 994d39241e chore: ruff format 2026-03-13 00:26:40 +08:00
2ndelement e6c1164755 perf(QQ Official API): improve streaming message delivery reliability and proactive media sending (#6131)
* fix(qqofficial): fix streaming message delivery for C2C

* fix(qqofficial): rewrite send_streaming for C2C vs non-C2C split

* fix(qqofficial): add _extract_response_message_id for safe id extraction

* fix(qqofficial): flush stream segment on tool-call break signal

* fix(qqofficial): downgrade rich-media to non-stream send in C2C

* fix(qqofficial): auto-append \n to final stream chunk (state=10)

* fix(qqofficial): propagate stream param to all _send_with_markdown_fallback call sites

* fix(qqofficial): retry on STREAM_MARKDOWN_NEWLINE_ERROR with newline fix

* fix(qqofficial): handle None/non-dict response in post_c2c_message gracefully

* fix(qqofficial): remove msg_id from video/file media payloads in send_by_session

QQ API rejects msg_id on proactive media (video/file, msg_type=7) messages
sent via the tool-call path, returning "请求参数msg_id无效或越权". The
msg_id passive-reply credential is consumed by the first send and cannot be
reused for subsequent media uploads in the same session.

Remove msg_id from the payload after setting msg_type=7 for video and file
sends, for both FRIEND_MESSAGE (C2C) and GROUP_MESSAGE paths.

* fix(qqofficial): replace deprecated get_event_loop() with get_running_loop()

asyncio.get_event_loop() is deprecated since Python 3.10 and raises a
DeprecationWarning (or errors) when called from inside a running coroutine
without a current event loop set on the thread.  Replace both call-sites
in the streaming throttle logic with asyncio.get_running_loop(), which is
the correct API to use inside an already-running async context.

Co-Authored-By: Claude Sonnet <noreply@anthropic.com>

---------

Co-authored-by: 2ndelement <2ndelement@users.noreply.github.com>
Co-authored-by: Claude Sonnet <noreply@anthropic.com>
2026-03-13 00:24:15 +08:00
Aleksandr 89cc8a1a65 feat: add Russian translation (#6081)
* feat: add Russian translation

* revert: remove auth route changes from PR
2026-03-13 00:08:37 +08:00
Stable Genius c0e4f1e114 fix(dashboard): restore README dialog anchor navigation (#6083)
Co-authored-by: stablegenius49 <185121704+stablegenius49@users.noreply.github.com>
2026-03-13 00:02:45 +08:00
Stable Genius 7b43448ce4 fix: prefer named weekday cron examples (#6091)
Co-authored-by: stablegenius49 <185121704+stablegenius49@users.noreply.github.com>
2026-03-12 23:57:45 +08:00
orbisai0security bdac0b65f4 fix: resolve critical vulnerability V-004 (#6093)
Automatically generated security fix

Co-authored-by: orbisai0security <orbisai0security@users.noreply.github.com>
2026-03-12 23:53:47 +08:00
Gao Jinzhe cf9ee6f20c Merge pull request #6135 from advent259141/feat/add-community-links
docs: 添加 Astrbook 和玖帕喵社区链接
2026-03-12 23:11:19 +08:00
advent259141 01eae72a64 docs: 添加 Astrbook 和玖帕喵社区链接 2026-03-12 23:05:00 +08:00
letr bca1476eab fix(extension): refresh plugin market install state after install (#6124)
* fix(extension): refresh market install state after plugin install

* chore: remove redundant call

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-03-12 20:19:00 +08:00
エイカク fbcbde0a4b chore: update dependency and workflow versions (#6119) 2026-03-12 20:18:23 +09:00
エイカク 3914d766db fix: install only missing plugin dependencies (#6088)
* chore: ignore local worktrees

* fix: install only missing plugin dependencies

* fix: harden missing dependency install fallback

* fix: clarify dependency install fallback logging

* refactor: simplify dependency install test helpers

* refactor: reuse requirements precheck planning
2026-03-12 11:50:29 +09:00
DOHEX 3e2cb6a2ab fix(telegram): remove deprecated normalize_whitespace param from (#6044)
telegramify_markdown.markdownify calls
2026-03-12 00:34:07 +08:00
莫思潋 25830524f3 fix(docs): typo in docker.md & napcat.md (#6048)
* Fix wording in admin ID configuration instructions

* Update docker.md
2026-03-12 00:30:31 +08:00
Soulter 304094630c perf: optimize booter selection for edge cases and message sending tool (#6064)
* feat: add video message support and enhance message type descriptions in SendMessageToUserTool

* feat: add error handling for disabled sandbox runtime in get_booter function
2026-03-12 00:29:52 +08:00
Soulter 5c3643c54c feat: added support for file, voice, and video messages for QQ Official Bot (including WebSocket mode). (#6063) 2026-03-12 00:26:08 +08:00
エイカク 589cce18af fix: improve Windows local skill file reading (#6028)
* chore: ignore local worktrees

* fix: improve Windows local skill file reading

* fix: address Windows path and decoding review feedback

* fix: simplify shell decoding follow-up

* fix: harden sandbox skill prompt metadata

* fix: preserve safe sandbox skill summaries

* fix: relax sandbox summary sanitization

* fix: tighten path sanitization for skill prompts

* fix: harden sandbox skill display metadata

* fix: preserve Unicode skill paths in prompts

* fix: quote Windows skill prompt paths

* fix: simplify local shell output decoding

* fix: localize Windows prompt path handling

* fix: normalize Windows-style skill paths in prompts

* fix: align prompt and shell decoding behavior
2026-03-11 23:58:28 +09:00
Soulter e254caf82d fix(docs): add official developer group ID to multiple language READMEs and enhance regex description in config metadata 2026-03-11 21:26:11 +08:00
Soulter 7efcd242d6 fix(docs): update edit link patterns and remove obsolete repository reference 2026-03-11 17:42:42 +08:00
JIANG Zijun 5d811d3949 fix: Persist Discord pre-ack emoji config across restart by adding missing default key (#6031)
* Initial plan

* fix: add discord default platform_specific pre-ack config

Co-authored-by: Jzjerry <20167827+Jzjerry@users.noreply.github.com>

* Delete tests/unit/test_config.py

we don't need to add tests

* fix: use 🤔 as default discord pre-ack emoji

Co-authored-by: Jzjerry <20167827+Jzjerry@users.noreply.github.com>

* add back old test config

* doc: discord pre-ack-emoji doc

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Jzjerry <20167827+Jzjerry@users.noreply.github.com>
2026-03-11 16:41:08 +08:00
Flartiny 8e6aaee10c fix(webui): unify search input clear behavior (#6017)
* fix(webui): unify search input clear behavior

* fix: centralize search input normalization
2026-03-11 15:14:16 +08:00
エイカク 6da59cfb07 fix: 插件依赖自动安装逻辑与 Dashboard 安装体验优化 (#5954)
* fix: install plugin requirements before first load

* fix: handle pip option arguments correctly

* fix: harden pip install input parsing

* refactor: simplify pip install input parsing

* fix: align plugin dependency install handling

* fix: respect configured pip index overrides

* test: parameterize plugin dependency install flows

* refactor: simplify multiline pip input parsing

* fix: install plugin dependencies before loading

* fix: protect core dependencies from downgrades and simplify package input splitting

* fix: enhance dependency conflict reporting and improve user-facing warnings

* refactor: preserve pip log indentation and fix CodeQL URL sanitization alert

* fix: explicit re-export for DependencyConflictError to satisfy ruff F401

* test: enhance index override verification in pip installer tests

* fix: correctly map pip ERROR and WARNING outputs to proper log levels

* refactor: show specific version conflicts in DependencyConflictError and revert log level mapping

* refactor: simplify install() by decoupling pip logging, failure classification and constraint file management

* refactor: further simplify pip installer and requirement parsing logic

* refactor: simplify dependency installation logic and improve circular requirement reporting

* style: organize imports in astrbot/core/__init__.py

* refactor: optimize requirement parsing efficiency and flatten pip installer API

* style: fix import sorting in astrbot/core/__init__.py

* refactor: consolidate requirement parsing, optimize core protection, and improve exception propagation

* fix: preserve valid pip requirement parsing

* fix: skip empty pip installs and preserve blank output

* chore: normalize gitignore entry style

* fix: tighten pip trust and requirement parsing

* refactor: centralize pip install parsing and failure handling

* fix: redact pip argv credentials in logs

* fix: surface plugin dependency install errors

* fix: cache core constraints and clarify requirement installs

* fix: harden pip requirement parsing for plugin installs

* fix: simplify pip installer parsing internals

* fix: tighten pip installer parsing and redaction

* refactor: simplify plugin dependency install flow

* fix: preserve core constraint conflict errors

* fix: harden pip installer fallback resolution

* refactor: split pip requirement and constraint helpers

* refactor: simplify pip installer helper flow

* refactor: streamline requirement precheck helpers

* refactor: clarify core constraint resolution

* fix: surface pip install failures explicitly

* refactor: separate pip conflict context parsing

* fix: harden core constraint resolution

* test: cover pip installer failure call sites

* refactor: remove dead requirements fallback helper

* refactor: narrow core constraint error handling

* refactor: unify requirement iteration

* refactor: share requirement name parsing

* test: align pip helper coverage

* fix: bind pip output limit at runtime

* refactor: reuse core requirement parser for tokens
2026-03-11 14:21:55 +09:00
Soulter 10ceacfbb1 chore: bump version to 4.19.5 2026-03-11 00:17:14 +08:00
ChuwuYo 66f5ccd902 fix: add file size validation to TTS provider test and MiniMax empty audio detection (#5999)
- Add audio data validation in MiniMax TTS get_audio() method to detect empty responses
- Validate generated audio file size in TTSProvider.test() to ensure valid output
- Provide detailed error messages guiding users to check group_id configuration
- Auto-cleanup test audio files after validation
- Fixes issue where 0KB audio files would pass TTS detection when group_id is not configured
2026-03-11 00:07:19 +08:00
Soulter 3379587223 feat(mcp): enhance logging and initialize MCP clients in background (#5993)
* feat(mcp): enhance logging and initialize MCP clients in background

fixes: #5777

* rf

* fix(mcp): simplify MCP client initialization in background

* fix(mcp): update error message for MCP background initialization failure
2026-03-11 00:00:48 +08:00
邹永赫 e25a1a42cf Revert "fix: clarify missing MCP stdio command errors (#5992)"
This reverts commit 0c771e4a77.
2026-03-11 00:08:06 +09:00
エイカク 0c771e4a77 fix: clarify missing MCP stdio command errors (#5992)
* fix: clarify missing MCP stdio command errors

* refactor: tighten MCP error presentation helpers

* fix: improve MCP test connection feedback

* fix: structure MCP test connection errors

* refactor: share MCP test error codes
2026-03-10 23:05:50 +09:00
camera-2018 ec21cb13d3 feat(lark): supports CardKit streaming output for feishu (#5777)
* feat(lark): 支持飞书 CardKit 流式输出

* refactor(lark): extract streaming fallback logic and deduplicate final text update

* fix(lark): 修复流式输出竞态条件及增强健壮性

- 修复 sender loop 中 delta 快照竟态: await 期间 delta 被 generator
  更新导致 last_sent 记录了未发送的值, 造成输出卡在最后一段
- send_streaming 入口增加 platform_meta 守卫, 未启用时直接回退
- _fallback_send_streaming 移除对已耗尽 generator 的 super() 调用,
  改为内联父类副作用 (Metric.upload + _has_send_oper)
- Metric.upload 统一改为 await, 确保指标上报在方法返回前完成
- 装饰器 support_streaming_message 改为 False, 与 meta() 动态配置对齐
- i18n hint 补充提示: 需在「AI 配置 → 其他配置」中开启流式输出

* chore(lark): 收口配置

* docs(lark): update streaming output instructions and client version requirements

---------

Co-authored-by: bread-ovo <2570425204@qq.com>
Co-authored-by: Soulter <905617992@qq.com>
2026-03-10 19:40:46 +08:00
Soulter 1d26b96d90 fix(workflow): update build-docs.yml to trigger on version tags instead of master branch 2026-03-10 17:16:56 +08:00
一袋米要扛幾樓 be017c87f4 fix: 前端修正切換到 chat 切換後回 welcome 的配置保存最終切換頁面 (#5792)
* 前端修正切換到chat切換後回 welcome 的配置保存最終切換頁面

* 修復 SSR 不含localStorage 環境驗證
2026-03-10 17:14:28 +08:00
lustresixx 23fffa95c8 fix(provider): support 84-char Azure TTS subscription keys (#5813)
* fix(provider): support 84-char Azure TTS subscription keys

* test(provider): add negative Azure TTS key validation cases

* chore: delete test

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-10 17:09:13 +08:00
dependabot[bot] 5b303e2e6d chore(deps): bump the github-actions group with 7 updates (#5966)
Bumps the github-actions group with 7 updates:

| Package | From | To |
| --- | --- | --- |
| [actions/setup-node](https://github.com/actions/setup-node) | `2` | `6` |
| [actions/checkout](https://github.com/actions/checkout) | `4` | `6` |
| [actions/setup-python](https://github.com/actions/setup-python) | `5` | `6` |
| [docker/setup-qemu-action](https://github.com/docker/setup-qemu-action) | `3` | `4` |
| [docker/setup-buildx-action](https://github.com/docker/setup-buildx-action) | `3` | `4` |
| [docker/login-action](https://github.com/docker/login-action) | `3` | `4` |
| [docker/build-push-action](https://github.com/docker/build-push-action) | `6` | `7` |


Updates `actions/setup-node` from 2 to 6
- [Release notes](https://github.com/actions/setup-node/releases)
- [Commits](https://github.com/actions/setup-node/compare/v2...v6)

Updates `actions/checkout` from 4 to 6
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v4...v6)

Updates `actions/setup-python` from 5 to 6
- [Release notes](https://github.com/actions/setup-python/releases)
- [Commits](https://github.com/actions/setup-python/compare/v5...v6)

Updates `docker/setup-qemu-action` from 3 to 4
- [Release notes](https://github.com/docker/setup-qemu-action/releases)
- [Commits](https://github.com/docker/setup-qemu-action/compare/v3...v4)

Updates `docker/setup-buildx-action` from 3 to 4
- [Release notes](https://github.com/docker/setup-buildx-action/releases)
- [Commits](https://github.com/docker/setup-buildx-action/compare/v3...v4)

Updates `docker/login-action` from 3 to 4
- [Release notes](https://github.com/docker/login-action/releases)
- [Commits](https://github.com/docker/login-action/compare/v3...v4)

Updates `docker/build-push-action` from 6 to 7
- [Release notes](https://github.com/docker/build-push-action/releases)
- [Commits](https://github.com/docker/build-push-action/compare/v6...v7)

---
updated-dependencies:
- dependency-name: actions/setup-node
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/setup-python
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: docker/setup-qemu-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: docker/setup-buildx-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: docker/login-action
  dependency-version: '4'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: docker/build-push-action
  dependency-version: '7'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-10 16:56:52 +08:00
Soulter fc33b3eb68 docs: transfer AstrBotDevs/AstrBot-docs to AstrBotDevs/AstrBot (#5960)
* docs: transfer AstrBotDevs/AstrBot-docs to AstrBotDevs/AstrBot
* refactor: reorder imports and improve type hints in sync_docs_to_wiki.py and upload_doc_images_to_r2.py
* feat: add GitHub Actions workflow to sync wiki with documentation

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: anka-afk <110004162+anka-afk@users.noreply.github.com>
Co-authored-by: zouyonghe <62183434+zouyonghe@users.noreply.github.com>
Co-authored-by: shuiping233 <49360196+shuiping233@users.noreply.github.com>
Co-authored-by: LIghtJUNction <106986785+LIghtJUNction@users.noreply.github.com>
Co-authored-by: Sjshi763 <179909421+Sjshi763@users.noreply.github.com>
Co-authored-by: xiewoc <70128845+xiewoc@users.noreply.github.com>
Co-authored-by: QingFeng-awa <151742581+QingFeng-awa@users.noreply.github.com>
Co-authored-by: PaloMiku <96452465+PaloMiku@users.noreply.github.com>
Co-authored-by: shangxueink <138397030+shangxueink@users.noreply.github.com>
Co-authored-by: IGCrystal-A <244300990+IGCrystal-A@users.noreply.github.com>
Co-authored-by: RC-CHN <67079377+RC-CHN@users.noreply.github.com>
Co-authored-by: MC090610 <113341105+MC090610@users.noreply.github.com>
Co-authored-by: Waterwzy <196913419+Waterwzy@users.noreply.github.com>
Co-authored-by: Lanhuace-Wan <186303160+Lanhuace-Wan@users.noreply.github.com>
Co-authored-by: LiAlH4qwq <61769640+LiAlH4qwq@users.noreply.github.com>
Co-authored-by: HSOS6 <209910899+HSOS6@users.noreply.github.com>
Co-authored-by: th-dd <162813557+th-dd@users.noreply.github.com>
Co-authored-by: miaoxutao123 <81676466+miaoxutao123@users.noreply.github.com>
Co-authored-by: nuomicici <143102889+nuomicici@users.noreply.github.com>
Co-authored-by: nasyt233 <210103278+nasyt233@users.noreply.github.com>
Co-authored-by: jlugjb <7426462+jlugjb@users.noreply.github.com>
Co-authored-by: Raven95676 <176760093+Raven95676@users.noreply.github.com>
Co-authored-by: Futureppo <180109455+Futureppo@users.noreply.github.com>
Co-authored-by: MliKiowa <61873808+MliKiowa@users.noreply.github.com>
Co-authored-by: Fridemn <150212937+Fridemn@users.noreply.github.com>
Co-authored-by: BakaCookie520 <138355736+BakaCookie520@users.noreply.github.com>
Co-authored-by: YumeYuka <125112916+YumeYuka@users.noreply.github.com>
Co-authored-by: xming521 <32786500+xming521@users.noreply.github.com>
Co-authored-by: ywh555hhh <121592812+ywh555hhh@users.noreply.github.com>
Co-authored-by: stevessr <89645372+stevessr@users.noreply.github.com>
Co-authored-by: roeseth <41995115+roeseth@users.noreply.github.com>
Co-authored-by: ikun-1145141 <265925499+ikun-1145141@users.noreply.github.com>
Co-authored-by: evpeople <54983536+evpeople@users.noreply.github.com>
Co-authored-by: Yue-bin <60509781+Yue-bin@users.noreply.github.com>
Co-authored-by: W1ndys <109416673+W1ndys@users.noreply.github.com>
Co-authored-by: TheFurina <218887821+TheFurina@users.noreply.github.com>
Co-authored-by: Seayon <12275933+Seayon@users.noreply.github.com>
Co-authored-by: OnlyblackTea <38585636+OnlyblackTea@users.noreply.github.com>
Co-authored-by: ocetars <74854972+ocetars@users.noreply.github.com>
Co-authored-by: railgun19457 <117180744+railgun19457@users.noreply.github.com>
Co-authored-by: JunieXD <107397009+JunieXD@users.noreply.github.com>
Co-authored-by: advent259141 <197440256+advent259141@users.noreply.github.com>
Co-authored-by: Doge2077 <91442300+Doge2077@users.noreply.github.com>
Co-authored-by: Bocity <23430545+Bocity@users.noreply.github.com>
Co-authored-by: Aurora-xk <192227833+Aurora-xk@users.noreply.github.com>
2026-03-09 23:38:21 +08:00
ChuwuYo 795aec9578 feat(extension): add filtering and sorting for installed plugins in WebUI (#5923)
* feat(extension): add PluginSortControl reusable component for sorting

* i18n: add i18n keys for plugin sorting and filtering features

* feat(extension): add sorting and status filtering for installed plugins

Backend changes (plugin.py):
- Add _resolve_plugin_dir method to resolve plugin directory path
- Add _get_plugin_installed_at method to get installation time from file mtime
- Add installed_at field to plugin API response

Frontend changes (InstalledPluginsTab.vue):
- Import PluginSortControl component
- Add status filter toggle (all/enabled/disabled) using v-btn-toggle
- Integrate PluginSortControl for sorting options
- Add toolbar layout with actions and controls sections

Frontend changes (MarketPluginsTab.vue):
- Import PluginSortControl component
- Replace v-select + v-btn combination with unified PluginSortControl

Frontend changes (useExtensionPage.js):
- Add installedStatusFilter, installedSortBy, installedSortOrder refs
- Add installedSortItems and installedSortUsesOrder computed properties
- Add sortInstalledPlugins function with multi-criteria support
- Support sorting by install time, name, author, and update status
- Add status filtering in filteredPlugins computed property
- Disable default table sorting by setting sortable: false

* test: add tests for installed_at field in plugin API

- Assert all plugins have installed_at field in get_plugins response
- Assert installed_at is not null after plugin installation

* fix(extension): add explicit fallbacks for installed plugin sort comparisons

* i18n(extension): rename install time label to last modified

* fix(extension): cache installed_at parsing and validate timestamp format in tests

* test(dashboard): strengthen installed_at coverage for plugin API
2026-03-09 17:12:22 +09:00
Soulter 7d31140c14 chore: bump version to 4.19.4 2026-03-09 11:13:39 +08:00
Soulter 654112ca86 feat(wecomai): implement long connection mode and update configuration options (#5930) 2026-03-09 11:10:32 +08:00
Soulter 5dd30f9a45 chore: bump version to 4.19.3 2026-03-09 00:20:33 +08:00
Jason a53a1ca49b fix(provider): handle MiniMax ThinkingBlock when max_tokens reached (#5913)
* fix(provider): handle MiniMax ThinkingBlock when max_tokens reached

Fixes #5912

Problem: MiniMax API returns ThinkingBlock when stop_reason='max_tokens',
but AstrBot throws 'completion 无法解析' exception because both
completion_text and tools_call_args are empty.

Root cause: The validation logic didn't consider ThinkingBlock
(reasoning_content) as valid content.

Fix: When completion_text and tools_call_args are empty but
reasoning_content is present, treat it as valid instead of throwing
exception. This happens when the model thinks but runs out of tokens
before generating the actual response.

Impact: MiniMax models now work correctly when responses are truncated
due to max_tokens limit.

* refactor: address review feedback

1. Use getattr for safe stop_reason access (prevent AttributeError)
2. Use ValueError instead of generic Exception for better error handling

Thanks @gemini-code-assist and @sourcery-ai for the review!

* refactor: flatten nested if/else with guard clause

Address Gemini Code Assist feedback:
- Use guard clause for early return
- Flattened nested conditional for better readability

Logic unchanged, just cleaner code structure.

* fix(provider): improve logging for ThinkingBlock completions in ProviderAnthropic

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-09 00:17:11 +08:00
whatevertogo 3fd6c4c8a6 fix: 修复 asyncio 事件循环相关问题 (#5774)
* fix: 修复 asyncio 事件循环相关的问题

1. components.py: 修复异常处理结构错误
   - 将 except Exception 移到正确的内部 try 块
   - 确保 _download_file() 异常能被正确捕获和记录

2. session_lock.py: 修复跨事件循环 Lock 绑定问题
   - 添加 _access_lock_loop_id 追踪事件循环
   - 当事件循环变化时重新创建 Lock

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: 根据代码审查反馈修复问题

1. components.py: 移除 asyncio.set_event_loop() 调用
   - 创建临时 event loop 时不再设置为全局
   - 避免干扰其他 asyncio 使用

2. session_lock.py: 简化延迟初始化逻辑
   - 移除 loop-ID 追踪和 _get_lock 方法
   - 使用 setdefault 简化 session lock 创建
   - 保留延迟初始化行为

3. wecomai_queue_mgr.py: 使用 time.monotonic() 替代 loop.time()
   - 同步方法不再依赖活动的 event loop
   - 避免在非异步上下文中抛出 RuntimeError

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: 优化 asyncio 事件循环管理,使用安全的方式创建和关闭事件循环

* fix: 根据代码审查反馈改进异常处理和事件循环使用

- main.py: 显式处理 check_dashboard_files() 返回 None 的情况
- components.py: 使用 logger.exception 保留异常堆栈信息
- star_manager.py: 添加 Future 异常回调处理 __del__ 执行异常
- bay_manager.py: 缓存事件循环引用避免重复调用

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor: 简化 SessionLockManager 使用 defaultdict 和 setdefault

- 使用 defaultdict(asyncio.Lock) 简化锁的懒创建
- 使用 setdefault 简化 _get_loop_state 逻辑
- 减少 get + if 分支,提升可读性

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: 降低 webui_dir 检查失败时的日志级别为 warning

改为警告而非退出,允许程序在无 WebUI 的情况下继续运行

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor: 重构事件循环锁管理,简化锁状态管理逻辑

* 新增对 SessionLockManager 的多事件循环隔离测试

* fix: 修复测试中的变量声明和断言,确保事件循环管理器的正确性

* fix: 修复插件删除时异常处理逻辑,确保正确记录错误信息

* fix: 新增针对多个事件循环的 OneBot 实例的测试,确保锁对象在不同事件循环间不共享

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 01:00:13 +09:00
sanyekana 5808784f07 fix: prevent crash on malformed MCP server config (#5666) (#5673)
* fix: prevent crash on malformed MCP server config (#5666)

* fix: prevent crash on malformed MCP server config (#5666)

* fix: validate MCP connection before persisting server config

* fix: guard mcpServers type before iterating server list

* refactor: use typed empty-config error and extract MCP rollback helper

* fix: translate error messages and comments to English for consistency

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-08 23:46:32 +08:00
Soulter 537849c1e7 fix(dingtalk): text is ignored; cannot send file actively (#5921) 2026-03-08 23:31:11 +08:00
Soulter 7f3c0fdeb2 fix: cannot receive image, file in dingtalk (#5920)
fixes: #5916 #5786
2026-03-08 23:18:56 +08:00
Windy_cold 8e431e2076 correct openrouter api_base (#5911) 2026-03-08 21:53:56 +09:00
ChuwuYo 89c11fd683 fix(extension): support searching installed plugins by display name (#5806) (#5811)
* fix(extension): support searching installed plugins by display name

* fix: unify plugin search matching across installed and market tabs

* refactor(extension): optimize plugin search matcher and remove redundant checks

* refactor(extension-page): centralize search query normalization and text matching logic

- Extract `buildSearchQuery` to create normalized query objects from raw input
- Extract `matchesText` as a reusable text matching helper for normalized/loose/pinyin/initials matching
- Remove unused `marketCustomFilter` to eliminate dead code
- Simplify `matchesPluginSearch` to accept query object instead of pre-normalized string
- Replace Set with Array for candidates to simplify control flow
- Avoid redundant normalization by having callers pass raw strings to `buildSearchQuery`

* refactor: remove unused marketCustomFilter from extension page components

- Remove marketCustomFilter from destructuring in ExtensionPage.vue, InstalledPluginsTab.vue, and MarketPluginsTab.vue

* refactor(extension): extract plugin search utilities into shared module

- Create pluginSearch.js to centralize plugin search helpers
- Move `normalizeStr`, `normalizeLoose`, `toPinyinText`, and `toInitials` into the shared module
- Add `buildSearchQuery`, `matchesText`, and `matchesPluginSearch` for reusable search matching
- Refactor useExtensionPage.js to consume the shared utilities
- Simplify plugin search logic by consolidating normalization and matching in one place

* refactor(extension): add caching to pinyin utilities and extract search fields helper

- Add Map-based caching for `toPinyinText` and `toInitials` to avoid redundant pinyin computation
- Extract `getPluginSearchFields` function to retrieve plugin fields for searching
- Improve plugin search performance with caching and better code organization

* perf(extension): add bounded caching for plugin search

- cap normalization and pinyin caches with `MAX_SEARCH_CACHE_SIZE`
- add `setCacheValue()` for oldest-entry eviction
- cache normalized and loose text values to avoid repeated string processing
- skip pinyin matching for non-CJK text using Unicode `\p{Unified_Ideograph}` property
- improve search performance while keeping memory usage bounded

* refactor(extension): extract memoizeLRU helper for cache management

- Create `memoizeLRU` higher-order function to generate LRU-cached functions
- Replace manual cache implementation with `memoizeLRU` for cleaner code
- Optimize `matchesText` to lazily compute looseValue only when needed
- Simplify caching logic while maintaining bounded cache size

* refactor(extension): simplify memoization and remove LRU logic

- Rename `memoizeLRU` to `memoizeStringFn` and remove bounded cache size
- Simplify cache hit logic for cleaner code
- Remove `MAX_SEARCH_CACHE_SIZE` constant as it's no longer needed
2026-03-08 17:41:45 +09:00
時壹 7cfe2aca99 fix: apply reply_with_quote and reply_with_mention to image-only response (#5219)
* fix: apply reply_with_quote and reply_with_mention to image-only responses

* fix: restrict reply_with_quote and reply_with_mention to plain-text/image chains
2026-03-08 17:41:12 +09:00
時壹 3a938d2a13 fix: use re.search instead of re.match in RegexFilter (#5368) 2026-03-08 17:40:40 +09:00
whatevertogo 812834bc9f feat(skills): add batch upload functionality for multiple skill ZIP files (#5804)
* feat(skills): add batch upload functionality for multiple skill ZIP files

- Implemented a new endpoint for batch uploading skills.
- Enhanced the SkillsSection component to support multiple file selection and drag-and-drop functionality.
- Updated localization files for new upload features and messages.
- Added tests to validate batch upload behavior and error handling.

* feat(skills): improve batch upload handling and enhance accessibility for dropzone

* feat(skills): enhance batch upload process and improve UI for better user experience

* feat(skills): enhance skills upload dialog layout and styling for improved usability

* feat(skills): update upload dialog description styling for better visibility and usability

* feat(skills): improve upload dialog button styling and layout for enhanced usability

* feat(skills): refine upload dialog text for clarity and consistency

* feat(skills): enhance batch upload functionality by ignoring __MACOSX entries and improving upload dialog styling

* feat(skills): refactor upload dialog and button styles for improved consistency and usability

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
2026-03-07 23:18:01 +08:00
エイカク 51ff4f6e46 fix: detect desktop runtime without frozen python (#5859)
* fix: detect desktop runtime without frozen python

* chore: drop planning docs from runtime fix pr
2026-03-07 21:42:56 +09:00
Soulter 7ac169c5e8 docs: add macOS usage note and update instructions for astrbot in multiple README files 2026-03-06 14:34:29 +08:00
Soulter 61648ebe3e docs: add new QQ group entries to README files 2026-03-06 11:11:12 +08:00
Soulter 0610f0db0a fix: pipeline scheduler not found after creating platform bot via using 'create new config' (#5776) 2026-03-05 23:53:53 +08:00
whatevertogo 8c935981bb fix: align aiocqhttp poke payload with onebot v11 (#5773)
Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
2026-03-05 23:02:26 +08:00
Ruochen Pan 3f3b4e4924 test(skill_manager): update sandbox cache path expectations (#5706)
* test(skill_manager): update sandbox cache path expectations

adjust sandbox cache tests to match absolute path resolution in
list_skills for sandbox runtime.

verify sandbox-cached skills cannot be deactivated via set_skill_active
by asserting a PermissionError, and keep active-only listing behavior
intact.

add coverage for show_sandbox_path=false to ensure local skills still
override cached metadata while sandbox-only skills retain cached paths.

* test(skill_manager): tighten local skill path assertions
2026-03-05 22:47:20 +08:00
Soulter af581e7f21 chore: bump version to 4.19.2 2026-03-05 16:10:09 +08:00
Soulter 9e371ee10b chore: update shipyard-neo-sdk dependency to version 0.2.0 2026-03-05 16:07:10 +08:00
camera-2018 7cf77adbc8 feat(telegram): supports sendMessageDraft API (#5726)
* feat(telegram): 使用 sendMessageDraft API 实现私聊流式输出

- 新增 _send_message_draft 方法封装 Telegram Bot API sendMessageDraft
- 私聊流式输出使用 sendMessageDraft 推送草稿动画,群聊保留 edit_message_text 回退
- 使用独立异步发送循环 (_draft_sender_loop) 按固定间隔推送最新缓冲区内容,
  完全解耦 token 到达速度与 API 网络延迟
- 流式结束后发送真实消息保留最终内容(draft 是临时的)
- 使用模块级递增 draft_id 替代随机生成,确保 Telegram 端动画连续性

* fix(telegram): convert draft text to Markdown before sending message draft

* chore(telegram): telegram 适配器重构

- 提取公共方法
- 有新 token 到达时触发流式
- 生成结束后清除draft内容
- 默认draft发送md格式

* style(telegram): ruff format

* style(telegram): ruff check

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-05 11:20:28 +08:00
Soulter 31673ee521 fix: require node.js env when uv sync 2026-03-05 11:20:16 +08:00
エイカク ff22030dde docs: align deployment sections across multilingual readmes (#5734)
* docs: align deployment sections across multilingual readmes

* docs: normalize deployment punctuation and AUR guidance

* docs: fix french and russian deployment wording
2026-03-05 11:19:27 +08:00
Soulter 101580fd77 chore: add sponsors section to README
Added a sponsors section with an image link.
2026-03-03 19:08:12 +08:00
Soulter 418f05f6e4 fix: tests 2026-03-03 16:06:49 +08:00
Soulter df421e5554 fix: test 2026-03-03 16:04:08 +08:00
shuiping233 ed84074a60 unittest: 添加之前遗漏的kook_card_data.json (#5703) 2026-03-03 16:01:26 +08:00
Soulter bbf61239ad fix(kook): remove debug logging for received messages and heartbeat responses 2026-03-03 15:54:45 +08:00
miaoxutao123 92ee534a2c feat: add OS information to tool descriptions and implement unit tests (#5677)
* feat: add OS information to tool descriptions and implement unit tests

* refactor: use module-level constant for OS name as suggested in PR review
2026-03-03 15:16:38 +08:00
L1ngg fa4df0b5f3 fix(core): correctly parse DEMO_MODE as boolean from env var. (#5676)
* fix(core): correctly parse DEMO_MODE as boolean from env var.

* Update astrbot/core/__init__.py

fix(core): 添加.strip()以确保代码健壮性

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: sourcery-ai[bot] <58596630+sourcery-ai[bot]@users.noreply.github.com>
2026-03-03 15:15:20 +08:00
dependabot[bot] e5ac31efe7 chore(deps): bump the github-actions group with 2 updates (#5694)
Bumps the github-actions group with 2 updates: [actions/upload-artifact](https://github.com/actions/upload-artifact) and [actions/download-artifact](https://github.com/actions/download-artifact).


Updates `actions/upload-artifact` from 6 to 7
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v6...v7)

Updates `actions/download-artifact` from 7 to 8
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](https://github.com/actions/download-artifact/compare/v7...v8)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: '7'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/download-artifact
  dependency-version: '8'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-03 15:14:28 +08:00
時壹 2a7745c767 fix: only allow HTTPS URLs to pass through directly in LINE adapter (#5697) 2026-03-03 15:14:08 +08:00
Gargantua 82e7502f74 fix(dashboard): stabilize sidebar customization state (#5405) (#5670)
- use stable sidebar list keys to avoid vnode reuse drift

- sanitize persisted opened groups against current sidebar menu

- guard non-array customization keys from localStorage

Co-authored-by: Gargantua <22532097@zju.edu.cn>
2026-03-03 15:12:15 +08:00
shuiping233 866e546b59 feat: integrates KOOK platform adapter (#5658)
* feat: 将kook适配器插件并入astrbot官方适配器目录中

* refactor: 重命名函数名为 _warp_message

* refactor: 使用Protocol替换Union类型

* bugfix: 修复base64前缀处理问题

* refactor: 抛出的错误不再包含"[kook]"

* refactor: 添加读取本地文件时的路径安全检查

* refactor: 卡片消息解析失败时会打印错误信息

* refactor: 添加处理接收卡片消息内的图片url时的安全校验

* refactor: 安全处理ws需要重连的情况

* Revert "refactor: 使用Protocol替换Union类型"

This reverts commit 58e0dceeb20c3d7dddb16f623fd3bbdcfa632173.

* feat: 添加获取机器人名称的实现

* refactor: 让send_by_session发送主动消息时正确传入当前消息链的文本消息内容

* refactor: 统一处理适配器配置相关内容,处理仪表盘出传入配置,并添加仪表盘的kook适配器配置页面的i18n文本

* unittest: 添加kook适配器的单元测试,虽然没覆盖多少单测

* unittest: TEST_DATA_DIR用更安全的路径

* refactor: KookConfig使用了更好的默认值处理方式

* refactor: 移除kook_adapter 的config字段重复定义

* refactor: 隐藏获取kook gateway时url里的token,防止把token打印出来

* refactor: KookConfig.pretty_jsons使用*来屏蔽token内容

* bugfix: 修复主动发送消息时,调用了父方法`send_by_session`可能导致指标被重复上传的bug

* refactor: 优化upload_asset的路径处理报错

* bugfix: 修复kook ws心跳间隔可能会出现负数时间的bug

* refactor: KookClient移到KookPlatformAdapter.__init__里初始化

* bugfix: 修复处理base64 url 多替换了/而报错的bug

* refactor: kook适配器上传文件失败时,会抛出错误

* chore: 移除一条注释

* refactor: 移除没用的return

* refactor: 即使消息链中有消息发送失败了,也尽可能将其他消息发送出去,并把报错信息也发送出去

* refactor: 增强上传任务失败时的错误处理,使其发生错误时尽力而为发送其余消息

* refactor: 发送到消息频道的报错消息加了个⚠️,小巧思这块?

* refactor: 咱们在写适配器啊,要什么小巧思呢,小巧思给上游插件开发弄不好么)

* refactor: enhance Kook adapter with kmarkdown parsing and improve file URL handling

* refactor: extract card message parsing logic into a separate method

* feat: add kook_bot_nickname configuration to ignore messages from specific nicknames

* refactor: remove commented-out code and clean up file upload error handling

* fix: remove redundant prefix handling for file URLs in asset upload

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-03 15:08:16 +08:00
Soulter 6b642d7674 refactor: bundled webui static files into wheel and replace astrbot cli log with English (#5665)
* refactor: bundled webui static files into wheel and replace astrbot cli log with English

- Translated and standardized log messages in cmd_conf.py for better clarity.
- Updated initialization logic in cmd_init.py to provide clearer user prompts and error handling.
- Improved plugin management commands in cmd_plug.py with consistent language and error messages.
- Enhanced run command in cmd_run.py with clearer status messages and error handling.
- Updated utility functions in basic.py and plugin.py to improve readability and maintainability.
- Added version comparison logic in version_comparator.py with clearer comments.
- Enhanced logging configuration in log.py to suppress noisy loggers.
- Updated the updater logic in updator.py to provide clearer error messages for users.
- Improved IO utility functions in io.py to handle dashboard versioning more effectively.
- Enhanced dashboard server logic in server.py to prioritize bundled assets and improve user feedback.
- Updated pyproject.toml to include bundled dashboard assets and custom build hooks.
- Added a custom build script (hatch_build.py) to automate dashboard builds during package creation.

* refactor: improve exception messages and formatting in CLI command validation

* perf: change npm install to npm ci for consistent dependency installation

* fix
2026-03-03 12:58:59 +08:00
SJ 0711ec346f Fix/fix: resolve MCP tools race condition causing 'completion 无法解析' error (#5534)
* fix: resolve MCP tools race condition causing 'completion 无法解析' error

- Wait for MCP client initialization to complete before accepting requests
- Add Future-based synchronization in init_mcp_clients()
- Prevent tool_calls from being rejected due to empty func_list
- Improve error logging for MCP initialization failures

Fixes race condition where AI attempts to call MCP tools before they are
registered, resulting in 'API 返回的 completion 无法解析' exceptions.

The issue occurred because:
1. MCP clients were initialized asynchronously without waiting
2. System accepted user requests immediately after startup
3. AI received empty tool list and attempted to call non-existent tools
4. Tool matching failed, causing parsing errors

This fix ensures all MCP tools are loaded before the system processes
any requests that might use them.

* perf: add timeout and better error handling for MCP initialization

- Add 20-second total timeout to prevent slow MCP servers from blocking startup
- Show detailed configuration info when MCP initialization fails
- List all failed services in a summary warning
- Gracefully handle timeout by using already-completed services

This ensures that even if some MCP servers are slow or unreachable,
the system will start within a reasonable time and provide clear
feedback about which services failed and why.

* refactor: simplify MCP init orchestration and improve log security

- Replace Future-based sync with asyncio.wait + name→task mapping
- Explicitly cancel timed-out tasks after 20s timeout
- Downgrade sensitive config details (command/args/URL) to debug level
- Move urllib.parse import to top-level

* fix: prevent initialized MCP clients from being cleaned up on timeout

- Do not cancel pending tasks on timeout; let them continue running
  in the background waiting for the termination signal (event.set()),
  so successfully initialized services remain available
- Track initialization state with a flag to distinguish init failures
  from post-init cancellations in _init_mcp_client_task_wrapper

* fix: restore task cancellation on timeout per review feedback

Pending tasks in asyncio.wait are tasks that have NOT completed
initialization within 20s, so cancelling them is safe and correct.

* fix: separate init signal from client lifetime in MCP task wrapper

The previous design awaited task completion, but tasks only finish
on shutdown (after event.wait()), causing asyncio.wait to always
hit the 20s timeout and cancel all clients.

Fix: introduce a dedicated ready_event that is set immediately after
_init_mcp_client completes. init_mcp_clients now waits only for
ready_event (with 20s timeout), while the long-lived client task
continues running in the background until shutdown_event is set.

This ensures startup returns promptly once clients are ready.

* security: redact sensitive MCP config from debug logs

Only log executable name and argument count instead of full
command/args to avoid leaking tokens or credentials even at
debug level.

* refactor: use McpClientInfo dataclass and MCP_INIT_TIMEOUT constant

- Extract MCP_INIT_TIMEOUT = 20.0 as a named module-level constant
- Replace tuple-based client_info with _McpClientInfo dataclass to
  eliminate index-based access and improve readability
- Remove _wait_ready helper; use asyncio.create_task(event.wait()) directly
- Await cancelled tasks after timeout to prevent lingering background
  tasks and unobserved exceptions

* fix: handle CancelledError and clean up wait_tasks on timeout

- Catch asyncio.CancelledError separately in _init_mcp_client_task_wrapper
  so ready_event.set() is always called (Python 3.8+ CancelledError
  inherits BaseException, not Exception)
- Cancel and await lingering wait_tasks after timeout to prevent
  them from hanging indefinitely when ready_event is never set

* fix: align enable_mcp_server with new wrapper API and fix security/config issues

- Fix enable_mcp_server to pass shutdown_event + ready_event instead of
  ready_future, matching _init_mcp_client_task_wrapper's current signature
- Cancel and await init_task on timeout; clean up mcp_client_event on failure
- Read MCP_INIT_TIMEOUT from env var ASTRBOT_MCP_INIT_TIMEOUT (default 20s)
  so operators can tune it without code changes
- Strip userinfo from URL in debug log (use hostname+port only, not netloc)
  to avoid leaking credentials embedded in URLs

* refactor: register mcp_client_event only after successful init in enable_mcp_server

Move self.mcp_client_event[name] assignment to after initialization
succeeds, so callers never observe a stale event for a failed client.

* fix: harden MCP init state handling and timeout parsing

* fix: improve MCP timeout and post-init error observability

* refactor: simplify MCP init lifecycle orchestration

* refactor: simplify MCP init flow and cap timeout values

* fix: refine mcp timeout handling and lifecycle task tracking

* fix: harden mcp shutdown and timeout source logging

* refactor: simplify mcp runtime registry and timeout flow

* fix: keep mcp init summary return contract

* refactor: streamline mcp lifecycle and init errors

* refactor: unify mcp lifecycle wait handling

* refactor: simplify mcp runtime ownership and timeout resolution

* fix: harden mcp shutdown waiting and startup signaling

* refactor: streamline mcp lifecycle and shutdown errors

* refactor: harden mcp runtime access and shutdown

* fix: ensure mcp client cleanup and clarify views

* refactor: cache mcp client view and guard startup

* refactor: simplify mcp init cleanup and runtime lock

* refactor: reduce mcp runtime duplication

* refactor: reuse mcp cleanup and client view

---------

Co-authored-by: idiotsj <idiotsj@users.noreply.github.com>
Co-authored-by: 邹永赫 <1259085392@qq.com>
2026-03-03 01:09:45 +09:00
Copilot 0dbe32e2dc feat: add Discord pre-ack emoji support (#5609)
* Initial plan

* feat: add Discord pre-ack emoji support

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* feat: add Discord pre-acknowledgment emoji configuration in English and Chinese locales

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2026-03-02 14:38:12 +08:00
Soulter 4e855a17bc fix: update Discord command registration descriptions and hints in config metadata 2026-03-02 14:31:36 +08:00
Soulter f2fc724e0f fix: update tutorial links to use the correct path format 2026-03-02 14:22:56 +08:00
Copilot 460acf40c0 fix: apply max_agent_step config to subagents (#5608)
* Initial plan

* fix: apply max_agent_step config to subagents

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* fix: streamline max_agent_step and streaming_response retrieval in FunctionToolExecutor

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2026-03-02 14:16:14 +08:00
Soulter cf29d9390f chore: reorganize provider settings for quoted message parsing 2026-03-02 12:35:35 +08:00
Soulter ac44d1fdef feat: enhance chat interface and mobile responsiveness (#5635) 2026-03-02 12:26:55 +08:00
Soulter 66d0f0afd4 chore: remove deprecated websearch command from event filter 2026-03-02 11:51:58 +08:00
Ruochen Pan 2a7b4f6e64 Merge pull request #5028 from w31r4/feat/neo-skill-self-iteration
feat: 接入 Shipyard Neo 自迭代 Skill 闭环与管理能力
2026-03-02 09:40:32 +08:00
RC-CHN 6e1be64aef Merge branch 'feat/neo-skill-self-iteration' of https://github.com/w31r4/AstrBot into feat/neo-skill-self-iteration 2026-03-02 09:37:42 +08:00
RC-CHN f818ad0758 Merge remote-tracking branch 'origin/master' into feat/neo-skill-self-iteration 2026-03-02 09:37:06 +08:00
sanyekana 4abea2bd30 fix: harden backup import for duplicate platform stats (#5594)
* fix: harden backup import for duplicate platform stats

- 修复 replace 模式下主库清空失败仍继续导入的问题。
- 导入前对 platform_stats 重复键做聚合(count 累加),并统一时间戳判重格式。
- 非法 count 按 0 处理并告警(限流),补充对应测试。

* refactor: improve robustness and readability of platform stats import

- 告警上限魔法数字提取为模块常量 PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
- 抽取 parse_count 内联函数,消除重复的 try/except 分支
- 存储行的 timestamp 同步写入规范化值,避免落库格式混用
- 补充测试:已有行 count 非法、告警限流、replace 模式中断断言

* fix: normalize invalid platform_stats count for non-duplicate rows

* fix: avoid merging invalid platform_stats timestamps

* refactor: simplify platform stats merge and normalize naive UTC

* refactor: inline platform stats merge helpers

* refactor: flatten platform stats merge flow

* refactor: harden platform stats merge key handling

* refactor: streamline platform stats preprocessing

* refactor: simplify platform stats merge helpers

* refactor: inline platform stats merge normalization

* refactor: extract platform stats merge helpers

* refactor: simplify platform stats preprocessing flow

* refactor: flatten platform stats preprocess helpers

* refactor: streamline platform stats merge helpers

* refactor: isolate platform stats warning limiter

---------

Co-authored-by: 邹永赫 <1259085392@qq.com>
2026-03-01 20:46:35 +09:00
pandyzhou 267abfd552 fix: resolve /model command misleading behavior when switching to model from different provider (#5578)
* fix: /model command now auto-switches provider when model exists elsewhere

Made-with: Cursor

* fix: address Sourcery review - log get_models() failures in cross-provider lookup

Made-with: Cursor

* fix: integer branch exception handling and API key masking in model command

Made-with: Cursor

* fix: harden cross-provider model resolution

* fix: improve model lookup resilience and cache hygiene

* refactor: simplify model switch lookup flow

* refactor: streamline provider model cache updates

* fix: align provider annotations and key error flow

* fix: narrow provider command exception handling

* refactor: harden provider command error redaction and flow

* fix: improve provider model lookup and secret redaction

* refactor: cache normalized model names in provider lookup

* refactor: simplify provider model lookup helpers

* refactor: extract provider model lookup helpers

* fix: harden provider lookup cancellation and redaction

* refactor: streamline provider cache and lookup settings

* refactor: simplify provider command setting and update helpers

* refactor: streamline provider model lookup config usage

* refactor: flatten provider lookup settings and filter model lookup providers

* refactor: simplify provider cache and callback flow

* refactor: simplify provider command model cache flow

* refactor: scope provider model cache by session

* fix: preserve redaction context and restore provider hooks

* refactor: unify provider model lookup config flow

* refactor: inline provider model cache access flow

* fix: align provider lookup cache and callback semantics

* refactor: centralize provider model fetch error handling

* refactor: simplify provider model cache and lookup flow

---------

Co-authored-by: 邹永赫 <1259085392@qq.com>
2026-03-01 19:11:31 +09:00
Soulter 064495698f feat(i18n): add neoDeactivate messages for extension management 2026-02-28 15:45:41 +08:00
Soulter 7c913093b0 feat(i18n): add neoFilterHint for filtering candidates and release records 2026-02-28 15:27:50 +08:00
Soulter edf0982ce4 feat(skills): enhance candidate promotion buttons with loading and disabled states 2026-02-28 15:25:14 +08:00
RC-CHN a219a8b70d Merge remote-tracking branch 'origin/master' into feat/neo-skill-self-iteration 2026-02-27 15:25:50 +08:00
RC-CHN c1de265baf feat(skills): mark sandbox preset skills readonly
expose skill source metadata and sandbox cache status in the skills API
response so the dashboard can distinguish local, sandbox-only, and
synced skills.

prevent enabling, disabling, or deleting sandbox-only preset skills in
both backend guards and UI actions to avoid invalid local operations.

add source badges, discovery-pending hinting for sandbox runtime, and
new i18n strings for source labels and readonly warnings.
2026-02-27 15:22:07 +08:00
RC-CHN 13c8fa3f92 fix(skills): use workspace path for sandbox skills
default sandbox skill paths to /workspace/skills/<name>/SKILL.md
when loading config and when exposing sandbox paths.
preserve cached sandbox paths when available to avoid losing
resolved locations for existing skills.
2026-02-27 14:08:59 +08:00
RC-CHN 4ff4c5f1bf fix(skills): remove deleted skills from sandbox cache
keep sandbox skill cache in sync when deleting a skill from disk.
this prevents stale entries in the UI when no sandbox session is
active to refresh runtime cache
2026-02-26 16:52:02 +08:00
RC-CHN 73e665bef7 feat(neo): guide skill lifecycle tool workflow
Add explicit Neo lifecycle instructions to the main agent prompt so
skill creation and updates follow payload -> candidate -> promotion
instead of direct local folder writes.

Clarify lifecycle tool descriptions and parameter semantics, including
skill_key/source_execution_ids usage and stable release sync_to_local
behavior, to reduce ambiguity and improve consistent skill publishing.
2026-02-26 16:14:16 +08:00
w31rd 4b1bda5f2e Merge pull request #2 from camera-2018/feat/neo-skill-self-iteration
Feat/neo skill self iteration
2026-02-26 16:13:42 +08:00
camera-2018 18114eafda fix(neo): sanitize skill name in frontmatter to prevent injection
Sanitized the name field in SKILL.md frontmatter within astrbot/core/skills/neo_skill_sync.py. This prevents potential frontmatter injection vulnerabilities by removing newlines and control characters from the skill name. Verified the fix with a reproduction script and ensured existing tests pass.
2026-02-26 16:04:42 +08:00
camera-2018 87cbcc9875 fix(neo): sanitize skill name in frontmatter to prevent injection
Sanitize the `name` field in `SKILL.md` frontmatter to remove newlines and control characters. This prevents potential frontmatter injection vulnerabilities where a malicious skill name could introduce arbitrary YAML fields or corrupt the file structure.

- Modified `_ensure_skill_frontmatter` in `astrbot/core/skills/neo_skill_sync.py` to normalize whitespace in `name`.
- Ensured `name` is cast to string before splitting to handle non-string inputs safely.
2026-02-26 08:03:44 +00:00
RC-CHN 1ebc2070c0 fix(skills): gate neo mode by runtime config
Disable the Neo mode toggle unless runtime is sandbox with
shipyard_neo configured, and show a warning when Neo is unavailable.

Also avoid loading Neo data when the environment is not compatible and
fall back to local mode to prevent invalid requests and confusion.
2026-02-26 15:50:39 +08:00
RC-CHN e95bd8d3a6 style: format code 2026-02-26 15:27:37 +08:00
RC-CHN d5a3107f8f style: format code 2026-02-26 15:24:10 +08:00
RC-CHN 8d5841b71f feat(skills): add neo candidate and release deletion
Add backend routes to delete neo candidates and releases with optional
reason support and demo mode protection.

Expose delete actions in the Skills dashboard for candidate and release
rows, refresh data after success, and add localized success/failure
messages in en-US and zh-CN.
2026-02-26 14:48:20 +08:00
RC-CHN 8faed949c2 fix(skills): ensure synced markdown has frontmatter
Normalize SKILL.md content during sync so each file includes name and
description metadata in a frontmatter block.

Preserve existing frontmatter values when present, derive description
from markdown content when missing, and fallback to a default
description to keep metadata complete and consistent.
2026-02-26 11:10:09 +08:00
RC-CHN e1719efbc8 fix(skills): normalize release stage and handle rollback skip
Normalize release stage values before stability checks so enum-like
objects and mixed-case strings are handled consistently.

When stable sync fails, treat "no previous release exists" during
auto-rollback as a skipped rollback instead of raising a secondary
runtime error
2026-02-26 10:45:03 +08:00
RC-CHN f01c23ad40 fix(agent): enforce relative paths for neo sandbox tools
append a Shipyard Neo-specific system prompt note for filesystem
tool calls so paths are provided relative to the workspace root.
this prevents models from prepending `/workspace` and causing tool
path resolution failures
2026-02-26 10:33:22 +08:00
RC-CHN 847ef0f3f4 Merge remote-tracking branch 'origin/master' into feat/neo-skill-self-iteration 2026-02-26 10:04:48 +08:00
zenfun 48a0b97ac0 test(skills): add skill metadata enrichment tests
11 tests covering:
- _parse_frontmatter_description: standard, description-only, empty,
  missing delimiter, quoted values
- build_skills_prompt: format, absolute path in example, progressive
  disclosure rules, absence of legacy custom fields
- SkillManager.list_skills: local frontmatter parsing, sandbox cache
  description passthrough
2026-02-21 01:06:39 +08:00
zenfun d21212d0e4 test(computer): add profile-aware sandbox selection tests
17 tests covering:
- ShipyardNeoBooter.capabilities property (tuple, immutability, pre/post boot)
- _apply_sandbox_tools conditional browser tool registration
- _resolve_profile smart selection (user-specified, browser preference,
  API error fallback, empty profiles, auth error pass-through)
- ComputerBooter base class defaults
2026-02-21 01:03:58 +08:00
zenfun c1917ebf4f fix(computer): resolve absolute skill paths at runtime in scan command
- Resolve skills root via Path.resolve() so LLM prompts always
  reference absolute paths regardless of sandbox cwd
- Use resolved path in skill metadata for reliable cat/head commands
- Add DRY cross-reference comment for frontmatter parser
- Remove dead skills_root_abs field from JSON output (no consumer)
- Remove unnecessary os import and fake resolve/abspath branch
2026-02-21 01:03:45 +08:00
zenfun b816045f37 refactor(skills): rewrite skills prompt and sanitize example paths
- Rewrite build_skills_prompt() with structured numbered rules and
  markdown formatting for better LLM comprehension
- Sanitize example_path with _SAFE_PATH_RE before embedding in system
  prompt to prevent prompt injection via crafted skill paths
- Add docstring to _parse_frontmatter_description()
- Remove debug print(top_dirs) from install_skill_from_zip()
- Remove stale commented-out SANDBOX_SKILLS_ROOT line
2026-02-21 01:03:32 +08:00
zenfun 1df1138d04 feat(agent): conditionally register browser tools based on sandbox capabilities
_apply_sandbox_tools now checks the booted session's capabilities
before registering browser tools (BrowserExecTool, BrowserBatchExecTool,
RunBrowserSkillTool).

- If no session exists yet (first request), all tools are registered
  conservatively to avoid breaking the initial interaction
- If a session exists without browser capability, browser tools are
  omitted, preventing CapabilityNotSupportedError from Bay
- Skill lifecycle tools remain unconditionally registered
2026-02-21 01:03:19 +08:00
zenfun 1962ff2def feat(computer): expose sandbox capabilities and smart profile selection
Add capabilities property to ComputerBooter base class (returns None)
and ShipyardNeoBooter (returns immutable tuple from sandbox).

- Extract DEFAULT_PROFILE class constant to replace scattered magic string
- Use tuple[str, ...] for immutability (no defensive copy needed)
- Add _resolve_profile() for smart profile selection:
  - honour user-specified profile
  - query Bay API, prefer browser-capable profiles
  - re-raise auth errors (401/403), fallback on transient failures
- Conditionally create NeoBrowserComponent only when profile has browser
- Log resolved profile and capabilities at boot
2026-02-21 01:03:05 +08:00
zenfun 92a8e40cde feat(computer): auto-start Bay container for zero-config Neo integration
Add BayContainerManager to manage Bay container lifecycle via Docker
Engine API, similar to how BoxliteBooter manages Ship containers.

When ShipyardNeoBooter endpoint is empty or set to '__auto__', Bay is
automatically pulled, started, health-checked, and credentials are
read from the container.

- New bay_manager.py: ensure_running, wait_healthy, read_credentials
- Integrate auto-start into ShipyardNeoBooter boot/shutdown
- Reuse Bay container across sessions (unless-stopped policy)
- Friendly error messages for Docker and credential failures
2026-02-20 23:11:19 +08:00
zenfun 3769f145ee feat(dashboard): validate Bay connectivity on config save
When saving config with shipyard_neo sandbox, _validate_neo_connectivity()
performs an async /health check against the Bay endpoint. If Bay is
unreachable, a ⚠️ warning is appended to the success snackbar message.
Config still saves successfully — the warning is informational only.
2026-02-19 01:41:29 +08:00
zenfun 18ebeae318 test(computer): add tests for credentials discovery and config logging
19 tests in test_computer_config.py:
- TestDiscoverBayCredentials (9 tests): env priority, cwd fallback,
  missing file, empty key, malformed JSON, endpoint mismatch, slash normalization
- TestLogComputerConfigChanges (10 tests): runtime change, sandbox key change,
  token masking, empty token label, missing provider_settings, add/remove keys

Uses unittest.mock.patch on AstrBot custom logger for reliable assertions.
2026-02-19 01:26:04 +08:00
zenfun 7e246477f0 fix(dashboard): graceful error handling for Neo skills when unconfigured
- Add _discover_bay_credentials() auto-discovery in _get_neo_client_config()
- Catch ValueError separately in _with_neo_client(), log at DEBUG instead of
  ERROR with full traceback — prevents log spam when visiting Skills page
  without Bay configured
2026-02-19 01:25:50 +08:00
RC-CHN bc3e09f47b refactor(computer): split sandbox skill sync phases
separate sandbox skill syncing into distinct apply and scan steps
while keeping the legacy combined command for compatibility

improve observability by adding phase-based logs and richer shell
error details that include exit code, stderr, and stdout tail

reuse a shared python-exec command builder to reduce duplication
and keep command generation consistent
2026-02-18 13:35:17 +08:00
RC-CHN 707db768ea style: format code 2026-02-17 17:26:37 +08:00
RC-CHN 591803d407 refactor(skills): centralize neo promote and sync flow
extract shared promote/sync orchestration into `NeoSkillSyncManager` so
computer tools and dashboard routes use the same rollback and error logic

add a reusable neo tool base runner to remove duplicated admin checks and
try/catch handling across skill-related tools, keeping responses consistent

factor sync result serialization into a single helper and reuse it where
stable release sync output is returned
2026-02-17 17:20:42 +08:00
RC-CHN b48919246d refactor(api): centralize neo client lifecycle in skills route
extract a shared `_with_neo_client` wrapper to handle neo client
setup, teardown, and error responses in one place.

reduce duplicated try/except and `BayClient` context boilerplate across
neo skills endpoints while preserving existing request validation and
response payloads.
2026-02-17 17:06:11 +08:00
RC-CHN cf9a7235f7 fix(computer): return none for unsupported browser capability
set the base booter browser property to return None instead of
raising NotImplementedError so callers can handle missing browser
support through capability checks
2026-02-17 16:59:05 +08:00
RC-CHN d62a6f107b fix(computer): mask bay api key in logs
Also add shipyard-neo-sdk dependency for neo support
2026-02-17 16:40:55 +08:00
Ruochen Pan 1a539830f8 Merge branch 'master' into feat/neo-skill-self-iteration 2026-02-17 16:23:14 +08:00
zenfun 418913aa53 docs: add PR verification workflow to CONTRIBUTING.md
Document make pr-test-neo and make pr-test-full commands for local
CI-equivalent verification before submitting PRs.
2026-02-17 04:25:06 +08:00
zenfun 4b07aa2bc3 test(computer): add tests for credentials discovery and config logging
19 new tests in test_computer_config.py:
- TestDiscoverBayCredentials (9 tests): env priority, cwd fallback,
  missing file, empty key, malformed JSON, endpoint mismatch, slash normalization
- TestLogComputerConfigChanges (10 tests): runtime change, sandbox key change,
  token masking, empty token label, missing provider_settings, add/remove keys
2026-02-17 04:24:55 +08:00
zenfun 64d8daa67d feat(scripts): update start-with-neo.sh for auto-provisioned API key
- Generated config uses allow_anonymous: false (triggers auto-provision)
- Set BAY_DATA_DIR so credentials.json writes to pkgs/bay/
- Add read_bay_credentials() to extract auto-generated key after boot
- Display API key in config hints for easy AstrBot setup
2026-02-17 04:24:44 +08:00
zenfun 9d44947500 feat(dashboard): update Shipyard Neo config hints
- Endpoint hint: mention default port 8114
- Access Token hint: mention sk-bay-* format and credentials.json auto-discovery
- Updated in default.py, zh-CN, and en-US i18n files
2026-02-17 04:24:34 +08:00
zenfun 4043a10531 fix(computer): improve ShipyardNeoBooter error message
Include default endpoint URL (http://127.0.0.1:8114) and credentials.json
auto-discovery hint in the ValueError message when config is incomplete.
2026-02-17 04:24:24 +08:00
zenfun 7c8dac2fd5 feat(computer): add Bay credentials.json auto-discovery
When shipyard_neo_access_token is not configured, _discover_bay_credentials()
searches for Bay's credentials.json in:
1. BAY_DATA_DIR env var
2. Mono-repo relative path ../pkgs/bay/
3. Current working directory

Enables zero-config dev mode when Bay runs locally alongside AstrBot.
2026-02-17 04:24:12 -06:00
zenfun 963122b916 chore: update gitignore, Makefile, skills route, and test scaffolding 2026-02-16 02:38:01 +08:00
zenfun aa3b012d60 feat: add Shipyard Neo quick-start script
Add scripts/start-with-neo.sh: one-click launcher that auto-generates
Bay config.yaml (anonymous mode, host_port), pulls Ship image, starts
Bay (port 8114) with health check, then starts AstrBot in foreground.
Ctrl+C stops both services. Supports BAY_PORT env var override.
2026-02-16 02:37:48 +08:00
zenfun 401dfb9ee2 feat(dashboard): log Computer/sandbox config changes on save
Add _log_computer_config_changes() to detect and log modifications to
computer_use_runtime and sandbox.* keys when saving config via Dashboard.
Sensitive fields (tokens/secrets) are masked in log output.
2026-02-16 02:37:24 +08:00
zenfun 1d81c52950 feat(computer): add INFO-level lifecycle logging to booter implementations
Add [Computer] prefixed INFO logs to:
- shipyard_neo.py: shutdown, upload_file, download_file, available
- shipyard.py: shutdown, upload_file, download_file, available
- boxlite.py: upload_file success path
- computer_client.py: sync_skills_to_active_sandboxes, _sync_skills_to_sandbox

Improves traceability of sandbox lifecycle events.
2026-02-16 02:37:14 +08:00
zenfun 40c7cf3901 feat(skills): merge sandbox built-ins with uploaded skill sync 2026-02-13 03:20:51 +08:00
zenfun afe292de35 fix: address neo skill review findings 2026-02-11 19:35:01 +08:00
zenfun d4dcc6430f chore: apply pre-commit formatting fixes for neo integration 2026-02-11 17:34:07 +08:00
zenfun a8cc995633 feat(dashboard): add neo skills APIs and management UI 2026-02-11 17:14:55 +08:00
zenfun 73251db1da feat(skills): add neo lifecycle tools and stable sync manager 2026-02-11 17:14:47 +08:00
zenfun d16398a0e8 feat(computer): add shipyard_neo booter runtime and sandbox config 2026-02-11 17:14:38 +08:00
435 changed files with 46031 additions and 2120 deletions
+43
View File
@@ -0,0 +1,43 @@
name: release
on:
push:
tags:
- 'v*'
workflow_dispatch:
jobs:
build:
runs-on: ubuntu-latest # 运行环境
steps:
- name: checkout
uses: actions/checkout@v6
- name: nodejs installation
uses: actions/setup-node@v6
with:
node-version: "18"
- name: npm install
run: npm add -D vitepress
working-directory: './docs' # working-directory 指定 shell 命令运行目录
- name: npm run build
run: npm run docs:build
working-directory: './docs'
- name: scp
uses: appleboy/scp-action@v1.0.0
with:
host: ${{ secrets.HOST_NEKO }}
username: ${{ secrets.USERNAME }}
password: ${{ secrets.PASSWORDNEKO }}
source: 'docs/.vitepress/dist/*'
target: '/tmp/'
- name: script
uses: appleboy/ssh-action@v1.2.5
with:
host: ${{ secrets.HOST_NEKO }}
username: ${{ secrets.USERNAME }}
password: ${{ secrets.PASSWORDNEKO }}
script: |
mkdir -p /root/docker_data/caddy/caddy_data/static_site/abv4/
rm -rf /root/docker_data/caddy/caddy_data/static_site/abv4/*
mv /tmp/docs/.vitepress/dist/* /root/docker_data/caddy/caddy_data/static_site/abv4/
rm -rf /tmp/docs/
+2 -2
View File
@@ -36,7 +36,7 @@ jobs:
zip -r dist.zip dist
- name: Archive production artifacts
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: dist-without-markdown
path: |
@@ -45,7 +45,7 @@ jobs:
- name: Create GitHub Release
if: github.event_name == 'push'
uses: ncipollo/release-action@v1
uses: ncipollo/release-action@v1.20.0
with:
tag: release-${{ github.sha }}
owner: AstrBotDevs
+10 -10
View File
@@ -64,20 +64,20 @@ jobs:
echo "build_date=$build_date" >> $GITHUB_OUTPUT
- name: Set QEMU
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4.0.0
- name: Set Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v4.0.0
- name: Log in to DockerHub
uses: docker/login-action@v3
uses: docker/login-action@v4.0.0
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry
if: env.HAS_GHCR_TOKEN == 'true'
uses: docker/login-action@v3
uses: docker/login-action@v4.0.0
with:
registry: ghcr.io
username: ${{ env.GHCR_OWNER }}
@@ -98,7 +98,7 @@ jobs:
echo "EOF" >> $GITHUB_OUTPUT
- name: Build and Push Nightly Image
uses: docker/build-push-action@v6
uses: docker/build-push-action@v7.0.0
with:
context: .
platforms: linux/amd64,linux/arm64
@@ -163,27 +163,27 @@ jobs:
cp -r dashboard/dist data/
- name: Set QEMU
uses: docker/setup-qemu-action@v3
uses: docker/setup-qemu-action@v4.0.0
- name: Set Docker Buildx
uses: docker/setup-buildx-action@v3
uses: docker/setup-buildx-action@v4.0.0
- name: Log in to DockerHub
uses: docker/login-action@v3
uses: docker/login-action@v4.0.0
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Login to GitHub Container Registry
if: env.HAS_GHCR_TOKEN == 'true'
uses: docker/login-action@v3
uses: docker/login-action@v4.0.0
with:
registry: ghcr.io
username: ${{ env.GHCR_OWNER }}
password: ${{ secrets.GHCR_GITHUB_TOKEN }}
- name: Build and Push Release Image
uses: docker/build-push-action@v6
uses: docker/build-push-action@v7.0.0
with:
context: .
platforms: linux/amd64,linux/arm64
+37 -4
View File
@@ -50,7 +50,7 @@ jobs:
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Setup pnpm
uses: pnpm/action-setup@v4
uses: pnpm/action-setup@v4.3.0
with:
version: 10.28.2
@@ -71,7 +71,7 @@ jobs:
zip -r "AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip" dist
- name: Upload dashboard artifact
uses: actions/upload-artifact@v6
uses: actions/upload-artifact@v7
with:
name: Dashboard-${{ steps.tag.outputs.tag }}
if-no-files-found: error
@@ -132,7 +132,7 @@ jobs:
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Download dashboard artifact
uses: actions/download-artifact@v7
uses: actions/download-artifact@v8
with:
name: Dashboard-${{ steps.tag.outputs.tag }}
path: release-assets
@@ -184,7 +184,8 @@ jobs:
publish-pypi:
name: Publish PyPI
runs-on: ubuntu-24.04
needs: publish-release
needs:
- publish-release
steps:
- name: Checkout repository
uses: actions/checkout@v6
@@ -192,6 +193,36 @@ jobs:
fetch-depth: 0
ref: ${{ inputs.ref || github.ref }}
- name: Resolve tag
id: tag
shell: bash
run: |
if [ "${{ github.event_name }}" = "push" ]; then
tag="${GITHUB_REF_NAME}"
elif [ -n "${{ inputs.tag }}" ]; then
tag="${{ inputs.tag }}"
else
tag="$(git describe --tags --abbrev=0)"
fi
if [ -z "$tag" ]; then
echo "Failed to resolve tag." >&2
exit 1
fi
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Download dashboard artifact
uses: actions/download-artifact@v8
with:
name: Dashboard-${{ steps.tag.outputs.tag }}
path: dashboard-artifact
- name: Unpack dashboard dist into package tree
shell: bash
run: |
mkdir -p astrbot/dashboard/dist
unzip -q "dashboard-artifact/AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip" -d dashboard-artifact/unpacked
cp -r dashboard-artifact/unpacked/dist/. astrbot/dashboard/dist/
- name: Set up Python
uses: actions/setup-python@v6
with:
@@ -203,6 +234,8 @@ jobs:
- name: Build package
shell: bash
# Dashboard assets are already in astrbot/dashboard/dist/;
# ASTRBOT_BUILD_DASHBOARD is intentionally unset so the hatch hook skips npm.
run: uv build
- name: Publish to PyPI
+68
View File
@@ -0,0 +1,68 @@
name: sync wiki
on:
workflow_dispatch:
push:
branches:
- master
paths:
- '.github/workflows/sync-wiki.yml'
- 'docs/scripts/sync_docs_to_wiki.py'
- 'docs/tests/test_sync_docs_to_wiki.py'
- 'docs/zh/**'
- 'docs/en/**'
concurrency:
group: sync-wiki-${{ github.ref }}
cancel-in-progress: true
jobs:
sync:
runs-on: ubuntu-latest
permissions:
contents: read
steps:
- name: Validate manual ref
if: github.event_name == 'workflow_dispatch' && github.ref != 'refs/heads/master'
run: |
echo "This workflow only publishes from refs/heads/master. Re-run it from the master branch."
exit 1
- name: Check out docs repository
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.11'
- name: Run sync unit tests
working-directory: docs
run: python -m unittest discover -s tests -p 'test_sync_docs_to_wiki.py' -v
- name: Validate internal doc links
run: python docs/scripts/sync_docs_to_wiki.py --source-root docs --check-links-only
- name: Clone AstrBot wiki
env:
WIKI_TOKEN: ${{ secrets.ASTRBOT_WIKI_TOKEN }}
run: |
test -n "$WIKI_TOKEN"
git clone "https://x-access-token:${WIKI_TOKEN}@github.com/AstrBotDevs/AstrBot.wiki.git" wiki
- name: Generate wiki pages
run: python docs/scripts/sync_docs_to_wiki.py --source-root docs --wiki-root wiki
- name: Commit and push wiki changes
working-directory: wiki
run: |
git config user.name "github-actions[bot]"
git config user.email "41898282+github-actions[bot]@users.noreply.github.com"
git add .
if git diff --cached --quiet; then
echo "No wiki changes to push"
exit 0
fi
git commit -m "docs: sync wiki from AstrBot-1/docs"
git push
+9
View File
@@ -36,6 +36,9 @@ dashboard/dist/
package-lock.json
yarn.lock
# Bundled dashboard dist (generated by hatch_build.py during pip wheel build)
astrbot/dashboard/dist/
# Operating System
**/.DS_Store
.DS_Store
@@ -54,3 +57,9 @@ IFLOW.md
# genie_tts data
CharacterModels/
GenieData/
.agent/
.codex/
.opencode/
.kilocode/
.worktrees/
+52
View File
@@ -46,6 +46,32 @@ ruff check .
如果您使用 VSCode,可以安装 `Ruff` 插件。
##### PR 功能完整性验证(推荐)
如果您希望在本地做一套接近 CI 的完整验证,可使用:
```bash
make pr-test-neo
```
该命令会执行:
- `uv sync --group dev`
- `ruff format --check .``ruff check .`
- Neo 相关关键测试
- `main.py` 启动 smoke test(检测 `http://localhost:6185`
需要全量验证时可使用:
```bash
make pr-test-full
```
如果只想快速重复执行(跳过依赖同步和 dashboard 构建):
```bash
make pr-test-full-fast
```
## Contributing Guide
@@ -88,3 +114,29 @@ We use Ruff as our code formatter and static analysis tool. Before submitting yo
ruff format .
ruff check .
```
##### PR completeness checks (recommended)
To run a local validation flow close to CI, use:
```bash
make pr-test-neo
```
This command runs:
- `uv sync --group dev`
- `ruff format --check .` and `ruff check .`
- Neo-related critical tests
- a startup smoke test against `http://localhost:6185`
For full validation, use:
```bash
make pr-test-full
```
For faster repeated runs (skip dependency sync and dashboard build), use:
```bash
make pr-test-full-fast
```
+10 -1
View File
@@ -1,4 +1,4 @@
.PHONY: worktree worktree-add worktree-rm
.PHONY: worktree worktree-add worktree-rm pr-test-neo pr-test-full pr-test-full-fast
WORKTREE_DIR ?= ../astrbot_worktree
BRANCH ?= $(word 2,$(MAKECMDGOALS))
@@ -27,6 +27,15 @@ endif
echo "Worktree $(WORKTREE_DIR)/$(BRANCH) not found."; \
fi
pr-test-neo:
./scripts/pr_test_env.sh --profile neo
pr-test-full:
./scripts/pr_test_env.sh --profile full
pr-test-full-fast:
./scripts/pr_test_env.sh --profile full --skip-sync --no-dashboard
# Swallow extra args (branch/base) so make doesn't treat them as targets
%:
@true
+39 -16
View File
@@ -73,7 +73,7 @@ AstrBot is an open-source all-in-one Agent chatbot platform that integrates with
### One-Click Deployment
For users who want to quickly experience AstrBot, we recommend using the one-click deployment method with `uv` ⚡️:
For users who want to quickly experience AstrBot, are familiar with command-line usage, and can install a `uv` environment on their own, we recommend the `uv` one-click deployment method ⚡️:
```bash
uv tool install astrbot
@@ -83,47 +83,58 @@ astrbot
> Requires [uv](https://docs.astral.sh/uv/) to be installed.
> [!NOTE]
> For macOS user: due to macOS security checks, the first run of the `astrbot` command may take longer (about 10-20s).
Update `astrbot`:
```bash
uv tool upgrade astrbot
```
### Docker Deployment
For users who want a more stable and production-ready deployment, we recommend using Docker / Docker Compose to deploy AstrBot.
For users familiar with containers and looking for a more stable, production-ready deployment method, we recommend deploying AstrBot with Docker / Docker Compose.
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Deploy on RainYun
For users who want to deploy AstrBot with one-click and don't want to manage the server, we recommend using RainYun's one-click cloud deployment service ☁️:
For users who want one-click deployment and do not want to manage servers themselves, we recommend RainYun's one-click cloud deployment service ☁️:
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### Desktop Application (Tauri)
### Desktop Application Deployment
For users who want to deploy AstrBot on their desktop, primarily using AstrBot ChatUI, rarely use AstrBot plugins, we recommend using the AstrBot App:
For users who want to use AstrBot on desktop and mainly use ChatUI, we recommend AstrBot App.
Desktop repository: [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop).
Visit [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) to download and install; this method is designed for desktop usage and is not recommended for server scenarios.
Supports multiple system architectures, direct package installation, and out-of-the-box usage. A convenient one-click desktop deployment option for beginners.
### Launcher Deployment
### One-Click Launcher Deployment (AstrBot Launcher)
For desktop users who also want fast deployment and isolated multi-instance usage, we recommend AstrBot Launcher.
For users who want a quick deployment and multi-instance solution with environment isolation, we recommend using the AstrBot Launcher:
Visit the [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) repository and install the package for your OS from the latest release.
A quick deployment and multi-instance solution with environment isolation.
Visit [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) to download and install.
### Deploy on Replit
Community-contributed deployment method.
Replit deployment is maintained by the community and is suitable for online demos and lightweight trials.
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR
AUR deployment targets Arch Linux users who prefer installing AstrBot through the system package workflow.
Run the command below to install `astrbot-git`, then start AstrBot in your local environment.
```bash
yay -S astrbot-git
```
**More deployment methods**: [BT-Panel Deployment](https://astrbot.app/deploy/astrbot/btpanel.html) | [1Panel Deployment](https://astrbot.app/deploy/astrbot/1panel.html) | [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html) | [Manual Deployment](https://astrbot.app/deploy/astrbot/cli.html)
**More deployment methods**
If you need panel-based management or deeper customization, see [BT-Panel Deployment](https://astrbot.app/deploy/astrbot/btpanel.html) for BT Panel app-store setup, [1Panel Deployment](https://astrbot.app/deploy/astrbot/1panel.html) for 1Panel app-market deployment, [CasaOS Deployment](https://astrbot.app/deploy/astrbot/casaos.html) for NAS/home-server visual deployment, and [Manual Deployment](https://astrbot.app/deploy/astrbot/cli.html) for fully custom source-based installation with `uv`.
## Supported Messaging Platforms
@@ -184,6 +195,13 @@ Connect AstrBot to your favorite chat platform.
| Minimax TTS | Text-to-Speech Services |
| Volcano Engine TTS | Text-to-Speech Services |
## ❤️ Sponsors
<p align="center">
<img alt="sponsors" src="https://sponsors.astrbot.app/?v=1">
</p>
## ❤️ Contributing
Issues and Pull Requests are always welcome! Feel free to submit your changes to this project :)
@@ -202,17 +220,22 @@ pip install pre-commit
pre-commit install
```
## 🌍 Community
### QQ Groups
- Group 9: 1076659624 (New)
- Group 10: 1078079676 (New)
- Group 1: 322154837
- Group 3: 630166526
- Group 5: 822130018
- Group 6: 753075035
- Group 7: 743746109
- Group 8: 1030353265
- Developer Group: 975206796
- Developer Group(Chit-chat): 975206796
- Developer Group(Formal): 1039761811
### Discord Server
+28 -16
View File
@@ -73,7 +73,7 @@ AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègr
### Déploiement en un clic
Pour les utilisateurs qui souhaitent découvrir AstrBot rapidement, nous recommandons la méthode de déploiement en un clic avec `uv` ⚡️ :
Pour les utilisateurs qui veulent découvrir AstrBot rapidement, qui sont familiers avec la ligne de commande et peuvent installer eux-mêmes l'environnement `uv`, nous recommandons la méthode de déploiement en un clic avec `uv` ⚡️ :
```bash
uv tool install astrbot
@@ -83,47 +83,58 @@ astrbot
> [uv](https://docs.astral.sh/uv/) doit être installé.
> [!NOTE]
> Pour les utilisateurs macOS : en raison des vérifications de sécurité de macOS, la première exécution de la commande `astrbot` peut prendre plus de temps (environ 10-20s).
Mettre à jour `astrbot` :
```bash
uv tool upgrade astrbot
```
### Déploiement Docker
Pour les utilisateurs qui veulent un déploiement plus stable et prêt pour la production, nous recommandons d'utiliser Docker / Docker Compose pour déployer AstrBot.
Pour les utilisateurs familiers avec les conteneurs et qui souhaitent une méthode plus stable et adaptée à la production, nous recommandons de déployer AstrBot avec Docker / Docker Compose.
Veuillez consulter la documentation officielle : [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
Veuillez consulter la documentation officielle [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Déployer sur RainYun
Pour les utilisateurs qui souhaitent déployer AstrBot en un clic sans gérer le serveur, nous recommandons le service de déploiement cloud en un clic de RainYun ☁️ :
Pour les utilisateurs qui souhaitent déployer AstrBot en un clic sans gérer le serveur eux-mêmes, nous recommandons le service de déploiement cloud en un clic de RainYun ☁️ :
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### Application de bureau (Tauri)
### Déploiement de l'application de bureau
Pour les utilisateurs qui veulent déployer AstrBot sur desktop, utilisent principalement AstrBot ChatUI et utilisent rarement les plugins AstrBot, nous recommandons AstrBot App :
Pour les utilisateurs qui veulent utiliser AstrBot sur desktop et passer principalement par ChatUI, nous recommandons AstrBot App.
Dépôt de l'application de bureau : [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop).
Accédez à [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) pour télécharger et installer l'application ; cette méthode est conçue pour un usage desktop et n'est pas recommandée pour les scénarios serveur.
Prend en charge plusieurs architectures système, installation directe, prête à l'emploi. Solution de déploiement bureau en un clic, particulièrement adaptée aux débutants. Non recommandée pour les serveurs.
### Déploiement avec le lanceur
### Déploiement en un clic avec le lanceur (AstrBot Launcher)
Également sur desktop, pour les utilisateurs qui souhaitent un déploiement rapide avec isolation d'environnement et multi-instances, nous recommandons AstrBot Launcher.
Pour les utilisateurs qui veulent une solution de déploiement rapide et multi-instances avec isolation d'environnement, nous recommandons d'utiliser AstrBot Launcher :
Accédez au dépôt [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) et installez le package correspondant à votre système depuis la dernière release.
Une solution de déploiement rapide et multi-instances avec isolation d'environnement.
Accédez à [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) pour télécharger et installer.
### Déployer sur Replit
Méthode de déploiement contribuée par la communauté.
Le déploiement sur Replit est maintenu par la communauté et convient aux démonstrations en ligne et aux essais légers.
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR
Le mode AUR s'adresse aux utilisateurs Arch Linux qui préfèrent installer AstrBot via le gestionnaire de paquets système.
Exécutez la commande ci-dessous pour installer `astrbot-git`, puis lancez AstrBot localement.
```bash
yay -S astrbot-git
```
**Autres méthodes de déploiement** : [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) | [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) | [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) | [Déploiement manuel](https://astrbot.app/deploy/astrbot/cli.html)
**Autres méthodes de déploiement**
Si vous avez besoin d'une gestion par panneau ou d'une personnalisation plus poussée, consultez [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) pour une installation via BT Panel, [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) pour le marketplace 1Panel, [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) pour un déploiement visuel sur NAS/serveur domestique, et [Déploiement manuel](https://astrbot.app/deploy/astrbot/cli.html) pour une installation complète depuis les sources avec `uv`.
## Plateformes de messagerie prises en charge
@@ -211,6 +222,7 @@ pre-commit install
- Groupe 5 : 822130018
- Groupe 6 : 753075035
- Groupe développeurs : 975206796
- Groupe développeurs (officiel) : 1039761811
### Serveur Discord
+27 -15
View File
@@ -73,7 +73,7 @@ AstrBot は、主要なインスタントメッセージングアプリと統合
### ワンクリックデプロイ
AstrBot を素早く試したいユーザーは、`uv` を使ったワンクリックデプロイをおすすめします ⚡️:
AstrBot を素早く試したいユーザーで、コマンドラインに慣れており `uv` 環境を自分でインストールできる場合は、`uv` ワンクリックデプロイをおすすめします ⚡️:
```bash
uv tool install astrbot
@@ -83,47 +83,58 @@ astrbot
> [uv](https://docs.astral.sh/uv/) のインストールが必要です。
> [!NOTE]
> macOS ユーザーの場合:macOS のセキュリティチェックにより、`astrbot` コマンドの初回実行に時間がかかる場合があります(約 10〜20 秒)。
`astrbot` の更新:
```bash
uv tool upgrade astrbot
```
### Docker デプロイ
より安定した本番向けのデプロイを求めるユーザーには、Docker / Docker Compose で AstrBot デプロイすることをおすすめします。
コンテナ運用に慣れており、より安定した本番向けのデプロイ方法を求めるユーザーには、Docker / Docker Compose で AstrBot デプロイをおすすめします。
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
### 雨云でのデプロイ
サーバー管理をせずに AstrBot をワンクリックでデプロイしたいユーザーには、雨云のワンクリッククラウドデプロイサービスをおすすめします ☁️:
AstrBot をワンクリックでデプロイしたく、サーバーを自分で管理したくないユーザーには、雨云のワンクリッククラウドデプロイサービスをおすすめします ☁️:
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### デスクトップクライアント(Tauri
### デスクトップアプリのデプロイ
デスクトップで AstrBot を使いたいユーザーで、主に AstrBot ChatUI を利用し、AstrBot プラグインの利用頻度が低い場合は、AstrBot App の利用をおすすめします:
デスクトップで AstrBot を使い、主に ChatUI を入口として利用するユーザーには、AstrBot App をおすすめします
デスクトップアプリのリポジトリ [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop)。
[AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) からダウンロードしてインストールしてください。この方式はデスクトップ向けであり、サーバー用途には推奨されません
マルチシステムアーキテクチャに対応し、インストーラーですぐ利用可能。初心者にも使いやすいワンクリックのデスクトップデプロイ方式です。サーバー用途には推奨されません。
### ランチャーのデプロイ
### ランチャーによるワンクリックデプロイ(AstrBot Launcher
同じくデスクトップで、素早くデプロイしつつ環境を分離して多重起動したいユーザーには、AstrBot Launcher をおすすめします。
高速デプロイと環境分離されたマルチインスタンス運用を求めるユーザーには、AstrBot Launcher の利用をおすすめします:
[AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) リポジトリにアクセスし、最新リリースからお使いの OS 向けパッケージをインストールしてください。
高速デプロイと環境分離されたマルチインスタンス運用を実現できます。
[AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) からダウンロードしてインストールしてください。
### Replit でのデプロイ
コミュニティ貢献によるデプロイ方法
Replit デプロイはコミュニティ提供の方式で、オンラインデモや軽量な試用に向いています
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR
AUR 方式は Arch Linux ユーザー向けで、システムのパッケージ運用に合わせて AstrBot を導入したい場合に適しています。
次のコマンドで `astrbot-git` をインストールし、ローカル環境で AstrBot を起動してください。
```bash
yay -S astrbot-git
```
**その他のデプロイ方法**[宝塔パネルデプロイ](https://astrbot.app/deploy/astrbot/btpanel.html) | [1Panel デプロイ](https://astrbot.app/deploy/astrbot/1panel.html) | [CasaOS デプロイ](https://astrbot.app/deploy/astrbot/casaos.html) | [手動デプロイ](https://astrbot.app/deploy/astrbot/cli.html)
**その他のデプロイ方法**
パネル操作での導入やより高度なカスタマイズが必要な場合は、[宝塔パネルデプロイ](https://astrbot.app/deploy/astrbot/btpanel.html)BT Panel 経由の導入)、[1Panel デプロイ](https://astrbot.app/deploy/astrbot/1panel.html)(1Panel アプリマーケット経由)、[CasaOS デプロイ](https://astrbot.app/deploy/astrbot/casaos.html)(NAS / ホームサーバー向け可視化導入)、[手動デプロイ](https://astrbot.app/deploy/astrbot/cli.html)`uv` とソースベースのフルカスタム導入)を参照してください。
## サポートされているメッセージプラットフォーム
@@ -212,6 +223,7 @@ pre-commit install
- 5群: 822130018
- 6群: 753075035
- 開発者群: 975206796
- 開発者群(正式): 1039761811
### Discord サーバー
+28 -16
View File
@@ -73,7 +73,7 @@ AstrBot — это универсальная платформа Agent-чатб
### Развёртывание в один клик
Для пользователей, которые хотят быстро попробовать AstrBot, мы рекомендуем использовать развёртывание в один клик через `uv` ⚡️:
Для пользователей, которые хотят быстро попробовать AstrBot, знакомы с командной строкой и могут самостоятельно установить окружение `uv`, мы рекомендуем использовать развёртывание в один клик через `uv` ⚡️:
```bash
uv tool install astrbot
@@ -83,47 +83,58 @@ astrbot
> Требуется установленный [uv](https://docs.astral.sh/uv/).
> [!NOTE]
> Для пользователей macOS: из-за проверок безопасности macOS первый запуск команды `astrbot` может занять больше времени (около 10-20 секунд).
Обновить `astrbot`:
```bash
uv tool upgrade astrbot
```
### Развёртывание Docker
Для пользователей, которым нужен более стабильный и готовый к production вариант, мы рекомендуем развёртывать AstrBot через Docker / Docker Compose.
Для пользователей, знакомых с контейнерами и которым нужен более стабильный и подходящий для production способ, мы рекомендуем разворачивать AstrBot через Docker / Docker Compose.
См. официальную документацию: [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
См. официальную документацию [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
### Развёртывание на RainYun
Для пользователей, которые хотят развернуть AstrBot в один клик и не управлять сервером самостоятельно, мы рекомендуем облачный сервис развёртывания в один клик от RainYun ☁️:
Для пользователей, которые хотят развернуть AstrBot в один клик и не хотят самостоятельно управлять сервером, мы рекомендуем облачный сервис развёртывания в один клик от RainYun ☁️:
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### Десктопное приложение (Tauri)
### Развёртывание десктопного приложения
Для пользователей, которые хотят использовать AstrBot на десктопе, в основном работают с AstrBot ChatUI и редко используют плагины AstrBot, мы рекомендуем AstrBot App:
Для пользователей, которые хотят использовать AstrBot на десктопе и в основном работают через ChatUI, мы рекомендуем AstrBot App.
Репозиторий десктопного приложения: [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop).
Перейдите в [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop), скачайте и установите приложение; этот вариант предназначен для десктопа и не рекомендуется для серверных сценариев.
Поддерживает разные архитектуры систем, устанавливается напрямую и работает сразу после установки. Удобное настольное развёртывание в один клик для новичков. Не рекомендуется для серверных сценариев.
### Развёртывание через лаунчер
### Установка в один клик через лаунчер (AstrBot Launcher)
Также на десктопе, для пользователей, которым нужен быстрый запуск и мультиинстанс с изоляцией окружений, мы рекомендуем AstrBot Launcher.
Для пользователей, которым нужно быстрое развёртывание и мультиинстанс с изоляцией окружений, мы рекомендуем использовать AstrBot Launcher:
Перейдите в репозиторий [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher), откройте Releases и установите пакет для вашей системы из последней версии.
Быстрое развёртывание и мультиинстанс-решение с изоляцией окружений.
Перейдите в [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher), чтобы скачать и установить.
### Развёртывание на Replit
Метод развёртывания от сообщества.
Развёртывание через Replit поддерживается сообществом и подходит для онлайн-демо и лёгких тестовых запусков.
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR
AUR-вариант предназначен для пользователей Arch Linux, которым удобна установка через системный менеджер пакетов.
Выполните команду ниже для установки `astrbot-git`, затем запустите AstrBot локально.
```bash
yay -S astrbot-git
```
**Другие способы развёртывания**: [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) | [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) | [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) | [Ручное развёртывание](https://astrbot.app/deploy/astrbot/cli.html)
**Другие способы развёртывания**
Если вам нужна панельная установка или более глубокая кастомизация, смотрите [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html) (установка через BT Panel), [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html) (развёртывание через маркетплейс 1Panel), [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) (визуальный вариант для NAS и домашних серверов) и [Ручное развёртывание](https://astrbot.app/deploy/astrbot/cli.html) (полностью настраиваемая установка из исходников через `uv`).
## Поддерживаемые платформы обмена сообщениями
@@ -211,6 +222,7 @@ pre-commit install
- Группа 5: 822130018
- Группа 6: 753075035
- Группа разработчиков: 975206796
- Группа разработчиков (официальная): 1039761811
### Сервер Discord
+32 -16
View File
@@ -73,7 +73,7 @@ AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主
### 一鍵部署
對於想快速體驗 AstrBot 的使用者,我們推薦使用 `uv` 一鍵部署方式 ⚡️
對於想快速體驗 AstrBot、且熟悉命令列並能自行安裝 `uv` 環境的使用者,我們推薦使用 `uv` 一鍵部署方式 ⚡️
```bash
uv tool install astrbot
@@ -83,11 +83,20 @@ astrbot
> 需要安裝 [uv](https://docs.astral.sh/uv/)。
> [!NOTE]
> 對於 macOS 使用者:由於 macOS 安全性檢查,首次執行 `astrbot` 指令可能需要較長時間(約 10-20 秒)。
更新 `astrbot`
```bash
uv tool upgrade astrbot
```
### Docker 部署
對於希望獲得更穩定更適合正式環境部署方式的使用者,我們推薦使用 Docker / Docker Compose 部署 AstrBot。
對於熟悉容器、希望獲得更穩定更適合正式環境部署方式的使用者,我們推薦使用 Docker / Docker Compose 部署 AstrBot。
請參官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
請參官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
### 在雨雲上部署
@@ -95,35 +104,37 @@ astrbot
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### 桌面客戶端Tauri
### 桌面客戶端部署
對於希望在桌面部署 AstrBot、以 AstrBot ChatUI 為主要使用方式、較少使用 AstrBot 外掛的使用者,我們推薦使用 AstrBot App
對於希望在桌面端使用 AstrBot、以 ChatUI 為主要入口的使用者,我們推薦使用 AstrBot App
桌面應用倉庫 [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop)。
前往 [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) 下載並安裝;此方式面向桌面使用,不建議伺服器場景
支援多系統架構,安裝包直接安裝,開箱即用,最適合新手和懶人的一鍵桌面部署方案,不推薦伺服器場景。
### 啟動器部署
### 啟動器一鍵部署(AstrBot Launcher
同樣在桌面端,對於希望快速部署並實現環境隔離多開的使用者,我們推薦使用 AstrBot Launcher
對於希望快速部署並實現環境隔離多開的使用者,我們推薦使用 AstrBot Launcher
進入 [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) 倉庫,在 Releases 頁最新版本下找到對應的系統安裝包安裝即可。
一個快速部署和多開方案,實現環境隔離。
前往 [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) 下載並安裝。
### 在 Replit 上部署
社群貢獻的部署方式
Replit 部署由社群維護,適合線上示範與輕量試用情境
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR
AUR 方式面向 Arch Linux 使用者,適合希望透過系統套件管理器安裝 AstrBot 的場景。
在終端執行下方命令安裝 `astrbot-git` 套件,安裝完成後即可啟動使用。
```bash
yay -S astrbot-git
```
**更多部署方式**[寶塔面板](https://astrbot.app/deploy/astrbot/btpanel.html) | [1Panel](https://astrbot.app/deploy/astrbot/1panel.html) | [CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) | [手動部署](https://astrbot.app/deploy/astrbot/cli.html)
**更多部署方式**
若你需要面板化或更高自訂程度的部署,可參考 [寶塔面板](https://astrbot.app/deploy/astrbot/btpanel.html)(BT Panel 應用商店安裝)、[1Panel](https://astrbot.app/deploy/astrbot/1panel.html)1Panel 應用商店安裝)、[CasaOS](https://astrbot.app/deploy/astrbot/casaos.html)(NAS / 家用伺服器可視化部署)與 [手動部署](https://astrbot.app/deploy/astrbot/cli.html)(基於原始碼與 `uv` 的完整自訂安裝)。
## 支援的訊息平台
@@ -206,11 +217,16 @@ pre-commit install
### QQ 群組
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 開發者群:975206796
- 7 群:743746109
- 8 群:1030353265
- 開發者群(闲聊吹水):975206796
- 開發者群(正式):1039761811
### Discord 群組
+30 -16
View File
@@ -73,7 +73,7 @@ AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、
### 一键部署
对于想快速体验 AstrBot 的用户,我们推荐使用 `uv` 一键部署方式 ⚡️
对于想快速体验 AstrBot、且熟悉命令行并能够自行安装 `uv` 环境的用户,我们推荐使用 `uv` 一键部署方式 ⚡️
```bash
uv tool install astrbot
@@ -83,11 +83,20 @@ astrbot
> 需要安装 [uv](https://docs.astral.sh/uv/)。
> [!NOTE]
> 对于 macOS 用户:由于 macOS 安全检查,首次运行 `astrbot` 命令可能需要较长时间(约 10-20 秒)。
更新 `astrbot`
```bash
uv tool upgrade astrbot
```
### Docker 部署
对于希望获得更稳定更适合生产环境部署方式的用户,我们推荐使用 Docker / Docker Compose 部署 AstrBot。
对于熟悉容器、希望获得更稳定更适合生产环境部署方式的用户,我们推荐使用 Docker / Docker Compose 部署 AstrBot。
请参官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)
请参官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
### 在 雨云 上部署
@@ -95,35 +104,37 @@ astrbot
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
### 桌面客户端Tauri
### 桌面客户端部署
对于希望在桌面部署 AstrBot、以 AstrBot ChatUI 为主要使用方式、较少使用 AstrBot 插件的用户,我们推荐使用 AstrBot App
对于希望在桌面端使用 AstrBot、以 ChatUI 为主要入口的用户,我们推荐使用 AstrBot App
桌面应用仓库 [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop)。
前往 [AstrBot-desktop](https://github.com/AstrBotDevs/AstrBot-desktop) 下载并安装;该方式面向桌面使用,不推荐服务器场景
支持多系统架构,安装包直接安装,开箱即用,最适合新手和懒人的一键桌面部署方案,不推荐服务器场景。
### 启动器部署
### 启动器一键部署(AstrBot Launcher
同样在桌面端,希望快速部署并实现环境隔离多开的用户,我们推荐使用 AstrBot Launcher
对于希望快速部署并实现环境隔离多开的用户,我们推荐使用 AstrBot Launcher
进入 [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) 仓库,在 Releases 页最新版本下找到对应的系统安装包安装即可。
一个快速部署和多开方案,实现环境隔离。
前往 [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) 下载并安装。
### 在 Replit 上部署
社区贡献的部署方式
Replit 部署由社区维护,适合在线演示和轻量试用场景
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
### AUR
AUR 方式面向 Arch Linux 用户,适合希望通过系统包管理器安装 AstrBot 的场景。
在终端执行下方命令安装 `astrbot-git` 包,安装完成后即可启动使用。
```bash
yay -S astrbot-git
```
**更多部署方式**[宝塔面板](https://astrbot.app/deploy/astrbot/btpanel.html) | [1Panel](https://astrbot.app/deploy/astrbot/1panel.html) | [CasaOS](https://astrbot.app/deploy/astrbot/casaos.html) | [手动部署](https://astrbot.app/deploy/astrbot/cli.html)
**更多部署方式**
若你需要面板化或更高自定义部署,可参考 [宝塔面板](https://astrbot.app/deploy/astrbot/btpanel.html)(BT Panel 应用商店安装)、[1Panel](https://astrbot.app/deploy/astrbot/1panel.html)1Panel 应用商店安装)、[CasaOS](https://astrbot.app/deploy/astrbot/casaos.html)(NAS / 家庭服务器可视化部署)和 [手动部署](https://astrbot.app/deploy/astrbot/cli.html)(基于源码与 `uv` 的完整自定义安装)。
## 支持的消息平台
@@ -207,13 +218,16 @@ pre-commit install
### QQ 群组
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 8 群:1030353265
- 开发者群:975206796
- 开发者群(偏闲聊吹水)975206796
- 开发者群(正式):1039761811
### Discord 频道
@@ -1,15 +1,262 @@
from __future__ import annotations
import asyncio
import re
import time
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING
from astrbot import logger
from astrbot.api import star
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.core.provider.entities import ProviderType
from astrbot.core.utils.error_redaction import safe_error
if TYPE_CHECKING:
from astrbot.core.provider.provider import Provider
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT = 30.0
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT = 4
MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND = 16
MODEL_LIST_CACHE_TTL_KEY = "model_list_cache_ttl_seconds"
MODEL_LOOKUP_MAX_CONCURRENCY_KEY = "model_lookup_max_concurrency"
MODEL_CACHE_MAX_ENTRIES = 512
@dataclass(frozen=True)
class _ModelLookupConfig:
umo: str | None
cache_ttl_seconds: float
max_concurrency: int
class _ModelCache:
def __init__(self) -> None:
self._store: dict[tuple[str, str | None], tuple[float, list[str]]] = {}
def get(self, provider_id: str, umo: str | None, ttl: float) -> list[str] | None:
if ttl <= 0:
return None
entry = self._store.get((provider_id, umo))
if not entry:
return None
timestamp, models = entry
if time.monotonic() - timestamp > ttl:
self._store.pop((provider_id, umo), None)
return None
return models
def set(
self, provider_id: str, umo: str | None, models: list[str], ttl: float
) -> None:
if ttl <= 0:
return
self._store[(provider_id, umo)] = (time.monotonic(), list(models))
self._evict_if_needed()
def _evict_if_needed(self) -> None:
if len(self._store) <= MODEL_CACHE_MAX_ENTRIES:
return
# Drop oldest entries first when cache grows too large.
overflow = len(self._store) - MODEL_CACHE_MAX_ENTRIES
for key, _ in sorted(
self._store.items(),
key=lambda item: item[1][0],
)[:overflow]:
self._store.pop(key, None)
def invalidate(
self, provider_id: str | None = None, *, umo: str | None = None
) -> None:
if provider_id is None:
self._store.clear()
return
if umo is not None:
self._store.pop((provider_id, umo), None)
return
stale_keys = [
cache_key for cache_key in self._store if cache_key[0] == provider_id
]
for cache_key in stale_keys:
self._store.pop(cache_key, None)
class ProviderCommands:
def __init__(self, context: star.Context) -> None:
self.context = context
self._model_cache = _ModelCache()
self._register_provider_change_hook()
def _register_provider_change_hook(self) -> None:
set_change_callback = getattr(
self.context.provider_manager,
"set_provider_change_callback",
None,
)
if callable(set_change_callback):
set_change_callback(self._on_provider_manager_changed)
return
register_change_hook = getattr(
self.context.provider_manager,
"register_provider_change_hook",
None,
)
if callable(register_change_hook):
register_change_hook(self._on_provider_manager_changed)
def invalidate_provider_models_cache(
self, provider_id: str | None = None, *, umo: str | None = None
) -> None:
"""Public hook for cache invalidation on external provider config changes."""
self._model_cache.invalidate(provider_id, umo=umo)
def _on_provider_manager_changed(
self,
provider_id: str,
provider_type: ProviderType,
umo: str | None,
) -> None:
if provider_type == ProviderType.CHAT_COMPLETION:
self.invalidate_provider_models_cache(provider_id, umo=umo)
def _get_provider_settings(self, umo: str | None) -> dict:
if not umo:
return {}
try:
return self.context.get_config(umo).get("provider_settings", {}) or {}
except Exception as e:
logger.debug(
"读取 provider_settings 失败,使用默认值: %s",
safe_error("", e),
)
return {}
def _get_model_cache_ttl(self, umo: str | None) -> float:
settings = self._get_provider_settings(umo)
raw = settings.get(
MODEL_LIST_CACHE_TTL_KEY,
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
)
try:
return max(float(raw), 0.0)
except Exception as e:
logger.debug(
"读取 %s 失败,回退默认值 %r: %s",
MODEL_LIST_CACHE_TTL_KEY,
MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT,
safe_error("", e),
)
return MODEL_LIST_CACHE_TTL_SECONDS_DEFAULT
def _get_model_lookup_concurrency(self, umo: str | None) -> int:
settings = self._get_provider_settings(umo)
raw = settings.get(
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
)
try:
value = int(raw)
except Exception as e:
logger.debug(
"读取 %s 失败,回退默认值 %r: %s",
MODEL_LOOKUP_MAX_CONCURRENCY_KEY,
MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT,
safe_error("", e),
)
value = MODEL_LOOKUP_MAX_CONCURRENCY_DEFAULT
return min(max(value, 1), MODEL_LOOKUP_MAX_CONCURRENCY_UPPER_BOUND)
def _get_model_lookup_config(self, umo: str | None) -> _ModelLookupConfig:
return _ModelLookupConfig(
umo=umo,
cache_ttl_seconds=self._get_model_cache_ttl(umo),
max_concurrency=self._get_model_lookup_concurrency(umo),
)
def _resolve_model_name(
self,
model_name: str,
models: Sequence[str],
) -> str | None:
"""Resolve model name with precedence:
exact > case-insensitive > provider-qualified suffix.
"""
requested = model_name.strip()
if not requested:
return None
requested_norm = requested.casefold()
# exact / case-insensitive match
for candidate in models:
if candidate == requested or candidate.casefold() == requested_norm:
return candidate
# provider-qualified suffix match:
# e.g. candidate `openai/gpt-4o` should match requested `gpt-4o`.
for candidate in models:
cand_norm = candidate.casefold()
if cand_norm.endswith(f"/{requested_norm}") or cand_norm.endswith(
f":{requested_norm}"
):
return candidate
return None
def _apply_model(
self, prov: Provider, model_name: str, *, umo: str | None = None
) -> str:
prov.set_model(model_name)
self.invalidate_provider_models_cache(prov.meta().id, umo=umo)
return f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]"
async def _get_provider_models(
self,
provider: Provider,
*,
config: _ModelLookupConfig,
use_cache: bool = True,
) -> list[str]:
provider_id = provider.meta().id
ttl_seconds = config.cache_ttl_seconds
umo = config.umo
if use_cache:
cached = self._model_cache.get(provider_id, umo, ttl_seconds)
if cached is not None:
return cached
models = list(await provider.get_models())
if use_cache:
self._model_cache.set(provider_id, umo, models, ttl_seconds)
return models
async def _get_models_or_reply_error(
self,
message: AstrMessageEvent,
prov: Provider,
config: _ModelLookupConfig,
*,
error_prefix: str,
disable_t2i: bool = False,
warning_log: str | None = None,
) -> list[str] | None:
try:
return await self._get_provider_models(prov, config=config)
except asyncio.CancelledError:
raise
except Exception as e:
if warning_log is not None:
logger.warning(
warning_log,
prov.meta().id,
safe_error("", e),
)
result = MessageEventResult().message(safe_error(error_prefix, e))
if disable_t2i:
result = result.use_t2i(False)
message.set_result(result)
return None
def _log_reachability_failure(
self,
@@ -38,12 +285,96 @@ class ProviderCommands:
return True, None, None
except Exception as e:
err_code = "TEST_FAILED"
err_reason = str(e)
err_reason = safe_error("", e)
self._log_reachability_failure(
provider, provider_capability_type, err_code, err_reason
)
return False, err_code, err_reason
async def _find_provider_for_model(
self,
model_name: str,
*,
exclude_provider_id: str | None = None,
config: _ModelLookupConfig,
use_cache: bool = True,
) -> tuple[Provider | None, str | None]:
all_providers = []
for provider in self.context.get_all_providers():
provider_meta = provider.meta()
if provider_meta.provider_type != ProviderType.CHAT_COMPLETION:
continue
if (
exclude_provider_id is not None
and provider_meta.id == exclude_provider_id
):
continue
all_providers.append(provider)
if not all_providers:
return None, None
semaphore = asyncio.Semaphore(config.max_concurrency)
async def fetch_models(
provider: Provider,
) -> tuple[Provider, list[str] | None, str | None]:
async with semaphore:
try:
models = await self._get_provider_models(
provider,
config=config,
use_cache=use_cache,
)
return provider, models, None
except asyncio.CancelledError:
raise
except Exception as e:
err = safe_error("", e)
logger.debug(
"跨提供商查找模型 %s 获取 %s 模型列表失败: %s",
model_name,
provider.meta().id,
err,
)
return provider, None, err
results = await asyncio.gather(
*(fetch_models(provider) for provider in all_providers)
)
failed_provider_errors: list[tuple[str, str]] = []
for provider, models, err in results:
if err is not None:
failed_provider_errors.append((provider.meta().id, err))
continue
if models is None:
continue
matched_model_name = self._resolve_model_name(model_name, models)
if matched_model_name is not None:
return provider, matched_model_name
if failed_provider_errors and len(failed_provider_errors) == len(all_providers):
failed_ids = ",".join(
provider_id for provider_id, _ in failed_provider_errors
)
logger.error(
"跨提供商查找模型 %s 时,所有 %d 个提供商的 get_models() 均失败: %s。请检查配置或网络",
model_name,
len(all_providers),
failed_ids,
)
elif failed_provider_errors:
logger.debug(
"跨提供商查找模型 %s 时有 %d 个提供商获取模型失败: %s",
model_name,
len(failed_provider_errors),
",".join(
f"{provider_id}({error})"
for provider_id, error in failed_provider_errors
),
)
return None, None
async def provider(
self,
event: AstrMessageEvent,
@@ -92,13 +423,15 @@ class ProviderCommands:
id_ = meta.id
error_code = None
if isinstance(reachable, asyncio.CancelledError):
raise reachable
if isinstance(reachable, Exception):
# 异常情况下兜底处理,避免单个 provider 导致列表失败
self._log_reachability_failure(
p,
None,
reachable.__class__.__name__,
str(reachable),
safe_error("", reachable),
)
reachable_flag = False
error_code = reachable.__class__.__name__
@@ -224,6 +557,73 @@ class ProviderCommands:
else:
event.set_result(MessageEventResult().message("无效的参数。"))
async def _switch_model_by_name(
self, message: AstrMessageEvent, model_name: str, prov: Provider
) -> None:
model_name = model_name.strip()
if not model_name:
message.set_result(MessageEventResult().message("模型名不能为空。"))
return
umo = message.unified_msg_origin
config = self._get_model_lookup_config(umo)
curr_provider_id = prov.meta().id
models = await self._get_models_or_reply_error(
message,
prov,
config,
error_prefix="获取当前提供商模型列表失败: ",
warning_log="获取当前提供商 %s 模型列表失败,停止跨提供商查找: %s",
)
if models is None:
return
matched_model_name = self._resolve_model_name(model_name, models)
if matched_model_name is not None:
message.set_result(
MessageEventResult().message(
self._apply_model(prov, matched_model_name, umo=umo)
),
)
return
target_prov, matched_target_model_name = await self._find_provider_for_model(
model_name,
exclude_provider_id=curr_provider_id,
config=config,
)
if target_prov is None or matched_target_model_name is None:
message.set_result(
MessageEventResult().message(
f"模型 [{model_name}] 未在任何已配置的提供商中找到,或所有提供商模型列表获取失败,请检查配置或网络后重试。",
),
)
return
target_id = target_prov.meta().id
try:
await self.context.provider_manager.set_provider(
provider_id=target_id,
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
self._apply_model(target_prov, matched_target_model_name, umo=umo)
message.set_result(
MessageEventResult().message(
f"检测到模型 [{matched_target_model_name}] 属于提供商 [{target_id}],已自动切换提供商并设置模型。",
),
)
except asyncio.CancelledError:
raise
except Exception as e:
message.set_result(
MessageEventResult().message(
safe_error("跨提供商切换并设置模型失败: ", e)
),
)
async def model_ls(
self,
message: AstrMessageEvent,
@@ -236,20 +636,17 @@ class ProviderCommands:
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。"),
)
return
# 定义正则表达式匹配 API 密钥
api_key_pattern = re.compile(r"key=[^&'\" ]+")
config = self._get_model_lookup_config(message.unified_msg_origin)
if idx_or_name is None:
models = []
try:
models = await prov.get_models()
except BaseException as e:
err_msg = api_key_pattern.sub("key=***", str(e))
message.set_result(
MessageEventResult()
.message("获取模型列表失败: " + err_msg)
.use_t2i(False),
)
models = await self._get_models_or_reply_error(
message,
prov,
config,
error_prefix="获取模型列表失败: ",
disable_t2i=True,
)
if models is None:
return
parts = ["下面列出了此模型提供商可用模型:"]
for i, model in enumerate(models, 1):
@@ -258,40 +655,43 @@ class ProviderCommands:
curr_model = prov.get_model() or ""
parts.append(f"\n当前模型: [{curr_model}]")
parts.append(
"\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名"
"\nTips: 使用 /model <模型名/编号> 切换模型。输入模型名时可自动跨提供商查找并切换;跨提供商也可使用 /provider 切换"
)
ret = "".join(parts)
message.set_result(MessageEventResult().message(ret).use_t2i(False))
elif isinstance(idx_or_name, int):
models = []
try:
models = await prov.get_models()
except BaseException as e:
message.set_result(
MessageEventResult().message("获取模型列表失败: " + str(e)),
)
models = await self._get_models_or_reply_error(
message,
prov,
config,
error_prefix="获取模型列表失败: ",
)
if models is None:
return
if idx_or_name > len(models) or idx_or_name < 1:
message.set_result(MessageEventResult().message("模型序号错误。"))
else:
try:
new_model = models[idx_or_name - 1]
prov.set_model(new_model)
except BaseException as e:
message.set_result(
MessageEventResult().message("切换模型未知错误: " + str(e)),
MessageEventResult().message(
self._apply_model(
prov,
new_model,
umo=message.unified_msg_origin,
)
),
)
message.set_result(
MessageEventResult().message(
f"切换模型成功。当前提供商: [{prov.meta().id}] 当前模型: [{prov.get_model()}]",
),
)
except Exception as e:
message.set_result(
MessageEventResult().message(
safe_error("切换模型未知错误: ", e)
),
)
return
else:
prov.set_model(idx_or_name)
message.set_result(
MessageEventResult().message(f"切换模型到 {prov.get_model()}"),
)
await self._switch_model_by_name(message, idx_or_name, prov)
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
prov = self.context.get_using_provider(message.unified_msg_origin)
@@ -322,8 +722,15 @@ class ProviderCommands:
try:
new_key = keys_data[index - 1]
prov.set_key(new_key)
except BaseException as e:
message.set_result(
MessageEventResult().message(f"切换 Key 未知错误: {e!s}"),
self.invalidate_provider_models_cache(
prov.meta().id,
umo=message.unified_msg_origin,
)
message.set_result(MessageEventResult().message("切换 Key 成功。"))
message.set_result(MessageEventResult().message("切换 Key 成功。"))
except Exception as e:
message.set_result(
MessageEventResult().message(
safe_error("切换 Key 未知错误: ", e)
),
)
return
+1 -10
View File
@@ -8,7 +8,7 @@ from bs4 import BeautifulSoup
from readability import Document
from astrbot.api import AstrBotConfig, llm_tool, logger, sp, star
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
from astrbot.api.event import AstrMessageEvent, filter
from astrbot.api.provider import ProviderRequest
from astrbot.core.provider.func_tool_manager import FunctionToolManager
@@ -196,15 +196,6 @@ class Main(star.Star):
)
return results
@filter.command("websearch")
async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None:
"""网页搜索指令(已废弃)"""
event.set_result(
MessageEventResult().message(
"此指令已经被废弃,请在 WebUI 中开启或关闭网页搜索功能。",
),
)
@llm_tool(name="web_search")
async def search_from_search_engine(
self,
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.18.3"
__version__ = "4.20.0"
+7 -7
View File
@@ -1,4 +1,4 @@
"""AstrBot CLI入口"""
"""AstrBot CLI entry point"""
import sys
@@ -29,23 +29,23 @@ def cli() -> None:
@click.command()
@click.argument("command_name", required=False, type=str)
def help(command_name: str | None) -> None:
"""显示命令的帮助信息
"""Display help information for commands
如果提供了 COMMAND_NAME,则显示该命令的详细帮助信息。
否则,显示通用帮助信息。
If COMMAND_NAME is provided, display detailed help for that command.
Otherwise, display general help information.
"""
ctx = click.get_current_context()
if command_name:
# 查找指定命令
# Find the specified command
command = cli.get_command(ctx, command_name)
if command:
# 显示特定命令的帮助信息
# Display help for the specific command
click.echo(command.get_help(ctx))
else:
click.echo(f"Unknown command: {command_name}")
sys.exit(1)
else:
# 显示通用帮助信息
# Display general help information
click.echo(cli.get_help(ctx))
+47 -43
View File
@@ -10,57 +10,61 @@ from ..utils import check_astrbot_root, get_astrbot_root
def _validate_log_level(value: str) -> str:
"""验证日志级别"""
"""Validate log level"""
value = value.upper()
if value not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
raise click.ClickException(
"日志级别必须是 DEBUG/INFO/WARNING/ERROR/CRITICAL 之一",
"Log level must be one of DEBUG/INFO/WARNING/ERROR/CRITICAL",
)
return value
def _validate_dashboard_port(value: str) -> int:
"""验证 Dashboard 端口"""
"""Validate Dashboard port"""
try:
port = int(value)
if port < 1 or port > 65535:
raise click.ClickException("端口必须在 1-65535 范围内")
raise click.ClickException("Port must be in range 1-65535")
return port
except ValueError:
raise click.ClickException("端口必须是数字")
raise click.ClickException("Port must be a number")
def _validate_dashboard_username(value: str) -> str:
"""验证 Dashboard 用户名"""
"""Validate Dashboard username"""
if not value:
raise click.ClickException("用户名不能为空")
raise click.ClickException("Username cannot be empty")
return value
def _validate_dashboard_password(value: str) -> str:
"""验证 Dashboard 密码"""
"""Validate Dashboard password"""
if not value:
raise click.ClickException("密码不能为空")
raise click.ClickException("Password cannot be empty")
return hashlib.md5(value.encode()).hexdigest()
def _validate_timezone(value: str) -> str:
"""验证时区"""
"""Validate timezone"""
try:
zoneinfo.ZoneInfo(value)
except Exception:
raise click.ClickException(f"无效的时区: {value},请使用有效的IANA时区名称")
raise click.ClickException(
f"Invalid timezone: {value}. Please use a valid IANA timezone name"
)
return value
def _validate_callback_api_base(value: str) -> str:
"""验证回调接口基址"""
"""Validate callback API base URL"""
if not value.startswith("http://") and not value.startswith("https://"):
raise click.ClickException("回调接口基址必须以 http:// 或 https:// 开头")
raise click.ClickException(
"Callback API base must start with http:// or https://"
)
return value
# 可通过CLI设置的配置项,配置键到验证器函数的映射
# Configuration items settable via CLI, mapping config keys to validator functions
CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
"timezone": _validate_timezone,
"log_level": _validate_log_level,
@@ -72,11 +76,11 @@ CONFIG_VALIDATORS: dict[str, Callable[[str], Any]] = {
def _load_config() -> dict[str, Any]:
"""加载或初始化配置文件"""
"""Load or initialize config file"""
root = get_astrbot_root()
if not check_astrbot_root(root):
raise click.ClickException(
f"{root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
f"{root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
config_path = root / "data" / "cmd_config.json"
@@ -91,11 +95,11 @@ def _load_config() -> dict[str, Any]:
try:
return json.loads(config_path.read_text(encoding="utf-8-sig"))
except json.JSONDecodeError as e:
raise click.ClickException(f"配置文件解析失败: {e!s}")
raise click.ClickException(f"Failed to parse config file: {e!s}")
def _save_config(config: dict[str, Any]) -> None:
"""保存配置文件"""
"""Save config file"""
config_path = get_astrbot_root() / "data" / "cmd_config.json"
config_path.write_text(
@@ -105,21 +109,21 @@ def _save_config(config: dict[str, Any]) -> None:
def _set_nested_item(obj: dict[str, Any], path: str, value: Any) -> None:
"""设置嵌套字典中的值"""
"""Set a value in a nested dictionary"""
parts = path.split(".")
for part in parts[:-1]:
if part not in obj:
obj[part] = {}
elif not isinstance(obj[part], dict):
raise click.ClickException(
f"配置路径冲突: {'.'.join(parts[: parts.index(part) + 1])} 不是字典",
f"Config path conflict: {'.'.join(parts[: parts.index(part) + 1])} is not a dict",
)
obj = obj[part]
obj[parts[-1]] = value
def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
"""获取嵌套字典中的值"""
"""Get a value from a nested dictionary"""
parts = path.split(".")
for part in parts:
obj = obj[part]
@@ -128,21 +132,21 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
@click.group(name="conf")
def conf() -> None:
"""配置管理命令
"""Configuration management commands
支持的配置项:
Supported config keys:
- timezone: 时区设置 (例如: Asia/Shanghai)
- timezone: Timezone setting (e.g. Asia/Shanghai)
- log_level: 日志级别 (DEBUG/INFO/WARNING/ERROR/CRITICAL)
- log_level: Log level (DEBUG/INFO/WARNING/ERROR/CRITICAL)
- dashboard.port: Dashboard 端口
- dashboard.port: Dashboard port
- dashboard.username: Dashboard 用户名
- dashboard.username: Dashboard username
- dashboard.password: Dashboard 密码
- dashboard.password: Dashboard password
- callback_api_base: 回调接口基址
- callback_api_base: Callback API base URL
"""
@@ -150,9 +154,9 @@ def conf() -> None:
@click.argument("key")
@click.argument("value")
def set_config(key: str, value: str) -> None:
"""设置配置项的值"""
"""Set the value of a config item"""
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"不支持的配置项: {key}")
raise click.ClickException(f"Unsupported config key: {key}")
config = _load_config()
@@ -162,29 +166,29 @@ def set_config(key: str, value: str) -> None:
_set_nested_item(config, key, validated_value)
_save_config(config)
click.echo(f"配置已更新: {key}")
click.echo(f"Config updated: {key}")
if key == "dashboard.password":
click.echo(" 原值: ********")
click.echo(" 新值: ********")
click.echo(" Old value: ********")
click.echo(" New value: ********")
else:
click.echo(f" 原值: {old_value}")
click.echo(f" 新值: {validated_value}")
click.echo(f" Old value: {old_value}")
click.echo(f" New value: {validated_value}")
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
raise click.ClickException(f"Unknown config key: {key}")
except Exception as e:
raise click.UsageError(f"设置配置失败: {e!s}")
raise click.UsageError(f"Failed to set config: {e!s}")
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str | None = None) -> None:
"""获取配置项的值,不提供key则显示所有可配置项"""
"""Get the value of a config item. If no key is provided, show all configurable items"""
config = _load_config()
if key:
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"不支持的配置项: {key}")
raise click.ClickException(f"Unsupported config key: {key}")
try:
value = _get_nested_item(config, key)
@@ -192,11 +196,11 @@ def get_config(key: str | None = None) -> None:
value = "********"
click.echo(f"{key}: {value}")
except KeyError:
raise click.ClickException(f"未知的配置项: {key}")
raise click.ClickException(f"Unknown config key: {key}")
except Exception as e:
raise click.UsageError(f"获取配置失败: {e!s}")
raise click.UsageError(f"Failed to get config: {e!s}")
else:
click.echo("当前配置:")
click.echo("Current config:")
for key in CONFIG_VALIDATORS:
try:
value = (
+8 -9
View File
@@ -8,16 +8,12 @@ from ..utils import check_dashboard, get_astrbot_root
async def initialize_astrbot(astrbot_root: Path) -> None:
"""执行 AstrBot 初始化逻辑"""
"""Execute AstrBot initialization logic"""
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
click.echo(f"Current Directory: {astrbot_root}")
click.echo(
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。",
)
if click.confirm(
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
f"Install AstrBot to this directory? {astrbot_root}",
default=True,
abort=True,
):
@@ -40,7 +36,7 @@ async def initialize_astrbot(astrbot_root: Path) -> None:
@click.command()
def init() -> None:
"""初始化 AstrBot"""
"""Initialize AstrBot"""
click.echo("Initializing AstrBot...")
astrbot_root = get_astrbot_root()
lock_file = astrbot_root / "astrbot.lock"
@@ -49,8 +45,11 @@ def init() -> None:
try:
with lock.acquire():
asyncio.run(initialize_astrbot(astrbot_root))
click.echo("Done! You can now run 'astrbot run' to start AstrBot")
except Timeout:
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
raise click.ClickException(
"Cannot acquire lock file. Please check if another instance is running"
)
except Exception as e:
raise click.ClickException(f"初始化失败: {e!s}")
raise click.ClickException(f"Initialization failed: {e!s}")
+54 -46
View File
@@ -16,14 +16,14 @@ from ..utils import (
@click.group()
def plug() -> None:
"""插件管理"""
"""Plugin management"""
def _get_data_path() -> Path:
base = get_astrbot_root()
if not check_astrbot_root(base):
raise click.ClickException(
f"{base}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
f"{base} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
return (base / "data").resolve()
@@ -32,7 +32,9 @@ def display_plugins(plugins, title=None, color=None) -> None:
if title:
click.echo(click.style(title, fg=color, bold=True))
click.echo(f"{'名称':<20} {'版本':<10} {'状态':<10} {'作者':<15} {'描述':<30}")
click.echo(
f"{'Name':<20} {'Version':<10} {'Status':<10} {'Author':<15} {'Description':<30}"
)
click.echo("-" * 85)
for p in plugins:
@@ -46,30 +48,30 @@ def display_plugins(plugins, title=None, color=None) -> None:
@plug.command()
@click.argument("name")
def new(name: str) -> None:
"""创建新插件"""
"""Create a new plugin"""
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
if plug_path.exists():
raise click.ClickException(f"插件 {name} 已存在")
raise click.ClickException(f"Plugin {name} already exists")
author = click.prompt("请输入插件作者", type=str)
desc = click.prompt("请输入插件描述", type=str)
version = click.prompt("请输入插件版本", type=str)
author = click.prompt("Enter plugin author", type=str)
desc = click.prompt("Enter plugin description", type=str)
version = click.prompt("Enter plugin version", type=str)
if not re.match(r"^\d+\.\d+(\.\d+)?$", version.lower().lstrip("v")):
raise click.ClickException("版本号必须为 x.y x.y.z 格式")
repo = click.prompt("请输入插件仓库:", type=str)
raise click.ClickException("Version must be in x.y or x.y.z format")
repo = click.prompt("Enter plugin repository URL:", type=str)
if not repo.startswith("http"):
raise click.ClickException("仓库地址必须以 http 开头")
raise click.ClickException("Repository URL must start with http")
click.echo("下载插件模板...")
click.echo("Downloading plugin template...")
get_git_repo(
"https://github.com/Soulter/helloworld",
plug_path,
)
click.echo("重写插件信息...")
# 重写 metadata.yaml
click.echo("Rewriting plugin metadata...")
# Rewrite metadata.yaml
with open(plug_path / "metadata.yaml", "w", encoding="utf-8") as f:
f.write(
f"name: {name}\n"
@@ -79,11 +81,13 @@ def new(name: str) -> None:
f"repo: {repo}\n",
)
# 重写 README.md
# Rewrite README.md
with open(plug_path / "README.md", "w", encoding="utf-8") as f:
f.write(f"# {name}\n\n{desc}\n\n# 支持\n\n[帮助文档](https://astrbot.app)\n")
f.write(
f"# {name}\n\n{desc}\n\n# Support\n\n[Documentation](https://astrbot.app)\n"
)
# 重写 main.py
# Rewrite main.py
with open(plug_path / "main.py", encoding="utf-8") as f:
content = f.read()
@@ -95,54 +99,54 @@ def new(name: str) -> None:
with open(plug_path / "main.py", "w", encoding="utf-8") as f:
f.write(new_content)
click.echo(f"插件 {name} 创建成功")
click.echo(f"Plugin {name} created successfully")
@plug.command()
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
@click.option("--all", "-a", is_flag=True, help="List uninstalled plugins")
def list(all: bool) -> None:
"""列出插件"""
"""List plugins"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
# 未发布的插件
# Unpublished plugins
not_published_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_PUBLISHED
]
if not_published_plugins:
display_plugins(not_published_plugins, "未发布的插件", "red")
display_plugins(not_published_plugins, "Unpublished Plugins", "red")
# 需要更新的插件
# Plugins needing update
need_update_plugins = [
p for p in plugins if p["status"] == PluginStatus.NEED_UPDATE
]
if need_update_plugins:
display_plugins(need_update_plugins, "需要更新的插件", "yellow")
display_plugins(need_update_plugins, "Plugins Needing Update", "yellow")
# 已安装的插件
# Installed plugins
installed_plugins = [p for p in plugins if p["status"] == PluginStatus.INSTALLED]
if installed_plugins:
display_plugins(installed_plugins, "已安装的插件", "green")
display_plugins(installed_plugins, "Installed Plugins", "green")
# 未安装的插件
# Uninstalled plugins
not_installed_plugins = [
p for p in plugins if p["status"] == PluginStatus.NOT_INSTALLED
]
if not_installed_plugins and all:
display_plugins(not_installed_plugins, "未安装的插件", "blue")
display_plugins(not_installed_plugins, "Uninstalled Plugins", "blue")
if (
not any([not_published_plugins, need_update_plugins, installed_plugins])
and not all
):
click.echo("未安装任何插件")
click.echo("No plugins installed")
@plug.command()
@click.argument("name")
@click.option("--proxy", help="代理服务器地址")
@click.option("--proxy", help="Proxy server address")
def install(name: str, proxy: str | None) -> None:
"""安装插件"""
"""Install a plugin"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
@@ -157,7 +161,7 @@ def install(name: str, proxy: str | None) -> None:
)
if not plugin:
raise click.ClickException(f"未找到可安装的插件 {name},可能是不存在或已安装")
raise click.ClickException(f"Plugin {name} not found or already installed")
manage_plugin(plugin, plug_path, is_update=False, proxy=proxy)
@@ -165,30 +169,32 @@ def install(name: str, proxy: str | None) -> None:
@plug.command()
@click.argument("name")
def remove(name: str) -> None:
"""卸载插件"""
"""Uninstall a plugin"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
plugin = next((p for p in plugins if p["name"] == name), None)
if not plugin or not plugin.get("local_path"):
raise click.ClickException(f"插件 {name} 不存在或未安装")
raise click.ClickException(f"Plugin {name} does not exist or is not installed")
plugin_path = plugin["local_path"]
click.confirm(f"确定要卸载插件 {name} 吗?", default=False, abort=True)
click.confirm(
f"Are you sure you want to uninstall plugin {name}?", default=False, abort=True
)
try:
shutil.rmtree(plugin_path)
click.echo(f"插件 {name} 已卸载")
click.echo(f"Plugin {name} has been uninstalled")
except Exception as e:
raise click.ClickException(f"卸载插件 {name} 失败: {e}")
raise click.ClickException(f"Failed to uninstall plugin {name}: {e}")
@plug.command()
@click.argument("name", required=False)
@click.option("--proxy", help="Github代理地址")
@click.option("--proxy", help="GitHub proxy address")
def update(name: str, proxy: str | None) -> None:
"""更新插件"""
"""Update plugins"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
plugins = build_plug_list(base_path / "plugins")
@@ -204,7 +210,9 @@ def update(name: str, proxy: str | None) -> None:
)
if not plugin:
raise click.ClickException(f"插件 {name} 不需要更新或无法更新")
raise click.ClickException(
f"Plugin {name} does not need updating or cannot be updated"
)
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
else:
@@ -213,20 +221,20 @@ def update(name: str, proxy: str | None) -> None:
]
if not need_update_plugins:
click.echo("没有需要更新的插件")
click.echo("No plugins need updating")
return
click.echo(f"发现 {len(need_update_plugins)} 个插件需要更新")
click.echo(f"Found {len(need_update_plugins)} plugin(s) needing update")
for plugin in need_update_plugins:
plugin_name = plugin["name"]
click.echo(f"正在更新插件 {plugin_name}...")
click.echo(f"Updating plugin {plugin_name}...")
manage_plugin(plugin, plug_path, is_update=True, proxy=proxy)
@plug.command()
@click.argument("query")
def search(query: str) -> None:
"""搜索插件"""
"""Search for plugins"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
@@ -239,7 +247,7 @@ def search(query: str) -> None:
]
if not matched_plugins:
click.echo(f"未找到匹配 '{query}' 的插件")
click.echo(f"No plugins matching '{query}' found")
return
display_plugins(matched_plugins, f"搜索结果: '{query}'", "cyan")
display_plugins(matched_plugins, f"Search results: '{query}'", "cyan")
+11 -9
View File
@@ -11,7 +11,7 @@ from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
async def run_astrbot(astrbot_root: Path) -> None:
"""运行 AstrBot"""
"""Run AstrBot"""
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.initial_loader import InitialLoader
@@ -26,18 +26,18 @@ async def run_astrbot(astrbot_root: Path) -> None:
await core_lifecycle.start()
@click.option("--reload", "-r", is_flag=True, help="插件自动重载")
@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str)
@click.option("--reload", "-r", is_flag=True, help="Auto-reload plugins")
@click.option("--port", "-p", help="AstrBot Dashboard port", required=False, type=str)
@click.command()
def run(reload: bool, port: str) -> None:
"""运行 AstrBot"""
"""Run AstrBot"""
try:
os.environ["ASTRBOT_CLI"] = "1"
astrbot_root = get_astrbot_root()
if not check_astrbot_root(astrbot_root):
raise click.ClickException(
f"{astrbot_root}不是有效的 AstrBot 根目录,如需初始化请使用 astrbot init",
f"{astrbot_root} is not a valid AstrBot root directory. Use 'astrbot init' to initialize",
)
os.environ["ASTRBOT_ROOT"] = str(astrbot_root)
@@ -47,7 +47,7 @@ def run(reload: bool, port: str) -> None:
os.environ["DASHBOARD_PORT"] = port
if reload:
click.echo("启用插件自动重载")
click.echo("Plugin auto-reload enabled")
os.environ["ASTRBOT_RELOAD"] = "1"
lock_file = astrbot_root / "astrbot.lock"
@@ -55,8 +55,10 @@ def run(reload: bool, port: str) -> None:
with lock.acquire():
asyncio.run(run_astrbot(astrbot_root))
except KeyboardInterrupt:
click.echo("AstrBot 已关闭...")
click.echo("AstrBot has been shut down.")
except Timeout:
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
raise click.ClickException(
"Cannot acquire lock file. Please check if another instance is running"
)
except Exception as e:
raise click.ClickException(f"运行时出现错误: {e}\n{traceback.format_exc()}")
raise click.ClickException(f"Runtime error: {e}\n{traceback.format_exc()}")
+21 -13
View File
@@ -2,9 +2,12 @@ from pathlib import Path
import click
# Static assets bundled inside the installed wheel (built by hatch_build.py).
_BUNDLED_DIST = Path(__file__).parent.parent.parent / "dashboard" / "dist"
def check_astrbot_root(path: str | Path) -> bool:
"""检查路径是否为 AstrBot 根目录"""
"""Check if the path is an AstrBot root directory"""
if not isinstance(path, Path):
path = Path(path)
if not path.exists() or not path.is_dir():
@@ -15,43 +18,48 @@ def check_astrbot_root(path: str | Path) -> bool:
def get_astrbot_root() -> Path:
"""获取Astrbot根目录路径"""
"""Get the AstrBot root directory path"""
return Path.cwd()
async def check_dashboard(astrbot_root: Path) -> None:
"""检查是否安装了dashboard"""
"""Check if the dashboard is installed"""
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
from .version_comparator import VersionComparator
# If the wheel ships bundled dashboard assets, no network download is needed.
if _BUNDLED_DIST.exists():
click.echo("Dashboard is bundled with the package skipping download.")
return
try:
dashboard_version = await get_dashboard_version()
match dashboard_version:
case None:
click.echo("未安装管理面板")
click.echo("Dashboard is not installed")
if click.confirm(
"是否安装管理面板?",
"Install dashboard?",
default=True,
abort=True,
):
click.echo("正在安装管理面板...")
click.echo("Installing dashboard...")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
version=f"v{VERSION}",
latest=False,
)
click.echo("管理面板安装完成")
click.echo("Dashboard installed successfully")
case str():
if VersionComparator.compare_version(VERSION, dashboard_version) <= 0:
click.echo("管理面板已是最新版本")
click.echo("Dashboard is already up to date")
return
try:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
click.echo(f"Dashboard version: {version}")
await download_dashboard(
path="data/dashboard.zip",
extract_path=str(astrbot_root),
@@ -59,10 +67,10 @@ async def check_dashboard(astrbot_root: Path) -> None:
latest=False,
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
click.echo(f"Failed to download dashboard: {e}")
return
except FileNotFoundError:
click.echo("初始化管理面板目录...")
click.echo("Initializing dashboard directory...")
try:
await download_dashboard(
path=str(astrbot_root / "dashboard.zip"),
@@ -70,7 +78,7 @@ async def check_dashboard(astrbot_root: Path) -> None:
version=f"v{VERSION}",
latest=False,
)
click.echo("管理面板初始化完成")
click.echo("Dashboard initialized successfully")
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
click.echo(f"Failed to download dashboard: {e}")
return
+47 -43
View File
@@ -13,22 +13,22 @@ from .version_comparator import VersionComparator
class PluginStatus(str, Enum):
INSTALLED = "已安装"
NEED_UPDATE = "需更新"
NOT_INSTALLED = "未安装"
NOT_PUBLISHED = "未发布"
INSTALLED = "installed"
NEED_UPDATE = "needs-update"
NOT_INSTALLED = "not-installed"
NOT_PUBLISHED = "unpublished"
def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
"""从 Git 仓库下载代码并解压到指定路径"""
"""Download code from a Git repository and extract to the specified path"""
temp_dir = Path(tempfile.mkdtemp())
try:
# 解析仓库信息
# Parse repository info
repo_namespace = url.split("/")[-2:]
author = repo_namespace[0]
repo = repo_namespace[1]
# 尝试获取最新的 release
# Try to get the latest release
release_url = f"https://api.github.com/repos/{author}/{repo}/releases"
try:
with httpx.Client(
@@ -40,21 +40,21 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
releases = resp.json()
if releases:
# 使用最新的 release
# Use the latest release
download_url = releases[0]["zipball_url"]
else:
# 没有 release,使用默认分支
click.echo(f"正在从默认分支下载 {author}/{repo}")
# No release found, use default branch
click.echo(f"Downloading {author}/{repo} from default branch")
download_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
except Exception as e:
click.echo(f"获取 release 信息失败: {e},将直接使用提供的 URL")
click.echo(f"Failed to get release info: {e}. Using provided URL directly")
download_url = url
# 应用代理
# Apply proxy
if proxy:
download_url = f"{proxy}/{download_url}"
# 下载并解压
# Download and extract
with httpx.Client(
proxy=proxy if proxy else None,
follow_redirects=True,
@@ -65,7 +65,7 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
and "archive/refs/heads/master.zip" in download_url
):
alt_url = download_url.replace("master.zip", "main.zip")
click.echo("master 分支不存在,尝试下载 main 分支")
click.echo("Branch 'master' not found, trying 'main' branch")
resp = client.get(alt_url)
resp.raise_for_status()
else:
@@ -84,13 +84,13 @@ def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
def load_yaml_metadata(plugin_dir: Path) -> dict:
""" metadata.yaml 文件加载插件元数据
"""Load plugin metadata from metadata.yaml file
Args:
plugin_dir: 插件目录路径
plugin_dir: Plugin directory path
Returns:
dict: 包含元数据的字典,如果读取失败则返回空字典
dict: Dictionary containing metadata, or empty dict if loading fails
"""
yaml_path = plugin_dir / "metadata.yaml"
@@ -98,33 +98,33 @@ def load_yaml_metadata(plugin_dir: Path) -> dict:
try:
return yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
except Exception as e:
click.echo(f"读取 {yaml_path} 失败: {e}", err=True)
click.echo(f"Failed to read {yaml_path}: {e}", err=True)
return {}
def build_plug_list(plugins_dir: Path) -> list:
"""构建插件列表,包含本地和在线插件信息
"""Build plugin list containing local and online plugin information
Args:
plugins_dir (Path): 插件目录路径
plugins_dir (Path): Plugin directory path
Returns:
list: 包含插件信息的字典列表
list: List of dicts containing plugin information
"""
# 获取本地插件信息
# Get local plugin info
result = []
if plugins_dir.exists():
for plugin_name in [d.name for d in plugins_dir.glob("*") if d.is_dir()]:
plugin_dir = plugins_dir / plugin_name
# metadata.yaml 加载元数据
# Load metadata from metadata.yaml
metadata = load_yaml_metadata(plugin_dir)
if "desc" not in metadata and "description" in metadata:
metadata["desc"] = metadata["description"]
# 如果成功加载元数据,添加到结果列表
# If metadata loaded successfully, add to result list
if metadata and all(
k in metadata for k in ["name", "desc", "version", "author", "repo"]
):
@@ -140,7 +140,7 @@ def build_plug_list(plugins_dir: Path) -> list:
},
)
# 获取在线插件列表
# Get online plugin list
online_plugins = []
try:
with httpx.Client() as client:
@@ -160,13 +160,13 @@ def build_plug_list(plugins_dir: Path) -> list:
},
)
except Exception as e:
click.echo(f"获取在线插件列表失败: {e}", err=True)
click.echo(f"Failed to get online plugin list: {e}", err=True)
# 与在线插件比对,更新状态
# Compare with online plugins and update status
online_plugin_names = {plugin["name"] for plugin in online_plugins}
for local_plugin in result:
if local_plugin["name"] in online_plugin_names:
# 查找对应的在线插件
# Find the corresponding online plugin
online_plugin = next(
p for p in online_plugins if p["name"] == local_plugin["name"]
)
@@ -179,10 +179,10 @@ def build_plug_list(plugins_dir: Path) -> list:
):
local_plugin["status"] = PluginStatus.NEED_UPDATE
else:
# 本地插件未在线上发布
# Local plugin is not published online
local_plugin["status"] = PluginStatus.NOT_PUBLISHED
# 添加未安装的在线插件
# Add uninstalled online plugins
for online_plugin in online_plugins:
if not any(plugin["name"] == online_plugin["name"] for plugin in result):
result.append(online_plugin)
@@ -196,19 +196,19 @@ def manage_plugin(
is_update: bool = False,
proxy: str | None = None,
) -> None:
"""安装或更新插件
"""Install or update a plugin
Args:
plugin (dict): 插件信息字典
plugins_dir (Path): 插件目录
is_update (bool, optional): 是否为更新操作. 默认为 False
proxy (str, optional): 代理服务器地址
plugin (dict): Plugin info dict
plugins_dir (Path): Plugins directory
is_update (bool, optional): Whether this is an update operation. Defaults to False
proxy (str, optional): Proxy server address
"""
plugin_name = plugin["name"]
repo_url = plugin["repo"]
# 如果是更新且有本地路径,直接使用本地路径
# If updating and local path exists, use it directly
if is_update and plugin.get("local_path"):
target_path = Path(plugin["local_path"])
else:
@@ -216,11 +216,13 @@ def manage_plugin(
backup_path = Path(f"{target_path}_backup") if is_update else None
# 检查插件是否存在
# Check if plugin exists
if is_update and not target_path.exists():
raise click.ClickException(f"插件 {plugin_name} 未安装,无法更新")
raise click.ClickException(
f"Plugin {plugin_name} is not installed and cannot be updated"
)
# 备份现有插件
# Backup existing plugin
if is_update and backup_path is not None and backup_path.exists():
shutil.rmtree(backup_path)
if is_update and backup_path is not None:
@@ -228,19 +230,21 @@ def manage_plugin(
try:
click.echo(
f"正在从 {repo_url} {'更新' if is_update else '下载'}插件 {plugin_name}...",
f"{'Updating' if is_update else 'Downloading'} plugin {plugin_name} from {repo_url}...",
)
get_git_repo(repo_url, target_path, proxy)
# 更新成功,删除备份
# Update succeeded, delete backup
if is_update and backup_path is not None and backup_path.exists():
shutil.rmtree(backup_path)
click.echo(f"插件 {plugin_name} {'更新' if is_update else '安装'}成功")
click.echo(
f"Plugin {plugin_name} {'updated' if is_update else 'installed'} successfully"
)
except Exception as e:
if target_path.exists():
shutil.rmtree(target_path, ignore_errors=True)
if is_update and backup_path is not None and backup_path.exists():
shutil.move(backup_path, target_path)
raise click.ClickException(
f"{'更新' if is_update else '安装'}插件 {plugin_name} 时出错: {e}",
f"Error {'updating' if is_update else 'installing'} plugin {plugin_name}: {e}",
)
+11 -11
View File
@@ -1,4 +1,4 @@
"""拷贝自 astrbot.core.utils.version_comparator"""
"""Copied from astrbot.core.utils.version_comparator"""
import re
@@ -6,11 +6,11 @@ import re
class VersionComparator:
@staticmethod
def compare_version(v1: str, v2: str) -> int:
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
"""Compare version numbers according to Semver semantics. Supports version numbers with more than 3 digits and handles pre-release tags.
参考: https://semver.org/lang/zh-CN/
Reference: https://semver.org/
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2
Returns 1 if v1 > v2, -1 if v1 < v2, 0 if v1 == v2.
"""
v1 = v1.lower().replace("v", "")
v2 = v2.lower().replace("v", "")
@@ -24,7 +24,7 @@ class VersionComparator:
return [], None
major_minor_patch = match.group(1).split(".")
prerelease = match.group(2)
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
# buildmetadata = match.group(3) # Build metadata is ignored in comparison
parts = [int(x) for x in major_minor_patch]
prerelease = VersionComparator._split_prerelease(prerelease)
return parts, prerelease
@@ -32,7 +32,7 @@ class VersionComparator:
v1_parts, v1_prerelease = split_version(v1)
v2_parts, v2_prerelease = split_version(v2)
# 比较数字部分
# Compare numeric parts
length = max(len(v1_parts), len(v2_parts))
v1_parts.extend([0] * (length - len(v1_parts)))
v2_parts.extend([0] * (length - len(v2_parts)))
@@ -43,11 +43,11 @@ class VersionComparator:
if v1_parts[i] < v2_parts[i]:
return -1
# 比较预发布标签
# Compare pre-release tags
if v1_prerelease is None and v2_prerelease is not None:
return 1 # 没有预发布标签的版本高于有预发布标签的版本
return 1 # Version without pre-release tag is higher than one with it
if v1_prerelease is not None and v2_prerelease is None:
return -1 # 有预发布标签的版本低于没有预发布标签的版本
return -1 # Version with pre-release tag is lower than one without it
if v1_prerelease is not None and v2_prerelease is not None:
len_pre = max(len(v1_prerelease), len(v2_prerelease))
for i in range(len_pre):
@@ -72,9 +72,9 @@ class VersionComparator:
return 1
if p1 < p2:
return -1
return 0 # 预发布标签完全相同
return 0 # Pre-release tags are identical
return 0 # 数字部分和预发布标签都相同
return 0 # Both numeric parts and pre-release tags are equal
@staticmethod
def _split_prerelease(prerelease):
+16 -2
View File
@@ -4,7 +4,21 @@ from astrbot.core.config import AstrBotConfig
from astrbot.core.config.default import DB_PATH
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.file_token_service import FileTokenService
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.utils.pip_installer import (
DependencyConflictError as DependencyConflictError,
)
from astrbot.core.utils.pip_installer import (
PipInstaller,
)
from astrbot.core.utils.requirements_utils import (
RequirementsPrecheckFailed as RequirementsPrecheckFailed,
)
from astrbot.core.utils.requirements_utils import (
find_missing_requirements as find_missing_requirements,
)
from astrbot.core.utils.requirements_utils import (
find_missing_requirements_or_raise as find_missing_requirements_or_raise,
)
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.t2i.renderer import HtmlRenderer
@@ -14,7 +28,7 @@ from .utils.astrbot_path import get_astrbot_data_path
# 初始化数据存储文件夹
os.makedirs(get_astrbot_data_path(), exist_ok=True)
DEMO_MODE = os.getenv("DEMO_MODE", False)
DEMO_MODE = os.getenv("DEMO_MODE", "False").strip().lower() in ("true", "1", "t")
astrbot_config = AstrBotConfig()
t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img")
+19 -6
View File
@@ -144,10 +144,14 @@ class MCPClient:
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str) -> None:
def logging_callback(
msg: str | mcp.types.LoggingMessageNotificationParams,
) -> None:
# Handle MCP service error logs
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
if msg.level in ("warning", "error", "critical", "alert", "emergency"):
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
self.server_errlogs.append(log_msg)
if "url" in cfg:
success, error_msg = await _quick_test_mcp_connection(cfg)
@@ -214,15 +218,24 @@ class MCPClient:
**cfg,
)
def callback(msg: str) -> None:
def callback(msg: str | mcp.types.LoggingMessageNotificationParams) -> None:
# Handle MCP service error logs
self.server_errlogs.append(msg)
if isinstance(msg, mcp.types.LoggingMessageNotificationParams):
if msg.level in (
"warning",
"error",
"critical",
"alert",
"emergency",
):
log_msg = f"[{msg.level.upper()}] {str(msg.data)}"
self.server_errlogs.append(log_msg)
stdio_transport = await self.exit_stack.enter_async_context(
mcp.stdio_client(
server_params,
errlog=LogPipe(
level=logging.ERROR,
level=logging.INFO,
logger=logger,
identifier=f"MCPServer-{name}",
callback=callback,
@@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
while True:
try:
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
item_type, item_data = await asyncio.get_running_loop().run_in_executor(
None, response_queue.get, True, 1
)
except queue.Empty:
@@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
# 发起请求
partial = functools.partial(Application.call, **payload)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
response = await asyncio.get_running_loop().run_in_executor(None, partial)
async for resp in self._handle_streaming_response(response, session_id):
yield resp
+5 -3
View File
@@ -291,6 +291,9 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
except Exception:
continue
prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {})
agent_max_step = int(prov_settings.get("max_agent_step", 30))
stream = prov_settings.get("streaming_response", False)
llm_resp = await ctx.tool_loop_agent(
event=event,
chat_provider_id=prov_id,
@@ -299,9 +302,8 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
system_prompt=tool.agent.instructions,
tools=toolset,
contexts=contexts,
max_steps=30,
run_hooks=tool.agent.run_hooks,
stream=ctx.get_config().get("provider_settings", {}).get("stream", False),
max_steps=agent_max_step,
stream=stream,
)
yield mcp.types.CallToolResult(
content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)]
+72 -2
View File
@@ -20,18 +20,32 @@ from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
from astrbot.core.astr_agent_run_util import AgentRunner
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
from astrbot.core.astr_main_agent_resources import (
ANNOTATE_EXECUTION_TOOL,
BROWSER_BATCH_EXEC_TOOL,
BROWSER_EXEC_TOOL,
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
CREATE_SKILL_CANDIDATE_TOOL,
CREATE_SKILL_PAYLOAD_TOOL,
EVALUATE_SKILL_CANDIDATE_TOOL,
EXECUTE_SHELL_TOOL,
FILE_DOWNLOAD_TOOL,
FILE_UPLOAD_TOOL,
GET_EXECUTION_HISTORY_TOOL,
GET_SKILL_PAYLOAD_TOOL,
KNOWLEDGE_BASE_QUERY_TOOL,
LIST_SKILL_CANDIDATES_TOOL,
LIST_SKILL_RELEASES_TOOL,
LIVE_MODE_SYSTEM_PROMPT,
LLM_SAFETY_MODE_SYSTEM_PROMPT,
LOCAL_EXECUTE_SHELL_TOOL,
LOCAL_PYTHON_TOOL,
PROMOTE_SKILL_CANDIDATE_TOOL,
PYTHON_TOOL,
ROLLBACK_SKILL_RELEASE_TOOL,
RUN_BROWSER_SKILL_TOOL,
SANDBOX_MODE_PROMPT,
SEND_MESSAGE_TO_USER_TOOL,
SYNC_SKILL_RELEASE_TOOL,
TOOL_CALL_PROMPT,
TOOL_CALL_PROMPT_SKILLS_LIKE_MODE,
retrieve_knowledge_base,
@@ -832,7 +846,10 @@ def _apply_sandbox_tools(
) -> None:
if req.func_tool is None:
req.func_tool = ToolSet()
if config.sandbox_cfg.get("booter") == "shipyard":
if req.system_prompt is None:
req.system_prompt = ""
booter = config.sandbox_cfg.get("booter", "shipyard_neo")
if booter == "shipyard":
ep = config.sandbox_cfg.get("shipyard_endpoint", "")
at = config.sandbox_cfg.get("shipyard_access_token", "")
if not ep or not at:
@@ -840,11 +857,64 @@ def _apply_sandbox_tools(
return
os.environ["SHIPYARD_ENDPOINT"] = ep
os.environ["SHIPYARD_ACCESS_TOKEN"] = at
req.func_tool.add_tool(EXECUTE_SHELL_TOOL)
req.func_tool.add_tool(PYTHON_TOOL)
req.func_tool.add_tool(FILE_UPLOAD_TOOL)
req.func_tool.add_tool(FILE_DOWNLOAD_TOOL)
req.system_prompt = f"{req.system_prompt}\n{SANDBOX_MODE_PROMPT}\n"
if booter == "shipyard_neo":
# Neo-specific path rule: filesystem tools operate relative to sandbox
# workspace root. Do not prepend "/workspace".
req.system_prompt += (
"\n[Shipyard Neo File Path Rule]\n"
"When using sandbox filesystem tools (upload/download/read/write/list/delete), "
"always pass paths relative to the sandbox workspace root. "
"Example: use `baidu_homepage.png` instead of `/workspace/baidu_homepage.png`.\n"
)
req.system_prompt += (
"\n[Neo Skill Lifecycle Workflow]\n"
"When user asks to create/update a reusable skill in Neo mode, use lifecycle tools instead of directly writing local skill folders.\n"
"Preferred sequence:\n"
"1) Use `astrbot_create_skill_payload` to store canonical payload content and get `payload_ref`.\n"
"2) Use `astrbot_create_skill_candidate` with `skill_key` + `source_execution_ids` (and optional `payload_ref`) to create a candidate.\n"
"3) Use `astrbot_promote_skill_candidate` to release: `stage=canary` for trial; `stage=stable` for production.\n"
"For stable release, set `sync_to_local=true` to sync `payload.skill_markdown` into local `SKILL.md`.\n"
"Do not treat ad-hoc generated files as reusable Neo skills unless they are captured via payload/candidate/release.\n"
"To update an existing skill, create a new payload/candidate and promote a new release version; avoid patching old local folders directly.\n"
)
# Determine sandbox capabilities from an already-booted session.
# If no session exists yet (first request), capabilities is None
# and we register all tools conservatively.
from astrbot.core.computer.computer_client import session_booter
sandbox_capabilities: list[str] | None = None
existing_booter = session_booter.get(session_id)
if existing_booter is not None:
sandbox_capabilities = getattr(existing_booter, "capabilities", None)
# Browser tools: only register if profile supports browser
# (or if capabilities are unknown because sandbox hasn't booted yet)
if sandbox_capabilities is None or "browser" in sandbox_capabilities:
req.func_tool.add_tool(BROWSER_EXEC_TOOL)
req.func_tool.add_tool(BROWSER_BATCH_EXEC_TOOL)
req.func_tool.add_tool(RUN_BROWSER_SKILL_TOOL)
# Neo-specific tools (always available for shipyard_neo)
req.func_tool.add_tool(GET_EXECUTION_HISTORY_TOOL)
req.func_tool.add_tool(ANNOTATE_EXECUTION_TOOL)
req.func_tool.add_tool(CREATE_SKILL_PAYLOAD_TOOL)
req.func_tool.add_tool(GET_SKILL_PAYLOAD_TOOL)
req.func_tool.add_tool(CREATE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(LIST_SKILL_CANDIDATES_TOOL)
req.func_tool.add_tool(EVALUATE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(PROMOTE_SKILL_CANDIDATE_TOOL)
req.func_tool.add_tool(LIST_SKILL_RELEASES_TOOL)
req.func_tool.add_tool(ROLLBACK_SKILL_RELEASE_TOOL)
req.func_tool.add_tool(SYNC_SKILL_RELEASE_TOOL)
req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n"
def _proactive_cron_job_tools(req: ProviderRequest) -> None:
+42 -1
View File
@@ -13,11 +13,25 @@ from astrbot.core.agent.tool import FunctionTool, ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.computer.computer_client import get_booter
from astrbot.core.computer.tools import (
AnnotateExecutionTool,
BrowserBatchExecTool,
BrowserExecTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
ExecuteShellTool,
FileDownloadTool,
FileUploadTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
LocalPythonTool,
PromoteSkillCandidateTool,
PythonTool,
RollbackSkillReleaseTool,
RunBrowserSkillTool,
SyncSkillReleaseTool,
)
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform.message_session import MessageSession
@@ -190,7 +204,7 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
"type": "string",
"description": (
"Component type. One of: "
"plain, image, record, file, mention_user"
"plain, image, record, video, file, mention_user. Record is voice message."
),
},
"text": {
@@ -306,6 +320,19 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
components.append(Comp.Record.fromURL(url=url))
else:
return f"error: messages[{idx}] must include path or url for record component."
elif msg_type == "video":
path = msg.get("path")
url = msg.get("url")
if path:
(
local_path,
file_from_sandbox,
) = await self._resolve_path_from_sandbox(context, path)
components.append(Comp.Video.fromFileSystem(path=local_path))
elif url:
components.append(Comp.Video.fromURL(url=url))
else:
return f"error: messages[{idx}] must include path or url for video component."
elif msg_type == "file":
path = msg.get("path")
url = msg.get("url")
@@ -449,6 +476,20 @@ PYTHON_TOOL = PythonTool()
LOCAL_PYTHON_TOOL = LocalPythonTool()
FILE_UPLOAD_TOOL = FileUploadTool()
FILE_DOWNLOAD_TOOL = FileDownloadTool()
BROWSER_EXEC_TOOL = BrowserExecTool()
BROWSER_BATCH_EXEC_TOOL = BrowserBatchExecTool()
RUN_BROWSER_SKILL_TOOL = RunBrowserSkillTool()
GET_EXECUTION_HISTORY_TOOL = GetExecutionHistoryTool()
ANNOTATE_EXECUTION_TOOL = AnnotateExecutionTool()
CREATE_SKILL_PAYLOAD_TOOL = CreateSkillPayloadTool()
GET_SKILL_PAYLOAD_TOOL = GetSkillPayloadTool()
CREATE_SKILL_CANDIDATE_TOOL = CreateSkillCandidateTool()
LIST_SKILL_CANDIDATES_TOOL = ListSkillCandidatesTool()
EVALUATE_SKILL_CANDIDATE_TOOL = EvaluateSkillCandidateTool()
PROMOTE_SKILL_CANDIDATE_TOOL = PromoteSkillCandidateTool()
LIST_SKILL_RELEASES_TOOL = ListSkillReleasesTool()
ROLLBACK_SKILL_RELEASE_TOOL = RollbackSkillReleaseTool()
SYNC_SKILL_RELEASE_TOOL = SyncSkillReleaseTool()
# we prevent astrbot from connecting to known malicious hosts
# these hosts are base64 encoded
+188 -3
View File
@@ -12,7 +12,7 @@ import os
import shutil
import zipfile
from dataclasses import dataclass, field
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING, Any
@@ -61,6 +61,69 @@ def _get_major_version(version_str: str) -> str:
CMD_CONFIG_FILE_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
KB_PATH = get_astrbot_knowledge_base_path()
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = 5
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV = (
"ASTRBOT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT"
)
def _load_platform_stats_invalid_count_warn_limit() -> int:
raw_value = os.getenv(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV)
if raw_value is None:
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
try:
value = int(raw_value)
if value < 0:
raise ValueError("negative")
return value
except (TypeError, ValueError):
logger.warning(
"Invalid env %s=%r, fallback to default %d",
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT_ENV,
raw_value,
DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT,
)
return DEFAULT_PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT
PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT = (
_load_platform_stats_invalid_count_warn_limit()
)
class _InvalidCountWarnLimiter:
"""Rate-limit warnings for invalid platform_stats count values."""
def __init__(self, limit: int) -> None:
self.limit = limit
self._count = 0
self._suppression_logged = False
def warn_invalid_count(self, value: Any, key_for_log: tuple[Any, ...]) -> None:
if self.limit > 0:
if self._count < self.limit:
logger.warning(
"platform_stats count 非法,已按 0 处理: value=%r, key=%s",
value,
key_for_log,
)
self._count += 1
if self._count == self.limit and not self._suppression_logged:
logger.warning(
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
self.limit,
)
self._suppression_logged = True
return
if not self._suppression_logged:
# limit <= 0: emit only one suppression warning.
logger.warning(
"platform_stats 非法 count 告警已达到上限 (%d),后续将抑制",
self.limit,
)
self._suppression_logged = True
@dataclass
@@ -138,6 +201,10 @@ class ImportResult:
}
class DatabaseClearError(RuntimeError):
"""Raised when clearing the main database in replace mode fails."""
class AstrBotImporter:
"""AstrBot 数据导入器
@@ -342,6 +409,9 @@ class AstrBotImporter:
imported = await self._import_main_database(main_data)
result.imported_tables.update(imported)
except DatabaseClearError as e:
result.add_error(f"清空主数据库失败: {e}")
return result
except Exception as e:
result.add_error(f"导入主数据库失败: {e}")
return result
@@ -452,7 +522,9 @@ class AstrBotImporter:
await session.execute(delete(model_class))
logger.debug(f"已清空表 {table_name}")
except Exception as e:
logger.warning(f"清空表 {table_name} 失败: {e}")
raise DatabaseClearError(
f"清空表 {table_name} 失败: {e}"
) from e
async def _clear_kb_data(self) -> None:
"""清空知识库数据"""
@@ -494,9 +566,10 @@ class AstrBotImporter:
if not model_class:
logger.warning(f"未知的表: {table_name}")
continue
normalized_rows = self._preprocess_main_table_rows(table_name, rows)
count = 0
for row in rows:
for row in normalized_rows:
try:
# 转换 datetime 字符串为 datetime 对象
row = self._convert_datetime_fields(row, model_class)
@@ -511,6 +584,118 @@ class AstrBotImporter:
return imported
def _preprocess_main_table_rows(
self, table_name: str, rows: list[dict[str, Any]]
) -> list[dict[str, Any]]:
if table_name == "platform_stats":
normalized_rows = self._merge_platform_stats_rows(rows)
duplicate_count = len(rows) - len(normalized_rows)
if duplicate_count > 0:
logger.warning(
"检测到 %s 重复键 %d 条,已在导入前聚合",
table_name,
duplicate_count,
)
return normalized_rows
return rows
def _merge_platform_stats_rows(
self, rows: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Merge duplicate platform_stats rows by normalized timestamp/platform key.
Note:
- Invalid/empty timestamps are kept as distinct rows to avoid accidental merging.
- Non-string platform_id/platform_type are kept as distinct rows.
- Invalid count warnings are rate-limited per function invocation.
"""
merged: dict[tuple[str, str, str], dict[str, Any]] = {}
result: list[dict[str, Any]] = []
warn_limiter = _InvalidCountWarnLimiter(PLATFORM_STATS_INVALID_COUNT_WARN_LIMIT)
for row in rows:
normalized_row, normalized_timestamp, count = (
self._normalize_platform_stats_entry(row, warn_limiter)
)
platform_id = normalized_row.get("platform_id")
platform_type = normalized_row.get("platform_type")
if (
normalized_timestamp is None
or not isinstance(platform_id, str)
or not isinstance(platform_type, str)
):
result.append(normalized_row)
continue
merge_key = (normalized_timestamp, platform_id, platform_type)
existing = merged.get(merge_key)
if existing is None:
merged[merge_key] = normalized_row
result.append(normalized_row)
else:
existing["count"] += count
return result
def _normalize_platform_stats_entry(
self,
row: dict[str, Any],
warn_limiter: _InvalidCountWarnLimiter,
) -> tuple[dict[str, Any], str | None, int]:
normalized_row = dict(row)
raw_timestamp = normalized_row.get("timestamp")
normalized_timestamp = self._normalize_platform_stats_timestamp(raw_timestamp)
if normalized_timestamp is not None:
normalized_row["timestamp"] = normalized_timestamp
elif isinstance(raw_timestamp, str):
normalized_row["timestamp"] = raw_timestamp.strip()
elif raw_timestamp is None:
normalized_row["timestamp"] = ""
else:
normalized_row["timestamp"] = str(raw_timestamp)
raw_count = normalized_row.get("count", 0)
try:
count = int(raw_count)
except (TypeError, ValueError):
key_for_log = (
normalized_row.get("timestamp"),
repr(normalized_row.get("platform_id")),
repr(normalized_row.get("platform_type")),
)
warn_limiter.warn_invalid_count(raw_count, key_for_log)
count = 0
normalized_row["count"] = count
return normalized_row, normalized_timestamp, count
def _normalize_platform_stats_timestamp(self, value: Any) -> str | None:
if isinstance(value, datetime):
dt = value
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
else:
dt = dt.astimezone(timezone.utc)
return dt.isoformat()
if isinstance(value, str):
timestamp = value.strip()
if not timestamp:
return None
if timestamp.endswith("Z"):
timestamp = f"{timestamp[:-1]}+00:00"
try:
dt = datetime.fromisoformat(timestamp)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=timezone.utc)
else:
dt = dt.astimezone(timezone.utc)
return dt.isoformat()
except ValueError:
return None
return None
async def _import_knowledge_bases(
self,
zf: zipfile.ZipFile,
+19 -1
View File
@@ -1,4 +1,9 @@
from ..olayer import FileSystemComponent, PythonComponent, ShellComponent
from ..olayer import (
BrowserComponent,
FileSystemComponent,
PythonComponent,
ShellComponent,
)
class ComputerBooter:
@@ -11,6 +16,19 @@ class ComputerBooter:
@property
def shell(self) -> ShellComponent: ...
@property
def capabilities(self) -> tuple[str, ...] | None:
"""Sandbox capabilities (e.g. ('python', 'shell', 'filesystem', 'browser')).
Returns None if the booter doesn't support capability introspection
(backward-compatible default). Subclasses override after boot.
"""
return None
@property
def browser(self) -> BrowserComponent | None:
return None
async def boot(self, session_id: str) -> None: ...
async def shutdown(self) -> None: ...
@@ -0,0 +1,259 @@
"""Manage Bay container lifecycle for zero-config Shipyard Neo integration.
When no Bay endpoint is configured, AstrBot can automatically start a Bay
container using the Docker socket (like BoxliteBooter does for Ship
containers).
"""
from __future__ import annotations
import asyncio
import io
import json
import tarfile
from typing import Any
import aiodocker
import aiohttp
from astrbot.api import logger
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
BAY_IMAGE = "ghcr.io/astrbotdevs/shipyard-neo-bay:latest"
BAY_CONTAINER_NAME = "astrbot-bay"
BAY_LABEL = "astrbot.bay.managed"
BAY_PORT = 8114
HEALTH_TIMEOUT_S = 60
HEALTH_POLL_INTERVAL_S = 2
class BayContainerManager:
"""Start / reuse / stop a Bay container via Docker Engine API."""
def __init__(
self,
image: str = BAY_IMAGE,
host_port: int = BAY_PORT,
) -> None:
self._image = image
self._host_port = host_port
self._docker: aiodocker.Docker | None = None
self._container: Any = None
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
async def ensure_running(self) -> str:
"""Make sure a Bay container is running. Returns the endpoint URL.
If a container labelled ``astrbot.bay.managed`` already exists
and is running, it will be reused. Otherwise a new container is
created from *self._image*.
"""
try:
self._docker = aiodocker.Docker()
except Exception as exc:
raise RuntimeError(
"Failed to connect to Docker daemon. "
"Ensure Docker is installed and running, or configure "
"an explicit Bay endpoint instead of auto-start mode."
) from exc
# 1. Look for an existing managed container
existing = await self._find_managed_container()
if existing is not None:
state = existing["State"]
if state.get("Running"):
cid = existing["Id"][:12]
logger.info("[BayManager] Reusing existing Bay container: %s", cid)
self._container = await self._docker.containers.get(existing["Id"])
return f"http://127.0.0.1:{self._host_port}"
else:
# Container exists but stopped — restart it
logger.info("[BayManager] Restarting stopped Bay container")
container = await self._docker.containers.get(existing["Id"])
await container.start()
self._container = container
return f"http://127.0.0.1:{self._host_port}"
# 2. Pull image if needed
await self._pull_image_if_needed()
# 3. Create and start container
logger.info(
"[BayManager] Starting Bay container: image=%s, port=%d",
self._image,
self._host_port,
)
config = {
"Image": self._image,
"Labels": {BAY_LABEL: "true"},
"Env": [
"BAY_SERVER__HOST=0.0.0.0",
f"BAY_SERVER__PORT={BAY_PORT}",
"BAY_DATA_DIR=/app/data",
# allow_anonymous=false → auto-provisions API key
"BAY_SECURITY__ALLOW_ANONYMOUS=false",
],
"HostConfig": {
"PortBindings": {
f"{BAY_PORT}/tcp": [{"HostPort": str(self._host_port)}],
},
"Binds": [
# Bay needs Docker socket to create sandbox containers
"/var/run/docker.sock:/var/run/docker.sock",
],
"RestartPolicy": {"Name": "unless-stopped"},
},
}
self._container = await self._docker.containers.create_or_replace(
BAY_CONTAINER_NAME, config
)
await self._container.start()
logger.info("[BayManager] Bay container started: %s", BAY_CONTAINER_NAME)
return f"http://127.0.0.1:{self._host_port}"
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
"""Block until Bay's ``/health`` endpoint returns 200."""
url = f"http://127.0.0.1:{self._host_port}/health"
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
last_error: str = ""
async with aiohttp.ClientSession() as session:
while loop.time() < deadline:
try:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=3)
) as resp:
if resp.status == 200:
logger.info("[BayManager] Bay is healthy")
return
last_error = f"HTTP {resp.status}"
except Exception as exc:
last_error = str(exc)
await asyncio.sleep(HEALTH_POLL_INTERVAL_S)
raise TimeoutError(
f"Bay did not become healthy within {timeout}s (last error: {last_error})"
)
async def read_credentials(self) -> str:
"""Read auto-provisioned API key from Bay container.
Bay writes ``credentials.json`` to its data directory when
``allow_anonymous=false`` and no explicit API key is set.
"""
if self._container is None:
return ""
try:
# Read credentials.json from container filesystem
tar_stream = await self._container.get_archive("/app/data/credentials.json")
# get_archive returns (tar_data, stat)
tar_data = tar_stream
if isinstance(tar_data, dict):
raw = tar_data.get("data", b"")
elif isinstance(tar_data, tuple):
# (stream, stat_info)
raw = b""
stream = tar_data[0]
if hasattr(stream, "read"):
raw = await stream.read()
elif isinstance(stream, bytes):
raw = stream
else:
# It might be a chunked response
chunks = []
async for chunk in stream:
chunks.append(chunk)
raw = b"".join(chunks)
else:
raw = tar_data if isinstance(tar_data, bytes) else b""
if not raw:
logger.debug("[BayManager] Empty tar response from container")
return ""
tario = io.BytesIO(raw)
with tarfile.open(fileobj=tario) as tar:
for member in tar.getmembers():
f = tar.extractfile(member)
if f:
creds = json.loads(f.read().decode("utf-8"))
api_key = creds.get("api_key", "")
if api_key:
masked = (
f"{api_key[:8]}..."
if len(api_key) >= 10
else "redacted"
)
logger.info(
"[BayManager] Auto-discovered Bay API key: %s",
masked,
)
return api_key
except Exception as exc:
logger.debug(
"[BayManager] Failed to read credentials from container: %s", exc
)
return ""
async def close_client(self) -> None:
"""Close the Docker client without stopping the container.
The Bay container stays running for reuse by future sessions.
"""
if self._docker is not None:
await self._docker.close()
self._docker = None
async def stop(self) -> None:
"""Stop and remove the managed Bay container."""
if self._container is not None:
try:
await self._container.stop()
await self._container.delete(force=True)
logger.info("[BayManager] Bay container stopped and removed")
except Exception as exc:
logger.debug("[BayManager] Error stopping Bay container: %s", exc)
finally:
self._container = None
await self.close_client()
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
async def _find_managed_container(self) -> dict | None:
"""Find an existing container with our management label."""
assert self._docker is not None
containers = await self._docker.containers.list(
all=True,
filters=json.dumps({"label": [f"{BAY_LABEL}=true"]}),
)
if containers:
# Inspect first match to get full state
return await containers[0].show()
return None
async def _pull_image_if_needed(self) -> None:
"""Pull the Bay image if it doesn't exist locally."""
assert self._docker is not None
try:
await self._docker.images.inspect(self._image)
logger.debug("[BayManager] Image %s already exists", self._image)
except aiodocker.exceptions.DockerError:
logger.info("[BayManager] Pulling image %s ...", self._image)
# Pull with progress logging
await self._docker.images.pull(self._image)
logger.info("[BayManager] Image %s pulled successfully", self._image)
+4
View File
@@ -64,6 +64,10 @@ class MockShipyardSandboxClient:
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(url, data=data) as response:
if response.status == 200:
logger.info(
"[Computer] File uploaded to Boxlite sandbox: %s",
remote_path,
)
return {
"success": True,
"message": "File uploaded successfully",
+38 -8
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import locale
import os
import shutil
import subprocess
@@ -52,6 +53,31 @@ def _ensure_safe_path(path: str) -> str:
return abs_path
def _decode_shell_output(output: bytes | None) -> str:
if output is None:
return ""
preferred = locale.getpreferredencoding(False) or "utf-8"
try:
return output.decode("utf-8")
except (LookupError, UnicodeDecodeError):
pass
if os.name == "nt":
for encoding in ("mbcs", "cp936", "gbk", "gb18030"):
try:
return output.decode(encoding)
except (LookupError, UnicodeDecodeError):
continue
try:
return output.decode(preferred)
except (LookupError, UnicodeDecodeError):
pass
return output.decode("utf-8", errors="replace")
@dataclass
class LocalShellComponent(ShellComponent):
async def exec(
@@ -72,28 +98,32 @@ class LocalShellComponent(ShellComponent):
run_env.update({str(k): str(v) for k, v in env.items()})
working_dir = _ensure_safe_path(cwd) if cwd else get_astrbot_root()
if background:
proc = subprocess.Popen(
# `command` is intentionally executed through the current shell so
# local computer-use behavior matches existing tool semantics.
# Safety relies on `_is_safe_command()` and the allowed-root checks.
proc = subprocess.Popen( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
command,
shell=shell,
cwd=working_dir,
env=run_env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
return {"pid": proc.pid, "stdout": "", "stderr": "", "exit_code": None}
result = subprocess.run(
# `command` is intentionally executed through the current shell so
# local computer-use behavior matches existing tool semantics.
# Safety relies on `_is_safe_command()` and the allowed-root checks.
result = subprocess.run( # noqa: S602 # nosemgrep: python.lang.security.audit.dangerous-subprocess-use-audit
command,
shell=shell,
cwd=working_dir,
env=run_env,
timeout=timeout,
capture_output=True,
text=True,
)
return {
"stdout": result.stdout,
"stderr": result.stderr,
"stdout": _decode_shell_output(result.stdout),
"stderr": _decode_shell_output(result.stderr),
"exit_code": result.returncode,
}
+20 -3
View File
@@ -31,7 +31,7 @@ class ShipyardBooter(ComputerBooter):
self._ship = ship
async def shutdown(self) -> None:
pass
logger.info("[Computer] Shipyard booter shutdown.")
@property
def fs(self) -> FileSystemComponent:
@@ -47,11 +47,19 @@ class ShipyardBooter(ComputerBooter):
async def upload_file(self, path: str, file_name: str) -> dict:
"""Upload file to sandbox"""
return await self._ship.upload_file(path, file_name)
result = await self._ship.upload_file(path, file_name)
logger.info("[Computer] File uploaded to Shipyard sandbox: %s", file_name)
return result
async def download_file(self, remote_path: str, local_path: str):
"""Download file from sandbox."""
return await self._ship.download_file(remote_path, local_path)
result = await self._ship.download_file(remote_path, local_path)
logger.info(
"[Computer] File downloaded from Shipyard sandbox: %s -> %s",
remote_path,
local_path,
)
return result
async def available(self) -> bool:
"""Check if the sandbox is available."""
@@ -59,8 +67,17 @@ class ShipyardBooter(ComputerBooter):
ship_id = self._ship.id
data = await self._sandbox_client.get_ship(ship_id)
if not data:
logger.info(
"[Computer] Shipyard sandbox health check: id=%s, healthy=False (no data)",
ship_id,
)
return False
health = bool(data.get("status", 0) == 1)
logger.info(
"[Computer] Shipyard sandbox health check: id=%s, healthy=%s",
ship_id,
health,
)
return health
except Exception as e:
logger.error(f"Error checking Shipyard sandbox availability: {e}")
@@ -0,0 +1,513 @@
from __future__ import annotations
import os
import shlex
from typing import Any, cast
from astrbot.api import logger
from ..olayer import (
BrowserComponent,
FileSystemComponent,
PythonComponent,
ShellComponent,
)
from .base import ComputerBooter
def _maybe_model_dump(value: Any) -> dict[str, Any]:
if isinstance(value, dict):
return value
if hasattr(value, "model_dump"):
dumped = value.model_dump()
if isinstance(dumped, dict):
return dumped
return {}
class NeoPythonComponent(PythonComponent):
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def exec(
self,
code: str,
kernel_id: str | None = None,
timeout: int = 30,
silent: bool = False,
) -> dict[str, Any]:
_ = kernel_id # Bay runtime does not expose kernel_id in current SDK.
result = await self._sandbox.python.exec(code, timeout=timeout)
payload = _maybe_model_dump(result)
output_text = payload.get("output", "") or ""
error_text = payload.get("error", "") or ""
data = payload.get("data") if isinstance(payload.get("data"), dict) else {}
rich_output = data.get("output") if isinstance(data.get("output"), dict) else {}
if not isinstance(rich_output.get("images"), list):
rich_output["images"] = []
if "text" not in rich_output:
rich_output["text"] = output_text
if silent:
rich_output["text"] = ""
return {
"success": bool(payload.get("success", error_text == "")),
"data": {
"output": rich_output,
"error": error_text,
},
"execution_id": payload.get("execution_id"),
"execution_time_ms": payload.get("execution_time_ms"),
"code": payload.get("code"),
"output": output_text,
"error": error_text,
}
class NeoShellComponent(ShellComponent):
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def exec(
self,
command: str,
cwd: str | None = None,
env: dict[str, str] | None = None,
timeout: int | None = 30,
shell: bool = True,
background: bool = False,
) -> dict[str, Any]:
if not shell:
return {
"stdout": "",
"stderr": "error: only shell mode is supported in shipyard_neo booter.",
"exit_code": 2,
"success": False,
}
run_command = command
if env:
env_prefix = " ".join(
f"{k}={shlex.quote(str(v))}" for k, v in sorted(env.items())
)
run_command = f"{env_prefix} {run_command}"
if background:
run_command = f"nohup sh -lc {shlex.quote(run_command)} >/tmp/astrbot_bg.log 2>&1 & echo $!"
result = await self._sandbox.shell.exec(
run_command,
timeout=timeout or 30,
cwd=cwd,
)
payload = _maybe_model_dump(result)
stdout = payload.get("output", "") or ""
stderr = payload.get("error", "") or ""
exit_code = payload.get("exit_code")
if background:
pid: int | None = None
try:
pid = int(stdout.strip().splitlines()[-1])
except Exception:
pid = None
return {
"pid": pid,
"stdout": stdout,
"stderr": stderr,
"exit_code": exit_code,
"success": bool(payload.get("success", not stderr)),
"execution_id": payload.get("execution_id"),
"execution_time_ms": payload.get("execution_time_ms"),
"command": payload.get("command"),
}
return {
"stdout": stdout,
"stderr": stderr,
"exit_code": exit_code,
"success": bool(payload.get("success", not stderr)),
"execution_id": payload.get("execution_id"),
"execution_time_ms": payload.get("execution_time_ms"),
"command": payload.get("command"),
}
class NeoFileSystemComponent(FileSystemComponent):
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def create_file(
self,
path: str,
content: str = "",
mode: int = 0o644,
) -> dict[str, Any]:
_ = mode
await self._sandbox.filesystem.write_file(path, content)
return {"success": True, "path": path}
async def read_file(self, path: str, encoding: str = "utf-8") -> dict[str, Any]:
_ = encoding
content = await self._sandbox.filesystem.read_file(path)
return {"success": True, "path": path, "content": content}
async def write_file(
self,
path: str,
content: str,
mode: str = "w",
encoding: str = "utf-8",
) -> dict[str, Any]:
_ = mode
_ = encoding
await self._sandbox.filesystem.write_file(path, content)
return {"success": True, "path": path}
async def delete_file(self, path: str) -> dict[str, Any]:
await self._sandbox.filesystem.delete(path)
return {"success": True, "path": path}
async def list_dir(
self,
path: str = ".",
show_hidden: bool = False,
) -> dict[str, Any]:
entries = await self._sandbox.filesystem.list_dir(path)
data = []
for entry in entries:
item = _maybe_model_dump(entry)
if not show_hidden and str(item.get("name", "")).startswith("."):
continue
data.append(item)
return {"success": True, "path": path, "entries": data}
class NeoBrowserComponent(BrowserComponent):
def __init__(self, sandbox: Any) -> None:
self._sandbox = sandbox
async def exec(
self,
cmd: str,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> dict[str, Any]:
result = await self._sandbox.browser.exec(
cmd,
timeout=timeout,
description=description,
tags=tags,
learn=learn,
include_trace=include_trace,
)
return _maybe_model_dump(result)
async def exec_batch(
self,
commands: list[str],
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> dict[str, Any]:
result = await self._sandbox.browser.exec_batch(
commands,
timeout=timeout,
stop_on_error=stop_on_error,
description=description,
tags=tags,
learn=learn,
include_trace=include_trace,
)
return _maybe_model_dump(result)
async def run_skill(
self,
skill_key: str,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
tags: str | None = None,
) -> dict[str, Any]:
result = await self._sandbox.browser.run_skill(
skill_key=skill_key,
timeout=timeout,
stop_on_error=stop_on_error,
include_trace=include_trace,
description=description,
tags=tags,
)
return _maybe_model_dump(result)
class ShipyardNeoBooter(ComputerBooter):
"""Booter backed by Shipyard Neo (Bay).
If *endpoint_url* is empty or set to ``"__auto__"``, Bay will be
started automatically as a Docker container (like Boxlite does for
Ship containers).
"""
AUTO_SENTINEL = "__auto__"
DEFAULT_PROFILE = "python-default"
def __init__(
self,
endpoint_url: str,
access_token: str,
profile: str = DEFAULT_PROFILE,
ttl: int = 3600,
) -> None:
self._endpoint_url = endpoint_url
self._access_token = access_token
self._profile = profile
self._ttl = ttl
self._client: Any = None
self._sandbox: Any = None
self._bay_manager: Any = None # BayContainerManager when auto-started
self._fs: FileSystemComponent | None = None
self._python: PythonComponent | None = None
self._shell: ShellComponent | None = None
self._browser: BrowserComponent | None = None
@property
def bay_client(self) -> Any:
return self._client
@property
def sandbox(self) -> Any:
return self._sandbox
@property
def capabilities(self) -> tuple[str, ...] | None:
"""Sandbox capabilities from the Bay profile.
Returns an immutable tuple after :meth:`boot`; ``None`` before boot.
"""
if self._sandbox is None:
return None
caps = getattr(self._sandbox, "capabilities", None)
return tuple(caps) if caps is not None else None
@property
def is_auto_mode(self) -> bool:
"""True when Bay should be auto-started."""
ep = (self._endpoint_url or "").strip()
return not ep or ep == self.AUTO_SENTINEL
async def boot(self, session_id: str) -> None:
_ = session_id
# --- Auto-start Bay if needed ---
if self.is_auto_mode:
from .bay_manager import BayContainerManager
# Clean up previous manager if re-booting
if self._bay_manager is not None:
await self._bay_manager.close_client()
logger.info("[Computer] Neo auto-start mode: launching Bay container")
self._bay_manager = BayContainerManager()
self._endpoint_url = await self._bay_manager.ensure_running()
await self._bay_manager.wait_healthy()
# Read auto-provisioned credentials
if not self._access_token:
self._access_token = await self._bay_manager.read_credentials()
logger.info("[Computer] Bay auto-started at %s", self._endpoint_url)
if not self._endpoint_url or not self._access_token:
if self._bay_manager is not None:
raise ValueError(
"Bay container started but credentials could not be read. "
"Ensure Bay generated credentials.json, or set access_token manually."
)
raise ValueError(
"Shipyard Neo sandbox configuration is incomplete. "
"Set endpoint (default http://127.0.0.1:8114) and access token, "
"or ensure Bay's credentials.json is accessible for auto-discovery."
)
from shipyard_neo import BayClient
self._client = BayClient(
endpoint_url=self._endpoint_url,
access_token=self._access_token,
)
await self._client.__aenter__()
# Resolve profile: user-specified > smart selection > default
resolved_profile = await self._resolve_profile(self._client)
self._sandbox = await self._client.create_sandbox(
profile=resolved_profile,
ttl=self._ttl,
)
self._fs = NeoFileSystemComponent(self._sandbox)
self._python = NeoPythonComponent(self._sandbox)
self._shell = NeoShellComponent(self._sandbox)
caps = self.capabilities or ()
self._browser = (
NeoBrowserComponent(self._sandbox) if "browser" in caps else None
)
logger.info(
"Got Shipyard Neo sandbox: %s (profile=%s, capabilities=%s, auto=%s)",
self._sandbox.id,
resolved_profile,
list(caps),
bool(self._bay_manager),
)
async def _resolve_profile(self, client: Any) -> str:
"""Pick the best profile for this session.
Resolution order:
1. User-specified profile (non-empty, non-default) use as-is.
2. Query ``GET /v1/profiles`` and pick the profile with the most
capabilities, preferring profiles that include ``"browser"``.
3. Fall back to :attr:`DEFAULT_PROFILE`.
Auth errors (401/403) are re-raised immediately they indicate a
misconfigured token, and silently falling back would just delay the
real failure to ``create_sandbox``.
"""
# User explicitly set a profile → honour it
if self._profile and self._profile != self.DEFAULT_PROFILE:
logger.info("[Computer] Using user-specified profile: %s", self._profile)
return self._profile
# Query Bay for available profiles
from shipyard_neo.errors import ForbiddenError, UnauthorizedError
try:
profile_list = await client.list_profiles()
profiles = profile_list.items
except (UnauthorizedError, ForbiddenError):
raise # auth errors must not be silenced
except Exception as exc:
logger.warning(
"[Computer] Failed to query Bay profiles, falling back to %s: %s",
self.DEFAULT_PROFILE,
exc,
)
return self.DEFAULT_PROFILE
if not profiles:
return self.DEFAULT_PROFILE
def _score(p: Any) -> tuple[int, int]:
"""(has_browser, capability_count) — higher is better."""
caps = getattr(p, "capabilities", []) or []
return (1 if "browser" in caps else 0, len(caps))
best = max(profiles, key=_score)
chosen = getattr(best, "id", self.DEFAULT_PROFILE)
if chosen != self.DEFAULT_PROFILE:
caps = getattr(best, "capabilities", [])
logger.info(
"[Computer] Auto-selected profile %s (capabilities=%s)",
chosen,
caps,
)
return chosen
async def shutdown(self) -> None:
if self._client is not None:
sandbox_id = getattr(self._sandbox, "id", "unknown")
logger.info(
"[Computer] Shutting down Shipyard Neo sandbox: id=%s", sandbox_id
)
await self._client.__aexit__(None, None, None)
self._client = None
self._sandbox = None
logger.info("[Computer] Shipyard Neo sandbox shut down: id=%s", sandbox_id)
# NOTE: We intentionally do NOT stop the Bay container here.
# It stays running for reuse by future sessions. The user can
# stop it manually or via ``BayContainerManager.stop()``.
if self._bay_manager is not None:
await self._bay_manager.close_client()
@property
def fs(self) -> FileSystemComponent:
if self._fs is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
return self._fs
@property
def python(self) -> PythonComponent:
if self._python is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
return self._python
@property
def shell(self) -> ShellComponent:
if self._shell is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
return self._shell
@property
def browser(self) -> BrowserComponent:
if self._browser is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
return self._browser
async def upload_file(self, path: str, file_name: str) -> dict:
if self._sandbox is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
with open(path, "rb") as f:
content = f.read()
remote_path = file_name.lstrip("/")
await self._sandbox.filesystem.upload(remote_path, content)
logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path)
return {
"success": True,
"message": "File uploaded successfully",
"file_path": remote_path,
}
async def download_file(self, remote_path: str, local_path: str) -> None:
if self._sandbox is None:
raise RuntimeError("ShipyardNeoBooter is not initialized.")
content = await self._sandbox.filesystem.download(remote_path.lstrip("/"))
local_dir = os.path.dirname(local_path)
if local_dir:
os.makedirs(local_dir, exist_ok=True)
with open(local_path, "wb") as f:
f.write(cast(bytes, content))
logger.info(
"[Computer] File downloaded from Neo sandbox: %s -> %s",
remote_path,
local_path,
)
async def available(self) -> bool:
if self._sandbox is None:
return False
try:
await self._sandbox.refresh()
status = getattr(self._sandbox.status, "value", str(self._sandbox.status))
healthy = status not in {"failed", "expired"}
logger.info(
"[Computer] Neo sandbox health check: id=%s, status=%s, healthy=%s",
getattr(self._sandbox, "id", "unknown"),
status,
healthy,
)
return healthy
except Exception as e:
logger.error(f"Error checking Shipyard Neo sandbox availability: {e}")
return False
+440 -32
View File
@@ -1,10 +1,11 @@
import json
import os
import shutil
import uuid
from pathlib import Path
from astrbot.api import logger
from astrbot.core.skills.skill_manager import SANDBOX_SKILLS_ROOT
from astrbot.core.skills.skill_manager import SANDBOX_SKILLS_ROOT, SkillManager
from astrbot.core.star.context import Context
from astrbot.core.utils.astrbot_path import (
get_astrbot_skills_path,
@@ -16,45 +17,401 @@ from .booters.local import LocalBooter
session_booter: dict[str, ComputerBooter] = {}
local_booter: ComputerBooter | None = None
_MANAGED_SKILLS_FILE = ".astrbot_managed_skills.json"
def _list_local_skill_dirs(skills_root: Path) -> list[Path]:
skills: list[Path] = []
for entry in sorted(skills_root.iterdir()):
if not entry.is_dir():
continue
skill_md = entry / "SKILL.md"
if skill_md.exists():
skills.append(entry)
return skills
def _discover_bay_credentials(endpoint: str) -> str:
"""Try to auto-discover Bay API key from credentials.json.
Search order:
1. BAY_DATA_DIR env var
2. Mono-repo relative path: ../pkgs/bay/ (dev layout)
3. Current working directory
Returns:
API key string, or empty string if not found.
"""
candidates: list[Path] = []
# 1. BAY_DATA_DIR env var
bay_data_dir = os.environ.get("BAY_DATA_DIR")
if bay_data_dir:
candidates.append(Path(bay_data_dir) / "credentials.json")
# 2. Mono-repo layout: AstrBot/../pkgs/bay/credentials.json
astrbot_root = Path(__file__).resolve().parents[3] # astrbot/core/computer/ → root
candidates.append(astrbot_root.parent / "pkgs" / "bay" / "credentials.json")
# 3. Current working directory
candidates.append(Path.cwd() / "credentials.json")
for cred_path in candidates:
if not cred_path.is_file():
continue
try:
data = json.loads(cred_path.read_text())
api_key = data.get("api_key", "")
if api_key:
# Optionally verify endpoint matches
cred_endpoint = data.get("endpoint", "")
if (
cred_endpoint
and endpoint
and cred_endpoint.rstrip("/") != endpoint.rstrip("/")
):
logger.warning(
"[Computer] credentials.json endpoint mismatch: "
"file=%s, configured=%s — using key anyway",
cred_endpoint,
endpoint,
)
masked_key = f"{api_key[:4]}..." if len(api_key) >= 6 else "redacted"
logger.info(
"[Computer] Auto-discovered Bay API key from %s (prefix=%s)",
cred_path,
masked_key,
)
return api_key
except (json.JSONDecodeError, OSError) as exc:
logger.debug("[Computer] Failed to read %s: %s", cred_path, exc)
logger.debug("[Computer] No Bay credentials.json found in search paths")
return ""
def _build_python_exec_command(script: str) -> str:
return (
"if command -v python3 >/dev/null 2>&1; then PYBIN=python3; "
"elif command -v python >/dev/null 2>&1; then PYBIN=python; "
"else echo 'python not found in sandbox' >&2; exit 127; fi; "
"$PYBIN - <<'PY'\n"
f"{script}\n"
"PY"
)
def _build_apply_sync_command() -> str:
"""Build shell command for sync stage only.
This stage mutates sandbox files (managed skill replacement) but does not scan
metadata. Keeping it separate allows callers to preserve old behavior while
reusing the apply step independently.
"""
script = f"""
import json
import shutil
import zipfile
from pathlib import Path
root = Path({SANDBOX_SKILLS_ROOT!r})
zip_path = root / "skills.zip"
tmp_extract = Path(f"{{root}}_tmp_extract")
managed_file = root / {_MANAGED_SKILLS_FILE!r}
def remove_tree(path: Path) -> None:
if not path.exists():
return
if path.is_dir():
shutil.rmtree(path, ignore_errors=True)
else:
path.unlink(missing_ok=True)
def load_managed_skills() -> list[str]:
if not managed_file.exists():
return []
try:
payload = json.loads(managed_file.read_text(encoding="utf-8"))
except Exception:
return []
if not isinstance(payload, dict):
return []
items = payload.get("managed_skills", [])
if not isinstance(items, list):
return []
result: list[str] = []
for item in items:
if isinstance(item, str) and item.strip():
result.append(item.strip())
return result
root.mkdir(parents=True, exist_ok=True)
for managed_name in load_managed_skills():
remove_tree(root / managed_name)
current_managed: list[str] = []
if zip_path.exists():
remove_tree(tmp_extract)
tmp_extract.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(zip_path) as zf:
zf.extractall(tmp_extract)
for entry in sorted(tmp_extract.iterdir()):
if not entry.is_dir():
continue
target = root / entry.name
remove_tree(target)
shutil.copytree(entry, target)
current_managed.append(entry.name)
remove_tree(tmp_extract)
remove_tree(zip_path)
managed_file.write_text(
json.dumps({{"managed_skills": current_managed}}, ensure_ascii=False, indent=2),
encoding="utf-8",
)
print(json.dumps({{"managed_skills": current_managed}}, ensure_ascii=False))
""".strip()
return _build_python_exec_command(script)
def _build_scan_command() -> str:
"""Build shell command for scan stage only.
This stage is read-oriented: it scans SKILL.md metadata and returns the
historical payload shape consumed by cache update logic.
The scan resolves the absolute path of the skills root at runtime so
that the LLM can reliably ``cat`` skill files regardless of cwd.
Only the ``description`` field is extracted from frontmatter.
"""
script = f"""
import json
from pathlib import Path
root = Path({SANDBOX_SKILLS_ROOT!r})
managed_file = root / {_MANAGED_SKILLS_FILE!r}
# Resolve absolute path at runtime so prompts always have a reliable path
root_abs = str(root.resolve())
# NOTE: This parser mirrors skill_manager._parse_frontmatter_description.
# Keep the two implementations in sync when changing parsing logic.
def parse_description(text: str) -> str:
if not text.startswith("---"):
return ""
lines = text.splitlines()
if not lines or lines[0].strip() != "---":
return ""
end_idx = None
for i in range(1, len(lines)):
if lines[i].strip() == "---":
end_idx = i
break
if end_idx is None:
return ""
for line in lines[1:end_idx]:
if ":" not in line:
continue
key, value = line.split(":", 1)
if key.strip().lower() == "description":
return value.strip().strip('"').strip("'")
return ""
def load_managed_skills() -> list[str]:
if not managed_file.exists():
return []
try:
payload = json.loads(managed_file.read_text(encoding="utf-8"))
except Exception:
return []
if not isinstance(payload, dict):
return []
items = payload.get("managed_skills", [])
if not isinstance(items, list):
return []
result: list[str] = []
for item in items:
if isinstance(item, str) and item.strip():
result.append(item.strip())
return result
def collect_skills() -> list[dict[str, str]]:
skills: list[dict[str, str]] = []
if not root.exists():
return skills
for skill_dir in sorted(root.iterdir()):
if not skill_dir.is_dir():
continue
skill_md = skill_dir / "SKILL.md"
if not skill_md.is_file():
continue
description = ""
try:
text = skill_md.read_text(encoding="utf-8")
description = parse_description(text)
except Exception:
description = ""
skills.append(
{{
"name": skill_dir.name,
"description": description,
"path": f"{{root_abs}}/{{skill_dir.name}}/SKILL.md",
}}
)
return skills
print(
json.dumps(
{{
"managed_skills": load_managed_skills(),
"skills": collect_skills(),
}},
ensure_ascii=False,
)
)
""".strip()
return _build_python_exec_command(script)
def _build_sync_and_scan_command() -> str:
"""Legacy combined command kept for backward compatibility.
New code paths should prefer apply + scan split helpers.
"""
return f"{_build_apply_sync_command()}\n{_build_scan_command()}"
def _shell_exec_succeeded(result: dict) -> bool:
if "success" in result:
return bool(result.get("success"))
exit_code = result.get("exit_code")
return exit_code in (0, None)
def _format_exec_error_detail(result: dict) -> str:
"""Format shell execution details for better observability.
Keep the message compact while still surfacing exit code and stderr/stdout.
"""
exit_code = result.get("exit_code")
stderr = str(result.get("stderr", "") or "").strip()
stdout = str(result.get("stdout", "") or "").strip()
stderr_text = stderr[:500]
stdout_text = stdout[:300]
return f"exit_code={exit_code}, stderr={stderr_text!r}, stdout_tail={stdout_text!r}"
def _decode_sync_payload(stdout: str) -> dict | None:
text = stdout.strip()
if not text:
return None
candidates = [text]
candidates.extend([line.strip() for line in text.splitlines() if line.strip()])
for candidate in reversed(candidates):
try:
payload = json.loads(candidate)
except Exception:
continue
if isinstance(payload, dict):
return payload
return None
def _update_sandbox_skills_cache(payload: dict | None) -> None:
if not isinstance(payload, dict):
return
skills = payload.get("skills", [])
if not isinstance(skills, list):
return
SkillManager().set_sandbox_skills_cache(skills)
async def _apply_skills_to_sandbox(booter: ComputerBooter) -> None:
"""Apply local skill bundle to sandbox filesystem only.
This function is intentionally limited to file mutation. Metadata scanning is
executed in a separate phase to keep failure domains clear.
"""
logger.info("[Computer] Skill sync phase=apply start")
apply_result = await booter.shell.exec(_build_apply_sync_command())
if not _shell_exec_succeeded(apply_result):
detail = _format_exec_error_detail(apply_result)
logger.error("[Computer] Skill sync phase=apply failed: %s", detail)
raise RuntimeError(f"Failed to apply sandbox skill sync strategy: {detail}")
logger.info("[Computer] Skill sync phase=apply done")
async def _scan_sandbox_skills(booter: ComputerBooter) -> dict | None:
"""Scan sandbox skills and return normalized payload for cache update."""
logger.info("[Computer] Skill sync phase=scan start")
scan_result = await booter.shell.exec(_build_scan_command())
if not _shell_exec_succeeded(scan_result):
detail = _format_exec_error_detail(scan_result)
logger.error("[Computer] Skill sync phase=scan failed: %s", detail)
raise RuntimeError(f"Failed to scan sandbox skills after sync: {detail}")
payload = _decode_sync_payload(str(scan_result.get("stdout", "") or ""))
if payload is None:
logger.warning("[Computer] Skill sync phase=scan returned empty payload")
else:
logger.info("[Computer] Skill sync phase=scan done")
return payload
async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None:
skills_root = get_astrbot_skills_path()
if not os.path.isdir(skills_root):
return
if not any(Path(skills_root).iterdir()):
return
"""Sync local skills to sandbox and refresh cache.
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
zip_base = os.path.join(temp_dir, "skills_bundle")
zip_path = f"{zip_base}.zip"
Backward-compatible orchestrator: keep historical behavior while internally
splitting into `apply` and `scan` phases.
"""
skills_root = Path(get_astrbot_skills_path())
if not skills_root.is_dir():
return
local_skill_dirs = _list_local_skill_dirs(skills_root)
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
zip_base = temp_dir / "skills_bundle"
zip_path = zip_base.with_suffix(".zip")
try:
if os.path.exists(zip_path):
os.remove(zip_path)
shutil.make_archive(zip_base, "zip", skills_root)
remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip"
logger.info("Uploading skills bundle to sandbox...")
await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}")
upload_result = await booter.upload_file(zip_path, str(remote_zip))
if not upload_result.get("success", False):
raise RuntimeError("Failed to upload skills bundle to sandbox.")
# Use -n flag to never overwrite existing files, fallback to Python if unzip unavailable
await booter.shell.exec(
f"unzip -n {remote_zip} -d {SANDBOX_SKILLS_ROOT} || "
f"python3 -c \"import zipfile, os, pathlib; z=zipfile.ZipFile('{remote_zip}'); "
f"[z.extract(m, '{SANDBOX_SKILLS_ROOT}') for m in z.namelist() "
f"if not os.path.exists(os.path.join('{SANDBOX_SKILLS_ROOT}', m))]\" || "
f"python -c \"import zipfile, os, pathlib; z=zipfile.ZipFile('{remote_zip}'); "
f"[z.extract(m, '{SANDBOX_SKILLS_ROOT}') for m in z.namelist() "
f"if not os.path.exists(os.path.join('{SANDBOX_SKILLS_ROOT}', m))]\"; "
f"rm -f {remote_zip}"
if local_skill_dirs:
if zip_path.exists():
zip_path.unlink()
shutil.make_archive(str(zip_base), "zip", str(skills_root))
remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip"
logger.info("Uploading skills bundle to sandbox...")
await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}")
upload_result = await booter.upload_file(str(zip_path), str(remote_zip))
if not upload_result.get("success", False):
raise RuntimeError("Failed to upload skills bundle to sandbox.")
else:
logger.info(
"No local skills found. Keeping sandbox built-ins and refreshing metadata."
)
await booter.shell.exec(f"rm -f {SANDBOX_SKILLS_ROOT}/skills.zip")
# Keep backward-compatible behavior while splitting lifecycle into two
# observable phases: apply (filesystem mutation) + scan (metadata read).
await _apply_skills_to_sandbox(booter)
payload = await _scan_sandbox_skills(booter)
_update_sandbox_skills_cache(payload)
managed = payload.get("managed_skills", []) if isinstance(payload, dict) else []
logger.info(
"[Computer] Sandbox skill sync complete: managed=%d",
len(managed),
)
finally:
if os.path.exists(zip_path):
if zip_path.exists():
try:
os.remove(zip_path)
zip_path.unlink()
except Exception:
logger.warning(f"Failed to remove temp skills zip: {zip_path}")
@@ -65,8 +422,14 @@ async def get_booter(
) -> ComputerBooter:
config = context.get_config(umo=session_id)
runtime = config.get("provider_settings", {}).get("computer_use_runtime", "local")
if runtime == "local":
return get_local_booter()
elif runtime == "none":
raise RuntimeError("Sandbox runtime is disabled by configuration.")
sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {})
booter_type = sandbox_cfg.get("booter", "shipyard")
booter_type = sandbox_cfg.get("booter", "shipyard_neo")
if session_id in session_booter:
booter = session_booter[session_id]
@@ -75,6 +438,9 @@ async def get_booter(
session_booter.pop(session_id, None)
if session_id not in session_booter:
uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex
logger.info(
f"[Computer] Initializing booter: type={booter_type}, session={session_id}"
)
if booter_type == "shipyard":
from .booters.shipyard import ShipyardBooter
@@ -86,6 +452,27 @@ async def get_booter(
client = ShipyardBooter(
endpoint_url=ep, access_token=token, ttl=ttl, session_num=max_sessions
)
elif booter_type == "shipyard_neo":
from .booters.shipyard_neo import ShipyardNeoBooter
ep = sandbox_cfg.get("shipyard_neo_endpoint", "")
token = sandbox_cfg.get("shipyard_neo_access_token", "")
ttl = sandbox_cfg.get("shipyard_neo_ttl", 3600)
profile = sandbox_cfg.get("shipyard_neo_profile", "python-default")
# Auto-discover token from Bay's credentials.json if not configured
if not token:
token = _discover_bay_credentials(ep)
logger.info(
f"[Computer] Shipyard Neo config: endpoint={ep}, profile={profile}, ttl={ttl}"
)
client = ShipyardNeoBooter(
endpoint_url=ep,
access_token=token,
profile=profile,
ttl=ttl,
)
elif booter_type == "boxlite":
from .booters.boxlite import BoxliteBooter
@@ -95,6 +482,9 @@ async def get_booter(
try:
await client.boot(uuid_str)
logger.info(
f"[Computer] Sandbox booted successfully: type={booter_type}, session={session_id}"
)
await _sync_skills_to_sandbox(client)
except Exception as e:
logger.error(f"Error booting sandbox for session {session_id}: {e}")
@@ -104,6 +494,24 @@ async def get_booter(
return session_booter[session_id]
async def sync_skills_to_active_sandboxes() -> None:
"""Best-effort skills synchronization for all active sandbox sessions."""
logger.info(
"[Computer] Syncing skills to %d active sandbox(es)", len(session_booter)
)
for session_id, booter in list(session_booter.items()):
try:
if not await booter.available():
continue
await _sync_skills_to_sandbox(booter)
except Exception as e:
logger.warning(
"Failed to sync skills to sandbox for session %s: %s",
session_id,
e,
)
def get_local_booter() -> ComputerBooter:
global local_booter
if local_booter is None:
+7 -1
View File
@@ -1,5 +1,11 @@
from .browser import BrowserComponent
from .filesystem import FileSystemComponent
from .python import PythonComponent
from .shell import ShellComponent
__all__ = ["PythonComponent", "ShellComponent", "FileSystemComponent"]
__all__ = [
"PythonComponent",
"ShellComponent",
"FileSystemComponent",
"BrowserComponent",
]
+46
View File
@@ -0,0 +1,46 @@
"""
Browser automation component
"""
from typing import Any, Protocol
class BrowserComponent(Protocol):
"""Browser operations component"""
async def exec(
self,
cmd: str,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> dict[str, Any]:
"""Execute a browser automation command"""
...
async def exec_batch(
self,
commands: list[str],
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> dict[str, Any]:
"""Execute a browser automation command batch"""
...
async def run_skill(
self,
skill_key: str,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
tags: str | None = None,
) -> dict[str, Any]:
"""Run a browser skill by skill key"""
...
+28
View File
@@ -1,8 +1,36 @@
from .browser import BrowserBatchExecTool, BrowserExecTool, RunBrowserSkillTool
from .fs import FileDownloadTool, FileUploadTool
from .neo_skills import (
AnnotateExecutionTool,
CreateSkillCandidateTool,
CreateSkillPayloadTool,
EvaluateSkillCandidateTool,
GetExecutionHistoryTool,
GetSkillPayloadTool,
ListSkillCandidatesTool,
ListSkillReleasesTool,
PromoteSkillCandidateTool,
RollbackSkillReleaseTool,
SyncSkillReleaseTool,
)
from .python import LocalPythonTool, PythonTool
from .shell import ExecuteShellTool
__all__ = [
"BrowserExecTool",
"BrowserBatchExecTool",
"RunBrowserSkillTool",
"GetExecutionHistoryTool",
"AnnotateExecutionTool",
"CreateSkillPayloadTool",
"GetSkillPayloadTool",
"CreateSkillCandidateTool",
"ListSkillCandidatesTool",
"EvaluateSkillCandidateTool",
"PromoteSkillCandidateTool",
"ListSkillReleasesTool",
"RollbackSkillReleaseTool",
"SyncSkillReleaseTool",
"FileUploadTool",
"PythonTool",
"LocalPythonTool",
+204
View File
@@ -0,0 +1,204 @@
import json
from dataclasses import dataclass, field
from typing import Any
from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from ..computer_client import get_booter
def _to_json(data: Any) -> str:
return json.dumps(data, ensure_ascii=False, default=str)
def _ensure_admin(context: ContextWrapper[AstrAgentContext]) -> str | None:
if context.context.event.role != "admin":
return (
"error: Permission denied. Browser and skill lifecycle tools are only allowed "
"for admin users."
)
return None
async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> Any:
booter = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
browser = getattr(booter, "browser", None)
if browser is None:
raise RuntimeError(
"Current sandbox booter does not support browser capability. "
"Please switch to shipyard_neo."
)
return browser
@dataclass
class BrowserExecTool(FunctionTool):
name: str = "astrbot_execute_browser"
description: str = "Execute one browser automation command in the sandbox."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"cmd": {"type": "string", "description": "Browser command to execute."},
"timeout": {"type": "integer", "default": 30},
"description": {
"type": "string",
"description": "Optional execution description.",
},
"tags": {"type": "string", "description": "Optional tags."},
"learn": {
"type": "boolean",
"description": "Whether to mark execution as learn evidence.",
"default": False,
},
"include_trace": {
"type": "boolean",
"description": "Whether to include trace_ref in response.",
"default": False,
},
},
"required": ["cmd"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
cmd: str,
timeout: int = 30,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
result = await browser.exec(
cmd=cmd,
timeout=timeout,
description=description,
tags=tags,
learn=learn,
include_trace=include_trace,
)
return _to_json(result)
except Exception as e:
return f"Error executing browser command: {str(e)}"
@dataclass
class BrowserBatchExecTool(FunctionTool):
name: str = "astrbot_execute_browser_batch"
description: str = "Execute a browser command batch in the sandbox."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"commands": {
"type": "array",
"items": {"type": "string"},
"description": "Ordered browser commands.",
},
"timeout": {"type": "integer", "default": 60},
"stop_on_error": {"type": "boolean", "default": True},
"description": {
"type": "string",
"description": "Optional execution description.",
},
"tags": {"type": "string", "description": "Optional tags."},
"learn": {
"type": "boolean",
"description": "Whether to mark execution as learn evidence.",
"default": False,
},
"include_trace": {
"type": "boolean",
"description": "Whether to include trace_ref in response.",
"default": False,
},
},
"required": ["commands"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
commands: list[str],
timeout: int = 60,
stop_on_error: bool = True,
description: str | None = None,
tags: str | None = None,
learn: bool = False,
include_trace: bool = False,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
result = await browser.exec_batch(
commands=commands,
timeout=timeout,
stop_on_error=stop_on_error,
description=description,
tags=tags,
learn=learn,
include_trace=include_trace,
)
return _to_json(result)
except Exception as e:
return f"Error executing browser batch command: {str(e)}"
@dataclass
class RunBrowserSkillTool(FunctionTool):
name: str = "astrbot_run_browser_skill"
description: str = "Run a released browser skill in the sandbox by skill_key."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"skill_key": {"type": "string"},
"timeout": {"type": "integer", "default": 60},
"stop_on_error": {"type": "boolean", "default": True},
"include_trace": {"type": "boolean", "default": False},
"description": {"type": "string"},
"tags": {"type": "string"},
},
"required": ["skill_key"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
skill_key: str,
timeout: int = 60,
stop_on_error: bool = True,
include_trace: bool = False,
description: str | None = None,
tags: str | None = None,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
try:
browser = await _get_browser_component(context)
result = await browser.run_skill(
skill_key=skill_key,
timeout=timeout,
stop_on_error=stop_on_error,
include_trace=include_trace,
description=description,
tags=tags,
)
return _to_json(result)
except Exception as e:
return f"Error running browser skill: {str(e)}"
+542
View File
@@ -0,0 +1,542 @@
import json
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any
from astrbot.api import FunctionTool
from astrbot.core.agent.run_context import ContextWrapper
from astrbot.core.agent.tool import ToolExecResult
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
from ..computer_client import get_booter
def _to_jsonable(model_like: Any) -> Any:
if isinstance(model_like, dict):
return model_like
if isinstance(model_like, list):
return [_to_jsonable(i) for i in model_like]
if hasattr(model_like, "model_dump"):
return _to_jsonable(model_like.model_dump())
return model_like
def _to_json_text(data: Any) -> str:
return json.dumps(_to_jsonable(data), ensure_ascii=False, default=str)
def _ensure_admin(context: ContextWrapper[AstrAgentContext]) -> str | None:
if context.context.event.role != "admin":
return "error: Permission denied. Skill lifecycle tools are only allowed for admin users."
return None
async def _get_neo_context(
context: ContextWrapper[AstrAgentContext],
) -> tuple[Any, Any]:
booter = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
)
client = getattr(booter, "bay_client", None)
sandbox = getattr(booter, "sandbox", None)
if client is None or sandbox is None:
raise RuntimeError(
"Current sandbox booter does not support Neo skill lifecycle APIs. "
"Please switch to shipyard_neo."
)
return client, sandbox
@dataclass
class NeoSkillToolBase(FunctionTool):
error_prefix: str = "Error"
async def _run(
self,
context: ContextWrapper[AstrAgentContext],
neo_call: Callable[[Any, Any], Awaitable[Any]],
error_action: str,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
try:
client, sandbox = await _get_neo_context(context)
result = await neo_call(client, sandbox)
return _to_json_text(result)
except Exception as e:
return f"{self.error_prefix} {error_action}: {str(e)}"
@dataclass
class GetExecutionHistoryTool(NeoSkillToolBase):
name: str = "astrbot_get_execution_history"
description: str = "Get execution history from current sandbox."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"exec_type": {"type": "string"},
"success_only": {"type": "boolean", "default": False},
"limit": {"type": "integer", "default": 100},
"offset": {"type": "integer", "default": 0},
"tags": {"type": "string"},
"has_notes": {"type": "boolean", "default": False},
"has_description": {"type": "boolean", "default": False},
},
"required": [],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
exec_type: str | None = None,
success_only: bool = False,
limit: int = 100,
offset: int = 0,
tags: str | None = None,
has_notes: bool = False,
has_description: bool = False,
) -> ToolExecResult:
return await self._run(
context,
lambda _client, sandbox: sandbox.get_execution_history(
exec_type=exec_type,
success_only=success_only,
limit=limit,
offset=offset,
tags=tags,
has_notes=has_notes,
has_description=has_description,
),
error_action="getting execution history",
)
@dataclass
class AnnotateExecutionTool(NeoSkillToolBase):
name: str = "astrbot_annotate_execution"
description: str = "Annotate one execution history record."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"execution_id": {"type": "string"},
"description": {"type": "string"},
"tags": {"type": "string"},
"notes": {"type": "string"},
},
"required": ["execution_id"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
execution_id: str,
description: str | None = None,
tags: str | None = None,
notes: str | None = None,
) -> ToolExecResult:
return await self._run(
context,
lambda _client, sandbox: sandbox.annotate_execution(
execution_id=execution_id,
description=description,
tags=tags,
notes=notes,
),
error_action="annotating execution",
)
@dataclass
class CreateSkillPayloadTool(NeoSkillToolBase):
name: str = "astrbot_create_skill_payload"
description: str = (
"Step 1/3 for Neo skill authoring: create immutable payload content and return payload_ref. "
"Use this to store skill_markdown and structured metadata; do NOT write local skill folders directly."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"payload": {
"anyOf": [{"type": "object"}, {"type": "array"}],
"description": (
"Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. "
"This only stores content and returns payload_ref; it does not create a candidate or release."
),
},
"kind": {
"type": "string",
"description": "Payload kind.",
"default": "astrbot_skill_v1",
},
},
"required": ["payload"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
payload: dict[str, Any] | list[Any],
kind: str = "astrbot_skill_v1",
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.create_payload(
payload=payload,
kind=kind,
),
error_action="creating skill payload",
)
@dataclass
class GetSkillPayloadTool(NeoSkillToolBase):
name: str = "astrbot_get_skill_payload"
description: str = "Get one skill payload by payload_ref."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"payload_ref": {"type": "string"},
},
"required": ["payload_ref"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
payload_ref: str,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.get_payload(payload_ref),
error_action="getting skill payload",
)
@dataclass
class CreateSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_create_skill_candidate"
description: str = (
"Step 2/3 for Neo skill authoring: create a candidate by binding execution evidence "
"(source_execution_ids) with skill identity (skill_key) and optional payload_ref."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"skill_key": {
"type": "string",
"description": "Stable logical identifier, e.g. image-collage-9grid.",
},
"source_execution_ids": {
"type": "array",
"items": {"type": "string"},
"description": "Execution evidence IDs captured from sandbox history.",
},
"scenario_key": {
"type": "string",
"description": "Optional scenario namespace for grouping candidates.",
},
"payload_ref": {
"type": "string",
"description": "Optional payload reference created by astrbot_create_skill_payload.",
},
},
"required": ["skill_key", "source_execution_ids"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
skill_key: str,
source_execution_ids: list[str],
scenario_key: str | None = None,
payload_ref: str | None = None,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.create_candidate(
skill_key=skill_key,
source_execution_ids=source_execution_ids,
scenario_key=scenario_key,
payload_ref=payload_ref,
),
error_action="creating skill candidate",
)
@dataclass
class ListSkillCandidatesTool(NeoSkillToolBase):
name: str = "astrbot_list_skill_candidates"
description: str = "List skill candidates."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"status": {"type": "string"},
"skill_key": {"type": "string"},
"limit": {"type": "integer", "default": 100},
"offset": {"type": "integer", "default": 0},
},
"required": [],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
status: str | None = None,
skill_key: str | None = None,
limit: int = 100,
offset: int = 0,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.list_candidates(
status=status,
skill_key=skill_key,
limit=limit,
offset=offset,
),
error_action="listing skill candidates",
)
@dataclass
class EvaluateSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_evaluate_skill_candidate"
description: str = "Evaluate a skill candidate."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"candidate_id": {"type": "string"},
"passed": {"type": "boolean"},
"score": {"type": "number"},
"benchmark_id": {"type": "string"},
"report": {"type": "string"},
},
"required": ["candidate_id", "passed"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
candidate_id: str,
passed: bool,
score: float | None = None,
benchmark_id: str | None = None,
report: str | None = None,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.evaluate_candidate(
candidate_id,
passed=passed,
score=score,
benchmark_id=benchmark_id,
report=report,
),
error_action="evaluating skill candidate",
)
@dataclass
class PromoteSkillCandidateTool(NeoSkillToolBase):
name: str = "astrbot_promote_skill_candidate"
description: str = (
"Step 3/3 for Neo skill authoring: promote candidate to canary/stable release. "
"If stage=stable and sync_to_local=true, payload.skill_markdown is synced to local SKILL.md automatically."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"candidate_id": {"type": "string"},
"stage": {
"type": "string",
"description": "Release stage: canary/stable",
"default": "canary",
},
"sync_to_local": {
"type": "boolean",
"description": (
"Only used with stage=stable. true means sync payload.skill_markdown to local SKILL.md; "
"false means release remains Neo-side only."
),
"default": True,
},
},
"required": ["candidate_id"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
candidate_id: str,
stage: str = "canary",
sync_to_local: bool = True,
) -> ToolExecResult:
if err := _ensure_admin(context):
return err
if stage not in {"canary", "stable"}:
return "Error promoting skill candidate: stage must be canary or stable."
try:
client, _sandbox = await _get_neo_context(context)
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.promote_with_optional_sync(
client,
candidate_id=candidate_id,
stage=stage,
sync_to_local=sync_to_local,
)
if result.get("sync_error"):
rollback_json = result.get("rollback")
if rollback_json:
return (
"Error promoting skill candidate: stable release synced failed; "
f"auto rollback succeeded. sync_error={result['sync_error']}; "
f"rollback={_to_json_text(rollback_json)}"
)
return _to_json_text(
{
"release": result.get("release"),
"sync": result.get("sync"),
"rollback": result.get("rollback"),
}
)
except Exception as e:
return f"Error promoting skill candidate: {str(e)}"
@dataclass
class ListSkillReleasesTool(NeoSkillToolBase):
name: str = "astrbot_list_skill_releases"
description: str = "List skill releases."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"skill_key": {"type": "string"},
"active_only": {"type": "boolean", "default": False},
"stage": {"type": "string"},
"limit": {"type": "integer", "default": 100},
"offset": {"type": "integer", "default": 0},
},
"required": [],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
skill_key: str | None = None,
active_only: bool = False,
stage: str | None = None,
limit: int = 100,
offset: int = 0,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.list_releases(
skill_key=skill_key,
active_only=active_only,
stage=stage,
limit=limit,
offset=offset,
),
error_action="listing skill releases",
)
@dataclass
class RollbackSkillReleaseTool(NeoSkillToolBase):
name: str = "astrbot_rollback_skill_release"
description: str = "Rollback one skill release."
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"release_id": {"type": "string"},
},
"required": ["release_id"],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
release_id: str,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: client.skills.rollback_release(release_id),
error_action="rolling back skill release",
)
@dataclass
class SyncSkillReleaseTool(NeoSkillToolBase):
name: str = "astrbot_sync_skill_release"
description: str = (
"Sync stable Neo release payload to local SKILL.md and update mapping metadata."
)
parameters: dict = field(
default_factory=lambda: {
"type": "object",
"properties": {
"release_id": {"type": "string"},
"skill_key": {"type": "string"},
"require_stable": {"type": "boolean", "default": True},
},
"required": [],
}
)
async def call(
self,
context: ContextWrapper[AstrAgentContext],
release_id: str | None = None,
skill_key: str | None = None,
require_stable: bool = True,
) -> ToolExecResult:
return await self._run(
context,
lambda client, _sandbox: _sync_release_to_dict(
client,
release_id=release_id,
skill_key=skill_key,
require_stable=require_stable,
),
error_action="syncing skill release",
)
async def _sync_release_to_dict(
client: Any,
*,
release_id: str | None,
skill_key: str | None,
require_stable: bool,
) -> dict[str, str]:
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.sync_release(
client,
release_id=release_id,
skill_key=skill_key,
require_stable=require_stable,
)
return sync_mgr.sync_result_to_dict(result)
+8 -2
View File
@@ -1,3 +1,4 @@
import platform
from dataclasses import dataclass, field
import mcp
@@ -10,6 +11,8 @@ from astrbot.core.computer.computer_client import get_booter, get_local_booter
from astrbot.core.computer.tools.permissions import check_admin_permission
from astrbot.core.message.message_event_result import MessageChain
_OS_NAME = platform.system()
param_schema = {
"type": "object",
"properties": {
@@ -61,7 +64,7 @@ async def handle_result(result: dict, event: AstrMessageEvent) -> ToolExecResult
@dataclass
class PythonTool(FunctionTool):
name: str = "astrbot_execute_ipython"
description: str = "Run codes in an IPython shell."
description: str = f"Run codes in an IPython shell. Current OS: {_OS_NAME}."
parameters: dict = field(default_factory=lambda: param_schema)
async def call(
@@ -83,7 +86,10 @@ class PythonTool(FunctionTool):
@dataclass
class LocalPythonTool(FunctionTool):
name: str = "astrbot_execute_python"
description: str = "Execute codes in a Python environment."
description: str = (
f"Execute codes in a Python environment. Current OS: {_OS_NAME}. "
"Use system-compatible commands."
)
parameters: dict = field(default_factory=lambda: param_schema)
+229 -51
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.18.3"
VERSION = "4.20.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -132,11 +132,15 @@ DEFAULT_CONFIG = {
"computer_use_runtime": "none",
"computer_use_require_admin": True,
"sandbox": {
"booter": "shipyard",
"booter": "shipyard_neo",
"shipyard_endpoint": "",
"shipyard_access_token": "",
"shipyard_ttl": 3600,
"shipyard_max_sessions": 10,
"shipyard_neo_endpoint": "",
"shipyard_neo_access_token": "",
"shipyard_neo_profile": "python-default",
"shipyard_neo_ttl": 3600,
},
},
# SubAgent orchestrator mode:
@@ -215,6 +219,9 @@ DEFAULT_CONFIG = {
"telegram": {
"pre_ack_emoji": {"enable": False, "emojis": ["✍️"]},
},
"discord": {
"pre_ack_emoji": {"enable": False, "emojis": ["🤔"]},
},
},
"wake_prefix": ["/"],
"log_level": "INFO",
@@ -338,14 +345,20 @@ CONFIG_METADATA_2 = {
"企业微信智能机器人": {
"id": "wecom_ai_bot",
"type": "wecom_ai_bot",
"hint": "如果发现字段有异常,请重新创建",
"enable": True,
"wecom_ai_bot_connection_mode": "long_connection", # long_connection, webhook
"wecom_ai_bot_name": "",
"wecomaibot_ws_bot_id": "",
"wecomaibot_ws_secret": "",
"wecomaibot_token": "",
"wecomaibot_encoding_aes_key": "",
"wecomaibot_init_respond_text": "",
"wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"msg_push_webhook_url": "",
"only_use_webhook_url_to_send": False,
"token": "",
"encoding_aes_key": "",
"wecomaibot_ws_url": "wss://openws.work.weixin.qq.com",
"wecomaibot_heartbeat_interval": 30,
"unified_webhook_mode": True,
"webhook_uuid": "",
"callback_server_host": "0.0.0.0",
@@ -391,7 +404,6 @@ CONFIG_METADATA_2 = {
"discord_token": "",
"discord_proxy": "",
"discord_command_register": True,
"discord_guild_id_for_debug": "",
"discord_activity_name": "",
},
"Misskey": {
@@ -446,6 +458,20 @@ CONFIG_METADATA_2 = {
"satori_heartbeat_interval": 10,
"satori_reconnect_delay": 5,
},
"kook": {
"id": "kook",
"type": "kook",
"enable": False,
"kook_bot_token": "",
"kook_bot_nickname": "",
"kook_reconnect_delay": 1,
"kook_max_reconnect_delay": 60,
"kook_max_retry_delay": 60,
"kook_heartbeat_interval": 30,
"kook_heartbeat_timeout": 6,
"kook_max_heartbeat_failures": 3,
"kook_max_consecutive_failures": 5,
},
# "WebChat": {
# "id": "webchat",
# "type": "webchat",
@@ -715,6 +741,13 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "请务必填写正确,否则无法使用一些指令。",
},
"wecom_ai_bot_connection_mode": {
"description": "企业微信智能机器人连接模式",
"type": "string",
"options": ["webhook", "long_connection"],
"labels": ["Webhook 回调", "长连接"],
"hint": "Webhook 回调模式需要配置 Token/EncodingAESKey。长连接模式需要配置 BotID/Secret。",
},
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"type": "string",
@@ -725,6 +758,22 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
},
"wecomaibot_token": {
"description": "企业微信智能机器人 Token",
"type": "string",
"hint": "用于 Webhook 回调模式的身份验证。",
"condition": {
"wecom_ai_bot_connection_mode": "webhook",
},
},
"wecomaibot_encoding_aes_key": {
"description": "企业微信智能机器人 EncodingAESKey",
"type": "string",
"hint": "用于 Webhook 回调模式的消息加密解密。",
"condition": {
"wecom_ai_bot_connection_mode": "webhook",
},
},
"msg_push_webhook_url": {
"description": "企业微信消息推送 Webhook URL",
"type": "string",
@@ -735,6 +784,40 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
},
"wecomaibot_ws_bot_id": {
"description": "长连接 BotID",
"type": "string",
"hint": "企业微信智能机器人长连接模式凭证 BotID。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"wecomaibot_ws_secret": {
"description": "长连接 Secret",
"type": "string",
"hint": "企业微信智能机器人长连接模式凭证 Secret。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"wecomaibot_ws_url": {
"description": "长连接 WebSocket 地址",
"type": "string",
"invisible": True,
"hint": "默认值为 wss://openws.work.weixin.qq.com,一般无需修改。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"wecomaibot_heartbeat_interval": {
"description": "长连接心跳间隔",
"type": "int",
"invisible": True,
"hint": "长连接模式心跳间隔(秒),建议 30 秒。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"lark_bot_name": {
"description": "飞书机器人的名字",
"type": "string",
@@ -751,7 +834,8 @@ CONFIG_METADATA_2 = {
"hint": "可选的代理地址:http://ip:port",
},
"discord_command_register": {
"description": "是否自动将插件指令注册 Discord 斜杠指令",
"description": "注册 Discord 指令",
"hint": "启用后,自动将插件指令注册为 Discord 斜杠指令",
"type": "bool",
},
"discord_activity_name": {
@@ -778,7 +862,7 @@ CONFIG_METADATA_2 = {
"unified_webhook_mode": {
"description": "统一 Webhook 模式",
"type": "bool",
"hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}",
"hint": "Webhook 模式下使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}",
},
"webhook_uuid": {
"invisible": True,
@@ -786,6 +870,51 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "统一 Webhook 模式下的唯一标识符,创建平台时自动生成。",
},
"kook_bot_token": {
"description": "机器人 Token",
"type": "string",
"hint": "必填项。从 KOOK 开发者平台获取的机器人 Token。",
},
"kook_bot_nickname": {
"description": "Bot Nickname",
"type": "string",
"hint": "可选项。若发送者昵称与此值一致,将忽略该消息以避免广播风暴。",
},
"kook_reconnect_delay": {
"description": "重连延迟",
"type": "int",
"hint": "重连延迟时间(秒),使用指数退避策略。",
},
"kook_max_reconnect_delay": {
"description": "最大重连延迟",
"type": "int",
"hint": "重连延迟的最大值(秒)。",
},
"kook_max_retry_delay": {
"description": "最大重试延迟",
"type": "int",
"hint": "重试的最大延迟时间(秒)。",
},
"kook_heartbeat_interval": {
"description": "心跳间隔",
"type": "int",
"hint": "心跳检测间隔时间(秒)。",
},
"kook_heartbeat_timeout": {
"description": "心跳超时时间",
"type": "int",
"hint": "心跳检测超时时间(秒)。",
},
"kook_max_heartbeat_failures": {
"description": "最大心跳失败次数",
"type": "int",
"hint": "允许的最大心跳失败次数,超过后断开连接。",
},
"kook_max_consecutive_failures": {
"description": "最大连续失败次数",
"type": "int",
"hint": "允许的最大连续失败次数,超过后停止重试。",
},
},
},
"platform_settings": {
@@ -1060,7 +1189,7 @@ CONFIG_METADATA_2 = {
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://openrouter.ai/v1",
"api_base": "https://openrouter.ai/api/v1",
"proxy": "",
"custom_headers": {},
},
@@ -2871,12 +3000,48 @@ CONFIG_METADATA_3 = {
"provider_settings.sandbox.booter": {
"description": "沙箱环境驱动器",
"type": "string",
"options": ["shipyard"],
"labels": ["Shipyard"],
"options": ["shipyard_neo", "shipyard"],
"labels": ["Shipyard Neo", "Shipyard"],
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
},
},
"provider_settings.sandbox.shipyard_neo_endpoint": {
"description": "Shipyard Neo API Endpoint",
"type": "string",
"hint": "Shipyard Neo(Bay) 服务的 API 地址,默认 http://127.0.0.1:8114。",
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
},
},
"provider_settings.sandbox.shipyard_neo_access_token": {
"description": "Shipyard Neo Access Token",
"type": "string",
"hint": "Bay 的 API Keysk-bay-...)。留空时自动从 credentials.json 发现。",
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
},
},
"provider_settings.sandbox.shipyard_neo_profile": {
"description": "Shipyard Neo Profile",
"type": "string",
"hint": "Shipyard Neo 沙箱 profile,如 python-default。",
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
},
},
"provider_settings.sandbox.shipyard_neo_ttl": {
"description": "Shipyard Neo Sandbox TTL",
"type": "int",
"hint": "Shipyard Neo 沙箱生存时间(秒)。",
"condition": {
"provider_settings.computer_use_runtime": "sandbox",
"provider_settings.sandbox.booter": "shipyard_neo",
},
},
"provider_settings.sandbox.shipyard_endpoint": {
"description": "Shipyard API Endpoint",
"type": "string",
@@ -3112,46 +3277,6 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_quoted_fallback_images": {
"description": "引用图片回退解析上限",
"type": "int",
"hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_component_chain_depth": {
"description": "引用解析组件链深度",
"type": "int",
"hint": "解析 Reply 组件链时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_node_depth": {
"description": "引用解析转发节点深度",
"type": "int",
"hint": "解析合并转发节点时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_fetch": {
"description": "引用解析转发拉取上限",
"type": "int",
"hint": "递归拉取 get_forward_msg 的最大次数。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.warn_on_action_failure": {
"description": "引用解析 action 失败告警",
"type": "bool",
"hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
@@ -3195,6 +3320,46 @@ CONFIG_METADATA_3 = {
"type": "bool",
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
},
"provider_settings.max_quoted_fallback_images": {
"description": "引用图片回退解析上限",
"type": "int",
"hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_component_chain_depth": {
"description": "引用解析组件链深度",
"type": "int",
"hint": "解析 Reply 组件链时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_node_depth": {
"description": "引用解析转发节点深度",
"type": "int",
"hint": "解析合并转发节点时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_fetch": {
"description": "引用解析转发拉取上限",
"type": "int",
"hint": "递归拉取 get_forward_msg 的最大次数。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.warn_on_action_failure": {
"description": "引用解析 action 失败告警",
"type": "bool",
"hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
},
"condition": {
"provider_settings.enable": True,
@@ -3406,6 +3571,19 @@ CONFIG_METADATA_3 = {
"platform_specific.telegram.pre_ack_emoji.enable": True,
},
},
"platform_specific.discord.pre_ack_emoji.enable": {
"description": "[Discord] 启用预回应表情",
"type": "bool",
},
"platform_specific.discord.pre_ack_emoji.emojis": {
"description": "表情列表(Unicode 或自定义表情名)",
"type": "list",
"items": {"type": "string"},
"hint": "填写 Unicode 表情符号,例如:👍、🤔、⏳",
"condition": {
"platform_specific.discord.pre_ack_emoji.enable": True,
},
},
},
},
},
+4
View File
@@ -175,6 +175,10 @@ class LogManager:
_trace_sink_id: int | None = None
_NOISY_LOGGER_LEVELS: dict[str, int] = {
"aiosqlite": logging.WARNING,
"filelock": logging.WARNING,
"asyncio": logging.WARNING,
"tzlocal": logging.WARNING,
"apscheduler": logging.WARNING,
}
@classmethod
+44 -18
View File
@@ -539,13 +539,36 @@ class Reply(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: str = ComponentType.Poke
id: int | None = 0
qq: int | None = 0
type: ComponentType = ComponentType.Poke
_type: str | int = "126"
id: int | str | None = 0
qq: int | str | None = 0 # deprecated: legacy field, kept for compatibility
def __init__(self, type: str, **_) -> None:
type = f"Poke:{type}"
super().__init__(type=type, **_)
def __init__(self, poke_type: str | int | None = None, **_) -> None:
# Backward compatible with old signature: Poke(type="poke", ...)
legacy_type = _.pop("type", None)
if poke_type is None:
poke_type = legacy_type
if poke_type in (None, "", "poke", "Poke"):
poke_type = "126"
super().__init__(_type=str(poke_type), **_)
def target_id(self) -> str | None:
"""Return normalized target id, compatible with old `qq` field."""
for value in (self.id, self.qq):
if value is None:
continue
text = str(value).strip()
if text and text != "0":
return text
return None
def toDict(self):
target_id = self.target_id()
data = {"type": str(self._type or "126")}
if target_id:
data["id"] = target_id
return {"type": "poke", "data": data}
class Forward(BaseMessageComponent):
@@ -676,21 +699,24 @@ class File(BaseMessageComponent):
if self.url:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
# 等待下载完成
loop.run_until_complete(self._download_file())
# 检查是否有正在运行的 event loop
asyncio.get_running_loop()
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
except RuntimeError:
# 没有运行中的 event loop,可以同步执行
try:
# 使用 asyncio.run 安全地创建和关闭事件循环
asyncio.run(self._download_file())
except Exception:
logger.exception("文件下载失败")
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
@@ -27,7 +27,7 @@ class PreProcessStage(Stage):
) -> None | AsyncGenerator[None, None]:
"""在处理事件之前的预处理"""
# 平台特异配置:platform_specific.<platform>.pre_ack_emoji
supported = {"telegram", "lark"}
supported = {"telegram", "lark", "discord"}
platform = event.get_platform_name()
cfg = (
self.config.get("platform_specific", {})
+1 -1
View File
@@ -28,7 +28,7 @@ class RespondStage(Stage):
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Poke: lambda comp: comp.target_id() is not None, # 戳一戳
Comp.Node: lambda comp: bool(comp.content), # 转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.File: lambda comp: bool(comp.file_ or comp.url),
@@ -5,7 +5,7 @@ import traceback
from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.components import At, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -383,8 +383,11 @@ class ResultDecorateStage(Stage):
)
result.chain = [node]
has_plain = any(isinstance(item, Plain) for item in result.chain)
if has_plain:
# at 回复 / 引用回复仅适用于纯文本或图文消息
can_decorate = all(
isinstance(item, (Plain, Image)) for item in result.chain
)
if can_decorate:
# at 回复
if (
self.reply_with_mention
@@ -399,5 +402,4 @@ class ResultDecorateStage(Stage):
# 引用回复
if self.reply_with_quote:
if not any(isinstance(item, File) for item in result.chain):
result.chain.insert(0, Reply(id=event.message_obj.message_id))
result.chain.insert(0, Reply(id=event.message_obj.message_id))
+4
View File
@@ -180,6 +180,10 @@ class PlatformManager:
from .sources.line.line_adapter import (
LinePlatformAdapter, # noqa: F401
)
case "kook":
from .sources.kook.kook_adapter import (
KookPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
@@ -191,7 +191,7 @@ class AiocqhttpAdapter(Platform):
if "sub_type" in event:
if event["sub_type"] == "poke" and "target_id" in event:
abm.message.append(Poke(qq=str(event["target_id"]), type="poke"))
abm.message.append(Poke(id=str(event["target_id"])))
return abm
@@ -11,7 +11,7 @@ from dingtalk_stream import AckMessage
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, Image, Plain, Record, Video
from astrbot.api.message_components import At, File, Image, Plain, Record, Video
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
@@ -178,29 +178,110 @@ class DingtalkPlatformAdapter(Platform):
abm.session_id = abm.sender.user_id
message_type: str = cast(str, message.message_type)
robot_code = cast(str, message.robot_code or "")
raw_content = cast(dict, message.extensions.get("content") or {})
if not isinstance(raw_content, dict):
raw_content = {}
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "picture":
if not robot_code:
logger.error("钉钉图片消息解析失败: 回调中缺少 robotCode")
await self._remember_sender_binding(message, abm)
return abm
image_content = cast(
dingtalk_stream.ImageContent | None,
message.image_content,
)
download_code = cast(
str, (image_content.download_code if image_content else "") or ""
)
if not download_code:
logger.warning("钉钉图片消息缺少 downloadCode,已跳过")
else:
f_path = await self.download_ding_file(
download_code,
robot_code,
"jpg",
)
if f_path:
abm.message.append(Image.fromFileSystem(f_path))
else:
logger.warning("钉钉图片消息下载失败,无法解析为图片")
case "richText":
rtc: dingtalk_stream.RichTextContent = cast(
dingtalk_stream.RichTextContent, message.rich_text_content
)
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
plain_parts: list[str] = []
for content in contents:
plains = ""
if "text" in content:
plains += content["text"]
abm.message.append(Plain(plains))
plain_text = cast(str, content.get("text") or "")
if plain_text:
plain_parts.append(plain_text)
abm.message.append(Plain(plain_text))
elif "type" in content and content["type"] == "picture":
download_code = cast(str, content.get("downloadCode") or "")
if not download_code:
logger.warning(
"钉钉富文本图片消息缺少 downloadCode,已跳过"
)
continue
if not robot_code:
logger.error(
"钉钉富文本图片消息解析失败: 回调中缺少 robotCode"
)
continue
f_path = await self.download_ding_file(
content["downloadCode"],
cast(str, message.robot_code),
download_code,
robot_code,
"jpg",
)
abm.message.append(Image.fromFileSystem(f_path))
case "audio":
pass
if f_path:
abm.message.append(Image.fromFileSystem(f_path))
abm.message_str = "".join(plain_parts).strip()
case "audio" | "voice":
download_code = cast(str, raw_content.get("downloadCode") or "")
if not download_code:
logger.warning("钉钉语音消息缺少 downloadCode,已跳过")
elif not robot_code:
logger.error("钉钉语音消息解析失败: 回调中缺少 robotCode")
else:
voice_ext = cast(str, raw_content.get("fileExtension") or "")
if not voice_ext:
voice_ext = "amr"
voice_ext = voice_ext.lstrip(".")
f_path = await self.download_ding_file(
download_code,
robot_code,
voice_ext,
)
if f_path:
abm.message.append(Record.fromFileSystem(f_path))
case "file":
download_code = cast(str, raw_content.get("downloadCode") or "")
if not download_code:
logger.warning("钉钉文件消息缺少 downloadCode,已跳过")
elif not robot_code:
logger.error("钉钉文件消息解析失败: 回调中缺少 robotCode")
else:
file_name = cast(str, raw_content.get("fileName") or "")
file_ext = Path(file_name).suffix.lstrip(".") if file_name else ""
if not file_ext:
file_ext = cast(str, raw_content.get("fileExtension") or "")
if not file_ext:
file_ext = "file"
f_path = await self.download_ding_file(
download_code,
robot_code,
file_ext,
)
if f_path:
if not file_name:
file_name = Path(f_path).name
abm.message.append(File(name=file_name, file=f_path))
await self._remember_sender_binding(message, abm)
return abm # 别忘了返回转换后的消息对象
@@ -270,13 +351,23 @@ class DingtalkPlatformAdapter(Platform):
)
return ""
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
download_url = cast(
str,
(
resp_data.get("downloadUrl")
or resp_data.get("data", {}).get("downloadUrl")
or ""
),
)
if not download_url:
logger.error(f"下载钉钉文件失败: 未找到 downloadUrl, 响应: {resp_data}")
return ""
await download_file(download_url, str(f_path))
return str(f_path)
async def get_access_token(self) -> str:
try:
access_token = await asyncio.get_event_loop().run_in_executor(
access_token = await asyncio.get_running_loop().run_in_executor(
None,
self.client_.get_access_token,
)
@@ -541,6 +632,28 @@ class DingtalkPlatformAdapter(Platform):
self._safe_remove_file(cover_path)
if converted_video:
self._safe_remove_file(video_path)
elif isinstance(segment, File):
try:
file_path = await segment.get_file()
if not file_path:
logger.warning("钉钉文件发送失败: 无法解析文件路径")
continue
media_id = await self.upload_media(file_path, "file")
if not media_id:
continue
file_name = segment.name or Path(file_path).name
file_type = Path(file_name).suffix.lstrip(".")
await send_message(
msg_key="sampleFile",
msg_param={
"mediaId": media_id,
"fileName": file_name,
"fileType": file_type,
},
)
except Exception as e:
logger.warning(f"钉钉文件发送失败: {e}")
continue
async def send_message_chain_to_group(
self,
@@ -647,7 +760,7 @@ class DingtalkPlatformAdapter(Platform):
return
logger.error(f"钉钉机器人启动失败: {e}")
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self) -> None:
@@ -0,0 +1,371 @@
import asyncio
import json
import re
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, AtAll, Image, Plain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
register_platform_adapter,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from .kook_client import KookClient
from .kook_config import KookConfig
from .kook_event import KookEvent
@register_platform_adapter(
"kook",
"KOOK 适配器",
)
class KookPlatformAdapter(Platform):
def __init__(
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
) -> None:
super().__init__(platform_config, event_queue)
self.kook_config = KookConfig.from_dict(platform_config)
logger.debug(f"[KOOK] 配置: {self.kook_config.pretty_jsons()}")
self.settings = platform_settings
self.client = KookClient(self.kook_config, self._on_received)
self._reconnect_task = None
self.running = False
self._main_task = None
async def send_by_session(
self, session: MessageSesion, message_chain: MessageChain
):
inner_message = AstrBotMessage()
inner_message.session_id = session.session_id
inner_message.type = session.message_type
message_event = KookEvent(
message_str=message_chain.get_plain_text(),
message_obj=inner_message,
platform_meta=self.meta(),
session_id=session.session_id,
client=self.client,
)
await message_event.send(message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="kook", description="KOOK 适配器", id=self.kook_config.id
)
def _should_ignore_event_by_bot_nickname(self, payload: dict) -> bool:
bot_nickname = self.kook_config.bot_nickname.strip()
if not bot_nickname:
return False
author = payload.get("extra", {}).get("author", {})
if not isinstance(author, dict):
return False
author_nickname = author.get("nickname") or author.get("username") or ""
if not isinstance(author_nickname, str):
author_nickname = str(author_nickname)
return author_nickname.strip().casefold() == bot_nickname.casefold()
async def _on_received(self, data: dict):
logger.debug(f"KOOK 收到数据: {data}")
if "d" in data and data["s"] == 0:
payload = data["d"]
event_type = payload.get("type")
# 支持type=9(文本)和type=10(卡片)
if event_type in (9, 10):
if self._should_ignore_event_by_bot_nickname(payload):
return
try:
abm = await self.convert_message(payload)
await self.handle_msg(abm)
except Exception as e:
logger.error(f"[KOOK] 消息处理异常: {e}")
async def run(self):
"""主运行循环"""
self.running = True
logger.info("[KOOK] 启动KOOK适配器")
# 启动主循环
self._main_task = asyncio.create_task(self._main_loop())
try:
await self._main_task
except asyncio.CancelledError:
logger.info("[KOOK] 适配器被取消")
except Exception as e:
logger.error(f"[KOOK] 适配器运行异常: {e}")
finally:
self.running = False
await self._cleanup()
async def _main_loop(self):
"""主循环,处理连接和重连"""
consecutive_failures = 0
max_consecutive_failures = self.kook_config.max_consecutive_failures
max_retry_delay = self.kook_config.max_retry_delay
while self.running:
try:
logger.info("[KOOK] 尝试连接KOOK服务器...")
# 尝试连接
success = await self.client.connect()
if success:
logger.info("[KOOK] 连接成功,开始监听消息")
consecutive_failures = 0 # 重置失败计数
# 等待连接结束(可能是正常关闭或异常)
while self.client.running and self.running:
try:
# 等待 client 内部触发 _stop_event,或者超时 1 秒后重试
# 使用 wait_for 配合 timeout 是为了防止极端情况下 self.running 变化没被察觉
await asyncio.wait_for(
self.client.wait_until_closed(), timeout=1.0
)
except asyncio.TimeoutError:
# 正常超时,继续下一轮 while 检查
continue
if self.running:
logger.warning("[KOOK] 连接断开,准备重连")
else:
consecutive_failures += 1
logger.error(
f"[KOOK] 连接失败,连续失败次数: {consecutive_failures}"
)
if consecutive_failures >= max_consecutive_failures:
logger.error("[KOOK] 连续失败次数过多,停止重连")
break
# 等待一段时间后重试
wait_time = min(
2**consecutive_failures, max_retry_delay
) # 指数退避
logger.info(f"[KOOK] 等待 {wait_time} 秒后重试...")
await asyncio.sleep(wait_time)
except Exception as e:
consecutive_failures += 1
logger.error(f"[KOOK] 主循环异常: {e}")
if consecutive_failures >= max_consecutive_failures:
logger.error("[KOOK] 连续异常次数过多,停止重连")
break
await asyncio.sleep(5)
async def _cleanup(self):
"""清理资源"""
logger.info("[KOOK] 开始清理资源")
if self.client:
try:
await self.client.close()
except Exception as e:
logger.error(f"[KOOK] 关闭客户端异常: {e}")
if self._main_task and not self._main_task.done():
self._main_task.cancel()
try:
await self._main_task
except asyncio.CancelledError:
pass
logger.info("[KOOK] 资源清理完成")
def _parse_kmarkdown_text_message(
self, data: dict, self_id: str
) -> tuple[list, str]:
kmarkdown = data.get("extra", {}).get("kmarkdown", {})
content = data.get("content") or ""
raw_content = kmarkdown.get("raw_content") or content
if not isinstance(content, str):
content = str(content)
if not isinstance(raw_content, str):
raw_content = str(raw_content)
mention_name_map: dict[str, str] = {}
mention_part = kmarkdown.get("mention_part", [])
if isinstance(mention_part, list):
for item in mention_part:
if not isinstance(item, dict):
continue
mention_id = item.get("id")
if mention_id is None:
continue
mention_name_map[str(mention_id)] = str(item.get("username", ""))
components = []
cursor = 0
for match in re.finditer(r"\(met\)([^()]+)\(met\)", content):
if match.start() > cursor:
plain_text = content[cursor : match.start()]
if plain_text:
components.append(Plain(text=plain_text))
mention_target = match.group(1).strip()
if mention_target == "all":
components.append(AtAll())
elif mention_target:
components.append(
At(
qq=mention_target,
name=mention_name_map.get(mention_target, ""),
)
)
cursor = match.end()
if cursor < len(content):
tail_text = content[cursor:]
if tail_text:
components.append(Plain(text=tail_text))
message_str = raw_content
if components:
for comp in components:
if isinstance(comp, Plain):
if not comp.text.strip():
continue
break
if isinstance(comp, At):
if str(comp.qq) == str(self_id):
message_str = re.sub(
r"^@[^\s]+(\s*-\s*[^\s]+)?\s*",
"",
message_str,
count=1,
).strip()
break
if not components:
if message_str:
components = [Plain(text=message_str)]
else:
components = []
return components, message_str
def _parse_card_message(self, data: dict) -> tuple[list, str]:
content = data.get("content", "[]")
if not isinstance(content, str):
content = str(content)
card_list = json.loads(content)
text_parts: list[str] = []
images: list[str] = []
for card in card_list:
if not isinstance(card, dict):
continue
for module in card.get("modules", []):
if not isinstance(module, dict):
continue
module_type = module.get("type")
if module_type == "section":
section_text = module.get("text", {}).get("content", "")
if section_text:
text_parts.append(str(section_text))
continue
if module_type != "container":
continue
for element in module.get("elements", []):
if not isinstance(element, dict):
continue
if element.get("type") != "image":
continue
image_src = element.get("src")
if not isinstance(image_src, str):
logger.warning(
f'[KOOK] 处理卡片中的图片时发生错误,图片url "{image_src}" 应该为str类型, 而不是 "{type(image_src)}" '
)
continue
if not image_src.startswith(("http://", "https://")):
logger.warning(f"[KOOK] 屏蔽非http图片url: {image_src}")
continue
images.append(image_src)
text = "".join(text_parts)
message = []
if text:
message.append(Plain(text=text))
for img_url in images:
message.append(Image(file=img_url))
return message, text
async def convert_message(self, data: dict) -> AstrBotMessage:
abm = AstrBotMessage()
abm.raw_message = data
abm.self_id = self.client.bot_id
channel_type = data.get("channel_type")
author_id = data.get("author_id", "unknown")
# channel_type定义: https://developer.kookapp.cn/doc/event/event-introduction
match channel_type:
case "GROUP":
session_id = data.get("target_id") or "unknown"
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = session_id
abm.session_id = session_id
case "PERSON":
abm.type = MessageType.FRIEND_MESSAGE
abm.group_id = ""
abm.session_id = data.get("author_id", "unknown")
case "BROADCAST":
session_id = data.get("target_id") or "unknown"
abm.type = MessageType.OTHER_MESSAGE
abm.group_id = session_id
abm.session_id = session_id
case _:
raise ValueError(f"不支持的频道类型: {channel_type}")
abm.sender = MessageMember(
user_id=author_id,
nickname=data.get("extra", {}).get("author", {}).get("username", ""),
)
abm.message_id = data.get("msg_id", "unknown")
# 普通文本消息
if data.get("type") == 9:
message, message_str = self._parse_kmarkdown_text_message(
data, str(abm.self_id)
)
abm.message = message
abm.message_str = message_str
# 卡片消息
elif data.get("type") == 10:
try:
abm.message, abm.message_str = self._parse_card_message(data)
except Exception as exp:
logger.error(f"[KOOK] 卡片消息解析失败: {exp}")
abm.message_str = "[卡片消息解析失败]"
abm.message = [Plain(text="[卡片消息解析失败]")]
else:
logger.warning(f'[KOOK] 不支持的kook消息类型: "{data.get("type")}"')
abm.message_str = "[不支持的消息类型]"
abm.message = [Plain(text="[不支持的消息类型]")]
return abm
async def handle_msg(self, message: AstrBotMessage):
message_event = KookEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client,
)
self.commit_event(message_event)
@@ -0,0 +1,437 @@
import asyncio
import base64
import json
import os
import random
import time
import zlib
from pathlib import Path
import aiofiles
import aiohttp
import websockets
from astrbot import logger
from astrbot.core.platform.message_type import MessageType
from .kook_config import KookConfig
from .kook_types import KookApiPaths, KookMessageType
class KookClient:
def __init__(self, config: KookConfig, event_callback):
# 数据字段
self.config = config
self._bot_id = ""
self._bot_name = ""
# 资源字段
self._http_client = aiohttp.ClientSession(
headers={
"Authorization": f"Bot {self.config.token}",
}
)
self.event_callback = event_callback # 回调函数,用于处理接收到的事件
self.ws = None
self.heartbeat_task = None
self._stop_event = asyncio.Event() # 用于通知连接结束
# 状态/计算字段
self.running = False
self.session_id = None
self.last_sn = 0 # 记录最后处理的消息序号
self.last_heartbeat_time = 0
self.heartbeat_failed_count = 0
@property
def bot_id(self):
return self._bot_id
@property
def bot_name(self):
return self._bot_name
async def get_bot_info(self) -> str:
"""获取机器人账号ID"""
url = KookApiPaths.USER_ME
try:
async with self._http_client.get(url) as resp:
if resp.status != 200:
logger.error(f"[KOOK] 获取机器人账号ID失败,状态码: {resp.status}")
return ""
data = await resp.json()
if data.get("code") != 0:
logger.error(f"[KOOK] 获取机器人账号ID失败: {data}")
return ""
bot_id: str = data["data"]["id"]
self._bot_id = bot_id
logger.info(f"[KOOK] 获取机器人账号ID成功: {bot_id}")
bot_name: str = data["data"]["nickname"] or data["data"]["username"]
self._bot_name = bot_name
logger.info(f"[KOOK] 获取机器人名称成功: {self._bot_name}")
return bot_id
except Exception as e:
logger.error(f"[KOOK] 获取机器人账号ID异常: {e}")
return ""
async def get_gateway_url(self, resume=False, sn=0, session_id=None):
"""获取网关连接地址"""
url = KookApiPaths.GATEWAY_INDEX
# 构建连接参数
params = {}
if resume:
params["resume"] = 1
params["sn"] = sn
if session_id:
params["session_id"] = session_id
try:
async with self._http_client.get(url, params=params) as resp:
if resp.status != 200:
logger.error(f"[KOOK] 获取gateway失败,状态码: {resp.status}")
return None
data = await resp.json()
if data.get("code") != 0:
logger.error(f"[KOOK] 获取gateway失败: {data}")
return None
gateway_url: str = data["data"]["url"]
logger.info(f"[KOOK] 获取gateway成功: {gateway_url.split('?')[0]}")
return gateway_url
except Exception as e:
logger.error(f"[KOOK] 获取gateway异常: {e}")
return None
async def connect(self, resume=False):
"""连接WebSocket"""
if self.ws:
try:
await self.ws.close()
except Exception:
pass
self.ws = None
self._stop_event.clear()
try:
# 获取gateway地址
gateway_url = await self.get_gateway_url(
resume=resume, sn=self.last_sn, session_id=self.session_id
)
await self.get_bot_info()
if not gateway_url:
return False
# 连接WebSocket
self.ws = await websockets.connect(gateway_url)
self.running = True
logger.info("[KOOK] WebSocket 连接成功")
# 启动心跳任务
if self.heartbeat_task:
self.heartbeat_task.cancel()
self.heartbeat_task = asyncio.create_task(self._heartbeat_loop())
# 开始监听消息
await self.listen()
return True
except Exception as e:
logger.error(f"[KOOK] WebSocket 连接失败: {e}")
if self.ws:
try:
await self.ws.close()
except Exception:
pass
self.ws = None
return False
async def listen(self):
"""监听WebSocket消息"""
try:
while self.running:
try:
msg = await asyncio.wait_for(self.ws.recv(), timeout=10) # type: ignore
if isinstance(msg, bytes):
try:
msg = zlib.decompress(msg)
except Exception as e:
logger.error(f"[KOOK] 解压消息失败: {e}")
continue
msg = msg.decode("utf-8")
data = json.loads(msg)
# 处理不同类型的信令
await self._handle_signal(data)
except asyncio.TimeoutError:
# 超时检查,继续循环
continue
except websockets.exceptions.ConnectionClosed:
logger.warning("[KOOK] WebSocket连接已关闭")
break
except Exception as e:
logger.error(f"[KOOK] 消息处理异常: {e}")
break
except Exception as e:
logger.error(f"[KOOK] WebSocket 监听异常: {e}")
finally:
self.running = False
self._stop_event.set()
async def _handle_signal(self, data):
"""处理不同类型的信令"""
signal_type = data.get("s")
if signal_type == 0: # 事件消息
# 更新消息序号
if "sn" in data:
self.last_sn = data["sn"]
await self.event_callback(data)
elif signal_type == 1: # HELLO握手
await self._handle_hello(data)
elif signal_type == 3: # PONG心跳响应
await self._handle_pong(data)
elif signal_type == 5: # RECONNECT重连指令
await self._handle_reconnect(data)
elif signal_type == 6: # RESUME ACK
await self._handle_resume_ack(data)
else:
logger.debug(f"[KOOK] 未处理的信令类型: {signal_type}")
async def _handle_hello(self, data):
"""处理HELLO握手"""
hello_data = data.get("d", {})
code = hello_data.get("code", 0)
if code == 0:
self.session_id = hello_data.get("session_id")
logger.info(f"[KOOK] 握手成功,session_id: {self.session_id}")
# TODO 重置重连延迟
# self.reconnect_delay = 1
else:
logger.error(f"[KOOK] 握手失败,错误码: {code}")
if code == 40103: # token过期
logger.error("[KOOK] Token已过期,需要重新获取")
self.running = False
async def _handle_pong(self, data):
"""处理PONG心跳响应"""
self.last_heartbeat_time = time.time()
self.heartbeat_failed_count = 0
async def _handle_reconnect(self, data):
"""处理重连指令"""
logger.warning("[KOOK] 收到重连指令")
# 清空本地状态
self.last_sn = 0
self.session_id = None
self.running = False
async def _handle_resume_ack(self, data):
"""处理RESUME确认"""
resume_data = data.get("d", {})
self.session_id = resume_data.get("session_id")
logger.info(f"[KOOK] Resume成功,session_id: {self.session_id}")
async def _heartbeat_loop(self):
"""心跳循环"""
while self.running:
try:
# 随机化心跳间隔 (±5秒)
interval = max(
1, self.config.heartbeat_interval + random.randint(-5, 5)
)
await asyncio.sleep(interval)
if not self.running:
break
# 发送心跳
await self._send_ping()
# 等待PONG响应
await asyncio.sleep(self.config.heartbeat_timeout)
# 检查是否收到PONG响应
if (
time.time() - self.last_heartbeat_time
> self.config.heartbeat_timeout
):
self.heartbeat_failed_count += 1
logger.warning(
f"[KOOK] 心跳超时,失败次数: {self.heartbeat_failed_count}"
)
if (
self.heartbeat_failed_count
>= self.config.max_heartbeat_failures
):
logger.error("[KOOK] 心跳失败次数过多,准备重连")
self.running = False
break
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"[KOOK] 心跳异常: {e}")
self.heartbeat_failed_count += 1
async def _send_ping(self):
"""发送心跳PING"""
try:
ping_data = {"s": 2, "sn": self.last_sn}
await self.ws.send(json.dumps(ping_data)) # type: ignore
except Exception as e:
logger.error(f"[KOOK] 发送心跳失败: {e}")
async def send_text(
self,
target_id: str,
content: str,
astrbot_message_type: MessageType,
kook_message_type: KookMessageType,
reply_message_id: str | int = "",
):
"""发送文本消息
消息发送接口文档参见: https://developer.kookapp.cn/doc/http/message#%E5%8F%91%E9%80%81%E9%A2%91%E9%81%93%E8%81%8A%E5%A4%A9%E6%B6%88%E6%81%AF
KMarkdown格式参见: https://developer.kookapp.cn/doc/kmarkdown-desc
"""
url = KookApiPaths.CHANNEL_MESSAGE_CREATE
if astrbot_message_type == MessageType.FRIEND_MESSAGE:
url = KookApiPaths.DIRECT_MESSAGE_CREATE
payload = {
"target_id": target_id,
"content": content,
"type": kook_message_type,
}
if reply_message_id:
payload["quote"] = reply_message_id
payload["reply_msg_id"] = reply_message_id
try:
async with self._http_client.post(url, json=payload) as resp:
if resp.status == 200:
result = await resp.json()
if result.get("code") != 0:
raise RuntimeError(
f'发送kook消息类型 "{kook_message_type.name}" 失败: {result}'
)
# else:
# logger.info("[KOOK] 发送消息成功")
else:
raise RuntimeError(
f'发送kook消息类型 "{kook_message_type.name}" HTTP错误: {resp.status} , 响应内容 : {await resp.text()}'
)
except RuntimeError:
raise
except Exception as e:
logger.error(
f'[KOOK] 发送kook消息类型 "{kook_message_type.name}" 异常: {e}'
)
async def upload_asset(self, file_url: str | None) -> str:
"""上传文件到kook,获得远端资源url
接口定义参见: https://developer.kookapp.cn/doc/http/asset
"""
if not file_url:
return ""
bytes_data: bytes | None = None
filename = "unknown"
if file_url.startswith(("http://", "https://")):
filename = file_url.split("/")[-1]
return file_url
if file_url.startswith("base64:///"):
# b64decode的时候得开头留一个'/'的, 不然会报错
b64_str = file_url.removeprefix("base64://")
bytes_data = base64.b64decode(b64_str)
elif file_url.startswith("file://") or os.path.exists(file_url):
file_url = file_url.removeprefix("file:///")
file_url = file_url.removeprefix("file://")
try:
target_path = Path(file_url).resolve()
except Exception as exp:
logger.error(f'[KOOK] 获取文件 "{file_url}" 绝对路径失败: "{exp}"')
raise FileNotFoundError(
f'获取文件 "{file_url}" 绝对路径失败: "{exp}"'
) from exp
if not target_path.is_file():
raise FileNotFoundError(f"文件不存在: {target_path.name}")
filename = target_path.name
async with aiofiles.open(target_path, "rb") as f:
bytes_data = await f.read()
else:
raise ValueError(f'[KOOK] 不支持的文件资源类型: "{file_url}"')
data = aiohttp.FormData()
data.add_field("file", bytes_data, filename=filename)
url = KookApiPaths.ASSET_CREATE
try:
async with self._http_client.post(url, data=data) as resp:
if resp.status == 200:
result: dict = await resp.json()
logger.debug(f"[KOOK] 上传文件响应: {result}")
if result.get("code") == 0:
logger.info("[KOOK] 上传文件到kook服务器成功")
remote_url = result["data"]["url"]
logger.debug(f"[KOOK] 文件远端URL: {remote_url}")
return remote_url
else:
raise RuntimeError(f"上传文件到kook服务器失败: {result}")
else:
raise RuntimeError(
f"上传文件到kook服务器 HTTP错误: {resp.status} , {await resp.text()}"
)
except RuntimeError:
raise
except Exception as e:
raise RuntimeError(f"上传文件到kook服务器异常: {e}") from e
async def wait_until_closed(self):
"""提供给外部调用的等待方法"""
await self._stop_event.wait()
async def close(self):
"""关闭连接"""
self.running = False
self._stop_event.set()
if self.heartbeat_task:
self.heartbeat_task.cancel()
try:
await self.heartbeat_task
except asyncio.CancelledError:
pass
if self.ws:
try:
await self.ws.close()
except Exception as e:
logger.error(f"[KOOK] 关闭WebSocket异常: {e}")
if self._http_client:
await self._http_client.close()
logger.info("[KOOK] 连接已关闭")
@@ -0,0 +1,133 @@
import json
from dataclasses import asdict, dataclass
from typing import Any
@dataclass
class KookConfig:
"""KOOK 适配器配置类"""
# 基础配置
token: str
bot_nickname: str = ""
enable: bool = False
id: str = "kook"
# 重连配置
reconnect_delay: int = 1
"""重连延迟基数(秒),指数退避"""
max_reconnect_delay: int = 60
"""最大重连延迟(秒)"""
max_retry_delay: int = 60
"""最大重试延迟(秒)"""
# 心跳配置
heartbeat_interval: int = 30
"""心跳间隔(秒)"""
heartbeat_timeout: int = 6
"""心跳超时时间(秒)"""
max_heartbeat_failures: int = 3
"""最大心跳失败次数"""
# 失败处理
max_consecutive_failures: int = 5
"""最大连续失败次数"""
@classmethod
def from_dict(cls, config_dict: dict) -> "KookConfig":
"""从字典创建配置对象"""
return cls(
# 适配器id 应该是不能改的
# id=config_dict.get("id", "kook"),
enable=config_dict.get("enable", False),
token=config_dict.get("kook_bot_token", ""),
bot_nickname=config_dict.get("kook_bot_nickname", ""),
reconnect_delay=config_dict.get(
"kook_reconnect_delay",
KookConfig.reconnect_delay,
),
max_reconnect_delay=config_dict.get(
"kook_max_reconnect_delay",
KookConfig.max_reconnect_delay,
),
max_retry_delay=config_dict.get(
"kook_max_retry_delay",
KookConfig.max_retry_delay,
),
heartbeat_interval=config_dict.get(
"kook_heartbeat_interval",
KookConfig.heartbeat_interval,
),
heartbeat_timeout=config_dict.get(
"kook_heartbeat_timeout",
KookConfig.heartbeat_timeout,
),
max_heartbeat_failures=config_dict.get(
"kook_max_heartbeat_failures",
KookConfig.max_heartbeat_failures,
),
max_consecutive_failures=config_dict.get(
"kook_max_consecutive_failures",
KookConfig.max_consecutive_failures,
),
)
def to_dict(self) -> dict[str, Any]:
return asdict(self)
def pretty_jsons(self, indent=2) -> str:
dict_config = self.to_dict()
dict_config["token"] = "*" * len(self.token) if self.token else "MISSING"
return json.dumps(dict_config, indent=indent, ensure_ascii=False)
# TODO 没用上的config配置,未来有空会实现这些配置描述的功能?
# # 连接配置
# CONNECTION_CONFIG = {
# # 心跳配置
# "heartbeat_interval": 30, # 心跳间隔(秒)
# "heartbeat_timeout": 6, # 心跳超时时间(秒)
# "max_heartbeat_failures": 3, # 最大心跳失败次数
# # 重连配置
# "initial_reconnect_delay": 1, # 初始重连延迟(秒)
# "max_reconnect_delay": 60, # 最大重连延迟(秒)
# "max_consecutive_failures": 5, # 最大连续失败次数
# # WebSocket配置
# "websocket_timeout": 10, # WebSocket接收超时(秒)
# "connection_timeout": 30, # 连接超时(秒)
# # 消息处理配置
# "enable_compression": True, # 是否启用消息压缩
# "max_message_size": 1024 * 1024, # 最大消息大小(字节)
# }
# # 日志配置
# LOGGING_CONFIG = {
# "level": "INFO", # 日志级别:DEBUG, INFO, WARNING, ERROR
# "format": "[KOOK] %(message)s",
# "enable_heartbeat_logs": False, # 是否启用心跳日志
# "enable_message_logs": False, # 是否启用消息日志
# }
# # 错误处理配置
# ERROR_HANDLING_CONFIG = {
# "retry_on_network_error": True, # 网络错误时是否重试
# "retry_on_token_expired": True, # Token过期时是否重试
# "max_retry_attempts": 3, # 最大重试次数
# "retry_delay_base": 2, # 重试延迟基数(秒)
# }
# # 性能配置
# PERFORMANCE_CONFIG = {
# "enable_message_buffering": True, # 是否启用消息缓冲
# "buffer_size": 100, # 缓冲区大小
# "enable_connection_pooling": True, # 是否启用连接池
# "max_concurrent_requests": 10, # 最大并发请求数
# }
# # 安全配置
# SECURITY_CONFIG = {
# "verify_ssl": True, # 是否验证SSL证书
# "enable_rate_limiting": True, # 是否启用速率限制
# "rate_limit_requests": 100, # 速率限制请求数
# "rate_limit_window": 60, # 速率限制窗口(秒)
# }
@@ -0,0 +1,209 @@
import asyncio
import json
from collections.abc import Coroutine
from pathlib import Path
from typing import Any
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.message.components import (
At,
AtAll,
BaseMessageComponent,
File,
Image,
Json,
Plain,
Record,
Reply,
Video,
)
from astrbot.core.platform import MessageType
from .kook_client import KookClient
from .kook_types import (
FileModule,
KookCardMessage,
KookCardMessageContainer,
KookMessageType,
OrderMessage,
)
class KookEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: KookClient,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.channel_id = message_obj.group_id or message_obj.session_id
self.astrbot_message_type: MessageType = message_obj.type
self._file_message_counter = 0
def _wrap_message(
self, index: int, message_component: BaseMessageComponent
) -> Coroutine[Any, Any, OrderMessage]:
async def wrap_upload(
index: int, message_type: KookMessageType, upload_coro
) -> OrderMessage:
url = await upload_coro
return OrderMessage(index=index, text=url, type=message_type)
async def handle_plain(
index: int,
text: str | None,
reply_id: str | int = "",
type: KookMessageType = KookMessageType.KMARKDOWN,
):
if not text:
text = ""
return OrderMessage(
index=index,
text=text,
type=type,
reply_id=reply_id,
)
match message_component:
case Image():
self._file_message_counter += 1
return wrap_upload(
index,
KookMessageType.IMAGE,
self.client.upload_asset(message_component.file),
)
case Video():
self._file_message_counter += 1
return wrap_upload(
index,
KookMessageType.VIDEO,
self.client.upload_asset(message_component.file),
)
case File():
async def handle_file(index: int, f_item: File):
f_data = await f_item.get_file()
url = await self.client.upload_asset(f_data)
return OrderMessage(
index=index, text=url, type=KookMessageType.FILE
)
self._file_message_counter += 1
return handle_file(index, message_component)
case Record():
async def handle_audio(index: int, f_item: Record):
file_path = await f_item.convert_to_file_path()
url = await self.client.upload_asset(file_path)
title = f_item.text or Path(file_path).name
return OrderMessage(
index=index,
text=KookCardMessageContainer(
[
KookCardMessage(
modules=[
FileModule(
type="audio",
title=title,
src=url,
)
]
)
]
).to_json(),
type=KookMessageType.CARD,
)
return handle_audio(index, message_component)
case Plain():
return handle_plain(index, message_component.text)
case At():
return handle_plain(index, f"(met){message_component.qq}(met)")
case AtAll():
return handle_plain(index, "(met)all(met)")
case Reply():
return handle_plain(index, "", reply_id=message_component.id)
case Json():
json_data = message_component.data
# kook卡片json外层得是一个列表
if isinstance(json_data, dict):
json_data = [json_data]
return handle_plain(
index,
# 考虑到kook可能会更改消息结构,为了能让插件开发者
# 自行根据kook文档描述填卡片json内容,故不做模型校验
# KookCardMessage().model_validate(message_component.data).to_json(),
text=json.dumps(json_data),
type=KookMessageType.CARD,
)
case _:
raise NotImplementedError(
f'kook适配器尚未实现对 "{message_component.type}" 消息类型的支持'
)
async def send(self, message: MessageChain):
file_upload_tasks: list[Coroutine[Any, Any, OrderMessage]] = []
for index, item in enumerate(message.chain):
file_upload_tasks.append(self._wrap_message(index, item))
if self._file_message_counter > 0:
logger.debug("[Kook] 正在向kook服务器上传文件")
tasks_result = await asyncio.gather(*file_upload_tasks, return_exceptions=True)
order_messages: list[OrderMessage] = []
for index, result in enumerate(tasks_result):
if isinstance(result, BaseException):
logger.error(f"[Kook] {result}")
# 构造一个虚假的 OrderMessage,让用户知道这里本来有张图但坏了
# 这样后面的 for 循环就能把它当成普通文本发出去
err_node = OrderMessage(
index=index,
text=str(result),
type=KookMessageType.TEXT,
)
order_messages.append(err_node)
else:
order_messages.append(result)
order_messages.sort(key=lambda x: x.index)
reply_id: str | int = ""
errors: list[Exception] = []
for item in order_messages:
if item.reply_id:
reply_id = item.reply_id
if not item.text:
logger.debug(f'[Kook] 跳过空消息,类型为"{item.type}"')
continue
try:
await self.client.send_text(
self.channel_id,
item.text,
self.astrbot_message_type,
item.type,
reply_id,
)
except RuntimeError as exp:
await self.client.send_text(
self.channel_id,
str(exp),
self.astrbot_message_type,
KookMessageType.TEXT,
reply_id,
)
errors.append(exp)
if errors:
err_msg = "\n".join([str(err) for err in errors])
logger.error(f"[kook] {err_msg}")
await super().send(message)
@@ -0,0 +1,241 @@
import json
from dataclasses import field
from enum import IntEnum
from typing import Literal
from pydantic import BaseModel, ConfigDict
from pydantic.dataclasses import dataclass
class KookApiPaths:
"""Kook Api 路径"""
BASE_URL = "https://www.kookapp.cn"
API_VERSION_PATH = "/api/v3"
# 初始化相关
USER_ME = f"{BASE_URL}{API_VERSION_PATH}/user/me"
GATEWAY_INDEX = f"{BASE_URL}{API_VERSION_PATH}/gateway/index"
# 消息相关
ASSET_CREATE = f"{BASE_URL}{API_VERSION_PATH}/asset/create"
## 频道消息
CHANNEL_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/message/create"
## 私聊消息
DIRECT_MESSAGE_CREATE = f"{BASE_URL}{API_VERSION_PATH}/direct-message/create"
# 定义参见kook事件结构文档: https://developer.kookapp.cn/doc/event/event-introduction
class KookMessageType(IntEnum):
TEXT = 1
IMAGE = 2
VIDEO = 3
FILE = 4
AUDIO = 8
KMARKDOWN = 9
CARD = 10
SYSTEM = 255
ThemeType = Literal[
"primary", "success", "danger", "warning", "info", "secondary", "none", "invisible"
]
"""主题,可选的值为:primary, success, danger, warning, info, secondary, none.默认为 primary,为 none 时不显示侧边框。"""
SizeType = Literal["xs", "sm", "md", "lg"]
"""大小,可选值为:xs, sm, md, lg, 一般默认为 lg"""
SectionMode = Literal["left", "right"]
CountdownMode = Literal["day", "hour", "second"]
class KookCardColor(str):
"""16 进制色值"""
class KookCardModelBase:
"""卡片模块基类"""
type: str
@dataclass
class PlainTextElement(KookCardModelBase):
content: str
type: str = "plain-text"
emoji: bool = True
@dataclass
class KmarkdownElement(KookCardModelBase):
content: str
type: str = "kmarkdown"
@dataclass
class ImageElement(KookCardModelBase):
src: str
type: str = "image"
alt: str = ""
size: SizeType = "lg"
circle: bool = False
fallbackUrl: str | None = None
@dataclass
class ButtonElement(KookCardModelBase):
text: str
type: str = "button"
theme: ThemeType = "primary"
value: str = ""
"""当为 link 时,会跳转到 value 代表的链接;
当为 return-val 系统会通过系统消息将消息 id,点击用户 id value 发回给发送者发送者可以根据自己的需求进行处理,消息事件参见button 点击事件私聊和频道内均可使用按钮点击事件"""
click: Literal["", "link", "return-val"] = ""
"""click 代表用户点击的事件,默认为"",代表无任何事件。"""
AnyElement = PlainTextElement | KmarkdownElement | ImageElement | ButtonElement | str
@dataclass
class ParagraphStructure(KookCardModelBase):
fields: list[PlainTextElement | KmarkdownElement]
type: str = "paragraph"
cols: int = 1
"""范围是 1-3 , 移动端忽略此参数"""
@dataclass
class HeaderModule(KookCardModelBase):
text: PlainTextElement
type: str = "header"
@dataclass
class SectionModule(KookCardModelBase):
text: PlainTextElement | KmarkdownElement | ParagraphStructure
type: str = "section"
mode: SectionMode = "left"
accessory: ImageElement | ButtonElement | None = None
@dataclass
class ImageGroupModule(KookCardModelBase):
"""1 到多张图片的组合"""
elements: list[ImageElement]
type: str = "image-group"
@dataclass
class ContainerModule(KookCardModelBase):
"""1 到多张图片的组合,与图片组模块(ImageGroupModule)不同,图片并不会裁切为正方形。多张图片会纵向排列。"""
elements: list[ImageElement]
type: str = "container"
@dataclass
class ActionGroupModule(KookCardModelBase):
elements: list[ButtonElement]
type: str = "action-group"
@dataclass
class ContextModule(KookCardModelBase):
elements: list[PlainTextElement | KmarkdownElement | ImageElement]
"""最多包含10个元素"""
type: str = "context"
@dataclass
class DividerModule(KookCardModelBase):
type: str = "divider"
@dataclass
class FileModule(KookCardModelBase):
src: str
title: str = ""
type: Literal["file", "audio", "video"] = "file"
cover: str | None = None
"""cover 仅音频有效, 是音频的封面图"""
@dataclass
class CountdownModule(KookCardModelBase):
"""startTime 和 endTime 为毫秒时间戳,startTime 和 endTime 不能小于服务器当前时间戳。"""
endTime: int
"""毫秒时间戳"""
type: str = "countdown"
startTime: int | None = None
"""毫秒时间戳, 仅当mode为second才有这个字段"""
mode: CountdownMode = "day"
"""mode 主要是倒计时的样式"""
@dataclass
class InviteModule(KookCardModelBase):
code: str
"""邀请链接或者邀请码"""
type: str = "invite"
# 所有模块的联合类型
AnyModule = (
HeaderModule
| SectionModule
| ImageGroupModule
| ContainerModule
| ActionGroupModule
| ContextModule
| DividerModule
| FileModule
| CountdownModule
| InviteModule
)
class KookCardMessage(BaseModel):
"""卡片定义文档详见 : https://developer.kookapp.cn/doc/cardmessage
此类型不能直接to_json后发送,因为kook要求卡片容器json顶层必须是**列表**
若要发送卡片消息请使用KookCardMessageContainer
"""
model_config = ConfigDict(arbitrary_types_allowed=True)
type: str = "card"
theme: ThemeType | None = None
size: SizeType | None = None
color: KookCardColor | None = None
modules: list[AnyModule] = field(default_factory=list)
"""单个 card 模块数量不限制,但是一条消息中所有卡片的模块数量之和最多是 50"""
def add_module(self, module: AnyModule):
self.modules.append(module)
def to_dict(self, exclude_none: bool = True):
"""exclude_none:去掉值为 None 字段,保留结构"""
return self.model_dump(exclude_none=exclude_none)
def to_json(self, indent: int | None = None, ensure_ascii: bool = True):
return json.dumps(self.to_dict(), indent=indent, ensure_ascii=ensure_ascii)
class KookCardMessageContainer(list[KookCardMessage]):
"""卡片消息容器(列表),此类型可以直接to_json后发送出去"""
def append(self, object: KookCardMessage) -> None:
return super().append(object)
def to_json(self, indent: int | None = None, ensure_ascii: bool = True) -> str:
return json.dumps(
[i.to_dict() for i in self], indent=indent, ensure_ascii=ensure_ascii
)
@dataclass
class OrderMessage:
index: int
text: str
type: KookMessageType
reply_id: str | int = ""
@@ -34,7 +34,7 @@ from .server import LarkWebhookServer
@register_platform_adapter(
"lark", "飞书机器人官方 API 适配器", support_streaming_message=False
"lark", "飞书机器人官方 API 适配器", support_streaming_message=True
)
class LarkPlatformAdapter(Platform):
def __init__(
@@ -491,7 +491,7 @@ class LarkPlatformAdapter(Platform):
name="lark",
description="飞书机器人官方 API 适配器",
id=cast(str, self.config.get("id")),
support_streaming_message=False,
support_streaming_message=True,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None:
@@ -1,3 +1,4 @@
import asyncio
import base64
import json
import os
@@ -5,6 +6,14 @@ import uuid
from io import BytesIO
import lark_oapi as lark
from lark_oapi.api.cardkit.v1 import (
ContentCardElementRequest,
ContentCardElementRequestBody,
CreateCardRequest,
CreateCardRequestBody,
SettingsCardRequest,
SettingsCardRequestBody,
)
from lark_oapi.api.im.v1 import (
CreateFileRequest,
CreateFileRequestBody,
@@ -28,6 +37,7 @@ from astrbot.core.utils.media_utils import (
convert_video_format,
get_media_duration,
)
from astrbot.core.utils.metrics import Metric
class LarkMessageEvent(AstrMessageEvent):
@@ -555,15 +565,257 @@ class LarkMessageEvent(AstrMessageEvent):
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
return
async def send_streaming(self, generator, use_fallback: bool = False):
async def _create_streaming_card(self) -> str | None:
"""创建一个开启流式更新模式的卡片实体,返回 card_id。"""
if self.bot.cardkit is None:
logger.error("[Lark] API Client cardkit 模块未初始化")
return None
card_json = {
"schema": "2.0",
"header": {
"title": {"content": "", "tag": "plain_text"},
},
"config": {
"streaming_mode": True,
"summary": {"content": ""},
"streaming_config": {
"print_frequency_ms": {"default": 50},
"print_step": {"default": 2},
"print_strategy": "fast",
},
},
"body": {
"elements": [
{
"tag": "markdown",
"content": "",
"element_id": "markdown_1",
}
]
},
}
request = (
CreateCardRequest.builder()
.request_body(
CreateCardRequestBody.builder()
.type("card_json")
.data(json.dumps(card_json, ensure_ascii=False))
.build()
)
.build()
)
try:
response = await self.bot.cardkit.v1.card.acreate(request)
except Exception as e:
logger.error(f"[Lark] 创建流式卡片实体失败: {e}")
return None
if not response.success():
logger.error(
f"[Lark] 创建流式卡片实体失败({response.code}): {response.msg}"
)
return None
if response.data is None or not response.data.card_id:
logger.error("[Lark] 创建流式卡片实体成功但未返回 card_id")
return None
card_id = response.data.card_id
logger.debug(f"[Lark] 创建流式卡片实体成功: {card_id}")
return card_id
async def _send_card_message(
self,
card_id: str,
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
) -> bool:
"""将卡片实体作为 interactive 消息发送。"""
content = json.dumps(
{"type": "card", "data": {"card_id": card_id}},
ensure_ascii=False,
)
return await self._send_im_message(
self.bot,
content=content,
msg_type="interactive",
reply_message_id=reply_message_id,
receive_id=receive_id,
receive_id_type=receive_id_type,
)
async def _update_streaming_text(
self,
card_id: str,
content: str,
sequence: int,
) -> bool:
"""调用 CardKit 流式更新文本接口,向 markdown_1 组件推送全量文本。"""
if self.bot.cardkit is None:
logger.error("[Lark] API Client cardkit 模块未初始化")
return False
request = (
ContentCardElementRequest.builder()
.card_id(card_id)
.element_id("markdown_1")
.request_body(
ContentCardElementRequestBody.builder()
.content(content)
.sequence(sequence)
.uuid(str(uuid.uuid4()))
.build()
)
.build()
)
try:
response = await self.bot.cardkit.v1.card_element.acontent(request)
except Exception as e:
logger.debug(f"[Lark] 流式更新文本失败 (ignored): {e}")
return False
if not response.success():
logger.debug(f"[Lark] 流式更新文本失败({response.code}): {response.msg}")
return False
return True
async def _close_streaming_mode(
self,
card_id: str,
sequence: int,
) -> None:
"""关闭卡片的流式更新模式,使其可正常转发、摘要恢复。"""
if self.bot.cardkit is None:
logger.error("[Lark] API Client cardkit 模块未初始化")
return
settings_json = json.dumps(
{"config": {"streaming_mode": False}},
ensure_ascii=False,
)
request = (
SettingsCardRequest.builder()
.card_id(card_id)
.request_body(
SettingsCardRequestBody.builder()
.settings(settings_json)
.sequence(sequence)
.uuid(str(uuid.uuid4()))
.build()
)
.build()
)
try:
response = await self.bot.cardkit.v1.card.asettings(request)
except Exception as e:
logger.error(f"[Lark] 关闭流式模式失败: {e}")
return
if not response.success():
logger.error(f"[Lark] 关闭流式模式失败({response.code}): {response.msg}")
else:
logger.debug(f"[Lark] 流式模式已关闭: {card_id}")
async def _fallback_send_streaming(self, generator, use_fallback: bool = False):
"""回退到非流式发送:缓冲全部文本后一次性发送,并保留父类副作用。"""
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
if buffer:
buffer.squash_plain()
await self.send(buffer)
await Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
self._has_send_oper = True
async def send_streaming(self, generator, use_fallback: bool = False):
"""使用 CardKit 流式卡片实现打字机效果。
流程创建卡片实体 发送消息 流式更新文本 关闭流式模式
使用解耦发送循环LLM token 到达时只更新 buffer 并唤醒发送协程
发送频率由网络 RTT 自然限流
"""
# Step 1: 创建流式卡片实体
card_id = await self._create_streaming_card()
if not card_id:
logger.warning("[Lark] 无法创建流式卡片,回退到非流式发送")
await self._fallback_send_streaming(generator, use_fallback)
return
# Step 2: 发送卡片消息
sent = await self._send_card_message(
card_id,
reply_message_id=self.message_obj.message_id,
)
if not sent:
logger.error("[Lark] 发送流式卡片消息失败,回退到非流式发送")
await self._fallback_send_streaming(generator, use_fallback)
return
logger.info("[Lark] 流式输出: 使用 CardKit 流式卡片")
# Step 3: 解耦发送循环 (Event-driven, 参考 Telegram Draft 路径)
sequence = 0
delta = ""
last_sent = ""
done = False
text_changed = asyncio.Event()
async def _sender_loop() -> None:
"""信号驱动的文本发送循环,有新内容就发,RTT 自然限流。"""
nonlocal sequence, last_sent
while not done:
await text_changed.wait()
text_changed.clear()
snapshot = delta
if snapshot and snapshot != last_sent:
sequence += 1
ok = await self._update_streaming_text(card_id, snapshot, sequence)
if ok:
last_sent = snapshot
if delta != snapshot:
text_changed.set()
sender_task = asyncio.create_task(_sender_loop())
try:
async for chain in generator:
if not isinstance(chain, MessageChain):
continue
if chain.type == "break":
# 飞书卡片不支持分段,忽略 break
continue
for comp in chain.chain:
if isinstance(comp, Plain):
delta += comp.text
text_changed.set()
finally:
done = True
text_changed.set()
await sender_task
# Step 4: 必要时补发最终文本 + 关闭流式模式
if delta and delta != last_sent:
sequence += 1
await self._update_streaming_text(card_id, delta, sequence)
sequence += 1
await self._close_streaming_mode(card_id, sequence)
# Step 5: 内联父类 send_streaming 的副作用
await Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
self._has_send_oper = True
@@ -104,7 +104,7 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_image_url(segment: Image) -> str:
candidate = (segment.url or segment.file or "").strip()
if candidate.startswith("http://") or candidate.startswith("https://"):
if candidate.startswith("https://"):
return candidate
try:
return await segment.register_to_file_service()
@@ -115,7 +115,7 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_record_url(segment: Record) -> str:
candidate = (segment.url or segment.file or "").strip()
if candidate.startswith("http://") or candidate.startswith("https://"):
if candidate.startswith("https://"):
return candidate
try:
return await segment.register_to_file_service()
@@ -137,7 +137,7 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_video_url(segment: Video) -> str:
candidate = (segment.file or "").strip()
if candidate.startswith("http://") or candidate.startswith("https://"):
if candidate.startswith("https://"):
return candidate
try:
return await segment.register_to_file_service()
@@ -148,9 +148,7 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_video_preview_url(segment: Video) -> str:
cover_candidate = (segment.cover or "").strip()
if cover_candidate.startswith("http://") or cover_candidate.startswith(
"https://"
):
if cover_candidate.startswith("https://"):
return cover_candidate
if cover_candidate:
@@ -191,7 +189,7 @@ class LineMessageEvent(AstrMessageEvent):
@staticmethod
async def _resolve_file_url(segment: File) -> str:
if segment.url and segment.url.startswith(("http://", "https://")):
if segment.url and segment.url.startswith("https://"):
return segment.url
try:
return await segment.register_to_file_service()
@@ -18,7 +18,7 @@ from botpy.types.message import MarkdownPayload, Media
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Image, Plain, Record
from astrbot.api.message_components import File, Image, Plain, Record, Video
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.io import download_image_by_url, file_to_base64
@@ -47,6 +47,11 @@ _patch_qq_botpy_formdata()
class QQOfficialMessageEvent(AstrMessageEvent):
MARKDOWN_NOT_ALLOWED_ERROR = "不允许发送原生 markdown"
IMAGE_FILE_TYPE = 1
VIDEO_FILE_TYPE = 2
VOICE_FILE_TYPE = 3
FILE_FILE_TYPE = 4
STREAM_MARKDOWN_NEWLINE_ERROR = "流式消息md分片需要\\n结束"
def __init__(
self,
@@ -65,35 +70,71 @@ class QQOfficialMessageEvent(AstrMessageEvent):
await self._post_send()
async def send_streaming(self, generator, use_fallback: bool = False):
"""流式输出仅支持消息列表私聊"""
"""流式输出仅支持消息列表私聊C2C),其他消息源退化为普通发送"""
# 先标记事件层“已执行发送操作”,避免异常路径遗漏
await super().send_streaming(generator, use_fallback)
# QQ C2C 流式协议:开始/中间分片使用 state=1,结束分片使用 state=10
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
last_edit_time = 0 # 上次发送分片的时间
throttle_interval = 1 # 分片间最短间隔 (秒)
ret = None
source = (
self.message_obj.raw_message
) # 提前获取,避免 generator 为空时 NameError
try:
async for chain in generator:
source = self.message_obj.raw_message
if not isinstance(source, botpy.message.C2CMessage):
# 非 C2C 场景:直接累积,最后统一发
if not self.send_buffer:
self.send_buffer = chain
else:
self.send_buffer.chain.extend(chain.chain)
continue
# ---- C2C 流式场景 ----
# tool_call break 信号:工具开始执行,先把已有 buffer 以 state=10 结束当前流式段
if chain.type == "break":
if self.send_buffer:
stream_payload["state"] = 10
ret = await self._post_send(stream=stream_payload)
ret_id = self._extract_response_message_id(ret)
if ret_id is not None:
stream_payload["id"] = ret_id
# 重置 stream_payload,为下一段流式做准备
stream_payload = {
"state": 1,
"id": None,
"index": 0,
"reset": False,
}
last_edit_time = 0
continue
# 累积内容
if not self.send_buffer:
self.send_buffer = chain
else:
self.send_buffer.chain.extend(chain.chain)
if isinstance(source, botpy.message.C2CMessage):
# 真流式传输
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
ret = cast(
message.Message,
await self._post_send(stream=stream_payload),
)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
# 节流:按时间间隔发送中间分片
current_time = asyncio.get_running_loop().time()
if current_time - last_edit_time >= throttle_interval:
ret = cast(
message.Message,
await self._post_send(stream=stream_payload),
)
stream_payload["index"] += 1
ret_id = self._extract_response_message_id(ret)
if ret_id is not None:
stream_payload["id"] = ret_id
last_edit_time = asyncio.get_running_loop().time()
self.send_buffer = None # 清空已发送的分片,避免下次重复发送旧内容
if isinstance(source, botpy.message.C2CMessage):
# 结束流式对话,并且传输 buffer 中剩余的消息
# 结束流式对话,发送 buffer 中剩余内容
stream_payload["state"] = 10
ret = await self._post_send(stream=stream_payload)
else:
@@ -101,9 +142,22 @@ class QQOfficialMessageEvent(AstrMessageEvent):
except Exception as e:
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
# 避免累计内容在异常后被整包重复发送:仅清理缓存,不做非流式整包兜底
# 如需兜底,应该只发送未发送 delta(后续可继续优化)
self.send_buffer = None
return await super().send_streaming(generator, use_fallback)
return None
@staticmethod
def _extract_response_message_id(ret) -> str | None:
"""兼容 qq-botpy 返回 Message 对象或 dict 两种形态。"""
if ret is None:
return None
if isinstance(ret, dict):
ret_id = ret.get("id")
return str(ret_id) if ret_id is not None else None
ret_id = getattr(ret, "id", None)
return str(ret_id) if ret_id is not None else None
async def _post_send(self, stream: dict | None = None):
if not self.send_buffer:
@@ -126,16 +180,37 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64,
image_path,
record_file_path,
video_file_source,
file_source,
file_name,
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
# C2C 流式仅用于文本分片,富媒体时降级为普通发送,避免平台侧流式校验报错。
if stream and (image_base64 or record_file_path):
logger.debug("[QQOfficial] 检测到富媒体,降级为非流式发送。")
stream = None
if (
not plain_text
and not image_base64
and not image_path
and not record_file_path
and not video_file_source
and not file_source
):
return None
# QQ C2C 流式 API 说明:
# - 开始/中间分片(state=1):增量追加内容,不需要 \n(加了会导致强制换行)
# - 最终分片(state=10):结束流,content 必须以 \n 结尾(QQ API 要求)
if (
stream
and stream.get("state") == 10
and plain_text
and not plain_text.endswith("\n")
):
plain_text = plain_text + "\n"
payload: dict = {
# "content": plain_text,
"markdown": MarkdownPayload(content=plain_text) if plain_text else None,
@@ -157,7 +232,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
1,
self.IMAGE_FILE_TYPE,
group_openid=source.group_openid,
)
payload["media"] = media
@@ -165,15 +240,39 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload.pop("markdown", None)
payload["content"] = plain_text or None
if record_file_path: # group record msg
media = await self.upload_group_and_c2c_record(
media = await self.upload_group_and_c2c_media(
record_file_path,
3,
self.VOICE_FILE_TYPE,
group_openid=source.group_openid,
)
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if video_file_source:
media = await self.upload_group_and_c2c_media(
video_file_source,
self.VIDEO_FILE_TYPE,
group_openid=source.group_openid,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if file_source:
media = await self.upload_group_and_c2c_media(
file_source,
self.FILE_FILE_TYPE,
file_name=file_name,
group_openid=source.group_openid,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
ret = await self._send_with_markdown_fallback(
send_func=lambda retry_payload: self.bot.api.post_group_message(
group_openid=source.group_openid, # type: ignore
@@ -181,13 +280,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
),
payload=payload,
plain_text=plain_text,
stream=stream,
)
case botpy.message.C2CMessage():
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
1,
self.IMAGE_FILE_TYPE,
openid=source.author.user_openid,
)
payload["media"] = media
@@ -195,15 +295,39 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload.pop("markdown", None)
payload["content"] = plain_text or None
if record_file_path: # c2c record
media = await self.upload_group_and_c2c_record(
media = await self.upload_group_and_c2c_media(
record_file_path,
3,
self.VOICE_FILE_TYPE,
openid=source.author.user_openid,
)
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if video_file_source:
media = await self.upload_group_and_c2c_media(
video_file_source,
self.VIDEO_FILE_TYPE,
openid=source.author.user_openid,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if file_source:
media = await self.upload_group_and_c2c_media(
file_source,
self.FILE_FILE_TYPE,
file_name=file_name,
openid=source.author.user_openid,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("markdown", None)
payload["content"] = plain_text or None
if stream:
ret = await self._send_with_markdown_fallback(
send_func=lambda retry_payload: self.post_c2c_message(
@@ -213,6 +337,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
),
payload=payload,
plain_text=plain_text,
stream=stream,
)
else:
ret = await self._send_with_markdown_fallback(
@@ -222,6 +347,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
),
payload=payload,
plain_text=plain_text,
stream=stream,
)
logger.debug(f"Message sent to C2C: {ret}")
@@ -237,6 +363,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
),
payload=payload,
plain_text=plain_text,
stream=stream,
)
case botpy.message.DirectMessage():
@@ -251,6 +378,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
),
payload=payload,
plain_text=plain_text,
stream=stream,
)
case _:
@@ -267,10 +395,31 @@ class QQOfficialMessageEvent(AstrMessageEvent):
send_func,
payload: dict,
plain_text: str,
stream: dict | None = None,
):
try:
return await send_func(payload)
except botpy.errors.ServerError as err:
# QQ 流式 markdown 分片校验:内容必须以换行结尾。
# 某些边界场景服务端仍可能判定失败,这里做一次修正重试。
if stream and self.STREAM_MARKDOWN_NEWLINE_ERROR in str(err):
retry_payload = payload.copy()
markdown_payload = retry_payload.get("markdown")
if isinstance(markdown_payload, dict):
md_content = cast(str, markdown_payload.get("content", "") or "")
if md_content and not md_content.endswith("\n"):
retry_payload["markdown"] = {"content": md_content + "\n"}
content = cast(str | None, retry_payload.get("content"))
if content and not content.endswith("\n"):
retry_payload["content"] = content + "\n"
logger.warning(
"[QQOfficial] 流式 markdown 分片换行校验失败,已修正后重试一次。"
)
return await send_func(retry_payload)
if (
self.MARKDOWN_NOT_ALLOWED_ERROR not in str(err)
or not payload.get("markdown")
@@ -282,10 +431,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"[QQOfficial] markdown 发送被拒绝,回退到 content 模式重试。"
)
fallback_payload = payload.copy()
fallback_payload["markdown"] = None
fallback_payload.pop("markdown", None)
fallback_payload["content"] = plain_text
if fallback_payload.get("msg_type") == 2:
fallback_payload["msg_type"] = 0
if stream:
fallback_content = cast(str, fallback_payload.get("content") or "")
if fallback_content and not fallback_content.endswith("\n"):
fallback_payload["content"] = fallback_content + "\n"
return await send_func(fallback_payload)
async def upload_group_and_c2c_image(
@@ -327,16 +480,19 @@ class QQOfficialMessageEvent(AstrMessageEvent):
ttl=result.get("ttl", 0),
)
async def upload_group_and_c2c_record(
async def upload_group_and_c2c_media(
self,
file_source: str,
file_type: int,
srv_send_msg: bool = False,
file_name: str | None = None,
**kwargs,
) -> Media | None:
"""上传媒体文件"""
# 构建基础payload
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
if file_name:
payload["file_name"] = file_name
# 处理文件数据
if os.path.exists(file_source):
@@ -400,13 +556,21 @@ class QQOfficialMessageEvent(AstrMessageEvent):
) -> message.Message:
payload = locals()
payload.pop("self", None)
# QQ API does not accept stream.id=None; remove it when not yet assigned
if "stream" in payload and payload["stream"] is not None:
stream_data = dict(payload["stream"])
if stream_data.get("id") is None:
stream_data.pop("id", None)
payload["stream"] = stream_data
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
result = await self.bot.api._http.request(route, json=payload)
if result is None:
logger.warning("[QQOfficial] post_c2c_message: API 返回 None,跳过本次发送")
return None
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to post c2c message, response is not dict: {result}"
)
logger.error(f"[QQOfficial] post_c2c_message: 响应不是 dict: {result}")
return None
return message.Message(**result)
@@ -416,6 +580,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = None # only one img supported
image_file_path = None
record_file_path = None
video_file_source = None
file_source = None
file_name = None
for i in message.chain:
if isinstance(i, Plain):
plain_text += i.text
@@ -454,6 +621,30 @@ class QQOfficialMessageEvent(AstrMessageEvent):
except Exception as e:
logger.error(f"处理语音时出错: {e}")
record_file_path = None
elif isinstance(i, Video) and not video_file_source:
if i.file.startswith("file:///"):
video_file_source = i.file[8:]
else:
video_file_source = i.file
elif isinstance(i, File) and not file_source:
file_name = i.name
if i.file_:
file_path = i.file_
if file_path.startswith("file:///"):
file_path = file_path[8:]
elif file_path.startswith("file://"):
file_path = file_path[7:]
file_source = file_path
elif i.url:
file_source = i.url
else:
logger.debug(f"qq_official 忽略 {i.type}")
return plain_text, image_base64, image_file_path, record_file_path
return (
plain_text,
image_base64,
image_file_path,
record_file_path,
video_file_source,
file_source,
file_name,
)
@@ -3,8 +3,10 @@ from __future__ import annotations
import asyncio
import logging
import os
import random
import time
from typing import cast
from types import SimpleNamespace
from typing import Any, cast
import botpy
import botpy.message
@@ -12,7 +14,7 @@ from botpy import Client
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, File, Image, Plain
from astrbot.api.message_components import At, File, Image, Plain, Record, Video
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
@@ -46,6 +48,7 @@ class botClient(Client):
)
abm.group_id = cast(str, message.group_openid)
abm.session_id = abm.group_id
self.platform.remember_session_scene(abm.session_id, "group")
self._commit(abm)
# 收到频道消息
@@ -56,6 +59,7 @@ class botClient(Client):
)
abm.group_id = message.channel_id
abm.session_id = abm.group_id
self.platform.remember_session_scene(abm.session_id, "channel")
self._commit(abm)
# 收到私聊消息
@@ -67,6 +71,7 @@ class botClient(Client):
MessageType.FRIEND_MESSAGE,
)
abm.session_id = abm.sender.user_id
self.platform.remember_session_scene(abm.session_id, "friend")
self._commit(abm)
# 收到 C2C 消息
@@ -76,9 +81,11 @@ class botClient(Client):
MessageType.FRIEND_MESSAGE,
)
abm.session_id = abm.sender.user_id
self.platform.remember_session_scene(abm.session_id, "friend")
self._commit(abm)
def _commit(self, abm: AstrBotMessage) -> None:
self.platform.remember_session_message_id(abm.session_id, abm.message_id)
self.platform.commit_event(
QQOfficialMessageEvent(
abm.message_str,
@@ -124,6 +131,9 @@ class QQOfficialPlatformAdapter(Platform):
self.client.set_platform(self)
self._session_last_message_id: dict[str, str] = {}
self._session_scene: dict[str, str] = {}
self.test_mode = os.environ.get("TEST_MODE", "off") == "on"
async def send_by_session(
@@ -131,14 +141,191 @@ class QQOfficialPlatformAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
) -> None:
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
await self._send_by_session_common(session, message_chain)
async def _send_by_session_common(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
(
plain_text,
image_base64,
image_path,
record_file_path,
video_file_source,
file_source,
file_name,
) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain)
if (
not plain_text
and not image_path
and not image_base64
and not record_file_path
and not video_file_source
and not file_source
):
return
msg_id = self._session_last_message_id.get(session.session_id)
if not msg_id:
logger.warning(
"[QQOfficial] No cached msg_id for session: %s, skip send_by_session",
session.session_id,
)
return
payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id}
ret: Any = None
send_helper = SimpleNamespace(bot=self.client)
if session.message_type == MessageType.GROUP_MESSAGE:
scene = self._session_scene.get(session.session_id)
if scene == "group":
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
QQOfficialMessageEvent.IMAGE_FILE_TYPE,
group_openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
record_file_path,
QQOfficialMessageEvent.VOICE_FILE_TYPE,
group_openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
if video_file_source:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
video_file_source,
QQOfficialMessageEvent.VIDEO_FILE_TYPE,
group_openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("msg_id", None)
if file_source:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
file_source,
QQOfficialMessageEvent.FILE_FILE_TYPE,
file_name=file_name,
group_openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("msg_id", None)
ret = await self.client.api.post_group_message(
group_openid=session.session_id,
**payload,
)
else:
if image_path:
payload["file_image"] = image_path
ret = await self.client.api.post_message(
channel_id=session.session_id,
**payload,
)
elif session.message_type == MessageType.FRIEND_MESSAGE:
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
QQOfficialMessageEvent.IMAGE_FILE_TYPE,
openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
record_file_path,
QQOfficialMessageEvent.VOICE_FILE_TYPE,
openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
if video_file_source:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
video_file_source,
QQOfficialMessageEvent.VIDEO_FILE_TYPE,
openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
# QQ API rejects msg_id for media (video/file) messages sent
# via the proactive tool-call path; remove it to avoid 越权 error.
payload.pop("msg_id", None)
if file_source:
media = await QQOfficialMessageEvent.upload_group_and_c2c_media(
send_helper, # type: ignore
file_source,
QQOfficialMessageEvent.FILE_FILE_TYPE,
file_name=file_name,
openid=session.session_id,
)
if media:
payload["media"] = media
payload["msg_type"] = 7
payload.pop("msg_id", None)
ret = await QQOfficialMessageEvent.post_c2c_message(
send_helper, # type: ignore
openid=session.session_id,
**payload,
)
else:
logger.warning(
"[QQOfficial] Unsupported message type for send_by_session: %s",
session.message_type,
)
return
sent_message_id = self._extract_message_id(ret)
if sent_message_id:
self.remember_session_message_id(session.session_id, sent_message_id)
await super().send_by_session(session, message_chain)
def remember_session_message_id(self, session_id: str, message_id: str) -> None:
if not session_id or not message_id:
return
self._session_last_message_id[session_id] = message_id
def remember_session_scene(self, session_id: str, scene: str) -> None:
if not session_id or not scene:
return
self._session_scene[session_id] = scene
def _extract_message_id(self, ret: Any) -> str | None:
if isinstance(ret, dict):
message_id = ret.get("id")
return str(message_id) if message_id else None
message_id = getattr(ret, "id", None)
if message_id:
return str(message_id)
return None
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="qq_official",
description="QQ 机器人官方 API 适配器",
id=cast(str, self.config.get("id")),
support_proactive_message=False,
support_proactive_message=True,
)
@staticmethod
@@ -158,7 +345,10 @@ class QQOfficialPlatformAdapter(Platform):
return
for attachment in attachments:
content_type = cast(str, getattr(attachment, "content_type", "") or "")
content_type = cast(
str,
getattr(attachment, "content_type", "") or "",
).lower()
url = QQOfficialPlatformAdapter._normalize_attachment_url(
cast(str | None, getattr(attachment, "url", None))
)
@@ -174,7 +364,32 @@ class QQOfficialPlatformAdapter(Platform):
or getattr(attachment, "name", None)
or "attachment",
)
msg.append(File(name=filename, file=url, url=url))
ext = os.path.splitext(filename)[1].lower()
image_exts = {".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp"}
audio_exts = {
".mp3",
".wav",
".ogg",
".m4a",
".amr",
".silk",
}
video_exts = {
".mp4",
".mov",
".avi",
".mkv",
".webm",
}
if content_type.startswith("audio") or ext in audio_exts:
msg.append(Record.fromURL(url))
elif content_type.startswith("video") or ext in video_exts:
msg.append(Video.fromURL(url))
elif content_type.startswith("image") or ext in image_exts:
msg.append(Image.fromURL(url))
else:
msg.append(File(name=filename, file=url, url=url))
@staticmethod
def _parse_from_qqofficial(
@@ -1,7 +1,5 @@
import asyncio
import logging
import random
from types import SimpleNamespace
from typing import Any, cast
import botpy
@@ -15,7 +13,6 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
from .qo_webhook_event import QQOfficialWebhookMessageEvent
from .qo_webhook_server import QQOfficialWebhook
@@ -123,95 +120,11 @@ class QQOfficialWebhookPlatformAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
) -> None:
(
plain_text,
image_base64,
image_path,
record_file_path,
) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain)
if not plain_text and not image_path:
return
msg_id = self._session_last_message_id.get(session.session_id)
if not msg_id:
logger.warning(
"[QQOfficialWebhook] No cached msg_id for session: %s, skip send_by_session",
session.session_id,
)
return
payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id}
ret: Any = None
send_helper = SimpleNamespace(bot=self.client)
if session.message_type == MessageType.GROUP_MESSAGE:
scene = self._session_scene.get(session.session_id)
if scene == "group":
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
1,
group_openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
send_helper, # type: ignore
record_file_path,
3,
group_openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
ret = await self.client.api.post_group_message(
group_openid=session.session_id,
**payload,
)
else:
if image_path:
payload["file_image"] = image_path
ret = await self.client.api.post_message(
channel_id=session.session_id,
**payload,
)
elif session.message_type == MessageType.FRIEND_MESSAGE:
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
1,
openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
send_helper, # type: ignore
record_file_path,
3,
openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
ret = await QQOfficialMessageEvent.post_c2c_message(
send_helper, # type: ignore
openid=session.session_id,
**payload,
)
else:
logger.warning(
"[QQOfficialWebhook] Unsupported message type for send_by_session: %s",
session.message_type,
)
return
sent_message_id = self._extract_message_id(ret)
if sent_message_id:
self.remember_session_message_id(session.session_id, sent_message_id)
await super().send_by_session(session, message_chain)
await QQOfficialPlatformAdapter._send_by_session_common(
cast(Any, self),
session,
message_chain,
)
def remember_session_message_id(self, session_id: str, message_id: str) -> None:
if not session_id or not message_id:
@@ -55,7 +55,7 @@ class QQOfficialWebhook:
max_async=1,
connect=bot_connect,
dispatch=self.client.ws_dispatch,
loop=asyncio.get_event_loop(),
loop=asyncio.get_running_loop(),
api=self.api,
)
+308 -108
View File
@@ -1,6 +1,7 @@
import asyncio
import os
import re
from collections.abc import Callable
from typing import Any, cast
import telegramify_markdown
@@ -21,6 +22,7 @@ from astrbot.api.message_components import (
Video,
)
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
from astrbot.core.utils.metrics import Metric
class TelegramPlatformEvent(AstrMessageEvent):
@@ -34,6 +36,20 @@ class TelegramPlatformEvent(AstrMessageEvent):
"word": re.compile(r"\s"),
}
# sendMessageDraft 的 draft_id 类级递增计数器
_TELEGRAM_DRAFT_ID_MAX = 2_147_483_647
_next_draft_id: int = 0
@classmethod
def _allocate_draft_id(cls) -> int:
"""分配一个递增的 draft_id,溢出时归 1。"""
cls._next_draft_id = (
1
if cls._next_draft_id >= cls._TELEGRAM_DRAFT_ID_MAX
else cls._next_draft_id + 1
)
return cls._next_draft_id
# 消息类型到 chat action 的映射,用于优先级判断
ACTION_BY_TYPE: dict[type, str] = {
Record: ChatAction.UPLOAD_VOICE,
@@ -262,7 +278,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
md_text = telegramify_markdown.markdownify(
chunk,
normalize_whitespace=False,
)
await client.send_message(
text=md_text,
@@ -339,6 +354,117 @@ class TelegramPlatformEvent(AstrMessageEvent):
except Exception as e:
logger.error(f"[Telegram] 添加反应失败: {e}")
async def _send_message_draft(
self,
chat_id: str,
draft_id: int,
text: str,
message_thread_id: str | None = None,
parse_mode: str | None = None,
) -> None:
"""通过 Bot.send_message_draft 发送草稿消息(流式推送部分消息)。
API 仅支持私聊
Args:
chat_id: 目标私聊的 chat_id
draft_id: 草稿唯一标识非零整数相同 draft_id 的变更会以动画展示
text: 消息文本1-4096 字符
message_thread_id: 可选目标消息线程 ID
parse_mode: 可选消息文本的解析模式
"""
kwargs: dict[str, Any] = {}
if message_thread_id:
kwargs["message_thread_id"] = int(message_thread_id)
if parse_mode:
kwargs["parse_mode"] = parse_mode
try:
logger.debug(
f"[Telegram] sendMessageDraft: chat_id={chat_id}, draft_id={draft_id}, text_len={len(text)}"
)
await self.client.send_message_draft(
chat_id=int(chat_id),
draft_id=draft_id,
text=text,
**kwargs,
)
except Exception as e:
logger.warning(f"[Telegram] sendMessageDraft 失败: {e!s}")
async def _process_chain_items(
self,
chain: MessageChain,
payload: dict[str, Any],
user_name: str,
message_thread_id: str | None,
on_text: Callable[[str], None],
) -> None:
"""处理 MessageChain 中的各类组件,文本通过 on_text 回调追加,媒体直接发送。"""
for i in chain.chain:
if isinstance(i, Plain):
on_text(i.text)
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_PHOTO,
self.client.send_photo,
user_name=user_name,
photo=image_path,
**cast(Any, payload),
)
elif isinstance(i, File):
path = await i.get_file()
name = i.name or os.path.basename(path)
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_DOCUMENT,
self.client.send_document,
user_name=user_name,
document=path,
filename=name,
**cast(Any, payload),
)
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self._send_voice_with_fallback(
self.client,
path,
payload,
caption=i.text or None,
user_name=user_name,
message_thread_id=message_thread_id,
use_media_action=True,
)
elif isinstance(i, Video):
path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_VIDEO,
self.client.send_video,
user_name=user_name,
video=path,
**cast(Any, payload),
)
else:
logger.warning(f"不支持的消息类型: {type(i)}")
async def _send_final_segment(self, delta: str, payload: dict[str, Any]) -> None:
"""将累积文本作为 MarkdownV2 真实消息发送,失败时回退到纯文本。"""
try:
markdown_text = telegramify_markdown.markdownify(
delta,
)
await self.client.send_message(
text=markdown_text,
parse_mode="MarkdownV2",
**cast(Any, payload),
)
except Exception as e:
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
await self.client.send_message(text=delta, **cast(Any, payload))
async def send_streaming(self, generator, use_fallback: bool = False):
message_thread_id = None
@@ -356,6 +482,137 @@ class TelegramPlatformEvent(AstrMessageEvent):
if message_thread_id:
payload["message_thread_id"] = message_thread_id
# sendMessageDraft 仅支持私聊(显式检查 FRIEND_MESSAGE
is_private = self.get_message_type() == MessageType.FRIEND_MESSAGE
if is_private:
logger.info("[Telegram] 流式输出: 使用 sendMessageDraft (私聊)")
await self._send_streaming_draft(
user_name, message_thread_id, payload, generator
)
else:
logger.info("[Telegram] 流式输出: 使用 edit_message_text fallback (群聊)")
await self._send_streaming_edit(
user_name, message_thread_id, payload, generator
)
# 内联父类 send_streaming 的副作用(避免传入已消费的 generator)
asyncio.create_task(
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name),
)
self._has_send_oper = True
async def _send_streaming_draft(
self,
user_name: str,
message_thread_id: str | None,
payload: dict[str, Any],
generator,
) -> None:
"""使用 sendMessageDraft API 进行流式推送(私聊专用)。
流式过程中使用 sendMessageDraft 推送草稿动画
流式结束后发送一条真实消息保留最终内容draft 是临时的会消失
使用信号驱动的发送循环每次有新 token 到达时唤醒发送
发送频率由网络 RTT 自然限制最多一个请求 in-flight
"""
draft_id = self._allocate_draft_id()
delta = ""
last_sent_text = ""
done = False # 信号:生成器已结束
text_changed = asyncio.Event() # 有新 token 到达时触发
async def _draft_sender_loop() -> None:
"""信号驱动的草稿发送循环,有新内容就发,RTT 自然限流。"""
nonlocal last_sent_text
while not done:
await text_changed.wait()
text_changed.clear()
# 发送最新的缓冲区内容(MarkdownV2 渲染,与真实消息一致)
if delta and delta != last_sent_text:
draft_text = delta[: self.MAX_MESSAGE_LENGTH]
if draft_text != last_sent_text:
try:
md = telegramify_markdown.markdownify(
draft_text,
)
await self._send_message_draft(
user_name,
draft_id,
md,
message_thread_id,
parse_mode="MarkdownV2",
)
last_sent_text = draft_text
except Exception:
# markdownify 对未闭合语法可能失败,回退纯文本
try:
await self._send_message_draft(
user_name,
draft_id,
draft_text,
message_thread_id,
)
last_sent_text = draft_text
except Exception as e2:
logger.debug(
f"[Telegram] sendMessageDraft failed (ignored): {e2!s}"
)
sender_task = asyncio.create_task(_draft_sender_loop())
def _append_text(t: str) -> None:
nonlocal delta
delta += t
text_changed.set() # 唤醒发送循环
try:
async for chain in generator:
if not isinstance(chain, MessageChain):
continue
if chain.type == "break":
# 分割符:发送真实消息保留内容,重置缓冲区
if delta:
# 用 emoji 清空 draft 显示,避免 draft 和真实消息同时可见
await self._send_message_draft(
user_name,
draft_id,
"\u23f3",
message_thread_id,
)
await self._send_final_segment(delta, payload)
delta = ""
last_sent_text = ""
draft_id = self._allocate_draft_id()
continue
await self._process_chain_items(
chain, payload, user_name, message_thread_id, _append_text
)
finally:
done = True
text_changed.set() # 唤醒循环使其退出
await sender_task
# 流式结束:用 emoji 清空 draft,然后发真实消息持久化
if delta:
await self._send_message_draft(
user_name,
draft_id,
"\u23f3",
message_thread_id,
)
await self._send_final_segment(delta, payload)
async def _send_streaming_edit(
self,
user_name: str,
message_thread_id: str | None,
payload: dict[str, Any],
generator,
) -> None:
"""使用 send_message + edit_message_text 进行流式推送(群聊 fallback)。"""
delta = ""
current_content = ""
message_id = None
@@ -366,130 +623,75 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 发送初始 typing 状态
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = asyncio.get_event_loop().time()
last_chat_action_time = asyncio.get_running_loop().time()
def _append_text(t: str) -> None:
nonlocal delta
delta += t
async for chain in generator:
if isinstance(chain, MessageChain):
if chain.type == "break":
# 分割符
if message_id:
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
message_id = None # 重置消息 ID
delta = "" # 重置 delta
continue
if not isinstance(chain, MessageChain):
continue
# 处理消息链中的每个组件
for i in chain.chain:
if isinstance(i, Plain):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_PHOTO,
self.client.send_photo,
user_name=user_name,
photo=image_path,
**cast(Any, payload),
if chain.type == "break":
# 分割符
if message_id:
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
continue
elif isinstance(i, File):
path = await i.get_file()
name = i.name or os.path.basename(path)
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_DOCUMENT,
self.client.send_document,
user_name=user_name,
document=path,
filename=name,
**cast(Any, payload),
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self._send_voice_with_fallback(
self.client,
path,
payload,
caption=i.text or delta or None,
user_name=user_name,
message_thread_id=message_thread_id,
use_media_action=True,
)
continue
elif isinstance(i, Video):
path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_VIDEO,
self.client.send_video,
user_name=user_name,
video=path,
**cast(Any, payload),
)
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
continue
except Exception as e:
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
message_id = None
delta = ""
continue
# Plain
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
current_time = asyncio.get_event_loop().time()
time_since_last_edit = current_time - last_edit_time
await self._process_chain_items(
chain, payload, user_name, message_thread_id, _append_text
)
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
if time_since_last_edit >= throttle_interval:
# 发送 typing 状态(带节流)
current_time = asyncio.get_event_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
# 编辑消息
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
current_content = delta
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
last_edit_time = (
asyncio.get_event_loop().time()
) # 更新上次编辑的时间
else:
# delta 长度一般不会大于 4096,因此这里直接发送
# 发送 typing 状态(带节流)
current_time = asyncio.get_event_loop().time()
# 编辑或发送消息
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
current_time = asyncio.get_running_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
current_time = asyncio.get_running_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
try:
msg = await self.client.send_message(
text=delta, **cast(Any, payload)
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = (
asyncio.get_event_loop().time()
) # 记录初始消息发送时间
logger.warning(f"编辑消息失败(streaming): {e!s}")
last_edit_time = asyncio.get_running_loop().time()
else:
current_time = asyncio.get_running_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
try:
msg = await self.client.send_message(
text=delta, **cast(Any, payload)
)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = asyncio.get_running_loop().time()
try:
if delta and current_content != delta:
try:
markdown_text = telegramify_markdown.markdownify(
delta,
normalize_whitespace=False,
)
await self.client.edit_message_text(
text=markdown_text,
@@ -506,5 +708,3 @@ class TelegramPlatformEvent(AstrMessageEvent):
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
return await super().send_streaming(generator, use_fallback)
@@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform):
return msg_list[-1]
return None
msg_new = await asyncio.get_event_loop().run_in_executor(
msg_new = await asyncio.get_running_loop().run_in_executor(
None,
get_latest_msg_item,
)
@@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform):
@override
async def run(self) -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
if self.kf_name:
try:
acc_list = (
@@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform):
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif isinstance(msg, VoiceMessage):
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
msg.media_id,
@@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform):
abm.message_str = text
elif msgtype == "image":
media_id = msg.get("image", {}).get("media_id", "")
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
media_id,
@@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform):
abm.message = [Image(file=path, url=path)]
elif msgtype == "voice":
media_id = msg.get("voice", {}).get("media_id", "")
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
media_id,
@@ -1,5 +1,5 @@
"""企业微信智能机器人平台适配器
基于企业微信智能机器人 API 的消息平台适配器支持 HTTP 回调
基于企业微信智能机器人 API 的消息平台适配器支持 HTTP 回调与长连接
参考webchat_adapter.py的队列机制实现异步消息处理和流式响应
"""
@@ -31,6 +31,7 @@ from .wecomai_api import (
WecomAIBotStreamMessageBuilder,
)
from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_long_connection import WecomAIBotLongConnectionClient
from .wecomai_queue_mgr import WecomAIQueueMgr
from .wecomai_server import WecomAIBotServer
from .wecomai_utils import (
@@ -78,8 +79,13 @@ class WecomAIBotAdapter(Platform):
self.settings = platform_settings
# 初始化配置参数
self.token = self.config["token"]
self.encoding_aes_key = self.config["encoding_aes_key"]
self.connection_mode = self.config.get(
"wecom_ai_bot_connection_mode", "webhook"
)
self.token = self.config.get("token", self.config.get("wecomaibot_token", ""))
self.encoding_aes_key = self.config.get(
"encoding_aes_key", self.config.get("wecomaibot_encoding_aes_key", "")
)
self.port = int(self.config["port"])
self.host = self.config.get("callback_server_host", "0.0.0.0")
self.bot_name = self.config.get("wecom_ai_bot_name", "")
@@ -96,25 +102,52 @@ class WecomAIBotAdapter(Platform):
self.only_use_webhook_url_to_send = bool(
self.config.get("only_use_webhook_url_to_send", False),
)
self.long_connection_bot_id = self.config.get(
"wecomaibot_ws_bot_id", self.config.get("long_connection_bot_id", "")
)
self.long_connection_secret = self.config.get(
"wecomaibot_ws_secret", self.config.get("long_connection_secret", "")
)
self.long_connection_ws_url = self.config.get(
"wecomaibot_ws_url",
"wss://openws.work.weixin.qq.com",
)
self.long_connection_heartbeat_interval = int(
self.config.get("wecomaibot_heartbeat_interval", 30),
)
# 平台元数据
self.metadata = PlatformMetadata(
name="wecom_ai_bot",
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
description="企业微信智能机器人适配器,支持 HTTP 回调和长连接模式",
id=self.config.get("id", "wecom_ai_bot"),
support_proactive_message=bool(self.msg_push_webhook_url),
)
# 初始化 API 客户端
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
self.api_client: WecomAIBotAPIClient | None = None
self.server: WecomAIBotServer | None = None
self.long_connection_client: WecomAIBotLongConnectionClient | None = None
# 初始化 HTTP 服务器
self.server = WecomAIBotServer(
host=self.host,
port=self.port,
api_client=self.api_client,
message_handler=self._process_message,
)
if self.connection_mode == "long_connection":
if not self.long_connection_bot_id or not self.long_connection_secret:
logger.warning(
"企业微信智能机器人长连接模式缺少 BotID 或 Secret,连接可能失败"
)
self.long_connection_client = WecomAIBotLongConnectionClient(
bot_id=self.long_connection_bot_id,
secret=self.long_connection_secret,
ws_url=self.long_connection_ws_url,
heartbeat_interval=self.long_connection_heartbeat_interval,
message_handler=self._process_long_connection_payload,
)
else:
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
self.server = WecomAIBotServer(
host=self.host,
port=self.port,
api_client=self.api_client,
message_handler=self._process_message,
)
# 事件循环和关闭信号
self.shutdown_event = asyncio.Event()
@@ -161,6 +194,9 @@ class WecomAIBotAdapter(Platform):
加密后的响应消息无需响应时返回 None
"""
if not self.api_client:
logger.error("Webhook 消息处理失败: API 客户端未初始化")
return None
msgtype = message_data.get("msgtype")
if not msgtype:
logger.warning(f"消息类型未知,忽略: {message_data}")
@@ -320,6 +356,89 @@ class WecomAIBotAdapter(Platform):
logger.error("处理欢迎消息时发生异常: %s", e)
return None
async def _process_long_connection_payload(
self,
payload: dict[str, Any],
) -> None:
"""处理长连接回调消息。"""
cmd = payload.get("cmd")
headers = payload.get("headers") or {}
body = payload.get("body") or {}
req_id = headers.get("req_id")
if not isinstance(body, dict):
return
if cmd == "aibot_msg_callback":
session_id = self._extract_session_id(body)
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._enqueue_message(
body, {"req_id": req_id or ""}, stream_id, session_id
)
self.queue_mgr.set_pending_response(
stream_id,
{
"req_id": req_id or "",
"connection_mode": "long_connection",
},
)
if self.initial_respond_text and req_id:
await self._send_long_connection_respond_msg(
req_id=req_id,
body={
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": self.initial_respond_text,
},
},
)
return
if cmd == "aibot_event_callback":
event = body.get("event") or {}
event_type = event.get("eventtype")
if (
event_type == "enter_chat"
and self.friend_message_welcome_text
and req_id
):
await self._send_long_connection_respond_welcome(req_id)
elif event_type == "disconnected_event":
logger.warning(
"[WecomAI][LongConn] 收到 disconnected_event,旧连接将被关闭"
)
async def _send_long_connection_respond_welcome(self, req_id: str) -> bool:
client = self.long_connection_client
if not client:
return False
return await client.send_command(
cmd="aibot_respond_welcome_msg",
req_id=req_id,
body={
"msgtype": "text",
"text": {
"content": self.friend_message_welcome_text,
},
},
)
async def _send_long_connection_respond_msg(
self,
req_id: str,
body: dict[str, Any],
) -> bool:
client = self.long_connection_client
if not client:
return False
return await client.send_command(
cmd="aibot_respond_msg",
req_id=req_id,
body=body,
)
def _extract_session_id(self, message_data: dict[str, Any]) -> str:
"""从消息数据中提取会话ID"""
user_id = message_data.get("from", {}).get("userid", "default_user")
@@ -355,15 +474,16 @@ class WecomAIBotAdapter(Platform):
content = ""
image_base64 = []
_img_url_to_process = []
_img_url_to_process: list[tuple[str, str | None]] = []
msg_items = []
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
content = WecomAIBotMessageParser.parse_text_message(message_data)
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
_img_url_to_process.append(
WecomAIBotMessageParser.parse_image_message(message_data),
)
image_payload = message_data.get("image", {})
image_url = image_payload.get("url", "")
if image_url:
_img_url_to_process.append((image_url, image_payload.get("aeskey")))
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
# 提取混合消息中的文本内容
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
@@ -374,9 +494,12 @@ class WecomAIBotAdapter(Platform):
if text_content:
text_parts.append(text_content)
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
image_url = item.get("image", {}).get("url", "")
image_payload = item.get("image", {})
image_url = image_payload.get("url", "")
if image_url:
_img_url_to_process.append(image_url)
_img_url_to_process.append(
(image_url, image_payload.get("aeskey"))
)
content = " ".join(text_parts) if text_parts else ""
else:
content = f"[{msgtype}消息]"
@@ -384,8 +507,8 @@ class WecomAIBotAdapter(Platform):
# 并行处理图片下载和解密
if _img_url_to_process:
tasks = [
process_encrypted_image(url, self.encoding_aes_key)
for url in _img_url_to_process
process_encrypted_image(url, aes_key or self.encoding_aes_key)
for url, aes_key in _img_url_to_process
]
results = await asyncio.gather(*tasks)
for success, result in results:
@@ -459,26 +582,43 @@ class WecomAIBotAdapter(Platform):
"""运行适配器,同时启动HTTP服务器和队列监听器"""
async def run_both() -> None:
# 如果启用统一 webhook 模式,则不启动独立服务器
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", webhook_uuid)
# 只运行队列监听器
await self.queue_listener.run()
else:
if self.connection_mode == "long_connection":
if not self.long_connection_client:
raise RuntimeError("长连接客户端未初始化")
logger.info(
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
"启动企业微信智能机器人长连接模式: %s", self.long_connection_ws_url
)
# 同时运行HTTP服务器和队列监听器
await asyncio.gather(
self.server.start_server(),
self.long_connection_client.start(),
self.queue_listener.run(),
)
else:
# 如果启用统一 webhook 模式,则不启动独立服务器
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(
f"{self.meta().id}(企业微信智能机器人)", webhook_uuid
)
# 只运行队列监听器
await self.queue_listener.run()
else:
if not self.server:
raise RuntimeError("Webhook 服务器未初始化")
logger.info(
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
)
# 同时运行HTTP服务器和队列监听器
await asyncio.gather(
self.server.start_server(),
self.queue_listener.run(),
)
return run_both()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
if self.connection_mode == "long_connection" or not self.server:
return "long_connection mode does not accept webhook callbacks", 400
# 根据请求方法分发到不同的处理函数
if request.method == "GET":
return await self.server.handle_verify(request)
@@ -489,7 +629,10 @@ class WecomAIBotAdapter(Platform):
"""终止适配器"""
logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set()
await self.server.shutdown()
if self.long_connection_client:
await self.long_connection_client.shutdown()
if self.server:
await self.server.shutdown()
def meta(self) -> PlatformMetadata:
"""获取平台元数据"""
@@ -507,17 +650,22 @@ class WecomAIBotAdapter(Platform):
queue_mgr=self.queue_mgr,
webhook_client=self.webhook_client,
only_use_webhook_url_to_send=self.only_use_webhook_url_to_send,
long_connection_sender=self._send_long_connection_respond_msg,
)
message_event.is_at_or_wake_command = (
True # 企业微信智能机器人默认消息都是 at 或唤醒命令
)
message_event.is_wake = True # 企业微信智能机器人消息默认当做唤醒命令处理
self.commit_event(message_event)
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
def get_client(self) -> WecomAIBotAPIClient:
def get_client(self) -> WecomAIBotAPIClient | None:
"""获取 API 客户端"""
return self.api_client
def get_server(self) -> WecomAIBotServer:
def get_server(self) -> WecomAIBotServer | None:
"""获取 HTTP 服务器实例"""
return self.server
@@ -1,5 +1,7 @@
"""企业微信智能机器人事件处理模块,处理消息事件的发送和接收"""
from collections.abc import Awaitable, Callable
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, Image, Plain
@@ -18,10 +20,11 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
message_obj,
platform_meta,
session_id: str,
api_client: WecomAIBotAPIClient,
api_client: WecomAIBotAPIClient | None,
queue_mgr: WecomAIQueueMgr,
webhook_client: WecomAIBotWebhookClient | None = None,
only_use_webhook_url_to_send: bool = False,
long_connection_sender: (Callable[[str, dict], Awaitable[bool]] | None) = None,
) -> None:
"""初始化消息事件
@@ -38,6 +41,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
self.queue_mgr = queue_mgr
self.webhook_client = webhook_client
self.only_use_webhook_url_to_send = only_use_webhook_url_to_send
self.long_connection_sender = long_connection_sender
async def _mark_stream_complete(self, stream_id: str) -> None:
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
@@ -117,6 +121,18 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data
@staticmethod
def _extract_plain_text_from_chain(message_chain: MessageChain | None) -> str:
if not message_chain:
return ""
plain_parts: list[str] = []
for comp in message_chain.chain:
if isinstance(comp, At):
plain_parts.append(f"@{comp.name} ")
elif isinstance(comp, Plain):
plain_parts.append(comp.text)
return "".join(plain_parts).strip()
async def send(self, message: MessageChain | None) -> None:
"""发送消息"""
raw = self.message_obj.raw_message
@@ -124,6 +140,44 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get(
"connection_mode"
)
req_id = pending_response.get("callback_params", {}).get("req_id")
if (
connection_mode == "long_connection"
and self.long_connection_sender
and isinstance(req_id, str)
and req_id
):
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await super().send(MessageChain([]))
return
if self.webhook_client and message:
await self.webhook_client.send_message_chain(
message,
unsupported_only=True,
)
content = self._extract_plain_text_from_chain(message)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": content,
},
},
)
await super().send(MessageChain([]))
return
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await self._mark_stream_complete(stream_id)
@@ -152,8 +206,77 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get(
"connection_mode"
)
req_id = pending_response.get("callback_params", {}).get("req_id")
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
if (
connection_mode == "long_connection"
and self.long_connection_sender
and isinstance(req_id, str)
and req_id
):
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
merged_chain.chain.extend(chain.chain)
merged_chain.squash_plain()
await self.webhook_client.send_message_chain(merged_chain)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": "",
},
},
)
await super().send_streaming(generator, use_fallback)
return
increment_plain = ""
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(
chain,
unsupported_only=True,
)
chain.squash_plain()
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
},
},
)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": increment_plain,
},
},
)
await super().send_streaming(generator, use_fallback)
return
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
@@ -0,0 +1,236 @@
"""企业微信智能机器人长连接客户端。"""
import asyncio
import json
import uuid
from collections.abc import Awaitable, Callable
from typing import Any
import aiohttp
from astrbot.api import logger
class WecomAIBotLongConnectionClient:
"""企业微信智能机器人 WebSocket 长连接客户端。"""
def __init__(
self,
bot_id: str,
secret: str,
ws_url: str,
heartbeat_interval: int,
message_handler: Callable[[dict[str, Any]], Awaitable[None]],
) -> None:
self.bot_id = bot_id
self.secret = secret
self.ws_url = ws_url
self.heartbeat_interval = max(5, int(heartbeat_interval))
self.message_handler = message_handler
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._shutdown_event = asyncio.Event()
self._send_lock = asyncio.Lock()
self._command_lock = asyncio.Lock()
self._response_waiters: dict[str, asyncio.Future[dict[str, Any]]] = {}
@staticmethod
def gen_req_id() -> str:
return uuid.uuid4().hex
async def start(self) -> None:
"""启动长连接并自动重连。"""
reconnect_delay = 1
while not self._shutdown_event.is_set():
try:
await self._run_once()
reconnect_delay = 1
except asyncio.CancelledError:
raise
except Exception as e:
logger.error("[WecomAI][LongConn] 长连接异常: %s", e)
if self._shutdown_event.is_set():
break
await asyncio.sleep(reconnect_delay)
reconnect_delay = min(reconnect_delay * 2, 30)
async def _run_once(self) -> None:
timeout = aiohttp.ClientTimeout(total=None, sock_connect=15, sock_read=None)
async with aiohttp.ClientSession(timeout=timeout) as session:
self._session = session
logger.info("[WecomAI][LongConn] 正在连接: %s", self.ws_url)
async with session.ws_connect(
self.ws_url, heartbeat=None, autoping=True
) as ws:
self._ws = ws
await self._subscribe()
logger.info("[WecomAI][LongConn] 订阅成功,已建立长连接")
heartbeat_task = asyncio.create_task(self._heartbeat_loop())
try:
while not self._shutdown_event.is_set():
message = await ws.receive()
if message.type == aiohttp.WSMsgType.TEXT:
await self._handle_text_message(message.data)
elif message.type in {
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.ERROR,
}:
break
finally:
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
self._ws = None
async def _subscribe(self) -> None:
"""发送 aibot_subscribe,并等待响应。"""
req_id = self.gen_req_id()
payload = {
"cmd": "aibot_subscribe",
"headers": {"req_id": req_id},
"body": {"bot_id": self.bot_id, "secret": self.secret},
}
await self._send_json(payload)
if not self._ws:
raise RuntimeError("WebSocket 未建立")
reply = await self._ws.receive(timeout=10)
if reply.type != aiohttp.WSMsgType.TEXT:
raise RuntimeError(f"订阅失败: 非文本响应 {reply.type}")
data = json.loads(reply.data)
if data.get("errcode") != 0:
raise RuntimeError(
f"订阅失败 errcode={data.get('errcode')} errmsg={data.get('errmsg')}"
)
async def _heartbeat_loop(self) -> None:
while not self._shutdown_event.is_set():
await asyncio.sleep(self.heartbeat_interval)
if self._shutdown_event.is_set():
break
try:
await self.send_command("ping", self.gen_req_id(), None)
except Exception as e:
logger.warning("[WecomAI][LongConn] 发送心跳失败: %s", e)
return
async def _handle_text_message(self, text: str) -> None:
try:
payload = json.loads(text)
except json.JSONDecodeError:
logger.warning("[WecomAI][LongConn] 收到非 JSON 消息: %s", text)
return
headers = payload.get("headers") or {}
req_id = headers.get("req_id")
if isinstance(req_id, str):
waiter = self._response_waiters.get(req_id)
if waiter and not waiter.done():
waiter.set_result(payload)
return
cmd = payload.get("cmd")
if cmd in {"aibot_msg_callback", "aibot_event_callback"}:
await self.message_handler(payload)
return
if payload.get("errcode") not in (None, 0):
logger.warning(
"[WecomAI][LongConn] 服务端返回错误: errcode=%s errmsg=%s",
payload.get("errcode"),
payload.get("errmsg"),
)
async def send_command(
self,
cmd: str,
req_id: str,
body: dict[str, Any] | None,
) -> bool:
"""发送长连接命令。"""
headers = {"req_id": req_id}
payload: dict[str, Any] = {"cmd": cmd, "headers": headers}
if body is not None:
payload["body"] = body
async with self._command_lock:
max_retries = 3
for attempt in range(max_retries + 1):
response = await self._send_and_wait_response(req_id, payload)
if not response:
if attempt < max_retries:
await asyncio.sleep(min(0.2 * (2**attempt), 2.0))
continue
return False
errcode = response.get("errcode")
if errcode in (0, None):
return True
if errcode == 6000 and attempt < max_retries:
backoff = min(0.2 * (2**attempt), 2.0)
logger.warning(
"[WecomAI][LongConn] 命令冲突(errcode=6000),将重试。cmd=%s req_id=%s attempt=%d",
cmd,
req_id,
attempt + 1,
)
await asyncio.sleep(backoff)
continue
logger.warning(
"[WecomAI][LongConn] 命令失败: cmd=%s req_id=%s errcode=%s errmsg=%s",
cmd,
req_id,
errcode,
response.get("errmsg"),
)
return False
return False
async def _send_and_wait_response(
self,
req_id: str,
payload: dict[str, Any],
timeout: float = 10.0,
) -> dict[str, Any] | None:
loop = asyncio.get_running_loop()
waiter: asyncio.Future[dict[str, Any]] = loop.create_future()
self._response_waiters[req_id] = waiter
try:
await self._send_json(payload)
return await asyncio.wait_for(waiter, timeout=timeout)
except TimeoutError:
logger.warning(
"[WecomAI][LongConn] 等待命令响应超时: cmd=%s req_id=%s",
payload.get("cmd"),
req_id,
)
return None
finally:
self._response_waiters.pop(req_id, None)
async def _send_json(self, payload: dict[str, Any]) -> None:
ws = self._ws
if ws is None or ws.closed:
raise RuntimeError("长连接尚未建立")
async with self._send_lock:
await ws.send_json(payload)
async def shutdown(self) -> None:
self._shutdown_event.set()
ws = self._ws
if ws is not None and not ws.closed:
await ws.close()
session = self._session
if session is not None and not session.closed:
await session.close()
@@ -4,6 +4,7 @@
"""
import asyncio
import time
from collections.abc import Awaitable, Callable
from typing import Any
@@ -82,7 +83,7 @@ class WecomAIQueueMgr:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
if mark_finished:
self.completed_streams[session_id] = asyncio.get_event_loop().time()
self.completed_streams[session_id] = time.monotonic()
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
def remove_queue(self, session_id: str):
@@ -135,7 +136,7 @@ class WecomAIQueueMgr:
"""
self.pending_responses[session_id] = {
"callback_params": callback_params,
"timestamp": asyncio.get_event_loop().time(),
"timestamp": time.monotonic(),
}
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
@@ -160,7 +161,7 @@ class WecomAIQueueMgr:
finished_at = self.completed_streams.get(session_id)
if finished_at is None:
return False
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
if time.monotonic() - finished_at > max_age_seconds:
self.completed_streams.pop(session_id, None)
return False
return True
@@ -172,7 +173,7 @@ class WecomAIQueueMgr:
max_age_seconds: 最大存活时间
"""
current_time = asyncio.get_event_loop().time()
current_time = time.monotonic()
expired_sessions = []
for session_id, response_data in self.pending_responses.items():
@@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if future:
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
future = asyncio.get_running_loop().create_future()
self.wexin_event_workers[msg_id] = future
await self.convert_message(msg, future)
# I love shield so much!
@@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
msg.media_id,
+454 -116
View File
@@ -4,7 +4,11 @@ import asyncio
import copy
import json
import os
from collections.abc import AsyncGenerator, Awaitable, Callable
import threading
import urllib.parse
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from dataclasses import dataclass
from types import MappingProxyType
from typing import Any
import aiohttp
@@ -17,6 +21,103 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
DEFAULT_MCP_INIT_TIMEOUT_SECONDS = 180.0
DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS = 180.0
MCP_INIT_TIMEOUT_ENV = "ASTRBOT_MCP_INIT_TIMEOUT"
ENABLE_MCP_TIMEOUT_ENV = "ASTRBOT_MCP_ENABLE_TIMEOUT"
MAX_MCP_TIMEOUT_SECONDS = 300.0
class MCPInitError(Exception):
"""Base exception for MCP initialization failures."""
class MCPInitTimeoutError(asyncio.TimeoutError, MCPInitError):
"""Raised when MCP client initialization exceeds the configured timeout."""
class MCPAllServicesFailedError(MCPInitError):
"""Raised when all configured MCP services fail to initialize."""
class MCPShutdownTimeoutError(asyncio.TimeoutError):
"""Raised when MCP shutdown exceeds the configured timeout."""
def __init__(self, names: list[str], timeout: float) -> None:
self.names = names
self.timeout = timeout
message = f"MCP 服务关闭超时({timeout:g} 秒):{', '.join(names)}"
super().__init__(message)
@dataclass
class MCPInitSummary:
total: int
success: int
failed: list[str]
@dataclass
class _MCPServerRuntime:
name: str
client: MCPClient
shutdown_event: asyncio.Event
lifecycle_task: asyncio.Task[None]
class _MCPClientDictView(Mapping[str, MCPClient]):
"""Read-only view of MCP clients derived from runtime state."""
def __init__(self, runtime: dict[str, _MCPServerRuntime]) -> None:
self._runtime = runtime
def __getitem__(self, key: str) -> MCPClient:
return self._runtime[key].client
def __iter__(self):
return iter(self._runtime)
def __len__(self) -> int:
return len(self._runtime)
def _resolve_timeout(
timeout: float | int | str | None = None,
*,
env_name: str = MCP_INIT_TIMEOUT_ENV,
default: float = DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
) -> float:
"""Resolve timeout with precedence: explicit argument > env value > default."""
source = f"环境变量 {env_name}"
if timeout is None:
timeout = os.getenv(env_name, str(default))
else:
source = "显式参数 timeout"
try:
timeout_value = float(timeout)
except (TypeError, ValueError):
logger.warning(
f"超时配置({source}={timeout!r} 无效,使用默认值 {default:g} 秒。"
)
return default
if timeout_value <= 0:
logger.warning(
f"超时配置({source}={timeout_value:g} 必须大于 0,使用默认值 {default:g} 秒。"
)
return default
if timeout_value > MAX_MCP_TIMEOUT_SECONDS:
logger.warning(
f"超时配置({source}={timeout_value:g} 过大,已限制为最大值 "
f"{MAX_MCP_TIMEOUT_SECONDS:g} 秒,以避免长时间等待。"
)
return MAX_MCP_TIMEOUT_SECONDS
return timeout_value
SUPPORTED_TYPES = [
"string",
"number",
@@ -106,9 +207,49 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
class FunctionToolManager:
def __init__(self) -> None:
self.func_list: list[FuncTool] = []
self.mcp_client_dict: dict[str, MCPClient] = {}
"""MCP 服务列表"""
self.mcp_client_event: dict[str, asyncio.Event] = {}
self._mcp_server_runtime: dict[str, _MCPServerRuntime] = {}
"""MCP 服务运行时状态(唯一事实来源)"""
self._mcp_server_runtime_view = MappingProxyType(self._mcp_server_runtime)
self._mcp_client_dict_view = _MCPClientDictView(self._mcp_server_runtime)
self._timeout_mismatch_warned = False
self._timeout_warn_lock = threading.Lock()
self._runtime_lock = asyncio.Lock()
self._mcp_starting: set[str] = set()
self._init_timeout_default = _resolve_timeout(
timeout=None,
env_name=MCP_INIT_TIMEOUT_ENV,
default=DEFAULT_MCP_INIT_TIMEOUT_SECONDS,
)
self._enable_timeout_default = _resolve_timeout(
timeout=None,
env_name=ENABLE_MCP_TIMEOUT_ENV,
default=DEFAULT_ENABLE_MCP_TIMEOUT_SECONDS,
)
self._warn_on_timeout_mismatch(
self._init_timeout_default,
self._enable_timeout_default,
)
@property
def mcp_client_dict(self) -> Mapping[str, MCPClient]:
"""Read-only compatibility view for external callers that still read mcp_client_dict.
Note: Mutating this mapping is unsupported and will raise TypeError.
"""
return self._mcp_client_dict_view
@property
def mcp_server_runtime_view(self) -> Mapping[str, _MCPServerRuntime]:
"""Read-only view of MCP runtime metadata for external callers."""
return self._mcp_server_runtime_view
@property
def mcp_server_runtime(self) -> Mapping[str, _MCPServerRuntime]:
"""Backward-compatible read-only view (deprecated). Do not mutate.
Note: Mutations are not supported and will raise TypeError.
"""
return self._mcp_server_runtime_view
def empty(self) -> bool:
return len(self.func_list) == 0
@@ -179,7 +320,34 @@ class FunctionToolManager:
tool_set = ToolSet(self.func_list.copy())
return tool_set
async def init_mcp_clients(self) -> None:
@staticmethod
def _log_safe_mcp_debug_config(cfg: dict) -> None:
# 仅记录脱敏后的摘要,避免泄露 command/args/url 中的敏感信息
if "command" in cfg:
cmd = cfg["command"]
executable = str(cmd[0] if isinstance(cmd, (list, tuple)) and cmd else cmd)
args_val = cfg.get("args", [])
args_count = (
len(args_val)
if isinstance(args_val, (list, tuple))
else (0 if args_val is None else 1)
)
logger.debug(f" 命令可执行文件: {executable}, 参数数量: {args_count}")
return
if "url" in cfg:
parsed = urllib.parse.urlparse(str(cfg["url"]))
host = parsed.hostname or ""
scheme = parsed.scheme or "unknown"
try:
port = f":{parsed.port}" if parsed.port else ""
except ValueError:
port = ""
logger.debug(f" 主机: {scheme}://{host}{port}")
async def init_mcp_clients(
self, raise_on_all_failed: bool = False
) -> MCPInitSummary:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
{
@@ -197,6 +365,10 @@ class FunctionToolManager:
...
}
```
Timeout behavior:
- 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值
- 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT独立于初始化超时
"""
data_dir = get_astrbot_data_path()
@@ -206,56 +378,217 @@ class FunctionToolManager:
with open(mcp_json_file, "w", encoding="utf-8") as f:
json.dump(DEFAULT_MCP_CONFIG, f, ensure_ascii=False, indent=4)
logger.info(f"未找到 MCP 服务配置文件,已创建默认配置文件 {mcp_json_file}")
return
return MCPInitSummary(total=0, success=0, failed=[])
mcp_server_json_obj: dict[str, dict] = json.load(
open(mcp_json_file, encoding="utf-8"),
)["mcpServers"]
with open(mcp_json_file, encoding="utf-8") as f:
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
for name in mcp_server_json_obj:
cfg = mcp_server_json_obj[name]
init_timeout = self._init_timeout_default
timeout_display = f"{init_timeout:g}"
active_configs: list[tuple[str, dict, asyncio.Event]] = []
for name, cfg in mcp_server_json_obj.items():
if cfg.get("active", True):
event = asyncio.Event()
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, cfg, event),
)
self.mcp_client_event[name] = event
shutdown_event = asyncio.Event()
active_configs.append((name, cfg, shutdown_event))
async def _init_mcp_client_task_wrapper(
if not active_configs:
return MCPInitSummary(total=0, success=0, failed=[])
logger.info(f"等待 {len(active_configs)} 个 MCP 服务初始化...")
init_tasks = [
asyncio.create_task(
self._start_mcp_server(
name=name,
cfg=cfg,
shutdown_event=shutdown_event,
timeout=init_timeout,
),
name=f"mcp-init:{name}",
)
for (name, cfg, shutdown_event) in active_configs
]
results = await asyncio.gather(*init_tasks, return_exceptions=True)
success_count = 0
failed_services: list[str] = []
for (name, cfg, _), result in zip(active_configs, results, strict=False):
if isinstance(result, Exception):
if isinstance(result, MCPInitTimeoutError):
logger.error(
f"Connected to MCP server {name} timeout ({timeout_display} seconds)"
)
else:
logger.error(f"Failed to initialize MCP server {name}: {result}")
self._log_safe_mcp_debug_config(cfg)
failed_services.append(name)
async with self._runtime_lock:
self._mcp_server_runtime.pop(name, None)
continue
success_count += 1
if failed_services:
logger.warning(
f"The following MCP services failed to initialize: {', '.join(failed_services)}. "
f"Please check the mcp_server.json file and server availability."
)
summary = MCPInitSummary(
total=len(active_configs), success=success_count, failed=failed_services
)
logger.info(
f"MCP services initialization completed: {summary.success}/{summary.total} successful, {len(summary.failed)} failed."
)
if summary.total > 0 and summary.success == 0:
msg = "All MCP services failed to initialize, please check the mcp_server.json and server availability."
if raise_on_all_failed:
raise MCPAllServicesFailedError(msg)
logger.error(msg)
return summary
async def _start_mcp_server(
self,
name: str,
cfg: dict,
event: asyncio.Event,
ready_future: asyncio.Future | None = None,
*,
shutdown_event: asyncio.Event | None = None,
timeout: float,
) -> None:
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
"""Initialize MCP server with timeout and register task/event together.
This method is idempotent. If the server is already running, the existing
runtime is kept and the new config is ignored.
"""
async with self._runtime_lock:
if name in self._mcp_server_runtime or name in self._mcp_starting:
logger.warning(
f"Connected to MCP server {name}, ignoring this startup request (timeout={timeout:g})."
)
self._log_safe_mcp_debug_config(cfg)
return
self._mcp_starting.add(name)
if shutdown_event is None:
shutdown_event = asyncio.Event()
mcp_client: MCPClient | None = None
try:
await self._init_mcp_client(name, cfg)
tools = await self.mcp_client_dict[name].list_tools_and_save()
if ready_future and not ready_future.done():
# tell the caller we are ready
ready_future.set_result(tools)
await event.wait()
logger.info(f"收到 MCP 客户端 {name} 终止信号")
except Exception as e:
logger.error(f"初始化 MCP 客户端 {name} 失败", exc_info=True)
if ready_future and not ready_future.done():
ready_future.set_exception(e)
mcp_client = await asyncio.wait_for(
self._init_mcp_client(name, cfg),
timeout=timeout,
)
except asyncio.TimeoutError as exc:
raise MCPInitTimeoutError(
f"Connected to MCP server {name} timeout ({timeout:g} seconds)"
) from exc
except Exception:
logger.error(f"Failed to initialize MCP client {name}", exc_info=True)
raise
finally:
# 无论如何都能清理
await self._terminate_mcp_client(name)
if mcp_client is None:
async with self._runtime_lock:
self._mcp_starting.discard(name)
async def _init_mcp_client(self, name: str, config: dict) -> None:
async def lifecycle() -> None:
try:
await shutdown_event.wait()
logger.info(f"Received shutdown signal for MCP client {name}")
except asyncio.CancelledError:
logger.debug(f"MCP client {name} task was cancelled")
raise
finally:
await self._terminate_mcp_client(name)
lifecycle_task = asyncio.create_task(lifecycle(), name=f"mcp-client:{name}")
async with self._runtime_lock:
self._mcp_server_runtime[name] = _MCPServerRuntime(
name=name,
client=mcp_client,
shutdown_event=shutdown_event,
lifecycle_task=lifecycle_task,
)
self._mcp_starting.discard(name)
async def _shutdown_runtimes(
self,
runtimes: list[_MCPServerRuntime],
timeout: float,
*,
strict: bool = True,
) -> list[str]:
"""Shutdown runtimes and wait for lifecycle tasks to complete."""
lifecycle_tasks = [
runtime.lifecycle_task
for runtime in runtimes
if not runtime.lifecycle_task.done()
]
if not lifecycle_tasks:
return []
for runtime in runtimes:
runtime.shutdown_event.set()
try:
results = await asyncio.wait_for(
asyncio.gather(*lifecycle_tasks, return_exceptions=True),
timeout=timeout,
)
except asyncio.TimeoutError:
pending_names = [
runtime.name
for runtime in runtimes
if not runtime.lifecycle_task.done()
]
for task in lifecycle_tasks:
if not task.done():
task.cancel()
await asyncio.gather(*lifecycle_tasks, return_exceptions=True)
if strict:
raise MCPShutdownTimeoutError(pending_names, timeout)
logger.warning(
"MCP server shutdown timeout (%s seconds), the following servers were not fully closed: %s",
f"{timeout:g}",
", ".join(pending_names),
)
return pending_names
else:
for result in results:
if isinstance(result, asyncio.CancelledError):
logger.debug("MCP lifecycle task was cancelled during shutdown.")
elif isinstance(result, Exception):
logger.error(
"MCP lifecycle task failed during shutdown.",
exc_info=(type(result), result, result.__traceback__),
)
return []
async def _cleanup_mcp_client_safely(
self, mcp_client: MCPClient, name: str
) -> None:
"""安全清理单个 MCP 客户端,避免清理异常中断主流程。"""
try:
await mcp_client.cleanup()
except Exception as cleanup_exc: # noqa: BLE001 - only log here
logger.error(
f"Failed to cleanup MCP client resources {name}: {cleanup_exc}"
)
async def _init_mcp_client(self, name: str, config: dict) -> MCPClient:
"""初始化单个MCP客户端"""
# 先清理之前的客户端,如果存在
if name in self.mcp_client_dict:
await self._terminate_mcp_client(name)
mcp_client = MCPClient()
mcp_client.name = name
self.mcp_client_dict[name] = mcp_client
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
try:
await mcp_client.connect_to_server(config, name)
tools_res = await mcp_client.list_tools_and_save()
except asyncio.CancelledError:
await self._cleanup_mcp_client_safely(mcp_client, name)
raise
except Exception:
await self._cleanup_mcp_client_safely(mcp_client, name)
raise
logger.debug(f"MCP server {name} list tools response: {tools_res}")
tool_names = [tool.name for tool in tools_res.tools]
@@ -275,27 +608,37 @@ class FunctionToolManager:
)
self.func_list.append(func_tool)
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
logger.info(f"Connected to MCP server {name}, Tools: {tool_names}")
return mcp_client
async def _terminate_mcp_client(self, name: str) -> None:
"""关闭并清理MCP客户端"""
if name in self.mcp_client_dict:
client = self.mcp_client_dict[name]
try:
# 关闭MCP连接
await client.cleanup()
except Exception as e:
logger.error(f"清空 MCP 客户端资源 {name}: {e}")
finally:
# Remove client from dict after cleanup attempt (successful or not)
self.mcp_client_dict.pop(name, None)
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
logger.info(f"已关闭 MCP 服务 {name}")
async with self._runtime_lock:
runtime = self._mcp_server_runtime.get(name)
if runtime:
client = runtime.client
# 关闭MCP连接
await self._cleanup_mcp_client_safely(client, name)
# 移除关联的FuncTool
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
async with self._runtime_lock:
self._mcp_server_runtime.pop(name, None)
self._mcp_starting.discard(name)
logger.info(f"Disconnected from MCP server {name}")
return
# Runtime missing but stale tools may still exist after failed flows.
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
async with self._runtime_lock:
self._mcp_starting.discard(name)
@staticmethod
async def test_mcp_server_connection(config: dict) -> list[str]:
@@ -319,42 +662,36 @@ class FunctionToolManager:
self,
name: str,
config: dict,
event: asyncio.Event | None = None,
ready_future: asyncio.Future | None = None,
timeout: int = 30,
shutdown_event: asyncio.Event | None = None,
timeout: float | int | str | None = None,
) -> None:
"""Enable_mcp_server a new MCP server to the manager and initialize it.
"""Enable a new MCP server and initialize it.
Args:
name (str): The name of the MCP server.
config (dict): Configuration for the MCP server.
event (asyncio.Event): Event to signal when the MCP client is ready.
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
timeout (int): Timeout for the initialization.
name: The name of the MCP server.
config: Configuration for the MCP server.
shutdown_event: Event to signal when the MCP client should shut down.
timeout: Timeout in seconds for initialization.
Uses ASTRBOT_MCP_ENABLE_TIMEOUT by default (separate from init timeout).
Raises:
TimeoutError: If the initialization does not complete within the specified timeout.
MCPInitTimeoutError: If initialization does not complete within timeout.
Exception: If there is an error during initialization.
"""
if not event:
event = asyncio.Event()
if not ready_future:
ready_future = asyncio.Future()
if name in self.mcp_client_dict:
return
asyncio.create_task(
self._init_mcp_client_task_wrapper(name, config, event, ready_future),
if timeout is None:
timeout_value = self._enable_timeout_default
else:
timeout_value = _resolve_timeout(
timeout=timeout,
env_name=ENABLE_MCP_TIMEOUT_ENV,
default=self._enable_timeout_default,
)
await self._start_mcp_server(
name=name,
cfg=config,
shutdown_event=shutdown_event,
timeout=timeout_value,
)
try:
await asyncio.wait_for(ready_future, timeout=timeout)
finally:
self.mcp_client_event[name] = event
if ready_future.done() and ready_future.exception():
exc = ready_future.exception()
if exc is not None:
raise exc
async def disable_mcp_server(
self,
@@ -367,39 +704,40 @@ class FunctionToolManager:
name (str): The name of the MCP server to disable. If None, ALL MCP servers will be disabled.
timeout (int): Timeout.
Raises:
MCPShutdownTimeoutError: If shutdown does not complete within timeout.
Only raised when disabling a specific server (name is not None).
"""
if name:
if name not in self.mcp_client_event:
async with self._runtime_lock:
runtime = self._mcp_server_runtime.get(name)
if runtime is None:
return
client = self.mcp_client_dict.get(name)
self.mcp_client_event[name].set()
if not client:
return
client_running_event = client.running_event
try:
await asyncio.wait_for(client_running_event.wait(), timeout=timeout)
finally:
self.mcp_client_event.pop(name, None)
self.func_list = [
f
for f in self.func_list
if not (isinstance(f, MCPTool) and f.mcp_server_name == name)
]
await self._shutdown_runtimes([runtime], timeout, strict=True)
else:
running_events = [
client.running_event.wait() for client in self.mcp_client_dict.values()
]
for key, event in self.mcp_client_event.items():
event.set()
# waiting for all clients to finish
try:
await asyncio.wait_for(asyncio.gather(*running_events), timeout=timeout)
finally:
self.mcp_client_event.clear()
self.mcp_client_dict.clear()
self.func_list = [
f for f in self.func_list if not isinstance(f, MCPTool)
]
async with self._runtime_lock:
runtimes = list(self._mcp_server_runtime.values())
await self._shutdown_runtimes(runtimes, timeout, strict=False)
def _warn_on_timeout_mismatch(
self,
init_timeout: float,
enable_timeout: float,
) -> None:
if init_timeout == enable_timeout:
return
with self._timeout_warn_lock:
if self._timeout_mismatch_warned:
return
logger.info(
"检测到 MCP 初始化超时与动态启用超时配置不同:"
"初始化使用 %s 秒,动态启用使用 %s 秒。如需一致,请设置相同值。",
f"{init_timeout:g}",
f"{enable_timeout:g}",
)
self._timeout_mismatch_warned = True
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
"""获得 OpenAI API 风格的**已经激活**的工具描述"""
+75 -2
View File
@@ -2,11 +2,13 @@ import asyncio
import copy
import os
import traceback
from collections.abc import Callable
from typing import Protocol, runtime_checkable
from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.error_redaction import safe_error
from ..persona_mgr import PersonaManager
from .entities import ProviderType
@@ -71,6 +73,57 @@ class ProviderManager:
self.curr_tts_provider_inst: TTSProvider | None = None
"""默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。"""
self.db_helper = db_helper
self._provider_change_callback: (
Callable[[str, ProviderType, str | None], None] | None
) = None
self._provider_change_hooks: list[
Callable[[str, ProviderType, str | None], None]
] = []
self._mcp_init_task: asyncio.Task | None = None
def set_provider_change_callback(
self,
cb: Callable[[str, ProviderType, str | None], None] | None,
) -> None:
# Backward-compatible single-callback setter.
# This callback coexists with register_provider_change_hook subscriptions.
self._provider_change_callback = cb
def register_provider_change_hook(
self,
hook: Callable[[str, ProviderType, str | None], None],
) -> None:
if hook not in self._provider_change_hooks:
self._provider_change_hooks.append(hook)
def _notify_provider_changed(
self,
provider_id: str,
provider_type: ProviderType,
umo: str | None,
) -> None:
if self._provider_change_callback is not None:
try:
self._provider_change_callback(provider_id, provider_type, umo)
except Exception as e:
logger.warning(
"调用 provider 变更回调失败: provider_id=%s, type=%s, err=%s",
provider_id,
provider_type,
safe_error("", e),
)
for hook in list(self._provider_change_hooks):
if hook is self._provider_change_callback:
continue
try:
hook(provider_id, provider_type, umo)
except Exception as e:
logger.warning(
"调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s",
provider_id,
provider_type,
safe_error("", e),
)
@property
def persona_configs(self) -> list:
@@ -111,6 +164,7 @@ class ProviderManager:
f"provider_perf_{provider_type.value}",
provider_id,
)
self._notify_provider_changed(provider_id, provider_type, umo)
return
# 不启用提供商会话隔离模式的情况
@@ -126,6 +180,7 @@ class ProviderManager:
scope="global",
scope_id="global",
)
self._notify_provider_changed(provider_id, provider_type, umo)
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
prov,
STTProvider,
@@ -137,6 +192,7 @@ class ProviderManager:
scope="global",
scope_id="global",
)
self._notify_provider_changed(provider_id, provider_type, umo)
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
prov,
Provider,
@@ -148,6 +204,7 @@ class ProviderManager:
scope="global",
scope_id="global",
)
self._notify_provider_changed(provider_id, provider_type, umo)
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
@@ -274,8 +331,17 @@ class ProviderManager:
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
async def _init_mcp_clients_bg() -> None:
try:
await self.llm_tools.init_mcp_clients()
except Exception:
logger.error("MCP init background task failed", exc_info=True)
if self._mcp_init_task is None or self._mcp_init_task.done():
self._mcp_init_task = asyncio.create_task(
_init_mcp_clients_bg(),
name="provider-manager:mcp-init",
)
def dynamic_import_provider(self, type: str) -> None:
"""动态导入提供商适配器模块
@@ -744,6 +810,13 @@ class ProviderManager:
await self.load_provider(new_config)
async def terminate(self) -> None:
if self._mcp_init_task and not self._mcp_init_task.done():
self._mcp_init_task.cancel()
try:
await self._mcp_init_task
except asyncio.CancelledError:
pass
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
await provider_inst.terminate() # type: ignore
+18 -1
View File
@@ -281,7 +281,24 @@ class TTSProvider(AbstractProvider):
accumulated_text += text_part
async def test(self) -> None:
await self.get_audio("hi")
audio_path = await self.get_audio("hi")
# 检查生成的音频文件是否有效
if not os.path.exists(audio_path):
raise Exception("TTS test failed: audio file was not created")
file_size = os.path.getsize(audio_path)
if file_size == 0:
raise Exception(
"TTS test failed: generated audio file is empty (0 bytes). "
"Please check your TTS provider configuration, especially required parameters like group_id for MiniMax."
)
# 清理测试文件
try:
os.remove(audio_path)
except Exception:
pass
class EmbeddingProvider(AbstractProvider):
@@ -276,9 +276,24 @@ class ProviderAnthropic(Provider):
llm_response.id = completion.id
llm_response.usage = self._extract_usage(completion.usage)
# TODO(Soulter): 处理 end_turn 情况
# Handle cases where completion only contains ThinkingBlock (e.g., MiniMax max_tokens)
# When stop_reason='max_tokens', the model may return only thinking content
# This is valid and should not raise an exception
if not llm_response.completion_text and not llm_response.tools_call_args:
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}")
# Guard clause: raise early if no valid content at all
if not llm_response.reasoning_content:
raise ValueError(
f"Anthropic API returned unparsable completion: "
f"no text, tool_use, or thinking content found. "
f"Completion: {completion}"
)
# We have reasoning content (ThinkingBlock) - this is valid
stop_reason = getattr(completion, "stop_reason", "unknown")
logger.debug(
f"Completion contains only ThinkingBlock (stop_reason={stop_reason})"
)
llm_response.completion_text = "" # Ensure empty string, not None
return llm_response
@@ -20,6 +20,7 @@ from ..register import register_provider_adapter
TEMP_DIR = Path(get_astrbot_temp_path()) / "azure_tts"
TEMP_DIR.mkdir(parents=True, exist_ok=True)
AZURE_TTS_SUBSCRIPTION_KEY_PATTERN = r"^(?:[a-zA-Z0-9]{32}|[a-zA-Z0-9]{84})$"
class OTTSProvider:
@@ -116,7 +117,7 @@ class AzureNativeProvider(TTSProvider):
"azure_tts_subscription_key",
"",
).strip()
if not re.fullmatch(r"^[a-zA-Z0-9]{32}$", self.subscription_key):
if not re.fullmatch(AZURE_TTS_SUBSCRIPTION_KEY_PATTERN, self.subscription_key):
raise ValueError("无效的Azure订阅密钥")
self.region = provider_config.get("azure_tts_region", "eastus").strip()
self.endpoint = (
@@ -235,9 +236,9 @@ class AzureTTSProvider(TTSProvider):
raise ValueError(error_msg) from e
except KeyError as e:
raise ValueError(f"配置错误: 缺少必要参数 {e}") from e
if re.fullmatch(r"^[a-zA-Z0-9]{32}$", key_value):
if re.fullmatch(AZURE_TTS_SUBSCRIPTION_KEY_PATTERN, key_value):
return AzureNativeProvider(config, self.provider_settings)
raise ValueError("订阅密钥格式无效,应为32位字母数字或other[...]格式")
raise ValueError("订阅密钥格式无效,应为32位或84位字母数字或other[...]格式")
async def get_audio(self, text: str) -> str:
if isinstance(self.provider, OTTSProvider):
@@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
model: str,
text: str,
) -> tuple[bytes | None, str]:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
audio_bytes = await self._extract_audio_from_response(response)
if not audio_bytes:
@@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
)
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
audio_bytes = await loop.run_in_executor(
None,
synthesizer.call,
+2 -2
View File
@@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider):
filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename)
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
def _generate(save_path: str) -> None:
assert genie is not None
@@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider):
text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
) -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
while True:
text = await text_queue.get()
@@ -154,6 +154,14 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
audio_stream = self._call_tts_stream(text)
audio = await self._audio_play(audio_stream)
# 检查音频数据是否为空
if not audio or len(audio) == 0:
raise Exception(
"MiniMax TTS API returned empty audio data. "
"Please verify your configuration, especially the 'group_id' parameter. "
"You can find your group_id in Account Management -> Basic Information on the MiniMax platform."
)
# 结果保存至文件
with open(path, "wb") as file:
file.write(audio)
@@ -161,4 +169,4 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return path
except aiohttp.ClientError as e:
raise e
raise Exception(f"MiniMax TTS API request failed: {e!s}")
@@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
# 将模型加载放到线程池中执行
self.model = await asyncio.get_event_loop().run_in_executor(
self.model = await asyncio.get_running_loop().run_in_executor(
None,
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
)
@@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
audio_url = output_path
# 使用 run_in_executor 来调用模型进行识别
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
res = await loop.run_in_executor(
None, # 使用默认的线程池
lambda: cast(SenseVoiceSmall, self.model)(
@@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
self.model = None
async def initialize(self) -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor(
None,
@@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
return False
async def get_text(self, audio_url: str) -> str:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
is_tencent = False
+372
View File
@@ -0,0 +1,372 @@
from __future__ import annotations
import hashlib
import json
import os
import re
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from astrbot.core.computer.computer_client import sync_skills_to_active_sandboxes
from astrbot.core.skills.skill_manager import SkillManager
from astrbot.core.utils.astrbot_path import get_astrbot_skills_path
_MAP_VERSION = 1
_MAP_FILE_NAME = "neo_skill_map.json"
_SKILL_NAME_RE = re.compile(r"[^a-zA-Z0-9._-]+")
def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
def _to_jsonable(model_like: Any) -> dict[str, Any]:
if isinstance(model_like, dict):
return model_like
if hasattr(model_like, "model_dump"):
dumped = model_like.model_dump()
if isinstance(dumped, dict):
return dumped
return {}
def _parse_frontmatter(text: str) -> tuple[dict[str, str], str]:
if not text.startswith("---"):
return {}, text
lines = text.splitlines()
if not lines or lines[0].strip() != "---":
return {}, text
end_idx = None
for i in range(1, len(lines)):
if lines[i].strip() == "---":
end_idx = i
break
if end_idx is None:
return {}, text
data: dict[str, str] = {}
for line in lines[1:end_idx]:
if ":" not in line:
continue
key, value = line.split(":", 1)
key = key.strip().lower()
value = value.strip().strip('"').strip("'")
if key in {"name", "description"} and value:
data[key] = value
body = "\n".join(lines[end_idx + 1 :]).lstrip("\n")
return data, body
def _derive_description(markdown_body: str) -> str:
lines = markdown_body.splitlines()
heading_idx = None
for i, line in enumerate(lines):
normalized = line.strip().lower()
if normalized in {"## 描述", "## description"}:
heading_idx = i
break
if heading_idx is not None:
for line in lines[heading_idx + 1 :]:
text = line.strip()
if not text:
continue
if text.startswith("#"):
break
return text
for line in lines:
text = line.strip()
if not text or text.startswith("#"):
continue
return text
return ""
def _ensure_skill_frontmatter(markdown: str, *, skill_name: str, skill_key: str) -> str:
frontmatter, body = _parse_frontmatter(markdown)
name = frontmatter.get("name") or skill_name
name = " ".join(str(name).split())
description = frontmatter.get("description") or _derive_description(body)
if not description:
description = f"Synced skill for `{skill_key}`."
description = " ".join(description.split())
header = f"---\nname: {name}\ndescription: {description}\n---\n\n"
body = body.strip("\n")
return f"{header}{body}\n"
@dataclass
class NeoSkillSyncResult:
skill_key: str
local_skill_name: str
release_id: str
candidate_id: str
payload_ref: str
map_path: str
synced_at: str
class NeoSkillSyncManager:
@staticmethod
def sync_result_to_dict(result: NeoSkillSyncResult) -> dict[str, str]:
return {
"skill_key": result.skill_key,
"local_skill_name": result.local_skill_name,
"release_id": result.release_id,
"candidate_id": result.candidate_id,
"payload_ref": result.payload_ref,
"map_path": result.map_path,
"synced_at": result.synced_at,
}
def __init__(
self,
*,
skills_root: str | None = None,
map_path: str | None = None,
) -> None:
self.skills_root = skills_root or get_astrbot_skills_path()
self.map_path = map_path or str(Path(self.skills_root) / _MAP_FILE_NAME)
os.makedirs(self.skills_root, exist_ok=True)
def _load_map(self) -> dict[str, Any]:
if not os.path.exists(self.map_path):
return {"version": _MAP_VERSION, "items": {}}
try:
with open(self.map_path, encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
return {"version": _MAP_VERSION, "items": {}}
items = data.get("items", {})
if not isinstance(items, dict):
items = {}
return {"version": int(data.get("version", _MAP_VERSION)), "items": items}
except Exception:
return {"version": _MAP_VERSION, "items": {}}
def _save_map(self, data: dict[str, Any]) -> None:
os.makedirs(os.path.dirname(self.map_path), exist_ok=True)
with open(self.map_path, "w", encoding="utf-8") as f:
json.dump(data, f, ensure_ascii=False, indent=2)
@staticmethod
def normalize_skill_name(skill_key: str) -> str:
normalized = _SKILL_NAME_RE.sub("-", skill_key.strip().lower())
normalized = normalized.strip("._-")
if not normalized:
normalized = "skill"
return f"neo_{normalized}"
def _resolve_local_skill_name(self, skill_key: str, mapping: dict[str, Any]) -> str:
items = mapping.get("items", {})
if not isinstance(items, dict):
items = {}
existing = items.get(skill_key)
if isinstance(existing, dict):
local_name = existing.get("local_skill_name")
if isinstance(local_name, str) and local_name:
return local_name
base = self.normalize_skill_name(skill_key)
used_names = {
str(v.get("local_skill_name"))
for v in items.values()
if isinstance(v, dict) and v.get("local_skill_name")
}
if base not in used_names:
return base
suffix = hashlib.sha1(skill_key.encode("utf-8")).hexdigest()[:8]
return f"{base}-{suffix}"
async def _find_release(self, client: Any, *, release_id: str) -> dict[str, Any]:
offset = 0
while True:
page = await client.skills.list_releases(limit=100, offset=offset)
page_json = _to_jsonable(page)
items = page_json.get("items", [])
if not isinstance(items, list):
items = []
for item in items:
if isinstance(item, dict) and item.get("id") == release_id:
return item
total = int(page_json.get("total", 0) or 0)
offset += len(items)
if offset >= total or not items:
break
raise ValueError(f"Release not found: {release_id}")
async def _find_active_stable_release(
self,
client: Any,
*,
skill_key: str,
) -> dict[str, Any]:
page = await client.skills.list_releases(
skill_key=skill_key,
active_only=True,
stage="stable",
limit=1,
offset=0,
)
page_json = _to_jsonable(page)
items = page_json.get("items", [])
if not isinstance(items, list) or not items:
raise ValueError(
f"No active stable release found for skill_key: {skill_key}"
)
if not isinstance(items[0], dict):
raise ValueError("Unexpected release payload format.")
return items[0]
async def sync_release(
self,
client: Any,
*,
release_id: str | None = None,
skill_key: str | None = None,
require_stable: bool = True,
) -> NeoSkillSyncResult:
if release_id:
release = await self._find_release(client, release_id=release_id)
elif skill_key:
release = await self._find_active_stable_release(
client, skill_key=skill_key
)
else:
raise ValueError("release_id or skill_key is required for sync.")
release_id_val = str(release.get("id") or "")
release_stage_raw = release.get("stage")
release_stage_value = getattr(release_stage_raw, "value", release_stage_raw)
release_stage = str(release_stage_value or "").strip().lower()
skill_key_val = str(release.get("skill_key") or "")
candidate_id = str(release.get("candidate_id") or "")
if not release_id_val or not skill_key_val or not candidate_id:
raise ValueError("Release payload is incomplete.")
if require_stable and release_stage != "stable":
raise ValueError(
"Only stable releases can be synced to local SKILL.md "
f"(got: {release_stage_raw})."
)
candidate = await client.skills.get_candidate(candidate_id)
candidate_json = _to_jsonable(candidate)
payload_ref = candidate_json.get("payload_ref")
if not isinstance(payload_ref, str) or not payload_ref:
raise ValueError("Candidate payload_ref is missing.")
payload_resp = await client.skills.get_payload(payload_ref)
payload_json = _to_jsonable(payload_resp)
payload = payload_json.get("payload")
if not isinstance(payload, dict):
raise ValueError("Skill payload must be a JSON object.")
skill_markdown = payload.get("skill_markdown")
if not isinstance(skill_markdown, str) or not skill_markdown.strip():
raise ValueError(
"payload.skill_markdown is required for stable sync to local skill."
)
mapping = self._load_map()
local_skill_name = self._resolve_local_skill_name(skill_key_val, mapping)
skill_dir = Path(self.skills_root) / local_skill_name
skill_dir.mkdir(parents=True, exist_ok=True)
normalized_markdown = _ensure_skill_frontmatter(
skill_markdown,
skill_name=local_skill_name,
skill_key=skill_key_val,
)
skill_md_path = skill_dir / "SKILL.md"
skill_md_path.write_text(normalized_markdown, encoding="utf-8")
items = mapping.setdefault("items", {})
items[skill_key_val] = {
"local_skill_name": local_skill_name,
"latest_release_id": release_id_val,
"latest_candidate_id": candidate_id,
"latest_payload_ref": payload_ref,
"updated_at": _now_iso(),
}
mapping["version"] = _MAP_VERSION
self._save_map(mapping)
# Ensure local skill is visible to AstrBot skill manager.
SkillManager().set_skill_active(local_skill_name, True)
# Best-effort synchronization to active sandboxes.
await sync_skills_to_active_sandboxes()
return NeoSkillSyncResult(
skill_key=skill_key_val,
local_skill_name=local_skill_name,
release_id=release_id_val,
candidate_id=candidate_id,
payload_ref=payload_ref,
map_path=self.map_path,
synced_at=_now_iso(),
)
async def promote_with_optional_sync(
self,
client: Any,
*,
candidate_id: str,
stage: str,
sync_to_local: bool,
) -> dict[str, Any]:
release = await client.skills.promote_candidate(candidate_id, stage=stage)
release_json = _to_jsonable(release)
sync_json: dict[str, Any] | None = None
rollback_json: dict[str, Any] | None = None
sync_error: str | None = None
if stage == "stable" and sync_to_local:
try:
sync_result = await self.sync_release(
client,
release_id=str(release_json.get("id", "")),
require_stable=True,
)
sync_json = self.sync_result_to_dict(sync_result)
except Exception as err:
sync_error = str(err)
try:
rollback = await client.skills.rollback_release(
str(release_json.get("id", ""))
)
rollback_json = _to_jsonable(rollback)
except Exception as rollback_err:
rollback_msg = str(rollback_err)
if "no previous release exists" in rollback_msg.lower():
rollback_json = {
"skipped": True,
"reason": rollback_msg,
}
else:
raise RuntimeError(
"stable release synced failed and auto rollback also failed; "
f"sync_error={sync_error}; rollback_error={rollback_err}"
) from rollback_err
return {
"release": release_json,
"sync": sync_json,
"rollback": rollback_json,
"sync_error": sync_error,
}
+342 -39
View File
@@ -3,10 +3,12 @@ from __future__ import annotations
import json
import os
import re
import shlex
import shutil
import tempfile
import zipfile
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path, PurePosixPath
from astrbot.core.utils.astrbot_path import (
@@ -16,22 +18,45 @@ from astrbot.core.utils.astrbot_path import (
)
SKILLS_CONFIG_FILENAME = "skills.json"
SANDBOX_SKILLS_CACHE_FILENAME = "sandbox_skills_cache.json"
DEFAULT_SKILLS_CONFIG: dict[str, dict] = {"skills": {}}
# SANDBOX_SKILLS_ROOT = "/home/shared/skills"
SANDBOX_SKILLS_ROOT = "skills"
SANDBOX_WORKSPACE_ROOT = "/workspace"
_SANDBOX_SKILLS_CACHE_VERSION = 1
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
def _is_ignored_zip_entry(name: str) -> bool:
parts = PurePosixPath(name).parts
if not parts:
return True
return parts[0] == "__MACOSX"
@dataclass
class SkillInfo:
name: str
description: str
path: str
active: bool
source_type: str = "local_only"
source_label: str = "local"
local_exists: bool = True
sandbox_exists: bool = False
def _parse_frontmatter_description(text: str) -> str:
"""Extract the ``description`` value from YAML frontmatter.
Expects the standard SKILL.md format used by OpenAI Codex CLI and
Anthropic Claude Skills::
---
name: my-skill
description: What this skill does and when to use it.
---
"""
if not text.startswith("---"):
return ""
lines = text.splitlines()
@@ -53,45 +78,148 @@ def _parse_frontmatter_description(text: str) -> str:
return ""
# Regex for sanitizing paths used in prompt examples — only allow
# safe path characters to prevent prompt injection via crafted skill paths.
_SAFE_PATH_RE = re.compile(r"[^\w./ ,()'\-]", re.UNICODE)
_WINDOWS_DRIVE_PATH_RE = re.compile(r"^[A-Za-z]:(?:/|\\)")
_WINDOWS_UNC_PATH_RE = re.compile(r"^(//|\\\\)[^/\\]+[/\\][^/\\]+")
_CONTROL_CHARS_RE = re.compile(r"[\x00-\x1F\x7F]")
def _is_windows_prompt_path(path: str) -> bool:
if os.name != "nt":
return False
return bool(_WINDOWS_DRIVE_PATH_RE.match(path) or _WINDOWS_UNC_PATH_RE.match(path))
def _sanitize_prompt_path_for_prompt(path: str) -> str:
if not path:
return ""
if _WINDOWS_DRIVE_PATH_RE.match(path) or _WINDOWS_UNC_PATH_RE.match(path):
path = path.replace("\\", "/")
drive_prefix = ""
if _WINDOWS_DRIVE_PATH_RE.match(path):
drive_prefix = path[:2]
path = path[2:]
path = path.replace("`", "")
path = _CONTROL_CHARS_RE.sub("", path)
sanitized = _SAFE_PATH_RE.sub("", path)
return f"{drive_prefix}{sanitized}"
def _sanitize_prompt_description(description: str) -> str:
description = description.replace("`", "")
description = _CONTROL_CHARS_RE.sub(" ", description)
description = " ".join(description.split())
return description
def _sanitize_skill_display_name(name: str) -> str:
if _SKILL_NAME_RE.fullmatch(name):
return name
return "<invalid_skill_name>"
def _build_skill_read_command_example(path: str) -> str:
if path == "<skills_root>/<skill_name>/SKILL.md":
return f"cat {path}"
if _is_windows_prompt_path(path):
command = "type"
path_arg = f'"{path}"'
else:
command = "cat"
path_arg = shlex.quote(path)
return f"{command} {path_arg}"
def build_skills_prompt(skills: list[SkillInfo]) -> str:
skills_lines = []
"""Build the skills section of the system prompt.
Generates a markdown-formatted skill inventory for the LLM. Only
``name`` and ``description`` are shown upfront; the LLM must read
the full ``SKILL.md`` before execution (progressive disclosure).
"""
skills_lines: list[str] = []
example_path = ""
for skill in skills:
display_name = _sanitize_skill_display_name(skill.name)
description = skill.description or "No description"
skills_lines.append(f"- {skill.name}: {description} (file: {skill.path})")
if skill.source_type == "sandbox_only":
description = _sanitize_prompt_description(description)
if not description:
description = "Read SKILL.md for details."
if skill.source_type == "sandbox_only":
rendered_path = (
f"{str(SANDBOX_WORKSPACE_ROOT)}/{str(SANDBOX_SKILLS_ROOT)}/"
f"{display_name}/SKILL.md"
)
else:
rendered_path = _sanitize_prompt_path_for_prompt(skill.path)
if not rendered_path:
rendered_path = "<skills_root>/<skill_name>/SKILL.md"
skills_lines.append(
f"- **{display_name}**: {description}\n File: `{rendered_path}`"
)
if not example_path:
example_path = rendered_path
skills_block = "\n".join(skills_lines)
# Based on openai/codex
# Sanitize example_path — it may originate from sandbox cache (untrusted)
if example_path == "<skills_root>/<skill_name>/SKILL.md":
example_path = "<skills_root>/<skill_name>/SKILL.md"
else:
example_path = _sanitize_prompt_path_for_prompt(example_path)
example_path = example_path or "<skills_root>/<skill_name>/SKILL.md"
example_command = _build_skill_read_command_example(example_path)
return (
"## Skills\n"
"You have many useful skills that can help you accomplish various tasks.\n"
"A skill is a set of local instructions stored in a `SKILL.md` file.\n"
"### Available skills\n"
f"{skills_block}\n"
"### Skill Rules\n"
"\n"
"- Discovery: The list above shows all skills available in this session. Full instructions live in the referenced `SKILL.md`.\n"
"- Trigger rules: Use a skill if the user names it or the task matches its description. Do not carry skills across turns unless re-mentioned\n"
"### How to use a skill (progressive disclosure):\n"
" 0) Mandatory grounding: Before using any skill, you MUST inspect its `SKILL.md` using shell tools"
" (e.g., `cat`, `head`, `sed`, `awk`, `grep`). Do not rely on assumptions or memory.\n"
" 1) Load only directly referenced files, DO NOT bulk-load everything.\n"
" 2) If `scripts/` exist, prefer running or patching them instead of retyping large blocks of code.\n"
" 3) If `assets/` or templates exist, reuse them rather than recreating everything from scratch.\n"
"- Coordination:\n"
" - If multiple skills apply, choose the minimal set that covers the request and state the order in which you will use them.\n"
" - Announce which skill(s) you are using and why (one short line). If you skip an obvious skill, explain why.\n"
" - Prefer to use `astrbot_*` tools to perform skills that need to run scripts.\n"
"- Context hygiene:\n"
" - Avoid deep reference chasing: unless blocked, open only files that are directly linked from `SKILL.md`.\n"
"- Failure handling: If a skill cannot be applied, state the issue and continue with the best alternative.\n"
"### Example\n"
"When you decided to use a skill, use shell tool to read its `SKILL.md`, e.g., `head -40 skills/code_formatter/SKILL.md`, and you can increase or decrease the number of lines as needed.\n"
"## Skills\n\n"
"You have specialized skills — reusable instruction bundles stored "
"in `SKILL.md` files. Each skill has a **name** and a **description** "
"that tells you what it does and when to use it.\n\n"
"### Available skills\n\n"
f"{skills_block}\n\n"
"### Skill rules\n\n"
"1. **Discovery** — The list above is the complete skill inventory "
"for this session. Full instructions are in the referenced "
"`SKILL.md` file.\n"
"2. **When to trigger** — Use a skill if the user names it "
"explicitly, or if the task clearly matches the skill's description. "
"*Never silently skip a matching skill* — either use it or briefly "
"explain why you chose not to.\n"
"3. **Mandatory grounding** — Before executing any skill you MUST "
"first read its `SKILL.md` by running a shell command compatible "
"with the current runtime shell and using the **absolute path** "
f"shown above (e.g. `{example_command}`). "
"Never rely on memory or assumptions about a skill's content.\n"
"4. **Progressive disclosure** — Load only what is directly "
"referenced from `SKILL.md`:\n"
" - If `scripts/` exist, prefer running or patching them over "
"rewriting code from scratch.\n"
" - If `assets/` or templates exist, reuse them.\n"
" - Do NOT bulk-load every file in the skill directory.\n"
"5. **Coordination** — When multiple skills apply, pick the minimal "
"set needed. Announce which skill(s) you are using and why "
"(one short line). Prefer `astrbot_*` tools when running skill "
"scripts.\n"
"6. **Context hygiene** — Avoid deep reference chasing; open only "
"files that are directly linked from `SKILL.md`.\n"
"7. **Failure handling** — If a skill cannot be applied, state the "
"issue clearly and continue with the best alternative.\n"
)
class SkillManager:
def __init__(self, skills_root: str | None = None) -> None:
self.skills_root = skills_root or get_astrbot_skills_path()
self.config_path = os.path.join(get_astrbot_data_path(), SKILLS_CONFIG_FILENAME)
data_path = Path(get_astrbot_data_path())
self.config_path = str(data_path / SKILLS_CONFIG_FILENAME)
self.sandbox_skills_cache_path = str(data_path / SANDBOX_SKILLS_CACHE_FILENAME)
os.makedirs(self.skills_root, exist_ok=True)
def _load_config(self) -> dict:
@@ -108,6 +236,66 @@ class SkillManager:
with open(self.config_path, "w", encoding="utf-8") as f:
json.dump(config, f, ensure_ascii=False, indent=4)
def _load_sandbox_skills_cache(self) -> dict:
if not os.path.exists(self.sandbox_skills_cache_path):
return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []}
try:
with open(self.sandbox_skills_cache_path, encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, dict):
return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []}
skills = data.get("skills", [])
if not isinstance(skills, list):
skills = []
return {
"version": int(data.get("version", _SANDBOX_SKILLS_CACHE_VERSION)),
"skills": skills,
"updated_at": data.get("updated_at"),
}
except Exception:
return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []}
def _save_sandbox_skills_cache(self, cache: dict) -> None:
cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION
cache["updated_at"] = datetime.now(timezone.utc).isoformat()
with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f:
json.dump(cache, f, ensure_ascii=False, indent=2)
def set_sandbox_skills_cache(self, skills: list[dict]) -> None:
"""Persist sandbox skill metadata discovered from runtime side."""
deduped: dict[str, dict[str, str]] = {}
for item in skills:
if not isinstance(item, dict):
continue
name = str(item.get("name", "")).strip()
if not name or not _SKILL_NAME_RE.match(name):
continue
description = str(item.get("description", "") or "")
path = str(item.get("path", "") or "")
if not path:
path = f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{name}/SKILL.md"
deduped[name] = {
"name": name,
"description": description,
"path": path.replace("\\", "/"),
}
cache = {
"version": _SANDBOX_SKILLS_CACHE_VERSION,
"skills": [deduped[name] for name in sorted(deduped)],
}
self._save_sandbox_skills_cache(cache)
def get_sandbox_skills_cache_status(self) -> dict[str, object]:
cache = self._load_sandbox_skills_cache()
skills = cache.get("skills", [])
count = len(skills) if isinstance(skills, list) else 0
return {
"exists": os.path.exists(self.sandbox_skills_cache_path),
"ready": count > 0,
"count": count,
"updated_at": cache.get("updated_at"),
}
def list_skills(
self,
*,
@@ -124,7 +312,21 @@ class SkillManager:
config = self._load_config()
skill_configs = config.get("skills", {})
modified = False
skills: list[SkillInfo] = []
skills_by_name: dict[str, SkillInfo] = {}
sandbox_cached_paths: dict[str, str] = {}
sandbox_cached_descriptions: dict[str, str] = {}
cache_for_paths = self._load_sandbox_skills_cache()
for item in cache_for_paths.get("skills", []):
if not isinstance(item, dict):
continue
name = str(item.get("name", "") or "").strip()
path = str(item.get("path", "") or "").strip().replace("\\", "/")
if not name or not _SKILL_NAME_RE.match(name):
continue
sandbox_cached_descriptions[name] = str(item.get("description", "") or "")
if path:
sandbox_cached_paths[name] = path
for entry in sorted(Path(self.skills_root).iterdir()):
if not entry.is_dir():
@@ -145,36 +347,129 @@ class SkillManager:
description = _parse_frontmatter_description(content)
except Exception:
description = ""
sandbox_exists = (
runtime == "sandbox" and skill_name in sandbox_cached_descriptions
)
source_type = "both" if sandbox_exists else "local_only"
source_label = "synced" if sandbox_exists else "local"
if runtime == "sandbox" and show_sandbox_path:
path_str = f"{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
path_str = sandbox_cached_paths.get(skill_name) or (
f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
)
else:
path_str = str(skill_md)
path_str = path_str.replace("\\", "/")
skills.append(
SkillInfo(
skills_by_name[skill_name] = SkillInfo(
name=skill_name,
description=description,
path=path_str,
active=active,
source_type=source_type,
source_label=source_label,
local_exists=True,
sandbox_exists=sandbox_exists,
)
if runtime == "sandbox":
cache = self._load_sandbox_skills_cache()
for item in cache.get("skills", []):
if not isinstance(item, dict):
continue
skill_name = str(item.get("name", "")).strip()
if (
not skill_name
or skill_name in skills_by_name
or not _SKILL_NAME_RE.match(skill_name)
):
continue
active = skill_configs.get(skill_name, {}).get("active", True)
if skill_name not in skill_configs:
skill_configs[skill_name] = {"active": active}
modified = True
if active_only and not active:
continue
description = sandbox_cached_descriptions.get(skill_name, "")
if show_sandbox_path:
path_str = f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
else:
path_str = sandbox_cached_paths.get(skill_name, "")
if not path_str:
path_str = f"{SANDBOX_WORKSPACE_ROOT}/{SANDBOX_SKILLS_ROOT}/{skill_name}/SKILL.md"
skills_by_name[skill_name] = SkillInfo(
name=skill_name,
description=description,
path=path_str,
path=path_str.replace("\\", "/"),
active=active,
source_type="sandbox_only",
source_label="sandbox_preset",
local_exists=False,
sandbox_exists=True,
)
)
if modified:
config["skills"] = skill_configs
self._save_config(config)
return skills
return [skills_by_name[name] for name in sorted(skills_by_name)]
def is_sandbox_only_skill(self, name: str) -> bool:
skill_dir = Path(self.skills_root) / name
skill_md_exists = (skill_dir / "SKILL.md").exists()
if skill_md_exists:
return False
cache = self._load_sandbox_skills_cache()
skills = cache.get("skills", [])
if not isinstance(skills, list):
return False
for item in skills:
if not isinstance(item, dict):
continue
if str(item.get("name", "")).strip() == name:
return True
return False
def set_skill_active(self, name: str, active: bool) -> None:
if self.is_sandbox_only_skill(name):
raise PermissionError(
"Sandbox preset skill cannot be enabled/disabled from local skill management."
)
config = self._load_config()
config.setdefault("skills", {})
config["skills"][name] = {"active": bool(active)}
self._save_config(config)
def _remove_skill_from_sandbox_cache(self, name: str) -> None:
cache = self._load_sandbox_skills_cache()
skills = cache.get("skills", [])
if not isinstance(skills, list):
return
filtered = [
item
for item in skills
if not (
isinstance(item, dict) and str(item.get("name", "")).strip() == name
)
]
if len(filtered) != len(skills):
cache["skills"] = filtered
self._save_sandbox_skills_cache(cache)
def delete_skill(self, name: str) -> None:
if self.is_sandbox_only_skill(name):
raise PermissionError(
"Sandbox preset skill cannot be deleted from local skill management."
)
skill_dir = Path(self.skills_root) / name
if skill_dir.exists():
shutil.rmtree(skill_dir)
# Ensure UI consistency even when there is no active sandbox session
# to refresh cache from runtime side.
self._remove_skill_from_sandbox_cache(name)
config = self._load_config()
if name in config.get("skills", {}):
config["skills"].pop(name, None)
@@ -188,7 +483,11 @@ class SkillManager:
raise ValueError("Uploaded file is not a valid zip archive.")
with zipfile.ZipFile(zip_path) as zf:
names = [name.replace("\\", "/") for name in zf.namelist()]
names = [
name
for name in (entry.replace("\\", "/") for entry in zf.namelist())
if name and not _is_ignored_zip_entry(name)
]
file_names = [name for name in names if name and not name.endswith("/")]
if not file_names:
raise ValueError("Zip archive is empty.")
@@ -196,7 +495,7 @@ class SkillManager:
top_dirs = {
PurePosixPath(name).parts[0] for name in file_names if name.strip()
}
print(top_dirs)
if len(top_dirs) != 1:
raise ValueError("Zip archive must contain a single top-level folder.")
skill_name = next(iter(top_dirs))
@@ -223,7 +522,11 @@ class SkillManager:
raise ValueError("SKILL.md not found in the skill folder.")
with tempfile.TemporaryDirectory(dir=get_astrbot_temp_path()) as tmp_dir:
zf.extractall(tmp_dir)
for member in zf.infolist():
member_name = member.filename.replace("\\", "/")
if not member_name or _is_ignored_zip_entry(member_name):
continue
zf.extract(member, tmp_dir)
src_dir = Path(tmp_dir) / skill_name
if not src_dir.exists():
raise ValueError("Skill folder not found after extraction.")
+1 -1
View File
@@ -15,4 +15,4 @@ class RegexFilter(HandlerFilter):
self.regex = re.compile(regex)
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
return bool(self.regex.match(event.get_message_str().strip()))
return bool(self.regex.search(event.get_message_str().strip()))
+161 -10
View File
@@ -1,12 +1,14 @@
"""插件的重载、启停、安装、卸载等操作。"""
import asyncio
import contextlib
import functools
import inspect
import json
import logging
import os
import sys
import tempfile
import traceback
from types import ModuleType
@@ -14,7 +16,12 @@ import yaml
from packaging.specifiers import InvalidSpecifier, SpecifierSet
from packaging.version import InvalidVersion, Version
from astrbot.core import logger, pip_installer, sp
from astrbot.core import (
DependencyConflictError,
logger,
pip_installer,
sp,
)
from astrbot.core.agent.handoff import FunctionTool, HandoffTool
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.config.default import VERSION
@@ -24,9 +31,13 @@ from astrbot.core.utils.astrbot_path import (
get_astrbot_config_path,
get_astrbot_path,
get_astrbot_plugin_path,
get_astrbot_temp_path,
)
from astrbot.core.utils.io import remove_dir
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.requirements_utils import (
plan_missing_requirements_install,
)
from . import StarMetadata
from .command_management import sync_command_configs
@@ -48,6 +59,97 @@ class PluginVersionIncompatibleError(Exception):
"""Raised when plugin astrbot_version is incompatible with current AstrBot."""
class PluginDependencyInstallError(Exception):
"""Raised when plugin dependency installation fails."""
def __init__(
self,
*,
plugin_label: str,
requirements_path: str,
error: Exception,
) -> None:
message = f"插件 {plugin_label} 依赖安装失败: {error!s}"
super().__init__(message)
self.plugin_label = plugin_label
self.requirements_path = requirements_path
self.error = error
@contextlib.contextmanager
def _temporary_filtered_requirements_file(
*,
install_lines: tuple[str, ...],
):
filtered_requirements_path: str | None = None
temp_dir = get_astrbot_temp_path()
try:
os.makedirs(temp_dir, exist_ok=True)
with tempfile.NamedTemporaryFile(
mode="w",
suffix="_plugin_requirements.txt",
delete=False,
dir=temp_dir,
encoding="utf-8",
) as filtered_requirements_file:
filtered_requirements_file.write("\n".join(install_lines) + "\n")
filtered_requirements_path = filtered_requirements_file.name
yield filtered_requirements_path
finally:
if filtered_requirements_path and os.path.exists(filtered_requirements_path):
try:
os.remove(filtered_requirements_path)
except OSError as exc:
logger.warning(
"删除临时插件依赖文件失败:%s(路径:%s",
exc,
filtered_requirements_path,
)
async def _install_requirements_with_precheck(
*,
plugin_label: str,
requirements_path: str,
) -> None:
install_plan = plan_missing_requirements_install(requirements_path)
if install_plan is None:
logger.info(
f"正在安装插件 {plugin_label} 的依赖库(缺失依赖预检查不可裁剪,回退到完整安装): "
f"{requirements_path}"
)
await pip_installer.install(requirements_path=requirements_path)
return
if not install_plan.missing_names:
logger.info(f"插件 {plugin_label} 的依赖已满足,跳过安装。")
return
if not install_plan.install_lines:
fallback_reason = install_plan.fallback_reason or "unknown reason"
logger.info(
"检测到插件 %s 缺失依赖,但无法安全裁剪 requirements,回退到完整安装: %s (%s)",
plugin_label,
requirements_path,
fallback_reason,
)
await pip_installer.install(requirements_path=requirements_path)
return
logger.info(
f"检测到插件 {plugin_label} 缺失依赖,正在按 requirements.txt 安装: "
f"{requirements_path} -> {sorted(install_plan.missing_names)}"
)
with _temporary_filtered_requirements_file(
install_lines=install_plan.install_lines,
) as filtered_requirements_path:
await pip_installer.install(requirements_path=filtered_requirements_path)
class PluginManager:
def __init__(self, context: Context, config: AstrBotConfig) -> None:
from .star_tools import StarTools
@@ -198,15 +300,37 @@ class PluginManager:
to_update.append(p.root_dir_name)
for p in to_update:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在安装插件 {p} 所需的依赖库: {pth}")
try:
await pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}")
await self._ensure_plugin_requirements(plugin_path, p)
return True
async def _ensure_plugin_requirements(
self,
plugin_dir_path: str,
plugin_label: str,
) -> None:
requirements_path = os.path.join(plugin_dir_path, "requirements.txt")
if not os.path.exists(requirements_path):
return
try:
await _install_requirements_with_precheck(
plugin_label=plugin_label,
requirements_path=requirements_path,
)
except asyncio.CancelledError:
raise
except DependencyConflictError as e:
logger.error(f"插件 {plugin_label} 依赖冲突: {e!s}")
raise
except Exception as e:
dependency_error = PluginDependencyInstallError(
plugin_label=plugin_label,
requirements_path=requirements_path,
error=e,
)
logger.exception(str(dependency_error))
raise dependency_error from e
async def _import_plugin_with_dependency_recovery(
self,
path: str,
@@ -422,7 +546,7 @@ class PluginManager:
root_dir_name: str,
plugin_dir_path: str,
reserved: bool,
error: Exception | str,
error: BaseException | str,
error_trace: str,
) -> dict:
record: dict = {
@@ -495,6 +619,9 @@ class PluginManager:
self._cleanup_plugin_state(dir_name)
plugin_path = os.path.join(self.plugin_store_path, dir_name)
await self._ensure_plugin_requirements(plugin_path, dir_name)
success, error = await self.load(specified_dir_name=dir_name)
if success:
self.failed_plugin_dict.pop(dir_name, None)
@@ -1078,6 +1205,10 @@ class PluginManager:
# reload the plugin
dir_name = os.path.basename(plugin_path)
await self._ensure_plugin_requirements(
plugin_path,
dir_name,
)
success, error_message = await self.load(
specified_dir_name=dir_name,
ignore_version_check=ignore_version_check,
@@ -1317,6 +1448,12 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin, proxy=proxy)
if plugin.root_dir_name:
plugin_dir_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
await self._ensure_plugin_requirements(
plugin_dir_path,
plugin_name,
)
await self.reload(plugin_name)
async def turn_off_plugin(self, plugin_name: str) -> None:
@@ -1374,10 +1511,23 @@ class PluginManager:
return
if "__del__" in star_metadata.star_cls_type.__dict__:
asyncio.get_event_loop().run_in_executor(
loop = asyncio.get_running_loop()
future = loop.run_in_executor(
None,
star_metadata.star_cls.__del__,
)
def _log_del_exception(fut: asyncio.Future) -> None:
if fut.cancelled():
return
if (exc := fut.exception()) is not None:
logger.error(
"插件 %s 在 __del__ 中抛出了异常:%r",
star_metadata.name,
exc,
)
future.add_done_callback(_log_del_exception)
elif "terminate" in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate()
@@ -1475,6 +1625,7 @@ class PluginManager:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {e!s}")
await self._ensure_plugin_requirements(desti_dir, dir_name)
# await self.reload()
success, error_message = await self.load(
specified_dir_name=dir_name,
+1 -1
View File
@@ -30,7 +30,7 @@ class CreateActiveCronTool(FunctionTool[AstrAgentContext]):
"properties": {
"cron_expression": {
"type": "string",
"description": "Cron expression defining recurring schedule (e.g., '0 8 * * *').",
"description": "Cron expression defining recurring schedule (e.g., '0 8 * * *' or '0 23 * * mon-fri'). Prefer named weekdays like 'mon-fri' or 'sat,sun' instead of numeric day-of-week ranges such as '1-5' to avoid ambiguity across cron implementations.",
},
"run_at": {
"type": "string",
+3 -1
View File
@@ -149,7 +149,9 @@ class AstrBotUpdator(RepoZipUpdator):
file_url = None
if os.environ.get("ASTRBOT_CLI") or os.environ.get("ASTRBOT_LAUNCHER"):
raise Exception("不支持更新此方式启动的AstrBot") # 避免版本管理混乱
raise Exception(
"Error: You are running AstrBot via CLI, please use `pip` or `uv tool upgrade` to update AstrBot."
) # 避免版本管理混乱
if latest:
latest_version = update_data[0]["tag_name"]
+121
View File
@@ -0,0 +1,121 @@
import contextlib
import functools
import importlib.metadata as importlib_metadata
import logging
import os
from collections.abc import Iterator
from packaging.requirements import Requirement
from astrbot.core.utils.requirements_utils import (
canonicalize_distribution_name,
collect_installed_distribution_versions,
get_requirement_check_paths,
)
logger = logging.getLogger("astrbot")
def _resolve_core_dist_name(core_dist_name: str | None) -> str | None:
if core_dist_name:
try:
importlib_metadata.distribution(core_dist_name)
return core_dist_name
except importlib_metadata.PackageNotFoundError:
return None
try:
importlib_metadata.distribution("AstrBot")
return "AstrBot"
except importlib_metadata.PackageNotFoundError:
pass
if not __package__:
return None
top_pkg = __package__.split(".")[0]
for dist in importlib_metadata.distributions():
try:
top_level = dist.read_text("top_level.txt") or ""
except Exception:
continue
if top_pkg in top_level.splitlines():
if "Name" in dist.metadata:
return dist.metadata["Name"]
return None
@functools.cache
def _get_core_constraints(core_dist_name: str | None) -> tuple[str, ...]:
try:
resolved_core_dist_name = _resolve_core_dist_name(core_dist_name)
except Exception as exc:
logger.warning("解析核心分发名称失败: %s", exc)
return ()
if not resolved_core_dist_name:
return ()
try:
dist = importlib_metadata.distribution(resolved_core_dist_name)
except importlib_metadata.PackageNotFoundError:
return ()
except Exception as exc:
logger.warning("读取核心分发元数据失败 (%s): %s", resolved_core_dist_name, exc)
return ()
if not dist or not dist.requires:
return ()
installed = collect_installed_distribution_versions(get_requirement_check_paths())
if not installed:
return ()
constraints: list[str] = []
for req_str in dist.requires:
try:
req = Requirement(req_str)
if req.marker and not req.marker.evaluate():
continue
name = canonicalize_distribution_name(req.name)
if name in installed:
constraints.append(f"{name}=={installed[name]}")
except Exception:
continue
return tuple(constraints)
class CoreConstraintsProvider:
def __init__(self, core_dist_name: str | None) -> None:
self._core_dist_name = core_dist_name
@contextlib.contextmanager
def constraints_file(self) -> Iterator[str | None]:
constraints = _get_core_constraints(self._core_dist_name)
if not constraints:
yield None
return
path: str | None = None
try:
import tempfile
with tempfile.NamedTemporaryFile(
mode="w", suffix="_constraints.txt", delete=False, encoding="utf-8"
) as f:
f.write("\n".join(constraints))
path = f.name
logger.info("已启用核心依赖版本保护 (%d 个约束)", len(constraints))
except Exception as exc:
logger.warning("创建临时约束文件失败: %s", exc)
yield None
return
try:
yield path
finally:
if path and os.path.exists(path):
with contextlib.suppress(Exception):
os.remove(path)
+82
View File
@@ -0,0 +1,82 @@
import re
_SECRET_KEYS = (
r"(?:api_?key|access_?token|auth_?token|refresh_?token|session_?id|secret|password)"
)
_JSON_FIELD_PATTERN = re.compile(
rf"(?i)(?P<prefix>(?P<kq>['\"]){_SECRET_KEYS}(?P=kq)\s*:\s*)(?P<vq>['\"])(?P<value>[^'\"]+)(?P=vq)"
)
_AUTH_JSON_FIELD_PATTERN = re.compile(
r"(?i)(?P<prefix>(?P<kq>['\"])authorization(?P=kq)\s*:\s*)(?P<vq>['\"])bearer\s+[^'\"]+(?P=vq)"
)
_QUERY_FIELD_PATTERN = re.compile(
rf"(?i)(?P<prefix>{_SECRET_KEYS}\s*=\s*)(?P<value>[^&'\" ]+)"
)
_QUERY_PARAM_PATTERN = re.compile(
r"(?i)(?P<prefix>[?&](?:api_?key|key|access_?token|auth_?token)=)(?P<value>[^&'\" ]+)"
)
_AUTH_HEADER_PATTERN = re.compile(
r"(?i)(?P<prefix>\bauthorization\s*:\s*bearer\s+)(?P<token>[A-Za-z0-9._\-]+)"
)
_BEARER_PATTERN = re.compile(r"(?i)(?P<prefix>\bbearer\s+)(?P<token>[A-Za-z0-9._\-]+)")
_SK_PATTERN = re.compile(r"\bsk-[A-Za-z0-9]{16,}\b")
def _redact_json_field(match: re.Match[str]) -> str:
quote = match.group("vq")
return f"{match.group('prefix')}{quote}[REDACTED]{quote}"
def _redact_auth_json_field(match: re.Match[str]) -> str:
quote = match.group("vq")
return f"{match.group('prefix')}{quote}Bearer [REDACTED]{quote}"
def _redact_prefixed_value(match: re.Match[str]) -> str:
return f"{match.group('prefix')}[REDACTED]"
def _redact_bearer_token(match: re.Match[str]) -> str:
return f"{match.group('prefix')}[REDACTED]"
def _redact_json_like(text: str) -> str:
text = _JSON_FIELD_PATTERN.sub(_redact_json_field, text)
return _AUTH_JSON_FIELD_PATTERN.sub(_redact_auth_json_field, text)
def _redact_query_like(text: str) -> str:
text = _QUERY_FIELD_PATTERN.sub(_redact_prefixed_value, text)
return _QUERY_PARAM_PATTERN.sub(_redact_prefixed_value, text)
def _redact_tokens(text: str) -> str:
text = _AUTH_HEADER_PATTERN.sub(_redact_bearer_token, text)
text = _BEARER_PATTERN.sub(_redact_bearer_token, text)
return _SK_PATTERN.sub("[REDACTED]", text)
def redact_sensitive_text(text: str) -> str:
text = _redact_json_like(text)
text = _redact_query_like(text)
text = _redact_tokens(text)
return text
def safe_error(
prefix: str,
error: Exception | BaseException | str,
*,
redact: bool = True,
) -> str:
try:
text = str(error)
except Exception:
try:
text = repr(error)
except Exception:
text = "<unprintable error>"
if redact:
text = redact_sensitive_text(text)
return prefix + text
+7 -1
View File
@@ -14,7 +14,7 @@ import certifi
import psutil
from PIL import Image
from .astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
from .astrbot_path import get_astrbot_data_path, get_astrbot_path, get_astrbot_temp_path
logger = logging.getLogger("astrbot")
@@ -219,7 +219,13 @@ def get_local_ip_addresses():
async def get_dashboard_version():
# First check user data directory (manually updated / downloaded dashboard).
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
if not os.path.exists(dist_dir):
# Fall back to the dist bundled inside the installed wheel.
_bundled = Path(get_astrbot_path()) / "astrbot" / "dashboard" / "dist"
if _bundled.exists():
dist_dir = str(_bundled)
if os.path.exists(dist_dir):
version_file = os.path.join(dist_dir, "assets", "version")
if os.path.exists(version_file):
+428 -96
View File
@@ -7,21 +7,71 @@ import io
import logging
import os
import re
import shlex
import sys
import threading
from collections import deque
from dataclasses import dataclass
from urllib.parse import urlparse
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
from astrbot.core.utils.core_constraints import CoreConstraintsProvider
from astrbot.core.utils.requirements_utils import (
canonicalize_distribution_name as _canonicalize_distribution_name,
)
from astrbot.core.utils.requirements_utils import (
extract_requirement_name,
extract_requirement_names,
parse_package_install_input,
)
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
logger = logging.getLogger("astrbot")
_DISTLIB_FINDER_PATCH_ATTEMPTED = False
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
_PIP_FAILURE_PATTERNS = {
"error_prefix": re.compile(r"^\s*error:", re.IGNORECASE),
"user_requested": re.compile(r"\bthe user requested\b", re.IGNORECASE),
"resolution_impossible": re.compile(r"\bresolutionimpossible\b", re.IGNORECASE),
"cannot_install": re.compile(r"\bcannot install\b", re.IGNORECASE),
"conflict": re.compile(r"\bconflict(?:ing|s)?\b", re.IGNORECASE),
"constraint": re.compile(r"\(constraint\)", re.IGNORECASE),
"dependency_detail": re.compile(r"\bdepends on\b", re.IGNORECASE),
}
_SENSITIVE_PIP_VALUE_KEYS = frozenset(
{"password", "passwd", "pass", "api_token", "token", "auth_token"}
)
_MAX_PIP_OUTPUT_LINES = 200
def _canonicalize_distribution_name(name: str) -> str:
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
class DependencyConflictError(Exception):
"""Raised when pip encounters a dependency conflict."""
def __init__(
self, message: str, errors: list[str], *, is_core_conflict: bool
) -> None:
super().__init__(message)
self.errors = errors
self.is_core_conflict = is_core_conflict
class PipInstallError(Exception):
"""Raised when pip install fails without a classified dependency conflict."""
def __init__(self, message: str, *, code: int) -> None:
super().__init__(message)
self.code = code
@dataclass
class PipConflictContext:
relevant_lines: list[str]
requested_lines: list[str]
dependency_detail_lines: list[str]
constraint_lines: list[str]
has_strong_conflict_signal: bool
has_contextual_conflict_signal: bool
def _get_pip_main():
@@ -41,11 +91,12 @@ def _get_pip_main():
return pip_main
def _run_pip_main_with_output(pip_main, args: list[str]) -> tuple[int, str]:
stream = io.StringIO()
with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream):
result_code = pip_main(args)
return result_code, stream.getvalue()
def _prepend_sys_path(path: str) -> None:
normalized_target = os.path.realpath(path)
sys.path[:] = [
item for item in sys.path if os.path.realpath(item) != normalized_target
]
sys.path.insert(0, normalized_target)
def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None:
@@ -59,76 +110,258 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No
handler.close()
def _prepend_sys_path(path: str) -> None:
normalized_target = os.path.realpath(path)
sys.path[:] = [
item for item in sys.path if os.path.realpath(item) != normalized_target
]
sys.path.insert(0, normalized_target)
def _get_trusted_host_for_index_url(index_url: str) -> str | None:
parsed = urlparse(index_url if "://" in index_url else f"//{index_url}")
host = parsed.hostname
if host == "mirrors.aliyun.com":
return host
return None
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
base_path = os.path.join(site_packages_path, *module_name.split("."))
package_init = os.path.join(base_path, "__init__.py")
module_file = f"{base_path}.py"
return os.path.isfile(package_init) or os.path.isfile(module_file)
def _normalize_sensitive_pip_key(raw_key: str) -> str:
return raw_key.lstrip("-").replace("-", "_").lower()
def _is_module_loaded_from_site_packages(
module_name: str,
site_packages_path: str,
) -> bool:
module = sys.modules.get(module_name)
if module is None:
try:
module = importlib.import_module(module_name)
except Exception:
return False
def _is_sensitive_pip_value_key(raw_key: str) -> bool:
return _normalize_sensitive_pip_key(raw_key) in _SENSITIVE_PIP_VALUE_KEYS
module_file = getattr(module, "__file__", None)
if not module_file:
return False
module_path = os.path.realpath(module_file)
site_packages_real = os.path.realpath(site_packages_path)
try:
return (
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
def _redact_url_credentials(raw_value: str) -> str:
"""Redact URL credentials and known inline secret values for safe logging."""
parsed = urlparse(raw_value)
if parsed.netloc and "@" in parsed.netloc:
hostname = parsed.hostname or ""
port = f":{parsed.port}" if parsed.port else ""
return parsed._replace(netloc=f"<redacted>@{hostname}{port}").geturl()
if raw_value.startswith("--"):
option, separator, _ = raw_value.partition("=")
if separator and _is_sensitive_pip_value_key(option):
return f"{option}=****"
return raw_value
key, separator, _ = raw_value.partition("=")
if separator and _is_sensitive_pip_value_key(key):
return f"{key}=****"
return raw_value
def _redact_pip_args_for_logging(args: list[str]) -> list[str]:
redacted_args: list[str] = []
redact_next_value = False
for arg in args:
if redact_next_value:
redacted_args.append("****")
redact_next_value = False
continue
if arg.startswith("--") and "=" in arg:
option, value = arg.split("=", 1)
if _is_sensitive_pip_value_key(option):
redacted_args.append(f"{option}=****")
else:
redacted_args.append(f"{option}={_redact_url_credentials(value)}")
continue
if arg.startswith("-i") and arg != "-i":
redacted_args.append(f"-i{_redact_url_credentials(arg[2:])}")
continue
if _is_sensitive_pip_value_key(arg):
redacted_args.append(arg)
redact_next_value = True
continue
redacted_args.append(_redact_url_credentials(arg))
return redacted_args
def _package_specs_override_index(package_specs: list[str]) -> bool:
for index, spec in enumerate(package_specs):
if spec == "--no-index":
return True
if spec in {"-i", "--index-url"}:
if index + 1 < len(package_specs):
return True
continue
if spec.startswith("--index-url="):
return True
if spec.startswith("-i") and spec != "-i":
return True
return False
class _StreamingLogWriter(io.TextIOBase):
def __init__(self, log_func, *, max_lines: int | None = None) -> None:
self._log_func = log_func
self._lines = deque(maxlen=max_lines or _MAX_PIP_OUTPUT_LINES)
self._buffer = ""
def write(self, text: str) -> int:
if not text:
return 0
self._buffer += text.replace("\r\n", "\n").replace("\r", "\n")
while "\n" in self._buffer:
raw_line, self._buffer = self._buffer.split("\n", 1)
line = raw_line.rstrip("\r\n")
self._log_func(line)
self._lines.append(line)
return len(text)
def flush(self) -> None:
line = self._buffer.rstrip("\r\n")
if line:
self._log_func(line)
self._lines.append(line)
self._buffer = ""
@property
def lines(self) -> list[str]:
return list(self._lines)
def _run_pip_main_streaming(pip_main, args: list[str]) -> tuple[int, list[str]]:
stream = _StreamingLogWriter(logger.info, max_lines=_MAX_PIP_OUTPUT_LINES)
with (
contextlib.redirect_stdout(stream),
contextlib.redirect_stderr(stream),
):
result_code = pip_main(args)
stream.flush()
return result_code, stream.lines
def _matches_pip_failure_pattern(line: str, *pattern_names: str) -> bool:
names = pattern_names or tuple(_PIP_FAILURE_PATTERNS)
return any(_PIP_FAILURE_PATTERNS[name].search(line) for name in names)
def _normalize_conflict_detail_line(line: str) -> str:
stripped = line.strip()
if _matches_pip_failure_pattern(stripped, "user_requested"):
return re.sub(
r"^\s*The user requested\s+",
"",
stripped,
flags=re.IGNORECASE,
)
except ValueError:
return False
return stripped
def _extract_requirement_name(raw_requirement: str) -> str | None:
line = raw_requirement.split("#", 1)[0].strip()
if not line:
return None
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
return None
if line.startswith("-"):
def _build_pip_conflict_context(output_lines: list[str]) -> PipConflictContext | None:
matched_indices = [
index
for index, line in enumerate(output_lines)
if _matches_pip_failure_pattern(line)
]
if matched_indices:
relevant_index_set: set[int] = set()
for index in matched_indices:
start = max(0, index - 1)
end = min(len(output_lines), index + 2)
relevant_index_set.update(range(start, end))
relevant_output_lines = [
line
for index, line in enumerate(output_lines)
if index in relevant_index_set
]
else:
relevant_output_lines = output_lines[-5:]
if not relevant_output_lines:
return None
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
if egg_match:
return _canonicalize_distribution_name(egg_match.group(1))
dependency_detail_lines = [
line.strip()
for line in relevant_output_lines
if _matches_pip_failure_pattern(line, "dependency_detail")
]
requested_lines = [
line.strip()
for line in relevant_output_lines
if _matches_pip_failure_pattern(line, "user_requested")
and not _matches_pip_failure_pattern(line, "constraint")
]
if not requested_lines:
requested_lines = [
line
for line in dependency_detail_lines
if not _matches_pip_failure_pattern(line, "constraint")
]
constraint_lines = [
line.strip()
for line in relevant_output_lines
if _matches_pip_failure_pattern(line, "constraint")
]
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
if not candidate:
has_strong_conflict_signal = any(
_matches_pip_failure_pattern(
line,
"resolution_impossible",
"cannot_install",
)
for line in relevant_output_lines
)
has_contextual_conflict_signal = any(
_matches_pip_failure_pattern(line, "conflict") for line in relevant_output_lines
) and bool(dependency_detail_lines or requested_lines or constraint_lines)
return PipConflictContext(
relevant_lines=relevant_output_lines,
requested_lines=requested_lines,
dependency_detail_lines=dependency_detail_lines,
constraint_lines=constraint_lines,
has_strong_conflict_signal=has_strong_conflict_signal,
has_contextual_conflict_signal=has_contextual_conflict_signal,
)
def _classify_pip_failure(output_lines: list[str]) -> DependencyConflictError | None:
context = _build_pip_conflict_context(output_lines)
if context is None:
return None
return _canonicalize_distribution_name(candidate)
if (
not context.has_strong_conflict_signal
and not context.has_contextual_conflict_signal
and not (context.requested_lines and context.constraint_lines)
):
return None
def _extract_requirement_names(requirements_path: str) -> set[str]:
names: set[str] = set()
try:
with open(requirements_path, encoding="utf-8") as requirements_file:
for line in requirements_file:
requirement_name = _extract_requirement_name(line)
if requirement_name:
names.add(requirement_name)
except Exception as exc:
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
return names
is_core_conflict = bool(context.constraint_lines)
detail = ""
if context.constraint_lines and context.requested_lines:
detail = (
" 冲突详情: "
f"{_normalize_conflict_detail_line(context.requested_lines[0])} vs "
f"{_normalize_conflict_detail_line(context.constraint_lines[0])}"
)
elif len(context.dependency_detail_lines) >= 2:
detail = (
" 冲突详情: "
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[0])} vs "
f"{_normalize_conflict_detail_line(context.dependency_detail_lines[1])}"
)
if is_core_conflict:
message = (
f"检测到核心依赖版本保护冲突。{detail}插件要求的依赖版本与 AstrBot 核心不兼容,"
"为了系统稳定,已阻止该降级行为。请联系插件作者或调整 requirements.txt。"
)
else:
message = f"检测到依赖冲突。{detail}"
return DependencyConflictError(
message,
context.relevant_lines,
is_core_conflict=is_core_conflict,
)
def _extract_top_level_modules(
@@ -155,7 +388,11 @@ def _collect_candidate_modules(
by_name: dict[str, list[importlib_metadata.Distribution]] = {}
try:
for distribution in importlib_metadata.distributions(path=[site_packages_path]):
distribution_name = distribution.metadata.get("Name")
distribution_name = (
distribution.metadata["Name"]
if "Name" in distribution.metadata
else None
)
if not distribution_name:
continue
canonical_name = _canonicalize_distribution_name(distribution_name)
@@ -173,7 +410,7 @@ def _collect_candidate_modules(
for distribution in by_name.get(requirement_name, []):
for dependency_line in distribution.requires or []:
dependency_name = _extract_requirement_name(dependency_line)
dependency_name = extract_requirement_name(dependency_line)
if not dependency_name:
continue
if dependency_name in expanded_requirement_names:
@@ -230,6 +467,38 @@ def _ensure_preferred_modules(
raise RuntimeError(conflict_message)
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
base_path = os.path.join(site_packages_path, *module_name.split("."))
package_init = os.path.join(base_path, "__init__.py")
module_file = f"{base_path}.py"
return os.path.isfile(package_init) or os.path.isfile(module_file)
def _is_module_loaded_from_site_packages(
module_name: str,
site_packages_path: str,
) -> bool:
module = sys.modules.get(module_name)
if module is None:
try:
module = importlib.import_module(module_name)
except Exception:
return False
module_file = getattr(module, "__file__", None)
if not module_file:
return False
module_path = os.path.realpath(module_file)
site_packages_real = os.path.realpath(site_packages_path)
try:
return (
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
)
except ValueError:
return False
def _prefer_module_from_site_packages(
module_name: str, site_packages_path: str
) -> bool:
@@ -531,9 +800,63 @@ def _patch_distlib_finder_for_frozen_runtime() -> None:
class PipInstaller:
def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None:
def __init__(
self,
pip_install_arg: str,
pypi_index_url: str | None = None,
core_dist_name: str | None = "AstrBot",
) -> None:
self.pip_install_arg = pip_install_arg
self.pypi_index_url = pypi_index_url
self.core_dist_name = core_dist_name
self._core_constraints = CoreConstraintsProvider(core_dist_name)
def _build_pip_args(
self,
package_name: str | None,
requirements_path: str | None,
mirror: str | None,
) -> tuple[list[str], set[str]]:
args: list[str] = []
requested_requirements: set[str] = set()
normalized_requirements_path = (
requirements_path.strip() if requirements_path else ""
)
if package_name and normalized_requirements_path:
raise ValueError(
"package_name and requirements_path cannot be used together"
)
if package_name:
parsed_package = parse_package_install_input(package_name)
if parsed_package.specs:
args = ["install", *parsed_package.specs]
requested_requirements = set(parsed_package.requirement_names)
elif normalized_requirements_path:
args = ["install", "-r", normalized_requirements_path]
requested_requirements = extract_requirement_names(
normalized_requirements_path
)
if not args:
return [], requested_requirements
pip_install_args = (
shlex.split(self.pip_install_arg) if self.pip_install_arg else []
)
if not _package_specs_override_index([*args[1:], *pip_install_args]):
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
trusted_host = _get_trusted_host_for_index_url(index_url)
if trusted_host:
args.extend(["--trusted-host", trusted_host])
args.extend(["-i", index_url])
if pip_install_args:
args.extend(pip_install_args)
return args, requested_requirements
async def install(
self,
@@ -541,36 +864,37 @@ class PipInstaller:
requirements_path: str | None = None,
mirror: str | None = None,
) -> None:
args = ["install"]
requested_requirements: set[str] = set()
if package_name:
args.append(package_name)
requirement_name = _extract_requirement_name(package_name)
if requirement_name:
requested_requirements.add(requirement_name)
elif requirements_path:
args.extend(["-r", requirements_path])
requested_requirements = _extract_requirement_names(requirements_path)
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
args, requested_requirements = self._build_pip_args(
package_name, requirements_path, mirror
)
if not args:
logger.info("Pip 包管理器跳过安装:未提供有效的包名或 requirements 文件。")
return
target_site_packages = None
if is_packaged_desktop_runtime():
target_site_packages = get_astrbot_site_packages_path()
os.makedirs(target_site_packages, exist_ok=True)
_prepend_sys_path(target_site_packages)
args.extend(["--target", target_site_packages])
args.extend(["--upgrade", "--force-reinstall"])
args.extend(
[
"--target",
target_site_packages,
"--upgrade",
"--upgrade-strategy",
"only-if-needed",
]
)
if self.pip_install_arg:
args.extend(self.pip_install_arg.split())
with self._core_constraints.constraints_file() as constraints_file_path:
if constraints_file_path:
args.extend(["-c", constraints_file_path])
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
result_code = await self._run_pip_in_process(args)
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")
logger.info(
"Pip 包管理器 argv: %s",
["pip", *_redact_pip_args_for_logging(args)],
)
await self._run_pip_with_classification(args)
if target_site_packages:
_prepend_sys_path(target_site_packages)
@@ -589,7 +913,7 @@ class PipInstaller:
if not os.path.isdir(target_site_packages):
return
requested_requirements = _extract_requirement_names(requirements_path)
requested_requirements = extract_requirement_names(requirements_path)
if not requested_requirements:
return
@@ -605,13 +929,21 @@ class PipInstaller:
_patch_distlib_finder_for_frozen_runtime()
original_handlers = list(logging.getLogger().handlers)
result_code, output = await asyncio.to_thread(
_run_pip_main_with_output, pip_main, args
)
for line in output.splitlines():
line = line.strip()
if line:
logger.info(line)
try:
result_code, output_lines = await asyncio.to_thread(
_run_pip_main_streaming, pip_main, args
)
finally:
_cleanup_added_root_handlers(original_handlers)
if result_code != 0:
conflict = _classify_pip_failure(output_lines)
if conflict:
raise conflict
_cleanup_added_root_handlers(original_handlers)
return result_code
async def _run_pip_with_classification(self, args: list[str]) -> None:
result_code = await self._run_pip_in_process(args)
if result_code != 0:
raise PipInstallError(f"安装失败,错误码:{result_code}", code=result_code)
+486
View File
@@ -0,0 +1,486 @@
import importlib.metadata as importlib_metadata
import logging
import os
import re
import shlex
import sys
from collections.abc import Iterable, Iterator, Sequence
from dataclasses import dataclass
from packaging.requirements import InvalidRequirement, Requirement
from packaging.specifiers import SpecifierSet
from packaging.version import InvalidVersion, Version
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
logger = logging.getLogger("astrbot")
class RequirementsPrecheckFailed(Exception):
"""Raised when the pre-check of requirements fails."""
pass
@dataclass(frozen=True)
class ParsedPackageInput:
specs: tuple[str, ...]
requirement_names: frozenset[str]
@dataclass(frozen=True)
class MissingRequirementsPlan:
missing_names: frozenset[str]
install_lines: tuple[str, ...]
fallback_reason: str | None = None
def canonicalize_distribution_name(name: str) -> str:
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
def strip_inline_requirement_comment(raw_input: str) -> str:
if raw_input.lstrip().startswith("#"):
return ""
return re.split(r"[ \t]+#", raw_input, maxsplit=1)[0].strip()
def _specifier_contains_version(specifier: SpecifierSet, version: str) -> bool:
try:
parsed_version = Version(version)
except InvalidVersion:
return False
return specifier.contains(parsed_version, prereleases=True)
def _looks_like_local_path_reference(token: str) -> bool:
candidate = token.strip()
if not candidate:
return False
return candidate in {".", ".."} or candidate.startswith(
("./", "../", "/", "~/", ".\\", "..\\", "\\")
)
def looks_like_direct_reference(token: str) -> bool:
candidate = token.strip()
if not candidate:
return False
return (
_looks_like_local_path_reference(candidate)
or candidate.startswith("git+")
or "://" in candidate
)
def extract_requirement_name(raw_requirement: str) -> str | None:
line = raw_requirement.split("#", 1)[0].strip()
if not line:
return None
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
return None
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
if egg_match:
return canonicalize_distribution_name(egg_match.group(1))
if line.startswith("-"):
return None
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
if not candidate:
return None
return canonicalize_distribution_name(candidate)
def _parse_editable_or_direct_name(target: str) -> str | None:
name = extract_requirement_name(target)
if not name:
return None
if "#egg=" in target or not looks_like_direct_reference(target):
return name
return None
def _parse_requirement_name_and_spec(
line: str,
) -> tuple[str | None, SpecifierSet | None]:
if line.startswith(("-c", "--constraint")):
return None, None
try:
req = Requirement(line)
except InvalidRequirement:
tokens = shlex.split(line)
if not tokens:
return None, None
editable_target: str | None = None
if tokens[0] in {"-e", "--editable"} and len(tokens) > 1:
editable_target = tokens[1]
elif tokens[0].startswith("--editable="):
editable_target = tokens[0].split("=", 1)[1]
if editable_target:
name = _parse_editable_or_direct_name(editable_target)
return (name, None) if name else (None, None)
name = _parse_editable_or_direct_name(line)
return (name, None) if name else (None, None)
if req.marker and not req.marker.evaluate():
return None, None
return canonicalize_distribution_name(req.name), (req.specifier or None)
def _parse_requirement_line(
line: str,
) -> tuple[str, SpecifierSet | None] | None:
name, specifier = _parse_requirement_name_and_spec(line)
return (name, specifier) if name else None
def _extract_requirement_names_from_package_tokens(tokens: list[str]) -> frozenset[str]:
requirement_names: set[str] = set()
skip_next_for: str | None = None
for token in tokens:
if skip_next_for:
if skip_next_for == "editable":
name = _parse_editable_or_direct_name(token)
if name:
requirement_names.add(name)
skip_next_for = None
continue
if token in {"-e", "--editable"}:
skip_next_for = "editable"
continue
if token in {
"-i",
"--index-url",
"--extra-index-url",
"-f",
"--find-links",
"--trusted-host",
"-r",
"--requirement",
"-c",
"--constraint",
}:
skip_next_for = "option-value"
continue
if token.startswith(("--editable=",)):
editable_target = token.split("=", 1)[1]
name = _parse_editable_or_direct_name(editable_target)
if name:
requirement_names.add(name)
continue
if token.startswith(
(
"--index-url=",
"--extra-index-url=",
"--find-links=",
"--trusted-host=",
"--requirement=",
"--constraint=",
)
):
continue
if (
(token.startswith("-i") and token != "-i")
or (token.startswith("-f") and token != "-f")
or token == "--no-index"
):
continue
if token.startswith("-"):
continue
name, _ = _parse_requirement_name_and_spec(token)
if name:
requirement_names.add(name)
return frozenset(requirement_names)
def parse_package_install_input(raw_input: str) -> ParsedPackageInput:
specs: list[str] = []
requirement_names: set[str] = set()
normalized = raw_input.strip()
if not normalized:
return ParsedPackageInput(specs=(), requirement_names=frozenset())
for raw_line in normalized.splitlines():
line = strip_inline_requirement_comment(raw_line)
if not line:
continue
try:
Requirement(line)
except InvalidRequirement:
tokens = shlex.split(line)
if not tokens:
continue
specs.extend(tokens)
requirement_names.update(
_extract_requirement_names_from_package_tokens(tokens)
)
continue
specs.append(line)
name, _ = _parse_requirement_name_and_spec(line)
if name:
requirement_names.add(name)
return ParsedPackageInput(
specs=tuple(specs),
requirement_names=frozenset(requirement_names),
)
def _iter_requirement_lines(
requirements_path: str,
_visited: set[str] | None = None,
) -> Iterator[str]:
visited = _visited or set()
resolved_path = os.path.realpath(requirements_path)
if resolved_path in visited:
logger.warning(
"检测到循环依赖的 requirements 包含: %s,将跳过该文件", resolved_path
)
return
visited.add(resolved_path)
with open(resolved_path, encoding="utf-8") as f:
for raw_line in f:
line = strip_inline_requirement_comment(raw_line)
if not line:
continue
tokens = shlex.split(line)
if not tokens:
continue
nested: str | None = None
if tokens[0] in {"-r", "--requirement"} and len(tokens) > 1:
nested = tokens[1]
elif tokens[0].startswith("--requirement="):
nested = tokens[0].split("=", 1)[1]
if nested:
if not os.path.isabs(nested):
nested = os.path.join(os.path.dirname(resolved_path), nested)
yield from _iter_requirement_lines(nested, _visited=visited)
continue
yield line
def iter_requirements(
requirements_path: str | None = None,
lines: Iterable[str] | None = None,
) -> Iterator[tuple[str, SpecifierSet | None]]:
if lines is None:
if requirements_path is None:
raise ValueError("Either requirements_path or lines must be provided")
lines = _iter_requirement_lines(requirements_path)
for line in lines:
parsed = _parse_requirement_line(line)
if parsed is not None:
yield parsed
def extract_requirement_names(requirements_path: str) -> set[str]:
try:
return {
name for name, _ in iter_requirements(requirements_path=requirements_path)
}
except Exception as exc:
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
return set()
def get_requirement_check_paths() -> list[str]:
paths = list(sys.path)
if is_packaged_desktop_runtime():
target_site_packages = get_astrbot_site_packages_path()
if os.path.isdir(target_site_packages):
paths.insert(0, target_site_packages)
return paths
def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]:
distribution_name = (
distribution.metadata["Name"] if "Name" in distribution.metadata else None
)
if not distribution_name:
return None, None
return canonicalize_distribution_name(distribution_name), distribution.version
def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] | None:
installed: dict[str, str] = {}
try:
for distribution in importlib_metadata.distributions(path=paths):
distribution_name, version = _canonical_distribution_identity(distribution)
if not distribution_name or not version:
continue
installed.setdefault(distribution_name, version)
except Exception as exc:
logger.warning("读取已安装依赖失败,跳过缺失依赖预检查: %s", exc)
return None
return installed
def _load_requirement_lines_for_precheck(
requirements_path: str,
) -> tuple[bool, list[str] | None]:
try:
requirement_lines = list(_iter_requirement_lines(requirements_path))
except Exception as exc:
logger.warning(
"预检查缺失依赖失败,将回退到完整安装: %s (%s)",
requirements_path,
exc,
)
return False, None
fallback_line = next(
(
line
for line in requirement_lines
if (
(
line.startswith(("-e ", "--editable ", "--editable="))
and "#egg=" not in line
)
or (
_parse_requirement_line(line) is None
and looks_like_direct_reference(line)
)
)
),
None,
)
if fallback_line is not None:
logger.info(
"缺失依赖预检查发现无法安全裁剪的 option/direct-reference 行,将回退到完整安装: %s (%s)",
requirements_path,
fallback_line,
)
return False, None
return True, requirement_lines
def find_missing_requirements(requirements_path: str) -> set[str] | None:
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
requirements_path
)
if not can_precheck or requirement_lines is None:
return None
return find_missing_requirements_from_lines(requirement_lines)
def find_missing_requirements_from_lines(
requirement_lines: Sequence[str],
) -> set[str] | None:
required = list(iter_requirements(lines=requirement_lines))
if not required:
return set()
installed = collect_installed_distribution_versions(get_requirement_check_paths())
if installed is None:
return None
missing: set[str] = set()
for name, specifier in required:
installed_version = installed.get(name)
if not installed_version:
missing.add(name)
continue
if specifier and not _specifier_contains_version(specifier, installed_version):
missing.add(name)
return missing
def build_missing_requirements_install_lines(
requirements_path: str,
requirement_lines: Sequence[str],
missing_names: set[str] | frozenset[str],
) -> tuple[str, ...] | None:
wanted_names = set(missing_names)
install_lines: list[str] = []
for line in requirement_lines:
parsed = _parse_requirement_line(line)
if parsed is None:
if looks_like_direct_reference(line) or line.startswith(("-", "--")):
logger.debug(
"缺失依赖行筛选回退到完整安装:requirements 中包含无法安全裁剪的 option/direct-reference 行: %s (%s)",
requirements_path,
line,
)
return None
continue
name, _specifier = parsed
if name in wanted_names:
install_lines.append(line)
return tuple(install_lines)
def plan_missing_requirements_install(
requirements_path: str,
) -> MissingRequirementsPlan | None:
can_precheck, requirement_lines = _load_requirement_lines_for_precheck(
requirements_path
)
if not can_precheck or requirement_lines is None:
return None
missing = find_missing_requirements_from_lines(requirement_lines)
if missing is None:
return None
install_lines = build_missing_requirements_install_lines(
requirements_path,
requirement_lines,
missing,
)
if install_lines is None:
return None
if missing and not install_lines:
logger.warning(
"预检查缺失依赖成功,但无法映射到可安装 requirement 行,将回退到完整安装: %s -> %s",
requirements_path,
sorted(missing),
)
return MissingRequirementsPlan(
missing_names=frozenset(missing),
install_lines=(),
fallback_reason="unmapped missing requirement names",
)
return MissingRequirementsPlan(
missing_names=frozenset(missing),
install_lines=install_lines,
)
def find_missing_requirements_or_raise(requirements_path: str) -> set[str]:
missing = find_missing_requirements(requirements_path)
if missing is None:
raise RequirementsPrecheckFailed(f"预检查失败: {requirements_path}")
return missing
+1 -1
View File
@@ -7,4 +7,4 @@ def is_frozen_runtime() -> bool:
def is_packaged_desktop_runtime() -> bool:
return is_frozen_runtime() and os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1"
return os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1"
+27 -1
View File
@@ -1,9 +1,13 @@
import asyncio
import threading
import weakref
from collections import defaultdict
from contextlib import asynccontextmanager
class SessionLockManager:
class _PerLoopSessionLockManager:
"""Per-event-loop session lock manager; keeps original simple semantics."""
def __init__(self) -> None:
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._lock_count: dict[str, int] = defaultdict(int)
@@ -26,4 +30,26 @@ class SessionLockManager:
self._lock_count.pop(session_id, None)
class SessionLockManager:
"""Thread-safe session lock manager with per-event-loop isolation."""
def __init__(self) -> None:
self._state_guard = threading.Lock()
self._loop_managers: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, _PerLoopSessionLockManager
] = weakref.WeakKeyDictionary()
def _get_loop_manager(self) -> _PerLoopSessionLockManager:
"""Get the lock manager for the current event loop."""
loop = asyncio.get_running_loop()
with self._state_guard:
return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager())
@asynccontextmanager
async def acquire_lock(self, session_id: str):
manager = self._get_loop_manager()
async with manager.acquire_lock(session_id):
yield
session_lock_manager = SessionLockManager()
+11 -1
View File
@@ -977,7 +977,17 @@ class BackupRoute(Route):
if not jwt_secret:
return Response().error("服务器配置错误").__dict__
jwt.decode(token, jwt_secret, algorithms=["HS256"])
# Verify JWT token with strict security options
jwt.decode(
token,
jwt_secret,
algorithms=["HS256"],
options={
"require": ["exp"], # Require expiration claim
"verify_signature": True, # Explicitly verify signature
"verify_exp": True, # Verify expiration
},
)
except jwt.ExpiredSignatureError:
return Response().error("Token 已过期,请刷新页面后重试").__dict__
except jwt.InvalidTokenError:
+105
View File
@@ -206,12 +206,110 @@ def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]
return errors, data
def _log_computer_config_changes(old_config: dict, new_config: dict) -> None:
"""Compare and log Computer/sandbox configuration changes."""
old_ps = old_config.get("provider_settings", {})
new_ps = new_config.get("provider_settings", {})
# Check computer_use_runtime
old_runtime = old_ps.get("computer_use_runtime", "none")
new_runtime = new_ps.get("computer_use_runtime", "none")
if old_runtime != new_runtime:
logger.info(
"[Computer] Config changed: computer_use_runtime %s -> %s",
old_runtime,
new_runtime,
)
# Check sandbox sub-keys
old_sandbox = old_ps.get("sandbox", {})
new_sandbox = new_ps.get("sandbox", {})
all_keys = set(old_sandbox.keys()) | set(new_sandbox.keys())
for key in sorted(all_keys):
old_val = old_sandbox.get(key)
new_val = new_sandbox.get(key)
if old_val != new_val:
# Mask tokens/secrets in log output
if "token" in key or "secret" in key:
old_display = "***" if old_val else "(empty)"
new_display = "***" if new_val else "(empty)"
else:
old_display = old_val
new_display = new_val
logger.info(
"[Computer] Config changed: sandbox.%s %s -> %s",
key,
old_display,
new_display,
)
async def _validate_neo_connectivity(
post_config: dict,
) -> str | None:
"""Check if Bay is reachable when Shipyard Neo sandbox is configured.
Returns a warning message string if Bay isn't reachable, or None if
everything looks fine (or Neo isn't configured).
"""
ps = post_config.get("provider_settings", {})
runtime = ps.get("computer_use_runtime", "none")
sandbox = ps.get("sandbox", {})
booter = sandbox.get("booter", "")
# Only check when sandbox mode + shipyard_neo is selected
if runtime != "sandbox" or booter != "shipyard_neo":
return None
endpoint = sandbox.get("shipyard_neo_endpoint", "").rstrip("/")
if not endpoint:
return "⚠️ Shipyard Neo endpoint 未设置"
access_token = sandbox.get("shipyard_neo_access_token", "")
if not access_token:
# Try auto-discovery
from astrbot.core.computer.computer_client import _discover_bay_credentials
access_token = _discover_bay_credentials(endpoint)
if not access_token:
return (
"⚠️ 未找到 Bay API Key。请填写访问令牌,"
"或确保 Bay 的 credentials.json 可被自动发现。"
)
# Connectivity check
import aiohttp
health_url = f"{endpoint}/health"
try:
async with aiohttp.ClientSession() as session:
async with session.get(
health_url,
timeout=aiohttp.ClientTimeout(total=5),
) as resp:
if resp.status != 200:
return (
f"⚠️ Bay 健康检查失败 (HTTP {resp.status})"
f"请确认 Bay 正在运行: {endpoint}"
)
except Exception:
return f"⚠️ 无法连接 Bay ({endpoint}),请确认 Bay 已启动。"
return None
def save_config(
post_config: dict, config: AstrBotConfig, is_core: bool = False
) -> None:
"""验证并保存配置"""
errors = None
logger.info(f"Saving config, is_core={is_core}")
# Snapshot old Computer config for change detection
if is_core:
_log_computer_config_changes(dict(config), post_config)
try:
if is_core:
errors, post_config = validate_config(
@@ -512,6 +610,7 @@ class ConfigRoute(Route):
try:
conf_id = self.acm.create_conf(name=name, config=config)
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
@@ -551,6 +650,7 @@ class ConfigRoute(Route):
try:
success = self.acm.delete_conf(conf_id)
if success:
self.core_lifecycle.pipeline_scheduler_mapping.pop(conf_id, None)
return Response().ok(message="删除成功").__dict__
return Response().error("删除失败").__dict__
except ValueError as e:
@@ -928,6 +1028,11 @@ class ConfigRoute(Route):
await self._save_astrbot_configs(config, conf_id)
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
# Non-blocking Bay connectivity check
warning = await _validate_neo_connectivity(config)
if warning:
return Response().ok(None, f"保存成功。{warning}").__dict__
return Response().ok(None, "保存成功~").__dict__
except Exception as e:
logger.error(traceback.format_exc())
+31 -1
View File
@@ -5,7 +5,8 @@ import os
import ssl
import traceback
from dataclasses import dataclass
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path
import aiohttp
import certifi
@@ -352,6 +353,34 @@ class PluginRoute(Route):
logger.warning(f"获取插件 Logo 失败: {e}")
return None
def _resolve_plugin_dir(self, plugin) -> Path | None:
if not plugin.root_dir_name:
return None
base_dir = Path(
self.plugin_manager.reserved_plugin_path
if plugin.reserved
else self.plugin_manager.plugin_store_path
)
plugin_dir = base_dir / plugin.root_dir_name
if not plugin_dir.is_dir():
return None
return plugin_dir
def _get_plugin_installed_at(self, plugin) -> str | None:
plugin_dir = self._resolve_plugin_dir(plugin)
if plugin_dir is None:
return None
try:
return datetime.fromtimestamp(
plugin_dir.stat().st_mtime,
timezone.utc,
).isoformat()
except OSError as exc:
logger.warning(f"获取插件安装时间失败 {plugin.name}: {exc!s}")
return None
async def get_plugins(self):
_plugin_resp = []
plugin_name = request.args.get("name")
@@ -377,6 +406,7 @@ class PluginRoute(Route):
"logo": f"/api/file/{logo_url}" if logo_url else None,
"support_platforms": plugin.support_platforms,
"astrbot_version": plugin.astrbot_version,
"installed_at": self._get_plugin_installed_at(plugin),
}
# 检查是否为全空的幽灵插件
if not any(
+517 -2
View File
@@ -1,15 +1,49 @@
import os
import re
import shutil
import traceback
import uuid
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any
from quart import request
from quart import request, send_file
from astrbot.core import DEMO_MODE, logger
from astrbot.core.computer.computer_client import (
_discover_bay_credentials,
sync_skills_to_active_sandboxes,
)
from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager
from astrbot.core.skills.skill_manager import SkillManager
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .route import Response, Route, RouteContext
def _to_jsonable(value: Any) -> Any:
if isinstance(value, dict):
return {k: _to_jsonable(v) for k, v in value.items()}
if isinstance(value, list):
return [_to_jsonable(v) for v in value]
if hasattr(value, "model_dump"):
return _to_jsonable(value.model_dump())
return value
def _to_bool(value: Any, default: bool = False) -> bool:
if value is None:
return default
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.strip().lower() in {"1", "true", "yes", "y", "on"}
return bool(value)
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
class SkillsRoute(Route):
def __init__(self, context: RouteContext, core_lifecycle) -> None:
super().__init__(context)
@@ -17,18 +51,82 @@ class SkillsRoute(Route):
self.routes = {
"/skills": ("GET", self.get_skills),
"/skills/upload": ("POST", self.upload_skill),
"/skills/batch-upload": ("POST", self.batch_upload_skills),
"/skills/download": ("GET", self.download_skill),
"/skills/update": ("POST", self.update_skill),
"/skills/delete": ("POST", self.delete_skill),
"/skills/neo/candidates": ("GET", self.get_neo_candidates),
"/skills/neo/releases": ("GET", self.get_neo_releases),
"/skills/neo/payload": ("GET", self.get_neo_payload),
"/skills/neo/evaluate": ("POST", self.evaluate_neo_candidate),
"/skills/neo/promote": ("POST", self.promote_neo_candidate),
"/skills/neo/rollback": ("POST", self.rollback_neo_release),
"/skills/neo/sync": ("POST", self.sync_neo_release),
"/skills/neo/delete-candidate": ("POST", self.delete_neo_candidate),
"/skills/neo/delete-release": ("POST", self.delete_neo_release),
}
self.register_routes()
def _get_neo_client_config(self) -> tuple[str, str]:
provider_settings = self.core_lifecycle.astrbot_config.get(
"provider_settings",
{},
)
sandbox = provider_settings.get("sandbox", {})
endpoint = sandbox.get("shipyard_neo_endpoint", "")
access_token = sandbox.get("shipyard_neo_access_token", "")
# Auto-discover token from Bay's credentials.json if not configured
if not access_token and endpoint:
access_token = _discover_bay_credentials(endpoint)
if not endpoint or not access_token:
raise ValueError(
"Shipyard Neo endpoint or access token not configured. "
"Set them in Dashboard or ensure Bay's credentials.json is accessible."
)
return endpoint, access_token
async def _delete_neo_release(
self, client: Any, release_id: str, reason: str | None
):
return await client.skills.delete_release(release_id, reason=reason)
async def _delete_neo_candidate(
self, client: Any, candidate_id: str, reason: str | None
):
return await client.skills.delete_candidate(candidate_id, reason=reason)
async def _with_neo_client(
self,
operation: Callable[[Any], Awaitable[dict]],
) -> dict:
try:
endpoint, access_token = self._get_neo_client_config()
from shipyard_neo import BayClient
async with BayClient(
endpoint_url=endpoint,
access_token=access_token,
) as client:
return await operation(client)
except ValueError as e:
# Config not ready — expected when Neo isn't set up yet
logger.debug("[Neo] %s", e)
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def get_skills(self):
try:
provider_settings = self.core_lifecycle.astrbot_config.get(
"provider_settings", {}
)
runtime = provider_settings.get("computer_use_runtime", "local")
skills = SkillManager().list_skills(
skill_mgr = SkillManager()
skills = skill_mgr.list_skills(
active_only=False, runtime=runtime, show_sandbox_path=False
)
return (
@@ -36,6 +134,8 @@ class SkillsRoute(Route):
.ok(
{
"skills": [skill.__dict__ for skill in skills],
"runtime": runtime,
"sandbox_cache": skill_mgr.get_sandbox_skills_cache_status(),
}
)
.__dict__
@@ -70,6 +170,11 @@ class SkillsRoute(Route):
skill_mgr = SkillManager()
skill_name = skill_mgr.install_skill_from_zip(temp_path, overwrite=True)
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync uploaded skills to active sandboxes.")
return (
Response()
.ok({"name": skill_name}, "Skill uploaded successfully.")
@@ -85,6 +190,161 @@ class SkillsRoute(Route):
except Exception:
logger.warning(f"Failed to remove temp skill file: {temp_path}")
async def batch_upload_skills(self):
"""批量上传多个 skill ZIP 文件"""
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
files = await request.files
file_list = files.getlist("files")
if not file_list:
return Response().error("No files provided").__dict__
succeeded = []
failed = []
skill_mgr = SkillManager()
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
for file in file_list:
filename = os.path.basename(file.filename or "unknown.zip")
temp_path = None
try:
if not filename.lower().endswith(".zip"):
failed.append(
{
"filename": filename,
"error": "Only .zip files are supported",
}
)
continue
temp_path = os.path.join(
temp_dir, f"batch_{uuid.uuid4().hex}_{filename}"
)
await file.save(temp_path)
skill_name = skill_mgr.install_skill_from_zip(
temp_path, overwrite=True
)
succeeded.append({"filename": filename, "name": skill_name})
except Exception as e:
failed.append({"filename": filename, "error": str(e)})
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except Exception:
pass
if succeeded:
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning(
"Failed to sync uploaded skills to active sandboxes."
)
total = len(file_list)
success_count = len(succeeded)
if success_count == total:
message = f"All {total} skill(s) uploaded successfully."
return (
Response()
.ok(
{
"total": total,
"succeeded": succeeded,
"failed": failed,
},
message,
)
.__dict__
)
if success_count == 0:
message = f"Upload failed for all {total} file(s)."
resp = Response().error(message)
resp.data = {
"total": total,
"succeeded": succeeded,
"failed": failed,
}
return resp.__dict__
message = f"Partial success: {success_count}/{total} skill(s) uploaded."
return (
Response()
.ok(
{
"total": total,
"succeeded": succeeded,
"failed": failed,
},
message,
)
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def download_skill(self):
try:
name = str(request.args.get("name") or "").strip()
if not name:
return Response().error("Missing skill name").__dict__
if not _SKILL_NAME_RE.match(name):
return Response().error("Invalid skill name").__dict__
skill_mgr = SkillManager()
if skill_mgr.is_sandbox_only_skill(name):
return (
Response()
.error(
"Sandbox preset skill cannot be downloaded from local skill files."
)
.__dict__
)
skill_dir = Path(skill_mgr.skills_root) / name
skill_md = skill_dir / "SKILL.md"
if not skill_dir.is_dir() or not skill_md.exists():
return Response().error("Local skill not found").__dict__
export_dir = Path(get_astrbot_temp_path()) / "skill_exports"
export_dir.mkdir(parents=True, exist_ok=True)
zip_base = export_dir / name
zip_path = zip_base.with_suffix(".zip")
if zip_path.exists():
zip_path.unlink()
shutil.make_archive(
str(zip_base),
"zip",
root_dir=str(skill_mgr.skills_root),
base_dir=name,
)
return await send_file(
str(zip_path),
as_attachment=True,
attachment_filename=f"{name}.zip",
conditional=True,
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def update_skill(self):
if DEMO_MODE:
return (
@@ -117,7 +377,262 @@ class SkillsRoute(Route):
if not name:
return Response().error("Missing skill name").__dict__
SkillManager().delete_skill(name)
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync deleted skills to active sandboxes.")
return Response().ok({"name": name}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def get_neo_candidates(self):
logger.info("[Neo] GET /skills/neo/candidates requested.")
status = request.args.get("status")
skill_key = request.args.get("skill_key")
limit = int(request.args.get("limit", 100))
offset = int(request.args.get("offset", 0))
async def _do(client):
candidates = await client.skills.list_candidates(
status=status,
skill_key=skill_key,
limit=limit,
offset=offset,
)
result = _to_jsonable(candidates)
total = result.get("total", "?") if isinstance(result, dict) else "?"
logger.info(f"[Neo] Candidates fetched: total={total}")
return Response().ok(result).__dict__
return await self._with_neo_client(_do)
async def get_neo_releases(self):
logger.info("[Neo] GET /skills/neo/releases requested.")
skill_key = request.args.get("skill_key")
stage = request.args.get("stage")
active_only = _to_bool(request.args.get("active_only"), False)
limit = int(request.args.get("limit", 100))
offset = int(request.args.get("offset", 0))
async def _do(client):
releases = await client.skills.list_releases(
skill_key=skill_key,
active_only=active_only,
stage=stage,
limit=limit,
offset=offset,
)
result = _to_jsonable(releases)
total = result.get("total", "?") if isinstance(result, dict) else "?"
logger.info(f"[Neo] Releases fetched: total={total}")
return Response().ok(result).__dict__
return await self._with_neo_client(_do)
async def get_neo_payload(self):
logger.info("[Neo] GET /skills/neo/payload requested.")
payload_ref = request.args.get("payload_ref", "")
if not payload_ref:
return Response().error("Missing payload_ref").__dict__
async def _do(client):
payload = await client.skills.get_payload(payload_ref)
logger.info(f"[Neo] Payload fetched: ref={payload_ref}")
return Response().ok(_to_jsonable(payload)).__dict__
return await self._with_neo_client(_do)
async def evaluate_neo_candidate(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/evaluate requested.")
data = await request.get_json()
candidate_id = data.get("candidate_id")
passed_value = data.get("passed")
if not candidate_id or passed_value is None:
return Response().error("Missing candidate_id or passed").__dict__
passed = _to_bool(passed_value, False)
async def _do(client):
result = await client.skills.evaluate_candidate(
candidate_id,
passed=passed,
score=data.get("score"),
benchmark_id=data.get("benchmark_id"),
report=data.get("report"),
)
logger.info(
f"[Neo] Candidate evaluated: id={candidate_id}, passed={passed}"
)
return Response().ok(_to_jsonable(result)).__dict__
return await self._with_neo_client(_do)
async def promote_neo_candidate(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/promote requested.")
data = await request.get_json()
candidate_id = data.get("candidate_id")
stage = data.get("stage", "canary")
sync_to_local = _to_bool(data.get("sync_to_local"), True)
if not candidate_id:
return Response().error("Missing candidate_id").__dict__
if stage not in {"canary", "stable"}:
return Response().error("Invalid stage, must be canary/stable").__dict__
async def _do(client):
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.promote_with_optional_sync(
client,
candidate_id=candidate_id,
stage=stage,
sync_to_local=sync_to_local,
)
release_json = result.get("release")
logger.info(f"[Neo] Candidate promoted: id={candidate_id}, stage={stage}")
sync_json = result.get("sync")
did_sync_to_local = bool(sync_json)
if did_sync_to_local:
logger.info(
f"[Neo] Stable release synced to local: skill={sync_json.get('local_skill_name', '')}"
)
if result.get("sync_error"):
resp = Response().error(
"Stable promote synced failed and has been rolled back. "
f"sync_error={result['sync_error']}"
)
resp.data = {
"release": release_json,
"rollback": result.get("rollback"),
}
return resp.__dict__
# Try to push latest local skills to all active sandboxes.
if not did_sync_to_local:
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning("Failed to sync skills to active sandboxes.")
return Response().ok({"release": release_json, "sync": sync_json}).__dict__
return await self._with_neo_client(_do)
async def rollback_neo_release(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/rollback requested.")
data = await request.get_json()
release_id = data.get("release_id")
if not release_id:
return Response().error("Missing release_id").__dict__
async def _do(client):
result = await client.skills.rollback_release(release_id)
logger.info(f"[Neo] Release rolled back: id={release_id}")
return Response().ok(_to_jsonable(result)).__dict__
return await self._with_neo_client(_do)
async def sync_neo_release(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/sync requested.")
data = await request.get_json()
release_id = data.get("release_id")
skill_key = data.get("skill_key")
require_stable = _to_bool(data.get("require_stable"), True)
if not release_id and not skill_key:
return Response().error("Missing release_id or skill_key").__dict__
async def _do(client):
sync_mgr = NeoSkillSyncManager()
result = await sync_mgr.sync_release(
client,
release_id=release_id,
skill_key=skill_key,
require_stable=require_stable,
)
logger.info(
f"[Neo] Release synced to local: skill={result.local_skill_name}, "
f"release_id={result.release_id}"
)
return (
Response()
.ok(
{
"skill_key": result.skill_key,
"local_skill_name": result.local_skill_name,
"release_id": result.release_id,
"candidate_id": result.candidate_id,
"payload_ref": result.payload_ref,
"map_path": result.map_path,
"synced_at": result.synced_at,
}
)
.__dict__
)
return await self._with_neo_client(_do)
async def delete_neo_candidate(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/delete-candidate requested.")
data = await request.get_json()
candidate_id = data.get("candidate_id")
reason = data.get("reason")
if not candidate_id:
return Response().error("Missing candidate_id").__dict__
async def _do(client):
result = await self._delete_neo_candidate(client, candidate_id, reason)
logger.info(f"[Neo] Candidate deleted: id={candidate_id}")
return Response().ok(_to_jsonable(result)).__dict__
return await self._with_neo_client(_do)
async def delete_neo_release(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
logger.info("[Neo] POST /skills/neo/delete-release requested.")
data = await request.get_json()
release_id = data.get("release_id")
reason = data.get("reason")
if not release_id:
return Response().error("Missing release_id").__dict__
async def _do(client):
result = await self._delete_neo_release(client, release_id, reason)
logger.info(f"[Neo] Release deleted: id={release_id}")
return Response().ok(_to_jsonable(result)).__dict__
return await self._with_neo_client(_do)

Some files were not shown because too many files have changed in this diff Show More