diff --git a/.github/workflows/smoke_test.yml b/.github/workflows/smoke_test.yml new file mode 100644 index 000000000..15571867f --- /dev/null +++ b/.github/workflows/smoke_test.yml @@ -0,0 +1,58 @@ +name: Smoke Test + +on: + push: + branches: + - master + paths-ignore: + - 'README*.md' + - 'changelogs/**' + - 'dashboard/**' + pull_request: + workflow_dispatch: + +jobs: + smoke-test: + name: Run smoke tests + runs-on: ubuntu-latest + timeout-minutes: 10 + + steps: + - name: Checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.12' + + - name: Install UV package manager + run: | + pip install uv + + - name: Install dependencies + run: | + uv sync + timeout-minutes: 15 + + - name: Run smoke tests + run: | + uv run main.py & + APP_PID=$! + + echo "Waiting for application to start..." + for i in {1..60}; do + if curl -f http://localhost:6185 > /dev/null 2>&1; then + echo "Application started successfully!" + kill $APP_PID + exit 0 + fi + sleep 1 + done + + echo "Application failed to start within 30 seconds" + kill $APP_PID 2>/dev/null || true + exit 1 + timeout-minutes: 2 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 000000000..1e261bfa3 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,65 @@ +# CONTRIBUTING + +## 贡献指南 + +首先,感谢您花时间做出贡献!❤️ + +所有类型的贡献都受到鼓励和重视。有关不同的帮助方式和处理方式的详细信息,请参阅[目录](#目录)。在做出贡献之前,请确保阅读相关部分。这将使我们维护人员的工作变得更加容易,并为所有参与者带来顺畅的体验。社区期待您的贡献。🎉 + +### 目录 + +- [报告问题](#报告问题) +- [提交代码更改](#提交代码更改) + +### 报告问题 + +如果您在使用 AstrBot 时遇到任何问题,请按照以下步骤报告: + +1. **检查现有问题**:在提交新问题之前,请先检查 [Issues](https://github.com/AstrBotDevs/AstrBot/issues) 中是否已经存在类似的问题。 +2. **创建新问题**:如果没有类似的问题,请创建一个新问题。请确保提供以下信息: + - 问题的简要描述 + - 重现问题的步骤 + - 预期结果和实际结果 + - 相关日志或错误消息 + +### 提交代码更改 + +#### 分支命名 + +我们使用 `fix/` 前缀来修复错误,使用 `feat/` 前缀来添加新功能。对于 `fix/` 分支,请使用简短的描述,或者直接使用 Issue 编号。例如:`fix/1234` 或者 `fix/1234-login-typo`。对于 `feat/` 分支,请使用简短的描述,例如:`feat/add-user-profile`。 + +#### PR 描述 + +- 请使用英文描述您的 PR。 +- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。 + +## Contributing Guide + +First off, thanks for taking the time to contribute! ❤️ + +All types of contributions are encouraged and valued. See the [Table of Contents](#table-of-contents) for different ways to help and details about how this project handles them. Please make sure to read the relevant section before making your contribution. It will make it a lot easier for us maintainers and smooth out the experience for all involved. The community looks forward to your contributions. 🎉 + +### Table of Contents + +- [Reporting Issues](#reporting-issues) +- [Pull Requests](#pull-requests) + +### Reporting Issues + +If you encounter any issues while using AstrBot, please follow these steps to report them: +1. **Check Existing Issues**: Before submitting a new issue, please check if a similar issue already exists in the [Issues](https://github.com/AstrBotDevs/AstrBot/issues) section of the repository. +2. **Create a New Issue**: If no similar issue exists, please create a new issue. Make sure to provide the following information: + - A brief description of the issue + - Steps to reproduce the issue + - Expected and actual results + - Relevant logs or error messages + +### Pull Requests + +#### Branch Naming + +We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features. For `fix/` branches, please use a short description or directly use the Issue number, e.g., `fix/1234` or `fix/1234-login-typo`. For `feat/` branches, please use a short description, e.g., `feat/add-user-profile`. + +#### PR Description +- Please use English to describe your PR. +- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`. \ No newline at end of file diff --git a/README.md b/README.md index 6d60f80a3..a45fa1fc5 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,13 @@ ![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9) -

-
-
+ +English | +日本語 | +繁體中文 | +Français | +Русский
Soulter%2FAstrBot | Trendshift @@ -14,35 +17,38 @@
- -python -Docker pull -QQ_community -Telegram_community - + +python + +zread +Docker pull + +

-English | -日本語文档Blog路线图问题提交
-AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可无缝接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。 +AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。 + +image ## 主要功能 -1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。 -2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。 -3. **Agent**。完善适配的 Agentic 能力。支持多轮工具调用、内置沙盒代码执行器、网页搜索等功能。 -4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。 -5. **WebUI**。可视化配置和管理机器人,功能齐全。 +1. 💯 免费 & 开源。 +1. ✨ AI 大模型对话,多模态,Agent,MCP,知识库,人格设定。 +2. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台。 +2. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)。 +3. 📦 插件扩展,已有近 800 个插件可一键安装。 +5. 💻 WebUI 支持。 +6. 🌐 国际化(i18n)支持。 -## 部署方式 +## 快速开始 #### Docker 部署(推荐 🥳) @@ -50,6 +56,12 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可无缝接 请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。 +#### uv 部署 + +```bash +uvx astrbot +``` + #### 宝塔面板部署 AstrBot 与宝塔面板合作,已上架至宝塔面板。 @@ -101,24 +113,6 @@ uv run main.py 或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。 -## 🌍 社区 - -### QQ 群组 - -- 1 群:322154837 -- 3 群:630166526 -- 5 群:822130018 -- 6 群:753075035 -- 开发者群:975206796 - -### Telegram 群组 - -Telegram_community - -### Discord 群组 - -Discord_community - ## 支持的消息平台 **官方维护** @@ -205,6 +199,25 @@ pip install pre-commit pre-commit install ``` +## 🌍 社区 + +### QQ 群组 + +- 1 群:322154837 +- 3 群:630166526 +- 5 群:822130018 +- 6 群:753075035 +- 7 群:743746109 +- 开发者群:975206796 + +### Telegram 群组 + +Telegram_community + +### Discord 群组 + +Discord_community + ## ❤️ Special Thanks 特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️ diff --git a/README_en.md b/README_en.md index 520cfebe8..c5bc86593 100644 --- a/README_en.md +++ b/README_en.md @@ -19,30 +19,38 @@ Docker pull QQ_community Telegram_community - +

中文日本語 | +繁體中文 | +Français | +Русский + DocumentationBlogRoadmapIssue Tracker -AstrBot is an open-source all-in-one Agent chatbot platform and development framework. +AstrBot is an open-source all-in-one Agent chatbot platform that integrates with mainstream instant messaging apps. It provides reliable and scalable conversational AI infrastructure for individuals, developers, and teams. Whether you're building a personal AI companion, intelligent customer service, automation assistant, or enterprise knowledge base, AstrBot enables you to quickly build production-ready AI applications within your IM platform workflows. + +image ## Key Features -1. **LLM Conversations**. Supports integration with various large language model services. Features include multimodal capabilities, tool calling, MCP, native knowledge base, character personas, and more. -2. **Multi-Platform Support**. Integrates with QQ, WeChat Work, WeChat Official Accounts, Feishu, Telegram, DingTalk, Discord, KOOK, and other platforms. Supports rate limiting, whitelisting, and Baidu content moderation. -3. **Agent Capabilities**. Fully optimized agentic features including multi-turn tool calling, built-in sandboxed code executor, web search, and more. -4. **Plugin Extensions**. Deeply optimized plugin mechanism supporting [plugin development](https://astrbot.app/dev/plugin.html) to extend functionality, with a rich community plugin ecosystem. -5. **Web UI**. Visual configuration and management of your bot with comprehensive features. +1. 💯 Free & Open Source. +2. ✨ AI LLM Conversations, Multimodal, Agent, MCP, Knowledge Base, Persona Settings. +3. 🤖 Supports integration with Dify, Alibaba Cloud Bailian, Coze and other agent platforms. +4. 🌐 Multi-Platform: QQ, WeChat Work, Feishu, DingTalk, WeChat Official Accounts, Telegram, Slack, and [more](#supported-messaging-platforms). +5. 📦 Plugin Extensions with nearly 800 plugins available for one-click installation. +6. 💻 WebUI Support. +7. 🌐 Internationalization (i18n) Support. -## Deployment Methods +## Quick Start #### Docker Deployment (Recommended 🥳) @@ -50,6 +58,12 @@ We recommend deploying AstrBot using Docker or Docker Compose. Please refer to the official documentation: [Deploy AstrBot with Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot). +#### uv Deployment + +```bash +uvx astrbot +``` + #### BT-Panel Deployment AstrBot has partnered with BT-Panel and is now available in their marketplace. @@ -101,24 +115,6 @@ uv run main.py Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html). -## 🌍 Community - -### QQ Groups - -- Group 1: 322154837 -- Group 3: 630166526 -- Group 5: 822130018 -- Group 6: 753075035 -- Developer Group: 975206796 - -### Telegram Group - -Telegram_community - -### Discord Server - -Discord_community - ## Supported Messaging Platforms **Officially Maintained** @@ -205,6 +201,24 @@ pip install pre-commit pre-commit install ``` +## 🌍 Community + +### QQ Groups + +- Group 1: 322154837 +- Group 3: 630166526 +- Group 5: 822130018 +- Group 6: 753075035 +- Developer Group: 975206796 + +### Telegram Group + +Telegram_community + +### Discord Server + +Discord_community + ## ❤️ Special Thanks Special thanks to all Contributors and plugin developers for their contributions to AstrBot ❤️ diff --git a/README_fr.md b/README_fr.md new file mode 100644 index 000000000..8f658c9a0 --- /dev/null +++ b/README_fr.md @@ -0,0 +1,248 @@ +![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9) + +

+ +
+ +
+ +
+Soulter%2FAstrBot | Trendshift +Featured|HelloGitHub +
+ +
+ +
+ +python +Docker pull +QQ_community +Telegram_community + +
+ +
+ +中文 | +English | +日本語 | +繁體中文 | +Русский + +Documentation | +Blog | +Feuille de route | +Signaler un problème +
+ +AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègre aux principales applications de messagerie instantanée. Elle fournit une infrastructure d'IA conversationnelle fiable et évolutive pour les particuliers, les développeurs et les équipes. Que vous construisiez un compagnon IA personnel, un service client intelligent, un assistant d'automatisation ou une base de connaissances d'entreprise, AstrBot vous permet de créer rapidement des applications d'IA prêtes pour la production dans les flux de travail de votre plateforme de messagerie. + +image + +## Fonctionnalités principales + +1. 💯 Gratuit & Open Source. +2. ✨ Conversations avec LLM IA, Multimodal, Agent, MCP, Base de connaissances, Paramètres de personnalité. +3. 🤖 Prise en charge de l'intégration avec Dify, Alibaba Cloud Bailian, Coze et autres plateformes d'agents. +4. 🌐 Multi-plateforme : QQ, WeChat Work, Feishu, DingTalk, Comptes officiels WeChat, Telegram, Slack, et [plus encore](#plateformes-de-messagerie-prises-en-charge). +5. 📦 Extensions de plugins avec près de 800 plugins disponibles pour une installation en un clic. +6. 💻 Support WebUI. +7. 🌐 Support de l'internationalisation (i18n). + +## Démarrage rapide + +#### Déploiement Docker (Recommandé 🥳) + +Nous recommandons de déployer AstrBot en utilisant Docker ou Docker Compose. + +Veuillez consulter la documentation officielle : [Déployer AstrBot avec Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot). + +#### Déploiement uv + +```bash +uvx astrbot +``` + +#### Déploiement BT-Panel + +AstrBot s'est associé à BT-Panel et est maintenant disponible sur leur marketplace. + +Veuillez consulter la documentation officielle : [Déploiement BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html). + +#### Déploiement 1Panel + +AstrBot a été officiellement listé sur le marketplace 1Panel. + +Veuillez consulter la documentation officielle : [Déploiement 1Panel](https://astrbot.app/deploy/astrbot/1panel.html). + +#### Déployer sur RainYun + +AstrBot a été officiellement listé sur la plateforme d'applications cloud de RainYun avec un déploiement en un clic. + +[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0) + +#### Déployer sur Replit + +Méthode de déploiement contribuée par la communauté. + +[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) + +#### Installateur Windows en un clic + +Veuillez consulter la documentation officielle : [Déployer AstrBot avec l'installateur Windows en un clic](https://astrbot.app/deploy/astrbot/windows.html). + +#### Déploiement CasaOS + +Méthode de déploiement contribuée par la communauté. + +Veuillez consulter la documentation officielle : [Déploiement CasaOS](https://astrbot.app/deploy/astrbot/casaos.html). + +#### Déploiement manuel + +Tout d'abord, installez uv : + +```bash +pip install uv +``` + +Installez AstrBot via Git Clone : + +```bash +git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot +uv run main.py +``` + +Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources](https://astrbot.app/deploy/astrbot/cli.html). + +## Plateformes de messagerie prises en charge + +**Maintenues officiellement** + +- QQ (Plateforme officielle & OneBot) +- Telegram +- Application WeChat Work & Bot intelligent WeChat Work +- Service client WeChat & Comptes officiels WeChat +- Feishu (Lark) +- DingTalk +- Slack +- Discord +- Satori +- Misskey +- WhatsApp (Bientôt disponible) +- LINE (Bientôt disponible) + +**Maintenues par la communauté** + +- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) +- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) +- [Messages directs Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) +- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) + +## Services de modèles pris en charge + +**Services LLM** + +- OpenAI et services compatibles +- Anthropic +- Google Gemini +- Moonshot AI +- Zhipu AI +- DeepSeek +- Ollama (Auto-hébergé) +- LM Studio (Auto-hébergé) +- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) +- [302.AI](https://share.302.ai/rr1M3l) +- [TokenPony](https://www.tokenpony.cn/3YPyf) +- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot) +- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE) +- ModelScope +- OneAPI + +**Plateformes LLMOps** + +- Dify +- Applications Alibaba Cloud Bailian +- Coze + +**Services de reconnaissance vocale** + +- OpenAI Whisper +- SenseVoice + +**Services de synthèse vocale** + +- OpenAI TTS +- Gemini TTS +- GPT-Sovits-Inference +- GPT-Sovits +- FishAudio +- Edge TTS +- Alibaba Cloud Bailian TTS +- Azure TTS +- Minimax TTS +- Volcano Engine TTS + +## ❤️ Contribuer + +Les Issues et Pull Requests sont toujours les bienvenues ! N'hésitez pas à soumettre vos modifications à ce projet :) + +### Comment contribuer + +Vous pouvez contribuer en examinant les issues ou en aidant à la revue des pull requests. Toutes les issues ou PRs sont les bienvenues pour encourager la participation de la communauté. Bien sûr, ce ne sont que des suggestions - vous pouvez contribuer de la manière que vous souhaitez. Pour l'ajout de nouvelles fonctionnalités, veuillez d'abord en discuter via une Issue. + +### Environnement de développement + +AstrBot utilise `ruff` pour le formatage et le linting du code. + +```bash +git clone https://github.com/AstrBotDevs/AstrBot +pip install pre-commit +pre-commit install +``` + +## 🌍 Communauté + +### Groupes QQ + +- Groupe 1 : 322154837 +- Groupe 3 : 630166526 +- Groupe 5 : 822130018 +- Groupe 6 : 753075035 +- Groupe développeurs : 975206796 + +### Groupe Telegram + +Telegram_community + +### Serveur Discord + +Discord_community + +## ❤️ Remerciements spéciaux + +Un grand merci à tous les contributeurs et développeurs de plugins pour leurs contributions à AstrBot ❤️ + + + + + +De plus, la naissance de ce projet n'aurait pas été possible sans l'aide des projets open source suivants : + +- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - L'incroyable framework chat + +## ⭐ Historique des étoiles + +> [!TIP] +> Si ce projet vous a aidé dans votre vie ou votre travail, ou si vous êtes intéressé par son développement futur, veuillez donner une étoile au projet. C'est la force motrice derrière la maintenance de ce projet open source <3 + +
+ +[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date) + +
+ + + +_私は、高性能ですから!_ + diff --git a/README_ja.md b/README_ja.md index 3fce01878..d94bf83b7 100644 --- a/README_ja.md +++ b/README_ja.md @@ -19,30 +19,38 @@ Docker pull QQ_community Telegram_community - +
中文English | +繁體中文 | +Français | +Русский + ドキュメントBlogロードマップIssue -AstrBot は、オープンソースのオールインワン Agent チャットボットプラットフォーム及び開発フレームワークです。 +AstrBot は、主要なインスタントメッセージングアプリと統合できるオープンソースのオールインワン Agent チャットボットプラットフォームです。個人、開発者、チームに信頼性が高くスケーラブルな会話型 AI インフラストラクチャを提供します。パーソナル AI コンパニオン、インテリジェントカスタマーサービス、オートメーションアシスタント、エンタープライズナレッジベースなど、AstrBot を使用すると、IM プラットフォームのワークフロー内で本番環境対応の AI アプリケーションを迅速に構築できます。 + +image ## 主な機能 -1. **大規模言語モデル対話**。多様な大規模言語モデルサービスとの統合をサポート。マルチモーダル、ツール呼び出し、MCP、ネイティブナレッジベース、キャラクター設定などの機能を搭載。 -2. **マルチメッセージプラットフォームサポート**。QQ、WeChat Work、WeChat公式アカウント、Feishu、Telegram、DingTalk、Discord、KOOK などのプラットフォームと統合可能。レート制限、ホワイトリスト、Baidu コンテンツ審査をサポート。 -3. **Agent**。完全に最適化された Agentic 機能。マルチターンツール呼び出し、内蔵サンドボックスコード実行環境、Web 検索などの機能をサポート。 -4. **プラグイン拡張**。深く最適化されたプラグインメカニズムで、[プラグイン開発](https://astrbot.app/dev/plugin.html)による機能拡張をサポート。豊富なコミュニティプラグインエコシステム。 -5. **WebUI**。ビジュアル設定とボット管理、充実した機能。 +1. 💯 無料 & オープンソース。 +2. ✨ AI 大規模言語モデル対話、マルチモーダル、Agent、MCP、ナレッジベース、ペルソナ設定。 +3. 🤖 Dify、Alibaba Cloud 百炼、Coze などの Agent プラットフォームとの統合をサポート。 +4. 🌐 マルチプラットフォーム:QQ、WeChat Work、Feishu、DingTalk、WeChat 公式アカウント、Telegram、Slack、[その他](#サポートされているメッセージプラットフォーム)。 +5. 📦 約800個のプラグインをワンクリックでインストール可能なプラグイン拡張機能。 +6. 💻 WebUI サポート。 +7. 🌐 国際化(i18n)サポート。 -## デプロイ方法 +## クイックスタート #### Docker デプロイ(推奨 🥳) @@ -50,6 +58,12 @@ Docker / Docker Compose を使用した AstrBot のデプロイを推奨しま 公式ドキュメント [Docker を使用した AstrBot のデプロイ](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) をご参照ください。 +#### uv デプロイ + +```bash +uvx astrbot +``` + #### 宝塔パネルデプロイ AstrBot は宝塔パネルと提携し、宝塔パネルに公開されています。 @@ -101,24 +115,6 @@ uv run main.py または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。 -## 🌍 コミュニティ - -### QQ グループ - -- 1群:322154837 -- 3群:630166526 -- 5群:822130018 -- 6群:753075035 -- 開発者群:975206796 - -### Telegram グループ - -Telegram_community - -### Discord サーバー - -Discord_community - ## サポートされているメッセージプラットフォーム **公式メンテナンス** @@ -205,6 +201,24 @@ pip install pre-commit pre-commit install ``` +## 🌍 コミュニティ + +### QQ グループ + +- 1群: 322154837 +- 3群: 630166526 +- 5群: 822130018 +- 6群: 753075035 +- 開発者群: 975206796 + +### Telegram グループ + +Telegram_community + +### Discord サーバー + +Discord_community + ## ❤️ Special Thanks AstrBot への貢献をしていただいたすべてのコントリビューターとプラグイン開発者に特別な感謝を ❤️ diff --git a/README_ru.md b/README_ru.md new file mode 100644 index 000000000..ea8e9b6bf --- /dev/null +++ b/README_ru.md @@ -0,0 +1,248 @@ +![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9) + +

+ +
+ +
+ +
+Soulter%2FAstrBot | Trendshift +Featured|HelloGitHub +
+ +
+ +
+ +python +Docker pull +QQ_community +Telegram_community + +
+ +
+ +中文 | +English | +日本語 | +繁體中文 | +Français + +Документация | +Блог | +Дорожная карта | +Сообщить о проблеме +
+ +AstrBot — это универсальная платформа Agent-чатботов с открытым исходным кодом, которая интегрируется с основными приложениями для обмена мгновенными сообщениями. Она предоставляет надёжную и масштабируемую инфраструктуру разговорного ИИ для частных лиц, разработчиков и команд. Будь то персональный ИИ-компаньон, интеллектуальная служба поддержки, автоматизированный помощник или корпоративная база знаний — AstrBot позволяет быстро создавать готовые к использованию ИИ-приложения в рабочих процессах вашей платформы обмена сообщениями. + +image + +## Основные возможности + +1. 💯 Бесплатно и с открытым исходным кодом. +2. ✨ ИИ-диалоги с LLM, мультимодальность, Agent, MCP, база знаний, настройки личности. +3. 🤖 Поддержка интеграции с Dify, Alibaba Cloud Bailian, Coze и другими платформами агентов. +4. 🌐 Мультиплатформенность: QQ, WeChat Work, Feishu, DingTalk, официальные аккаунты WeChat, Telegram, Slack и [другие](#поддерживаемые-платформы-обмена-сообщениями). +5. 📦 Расширения плагинов с почти 800 плагинами, доступными для установки в один клик. +6. 💻 Поддержка WebUI. +7. 🌐 Поддержка интернационализации (i18n). + +## Быстрый старт + +#### Развёртывание Docker (Рекомендуется 🥳) + +Мы рекомендуем развёртывать AstrBot с помощью Docker или Docker Compose. + +См. официальную документацию: [Развёртывание AstrBot с Docker](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot). + +#### Развёртывание uv + +```bash +uvx astrbot +``` + +#### Развёртывание BT-Panel + +AstrBot в партнёрстве с BT-Panel теперь доступен на их маркетплейсе. + +См. официальную документацию: [Развёртывание BT-Panel](https://astrbot.app/deploy/astrbot/btpanel.html). + +#### Развёртывание 1Panel + +AstrBot официально размещён на маркетплейсе 1Panel. + +См. официальную документацию: [Развёртывание 1Panel](https://astrbot.app/deploy/astrbot/1panel.html). + +#### Развёртывание на RainYun + +AstrBot официально размещён на облачной платформе приложений RainYun с развёртыванием в один клик. + +[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0) + +#### Развёртывание на Replit + +Метод развёртывания от сообщества. + +[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) + +#### Установщик Windows в один клик + +См. официальную документацию: [Развёртывание AstrBot с установщиком Windows в один клик](https://astrbot.app/deploy/astrbot/windows.html). + +#### Развёртывание CasaOS + +Метод развёртывания от сообщества. + +См. официальную документацию: [Развёртывание CasaOS](https://astrbot.app/deploy/astrbot/casaos.html). + +#### Ручное развёртывание + +Сначала установите uv: + +```bash +pip install uv +``` + +Установите AstrBot через Git Clone: + +```bash +git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot +uv run main.py +``` + +Или см. официальную документацию: [Развёртывание AstrBot из исходного кода](https://astrbot.app/deploy/astrbot/cli.html). + +## Поддерживаемые платформы обмена сообщениями + +**Официально поддерживаемые** + +- QQ (Официальная платформа и OneBot) +- Telegram +- Приложение WeChat Work и интеллектуальный бот WeChat Work +- Служба поддержки WeChat и официальные аккаунты WeChat +- Feishu (Lark) +- DingTalk +- Slack +- Discord +- Satori +- Misskey +- WhatsApp (Скоро) +- LINE (Скоро) + +**Поддерживаемые сообществом** + +- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) +- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) +- [Личные сообщения Bilibili](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) +- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) + +## Поддерживаемые сервисы моделей + +**Сервисы LLM** + +- OpenAI и совместимые сервисы +- Anthropic +- Google Gemini +- Moonshot AI +- Zhipu AI +- DeepSeek +- Ollama (Самостоятельное размещение) +- LM Studio (Самостоятельное размещение) +- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) +- [302.AI](https://share.302.ai/rr1M3l) +- [TokenPony](https://www.tokenpony.cn/3YPyf) +- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot) +- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE) +- ModelScope +- OneAPI + +**Платформы LLMOps** + +- Dify +- Приложения Alibaba Cloud Bailian +- Coze + +**Сервисы распознавания речи** + +- OpenAI Whisper +- SenseVoice + +**Сервисы синтеза речи** + +- OpenAI TTS +- Gemini TTS +- GPT-Sovits-Inference +- GPT-Sovits +- FishAudio +- Edge TTS +- Alibaba Cloud Bailian TTS +- Azure TTS +- Minimax TTS +- Volcano Engine TTS + +## ❤️ Вклад в проект + +Issues и Pull Request всегда приветствуются! Не стесняйтесь отправлять свои изменения в этот проект :) + +### Как внести вклад + +Вы можете внести вклад, просматривая issues или помогая с ревью pull request. Любые issues или PR приветствуются для поощрения участия сообщества. Конечно, это лишь предложения — вы можете вносить вклад любым удобным для вас способом. Для добавления новых функций сначала обсудите это через Issue. + +### Среда разработки + +AstrBot использует `ruff` для форматирования и линтинга кода. + +```bash +git clone https://github.com/AstrBotDevs/AstrBot +pip install pre-commit +pre-commit install +``` + +## 🌍 Сообщество + +### Группы QQ + +- Группа 1: 322154837 +- Группа 3: 630166526 +- Группа 5: 822130018 +- Группа 6: 753075035 +- Группа разработчиков: 975206796 + +### Группа Telegram + +Telegram_community + +### Сервер Discord + +Discord_community + +## ❤️ Особая благодарность + +Особая благодарность всем контрибьюторам и разработчикам плагинов за их вклад в AstrBot ❤️ + + + + + +Кроме того, рождение этого проекта было бы невозможно без помощи следующих проектов с открытым исходным кодом: + +- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - Замечательный кошачий фреймворк + +## ⭐ История звёзд + +> [!TIP] +> Если этот проект помог вам в жизни или работе, или если вас интересует его будущее развитие, пожалуйста, поставьте проекту звезду. Это движущая сила поддержки этого проекта с открытым исходным кодом <3 + +
+ +[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date) + +
+ + + +_私は、高性能ですから!_ + diff --git a/README_zh-TW.md b/README_zh-TW.md new file mode 100644 index 000000000..5f77ab7ce --- /dev/null +++ b/README_zh-TW.md @@ -0,0 +1,248 @@ +![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9) + +

+ +
+ +
+ +
+Soulter%2FAstrBot | Trendshift +Featured|HelloGitHub +
+ +
+ +
+ +python +Docker pull +QQ_community +Telegram_community + +
+ +
+ +简体中文 | +English | +日本語 | +Français | +Русский + +文件 | +Blog | +路線圖 | +問題回報 +
+ +AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主流即時通訊軟體,為個人、開發者和團隊打造可靠、可擴展的對話式智慧基礎設施。無論是個人 AI 夥伴、智慧客服、自動化助手,還是企業知識庫,AstrBot 都能在您的即時通訊軟體平台的工作流程中快速構建生產可用的 AI 應用程式。 + +image + +## 主要功能 + +1. 💯 免費 & 開源。 +2. ✨ AI 大型模型對話,多模態,Agent,MCP,知識庫,人格設定。 +3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體平台。 +4. 🌐 多平台:QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。 +5. 📦 外掛擴充,已有近 800 個外掛可一鍵安裝。 +6. 💻 WebUI 支援。 +7. 🌐 國際化(i18n)支援。 + +## 快速開始 + +#### Docker 部署(推薦 🥳) + +推薦使用 Docker / Docker Compose 方式部署 AstrBot。 + +請參閱官方文件 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot)。 + +#### uv 部署 + +```bash +uvx astrbot +``` + +#### 寶塔面板部署 + +AstrBot 與寶塔面板合作,已上架至寶塔面板。 + +請參閱官方文件 [寶塔面板部署](https://astrbot.app/deploy/astrbot/btpanel.html)。 + +#### 1Panel 部署 + +AstrBot 已由 1Panel 官方上架至 1Panel 面板。 + +請參閱官方文件 [1Panel 部署](https://astrbot.app/deploy/astrbot/1panel.html)。 + +#### 在雨雲上部署 + +AstrBot 已由雨雲官方上架至雲端應用程式平台,可一鍵部署。 + +[![Deploy on RainYun](https://rainyun-apps.cn-nb1.rains3.com/materials/deploy-on-rainyun-en.svg)](https://app.rainyun.com/apps/rca/store/5994?ref=NjU1ODg0) + +#### 在 Replit 上部署 + +社群貢獻的部署方式。 + +[![Run on Repl.it](https://repl.it/badge/github/AstrBotDevs/AstrBot)](https://repl.it/github/AstrBotDevs/AstrBot) + +#### Windows 一鍵安裝器部署 + +請參閱官方文件 [使用 Windows 一鍵安裝器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html)。 + +#### CasaOS 部署 + +社群貢獻的部署方式。 + +請參閱官方文件 [CasaOS 部署](https://astrbot.app/deploy/astrbot/casaos.html)。 + +#### 手動部署 + +首先安裝 uv: + +```bash +pip install uv +``` + +透過 Git Clone 安裝 AstrBot: + +```bash +git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot +uv run main.py +``` + +或者請參閱官方文件 [透過原始碼部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)。 + +## 支援的訊息平台 + +**官方維護** + +- QQ(官方平台 & OneBot) +- Telegram +- 企微應用 & 企微智慧機器人 +- 微信客服 & 微信公眾號 +- 飛書 +- 釘釘 +- Slack +- Discord +- Satori +- Misskey +- Whatsapp(即將支援) +- LINE(即將支援) + +**社群維護** + +- [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) +- [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) +- [Bilibili 私訊](https://github.com/Hina-Chat/astrbot_plugin_bilibili_adapter) +- [wxauto](https://github.com/luosheng520qaq/wxauto-repost-onebotv11) + +## 支援的模型服務 + +**大型模型服務** + +- OpenAI 及相容服務 +- Anthropic +- Google Gemini +- Moonshot AI +- 智譜 AI +- DeepSeek +- Ollama(本機部署) +- LM Studio(本機部署) +- [優雲智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) +- [302.AI](https://share.302.ai/rr1M3l) +- [小馬算力](https://www.tokenpony.cn/3YPyf) +- [矽基流動](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot) +- [PPIO 派歐雲](https://ppio.com/user/register?invited_by=AIOONE) +- ModelScope +- OneAPI + +**LLMOps 平台** + +- Dify +- 阿里雲百煉應用 +- Coze + +**語音轉文字服務** + +- OpenAI Whisper +- SenseVoice + +**文字轉語音服務** + +- OpenAI TTS +- Gemini TTS +- GPT-Sovits-Inference +- GPT-Sovits +- FishAudio +- Edge TTS +- 阿里雲百煉 TTS +- Azure TTS +- Minimax TTS +- 火山引擎 TTS + +## ❤️ 貢獻 + +歡迎任何 Issues/Pull Requests!只需要將您的變更提交到此專案 :) + +### 如何貢獻 + +您可以透過檢視問題或協助審核 PR(拉取請求)來貢獻。任何問題或 PR 都歡迎參與,以促進社群貢獻。當然,這些只是建議,您可以以任何方式進行貢獻。對於新功能的新增,請先透過 Issue 討論。 + +### 開發環境 + +AstrBot 使用 `ruff` 進行程式碼格式化和檢查。 + +```bash +git clone https://github.com/AstrBotDevs/AstrBot +pip install pre-commit +pre-commit install +``` + +## 🌍 社群 + +### QQ 群組 + +- 1 群:322154837 +- 3 群:630166526 +- 5 群:822130018 +- 6 群:753075035 +- 開發者群:975206796 + +### Telegram 群組 + +Telegram_community + +### Discord 群組 + +Discord_community + +## ❤️ Special Thanks + +特別感謝所有 Contributors 和外掛開發者對 AstrBot 的貢獻 ❤️ + + + + + +此外,本專案的誕生離不開以下開源專案的幫助: + +- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 偉大的貓貓框架 + +## ⭐ Star History + +> [!TIP] +> 如果本專案對您的生活 / 工作產生了幫助,或者您關注本專案的未來發展,請給專案 Star,這是我們維護這個開源專案的動力 <3 + +
+ +[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date) + +
+ + + +_私は、高性能ですから!_ + diff --git a/astrbot/cli/__init__.py b/astrbot/cli/__init__.py index 8358b03cc..ea674c5c5 100644 --- a/astrbot/cli/__init__.py +++ b/astrbot/cli/__init__.py @@ -1 +1 @@ -__version__ = "4.7.4" +__version__ = "4.8.0" diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 6f3c813eb..1cf572aa8 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -97,7 +97,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): llm_resp_result = None async for llm_response in self._iter_llm_responses(): - assert isinstance(llm_response, LLMResponse) if llm_response.is_chunk: if llm_response.result_chain: yield AgentResponse( diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index b60088609..7f30f44ef 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -1,4 +1,4 @@ -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any, Generic import jsonschema @@ -7,6 +7,8 @@ from deprecated import deprecated from pydantic import Field, model_validator from pydantic.dataclasses import dataclass +from astrbot.core.message.message_event_result import MessageEventResult + from .run_context import ContextWrapper, TContext ParametersType = dict[str, Any] @@ -38,7 +40,10 @@ class ToolSchema: class FunctionTool(ToolSchema, Generic[TContext]): """A callable tool, for function calling.""" - handler: Callable[..., Awaitable[Any]] | None = None + handler: ( + Callable[..., Awaitable[str | None] | AsyncGenerator[MessageEventResult, None]] + | None + ) = None """a callable that implements the tool's functionality. It should be an async function.""" handler_module_path: str | None = None diff --git a/astrbot/core/astr_agent_run_util.py b/astrbot/core/astr_agent_run_util.py index 5deb5af4e..d94d96a82 100644 --- a/astrbot/core/astr_agent_run_util.py +++ b/astrbot/core/astr_agent_run_util.py @@ -9,6 +9,7 @@ from astrbot.core.message.message_event_result import ( MessageEventResult, ResultContentType, ) +from astrbot.core.provider.entities import LLMResponse AgentRunner = ToolLoopAgentRunner[AstrAgentContext] @@ -72,7 +73,20 @@ async def run_agent( except Exception as e: logger.error(traceback.format_exc()) - err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n" + + err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n" + + error_llm_response = LLMResponse( + role="err", + completion_text=err_msg, + ) + try: + await agent_runner.agent_hooks.on_agent_done( + agent_runner.run_context, error_llm_response + ) + except Exception: + logger.exception("Error in on_agent_done hook") + if agent_runner.streaming: yield MessageChain().message(err_msg) else: diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index 440dea2d1..ed08e90a9 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -185,7 +185,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): async def call_local_llm_tool( context: ContextWrapper[AstrAgentContext], - handler: T.Callable[..., T.Awaitable[T.Any]], + handler: T.Callable[ + ..., + T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None] + | T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None], + ], method_name: str, *args, **kwargs, diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 786d29c81..9477eabaa 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -24,6 +24,10 @@ class AstrBotConfig(dict): - 如果传入了 schema,将会通过 schema 解析出 default_config,此时传入的 default_config 会被忽略。 """ + config_path: str + default_config: dict + schema: dict | None + def __init__( self, config_path: str = ASTRBOT_CONFIG_PATH, diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index b91c57c63..e8778bfc6 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -4,9 +4,17 @@ import os from astrbot.core.utils.astrbot_path import get_astrbot_data_path -VERSION = "4.7.4" +VERSION = "4.8.0" DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") +WEBHOOK_SUPPORTED_PLATFORMS = [ + "qq_official_webhook", + "weixin_official_account", + "wecom", + "wecom_ai_bot", + "slack", +] + # 默认配置 DEFAULT_CONFIG = { "config_version": 2, @@ -34,7 +42,15 @@ DEFAULT_CONFIG = { "interval": "1.5,3.5", "log_base": 2.6, "words_count_threshold": 150, + "split_mode": "regex", # regex 或 words "regex": ".*?[。?!~…]+|.+$", + "split_words": [ + "。", + "?", + "!", + "~", + "…", + ], # 当 split_mode 为 words 时使用 "content_cleanup_rule": "", }, "no_permission_reply": True, @@ -149,6 +165,7 @@ DEFAULT_CONFIG = { "kb_fusion_top_k": 20, # 知识库检索融合阶段返回结果数量 "kb_final_top_k": 5, # 知识库检索最终返回结果数量 "kb_agentic_mode": False, + "disable_builtin_commands": False, } @@ -185,6 +202,8 @@ CONFIG_METADATA_2 = { "appid": "", "secret": "", "is_sandbox": False, + "unified_webhook_mode": True, + "webhook_uuid": "", "callback_server_host": "0.0.0.0", "port": 6196, }, @@ -215,6 +234,8 @@ CONFIG_METADATA_2 = { "token": "", "encoding_aes_key": "", "api_base_url": "https://api.weixin.qq.com/cgi-bin/", + "unified_webhook_mode": True, + "webhook_uuid": "", "callback_server_host": "0.0.0.0", "port": 6194, "active_send_mode": False, @@ -229,6 +250,8 @@ CONFIG_METADATA_2 = { "encoding_aes_key": "", "kf_name": "", "api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/", + "unified_webhook_mode": True, + "webhook_uuid": "", "callback_server_host": "0.0.0.0", "port": 6195, }, @@ -241,6 +264,8 @@ CONFIG_METADATA_2 = { "wecom_ai_bot_name": "", "token": "", "encoding_aes_key": "", + "unified_webhook_mode": True, + "webhook_uuid": "", "callback_server_host": "0.0.0.0", "port": 6198, }, @@ -308,6 +333,8 @@ CONFIG_METADATA_2 = { "app_token": "", "signing_secret": "", "slack_connection_mode": "socket", # webhook, socket + "unified_webhook_mode": True, + "webhook_uuid": "", "slack_webhook_host": "0.0.0.0", "slack_webhook_port": 6197, "slack_webhook_path": "/astrbot-slack-webhook/callback", @@ -387,16 +414,28 @@ CONFIG_METADATA_2 = { "description": "Slack Webhook Host", "type": "string", "hint": "Only valid when Slack connection mode is `webhook`.", + "condition": { + "slack_connection_mode": "webhook", + "unified_webhook_mode": False, + }, }, "slack_webhook_port": { "description": "Slack Webhook Port", "type": "int", "hint": "Only valid when Slack connection mode is `webhook`.", + "condition": { + "slack_connection_mode": "webhook", + "unified_webhook_mode": False, + }, }, "slack_webhook_path": { "description": "Slack Webhook Path", "type": "string", "hint": "Only valid when Slack connection mode is `webhook`.", + "condition": { + "slack_connection_mode": "webhook", + "unified_webhook_mode": False, + }, }, "active_send_mode": { "description": "是否换用主动发送接口", @@ -587,6 +626,33 @@ CONFIG_METADATA_2 = { "type": "string", "hint": "可选的 Discord 活动名称。留空则不设置活动。", }, + "port": { + "description": "回调服务器端口", + "type": "int", + "hint": "回调服务器端口。留空则不启用回调服务器。", + "condition": { + "unified_webhook_mode": False, + }, + }, + "callback_server_host": { + "description": "回调服务器主机", + "type": "string", + "hint": "回调服务器主机。留空则不启用回调服务器。", + "condition": { + "unified_webhook_mode": False, + }, + }, + "unified_webhook_mode": { + "description": "统一 Webhook 模式", + "type": "bool", + "hint": "启用后,将使用 AstrBot 统一 Webhook 入口,无需单独开启端口。回调地址为 /api/platform/webhook/{webhook_uuid}。", + }, + "webhook_uuid": { + "invisible": True, + "description": "Webhook UUID", + "type": "string", + "hint": "统一 Webhook 模式下的唯一标识符,创建平台时自动生成。", + }, }, }, "platform_settings": { @@ -2604,6 +2670,11 @@ CONFIG_METADATA_3 = { "description": "只 @ 机器人是否触发等待", "type": "bool", }, + "disable_builtin_commands": { + "description": "禁用自带指令", + "type": "bool", + "hint": "禁用所有 AstrBot 的自带指令,如 help, provider, model 等。", + }, }, }, "whitelist": { @@ -2818,9 +2889,26 @@ CONFIG_METADATA_3 = { "description": "分段回复字数阈值", "type": "int", }, + "platform_settings.segmented_reply.split_mode": { + "description": "分段模式", + "type": "string", + "options": ["regex", "words"], + "labels": ["正则表达式", "分段词列表"], + }, "platform_settings.segmented_reply.regex": { "description": "分段正则表达式", "type": "string", + "condition": { + "platform_settings.segmented_reply.split_mode": "regex", + }, + }, + "platform_settings.segmented_reply.split_words": { + "description": "分段词列表", + "type": "list", + "hint": "检测到列表中的任意词时进行分段,如:。、?、!等", + "condition": { + "platform_settings.segmented_reply.split_mode": "words", + }, }, "platform_settings.segmented_reply.content_cleanup_rule": { "description": "内容过滤正则表达式", diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index e8241f85a..5a8672837 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -197,7 +197,7 @@ class AstrBotCoreLifecycle: # 把插件中注册的所有协程函数注册到事件总线中并执行 extra_tasks = [] for task in self.star_context._register_tasks: - extra_tasks.append(asyncio.create_task(task, name=task.__name__)) + extra_tasks.append(asyncio.create_task(task, name=task.__name__)) # type: ignore tasks_ = [event_bus_task, *extra_tasks] for task in tasks_: diff --git a/astrbot/core/db/__init__.py b/astrbot/core/db/__init__.py index 0b341c9db..192c7b263 100644 --- a/astrbot/core/db/__init__.py +++ b/astrbot/core/db/__init__.py @@ -5,8 +5,7 @@ from contextlib import asynccontextmanager from dataclasses import dataclass from deprecated import deprecated -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from astrbot.core.db.po import ( Attachment, @@ -34,7 +33,7 @@ class BaseDatabase(abc.ABC): echo=False, future=True, ) - self.AsyncSessionLocal = sessionmaker( + self.AsyncSessionLocal = async_sessionmaker( self.engine, class_=AsyncSession, expire_on_commit=False, @@ -175,7 +174,7 @@ class BaseDatabase(abc.ABC): content: dict, sender_id: str | None = None, sender_name: str | None = None, - ) -> None: + ) -> PlatformMessageHistory: """Insert a new platform message history record.""" ... @@ -200,6 +199,14 @@ class BaseDatabase(abc.ABC): """Get platform message history for a specific user.""" ... + @abc.abstractmethod + async def get_platform_message_history_by_id( + self, + message_id: int, + ) -> PlatformMessageHistory | None: + """Get a platform message history record by its ID.""" + ... + @abc.abstractmethod async def insert_attachment( self, @@ -215,6 +222,27 @@ class BaseDatabase(abc.ABC): """Get an attachment by its ID.""" ... + @abc.abstractmethod + async def get_attachments(self, attachment_ids: list[str]) -> list[Attachment]: + """Get multiple attachments by their IDs.""" + ... + + @abc.abstractmethod + async def delete_attachment(self, attachment_id: str) -> bool: + """Delete an attachment by its ID. + + Returns True if the attachment was deleted, False if it was not found. + """ + ... + + @abc.abstractmethod + async def delete_attachments(self, attachment_ids: list[str]) -> int: + """Delete multiple attachments by their IDs. + + Returns the number of attachments deleted. + """ + ... + @abc.abstractmethod async def insert_persona( self, diff --git a/astrbot/core/db/migration/migra_3_to_4.py b/astrbot/core/db/migration/migra_3_to_4.py index a75c60a1b..66b72d5cb 100644 --- a/astrbot/core/db/migration/migra_3_to_4.py +++ b/astrbot/core/db/migration/migra_3_to_4.py @@ -70,6 +70,7 @@ async def migration_conversation_table( logger.info( f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", ) + continue if ":" not in conv.user_id: continue session = MessageSesion.from_str(session_str=conv.user_id) @@ -207,6 +208,7 @@ async def migration_webchat_data( logger.info( f"未找到该条旧会话对应的具体数据: {conversation}, 跳过。", ) + continue if ":" in conv.user_id: continue platform_id = "webchat" diff --git a/astrbot/core/db/migration/sqlite_v3.py b/astrbot/core/db/migration/sqlite_v3.py index a301028d1..b1a780d48 100644 --- a/astrbot/core/db/migration/sqlite_v3.py +++ b/astrbot/core/db/migration/sqlite_v3.py @@ -127,7 +127,7 @@ class SQLiteDatabase: conn.text_factory = str return conn - def _exec_sql(self, sql: str, params: tuple = None): + def _exec_sql(self, sql: str, params: tuple | None = None): conn = self.conn try: c = self.conn.cursor() @@ -224,9 +224,11 @@ class SQLiteDatabase: c.close() - return Stats(platform, [], []) + return Stats(platform) - def get_conversation_by_user_id(self, user_id: str, cid: str) -> Conversation: + def get_conversation_by_user_id( + self, user_id: str, cid: str + ) -> Conversation | None: try: c = self.conn.cursor() except sqlite3.ProgrammingError: @@ -258,7 +260,7 @@ class SQLiteDatabase: (user_id, cid, history, updated_at, created_at), ) - def get_conversations(self, user_id: str) -> tuple: + def get_conversations(self, user_id: str) -> list[Conversation]: try: c = self.conn.cursor() except sqlite3.ProgrammingError: diff --git a/astrbot/core/db/po.py b/astrbot/core/db/po.py index e37d9290a..64bcf4ce3 100644 --- a/astrbot/core/db/po.py +++ b/astrbot/core/db/po.py @@ -12,7 +12,7 @@ class PlatformStat(SQLModel, table=True): Note: In astrbot v4, we moved `platform` table to here. """ - __tablename__ = "platform_stats" # type: ignore + __tablename__: str = "platform_stats" id: int = Field(primary_key=True, sa_column_kwargs={"autoincrement": True}) timestamp: datetime = Field(nullable=False) @@ -31,9 +31,10 @@ class PlatformStat(SQLModel, table=True): class ConversationV2(SQLModel, table=True): - __tablename__ = "conversations" # type: ignore + __tablename__: str = "conversations" - inner_conversation_id: int = Field( + inner_conversation_id: int | None = Field( + default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}, ) @@ -68,7 +69,7 @@ class Persona(SQLModel, table=True): It can be used to customize the behavior of LLMs. """ - __tablename__ = "personas" # type: ignore + __tablename__: str = "personas" id: int | None = Field( primary_key=True, @@ -98,7 +99,7 @@ class Persona(SQLModel, table=True): class Preference(SQLModel, table=True): """This class represents preferences for bots.""" - __tablename__ = "preferences" # type: ignore + __tablename__: str = "preferences" id: int | None = Field( default=None, @@ -134,7 +135,7 @@ class PlatformMessageHistory(SQLModel, table=True): or platform-specific messages. """ - __tablename__ = "platform_message_history" # type: ignore + __tablename__: str = "platform_message_history" id: int | None = Field( primary_key=True, @@ -162,7 +163,7 @@ class PlatformSession(SQLModel, table=True): Each session can have multiple conversations (对话) associated with it. """ - __tablename__ = "platform_sessions" # type: ignore + __tablename__: str = "platform_sessions" inner_id: int | None = Field( primary_key=True, @@ -203,7 +204,7 @@ class Attachment(SQLModel, table=True): Attachments can be images, files, or other media types. """ - __tablename__ = "attachments" # type: ignore + __tablename__: str = "attachments" inner_attachment_id: int | None = Field( primary_key=True, @@ -320,17 +321,17 @@ class Personality(TypedDict): 在 v4.0.0 版本及之后,推荐使用上面的 Persona 类。并且, mood_imitation_dialogs 字段已被废弃。 """ - prompt: str = "" - name: str = "" - begin_dialogs: list[str] = [] - mood_imitation_dialogs: list[str] = [] + prompt: str + name: str + begin_dialogs: list[str] + mood_imitation_dialogs: list[str] """情感模拟对话预设。在 v4.0.0 版本及之后,已被废弃。""" - tools: list[str] | None = None + tools: list[str] | None """工具列表。None 表示使用所有工具,空列表表示不使用任何工具""" # cache - _begin_dialogs_processed: list[dict] = [] - _mood_imitation_dialogs_processed: str = "" + _begin_dialogs_processed: list[dict] + _mood_imitation_dialogs_processed: str # ==== diff --git a/astrbot/core/db/sqlite.py b/astrbot/core/db/sqlite.py index ffa37f1e5..7203a40d1 100644 --- a/astrbot/core/db/sqlite.py +++ b/astrbot/core/db/sqlite.py @@ -3,6 +3,7 @@ import threading import typing as T from datetime import datetime, timedelta, timezone +from sqlalchemy import CursorResult from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col, delete, desc, func, or_, select, text, update @@ -107,8 +108,8 @@ class SQLiteDatabase(BaseDatabase): text(""" SELECT * FROM platform_stats WHERE timestamp >= :start_time - ORDER BY timestamp DESC GROUP BY platform_id + ORDER BY timestamp DESC """), {"start_time": start_time}, ) @@ -451,6 +452,18 @@ class SQLiteDatabase(BaseDatabase): result = await session.execute(query.offset(offset).limit(page_size)) return result.scalars().all() + async def get_platform_message_history_by_id( + self, message_id: int + ) -> PlatformMessageHistory | None: + """Get a platform message history record by its ID.""" + async with self.get_db() as session: + session: AsyncSession + query = select(PlatformMessageHistory).where( + PlatformMessageHistory.id == message_id + ) + result = await session.execute(query) + return result.scalar_one_or_none() + async def insert_attachment(self, path, type, mime_type): """Insert a new attachment record.""" async with self.get_db() as session: @@ -472,6 +485,48 @@ class SQLiteDatabase(BaseDatabase): result = await session.execute(query) return result.scalar_one_or_none() + async def get_attachments(self, attachment_ids: list[str]) -> list: + """Get multiple attachments by their IDs.""" + if not attachment_ids: + return [] + async with self.get_db() as session: + session: AsyncSession + query = select(Attachment).where( + col(Attachment.attachment_id).in_(attachment_ids) + ) + result = await session.execute(query) + return list(result.scalars().all()) + + async def delete_attachment(self, attachment_id: str) -> bool: + """Delete an attachment by its ID. + + Returns True if the attachment was deleted, False if it was not found. + """ + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = delete(Attachment).where( + col(Attachment.attachment_id) == attachment_id + ) + result = T.cast(CursorResult, await session.execute(query)) + return result.rowcount > 0 + + async def delete_attachments(self, attachment_ids: list[str]) -> int: + """Delete multiple attachments by their IDs. + + Returns the number of attachments deleted. + """ + if not attachment_ids: + return 0 + async with self.get_db() as session: + session: AsyncSession + async with session.begin(): + query = delete(Attachment).where( + col(Attachment.attachment_id).in_(attachment_ids) + ) + result = T.cast(CursorResult, await session.execute(query)) + return result.rowcount + async def insert_persona( self, persona_id, diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index 24f1c323c..564454cb1 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -90,4 +90,6 @@ class EmbeddingStorage: path (str): 保存索引的路径 """ + if self.index is None: + return faiss.write_index(self.index, self.path) diff --git a/astrbot/core/event_bus.py b/astrbot/core/event_bus.py index 749df753e..0017e65fa 100644 --- a/astrbot/core/event_bus.py +++ b/astrbot/core/event_bus.py @@ -27,7 +27,7 @@ class EventBus: self, event_queue: Queue, pipeline_scheduler_mapping: dict[str, PipelineScheduler], - astrbot_config_mgr: AstrBotConfigManager = None, + astrbot_config_mgr: AstrBotConfigManager, ): self.event_queue = event_queue # 事件队列 # abconf uuid -> scheduler @@ -40,6 +40,11 @@ class EventBus: conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin) self._print_event(event, conf_info["name"]) scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"]) + if not scheduler: + logger.error( + f"PipelineScheduler not found for id: {conf_info['id']}, event ignored." + ) + continue asyncio.create_task(scheduler.execute(event)) def _print_event(self, event: AstrMessageEvent, conf_name: str): diff --git a/astrbot/core/knowledge_base/retrieval/manager.py b/astrbot/core/knowledge_base/retrieval/manager.py index 9a42cd6cd..746406e90 100644 --- a/astrbot/core/knowledge_base/retrieval/manager.py +++ b/astrbot/core/knowledge_base/retrieval/manager.py @@ -166,7 +166,11 @@ class RetrievalManager: # 5. Rerank first_rerank = None for kb_id in kb_ids: - vec_db: FaissVecDB = kb_options[kb_id]["vec_db"] + vec_db = kb_options[kb_id]["vec_db"] + if not isinstance(vec_db, FaissVecDB): + logger.warning(f"vec_db for kb_id {kb_id} is not FaissVecDB") + continue + rerank_pi = kb_options[kb_id]["rerank_provider_id"] if ( vec_db diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 47d6ff781..0e7b3bab6 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -66,6 +66,9 @@ class ComponentType(str, Enum): class BaseMessageComponent(BaseModel): type: ComponentType + def __init__(self, **kwargs): + super().__init__(**kwargs) + def toDict(self): data = {} for k, v in self.__dict__.items(): @@ -551,7 +554,7 @@ class Node(BaseMessageComponent): id: int | None = 0 # 忽略 name: str | None = "" # qq昵称 uin: str | None = "0" # qq号 - content: list[BaseMessageComponent] | None = [] + content: list[BaseMessageComponent] = [] seq: str | list | None = "" # 忽略 time: int | None = 0 # 忽略 @@ -615,7 +618,7 @@ class Nodes(BaseMessageComponent): ret["messages"].append(d) return ret - async def to_dict(self): + async def to_dict(self) -> dict: """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" ret = {"messages": []} for node in self.nodes: @@ -714,12 +717,15 @@ class File(BaseMessageComponent): if self.url: await self._download_file() - return os.path.abspath(self.file_) + if self.file_: + return os.path.abspath(self.file_) return "" async def _download_file(self): """下载文件""" + if not self.url: + raise ValueError("Download failed: No URL provided in File component.") download_dir = os.path.join(get_astrbot_data_path(), "temp") os.makedirs(download_dir, exist_ok=True) if self.name: diff --git a/astrbot/core/persona_mgr.py b/astrbot/core/persona_mgr.py index 5d1743ab9..b2d2c6be1 100644 --- a/astrbot/core/persona_mgr.py +++ b/astrbot/core/persona_mgr.py @@ -98,8 +98,8 @@ class PersonaManager: self, persona_id: str, system_prompt: str, - begin_dialogs: list[str] = None, - tools: list[str] = None, + begin_dialogs: list[str] | None = None, + tools: list[str] | None = None, ) -> Persona: """创建新的 persona。tools 参数为 None 时表示使用所有工具,空列表表示不使用任何工具""" if await self.db.get_persona_by_id(persona_id): diff --git a/astrbot/core/pipeline/content_safety_check/stage.py b/astrbot/core/pipeline/content_safety_check/stage.py index c477cc23a..b089c48e0 100644 --- a/astrbot/core/pipeline/content_safety_check/stage.py +++ b/astrbot/core/pipeline/content_safety_check/stage.py @@ -24,7 +24,7 @@ class ContentSafetyCheckStage(Stage): self, event: AstrMessageEvent, check_text: str | None = None, - ) -> None | AsyncGenerator[None, None]: + ) -> AsyncGenerator[None, None]: """检查内容安全""" text = check_text if check_text else event.get_message_str() ok, info = self.strategy_selector.check(text) diff --git a/astrbot/core/pipeline/context_utils.py b/astrbot/core/pipeline/context_utils.py index 73d28c5d1..1f5ba43a0 100644 --- a/astrbot/core/pipeline/context_utils.py +++ b/astrbot/core/pipeline/context_utils.py @@ -11,7 +11,7 @@ from astrbot.core.star.star_handler import EventType, star_handlers_registry async def call_handler( event: AstrMessageEvent, - handler: T.Callable[..., T.Awaitable[T.Any]], + handler: T.Callable[..., T.Awaitable[T.Any] | T.AsyncGenerator[T.Any, None]], *args, **kwargs, ) -> T.AsyncGenerator[T.Any, None]: @@ -91,6 +91,7 @@ async def call_event_hook( ) for handler in handlers: try: + assert inspect.iscoroutinefunction(handler.handler) logger.debug( f"hook({hook_type.name}) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", ) diff --git a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py index 34dc02ceb..b590bd77e 100644 --- a/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py +++ b/astrbot/core/pipeline/process_stage/method/agent_sub_stages/third_party.py @@ -57,7 +57,7 @@ async def run_third_party_agent( logger.error(f"Third party agent runner error: {e}") err_msg = ( f"\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n" - f"错误信息: {e!s}\n\n请在控制台查看和分享错误详情。\n" + f"错误信息: {e!s}\n\n请在平台日志查看和分享错误详情。\n" ) yield MessageChain().message(err_msg) diff --git a/astrbot/core/pipeline/process_stage/method/star_request.py b/astrbot/core/pipeline/process_stage/method/star_request.py index 56d305de4..8a79b96c9 100644 --- a/astrbot/core/pipeline/process_stage/method/star_request.py +++ b/astrbot/core/pipeline/process_stage/method/star_request.py @@ -16,7 +16,6 @@ from ..stage import Stage class StarRequestSubStage(Stage): async def initialize(self, ctx: PipelineContext) -> None: - self.curr_provider = ctx.plugin_manager.context.get_using_provider() self.prompt_prefix = ctx.astrbot_config["provider_settings"]["prompt_prefix"] self.identifier = ctx.astrbot_config["provider_settings"]["identifier"] self.ctx = ctx @@ -24,7 +23,7 @@ class StarRequestSubStage(Stage): async def process( self, event: AstrMessageEvent, - ) -> AsyncGenerator[None, None]: + ) -> AsyncGenerator[Any, None]: activated_handlers: list[StarHandlerMetadata] = event.get_extra( "activated_handlers", ) diff --git a/astrbot/core/pipeline/process_stage/stage.py b/astrbot/core/pipeline/process_stage/stage.py index e19b8dc18..076f7f12a 100644 --- a/astrbot/core/pipeline/process_stage/stage.py +++ b/astrbot/core/pipeline/process_stage/stage.py @@ -60,7 +60,7 @@ class ProcessStage(Stage): ): # 是否有过发送操作 and 是否是被 @ 或者通过唤醒前缀 if ( - event.get_result() and not event.get_result().is_stopped() + event.get_result() and not event.is_stopped() ) or not event.get_result(): async for _ in self.agent_sub_stage.process(event): yield diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index 760649563..8f1b87efc 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -117,7 +117,9 @@ class RespondStage(Stage): if not self.enable_seg: return False - if self.only_llm_result and not event.get_result().is_llm_result(): + if (result := event.get_result()) is None: + return False + if self.only_llm_result and result.is_llm_result(): return False if event.get_platform_name() in [ @@ -185,7 +187,7 @@ class RespondStage(Stage): if isinstance(component, Comp.File) and component.file: # 支持 File 消息段的路径映射。 component.file = path_Mapping(mappings, component.file) - event.get_result().chain[idx] = component + result.chain[idx] = component # 检查消息链是否为空 try: diff --git a/astrbot/core/pipeline/result_decorate/stage.py b/astrbot/core/pipeline/result_decorate/stage.py index ef394edcf..208f3a9f2 100644 --- a/astrbot/core/pipeline/result_decorate/stage.py +++ b/astrbot/core/pipeline/result_decorate/stage.py @@ -6,6 +6,7 @@ from collections.abc import AsyncGenerator from astrbot.core import file_token_service, html_renderer, logger from astrbot.core.message.components import At, File, Image, Node, Plain, Record, Reply from astrbot.core.message.message_event_result import ResultContentType +from astrbot.core.pipeline.content_safety_check.stage import ContentSafetyCheckStage from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.message_type import MessageType from astrbot.core.star.session_llm_manager import SessionServiceManager @@ -53,7 +54,22 @@ class ResultDecorateStage(Stage): self.only_llm_result = ctx.astrbot_config["platform_settings"][ "segmented_reply" ]["only_llm_result"] + self.split_mode = ctx.astrbot_config["platform_settings"][ + "segmented_reply" + ].get("split_mode", "regex") self.regex = ctx.astrbot_config["platform_settings"]["segmented_reply"]["regex"] + self.split_words = ctx.astrbot_config["platform_settings"][ + "segmented_reply" + ].get("split_words", ["。", "?", "!", "~", "…"]) + if self.split_words: + escaped_words = sorted( + [re.escape(word) for word in self.split_words], key=len, reverse=True + ) + self.split_words_pattern = re.compile( + f"(.*?({'|'.join(escaped_words)})|.+$)", re.DOTALL + ) + else: + self.split_words_pattern = None self.content_cleanup_rule = ctx.astrbot_config["platform_settings"][ "segmented_reply" ]["content_cleanup_rule"] @@ -69,6 +85,28 @@ class ResultDecorateStage(Stage): self.content_safe_check_stage = stage_cls() await self.content_safe_check_stage.initialize(ctx) + def _split_text_by_words(self, text: str) -> list[str]: + """使用分段词列表分段文本""" + if not self.split_words_pattern: + return [text] + + segments = self.split_words_pattern.findall(text) + result = [] + for seg in segments: + if isinstance(seg, tuple): + content = seg[0] + if not isinstance(content, str): + continue + for word in self.split_words: + if content.endswith(word): + content = content[: -len(word)] + break + if content.strip(): + result.append(content) + elif seg and seg.strip(): + result.append(seg) + return result if result else [text] + async def process( self, event: AstrMessageEvent, @@ -93,11 +131,13 @@ class ResultDecorateStage(Stage): for comp in result.chain: if isinstance(comp, Plain): text += comp.text - async for _ in self.content_safe_check_stage.process( - event, - check_text=text, - ): - yield + + if isinstance(self.content_safe_check_stage, ContentSafetyCheckStage): + async for _ in self.content_safe_check_stage.process( + event, + check_text=text, + ): + yield # 发送消息前事件钩子 handlers = star_handlers_registry.get_handlers_by_event_type( @@ -114,7 +154,8 @@ class ResultDecorateStage(Stage): "启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作", ) await handler.handler(event) - if event.get_result() is None or not event.get_result().chain: + + if (result := event.get_result()) is None or not result.chain: logger.debug( f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name} 将消息结果清空。", ) @@ -161,21 +202,27 @@ class ResultDecorateStage(Stage): # 不分段回复 new_chain.append(comp) continue - try: - split_response = re.findall( - self.regex, - comp.text, - re.DOTALL | re.MULTILINE, - ) - except re.error: - logger.error( - f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}", - ) - split_response = re.findall( - r".*?[。?!~…]+|.+$", - comp.text, - re.DOTALL | re.MULTILINE, - ) + + # 根据 split_mode 选择分段方式 + if self.split_mode == "words": + split_response = self._split_text_by_words(comp.text) + else: # regex 模式 + try: + split_response = re.findall( + self.regex, + comp.text, + re.DOTALL | re.MULTILINE, + ) + except re.error: + logger.error( + f"分段回复正则表达式错误,使用默认分段方式: {traceback.format_exc()}", + ) + split_response = re.findall( + r".*?[。?!~…]+|.+$", + comp.text, + re.DOTALL | re.MULTILINE, + ) + if not split_response: new_chain.append(comp) continue diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 5c461a1e1..5fb3034f5 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -2,6 +2,10 @@ from collections.abc import AsyncGenerator from astrbot.core import logger from astrbot.core.platform import AstrMessageEvent +from astrbot.core.platform.sources.webchat.webchat_event import WebChatMessageEvent +from astrbot.core.platform.sources.wecom_ai_bot.wecomai_event import ( + WecomAIBotMessageEvent, +) from . import STAGES_ORDER from .context import PipelineContext @@ -78,7 +82,7 @@ class PipelineScheduler: await self._process_stages(event) # 如果没有发送操作, 则发送一个空消息, 以便于后续的处理 - if event.get_platform_name() in ["webchat", "wecom_ai_bot"]: + if isinstance(event, (WebChatMessageEvent, WecomAIBotMessageEvent)): await event.send(None) logger.debug("pipeline 执行完毕。") diff --git a/astrbot/core/pipeline/waking_check/stage.py b/astrbot/core/pipeline/waking_check/stage.py index 814919115..1efda7c84 100644 --- a/astrbot/core/pipeline/waking_check/stage.py +++ b/astrbot/core/pipeline/waking_check/stage.py @@ -50,6 +50,9 @@ class WakingCheckStage(Stage): "ignore_at_all", False, ) + self.disable_builtin_commands = self.ctx.astrbot_config.get( + "disable_builtin_commands", False + ) async def process( self, @@ -131,6 +134,13 @@ class WakingCheckStage(Stage): EventType.AdapterMessageEvent, plugins_name=event.plugins_name, ): + if ( + self.disable_builtin_commands + and handler.handler_module_path == "packages.builtin_commands.main" + ): + logger.debug("skipping builtin command") + continue + # filter 需满足 AND 逻辑关系 passed = True permission_not_pass = False diff --git a/astrbot/core/platform/astr_message_event.py b/astrbot/core/platform/astr_message_event.py index 6402aeaed..f6eda07a9 100644 --- a/astrbot/core/platform/astr_message_event.py +++ b/astrbot/core/platform/astr_message_event.py @@ -153,7 +153,9 @@ class AstrMessageEvent(abc.ABC): def get_sender_name(self) -> str: """获取消息发送者的名称。(可能会返回空字符串)""" - return self.message_obj.sender.nickname + if isinstance(self.message_obj.sender.nickname, str): + return self.message_obj.sender.nickname + return "" def set_extra(self, key, value): """设置额外的信息。""" @@ -270,7 +272,7 @@ class AstrMessageEvent(abc.ABC): """ self.call_llm = call_llm - def get_result(self) -> MessageEventResult: + def get_result(self) -> MessageEventResult | None: """获取消息事件的结果。""" return self._result @@ -320,7 +322,7 @@ class AstrMessageEvent(abc.ABC): self, prompt: str, func_tool_manager=None, - session_id: str = None, + session_id: str = "", image_urls: list[str] | None = None, contexts: list | None = None, system_prompt: str = "", diff --git a/astrbot/core/platform/astrbot_message.py b/astrbot/core/platform/astrbot_message.py index 0ada18506..253963322 100644 --- a/astrbot/core/platform/astrbot_message.py +++ b/astrbot/core/platform/astrbot_message.py @@ -54,7 +54,7 @@ class AstrBotMessage: self_id: str # 机器人的识别id session_id: str # 会话id。取决于 unique_session 的设置。 message_id: str # 消息id - group: Group # 群组 + group: Group | None # 群组 sender: MessageMember # 发送者 message: list[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式 message_str: str # 最直观的纯文本消息字符串 @@ -78,7 +78,7 @@ class AstrBotMessage: return "" @group_id.setter - def group_id(self, value: str): + def group_id(self, value: str | None): """设置 group_id""" if value: if self.group: diff --git a/astrbot/core/platform/manager.py b/astrbot/core/platform/manager.py index 9ff892025..b941c8cbc 100644 --- a/astrbot/core/platform/manager.py +++ b/astrbot/core/platform/manager.py @@ -6,7 +6,7 @@ from astrbot.core import logger from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map -from .platform import Platform +from .platform import Platform, PlatformStatus from .register import platform_cls_map from .sources.webchat.webchat_adapter import WebChatAdapter @@ -16,7 +16,7 @@ class PlatformManager: self.platform_insts: list[Platform] = [] """加载的 Platform 的实例""" - self._inst_map = {} + self._inst_map: dict[str, dict] = {} self.platforms_config = config["platform"] self.settings = config["platform_settings"] @@ -37,7 +37,10 @@ class PlatformManager: webchat_inst = WebChatAdapter({}, self.settings, self.event_queue) self.platform_insts.append(webchat_inst) asyncio.create_task( - self._task_wrapper(asyncio.create_task(webchat_inst.run(), name="webchat")), + self._task_wrapper( + asyncio.create_task(webchat_inst.run(), name="webchat"), + platform=webchat_inst, + ), ) async def load_platform(self, platform_config: dict): @@ -107,7 +110,7 @@ class PlatformManager: ) except (ImportError, ModuleNotFoundError) as e: logger.error( - f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。", + f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。", ) except Exception as e: logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。") @@ -131,6 +134,7 @@ class PlatformManager: inst.run(), name=f"platform_{platform_config['type']}_{platform_config['id']}", ), + platform=inst, ), ) handlers = star_handlers_registry.get_handlers_by_event_type( @@ -145,17 +149,28 @@ class PlatformManager: except Exception: logger.error(traceback.format_exc()) - async def _task_wrapper(self, task: asyncio.Task): + async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None): + # 设置平台状态为运行中 + if platform: + platform.status = PlatformStatus.RUNNING + try: await task except asyncio.CancelledError: - pass + if platform: + platform.status = PlatformStatus.STOPPED except Exception as e: + error_msg = str(e) + tb_str = traceback.format_exc() logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") - for line in traceback.format_exc().split("\n"): + for line in tb_str.split("\n"): logger.error(f"| {line}") logger.error("-------") + # 记录错误到平台实例 + if platform: + platform.record_error(error_msg, tb_str) + async def reload(self, platform_config: dict): await self.terminate_platform(platform_config["id"]) if platform_config["enable"]: @@ -172,9 +187,9 @@ class PlatformManager: logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") # client_id = self._inst_map.pop(platform_id, None) - info = self._inst_map.pop(platform_id, None) + info = self._inst_map.pop(platform_id) client_id = info["client_id"] - inst = info["inst"] + inst: Platform = info["inst"] try: self.platform_insts.remove( next( @@ -196,3 +211,46 @@ class PlatformManager: def get_insts(self): return self.platform_insts + + def get_all_stats(self) -> dict: + """获取所有平台的统计信息 + + Returns: + 包含所有平台统计信息的字典 + """ + stats_list = [] + total_errors = 0 + running_count = 0 + error_count = 0 + + for inst in self.platform_insts: + try: + stat = inst.get_stats() + stats_list.append(stat) + total_errors += stat.get("error_count", 0) + if stat.get("status") == PlatformStatus.RUNNING.value: + running_count += 1 + elif stat.get("status") == PlatformStatus.ERROR.value: + error_count += 1 + except Exception as e: + # 如果获取统计信息失败,记录基本信息 + logger.warning(f"获取平台统计信息失败: {e}") + stats_list.append( + { + "id": getattr(inst, "config", {}).get("id", "unknown"), + "type": "unknown", + "status": "unknown", + "error_count": 0, + "last_error": None, + } + ) + + return { + "platforms": stats_list, + "summary": { + "total": len(stats_list), + "running": running_count, + "error": error_count, + "total_errors": total_errors, + }, + } diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index 3f36e17f3..c139b8bd7 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -1,7 +1,10 @@ import abc import uuid from asyncio import Queue -from collections.abc import Awaitable +from collections.abc import Coroutine +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum from typing import Any from astrbot.core.message.message_event_result import MessageChain @@ -12,15 +15,92 @@ from .message_session import MessageSesion from .platform_metadata import PlatformMetadata +class PlatformStatus(Enum): + """平台运行状态""" + + PENDING = "pending" # 待启动 + RUNNING = "running" # 运行中 + ERROR = "error" # 发生错误 + STOPPED = "stopped" # 已停止 + + +@dataclass +class PlatformError: + """平台错误信息""" + + message: str + timestamp: datetime = field(default_factory=datetime.now) + traceback: str | None = None + + class Platform(abc.ABC): - def __init__(self, event_queue: Queue): + def __init__(self, config: dict, event_queue: Queue): super().__init__() + # 平台配置 + self.config = config # 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。 self._event_queue = event_queue self.client_self_id = uuid.uuid4().hex + # 平台运行状态 + self._status: PlatformStatus = PlatformStatus.PENDING + self._errors: list[PlatformError] = [] + self._started_at: datetime | None = None + + @property + def status(self) -> PlatformStatus: + """获取平台运行状态""" + return self._status + + @status.setter + def status(self, value: PlatformStatus): + """设置平台运行状态""" + self._status = value + if value == PlatformStatus.RUNNING and self._started_at is None: + self._started_at = datetime.now() + + @property + def errors(self) -> list[PlatformError]: + """获取错误列表""" + return self._errors + + @property + def last_error(self) -> PlatformError | None: + """获取最近的错误""" + return self._errors[-1] if self._errors else None + + def record_error(self, message: str, traceback_str: str | None = None): + """记录一个错误""" + self._errors.append(PlatformError(message=message, traceback=traceback_str)) + self._status = PlatformStatus.ERROR + + def clear_errors(self): + """清除错误记录""" + self._errors.clear() + if self._status == PlatformStatus.ERROR: + self._status = PlatformStatus.RUNNING + + def get_stats(self) -> dict: + """获取平台统计信息""" + meta = self.meta() + return { + "id": meta.id or self.config.get("id"), + "type": meta.name, + "display_name": meta.adapter_display_name or meta.name, + "status": self._status.value, + "started_at": self._started_at.isoformat() if self._started_at else None, + "error_count": len(self._errors), + "last_error": { + "message": self.last_error.message, + "timestamp": self.last_error.timestamp.isoformat(), + "traceback": self.last_error.traceback, + } + if self.last_error + else None, + } + @abc.abstractmethod - def run(self) -> Awaitable[Any]: + def run(self) -> Coroutine[Any, Any, None]: """得到一个平台的运行实例,需要返回一个协程对象。""" raise NotImplementedError @@ -36,7 +116,7 @@ class Platform(abc.ABC): self, session: MessageSesion, message_chain: MessageChain, - ) -> Awaitable[Any]: + ) -> None: """通过会话发送消息。该方法旨在让插件能够直接通过**可持久化的会话数据**发送消息,而不需要保存 event 对象。 异步方法。 @@ -49,3 +129,20 @@ class Platform(abc.ABC): def get_client(self): """获取平台的客户端对象。""" + + async def webhook_callback(self, request: Any) -> Any: + """统一 Webhook 回调入口。 + + 支持统一 Webhook 模式的平台需要实现此方法。 + 当 Dashboard 收到 /api/platform/webhook/{uuid} 请求时,会调用此方法。 + + Args: + request: Quart 请求对象 + + Returns: + 响应内容,格式取决于具体平台的要求 + + Raises: + NotImplementedError: 平台未实现统一 Webhook 模式 + """ + raise NotImplementedError(f"平台 {self.meta().name} 未实现统一 Webhook 模式") diff --git a/astrbot/core/platform/platform_metadata.py b/astrbot/core/platform/platform_metadata.py index c63bd82b1..06455aac4 100644 --- a/astrbot/core/platform/platform_metadata.py +++ b/astrbot/core/platform/platform_metadata.py @@ -7,7 +7,7 @@ class PlatformMetadata: """平台的名称,即平台的类型,如 aiocqhttp, discord, slack""" description: str """平台的描述""" - id: str | None = None + id: str """平台的唯一标识符,用于配置中识别特定平台""" default_config_tmpl: dict | None = None diff --git a/astrbot/core/platform/register.py b/astrbot/core/platform/register.py index c1721c5c5..5f550ecd1 100644 --- a/astrbot/core/platform/register.py +++ b/astrbot/core/platform/register.py @@ -40,6 +40,7 @@ def register_platform_adapter( pm = PlatformMetadata( name=adapter_name, description=desc, + id=adapter_name, default_config_tmpl=default_config_tmpl, adapter_display_name=adapter_display_name, logo_path=logo_path, diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index ce8fd56df..293b462d3 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -70,16 +70,18 @@ class AiocqhttpMessageEvent(AstrMessageEvent): bot: CQHttp, event: Event | None, is_group: bool, - session_id: str, + session_id: str | None, messages: list[dict], ): # session_id 必须是纯数字字符串 - session_id = int(session_id) if session_id.isdigit() else None + session_id_int = ( + int(session_id) if session_id and session_id.isdigit() else None + ) - if is_group and isinstance(session_id, int): - await bot.send_group_msg(group_id=session_id, message=messages) - elif not is_group and isinstance(session_id, int): - await bot.send_private_msg(user_id=session_id, message=messages) + if is_group and isinstance(session_id_int, int): + await bot.send_group_msg(group_id=session_id_int, message=messages) + elif not is_group and isinstance(session_id_int, int): + await bot.send_private_msg(user_id=session_id_int, message=messages) elif isinstance(event, Event): # 最后兜底 await bot.send(event=event, message=messages) else: diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index 8e8bcdb30..b3c2229ab 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -4,7 +4,7 @@ import logging import time import uuid from collections.abc import Awaitable -from typing import Any +from typing import Any, cast from aiocqhttp import CQHttp, Event from aiocqhttp.exceptions import ActionFailed @@ -38,9 +38,8 @@ class AiocqhttpAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) + super().__init__(platform_config, event_queue) - self.config = platform_config self.settings = platform_settings self.unique_session = platform_settings["unique_session"] self.host = platform_config["ws_reverse_host"] @@ -49,7 +48,7 @@ class AiocqhttpAdapter(Platform): self.metadata = PlatformMetadata( name="aiocqhttp", description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), support_streaming_message=False, ) @@ -128,7 +127,9 @@ class AiocqhttpAdapter(Platform): """OneBot V11 请求类事件""" abm = AstrBotMessage() abm.self_id = str(event.self_id) - abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id) + abm.sender = MessageMember( + user_id=str(event.user_id), nickname=str(event.user_id) + ) abm.type = MessageType.OTHER_MESSAGE if event.get("group_id"): abm.type = MessageType.GROUP_MESSAGE @@ -154,7 +155,9 @@ class AiocqhttpAdapter(Platform): """OneBot V11 通知类事件""" abm = AstrBotMessage() abm.self_id = str(event.self_id) - abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id) + abm.sender = MessageMember( + user_id=str(event.user_id), nickname=str(event.user_id) + ) abm.type = MessageType.OTHER_MESSAGE if event.get("group_id"): abm.group_id = str(event.group_id) @@ -193,6 +196,7 @@ class AiocqhttpAdapter(Platform): @param event: 事件对象 @param get_reply: 是否获取回复消息。这个参数是为了防止多个回复嵌套。 """ + assert event.sender is not None abm = AstrBotMessage() abm.self_id = str(event.self_id) abm.sender = MessageMember( @@ -202,6 +206,7 @@ class AiocqhttpAdapter(Platform): if event["message_type"] == "group": abm.type = MessageType.GROUP_MESSAGE abm.group_id = str(event.group_id) + abm.group = Group(str(event.group_id)) abm.group.group_name = event.get("group_name", "N/A") elif event["message_type"] == "private": abm.type = MessageType.FRIEND_MESSAGE @@ -227,7 +232,7 @@ class AiocqhttpAdapter(Platform): await self.bot.send(event, err) except BaseException as e: logger.error(f"回复消息失败: {e}") - return None + raise ValueError(err) # 按消息段类型类型适配 for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]): diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 7ad612ef6..8905698a5 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -2,6 +2,7 @@ import asyncio import os import threading import uuid +from typing import cast import aiohttp import dingtalk_stream @@ -47,21 +48,21 @@ class DingtalkPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - - self.config = platform_config + super().__init__(platform_config, event_queue) self.unique_session = platform_settings["unique_session"] self.client_id = platform_config["client_id"] self.client_secret = platform_config["client_secret"] + outer_self = self + class AstrCallbackClient(dingtalk_stream.ChatbotHandler): - async def process(self_, message: dingtalk_stream.CallbackMessage): + async def process(self, message: dingtalk_stream.CallbackMessage): logger.debug(f"dingtalk: {message.data}") im = dingtalk_stream.ChatbotMessage.from_dict(message.data) - abm = await self.convert_msg(im) - await self.handle_msg(abm) + abm = await outer_self.convert_msg(im) + await outer_self.handle_msg(abm) return AckMessage.STATUS_OK, "OK" @@ -75,14 +76,15 @@ class DingtalkPlatformAdapter(Platform): self.client, ) self.client_ = client # 用于 websockets 的 client + self._shutdown_event: threading.Event | None = None - def _id_to_sid(self, dingtalk_id: str | None) -> str | None: + def _id_to_sid(self, dingtalk_id: str | None) -> str: if not dingtalk_id: - return dingtalk_id + return dingtalk_id or "unknown" prefix = "$:LWCP_v1:$" if dingtalk_id.startswith(prefix): return dingtalk_id[len(prefix) :] - return dingtalk_id + return dingtalk_id or "unknown" async def send_by_session( self, @@ -95,7 +97,7 @@ class DingtalkPlatformAdapter(Platform): return PlatformMetadata( name="dingtalk", description="钉钉机器人官方 API 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), support_streaming_message=False, ) @@ -106,7 +108,7 @@ class DingtalkPlatformAdapter(Platform): abm = AstrBotMessage() abm.message = [] abm.message_str = "" - abm.timestamp = int(message.create_at / 1000) + abm.timestamp = int(cast(int, message.create_at) / 1000) abm.type = ( MessageType.GROUP_MESSAGE if message.conversation_type == "2" @@ -117,7 +119,7 @@ class DingtalkPlatformAdapter(Platform): nickname=message.sender_nick, ) abm.self_id = self._id_to_sid(message.chatbot_user_id) - abm.message_id = message.message_id + abm.message_id = cast(str, message.message_id) abm.raw_message = message if abm.type == MessageType.GROUP_MESSAGE: @@ -134,14 +136,16 @@ class DingtalkPlatformAdapter(Platform): else: abm.session_id = abm.sender.user_id - message_type: str = message.message_type + message_type: str = cast(str, message.message_type) match message_type: case "text": abm.message_str = message.text.content.strip() abm.message.append(Plain(abm.message_str)) case "richText": - rtc: dingtalk_stream.RichTextContent = message.rich_text_content - contents: list[dict] = rtc.rich_text_list + rtc: dingtalk_stream.RichTextContent = cast( + dingtalk_stream.RichTextContent, message.rich_text_content + ) + contents: list[dict] = cast(list[dict], rtc.rich_text_list) for content in contents: plains = "" if "text" in content: @@ -150,7 +154,7 @@ class DingtalkPlatformAdapter(Platform): elif "type" in content and content["type"] == "picture": f_path = await self.download_ding_file( content["downloadCode"], - message.robot_code, + cast(str, message.robot_code), "jpg", ) abm.message.append(Image.fromFileSystem(f_path)) @@ -195,7 +199,7 @@ class DingtalkPlatformAdapter(Platform): logger.error( f"下载钉钉文件失败: {resp.status}, {await resp.text()}", ) - return None + return "" resp_data = await resp.json() download_url = resp_data["data"]["downloadUrl"] await download_file(download_url, f_path) @@ -215,7 +219,7 @@ class DingtalkPlatformAdapter(Platform): logger.error( f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}", ) - return None + return "" return (await resp.json())["data"]["accessToken"] async def handle_msg(self, abm: AstrBotMessage): @@ -252,9 +256,11 @@ class DingtalkPlatformAdapter(Platform): def monkey_patch_close(): raise KeyboardInterrupt("Graceful shutdown") - self.client_.open_connection = monkey_patch_close - await self.client_.websocket.close(code=1000, reason="Graceful shutdown") - self._shutdown_event.set() + if self.client_.websocket is not None: + self.client_.open_connection = monkey_patch_close + await self.client_.websocket.close(code=1000, reason="Graceful shutdown") + if self._shutdown_event is not None: + self._shutdown_event.set() def get_client(self): return self.client diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py index a1cd9c1aa..d520189d8 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_event.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_event.py @@ -1,4 +1,5 @@ import asyncio +from typing import cast import dingtalk_stream @@ -32,7 +33,7 @@ class DingtalkMessageEvent(AstrMessageEvent): client.reply_markdown, segment.text, segment.text, - self.message_obj.raw_message, + cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message), ) elif isinstance(segment, Comp.Image): markdown_str = "" @@ -53,7 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent): client.reply_markdown, "😄", markdown_str, - self.message_obj.raw_message, + cast( + dingtalk_stream.ChatbotMessage, self.message_obj.raw_message + ), ) logger.debug(f"send image: {ret}") diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index 5d29e3429..ac0610f2a 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -1,4 +1,5 @@ import sys +from collections.abc import Awaitable, Callable import discord @@ -27,13 +28,16 @@ class DiscordBotClient(discord.Bot): super().__init__(intents=intents, proxy=proxy) # 回调函数 - self.on_message_received = None - self.on_ready_once_callback = None + self.on_message_received: Callable[[dict], Awaitable[None]] | None = None + self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None self._ready_once_fired = False - @override async def on_ready(self): """当机器人成功连接并准备就绪时触发""" + if self.user is None: + logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)") + return + logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录") logger.info("[Discord] 客户端已准备就绪。") @@ -49,6 +53,9 @@ class DiscordBotClient(discord.Bot): def _create_message_data(self, message: discord.Message) -> dict: """从 discord.Message 创建数据字典""" + if self.user is None: + raise RuntimeError("Bot is not ready: self.user is None") + is_mentioned = self.user in message.mentions return { "message": message, @@ -66,6 +73,12 @@ class DiscordBotClient(discord.Bot): def _create_interaction_data(self, interaction: discord.Interaction) -> dict: """从 discord.Interaction 创建数据字典""" + if self.user is None: + raise RuntimeError("Bot is not ready: self.user is None") + + if interaction.user is None: + raise ValueError("Interaction received without a valid user") + return { "interaction": interaction, "bot_id": str(self.user.id), @@ -80,7 +93,6 @@ class DiscordBotClient(discord.Bot): "type": "interaction", } - @override async def on_message(self, message: discord.Message): """当接收到消息时触发""" if message.author.bot: diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index d3e69e763..f875652a0 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -97,8 +97,8 @@ class DiscordView(BaseMessageComponent): def __init__( self, - components: list[BaseMessageComponent] = None, - timeout: float = None, + components: list[BaseMessageComponent] | None = None, + timeout: float | None = None, ): self.components = components or [] self.timeout = timeout diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 49b886dea..50aa0fe6f 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -1,10 +1,10 @@ import asyncio import re import sys -from typing import Any +from typing import Any, cast import discord -from discord.abc import Messageable +from discord.abc import GuildChannel, Messageable, PrivateChannel from discord.channel import DMChannel from astrbot import logger @@ -44,10 +44,9 @@ class DiscordPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - self.config = platform_config + super().__init__(platform_config, event_queue) self.settings = platform_settings - self.client_self_id = None + self.client_self_id: str | None = None self.registered_handlers = [] # 指令注册相关 self.enable_command_register = self.config.get("discord_command_register", True) @@ -63,6 +62,12 @@ class DiscordPlatformAdapter(Platform): message_chain: MessageChain, ): """通过会话发送消息""" + if self.client.user is None: + logger.error( + "[Discord] 客户端未就绪 (self.client.user is None),无法发送消息" + ) + return + # 创建一个 message_obj 以便在 event 中使用 message_obj = AstrBotMessage() if "_" in session.session_id: @@ -90,7 +95,7 @@ class DiscordPlatformAdapter(Platform): user_id=str(self.client_self_id), nickname=self.client.user.display_name, ) - message_obj.self_id = self.client_self_id + message_obj.self_id = cast(str, self.client_self_id) message_obj.session_id = session.session_id message_obj.message = message_chain.chain @@ -111,7 +116,7 @@ class DiscordPlatformAdapter(Platform): return PlatformMetadata( "discord", "Discord 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), default_config_tmpl=self.config, support_streaming_message=False, ) @@ -161,7 +166,7 @@ class DiscordPlatformAdapter(Platform): def _get_message_type( self, - channel: Messageable, + channel: Messageable | GuildChannel | PrivateChannel, guild_id: int | None = None, ) -> MessageType: """根据 channel 对象和 guild_id 判断消息类型""" @@ -171,13 +176,15 @@ class DiscordPlatformAdapter(Platform): return MessageType.FRIEND_MESSAGE return MessageType.GROUP_MESSAGE - def _get_channel_id(self, channel: Messageable) -> str: + def _get_channel_id( + self, channel: Messageable | GuildChannel | PrivateChannel + ) -> str: """根据 channel 对象获取ID""" return str(getattr(channel, "id", None)) def _convert_message_to_abm(self, data: dict) -> AstrBotMessage: """将普通消息转换为 AstrBotMessage""" - message: discord.Message = data["message"] + message = data["message"] content = message.content @@ -234,7 +241,7 @@ class DiscordPlatformAdapter(Platform): ) abm.message = message_chain abm.raw_message = message - abm.self_id = self.client_self_id + abm.self_id = cast(str, self.client_self_id) abm.session_id = str(message.channel.id) abm.message_id = str(message.id) return abm @@ -255,32 +262,52 @@ class DiscordPlatformAdapter(Platform): interaction_followup_webhook=followup_webhook, ) + if self.client.user is None: + logger.error( + "[Discord] 客户端未就绪 (self.client.user is None),无法处理消息" + ) + return + # 检查是否为斜杠指令 is_slash_command = message_event.interaction_followup_webhook is not None + # 1. 优先处理斜杠指令 + if is_slash_command: + message_event.is_wake = True + message_event.is_at_or_wake_command = True + self.commit_event(message_event) + return + + # 2. 处理普通消息(提及检测) + # 确保 raw_message 是 discord.Message 类型,以便静态检查通过 + raw_message = message.raw_message + if not isinstance(raw_message, discord.Message): + logger.warning( + f"[Discord] 收到非 Message 类型的消息: {type(raw_message)},已忽略。" + ) + return + # 检查是否被@(User Mention 或 Bot 拥有的 Role Mention) is_mention = False + # User Mention - if ( - self.client - and self.client.user - and hasattr(message.raw_message, "mentions") - ): - if self.client.user in message.raw_message.mentions: - is_mention = True + # 此时 Pylance 知道 raw_message 是 discord.Message,具有 mentions 属性 + if self.client.user in raw_message.mentions: + is_mention = True + # Role Mention(Bot 拥有的角色被提及) - if not is_mention and hasattr(message.raw_message, "role_mentions"): + if not is_mention and raw_message.role_mentions: bot_member = None - if hasattr(message.raw_message, "guild") and message.raw_message.guild: + if raw_message.guild: try: - bot_member = message.raw_message.guild.get_member( + bot_member = raw_message.guild.get_member( self.client.user.id, ) except Exception: bot_member = None if bot_member and hasattr(bot_member, "roles"): bot_roles = set(bot_member.roles) - mentioned_roles = set(message.raw_message.role_mentions) + mentioned_roles = set(raw_message.role_mentions) if ( bot_roles and mentioned_roles @@ -288,8 +315,8 @@ class DiscordPlatformAdapter(Platform): ): is_mention = True - # 如果是斜杠指令或被@的消息,设置为唤醒状态 - if is_slash_command or is_mention: + # 如果是被@的消息,设置为唤醒状态 + if is_mention: message_event.is_wake = True message_event.is_at_or_wake_command = True @@ -425,7 +452,7 @@ class DiscordPlatformAdapter(Platform): ) abm.message = [Plain(text=message_str_for_filter)] abm.raw_message = ctx.interaction - abm.self_id = self.client_self_id + abm.self_id = cast(str, self.client_self_id) abm.session_id = str(ctx.channel_id) abm.message_id = str(ctx.interaction.id) @@ -438,7 +465,7 @@ class DiscordPlatformAdapter(Platform): def _extract_command_info( event_filter: Any, handler_metadata: StarHandlerMetadata, - ) -> tuple[str, str, CommandFilter] | None: + ) -> tuple[str, str, CommandFilter | None] | None: """从事件过滤器中提取指令信息""" cmd_name = None # is_group = False diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index 82eb9f144..053018225 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -4,8 +4,10 @@ import binascii from collections.abc import AsyncGenerator from io import BytesIO from pathlib import Path +from typing import cast import discord +from discord.types.interactions import ComponentInteractionData from astrbot import logger from astrbot.api.event import AstrMessageEvent, MessageChain @@ -85,6 +87,9 @@ class DiscordPlatformEvent(AstrMessageEvent): channel = await self._get_channel() if not channel: return + if not isinstance(channel, discord.abc.Messageable): + logger.error(f"[Discord] 频道 {channel.id} 不是可发送消息的类型") + return await channel.send(**kwargs) except Exception as e: @@ -107,7 +112,9 @@ class DiscordPlatformEvent(AstrMessageEvent): await self.send(buffer) return await super().send_streaming(generator, use_fallback) - async def _get_channel(self) -> discord.abc.Messageable | None: + async def _get_channel( + self, + ) -> discord.Thread | discord.abc.GuildChannel | discord.abc.PrivateChannel | None: """获取当前事件对应的频道对象""" try: channel_id = int(self.session_id) @@ -121,7 +128,13 @@ class DiscordPlatformEvent(AstrMessageEvent): async def _parse_to_discord( self, message: MessageChain, - ) -> tuple[str, list[discord.File], discord.ui.View | None, list[discord.Embed]]: + ) -> tuple[ + str, + list[discord.File], + discord.ui.View | None, + list[discord.Embed], + str | int | None, + ]: """将 MessageChain 解析为 Discord 发送所需的内容""" content_parts = [] files = [] @@ -261,7 +274,9 @@ class DiscordPlatformEvent(AstrMessageEvent): self.message_obj.raw_message, "add_reaction", ): - await self.message_obj.raw_message.add_reaction(emoji) + await cast(discord.Message, self.message_obj.raw_message).add_reaction( + emoji + ) except Exception as e: logger.error(f"[Discord] 添加反应失败: {e}") @@ -270,7 +285,7 @@ class DiscordPlatformEvent(AstrMessageEvent): return ( hasattr(self.message_obj, "raw_message") and hasattr(self.message_obj.raw_message, "type") - and self.message_obj.raw_message.type + and cast(discord.Interaction, self.message_obj.raw_message).type == discord.InteractionType.application_command ) @@ -279,14 +294,18 @@ class DiscordPlatformEvent(AstrMessageEvent): return ( hasattr(self.message_obj, "raw_message") and hasattr(self.message_obj.raw_message, "type") - and self.message_obj.raw_message.type == discord.InteractionType.component + and cast(discord.Interaction, self.message_obj.raw_message).type + == discord.InteractionType.component ) def get_interaction_custom_id(self) -> str: """获取交互组件的custom_id""" if self.is_button_interaction(): try: - return self.message_obj.raw_message.data.get("custom_id", "") + return cast( + ComponentInteractionData, + cast(discord.Interaction, self.message_obj.raw_message).data, + ).get("custom_id", "") except Exception: pass return "" @@ -299,7 +318,9 @@ class DiscordPlatformEvent(AstrMessageEvent): ): return any( mention.id == int(self.message_obj.self_id) - for mention in self.message_obj.raw_message.mentions + for mention in cast( + discord.Message, self.message_obj.raw_message + ).mentions ) return False @@ -309,5 +330,5 @@ class DiscordPlatformEvent(AstrMessageEvent): self.message_obj.raw_message, "clean_content", ): - return self.message_obj.raw_message.clean_content + return cast(discord.Message, self.message_obj.raw_message).clean_content return self.message_str diff --git a/astrbot/core/platform/sources/lark/lark_adapter.py b/astrbot/core/platform/sources/lark/lark_adapter.py index e6e6d4d2b..473be096f 100644 --- a/astrbot/core/platform/sources/lark/lark_adapter.py +++ b/astrbot/core/platform/sources/lark/lark_adapter.py @@ -3,9 +3,14 @@ import base64 import json import re import uuid +from typing import cast import lark_oapi as lark -from lark_oapi.api.im.v1 import * +from lark_oapi.api.im.v1 import ( + CreateMessageRequest, + CreateMessageRequestBody, + GetMessageResourceRequest, +) import astrbot.api.message_components as Comp from astrbot import logger @@ -33,9 +38,7 @@ class LarkPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - - self.config = platform_config + super().__init__(platform_config, event_queue) self.unique_session = platform_settings["unique_session"] @@ -76,6 +79,10 @@ class LarkPlatformAdapter(Platform): session: MessageSesion, message_chain: MessageChain, ): + if self.lark_api.im is None: + logger.error("[Lark] API Client im 模块未初始化,无法发送消息") + return + res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api) wrapped = { "zh_cn": { @@ -116,14 +123,21 @@ class LarkPlatformAdapter(Platform): return PlatformMetadata( name="lark", description="飞书机器人官方 API 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), support_streaming_message=False, ) async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1): + if event.event is None: + logger.debug("[Lark] 收到空事件(event.event is None)") + return message = event.event.message + if message is None: + logger.debug("[Lark] 事件中没有消息体(message is None)") + return + abm = AstrBotMessage() - abm.timestamp = int(message.create_time) / 1000 + abm.timestamp = cast(int, message.create_time) // 1000 abm.message = [] abm.type = ( MessageType.GROUP_MESSAGE @@ -138,14 +152,28 @@ class LarkPlatformAdapter(Platform): at_list = {} if message.mentions: for m in message.mentions: - at_list[m.key] = Comp.At(qq=m.id.open_id, name=m.name) - if m.name == self.bot_name: - abm.self_id = m.id.open_id + if m.id is None: + continue + # 飞书 open_id 可能是 None,这里做个防护 + open_id = m.id.open_id if m.id.open_id else "" + at_list[m.key] = Comp.At(qq=open_id, name=m.name) - content_json_b = json.loads(message.content) + if m.name == self.bot_name: + if m.id.open_id is not None: + abm.self_id = m.id.open_id + + if message.content is None: + logger.warning("[Lark] 消息内容为空") + return + + try: + content_json_b = json.loads(message.content) + except json.JSONDecodeError: + logger.error(f"[Lark] 解析消息内容失败: {message.content}") + return if message.message_type == "text": - message_str_raw = content_json_b["text"] # 带有 @ 的消息 + message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息 at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则 # at_users = re.findall(at_pattern, message_str_raw) # 拆分文本,去掉AT符号部分 @@ -170,27 +198,47 @@ class LarkPlatformAdapter(Platform): content_json_b = _ls elif message.message_type == "image": content_json_b = [ - {"tag": "img", "image_key": content_json_b["image_key"], "style": []}, + { + "tag": "img", + "image_key": content_json_b.get("image_key"), + "style": [], + }, ] if message.message_type in ("post", "image"): for comp in content_json_b: - if comp["tag"] == "at": - abm.message.append(at_list[comp["user_id"]]) - elif comp["tag"] == "text" and comp["text"].strip(): + if comp.get("tag") == "at": + user_id = comp.get("user_id") + if user_id in at_list: + abm.message.append(at_list[user_id]) + elif comp.get("tag") == "text" and comp.get("text", "").strip(): abm.message.append(Comp.Plain(comp["text"].strip())) - elif comp["tag"] == "img": - image_key = comp["image_key"] + elif comp.get("tag") == "img": + image_key = comp.get("image_key") + if not image_key: + continue + request = ( GetMessageResourceRequest.builder() - .message_id(message.message_id) + .message_id(cast(str, message.message_id)) .file_key(image_key) .type("image") .build() ) + + if self.lark_api.im is None: + logger.error("[Lark] API Client im 模块未初始化") + continue + response = await self.lark_api.im.v1.message_resource.aget(request) if not response.success(): logger.error(f"无法下载飞书图片: {image_key}") + continue + + if response.file is None: + logger.error(f"飞书图片响应中不包含文件流: {image_key}") + continue + image_bytes = response.file.read() image_base64 = base64.b64encode(image_bytes).decode() abm.message.append(Comp.Image.fromBase64(image_base64)) @@ -198,6 +246,19 @@ class LarkPlatformAdapter(Platform): for comp in abm.message: if isinstance(comp, Comp.Plain): abm.message_str += comp.text + + if message.message_id is None: + logger.error("[Lark] 消息缺少 message_id") + return + + if ( + event.event.sender is None + or event.event.sender.sender_id is None + or event.event.sender.sender_id.open_id is None + ): + logger.error("[Lark] 消息发送者信息不完整") + return + abm.message_id = message.message_id abm.raw_message = message abm.sender = MessageMember( @@ -237,5 +298,5 @@ class LarkPlatformAdapter(Platform): await self.client._disconnect() logger.info("飞书(Lark) 适配器已被优雅地关闭") - def get_client(self) -> lark.Client: + def get_client(self) -> lark.ws.Client: return self.client diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 04204d35e..7b7d20b38 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -5,7 +5,15 @@ import uuid from io import BytesIO import lark_oapi as lark -from lark_oapi.api.im.v1 import * +from lark_oapi.api.im.v1 import ( + CreateImageRequest, + CreateImageRequestBody, + CreateMessageReactionRequest, + CreateMessageReactionRequestBody, + Emoji, + ReplyMessageRequest, + ReplyMessageRequestBody, +) from astrbot import logger from astrbot.api.event import AstrMessageEvent, MessageChain @@ -44,7 +52,7 @@ class LarkMessageEvent(AstrMessageEvent): file_path = comp.file.replace("file:///", "") elif comp.file and comp.file.startswith("http"): image_file_path = await download_image_by_url(comp.file) - file_path = image_file_path + file_path = image_file_path if image_file_path else "" elif comp.file and comp.file.startswith("base64://"): base64_str = comp.file.removeprefix("base64://") image_data = base64.b64decode(base64_str) @@ -54,10 +62,17 @@ class LarkMessageEvent(AstrMessageEvent): with open(file_path, "wb") as f: f.write(BytesIO(image_data).getvalue()) else: - file_path = comp.file + file_path = comp.file if comp.file else "" if image_file is None: - image_file = open(file_path, "rb") + if not file_path: + logger.error("[Lark] 图片路径为空,无法上传") + continue + try: + image_file = open(file_path, "rb") + except Exception as e: + logger.error(f"[Lark] 无法打开图片文件: {e}") + continue request = ( CreateImageRequest.builder() @@ -69,9 +84,20 @@ class LarkMessageEvent(AstrMessageEvent): ) .build() ) + + if lark_client.im is None: + logger.error("[Lark] API Client im 模块未初始化,无法上传图片") + continue + response = await lark_client.im.v1.image.acreate(request) if not response.success(): logger.error(f"无法上传飞书图片({response.code}): {response.msg}") + continue + + if response.data is None: + logger.error("[Lark] 上传图片成功但未返回数据(data is None)") + continue + image_key = response.data.image_key logger.debug(image_key) ret.append(_stage) @@ -107,6 +133,10 @@ class LarkMessageEvent(AstrMessageEvent): .build() ) + if self.bot.im is None: + logger.error("[Lark] API Client im 模块未初始化,无法回复消息") + return + response = await self.bot.im.v1.message.areply(request) if not response.success(): @@ -115,6 +145,10 @@ class LarkMessageEvent(AstrMessageEvent): await super().send(message) async def react(self, emoji: str): + if self.bot.im is None: + logger.error("[Lark] API Client im 模块未初始化,无法发送表情") + return + request = ( CreateMessageReactionRequest.builder() .message_id(self.message_obj.message_id) @@ -125,6 +159,7 @@ class LarkMessageEvent(AstrMessageEvent): ) .build() ) + response = await self.bot.im.v1.message_reaction.acreate(request) if not response.success(): logger.error(f"发送飞书表情回应失败({response.code}): {response.msg}") diff --git a/astrbot/core/platform/sources/misskey/misskey_adapter.py b/astrbot/core/platform/sources/misskey/misskey_adapter.py index ddeec93bc..7f3db3062 100644 --- a/astrbot/core/platform/sources/misskey/misskey_adapter.py +++ b/astrbot/core/platform/sources/misskey/misskey_adapter.py @@ -1,7 +1,6 @@ import asyncio import os import random -from collections.abc import Awaitable from typing import Any import astrbot.api.message_components as Comp @@ -55,8 +54,7 @@ class MisskeyPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - self.config = platform_config or {} + super().__init__(platform_config or {}, event_queue) self.settings = platform_settings or {} self.instance_url = self.config.get("misskey_instance_url", "") self.access_token = self.config.get("misskey_token", "") @@ -204,7 +202,7 @@ class MisskeyPlatformAdapter(Platform): if not isinstance(message.raw_message, dict): message.raw_message = {} message.raw_message["poll"] = poll - message.poll = poll + message.__setattr__("poll", poll) except Exception: pass @@ -373,7 +371,7 @@ class MisskeyPlatformAdapter(Platform): self, session: MessageSession, message_chain: MessageChain, - ) -> Awaitable[Any]: + ) -> None: if not self.api: logger.error("[Misskey] API 客户端未初始化") return await super().send_by_session(session, message_chain) diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py index fe1496644..d693c4206 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_message_event.py @@ -3,6 +3,7 @@ import base64 import os import random import uuid +from typing import cast import aiofiles import botpy @@ -60,7 +61,10 @@ class QQOfficialMessageEvent(AstrMessageEvent): time_since_last_edit = current_time - last_edit_time if time_since_last_edit >= throttle_interval: - ret = await self._post_send(stream=stream_payload) + ret = cast( + message.Message, + await self._post_send(stream=stream_payload), + ) stream_payload["index"] += 1 stream_payload["id"] = ret["id"] last_edit_time = asyncio.get_event_loop().time() @@ -69,6 +73,8 @@ class QQOfficialMessageEvent(AstrMessageEvent): # 结束流式对话,并且传输 buffer 中剩余的消息 stream_payload["state"] = 10 ret = await self._post_send(stream=stream_payload) + else: + ret = await self._post_send() except Exception as e: logger.error(f"发送流式消息时出错: {e}", exc_info=True) @@ -81,7 +87,8 @@ class QQOfficialMessageEvent(AstrMessageEvent): return None source = self.message_obj.raw_message - assert isinstance( + + if not isinstance( source, ( botpy.message.Message, @@ -89,7 +96,9 @@ class QQOfficialMessageEvent(AstrMessageEvent): botpy.message.DirectMessage, botpy.message.C2CMessage, ), - ) + ): + logger.warning(f"[QQOfficial] 不支持的消息源类型: {type(source)}") + return None ( plain_text, @@ -106,7 +115,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): ): return None - payload = { + payload: dict = { "content": plain_text, "msg_id": self.message_obj.message_id, } @@ -116,8 +125,12 @@ class QQOfficialMessageEvent(AstrMessageEvent): ret = None - match type(source): - case botpy.message.GroupMessage: + match source: + case botpy.message.GroupMessage(): + if not source.group_openid: + logger.error("[QQOfficial] GroupMessage 缺少 group_openid") + return None + if image_base64: media = await self.upload_group_and_c2c_image( image_base64, @@ -138,7 +151,8 @@ class QQOfficialMessageEvent(AstrMessageEvent): group_openid=source.group_openid, **payload, ) - case botpy.message.C2CMessage: + + case botpy.message.C2CMessage(): if image_base64: media = await self.upload_group_and_c2c_image( image_base64, @@ -167,18 +181,23 @@ class QQOfficialMessageEvent(AstrMessageEvent): **payload, ) logger.debug(f"Message sent to C2C: {ret}") - case botpy.message.Message: + + case botpy.message.Message(): if image_path: payload["file_image"] = image_path ret = await self.bot.api.post_message( channel_id=source.channel_id, **payload, ) - case botpy.message.DirectMessage: + + case botpy.message.DirectMessage(): if image_path: payload["file_image"] = image_path ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload) + case _: + pass + await super().send(self.send_buffer) self.send_buffer = None @@ -196,18 +215,33 @@ class QQOfficialMessageEvent(AstrMessageEvent): "file_type": file_type, "srv_send_msg": False, } + + result = None if "openid" in kwargs: payload["openid"] = kwargs["openid"] route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"]) - return await self.bot.api._http.request(route, json=payload) - if "group_openid" in kwargs: + result = await self.bot.api._http.request(route, json=payload) + elif "group_openid" in kwargs: payload["group_openid"] = kwargs["group_openid"] route = Route( "POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"], ) - return await self.bot.api._http.request(route, json=payload) + result = await self.bot.api._http.request(route, json=payload) + else: + raise ValueError("Invalid upload parameters") + + if not isinstance(result, dict): + raise RuntimeError( + f"Failed to upload image, response is not dict: {result}" + ) + + return Media( + file_uuid=result["file_uuid"], + file_info=result["file_info"], + ttl=result.get("ttl", 0), + ) async def upload_group_and_c2c_record( self, @@ -250,11 +284,14 @@ class QQOfficialMessageEvent(AstrMessageEvent): result = await self.bot.api._http.request(route, json=payload) if result: + if not isinstance(result, dict): + logger.error(f"上传文件响应格式错误: {result}") + return None + return Media( - file_uuid=result.get("file_uuid"), - file_info=result.get("file_info"), + file_uuid=result["file_uuid"], + file_info=result["file_info"], ttl=result.get("ttl", 0), - file_id=result.get("id", ""), ) except Exception as e: logger.error(f"上传请求错误: {e}") @@ -271,7 +308,7 @@ class QQOfficialMessageEvent(AstrMessageEvent): message_reference: message.Reference | None = None, media: message.Media | None = None, msg_id: str | None = None, - msg_seq: str = 1, + msg_seq: int | None = 1, event_id: str | None = None, markdown: message.MarkdownPayload | None = None, keyboard: message.Keyboard | None = None, @@ -280,7 +317,14 @@ class QQOfficialMessageEvent(AstrMessageEvent): payload = locals() payload.pop("self", None) route = Route("POST", "/v2/users/{openid}/messages", openid=openid) - return await self.bot.api._http.request(route, json=payload) + result = await self.bot.api._http.request(route, json=payload) + + if not isinstance(result, dict): + raise RuntimeError( + f"Failed to post c2c message, response is not dict: {result}" + ) + + return message.Message(**result) @staticmethod async def _parse_to_qqofficial(message: MessageChain): @@ -300,8 +344,10 @@ class QQOfficialMessageEvent(AstrMessageEvent): image_base64 = file_to_base64(image_file_path) elif i.file and i.file.startswith("base64://"): image_base64 = i.file - else: + elif i.file: image_base64 = file_to_base64(i.file) + else: + raise ValueError("Unsupported image file format") image_base64 = image_base64.removeprefix("base64://") elif isinstance(i, Record): if i.file: diff --git a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py index 96be734fd..2a1bcda47 100644 --- a/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py +++ b/astrbot/core/platform/sources/qqofficial/qqofficial_platform_adapter.py @@ -4,6 +4,7 @@ import asyncio import logging import os import time +from typing import cast import botpy import botpy.message @@ -44,7 +45,9 @@ class botClient(Client): MessageType.GROUP_MESSAGE, ) abm.session_id = ( - abm.sender.user_id if self.platform.unique_session else message.group_openid + abm.sender.user_id + if self.platform.unique_session + else cast(str, message.group_openid) ) self._commit(abm) @@ -97,13 +100,11 @@ class QQOfficialPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - - self.config = platform_config + super().__init__(platform_config, event_queue) self.appid = platform_config["appid"] self.secret = platform_config["secret"] - self.unique_session = platform_settings["unique_session"] + self.unique_session: bool = platform_settings["unique_session"] qq_group = platform_config["enable_group_c2c"] guild_dm = platform_config["enable_guild_direct_message"] @@ -139,12 +140,15 @@ class QQOfficialPlatformAdapter(Platform): return PlatformMetadata( name="qq_official", description="QQ 机器人官方 API 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), ) @staticmethod def _parse_from_qqofficial( - message: botpy.message.Message | botpy.message.GroupMessage, + message: botpy.message.Message + | botpy.message.GroupMessage + | botpy.message.DirectMessage + | botpy.message.C2CMessage, message_type: MessageType, ): abm = AstrBotMessage() @@ -152,7 +156,7 @@ class QQOfficialPlatformAdapter(Platform): abm.timestamp = int(time.time()) abm.raw_message = message abm.message_id = message.id - abm.tag = "qq_official" + # abm.tag = "qq_official" msg: list[BaseMessageComponent] = [] if isinstance(message, botpy.message.GroupMessage) or isinstance( @@ -182,9 +186,9 @@ class QQOfficialPlatformAdapter(Platform): message, botpy.message.DirectMessage, ): - try: + if isinstance(message, botpy.message.Message): abm.self_id = str(message.mentions[0].id) - except BaseException as _: + else: abm.self_id = "" plain_content = message.content.replace( diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py index 2b8c0b420..63b6726fe 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_adapter.py @@ -1,5 +1,6 @@ import asyncio import logging +from typing import Any, cast import botpy import botpy.message @@ -11,6 +12,7 @@ from astrbot import logger from astrbot.api.event import MessageChain from astrbot.api.platform import AstrBotMessage, MessageType, Platform, PlatformMetadata from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.utils.webhook_utils import log_webhook_info from ...register import register_platform_adapter from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter @@ -34,7 +36,9 @@ class botClient(Client): MessageType.GROUP_MESSAGE, ) abm.session_id = ( - abm.sender.user_id if self.platform.unique_session else message.group_openid + abm.sender.user_id + if self.platform.unique_session + else cast(str, message.group_openid) ) self._commit(abm) @@ -87,13 +91,12 @@ class QQOfficialWebhookPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - - self.config = platform_config + super().__init__(platform_config, event_queue) self.appid = platform_config["appid"] self.secret = platform_config["secret"] self.unique_session = platform_settings["unique_session"] + self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False) intents = botpy.Intents( public_messages=True, @@ -106,6 +109,7 @@ class QQOfficialWebhookPlatformAdapter(Platform): timeout=20, ) self.client.set_platform(self) + self.webhook_helper = None async def send_by_session( self, @@ -118,7 +122,7 @@ class QQOfficialWebhookPlatformAdapter(Platform): return PlatformMetadata( name="qq_official_webhook", description="QQ 机器人官方 API 适配器", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), ) async def run(self): @@ -128,16 +132,37 @@ class QQOfficialWebhookPlatformAdapter(Platform): self.client, ) await self.webhook_helper.initialize() - await self.webhook_helper.start_polling() + + # 如果启用统一 webhook 模式,则不启动独立服务器 + webhook_uuid = self.config.get("webhook_uuid") + if self.unified_webhook_mode and webhook_uuid: + log_webhook_info(f"{self.meta().id}(QQ 官方机器人 Webhook)", webhook_uuid) + # 保持运行状态,等待 shutdown + await self.webhook_helper.shutdown_event.wait() + else: + await self.webhook_helper.start_polling() def get_client(self) -> botClient: return self.client + async def webhook_callback(self, request: Any) -> Any: + """统一 Webhook 回调入口""" + if not self.webhook_helper: + return {"error": "Webhook helper not initialized"}, 500 + + # 复用 webhook_helper 的回调处理逻辑 + return await self.webhook_helper.handle_callback(request) + async def terminate(self): - self.webhook_helper.shutdown_event.set() + if self.webhook_helper: + self.webhook_helper.shutdown_event.set() await self.client.close() - try: - await self.webhook_helper.server.shutdown() - except Exception as _: - pass + if self.webhook_helper and not self.unified_webhook_mode: + try: + await self.webhook_helper.server.shutdown() + except Exception as exc: + logger.warning( + f"Exception occurred during QQOfficialWebhook server shutdown: {exc}", + exc_info=True, + ) logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭") diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 65b7c701a..2eda11a6c 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -1,5 +1,6 @@ import asyncio import logging +from typing import cast import quart from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token @@ -78,7 +79,19 @@ class QQOfficialWebhook: return response async def callback(self): - msg: dict = await quart.request.json + """内部服务器的回调入口""" + return await self.handle_callback(quart.request) + + async def handle_callback(self, request) -> dict: + """处理 webhook 回调,可被统一 webhook 入口复用 + + Args: + request: Quart 请求对象 + + Returns: + 响应数据 + """ + msg: dict = await request.json logger.debug(f"收到 qq_official_webhook 回调: {msg}") event = msg.get("t") @@ -87,7 +100,7 @@ class QQOfficialWebhook: if opcode == 13: # validation - signed = await self.webhook_validation(data) + signed = await self.webhook_validation(cast(dict, data)) print(signed) return signed diff --git a/astrbot/core/platform/sources/satori/satori_adapter.py b/astrbot/core/platform/sources/satori/satori_adapter.py index fd90804f0..46f9a4e0f 100644 --- a/astrbot/core/platform/sources/satori/satori_adapter.py +++ b/astrbot/core/platform/sources/satori/satori_adapter.py @@ -38,8 +38,7 @@ class SatoriPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - self.config = platform_config + super().__init__(platform_config, event_queue) self.settings = platform_settings self.api_base_url = self.config.get( diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index 0411f73a4..fbdc71759 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -4,9 +4,11 @@ import hmac import json import logging from collections.abc import Callable +from typing import cast from quart import Quart, Response, request from slack_sdk.socket_mode.aiohttp import SocketModeClient +from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.web.async_client import AsyncWebClient @@ -47,51 +49,62 @@ class SlackWebhookClient: @self.app.route(self.path, methods=["POST"]) async def slack_events(): - """处理 Slack 事件""" - try: - # 获取请求体和头部 - body = await request.get_data() - event_data = json.loads(body.decode("utf-8")) - - # Verify Slack request signature - timestamp = request.headers.get("X-Slack-Request-Timestamp") - signature = request.headers.get("X-Slack-Signature") - if not timestamp or not signature: - return Response("Missing headers", status=400) - # Calculate the HMAC signature - sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}" - my_signature = ( - "v0=" - + hmac.new( - self.signing_secret.encode("utf-8"), - sig_basestring.encode("utf-8"), - hashlib.sha256, - ).hexdigest() - ) - # Verify the signature - if not hmac.compare_digest(my_signature, signature): - logger.warning("Slack request signature verification failed") - return Response("Invalid signature", status=400) - logger.info(f"Received Slack event: {event_data}") - - # 处理 URL 验证事件 - if event_data.get("type") == "url_verification": - return {"challenge": event_data.get("challenge")} - # 处理事件 - if self.event_handler and event_data.get("type") == "event_callback": - await self.event_handler(event_data) - - return Response("", status=200) - - except Exception as e: - logger.error(f"处理 Slack 事件时出错: {e}") - return Response("Internal Server Error", status=500) + """内部服务器的 POST 回调入口""" + return await self.handle_callback(request) @self.app.route("/health", methods=["GET"]) async def health_check(): """健康检查端点""" return {"status": "ok", "service": "slack-webhook"} + async def handle_callback(self, req): + """处理 Slack 回调请求,可被统一 webhook 入口复用 + + Args: + req: Quart 请求对象 + + Returns: + Response 对象或字典 + """ + try: + # 获取请求体和头部 + body = cast(bytes, await req.get_data()) + event_data = json.loads(body.decode("utf-8")) + + # Verify Slack request signature + timestamp = req.headers.get("X-Slack-Request-Timestamp") + signature = req.headers.get("X-Slack-Signature") + if not timestamp or not signature: + return Response("Missing headers", status=400) + # Calculate the HMAC signature + sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}" + my_signature = ( + "v0=" + + hmac.new( + self.signing_secret.encode("utf-8"), + sig_basestring.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + ) + # Verify the signature + if not hmac.compare_digest(my_signature, signature): + logger.warning("Slack request signature verification failed") + return Response("Invalid signature", status=400) + logger.info(f"Received Slack event: {event_data}") + + # 处理 URL 验证事件 + if event_data.get("type") == "url_verification": + return {"challenge": event_data.get("challenge")} + # 处理事件 + if self.event_handler and event_data.get("type") == "event_callback": + await self.event_handler(event_data) + + return Response("", status=200) + + except Exception as e: + logger.error(f"处理 Slack 事件时出错: {e}") + return Response("Internal Server Error", status=500) + async def start(self): """启动 Webhook 服务器""" logger.info( @@ -128,9 +141,14 @@ class SlackSocketClient: self.event_handler = event_handler self.socket_client = None - async def _handle_events(self, _: SocketModeClient, req: SocketModeRequest): + async def _handle_events( + self, _: AsyncBaseSocketModeClient, req: SocketModeRequest + ): """处理 Socket Mode 事件""" try: + if self.socket_client is None: + raise RuntimeError("Socket client is not initialized") + # 确认收到事件 response = SocketModeResponse(envelope_id=req.envelope_id) await self.socket_client.send_socket_mode_response(response) diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index d5427deb7..4621f8494 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -3,8 +3,7 @@ import base64 import re import time import uuid -from collections.abc import Awaitable -from typing import Any +from typing import Any, cast import aiohttp from slack_sdk.socket_mode.request import SocketModeRequest @@ -21,6 +20,7 @@ from astrbot.api.platform import ( PlatformMetadata, ) from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.utils.webhook_utils import log_webhook_info from ...register import register_platform_adapter from .client import SlackSocketClient, SlackWebhookClient @@ -39,9 +39,7 @@ class SlackAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - - self.config = platform_config + super().__init__(platform_config, event_queue) self.settings = platform_settings self.unique_session = platform_settings.get("unique_session", False) @@ -49,6 +47,7 @@ class SlackAdapter(Platform): self.app_token = platform_config.get("app_token") self.signing_secret = platform_config.get("signing_secret") self.connection_mode = platform_config.get("slack_connection_mode", "socket") + self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False) self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0") self.webhook_port = platform_config.get("slack_webhook_port", 3000) self.webhook_path = platform_config.get( @@ -68,7 +67,7 @@ class SlackAdapter(Platform): self.metadata = PlatformMetadata( name="slack", description="适用于 Slack 的消息平台适配器,支持 Socket Mode 和 Webhook Mode。", - id=self.config.get("id"), + id=cast(str, self.config.get("id")), support_streaming_message=False, ) @@ -118,13 +117,13 @@ class SlackAdapter(Platform): logger.debug(f"[slack] RawMessage {event}") abm = AstrBotMessage() - abm.self_id = self.bot_self_id + abm.self_id = cast(str, self.bot_self_id) # 获取用户信息 user_id = event.get("user", "") try: user_info = await self.web_client.users_info(user=user_id) - user_data = user_info["user"] + user_data = cast(dict, user_info["user"]) user_name = user_data.get("real_name") or user_data.get("name", user_id) except Exception: user_name = user_id @@ -135,7 +134,7 @@ class SlackAdapter(Platform): channel_id = event.get("channel", "") try: channel_info = await self.web_client.conversations_info(channel=channel_id) - is_im = channel_info["channel"]["is_im"] + is_im = cast(dict, channel_info["channel"])["is_im"] if is_im: abm.type = MessageType.FRIEND_MESSAGE @@ -178,7 +177,7 @@ class SlackAdapter(Platform): for mention in mentions: try: mentioned_user = await self.web_client.users_info(user=mention) - user_data = mentioned_user["user"] + user_data = cast(dict, mentioned_user["user"]) user_name = user_data.get("real_name") or user_data.get( "name", mention, @@ -329,7 +328,7 @@ class SlackAdapter(Platform): ) raise Exception(f"下载文件失败: {resp.status}") - async def run(self) -> Awaitable[Any]: + async def run(self) -> None: self.bot_self_id = await self.get_bot_user_id() logger.info(f"Slack auth test OK. Bot ID: {self.bot_self_id}") @@ -361,10 +360,17 @@ class SlackAdapter(Platform): self._handle_webhook_event, ) - logger.info( - f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...", - ) - await self.webhook_client.start() + # 如果启用统一 webhook 模式,则不启动独立服务器 + webhook_uuid = self.config.get("webhook_uuid") + if self.unified_webhook_mode and webhook_uuid: + log_webhook_info(f"{self.meta().id}(Slack)", webhook_uuid) + # 保持运行状态,等待 shutdown + await self.webhook_client.shutdown_event.wait() + else: + logger.info( + f"Slack 适配器 (Webhook Mode) 启动中,监听 {self.webhook_host}:{self.webhook_port}{self.webhook_path}...", + ) + await self.webhook_client.start() else: raise ValueError( @@ -391,6 +397,13 @@ class SlackAdapter(Platform): if abm: await self.handle_msg(abm) + async def webhook_callback(self, request: Any) -> Any: + """统一 Webhook 回调入口""" + if self.connection_mode != "webhook" or not self.webhook_client: + return {"error": "Slack adapter is not in webhook mode"}, 400 + + return await self.webhook_client.handle_callback(request) + async def terminate(self): if self.socket_client: await self.socket_client.stop() diff --git a/astrbot/core/platform/sources/slack/slack_event.py b/astrbot/core/platform/sources/slack/slack_event.py index c918abbac..822e6fdeb 100644 --- a/astrbot/core/platform/sources/slack/slack_event.py +++ b/astrbot/core/platform/sources/slack/slack_event.py @@ -1,6 +1,7 @@ import asyncio import re -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Iterable +from typing import cast from slack_sdk.web.async_client import AsyncWebClient @@ -31,14 +32,14 @@ class SlackMessageEvent(AstrMessageEvent): async def _from_segment_to_slack_block( segment: BaseMessageComponent, web_client: AsyncWebClient, - ) -> dict: + ) -> dict | None: """将消息段转换为 Slack 块格式""" if isinstance(segment, Plain): return {"type": "section", "text": {"type": "mrkdwn", "text": segment.text}} if isinstance(segment, Image): # upload file url = segment.url or segment.file - if url.startswith("http"): + if url and url.startswith("http"): return { "type": "image", "image_url": url, @@ -55,7 +56,7 @@ class SlackMessageEvent(AstrMessageEvent): "type": "section", "text": {"type": "mrkdwn", "text": "图片上传失败"}, } - image_url = response["files"][0]["url_private"] + image_url = cast(list, response["files"])[0]["url_private"] logger.debug(f"Slack file upload response: {response}") return { "type": "image", @@ -77,7 +78,7 @@ class SlackMessageEvent(AstrMessageEvent): "type": "section", "text": {"type": "mrkdwn", "text": "文件上传失败"}, } - file_url = response["files"][0]["permalink"] + file_url = cast(list, response["files"])[0]["permalink"] return { "type": "section", "text": { @@ -85,7 +86,6 @@ class SlackMessageEvent(AstrMessageEvent): "text": f"文件: <{file_url}|{segment.name or '文件'}>", }, } - return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}} @staticmethod async def _parse_slack_blocks( @@ -115,7 +115,8 @@ class SlackMessageEvent(AstrMessageEvent): segment, web_client, ) - blocks.append(block) + if block: + blocks.append(block) # 如果最后还有文本内容 if text_content.strip(): @@ -225,10 +226,10 @@ class SlackMessageEvent(AstrMessageEvent): ) members = [] - for member_id in members_response["members"]: + for member_id in cast(Iterable, members_response["members"]): try: user_info = await self.web_client.users_info(user=member_id) - user_data = user_info["user"] + user_data = cast(dict, user_info["user"]) members.append( MessageMember( user_id=member_id, @@ -240,7 +241,7 @@ class SlackMessageEvent(AstrMessageEvent): # 如果获取用户信息失败,使用默认信息 members.append(MessageMember(user_id=member_id, nickname=member_id)) - channel_data = channel_info["channel"] + channel_data = cast(dict, channel_info["channel"]) return Group( group_id=channel_id, group_name=channel_data.get("name", ""), diff --git a/astrbot/core/platform/sources/telegram/tg_adapter.py b/astrbot/core/platform/sources/telegram/tg_adapter.py index 6b4d23f65..bca45ea8d 100644 --- a/astrbot/core/platform/sources/telegram/tg_adapter.py +++ b/astrbot/core/platform/sources/telegram/tg_adapter.py @@ -42,8 +42,7 @@ class TelegramPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - self.config = platform_config + super().__init__(platform_config, event_queue) self.settings = platform_settings self.client_self_id = uuid.uuid4().hex[:8] diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 34fd86ad9..37f60e65a 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -1,6 +1,7 @@ import asyncio import os import re +from typing import Any, cast import telegramify_markdown from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji @@ -17,8 +18,6 @@ from astrbot.api.message_components import ( Reply, ) from astrbot.api.platform import AstrBotMessage, MessageType, PlatformMetadata -from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.utils.io import download_file class TelegramPlatformEvent(AstrMessageEvent): @@ -97,7 +96,7 @@ class TelegramPlatformEvent(AstrMessageEvent): "chat_id": user_name, } if has_reply: - payload["reply_to_message_id"] = reply_message_id + payload["reply_to_message_id"] = str(reply_message_id) if message_thread_id: payload["message_thread_id"] = message_thread_id @@ -110,33 +109,30 @@ class TelegramPlatformEvent(AstrMessageEvent): try: md_text = telegramify_markdown.markdownify( chunk, - max_line_length=None, normalize_whitespace=False, ) await client.send_message( text=md_text, parse_mode="MarkdownV2", - **payload, + **cast(Any, payload), ) except Exception as e: logger.warning( f"MarkdownV2 send failed: {e}. Using plain text instead.", ) - await client.send_message(text=chunk, **payload) + await client.send_message(text=chunk, **cast(Any, payload)) elif isinstance(i, Image): image_path = await i.convert_to_file_path() - await client.send_photo(photo=image_path, **payload) + await client.send_photo(photo=image_path, **cast(Any, payload)) elif isinstance(i, File): - if i.file.startswith("https://"): - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - path = os.path.join(temp_dir, i.name) - await download_file(i.file, path) - i.file = path - - await client.send_document(document=i.file, filename=i.name, **payload) + path = await i.get_file() + name = i.name or os.path.basename(path) + await client.send_document( + document=path, filename=name, **cast(Any, payload) + ) elif isinstance(i, Record): path = await i.convert_to_file_path() - await client.send_voice(voice=path, **payload) + await client.send_voice(voice=path, **cast(Any, payload)) async def send(self, message: MessageChain): if self.get_message_type() == MessageType.GROUP_MESSAGE: @@ -214,24 +210,23 @@ class TelegramPlatformEvent(AstrMessageEvent): delta += i.text elif isinstance(i, Image): image_path = await i.convert_to_file_path() - await self.client.send_photo(photo=image_path, **payload) + await self.client.send_photo( + photo=image_path, **cast(Any, payload) + ) continue elif isinstance(i, File): - if i.file.startswith("https://"): - temp_dir = os.path.join(get_astrbot_data_path(), "temp") - path = os.path.join(temp_dir, i.name) - await download_file(i.file, path) - i.file = path + path = await i.get_file() + name = i.name or os.path.basename(path) await self.client.send_document( - document=i.file, - filename=i.name, - **payload, + document=path, + filename=name, + **cast(Any, payload), ) continue elif isinstance(i, Record): path = await i.convert_to_file_path() - await self.client.send_voice(voice=path, **payload) + await self.client.send_voice(voice=path, **cast(Any, payload)) continue else: logger.warning(f"不支持的消息类型: {type(i)}") @@ -260,7 +255,9 @@ class TelegramPlatformEvent(AstrMessageEvent): else: # delta 长度一般不会大于 4096,因此这里直接发送 try: - msg = await self.client.send_message(text=delta, **payload) + msg = await self.client.send_message( + text=delta, **cast(Any, payload) + ) current_content = delta except Exception as e: logger.warning(f"发送消息失败(streaming): {e!s}") @@ -274,7 +271,6 @@ class TelegramPlatformEvent(AstrMessageEvent): try: markdown_text = telegramify_markdown.markdownify( delta, - max_line_length=None, normalize_whitespace=False, ) await self.client.edit_message_text( diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index ff5482f58..084d7860d 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -2,11 +2,13 @@ import asyncio import os import time import uuid -from collections.abc import Awaitable, Callable +from collections.abc import Callable, Coroutine from typing import Any from astrbot import logger -from astrbot.core.message.components import Image, Plain, Record +from astrbot.core import db_helper +from astrbot.core.db.po import PlatformMessageHistory +from astrbot.core.message.components import File, Image, Plain, Record, Reply, Video from astrbot.core.message.message_event_result import MessageChain from astrbot.core.platform import ( AstrBotMessage, @@ -74,9 +76,8 @@ class WebChatAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) + super().__init__(platform_config, event_queue) - self.config = platform_config self.settings = platform_settings self.unique_session = platform_settings["unique_session"] self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") @@ -96,6 +97,92 @@ class WebChatAdapter(Platform): await WebChatMessageEvent._send(message_chain, session.session_id) await super().send_by_session(session, message_chain) + async def _get_message_history( + self, message_id: int + ) -> PlatformMessageHistory | None: + return await db_helper.get_platform_message_history_by_id(message_id) + + async def _parse_message_parts( + self, + message_parts: list, + depth: int = 0, + max_depth: int = 1, + ) -> tuple[list, list[str]]: + """解析消息段列表,返回消息组件列表和纯文本列表 + + Args: + message_parts: 消息段列表 + depth: 当前递归深度 + max_depth: 最大递归深度(用于处理 reply) + + Returns: + tuple[list, list[str]]: (消息组件列表, 纯文本列表) + """ + components = [] + text_parts = [] + + for part in message_parts: + part_type = part.get("type") + if part_type == "plain": + text = part.get("text", "") + components.append(Plain(text)) + text_parts.append(text) + elif part_type == "reply": + message_id = part.get("message_id") + reply_chain = [] + reply_message_str = "" + sender_id = None + sender_name = None + + # recursively get the content of the referenced message + if depth < max_depth and message_id: + history = await self._get_message_history(message_id) + if history and history.content: + reply_parts = history.content.get("message", []) + if isinstance(reply_parts, list): + ( + reply_chain, + reply_text_parts, + ) = await self._parse_message_parts( + reply_parts, + depth=depth + 1, + max_depth=max_depth, + ) + reply_message_str = "".join(reply_text_parts) + sender_id = history.sender_id + sender_name = history.sender_name + + components.append( + Reply( + id=message_id, + chain=reply_chain, + message_str=reply_message_str, + sender_id=sender_id, + sender_nickname=sender_name, + ) + ) + elif part_type == "image": + path = part.get("path") + if path: + components.append(Image.fromFileSystem(path)) + elif part_type == "record": + path = part.get("path") + if path: + components.append(Record.fromFileSystem(path)) + elif part_type == "file": + path = part.get("path") + if path: + filename = part.get("filename") or ( + os.path.basename(path) if path else "file" + ) + components.append(File(name=filename, file=path)) + elif part_type == "video": + path = part.get("path") + if path: + components.append(Video.fromFileSystem(path)) + + return components, text_parts + async def convert_message(self, data: tuple) -> AstrBotMessage: username, cid, payload = data @@ -108,40 +195,19 @@ class WebChatAdapter(Platform): abm.session_id = f"webchat!{username}!{cid}" abm.message_id = str(uuid.uuid4()) - abm.message = [] - if payload["message"]: - abm.message.append(Plain(payload["message"])) - if payload["image_url"]: - if isinstance(payload["image_url"], list): - for img in payload["image_url"]: - abm.message.append( - Image.fromFileSystem(os.path.join(self.imgs_dir, img)), - ) - else: - abm.message.append( - Image.fromFileSystem( - os.path.join(self.imgs_dir, payload["image_url"]), - ), - ) - if payload["audio_url"]: - if isinstance(payload["audio_url"], list): - for audio in payload["audio_url"]: - path = os.path.join(self.imgs_dir, audio) - abm.message.append(Record(file=path, path=path)) - else: - path = os.path.join(self.imgs_dir, payload["audio_url"]) - abm.message.append(Record(file=path, path=path)) + # 处理消息段列表 + message_parts = payload.get("message", []) + abm.message, message_str_parts = await self._parse_message_parts(message_parts) logger.debug(f"WebChatAdapter: {abm.message}") - message_str = payload["message"] abm.timestamp = int(time.time()) - abm.message_str = message_str + abm.message_str = "".join(message_str_parts) abm.raw_message = data return abm - def run(self) -> Awaitable[Any]: + def run(self) -> Coroutine[Any, Any, None]: async def callback(data: tuple): abm = await self.convert_message(data) await self.handle_msg(abm) diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index 4ced79b19..9f1a6d059 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -1,12 +1,12 @@ import base64 import os +import shutil import uuid from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.message_components import Image, Plain, Record +from astrbot.api.message_components import File, Image, Plain, Record from astrbot.core.utils.astrbot_path import get_astrbot_data_path -from astrbot.core.utils.io import download_image_by_url from .webchat_queue_mgr import webchat_queue_mgr @@ -19,7 +19,9 @@ class WebChatMessageEvent(AstrMessageEvent): os.makedirs(imgs_dir, exist_ok=True) @staticmethod - async def _send(message: MessageChain, session_id: str, streaming: bool = False): + async def _send( + message: MessageChain | None, session_id: str, streaming: bool = False + ) -> str | None: cid = session_id.split("!")[-1] web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) if not message: @@ -30,7 +32,7 @@ class WebChatMessageEvent(AstrMessageEvent): "streaming": False, }, # end means this request is finished ) - return "" + return data = "" for comp in message.chain: @@ -47,24 +49,11 @@ class WebChatMessageEvent(AstrMessageEvent): ) elif isinstance(comp, Image): # save image to local - filename = str(uuid.uuid4()) + ".jpg" + filename = f"{str(uuid.uuid4())}.jpg" path = os.path.join(imgs_dir, filename) - if comp.file and comp.file.startswith("file:///"): - ph = comp.file[8:] - with open(path, "wb") as f: - with open(ph, "rb") as f2: - f.write(f2.read()) - elif comp.file.startswith("base64://"): - base64_str = comp.file[9:] - image_data = base64.b64decode(base64_str) - with open(path, "wb") as f: - f.write(image_data) - elif comp.file and comp.file.startswith("http"): - await download_image_by_url(comp.file, path=path) - else: - with open(path, "wb") as f: - with open(comp.file, "rb") as f2: - f.write(f2.read()) + image_base64 = await comp.convert_to_base64() + with open(path, "wb") as f: + f.write(base64.b64decode(image_base64)) data = f"[IMAGE]{filename}" await web_chat_back_queue.put( { @@ -76,19 +65,11 @@ class WebChatMessageEvent(AstrMessageEvent): ) elif isinstance(comp, Record): # save record to local - filename = str(uuid.uuid4()) + ".wav" + filename = f"{str(uuid.uuid4())}.wav" path = os.path.join(imgs_dir, filename) - if comp.file and comp.file.startswith("file:///"): - ph = comp.file[8:] - with open(path, "wb") as f: - with open(ph, "rb") as f2: - f.write(f2.read()) - elif comp.file and comp.file.startswith("http"): - await download_image_by_url(comp.file, path=path) - else: - with open(path, "wb") as f: - with open(comp.file, "rb") as f2: - f.write(f2.read()) + record_base64 = await comp.convert_to_base64() + with open(path, "wb") as f: + f.write(base64.b64decode(record_base64)) data = f"[RECORD]{filename}" await web_chat_back_queue.put( { @@ -98,14 +79,31 @@ class WebChatMessageEvent(AstrMessageEvent): "streaming": streaming, }, ) + elif isinstance(comp, File): + # save file to local + file_path = await comp.get_file() + original_name = comp.name or os.path.basename(file_path) + ext = os.path.splitext(original_name)[1] or "" + filename = f"{uuid.uuid4()!s}{ext}" + dest_path = os.path.join(imgs_dir, filename) + shutil.copy2(file_path, dest_path) + data = f"[FILE]{filename}|{original_name}" + await web_chat_back_queue.put( + { + "type": "file", + "cid": cid, + "data": data, + "streaming": streaming, + }, + ) else: logger.debug(f"webchat 忽略: {comp.type}") return data - async def send(self, message: MessageChain): + async def send(self, message: MessageChain | None): await WebChatMessageEvent._send(message, session_id=self.session_id) - await super().send(message) + await super().send(MessageChain([])) async def send_streaming(self, generator, use_fallback: bool = False): final_data = "" @@ -131,6 +129,8 @@ class WebChatMessageEvent(AstrMessageEvent): session_id=self.session_id, streaming=True, ) + if not r: + continue if chain.type == "reasoning": reasoning_content += chain.get_plain_text() else: diff --git a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py index e8629ec11..4c9a9d36b 100644 --- a/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py +++ b/astrbot/core/platform/sources/wechatpadpro/wechatpadpro_adapter.py @@ -4,6 +4,7 @@ import json import os import time import traceback +from typing import cast import aiohttp import anyio @@ -42,10 +43,9 @@ class WeChatPadProAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) + super().__init__(platform_config, event_queue) self._shutdown_event = None self.wxnewpass = None - self.config = platform_config self.settings = platform_settings self.unique_session = platform_settings.get("unique_session", False) @@ -70,7 +70,7 @@ class WeChatPadProAdapter(Platform): ) self.base_url = f"http://{self.host}:{self.port}" self.auth_key = None # 用于保存生成的授权码 - self.wxid = None # 用于保存登录成功后的 wxid + self.wxid: str | None = None # 用于保存登录成功后的 wxid self.credentials_file = os.path.join( get_astrbot_data_path(), "wechatpadpro_credentials.json", @@ -399,7 +399,7 @@ class WeChatPadProAdapter(Platform): ) await asyncio.sleep(5) - async def handle_websocket_message(self, message: str): + async def handle_websocket_message(self, message: str | bytes): """处理从 WebSocket 接收到的消息。""" logger.debug(f"收到 WebSocket 消息: {message}") try: @@ -431,10 +431,13 @@ class WeChatPadProAdapter(Platform): async def convert_message(self, raw_message: dict) -> AstrBotMessage | None: """将 WeChatPadPro 原始消息转换为 AstrBotMessage。""" + if self.wxid is None: + logger.error("WeChatPadPro 适配器未登录或未获取到 wxid,无法处理消息。") + return None abm = AstrBotMessage() abm.raw_message = raw_message abm.message_id = str(raw_message.get("msg_id")) - abm.timestamp = raw_message.get("create_time") + abm.timestamp = cast(int, raw_message.get("create_time")) abm.self_id = self.wxid if int(time.time()) - abm.timestamp > 180: @@ -447,7 +450,7 @@ class WeChatPadProAdapter(Platform): to_user_name = raw_message.get("to_user_name", {}).get("str", "") content = raw_message.get("content", {}).get("str", "") push_content = raw_message.get("push_content", "") - msg_type = raw_message.get("msg_type") + msg_type = cast(int, raw_message.get("msg_type")) abm.message_str = "" abm.message = [] @@ -575,7 +578,7 @@ class WeChatPadProAdapter(Platform): from_user_name: str, to_user_name: str, msg_id: int, - ): + ) -> dict | None: """下载原始图片。""" url = f"{self.base_url}/message/GetMsgBigImg" params = {"key": self.auth_key} @@ -726,12 +729,15 @@ class WeChatPadProAdapter(Platform): # 图片消息 from_user_name = raw_message.get("from_user_name", {}).get("str", "") to_user_name = raw_message.get("to_user_name", {}).get("str", "") - msg_id = raw_message.get("msg_id") + msg_id = cast(int, raw_message.get("msg_id")) image_resp = await self._download_raw_image( from_user_name, to_user_name, msg_id, ) + if image_resp is None: + logger.error(f"下载图片失败: msg_id={msg_id}") + return image_bs64_data = ( image_resp.get("Data", {}).get("Data", {}).get("Buffer", None) ) @@ -772,6 +778,9 @@ class WeChatPadProAdapter(Platform): bufid = 0 to_user_name = raw_message.get("to_user_name", {}).get("str", "") new_msg_id = raw_message.get("new_msg_id") + if new_msg_id is None: + logger.error("语音消息缺少 new_msg_id") + return data_parser = GeweDataParser( content=content, is_private_chat=(abm.type != MessageType.GROUP_MESSAGE), @@ -779,6 +788,9 @@ class WeChatPadProAdapter(Platform): ) voicemsg = data_parser._format_to_xml().find("voicemsg") + if voicemsg is None: + logger.error("无法从 XML 解析 voicemsg 节点") + return bufid = voicemsg.get("bufid") or "0" length = int(voicemsg.get("length") or 0) voice_resp = await self.download_voice( @@ -787,6 +799,9 @@ class WeChatPadProAdapter(Platform): bufid=bufid, length=length, ) + if voice_resp is None: + logger.error(f"下载语音失败: new_msg_id={new_msg_id}") + return voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None) if voice_bs64_data: voice_bs64_data = base64.b64decode(voice_bs64_data) @@ -828,7 +843,8 @@ class WeChatPadProAdapter(Platform): try: if self.ws_handle_task: self.ws_handle_task.cancel() - self._shutdown_event.set() + if self._shutdown_event is not None: + self._shutdown_event.set() except Exception: pass @@ -895,8 +911,8 @@ class WeChatPadProAdapter(Platform): async def get_contact_details_list( self, - room_wx_id_list: list[str] = None, - user_names: list[str] = None, + room_wx_id_list: list[str] | None = None, + user_names: list[str] | None = None, ) -> dict | None: """获取联系人详情列表。""" if room_wx_id_list is None: diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 1ea4c8e20..8f3d091a4 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -2,6 +2,8 @@ import asyncio import os import sys import uuid +from collections.abc import Awaitable, Callable +from typing import Any, cast import quart from requests import Response @@ -24,6 +26,7 @@ from astrbot.api.platform import ( from astrbot.core import logger from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.utils.astrbot_path import get_astrbot_data_path +from astrbot.core.utils.webhook_utils import log_webhook_info from .wecom_event import WecomPlatformEvent from .wecom_kf import WeChatKF @@ -38,7 +41,7 @@ else: class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) - self.port = int(config.get("port")) + self.port = int(cast(str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.server.add_url_rule( "/callback/command", @@ -58,12 +61,24 @@ class WecomServer: config["corpid"].strip(), ) - self.callback = None + self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None self.shutdown_event = asyncio.Event() async def verify(self): - logger.info(f"验证请求有效性: {quart.request.args}") - args = quart.request.args + """内部服务器的 GET 验证入口""" + return await self.handle_verify(quart.request) + + async def handle_verify(self, request) -> str: + """处理验证请求,可被统一 webhook 入口复用 + + Args: + request: Quart 请求对象 + + Returns: + 验证响应 + """ + logger.info(f"验证请求有效性: {request.args}") + args = request.args try: echo_str = self.crypto.check_signature( args.get("msg_signature"), @@ -78,17 +93,29 @@ class WecomServer: raise async def callback_command(self): - data = await quart.request.get_data() - msg_signature = quart.request.args.get("msg_signature") - timestamp = quart.request.args.get("timestamp") - nonce = quart.request.args.get("nonce") + """内部服务器的 POST 回调入口""" + return await self.handle_callback(quart.request) + + async def handle_callback(self, request) -> str: + """处理回调请求,可被统一 webhook 入口复用 + + Args: + request: Quart 请求对象 + + Returns: + 响应内容 + """ + data = await request.get_data() + msg_signature = request.args.get("msg_signature") + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") try: xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce) except InvalidSignatureException: logger.error("解密失败,签名异常,请检查配置。") raise else: - msg = parse_message(xml) + msg = cast(BaseMessage, parse_message(xml)) logger.info(f"解析成功: {msg}") if self.callback: @@ -118,14 +145,14 @@ class WecomPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - self.config = platform_config + super().__init__(platform_config, event_queue) self.settingss = platform_settings self.client_self_id = uuid.uuid4().hex[:8] self.api_base_url = platform_config.get( "api_base_url", "https://qyapi.weixin.qq.com/cgi-bin/", ) + self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False) if not self.api_base_url: self.api_base_url = "https://qyapi.weixin.qq.com/cgi-bin/" @@ -150,10 +177,10 @@ class WecomPlatformAdapter(Platform): # inject self.wechat_kf_api = WeChatKF(client=self.client) self.wechat_kf_message_api = WeChatKFMessage(self.client) - self.client.kf = self.wechat_kf_api - self.client.kf_message = self.wechat_kf_message_api + self.client.__setattr__("kf", self.wechat_kf_api) + self.client.__setattr__("kf_message", self.wechat_kf_message_api) - self.client.API_BASE_URL = self.api_base_url + self.client.__setattr__("API_BASE_URL", self.api_base_url) async def callback(msg: BaseMessage): if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event": @@ -232,41 +259,53 @@ class WecomPlatformAdapter(Platform): ) except Exception as e: logger.error(e) - await self.server.start_polling() + + # 如果启用统一 webhook 模式,则不启动独立服务器 + webhook_uuid = self.config.get("webhook_uuid") + if self.unified_webhook_mode and webhook_uuid: + log_webhook_info(f"{self.meta().id}(企业微信)", webhook_uuid) + # 保持运行状态,等待 shutdown + await self.server.shutdown_event.wait() + else: + await self.server.start_polling() + + async def webhook_callback(self, request: Any) -> Any: + """统一 Webhook 回调入口""" + # 根据请求方法分发到不同的处理函数 + if request.method == "GET": + return await self.server.handle_verify(request) + else: + return await self.server.handle_callback(request) async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None: abm = AstrBotMessage() - if msg.type == "text": - assert isinstance(msg, TextMessage) + if isinstance(msg, TextMessage): abm.message_str = msg.content abm.self_id = str(msg.agent) abm.message = [Plain(msg.content)] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(msg.id) + abm.timestamp = int(cast(int | str, msg.time)) abm.session_id = abm.sender.user_id abm.raw_message = msg - elif msg.type == "image": - assert isinstance(msg, ImageMessage) + elif isinstance(msg, ImageMessage): abm.message_str = "[图片]" abm.self_id = str(msg.agent) abm.message = [Image(file=msg.image, url=msg.image)] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(msg.id) + abm.timestamp = int(cast(int | str, msg.time)) abm.session_id = abm.sender.user_id abm.raw_message = msg - elif msg.type == "voice": - assert isinstance(msg, VoiceMessage) - + elif isinstance(msg, VoiceMessage): resp: Response = await asyncio.get_event_loop().run_in_executor( None, self.client.media.download, @@ -293,11 +332,11 @@ class WecomPlatformAdapter(Platform): abm.message = [Record(file=path_wav, url=path_wav)] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(msg.id) + abm.timestamp = int(cast(int | str, msg.time)) abm.session_id = abm.sender.user_id abm.raw_message = msg else: @@ -309,7 +348,7 @@ class WecomPlatformAdapter(Platform): async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None: msgtype = msg.get("msgtype") - external_userid = msg.get("external_userid") + external_userid = cast(str, msg.get("external_userid")) abm = AstrBotMessage() abm.raw_message = msg abm.raw_message["_wechat_kf_flag"] = None # 方便处理 diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index ba9ad9a49..0b5dae272 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -16,7 +16,7 @@ try: import pydub except Exception: logger.warning( - "检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。", + "检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。", ) @@ -93,10 +93,10 @@ class WecomPlatformEvent(AstrMessageEvent): if is_wechat_kf: # 微信客服 kf_message_api = getattr(self.client, "kf_message", None) - if not kf_message_api: + if not isinstance(kf_message_api, WeChatKFMessage): logger.warning("未找到微信客服发送消息方法。") return - assert isinstance(kf_message_api, WeChatKFMessage) + user_id = self.get_sender_id() for comp in message.chain: if isinstance(comp, Plain): diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 9c13cfeff..70581e7ea 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -22,6 +22,7 @@ from astrbot.api.platform import ( PlatformMetadata, ) from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.utils.webhook_utils import log_webhook_info from ...register import register_platform_adapter from .wecomai_api import ( @@ -103,9 +104,7 @@ class WecomAIBotAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - - self.config = platform_config + super().__init__(platform_config, event_queue) self.settings = platform_settings # 初始化配置参数 @@ -122,6 +121,7 @@ class WecomAIBotAdapter(Platform): "wecomaibot_friend_message_welcome_text", "", ) + self.unified_webhook_mode = self.config.get("unified_webhook_mode", False) # 平台元数据 self.metadata = PlatformMetadata( @@ -425,17 +425,34 @@ class WecomAIBotAdapter(Platform): def run(self) -> Awaitable[Any]: """运行适配器,同时启动HTTP服务器和队列监听器""" - logger.info("启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port) async def run_both(): - # 同时运行HTTP服务器和队列监听器 - await asyncio.gather( - self.server.start_server(), - self.queue_listener.run(), - ) + # 如果启用统一 webhook 模式,则不启动独立服务器 + webhook_uuid = self.config.get("webhook_uuid") + if self.unified_webhook_mode and webhook_uuid: + log_webhook_info(f"{self.meta().id}(企业微信智能机器人)", webhook_uuid) + # 只运行队列监听器 + await self.queue_listener.run() + else: + logger.info( + "启动企业微信智能机器人适配器,监听 %s:%d", self.host, self.port + ) + # 同时运行HTTP服务器和队列监听器 + await asyncio.gather( + self.server.start_server(), + self.queue_listener.run(), + ) return run_both() + async def webhook_callback(self, request: Any) -> Any: + """统一 Webhook 回调入口""" + # 根据请求方法分发到不同的处理函数 + if request.method == "GET": + return await self.server.handle_verify(request) + else: + return await self.server.handle_callback(request) + async def terminate(self): """终止适配器""" logger.info("企业微信智能机器人适配器正在关闭...") diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py index 0091783a4..fd11d7ceb 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_event.py @@ -39,7 +39,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): @staticmethod async def _send( - message_chain: MessageChain, + message_chain: MessageChain | None, stream_id: str, queue_mgr: WecomAIQueueMgr, streaming: bool = False, @@ -90,7 +90,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): return data - async def send(self, message: MessageChain): + async def send(self, message: MessageChain | None): """发送消息""" raw = self.message_obj.raw_message assert isinstance(raw, dict), ( @@ -98,7 +98,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent): ) stream_id = raw.get("stream_id", self.session_id) await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr) - await super().send(message) + await super().send(MessageChain([])) async def send_streaming(self, generator, use_fallback=False): """流式发送消息,参考webchat的send_streaming设计""" diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index 35acd9066..5cbdd1130 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -59,8 +59,19 @@ class WecomAIBotServer: ) async def verify_url(self): - """验证回调 URL""" - args = quart.request.args + """内部服务器的 GET 验证入口""" + return await self.handle_verify(quart.request) + + async def handle_verify(self, request): + """处理 URL 验证请求,可被统一 webhook 入口复用 + + Args: + request: Quart 请求对象 + + Returns: + 验证响应元组 (content, status_code, headers) + """ + args = request.args msg_signature = args.get("msg_signature") timestamp = args.get("timestamp") nonce = args.get("nonce") @@ -81,8 +92,19 @@ class WecomAIBotServer: return result, 200, {"Content-Type": "text/plain"} async def handle_message(self): - """处理消息回调""" - args = quart.request.args + """内部服务器的 POST 消息回调入口""" + return await self.handle_callback(quart.request) + + async def handle_callback(self, request): + """处理消息回调,可被统一 webhook 入口复用 + + Args: + request: Quart 请求对象 + + Returns: + 响应元组 (content, status_code, headers) + """ + args = request.args msg_signature = args.get("msg_signature") timestamp = args.get("timestamp") nonce = args.get("nonce") @@ -102,7 +124,7 @@ class WecomAIBotServer: try: # 获取请求体 - post_data = await quart.request.get_data() + post_data = await request.get_data() # 确保 post_data 是 bytes 类型 if isinstance(post_data, str): diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index d1309374f..d0304a48e 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -1,6 +1,8 @@ import asyncio import sys import uuid +from collections.abc import Awaitable, Callable +from typing import Any, cast import quart from requests import Response @@ -22,6 +24,7 @@ from astrbot.api.platform import ( ) from astrbot.core import logger from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.utils.webhook_utils import log_webhook_info from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent @@ -31,10 +34,10 @@ else: from typing_extensions import override -class WecomServer: +class WeixinOfficialAccountServer: def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) - self.port = int(config.get("port")) + self.port = int(cast(int | str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.token = config.get("token") self.encoding_aes_key = config.get("encoding_aes_key") @@ -53,13 +56,25 @@ class WecomServer: self.event_queue = event_queue - self.callback = None + self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None self.shutdown_event = asyncio.Event() async def verify(self): - logger.info(f"验证请求有效性: {quart.request.args}") + """内部服务器的 GET 验证入口""" + return await self.handle_verify(quart.request) - args = quart.request.args + async def handle_verify(self, request) -> str: + """处理验证请求,可被统一 webhook 入口复用 + + Args: + request: Quart 请求对象 + + Returns: + 验证响应 + """ + logger.info(f"验证请求有效性: {request.args}") + + args = request.args if not args.get("signature", None): logger.error("未知的响应,请检查回调地址是否填写正确。") return "err" @@ -77,10 +92,22 @@ class WecomServer: return "err" async def callback_command(self): - data = await quart.request.get_data() - msg_signature = quart.request.args.get("msg_signature") - timestamp = quart.request.args.get("timestamp") - nonce = quart.request.args.get("nonce") + """内部服务器的 POST 回调入口""" + return await self.handle_callback(quart.request) + + async def handle_callback(self, request) -> str: + """处理回调请求,可被统一 webhook 入口复用 + + Args: + request: Quart 请求对象 + + Returns: + 响应内容 + """ + data = await request.get_data() + msg_signature = request.args.get("msg_signature") + timestamp = request.args.get("timestamp") + nonce = request.args.get("nonce") try: xml = self.crypto.decrypt_message(data, msg_signature, timestamp, nonce) except InvalidSignatureException: @@ -88,6 +115,9 @@ class WecomServer: raise else: msg = parse_message(xml) + if not msg: + logger.error("解析失败。msg为None。") + raise logger.info(f"解析成功: {msg}") if self.callback: @@ -123,8 +153,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform): platform_settings: dict, event_queue: asyncio.Queue, ) -> None: - super().__init__(event_queue) - self.config = platform_config + super().__init__(platform_config, event_queue) self.settingss = platform_settings self.client_self_id = uuid.uuid4().hex[:8] self.api_base_url = platform_config.get( @@ -132,6 +161,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform): "https://api.weixin.qq.com/cgi-bin/", ) self.active_send_mode = self.config.get("active_send_mode", False) + self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False) if not self.api_base_url: self.api_base_url = "https://api.weixin.qq.com/cgi-bin/" @@ -143,14 +173,14 @@ class WeixinOfficialAccountPlatformAdapter(Platform): if not self.api_base_url.endswith("/"): self.api_base_url += "/" - self.server = WecomServer(self._event_queue, self.config) + self.server = WeixinOfficialAccountServer(self._event_queue, self.config) self.client = WeChatClient( self.config["appid"].strip(), self.config["secret"].strip(), ) - self.client.API_BASE_URL = self.api_base_url + self.client.__setattr__("API_BASE_URL", self.api_base_url) # 微信公众号必须 5 秒内进行回复,否则会重试 3 次,我们需要对其进行消息排重 # msgid -> Future @@ -162,11 +192,11 @@ class WeixinOfficialAccountPlatformAdapter(Platform): await self.convert_message(msg, None) else: if msg.id in self.wexin_event_workers: - future = self.wexin_event_workers[msg.id] + future = self.wexin_event_workers[str(cast(str | int, msg.id))] logger.debug(f"duplicate message id checked: {msg.id}") else: future = asyncio.get_event_loop().create_future() - self.wexin_event_workers[msg.id] = future + self.wexin_event_workers[str(cast(str | int, msg.id))] = future await self.convert_message(msg, future) # I love shield so much! result = await asyncio.wait_for( @@ -174,7 +204,7 @@ class WeixinOfficialAccountPlatformAdapter(Platform): 60, ) # wait for 60s logger.debug(f"Got future result: {result}") - self.wexin_event_workers.pop(msg.id, None) + self.wexin_event_workers.pop(str(cast(str | int, msg.id)), None) return result # xml. see weixin_offacc_event.py except asyncio.TimeoutError: pass @@ -202,38 +232,53 @@ class WeixinOfficialAccountPlatformAdapter(Platform): @override async def run(self): - await self.server.start_polling() + # 如果启用统一 webhook 模式,则不启动独立服务器 + webhook_uuid = self.config.get("webhook_uuid") + if self.unified_webhook_mode and webhook_uuid: + log_webhook_info(f"{self.meta().id}(微信公众平台)", webhook_uuid) + # 保持运行状态,等待 shutdown + await self.server.shutdown_event.wait() + else: + await self.server.start_polling() + + async def webhook_callback(self, request: Any) -> Any: + """统一 Webhook 回调入口""" + # 根据请求方法分发到不同的处理函数 + if request.method == "GET": + return await self.server.handle_verify(request) + else: + return await self.server.handle_callback(request) async def convert_message( self, msg, - future: asyncio.Future = None, + future: asyncio.Future | None = None, ) -> AstrBotMessage | None: abm = AstrBotMessage() if isinstance(msg, TextMessage): - abm.message_str = msg.content + abm.message_str = cast(str, msg.content) abm.self_id = str(msg.target) - abm.message = [Plain(msg.content)] + abm.message = [Plain(cast(str, msg.content))] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(cast(str | int, msg.id)) + abm.timestamp = cast(int, msg.time) abm.session_id = abm.sender.user_id elif msg.type == "image": assert isinstance(msg, ImageMessage) abm.message_str = "[图片]" abm.self_id = str(msg.target) - abm.message = [Image(file=msg.image, url=msg.image)] + abm.message = [Image(file=cast(str, msg.image), url=cast(str, msg.image))] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(cast(str | int, msg.id)) + abm.timestamp = cast(int, msg.time) abm.session_id = abm.sender.user_id elif msg.type == "voice": assert isinstance(msg, VoiceMessage) @@ -265,15 +310,16 @@ class WeixinOfficialAccountPlatformAdapter(Platform): abm.message = [Record(file=path_wav, url=path_wav)] abm.type = MessageType.FRIEND_MESSAGE abm.sender = MessageMember( - msg.source, - msg.source, + cast(str, msg.source), + cast(str, msg.source), ) - abm.message_id = msg.id - abm.timestamp = msg.time + abm.message_id = str(cast(str | int, msg.id)) + abm.timestamp = cast(int, msg.time) abm.session_id = abm.sender.user_id else: logger.warning(f"暂未实现的事件: {msg.type}") - future.set_result(None) + if future: + future.set_result(None) return # 很不优雅 :( abm.raw_message = { diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py index d138fc80c..c1f137a41 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_event.py @@ -1,5 +1,6 @@ import asyncio import uuid +from typing import cast from wechatpy import WeChatClient from wechatpy.replies import ImageReply, TextReply, VoiceReply @@ -13,7 +14,7 @@ try: import pydub except Exception: logger.warning( - "检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 控制台 -> 安装 Pip 库安装 pydub。", + "检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。", ) @@ -85,7 +86,9 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): async def send(self, message: MessageChain): message_obj = self.message_obj - active_send_mode = message_obj.raw_message.get("active_send_mode", False) + active_send_mode = cast(dict, message_obj.raw_message).get( + "active_send_mode", False + ) for comp in message.chain: if isinstance(comp, Plain): # Split long text messages if needed @@ -96,10 +99,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): else: reply = TextReply( content=chunk, - message=self.message_obj.raw_message["message"], + message=cast(dict, self.message_obj.raw_message)["message"], ) xml = reply.render() - future = self.message_obj.raw_message["future"] + future = cast(dict, self.message_obj.raw_message)["future"] assert isinstance(future, asyncio.Future) future.set_result(xml) await asyncio.sleep(0.5) # Avoid sending too fast @@ -125,10 +128,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): else: reply = ImageReply( media_id=response["media_id"], - message=self.message_obj.raw_message["message"], + message=cast(dict, self.message_obj.raw_message)["message"], ) xml = reply.render() - future = self.message_obj.raw_message["future"] + future = cast(dict, self.message_obj.raw_message)["future"] assert isinstance(future, asyncio.Future) future.set_result(xml) @@ -160,10 +163,10 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent): else: reply = VoiceReply( media_id=response["media_id"], - message=self.message_obj.raw_message["message"], + message=cast(dict, self.message_obj.raw_message)["message"], ) xml = reply.render() - future = self.message_obj.raw_message["future"] + future = cast(dict, self.message_obj.raw_message)["future"] assert isinstance(future, asyncio.Future) future.set_result(xml) diff --git a/astrbot/core/platform_message_history_mgr.py b/astrbot/core/platform_message_history_mgr.py index 0e079e893..d6d524698 100644 --- a/astrbot/core/platform_message_history_mgr.py +++ b/astrbot/core/platform_message_history_mgr.py @@ -10,12 +10,12 @@ class PlatformMessageHistoryManager: self, platform_id: str, user_id: str, - content: list[dict], # TODO: parse from message chain + content: dict, # TODO: parse from message chain sender_id: str | None = None, sender_name: str | None = None, - ): + ) -> PlatformMessageHistory: """Insert a new platform message history record.""" - await self.db.insert_platform_message_history( + return await self.db.insert_platform_message_history( platform_id=platform_id, user_id=user_id, content=content, diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 8e04423ed..7aad86bdd 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -4,7 +4,7 @@ import asyncio import copy import json import os -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any import aiohttp @@ -118,7 +118,7 @@ class FunctionToolManager: name: str, func_args: list[dict], desc: str, - handler: Callable[..., Awaitable[Any]], + handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]], ) -> FuncTool: params = { "type": "object", # hard-coded here @@ -140,7 +140,7 @@ class FunctionToolManager: name: str, func_args: list, desc: str, - handler: Callable[..., Awaitable[Any]], + handler: Callable[..., Awaitable[Any] | AsyncGenerator[Any]], ) -> None: """添加函数调用工具 diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 3e477255a..be8edc282 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -1,5 +1,6 @@ import asyncio import traceback +from typing import Protocol, runtime_checkable from astrbot.core import astrbot_config, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager @@ -10,6 +11,7 @@ from .entities import ProviderType from .provider import ( EmbeddingProvider, Provider, + Providers, RerankProvider, STTProvider, TTSProvider, @@ -17,6 +19,11 @@ from .provider import ( from .register import llm_tools, provider_cls_map +@runtime_checkable +class HasInitialize(Protocol): + async def initialize(self) -> None: ... + + class ProviderManager: def __init__( self, @@ -48,7 +55,7 @@ class ProviderManager: """加载的 Rerank Provider 的实例""" self.inst_map: dict[ str, - Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider, + Providers, ] = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools @@ -123,15 +130,13 @@ class ProviderManager: self.curr_provider_inst = prov sp.put("curr_provider", provider_id, scope="global", scope_id="global") - async def get_provider_by_id(self, provider_id: str) -> Provider | None: + async def get_provider_by_id(self, provider_id: str) -> Providers | None: """根据提供商 ID 获取提供商实例""" return self.inst_map.get(provider_id) def get_using_provider( - self, - provider_type: ProviderType, - umo=None, - ) -> Provider | STTProvider | TTSProvider | None: + self, provider_type: ProviderType, umo=None + ) -> Providers | None: """获取正在使用的提供商实例。 Args: @@ -191,7 +196,6 @@ class ProviderManager: logger.error(traceback.format_exc()) logger.error(e) - # 设置默认提供商 selected_provider_id = sp.get( "curr_provider", self.provider_settings.get("default_provider_id"), @@ -210,15 +214,37 @@ class ProviderManager: scope="global", scope_id="global", ) - self.curr_provider_inst = self.inst_map.get(selected_provider_id) + + temp_provider = ( + self.inst_map.get(selected_provider_id) + if isinstance(selected_provider_id, str) + else None + ) + self.curr_provider_inst = ( + temp_provider if isinstance(temp_provider, Provider) else None + ) if not self.curr_provider_inst and self.provider_insts: self.curr_provider_inst = self.provider_insts[0] - self.curr_stt_provider_inst = self.inst_map.get(selected_stt_provider_id) + temp_stt = ( + self.inst_map.get(selected_stt_provider_id) + if isinstance(selected_stt_provider_id, str) + else None + ) + self.curr_stt_provider_inst = ( + temp_stt if isinstance(temp_stt, STTProvider) else None + ) if not self.curr_stt_provider_inst and self.stt_provider_insts: self.curr_stt_provider_inst = self.stt_provider_insts[0] - self.curr_tts_provider_inst = self.inst_map.get(selected_tts_provider_id) + temp_tts = ( + self.inst_map.get(selected_tts_provider_id) + if isinstance(selected_tts_provider_id, str) + else None + ) + self.curr_tts_provider_inst = ( + temp_tts if isinstance(temp_tts, TTSProvider) else None + ) if not self.curr_tts_provider_inst and self.tts_provider_insts: self.curr_tts_provider_inst = self.tts_provider_insts[0] @@ -358,73 +384,103 @@ class ProviderManager: provider_metadata.id = provider_config["id"] - if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT: - # STT 任务 - inst = cls_type(provider_config, self.provider_settings) + match provider_metadata.provider_type: + case ProviderType.SPEECH_TO_TEXT: + # STT 任务 + if not issubclass(cls_type, STTProvider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of STTProvider" + ) + inst = cls_type(provider_config, self.provider_settings) - if getattr(inst, "initialize", None): - await inst.initialize() + if isinstance(inst, HasInitialize): + await inst.initialize() - self.stt_provider_insts.append(inst) - if ( - self.provider_stt_settings.get("provider_id") - == provider_config["id"] - ): - self.curr_stt_provider_inst = inst - logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", + self.stt_provider_insts.append(inst) + if ( + self.provider_stt_settings.get("provider_id") + == provider_config["id"] + ): + self.curr_stt_provider_inst = inst + logger.info( + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", + ) + if not self.curr_stt_provider_inst: + self.curr_stt_provider_inst = inst + + case ProviderType.TEXT_TO_SPEECH: + # TTS 任务 + if not issubclass(cls_type, TTSProvider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of TTSProvider" + ) + inst = cls_type(provider_config, self.provider_settings) + + if isinstance(inst, HasInitialize): + await inst.initialize() + + self.tts_provider_insts.append(inst) + if ( + self.provider_settings.get("provider_id") + == provider_config["id"] + ): + self.curr_tts_provider_inst = inst + logger.info( + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", + ) + if not self.curr_tts_provider_inst: + self.curr_tts_provider_inst = inst + + case ProviderType.CHAT_COMPLETION: + # 文本生成任务 + if not issubclass(cls_type, Provider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of Provider" + ) + inst = cls_type( + provider_config, + self.provider_settings, ) - if not self.curr_stt_provider_inst: - self.curr_stt_provider_inst = inst - elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH: - # TTS 任务 - inst = cls_type(provider_config, self.provider_settings) + if isinstance(inst, HasInitialize): + await inst.initialize() - if getattr(inst, "initialize", None): - await inst.initialize() + self.provider_insts.append(inst) + if ( + self.provider_settings.get("default_provider_id") + == provider_config["id"] + ): + self.curr_provider_inst = inst + logger.info( + f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。", + ) + if not self.curr_provider_inst: + self.curr_provider_inst = inst - self.tts_provider_insts.append(inst) - if self.provider_settings.get("provider_id") == provider_config["id"]: - self.curr_tts_provider_inst = inst - logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", + case ProviderType.EMBEDDING: + if not issubclass(cls_type, EmbeddingProvider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of EmbeddingProvider" + ) + inst = cls_type(provider_config, self.provider_settings) + if isinstance(inst, HasInitialize): + await inst.initialize() + self.embedding_provider_insts.append(inst) + case ProviderType.RERANK: + if not issubclass(cls_type, RerankProvider): + raise TypeError( + f"Provider class {cls_type} is not a subclass of RerankProvider" + ) + inst = cls_type(provider_config, self.provider_settings) + if isinstance(inst, HasInitialize): + await inst.initialize() + self.rerank_provider_insts.append(inst) + case _: + # 未知供应商抛出异常,确保inst初始化 + # Should be unreachable + raise Exception( + f"未知的提供商类型:{provider_metadata.provider_type}" ) - if not self.curr_tts_provider_inst: - self.curr_tts_provider_inst = inst - - elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION: - # 文本生成任务 - inst = cls_type( - provider_config, - self.provider_settings, - ) - - if getattr(inst, "initialize", None): - await inst.initialize() - - self.provider_insts.append(inst) - if ( - self.provider_settings.get("default_provider_id") - == provider_config["id"] - ): - self.curr_provider_inst = inst - logger.info( - f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。", - ) - if not self.curr_provider_inst: - self.curr_provider_inst = inst - - elif provider_metadata.provider_type == ProviderType.EMBEDDING: - inst = cls_type(provider_config, self.provider_settings) - if getattr(inst, "initialize", None): - await inst.initialize() - self.embedding_provider_insts.append(inst) - elif provider_metadata.provider_type == ProviderType.RERANK: - inst = cls_type(provider_config, self.provider_settings) - if getattr(inst, "initialize", None): - await inst.initialize() - self.rerank_provider_insts.append(inst) self.inst_map[provider_config["id"]] = inst except Exception as e: diff --git a/astrbot/core/provider/provider.py b/astrbot/core/provider/provider.py index 2b5057e85..7f21a2ee1 100644 --- a/astrbot/core/provider/provider.py +++ b/astrbot/core/provider/provider.py @@ -2,6 +2,7 @@ import abc import asyncio import os from collections.abc import AsyncGenerator +from typing import TypeAlias, Union from astrbot.core.agent.message import Message from astrbot.core.agent.tool import ToolSet @@ -14,6 +15,14 @@ from astrbot.core.provider.entities import ( from astrbot.core.provider.register import provider_cls_map from astrbot.core.utils.astrbot_path import get_astrbot_path +Providers: TypeAlias = Union[ + "Provider", + "STTProvider", + "TTSProvider", + "EmbeddingProvider", + "RerankProvider", +] + class AbstractProvider(abc.ABC): """Provider Abstract Class""" @@ -142,7 +151,9 @@ class Provider(AbstractProvider): - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ - ... + if False: # pragma: no cover - make this an async generator for typing + yield None # type: ignore + raise NotImplementedError() async def pop_record(self, context: list): """弹出 context 第一条非系统提示词对话记录""" diff --git a/astrbot/core/provider/sources/azure_tts_source.py b/astrbot/core/provider/sources/azure_tts_source.py index e85d91793..2ccf146ca 100644 --- a/astrbot/core/provider/sources/azure_tts_source.py +++ b/astrbot/core/provider/sources/azure_tts_source.py @@ -29,15 +29,24 @@ class OTTSProvider: self.last_sync_time = 0 self.timeout = Timeout(10.0) self.retry_count = 3 - self.client = None + self._client: AsyncClient | None = None + + @property + def client(self) -> AsyncClient: + if self._client is None: + raise RuntimeError( + "Client not initialized. Please use 'async with' context." + ) + return self._client async def __aenter__(self): - self.client = AsyncClient(timeout=self.timeout) + self._client = AsyncClient(timeout=self.timeout) return self async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.client: - await self.client.aclose() + if self._client: + await self._client.aclose() + self._client = None async def _sync_time(self): try: @@ -90,6 +99,7 @@ class OTTSProvider: if attempt == self.retry_count - 1: raise RuntimeError(f"OTTS请求失败: {e!s}") from e await asyncio.sleep(0.5 * (attempt + 1)) + raise RuntimeError("OTTS未返回音频文件") class AzureNativeProvider(TTSProvider): @@ -105,7 +115,7 @@ class AzureNativeProvider(TTSProvider): self.endpoint = ( f"https://{self.region}.tts.speech.microsoft.com/cognitiveservices/v1" ) - self.client = None + self._client: AsyncClient | None = None self.token = None self.token_expire = 0 self.voice_params = { @@ -116,8 +126,16 @@ class AzureNativeProvider(TTSProvider): "volume": provider_config.get("azure_tts_volume", "100"), } + @property + def client(self) -> AsyncClient: + if self._client is None: + raise RuntimeError( + "Client not initialized. Please use 'async with' context." + ) + return self._client + async def __aenter__(self): - self.client = AsyncClient( + self._client = AsyncClient( headers={ "User-Agent": f"AstrBot/{VERSION}", "Content-Type": "application/ssml+xml", @@ -127,8 +145,9 @@ class AzureNativeProvider(TTSProvider): return self async def __aexit__(self, exc_type, exc_val, exc_tb): - if self.client: - await self.client.aclose() + if self._client: + await self._client.aclose() + self._client = None async def _refresh_token(self): token_url = ( @@ -181,8 +200,11 @@ class AzureTTSProvider(TTSProvider): key_value = provider_config.get("azure_tts_subscription_key", "") self.provider = self._parse_provider(key_value, provider_config) - def _parse_provider(self, key_value: str, config: dict) -> TTSProvider: + def _parse_provider( + self, key_value: str, config: dict + ) -> OTTSProvider | AzureNativeProvider: if key_value.lower().startswith("other["): + json_str = "" try: match = re.match(r"other\[(.*)\]", key_value, re.DOTALL) if not match: diff --git a/astrbot/core/provider/sources/bailian_rerank_source.py b/astrbot/core/provider/sources/bailian_rerank_source.py index e6f6f1a4d..9e079d4a9 100644 --- a/astrbot/core/provider/sources/bailian_rerank_source.py +++ b/astrbot/core/provider/sources/bailian_rerank_source.py @@ -177,6 +177,10 @@ class BailianRerankProvider(RerankProvider): Returns: 重排序结果列表 """ + if not self.client: + logger.error("百炼 Rerank 客户端会话已关闭,返回空结果") + return [] + if not documents: logger.warning("文档列表为空,返回空结果") return [] diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index 44e9965cc..50bc421fd 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -36,7 +36,7 @@ class ProviderDashscopeTTSAPI(TTSProvider): super().__init__(provider_config, provider_settings) self.chosen_api_key: str = provider_config.get("api_key", "") self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella") - self.set_model(provider_config.get("model")) + self.set_model(provider_config["model"]) self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000 dashscope.api_key = self.chosen_api_key @@ -71,9 +71,10 @@ class ProviderDashscopeTTSAPI(TTSProvider): kwargs = { "model": model, - "text": text, + "messages": None, "api_key": self.chosen_api_key, "voice": self.voice or "Cherry", + "text": text, } if not self.voice: logging.warning( diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 8bbf62325..71a5a82d6 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -67,7 +67,7 @@ class ProviderEdgeTTS(TTSProvider): from pyffmpeg import FFmpeg ff = FFmpeg() - ff.convert(input=mp3_path, output=wav_path) + ff.convert(input_file=mp3_path, output_file=wav_path) except Exception as e: logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") # use ffmpeg command line diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index ca571c3ee..8362ce1b4 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -59,9 +59,9 @@ class ProviderFishAudioTTSAPI(TTSProvider): self.headers = { "Authorization": f"Bearer {self.chosen_api_key}", } - self.set_model(provider_config.get("model")) + self.set_model(provider_config["model"]) - async def _get_reference_id_by_character(self, character: str) -> str: + async def _get_reference_id_by_character(self, character: str) -> str | None: """获取角色的reference_id Args: @@ -109,7 +109,7 @@ class ProviderFishAudioTTSAPI(TTSProvider): pattern = r"^[a-fA-F0-9]{32}$" return bool(re.match(pattern, reference_id.strip())) - async def _generate_request(self, text: str) -> dict: + async def _generate_request(self, text: str) -> ServeTTSRequest: # 向前兼容逻辑:优先使用reference_id,如果没有则使用角色名称查询 if self.reference_id and self.reference_id.strip(): # 验证reference_id格式 @@ -146,5 +146,6 @@ class ProviderFishAudioTTSAPI(TTSProvider): async for chunk in response.aiter_bytes(): f.write(chunk) return path - text = await response.aread() + body = await response.aread() + text = body.decode("utf-8", errors="replace") raise Exception(f"Fish Audio API请求失败: {text}") diff --git a/astrbot/core/provider/sources/gemini_embedding_source.py b/astrbot/core/provider/sources/gemini_embedding_source.py index 8d11cce5f..146b50a4e 100644 --- a/astrbot/core/provider/sources/gemini_embedding_source.py +++ b/astrbot/core/provider/sources/gemini_embedding_source.py @@ -1,3 +1,5 @@ +from typing import cast + from google import genai from google.genai import types from google.genai.errors import APIError @@ -18,8 +20,8 @@ class GeminiEmbeddingProvider(EmbeddingProvider): self.provider_config = provider_config self.provider_settings = provider_settings - api_key: str = provider_config.get("embedding_api_key") - api_base: str = provider_config.get("embedding_api_base") + api_key: str = provider_config["embedding_api_key"] + api_base: str = provider_config["embedding_api_base"] timeout: int = int(provider_config.get("timeout", 20)) http_options = types.HttpOptions(timeout=timeout * 1000) @@ -41,18 +43,26 @@ class GeminiEmbeddingProvider(EmbeddingProvider): model=self.model, contents=text, ) + assert result.embeddings is not None + assert result.embeddings[0].values is not None return result.embeddings[0].values except APIError as e: raise Exception(f"Gemini Embedding API请求失败: {e.message}") - async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" try: result = await self.client.models.embed_content( model=self.model, - contents=texts, + contents=cast(types.ContentListUnion, text), ) - return [embedding.values for embedding in result.embeddings] + assert result.embeddings is not None + + embeddings: list[list[float]] = [] + for embedding in result.embeddings: + assert embedding.values is not None + embeddings.append(embedding.values) + return embeddings except APIError as e: raise Exception(f"Gemini Embedding API批量请求失败: {e.message}") diff --git a/astrbot/core/provider/sources/gemini_source.py b/astrbot/core/provider/sources/gemini_source.py index 3bc6c67cc..e2efc6aab 100644 --- a/astrbot/core/provider/sources/gemini_source.py +++ b/astrbot/core/provider/sources/gemini_source.py @@ -4,6 +4,7 @@ import json import logging import random from collections.abc import AsyncGenerator +from typing import cast from google import genai from google.genai import types @@ -126,17 +127,17 @@ class ProviderGoogleGenAI(Provider): ) -> types.GenerateContentConfig: """准备查询配置""" if not modalities: - modalities = ["Text"] + modalities = ["TEXT"] # 流式输出不支持图片模态 if ( self.provider_settings.get("streaming_response", False) - and "Image" in modalities + and "IMAGE" in modalities ): logger.warning("流式输出不支持图片模态,已自动降级为文本模态") - modalities = ["Text"] + modalities = ["TEXT"] - tool_list = [] + tool_list: list[types.Tool] | None = [] model_name = self.get_model() native_coderunner = self.provider_config.get("gm_native_coderunner", False) native_search = self.provider_config.get("gm_native_search", False) @@ -213,7 +214,7 @@ class ProviderGoogleGenAI(Provider): logprobs=payloads.get("logprobs"), seed=payloads.get("seed"), response_modalities=modalities, - tools=tool_list, + tools=cast(types.ToolListUnion | None, tool_list), safety_settings=self.safety_settings if self.safety_settings else None, thinking_config=( types.ThinkingConfig( @@ -257,6 +258,7 @@ class ProviderGoogleGenAI(Provider): content_cls: type[types.Content], ) -> None: if contents and isinstance(contents[-1], content_cls): + assert contents[-1].parts is not None contents[-1].parts.extend(part) else: contents.append(content_cls(parts=part)) @@ -429,9 +431,9 @@ class ProviderGoogleGenAI(Provider): None, ) - modalities = ["Text"] + modalities = ["TEXT"] if self.provider_config.get("gm_resp_image_modal", False): - modalities.append("Image") + modalities.append("IMAGE") conversation = self._prepare_conversation(payloads) temperature = payloads.get("temperature", 0.7) @@ -448,7 +450,7 @@ class ProviderGoogleGenAI(Provider): ) result = await self.client.models.generate_content( model=self.get_model(), - contents=conversation, + contents=cast(types.ContentListUnion, conversation), config=config, ) logger.debug(f"genai result: {result}") @@ -488,7 +490,7 @@ class ProviderGoogleGenAI(Provider): logger.warning( f"{self.get_model()} 不支持多模态输出,降级为文本模态", ) - modalities = ["Text"] + modalities = ["TEXT"] else: raise continue @@ -524,7 +526,7 @@ class ProviderGoogleGenAI(Provider): ) result = await self.client.models.generate_content_stream( model=self.get_model(), - contents=conversation, + contents=cast(types.ContentListUnion, conversation), config=config, ) break diff --git a/astrbot/core/provider/sources/minimax_tts_api_source.py b/astrbot/core/provider/sources/minimax_tts_api_source.py index 5ffc7cc63..9e2d665c7 100644 --- a/astrbot/core/provider/sources/minimax_tts_api_source.py +++ b/astrbot/core/provider/sources/minimax_tts_api_source.py @@ -87,7 +87,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider): return json.dumps(dict_body) - async def _call_tts_stream(self, text: str) -> AsyncIterator[bytes]: + async def _call_tts_stream(self, text: str) -> AsyncIterator[str]: """进行流式请求""" try: async with ( @@ -117,7 +117,9 @@ class ProviderMiniMaxTTSAPI(TTSProvider): data = json.loads(message[6:]) if "extra_info" in data: continue - audio = data.get("data", {}).get("audio") + audio: str | None = data.get("data", {}).get( + "audio" + ) if audio is not None: yield audio except json.JSONDecodeError: diff --git a/astrbot/core/provider/sources/openai_embedding_source.py b/astrbot/core/provider/sources/openai_embedding_source.py index 368e610ec..c9e03d7af 100644 --- a/astrbot/core/provider/sources/openai_embedding_source.py +++ b/astrbot/core/provider/sources/openai_embedding_source.py @@ -30,9 +30,9 @@ class OpenAIEmbeddingProvider(EmbeddingProvider): embedding = await self.client.embeddings.create(input=text, model=self.model) return embedding.data[0].embedding - async def get_embeddings(self, texts: list[str]) -> list[list[float]]: + async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的嵌入""" - embeddings = await self.client.embeddings.create(input=texts, model=self.model) + embeddings = await self.client.embeddings.create(input=text, model=self.model) return [item.embedding for item in embeddings.data] def get_dim(self) -> int: diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index cce3f01c9..788b649a9 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -284,6 +284,10 @@ class ProviderOpenAIOfficial(Provider): if isinstance(tool_call, str): # workaround for #1359 tool_call = json.loads(tool_call) + if tools is None: + # 工具集未提供 + # Should be unreachable + raise Exception("工具集未提供") for tool in tools.func_list: if ( tool_call.type == "function" diff --git a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py index 67947c685..a41bd72fd 100644 --- a/astrbot/core/provider/sources/sensevoice_selfhosted_source.py +++ b/astrbot/core/provider/sources/sensevoice_selfhosted_source.py @@ -7,6 +7,7 @@ import asyncio import os import re from datetime import datetime +from typing import cast from funasr_onnx import SenseVoiceSmall from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess @@ -32,7 +33,7 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) - self.set_model(provider_config.get("stt_model")) + self.set_model(provider_config["stt_model"]) self.model = None self.is_emotion = provider_config.get("is_emotion", False) @@ -86,7 +87,9 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider): loop = asyncio.get_event_loop() res = await loop.run_in_executor( None, # 使用默认的线程池 - lambda: self.model(audio_url, language="auto", use_itn=True), + lambda: cast(SenseVoiceSmall, self.model)( + audio_url, language="auto", use_itn=True + ), ) # res = self.model(audio_url, language="auto", use_itn=True) diff --git a/astrbot/core/provider/sources/vllm_rerank_source.py b/astrbot/core/provider/sources/vllm_rerank_source.py index 3e6f3d33c..edd8a5491 100644 --- a/astrbot/core/provider/sources/vllm_rerank_source.py +++ b/astrbot/core/provider/sources/vllm_rerank_source.py @@ -44,6 +44,7 @@ class VLLMRerankProvider(RerankProvider): } if top_n is not None: payload["top_n"] = top_n + assert self.client is not None async with self.client.post( f"{self.base_url}/v1/rerank", json=payload, diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 8f6d9e292..fa69206ef 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -6,7 +6,10 @@ from openai import NOT_GIVEN, AsyncOpenAI from astrbot.core import logger from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.io import download_file -from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav +from astrbot.core.utils.tencent_record_helper import ( + convert_to_pcm_wav, + tencent_silk_to_wav, +) from ..entities import ProviderType from ..provider import STTProvider @@ -33,20 +36,30 @@ class ProviderOpenAIWhisperAPI(STTProvider): timeout=provider_config.get("timeout", NOT_GIVEN), ) - self.set_model(provider_config.get("model")) + self.set_model(provider_config["model"]) - async def _is_silk_file(self, file_path): + async def _get_audio_format(self, file_path): + # 定义要检测的头部字节 silk_header = b"SILK" - with open(file_path, "rb") as f: - file_header = f.read(8) + amr_header = b"#!AMR" + + try: + with open(file_path, "rb") as f: + file_header = f.read(8) + except FileNotFoundError: + return None if silk_header in file_header: - return True - return False + return "silk" + + if amr_header in file_header: + return "amr" + return None async def get_text(self, audio_url: str) -> str: """Only supports mp3, mp4, mpeg, m4a, wav, webm""" is_tencent = False + output_path = None if audio_url.startswith("http"): if "multimedia.nt.qq.com.cn" in audio_url: @@ -62,16 +75,35 @@ class ProviderOpenAIWhisperAPI(STTProvider): raise FileNotFoundError(f"文件不存在: {audio_url}") if audio_url.endswith(".amr") or audio_url.endswith(".silk") or is_tencent: - is_silk = await self._is_silk_file(audio_url) - if is_silk: - logger.info("Converting silk file to wav ...") + file_format = await self._get_audio_format(audio_url) + + # 判断是否需要转换 + if file_format in ["silk", "amr"]: temp_dir = os.path.join(get_astrbot_data_path(), "temp") output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav") - await tencent_silk_to_wav(audio_url, output_path) + + if file_format == "silk": + logger.info( + "Converting silk file to wav using tencent_silk_to_wav..." + ) + await tencent_silk_to_wav(audio_url, output_path) + elif file_format == "amr": + logger.info( + "Converting amr file to wav using convert_to_pcm_wav..." + ) + await convert_to_pcm_wav(audio_url, output_path) + audio_url = output_path result = await self.client.audio.transcriptions.create( model=self.model_name, - file=open(audio_url, "rb"), + file=("audio.wav", open(audio_url, "rb")), ) + + # remove temp file + if output_path and os.path.exists(output_path): + try: + os.remove(audio_url) + except Exception as e: + logger.error(f"Failed to remove temp file {audio_url}: {e}") return result.text diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index fbdc7d626..a14f93f14 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -1,6 +1,7 @@ import asyncio import os import uuid +from typing import cast import whisper @@ -26,7 +27,7 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): provider_settings: dict, ) -> None: super().__init__(provider_config, provider_settings) - self.set_model(provider_config.get("model")) + self.set_model(provider_config["model"]) self.model = None async def initialize(self): @@ -75,5 +76,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): await tencent_silk_to_wav(audio_url, output_path) audio_url = output_path + if not self.model: + raise RuntimeError("Whisper 模型未初始化") + result = await loop.run_in_executor(None, self.model.transcribe, audio_url) - return result["text"] + return cast(str, result["text"]) diff --git a/astrbot/core/provider/sources/xinference_rerank_source.py b/astrbot/core/provider/sources/xinference_rerank_source.py index 29f3ab095..960408550 100644 --- a/astrbot/core/provider/sources/xinference_rerank_source.py +++ b/astrbot/core/provider/sources/xinference_rerank_source.py @@ -1,6 +1,11 @@ +from typing import cast + from xinference_client.client.restful.async_restful_client import ( AsyncClient as Client, ) +from xinference_client.client.restful.async_restful_client import ( + AsyncRESTfulRerankModelHandle, +) from astrbot import logger @@ -29,7 +34,7 @@ class XinferenceRerankProvider(RerankProvider): False, ) self.client = None - self.model = None + self.model: AsyncRESTfulRerankModelHandle | None = None self.model_uid = None async def initialize(self): @@ -65,7 +70,10 @@ class XinferenceRerankProvider(RerankProvider): return if self.model_uid: - self.model = await self.client.get_model(self.model_uid) + self.model = cast( + AsyncRESTfulRerankModelHandle, + await self.client.get_model(self.model_uid), + ) except Exception as e: logger.error(f"Failed to initialize Xinference model: {e}") diff --git a/astrbot/core/star/context.py b/astrbot/core/star/context.py index 21c1ad8fd..9a52ec8bc 100644 --- a/astrbot/core/star/context.py +++ b/astrbot/core/star/context.py @@ -285,7 +285,7 @@ class Context: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts - def get_using_provider(self, umo: str | None = None) -> Provider | None: + def get_using_provider(self, umo: str | None = None) -> Provider: """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 Args: @@ -296,7 +296,7 @@ class Context: provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) - if prov and not isinstance(prov, Provider): + if not isinstance(prov, Provider): raise ValueError("返回的 Provider 不是 Provider 类型") return prov diff --git a/astrbot/core/star/register/star_handler.py b/astrbot/core/star/register/star_handler.py index ee3c09680..daf36a8f6 100644 --- a/astrbot/core/star/register/star_handler.py +++ b/astrbot/core/star/register/star_handler.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from typing import Any import docstring_parser @@ -12,6 +12,7 @@ from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.tool import FunctionTool from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES from astrbot.core.provider.register import llm_tools @@ -28,13 +29,19 @@ from ..filter.regex import RegexFilter from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry -def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str: +def get_handler_full_name( + awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]], +) -> str: """获取 Handler 的全名""" return f"{awaitable.__module__}_{awaitable.__name__}" def get_handler_or_create( - handler: Callable[..., Awaitable[Any]], + handler: Callable[ + ..., + Awaitable[MessageEventResult | str | None] + | AsyncGenerator[MessageEventResult | str | None], + ], event_type: EventType, dont_add=False, **kwargs, @@ -169,6 +176,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs): for ( sub_handle ) in parent_register_commandable.parent_group.sub_command_filters: + if isinstance(sub_handle, CommandGroupFilter): + continue # 所有符合fullname一致的子指令handle添加自定义过滤器。 # 不确定是否会有多个子指令有一样的fullname,比如一个方法添加多个command装饰器? sub_handle_md = sub_handle.get_handler_md() @@ -180,6 +189,8 @@ def register_custom_filter(custom_type_filter, *args, **kwargs): else: # 裸指令 + # 确保运行时是可调用的 handler,针对类型检查器添加忽略 + assert isinstance(awaitable, Callable) handler_md = get_handler_or_create( awaitable, EventType.AdapterMessageEvent, @@ -237,7 +248,7 @@ class RegisteringCommandable: group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group command: Callable[..., Callable[..., None]] = register_command - custom_filter: Callable[..., Callable[..., None]] = register_custom_filter + custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter def __init__(self, parent_group: CommandGroupFilter): self.parent_group = parent_group @@ -412,7 +423,13 @@ def register_llm_tool(name: str | None = None, **kwargs): if kwargs.get("registering_agent"): registering_agent = kwargs["registering_agent"] - def decorator(awaitable: Callable[..., Awaitable[Any]]): + def decorator( + awaitable: Callable[ + ..., + AsyncGenerator[MessageEventResult | str | None] + | Awaitable[MessageEventResult | str | None], + ], + ): llm_tool_name = name_ if name_ else awaitable.__name__ func_doc = awaitable.__doc__ or "" docstring = docstring_parser.parse(func_doc) diff --git a/astrbot/core/star/star_handler.py b/astrbot/core/star/star_handler.py index 69a779b41..be5b4679f 100644 --- a/astrbot/core/star/star_handler.py +++ b/astrbot/core/star/star_handler.py @@ -1,9 +1,9 @@ from __future__ import annotations import enum -from collections.abc import Awaitable, Callable +from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass, field -from typing import Any, Generic, TypeVar +from typing import Any, Generic, Literal, TypeVar, overload from .filter import HandlerFilter from .star import star_map @@ -29,6 +29,84 @@ class StarHandlerRegistry(Generic[T]): for handler in self._handlers: print(handler.handler_full_name) + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnAstrBotLoadedEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnPlatformLoadedEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.AdapterMessageEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[ + StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]] + ]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnLLMRequestEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnLLMResponseEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnDecoratingResultEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnCallingFuncToolEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[ + StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]] + ]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: Literal[EventType.OnAfterMessageSentEvent], + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[StarHandlerMetadata[Callable[..., Awaitable[Any]]]]: ... + + @overload + def get_handlers_by_event_type( + self, + event_type: EventType, + only_activated=True, + plugins_name: list[str] | None = None, + ) -> list[ + StarHandlerMetadata[Callable[..., Awaitable[Any] | AsyncGenerator[Any]]] + ]: ... + def get_handlers_by_event_type( self, event_type: EventType, @@ -113,8 +191,11 @@ class EventType(enum.Enum): OnAfterMessageSentEvent = enum.auto() # 发送消息后 +H = TypeVar("H", bound=Callable[..., Any]) + + @dataclass -class StarHandlerMetadata: +class StarHandlerMetadata(Generic[H]): """描述一个 Star 所注册的某一个 Handler。""" event_type: EventType @@ -129,7 +210,7 @@ class StarHandlerMetadata: handler_module_path: str """Handler 所在的模块路径。""" - handler: Callable[..., Awaitable[Any]] + handler: H """Handler 的函数对象,应当是一个异步函数""" event_filters: list[HandlerFilter] diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index d13bab687..0a7116a0d 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -71,10 +71,10 @@ class AstrBotUpdator(RepoZipUpdator): async def check_update( self, - url: str, - current_version: str, + url: str | None, + current_version: str | None, consider_prerelease: bool = True, - ) -> ReleaseInfo: + ) -> ReleaseInfo | None: """检查更新""" return await super().check_update( self.ASTRBOT_RELEASE_API, diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 073c04938..fcf5bb3c7 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -49,7 +49,7 @@ def port_checker(port: int, host: str = "localhost"): return False -def save_temp_img(img: Image.Image | str) -> str: +def save_temp_img(img: Image.Image | bytes) -> str: temp_dir = os.path.join(get_astrbot_data_path(), "temp") # 获得文件创建时间,清除超过 12 小时的 try: diff --git a/astrbot/core/utils/session_waiter.py b/astrbot/core/utils/session_waiter.py index 33b7cb17a..e1f2fbef7 100644 --- a/astrbot/core/utils/session_waiter.py +++ b/astrbot/core/utils/session_waiter.py @@ -20,16 +20,16 @@ class SessionController: def __init__(self): self.future = asyncio.Future() - self.current_event: asyncio.Event = None + self.current_event: asyncio.Event | None = None """当前正在等待的所用的异步事件""" - self.ts: float = None + self.ts: float | None = None """上次保持(keep)开始时的时间""" - self.timeout: float | int = None + self.timeout: float | int | None = None """上次保持(keep)开始时的超时时间""" self.history_chains: list[list[Comp.BaseMessageComponent]] = [] - def stop(self, error: Exception = None): + def stop(self, error: Exception | None = None): """立即结束这个会话""" if not self.future.done(): if error: @@ -53,6 +53,8 @@ class SessionController: self.stop() return else: + assert self.timeout is not None + assert self.ts is not None left_timeout = self.timeout - (new_ts - self.ts) timeout = left_timeout + timeout if timeout <= 0: @@ -69,7 +71,7 @@ class SessionController: asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep - async def _holding(self, event: asyncio.Event, timeout: int): + async def _holding(self, event: asyncio.Event, timeout: float): """等待事件结束或超时""" try: await asyncio.wait_for(event.wait(), timeout) @@ -108,7 +110,9 @@ class SessionWaiter: ): self.session_id = session_id self.session_filter = session_filter - self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数 + self.handler: ( + Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None + ) = None # 处理函数 self.session_controller = SessionController() self.record_history_chains = record_history_chains @@ -119,7 +123,7 @@ class SessionWaiter: async def register_wait( self, - handler: Callable[[str], Awaitable[Any]], + handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]], timeout: int = 30, ) -> Any: """等待外部输入并处理""" @@ -137,7 +141,7 @@ class SessionWaiter: finally: self._cleanup() - def _cleanup(self, error: Exception = None): + def _cleanup(self, error: Exception | None = None): """清理会话""" USER_SESSIONS.pop(self.session_id, None) try: @@ -161,6 +165,7 @@ class SessionWaiter: ) try: # TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行 + assert session.handler is not None await session.handler(session.session_controller, event) except Exception as e: session.session_controller.stop(e) @@ -173,11 +178,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False): :param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。 """ - def decorator(func: Callable[[str], Awaitable[Any]]): + def decorator( + func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]], + ): @functools.wraps(func) async def wrapper( event: AstrMessageEvent, - session_filter: SessionFilter = None, + session_filter: SessionFilter | None = None, *args, **kwargs, ): diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index 6b1f52a69..ccd394ee4 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -53,6 +53,38 @@ class SharedPreferences: ret = await self.db_helper.get_preferences(scope, scope_id, key) return ret + @overload + async def session_get( + self, + umo: str, + key: str, + default: _VT = None, + ) -> _VT: ... + + @overload + async def session_get( + self, + umo: None, + key: str, + default: Any = None, + ) -> list[Preference]: ... + + @overload + async def session_get( + self, + umo: str, + key: None, + default: Any = None, + ) -> list[Preference]: ... + + @overload + async def session_get( + self, + umo: None, + key: None, + default: Any = None, + ) -> list[Preference]: ... + async def session_get( self, umo: str | None, diff --git a/astrbot/core/utils/t2i/__init__.py b/astrbot/core/utils/t2i/__init__.py index 5038a46f7..e4112c354 100644 --- a/astrbot/core/utils/t2i/__init__.py +++ b/astrbot/core/utils/t2i/__init__.py @@ -3,11 +3,11 @@ from abc import ABC, abstractmethod class RenderStrategy(ABC): @abstractmethod - def render(self, text: str, return_url: bool) -> str: + async def render(self, text: str, return_url: bool) -> str: pass @abstractmethod - def render_custom_template( + async def render_custom_template( self, tmpl_str: str, tmpl_data: dict, diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py index 19eab2efe..2fa235129 100644 --- a/astrbot/core/utils/t2i/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -20,7 +20,7 @@ class FontManager: _font_cache = {} @classmethod - def get_font(cls, size: int) -> ImageFont.FreeTypeFont: + def get_font(cls, size: int) -> ImageFont.FreeTypeFont|ImageFont.ImageFont: """获取指定大小的字体,优先从缓存获取""" if size in cls._font_cache: return cls._font_cache[size] @@ -66,23 +66,17 @@ class TextMeasurer: """测量文本尺寸的工具类""" @staticmethod - def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: + def get_text_size(text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont) -> tuple[int, int]: """获取文本的尺寸""" - try: - # PIL 9.0.0 以上版本 - return ( - font.getbbox(text)[2:] - if hasattr(font, "getbbox") - else font.getsize(text) - ) - except Exception: - # 兼容旧版本 - return font.getsize(text) + + # 依赖库Pillow>=11.2.1,不再需要考虑<9.0.0 + left, top, right, bottom = font.getbbox("Hello world") + return int(right - left), int(bottom - top) @staticmethod def split_text_to_fit_width( - text: str, font: ImageFont.FreeTypeFont, max_width: int - ) -> List[str]: + text: str, font: ImageFont.FreeTypeFont|ImageFont.ImageFont, max_width: int + ) -> list[str]: """将文本拆分为多行,确保每行不超过指定宽度""" lines = [] if not text: @@ -126,7 +120,7 @@ class MarkdownElement(ABC): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -152,7 +146,7 @@ class TextElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -186,7 +180,7 @@ class BoldTextElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -251,7 +245,7 @@ class ItalicTextElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -299,7 +293,7 @@ class ItalicTextElement(MarkdownElement): # 倾斜变换,使用仿射变换实现斜体效果 # 变换矩阵: [1, 0.2, 0, 0, 1, 0] italic_img = text_img.transform( - text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC + text_img.size, Image.Transform.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.Resampling.BICUBIC ) # 粘贴到原图像 @@ -331,7 +325,7 @@ class UnderlineTextElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -371,7 +365,7 @@ class StrikethroughTextElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -422,7 +416,7 @@ class HeaderElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -458,7 +452,7 @@ class QuoteElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -502,7 +496,7 @@ class ListItemElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -532,7 +526,7 @@ class ListItemElement(MarkdownElement): class CodeBlockElement(MarkdownElement): """代码块元素""" - def __init__(self, content: List[str]): + def __init__(self, content: list[str]): super().__init__("\n".join(content)) def calculate_height(self, image_width: int, font_size: int) -> int: @@ -552,7 +546,7 @@ class CodeBlockElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -595,7 +589,7 @@ class InlineCodeElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -667,7 +661,7 @@ class ImageElement(MarkdownElement): def render( self, image: Image.Image, - draw: ImageDraw.Draw, + draw: ImageDraw.ImageDraw, x: int, y: int, image_width: int, @@ -686,7 +680,7 @@ class ImageElement(MarkdownElement): if pasted_image.width > max_width: ratio = max_width / pasted_image.width new_size = (int(max_width), int(pasted_image.height * ratio)) - pasted_image = pasted_image.resize(new_size, Image.LANCZOS) + pasted_image = pasted_image.resize(new_size, Image.Resampling.LANCZOS) # 计算居中位置 paste_x = x + (image_width - pasted_image.width) // 2 - 10 @@ -705,7 +699,7 @@ class MarkdownParser: """Markdown解析器,将文本解析为元素""" @staticmethod - async def parse(text: str) -> List[MarkdownElement]: + async def parse(text: str) -> list[MarkdownElement]: elements = [] lines = text.split("\n") @@ -847,7 +841,7 @@ class MarkdownRenderer: self, font_size: int = 26, width: int = 800, - bg_color: Tuple[int, int, int] = (255, 255, 255), + bg_color: tuple[int, int, int] = (255, 255, 255), ): self.font_size = font_size self.width = width diff --git a/astrbot/core/utils/tencent_record_helper.py b/astrbot/core/utils/tencent_record_helper.py index 9cc36571e..b58643bd3 100644 --- a/astrbot/core/utils/tencent_record_helper.py +++ b/astrbot/core/utils/tencent_record_helper.py @@ -36,7 +36,7 @@ async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int: import pilk except (ImportError, ModuleNotFoundError) as _: raise Exception( - "pilk 模块未安装,请前往管理面板->控制台->安装pip库 安装 pilk 这个库", + "pilk 模块未安装,请前往管理面板->平台日志->安装pip库 安装 pilk 这个库", ) # with wave.open(wav_path, 'rb') as wav: # wav_data = wav.readframes(wav.getnframes()) @@ -68,7 +68,7 @@ async def convert_to_pcm_wav(input_path: str, output_path: str) -> str: from pyffmpeg import FFmpeg ff = FFmpeg() - ff.convert(input=input_path, output=output_path) + ff.convert(input_file=input_path, output_file=output_path) except Exception as e: logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换") diff --git a/astrbot/core/utils/version_comparator.py b/astrbot/core/utils/version_comparator.py index e3bf74951..4ad2da10e 100644 --- a/astrbot/core/utils/version_comparator.py +++ b/astrbot/core/utils/version_comparator.py @@ -60,9 +60,12 @@ class VersionComparator: return -1 if isinstance(p1, str) and isinstance(p2, int): return 1 - if (isinstance(p1, int) and isinstance(p2, int)) or ( - isinstance(p1, str) and isinstance(p2, str) - ): + if isinstance(p1, int) and isinstance(p2, int): + if p1 > p2: + return 1 + if p1 < p2: + return -1 + if isinstance(p1, str) and isinstance(p2, str): if p1 > p2: return 1 if p1 < p2: diff --git a/astrbot/core/utils/webhook_utils.py b/astrbot/core/utils/webhook_utils.py new file mode 100644 index 000000000..c56d00b37 --- /dev/null +++ b/astrbot/core/utils/webhook_utils.py @@ -0,0 +1,47 @@ +from astrbot.core import astrbot_config, logger + + +def _get_callback_api_base() -> str: + try: + return astrbot_config.get("callback_api_base", "").rstrip("/") + except Exception as e: + logger.error(f"获取 callback_api_base 失败: {e!s}") + return "" + + +def _get_dashboard_port() -> int: + try: + return astrbot_config.get("dashboard", {}).get("port", 6185) + except Exception as e: + logger.error(f"获取 dashboard 端口失败: {e!s}") + return 6185 + + +def log_webhook_info(platform_name: str, webhook_uuid: str): + """打印美观的 webhook 信息日志 + + Args: + platform_name: 平台名称 + webhook_uuid: webhook 的 UUID + """ + + callback_base = _get_callback_api_base() + + if not callback_base: + callback_base = "http(s)://" + + if not callback_base.startswith("http"): + callback_base = f"http(s)://{callback_base}" + + callback_base = callback_base.rstrip("/") + webhook_url = f"{callback_base}/api/platform/webhook/{webhook_uuid}" + + display_log = ( + "\n====================\n" + f"🔗 机器人平台 {platform_name} 已启用统一 Webhook 模式\n" + f"📍 Webhook 回调地址: \n" + f" ➜ http://:{_get_dashboard_port()}/api/platform/webhook/{webhook_uuid}\n" + f" ➜ {webhook_url}\n" + "====================\n" + ) + logger.info(display_log) diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index 03e9ea798..951db956c 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -7,6 +7,7 @@ from .file import FileRoute from .knowledge_base import KnowledgeBaseRoute from .log import LogRoute from .persona import PersonaRoute +from .platform import PlatformRoute from .plugin import PluginRoute from .session_management import SessionManagementRoute from .stat import StatRoute @@ -24,6 +25,7 @@ __all__ = [ "KnowledgeBaseRoute", "LogRoute", "PersonaRoute", + "PlatformRoute", "PluginRoute", "SessionManagementRoute", "StatRoute", diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index 5381b5649..cfb750803 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -1,11 +1,13 @@ import asyncio import json +import mimetypes import os import uuid from contextlib import asynccontextmanager +from typing import cast from quart import Response as QuartResponse -from quart import g, make_response, request +from quart import g, make_response, request, send_file from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -44,7 +46,7 @@ class ChatRoute(Route): self.update_session_display_name, ), "/chat/get_file": ("GET", self.get_file), - "/chat/post_image": ("POST", self.post_image), + "/chat/get_attachment": ("GET", self.get_attachment), "/chat/post_file": ("POST", self.post_file), } self.core_lifecycle = core_lifecycle @@ -73,52 +75,184 @@ class ChatRoute(Route): if not real_file_path.startswith(real_imgs_dir): return Response().error("Invalid file path").__dict__ - with open(real_file_path, "rb") as f: - filename_ext = os.path.splitext(filename)[1].lower() - - if filename_ext == ".wav": - return QuartResponse(f.read(), mimetype="audio/wav") - if filename_ext[1:] in self.supported_imgs: - return QuartResponse(f.read(), mimetype="image/jpeg") - return QuartResponse(f.read()) + filename_ext = os.path.splitext(filename)[1].lower() + if filename_ext == ".wav": + return await send_file(real_file_path, mimetype="audio/wav") + if filename_ext[1:] in self.supported_imgs: + return await send_file(real_file_path, mimetype="image/jpeg") + return await send_file(real_file_path) except (FileNotFoundError, OSError): return Response().error("File access error").__dict__ - async def post_image(self): - post_data = await request.files - if "file" not in post_data: - return Response().error("Missing key: file").__dict__ + async def get_attachment(self): + """Get attachment file by attachment_id.""" + attachment_id = request.args.get("attachment_id") + if not attachment_id: + return Response().error("Missing key: attachment_id").__dict__ - file = post_data["file"] - filename = str(uuid.uuid4()) + ".jpg" - path = os.path.join(self.imgs_dir, filename) - await file.save(path) + try: + attachment = await self.db.get_attachment_by_id(attachment_id) + if not attachment: + return Response().error("Attachment not found").__dict__ - return Response().ok(data={"filename": filename}).__dict__ + file_path = attachment.path + real_file_path = os.path.realpath(file_path) + + return await send_file(real_file_path, mimetype=attachment.mime_type) + + except (FileNotFoundError, OSError): + return Response().error("File access error").__dict__ async def post_file(self): + """Upload a file and create an attachment record, return attachment_id.""" post_data = await request.files if "file" not in post_data: return Response().error("Missing key: file").__dict__ file = post_data["file"] - filename = f"{uuid.uuid4()!s}" - # 通过文件格式判断文件类型 - if file.content_type.startswith("audio"): - filename += ".wav" + filename = file.filename or f"{uuid.uuid4()!s}" + content_type = file.content_type or "application/octet-stream" + + # 根据 content_type 判断文件类型并添加扩展名 + if content_type.startswith("image"): + attach_type = "image" + elif content_type.startswith("audio"): + attach_type = "record" + elif content_type.startswith("video"): + attach_type = "video" + else: + attach_type = "file" path = os.path.join(self.imgs_dir, filename) await file.save(path) - return Response().ok(data={"filename": filename}).__dict__ + # 创建 attachment 记录 + attachment = await self.db.insert_attachment( + path=path, + type=attach_type, + mime_type=content_type, + ) + + if not attachment: + return Response().error("Failed to create attachment").__dict__ + + filename = os.path.basename(attachment.path) + + return ( + Response() + .ok( + data={ + "attachment_id": attachment.attachment_id, + "filename": filename, + "type": attach_type, + } + ) + .__dict__ + ) + + async def _build_user_message_parts(self, message: str | list) -> list[dict]: + """构建用户消息的部分列表 + + Args: + message: 文本消息 (str) 或消息段列表 (list) + """ + parts = [] + + if isinstance(message, list): + for part in message: + part_type = part.get("type") + if part_type == "plain": + parts.append({"type": "plain", "text": part.get("text", "")}) + elif part_type == "reply": + parts.append( + {"type": "reply", "message_id": part.get("message_id")} + ) + elif attachment_id := part.get("attachment_id"): + attachment = await self.db.get_attachment_by_id(attachment_id) + if attachment: + parts.append( + { + "type": attachment.type, + "attachment_id": attachment.attachment_id, + "filename": os.path.basename(attachment.path), + "path": attachment.path, # will be deleted + } + ) + return parts + + if message: + parts.append({"type": "plain", "text": message}) + + return parts + + async def _create_attachment_from_file( + self, filename: str, attach_type: str + ) -> dict | None: + """从本地文件创建 attachment 并返回消息部分 + + 用于处理 bot 回复中的媒体文件 + + Args: + filename: 存储的文件名 + attach_type: 附件类型 (image, record, file, video) + """ + file_path = os.path.join(self.imgs_dir, os.path.basename(filename)) + if not os.path.exists(file_path): + return None + + # guess mime type + mime_type, _ = mimetypes.guess_type(filename) + if not mime_type: + mime_type = "application/octet-stream" + + # insert attachment + attachment = await self.db.insert_attachment( + path=file_path, + type=attach_type, + mime_type=mime_type, + ) + if not attachment: + return None + + return { + "type": attach_type, + "attachment_id": attachment.attachment_id, + "filename": os.path.basename(file_path), + } + + async def _save_bot_message( + self, + webchat_conv_id: str, + text: str, + media_parts: list, + reasoning: str, + ): + """保存 bot 消息到历史记录,返回保存的记录""" + bot_message_parts = [] + if text: + bot_message_parts.append({"type": "plain", "text": text}) + bot_message_parts.extend(media_parts) + + new_his = {"type": "bot", "message": bot_message_parts} + if reasoning: + new_his["reasoning"] = reasoning + + record = await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content=new_his, + sender_id="bot", + sender_name="bot", + ) + return record async def chat(self): username = g.get("username", "guest") post_data = await request.json - if "message" not in post_data and "image_url" not in post_data: - return Response().error("Missing key: message or image_url").__dict__ + if "message" not in post_data and "files" not in post_data: + return Response().error("Missing key: message or files").__dict__ if "session_id" not in post_data and "conversation_id" not in post_data: return ( @@ -126,44 +260,40 @@ class ChatRoute(Route): ) message = post_data["message"] - # conversation_id = post_data["conversation_id"] session_id = post_data.get("session_id", post_data.get("conversation_id")) - image_url = post_data.get("image_url") - audio_url = post_data.get("audio_url") selected_provider = post_data.get("selected_provider") selected_model = post_data.get("selected_model") - enable_streaming = post_data.get("enable_streaming", True) # 默认为 True + enable_streaming = post_data.get("enable_streaming", True) - if not message and not image_url and not audio_url: - return ( - Response() - .error("Message and image_url and audio_url are empty") - .__dict__ + # 检查消息是否为空 + if isinstance(message, list): + has_content = any( + part.get("type") in ("plain", "image", "record", "file", "video") + for part in message ) + if not has_content: + return ( + Response() + .error("Message content is empty (reply only is not allowed)") + .__dict__ + ) + elif not message: + return Response().error("Message are both empty").__dict__ + if not session_id: return Response().error("session_id is empty").__dict__ - # 追加用户消息 webchat_conv_id = session_id - - # 获取会话特定的队列 back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id) - new_his = {"type": "user", "message": message} - if image_url: - new_his["image_url"] = image_url - if audio_url: - new_his["audio_url"] = audio_url - await self.platform_history_mgr.insert( - platform_id="webchat", - user_id=webchat_conv_id, - content=new_his, - sender_id=username, - sender_name=username, - ) + # 构建用户消息段(包含 path 用于传递给 adapter) + message_parts = await self._build_user_message_parts(message) async def stream(): client_disconnected = False + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" try: async with track_conversation(self.running_convs, webchat_conv_id): @@ -182,16 +312,17 @@ class ChatRoute(Route): continue result_text = result["data"] - type = result.get("type") + msg_type = result.get("type") streaming = result.get("streaming", False) + # 发送 SSE 数据 try: if not client_disconnected: yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n" except Exception as e: if not client_disconnected: logger.debug( - f"[WebChat] 用户 {username} 断开聊天长连接。 {e}", + f"[WebChat] 用户 {username} 断开聊天长连接。 {e}" ) client_disconnected = True @@ -202,24 +333,68 @@ class ChatRoute(Route): logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。") client_disconnected = True - if type == "end": + # 累积消息部分 + if msg_type == "plain": + chain_type = result.get("chain_type", "normal") + if chain_type == "reasoning": + accumulated_reasoning += result_text + else: + accumulated_text += result_text + elif msg_type == "image": + filename = result_text.replace("[IMAGE]", "") + part = await self._create_attachment_from_file( + filename, "image" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "record": + filename = result_text.replace("[RECORD]", "") + part = await self._create_attachment_from_file( + filename, "record" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "file": + # 格式: [FILE]filename + filename = result_text.replace("[FILE]", "") + part = await self._create_attachment_from_file( + filename, "file" + ) + if part: + accumulated_parts.append(part) + + # 消息结束处理 + if msg_type == "end": break elif ( - (streaming and type == "complete") + (streaming and msg_type == "complete") or not streaming - or type == "break" + or msg_type == "break" ): - # 追加机器人消息 - new_his = {"type": "bot", "message": result_text} - if "reasoning" in result: - new_his["reasoning"] = result["reasoning"] - await self.platform_history_mgr.insert( - platform_id="webchat", - user_id=webchat_conv_id, - content=new_his, - sender_id="bot", - sender_name="bot", + saved_record = await self._save_bot_message( + webchat_conv_id, + accumulated_text, + accumulated_parts, + accumulated_reasoning, ) + # 发送保存的消息信息给前端 + if saved_record and not client_disconnected: + saved_info = { + "type": "message_saved", + "data": { + "id": saved_record.id, + "created_at": saved_record.created_at.astimezone().isoformat(), + }, + } + try: + yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n" + except Exception: + pass + # 重置累积变量 (对于 break 后的下一段消息) + if msg_type == "break": + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" except BaseException as e: logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) @@ -230,9 +405,7 @@ class ChatRoute(Route): username, webchat_conv_id, { - "message": message, - "image_url": image_url, # list - "audio_url": audio_url, + "message": message_parts, "selected_provider": selected_provider, "selected_model": selected_model, "enable_streaming": enable_streaming, @@ -240,14 +413,30 @@ class ChatRoute(Route): ), ) - response = await make_response( - stream(), - { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Transfer-Encoding": "chunked", - "Connection": "keep-alive", - }, + message_parts_for_storage = [] + for part in message_parts: + part_copy = {k: v for k, v in part.items() if k != "path"} + message_parts_for_storage.append(part_copy) + + await self.platform_history_mgr.insert( + platform_id="webchat", + user_id=webchat_conv_id, + content={"type": "user", "message": message_parts_for_storage}, + sender_id=username, + sender_name=username, + ) + + response = cast( + QuartResponse, + await make_response( + stream(), + { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Transfer-Encoding": "chunked", + "Connection": "keep-alive", + }, + ), ) response.timeout = None # fix SSE auto disconnect issue return response @@ -271,6 +460,17 @@ class ChatRoute(Route): unified_msg_origin = f"{session.platform_id}:{message_type}:{session.platform_id}!{username}!{session_id}" await self.conv_mgr.delete_conversations_by_user_id(unified_msg_origin) + # 获取消息历史中的所有附件 ID 并删除附件 + history_list = await self.platform_history_mgr.get( + platform_id=session.platform_id, + user_id=session_id, + page=1, + page_size=100000, # 获取足够多的记录 + ) + attachment_ids = self._extract_attachment_ids(history_list) + if attachment_ids: + await self._delete_attachments(attachment_ids) + # 删除消息历史 await self.platform_history_mgr.delete( platform_id=session.platform_id, @@ -297,6 +497,41 @@ class ChatRoute(Route): return Response().ok().__dict__ + def _extract_attachment_ids(self, history_list) -> list[str]: + """从消息历史中提取所有 attachment_id""" + attachment_ids = [] + for history in history_list: + content = history.content + if not content or "message" not in content: + continue + message_parts = content.get("message", []) + for part in message_parts: + if isinstance(part, dict) and "attachment_id" in part: + attachment_ids.append(part["attachment_id"]) + return attachment_ids + + async def _delete_attachments(self, attachment_ids: list[str]): + """删除附件(包括数据库记录和磁盘文件)""" + try: + attachments = await self.db.get_attachments(attachment_ids) + for attachment in attachments: + if not os.path.exists(attachment.path): + continue + try: + os.remove(attachment.path) + except OSError as e: + logger.warning( + f"Failed to delete attachment file {attachment.path}: {e}" + ) + except Exception as e: + logger.warning(f"Failed to get attachments: {e}") + + # 批量删除数据库记录 + try: + await self.db.delete_attachments(attachment_ids) + except Exception as e: + logger.warning(f"Failed to delete attachments: {e}") + async def new_session(self): """Create a new Platform session (default: webchat).""" username = g.get("username", "guest") diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 1089d8f81..e8f17cc99 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -2,6 +2,8 @@ import asyncio import inspect import os import traceback +import uuid +from typing import Any from quart import request @@ -13,6 +15,7 @@ from astrbot.core.config.default import ( CONFIG_METADATA_3_SYSTEM, DEFAULT_CONFIG, DEFAULT_VALUE_MAP, + WEBHOOK_SUPPORTED_PLATFORMS, ) from astrbot.core.config.i18n_utils import ConfigMetadataI18n from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -24,7 +27,7 @@ from astrbot.core.star.star import star_registry from .route import Response, Route, RouteContext -def try_cast(value: str, type_: str): +def try_cast(value: Any, type_: str): if type_ == "int": try: return int(value) @@ -503,9 +506,9 @@ class ConfigRoute(Route): if not isinstance(inst, EmbeddingProvider): return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ - # 初始化 - if getattr(inst, "initialize", None): - await inst.initialize() + init_fn = getattr(inst, "initialize", None) + if inspect.iscoroutinefunction(init_fn): + await init_fn() # 获取嵌入向量维度 vec = await inst.get_embedding("echo") @@ -555,6 +558,15 @@ class ConfigRoute(Route): async def post_new_platform(self): new_platform_config = await request.json + + # 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,自动生成 webhook_uuid + platform_type = new_platform_config.get("type", "") + if platform_type in WEBHOOK_SUPPORTED_PLATFORMS: + if new_platform_config.get("unified_webhook_mode", False): + # 如果没有 webhook_uuid 或为空,自动生成 + if not new_platform_config.get("webhook_uuid"): + new_platform_config["webhook_uuid"] = uuid.uuid4().hex[:16] + self.config["platform"].append(new_platform_config) try: save_config(self.config, self.config, is_core=True) @@ -584,6 +596,14 @@ class ConfigRoute(Route): if not platform_id or not new_config: return Response().error("参数错误").__dict__ + # 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid + platform_type = new_config.get("type", "") + if platform_type in WEBHOOK_SUPPORTED_PLATFORMS: + if new_config.get("unified_webhook_mode", False): + # 如果没有 webhook_uuid 或为空,自动生成 + if not new_config.get("webhook_uuid"): + new_config["webhook_uuid"] = uuid.uuid4().hex + for i, platform in enumerate(self.config["platform"]): if platform["id"] == platform_id: self.config["platform"][i] = new_config @@ -758,7 +778,7 @@ class ConfigRoute(Route): return {"metadata": CONFIG_METADATA_2, "config": config} async def _get_plugin_config(self, plugin_name: str): - ret = {"metadata": None, "config": None} + ret: dict = {"metadata": None, "config": None} for plugin_md in star_registry: if plugin_md.name == plugin_name: diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 050d5836c..d7db42c40 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -274,7 +274,7 @@ class KnowledgeBaseRoute(Route): except Exception as e: return ( Response() - .error(f"测试重排序模型失败: {e!s},请检查控制台日志输出。") + .error(f"测试重排序模型失败: {e!s},请检查平台日志输出。") .__dict__ ) diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py index eb02fdf40..86cc8c6ca 100644 --- a/astrbot/dashboard/routes/log.py +++ b/astrbot/dashboard/routes/log.py @@ -1,6 +1,8 @@ import asyncio import json +from typing import cast +from quart import Response as QuartResponse from quart import make_response from astrbot.core import LogBroker, logger @@ -39,14 +41,17 @@ class LogRoute(Route): if queue: self.log_broker.unregister(queue) - response = await make_response( - stream(), - { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Transfer-Encoding": "chunked", - }, + response = cast( + QuartResponse, + await make_response( + stream(), + { + "Content-Type": "text/event-stream", + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + }, + ), ) response.timeout = None return response diff --git a/astrbot/dashboard/routes/platform.py b/astrbot/dashboard/routes/platform.py new file mode 100644 index 000000000..5b709a628 --- /dev/null +++ b/astrbot/dashboard/routes/platform.py @@ -0,0 +1,100 @@ +"""统一 Webhook 路由 + +提供统一的 webhook 回调入口,支持多个平台使用同一端口接收回调。 +""" + +from quart import request + +from astrbot.core import logger +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.platform import Platform + +from .route import Response, Route, RouteContext + + +class PlatformRoute(Route): + """统一 Webhook 路由""" + + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.core_lifecycle = core_lifecycle + self.platform_manager = core_lifecycle.platform_manager + + self._register_webhook_routes() + + def _register_webhook_routes(self): + """注册 webhook 路由""" + # 统一 webhook 入口,支持 GET 和 POST + self.app.add_url_rule( + "/api/platform/webhook/", + view_func=self.unified_webhook_callback, + methods=["GET", "POST"], + ) + + # 平台统计信息接口 + self.app.add_url_rule( + "/api/platform/stats", + view_func=self.get_platform_stats, + methods=["GET"], + ) + + async def unified_webhook_callback(self, webhook_uuid: str): + """统一 webhook 回调入口 + + Args: + webhook_uuid: 平台配置中的 webhook_uuid + + Returns: + 根据平台适配器返回相应的响应 + """ + # 根据 webhook_uuid 查找对应的平台 + platform_adapter = self._find_platform_by_uuid(webhook_uuid) + + if not platform_adapter: + logger.warning(f"未找到 webhook_uuid 为 {webhook_uuid} 的平台") + return Response().error("未找到对应平台").__dict__, 404 + + # 调用平台适配器的 webhook_callback 方法 + try: + result = await platform_adapter.webhook_callback(request) + return result + except NotImplementedError: + logger.error( + f"平台 {platform_adapter.meta().name} 未实现 webhook_callback 方法" + ) + return Response().error("平台未支持统一 Webhook 模式").__dict__, 500 + except Exception as e: + logger.error(f"处理 webhook 回调时发生错误: {e}", exc_info=True) + return Response().error("处理回调失败").__dict__, 500 + + def _find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None: + """根据 webhook_uuid 查找对应的平台适配器 + + Args: + webhook_uuid: webhook UUID + + Returns: + 平台适配器实例,未找到则返回 None + """ + for platform in self.platform_manager.platform_insts: + if platform.config.get("webhook_uuid") == webhook_uuid: + if platform.config.get("unified_webhook_mode", False): + return platform + return None + + async def get_platform_stats(self): + """获取所有平台的统计信息 + + Returns: + 包含平台统计信息的响应 + """ + try: + stats = self.platform_manager.get_all_stats() + return Response().ok(stats).__dict__ + except Exception as e: + logger.error(f"获取平台统计信息失败: {e}", exc_info=True) + return Response().error(f"获取统计信息失败: {e}").__dict__, 500 diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py index f2a35dfe1..c249b07b7 100644 --- a/astrbot/dashboard/routes/plugin.py +++ b/astrbot/dashboard/routes/plugin.py @@ -1,14 +1,17 @@ import asyncio +import hashlib import json import os import ssl import traceback +from dataclasses import dataclass from datetime import datetime import aiohttp import certifi from quart import request +from astrbot.api import sp from astrbot.core import DEMO_MODE, file_token_service, logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.star.filter.command import CommandFilter @@ -25,6 +28,13 @@ PLUGIN_UPDATE_CONCURRENCY = ( ) +@dataclass +class RegistrySource: + urls: list[str] + cache_file: str + md5_url: str | None # None means "no remote MD5, always treat cache as stale" + + class PluginRoute(Route): def __init__( self, @@ -45,6 +55,8 @@ class PluginRoute(Route): "/plugin/on": ("POST", self.on_plugin), "/plugin/reload": ("POST", self.reload_plugins), "/plugin/readme": ("GET", self.get_plugin_readme), + "/plugin/source/get": ("GET", self.get_custom_source), + "/plugin/source/save": ("POST", self.save_custom_source), } self.core_lifecycle = core_lifecycle self.plugin_manager = plugin_manager @@ -84,22 +96,15 @@ class PluginRoute(Route): custom = request.args.get("custom_registry") force_refresh = request.args.get("force_refresh", "false").lower() == "true" - cache_file = "data/plugins.json" - - if custom: - urls = [custom] - else: - urls = [ - "https://api.soulter.top/astrbot/plugins", - "https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json", - ] + # 构建注册表源信息 + source = self._build_registry_source(custom) # 如果不是强制刷新,先检查缓存是否有效 cached_data = None if not force_refresh: # 先检查MD5是否匹配,如果匹配则使用缓存 - if await self._is_cache_valid(cache_file): - cached_data = self._load_plugin_cache(cache_file) + if await self._is_cache_valid(source): + cached_data = self._load_plugin_cache(source.cache_file) if cached_data: logger.debug("缓存MD5匹配,使用缓存的插件市场数据") return Response().ok(cached_data).__dict__ @@ -109,7 +114,7 @@ class PluginRoute(Route): ssl_context = ssl.create_default_context(cafile=certifi.where()) connector = aiohttp.TCPConnector(ssl=ssl_context) - for url in urls: + for url in source.urls: try: async with ( aiohttp.ClientSession( @@ -128,11 +133,13 @@ class PluginRoute(Route): logger.warning(f"远程插件市场数据为空: {url}") continue # 继续尝试其他URL或使用缓存 - logger.info("成功获取远程插件市场数据") + logger.info( + f"成功获取远程插件市场数据,包含 {len(remote_data)} 个插件" + ) # 获取最新的MD5并保存到缓存 - current_md5 = await self._get_remote_md5() + current_md5 = await self._fetch_remote_md5(source.md5_url) self._save_plugin_cache( - cache_file, + source.cache_file, remote_data, current_md5, ) @@ -143,7 +150,7 @@ class PluginRoute(Route): # 如果远程获取失败,尝试使用缓存数据 if not cached_data: - cached_data = self._load_plugin_cache(cache_file) + cached_data = self._load_plugin_cache(source.cache_file) if cached_data: logger.warning("远程插件市场数据获取失败,使用缓存数据") @@ -151,24 +158,75 @@ class PluginRoute(Route): return Response().error("获取插件列表失败,且没有可用的缓存数据").__dict__ - async def _is_cache_valid(self, cache_file: str) -> bool: - """检查缓存是否有效(基于MD5)""" - try: - if not os.path.exists(cache_file): - return False + def _build_registry_source(self, custom_url: str | None) -> RegistrySource: + """构建注册表源信息""" + if custom_url: + # 对自定义URL生成一个安全的文件名 + url_hash = hashlib.md5(custom_url.encode()).hexdigest()[:8] + cache_file = f"data/plugins_custom_{url_hash}.json" - # 加载缓存文件 + # 更安全的后缀处理方式 + if custom_url.endswith(".json"): + md5_url = custom_url[:-5] + "-md5.json" + else: + md5_url = custom_url + "-md5.json" + + urls = [custom_url] + else: + cache_file = "data/plugins.json" + md5_url = "https://api.soulter.top/astrbot/plugins-md5" + urls = [ + "https://api.soulter.top/astrbot/plugins", + "https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json", + ] + return RegistrySource(urls=urls, cache_file=cache_file, md5_url=md5_url) + + def _load_cached_md5(self, cache_file: str) -> str | None: + """从缓存文件中加载MD5""" + if not os.path.exists(cache_file): + return None + + try: with open(cache_file, encoding="utf-8") as f: cache_data = json.load(f) + return cache_data.get("md5") + except Exception as e: + logger.warning(f"加载缓存MD5失败: {e}") + return None - cached_md5 = cache_data.get("md5") + async def _fetch_remote_md5(self, md5_url: str | None) -> str | None: + """获取远程MD5""" + if not md5_url: + return None + + try: + ssl_context = ssl.create_default_context(cafile=certifi.where()) + connector = aiohttp.TCPConnector(ssl=ssl_context) + + async with ( + aiohttp.ClientSession( + trust_env=True, + connector=connector, + ) as session, + session.get(md5_url) as response, + ): + if response.status == 200: + data = await response.json() + return data.get("md5", "") + except Exception as e: + logger.debug(f"获取远程MD5失败: {e}") + return None + + async def _is_cache_valid(self, source: RegistrySource) -> bool: + """检查缓存是否有效(基于MD5)""" + try: + cached_md5 = self._load_cached_md5(source.cache_file) if not cached_md5: logger.debug("缓存文件中没有MD5信息") return False - # 获取远程MD5 - remote_md5 = await self._get_remote_md5() - if not remote_md5: + remote_md5 = await self._fetch_remote_md5(source.md5_url) + if remote_md5 is None: logger.warning("无法获取远程MD5,将使用缓存") return True # 如果无法获取远程MD5,认为缓存有效 @@ -182,30 +240,6 @@ class PluginRoute(Route): logger.warning(f"检查缓存有效性失败: {e}") return False - async def _get_remote_md5(self) -> str: - """获取远程插件数据的MD5""" - try: - ssl_context = ssl.create_default_context(cafile=certifi.where()) - connector = aiohttp.TCPConnector(ssl=ssl_context) - - async with ( - aiohttp.ClientSession( - trust_env=True, - connector=connector, - ) as session, - session.get( - "https://api.soulter.top/astrbot/plugins-md5", - ) as response, - ): - if response.status == 200: - data = await response.json() - return data.get("md5", "") - logger.error(f"获取MD5失败,状态码:{response.status}") - return "" - except Exception as e: - logger.error(f"获取远程MD5失败: {e}") - return "" - def _load_plugin_cache(self, cache_file: str): """加载本地缓存的插件市场数据""" try: @@ -545,9 +579,13 @@ class PluginRoute(Route): logger.warning(f"插件 {plugin_name} 不存在") return Response().error(f"插件 {plugin_name} 不存在").__dict__ + if not plugin_obj.root_dir_name: + logger.warning(f"插件 {plugin_name} 目录不存在") + return Response().error(f"插件 {plugin_name} 目录不存在").__dict__ + plugin_dir = os.path.join( self.plugin_manager.plugin_store_path, - plugin_obj.root_dir_name, + plugin_obj.root_dir_name or "", ) if not os.path.isdir(plugin_dir): @@ -572,3 +610,22 @@ class PluginRoute(Route): except Exception as e: logger.error(f"/api/plugin/readme: {traceback.format_exc()}") return Response().error(f"读取README文件失败: {e!s}").__dict__ + + async def get_custom_source(self): + """获取自定义插件源""" + sources = await sp.global_get("custom_plugin_sources", []) + return Response().ok(sources).__dict__ + + async def save_custom_source(self): + """保存自定义插件源""" + try: + data = await request.get_json() + sources = data.get("sources", []) + if not isinstance(sources, list): + return Response().error("sources fields must be a list").__dict__ + + await sp.global_put("custom_plugin_sources", sources) + return Response().ok(None, "保存成功").__dict__ + except Exception as e: + logger.error(f"/api/plugin/source/save: {traceback.format_exc()}") + return Response().error(str(e)).__dict__ diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 1105b69a7..01ab292d4 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -12,6 +12,8 @@ class RouteContext: class Route: + routes: list | dict + def __init__(self, context: RouteContext): self.app = context.app self.config = context.config diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 643e96542..6d6530c90 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -2,9 +2,12 @@ import asyncio import logging import os import socket +from typing import cast import jwt import psutil +from flask.json.provider import DefaultJSONProvider +from psutil._common import addr as psutil_addr from quart import Quart, g, jsonify, request from quart.logging import default_handler @@ -16,11 +19,12 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.io import get_local_ip_addresses from .routes import * +from .routes.platform import PlatformRoute from .routes.route import Response, RouteContext from .routes.session_management import SessionManagementRoute from .routes.t2i import T2iRoute -APP: Quart = None +APP: Quart class AstrBotDashboard: @@ -47,7 +51,7 @@ class AstrBotDashboard: self.app.config["MAX_CONTENT_LENGTH"] = ( 128 * 1024 * 1024 ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB - self.app.json.sort_keys = False + cast(DefaultJSONProvider, self.app.json).sort_keys = False self.app.before_request(self.auth_middleware) # token 用于验证请求 logging.getLogger(self.app.name).removeHandler(default_handler) @@ -80,6 +84,7 @@ class AstrBotDashboard: self.persona_route = PersonaRoute(self.context, db, core_lifecycle) self.t2i_route = T2iRoute(self.context, core_lifecycle) self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle) + self.platform_route = PlatformRoute(self.context, core_lifecycle) self.app.add_url_rule( "/api/plug/", @@ -103,7 +108,7 @@ class AstrBotDashboard: async def auth_middleware(self): if not request.path.startswith("/api"): return None - allowed_endpoints = ["/api/auth/login", "/api/file"] + allowed_endpoints = ["/api/auth/login", "/api/file", "/api/platform/webhook"] if any(request.path.startswith(prefix) for prefix in allowed_endpoints): return None # 声明 JWT @@ -146,7 +151,7 @@ class AstrBotDashboard: """获取占用端口的进程详细信息""" try: for conn in psutil.net_connections(kind="inet"): - if conn.laddr.port == port: + if cast(psutil_addr, conn.laddr).port == port: try: process = psutil.Process(conn.pid) # 获取详细信息 diff --git a/changelogs/v4.8.0.md b/changelogs/v4.8.0.md new file mode 100644 index 000000000..c0831c52d --- /dev/null +++ b/changelogs/v4.8.0.md @@ -0,0 +1,15 @@ +## What's Changed + +**新增:** +- 对部分需要 Webhook 的适配器(QQ 官方机器人、Slack、企业微信、微信客服、企业微信智能机器人、微信公众号)支持统一的 Webhook 链接模式,避免开多个端口。并支持在 WebUI 机器人卡片中查看和复制 Webhook 链接。详情请看:[统一 Webhook 模式](https://docs.astrbot.app/use/unified-webhook.html) +- 新增 Kubernetes 部署文档。 + +**修复:** +- 修复:Telegram 和 QQ 场景下,使用 Whisper API 报错。 +- 修复:部分情况下 Slack 输出消息段代码的问题。 +- 修复:当启动了流式输出时,QQ 官方机器人适配器无法正常回复消息。 +- 修复:对话数据页的对话详情在暗夜模式下显示异常的问题。 + +**优化:** +- 重构:WebChat 的消息数据结构,支持引用回复、文件发送、时间显示等功能,优化思考内容显示的部分 Bug。 +- 优化:机器人页面支持显示报错信息,方便排查问题。 diff --git a/compose.yml b/compose.yml index 2b3185301..99557a1d8 100644 --- a/compose.yml +++ b/compose.yml @@ -9,10 +9,9 @@ services: restart: always ports: # mappings description: https://github.com/AstrBotDevs/AstrBot/issues/497 - "6185:6185" # 必选,AstrBot WebUI 端口 - - "6195:6195" # 可选, 企业微信 Webhook 端口 - "6199:6199" # 可选, QQ 个人号 WebSocket 端口 - - "6196:6196" # 可选, QQ 官方接口 Webhook 端口 - - "11451:11451" # 可选, 微信个人号 Webhook 端口 + # - "6195:6195" # 可选, 企业微信 Webhook 端口 + # - "6196:6196" # 可选, QQ 官方接口 Webhook 端口 environment: - TZ=Asia/Shanghai volumes: diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index caff448cc..509971ca8 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -71,6 +71,7 @@
@@ -84,19 +85,23 @@ v-model:prompt="prompt" :stagedImagesUrl="stagedImagesUrl" :stagedAudioUrl="stagedAudioUrl" + :stagedFiles="stagedNonImageFiles" :disabled="isStreaming" :enableStreaming="enableStreaming" :isRecording="isRecording" :session-id="currSessionId || null" :current-session="getCurrentSession" + :replyTo="replyTo" @send="handleSendMessage" @toggleStreaming="toggleStreaming" @removeImage="removeImage" @removeAudio="removeAudio" + @removeFile="removeFile" @startRecording="handleStartRecording" @stopRecording="handleStopRecording" @pasteImage="handlePaste" @fileSelect="handleFileSelect" + @clearReply="clearReply" ref="chatInputRef" />
@@ -189,14 +194,17 @@ const { } = useSessions(props.chatboxMode); const { - stagedImagesName, stagedImagesUrl, stagedAudioUrl, + stagedFiles, + stagedNonImageFiles, getMediaFile, processAndUploadImage, + processAndUploadFile, handlePaste, removeImage, removeAudio, + removeFile, clearStaged, cleanupMediaCache } = useMediaHandling(); @@ -220,6 +228,13 @@ const chatInputRef = ref | null>(null); // 输入状态 const prompt = ref(''); +// 引用消息状态 +interface ReplyInfo { + messageId: number; // PlatformSessionHistoryMessage 的 id + messageContent: string; // 用于显示的消息内容 +} +const replyTo = ref(null); + const isDark = computed(() => useCustomizerStore().uiTheme === 'PurpleThemeDark'); // 检测是否为手机端 @@ -250,6 +265,41 @@ function openImagePreview(imageUrl: string) { imagePreviewDialog.value = true; } +function handleReplyMessage(msg: any, index: number) { + // 从消息中获取 id (PlatformSessionHistoryMessage 的 id) + const messageId = msg.id; + if (!messageId) { + console.warn('Message does not have an id'); + return; + } + + // 获取消息内容用于显示 + let messageContent = ''; + if (typeof msg.content.message === 'string') { + messageContent = msg.content.message; + } else if (Array.isArray(msg.content.message)) { + // 从消息段数组中提取纯文本 + const textParts = msg.content.message + .filter((part: any) => part.type === 'plain' && part.text) + .map((part: any) => part.text); + messageContent = textParts.join(''); + } + + // 截断过长的内容 + if (messageContent.length > 100) { + messageContent = messageContent.substring(0, 100) + '...'; + } + + replyTo.value = { + messageId, + messageContent: messageContent || '[媒体内容]' + }; +} + +function clearReply() { + replyTo.value = null; +} + async function handleSelectConversation(sessionIds: string[]) { if (!sessionIds[0]) return; @@ -265,6 +315,9 @@ async function handleSelectConversation(sessionIds: string[]) { closeMobileSidebar(); } + // 清除引用状态 + clearReply(); + currSessionId.value = sessionIds[0]; selectedSessions.value = [sessionIds[0]]; @@ -278,6 +331,7 @@ async function handleSelectConversation(sessionIds: string[]) { function handleNewChat() { newChat(closeMobileSidebar); messages.value = []; + clearReply(); } async function handleDeleteConversation(sessionId: string) { @@ -295,13 +349,19 @@ async function handleStopRecording() { } async function handleFileSelect(files: FileList) { + const imageTypes = ['image/jpeg', 'image/png', 'image/gif', 'image/webp']; for (const file of files) { - await processAndUploadImage(file); + if (imageTypes.includes(file.type)) { + await processAndUploadImage(file); + } else { + await processAndUploadFile(file); + } } } async function handleSendMessage() { - if (!prompt.value.trim() && stagedImagesName.value.length === 0 && !stagedAudioUrl.value) { + // 只有引用不能发送,必须有输入内容 + if (!prompt.value.trim() && stagedFiles.value.length === 0 && !stagedAudioUrl.value) { return; } @@ -310,12 +370,19 @@ async function handleSendMessage() { } const promptToSend = prompt.value.trim(); - const imageNamesToSend = [...stagedImagesName.value]; const audioNameToSend = stagedAudioUrl.value; + const filesToSend = stagedFiles.value.map(f => ({ + attachment_id: f.attachment_id, + url: f.url, + original_name: f.original_name, + type: f.type + })); + const replyToSend = replyTo.value ? { ...replyTo.value } : null; - // 清空输入和附件 + // 清空输入和附件和引用 prompt.value = ''; clearStaged(); + clearReply(); // 获取选择的提供商和模型 const selection = chatInputRef.value?.getCurrentSelection(); @@ -324,10 +391,11 @@ async function handleSendMessage() { await sendMsg( promptToSend, - imageNamesToSend, + filesToSend, audioNameToSend, selectedProviderId, - selectedModelName + selectedModelName, + replyToSend ); } diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue index 79ce27654..53e1e30c0 100644 --- a/dashboard/src/components/chat/ChatInput.vue +++ b/dashboard/src/components/chat/ChatInput.vue @@ -2,6 +2,14 @@
+ +
+
+ mdi-reply + "{{ props.replyTo.messageContent }}" +
+ +