Compare commits

..

24 Commits

Author SHA1 Message Date
Soulter 7d31140c14 chore: bump version to 4.19.4 2026-03-09 11:13:39 +08:00
Soulter 654112ca86 feat(wecomai): implement long connection mode and update configuration options (#5930) 2026-03-09 11:10:32 +08:00
Soulter 5dd30f9a45 chore: bump version to 4.19.3 2026-03-09 00:20:33 +08:00
Jason a53a1ca49b fix(provider): handle MiniMax ThinkingBlock when max_tokens reached (#5913)
* fix(provider): handle MiniMax ThinkingBlock when max_tokens reached

Fixes #5912

Problem: MiniMax API returns ThinkingBlock when stop_reason='max_tokens',
but AstrBot throws 'completion 无法解析' exception because both
completion_text and tools_call_args are empty.

Root cause: The validation logic didn't consider ThinkingBlock
(reasoning_content) as valid content.

Fix: When completion_text and tools_call_args are empty but
reasoning_content is present, treat it as valid instead of throwing
exception. This happens when the model thinks but runs out of tokens
before generating the actual response.

Impact: MiniMax models now work correctly when responses are truncated
due to max_tokens limit.

* refactor: address review feedback

1. Use getattr for safe stop_reason access (prevent AttributeError)
2. Use ValueError instead of generic Exception for better error handling

Thanks @gemini-code-assist and @sourcery-ai for the review!

* refactor: flatten nested if/else with guard clause

Address Gemini Code Assist feedback:
- Use guard clause for early return
- Flattened nested conditional for better readability

Logic unchanged, just cleaner code structure.

* fix(provider): improve logging for ThinkingBlock completions in ProviderAnthropic

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-09 00:17:11 +08:00
whatevertogo 3fd6c4c8a6 fix: 修复 asyncio 事件循环相关问题 (#5774)
* fix: 修复 asyncio 事件循环相关的问题

1. components.py: 修复异常处理结构错误
   - 将 except Exception 移到正确的内部 try 块
   - 确保 _download_file() 异常能被正确捕获和记录

2. session_lock.py: 修复跨事件循环 Lock 绑定问题
   - 添加 _access_lock_loop_id 追踪事件循环
   - 当事件循环变化时重新创建 Lock

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: 根据代码审查反馈修复问题

1. components.py: 移除 asyncio.set_event_loop() 调用
   - 创建临时 event loop 时不再设置为全局
   - 避免干扰其他 asyncio 使用

2. session_lock.py: 简化延迟初始化逻辑
   - 移除 loop-ID 追踪和 _get_lock 方法
   - 使用 setdefault 简化 session lock 创建
   - 保留延迟初始化行为

3. wecomai_queue_mgr.py: 使用 time.monotonic() 替代 loop.time()
   - 同步方法不再依赖活动的 event loop
   - 避免在非异步上下文中抛出 RuntimeError

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: 优化 asyncio 事件循环管理,使用安全的方式创建和关闭事件循环

* fix: 根据代码审查反馈改进异常处理和事件循环使用

- main.py: 显式处理 check_dashboard_files() 返回 None 的情况
- components.py: 使用 logger.exception 保留异常堆栈信息
- star_manager.py: 添加 Future 异常回调处理 __del__ 执行异常
- bay_manager.py: 缓存事件循环引用避免重复调用

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor: 简化 SessionLockManager 使用 defaultdict 和 setdefault

- 使用 defaultdict(asyncio.Lock) 简化锁的懒创建
- 使用 setdefault 简化 _get_loop_state 逻辑
- 减少 get + if 分支,提升可读性

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: 降低 webui_dir 检查失败时的日志级别为 warning

改为警告而非退出,允许程序在无 WebUI 的情况下继续运行

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor: 重构事件循环锁管理,简化锁状态管理逻辑

* 新增对 SessionLockManager 的多事件循环隔离测试

* fix: 修复测试中的变量声明和断言,确保事件循环管理器的正确性

* fix: 修复插件删除时异常处理逻辑,确保正确记录错误信息

* fix: 新增针对多个事件循环的 OneBot 实例的测试,确保锁对象在不同事件循环间不共享

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-09 01:00:13 +09:00
sanyekana 5808784f07 fix: prevent crash on malformed MCP server config (#5666) (#5673)
* fix: prevent crash on malformed MCP server config (#5666)

* fix: prevent crash on malformed MCP server config (#5666)

* fix: validate MCP connection before persisting server config

* fix: guard mcpServers type before iterating server list

* refactor: use typed empty-config error and extract MCP rollback helper

* fix: translate error messages and comments to English for consistency

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-08 23:46:32 +08:00
Soulter 537849c1e7 fix(dingtalk): text is ignored; cannot send file actively (#5921) 2026-03-08 23:31:11 +08:00
Soulter 7f3c0fdeb2 fix: cannot receive image, file in dingtalk (#5920)
fixes: #5916 #5786
2026-03-08 23:18:56 +08:00
Windy_cold 8e431e2076 correct openrouter api_base (#5911) 2026-03-08 21:53:56 +09:00
ChuwuYo 89c11fd683 fix(extension): support searching installed plugins by display name (#5806) (#5811)
* fix(extension): support searching installed plugins by display name

* fix: unify plugin search matching across installed and market tabs

* refactor(extension): optimize plugin search matcher and remove redundant checks

* refactor(extension-page): centralize search query normalization and text matching logic

- Extract `buildSearchQuery` to create normalized query objects from raw input
- Extract `matchesText` as a reusable text matching helper for normalized/loose/pinyin/initials matching
- Remove unused `marketCustomFilter` to eliminate dead code
- Simplify `matchesPluginSearch` to accept query object instead of pre-normalized string
- Replace Set with Array for candidates to simplify control flow
- Avoid redundant normalization by having callers pass raw strings to `buildSearchQuery`

* refactor: remove unused marketCustomFilter from extension page components

- Remove marketCustomFilter from destructuring in ExtensionPage.vue, InstalledPluginsTab.vue, and MarketPluginsTab.vue

* refactor(extension): extract plugin search utilities into shared module

- Create pluginSearch.js to centralize plugin search helpers
- Move `normalizeStr`, `normalizeLoose`, `toPinyinText`, and `toInitials` into the shared module
- Add `buildSearchQuery`, `matchesText`, and `matchesPluginSearch` for reusable search matching
- Refactor useExtensionPage.js to consume the shared utilities
- Simplify plugin search logic by consolidating normalization and matching in one place

* refactor(extension): add caching to pinyin utilities and extract search fields helper

- Add Map-based caching for `toPinyinText` and `toInitials` to avoid redundant pinyin computation
- Extract `getPluginSearchFields` function to retrieve plugin fields for searching
- Improve plugin search performance with caching and better code organization

* perf(extension): add bounded caching for plugin search

- cap normalization and pinyin caches with `MAX_SEARCH_CACHE_SIZE`
- add `setCacheValue()` for oldest-entry eviction
- cache normalized and loose text values to avoid repeated string processing
- skip pinyin matching for non-CJK text using Unicode `\p{Unified_Ideograph}` property
- improve search performance while keeping memory usage bounded

* refactor(extension): extract memoizeLRU helper for cache management

- Create `memoizeLRU` higher-order function to generate LRU-cached functions
- Replace manual cache implementation with `memoizeLRU` for cleaner code
- Optimize `matchesText` to lazily compute looseValue only when needed
- Simplify caching logic while maintaining bounded cache size

* refactor(extension): simplify memoization and remove LRU logic

- Rename `memoizeLRU` to `memoizeStringFn` and remove bounded cache size
- Simplify cache hit logic for cleaner code
- Remove `MAX_SEARCH_CACHE_SIZE` constant as it's no longer needed
2026-03-08 17:41:45 +09:00
時壹 7cfe2aca99 fix: apply reply_with_quote and reply_with_mention to image-only response (#5219)
* fix: apply reply_with_quote and reply_with_mention to image-only responses

* fix: restrict reply_with_quote and reply_with_mention to plain-text/image chains
2026-03-08 17:41:12 +09:00
時壹 3a938d2a13 fix: use re.search instead of re.match in RegexFilter (#5368) 2026-03-08 17:40:40 +09:00
whatevertogo 812834bc9f feat(skills): add batch upload functionality for multiple skill ZIP files (#5804)
* feat(skills): add batch upload functionality for multiple skill ZIP files

- Implemented a new endpoint for batch uploading skills.
- Enhanced the SkillsSection component to support multiple file selection and drag-and-drop functionality.
- Updated localization files for new upload features and messages.
- Added tests to validate batch upload behavior and error handling.

* feat(skills): improve batch upload handling and enhance accessibility for dropzone

* feat(skills): enhance batch upload process and improve UI for better user experience

* feat(skills): enhance skills upload dialog layout and styling for improved usability

* feat(skills): update upload dialog description styling for better visibility and usability

* feat(skills): improve upload dialog button styling and layout for enhanced usability

* feat(skills): refine upload dialog text for clarity and consistency

* feat(skills): enhance batch upload functionality by ignoring __MACOSX entries and improving upload dialog styling

* feat(skills): refactor upload dialog and button styles for improved consistency and usability

---------

Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
2026-03-07 23:18:01 +08:00
エイカク 51ff4f6e46 fix: detect desktop runtime without frozen python (#5859)
* fix: detect desktop runtime without frozen python

* chore: drop planning docs from runtime fix pr
2026-03-07 21:42:56 +09:00
Soulter 7ac169c5e8 docs: add macOS usage note and update instructions for astrbot in multiple README files 2026-03-06 14:34:29 +08:00
Soulter 61648ebe3e docs: add new QQ group entries to README files 2026-03-06 11:11:12 +08:00
Soulter 0610f0db0a fix: pipeline scheduler not found after creating platform bot via using 'create new config' (#5776) 2026-03-05 23:53:53 +08:00
whatevertogo 8c935981bb fix: align aiocqhttp poke payload with onebot v11 (#5773)
Co-authored-by: whatevertogo <whatevertogo@users.noreply.github.com>
2026-03-05 23:02:26 +08:00
Ruochen Pan 3f3b4e4924 test(skill_manager): update sandbox cache path expectations (#5706)
* test(skill_manager): update sandbox cache path expectations

adjust sandbox cache tests to match absolute path resolution in
list_skills for sandbox runtime.

verify sandbox-cached skills cannot be deactivated via set_skill_active
by asserting a PermissionError, and keep active-only listing behavior
intact.

add coverage for show_sandbox_path=false to ensure local skills still
override cached metadata while sandbox-only skills retain cached paths.

* test(skill_manager): tighten local skill path assertions
2026-03-05 22:47:20 +08:00
Soulter af581e7f21 chore: bump version to 4.19.2 2026-03-05 16:10:09 +08:00
Soulter 9e371ee10b chore: update shipyard-neo-sdk dependency to version 0.2.0 2026-03-05 16:07:10 +08:00
camera-2018 7cf77adbc8 feat(telegram): supports sendMessageDraft API (#5726)
* feat(telegram): 使用 sendMessageDraft API 实现私聊流式输出

- 新增 _send_message_draft 方法封装 Telegram Bot API sendMessageDraft
- 私聊流式输出使用 sendMessageDraft 推送草稿动画,群聊保留 edit_message_text 回退
- 使用独立异步发送循环 (_draft_sender_loop) 按固定间隔推送最新缓冲区内容,
  完全解耦 token 到达速度与 API 网络延迟
- 流式结束后发送真实消息保留最终内容(draft 是临时的)
- 使用模块级递增 draft_id 替代随机生成,确保 Telegram 端动画连续性

* fix(telegram): convert draft text to Markdown before sending message draft

* chore(telegram): telegram 适配器重构

- 提取公共方法
- 有新 token 到达时触发流式
- 生成结束后清除draft内容
- 默认draft发送md格式

* style(telegram): ruff format

* style(telegram): ruff check

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-03-05 11:20:28 +08:00
Soulter 31673ee521 fix: require node.js env when uv sync 2026-03-05 11:20:16 +08:00
エイカク ff22030dde docs: align deployment sections across multilingual readmes (#5734)
* docs: align deployment sections across multilingual readmes

* docs: normalize deployment punctuation and AUR guidance

* docs: fix french and russian deployment wording
2026-03-05 11:19:27 +08:00
72 changed files with 3602 additions and 1186 deletions
+12
View File
@@ -83,6 +83,15 @@ astrbot
> Requires [uv](https://docs.astral.sh/uv/) to be installed.
> [!NOTE]
> For macOS user: due to macOS security checks, the first run of the `astrbot` command may take longer (about 10-20s).
Update `astrbot`:
```bash
uv tool upgrade astrbot
```
### Docker Deployment
For users familiar with containers and looking for a more stable, production-ready deployment method, we recommend deploying AstrBot with Docker / Docker Compose.
@@ -216,12 +225,15 @@ pre-commit install
### QQ Groups
- Group 9: 1076659624 (New)
- Group 10: 1078079676 (New)
- Group 1: 322154837
- Group 3: 630166526
- Group 5: 822130018
- Group 6: 753075035
- Group 7: 743746109
- Group 8: 1030353265
- Developer Group: 975206796
### Discord Server
+9
View File
@@ -83,6 +83,15 @@ astrbot
> [uv](https://docs.astral.sh/uv/) doit être installé.
> [!NOTE]
> Pour les utilisateurs macOS : en raison des vérifications de sécurité de macOS, la première exécution de la commande `astrbot` peut prendre plus de temps (environ 10-20s).
Mettre à jour `astrbot` :
```bash
uv tool upgrade astrbot
```
### Déploiement Docker
Pour les utilisateurs familiers avec les conteneurs et qui souhaitent une méthode plus stable et adaptée à la production, nous recommandons de déployer AstrBot avec Docker / Docker Compose.
+9
View File
@@ -83,6 +83,15 @@ astrbot
> [uv](https://docs.astral.sh/uv/) のインストールが必要です。
> [!NOTE]
> macOS ユーザーの場合:macOS のセキュリティチェックにより、`astrbot` コマンドの初回実行に時間がかかる場合があります(約 10〜20 秒)。
`astrbot` の更新:
```bash
uv tool upgrade astrbot
```
### Docker デプロイ
コンテナ運用に慣れており、より安定した本番向けのデプロイ方法を求めるユーザーには、Docker / Docker Compose での AstrBot デプロイをおすすめします。
+9
View File
@@ -83,6 +83,15 @@ astrbot
> Требуется установленный [uv](https://docs.astral.sh/uv/).
> [!NOTE]
> Для пользователей macOS: из-за проверок безопасности macOS первый запуск команды `astrbot` может занять больше времени (около 10-20 секунд).
Обновить `astrbot`:
```bash
uv tool upgrade astrbot
```
### Развёртывание Docker
Для пользователей, знакомых с контейнерами и которым нужен более стабильный и подходящий для production способ, мы рекомендуем разворачивать AstrBot через Docker / Docker Compose.
+13
View File
@@ -83,6 +83,15 @@ astrbot
> 需要安裝 [uv](https://docs.astral.sh/uv/)。
> [!NOTE]
> 對於 macOS 使用者:由於 macOS 安全性檢查,首次執行 `astrbot` 指令可能需要較長時間(約 10-20 秒)。
更新 `astrbot`
```bash
uv tool upgrade astrbot
```
### Docker 部署
對於熟悉容器、希望獲得更穩定且更適合正式環境部署方式的使用者,我們推薦使用 Docker / Docker Compose 部署 AstrBot。
@@ -208,10 +217,14 @@ pre-commit install
### QQ 群組
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 8 群:1030353265
- 開發者群:975206796
### Discord 群組
+11
View File
@@ -83,6 +83,15 @@ astrbot
> 需要安装 [uv](https://docs.astral.sh/uv/)。
> [!NOTE]
> 对于 macOS 用户:由于 macOS 安全检查,首次运行 `astrbot` 命令可能需要较长时间(约 10-20 秒)。
更新 `astrbot`
```bash
uv tool upgrade astrbot
```
### Docker 部署
对于熟悉容器、希望获得更稳定且更适合生产环境部署方式的用户,我们推荐使用 Docker / Docker Compose 部署 AstrBot。
@@ -209,6 +218,8 @@ pre-commit install
### QQ 群组
- 9 群: 1076659624 (新)
- 10 群: 1078079676 (新)
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.19.1"
__version__ = "4.19.4"
@@ -302,7 +302,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
while True:
try:
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
item_type, item_data = await asyncio.get_running_loop().run_in_executor(
None, response_queue.get, True, 1
)
except queue.Empty:
@@ -388,7 +388,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
# 发起请求
partial = functools.partial(Application.call, **payload)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
response = await asyncio.get_running_loop().run_in_executor(None, partial)
async for resp in self._handle_streaming_response(response, session_id):
yield resp
+3 -2
View File
@@ -121,11 +121,12 @@ class BayContainerManager:
async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None:
"""Block until Bay's ``/health`` endpoint returns 200."""
url = f"http://127.0.0.1:{self._host_port}/health"
deadline = asyncio.get_event_loop().time() + timeout
loop = asyncio.get_running_loop()
deadline = loop.time() + timeout
last_error: str = ""
async with aiohttp.ClientSession() as session:
while asyncio.get_event_loop().time() < deadline:
while loop.time() < deadline:
try:
async with session.get(
url, timeout=aiohttp.ClientTimeout(total=3)
+69 -6
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.19.1"
VERSION = "4.19.4"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -342,14 +342,20 @@ CONFIG_METADATA_2 = {
"企业微信智能机器人": {
"id": "wecom_ai_bot",
"type": "wecom_ai_bot",
"hint": "如果发现字段有异常,请重新创建",
"enable": True,
"wecom_ai_bot_connection_mode": "long_connection", # long_connection, webhook
"wecom_ai_bot_name": "",
"wecomaibot_ws_bot_id": "",
"wecomaibot_ws_secret": "",
"wecomaibot_token": "",
"wecomaibot_encoding_aes_key": "",
"wecomaibot_init_respond_text": "",
"wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"msg_push_webhook_url": "",
"only_use_webhook_url_to_send": False,
"token": "",
"encoding_aes_key": "",
"wecomaibot_ws_url": "wss://openws.work.weixin.qq.com",
"wecomaibot_heartbeat_interval": 30,
"unified_webhook_mode": True,
"webhook_uuid": "",
"callback_server_host": "0.0.0.0",
@@ -732,6 +738,13 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "请务必填写正确,否则无法使用一些指令。",
},
"wecom_ai_bot_connection_mode": {
"description": "企业微信智能机器人连接模式",
"type": "string",
"options": ["webhook", "long_connection"],
"labels": ["Webhook 回调", "长连接"],
"hint": "Webhook 回调模式需要配置 Token/EncodingAESKey。长连接模式需要配置 BotID/Secret。",
},
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"type": "string",
@@ -742,6 +755,22 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
},
"wecomaibot_token": {
"description": "企业微信智能机器人 Token",
"type": "string",
"hint": "用于 Webhook 回调模式的身份验证。",
"condition": {
"wecom_ai_bot_connection_mode": "webhook",
},
},
"wecomaibot_encoding_aes_key": {
"description": "企业微信智能机器人 EncodingAESKey",
"type": "string",
"hint": "用于 Webhook 回调模式的消息加密解密。",
"condition": {
"wecom_ai_bot_connection_mode": "webhook",
},
},
"msg_push_webhook_url": {
"description": "企业微信消息推送 Webhook URL",
"type": "string",
@@ -752,6 +781,40 @@ CONFIG_METADATA_2 = {
"type": "bool",
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
},
"wecomaibot_ws_bot_id": {
"description": "长连接 BotID",
"type": "string",
"hint": "企业微信智能机器人长连接模式凭证 BotID。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"wecomaibot_ws_secret": {
"description": "长连接 Secret",
"type": "string",
"hint": "企业微信智能机器人长连接模式凭证 Secret。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"wecomaibot_ws_url": {
"description": "长连接 WebSocket 地址",
"type": "string",
"invisible": True,
"hint": "默认值为 wss://openws.work.weixin.qq.com,一般无需修改。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"wecomaibot_heartbeat_interval": {
"description": "长连接心跳间隔",
"type": "int",
"invisible": True,
"hint": "长连接模式心跳间隔(秒),建议 30 秒。",
"condition": {
"wecom_ai_bot_connection_mode": "long_connection",
},
},
"lark_bot_name": {
"description": "飞书机器人的名字",
"type": "string",
@@ -796,7 +859,7 @@ CONFIG_METADATA_2 = {
"unified_webhook_mode": {
"description": "统一 Webhook 模式",
"type": "bool",
"hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}",
"hint": "Webhook 模式下使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}",
},
"webhook_uuid": {
"invisible": True,
@@ -1123,7 +1186,7 @@ CONFIG_METADATA_2 = {
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://openrouter.ai/v1",
"api_base": "https://openrouter.ai/api/v1",
"proxy": "",
"custom_headers": {},
},
+2 -6
View File
@@ -97,11 +97,7 @@ class AstrBotCoreLifecycle:
except Exception as e:
logger.error(f"Subagent orchestrator init failed: {e}", exc_info=True)
async def initialize(
self,
*,
mcp_init_timeout: float | int | str | None = None,
) -> None:
async def initialize(self) -> None:
"""初始化 AstrBot 核心生命周期管理类.
负责初始化各个组件, 包括 ProviderManager、PlatformManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
@@ -205,7 +201,7 @@ class AstrBotCoreLifecycle:
await self.plugin_manager.reload()
# 根据配置实例化各个 Provider
await self.provider_manager.initialize(init_timeout=mcp_init_timeout)
await self.provider_manager.initialize()
await self.kb_manager.initialize()
+44 -18
View File
@@ -539,13 +539,36 @@ class Reply(BaseMessageComponent):
class Poke(BaseMessageComponent):
type: str = ComponentType.Poke
id: int | None = 0
qq: int | None = 0
type: ComponentType = ComponentType.Poke
_type: str | int = "126"
id: int | str | None = 0
qq: int | str | None = 0 # deprecated: legacy field, kept for compatibility
def __init__(self, type: str, **_) -> None:
type = f"Poke:{type}"
super().__init__(type=type, **_)
def __init__(self, poke_type: str | int | None = None, **_) -> None:
# Backward compatible with old signature: Poke(type="poke", ...)
legacy_type = _.pop("type", None)
if poke_type is None:
poke_type = legacy_type
if poke_type in (None, "", "poke", "Poke"):
poke_type = "126"
super().__init__(_type=str(poke_type), **_)
def target_id(self) -> str | None:
"""Return normalized target id, compatible with old `qq` field."""
for value in (self.id, self.qq):
if value is None:
continue
text = str(value).strip()
if text and text != "0":
return text
return None
def toDict(self):
target_id = self.target_id()
data = {"type": str(self._type or "126")}
if target_id:
data["id"] = target_id
return {"type": "poke", "data": data}
class Forward(BaseMessageComponent):
@@ -676,21 +699,24 @@ class File(BaseMessageComponent):
if self.url:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
# 等待下载完成
loop.run_until_complete(self._download_file())
# 检查是否有正在运行的 event loop
asyncio.get_running_loop()
logger.warning(
"不可以在异步上下文中同步等待下载! "
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
"请使用 await get_file() 代替直接获取 <File>.file 字段",
)
return ""
except RuntimeError:
# 没有运行中的 event loop,可以同步执行
try:
# 使用 asyncio.run 安全地创建和关闭事件循环
asyncio.run(self._download_file())
except Exception:
logger.exception("文件下载失败")
if self.file_ and os.path.exists(self.file_):
return os.path.abspath(self.file_)
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
+1 -1
View File
@@ -28,7 +28,7 @@ class RespondStage(Stage):
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
Comp.Image: lambda comp: bool(comp.file), # 图片
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
Comp.Poke: lambda comp: comp.target_id() is not None, # 戳一戳
Comp.Node: lambda comp: bool(comp.content), # 转发节点
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
Comp.File: lambda comp: bool(comp.file_ or comp.url),
@@ -5,7 +5,7 @@ import traceback
from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.components import At, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
@@ -383,8 +383,11 @@ class ResultDecorateStage(Stage):
)
result.chain = [node]
has_plain = any(isinstance(item, Plain) for item in result.chain)
if has_plain:
# at 回复 / 引用回复仅适用于纯文本或图文消息
can_decorate = all(
isinstance(item, (Plain, Image)) for item in result.chain
)
if can_decorate:
# at 回复
if (
self.reply_with_mention
@@ -399,5 +402,4 @@ class ResultDecorateStage(Stage):
# 引用回复
if self.reply_with_quote:
if not any(isinstance(item, File) for item in result.chain):
result.chain.insert(0, Reply(id=event.message_obj.message_id))
result.chain.insert(0, Reply(id=event.message_obj.message_id))
@@ -191,7 +191,7 @@ class AiocqhttpAdapter(Platform):
if "sub_type" in event:
if event["sub_type"] == "poke" and "target_id" in event:
abm.message.append(Poke(qq=str(event["target_id"]), type="poke"))
abm.message.append(Poke(id=str(event["target_id"])))
return abm
@@ -11,7 +11,7 @@ from dingtalk_stream import AckMessage
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, Image, Plain, Record, Video
from astrbot.api.message_components import At, File, Image, Plain, Record, Video
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
@@ -178,29 +178,110 @@ class DingtalkPlatformAdapter(Platform):
abm.session_id = abm.sender.user_id
message_type: str = cast(str, message.message_type)
robot_code = cast(str, message.robot_code or "")
raw_content = cast(dict, message.extensions.get("content") or {})
if not isinstance(raw_content, dict):
raw_content = {}
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "picture":
if not robot_code:
logger.error("钉钉图片消息解析失败: 回调中缺少 robotCode")
await self._remember_sender_binding(message, abm)
return abm
image_content = cast(
dingtalk_stream.ImageContent | None,
message.image_content,
)
download_code = cast(
str, (image_content.download_code if image_content else "") or ""
)
if not download_code:
logger.warning("钉钉图片消息缺少 downloadCode,已跳过")
else:
f_path = await self.download_ding_file(
download_code,
robot_code,
"jpg",
)
if f_path:
abm.message.append(Image.fromFileSystem(f_path))
else:
logger.warning("钉钉图片消息下载失败,无法解析为图片")
case "richText":
rtc: dingtalk_stream.RichTextContent = cast(
dingtalk_stream.RichTextContent, message.rich_text_content
)
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
plain_parts: list[str] = []
for content in contents:
plains = ""
if "text" in content:
plains += content["text"]
abm.message.append(Plain(plains))
plain_text = cast(str, content.get("text") or "")
if plain_text:
plain_parts.append(plain_text)
abm.message.append(Plain(plain_text))
elif "type" in content and content["type"] == "picture":
download_code = cast(str, content.get("downloadCode") or "")
if not download_code:
logger.warning(
"钉钉富文本图片消息缺少 downloadCode,已跳过"
)
continue
if not robot_code:
logger.error(
"钉钉富文本图片消息解析失败: 回调中缺少 robotCode"
)
continue
f_path = await self.download_ding_file(
content["downloadCode"],
cast(str, message.robot_code),
download_code,
robot_code,
"jpg",
)
abm.message.append(Image.fromFileSystem(f_path))
case "audio":
pass
if f_path:
abm.message.append(Image.fromFileSystem(f_path))
abm.message_str = "".join(plain_parts).strip()
case "audio" | "voice":
download_code = cast(str, raw_content.get("downloadCode") or "")
if not download_code:
logger.warning("钉钉语音消息缺少 downloadCode,已跳过")
elif not robot_code:
logger.error("钉钉语音消息解析失败: 回调中缺少 robotCode")
else:
voice_ext = cast(str, raw_content.get("fileExtension") or "")
if not voice_ext:
voice_ext = "amr"
voice_ext = voice_ext.lstrip(".")
f_path = await self.download_ding_file(
download_code,
robot_code,
voice_ext,
)
if f_path:
abm.message.append(Record.fromFileSystem(f_path))
case "file":
download_code = cast(str, raw_content.get("downloadCode") or "")
if not download_code:
logger.warning("钉钉文件消息缺少 downloadCode,已跳过")
elif not robot_code:
logger.error("钉钉文件消息解析失败: 回调中缺少 robotCode")
else:
file_name = cast(str, raw_content.get("fileName") or "")
file_ext = Path(file_name).suffix.lstrip(".") if file_name else ""
if not file_ext:
file_ext = cast(str, raw_content.get("fileExtension") or "")
if not file_ext:
file_ext = "file"
f_path = await self.download_ding_file(
download_code,
robot_code,
file_ext,
)
if f_path:
if not file_name:
file_name = Path(f_path).name
abm.message.append(File(name=file_name, file=f_path))
await self._remember_sender_binding(message, abm)
return abm # 别忘了返回转换后的消息对象
@@ -270,13 +351,23 @@ class DingtalkPlatformAdapter(Platform):
)
return ""
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
download_url = cast(
str,
(
resp_data.get("downloadUrl")
or resp_data.get("data", {}).get("downloadUrl")
or ""
),
)
if not download_url:
logger.error(f"下载钉钉文件失败: 未找到 downloadUrl, 响应: {resp_data}")
return ""
await download_file(download_url, str(f_path))
return str(f_path)
async def get_access_token(self) -> str:
try:
access_token = await asyncio.get_event_loop().run_in_executor(
access_token = await asyncio.get_running_loop().run_in_executor(
None,
self.client_.get_access_token,
)
@@ -541,6 +632,28 @@ class DingtalkPlatformAdapter(Platform):
self._safe_remove_file(cover_path)
if converted_video:
self._safe_remove_file(video_path)
elif isinstance(segment, File):
try:
file_path = await segment.get_file()
if not file_path:
logger.warning("钉钉文件发送失败: 无法解析文件路径")
continue
media_id = await self.upload_media(file_path, "file")
if not media_id:
continue
file_name = segment.name or Path(file_path).name
file_type = Path(file_name).suffix.lstrip(".")
await send_message(
msg_key="sampleFile",
msg_param={
"mediaId": media_id,
"fileName": file_name,
"fileType": file_type,
},
)
except Exception as e:
logger.warning(f"钉钉文件发送失败: {e}")
continue
async def send_message_chain_to_group(
self,
@@ -647,7 +760,7 @@ class DingtalkPlatformAdapter(Platform):
return
logger.error(f"钉钉机器人启动失败: {e}")
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self) -> None:
@@ -80,7 +80,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if isinstance(source, botpy.message.C2CMessage):
# 真流式传输
current_time = asyncio.get_event_loop().time()
current_time = asyncio.get_running_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
@@ -90,7 +90,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
last_edit_time = asyncio.get_running_loop().time()
if isinstance(source, botpy.message.C2CMessage):
# 结束流式对话,并且传输 buffer 中剩余的消息
@@ -55,7 +55,7 @@ class QQOfficialWebhook:
max_async=1,
connect=bot_connect,
dispatch=self.client.ws_dispatch,
loop=asyncio.get_event_loop(),
loop=asyncio.get_running_loop(),
api=self.api,
)
@@ -626,7 +626,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 发送初始 typing 状态
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = asyncio.get_event_loop().time()
last_chat_action_time = asyncio.get_running_loop().time()
def _append_text(t: str) -> None:
nonlocal delta
@@ -657,11 +657,11 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 编辑或发送消息
if message_id and len(delta) <= self.MAX_MESSAGE_LENGTH:
current_time = asyncio.get_event_loop().time()
current_time = asyncio.get_running_loop().time()
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
current_time = asyncio.get_event_loop().time()
current_time = asyncio.get_running_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
@@ -674,9 +674,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
current_content = delta
except Exception as e:
logger.warning(f"编辑消息失败(streaming): {e!s}")
last_edit_time = asyncio.get_event_loop().time()
last_edit_time = asyncio.get_running_loop().time()
else:
current_time = asyncio.get_event_loop().time()
current_time = asyncio.get_running_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
@@ -688,7 +688,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
message_id = msg.message_id
last_edit_time = asyncio.get_event_loop().time()
last_edit_time = asyncio.get_running_loop().time()
try:
if delta and current_content != delta:
@@ -200,7 +200,7 @@ class WecomPlatformAdapter(Platform):
return msg_list[-1]
return None
msg_new = await asyncio.get_event_loop().run_in_executor(
msg_new = await asyncio.get_running_loop().run_in_executor(
None,
get_latest_msg_item,
)
@@ -261,7 +261,7 @@ class WecomPlatformAdapter(Platform):
@override
async def run(self) -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
if self.kf_name:
try:
acc_list = (
@@ -339,7 +339,7 @@ class WecomPlatformAdapter(Platform):
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif isinstance(msg, VoiceMessage):
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
msg.media_id,
@@ -395,7 +395,7 @@ class WecomPlatformAdapter(Platform):
abm.message_str = text
elif msgtype == "image":
media_id = msg.get("image", {}).get("media_id", "")
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
media_id,
@@ -407,7 +407,7 @@ class WecomPlatformAdapter(Platform):
abm.message = [Image(file=path, url=path)]
elif msgtype == "voice":
media_id = msg.get("voice", {}).get("media_id", "")
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
media_id,
@@ -1,5 +1,5 @@
"""企业微信智能机器人平台适配器
基于企业微信智能机器人 API 的消息平台适配器支持 HTTP 回调
基于企业微信智能机器人 API 的消息平台适配器支持 HTTP 回调与长连接
参考webchat_adapter.py的队列机制实现异步消息处理和流式响应
"""
@@ -31,6 +31,7 @@ from .wecomai_api import (
WecomAIBotStreamMessageBuilder,
)
from .wecomai_event import WecomAIBotMessageEvent
from .wecomai_long_connection import WecomAIBotLongConnectionClient
from .wecomai_queue_mgr import WecomAIQueueMgr
from .wecomai_server import WecomAIBotServer
from .wecomai_utils import (
@@ -78,8 +79,13 @@ class WecomAIBotAdapter(Platform):
self.settings = platform_settings
# 初始化配置参数
self.token = self.config["token"]
self.encoding_aes_key = self.config["encoding_aes_key"]
self.connection_mode = self.config.get(
"wecom_ai_bot_connection_mode", "webhook"
)
self.token = self.config.get("token", self.config.get("wecomaibot_token", ""))
self.encoding_aes_key = self.config.get(
"encoding_aes_key", self.config.get("wecomaibot_encoding_aes_key", "")
)
self.port = int(self.config["port"])
self.host = self.config.get("callback_server_host", "0.0.0.0")
self.bot_name = self.config.get("wecom_ai_bot_name", "")
@@ -96,25 +102,52 @@ class WecomAIBotAdapter(Platform):
self.only_use_webhook_url_to_send = bool(
self.config.get("only_use_webhook_url_to_send", False),
)
self.long_connection_bot_id = self.config.get(
"wecomaibot_ws_bot_id", self.config.get("long_connection_bot_id", "")
)
self.long_connection_secret = self.config.get(
"wecomaibot_ws_secret", self.config.get("long_connection_secret", "")
)
self.long_connection_ws_url = self.config.get(
"wecomaibot_ws_url",
"wss://openws.work.weixin.qq.com",
)
self.long_connection_heartbeat_interval = int(
self.config.get("wecomaibot_heartbeat_interval", 30),
)
# 平台元数据
self.metadata = PlatformMetadata(
name="wecom_ai_bot",
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
description="企业微信智能机器人适配器,支持 HTTP 回调和长连接模式",
id=self.config.get("id", "wecom_ai_bot"),
support_proactive_message=bool(self.msg_push_webhook_url),
)
# 初始化 API 客户端
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
self.api_client: WecomAIBotAPIClient | None = None
self.server: WecomAIBotServer | None = None
self.long_connection_client: WecomAIBotLongConnectionClient | None = None
# 初始化 HTTP 服务器
self.server = WecomAIBotServer(
host=self.host,
port=self.port,
api_client=self.api_client,
message_handler=self._process_message,
)
if self.connection_mode == "long_connection":
if not self.long_connection_bot_id or not self.long_connection_secret:
logger.warning(
"企业微信智能机器人长连接模式缺少 BotID 或 Secret,连接可能失败"
)
self.long_connection_client = WecomAIBotLongConnectionClient(
bot_id=self.long_connection_bot_id,
secret=self.long_connection_secret,
ws_url=self.long_connection_ws_url,
heartbeat_interval=self.long_connection_heartbeat_interval,
message_handler=self._process_long_connection_payload,
)
else:
self.api_client = WecomAIBotAPIClient(self.token, self.encoding_aes_key)
self.server = WecomAIBotServer(
host=self.host,
port=self.port,
api_client=self.api_client,
message_handler=self._process_message,
)
# 事件循环和关闭信号
self.shutdown_event = asyncio.Event()
@@ -161,6 +194,9 @@ class WecomAIBotAdapter(Platform):
加密后的响应消息无需响应时返回 None
"""
if not self.api_client:
logger.error("Webhook 消息处理失败: API 客户端未初始化")
return None
msgtype = message_data.get("msgtype")
if not msgtype:
logger.warning(f"消息类型未知,忽略: {message_data}")
@@ -320,6 +356,89 @@ class WecomAIBotAdapter(Platform):
logger.error("处理欢迎消息时发生异常: %s", e)
return None
async def _process_long_connection_payload(
self,
payload: dict[str, Any],
) -> None:
"""处理长连接回调消息。"""
cmd = payload.get("cmd")
headers = payload.get("headers") or {}
body = payload.get("body") or {}
req_id = headers.get("req_id")
if not isinstance(body, dict):
return
if cmd == "aibot_msg_callback":
session_id = self._extract_session_id(body)
stream_id = f"{session_id}_{generate_random_string(10)}"
await self._enqueue_message(
body, {"req_id": req_id or ""}, stream_id, session_id
)
self.queue_mgr.set_pending_response(
stream_id,
{
"req_id": req_id or "",
"connection_mode": "long_connection",
},
)
if self.initial_respond_text and req_id:
await self._send_long_connection_respond_msg(
req_id=req_id,
body={
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": self.initial_respond_text,
},
},
)
return
if cmd == "aibot_event_callback":
event = body.get("event") or {}
event_type = event.get("eventtype")
if (
event_type == "enter_chat"
and self.friend_message_welcome_text
and req_id
):
await self._send_long_connection_respond_welcome(req_id)
elif event_type == "disconnected_event":
logger.warning(
"[WecomAI][LongConn] 收到 disconnected_event,旧连接将被关闭"
)
async def _send_long_connection_respond_welcome(self, req_id: str) -> bool:
client = self.long_connection_client
if not client:
return False
return await client.send_command(
cmd="aibot_respond_welcome_msg",
req_id=req_id,
body={
"msgtype": "text",
"text": {
"content": self.friend_message_welcome_text,
},
},
)
async def _send_long_connection_respond_msg(
self,
req_id: str,
body: dict[str, Any],
) -> bool:
client = self.long_connection_client
if not client:
return False
return await client.send_command(
cmd="aibot_respond_msg",
req_id=req_id,
body=body,
)
def _extract_session_id(self, message_data: dict[str, Any]) -> str:
"""从消息数据中提取会话ID"""
user_id = message_data.get("from", {}).get("userid", "default_user")
@@ -355,15 +474,16 @@ class WecomAIBotAdapter(Platform):
content = ""
image_base64 = []
_img_url_to_process = []
_img_url_to_process: list[tuple[str, str | None]] = []
msg_items = []
if msgtype == WecomAIBotConstants.MSG_TYPE_TEXT:
content = WecomAIBotMessageParser.parse_text_message(message_data)
elif msgtype == WecomAIBotConstants.MSG_TYPE_IMAGE:
_img_url_to_process.append(
WecomAIBotMessageParser.parse_image_message(message_data),
)
image_payload = message_data.get("image", {})
image_url = image_payload.get("url", "")
if image_url:
_img_url_to_process.append((image_url, image_payload.get("aeskey")))
elif msgtype == WecomAIBotConstants.MSG_TYPE_MIXED:
# 提取混合消息中的文本内容
msg_items = WecomAIBotMessageParser.parse_mixed_message(message_data)
@@ -374,9 +494,12 @@ class WecomAIBotAdapter(Platform):
if text_content:
text_parts.append(text_content)
elif item.get("msgtype") == WecomAIBotConstants.MSG_TYPE_IMAGE:
image_url = item.get("image", {}).get("url", "")
image_payload = item.get("image", {})
image_url = image_payload.get("url", "")
if image_url:
_img_url_to_process.append(image_url)
_img_url_to_process.append(
(image_url, image_payload.get("aeskey"))
)
content = " ".join(text_parts) if text_parts else ""
else:
content = f"[{msgtype}消息]"
@@ -384,8 +507,8 @@ class WecomAIBotAdapter(Platform):
# 并行处理图片下载和解密
if _img_url_to_process:
tasks = [
process_encrypted_image(url, self.encoding_aes_key)
for url in _img_url_to_process
process_encrypted_image(url, aes_key or self.encoding_aes_key)
for url, aes_key in _img_url_to_process
]
results = await asyncio.gather(*tasks)
for success, result in results:
@@ -459,26 +582,43 @@ class WecomAIBotAdapter(Platform):
"""运行适配器,同时启动HTTP服务器和队列监听器"""
async def run_both() -> None:
# 如果启用统一 webhook 模式,则不启动独立服务器
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", webhook_uuid)
# 只运行队列监听器
await self.queue_listener.run()
else:
if self.connection_mode == "long_connection":
if not self.long_connection_client:
raise RuntimeError("长连接客户端未初始化")
logger.info(
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
"启动企业微信智能机器人长连接模式: %s", self.long_connection_ws_url
)
# 同时运行HTTP服务器和队列监听器
await asyncio.gather(
self.server.start_server(),
self.long_connection_client.start(),
self.queue_listener.run(),
)
else:
# 如果启用统一 webhook 模式,则不启动独立服务器
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(
f"{self.meta().id}(企业微信智能机器人)", webhook_uuid
)
# 只运行队列监听器
await self.queue_listener.run()
else:
if not self.server:
raise RuntimeError("Webhook 服务器未初始化")
logger.info(
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
)
# 同时运行HTTP服务器和队列监听器
await asyncio.gather(
self.server.start_server(),
self.queue_listener.run(),
)
return run_both()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
if self.connection_mode == "long_connection" or not self.server:
return "long_connection mode does not accept webhook callbacks", 400
# 根据请求方法分发到不同的处理函数
if request.method == "GET":
return await self.server.handle_verify(request)
@@ -489,7 +629,10 @@ class WecomAIBotAdapter(Platform):
"""终止适配器"""
logger.info("企业微信智能机器人适配器正在关闭...")
self.shutdown_event.set()
await self.server.shutdown()
if self.long_connection_client:
await self.long_connection_client.shutdown()
if self.server:
await self.server.shutdown()
def meta(self) -> PlatformMetadata:
"""获取平台元数据"""
@@ -507,17 +650,22 @@ class WecomAIBotAdapter(Platform):
queue_mgr=self.queue_mgr,
webhook_client=self.webhook_client,
only_use_webhook_url_to_send=self.only_use_webhook_url_to_send,
long_connection_sender=self._send_long_connection_respond_msg,
)
message_event.is_at_or_wake_command = (
True # 企业微信智能机器人默认消息都是 at 或唤醒命令
)
message_event.is_wake = True # 企业微信智能机器人消息默认当做唤醒命令处理
self.commit_event(message_event)
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
def get_client(self) -> WecomAIBotAPIClient:
def get_client(self) -> WecomAIBotAPIClient | None:
"""获取 API 客户端"""
return self.api_client
def get_server(self) -> WecomAIBotServer:
def get_server(self) -> WecomAIBotServer | None:
"""获取 HTTP 服务器实例"""
return self.server
@@ -1,5 +1,7 @@
"""企业微信智能机器人事件处理模块,处理消息事件的发送和接收"""
from collections.abc import Awaitable, Callable
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, Image, Plain
@@ -18,10 +20,11 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
message_obj,
platform_meta,
session_id: str,
api_client: WecomAIBotAPIClient,
api_client: WecomAIBotAPIClient | None,
queue_mgr: WecomAIQueueMgr,
webhook_client: WecomAIBotWebhookClient | None = None,
only_use_webhook_url_to_send: bool = False,
long_connection_sender: (Callable[[str, dict], Awaitable[bool]] | None) = None,
) -> None:
"""初始化消息事件
@@ -38,6 +41,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
self.queue_mgr = queue_mgr
self.webhook_client = webhook_client
self.only_use_webhook_url_to_send = only_use_webhook_url_to_send
self.long_connection_sender = long_connection_sender
async def _mark_stream_complete(self, stream_id: str) -> None:
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
@@ -117,6 +121,18 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data
@staticmethod
def _extract_plain_text_from_chain(message_chain: MessageChain | None) -> str:
if not message_chain:
return ""
plain_parts: list[str] = []
for comp in message_chain.chain:
if isinstance(comp, At):
plain_parts.append(f"@{comp.name} ")
elif isinstance(comp, Plain):
plain_parts.append(comp.text)
return "".join(plain_parts).strip()
async def send(self, message: MessageChain | None) -> None:
"""发送消息"""
raw = self.message_obj.raw_message
@@ -124,6 +140,44 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get(
"connection_mode"
)
req_id = pending_response.get("callback_params", {}).get("req_id")
if (
connection_mode == "long_connection"
and self.long_connection_sender
and isinstance(req_id, str)
and req_id
):
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await super().send(MessageChain([]))
return
if self.webhook_client and message:
await self.webhook_client.send_message_chain(
message,
unsupported_only=True,
)
content = self._extract_plain_text_from_chain(message)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": content,
},
},
)
await super().send(MessageChain([]))
return
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await self._mark_stream_complete(stream_id)
@@ -152,8 +206,77 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
pending_response = self.queue_mgr.get_pending_response(stream_id) or {}
connection_mode = pending_response.get("callback_params", {}).get(
"connection_mode"
)
req_id = pending_response.get("callback_params", {}).get("req_id")
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
if (
connection_mode == "long_connection"
and self.long_connection_sender
and isinstance(req_id, str)
and req_id
):
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
merged_chain.chain.extend(chain.chain)
merged_chain.squash_plain()
await self.webhook_client.send_message_chain(merged_chain)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": "",
},
},
)
await super().send_streaming(generator, use_fallback)
return
increment_plain = ""
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(
chain,
unsupported_only=True,
)
chain.squash_plain()
chunk_text = self._extract_plain_text_from_chain(chain)
if chunk_text:
increment_plain += chunk_text
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": False,
"content": increment_plain,
},
},
)
await self.long_connection_sender(
req_id,
{
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": increment_plain,
},
},
)
await super().send_streaming(generator, use_fallback)
return
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
@@ -0,0 +1,236 @@
"""企业微信智能机器人长连接客户端。"""
import asyncio
import json
import uuid
from collections.abc import Awaitable, Callable
from typing import Any
import aiohttp
from astrbot.api import logger
class WecomAIBotLongConnectionClient:
"""企业微信智能机器人 WebSocket 长连接客户端。"""
def __init__(
self,
bot_id: str,
secret: str,
ws_url: str,
heartbeat_interval: int,
message_handler: Callable[[dict[str, Any]], Awaitable[None]],
) -> None:
self.bot_id = bot_id
self.secret = secret
self.ws_url = ws_url
self.heartbeat_interval = max(5, int(heartbeat_interval))
self.message_handler = message_handler
self._session: aiohttp.ClientSession | None = None
self._ws: aiohttp.ClientWebSocketResponse | None = None
self._shutdown_event = asyncio.Event()
self._send_lock = asyncio.Lock()
self._command_lock = asyncio.Lock()
self._response_waiters: dict[str, asyncio.Future[dict[str, Any]]] = {}
@staticmethod
def gen_req_id() -> str:
return uuid.uuid4().hex
async def start(self) -> None:
"""启动长连接并自动重连。"""
reconnect_delay = 1
while not self._shutdown_event.is_set():
try:
await self._run_once()
reconnect_delay = 1
except asyncio.CancelledError:
raise
except Exception as e:
logger.error("[WecomAI][LongConn] 长连接异常: %s", e)
if self._shutdown_event.is_set():
break
await asyncio.sleep(reconnect_delay)
reconnect_delay = min(reconnect_delay * 2, 30)
async def _run_once(self) -> None:
timeout = aiohttp.ClientTimeout(total=None, sock_connect=15, sock_read=None)
async with aiohttp.ClientSession(timeout=timeout) as session:
self._session = session
logger.info("[WecomAI][LongConn] 正在连接: %s", self.ws_url)
async with session.ws_connect(
self.ws_url, heartbeat=None, autoping=True
) as ws:
self._ws = ws
await self._subscribe()
logger.info("[WecomAI][LongConn] 订阅成功,已建立长连接")
heartbeat_task = asyncio.create_task(self._heartbeat_loop())
try:
while not self._shutdown_event.is_set():
message = await ws.receive()
if message.type == aiohttp.WSMsgType.TEXT:
await self._handle_text_message(message.data)
elif message.type in {
aiohttp.WSMsgType.CLOSED,
aiohttp.WSMsgType.CLOSE,
aiohttp.WSMsgType.ERROR,
}:
break
finally:
heartbeat_task.cancel()
try:
await heartbeat_task
except asyncio.CancelledError:
pass
self._ws = None
async def _subscribe(self) -> None:
"""发送 aibot_subscribe,并等待响应。"""
req_id = self.gen_req_id()
payload = {
"cmd": "aibot_subscribe",
"headers": {"req_id": req_id},
"body": {"bot_id": self.bot_id, "secret": self.secret},
}
await self._send_json(payload)
if not self._ws:
raise RuntimeError("WebSocket 未建立")
reply = await self._ws.receive(timeout=10)
if reply.type != aiohttp.WSMsgType.TEXT:
raise RuntimeError(f"订阅失败: 非文本响应 {reply.type}")
data = json.loads(reply.data)
if data.get("errcode") != 0:
raise RuntimeError(
f"订阅失败 errcode={data.get('errcode')} errmsg={data.get('errmsg')}"
)
async def _heartbeat_loop(self) -> None:
while not self._shutdown_event.is_set():
await asyncio.sleep(self.heartbeat_interval)
if self._shutdown_event.is_set():
break
try:
await self.send_command("ping", self.gen_req_id(), None)
except Exception as e:
logger.warning("[WecomAI][LongConn] 发送心跳失败: %s", e)
return
async def _handle_text_message(self, text: str) -> None:
try:
payload = json.loads(text)
except json.JSONDecodeError:
logger.warning("[WecomAI][LongConn] 收到非 JSON 消息: %s", text)
return
headers = payload.get("headers") or {}
req_id = headers.get("req_id")
if isinstance(req_id, str):
waiter = self._response_waiters.get(req_id)
if waiter and not waiter.done():
waiter.set_result(payload)
return
cmd = payload.get("cmd")
if cmd in {"aibot_msg_callback", "aibot_event_callback"}:
await self.message_handler(payload)
return
if payload.get("errcode") not in (None, 0):
logger.warning(
"[WecomAI][LongConn] 服务端返回错误: errcode=%s errmsg=%s",
payload.get("errcode"),
payload.get("errmsg"),
)
async def send_command(
self,
cmd: str,
req_id: str,
body: dict[str, Any] | None,
) -> bool:
"""发送长连接命令。"""
headers = {"req_id": req_id}
payload: dict[str, Any] = {"cmd": cmd, "headers": headers}
if body is not None:
payload["body"] = body
async with self._command_lock:
max_retries = 3
for attempt in range(max_retries + 1):
response = await self._send_and_wait_response(req_id, payload)
if not response:
if attempt < max_retries:
await asyncio.sleep(min(0.2 * (2**attempt), 2.0))
continue
return False
errcode = response.get("errcode")
if errcode in (0, None):
return True
if errcode == 6000 and attempt < max_retries:
backoff = min(0.2 * (2**attempt), 2.0)
logger.warning(
"[WecomAI][LongConn] 命令冲突(errcode=6000),将重试。cmd=%s req_id=%s attempt=%d",
cmd,
req_id,
attempt + 1,
)
await asyncio.sleep(backoff)
continue
logger.warning(
"[WecomAI][LongConn] 命令失败: cmd=%s req_id=%s errcode=%s errmsg=%s",
cmd,
req_id,
errcode,
response.get("errmsg"),
)
return False
return False
async def _send_and_wait_response(
self,
req_id: str,
payload: dict[str, Any],
timeout: float = 10.0,
) -> dict[str, Any] | None:
loop = asyncio.get_running_loop()
waiter: asyncio.Future[dict[str, Any]] = loop.create_future()
self._response_waiters[req_id] = waiter
try:
await self._send_json(payload)
return await asyncio.wait_for(waiter, timeout=timeout)
except TimeoutError:
logger.warning(
"[WecomAI][LongConn] 等待命令响应超时: cmd=%s req_id=%s",
payload.get("cmd"),
req_id,
)
return None
finally:
self._response_waiters.pop(req_id, None)
async def _send_json(self, payload: dict[str, Any]) -> None:
ws = self._ws
if ws is None or ws.closed:
raise RuntimeError("长连接尚未建立")
async with self._send_lock:
await ws.send_json(payload)
async def shutdown(self) -> None:
self._shutdown_event.set()
ws = self._ws
if ws is not None and not ws.closed:
await ws.close()
session = self._session
if session is not None and not session.closed:
await session.close()
@@ -4,6 +4,7 @@
"""
import asyncio
import time
from collections.abc import Awaitable, Callable
from typing import Any
@@ -82,7 +83,7 @@ class WecomAIQueueMgr:
del self.pending_responses[session_id]
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
if mark_finished:
self.completed_streams[session_id] = asyncio.get_event_loop().time()
self.completed_streams[session_id] = time.monotonic()
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
def remove_queue(self, session_id: str):
@@ -135,7 +136,7 @@ class WecomAIQueueMgr:
"""
self.pending_responses[session_id] = {
"callback_params": callback_params,
"timestamp": asyncio.get_event_loop().time(),
"timestamp": time.monotonic(),
}
logger.debug(f"[WecomAI] 设置待处理响应: {session_id}")
@@ -160,7 +161,7 @@ class WecomAIQueueMgr:
finished_at = self.completed_streams.get(session_id)
if finished_at is None:
return False
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
if time.monotonic() - finished_at > max_age_seconds:
self.completed_streams.pop(session_id, None)
return False
return True
@@ -172,7 +173,7 @@ class WecomAIQueueMgr:
max_age_seconds: 最大存活时间
"""
current_time = asyncio.get_event_loop().time()
current_time = time.monotonic()
expired_sessions = []
for session_id, response_data in self.pending_responses.items():
@@ -369,7 +369,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if future:
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
future = asyncio.get_running_loop().create_future()
self.wexin_event_workers[msg_id] = future
await self.convert_message(msg, future)
# I love shield so much!
@@ -461,7 +461,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
resp: Response = await asyncio.get_event_loop().run_in_executor(
resp: Response = await asyncio.get_running_loop().run_in_executor(
None,
self.client.media.download,
msg.media_id,
+4 -12
View File
@@ -346,10 +346,7 @@ class FunctionToolManager:
logger.debug(f" 主机: {scheme}://{host}{port}")
async def init_mcp_clients(
self,
raise_on_all_failed: bool = False,
*,
init_timeout: float | int | str | None = None,
self, raise_on_all_failed: bool = False
) -> MCPInitSummary:
"""从项目根目录读取 mcp_server.json 文件,初始化 MCP 服务列表。文件格式如下:
```
@@ -370,7 +367,6 @@ class FunctionToolManager:
```
Timeout behavior:
- 显式 `init_timeout` 参数优先用于测试或调用方覆盖
- 初始化超时使用环境变量 ASTRBOT_MCP_INIT_TIMEOUT 或默认值
- 动态启用超时使用 ASTRBOT_MCP_ENABLE_TIMEOUT独立于初始化超时
"""
@@ -387,12 +383,8 @@ class FunctionToolManager:
with open(mcp_json_file, encoding="utf-8") as f:
mcp_server_json_obj: dict[str, dict] = json.load(f)["mcpServers"]
init_timeout_value = _resolve_timeout(
timeout=init_timeout,
env_name=MCP_INIT_TIMEOUT_ENV,
default=self._init_timeout_default,
)
timeout_display = f"{init_timeout_value:g}"
init_timeout = self._init_timeout_default
timeout_display = f"{init_timeout:g}"
active_configs: list[tuple[str, dict, asyncio.Event]] = []
for name, cfg in mcp_server_json_obj.items():
@@ -411,7 +403,7 @@ class FunctionToolManager:
name=name,
cfg=cfg,
shutdown_event=shutdown_event,
timeout_seconds=init_timeout_value,
timeout=init_timeout,
),
name=f"mcp-init:{name}",
)
+2 -7
View File
@@ -269,11 +269,7 @@ class ProviderManager:
return provider
async def initialize(
self,
*,
init_timeout: float | int | str | None = None,
) -> None:
async def initialize(self) -> None:
# 逐个初始化提供商
for provider_config in self.providers_config:
try:
@@ -342,8 +338,7 @@ class ProviderManager:
"on",
}
mcp_init_summary = await self.llm_tools.init_mcp_clients(
raise_on_all_failed=strict_mcp_init,
init_timeout=init_timeout,
raise_on_all_failed=strict_mcp_init
)
if (
mcp_init_summary.total > 0
@@ -276,9 +276,24 @@ class ProviderAnthropic(Provider):
llm_response.id = completion.id
llm_response.usage = self._extract_usage(completion.usage)
# TODO(Soulter): 处理 end_turn 情况
# Handle cases where completion only contains ThinkingBlock (e.g., MiniMax max_tokens)
# When stop_reason='max_tokens', the model may return only thinking content
# This is valid and should not raise an exception
if not llm_response.completion_text and not llm_response.tools_call_args:
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}")
# Guard clause: raise early if no valid content at all
if not llm_response.reasoning_content:
raise ValueError(
f"Anthropic API returned unparsable completion: "
f"no text, tool_use, or thinking content found. "
f"Completion: {completion}"
)
# We have reasoning content (ThinkingBlock) - this is valid
stop_reason = getattr(completion, "stop_reason", "unknown")
logger.debug(
f"Completion contains only ThinkingBlock (stop_reason={stop_reason})"
)
llm_response.completion_text = "" # Ensure empty string, not None
return llm_response
@@ -87,7 +87,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
model: str,
text: str,
) -> tuple[bytes | None, str]:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(None, self._call_qwen_tts, model, text)
audio_bytes = await self._extract_audio_from_response(response)
if not audio_bytes:
@@ -143,7 +143,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
voice=self.voice,
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
)
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
audio_bytes = await loop.run_in_executor(
None,
synthesizer.call,
+2 -2
View File
@@ -59,7 +59,7 @@ class GenieTTSProvider(TTSProvider):
filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename)
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
def _generate(save_path: str) -> None:
assert genie is not None
@@ -85,7 +85,7 @@ class GenieTTSProvider(TTSProvider):
text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
) -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
while True:
text = await text_queue.get()
@@ -43,7 +43,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
logger.info("下载或者加载 SenseVoice 模型中,这可能需要一些时间 ...")
# 将模型加载放到线程池中执行
self.model = await asyncio.get_event_loop().run_in_executor(
self.model = await asyncio.get_running_loop().run_in_executor(
None,
lambda: SenseVoiceSmall(self.model_name, quantize=True, batch_size=16),
)
@@ -88,7 +88,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
audio_url = output_path
# 使用 run_in_executor 来调用模型进行识别
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
res = await loop.run_in_executor(
None, # 使用默认的线程池
lambda: cast(SenseVoiceSmall, self.model)(
@@ -31,7 +31,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
self.model = None
async def initialize(self) -> None:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor(
None,
@@ -50,7 +50,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
return False
async def get_text(self, audio_url: str) -> str:
loop = asyncio.get_event_loop()
loop = asyncio.get_running_loop()
is_tencent = False
+17 -2
View File
@@ -26,6 +26,13 @@ _SANDBOX_SKILLS_CACHE_VERSION = 1
_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$")
def _is_ignored_zip_entry(name: str) -> bool:
parts = PurePosixPath(name).parts
if not parts:
return True
return parts[0] == "__MACOSX"
@dataclass
class SkillInfo:
name: str
@@ -401,7 +408,11 @@ class SkillManager:
raise ValueError("Uploaded file is not a valid zip archive.")
with zipfile.ZipFile(zip_path) as zf:
names = [name.replace("\\", "/") for name in zf.namelist()]
names = [
name
for name in (entry.replace("\\", "/") for entry in zf.namelist())
if name and not _is_ignored_zip_entry(name)
]
file_names = [name for name in names if name and not name.endswith("/")]
if not file_names:
raise ValueError("Zip archive is empty.")
@@ -436,7 +447,11 @@ class SkillManager:
raise ValueError("SKILL.md not found in the skill folder.")
with tempfile.TemporaryDirectory(dir=get_astrbot_temp_path()) as tmp_dir:
zf.extractall(tmp_dir)
for member in zf.infolist():
member_name = member.filename.replace("\\", "/")
if not member_name or _is_ignored_zip_entry(member_name):
continue
zf.extract(member, tmp_dir)
src_dir = Path(tmp_dir) / skill_name
if not src_dir.exists():
raise ValueError("Skill folder not found after extraction.")
+1 -1
View File
@@ -15,4 +15,4 @@ class RegexFilter(HandlerFilter):
self.regex = re.compile(regex)
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
return bool(self.regex.match(event.get_message_str().strip()))
return bool(self.regex.search(event.get_message_str().strip()))
+14 -1
View File
@@ -1374,10 +1374,23 @@ class PluginManager:
return
if "__del__" in star_metadata.star_cls_type.__dict__:
asyncio.get_event_loop().run_in_executor(
loop = asyncio.get_running_loop()
future = loop.run_in_executor(
None,
star_metadata.star_cls.__del__,
)
def _log_del_exception(fut: asyncio.Future) -> None:
if fut.cancelled():
return
if (exc := fut.exception()) is not None:
logger.error(
"插件 %s 在 __del__ 中抛出了异常:%r",
star_metadata.name,
exc,
)
future.add_done_callback(_log_del_exception)
elif "terminate" in star_metadata.star_cls_type.__dict__:
await star_metadata.star_cls.terminate()
+1 -1
View File
@@ -7,4 +7,4 @@ def is_frozen_runtime() -> bool:
def is_packaged_desktop_runtime() -> bool:
return is_frozen_runtime() and os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1"
return os.environ.get("ASTRBOT_DESKTOP_CLIENT") == "1"
+27 -1
View File
@@ -1,9 +1,13 @@
import asyncio
import threading
import weakref
from collections import defaultdict
from contextlib import asynccontextmanager
class SessionLockManager:
class _PerLoopSessionLockManager:
"""Per-event-loop session lock manager; keeps original simple semantics."""
def __init__(self) -> None:
self._locks: dict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
self._lock_count: dict[str, int] = defaultdict(int)
@@ -26,4 +30,26 @@ class SessionLockManager:
self._lock_count.pop(session_id, None)
class SessionLockManager:
"""Thread-safe session lock manager with per-event-loop isolation."""
def __init__(self) -> None:
self._state_guard = threading.Lock()
self._loop_managers: weakref.WeakKeyDictionary[
asyncio.AbstractEventLoop, _PerLoopSessionLockManager
] = weakref.WeakKeyDictionary()
def _get_loop_manager(self) -> _PerLoopSessionLockManager:
"""Get the lock manager for the current event loop."""
loop = asyncio.get_running_loop()
with self._state_guard:
return self._loop_managers.setdefault(loop, _PerLoopSessionLockManager())
@asynccontextmanager
async def acquire_lock(self, session_id: str):
manager = self._get_loop_manager()
async with manager.acquire_lock(session_id):
yield
session_lock_manager = SessionLockManager()
+2
View File
@@ -610,6 +610,7 @@ class ConfigRoute(Route):
try:
conf_id = self.acm.create_conf(name=name, config=config)
await self.core_lifecycle.reload_pipeline_scheduler(conf_id)
return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__
except ValueError as e:
return Response().error(str(e)).__dict__
@@ -649,6 +650,7 @@ class ConfigRoute(Route):
try:
success = self.acm.delete_conf(conf_id)
if success:
self.core_lifecycle.pipeline_scheduler_mapping.pop(conf_id, None)
return Response().ok(message="删除成功").__dict__
return Response().error("删除失败").__dict__
except ValueError as e:
+110
View File
@@ -2,6 +2,7 @@ import os
import re
import shutil
import traceback
import uuid
from collections.abc import Awaitable, Callable
from pathlib import Path
from typing import Any
@@ -50,6 +51,7 @@ class SkillsRoute(Route):
self.routes = {
"/skills": ("GET", self.get_skills),
"/skills/upload": ("POST", self.upload_skill),
"/skills/batch-upload": ("POST", self.batch_upload_skills),
"/skills/download": ("GET", self.download_skill),
"/skills/update": ("POST", self.update_skill),
"/skills/delete": ("POST", self.delete_skill),
@@ -188,6 +190,114 @@ class SkillsRoute(Route):
except Exception:
logger.warning(f"Failed to remove temp skill file: {temp_path}")
async def batch_upload_skills(self):
"""批量上传多个 skill ZIP 文件"""
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
files = await request.files
file_list = files.getlist("files")
if not file_list:
return Response().error("No files provided").__dict__
succeeded = []
failed = []
skill_mgr = SkillManager()
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
for file in file_list:
filename = os.path.basename(file.filename or "unknown.zip")
temp_path = None
try:
if not filename.lower().endswith(".zip"):
failed.append(
{
"filename": filename,
"error": "Only .zip files are supported",
}
)
continue
temp_path = os.path.join(
temp_dir, f"batch_{uuid.uuid4().hex}_{filename}"
)
await file.save(temp_path)
skill_name = skill_mgr.install_skill_from_zip(
temp_path, overwrite=True
)
succeeded.append({"filename": filename, "name": skill_name})
except Exception as e:
failed.append({"filename": filename, "error": str(e)})
finally:
if temp_path and os.path.exists(temp_path):
try:
os.remove(temp_path)
except Exception:
pass
if succeeded:
try:
await sync_skills_to_active_sandboxes()
except Exception:
logger.warning(
"Failed to sync uploaded skills to active sandboxes."
)
total = len(file_list)
success_count = len(succeeded)
if success_count == total:
message = f"All {total} skill(s) uploaded successfully."
return (
Response()
.ok(
{
"total": total,
"succeeded": succeeded,
"failed": failed,
},
message,
)
.__dict__
)
if success_count == 0:
message = f"Upload failed for all {total} file(s)."
resp = Response().error(message)
resp.data = {
"total": total,
"succeeded": succeeded,
"failed": failed,
}
return resp.__dict__
message = f"Partial success: {success_count}/{total} skill(s) uploaded."
return (
Response()
.ok(
{
"total": total,
"succeeded": succeeded,
"failed": failed,
},
message,
)
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(str(e)).__dict__
async def download_skill(self):
try:
name = str(request.args.get("name") or "").strip()
+185 -68
View File
@@ -12,6 +12,32 @@ from .route import Response, Route, RouteContext
DEFAULT_MCP_CONFIG = {"mcpServers": {}}
class EmptyMcpServersError(ValueError):
"""Raised when mcpServers is empty."""
pass
def _extract_mcp_server_config(mcp_servers_value: object) -> dict:
"""Extract server configuration from user-submitted mcpServers field.
Raises:
ValueError: Invalid configuration
"""
if not isinstance(mcp_servers_value, dict):
raise ValueError("mcpServers must be a JSON object")
if not mcp_servers_value:
raise EmptyMcpServersError("mcpServers configuration cannot be empty")
key_0 = next(iter(mcp_servers_value))
extracted = mcp_servers_value[key_0]
if not isinstance(extracted, dict):
raise ValueError(
"Invalid mcpServers format. Ensure each key in mcpServers is a server name, "
"and each value is an object containing fields like command/url."
)
return extracted
class ToolsRoute(Route):
def __init__(
self,
@@ -33,13 +59,37 @@ class ToolsRoute(Route):
self.register_routes()
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
def _rollback_mcp_server(self, name: str) -> bool:
try:
rollback_config = self.tool_mgr.load_mcp_config()
if name in rollback_config["mcpServers"]:
rollback_config["mcpServers"].pop(name)
return self.tool_mgr.save_mcp_config(rollback_config)
return True
except Exception:
logger.error(traceback.format_exc())
return False
async def get_mcp_servers(self):
try:
config = self.tool_mgr.load_mcp_config()
servers = []
mcp_servers = config.get("mcpServers", {})
if not isinstance(mcp_servers, dict):
logger.warning(
f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers."
)
mcp_servers = {}
# 获取所有服务器并添加它们的工具列表
for name, server_config in config["mcpServers"].items():
for name, server_config in mcp_servers.items():
if not isinstance(server_config, dict):
logger.warning(
f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped."
)
continue
server_info = {
"name": name,
"active": server_config.get("active", True),
@@ -65,7 +115,7 @@ class ToolsRoute(Route):
return Response().ok(servers).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取 MCP 服务器列表失败: {e!s}").__dict__
return Response().error(f"Failed to get MCP server list: {e!s}").__dict__
async def add_mcp_server(self):
try:
@@ -75,7 +125,7 @@ class ToolsRoute(Route):
# 检查必填字段
if not name:
return Response().error("服务器名称不能为空").__dict__
return Response().error("Server name cannot be empty").__dict__
# 移除特殊字段并检查配置是否有效
has_valid_config = False
@@ -85,21 +135,33 @@ class ToolsRoute(Route):
for key, value in server_data.items():
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
if key == "mcpServers":
key_0 = list(server_data["mcpServers"].keys())[
0
] # 不考虑为空的情况
server_config = server_data["mcpServers"][key_0]
try:
server_config = _extract_mcp_server_config(
server_data["mcpServers"]
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
else:
server_config[key] = value
has_valid_config = True
if not has_valid_config:
return Response().error("必须提供有效的服务器配置").__dict__
return (
Response()
.error("A valid server configuration is required")
.__dict__
)
config = self.tool_mgr.load_mcp_config()
if name in config["mcpServers"]:
return Response().error(f"服务器 {name} 已存在").__dict__
return Response().error(f"Server {name} already exists").__dict__
try:
await self.tool_mgr.test_mcp_server_connection(server_config)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"MCP connection test failed: {e!s}").__dict__
config["mcpServers"][name] = server_config
@@ -111,17 +173,27 @@ class ToolsRoute(Route):
timeout=30,
)
except TimeoutError:
return Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
rollback_ok = self._rollback_mcp_server(name)
err_msg = f"Timed out while enabling MCP server {name}."
if not rollback_ok:
err_msg += " Configuration rollback failed. Please check the config manually."
return Response().error(err_msg).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return (
Response().error(f"启用 MCP 服务器 {name} 失败: {e!s}").__dict__
)
return Response().ok(None, f"成功添加 MCP 服务器 {name}").__dict__
return Response().error("保存配置失败").__dict__
rollback_ok = self._rollback_mcp_server(name)
err_msg = f"Failed to enable MCP server {name}: {e!s}"
if not rollback_ok:
err_msg += " Configuration rollback failed. Please check the config manually."
return Response().error(err_msg).__dict__
return (
Response()
.ok(None, f"Successfully added MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"添加 MCP 服务器失败: {e!s}").__dict__
return Response().error(f"Failed to add MCP server: {e!s}").__dict__
async def update_mcp_server(self):
try:
@@ -131,23 +203,25 @@ class ToolsRoute(Route):
old_name = server_data.get("oldName") or name
if not name:
return Response().error("服务器名称不能为空").__dict__
return Response().error("Server name cannot be empty").__dict__
config = self.tool_mgr.load_mcp_config()
if old_name not in config["mcpServers"]:
return Response().error(f"服务器 {old_name} 不存在").__dict__
return Response().error(f"Server {old_name} does not exist").__dict__
is_rename = name != old_name
if name in config["mcpServers"] and is_rename:
return Response().error(f"服务器 {name} 已存在").__dict__
return Response().error(f"Server {name} already exists").__dict__
# 获取活动状态
active = server_data.get(
"active",
config["mcpServers"][old_name].get("active", True),
)
old_config = config["mcpServers"][old_name]
if isinstance(old_config, dict):
old_active = old_config.get("active", True)
else:
old_active = True
active = server_data.get("active", old_active)
# 创建新的配置对象
server_config = {"active": active}
@@ -165,17 +239,19 @@ class ToolsRoute(Route):
"oldName",
]: # 排除特殊字段
if key == "mcpServers":
key_0 = list(server_data["mcpServers"].keys())[
0
] # 不考虑为空的情况
server_config = server_data["mcpServers"][key_0]
try:
server_config = _extract_mcp_server_config(
server_data["mcpServers"]
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
else:
server_config[key] = value
only_update_active = False
# 如果只更新活动状态,保留原始配置
if only_update_active:
for key, value in config["mcpServers"][old_name].items():
if only_update_active and isinstance(old_config, dict):
for key, value in old_config.items():
if key != "active": # 除了active之外的所有字段都保留
server_config[key] = value
@@ -200,7 +276,7 @@ class ToolsRoute(Route):
return (
Response()
.error(
f"启用前停用 MCP 服务器时 {old_name} 超时: {e!s}"
f"Timed out while disabling MCP server {old_name} before enabling: {e!s}"
)
.__dict__
)
@@ -209,7 +285,7 @@ class ToolsRoute(Route):
return (
Response()
.error(
f"启用前停用 MCP 服务器时 {old_name} 失败: {e!s}"
f"Failed to disable MCP server {old_name} before enabling: {e!s}"
)
.__dict__
)
@@ -221,13 +297,15 @@ class ToolsRoute(Route):
)
except TimeoutError:
return (
Response().error(f"启用 MCP 服务器 {name} 超时。").__dict__
Response()
.error(f"Timed out while enabling MCP server {name}.")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(f"启用 MCP 服务器 {name} 失败: {e!s}")
.error(f"Failed to enable MCP server {name}: {e!s}")
.__dict__
)
# 如果要停用服务器
@@ -237,22 +315,26 @@ class ToolsRoute(Route):
except TimeoutError:
return (
Response()
.error(f"停用 MCP 服务器 {old_name} 超时。")
.error(f"Timed out while disabling MCP server {old_name}.")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(f"停用 MCP 服务器 {old_name} 失败: {e!s}")
.error(f"Failed to disable MCP server {old_name}: {e!s}")
.__dict__
)
return Response().ok(None, f"成功更新 MCP 服务器 {name}").__dict__
return Response().error("保存配置失败").__dict__
return (
Response()
.ok(None, f"Successfully updated MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"更新 MCP 服务器失败: {e!s}").__dict__
return Response().error(f"Failed to update MCP server: {e!s}").__dict__
async def delete_mcp_server(self):
try:
@@ -260,12 +342,12 @@ class ToolsRoute(Route):
name = server_data.get("name", "")
if not name:
return Response().error("服务器名称不能为空").__dict__
return Response().error("Server name cannot be empty").__dict__
config = self.tool_mgr.load_mcp_config()
if name not in config["mcpServers"]:
return Response().error(f"服务器 {name} 不存在").__dict__
return Response().error(f"Server {name} does not exist").__dict__
del config["mcpServers"][name]
@@ -275,51 +357,76 @@ class ToolsRoute(Route):
await self.tool_mgr.disable_mcp_server(name, timeout=10)
except TimeoutError:
return (
Response().error(f"停用 MCP 服务器 {name} 超时。").__dict__
Response()
.error(f"Timed out while disabling MCP server {name}.")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return (
Response()
.error(f"停用 MCP 服务器 {name} 失败: {e!s}")
.error(f"Failed to disable MCP server {name}: {e!s}")
.__dict__
)
return Response().ok(None, f"成功删除 MCP 服务器 {name}").__dict__
return Response().error("保存配置失败").__dict__
return (
Response()
.ok(None, f"Successfully deleted MCP server {name}")
.__dict__
)
return Response().error("Failed to save configuration").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"删除 MCP 服务器失败: {e!s}").__dict__
return Response().error(f"Failed to delete MCP server: {e!s}").__dict__
async def test_mcp_connection(self):
"""测试 MCP 服务器连接"""
"""Test MCP server connection."""
try:
server_data = await request.json
config = server_data.get("mcp_server_config", None)
if not isinstance(config, dict) or not config:
return Response().error("无效的 MCP 服务器配置").__dict__
return Response().error("Invalid MCP server configuration").__dict__
if "mcpServers" in config:
keys = list(config["mcpServers"].keys())
if not keys:
return Response().error("MCP 服务器配置不能为空").__dict__
if len(keys) > 1:
return Response().error("一次只能配置一个 MCP 服务器配置").__dict__
config = config["mcpServers"][keys[0]]
mcp_servers = config["mcpServers"]
if isinstance(mcp_servers, dict) and len(mcp_servers) > 1:
return (
Response()
.error(
"Only one MCP server configuration can be tested at a time"
)
.__dict__
)
try:
config = _extract_mcp_server_config(mcp_servers)
except EmptyMcpServersError:
return (
Response()
.error("MCP server configuration cannot be empty")
.__dict__
)
except ValueError as e:
return Response().error(f"{e!s}").__dict__
elif not config:
return Response().error("MCP 服务器配置不能为空").__dict__
return (
Response()
.error("MCP server configuration cannot be empty")
.__dict__
)
tools_name = await self.tool_mgr.test_mcp_server_connection(config)
return (
Response().ok(data=tools_name, message="🎉 MCP 服务器可用!").__dict__
Response()
.ok(data=tools_name, message="🎉 MCP server is available!")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"测试 MCP 连接失败: {e!s}").__dict__
return Response().error(f"Failed to test MCP connection: {e!s}").__dict__
async def get_tool_list(self):
"""获取所有注册的工具列表"""
"""Get all registered tools."""
try:
tools = self.tool_mgr.func_list
tools_dict = []
@@ -349,36 +456,44 @@ class ToolsRoute(Route):
return Response().ok(data=tools_dict).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"获取工具列表失败: {e!s}").__dict__
return Response().error(f"Failed to get tool list: {e!s}").__dict__
async def toggle_tool(self):
"""启用或停用指定的工具"""
"""Activate or deactivate a specified tool."""
try:
data = await request.json
tool_name = data.get("name")
action = data.get("activate") # True or False
if not tool_name or action is None:
return Response().error("缺少必要参数: name 或 action").__dict__
return (
Response()
.error("Missing required parameters: name or activate")
.__dict__
)
if action:
try:
ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map)
except ValueError as e:
return Response().error(f"启用工具失败: {e!s}").__dict__
return Response().error(f"Failed to activate tool: {e!s}").__dict__
else:
ok = self.tool_mgr.deactivate_llm_tool(tool_name)
if ok:
return Response().ok(None, "操作成功。").__dict__
return Response().error(f"工具 {tool_name} 不存在或操作失败。").__dict__
return Response().ok(None, "Operation successful.").__dict__
return (
Response()
.error(f"Tool {tool_name} does not exist or the operation failed.")
.__dict__
)
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"操作工具失败: {e!s}").__dict__
return Response().error(f"Failed to operate tool: {e!s}").__dict__
async def sync_provider(self):
"""同步 MCP 提供者配置"""
"""Sync MCP provider configuration."""
try:
data = await request.json
provider_name = data.get("name") # modelscope, or others
@@ -387,9 +502,11 @@ class ToolsRoute(Route):
access_token = data.get("access_token", "")
await self.tool_mgr.sync_modelscope_mcp_servers(access_token)
case _:
return Response().error(f"未知: {provider_name}").__dict__
return (
Response().error(f"Unknown provider: {provider_name}").__dict__
)
return Response().ok(message="同步成功").__dict__
return Response().ok(message="Sync completed").__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"同步失败: {e!s}").__dict__
return Response().error(f"Sync failed: {e!s}").__dict__
@@ -3,10 +3,10 @@
### 新增
- 集成 KOOK 平台适配器 ([#5658](https://github.com/AstrBotDevs/AstrBot/pull/5658))。
- 集成 DeerFlow Agent Runner 并优化流式处理 ([#5581](https://github.com/AstrBotDevs/AstrBot/pull/5581))。
- 新增 Discord pre-react Emoji 支持 ([#5609](https://github.com/AstrBotDevs/AstrBot/pull/5609))。
- 新增 Telegram 支持 `sendMessageDraft` 流式实时输出 API ([#5726](https://github.com/AstrBotDevs/AstrBot/issues/5726))
- 支持在 Agent 运行时进行消息跟进能力,跟进的消息实时注入给 Agent ([#5484](https://github.com/AstrBotDevs/AstrBot/pull/5484))。
- 集成 DeerFlow Agent Runner 并优化流式处理 ([#5581](https://github.com/AstrBotDevs/AstrBot/pull/5581))。
- 新增 shell, ipython tool 中包含操作系统信息,提高 windows 下 tool call 成功率 ([#5677](https://github.com/AstrBotDevs/AstrBot/pull/5677))。
- Sandbox 支持 Shipyard-neo - 支持 Skills 自迭代 ([#5028](https://github.com/AstrBotDevs/AstrBot/pull/5028))。
- 新增 ChatUI WebSocket 传输模式选择,OpenAPI Chat API 支持 WebSocket 连接 ([#5410](https://github.com/AstrBotDevs/AstrBot/pull/5410))。
+40
View File
@@ -0,0 +1,40 @@
## What's Changed
### 新增
- 新增技能 ZIP 批量上传能力 ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804))。
### 修复
- 修复 MCP Server 配置异常时可能导致崩溃的问题 ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673))。
- 修复钉钉适配器文本消息被忽略、无法主动发送文件的问题 ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921))。
- 修复钉钉适配器无法接收图片与文件的问题 ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920))。
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))。
- 修复 OpenRouter `api_base` 配置错误的问题 ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911))。
- 修复插件市场中按展示名搜索已安装插件不生效的问题 ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811))。
- 修复仅图片响应未应用 `reply_with_quote``reply_with_mention` 的问题 ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219))。
- 修复 `RegexFilter` 使用 `re.match` 导致匹配范围不正确的问题 ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368))。
- 修复桌面运行环境检测依赖 frozen Python 的问题 ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859))。
- 修复通过“创建新配置”创建平台机器人后找不到 pipeline scheduler 的问题 ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776))。
---
## What's Changed (EN)
### New Features
- Added batch upload support for multiple skill ZIP files ([#5804](https://github.com/AstrBotDevs/AstrBot/pull/5804)).
### Bug Fixes
- Fixed potential crash on malformed MCP server config ([#5666](https://github.com/AstrBotDevs/AstrBot/pull/5666), [#5673](https://github.com/AstrBotDevs/AstrBot/pull/5673)).
- Fixed DingTalk adapter issue where text messages were ignored and files could not be sent proactively ([#5921](https://github.com/AstrBotDevs/AstrBot/pull/5921)).
- Fixed DingTalk adapter issue where image and file messages could not be received ([#5920](https://github.com/AstrBotDevs/AstrBot/pull/5920)).
- Fixed incorrect OpenRouter `api_base` configuration ([#5911](https://github.com/AstrBotDevs/AstrBot/pull/5911)).
- Fixed searching installed plugins by display name in extensions ([#5806](https://github.com/AstrBotDevs/AstrBot/pull/5806), [#5811](https://github.com/AstrBotDevs/AstrBot/pull/5811)).
- Fixed image-only responses not applying `reply_with_quote` and `reply_with_mention` ([#5219](https://github.com/AstrBotDevs/AstrBot/pull/5219)).
- Fixed `RegexFilter` using `re.match` instead of `re.search` for expected matching behavior ([#5368](https://github.com/AstrBotDevs/AstrBot/pull/5368)).
- Fixed desktop runtime detection requiring frozen Python ([#5859](https://github.com/AstrBotDevs/AstrBot/pull/5859)).
- Fixed missing pipeline scheduler after creating a platform bot via "create new config" ([#5776](https://github.com/AstrBotDevs/AstrBot/pull/5776)).
- fix(provider): handle MiniMax ThinkingBlock when max_tokens reached ([#5913](https://github.com/AstrBotDevs/AstrBot/pull/5913))
+9
View File
@@ -0,0 +1,9 @@
## What's Changed
### 新增
- 企业微信智能机器人支持长连接模式。[#5930](https://github.com/AstrBotDevs/AstrBot/pull/5930)
### New
- Wecom AI Bot supports long-connection mode(Websockets). [#5930](https://github.com/AstrBotDevs/AstrBot/pull/5930)
@@ -300,6 +300,10 @@ export default {
this.loadingGettingServers = true;
axios.get('/api/tools/mcp/servers')
.then(response => {
if (response.data.status === 'error') {
this.showError(response.data.message || this.tm('messages.getServersError', { error: 'Unknown error' }));
return;
}
this.mcpServers = response.data.data || [];
this.mcpServers.forEach(server => {
if (!this.mcpServerUpdateLoaders[server.name]) {
@@ -372,6 +376,10 @@ export default {
axios.post(endpoint, serverData)
.then(response => {
this.loading = false;
if (response.data.status === 'error') {
this.showError(response.data.message || this.tm('messages.saveError', { error: 'Unknown error' }));
return;
}
this.showMcpServerDialog = false;
this.addServerDialogMessage = '';
this.getServers();
File diff suppressed because it is too large Load Diff
@@ -550,6 +550,10 @@
"description": "WeCom AI Bot Name",
"hint": "Must be correct; otherwise some commands won't work."
},
"wecom_ai_bot_connection_mode": {
"description": "WeCom AI Bot Connection Mode",
"hint": "Webhook mode requires Token/EncodingAESKey; long_connection mode requires BotID/Secret."
},
"wecomaibot_friend_message_welcome_text": {
"description": "WeCom AI Bot DM Welcome Message",
"hint": "When a user enters a DM session on that day, reply with a welcome message. Leave empty to disable."
@@ -558,6 +562,30 @@
"description": "WeCom AI Bot Initial Response Text",
"hint": "First reply when the bot receives a message. Leave empty to disable."
},
"wecomaibot_token": {
"description": "WeCom AI Bot Token",
"hint": "Used for authentication in webhook callback mode."
},
"wecomaibot_encoding_aes_key": {
"description": "WeCom AI Bot EncodingAESKey",
"hint": "Used for message encryption/decryption in webhook callback mode."
},
"wecomaibot_ws_bot_id": {
"description": "Long Connection BotID",
"hint": "BotID credential for WeCom AI Bot long connection mode."
},
"wecomaibot_ws_secret": {
"description": "Long Connection Secret",
"hint": "Secret credential for WeCom AI Bot long connection mode."
},
"wecomaibot_ws_url": {
"description": "Long Connection WebSocket URL",
"hint": "Default is wss://openws.work.weixin.qq.com and usually does not need changes."
},
"wecomaibot_heartbeat_interval": {
"description": "Long Connection Heartbeat Interval",
"hint": "Heartbeat interval (seconds) in long connection mode. 30 seconds is recommended."
},
"wpp_active_message_poll": {
"description": "Enable Proactive Message Polling",
"hint": "Only enable if WeChat messages are not syncing to AstrBot on time. Disabled by default."
@@ -224,10 +224,43 @@
"empty": "No Skills found",
"emptyHint": "Upload a Skills zip to get started",
"uploadDialogTitle": "Upload Skills",
"uploadHint": "Upload a zip file that contains skill_name/ and a SKILL.md inside.",
"uploadHint": "Upload multiple zip skill packages or drag them in. The system validates the structure automatically and shows a result for each file.",
"structureRequirement": "The most common failure is an invalid archive structure. Each zip must contain exactly one top-level folder such as `skillname/`, and that folder must include `SKILL.md`.",
"abilityMultiple": "Upload multiple zip files at once",
"abilityValidate": "Validate `SKILL.md` automatically",
"abilitySkip": "Automatically skip duplicate files.",
"selectFile": "Select file",
"confirmUpload": "Upload",
"selectFiles": "Select files (multiple allowed)",
"dropzoneTitle": "Drag multiple zip files here",
"dropzoneAction": "or click to pick multiple files from a folder",
"dropzoneHint": "Batch upload is supported and the structure will be validated automatically",
"fileListTitle": "Files in queue",
"fileListEmpty": "Selected files will appear here with validation feedback and upload status",
"uploading": "Uploading...",
"batchResultTitle": "Batch Upload Results",
"batchResultSummary": "{success} of {total} files uploaded successfully",
"batchSuccessList": "Successfully uploaded",
"batchFailedList": "Failed to upload",
"confirm": "OK",
"confirmUpload": "Start Upload",
"cancel": "Cancel",
"statusWaiting": "Waiting",
"statusUploading": "Uploading",
"statusSuccess": "Uploaded",
"statusError": "Failed",
"statusSkipped": "Skipped",
"summaryTotal": "{count} file(s)",
"summaryReady": "Pending {count}",
"summarySuccess": "Success {count}",
"summaryFailed": "Failed {count}",
"summarySkipped": "Skipped {count}",
"validationReady": "Ready to upload. The archive structure will be checked during upload.",
"validationZipOnly": "Only zip skill packages are supported",
"validationDuplicate": "A file with the same name is already in the queue and has been skipped",
"validationUploading": "Validating and uploading...",
"validationUploadFailed": "Upload failed. Please try again.",
"validationUploadedAs": "Installed as {name}",
"validationNoResult": "No validation result was returned. Check the platform logs.",
"noDescription": "No description",
"path": "Path",
"uploadSuccess": "Upload succeeded",
@@ -543,7 +543,7 @@
},
"unified_webhook_mode": {
"description": "统一 Webhook 模式",
"hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}。"
"hint": "Webhook 模式下使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}。"
},
"webhook_uuid": {
"description": "Webhook UUID",
@@ -553,13 +553,41 @@
"description": "企业微信智能机器人的名字",
"hint": "请务必填写正确,否则无法使用一些指令。"
},
"wecom_ai_bot_connection_mode": {
"description": "企业微信智能机器人连接模式",
"hint": "Webhook 回调模式需要配置 Token/EncodingAESKey;长连接模式需要配置 BotID/Secret。"
},
"wecomaibot_friend_message_welcome_text": {
"description": "企业微信智能机器人私聊欢迎语",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,如 “💭 思考中...”。留空则不回复。"
"hint": "可选。当用户当天进入智能机器人单聊会话,回复欢迎语,如 “💭 思考中...”。留空则不回复。"
},
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"hint": "当机器人收到消息时,首先回复的文本内容。留空则不设置。"
"hint": "可选。当机器人收到消息时,首先回复的文本内容。留空则不设置。"
},
"wecomaibot_token": {
"description": "企业微信智能机器人 Token",
"hint": "用于 Webhook 回调模式的身份验证。"
},
"wecomaibot_encoding_aes_key": {
"description": "企业微信智能机器人 EncodingAESKey",
"hint": "用于 Webhook 回调模式的消息加密解密。"
},
"wecomaibot_ws_bot_id": {
"description": "长连接 BotID",
"hint": "企业微信智能机器人长连接模式凭证 BotID。"
},
"wecomaibot_ws_secret": {
"description": "长连接 Secret",
"hint": "企业微信智能机器人长连接模式凭证 Secret。"
},
"wecomaibot_ws_url": {
"description": "长连接 WebSocket 地址",
"hint": "默认值为 wss://openws.work.weixin.qq.com,一般无需修改。"
},
"wecomaibot_heartbeat_interval": {
"description": "长连接心跳间隔",
"hint": "长连接模式心跳间隔(秒),建议 30 秒。"
},
"wpp_active_message_poll": {
"description": "是否启用主动消息轮询",
@@ -582,11 +610,11 @@
},
"msg_push_webhook_url": {
"description": "企业微信消息推送 Webhook URL",
"hint": "用于主动消息推送,请在企微群->消息推送得到 URL。强烈建议设置此项以带来更好的消息发送体验。"
"hint": "可选。用于主动消息推送,请在企微群->消息推送得到 URL。建议设置此项以带来更好的消息发送体验。"
},
"only_use_webhook_url_to_send": {
"description": "仅使用 Webhook 发送消息",
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。如果不需要打字机效果,强烈建议使用此选项。"
"hint": "可选。启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。如果不需要打字机效果,强烈建议使用此选项。"
},
"kook_bot_token": {
"description": "机器人 Token",
@@ -224,10 +224,43 @@
"empty": "暂无 Skills",
"emptyHint": "请上传 Skills 压缩包",
"uploadDialogTitle": "上传 Skills",
"uploadHint": "上传 zip 压缩包,解压后为 skill_name/ 目录,且包含 SKILL.md",
"uploadHint": "支持批量上传 zip 技能包,也支持拖拽批量上传 zip 技能包。系统会自动校验目录结构,并给出逐个文件的结果。",
"structureRequirement": "常见失败原因是压缩包结构不正确。每个 zip 必须只包含一个顶层目录,例如 `skillname/`,且该目录下必须存在 `SKILL.md`。",
"abilityMultiple": "支持一次上传多个zip文件",
"abilityValidate": "自动校验 `SKILL.md`",
"abilitySkip": "自动跳过重复文件",
"selectFile": "选择文件",
"confirmUpload": "上传",
"selectFiles": "选择文件(可多选)",
"dropzoneTitle": "拖拽多个 zip 文件到这里",
"dropzoneAction": "或者点击之后在文件夹中选择多个文件",
"dropzoneHint": "支持批量上传,系统会自动校验目录结构",
"fileListTitle": "待处理文件",
"fileListEmpty": "选择文件后会在这里显示校验结果与上传状态",
"uploading": "正在上传...",
"batchResultTitle": "批量上传结果",
"batchResultSummary": "共 {total} 个文件,成功 {success} 个",
"batchSuccessList": "上传成功",
"batchFailedList": "上传失败",
"confirm": "确定",
"confirmUpload": "开始上传",
"cancel": "取消",
"statusWaiting": "待上传",
"statusUploading": "上传中",
"statusSuccess": "已上传",
"statusError": "校验失败",
"statusSkipped": "已跳过",
"summaryTotal": "共 {count} 个文件",
"summaryReady": "待处理 {count}",
"summarySuccess": "成功 {count}",
"summaryFailed": "失败 {count}",
"summarySkipped": "跳过 {count}",
"validationReady": "等待上传,上传时会自动校验目录结构",
"validationZipOnly": "仅支持 zip 技能包",
"validationDuplicate": "同名文件已在列表中,已跳过",
"validationUploading": "正在校验并上传...",
"validationUploadFailed": "上传失败,请重试",
"validationUploadedAs": "已安装为 {name}",
"validationNoResult": "未收到校验结果,请检查平台日志",
"noDescription": "无描述",
"path": "路径",
"uploadSuccess": "上传成功",
+102
View File
@@ -0,0 +1,102 @@
import { pinyin } from "pinyin-pro";
const HAN_IDEOGRAPH_RE = /\p{Unified_Ideograph}/u;
export const normalizeStr = (s) => (s ?? "").toString().toLowerCase().trim();
const normalizeLooseFromNormalized = (normalized) =>
normalized.replace(/[\s_-]+/g, "").replace(/[()()【】\[\]{}·•]+/g, "");
export const normalizeLoose = (s) =>
normalizeLooseFromNormalized(normalizeStr(s));
const memoizeStringFn = (fn) => {
const cache = new Map();
return (raw) => {
const key = (raw ?? "").toString();
if (cache.has(key)) {
return cache.get(key);
}
const value = fn(key);
cache.set(key, value);
return value;
};
};
const getNormalizedText = memoizeStringFn(normalizeStr);
const getLooseText = memoizeStringFn((text) =>
normalizeLooseFromNormalized(getNormalizedText(text)),
);
export const toPinyinText = memoizeStringFn((text) =>
pinyin(text, { toneType: "none" })
.toLowerCase()
.replace(/\s+/g, ""),
);
export const toInitials = memoizeStringFn((text) =>
pinyin(text, { pattern: "first", toneType: "none" })
.toLowerCase()
.replace(/\s+/g, ""),
);
export const buildSearchQuery = (raw) => {
const norm = getNormalizedText(raw);
if (!norm) return null;
return {
norm,
loose: getLooseText(raw),
};
};
export const matchesText = (value, query) => {
if (value == null || !query?.norm) return false;
const text = String(value);
const normalizedValue = getNormalizedText(text);
const looseValue = query.loose ? getLooseText(text) : null;
if (normalizedValue.includes(query.norm)) return true;
if (query.loose && looseValue?.includes(query.loose)) return true;
if (!HAN_IDEOGRAPH_RE.test(text)) return false;
const pinyinValue = toPinyinText(text);
if (pinyinValue.includes(query.norm)) return true;
const initialsValue = toInitials(text);
if (initialsValue.includes(query.norm)) return true;
return false;
};
export const getPluginSearchFields = (plugin) => {
const supportPlatforms = Array.isArray(plugin?.support_platforms)
? plugin.support_platforms.join(" ")
: "";
const tags = Array.isArray(plugin?.tags) ? plugin.tags.join(" ") : "";
return [
plugin?.name,
plugin?.trimmedName,
plugin?.display_name,
plugin?.desc,
plugin?.author,
plugin?.repo,
plugin?.version,
plugin?.astrbot_version,
supportPlatforms,
tags,
];
};
export const matchesPluginSearch = (plugin, query) => {
if (!query) return true;
return getPluginSearchFields(plugin).some((candidate) =>
matchesText(candidate, query),
);
};
-1
View File
@@ -84,7 +84,6 @@ const {
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
@@ -81,7 +81,6 @@ const {
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
@@ -82,7 +82,6 @@ const {
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
@@ -1,9 +1,15 @@
import axios from "axios";
import { pinyin } from "pinyin-pro";
import { useCommonStore } from "@/stores/common";
import { useI18n, useModuleI18n } from "@/i18n/composables";
import { getPlatformDisplayName } from "@/utils/platformUtils";
import { resolveErrorMessage } from "@/utils/errorUtils";
import {
buildSearchQuery,
matchesPluginSearch,
normalizeStr,
toInitials,
toPinyinText,
} from "@/utils/pluginSearch";
import { ref, computed, onMounted, onUnmounted, reactive, watch } from "vue";
import { useRoute, useRouter } from "vue-router";
import { useDisplay } from "vuetify";
@@ -240,37 +246,6 @@ export const useExtensionPage = () => {
});
// 插件市场拼音搜索
const normalizeStr = (s) => (s ?? "").toString().toLowerCase().trim();
const toPinyinText = (s) =>
pinyin(s ?? "", { toneType: "none" })
.toLowerCase()
.replace(/\s+/g, "");
const toInitials = (s) =>
pinyin(s ?? "", { pattern: "first", toneType: "none" })
.toLowerCase()
.replace(/\s+/g, "");
const marketCustomFilter = (value, query, item) => {
const q = normalizeStr(query);
if (!q) return true;
const candidates = new Set();
if (value != null) candidates.add(String(value));
if (item?.name) candidates.add(String(item.name));
if (item?.trimmedName) candidates.add(String(item.trimmedName));
if (item?.display_name) candidates.add(String(item.display_name));
if (item?.desc) candidates.add(String(item.desc));
if (item?.author) candidates.add(String(item.author));
for (const v of candidates) {
const nv = normalizeStr(v);
if (nv.includes(q)) return true;
const pv = toPinyinText(v);
if (pv.includes(q)) return true;
const iv = toInitials(v);
if (iv.includes(q)) return true;
}
return false;
};
const plugin_handler_info_headers = computed(() => [
{ title: tm("table.headers.eventType"), key: "event_type_h" },
@@ -347,47 +322,24 @@ export const useExtensionPage = () => {
// 通过搜索过滤插件
const filteredPlugins = computed(() => {
const plugins = filteredExtensions.value;
let filtered = plugins;
if (pluginSearch.value) {
const search = pluginSearch.value.toLowerCase();
filtered = plugins.filter((plugin) => {
const pluginName = (plugin.name ?? "").toLowerCase();
const pluginDesc = (plugin.desc ?? "").toLowerCase();
const pluginAuthor = (plugin.author ?? "").toLowerCase();
const supportPlatforms = Array.isArray(plugin.support_platforms)
? plugin.support_platforms.join(" ").toLowerCase()
: "";
const astrbotVersion = (plugin.astrbot_version ?? "").toLowerCase();
return (
pluginName.includes(search) ||
pluginDesc.includes(search) ||
pluginAuthor.includes(search) ||
supportPlatforms.includes(search) ||
astrbotVersion.includes(search)
);
});
}
const query = buildSearchQuery(pluginSearch.value);
const filtered = query
? plugins.filter((plugin) => matchesPluginSearch(plugin, query))
: plugins;
return sortPluginsByName([...filtered]);
});
// 过滤后的插件市场数据(带搜索)
const filteredMarketPlugins = computed(() => {
if (!debouncedMarketSearch.value) {
const query = buildSearchQuery(debouncedMarketSearch.value);
if (!query) {
return pluginMarketData.value;
}
const search = debouncedMarketSearch.value.toLowerCase();
return pluginMarketData.value.filter((plugin) => {
// 使用自定义过滤器
return (
marketCustomFilter(plugin.name, search, plugin) ||
marketCustomFilter(plugin.desc, search, plugin) ||
marketCustomFilter(plugin.author, search, plugin)
);
});
return pluginMarketData.value.filter((plugin) =>
matchesPluginSearch(plugin, query),
);
});
// 所有插件列表,推荐插件排在前面
@@ -1563,7 +1515,6 @@ export const useExtensionPage = () => {
normalizeStr,
toPinyinText,
toInitials,
marketCustomFilter,
plugin_handler_info_headers,
pluginHeaders,
filteredExtensions,
+26 -11
View File
@@ -5,6 +5,10 @@ import os
import sys
from pathlib import Path
import runtime_bootstrap
runtime_bootstrap.initialize_runtime_bootstrap()
from astrbot.core import LogBroker, LogManager, db_helper, logger # noqa: E402
from astrbot.core.config.default import VERSION # noqa: E402
from astrbot.core.initial_loader import InitialLoader # noqa: E402
@@ -97,6 +101,26 @@ async def check_dashboard_files(webui_dir: str | None = None):
return data_dist_path
async def main_async(webui_dir_arg: str | None) -> None:
"""主异步入口"""
# 检查仪表板文件
webui_dir = await check_dashboard_files(webui_dir_arg)
if webui_dir is None:
logger.warning(
"管理面板文件检查失败,WebUI 功能将不可用。"
"请检查网络连接或手动指定 --webui-dir 参数。"
)
db = db_helper
# 打印 logo
logger.info(logo_tmpl)
core_lifecycle = InitialLoader(db, log_broker)
core_lifecycle.webui_dir = webui_dir
await core_lifecycle.start()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="AstrBot")
parser.add_argument(
@@ -113,14 +137,5 @@ if __name__ == "__main__":
log_broker = LogBroker()
LogManager.set_queue_handler(logger, log_broker)
# 检查仪表板文件
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
db = db_helper
# 打印 logo
logger.info(logo_tmpl)
core_lifecycle = InitialLoader(db, log_broker)
core_lifecycle.webui_dir = webui_dir
asyncio.run(core_lifecycle.start())
# 只使用一次 asyncio.run()
asyncio.run(main_async(args.webui_dir))
+3 -3
View File
@@ -1,9 +1,9 @@
[project]
name = "AstrBot"
version = "4.19.1"
version = "4.19.4"
description = "Easy-to-use multi-platform LLM chatbot and development framework"
readme = "README.md"
requires-python = ">=3.11"
requires-python = ">=3.12"
keywords = ["Astrbot", "Astrbot Module", "Astrbot Plugin"]
@@ -61,7 +61,7 @@ dependencies = [
"xinference-client",
"tenacity>=9.1.2",
"shipyard-python-sdk>=0.2.4",
"shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk",
"shipyard-neo-sdk>=0.2.0",
"python-socks>=2.8.0",
"packaging>=24.2",
]
+1 -1
View File
@@ -54,5 +54,5 @@ markitdown-no-magika[docx,xls,xlsx]>=0.1.2
xinference-client
tenacity>=9.1.2
shipyard-python-sdk>=0.2.4
shipyard-neo-sdk @ git+https://github.com/AstrBotDevs/shipyard-neo.git#subdirectory=shipyard-neo-sdk
shipyard-neo-sdk>=0.2.0
packaging>=24.2
+50
View File
@@ -0,0 +1,50 @@
import logging
import ssl
from typing import Any
import aiohttp.connector as aiohttp_connector
from astrbot.utils.http_ssl_common import build_ssl_context_with_certifi
logger = logging.getLogger(__name__)
def _try_patch_aiohttp_ssl_context(
ssl_context: ssl.SSLContext,
log_obj: Any | None = None,
) -> bool:
log = log_obj or logger
attr_name = "_SSL_CONTEXT_VERIFIED"
if not hasattr(aiohttp_connector, attr_name):
log.warning(
"aiohttp connector does not expose _SSL_CONTEXT_VERIFIED; skipped patch.",
)
return False
current_value = getattr(aiohttp_connector, attr_name, None)
if current_value is not None and not isinstance(current_value, ssl.SSLContext):
log.warning(
"aiohttp connector exposes _SSL_CONTEXT_VERIFIED with unexpected type; skipped patch.",
)
return False
setattr(aiohttp_connector, attr_name, ssl_context)
log.info("Configured aiohttp verified SSL context with system+certifi trust chain.")
return True
def configure_runtime_ca_bundle(log_obj: Any | None = None) -> bool:
log = log_obj or logger
try:
log.info("Bootstrapping runtime CA bundle.")
ssl_context = build_ssl_context_with_certifi(log_obj=log)
return _try_patch_aiohttp_ssl_context(ssl_context, log_obj=log)
except Exception as exc:
log.error("Failed to configure runtime CA bundle for aiohttp: %r", exc)
return False
def initialize_runtime_bootstrap(log_obj: Any | None = None) -> bool:
return configure_runtime_ca_bundle(log_obj=log_obj)
-20
View File
@@ -7,7 +7,6 @@ import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable
from urllib.parse import urlparse
from unittest.mock import AsyncMock, MagicMock
from astrbot.core.message.components import BaseMessageComponent
@@ -25,25 +24,6 @@ class NoopAwaitable:
return None
def get_bound_tcp_port(site: Any) -> int:
"""Resolve the bound aiohttp TCP site port for tests.
We prefer the public ``site.name`` first. Some aiohttp test setups with
ephemeral ports may not expose a usable port there, so we fall back to
``site._server.sockets`` as a test-only compatibility path.
"""
parsed = urlparse(getattr(site, "name", ""))
if parsed.port is not None and parsed.port > 0:
return parsed.port
server = getattr(site, "_server", None)
sockets = getattr(server, "sockets", None) if server else None
if sockets:
return sockets[0].getsockname()[1]
raise RuntimeError("Unable to resolve bound TCP port from aiohttp site")
# ============================================================
# 平台配置工厂
# ============================================================
-268
View File
@@ -1,268 +0,0 @@
"""Performance benchmark tests for core AstrBot execution paths.
Run with:
uv run pytest tests/performance/test_benchmarks.py -q -s
Optional output:
ASTRBOT_BENCHMARK_OUTPUT=/tmp/astrbot_benchmark.json
"""
from __future__ import annotations
import asyncio
import json
import math
import os
import time
import zipfile
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Awaitable, Callable
from unittest.mock import MagicMock
import pytest
from aiohttp import web
from astrbot.core.backup.exporter import AstrBotExporter
from astrbot.core.message.components import File, Image, Record
from astrbot.core.utils.io import download_file, file_to_base64
from tests.fixtures.helpers import get_bound_tcp_port
@dataclass(slots=True)
class BenchmarkResult:
name: str
iterations: int
warmup: int
min_ms: float
max_ms: float
mean_ms: float
p50_ms: float
p95_ms: float
ops_per_sec: float
def _percentile(values: list[float], q: float) -> float:
if not values:
return 0.0
sorted_values = sorted(values)
if len(sorted_values) == 1:
return sorted_values[0]
rank = (len(sorted_values) - 1) * q
lower = math.floor(rank)
upper = math.ceil(rank)
if lower == upper:
return sorted_values[lower]
weight = rank - lower
return sorted_values[lower] * (1 - weight) + sorted_values[upper] * weight
async def run_async_benchmark(
name: str,
func: Callable[[], Awaitable[None]],
*,
iterations: int,
warmup: int = 5,
) -> BenchmarkResult:
for _ in range(warmup):
await func()
samples_ms: list[float] = []
for _ in range(iterations):
start_ns = time.perf_counter_ns()
await func()
elapsed_ms = (time.perf_counter_ns() - start_ns) / 1_000_000
samples_ms.append(elapsed_ms)
mean_ms = sum(samples_ms) / len(samples_ms)
return BenchmarkResult(
name=name,
iterations=iterations,
warmup=warmup,
min_ms=min(samples_ms),
max_ms=max(samples_ms),
mean_ms=mean_ms,
p50_ms=_percentile(samples_ms, 0.50),
p95_ms=_percentile(samples_ms, 0.95),
ops_per_sec=1000 / mean_ms if mean_ms > 0 else 0.0,
)
def _print_report(results: list[BenchmarkResult]) -> None:
print("\nAstrBot Benchmark Report")
print("-" * 84)
print(
f"{'case':35} {'iters':>7} {'mean(ms)':>10} {'p50(ms)':>10} "
f"{'p95(ms)':>10} {'ops/s':>10}"
)
print("-" * 84)
for result in results:
print(
f"{result.name:35} {result.iterations:7d} "
f"{result.mean_ms:10.4f} {result.p50_ms:10.4f} "
f"{result.p95_ms:10.4f} {result.ops_per_sec:10.1f}"
)
def _scaled_iterations(value: int) -> int:
scale = int(os.environ.get("ASTRBOT_BENCHMARK_SCALE", "1"))
return max(1, value * scale)
@pytest.mark.asyncio
@pytest.mark.slow
async def test_core_performance_benchmarks(tmp_path: Path) -> None:
"""Measure representative performance paths across core modules."""
data = os.urandom(256 * 1024)
payload_path = tmp_path / "payload.bin"
payload_path.write_bytes(data)
image = Image.fromFileSystem(str(payload_path))
record = Record.fromFileSystem(str(payload_path))
file_component = File(name="payload.bin", file=str(payload_path))
exists_path = tmp_path / "exists_target.txt"
exists_path.write_text("ok", encoding="utf-8")
attachments_dir = tmp_path / "attachments"
attachments_dir.mkdir()
attachments: list[dict[str, str]] = []
attachments_with_missing: list[dict[str, str]] = []
for i in range(64):
file_path = attachments_dir / f"attachment_{i}.bin"
file_path.write_bytes(data[:2048])
attachments.append({"attachment_id": f"att_{i}", "path": str(file_path)})
if i % 4 == 0:
missing_path = attachments_dir / f"missing_{i}.bin"
attachments_with_missing.append(
{"attachment_id": f"att_missing_{i}", "path": str(missing_path)}
)
attachments_with_missing.append(
{"attachment_id": f"att_existing_{i}", "path": str(file_path)}
)
exporter = AstrBotExporter(main_db=MagicMock())
zip_path = tmp_path / "attachments_bench.zip"
micro_batch = 32
download_target = tmp_path / "download_target.bin"
download_payload = os.urandom(512 * 1024)
async def handle_download(_request):
return web.Response(body=download_payload)
app = web.Application()
app.router.add_get("/download.bin", handle_download)
runner = web.AppRunner(app, access_log=None)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
port = get_bound_tcp_port(site)
download_url = f"http://127.0.0.1:{port}/download.bin"
async def bench_file_to_base64() -> None:
await file_to_base64(str(payload_path))
async def bench_image_convert_to_base64() -> None:
await image.convert_to_base64()
async def bench_record_convert_to_base64() -> None:
await record.convert_to_base64()
async def bench_image_convert_to_file_path() -> None:
for _ in range(micro_batch):
await image.convert_to_file_path()
async def bench_file_component_get_file() -> None:
await file_component.get_file()
async def bench_to_thread_exists() -> None:
await asyncio.to_thread(exists_path.exists)
async def bench_export_attachments_existing() -> None:
if zip_path.exists():
zip_path.unlink()
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
await exporter._export_attachments(zf, attachments)
zip_path.unlink(missing_ok=True)
async def bench_export_attachments_with_missing() -> None:
if zip_path.exists():
zip_path.unlink()
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
await exporter._export_attachments(zf, attachments_with_missing)
zip_path.unlink(missing_ok=True)
async def bench_download_file_local_http() -> None:
await download_file(download_url, str(download_target))
download_target.unlink(missing_ok=True)
try:
results = [
await run_async_benchmark(
"utils.io.file_to_base64(256KB)",
bench_file_to_base64,
iterations=_scaled_iterations(120),
),
await run_async_benchmark(
"components.Image.convert_to_base64",
bench_image_convert_to_base64,
iterations=_scaled_iterations(120),
),
await run_async_benchmark(
"components.Record.convert_to_base64",
bench_record_convert_to_base64,
iterations=_scaled_iterations(120),
),
await run_async_benchmark(
f"components.Image.convert_to_file_path(x{micro_batch})",
bench_image_convert_to_file_path,
iterations=_scaled_iterations(140),
),
await run_async_benchmark(
"components.File.get_file(local)",
bench_file_component_get_file,
iterations=_scaled_iterations(140),
),
await run_async_benchmark(
"asyncio.to_thread(Path.exists)",
bench_to_thread_exists,
iterations=_scaled_iterations(240),
),
await run_async_benchmark(
"backup.exporter._export_attachments(existing)",
bench_export_attachments_existing,
iterations=_scaled_iterations(20),
warmup=2,
),
await run_async_benchmark(
"backup.exporter._export_attachments(mixed)",
bench_export_attachments_with_missing,
iterations=_scaled_iterations(20),
warmup=2,
),
await run_async_benchmark(
"utils.io.download_file(local_http_512KB)",
bench_download_file_local_http,
iterations=_scaled_iterations(12),
warmup=2,
),
]
finally:
await runner.cleanup()
_print_report(results)
output_path = os.environ.get("ASTRBOT_BENCHMARK_OUTPUT")
if output_path:
Path(output_path).write_text(
json.dumps([asdict(result) for result in results], indent=2),
encoding="utf-8",
)
# Keep assertions broad: benchmarks are for measurement, not strict gating.
assert len(results) == 9
for result in results:
assert result.iterations > 0
assert result.mean_ms > 0
assert result.max_ms >= result.min_ms
assert result.p95_ms >= result.p50_ms
-98
View File
@@ -172,15 +172,6 @@ class TestAstrBotExporter:
assert "test.json" in exporter._checksums
assert exporter._checksums["test.json"].startswith("sha256:")
def test_read_text_if_exists(self, tmp_path):
"""测试 _read_text_if_exists 行为。"""
exporter = AstrBotExporter(main_db=MagicMock())
file_path = tmp_path / "config.json"
file_path.write_text('{"k":"v"}', encoding="utf-8")
assert exporter._read_text_if_exists(str(file_path)) == '{"k":"v"}'
assert exporter._read_text_if_exists(str(tmp_path / "missing.json")) is None
def test_generate_manifest(self, mock_main_db, mock_kb_manager):
"""测试生成清单"""
exporter = AstrBotExporter(
@@ -249,95 +240,6 @@ class TestAstrBotExporter:
assert "databases/main_db.json" in namelist
assert "config/cmd_config.json" in namelist
@pytest.mark.asyncio
async def test_export_attachments_exports_existing_and_skips_missing(
self, mock_main_db, tmp_path
):
"""测试附件导出:存在文件写入 ZIP,不存在文件跳过。"""
exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None)
existing_file = tmp_path / "exists.txt"
existing_file.write_text("hello", encoding="utf-8")
missing_file = tmp_path / "missing.txt"
zip_path = tmp_path / "attachments.zip"
attachments = [
{"attachment_id": "att_ok", "path": str(existing_file)},
{"attachment_id": "att_missing", "path": str(missing_file)},
]
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
await exporter._export_attachments(zf, attachments)
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()
assert "files/attachments/att_ok.txt" in namelist
assert "files/attachments/att_missing.txt" not in namelist
@pytest.mark.asyncio
async def test_export_attachments_skips_empty_attachment_id(
self, mock_main_db, tmp_path
):
"""测试附件导出:attachment_id 为空时跳过,避免覆盖冲突。"""
exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None)
file_a = tmp_path / "a.txt"
file_b = tmp_path / "b.txt"
file_a.write_text("a", encoding="utf-8")
file_b.write_text("b", encoding="utf-8")
zip_path = tmp_path / "attachments_empty_id.zip"
attachments = [
{"attachment_id": "", "path": str(file_a)},
{"path": str(file_b)},
{"attachment_id": "att_ok", "path": str(file_a)},
]
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
await exporter._export_attachments(zf, attachments)
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()
assert "files/attachments/att_ok.txt" in namelist
assert "files/attachments/.txt" not in namelist
@pytest.mark.asyncio
async def test_export_attachments_keeps_best_effort_on_unexpected_write_error(
self, mock_main_db, tmp_path
):
"""测试附件导出:单个非 OSError 写入异常不会中断后续附件导出。"""
exporter = AstrBotExporter(main_db=mock_main_db, kb_manager=None)
file_a = tmp_path / "a.txt"
file_b = tmp_path / "b.txt"
file_a.write_text("a", encoding="utf-8")
file_b.write_text("b", encoding="utf-8")
zip_path = tmp_path / "attachments_best_effort.zip"
attachments = [
{"attachment_id": "att_boom", "path": str(file_a)},
{"attachment_id": "att_ok", "path": str(file_b)},
]
with zipfile.ZipFile(zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
original_write = zf.write
def flaky_write(filename, arcname=None, *args, **kwargs):
if arcname == "files/attachments/att_boom.txt":
raise RuntimeError("boom")
return original_write(filename, arcname, *args, **kwargs)
with patch.object(zf, "write", side_effect=flaky_write):
await exporter._export_attachments(zf, attachments)
with zipfile.ZipFile(zip_path, "r") as zf:
namelist = zf.namelist()
assert "files/attachments/att_boom.txt" not in namelist
assert "files/attachments/att_ok.txt" in namelist
class TestAstrBotImporter:
"""AstrBotImporter 类测试"""
+227 -5
View File
@@ -1,11 +1,14 @@
import asyncio
import io
import os
import sys
import zipfile
from types import SimpleNamespace
import pytest
import pytest_asyncio
from quart import Quart
from werkzeug.datastructures import FileStorage
from astrbot.core import LogBroker
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -15,7 +18,6 @@ from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.dashboard.server import AstrBotDashboard
from tests.fixtures.helpers import (
MockPluginBuilder,
MockPluginConfig,
create_mock_updater_install,
create_mock_updater_update,
)
@@ -145,9 +147,7 @@ async def test_plugins(
monkeypatch.setattr(
core_lifecycle_td.plugin_manager.updator, "install", mock_install
)
monkeypatch.setattr(
core_lifecycle_td.plugin_manager.updator, "update", mock_update
)
monkeypatch.setattr(core_lifecycle_td.plugin_manager.updator, "update", mock_update)
try:
# 插件安装
@@ -158,7 +158,9 @@ async def test_plugins(
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok", f"安装失败: {data.get('message', 'unknown error')}"
assert data["status"] == "ok", (
f"安装失败: {data.get('message', 'unknown error')}"
)
# 验证插件已注册
exists = any(md.name == test_plugin_name for md in star_registry)
@@ -493,3 +495,223 @@ async def test_neo_skills_routes(
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["skill_key"] == "neo.demo"
@pytest.mark.asyncio
async def test_batch_upload_skills_returns_error_when_all_files_invalid(
app: Quart,
authenticated_header: dict,
):
test_client = app.test_client()
response = await test_client.post(
"/api/skills/batch-upload",
headers=authenticated_header,
files={
"files": FileStorage(
stream=io.BytesIO(b"not-a-zip"),
filename="invalid.txt",
content_type="text/plain",
),
},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "error"
assert data["message"] == "Upload failed for all 1 file(s)."
@pytest.mark.asyncio
async def test_batch_upload_skills_accepts_zip_files(
app: Quart,
authenticated_header: dict,
monkeypatch,
):
async def _fake_sync_skills_to_active_sandboxes():
return
def _fake_install_skill_from_zip(
self,
zip_path: str,
*,
overwrite: bool = True,
):
_ = self, overwrite
assert zip_path.endswith(".zip")
return "demo_skill"
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
_fake_sync_skills_to_active_sandboxes,
)
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.SkillManager.install_skill_from_zip",
_fake_install_skill_from_zip,
)
test_client = app.test_client()
response = await test_client.post(
"/api/skills/batch-upload",
headers=authenticated_header,
files={
"files": FileStorage(
stream=io.BytesIO(b"fake-zip"),
filename="demo_skill.zip",
content_type="application/zip",
),
},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["message"] == "All 1 skill(s) uploaded successfully."
assert data["data"]["total"] == 1
assert data["data"]["succeeded"] == [
{"filename": "demo_skill.zip", "name": "demo_skill"}
]
assert data["data"]["failed"] == []
@pytest.mark.asyncio
async def test_batch_upload_skills_accepts_valid_skill_archive(
app: Quart,
authenticated_header: dict,
monkeypatch,
tmp_path,
):
data_dir = tmp_path / "data"
skills_dir = tmp_path / "skills"
temp_dir = tmp_path / "temp"
data_dir.mkdir()
skills_dir.mkdir()
temp_dir.mkdir()
async def _fake_sync_skills_to_active_sandboxes():
return
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
_fake_sync_skills_to_active_sandboxes,
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_skills_path",
lambda: str(skills_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
lambda: str(temp_dir),
)
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.get_astrbot_temp_path",
lambda: str(temp_dir),
)
archive = io.BytesIO()
with zipfile.ZipFile(archive, "w", zipfile.ZIP_DEFLATED) as zf:
zf.writestr(
"demo_skill/SKILL.md",
"---\nname: demo-skill\ndescription: Demo skill\n---\n",
)
zf.writestr("demo_skill/notes.txt", "hello")
zf.writestr("__MACOSX/demo_skill/._SKILL.md", "")
zf.writestr("__MACOSX/._demo_skill", "")
archive.seek(0)
test_client = app.test_client()
response = await test_client.post(
"/api/skills/batch-upload",
headers=authenticated_header,
files={
"files": FileStorage(
stream=archive,
filename="demo_skill.zip",
content_type="application/zip",
),
},
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["data"]["succeeded"] == [
{"filename": "demo_skill.zip", "name": "demo_skill"}
]
assert data["data"]["failed"] == []
assert (skills_dir / "demo_skill" / "SKILL.md").exists()
@pytest.mark.asyncio
async def test_batch_upload_skills_partial_success(
app: Quart,
authenticated_header: dict,
monkeypatch,
):
async def _fake_sync_skills_to_active_sandboxes():
return
def _fake_install_skill_from_zip(
self,
zip_path: str,
*,
overwrite: bool = True,
):
_ = self, overwrite
if "ok_skill" in zip_path:
return "ok_skill"
raise RuntimeError("install failed")
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes",
_fake_sync_skills_to_active_sandboxes,
)
monkeypatch.setattr(
"astrbot.dashboard.routes.skills.SkillManager.install_skill_from_zip",
_fake_install_skill_from_zip,
)
test_client = app.test_client()
boundary = "----AstrBotBatchBoundary"
body = (
(
f"--{boundary}\r\n"
'Content-Disposition: form-data; name="files"; filename="ok_skill.zip"\r\n'
"Content-Type: application/zip\r\n\r\n"
).encode()
+ b"fake-zip-1\r\n"
+ (
f"--{boundary}\r\n"
'Content-Disposition: form-data; name="files"; filename="bad_skill.zip"\r\n'
"Content-Type: application/zip\r\n\r\n"
).encode()
+ b"fake-zip-2\r\n"
+ f"--{boundary}--\r\n".encode()
)
headers = dict(authenticated_header)
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
response = await test_client.post(
"/api/skills/batch-upload",
headers=headers,
data=body,
)
assert response.status_code == 200
data = await response.get_json()
assert data["status"] == "ok"
assert data["message"] == "Partial success: 1/2 skill(s) uploaded."
assert data["data"]["total"] == 2
assert data["data"]["succeeded"] == [
{"filename": "ok_skill.zip", "name": "ok_skill"}
]
assert data["data"]["failed"] == [
{"filename": "bad_skill.zip", "error": "install failed"}
]
+41
View File
@@ -0,0 +1,41 @@
from unittest.mock import AsyncMock
import pytest
from astrbot.core.utils.pip_installer import PipInstaller
@pytest.mark.asyncio
async def test_install_targets_site_packages_for_desktop_client(monkeypatch, tmp_path):
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
monkeypatch.delattr("sys.frozen", raising=False)
site_packages_path = tmp_path / "site-packages"
run_pip = AsyncMock(return_value=0)
prepend_sys_path_calls = []
ensure_preferred_calls = []
monkeypatch.setattr(PipInstaller, "_run_pip_in_process", run_pip)
monkeypatch.setattr(
"astrbot.core.utils.pip_installer.get_astrbot_site_packages_path",
lambda: str(site_packages_path),
)
monkeypatch.setattr(
"astrbot.core.utils.pip_installer._prepend_sys_path",
lambda path: prepend_sys_path_calls.append(path),
)
monkeypatch.setattr(
"astrbot.core.utils.pip_installer._ensure_plugin_dependencies_preferred",
lambda path, requirements: ensure_preferred_calls.append((path, requirements)),
)
installer = PipInstaller("")
await installer.install(package_name="demo-package")
run_pip.assert_awaited_once()
recorded_args = run_pip.await_args_list[0].args[0]
assert "--target" in recorded_args
assert str(site_packages_path) in recorded_args
assert prepend_sys_path_calls == [str(site_packages_path), str(site_packages_path)]
assert ensure_preferred_calls == [(str(site_packages_path), {"demo-package"})]
+26
View File
@@ -0,0 +1,26 @@
from astrbot.core.utils.astrbot_path import get_astrbot_root
from astrbot.core.utils.runtime_env import is_packaged_desktop_runtime
def test_desktop_client_env_marks_desktop_runtime_without_frozen(monkeypatch):
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
monkeypatch.delattr("sys.frozen", raising=False)
assert is_packaged_desktop_runtime() is True
def test_desktop_client_uses_home_root_without_explicit_astrbot_root(monkeypatch):
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
monkeypatch.delattr("sys.frozen", raising=False)
assert get_astrbot_root().endswith(".astrbot")
def test_explicit_astrbot_root_overrides_desktop_default(monkeypatch, tmp_path):
explicit_root = tmp_path / "astrbot-root"
monkeypatch.setenv("ASTRBOT_DESKTOP_CLIENT", "1")
monkeypatch.setenv("ASTRBOT_ROOT", str(explicit_root))
monkeypatch.delattr("sys.frozen", raising=False)
assert get_astrbot_root() == str(explicit_root.resolve())
+57 -4
View File
@@ -2,6 +2,8 @@ from __future__ import annotations
from pathlib import Path
import pytest
from astrbot.core.skills.skill_manager import SkillManager
@@ -56,7 +58,7 @@ def test_list_skills_merges_local_and_sandbox_cache(monkeypatch, tmp_path: Path)
assert by_name["custom-local"].description == "local description"
assert by_name["custom-local"].path == "skills/custom-local/SKILL.md"
assert by_name["python-sandbox"].description == "ship built-in"
assert by_name["python-sandbox"].path == "skills/python-sandbox/SKILL.md"
assert by_name["python-sandbox"].path == "/workspace/skills/python-sandbox/SKILL.md"
def test_sandbox_cached_skill_respects_active_and_display_path(
@@ -98,7 +100,58 @@ def test_sandbox_cached_skill_respects_active_and_display_path(
assert len(all_skills) == 1
assert all_skills[0].path == "/app/skills/browser-automation/SKILL.md"
mgr.set_skill_active("browser-automation", False)
active_skills = mgr.list_skills(runtime="sandbox", active_only=True)
assert active_skills == []
with pytest.raises(PermissionError):
mgr.set_skill_active("browser-automation", False)
active_skills = mgr.list_skills(runtime="sandbox", active_only=True)
assert len(active_skills) == 1
assert active_skills[0].name == "browser-automation"
def test_sandbox_and_local_path_resolution_with_show_sandbox_path_false(
monkeypatch,
tmp_path: Path,
):
data_dir = tmp_path / "data"
temp_dir = tmp_path / "temp"
skills_root = tmp_path / "skills"
data_dir.mkdir(parents=True, exist_ok=True)
temp_dir.mkdir(parents=True, exist_ok=True)
skills_root.mkdir(parents=True, exist_ok=True)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_data_path",
lambda: str(data_dir),
)
monkeypatch.setattr(
"astrbot.core.skills.skill_manager.get_astrbot_temp_path",
lambda: str(temp_dir),
)
mgr = SkillManager(skills_root=str(skills_root))
_write_skill(skills_root, "custom-local", "local description")
mgr.set_sandbox_skills_cache(
[
{
"name": "custom-local",
"description": "cached description should be overridden",
"path": "/app/skills/custom-local/SKILL.md",
},
{
"name": "python-sandbox",
"description": "ship built-in",
"path": "/app/skills/python-sandbox/SKILL.md",
},
]
)
skills = mgr.list_skills(runtime="sandbox", show_sandbox_path=False)
by_name = {item.name: item for item in skills}
assert sorted(by_name) == ["custom-local", "python-sandbox"]
assert by_name["custom-local"].description == "local description"
local_skill_path = Path(by_name["custom-local"].path)
assert local_skill_path.is_relative_to(skills_root)
assert local_skill_path == skills_root / "custom-local" / "SKILL.md"
assert by_name["python-sandbox"].path == "/app/skills/python-sandbox/SKILL.md"
+59
View File
@@ -0,0 +1,59 @@
from unittest.mock import AsyncMock
import pytest
import astrbot.core.message.components as Comp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.pipeline.respond.stage import RespondStage
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import (
AiocqhttpMessageEvent,
)
def test_poke_to_dict_matches_onebot_v11_segment_format():
poke = Comp.Poke(type="126", id=2003)
assert poke.toDict() == {
"type": "poke",
"data": {"type": "126", "id": "2003"},
}
def test_poke_to_dict_keeps_legacy_qq_compatible():
poke = Comp.Poke(type="poke", qq=2916963017)
assert poke.toDict() == {
"type": "poke",
"data": {"type": "126", "id": "2916963017"},
}
@pytest.mark.asyncio
async def test_respond_stage_treats_poke_with_target_as_non_empty():
stage = RespondStage()
chain = [Comp.Poke(type="126", id=2003)]
assert await stage._is_empty_message_chain(chain) is False
@pytest.mark.asyncio
async def test_aiocqhttp_parse_json_outputs_standard_poke_data():
chain = MessageChain([Comp.Poke(type="126", id=2003)])
data = await AiocqhttpMessageEvent._parse_onebot_json(chain)
assert data == [{"type": "poke", "data": {"type": "126", "id": "2003"}}]
@pytest.mark.asyncio
async def test_aiocqhttp_send_message_dispatches_onebot_v11_poke_payload():
bot = AsyncMock()
chain = MessageChain([Comp.Poke(type="126", id=2003)])
await AiocqhttpMessageEvent.send_message(
bot=bot,
message_chain=chain,
event=None,
is_group=True,
session_id="123456",
)
bot.send_group_msg.assert_awaited_once_with(
group_id=123456,
message=[{"type": "poke", "data": {"type": "126", "id": "2003"}}],
)
+2 -2
View File
@@ -373,7 +373,7 @@ class TestAstrBotCoreLifecycleInitialize:
new_callable=AsyncMock,
),
):
await lifecycle.initialize(mcp_init_timeout=3.5)
await lifecycle.initialize()
# Verify database initialized
mock_db.initialize.assert_awaited_once()
@@ -388,7 +388,7 @@ class TestAstrBotCoreLifecycleInitialize:
mock_persona_mgr.initialize.assert_awaited_once()
# Verify provider manager initialized
mock_provider_manager.initialize.assert_awaited_once_with(init_timeout=3.5)
mock_provider_manager.initialize.assert_awaited_once()
# Verify platform manager initialized
mock_platform_manager.initialize.assert_awaited_once()
-43
View File
@@ -1,43 +0,0 @@
import pytest
from tests.fixtures.helpers import get_bound_tcp_port
class _DummySiteNoAttrs:
pass
class _DummySocket:
def __init__(self, port: int) -> None:
self._port = port
def getsockname(self):
return ("127.0.0.1", self._port)
class _DummyServer:
def __init__(self, port: int) -> None:
self.sockets = [_DummySocket(port)]
class _DummySiteWithName:
def __init__(self, port: int) -> None:
self.name = f"http://localhost:{port}"
class _DummySiteWithServer:
def __init__(self, port: int) -> None:
self._server = _DummyServer(port)
def test_get_bound_tcp_port_raises_on_unresolvable_site():
with pytest.raises(RuntimeError, match="Unable to resolve bound TCP port"):
get_bound_tcp_port(_DummySiteNoAttrs())
def test_get_bound_tcp_port_uses_name_port_when_available():
assert get_bound_tcp_port(_DummySiteWithName(8081)) == 8081
def test_get_bound_tcp_port_falls_back_to_server_sockets():
assert get_bound_tcp_port(_DummySiteWithServer(9092)) == 9092
-98
View File
@@ -1,98 +0,0 @@
import json
import pytest
from astrbot.core.provider import func_tool_manager
from astrbot.core.provider.func_tool_manager import FunctionToolManager
@pytest.fixture
def mcp_init_harness(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
):
manager = FunctionToolManager()
data_dir = tmp_path / "data"
data_dir.mkdir()
(data_dir / "mcp_server.json").write_text(
json.dumps({"mcpServers": {"demo": {"active": True}}}),
encoding="utf-8",
)
monkeypatch.setattr(
func_tool_manager,
"get_astrbot_data_path",
lambda: data_dir,
)
called = {}
async def fake_start_mcp_server(*, name, cfg, shutdown_event, timeout_seconds):
called[name] = {
"cfg": cfg,
"shutdown_event_type": type(shutdown_event).__name__,
"timeout_seconds": timeout_seconds,
}
monkeypatch.setattr(manager, "_start_mcp_server", fake_start_mcp_server)
return manager, called
def assert_demo_init_result(summary, called, *, timeout_seconds: float) -> None:
assert summary.total == 1
assert summary.success == 1
assert summary.failed == []
assert called["demo"]["cfg"] == {"active": True}
assert called["demo"]["shutdown_event_type"] == "Event"
assert called["demo"]["timeout_seconds"] == timeout_seconds
@pytest.mark.asyncio
async def test_init_mcp_clients_passes_timeout_seconds_keyword(mcp_init_harness):
manager, called = mcp_init_harness
summary = await manager.init_mcp_clients()
assert_demo_init_result(
summary,
called,
timeout_seconds=manager._init_timeout_default,
)
@pytest.mark.asyncio
async def test_init_mcp_clients_passes_overridden_init_timeout(
mcp_init_harness,
):
manager, called = mcp_init_harness
summary = await manager.init_mcp_clients(init_timeout=3.5)
assert_demo_init_result(summary, called, timeout_seconds=3.5)
@pytest.mark.asyncio
async def test_init_mcp_clients_reads_env_timeout_when_not_overridden(
mcp_init_harness,
monkeypatch: pytest.MonkeyPatch,
):
manager, called = mcp_init_harness
manager._init_timeout_default = 20.0 # ensure env override is observable
monkeypatch.setenv("ASTRBOT_MCP_INIT_TIMEOUT", "3.5")
summary = await manager.init_mcp_clients()
assert_demo_init_result(summary, called, timeout_seconds=3.5)
@pytest.mark.asyncio
async def test_init_mcp_clients_prefers_explicit_timeout_over_env(
mcp_init_harness,
monkeypatch: pytest.MonkeyPatch,
):
manager, called = mcp_init_harness
monkeypatch.setenv("ASTRBOT_MCP_INIT_TIMEOUT", "7.0")
summary = await manager.init_mcp_clients(init_timeout=3.5)
assert_demo_init_result(summary, called, timeout_seconds=3.5)
-71
View File
@@ -1,71 +0,0 @@
import pytest
from aiohttp import web
from astrbot.core.utils import io as io_module
from astrbot.core.utils.io import _stream_to_file, download_file
from tests.fixtures.helpers import get_bound_tcp_port
@pytest.mark.asyncio
async def test_download_file_downloads_content(tmp_path):
payload = b"astrbot-download-payload" * 256
async def handle(_request):
return web.Response(body=payload)
app = web.Application()
app.router.add_get("/file.bin", handle)
runner = web.AppRunner(app, access_log=None)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
try:
port = get_bound_tcp_port(site)
url = f"http://127.0.0.1:{port}/file.bin"
out = tmp_path / "downloaded.bin"
await download_file(url, str(out))
assert out.read_bytes() == payload
finally:
await runner.cleanup()
class _DummyStream:
def __init__(self, chunks: list[bytes]) -> None:
self._chunks = chunks
async def read(self, _size: int) -> bytes:
if not self._chunks:
return b""
return self._chunks.pop(0)
class _RecordingFile:
def __init__(self) -> None:
self.writes: list[bytes] = []
def write(self, data: bytes) -> int:
self.writes.append(data)
return len(data)
@pytest.mark.asyncio
async def test_stream_to_file_batches_multiple_chunks_per_write(monkeypatch):
monkeypatch.setattr(io_module, "_DOWNLOAD_READ_CHUNK_SIZE", 4)
monkeypatch.setattr(io_module, "_DOWNLOAD_FLUSH_THRESHOLD", 10)
stream = _DummyStream([b"aaaa", b"bbbb", b"cccc"])
file_obj = _RecordingFile()
await _stream_to_file(
stream,
file_obj,
total_size=12,
start_time=0.0,
show_progress=False,
)
assert len(file_obj.writes) == 1
assert file_obj.writes[0] == b"aaaabbbbcccc"
-178
View File
@@ -1,178 +0,0 @@
import base64
import os
from pathlib import Path
import pytest
from aiohttp import web
from astrbot.core.message import components as components_module
from astrbot.core.message.components import File, Image, Record
from tests.fixtures.helpers import get_bound_tcp_port
@pytest.mark.asyncio
async def test_image_convert_to_file_path_returns_absolute_path(tmp_path):
file_path = tmp_path / "img.bin"
file_path.write_bytes(b"img")
image = Image(file=str(file_path))
resolved = await image.convert_to_file_path()
assert resolved == os.path.abspath(str(file_path))
@pytest.mark.asyncio
async def test_record_convert_to_file_path_returns_absolute_path(tmp_path):
file_path = tmp_path / "record.bin"
file_path.write_bytes(b"record")
record = Record(file=str(file_path))
resolved = await record.convert_to_file_path()
assert resolved == os.path.abspath(str(file_path))
@pytest.mark.asyncio
async def test_file_component_get_file_returns_absolute_path(tmp_path):
file_path = tmp_path / "file.bin"
file_path.write_bytes(b"file")
file_component = File(name="file.bin", file=str(file_path))
resolved = await file_component.get_file()
assert resolved == os.path.abspath(str(file_path))
@pytest.mark.asyncio
async def test_image_convert_to_base64_raises_on_missing_file(tmp_path):
image = Image(file=str(tmp_path / "missing.bin"))
with pytest.raises(Exception, match="not a valid file"):
await image.convert_to_base64()
@pytest.mark.asyncio
async def test_record_convert_to_base64_raises_on_missing_file(tmp_path):
record = Record(file=str(tmp_path / "missing.bin"))
with pytest.raises(Exception, match="not a valid file"):
await record.convert_to_base64()
@pytest.mark.asyncio
async def test_image_convert_to_base64_reads_existing_local_file(tmp_path):
raw = b"image-bytes"
file_path = tmp_path / "exists_image.bin"
file_path.write_bytes(raw)
image = Image(file=str(file_path))
encoded = await image.convert_to_base64()
assert base64.b64decode(encoded) == raw
@pytest.mark.asyncio
async def test_record_convert_to_base64_reads_existing_local_file(tmp_path):
raw = b"record-bytes"
file_path = tmp_path / "exists_record.bin"
file_path.write_bytes(raw)
record = Record(file=str(file_path))
encoded = await record.convert_to_base64()
assert base64.b64decode(encoded) == raw
@pytest.mark.asyncio
async def test_image_convert_to_base64_maps_permission_error(monkeypatch):
async def _raise_permission_error(_path: str) -> str:
raise PermissionError("permission denied")
monkeypatch.setattr(components_module, "file_to_base64", _raise_permission_error)
image = Image(file="/tmp/forbidden-image")
with pytest.raises(Exception, match="not a valid file"):
await image.convert_to_base64()
@pytest.mark.asyncio
async def test_record_convert_to_base64_maps_permission_error(monkeypatch):
async def _raise_permission_error(_path: str) -> str:
raise PermissionError("permission denied")
monkeypatch.setattr(components_module, "file_to_base64", _raise_permission_error)
record = Record(file="/tmp/forbidden-record")
with pytest.raises(Exception, match="not a valid file"):
await record.convert_to_base64()
@pytest.mark.asyncio
async def test_image_convert_to_file_path_from_base64_creates_absolute_file():
payload = b"image-base64-payload"
image = Image(file=f"base64://{base64.b64encode(payload).decode()}")
resolved = await image.convert_to_file_path()
resolved_path = Path(resolved)
assert resolved_path.is_absolute()
assert resolved_path.exists()
assert resolved_path.read_bytes() == payload
@pytest.mark.asyncio
async def test_record_convert_to_file_path_from_base64_creates_absolute_file():
payload = b"record-base64-payload"
record = Record(file=f"base64://{base64.b64encode(payload).decode()}")
resolved = await record.convert_to_file_path()
resolved_path = Path(resolved)
assert resolved_path.is_absolute()
assert resolved_path.exists()
assert resolved_path.read_bytes() == payload
async def _serve_payload(payload: bytes, route: str):
async def handle(_request):
return web.Response(body=payload)
app = web.Application()
app.router.add_get(route, handle)
runner = web.AppRunner(app, access_log=None)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
return runner, get_bound_tcp_port(site)
@pytest.mark.asyncio
async def test_image_convert_to_file_path_from_http_creates_absolute_file():
payload = b"image-http-payload"
runner, port = await _serve_payload(payload, "/img.bin")
try:
image = Image(file=f"http://127.0.0.1:{port}/img.bin")
resolved = await image.convert_to_file_path()
resolved_path = Path(resolved)
assert resolved_path.is_absolute()
assert resolved_path.exists()
assert resolved_path.read_bytes() == payload
finally:
await runner.cleanup()
@pytest.mark.asyncio
async def test_record_convert_to_file_path_from_http_creates_absolute_file():
payload = b"record-http-payload"
runner, port = await _serve_payload(payload, "/record.bin")
try:
record = Record(file=f"http://127.0.0.1:{port}/record.bin")
resolved = await record.convert_to_file_path()
resolved_path = Path(resolved)
assert resolved_path.is_absolute()
assert resolved_path.exists()
assert resolved_path.read_bytes() == payload
finally:
await runner.cleanup()
+545
View File
@@ -0,0 +1,545 @@
"""Tests for SessionLockManager with multi-event-loop isolation."""
import asyncio
import threading
import time
import weakref
from concurrent.futures import ThreadPoolExecutor
import pytest
from astrbot.core.utils.session_lock import SessionLockManager
class TestSessionLockManagerBasic:
"""Basic functionality tests."""
def test_init(self):
"""Test manager initialization."""
manager = SessionLockManager()
assert manager._state_guard is not None
assert manager._loop_managers is not None
@pytest.mark.asyncio
async def test_acquire_release_lock(self):
"""Test basic lock acquire and release."""
manager = SessionLockManager()
session_id = "test-session"
async with manager.acquire_lock(session_id):
# Lock acquired successfully
pass
# Lock should be released and cleaned up
state = manager._get_loop_manager()
assert session_id not in state._locks
assert session_id not in state._lock_count
@pytest.mark.asyncio
async def test_lock_is_reusable(self):
"""Test that locks can be acquired multiple times."""
manager = SessionLockManager()
session_id = "test-session"
async with manager.acquire_lock(session_id):
pass
async with manager.acquire_lock(session_id):
pass
# Both acquisitions should succeed
class TestCrossLoopIsolation:
"""Tests for event loop isolation."""
@pytest.mark.asyncio
async def test_different_loops_have_different_managers(self):
"""Test that different event loops get different per-loop managers."""
manager = SessionLockManager()
# Get manager for current loop
manager1 = manager._get_loop_manager()
# Run in a different event loop
def run_in_new_loop():
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
async def get_manager():
return manager._get_loop_manager()
return new_loop.run_until_complete(get_manager())
finally:
new_loop.close()
asyncio.set_event_loop(None)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
manager2 = future.result()
# Should be different manager instances
assert manager1 is not manager2
@pytest.mark.asyncio
async def test_locks_isolated_across_loops(self):
"""Test that locks from different loops are isolated."""
manager = SessionLockManager()
session_id = "shared-session"
results = []
async def acquire_in_loop(loop_id: int):
"""Acquire lock in a new event loop."""
async with manager.acquire_lock(session_id):
results.append(f"loop-{loop_id}-acquired")
await asyncio.sleep(0.05)
results.append(f"loop-{loop_id}-released")
def run_in_thread(loop_id: int):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(acquire_in_loop(loop_id))
finally:
loop.close()
asyncio.set_event_loop(None)
# Run two loops concurrently - they should NOT block each other
# because locks are isolated per-loop
with ThreadPoolExecutor(max_workers=2) as executor:
futures = [executor.submit(run_in_thread, i) for i in range(2)]
for f in futures:
f.result()
# Both loops should acquire immediately (no blocking between loops)
# Order should show interleaved acquisitions, not sequential
assert len(results) == 4
@pytest.mark.asyncio
async def test_same_loop_blocks_on_same_session(self):
"""Test that same loop blocks when acquiring same session lock."""
manager = SessionLockManager()
session_id = "test-session"
execution_order = []
async def task1():
async with manager.acquire_lock(session_id):
execution_order.append("task1-start")
await asyncio.sleep(0.1)
execution_order.append("task1-end")
async def task2():
await asyncio.sleep(0.01) # Let task1 start first
async with manager.acquire_lock(session_id):
execution_order.append("task2-start")
execution_order.append("task2-end")
await asyncio.gather(task1(), task2())
# task2 should wait for task1 to finish
assert execution_order.index("task1-start") < execution_order.index("task1-end")
assert execution_order.index("task1-end") < execution_order.index("task2-start")
class TestConcurrency:
"""Tests for concurrent access."""
@pytest.mark.asyncio
async def test_concurrent_acquisitions_same_loop(self):
"""Test concurrent lock acquisitions on the same loop."""
manager = SessionLockManager()
session_id = "concurrent-session"
acquired_count = 0
max_concurrent = 0
lock = asyncio.Lock()
async def acquire_and_check():
nonlocal acquired_count, max_concurrent
async with manager.acquire_lock(session_id):
async with lock:
acquired_count += 1
max_concurrent = max(max_concurrent, acquired_count)
await asyncio.sleep(0.01)
async with lock:
acquired_count -= 1
# Run multiple concurrent tasks
tasks = [acquire_and_check() for _ in range(5)]
await asyncio.gather(*tasks)
# Max concurrent should be 1 (lock serializes access)
assert max_concurrent == 1
@pytest.mark.asyncio
async def test_thread_safety_of_loop_manager_creation(self):
"""Test that _get_loop_manager is thread-safe."""
manager = SessionLockManager()
managers = []
errors = []
def create_loop_and_get_manager():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def get_mgr():
return manager._get_loop_manager()
mgr = loop.run_until_complete(get_mgr())
managers.append(mgr)
except Exception as e:
errors.append(e)
finally:
loop.close()
asyncio.set_event_loop(None)
threads = [threading.Thread(target=create_loop_and_get_manager) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(errors) == 0
# All managers should be valid
for m in managers:
assert hasattr(m, "_locks")
assert hasattr(m, "_access_lock")
class TestEventLoopCleanup:
"""Tests for event loop cleanup behavior."""
@pytest.mark.asyncio
async def test_weakref_cleanup_on_loop_close(self):
"""Test that per-loop managers are cleaned up when loop is closed."""
manager = SessionLockManager()
loop_ref: weakref.ref[asyncio.AbstractEventLoop] | None = None
def run_in_new_loop():
nonlocal loop_ref
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop_ref = weakref.ref(loop)
async def use_lock():
async with manager.acquire_lock("test-session"):
pass
return manager._get_loop_manager()
try:
per_loop_mgr = loop.run_until_complete(use_lock())
# Keep a weak ref to the per-loop manager
return weakref.ref(per_loop_mgr)
finally:
loop.close()
asyncio.set_event_loop(None)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
per_loop_mgr_ref = future.result()
# Give time for weakref cleanup
import gc
gc.collect()
# The per-loop manager should be cleaned up when the loop is closed
# because WeakKeyDictionary removes entries when the key (loop) is gone
per_loop_mgr = per_loop_mgr_ref()
loop = loop_ref() if loop_ref is not None else None
assert per_loop_mgr is None or loop is None
@pytest.mark.asyncio
async def test_access_after_loop_close_in_new_loop_works(self):
"""Test that accessing from a new loop after old loop closes works."""
manager = SessionLockManager()
# Use lock in current loop
async with manager.acquire_lock("session-1"):
pass
# Simulate old loop being closed and new loop being created
def run_in_new_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def use_lock():
# Should work without issues in new loop
async with manager.acquire_lock("session-2"):
return "success"
return loop.run_until_complete(use_lock())
finally:
loop.close()
asyncio.set_event_loop(None)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(run_in_new_loop)
result = future.result()
assert result == "success"
class TestIssue5464:
"""Tests for issue #5464: Multiple OneBot instances with different event loops.
Issue: Running multiple OneBot adapter instances causes
"is bound to a different event loop" error.
"""
@pytest.mark.asyncio
async def test_multiple_event_loops_no_cross_loop_error(self):
"""Test that multiple event loops don't cause cross-loop binding errors.
This simulates the scenario where multiple OneBot instances
(each potentially running in different event loops) access the
same SessionLockManager concurrently.
"""
from astrbot.core.utils.session_lock import session_lock_manager
errors: list[Exception] = []
results: list[str] = []
def simulate_onebot_instance(instance_id: int, session_ids: list[str]):
"""Simulate a OneBot instance running in its own event loop."""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def process_messages():
for session_id in session_ids:
try:
async with session_lock_manager.acquire_lock(session_id):
# Simulate message processing
await asyncio.sleep(0.01)
results.append(f"instance-{instance_id}-{session_id}")
except Exception as e:
errors.append(e)
loop.run_until_complete(process_messages())
finally:
loop.close()
asyncio.set_event_loop(None)
# Simulate 4 OneBot instances (as in the issue report)
# Each handles multiple sessions concurrently
threads = []
for i in range(4):
sessions = [f"session-{i}-1", f"session-{i}-2", f"session-{i}-3"]
t = threading.Thread(target=simulate_onebot_instance, args=(i, sessions))
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
# Should have no errors (especially no "bound to a different event loop")
assert len(errors) == 0, f"Errors occurred: {errors}"
assert len(results) == 12 # 4 instances * 3 sessions each
@pytest.mark.asyncio
async def test_lock_object_not_shared_across_loops(self):
"""Verify that asyncio.Lock objects are not shared across event loops.
The root cause of issue #5464 was that Lock objects created in one
event loop were being used in another, causing the error.
"""
manager = SessionLockManager()
session_id = "shared-session-id"
lock_ids: set[int] = set()
lock_id_lock = threading.Lock()
def get_lock_in_new_loop():
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def acquire_and_capture():
# Get the per-loop manager
per_loop_mgr = manager._get_loop_manager()
# Capture the lock object id before acquiring
async with per_loop_mgr._access_lock:
lock = per_loop_mgr._locks[session_id]
with lock_id_lock:
lock_ids.add(id(lock))
async with manager.acquire_lock(session_id):
await asyncio.sleep(0.01)
loop.run_until_complete(acquire_and_capture())
finally:
loop.close()
asyncio.set_event_loop(None)
# Run multiple loops concurrently
threads = [threading.Thread(target=get_lock_in_new_loop) for _ in range(5)]
for t in threads:
t.start()
for t in threads:
t.join()
# Each loop should have its own Lock object
# If locks were shared, we'd only have 1 lock_id
assert len(lock_ids) == 5, "Each event loop should have its own Lock object"
@pytest.mark.asyncio
async def test_concurrent_access_same_session_different_loops(self):
"""Test that same session ID accessed from different loops doesn't block.
This verifies the fix: locks are isolated per event loop,
so different loops can acquire the "same" session lock concurrently.
"""
from astrbot.core.utils.session_lock import session_lock_manager
session_id = "global-session"
acquisition_times: list[float] = []
time_lock = threading.Lock()
def acquire_lock_in_loop(loop_id: int):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
async def acquire():
import time
start = time.time()
async with session_lock_manager.acquire_lock(session_id):
with time_lock:
acquisition_times.append(start)
await asyncio.sleep(0.1) # Hold the lock
loop.run_until_complete(acquire())
finally:
loop.close()
asyncio.set_event_loop(None)
# Start 3 threads nearly simultaneously
threads = [threading.Thread(target=acquire_lock_in_loop, args=(i,)) for i in range(3)]
start_time = time.time()
for t in threads:
t.start()
for t in threads:
t.join()
total_time = time.time() - start_time
# If locks were NOT isolated, we'd need ~0.3s (3 * 0.1s serial)
# With isolation, all should complete in ~0.1s (parallel)
# Allow some overhead, but should be much less than 0.3s
assert total_time < 0.25, (
f"Locks should be isolated per loop, but took {total_time:.2f}s"
)
class TestEdgeCases:
"""Tests for edge cases."""
@pytest.mark.asyncio
async def test_empty_session_id(self):
"""Test with empty session ID."""
manager = SessionLockManager()
async with manager.acquire_lock(""):
pass
# Should work without issues
@pytest.mark.asyncio
async def test_special_characters_in_session_id(self):
"""Test with special characters in session ID."""
manager = SessionLockManager()
session_id = "session-with-special-chars!@#$%^&*()"
async with manager.acquire_lock(session_id):
pass
# Should work without issues
@pytest.mark.asyncio
async def test_very_long_session_id(self):
"""Test with very long session ID."""
manager = SessionLockManager()
session_id = "a" * 10000
async with manager.acquire_lock(session_id):
pass
# Should work without issues
@pytest.mark.asyncio
async def test_lock_not_held_after_context_exit(self):
"""Test that lock is released after context manager exit."""
manager = SessionLockManager()
session_id = "test-session"
async with manager.acquire_lock(session_id):
state = manager._get_loop_manager()
# Lock should exist and have count 1
assert session_id in state._locks
assert state._lock_count[session_id] == 1
# After exit, lock should be cleaned up
state = manager._get_loop_manager()
assert session_id not in state._locks
assert session_id not in state._lock_count
@pytest.mark.asyncio
async def test_exception_during_lock(self):
"""Test that lock is released even if exception occurs."""
manager = SessionLockManager()
session_id = "test-session"
with pytest.raises(ValueError):
async with manager.acquire_lock(session_id):
raise ValueError("test error")
# Lock should still be released
state = manager._get_loop_manager()
assert session_id not in state._locks
assert session_id not in state._lock_count
@pytest.mark.asyncio
async def test_nested_lock_different_sessions(self):
"""Test nested locks on different sessions."""
manager = SessionLockManager()
async with manager.acquire_lock("session-1"):
async with manager.acquire_lock("session-2"):
state = manager._get_loop_manager()
assert "session-1" in state._locks
assert "session-2" in state._locks
assert state._lock_count["session-1"] == 1
assert state._lock_count["session-2"] == 1
state = manager._get_loop_manager()
assert "session-1" not in state._locks
assert "session-2" not in state._locks
@pytest.mark.asyncio
async def test_reentrant_lock_same_session(self):
"""Test reentrant locking on same session (should block)."""
manager = SessionLockManager()
session_id = "test-session"
order = []
async def outer():
async with manager.acquire_lock(session_id):
order.append("outer-acquired")
await asyncio.sleep(0.1)
order.append("outer-done")
async def inner():
await asyncio.sleep(0.01) # Let outer acquire first
order.append("inner-attempt")
async with manager.acquire_lock(session_id):
order.append("inner-acquired")
order.append("inner-done")
await asyncio.gather(outer(), inner())
# Inner should wait for outer to complete
assert order.index("outer-acquired") < order.index("outer-done")
assert order.index("outer-done") < order.index("inner-acquired")