Compare commits

...

119 Commits

Author SHA1 Message Date
Soulter e7e97730af chore: bump version to 4.9.0 2025-12-13 18:49:07 +08:00
Soulter 467ca1eb5c fix: webui log output incompletely (#4029)
* fix: webui log output incompletely

* fix: improve SSE log parsing to handle partial data chunks

* fix: enhance log handling by implementing local cache and fetching history

* fix: log time handling to use epoch time
2025-12-13 18:46:16 +08:00
RC-CHN 46528391c2 feat: add pre-chunk import strategy for knowledge base (#3973)
* feat: 添加文档导入功能及相关测试

* feat: 优化文档上传功能,支持从文件名推断文件类型,并增强文档切片验证

* feat: 添加文档导入功能的无效输入测试,验证 chunks 类型和内容的错误处理

* refactor: 重构文档上传和导入任务的状态管理,添加任务初始化、结果设置和进度更新方法
2025-12-12 23:15:11 +08:00
Soulter 8a0b7717cc feat: supports webhook mode for Lark platform (#4016)
* feat: add Lark platform support with unified webhook configuration

* fix: update token verification logic in LarkWebhookServer

* feat: implement event deduplication and cleanup for Lark webhook events
2025-12-12 22:12:13 +08:00
Copilot 3b81fb4985 fix: mobile dialog close button visibility (#4010)
* Initial plan

* Fix mobile dialog close button visibility by adding max-height and scrollable content

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 16:02:24 +08:00
Soulter c09d57a820 refactor: improve UI layout and interaction for list item management (#4002)
* refactor: improve UI layout and interaction for list item management

* feat: enhance list configuration UI with batch import functionality

* feat: add internationalization support for list configuration UI
2025-12-11 18:55:56 +08:00
Soulter ec408a2aff fix: lark message timestamp 2025-12-11 18:20:50 +08:00
Soulter 417179a6b9 ci: add smoke test 2025-12-11 10:44:15 +08:00
Soulter fcd29445c7 refactor: remove unused current provider initialization in StarRequestSubStage 2025-12-11 10:36:33 +08:00
BiDuang 5f535001db fix: incorrect modalities enum of gemini api provider (#3993) 2025-12-10 20:27:51 +08:00
PaloMiku 750d245b16 docs: Update README with new Zread link and badges (#3992)
ZRead 是由智谱 AI 推出的 DeepWiki 类似平替品。
2025-12-10 20:22:56 +08:00
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00
Soulter aa6d07afcc refactor: move all internal commands from astrbot plugin to default_command plugin (#3960)
* refactor: move all internal commands from astrbot plugin to default_command plugin

* ruff check

* feat: add config

* ruff check
2025-12-08 22:17:32 +08:00
Soulter 2c36649874 feat: add Agent Runner test prompt dialog in ProviderPage (#3968) 2025-12-08 21:46:47 +08:00
Soulter c95735dcc0 docs: update readme 2025-12-08 12:05:57 +08:00
Soulter 03bb278f50 chore: ruff check 2025-12-08 11:00:43 +08:00
Soulter a5e0974da3 chore: ruff format 2025-12-08 00:36:56 +08:00
vmoranv f0fb447fbc feat: custom plugin api source manager (#3956)
* feat: custom plugin api source manager

* fix: rename plugin source file in a safer way

* chore: turned the way of saving plugin source to backend and refacted some components

* style: clean up whitespace and improve logging message formatting

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-08 00:32:50 +08:00
Soulter 37566182b0 feat: segment reply supports segmentation words (#3959)
* feat: segment reply supports segmentation words

* chore: ruff format

* feat: enhance segmented reply processing by refining word extraction logic

* ruff format
2025-12-08 00:27:17 +08:00
Soulter e460b411da chore: remove dev version from webui (#3951)
* chore: remove dev version

* chore: remove development version references from header localization files
2025-12-07 15:23:30 +08:00
Soulter e14ed804da chore: bump version to 4.8.0 2025-12-05 19:09:56 +08:00
Oscar Shaw 8e4e49df20 fix: not invoke on_llm_response hook when LLM request has error (#3871)
* fix: handle on_agent_done in error responses

- Introduced an LLMResponse for error messages to be processed by agent hooks, ensuring better error reporting and handling.

* fix: improve error logging in on_agent_done hook

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-05 16:13:46 +08:00
Oscar Shaw 5d856900ef perf: some UI/UX fixes, change Console to Platform Logs (#3873)
* refactor: 统一‘平台日志’文案

* perf: 优化自动滚动开关键操作逻辑

* perf: add tooltips to save and code editor buttons
2025-12-05 16:02:20 +08:00
Soulter 380a68b96c chore: add CONTRIBUTING.md 2025-12-05 15:59:18 +08:00
易推倒白毛 8879bd7e9d fix: add supports for Whisper with QQ amr audio file
* fix: Whisper API对QQ语音amr文件的支持

* Update whisper_api_source.py

* fix: cleanup temporary files in Whisper API

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-05 15:41:37 +08:00
RC-CHN 2cce09400f feat: add Kubernetes manifests for astrbot and napcat deployment with services and persistent storage (#3901)
* feat: add Kubernetes manifests for astrbot and napcat deployment with services and persistent storage

* chore: remove 11451 port

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-04 20:36:35 +08:00
Oscar Shaw 54d26dcd38 perf: integrate Pinia store for log cache management (#3852)
* perf: integrate Pinia store for log cache management

* perf: remove unused code
2025-12-04 14:26:05 +08:00
Soulter 205024f27a fix: correct SQL query syntax in SQLiteDatabase class 2025-12-04 12:51:22 +08:00
Soulter efde994907 chore: revise badges and language links
Updated badge links and language options in README.
2025-12-03 17:21:09 +08:00
Soulter 8ca4f9cb74 feat: update README files for multilingual support and enhanced descriptions
- Added French, Russian, and Traditional Chinese README files to support a wider audience.
- Updated English and Japanese README files with improved descriptions of AstrBot's capabilities and features.
- Enhanced community section in all README files to include QQ, Telegram, and Discord group information.
- Adjusted plugin marketplace badge and key features list for clarity and consistency across languages.
2025-12-03 17:01:56 +08:00
Soulter 54e49b997b feat: enhance platform management with status tracking and error handling
- Introduced PlatformStatus enum to manage platform states (pending, running, error, stopped).
- Added error recording and retrieval functionality in the Platform class.
- Implemented a new method in PlatformManager to gather statistics for all platforms.
- Updated the dashboard to display platform statuses and error details, including a dialog for error insights.
- Enhanced localization for runtime statuses and error dialogs in both English and Chinese.
2025-12-03 16:48:57 +08:00
Soulter 5714944eef feat: unified platform webhook url (#3889)
* feat: unified platform webhook url

* chore: ruff format

* fix: 修复 Telegram 语音使用 Whisper API 报错 (#3884)

* Update whisper_api_source.py

* chore: ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>

* Update astrbot/dashboard/routes/platform.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* chore: ruff format

* fix: update webhook dialog descriptions for clarity in English and Chinese locales

* fix: update webhook URL paths to include '/api' prefix for consistency across the application

---------

Co-authored-by: 易推倒白毛 <zhaixingbi@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-12-03 15:44:52 +08:00
Soulter defc46b6c9 fix: remove unnecessary blocks in Slack reply message (#3897) 2025-12-03 13:59:41 +08:00
Soulter 4d819546b0 fix: handle message sending in QQOfficialMessageEvent class (#3894)
- Added a fallback to the `_post_send` method without parameters when the stream payload is not set, ensuring proper message handling in all scenarios.

fixes: #3893
2025-12-03 13:15:12 +08:00
易推倒白毛 8006981976 fix: 修复 Telegram 语音使用 Whisper API 报错 (#3884)
* Update whisper_api_source.py

* chore: ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-03 02:50:50 +08:00
Soulter f7a716af43 refactor: message storage format of webchat, support reply and file message segment (#3845)
* refactor: message storage format of webchat

* refactor: update image and record handling in webchat event processing

* fix: thinking placeholder in webchat

* feat: supports file upload in webchat

* feat: supports to delete attachments when webchat session is deleted

* perf: improve performance of file downloading

* refactor: remove unused import in chat route

* feat: add message timestamp formatting and localization support in chat

* fix: handle missing filename in file upload for chat route

* feat: enhance file handling in chat and webchat, supporting video uploads and improved attachment management

* fix: update property name for embedded files in message handling

* fix: compute variable errors after uninstalling plugins

* feat: supported for reply message and standarlize the message param

* fix: ensure message actions are displayed for the last message in the list
2025-12-02 17:11:08 +08:00
Copilot a708901e7f fix: fix dark mode white background in conversation preview dialog (#3881)
* Initial plan

* Fix dark mode background issue in conversation data preview

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* style: update conversation messages container background color and add debug log for dark mode detection

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-02 17:03:59 +08:00
Soulter e9be8cf69f chore: bump version to 4.7.4 2025-12-01 18:42:07 +08:00
Soulter 31d53edb9d refactor: standardize provider test method implementation
- Updated the `test` method in all provider classes to remove return values and raise exceptions for failure cases, enhancing clarity and consistency.
- Adjusted related logic in the dashboard and command routes to align with the new `test` method behavior, simplifying error handling.
2025-12-01 18:37:08 +08:00
Soulter 2ba0460f19 feat: introduce file extract capability (#3870)
* feat: introduce file extract capability

powered by MoonshotAI

* fix: correct indentation in default configuration file

* fix: add error handling for file extract application in InternalAgentSubStage

* fix: update file name handling in InternalAgentSubStage to correctly associate file names with extracted content

* feat: add condition settings for local agent runner in default configuration

* fix: enhance file naming logic in File component and update prompt handling in InternalAgentSubStage
2025-12-01 18:12:39 +08:00
雪語 0e034f0fbd fix: aiocqhttp 适配器 NapCat 文件名获取为空 (#3853)
* aiocqhttp 适配器 NapCat 文件名获取为空

修复使用 NapCat 时,文件消息的 File.name 为空的问题。原代码硬编码 name="",导致下游插件无法获取文件名和扩展名

* Enhance file name retrieval from message data

Updated file name extraction logic to check multiple fields for better accuracy.
2025-12-01 13:36:19 +08:00
Soulter 2a7d03f9e1 fix: fit language and log AI responses more clearly (#3864)
* fix: fit language and log AI responses more clearly

* chore: ruff format
2025-12-01 13:24:52 +08:00
Soulter 72fac4b9f1 feat: implement unified provider availability testing across components (#3865)
- Added a `test` method to each provider class to standardize availability checks.
- Updated the dashboard and command routes to utilize the new `test` method for provider reachability verification, simplifying the logic and improving maintainability.
- Removed redundant reachability check logic from the command handler.
2025-12-01 13:17:20 +08:00
Soulter 38281ba2cf refactor: restore reachability check configuration in default settings and localization files 2025-12-01 00:38:30 +08:00
Soulter 21aa3174f4 fix: disable reachability check in default configuration 2025-12-01 00:16:11 +08:00
邹永赫 dcda871fc0 feat: provider availability reachability improvements (#3708) 2025-12-01 01:06:10 +09:00
Soulter c13c51f499 fix: assistant message validation error when tool_call exists but content not exists (#3862)
* fix: assistant message validation error when tool_call exists but content not exists

* fix: enhance content validation in Message model to allow None for assistant role with tool_calls
2025-11-30 23:42:37 +08:00
Dt8333 a130db5cf4 fix: 将 Graceful shutdown 的异常改为 KeyboardInterrupt (#3855) 2025-11-30 20:31:17 +08:00
邹永赫 7faeb5cea8 Merge pull request #3850 from zouyonghe/feature/plugin-upgrade-all
增加升级所有插件按钮
2025-11-30 15:12:36 +09:00
ZouYonghe 8d3ff61e0d Format plugin route with ruff 2025-11-30 11:56:24 +08:00
ZouYonghe 4c03e82570 Fix plugin update JSON parsing and concurrency handling 2025-11-30 11:50:46 +08:00
ZouYonghe e7e8664ab4 chore: tweak update all label 2025-11-30 11:18:30 +08:00
ZouYonghe 1dd1623e7d feat: batch update plugins via new api 2025-11-30 11:11:36 +08:00
ZouYonghe 80d8161d58 feat: add update all plugins action 2025-11-30 10:40:46 +08:00
Soulter fc80d7d681 chore: bump version to 4.7.3 2025-11-30 00:42:49 +08:00
Soulter c2f036b27c chore: bump vertion to 4.7.2 2025-11-30 00:33:07 +08:00
Soulter 4087bbb512 perf: set content attribute optional to AssistantMessageSegment for enhanced message handling
fixes: #3843
2025-11-30 00:32:00 +08:00
Soulter e1c728582d chore: bump version to 4.7.2 2025-11-30 00:18:23 +08:00
Oscar Shaw 93c69a639a feat: 新增群聊模式下的专用图片转述模型配置 (#3822)
* feat: add image caption provider configuration for group chat

- Introduced `image_caption_provider_id` to allow separate configuration for group chat image understanding.
- Updated metadata and hints in English and Chinese for clarity on new settings.
- Adjusted logic in long term memory to utilize the new provider ID for image captioning.

* fix: format

* Fix logic for image caption and active reply settings

* Fix indentation and formatting in long_term_memory.py

* chore: ruff format

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-11-29 23:53:32 +08:00
Soulter a7fdc98b29 fix: third party agent runner cannot run properly when using non-default config file
fix: #3815
2025-11-29 23:45:12 +08:00
Soulter 85b7f104df fix: remove unnecessary provider check (#3846)
fixes: #3815
2025-11-29 23:15:19 +08:00
Oscar Shaw d76d1bd7fe perf: adjust padding for PlatformPage and ProviderPage log sections (#3825)
- Added bottom margin to log card for better spacing.
2025-11-29 19:15:35 +08:00
Soulter df4412aa80 style: adjust bot-embedded-image max-width and remove hover effect for improved layout 2025-11-29 01:31:25 +08:00
Soulter ab2c94e19a chore: comment out error logging in provider sources to reduce verbosity 2025-11-28 19:59:33 +08:00
Oscar Shaw 37cc4e2121 perf: console tag UI improve (#3816)
- Added yarn.lock to .gitignore to prevent tracking of Yarn lock files.
- Updated ConsoleDisplayer.vue to improve chip styling
2025-11-28 17:17:11 +08:00
Soulter 60dfdd0a66 chore: update astrbot cli version 2025-11-28 16:53:20 +08:00
Soulter bb8b2cb194 chore: bump version to 4.7.1 2025-11-28 15:13:35 +08:00
Soulter 4e29684aa3 fix: add plugin set and knowledge bases selection in custom rules page (#3813)
fixes: #3806
2025-11-28 13:29:50 +08:00
Soulter 0e17e3553d chore: bump version to 4.7.0 2025-11-27 23:50:05 +08:00
Soulter 0a55060e89 fix: session controller in webchat 2025-11-27 22:32:35 +08:00
Soulter 77859c7daa feat: enhance provider status display in ProviderPage
- Added a tooltip to show detailed provider status, including availability and error messages.
- Refactored item details template to include status chips for better visual representation.
- Removed unused status section to streamline the UI.
2025-11-27 16:39:51 +08:00
Soulter ba39c393a0 perf: enhance provider management with reload locking and logging (#3793)
- Introduced a reload lock to prevent concurrent reloads of providers.
- Added logging to indicate when a provider is disabled and when providers are being synchronized with the configuration.
- Refactored the reload method to improve clarity and maintainability.


Co-authored-by: anka <1350989414@qq.com>
2025-11-27 16:25:31 +08:00
Soulter 6a50d316d9 fix: mcp server cannot reload successfully after updating mcp server config (#3797)
fixes: #3780
2025-11-27 16:22:26 +08:00
Soulter 88c1d77f0b perf: add at message to group chat history (#3796)
* feat: enhance long-term memory message formatting

- Added support for 'At' message components in long-term memory, allowing for better representation of mentions in messages.

* chore: ruff check
2025-11-27 15:59:07 +08:00
Dt8333 758ce40cc1 chore: fix test (#3787) 2025-11-27 14:02:42 +08:00
Soulter 3e7bb80492 chore: ruff format 2025-11-27 14:01:25 +08:00
Soulter 75e95aa9ca fix: update session management icon in sidebar
- Changed the icon for the session management sidebar item from 'mdi-account-group' to 'mdi-pencil-ruler' for better representation.
2025-11-27 14:00:05 +08:00
Soulter a389842e25 feat: update session management UI with information button and layout adjustments
- Added an information button linking to custom rules documentation.
- Adjusted layout for improved spacing and readability in the session management page.
- Minor refactoring of the data table component for better alignment.
2025-11-27 13:58:37 +08:00
Soulter 0f6a3c3f5a refactor: session management custom rules (#3792)
* refactor: umo custom rules

* feat(i18n): update session management translations and improve provider configuration handling

- Updated English and Chinese translations for session management, including "Unified Message Origin" and "Follow Config".
- Enhanced provider configuration options to include "Follow Config" as a selectable item.
- Removed unused clear buttons and refactored provider configuration saving logic to handle updates and deletions more efficiently.
2025-11-27 13:30:43 +08:00
Soulter 133f27422d feat: implement i18n of astrbot config (#3772)
* feat: implement i18n of astrbot config

* feat(config): update configuration metadata with i18n details and future deprecation notes
2025-11-26 16:40:58 +08:00
RC-CHN abc6deb244 feat: add plugin logo placeholder (#3784) 2025-11-26 16:22:11 +08:00
teapot1de 06869b4597 docs: clarify segmented_reply words_count_threshold hint (#3779)
Update the configuration hint for `words_count_threshold` to explicitly state that it acts as a maximum limit for segmentation, preventing user confusion about it being a minimum trigger.
2025-11-26 16:15:09 +08:00
dependabot[bot] d32cea9870 chore(deps): bump actions/checkout in the github-actions group (#3775)
Bumps the github-actions group with 1 update: [actions/checkout](https://github.com/actions/checkout).


Updates `actions/checkout` from 5 to 6
- [Release notes](https://github.com/actions/checkout/releases)
- [Changelog](https://github.com/actions/checkout/blob/main/CHANGELOG.md)
- [Commits](https://github.com/actions/checkout/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/checkout
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-11-26 16:13:42 +08:00
Soulter 4b68100f16 feat(chat): add standalone chat component and integrate with config page for testing configurations (#3767)
* feat(chat): add standalone chat component and integrate with config page for testing configurations

* feat(chat): add error handling for message sending and session creation
2025-11-24 22:06:02 +08:00
Soulter 5c5515d462 fix: segmented reply regex error handling (#3771)
* fix: segmented reply regex error handling

closes: #3761

* fix: improve regex handling for segmented replies to support multiline input

* fix: update regex handling in ResultDecorateStage to use findall for segmented replies

* fix: update error logging message for segmented reply regex handling
2025-11-24 22:00:59 +08:00
Soulter 3932b8f982 Merge pull request #3760 from AstrBotDevs/feat/agent-runner
refactor: transfer dify, coze and alibaba dashscope from chat provider to agent runner
2025-11-24 15:33:20 +08:00
Soulter 82488ca900 feat(api): enhance file upload method to support mime type and file name 2025-11-24 15:30:49 +08:00
Soulter 29d9b9b2d6 feat(config): add condition for display_reasoning_text based on agent_runner_type 2025-11-24 15:10:17 +08:00
Soulter 02215e9b7b feat(config): update hint for agent_runner execution method to clarify third-party integration 2025-11-24 15:07:33 +08:00
Soulter 7160b7a18b fix: dify workflow streaming mode 2025-11-24 15:04:15 +08:00
Soulter ea8dac837a feat(config): enhance hint for agent_runner execution method in configuration 2025-11-24 14:42:36 +08:00
Soulter e2a7a028bd feat(migration): enhance migration process with error handling and agent runner config updates 2025-11-24 14:37:25 +08:00
Soulter 70db8d264b fix(config): disable auto_save_history option in configuration 2025-11-24 14:25:14 +08:00
Soulter 0518e6d487 feat(config): add hint for agent_runner execution method in configuration 2025-11-24 14:23:53 +08:00
Soulter 39eb367866 perf: improve file structure
- Implemented CozeAPIClient for file upload, image download, chat messaging, and context management.
- Developed DashscopeAgentRunner for handling requests to the Dashscope API with streaming support.
- Created DifyAgentRunner to manage interactions with the Dify API, including file uploads and workflow execution.
- Introduced DifyAPIClient for making asynchronous requests to the Dify API.
- Updated third-party agent imports to reflect new module structure.
2025-11-24 14:00:16 +08:00
Soulter f1d51a22ad feat(dashscope_agent_runner): refactor request payload construction and enhance streaming response handling 2025-11-24 13:21:34 +08:00
Soulter 77fb554e8f feat(dashscope_agent_runner): implement streaming response handling and request payload construction 2025-11-24 13:09:57 +08:00
Soulter 91f8a0ae09 fix(provider_manager): use get method for provider_type check in load_provider 2025-11-24 10:57:13 +08:00
Soulter 370cda7cf0 feat(dify_api_client): add docstring for file_upload method 2025-11-24 10:53:50 +08:00
Soulter 66b3eed273 fix: correct typo in agent state transition log message 2025-11-24 00:03:22 +08:00
Soulter 99b061a143 fix: make session properties required in Session interface 2025-11-23 23:25:29 +08:00
Soulter 5f3c7ed673 feat(conversation): update agent runner type configuration path to provider_settings 2025-11-23 23:05:36 +08:00
Soulter a6dc458212 feat(third-party-agent): implement streaming response handling and enhance agent execution flow 2025-11-23 23:03:56 +08:00
Soulter 520f521887 feat(provider): enhance agent runner provider selection with subtype filtering 2025-11-23 22:23:23 +08:00
Soulter 01427d9969 feat(config): add hint for non-built-in agent execution model configuration 2025-11-23 22:13:52 +08:00
Soulter 34c03ce983 Merge remote-tracking branch 'origin/master' into feat/agent-runner 2025-11-23 22:06:52 +08:00
Soulter 95e9da42d6 fix(webchat): webchat session cannot be deleted (#3759) 2025-11-23 22:03:07 +08:00
Soulter 1338cab61b feat: add configuration selector for session management and enhance session handling in chat components 2025-11-23 21:53:56 +08:00
Soulter 7ba98c1e91 feat: enhance provider display with grouped categorization and improved filtering 2025-11-23 21:06:16 +08:00
Soulter 9a5f507cbe feat: enable agent runner providers in configuration 2025-11-23 20:58:18 +08:00
Soulter d560671d1f feat: agent runner config migration 2025-11-23 20:54:19 +08:00
Soulter 82c9cf4db6 chore: remove legacy coze and dashscope provider 2025-11-23 20:18:51 +08:00
Soulter 910ec6c695 feat: implement third party agent sub stage and refactor provider management
- Added `ThirdPartyAgentSubStage` to handle interactions with third-party agent runners (Dify, Coze, Dashscope).
- Refactored `star_request.py` to ensure consistent return types in the `process` method.
- Updated `stage.py` to initialize and utilize the new `AgentRequestSubStage`.
- Modified `ProviderManager` to skip loading agent runner providers.
- Removed `Dify` source implementation as it is now handled by the new agent runner structure.
- Enhanced `DifyAPIClient` to support file uploads via both file path and file data.
- Cleaned up shared preferences handling to simplify session preference retrieval.
- Updated dashboard configuration to reflect changes in agent runner provider selection.
- Refactored conversation commands to accommodate the new agent runner structure and remove direct dependencies on Dify.
- Adjusted main application logic to ensure compatibility with the new conversation management approach.
2025-11-23 20:18:51 +08:00
Soulter 766d6f2bec fix(conversation): update session configuration retrieval to use unified message origin 2025-11-23 20:18:51 +08:00
Soulter 9f39140987 fix(conversation): update session configuration retrieval to use unified message origin 2025-11-23 19:59:21 +08:00
Soulter 89716ef4da Merge remote-tracking branch 'origin/master' into feat/agent-runner 2025-11-23 14:48:08 +08:00
Soulter 61a68477d0 stage 2025-10-21 14:19:38 +08:00
Soulter e74f626383 stage 2025-10-21 09:55:14 +08:00
Soulter ef99f64291 feat(config): 添加 agent 运行器类型及相关配置支持 2025-10-21 00:47:04 +08:00
249 changed files with 13356 additions and 5923 deletions
+2 -2
View File
@@ -13,7 +13,7 @@ jobs:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
- name: Dashboard Build
run: |
@@ -70,7 +70,7 @@ jobs:
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
+1 -1
View File
@@ -12,7 +12,7 @@ jobs:
steps:
- name: Checkout code
uses: actions/checkout@v5
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
+1 -1
View File
@@ -56,7 +56,7 @@ jobs:
# your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
+1 -1
View File
@@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
fetch-depth: 0
+1 -1
View File
@@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@v5
uses: actions/checkout@v6
- name: Setup Node.js
uses: actions/setup-node@v6
+2 -2
View File
@@ -20,7 +20,7 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
fetch-depth: 1
fetch-tag: true
@@ -118,7 +118,7 @@ jobs:
steps:
- name: Checkout
uses: actions/checkout@v5
uses: actions/checkout@v6
with:
fetch-depth: 1
fetch-tag: true
+58
View File
@@ -0,0 +1,58 @@
name: Smoke Test
on:
push:
branches:
- master
paths-ignore:
- 'README*.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
smoke-test:
name: Run smoke tests
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install UV package manager
run: |
pip install uv
- name: Install dependencies
run: |
uv sync
timeout-minutes: 15
- name: Run smoke tests
run: |
uv run main.py &
APP_PID=$!
echo "Waiting for application to start..."
for i in {1..60}; do
if curl -f http://localhost:6185 > /dev/null 2>&1; then
echo "Application started successfully!"
kill $APP_PID
exit 0
fi
sleep 1
done
echo "Application failed to start within 30 seconds"
kill $APP_PID 2>/dev/null || true
exit 1
timeout-minutes: 2
+3
View File
@@ -34,6 +34,7 @@ dashboard/node_modules/
dashboard/dist/
package-lock.json
package.json
yarn.lock
# Operating System
**/.DS_Store
@@ -47,3 +48,5 @@ astrbot.lock
chroma
venv/*
pytest.ini
AGENTS.md
IFLOW.md
+65
View File
@@ -0,0 +1,65 @@
# CONTRIBUTING
## 贡献指南
首先,感谢您花时间做出贡献!❤️
所有类型的贡献都受到鼓励和重视。有关不同的帮助方式和处理方式的详细信息,请参阅[目录](#目录)。在做出贡献之前,请确保阅读相关部分。这将使我们维护人员的工作变得更加容易,并为所有参与者带来顺畅的体验。社区期待您的贡献。🎉
### 目录
- [报告问题](#报告问题)
- [提交代码更改](#提交代码更改)
### 报告问题
如果您在使用 AstrBot 时遇到任何问题,请按照以下步骤报告:
1. **检查现有问题**:在提交新问题之前,请先检查 [Issues](https://github.com/AstrBotDevs/AstrBot/issues) 中是否已经存在类似的问题。
2. **创建新问题**:如果没有类似的问题,请创建一个新问题。请确保提供以下信息:
- 问题的简要描述
- 重现问题的步骤
- 预期结果和实际结果
- 相关日志或错误消息
### 提交代码更改
#### 分支命名
我们使用 `fix/` 前缀来修复错误,使用 `feat/` 前缀来添加新功能。对于 `fix/` 分支,请使用简短的描述,或者直接使用 Issue 编号。例如:`fix/1234` 或者 `fix/1234-login-typo`。对于 `feat/` 分支,请使用简短的描述,例如:`feat/add-user-profile`
#### PR 描述
- 请使用英文描述您的 PR。
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`
## Contributing Guide
First off, thanks for taking the time to contribute! ❤️
All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉
### Table of Contents
- [Reporting Issues](#reporting-issues)
- [Pull Requests](#pull-requests)
### Reporting Issues
If you encounter any issues while using AstrBot, please follow these steps to report them:
1. **Check Existing Issues**: Before submitting a new issue, please check if a similar issue already exists in the [Issues](https://github.com/AstrBotDevs/AstrBot/issues) section of the repository.
2. **Create a New Issue**: If no similar issue exists, please create a new issue. Make sure to provide the following information:
- A brief description of the issue
- Steps to reproduce the issue
- Expected and actual results
- Relevant logs or error messages
### Pull Requests
#### Branch Naming
We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features. For `fix/` branches, please use a short description or directly use the Issue number, e.g., `fix/1234` or `fix/1234-login-typo`. For `feat/` branches, please use a short description, e.g., `feat/add-user-profile`.
#### PR Description
- Please use English to describe your PR.
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
+49 -36
View File
@@ -1,10 +1,13 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p>
<div align="center">
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
@@ -14,35 +17,38 @@
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://astrbot.app/">文档</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">路线图</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
</div>
AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可无缝接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。
AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
## 主要功能
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核
3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富
5. **WebUI**。可视化配置和管理机器人,功能齐全
1. 💯 免费 & 开源
1. ✨ AI 大模型对话,多模态,Agent,MCP,知识库,人格设定
2. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台
2. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)
3. 📦 插件扩展,已有近 800 个插件可一键安装
5. 💻 WebUI 支持。
6. 🌐 国际化(i18n)支持。
## 部署方式
## 快速开始
#### Docker 部署(推荐 🥳)
@@ -50,6 +56,12 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可无缝接
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
#### uv 部署
```bash
uvx astrbot
```
#### 宝塔面板部署
AstrBot 与宝塔面板合作,已上架至宝塔面板。
@@ -101,24 +113,6 @@ uv run main.py
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
## 🌍 社区
### QQ 群组
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 开发者群:975206796
### Telegram 群组
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord 群组
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## 支持的消息平台
**官方维护**
@@ -205,6 +199,25 @@ pip install pre-commit
pre-commit install
```
## 🌍 社区
### QQ 群组
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 开发者群:975206796
### Telegram 群组
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord 群组
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ❤️ Special Thanks
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
+40 -26
View File
@@ -19,30 +19,38 @@
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&style=for-the-badge&label=Marketplace&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<a href="https://astrbot.app/">Documentation</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">Roadmap</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue Tracker</a>
</div>
AstrBot is an open-source all-in-one Agent chatbot platform and development framework.
AstrBot is an open-source all-in-one Agent chatbot platform that integrates with mainstream instant messaging apps. It provides reliable and scalable conversational AI infrastructure for individuals, developers, and teams. Whether you're building a personal AI companion, intelligent customer service, automation assistant, or enterprise knowledge base, AstrBot enables you to quickly build production-ready AI applications within your IM platform workflows.
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
## Key Features
1. **LLM Conversations**. Supports integration with various large language model services. Features include multimodal capabilities, tool calling, MCP, native knowledge base, character personas, and more.
2. **Multi-Platform Support**. Integrates with QQ, WeChat Work, WeChat Official Accounts, Feishu, Telegram, DingTalk, Discord, KOOK, and other platforms. Supports rate limiting, whitelisting, and Baidu content moderation.
3. **Agent Capabilities**. Fully optimized agentic features including multi-turn tool calling, built-in sandboxed code executor, web search, and more.
4. **Plugin Extensions**. Deeply optimized plugin mechanism supporting [plugin development](https://astrbot.app/dev/plugin.html) to extend functionality, with a rich community plugin ecosystem.
5. **Web UI**. Visual configuration and management of your bot with comprehensive features.
1. 💯 Free & Open Source.
2. ✨ AI LLM Conversations, Multimodal, Agent, MCP, Knowledge Base, Persona Settings.
3. 🤖 Supports integration with Dify, Alibaba Cloud Bailian, Coze and other agent platforms.
4. 🌐 Multi-Platform: QQ, WeChat Work, Feishu, DingTalk, WeChat Official Accounts, Telegram, Slack, and [more](#supported-messaging-platforms).
5. 📦 Plugin Extensions with nearly 800 plugins available for one-click installation.
6. 💻 WebUI Support.
7. 🌐 Internationalization (i18n) Support.
## Deployment Methods
## Quick Start
#### Docker Deployment (Recommended 🥳)
@@ -50,6 +58,12 @@ We recommend deploying AstrBot using Docker or Docker Compose.
Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
#### uv Deployment
```bash
uvx astrbot
```
#### BT-Panel Deployment
AstrBot has partnered with BT-Panel and is now available in their marketplace.
@@ -101,24 +115,6 @@ uv run main.py
Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
## 🌍 Community
### QQ Groups
- Group 1: 322154837
- Group 3: 630166526
- Group 5: 822130018
- Group 6: 753075035
- Developer Group: 975206796
### Telegram Group
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord Server
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## Supported Messaging Platforms
**Officially Maintained**
@@ -205,6 +201,24 @@ pip install pre-commit
pre-commit install
```
## 🌍 Community
### QQ Groups
- Group 1: 322154837
- Group 3: 630166526
- Group 5: 822130018
- Group 6: 753075035
- Developer Group: 975206796
### Telegram Group
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord Server
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ❤️ Special Thanks
Special thanks to all Contributors and plugin developers for their contributions to AstrBot ❤️
+248
View File
@@ -0,0 +1,248 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p>
<div align="center">
<br>
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&style=for-the-badge&label=Marketplace&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<a href="https://astrbot.app/">Documentation</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">Feuille de route</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Signaler un problème</a>
</div>
AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègre aux principales applications de messagerie instantanée. Elle fournit une infrastructure d'IA conversationnelle fiable et évolutive pour les particuliers, les développeurs et les équipes. Que vous construisiez un compagnon IA personnel, un service client intelligent, un assistant d'automatisation ou une base de connaissances d'entreprise, AstrBot vous permet de créer rapidement des applications d'IA prêtes pour la production dans les flux de travail de votre plateforme de messagerie.
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
## Fonctionnalités principales
1. 💯 Gratuit & Open Source.
2. ✨ Conversations avec LLM IA, Multimodal, Agent, MCP, Base de connaissances, Paramètres de personnalité.
3. 🤖 Prise en charge de l'intégration avec Dify, Alibaba Cloud Bailian, Coze et autres plateformes d'agents.
4. 🌐 Multi-plateforme : QQ, WeChat Work, Feishu, DingTalk, Comptes officiels WeChat, Telegram, Slack, et [plus encore](#plateformes-de-messagerie-prises-en-charge).
5. 📦 Extensions de plugins avec près de 800 plugins disponibles pour une installation en un clic.
6. 💻 Support WebUI.
7. 🌐 Support de l'internationalisation (i18n).
## Démarrage rapide
#### Déploiement Docker (Recommandé 🥳)
Nous recommandons de déployer AstrBot en utilisant Docker ou Docker Compose.
Veuillez consulter la documentation officielle : [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
#### Déploiement uv
```bash
uvx astrbot
```
#### Déploiement BT-Panel
AstrBot s'est associé à BT-Panel et est maintenant disponible sur leur marketplace.
Veuillez consulter la documentation officielle : [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
#### Déploiement 1Panel
AstrBot a été officiellement listé sur le marketplace 1Panel.
Veuillez consulter la documentation officielle : [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
#### Déployer sur RainYun
AstrBot a été officiellement listé sur la plateforme d'applications cloud de RainYun avec un déploiement en un clic.
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### Déployer sur Replit
Méthode de déploiement contribuée par la communauté.
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
#### Installateur Windows en un clic
Veuillez consulter la documentation officielle : [Déployer AstrBot avec l'installateur Windows en un clic](https://astrbot.app/deploy/astrbot/windows.html).
#### Déploiement CasaOS
Méthode de déploiement contribuée par la communauté.
Veuillez consulter la documentation officielle : [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
#### Déploiement manuel
Tout d'abord, installez uv :
```bash
pip install uv
```
Installez AstrBot via Git Clone :
```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources](https://astrbot.app/deploy/astrbot/cli.html).
## Plateformes de messagerie prises en charge
**Maintenues officiellement**
- QQ (Plateforme officielle & OneBot)
- Telegram
- Application WeChat Work & Bot intelligent WeChat Work
- Service client WeChat & Comptes officiels WeChat
- Feishu (Lark)
- DingTalk
- Slack
- Discord
- Satori
- Misskey
- WhatsApp (Bientôt disponible)
- LINE (Bientôt disponible)
**Maintenues par la communauté**
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
## Services de modèles pris en charge
**Services LLM**
- OpenAI et services compatibles
- Anthropic
- Google Gemini
- Moonshot AI
- Zhipu AI
- DeepSeek
- Ollama (Auto-hébergé)
- LM Studio (Auto-hébergé)
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [TokenPony](https://www.tokenpony.cn/3YPyf)
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
**Plateformes LLMOps**
- Dify
- Applications Alibaba Cloud Bailian
- Coze
**Services de reconnaissance vocale**
- OpenAI Whisper
- SenseVoice
**Services de synthèse vocale**
- OpenAI TTS
- Gemini TTS
- GPT-Sovits-Inference
- GPT-Sovits
- FishAudio
- Edge TTS
- Alibaba Cloud Bailian TTS
- Azure TTS
- Minimax TTS
- Volcano Engine TTS
## ❤️ Contribuer
Les Issues et Pull Requests sont toujours les bienvenues ! N'hésitez pas à soumettre vos modifications à ce projet :)
### Comment contribuer
Vous pouvez contribuer en examinant les issues ou en aidant à la revue des pull requests. Toutes les issues ou PRs sont les bienvenues pour encourager la participation de la communauté. Bien sûr, ce ne sont que des suggestions - vous pouvez contribuer de la manière que vous souhaitez. Pour l'ajout de nouvelles fonctionnalités, veuillez d'abord en discuter via une Issue.
### Environnement de développement
AstrBot utilise `ruff` pour le formatage et le linting du code.
```bash
git clone https://github.com/AstrBotDevs/AstrBot
pip install pre-commit
pre-commit install
```
## 🌍 Communauté
### Groupes QQ
- Groupe 1 : 322154837
- Groupe 3 : 630166526
- Groupe 5 : 822130018
- Groupe 6 : 753075035
- Groupe développeurs : 975206796
### Groupe Telegram
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Serveur Discord
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ❤️ Remerciements spéciaux
Un grand merci à tous les contributeurs et développeurs de plugins pour leurs contributions à AstrBot ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
De plus, la naissance de ce projet n'aurait pas été possible sans l'aide des projets open source suivants :
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - L'incroyable framework chat
## ⭐ Historique des étoiles
> [!TIP]
> Si ce projet vous a aidé dans votre vie ou votre travail, ou si vous êtes intéressé par son développement futur, veuillez donner une étoile au projet. C'est la force motrice derrière la maintenance de ce projet open source <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date)
</div>
</details>
_私は、高性能ですから!_
+40 -26
View File
@@ -19,30 +19,38 @@
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E3%83%97%E3%83%A9%E3%82%B0%E3%82%A4%E3%83%B3&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<a href="https://astrbot.app/">ドキュメント</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">ロードマップ</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Issue</a>
</div>
AstrBot は、オープンソースのオールインワン Agent チャットボットプラットフォーム及び開発フレームワークす。
AstrBot は、主要なインスタントメッセージングアプリと統合できるオープンソースのオールインワン Agent チャットボットプラットフォームです。個人、開発者、チームに信頼性が高くスケーラブルな会話型 AI インフラストラクチャを提供します。パーソナル AI コンパニオン、インテリジェントカスタマーサービス、オートメーションアシスタント、エンタープライズナレッジベースなど、AstrBot を使用すると、IM プラットフォームワークフロー内で本番環境対応の AI アプリケーションを迅速に構築できます。
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
## 主な機能
1. **大規模言語モデル対話**。多様な大規模言語モデルサービスとの統合をサポート。マルチモーダル、ツール呼び出し、MCP、ネイティブナレッジベース、キャラクター設定などの機能を搭載
2. **マルチメッセージプラットフォームサポート**。QQ、WeChat Work、WeChat公式アカウント、Feishu、Telegram、DingTalk、Discord、KOOK などのプラットフォームと統合可能。レート制限、ホワイトリスト、Baidu コンテンツ審査をサポート
3. **Agent**。完全に最適化された Agentic 機能。マルチターンツール呼び出し、内蔵サンドボックスコード実行環境、Web 検索などの機能をサポート。
4. **プラグイン拡張**。深く最適化されたプラグインメカニズムで、[プラグイン開発](https://astrbot.app/dev/plugin.html)による機能拡張をサポート。豊富なコミュニティプラグインエコシステム
5. **WebUI**。ビジュアル設定とボット管理、充実した機能。
1. 💯 無料 & オープンソース
2. ✨ AI 大規模言語モデル対話、マルチモーダル、Agent、MCP、ナレッジベース、ペルソナ設定
3. 🤖 Dify、Alibaba Cloud 百炼、Coze などの Agent プラットフォームとの統合をサポート。
4. 🌐 マルチプラットフォーム:QQ、WeChat Work、Feishu、DingTalk、WeChat 公式アカウント、Telegram、Slack、[その他](#サポートされているメッセージプラットフォーム)
5. 📦 約800個のプラグインをワンクリックでインストール可能なプラグイン拡張機能。
6. 💻 WebUI サポート。
7. 🌐 国際化(i18n)サポート。
## デプロイ方法
## クイックスタート
#### Docker デプロイ(推奨 🥳)
@@ -50,6 +58,12 @@ Docker / Docker Compose を使用した AstrBot のデプロイを推奨しま
公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。
#### uv デプロイ
```bash
uvx astrbot
```
#### 宝塔パネルデプロイ
AstrBot は宝塔パネルと提携し、宝塔パネルに公開されています。
@@ -101,24 +115,6 @@ uv run main.py
または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。
## 🌍 コミュニティ
### QQ グループ
- 1群:322154837
- 3群:630166526
- 5群:822130018
- 6群:753075035
- 開発者群:975206796
### Telegram グループ
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord サーバー
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## サポートされているメッセージプラットフォーム
**公式メンテナンス**
@@ -205,6 +201,24 @@ pip install pre-commit
pre-commit install
```
## 🌍 コミュニティ
### QQ グループ
- 1群: 322154837
- 3群: 630166526
- 5群: 822130018
- 6群: 753075035
- 開発者群: 975206796
### Telegram グループ
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord サーバー
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ❤️ Special Thanks
AstrBot への貢献をしていただいたすべてのコントリビューターとプラグイン開発者に特別な感謝を ❤️
+248
View File
@@ -0,0 +1,248 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p>
<div align="center">
<br>
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20%D0%BF%D0%BB%D0%B0%D0%B3%D0%B8%D0%BD%D0%BE%D0%B2&style=for-the-badge&label=%D0%9C%D0%B0%D0%B3%D0%B0%D0%B7%D0%B8%D0%BD&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://astrbot.app/">Документация</a>
<a href="https://blog.astrbot.app/">Блог</a>
<a href="https://astrbot.featurebase.app/roadmap">Дорожная карта</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">Сообщить о проблеме</a>
</div>
AstrBot — это универсальная платформа Agent-чатботов с открытым исходным кодом, которая интегрируется с основными приложениями для обмена мгновенными сообщениями. Она предоставляет надёжную и масштабируемую инфраструктуру разговорного ИИ для частных лиц, разработчиков и команд. Будь то персональный ИИ-компаньон, интеллектуальная служба поддержки, автоматизированный помощник или корпоративная база знаний — AstrBot позволяет быстро создавать готовые к использованию ИИ-приложения в рабочих процессах вашей платформы обмена сообщениями.
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
## Основные возможности
1. 💯 Бесплатно и с открытым исходным кодом.
2. ✨ ИИ-диалоги с LLM, мультимодальность, Agent, MCP, база знаний, настройки личности.
3. 🤖 Поддержка интеграции с Dify, Alibaba Cloud Bailian, Coze и другими платформами агентов.
4. 🌐 Мультиплатформенность: QQ, WeChat Work, Feishu, DingTalk, официальные аккаунты WeChat, Telegram, Slack и [другие](#поддерживаемые-платформы-обмена-сообщениями).
5. 📦 Расширения плагинов с почти 800 плагинами, доступными для установки в один клик.
6. 💻 Поддержка WebUI.
7. 🌐 Поддержка интернационализации (i18n).
## Быстрый старт
#### Развёртывание Docker (Рекомендуется 🥳)
Мы рекомендуем развёртывать AstrBot с помощью Docker или Docker Compose.
См. официальную документацию: [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot).
#### Развёртывание uv
```bash
uvx astrbot
```
#### Развёртывание BT-Panel
AstrBot в партнёрстве с BT-Panel теперь доступен на их маркетплейсе.
См. официальную документацию: [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html).
#### Развёртывание 1Panel
AstrBot официально размещён на маркетплейсе 1Panel.
См. официальную документацию: [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html).
#### Развёртывание на RainYun
AstrBot официально размещён на облачной платформе приложений RainYun с развёртыванием в один клик.
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### Развёртывание на Replit
Метод развёртывания от сообщества.
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
#### Установщик Windows в один клик
См. официальную документацию: [Развёртывание AstrBot с установщиком Windows в один клик](https://astrbot.app/deploy/astrbot/windows.html).
#### Развёртывание CasaOS
Метод развёртывания от сообщества.
См. официальную документацию: [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html).
#### Ручное развёртывание
Сначала установите uv:
```bash
pip install uv
```
Установите AstrBot через Git Clone:
```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
Или см. официальную документацию: [Развёртывание AstrBot из исходного кода](https://astrbot.app/deploy/astrbot/cli.html).
## Поддерживаемые платформы обмена сообщениями
**Официально поддерживаемые**
- QQ (Официальная платформа и OneBot)
- Telegram
- Приложение WeChat Work и интеллектуальный бот WeChat Work
- Служба поддержки WeChat и официальные аккаунты WeChat
- Feishu (Lark)
- DingTalk
- Slack
- Discord
- Satori
- Misskey
- WhatsApp (Скоро)
- LINE (Скоро)
**Поддерживаемые сообществом**
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
## Поддерживаемые сервисы моделей
**Сервисы LLM**
- OpenAI и совместимые сервисы
- Anthropic
- Google Gemini
- Moonshot AI
- Zhipu AI
- DeepSeek
- Ollama (Самостоятельное размещение)
- LM Studio (Самостоятельное размещение)
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [TokenPony](https://www.tokenpony.cn/3YPyf)
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
**Платформы LLMOps**
- Dify
- Приложения Alibaba Cloud Bailian
- Coze
**Сервисы распознавания речи**
- OpenAI Whisper
- SenseVoice
**Сервисы синтеза речи**
- OpenAI TTS
- Gemini TTS
- GPT-Sovits-Inference
- GPT-Sovits
- FishAudio
- Edge TTS
- Alibaba Cloud Bailian TTS
- Azure TTS
- Minimax TTS
- Volcano Engine TTS
## ❤️ Вклад в проект
Issues и Pull Request всегда приветствуются! Не стесняйтесь отправлять свои изменения в этот проект :)
### Как внести вклад
Вы можете внести вклад, просматривая issues или помогая с ревью pull request. Любые issues или PR приветствуются для поощрения участия сообщества. Конечно, это лишь предложения — вы можете вносить вклад любым удобным для вас способом. Для добавления новых функций сначала обсудите это через Issue.
### Среда разработки
AstrBot использует `ruff` для форматирования и линтинга кода.
```bash
git clone https://github.com/AstrBotDevs/AstrBot
pip install pre-commit
pre-commit install
```
## 🌍 Сообщество
### Группы QQ
- Группа 1: 322154837
- Группа 3: 630166526
- Группа 5: 822130018
- Группа 6: 753075035
- Группа разработчиков: 975206796
### Группа Telegram
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Сервер Discord
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ❤️ Особая благодарность
Особая благодарность всем контрибьюторам и разработчикам плагинов за их вклад в AstrBot ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
Кроме того, рождение этого проекта было бы невозможно без помощи следующих проектов с открытым исходным кодом:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - Замечательный кошачий фреймворк
## ⭐ История звёзд
> [!TIP]
> Если этот проект помог вам в жизни или работе, или если вас интересует его будущее развитие, пожалуйста, поставьте проекту звезду. Это движущая сила поддержки этого проекта с открытым исходным кодом <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date)
</div>
</details>
_私は、高性能ですから!_
+248
View File
@@ -0,0 +1,248 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
</p>
<div align="center">
<br>
<div>
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
<a href="https://hellogithub.com/repository/AstrBotDevs/AstrBot" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=d127d50cd5e54c5382328acc3bb25483&claim_uid=ZO9by7qCXgSd6Lp&t=2" alt="FeaturedHelloGitHub" style="width: 250px; height: 54px;" width="250" height="54" /></a>
</div>
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%A0%B4&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">简体中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<a href="https://astrbot.app/">文件</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">路線圖</a>
<a href="https://github.com/AstrBotDevs/AstrBot/issues">問題回報</a>
</div>
AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主流即時通訊軟體,為個人、開發者和團隊打造可靠、可擴展的對話式智慧基礎設施。無論是個人 AI 夥伴、智慧客服、自動化助手,還是企業知識庫,AstrBot 都能在您的即時通訊軟體平台的工作流程中快速構建生產可用的 AI 應用程式。
<img width="1776" height="1080" alt="image" src="https://github.com/user-attachments/assets/00782c4c-4437-4d97-aabc-605e3738da5c" />
## 主要功能
1. 💯 免費 & 開源。
2. ✨ AI 大型模型對話,多模態,Agent,MCP,知識庫,人格設定。
3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體平台。
4. 🌐 多平台:QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。
5. 📦 外掛擴充,已有近 800 個外掛可一鍵安裝。
6. 💻 WebUI 支援。
7. 🌐 國際化(i18n)支援。
## 快速開始
#### Docker 部署(推薦 🥳)
推薦使用 Docker / Docker Compose 方式部署 AstrBot。
請參閱官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。
#### uv 部署
```bash
uvx astrbot
```
#### 寶塔面板部署
AstrBot 與寶塔面板合作,已上架至寶塔面板。
請參閱官方文件 [寶塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html)。
#### 1Panel 部署
AstrBot 已由 1Panel 官方上架至 1Panel 面板。
請參閱官方文件 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html)。
#### 在雨雲上部署
AstrBot 已由雨雲官方上架至雲端應用程式平台,可一鍵部署。
[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0)
#### 在 Replit 上部署
社群貢獻的部署方式。
[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot)
#### Windows 一鍵安裝器部署
請參閱官方文件 [使用 Windows 一鍵安裝器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html)。
#### CasaOS 部署
社群貢獻的部署方式。
請參閱官方文件 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html)。
#### 手動部署
首先安裝 uv
```bash
pip install uv
```
透過 Git Clone 安裝 AstrBot
```bash
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
uv run main.py
```
或者請參閱官方文件 [透過原始碼部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)。
## 支援的訊息平台
**官方維護**
- QQ(官方平台 & OneBot
- Telegram
- 企微應用 & 企微智慧機器人
- 微信客服 & 微信公眾號
- 飛書
- 釘釘
- Slack
- Discord
- Satori
- Misskey
- Whatsapp(即將支援)
- LINE(即將支援)
**社群維護**
- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter)
- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat)
- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter)
- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11)
## 支援的模型服務
**大型模型服務**
- OpenAI 及相容服務
- Anthropic
- Google Gemini
- Moonshot AI
- 智譜 AI
- DeepSeek
- Ollama(本機部署)
- LM Studio(本機部署)
- [優雲智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [小馬算力](https://www.tokenpony.cn/3YPyf)
- [矽基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
- [PPIO 派歐雲](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
**LLMOps 平台**
- Dify
- 阿里雲百煉應用
- Coze
**語音轉文字服務**
- OpenAI Whisper
- SenseVoice
**文字轉語音服務**
- OpenAI TTS
- Gemini TTS
- GPT-Sovits-Inference
- GPT-Sovits
- FishAudio
- Edge TTS
- 阿里雲百煉 TTS
- Azure TTS
- Minimax TTS
- 火山引擎 TTS
## ❤️ 貢獻
歡迎任何 Issues/Pull Requests!只需要將您的變更提交到此專案 :)
### 如何貢獻
您可以透過檢視問題或協助審核 PR(拉取請求)來貢獻。任何問題或 PR 都歡迎參與,以促進社群貢獻。當然,這些只是建議,您可以以任何方式進行貢獻。對於新功能的新增,請先透過 Issue 討論。
### 開發環境
AstrBot 使用 `ruff` 進行程式碼格式化和檢查。
```bash
git clone https://github.com/AstrBotDevs/AstrBot
pip install pre-commit
pre-commit install
```
## 🌍 社群
### QQ 群組
- 1 群:322154837
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 開發者群:975206796
### Telegram 群組
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
### Discord 群組
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
## ❤️ Special Thanks
特別感謝所有 Contributors 和外掛開發者對 AstrBot 的貢獻 ❤️
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
</a>
此外,本專案的誕生離不開以下開源專案的幫助:
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 偉大的貓貓框架
## ⭐ Star History
> [!TIP]
> 如果本專案對您的生活 / 工作產生了幫助,或者您關注本專案的未來發展,請給專案 Star,這是我們維護這個開源專案的動力 <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date)
</div>
</details>
_私は、高性能ですから!_
+1 -1
View File
@@ -1 +1 @@
__version__ = "3.5.23"
__version__ = "4.9.0"
+3 -3
View File
@@ -345,9 +345,6 @@ class MCPClient:
async def cleanup(self):
"""Clean up resources including old exit stacks from reconnections"""
# Set running_event first to unblock any waiting tasks
self.running_event.set()
# Close current exit stack
try:
await self.exit_stack.aclose()
@@ -359,6 +356,9 @@ class MCPClient:
# Just clear the list to release references
self._old_exit_stacks.clear()
# Set running_event first to unblock any waiting tasks
self.running_event.set()
class MCPTool(FunctionTool, Generic[TContext]):
"""A function tool that calls an MCP service."""
+21 -4
View File
@@ -3,7 +3,7 @@
from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
from pydantic_core import core_schema
@@ -145,22 +145,39 @@ class Message(BaseModel):
"tool",
]
content: str | list[ContentPart]
content: str | list[ContentPart] | None = None
"""The content of the message."""
tool_calls: list[ToolCall] | list[dict] | None = None
"""The tool calls of the message."""
tool_call_id: str | None = None
"""The ID of the tool call."""
@model_validator(mode="after")
def check_content_required(self):
# assistant + tool_calls is not None: allow content to be None
if self.role == "assistant" and self.tool_calls is not None:
return self
# other all cases: content is required
if self.content is None:
raise ValueError(
"content is required unless role='assistant' and tool_calls is not None"
)
return self
class AssistantMessageSegment(Message):
"""A message segment from the assistant."""
role: Literal["assistant"] = "assistant"
tool_calls: list[ToolCall] | list[dict] | None = None
class ToolCallMessageSegment(Message):
"""A message segment representing a tool call."""
role: Literal["tool"] = "tool"
tool_call_id: str
class UserMessageSegment(Message):
+7 -4
View File
@@ -2,13 +2,12 @@ import abc
import typing as T
from enum import Enum, auto
from astrbot.core.provider import Provider
from astrbot import logger
from astrbot.core.provider.entities import LLMResponse
from ..hooks import BaseAgentRunHooks
from ..response import AgentResponse
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
class AgentState(Enum):
@@ -24,9 +23,7 @@ class BaseAgentRunner(T.Generic[TContext]):
@abc.abstractmethod
async def reset(
self,
provider: Provider,
run_context: ContextWrapper[TContext],
tool_executor: BaseFunctionToolExecutor[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
**kwargs: T.Any,
) -> None:
@@ -60,3 +57,9 @@ class BaseAgentRunner(T.Generic[TContext]):
This method should be called after the agent is done.
"""
...
def _transition_state(self, new_state: AgentState) -> None:
"""Transition the agent state."""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state
@@ -0,0 +1,367 @@
import base64
import json
import sys
import typing as T
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core import sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .coze_api_client import CozeAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class CozeAgentRunner(BaseAgentRunner[TContext]):
"""Coze Agent Runner"""
@override
async def reset(
self,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://"),
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
# 会话相关缓存
self.file_id_cache: dict[str, dict[str, str]] = {}
@override
async def step(self):
"""
执行 Coze Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
# 执行 Coze 请求并处理结果
async for response in self._execute_coze_request():
yield response
except Exception as e:
logger.error(f"Coze 请求失败:{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"Coze 请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"Coze 请求失败:{str(e)}")
),
)
finally:
await self.api_client.close()
@override
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
async def _execute_coze_request(self):
"""执行 Coze 请求的核心逻辑"""
prompt = self.req.prompt or ""
session_id = self.req.session_id or "unknown"
image_urls = self.req.image_urls or []
contexts = self.req.contexts or []
system_prompt = self.req.system_prompt
# 用户ID参数
user_id = session_id
# 获取或创建会话ID
conversation_id = await sp.get_async(
scope="umo",
scope_id=user_id,
key="coze_conversation_id",
default="",
)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{
"role": "system",
"content": system_prompt,
"content_type": "text",
},
)
# 处理历史上下文
if not self.auto_save_history and contexts:
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
# 处理上下文中的图片
content = ctx["content"]
if isinstance(content, list):
# 多模态内容,需要处理图片
processed_content = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片上传
try:
image_data = item.get("image_url", {})
url = image_data.get("url", "")
if url:
file_id = (
await self._download_and_upload_image(
url, session_id
)
)
processed_content.append(
{
"type": "file",
"file_id": file_id,
"file_url": url,
}
)
except Exception as e:
logger.warning(f"处理上下文图片失败: {e}")
continue
if processed_content:
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
}
)
else:
# 纯文本内容
additional_messages.append(
{
"role": ctx["role"],
"content": content,
"content_type": "text",
}
)
# 构建当前消息
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
# the url is a base64 string
try:
image_data = base64.b64decode(url)
file_id = await self.api_client.upload_file(image_data)
object_string_content.append(
{
"type": "image",
"file_id": file_id,
}
)
except Exception as e:
logger.warning(f"处理图片失败 {url}: {e}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
}
)
elif prompt:
# 纯文本
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
},
)
# 执行 Coze API 请求
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
await sp.put_async(
scope="umo",
scope_id=user_id,
key="coze_conversation_id",
value=data["conversation_id"],
)
if event_type == "conversation.message.delta":
# 增量消息
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
accumulated_content += content
message_started = True
# 如果是流式响应,发送增量数据
if self.streaming:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(content)
),
)
elif event_type == "conversation.message.completed":
# 消息完成
logger.debug("Coze message completed")
message_started = True
elif event_type == "conversation.chat.completed":
# 对话完成
logger.debug("Coze chat completed")
break
elif event_type == "error":
# 错误处理
error_msg = data.get("msg", "未知错误")
error_code = data.get("code", "UNKNOWN")
logger.error(f"Coze 出现错误: {error_code} - {error_msg}")
raise Exception(f"Coze 出现错误: {error_code} - {error_msg}")
if not message_started and not accumulated_content:
logger.warning("Coze 未返回任何内容")
accumulated_content = ""
# 创建最终响应
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回最终结果
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=chain),
)
async def _download_and_upload_image(
self,
image_url: str,
session_id: str | None = None,
) -> str:
"""下载图片并上传到 Coze,返回 file_id"""
import hashlib
# 计算哈希实现缓存
cache_key = hashlib.md5(image_url.encode("utf-8")).hexdigest()
if session_id:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
logger.debug(f"[Coze] 使用缓存的 file_id: {file_id}")
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self.api_client.upload_file(image_data)
if session_id:
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {e!s}")
raise Exception(f"处理图片失败: {e!s}")
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
@override
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp
@@ -0,0 +1,403 @@
import asyncio
import functools
import queue
import re
import sys
import threading
import typing as T
from dashscope import Application
from dashscope.app.application_response import ApplicationResponse
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DashscopeAgentRunner(BaseAgentRunner[TContext]):
"""Dashscope Agent Runner"""
@override
async def reset(
self,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
self.api_key = provider_config.get("dashscope_api_key", "")
if not self.api_key:
raise Exception("阿里云百炼 API Key 不能为空。")
self.app_id = provider_config.get("dashscope_app_id", "")
if not self.app_id:
raise Exception("阿里云百炼 APP ID 不能为空。")
self.dashscope_app_type = provider_config.get("dashscope_app_type", "")
if not self.dashscope_app_type:
raise Exception("阿里云百炼 APP 类型不能为空。")
self.variables: dict = provider_config.get("variables", {}) or {}
self.rag_options: dict = provider_config.get("rag_options", {})
self.output_reference = self.rag_options.get("output_reference", False)
self.rag_options = self.rag_options.copy()
self.rag_options.pop("output_reference", None)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
def has_rag_options(self):
"""判断是否有 RAG 选项
Returns:
bool: 是否有 RAG 选项
"""
if self.rag_options and (
len(self.rag_options.get("pipeline_ids", [])) > 0
or len(self.rag_options.get("file_ids", [])) > 0
):
return True
return False
@override
async def step(self):
"""
执行 Dashscope Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
# 执行 Dashscope 请求并处理结果
async for response in self._execute_dashscope_request():
yield response
except Exception as e:
logger.error(f"阿里云百炼请求失败:{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"阿里云百炼请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"阿里云百炼请求失败:{str(e)}")
),
)
@override
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
def _consume_sync_generator(
self, response: T.Any, response_queue: queue.Queue
) -> None:
"""在线程中消费同步generator,将结果放入队列
Args:
response: 同步generator对象
response_queue: 用于传递数据的队列
"""
try:
if self.streaming:
for chunk in response:
response_queue.put(("data", chunk))
else:
response_queue.put(("data", response))
except Exception as e:
response_queue.put(("error", e))
finally:
response_queue.put(("done", None))
async def _process_stream_chunk(
self, chunk: ApplicationResponse, output_text: str
) -> tuple[str, list | None, AgentResponse | None]:
"""处理流式响应的单个chunk
Args:
chunk: Dashscope响应chunk
output_text: 当前累积的输出文本
Returns:
(更新后的output_text, doc_references, AgentResponse或None)
"""
logger.debug(f"dashscope stream chunk: {chunk}")
if chunk.status_code != 200:
logger.error(
f"阿里云百炼请求失败: request_id={chunk.request_id}, code={chunk.status_code}, message={chunk.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code",
)
self._transition_state(AgentState.ERROR)
error_msg = (
f"阿里云百炼请求失败: message={chunk.message} code={chunk.status_code}"
)
self.final_llm_resp = LLMResponse(
role="err",
result_chain=MessageChain().message(error_msg),
)
return (
output_text,
None,
AgentResponse(
type="err",
data=AgentResponseData(chain=MessageChain().message(error_msg)),
),
)
chunk_text = chunk.output.get("text", "") or ""
# RAG 引用脚标格式化
chunk_text = re.sub(r"<ref>\[(\d+)\]</ref>", r"[\1]", chunk_text)
response = None
if chunk_text:
output_text += chunk_text
response = AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=MessageChain().message(chunk_text)),
)
# 获取文档引用
doc_references = chunk.output.get("doc_references", None)
return output_text, doc_references, response
def _format_doc_references(self, doc_references: list) -> str:
"""格式化文档引用为文本
Args:
doc_references: 文档引用列表
Returns:
格式化后的引用文本
"""
ref_parts = []
for ref in doc_references:
ref_title = (
ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
)
ref_parts.append(f"{ref['index_id']}. {ref_title}\n")
ref_str = "".join(ref_parts)
return f"\n\n回答来源:\n{ref_str}"
async def _build_request_payload(
self, prompt: str, session_id: str, contexts: list, system_prompt: str
) -> dict:
"""构建请求payload
Args:
prompt: 用户输入
session_id: 会话ID
contexts: 上下文列表
system_prompt: 系统提示词
Returns:
请求payload字典
"""
conversation_id = await sp.get_async(
scope="umo",
scope_id=session_id,
key="dashscope_conversation_id",
default="",
)
# 获得会话变量
payload_vars = self.variables.copy()
session_var = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
payload_vars.update(session_var)
if (
self.dashscope_app_type in ["agent", "dialog-workflow"]
and not self.has_rag_options()
):
# 支持多轮对话的
p = {
"app_id": self.app_id,
"api_key": self.api_key,
"prompt": prompt,
"biz_params": payload_vars or None,
"stream": self.streaming,
"incremental_output": True,
}
if conversation_id:
p["session_id"] = conversation_id
return p
else:
# 不支持多轮对话的
payload = {
"app_id": self.app_id,
"prompt": prompt,
"api_key": self.api_key,
"biz_params": payload_vars or None,
"stream": self.streaming,
"incremental_output": True,
}
if self.rag_options:
payload["rag_options"] = self.rag_options
return payload
async def _handle_streaming_response(
self, response: T.Any, session_id: str
) -> T.AsyncGenerator[AgentResponse, None]:
"""处理流式响应
Args:
response: Dashscope 流式响应 generator
Yields:
AgentResponse 对象
"""
response_queue = queue.Queue()
consumer_thread = threading.Thread(
target=self._consume_sync_generator,
args=(response, response_queue),
daemon=True,
)
consumer_thread.start()
output_text = ""
doc_references = None
while True:
try:
item_type, item_data = await asyncio.get_event_loop().run_in_executor(
None, response_queue.get, True, 1
)
except queue.Empty:
continue
if item_type == "done":
break
elif item_type == "error":
raise item_data
elif item_type == "data":
chunk = item_data
assert isinstance(chunk, ApplicationResponse)
(
output_text,
chunk_doc_refs,
response,
) = await self._process_stream_chunk(chunk, output_text)
if response:
if response.type == "err":
yield response
return
yield response
if chunk_doc_refs:
doc_references = chunk_doc_refs
if chunk.output.session_id:
await sp.put_async(
scope="umo",
scope_id=session_id,
key="dashscope_conversation_id",
value=chunk.output.session_id,
)
# 添加 RAG 引用
if self.output_reference and doc_references:
ref_text = self._format_doc_references(doc_references)
output_text += ref_text
if self.streaming:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(chain=MessageChain().message(ref_text)),
)
# 创建最终响应
chain = MessageChain(chain=[Comp.Plain(output_text)])
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回最终结果
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=chain),
)
async def _execute_dashscope_request(self):
"""执行 Dashscope 请求的核心逻辑"""
prompt = self.req.prompt or ""
session_id = self.req.session_id or "unknown"
image_urls = self.req.image_urls or []
contexts = self.req.contexts or []
system_prompt = self.req.system_prompt
# 检查图片输入
if image_urls:
logger.warning("阿里云百炼暂不支持图片输入,将自动忽略图片内容。")
# 构建请求payload
payload = await self._build_request_payload(
prompt, session_id, contexts, system_prompt
)
if not self.streaming:
payload["incremental_output"] = False
# 发起请求
partial = functools.partial(Application.call, **payload)
response = await asyncio.get_event_loop().run_in_executor(None, partial)
async for resp in self._handle_streaming_response(response, session_id):
yield resp
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
@override
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp
@@ -0,0 +1,336 @@
import base64
import os
import sys
import typing as T
import astrbot.core.message.components as Comp
from astrbot.core import logger, sp
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from ...hooks import BaseAgentRunHooks
from ...response import AgentResponseData
from ...run_context import ContextWrapper, TContext
from ..base import AgentResponse, AgentState, BaseAgentRunner
from .dify_api_client import DifyAPIClient
if sys.version_info >= (3, 12):
from typing import override
else:
from typing_extensions import override
class DifyAgentRunner(BaseAgentRunner[TContext]):
"""Dify Agent Runner"""
@override
async def reset(
self,
request: ProviderRequest,
run_context: ContextWrapper[TContext],
agent_hooks: BaseAgentRunHooks[TContext],
provider_config: dict,
**kwargs: T.Any,
) -> None:
self.req = request
self.streaming = kwargs.get("streaming", False)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.agent_hooks = agent_hooks
self.run_context = run_context
self.api_key = provider_config.get("dify_api_key", "")
self.api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_type = provider_config.get("dify_api_type", "chat")
self.workflow_output_key = provider_config.get(
"dify_workflow_output_key",
"astrbot_wf_output",
)
self.dify_query_input_key = provider_config.get(
"dify_query_input_key",
"astrbot_text_query",
)
self.variables: dict = provider_config.get("variables", {}) or {}
self.timeout = provider_config.get("timeout", 60)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.api_client = DifyAPIClient(self.api_key, self.api_base)
@override
async def step(self):
"""
执行 Dify Agent 的一个步骤
"""
if not self.req:
raise ValueError("Request is not set. Please call reset() first.")
if self._state == AgentState.IDLE:
try:
await self.agent_hooks.on_agent_begin(self.run_context)
except Exception as e:
logger.error(f"Error in on_agent_begin hook: {e}", exc_info=True)
# 开始处理,转换到运行状态
self._transition_state(AgentState.RUNNING)
try:
# 执行 Dify 请求并处理结果
async for response in self._execute_dify_request():
yield response
except Exception as e:
logger.error(f"Dify 请求失败:{str(e)}")
self._transition_state(AgentState.ERROR)
self.final_llm_resp = LLMResponse(
role="err", completion_text=f"Dify 请求失败:{str(e)}"
)
yield AgentResponse(
type="err",
data=AgentResponseData(
chain=MessageChain().message(f"Dify 请求失败:{str(e)}")
),
)
finally:
await self.api_client.close()
@override
async def step_until_done(
self, max_step: int = 30
) -> T.AsyncGenerator[AgentResponse, None]:
while not self.done():
async for resp in self.step():
yield resp
async def _execute_dify_request(self):
"""执行 Dify 请求的核心逻辑"""
prompt = self.req.prompt or ""
session_id = self.req.session_id or "unknown"
image_urls = self.req.image_urls or []
system_prompt = self.req.system_prompt
conversation_id = await sp.get_async(
scope="umo",
scope_id=session_id,
key="dify_conversation_id",
default="",
)
result = ""
# 处理图片上传
files_payload = []
for image_url in image_urls:
# image_url is a base64 string
try:
image_data = base64.b64decode(image_url)
file_response = await self.api_client.file_upload(
file_data=image_data,
user=session_id,
mime_type="image/png",
file_name="image.png",
)
logger.debug(f"Dify 上传图片响应:{file_response}")
if "id" not in file_response:
logger.warning(
f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。"
)
continue
files_payload.append(
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response["id"],
}
)
except Exception as e:
logger.warning(f"上传图片失败:{e}")
continue
# 获得会话变量
payload_vars = self.variables.copy()
# 动态变量
session_var = await sp.get_async(
scope="umo",
scope_id=session_id,
key="session_variables",
default={},
)
payload_vars.update(session_var)
payload_vars["system_prompt"] = system_prompt
# 处理不同的 API 类型
match self.api_type:
case "chat" | "agent" | "chatflow":
if not prompt:
prompt = "请描述这张图片。"
async for chunk in self.api_client.chat_messages(
inputs={
**payload_vars,
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload,
timeout=self.timeout,
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk["event"] == "message" or chunk["event"] == "agent_message":
result += chunk["answer"]
if not conversation_id:
await sp.put_async(
scope="umo",
scope_id=session_id,
key="dify_conversation_id",
value=chunk["conversation_id"],
)
conversation_id = chunk["conversation_id"]
# 如果是流式响应,发送增量数据
if self.streaming and chunk["answer"]:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(chunk["answer"])
),
)
elif chunk["event"] == "message_end":
logger.debug("Dify message end")
break
elif chunk["event"] == "error":
logger.error(f"Dify 出现错误:{chunk}")
raise Exception(
f"Dify 出现错误 status: {chunk['status']} message: {chunk['message']}"
)
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
self.dify_query_input_key: prompt,
"astrbot_session_id": session_id,
**payload_vars,
},
user=session_id,
files=files_payload,
timeout=self.timeout,
):
logger.debug(f"dify workflow resp chunk: {chunk}")
match chunk["event"]:
case "workflow_started":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。"
)
case "node_finished":
logger.debug(
f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。"
)
case "text_chunk":
if self.streaming and chunk["data"]["text"]:
yield AgentResponse(
type="streaming_delta",
data=AgentResponseData(
chain=MessageChain().message(
chunk["data"]["text"]
)
),
)
case "workflow_finished":
logger.info(
f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束"
)
logger.debug(f"Dify 工作流结果:{chunk}")
if chunk["data"]["error"]:
logger.error(
f"Dify 工作流出现错误:{chunk['data']['error']}"
)
raise Exception(
f"Dify 工作流出现错误:{chunk['data']['error']}"
)
if self.workflow_output_key not in chunk["data"]["outputs"]:
raise Exception(
f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}"
)
result = chunk
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
if not result:
logger.warning("Dify 请求结果为空,请查看 Debug 日志。")
# 解析结果
chain = await self.parse_dify_result(result)
# 创建最终响应
self.final_llm_resp = LLMResponse(role="assistant", result_chain=chain)
self._transition_state(AgentState.DONE)
try:
await self.agent_hooks.on_agent_done(self.run_context, self.final_llm_resp)
except Exception as e:
logger.error(f"Error in on_agent_done hook: {e}", exc_info=True)
# 返回最终结果
yield AgentResponse(
type="llm_result",
data=AgentResponseData(chain=chain),
)
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
"""解析 Dify 的响应结果"""
if isinstance(chunk, str):
# Chat
return MessageChain(chain=[Comp.Plain(chunk)])
async def parse_file(item: dict):
match item["type"]:
case "image":
return Comp.Image(file=item["url"], url=item["url"])
case "audio":
# 仅支持 wav
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"{item['filename']}.wav")
await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"])
case "video":
return Comp.Video(file=item["url"])
case _:
return Comp.File(name=item["filename"], file=item["url"])
output = chunk["data"]["outputs"][self.workflow_output_key]
chains = []
if isinstance(output, str):
# 纯文本输出
chains.append(Comp.Plain(output))
elif isinstance(output, list):
# 主要适配 Dify 的 HTTP 请求结点的多模态输出
for item in output:
# handle Array[File]
if (
not isinstance(item, dict)
or item.get("dify_model_identity", "") != "__dify__file__"
):
chains.append(Comp.Plain(str(output)))
break
else:
chains.append(Comp.Plain(str(output)))
# scan file
files = chunk["data"].get("files", [])
for item in files:
comp = await parse_file(item)
chains.append(comp)
return MessageChain(chain=chains)
@override
def done(self) -> bool:
"""检查 Agent 是否已完成工作"""
return self._state in (AgentState.DONE, AgentState.ERROR)
@override
def get_final_llm_resp(self) -> LLMResponse | None:
return self.final_llm_resp
@@ -3,7 +3,7 @@ import json
from collections.abc import AsyncGenerator
from typing import Any
from aiohttp import ClientResponse, ClientSession
from aiohttp import ClientResponse, ClientSession, FormData
from astrbot.core import logger
@@ -101,21 +101,59 @@ class DifyAPIClient:
async def file_upload(
self,
file_path: str,
user: str,
file_path: str | None = None,
file_data: bytes | None = None,
file_name: str | None = None,
mime_type: str | None = None,
) -> dict[str, Any]:
"""Upload a file to Dify. Must provide either file_path or file_data.
Args:
user: The user ID.
file_path: The path to the file to upload.
file_data: The file data in bytes.
file_name: Optional file name when using file_data.
Returns:
A dictionary containing the uploaded file information.
"""
url = f"{self.api_base}/files/upload"
with open(file_path, "rb") as f:
payload = {
"user": user,
"file": f,
}
async with self.session.post(
url,
data=payload,
headers=self.headers,
) as resp:
return await resp.json() # {"id": "xxx", ...}
form = FormData()
form.add_field("user", user)
if file_data is not None:
# 使用 bytes 数据
form.add_field(
"file",
file_data,
filename=file_name or "uploaded_file",
content_type=mime_type or "application/octet-stream",
)
elif file_path is not None:
# 使用文件路径
import os
with open(file_path, "rb") as f:
file_content = f.read()
form.add_field(
"file",
file_content,
filename=os.path.basename(file_path),
content_type=mime_type or "application/octet-stream",
)
else:
raise ValueError("file_path 和 file_data 不能同时为 None")
async with self.session.post(
url,
data=form,
headers=self.headers, # 不包含 Content-Type,让 aiohttp 自动设置
) as resp:
if resp.status != 200 and resp.status != 201:
text = await resp.text()
raise Exception(f"Dify 文件上传失败:{resp.status}. {text}")
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
@@ -69,12 +69,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
self.run_context.messages = messages
def _transition_state(self, new_state: AgentState) -> None:
"""转换 Agent 状态"""
if self._state != new_state:
logger.debug(f"Agent state transition: {self._state} -> {new_state}")
self._state = new_state
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
if self.streaming:
@@ -103,7 +97,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_resp_result = None
async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
if llm_response.result_chain:
yield AgentResponse(
+7 -2
View File
@@ -1,4 +1,4 @@
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Generic
import jsonschema
@@ -7,6 +7,8 @@ from deprecated import deprecated
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from astrbot.core.message.message_event_result import MessageEventResult
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
@@ -38,7 +40,10 @@ class ToolSchema:
class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling."""
handler: Callable[..., Awaitable[Any]] | None = None
handler: (
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
| None
) = None
"""a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None
+15 -1
View File
@@ -9,6 +9,7 @@ from astrbot.core.message.message_event_result import (
MessageEventResult,
ResultContentType,
)
from astrbot.core.provider.entities import LLMResponse
AgentRunner = ToolLoopAgentRunner[AstrAgentContext]
@@ -72,7 +73,20 @@ async def run_agent(
except Exception as e:
logger.error(traceback.format_exc())
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n"
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
error_llm_response = LLMResponse(
role="err",
completion_text=err_msg,
)
try:
await agent_runner.agent_hooks.on_agent_done(
agent_runner.run_context, error_llm_response
)
except Exception:
logger.exception("Error in on_agent_done hook")
if agent_runner.streaming:
yield MessageChain().message(err_msg)
else:
+5 -1
View File
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
handler: T.Callable[
...,
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
],
method_name: str,
*args,
**kwargs,
+4
View File
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
"""
config_path: str
default_config: dict
schema: dict | None
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
+325 -27
View File
@@ -4,9 +4,18 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.6.1"
VERSION = "4.9.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
"qq_official_webhook",
"weixin_official_account",
"wecom",
"wecom_ai_bot",
"slack",
"lark",
]
# 默认配置
DEFAULT_CONFIG = {
"config_version": 2,
@@ -34,7 +43,15 @@ DEFAULT_CONFIG = {
"interval": "1.5,3.5",
"log_base": 2.6,
"words_count_threshold": 150,
"split_mode": "regex", # regex 或 words
"regex": ".*?[。?!~…]+|.+$",
"split_words": [
"",
"",
"",
"~",
"",
], # 当 split_mode 为 words 时使用
"content_cleanup_rule": "",
},
"no_permission_reply": True,
@@ -68,9 +85,19 @@ DEFAULT_CONFIG = {
"dequeue_context_length": 1,
"streaming_response": False,
"show_tool_use_status": False,
"agent_runner_type": "local",
"dify_agent_runner_provider_id": "",
"coze_agent_runner_provider_id": "",
"dashscope_agent_runner_provider_id": "",
"unsupported_streaming_strategy": "realtime_segmenting",
"reachability_check": False,
"max_agent_step": 30,
"tool_call_timeout": 60,
"file_extract": {
"enable": False,
"provider": "moonshotai",
"moonshotai_api_key": "",
},
},
"provider_stt_settings": {
"enable": False,
@@ -86,6 +113,7 @@ DEFAULT_CONFIG = {
"group_icl_enable": False,
"group_message_max_cnt": 300,
"image_caption": False,
"image_caption_provider_id": "",
"active_reply": {
"enable": False,
"method": "possibility_reply",
@@ -138,10 +166,20 @@ DEFAULT_CONFIG = {
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
"kb_agentic_mode": False,
"disable_builtin_commands": False,
}
# 配置项的中文描述、值类型
"""
AstrBot v3 时代的配置元数据,目前仅承担以下功能:
1. 保存配置时,配置项的类型验证
2. WebUI 展示提供商和平台适配器模版
WebUI 的配置文件在 `CONFIG_METADATA_3` 中。
未来将会逐步淘汰此配置元数据。
"""
CONFIG_METADATA_2 = {
"platform_group": {
"metadata": {
@@ -165,6 +203,8 @@ CONFIG_METADATA_2 = {
"appid": "",
"secret": "",
"is_sandbox": False,
"unified_webhook_mode": True,
"webhook_uuid": "",
"callback_server_host": "0.0.0.0",
"port": 6196,
},
@@ -195,6 +235,8 @@ CONFIG_METADATA_2 = {
"token": "",
"encoding_aes_key": "",
"api_base_url": "https://api.weixin.qq.com/cgi-bin/",
"unified_webhook_mode": True,
"webhook_uuid": "",
"callback_server_host": "0.0.0.0",
"port": 6194,
"active_send_mode": False,
@@ -209,6 +251,8 @@ CONFIG_METADATA_2 = {
"encoding_aes_key": "",
"kf_name": "",
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
"unified_webhook_mode": True,
"webhook_uuid": "",
"callback_server_host": "0.0.0.0",
"port": 6195,
},
@@ -221,6 +265,8 @@ CONFIG_METADATA_2 = {
"wecom_ai_bot_name": "",
"token": "",
"encoding_aes_key": "",
"unified_webhook_mode": True,
"webhook_uuid": "",
"callback_server_host": "0.0.0.0",
"port": 6198,
},
@@ -232,6 +278,10 @@ CONFIG_METADATA_2 = {
"app_id": "",
"app_secret": "",
"domain": "https://open.feishu.cn",
"lark_connection_mode": "socket", # webhook, socket
"webhook_uuid": "",
"lark_encrypt_key": "",
"lark_verification_token": "",
},
"钉钉(DingTalk)": {
"id": "dingtalk",
@@ -288,6 +338,8 @@ CONFIG_METADATA_2 = {
"app_token": "",
"signing_secret": "",
"slack_connection_mode": "socket", # webhook, socket
"unified_webhook_mode": True,
"webhook_uuid": "",
"slack_webhook_host": "0.0.0.0",
"slack_webhook_port": 6197,
"slack_webhook_path": "/astrbot-slack-webhook/callback",
@@ -323,6 +375,28 @@ CONFIG_METADATA_2 = {
# "type": "string",
# "options": ["fullscreen", "embedded"],
# },
"lark_connection_mode": {
"description": "订阅方式",
"type": "string",
"options": ["socket", "webhook"],
"labels": ["长连接模式", "推送至服务器模式"],
},
"lark_encrypt_key": {
"description": "Encrypt Key",
"type": "string",
"hint": "用于解密飞书回调数据的加密密钥",
"condition": {
"lark_connection_mode": "webhook",
},
},
"lark_verification_token": {
"description": "Verification Token",
"type": "string",
"hint": "用于验证飞书回调请求的令牌",
"condition": {
"lark_connection_mode": "webhook",
},
},
"is_sandbox": {
"description": "沙箱模式",
"type": "bool",
@@ -367,16 +441,28 @@ CONFIG_METADATA_2 = {
"description": "Slack Webhook Host",
"type": "string",
"hint": "Only valid when Slack connection mode is `webhook`.",
"condition": {
"slack_connection_mode": "webhook",
"unified_webhook_mode": False,
},
},
"slack_webhook_port": {
"description": "Slack Webhook Port",
"type": "int",
"hint": "Only valid when Slack connection mode is `webhook`.",
"condition": {
"slack_connection_mode": "webhook",
"unified_webhook_mode": False,
},
},
"slack_webhook_path": {
"description": "Slack Webhook Path",
"type": "string",
"hint": "Only valid when Slack connection mode is `webhook`.",
"condition": {
"slack_connection_mode": "webhook",
"unified_webhook_mode": False,
},
},
"active_send_mode": {
"description": "是否换用主动发送接口",
@@ -567,6 +653,33 @@ CONFIG_METADATA_2 = {
"type": "string",
"hint": "可选的 Discord 活动名称。留空则不设置活动。",
},
"port": {
"description": "回调服务器端口",
"type": "int",
"hint": "回调服务器端口。留空则不启用回调服务器。",
"condition": {
"unified_webhook_mode": False,
},
},
"callback_server_host": {
"description": "回调服务器主机",
"type": "string",
"hint": "回调服务器主机。留空则不启用回调服务器。",
"condition": {
"unified_webhook_mode": False,
},
},
"unified_webhook_mode": {
"description": "统一 Webhook 模式",
"type": "bool",
"hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}",
},
"webhook_uuid": {
"invisible": True,
"description": "Webhook UUID",
"type": "string",
"hint": "统一 Webhook 模式下的唯一标识符,创建平台时自动生成。",
},
},
},
"platform_settings": {
@@ -634,7 +747,7 @@ CONFIG_METADATA_2 = {
},
"words_count_threshold": {
"type": "int",
"hint": "超过这个字数的消息会被分段回复。默认为 150",
"hint": "分段回复的字数上限。只有字数小于此值的消息会被分段,超过此值的长消息将直接发送(不分段)。默认为 150",
},
"regex": {
"type": "string",
@@ -1011,7 +1124,7 @@ CONFIG_METADATA_2 = {
"id": "dify_app_default",
"provider": "dify",
"type": "dify",
"provider_type": "chat_completion",
"provider_type": "agent_runner",
"enable": True,
"dify_api_type": "chat",
"dify_api_key": "",
@@ -1025,20 +1138,20 @@ CONFIG_METADATA_2 = {
"Coze": {
"id": "coze",
"provider": "coze",
"provider_type": "chat_completion",
"provider_type": "agent_runner",
"type": "coze",
"enable": True,
"coze_api_key": "",
"bot_id": "",
"coze_api_base": "https://api.coze.cn",
"timeout": 60,
"auto_save_history": True,
# "auto_save_history": True,
},
"阿里云百炼应用": {
"id": "dashscope",
"provider": "dashscope",
"type": "dashscope",
"provider_type": "chat_completion",
"provider_type": "agent_runner",
"enable": True,
"dashscope_app_type": "agent",
"dashscope_api_key": "",
@@ -1087,7 +1200,7 @@ CONFIG_METADATA_2 = {
"api_base": "",
"model": "whisper-1",
},
"Whisper(本地加载)": {
"Whisper(Local)": {
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cudaCPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"provider": "openai",
"type": "openai_whisper_selfhost",
@@ -1096,7 +1209,7 @@ CONFIG_METADATA_2 = {
"id": "whisper_selfhost",
"model": "tiny",
},
"SenseVoice(本地加载)": {
"SenseVoice(Local)": {
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"type": "sensevoice_stt_selfhost",
"provider": "sensevoice",
@@ -1131,7 +1244,7 @@ CONFIG_METADATA_2 = {
"pitch": "+0Hz",
"timeout": 20,
},
"GSV TTS(本地加载)": {
"GSV TTS(Local)": {
"id": "gsv_tts",
"enable": False,
"provider": "gpt_sovits",
@@ -1907,7 +2020,6 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用",
"type": "bool",
"hint": "是否启用。",
},
"key": {
"description": "API Key",
@@ -2037,14 +2149,38 @@ CONFIG_METADATA_2 = {
"unsupported_streaming_strategy": {
"type": "string",
},
"agent_runner_type": {
"type": "string",
},
"dify_agent_runner_provider_id": {
"type": "string",
},
"coze_agent_runner_provider_id": {
"type": "string",
},
"dashscope_agent_runner_provider_id": {
"type": "string",
},
"max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
},
"tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
},
"file_extract": {
"type": "object",
"items": {
"enable": {
"type": "bool",
},
"provider": {
"type": "string",
},
"moonshotai_api_key": {
"type": "string",
},
},
},
},
},
"provider_stt_settings": {
@@ -2087,6 +2223,9 @@ CONFIG_METADATA_2 = {
"image_caption": {
"type": "bool",
},
"image_caption_provider_id": {
"type": "string",
},
"image_caption_prompt": {
"type": "string",
},
@@ -2176,34 +2315,87 @@ CONFIG_METADATA_2 = {
}
"""
v4.7.0 之后,name, description, hint 等字段已经实现 i18n 国际化。国际化资源文件位于:
- dashboard/src/i18n/locales/en-US/features/config-metadata.json
- dashboard/src/i18n/locales/zh-CN/features/config-metadata.json
如果在此文件中添加了新的配置字段,请务必同步更新上述两个国际化资源文件。
"""
CONFIG_METADATA_3 = {
"ai_group": {
"name": "AI 配置",
"metadata": {
"ai": {
"description": "模型",
"agent_runner": {
"description": "Agent 执行方式",
"hint": "选择 AI 对话的执行器,默认为 AstrBot 内置 Agent 执行器,可使用 AstrBot 内的知识库、人格、工具调用功能。如果不打算接入 Dify 或 Coze 等第三方 Agent 执行器,不需要修改此节。",
"type": "object",
"items": {
"provider_settings.enable": {
"description": "启用大语言模型聊天",
"description": "启用",
"type": "bool",
"hint": "AI 对话总开关",
},
"provider_settings.agent_runner_type": {
"description": "执行器",
"type": "string",
"options": ["local", "dify", "coze", "dashscope"],
"labels": ["内置 Agent", "Dify", "Coze", "阿里云百炼应用"],
"condition": {
"provider_settings.enable": True,
},
},
"provider_settings.coze_agent_runner_provider_id": {
"description": "Coze Agent 执行器提供商 ID",
"type": "string",
"_special": "select_agent_runner_provider:coze",
"condition": {
"provider_settings.agent_runner_type": "coze",
"provider_settings.enable": True,
},
},
"provider_settings.dify_agent_runner_provider_id": {
"description": "Dify Agent 执行器提供商 ID",
"type": "string",
"_special": "select_agent_runner_provider:dify",
"condition": {
"provider_settings.agent_runner_type": "dify",
"provider_settings.enable": True,
},
},
"provider_settings.dashscope_agent_runner_provider_id": {
"description": "阿里云百炼应用 Agent 执行器提供商 ID",
"type": "string",
"_special": "select_agent_runner_provider:dashscope",
"condition": {
"provider_settings.agent_runner_type": "dashscope",
"provider_settings.enable": True,
},
},
},
},
"ai": {
"description": "模型",
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"type": "object",
"items": {
"provider_settings.default_provider_id": {
"description": "默认聊天模型",
"type": "string",
"_special": "select_provider",
"hint": "留空时使用第一个模型",
"hint": "留空时使用第一个模型",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",
"_special": "select_provider",
"hint": "留空代表不使用可用于不支持视觉模态的聊天模型",
"hint": "留空代表不使用可用于非多模态模型",
},
"provider_stt_settings.enable": {
"description": "启用语音转文本",
"type": "bool",
"hint": "STT 总开关",
"hint": "STT 总开关",
},
"provider_stt_settings.provider_id": {
"description": "默认语音转文本模型",
@@ -2217,12 +2409,11 @@ CONFIG_METADATA_3 = {
"provider_tts_settings.enable": {
"description": "启用文本转语音",
"type": "bool",
"hint": "TTS 总开关。当关闭时,会话启用 TTS 也不会生效。",
"hint": "TTS 总开关",
},
"provider_tts_settings.provider_id": {
"description": "默认文本转语音模型",
"type": "string",
"hint": "用户也可使用 /provider 单独选择会话的 TTS 模型。",
"_special": "select_provider_tts",
"condition": {
"provider_tts_settings.enable": True,
@@ -2233,6 +2424,9 @@ CONFIG_METADATA_3 = {
"type": "text",
},
},
"condition": {
"provider_settings.enable": True,
},
},
"persona": {
"description": "人格",
@@ -2244,6 +2438,10 @@ CONFIG_METADATA_3 = {
"_special": "select_persona",
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
"knowledgebase": {
"description": "知识库",
@@ -2272,6 +2470,10 @@ CONFIG_METADATA_3 = {
"hint": "启用后,知识库检索将作为 LLM Tool,由模型自主决定何时调用知识库进行查询。需要模型支持函数调用能力。",
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
"websearch": {
"description": "网页搜索",
@@ -2308,7 +2510,41 @@ CONFIG_METADATA_3 = {
"type": "bool",
},
},
"condition": {
"provider_settings.agent_runner_type": "local",
"provider_settings.enable": True,
},
},
# "file_extract": {
# "description": "文档解析能力 [beta]",
# "type": "object",
# "items": {
# "provider_settings.file_extract.enable": {
# "description": "启用文档解析能力",
# "type": "bool",
# },
# "provider_settings.file_extract.provider": {
# "description": "文档解析提供商",
# "type": "string",
# "options": ["moonshotai"],
# "condition": {
# "provider_settings.file_extract.enable": True,
# },
# },
# "provider_settings.file_extract.moonshotai_api_key": {
# "description": "Moonshot AI API Key",
# "type": "string",
# "condition": {
# "provider_settings.file_extract.provider": "moonshotai",
# "provider_settings.file_extract.enable": True,
# },
# },
# },
# "condition": {
# "provider_settings.agent_runner_type": "local",
# "provider_settings.enable": True,
# },
# },
"others": {
"description": "其他配置",
"type": "object",
@@ -2316,34 +2552,51 @@ CONFIG_METADATA_3 = {
"provider_settings.display_reasoning_text": {
"description": "显示思考内容",
"type": "bool",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.identifier": {
"description": "用户识别",
"type": "bool",
"hint": "启用后,会在提示词前包含用户 ID 信息。",
},
"provider_settings.group_name_display": {
"description": "显示群名称",
"type": "bool",
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
"hint": "启用后,在支持的平台(OneBot v11)上会在提示词前包含群名称信息。",
},
"provider_settings.datetime_system_prompt": {
"description": "现实世界时间感知",
"type": "bool",
"hint": "启用后,会在系统提示词中附带当前时间信息。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.show_tool_use_status": {
"description": "输出函数调用状态",
"type": "bool",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.tool_call_timeout": {
"description": "工具调用超时时间(秒)",
"type": "int",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.streaming_response": {
"description": "流式回复",
"description": "流式输出",
"type": "bool",
},
"provider_settings.unsupported_streaming_strategy": {
@@ -2359,17 +2612,23 @@ CONFIG_METADATA_3 = {
"provider_settings.max_context_length": {
"description": "最多携带对话轮数",
"type": "int",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条-1 为不限制",
"hint": "超出这个数量时丢弃最旧的部分,一轮聊天记为 1 条-1 为不限制",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.dequeue_context_length": {
"description": "丢弃对话轮数",
"type": "int",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"hint": "超出最多携带对话轮数时, 一次丢弃的聊天轮数",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.wake_prefix": {
"description": "LLM 聊天额外唤醒前缀 ",
"type": "string",
"hint": "如果唤醒前缀为 `/`, 额外聊天唤醒前缀为 `chat`,则需要 `/chat` 才会触发 LLM 请求。默认为空。",
"hint": "如果唤醒前缀为 /, 额外聊天唤醒前缀为 chat,则需要 /chat 才会触发 LLM 请求",
},
"provider_settings.prompt_prefix": {
"description": "用户提示词",
@@ -2380,6 +2639,14 @@ CONFIG_METADATA_3 = {
"description": "开启 TTS 时同时输出语音和文字内容",
"type": "bool",
},
"provider_settings.reachability_check": {
"description": "提供商可达性检测",
"type": "bool",
"hint": "/provider 命令列出模型时是否并发检测连通性。开启后会主动调用模型测试连通性,可能产生额外 token 消耗。",
},
},
"condition": {
"provider_settings.enable": True,
},
},
},
@@ -2430,6 +2697,11 @@ CONFIG_METADATA_3 = {
"description": "只 @ 机器人是否触发等待",
"type": "bool",
},
"disable_builtin_commands": {
"description": "禁用自带指令",
"type": "bool",
"hint": "禁用所有 AstrBot 的自带指令,如 help, provider, model 等。",
},
},
},
"whitelist": {
@@ -2644,9 +2916,26 @@ CONFIG_METADATA_3 = {
"description": "分段回复字数阈值",
"type": "int",
},
"platform_settings.segmented_reply.split_mode": {
"description": "分段模式",
"type": "string",
"options": ["regex", "words"],
"labels": ["正则表达式", "分段词列表"],
},
"platform_settings.segmented_reply.regex": {
"description": "分段正则表达式",
"type": "string",
"condition": {
"platform_settings.segmented_reply.split_mode": "regex",
},
},
"platform_settings.segmented_reply.split_words": {
"description": "分段词列表",
"type": "list",
"hint": "检测到列表中的任意词时进行分段,如:。、?、!等",
"condition": {
"platform_settings.segmented_reply.split_mode": "words",
},
},
"platform_settings.segmented_reply.content_cleanup_rule": {
"description": "内容过滤正则表达式",
@@ -2670,7 +2959,16 @@ CONFIG_METADATA_3 = {
"provider_ltm_settings.image_caption": {
"description": "自动理解图片",
"type": "bool",
"hint": "需要设置默认图片转述模型。",
"hint": "需要设置群聊图片转述模型。",
},
"provider_ltm_settings.image_caption_provider_id": {
"description": "群聊图片转述模型",
"type": "string",
"_special": "select_provider",
"hint": "用于群聊上下文感知的图片理解,与默认图片转述模型分开配置。",
"condition": {
"provider_ltm_settings.image_caption": True,
},
},
"provider_ltm_settings.active_reply.enable": {
"description": "主动回复",
+110
View File
@@ -0,0 +1,110 @@
"""
配置元数据国际化工具
提供配置元数据的国际化键转换功能
"""
from typing import Any
class ConfigMetadataI18n:
"""配置元数据国际化转换器"""
@staticmethod
def _get_i18n_key(group: str, section: str, field: str, attr: str) -> str:
"""
生成国际化键
Args:
group: 配置组,如 'ai_group', 'platform_group'
section: 配置节,如 'agent_runner', 'general'
field: 字段名,如 'enable', 'default_provider'
attr: 属性类型,如 'description', 'hint', 'labels'
Returns:
国际化键,格式如: 'ai_group.agent_runner.enable.description'
"""
if field:
return f"{group}.{section}.{field}.{attr}"
else:
return f"{group}.{section}.{attr}"
@staticmethod
def convert_to_i18n_keys(metadata: dict[str, Any]) -> dict[str, Any]:
"""
将配置元数据转换为使用国际化键
Args:
metadata: 原始配置元数据字典
Returns:
使用国际化键的配置元数据字典
"""
result = {}
for group_key, group_data in metadata.items():
group_result = {
"name": f"{group_key}.name",
"metadata": {},
}
for section_key, section_data in group_data.get("metadata", {}).items():
section_result = {
"description": f"{group_key}.{section_key}.description",
"type": section_data.get("type"),
}
# 复制其他属性
for key in ["items", "condition", "_special", "invisible"]:
if key in section_data:
section_result[key] = section_data[key]
# 处理 hint
if "hint" in section_data:
section_result["hint"] = f"{group_key}.{section_key}.hint"
# 处理 items 中的字段
if "items" in section_data and isinstance(section_data["items"], dict):
items_result = {}
for field_key, field_data in section_data["items"].items():
# 处理嵌套的点号字段名(如 provider_settings.enable
field_name = field_key
field_result = {}
# 复制基本属性
for attr in [
"type",
"condition",
"_special",
"invisible",
"options",
]:
if attr in field_data:
field_result[attr] = field_data[attr]
# 转换文本属性为国际化键
if "description" in field_data:
field_result["description"] = (
f"{group_key}.{section_key}.{field_name}.description"
)
if "hint" in field_data:
field_result["hint"] = (
f"{group_key}.{section_key}.{field_name}.hint"
)
if "labels" in field_data:
field_result["labels"] = (
f"{group_key}.{section_key}.{field_name}.labels"
)
items_result[field_key] = field_result
section_result["items"] = items_result
group_result["metadata"][section_key] = section_result
result[group_key] = group_result
return result
+12 -14
View File
@@ -16,13 +16,12 @@ import time
import traceback
from asyncio import Queue
from astrbot.core import LogBroker, logger, sp
from astrbot.api import logger, sp
from astrbot.core import LogBroker
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.config.default import VERSION
from astrbot.core.conversation_mgr import ConversationManager
from astrbot.core.db import BaseDatabase
from astrbot.core.db.migration.migra_45_to_46 import migrate_45_to_46
from astrbot.core.db.migration.migra_webchat_session import migrate_webchat_session
from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager
from astrbot.core.persona_mgr import PersonaManager
from astrbot.core.pipeline.scheduler import PipelineContext, PipelineScheduler
@@ -34,6 +33,7 @@ from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.migra_helper import migra
from . import astrbot_config, html_renderer
from .event_bus import EventBus
@@ -97,18 +97,16 @@ class AstrBotCoreLifecycle:
sp=sp,
)
# 4.5 to 4.6 migration for umop_config_router
# apply migration
try:
await migrate_45_to_46(self.astrbot_config_mgr, self.umop_config_router)
await migra(
self.db,
self.astrbot_config_mgr,
self.umop_config_router,
self.astrbot_config_mgr,
)
except Exception as e:
logger.error(f"Migration from version 4.5 to 4.6 failed: {e!s}")
logger.error(traceback.format_exc())
# migration for webchat session
try:
await migrate_webchat_session(self.db)
except Exception as e:
logger.error(f"Migration for webchat session failed: {e!s}")
logger.error(f"AstrBot migration failed: {e!s}")
logger.error(traceback.format_exc())
# 初始化事件队列
@@ -199,7 +197,7 @@ class AstrBotCoreLifecycle:
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = []
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_:
+32 -4
View File
@@ -5,8 +5,7 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass
from deprecated import deprecated
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core.db.po import (
Attachment,
@@ -32,7 +31,7 @@ class BaseDatabase(abc.ABC):
echo=False,
future=True,
)
self.AsyncSessionLocal = sessionmaker(
self.AsyncSessionLocal = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
@@ -173,7 +172,7 @@ class BaseDatabase(abc.ABC):
content: dict,
sender_id: str | None = None,
sender_name: str | None = None,
) -> None:
) -> PlatformMessageHistory:
"""Insert a new platform message history record."""
...
@@ -198,6 +197,14 @@ class BaseDatabase(abc.ABC):
"""Get platform message history for a specific user."""
...
@abc.abstractmethod
async def get_platform_message_history_by_id(
self,
message_id: int,
) -> PlatformMessageHistory | None:
"""Get a platform message history record by its ID."""
...
@abc.abstractmethod
async def insert_attachment(
self,
@@ -213,6 +220,27 @@ class BaseDatabase(abc.ABC):
"""Get an attachment by its ID."""
...
@abc.abstractmethod
async def get_attachments(self, attachment_ids: list[str]) -> list[Attachment]:
"""Get multiple attachments by their IDs."""
...
@abc.abstractmethod
async def delete_attachment(self, attachment_id: str) -> bool:
"""Delete an attachment by its ID.
Returns True if the attachment was deleted, False if it was not found.
"""
...
@abc.abstractmethod
async def delete_attachments(self, attachment_ids: list[str]) -> int:
"""Delete multiple attachments by their IDs.
Returns the number of attachments deleted.
"""
...
@abc.abstractmethod
async def insert_persona(
self,
@@ -70,6 +70,7 @@ async def migration_conversation_table(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
)
continue
if ":" not in conv.user_id:
continue
session = MessageSesion.from_str(session_str=conv.user_id)
@@ -207,6 +208,7 @@ async def migration_webchat_data(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
)
continue
if ":" in conv.user_id:
continue
platform_id = "webchat"
+6 -4
View File
@@ -127,7 +127,7 @@ class SQLiteDatabase:
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: tuple = None):
def _exec_sql(self, sql: str, params: tuple | None = None):
conn = self.conn
try:
c = self.conn.cursor()
@@ -224,9 +224,11 @@ class SQLiteDatabase:
c.close()
return Stats(platform, [], [])
return Stats(platform)
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
def get_conversation_by_user_id(
self, user_id: str, cid: str
) -> Conversation | None:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -258,7 +260,7 @@ class SQLiteDatabase:
(user_id, cid, history, updated_at, created_at),
)
def get_conversations(self, user_id: str) -> tuple:
def get_conversations(self, user_id: str) -> list[Conversation]:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
+17 -16
View File
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
Note: In astrbot v4, we moved `platform` table to here.
"""
__tablename__ = "platform_stats" # type: ignore
__tablename__: str = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False)
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations" # type: ignore
__tablename__: str = "conversations"
inner_conversation_id: int = Field(
inner_conversation_id: int | None = Field(
default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
)
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
It can be used to customize the behavior of LLMs.
"""
__tablename__ = "personas" # type: ignore
__tablename__: str = "personas"
id: int | None = Field(
primary_key=True,
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
class Preference(SQLModel, table=True):
"""This class represents preferences for bots."""
__tablename__ = "preferences" # type: ignore
__tablename__: str = "preferences"
id: int | None = Field(
default=None,
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
or platform-specific messages.
"""
__tablename__ = "platform_message_history" # type: ignore
__tablename__: str = "platform_message_history"
id: int | None = Field(
primary_key=True,
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
Each session can have multiple conversations (对话) associated with it.
"""
__tablename__ = "platform_sessions" # type: ignore
__tablename__: str = "platform_sessions"
inner_id: int | None = Field(
primary_key=True,
@@ -173,7 +174,7 @@ class PlatformSession(SQLModel, table=True):
max_length=100,
nullable=False,
unique=True,
default_factory=lambda: f"webchat_{uuid.uuid4()}",
default_factory=lambda: str(uuid.uuid4()),
)
platform_id: str = Field(default="webchat", nullable=False)
"""Platform identifier (e.g., 'webchat', 'qq', 'discord')"""
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
Attachments can be images, files, or other media types.
"""
__tablename__ = "attachments" # type: ignore
__tablename__: str = "attachments"
inner_attachment_id: int | None = Field(
primary_key=True,
@@ -261,17 +262,17 @@ class Personality(TypedDict):
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
"""
prompt: str = ""
name: str = ""
begin_dialogs: list[str] = []
mood_imitation_dialogs: list[str] = []
prompt: str
name: str
begin_dialogs: list[str]
mood_imitation_dialogs: list[str]
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None
tools: list[str] | None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache
_begin_dialogs_processed: list[dict] = []
_mood_imitation_dialogs_processed: str = ""
_begin_dialogs_processed: list[dict]
_mood_imitation_dialogs_processed: str
# ====
+58 -3
View File
@@ -3,6 +3,7 @@ import threading
import typing as T
from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
@@ -105,8 +106,8 @@ class SQLiteDatabase(BaseDatabase):
text("""
SELECT * FROM platform_stats
WHERE timestamp >= :start_time
ORDER BY timestamp DESC
GROUP BY platform_id
ORDER BY timestamp DESC
"""),
{"start_time": start_time},
)
@@ -449,6 +450,18 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query.offset(offset).limit(page_size))
return result.scalars().all()
async def get_platform_message_history_by_id(
self, message_id: int
) -> PlatformMessageHistory | None:
"""Get a platform message history record by its ID."""
async with self.get_db() as session:
session: AsyncSession
query = select(PlatformMessageHistory).where(
PlatformMessageHistory.id == message_id
)
result = await session.execute(query)
return result.scalar_one_or_none()
async def insert_attachment(self, path, type, mime_type):
"""Insert a new attachment record."""
async with self.get_db() as session:
@@ -470,6 +483,48 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalar_one_or_none()
async def get_attachments(self, attachment_ids: list[str]) -> list:
"""Get multiple attachments by their IDs."""
if not attachment_ids:
return []
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
async def delete_attachment(self, attachment_id: str) -> bool:
"""Delete an attachment by its ID.
Returns True if the attachment was deleted, False if it was not found.
"""
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int:
"""Delete multiple attachments by their IDs.
Returns the number of attachments deleted.
"""
if not attachment_ids:
return 0
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount
async def insert_persona(
self,
persona_id,
@@ -794,7 +849,7 @@ class SQLiteDatabase(BaseDatabase):
await session.execute(
update(PlatformSession)
.where(col(PlatformSession.session_id == session_id))
.where(col(PlatformSession.session_id) == session_id)
.values(**values),
)
@@ -805,6 +860,6 @@ class SQLiteDatabase(BaseDatabase):
async with session.begin():
await session.execute(
delete(PlatformSession).where(
col(PlatformSession.session_id == session_id),
col(PlatformSession.session_id) == session_id,
),
)
@@ -90,4 +90,6 @@ class EmbeddingStorage:
path (str): 保存索引的路径
"""
if self.index is None:
return
faiss.write_index(self.index, self.path)
+6 -1
View File
@@ -27,7 +27,7 @@ class EventBus:
self,
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None,
astrbot_config_mgr: AstrBotConfigManager,
):
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
@@ -40,6 +40,11 @@ class EventBus:
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
if not scheduler:
logger.error(
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
)
continue
asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str):
@@ -166,7 +166,11 @@ class RetrievalManager:
# 5. Rerank
first_rerank = None
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
vec_db = kb_options[kb_id]["vec_db"]
if not isinstance(vec_db, FaissVecDB):
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
continue
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
+2 -1
View File
@@ -24,6 +24,7 @@ import asyncio
import logging
import os
import sys
import time
from asyncio import Queue
from collections import deque
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
self.log_broker.publish(
{
"level": record.levelname,
"time": record.asctime,
"time": time.time(),
"data": log_entry,
},
)
+15 -4
View File
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
class BaseMessageComponent(BaseModel):
type: ComponentType
def __init__(self, **kwargs):
super().__init__(**kwargs)
def toDict(self):
data = {}
for k, v in self.__dict__.items():
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
id: int | None = 0 # 忽略
name: str | None = "" # qq昵称
uin: str | None = "0" # qq号
content: list[BaseMessageComponent] | None = []
content: list[BaseMessageComponent] = []
seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
ret["messages"].append(d)
return ret
async def to_dict(self):
async def to_dict(self) -> dict:
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []}
for node in self.nodes:
@@ -714,15 +717,23 @@ class File(BaseMessageComponent):
if self.url:
await self._download_file()
return os.path.abspath(self.file_)
if self.file_:
return os.path.abspath(self.file_)
return ""
async def _download_file(self):
"""下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
if self.name:
name, ext = os.path.splitext(self.name)
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
else:
filename = f"{uuid.uuid4().hex}"
file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
+2 -2
View File
@@ -98,8 +98,8 @@ class PersonaManager:
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id):
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
self,
event: AstrMessageEvent,
check_text: str | None = None,
) -> None | AsyncGenerator[None, None]:
) -> AsyncGenerator[None, None]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
+2 -1
View File
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
async def call_handler(
event: AstrMessageEvent,
handler: T.Callable[..., T.Awaitable[T.Any]],
handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
@@ -91,6 +91,7 @@ async def call_event_hook(
)
for handler in handlers:
try:
assert inspect.iscoroutinefunction(handler.handler)
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
)
@@ -0,0 +1,48 @@
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.session_llm_manager import SessionServiceManager
from ...context import PipelineContext
from ..stage import Stage
from .agent_sub_stages.internal import InternalAgentSubStage
from .agent_sub_stages.third_party import ThirdPartyAgentSubStage
class AgentRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.bot_wake_prefixs: list[str] = self.config["wake_prefix"]
self.prov_wake_prefix: str = self.config["provider_settings"]["wake_prefix"]
for bwp in self.bot_wake_prefixs:
if self.prov_wake_prefix.startswith(bwp):
logger.info(
f"识别 LLM 聊天额外唤醒前缀 {self.prov_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
)
self.prov_wake_prefix = self.prov_wake_prefix[len(bwp) :]
agent_runner_type = self.config["provider_settings"]["agent_runner_type"]
if agent_runner_type == "local":
self.agent_sub_stage = InternalAgentSubStage()
else:
self.agent_sub_stage = ThirdPartyAgentSubStage()
await self.agent_sub_stage.initialize(ctx)
async def process(self, event: AstrMessageEvent) -> AsyncGenerator[None, None]:
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug(
"This pipeline does not enable AI capability, skip processing."
)
return
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(
f"The session {event.unified_msg_origin} has disabled AI capability, skipping processing."
)
return
async for resp in self.agent_sub_stage.process(event, self.prov_wake_prefix):
yield resp
@@ -9,7 +9,7 @@ from astrbot.core import logger
from astrbot.core.agent.tool import ToolSet
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import Image
from astrbot.core.message.components import File, Image, Reply
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
@@ -21,27 +21,25 @@ from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.star.session_llm_manager import SessionServiceManager
from astrbot.core.star.star_handler import EventType, star_map
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.metrics import Metric
from astrbot.core.utils.session_lock import session_lock_manager
from ....astr_agent_context import AgentContextWrapper
from ....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....astr_agent_run_util import AgentRunner, run_agent
from ....astr_agent_tool_exec import FunctionToolExecutor
from ...context import PipelineContext, call_event_hook
from ..stage import Stage
from ..utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
from .....astr_agent_context import AgentContextWrapper
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from .....astr_agent_run_util import AgentRunner, run_agent
from .....astr_agent_tool_exec import FunctionToolExecutor
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
from ...utils import KNOWLEDGE_BASE_QUERY_TOOL, retrieve_knowledge_base
class LLMRequestSubStage(Stage):
class InternalAgentSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
conf = ctx.astrbot_config
settings = conf["provider_settings"]
self.bot_wake_prefixs: list[str] = conf["wake_prefix"] # list
self.provider_wake_prefix: str = settings["wake_prefix"] # str
self.max_context_length = settings["max_context_length"] # int
self.dequeue_context_length: int = min(
max(1, settings["dequeue_context_length"]),
@@ -59,12 +57,12 @@ class LLMRequestSubStage(Stage):
self.show_reasoning = settings.get("display_reasoning_text", False)
self.kb_agentic_mode: bool = conf.get("kb_agentic_mode", False)
for bwp in self.bot_wake_prefixs:
if self.provider_wake_prefix.startswith(bwp):
logger.info(
f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。",
)
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp) :]
file_extract_conf: dict = settings.get("file_extract", {})
self.file_extract_enabled: bool = file_extract_conf.get("enable", False)
self.file_extract_prov: str = file_extract_conf.get("provider", "moonshotai")
self.file_extract_msh_api_key: str = file_extract_conf.get(
"moonshotai_api_key", ""
)
self.conv_manager = ctx.plugin_manager.context.conversation_manager
@@ -124,6 +122,50 @@ class LLMRequestSubStage(Stage):
req.func_tool = ToolSet()
req.func_tool.add_tool(KNOWLEDGE_BASE_QUERY_TOOL)
async def _apply_file_extract(
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
"""Apply file extract to the provider request"""
file_paths = []
file_names = []
for comp in event.message_obj.message:
if isinstance(comp, File):
file_paths.append(await comp.get_file())
file_names.append(comp.name)
elif isinstance(comp, Reply) and comp.chain:
for reply_comp in comp.chain:
if isinstance(reply_comp, File):
file_paths.append(await reply_comp.get_file())
file_names.append(reply_comp.name)
if not file_paths:
return
if not req.prompt:
req.prompt = "总结一下文件里面讲了什么?"
if self.file_extract_prov == "moonshotai":
if not self.file_extract_msh_api_key:
logger.error("Moonshot AI API key for file extract is not set")
return
file_contents = await asyncio.gather(
*[
extract_file_moonshotai(file_path, self.file_extract_msh_api_key)
for file_path in file_paths
]
)
else:
logger.error(f"Unsupported file extract provider: {self.file_extract_prov}")
return
# add file extract results to contexts
for file_content, file_name in zip(file_contents, file_names):
req.contexts.append(
{
"role": "system",
"content": f"File Extract Results of user uploaded files:\n{file_content}\nFile Name: {file_name or 'Unknown'}",
},
)
def _truncate_contexts(
self,
contexts: list[dict],
@@ -304,21 +346,10 @@ class LLMRequestSubStage(Stage):
return fixed_messages
async def process(
self,
event: AstrMessageEvent,
_nested: bool = False,
) -> None | AsyncGenerator[None, None]:
self, event: AstrMessageEvent, provider_wake_prefix: str
) -> AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
logger.debug("未启用 LLM 能力,跳过处理。")
return
# 检查会话级别的LLM启停状态
if not SessionServiceManager.should_process_llm_request(event):
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
return
provider = self._select_provider(event)
if provider is None:
return
@@ -348,12 +379,12 @@ class LLMRequestSubStage(Stage):
req.image_urls = []
if sel_model := event.get_extra("selected_model"):
req.model = sel_model
if self.provider_wake_prefix and not event.message_str.startswith(
self.provider_wake_prefix
if provider_wake_prefix and not event.message_str.startswith(
provider_wake_prefix
):
return
req.prompt = event.message_str[len(self.provider_wake_prefix) :]
req.prompt = event.message_str[len(provider_wake_prefix) :]
# func_tool selection 现在已经转移到 packages/astrbot 插件中进行选择。
# req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
@@ -367,6 +398,17 @@ class LLMRequestSubStage(Stage):
event.set_extra("provider_request", req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# apply file extract
if self.file_extract_enabled:
try:
await self._apply_file_extract(event, req)
except Exception as e:
logger.error(f"Error occurred while applying file extract: {e}")
if not req.prompt and not req.image_urls:
return
@@ -377,10 +419,6 @@ class LLMRequestSubStage(Stage):
# apply knowledge base feature
await self._apply_kb(event, req)
# fix contexts json str
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
# truncate contexts to fit max length
if req.contexts:
req.contexts = self._truncate_contexts(req.contexts)
@@ -0,0 +1,205 @@
import asyncio
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING
from astrbot.core import astrbot_config, logger
from astrbot.core.agent.runners.coze.coze_agent_runner import CozeAgentRunner
from astrbot.core.agent.runners.dashscope.dashscope_agent_runner import (
DashscopeAgentRunner,
)
from astrbot.core.agent.runners.dify.dify_agent_runner import DifyAgentRunner
from astrbot.core.message.components import Image
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
ResultContentType,
)
if TYPE_CHECKING:
from astrbot.core.agent.runners.base import BaseAgentRunner
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import (
ProviderRequest,
)
from astrbot.core.star.star_handler import EventType
from astrbot.core.utils.metrics import Metric
from .....astr_agent_context import AgentContextWrapper, AstrAgentContext
from .....astr_agent_hooks import MAIN_AGENT_HOOKS
from ....context import PipelineContext, call_event_hook
from ...stage import Stage
AGENT_RUNNER_TYPE_KEY = {
"dify": "dify_agent_runner_provider_id",
"coze": "coze_agent_runner_provider_id",
"dashscope": "dashscope_agent_runner_provider_id",
}
async def run_third_party_agent(
runner: "BaseAgentRunner",
stream_to_general: bool = False,
) -> AsyncGenerator[MessageChain | None, None]:
"""
运行第三方 agent runner 并转换响应格式
类似于 run_agent 函数,但专门处理第三方 agent runner
"""
try:
async for resp in runner.step_until_done(max_step=30): # type: ignore[misc]
if resp.type == "streaming_delta":
if stream_to_general:
continue
yield resp.data["chain"]
elif resp.type == "llm_result":
if stream_to_general:
yield resp.data["chain"]
except Exception as e:
logger.error(f"Third party agent runner error: {e}")
err_msg = (
f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n"
f"错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n"
)
yield MessageChain().message(err_msg)
class ThirdPartyAgentSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.conf = ctx.astrbot_config
self.runner_type = self.conf["provider_settings"]["agent_runner_type"]
self.prov_id = self.conf["provider_settings"].get(
AGENT_RUNNER_TYPE_KEY.get(self.runner_type, ""),
"",
)
settings = ctx.astrbot_config["provider_settings"]
self.streaming_response: bool = settings["streaming_response"]
self.unsupported_streaming_strategy: str = settings[
"unsupported_streaming_strategy"
]
async def process(
self, event: AstrMessageEvent, provider_wake_prefix: str
) -> AsyncGenerator[None, None]:
req: ProviderRequest | None = None
if provider_wake_prefix and not event.message_str.startswith(
provider_wake_prefix
):
return
self.prov_cfg: dict = next(
(p for p in astrbot_config["provider"] if p["id"] == self.prov_id),
{},
)
if not self.prov_id:
logger.error("没有填写 Agent Runner 提供商 ID,请前往配置页面配置。")
return
if not self.prov_cfg:
logger.error(
f"Agent Runner 提供商 {self.prov_id} 配置不存在,请前往配置页面修改配置。"
)
return
# make provider request
req = ProviderRequest()
req.session_id = event.unified_msg_origin
req.prompt = event.message_str[len(provider_wake_prefix) :]
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_base64()
req.image_urls.append(image_path)
if not req.prompt and not req.image_urls:
return
# call event hook
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
return
if self.runner_type == "dify":
runner = DifyAgentRunner[AstrAgentContext]()
elif self.runner_type == "coze":
runner = CozeAgentRunner[AstrAgentContext]()
elif self.runner_type == "dashscope":
runner = DashscopeAgentRunner[AstrAgentContext]()
else:
raise ValueError(
f"Unsupported third party agent runner type: {self.runner_type}",
)
astr_agent_ctx = AstrAgentContext(
context=self.ctx.plugin_manager.context,
event=event,
)
streaming_response = self.streaming_response
if (enable_streaming := event.get_extra("enable_streaming")) is not None:
streaming_response = bool(enable_streaming)
stream_to_general = (
self.unsupported_streaming_strategy == "turn_off"
and not event.platform_meta.support_streaming_message
)
await runner.reset(
request=req,
run_context=AgentContextWrapper(
context=astr_agent_ctx,
tool_call_timeout=60,
),
agent_hooks=MAIN_AGENT_HOOKS,
provider_config=self.prov_cfg,
streaming=streaming_response,
)
if streaming_response and not stream_to_general:
# 流式响应
event.set_result(
MessageEventResult()
.set_result_content_type(ResultContentType.STREAMING_RESULT)
.set_async_stream(
run_third_party_agent(
runner,
stream_to_general=False,
),
),
)
yield
if runner.done():
final_resp = runner.get_final_llm_resp()
if final_resp and final_resp.result_chain:
event.set_result(
MessageEventResult(
chain=final_resp.result_chain.chain or [],
result_content_type=ResultContentType.STREAMING_FINISH,
),
)
else:
# 非流式响应或转换为普通响应
async for _ in run_third_party_agent(
runner,
stream_to_general=stream_to_general,
):
yield
final_resp = runner.get_final_llm_resp()
if not final_resp or not final_resp.result_chain:
logger.warning("Agent Runner 未返回最终结果。")
return
event.set_result(
MessageEventResult(
chain=final_resp.result_chain.chain or [],
result_content_type=ResultContentType.LLM_RESULT,
),
)
yield
asyncio.create_task(
Metric.upload(
llm_tick=1,
model_name=self.runner_type,
provider_type=self.runner_type,
),
)
@@ -16,7 +16,6 @@ from ..stage import Stage
class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
self.ctx = ctx
@@ -24,7 +23,7 @@ class StarRequestSubStage(Stage):
async def process(
self,
event: AstrMessageEvent,
) -> None | AsyncGenerator[None, None]:
) -> AsyncGenerator[Any, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers",
)
+9 -14
View File
@@ -1,13 +1,12 @@
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.star.star_handler import StarHandlerMetadata
from ..context import PipelineContext
from ..stage import Stage, register_stage
from .method.llm_request import LLMRequestSubStage
from .method.agent_request import AgentRequestSubStage
from .method.star_request import StarRequestSubStage
@@ -17,9 +16,12 @@ class ProcessStage(Stage):
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.llm_request_sub_stage = LLMRequestSubStage()
await self.llm_request_sub_stage.initialize(ctx)
# initialize agent sub stage
self.agent_sub_stage = AgentRequestSubStage()
await self.agent_sub_stage.initialize(ctx)
# initialize star request sub stage
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
@@ -39,7 +41,7 @@ class ProcessStage(Stage):
# Handler 的 LLM 请求
event.set_extra("provider_request", resp)
_t = False
async for _ in self.llm_request_sub_stage.process(event):
async for _ in self.agent_sub_stage.process(event):
_t = True
yield
if not _t:
@@ -58,14 +60,7 @@ class ProcessStage(Stage):
):
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (
event.get_result() and not event.get_result().is_stopped()
event.get_result() and not event.is_stopped()
) or not event.get_result():
# 事件没有终止传播
provider = self.ctx.plugin_manager.context.get_using_provider()
if not provider:
logger.info("未找到可用的 LLM 提供商,请先前往配置服务提供商。")
return
async for _ in self.llm_request_sub_stage.process(event):
async for _ in self.agent_sub_stage.process(event):
yield
+4 -2
View File
@@ -117,7 +117,9 @@ class RespondStage(Stage):
if not self.enable_seg:
return False
if self.only_llm_result and not event.get_result().is_llm_result():
if (result := event.get_result()) is None:
return False
if self.only_llm_result and result.is_llm_result():
return False
if event.get_platform_name() in [
@@ -185,7 +187,7 @@ class RespondStage(Stage):
if isinstance(component, Comp.File) and component.file:
# 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component
result.chain[idx] = component
# 检查消息链是否为空
try:
+68 -11
View File
@@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager
@@ -53,7 +54,22 @@ class ResultDecorateStage(Stage):
self.only_llm_result = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["only_llm_result"]
self.split_mode = ctx.astrbot_config["platform_settings"][
"segmented_reply"
].get("split_mode", "regex")
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
self.split_words = ctx.astrbot_config["platform_settings"][
"segmented_reply"
].get("split_words", ["", "", "", "~", ""])
if self.split_words:
escaped_words = sorted(
[re.escape(word) for word in self.split_words], key=len, reverse=True
)
self.split_words_pattern = re.compile(
f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL
)
else:
self.split_words_pattern = None
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["content_cleanup_rule"]
@@ -69,6 +85,28 @@ class ResultDecorateStage(Stage):
self.content_safe_check_stage = stage_cls()
await self.content_safe_check_stage.initialize(ctx)
def _split_text_by_words(self, text: str) -> list[str]:
"""使用分段词列表分段文本"""
if not self.split_words_pattern:
return [text]
segments = self.split_words_pattern.findall(text)
result = []
for seg in segments:
if isinstance(seg, tuple):
content = seg[0]
if not isinstance(content, str):
continue
for word in self.split_words:
if content.endswith(word):
content = content[: -len(word)]
break
if content.strip():
result.append(content)
elif seg and seg.strip():
result.append(seg)
return result if result else [text]
async def process(
self,
event: AstrMessageEvent,
@@ -93,11 +131,13 @@ class ResultDecorateStage(Stage):
for comp in result.chain:
if isinstance(comp, Plain):
text += comp.text
async for _ in self.content_safe_check_stage.process(
event,
check_text=text,
):
yield
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
async for _ in self.content_safe_check_stage.process(
event,
check_text=text,
):
yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
@@ -114,7 +154,8 @@ class ResultDecorateStage(Stage):
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
)
await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
if (result := event.get_result()) is None or not result.chain:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
)
@@ -161,11 +202,27 @@ class ResultDecorateStage(Stage):
# 不分段回复
new_chain.append(comp)
continue
split_response = re.findall(
self.regex,
comp.text,
re.DOTALL | re.MULTILINE,
)
# 根据 split_mode 选择分段方式
if self.split_mode == "words":
split_response = self._split_text_by_words(comp.text)
else: # regex 模式
try:
split_response = re.findall(
self.regex,
comp.text,
re.DOTALL | re.MULTILINE,
)
except re.error:
logger.error(
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
)
split_response = re.findall(
r".*?[。?!~…]+|.+$",
comp.text,
re.DOTALL | re.MULTILINE,
)
if not split_response:
new_chain.append(comp)
continue
+5 -1
View File
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform import AstrMessageEvent
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
WecomAIBotMessageEvent,
)
from . import STAGES_ORDER
from .context import PipelineContext
@@ -78,7 +82,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
await event.send(None)
logger.debug("pipeline 执行完毕。")
@@ -50,6 +50,9 @@ class WakingCheckStage(Stage):
"ignore_at_all",
False,
)
self.disable_builtin_commands = self.ctx.astrbot_config.get(
"disable_builtin_commands", False
)
async def process(
self,
@@ -131,6 +134,13 @@ class WakingCheckStage(Stage):
EventType.AdapterMessageEvent,
plugins_name=event.plugins_name,
):
if (
self.disable_builtin_commands
and handler.handler_module_path == "packages.builtin_commands.main"
):
logger.debug("skipping builtin command")
continue
# filter 需满足 AND 逻辑关系
passed = True
permission_not_pass = False
+5 -3
View File
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)"""
return self.message_obj.sender.nickname
if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value):
"""设置额外的信息。"""
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
"""
self.call_llm = call_llm
def get_result(self) -> MessageEventResult:
def get_result(self) -> MessageEventResult | None:
"""获取消息事件的结果。"""
return self._result
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
self,
prompt: str,
func_tool_manager=None,
session_id: str = None,
session_id: str = "",
image_urls: list[str] | None = None,
contexts: list | None = None,
system_prompt: str = "",
+2 -2
View File
@@ -54,7 +54,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group: Group # 群组
group: Group | None # 群组
sender: MessageMember # 发送者
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
@@ -78,7 +78,7 @@ class AstrBotMessage:
return ""
@group_id.setter
def group_id(self, value: str):
def group_id(self, value: str | None):
"""设置 group_id"""
if value:
if self.group:
+71 -9
View File
@@ -5,8 +5,9 @@ from asyncio import Queue
from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
from .platform import Platform
from .platform import Platform, PlatformStatus
from .register import platform_cls_map
from .sources.webchat.webchat_adapter import WebChatAdapter
@@ -16,8 +17,9 @@ class PlatformManager:
self.platform_insts: list[Platform] = []
"""加载的 Platform 的实例"""
self._inst_map = {}
self._inst_map: dict[str, dict] = {}
self.astrbot_config = config
self.platforms_config = config["platform"]
self.settings = config["platform_settings"]
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
@@ -29,6 +31,8 @@ class PlatformManager:
"""初始化所有平台适配器"""
for platform in self.platforms_config:
try:
if ensure_platform_webhook_config(platform):
self.astrbot_config.save_config()
await self.load_platform(platform)
except Exception as e:
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
@@ -37,7 +41,10 @@ class PlatformManager:
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
self.platform_insts.append(webchat_inst)
asyncio.create_task(
self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")),
self._task_wrapper(
asyncio.create_task(webchat_inst.run(), name="webchat"),
platform=webchat_inst,
),
)
async def load_platform(self, platform_config: dict):
@@ -107,7 +114,7 @@ class PlatformManager:
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。",
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
)
except Exception as e:
logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}")
@@ -131,6 +138,7 @@ class PlatformManager:
inst.run(),
name=f"platform_{platform_config['type']}_{platform_config['id']}",
),
platform=inst,
),
)
handlers = star_handlers_registry.get_handlers_by_event_type(
@@ -145,17 +153,28 @@ class PlatformManager:
except Exception:
logger.error(traceback.format_exc())
async def _task_wrapper(self, task: asyncio.Task):
async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None):
# 设置平台状态为运行中
if platform:
platform.status = PlatformStatus.RUNNING
try:
await task
except asyncio.CancelledError:
pass
if platform:
platform.status = PlatformStatus.STOPPED
except Exception as e:
error_msg = str(e)
tb_str = traceback.format_exc()
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
for line in traceback.format_exc().split("\n"):
for line in tb_str.split("\n"):
logger.error(f"| {line}")
logger.error("-------")
# 记录错误到平台实例
if platform:
platform.record_error(error_msg, tb_str)
async def reload(self, platform_config: dict):
await self.terminate_platform(platform_config["id"])
if platform_config["enable"]:
@@ -172,9 +191,9 @@ class PlatformManager:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
# client_id = self._inst_map.pop(platform_id, None)
info = self._inst_map.pop(platform_id, None)
info = self._inst_map.pop(platform_id)
client_id = info["client_id"]
inst = info["inst"]
inst: Platform = info["inst"]
try:
self.platform_insts.remove(
next(
@@ -196,3 +215,46 @@ class PlatformManager:
def get_insts(self):
return self.platform_insts
def get_all_stats(self) -> dict:
"""获取所有平台的统计信息
Returns:
包含所有平台统计信息的字典
"""
stats_list = []
total_errors = 0
running_count = 0
error_count = 0
for inst in self.platform_insts:
try:
stat = inst.get_stats()
stats_list.append(stat)
total_errors += stat.get("error_count", 0)
if stat.get("status") == PlatformStatus.RUNNING.value:
running_count += 1
elif stat.get("status") == PlatformStatus.ERROR.value:
error_count += 1
except Exception as e:
# 如果获取统计信息失败,记录基本信息
logger.warning(f"获取平台统计信息失败: {e}")
stats_list.append(
{
"id": getattr(inst, "config", {}).get("id", "unknown"),
"type": "unknown",
"status": "unknown",
"error_count": 0,
"last_error": None,
}
)
return {
"platforms": stats_list,
"summary": {
"total": len(stats_list),
"running": running_count,
"error": error_count,
"total_errors": total_errors,
},
}
+109 -4
View File
@@ -1,7 +1,10 @@
import abc
import uuid
from asyncio import Queue
from collections.abc import Awaitable
from collections.abc import Coroutine
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
from astrbot.core.message.message_event_result import MessageChain
@@ -12,15 +15,100 @@ from .message_session import MessageSesion
from .platform_metadata import PlatformMetadata
class PlatformStatus(Enum):
"""平台运行状态"""
PENDING = "pending" # 待启动
RUNNING = "running" # 运行中
ERROR = "error" # 发生错误
STOPPED = "stopped" # 已停止
@dataclass
class PlatformError:
"""平台错误信息"""
message: str
timestamp: datetime = field(default_factory=datetime.now)
traceback: str | None = None
class Platform(abc.ABC):
def __init__(self, event_queue: Queue):
def __init__(self, config: dict, event_queue: Queue):
super().__init__()
# 平台配置
self.config = config
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
self._event_queue = event_queue
self.client_self_id = uuid.uuid4().hex
# 平台运行状态
self._status: PlatformStatus = PlatformStatus.PENDING
self._errors: list[PlatformError] = []
self._started_at: datetime | None = None
@property
def status(self) -> PlatformStatus:
"""获取平台运行状态"""
return self._status
@status.setter
def status(self, value: PlatformStatus):
"""设置平台运行状态"""
self._status = value
if value == PlatformStatus.RUNNING and self._started_at is None:
self._started_at = datetime.now()
@property
def errors(self) -> list[PlatformError]:
"""获取错误列表"""
return self._errors
@property
def last_error(self) -> PlatformError | None:
"""获取最近的错误"""
return self._errors[-1] if self._errors else None
def record_error(self, message: str, traceback_str: str | None = None):
"""记录一个错误"""
self._errors.append(PlatformError(message=message, traceback=traceback_str))
self._status = PlatformStatus.ERROR
def clear_errors(self):
"""清除错误记录"""
self._errors.clear()
if self._status == PlatformStatus.ERROR:
self._status = PlatformStatus.RUNNING
def unified_webhook(self) -> bool:
"""是否正在使用统一 Webhook 模式"""
return bool(
self.config.get("unified_webhook_mode", False)
and self.config.get("webhook_uuid")
)
def get_stats(self) -> dict:
"""获取平台统计信息"""
meta = self.meta()
return {
"id": meta.id or self.config.get("id"),
"type": meta.name,
"display_name": meta.adapter_display_name or meta.name,
"status": self._status.value,
"started_at": self._started_at.isoformat() if self._started_at else None,
"error_count": len(self._errors),
"last_error": {
"message": self.last_error.message,
"timestamp": self.last_error.timestamp.isoformat(),
"traceback": self.last_error.traceback,
}
if self.last_error
else None,
"unified_webhook": self.unified_webhook(),
}
@abc.abstractmethod
def run(self) -> Awaitable[Any]:
def run(self) -> Coroutine[Any, Any, None]:
"""得到一个平台的运行实例,需要返回一个协程对象。"""
raise NotImplementedError
@@ -36,7 +124,7 @@ class Platform(abc.ABC):
self,
session: MessageSesion,
message_chain: MessageChain,
) -> Awaitable[Any]:
) -> None:
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
异步方法
@@ -49,3 +137,20 @@ class Platform(abc.ABC):
def get_client(self):
"""获取平台的客户端对象。"""
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口。
支持统一 Webhook 模式的平台需要实现此方法
Dashboard 收到 /api/platform/webhook/{uuid} 请求时会调用此方法
Args:
request: Quart 请求对象
Returns:
响应内容格式取决于具体平台的要求
Raises:
NotImplementedError: 平台未实现统一 Webhook 模式
"""
raise NotImplementedError(f"平台 {self.meta().name} 未实现统一 Webhook 模式")
+1 -1
View File
@@ -7,7 +7,7 @@ class PlatformMetadata:
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str
"""平台的描述"""
id: str | None = None
id: str
"""平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict | None = None
+1
View File
@@ -40,6 +40,7 @@ def register_platform_adapter(
pm = PlatformMetadata(
name=adapter_name,
description=desc,
id=adapter_name,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
bot: CQHttp,
event: Event | None,
is_group: bool,
session_id: str,
session_id: str | None,
messages: list[dict],
):
# session_id 必须是纯数字字符串
session_id = int(session_id) if session_id.isdigit() else None
session_id_int = (
int(session_id) if session_id and session_id.isdigit() else None
)
if is_group and isinstance(session_id, int):
await bot.send_group_msg(group_id=session_id, message=messages)
elif not is_group and isinstance(session_id, int):
await bot.send_private_msg(user_id=session_id, message=messages)
if is_group and isinstance(session_id_int, int):
await bot.send_group_msg(group_id=session_id_int, message=messages)
elif not is_group and isinstance(session_id_int, int):
await bot.send_private_msg(user_id=session_id_int, message=messages)
elif isinstance(event, Event): # 最后兜底
await bot.send(event=event, message=messages)
else:
@@ -4,7 +4,7 @@ import logging
import time
import uuid
from collections.abc import Awaitable
from typing import Any
from typing import Any, cast
from aiocqhttp import CQHttp, Event
from aiocqhttp.exceptions import ActionFailed
@@ -38,9 +38,8 @@ class AiocqhttpAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
super().__init__(platform_config, event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.host = platform_config["ws_reverse_host"]
@@ -49,7 +48,7 @@ class AiocqhttpAdapter(Platform):
self.metadata = PlatformMetadata(
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
@@ -128,7 +127,9 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 请求类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.sender = MessageMember(
user_id=str(event.user_id), nickname=str(event.user_id)
)
abm.type = MessageType.OTHER_MESSAGE
if event.get("group_id"):
abm.type = MessageType.GROUP_MESSAGE
@@ -154,7 +155,9 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 通知类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.sender = MessageMember(
user_id=str(event.user_id), nickname=str(event.user_id)
)
abm.type = MessageType.OTHER_MESSAGE
if event.get("group_id"):
abm.group_id = str(event.group_id)
@@ -193,6 +196,7 @@ class AiocqhttpAdapter(Platform):
@param event: 事件对象
@param get_reply: 是否获取回复消息这个参数是为了防止多个回复嵌套
"""
assert event.sender is not None
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
@@ -202,6 +206,7 @@ class AiocqhttpAdapter(Platform):
if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
abm.group = Group(str(event.group_id))
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
@@ -227,7 +232,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send(event, err)
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return None
raise ValueError(err)
# 按消息段类型类型适配
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
@@ -246,7 +251,13 @@ class AiocqhttpAdapter(Platform):
if m["data"].get("url") and m["data"].get("url").startswith("http"):
# Lagrange
logger.info("guessing lagrange")
file_name = m["data"].get("file_name", "file")
# 检查多个可能的文件名字段
file_name = (
m["data"].get("file_name", "")
or m["data"].get("name", "")
or m["data"].get("file", "")
or "file"
)
abm.message.append(File(name=file_name, url=m["data"]["url"]))
else:
try:
@@ -265,7 +276,14 @@ class AiocqhttpAdapter(Platform):
)
if ret and "url" in ret:
file_url = ret["url"] # https
a = File(name="", url=file_url)
# 优先从 API 返回值获取文件名,其次从原始消息数据获取
file_name = (
ret.get("file_name", "")
or ret.get("name", "")
or m["data"].get("file", "")
or m["data"].get("file_name", "")
)
a = File(name=file_name, url=file_url)
abm.message.append(a)
else:
logger.error(f"获取文件失败: {ret}")
@@ -403,7 +421,7 @@ class AiocqhttpAdapter(Platform):
async def shutdown_trigger_placeholder(self):
await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被优雅地关闭")
logger.info("aiocqhttp 适配器已被关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
@@ -2,6 +2,7 @@ import asyncio
import os
import threading
import uuid
from typing import cast
import aiohttp
import dingtalk_stream
@@ -47,21 +48,21 @@ class DingtalkPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.unique_session = platform_settings["unique_session"]
self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"]
outer_self = self
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
async def process(self_, message: dingtalk_stream.CallbackMessage):
async def process(self, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
abm = await self.convert_msg(im)
await self.handle_msg(abm)
abm = await outer_self.convert_msg(im)
await outer_self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK"
@@ -75,14 +76,15 @@ class DingtalkPlatformAdapter(Platform):
self.client,
)
self.client_ = client # 用于 websockets 的 client
self._shutdown_event: threading.Event | None = None
def _id_to_sid(self, dingtalk_id: str | None) -> str | None:
def _id_to_sid(self, dingtalk_id: str | None) -> str:
if not dingtalk_id:
return dingtalk_id
return dingtalk_id or "unknown"
prefix = "$:LWCP_v1:$"
if dingtalk_id.startswith(prefix):
return dingtalk_id[len(prefix) :]
return dingtalk_id
return dingtalk_id or "unknown"
async def send_by_session(
self,
@@ -95,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
return PlatformMetadata(
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
@@ -106,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
abm = AstrBotMessage()
abm.message = []
abm.message_str = ""
abm.timestamp = int(message.create_at / 1000)
abm.timestamp = int(cast(int, message.create_at) / 1000)
abm.type = (
MessageType.GROUP_MESSAGE
if message.conversation_type == "2"
@@ -117,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
nickname=message.sender_nick,
)
abm.self_id = self._id_to_sid(message.chatbot_user_id)
abm.message_id = message.message_id
abm.message_id = cast(str, message.message_id)
abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE:
@@ -134,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
else:
abm.session_id = abm.sender.user_id
message_type: str = message.message_type
message_type: str = cast(str, message.message_type)
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "richText":
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
contents: list[dict] = rtc.rich_text_list
rtc: dingtalk_stream.RichTextContent = cast(
dingtalk_stream.RichTextContent, message.rich_text_content
)
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
for content in contents:
plains = ""
if "text" in content:
@@ -150,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file(
content["downloadCode"],
message.robot_code,
cast(str, message.robot_code),
"jpg",
)
abm.message.append(Image.fromFileSystem(f_path))
@@ -195,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
)
return None
return ""
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path)
@@ -215,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
)
return None
return ""
return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage):
@@ -241,7 +245,7 @@ class DingtalkPlatformAdapter(Platform):
task.result()
except Exception as e:
if "Graceful shutdown" in str(e):
logger.info("钉钉适配器已被优雅地关闭")
logger.info("钉钉适配器已被关闭")
return
logger.error(f"钉钉机器人启动失败: {e}")
@@ -250,11 +254,13 @@ class DingtalkPlatformAdapter(Platform):
async def terminate(self):
def monkey_patch_close():
raise Exception("Graceful shutdown")
raise KeyboardInterrupt("Graceful shutdown")
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
self._shutdown_event.set()
if self.client_.websocket is not None:
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
if self._shutdown_event is not None:
self._shutdown_event.set()
def get_client(self):
return self.client
@@ -1,4 +1,5 @@
import asyncio
from typing import cast
import dingtalk_stream
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown,
segment.text,
segment.text,
self.message_obj.raw_message,
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown,
"😄",
markdown_str,
self.message_obj.raw_message,
cast(
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
),
)
logger.debug(f"send image: {ret}")
@@ -1,4 +1,5 @@
import sys
from collections.abc import Awaitable, Callable
import discord
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
super().__init__(intents=intents, proxy=proxy)
# 回调函数
self.on_message_received = None
self.on_ready_once_callback = None
self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
self._ready_once_fired = False
@override
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
if self.user is None:
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
return
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
is_mentioned = self.user in message.mentions
return {
"message": message,
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
if interaction.user is None:
raise ValueError("Interaction received without a valid user")
return {
"interaction": interaction,
"bot_id": str(self.user.id),
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
"type": "interaction",
}
@override
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
def __init__(
self,
components: list[BaseMessageComponent] = None,
timeout: float = None,
components: list[BaseMessageComponent] | None = None,
timeout: float | None = None,
):
self.components = components or []
self.timeout = timeout
@@ -1,10 +1,10 @@
import asyncio
import re
import sys
from typing import Any
from typing import Any, cast
import discord
from discord.abc import Messageable
from discord.abc import GuildChannel, Messageable, PrivateChannel
from discord.channel import DMChannel
from astrbot import logger
@@ -44,10 +44,9 @@ class DiscordPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.client_self_id = None
self.client_self_id: str | None = None
self.registered_handlers = []
# 指令注册相关
self.enable_command_register = self.config.get("discord_command_register", True)
@@ -63,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
message_chain: MessageChain,
):
"""通过会话发送消息"""
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
)
return
# 创建一个 message_obj 以便在 event 中使用
message_obj = AstrBotMessage()
if "_" in session.session_id:
@@ -90,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
user_id=str(self.client_self_id),
nickname=self.client.user.display_name,
)
message_obj.self_id = self.client_self_id
message_obj.self_id = cast(str, self.client_self_id)
message_obj.session_id = session.session_id
message_obj.message = message_chain.chain
@@ -111,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
return PlatformMetadata(
"discord",
"Discord 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
default_config_tmpl=self.config,
support_streaming_message=False,
)
@@ -161,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
def _get_message_type(
self,
channel: Messageable,
channel: Messageable | GuildChannel | PrivateChannel,
guild_id: int | None = None,
) -> MessageType:
"""根据 channel 对象和 guild_id 判断消息类型"""
@@ -171,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
return MessageType.FRIEND_MESSAGE
return MessageType.GROUP_MESSAGE
def _get_channel_id(self, channel: Messageable) -> str:
def _get_channel_id(
self, channel: Messageable | GuildChannel | PrivateChannel
) -> str:
"""根据 channel 对象获取ID"""
return str(getattr(channel, "id", None))
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"]
message = data["message"]
content = message.content
@@ -234,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
)
abm.message = message_chain
abm.raw_message = message
abm.self_id = self.client_self_id
abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(message.channel.id)
abm.message_id = str(message.id)
return abm
@@ -255,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
interaction_followup_webhook=followup_webhook,
)
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
)
return
# 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None
# 1. 优先处理斜杠指令
if is_slash_command:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
self.commit_event(message_event)
return
# 2. 处理普通消息(提及检测)
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
raw_message = message.raw_message
if not isinstance(raw_message, discord.Message):
logger.warning(
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
)
return
# 检查是否被@User Mention 或 Bot 拥有的 Role Mention
is_mention = False
# User Mention
if (
self.client
and self.client.user
and hasattr(message.raw_message, "mentions")
):
if self.client.user in message.raw_message.mentions:
is_mention = True
# 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
if self.client.user in raw_message.mentions:
is_mention = True
# Role MentionBot 拥有的角色被提及)
if not is_mention and hasattr(message.raw_message, "role_mentions"):
if not is_mention and raw_message.role_mentions:
bot_member = None
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
if raw_message.guild:
try:
bot_member = message.raw_message.guild.get_member(
bot_member = raw_message.guild.get_member(
self.client.user.id,
)
except Exception:
bot_member = None
if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles)
mentioned_roles = set(message.raw_message.role_mentions)
mentioned_roles = set(raw_message.role_mentions)
if (
bot_roles
and mentioned_roles
@@ -288,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
):
is_mention = True
# 如果是斜杠指令或被@的消息,设置为唤醒状态
if is_slash_command or is_mention:
# 如果是被@的消息,设置为唤醒状态
if is_mention:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
@@ -425,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
)
abm.message = [Plain(text=message_str_for_filter)]
abm.raw_message = ctx.interaction
abm.self_id = self.client_self_id
abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(ctx.channel_id)
abm.message_id = str(ctx.interaction.id)
@@ -438,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
def _extract_command_info(
event_filter: Any,
handler_metadata: StarHandlerMetadata,
) -> tuple[str, str, CommandFilter] | None:
) -> tuple[str, str, CommandFilter | None] | None:
"""从事件过滤器中提取指令信息"""
cmd_name = None
# is_group = False
@@ -4,8 +4,10 @@ import binascii
from collections.abc import AsyncGenerator
from io import BytesIO
from pathlib import Path
from typing import cast
import discord
from discord.types.interactions import ComponentInteractionData
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
channel = await self._get_channel()
if not channel:
return
if not isinstance(channel, discord.abc.Messageable):
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
return
await channel.send(**kwargs)
except Exception as e:
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
async def _get_channel(self) -> discord.abc.Messageable | None:
async def _get_channel(
self,
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
"""获取当前事件对应的频道对象"""
try:
channel_id = int(self.session_id)
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
async def _parse_to_discord(
self,
message: MessageChain,
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
) -> tuple[
str,
list[discord.File],
discord.ui.View | None,
list[discord.Embed],
str | int | None,
]:
"""将 MessageChain 解析为 Discord 发送所需的内容"""
content_parts = []
files = []
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message,
"add_reaction",
):
await self.message_obj.raw_message.add_reaction(emoji)
await cast(discord.Message, self.message_obj.raw_message).add_reaction(
emoji
)
except Exception as e:
logger.error(f"[Discord] 添加反应失败: {e}")
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type
and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.application_command
)
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type == discord.InteractionType.component
and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.component
)
def get_interaction_custom_id(self) -> str:
"""获取交互组件的custom_id"""
if self.is_button_interaction():
try:
return self.message_obj.raw_message.data.get("custom_id", "")
return cast(
ComponentInteractionData,
cast(discord.Interaction, self.message_obj.raw_message).data,
).get("custom_id", "")
except Exception:
pass
return ""
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
):
return any(
mention.id == int(self.message_obj.self_id)
for mention in self.message_obj.raw_message.mentions
for mention in cast(
discord.Message, self.message_obj.raw_message
).mentions
)
return False
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message,
"clean_content",
):
return self.message_obj.raw_message.clean_content
return cast(discord.Message, self.message_obj.raw_message).clean_content
return self.message_str
@@ -2,10 +2,17 @@ import asyncio
import base64
import json
import re
import time
import uuid
from typing import Any, cast
import lark_oapi as lark
from lark_oapi.api.im.v1 import *
from lark_oapi.api.im.v1 import (
CreateMessageRequest,
CreateMessageRequestBody,
GetMessageResourceRequest,
)
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
import astrbot.api.message_components as Comp
from astrbot import logger
@@ -18,9 +25,11 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .lark_event import LarkMessageEvent
from .server import LarkWebhookServer
@register_platform_adapter(
@@ -33,9 +42,7 @@ class LarkPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.unique_session = platform_settings["unique_session"]
@@ -44,9 +51,13 @@ class LarkPlatformAdapter(Platform):
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
self.bot_name = platform_config.get("lark_bot_name", "astrbot")
# socket or webhook
self.connection_mode = platform_config.get("lark_connection_mode", "socket")
if not self.bot_name:
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
# 初始化 WebSocket 长连接相关配置
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
await self.convert_msg(event)
@@ -59,6 +70,8 @@ class LarkPlatformAdapter(Platform):
.build()
)
self.do_v2_msg_event = do_v2_msg_event
self.client = lark.ws.Client(
app_id=self.appid,
app_secret=self.appsecret,
@@ -71,11 +84,48 @@ class LarkPlatformAdapter(Platform):
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
)
self.webhook_server = None
if self.connection_mode == "webhook":
self.webhook_server = LarkWebhookServer(platform_config, event_queue)
self.webhook_server.set_callback(self.handle_webhook_event)
self.event_id_timestamps: dict[str, float] = {}
def _clean_expired_events(self):
"""清理超过 30 分钟的事件记录"""
current_time = time.time()
expired_keys = [
event_id
for event_id, timestamp in self.event_id_timestamps.items()
if current_time - timestamp > 1800
]
for event_id in expired_keys:
del self.event_id_timestamps[event_id]
def _is_duplicate_event(self, event_id: str) -> bool:
"""检查事件是否重复
Args:
event_id: 事件ID
Returns:
True 表示重复事件False 表示新事件
"""
self._clean_expired_events()
if event_id in self.event_id_timestamps:
return True
self.event_id_timestamps[event_id] = time.time()
return False
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
return
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = {
"zh_cn": {
@@ -116,14 +166,25 @@ class LarkPlatformAdapter(Platform):
return PlatformMetadata(
name="lark",
description="飞书机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
if event.event is None:
logger.debug("[Lark] 收到空事件(event.event is None)")
return
message = event.event.message
if message is None:
logger.debug("[Lark] 事件中没有消息体(message is None)")
return
abm = AstrBotMessage()
abm.timestamp = int(message.create_time) / 1000
if message.create_time:
abm.timestamp = int(message.create_time) // 1000
else:
abm.timestamp = int(time.time())
abm.message = []
abm.type = (
MessageType.GROUP_MESSAGE
@@ -138,14 +199,28 @@ class LarkPlatformAdapter(Platform):
at_list = {}
if message.mentions:
for m in message.mentions:
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
if m.name == self.bot_name:
abm.self_id = m.id.open_id
if m.id is None:
continue
# 飞书 open_id 可能是 None,这里做个防护
open_id = m.id.open_id if m.id.open_id else ""
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
content_json_b = json.loads(message.content)
if m.name == self.bot_name:
if m.id.open_id is not None:
abm.self_id = m.id.open_id
if message.content is None:
logger.warning("[Lark] 消息内容为空")
return
try:
content_json_b = json.loads(message.content)
except json.JSONDecodeError:
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
return
if message.message_type == "text":
message_str_raw = content_json_b["text"] # 带有 @ 的消息
message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
# at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本,去掉AT符号部分
@@ -170,27 +245,47 @@ class LarkPlatformAdapter(Platform):
content_json_b = _ls
elif message.message_type == "image":
content_json_b = [
{"tag": "img", "image_key": content_json_b["image_key"], "style": []},
{
"tag": "img",
"image_key": content_json_b.get("image_key"),
"style": [],
},
]
if message.message_type in ("post", "image"):
for comp in content_json_b:
if comp["tag"] == "at":
abm.message.append(at_list[comp["user_id"]])
elif comp["tag"] == "text" and comp["text"].strip():
if comp.get("tag") == "at":
user_id = comp.get("user_id")
if user_id in at_list:
abm.message.append(at_list[user_id])
elif comp.get("tag") == "text" and comp.get("text", "").strip():
abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp["tag"] == "img":
image_key = comp["image_key"]
elif comp.get("tag") == "img":
image_key = comp.get("image_key")
if not image_key:
continue
request = (
GetMessageResourceRequest.builder()
.message_id(message.message_id)
.message_id(cast(str, message.message_id))
.file_key(image_key)
.type("image")
.build()
)
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化")
continue
response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success():
logger.error(f"无法下载飞书图片: {image_key}")
continue
if response.file is None:
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
continue
image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Comp.Image.fromBase64(image_base64))
@@ -198,6 +293,19 @@ class LarkPlatformAdapter(Platform):
for comp in abm.message:
if isinstance(comp, Comp.Plain):
abm.message_str += comp.text
if message.message_id is None:
logger.error("[Lark] 消息缺少 message_id")
return
if (
event.event.sender is None
or event.event.sender.sender_id is None
or event.event.sender.sender_id.open_id is None
):
logger.error("[Lark] 消息发送者信息不完整")
return
abm.message_id = message.message_id
abm.raw_message = message
abm.sender = MessageMember(
@@ -229,13 +337,61 @@ class LarkPlatformAdapter(Platform):
self._event_queue.put_nowait(event)
async def handle_webhook_event(self, event_data: dict):
"""处理 Webhook 事件
Args:
event_data: Webhook 事件数据
"""
try:
header = event_data.get("header", {})
event_id = header.get("event_id", "")
if event_id and self._is_duplicate_event(event_id):
logger.debug(f"[Lark Webhook] 跳过重复事件: {event_id}")
return
event_type = header.get("event_type", "")
if event_type == "im.message.receive_v1":
processor = P2ImMessageReceiveV1Processor(self.do_v2_msg_event)
data = (processor.type())(event_data)
processor.do(data)
else:
logger.debug(f"[Lark Webhook] 未处理的事件类型: {event_type}")
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
async def run(self):
# self.client.start()
await self.client._connect()
if self.connection_mode == "webhook":
# Webhook 模式
if self.webhook_server is None:
logger.error("[Lark] Webhook 模式已启用,但 webhook_server 未初始化")
return
webhook_uuid = self.config.get("webhook_uuid")
if webhook_uuid:
log_webhook_info(f"{self.meta().id}(飞书 Webhook)", webhook_uuid)
else:
logger.warning("[Lark] Webhook 模式已启用,但未配置 webhook_uuid")
else:
# 长连接模式
await self.client._connect()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
if not self.webhook_server:
return {"error": "Webhook server not initialized"}, 500
return await self.webhook_server.handle_callback(request)
async def terminate(self):
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已被优雅地关闭")
if self.connection_mode == "socket":
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已关闭")
def get_client(self) -> lark.Client:
def get_client(self) -> lark.ws.Client:
return self.client
def unified_webhook(self) -> bool:
return bool(
self.config.get("lark_connection_mode", "") == "webhook"
and self.config.get("webhook_uuid")
)
@@ -5,7 +5,15 @@ import uuid
from io import BytesIO
import lark_oapi as lark
from lark_oapi.api.im.v1 import *
from lark_oapi.api.im.v1 import (
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
ReplyMessageRequest,
ReplyMessageRequestBody,
)
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
file_path = comp.file.replace("file:///", "")
elif comp.file and comp.file.startswith("http"):
image_file_path = await download_image_by_url(comp.file)
file_path = image_file_path
file_path = image_file_path if image_file_path else ""
elif comp.file and comp.file.startswith("base64://"):
base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str)
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue())
else:
file_path = comp.file
file_path = comp.file if comp.file else ""
if image_file is None:
image_file = open(file_path, "rb")
if not file_path:
logger.error("[Lark] 图片路径为空,无法上传")
continue
try:
image_file = open(file_path, "rb")
except Exception as e:
logger.error(f"[Lark] 无法打开图片文件: {e}")
continue
request = (
CreateImageRequest.builder()
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
)
.build()
)
if lark_client.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
continue
response = await lark_client.im.v1.image.acreate(request)
if not response.success():
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
continue
if response.data is None:
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
continue
image_key = response.data.image_key
logger.debug(image_key)
ret.append(_stage)
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
.build()
)
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
return
response = await self.bot.im.v1.message.areply(request)
if not response.success():
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message)
async def react(self, emoji: str):
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
return
request = (
CreateMessageReactionRequest.builder()
.message_id(self.message_obj.message_id)
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
)
.build()
)
response = await self.bot.im.v1.message_reaction.acreate(request)
if not response.success():
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
@@ -0,0 +1,206 @@
"""飞书(Lark) Webhook 服务器实现
实现飞书事件订阅的 Webhook 模式支持:
1. 请求 URL 验证 (challenge 验证)
2. 事件加密/解密 (AES-256-CBC)
3. 签名校验 (SHA256)
4. 事件接收和处理
"""
import asyncio
import base64
import hashlib
import json
from collections.abc import Awaitable, Callable
from Crypto.Cipher import AES
from astrbot.api import logger
class AESCipher:
"""AES 加密/解密工具类"""
def __init__(self, key: str):
self.bs = AES.block_size
self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
@staticmethod
def str_to_bytes(data):
u_type = type(b"".decode("utf8"))
if isinstance(data, u_type):
return data.encode("utf8")
return data
@staticmethod
def _unpad(s):
return s[: -ord(s[len(s) - 1 :])]
def decrypt(self, enc):
iv = enc[: AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
def decrypt_string(self, enc):
enc = base64.b64decode(enc)
return self.decrypt(enc).decode("utf8")
class LarkWebhookServer:
"""飞书 Webhook 服务器
仅支持统一 Webhook 模式
"""
def __init__(self, config: dict, event_queue: asyncio.Queue):
"""初始化 Webhook 服务器
Args:
config: 飞书配置
event_queue: 事件队列
"""
self.app_id = config["app_id"]
self.app_secret = config["app_secret"]
self.encrypt_key = config.get("lark_encrypt_key", "")
self.verification_token = config.get("lark_verification_token", "")
self.event_queue = event_queue
self.callback: Callable[[dict], Awaitable[None]] | None = None
# 初始化加密工具
self.cipher = None
if self.encrypt_key:
self.cipher = AESCipher(self.encrypt_key)
def verify_signature(
self,
timestamp: str,
nonce: str,
encrypt_key: str,
body: bytes,
signature: str,
) -> bool:
"""验证签名
Args:
timestamp: 请求时间戳
nonce: 随机数
encrypt_key: 加密密钥
body: 请求体
signature: 签名
Returns:
签名是否有效
"""
# 拼接字符串: timestamp + nonce + encrypt_key + body
bytes_b1 = (timestamp + nonce + encrypt_key).encode("utf-8")
bytes_b = bytes_b1 + body
h = hashlib.sha256(bytes_b)
calculated_signature = h.hexdigest()
return calculated_signature == signature
def decrypt_event(self, encrypted_data: str) -> dict:
"""解密事件数据
Args:
encrypted_data: 加密的事件数据
Returns:
解密后的事件字典
"""
if not self.cipher:
raise ValueError("未配置 encrypt_key,无法解密事件")
decrypted_str = self.cipher.decrypt_string(encrypted_data)
return json.loads(decrypted_str)
async def handle_challenge(self, event_data: dict) -> dict:
"""处理 challenge 验证请求
Args:
event_data: 事件数据
Returns:
包含 challenge 的响应
"""
challenge = event_data.get("challenge", "")
logger.info(f"[Lark Webhook] 收到 challenge 验证请求: {challenge}")
return {"challenge": challenge}
async def handle_callback(self, request) -> tuple[dict, int] | dict:
"""处理 webhook 回调,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
响应数据
"""
# 获取原始请求体
body = await request.get_data()
try:
event_data = await request.json
except Exception as e:
logger.error(f"[Lark Webhook] 解析请求体失败: {e}")
return {"error": "Invalid JSON"}, 400
if not event_data:
logger.error("[Lark Webhook] 请求体为空")
return {"error": "Empty request body"}, 400
# 如果配置了 encrypt_key,进行签名验证
if self.encrypt_key:
timestamp = request.headers.get("X-Lark-Request-Timestamp", "")
nonce = request.headers.get("X-Lark-Request-Nonce", "")
signature = request.headers.get("X-Lark-Signature", "")
if timestamp and nonce and signature:
if not self.verify_signature(
timestamp, nonce, self.encrypt_key, body, signature
):
logger.error("[Lark Webhook] 签名验证失败")
return {"error": "Invalid signature"}, 401
# 检查是否是加密事件
if "encrypt" in event_data:
try:
event_data = self.decrypt_event(event_data["encrypt"])
logger.debug(f"[Lark Webhook] 解密后的事件: {event_data}")
except Exception as e:
logger.error(f"[Lark Webhook] 解密事件失败: {e}")
return {"error": "Decryption failed"}, 400
# 验证 token
if self.verification_token:
header = event_data.get("header", {})
if header:
token = header.get("token", "")
else:
token = event_data.get("token", "")
if token != self.verification_token:
logger.error("[Lark Webhook] Verification Token 不匹配。")
return {"error": "Invalid verification token"}, 401
# 处理 URL 验证 (challenge)
if event_data.get("type") == "url_verification":
return await self.handle_challenge(event_data)
# 调用回调函数处理事件
if self.callback:
try:
await self.callback(event_data)
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件回调失败: {e}", exc_info=True)
return {"error": "Event processing failed"}, 500
return {}
def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
"""设置事件回调函数
Args:
callback: 处理事件的异步函数
"""
self.callback = callback
@@ -1,7 +1,6 @@
import asyncio
import os
import random
from collections.abc import Awaitable
from typing import Any
import astrbot.api.message_components as Comp
@@ -55,8 +54,7 @@ class MisskeyPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config or {}
super().__init__(platform_config or {}, event_queue)
self.settings = platform_settings or {}
self.instance_url = self.config.get("misskey_instance_url", "")
self.access_token = self.config.get("misskey_token", "")
@@ -204,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
if not isinstance(message.raw_message, dict):
message.raw_message = {}
message.raw_message["poll"] = poll
message.poll = poll
message.__setattr__("poll", poll)
except Exception:
pass
@@ -373,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
self,
session: MessageSession,
message_chain: MessageChain,
) -> Awaitable[Any]:
) -> None:
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain)
@@ -3,6 +3,7 @@ import base64
import os
import random
import uuid
from typing import cast
import aiofiles
import botpy
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
ret = await self._post_send(stream=stream_payload)
ret = cast(
message.Message,
await self._post_send(stream=stream_payload),
)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
@@ -69,6 +73,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
# 结束流式对话,并且传输 buffer 中剩余的消息
stream_payload["state"] = 10
ret = await self._post_send(stream=stream_payload)
else:
ret = await self._post_send()
except Exception as e:
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
@@ -81,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
return None
source = self.message_obj.raw_message
assert isinstance(
if not isinstance(
source,
(
botpy.message.Message,
@@ -89,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
botpy.message.DirectMessage,
botpy.message.C2CMessage,
),
)
):
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
return None
(
plain_text,
@@ -106,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
):
return None
payload = {
payload: dict = {
"content": plain_text,
"msg_id": self.message_obj.message_id,
}
@@ -116,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
ret = None
match type(source):
case botpy.message.GroupMessage:
match source:
case botpy.message.GroupMessage():
if not source.group_openid:
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
return None
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
@@ -138,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
group_openid=source.group_openid,
**payload,
)
case botpy.message.C2CMessage:
case botpy.message.C2CMessage():
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
@@ -167,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
**payload,
)
logger.debug(f"Message sent to C2C: {ret}")
case botpy.message.Message:
case botpy.message.Message():
if image_path:
payload["file_image"] = image_path
ret = await self.bot.api.post_message(
channel_id=source.channel_id,
**payload,
)
case botpy.message.DirectMessage:
case botpy.message.DirectMessage():
if image_path:
payload["file_image"] = image_path
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
case _:
pass
await super().send(self.send_buffer)
self.send_buffer = None
@@ -196,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"file_type": file_type,
"srv_send_msg": False,
}
result = None
if "openid" in kwargs:
payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
return await self.bot.api._http.request(route, json=payload)
if "group_openid" in kwargs:
result = await self.bot.api._http.request(route, json=payload)
elif "group_openid" in kwargs:
payload["group_openid"] = kwargs["group_openid"]
route = Route(
"POST",
"/v2/groups/{group_openid}/files",
group_openid=kwargs["group_openid"],
)
return await self.bot.api._http.request(route, json=payload)
result = await self.bot.api._http.request(route, json=payload)
else:
raise ValueError("Invalid upload parameters")
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to upload image, response is not dict: {result}"
)
return Media(
file_uuid=result["file_uuid"],
file_info=result["file_info"],
ttl=result.get("ttl", 0),
)
async def upload_group_and_c2c_record(
self,
@@ -250,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
result = await self.bot.api._http.request(route, json=payload)
if result:
if not isinstance(result, dict):
logger.error(f"上传文件响应格式错误: {result}")
return None
return Media(
file_uuid=result.get("file_uuid"),
file_info=result.get("file_info"),
file_uuid=result["file_uuid"],
file_info=result["file_info"],
ttl=result.get("ttl", 0),
file_id=result.get("id", ""),
)
except Exception as e:
logger.error(f"上传请求错误: {e}")
@@ -271,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
message_reference: message.Reference | None = None,
media: message.Media | None = None,
msg_id: str | None = None,
msg_seq: str = 1,
msg_seq: int | None = 1,
event_id: str | None = None,
markdown: message.MarkdownPayload | None = None,
keyboard: message.Keyboard | None = None,
@@ -280,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload = locals()
payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
return await self.bot.api._http.request(route, json=payload)
result = await self.bot.api._http.request(route, json=payload)
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to post c2c message, response is not dict: {result}"
)
return message.Message(**result)
@staticmethod
async def _parse_to_qqofficial(message: MessageChain):
@@ -300,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = file_to_base64(image_file_path)
elif i.file and i.file.startswith("base64://"):
image_base64 = i.file
else:
elif i.file:
image_base64 = file_to_base64(i.file)
else:
raise ValueError("Unsupported image file format")
image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record):
if i.file:
@@ -4,6 +4,7 @@ import asyncio
import logging
import os
import time
from typing import cast
import botpy
import botpy.message
@@ -44,7 +45,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
self._commit(abm)
@@ -97,13 +100,11 @@ class QQOfficialPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"]
self.unique_session: bool = platform_settings["unique_session"]
qq_group = platform_config["enable_group_c2c"]
guild_dm = platform_config["enable_guild_direct_message"]
@@ -139,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
return PlatformMetadata(
name="qq_official",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
)
@staticmethod
def _parse_from_qqofficial(
message: botpy.message.Message | botpy.message.GroupMessage,
message: botpy.message.Message
| botpy.message.GroupMessage
| botpy.message.DirectMessage
| botpy.message.C2CMessage,
message_type: MessageType,
):
abm = AstrBotMessage()
@@ -152,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.timestamp = int(time.time())
abm.raw_message = message
abm.message_id = message.id
abm.tag = "qq_official"
# abm.tag = "qq_official"
msg: list[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance(
@@ -182,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
message,
botpy.message.DirectMessage,
):
try:
if isinstance(message, botpy.message.Message):
abm.self_id = str(message.mentions[0].id)
except BaseException as _:
else:
abm.self_id = ""
plain_content = message.content.replace(
@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import Any, cast
import botpy
import botpy.message
@@ -11,6 +12,7 @@ from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.platform import AstrBotMessage, MessageType, Platform, PlatformMetadata
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
@@ -34,7 +36,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
self._commit(abm)
@@ -87,13 +91,12 @@ class QQOfficialWebhookPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"]
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
intents = botpy.Intents(
public_messages=True,
@@ -106,6 +109,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
timeout=20,
)
self.client.set_platform(self)
self.webhook_helper = None
async def send_by_session(
self,
@@ -118,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
return PlatformMetadata(
name="qq_official_webhook",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
)
async def run(self):
@@ -128,16 +132,37 @@ class QQOfficialWebhookPlatformAdapter(Platform):
self.client,
)
await self.webhook_helper.initialize()
await self.webhook_helper.start_polling()
# 如果启用统一 webhook 模式,则不启动独立服务器
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(QQ 官方机器人 Webhook)", webhook_uuid)
# 保持运行状态,等待 shutdown
await self.webhook_helper.shutdown_event.wait()
else:
await self.webhook_helper.start_polling()
def get_client(self) -> botClient:
return self.client
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
if not self.webhook_helper:
return {"error": "Webhook helper not initialized"}, 500
# 复用 webhook_helper 的回调处理逻辑
return await self.webhook_helper.handle_callback(request)
async def terminate(self):
self.webhook_helper.shutdown_event.set()
if self.webhook_helper:
self.webhook_helper.shutdown_event.set()
await self.client.close()
try:
await self.webhook_helper.server.shutdown()
except Exception as _:
pass
if self.webhook_helper and not self.unified_webhook_mode:
try:
await self.webhook_helper.server.shutdown()
except Exception as exc:
logger.warning(
f"Exception occurred during QQOfficialWebhook server shutdown: {exc}",
exc_info=True,
)
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import cast
import quart
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
@@ -78,7 +79,19 @@ class QQOfficialWebhook:
return response
async def callback(self):
msg: dict = await quart.request.json
"""内部服务器的回调入口"""
return await self.handle_callback(quart.request)
async def handle_callback(self, request) -> dict:
"""处理 webhook 回调,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
响应数据
"""
msg: dict = await request.json
logger.debug(f"收到 qq_official_webhook 回调: {msg}")
event = msg.get("t")
@@ -87,7 +100,7 @@ class QQOfficialWebhook:
if opcode == 13:
# validation
signed = await self.webhook_validation(data)
signed = await self.webhook_validation(cast(dict, data))
print(signed)
return signed
@@ -38,8 +38,7 @@ class SatoriPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.api_base_url = self.config.get(
+58 -40
View File
@@ -4,9 +4,11 @@ import hmac
import json
import logging
from collections.abc import Callable
from typing import cast
from quart import Quart, Response, request
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.web.async_client import AsyncWebClient
@@ -47,51 +49,62 @@ class SlackWebhookClient:
@self.app.route(self.path, methods=["POST"])
async def slack_events():
"""处理 Slack 事件"""
try:
# 获取请求体和头部
body = await request.get_data()
event_data = json.loads(body.decode("utf-8"))
# Verify Slack request signature
timestamp = request.headers.get("X-Slack-Request-Timestamp")
signature = request.headers.get("X-Slack-Signature")
if not timestamp or not signature:
return Response("Missing headers", status=400)
# Calculate the HMAC signature
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
my_signature = (
"v0="
+ hmac.new(
self.signing_secret.encode("utf-8"),
sig_basestring.encode("utf-8"),
hashlib.sha256,
).hexdigest()
)
# Verify the signature
if not hmac.compare_digest(my_signature, signature):
logger.warning("Slack request signature verification failed")
return Response("Invalid signature", status=400)
logger.info(f"Received Slack event: {event_data}")
# 处理 URL 验证事件
if event_data.get("type") == "url_verification":
return {"challenge": event_data.get("challenge")}
# 处理事件
if self.event_handler and event_data.get("type") == "event_callback":
await self.event_handler(event_data)
return Response("", status=200)
except Exception as e:
logger.error(f"处理 Slack 事件时出错: {e}")
return Response("Internal Server Error", status=500)
"""内部服务器的 POST 回调入口"""
return await self.handle_callback(request)
@self.app.route("/health", methods=["GET"])
async def health_check():
"""健康检查端点"""
return {"status": "ok", "service": "slack-webhook"}
async def handle_callback(self, req):
"""处理 Slack 回调请求,可被统一 webhook 入口复用
Args:
req: Quart 请求对象
Returns:
Response 对象或字典
"""
try:
# 获取请求体和头部
body = cast(bytes, await req.get_data())
event_data = json.loads(body.decode("utf-8"))
# Verify Slack request signature
timestamp = req.headers.get("X-Slack-Request-Timestamp")
signature = req.headers.get("X-Slack-Signature")
if not timestamp or not signature:
return Response("Missing headers", status=400)
# Calculate the HMAC signature
sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}"
my_signature = (
"v0="
+ hmac.new(
self.signing_secret.encode("utf-8"),
sig_basestring.encode("utf-8"),
hashlib.sha256,
).hexdigest()
)
# Verify the signature
if not hmac.compare_digest(my_signature, signature):
logger.warning("Slack request signature verification failed")
return Response("Invalid signature", status=400)
logger.info(f"Received Slack event: {event_data}")
# 处理 URL 验证事件
if event_data.get("type") == "url_verification":
return {"challenge": event_data.get("challenge")}
# 处理事件
if self.event_handler and event_data.get("type") == "event_callback":
await self.event_handler(event_data)
return Response("", status=200)
except Exception as e:
logger.error(f"处理 Slack 事件时出错: {e}")
return Response("Internal Server Error", status=500)
async def start(self):
"""启动 Webhook 服务器"""
logger.info(
@@ -128,9 +141,14 @@ class SlackSocketClient:
self.event_handler = event_handler
self.socket_client = None
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
async def _handle_events(
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
):
"""处理 Socket Mode 事件"""
try:
if self.socket_client is None:
raise RuntimeError("Socket client is not initialized")
# 确认收到事件
response = SocketModeResponse(envelope_id=req.envelope_id)
await self.socket_client.send_socket_mode_response(response)
@@ -3,8 +3,7 @@ import base64
import re
import time
import uuid
from collections.abc import Awaitable
from typing import Any
from typing import Any, cast
import aiohttp
from slack_sdk.socket_mode.request import SocketModeRequest
@@ -21,6 +20,7 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .client import SlackSocketClient, SlackWebhookClient
@@ -39,9 +39,7 @@ class SlackAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)
@@ -49,6 +47,7 @@ class SlackAdapter(Platform):
self.app_token = platform_config.get("app_token")
self.signing_secret = platform_config.get("signing_secret")
self.connection_mode = platform_config.get("slack_connection_mode", "socket")
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0")
self.webhook_port = platform_config.get("slack_webhook_port", 3000)
self.webhook_path = platform_config.get(
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
self.metadata = PlatformMetadata(
name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
logger.debug(f"[slack] RawMessage {event}")
abm = AstrBotMessage()
abm.self_id = self.bot_self_id
abm.self_id = cast(str, self.bot_self_id)
# 获取用户信息
user_id = event.get("user", "")
try:
user_info = await self.web_client.users_info(user=user_id)
user_data = user_info["user"]
user_data = cast(dict, user_info["user"])
user_name = user_data.get("real_name") or user_data.get("name", user_id)
except Exception:
user_name = user_id
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
channel_id = event.get("channel", "")
try:
channel_info = await self.web_client.conversations_info(channel=channel_id)
is_im = channel_info["channel"]["is_im"]
is_im = cast(dict, channel_info["channel"])["is_im"]
if is_im:
abm.type = MessageType.FRIEND_MESSAGE
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
for mention in mentions:
try:
mentioned_user = await self.web_client.users_info(user=mention)
user_data = mentioned_user["user"]
user_data = cast(dict, mentioned_user["user"])
user_name = user_data.get("real_name") or user_data.get(
"name",
mention,
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
)
raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]:
async def run(self) -> None:
self.bot_self_id = await self.get_bot_user_id()
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
@@ -361,10 +360,17 @@ class SlackAdapter(Platform):
self._handle_webhook_event,
)
logger.info(
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...",
)
await self.webhook_client.start()
# 如果启用统一 webhook 模式,则不启动独立服务器
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(Slack)", webhook_uuid)
# 保持运行状态,等待 shutdown
await self.webhook_client.shutdown_event.wait()
else:
logger.info(
f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...",
)
await self.webhook_client.start()
else:
raise ValueError(
@@ -391,12 +397,19 @@ class SlackAdapter(Platform):
if abm:
await self.handle_msg(abm)
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
if self.connection_mode != "webhook" or not self.webhook_client:
return {"error": "Slack adapter is not in webhook mode"}, 400
return await self.webhook_client.handle_callback(request)
async def terminate(self):
if self.socket_client:
await self.socket_client.stop()
if self.webhook_client:
await self.webhook_client.stop()
logger.info("Slack 适配器已被优雅地关闭")
logger.info("Slack 适配器已被关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
@@ -414,3 +427,10 @@ class SlackAdapter(Platform):
def get_client(self):
return self.web_client
def unified_webhook(self) -> bool:
return bool(
self.config.get("unified_webhook_mode", False)
and self.config.get("slack_connection_mode", "") == "webhook"
and self.config.get("webhook_uuid")
)
@@ -1,6 +1,7 @@
import asyncio
import re
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Iterable
from typing import cast
from slack_sdk.web.async_client import AsyncWebClient
@@ -31,14 +32,14 @@ class SlackMessageEvent(AstrMessageEvent):
async def _from_segment_to_slack_block(
segment: BaseMessageComponent,
web_client: AsyncWebClient,
) -> dict:
) -> dict | None:
"""将消息段转换为 Slack 块格式"""
if isinstance(segment, Plain):
return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}}
if isinstance(segment, Image):
# upload file
url = segment.url or segment.file
if url.startswith("http"):
if url and url.startswith("http"):
return {
"type": "image",
"image_url": url,
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section",
"text": {"type": "mrkdwn", "text": "图片上传失败"},
}
image_url = response["files"][0]["url_private"]
image_url = cast(list, response["files"])[0]["url_private"]
logger.debug(f"Slack file upload response: {response}")
return {
"type": "image",
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section",
"text": {"type": "mrkdwn", "text": "文件上传失败"},
}
file_url = response["files"][0]["permalink"]
file_url = cast(list, response["files"])[0]["permalink"]
return {
"type": "section",
"text": {
@@ -85,7 +86,6 @@ class SlackMessageEvent(AstrMessageEvent):
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
},
}
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
@staticmethod
async def _parse_slack_blocks(
@@ -115,7 +115,8 @@ class SlackMessageEvent(AstrMessageEvent):
segment,
web_client,
)
blocks.append(block)
if block:
blocks.append(block)
# 如果最后还有文本内容
if text_content.strip():
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
)
members = []
for member_id in members_response["members"]:
for member_id in cast(Iterable, members_response["members"]):
try:
user_info = await self.web_client.users_info(user=member_id)
user_data = user_info["user"]
user_data = cast(dict, user_info["user"])
members.append(
MessageMember(
user_id=member_id,
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
# 如果获取用户信息失败,使用默认信息
members.append(MessageMember(user_id=member_id, nickname=member_id))
channel_data = channel_info["channel"]
channel_data = cast(dict, channel_info["channel"])
return Group(
group_id=channel_id,
group_name=channel_data.get("name", ""),
@@ -42,8 +42,7 @@ class TelegramPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.client_self_id = uuid.uuid4().hex[:8]
@@ -381,7 +380,9 @@ class TelegramPlatformAdapter(Platform):
f"Telegram document file_path is None, cannot save the file {file_name}.",
)
else:
message.message.append(Comp.File(file=file_path, name=file_name))
message.message.append(
Comp.File(file=file_path, name=file_name, url=file_path)
)
elif update.message.video:
file = await update.message.video.get_file()
@@ -423,6 +424,6 @@ class TelegramPlatformAdapter(Platform):
if self.application.updater is not None:
await self.application.updater.stop()
logger.info("Telegram 适配器已被优雅地关闭")
logger.info("Telegram 适配器已被关闭")
except Exception as e:
logger.error(f"Telegram 适配器关闭时出错: {e}")
@@ -1,6 +1,7 @@
import asyncio
import os
import re
from typing import Any, cast
import telegramify_markdown
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
Reply,
)
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
class TelegramPlatformEvent(AstrMessageEvent):
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
"chat_id": user_name,
}
if has_reply:
payload["reply_to_message_id"] = reply_message_id
payload["reply_to_message_id"] = str(reply_message_id)
if message_thread_id:
payload["message_thread_id"] = message_thread_id
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
md_text = telegramify_markdown.markdownify(
chunk,
max_line_length=None,
normalize_whitespace=False,
)
await client.send_message(
text=md_text,
parse_mode="MarkdownV2",
**payload,
**cast(Any, payload),
)
except Exception as e:
logger.warning(
f"MarkdownV2 send failed: {e}. Using plain text instead.",
)
await client.send_message(text=chunk, **payload)
await client.send_message(text=chunk, **cast(Any, payload))
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload)
await client.send_photo(photo=image_path, **cast(Any, payload))
elif isinstance(i, File):
if i.file.startswith("https://"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
await client.send_document(document=i.file, filename=i.name, **payload)
path = await i.get_file()
name = i.name or os.path.basename(path)
await client.send_document(
document=path, filename=name, **cast(Any, payload)
)
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await client.send_voice(voice=path, **payload)
await client.send_voice(voice=path, **cast(Any, payload))
async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -214,24 +210,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload)
await self.client.send_photo(
photo=image_path, **cast(Any, payload)
)
continue
elif isinstance(i, File):
if i.file.startswith("https://"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
path = await i.get_file()
name = i.name or os.path.basename(path)
await self.client.send_document(
document=i.file,
filename=i.name,
**payload,
document=path,
filename=name,
**cast(Any, payload),
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload)
await self.client.send_voice(voice=path, **cast(Any, payload))
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
@@ -260,7 +255,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
else:
# delta 长度一般不会大于 4096,因此这里直接发送
try:
msg = await self.client.send_message(text=delta, **payload)
msg = await self.client.send_message(
text=delta, **cast(Any, payload)
)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
@@ -274,7 +271,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
markdown_text = telegramify_markdown.markdownify(
delta,
max_line_length=None,
normalize_whitespace=False,
)
await self.client.edit_message_text(
@@ -2,11 +2,13 @@ import asyncio
import os
import time
import uuid
from collections.abc import Awaitable, Callable
from collections.abc import Callable, Coroutine
from typing import Any
from astrbot import logger
from astrbot.core.message.components import Image, Plain, Record
from astrbot.core import db_helper
from astrbot.core.db.po import PlatformMessageHistory
from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.platform import (
AstrBotMessage,
@@ -74,9 +76,8 @@ class WebChatAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
super().__init__(platform_config, event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings["unique_session"]
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
@@ -96,6 +97,92 @@ class WebChatAdapter(Platform):
await WebChatMessageEvent._send(message_chain, session.session_id)
await super().send_by_session(session, message_chain)
async def _get_message_history(
self, message_id: int
) -> PlatformMessageHistory | None:
return await db_helper.get_platform_message_history_by_id(message_id)
async def _parse_message_parts(
self,
message_parts: list,
depth: int = 0,
max_depth: int = 1,
) -> tuple[list, list[str]]:
"""解析消息段列表,返回消息组件列表和纯文本列表
Args:
message_parts: 消息段列表
depth: 当前递归深度
max_depth: 最大递归深度用于处理 reply
Returns:
tuple[list, list[str]]: (消息组件列表, 纯文本列表)
"""
components = []
text_parts = []
for part in message_parts:
part_type = part.get("type")
if part_type == "plain":
text = part.get("text", "")
components.append(Plain(text))
text_parts.append(text)
elif part_type == "reply":
message_id = part.get("message_id")
reply_chain = []
reply_message_str = ""
sender_id = None
sender_name = None
# recursively get the content of the referenced message
if depth < max_depth and message_id:
history = await self._get_message_history(message_id)
if history and history.content:
reply_parts = history.content.get("message", [])
if isinstance(reply_parts, list):
(
reply_chain,
reply_text_parts,
) = await self._parse_message_parts(
reply_parts,
depth=depth + 1,
max_depth=max_depth,
)
reply_message_str = "".join(reply_text_parts)
sender_id = history.sender_id
sender_name = history.sender_name
components.append(
Reply(
id=message_id,
chain=reply_chain,
message_str=reply_message_str,
sender_id=sender_id,
sender_nickname=sender_name,
)
)
elif part_type == "image":
path = part.get("path")
if path:
components.append(Image.fromFileSystem(path))
elif part_type == "record":
path = part.get("path")
if path:
components.append(Record.fromFileSystem(path))
elif part_type == "file":
path = part.get("path")
if path:
filename = part.get("filename") or (
os.path.basename(path) if path else "file"
)
components.append(File(name=filename, file=path))
elif part_type == "video":
path = part.get("path")
if path:
components.append(Video.fromFileSystem(path))
return components, text_parts
async def convert_message(self, data: tuple) -> AstrBotMessage:
username, cid, payload = data
@@ -108,40 +195,19 @@ class WebChatAdapter(Platform):
abm.session_id = f"webchat!{username}!{cid}"
abm.message_id = str(uuid.uuid4())
abm.message = []
if payload["message"]:
abm.message.append(Plain(payload["message"]))
if payload["image_url"]:
if isinstance(payload["image_url"], list):
for img in payload["image_url"]:
abm.message.append(
Image.fromFileSystem(os.path.join(self.imgs_dir, img)),
)
else:
abm.message.append(
Image.fromFileSystem(
os.path.join(self.imgs_dir, payload["image_url"]),
),
)
if payload["audio_url"]:
if isinstance(payload["audio_url"], list):
for audio in payload["audio_url"]:
path = os.path.join(self.imgs_dir, audio)
abm.message.append(Record(file=path, path=path))
else:
path = os.path.join(self.imgs_dir, payload["audio_url"])
abm.message.append(Record(file=path, path=path))
# 处理消息段列表
message_parts = payload.get("message", [])
abm.message, message_str_parts = await self._parse_message_parts(message_parts)
logger.debug(f"WebChatAdapter: {abm.message}")
message_str = payload["message"]
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.message_str = "".join(message_str_parts)
abm.raw_message = data
return abm
def run(self) -> Awaitable[Any]:
def run(self) -> Coroutine[Any, Any, None]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
@@ -1,12 +1,12 @@
import base64
import os
import shutil
import uuid
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Image, Plain, Record
from astrbot.api.message_components import File, Image, Plain, Record
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_image_by_url
from .webchat_queue_mgr import webchat_queue_mgr
@@ -19,7 +19,9 @@ class WebChatMessageEvent(AstrMessageEvent):
os.makedirs(imgs_dir, exist_ok=True)
@staticmethod
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
async def _send(
message: MessageChain | None, session_id: str, streaming: bool = False
) -> str | None:
cid = session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
if not message:
@@ -30,7 +32,7 @@ class WebChatMessageEvent(AstrMessageEvent):
"streaming": False,
}, # end means this request is finished
)
return ""
return
data = ""
for comp in message.chain:
@@ -47,24 +49,11 @@ class WebChatMessageEvent(AstrMessageEvent):
)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
filename = f"{str(uuid.uuid4())}.jpg"
path = os.path.join(imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file.startswith("base64://"):
base64_str = comp.file[9:]
image_data = base64.b64decode(base64_str)
with open(path, "wb") as f:
f.write(image_data)
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
image_base64 = await comp.convert_to_base64()
with open(path, "wb") as f:
f.write(base64.b64decode(image_base64))
data = f"[IMAGE]{filename}"
await web_chat_back_queue.put(
{
@@ -76,19 +65,11 @@ class WebChatMessageEvent(AstrMessageEvent):
)
elif isinstance(comp, Record):
# save record to local
filename = str(uuid.uuid4()) + ".wav"
filename = f"{str(uuid.uuid4())}.wav"
path = os.path.join(imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
record_base64 = await comp.convert_to_base64()
with open(path, "wb") as f:
f.write(base64.b64decode(record_base64))
data = f"[RECORD]{filename}"
await web_chat_back_queue.put(
{
@@ -98,14 +79,31 @@ class WebChatMessageEvent(AstrMessageEvent):
"streaming": streaming,
},
)
elif isinstance(comp, File):
# save file to local
file_path = await comp.get_file()
original_name = comp.name or os.path.basename(file_path)
ext = os.path.splitext(original_name)[1] or ""
filename = f"{uuid.uuid4()!s}{ext}"
dest_path = os.path.join(imgs_dir, filename)
shutil.copy2(file_path, dest_path)
data = f"[FILE]{filename}|{original_name}"
await web_chat_back_queue.put(
{
"type": "file",
"cid": cid,
"data": data,
"streaming": streaming,
},
)
else:
logger.debug(f"webchat 忽略: {comp.type}")
return data
async def send(self, message: MessageChain):
async def send(self, message: MessageChain | None):
await WebChatMessageEvent._send(message, session_id=self.session_id)
await super().send(message)
await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
@@ -131,6 +129,8 @@ class WebChatMessageEvent(AstrMessageEvent):
session_id=self.session_id,
streaming=True,
)
if not r:
continue
if chain.type == "reasoning":
reasoning_content += chain.get_plain_text()
else:
@@ -4,6 +4,7 @@ import json
import os
import time
import traceback
from typing import cast
import aiohttp
import anyio
@@ -42,10 +43,9 @@ class WeChatPadProAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
super().__init__(platform_config, event_queue)
self._shutdown_event = None
self.wxnewpass = None
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings.get("unique_session", False)
@@ -70,7 +70,7 @@ class WeChatPadProAdapter(Platform):
)
self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码
self.wxid = None # 用于保存登录成功后的 wxid
self.wxid: str | None = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join(
get_astrbot_data_path(),
"wechatpadpro_credentials.json",
@@ -399,7 +399,7 @@ class WeChatPadProAdapter(Platform):
)
await asyncio.sleep(5)
async def handle_websocket_message(self, message: str):
async def handle_websocket_message(self, message: str | bytes):
"""处理从 WebSocket 接收到的消息。"""
logger.debug(f"收到 WebSocket 消息: {message}")
try:
@@ -431,10 +431,13 @@ class WeChatPadProAdapter(Platform):
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
if self.wxid is None:
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
return None
abm = AstrBotMessage()
abm.raw_message = raw_message
abm.message_id = str(raw_message.get("msg_id"))
abm.timestamp = raw_message.get("create_time")
abm.timestamp = cast(int, raw_message.get("create_time"))
abm.self_id = self.wxid
if int(time.time()) - abm.timestamp > 180:
@@ -447,7 +450,7 @@ class WeChatPadProAdapter(Platform):
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
content = raw_message.get("content", {}).get("str", "")
push_content = raw_message.get("push_content", "")
msg_type = raw_message.get("msg_type")
msg_type = cast(int, raw_message.get("msg_type"))
abm.message_str = ""
abm.message = []
@@ -575,7 +578,7 @@ class WeChatPadProAdapter(Platform):
from_user_name: str,
to_user_name: str,
msg_id: int,
):
) -> dict | None:
"""下载原始图片。"""
url = f"{self.base_url}/message/GetMsgBigImg"
params = {"key": self.auth_key}
@@ -726,12 +729,15 @@ class WeChatPadProAdapter(Platform):
# 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
msg_id = raw_message.get("msg_id")
msg_id = cast(int, raw_message.get("msg_id"))
image_resp = await self._download_raw_image(
from_user_name,
to_user_name,
msg_id,
)
if image_resp is None:
logger.error(f"下载图片失败: msg_id={msg_id}")
return
image_bs64_data = (
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
)
@@ -772,6 +778,9 @@ class WeChatPadProAdapter(Platform):
bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id is None:
logger.error("语音消息缺少 new_msg_id")
return
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
@@ -779,6 +788,9 @@ class WeChatPadProAdapter(Platform):
)
voicemsg = data_parser._format_to_xml().find("voicemsg")
if voicemsg is None:
logger.error("无法从 XML 解析 voicemsg 节点")
return
bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice(
@@ -787,6 +799,9 @@ class WeChatPadProAdapter(Platform):
bufid=bufid,
length=length,
)
if voice_resp is None:
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
return
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data)
@@ -828,7 +843,8 @@ class WeChatPadProAdapter(Platform):
try:
if self.ws_handle_task:
self.ws_handle_task.cancel()
self._shutdown_event.set()
if self._shutdown_event is not None:
self._shutdown_event.set()
except Exception:
pass
@@ -895,8 +911,8 @@ class WeChatPadProAdapter(Platform):
async def get_contact_details_list(
self,
room_wx_id_list: list[str] = None,
user_names: list[str] = None,
room_wx_id_list: list[str] | None = None,
user_names: list[str] | None = None,
) -> dict | None:
"""获取联系人详情列表。"""
if room_wx_id_list is None:
@@ -2,6 +2,8 @@ import asyncio
import os
import sys
import uuid
from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart
from requests import Response
@@ -24,6 +26,7 @@ from astrbot.api.platform import (
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.webhook_utils import log_webhook_info
from .wecom_event import WecomPlatformEvent
from .wecom_kf import WeChatKF
@@ -38,7 +41,7 @@ else:
class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
self.port = int(config.get("port"))
self.port = int(cast(str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule(
"/callback/command",
@@ -58,12 +61,24 @@ class WecomServer:
config["corpid"].strip(),
)
self.callback = None
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
async def verify(self):
logger.info(f"验证请求有效性: {quart.request.args}")
args = quart.request.args
"""内部服务器的 GET 验证入口"""
return await self.handle_verify(quart.request)
async def handle_verify(self, request) -> str:
"""处理验证请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
验证响应
"""
logger.info(f"验证请求有效性: {request.args}")
args = request.args
try:
echo_str = self.crypto.check_signature(
args.get("msg_signature"),
@@ -78,17 +93,29 @@ class WecomServer:
raise
async def callback_command(self):
data = await quart.request.get_data()
msg_signature = quart.request.args.get("msg_signature")
timestamp = quart.request.args.get("timestamp")
nonce = quart.request.args.get("nonce")
"""内部服务器的 POST 回调入口"""
return await self.handle_callback(quart.request)
async def handle_callback(self, request) -> str:
"""处理回调请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
响应内容
"""
data = await request.get_data()
msg_signature = request.args.get("msg_signature")
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
try:
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
except InvalidSignatureException:
logger.error("解密失败,签名异常,请检查配置。")
raise
else:
msg = parse_message(xml)
msg = cast(BaseMessage, parse_message(xml))
logger.info(f"解析成功: {msg}")
if self.callback:
@@ -118,14 +145,14 @@ class WecomPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.settingss = platform_settings
self.client_self_id = uuid.uuid4().hex[:8]
self.api_base_url = platform_config.get(
"api_base_url",
"https://qyapi.weixin.qq.com/cgi-bin/",
)
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
if not self.api_base_url:
self.api_base_url = "https://qyapi.weixin.qq.com/cgi-bin/"
@@ -150,10 +177,10 @@ class WecomPlatformAdapter(Platform):
# inject
self.wechat_kf_api = WeChatKF(client=self.client)
self.wechat_kf_message_api = WeChatKFMessage(self.client)
self.client.kf = self.wechat_kf_api
self.client.kf_message = self.wechat_kf_message_api
self.client.__setattr__("kf", self.wechat_kf_api)
self.client.__setattr__("kf_message", self.wechat_kf_message_api)
self.client.API_BASE_URL = self.api_base_url
self.client.__setattr__("API_BASE_URL", self.api_base_url)
async def callback(msg: BaseMessage):
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
@@ -232,41 +259,53 @@ class WecomPlatformAdapter(Platform):
)
except Exception as e:
logger.error(e)
await self.server.start_polling()
# 如果启用统一 webhook 模式,则不启动独立服务器
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(企业微信)", webhook_uuid)
# 保持运行状态,等待 shutdown
await self.server.shutdown_event.wait()
else:
await self.server.start_polling()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
# 根据请求方法分发到不同的处理函数
if request.method == "GET":
return await self.server.handle_verify(request)
else:
return await self.server.handle_callback(request)
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
abm = AstrBotMessage()
if msg.type == "text":
assert isinstance(msg, TextMessage)
if isinstance(msg, TextMessage):
abm.message_str = msg.content
abm.self_id = str(msg.agent)
abm.message = [Plain(msg.content)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(msg.id)
abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif msg.type == "image":
assert isinstance(msg, ImageMessage)
elif isinstance(msg, ImageMessage):
abm.message_str = "[图片]"
abm.self_id = str(msg.agent)
abm.message = [Image(file=msg.image, url=msg.image)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(msg.id)
abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
elif isinstance(msg, VoiceMessage):
resp: Response = await asyncio.get_event_loop().run_in_executor(
None,
self.client.media.download,
@@ -293,11 +332,11 @@ class WecomPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(msg.id)
abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
else:
@@ -309,7 +348,7 @@ class WecomPlatformAdapter(Platform):
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
msgtype = msg.get("msgtype")
external_userid = msg.get("external_userid")
external_userid = cast(str, msg.get("external_userid"))
abm = AstrBotMessage()
abm.raw_message = msg
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
@@ -383,4 +422,4 @@ class WecomPlatformAdapter(Platform):
await self.server.server.shutdown()
except Exception as _:
pass
logger.info("企业微信 适配器已被优雅地关闭")
logger.info("企业微信 适配器已被关闭")
@@ -16,7 +16,7 @@ try:
import pydub
except Exception:
logger.warning(
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
)
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
if is_wechat_kf:
# 微信客服
kf_message_api = getattr(self.client, "kf_message", None)
if not kf_message_api:
if not isinstance(kf_message_api, WeChatKFMessage):
logger.warning("未找到微信客服发送消息方法。")
return
assert isinstance(kf_message_api, WeChatKFMessage)
user_id = self.get_sender_id()
for comp in message.chain:
if isinstance(comp, Plain):
@@ -22,6 +22,7 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .wecomai_api import (
@@ -103,9 +104,7 @@ class WecomAIBotAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.settings = platform_settings
# 初始化配置参数
@@ -122,6 +121,7 @@ class WecomAIBotAdapter(Platform):
"wecomaibot_friend_message_welcome_text",
"",
)
self.unified_webhook_mode = self.config.get("unified_webhook_mode", False)
# 平台元数据
self.metadata = PlatformMetadata(
@@ -425,17 +425,34 @@ class WecomAIBotAdapter(Platform):
def run(self) -> Awaitable[Any]:
"""运行适配器,同时启动HTTP服务器和队列监听器"""
logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port)
async def run_both():
# 同时运行HTTP服务器和队列监听
await asyncio.gather(
self.server.start_server(),
self.queue_listener.run(),
)
# 如果启用统一 webhook 模式,则不启动独立服务
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", webhook_uuid)
# 只运行队列监听器
await self.queue_listener.run()
else:
logger.info(
"启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port
)
# 同时运行HTTP服务器和队列监听器
await asyncio.gather(
self.server.start_server(),
self.queue_listener.run(),
)
return run_both()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
# 根据请求方法分发到不同的处理函数
if request.method == "GET":
return await self.server.handle_verify(request)
else:
return await self.server.handle_callback(request)
async def terminate(self):
"""终止适配器"""
logger.info("企业微信智能机器人适配器正在关闭...")
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
@staticmethod
async def _send(
message_chain: MessageChain,
message_chain: MessageChain | None,
stream_id: str,
queue_mgr: WecomAIQueueMgr,
streaming: bool = False,
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data
async def send(self, message: MessageChain):
async def send(self, message: MessageChain | None):
"""发送消息"""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
)
stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
await super().send(message)
await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息,参考webchat的send_streaming设计"""
@@ -59,8 +59,19 @@ class WecomAIBotServer:
)
async def verify_url(self):
"""验证回调 URL"""
args = quart.request.args
"""内部服务器的 GET 验证入口"""
return await self.handle_verify(quart.request)
async def handle_verify(self, request):
"""处理 URL 验证请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
验证响应元组 (content, status_code, headers)
"""
args = request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
@@ -81,8 +92,19 @@ class WecomAIBotServer:
return result, 200, {"Content-Type": "text/plain"}
async def handle_message(self):
"""处理消息回调"""
args = quart.request.args
"""内部服务器的 POST 消息回调入口"""
return await self.handle_callback(quart.request)
async def handle_callback(self, request):
"""处理消息回调,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
响应元组 (content, status_code, headers)
"""
args = request.args
msg_signature = args.get("msg_signature")
timestamp = args.get("timestamp")
nonce = args.get("nonce")
@@ -102,7 +124,7 @@ class WecomAIBotServer:
try:
# 获取请求体
post_data = await quart.request.get_data()
post_data = await request.get_data()
# 确保 post_data 是 bytes 类型
if isinstance(post_data, str):
@@ -1,6 +1,8 @@
import asyncio
import sys
import uuid
from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart
from requests import Response
@@ -22,6 +24,7 @@ from astrbot.api.platform import (
)
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
@@ -31,10 +34,10 @@ else:
from typing_extensions import override
class WecomServer:
class WeixinOfficialAccountServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
self.port = int(config.get("port"))
self.port = int(cast(int | str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.token = config.get("token")
self.encoding_aes_key = config.get("encoding_aes_key")
@@ -53,13 +56,25 @@ class WecomServer:
self.event_queue = event_queue
self.callback = None
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
async def verify(self):
logger.info(f"验证请求有效性: {quart.request.args}")
"""内部服务器的 GET 验证入口"""
return await self.handle_verify(quart.request)
args = quart.request.args
async def handle_verify(self, request) -> str:
"""处理验证请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
验证响应
"""
logger.info(f"验证请求有效性: {request.args}")
args = request.args
if not args.get("signature", None):
logger.error("未知的响应,请检查回调地址是否填写正确。")
return "err"
@@ -77,10 +92,22 @@ class WecomServer:
return "err"
async def callback_command(self):
data = await quart.request.get_data()
msg_signature = quart.request.args.get("msg_signature")
timestamp = quart.request.args.get("timestamp")
nonce = quart.request.args.get("nonce")
"""内部服务器的 POST 回调入口"""
return await self.handle_callback(quart.request)
async def handle_callback(self, request) -> str:
"""处理回调请求,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
响应内容
"""
data = await request.get_data()
msg_signature = request.args.get("msg_signature")
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
try:
xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce)
except InvalidSignatureException:
@@ -88,6 +115,9 @@ class WecomServer:
raise
else:
msg = parse_message(xml)
if not msg:
logger.error("解析失败。msg为None。")
raise
logger.info(f"解析成功: {msg}")
if self.callback:
@@ -123,8 +153,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(event_queue)
self.config = platform_config
super().__init__(platform_config, event_queue)
self.settingss = platform_settings
self.client_self_id = uuid.uuid4().hex[:8]
self.api_base_url = platform_config.get(
@@ -132,6 +161,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
"https://api.weixin.qq.com/cgi-bin/",
)
self.active_send_mode = self.config.get("active_send_mode", False)
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False)
if not self.api_base_url:
self.api_base_url = "https://api.weixin.qq.com/cgi-bin/"
@@ -143,14 +173,14 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
if not self.api_base_url.endswith("/"):
self.api_base_url += "/"
self.server = WecomServer(self._event_queue, self.config)
self.server = WeixinOfficialAccountServer(self._event_queue, self.config)
self.client = WeChatClient(
self.config["appid"].strip(),
self.config["secret"].strip(),
)
self.client.API_BASE_URL = self.api_base_url
self.client.__setattr__("API_BASE_URL", self.api_base_url)
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
# msgid -> Future
@@ -162,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.convert_message(msg, None)
else:
if msg.id in self.wexin_event_workers:
future = self.wexin_event_workers[msg.id]
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg.id] = future
self.wexin_event_workers[str(cast(str | int, msg.id))] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(
@@ -174,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
60,
) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None)
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
return result # xml. see weixin_offacc_event.py
except asyncio.TimeoutError:
pass
@@ -202,38 +232,53 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
@override
async def run(self):
await self.server.start_polling()
# 如果启用统一 webhook 模式,则不启动独立服务器
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(微信公众平台)", webhook_uuid)
# 保持运行状态,等待 shutdown
await self.server.shutdown_event.wait()
else:
await self.server.start_polling()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
# 根据请求方法分发到不同的处理函数
if request.method == "GET":
return await self.server.handle_verify(request)
else:
return await self.server.handle_callback(request)
async def convert_message(
self,
msg,
future: asyncio.Future = None,
future: asyncio.Future | None = None,
) -> AstrBotMessage | None:
abm = AstrBotMessage()
if isinstance(msg, TextMessage):
abm.message_str = msg.content
abm.message_str = cast(str, msg.content)
abm.self_id = str(msg.target)
abm.message = [Plain(msg.content)]
abm.message = [Plain(cast(str, msg.content))]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
elif msg.type == "image":
assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]"
abm.self_id = str(msg.target)
abm.message = [Image(file=msg.image, url=msg.image)]
abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
@@ -265,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
else:
logger.warning(f"暂未实现的事件: {msg.type}")
future.set_result(None)
if future:
future.set_result(None)
return
# 很不优雅 :(
abm.raw_message = {
@@ -303,4 +349,4 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.server.server.shutdown()
except Exception as _:
pass
logger.info("微信公众平台 适配器已被优雅地关闭")
logger.info("微信公众平台 适配器已被关闭")
@@ -1,5 +1,6 @@
import asyncio
import uuid
from typing import cast
from wechatpy import WeChatClient
from wechatpy.replies import ImageReply, TextReply, VoiceReply
@@ -13,7 +14,7 @@ try:
import pydub
except Exception:
logger.warning(
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。",
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
)
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
message_obj = self.message_obj
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
active_send_mode = cast(dict, message_obj.raw_message).get(
"active_send_mode", False
)
for comp in message.chain:
if isinstance(comp, Plain):
# Split long text messages if needed
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = TextReply(
content=chunk,
message=self.message_obj.raw_message["message"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = ImageReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = VoiceReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
+3 -3
View File
@@ -10,12 +10,12 @@ class PlatformMessageHistoryManager:
self,
platform_id: str,
user_id: str,
content: list[dict], # TODO: parse from message chain
content: dict, # TODO: parse from message chain
sender_id: str | None = None,
sender_name: str | None = None,
):
) -> PlatformMessageHistory:
"""Insert a new platform message history record."""
await self.db.insert_platform_message_history(
return await self.db.insert_platform_message_history(
platform_id=platform_id,
user_id=user_id,
content=content,
+3 -3
View File
@@ -4,7 +4,7 @@ import asyncio
import copy
import json
import os
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
import aiohttp
@@ -118,7 +118,7 @@ class FunctionToolManager:
name: str,
func_args: list[dict],
desc: str,
handler: Callable[..., Awaitable[Any]],
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
@@ -140,7 +140,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> None:
"""添加函数调用工具
+167 -109
View File
@@ -1,7 +1,8 @@
import asyncio
import traceback
from typing import Protocol, runtime_checkable
from astrbot.core import logger, sp
from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.db import BaseDatabase
@@ -10,6 +11,7 @@ from .entities import ProviderType
from .provider import (
EmbeddingProvider,
Provider,
Providers,
RerankProvider,
STTProvider,
TTSProvider,
@@ -17,6 +19,11 @@ from .provider import (
from .register import llm_tools, provider_cls_map
@runtime_checkable
class HasInitialize(Protocol):
async def initialize(self) -> None: ...
class ProviderManager:
def __init__(
self,
@@ -24,6 +31,7 @@ class ProviderManager:
db_helper: BaseDatabase,
persona_mgr: PersonaManager,
):
self.reload_lock = asyncio.Lock()
self.persona_mgr = persona_mgr
self.acm = acm
config = acm.confs["default"]
@@ -47,7 +55,7 @@ class ProviderManager:
"""加载的 Rerank Provider 的实例"""
self.inst_map: dict[
str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
Providers,
] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -122,15 +130,13 @@ class ProviderManager:
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(
self,
provider_type: ProviderType,
umo=None,
) -> Provider | STTProvider | TTSProvider | None:
self, provider_type: ProviderType, umo=None
) -> Providers | None:
"""获取正在使用的提供商实例。
Args:
@@ -190,7 +196,6 @@ class ProviderManager:
logger.error(traceback.format_exc())
logger.error(e)
# 设置默认提供商
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
@@ -209,15 +214,37 @@ class ProviderManager:
scope="global",
scope_id="global",
)
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
temp_provider = (
self.inst_map.get(selected_provider_id)
if isinstance(selected_provider_id, str)
else None
)
self.curr_provider_inst = (
temp_provider if isinstance(temp_provider, Provider) else None
)
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
temp_stt = (
self.inst_map.get(selected_stt_provider_id)
if isinstance(selected_stt_provider_id, str)
else None
)
self.curr_stt_provider_inst = (
temp_stt if isinstance(temp_stt, STTProvider) else None
)
if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
temp_tts = (
self.inst_map.get(selected_tts_provider_id)
if isinstance(selected_tts_provider_id, str)
else None
)
self.curr_tts_provider_inst = (
temp_tts if isinstance(temp_tts, TTSProvider) else None
)
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
@@ -226,6 +253,9 @@ class ProviderManager:
async def load_provider(self, provider_config: dict):
if not provider_config["enable"]:
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
return
if provider_config.get("provider_type", "") == "agent_runner":
return
logger.info(
@@ -247,14 +277,6 @@ class ProviderManager:
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
case "dify":
from .sources.dify_source import ProviderDify as ProviderDify
case "coze":
from .sources.coze_source import ProviderCoze as ProviderCoze
case "dashscope":
from .sources.dashscope_source import (
ProviderDashscope as ProviderDashscope,
)
case "googlegenai_chat_completion":
from .sources.gemini_source import (
ProviderGoogleGenAI as ProviderGoogleGenAI,
@@ -362,73 +384,103 @@ class ProviderManager:
provider_metadata.id = provider_config["id"]
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = cls_type(provider_config, self.provider_settings)
match provider_metadata.provider_type:
case ProviderType.SPEECH_TO_TEXT:
# STT 任务
if not issubclass(cls_type, STTProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of STTProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
if isinstance(inst, HasInitialize):
await inst.initialize()
self.stt_provider_insts.append(inst)
if (
self.provider_stt_settings.get("provider_id")
== provider_config["id"]
):
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
self.stt_provider_insts.append(inst)
if (
self.provider_stt_settings.get("provider_id")
== provider_config["id"]
):
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
)
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
case ProviderType.TEXT_TO_SPEECH:
# TTS 任务
if not issubclass(cls_type, TTSProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of TTSProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.tts_provider_insts.append(inst)
if (
self.provider_settings.get("provider_id")
== provider_config["id"]
):
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
)
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
case ProviderType.CHAT_COMPLETION:
# 文本生成任务
if not issubclass(cls_type, Provider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of Provider"
)
inst = cls_type(
provider_config,
self.provider_settings,
)
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
self.tts_provider_insts.append(inst)
if self.provider_settings.get("provider_id") == provider_config["id"]:
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
case ProviderType.EMBEDDING:
if not issubclass(cls_type, EmbeddingProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.embedding_provider_insts.append(inst)
case ProviderType.RERANK:
if not issubclass(cls_type, RerankProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of RerankProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.rerank_provider_insts.append(inst)
case _:
# 未知供应商抛出异常,确保inst初始化
# Should be unreachable
raise Exception(
f"未知的提供商类型:{provider_metadata.provider_type}"
)
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = cls_type(
provider_config,
self.provider_settings,
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
@@ -440,40 +492,46 @@ class ProviderManager:
)
async def reload(self, provider_config: dict):
await self.terminate_provider(provider_config["id"])
if provider_config["enable"]:
await self.load_provider(provider_config)
async with self.reload_lock:
await self.terminate_provider(provider_config["id"])
if provider_config["enable"]:
await self.load_provider(provider_config)
# 和配置文件保持同步
config_ids = [provider["id"] for provider in self.providers_config]
logger.debug(f"providers in user's config: {config_ids}")
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)
# 和配置文件保持同步
self.providers_config = astrbot_config["provider"]
config_ids = [provider["id"] for provider in self.providers_config]
logger.info(f"providers in user's config: {config_ids}")
for key in list(self.inst_map.keys()):
if key not in config_ids:
await self.terminate_provider(key)
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
)
if len(self.provider_insts) == 0:
self.curr_provider_inst = None
elif self.curr_provider_inst is None and len(self.provider_insts) > 0:
self.curr_provider_inst = self.provider_insts[0]
logger.info(
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。",
)
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
elif self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
)
if len(self.stt_provider_insts) == 0:
self.curr_stt_provider_inst = None
elif (
self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0
):
self.curr_stt_provider_inst = self.stt_provider_insts[0]
logger.info(
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。",
)
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
elif self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
)
if len(self.tts_provider_insts) == 0:
self.curr_tts_provider_inst = None
elif (
self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0
):
self.curr_tts_provider_inst = self.tts_provider_insts[0]
logger.info(
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。",
)
def get_insts(self):
return self.provider_insts
+47 -1
View File
@@ -1,6 +1,8 @@
import abc
import asyncio
import os
from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
@@ -11,6 +13,15 @@ from astrbot.core.provider.entities import (
ToolCallsResult,
)
from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
Providers: TypeAlias = Union[
"Provider",
"STTProvider",
"TTSProvider",
"EmbeddingProvider",
"RerankProvider",
]
class AbstractProvider(abc.ABC):
@@ -43,6 +54,14 @@ class AbstractProvider(abc.ABC):
)
return meta
async def test(self):
"""test the provider is a
raises:
Exception: if the provider is not available
"""
...
class Provider(AbstractProvider):
"""Chat Provider"""
@@ -132,7 +151,9 @@ class Provider(AbstractProvider):
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
"""
...
if False: # pragma: no cover - make this an async generator for typing
yield None # type: ignore
raise NotImplementedError()
async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录"""
@@ -165,6 +186,12 @@ class Provider(AbstractProvider):
return dicts
async def test(self, timeout: float = 45.0):
await asyncio.wait_for(
self.text_chat(prompt="REPLY `PONG` ONLY"),
timeout=timeout,
)
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -177,6 +204,14 @@ class STTProvider(AbstractProvider):
"""获取音频的文本"""
raise NotImplementedError
async def test(self):
sample_audio_path = os.path.join(
get_astrbot_path(),
"samples",
"stt_health_check.wav",
)
await self.get_text(sample_audio_path)
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -189,6 +224,9 @@ class TTSProvider(AbstractProvider):
"""获取文本的音频,返回音频文件路径"""
raise NotImplementedError
async def test(self):
await self.get_audio("hi")
class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
@@ -211,6 +249,9 @@ class EmbeddingProvider(AbstractProvider):
"""获取向量的维度"""
...
async def test(self):
await self.get_embedding("astrbot")
async def get_embeddings_batch(
self,
texts: list[str],
@@ -294,3 +335,8 @@ class RerankProvider(AbstractProvider):
) -> list[RerankResult]:
"""获取查询和文档的重排序分数"""
...
async def test(self):
result = await self.rerank("Apple", documents=["apple", "banana"])
if not result:
raise Exception("Rerank provider test failed, no results returned")
@@ -290,7 +290,7 @@ class ProviderAnthropic(Provider):
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
logger.error(f"发生了错误。Provider 配置如下: {model_config}")
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e
return llm_response
@@ -29,15 +29,24 @@ class OTTSProvider:
self.last_sync_time = 0
self.timeout = Timeout(10.0)
self.retry_count = 3
self.client = None
self._client: AsyncClient | None = None
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self):
self.client = AsyncClient(timeout=self.timeout)
self._client = AsyncClient(timeout=self.timeout)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client:
await self.client.aclose()
if self._client:
await self._client.aclose()
self._client = None
async def _sync_time(self):
try:
@@ -90,6 +99,7 @@ class OTTSProvider:
if attempt == self.retry_count - 1:
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
await asyncio.sleep(0.5 * (attempt + 1))
raise RuntimeError("OTTS未返回音频文件")
class AzureNativeProvider(TTSProvider):
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
)
self.client = None
self._client: AsyncClient | None = None
self.token = None
self.token_expire = 0
self.voice_params = {
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
"volume": provider_config.get("azure_tts_volume", "100"),
}
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self):
self.client = AsyncClient(
self._client = AsyncClient(
headers={
"User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml",
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client:
await self.client.aclose()
if self._client:
await self._client.aclose()
self._client = None
async def _refresh_token(self):
token_url = (
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
key_value = provider_config.get("azure_tts_subscription_key", "")
self.provider = self._parse_provider(key_value, provider_config)
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
def _parse_provider(
self, key_value: str, config: dict
) -> OTTSProvider | AzureNativeProvider:
if key_value.lower().startswith("other["):
json_str = ""
try:
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
if not match:
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
Returns:
重排序结果列表
"""
if not self.client:
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
return []
if not documents:
logger.warning("文档列表为空,返回空结果")
return []
@@ -1,650 +0,0 @@
import base64
import hashlib
import json
import os
from collections.abc import AsyncGenerator
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse
from ..register import register_provider_adapter
from .coze_api_client import CozeAPIClient
@register_provider_adapter("coze", "Coze (扣子) 智能体适配器")
class ProviderCoze(Provider):
def __init__(
self,
provider_config,
provider_settings,
) -> None:
super().__init__(
provider_config,
provider_settings,
)
self.api_key = provider_config.get("coze_api_key", "")
if not self.api_key:
raise Exception("Coze API Key 不能为空。")
self.bot_id = provider_config.get("bot_id", "")
if not self.bot_id:
raise Exception("Coze Bot ID 不能为空。")
self.api_base: str = provider_config.get("coze_api_base", "https://api.coze.cn")
if not isinstance(self.api_base, str) or not self.api_base.startswith(
("http://", "https://"),
):
raise Exception(
"Coze API Base URL 格式不正确,必须以 http:// 或 https:// 开头。",
)
self.timeout = provider_config.get("timeout", 120)
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
self.auto_save_history = provider_config.get("auto_save_history", True)
self.conversation_ids: dict[str, str] = {}
self.file_id_cache: dict[str, dict[str, str]] = {}
# 创建 API 客户端
self.api_client = CozeAPIClient(api_key=self.api_key, api_base=self.api_base)
def _generate_cache_key(self, data: str, is_base64: bool = False) -> str:
"""生成统一的缓存键
Args:
data: 图片数据或路径
is_base64: 是否是 base64 数据
Returns:
str: 缓存键
"""
try:
if is_base64 and data.startswith("data:image/"):
try:
header, encoded = data.split(",", 1)
image_bytes = base64.b64decode(encoded)
cache_key = hashlib.md5(image_bytes).hexdigest()
return cache_key
except Exception:
cache_key = hashlib.md5(encoded.encode("utf-8")).hexdigest()
return cache_key
elif data.startswith(("http://", "https://")):
# URL图片,使用URL作为缓存键
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
return cache_key
else:
clean_path = (
data.split("_")[0]
if "_" in data and len(data.split("_")) >= 3
else data
)
if os.path.exists(clean_path):
with open(clean_path, "rb") as f:
file_content = f.read()
cache_key = hashlib.md5(file_content).hexdigest()
return cache_key
cache_key = hashlib.md5(clean_path.encode("utf-8")).hexdigest()
return cache_key
except Exception as e:
cache_key = hashlib.md5(data.encode("utf-8")).hexdigest()
logger.debug(f"[Coze] 异常文件缓存键: {cache_key}, error={e}")
return cache_key
async def _upload_file(
self,
file_data: bytes,
session_id: str | None = None,
cache_key: str | None = None,
) -> str:
"""上传文件到 Coze 并返回 file_id"""
# 使用 API 客户端上传文件
file_id = await self.api_client.upload_file(file_data)
# 缓存 file_id
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
self.file_id_cache[session_id][cache_key] = file_id
logger.debug(f"[Coze] 图片上传成功并缓存,file_id: {file_id}")
return file_id
async def _download_and_upload_image(
self,
image_url: str,
session_id: str | None = None,
) -> str:
"""下载图片并上传到 Coze,返回 file_id"""
# 计算哈希实现缓存
cache_key = self._generate_cache_key(image_url) if session_id else None
if session_id and cache_key:
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
return file_id
try:
image_data = await self.api_client.download_image(image_url)
file_id = await self._upload_file(image_data, session_id, cache_key)
if session_id and cache_key:
self.file_id_cache[session_id][cache_key] = file_id
return file_id
except Exception as e:
logger.error(f"处理图片失败 {image_url}: {e!s}")
raise Exception(f"处理图片失败: {e!s}")
async def _process_context_images(
self,
content: str | list,
session_id: str,
) -> str:
"""处理上下文中的图片内容,将 base64 图片上传并替换为 file_id"""
try:
if isinstance(content, str):
return content
processed_content = []
if session_id not in self.file_id_cache:
self.file_id_cache[session_id] = {}
for item in content:
if not isinstance(item, dict):
processed_content.append(item)
continue
if item.get("type") == "text":
processed_content.append(item)
elif item.get("type") == "image_url":
# 处理图片逻辑
if "file_id" in item:
# 已经有 file_id
logger.debug(f"[Coze] 图片已有file_id: {item['file_id']}")
processed_content.append(item)
else:
# 获取图片数据
image_data = ""
if "image_url" in item and isinstance(item["image_url"], dict):
image_data = item["image_url"].get("url", "")
elif "data" in item:
image_data = item.get("data", "")
elif "url" in item:
image_data = item.get("url", "")
if not image_data:
continue
# 计算哈希用于缓存
cache_key = self._generate_cache_key(
image_data,
is_base64=image_data.startswith("data:image/"),
)
# 检查缓存
if cache_key in self.file_id_cache[session_id]:
file_id = self.file_id_cache[session_id][cache_key]
processed_content.append(
{"type": "image", "file_id": file_id},
)
else:
# 上传图片并缓存
if image_data.startswith("data:image/"):
# base64 处理
_, encoded = image_data.split(",", 1)
image_bytes = base64.b64decode(encoded)
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
elif image_data.startswith(("http://", "https://")):
# URL 图片
file_id = await self._download_and_upload_image(
image_data,
session_id,
)
# 为URL图片也添加缓存
self.file_id_cache[session_id][cache_key] = file_id
elif os.path.exists(image_data):
# 本地文件
with open(image_data, "rb") as f:
image_bytes = f.read()
file_id = await self._upload_file(
image_bytes,
session_id,
cache_key,
)
else:
logger.warning(
f"无法处理的图片格式: {image_data[:50]}...",
)
continue
processed_content.append(
{"type": "image", "file_id": file_id},
)
result = json.dumps(processed_content, ensure_ascii=False)
return result
except Exception as e:
logger.error(f"处理上下文图片失败: {e!s}")
if isinstance(content, str):
return content
return json.dumps(content, ensure_ascii=False)
async def text_chat(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
"""文本对话, 内部使用流式接口实现非流式
Args:
prompt (str): 用户提示词
session_id (str): 会话ID
image_urls (List[str]): 图片URL列表
func_tool (FuncCall): 函数调用工具(不支持)
contexts (List): 上下文列表
system_prompt (str): 系统提示语
tool_calls_result (ToolCallsResult | List[ToolCallsResult]): 工具调用结果(不支持)
model (str): 模型名称(不支持)
Returns:
LLMResponse: LLM响应对象
"""
accumulated_content = ""
final_response = None
async for llm_response in self.text_chat_stream(
prompt=prompt,
session_id=session_id,
image_urls=image_urls,
func_tool=func_tool,
contexts=contexts,
system_prompt=system_prompt,
tool_calls_result=tool_calls_result,
model=model,
**kwargs,
):
if llm_response.is_chunk:
if llm_response.completion_text:
accumulated_content += llm_response.completion_text
else:
final_response = llm_response
if final_response:
return final_response
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
return LLMResponse(role="assistant", result_chain=chain)
return LLMResponse(role="assistant", completion_text="")
async def text_chat_stream(
self,
prompt: str,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话接口"""
# 用户ID参数(参考文档, 可以自定义)
user_id = session_id or kwargs.get("user", "default_user")
# 获取或创建会话ID
conversation_id = self.conversation_ids.get(user_id)
# 构建消息
additional_messages = []
if system_prompt:
if not self.auto_save_history or not conversation_id:
additional_messages.append(
{
"role": "system",
"content": system_prompt,
"content_type": "text",
},
)
contexts = self._ensure_message_to_dicts(contexts)
if not self.auto_save_history and contexts:
# 如果关闭了自动保存历史,传入上下文
for ctx in contexts:
if isinstance(ctx, dict) and "role" in ctx and "content" in ctx:
content = ctx["content"]
content_type = ctx.get("content_type", "text")
# 处理可能包含图片的上下文
if (
content_type == "object_string"
or (isinstance(content, str) and content.startswith("["))
or (
isinstance(content, list)
and any(
isinstance(item, dict)
and item.get("type") == "image_url"
for item in content
)
)
):
processed_content = await self._process_context_images(
content,
user_id,
)
additional_messages.append(
{
"role": ctx["role"],
"content": processed_content,
"content_type": "object_string",
},
)
else:
# 纯文本
additional_messages.append(
{
"role": ctx["role"],
"content": (
content
if isinstance(content, str)
else json.dumps(content, ensure_ascii=False)
),
"content_type": "text",
},
)
else:
logger.info(f"[Coze] 跳过格式不正确的上下文: {ctx}")
if prompt or image_urls:
if image_urls:
# 多模态
object_string_content = []
if prompt:
object_string_content.append({"type": "text", "text": prompt})
for url in image_urls:
try:
if url.startswith(("http://", "https://")):
# 网络图片
file_id = await self._download_and_upload_image(
url,
user_id,
)
else:
# 本地文件或 base64
if url.startswith("data:image/"):
# base64
_, encoded = url.split(",", 1)
image_data = base64.b64decode(encoded)
cache_key = self._generate_cache_key(
url,
is_base64=True,
)
file_id = await self._upload_file(
image_data,
user_id,
cache_key,
)
# 本地文件
elif os.path.exists(url):
with open(url, "rb") as f:
image_data = f.read()
# 用文件路径和修改时间来缓存
file_stat = os.stat(url)
cache_key = self._generate_cache_key(
f"{url}_{file_stat.st_mtime}_{file_stat.st_size}",
is_base64=False,
)
file_id = await self._upload_file(
image_data,
user_id,
cache_key,
)
else:
logger.warning(f"图片文件不存在: {url}")
continue
object_string_content.append(
{
"type": "image",
"file_id": file_id,
},
)
except Exception as e:
logger.error(f"处理图片失败 {url}: {e!s}")
continue
if object_string_content:
content = json.dumps(object_string_content, ensure_ascii=False)
additional_messages.append(
{
"role": "user",
"content": content,
"content_type": "object_string",
},
)
# 纯文本
elif prompt:
additional_messages.append(
{
"role": "user",
"content": prompt,
"content_type": "text",
},
)
try:
accumulated_content = ""
message_started = False
async for chunk in self.api_client.chat_messages(
bot_id=self.bot_id,
user_id=user_id,
additional_messages=additional_messages,
conversation_id=conversation_id,
auto_save_history=self.auto_save_history,
stream=True,
timeout=self.timeout,
):
event_type = chunk.get("event")
data = chunk.get("data", {})
if event_type == "conversation.chat.created":
if isinstance(data, dict) and "conversation_id" in data:
self.conversation_ids[user_id] = data["conversation_id"]
elif event_type == "conversation.message.delta":
if isinstance(data, dict):
content = data.get("content", "")
if not content and "delta" in data:
content = data["delta"].get("content", "")
if not content and "text" in data:
content = data.get("text", "")
if content:
message_started = True
accumulated_content += content
yield LLMResponse(
role="assistant",
completion_text=content,
is_chunk=True,
)
elif event_type == "conversation.message.completed":
if isinstance(data, dict):
msg_type = data.get("type")
if msg_type == "answer" and data.get("role") == "assistant":
final_content = data.get("content", "")
if not accumulated_content and final_content:
chain = MessageChain(chain=[Comp.Plain(final_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
elif event_type == "conversation.chat.completed":
if accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
break
elif event_type == "done":
break
elif event_type == "error":
error_msg = (
data.get("message", "未知错误")
if isinstance(data, dict)
else str(data)
)
logger.error(f"Coze 流式响应错误: {error_msg}")
yield LLMResponse(
role="err",
completion_text=f"Coze 错误: {error_msg}",
is_chunk=False,
)
break
if not message_started and not accumulated_content:
yield LLMResponse(
role="assistant",
completion_text="LLM 未响应任何内容。",
is_chunk=False,
)
elif message_started and accumulated_content:
chain = MessageChain(chain=[Comp.Plain(accumulated_content)])
yield LLMResponse(
role="assistant",
result_chain=chain,
is_chunk=False,
)
except Exception as e:
logger.error(f"Coze 流式请求失败: {e!s}")
yield LLMResponse(
role="err",
completion_text=f"Coze 流式请求失败: {e!s}",
is_chunk=False,
)
async def forget(self, session_id: str):
"""清空指定会话的上下文"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if user_id in self.file_id_cache:
self.file_id_cache.pop(user_id, None)
if not conversation_id:
return True
try:
response = await self.api_client.clear_context(conversation_id)
if "code" in response and response["code"] == 0:
self.conversation_ids.pop(user_id, None)
return True
logger.warning(f"清空 Coze 会话上下文失败: {response}")
return False
except Exception as e:
logger.error(f"清空 Coze 会话失败: {e!s}")
return False
async def get_current_key(self):
"""获取当前API Key"""
return self.api_key
async def set_key(self, key: str):
"""设置新的API Key"""
raise NotImplementedError("Coze 适配器不支持设置 API Key。")
async def get_models(self):
"""获取可用模型列表"""
return [f"bot_{self.bot_id}"]
def get_model(self):
"""获取当前模型"""
return f"bot_{self.bot_id}"
def set_model(self, model: str):
"""设置模型(在Coze中是Bot ID"""
if model.startswith("bot_"):
self.bot_id = model[4:]
else:
self.bot_id = model
async def get_human_readable_context(
self,
session_id: str,
page: int = 1,
page_size: int = 10,
):
"""获取人类可读的上下文历史"""
user_id = session_id
conversation_id = self.conversation_ids.get(user_id)
if not conversation_id:
return []
try:
data = await self.api_client.get_message_list(
conversation_id=conversation_id,
order="desc",
limit=page_size,
offset=(page - 1) * page_size,
)
if data.get("code") != 0:
logger.warning(f"获取 Coze 消息历史失败: {data}")
return []
messages = data.get("data", {}).get("messages", [])
readable_history = []
for msg in messages:
role = msg.get("role", "unknown")
content = msg.get("content", "")
msg_type = msg.get("type", "")
if role == "user":
readable_history.append(f"用户: {content}")
elif role == "assistant" and msg_type == "answer":
readable_history.append(f"助手: {content}")
return readable_history
except Exception as e:
logger.error(f"获取 Coze 消息历史失败: {e!s}")
return []
async def terminate(self):
"""清理资源"""
await self.api_client.close()

Some files were not shown because too many files have changed in this diff Show More