Compare commits

...

108 Commits

Author SHA1 Message Date
Soulter 7cedf0d587 chore: improve documentation for extra_user_content_parts in Provider classes 2025-12-26 21:55:44 +08:00
kawayiYokami aeb21f719e claude额外块支持图片模态 2025-12-26 21:54:01 +08:00
Soulter 7c1dbecea5 refactor: unify extra_user_content_parts type to ContentPart across providers and update related handling 2025-12-26 21:47:02 +08:00
kawayiYokami 05012af627 重命名 2025-12-26 20:54:38 +08:00
kawayiYokami 17b52ab5dd 传递链 2025-12-26 18:57:51 +08:00
kawayiYokami 9449ff668b FIX 2025-12-25 13:33:40 +08:00
kawayiYokami c5a2827def feat: 多文本块功能 2025-12-25 03:54:05 +08:00
Soulter 701399c00c docs: update readme xmas 2025-12-24 21:58:04 +08:00
Soulter eaee98d4b8 chore: bump version to 4.10.2 2025-12-24 21:55:05 +08:00
Soulter 76c66000a7 chore: restrict psutil version <7.2.0 to avoid compatibility issues
fixes: #4176
2025-12-24 15:48:58 +08:00
Oscar Shaw 4b365143c0 feat: support for managing command aliases (#4170)
* feat(command): persist aliases on rename and apply to runtime filter

* feat(dashboard-api): support aliases in rename command endpoint

* feat(dashboard-ui): add alias editor to rename command dialog

* feat(dashboard-ui): enhance alias editor UI in rename dialog
2025-12-24 15:37:10 +08:00
Soulter 6e4e5011e2 chore: bump version to 4.10.1 2025-12-23 21:35:40 +08:00
Venus Yan d853bfde84 perf: handle unsupported message types with logging in OneBot adapter (#4164)
* Handle unsupported message types with logging

解决else 分支中对未知消息类型毫无防御,直接索引ComponentTypes[t],导致新类型markdown类信息报错并炸掉事件管道,且对应群聊单群永久不响应插件;尝试支持markdown类型进行支持但未经过测试

* chore: ruff format

* chore: ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-23 21:31:32 +08:00
Soulter a0e856f80f fix: provider source id contains slash will lead to 405 (#4162) 2025-12-22 20:28:20 +08:00
Oscar Shaw 8c94a0010c fix(core): improve error handling of command parser and sync (#4161) 2025-12-22 19:54:26 +08:00
Soulter a44fdaaec0 chore: bump version to 4.10.0 2025-12-22 18:10:30 +08:00
Soulter 60105c76f5 feat: implement router loading progress indicator 2025-12-22 13:20:39 +08:00
Soulter bcf87d3ce4 fix: update provider subtitle for clarity in English and Chinese locales
- Revised the subtitle in the provider feature localization files to provide a more detailed description of functionalities, including chat model configuration and third-party service integrations.
2025-12-22 13:13:42 +08:00
Soulter 4d7c8c8453 style: add active background color for provider source list item in dark theme 2025-12-22 12:59:55 +08:00
Soulter a064a9115f fix: omit thinking params for gemini image generation models (#4151)
- Expanded model name checks to include specific Gemini 2.5 and 3 variants, ensuring correct configuration for thinking parameters based on the model used.
2025-12-22 00:09:30 +08:00
Soulter 6ef99e1553 feat: enhance ChatInput and ConversationSidebar dark theme 2025-12-21 21:19:54 +08:00
Soulter c0dbe5cf65 chore: bump version to 4.10.0-alpha.2 2025-12-21 13:11:32 +08:00
Soulter 3598c51eff fix: enhance provider model menu and sidebar session selection handling (#4144)
- Updated `ProviderModelMenu.vue` to manage menu state and load provider configurations dynamically upon opening.
- Filtered provider configurations to exclude those with `enable` set to false.
- Improved session selection logic in `useSessions.ts` to ensure the currently selected session is highlighted and properly managed during navigation.
2025-12-21 13:05:15 +08:00
Soulter b5cdb8f650 fix: improve error handling in tool execution to prevent infinite tool call loops (#4143)
* fix: improve error handling in tool execution to prevent infinite tool call loops

- Enhanced error handling in `call_local_llm_tool` to provide more informative exceptions for ValueError and TypeError, including detailed parameter information.
- Updated `ToolLoopAgentRunner` to yield appropriate messages for cases with no response or unsupported types, ensuring clearer communication to users.
- Improved logging and messaging consistency across tool execution processes.

* refactor: clean up unused router parameter in message retrieval functions

- Removed the unused `router` parameter from `getSessionMessages` and related function calls in `Chat.vue` and `useMessages.ts`.
- Commented out the `tool_calls` dictionary in `chat.py` for clarity, indicating it is not currently in use.

* fix: enhance exception handling in tool execution for clearer error reporting

- Improved exception handling in `call_local_llm_tool` by chaining exceptions for ValueError and TypeError, providing more context in error messages.
- Ensured that traceback information is preserved in raised exceptions for better debugging.
2025-12-21 12:57:54 +08:00
Yokami fc5b520f9b perf(agent): add max step limit to prevent infinite tool call loops (#4110)
* perf(agent): add max step limit to prevent infinite tool call loops

* feat: implement max step limit handling in main agent runner

- Enhanced the agent runner to enforce a maximum step limit, logging a warning and forcing a final response when the limit is reached.
- Updated message handling to append a user prompt when the tool call limit is exceeded.
- Refactored tool response handling to yield appropriate messages based on the response type, including handling cases with no response or unsupported types.
- Improved conversation message formatting to ensure consistent output in the assistant's responses.

* chore: ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-21 12:30:43 +08:00
Soulter 904f56b32f fix: webui conversation traj data display error (#4142)
fixes: #4141
2025-12-20 23:29:40 +08:00
Soulter 2f15fd019c chore: bump version to v4.10.0-alpha.1 2025-12-20 16:35:54 +08:00
Soulter 82330b8d10 feat: add changelog functionality and dialog component (#4135)
* feat: add changelog functionality and dialog component

- Implemented new routes for fetching changelogs and available versions in StatRoute.
- Created ChangelogDialog.vue for displaying changelog content and version selection.
- Updated VerticalSidebar.vue to include a button for opening the changelog dialog.
- Enhanced localization files for English and Chinese to support new changelog features.
- Adjusted styles in VerticalHeader.vue for improved layout consistency.

* chore: ruff format
2025-12-20 16:33:12 +08:00
Soulter 3ee6af7027 feat: add route watcher for viewMode changes in VerticalHeader.vue
- Introduced a watcher to monitor changes in customizer.viewMode, automatically redirecting to the homepage when switching from 'chat' to 'bot' mode.
- Updated imports to include useRoute from vue-router for routing functionality.
- Adjusted button styles for improved layout consistency in bot mode.
2025-12-20 15:38:01 +08:00
Soulter 6e20ebe901 feat: add KaTeX and Mermaid and computation-friendly renderer support (#4118)
* feat: add KaTeX and Mermaid support for enhanced markdown rendering in MessageList.vue

closes: #3747
- Integrated @mdit/plugin-katex and katex for LaTeX rendering.
- Added markstream-vue for improved markdown rendering capabilities.
- Updated MessageList.vue to utilize MarkdownRender component for rendering markdown content.
- Enhanced UI for dark mode compatibility across various components.
- Introduced new styles for file links, reasoning blocks, and tool call cards to improve visual consistency.

* refactor: replace markdown-it with markstream-vue for improved markdown rendering

- Removed markdown-it and related configurations from ReadmeDialog.vue, VerticalHeader.vue, and ConversationPage.vue.
- Integrated markstream-vue for enhanced markdown rendering capabilities, including support for KaTeX and Mermaid.
- Updated components to utilize MarkdownRender for rendering markdown content, improving consistency and performance.

* chore: remove deprecated markdown-it and marked dependencies from pnpm-lock.yaml

- Cleaned up pnpm-lock.yaml by removing markdown-it and marked entries, streamlining the dependency list.
- This change follows the recent integration of markstream-vue for improved markdown rendering capabilities.

* chore: remove d3 dependency and update MessageList.vue for dark mode support

- Removed d3 from package.json and commented out its import in LongTermMemory.vue to clean up unused dependencies.
- Updated MessageList.vue to ensure consistent dark mode styling by passing the isDark prop to MarkdownRender components.

* feat: add loading indicator for message retrieval in Chat and MessageList components

- Introduced a loading overlay in Chat.vue and MessageList.vue to indicate when messages are being loaded.
- Added a new `isLoadingMessages` prop to manage loading state and enhance user experience during message retrieval.
- Updated styles to ensure the loading indicator is visually integrated with the existing UI.

* feat: add provider configuration dialog to chat sidebar

- Introduced a new `ProviderConfigDialog` component for managing provider settings.
- Added a menu item in the `ConversationSidebar` to open the provider configuration dialog.
- Updated English and Chinese localization files to include translations for the new provider configuration feature.

* feat: update dashboard components and styles for improved chat experience

- Replaced font in index.html to use 'Outfit' for a fresh look.
- Changed icon in ConversationSidebar.vue to 'mdi-creation' for better representation.
- Refactored MessageList.vue to streamline loading indicators and enhance styling consistency.
- Updated localization files to change 'Provider Configuration' to 'AI Configuration' for clarity.
- Introduced new styles for loading indicators and chat mode adjustments in FullLayout.vue.
- Added functionality for toggling between bot and chat modes in the header.
- Removed deprecated sidebar item for chat navigation.

* feat: xmas easter egg

* chore: remove pnpm lock file
2025-12-20 15:22:48 +08:00
Yokami 4d6150fd6d fix: handle quoted messages correctly to prevent breaking cache (#4112)
* fix: Handle quoted messages correctly as user context

This change ensures quoted messages, including text and image captions, are appended to the conversation history as a user message rather than being injected into the system prompt.

Fixes #3886

* 注入到req.prompt里
2025-12-20 11:03:27 +08:00
Soulter 544e52191b Merge pull request #4065 from AstrBotDevs/refactor/provider-source
refactor: SUPER AMAZING model provider refactor
2025-12-20 00:09:36 +08:00
Soulter f2c2a6da4a chore: ruff format 2025-12-20 00:07:42 +08:00
Soulter dd3df425ee feat: add warnings for missing provider IDs in manager and context
- Introduced logging warnings in ProviderManager and Context classes when a provider ID is not found, indicating potential issues due to ID modifications.
- Updated the ProviderPage.vue to advise against modifying provider IDs, highlighting possible configuration impacts.
2025-12-20 00:06:42 +08:00
Soulter 40b4a27a3d Merge remote-tracking branch 'origin/master' into refactor/provider-source 2025-12-19 15:48:42 +08:00
Soulter 9d991c7468 perf: enhance chat components with theme and fullscreen toggles (#4116)
* perf: enhance chat components with theme and fullscreen toggles

- Added theme and fullscreen toggle functionality to Chat.vue and ConversationSidebar.vue.
- Introduced a new StyledMenu component for improved dropdown menus.
- Updated MessageList.vue and ChatInput.vue for better mobile responsiveness and UI consistency.
- Enhanced language switcher integration in ConversationSidebar.vue.
- Added new settings translations in English and Chinese locales.

* fix: streamline conversation selection handling in Chat.vue

- Updated handleSelectConversation function to immediately set the current session ID and selected sessions, reducing the need for multiple clicks.
- Adjusted padding in ConversationSidebar.vue for improved layout consistency.
2025-12-19 11:18:01 +08:00
Soulter ad6a8b5c94 Merge remote-tracking branch 'origin/master' into refactor/provider-source 2025-12-18 17:39:27 +08:00
Soulter 1b4bfcbd72 chore: ruff format 2025-12-18 17:37:12 +08:00
Soulter 9d3cc593a1 feat: supports thinking level of google gemini (#4104)
* feat: supports thinking level of google gemini

- Updated google-genai version to >=1.56.0 in pyproject.toml and requirements.txt.
- Changed model configuration from "gemini-1.5-flash" to "gemini-3-flash-preview" in default.py.
- Enhanced thinking configuration handling in gemini_source.py to support new parameters for Gemini 3 models.

* fix: standardize thinking level configuration in default.py and gemini_source.py

- Updated the thinking level values in default.py to uppercase for consistency.
- Enhanced gemini_source.py to validate the thinking level and default to "HIGH" if an invalid value is provided.
2025-12-18 17:37:11 +08:00
Soulter f0dee35ba9 feat: enhance tool call handling and agent stats tracking and UI integration for tool calls render (#4101)
* feat: enhance tool call handling and UI integration for tool calls render

- Added support for tool call messages in the agent runner and webchat event handling.
- Implemented JSON message component for structured tool call data.
- Updated chat route to save tool call information in message history.
- Enhanced frontend to display tool call details in a collapsible format, including status and results.
- Introduced elapsed time tracking for ongoing tool calls in the chat interface.

* fix: improve message handling in agent run utility and tool loop runner

- Refactored message sending logic in `astr_agent_run_util.py` to use `msg_chain` directly for better clarity.
- Added a check in `tool_loop_agent_runner.py` to ensure `tool_call_result_blocks` is not empty before yielding the last tool call result, preventing potential errors.

* refactor: enhance message structure and UI for chat components

- Updated message handling in `MessageList.vue` to support structured message parts, including plain text, images, audio, and files.
- Improved the `Chat.vue` component styles for better visual consistency.
- Refactored message parsing logic in `useMessages.ts` to accommodate new message formats and ensure proper rendering of embedded content.
- Removed deprecated tool call handling from the message structure, streamlining the message display process.

* chore: ruff format

* feat: implement agent statistics tracking and display in chat

- Added `AgentStats` and `TokenUsage` data classes to track agent performance metrics.
- Enhanced `ToolLoopAgentRunner` to collect and update agent statistics during execution.
- Integrated agent statistics sending to webchat for real-time updates.
- Updated chat route to save and display agent statistics in message history.
- Improved frontend components to visualize agent statistics, including token usage and duration metrics.

* fix: improve message handling in Telegram event and agent run utility

- Updated message sending logic in `astr_agent_run_util.py` to send the correct message chain for tool calls.
- Enhanced `tg_event.py` to edit messages during streaming breaks, improving message management and user experience.
- Added error handling for message editing failures to ensure robustness.

* chore: ruff format
2025-12-18 17:36:45 +08:00
Soulter 4135bd84d5 refactor: update OneBot configuration and add platform logo (#4106)
- Renamed "QQ 个人号(OneBot v11)" to "OneBot v11" in the configuration.
- Added a new logo for OneBot in the dashboard assets.
- Updated platform icon retrieval logic to include the new OneBot logo.
2025-12-18 17:34:59 +08:00
Soulter f6da614e5d fix: validation error for ToolCall.extra_content in specific upstream model providers (#4102)
* fix: validation error for ToolCall.extra_content in specific upstream model providers

* fix: handle missing extra_content gracefully in ToolCall serialization
2025-12-18 17:34:59 +08:00
Soulter 5f531c9be5 chore: ruff format 2025-12-18 17:17:17 +08:00
Soulter 94591d965b feat: supports thinking level of google gemini (#4104)
* feat: supports thinking level of google gemini

- Updated google-genai version to >=1.56.0 in pyproject.toml and requirements.txt.
- Changed model configuration from "gemini-1.5-flash" to "gemini-3-flash-preview" in default.py.
- Enhanced thinking configuration handling in gemini_source.py to support new parameters for Gemini 3 models.

* fix: standardize thinking level configuration in default.py and gemini_source.py

- Updated the thinking level values in default.py to uppercase for consistency.
- Enhanced gemini_source.py to validate the thinking level and default to "HIGH" if an invalid value is provided.
2025-12-18 17:15:01 +08:00
Soulter 8a0f865af1 feat: enhance tool call handling and agent stats tracking and UI integration for tool calls render (#4101)
* feat: enhance tool call handling and UI integration for tool calls render

- Added support for tool call messages in the agent runner and webchat event handling.
- Implemented JSON message component for structured tool call data.
- Updated chat route to save tool call information in message history.
- Enhanced frontend to display tool call details in a collapsible format, including status and results.
- Introduced elapsed time tracking for ongoing tool calls in the chat interface.

* fix: improve message handling in agent run utility and tool loop runner

- Refactored message sending logic in `astr_agent_run_util.py` to use `msg_chain` directly for better clarity.
- Added a check in `tool_loop_agent_runner.py` to ensure `tool_call_result_blocks` is not empty before yielding the last tool call result, preventing potential errors.

* refactor: enhance message structure and UI for chat components

- Updated message handling in `MessageList.vue` to support structured message parts, including plain text, images, audio, and files.
- Improved the `Chat.vue` component styles for better visual consistency.
- Refactored message parsing logic in `useMessages.ts` to accommodate new message formats and ensure proper rendering of embedded content.
- Removed deprecated tool call handling from the message structure, streamlining the message display process.

* chore: ruff format

* feat: implement agent statistics tracking and display in chat

- Added `AgentStats` and `TokenUsage` data classes to track agent performance metrics.
- Enhanced `ToolLoopAgentRunner` to collect and update agent statistics during execution.
- Integrated agent statistics sending to webchat for real-time updates.
- Updated chat route to save and display agent statistics in message history.
- Improved frontend components to visualize agent statistics, including token usage and duration metrics.

* fix: improve message handling in Telegram event and agent run utility

- Updated message sending logic in `astr_agent_run_util.py` to send the correct message chain for tool calls.
- Enhanced `tg_event.py` to edit messages during streaming breaks, improving message management and user experience.
- Added error handling for message editing failures to ensure robustness.

* chore: ruff format
2025-12-18 17:11:09 +08:00
Soulter 4aced976a8 refactor: update OneBot configuration and add platform logo (#4106)
- Renamed "QQ 个人号(OneBot v11)" to "OneBot v11" in the configuration.
- Added a new logo for OneBot in the dashboard assets.
- Updated platform icon retrieval logic to include the new OneBot logo.
2025-12-18 15:19:15 +08:00
Soulter 0299aa6e4c fix: validation error for ToolCall.extra_content in specific upstream model providers (#4102)
* fix: validation error for ToolCall.extra_content in specific upstream model providers

* fix: handle missing extra_content gracefully in ToolCall serialization
2025-12-18 11:55:49 +08:00
Soulter e8b54a019e refactor: replace ProviderModelSelector with ProviderModelMenu for improved UI and functionality 2025-12-17 22:57:32 +08:00
Soulter 98ce796275 chore: remove copilot instruction 2025-12-17 17:21:33 +08:00
Soulter b87dcf2275 refactor: improve provider source ID validation to prevent duplicates during configuration updates 2025-12-17 17:19:35 +08:00
Soulter 591a228431 refactor: enhance provider management with resource locking and CRUD operations 2025-12-17 17:08:52 +08:00
Soulter f52f375154 refactor: update provider handling to use new config structure and improve template retrieval 2025-12-17 16:55:12 +08:00
Soulter 975c685a17 chore: ruff format 2025-12-17 16:32:38 +08:00
Soulter 6db80d36a8 fix: prevent platform ID modification during updates and ensure correct routing table handling 2025-12-17 16:16:50 +08:00
Soulter 4651bd2807 feat: implement provider deletion functionality and ensure unique provider IDs 2025-12-17 15:00:22 +08:00
Soulter 94ada3793e Merge remote-tracking branch 'origin/master' into refactor/provider-source 2025-12-17 13:33:23 +08:00
Soulter fd05b0bf09 docs: update contributing guidelines to include code style and formatting instructions 2025-12-17 13:26:22 +08:00
Soulter 4d046f8490 delete: remove backup of ProviderPage.vue 2025-12-17 11:34:12 +08:00
Copilot 58e32b7b70 fix: inverted logic in segmented reply LLM-only filter (#4071)
* Initial plan

* Fix: Correct inverted logic in is_seg_reply_required for only_llm_result option

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-17 11:12:05 +08:00
Soulter 903dd0f9f7 feat: add manual model addition functionality and search capability in ProviderPage 2025-12-17 10:56:45 +08:00
Soulter 1acac0cac2 feat: enhance provider selection with a new drawer interface and localization updates 2025-12-17 10:39:16 +08:00
Oscar Shaw 80b89fd2ea feat: implements command management and improve webui feature structure (#3904)
move mcp management to plugin managemanet page

* feat: 新增命令配置数据库模型

* feat: 实现核心命令管理系统

* feat: 将命令管理集成到 Star 框架

* feat: 新增命令管理后台 API

* feat: 新增命令管理界面页面

* feat: 新增命令管理国际化支持

* test: 新增命令管理相关测试

* refactor(command): 移除指令重命名时的别名功能

* fix(command): 修正指令冲突检测逻辑

* fix(command): 排除已禁用指令的冲突检测

- 只有 `effective_command` 存在且 `enabled` 为 `True` 的指令才会被纳入冲突检测范围。

* feat(command): 优化指令冲突显示与提示

- 【功能】新增指令冲突警告提示,当检测到冲突时显示详细信息及解决方案。
- 【优化】调整指令列表排序逻辑,将冲突指令优先显示并分组。
- 【样式】为冲突指令行添加专属高亮样式,提升视觉识别度。
- 【国际化】更新英文和中文多语言文件,增加指令冲突警告相关的翻译文本。

* chore(command-page): 禁用命令表格部分列的排序功能

* style(command-page): 调整命令页面表格样式和图标大小

* refactor(command): 优化指令页面布局并更新冲突警告

- 【布局优化】重新组织指令管理页面布局,将筛选器移至顶部独立行
- 【信息展示】将搜索栏与总指令数、已禁用指令数合并显示,提升页面空间利用率
- 【视觉更新】更新指令冲突警告样式

* style: UI 细节

* refactor(command): 调整指令管理中的成员权限显示与筛选

  - 更新指令筛选逻辑,当选择“所有人”权限筛选时,将同时包含 `everyone` 和 `member` 权限的指令。

* feat(command-management): 新增指令层级管理与UI展示

- 【后端】
  - `CommandDescriptor` 新增 `parent_group_handler` 和 `sub_commands` 字段,支持指令层级结构定义。
  - `list_commands` 函数重构,实现指令的层级收集与构建,将子指令正确挂载到其父指令组下。
  - 新增 `_collect_all_descriptors` 和 `_find_parent_group_handler` 辅助函数,用于全面收集指令并定位父指令组。
  - `_build_descriptor` 优化指令类型判断逻辑,明确区分普通指令、指令组和子指令。
  - `_descriptor_to_dict` 递归处理子指令,确保 API 返回完整的指令层级数据。
- 【前端】
  - 指令管理页面 (`CommandPage.vue`) 增加指令类型筛选器,并支持指令组的展开/折叠功能。
  - 表格展示优化,为指令组和子指令添加不同的样式和缩进,提升层级结构的视觉可读性。
  - 指令详情对话框新增指令类型、所属指令组和子指令列表的展示。
  - 更新 `CommandItem` 接口,以适配后端提供的层级数据结构。
- 【i18n】
  - 新增指令类型(指令、指令组、子指令)的国际化文本。
  - 更新指令管理相关 UI 文本,包括表格头部、详情对话框字段和筛选器选项。

* style(command): 优化指令组子指令数量显示UI

* refactor(command): 修改指令列表排序逻辑

* style(command-page): 优化命令列表UI

* feat(command): 添加系统插件指令过滤与冲突处理

* refactor(command): 更新指令数展示逻辑

* style(command): 更新空状态描述

* feat(extension): 添加插件指令冲突检测与提示

- 在插件安装或启用后,自动检测并提示指令冲突。
- 当检测到指令冲突时,显示警告对话框,告知用户冲突数量及可能的影响。

* refactor(command): 移除指令表格内部加载指示器

* style(extension): 文案修改

* refactor(command): 模块化指令管理面板前端代码

* refactor(commandPanel): 重命名指令模块目录为 commandPanel

* style(commandPanel): 微调指令面板UI

* fix(command): 确保新命令配置的事务提交

* fix(sidebar): 补全新增侧边栏项后的侧边栏位追加逻辑

* refactor(commands): 重构/help指令以动态显示实际命令并补充部分命令描述

* style(builtin_commands): 补充命令描述

* refactor(commandPanel): 移除未使用的 filterState 常量

* perf(dashboard): 删除多余的CommandPage.vue文件(已被模块化引用)

* perf(command): 优化命令冲突计数逻辑

* perf(command): 优化指令管理辅助函数和配置绑定逻辑

* perf(db): 优化重构command相关数据库操作

* refactor(sidebar): 提取侧边栏项目解析逻辑到工具函数复用

* refactor: move mcp and command page to extension page

* refactor: remove unused imports in component panel

* fix: update terminology for handler management in extension localization

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-16 20:24:57 +08:00
Soulter 26f863ba81 Revert "fix: omit empty content field for the LLM request after tool calls ar…" (#4068)
This reverts commit f78a90218e.
2025-12-16 20:22:13 +08:00
sctop f78a90218e fix: omit empty content field for the LLM request after tool calls are completed (#4008)
* fix: omit content field for the LLM request after tool calls are completed and content is empy string or none

* chore: ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-16 20:11:11 +08:00
Soulter a3ecebd2aa fix: correct text accumulation logic in webchat (#4066) 2025-12-16 19:35:41 +08:00
Soulter 67c33b842d feat: add new provider icons and improve provider source handling
- Added icons for 'modelstack', 'tokenpony', and 'compshare' in providerUtils.js.
- Updated ProviderPage.vue to display the correct count of displayed provider sources.
- Enhanced the logic for displaying provider sources to include placeholders for unselected templates.
- Improved the display name for provider sources to show template keys for placeholders.
- Adjusted styles for better layout and overflow handling in provider source list and cards.
- Refactored source selection logic to handle placeholder sources correctly.
- Updated error handling in provider testing to provide clearer messages.
2025-12-16 16:11:56 +08:00
Soulter 5431c9f46e refactor: remove unused tab from AddNewProvider and disable button based on provider status in ProviderPage 2025-12-16 12:26:26 +08:00
Soulter 764b91a5f7 chore: ruff check 2025-12-16 12:21:14 +08:00
Soulter c20c1b84bf feat: implement LLM metadata fetching and integrate into provider model selection 2025-12-16 12:19:40 +08:00
Soulter fd66a0ac00 perf: better UI 2025-12-16 11:24:07 +08:00
Soulter aaee283367 fix: type checking of AstrAgentContext 2025-12-16 10:09:57 +08:00
Soulter 4a5b7d1976 fix: type checking of contextwrapper 2025-12-16 09:59:56 +08:00
Sukafon 08244548ab fix: incorrect type assignment when the agent send an image (#4050) 2025-12-16 08:28:10 +08:00
dependabot[bot] b486de6a98 chore(deps): bump actions/upload-artifact in the github-actions group (#4061)
Bumps the github-actions group with 1 update: [actions/upload-artifact](https://github.com/actions/upload-artifact).


Updates `actions/upload-artifact` from 5 to 6
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-16 08:24:03 +08:00
Soulter e2f928a7e5 chore: bump version to 4.9.2 2025-12-15 16:58:32 +08:00
Soulter b8e4068c75 feat: support key-value storage for plugins (#4048)
* feat: support key-value storage for plugins

* fix: remove unnecessary initialization method from Main class
2025-12-15 16:50:44 +08:00
Soulter 0916177a57 chore: bump version to 4.9.1 2025-12-15 16:07:10 +08:00
Soulter 02cd5e396b feat: add trigger probability setting for TTS and support to render slider in schema (#4047)
* feat: add trigger probability setting for TTS and support to render slider in schema

* chore: ruff format
2025-12-15 16:04:27 +08:00
Soulter 56673ad78f fix: prevent duplicate result content type after streaming finishes in RespondStage 2025-12-15 15:33:40 +08:00
Soulter 9a4d05e2b6 fix: remove unnecessary persistent attribute from ReadmeDialog and adjust dialog structure in ExtensionPage 2025-12-15 15:27:42 +08:00
Soulter b2e9dab233 refactor: enhance layout and improve provider source management in ProviderPage 2025-12-15 15:15:17 +08:00
Soulter 45110200ea feat: update provider and provider source configuration handling 2025-12-15 12:31:29 +08:00
Soulter c3f45449e8 docs: readme
wa ta shi wa ko sei no de su ka ra!
2025-12-15 11:47:21 +08:00
Copilot 65da469deb feat: add conversation export feature to JSONL for AI training (#4037)
* Initial plan

* Add conversation export functionality (backend and frontend)

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Address code review feedback: move imports, simplify logic, improve i18n

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Simplify frontend download logic: remove redundant Blob wrapper and complex filename parsing

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* fix: update conversation export filename format for consistency

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-14 21:44:12 +08:00
Soulter 16df64c405 fix: lark domain and log_level of Lark API client (#4038)
fixes: #4035
2025-12-14 21:31:17 +08:00
i0cLiceao 6b73b19e54 fix: support using GitHub Raw content as plugin source (#3975)
* Update plugin.py

* Update plugin.py

* Update plugin.py

* Update plugin.py
2025-12-14 18:23:29 +08:00
Soulter a70088b799 Merge remote-tracking branch 'origin/master' into refactor/provider-source 2025-12-13 23:37:23 +08:00
Soulter e7e97730af chore: bump version to 4.9.0 2025-12-13 18:49:07 +08:00
Soulter 467ca1eb5c fix: webui log output incompletely (#4029)
* fix: webui log output incompletely

* fix: improve SSE log parsing to handle partial data chunks

* fix: enhance log handling by implementing local cache and fetching history

* fix: log time handling to use epoch time
2025-12-13 18:46:16 +08:00
Soulter bb45d9cb54 stage 2025-12-13 17:16:07 +08:00
RC-CHN 46528391c2 feat: add pre-chunk import strategy for knowledge base (#3973)
* feat: 添加文档导入功能及相关测试

* feat: 优化文档上传功能,支持从文件名推断文件类型,并增强文档切片验证

* feat: 添加文档导入功能的无效输入测试,验证 chunks 类型和内容的错误处理

* refactor: 重构文档上传和导入任务的状态管理,添加任务初始化、结果设置和进度更新方法
2025-12-12 23:15:11 +08:00
Soulter 8a0b7717cc feat: supports webhook mode for Lark platform (#4016)
* feat: add Lark platform support with unified webhook configuration

* fix: update token verification logic in LarkWebhookServer

* feat: implement event deduplication and cleanup for Lark webhook events
2025-12-12 22:12:13 +08:00
Copilot 3b81fb4985 fix: mobile dialog close button visibility (#4010)
* Initial plan

* Fix mobile dialog close button visibility by adding max-height and scrollable content

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 16:02:24 +08:00
Soulter c09d57a820 refactor: improve UI layout and interaction for list item management (#4002)
* refactor: improve UI layout and interaction for list item management

* feat: enhance list configuration UI with batch import functionality

* feat: add internationalization support for list configuration UI
2025-12-11 18:55:56 +08:00
Soulter ec408a2aff fix: lark message timestamp 2025-12-11 18:20:50 +08:00
Soulter 417179a6b9 ci: add smoke test 2025-12-11 10:44:15 +08:00
Soulter fcd29445c7 refactor: remove unused current provider initialization in StarRequestSubStage 2025-12-11 10:36:33 +08:00
BiDuang 5f535001db fix: incorrect modalities enum of gemini api provider (#3993) 2025-12-10 20:27:51 +08:00
PaloMiku 750d245b16 docs: Update README with new Zread link and badges (#3992)
ZRead 是由智谱 AI 推出的 DeepWiki 类似平替品。
2025-12-10 20:22:56 +08:00
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00
Soulter aa6d07afcc refactor: move all internal commands from astrbot plugin to default_command plugin (#3960)
* refactor: move all internal commands from astrbot plugin to default_command plugin

* ruff check

* feat: add config

* ruff check
2025-12-08 22:17:32 +08:00
Soulter 2c36649874 feat: add Agent Runner test prompt dialog in ProviderPage (#3968) 2025-12-08 21:46:47 +08:00
Soulter c95735dcc0 docs: update readme 2025-12-08 12:05:57 +08:00
Soulter 03bb278f50 chore: ruff check 2025-12-08 11:00:43 +08:00
Soulter a5e0974da3 chore: ruff format 2025-12-08 00:36:56 +08:00
vmoranv f0fb447fbc feat: custom plugin api source manager (#3956)
* feat: custom plugin api source manager

* fix: rename plugin source file in a safer way

* chore: turned the way of saving plugin source to backend and refacted some components

* style: clean up whitespace and improve logging message formatting

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-08 00:32:50 +08:00
Soulter 37566182b0 feat: segment reply supports segmentation words (#3959)
* feat: segment reply supports segmentation words

* chore: ruff format

* feat: enhance segmented reply processing by refining word extraction logic

* ruff format
2025-12-08 00:27:17 +08:00
Soulter e460b411da chore: remove dev version from webui (#3951)
* chore: remove dev version

* chore: remove development version references from header localization files
2025-12-07 15:23:30 +08:00
257 changed files with 13754 additions and 4493 deletions
+1 -1
View File
@@ -36,7 +36,7 @@ jobs:
zip -r dist.zip dist
- name: Archive production artifacts
uses: actions/upload-artifact@v5
uses: actions/upload-artifact@v6
with:
name: dist-without-markdown
path: |
+58
View File
@@ -0,0 +1,58 @@
name: Smoke Test
on:
push:
branches:
- master
paths-ignore:
- 'README*.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
smoke-test:
name: Run smoke tests
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install UV package manager
run: |
pip install uv
- name: Install dependencies
run: |
uv sync
timeout-minutes: 15
- name: Run smoke tests
run: |
uv run main.py &
APP_PID=$!
echo "Waiting for application to start..."
for i in {1..60}; do
if curl -f http://localhost:6185 > /dev/null 2>&1; then
echo "Application started successfully!"
kill $APP_PID
exit 0
fi
sleep 1
done
echo "Application failed to start within 30 seconds"
kill $APP_PID 2>/dev/null || true
exit 1
timeout-minutes: 2
+26 -1
View File
@@ -33,6 +33,20 @@
- 请使用英文描述您的 PR。
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`
#### 代码规范
##### Core
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
```bash
ruff format .
ruff check .
```
如果您使用 VSCode,可以安装 `Ruff` 插件。
## Contributing Guide
First off, thanks for taking the time to contribute! ❤️
@@ -62,4 +76,15 @@ We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features.
#### PR Description
- Please use English to describe your PR.
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
#### Code Style
##### Core
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
```bash
ruff format .
ruff check .
```
+9 -1
View File
@@ -1,4 +1,4 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
![astrbot-banner-xmas](https://github.com/user-attachments/assets/bf2341de-ec7a-45a7-a04a-02ad36450e99)
<div align="center">
@@ -20,6 +20,7 @@
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
@@ -206,6 +207,7 @@ pre-commit install
- 3 群:630166526
- 5 群:822130018
- 6 群:753075035
- 7 群:743746109
- 开发者群:975206796
### Telegram 群组
@@ -241,4 +243,10 @@ pre-commit install
</details>
<div align="center">
_私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.8.0"
__version__ = "4.10.2"
+6 -4
View File
@@ -3,7 +3,7 @@
from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
from pydantic_core import core_schema
@@ -122,10 +122,12 @@ class ToolCall(BaseModel):
extra_content: dict[str, Any] | None = None
"""Extra metadata for the tool call."""
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
@model_serializer(mode="wrap")
def serialize(self, handler):
data = handler(self)
if self.extra_content is None:
kwargs.setdefault("exclude", set()).add("extra_content")
return super().model_dump(**kwargs)
data.pop("extra_content", None)
return data
class ToolCallPart(BaseModel):
+22 -1
View File
@@ -1,7 +1,8 @@
import typing as T
from dataclasses import dataclass
from dataclasses import dataclass, field
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import TokenUsage
class AgentResponseData(T.TypedDict):
@@ -12,3 +13,23 @@ class AgentResponseData(T.TypedDict):
class AgentResponse:
type: str
data: AgentResponseData
@dataclass
class AgentStats:
token_usage: TokenUsage = field(default_factory=TokenUsage)
start_time: float = 0.0
end_time: float = 0.0
time_to_first_token: float = 0.0
@property
def duration(self) -> float:
return self.end_time - self.start_time
def to_dict(self) -> dict:
return {
"token_usage": self.token_usage.__dict__,
"start_time": self.start_time,
"end_time": self.end_time,
"time_to_first_token": self.time_to_first_token,
}
+1 -1
View File
@@ -9,7 +9,7 @@ from .message import Message
TContext = TypeVar("TContext", default=Any)
@dataclass(config={"arbitrary_types_allowed": True})
@dataclass
class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state."""
@@ -1,4 +1,5 @@
import sys
import time
import traceback
import typing as T
@@ -12,6 +13,7 @@ from mcp.types import (
)
from astrbot import logger
from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import (
MessageChain,
)
@@ -24,7 +26,7 @@ from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData
from ..response import AgentResponseData, AgentStats
from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner
@@ -69,14 +71,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
self.run_context.messages = messages
self.stats = AgentStats()
self.stats.start_time = time.time()
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
}
if self.streaming:
stream = self.provider.text_chat_stream(**self.req.__dict__)
stream = self.provider.text_chat_stream(**payload)
async for resp in stream: # type: ignore
yield resp
else:
yield await self.provider.text_chat(**self.req.__dict__)
yield await self.provider.text_chat(**payload)
@override
async def step(self):
@@ -97,8 +110,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_resp_result = None
async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk:
# update ttft
if self.stats.time_to_first_token == 0:
self.stats.time_to_first_token = time.time() - self.stats.start_time
if llm_response.result_chain:
yield AgentResponse(
type="streaming_delta",
@@ -122,6 +138,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
)
continue
llm_resp_result = llm_response
if not llm_response.is_chunk and llm_response.usage:
# only count the token usage of the final response for computation purpose
self.stats.token_usage += llm_response.usage
break # got final response
if not llm_resp_result:
@@ -133,6 +153,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if llm_resp.role == "err":
# 如果 LLM 响应错误,转换到错误状态
self.final_llm_resp = llm_resp
self.stats.end_time = time.time()
self._transition_state(AgentState.ERROR)
yield AgentResponse(
type="err",
@@ -147,11 +168,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
# record the final assistant message
self.run_context.messages.append(
Message(
role="assistant",
content=llm_resp.completion_text or "",
content=llm_resp.completion_text or "*No response*",
),
)
try:
@@ -176,22 +198,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name:
tool_call_result_blocks = []
for tool_call_name in llm_resp.tools_call_name:
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain(type="tool_call").message(
f"🔨 调用工具: {tool_call_name}"
),
),
)
async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list):
tool_call_result_blocks = result
elif isinstance(result, MessageChain):
result.type = "tool_call_result"
if result.type is None:
# should not happen
continue
if result.type == "tool_direct_result":
ar_type = "tool_call_result"
else:
ar_type = result.type
yield AgentResponse(
type="tool_call_result",
type=ar_type,
data=AgentResponseData(chain=result),
)
# 将结果添加到上下文中
@@ -219,6 +238,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
async for resp in self.step():
yield resp
# 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step
if not self.done():
logger.warning(
f"Agent reached max steps ({max_step}), forcing a final response."
)
# 拔掉所有工具
if self.req:
self.req.func_tool = None
# 注入提示词
self.run_context.messages.append(
Message(
role="user",
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
)
)
# 再执行最后一步
async for resp in self.step():
yield resp
async def _handle_function_tools(
self,
req: ProviderRequest,
@@ -234,6 +272,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_response.tools_call_args,
llm_response.tools_call_ids,
):
yield MessageChain(
type="tool_call",
chain=[
Json(
data={
"id": func_tool_id,
"name": func_tool_name,
"args": func_tool_args,
"ts": time.time(),
}
)
],
)
try:
if not req.func_tool:
return
@@ -307,7 +358,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content=res.content[0].text,
),
)
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append(
ToolCallMessageSegment(
@@ -329,7 +379,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content=resource.text,
),
)
yield MessageChain().message(resource.text)
elif (
isinstance(resource, BlobResourceContents)
and resource.mimeType
@@ -353,20 +402,34 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content="返回的数据类型不受支持",
),
)
yield MessageChain().message("返回的数据类型不受支持。")
elif resp is None:
# Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。
# 发送消息逻辑在 ToolExecutor 中处理了。
logger.warning(
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中"
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户。"
)
self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="*工具没有返回值或者将结果直接发送给了用户*",
),
)
else:
# 不应该出现其他类型
logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略",
f"Tool 返回了不支持的类型: {type(resp)}",
)
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="*工具返回了不支持的类型,请告诉用户检查这个工具的定义和实现。*",
),
)
try:
@@ -388,6 +451,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
),
)
# yield the last tool call result
if tool_call_result_blocks:
last_tcr_content = str(tool_call_result_blocks[-1].content)
yield MessageChain(
type="tool_call_result",
chain=[
Json(
data={
"id": func_tool_id,
"ts": time.time(),
"result": last_tcr_content,
}
)
],
)
# 处理函数调用响应
if tool_call_result_blocks:
yield tool_call_result_blocks
+7 -2
View File
@@ -1,4 +1,4 @@
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Generic
import jsonschema
@@ -7,6 +7,8 @@ from deprecated import deprecated
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from astrbot.core.message.message_event_result import MessageEventResult
from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any]
@@ -38,7 +40,10 @@ class ToolSchema:
class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling."""
handler: Callable[..., Awaitable[Any]] | None = None
handler: (
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
| None
) = None
"""a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None
+3 -1
View File
@@ -6,8 +6,10 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.context import Context
@dataclass(config={"arbitrary_types_allowed": True})
@dataclass
class AstrAgentContext:
__pydantic_config__ = {"arbitrary_types_allowed": True}
context: Context
"""The star context instance"""
event: AstrMessageEvent
+42 -3
View File
@@ -2,8 +2,10 @@ import traceback
from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import (
MessageChain,
MessageEventResult,
@@ -23,8 +25,25 @@ async def run_agent(
) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0
astr_event = agent_runner.run_context.context.event
while step_idx < max_step:
while step_idx < max_step + 1:
step_idx += 1
if step_idx == max_step + 1:
logger.warning(
f"Agent reached max steps ({max_step}), forcing a final response."
)
if not agent_runner.done():
# 拔掉所有工具
if agent_runner.req:
agent_runner.req.func_tool = None
# 注入提示词
agent_runner.run_context.messages.append(
Message(
role="user",
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
)
)
try:
async for resp in agent_runner.step():
if astr_event.is_stopped():
@@ -33,16 +52,27 @@ async def run_agent(
msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
await astr_event.send(resp.data["chain"])
await astr_event.send(msg_chain)
continue
if astr_event.get_platform_id() == "webchat":
await astr_event.send(msg_chain)
# 对于其他情况,暂时先不处理
continue
elif resp.type == "tool_call":
if agent_runner.streaming:
# 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break")
if show_tool_use:
if astr_event.get_platform_name() == "webchat":
await astr_event.send(resp.data["chain"])
elif show_tool_use:
json_comp = resp.data["chain"].chain[0]
if isinstance(json_comp, Json):
m = f"🔨 调用工具: {json_comp.data.get('name')}"
else:
m = "🔨 调用工具..."
chain = MessageChain(type="tool_call").message(m)
await astr_event.send(chain)
continue
if stream_to_general and resp.type == "streaming_delta":
@@ -69,6 +99,15 @@ async def run_agent(
continue
yield resp.data["chain"] # MessageChain
if agent_runner.done():
# send agent stats to webchat
if astr_event.get_platform_name() == "webchat":
await astr_event.send(
MessageChain(
type="agent_stats",
chain=[Json(data=agent_runner.stats.to_dict())],
)
)
break
except Exception as e:
+39 -5
View File
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]],
handler: T.Callable[
...,
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
],
method_name: str,
*args,
**kwargs,
@@ -205,12 +209,42 @@ async def call_local_llm_tool(
else:
raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True)
except TypeError:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
raise Exception(f"Tool execution ValueError: {e}") from e
except TypeError as e:
# 获取函数的签名(包括类型),除了第一个 event/context 参数。
try:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
# 跳过第一个参数(event 或 context
if params:
params = params[1:]
param_strs = []
for param in params:
param_str = param.name
if param.annotation != inspect.Parameter.empty:
# 获取类型注解的字符串表示
if isinstance(param.annotation, type):
type_str = param.annotation.__name__
else:
type_str = str(param.annotation)
param_str += f": {type_str}"
if param.default != inspect.Parameter.empty:
param_str += f" = {param.default!r}"
param_strs.append(param_str)
handler_param_str = (
", ".join(param_strs) if param_strs else "(no additional parameters)"
)
except Exception:
handler_param_str = "(unable to inspect signature)"
raise Exception(
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}"
) from e
except Exception as e:
trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}")
raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e
if not ready_to_call:
return
+4
View File
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
"""
config_path: str
default_config: dict
schema: dict | None
def __init__(
self,
config_path: str = ASTRBOT_CONFIG_PATH,
+226 -212
View File
@@ -1,10 +1,11 @@
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
import os
from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.8.0"
VERSION = "4.10.2"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -13,6 +14,7 @@ WEBHOOK_SUPPORTED_PLATFORMS = [
"wecom",
"wecom_ai_bot",
"slack",
"lark",
]
# 默认配置
@@ -42,7 +44,15 @@ DEFAULT_CONFIG = {
"interval": "1.5,3.5",
"log_base": 2.6,
"words_count_threshold": 150,
"split_mode": "regex", # regex 或 words
"regex": ".*?[。?!~…]+|.+$",
"split_words": [
"",
"",
"",
"~",
"",
], # 当 split_mode 为 words 时使用
"content_cleanup_rule": "",
},
"no_permission_reply": True,
@@ -52,7 +62,8 @@ DEFAULT_CONFIG = {
"ignore_bot_self_message": False,
"ignore_at_all": False,
},
"provider": [],
"provider_sources": [], # provider sources
"provider": [], # models from provider_sources
"provider_settings": {
"enable": True,
"default_provider_id": "",
@@ -99,6 +110,7 @@ DEFAULT_CONFIG = {
"provider_id": "",
"dual_output": False,
"use_file_service": False,
"trigger_probability": 1.0,
},
"provider_ltm_settings": {
"group_icl_enable": False,
@@ -157,9 +169,26 @@ DEFAULT_CONFIG = {
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
"kb_final_top_k": 5, # 知识库检索最终返回结果数量
"kb_agentic_mode": False,
"disable_builtin_commands": False,
}
class ChatProviderTemplate(TypedDict):
id: str
provider_source_id: str
model: str
modalities: list
custom_extra_body: dict[str, Any]
CHAT_PROVIDER_TEMPLATE = {
"id": "",
"provide_source_id": "",
"model": "",
"modalities": [],
"custom_extra_body": {},
}
"""
AstrBot v3 时代的配置元数据,目前仅承担以下功能:
@@ -198,7 +227,7 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0",
"port": 6196,
},
"QQ 个人号(OneBot v11)": {
"OneBot v11": {
"id": "default",
"type": "aiocqhttp",
"enable": False,
@@ -268,6 +297,10 @@ CONFIG_METADATA_2 = {
"app_id": "",
"app_secret": "",
"domain": "https://open.feishu.cn",
"lark_connection_mode": "socket", # webhook, socket
"webhook_uuid": "",
"lark_encrypt_key": "",
"lark_verification_token": "",
},
"钉钉(DingTalk)": {
"id": "dingtalk",
@@ -361,6 +394,28 @@ CONFIG_METADATA_2 = {
# "type": "string",
# "options": ["fullscreen", "embedded"],
# },
"lark_connection_mode": {
"description": "订阅方式",
"type": "string",
"options": ["socket", "webhook"],
"labels": ["长连接模式", "推送至服务器模式"],
},
"lark_encrypt_key": {
"description": "Encrypt Key",
"type": "string",
"hint": "用于解密飞书回调数据的加密密钥",
"condition": {
"lark_connection_mode": "webhook",
},
},
"lark_verification_token": {
"description": "Verification Token",
"type": "string",
"hint": "用于验证飞书回调请求的令牌",
"condition": {
"lark_connection_mode": "webhook",
},
},
"is_sandbox": {
"description": "沙箱模式",
"type": "bool",
@@ -807,6 +862,7 @@ CONFIG_METADATA_2 = {
"metadata": {
"provider": {
"type": "list",
# provider sources templates
"config_template": {
"OpenAI": {
"id": "openai",
@@ -817,107 +873,10 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.openai.com/v1",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
},
"Azure OpenAI": {
"id": "azure",
"provider": "azure",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"api_version": "2024-05-01-preview",
"key": [],
"api_base": "",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"xAI": {
"id": "xai",
"provider": "xai",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"xai_native_search": False,
"modalities": ["text", "image", "tool_use"],
},
"Anthropic": {
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
"id": "claude",
"provider": "anthropic",
"type": "anthropic_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.anthropic.com/v1",
"timeout": 120,
"model_config": {
"model": "claude-3-5-sonnet-latest",
"max_tokens": 4096,
"temperature": 0.2,
},
"modalities": ["text", "image", "tool_use"],
},
"Ollama": {
"hint": "启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key",
"id": "ollama_default",
"provider": "ollama",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1",
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"LM Studio": {
"id": "lm_studio",
"provider": "lm_studio",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": ["lmstudio"],
"api_base": "http://localhost:1234/v1",
"model_config": {
"model": "llama-3.1-8b",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Gemini(OpenAI兼容)": {
"id": "gemini_default",
"provider": "google",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
"timeout": 120,
"model_config": {
"model": "gemini-1.5-flash",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Gemini": {
"id": "gemini_default",
"Google Gemini": {
"id": "google_gemini",
"provider": "google",
"type": "googlegenai_chat_completion",
"provider_type": "chat_completion",
@@ -925,10 +884,6 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://generativelanguage.googleapis.com/",
"timeout": 120,
"model_config": {
"model": "gemini-2.0-flash-exp",
"temperature": 0.4,
},
"gm_resp_image_modal": False,
"gm_native_search": False,
"gm_native_coderunner": False,
@@ -939,13 +894,43 @@ CONFIG_METADATA_2 = {
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
},
"gm_thinking_config": {
"budget": 0,
},
"modalities": ["text", "image", "tool_use"],
"gm_thinking_config": {"budget": 0, "level": "HIGH"},
},
"Anthropic": {
"id": "anthropic",
"provider": "anthropic",
"type": "anthropic_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.anthropic.com/v1",
"timeout": 120,
},
"Moonshot": {
"id": "moonshot",
"provider": "moonshot",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"custom_headers": {},
},
"xAI": {
"id": "xai",
"provider": "xai",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"custom_headers": {},
"xai_native_search": False,
},
"DeepSeek": {
"id": "deepseek_default",
"id": "deepseek",
"provider": "deepseek",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
@@ -953,13 +938,75 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.deepseek.com/v1",
"timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
"Zhipu": {
"id": "zhipu",
"provider": "zhipu",
"type": "zhipu_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"custom_headers": {},
},
"Azure OpenAI": {
"id": "azure_openai",
"provider": "azure",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"api_version": "2024-05-01-preview",
"key": [],
"api_base": "",
"timeout": 120,
"custom_headers": {},
},
"Ollama": {
"id": "ollama",
"provider": "ollama",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://127.0.0.1:11434/v1",
"custom_headers": {},
},
"LM Studio": {
"id": "lm_studio",
"provider": "lm_studio",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": ["lmstudio"],
"api_base": "http://127.0.0.1:1234/v1",
"custom_headers": {},
},
"ModelStack": {
"id": "modelstack",
"provider": "modelstack",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://modelstack.app/v1",
"timeout": 120,
"custom_headers": {},
},
"Gemini_OpenAI_API": {
"id": "google_gemini_openai",
"provider": "google",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
"timeout": 120,
"custom_headers": {},
},
"Groq": {
"id": "groq_default",
"id": "groq",
"provider": "groq",
"type": "groq_chat_completion",
"provider_type": "chat_completion",
@@ -967,13 +1014,7 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.groq.com/openai/v1",
"timeout": 120,
"model_config": {
"model": "openai/gpt-oss-20b",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
},
"302.AI": {
"id": "302ai",
@@ -984,12 +1025,9 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.302.ai/v1",
"timeout": 120,
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"硅基流动": {
"SiliconFlow": {
"id": "siliconflow",
"provider": "siliconflow",
"type": "openai_chat_completion",
@@ -998,15 +1036,9 @@ CONFIG_METADATA_2 = {
"key": [],
"timeout": 120,
"api_base": "https://api.siliconflow.cn/v1",
"model_config": {
"model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"PPIO派欧云": {
"PPIO": {
"id": "ppio",
"provider": "ppio",
"type": "openai_chat_completion",
@@ -1015,14 +1047,9 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.ppinfra.com/v3/openai",
"timeout": 120,
"model_config": {
"model": "deepseek/deepseek-r1",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
},
"小马算力": {
"TokenPony": {
"id": "tokenpony",
"provider": "tokenpony",
"type": "openai_chat_completion",
@@ -1031,14 +1058,9 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.tokenpony.cn/v1",
"timeout": 120,
"model_config": {
"model": "kimi-k2-instruct-0905",
"temperature": 0.7,
},
"custom_headers": {},
"custom_extra_body": {},
},
"优云智算": {
"Compshare": {
"id": "compshare",
"provider": "compshare",
"type": "openai_chat_completion",
@@ -1047,42 +1069,18 @@ CONFIG_METADATA_2 = {
"key": [],
"api_base": "https://api.modelverse.cn/v1",
"timeout": 120,
"model_config": {
"model": "moonshotai/Kimi-K2-Instruct",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Kimi": {
"id": "moonshot",
"provider": "moonshot",
"ModelScope": {
"id": "modelscope",
"provider": "modelscope",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
"api_base": "https://api-inference.modelscope.cn/v1",
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"智谱 AI": {
"id": "zhipu_default",
"provider": "zhipu",
"type": "zhipu_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Dify": {
"id": "dify_app_default",
@@ -1097,7 +1095,6 @@ CONFIG_METADATA_2 = {
"dify_query_input_key": "astrbot_text_query",
"variables": {},
"timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
},
"Coze": {
"id": "coze",
@@ -1128,20 +1125,6 @@ CONFIG_METADATA_2 = {
"variables": {},
"timeout": 60,
},
"ModelScope": {
"id": "modelscope",
"provider": "modelscope",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"FastGPT": {
"id": "fastgpt",
"provider": "fastgpt",
@@ -1165,7 +1148,6 @@ CONFIG_METADATA_2 = {
"model": "whisper-1",
},
"Whisper(Local)": {
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cudaCPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"provider": "openai",
"type": "openai_whisper_selfhost",
"provider_type": "speech_to_text",
@@ -1174,7 +1156,6 @@ CONFIG_METADATA_2 = {
"model": "tiny",
},
"SenseVoice(Local)": {
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"type": "sensevoice_stt_selfhost",
"provider": "sensevoice",
"provider_type": "speech_to_text",
@@ -1196,7 +1177,6 @@ CONFIG_METADATA_2 = {
"timeout": "20",
},
"Edge TTS": {
"hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
"id": "edge_tts",
"provider": "microsoft",
"type": "edge_tts",
@@ -1412,6 +1392,10 @@ CONFIG_METADATA_2 = {
},
},
"items": {
"provider_source_id": {
"invisible": True,
"type": "string",
},
"xai_native_search": {
"description": "启用原生搜索功能",
"type": "bool",
@@ -1782,13 +1766,24 @@ CONFIG_METADATA_2 = {
},
},
"gm_thinking_config": {
"description": "Gemini思考设置",
"description": "Thinking Config",
"type": "object",
"items": {
"budget": {
"description": "思考预算",
"description": "Thinking Budget",
"type": "int",
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
"hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget",
},
"level": {
"description": "Thinking Level",
"type": "string",
"hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
"options": [
"MINIMAL",
"LOW",
"MEDIUM",
"HIGH",
],
},
},
},
@@ -1969,7 +1964,6 @@ CONFIG_METADATA_2 = {
"id": {
"description": "ID",
"type": "string",
"hint": "模型提供商名字。",
},
"type": {
"description": "模型提供商种类",
@@ -1989,29 +1983,15 @@ CONFIG_METADATA_2 = {
"description": "API Key",
"type": "list",
"items": {"type": "string"},
"hint": "提供商 API Key。",
},
"api_base": {
"description": "API Base URL",
"type": "string",
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
},
"model_config": {
"description": "模型配置",
"type": "object",
"items": {
"model": {
"description": "模型名称",
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"max_tokens": {
"description": "模型最大输出长度(tokens",
"type": "int",
},
"temperature": {"description": "温度", "type": "float"},
"top_p": {"description": "Top P值", "type": "float"},
},
"model": {
"description": "模型 ID",
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"dify_api_key": {
"description": "API Key",
@@ -2173,6 +2153,9 @@ CONFIG_METADATA_2 = {
"use_file_service": {
"type": "bool",
},
"trigger_probability": {
"type": "float",
},
},
},
"provider_ltm_settings": {
@@ -2383,6 +2366,14 @@ CONFIG_METADATA_3 = {
"provider_tts_settings.enable": True,
},
},
"provider_tts_settings.trigger_probability": {
"description": "TTS 触发概率",
"type": "float",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": {
"provider_tts_settings.enable": True,
},
},
"provider_settings.image_caption_prompt": {
"description": "图片转述提示词",
"type": "text",
@@ -2661,6 +2652,11 @@ CONFIG_METADATA_3 = {
"description": "只 @ 机器人是否触发等待",
"type": "bool",
},
"disable_builtin_commands": {
"description": "禁用自带指令",
"type": "bool",
"hint": "禁用所有 AstrBot 的自带指令,如 help, provider, model 等。",
},
},
},
"whitelist": {
@@ -2875,9 +2871,26 @@ CONFIG_METADATA_3 = {
"description": "分段回复字数阈值",
"type": "int",
},
"platform_settings.segmented_reply.split_mode": {
"description": "分段模式",
"type": "string",
"options": ["regex", "words"],
"labels": ["正则表达式", "分段词列表"],
},
"platform_settings.segmented_reply.regex": {
"description": "分段正则表达式",
"type": "string",
"condition": {
"platform_settings.segmented_reply.split_mode": "regex",
},
},
"platform_settings.segmented_reply.split_words": {
"description": "分段词列表",
"type": "list",
"hint": "检测到列表中的任意词时进行分段,如:。、?、!等",
"condition": {
"platform_settings.segmented_reply.split_mode": "words",
},
},
"platform_settings.segmented_reply.content_cleanup_rule": {
"description": "内容过滤正则表达式",
@@ -2928,6 +2941,7 @@ CONFIG_METADATA_3 = {
"description": "回复概率",
"type": "float",
"hint": "0.0-1.0 之间的数值",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": {
"provider_ltm_settings.active_reply.enable": True,
},
+1
View File
@@ -79,6 +79,7 @@ class ConfigMetadataI18n:
"_special",
"invisible",
"options",
"slider",
]:
if attr in field_data:
field_result[attr] = field_data[attr]
+4 -1
View File
@@ -33,6 +33,7 @@ from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.llm_metadata import update_llm_metadata
from astrbot.core.utils.migra_helper import migra
from . import astrbot_config, html_renderer
@@ -185,6 +186,8 @@ class AstrBotCoreLifecycle:
# 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event()
asyncio.create_task(update_llm_metadata())
def _load(self) -> None:
"""加载事件总线和任务并初始化."""
# 创建一个异步任务来执行事件总线的 dispatch() 方法
@@ -197,7 +200,7 @@ class AstrBotCoreLifecycle:
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = []
for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_:
+74 -3
View File
@@ -5,11 +5,12 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass
from deprecated import deprecated
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from astrbot.core.db.po import (
Attachment,
CommandConfig,
CommandConflict,
ConversationV2,
Persona,
PlatformMessageHistory,
@@ -32,7 +33,7 @@ class BaseDatabase(abc.ABC):
echo=False,
future=True,
)
self.AsyncSessionLocal = sessionmaker(
self.AsyncSessionLocal = async_sessionmaker(
self.engine,
class_=AsyncSession,
expire_on_commit=False,
@@ -315,6 +316,76 @@ class BaseDatabase(abc.ABC):
"""Clear all preferences for a specific scope ID."""
...
@abc.abstractmethod
async def get_command_configs(self) -> list[CommandConfig]:
"""Get all stored command configurations."""
...
@abc.abstractmethod
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
"""Fetch a single command configuration by handler."""
...
@abc.abstractmethod
async def upsert_command_config(
self,
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
"""Create or update a command configuration."""
...
@abc.abstractmethod
async def delete_command_config(self, handler_full_name: str) -> None:
"""Delete a single command configuration."""
...
@abc.abstractmethod
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
"""Bulk delete command configurations."""
...
@abc.abstractmethod
async def list_command_conflicts(
self,
status: str | None = None,
) -> list[CommandConflict]:
"""List recorded command conflict entries."""
...
@abc.abstractmethod
async def upsert_command_conflict(
self,
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
"""Create or update a conflict record."""
...
@abc.abstractmethod
async def delete_command_conflicts(self, ids: list[int]) -> None:
"""Delete conflict records."""
...
# @abc.abstractmethod
# async def insert_llm_message(
# self,
@@ -70,6 +70,7 @@ async def migration_conversation_table(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
)
continue
if ":" not in conv.user_id:
continue
session = MessageSesion.from_str(session_str=conv.user_id)
@@ -207,6 +208,7 @@ async def migration_webchat_data(
logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
)
continue
if ":" in conv.user_id:
continue
platform_id = "webchat"
+6 -4
View File
@@ -127,7 +127,7 @@ class SQLiteDatabase:
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: tuple = None):
def _exec_sql(self, sql: str, params: tuple | None = None):
conn = self.conn
try:
c = self.conn.cursor()
@@ -224,9 +224,11 @@ class SQLiteDatabase:
c.close()
return Stats(platform, [], [])
return Stats(platform)
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation:
def get_conversation_by_user_id(
self, user_id: str, cid: str
) -> Conversation | None:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
@@ -258,7 +260,7 @@ class SQLiteDatabase:
(user_id, cid, history, updated_at, created_at),
)
def get_conversations(self, user_id: str) -> tuple:
def get_conversations(self, user_id: str) -> list[Conversation]:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
+75 -15
View File
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
Note: In astrbot v4, we moved `platform` table to here.
"""
__tablename__ = "platform_stats" # type: ignore
__tablename__: str = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False)
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations" # type: ignore
__tablename__: str = "conversations"
inner_conversation_id: int = Field(
inner_conversation_id: int | None = Field(
default=None,
primary_key=True,
sa_column_kwargs={"autoincrement": True},
)
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
It can be used to customize the behavior of LLMs.
"""
__tablename__ = "personas" # type: ignore
__tablename__: str = "personas"
id: int | None = Field(
primary_key=True,
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
class Preference(SQLModel, table=True):
"""This class represents preferences for bots."""
__tablename__ = "preferences" # type: ignore
__tablename__: str = "preferences"
id: int | None = Field(
default=None,
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
or platform-specific messages.
"""
__tablename__ = "platform_message_history" # type: ignore
__tablename__: str = "platform_message_history"
id: int | None = Field(
primary_key=True,
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
Each session can have multiple conversations (对话) associated with it.
"""
__tablename__ = "platform_sessions" # type: ignore
__tablename__: str = "platform_sessions"
inner_id: int | None = Field(
primary_key=True,
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
Attachments can be images, files, or other media types.
"""
__tablename__ = "attachments" # type: ignore
__tablename__: str = "attachments"
inner_attachment_id: int | None = Field(
primary_key=True,
@@ -233,6 +234,65 @@ class Attachment(SQLModel, table=True):
)
class CommandConfig(SQLModel, table=True):
"""Per-command configuration overrides for dashboard management."""
__tablename__ = "command_configs" # type: ignore
handler_full_name: str = Field(
primary_key=True,
max_length=512,
)
plugin_name: str = Field(nullable=False, max_length=255)
module_path: str = Field(nullable=False, max_length=255)
original_command: str = Field(nullable=False, max_length=255)
resolved_command: str | None = Field(default=None, max_length=255)
enabled: bool = Field(default=True, nullable=False)
keep_original_alias: bool = Field(default=False, nullable=False)
conflict_key: str | None = Field(default=None, max_length=255)
resolution_strategy: str | None = Field(default=None, max_length=64)
note: str | None = Field(default=None, sa_type=Text)
extra_data: dict | None = Field(default=None, sa_type=JSON)
auto_managed: bool = Field(default=False, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class CommandConflict(SQLModel, table=True):
"""Conflict tracking for duplicated command names."""
__tablename__ = "command_conflicts" # type: ignore
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
conflict_key: str = Field(nullable=False, max_length=255)
handler_full_name: str = Field(nullable=False, max_length=512)
plugin_name: str = Field(nullable=False, max_length=255)
status: str = Field(default="pending", max_length=32)
resolution: str | None = Field(default=None, max_length=64)
resolved_command: str | None = Field(default=None, max_length=255)
note: str | None = Field(default=None, sa_type=Text)
extra_data: dict | None = Field(default=None, sa_type=JSON)
auto_generated: bool = Field(default=False, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"conflict_key",
"handler_full_name",
name="uix_conflict_handler",
),
)
@dataclass
class Conversation:
"""LLM 对话类
@@ -261,17 +321,17 @@ class Personality(TypedDict):
在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。
"""
prompt: str = ""
name: str = ""
begin_dialogs: list[str] = []
mood_imitation_dialogs: list[str] = []
prompt: str
name: str
begin_dialogs: list[str]
mood_imitation_dialogs: list[str]
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None
tools: list[str] | None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache
_begin_dialogs_processed: list[dict] = []
_mood_imitation_dialogs_processed: str = ""
_begin_dialogs_processed: list[dict]
_mood_imitation_dialogs_processed: str
# ====
+244 -3
View File
@@ -1,14 +1,18 @@
import asyncio
import threading
import typing as T
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import (
Attachment,
CommandConfig,
CommandConflict,
ConversationV2,
Persona,
PlatformMessageHistory,
@@ -25,6 +29,7 @@ from astrbot.core.db.po import (
)
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
TxResult = T.TypeVar("TxResult")
class SQLiteDatabase(BaseDatabase):
@@ -489,7 +494,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session:
session: AsyncSession
query = select(Attachment).where(
Attachment.attachment_id.in_(attachment_ids)
col(Attachment.attachment_id).in_(attachment_ids)
)
result = await session.execute(query)
return list(result.scalars().all())
@@ -505,7 +510,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id
)
result = await session.execute(query)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int:
@@ -521,7 +526,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids)
)
result = await session.execute(query)
result = T.cast(CursorResult, await session.execute(query))
return result.rowcount
async def insert_persona(
@@ -669,6 +674,242 @@ class SQLiteDatabase(BaseDatabase):
)
await session.commit()
# ====
# Command Configuration & Conflict Tracking
# ====
async def _run_in_tx(
self,
fn: Callable[[AsyncSession], Awaitable[TxResult]],
) -> TxResult:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
return await fn(session)
@staticmethod
def _apply_updates(model, **updates) -> None:
for field, value in updates.items():
if value is not None:
setattr(model, field, value)
@staticmethod
def _new_command_config(
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
return CommandConfig(
handler_full_name=handler_full_name,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=True if enabled is None else enabled,
keep_original_alias=False
if keep_original_alias is None
else keep_original_alias,
conflict_key=conflict_key or original_command,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=bool(auto_managed),
)
@staticmethod
def _new_command_conflict(
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
return CommandConflict(
conflict_key=conflict_key,
handler_full_name=handler_full_name,
plugin_name=plugin_name,
status=status or "pending",
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=bool(auto_generated),
)
async def get_command_configs(self) -> list[CommandConfig]:
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(select(CommandConfig))
return list(result.scalars().all())
async def get_command_config(
self,
handler_full_name: str,
) -> CommandConfig | None:
async with self.get_db() as session:
session: AsyncSession
return await session.get(CommandConfig, handler_full_name)
async def upsert_command_config(
self,
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
async def _op(session: AsyncSession) -> CommandConfig:
config = await session.get(CommandConfig, handler_full_name)
if not config:
config = self._new_command_config(
handler_full_name,
plugin_name,
module_path,
original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
session.add(config)
else:
self._apply_updates(
config,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
await session.flush()
await session.refresh(config)
return config
return await self._run_in_tx(_op)
async def delete_command_config(self, handler_full_name: str) -> None:
await self.delete_command_configs([handler_full_name])
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
if not handler_full_names:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConfig).where(
col(CommandConfig.handler_full_name).in_(handler_full_names),
),
)
await self._run_in_tx(_op)
async def list_command_conflicts(
self,
status: str | None = None,
) -> list[CommandConflict]:
async with self.get_db() as session:
session: AsyncSession
query = select(CommandConflict)
if status:
query = query.where(CommandConflict.status == status)
result = await session.execute(query)
return list(result.scalars().all())
async def upsert_command_conflict(
self,
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
async def _op(session: AsyncSession) -> CommandConflict:
result = await session.execute(
select(CommandConflict).where(
CommandConflict.conflict_key == conflict_key,
CommandConflict.handler_full_name == handler_full_name,
),
)
record = result.scalar_one_or_none()
if not record:
record = self._new_command_conflict(
conflict_key,
handler_full_name,
plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
session.add(record)
else:
self._apply_updates(
record,
plugin_name=plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
await session.flush()
await session.refresh(record)
return record
return await self._run_in_tx(_op)
async def delete_command_conflicts(self, ids: list[int]) -> None:
if not ids:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
)
await self._run_in_tx(_op)
# ====
# Deprecated Methods
# ====
@@ -90,4 +90,6 @@ class EmbeddingStorage:
path (str): 保存索引的路径
"""
if self.index is None:
return
faiss.write_index(self.index, self.path)
+6 -1
View File
@@ -27,7 +27,7 @@ class EventBus:
self,
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None,
astrbot_config_mgr: AstrBotConfigManager,
):
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
@@ -40,6 +40,11 @@ class EventBus:
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
if not scheduler:
logger.error(
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
)
continue
asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str):
@@ -166,7 +166,11 @@ class RetrievalManager:
# 5. Rerank
first_rerank = None
for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"]
vec_db = kb_options[kb_id]["vec_db"]
if not isinstance(vec_db, FaissVecDB):
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
continue
rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if (
vec_db
+2 -1
View File
@@ -24,6 +24,7 @@ import asyncio
import logging
import os
import sys
import time
from asyncio import Queue
from collections import deque
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
self.log_broker.publish(
{
"level": record.levelname,
"time": record.asctime,
"time": time.time(),
"data": log_entry,
},
)
+13 -8
View File
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
class BaseMessageComponent(BaseModel):
type: ComponentType
def __init__(self, **kwargs):
super().__init__(**kwargs)
def toDict(self):
data = {}
for k, v in self.__dict__.items():
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
id: int | None = 0 # 忽略
name: str | None = "" # qq昵称
uin: str | None = "0" # qq号
content: list[BaseMessageComponent] | None = []
content: list[BaseMessageComponent] = []
seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
ret["messages"].append(d)
return ret
async def to_dict(self):
async def to_dict(self) -> dict:
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []}
for node in self.nodes:
@@ -626,12 +629,11 @@ class Nodes(BaseMessageComponent):
class Json(BaseMessageComponent):
type = ComponentType.Json
data: str | dict
resid: int | None = 0
data: dict
def __init__(self, data, **_):
if isinstance(data, dict):
data = json.dumps(data)
def __init__(self, data: str | dict, **_):
if isinstance(data, str):
data = json.loads(data)
super().__init__(data=data, **_)
@@ -714,12 +716,15 @@ class File(BaseMessageComponent):
if self.url:
await self._download_file()
return os.path.abspath(self.file_)
if self.file_:
return os.path.abspath(self.file_)
return ""
async def _download_file(self):
"""下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
if self.name:
+2 -2
View File
@@ -98,8 +98,8 @@ class PersonaManager:
self,
persona_id: str,
system_prompt: str,
begin_dialogs: list[str] = None,
tools: list[str] = None,
begin_dialogs: list[str] | None = None,
tools: list[str] | None = None,
) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id):
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
self,
event: AstrMessageEvent,
check_text: str | None = None,
) -> None | AsyncGenerator[None, None]:
) -> AsyncGenerator[None, None]:
"""检查内容安全"""
text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text)
+2 -1
View File
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
async def call_handler(
event: AstrMessageEvent,
handler: T.Callable[..., T.Awaitable[T.Any]],
handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
*args,
**kwargs,
) -> T.AsyncGenerator[T.Any, None]:
@@ -91,6 +91,7 @@ async def call_event_hook(
)
for handler in handlers:
try:
assert inspect.iscoroutinefunction(handler.handler)
logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
)
@@ -321,7 +321,12 @@ class InternalAgentSubStage(Stage):
elif isinstance(req.tool_calls_result, list):
for tcr in req.tool_calls_result:
messages.extend(tcr.to_openai_messages())
messages.append({"role": "assistant", "content": llm_response.completion_text})
messages.append(
{
"role": "assistant",
"content": llm_response.completion_text or "*No response*",
}
)
messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation(
event.unified_msg_origin,
@@ -16,7 +16,6 @@ from ..stage import Stage
class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
self.ctx = ctx
@@ -24,7 +23,7 @@ class StarRequestSubStage(Stage):
async def process(
self,
event: AstrMessageEvent,
) -> AsyncGenerator[None, None]:
) -> AsyncGenerator[Any, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers",
)
+1 -1
View File
@@ -60,7 +60,7 @@ class ProcessStage(Stage):
):
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if (
event.get_result() and not event.get_result().is_stopped()
event.get_result() and not event.is_stopped()
) or not event.get_result():
async for _ in self.agent_sub_stage.process(event):
yield
+8 -2
View File
@@ -117,7 +117,9 @@ class RespondStage(Stage):
if not self.enable_seg:
return False
if self.only_llm_result and not event.get_result().is_llm_result():
if (result := event.get_result()) is None:
return False
if self.only_llm_result and not result.is_llm_result():
return False
if event.get_platform_name() in [
@@ -156,7 +158,11 @@ class RespondStage(Stage):
result = event.get_result()
if result is None:
return
if event.get_extra("_streaming_finished", False):
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
return
if result.result_content_type == ResultContentType.STREAMING_FINISH:
event.set_extra("_streaming_finished", True)
return
logger.info(
@@ -185,7 +191,7 @@ class RespondStage(Stage):
if isinstance(component, Comp.File) and component.file:
# 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component
result.chain[idx] = component
# 检查消息链是否为空
try:
+89 -22
View File
@@ -1,3 +1,4 @@
import random
import re
import time
import traceback
@@ -6,6 +7,7 @@ from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager
@@ -41,6 +43,18 @@ class ResultDecorateStage(Stage):
"forward_threshold"
]
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
"trigger_probability",
1,
)
try:
self.tts_trigger_probability = max(
0.0,
min(float(trigger_probability), 1.0),
)
except (TypeError, ValueError):
self.tts_trigger_probability = 1.0
# 分段回复
self.words_count_threshold = int(
ctx.astrbot_config["platform_settings"]["segmented_reply"][
@@ -53,7 +67,22 @@ class ResultDecorateStage(Stage):
self.only_llm_result = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["only_llm_result"]
self.split_mode = ctx.astrbot_config["platform_settings"][
"segmented_reply"
].get("split_mode", "regex")
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
self.split_words = ctx.astrbot_config["platform_settings"][
"segmented_reply"
].get("split_words", ["", "", "", "~", ""])
if self.split_words:
escaped_words = sorted(
[re.escape(word) for word in self.split_words], key=len, reverse=True
)
self.split_words_pattern = re.compile(
f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL
)
else:
self.split_words_pattern = None
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["content_cleanup_rule"]
@@ -69,6 +98,28 @@ class ResultDecorateStage(Stage):
self.content_safe_check_stage = stage_cls()
await self.content_safe_check_stage.initialize(ctx)
def _split_text_by_words(self, text: str) -> list[str]:
"""使用分段词列表分段文本"""
if not self.split_words_pattern:
return [text]
segments = self.split_words_pattern.findall(text)
result = []
for seg in segments:
if isinstance(seg, tuple):
content = seg[0]
if not isinstance(content, str):
continue
for word in self.split_words:
if content.endswith(word):
content = content[: -len(word)]
break
if content.strip():
result.append(content)
elif seg and seg.strip():
result.append(seg)
return result if result else [text]
async def process(
self,
event: AstrMessageEvent,
@@ -93,11 +144,13 @@ class ResultDecorateStage(Stage):
for comp in result.chain:
if isinstance(comp, Plain):
text += comp.text
async for _ in self.content_safe_check_stage.process(
event,
check_text=text,
):
yield
if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
async for _ in self.content_safe_check_stage.process(
event,
check_text=text,
):
yield
# 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type(
@@ -114,7 +167,8 @@ class ResultDecorateStage(Stage):
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
)
await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
if (result := event.get_result()) is None or not result.chain:
logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
)
@@ -161,21 +215,27 @@ class ResultDecorateStage(Stage):
# 不分段回复
new_chain.append(comp)
continue
try:
split_response = re.findall(
self.regex,
comp.text,
re.DOTALL | re.MULTILINE,
)
except re.error:
logger.error(
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
)
split_response = re.findall(
r".*?[。?!~…]+|.+$",
comp.text,
re.DOTALL | re.MULTILINE,
)
# 根据 split_mode 选择分段方式
if self.split_mode == "words":
split_response = self._split_text_by_words(comp.text)
else: # regex 模式
try:
split_response = re.findall(
self.regex,
comp.text,
re.DOTALL | re.MULTILINE,
)
except re.error:
logger.error(
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
)
split_response = re.findall(
r".*?[。?!~…]+|.+$",
comp.text,
re.DOTALL | re.MULTILINE,
)
if not split_response:
new_chain.append(comp)
continue
@@ -199,7 +259,14 @@ class ResultDecorateStage(Stage):
and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event)
):
if not tts_provider:
should_tts = self.tts_trigger_probability >= 1.0 or (
self.tts_trigger_probability > 0.0
and random.random() <= self.tts_trigger_probability
)
if not should_tts:
logger.debug("跳过 TTS:触发概率未命中。")
elif not tts_provider:
logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
)
+5 -1
View File
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
from astrbot.core import logger
from astrbot.core.platform import AstrMessageEvent
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
WecomAIBotMessageEvent,
)
from . import STAGES_ORDER
from .context import PipelineContext
@@ -78,7 +82,7 @@ class PipelineScheduler:
await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]:
if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
await event.send(None)
logger.debug("pipeline 执行完毕。")
@@ -50,6 +50,9 @@ class WakingCheckStage(Stage):
"ignore_at_all",
False,
)
self.disable_builtin_commands = self.ctx.astrbot_config.get(
"disable_builtin_commands", False
)
async def process(
self,
@@ -131,6 +134,13 @@ class WakingCheckStage(Stage):
EventType.AdapterMessageEvent,
plugins_name=event.plugins_name,
):
if (
self.disable_builtin_commands
and handler.handler_module_path == "packages.builtin_commands.main"
):
logger.debug("skipping builtin command")
continue
# filter 需满足 AND 逻辑关系
passed = True
permission_not_pass = False
+5 -3
View File
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)"""
return self.message_obj.sender.nickname
if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value):
"""设置额外的信息。"""
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
"""
self.call_llm = call_llm
def get_result(self) -> MessageEventResult:
def get_result(self) -> MessageEventResult | None:
"""获取消息事件的结果。"""
return self._result
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
self,
prompt: str,
func_tool_manager=None,
session_id: str = None,
session_id: str = "",
image_urls: list[str] | None = None,
contexts: list | None = None,
system_prompt: str = "",
+2 -2
View File
@@ -54,7 +54,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id
group: Group # 群组
group: Group | None # 群组
sender: MessageMember # 发送者
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串
@@ -78,7 +78,7 @@ class AstrBotMessage:
return ""
@group_id.setter
def group_id(self, value: str):
def group_id(self, value: str | None):
"""设置 group_id"""
if value:
if self.group:
+4
View File
@@ -5,6 +5,7 @@ from asyncio import Queue
from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
from .platform import Platform, PlatformStatus
from .register import platform_cls_map
@@ -18,6 +19,7 @@ class PlatformManager:
self._inst_map: dict[str, dict] = {}
self.astrbot_config = config
self.platforms_config = config["platform"]
self.settings = config["platform_settings"]
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
@@ -29,6 +31,8 @@ class PlatformManager:
"""初始化所有平台适配器"""
for platform in self.platforms_config:
try:
if ensure_platform_webhook_config(platform):
self.astrbot_config.save_config()
await self.load_platform(platform)
except Exception as e:
logger.error(f"初始化 {platform} 平台适配器失败: {e}")
+11 -3
View File
@@ -1,7 +1,7 @@
import abc
import uuid
from asyncio import Queue
from collections.abc import Awaitable
from collections.abc import Coroutine
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
@@ -80,6 +80,13 @@ class Platform(abc.ABC):
if self._status == PlatformStatus.ERROR:
self._status = PlatformStatus.RUNNING
def unified_webhook(self) -> bool:
"""是否正在使用统一 Webhook 模式"""
return bool(
self.config.get("unified_webhook_mode", False)
and self.config.get("webhook_uuid")
)
def get_stats(self) -> dict:
"""获取平台统计信息"""
meta = self.meta()
@@ -97,10 +104,11 @@ class Platform(abc.ABC):
}
if self.last_error
else None,
"unified_webhook": self.unified_webhook(),
}
@abc.abstractmethod
def run(self) -> Awaitable[Any]:
def run(self) -> Coroutine[Any, Any, None]:
"""得到一个平台的运行实例,需要返回一个协程对象。"""
raise NotImplementedError
@@ -116,7 +124,7 @@ class Platform(abc.ABC):
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
异步方法。
+1 -1
View File
@@ -7,7 +7,7 @@ class PlatformMetadata:
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str
"""平台的描述"""
id: str | None = None
id: str
"""平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict | None = None
+1
View File
@@ -40,6 +40,7 @@ def register_platform_adapter(
pm = PlatformMetadata(
name=adapter_name,
description=desc,
id=adapter_name,
default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name,
logo_path=logo_path,
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
bot: CQHttp,
event: Event | None,
is_group: bool,
session_id: str,
session_id: str | None,
messages: list[dict],
):
# session_id 必须是纯数字字符串
session_id = int(session_id) if session_id.isdigit() else None
session_id_int = (
int(session_id) if session_id and session_id.isdigit() else None
)
if is_group and isinstance(session_id, int):
await bot.send_group_msg(group_id=session_id, message=messages)
elif not is_group and isinstance(session_id, int):
await bot.send_private_msg(user_id=session_id, message=messages)
if is_group and isinstance(session_id_int, int):
await bot.send_group_msg(group_id=session_id_int, message=messages)
elif not is_group and isinstance(session_id_int, int):
await bot.send_private_msg(user_id=session_id_int, message=messages)
elif isinstance(event, Event): # 最后兜底
await bot.send(event=event, message=messages)
else:
@@ -4,7 +4,7 @@ import logging
import time
import uuid
from collections.abc import Awaitable
from typing import Any
from typing import Any, cast
from aiocqhttp import CQHttp, Event
from aiocqhttp.exceptions import ActionFailed
@@ -48,7 +48,7 @@ class AiocqhttpAdapter(Platform):
self.metadata = PlatformMetadata(
name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
@@ -127,7 +127,9 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 请求类事件"""
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
abm.sender = MessageMember(
user_id=str(event.user_id), nickname=str(event.user_id)
)
abm.type = MessageType.OTHER_MESSAGE
if event.get("group_id"):
abm.type = MessageType.GROUP_MESSAGE
@@ -194,6 +196,7 @@ class AiocqhttpAdapter(Platform):
@param event: 事件对象
@param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。
"""
assert event.sender is not None
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.sender = MessageMember(
@@ -203,6 +206,7 @@ class AiocqhttpAdapter(Platform):
if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id)
abm.group = Group(str(event.group_id))
abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE
@@ -228,7 +232,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send(event, err)
except BaseException as e:
logger.error(f"回复消息失败: {e}")
return None
raise ValueError(err)
# 按消息段类型类型适配
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
@@ -381,10 +385,25 @@ class AiocqhttpAdapter(Platform):
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
message_str += "".join(at_parts)
elif t == "markdown":
text = m["data"].get("markdown") or m["data"].get("content", "")
abm.message.append(Plain(text=text))
message_str += text
else:
for m in m_group:
a = ComponentTypes[t](**m["data"])
abm.message.append(a)
try:
if t not in ComponentTypes:
logger.warning(
f"不支持的消息段类型,已忽略: {t}, data={m['data']}"
)
continue
a = ComponentTypes[t](**m["data"])
abm.message.append(a)
except Exception as e:
logger.exception(
f"消息段解析失败: type={t}, data={m['data']}. {e}"
)
continue
abm.timestamp = int(time.time())
abm.message_str = message_str
@@ -417,7 +436,7 @@ class AiocqhttpAdapter(Platform):
async def shutdown_trigger_placeholder(self):
await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被优雅地关闭")
logger.info("aiocqhttp 适配器已被关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
@@ -2,6 +2,7 @@ import asyncio
import os
import threading
import uuid
from typing import cast
import aiohttp
import dingtalk_stream
@@ -54,12 +55,14 @@ class DingtalkPlatformAdapter(Platform):
self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"]
outer_self = self
class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
async def process(self_, message: dingtalk_stream.CallbackMessage):
async def process(self, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
abm = await self.convert_msg(im)
await self.handle_msg(abm)
abm = await outer_self.convert_msg(im)
await outer_self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK"
@@ -73,6 +76,7 @@ class DingtalkPlatformAdapter(Platform):
self.client,
)
self.client_ = client # 用于 websockets 的 client
self._shutdown_event: threading.Event | None = None
def _id_to_sid(self, dingtalk_id: str | None) -> str:
if not dingtalk_id:
@@ -93,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
return PlatformMetadata(
name="dingtalk",
description="钉钉机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
@@ -104,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
abm = AstrBotMessage()
abm.message = []
abm.message_str = ""
abm.timestamp = int(message.create_at / 1000)
abm.timestamp = int(cast(int, message.create_at) / 1000)
abm.type = (
MessageType.GROUP_MESSAGE
if message.conversation_type == "2"
@@ -115,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
nickname=message.sender_nick,
)
abm.self_id = self._id_to_sid(message.chatbot_user_id)
abm.message_id = message.message_id
abm.message_id = cast(str, message.message_id)
abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE:
@@ -132,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
else:
abm.session_id = abm.sender.user_id
message_type: str = message.message_type
message_type: str = cast(str, message.message_type)
match message_type:
case "text":
abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str))
case "richText":
rtc: dingtalk_stream.RichTextContent = message.rich_text_content
contents: list[dict] = rtc.rich_text_list
rtc: dingtalk_stream.RichTextContent = cast(
dingtalk_stream.RichTextContent, message.rich_text_content
)
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
for content in contents:
plains = ""
if "text" in content:
@@ -148,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file(
content["downloadCode"],
message.robot_code,
cast(str, message.robot_code),
"jpg",
)
abm.message.append(Image.fromFileSystem(f_path))
@@ -193,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
)
return None
return ""
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path)
@@ -213,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
)
return None
return ""
return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage):
@@ -239,7 +245,7 @@ class DingtalkPlatformAdapter(Platform):
task.result()
except Exception as e:
if "Graceful shutdown" in str(e):
logger.info("钉钉适配器已被优雅地关闭")
logger.info("钉钉适配器已被关闭")
return
logger.error(f"钉钉机器人启动失败: {e}")
@@ -250,9 +256,11 @@ class DingtalkPlatformAdapter(Platform):
def monkey_patch_close():
raise KeyboardInterrupt("Graceful shutdown")
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
self._shutdown_event.set()
if self.client_.websocket is not None:
self.client_.open_connection = monkey_patch_close
await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
if self._shutdown_event is not None:
self._shutdown_event.set()
def get_client(self):
return self.client
@@ -1,4 +1,5 @@
import asyncio
from typing import cast
import dingtalk_stream
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown,
segment.text,
segment.text,
self.message_obj.raw_message,
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown,
"😄",
markdown_str,
self.message_obj.raw_message,
cast(
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
),
)
logger.debug(f"send image: {ret}")
@@ -1,4 +1,5 @@
import sys
from collections.abc import Awaitable, Callable
import discord
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
super().__init__(intents=intents, proxy=proxy)
# 回调函数
self.on_message_received = None
self.on_ready_once_callback = None
self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
self._ready_once_fired = False
@override
async def on_ready(self):
"""当机器人成功连接并准备就绪时触发"""
if self.user is None:
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
return
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。")
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
is_mentioned = self.user in message.mentions
return {
"message": message,
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
if interaction.user is None:
raise ValueError("Interaction received without a valid user")
return {
"interaction": interaction,
"bot_id": str(self.user.id),
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
"type": "interaction",
}
@override
async def on_message(self, message: discord.Message):
"""当接收到消息时触发"""
if message.author.bot:
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
def __init__(
self,
components: list[BaseMessageComponent] = None,
timeout: float = None,
components: list[BaseMessageComponent] | None = None,
timeout: float | None = None,
):
self.components = components or []
self.timeout = timeout
@@ -1,10 +1,10 @@
import asyncio
import re
import sys
from typing import Any
from typing import Any, cast
import discord
from discord.abc import Messageable
from discord.abc import GuildChannel, Messageable, PrivateChannel
from discord.channel import DMChannel
from astrbot import logger
@@ -46,7 +46,7 @@ class DiscordPlatformAdapter(Platform):
) -> None:
super().__init__(platform_config, event_queue)
self.settings = platform_settings
self.client_self_id = None
self.client_self_id: str | None = None
self.registered_handlers = []
# 指令注册相关
self.enable_command_register = self.config.get("discord_command_register", True)
@@ -62,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
message_chain: MessageChain,
):
"""通过会话发送消息"""
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
)
return
# 创建一个 message_obj 以便在 event 中使用
message_obj = AstrBotMessage()
if "_" in session.session_id:
@@ -89,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
user_id=str(self.client_self_id),
nickname=self.client.user.display_name,
)
message_obj.self_id = self.client_self_id
message_obj.self_id = cast(str, self.client_self_id)
message_obj.session_id = session.session_id
message_obj.message = message_chain.chain
@@ -110,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
return PlatformMetadata(
"discord",
"Discord 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
default_config_tmpl=self.config,
support_streaming_message=False,
)
@@ -160,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
def _get_message_type(
self,
channel: Messageable,
channel: Messageable | GuildChannel | PrivateChannel,
guild_id: int | None = None,
) -> MessageType:
"""根据 channel 对象和 guild_id 判断消息类型"""
@@ -170,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
return MessageType.FRIEND_MESSAGE
return MessageType.GROUP_MESSAGE
def _get_channel_id(self, channel: Messageable) -> str:
def _get_channel_id(
self, channel: Messageable | GuildChannel | PrivateChannel
) -> str:
"""根据 channel 对象获取ID"""
return str(getattr(channel, "id", None))
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"]
message = data["message"]
content = message.content
@@ -233,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
)
abm.message = message_chain
abm.raw_message = message
abm.self_id = self.client_self_id
abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(message.channel.id)
abm.message_id = str(message.id)
return abm
@@ -254,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
interaction_followup_webhook=followup_webhook,
)
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
)
return
# 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None
# 1. 优先处理斜杠指令
if is_slash_command:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
self.commit_event(message_event)
return
# 2. 处理普通消息(提及检测)
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
raw_message = message.raw_message
if not isinstance(raw_message, discord.Message):
logger.warning(
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
)
return
# 检查是否被@User Mention 或 Bot 拥有的 Role Mention
is_mention = False
# User Mention
if (
self.client
and self.client.user
and hasattr(message.raw_message, "mentions")
):
if self.client.user in message.raw_message.mentions:
is_mention = True
# 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
if self.client.user in raw_message.mentions:
is_mention = True
# Role MentionBot 拥有的角色被提及)
if not is_mention and hasattr(message.raw_message, "role_mentions"):
if not is_mention and raw_message.role_mentions:
bot_member = None
if hasattr(message.raw_message, "guild") and message.raw_message.guild:
if raw_message.guild:
try:
bot_member = message.raw_message.guild.get_member(
bot_member = raw_message.guild.get_member(
self.client.user.id,
)
except Exception:
bot_member = None
if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles)
mentioned_roles = set(message.raw_message.role_mentions)
mentioned_roles = set(raw_message.role_mentions)
if (
bot_roles
and mentioned_roles
@@ -287,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
):
is_mention = True
# 如果是斜杠指令或被@的消息,设置为唤醒状态
if is_slash_command or is_mention:
# 如果是被@的消息,设置为唤醒状态
if is_mention:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
@@ -424,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
)
abm.message = [Plain(text=message_str_for_filter)]
abm.raw_message = ctx.interaction
abm.self_id = self.client_self_id
abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(ctx.channel_id)
abm.message_id = str(ctx.interaction.id)
@@ -437,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
def _extract_command_info(
event_filter: Any,
handler_metadata: StarHandlerMetadata,
) -> tuple[str, str, CommandFilter] | None:
) -> tuple[str, str, CommandFilter | None] | None:
"""从事件过滤器中提取指令信息"""
cmd_name = None
# is_group = False
@@ -4,8 +4,10 @@ import binascii
from collections.abc import AsyncGenerator
from io import BytesIO
from pathlib import Path
from typing import cast
import discord
from discord.types.interactions import ComponentInteractionData
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
channel = await self._get_channel()
if not channel:
return
if not isinstance(channel, discord.abc.Messageable):
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
return
await channel.send(**kwargs)
except Exception as e:
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
async def _get_channel(self) -> discord.abc.Messageable | None:
async def _get_channel(
self,
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
"""获取当前事件对应的频道对象"""
try:
channel_id = int(self.session_id)
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
async def _parse_to_discord(
self,
message: MessageChain,
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]:
) -> tuple[
str,
list[discord.File],
discord.ui.View | None,
list[discord.Embed],
str | int | None,
]:
"""将 MessageChain 解析为 Discord 发送所需的内容"""
content_parts = []
files = []
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message,
"add_reaction",
):
await self.message_obj.raw_message.add_reaction(emoji)
await cast(discord.Message, self.message_obj.raw_message).add_reaction(
emoji
)
except Exception as e:
logger.error(f"[Discord] 添加反应失败: {e}")
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type
and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.application_command
)
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
return (
hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type == discord.InteractionType.component
and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.component
)
def get_interaction_custom_id(self) -> str:
"""获取交互组件的custom_id"""
if self.is_button_interaction():
try:
return self.message_obj.raw_message.data.get("custom_id", "")
return cast(
ComponentInteractionData,
cast(discord.Interaction, self.message_obj.raw_message).data,
).get("custom_id", "")
except Exception:
pass
return ""
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
):
return any(
mention.id == int(self.message_obj.self_id)
for mention in self.message_obj.raw_message.mentions
for mention in cast(
discord.Message, self.message_obj.raw_message
).mentions
)
return False
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message,
"clean_content",
):
return self.message_obj.raw_message.clean_content
return cast(discord.Message, self.message_obj.raw_message).clean_content
return self.message_str
@@ -2,10 +2,17 @@ import asyncio
import base64
import json
import re
import time
import uuid
from typing import Any, cast
import lark_oapi as lark
from lark_oapi.api.im.v1 import *
from lark_oapi.api.im.v1 import (
CreateMessageRequest,
CreateMessageRequestBody,
GetMessageResourceRequest,
)
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
import astrbot.api.message_components as Comp
from astrbot import logger
@@ -18,9 +25,11 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .lark_event import LarkMessageEvent
from .server import LarkWebhookServer
@register_platform_adapter(
@@ -42,9 +51,13 @@ class LarkPlatformAdapter(Platform):
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
self.bot_name = platform_config.get("lark_bot_name", "astrbot")
# socket or webhook
self.connection_mode = platform_config.get("lark_connection_mode", "socket")
if not self.bot_name:
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
# 初始化 WebSocket 长连接相关配置
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
await self.convert_msg(event)
@@ -57,6 +70,8 @@ class LarkPlatformAdapter(Platform):
.build()
)
self.do_v2_msg_event = do_v2_msg_event
self.client = lark.ws.Client(
app_id=self.appid,
app_secret=self.appsecret,
@@ -66,14 +81,56 @@ class LarkPlatformAdapter(Platform):
)
self.lark_api = (
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
lark.Client.builder()
.app_id(self.appid)
.app_secret(self.appsecret)
.log_level(lark.LogLevel.ERROR)
.domain(self.domain)
.build()
)
self.webhook_server = None
if self.connection_mode == "webhook":
self.webhook_server = LarkWebhookServer(platform_config, event_queue)
self.webhook_server.set_callback(self.handle_webhook_event)
self.event_id_timestamps: dict[str, float] = {}
def _clean_expired_events(self):
"""清理超过 30 分钟的事件记录"""
current_time = time.time()
expired_keys = [
event_id
for event_id, timestamp in self.event_id_timestamps.items()
if current_time - timestamp > 1800
]
for event_id in expired_keys:
del self.event_id_timestamps[event_id]
def _is_duplicate_event(self, event_id: str) -> bool:
"""检查事件是否重复
Args:
event_id: 事件ID
Returns:
True 表示重复事件False 表示新事件
"""
self._clean_expired_events()
if event_id in self.event_id_timestamps:
return True
self.event_id_timestamps[event_id] = time.time()
return False
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
return
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = {
"zh_cn": {
@@ -114,14 +171,25 @@ class LarkPlatformAdapter(Platform):
return PlatformMetadata(
name="lark",
description="飞书机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
if event.event is None:
logger.debug("[Lark] 收到空事件(event.event is None)")
return
message = event.event.message
if message is None:
logger.debug("[Lark] 事件中没有消息体(message is None)")
return
abm = AstrBotMessage()
abm.timestamp = int(message.create_time) / 1000
if message.create_time:
abm.timestamp = int(message.create_time) // 1000
else:
abm.timestamp = int(time.time())
abm.message = []
abm.type = (
MessageType.GROUP_MESSAGE
@@ -136,14 +204,28 @@ class LarkPlatformAdapter(Platform):
at_list = {}
if message.mentions:
for m in message.mentions:
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name)
if m.name == self.bot_name:
abm.self_id = m.id.open_id
if m.id is None:
continue
# 飞书 open_id 可能是 None,这里做个防护
open_id = m.id.open_id if m.id.open_id else ""
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
content_json_b = json.loads(message.content)
if m.name == self.bot_name:
if m.id.open_id is not None:
abm.self_id = m.id.open_id
if message.content is None:
logger.warning("[Lark] 消息内容为空")
return
try:
content_json_b = json.loads(message.content)
except json.JSONDecodeError:
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
return
if message.message_type == "text":
message_str_raw = content_json_b["text"] # 带有 @ 的消息
message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
# at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本,去掉AT符号部分
@@ -168,27 +250,47 @@ class LarkPlatformAdapter(Platform):
content_json_b = _ls
elif message.message_type == "image":
content_json_b = [
{"tag": "img", "image_key": content_json_b["image_key"], "style": []},
{
"tag": "img",
"image_key": content_json_b.get("image_key"),
"style": [],
},
]
if message.message_type in ("post", "image"):
for comp in content_json_b:
if comp["tag"] == "at":
abm.message.append(at_list[comp["user_id"]])
elif comp["tag"] == "text" and comp["text"].strip():
if comp.get("tag") == "at":
user_id = comp.get("user_id")
if user_id in at_list:
abm.message.append(at_list[user_id])
elif comp.get("tag") == "text" and comp.get("text", "").strip():
abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp["tag"] == "img":
image_key = comp["image_key"]
elif comp.get("tag") == "img":
image_key = comp.get("image_key")
if not image_key:
continue
request = (
GetMessageResourceRequest.builder()
.message_id(message.message_id)
.message_id(cast(str, message.message_id))
.file_key(image_key)
.type("image")
.build()
)
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化")
continue
response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success():
logger.error(f"无法下载飞书图片: {image_key}")
continue
if response.file is None:
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
continue
image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Comp.Image.fromBase64(image_base64))
@@ -196,6 +298,19 @@ class LarkPlatformAdapter(Platform):
for comp in abm.message:
if isinstance(comp, Comp.Plain):
abm.message_str += comp.text
if message.message_id is None:
logger.error("[Lark] 消息缺少 message_id")
return
if (
event.event.sender is None
or event.event.sender.sender_id is None
or event.event.sender.sender_id.open_id is None
):
logger.error("[Lark] 消息发送者信息不完整")
return
abm.message_id = message.message_id
abm.raw_message = message
abm.sender = MessageMember(
@@ -227,13 +342,61 @@ class LarkPlatformAdapter(Platform):
self._event_queue.put_nowait(event)
async def handle_webhook_event(self, event_data: dict):
"""处理 Webhook 事件
Args:
event_data: Webhook 事件数据
"""
try:
header = event_data.get("header", {})
event_id = header.get("event_id", "")
if event_id and self._is_duplicate_event(event_id):
logger.debug(f"[Lark Webhook] 跳过重复事件: {event_id}")
return
event_type = header.get("event_type", "")
if event_type == "im.message.receive_v1":
processor = P2ImMessageReceiveV1Processor(self.do_v2_msg_event)
data = (processor.type())(event_data)
processor.do(data)
else:
logger.debug(f"[Lark Webhook] 未处理的事件类型: {event_type}")
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
async def run(self):
# self.client.start()
await self.client._connect()
if self.connection_mode == "webhook":
# Webhook 模式
if self.webhook_server is None:
logger.error("[Lark] Webhook 模式已启用,但 webhook_server 未初始化")
return
webhook_uuid = self.config.get("webhook_uuid")
if webhook_uuid:
log_webhook_info(f"{self.meta().id}(飞书 Webhook)", webhook_uuid)
else:
logger.warning("[Lark] Webhook 模式已启用,但未配置 webhook_uuid")
else:
# 长连接模式
await self.client._connect()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
if not self.webhook_server:
return {"error": "Webhook server not initialized"}, 500
return await self.webhook_server.handle_callback(request)
async def terminate(self):
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已被优雅地关闭")
if self.connection_mode == "socket":
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已关闭")
def get_client(self) -> lark.Client:
def get_client(self) -> lark.ws.Client:
return self.client
def unified_webhook(self) -> bool:
return bool(
self.config.get("lark_connection_mode", "") == "webhook"
and self.config.get("webhook_uuid")
)
@@ -5,7 +5,15 @@ import uuid
from io import BytesIO
import lark_oapi as lark
from lark_oapi.api.im.v1 import *
from lark_oapi.api.im.v1 import (
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
ReplyMessageRequest,
ReplyMessageRequestBody,
)
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
file_path = comp.file.replace("file:///", "")
elif comp.file and comp.file.startswith("http"):
image_file_path = await download_image_by_url(comp.file)
file_path = image_file_path
file_path = image_file_path if image_file_path else ""
elif comp.file and comp.file.startswith("base64://"):
base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str)
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue())
else:
file_path = comp.file
file_path = comp.file if comp.file else ""
if image_file is None:
image_file = open(file_path, "rb")
if not file_path:
logger.error("[Lark] 图片路径为空,无法上传")
continue
try:
image_file = open(file_path, "rb")
except Exception as e:
logger.error(f"[Lark] 无法打开图片文件: {e}")
continue
request = (
CreateImageRequest.builder()
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
)
.build()
)
if lark_client.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
continue
response = await lark_client.im.v1.image.acreate(request)
if not response.success():
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
continue
if response.data is None:
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
continue
image_key = response.data.image_key
logger.debug(image_key)
ret.append(_stage)
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
.build()
)
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
return
response = await self.bot.im.v1.message.areply(request)
if not response.success():
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message)
async def react(self, emoji: str):
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
return
request = (
CreateMessageReactionRequest.builder()
.message_id(self.message_obj.message_id)
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
)
.build()
)
response = await self.bot.im.v1.message_reaction.acreate(request)
if not response.success():
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
@@ -0,0 +1,206 @@
"""飞书(Lark) Webhook 服务器实现
实现飞书事件订阅的 Webhook 模式支持:
1. 请求 URL 验证 (challenge 验证)
2. 事件加密/解密 (AES-256-CBC)
3. 签名校验 (SHA256)
4. 事件接收和处理
"""
import asyncio
import base64
import hashlib
import json
from collections.abc import Awaitable, Callable
from Crypto.Cipher import AES
from astrbot.api import logger
class AESCipher:
"""AES 加密/解密工具类"""
def __init__(self, key: str):
self.bs = AES.block_size
self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
@staticmethod
def str_to_bytes(data):
u_type = type(b"".decode("utf8"))
if isinstance(data, u_type):
return data.encode("utf8")
return data
@staticmethod
def _unpad(s):
return s[: -ord(s[len(s) - 1 :])]
def decrypt(self, enc):
iv = enc[: AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
def decrypt_string(self, enc):
enc = base64.b64decode(enc)
return self.decrypt(enc).decode("utf8")
class LarkWebhookServer:
"""飞书 Webhook 服务器
仅支持统一 Webhook 模式
"""
def __init__(self, config: dict, event_queue: asyncio.Queue):
"""初始化 Webhook 服务器
Args:
config: 飞书配置
event_queue: 事件队列
"""
self.app_id = config["app_id"]
self.app_secret = config["app_secret"]
self.encrypt_key = config.get("lark_encrypt_key", "")
self.verification_token = config.get("lark_verification_token", "")
self.event_queue = event_queue
self.callback: Callable[[dict], Awaitable[None]] | None = None
# 初始化加密工具
self.cipher = None
if self.encrypt_key:
self.cipher = AESCipher(self.encrypt_key)
def verify_signature(
self,
timestamp: str,
nonce: str,
encrypt_key: str,
body: bytes,
signature: str,
) -> bool:
"""验证签名
Args:
timestamp: 请求时间戳
nonce: 随机数
encrypt_key: 加密密钥
body: 请求体
signature: 签名
Returns:
签名是否有效
"""
# 拼接字符串: timestamp + nonce + encrypt_key + body
bytes_b1 = (timestamp + nonce + encrypt_key).encode("utf-8")
bytes_b = bytes_b1 + body
h = hashlib.sha256(bytes_b)
calculated_signature = h.hexdigest()
return calculated_signature == signature
def decrypt_event(self, encrypted_data: str) -> dict:
"""解密事件数据
Args:
encrypted_data: 加密的事件数据
Returns:
解密后的事件字典
"""
if not self.cipher:
raise ValueError("未配置 encrypt_key,无法解密事件")
decrypted_str = self.cipher.decrypt_string(encrypted_data)
return json.loads(decrypted_str)
async def handle_challenge(self, event_data: dict) -> dict:
"""处理 challenge 验证请求
Args:
event_data: 事件数据
Returns:
包含 challenge 的响应
"""
challenge = event_data.get("challenge", "")
logger.info(f"[Lark Webhook] 收到 challenge 验证请求: {challenge}")
return {"challenge": challenge}
async def handle_callback(self, request) -> tuple[dict, int] | dict:
"""处理 webhook 回调,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
响应数据
"""
# 获取原始请求体
body = await request.get_data()
try:
event_data = await request.json
except Exception as e:
logger.error(f"[Lark Webhook] 解析请求体失败: {e}")
return {"error": "Invalid JSON"}, 400
if not event_data:
logger.error("[Lark Webhook] 请求体为空")
return {"error": "Empty request body"}, 400
# 如果配置了 encrypt_key,进行签名验证
if self.encrypt_key:
timestamp = request.headers.get("X-Lark-Request-Timestamp", "")
nonce = request.headers.get("X-Lark-Request-Nonce", "")
signature = request.headers.get("X-Lark-Signature", "")
if timestamp and nonce and signature:
if not self.verify_signature(
timestamp, nonce, self.encrypt_key, body, signature
):
logger.error("[Lark Webhook] 签名验证失败")
return {"error": "Invalid signature"}, 401
# 检查是否是加密事件
if "encrypt" in event_data:
try:
event_data = self.decrypt_event(event_data["encrypt"])
logger.debug(f"[Lark Webhook] 解密后的事件: {event_data}")
except Exception as e:
logger.error(f"[Lark Webhook] 解密事件失败: {e}")
return {"error": "Decryption failed"}, 400
# 验证 token
if self.verification_token:
header = event_data.get("header", {})
if header:
token = header.get("token", "")
else:
token = event_data.get("token", "")
if token != self.verification_token:
logger.error("[Lark Webhook] Verification Token 不匹配。")
return {"error": "Invalid verification token"}, 401
# 处理 URL 验证 (challenge)
if event_data.get("type") == "url_verification":
return await self.handle_challenge(event_data)
# 调用回调函数处理事件
if self.callback:
try:
await self.callback(event_data)
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件回调失败: {e}", exc_info=True)
return {"error": "Event processing failed"}, 500
return {}
def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
"""设置事件回调函数
Args:
callback: 处理事件的异步函数
"""
self.callback = callback
@@ -1,7 +1,6 @@
import asyncio
import os
import random
from collections.abc import Awaitable
from typing import Any
import astrbot.api.message_components as Comp
@@ -203,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
if not isinstance(message.raw_message, dict):
message.raw_message = {}
message.raw_message["poll"] = poll
message.poll = poll
message.__setattr__("poll", poll)
except Exception:
pass
@@ -372,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
self,
session: MessageSession,
message_chain: MessageChain,
) -> Awaitable[Any]:
) -> None:
if not self.api:
logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain)
@@ -3,6 +3,7 @@ import base64
import os
import random
import uuid
from typing import cast
import aiofiles
import botpy
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval:
ret = await self._post_send(stream=stream_payload)
ret = cast(
message.Message,
await self._post_send(stream=stream_payload),
)
stream_payload["index"] += 1
stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time()
@@ -83,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
return None
source = self.message_obj.raw_message
assert isinstance(
if not isinstance(
source,
(
botpy.message.Message,
@@ -91,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
botpy.message.DirectMessage,
botpy.message.C2CMessage,
),
)
):
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
return None
(
plain_text,
@@ -108,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
):
return None
payload = {
payload: dict = {
"content": plain_text,
"msg_id": self.message_obj.message_id,
}
@@ -118,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
ret = None
match type(source):
case botpy.message.GroupMessage:
match source:
case botpy.message.GroupMessage():
if not source.group_openid:
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
return None
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
@@ -140,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
group_openid=source.group_openid,
**payload,
)
case botpy.message.C2CMessage:
case botpy.message.C2CMessage():
if image_base64:
media = await self.upload_group_and_c2c_image(
image_base64,
@@ -169,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
**payload,
)
logger.debug(f"Message sent to C2C: {ret}")
case botpy.message.Message:
case botpy.message.Message():
if image_path:
payload["file_image"] = image_path
ret = await self.bot.api.post_message(
channel_id=source.channel_id,
**payload,
)
case botpy.message.DirectMessage:
case botpy.message.DirectMessage():
if image_path:
payload["file_image"] = image_path
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
case _:
pass
await super().send(self.send_buffer)
self.send_buffer = None
@@ -198,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"file_type": file_type,
"srv_send_msg": False,
}
result = None
if "openid" in kwargs:
payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
return await self.bot.api._http.request(route, json=payload)
if "group_openid" in kwargs:
result = await self.bot.api._http.request(route, json=payload)
elif "group_openid" in kwargs:
payload["group_openid"] = kwargs["group_openid"]
route = Route(
"POST",
"/v2/groups/{group_openid}/files",
group_openid=kwargs["group_openid"],
)
return await self.bot.api._http.request(route, json=payload)
result = await self.bot.api._http.request(route, json=payload)
else:
raise ValueError("Invalid upload parameters")
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to upload image, response is not dict: {result}"
)
return Media(
file_uuid=result["file_uuid"],
file_info=result["file_info"],
ttl=result.get("ttl", 0),
)
async def upload_group_and_c2c_record(
self,
@@ -252,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
result = await self.bot.api._http.request(route, json=payload)
if result:
if not isinstance(result, dict):
logger.error(f"上传文件响应格式错误: {result}")
return None
return Media(
file_uuid=result.get("file_uuid"),
file_info=result.get("file_info"),
file_uuid=result["file_uuid"],
file_info=result["file_info"],
ttl=result.get("ttl", 0),
file_id=result.get("id", ""),
)
except Exception as e:
logger.error(f"上传请求错误: {e}")
@@ -273,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
message_reference: message.Reference | None = None,
media: message.Media | None = None,
msg_id: str | None = None,
msg_seq: str = 1,
msg_seq: int | None = 1,
event_id: str | None = None,
markdown: message.MarkdownPayload | None = None,
keyboard: message.Keyboard | None = None,
@@ -282,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload = locals()
payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
return await self.bot.api._http.request(route, json=payload)
result = await self.bot.api._http.request(route, json=payload)
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to post c2c message, response is not dict: {result}"
)
return message.Message(**result)
@staticmethod
async def _parse_to_qqofficial(message: MessageChain):
@@ -302,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = file_to_base64(image_file_path)
elif i.file and i.file.startswith("base64://"):
image_base64 = i.file
else:
elif i.file:
image_base64 = file_to_base64(i.file)
else:
raise ValueError("Unsupported image file format")
image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record):
if i.file:
@@ -4,6 +4,7 @@ import asyncio
import logging
import os
import time
from typing import cast
import botpy
import botpy.message
@@ -44,7 +45,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
self._commit(abm)
@@ -101,7 +104,7 @@ class QQOfficialPlatformAdapter(Platform):
self.appid = platform_config["appid"]
self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"]
self.unique_session: bool = platform_settings["unique_session"]
qq_group = platform_config["enable_group_c2c"]
guild_dm = platform_config["enable_guild_direct_message"]
@@ -137,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
return PlatformMetadata(
name="qq_official",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
)
@staticmethod
def _parse_from_qqofficial(
message: botpy.message.Message | botpy.message.GroupMessage,
message: botpy.message.Message
| botpy.message.GroupMessage
| botpy.message.DirectMessage
| botpy.message.C2CMessage,
message_type: MessageType,
):
abm = AstrBotMessage()
@@ -150,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.timestamp = int(time.time())
abm.raw_message = message
abm.message_id = message.id
abm.tag = "qq_official"
# abm.tag = "qq_official"
msg: list[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance(
@@ -180,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
message,
botpy.message.DirectMessage,
):
try:
if isinstance(message, botpy.message.Message):
abm.self_id = str(message.mentions[0].id)
except BaseException as _:
else:
abm.self_id = ""
plain_content = message.content.replace(
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Any
from typing import Any, cast
import botpy
import botpy.message
@@ -36,7 +36,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE,
)
abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid
abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
)
self._commit(abm)
@@ -120,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
return PlatformMetadata(
name="qq_official_webhook",
description="QQ 机器人官方 API 适配器",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
)
async def run(self):
@@ -1,5 +1,6 @@
import asyncio
import logging
from typing import cast
import quart
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
@@ -99,7 +100,7 @@ class QQOfficialWebhook:
if opcode == 13:
# validation
signed = await self.webhook_validation(data)
signed = await self.webhook_validation(cast(dict, data))
print(signed)
return signed
@@ -4,9 +4,11 @@ import hmac
import json
import logging
from collections.abc import Callable
from typing import cast
from quart import Quart, Response, request
from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.web.async_client import AsyncWebClient
@@ -66,7 +68,7 @@ class SlackWebhookClient:
"""
try:
# 获取请求体和头部
body = await req.get_data()
body = cast(bytes, await req.get_data())
event_data = json.loads(body.decode("utf-8"))
# Verify Slack request signature
@@ -139,9 +141,14 @@ class SlackSocketClient:
self.event_handler = event_handler
self.socket_client = None
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest):
async def _handle_events(
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
):
"""处理 Socket Mode 事件"""
try:
if self.socket_client is None:
raise RuntimeError("Socket client is not initialized")
# 确认收到事件
response = SocketModeResponse(envelope_id=req.envelope_id)
await self.socket_client.send_socket_mode_response(response)
@@ -3,8 +3,7 @@ import base64
import re
import time
import uuid
from collections.abc import Awaitable
from typing import Any
from typing import Any, cast
import aiohttp
from slack_sdk.socket_mode.request import SocketModeRequest
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
self.metadata = PlatformMetadata(
name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"),
id=cast(str, self.config.get("id")),
support_streaming_message=False,
)
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
logger.debug(f"[slack] RawMessage {event}")
abm = AstrBotMessage()
abm.self_id = self.bot_self_id
abm.self_id = cast(str, self.bot_self_id)
# 获取用户信息
user_id = event.get("user", "")
try:
user_info = await self.web_client.users_info(user=user_id)
user_data = user_info["user"]
user_data = cast(dict, user_info["user"])
user_name = user_data.get("real_name") or user_data.get("name", user_id)
except Exception:
user_name = user_id
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
channel_id = event.get("channel", "")
try:
channel_info = await self.web_client.conversations_info(channel=channel_id)
is_im = channel_info["channel"]["is_im"]
is_im = cast(dict, channel_info["channel"])["is_im"]
if is_im:
abm.type = MessageType.FRIEND_MESSAGE
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
for mention in mentions:
try:
mentioned_user = await self.web_client.users_info(user=mention)
user_data = mentioned_user["user"]
user_data = cast(dict, mentioned_user["user"])
user_name = user_data.get("real_name") or user_data.get(
"name",
mention,
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
)
raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]:
async def run(self) -> None:
self.bot_self_id = await self.get_bot_user_id()
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
@@ -410,7 +409,7 @@ class SlackAdapter(Platform):
await self.socket_client.stop()
if self.webhook_client:
await self.webhook_client.stop()
logger.info("Slack 适配器已被优雅地关闭")
logger.info("Slack 适配器已被关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
@@ -428,3 +427,10 @@ class SlackAdapter(Platform):
def get_client(self):
return self.web_client
def unified_webhook(self) -> bool:
return bool(
self.config.get("unified_webhook_mode", False)
and self.config.get("slack_connection_mode", "") == "webhook"
and self.config.get("webhook_uuid")
)
@@ -1,6 +1,7 @@
import asyncio
import re
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Iterable
from typing import cast
from slack_sdk.web.async_client import AsyncWebClient
@@ -38,7 +39,7 @@ class SlackMessageEvent(AstrMessageEvent):
if isinstance(segment, Image):
# upload file
url = segment.url or segment.file
if url.startswith("http"):
if url and url.startswith("http"):
return {
"type": "image",
"image_url": url,
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section",
"text": {"type": "mrkdwn", "text": "图片上传失败"},
}
image_url = response["files"][0]["url_private"]
image_url = cast(list, response["files"])[0]["url_private"]
logger.debug(f"Slack file upload response: {response}")
return {
"type": "image",
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section",
"text": {"type": "mrkdwn", "text": "文件上传失败"},
}
file_url = response["files"][0]["permalink"]
file_url = cast(list, response["files"])[0]["permalink"]
return {
"type": "section",
"text": {
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
)
members = []
for member_id in members_response["members"]:
for member_id in cast(Iterable, members_response["members"]):
try:
user_info = await self.web_client.users_info(user=member_id)
user_data = user_info["user"]
user_data = cast(dict, user_info["user"])
members.append(
MessageMember(
user_id=member_id,
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
# 如果获取用户信息失败,使用默认信息
members.append(MessageMember(user_id=member_id, nickname=member_id))
channel_data = channel_info["channel"]
channel_data = cast(dict, channel_info["channel"])
return Group(
group_id=channel_id,
group_name=channel_data.get("name", ""),
@@ -424,6 +424,6 @@ class TelegramPlatformAdapter(Platform):
if self.application.updater is not None:
await self.application.updater.stop()
logger.info("Telegram 适配器已被优雅地关闭")
logger.info("Telegram 适配器已被关闭")
except Exception as e:
logger.error(f"Telegram 适配器关闭时出错: {e}")
@@ -1,6 +1,7 @@
import asyncio
import os
import re
from typing import Any, cast
import telegramify_markdown
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
Reply,
)
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
class TelegramPlatformEvent(AstrMessageEvent):
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
"chat_id": user_name,
}
if has_reply:
payload["reply_to_message_id"] = reply_message_id
payload["reply_to_message_id"] = str(reply_message_id)
if message_thread_id:
payload["message_thread_id"] = message_thread_id
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
md_text = telegramify_markdown.markdownify(
chunk,
max_line_length=None,
normalize_whitespace=False,
)
await client.send_message(
text=md_text,
parse_mode="MarkdownV2",
**payload,
**cast(Any, payload),
)
except Exception as e:
logger.warning(
f"MarkdownV2 send failed: {e}. Using plain text instead.",
)
await client.send_message(text=chunk, **payload)
await client.send_message(text=chunk, **cast(Any, payload))
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload)
await client.send_photo(photo=image_path, **cast(Any, payload))
elif isinstance(i, File):
if i.file.startswith("https://"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
await client.send_document(document=i.file, filename=i.name, **payload)
path = await i.get_file()
name = i.name or os.path.basename(path)
await client.send_document(
document=path, filename=name, **cast(Any, payload)
)
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await client.send_voice(voice=path, **payload)
await client.send_voice(voice=path, **cast(Any, payload))
async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -204,6 +200,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
if isinstance(chain, MessageChain):
if chain.type == "break":
# 分割符
if message_id:
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
message_id = None # 重置消息 ID
delta = "" # 重置 delta
continue
@@ -214,24 +219,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload)
await self.client.send_photo(
photo=image_path, **cast(Any, payload)
)
continue
elif isinstance(i, File):
if i.file.startswith("https://"):
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
path = await i.get_file()
name = i.name or os.path.basename(path)
await self.client.send_document(
document=i.file,
filename=i.name,
**payload,
document=path,
filename=name,
**cast(Any, payload),
)
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload)
await self.client.send_voice(voice=path, **cast(Any, payload))
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
@@ -260,7 +264,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
else:
# delta 长度一般不会大于 4096,因此这里直接发送
try:
msg = await self.client.send_message(text=delta, **payload)
msg = await self.client.send_message(
text=delta, **cast(Any, payload)
)
current_content = delta
except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}")
@@ -274,7 +280,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try:
markdown_text = telegramify_markdown.markdownify(
delta,
max_line_length=None,
normalize_whitespace=False,
)
await self.client.edit_message_text(
@@ -2,7 +2,7 @@ import asyncio
import os
import time
import uuid
from collections.abc import Awaitable, Callable
from collections.abc import Callable, Coroutine
from typing import Any
from astrbot import logger
@@ -207,7 +207,7 @@ class WebChatAdapter(Platform):
abm.raw_message = data
return abm
def run(self) -> Awaitable[Any]:
def run(self) -> Coroutine[Any, Any, None]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
@@ -1,11 +1,12 @@
import base64
import json
import os
import shutil
import uuid
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import File, Image, Plain, Record
from astrbot.api.message_components import File, Image, Json, Plain, Record
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .webchat_queue_mgr import webchat_queue_mgr
@@ -41,12 +42,20 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put(
{
"type": "plain",
"cid": cid,
"data": data,
"streaming": streaming,
"chain_type": message.type,
},
)
elif isinstance(comp, Json):
await web_chat_back_queue.put(
{
"type": "plain",
"data": json.dumps(comp.data, ensure_ascii=False),
"streaming": streaming,
"chain_type": message.type,
},
)
elif isinstance(comp, Image):
# save image to local
filename = f"{str(uuid.uuid4())}.jpg"
@@ -58,7 +67,6 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put(
{
"type": "image",
"cid": cid,
"data": data,
"streaming": streaming,
},
@@ -74,7 +82,6 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put(
{
"type": "record",
"cid": cid,
"data": data,
"streaming": streaming,
},
@@ -91,7 +98,6 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put(
{
"type": "file",
"cid": cid,
"data": data,
"streaming": streaming,
},
@@ -101,9 +107,9 @@ class WebChatMessageEvent(AstrMessageEvent):
return data
async def send(self, message: MessageChain):
async def send(self, message: MessageChain | None):
await WebChatMessageEvent._send(message, session_id=self.session_id)
await super().send(message)
await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback: bool = False):
final_data = ""
@@ -111,18 +117,17 @@ class WebChatMessageEvent(AstrMessageEvent):
cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator:
if chain.type == "break" and final_data:
# 分割符
await web_chat_back_queue.put(
{
"type": "break", # break means a segment end
"data": final_data,
"streaming": True,
"cid": cid,
},
)
final_data = ""
continue
# if chain.type == "break" and final_data:
# # 分割符
# await web_chat_back_queue.put(
# {
# "type": "break", # break means a segment end
# "data": final_data,
# "streaming": True,
# },
# )
# final_data = ""
# continue
r = await WebChatMessageEvent._send(
chain,
@@ -142,7 +147,6 @@ class WebChatMessageEvent(AstrMessageEvent):
"data": final_data,
"reasoning": reasoning_content,
"streaming": True,
"cid": cid,
},
)
await super().send_streaming(generator, use_fallback)
@@ -4,6 +4,7 @@ import json
import os
import time
import traceback
from typing import cast
import aiohttp
import anyio
@@ -69,7 +70,7 @@ class WeChatPadProAdapter(Platform):
)
self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码
self.wxid = None # 用于保存登录成功后的 wxid
self.wxid: str | None = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join(
get_astrbot_data_path(),
"wechatpadpro_credentials.json",
@@ -398,7 +399,7 @@ class WeChatPadProAdapter(Platform):
)
await asyncio.sleep(5)
async def handle_websocket_message(self, message: str):
async def handle_websocket_message(self, message: str | bytes):
"""处理从 WebSocket 接收到的消息。"""
logger.debug(f"收到 WebSocket 消息: {message}")
try:
@@ -430,10 +431,13 @@ class WeChatPadProAdapter(Platform):
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
if self.wxid is None:
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
return None
abm = AstrBotMessage()
abm.raw_message = raw_message
abm.message_id = str(raw_message.get("msg_id"))
abm.timestamp = raw_message.get("create_time")
abm.timestamp = cast(int, raw_message.get("create_time"))
abm.self_id = self.wxid
if int(time.time()) - abm.timestamp > 180:
@@ -446,7 +450,7 @@ class WeChatPadProAdapter(Platform):
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
content = raw_message.get("content", {}).get("str", "")
push_content = raw_message.get("push_content", "")
msg_type = raw_message.get("msg_type")
msg_type = cast(int, raw_message.get("msg_type"))
abm.message_str = ""
abm.message = []
@@ -574,7 +578,7 @@ class WeChatPadProAdapter(Platform):
from_user_name: str,
to_user_name: str,
msg_id: int,
):
) -> dict | None:
"""下载原始图片。"""
url = f"{self.base_url}/message/GetMsgBigImg"
params = {"key": self.auth_key}
@@ -725,12 +729,15 @@ class WeChatPadProAdapter(Platform):
# 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
msg_id = raw_message.get("msg_id")
msg_id = cast(int, raw_message.get("msg_id"))
image_resp = await self._download_raw_image(
from_user_name,
to_user_name,
msg_id,
)
if image_resp is None:
logger.error(f"下载图片失败: msg_id={msg_id}")
return
image_bs64_data = (
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
)
@@ -771,6 +778,9 @@ class WeChatPadProAdapter(Platform):
bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id")
if new_msg_id is None:
logger.error("语音消息缺少 new_msg_id")
return
data_parser = GeweDataParser(
content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
@@ -778,6 +788,9 @@ class WeChatPadProAdapter(Platform):
)
voicemsg = data_parser._format_to_xml().find("voicemsg")
if voicemsg is None:
logger.error("无法从 XML 解析 voicemsg 节点")
return
bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice(
@@ -786,6 +799,9 @@ class WeChatPadProAdapter(Platform):
bufid=bufid,
length=length,
)
if voice_resp is None:
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
return
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data)
@@ -827,7 +843,8 @@ class WeChatPadProAdapter(Platform):
try:
if self.ws_handle_task:
self.ws_handle_task.cancel()
self._shutdown_event.set()
if self._shutdown_event is not None:
self._shutdown_event.set()
except Exception:
pass
@@ -894,8 +911,8 @@ class WeChatPadProAdapter(Platform):
async def get_contact_details_list(
self,
room_wx_id_list: list[str] = None,
user_names: list[str] = None,
room_wx_id_list: list[str] | None = None,
user_names: list[str] | None = None,
) -> dict | None:
"""获取联系人详情列表。"""
if room_wx_id_list is None:
@@ -2,7 +2,8 @@ import asyncio
import os
import sys
import uuid
from typing import Any
from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart
from requests import Response
@@ -40,7 +41,7 @@ else:
class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
self.port = int(config.get("port"))
self.port = int(cast(str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule(
"/callback/command",
@@ -60,7 +61,7 @@ class WecomServer:
config["corpid"].strip(),
)
self.callback = None
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
async def verify(self):
@@ -114,7 +115,7 @@ class WecomServer:
logger.error("解密失败,签名异常,请检查配置。")
raise
else:
msg = parse_message(xml)
msg = cast(BaseMessage, parse_message(xml))
logger.info(f"解析成功: {msg}")
if self.callback:
@@ -176,10 +177,10 @@ class WecomPlatformAdapter(Platform):
# inject
self.wechat_kf_api = WeChatKF(client=self.client)
self.wechat_kf_message_api = WeChatKFMessage(self.client)
self.client.kf = self.wechat_kf_api
self.client.kf_message = self.wechat_kf_message_api
self.client.__setattr__("kf", self.wechat_kf_api)
self.client.__setattr__("kf_message", self.wechat_kf_message_api)
self.client.API_BASE_URL = self.api_base_url
self.client.__setattr__("API_BASE_URL", self.api_base_url)
async def callback(msg: BaseMessage):
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
@@ -278,37 +279,33 @@ class WecomPlatformAdapter(Platform):
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
abm = AstrBotMessage()
if msg.type == "text":
assert isinstance(msg, TextMessage)
if isinstance(msg, TextMessage):
abm.message_str = msg.content
abm.self_id = str(msg.agent)
abm.message = [Plain(msg.content)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(msg.id)
abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif msg.type == "image":
assert isinstance(msg, ImageMessage)
elif isinstance(msg, ImageMessage):
abm.message_str = "[图片]"
abm.self_id = str(msg.agent)
abm.message = [Image(file=msg.image, url=msg.image)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(msg.id)
abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
elif isinstance(msg, VoiceMessage):
resp: Response = await asyncio.get_event_loop().run_in_executor(
None,
self.client.media.download,
@@ -335,11 +332,11 @@ class WecomPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(msg.id)
abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id
abm.raw_message = msg
else:
@@ -351,7 +348,7 @@ class WecomPlatformAdapter(Platform):
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
msgtype = msg.get("msgtype")
external_userid = msg.get("external_userid")
external_userid = cast(str, msg.get("external_userid"))
abm = AstrBotMessage()
abm.raw_message = msg
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
@@ -425,4 +422,4 @@ class WecomPlatformAdapter(Platform):
await self.server.server.shutdown()
except Exception as _:
pass
logger.info("企业微信 适配器已被优雅地关闭")
logger.info("企业微信 适配器已被关闭")
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
if is_wechat_kf:
# 微信客服
kf_message_api = getattr(self.client, "kf_message", None)
if not kf_message_api:
if not isinstance(kf_message_api, WeChatKFMessage):
logger.warning("未找到微信客服发送消息方法。")
return
assert isinstance(kf_message_api, WeChatKFMessage)
user_id = self.get_sender_id()
for comp in message.chain:
if isinstance(comp, Plain):
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
@staticmethod
async def _send(
message_chain: MessageChain,
message_chain: MessageChain | None,
stream_id: str,
queue_mgr: WecomAIQueueMgr,
streaming: bool = False,
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data
async def send(self, message: MessageChain):
async def send(self, message: MessageChain | None):
"""发送消息"""
raw = self.message_obj.raw_message
assert isinstance(raw, dict), (
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
)
stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
await super().send(message)
await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息,参考webchat的send_streaming设计"""
@@ -1,7 +1,8 @@
import asyncio
import sys
import uuid
from typing import Any
from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart
from requests import Response
@@ -36,7 +37,7 @@ else:
class WeixinOfficialAccountServer:
def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__)
self.port = int(config.get("port"))
self.port = int(cast(int | str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.token = config.get("token")
self.encoding_aes_key = config.get("encoding_aes_key")
@@ -55,7 +56,7 @@ class WeixinOfficialAccountServer:
self.event_queue = event_queue
self.callback = None
self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event()
async def verify(self):
@@ -114,6 +115,9 @@ class WeixinOfficialAccountServer:
raise
else:
msg = parse_message(xml)
if not msg:
logger.error("解析失败。msg为None。")
raise
logger.info(f"解析成功: {msg}")
if self.callback:
@@ -176,7 +180,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.config["secret"].strip(),
)
self.client.API_BASE_URL = self.api_base_url
self.client.__setattr__("API_BASE_URL", self.api_base_url)
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
# msgid -> Future
@@ -188,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.convert_message(msg, None)
else:
if msg.id in self.wexin_event_workers:
future = self.wexin_event_workers[msg.id]
future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}")
else:
future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg.id] = future
self.wexin_event_workers[str(cast(str | int, msg.id))] = future
await self.convert_message(msg, future)
# I love shield so much!
result = await asyncio.wait_for(
@@ -200,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
60,
) # wait for 60s
logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None)
self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
return result # xml. see weixin_offacc_event.py
except asyncio.TimeoutError:
pass
@@ -248,33 +252,33 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
async def convert_message(
self,
msg,
future: asyncio.Future = None,
future: asyncio.Future | None = None,
) -> AstrBotMessage | None:
abm = AstrBotMessage()
if isinstance(msg, TextMessage):
abm.message_str = msg.content
abm.message_str = cast(str, msg.content)
abm.self_id = str(msg.target)
abm.message = [Plain(msg.content)]
abm.message = [Plain(cast(str, msg.content))]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
elif msg.type == "image":
assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]"
abm.self_id = str(msg.target)
abm.message = [Image(file=msg.image, url=msg.image)]
abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
elif msg.type == "voice":
assert isinstance(msg, VoiceMessage)
@@ -306,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember(
msg.source,
msg.source,
cast(str, msg.source),
cast(str, msg.source),
)
abm.message_id = msg.id
abm.timestamp = msg.time
abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id
else:
logger.warning(f"暂未实现的事件: {msg.type}")
future.set_result(None)
if future:
future.set_result(None)
return
# 很不优雅 :(
abm.raw_message = {
@@ -344,4 +349,4 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.server.server.shutdown()
except Exception as _:
pass
logger.info("微信公众平台 适配器已被优雅地关闭")
logger.info("微信公众平台 适配器已被关闭")
@@ -1,5 +1,6 @@
import asyncio
import uuid
from typing import cast
from wechatpy import WeChatClient
from wechatpy.replies import ImageReply, TextReply, VoiceReply
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
async def send(self, message: MessageChain):
message_obj = self.message_obj
active_send_mode = message_obj.raw_message.get("active_send_mode", False)
active_send_mode = cast(dict, message_obj.raw_message).get(
"active_send_mode", False
)
for comp in message.chain:
if isinstance(comp, Plain):
# Split long text messages if needed
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = TextReply(
content=chunk,
message=self.message_obj.raw_message["message"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = ImageReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else:
reply = VoiceReply(
media_id=response["media_id"],
message=self.message_obj.raw_message["message"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = self.message_obj.raw_message["future"]
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
+73 -9
View File
@@ -1,3 +1,5 @@
from __future__ import annotations
import base64
import enum
import json
@@ -12,6 +14,7 @@ import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core.agent.message import (
AssistantMessageSegment,
ContentPart,
ToolCall,
ToolCallMessageSegment,
)
@@ -90,6 +93,8 @@ class ProviderRequest:
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。"""
func_tool: ToolSet | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
@@ -164,13 +169,23 @@ class ProviderRequest:
async def assemble_context(self) -> dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if self.prompt and self.prompt.strip():
content_blocks.append({"type": "text", "text": self.prompt})
elif self.image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
# 2. 额外的内容块(系统提醒、指令等)
if self.extra_user_content_parts:
for part in self.extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if self.image_urls:
user_content = {
"role": "user",
"content": [
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
],
}
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
@@ -183,11 +198,21 @@ class ProviderRequest:
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
content_blocks.append(
{"type": "image_url", "image_url": {"url": image_data}},
)
return user_content
return {"role": "user", "content": self.prompt}
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
if (
len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
and not self.extra_user_content_parts
and not self.image_urls
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
@@ -199,6 +224,38 @@ class ProviderRequest:
return ""
@dataclass
class TokenUsage:
input_other: int = 0
"""The number of input tokens, excluding cached tokens."""
input_cached: int = 0
"""The number of input cached tokens."""
output: int = 0
"""The number of output tokens."""
@property
def total(self) -> int:
return self.input_other + self.input_cached + self.output
@property
def input(self) -> int:
return self.input_other + self.input_cached
def __add__(self, other: TokenUsage) -> TokenUsage:
return TokenUsage(
input_other=self.input_other + other.input_other,
input_cached=self.input_cached + other.input_cached,
output=self.output + other.output,
)
def __sub__(self, other: TokenUsage) -> TokenUsage:
return TokenUsage(
input_other=self.input_other - other.input_other,
input_cached=self.input_cached - other.input_cached,
output=self.output - other.output,
)
@dataclass
class LLMResponse:
role: str
@@ -227,6 +284,11 @@ class LLMResponse:
is_chunk: bool = False
"""Indicates if the response is a chunked response."""
id: str | None = None
"""The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
usage: TokenUsage | None = None
"""The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
def __init__(
self,
role: str,
@@ -241,6 +303,8 @@ class LLMResponse:
| AnthropicMessage
| None = None,
is_chunk: bool = False,
id: str | None = None,
usage: TokenUsage | None = None,
):
"""初始化 LLMResponse
+3 -3
View File
@@ -4,7 +4,7 @@ import asyncio
import copy
import json
import os
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
import aiohttp
@@ -118,7 +118,7 @@ class FunctionToolManager:
name: str,
func_args: list[dict],
desc: str,
handler: Callable[..., Awaitable[Any]],
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> FuncTool:
params = {
"type": "object", # hard-coded here
@@ -140,7 +140,7 @@ class FunctionToolManager:
name: str,
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> None:
"""添加函数调用工具
+329 -163
View File
@@ -1,5 +1,7 @@
import asyncio
import copy
import traceback
from typing import Protocol, runtime_checkable
from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -10,6 +12,7 @@ from .entities import ProviderType
from .provider import (
EmbeddingProvider,
Provider,
Providers,
RerankProvider,
STTProvider,
TTSProvider,
@@ -17,6 +20,11 @@ from .provider import (
from .register import llm_tools, provider_cls_map
@runtime_checkable
class HasInitialize(Protocol):
async def initialize(self) -> None: ...
class ProviderManager:
def __init__(
self,
@@ -25,10 +33,12 @@ class ProviderManager:
persona_mgr: PersonaManager,
):
self.reload_lock = asyncio.Lock()
self.resource_lock = asyncio.Lock()
self.persona_mgr = persona_mgr
self.acm = acm
config = acm.confs["default"]
self.providers_config: list = config["provider"]
self.provider_sources_config: list = config.get("provider_sources", [])
self.provider_settings: dict = config["provider_settings"]
self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
@@ -48,7 +58,7 @@ class ProviderManager:
"""加载的 Rerank Provider 的实例"""
self.inst_map: dict[
str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
Providers,
] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools
@@ -123,15 +133,13 @@ class ProviderManager:
self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id)
def get_using_provider(
self,
provider_type: ProviderType,
umo=None,
) -> Provider | STTProvider | TTSProvider | None:
self, provider_type: ProviderType, umo=None
) -> Providers | None:
"""获取正在使用的提供商实例。
Args:
@@ -143,6 +151,7 @@ class ProviderManager:
"""
provider = None
provider_id = None
if umo:
provider_id = sp.get(
f"provider_perf_{provider_type.value}",
@@ -180,6 +189,12 @@ class ProviderManager:
)
else:
raise ValueError(f"Unknown provider type: {provider_type}")
if not provider and provider_id:
logger.warning(
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
)
return provider
async def initialize(self):
@@ -191,7 +206,6 @@ class ProviderManager:
logger.error(traceback.format_exc())
logger.error(e)
# 设置默认提供商
selected_provider_id = sp.get(
"curr_provider",
self.provider_settings.get("default_provider_id"),
@@ -210,22 +224,173 @@ class ProviderManager:
scope="global",
scope_id="global",
)
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
temp_provider = (
self.inst_map.get(selected_provider_id)
if isinstance(selected_provider_id, str)
else None
)
self.curr_provider_inst = (
temp_provider if isinstance(temp_provider, Provider) else None
)
if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0]
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id)
temp_stt = (
self.inst_map.get(selected_stt_provider_id)
if isinstance(selected_stt_provider_id, str)
else None
)
self.curr_stt_provider_inst = (
temp_stt if isinstance(temp_stt, STTProvider) else None
)
if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id)
temp_tts = (
self.inst_map.get(selected_tts_provider_id)
if isinstance(selected_tts_provider_id, str)
else None
)
self.curr_tts_provider_inst = (
temp_tts if isinstance(temp_tts, TTSProvider) else None
)
if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
def dynamic_import_provider(self, type: str):
"""动态导入提供商适配器模块
Args:
type (str): 提供商请求类型
Raises:
ImportError: 如果提供商类型未知或无法导入对应模块则抛出异常
"""
match type:
case "openai_chat_completion":
from .sources.openai_source import (
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
)
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "groq_chat_completion":
from .sources.groq_source import ProviderGroq as ProviderGroq
case "anthropic_chat_completion":
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
case "googlegenai_chat_completion":
from .sources.gemini_source import (
ProviderGoogleGenAI as ProviderGoogleGenAI,
)
case "sensevoice_stt_selfhost":
from .sources.sensevoice_selfhosted_source import (
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
)
case "openai_whisper_api":
from .sources.whisper_api_source import (
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
)
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import (
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
)
case "xinference_stt":
from .sources.xinference_stt_provider import (
ProviderXinferenceSTT as ProviderXinferenceSTT,
)
case "openai_tts_api":
from .sources.openai_tts_api_source import (
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
)
case "edge_tts":
from .sources.edge_tts_source import (
ProviderEdgeTTS as ProviderEdgeTTS,
)
case "gsv_tts_selfhost":
from .sources.gsv_selfhosted_source import (
ProviderGSVTTS as ProviderGSVTTS,
)
case "gsvi_tts_api":
from .sources.gsvi_tts_source import (
ProviderGSVITTS as ProviderGSVITTS,
)
case "fishaudio_tts_api":
from .sources.fishaudio_tts_api_source import (
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
)
case "dashscope_tts":
from .sources.dashscope_tts import (
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
)
case "azure_tts":
from .sources.azure_tts_source import (
AzureTTSProvider as AzureTTSProvider,
)
case "minimax_tts_api":
from .sources.minimax_tts_api_source import (
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
)
case "volcengine_tts":
from .sources.volcengine_tts import (
ProviderVolcengineTTS as ProviderVolcengineTTS,
)
case "gemini_tts":
from .sources.gemini_tts_source import (
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
)
case "openai_embedding":
from .sources.openai_embedding_source import (
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
)
case "gemini_embedding":
from .sources.gemini_embedding_source import (
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
)
case "vllm_rerank":
from .sources.vllm_rerank_source import (
VLLMRerankProvider as VLLMRerankProvider,
)
case "xinference_rerank":
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
case "bailian_rerank":
from .sources.bailian_rerank_source import (
BailianRerankProvider as BailianRerankProvider,
)
def get_merged_provider_config(self, provider_config: dict) -> dict:
"""获取 provider 配置和 provider_source 配置合并后的结果
Returns:
dict: 合并后的 provider 配置key provider idvalue 为合并后的配置字典
"""
pc = copy.deepcopy(provider_config)
provider_source_id = pc.get("provider_source_id", "")
if provider_source_id:
provider_source = None
for ps in self.provider_sources_config:
if ps.get("id") == provider_source_id:
provider_source = ps
break
if provider_source:
# 合并配置,provider 的配置优先级更高
merged_config = {**provider_source, **pc}
# 保持 id 为 provider 的 id,而不是 source 的 id
merged_config["id"] = pc["id"]
pc = merged_config
return pc
async def load_provider(self, provider_config: dict):
# 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并
provider_config = self.get_merged_provider_config(provider_config)
if not provider_config["enable"]:
logger.info(f"Provider {provider_config['id']} is disabled, skipping")
return
@@ -238,99 +403,7 @@ class ProviderManager:
# 动态导入
try:
match provider_config["type"]:
case "openai_chat_completion":
from .sources.openai_source import (
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
)
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "groq_chat_completion":
from .sources.groq_source import ProviderGroq as ProviderGroq
case "anthropic_chat_completion":
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
case "googlegenai_chat_completion":
from .sources.gemini_source import (
ProviderGoogleGenAI as ProviderGoogleGenAI,
)
case "sensevoice_stt_selfhost":
from .sources.sensevoice_selfhosted_source import (
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
)
case "openai_whisper_api":
from .sources.whisper_api_source import (
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
)
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import (
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
)
case "xinference_stt":
from .sources.xinference_stt_provider import (
ProviderXinferenceSTT as ProviderXinferenceSTT,
)
case "openai_tts_api":
from .sources.openai_tts_api_source import (
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
)
case "edge_tts":
from .sources.edge_tts_source import (
ProviderEdgeTTS as ProviderEdgeTTS,
)
case "gsv_tts_selfhost":
from .sources.gsv_selfhosted_source import (
ProviderGSVTTS as ProviderGSVTTS,
)
case "gsvi_tts_api":
from .sources.gsvi_tts_source import (
ProviderGSVITTS as ProviderGSVITTS,
)
case "fishaudio_tts_api":
from .sources.fishaudio_tts_api_source import (
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
)
case "dashscope_tts":
from .sources.dashscope_tts import (
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
)
case "azure_tts":
from .sources.azure_tts_source import (
AzureTTSProvider as AzureTTSProvider,
)
case "minimax_tts_api":
from .sources.minimax_tts_api_source import (
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
)
case "volcengine_tts":
from .sources.volcengine_tts import (
ProviderVolcengineTTS as ProviderVolcengineTTS,
)
case "gemini_tts":
from .sources.gemini_tts_source import (
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
)
case "openai_embedding":
from .sources.openai_embedding_source import (
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
)
case "gemini_embedding":
from .sources.gemini_embedding_source import (
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
)
case "vllm_rerank":
from .sources.vllm_rerank_source import (
VLLMRerankProvider as VLLMRerankProvider,
)
case "xinference_rerank":
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
case "bailian_rerank":
from .sources.bailian_rerank_source import (
BailianRerankProvider as BailianRerankProvider,
)
self.dynamic_import_provider(provider_config["type"])
except (ImportError, ModuleNotFoundError) as e:
logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
@@ -358,73 +431,103 @@ class ProviderManager:
provider_metadata.id = provider_config["id"]
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = cls_type(provider_config, self.provider_settings)
match provider_metadata.provider_type:
case ProviderType.SPEECH_TO_TEXT:
# STT 任务
if not issubclass(cls_type, STTProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of STTProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
if isinstance(inst, HasInitialize):
await inst.initialize()
self.stt_provider_insts.append(inst)
if (
self.provider_stt_settings.get("provider_id")
== provider_config["id"]
):
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
self.stt_provider_insts.append(inst)
if (
self.provider_stt_settings.get("provider_id")
== provider_config["id"]
):
self.curr_stt_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
)
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
case ProviderType.TEXT_TO_SPEECH:
# TTS 任务
if not issubclass(cls_type, TTSProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of TTSProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.tts_provider_insts.append(inst)
if (
self.provider_settings.get("provider_id")
== provider_config["id"]
):
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
)
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
case ProviderType.CHAT_COMPLETION:
# 文本生成任务
if not issubclass(cls_type, Provider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of Provider"
)
inst = cls_type(
provider_config,
self.provider_settings,
)
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
self.tts_provider_insts.append(inst)
if self.provider_settings.get("provider_id") == provider_config["id"]:
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
case ProviderType.EMBEDDING:
if not issubclass(cls_type, EmbeddingProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.embedding_provider_insts.append(inst)
case ProviderType.RERANK:
if not issubclass(cls_type, RerankProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of RerankProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.rerank_provider_insts.append(inst)
case _:
# 未知供应商抛出异常,确保inst初始化
# Should be unreachable
raise Exception(
f"未知的提供商类型:{provider_metadata.provider_type}"
)
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = cls_type(
provider_config,
self.provider_settings,
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst
except Exception as e:
@@ -443,6 +546,7 @@ class ProviderManager:
# 和配置文件保持同步
self.providers_config = astrbot_config["provider"]
self.provider_sources_config = astrbot_config.get("provider_sources", [])
config_ids = [provider["id"] for provider in self.providers_config]
logger.info(f"providers in user's config: {config_ids}")
for key in list(self.inst_map.keys()):
@@ -514,6 +618,68 @@ class ProviderManager:
)
del self.inst_map[provider_id]
async def delete_provider(
self, provider_id: str | None = None, provider_source_id: str | None = None
):
"""Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion."""
async with self.resource_lock:
# delete from config
target_prov_ids = []
if provider_id:
target_prov_ids.append(provider_id)
else:
for prov in self.providers_config:
if prov.get("provider_source_id") == provider_source_id:
target_prov_ids.append(prov.get("id"))
config = self.acm.default_conf
for tpid in target_prov_ids:
await self.terminate_provider(tpid)
config["provider"] = [
prov for prov in config["provider"] if prov.get("id") != tpid
]
config.save_config()
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")
async def update_provider(self, origin_provider_id: str, new_config: dict):
"""Update provider config and reload the instance. Config will be saved after update."""
async with self.resource_lock:
npid = new_config.get("id", None)
if not npid:
raise ValueError("New provider config must have an 'id' field")
config = self.acm.default_conf
for provider in config["provider"]:
if (
provider.get("id", None) == npid
and provider.get("id", None) != origin_provider_id
):
raise ValueError(f"Provider ID {npid} already exists")
# update config
for idx, provider in enumerate(config["provider"]):
if provider.get("id", None) == origin_provider_id:
config["provider"][idx] = new_config
break
else:
raise ValueError(f"Provider ID {origin_provider_id} not found")
config.save_config()
# reload instance
await self.reload(new_config)
async def create_provider(self, new_config: dict):
"""Add new provider config and load the instance. Config will be saved after addition."""
async with self.resource_lock:
npid = new_config.get("id", None)
if not npid:
raise ValueError("New provider config must have an 'id' field")
config = self.acm.default_conf
for provider in config["provider"]:
if provider.get("id", None) == npid:
raise ValueError(f"Provider ID {npid} already exists")
# add to config
config["provider"].append(new_config)
config.save_config()
# load instance
await self.load_provider(new_config)
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
+17 -2
View File
@@ -2,8 +2,9 @@ import abc
import asyncio
import os
from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message
from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.provider.entities import (
LLMResponse,
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path
Providers: TypeAlias = Union[
"Provider",
"STTProvider",
"TTSProvider",
"EmbeddingProvider",
"RerankProvider",
]
class AbstractProvider(abc.ABC):
"""Provider Abstract Class"""
@@ -94,6 +103,7 @@ class Provider(AbstractProvider):
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -105,6 +115,7 @@ class Provider(AbstractProvider):
tools: tool set
contexts: 上下文 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling
extra_user_content_parts: 额外的用户内容块列表用于在用户消息后添加额外的文本块如系统提醒指令等
kwargs: 其他参数
Notes:
@@ -124,6 +135,7 @@ class Provider(AbstractProvider):
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
@@ -135,6 +147,7 @@ class Provider(AbstractProvider):
tools: tool set
contexts: 上下文 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling
extra_user_content_parts: 额外的用户内容块列表用于在用户消息后添加额外的文本块如系统提醒指令等
kwargs: 其他参数
Notes:
@@ -142,7 +155,9 @@ class Provider(AbstractProvider):
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
"""
...
if False: # pragma: no cover - make this an async generator for typing
yield None # type: ignore
raise NotImplementedError()
async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录"""
+166 -47
View File
@@ -6,10 +6,13 @@ from mimetypes import guess_type
import anthropic
from anthropic import AsyncAnthropic
from anthropic.types import Message
from anthropic.types.message_delta_usage import MessageDeltaUsage
from anthropic.types.usage import Usage
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.agent.message import ContentPart
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
@@ -45,7 +48,7 @@ class ProviderAnthropic(Provider):
base_url=self.base_url,
)
self.set_model(provider_config["model_config"]["model"])
self.set_model(provider_config.get("model", "unknown"))
def _prepare_payload(self, messages: list[dict]):
"""准备 Anthropic API 的请求 payload
@@ -107,12 +110,32 @@ class ProviderAnthropic(Provider):
return system_prompt, new_messages
def _extract_usage(self, usage: Usage) -> TokenUsage:
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
return TokenUsage(
input_other=usage.input_tokens or 0,
input_cached=usage.cache_read_input_tokens or 0,
output=usage.output_tokens,
)
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
if usage.input_tokens is not None:
token_usage.input_other = usage.input_tokens
if usage.cache_read_input_tokens is not None:
token_usage.input_cached = usage.cache_read_input_tokens
if usage.output_tokens is not None:
token_usage.output = usage.output_tokens
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list
completion = await self.client.messages.create(**payloads, stream=False)
extra_body = self.provider_config.get("custom_extra_body", {})
completion = await self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
)
assert isinstance(completion, Message)
logger.debug(f"completion: {completion}")
@@ -131,6 +154,10 @@ class ProviderAnthropic(Provider):
llm_response.tools_call_args.append(content_block.input)
llm_response.tools_call_name.append(content_block.name)
llm_response.tools_call_ids.append(content_block.id)
llm_response.id = completion.id
llm_response.usage = self._extract_usage(completion.usage)
# TODO(Soulter): 处理 end_turn 情况
if not llm_response.completion_text and not llm_response.tools_call_args:
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}")
@@ -151,10 +178,19 @@ class ProviderAnthropic(Provider):
# 用于累积最终结果
final_text = ""
final_tool_calls = []
id = None
usage = TokenUsage()
extra_body = self.provider_config.get("custom_extra_body", {})
async with self.client.messages.stream(**payloads) as stream:
async with self.client.messages.stream(
**payloads, extra_body=extra_body
) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
if event.type == "message_start":
# the usage contains input token usage
id = event.message.id
usage = self._extract_usage(event.message.usage)
if event.type == "content_block_start":
if event.content_block.type == "text":
# 文本块开始
@@ -162,6 +198,8 @@ class ProviderAnthropic(Provider):
role="assistant",
completion_text="",
is_chunk=True,
usage=usage,
id=id,
)
elif event.content_block.type == "tool_use":
# 工具使用块开始,初始化缓冲区
@@ -179,6 +217,8 @@ class ProviderAnthropic(Provider):
role="assistant",
completion_text=event.delta.text,
is_chunk=True,
usage=usage,
id=id,
)
elif event.delta.type == "input_json_delta":
# 工具调用参数增量
@@ -215,6 +255,8 @@ class ProviderAnthropic(Provider):
tools_call_name=[tool_info["name"]],
tools_call_ids=[tool_info["id"]],
is_chunk=True,
usage=usage,
id=id,
)
except json.JSONDecodeError:
# JSON 解析失败,跳过这个工具调用
@@ -223,11 +265,17 @@ class ProviderAnthropic(Provider):
# 清理缓冲区
del tool_use_buffer[event.index]
elif event.type == "message_delta":
if event.usage:
self._update_usage(usage, event.usage)
# 返回最终的完整结果
final_response = LLMResponse(
role="assistant",
completion_text=final_text,
is_chunk=False,
usage=usage,
id=id,
)
if final_tool_calls:
@@ -249,13 +297,16 @@ class ProviderAnthropic(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -277,10 +328,9 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
payloads = {"messages": new_messages, "model": model}
# Anthropic has a different way of handling system prompts
if system_prompt:
@@ -290,7 +340,6 @@ class ProviderAnthropic(Provider):
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e
return llm_response
@@ -305,13 +354,16 @@ class ProviderAnthropic(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
):
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -332,10 +384,9 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": new_messages, **model_config}
payloads = {"messages": new_messages, "model": model}
# Anthropic has a different way of handling system prompts
if system_prompt:
@@ -344,48 +395,116 @@ class ProviderAnthropic(Provider):
async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
async def assemble_context(
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文,支持文本和图片"""
if not image_urls:
return {"role": "user", "content": text}
content = []
content.append({"type": "text", "text": text})
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content.append({"type": "text", "text": " "})
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for block in extra_user_content_parts:
block_type = block.get("type")
# Get mime type for the image
mime_type, _ = guess_type(image_url)
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG if can't determine
if block_type == "text":
# 文本直接添加
content.append(block)
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
elif block_type == "image_url":
# 转换 OpenAI 格式的图片为 Anthropic 格式
image_url_data = block.get("image_url", {})
if isinstance(image_url_data, dict):
url = image_url_data.get("url", "")
else:
# 兼容直接传 URL 字符串的情况
url = str(image_url_data)
if url and url.startswith("data:"):
try:
# 提取 MIME 类型和 base64 数据
mime_type = url.split(":")[1].split(";")[0]
base64_data = (
url.split("base64,")[1] if "base64," in url else url
)
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data,
},
}
)
except Exception as e:
logger.warning(f"转换 image_url 到 Anthropic 格式失败: {e}")
else:
logger.warning(f"image_url 不是有效的 data URI: {url[:50]}...")
else:
# 其他类型(如 audio_urlAnthropic 不支持,记录警告
logger.debug(f"Anthropic 不支持的内容类型 '{block_type}',已忽略")
# 3. 图片内容
if image_urls:
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
# Get mime type for the image
mime_type, _ = guess_type(image_url)
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG if can't determine
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
},
},
},
)
)
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content) == 1
and content[0]["type"] == "text"
):
return {"role": "user", "content": content[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content}
async def encode_image_bs64(self, image_url: str) -> str:
@@ -29,15 +29,24 @@ class OTTSProvider:
self.last_sync_time = 0
self.timeout = Timeout(10.0)
self.retry_count = 3
self.client = None
self._client: AsyncClient | None = None
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self):
self.client = AsyncClient(timeout=self.timeout)
self._client = AsyncClient(timeout=self.timeout)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client:
await self.client.aclose()
if self._client:
await self._client.aclose()
self._client = None
async def _sync_time(self):
try:
@@ -90,6 +99,7 @@ class OTTSProvider:
if attempt == self.retry_count - 1:
raise RuntimeError(f"OTTS请求失败: {e!s}") from e
await asyncio.sleep(0.5 * (attempt + 1))
raise RuntimeError("OTTS未返回音频文件")
class AzureNativeProvider(TTSProvider):
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
)
self.client = None
self._client: AsyncClient | None = None
self.token = None
self.token_expire = 0
self.voice_params = {
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
"volume": provider_config.get("azure_tts_volume", "100"),
}
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self):
self.client = AsyncClient(
self._client = AsyncClient(
headers={
"User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml",
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client:
await self.client.aclose()
if self._client:
await self._client.aclose()
self._client = None
async def _refresh_token(self):
token_url = (
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
key_value = provider_config.get("azure_tts_subscription_key", "")
self.provider = self._parse_provider(key_value, provider_config)
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider:
def _parse_provider(
self, key_value: str, config: dict
) -> OTTSProvider | AzureNativeProvider:
if key_value.lower().startswith("other["):
json_str = ""
try:
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
if not match:
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
Returns:
重排序结果列表
"""
if not self.client:
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
return []
if not documents:
logger.warning("文档列表为空,返回空结果")
return []
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
self.set_model(provider_config.get("model"))
self.set_model(provider_config["model"])
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
dashscope.api_key = self.chosen_api_key
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
kwargs = {
"model": model,
"text": text,
"messages": None,
"api_key": self.chosen_api_key,
"voice": self.voice or "Cherry",
"text": text,
}
if not self.voice:
logging.warning(
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
from pyffmpeg import FFmpeg
ff = FFmpeg()
ff.convert(input=mp3_path, output=wav_path)
ff.convert(input_file=mp3_path, output_file=wav_path)
except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
self.headers = {
"Authorization": f"Bearer {self.chosen_api_key}",
}
self.set_model(provider_config.get("model"))
self.set_model(provider_config["model"])
async def _get_reference_id_by_character(self, character: str) -> str:
async def _get_reference_id_by_character(self, character: str) -> str | None:
"""获取角色的reference_id
Args:
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
pattern = r"^[a-fA-F0-9]{32}$"
return bool(re.match(pattern, reference_id.strip()))
async def _generate_request(self, text: str) -> dict:
async def _generate_request(self, text: str) -> ServeTTSRequest:
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
if self.reference_id and self.reference_id.strip():
# 验证reference_id格式
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
async for chunk in response.aiter_bytes():
f.write(chunk)
return path
text = await response.aread()
body = await response.aread()
text = body.decode("utf-8", errors="replace")
raise Exception(f"Fish Audio API请求失败: {text}")
@@ -1,3 +1,5 @@
from typing import cast
from google import genai
from google.genai import types
from google.genai.errors import APIError
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
self.provider_config = provider_config
self.provider_settings = provider_settings
api_key: str = provider_config.get("embedding_api_key")
api_base: str = provider_config.get("embedding_api_base")
api_key: str = provider_config["embedding_api_key"]
api_base: str = provider_config["embedding_api_base"]
timeout: int = int(provider_config.get("timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000)
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
model=self.model,
contents=text,
)
assert result.embeddings is not None
assert result.embeddings[0].values is not None
return result.embeddings[0].values
except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}")
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
try:
result = await self.client.models.embed_content(
model=self.model,
contents=texts,
contents=cast(types.ContentListUnion, text),
)
return [embedding.values for embedding in result.embeddings]
assert result.embeddings is not None
embeddings: list[list[float]] = []
for embedding in result.embeddings:
assert embedding.values is not None
embeddings.append(embedding.values)
return embeddings
except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
+144 -53
View File
@@ -4,6 +4,7 @@ import json
import logging
import random
from collections.abc import AsyncGenerator
from typing import cast
from google import genai
from google.genai import types
@@ -12,8 +13,9 @@ from google.genai.errors import APIError
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import ContentPart
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url
@@ -67,7 +69,7 @@ class ProviderGoogleGenAI(Provider):
self.api_base = self.api_base[:-1]
self._init_client()
self.set_model(provider_config["model_config"]["model"])
self.set_model(provider_config.get("model", "unknown"))
self._init_safety_settings()
def _init_client(self) -> None:
@@ -126,18 +128,18 @@ class ProviderGoogleGenAI(Provider):
) -> types.GenerateContentConfig:
"""准备查询配置"""
if not modalities:
modalities = ["Text"]
modalities = ["TEXT"]
# 流式输出不支持图片模态
if (
self.provider_settings.get("streaming_response", False)
and "Image" in modalities
and "IMAGE" in modalities
):
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
modalities = ["Text"]
modalities = ["TEXT"]
tool_list = []
model_name = self.get_model()
tool_list: list[types.Tool] | None = []
model_name = cast(str, payloads.get("model", self.get_model()))
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
native_search = self.provider_config.get("gm_native_search", False)
url_context = self.provider_config.get("gm_url_context", False)
@@ -196,6 +198,53 @@ class ProviderGoogleGenAI(Provider):
types.Tool(function_declarations=func_desc["function_declarations"]),
]
# oper thinking config
thinking_config = None
if model_name in [
"gemini-2.5-pro",
"gemini-2.5-pro-preview",
"gemini-2.5-flash",
"gemini-2.5-flash-preview",
"gemini-2.5-flash-lite",
"gemini-2.5-flash-lite-preview",
"gemini-robotics-er-1.5-preview",
"gemini-live-2.5-flash-preview-native-audio-09-2025",
]:
# The thinkingBudget parameter, introduced with the Gemini 2.5 series
thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
"budget", 0
)
if thinking_budget is not None:
thinking_config = types.ThinkingConfig(
thinking_budget=thinking_budget,
)
elif model_name in [
"gemini-3-pro",
"gemini-3-pro-preview",
"gemini-3-flash",
"gemini-3-flash-preview",
"gemini-3-flash-lite",
"gemini-3-flash-lite-preview",
]:
# The thinkingLevel parameter, recommended for Gemini 3 models and onwards
# Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
"level", "HIGH"
)
if thinking_level and isinstance(thinking_level, str):
thinking_level = thinking_level.upper()
if thinking_level not in ["MINIMAL", "LOW", "MEDIUM", "HIGH"]:
logger.warning(
f"Invalid thinking level: {thinking_level}, using HIGH"
)
thinking_level = "HIGH"
level = types.ThinkingLevel(thinking_level)
thinking_config = types.ThinkingConfig()
if not hasattr(types.ThinkingConfig, "thinking_level"):
setattr(types.ThinkingConfig, "thinking_level", level)
else:
thinking_config.thinking_level = level
return types.GenerateContentConfig(
system_instruction=system_instruction,
temperature=temperature,
@@ -213,24 +262,9 @@ class ProviderGoogleGenAI(Provider):
logprobs=payloads.get("logprobs"),
seed=payloads.get("seed"),
response_modalities=modalities,
tools=tool_list,
tools=cast(types.ToolListUnion | None, tool_list),
safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=(
types.ThinkingConfig(
thinking_budget=min(
int(
self.provider_config.get("gm_thinking_config", {}).get(
"budget",
0,
),
),
24576,
),
)
if "gemini-2.5-flash" in self.get_model()
and hasattr(types.ThinkingConfig, "thinking_budget")
else None
),
thinking_config=thinking_config,
automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True,
),
@@ -257,6 +291,7 @@ class ProviderGoogleGenAI(Provider):
content_cls: type[types.Content],
) -> None:
if contents and isinstance(contents[-1], content_cls):
assert contents[-1].parts is not None
contents[-1].parts.extend(part)
else:
contents.append(content_cls(parts=part))
@@ -345,6 +380,16 @@ class ProviderGoogleGenAI(Provider):
]
return "".join(thought_buf).strip()
def _extract_usage(
self, usage_metadata: types.GenerateContentResponseUsageMetadata
) -> TokenUsage:
"""Extract usage from candidate"""
return TokenUsage(
input_other=usage_metadata.prompt_token_count or 0,
input_cached=usage_metadata.cached_content_token_count or 0,
output=usage_metadata.candidates_token_count or 0,
)
def _process_content_parts(
self,
candidate: types.Candidate,
@@ -429,9 +474,11 @@ class ProviderGoogleGenAI(Provider):
None,
)
modalities = ["Text"]
model = payloads.get("model", self.get_model())
modalities = ["TEXT"]
if self.provider_config.get("gm_resp_image_modal", False):
modalities.append("Image")
modalities.append("IMAGE")
conversation = self._prepare_conversation(payloads)
temperature = payloads.get("temperature", 0.7)
@@ -447,8 +494,8 @@ class ProviderGoogleGenAI(Provider):
temperature,
)
result = await self.client.models.generate_content(
model=self.get_model(),
contents=conversation,
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
)
logger.debug(f"genai result: {result}")
@@ -473,11 +520,11 @@ class ProviderGoogleGenAI(Provider):
e.message = ""
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
logger.warning(f"{model} 不支持函数调用,已自动去除")
tools = None
elif (
"Multi-modal output is not supported" in e.message
@@ -486,9 +533,9 @@ class ProviderGoogleGenAI(Provider):
or "only supports text output" in e.message
):
logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
f"{model} 不支持多模态输出,降级为文本模态",
)
modalities = ["Text"]
modalities = ["TEXT"]
else:
raise
continue
@@ -499,6 +546,9 @@ class ProviderGoogleGenAI(Provider):
result.candidates[0],
llm_response,
)
llm_response.id = result.response_id
if result.usage_metadata:
llm_response.usage = self._extract_usage(result.usage_metadata)
return llm_response
async def _query_stream(
@@ -511,7 +561,7 @@ class ProviderGoogleGenAI(Provider):
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None,
)
model = payloads.get("model", self.get_model())
conversation = self._prepare_conversation(payloads)
result = None
@@ -523,8 +573,8 @@ class ProviderGoogleGenAI(Provider):
system_instruction,
)
result = await self.client.models.generate_content_stream(
model=self.get_model(),
contents=conversation,
model=model,
contents=cast(types.ContentListUnion, conversation),
config=config,
)
break
@@ -533,11 +583,11 @@ class ProviderGoogleGenAI(Provider):
e.message = ""
if "Developer instruction is not enabled" in e.message:
logger.warning(
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
)
system_instruction = None
elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
logger.warning(f"{model} 不支持函数调用,已自动去除")
tools = None
else:
raise
@@ -567,6 +617,9 @@ class ProviderGoogleGenAI(Provider):
chunk.candidates[0],
llm_response,
)
llm_response.id = chunk.response_id
if chunk.usage_metadata:
llm_response.usage = self._extract_usage(chunk.usage_metadata)
yield llm_response
return
@@ -594,6 +647,9 @@ class ProviderGoogleGenAI(Provider):
chunk.candidates[0],
final_response,
)
final_response.id = chunk.response_id
if chunk.usage_metadata:
final_response.usage = self._extract_usage(chunk.usage_metadata)
break
# Yield final complete response with accumulated text
@@ -625,13 +681,16 @@ class ProviderGoogleGenAI(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -650,10 +709,9 @@ class ProviderGoogleGenAI(Provider):
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": context_query, **model_config}
payloads = {"messages": context_query, "model": model}
retry = 10
keys = self.api_keys.copy()
@@ -678,13 +736,16 @@ class ProviderGoogleGenAI(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -703,10 +764,9 @@ class ProviderGoogleGenAI(Provider):
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": context_query, **model_config}
payloads = {"messages": context_query, "model": model}
retry = 10
keys = self.api_keys.copy()
@@ -744,13 +804,33 @@ class ProviderGoogleGenAI(Provider):
self.chosen_api_key = key
self._init_client()
async def assemble_context(self, text: str, image_urls: list[str] | None = None):
async def assemble_context(
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文。"""
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content_blocks.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content_blocks.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for part in extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
@@ -763,14 +843,25 @@ class ProviderGoogleGenAI(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
content_blocks.append(
{
"type": "image_url",
"image_url": {"url": image_data},
},
)
return user_content
return {"role": "user", "content": text}
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return json.dumps(dict_body)
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]:
async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
"""进行流式请求"""
try:
async with (
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
data = json.loads(message[6:])
if "extra_info" in data:
continue
audio = data.get("data", {}).get("audio")
audio: str | None = data.get("data", {}).get(
"audio"
)
if audio is not None:
yield audio
except json.JSONDecodeError:
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding
async def get_embeddings(self, texts: list[str]) -> list[list[float]]:
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入"""
embeddings = await self.client.embeddings.create(input=texts, model=self.model)
embeddings = await self.client.embeddings.create(input=text, model=self.model)
return [item.embedding for item in embeddings.data]
def get_dim(self) -> int:
+68 -15
View File
@@ -12,14 +12,15 @@ from openai._exceptions import NotFoundError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.completion_usage import CompletionUsage
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import Message
from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter
@@ -68,8 +69,7 @@ class ProviderOpenAIOfficial(Provider):
self.client.chat.completions.create,
).parameters.keys()
model_config = provider_config.get("model_config", {})
model = model_config.get("model", "unknown")
model = provider_config.get("model", "unknown")
self.set_model(model)
self.reasoning_key = "reasoning_content"
@@ -208,6 +208,7 @@ class ProviderOpenAIOfficial(Provider):
# handle the content delta
reasoning = self._extract_reasoning_content(chunk)
_y = False
llm_response.id = chunk.id
if reasoning:
llm_response.reasoning_content = reasoning
_y = True
@@ -217,6 +218,8 @@ class ProviderOpenAIOfficial(Provider):
chain=[Comp.Plain(completion_text)],
)
_y = True
if chunk.usage:
llm_response.usage = self._extract_usage(chunk.usage)
if _y:
yield llm_response
@@ -245,6 +248,15 @@ class ProviderOpenAIOfficial(Provider):
reasoning_text = str(reasoning_attr)
return reasoning_text
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
ptd = usage.prompt_tokens_details
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
return TokenUsage(
input_other=usage.prompt_tokens - cached,
input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
output=usage.completion_tokens,
)
async def _parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse:
@@ -284,6 +296,10 @@ class ProviderOpenAIOfficial(Provider):
if isinstance(tool_call, str):
# workaround for #1359
tool_call = json.loads(tool_call)
if tools is None:
# 工具集未提供
# Should be unreachable
raise Exception("工具集未提供")
for tool in tools.func_list:
if (
tool_call.type == "function"
@@ -317,6 +333,10 @@ class ProviderOpenAIOfficial(Provider):
raise Exception(f"API 返回的 completion 无法解析:{completion}")
llm_response.raw_completion = completion
llm_response.id = completion.id
if completion.usage:
llm_response.usage = self._extract_usage(completion.usage)
return llm_response
@@ -328,6 +348,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
@@ -335,7 +356,9 @@ class ProviderOpenAIOfficial(Provider):
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
@@ -354,10 +377,9 @@ class ProviderOpenAIOfficial(Provider):
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
model = model or self.get_model()
payloads = {"messages": context_query, **model_config}
payloads = {"messages": context_query, "model": model}
# xAI origin search tool inject
self._maybe_inject_xai_search(payloads, **kwargs)
@@ -457,6 +479,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
@@ -466,6 +489,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt,
tool_calls_result,
model=model,
extra_user_content_parts=extra_user_content_parts,
**kwargs,
)
@@ -520,6 +544,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt=None,
tool_calls_result=None,
model=None,
extra_user_content_parts=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
@@ -530,6 +555,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt,
tool_calls_result,
model=model,
extra_user_content_parts=extra_user_content_parts,
**kwargs,
)
@@ -605,13 +631,29 @@ class ProviderOpenAIOfficial(Provider):
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content_blocks.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content_blocks.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for part in extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
@@ -624,14 +666,25 @@ class ProviderOpenAIOfficial(Provider):
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
content_blocks.append(
{
"type": "image_url",
"image_url": {"url": image_data},
},
)
return user_content
return {"role": "user", "content": text}
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
@@ -7,6 +7,7 @@ import asyncio
import os
import re
from datetime import datetime
from typing import cast
from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("stt_model"))
self.set_model(provider_config["stt_model"])
self.model = None
self.is_emotion = provider_config.get("is_emotion", False)
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
loop = asyncio.get_event_loop()
res = await loop.run_in_executor(
None, # 使用默认的线程池
lambda: self.model(audio_url, language="auto", use_itn=True),
lambda: cast(SenseVoiceSmall, self.model)(
audio_url, language="auto", use_itn=True
),
)
# res = self.model(audio_url, language="auto", use_itn=True)
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
}
if top_n is not None:
payload["top_n"] = top_n
assert self.client is not None
async with self.client.post(
f"{self.base_url}/v1/rerank",
json=payload,
@@ -36,7 +36,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model"))
self.set_model(provider_config["model"])
async def _get_audio_format(self, file_path):
# 定义要检测的头部字节
@@ -1,6 +1,7 @@
import asyncio
import os
import uuid
from typing import cast
import whisper
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("model"))
self.set_model(provider_config["model"])
self.model = None
async def initialize(self):
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path
if not self.model:
raise RuntimeError("Whisper 模型未初始化")
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
return result["text"]
return cast(str, result["text"])
@@ -1,6 +1,11 @@
from typing import cast
from xinference_client.client.restful.async_restful_client import (
AsyncClient as Client,
)
from xinference_client.client.restful.async_restful_client import (
AsyncRESTfulRerankModelHandle,
)
from astrbot import logger
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
False,
)
self.client = None
self.model = None
self.model: AsyncRESTfulRerankModelHandle | None = None
self.model_uid = None
async def initialize(self):
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
return
if self.model_uid:
self.model = await self.client.get_model(self.model_uid)
self.model = cast(
AsyncRESTfulRerankModelHandle,
await self.client.get_model(self.model_uid),
)
except Exception as e:
logger.error(f"Failed to initialize Xinference model: {e}")
+5 -1
View File
@@ -2,15 +2,19 @@ from astrbot.core import html_renderer
from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .context import Context
from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager
class Star(CommandParserMixin):
class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context)
self.context = context
+496
View File
@@ -0,0 +1,496 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from astrbot.core import db_helper, logger
from astrbot.core.db.po import CommandConfig
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
@dataclass
class CommandDescriptor:
handler: StarHandlerMetadata = field(repr=False)
filter_ref: CommandFilter | CommandGroupFilter | None = field(
default=None,
repr=False,
)
handler_full_name: str = ""
handler_name: str = ""
plugin_name: str = ""
plugin_display_name: str | None = None
module_path: str = ""
description: str = ""
command_type: str = "command" # "command" | "group" | "sub_command"
raw_command_name: str | None = None
current_fragment: str | None = None
parent_signature: str = ""
parent_group_handler: str = ""
original_command: str | None = None
effective_command: str | None = None
aliases: list[str] = field(default_factory=list)
permission: str = "everyone"
enabled: bool = True
is_group: bool = False
is_sub_command: bool = False
reserved: bool = False
config: CommandConfig | None = None
has_conflict: bool = False
sub_commands: list[CommandDescriptor] = field(default_factory=list)
async def sync_command_configs() -> None:
"""同步指令配置,清理过期配置。"""
descriptors = _collect_descriptors(include_sub_commands=False)
config_records = await db_helper.get_command_configs()
config_map = _bind_configs_to_descriptors(descriptors, config_records)
live_handlers = {desc.handler_full_name for desc in descriptors}
stale_configs = [key for key in config_map if key not in live_handlers]
if stale_configs:
await db_helper.delete_command_configs(stale_configs)
async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescriptor:
descriptor = _build_descriptor_by_full_name(handler_full_name)
if not descriptor:
raise ValueError("指定的处理函数不存在或不是指令。")
existing_cfg = await db_helper.get_command_config(handler_full_name)
config = await db_helper.upsert_command_config(
handler_full_name=handler_full_name,
plugin_name=descriptor.plugin_name or "",
module_path=descriptor.module_path,
original_command=descriptor.original_command or descriptor.handler_name,
resolved_command=(
existing_cfg.resolved_command
if existing_cfg
else descriptor.current_fragment
),
enabled=enabled,
keep_original_alias=False,
conflict_key=existing_cfg.conflict_key
if existing_cfg and existing_cfg.conflict_key
else descriptor.original_command,
resolution_strategy=existing_cfg.resolution_strategy if existing_cfg else None,
note=existing_cfg.note if existing_cfg else None,
extra_data=existing_cfg.extra_data if existing_cfg else None,
auto_managed=False,
)
_bind_descriptor_with_config(descriptor, config)
await sync_command_configs()
return descriptor
async def rename_command(
handler_full_name: str,
new_fragment: str,
aliases: list[str] | None = None,
) -> CommandDescriptor:
descriptor = _build_descriptor_by_full_name(handler_full_name)
if not descriptor:
raise ValueError("指定的处理函数不存在或不是指令。")
new_fragment = new_fragment.strip()
if not new_fragment:
raise ValueError("指令名不能为空。")
# 校验主指令名
candidate_full = _compose_command(descriptor.parent_signature, new_fragment)
if _is_command_in_use(handler_full_name, candidate_full):
raise ValueError(f"指令名 '{candidate_full}' 已被其他指令占用。")
# 校验别名
if aliases:
for alias in aliases:
alias = alias.strip()
if not alias:
continue
alias_full = _compose_command(descriptor.parent_signature, alias)
if _is_command_in_use(handler_full_name, alias_full):
raise ValueError(f"别名 '{alias_full}' 已被其他指令占用。")
existing_cfg = await db_helper.get_command_config(handler_full_name)
merged_extra = dict(existing_cfg.extra_data or {}) if existing_cfg else {}
merged_extra["resolved_aliases"] = aliases or []
config = await db_helper.upsert_command_config(
handler_full_name=handler_full_name,
plugin_name=descriptor.plugin_name or "",
module_path=descriptor.module_path,
original_command=descriptor.original_command or descriptor.handler_name,
resolved_command=new_fragment,
enabled=True if descriptor.enabled else False,
keep_original_alias=False,
conflict_key=descriptor.original_command,
resolution_strategy="manual_rename",
note=None,
extra_data=merged_extra,
auto_managed=False,
)
_bind_descriptor_with_config(descriptor, config)
await sync_command_configs()
return descriptor
async def list_commands() -> list[dict[str, Any]]:
descriptors = _collect_descriptors(include_sub_commands=True)
config_records = await db_helper.get_command_configs()
_bind_configs_to_descriptors(descriptors, config_records)
conflict_groups = _group_conflicts(descriptors)
conflict_handler_names: set[str] = {
d.handler_full_name for group in conflict_groups.values() for d in group
}
# 分类,设置冲突标志,将子指令挂载到父指令组
group_map: dict[str, CommandDescriptor] = {}
sub_commands: list[CommandDescriptor] = []
root_commands: list[CommandDescriptor] = []
for desc in descriptors:
desc.has_conflict = desc.handler_full_name in conflict_handler_names
if desc.is_group:
group_map[desc.handler_full_name] = desc
elif desc.is_sub_command:
sub_commands.append(desc)
else:
root_commands.append(desc)
for sub in sub_commands:
if sub.parent_group_handler and sub.parent_group_handler in group_map:
group_map[sub.parent_group_handler].sub_commands.append(sub)
else:
root_commands.append(sub)
# 指令组 + 普通指令,按 effective_command 字母排序
all_commands = list(group_map.values()) + root_commands
all_commands.sort(key=lambda d: (d.effective_command or "").lower())
result = [_descriptor_to_dict(desc) for desc in all_commands]
return result
async def list_command_conflicts() -> list[dict[str, Any]]:
"""列出所有冲突的指令组。"""
descriptors = _collect_descriptors(include_sub_commands=False)
config_records = await db_helper.get_command_configs()
_bind_configs_to_descriptors(descriptors, config_records)
conflict_groups = _group_conflicts(descriptors)
details = [
{
"conflict_key": key,
"handlers": [
{
"handler_full_name": item.handler_full_name,
"plugin": item.plugin_name,
"current_name": item.effective_command,
}
for item in group
],
}
for key, group in conflict_groups.items()
]
return details
# Internal helpers ----------------------------------------------------------
def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]:
"""收集指令,按需包含子指令。"""
descriptors: list[CommandDescriptor] = []
for handler in star_handlers_registry:
try:
desc = _build_descriptor(handler)
if not desc:
continue
if not include_sub_commands and desc.is_sub_command:
continue
descriptors.append(desc)
except Exception as e:
logger.warning(
f"解析指令处理函数 {handler.handler_full_name} 失败,跳过该指令。原因: {e!s}"
)
continue
return descriptors
def _build_descriptor(handler: StarHandlerMetadata) -> CommandDescriptor | None:
filter_ref = _locate_primary_filter(handler)
if filter_ref is None:
return None
plugin_meta = star_map.get(handler.handler_module_path)
plugin_name = (
plugin_meta.name if plugin_meta else None
) or handler.handler_module_path
plugin_display = plugin_meta.display_name if plugin_meta else None
is_sub_command = bool(handler.extras_configs.get("sub_command"))
parent_group_handler = ""
if isinstance(filter_ref, CommandFilter):
raw_fragment = getattr(
filter_ref, "_original_command_name", filter_ref.command_name
)
current_fragment = filter_ref.command_name
parent_signature = (filter_ref.parent_command_names or [""])[0].strip()
# 如果是子指令,尝试找到父指令组的 handler_full_name
if is_sub_command and parent_signature:
parent_group_handler = _find_parent_group_handler(
handler.handler_module_path, parent_signature
)
else:
raw_fragment = getattr(
filter_ref, "_original_group_name", filter_ref.group_name
)
current_fragment = filter_ref.group_name
parent_signature = _resolve_group_parent_signature(filter_ref)
original_command = _compose_command(parent_signature, raw_fragment)
effective_command = _compose_command(parent_signature, current_fragment)
# 确定 command_type
if isinstance(filter_ref, CommandGroupFilter):
command_type = "group"
elif is_sub_command:
command_type = "sub_command"
else:
command_type = "command"
descriptor = CommandDescriptor(
handler=handler,
filter_ref=filter_ref,
handler_full_name=handler.handler_full_name,
handler_name=handler.handler_name,
plugin_name=plugin_name,
plugin_display_name=plugin_display,
module_path=handler.handler_module_path,
description=handler.desc or "",
command_type=command_type,
raw_command_name=raw_fragment,
current_fragment=current_fragment,
parent_signature=parent_signature,
parent_group_handler=parent_group_handler,
original_command=original_command,
effective_command=effective_command,
aliases=sorted(getattr(filter_ref, "alias", set())),
permission=_determine_permission(handler),
enabled=handler.enabled,
is_group=isinstance(filter_ref, CommandGroupFilter),
is_sub_command=is_sub_command,
reserved=plugin_meta.reserved if plugin_meta else False,
)
return descriptor
def _build_descriptor_by_full_name(full_name: str) -> CommandDescriptor | None:
handler = star_handlers_registry.get_handler_by_full_name(full_name)
if not handler:
return None
return _build_descriptor(handler)
def _locate_primary_filter(
handler: StarHandlerMetadata,
) -> CommandFilter | CommandGroupFilter | None:
for filter_ref in handler.event_filters:
if isinstance(filter_ref, (CommandFilter, CommandGroupFilter)):
return filter_ref
return None
def _determine_permission(handler: StarHandlerMetadata) -> str:
for filter_ref in handler.event_filters:
if isinstance(filter_ref, PermissionTypeFilter):
return (
"admin"
if filter_ref.permission_type == PermissionType.ADMIN
else "member"
)
return "everyone"
def _resolve_group_parent_signature(group_filter: CommandGroupFilter) -> str:
signatures: list[str] = []
parent = group_filter.parent_group
while parent:
signatures.append(getattr(parent, "_original_group_name", parent.group_name))
parent = parent.parent_group
return " ".join(reversed(signatures)).strip()
def _find_parent_group_handler(module_path: str, parent_signature: str) -> str:
"""根据模块路径和父级签名,找到对应的指令组 handler_full_name。"""
parent_sig_normalized = parent_signature.strip()
for handler in star_handlers_registry:
if handler.handler_module_path != module_path:
continue
filter_ref = _locate_primary_filter(handler)
if not isinstance(filter_ref, CommandGroupFilter):
continue
# 检查该指令组的完整指令名是否匹配 parent_signature
group_names = filter_ref.get_complete_command_names()
if parent_sig_normalized in group_names:
return handler.handler_full_name
return ""
def _compose_command(parent_signature: str, fragment: str | None) -> str:
fragment = (fragment or "").strip()
parent_signature = parent_signature.strip()
if not parent_signature:
return fragment
if not fragment:
return parent_signature
return f"{parent_signature} {fragment}"
def _bind_descriptor_with_config(
descriptor: CommandDescriptor,
config: CommandConfig,
) -> None:
_apply_config_to_descriptor(descriptor, config)
_apply_config_to_runtime(descriptor, config)
def _apply_config_to_descriptor(
descriptor: CommandDescriptor,
config: CommandConfig,
) -> None:
descriptor.config = config
descriptor.enabled = config.enabled
if config.original_command:
descriptor.original_command = config.original_command
new_fragment = config.resolved_command or descriptor.current_fragment
descriptor.current_fragment = new_fragment
descriptor.effective_command = _compose_command(
descriptor.parent_signature,
new_fragment,
)
extra = config.extra_data or {}
resolved_aliases = extra.get("resolved_aliases")
if isinstance(resolved_aliases, list):
descriptor.aliases = [str(x) for x in resolved_aliases if str(x).strip()]
def _apply_config_to_runtime(
descriptor: CommandDescriptor,
config: CommandConfig,
) -> None:
descriptor.handler.enabled = config.enabled
if descriptor.filter_ref:
if descriptor.current_fragment:
_set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
extra = config.extra_data or {}
resolved_aliases = extra.get("resolved_aliases")
if isinstance(resolved_aliases, list):
_set_filter_aliases(
descriptor.filter_ref,
[str(x) for x in resolved_aliases if str(x).strip()],
)
def _bind_configs_to_descriptors(
descriptors: list[CommandDescriptor],
config_records: list[CommandConfig],
) -> dict[str, CommandConfig]:
config_map = {cfg.handler_full_name: cfg for cfg in config_records}
for desc in descriptors:
if cfg := config_map.get(desc.handler_full_name):
_bind_descriptor_with_config(desc, cfg)
return config_map
def _group_conflicts(
descriptors: list[CommandDescriptor],
) -> dict[str, list[CommandDescriptor]]:
conflicts: dict[str, list[CommandDescriptor]] = defaultdict(list)
for desc in descriptors:
if desc.effective_command and desc.enabled:
conflicts[desc.effective_command].append(desc)
return {k: v for k, v in conflicts.items() if len(v) > 1}
def _set_filter_fragment(
filter_ref: CommandFilter | CommandGroupFilter,
fragment: str,
) -> None:
attr = (
"group_name" if isinstance(filter_ref, CommandGroupFilter) else "command_name"
)
current_value = getattr(filter_ref, attr)
if fragment == current_value:
return
setattr(filter_ref, attr, fragment)
if hasattr(filter_ref, "_cmpl_cmd_names"):
filter_ref._cmpl_cmd_names = None
def _set_filter_aliases(
filter_ref: CommandFilter | CommandGroupFilter,
aliases: list[str],
) -> None:
current_aliases = getattr(filter_ref, "alias", set())
if set(aliases) == current_aliases:
return
setattr(filter_ref, "alias", set(aliases))
if hasattr(filter_ref, "_cmpl_cmd_names"):
filter_ref._cmpl_cmd_names = None
def _is_command_in_use(
target_handler_full_name: str,
candidate_full_command: str,
) -> bool:
candidate = candidate_full_command.strip()
for handler in star_handlers_registry:
if handler.handler_full_name == target_handler_full_name:
continue
filter_ref = _locate_primary_filter(handler)
if not filter_ref:
continue
names = {name.strip() for name in filter_ref.get_complete_command_names()}
if candidate in names:
return True
return False
def _descriptor_to_dict(desc: CommandDescriptor) -> dict[str, Any]:
result = {
"handler_full_name": desc.handler_full_name,
"handler_name": desc.handler_name,
"plugin": desc.plugin_name,
"plugin_display_name": desc.plugin_display_name,
"module_path": desc.module_path,
"description": desc.description,
"type": desc.command_type,
"parent_signature": desc.parent_signature,
"parent_group_handler": desc.parent_group_handler,
"original_command": desc.original_command,
"current_fragment": desc.current_fragment,
"effective_command": desc.effective_command,
"aliases": desc.aliases,
"permission": desc.permission,
"enabled": desc.enabled,
"is_group": desc.is_group,
"has_conflict": desc.has_conflict,
"reserved": desc.reserved,
}
# 如果是指令组,包含子指令列表
if desc.is_group and desc.sub_commands:
result["sub_commands"] = [_descriptor_to_dict(sub) for sub in desc.sub_commands]
else:
result["sub_commands"] = []
return result
+6 -2
View File
@@ -267,6 +267,10 @@ class Context:
):
"""通过 ID 获取对应的 LLM Provider。"""
prov = self.provider_manager.inst_map.get(provider_id)
if provider_id and not prov:
logger.warning(
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
)
return prov
def get_all_providers(self) -> list[Provider]:
@@ -285,7 +289,7 @@ class Context:
"""获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str | None = None) -> Provider | None:
def get_using_provider(self, umo: str | None = None) -> Provider:
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
Args:
@@ -296,7 +300,7 @@ class Context:
provider_type=ProviderType.CHAT_COMPLETION,
umo=umo,
)
if prov and not isinstance(prov, Provider):
if not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型")
return prov
+1
View File
@@ -40,6 +40,7 @@ class CommandFilter(HandlerFilter):
):
self.command_name = command_name
self.alias = alias if alias else set()
self._original_command_name = command_name
self.parent_command_names = (
parent_command_names if parent_command_names is not None else [""]
)
@@ -18,6 +18,7 @@ class CommandGroupFilter(HandlerFilter):
):
self.group_name = group_name
self.alias = alias if alias else set()
self._original_group_name = group_name
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
self.custom_filter_list: list[CustomFilter] = []
self.parent_group = parent_group
+22 -5
View File
@@ -1,7 +1,7 @@
from __future__ import annotations
import re
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any
import docstring_parser
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
def get_handler_full_name(
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> str:
"""获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(
handler: Callable[..., Awaitable[Any]],
handler: Callable[
...,
Awaitable[MessageEventResult | str | None]
| AsyncGenerator[MessageEventResult | str | None],
],
event_type: EventType,
dont_add=False,
**kwargs,
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
for (
sub_handle
) in parent_register_commandable.parent_group.sub_command_filters:
if isinstance(sub_handle, CommandGroupFilter):
continue
# 所有符合fullname一致的子指令handle添加自定义过滤器。
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
sub_handle_md = sub_handle.get_handler_md()
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
else:
# 裸指令
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
assert isinstance(awaitable, Callable)
handler_md = get_handler_or_create(
awaitable,
EventType.AdapterMessageEvent,
@@ -237,7 +248,7 @@ class RegisteringCommandable:
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
command: Callable[..., Callable[..., None]] = register_command
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter
custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"]
def decorator(awaitable: Callable[..., Awaitable[Any]]):
def decorator(
awaitable: Callable[
...,
AsyncGenerator[MessageEventResult | str | None]
| Awaitable[MessageEventResult | str | None],
],
):
llm_tool_name = name_ if name_ else awaitable.__name__
func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc)
+89 -4
View File
@@ -1,9 +1,9 @@
from __future__ import annotations
import enum
from collections.abc import Awaitable, Callable
from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar
from typing import Any, Generic, Literal, TypeVar, overload
from .filter import HandlerFilter
from .star import star_map
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
for handler in self._handlers:
print(handler.handler_full_name)
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAstrBotLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnPlatformLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.AdapterMessageEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMRequestEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMResponseEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnDecoratingResultEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnCallingFuncToolEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAfterMessageSentEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: EventType,
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
def get_handlers_by_event_type(
self,
event_type: EventType,
@@ -40,6 +118,8 @@ class StarHandlerRegistry(Generic[T]):
# 过滤事件类型
if handler.event_type != event_type:
continue
if not handler.enabled:
continue
# 过滤启用状态
if only_activated:
plugin = star_map.get(handler.handler_module_path)
@@ -111,8 +191,11 @@ class EventType(enum.Enum):
OnAfterMessageSentEvent = enum.auto() # 发送消息后
H = TypeVar("H", bound=Callable[..., Any])
@dataclass
class StarHandlerMetadata:
class StarHandlerMetadata(Generic[H]):
"""描述一个 Star 所注册的某一个 Handler。"""
event_type: EventType
@@ -127,7 +210,7 @@ class StarHandlerMetadata:
handler_module_path: str
"""Handler 所在的模块路径。"""
handler: Callable[..., Awaitable[Any]]
handler: H
"""Handler 的函数对象,应当是一个异步函数"""
event_filters: list[HandlerFilter]
@@ -139,6 +222,8 @@ class StarHandlerMetadata:
extras_configs: dict = field(default_factory=dict)
"""插件注册的一些其他的信息, 如 priority 等"""
enabled: bool = True
def __lt__(self, other: StarHandlerMetadata):
"""定义小于运算符以支持优先队列"""
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
+18
View File
@@ -23,6 +23,7 @@ from astrbot.core.utils.astrbot_path import (
from astrbot.core.utils.io import remove_dir
from . import StarMetadata
from .command_management import sync_command_configs
from .context import Context
from .filter.permission import PermissionType, PermissionTypeFilter
from .star import star_map, star_registry
@@ -467,6 +468,18 @@ class PluginManager:
metadata.star_cls = metadata.star_cls_type(
context=self.context,
)
p_name = (metadata.name or "unknown").lower().replace("/", "_")
p_author = (
(metadata.author or "unknown").lower().replace("/", "_")
)
setattr(metadata.star_cls, "name", p_name)
setattr(metadata.star_cls, "author", p_author)
setattr(
metadata.star_cls,
"plugin_id",
f"{p_author}/{p_name}",
)
else:
logger.info(f"插件 {metadata.name} 已被禁用。")
@@ -618,6 +631,11 @@ class PluginManager:
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
try:
await sync_command_configs()
except Exception as e:
logger.error(f"同步指令配置失败: {e!s}")
logger.error(traceback.format_exc())
if not fail_rec:
return True, None

Some files were not shown because too many files have changed in this diff Show More