Compare commits

..

108 Commits

Author SHA1 Message Date
Soulter 7cedf0d587 chore: improve documentation for extra_user_content_parts in Provider classes 2025-12-26 21:55:44 +08:00
kawayiYokami aeb21f719e claude额外块支持图片模态 2025-12-26 21:54:01 +08:00
Soulter 7c1dbecea5 refactor: unify extra_user_content_parts type to ContentPart across providers and update related handling 2025-12-26 21:47:02 +08:00
kawayiYokami 05012af627 重命名 2025-12-26 20:54:38 +08:00
kawayiYokami 17b52ab5dd 传递链 2025-12-26 18:57:51 +08:00
kawayiYokami 9449ff668b FIX 2025-12-25 13:33:40 +08:00
kawayiYokami c5a2827def feat: 多文本块功能 2025-12-25 03:54:05 +08:00
Soulter 701399c00c docs: update readme xmas 2025-12-24 21:58:04 +08:00
Soulter eaee98d4b8 chore: bump version to 4.10.2 2025-12-24 21:55:05 +08:00
Soulter 76c66000a7 chore: restrict psutil version <7.2.0 to avoid compatibility issues
fixes: #4176
2025-12-24 15:48:58 +08:00
Oscar Shaw 4b365143c0 feat: support for managing command aliases (#4170)
* feat(command): persist aliases on rename and apply to runtime filter

* feat(dashboard-api): support aliases in rename command endpoint

* feat(dashboard-ui): add alias editor to rename command dialog

* feat(dashboard-ui): enhance alias editor UI in rename dialog
2025-12-24 15:37:10 +08:00
Soulter 6e4e5011e2 chore: bump version to 4.10.1 2025-12-23 21:35:40 +08:00
Venus Yan d853bfde84 perf: handle unsupported message types with logging in OneBot adapter (#4164)
* Handle unsupported message types with logging

解决else 分支中对未知消息类型毫无防御,直接索引ComponentTypes[t],导致新类型markdown类信息报错并炸掉事件管道,且对应群聊单群永久不响应插件;尝试支持markdown类型进行支持但未经过测试

* chore: ruff format

* chore: ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-23 21:31:32 +08:00
Soulter a0e856f80f fix: provider source id contains slash will lead to 405 (#4162) 2025-12-22 20:28:20 +08:00
Oscar Shaw 8c94a0010c fix(core): improve error handling of command parser and sync (#4161) 2025-12-22 19:54:26 +08:00
Soulter a44fdaaec0 chore: bump version to 4.10.0 2025-12-22 18:10:30 +08:00
Soulter 60105c76f5 feat: implement router loading progress indicator 2025-12-22 13:20:39 +08:00
Soulter bcf87d3ce4 fix: update provider subtitle for clarity in English and Chinese locales
- Revised the subtitle in the provider feature localization files to provide a more detailed description of functionalities, including chat model configuration and third-party service integrations.
2025-12-22 13:13:42 +08:00
Soulter 4d7c8c8453 style: add active background color for provider source list item in dark theme 2025-12-22 12:59:55 +08:00
Soulter a064a9115f fix: omit thinking params for gemini image generation models (#4151)
- Expanded model name checks to include specific Gemini 2.5 and 3 variants, ensuring correct configuration for thinking parameters based on the model used.
2025-12-22 00:09:30 +08:00
Soulter 6ef99e1553 feat: enhance ChatInput and ConversationSidebar dark theme 2025-12-21 21:19:54 +08:00
Soulter c0dbe5cf65 chore: bump version to 4.10.0-alpha.2 2025-12-21 13:11:32 +08:00
Soulter 3598c51eff fix: enhance provider model menu and sidebar session selection handling (#4144)
- Updated `ProviderModelMenu.vue` to manage menu state and load provider configurations dynamically upon opening.
- Filtered provider configurations to exclude those with `enable` set to false.
- Improved session selection logic in `useSessions.ts` to ensure the currently selected session is highlighted and properly managed during navigation.
2025-12-21 13:05:15 +08:00
Soulter b5cdb8f650 fix: improve error handling in tool execution to prevent infinite tool call loops (#4143)
* fix: improve error handling in tool execution to prevent infinite tool call loops

- Enhanced error handling in `call_local_llm_tool` to provide more informative exceptions for ValueError and TypeError, including detailed parameter information.
- Updated `ToolLoopAgentRunner` to yield appropriate messages for cases with no response or unsupported types, ensuring clearer communication to users.
- Improved logging and messaging consistency across tool execution processes.

* refactor: clean up unused router parameter in message retrieval functions

- Removed the unused `router` parameter from `getSessionMessages` and related function calls in `Chat.vue` and `useMessages.ts`.
- Commented out the `tool_calls` dictionary in `chat.py` for clarity, indicating it is not currently in use.

* fix: enhance exception handling in tool execution for clearer error reporting

- Improved exception handling in `call_local_llm_tool` by chaining exceptions for ValueError and TypeError, providing more context in error messages.
- Ensured that traceback information is preserved in raised exceptions for better debugging.
2025-12-21 12:57:54 +08:00
Yokami fc5b520f9b perf(agent): add max step limit to prevent infinite tool call loops (#4110)
* perf(agent): add max step limit to prevent infinite tool call loops

* feat: implement max step limit handling in main agent runner

- Enhanced the agent runner to enforce a maximum step limit, logging a warning and forcing a final response when the limit is reached.
- Updated message handling to append a user prompt when the tool call limit is exceeded.
- Refactored tool response handling to yield appropriate messages based on the response type, including handling cases with no response or unsupported types.
- Improved conversation message formatting to ensure consistent output in the assistant's responses.

* chore: ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-21 12:30:43 +08:00
Soulter 904f56b32f fix: webui conversation traj data display error (#4142)
fixes: #4141
2025-12-20 23:29:40 +08:00
Soulter 2f15fd019c chore: bump version to v4.10.0-alpha.1 2025-12-20 16:35:54 +08:00
Soulter 82330b8d10 feat: add changelog functionality and dialog component (#4135)
* feat: add changelog functionality and dialog component

- Implemented new routes for fetching changelogs and available versions in StatRoute.
- Created ChangelogDialog.vue for displaying changelog content and version selection.
- Updated VerticalSidebar.vue to include a button for opening the changelog dialog.
- Enhanced localization files for English and Chinese to support new changelog features.
- Adjusted styles in VerticalHeader.vue for improved layout consistency.

* chore: ruff format
2025-12-20 16:33:12 +08:00
Soulter 3ee6af7027 feat: add route watcher for viewMode changes in VerticalHeader.vue
- Introduced a watcher to monitor changes in customizer.viewMode, automatically redirecting to the homepage when switching from 'chat' to 'bot' mode.
- Updated imports to include useRoute from vue-router for routing functionality.
- Adjusted button styles for improved layout consistency in bot mode.
2025-12-20 15:38:01 +08:00
Soulter 6e20ebe901 feat: add KaTeX and Mermaid and computation-friendly renderer support (#4118)
* feat: add KaTeX and Mermaid support for enhanced markdown rendering in MessageList.vue

closes: #3747
- Integrated @mdit/plugin-katex and katex for LaTeX rendering.
- Added markstream-vue for improved markdown rendering capabilities.
- Updated MessageList.vue to utilize MarkdownRender component for rendering markdown content.
- Enhanced UI for dark mode compatibility across various components.
- Introduced new styles for file links, reasoning blocks, and tool call cards to improve visual consistency.

* refactor: replace markdown-it with markstream-vue for improved markdown rendering

- Removed markdown-it and related configurations from ReadmeDialog.vue, VerticalHeader.vue, and ConversationPage.vue.
- Integrated markstream-vue for enhanced markdown rendering capabilities, including support for KaTeX and Mermaid.
- Updated components to utilize MarkdownRender for rendering markdown content, improving consistency and performance.

* chore: remove deprecated markdown-it and marked dependencies from pnpm-lock.yaml

- Cleaned up pnpm-lock.yaml by removing markdown-it and marked entries, streamlining the dependency list.
- This change follows the recent integration of markstream-vue for improved markdown rendering capabilities.

* chore: remove d3 dependency and update MessageList.vue for dark mode support

- Removed d3 from package.json and commented out its import in LongTermMemory.vue to clean up unused dependencies.
- Updated MessageList.vue to ensure consistent dark mode styling by passing the isDark prop to MarkdownRender components.

* feat: add loading indicator for message retrieval in Chat and MessageList components

- Introduced a loading overlay in Chat.vue and MessageList.vue to indicate when messages are being loaded.
- Added a new `isLoadingMessages` prop to manage loading state and enhance user experience during message retrieval.
- Updated styles to ensure the loading indicator is visually integrated with the existing UI.

* feat: add provider configuration dialog to chat sidebar

- Introduced a new `ProviderConfigDialog` component for managing provider settings.
- Added a menu item in the `ConversationSidebar` to open the provider configuration dialog.
- Updated English and Chinese localization files to include translations for the new provider configuration feature.

* feat: update dashboard components and styles for improved chat experience

- Replaced font in index.html to use 'Outfit' for a fresh look.
- Changed icon in ConversationSidebar.vue to 'mdi-creation' for better representation.
- Refactored MessageList.vue to streamline loading indicators and enhance styling consistency.
- Updated localization files to change 'Provider Configuration' to 'AI Configuration' for clarity.
- Introduced new styles for loading indicators and chat mode adjustments in FullLayout.vue.
- Added functionality for toggling between bot and chat modes in the header.
- Removed deprecated sidebar item for chat navigation.

* feat: xmas easter egg

* chore: remove pnpm lock file
2025-12-20 15:22:48 +08:00
Yokami 4d6150fd6d fix: handle quoted messages correctly to prevent breaking cache (#4112)
* fix: Handle quoted messages correctly as user context

This change ensures quoted messages, including text and image captions, are appended to the conversation history as a user message rather than being injected into the system prompt.

Fixes #3886

* 注入到req.prompt里
2025-12-20 11:03:27 +08:00
Soulter 544e52191b Merge pull request #4065 from AstrBotDevs/refactor/provider-source
refactor: SUPER AMAZING model provider refactor
2025-12-20 00:09:36 +08:00
Soulter f2c2a6da4a chore: ruff format 2025-12-20 00:07:42 +08:00
Soulter dd3df425ee feat: add warnings for missing provider IDs in manager and context
- Introduced logging warnings in ProviderManager and Context classes when a provider ID is not found, indicating potential issues due to ID modifications.
- Updated the ProviderPage.vue to advise against modifying provider IDs, highlighting possible configuration impacts.
2025-12-20 00:06:42 +08:00
Soulter 40b4a27a3d Merge remote-tracking branch 'origin/master' into refactor/provider-source 2025-12-19 15:48:42 +08:00
Soulter 9d991c7468 perf: enhance chat components with theme and fullscreen toggles (#4116)
* perf: enhance chat components with theme and fullscreen toggles

- Added theme and fullscreen toggle functionality to Chat.vue and ConversationSidebar.vue.
- Introduced a new StyledMenu component for improved dropdown menus.
- Updated MessageList.vue and ChatInput.vue for better mobile responsiveness and UI consistency.
- Enhanced language switcher integration in ConversationSidebar.vue.
- Added new settings translations in English and Chinese locales.

* fix: streamline conversation selection handling in Chat.vue

- Updated handleSelectConversation function to immediately set the current session ID and selected sessions, reducing the need for multiple clicks.
- Adjusted padding in ConversationSidebar.vue for improved layout consistency.
2025-12-19 11:18:01 +08:00
Soulter ad6a8b5c94 Merge remote-tracking branch 'origin/master' into refactor/provider-source 2025-12-18 17:39:27 +08:00
Soulter 1b4bfcbd72 chore: ruff format 2025-12-18 17:37:12 +08:00
Soulter 9d3cc593a1 feat: supports thinking level of google gemini (#4104)
* feat: supports thinking level of google gemini

- Updated google-genai version to >=1.56.0 in pyproject.toml and requirements.txt.
- Changed model configuration from "gemini-1.5-flash" to "gemini-3-flash-preview" in default.py.
- Enhanced thinking configuration handling in gemini_source.py to support new parameters for Gemini 3 models.

* fix: standardize thinking level configuration in default.py and gemini_source.py

- Updated the thinking level values in default.py to uppercase for consistency.
- Enhanced gemini_source.py to validate the thinking level and default to "HIGH" if an invalid value is provided.
2025-12-18 17:37:11 +08:00
Soulter f0dee35ba9 feat: enhance tool call handling and agent stats tracking and UI integration for tool calls render (#4101)
* feat: enhance tool call handling and UI integration for tool calls render

- Added support for tool call messages in the agent runner and webchat event handling.
- Implemented JSON message component for structured tool call data.
- Updated chat route to save tool call information in message history.
- Enhanced frontend to display tool call details in a collapsible format, including status and results.
- Introduced elapsed time tracking for ongoing tool calls in the chat interface.

* fix: improve message handling in agent run utility and tool loop runner

- Refactored message sending logic in `astr_agent_run_util.py` to use `msg_chain` directly for better clarity.
- Added a check in `tool_loop_agent_runner.py` to ensure `tool_call_result_blocks` is not empty before yielding the last tool call result, preventing potential errors.

* refactor: enhance message structure and UI for chat components

- Updated message handling in `MessageList.vue` to support structured message parts, including plain text, images, audio, and files.
- Improved the `Chat.vue` component styles for better visual consistency.
- Refactored message parsing logic in `useMessages.ts` to accommodate new message formats and ensure proper rendering of embedded content.
- Removed deprecated tool call handling from the message structure, streamlining the message display process.

* chore: ruff format

* feat: implement agent statistics tracking and display in chat

- Added `AgentStats` and `TokenUsage` data classes to track agent performance metrics.
- Enhanced `ToolLoopAgentRunner` to collect and update agent statistics during execution.
- Integrated agent statistics sending to webchat for real-time updates.
- Updated chat route to save and display agent statistics in message history.
- Improved frontend components to visualize agent statistics, including token usage and duration metrics.

* fix: improve message handling in Telegram event and agent run utility

- Updated message sending logic in `astr_agent_run_util.py` to send the correct message chain for tool calls.
- Enhanced `tg_event.py` to edit messages during streaming breaks, improving message management and user experience.
- Added error handling for message editing failures to ensure robustness.

* chore: ruff format
2025-12-18 17:36:45 +08:00
Soulter 4135bd84d5 refactor: update OneBot configuration and add platform logo (#4106)
- Renamed "QQ 个人号(OneBot v11)" to "OneBot v11" in the configuration.
- Added a new logo for OneBot in the dashboard assets.
- Updated platform icon retrieval logic to include the new OneBot logo.
2025-12-18 17:34:59 +08:00
Soulter f6da614e5d fix: validation error for ToolCall.extra_content in specific upstream model providers (#4102)
* fix: validation error for ToolCall.extra_content in specific upstream model providers

* fix: handle missing extra_content gracefully in ToolCall serialization
2025-12-18 17:34:59 +08:00
Soulter 5f531c9be5 chore: ruff format 2025-12-18 17:17:17 +08:00
Soulter 94591d965b feat: supports thinking level of google gemini (#4104)
* feat: supports thinking level of google gemini

- Updated google-genai version to >=1.56.0 in pyproject.toml and requirements.txt.
- Changed model configuration from "gemini-1.5-flash" to "gemini-3-flash-preview" in default.py.
- Enhanced thinking configuration handling in gemini_source.py to support new parameters for Gemini 3 models.

* fix: standardize thinking level configuration in default.py and gemini_source.py

- Updated the thinking level values in default.py to uppercase for consistency.
- Enhanced gemini_source.py to validate the thinking level and default to "HIGH" if an invalid value is provided.
2025-12-18 17:15:01 +08:00
Soulter 8a0f865af1 feat: enhance tool call handling and agent stats tracking and UI integration for tool calls render (#4101)
* feat: enhance tool call handling and UI integration for tool calls render

- Added support for tool call messages in the agent runner and webchat event handling.
- Implemented JSON message component for structured tool call data.
- Updated chat route to save tool call information in message history.
- Enhanced frontend to display tool call details in a collapsible format, including status and results.
- Introduced elapsed time tracking for ongoing tool calls in the chat interface.

* fix: improve message handling in agent run utility and tool loop runner

- Refactored message sending logic in `astr_agent_run_util.py` to use `msg_chain` directly for better clarity.
- Added a check in `tool_loop_agent_runner.py` to ensure `tool_call_result_blocks` is not empty before yielding the last tool call result, preventing potential errors.

* refactor: enhance message structure and UI for chat components

- Updated message handling in `MessageList.vue` to support structured message parts, including plain text, images, audio, and files.
- Improved the `Chat.vue` component styles for better visual consistency.
- Refactored message parsing logic in `useMessages.ts` to accommodate new message formats and ensure proper rendering of embedded content.
- Removed deprecated tool call handling from the message structure, streamlining the message display process.

* chore: ruff format

* feat: implement agent statistics tracking and display in chat

- Added `AgentStats` and `TokenUsage` data classes to track agent performance metrics.
- Enhanced `ToolLoopAgentRunner` to collect and update agent statistics during execution.
- Integrated agent statistics sending to webchat for real-time updates.
- Updated chat route to save and display agent statistics in message history.
- Improved frontend components to visualize agent statistics, including token usage and duration metrics.

* fix: improve message handling in Telegram event and agent run utility

- Updated message sending logic in `astr_agent_run_util.py` to send the correct message chain for tool calls.
- Enhanced `tg_event.py` to edit messages during streaming breaks, improving message management and user experience.
- Added error handling for message editing failures to ensure robustness.

* chore: ruff format
2025-12-18 17:11:09 +08:00
Soulter 4aced976a8 refactor: update OneBot configuration and add platform logo (#4106)
- Renamed "QQ 个人号(OneBot v11)" to "OneBot v11" in the configuration.
- Added a new logo for OneBot in the dashboard assets.
- Updated platform icon retrieval logic to include the new OneBot logo.
2025-12-18 15:19:15 +08:00
Soulter 0299aa6e4c fix: validation error for ToolCall.extra_content in specific upstream model providers (#4102)
* fix: validation error for ToolCall.extra_content in specific upstream model providers

* fix: handle missing extra_content gracefully in ToolCall serialization
2025-12-18 11:55:49 +08:00
Soulter e8b54a019e refactor: replace ProviderModelSelector with ProviderModelMenu for improved UI and functionality 2025-12-17 22:57:32 +08:00
Soulter 98ce796275 chore: remove copilot instruction 2025-12-17 17:21:33 +08:00
Soulter b87dcf2275 refactor: improve provider source ID validation to prevent duplicates during configuration updates 2025-12-17 17:19:35 +08:00
Soulter 591a228431 refactor: enhance provider management with resource locking and CRUD operations 2025-12-17 17:08:52 +08:00
Soulter f52f375154 refactor: update provider handling to use new config structure and improve template retrieval 2025-12-17 16:55:12 +08:00
Soulter 975c685a17 chore: ruff format 2025-12-17 16:32:38 +08:00
Soulter 6db80d36a8 fix: prevent platform ID modification during updates and ensure correct routing table handling 2025-12-17 16:16:50 +08:00
Soulter 4651bd2807 feat: implement provider deletion functionality and ensure unique provider IDs 2025-12-17 15:00:22 +08:00
Soulter 94ada3793e Merge remote-tracking branch 'origin/master' into refactor/provider-source 2025-12-17 13:33:23 +08:00
Soulter fd05b0bf09 docs: update contributing guidelines to include code style and formatting instructions 2025-12-17 13:26:22 +08:00
Soulter 4d046f8490 delete: remove backup of ProviderPage.vue 2025-12-17 11:34:12 +08:00
Copilot 58e32b7b70 fix: inverted logic in segmented reply LLM-only filter (#4071)
* Initial plan

* Fix: Correct inverted logic in is_seg_reply_required for only_llm_result option

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-17 11:12:05 +08:00
Soulter 903dd0f9f7 feat: add manual model addition functionality and search capability in ProviderPage 2025-12-17 10:56:45 +08:00
Soulter 1acac0cac2 feat: enhance provider selection with a new drawer interface and localization updates 2025-12-17 10:39:16 +08:00
Oscar Shaw 80b89fd2ea feat: implements command management and improve webui feature structure (#3904)
move mcp management to plugin managemanet page

* feat: 新增命令配置数据库模型

* feat: 实现核心命令管理系统

* feat: 将命令管理集成到 Star 框架

* feat: 新增命令管理后台 API

* feat: 新增命令管理界面页面

* feat: 新增命令管理国际化支持

* test: 新增命令管理相关测试

* refactor(command): 移除指令重命名时的别名功能

* fix(command): 修正指令冲突检测逻辑

* fix(command): 排除已禁用指令的冲突检测

- 只有 `effective_command` 存在且 `enabled` 为 `True` 的指令才会被纳入冲突检测范围。

* feat(command): 优化指令冲突显示与提示

- 【功能】新增指令冲突警告提示,当检测到冲突时显示详细信息及解决方案。
- 【优化】调整指令列表排序逻辑,将冲突指令优先显示并分组。
- 【样式】为冲突指令行添加专属高亮样式,提升视觉识别度。
- 【国际化】更新英文和中文多语言文件,增加指令冲突警告相关的翻译文本。

* chore(command-page): 禁用命令表格部分列的排序功能

* style(command-page): 调整命令页面表格样式和图标大小

* refactor(command): 优化指令页面布局并更新冲突警告

- 【布局优化】重新组织指令管理页面布局,将筛选器移至顶部独立行
- 【信息展示】将搜索栏与总指令数、已禁用指令数合并显示,提升页面空间利用率
- 【视觉更新】更新指令冲突警告样式

* style: UI 细节

* refactor(command): 调整指令管理中的成员权限显示与筛选

  - 更新指令筛选逻辑,当选择“所有人”权限筛选时,将同时包含 `everyone` 和 `member` 权限的指令。

* feat(command-management): 新增指令层级管理与UI展示

- 【后端】
  - `CommandDescriptor` 新增 `parent_group_handler` 和 `sub_commands` 字段,支持指令层级结构定义。
  - `list_commands` 函数重构,实现指令的层级收集与构建,将子指令正确挂载到其父指令组下。
  - 新增 `_collect_all_descriptors` 和 `_find_parent_group_handler` 辅助函数,用于全面收集指令并定位父指令组。
  - `_build_descriptor` 优化指令类型判断逻辑,明确区分普通指令、指令组和子指令。
  - `_descriptor_to_dict` 递归处理子指令,确保 API 返回完整的指令层级数据。
- 【前端】
  - 指令管理页面 (`CommandPage.vue`) 增加指令类型筛选器,并支持指令组的展开/折叠功能。
  - 表格展示优化,为指令组和子指令添加不同的样式和缩进,提升层级结构的视觉可读性。
  - 指令详情对话框新增指令类型、所属指令组和子指令列表的展示。
  - 更新 `CommandItem` 接口,以适配后端提供的层级数据结构。
- 【i18n】
  - 新增指令类型(指令、指令组、子指令)的国际化文本。
  - 更新指令管理相关 UI 文本,包括表格头部、详情对话框字段和筛选器选项。

* style(command): 优化指令组子指令数量显示UI

* refactor(command): 修改指令列表排序逻辑

* style(command-page): 优化命令列表UI

* feat(command): 添加系统插件指令过滤与冲突处理

* refactor(command): 更新指令数展示逻辑

* style(command): 更新空状态描述

* feat(extension): 添加插件指令冲突检测与提示

- 在插件安装或启用后,自动检测并提示指令冲突。
- 当检测到指令冲突时,显示警告对话框,告知用户冲突数量及可能的影响。

* refactor(command): 移除指令表格内部加载指示器

* style(extension): 文案修改

* refactor(command): 模块化指令管理面板前端代码

* refactor(commandPanel): 重命名指令模块目录为 commandPanel

* style(commandPanel): 微调指令面板UI

* fix(command): 确保新命令配置的事务提交

* fix(sidebar): 补全新增侧边栏项后的侧边栏位追加逻辑

* refactor(commands): 重构/help指令以动态显示实际命令并补充部分命令描述

* style(builtin_commands): 补充命令描述

* refactor(commandPanel): 移除未使用的 filterState 常量

* perf(dashboard): 删除多余的CommandPage.vue文件(已被模块化引用)

* perf(command): 优化命令冲突计数逻辑

* perf(command): 优化指令管理辅助函数和配置绑定逻辑

* perf(db): 优化重构command相关数据库操作

* refactor(sidebar): 提取侧边栏项目解析逻辑到工具函数复用

* refactor: move mcp and command page to extension page

* refactor: remove unused imports in component panel

* fix: update terminology for handler management in extension localization

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-16 20:24:57 +08:00
Soulter 26f863ba81 Revert "fix: omit empty content field for the LLM request after tool calls ar…" (#4068)
This reverts commit f78a90218e.
2025-12-16 20:22:13 +08:00
sctop f78a90218e fix: omit empty content field for the LLM request after tool calls are completed (#4008)
* fix: omit content field for the LLM request after tool calls are completed and content is empy string or none

* chore: ruff format

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-16 20:11:11 +08:00
Soulter a3ecebd2aa fix: correct text accumulation logic in webchat (#4066) 2025-12-16 19:35:41 +08:00
Soulter 67c33b842d feat: add new provider icons and improve provider source handling
- Added icons for 'modelstack', 'tokenpony', and 'compshare' in providerUtils.js.
- Updated ProviderPage.vue to display the correct count of displayed provider sources.
- Enhanced the logic for displaying provider sources to include placeholders for unselected templates.
- Improved the display name for provider sources to show template keys for placeholders.
- Adjusted styles for better layout and overflow handling in provider source list and cards.
- Refactored source selection logic to handle placeholder sources correctly.
- Updated error handling in provider testing to provide clearer messages.
2025-12-16 16:11:56 +08:00
Soulter 5431c9f46e refactor: remove unused tab from AddNewProvider and disable button based on provider status in ProviderPage 2025-12-16 12:26:26 +08:00
Soulter 764b91a5f7 chore: ruff check 2025-12-16 12:21:14 +08:00
Soulter c20c1b84bf feat: implement LLM metadata fetching and integrate into provider model selection 2025-12-16 12:19:40 +08:00
Soulter fd66a0ac00 perf: better UI 2025-12-16 11:24:07 +08:00
Soulter aaee283367 fix: type checking of AstrAgentContext 2025-12-16 10:09:57 +08:00
Soulter 4a5b7d1976 fix: type checking of contextwrapper 2025-12-16 09:59:56 +08:00
Sukafon 08244548ab fix: incorrect type assignment when the agent send an image (#4050) 2025-12-16 08:28:10 +08:00
dependabot[bot] b486de6a98 chore(deps): bump actions/upload-artifact in the github-actions group (#4061)
Bumps the github-actions group with 1 update: [actions/upload-artifact](https://github.com/actions/upload-artifact).


Updates `actions/upload-artifact` from 5 to 6
- [Release notes](https://github.com/actions/upload-artifact/releases)
- [Commits](https://github.com/actions/upload-artifact/compare/v5...v6)

---
updated-dependencies:
- dependency-name: actions/upload-artifact
  dependency-version: '6'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2025-12-16 08:24:03 +08:00
Soulter e2f928a7e5 chore: bump version to 4.9.2 2025-12-15 16:58:32 +08:00
Soulter b8e4068c75 feat: support key-value storage for plugins (#4048)
* feat: support key-value storage for plugins

* fix: remove unnecessary initialization method from Main class
2025-12-15 16:50:44 +08:00
Soulter 0916177a57 chore: bump version to 4.9.1 2025-12-15 16:07:10 +08:00
Soulter 02cd5e396b feat: add trigger probability setting for TTS and support to render slider in schema (#4047)
* feat: add trigger probability setting for TTS and support to render slider in schema

* chore: ruff format
2025-12-15 16:04:27 +08:00
Soulter 56673ad78f fix: prevent duplicate result content type after streaming finishes in RespondStage 2025-12-15 15:33:40 +08:00
Soulter 9a4d05e2b6 fix: remove unnecessary persistent attribute from ReadmeDialog and adjust dialog structure in ExtensionPage 2025-12-15 15:27:42 +08:00
Soulter b2e9dab233 refactor: enhance layout and improve provider source management in ProviderPage 2025-12-15 15:15:17 +08:00
Soulter 45110200ea feat: update provider and provider source configuration handling 2025-12-15 12:31:29 +08:00
Soulter c3f45449e8 docs: readme
wa ta shi wa ko sei no de su ka ra!
2025-12-15 11:47:21 +08:00
Copilot 65da469deb feat: add conversation export feature to JSONL for AI training (#4037)
* Initial plan

* Add conversation export functionality (backend and frontend)

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Address code review feedback: move imports, simplify logic, improve i18n

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* Simplify frontend download logic: remove redundant Blob wrapper and complex filename parsing

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* fix: update conversation export filename format for consistency

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-14 21:44:12 +08:00
Soulter 16df64c405 fix: lark domain and log_level of Lark API client (#4038)
fixes: #4035
2025-12-14 21:31:17 +08:00
i0cLiceao 6b73b19e54 fix: support using GitHub Raw content as plugin source (#3975)
* Update plugin.py

* Update plugin.py

* Update plugin.py

* Update plugin.py
2025-12-14 18:23:29 +08:00
Soulter a70088b799 Merge remote-tracking branch 'origin/master' into refactor/provider-source 2025-12-13 23:37:23 +08:00
Soulter e7e97730af chore: bump version to 4.9.0 2025-12-13 18:49:07 +08:00
Soulter 467ca1eb5c fix: webui log output incompletely (#4029)
* fix: webui log output incompletely

* fix: improve SSE log parsing to handle partial data chunks

* fix: enhance log handling by implementing local cache and fetching history

* fix: log time handling to use epoch time
2025-12-13 18:46:16 +08:00
Soulter bb45d9cb54 stage 2025-12-13 17:16:07 +08:00
RC-CHN 46528391c2 feat: add pre-chunk import strategy for knowledge base (#3973)
* feat: 添加文档导入功能及相关测试

* feat: 优化文档上传功能,支持从文件名推断文件类型,并增强文档切片验证

* feat: 添加文档导入功能的无效输入测试,验证 chunks 类型和内容的错误处理

* refactor: 重构文档上传和导入任务的状态管理,添加任务初始化、结果设置和进度更新方法
2025-12-12 23:15:11 +08:00
Soulter 8a0b7717cc feat: supports webhook mode for Lark platform (#4016)
* feat: add Lark platform support with unified webhook configuration

* fix: update token verification logic in LarkWebhookServer

* feat: implement event deduplication and cleanup for Lark webhook events
2025-12-12 22:12:13 +08:00
Copilot 3b81fb4985 fix: mobile dialog close button visibility (#4010)
* Initial plan

* Fix mobile dialog close button visibility by adding max-height and scrollable content

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 16:02:24 +08:00
Soulter c09d57a820 refactor: improve UI layout and interaction for list item management (#4002)
* refactor: improve UI layout and interaction for list item management

* feat: enhance list configuration UI with batch import functionality

* feat: add internationalization support for list configuration UI
2025-12-11 18:55:56 +08:00
Soulter ec408a2aff fix: lark message timestamp 2025-12-11 18:20:50 +08:00
Soulter 417179a6b9 ci: add smoke test 2025-12-11 10:44:15 +08:00
Soulter fcd29445c7 refactor: remove unused current provider initialization in StarRequestSubStage 2025-12-11 10:36:33 +08:00
BiDuang 5f535001db fix: incorrect modalities enum of gemini api provider (#3993) 2025-12-10 20:27:51 +08:00
PaloMiku 750d245b16 docs: Update README with new Zread link and badges (#3992)
ZRead 是由智谱 AI 推出的 DeepWiki 类似平替品。
2025-12-10 20:22:56 +08:00
Dt8333 f624971613 chore: fix bunches of type checking errors (#3213)
* chore(core.utils): 🚨 修正错误Lint

* chore(core.provider): 🚨 修复基类错误Lint

* chore(core.utils): 补全session_get()的重载

* chore(core.provider): 🚨 修正实现错误Lint

* chore(core.platform): 🚨 修正platform基类和webchat的错误Lint

* chore(core.platform): 修正错误实现Lint

* fix(core.provider): 修复循环调用和错误assert

* chore(core.platform): 修复部分实现Lint

* chore(core.provider): 补充Dify.text_chat_stream的参数类型

* chore(core.pipeline): 🚨 修复错误Lint

* fix(core.slack): 补充遗漏导入

* chore(core.utils): 修复错误的session_get声明

* chore(core.platform): 移除Lark adapter import中的wildcard

* chore(core.db): 修复声明和部分逻辑

* chore(core.db): 添加typings,使faiss参数能被正确识别。

* chore(core): 修复声明

* chore(core): 修改声明

* chore: 补充faiss声明

* chore(dashboard): 修改实现,减少报错

* chore(package): 修改部分声明与实现,减少报错

* chore(core): 添加Handler的overload,以去除部分assert同时通过类型检查

* chore(core.pipeline): 修改Pipeline Scheduler的execute,将判断属性改为判断类型,通过静态类型检查

* chore(core.config): 添加类型标注,通过类型检查

* chore(core.message): 为File._download_file添加检查,通过类型检查

* fix: 将断言改为条件判断以实现优雅关闭的容错性

* refactor: 移除 discord 客户端中的 assert,改用 if None 判断并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: DiscordPlatformAdapter 对 self.client.user 为 None 做日志并返回,移除断言

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 增强 Lark 相关空值/异常检查并完善日志输出

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将断言替换为条件检查并加入日志与错误处理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* chore: 移除LLM生成的无用注释

* refactor: 使用 File.get_file 替换下载逻辑并移除 assert,提供默认 filename

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: Slack Socket 未初始化抛出运行时异常,图片 URL 判空改为非空判断

* refactor: 将 WeChatPadProAdapter 的断言改为空值判断并添加日志

* refactor: 使用 isinstance 替代断言实现类型判断,便于静态检查

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 去除cast,直接使用字段与字典访问,修正端口解析

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 使用 match-case 重构 ProviderManager 加载并通过类型检查抛出 TypeError

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: group_name_display 时若 group 对象为空则记录错误并返回

* fix: 将 _get_current_persona_id 的 assert 替换成 if guard 并返回 None

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 优化插件目录存在性检查及图片URL非空验证,更新JSON排序配置

* fix: 将 datetime_str 的 assert 替换为显式检查并抛出异常

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改为运行时检查并在找不到调度器时跳过

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除 cast,改用 isinstance 检查 FaissVecDB 并警告

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 删除 typing.cast 导入,并在获取文件绝对路径前校验 file_

* refactor: 移除 typing.cast,简化内容安全检查调用

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 将 PlatformMetadata.id 设为必填并在注册时传入 id,移除 cast

* refactor: 移除 cast,改用 HasInitialize 与 isinstance 进行初始化

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 ProviderManager.initialize 增加ID类型判断,避免 None 导致 get 失败

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 为 OTTSProvider 与 AzureNativeProvider 引入 _client 与 client 属性改进上下文管理

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 为 Whisper 自托管源添加模型未初始化校验并直接调用 transcribe

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 移除未使用的 cast 导入并简化 platform_name 赋值

* refactor: 引入 cast 并对 id 使用 cast(str, ...) 提升类型安全

* fix: 将 _id_to_sid 返回改为 str,空值返回空串;对 id 与 message_id 使用 cast

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 重构 Discord 处理逻辑:强制 类型转换、优先斜杠指令并优化提及判断

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* fix: 统一对 id 获取执行 cast,并在微信消息解析失败时抛错

* Revert "fix: 去除cast,直接使用字段与字典访问,修正端口解析"

This reverts commit 1cbfdf9d1b.

* fix: 百炼 Rerank 会话关闭时返回空结果;初始化 request.prompt 避免空值拼接

* fix: 统一处理搜索结果链接为字符串,新增 _get_url 助手并适配 Bing/Sogo

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>

* refactor: 调整 call_handler 泛型、Discord 通道注解及 FishAudioTTS API 请求类型

* refactor: 使用 col(...) 替代列引用并对结果进行 CursorResult 强转

* chore: ruff format

---------

Co-authored-by: aider (openai/gemini-3-pro-high) <aider@aider.chat>
Co-authored-by: Soulter <905617992@qq.com>
2025-12-09 14:13:47 +08:00
Soulter aa6d07afcc refactor: move all internal commands from astrbot plugin to default_command plugin (#3960)
* refactor: move all internal commands from astrbot plugin to default_command plugin

* ruff check

* feat: add config

* ruff check
2025-12-08 22:17:32 +08:00
Soulter 2c36649874 feat: add Agent Runner test prompt dialog in ProviderPage (#3968) 2025-12-08 21:46:47 +08:00
Soulter c95735dcc0 docs: update readme 2025-12-08 12:05:57 +08:00
Soulter 03bb278f50 chore: ruff check 2025-12-08 11:00:43 +08:00
Soulter a5e0974da3 chore: ruff format 2025-12-08 00:36:56 +08:00
vmoranv f0fb447fbc feat: custom plugin api source manager (#3956)
* feat: custom plugin api source manager

* fix: rename plugin source file in a safer way

* chore: turned the way of saving plugin source to backend and refacted some components

* style: clean up whitespace and improve logging message formatting

---------

Co-authored-by: Soulter <905617992@qq.com>
2025-12-08 00:32:50 +08:00
Soulter 37566182b0 feat: segment reply supports segmentation words (#3959)
* feat: segment reply supports segmentation words

* chore: ruff format

* feat: enhance segmented reply processing by refining word extraction logic

* ruff format
2025-12-08 00:27:17 +08:00
Soulter e460b411da chore: remove dev version from webui (#3951)
* chore: remove dev version

* chore: remove development version references from header localization files
2025-12-07 15:23:30 +08:00
257 changed files with 13754 additions and 4493 deletions
+1 -1
View File
@@ -36,7 +36,7 @@ jobs:
zip -r dist.zip dist zip -r dist.zip dist
- name: Archive production artifacts - name: Archive production artifacts
uses: actions/upload-artifact@v5 uses: actions/upload-artifact@v6
with: with:
name: dist-without-markdown name: dist-without-markdown
path: | path: |
+58
View File
@@ -0,0 +1,58 @@
name: Smoke Test
on:
push:
branches:
- master
paths-ignore:
- 'README*.md'
- 'changelogs/**'
- 'dashboard/**'
pull_request:
workflow_dispatch:
jobs:
smoke-test:
name: Run smoke tests
runs-on: ubuntu-latest
timeout-minutes: 10
steps:
- name: Checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.12'
- name: Install UV package manager
run: |
pip install uv
- name: Install dependencies
run: |
uv sync
timeout-minutes: 15
- name: Run smoke tests
run: |
uv run main.py &
APP_PID=$!
echo "Waiting for application to start..."
for i in {1..60}; do
if curl -f http://localhost:6185 > /dev/null 2>&1; then
echo "Application started successfully!"
kill $APP_PID
exit 0
fi
sleep 1
done
echo "Application failed to start within 30 seconds"
kill $APP_PID 2>/dev/null || true
exit 1
timeout-minutes: 2
+26 -1
View File
@@ -33,6 +33,20 @@
- 请使用英文描述您的 PR。 - 请使用英文描述您的 PR。
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo` - 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`
#### 代码规范
##### Core
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
```bash
ruff format .
ruff check .
```
如果您使用 VSCode,可以安装 `Ruff` 插件。
## Contributing Guide ## Contributing Guide
First off, thanks for taking the time to contribute! ❤️ First off, thanks for taking the time to contribute! ❤️
@@ -62,4 +76,15 @@ We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features.
#### PR Description #### PR Description
- Please use English to describe your PR. - Please use English to describe your PR.
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`. - Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
#### Code Style
##### Core
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
```bash
ruff format .
ruff check .
```
+9 -1
View File
@@ -1,4 +1,4 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9) ![astrbot-banner-xmas](https://github.com/user-attachments/assets/bf2341de-ec7a-45a7-a04a-02ad36450e99)
<div align="center"> <div align="center">
@@ -20,6 +20,7 @@
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest"> <img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python"> <img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot"> <img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a> <a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600"> <img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot"> <img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
@@ -206,6 +207,7 @@ pre-commit install
- 3 群:630166526 - 3 群:630166526
- 5 群:822130018 - 5 群:822130018
- 6 群:753075035 - 6 群:753075035
- 7 群:743746109
- 开发者群:975206796 - 开发者群:975206796
### Telegram 群组 ### Telegram 群组
@@ -241,4 +243,10 @@ pre-commit install
</details> </details>
<div align="center">
_私は、高性能ですから!_ _私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.8.0" __version__ = "4.10.2"
+6 -4
View File
@@ -3,7 +3,7 @@
from typing import Any, ClassVar, Literal, cast from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
from pydantic_core import core_schema from pydantic_core import core_schema
@@ -122,10 +122,12 @@ class ToolCall(BaseModel):
extra_content: dict[str, Any] | None = None extra_content: dict[str, Any] | None = None
"""Extra metadata for the tool call.""" """Extra metadata for the tool call."""
def model_dump(self, **kwargs: Any) -> dict[str, Any]: @model_serializer(mode="wrap")
def serialize(self, handler):
data = handler(self)
if self.extra_content is None: if self.extra_content is None:
kwargs.setdefault("exclude", set()).add("extra_content") data.pop("extra_content", None)
return super().model_dump(**kwargs) return data
class ToolCallPart(BaseModel): class ToolCallPart(BaseModel):
+22 -1
View File
@@ -1,7 +1,8 @@
import typing as T import typing as T
from dataclasses import dataclass from dataclasses import dataclass, field
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import TokenUsage
class AgentResponseData(T.TypedDict): class AgentResponseData(T.TypedDict):
@@ -12,3 +13,23 @@ class AgentResponseData(T.TypedDict):
class AgentResponse: class AgentResponse:
type: str type: str
data: AgentResponseData data: AgentResponseData
@dataclass
class AgentStats:
token_usage: TokenUsage = field(default_factory=TokenUsage)
start_time: float = 0.0
end_time: float = 0.0
time_to_first_token: float = 0.0
@property
def duration(self) -> float:
return self.end_time - self.start_time
def to_dict(self) -> dict:
return {
"token_usage": self.token_usage.__dict__,
"start_time": self.start_time,
"end_time": self.end_time,
"time_to_first_token": self.time_to_first_token,
}
+1 -1
View File
@@ -9,7 +9,7 @@ from .message import Message
TContext = TypeVar("TContext", default=Any) TContext = TypeVar("TContext", default=Any)
@dataclass(config={"arbitrary_types_allowed": True}) @dataclass
class ContextWrapper(Generic[TContext]): class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state.""" """A context for running an agent, which can be used to pass additional data or state."""
@@ -1,4 +1,5 @@
import sys import sys
import time
import traceback import traceback
import typing as T import typing as T
@@ -12,6 +13,7 @@ from mcp.types import (
) )
from astrbot import logger from astrbot import logger
from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
) )
@@ -24,7 +26,7 @@ from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData from ..response import AgentResponseData, AgentStats
from ..run_context import ContextWrapper, TContext from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner from .base import AgentResponse, AgentState, BaseAgentRunner
@@ -69,14 +71,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
) )
self.run_context.messages = messages self.run_context.messages = messages
self.stats = AgentStats()
self.stats.start_time = time.time()
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse.""" """Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
}
if self.streaming: if self.streaming:
stream = self.provider.text_chat_stream(**self.req.__dict__) stream = self.provider.text_chat_stream(**payload)
async for resp in stream: # type: ignore async for resp in stream: # type: ignore
yield resp yield resp
else: else:
yield await self.provider.text_chat(**self.req.__dict__) yield await self.provider.text_chat(**payload)
@override @override
async def step(self): async def step(self):
@@ -97,8 +110,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_resp_result = None llm_resp_result = None
async for llm_response in self._iter_llm_responses(): async for llm_response in self._iter_llm_responses():
assert isinstance(llm_response, LLMResponse)
if llm_response.is_chunk: if llm_response.is_chunk:
# update ttft
if self.stats.time_to_first_token == 0:
self.stats.time_to_first_token = time.time() - self.stats.start_time
if llm_response.result_chain: if llm_response.result_chain:
yield AgentResponse( yield AgentResponse(
type="streaming_delta", type="streaming_delta",
@@ -122,6 +138,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
) )
continue continue
llm_resp_result = llm_response llm_resp_result = llm_response
if not llm_response.is_chunk and llm_response.usage:
# only count the token usage of the final response for computation purpose
self.stats.token_usage += llm_response.usage
break # got final response break # got final response
if not llm_resp_result: if not llm_resp_result:
@@ -133,6 +153,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if llm_resp.role == "err": if llm_resp.role == "err":
# 如果 LLM 响应错误,转换到错误状态 # 如果 LLM 响应错误,转换到错误状态
self.final_llm_resp = llm_resp self.final_llm_resp = llm_resp
self.stats.end_time = time.time()
self._transition_state(AgentState.ERROR) self._transition_state(AgentState.ERROR)
yield AgentResponse( yield AgentResponse(
type="err", type="err",
@@ -147,11 +168,12 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态 # 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE) self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
# record the final assistant message # record the final assistant message
self.run_context.messages.append( self.run_context.messages.append(
Message( Message(
role="assistant", role="assistant",
content=llm_resp.completion_text or "", content=llm_resp.completion_text or "*No response*",
), ),
) )
try: try:
@@ -176,22 +198,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果有工具调用,还需处理工具调用 # 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name: if llm_resp.tools_call_name:
tool_call_result_blocks = [] tool_call_result_blocks = []
for tool_call_name in llm_resp.tools_call_name:
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain(type="tool_call").message(
f"🔨 调用工具: {tool_call_name}"
),
),
)
async for result in self._handle_function_tools(self.req, llm_resp): async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list): if isinstance(result, list):
tool_call_result_blocks = result tool_call_result_blocks = result
elif isinstance(result, MessageChain): elif isinstance(result, MessageChain):
result.type = "tool_call_result" if result.type is None:
# should not happen
continue
if result.type == "tool_direct_result":
ar_type = "tool_call_result"
else:
ar_type = result.type
yield AgentResponse( yield AgentResponse(
type="tool_call_result", type=ar_type,
data=AgentResponseData(chain=result), data=AgentResponseData(chain=result),
) )
# 将结果添加到上下文中 # 将结果添加到上下文中
@@ -219,6 +238,25 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
async for resp in self.step(): async for resp in self.step():
yield resp yield resp
# 如果循环结束了但是 agent 还没有完成,说明是达到了 max_step
if not self.done():
logger.warning(
f"Agent reached max steps ({max_step}), forcing a final response."
)
# 拔掉所有工具
if self.req:
self.req.func_tool = None
# 注入提示词
self.run_context.messages.append(
Message(
role="user",
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
)
)
# 再执行最后一步
async for resp in self.step():
yield resp
async def _handle_function_tools( async def _handle_function_tools(
self, self,
req: ProviderRequest, req: ProviderRequest,
@@ -234,6 +272,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_response.tools_call_args, llm_response.tools_call_args,
llm_response.tools_call_ids, llm_response.tools_call_ids,
): ):
yield MessageChain(
type="tool_call",
chain=[
Json(
data={
"id": func_tool_id,
"name": func_tool_name,
"args": func_tool_args,
"ts": time.time(),
}
)
],
)
try: try:
if not req.func_tool: if not req.func_tool:
return return
@@ -307,7 +358,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content=res.content[0].text, content=res.content[0].text,
), ),
) )
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent): elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append( tool_call_result_blocks.append(
ToolCallMessageSegment( ToolCallMessageSegment(
@@ -329,7 +379,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content=resource.text, content=resource.text,
), ),
) )
yield MessageChain().message(resource.text)
elif ( elif (
isinstance(resource, BlobResourceContents) isinstance(resource, BlobResourceContents)
and resource.mimeType and resource.mimeType
@@ -353,20 +402,34 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content="返回的数据类型不受支持", content="返回的数据类型不受支持",
), ),
) )
yield MessageChain().message("返回的数据类型不受支持。")
elif resp is None: elif resp is None:
# Tool 直接请求发送消息给用户 # Tool 直接请求发送消息给用户
# 这里我们将直接结束 Agent Loop。 # 这里我们将直接结束 Agent Loop。
# 发送消息逻辑在 ToolExecutor 中处理了。 # 发送消息逻辑在 ToolExecutor 中处理了。
logger.warning( logger.warning(
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中" f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户。"
) )
self._transition_state(AgentState.DONE) self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="*工具没有返回值或者将结果直接发送给了用户*",
),
)
else: else:
# 不应该出现其他类型 # 不应该出现其他类型
logger.warning( logger.warning(
f"Tool 返回了不支持的类型: {type(resp)},将忽略", f"Tool 返回了不支持的类型: {type(resp)}",
)
tool_call_result_blocks.append(
ToolCallMessageSegment(
role="tool",
tool_call_id=func_tool_id,
content="*工具返回了不支持的类型,请告诉用户检查这个工具的定义和实现。*",
),
) )
try: try:
@@ -388,6 +451,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
), ),
) )
# yield the last tool call result
if tool_call_result_blocks:
last_tcr_content = str(tool_call_result_blocks[-1].content)
yield MessageChain(
type="tool_call_result",
chain=[
Json(
data={
"id": func_tool_id,
"ts": time.time(),
"result": last_tcr_content,
}
)
],
)
# 处理函数调用响应 # 处理函数调用响应
if tool_call_result_blocks: if tool_call_result_blocks:
yield tool_call_result_blocks yield tool_call_result_blocks
+7 -2
View File
@@ -1,4 +1,4 @@
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any, Generic from typing import Any, Generic
import jsonschema import jsonschema
@@ -7,6 +7,8 @@ from deprecated import deprecated
from pydantic import Field, model_validator from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from astrbot.core.message.message_event_result import MessageEventResult
from .run_context import ContextWrapper, TContext from .run_context import ContextWrapper, TContext
ParametersType = dict[str, Any] ParametersType = dict[str, Any]
@@ -38,7 +40,10 @@ class ToolSchema:
class FunctionTool(ToolSchema, Generic[TContext]): class FunctionTool(ToolSchema, Generic[TContext]):
"""A callable tool, for function calling.""" """A callable tool, for function calling."""
handler: Callable[..., Awaitable[Any]] | None = None handler: (
Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]]
| None
) = None
"""a callable that implements the tool's functionality. It should be an async function.""" """a callable that implements the tool's functionality. It should be an async function."""
handler_module_path: str | None = None handler_module_path: str | None = None
+3 -1
View File
@@ -6,8 +6,10 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.context import Context from astrbot.core.star.context import Context
@dataclass(config={"arbitrary_types_allowed": True}) @dataclass
class AstrAgentContext: class AstrAgentContext:
__pydantic_config__ = {"arbitrary_types_allowed": True}
context: Context context: Context
"""The star context instance""" """The star context instance"""
event: AstrMessageEvent event: AstrMessageEvent
+42 -3
View File
@@ -2,8 +2,10 @@ import traceback
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.agent.message import Message
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
MessageEventResult, MessageEventResult,
@@ -23,8 +25,25 @@ async def run_agent(
) -> AsyncGenerator[MessageChain | None, None]: ) -> AsyncGenerator[MessageChain | None, None]:
step_idx = 0 step_idx = 0
astr_event = agent_runner.run_context.context.event astr_event = agent_runner.run_context.context.event
while step_idx < max_step: while step_idx < max_step + 1:
step_idx += 1 step_idx += 1
if step_idx == max_step + 1:
logger.warning(
f"Agent reached max steps ({max_step}), forcing a final response."
)
if not agent_runner.done():
# 拔掉所有工具
if agent_runner.req:
agent_runner.req.func_tool = None
# 注入提示词
agent_runner.run_context.messages.append(
Message(
role="user",
content="工具调用次数已达到上限,请停止使用工具,并根据已经收集到的信息,对你的任务和发现进行总结,然后直接回复用户。",
)
)
try: try:
async for resp in agent_runner.step(): async for resp in agent_runner.step():
if astr_event.is_stopped(): if astr_event.is_stopped():
@@ -33,16 +52,27 @@ async def run_agent(
msg_chain = resp.data["chain"] msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result": if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容 # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
await astr_event.send(resp.data["chain"]) await astr_event.send(msg_chain)
continue continue
if astr_event.get_platform_id() == "webchat":
await astr_event.send(msg_chain)
# 对于其他情况,暂时先不处理 # 对于其他情况,暂时先不处理
continue continue
elif resp.type == "tool_call": elif resp.type == "tool_call":
if agent_runner.streaming: if agent_runner.streaming:
# 用来标记流式响应需要分节 # 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break") yield MessageChain(chain=[], type="break")
if show_tool_use:
if astr_event.get_platform_name() == "webchat":
await astr_event.send(resp.data["chain"]) await astr_event.send(resp.data["chain"])
elif show_tool_use:
json_comp = resp.data["chain"].chain[0]
if isinstance(json_comp, Json):
m = f"🔨 调用工具: {json_comp.data.get('name')}"
else:
m = "🔨 调用工具..."
chain = MessageChain(type="tool_call").message(m)
await astr_event.send(chain)
continue continue
if stream_to_general and resp.type == "streaming_delta": if stream_to_general and resp.type == "streaming_delta":
@@ -69,6 +99,15 @@ async def run_agent(
continue continue
yield resp.data["chain"] # MessageChain yield resp.data["chain"] # MessageChain
if agent_runner.done(): if agent_runner.done():
# send agent stats to webchat
if astr_event.get_platform_name() == "webchat":
await astr_event.send(
MessageChain(
type="agent_stats",
chain=[Json(data=agent_runner.stats.to_dict())],
)
)
break break
except Exception as e: except Exception as e:
+39 -5
View File
@@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
async def call_local_llm_tool( async def call_local_llm_tool(
context: ContextWrapper[AstrAgentContext], context: ContextWrapper[AstrAgentContext],
handler: T.Callable[..., T.Awaitable[T.Any]], handler: T.Callable[
...,
T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None]
| T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None],
],
method_name: str, method_name: str,
*args, *args,
**kwargs, **kwargs,
@@ -205,12 +209,42 @@ async def call_local_llm_tool(
else: else:
raise ValueError(f"未知的方法名: {method_name}") raise ValueError(f"未知的方法名: {method_name}")
except ValueError as e: except ValueError as e:
logger.error(f"调用本地 LLM 工具时出错: {e}", exc_info=True) raise Exception(f"Tool execution ValueError: {e}") from e
except TypeError: except TypeError as e:
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True) # 获取函数的签名(包括类型),除了第一个 event/context 参数。
try:
sig = inspect.signature(handler)
params = list(sig.parameters.values())
# 跳过第一个参数(event 或 context
if params:
params = params[1:]
param_strs = []
for param in params:
param_str = param.name
if param.annotation != inspect.Parameter.empty:
# 获取类型注解的字符串表示
if isinstance(param.annotation, type):
type_str = param.annotation.__name__
else:
type_str = str(param.annotation)
param_str += f": {type_str}"
if param.default != inspect.Parameter.empty:
param_str += f" = {param.default!r}"
param_strs.append(param_str)
handler_param_str = (
", ".join(param_strs) if param_strs else "(no additional parameters)"
)
except Exception:
handler_param_str = "(unable to inspect signature)"
raise Exception(
f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}"
) from e
except Exception as e: except Exception as e:
trace_ = traceback.format_exc() trace_ = traceback.format_exc()
logger.error(f"调用本地 LLM 工具时出错: {e}\n{trace_}") raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e
if not ready_to_call: if not ready_to_call:
return return
+4
View File
@@ -24,6 +24,10 @@ class AstrBotConfig(dict):
- 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。 - 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。
""" """
config_path: str
default_config: dict
schema: dict | None
def __init__( def __init__(
self, self,
config_path: str = ASTRBOT_CONFIG_PATH, config_path: str = ASTRBOT_CONFIG_PATH,
+226 -212
View File
@@ -1,10 +1,11 @@
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。""" """如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
import os import os
from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.8.0" VERSION = "4.10.2"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [ WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -13,6 +14,7 @@ WEBHOOK_SUPPORTED_PLATFORMS = [
"wecom", "wecom",
"wecom_ai_bot", "wecom_ai_bot",
"slack", "slack",
"lark",
] ]
# 默认配置 # 默认配置
@@ -42,7 +44,15 @@ DEFAULT_CONFIG = {
"interval": "1.5,3.5", "interval": "1.5,3.5",
"log_base": 2.6, "log_base": 2.6,
"words_count_threshold": 150, "words_count_threshold": 150,
"split_mode": "regex", # regex 或 words
"regex": ".*?[。?!~…]+|.+$", "regex": ".*?[。?!~…]+|.+$",
"split_words": [
"",
"",
"",
"~",
"",
], # 当 split_mode 为 words 时使用
"content_cleanup_rule": "", "content_cleanup_rule": "",
}, },
"no_permission_reply": True, "no_permission_reply": True,
@@ -52,7 +62,8 @@ DEFAULT_CONFIG = {
"ignore_bot_self_message": False, "ignore_bot_self_message": False,
"ignore_at_all": False, "ignore_at_all": False,
}, },
"provider": [], "provider_sources": [], # provider sources
"provider": [], # models from provider_sources
"provider_settings": { "provider_settings": {
"enable": True, "enable": True,
"default_provider_id": "", "default_provider_id": "",
@@ -99,6 +110,7 @@ DEFAULT_CONFIG = {
"provider_id": "", "provider_id": "",
"dual_output": False, "dual_output": False,
"use_file_service": False, "use_file_service": False,
"trigger_probability": 1.0,
}, },
"provider_ltm_settings": { "provider_ltm_settings": {
"group_icl_enable": False, "group_icl_enable": False,
@@ -157,9 +169,26 @@ DEFAULT_CONFIG = {
"kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量 "kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量
"kb_final_top_k": 5, # 知识库检索最终返回结果数量 "kb_final_top_k": 5, # 知识库检索最终返回结果数量
"kb_agentic_mode": False, "kb_agentic_mode": False,
"disable_builtin_commands": False,
} }
class ChatProviderTemplate(TypedDict):
id: str
provider_source_id: str
model: str
modalities: list
custom_extra_body: dict[str, Any]
CHAT_PROVIDER_TEMPLATE = {
"id": "",
"provide_source_id": "",
"model": "",
"modalities": [],
"custom_extra_body": {},
}
""" """
AstrBot v3 时代的配置元数据,目前仅承担以下功能: AstrBot v3 时代的配置元数据,目前仅承担以下功能:
@@ -198,7 +227,7 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0", "callback_server_host": "0.0.0.0",
"port": 6196, "port": 6196,
}, },
"QQ 个人号(OneBot v11)": { "OneBot v11": {
"id": "default", "id": "default",
"type": "aiocqhttp", "type": "aiocqhttp",
"enable": False, "enable": False,
@@ -268,6 +297,10 @@ CONFIG_METADATA_2 = {
"app_id": "", "app_id": "",
"app_secret": "", "app_secret": "",
"domain": "https://open.feishu.cn", "domain": "https://open.feishu.cn",
"lark_connection_mode": "socket", # webhook, socket
"webhook_uuid": "",
"lark_encrypt_key": "",
"lark_verification_token": "",
}, },
"钉钉(DingTalk)": { "钉钉(DingTalk)": {
"id": "dingtalk", "id": "dingtalk",
@@ -361,6 +394,28 @@ CONFIG_METADATA_2 = {
# "type": "string", # "type": "string",
# "options": ["fullscreen", "embedded"], # "options": ["fullscreen", "embedded"],
# }, # },
"lark_connection_mode": {
"description": "订阅方式",
"type": "string",
"options": ["socket", "webhook"],
"labels": ["长连接模式", "推送至服务器模式"],
},
"lark_encrypt_key": {
"description": "Encrypt Key",
"type": "string",
"hint": "用于解密飞书回调数据的加密密钥",
"condition": {
"lark_connection_mode": "webhook",
},
},
"lark_verification_token": {
"description": "Verification Token",
"type": "string",
"hint": "用于验证飞书回调请求的令牌",
"condition": {
"lark_connection_mode": "webhook",
},
},
"is_sandbox": { "is_sandbox": {
"description": "沙箱模式", "description": "沙箱模式",
"type": "bool", "type": "bool",
@@ -807,6 +862,7 @@ CONFIG_METADATA_2 = {
"metadata": { "metadata": {
"provider": { "provider": {
"type": "list", "type": "list",
# provider sources templates
"config_template": { "config_template": {
"OpenAI": { "OpenAI": {
"id": "openai", "id": "openai",
@@ -817,107 +873,10 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.openai.com/v1", "api_base": "https://api.openai.com/v1",
"timeout": 120, "timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_headers": {}, "custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
}, },
"Azure OpenAI": { "Google Gemini": {
"id": "azure", "id": "google_gemini",
"provider": "azure",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"api_version": "2024-05-01-preview",
"key": [],
"api_base": "",
"timeout": 120,
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"xAI": {
"id": "xai",
"provider": "xai",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"xai_native_search": False,
"modalities": ["text", "image", "tool_use"],
},
"Anthropic": {
"hint": "注意Claude系列模型的温度调节范围为0到1.0,超出可能导致报错",
"id": "claude",
"provider": "anthropic",
"type": "anthropic_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.anthropic.com/v1",
"timeout": 120,
"model_config": {
"model": "claude-3-5-sonnet-latest",
"max_tokens": 4096,
"temperature": 0.2,
},
"modalities": ["text", "image", "tool_use"],
},
"Ollama": {
"hint": "启用前请确保已正确安装并运行 Ollama 服务端,Ollama默认不带鉴权,无需修改key",
"id": "ollama_default",
"provider": "ollama",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://localhost:11434/v1",
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"LM Studio": {
"id": "lm_studio",
"provider": "lm_studio",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": ["lmstudio"],
"api_base": "http://localhost:1234/v1",
"model_config": {
"model": "llama-3.1-8b",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Gemini(OpenAI兼容)": {
"id": "gemini_default",
"provider": "google",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
"timeout": 120,
"model_config": {
"model": "gemini-1.5-flash",
"temperature": 0.4,
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"Gemini": {
"id": "gemini_default",
"provider": "google", "provider": "google",
"type": "googlegenai_chat_completion", "type": "googlegenai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
@@ -925,10 +884,6 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://generativelanguage.googleapis.com/", "api_base": "https://generativelanguage.googleapis.com/",
"timeout": 120, "timeout": 120,
"model_config": {
"model": "gemini-2.0-flash-exp",
"temperature": 0.4,
},
"gm_resp_image_modal": False, "gm_resp_image_modal": False,
"gm_native_search": False, "gm_native_search": False,
"gm_native_coderunner": False, "gm_native_coderunner": False,
@@ -939,13 +894,43 @@ CONFIG_METADATA_2 = {
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE", "sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE", "dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
}, },
"gm_thinking_config": { "gm_thinking_config": {"budget": 0, "level": "HIGH"},
"budget": 0, },
}, "Anthropic": {
"modalities": ["text", "image", "tool_use"], "id": "anthropic",
"provider": "anthropic",
"type": "anthropic_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.anthropic.com/v1",
"timeout": 120,
},
"Moonshot": {
"id": "moonshot",
"provider": "moonshot",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api.moonshot.cn/v1",
"custom_headers": {},
},
"xAI": {
"id": "xai",
"provider": "xai",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://api.x.ai/v1",
"timeout": 120,
"custom_headers": {},
"xai_native_search": False,
}, },
"DeepSeek": { "DeepSeek": {
"id": "deepseek_default", "id": "deepseek",
"provider": "deepseek", "provider": "deepseek",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
@@ -953,13 +938,75 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.deepseek.com/v1", "api_base": "https://api.deepseek.com/v1",
"timeout": 120, "timeout": 120,
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
"custom_headers": {}, "custom_headers": {},
"custom_extra_body": {}, },
"modalities": ["text", "tool_use"], "Zhipu": {
"id": "zhipu",
"provider": "zhipu",
"type": "zhipu_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"custom_headers": {},
},
"Azure OpenAI": {
"id": "azure_openai",
"provider": "azure",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"api_version": "2024-05-01-preview",
"key": [],
"api_base": "",
"timeout": 120,
"custom_headers": {},
},
"Ollama": {
"id": "ollama",
"provider": "ollama",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": ["ollama"], # ollama 的 key 默认是 ollama
"api_base": "http://127.0.0.1:11434/v1",
"custom_headers": {},
},
"LM Studio": {
"id": "lm_studio",
"provider": "lm_studio",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": ["lmstudio"],
"api_base": "http://127.0.0.1:1234/v1",
"custom_headers": {},
},
"ModelStack": {
"id": "modelstack",
"provider": "modelstack",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://modelstack.app/v1",
"timeout": 120,
"custom_headers": {},
},
"Gemini_OpenAI_API": {
"id": "google_gemini_openai",
"provider": "google",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
"timeout": 120,
"custom_headers": {},
}, },
"Groq": { "Groq": {
"id": "groq_default", "id": "groq",
"provider": "groq", "provider": "groq",
"type": "groq_chat_completion", "type": "groq_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
@@ -967,13 +1014,7 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.groq.com/openai/v1", "api_base": "https://api.groq.com/openai/v1",
"timeout": 120, "timeout": 120,
"model_config": {
"model": "openai/gpt-oss-20b",
"temperature": 0.4,
},
"custom_headers": {}, "custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "tool_use"],
}, },
"302.AI": { "302.AI": {
"id": "302ai", "id": "302ai",
@@ -984,12 +1025,9 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.302.ai/v1", "api_base": "https://api.302.ai/v1",
"timeout": 120, "timeout": 120,
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
"custom_headers": {}, "custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
}, },
"硅基流动": { "SiliconFlow": {
"id": "siliconflow", "id": "siliconflow",
"provider": "siliconflow", "provider": "siliconflow",
"type": "openai_chat_completion", "type": "openai_chat_completion",
@@ -998,15 +1036,9 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"timeout": 120, "timeout": 120,
"api_base": "https://api.siliconflow.cn/v1", "api_base": "https://api.siliconflow.cn/v1",
"model_config": {
"model": "deepseek-ai/DeepSeek-V3",
"temperature": 0.4,
},
"custom_headers": {}, "custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
}, },
"PPIO派欧云": { "PPIO": {
"id": "ppio", "id": "ppio",
"provider": "ppio", "provider": "ppio",
"type": "openai_chat_completion", "type": "openai_chat_completion",
@@ -1015,14 +1047,9 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.ppinfra.com/v3/openai", "api_base": "https://api.ppinfra.com/v3/openai",
"timeout": 120, "timeout": 120,
"model_config": {
"model": "deepseek/deepseek-r1",
"temperature": 0.4,
},
"custom_headers": {}, "custom_headers": {},
"custom_extra_body": {},
}, },
"小马算力": { "TokenPony": {
"id": "tokenpony", "id": "tokenpony",
"provider": "tokenpony", "provider": "tokenpony",
"type": "openai_chat_completion", "type": "openai_chat_completion",
@@ -1031,14 +1058,9 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.tokenpony.cn/v1", "api_base": "https://api.tokenpony.cn/v1",
"timeout": 120, "timeout": 120,
"model_config": {
"model": "kimi-k2-instruct-0905",
"temperature": 0.7,
},
"custom_headers": {}, "custom_headers": {},
"custom_extra_body": {},
}, },
"优云智算": { "Compshare": {
"id": "compshare", "id": "compshare",
"provider": "compshare", "provider": "compshare",
"type": "openai_chat_completion", "type": "openai_chat_completion",
@@ -1047,42 +1069,18 @@ CONFIG_METADATA_2 = {
"key": [], "key": [],
"api_base": "https://api.modelverse.cn/v1", "api_base": "https://api.modelverse.cn/v1",
"timeout": 120, "timeout": 120,
"model_config": {
"model": "moonshotai/Kimi-K2-Instruct",
},
"custom_headers": {}, "custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
}, },
"Kimi": { "ModelScope": {
"id": "moonshot", "id": "modelscope",
"provider": "moonshot", "provider": "modelscope",
"type": "openai_chat_completion", "type": "openai_chat_completion",
"provider_type": "chat_completion", "provider_type": "chat_completion",
"enable": True, "enable": True,
"key": [], "key": [],
"timeout": 120, "timeout": 120,
"api_base": "https://api.moonshot.cn/v1", "api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
"custom_headers": {}, "custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"智谱 AI": {
"id": "zhipu_default",
"provider": "zhipu",
"type": "zhipu_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
}, },
"Dify": { "Dify": {
"id": "dify_app_default", "id": "dify_app_default",
@@ -1097,7 +1095,6 @@ CONFIG_METADATA_2 = {
"dify_query_input_key": "astrbot_text_query", "dify_query_input_key": "astrbot_text_query",
"variables": {}, "variables": {},
"timeout": 60, "timeout": 60,
"hint": "请确保你在 AstrBot 里设置的 APP 类型和 Dify 里面创建的应用的类型一致!",
}, },
"Coze": { "Coze": {
"id": "coze", "id": "coze",
@@ -1128,20 +1125,6 @@ CONFIG_METADATA_2 = {
"variables": {}, "variables": {},
"timeout": 60, "timeout": 60,
}, },
"ModelScope": {
"id": "modelscope",
"provider": "modelscope",
"type": "openai_chat_completion",
"provider_type": "chat_completion",
"enable": True,
"key": [],
"timeout": 120,
"api_base": "https://api-inference.modelscope.cn/v1",
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
"custom_headers": {},
"custom_extra_body": {},
"modalities": ["text", "image", "tool_use"],
},
"FastGPT": { "FastGPT": {
"id": "fastgpt", "id": "fastgpt",
"provider": "fastgpt", "provider": "fastgpt",
@@ -1165,7 +1148,6 @@ CONFIG_METADATA_2 = {
"model": "whisper-1", "model": "whisper-1",
}, },
"Whisper(Local)": { "Whisper(Local)": {
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cudaCPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"provider": "openai", "provider": "openai",
"type": "openai_whisper_selfhost", "type": "openai_whisper_selfhost",
"provider_type": "speech_to_text", "provider_type": "speech_to_text",
@@ -1174,7 +1156,6 @@ CONFIG_METADATA_2 = {
"model": "tiny", "model": "tiny",
}, },
"SenseVoice(Local)": { "SenseVoice(Local)": {
"hint": "启用前请 pip 安装 funasr、funasr_onnx、torchaudio、torch、modelscope、jieba 库(默认使用CPU,大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"type": "sensevoice_stt_selfhost", "type": "sensevoice_stt_selfhost",
"provider": "sensevoice", "provider": "sensevoice",
"provider_type": "speech_to_text", "provider_type": "speech_to_text",
@@ -1196,7 +1177,6 @@ CONFIG_METADATA_2 = {
"timeout": "20", "timeout": "20",
}, },
"Edge TTS": { "Edge TTS": {
"hint": "提示:使用这个服务前需要安装有 ffmpeg,并且可以直接在终端调用 ffmpeg 指令。",
"id": "edge_tts", "id": "edge_tts",
"provider": "microsoft", "provider": "microsoft",
"type": "edge_tts", "type": "edge_tts",
@@ -1412,6 +1392,10 @@ CONFIG_METADATA_2 = {
}, },
}, },
"items": { "items": {
"provider_source_id": {
"invisible": True,
"type": "string",
},
"xai_native_search": { "xai_native_search": {
"description": "启用原生搜索功能", "description": "启用原生搜索功能",
"type": "bool", "type": "bool",
@@ -1782,13 +1766,24 @@ CONFIG_METADATA_2 = {
}, },
}, },
"gm_thinking_config": { "gm_thinking_config": {
"description": "Gemini思考设置", "description": "Thinking Config",
"type": "object", "type": "object",
"items": { "items": {
"budget": { "budget": {
"description": "思考预算", "description": "Thinking Budget",
"type": "int", "type": "int",
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。", "hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget",
},
"level": {
"description": "Thinking Level",
"type": "string",
"hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
"options": [
"MINIMAL",
"LOW",
"MEDIUM",
"HIGH",
],
}, },
}, },
}, },
@@ -1969,7 +1964,6 @@ CONFIG_METADATA_2 = {
"id": { "id": {
"description": "ID", "description": "ID",
"type": "string", "type": "string",
"hint": "模型提供商名字。",
}, },
"type": { "type": {
"description": "模型提供商种类", "description": "模型提供商种类",
@@ -1989,29 +1983,15 @@ CONFIG_METADATA_2 = {
"description": "API Key", "description": "API Key",
"type": "list", "type": "list",
"items": {"type": "string"}, "items": {"type": "string"},
"hint": "提供商 API Key。",
}, },
"api_base": { "api_base": {
"description": "API Base URL", "description": "API Base URL",
"type": "string", "type": "string",
"hint": "API Base URL 请在模型提供商处获得。如出现 404 报错,尝试在地址末尾加上 /v1",
}, },
"model_config": { "model": {
"description": "模型配置", "description": "模型 ID",
"type": "object", "type": "string",
"items": { "hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
"model": {
"description": "模型名称",
"type": "string",
"hint": "模型名称,如 gpt-4o-mini, deepseek-chat。",
},
"max_tokens": {
"description": "模型最大输出长度(tokens",
"type": "int",
},
"temperature": {"description": "温度", "type": "float"},
"top_p": {"description": "Top P值", "type": "float"},
},
}, },
"dify_api_key": { "dify_api_key": {
"description": "API Key", "description": "API Key",
@@ -2173,6 +2153,9 @@ CONFIG_METADATA_2 = {
"use_file_service": { "use_file_service": {
"type": "bool", "type": "bool",
}, },
"trigger_probability": {
"type": "float",
},
}, },
}, },
"provider_ltm_settings": { "provider_ltm_settings": {
@@ -2383,6 +2366,14 @@ CONFIG_METADATA_3 = {
"provider_tts_settings.enable": True, "provider_tts_settings.enable": True,
}, },
}, },
"provider_tts_settings.trigger_probability": {
"description": "TTS 触发概率",
"type": "float",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": {
"provider_tts_settings.enable": True,
},
},
"provider_settings.image_caption_prompt": { "provider_settings.image_caption_prompt": {
"description": "图片转述提示词", "description": "图片转述提示词",
"type": "text", "type": "text",
@@ -2661,6 +2652,11 @@ CONFIG_METADATA_3 = {
"description": "只 @ 机器人是否触发等待", "description": "只 @ 机器人是否触发等待",
"type": "bool", "type": "bool",
}, },
"disable_builtin_commands": {
"description": "禁用自带指令",
"type": "bool",
"hint": "禁用所有 AstrBot 的自带指令,如 help, provider, model 等。",
},
}, },
}, },
"whitelist": { "whitelist": {
@@ -2875,9 +2871,26 @@ CONFIG_METADATA_3 = {
"description": "分段回复字数阈值", "description": "分段回复字数阈值",
"type": "int", "type": "int",
}, },
"platform_settings.segmented_reply.split_mode": {
"description": "分段模式",
"type": "string",
"options": ["regex", "words"],
"labels": ["正则表达式", "分段词列表"],
},
"platform_settings.segmented_reply.regex": { "platform_settings.segmented_reply.regex": {
"description": "分段正则表达式", "description": "分段正则表达式",
"type": "string", "type": "string",
"condition": {
"platform_settings.segmented_reply.split_mode": "regex",
},
},
"platform_settings.segmented_reply.split_words": {
"description": "分段词列表",
"type": "list",
"hint": "检测到列表中的任意词时进行分段,如:。、?、!等",
"condition": {
"platform_settings.segmented_reply.split_mode": "words",
},
}, },
"platform_settings.segmented_reply.content_cleanup_rule": { "platform_settings.segmented_reply.content_cleanup_rule": {
"description": "内容过滤正则表达式", "description": "内容过滤正则表达式",
@@ -2928,6 +2941,7 @@ CONFIG_METADATA_3 = {
"description": "回复概率", "description": "回复概率",
"type": "float", "type": "float",
"hint": "0.0-1.0 之间的数值", "hint": "0.0-1.0 之间的数值",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": { "condition": {
"provider_ltm_settings.active_reply.enable": True, "provider_ltm_settings.active_reply.enable": True,
}, },
+1
View File
@@ -79,6 +79,7 @@ class ConfigMetadataI18n:
"_special", "_special",
"invisible", "invisible",
"options", "options",
"slider",
]: ]:
if attr in field_data: if attr in field_data:
field_result[attr] = field_data[attr] field_result[attr] = field_data[attr]
+4 -1
View File
@@ -33,6 +33,7 @@ from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.umop_config_router import UmopConfigRouter from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.llm_metadata import update_llm_metadata
from astrbot.core.utils.migra_helper import migra from astrbot.core.utils.migra_helper import migra
from . import astrbot_config, html_renderer from . import astrbot_config, html_renderer
@@ -185,6 +186,8 @@ class AstrBotCoreLifecycle:
# 初始化关闭控制面板的事件 # 初始化关闭控制面板的事件
self.dashboard_shutdown_event = asyncio.Event() self.dashboard_shutdown_event = asyncio.Event()
asyncio.create_task(update_llm_metadata())
def _load(self) -> None: def _load(self) -> None:
"""加载事件总线和任务并初始化.""" """加载事件总线和任务并初始化."""
# 创建一个异步任务来执行事件总线的 dispatch() 方法 # 创建一个异步任务来执行事件总线的 dispatch() 方法
@@ -197,7 +200,7 @@ class AstrBotCoreLifecycle:
# 把插件中注册的所有协程函数注册到事件总线中并执行 # 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = [] extra_tasks = []
for task in self.star_context._register_tasks: for task in self.star_context._register_tasks:
extra_tasks.append(asyncio.create_task(task, name=task.__name__)) extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore
tasks_ = [event_bus_task, *extra_tasks] tasks_ = [event_bus_task, *extra_tasks]
for task in tasks_: for task in tasks_:
+74 -3
View File
@@ -5,11 +5,12 @@ from contextlib import asynccontextmanager
from dataclasses import dataclass from dataclasses import dataclass
from deprecated import deprecated from deprecated import deprecated
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
from astrbot.core.db.po import ( from astrbot.core.db.po import (
Attachment, Attachment,
CommandConfig,
CommandConflict,
ConversationV2, ConversationV2,
Persona, Persona,
PlatformMessageHistory, PlatformMessageHistory,
@@ -32,7 +33,7 @@ class BaseDatabase(abc.ABC):
echo=False, echo=False,
future=True, future=True,
) )
self.AsyncSessionLocal = sessionmaker( self.AsyncSessionLocal = async_sessionmaker(
self.engine, self.engine,
class_=AsyncSession, class_=AsyncSession,
expire_on_commit=False, expire_on_commit=False,
@@ -315,6 +316,76 @@ class BaseDatabase(abc.ABC):
"""Clear all preferences for a specific scope ID.""" """Clear all preferences for a specific scope ID."""
... ...
@abc.abstractmethod
async def get_command_configs(self) -> list[CommandConfig]:
"""Get all stored command configurations."""
...
@abc.abstractmethod
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
"""Fetch a single command configuration by handler."""
...
@abc.abstractmethod
async def upsert_command_config(
self,
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
"""Create or update a command configuration."""
...
@abc.abstractmethod
async def delete_command_config(self, handler_full_name: str) -> None:
"""Delete a single command configuration."""
...
@abc.abstractmethod
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
"""Bulk delete command configurations."""
...
@abc.abstractmethod
async def list_command_conflicts(
self,
status: str | None = None,
) -> list[CommandConflict]:
"""List recorded command conflict entries."""
...
@abc.abstractmethod
async def upsert_command_conflict(
self,
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
"""Create or update a conflict record."""
...
@abc.abstractmethod
async def delete_command_conflicts(self, ids: list[int]) -> None:
"""Delete conflict records."""
...
# @abc.abstractmethod # @abc.abstractmethod
# async def insert_llm_message( # async def insert_llm_message(
# self, # self,
@@ -70,6 +70,7 @@ async def migration_conversation_table(
logger.info( logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
) )
continue
if ":" not in conv.user_id: if ":" not in conv.user_id:
continue continue
session = MessageSesion.from_str(session_str=conv.user_id) session = MessageSesion.from_str(session_str=conv.user_id)
@@ -207,6 +208,7 @@ async def migration_webchat_data(
logger.info( logger.info(
f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。",
) )
continue
if ":" in conv.user_id: if ":" in conv.user_id:
continue continue
platform_id = "webchat" platform_id = "webchat"
+6 -4
View File
@@ -127,7 +127,7 @@ class SQLiteDatabase:
conn.text_factory = str conn.text_factory = str
return conn return conn
def _exec_sql(self, sql: str, params: tuple = None): def _exec_sql(self, sql: str, params: tuple | None = None):
conn = self.conn conn = self.conn
try: try:
c = self.conn.cursor() c = self.conn.cursor()
@@ -224,9 +224,11 @@ class SQLiteDatabase:
c.close() c.close()
return Stats(platform, [], []) return Stats(platform)
def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: def get_conversation_by_user_id(
self, user_id: str, cid: str
) -> Conversation | None:
try: try:
c = self.conn.cursor() c = self.conn.cursor()
except sqlite3.ProgrammingError: except sqlite3.ProgrammingError:
@@ -258,7 +260,7 @@ class SQLiteDatabase:
(user_id, cid, history, updated_at, created_at), (user_id, cid, history, updated_at, created_at),
) )
def get_conversations(self, user_id: str) -> tuple: def get_conversations(self, user_id: str) -> list[Conversation]:
try: try:
c = self.conn.cursor() c = self.conn.cursor()
except sqlite3.ProgrammingError: except sqlite3.ProgrammingError:
+75 -15
View File
@@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True):
Note: In astrbot v4, we moved `platform` table to here. Note: In astrbot v4, we moved `platform` table to here.
""" """
__tablename__ = "platform_stats" # type: ignore __tablename__: str = "platform_stats"
id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True})
timestamp: datetime = Field(nullable=False) timestamp: datetime = Field(nullable=False)
@@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True):
class ConversationV2(SQLModel, table=True): class ConversationV2(SQLModel, table=True):
__tablename__ = "conversations" # type: ignore __tablename__: str = "conversations"
inner_conversation_id: int = Field( inner_conversation_id: int | None = Field(
default=None,
primary_key=True, primary_key=True,
sa_column_kwargs={"autoincrement": True}, sa_column_kwargs={"autoincrement": True},
) )
@@ -68,7 +69,7 @@ class Persona(SQLModel, table=True):
It can be used to customize the behavior of LLMs. It can be used to customize the behavior of LLMs.
""" """
__tablename__ = "personas" # type: ignore __tablename__: str = "personas"
id: int | None = Field( id: int | None = Field(
primary_key=True, primary_key=True,
@@ -98,7 +99,7 @@ class Persona(SQLModel, table=True):
class Preference(SQLModel, table=True): class Preference(SQLModel, table=True):
"""This class represents preferences for bots.""" """This class represents preferences for bots."""
__tablename__ = "preferences" # type: ignore __tablename__: str = "preferences"
id: int | None = Field( id: int | None = Field(
default=None, default=None,
@@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True):
or platform-specific messages. or platform-specific messages.
""" """
__tablename__ = "platform_message_history" # type: ignore __tablename__: str = "platform_message_history"
id: int | None = Field( id: int | None = Field(
primary_key=True, primary_key=True,
@@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True):
Each session can have multiple conversations (对话) associated with it. Each session can have multiple conversations (对话) associated with it.
""" """
__tablename__ = "platform_sessions" # type: ignore __tablename__: str = "platform_sessions"
inner_id: int | None = Field( inner_id: int | None = Field(
primary_key=True, primary_key=True,
@@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True):
Attachments can be images, files, or other media types. Attachments can be images, files, or other media types.
""" """
__tablename__ = "attachments" # type: ignore __tablename__: str = "attachments"
inner_attachment_id: int | None = Field( inner_attachment_id: int | None = Field(
primary_key=True, primary_key=True,
@@ -233,6 +234,65 @@ class Attachment(SQLModel, table=True):
) )
class CommandConfig(SQLModel, table=True):
"""Per-command configuration overrides for dashboard management."""
__tablename__ = "command_configs" # type: ignore
handler_full_name: str = Field(
primary_key=True,
max_length=512,
)
plugin_name: str = Field(nullable=False, max_length=255)
module_path: str = Field(nullable=False, max_length=255)
original_command: str = Field(nullable=False, max_length=255)
resolved_command: str | None = Field(default=None, max_length=255)
enabled: bool = Field(default=True, nullable=False)
keep_original_alias: bool = Field(default=False, nullable=False)
conflict_key: str | None = Field(default=None, max_length=255)
resolution_strategy: str | None = Field(default=None, max_length=64)
note: str | None = Field(default=None, sa_type=Text)
extra_data: dict | None = Field(default=None, sa_type=JSON)
auto_managed: bool = Field(default=False, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class CommandConflict(SQLModel, table=True):
"""Conflict tracking for duplicated command names."""
__tablename__ = "command_conflicts" # type: ignore
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
conflict_key: str = Field(nullable=False, max_length=255)
handler_full_name: str = Field(nullable=False, max_length=512)
plugin_name: str = Field(nullable=False, max_length=255)
status: str = Field(default="pending", max_length=32)
resolution: str | None = Field(default=None, max_length=64)
resolved_command: str | None = Field(default=None, max_length=255)
note: str | None = Field(default=None, sa_type=Text)
extra_data: dict | None = Field(default=None, sa_type=JSON)
auto_generated: bool = Field(default=False, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"conflict_key",
"handler_full_name",
name="uix_conflict_handler",
),
)
@dataclass @dataclass
class Conversation: class Conversation:
"""LLM 对话类 """LLM 对话类
@@ -261,17 +321,17 @@ class Personality(TypedDict):
v4.0.0 版本及之后推荐使用上面的 Persona 并且 mood_imitation_dialogs 字段已被废弃 v4.0.0 版本及之后推荐使用上面的 Persona 并且 mood_imitation_dialogs 字段已被废弃
""" """
prompt: str = "" prompt: str
name: str = "" name: str
begin_dialogs: list[str] = [] begin_dialogs: list[str]
mood_imitation_dialogs: list[str] = [] mood_imitation_dialogs: list[str]
"""情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。""" """情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。"""
tools: list[str] | None = None tools: list[str] | None
"""工具列表。None 表示使用所有工具,空列表表示不使用任何工具""" """工具列表。None 表示使用所有工具,空列表表示不使用任何工具"""
# cache # cache
_begin_dialogs_processed: list[dict] = [] _begin_dialogs_processed: list[dict]
_mood_imitation_dialogs_processed: str = "" _mood_imitation_dialogs_processed: str
# ==== # ====
+244 -3
View File
@@ -1,14 +1,18 @@
import asyncio import asyncio
import threading import threading
import typing as T import typing as T
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col, delete, desc, func, or_, select, text, update from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import ( from astrbot.core.db.po import (
Attachment, Attachment,
CommandConfig,
CommandConflict,
ConversationV2, ConversationV2,
Persona, Persona,
PlatformMessageHistory, PlatformMessageHistory,
@@ -25,6 +29,7 @@ from astrbot.core.db.po import (
) )
NOT_GIVEN = T.TypeVar("NOT_GIVEN") NOT_GIVEN = T.TypeVar("NOT_GIVEN")
TxResult = T.TypeVar("TxResult")
class SQLiteDatabase(BaseDatabase): class SQLiteDatabase(BaseDatabase):
@@ -489,7 +494,7 @@ class SQLiteDatabase(BaseDatabase):
async with self.get_db() as session: async with self.get_db() as session:
session: AsyncSession session: AsyncSession
query = select(Attachment).where( query = select(Attachment).where(
Attachment.attachment_id.in_(attachment_ids) col(Attachment.attachment_id).in_(attachment_ids)
) )
result = await session.execute(query) result = await session.execute(query)
return list(result.scalars().all()) return list(result.scalars().all())
@@ -505,7 +510,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where( query = delete(Attachment).where(
col(Attachment.attachment_id) == attachment_id col(Attachment.attachment_id) == attachment_id
) )
result = await session.execute(query) result = T.cast(CursorResult, await session.execute(query))
return result.rowcount > 0 return result.rowcount > 0
async def delete_attachments(self, attachment_ids: list[str]) -> int: async def delete_attachments(self, attachment_ids: list[str]) -> int:
@@ -521,7 +526,7 @@ class SQLiteDatabase(BaseDatabase):
query = delete(Attachment).where( query = delete(Attachment).where(
col(Attachment.attachment_id).in_(attachment_ids) col(Attachment.attachment_id).in_(attachment_ids)
) )
result = await session.execute(query) result = T.cast(CursorResult, await session.execute(query))
return result.rowcount return result.rowcount
async def insert_persona( async def insert_persona(
@@ -669,6 +674,242 @@ class SQLiteDatabase(BaseDatabase):
) )
await session.commit() await session.commit()
# ====
# Command Configuration & Conflict Tracking
# ====
async def _run_in_tx(
self,
fn: Callable[[AsyncSession], Awaitable[TxResult]],
) -> TxResult:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
return await fn(session)
@staticmethod
def _apply_updates(model, **updates) -> None:
for field, value in updates.items():
if value is not None:
setattr(model, field, value)
@staticmethod
def _new_command_config(
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
return CommandConfig(
handler_full_name=handler_full_name,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=True if enabled is None else enabled,
keep_original_alias=False
if keep_original_alias is None
else keep_original_alias,
conflict_key=conflict_key or original_command,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=bool(auto_managed),
)
@staticmethod
def _new_command_conflict(
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
return CommandConflict(
conflict_key=conflict_key,
handler_full_name=handler_full_name,
plugin_name=plugin_name,
status=status or "pending",
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=bool(auto_generated),
)
async def get_command_configs(self) -> list[CommandConfig]:
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(select(CommandConfig))
return list(result.scalars().all())
async def get_command_config(
self,
handler_full_name: str,
) -> CommandConfig | None:
async with self.get_db() as session:
session: AsyncSession
return await session.get(CommandConfig, handler_full_name)
async def upsert_command_config(
self,
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
async def _op(session: AsyncSession) -> CommandConfig:
config = await session.get(CommandConfig, handler_full_name)
if not config:
config = self._new_command_config(
handler_full_name,
plugin_name,
module_path,
original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
session.add(config)
else:
self._apply_updates(
config,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
await session.flush()
await session.refresh(config)
return config
return await self._run_in_tx(_op)
async def delete_command_config(self, handler_full_name: str) -> None:
await self.delete_command_configs([handler_full_name])
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
if not handler_full_names:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConfig).where(
col(CommandConfig.handler_full_name).in_(handler_full_names),
),
)
await self._run_in_tx(_op)
async def list_command_conflicts(
self,
status: str | None = None,
) -> list[CommandConflict]:
async with self.get_db() as session:
session: AsyncSession
query = select(CommandConflict)
if status:
query = query.where(CommandConflict.status == status)
result = await session.execute(query)
return list(result.scalars().all())
async def upsert_command_conflict(
self,
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
async def _op(session: AsyncSession) -> CommandConflict:
result = await session.execute(
select(CommandConflict).where(
CommandConflict.conflict_key == conflict_key,
CommandConflict.handler_full_name == handler_full_name,
),
)
record = result.scalar_one_or_none()
if not record:
record = self._new_command_conflict(
conflict_key,
handler_full_name,
plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
session.add(record)
else:
self._apply_updates(
record,
plugin_name=plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
await session.flush()
await session.refresh(record)
return record
return await self._run_in_tx(_op)
async def delete_command_conflicts(self, ids: list[int]) -> None:
if not ids:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
)
await self._run_in_tx(_op)
# ==== # ====
# Deprecated Methods # Deprecated Methods
# ==== # ====
@@ -90,4 +90,6 @@ class EmbeddingStorage:
path (str): 保存索引的路径 path (str): 保存索引的路径
""" """
if self.index is None:
return
faiss.write_index(self.index, self.path) faiss.write_index(self.index, self.path)
+6 -1
View File
@@ -27,7 +27,7 @@ class EventBus:
self, self,
event_queue: Queue, event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler], pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager = None, astrbot_config_mgr: AstrBotConfigManager,
): ):
self.event_queue = event_queue # 事件队列 self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler # abconf uuid -> scheduler
@@ -40,6 +40,11 @@ class EventBus:
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
self._print_event(event, conf_info["name"]) self._print_event(event, conf_info["name"])
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"]) scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
if not scheduler:
logger.error(
f"PipelineScheduler not found for id: {conf_info['id']}, event ignored."
)
continue
asyncio.create_task(scheduler.execute(event)) asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str): def _print_event(self, event: AstrMessageEvent, conf_name: str):
@@ -166,7 +166,11 @@ class RetrievalManager:
# 5. Rerank # 5. Rerank
first_rerank = None first_rerank = None
for kb_id in kb_ids: for kb_id in kb_ids:
vec_db: FaissVecDB = kb_options[kb_id]["vec_db"] vec_db = kb_options[kb_id]["vec_db"]
if not isinstance(vec_db, FaissVecDB):
logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB")
continue
rerank_pi = kb_options[kb_id]["rerank_provider_id"] rerank_pi = kb_options[kb_id]["rerank_provider_id"]
if ( if (
vec_db vec_db
+2 -1
View File
@@ -24,6 +24,7 @@ import asyncio
import logging import logging
import os import os
import sys import sys
import time
from asyncio import Queue from asyncio import Queue
from collections import deque from collections import deque
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
self.log_broker.publish( self.log_broker.publish(
{ {
"level": record.levelname, "level": record.levelname,
"time": record.asctime, "time": time.time(),
"data": log_entry, "data": log_entry,
}, },
) )
+13 -8
View File
@@ -66,6 +66,9 @@ class ComponentType(str, Enum):
class BaseMessageComponent(BaseModel): class BaseMessageComponent(BaseModel):
type: ComponentType type: ComponentType
def __init__(self, **kwargs):
super().__init__(**kwargs)
def toDict(self): def toDict(self):
data = {} data = {}
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
@@ -551,7 +554,7 @@ class Node(BaseMessageComponent):
id: int | None = 0 # 忽略 id: int | None = 0 # 忽略
name: str | None = "" # qq昵称 name: str | None = "" # qq昵称
uin: str | None = "0" # qq号 uin: str | None = "0" # qq号
content: list[BaseMessageComponent] | None = [] content: list[BaseMessageComponent] = []
seq: str | list | None = "" # 忽略 seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略 time: int | None = 0 # 忽略
@@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent):
ret["messages"].append(d) ret["messages"].append(d)
return ret return ret
async def to_dict(self): async def to_dict(self) -> dict:
"""将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式"""
ret = {"messages": []} ret = {"messages": []}
for node in self.nodes: for node in self.nodes:
@@ -626,12 +629,11 @@ class Nodes(BaseMessageComponent):
class Json(BaseMessageComponent): class Json(BaseMessageComponent):
type = ComponentType.Json type = ComponentType.Json
data: str | dict data: dict
resid: int | None = 0
def __init__(self, data, **_): def __init__(self, data: str | dict, **_):
if isinstance(data, dict): if isinstance(data, str):
data = json.dumps(data) data = json.loads(data)
super().__init__(data=data, **_) super().__init__(data=data, **_)
@@ -714,12 +716,15 @@ class File(BaseMessageComponent):
if self.url: if self.url:
await self._download_file() await self._download_file()
return os.path.abspath(self.file_) if self.file_:
return os.path.abspath(self.file_)
return "" return ""
async def _download_file(self): async def _download_file(self):
"""下载文件""" """下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
download_dir = os.path.join(get_astrbot_data_path(), "temp") download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True) os.makedirs(download_dir, exist_ok=True)
if self.name: if self.name:
+2 -2
View File
@@ -98,8 +98,8 @@ class PersonaManager:
self, self,
persona_id: str, persona_id: str,
system_prompt: str, system_prompt: str,
begin_dialogs: list[str] = None, begin_dialogs: list[str] | None = None,
tools: list[str] = None, tools: list[str] | None = None,
) -> Persona: ) -> Persona:
"""创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" """创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具"""
if await self.db.get_persona_by_id(persona_id): if await self.db.get_persona_by_id(persona_id):
@@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage):
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
check_text: str | None = None, check_text: str | None = None,
) -> None | AsyncGenerator[None, None]: ) -> AsyncGenerator[None, None]:
"""检查内容安全""" """检查内容安全"""
text = check_text if check_text else event.get_message_str() text = check_text if check_text else event.get_message_str()
ok, info = self.strategy_selector.check(text) ok, info = self.strategy_selector.check(text)
+2 -1
View File
@@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry
async def call_handler( async def call_handler(
event: AstrMessageEvent, event: AstrMessageEvent,
handler: T.Callable[..., T.Awaitable[T.Any]], handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]],
*args, *args,
**kwargs, **kwargs,
) -> T.AsyncGenerator[T.Any, None]: ) -> T.AsyncGenerator[T.Any, None]:
@@ -91,6 +91,7 @@ async def call_event_hook(
) )
for handler in handlers: for handler in handlers:
try: try:
assert inspect.iscoroutinefunction(handler.handler)
logger.debug( logger.debug(
f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
) )
@@ -321,7 +321,12 @@ class InternalAgentSubStage(Stage):
elif isinstance(req.tool_calls_result, list): elif isinstance(req.tool_calls_result, list):
for tcr in req.tool_calls_result: for tcr in req.tool_calls_result:
messages.extend(tcr.to_openai_messages()) messages.extend(tcr.to_openai_messages())
messages.append({"role": "assistant", "content": llm_response.completion_text}) messages.append(
{
"role": "assistant",
"content": llm_response.completion_text or "*No response*",
}
)
messages = list(filter(lambda item: "_no_save" not in item, messages)) messages = list(filter(lambda item: "_no_save" not in item, messages))
await self.conv_manager.update_conversation( await self.conv_manager.update_conversation(
event.unified_msg_origin, event.unified_msg_origin,
@@ -16,7 +16,6 @@ from ..stage import Stage
class StarRequestSubStage(Stage): class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None: async def initialize(self, ctx: PipelineContext) -> None:
self.curr_provider = ctx.plugin_manager.context.get_using_provider()
self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"] self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"]
self.identifier = ctx.astrbot_config["provider_settings"]["identifier"] self.identifier = ctx.astrbot_config["provider_settings"]["identifier"]
self.ctx = ctx self.ctx = ctx
@@ -24,7 +23,7 @@ class StarRequestSubStage(Stage):
async def process( async def process(
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
) -> AsyncGenerator[None, None]: ) -> AsyncGenerator[Any, None]:
activated_handlers: list[StarHandlerMetadata] = event.get_extra( activated_handlers: list[StarHandlerMetadata] = event.get_extra(
"activated_handlers", "activated_handlers",
) )
+1 -1
View File
@@ -60,7 +60,7 @@ class ProcessStage(Stage):
): ):
# 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀
if ( if (
event.get_result() and not event.get_result().is_stopped() event.get_result() and not event.is_stopped()
) or not event.get_result(): ) or not event.get_result():
async for _ in self.agent_sub_stage.process(event): async for _ in self.agent_sub_stage.process(event):
yield yield
+8 -2
View File
@@ -117,7 +117,9 @@ class RespondStage(Stage):
if not self.enable_seg: if not self.enable_seg:
return False return False
if self.only_llm_result and not event.get_result().is_llm_result(): if (result := event.get_result()) is None:
return False
if self.only_llm_result and not result.is_llm_result():
return False return False
if event.get_platform_name() in [ if event.get_platform_name() in [
@@ -156,7 +158,11 @@ class RespondStage(Stage):
result = event.get_result() result = event.get_result()
if result is None: if result is None:
return return
if event.get_extra("_streaming_finished", False):
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
return
if result.result_content_type == ResultContentType.STREAMING_FINISH: if result.result_content_type == ResultContentType.STREAMING_FINISH:
event.set_extra("_streaming_finished", True)
return return
logger.info( logger.info(
@@ -185,7 +191,7 @@ class RespondStage(Stage):
if isinstance(component, Comp.File) and component.file: if isinstance(component, Comp.File) and component.file:
# 支持 File 消息段的路径映射。 # 支持 File 消息段的路径映射。
component.file = path_Mapping(mappings, component.file) component.file = path_Mapping(mappings, component.file)
event.get_result().chain[idx] = component result.chain[idx] = component
# 检查消息链是否为空 # 检查消息链是否为空
try: try:
+89 -22
View File
@@ -1,3 +1,4 @@
import random
import re import re
import time import time
import traceback import traceback
@@ -6,6 +7,7 @@ from collections.abc import AsyncGenerator
from astrbot.core import file_token_service, html_renderer, logger from astrbot.core import file_token_service, html_renderer, logger
from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply
from astrbot.core.message.message_event_result import ResultContentType from astrbot.core.message.message_event_result import ResultContentType
from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType from astrbot.core.platform.message_type import MessageType
from astrbot.core.star.session_llm_manager import SessionServiceManager from astrbot.core.star.session_llm_manager import SessionServiceManager
@@ -41,6 +43,18 @@ class ResultDecorateStage(Stage):
"forward_threshold" "forward_threshold"
] ]
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
"trigger_probability",
1,
)
try:
self.tts_trigger_probability = max(
0.0,
min(float(trigger_probability), 1.0),
)
except (TypeError, ValueError):
self.tts_trigger_probability = 1.0
# 分段回复 # 分段回复
self.words_count_threshold = int( self.words_count_threshold = int(
ctx.astrbot_config["platform_settings"]["segmented_reply"][ ctx.astrbot_config["platform_settings"]["segmented_reply"][
@@ -53,7 +67,22 @@ class ResultDecorateStage(Stage):
self.only_llm_result = ctx.astrbot_config["platform_settings"][ self.only_llm_result = ctx.astrbot_config["platform_settings"][
"segmented_reply" "segmented_reply"
]["only_llm_result"] ]["only_llm_result"]
self.split_mode = ctx.astrbot_config["platform_settings"][
"segmented_reply"
].get("split_mode", "regex")
self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"] self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"]
self.split_words = ctx.astrbot_config["platform_settings"][
"segmented_reply"
].get("split_words", ["", "", "", "~", ""])
if self.split_words:
escaped_words = sorted(
[re.escape(word) for word in self.split_words], key=len, reverse=True
)
self.split_words_pattern = re.compile(
f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL
)
else:
self.split_words_pattern = None
self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][ self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][
"segmented_reply" "segmented_reply"
]["content_cleanup_rule"] ]["content_cleanup_rule"]
@@ -69,6 +98,28 @@ class ResultDecorateStage(Stage):
self.content_safe_check_stage = stage_cls() self.content_safe_check_stage = stage_cls()
await self.content_safe_check_stage.initialize(ctx) await self.content_safe_check_stage.initialize(ctx)
def _split_text_by_words(self, text: str) -> list[str]:
"""使用分段词列表分段文本"""
if not self.split_words_pattern:
return [text]
segments = self.split_words_pattern.findall(text)
result = []
for seg in segments:
if isinstance(seg, tuple):
content = seg[0]
if not isinstance(content, str):
continue
for word in self.split_words:
if content.endswith(word):
content = content[: -len(word)]
break
if content.strip():
result.append(content)
elif seg and seg.strip():
result.append(seg)
return result if result else [text]
async def process( async def process(
self, self,
event: AstrMessageEvent, event: AstrMessageEvent,
@@ -93,11 +144,13 @@ class ResultDecorateStage(Stage):
for comp in result.chain: for comp in result.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
text += comp.text text += comp.text
async for _ in self.content_safe_check_stage.process(
event, if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage):
check_text=text, async for _ in self.content_safe_check_stage.process(
): event,
yield check_text=text,
):
yield
# 发送消息前事件钩子 # 发送消息前事件钩子
handlers = star_handlers_registry.get_handlers_by_event_type( handlers = star_handlers_registry.get_handlers_by_event_type(
@@ -114,7 +167,8 @@ class ResultDecorateStage(Stage):
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作", "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作",
) )
await handler.handler(event) await handler.handler(event)
if event.get_result() is None or not event.get_result().chain:
if (result := event.get_result()) is None or not result.chain:
logger.debug( logger.debug(
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。", f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。",
) )
@@ -161,21 +215,27 @@ class ResultDecorateStage(Stage):
# 不分段回复 # 不分段回复
new_chain.append(comp) new_chain.append(comp)
continue continue
try:
split_response = re.findall( # 根据 split_mode 选择分段方式
self.regex, if self.split_mode == "words":
comp.text, split_response = self._split_text_by_words(comp.text)
re.DOTALL | re.MULTILINE, else: # regex 模式
) try:
except re.error: split_response = re.findall(
logger.error( self.regex,
f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}", comp.text,
) re.DOTALL | re.MULTILINE,
split_response = re.findall( )
r".*?[。?!~…]+|.+$", except re.error:
comp.text, logger.error(
re.DOTALL | re.MULTILINE, f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}",
) )
split_response = re.findall(
r".*?[。?!~…]+|.+$",
comp.text,
re.DOTALL | re.MULTILINE,
)
if not split_response: if not split_response:
new_chain.append(comp) new_chain.append(comp)
continue continue
@@ -199,7 +259,14 @@ class ResultDecorateStage(Stage):
and result.is_llm_result() and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event) and SessionServiceManager.should_process_tts_request(event)
): ):
if not tts_provider: should_tts = self.tts_trigger_probability >= 1.0 or (
self.tts_trigger_probability > 0.0
and random.random() <= self.tts_trigger_probability
)
if not should_tts:
logger.debug("跳过 TTS:触发概率未命中。")
elif not tts_provider:
logger.warning( logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。", f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
) )
+5 -1
View File
@@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.platform import AstrMessageEvent from astrbot.core.platform import AstrMessageEvent
from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent
from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import (
WecomAIBotMessageEvent,
)
from . import STAGES_ORDER from . import STAGES_ORDER
from .context import PipelineContext from .context import PipelineContext
@@ -78,7 +82,7 @@ class PipelineScheduler:
await self._process_stages(event) await self._process_stages(event)
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
if event.get_platform_name() in ["webchat", "wecom_ai_bot"]: if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)):
await event.send(None) await event.send(None)
logger.debug("pipeline 执行完毕。") logger.debug("pipeline 执行完毕。")
@@ -50,6 +50,9 @@ class WakingCheckStage(Stage):
"ignore_at_all", "ignore_at_all",
False, False,
) )
self.disable_builtin_commands = self.ctx.astrbot_config.get(
"disable_builtin_commands", False
)
async def process( async def process(
self, self,
@@ -131,6 +134,13 @@ class WakingCheckStage(Stage):
EventType.AdapterMessageEvent, EventType.AdapterMessageEvent,
plugins_name=event.plugins_name, plugins_name=event.plugins_name,
): ):
if (
self.disable_builtin_commands
and handler.handler_module_path == "packages.builtin_commands.main"
):
logger.debug("skipping builtin command")
continue
# filter 需满足 AND 逻辑关系 # filter 需满足 AND 逻辑关系
passed = True passed = True
permission_not_pass = False permission_not_pass = False
+5 -3
View File
@@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC):
def get_sender_name(self) -> str: def get_sender_name(self) -> str:
"""获取消息发送者的名称。(可能会返回空字符串)""" """获取消息发送者的名称。(可能会返回空字符串)"""
return self.message_obj.sender.nickname if isinstance(self.message_obj.sender.nickname, str):
return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value): def set_extra(self, key, value):
"""设置额外的信息。""" """设置额外的信息。"""
@@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC):
""" """
self.call_llm = call_llm self.call_llm = call_llm
def get_result(self) -> MessageEventResult: def get_result(self) -> MessageEventResult | None:
"""获取消息事件的结果。""" """获取消息事件的结果。"""
return self._result return self._result
@@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC):
self, self,
prompt: str, prompt: str,
func_tool_manager=None, func_tool_manager=None,
session_id: str = None, session_id: str = "",
image_urls: list[str] | None = None, image_urls: list[str] | None = None,
contexts: list | None = None, contexts: list | None = None,
system_prompt: str = "", system_prompt: str = "",
+2 -2
View File
@@ -54,7 +54,7 @@ class AstrBotMessage:
self_id: str # 机器人的识别id self_id: str # 机器人的识别id
session_id: str # 会话id。取决于 unique_session 的设置。 session_id: str # 会话id。取决于 unique_session 的设置。
message_id: str # 消息id message_id: str # 消息id
group: Group # 群组 group: Group | None # 群组
sender: MessageMember # 发送者 sender: MessageMember # 发送者
message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
message_str: str # 最直观的纯文本消息字符串 message_str: str # 最直观的纯文本消息字符串
@@ -78,7 +78,7 @@ class AstrBotMessage:
return "" return ""
@group_id.setter @group_id.setter
def group_id(self, value: str): def group_id(self, value: str | None):
"""设置 group_id""" """设置 group_id"""
if value: if value:
if self.group: if self.group:
+4
View File
@@ -5,6 +5,7 @@ from asyncio import Queue
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config
from .platform import Platform, PlatformStatus from .platform import Platform, PlatformStatus
from .register import platform_cls_map from .register import platform_cls_map
@@ -18,6 +19,7 @@ class PlatformManager:
self._inst_map: dict[str, dict] = {} self._inst_map: dict[str, dict] = {}
self.astrbot_config = config
self.platforms_config = config["platform"] self.platforms_config = config["platform"]
self.settings = config["platform_settings"] self.settings = config["platform_settings"]
"""NOTE: 这里是 default 的配置文件,以保证最大的兼容性; """NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
@@ -29,6 +31,8 @@ class PlatformManager:
"""初始化所有平台适配器""" """初始化所有平台适配器"""
for platform in self.platforms_config: for platform in self.platforms_config:
try: try:
if ensure_platform_webhook_config(platform):
self.astrbot_config.save_config()
await self.load_platform(platform) await self.load_platform(platform)
except Exception as e: except Exception as e:
logger.error(f"初始化 {platform} 平台适配器失败: {e}") logger.error(f"初始化 {platform} 平台适配器失败: {e}")
+11 -3
View File
@@ -1,7 +1,7 @@
import abc import abc
import uuid import uuid
from asyncio import Queue from asyncio import Queue
from collections.abc import Awaitable from collections.abc import Coroutine
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
@@ -80,6 +80,13 @@ class Platform(abc.ABC):
if self._status == PlatformStatus.ERROR: if self._status == PlatformStatus.ERROR:
self._status = PlatformStatus.RUNNING self._status = PlatformStatus.RUNNING
def unified_webhook(self) -> bool:
"""是否正在使用统一 Webhook 模式"""
return bool(
self.config.get("unified_webhook_mode", False)
and self.config.get("webhook_uuid")
)
def get_stats(self) -> dict: def get_stats(self) -> dict:
"""获取平台统计信息""" """获取平台统计信息"""
meta = self.meta() meta = self.meta()
@@ -97,10 +104,11 @@ class Platform(abc.ABC):
} }
if self.last_error if self.last_error
else None, else None,
"unified_webhook": self.unified_webhook(),
} }
@abc.abstractmethod @abc.abstractmethod
def run(self) -> Awaitable[Any]: def run(self) -> Coroutine[Any, Any, None]:
"""得到一个平台的运行实例,需要返回一个协程对象。""" """得到一个平台的运行实例,需要返回一个协程对象。"""
raise NotImplementedError raise NotImplementedError
@@ -116,7 +124,7 @@ class Platform(abc.ABC):
self, self,
session: MessageSesion, session: MessageSesion,
message_chain: MessageChain, message_chain: MessageChain,
): ) -> None:
"""通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。 """通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。
异步方法 异步方法
+1 -1
View File
@@ -7,7 +7,7 @@ class PlatformMetadata:
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack""" """平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str description: str
"""平台的描述""" """平台的描述"""
id: str | None = None id: str
"""平台的唯一标识符,用于配置中识别特定平台""" """平台的唯一标识符,用于配置中识别特定平台"""
default_config_tmpl: dict | None = None default_config_tmpl: dict | None = None
+1
View File
@@ -40,6 +40,7 @@ def register_platform_adapter(
pm = PlatformMetadata( pm = PlatformMetadata(
name=adapter_name, name=adapter_name,
description=desc, description=desc,
id=adapter_name,
default_config_tmpl=default_config_tmpl, default_config_tmpl=default_config_tmpl,
adapter_display_name=adapter_display_name, adapter_display_name=adapter_display_name,
logo_path=logo_path, logo_path=logo_path,
@@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
bot: CQHttp, bot: CQHttp,
event: Event | None, event: Event | None,
is_group: bool, is_group: bool,
session_id: str, session_id: str | None,
messages: list[dict], messages: list[dict],
): ):
# session_id 必须是纯数字字符串 # session_id 必须是纯数字字符串
session_id = int(session_id) if session_id.isdigit() else None session_id_int = (
int(session_id) if session_id and session_id.isdigit() else None
)
if is_group and isinstance(session_id, int): if is_group and isinstance(session_id_int, int):
await bot.send_group_msg(group_id=session_id, message=messages) await bot.send_group_msg(group_id=session_id_int, message=messages)
elif not is_group and isinstance(session_id, int): elif not is_group and isinstance(session_id_int, int):
await bot.send_private_msg(user_id=session_id, message=messages) await bot.send_private_msg(user_id=session_id_int, message=messages)
elif isinstance(event, Event): # 最后兜底 elif isinstance(event, Event): # 最后兜底
await bot.send(event=event, message=messages) await bot.send(event=event, message=messages)
else: else:
@@ -4,7 +4,7 @@ import logging
import time import time
import uuid import uuid
from collections.abc import Awaitable from collections.abc import Awaitable
from typing import Any from typing import Any, cast
from aiocqhttp import CQHttp, Event from aiocqhttp import CQHttp, Event
from aiocqhttp.exceptions import ActionFailed from aiocqhttp.exceptions import ActionFailed
@@ -48,7 +48,7 @@ class AiocqhttpAdapter(Platform):
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
name="aiocqhttp", name="aiocqhttp",
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
@@ -127,7 +127,9 @@ class AiocqhttpAdapter(Platform):
"""OneBot V11 请求类事件""" """OneBot V11 请求类事件"""
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = str(event.self_id) abm.self_id = str(event.self_id)
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id) abm.sender = MessageMember(
user_id=str(event.user_id), nickname=str(event.user_id)
)
abm.type = MessageType.OTHER_MESSAGE abm.type = MessageType.OTHER_MESSAGE
if event.get("group_id"): if event.get("group_id"):
abm.type = MessageType.GROUP_MESSAGE abm.type = MessageType.GROUP_MESSAGE
@@ -194,6 +196,7 @@ class AiocqhttpAdapter(Platform):
@param event: 事件对象 @param event: 事件对象
@param get_reply: 是否获取回复消息这个参数是为了防止多个回复嵌套 @param get_reply: 是否获取回复消息这个参数是为了防止多个回复嵌套
""" """
assert event.sender is not None
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = str(event.self_id) abm.self_id = str(event.self_id)
abm.sender = MessageMember( abm.sender = MessageMember(
@@ -203,6 +206,7 @@ class AiocqhttpAdapter(Platform):
if event["message_type"] == "group": if event["message_type"] == "group":
abm.type = MessageType.GROUP_MESSAGE abm.type = MessageType.GROUP_MESSAGE
abm.group_id = str(event.group_id) abm.group_id = str(event.group_id)
abm.group = Group(str(event.group_id))
abm.group.group_name = event.get("group_name", "N/A") abm.group.group_name = event.get("group_name", "N/A")
elif event["message_type"] == "private": elif event["message_type"] == "private":
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
@@ -228,7 +232,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send(event, err) await self.bot.send(event, err)
except BaseException as e: except BaseException as e:
logger.error(f"回复消息失败: {e}") logger.error(f"回复消息失败: {e}")
return None raise ValueError(err)
# 按消息段类型类型适配 # 按消息段类型类型适配
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]): for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
@@ -381,10 +385,25 @@ class AiocqhttpAdapter(Platform):
logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。") logger.error(f"获取 @ 用户信息失败: {e},此消息段将被忽略。")
message_str += "".join(at_parts) message_str += "".join(at_parts)
elif t == "markdown":
text = m["data"].get("markdown") or m["data"].get("content", "")
abm.message.append(Plain(text=text))
message_str += text
else: else:
for m in m_group: for m in m_group:
a = ComponentTypes[t](**m["data"]) try:
abm.message.append(a) if t not in ComponentTypes:
logger.warning(
f"不支持的消息段类型,已忽略: {t}, data={m['data']}"
)
continue
a = ComponentTypes[t](**m["data"])
abm.message.append(a)
except Exception as e:
logger.exception(
f"消息段解析失败: type={t}, data={m['data']}. {e}"
)
continue
abm.timestamp = int(time.time()) abm.timestamp = int(time.time())
abm.message_str = message_str abm.message_str = message_str
@@ -417,7 +436,7 @@ class AiocqhttpAdapter(Platform):
async def shutdown_trigger_placeholder(self): async def shutdown_trigger_placeholder(self):
await self.shutdown_event.wait() await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被优雅地关闭") logger.info("aiocqhttp 适配器已被关闭")
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return self.metadata return self.metadata
@@ -2,6 +2,7 @@ import asyncio
import os import os
import threading import threading
import uuid import uuid
from typing import cast
import aiohttp import aiohttp
import dingtalk_stream import dingtalk_stream
@@ -54,12 +55,14 @@ class DingtalkPlatformAdapter(Platform):
self.client_id = platform_config["client_id"] self.client_id = platform_config["client_id"]
self.client_secret = platform_config["client_secret"] self.client_secret = platform_config["client_secret"]
outer_self = self
class AstrCallbackClient(dingtalk_stream.ChatbotHandler): class AstrCallbackClient(dingtalk_stream.ChatbotHandler):
async def process(self_, message: dingtalk_stream.CallbackMessage): async def process(self, message: dingtalk_stream.CallbackMessage):
logger.debug(f"dingtalk: {message.data}") logger.debug(f"dingtalk: {message.data}")
im = dingtalk_stream.ChatbotMessage.from_dict(message.data) im = dingtalk_stream.ChatbotMessage.from_dict(message.data)
abm = await self.convert_msg(im) abm = await outer_self.convert_msg(im)
await self.handle_msg(abm) await outer_self.handle_msg(abm)
return AckMessage.STATUS_OK, "OK" return AckMessage.STATUS_OK, "OK"
@@ -73,6 +76,7 @@ class DingtalkPlatformAdapter(Platform):
self.client, self.client,
) )
self.client_ = client # 用于 websockets 的 client self.client_ = client # 用于 websockets 的 client
self._shutdown_event: threading.Event | None = None
def _id_to_sid(self, dingtalk_id: str | None) -> str: def _id_to_sid(self, dingtalk_id: str | None) -> str:
if not dingtalk_id: if not dingtalk_id:
@@ -93,7 +97,7 @@ class DingtalkPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="dingtalk", name="dingtalk",
description="钉钉机器人官方 API 适配器", description="钉钉机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
@@ -104,7 +108,7 @@ class DingtalkPlatformAdapter(Platform):
abm = AstrBotMessage() abm = AstrBotMessage()
abm.message = [] abm.message = []
abm.message_str = "" abm.message_str = ""
abm.timestamp = int(message.create_at / 1000) abm.timestamp = int(cast(int, message.create_at) / 1000)
abm.type = ( abm.type = (
MessageType.GROUP_MESSAGE MessageType.GROUP_MESSAGE
if message.conversation_type == "2" if message.conversation_type == "2"
@@ -115,7 +119,7 @@ class DingtalkPlatformAdapter(Platform):
nickname=message.sender_nick, nickname=message.sender_nick,
) )
abm.self_id = self._id_to_sid(message.chatbot_user_id) abm.self_id = self._id_to_sid(message.chatbot_user_id)
abm.message_id = message.message_id abm.message_id = cast(str, message.message_id)
abm.raw_message = message abm.raw_message = message
if abm.type == MessageType.GROUP_MESSAGE: if abm.type == MessageType.GROUP_MESSAGE:
@@ -132,14 +136,16 @@ class DingtalkPlatformAdapter(Platform):
else: else:
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
message_type: str = message.message_type message_type: str = cast(str, message.message_type)
match message_type: match message_type:
case "text": case "text":
abm.message_str = message.text.content.strip() abm.message_str = message.text.content.strip()
abm.message.append(Plain(abm.message_str)) abm.message.append(Plain(abm.message_str))
case "richText": case "richText":
rtc: dingtalk_stream.RichTextContent = message.rich_text_content rtc: dingtalk_stream.RichTextContent = cast(
contents: list[dict] = rtc.rich_text_list dingtalk_stream.RichTextContent, message.rich_text_content
)
contents: list[dict] = cast(list[dict], rtc.rich_text_list)
for content in contents: for content in contents:
plains = "" plains = ""
if "text" in content: if "text" in content:
@@ -148,7 +154,7 @@ class DingtalkPlatformAdapter(Platform):
elif "type" in content and content["type"] == "picture": elif "type" in content and content["type"] == "picture":
f_path = await self.download_ding_file( f_path = await self.download_ding_file(
content["downloadCode"], content["downloadCode"],
message.robot_code, cast(str, message.robot_code),
"jpg", "jpg",
) )
abm.message.append(Image.fromFileSystem(f_path)) abm.message.append(Image.fromFileSystem(f_path))
@@ -193,7 +199,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error( logger.error(
f"下载钉钉文件失败: {resp.status}, {await resp.text()}", f"下载钉钉文件失败: {resp.status}, {await resp.text()}",
) )
return None return ""
resp_data = await resp.json() resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"] download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path) await download_file(download_url, f_path)
@@ -213,7 +219,7 @@ class DingtalkPlatformAdapter(Platform):
logger.error( logger.error(
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}", f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
) )
return None return ""
return (await resp.json())["data"]["accessToken"] return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage): async def handle_msg(self, abm: AstrBotMessage):
@@ -239,7 +245,7 @@ class DingtalkPlatformAdapter(Platform):
task.result() task.result()
except Exception as e: except Exception as e:
if "Graceful shutdown" in str(e): if "Graceful shutdown" in str(e):
logger.info("钉钉适配器已被优雅地关闭") logger.info("钉钉适配器已被关闭")
return return
logger.error(f"钉钉机器人启动失败: {e}") logger.error(f"钉钉机器人启动失败: {e}")
@@ -250,9 +256,11 @@ class DingtalkPlatformAdapter(Platform):
def monkey_patch_close(): def monkey_patch_close():
raise KeyboardInterrupt("Graceful shutdown") raise KeyboardInterrupt("Graceful shutdown")
self.client_.open_connection = monkey_patch_close if self.client_.websocket is not None:
await self.client_.websocket.close(code=1000, reason="Graceful shutdown") self.client_.open_connection = monkey_patch_close
self._shutdown_event.set() await self.client_.websocket.close(code=1000, reason="Graceful shutdown")
if self._shutdown_event is not None:
self._shutdown_event.set()
def get_client(self): def get_client(self):
return self.client return self.client
@@ -1,4 +1,5 @@
import asyncio import asyncio
from typing import cast
import dingtalk_stream import dingtalk_stream
@@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown, client.reply_markdown,
segment.text, segment.text,
segment.text, segment.text,
self.message_obj.raw_message, cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
) )
elif isinstance(segment, Comp.Image): elif isinstance(segment, Comp.Image):
markdown_str = "" markdown_str = ""
@@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
client.reply_markdown, client.reply_markdown,
"😄", "😄",
markdown_str, markdown_str,
self.message_obj.raw_message, cast(
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
),
) )
logger.debug(f"send image: {ret}") logger.debug(f"send image: {ret}")
@@ -1,4 +1,5 @@
import sys import sys
from collections.abc import Awaitable, Callable
import discord import discord
@@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot):
super().__init__(intents=intents, proxy=proxy) super().__init__(intents=intents, proxy=proxy)
# 回调函数 # 回调函数
self.on_message_received = None self.on_message_received: Callable[[dict], Awaitable[None]] | None = None
self.on_ready_once_callback = None self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
self._ready_once_fired = False self._ready_once_fired = False
@override
async def on_ready(self): async def on_ready(self):
"""当机器人成功连接并准备就绪时触发""" """当机器人成功连接并准备就绪时触发"""
if self.user is None:
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
return
logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录") logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录")
logger.info("[Discord] 客户端已准备就绪。") logger.info("[Discord] 客户端已准备就绪。")
@@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot):
def _create_message_data(self, message: discord.Message) -> dict: def _create_message_data(self, message: discord.Message) -> dict:
"""从 discord.Message 创建数据字典""" """从 discord.Message 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
is_mentioned = self.user in message.mentions is_mentioned = self.user in message.mentions
return { return {
"message": message, "message": message,
@@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot):
def _create_interaction_data(self, interaction: discord.Interaction) -> dict: def _create_interaction_data(self, interaction: discord.Interaction) -> dict:
"""从 discord.Interaction 创建数据字典""" """从 discord.Interaction 创建数据字典"""
if self.user is None:
raise RuntimeError("Bot is not ready: self.user is None")
if interaction.user is None:
raise ValueError("Interaction received without a valid user")
return { return {
"interaction": interaction, "interaction": interaction,
"bot_id": str(self.user.id), "bot_id": str(self.user.id),
@@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot):
"type": "interaction", "type": "interaction",
} }
@override
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
"""当接收到消息时触发""" """当接收到消息时触发"""
if message.author.bot: if message.author.bot:
@@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent):
def __init__( def __init__(
self, self,
components: list[BaseMessageComponent] = None, components: list[BaseMessageComponent] | None = None,
timeout: float = None, timeout: float | None = None,
): ):
self.components = components or [] self.components = components or []
self.timeout = timeout self.timeout = timeout
@@ -1,10 +1,10 @@
import asyncio import asyncio
import re import re
import sys import sys
from typing import Any from typing import Any, cast
import discord import discord
from discord.abc import Messageable from discord.abc import GuildChannel, Messageable, PrivateChannel
from discord.channel import DMChannel from discord.channel import DMChannel
from astrbot import logger from astrbot import logger
@@ -46,7 +46,7 @@ class DiscordPlatformAdapter(Platform):
) -> None: ) -> None:
super().__init__(platform_config, event_queue) super().__init__(platform_config, event_queue)
self.settings = platform_settings self.settings = platform_settings
self.client_self_id = None self.client_self_id: str | None = None
self.registered_handlers = [] self.registered_handlers = []
# 指令注册相关 # 指令注册相关
self.enable_command_register = self.config.get("discord_command_register", True) self.enable_command_register = self.config.get("discord_command_register", True)
@@ -62,6 +62,12 @@ class DiscordPlatformAdapter(Platform):
message_chain: MessageChain, message_chain: MessageChain,
): ):
"""通过会话发送消息""" """通过会话发送消息"""
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法发送消息"
)
return
# 创建一个 message_obj 以便在 event 中使用 # 创建一个 message_obj 以便在 event 中使用
message_obj = AstrBotMessage() message_obj = AstrBotMessage()
if "_" in session.session_id: if "_" in session.session_id:
@@ -89,7 +95,7 @@ class DiscordPlatformAdapter(Platform):
user_id=str(self.client_self_id), user_id=str(self.client_self_id),
nickname=self.client.user.display_name, nickname=self.client.user.display_name,
) )
message_obj.self_id = self.client_self_id message_obj.self_id = cast(str, self.client_self_id)
message_obj.session_id = session.session_id message_obj.session_id = session.session_id
message_obj.message = message_chain.chain message_obj.message = message_chain.chain
@@ -110,7 +116,7 @@ class DiscordPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
"discord", "discord",
"Discord 适配器", "Discord 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
default_config_tmpl=self.config, default_config_tmpl=self.config,
support_streaming_message=False, support_streaming_message=False,
) )
@@ -160,7 +166,7 @@ class DiscordPlatformAdapter(Platform):
def _get_message_type( def _get_message_type(
self, self,
channel: Messageable, channel: Messageable | GuildChannel | PrivateChannel,
guild_id: int | None = None, guild_id: int | None = None,
) -> MessageType: ) -> MessageType:
"""根据 channel 对象和 guild_id 判断消息类型""" """根据 channel 对象和 guild_id 判断消息类型"""
@@ -170,13 +176,15 @@ class DiscordPlatformAdapter(Platform):
return MessageType.FRIEND_MESSAGE return MessageType.FRIEND_MESSAGE
return MessageType.GROUP_MESSAGE return MessageType.GROUP_MESSAGE
def _get_channel_id(self, channel: Messageable) -> str: def _get_channel_id(
self, channel: Messageable | GuildChannel | PrivateChannel
) -> str:
"""根据 channel 对象获取ID""" """根据 channel 对象获取ID"""
return str(getattr(channel, "id", None)) return str(getattr(channel, "id", None))
def _convert_message_to_abm(self, data: dict) -> AstrBotMessage: def _convert_message_to_abm(self, data: dict) -> AstrBotMessage:
"""将普通消息转换为 AstrBotMessage""" """将普通消息转换为 AstrBotMessage"""
message: discord.Message = data["message"] message = data["message"]
content = message.content content = message.content
@@ -233,7 +241,7 @@ class DiscordPlatformAdapter(Platform):
) )
abm.message = message_chain abm.message = message_chain
abm.raw_message = message abm.raw_message = message
abm.self_id = self.client_self_id abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(message.channel.id) abm.session_id = str(message.channel.id)
abm.message_id = str(message.id) abm.message_id = str(message.id)
return abm return abm
@@ -254,32 +262,52 @@ class DiscordPlatformAdapter(Platform):
interaction_followup_webhook=followup_webhook, interaction_followup_webhook=followup_webhook,
) )
if self.client.user is None:
logger.error(
"[Discord] 客户端未就绪 (self.client.user is None),无法处理消息"
)
return
# 检查是否为斜杠指令 # 检查是否为斜杠指令
is_slash_command = message_event.interaction_followup_webhook is not None is_slash_command = message_event.interaction_followup_webhook is not None
# 1. 优先处理斜杠指令
if is_slash_command:
message_event.is_wake = True
message_event.is_at_or_wake_command = True
self.commit_event(message_event)
return
# 2. 处理普通消息(提及检测)
# 确保 raw_message 是 discord.Message 类型,以便静态检查通过
raw_message = message.raw_message
if not isinstance(raw_message, discord.Message):
logger.warning(
f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。"
)
return
# 检查是否被@User Mention 或 Bot 拥有的 Role Mention # 检查是否被@User Mention 或 Bot 拥有的 Role Mention
is_mention = False is_mention = False
# User Mention # User Mention
if ( # 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性
self.client if self.client.user in raw_message.mentions:
and self.client.user is_mention = True
and hasattr(message.raw_message, "mentions")
):
if self.client.user in message.raw_message.mentions:
is_mention = True
# Role MentionBot 拥有的角色被提及) # Role MentionBot 拥有的角色被提及)
if not is_mention and hasattr(message.raw_message, "role_mentions"): if not is_mention and raw_message.role_mentions:
bot_member = None bot_member = None
if hasattr(message.raw_message, "guild") and message.raw_message.guild: if raw_message.guild:
try: try:
bot_member = message.raw_message.guild.get_member( bot_member = raw_message.guild.get_member(
self.client.user.id, self.client.user.id,
) )
except Exception: except Exception:
bot_member = None bot_member = None
if bot_member and hasattr(bot_member, "roles"): if bot_member and hasattr(bot_member, "roles"):
bot_roles = set(bot_member.roles) bot_roles = set(bot_member.roles)
mentioned_roles = set(message.raw_message.role_mentions) mentioned_roles = set(raw_message.role_mentions)
if ( if (
bot_roles bot_roles
and mentioned_roles and mentioned_roles
@@ -287,8 +315,8 @@ class DiscordPlatformAdapter(Platform):
): ):
is_mention = True is_mention = True
# 如果是斜杠指令或被@的消息,设置为唤醒状态 # 如果是被@的消息,设置为唤醒状态
if is_slash_command or is_mention: if is_mention:
message_event.is_wake = True message_event.is_wake = True
message_event.is_at_or_wake_command = True message_event.is_at_or_wake_command = True
@@ -424,7 +452,7 @@ class DiscordPlatformAdapter(Platform):
) )
abm.message = [Plain(text=message_str_for_filter)] abm.message = [Plain(text=message_str_for_filter)]
abm.raw_message = ctx.interaction abm.raw_message = ctx.interaction
abm.self_id = self.client_self_id abm.self_id = cast(str, self.client_self_id)
abm.session_id = str(ctx.channel_id) abm.session_id = str(ctx.channel_id)
abm.message_id = str(ctx.interaction.id) abm.message_id = str(ctx.interaction.id)
@@ -437,7 +465,7 @@ class DiscordPlatformAdapter(Platform):
def _extract_command_info( def _extract_command_info(
event_filter: Any, event_filter: Any,
handler_metadata: StarHandlerMetadata, handler_metadata: StarHandlerMetadata,
) -> tuple[str, str, CommandFilter] | None: ) -> tuple[str, str, CommandFilter | None] | None:
"""从事件过滤器中提取指令信息""" """从事件过滤器中提取指令信息"""
cmd_name = None cmd_name = None
# is_group = False # is_group = False
@@ -4,8 +4,10 @@ import binascii
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import cast
import discord import discord
from discord.types.interactions import ComponentInteractionData
from astrbot import logger from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
channel = await self._get_channel() channel = await self._get_channel()
if not channel: if not channel:
return return
if not isinstance(channel, discord.abc.Messageable):
logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型")
return
await channel.send(**kwargs) await channel.send(**kwargs)
except Exception as e: except Exception as e:
@@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
await self.send(buffer) await self.send(buffer)
return await super().send_streaming(generator, use_fallback) return await super().send_streaming(generator, use_fallback)
async def _get_channel(self) -> discord.abc.Messageable | None: async def _get_channel(
self,
) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None:
"""获取当前事件对应的频道对象""" """获取当前事件对应的频道对象"""
try: try:
channel_id = int(self.session_id) channel_id = int(self.session_id)
@@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
async def _parse_to_discord( async def _parse_to_discord(
self, self,
message: MessageChain, message: MessageChain,
) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]: ) -> tuple[
str,
list[discord.File],
discord.ui.View | None,
list[discord.Embed],
str | int | None,
]:
"""将 MessageChain 解析为 Discord 发送所需的内容""" """将 MessageChain 解析为 Discord 发送所需的内容"""
content_parts = [] content_parts = []
files = [] files = []
@@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message, self.message_obj.raw_message,
"add_reaction", "add_reaction",
): ):
await self.message_obj.raw_message.add_reaction(emoji) await cast(discord.Message, self.message_obj.raw_message).add_reaction(
emoji
)
except Exception as e: except Exception as e:
logger.error(f"[Discord] 添加反应失败: {e}") logger.error(f"[Discord] 添加反应失败: {e}")
@@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
return ( return (
hasattr(self.message_obj, "raw_message") hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type") and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.application_command == discord.InteractionType.application_command
) )
@@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent):
return ( return (
hasattr(self.message_obj, "raw_message") hasattr(self.message_obj, "raw_message")
and hasattr(self.message_obj.raw_message, "type") and hasattr(self.message_obj.raw_message, "type")
and self.message_obj.raw_message.type == discord.InteractionType.component and cast(discord.Interaction, self.message_obj.raw_message).type
== discord.InteractionType.component
) )
def get_interaction_custom_id(self) -> str: def get_interaction_custom_id(self) -> str:
"""获取交互组件的custom_id""" """获取交互组件的custom_id"""
if self.is_button_interaction(): if self.is_button_interaction():
try: try:
return self.message_obj.raw_message.data.get("custom_id", "") return cast(
ComponentInteractionData,
cast(discord.Interaction, self.message_obj.raw_message).data,
).get("custom_id", "")
except Exception: except Exception:
pass pass
return "" return ""
@@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent):
): ):
return any( return any(
mention.id == int(self.message_obj.self_id) mention.id == int(self.message_obj.self_id)
for mention in self.message_obj.raw_message.mentions for mention in cast(
discord.Message, self.message_obj.raw_message
).mentions
) )
return False return False
@@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent):
self.message_obj.raw_message, self.message_obj.raw_message,
"clean_content", "clean_content",
): ):
return self.message_obj.raw_message.clean_content return cast(discord.Message, self.message_obj.raw_message).clean_content
return self.message_str return self.message_str
@@ -2,10 +2,17 @@ import asyncio
import base64 import base64
import json import json
import re import re
import time
import uuid import uuid
from typing import Any, cast
import lark_oapi as lark import lark_oapi as lark
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import (
CreateMessageRequest,
CreateMessageRequestBody,
GetMessageResourceRequest,
)
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
import astrbot.api.message_components as Comp import astrbot.api.message_components as Comp
from astrbot import logger from astrbot import logger
@@ -18,9 +25,11 @@ from astrbot.api.platform import (
PlatformMetadata, PlatformMetadata,
) )
from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter from ...register import register_platform_adapter
from .lark_event import LarkMessageEvent from .lark_event import LarkMessageEvent
from .server import LarkWebhookServer
@register_platform_adapter( @register_platform_adapter(
@@ -42,9 +51,13 @@ class LarkPlatformAdapter(Platform):
self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN) self.domain = platform_config.get("domain", lark.FEISHU_DOMAIN)
self.bot_name = platform_config.get("lark_bot_name", "astrbot") self.bot_name = platform_config.get("lark_bot_name", "astrbot")
# socket or webhook
self.connection_mode = platform_config.get("lark_connection_mode", "socket")
if not self.bot_name: if not self.bot_name:
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。") logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
# 初始化 WebSocket 长连接相关配置
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1): async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
await self.convert_msg(event) await self.convert_msg(event)
@@ -57,6 +70,8 @@ class LarkPlatformAdapter(Platform):
.build() .build()
) )
self.do_v2_msg_event = do_v2_msg_event
self.client = lark.ws.Client( self.client = lark.ws.Client(
app_id=self.appid, app_id=self.appid,
app_secret=self.appsecret, app_secret=self.appsecret,
@@ -66,14 +81,56 @@ class LarkPlatformAdapter(Platform):
) )
self.lark_api = ( self.lark_api = (
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build() lark.Client.builder()
.app_id(self.appid)
.app_secret(self.appsecret)
.log_level(lark.LogLevel.ERROR)
.domain(self.domain)
.build()
) )
self.webhook_server = None
if self.connection_mode == "webhook":
self.webhook_server = LarkWebhookServer(platform_config, event_queue)
self.webhook_server.set_callback(self.handle_webhook_event)
self.event_id_timestamps: dict[str, float] = {}
def _clean_expired_events(self):
"""清理超过 30 分钟的事件记录"""
current_time = time.time()
expired_keys = [
event_id
for event_id, timestamp in self.event_id_timestamps.items()
if current_time - timestamp > 1800
]
for event_id in expired_keys:
del self.event_id_timestamps[event_id]
def _is_duplicate_event(self, event_id: str) -> bool:
"""检查事件是否重复
Args:
event_id: 事件ID
Returns:
True 表示重复事件False 表示新事件
"""
self._clean_expired_events()
if event_id in self.event_id_timestamps:
return True
self.event_id_timestamps[event_id] = time.time()
return False
async def send_by_session( async def send_by_session(
self, self,
session: MessageSesion, session: MessageSesion,
message_chain: MessageChain, message_chain: MessageChain,
): ):
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
return
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api) res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
wrapped = { wrapped = {
"zh_cn": { "zh_cn": {
@@ -114,14 +171,25 @@ class LarkPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="lark", name="lark",
description="飞书机器人官方 API 适配器", description="飞书机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
if event.event is None:
logger.debug("[Lark] 收到空事件(event.event is None)")
return
message = event.event.message message = event.event.message
if message is None:
logger.debug("[Lark] 事件中没有消息体(message is None)")
return
abm = AstrBotMessage() abm = AstrBotMessage()
abm.timestamp = int(message.create_time) / 1000
if message.create_time:
abm.timestamp = int(message.create_time) // 1000
else:
abm.timestamp = int(time.time())
abm.message = [] abm.message = []
abm.type = ( abm.type = (
MessageType.GROUP_MESSAGE MessageType.GROUP_MESSAGE
@@ -136,14 +204,28 @@ class LarkPlatformAdapter(Platform):
at_list = {} at_list = {}
if message.mentions: if message.mentions:
for m in message.mentions: for m in message.mentions:
at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name) if m.id is None:
if m.name == self.bot_name: continue
abm.self_id = m.id.open_id # 飞书 open_id 可能是 None,这里做个防护
open_id = m.id.open_id if m.id.open_id else ""
at_list[m.key] = Comp.At(qq=open_id, name=m.name)
content_json_b = json.loads(message.content) if m.name == self.bot_name:
if m.id.open_id is not None:
abm.self_id = m.id.open_id
if message.content is None:
logger.warning("[Lark] 消息内容为空")
return
try:
content_json_b = json.loads(message.content)
except json.JSONDecodeError:
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
return
if message.message_type == "text": if message.message_type == "text":
message_str_raw = content_json_b["text"] # 带有 @ 的消息 message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则 at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
# at_users = re.findall(at_pattern, message_str_raw) # at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本,去掉AT符号部分 # 拆分文本,去掉AT符号部分
@@ -168,27 +250,47 @@ class LarkPlatformAdapter(Platform):
content_json_b = _ls content_json_b = _ls
elif message.message_type == "image": elif message.message_type == "image":
content_json_b = [ content_json_b = [
{"tag": "img", "image_key": content_json_b["image_key"], "style": []}, {
"tag": "img",
"image_key": content_json_b.get("image_key"),
"style": [],
},
] ]
if message.message_type in ("post", "image"): if message.message_type in ("post", "image"):
for comp in content_json_b: for comp in content_json_b:
if comp["tag"] == "at": if comp.get("tag") == "at":
abm.message.append(at_list[comp["user_id"]]) user_id = comp.get("user_id")
elif comp["tag"] == "text" and comp["text"].strip(): if user_id in at_list:
abm.message.append(at_list[user_id])
elif comp.get("tag") == "text" and comp.get("text", "").strip():
abm.message.append(Comp.Plain(comp["text"].strip())) abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp["tag"] == "img": elif comp.get("tag") == "img":
image_key = comp["image_key"] image_key = comp.get("image_key")
if not image_key:
continue
request = ( request = (
GetMessageResourceRequest.builder() GetMessageResourceRequest.builder()
.message_id(message.message_id) .message_id(cast(str, message.message_id))
.file_key(image_key) .file_key(image_key)
.type("image") .type("image")
.build() .build()
) )
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化")
continue
response = await self.lark_api.im.v1.message_resource.aget(request) response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success(): if not response.success():
logger.error(f"无法下载飞书图片: {image_key}") logger.error(f"无法下载飞书图片: {image_key}")
continue
if response.file is None:
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
continue
image_bytes = response.file.read() image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode() image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Comp.Image.fromBase64(image_base64)) abm.message.append(Comp.Image.fromBase64(image_base64))
@@ -196,6 +298,19 @@ class LarkPlatformAdapter(Platform):
for comp in abm.message: for comp in abm.message:
if isinstance(comp, Comp.Plain): if isinstance(comp, Comp.Plain):
abm.message_str += comp.text abm.message_str += comp.text
if message.message_id is None:
logger.error("[Lark] 消息缺少 message_id")
return
if (
event.event.sender is None
or event.event.sender.sender_id is None
or event.event.sender.sender_id.open_id is None
):
logger.error("[Lark] 消息发送者信息不完整")
return
abm.message_id = message.message_id abm.message_id = message.message_id
abm.raw_message = message abm.raw_message = message
abm.sender = MessageMember( abm.sender = MessageMember(
@@ -227,13 +342,61 @@ class LarkPlatformAdapter(Platform):
self._event_queue.put_nowait(event) self._event_queue.put_nowait(event)
async def handle_webhook_event(self, event_data: dict):
"""处理 Webhook 事件
Args:
event_data: Webhook 事件数据
"""
try:
header = event_data.get("header", {})
event_id = header.get("event_id", "")
if event_id and self._is_duplicate_event(event_id):
logger.debug(f"[Lark Webhook] 跳过重复事件: {event_id}")
return
event_type = header.get("event_type", "")
if event_type == "im.message.receive_v1":
processor = P2ImMessageReceiveV1Processor(self.do_v2_msg_event)
data = (processor.type())(event_data)
processor.do(data)
else:
logger.debug(f"[Lark Webhook] 未处理的事件类型: {event_type}")
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
async def run(self): async def run(self):
# self.client.start() if self.connection_mode == "webhook":
await self.client._connect() # Webhook 模式
if self.webhook_server is None:
logger.error("[Lark] Webhook 模式已启用,但 webhook_server 未初始化")
return
webhook_uuid = self.config.get("webhook_uuid")
if webhook_uuid:
log_webhook_info(f"{self.meta().id}(飞书 Webhook)", webhook_uuid)
else:
logger.warning("[Lark] Webhook 模式已启用,但未配置 webhook_uuid")
else:
# 长连接模式
await self.client._connect()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口"""
if not self.webhook_server:
return {"error": "Webhook server not initialized"}, 500
return await self.webhook_server.handle_callback(request)
async def terminate(self): async def terminate(self):
await self.client._disconnect() if self.connection_mode == "socket":
logger.info("飞书(Lark) 适配器已被优雅地关闭") await self.client._disconnect()
logger.info("飞书(Lark) 适配器已关闭")
def get_client(self) -> lark.Client: def get_client(self) -> lark.ws.Client:
return self.client return self.client
def unified_webhook(self) -> bool:
return bool(
self.config.get("lark_connection_mode", "") == "webhook"
and self.config.get("webhook_uuid")
)
@@ -5,7 +5,15 @@ import uuid
from io import BytesIO from io import BytesIO
import lark_oapi as lark import lark_oapi as lark
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import (
CreateImageRequest,
CreateImageRequestBody,
CreateMessageReactionRequest,
CreateMessageReactionRequestBody,
Emoji,
ReplyMessageRequest,
ReplyMessageRequestBody,
)
from astrbot import logger from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent):
file_path = comp.file.replace("file:///", "") file_path = comp.file.replace("file:///", "")
elif comp.file and comp.file.startswith("http"): elif comp.file and comp.file.startswith("http"):
image_file_path = await download_image_by_url(comp.file) image_file_path = await download_image_by_url(comp.file)
file_path = image_file_path file_path = image_file_path if image_file_path else ""
elif comp.file and comp.file.startswith("base64://"): elif comp.file and comp.file.startswith("base64://"):
base64_str = comp.file.removeprefix("base64://") base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str) image_data = base64.b64decode(base64_str)
@@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent):
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue()) f.write(BytesIO(image_data).getvalue())
else: else:
file_path = comp.file file_path = comp.file if comp.file else ""
if image_file is None: if image_file is None:
image_file = open(file_path, "rb") if not file_path:
logger.error("[Lark] 图片路径为空,无法上传")
continue
try:
image_file = open(file_path, "rb")
except Exception as e:
logger.error(f"[Lark] 无法打开图片文件: {e}")
continue
request = ( request = (
CreateImageRequest.builder() CreateImageRequest.builder()
@@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent):
) )
.build() .build()
) )
if lark_client.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法上传图片")
continue
response = await lark_client.im.v1.image.acreate(request) response = await lark_client.im.v1.image.acreate(request)
if not response.success(): if not response.success():
logger.error(f"无法上传飞书图片({response.code}): {response.msg}") logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
continue
if response.data is None:
logger.error("[Lark] 上传图片成功但未返回数据(data is None)")
continue
image_key = response.data.image_key image_key = response.data.image_key
logger.debug(image_key) logger.debug(image_key)
ret.append(_stage) ret.append(_stage)
@@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent):
.build() .build()
) )
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
return
response = await self.bot.im.v1.message.areply(request) response = await self.bot.im.v1.message.areply(request)
if not response.success(): if not response.success():
@@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent):
await super().send(message) await super().send(message)
async def react(self, emoji: str): async def react(self, emoji: str):
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
return
request = ( request = (
CreateMessageReactionRequest.builder() CreateMessageReactionRequest.builder()
.message_id(self.message_obj.message_id) .message_id(self.message_obj.message_id)
@@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent):
) )
.build() .build()
) )
response = await self.bot.im.v1.message_reaction.acreate(request) response = await self.bot.im.v1.message_reaction.acreate(request)
if not response.success(): if not response.success():
logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}") logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}")
@@ -0,0 +1,206 @@
"""飞书(Lark) Webhook 服务器实现
实现飞书事件订阅的 Webhook 模式支持:
1. 请求 URL 验证 (challenge 验证)
2. 事件加密/解密 (AES-256-CBC)
3. 签名校验 (SHA256)
4. 事件接收和处理
"""
import asyncio
import base64
import hashlib
import json
from collections.abc import Awaitable, Callable
from Crypto.Cipher import AES
from astrbot.api import logger
class AESCipher:
"""AES 加密/解密工具类"""
def __init__(self, key: str):
self.bs = AES.block_size
self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
@staticmethod
def str_to_bytes(data):
u_type = type(b"".decode("utf8"))
if isinstance(data, u_type):
return data.encode("utf8")
return data
@staticmethod
def _unpad(s):
return s[: -ord(s[len(s) - 1 :])]
def decrypt(self, enc):
iv = enc[: AES.block_size]
cipher = AES.new(self.key, AES.MODE_CBC, iv)
return self._unpad(cipher.decrypt(enc[AES.block_size :]))
def decrypt_string(self, enc):
enc = base64.b64decode(enc)
return self.decrypt(enc).decode("utf8")
class LarkWebhookServer:
"""飞书 Webhook 服务器
仅支持统一 Webhook 模式
"""
def __init__(self, config: dict, event_queue: asyncio.Queue):
"""初始化 Webhook 服务器
Args:
config: 飞书配置
event_queue: 事件队列
"""
self.app_id = config["app_id"]
self.app_secret = config["app_secret"]
self.encrypt_key = config.get("lark_encrypt_key", "")
self.verification_token = config.get("lark_verification_token", "")
self.event_queue = event_queue
self.callback: Callable[[dict], Awaitable[None]] | None = None
# 初始化加密工具
self.cipher = None
if self.encrypt_key:
self.cipher = AESCipher(self.encrypt_key)
def verify_signature(
self,
timestamp: str,
nonce: str,
encrypt_key: str,
body: bytes,
signature: str,
) -> bool:
"""验证签名
Args:
timestamp: 请求时间戳
nonce: 随机数
encrypt_key: 加密密钥
body: 请求体
signature: 签名
Returns:
签名是否有效
"""
# 拼接字符串: timestamp + nonce + encrypt_key + body
bytes_b1 = (timestamp + nonce + encrypt_key).encode("utf-8")
bytes_b = bytes_b1 + body
h = hashlib.sha256(bytes_b)
calculated_signature = h.hexdigest()
return calculated_signature == signature
def decrypt_event(self, encrypted_data: str) -> dict:
"""解密事件数据
Args:
encrypted_data: 加密的事件数据
Returns:
解密后的事件字典
"""
if not self.cipher:
raise ValueError("未配置 encrypt_key,无法解密事件")
decrypted_str = self.cipher.decrypt_string(encrypted_data)
return json.loads(decrypted_str)
async def handle_challenge(self, event_data: dict) -> dict:
"""处理 challenge 验证请求
Args:
event_data: 事件数据
Returns:
包含 challenge 的响应
"""
challenge = event_data.get("challenge", "")
logger.info(f"[Lark Webhook] 收到 challenge 验证请求: {challenge}")
return {"challenge": challenge}
async def handle_callback(self, request) -> tuple[dict, int] | dict:
"""处理 webhook 回调,可被统一 webhook 入口复用
Args:
request: Quart 请求对象
Returns:
响应数据
"""
# 获取原始请求体
body = await request.get_data()
try:
event_data = await request.json
except Exception as e:
logger.error(f"[Lark Webhook] 解析请求体失败: {e}")
return {"error": "Invalid JSON"}, 400
if not event_data:
logger.error("[Lark Webhook] 请求体为空")
return {"error": "Empty request body"}, 400
# 如果配置了 encrypt_key,进行签名验证
if self.encrypt_key:
timestamp = request.headers.get("X-Lark-Request-Timestamp", "")
nonce = request.headers.get("X-Lark-Request-Nonce", "")
signature = request.headers.get("X-Lark-Signature", "")
if timestamp and nonce and signature:
if not self.verify_signature(
timestamp, nonce, self.encrypt_key, body, signature
):
logger.error("[Lark Webhook] 签名验证失败")
return {"error": "Invalid signature"}, 401
# 检查是否是加密事件
if "encrypt" in event_data:
try:
event_data = self.decrypt_event(event_data["encrypt"])
logger.debug(f"[Lark Webhook] 解密后的事件: {event_data}")
except Exception as e:
logger.error(f"[Lark Webhook] 解密事件失败: {e}")
return {"error": "Decryption failed"}, 400
# 验证 token
if self.verification_token:
header = event_data.get("header", {})
if header:
token = header.get("token", "")
else:
token = event_data.get("token", "")
if token != self.verification_token:
logger.error("[Lark Webhook] Verification Token 不匹配。")
return {"error": "Invalid verification token"}, 401
# 处理 URL 验证 (challenge)
if event_data.get("type") == "url_verification":
return await self.handle_challenge(event_data)
# 调用回调函数处理事件
if self.callback:
try:
await self.callback(event_data)
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件回调失败: {e}", exc_info=True)
return {"error": "Event processing failed"}, 500
return {}
def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
"""设置事件回调函数
Args:
callback: 处理事件的异步函数
"""
self.callback = callback
@@ -1,7 +1,6 @@
import asyncio import asyncio
import os import os
import random import random
from collections.abc import Awaitable
from typing import Any from typing import Any
import astrbot.api.message_components as Comp import astrbot.api.message_components as Comp
@@ -203,7 +202,7 @@ class MisskeyPlatformAdapter(Platform):
if not isinstance(message.raw_message, dict): if not isinstance(message.raw_message, dict):
message.raw_message = {} message.raw_message = {}
message.raw_message["poll"] = poll message.raw_message["poll"] = poll
message.poll = poll message.__setattr__("poll", poll)
except Exception: except Exception:
pass pass
@@ -372,7 +371,7 @@ class MisskeyPlatformAdapter(Platform):
self, self,
session: MessageSession, session: MessageSession,
message_chain: MessageChain, message_chain: MessageChain,
) -> Awaitable[Any]: ) -> None:
if not self.api: if not self.api:
logger.error("[Misskey] API 客户端未初始化") logger.error("[Misskey] API 客户端未初始化")
return await super().send_by_session(session, message_chain) return await super().send_by_session(session, message_chain)
@@ -3,6 +3,7 @@ import base64
import os import os
import random import random
import uuid import uuid
from typing import cast
import aiofiles import aiofiles
import botpy import botpy
@@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
time_since_last_edit = current_time - last_edit_time time_since_last_edit = current_time - last_edit_time
if time_since_last_edit >= throttle_interval: if time_since_last_edit >= throttle_interval:
ret = await self._post_send(stream=stream_payload) ret = cast(
message.Message,
await self._post_send(stream=stream_payload),
)
stream_payload["index"] += 1 stream_payload["index"] += 1
stream_payload["id"] = ret["id"] stream_payload["id"] = ret["id"]
last_edit_time = asyncio.get_event_loop().time() last_edit_time = asyncio.get_event_loop().time()
@@ -83,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
return None return None
source = self.message_obj.raw_message source = self.message_obj.raw_message
assert isinstance(
if not isinstance(
source, source,
( (
botpy.message.Message, botpy.message.Message,
@@ -91,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
botpy.message.DirectMessage, botpy.message.DirectMessage,
botpy.message.C2CMessage, botpy.message.C2CMessage,
), ),
) ):
logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}")
return None
( (
plain_text, plain_text,
@@ -108,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
): ):
return None return None
payload = { payload: dict = {
"content": plain_text, "content": plain_text,
"msg_id": self.message_obj.message_id, "msg_id": self.message_obj.message_id,
} }
@@ -118,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
ret = None ret = None
match type(source): match source:
case botpy.message.GroupMessage: case botpy.message.GroupMessage():
if not source.group_openid:
logger.error("[QQOfficial] GroupMessage 缺少 group_openid")
return None
if image_base64: if image_base64:
media = await self.upload_group_and_c2c_image( media = await self.upload_group_and_c2c_image(
image_base64, image_base64,
@@ -140,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent):
group_openid=source.group_openid, group_openid=source.group_openid,
**payload, **payload,
) )
case botpy.message.C2CMessage:
case botpy.message.C2CMessage():
if image_base64: if image_base64:
media = await self.upload_group_and_c2c_image( media = await self.upload_group_and_c2c_image(
image_base64, image_base64,
@@ -169,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent):
**payload, **payload,
) )
logger.debug(f"Message sent to C2C: {ret}") logger.debug(f"Message sent to C2C: {ret}")
case botpy.message.Message:
case botpy.message.Message():
if image_path: if image_path:
payload["file_image"] = image_path payload["file_image"] = image_path
ret = await self.bot.api.post_message( ret = await self.bot.api.post_message(
channel_id=source.channel_id, channel_id=source.channel_id,
**payload, **payload,
) )
case botpy.message.DirectMessage:
case botpy.message.DirectMessage():
if image_path: if image_path:
payload["file_image"] = image_path payload["file_image"] = image_path
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload) ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
case _:
pass
await super().send(self.send_buffer) await super().send(self.send_buffer)
self.send_buffer = None self.send_buffer = None
@@ -198,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent):
"file_type": file_type, "file_type": file_type,
"srv_send_msg": False, "srv_send_msg": False,
} }
result = None
if "openid" in kwargs: if "openid" in kwargs:
payload["openid"] = kwargs["openid"] payload["openid"] = kwargs["openid"]
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
return await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if "group_openid" in kwargs: elif "group_openid" in kwargs:
payload["group_openid"] = kwargs["group_openid"] payload["group_openid"] = kwargs["group_openid"]
route = Route( route = Route(
"POST", "POST",
"/v2/groups/{group_openid}/files", "/v2/groups/{group_openid}/files",
group_openid=kwargs["group_openid"], group_openid=kwargs["group_openid"],
) )
return await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
else:
raise ValueError("Invalid upload parameters")
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to upload image, response is not dict: {result}"
)
return Media(
file_uuid=result["file_uuid"],
file_info=result["file_info"],
ttl=result.get("ttl", 0),
)
async def upload_group_and_c2c_record( async def upload_group_and_c2c_record(
self, self,
@@ -252,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
result = await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if result: if result:
if not isinstance(result, dict):
logger.error(f"上传文件响应格式错误: {result}")
return None
return Media( return Media(
file_uuid=result.get("file_uuid"), file_uuid=result["file_uuid"],
file_info=result.get("file_info"), file_info=result["file_info"],
ttl=result.get("ttl", 0), ttl=result.get("ttl", 0),
file_id=result.get("id", ""),
) )
except Exception as e: except Exception as e:
logger.error(f"上传请求错误: {e}") logger.error(f"上传请求错误: {e}")
@@ -273,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
message_reference: message.Reference | None = None, message_reference: message.Reference | None = None,
media: message.Media | None = None, media: message.Media | None = None,
msg_id: str | None = None, msg_id: str | None = None,
msg_seq: str = 1, msg_seq: int | None = 1,
event_id: str | None = None, event_id: str | None = None,
markdown: message.MarkdownPayload | None = None, markdown: message.MarkdownPayload | None = None,
keyboard: message.Keyboard | None = None, keyboard: message.Keyboard | None = None,
@@ -282,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload = locals() payload = locals()
payload.pop("self", None) payload.pop("self", None)
route = Route("POST", "/v2/users/{openid}/messages", openid=openid) route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
return await self.bot.api._http.request(route, json=payload) result = await self.bot.api._http.request(route, json=payload)
if not isinstance(result, dict):
raise RuntimeError(
f"Failed to post c2c message, response is not dict: {result}"
)
return message.Message(**result)
@staticmethod @staticmethod
async def _parse_to_qqofficial(message: MessageChain): async def _parse_to_qqofficial(message: MessageChain):
@@ -302,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
image_base64 = file_to_base64(image_file_path) image_base64 = file_to_base64(image_file_path)
elif i.file and i.file.startswith("base64://"): elif i.file and i.file.startswith("base64://"):
image_base64 = i.file image_base64 = i.file
else: elif i.file:
image_base64 = file_to_base64(i.file) image_base64 = file_to_base64(i.file)
else:
raise ValueError("Unsupported image file format")
image_base64 = image_base64.removeprefix("base64://") image_base64 = image_base64.removeprefix("base64://")
elif isinstance(i, Record): elif isinstance(i, Record):
if i.file: if i.file:
@@ -4,6 +4,7 @@ import asyncio
import logging import logging
import os import os
import time import time
from typing import cast
import botpy import botpy
import botpy.message import botpy.message
@@ -44,7 +45,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE, MessageType.GROUP_MESSAGE,
) )
abm.session_id = ( abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
) )
self._commit(abm) self._commit(abm)
@@ -101,7 +104,7 @@ class QQOfficialPlatformAdapter(Platform):
self.appid = platform_config["appid"] self.appid = platform_config["appid"]
self.secret = platform_config["secret"] self.secret = platform_config["secret"]
self.unique_session = platform_settings["unique_session"] self.unique_session: bool = platform_settings["unique_session"]
qq_group = platform_config["enable_group_c2c"] qq_group = platform_config["enable_group_c2c"]
guild_dm = platform_config["enable_guild_direct_message"] guild_dm = platform_config["enable_guild_direct_message"]
@@ -137,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="qq_official", name="qq_official",
description="QQ 机器人官方 API 适配器", description="QQ 机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
) )
@staticmethod @staticmethod
def _parse_from_qqofficial( def _parse_from_qqofficial(
message: botpy.message.Message | botpy.message.GroupMessage, message: botpy.message.Message
| botpy.message.GroupMessage
| botpy.message.DirectMessage
| botpy.message.C2CMessage,
message_type: MessageType, message_type: MessageType,
): ):
abm = AstrBotMessage() abm = AstrBotMessage()
@@ -150,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.timestamp = int(time.time()) abm.timestamp = int(time.time())
abm.raw_message = message abm.raw_message = message
abm.message_id = message.id abm.message_id = message.id
abm.tag = "qq_official" # abm.tag = "qq_official"
msg: list[BaseMessageComponent] = [] msg: list[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance( if isinstance(message, botpy.message.GroupMessage) or isinstance(
@@ -180,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform):
message, message,
botpy.message.DirectMessage, botpy.message.DirectMessage,
): ):
try: if isinstance(message, botpy.message.Message):
abm.self_id = str(message.mentions[0].id) abm.self_id = str(message.mentions[0].id)
except BaseException as _: else:
abm.self_id = "" abm.self_id = ""
plain_content = message.content.replace( plain_content = message.content.replace(
@@ -1,6 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import Any from typing import Any, cast
import botpy import botpy
import botpy.message import botpy.message
@@ -36,7 +36,9 @@ class botClient(Client):
MessageType.GROUP_MESSAGE, MessageType.GROUP_MESSAGE,
) )
abm.session_id = ( abm.session_id = (
abm.sender.user_id if self.platform.unique_session else message.group_openid abm.sender.user_id
if self.platform.unique_session
else cast(str, message.group_openid)
) )
self._commit(abm) self._commit(abm)
@@ -120,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform):
return PlatformMetadata( return PlatformMetadata(
name="qq_official_webhook", name="qq_official_webhook",
description="QQ 机器人官方 API 适配器", description="QQ 机器人官方 API 适配器",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
) )
async def run(self): async def run(self):
@@ -1,5 +1,6 @@
import asyncio import asyncio
import logging import logging
from typing import cast
import quart import quart
from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token
@@ -99,7 +100,7 @@ class QQOfficialWebhook:
if opcode == 13: if opcode == 13:
# validation # validation
signed = await self.webhook_validation(data) signed = await self.webhook_validation(cast(dict, data))
print(signed) print(signed)
return signed return signed
@@ -4,9 +4,11 @@ import hmac
import json import json
import logging import logging
from collections.abc import Callable from collections.abc import Callable
from typing import cast
from quart import Quart, Response, request from quart import Quart, Response, request
from slack_sdk.socket_mode.aiohttp import SocketModeClient from slack_sdk.socket_mode.aiohttp import SocketModeClient
from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
@@ -66,7 +68,7 @@ class SlackWebhookClient:
""" """
try: try:
# 获取请求体和头部 # 获取请求体和头部
body = await req.get_data() body = cast(bytes, await req.get_data())
event_data = json.loads(body.decode("utf-8")) event_data = json.loads(body.decode("utf-8"))
# Verify Slack request signature # Verify Slack request signature
@@ -139,9 +141,14 @@ class SlackSocketClient:
self.event_handler = event_handler self.event_handler = event_handler
self.socket_client = None self.socket_client = None
async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest): async def _handle_events(
self, _: AsyncBaseSocketModeClient, req: SocketModeRequest
):
"""处理 Socket Mode 事件""" """处理 Socket Mode 事件"""
try: try:
if self.socket_client is None:
raise RuntimeError("Socket client is not initialized")
# 确认收到事件 # 确认收到事件
response = SocketModeResponse(envelope_id=req.envelope_id) response = SocketModeResponse(envelope_id=req.envelope_id)
await self.socket_client.send_socket_mode_response(response) await self.socket_client.send_socket_mode_response(response)
@@ -3,8 +3,7 @@ import base64
import re import re
import time import time
import uuid import uuid
from collections.abc import Awaitable from typing import Any, cast
from typing import Any
import aiohttp import aiohttp
from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.request import SocketModeRequest
@@ -68,7 +67,7 @@ class SlackAdapter(Platform):
self.metadata = PlatformMetadata( self.metadata = PlatformMetadata(
name="slack", name="slack",
description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。", description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。",
id=self.config.get("id"), id=cast(str, self.config.get("id")),
support_streaming_message=False, support_streaming_message=False,
) )
@@ -118,13 +117,13 @@ class SlackAdapter(Platform):
logger.debug(f"[slack] RawMessage {event}") logger.debug(f"[slack] RawMessage {event}")
abm = AstrBotMessage() abm = AstrBotMessage()
abm.self_id = self.bot_self_id abm.self_id = cast(str, self.bot_self_id)
# 获取用户信息 # 获取用户信息
user_id = event.get("user", "") user_id = event.get("user", "")
try: try:
user_info = await self.web_client.users_info(user=user_id) user_info = await self.web_client.users_info(user=user_id)
user_data = user_info["user"] user_data = cast(dict, user_info["user"])
user_name = user_data.get("real_name") or user_data.get("name", user_id) user_name = user_data.get("real_name") or user_data.get("name", user_id)
except Exception: except Exception:
user_name = user_id user_name = user_id
@@ -135,7 +134,7 @@ class SlackAdapter(Platform):
channel_id = event.get("channel", "") channel_id = event.get("channel", "")
try: try:
channel_info = await self.web_client.conversations_info(channel=channel_id) channel_info = await self.web_client.conversations_info(channel=channel_id)
is_im = channel_info["channel"]["is_im"] is_im = cast(dict, channel_info["channel"])["is_im"]
if is_im: if is_im:
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
@@ -178,7 +177,7 @@ class SlackAdapter(Platform):
for mention in mentions: for mention in mentions:
try: try:
mentioned_user = await self.web_client.users_info(user=mention) mentioned_user = await self.web_client.users_info(user=mention)
user_data = mentioned_user["user"] user_data = cast(dict, mentioned_user["user"])
user_name = user_data.get("real_name") or user_data.get( user_name = user_data.get("real_name") or user_data.get(
"name", "name",
mention, mention,
@@ -329,7 +328,7 @@ class SlackAdapter(Platform):
) )
raise Exception(f"下载文件失败: {resp.status}") raise Exception(f"下载文件失败: {resp.status}")
async def run(self) -> Awaitable[Any]: async def run(self) -> None:
self.bot_self_id = await self.get_bot_user_id() self.bot_self_id = await self.get_bot_user_id()
logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}") logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}")
@@ -410,7 +409,7 @@ class SlackAdapter(Platform):
await self.socket_client.stop() await self.socket_client.stop()
if self.webhook_client: if self.webhook_client:
await self.webhook_client.stop() await self.webhook_client.stop()
logger.info("Slack 适配器已被优雅地关闭") logger.info("Slack 适配器已被关闭")
def meta(self) -> PlatformMetadata: def meta(self) -> PlatformMetadata:
return self.metadata return self.metadata
@@ -428,3 +427,10 @@ class SlackAdapter(Platform):
def get_client(self): def get_client(self):
return self.web_client return self.web_client
def unified_webhook(self) -> bool:
return bool(
self.config.get("unified_webhook_mode", False)
and self.config.get("slack_connection_mode", "") == "webhook"
and self.config.get("webhook_uuid")
)
@@ -1,6 +1,7 @@
import asyncio import asyncio
import re import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator, Iterable
from typing import cast
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
@@ -38,7 +39,7 @@ class SlackMessageEvent(AstrMessageEvent):
if isinstance(segment, Image): if isinstance(segment, Image):
# upload file # upload file
url = segment.url or segment.file url = segment.url or segment.file
if url.startswith("http"): if url and url.startswith("http"):
return { return {
"type": "image", "type": "image",
"image_url": url, "image_url": url,
@@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section", "type": "section",
"text": {"type": "mrkdwn", "text": "图片上传失败"}, "text": {"type": "mrkdwn", "text": "图片上传失败"},
} }
image_url = response["files"][0]["url_private"] image_url = cast(list, response["files"])[0]["url_private"]
logger.debug(f"Slack file upload response: {response}") logger.debug(f"Slack file upload response: {response}")
return { return {
"type": "image", "type": "image",
@@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent):
"type": "section", "type": "section",
"text": {"type": "mrkdwn", "text": "文件上传失败"}, "text": {"type": "mrkdwn", "text": "文件上传失败"},
} }
file_url = response["files"][0]["permalink"] file_url = cast(list, response["files"])[0]["permalink"]
return { return {
"type": "section", "type": "section",
"text": { "text": {
@@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent):
) )
members = [] members = []
for member_id in members_response["members"]: for member_id in cast(Iterable, members_response["members"]):
try: try:
user_info = await self.web_client.users_info(user=member_id) user_info = await self.web_client.users_info(user=member_id)
user_data = user_info["user"] user_data = cast(dict, user_info["user"])
members.append( members.append(
MessageMember( MessageMember(
user_id=member_id, user_id=member_id,
@@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent):
# 如果获取用户信息失败,使用默认信息 # 如果获取用户信息失败,使用默认信息
members.append(MessageMember(user_id=member_id, nickname=member_id)) members.append(MessageMember(user_id=member_id, nickname=member_id))
channel_data = channel_info["channel"] channel_data = cast(dict, channel_info["channel"])
return Group( return Group(
group_id=channel_id, group_id=channel_id,
group_name=channel_data.get("name", ""), group_name=channel_data.get("name", ""),
@@ -424,6 +424,6 @@ class TelegramPlatformAdapter(Platform):
if self.application.updater is not None: if self.application.updater is not None:
await self.application.updater.stop() await self.application.updater.stop()
logger.info("Telegram 适配器已被优雅地关闭") logger.info("Telegram 适配器已被关闭")
except Exception as e: except Exception as e:
logger.error(f"Telegram 适配器关闭时出错: {e}") logger.error(f"Telegram 适配器关闭时出错: {e}")
@@ -1,6 +1,7 @@
import asyncio import asyncio
import os import os
import re import re
from typing import Any, cast
import telegramify_markdown import telegramify_markdown
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
@@ -17,8 +18,6 @@ from astrbot.api.message_components import (
Reply, Reply,
) )
from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
class TelegramPlatformEvent(AstrMessageEvent): class TelegramPlatformEvent(AstrMessageEvent):
@@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
"chat_id": user_name, "chat_id": user_name,
} }
if has_reply: if has_reply:
payload["reply_to_message_id"] = reply_message_id payload["reply_to_message_id"] = str(reply_message_id)
if message_thread_id: if message_thread_id:
payload["message_thread_id"] = message_thread_id payload["message_thread_id"] = message_thread_id
@@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent):
try: try:
md_text = telegramify_markdown.markdownify( md_text = telegramify_markdown.markdownify(
chunk, chunk,
max_line_length=None,
normalize_whitespace=False, normalize_whitespace=False,
) )
await client.send_message( await client.send_message(
text=md_text, text=md_text,
parse_mode="MarkdownV2", parse_mode="MarkdownV2",
**payload, **cast(Any, payload),
) )
except Exception as e: except Exception as e:
logger.warning( logger.warning(
f"MarkdownV2 send failed: {e}. Using plain text instead.", f"MarkdownV2 send failed: {e}. Using plain text instead.",
) )
await client.send_message(text=chunk, **payload) await client.send_message(text=chunk, **cast(Any, payload))
elif isinstance(i, Image): elif isinstance(i, Image):
image_path = await i.convert_to_file_path() image_path = await i.convert_to_file_path()
await client.send_photo(photo=image_path, **payload) await client.send_photo(photo=image_path, **cast(Any, payload))
elif isinstance(i, File): elif isinstance(i, File):
if i.file.startswith("https://"): path = await i.get_file()
temp_dir = os.path.join(get_astrbot_data_path(), "temp") name = i.name or os.path.basename(path)
path = os.path.join(temp_dir, i.name) await client.send_document(
await download_file(i.file, path) document=path, filename=name, **cast(Any, payload)
i.file = path )
await client.send_document(document=i.file, filename=i.name, **payload)
elif isinstance(i, Record): elif isinstance(i, Record):
path = await i.convert_to_file_path() path = await i.convert_to_file_path()
await client.send_voice(voice=path, **payload) await client.send_voice(voice=path, **cast(Any, payload))
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
if self.get_message_type() == MessageType.GROUP_MESSAGE: if self.get_message_type() == MessageType.GROUP_MESSAGE:
@@ -204,6 +200,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
if isinstance(chain, MessageChain): if isinstance(chain, MessageChain):
if chain.type == "break": if chain.type == "break":
# 分割符 # 分割符
if message_id:
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
message_id = None # 重置消息 ID message_id = None # 重置消息 ID
delta = "" # 重置 delta delta = "" # 重置 delta
continue continue
@@ -214,24 +219,23 @@ class TelegramPlatformEvent(AstrMessageEvent):
delta += i.text delta += i.text
elif isinstance(i, Image): elif isinstance(i, Image):
image_path = await i.convert_to_file_path() image_path = await i.convert_to_file_path()
await self.client.send_photo(photo=image_path, **payload) await self.client.send_photo(
photo=image_path, **cast(Any, payload)
)
continue continue
elif isinstance(i, File): elif isinstance(i, File):
if i.file.startswith("https://"): path = await i.get_file()
temp_dir = os.path.join(get_astrbot_data_path(), "temp") name = i.name or os.path.basename(path)
path = os.path.join(temp_dir, i.name)
await download_file(i.file, path)
i.file = path
await self.client.send_document( await self.client.send_document(
document=i.file, document=path,
filename=i.name, filename=name,
**payload, **cast(Any, payload),
) )
continue continue
elif isinstance(i, Record): elif isinstance(i, Record):
path = await i.convert_to_file_path() path = await i.convert_to_file_path()
await self.client.send_voice(voice=path, **payload) await self.client.send_voice(voice=path, **cast(Any, payload))
continue continue
else: else:
logger.warning(f"不支持的消息类型: {type(i)}") logger.warning(f"不支持的消息类型: {type(i)}")
@@ -260,7 +264,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
else: else:
# delta 长度一般不会大于 4096,因此这里直接发送 # delta 长度一般不会大于 4096,因此这里直接发送
try: try:
msg = await self.client.send_message(text=delta, **payload) msg = await self.client.send_message(
text=delta, **cast(Any, payload)
)
current_content = delta current_content = delta
except Exception as e: except Exception as e:
logger.warning(f"发送消息失败(streaming): {e!s}") logger.warning(f"发送消息失败(streaming): {e!s}")
@@ -274,7 +280,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
try: try:
markdown_text = telegramify_markdown.markdownify( markdown_text = telegramify_markdown.markdownify(
delta, delta,
max_line_length=None,
normalize_whitespace=False, normalize_whitespace=False,
) )
await self.client.edit_message_text( await self.client.edit_message_text(
@@ -2,7 +2,7 @@ import asyncio
import os import os
import time import time
import uuid import uuid
from collections.abc import Awaitable, Callable from collections.abc import Callable, Coroutine
from typing import Any from typing import Any
from astrbot import logger from astrbot import logger
@@ -207,7 +207,7 @@ class WebChatAdapter(Platform):
abm.raw_message = data abm.raw_message = data
return abm return abm
def run(self) -> Awaitable[Any]: def run(self) -> Coroutine[Any, Any, None]:
async def callback(data: tuple): async def callback(data: tuple):
abm = await self.convert_message(data) abm = await self.convert_message(data)
await self.handle_msg(abm) await self.handle_msg(abm)
@@ -1,11 +1,12 @@
import base64 import base64
import json
import os import os
import shutil import shutil
import uuid import uuid
from astrbot.api import logger from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import File, Image, Plain, Record from astrbot.api.message_components import File, Image, Json, Plain, Record
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .webchat_queue_mgr import webchat_queue_mgr from .webchat_queue_mgr import webchat_queue_mgr
@@ -41,12 +42,20 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
"type": "plain", "type": "plain",
"cid": cid,
"data": data, "data": data,
"streaming": streaming, "streaming": streaming,
"chain_type": message.type, "chain_type": message.type,
}, },
) )
elif isinstance(comp, Json):
await web_chat_back_queue.put(
{
"type": "plain",
"data": json.dumps(comp.data, ensure_ascii=False),
"streaming": streaming,
"chain_type": message.type,
},
)
elif isinstance(comp, Image): elif isinstance(comp, Image):
# save image to local # save image to local
filename = f"{str(uuid.uuid4())}.jpg" filename = f"{str(uuid.uuid4())}.jpg"
@@ -58,7 +67,6 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
"type": "image", "type": "image",
"cid": cid,
"data": data, "data": data,
"streaming": streaming, "streaming": streaming,
}, },
@@ -74,7 +82,6 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
"type": "record", "type": "record",
"cid": cid,
"data": data, "data": data,
"streaming": streaming, "streaming": streaming,
}, },
@@ -91,7 +98,6 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
"type": "file", "type": "file",
"cid": cid,
"data": data, "data": data,
"streaming": streaming, "streaming": streaming,
}, },
@@ -101,9 +107,9 @@ class WebChatMessageEvent(AstrMessageEvent):
return data return data
async def send(self, message: MessageChain): async def send(self, message: MessageChain | None):
await WebChatMessageEvent._send(message, session_id=self.session_id) await WebChatMessageEvent._send(message, session_id=self.session_id)
await super().send(message) await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback: bool = False): async def send_streaming(self, generator, use_fallback: bool = False):
final_data = "" final_data = ""
@@ -111,18 +117,17 @@ class WebChatMessageEvent(AstrMessageEvent):
cid = self.session_id.split("!")[-1] cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator: async for chain in generator:
if chain.type == "break" and final_data: # if chain.type == "break" and final_data:
# 分割符 # # 分割符
await web_chat_back_queue.put( # await web_chat_back_queue.put(
{ # {
"type": "break", # break means a segment end # "type": "break", # break means a segment end
"data": final_data, # "data": final_data,
"streaming": True, # "streaming": True,
"cid": cid, # },
}, # )
) # final_data = ""
final_data = "" # continue
continue
r = await WebChatMessageEvent._send( r = await WebChatMessageEvent._send(
chain, chain,
@@ -142,7 +147,6 @@ class WebChatMessageEvent(AstrMessageEvent):
"data": final_data, "data": final_data,
"reasoning": reasoning_content, "reasoning": reasoning_content,
"streaming": True, "streaming": True,
"cid": cid,
}, },
) )
await super().send_streaming(generator, use_fallback) await super().send_streaming(generator, use_fallback)
@@ -4,6 +4,7 @@ import json
import os import os
import time import time
import traceback import traceback
from typing import cast
import aiohttp import aiohttp
import anyio import anyio
@@ -69,7 +70,7 @@ class WeChatPadProAdapter(Platform):
) )
self.base_url = f"http://{self.host}:{self.port}" self.base_url = f"http://{self.host}:{self.port}"
self.auth_key = None # 用于保存生成的授权码 self.auth_key = None # 用于保存生成的授权码
self.wxid = None # 用于保存登录成功后的 wxid self.wxid: str | None = None # 用于保存登录成功后的 wxid
self.credentials_file = os.path.join( self.credentials_file = os.path.join(
get_astrbot_data_path(), get_astrbot_data_path(),
"wechatpadpro_credentials.json", "wechatpadpro_credentials.json",
@@ -398,7 +399,7 @@ class WeChatPadProAdapter(Platform):
) )
await asyncio.sleep(5) await asyncio.sleep(5)
async def handle_websocket_message(self, message: str): async def handle_websocket_message(self, message: str | bytes):
"""处理从 WebSocket 接收到的消息。""" """处理从 WebSocket 接收到的消息。"""
logger.debug(f"收到 WebSocket 消息: {message}") logger.debug(f"收到 WebSocket 消息: {message}")
try: try:
@@ -430,10 +431,13 @@ class WeChatPadProAdapter(Platform):
async def convert_message(self, raw_message: dict) -> AstrBotMessage | None: async def convert_message(self, raw_message: dict) -> AstrBotMessage | None:
"""将 WeChatPadPro 原始消息转换为 AstrBotMessage。""" """将 WeChatPadPro 原始消息转换为 AstrBotMessage。"""
if self.wxid is None:
logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。")
return None
abm = AstrBotMessage() abm = AstrBotMessage()
abm.raw_message = raw_message abm.raw_message = raw_message
abm.message_id = str(raw_message.get("msg_id")) abm.message_id = str(raw_message.get("msg_id"))
abm.timestamp = raw_message.get("create_time") abm.timestamp = cast(int, raw_message.get("create_time"))
abm.self_id = self.wxid abm.self_id = self.wxid
if int(time.time()) - abm.timestamp > 180: if int(time.time()) - abm.timestamp > 180:
@@ -446,7 +450,7 @@ class WeChatPadProAdapter(Platform):
to_user_name = raw_message.get("to_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "")
content = raw_message.get("content", {}).get("str", "") content = raw_message.get("content", {}).get("str", "")
push_content = raw_message.get("push_content", "") push_content = raw_message.get("push_content", "")
msg_type = raw_message.get("msg_type") msg_type = cast(int, raw_message.get("msg_type"))
abm.message_str = "" abm.message_str = ""
abm.message = [] abm.message = []
@@ -574,7 +578,7 @@ class WeChatPadProAdapter(Platform):
from_user_name: str, from_user_name: str,
to_user_name: str, to_user_name: str,
msg_id: int, msg_id: int,
): ) -> dict | None:
"""下载原始图片。""" """下载原始图片。"""
url = f"{self.base_url}/message/GetMsgBigImg" url = f"{self.base_url}/message/GetMsgBigImg"
params = {"key": self.auth_key} params = {"key": self.auth_key}
@@ -725,12 +729,15 @@ class WeChatPadProAdapter(Platform):
# 图片消息 # 图片消息
from_user_name = raw_message.get("from_user_name", {}).get("str", "") from_user_name = raw_message.get("from_user_name", {}).get("str", "")
to_user_name = raw_message.get("to_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "")
msg_id = raw_message.get("msg_id") msg_id = cast(int, raw_message.get("msg_id"))
image_resp = await self._download_raw_image( image_resp = await self._download_raw_image(
from_user_name, from_user_name,
to_user_name, to_user_name,
msg_id, msg_id,
) )
if image_resp is None:
logger.error(f"下载图片失败: msg_id={msg_id}")
return
image_bs64_data = ( image_bs64_data = (
image_resp.get("Data", {}).get("Data", {}).get("Buffer", None) image_resp.get("Data", {}).get("Data", {}).get("Buffer", None)
) )
@@ -771,6 +778,9 @@ class WeChatPadProAdapter(Platform):
bufid = 0 bufid = 0
to_user_name = raw_message.get("to_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "")
new_msg_id = raw_message.get("new_msg_id") new_msg_id = raw_message.get("new_msg_id")
if new_msg_id is None:
logger.error("语音消息缺少 new_msg_id")
return
data_parser = GeweDataParser( data_parser = GeweDataParser(
content=content, content=content,
is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), is_private_chat=(abm.type != MessageType.GROUP_MESSAGE),
@@ -778,6 +788,9 @@ class WeChatPadProAdapter(Platform):
) )
voicemsg = data_parser._format_to_xml().find("voicemsg") voicemsg = data_parser._format_to_xml().find("voicemsg")
if voicemsg is None:
logger.error("无法从 XML 解析 voicemsg 节点")
return
bufid = voicemsg.get("bufid") or "0" bufid = voicemsg.get("bufid") or "0"
length = int(voicemsg.get("length") or 0) length = int(voicemsg.get("length") or 0)
voice_resp = await self.download_voice( voice_resp = await self.download_voice(
@@ -786,6 +799,9 @@ class WeChatPadProAdapter(Platform):
bufid=bufid, bufid=bufid,
length=length, length=length,
) )
if voice_resp is None:
logger.error(f"下载语音失败: new_msg_id={new_msg_id}")
return
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None) voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
if voice_bs64_data: if voice_bs64_data:
voice_bs64_data = base64.b64decode(voice_bs64_data) voice_bs64_data = base64.b64decode(voice_bs64_data)
@@ -827,7 +843,8 @@ class WeChatPadProAdapter(Platform):
try: try:
if self.ws_handle_task: if self.ws_handle_task:
self.ws_handle_task.cancel() self.ws_handle_task.cancel()
self._shutdown_event.set() if self._shutdown_event is not None:
self._shutdown_event.set()
except Exception: except Exception:
pass pass
@@ -894,8 +911,8 @@ class WeChatPadProAdapter(Platform):
async def get_contact_details_list( async def get_contact_details_list(
self, self,
room_wx_id_list: list[str] = None, room_wx_id_list: list[str] | None = None,
user_names: list[str] = None, user_names: list[str] | None = None,
) -> dict | None: ) -> dict | None:
"""获取联系人详情列表。""" """获取联系人详情列表。"""
if room_wx_id_list is None: if room_wx_id_list is None:
@@ -2,7 +2,8 @@ import asyncio
import os import os
import sys import sys
import uuid import uuid
from typing import Any from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart import quart
from requests import Response from requests import Response
@@ -40,7 +41,7 @@ else:
class WecomServer: class WecomServer:
def __init__(self, event_queue: asyncio.Queue, config: dict): def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__) self.server = quart.Quart(__name__)
self.port = int(config.get("port")) self.port = int(cast(str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.server.add_url_rule( self.server.add_url_rule(
"/callback/command", "/callback/command",
@@ -60,7 +61,7 @@ class WecomServer:
config["corpid"].strip(), config["corpid"].strip(),
) )
self.callback = None self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event() self.shutdown_event = asyncio.Event()
async def verify(self): async def verify(self):
@@ -114,7 +115,7 @@ class WecomServer:
logger.error("解密失败,签名异常,请检查配置。") logger.error("解密失败,签名异常,请检查配置。")
raise raise
else: else:
msg = parse_message(xml) msg = cast(BaseMessage, parse_message(xml))
logger.info(f"解析成功: {msg}") logger.info(f"解析成功: {msg}")
if self.callback: if self.callback:
@@ -176,10 +177,10 @@ class WecomPlatformAdapter(Platform):
# inject # inject
self.wechat_kf_api = WeChatKF(client=self.client) self.wechat_kf_api = WeChatKF(client=self.client)
self.wechat_kf_message_api = WeChatKFMessage(self.client) self.wechat_kf_message_api = WeChatKFMessage(self.client)
self.client.kf = self.wechat_kf_api self.client.__setattr__("kf", self.wechat_kf_api)
self.client.kf_message = self.wechat_kf_message_api self.client.__setattr__("kf_message", self.wechat_kf_message_api)
self.client.API_BASE_URL = self.api_base_url self.client.__setattr__("API_BASE_URL", self.api_base_url)
async def callback(msg: BaseMessage): async def callback(msg: BaseMessage):
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
@@ -278,37 +279,33 @@ class WecomPlatformAdapter(Platform):
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
abm = AstrBotMessage() abm = AstrBotMessage()
if msg.type == "text": if isinstance(msg, TextMessage):
assert isinstance(msg, TextMessage)
abm.message_str = msg.content abm.message_str = msg.content
abm.self_id = str(msg.agent) abm.self_id = str(msg.agent)
abm.message = [Plain(msg.content)] abm.message = [Plain(msg.content)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(msg.id)
abm.timestamp = msg.time abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
elif msg.type == "image": elif isinstance(msg, ImageMessage):
assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]" abm.message_str = "[图片]"
abm.self_id = str(msg.agent) abm.self_id = str(msg.agent)
abm.message = [Image(file=msg.image, url=msg.image)] abm.message = [Image(file=msg.image, url=msg.image)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(msg.id)
abm.timestamp = msg.time abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
elif msg.type == "voice": elif isinstance(msg, VoiceMessage):
assert isinstance(msg, VoiceMessage)
resp: Response = await asyncio.get_event_loop().run_in_executor( resp: Response = await asyncio.get_event_loop().run_in_executor(
None, None,
self.client.media.download, self.client.media.download,
@@ -335,11 +332,11 @@ class WecomPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)] abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(msg.id)
abm.timestamp = msg.time abm.timestamp = int(cast(int | str, msg.time))
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
abm.raw_message = msg abm.raw_message = msg
else: else:
@@ -351,7 +348,7 @@ class WecomPlatformAdapter(Platform):
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
msgtype = msg.get("msgtype") msgtype = msg.get("msgtype")
external_userid = msg.get("external_userid") external_userid = cast(str, msg.get("external_userid"))
abm = AstrBotMessage() abm = AstrBotMessage()
abm.raw_message = msg abm.raw_message = msg
abm.raw_message["_wechat_kf_flag"] = None # 方便处理 abm.raw_message["_wechat_kf_flag"] = None # 方便处理
@@ -425,4 +422,4 @@ class WecomPlatformAdapter(Platform):
await self.server.server.shutdown() await self.server.server.shutdown()
except Exception as _: except Exception as _:
pass pass
logger.info("企业微信 适配器已被优雅地关闭") logger.info("企业微信 适配器已被关闭")
@@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent):
if is_wechat_kf: if is_wechat_kf:
# 微信客服 # 微信客服
kf_message_api = getattr(self.client, "kf_message", None) kf_message_api = getattr(self.client, "kf_message", None)
if not kf_message_api: if not isinstance(kf_message_api, WeChatKFMessage):
logger.warning("未找到微信客服发送消息方法。") logger.warning("未找到微信客服发送消息方法。")
return return
assert isinstance(kf_message_api, WeChatKFMessage)
user_id = self.get_sender_id() user_id = self.get_sender_id()
for comp in message.chain: for comp in message.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
@@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
@staticmethod @staticmethod
async def _send( async def _send(
message_chain: MessageChain, message_chain: MessageChain | None,
stream_id: str, stream_id: str,
queue_mgr: WecomAIQueueMgr, queue_mgr: WecomAIQueueMgr,
streaming: bool = False, streaming: bool = False,
@@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
return data return data
async def send(self, message: MessageChain): async def send(self, message: MessageChain | None):
"""发送消息""" """发送消息"""
raw = self.message_obj.raw_message raw = self.message_obj.raw_message
assert isinstance(raw, dict), ( assert isinstance(raw, dict), (
@@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
) )
stream_id = raw.get("stream_id", self.session_id) stream_id = raw.get("stream_id", self.session_id)
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
await super().send(message) await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback=False): async def send_streaming(self, generator, use_fallback=False):
"""流式发送消息,参考webchat的send_streaming设计""" """流式发送消息,参考webchat的send_streaming设计"""
@@ -1,7 +1,8 @@
import asyncio import asyncio
import sys import sys
import uuid import uuid
from typing import Any from collections.abc import Awaitable, Callable
from typing import Any, cast
import quart import quart
from requests import Response from requests import Response
@@ -36,7 +37,7 @@ else:
class WeixinOfficialAccountServer: class WeixinOfficialAccountServer:
def __init__(self, event_queue: asyncio.Queue, config: dict): def __init__(self, event_queue: asyncio.Queue, config: dict):
self.server = quart.Quart(__name__) self.server = quart.Quart(__name__)
self.port = int(config.get("port")) self.port = int(cast(int | str, config.get("port")))
self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.callback_server_host = config.get("callback_server_host", "0.0.0.0")
self.token = config.get("token") self.token = config.get("token")
self.encoding_aes_key = config.get("encoding_aes_key") self.encoding_aes_key = config.get("encoding_aes_key")
@@ -55,7 +56,7 @@ class WeixinOfficialAccountServer:
self.event_queue = event_queue self.event_queue = event_queue
self.callback = None self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None
self.shutdown_event = asyncio.Event() self.shutdown_event = asyncio.Event()
async def verify(self): async def verify(self):
@@ -114,6 +115,9 @@ class WeixinOfficialAccountServer:
raise raise
else: else:
msg = parse_message(xml) msg = parse_message(xml)
if not msg:
logger.error("解析失败。msg为None。")
raise
logger.info(f"解析成功: {msg}") logger.info(f"解析成功: {msg}")
if self.callback: if self.callback:
@@ -176,7 +180,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.config["secret"].strip(), self.config["secret"].strip(),
) )
self.client.API_BASE_URL = self.api_base_url self.client.__setattr__("API_BASE_URL", self.api_base_url)
# 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重 # 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重
# msgid -> Future # msgid -> Future
@@ -188,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.convert_message(msg, None) await self.convert_message(msg, None)
else: else:
if msg.id in self.wexin_event_workers: if msg.id in self.wexin_event_workers:
future = self.wexin_event_workers[msg.id] future = self.wexin_event_workers[str(cast(str | int, msg.id))]
logger.debug(f"duplicate message id checked: {msg.id}") logger.debug(f"duplicate message id checked: {msg.id}")
else: else:
future = asyncio.get_event_loop().create_future() future = asyncio.get_event_loop().create_future()
self.wexin_event_workers[msg.id] = future self.wexin_event_workers[str(cast(str | int, msg.id))] = future
await self.convert_message(msg, future) await self.convert_message(msg, future)
# I love shield so much! # I love shield so much!
result = await asyncio.wait_for( result = await asyncio.wait_for(
@@ -200,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
60, 60,
) # wait for 60s ) # wait for 60s
logger.debug(f"Got future result: {result}") logger.debug(f"Got future result: {result}")
self.wexin_event_workers.pop(msg.id, None) self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None)
return result # xml. see weixin_offacc_event.py return result # xml. see weixin_offacc_event.py
except asyncio.TimeoutError: except asyncio.TimeoutError:
pass pass
@@ -248,33 +252,33 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
async def convert_message( async def convert_message(
self, self,
msg, msg,
future: asyncio.Future = None, future: asyncio.Future | None = None,
) -> AstrBotMessage | None: ) -> AstrBotMessage | None:
abm = AstrBotMessage() abm = AstrBotMessage()
if isinstance(msg, TextMessage): if isinstance(msg, TextMessage):
abm.message_str = msg.content abm.message_str = cast(str, msg.content)
abm.self_id = str(msg.target) abm.self_id = str(msg.target)
abm.message = [Plain(msg.content)] abm.message = [Plain(cast(str, msg.content))]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = msg.time abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
elif msg.type == "image": elif msg.type == "image":
assert isinstance(msg, ImageMessage) assert isinstance(msg, ImageMessage)
abm.message_str = "[图片]" abm.message_str = "[图片]"
abm.self_id = str(msg.target) abm.self_id = str(msg.target)
abm.message = [Image(file=msg.image, url=msg.image)] abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = msg.time abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
elif msg.type == "voice": elif msg.type == "voice":
assert isinstance(msg, VoiceMessage) assert isinstance(msg, VoiceMessage)
@@ -306,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
abm.message = [Record(file=path_wav, url=path_wav)] abm.message = [Record(file=path_wav, url=path_wav)]
abm.type = MessageType.FRIEND_MESSAGE abm.type = MessageType.FRIEND_MESSAGE
abm.sender = MessageMember( abm.sender = MessageMember(
msg.source, cast(str, msg.source),
msg.source, cast(str, msg.source),
) )
abm.message_id = msg.id abm.message_id = str(cast(str | int, msg.id))
abm.timestamp = msg.time abm.timestamp = cast(int, msg.time)
abm.session_id = abm.sender.user_id abm.session_id = abm.sender.user_id
else: else:
logger.warning(f"暂未实现的事件: {msg.type}") logger.warning(f"暂未实现的事件: {msg.type}")
future.set_result(None) if future:
future.set_result(None)
return return
# 很不优雅 :( # 很不优雅 :(
abm.raw_message = { abm.raw_message = {
@@ -344,4 +349,4 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
await self.server.server.shutdown() await self.server.server.shutdown()
except Exception as _: except Exception as _:
pass pass
logger.info("微信公众平台 适配器已被优雅地关闭") logger.info("微信公众平台 适配器已被关闭")
@@ -1,5 +1,6 @@
import asyncio import asyncio
import uuid import uuid
from typing import cast
from wechatpy import WeChatClient from wechatpy import WeChatClient
from wechatpy.replies import ImageReply, TextReply, VoiceReply from wechatpy.replies import ImageReply, TextReply, VoiceReply
@@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
async def send(self, message: MessageChain): async def send(self, message: MessageChain):
message_obj = self.message_obj message_obj = self.message_obj
active_send_mode = message_obj.raw_message.get("active_send_mode", False) active_send_mode = cast(dict, message_obj.raw_message).get(
"active_send_mode", False
)
for comp in message.chain: for comp in message.chain:
if isinstance(comp, Plain): if isinstance(comp, Plain):
# Split long text messages if needed # Split long text messages if needed
@@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else: else:
reply = TextReply( reply = TextReply(
content=chunk, content=chunk,
message=self.message_obj.raw_message["message"], message=cast(dict, self.message_obj.raw_message)["message"],
) )
xml = reply.render() xml = reply.render()
future = self.message_obj.raw_message["future"] future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future) assert isinstance(future, asyncio.Future)
future.set_result(xml) future.set_result(xml)
await asyncio.sleep(0.5) # Avoid sending too fast await asyncio.sleep(0.5) # Avoid sending too fast
@@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else: else:
reply = ImageReply( reply = ImageReply(
media_id=response["media_id"], media_id=response["media_id"],
message=self.message_obj.raw_message["message"], message=cast(dict, self.message_obj.raw_message)["message"],
) )
xml = reply.render() xml = reply.render()
future = self.message_obj.raw_message["future"] future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future) assert isinstance(future, asyncio.Future)
future.set_result(xml) future.set_result(xml)
@@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
else: else:
reply = VoiceReply( reply = VoiceReply(
media_id=response["media_id"], media_id=response["media_id"],
message=self.message_obj.raw_message["message"], message=cast(dict, self.message_obj.raw_message)["message"],
) )
xml = reply.render() xml = reply.render()
future = self.message_obj.raw_message["future"] future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future) assert isinstance(future, asyncio.Future)
future.set_result(xml) future.set_result(xml)
+73 -9
View File
@@ -1,3 +1,5 @@
from __future__ import annotations
import base64 import base64
import enum import enum
import json import json
@@ -12,6 +14,7 @@ import astrbot.core.message.components as Comp
from astrbot import logger from astrbot import logger
from astrbot.core.agent.message import ( from astrbot.core.agent.message import (
AssistantMessageSegment, AssistantMessageSegment,
ContentPart,
ToolCall, ToolCall,
ToolCallMessageSegment, ToolCallMessageSegment,
) )
@@ -90,6 +93,8 @@ class ProviderRequest:
"""会话 ID""" """会话 ID"""
image_urls: list[str] = field(default_factory=list) image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表""" """图片 URL 列表"""
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。"""
func_tool: ToolSet | None = None func_tool: ToolSet | None = None
"""可用的函数工具""" """可用的函数工具"""
contexts: list[dict] = field(default_factory=list) contexts: list[dict] = field(default_factory=list)
@@ -164,13 +169,23 @@ class ProviderRequest:
async def assemble_context(self) -> dict: async def assemble_context(self) -> dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。""" """将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if self.prompt and self.prompt.strip():
content_blocks.append({"type": "text", "text": self.prompt})
elif self.image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
# 2. 额外的内容块(系统提醒、指令等)
if self.extra_user_content_parts:
for part in self.extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if self.image_urls: if self.image_urls:
user_content = {
"role": "user",
"content": [
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
],
}
for image_url in self.image_urls: for image_url in self.image_urls:
if image_url.startswith("http"): if image_url.startswith("http"):
image_path = await download_image_by_url(image_url) image_path = await download_image_by_url(image_url)
@@ -183,11 +198,21 @@ class ProviderRequest:
if not image_data: if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue continue
user_content["content"].append( content_blocks.append(
{"type": "image_url", "image_url": {"url": image_data}}, {"type": "image_url", "image_url": {"url": image_data}},
) )
return user_content
return {"role": "user", "content": self.prompt} # 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
if (
len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
and not self.extra_user_content_parts
and not self.image_urls
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def _encode_image_bs64(self, image_url: str) -> str: async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64""" """将图片转换为 base64"""
@@ -199,6 +224,38 @@ class ProviderRequest:
return "" return ""
@dataclass
class TokenUsage:
input_other: int = 0
"""The number of input tokens, excluding cached tokens."""
input_cached: int = 0
"""The number of input cached tokens."""
output: int = 0
"""The number of output tokens."""
@property
def total(self) -> int:
return self.input_other + self.input_cached + self.output
@property
def input(self) -> int:
return self.input_other + self.input_cached
def __add__(self, other: TokenUsage) -> TokenUsage:
return TokenUsage(
input_other=self.input_other + other.input_other,
input_cached=self.input_cached + other.input_cached,
output=self.output + other.output,
)
def __sub__(self, other: TokenUsage) -> TokenUsage:
return TokenUsage(
input_other=self.input_other - other.input_other,
input_cached=self.input_cached - other.input_cached,
output=self.output - other.output,
)
@dataclass @dataclass
class LLMResponse: class LLMResponse:
role: str role: str
@@ -227,6 +284,11 @@ class LLMResponse:
is_chunk: bool = False is_chunk: bool = False
"""Indicates if the response is a chunked response.""" """Indicates if the response is a chunked response."""
id: str | None = None
"""The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
usage: TokenUsage | None = None
"""The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
def __init__( def __init__(
self, self,
role: str, role: str,
@@ -241,6 +303,8 @@ class LLMResponse:
| AnthropicMessage | AnthropicMessage
| None = None, | None = None,
is_chunk: bool = False, is_chunk: bool = False,
id: str | None = None,
usage: TokenUsage | None = None,
): ):
"""初始化 LLMResponse """初始化 LLMResponse
+3 -3
View File
@@ -4,7 +4,7 @@ import asyncio
import copy import copy
import json import json
import os import os
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any from typing import Any
import aiohttp import aiohttp
@@ -118,7 +118,7 @@ class FunctionToolManager:
name: str, name: str,
func_args: list[dict], func_args: list[dict],
desc: str, desc: str,
handler: Callable[..., Awaitable[Any]], handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> FuncTool: ) -> FuncTool:
params = { params = {
"type": "object", # hard-coded here "type": "object", # hard-coded here
@@ -140,7 +140,7 @@ class FunctionToolManager:
name: str, name: str,
func_args: list, func_args: list,
desc: str, desc: str,
handler: Callable[..., Awaitable[Any]], handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> None: ) -> None:
"""添加函数调用工具 """添加函数调用工具
+329 -163
View File
@@ -1,5 +1,7 @@
import asyncio import asyncio
import copy
import traceback import traceback
from typing import Protocol, runtime_checkable
from astrbot.core import astrbot_config, logger, sp from astrbot.core import astrbot_config, logger, sp
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
@@ -10,6 +12,7 @@ from .entities import ProviderType
from .provider import ( from .provider import (
EmbeddingProvider, EmbeddingProvider,
Provider, Provider,
Providers,
RerankProvider, RerankProvider,
STTProvider, STTProvider,
TTSProvider, TTSProvider,
@@ -17,6 +20,11 @@ from .provider import (
from .register import llm_tools, provider_cls_map from .register import llm_tools, provider_cls_map
@runtime_checkable
class HasInitialize(Protocol):
async def initialize(self) -> None: ...
class ProviderManager: class ProviderManager:
def __init__( def __init__(
self, self,
@@ -25,10 +33,12 @@ class ProviderManager:
persona_mgr: PersonaManager, persona_mgr: PersonaManager,
): ):
self.reload_lock = asyncio.Lock() self.reload_lock = asyncio.Lock()
self.resource_lock = asyncio.Lock()
self.persona_mgr = persona_mgr self.persona_mgr = persona_mgr
self.acm = acm self.acm = acm
config = acm.confs["default"] config = acm.confs["default"]
self.providers_config: list = config["provider"] self.providers_config: list = config["provider"]
self.provider_sources_config: list = config.get("provider_sources", [])
self.provider_settings: dict = config["provider_settings"] self.provider_settings: dict = config["provider_settings"]
self.provider_stt_settings: dict = config.get("provider_stt_settings", {}) self.provider_stt_settings: dict = config.get("provider_stt_settings", {})
self.provider_tts_settings: dict = config.get("provider_tts_settings", {}) self.provider_tts_settings: dict = config.get("provider_tts_settings", {})
@@ -48,7 +58,7 @@ class ProviderManager:
"""加载的 Rerank Provider 的实例""" """加载的 Rerank Provider 的实例"""
self.inst_map: dict[ self.inst_map: dict[
str, str,
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider, Providers,
] = {} ] = {}
"""Provider 实例映射. key: provider_id, value: Provider 实例""" """Provider 实例映射. key: provider_id, value: Provider 实例"""
self.llm_tools = llm_tools self.llm_tools = llm_tools
@@ -123,15 +133,13 @@ class ProviderManager:
self.curr_provider_inst = prov self.curr_provider_inst = prov
sp.put("curr_provider", provider_id, scope="global", scope_id="global") sp.put("curr_provider", provider_id, scope="global", scope_id="global")
async def get_provider_by_id(self, provider_id: str) -> Provider | None: async def get_provider_by_id(self, provider_id: str) -> Providers | None:
"""根据提供商 ID 获取提供商实例""" """根据提供商 ID 获取提供商实例"""
return self.inst_map.get(provider_id) return self.inst_map.get(provider_id)
def get_using_provider( def get_using_provider(
self, self, provider_type: ProviderType, umo=None
provider_type: ProviderType, ) -> Providers | None:
umo=None,
) -> Provider | STTProvider | TTSProvider | None:
"""获取正在使用的提供商实例。 """获取正在使用的提供商实例。
Args: Args:
@@ -143,6 +151,7 @@ class ProviderManager:
""" """
provider = None provider = None
provider_id = None
if umo: if umo:
provider_id = sp.get( provider_id = sp.get(
f"provider_perf_{provider_type.value}", f"provider_perf_{provider_type.value}",
@@ -180,6 +189,12 @@ class ProviderManager:
) )
else: else:
raise ValueError(f"Unknown provider type: {provider_type}") raise ValueError(f"Unknown provider type: {provider_type}")
if not provider and provider_id:
logger.warning(
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
)
return provider return provider
async def initialize(self): async def initialize(self):
@@ -191,7 +206,6 @@ class ProviderManager:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
logger.error(e) logger.error(e)
# 设置默认提供商
selected_provider_id = sp.get( selected_provider_id = sp.get(
"curr_provider", "curr_provider",
self.provider_settings.get("default_provider_id"), self.provider_settings.get("default_provider_id"),
@@ -210,22 +224,173 @@ class ProviderManager:
scope="global", scope="global",
scope_id="global", scope_id="global",
) )
self.curr_provider_inst = self.inst_map.get(selected_provider_id)
temp_provider = (
self.inst_map.get(selected_provider_id)
if isinstance(selected_provider_id, str)
else None
)
self.curr_provider_inst = (
temp_provider if isinstance(temp_provider, Provider) else None
)
if not self.curr_provider_inst and self.provider_insts: if not self.curr_provider_inst and self.provider_insts:
self.curr_provider_inst = self.provider_insts[0] self.curr_provider_inst = self.provider_insts[0]
self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) temp_stt = (
self.inst_map.get(selected_stt_provider_id)
if isinstance(selected_stt_provider_id, str)
else None
)
self.curr_stt_provider_inst = (
temp_stt if isinstance(temp_stt, STTProvider) else None
)
if not self.curr_stt_provider_inst and self.stt_provider_insts: if not self.curr_stt_provider_inst and self.stt_provider_insts:
self.curr_stt_provider_inst = self.stt_provider_insts[0] self.curr_stt_provider_inst = self.stt_provider_insts[0]
self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) temp_tts = (
self.inst_map.get(selected_tts_provider_id)
if isinstance(selected_tts_provider_id, str)
else None
)
self.curr_tts_provider_inst = (
temp_tts if isinstance(temp_tts, TTSProvider) else None
)
if not self.curr_tts_provider_inst and self.tts_provider_insts: if not self.curr_tts_provider_inst and self.tts_provider_insts:
self.curr_tts_provider_inst = self.tts_provider_insts[0] self.curr_tts_provider_inst = self.tts_provider_insts[0]
# 初始化 MCP Client 连接 # 初始化 MCP Client 连接
asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients") asyncio.create_task(self.llm_tools.init_mcp_clients(), name="init_mcp_clients")
def dynamic_import_provider(self, type: str):
"""动态导入提供商适配器模块
Args:
type (str): 提供商请求类型
Raises:
ImportError: 如果提供商类型未知或无法导入对应模块则抛出异常
"""
match type:
case "openai_chat_completion":
from .sources.openai_source import (
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
)
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "groq_chat_completion":
from .sources.groq_source import ProviderGroq as ProviderGroq
case "anthropic_chat_completion":
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
case "googlegenai_chat_completion":
from .sources.gemini_source import (
ProviderGoogleGenAI as ProviderGoogleGenAI,
)
case "sensevoice_stt_selfhost":
from .sources.sensevoice_selfhosted_source import (
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
)
case "openai_whisper_api":
from .sources.whisper_api_source import (
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
)
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import (
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
)
case "xinference_stt":
from .sources.xinference_stt_provider import (
ProviderXinferenceSTT as ProviderXinferenceSTT,
)
case "openai_tts_api":
from .sources.openai_tts_api_source import (
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
)
case "edge_tts":
from .sources.edge_tts_source import (
ProviderEdgeTTS as ProviderEdgeTTS,
)
case "gsv_tts_selfhost":
from .sources.gsv_selfhosted_source import (
ProviderGSVTTS as ProviderGSVTTS,
)
case "gsvi_tts_api":
from .sources.gsvi_tts_source import (
ProviderGSVITTS as ProviderGSVITTS,
)
case "fishaudio_tts_api":
from .sources.fishaudio_tts_api_source import (
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
)
case "dashscope_tts":
from .sources.dashscope_tts import (
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
)
case "azure_tts":
from .sources.azure_tts_source import (
AzureTTSProvider as AzureTTSProvider,
)
case "minimax_tts_api":
from .sources.minimax_tts_api_source import (
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
)
case "volcengine_tts":
from .sources.volcengine_tts import (
ProviderVolcengineTTS as ProviderVolcengineTTS,
)
case "gemini_tts":
from .sources.gemini_tts_source import (
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
)
case "openai_embedding":
from .sources.openai_embedding_source import (
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
)
case "gemini_embedding":
from .sources.gemini_embedding_source import (
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
)
case "vllm_rerank":
from .sources.vllm_rerank_source import (
VLLMRerankProvider as VLLMRerankProvider,
)
case "xinference_rerank":
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
case "bailian_rerank":
from .sources.bailian_rerank_source import (
BailianRerankProvider as BailianRerankProvider,
)
def get_merged_provider_config(self, provider_config: dict) -> dict:
"""获取 provider 配置和 provider_source 配置合并后的结果
Returns:
dict: 合并后的 provider 配置key provider idvalue 为合并后的配置字典
"""
pc = copy.deepcopy(provider_config)
provider_source_id = pc.get("provider_source_id", "")
if provider_source_id:
provider_source = None
for ps in self.provider_sources_config:
if ps.get("id") == provider_source_id:
provider_source = ps
break
if provider_source:
# 合并配置,provider 的配置优先级更高
merged_config = {**provider_source, **pc}
# 保持 id 为 provider 的 id,而不是 source 的 id
merged_config["id"] = pc["id"]
pc = merged_config
return pc
async def load_provider(self, provider_config: dict): async def load_provider(self, provider_config: dict):
# 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并
provider_config = self.get_merged_provider_config(provider_config)
if not provider_config["enable"]: if not provider_config["enable"]:
logger.info(f"Provider {provider_config['id']} is disabled, skipping") logger.info(f"Provider {provider_config['id']} is disabled, skipping")
return return
@@ -238,99 +403,7 @@ class ProviderManager:
# 动态导入 # 动态导入
try: try:
match provider_config["type"]: self.dynamic_import_provider(provider_config["type"])
case "openai_chat_completion":
from .sources.openai_source import (
ProviderOpenAIOfficial as ProviderOpenAIOfficial,
)
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu as ProviderZhipu
case "groq_chat_completion":
from .sources.groq_source import ProviderGroq as ProviderGroq
case "anthropic_chat_completion":
from .sources.anthropic_source import (
ProviderAnthropic as ProviderAnthropic,
)
case "googlegenai_chat_completion":
from .sources.gemini_source import (
ProviderGoogleGenAI as ProviderGoogleGenAI,
)
case "sensevoice_stt_selfhost":
from .sources.sensevoice_selfhosted_source import (
ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost,
)
case "openai_whisper_api":
from .sources.whisper_api_source import (
ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI,
)
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import (
ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost,
)
case "xinference_stt":
from .sources.xinference_stt_provider import (
ProviderXinferenceSTT as ProviderXinferenceSTT,
)
case "openai_tts_api":
from .sources.openai_tts_api_source import (
ProviderOpenAITTSAPI as ProviderOpenAITTSAPI,
)
case "edge_tts":
from .sources.edge_tts_source import (
ProviderEdgeTTS as ProviderEdgeTTS,
)
case "gsv_tts_selfhost":
from .sources.gsv_selfhosted_source import (
ProviderGSVTTS as ProviderGSVTTS,
)
case "gsvi_tts_api":
from .sources.gsvi_tts_source import (
ProviderGSVITTS as ProviderGSVITTS,
)
case "fishaudio_tts_api":
from .sources.fishaudio_tts_api_source import (
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
)
case "dashscope_tts":
from .sources.dashscope_tts import (
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
)
case "azure_tts":
from .sources.azure_tts_source import (
AzureTTSProvider as AzureTTSProvider,
)
case "minimax_tts_api":
from .sources.minimax_tts_api_source import (
ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI,
)
case "volcengine_tts":
from .sources.volcengine_tts import (
ProviderVolcengineTTS as ProviderVolcengineTTS,
)
case "gemini_tts":
from .sources.gemini_tts_source import (
ProviderGeminiTTSAPI as ProviderGeminiTTSAPI,
)
case "openai_embedding":
from .sources.openai_embedding_source import (
OpenAIEmbeddingProvider as OpenAIEmbeddingProvider,
)
case "gemini_embedding":
from .sources.gemini_embedding_source import (
GeminiEmbeddingProvider as GeminiEmbeddingProvider,
)
case "vllm_rerank":
from .sources.vllm_rerank_source import (
VLLMRerankProvider as VLLMRerankProvider,
)
case "xinference_rerank":
from .sources.xinference_rerank_source import (
XinferenceRerankProvider as XinferenceRerankProvider,
)
case "bailian_rerank":
from .sources.bailian_rerank_source import (
BailianRerankProvider as BailianRerankProvider,
)
except (ImportError, ModuleNotFoundError) as e: except (ImportError, ModuleNotFoundError) as e:
logger.critical( logger.critical(
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。", f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。",
@@ -358,73 +431,103 @@ class ProviderManager:
provider_metadata.id = provider_config["id"] provider_metadata.id = provider_config["id"]
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: match provider_metadata.provider_type:
# STT 任务 case ProviderType.SPEECH_TO_TEXT:
inst = cls_type(provider_config, self.provider_settings) # STT 任务
if not issubclass(cls_type, STTProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of STTProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None): if isinstance(inst, HasInitialize):
await inst.initialize() await inst.initialize()
self.stt_provider_insts.append(inst) self.stt_provider_insts.append(inst)
if ( if (
self.provider_stt_settings.get("provider_id") self.provider_stt_settings.get("provider_id")
== provider_config["id"] == provider_config["id"]
): ):
self.curr_stt_provider_inst = inst self.curr_stt_provider_inst = inst
logger.info( logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。",
)
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
case ProviderType.TEXT_TO_SPEECH:
# TTS 任务
if not issubclass(cls_type, TTSProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of TTSProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.tts_provider_insts.append(inst)
if (
self.provider_settings.get("provider_id")
== provider_config["id"]
):
self.curr_tts_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。",
)
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
case ProviderType.CHAT_COMPLETION:
# 文本生成任务
if not issubclass(cls_type, Provider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of Provider"
)
inst = cls_type(
provider_config,
self.provider_settings,
) )
if not self.curr_stt_provider_inst:
self.curr_stt_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: if isinstance(inst, HasInitialize):
# TTS 任务 await inst.initialize()
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None): self.provider_insts.append(inst)
await inst.initialize() if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
self.tts_provider_insts.append(inst) case ProviderType.EMBEDDING:
if self.provider_settings.get("provider_id") == provider_config["id"]: if not issubclass(cls_type, EmbeddingProvider):
self.curr_tts_provider_inst = inst raise TypeError(
logger.info( f"Provider class {cls_type} is not a subclass of EmbeddingProvider"
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", )
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.embedding_provider_insts.append(inst)
case ProviderType.RERANK:
if not issubclass(cls_type, RerankProvider):
raise TypeError(
f"Provider class {cls_type} is not a subclass of RerankProvider"
)
inst = cls_type(provider_config, self.provider_settings)
if isinstance(inst, HasInitialize):
await inst.initialize()
self.rerank_provider_insts.append(inst)
case _:
# 未知供应商抛出异常,确保inst初始化
# Should be unreachable
raise Exception(
f"未知的提供商类型:{provider_metadata.provider_type}"
) )
if not self.curr_tts_provider_inst:
self.curr_tts_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = cls_type(
provider_config,
self.provider_settings,
)
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if (
self.provider_settings.get("default_provider_id")
== provider_config["id"]
):
self.curr_provider_inst = inst
logger.info(
f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。",
)
if not self.curr_provider_inst:
self.curr_provider_inst = inst
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.embedding_provider_insts.append(inst)
elif provider_metadata.provider_type == ProviderType.RERANK:
inst = cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.rerank_provider_insts.append(inst)
self.inst_map[provider_config["id"]] = inst self.inst_map[provider_config["id"]] = inst
except Exception as e: except Exception as e:
@@ -443,6 +546,7 @@ class ProviderManager:
# 和配置文件保持同步 # 和配置文件保持同步
self.providers_config = astrbot_config["provider"] self.providers_config = astrbot_config["provider"]
self.provider_sources_config = astrbot_config.get("provider_sources", [])
config_ids = [provider["id"] for provider in self.providers_config] config_ids = [provider["id"] for provider in self.providers_config]
logger.info(f"providers in user's config: {config_ids}") logger.info(f"providers in user's config: {config_ids}")
for key in list(self.inst_map.keys()): for key in list(self.inst_map.keys()):
@@ -514,6 +618,68 @@ class ProviderManager:
) )
del self.inst_map[provider_id] del self.inst_map[provider_id]
async def delete_provider(
self, provider_id: str | None = None, provider_source_id: str | None = None
):
"""Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion."""
async with self.resource_lock:
# delete from config
target_prov_ids = []
if provider_id:
target_prov_ids.append(provider_id)
else:
for prov in self.providers_config:
if prov.get("provider_source_id") == provider_source_id:
target_prov_ids.append(prov.get("id"))
config = self.acm.default_conf
for tpid in target_prov_ids:
await self.terminate_provider(tpid)
config["provider"] = [
prov for prov in config["provider"] if prov.get("id") != tpid
]
config.save_config()
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")
async def update_provider(self, origin_provider_id: str, new_config: dict):
"""Update provider config and reload the instance. Config will be saved after update."""
async with self.resource_lock:
npid = new_config.get("id", None)
if not npid:
raise ValueError("New provider config must have an 'id' field")
config = self.acm.default_conf
for provider in config["provider"]:
if (
provider.get("id", None) == npid
and provider.get("id", None) != origin_provider_id
):
raise ValueError(f"Provider ID {npid} already exists")
# update config
for idx, provider in enumerate(config["provider"]):
if provider.get("id", None) == origin_provider_id:
config["provider"][idx] = new_config
break
else:
raise ValueError(f"Provider ID {origin_provider_id} not found")
config.save_config()
# reload instance
await self.reload(new_config)
async def create_provider(self, new_config: dict):
"""Add new provider config and load the instance. Config will be saved after addition."""
async with self.resource_lock:
npid = new_config.get("id", None)
if not npid:
raise ValueError("New provider config must have an 'id' field")
config = self.acm.default_conf
for provider in config["provider"]:
if provider.get("id", None) == npid:
raise ValueError(f"Provider ID {npid} already exists")
# add to config
config["provider"].append(new_config)
config.save_config()
# load instance
await self.load_provider(new_config)
async def terminate(self): async def terminate(self):
for provider_inst in self.provider_insts: for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"): if hasattr(provider_inst, "terminate"):
+17 -2
View File
@@ -2,8 +2,9 @@ import abc
import asyncio import asyncio
import os import os
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import TypeAlias, Union
from astrbot.core.agent.message import Message from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool import ToolSet
from astrbot.core.provider.entities import ( from astrbot.core.provider.entities import (
LLMResponse, LLMResponse,
@@ -14,6 +15,14 @@ from astrbot.core.provider.entities import (
from astrbot.core.provider.register import provider_cls_map from astrbot.core.provider.register import provider_cls_map
from astrbot.core.utils.astrbot_path import get_astrbot_path from astrbot.core.utils.astrbot_path import get_astrbot_path
Providers: TypeAlias = Union[
"Provider",
"STTProvider",
"TTSProvider",
"EmbeddingProvider",
"RerankProvider",
]
class AbstractProvider(abc.ABC): class AbstractProvider(abc.ABC):
"""Provider Abstract Class""" """Provider Abstract Class"""
@@ -94,6 +103,7 @@ class Provider(AbstractProvider):
system_prompt: str | None = None, system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None, model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。 """获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -105,6 +115,7 @@ class Provider(AbstractProvider):
tools: tool set tools: tool set
contexts: 上下文 prompt 二选一使用 contexts: 上下文 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling
extra_user_content_parts: 额外的用户内容块列表用于在用户消息后添加额外的文本块如系统提醒指令等
kwargs: 其他参数 kwargs: 其他参数
Notes: Notes:
@@ -124,6 +135,7 @@ class Provider(AbstractProvider):
system_prompt: str | None = None, system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None, model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs, **kwargs,
) -> AsyncGenerator[LLMResponse, None]: ) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
@@ -135,6 +147,7 @@ class Provider(AbstractProvider):
tools: tool set tools: tool set
contexts: 上下文 prompt 二选一使用 contexts: 上下文 prompt 二选一使用
tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling tool_calls_result: 回传给 LLM 的工具调用结果参考: https://platform.openai.com/docs/guides/function-calling
extra_user_content_parts: 额外的用户内容块列表用于在用户消息后添加额外的文本块如系统提醒指令等
kwargs: 其他参数 kwargs: 其他参数
Notes: Notes:
@@ -142,7 +155,9 @@ class Provider(AbstractProvider):
- 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误 - 如果传入了 tools将会使用 tools 进行 Function-calling如果模型不支持 Function-calling将会抛出错误
""" """
... if False: # pragma: no cover - make this an async generator for typing
yield None # type: ignore
raise NotImplementedError()
async def pop_record(self, context: list): async def pop_record(self, context: list):
"""弹出 context 第一条非系统提示词对话记录""" """弹出 context 第一条非系统提示词对话记录"""
+166 -47
View File
@@ -6,10 +6,13 @@ from mimetypes import guess_type
import anthropic import anthropic
from anthropic import AsyncAnthropic from anthropic import AsyncAnthropic
from anthropic.types import Message from anthropic.types import Message
from anthropic.types.message_delta_usage import MessageDeltaUsage
from anthropic.types.usage import Usage
from astrbot import logger from astrbot import logger
from astrbot.api.provider import Provider from astrbot.api.provider import Provider
from astrbot.core.provider.entities import LLMResponse from astrbot.core.agent.message import ContentPart
from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
@@ -45,7 +48,7 @@ class ProviderAnthropic(Provider):
base_url=self.base_url, base_url=self.base_url,
) )
self.set_model(provider_config["model_config"]["model"]) self.set_model(provider_config.get("model", "unknown"))
def _prepare_payload(self, messages: list[dict]): def _prepare_payload(self, messages: list[dict]):
"""准备 Anthropic API 的请求 payload """准备 Anthropic API 的请求 payload
@@ -107,12 +110,32 @@ class ProviderAnthropic(Provider):
return system_prompt, new_messages return system_prompt, new_messages
def _extract_usage(self, usage: Usage) -> TokenUsage:
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
return TokenUsage(
input_other=usage.input_tokens or 0,
input_cached=usage.cache_read_input_tokens or 0,
output=usage.output_tokens,
)
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
if usage.input_tokens is not None:
token_usage.input_other = usage.input_tokens
if usage.cache_read_input_tokens is not None:
token_usage.input_cached = usage.cache_read_input_tokens
if usage.output_tokens is not None:
token_usage.output = usage.output_tokens
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools: if tools:
if tool_list := tools.get_func_desc_anthropic_style(): if tool_list := tools.get_func_desc_anthropic_style():
payloads["tools"] = tool_list payloads["tools"] = tool_list
completion = await self.client.messages.create(**payloads, stream=False) extra_body = self.provider_config.get("custom_extra_body", {})
completion = await self.client.messages.create(
**payloads, stream=False, extra_body=extra_body
)
assert isinstance(completion, Message) assert isinstance(completion, Message)
logger.debug(f"completion: {completion}") logger.debug(f"completion: {completion}")
@@ -131,6 +154,10 @@ class ProviderAnthropic(Provider):
llm_response.tools_call_args.append(content_block.input) llm_response.tools_call_args.append(content_block.input)
llm_response.tools_call_name.append(content_block.name) llm_response.tools_call_name.append(content_block.name)
llm_response.tools_call_ids.append(content_block.id) llm_response.tools_call_ids.append(content_block.id)
llm_response.id = completion.id
llm_response.usage = self._extract_usage(completion.usage)
# TODO(Soulter): 处理 end_turn 情况 # TODO(Soulter): 处理 end_turn 情况
if not llm_response.completion_text and not llm_response.tools_call_args: if not llm_response.completion_text and not llm_response.tools_call_args:
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}") raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}")
@@ -151,10 +178,19 @@ class ProviderAnthropic(Provider):
# 用于累积最终结果 # 用于累积最终结果
final_text = "" final_text = ""
final_tool_calls = [] final_tool_calls = []
id = None
usage = TokenUsage()
extra_body = self.provider_config.get("custom_extra_body", {})
async with self.client.messages.stream(**payloads) as stream: async with self.client.messages.stream(
**payloads, extra_body=extra_body
) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream) assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream: async for event in stream:
if event.type == "message_start":
# the usage contains input token usage
id = event.message.id
usage = self._extract_usage(event.message.usage)
if event.type == "content_block_start": if event.type == "content_block_start":
if event.content_block.type == "text": if event.content_block.type == "text":
# 文本块开始 # 文本块开始
@@ -162,6 +198,8 @@ class ProviderAnthropic(Provider):
role="assistant", role="assistant",
completion_text="", completion_text="",
is_chunk=True, is_chunk=True,
usage=usage,
id=id,
) )
elif event.content_block.type == "tool_use": elif event.content_block.type == "tool_use":
# 工具使用块开始,初始化缓冲区 # 工具使用块开始,初始化缓冲区
@@ -179,6 +217,8 @@ class ProviderAnthropic(Provider):
role="assistant", role="assistant",
completion_text=event.delta.text, completion_text=event.delta.text,
is_chunk=True, is_chunk=True,
usage=usage,
id=id,
) )
elif event.delta.type == "input_json_delta": elif event.delta.type == "input_json_delta":
# 工具调用参数增量 # 工具调用参数增量
@@ -215,6 +255,8 @@ class ProviderAnthropic(Provider):
tools_call_name=[tool_info["name"]], tools_call_name=[tool_info["name"]],
tools_call_ids=[tool_info["id"]], tools_call_ids=[tool_info["id"]],
is_chunk=True, is_chunk=True,
usage=usage,
id=id,
) )
except json.JSONDecodeError: except json.JSONDecodeError:
# JSON 解析失败,跳过这个工具调用 # JSON 解析失败,跳过这个工具调用
@@ -223,11 +265,17 @@ class ProviderAnthropic(Provider):
# 清理缓冲区 # 清理缓冲区
del tool_use_buffer[event.index] del tool_use_buffer[event.index]
elif event.type == "message_delta":
if event.usage:
self._update_usage(usage, event.usage)
# 返回最终的完整结果 # 返回最终的完整结果
final_response = LLMResponse( final_response = LLMResponse(
role="assistant", role="assistant",
completion_text=final_text, completion_text=final_text,
is_chunk=False, is_chunk=False,
usage=usage,
id=id,
) )
if final_tool_calls: if final_tool_calls:
@@ -249,13 +297,16 @@ class ProviderAnthropic(Provider):
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None, model=None,
extra_user_content_parts=None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
if contexts is None: if contexts is None:
contexts = [] contexts = []
new_record = None new_record = None
if prompt is not None: if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls) new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts) context_query = self._ensure_message_to_dicts(contexts)
if new_record: if new_record:
context_query.append(new_record) context_query.append(new_record)
@@ -277,10 +328,9 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query) system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {}) model = model or self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config} payloads = {"messages": new_messages, "model": model}
# Anthropic has a different way of handling system prompts # Anthropic has a different way of handling system prompts
if system_prompt: if system_prompt:
@@ -290,7 +340,6 @@ class ProviderAnthropic(Provider):
try: try:
llm_response = await self._query(payloads, func_tool) llm_response = await self._query(payloads, func_tool)
except Exception as e: except Exception as e:
# logger.error(f"发生了错误。Provider 配置如下: {model_config}")
raise e raise e
return llm_response return llm_response
@@ -305,13 +354,16 @@ class ProviderAnthropic(Provider):
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None, model=None,
extra_user_content_parts=None,
**kwargs, **kwargs,
): ):
if contexts is None: if contexts is None:
contexts = [] contexts = []
new_record = None new_record = None
if prompt is not None: if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls) new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts) context_query = self._ensure_message_to_dicts(contexts)
if new_record: if new_record:
context_query.append(new_record) context_query.append(new_record)
@@ -332,10 +384,9 @@ class ProviderAnthropic(Provider):
system_prompt, new_messages = self._prepare_payload(context_query) system_prompt, new_messages = self._prepare_payload(context_query)
model_config = self.provider_config.get("model_config", {}) model = model or self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": new_messages, **model_config} payloads = {"messages": new_messages, "model": model}
# Anthropic has a different way of handling system prompts # Anthropic has a different way of handling system prompts
if system_prompt: if system_prompt:
@@ -344,48 +395,116 @@ class ProviderAnthropic(Provider):
async for llm_response in self._query_stream(payloads, func_tool): async for llm_response in self._query_stream(payloads, func_tool):
yield llm_response yield llm_response
async def assemble_context(self, text: str, image_urls: list[str] | None = None): async def assemble_context(
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文,支持文本和图片""" """组装上下文,支持文本和图片"""
if not image_urls:
return {"role": "user", "content": text}
content = [] content = []
content.append({"type": "text", "text": text})
for image_url in image_urls: # 1. 用户原始发言(OpenAI 建议:用户发言在前)
if image_url.startswith("http"): if text:
image_path = await download_image_by_url(image_url) content.append({"type": "text", "text": text})
image_data = await self.encode_image_bs64(image_path) elif image_urls:
elif image_url.startswith("file:///"): # 如果没有文本但有图片,添加占位文本
image_path = image_url.replace("file:///", "") content.append({"type": "text", "text": "[图片]"})
image_data = await self.encode_image_bs64(image_path) elif extra_user_content_parts:
else: # 如果只有额外内容块,也需要添加占位文本
image_data = await self.encode_image_bs64(image_url) content.append({"type": "text", "text": " "})
if not image_data: # 2. 额外的内容块(系统提醒、指令等)
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") if extra_user_content_parts:
continue for block in extra_user_content_parts:
block_type = block.get("type")
# Get mime type for the image if block_type == "text":
mime_type, _ = guess_type(image_url) # 文本直接添加
if not mime_type: content.append(block)
mime_type = "image/jpeg" # Default to JPEG if can't determine
content.append( elif block_type == "image_url":
{ # 转换 OpenAI 格式的图片为 Anthropic 格式
"type": "image", image_url_data = block.get("image_url", {})
"source": { if isinstance(image_url_data, dict):
"type": "base64", url = image_url_data.get("url", "")
"media_type": mime_type, else:
"data": ( # 兼容直接传 URL 字符串的情况
image_data.split("base64,")[1] url = str(image_url_data)
if "base64," in image_data
else image_data if url and url.startswith("data:"):
), try:
# 提取 MIME 类型和 base64 数据
mime_type = url.split(":")[1].split(";")[0]
base64_data = (
url.split("base64,")[1] if "base64," in url else url
)
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": base64_data,
},
}
)
except Exception as e:
logger.warning(f"转换 image_url 到 Anthropic 格式失败: {e}")
else:
logger.warning(f"image_url 不是有效的 data URI: {url[:50]}...")
else:
# 其他类型(如 audio_urlAnthropic 不支持,记录警告
logger.debug(f"Anthropic 不支持的内容类型 '{block_type}',已忽略")
# 3. 图片内容
if image_urls:
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
# Get mime type for the image
mime_type, _ = guess_type(image_url)
if not mime_type:
mime_type = "image/jpeg" # Default to JPEG if can't determine
content.append(
{
"type": "image",
"source": {
"type": "base64",
"media_type": mime_type,
"data": (
image_data.split("base64,")[1]
if "base64," in image_data
else image_data
),
},
}, },
}, )
)
# 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content) == 1
and content[0]["type"] == "text"
):
return {"role": "user", "content": content[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content} return {"role": "user", "content": content}
async def encode_image_bs64(self, image_url: str) -> str: async def encode_image_bs64(self, image_url: str) -> str:
@@ -29,15 +29,24 @@ class OTTSProvider:
self.last_sync_time = 0 self.last_sync_time = 0
self.timeout = Timeout(10.0) self.timeout = Timeout(10.0)
self.retry_count = 3 self.retry_count = 3
self.client = None self._client: AsyncClient | None = None
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self): async def __aenter__(self):
self.client = AsyncClient(timeout=self.timeout) self._client = AsyncClient(timeout=self.timeout)
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client: if self._client:
await self.client.aclose() await self._client.aclose()
self._client = None
async def _sync_time(self): async def _sync_time(self):
try: try:
@@ -90,6 +99,7 @@ class OTTSProvider:
if attempt == self.retry_count - 1: if attempt == self.retry_count - 1:
raise RuntimeError(f"OTTS请求失败: {e!s}") from e raise RuntimeError(f"OTTS请求失败: {e!s}") from e
await asyncio.sleep(0.5 * (attempt + 1)) await asyncio.sleep(0.5 * (attempt + 1))
raise RuntimeError("OTTS未返回音频文件")
class AzureNativeProvider(TTSProvider): class AzureNativeProvider(TTSProvider):
@@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider):
self.endpoint = ( self.endpoint = (
f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1" f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1"
) )
self.client = None self._client: AsyncClient | None = None
self.token = None self.token = None
self.token_expire = 0 self.token_expire = 0
self.voice_params = { self.voice_params = {
@@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider):
"volume": provider_config.get("azure_tts_volume", "100"), "volume": provider_config.get("azure_tts_volume", "100"),
} }
@property
def client(self) -> AsyncClient:
if self._client is None:
raise RuntimeError(
"Client not initialized. Please use 'async with' context."
)
return self._client
async def __aenter__(self): async def __aenter__(self):
self.client = AsyncClient( self._client = AsyncClient(
headers={ headers={
"User-Agent": f"AstrBot/{VERSION}", "User-Agent": f"AstrBot/{VERSION}",
"Content-Type": "application/ssml+xml", "Content-Type": "application/ssml+xml",
@@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider):
return self return self
async def __aexit__(self, exc_type, exc_val, exc_tb): async def __aexit__(self, exc_type, exc_val, exc_tb):
if self.client: if self._client:
await self.client.aclose() await self._client.aclose()
self._client = None
async def _refresh_token(self): async def _refresh_token(self):
token_url = ( token_url = (
@@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider):
key_value = provider_config.get("azure_tts_subscription_key", "") key_value = provider_config.get("azure_tts_subscription_key", "")
self.provider = self._parse_provider(key_value, provider_config) self.provider = self._parse_provider(key_value, provider_config)
def _parse_provider(self, key_value: str, config: dict) -> TTSProvider: def _parse_provider(
self, key_value: str, config: dict
) -> OTTSProvider | AzureNativeProvider:
if key_value.lower().startswith("other["): if key_value.lower().startswith("other["):
json_str = ""
try: try:
match = re.match(r"other\[(.*)\]", key_value, re.DOTALL) match = re.match(r"other\[(.*)\]", key_value, re.DOTALL)
if not match: if not match:
@@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider):
Returns: Returns:
重排序结果列表 重排序结果列表
""" """
if not self.client:
logger.error("百炼 Rerank 客户端会话已关闭,返回空结果")
return []
if not documents: if not documents:
logger.warning("文档列表为空,返回空结果") logger.warning("文档列表为空,返回空结果")
return [] return []
@@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.chosen_api_key: str = provider_config.get("api_key", "") self.chosen_api_key: str = provider_config.get("api_key", "")
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella") self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000 self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
dashscope.api_key = self.chosen_api_key dashscope.api_key = self.chosen_api_key
@@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider):
kwargs = { kwargs = {
"model": model, "model": model,
"text": text, "messages": None,
"api_key": self.chosen_api_key, "api_key": self.chosen_api_key,
"voice": self.voice or "Cherry", "voice": self.voice or "Cherry",
"text": text,
} }
if not self.voice: if not self.voice:
logging.warning( logging.warning(
@@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider):
from pyffmpeg import FFmpeg from pyffmpeg import FFmpeg
ff = FFmpeg() ff = FFmpeg()
ff.convert(input=mp3_path, output=wav_path) ff.convert(input_file=mp3_path, output_file=wav_path)
except Exception as e: except Exception as e:
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
# use ffmpeg command line # use ffmpeg command line
@@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider):
self.headers = { self.headers = {
"Authorization": f"Bearer {self.chosen_api_key}", "Authorization": f"Bearer {self.chosen_api_key}",
} }
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
async def _get_reference_id_by_character(self, character: str) -> str: async def _get_reference_id_by_character(self, character: str) -> str | None:
"""获取角色的reference_id """获取角色的reference_id
Args: Args:
@@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
pattern = r"^[a-fA-F0-9]{32}$" pattern = r"^[a-fA-F0-9]{32}$"
return bool(re.match(pattern, reference_id.strip())) return bool(re.match(pattern, reference_id.strip()))
async def _generate_request(self, text: str) -> dict: async def _generate_request(self, text: str) -> ServeTTSRequest:
# 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询 # 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询
if self.reference_id and self.reference_id.strip(): if self.reference_id and self.reference_id.strip():
# 验证reference_id格式 # 验证reference_id格式
@@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider):
async for chunk in response.aiter_bytes(): async for chunk in response.aiter_bytes():
f.write(chunk) f.write(chunk)
return path return path
text = await response.aread() body = await response.aread()
text = body.decode("utf-8", errors="replace")
raise Exception(f"Fish Audio API请求失败: {text}") raise Exception(f"Fish Audio API请求失败: {text}")
@@ -1,3 +1,5 @@
from typing import cast
from google import genai from google import genai
from google.genai import types from google.genai import types
from google.genai.errors import APIError from google.genai.errors import APIError
@@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
self.provider_config = provider_config self.provider_config = provider_config
self.provider_settings = provider_settings self.provider_settings = provider_settings
api_key: str = provider_config.get("embedding_api_key") api_key: str = provider_config["embedding_api_key"]
api_base: str = provider_config.get("embedding_api_base") api_base: str = provider_config["embedding_api_base"]
timeout: int = int(provider_config.get("timeout", 20)) timeout: int = int(provider_config.get("timeout", 20))
http_options = types.HttpOptions(timeout=timeout * 1000) http_options = types.HttpOptions(timeout=timeout * 1000)
@@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
model=self.model, model=self.model,
contents=text, contents=text,
) )
assert result.embeddings is not None
assert result.embeddings[0].values is not None
return result.embeddings[0].values return result.embeddings[0].values
except APIError as e: except APIError as e:
raise Exception(f"Gemini Embedding API请求失败: {e.message}") raise Exception(f"Gemini Embedding API请求失败: {e.message}")
async def get_embeddings(self, texts: list[str]) -> list[list[float]]: async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入""" """批量获取文本的嵌入"""
try: try:
result = await self.client.models.embed_content( result = await self.client.models.embed_content(
model=self.model, model=self.model,
contents=texts, contents=cast(types.ContentListUnion, text),
) )
return [embedding.values for embedding in result.embeddings] assert result.embeddings is not None
embeddings: list[list[float]] = []
for embedding in result.embeddings:
assert embedding.values is not None
embeddings.append(embedding.values)
return embeddings
except APIError as e: except APIError as e:
raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") raise Exception(f"Gemini Embedding API批量请求失败: {e.message}")
+144 -53
View File
@@ -4,6 +4,7 @@ import json
import logging import logging
import random import random
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import cast
from google import genai from google import genai
from google.genai import types from google.genai import types
@@ -12,8 +13,9 @@ from google.genai.errors import APIError
import astrbot.core.message.components as Comp import astrbot.core.message.components as Comp
from astrbot import logger from astrbot import logger
from astrbot.api.provider import Provider from astrbot.api.provider import Provider
from astrbot.core.agent.message import ContentPart
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse from astrbot.core.provider.entities import LLMResponse, TokenUsage
from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
@@ -67,7 +69,7 @@ class ProviderGoogleGenAI(Provider):
self.api_base = self.api_base[:-1] self.api_base = self.api_base[:-1]
self._init_client() self._init_client()
self.set_model(provider_config["model_config"]["model"]) self.set_model(provider_config.get("model", "unknown"))
self._init_safety_settings() self._init_safety_settings()
def _init_client(self) -> None: def _init_client(self) -> None:
@@ -126,18 +128,18 @@ class ProviderGoogleGenAI(Provider):
) -> types.GenerateContentConfig: ) -> types.GenerateContentConfig:
"""准备查询配置""" """准备查询配置"""
if not modalities: if not modalities:
modalities = ["Text"] modalities = ["TEXT"]
# 流式输出不支持图片模态 # 流式输出不支持图片模态
if ( if (
self.provider_settings.get("streaming_response", False) self.provider_settings.get("streaming_response", False)
and "Image" in modalities and "IMAGE" in modalities
): ):
logger.warning("流式输出不支持图片模态,已自动降级为文本模态") logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
modalities = ["Text"] modalities = ["TEXT"]
tool_list = [] tool_list: list[types.Tool] | None = []
model_name = self.get_model() model_name = cast(str, payloads.get("model", self.get_model()))
native_coderunner = self.provider_config.get("gm_native_coderunner", False) native_coderunner = self.provider_config.get("gm_native_coderunner", False)
native_search = self.provider_config.get("gm_native_search", False) native_search = self.provider_config.get("gm_native_search", False)
url_context = self.provider_config.get("gm_url_context", False) url_context = self.provider_config.get("gm_url_context", False)
@@ -196,6 +198,53 @@ class ProviderGoogleGenAI(Provider):
types.Tool(function_declarations=func_desc["function_declarations"]), types.Tool(function_declarations=func_desc["function_declarations"]),
] ]
# oper thinking config
thinking_config = None
if model_name in [
"gemini-2.5-pro",
"gemini-2.5-pro-preview",
"gemini-2.5-flash",
"gemini-2.5-flash-preview",
"gemini-2.5-flash-lite",
"gemini-2.5-flash-lite-preview",
"gemini-robotics-er-1.5-preview",
"gemini-live-2.5-flash-preview-native-audio-09-2025",
]:
# The thinkingBudget parameter, introduced with the Gemini 2.5 series
thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
"budget", 0
)
if thinking_budget is not None:
thinking_config = types.ThinkingConfig(
thinking_budget=thinking_budget,
)
elif model_name in [
"gemini-3-pro",
"gemini-3-pro-preview",
"gemini-3-flash",
"gemini-3-flash-preview",
"gemini-3-flash-lite",
"gemini-3-flash-lite-preview",
]:
# The thinkingLevel parameter, recommended for Gemini 3 models and onwards
# Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
"level", "HIGH"
)
if thinking_level and isinstance(thinking_level, str):
thinking_level = thinking_level.upper()
if thinking_level not in ["MINIMAL", "LOW", "MEDIUM", "HIGH"]:
logger.warning(
f"Invalid thinking level: {thinking_level}, using HIGH"
)
thinking_level = "HIGH"
level = types.ThinkingLevel(thinking_level)
thinking_config = types.ThinkingConfig()
if not hasattr(types.ThinkingConfig, "thinking_level"):
setattr(types.ThinkingConfig, "thinking_level", level)
else:
thinking_config.thinking_level = level
return types.GenerateContentConfig( return types.GenerateContentConfig(
system_instruction=system_instruction, system_instruction=system_instruction,
temperature=temperature, temperature=temperature,
@@ -213,24 +262,9 @@ class ProviderGoogleGenAI(Provider):
logprobs=payloads.get("logprobs"), logprobs=payloads.get("logprobs"),
seed=payloads.get("seed"), seed=payloads.get("seed"),
response_modalities=modalities, response_modalities=modalities,
tools=tool_list, tools=cast(types.ToolListUnion | None, tool_list),
safety_settings=self.safety_settings if self.safety_settings else None, safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=( thinking_config=thinking_config,
types.ThinkingConfig(
thinking_budget=min(
int(
self.provider_config.get("gm_thinking_config", {}).get(
"budget",
0,
),
),
24576,
),
)
if "gemini-2.5-flash" in self.get_model()
and hasattr(types.ThinkingConfig, "thinking_budget")
else None
),
automatic_function_calling=types.AutomaticFunctionCallingConfig( automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True, disable=True,
), ),
@@ -257,6 +291,7 @@ class ProviderGoogleGenAI(Provider):
content_cls: type[types.Content], content_cls: type[types.Content],
) -> None: ) -> None:
if contents and isinstance(contents[-1], content_cls): if contents and isinstance(contents[-1], content_cls):
assert contents[-1].parts is not None
contents[-1].parts.extend(part) contents[-1].parts.extend(part)
else: else:
contents.append(content_cls(parts=part)) contents.append(content_cls(parts=part))
@@ -345,6 +380,16 @@ class ProviderGoogleGenAI(Provider):
] ]
return "".join(thought_buf).strip() return "".join(thought_buf).strip()
def _extract_usage(
self, usage_metadata: types.GenerateContentResponseUsageMetadata
) -> TokenUsage:
"""Extract usage from candidate"""
return TokenUsage(
input_other=usage_metadata.prompt_token_count or 0,
input_cached=usage_metadata.cached_content_token_count or 0,
output=usage_metadata.candidates_token_count or 0,
)
def _process_content_parts( def _process_content_parts(
self, self,
candidate: types.Candidate, candidate: types.Candidate,
@@ -429,9 +474,11 @@ class ProviderGoogleGenAI(Provider):
None, None,
) )
modalities = ["Text"] model = payloads.get("model", self.get_model())
modalities = ["TEXT"]
if self.provider_config.get("gm_resp_image_modal", False): if self.provider_config.get("gm_resp_image_modal", False):
modalities.append("Image") modalities.append("IMAGE")
conversation = self._prepare_conversation(payloads) conversation = self._prepare_conversation(payloads)
temperature = payloads.get("temperature", 0.7) temperature = payloads.get("temperature", 0.7)
@@ -447,8 +494,8 @@ class ProviderGoogleGenAI(Provider):
temperature, temperature,
) )
result = await self.client.models.generate_content( result = await self.client.models.generate_content(
model=self.get_model(), model=model,
contents=conversation, contents=cast(types.ContentListUnion, conversation),
config=config, config=config,
) )
logger.debug(f"genai result: {result}") logger.debug(f"genai result: {result}")
@@ -473,11 +520,11 @@ class ProviderGoogleGenAI(Provider):
e.message = "" e.message = ""
if "Developer instruction is not enabled" in e.message: if "Developer instruction is not enabled" in e.message:
logger.warning( logger.warning(
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)", f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
) )
system_instruction = None system_instruction = None
elif "Function calling is not enabled" in e.message: elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") logger.warning(f"{model} 不支持函数调用,已自动去除")
tools = None tools = None
elif ( elif (
"Multi-modal output is not supported" in e.message "Multi-modal output is not supported" in e.message
@@ -486,9 +533,9 @@ class ProviderGoogleGenAI(Provider):
or "only supports text output" in e.message or "only supports text output" in e.message
): ):
logger.warning( logger.warning(
f"{self.get_model()} 不支持多模态输出,降级为文本模态", f"{model} 不支持多模态输出,降级为文本模态",
) )
modalities = ["Text"] modalities = ["TEXT"]
else: else:
raise raise
continue continue
@@ -499,6 +546,9 @@ class ProviderGoogleGenAI(Provider):
result.candidates[0], result.candidates[0],
llm_response, llm_response,
) )
llm_response.id = result.response_id
if result.usage_metadata:
llm_response.usage = self._extract_usage(result.usage_metadata)
return llm_response return llm_response
async def _query_stream( async def _query_stream(
@@ -511,7 +561,7 @@ class ProviderGoogleGenAI(Provider):
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None, None,
) )
model = payloads.get("model", self.get_model())
conversation = self._prepare_conversation(payloads) conversation = self._prepare_conversation(payloads)
result = None result = None
@@ -523,8 +573,8 @@ class ProviderGoogleGenAI(Provider):
system_instruction, system_instruction,
) )
result = await self.client.models.generate_content_stream( result = await self.client.models.generate_content_stream(
model=self.get_model(), model=model,
contents=conversation, contents=cast(types.ContentListUnion, conversation),
config=config, config=config,
) )
break break
@@ -533,11 +583,11 @@ class ProviderGoogleGenAI(Provider):
e.message = "" e.message = ""
if "Developer instruction is not enabled" in e.message: if "Developer instruction is not enabled" in e.message:
logger.warning( logger.warning(
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)", f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
) )
system_instruction = None system_instruction = None
elif "Function calling is not enabled" in e.message: elif "Function calling is not enabled" in e.message:
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除") logger.warning(f"{model} 不支持函数调用,已自动去除")
tools = None tools = None
else: else:
raise raise
@@ -567,6 +617,9 @@ class ProviderGoogleGenAI(Provider):
chunk.candidates[0], chunk.candidates[0],
llm_response, llm_response,
) )
llm_response.id = chunk.response_id
if chunk.usage_metadata:
llm_response.usage = self._extract_usage(chunk.usage_metadata)
yield llm_response yield llm_response
return return
@@ -594,6 +647,9 @@ class ProviderGoogleGenAI(Provider):
chunk.candidates[0], chunk.candidates[0],
final_response, final_response,
) )
final_response.id = chunk.response_id
if chunk.usage_metadata:
final_response.usage = self._extract_usage(chunk.usage_metadata)
break break
# Yield final complete response with accumulated text # Yield final complete response with accumulated text
@@ -625,13 +681,16 @@ class ProviderGoogleGenAI(Provider):
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None, model=None,
extra_user_content_parts=None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
if contexts is None: if contexts is None:
contexts = [] contexts = []
new_record = None new_record = None
if prompt is not None: if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls) new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts) context_query = self._ensure_message_to_dicts(contexts)
if new_record: if new_record:
context_query.append(new_record) context_query.append(new_record)
@@ -650,10 +709,9 @@ class ProviderGoogleGenAI(Provider):
for tcr in tool_calls_result: for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages()) context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {}) model = model or self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config} payloads = {"messages": context_query, "model": model}
retry = 10 retry = 10
keys = self.api_keys.copy() keys = self.api_keys.copy()
@@ -678,13 +736,16 @@ class ProviderGoogleGenAI(Provider):
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None, model=None,
extra_user_content_parts=None,
**kwargs, **kwargs,
) -> AsyncGenerator[LLMResponse, None]: ) -> AsyncGenerator[LLMResponse, None]:
if contexts is None: if contexts is None:
contexts = [] contexts = []
new_record = None new_record = None
if prompt is not None: if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls) new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts) context_query = self._ensure_message_to_dicts(contexts)
if new_record: if new_record:
context_query.append(new_record) context_query.append(new_record)
@@ -703,10 +764,9 @@ class ProviderGoogleGenAI(Provider):
for tcr in tool_calls_result: for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages()) context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {}) model = model or self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config} payloads = {"messages": context_query, "model": model}
retry = 10 retry = 10
keys = self.api_keys.copy() keys = self.api_keys.copy()
@@ -744,13 +804,33 @@ class ProviderGoogleGenAI(Provider):
self.chosen_api_key = key self.chosen_api_key = key
self._init_client() self._init_client()
async def assemble_context(self, text: str, image_urls: list[str] | None = None): async def assemble_context(
self,
text: str,
image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
):
"""组装上下文。""" """组装上下文。"""
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content_blocks.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content_blocks.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for part in extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if image_urls: if image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
}
for image_url in image_urls: for image_url in image_urls:
if image_url.startswith("http"): if image_url.startswith("http"):
image_path = await download_image_by_url(image_url) image_path = await download_image_by_url(image_url)
@@ -763,14 +843,25 @@ class ProviderGoogleGenAI(Provider):
if not image_data: if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue continue
user_content["content"].append( content_blocks.append(
{ {
"type": "image_url", "type": "image_url",
"image_url": {"url": image_data}, "image_url": {"url": image_data},
}, },
) )
return user_content
return {"role": "user", "content": text} # 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def encode_image_bs64(self, image_url: str) -> str: async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64""" """将图片转换为 base64"""
@@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return json.dumps(dict_body) return json.dumps(dict_body)
async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]: async def _call_tts_stream(self, text: str) -> AsyncIterator[str]:
"""进行流式请求""" """进行流式请求"""
try: try:
async with ( async with (
@@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
data = json.loads(message[6:]) data = json.loads(message[6:])
if "extra_info" in data: if "extra_info" in data:
continue continue
audio = data.get("data", {}).get("audio") audio: str | None = data.get("data", {}).get(
"audio"
)
if audio is not None: if audio is not None:
yield audio yield audio
except json.JSONDecodeError: except json.JSONDecodeError:
@@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
embedding = await self.client.embeddings.create(input=text, model=self.model) embedding = await self.client.embeddings.create(input=text, model=self.model)
return embedding.data[0].embedding return embedding.data[0].embedding
async def get_embeddings(self, texts: list[str]) -> list[list[float]]: async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的嵌入""" """批量获取文本的嵌入"""
embeddings = await self.client.embeddings.create(input=texts, model=self.model) embeddings = await self.client.embeddings.create(input=text, model=self.model)
return [item.embedding for item in embeddings.data] return [item.embedding for item in embeddings.data]
def get_dim(self) -> int: def get_dim(self) -> int:
+68 -15
View File
@@ -12,14 +12,15 @@ from openai._exceptions import NotFoundError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.completion_usage import CompletionUsage
import astrbot.core.message.components as Comp import astrbot.core.message.components as Comp
from astrbot import logger from astrbot import logger
from astrbot.api.provider import Provider from astrbot.api.provider import Provider
from astrbot.core.agent.message import Message from astrbot.core.agent.message import ContentPart, Message
from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter from ..register import register_provider_adapter
@@ -68,8 +69,7 @@ class ProviderOpenAIOfficial(Provider):
self.client.chat.completions.create, self.client.chat.completions.create,
).parameters.keys() ).parameters.keys()
model_config = provider_config.get("model_config", {}) model = provider_config.get("model", "unknown")
model = model_config.get("model", "unknown")
self.set_model(model) self.set_model(model)
self.reasoning_key = "reasoning_content" self.reasoning_key = "reasoning_content"
@@ -208,6 +208,7 @@ class ProviderOpenAIOfficial(Provider):
# handle the content delta # handle the content delta
reasoning = self._extract_reasoning_content(chunk) reasoning = self._extract_reasoning_content(chunk)
_y = False _y = False
llm_response.id = chunk.id
if reasoning: if reasoning:
llm_response.reasoning_content = reasoning llm_response.reasoning_content = reasoning
_y = True _y = True
@@ -217,6 +218,8 @@ class ProviderOpenAIOfficial(Provider):
chain=[Comp.Plain(completion_text)], chain=[Comp.Plain(completion_text)],
) )
_y = True _y = True
if chunk.usage:
llm_response.usage = self._extract_usage(chunk.usage)
if _y: if _y:
yield llm_response yield llm_response
@@ -245,6 +248,15 @@ class ProviderOpenAIOfficial(Provider):
reasoning_text = str(reasoning_attr) reasoning_text = str(reasoning_attr)
return reasoning_text return reasoning_text
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
ptd = usage.prompt_tokens_details
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
return TokenUsage(
input_other=usage.prompt_tokens - cached,
input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
output=usage.completion_tokens,
)
async def _parse_openai_completion( async def _parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse: ) -> LLMResponse:
@@ -284,6 +296,10 @@ class ProviderOpenAIOfficial(Provider):
if isinstance(tool_call, str): if isinstance(tool_call, str):
# workaround for #1359 # workaround for #1359
tool_call = json.loads(tool_call) tool_call = json.loads(tool_call)
if tools is None:
# 工具集未提供
# Should be unreachable
raise Exception("工具集未提供")
for tool in tools.func_list: for tool in tools.func_list:
if ( if (
tool_call.type == "function" tool_call.type == "function"
@@ -317,6 +333,10 @@ class ProviderOpenAIOfficial(Provider):
raise Exception(f"API 返回的 completion 无法解析:{completion}") raise Exception(f"API 返回的 completion 无法解析:{completion}")
llm_response.raw_completion = completion llm_response.raw_completion = completion
llm_response.id = completion.id
if completion.usage:
llm_response.usage = self._extract_usage(completion.usage)
return llm_response return llm_response
@@ -328,6 +348,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt: str | None = None, system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None, model: str | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
**kwargs, **kwargs,
) -> tuple: ) -> tuple:
"""准备聊天所需的有效载荷和上下文""" """准备聊天所需的有效载荷和上下文"""
@@ -335,7 +356,9 @@ class ProviderOpenAIOfficial(Provider):
contexts = [] contexts = []
new_record = None new_record = None
if prompt is not None: if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls) new_record = await self.assemble_context(
prompt, image_urls, extra_user_content_parts
)
context_query = self._ensure_message_to_dicts(contexts) context_query = self._ensure_message_to_dicts(contexts)
if new_record: if new_record:
context_query.append(new_record) context_query.append(new_record)
@@ -354,10 +377,9 @@ class ProviderOpenAIOfficial(Provider):
for tcr in tool_calls_result: for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages()) context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {}) model = model or self.get_model()
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config} payloads = {"messages": context_query, "model": model}
# xAI origin search tool inject # xAI origin search tool inject
self._maybe_inject_xai_search(payloads, **kwargs) self._maybe_inject_xai_search(payloads, **kwargs)
@@ -457,6 +479,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None, model=None,
extra_user_content_parts=None,
**kwargs, **kwargs,
) -> LLMResponse: ) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload( payloads, context_query = await self._prepare_chat_payload(
@@ -466,6 +489,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt, system_prompt,
tool_calls_result, tool_calls_result,
model=model, model=model,
extra_user_content_parts=extra_user_content_parts,
**kwargs, **kwargs,
) )
@@ -520,6 +544,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt=None, system_prompt=None,
tool_calls_result=None, tool_calls_result=None,
model=None, model=None,
extra_user_content_parts=None,
**kwargs, **kwargs,
) -> AsyncGenerator[LLMResponse, None]: ) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果""" """流式对话,与服务商交互并逐步返回结果"""
@@ -530,6 +555,7 @@ class ProviderOpenAIOfficial(Provider):
system_prompt, system_prompt,
tool_calls_result, tool_calls_result,
model=model, model=model,
extra_user_content_parts=extra_user_content_parts,
**kwargs, **kwargs,
) )
@@ -605,13 +631,29 @@ class ProviderOpenAIOfficial(Provider):
self, self,
text: str, text: str,
image_urls: list[str] | None = None, image_urls: list[str] | None = None,
extra_user_content_parts: list[ContentPart] | None = None,
) -> dict: ) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段""" """组装成符合 OpenAI 格式的 role 为 user 的消息段"""
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if text:
content_blocks.append({"type": "text", "text": text})
elif image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
elif extra_user_content_parts:
# 如果只有额外内容块,也需要添加占位文本
content_blocks.append({"type": "text", "text": " "})
# 2. 额外的内容块(系统提醒、指令等)
if extra_user_content_parts:
for part in extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if image_urls: if image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
}
for image_url in image_urls: for image_url in image_urls:
if image_url.startswith("http"): if image_url.startswith("http"):
image_path = await download_image_by_url(image_url) image_path = await download_image_by_url(image_url)
@@ -624,14 +666,25 @@ class ProviderOpenAIOfficial(Provider):
if not image_data: if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue continue
user_content["content"].append( content_blocks.append(
{ {
"type": "image_url", "type": "image_url",
"image_url": {"url": image_data}, "image_url": {"url": image_data},
}, },
) )
return user_content
return {"role": "user", "content": text} # 如果只有主文本且没有额外内容块和图片,返回简单格式以保持向后兼容
if (
text
and not extra_user_content_parts
and not image_urls
and len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def encode_image_bs64(self, image_url: str) -> str: async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64""" """将图片转换为 base64"""
@@ -7,6 +7,7 @@ import asyncio
import os import os
import re import re
from datetime import datetime from datetime import datetime
from typing import cast
from funasr_onnx import SenseVoiceSmall from funasr_onnx import SenseVoiceSmall
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
@@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
provider_settings: dict, provider_settings: dict,
) -> None: ) -> None:
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("stt_model")) self.set_model(provider_config["stt_model"])
self.model = None self.model = None
self.is_emotion = provider_config.get("is_emotion", False) self.is_emotion = provider_config.get("is_emotion", False)
@@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
res = await loop.run_in_executor( res = await loop.run_in_executor(
None, # 使用默认的线程池 None, # 使用默认的线程池
lambda: self.model(audio_url, language="auto", use_itn=True), lambda: cast(SenseVoiceSmall, self.model)(
audio_url, language="auto", use_itn=True
),
) )
# res = self.model(audio_url, language="auto", use_itn=True) # res = self.model(audio_url, language="auto", use_itn=True)
@@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider):
} }
if top_n is not None: if top_n is not None:
payload["top_n"] = top_n payload["top_n"] = top_n
assert self.client is not None
async with self.client.post( async with self.client.post(
f"{self.base_url}/v1/rerank", f"{self.base_url}/v1/rerank",
json=payload, json=payload,
@@ -36,7 +36,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
timeout=provider_config.get("timeout", NOT_GIVEN), timeout=provider_config.get("timeout", NOT_GIVEN),
) )
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
async def _get_audio_format(self, file_path): async def _get_audio_format(self, file_path):
# 定义要检测的头部字节 # 定义要检测的头部字节
@@ -1,6 +1,7 @@
import asyncio import asyncio
import os import os
import uuid import uuid
from typing import cast
import whisper import whisper
@@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
provider_settings: dict, provider_settings: dict,
) -> None: ) -> None:
super().__init__(provider_config, provider_settings) super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("model")) self.set_model(provider_config["model"])
self.model = None self.model = None
async def initialize(self): async def initialize(self):
@@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
await tencent_silk_to_wav(audio_url, output_path) await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path audio_url = output_path
if not self.model:
raise RuntimeError("Whisper 模型未初始化")
result = await loop.run_in_executor(None, self.model.transcribe, audio_url) result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
return result["text"] return cast(str, result["text"])
@@ -1,6 +1,11 @@
from typing import cast
from xinference_client.client.restful.async_restful_client import ( from xinference_client.client.restful.async_restful_client import (
AsyncClient as Client, AsyncClient as Client,
) )
from xinference_client.client.restful.async_restful_client import (
AsyncRESTfulRerankModelHandle,
)
from astrbot import logger from astrbot import logger
@@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider):
False, False,
) )
self.client = None self.client = None
self.model = None self.model: AsyncRESTfulRerankModelHandle | None = None
self.model_uid = None self.model_uid = None
async def initialize(self): async def initialize(self):
@@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider):
return return
if self.model_uid: if self.model_uid:
self.model = await self.client.get_model(self.model_uid) self.model = cast(
AsyncRESTfulRerankModelHandle,
await self.client.get_model(self.model_uid),
)
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize Xinference model: {e}") logger.error(f"Failed to initialize Xinference model: {e}")
+5 -1
View File
@@ -2,15 +2,19 @@ from astrbot.core import html_renderer
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .context import Context from .context import Context
from .star import StarMetadata, star_map, star_registry from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager from .star_manager import PluginManager
class Star(CommandParserMixin): class Star(CommandParserMixin, PluginKVStoreMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类""" """所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
def __init__(self, context: Context, config: dict | None = None): def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context) StarTools.initialize(context)
self.context = context self.context = context
+496
View File
@@ -0,0 +1,496 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from astrbot.core import db_helper, logger
from astrbot.core.db.po import CommandConfig
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
@dataclass
class CommandDescriptor:
handler: StarHandlerMetadata = field(repr=False)
filter_ref: CommandFilter | CommandGroupFilter | None = field(
default=None,
repr=False,
)
handler_full_name: str = ""
handler_name: str = ""
plugin_name: str = ""
plugin_display_name: str | None = None
module_path: str = ""
description: str = ""
command_type: str = "command" # "command" | "group" | "sub_command"
raw_command_name: str | None = None
current_fragment: str | None = None
parent_signature: str = ""
parent_group_handler: str = ""
original_command: str | None = None
effective_command: str | None = None
aliases: list[str] = field(default_factory=list)
permission: str = "everyone"
enabled: bool = True
is_group: bool = False
is_sub_command: bool = False
reserved: bool = False
config: CommandConfig | None = None
has_conflict: bool = False
sub_commands: list[CommandDescriptor] = field(default_factory=list)
async def sync_command_configs() -> None:
"""同步指令配置,清理过期配置。"""
descriptors = _collect_descriptors(include_sub_commands=False)
config_records = await db_helper.get_command_configs()
config_map = _bind_configs_to_descriptors(descriptors, config_records)
live_handlers = {desc.handler_full_name for desc in descriptors}
stale_configs = [key for key in config_map if key not in live_handlers]
if stale_configs:
await db_helper.delete_command_configs(stale_configs)
async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescriptor:
descriptor = _build_descriptor_by_full_name(handler_full_name)
if not descriptor:
raise ValueError("指定的处理函数不存在或不是指令。")
existing_cfg = await db_helper.get_command_config(handler_full_name)
config = await db_helper.upsert_command_config(
handler_full_name=handler_full_name,
plugin_name=descriptor.plugin_name or "",
module_path=descriptor.module_path,
original_command=descriptor.original_command or descriptor.handler_name,
resolved_command=(
existing_cfg.resolved_command
if existing_cfg
else descriptor.current_fragment
),
enabled=enabled,
keep_original_alias=False,
conflict_key=existing_cfg.conflict_key
if existing_cfg and existing_cfg.conflict_key
else descriptor.original_command,
resolution_strategy=existing_cfg.resolution_strategy if existing_cfg else None,
note=existing_cfg.note if existing_cfg else None,
extra_data=existing_cfg.extra_data if existing_cfg else None,
auto_managed=False,
)
_bind_descriptor_with_config(descriptor, config)
await sync_command_configs()
return descriptor
async def rename_command(
handler_full_name: str,
new_fragment: str,
aliases: list[str] | None = None,
) -> CommandDescriptor:
descriptor = _build_descriptor_by_full_name(handler_full_name)
if not descriptor:
raise ValueError("指定的处理函数不存在或不是指令。")
new_fragment = new_fragment.strip()
if not new_fragment:
raise ValueError("指令名不能为空。")
# 校验主指令名
candidate_full = _compose_command(descriptor.parent_signature, new_fragment)
if _is_command_in_use(handler_full_name, candidate_full):
raise ValueError(f"指令名 '{candidate_full}' 已被其他指令占用。")
# 校验别名
if aliases:
for alias in aliases:
alias = alias.strip()
if not alias:
continue
alias_full = _compose_command(descriptor.parent_signature, alias)
if _is_command_in_use(handler_full_name, alias_full):
raise ValueError(f"别名 '{alias_full}' 已被其他指令占用。")
existing_cfg = await db_helper.get_command_config(handler_full_name)
merged_extra = dict(existing_cfg.extra_data or {}) if existing_cfg else {}
merged_extra["resolved_aliases"] = aliases or []
config = await db_helper.upsert_command_config(
handler_full_name=handler_full_name,
plugin_name=descriptor.plugin_name or "",
module_path=descriptor.module_path,
original_command=descriptor.original_command or descriptor.handler_name,
resolved_command=new_fragment,
enabled=True if descriptor.enabled else False,
keep_original_alias=False,
conflict_key=descriptor.original_command,
resolution_strategy="manual_rename",
note=None,
extra_data=merged_extra,
auto_managed=False,
)
_bind_descriptor_with_config(descriptor, config)
await sync_command_configs()
return descriptor
async def list_commands() -> list[dict[str, Any]]:
descriptors = _collect_descriptors(include_sub_commands=True)
config_records = await db_helper.get_command_configs()
_bind_configs_to_descriptors(descriptors, config_records)
conflict_groups = _group_conflicts(descriptors)
conflict_handler_names: set[str] = {
d.handler_full_name for group in conflict_groups.values() for d in group
}
# 分类,设置冲突标志,将子指令挂载到父指令组
group_map: dict[str, CommandDescriptor] = {}
sub_commands: list[CommandDescriptor] = []
root_commands: list[CommandDescriptor] = []
for desc in descriptors:
desc.has_conflict = desc.handler_full_name in conflict_handler_names
if desc.is_group:
group_map[desc.handler_full_name] = desc
elif desc.is_sub_command:
sub_commands.append(desc)
else:
root_commands.append(desc)
for sub in sub_commands:
if sub.parent_group_handler and sub.parent_group_handler in group_map:
group_map[sub.parent_group_handler].sub_commands.append(sub)
else:
root_commands.append(sub)
# 指令组 + 普通指令,按 effective_command 字母排序
all_commands = list(group_map.values()) + root_commands
all_commands.sort(key=lambda d: (d.effective_command or "").lower())
result = [_descriptor_to_dict(desc) for desc in all_commands]
return result
async def list_command_conflicts() -> list[dict[str, Any]]:
"""列出所有冲突的指令组。"""
descriptors = _collect_descriptors(include_sub_commands=False)
config_records = await db_helper.get_command_configs()
_bind_configs_to_descriptors(descriptors, config_records)
conflict_groups = _group_conflicts(descriptors)
details = [
{
"conflict_key": key,
"handlers": [
{
"handler_full_name": item.handler_full_name,
"plugin": item.plugin_name,
"current_name": item.effective_command,
}
for item in group
],
}
for key, group in conflict_groups.items()
]
return details
# Internal helpers ----------------------------------------------------------
def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]:
"""收集指令,按需包含子指令。"""
descriptors: list[CommandDescriptor] = []
for handler in star_handlers_registry:
try:
desc = _build_descriptor(handler)
if not desc:
continue
if not include_sub_commands and desc.is_sub_command:
continue
descriptors.append(desc)
except Exception as e:
logger.warning(
f"解析指令处理函数 {handler.handler_full_name} 失败,跳过该指令。原因: {e!s}"
)
continue
return descriptors
def _build_descriptor(handler: StarHandlerMetadata) -> CommandDescriptor | None:
filter_ref = _locate_primary_filter(handler)
if filter_ref is None:
return None
plugin_meta = star_map.get(handler.handler_module_path)
plugin_name = (
plugin_meta.name if plugin_meta else None
) or handler.handler_module_path
plugin_display = plugin_meta.display_name if plugin_meta else None
is_sub_command = bool(handler.extras_configs.get("sub_command"))
parent_group_handler = ""
if isinstance(filter_ref, CommandFilter):
raw_fragment = getattr(
filter_ref, "_original_command_name", filter_ref.command_name
)
current_fragment = filter_ref.command_name
parent_signature = (filter_ref.parent_command_names or [""])[0].strip()
# 如果是子指令,尝试找到父指令组的 handler_full_name
if is_sub_command and parent_signature:
parent_group_handler = _find_parent_group_handler(
handler.handler_module_path, parent_signature
)
else:
raw_fragment = getattr(
filter_ref, "_original_group_name", filter_ref.group_name
)
current_fragment = filter_ref.group_name
parent_signature = _resolve_group_parent_signature(filter_ref)
original_command = _compose_command(parent_signature, raw_fragment)
effective_command = _compose_command(parent_signature, current_fragment)
# 确定 command_type
if isinstance(filter_ref, CommandGroupFilter):
command_type = "group"
elif is_sub_command:
command_type = "sub_command"
else:
command_type = "command"
descriptor = CommandDescriptor(
handler=handler,
filter_ref=filter_ref,
handler_full_name=handler.handler_full_name,
handler_name=handler.handler_name,
plugin_name=plugin_name,
plugin_display_name=plugin_display,
module_path=handler.handler_module_path,
description=handler.desc or "",
command_type=command_type,
raw_command_name=raw_fragment,
current_fragment=current_fragment,
parent_signature=parent_signature,
parent_group_handler=parent_group_handler,
original_command=original_command,
effective_command=effective_command,
aliases=sorted(getattr(filter_ref, "alias", set())),
permission=_determine_permission(handler),
enabled=handler.enabled,
is_group=isinstance(filter_ref, CommandGroupFilter),
is_sub_command=is_sub_command,
reserved=plugin_meta.reserved if plugin_meta else False,
)
return descriptor
def _build_descriptor_by_full_name(full_name: str) -> CommandDescriptor | None:
handler = star_handlers_registry.get_handler_by_full_name(full_name)
if not handler:
return None
return _build_descriptor(handler)
def _locate_primary_filter(
handler: StarHandlerMetadata,
) -> CommandFilter | CommandGroupFilter | None:
for filter_ref in handler.event_filters:
if isinstance(filter_ref, (CommandFilter, CommandGroupFilter)):
return filter_ref
return None
def _determine_permission(handler: StarHandlerMetadata) -> str:
for filter_ref in handler.event_filters:
if isinstance(filter_ref, PermissionTypeFilter):
return (
"admin"
if filter_ref.permission_type == PermissionType.ADMIN
else "member"
)
return "everyone"
def _resolve_group_parent_signature(group_filter: CommandGroupFilter) -> str:
signatures: list[str] = []
parent = group_filter.parent_group
while parent:
signatures.append(getattr(parent, "_original_group_name", parent.group_name))
parent = parent.parent_group
return " ".join(reversed(signatures)).strip()
def _find_parent_group_handler(module_path: str, parent_signature: str) -> str:
"""根据模块路径和父级签名,找到对应的指令组 handler_full_name。"""
parent_sig_normalized = parent_signature.strip()
for handler in star_handlers_registry:
if handler.handler_module_path != module_path:
continue
filter_ref = _locate_primary_filter(handler)
if not isinstance(filter_ref, CommandGroupFilter):
continue
# 检查该指令组的完整指令名是否匹配 parent_signature
group_names = filter_ref.get_complete_command_names()
if parent_sig_normalized in group_names:
return handler.handler_full_name
return ""
def _compose_command(parent_signature: str, fragment: str | None) -> str:
fragment = (fragment or "").strip()
parent_signature = parent_signature.strip()
if not parent_signature:
return fragment
if not fragment:
return parent_signature
return f"{parent_signature} {fragment}"
def _bind_descriptor_with_config(
descriptor: CommandDescriptor,
config: CommandConfig,
) -> None:
_apply_config_to_descriptor(descriptor, config)
_apply_config_to_runtime(descriptor, config)
def _apply_config_to_descriptor(
descriptor: CommandDescriptor,
config: CommandConfig,
) -> None:
descriptor.config = config
descriptor.enabled = config.enabled
if config.original_command:
descriptor.original_command = config.original_command
new_fragment = config.resolved_command or descriptor.current_fragment
descriptor.current_fragment = new_fragment
descriptor.effective_command = _compose_command(
descriptor.parent_signature,
new_fragment,
)
extra = config.extra_data or {}
resolved_aliases = extra.get("resolved_aliases")
if isinstance(resolved_aliases, list):
descriptor.aliases = [str(x) for x in resolved_aliases if str(x).strip()]
def _apply_config_to_runtime(
descriptor: CommandDescriptor,
config: CommandConfig,
) -> None:
descriptor.handler.enabled = config.enabled
if descriptor.filter_ref:
if descriptor.current_fragment:
_set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
extra = config.extra_data or {}
resolved_aliases = extra.get("resolved_aliases")
if isinstance(resolved_aliases, list):
_set_filter_aliases(
descriptor.filter_ref,
[str(x) for x in resolved_aliases if str(x).strip()],
)
def _bind_configs_to_descriptors(
descriptors: list[CommandDescriptor],
config_records: list[CommandConfig],
) -> dict[str, CommandConfig]:
config_map = {cfg.handler_full_name: cfg for cfg in config_records}
for desc in descriptors:
if cfg := config_map.get(desc.handler_full_name):
_bind_descriptor_with_config(desc, cfg)
return config_map
def _group_conflicts(
descriptors: list[CommandDescriptor],
) -> dict[str, list[CommandDescriptor]]:
conflicts: dict[str, list[CommandDescriptor]] = defaultdict(list)
for desc in descriptors:
if desc.effective_command and desc.enabled:
conflicts[desc.effective_command].append(desc)
return {k: v for k, v in conflicts.items() if len(v) > 1}
def _set_filter_fragment(
filter_ref: CommandFilter | CommandGroupFilter,
fragment: str,
) -> None:
attr = (
"group_name" if isinstance(filter_ref, CommandGroupFilter) else "command_name"
)
current_value = getattr(filter_ref, attr)
if fragment == current_value:
return
setattr(filter_ref, attr, fragment)
if hasattr(filter_ref, "_cmpl_cmd_names"):
filter_ref._cmpl_cmd_names = None
def _set_filter_aliases(
filter_ref: CommandFilter | CommandGroupFilter,
aliases: list[str],
) -> None:
current_aliases = getattr(filter_ref, "alias", set())
if set(aliases) == current_aliases:
return
setattr(filter_ref, "alias", set(aliases))
if hasattr(filter_ref, "_cmpl_cmd_names"):
filter_ref._cmpl_cmd_names = None
def _is_command_in_use(
target_handler_full_name: str,
candidate_full_command: str,
) -> bool:
candidate = candidate_full_command.strip()
for handler in star_handlers_registry:
if handler.handler_full_name == target_handler_full_name:
continue
filter_ref = _locate_primary_filter(handler)
if not filter_ref:
continue
names = {name.strip() for name in filter_ref.get_complete_command_names()}
if candidate in names:
return True
return False
def _descriptor_to_dict(desc: CommandDescriptor) -> dict[str, Any]:
result = {
"handler_full_name": desc.handler_full_name,
"handler_name": desc.handler_name,
"plugin": desc.plugin_name,
"plugin_display_name": desc.plugin_display_name,
"module_path": desc.module_path,
"description": desc.description,
"type": desc.command_type,
"parent_signature": desc.parent_signature,
"parent_group_handler": desc.parent_group_handler,
"original_command": desc.original_command,
"current_fragment": desc.current_fragment,
"effective_command": desc.effective_command,
"aliases": desc.aliases,
"permission": desc.permission,
"enabled": desc.enabled,
"is_group": desc.is_group,
"has_conflict": desc.has_conflict,
"reserved": desc.reserved,
}
# 如果是指令组,包含子指令列表
if desc.is_group and desc.sub_commands:
result["sub_commands"] = [_descriptor_to_dict(sub) for sub in desc.sub_commands]
else:
result["sub_commands"] = []
return result
+6 -2
View File
@@ -267,6 +267,10 @@ class Context:
): ):
"""通过 ID 获取对应的 LLM Provider。""" """通过 ID 获取对应的 LLM Provider。"""
prov = self.provider_manager.inst_map.get(provider_id) prov = self.provider_manager.inst_map.get(provider_id)
if provider_id and not prov:
logger.warning(
f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。"
)
return prov return prov
def get_all_providers(self) -> list[Provider]: def get_all_providers(self) -> list[Provider]:
@@ -285,7 +289,7 @@ class Context:
"""获取所有用于 Embedding 任务的 Provider。""" """获取所有用于 Embedding 任务的 Provider。"""
return self.provider_manager.embedding_provider_insts return self.provider_manager.embedding_provider_insts
def get_using_provider(self, umo: str | None = None) -> Provider | None: def get_using_provider(self, umo: str | None = None) -> Provider:
"""获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。
Args: Args:
@@ -296,7 +300,7 @@ class Context:
provider_type=ProviderType.CHAT_COMPLETION, provider_type=ProviderType.CHAT_COMPLETION,
umo=umo, umo=umo,
) )
if prov and not isinstance(prov, Provider): if not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型") raise ValueError("返回的 Provider 不是 Provider 类型")
return prov return prov
+1
View File
@@ -40,6 +40,7 @@ class CommandFilter(HandlerFilter):
): ):
self.command_name = command_name self.command_name = command_name
self.alias = alias if alias else set() self.alias = alias if alias else set()
self._original_command_name = command_name
self.parent_command_names = ( self.parent_command_names = (
parent_command_names if parent_command_names is not None else [""] parent_command_names if parent_command_names is not None else [""]
) )
@@ -18,6 +18,7 @@ class CommandGroupFilter(HandlerFilter):
): ):
self.group_name = group_name self.group_name = group_name
self.alias = alias if alias else set() self.alias = alias if alias else set()
self._original_group_name = group_name
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = [] self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
self.custom_filter_list: list[CustomFilter] = [] self.custom_filter_list: list[CustomFilter] = []
self.parent_group = parent_group self.parent_group = parent_group
+22 -5
View File
@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from typing import Any from typing import Any
import docstring_parser import docstring_parser
@@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool
from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.hooks import BaseAgentRunHooks
from astrbot.core.agent.tool import FunctionTool from astrbot.core.agent.tool import FunctionTool
from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools from astrbot.core.provider.register import llm_tools
@@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter
from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str: def get_handler_full_name(
awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]],
) -> str:
"""获取 Handler 的全名""" """获取 Handler 的全名"""
return f"{awaitable.__module__}_{awaitable.__name__}" return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create( def get_handler_or_create(
handler: Callable[..., Awaitable[Any]], handler: Callable[
...,
Awaitable[MessageEventResult | str | None]
| AsyncGenerator[MessageEventResult | str | None],
],
event_type: EventType, event_type: EventType,
dont_add=False, dont_add=False,
**kwargs, **kwargs,
@@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
for ( for (
sub_handle sub_handle
) in parent_register_commandable.parent_group.sub_command_filters: ) in parent_register_commandable.parent_group.sub_command_filters:
if isinstance(sub_handle, CommandGroupFilter):
continue
# 所有符合fullname一致的子指令handle添加自定义过滤器。 # 所有符合fullname一致的子指令handle添加自定义过滤器。
# 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器? # 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器?
sub_handle_md = sub_handle.get_handler_md() sub_handle_md = sub_handle.get_handler_md()
@@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
else: else:
# 裸指令 # 裸指令
# 确保运行时是可调用的 handler,针对类型检查器添加忽略
assert isinstance(awaitable, Callable)
handler_md = get_handler_or_create( handler_md = get_handler_or_create(
awaitable, awaitable,
EventType.AdapterMessageEvent, EventType.AdapterMessageEvent,
@@ -237,7 +248,7 @@ class RegisteringCommandable:
group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group
command: Callable[..., Callable[..., None]] = register_command command: Callable[..., Callable[..., None]] = register_command
custom_filter: Callable[..., Callable[..., None]] = register_custom_filter custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter
def __init__(self, parent_group: CommandGroupFilter): def __init__(self, parent_group: CommandGroupFilter):
self.parent_group = parent_group self.parent_group = parent_group
@@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs):
if kwargs.get("registering_agent"): if kwargs.get("registering_agent"):
registering_agent = kwargs["registering_agent"] registering_agent = kwargs["registering_agent"]
def decorator(awaitable: Callable[..., Awaitable[Any]]): def decorator(
awaitable: Callable[
...,
AsyncGenerator[MessageEventResult | str | None]
| Awaitable[MessageEventResult | str | None],
],
):
llm_tool_name = name_ if name_ else awaitable.__name__ llm_tool_name = name_ if name_ else awaitable.__name__
func_doc = awaitable.__doc__ or "" func_doc = awaitable.__doc__ or ""
docstring = docstring_parser.parse(func_doc) docstring = docstring_parser.parse(func_doc)
+89 -4
View File
@@ -1,9 +1,9 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
from collections.abc import Awaitable, Callable from collections.abc import AsyncGenerator, Awaitable, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Generic, TypeVar from typing import Any, Generic, Literal, TypeVar, overload
from .filter import HandlerFilter from .filter import HandlerFilter
from .star import star_map from .star import star_map
@@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]):
for handler in self._handlers: for handler in self._handlers:
print(handler.handler_full_name) print(handler.handler_full_name)
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAstrBotLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnPlatformLoadedEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.AdapterMessageEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMRequestEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnLLMResponseEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnDecoratingResultEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnCallingFuncToolEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: Literal[EventType.OnAfterMessageSentEvent],
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ...
@overload
def get_handlers_by_event_type(
self,
event_type: EventType,
only_activated=True,
plugins_name: list[str] | None = None,
) -> list[
StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]]
]: ...
def get_handlers_by_event_type( def get_handlers_by_event_type(
self, self,
event_type: EventType, event_type: EventType,
@@ -40,6 +118,8 @@ class StarHandlerRegistry(Generic[T]):
# 过滤事件类型 # 过滤事件类型
if handler.event_type != event_type: if handler.event_type != event_type:
continue continue
if not handler.enabled:
continue
# 过滤启用状态 # 过滤启用状态
if only_activated: if only_activated:
plugin = star_map.get(handler.handler_module_path) plugin = star_map.get(handler.handler_module_path)
@@ -111,8 +191,11 @@ class EventType(enum.Enum):
OnAfterMessageSentEvent = enum.auto() # 发送消息后 OnAfterMessageSentEvent = enum.auto() # 发送消息后
H = TypeVar("H", bound=Callable[..., Any])
@dataclass @dataclass
class StarHandlerMetadata: class StarHandlerMetadata(Generic[H]):
"""描述一个 Star 所注册的某一个 Handler。""" """描述一个 Star 所注册的某一个 Handler。"""
event_type: EventType event_type: EventType
@@ -127,7 +210,7 @@ class StarHandlerMetadata:
handler_module_path: str handler_module_path: str
"""Handler 所在的模块路径。""" """Handler 所在的模块路径。"""
handler: Callable[..., Awaitable[Any]] handler: H
"""Handler 的函数对象,应当是一个异步函数""" """Handler 的函数对象,应当是一个异步函数"""
event_filters: list[HandlerFilter] event_filters: list[HandlerFilter]
@@ -139,6 +222,8 @@ class StarHandlerMetadata:
extras_configs: dict = field(default_factory=dict) extras_configs: dict = field(default_factory=dict)
"""插件注册的一些其他的信息, 如 priority 等""" """插件注册的一些其他的信息, 如 priority 等"""
enabled: bool = True
def __lt__(self, other: StarHandlerMetadata): def __lt__(self, other: StarHandlerMetadata):
"""定义小于运算符以支持优先队列""" """定义小于运算符以支持优先队列"""
return self.extras_configs.get("priority", 0) < other.extras_configs.get( return self.extras_configs.get("priority", 0) < other.extras_configs.get(
+18
View File
@@ -23,6 +23,7 @@ from astrbot.core.utils.astrbot_path import (
from astrbot.core.utils.io import remove_dir from astrbot.core.utils.io import remove_dir
from . import StarMetadata from . import StarMetadata
from .command_management import sync_command_configs
from .context import Context from .context import Context
from .filter.permission import PermissionType, PermissionTypeFilter from .filter.permission import PermissionType, PermissionTypeFilter
from .star import star_map, star_registry from .star import star_map, star_registry
@@ -467,6 +468,18 @@ class PluginManager:
metadata.star_cls = metadata.star_cls_type( metadata.star_cls = metadata.star_cls_type(
context=self.context, context=self.context,
) )
p_name = (metadata.name or "unknown").lower().replace("/", "_")
p_author = (
(metadata.author or "unknown").lower().replace("/", "_")
)
setattr(metadata.star_cls, "name", p_name)
setattr(metadata.star_cls, "author", p_author)
setattr(
metadata.star_cls,
"plugin_id",
f"{p_author}/{p_name}",
)
else: else:
logger.info(f"插件 {metadata.name} 已被禁用。") logger.info(f"插件 {metadata.name} 已被禁用。")
@@ -618,6 +631,11 @@ class PluginManager:
# 清除 pip.main 导致的多余的 logging handlers # 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]: for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler) logging.root.removeHandler(handler)
try:
await sync_command_configs()
except Exception as e:
logger.error(f"同步指令配置失败: {e!s}")
logger.error(traceback.format_exc())
if not fail_rec: if not fail_rec:
return True, None return True, None

Some files were not shown because too many files have changed in this diff Show More