Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter 7777895409 chore: bump version to 4.14.8 2026-02-09 00:52:07 +08:00
182 changed files with 1685 additions and 10191 deletions
+2 -2
View File
@@ -16,7 +16,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '24.13.0'
node-version: 'latest'
- name: npm install, build
run: |
@@ -52,4 +52,4 @@ jobs:
repo: astrbot-release-harbour
body: "Automated release from commit ${{ github.sha }}"
token: ${{ secrets.ASTRBOT_HARBOUR_TOKEN }}
artifacts: "dashboard/dist.zip"
artifacts: "dashboard/dist.zip"
+2 -2
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
env:
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
GHCR_OWNER: astrbotdevs
GHCR_OWNER: soulter
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
steps:
@@ -113,7 +113,7 @@ jobs:
runs-on: ubuntu-latest
env:
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
GHCR_OWNER: astrbotdevs
GHCR_OWNER: soulter
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
steps:
+5 -5
View File
@@ -57,7 +57,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '24.13.0'
node-version: 20
cache: "pnpm"
cache-dependency-path: dashboard/pnpm-lock.yaml
@@ -160,7 +160,7 @@ jobs:
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Setup uv
uses: astral-sh/setup-uv@v7
uses: astral-sh/setup-uv@v6
- name: Setup Python
uses: actions/setup-python@v6
@@ -175,7 +175,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: '24.13.0'
node-version: 20
cache: "pnpm"
cache-dependency-path: |
dashboard/pnpm-lock.yaml
@@ -291,13 +291,13 @@ jobs:
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Download dashboard artifact
uses: actions/download-artifact@v7
uses: actions/download-artifact@v6
with:
name: Dashboard-${{ steps.tag.outputs.tag }}
path: release-assets
- name: Download desktop artifacts
uses: actions/download-artifact@v7
uses: actions/download-artifact@v6
with:
pattern: AstrBot-${{ steps.tag.outputs.tag }}-*
path: release-assets
+7 -7
View File
@@ -15,17 +15,17 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
curl \
gnupg \
git \
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
&& apt-get install -y --no-install-recommends nodejs \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/* /tmp/* /var/tmp/*
RUN apt-get update && apt-get install -y curl gnupg \
&& curl -fsSL https://deb.nodesource.com/setup_lts.x | bash - \
&& apt-get install -y nodejs
RUN python -m pip install uv \
&& echo "3.12" > .python-version \
&& uv lock \
&& uv export --format requirements.txt --output-file requirements.txt --frozen \
&& uv pip install -r requirements.txt --no-cache-dir --system \
&& uv pip install socksio uv pilk --no-cache-dir --system
&& echo "3.12" > .python-version
RUN uv pip install -r requirements.txt --no-cache-dir --system
RUN uv pip install socksio uv pilk --no-cache-dir --system
EXPOSE 6185
-14
View File
@@ -1,14 +0,0 @@
## Welcome to AstrBot
🌟 Thank you for using AstrBot!
AstrBot is an Agentic AI assistant for personal and group chats, with support for multiple IM platforms and a wide range of built-in features. We hope it brings you an efficient and enjoyable experience. ❤️
Important notice:
AstrBot is a **free and open-source software project** protected by the AGPLv3 license. You can find the full source code and related resources on our [**official website**](https://astrbot.app) and [**GitHub**](https://github.com/astrbotdevs/astrbot).
As of now, AstrBot has **no commercial services of any kind**, and the official team **will never charge users any fees** under any name.
If anyone asks you to pay while using AstrBot, **you are likely being scammed**. Please request a refund immediately and report it to us by email.
📮 Official email: [community@astrbot.app](mailto:community@astrbot.app)
-14
View File
@@ -1,14 +0,0 @@
## 欢迎使用 AstrBot
🌟 感谢您使用 AstrBot
AstrBot 是一款可接入多种 IM 平台的 Agentic AI 个人 / 群聊助手,内置多项强大功能,希望能为您带来高效、愉快的使用体验。❤️
我们想特别说明:
AstrBot 是受 AGPLv3 开源协议保护的**免费开源软件项目**,您可以在[**官方网站**](https://astrbot.app)、[**GitHub**](https://github.com/astrbotdevs/astrbot) 上找到 AstrBot 的全部源代码及相关资源。
截至目前,AstrBot 项目**未开展任何形式的商业化服务**,官方**不会以任何名义向用户收取费用**。
如果您在使用 AstrBot 的过程中被要求付费,**表明您已经遭遇诈骗行为**。请立即向相关方申请退款,并及时通过邮件向我们反馈。
📮 官方邮箱:[community@astrbot.app](mailto:community@astrbot.app)
+12 -26
View File
@@ -2,6 +2,7 @@
<div align="center">
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
@@ -40,14 +41,14 @@ AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、
## 主要功能
1. 💯 免费 & 开源。
2. ✨ AI 大模型对话,多模态,Agent,MCP,Skills,知识库,人格设定,自动压缩对话。
3. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台。
4. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)。
5. 📦 插件扩展,已有近 800 个插件可一键安装。
6. 🛡️ [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) 隔离化环境,安全地执行任何代码、调用 Shell、会话级资源复用。
7. 💻 WebUI 支持。
8. 🌈 Web ChatUI 支持,ChatUI 内置代理沙盒、网页搜索等。
9. 🌐 国际化(i18n)支持。
1. ✨ AI 大模型对话,多模态,Agent,MCP,Skills,知识库,人格设定,自动压缩对话。
2. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台。
2. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)。
3. 📦 插件扩展,已有近 800 个插件可一键安装。
5. 🛡️ [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) 隔离化环境,安全地执行任何代码、调用 Shell、会话级资源复用。
6. 💻 WebUI 支持。
7. 🌈 Web ChatUI 支持,ChatUI 内置代理沙盒、网页搜索等。
8. 🌐 国际化(i18n)支持。
<br>
@@ -77,14 +78,9 @@ AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、
#### uv 部署
```bash
uv tool install astrbot
astrbot
uvx astrbot
```
#### 启动器一键部署(AstrBot Launcher
进入 [AstrBot Launcher](https://github.com/Raven95676/astrbot-launcher) 仓库,在 Releases 页最新版本下找到对应的系统安装包安装即可。
#### 宝塔面板部署
AstrBot 与宝塔面板合作,已上架至宝塔面板。
@@ -136,16 +132,6 @@ uv run main.py
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
#### 系统包管理器安装
##### Arch Linux
```bash
yay -S astrbot-git
# 或者使用 paru
paru -S astrbot-git
```
#### 桌面端 Electron 打包
桌面端(Electron 打包,`pnpm` 工作流)构建流程请参阅:[`desktop/README.md`](desktop/README.md)。
@@ -278,6 +264,8 @@ pre-commit install
</div>
</details>
<div align="center">
_陪伴与能力从来不应该是对立面。我们希望创造的是一个既能理解情绪、给予陪伴,也能可靠完成工作的机器人。_
@@ -285,5 +273,3 @@ _陪伴与能力从来不应该是对立面。我们希望创造的是一个既
_私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
+5 -42
View File
@@ -3,6 +3,7 @@
<div align="center">
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
@@ -51,23 +52,6 @@ AstrBot is an open-source all-in-one Agent chatbot platform that integrates with
8. 🌈 Web ChatUI Support with built-in agent sandbox and web search.
9. 🌐 Internationalization (i18n) Support.
<br>
<table align="center">
<tr align="center">
<th>💙 Role-playing & Emotional Companionship</th>
<th>✨ Proactive Agent</th>
<th>🚀 General Agentic Capabilities</th>
<th>🧩 900+ Community Plugins</th>
</tr>
<tr>
<td align="center"><p align="center"><img width="984" height="1746" alt="99b587c5d35eea09d84f33e6cf6cfd4f" src="https://github.com/user-attachments/assets/89196061-3290-458d-b51f-afa178049f84" /></p></td>
<td align="center"><p align="center"><img width="976" height="1612" alt="c449acd838c41d0915cc08a3824025b1" src="https://github.com/user-attachments/assets/f75368b4-e022-41dc-a9e0-131c3e73e32e" /></p></td>
<td align="center"><p align="center"><img width="974" height="1732" alt="image" src="https://github.com/user-attachments/assets/e22a3968-87d7-4708-a7cd-e7f198c7c32e" /></p></td>
<td align="center"><p align="center"><img width="976" height="1734" alt="image" src="https://github.com/user-attachments/assets/0952b395-6b4a-432a-8a50-c294b7f89750" /></p></td>
</tr>
</table>
## Quick Start
#### Docker Deployment (Recommended 🥳)
@@ -79,18 +63,7 @@ Please refer to the official documentation: [Deploy AstrBot with Docker](https:/
#### uv Deployment
```bash
uv tool install astrbot
astrbot
```
#### System Package Manager Installation
##### Arch Linux
```bash
yay -S astrbot-git
# or use paru
paru -S astrbot-git
uvx astrbot
```
#### BT-Panel Deployment
@@ -144,16 +117,6 @@ uv run main.py
Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
#### System Package Manager Installation
##### Arch Linux
```bash
yay -S astrbot-git
# or use paru
paru -S astrbot-git
```
#### Desktop Electron Build
For desktop build steps (Electron packaging, `pnpm` workflow), see [`desktop/README.md`](desktop/README.md).
@@ -196,7 +159,7 @@ For desktop build steps (Electron packaging, `pnpm` workflow), see [`desktop/REA
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [TokenPony](https://www.tokenpony.cn/3YPyf)
- [SiliconFlow](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
@@ -286,9 +249,9 @@ Additionally, the birth of this project would not have been possible without the
</div>
<div align="center">
</details>
_Companionship and capability should never be at odds. What we aim to create is a robot that can understand emotions, provide genuine companionship, and reliably accomplish tasks._
<div align="center">
_私は、高性能ですから!_
+23 -67
View File
@@ -1,12 +1,8 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<div align="center">
</p>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<div align="center">
<br>
@@ -18,17 +14,22 @@
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20plugins&style=for-the-badge&label=Marketplace&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<a href="https://astrbot.app/">Documentation</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">Feuille de route</a>
@@ -42,31 +43,12 @@ AstrBot est une plateforme de chatbot Agent tout-en-un open source qui s'intègr
## Fonctionnalités principales
1. 💯 Gratuit & Open Source.
2.Dialogue avec de grands modèles d'IA, multimodal, Agent, MCP, Skills, Base de connaissances, Paramétrage de personnalité, compression automatique des dialogues.
3. 🤖 Prise en charge de l'accès aux plateformes d'Agents telles que Dify, Alibaba Cloud Bailian, Coze, etc.
4. 🌐 Multiplateforme : supporte QQ, WeChat Enterprise, Feishu, DingTalk, Comptes officiels WeChat, Telegram, Slack et [plus encore](#plateformes-de-messagerie-prises-en-charge).
5. 📦 Extension par plugins, avec près de 800 plugins déjà disponibles pour une installation en un clic.
6. 🛡️ Environnement isolé [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) : exécution sécurisée de code, appels Shell et réutilisation des ressources au niveau de la session.
7. 💻 Support WebUI.
8. 🌈 Support Web ChatUI, avec sandbox d'agent intégrée, recherche web, etc.
9. 🌐 Support de l'internationalisation (i18n).
<br>
<table align="center">
<tr align="center">
<th>💙 Jeux de rôle & Accompagnement émotionnel</th>
<th>✨ Agent proactif</th>
<th>🚀 Capacités agentiques générales</th>
<th>🧩 900+ Plugins de communauté</th>
</tr>
<tr>
<td align="center"><p align="center"><img width="984" height="1746" alt="99b587c5d35eea09d84f33e6cf6cfd4f" src="https://github.com/user-attachments/assets/89196061-3290-458d-b51f-afa178049f84" /></p></td>
<td align="center"><p align="center"><img width="976" height="1612" alt="c449acd838c41d0915cc08a3824025b1" src="https://github.com/user-attachments/assets/f75368b4-e022-41dc-a9e0-131c3e73e32e" /></p></td>
<td align="center"><p align="center"><img width="974" height="1732" alt="image" src="https://github.com/user-attachments/assets/e22a3968-87d7-4708-a7cd-e7f198c7c32e" /></p></td>
<td align="center"><p align="center"><img width="976" height="1734" alt="image" src="https://github.com/user-attachments/assets/0952b395-6b4a-432a-8a50-c294b7f89750" /></p></td>
</tr>
</table>
2.Conversations avec LLM IA, Multimodal, Agent, MCP, Base de connaissances, Paramètres de personnalité.
3. 🤖 Prise en charge de l'intégration avec Dify, Alibaba Cloud Bailian, Coze et autres plateformes d'agents.
4. 🌐 Multi-plateforme : QQ, WeChat Work, Feishu, DingTalk, Comptes officiels WeChat, Telegram, Slack, et [plus encore](#plateformes-de-messagerie-prises-en-charge).
5. 📦 Extensions de plugins avec près de 800 plugins disponibles pour une installation en un clic.
6. 💻 Support WebUI.
7. 🌐 Support de l'internationalisation (i18n).
## Démarrage rapide
@@ -79,18 +61,7 @@ Veuillez consulter la documentation officielle : [Déployer AstrBot avec Docker]
#### Déploiement uv
```bash
uv tool install astrbot
astrbot
```
#### Installation via le gestionnaire de paquets du système
##### Arch Linux
```bash
yay -S astrbot-git
# ou utiliser paru
paru -S astrbot-git
uvx astrbot
```
#### Déploiement BT-Panel
@@ -144,16 +115,6 @@ uv run main.py
Ou consultez la documentation officielle : [Déployer AstrBot depuis les sources](https://astrbot.app/deploy/astrbot/cli.html).
#### Установка через системный пакетный менеджер
##### Arch Linux
```bash
yay -S astrbot-git
# или используйте paru
paru -S astrbot-git
```
## Plateformes de messagerie prises en charge
**Maintenues officiellement**
@@ -192,7 +153,7 @@ paru -S astrbot-git
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [TokenPony](https://www.tokenpony.cn/3YPyf)
- [SiliconFlow](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
@@ -280,12 +241,7 @@ De plus, la naissance de ce projet n'aurait pas été possible sans l'aide des p
</div>
<div align="center">
_La compagnie et la capacité ne devraient jamais être des opposés. Nous souhaitons créer un robot capable à la fois de comprendre les émotions, d'offrir de la présence, et d'accomplir des tâches de manière fiable._
</details>
_私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
+22 -67
View File
@@ -1,12 +1,8 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<div align="center">
</p>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<div align="center">
<br>
@@ -18,17 +14,22 @@
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E3%83%97%E3%83%A9%E3%82%B0%E3%82%A4%E3%83%B3&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<a href="https://astrbot.app/">ドキュメント</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">ロードマップ</a>
@@ -42,31 +43,12 @@ AstrBot は、主要なインスタントメッセージングアプリと統合
## 主な機能
1. 💯 無料 & オープンソース。
2. ✨ AI大規模言語モデル対話、マルチモーダル、Agent、MCP、Skills、ナレッジベース、ペルソナ設定、対話の自動圧縮
3. 🤖 Dify、Alibaba Cloud Bailian(百煉)、Coze などのAgentプラットフォームへの接続をサポート。
4. 🌐 マルチプラットフォーム:QQ、企業微信(WeCom)、飛書(Lark)、釘釘(DingTalk、WeChat公式アカウント、Telegram、Slack、[その他](#サポートされているメッセージプラットフォーム)に対応
5. 📦 プラグイン拡張:800近い既存プラグインをワンクリックでインストール可能。
6. 🛡️ 隔離環境[Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html):コードの安全な実行、Shell呼び出し、セッションレベルのリソース再利用
7. 💻 WebUI 対応
8. 🌈 Web ChatUI 対応:ChatUI内にAgent Sandboxやウェブ検索などを内蔵。
9. 🌐 多言語対応(i18n)。
<br>
<table align="center">
<tr align="center">
<th>💙 ロールプレイ & 感情的な対話</th>
<th>✨ プロアクティブ・エージェント (Proactive Agent)</th>
<th>🚀 汎用 エージェント的能力</th>
<th>🧩 900+ コミュニティプラグイン</th>
</tr>
<tr>
<td align="center"><p align="center"><img width="984" height="1746" alt="99b587c5d35eea09d84f33e6cf6cfd4f" src="https://github.com/user-attachments/assets/89196061-3290-458d-b51f-afa178049f84" /></p></td>
<td align="center"><p align="center"><img width="976" height="1612" alt="c449acd838c41d0915cc08a3824025b1" src="https://github.com/user-attachments/assets/f75368b4-e022-41dc-a9e0-131c3e73e32e" /></p></td>
<td align="center"><p align="center"><img width="974" height="1732" alt="image" src="https://github.com/user-attachments/assets/e22a3968-87d7-4708-a7cd-e7f198c7c32e" /></p></td>
<td align="center"><p align="center"><img width="976" height="1734" alt="image" src="https://github.com/user-attachments/assets/0952b395-6b4a-432a-8a50-c294b7f89750" /></p></td>
</tr>
</table>
2. ✨ AI 大規模言語モデル対話、マルチモーダル、Agent、MCP、ナレッジベース、ペルソナ設定。
3. 🤖 Dify、Alibaba Cloud 百炼、Coze などの Agent プラットフォームとの統合をサポート。
4. 🌐 マルチプラットフォーム:QQ、WeChat Work、Feishu、DingTalk、WeChat 公式アカウント、Telegram、Slack、[その他](#サポートされているメッセージプラットフォーム)。
5. 📦 約800個のプラグインをワンクリックでインストール可能なプラグイン拡張機能
6. 💻 WebUI サポート
7. 🌐 国際化(i18n)サポート
## クイックスタート
@@ -79,18 +61,7 @@ Docker / Docker Compose を使用した AstrBot のデプロイを推奨しま
#### uv デプロイ
```bash
uv tool install astrbot
astrbot
```
#### システムパッケージマネージャーでのインストール
##### Arch Linux
```bash
yay -S astrbot-git
# または paru を使用
paru -S astrbot-git
uvx astrbot
```
#### 宝塔パネルデプロイ
@@ -144,16 +115,6 @@ uv run main.py
または、公式ドキュメント [ソースコードから AstrBot をデプロイ](https://astrbot.app/deploy/astrbot/cli.html) をご参照ください。
#### Установка через системный пакетный менеджер
##### Arch Linux
```bash
yay -S astrbot-git
# или используйте paru
paru -S astrbot-git
```
## サポートされているメッセージプラットフォーム
**公式メンテナンス**
@@ -281,12 +242,6 @@ AstrBot への貢献をしていただいたすべてのコントリビュータ
</div>
<div align="center">
_共感力と能力は決して対立するものではありません。私たちが目指すのは、感情を理解し、心の支えとなるだけでなく、確実に仕事をこなせるロボットの創造です。_
</details>
_私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
+24 -59
View File
@@ -1,12 +1,8 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<div align="center">
</p>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<div align="center">
<br>
@@ -18,17 +14,22 @@
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E4%B8%AA&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%9C%BA&cacheSeconds=3600">
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%20%D0%BF%D0%BB%D0%B0%D0%B3%D0%B8%D0%BD%D0%BE%D0%B2&style=for-the-badge&label=%D0%9C%D0%B0%D0%B3%D0%B0%D0%B7%D0%B8%D0%BD&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_zh-TW.md">繁體中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://astrbot.app/">Документация</a>
<a href="https://blog.astrbot.app/">Блог</a>
<a href="https://astrbot.featurebase.app/roadmap">Дорожная карта</a>
@@ -41,32 +42,13 @@ AstrBot — это универсальная платформа Agent-чатб
## Основные возможности
1. 💯 Бесплатно & Открытый исходный код.
2.Диалоги с ИИ-моделями, мультимодальность, Agent, MCP, Skills, База знаний, Настройка личности, автоматическое сжатие диалогов.
3. 🤖 Поддержка интеграции с платформами Agents, такими как Dify, Alibaba Cloud Bailian, Coze и др.
4. 🌐 Мультиплатформенность: поддержка QQ, WeChat для предприятий, Feishu, DingTalk, публичных аккаунтов WeChat, Telegram, Slack и [других](#Поддерживаемые-платформы-обмена-сообщениями).
5. 📦 Расширение плагинами: доступно почти 800 плагинов для установки в один клик.
6. 🛡️ Изолированная среда[Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html): безопасное выполнение любого кода, вызов Shell, повторное использование ресурсов на уровне сессии.
7. 💻 Поддержка WebUI.
8. 🌈 Поддержка Web ChatUI: встроенная песочница агента, веб-поиск и др.
9. 🌐 Поддержка интернационализации (i18n).
<br>
<table align="center">
<tr align="center">
<th>💙 Ролевые игры & Эмоциональная поддержка</th>
<th>✨ Проактивный Агент(Agent)</th>
<th>🚀 Универсальные Агентные возможности</th>
<th>🧩 Универсальные Агентные (Agentic) возможности</th>
</tr>
<tr>
<td align="center"><p align="center"><img width="984" height="1746" alt="99b587c5d35eea09d84f33e6cf6cfd4f" src="https://github.com/user-attachments/assets/89196061-3290-458d-b51f-afa178049f84" /></p></td>
<td align="center"><p align="center"><img width="976" height="1612" alt="c449acd838c41d0915cc08a3824025b1" src="https://github.com/user-attachments/assets/f75368b4-e022-41dc-a9e0-131c3e73e32e" /></p></td>
<td align="center"><p align="center"><img width="974" height="1732" alt="image" src="https://github.com/user-attachments/assets/e22a3968-87d7-4708-a7cd-e7f198c7c32e" /></p></td>
<td align="center"><p align="center"><img width="976" height="1734" alt="image" src="https://github.com/user-attachments/assets/0952b395-6b4a-432a-8a50-c294b7f89750" /></p></td>
</tr>
</table>
1. 💯 Бесплатно и с открытым исходным кодом.
2.ИИ-диалоги с LLM, мультимодальность, Agent, MCP, база знаний, настройки личности.
3. 🤖 Поддержка интеграции с Dify, Alibaba Cloud Bailian, Coze и другими платформами агентов.
4. 🌐 Мультиплатформенность: QQ, WeChat Work, Feishu, DingTalk, официальные аккаунты WeChat, Telegram, Slack и [другие](#поддерживаемые-платформы-обмена-сообщениями).
5. 📦 Расширения плагинов с почти 800 плагинами, доступными для установки в один клик.
6. 💻 Поддержка WebUI.
7. 🌐 Поддержка интернационализации (i18n).
## Быстрый старт
@@ -79,8 +61,7 @@ AstrBot — это универсальная платформа Agent-чатб
#### Развёртывание uv
```bash
uv tool install astrbot
astrbot
uvx astrbot
```
#### Развёртывание BT-Panel
@@ -134,16 +115,6 @@ uv run main.py
Или см. официальную документацию: [Развёртывание AstrBot из исходного кода](https://astrbot.app/deploy/astrbot/cli.html).
#### Установка через системный пакетный менеджер
##### Arch Linux
```bash
yay -S astrbot-git
# или используйте paru
paru -S astrbot-git
```
## Поддерживаемые платформы обмена сообщениями
**Официально поддерживаемые**
@@ -182,7 +153,7 @@ paru -S astrbot-git
- [CompShare](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74)
- [302.AI](https://share.302.ai/rr1M3l)
- [TokenPony](https://www.tokenpony.cn/3YPyf)
- [SiliconFlow](https://docs.siliconflow.cn/cn/usercases/use-siliconcloud-in-astrbot)
- [SiliconFlow](https://docs.siliconflow.cn/cn/usecases/use-siliconcloud-in-astrbot)
- [PPIO Cloud](https://ppio.com/user/register?invited_by=AIOONE)
- ModelScope
- OneAPI
@@ -264,19 +235,13 @@ pre-commit install
> [!TIP]
> Если этот проект помог вам в жизни или работе, или если вас интересует его будущее развитие, пожалуйста, поставьте проекту звезду. Это движущая сила поддержки этого проекта с открытым исходным кодом <3
<div align="center">
[![Star History Chart](https://api.star-history.com/svg?repos=astrbotdevs/astrbot&type=Date)](https://star-history.com/#astrbotdevs/astrbot&Date)
</div>
<div align="center">
_Сопровождение и способности никогда не должны быть противоположностями. Мы стремимся создать робота, который сможет как понимать эмоции, оказывать душевную поддержку, так и надёжно выполнять работу._
</details>
_私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
+22 -56
View File
@@ -1,12 +1,8 @@
![AstrBot-Logo-Simplified](https://github.com/user-attachments/assets/ffd99b6b-3272-4682-beaa-6fe74250f7d9)
<div align="center">
</p>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">简体中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<div align="center">
<br>
@@ -18,17 +14,22 @@
<br>
<div>
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg" alt="python">
<img src="https://deepwiki.com/badge.svg" href="https://deepwiki.com/AstrBotDevs/AstrBot">
<a href="https://zread.ai/AstrBotDevs/AstrBot" target="_blank"><img src="https://img.shields.io/badge/Ask_Zread-_.svg?style=flat&color=00b0aa&labelColor=000000&logo=data%3Aimage%2Fsvg%2Bxml%3Bbase64%2CPHN2ZyB3aWR0aD0iMTYiIGhlaWdodD0iMTYiIHZpZXdCb3g9IjAgMCAxNiAxNiIgZmlsbD0ibm9uZSIgeG1sbnM9Imh0dHA6Ly93d3cudzMub3JnLzIwMDAvc3ZnIj4KPHBhdGggZD0iTTQuOTYxNTYgMS42MDAxSDIuMjQxNTZDMS44ODgxIDEuNjAwMSAxLjYwMTU2IDEuODg2NjQgMS42MDE1NiAyLjI0MDFWNC45NjAxQzEuNjAxNTYgNS4zMTM1NiAxLjg4ODEgNS42MDAxIDIuMjQxNTYgNS42MDAxSDQuOTYxNTZDNS4zMTUwMiA1LjYwMDEgNS42MDE1NiA1LjMxMzU2IDUuNjAxNTYgNC45NjAxVjIuMjQwMUM1LjYwMTU2IDEuODg2NjQgNS4zMTUwMiAxLjYwMDEgNC45NjE1NiAxLjYwMDFaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00Ljk2MTU2IDEwLjM5OTlIMi4yNDE1NkMxLjg4ODEgMTAuMzk5OSAxLjYwMTU2IDEwLjY4NjQgMS42MDE1NiAxMS4wMzk5VjEzLjc1OTlDMS42MDE1NiAxNC4xMTM0IDEuODg4MSAxNC4zOTk5IDIuMjQxNTYgMTQuMzk5OUg0Ljk2MTU2QzUuMzE1MDIgMTQuMzk5OSA1LjYwMTU2IDE0LjExMzQgNS42MDE1NiAxMy43NTk5VjExLjAzOTlDNS42MDE1NiAxMC42ODY0IDUuMzE1MDIgMTAuMzk5OSA0Ljk2MTU2IDEwLjM5OTlaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik0xMy43NTg0IDEuNjAwMUgxMS4wMzg0QzEwLjY4NSAxLjYwMDEgMTAuMzk4NCAxLjg4NjY0IDEwLjM5ODQgMi4yNDAxVjQuOTYwMUMxMC4zOTg0IDUuMzEzNTYgMTAuNjg1IDUuNjAwMSAxMS4wMzg0IDUuNjAwMUgxMy43NTg0QzE0LjExMTkgNS42MDAxIDE0LjM5ODQgNS4zMTM1NiAxNC4zOTg0IDQuOTYwMVYyLjI0MDFDMTQuMzk4NCAxLjg4NjY0IDE0LjExMTkgMS42MDAxIDEzLjc1ODQgMS42MDAxWiIgZmlsbD0iI2ZmZiIvPgo8cGF0aCBkPSJNNCAxMkwxMiA0TDQgMTJaIiBmaWxsPSIjZmZmIi8%2BCjxwYXRoIGQ9Ik00IDEyTDEyIDQiIHN0cm9rZT0iI2ZmZiIgc3Ryb2tlLXdpZHRoPSIxLjUiIHN0cm9rZS1saW5lY2FwPSJyb3VuZCIvPgo8L3N2Zz4K&logoColor=ffffff" alt="zread"/></a>
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?color=76bad9"/></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%A0%B4&cacheSeconds=3600">
<img src="https://gitcode.com/Soulter/AstrBot/star/badge.svg" href="https://gitcode.com/Soulter/AstrBot">
<img src="https://img.shields.io/github/v/release/AstrBotDevs/AstrBot?style=for-the-badge&color=76bad9" href="https://github.com/AstrBotDevs/AstrBot/releases/latest">
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
<img src="https://img.shields.io/badge/dynamic/json?url=https%3A%2F%2Fapi.soulter.top%2Fastrbot%2Fplugin-num&query=%24.result&suffix=%E5%80%8B&style=for-the-badge&label=%E6%8F%92%E4%BB%B6%E5%B8%82%E5%A0%B4&cacheSeconds=3600">
</div>
<br>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README.md">简体中文</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_en.md">English</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ja.md">日本語</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_fr.md">Français</a>
<a href="https://github.com/AstrBotDevs/AstrBot/blob/master/README_ru.md">Русский</a>
<a href="https://astrbot.app/">文件</a>
<a href="https://blog.astrbot.app/">Blog</a>
<a href="https://astrbot.featurebase.app/roadmap">路線圖</a>
@@ -42,31 +43,12 @@ AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主
## 主要功能
1. 💯 免費 & 開源。
2. ✨ AI 大模型對話,多模態,Agent,MCP,Skills知識庫,人格設定,自動壓縮對話
3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體 (Agent) 平台。
4. 🌐 多平台,支援 QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。
5. 📦 插件擴展,已有近 800 個插件可一鍵安裝。
6. 🛡️ [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) 隔離化環境,安全地執行任何代碼、調用 Shell、會話級資源複用
7. 💻 WebUI 支援。
8. 🌈 Web ChatUI 支援,ChatUI 內置代理沙盒 (Agent Sandbox)、網頁搜尋等。
9. 🌐 國際化(i18n)支援。
<br>
<table align="center">
<tr align="center">
<th>💙 角色扮演 & 情感陪伴</th>
<th>✨ 主動式 Agent</th>
<th>🚀 通用 Agentic 能力</th>
<th>🧩 900+ 社區外掛程式</th>
</tr>
<tr>
<td align="center"><p align="center"><img width="984" height="1746" alt="99b587c5d35eea09d84f33e6cf6cfd4f" src="https://github.com/user-attachments/assets/89196061-3290-458d-b51f-afa178049f84" /></p></td>
<td align="center"><p align="center"><img width="976" height="1612" alt="c449acd838c41d0915cc08a3824025b1" src="https://github.com/user-attachments/assets/f75368b4-e022-41dc-a9e0-131c3e73e32e" /></p></td>
<td align="center"><p align="center"><img width="974" height="1732" alt="image" src="https://github.com/user-attachments/assets/e22a3968-87d7-4708-a7cd-e7f198c7c32e" /></p></td>
<td align="center"><p align="center"><img width="976" height="1734" alt="image" src="https://github.com/user-attachments/assets/0952b395-6b4a-432a-8a50-c294b7f89750" /></p></td>
</tr>
</table>
2. ✨ AI 大模型對話,多模態,Agent,MCP,知識庫,人格設定。
3. 🤖 支援接入 Dify、阿里雲百煉、Coze 等智慧體平台。
4. 🌐 多平台QQ、企業微信、飛書、釘釘、微信公眾號、Telegram、Slack 以及[更多](#支援的訊息平台)。
5. 📦 外掛擴充,已有近 800 個外掛可一鍵安裝。
6. 💻 WebUI 支援
7. 🌐 國際化(i18n支援。
## 快速開始
@@ -79,8 +61,7 @@ AstrBot 是一個開源的一站式 Agent 聊天機器人平台,可接入主
#### uv 部署
```bash
uv tool install astrbot
astrbot
uvx astrbot
```
#### 寶塔面板部署
@@ -134,16 +115,6 @@ uv run main.py
或者請參閱官方文件 [透過原始碼部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html)。
#### 系統套件管理員安裝
##### Arch Linux
```bash
yay -S astrbot-git
# 或者使用 paru
paru -S astrbot-git
```
## 支援的訊息平台
**官方維護**
@@ -270,12 +241,7 @@ pre-commit install
</div>
<div align="center">
_陪伴與能力從來不應該是對立面。我們希望創造的是一個既能理解情緒、給予陪伴,也能可靠完成工作的機器人。_
</details>
_私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div>
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.17.2"
__version__ = "4.14.8"
+1 -2
View File
@@ -15,6 +15,7 @@ class HandoffTool(FunctionTool, Generic[TContext]):
tool_description: str | None = None,
**kwargs,
) -> None:
self.agent = agent
# Avoid passing duplicate `description` to the FunctionTool dataclass.
# Some call sites (e.g. SubAgentOrchestrator) pass `description` via kwargs
@@ -33,8 +34,6 @@ class HandoffTool(FunctionTool, Generic[TContext]):
# Optional provider override for this subagent. When set, the handoff
# execution will use this chat provider id instead of the global/default.
self.provider_id: str | None = None
# Note: Must assign after super().__init__() to prevent parent class from overriding this attribute
self.agent = agent
def default_parameters(self) -> dict:
return {
@@ -10,7 +10,7 @@ from astrbot.core.provider.entities import (
LLMResponse,
ProviderRequest,
)
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from ...hooks import BaseAgentRunHooks
@@ -291,8 +291,8 @@ class DifyAgentRunner(BaseAgentRunner[TContext]):
return Comp.Image(file=item["url"], url=item["url"])
case "audio":
# 仅支持 wav
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"dify_{item['filename']}.wav")
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"{item['filename']}.wav")
await download_file(item["url"], path)
return Comp.Image(file=item["url"], url=item["url"])
case "video":
@@ -91,7 +91,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
custom_token_counter: TokenCounter | None = None,
custom_compressor: ContextCompressor | None = None,
tool_schema_mode: str | None = "full",
fallback_providers: list[Provider] | None = None,
**kwargs: T.Any,
) -> None:
self.req = request
@@ -121,17 +120,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.context_manager = ContextManager(self.context_config)
self.provider = provider
self.fallback_providers: list[Provider] = []
seen_provider_ids: set[str] = {str(provider.provider_config.get("id", ""))}
for fallback_provider in fallback_providers or []:
fallback_id = str(fallback_provider.provider_config.get("id", ""))
if fallback_provider is provider:
continue
if fallback_id and fallback_id in seen_provider_ids:
continue
self.fallback_providers.append(fallback_provider)
if fallback_id:
seen_provider_ids.add(fallback_id)
self.final_llm_resp = None
self._state = AgentState.IDLE
self.tool_executor = tool_executor
@@ -178,19 +166,16 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
self.stats = AgentStats()
self.stats.start_time = time.time()
async def _iter_llm_responses(
self, *, include_model: bool = True
) -> T.AsyncGenerator[LLMResponse, None]:
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse."""
payload = {
"contexts": self.run_context.messages, # list[Message]
"func_tool": self.req.func_tool,
"model": self.req.model, # NOTE: in fact, this arg is None in most cases
"session_id": self.req.session_id,
"extra_user_content_parts": self.req.extra_user_content_parts, # list[ContentPart]
}
if include_model:
# For primary provider we keep explicit model selection if provided.
payload["model"] = self.req.model
if self.streaming:
stream = self.provider.text_chat_stream(**payload)
async for resp in stream: # type: ignore
@@ -198,83 +183,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
else:
yield await self.provider.text_chat(**payload)
async def _iter_llm_responses_with_fallback(
self,
) -> T.AsyncGenerator[LLMResponse, None]:
"""Wrap _iter_llm_responses with provider fallback handling."""
candidates = [self.provider, *self.fallback_providers]
total_candidates = len(candidates)
last_exception: Exception | None = None
last_err_response: LLMResponse | None = None
for idx, candidate in enumerate(candidates):
candidate_id = candidate.provider_config.get("id", "<unknown>")
is_last_candidate = idx == total_candidates - 1
if idx > 0:
logger.warning(
"Switched from %s to fallback chat provider: %s",
self.provider.provider_config.get("id", "<unknown>"),
candidate_id,
)
self.provider = candidate
has_stream_output = False
try:
async for resp in self._iter_llm_responses(include_model=idx == 0):
if resp.is_chunk:
has_stream_output = True
yield resp
continue
if (
resp.role == "err"
and not has_stream_output
and (not is_last_candidate)
):
last_err_response = resp
logger.warning(
"Chat Model %s returns error response, trying fallback to next provider.",
candidate_id,
)
break
yield resp
return
if has_stream_output:
return
except Exception as exc: # noqa: BLE001
last_exception = exc
logger.warning(
"Chat Model %s request error: %s",
candidate_id,
exc,
exc_info=True,
)
continue
if last_err_response:
yield last_err_response
return
if last_exception:
yield LLMResponse(
role="err",
completion_text=(
"All chat models failed: "
f"{type(last_exception).__name__}: {last_exception}"
),
)
return
yield LLMResponse(
role="err",
completion_text="All available chat models are unavailable.",
)
def _simple_print_message_role(self, tag: str = ""):
roles = []
for message in self.run_context.messages:
roles.append(message.role)
logger.debug(f"{tag} RunCtx.messages -> [{len(roles)}] {','.join(roles)}")
@override
async def step(self):
"""Process a single step of the agent.
@@ -295,13 +203,11 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# do truncate and compress
token_usage = self.req.conversation.token_usage if self.req.conversation else 0
self._simple_print_message_role("[BefCompact]")
self.run_context.messages = await self.context_manager.process(
self.run_context.messages, trusted_token_usage=token_usage
)
self._simple_print_message_role("[AftCompact]")
async for llm_response in self._iter_llm_responses_with_fallback():
async for llm_response in self._iter_llm_responses():
if llm_response.is_chunk:
# update ttft
if self.stats.time_to_first_token == 0:
+31 -194
View File
@@ -42,7 +42,6 @@ from astrbot.core.message.components import File, Image, Reply
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
from astrbot.core.provider.manager import llm_tools
from astrbot.core.skills.skill_manager import SkillManager, build_skills_prompt
from astrbot.core.star.context import Context
from astrbot.core.star.star_handler import star_map
@@ -53,17 +52,6 @@ from astrbot.core.tools.cron_tools import (
)
from astrbot.core.utils.file_extract import extract_file_moonshotai
from astrbot.core.utils.llm_metadata import LLM_METADATAS
from astrbot.core.utils.quoted_message.settings import (
SETTINGS as DEFAULT_QUOTED_MESSAGE_SETTINGS,
)
from astrbot.core.utils.quoted_message.settings import (
QuotedMessageParserSettings,
)
from astrbot.core.utils.quoted_message_parser import (
extract_quoted_message_images,
extract_quoted_message_text,
)
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
@dataclass(slots=True)
@@ -120,8 +108,6 @@ class MainAgentBuildConfig:
provider_settings: dict = field(default_factory=dict)
subagent_orchestrator: dict = field(default_factory=dict)
timezone: str | None = None
max_quoted_fallback_images: int = 20
"""Maximum number of images injected from quoted-message fallback extraction."""
@dataclass(slots=True)
@@ -340,24 +326,6 @@ async def _ensure_persona_and_skills(
)
tmgr = plugin_context.get_llm_tool_manager()
# inject toolset in the persona
if (persona and persona.get("tools") is None) or not persona:
persona_toolset = tmgr.get_full_tool_set()
for tool in list(persona_toolset):
if not tool.active:
persona_toolset.remove_tool(tool.name)
else:
persona_toolset = ToolSet()
if persona["tools"]:
for tool_name in persona["tools"]:
tool = tmgr.get_func(tool_name)
if tool and tool.active:
persona_toolset.add_tool(tool)
if not req.func_tool:
req.func_tool = persona_toolset
else:
req.func_tool.merge(persona_toolset)
# sub agents integration
orch_cfg = plugin_context.get_config().get("subagent_orchestrator", {})
so = plugin_context.subagent_orchestrator
@@ -403,19 +371,22 @@ async def _ensure_persona_and_skills(
assigned_tools.add(name)
if req.func_tool is None:
req.func_tool = ToolSet()
toolset = ToolSet()
else:
toolset = req.func_tool
# add subagent handoff tools
for tool in so.handoffs:
req.func_tool.add_tool(tool)
toolset.add_tool(tool)
# check duplicates
if remove_dup:
handoff_names = {tool.name for tool in so.handoffs}
names = toolset.names()
for tool_name in assigned_tools:
if tool_name in handoff_names:
continue
req.func_tool.remove_tool(tool_name)
if tool_name in names:
toolset.remove_tool(tool_name)
req.func_tool = toolset
router_prompt = (
plugin_context.get_config()
@@ -424,14 +395,32 @@ async def _ensure_persona_and_skills(
).strip()
if router_prompt:
req.system_prompt += f"\n{router_prompt}\n"
return
# inject toolset in the persona
if (persona and persona.get("tools") is None) or not persona:
toolset = tmgr.get_full_tool_set()
for tool in list(toolset):
if not tool.active:
toolset.remove_tool(tool.name)
else:
toolset = ToolSet()
if persona["tools"]:
for tool_name in persona["tools"]:
tool = tmgr.get_func(tool_name)
if tool and tool.active:
toolset.add_tool(tool)
if not req.func_tool:
req.func_tool = toolset
else:
req.func_tool.merge(toolset)
try:
event.trace.record(
"sel_persona",
persona_id=persona_id,
persona_toolset=persona_toolset.names(),
"sel_persona", persona_id=persona_id, persona_toolset=toolset.names()
)
except Exception:
pass
logger.debug("Tool set for persona %s: %s", persona_id, toolset.names())
async def _request_img_caption(
@@ -484,29 +473,11 @@ async def _ensure_img_caption(
logger.error("处理图片描述失败: %s", exc)
def _append_quoted_image_attachment(req: ProviderRequest, image_path: str) -> None:
req.extra_user_content_parts.append(
TextPart(text=f"[Image Attachment in quoted message: path {image_path}]")
)
def _get_quoted_message_parser_settings(
provider_settings: dict[str, object] | None,
) -> QuotedMessageParserSettings:
if not isinstance(provider_settings, dict):
return DEFAULT_QUOTED_MESSAGE_SETTINGS
overrides = provider_settings.get("quoted_message_parser")
if not isinstance(overrides, dict):
return DEFAULT_QUOTED_MESSAGE_SETTINGS
return DEFAULT_QUOTED_MESSAGE_SETTINGS.with_overrides(overrides)
async def _process_quote_message(
event: AstrMessageEvent,
req: ProviderRequest,
img_cap_prov_id: str,
plugin_context: Context,
quoted_message_settings: QuotedMessageParserSettings = DEFAULT_QUOTED_MESSAGE_SETTINGS,
) -> None:
quote = None
for comp in event.message_obj.message:
@@ -518,15 +489,7 @@ async def _process_quote_message(
content_parts = []
sender_info = f"({quote.sender_nickname}): " if quote.sender_nickname else ""
message_str = (
await extract_quoted_message_text(
event,
quote,
settings=quoted_message_settings,
)
or quote.message_str
or "[Empty Text]"
)
message_str = quote.message_str or "[Empty Text]"
content_parts.append(f"{sender_info}{message_str}")
image_seg = None
@@ -632,13 +595,11 @@ async def _decorate_llm_request(
)
img_cap_prov_id = cfg.get("default_image_caption_provider_id") or ""
quoted_message_settings = _get_quoted_message_parser_settings(cfg)
await _process_quote_message(
event,
req,
img_cap_prov_id,
plugin_context,
quoted_message_settings,
)
tz = config.timezone
@@ -770,14 +731,6 @@ def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
if plugin.name in event.plugins_name or plugin.reserved:
new_tool_set.add_tool(tool)
req.func_tool = new_tool_set
else:
# mcp tools
tool_set = req.func_tool
if not tool_set:
tool_set = ToolSet()
for tool in llm_tools.func_list:
if isinstance(tool, MCPTool):
tool_set.add_tool(tool)
async def _handle_webchat(
@@ -879,41 +832,6 @@ def _get_compress_provider(
return provider
def _get_fallback_chat_providers(
provider: Provider, plugin_context: Context, provider_settings: dict
) -> list[Provider]:
fallback_ids = provider_settings.get("fallback_chat_models", [])
if not isinstance(fallback_ids, list):
logger.warning(
"fallback_chat_models setting is not a list, skip fallback providers."
)
return []
provider_id = str(provider.provider_config.get("id", ""))
seen_provider_ids: set[str] = {provider_id} if provider_id else set()
fallbacks: list[Provider] = []
for fallback_id in fallback_ids:
if not isinstance(fallback_id, str) or not fallback_id:
continue
if fallback_id in seen_provider_ids:
continue
fallback_provider = plugin_context.get_provider_by_id(fallback_id)
if fallback_provider is None:
logger.warning("Fallback chat provider `%s` not found, skip.", fallback_id)
continue
if not isinstance(fallback_provider, Provider):
logger.warning(
"Fallback chat provider `%s` is invalid type: %s, skip.",
fallback_id,
type(fallback_provider),
)
continue
fallbacks.append(fallback_provider)
seen_provider_ids.add(fallback_id)
return fallbacks
async def build_main_agent(
*,
event: AstrMessageEvent,
@@ -952,8 +870,6 @@ async def build_main_agent(
return None
req.prompt = event.message_str[len(config.provider_wake_prefix) :]
# media files attachments
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_path = await comp.convert_to_file_path()
@@ -969,81 +885,6 @@ async def build_main_agent(
text=f"[File Attachment: name {file_name}, path {file_path}]"
)
)
# quoted message attachments
reply_comps = [
comp for comp in event.message_obj.message if isinstance(comp, Reply)
]
quoted_message_settings = _get_quoted_message_parser_settings(
config.provider_settings
)
fallback_quoted_image_count = 0
for comp in reply_comps:
has_embedded_image = False
if comp.chain:
for reply_comp in comp.chain:
if isinstance(reply_comp, Image):
has_embedded_image = True
image_path = await reply_comp.convert_to_file_path()
req.image_urls.append(image_path)
_append_quoted_image_attachment(req, image_path)
elif isinstance(reply_comp, File):
file_path = await reply_comp.get_file()
file_name = reply_comp.name or os.path.basename(file_path)
req.extra_user_content_parts.append(
TextPart(
text=(
f"[File Attachment in quoted message: "
f"name {file_name}, path {file_path}]"
)
)
)
# Fallback quoted image extraction for reply-id-only payloads, or when
# embedded reply chain only contains placeholders (e.g. [Forward Message], [Image]).
if not has_embedded_image:
try:
fallback_images = normalize_and_dedupe_strings(
await extract_quoted_message_images(
event,
comp,
settings=quoted_message_settings,
)
)
remaining_limit = max(
config.max_quoted_fallback_images
- fallback_quoted_image_count,
0,
)
if remaining_limit <= 0 and fallback_images:
logger.warning(
"Skip quoted fallback images due to limit=%d for umo=%s",
config.max_quoted_fallback_images,
event.unified_msg_origin,
)
continue
if len(fallback_images) > remaining_limit:
logger.warning(
"Truncate quoted fallback images for umo=%s, reply_id=%s from %d to %d",
event.unified_msg_origin,
getattr(comp, "id", None),
len(fallback_images),
remaining_limit,
)
fallback_images = fallback_images[:remaining_limit]
for image_ref in fallback_images:
if image_ref in req.image_urls:
continue
req.image_urls.append(image_ref)
fallback_quoted_image_count += 1
_append_quoted_image_attachment(req, image_ref)
except Exception as exc: # noqa: BLE001
logger.warning(
"Failed to resolve fallback quoted images for umo=%s, reply_id=%s: %s",
event.unified_msg_origin,
getattr(comp, "id", None),
exc,
exc_info=True,
)
conversation = await _get_session_conv(event, plugin_context)
req.conversation = conversation
@@ -1052,7 +893,6 @@ async def build_main_agent(
if isinstance(req.contexts, str):
req.contexts = json.loads(req.contexts)
req.image_urls = normalize_and_dedupe_strings(req.image_urls)
if config.file_extract_enabled:
try:
@@ -1137,9 +977,6 @@ async def build_main_agent(
truncate_turns=config.dequeue_context_length,
enforce_max_turns=config.max_context_length,
tool_schema_mode=config.tool_schema_mode,
fallback_providers=_get_fallback_chat_providers(
provider, plugin_context, config.provider_settings
),
)
if apply_reset:
+6 -9
View File
@@ -1,7 +1,6 @@
import base64
import json
import os
import uuid
from pydantic import Field
from pydantic.dataclasses import dataclass
@@ -241,9 +240,7 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
if "_&exists_" in json.dumps(result):
# Download the file from sandbox
name = os.path.basename(path)
local_path = os.path.join(
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
)
local_path = os.path.join(get_astrbot_temp_path(), name)
await sb.download_file(path, local_path)
logger.info(f"Downloaded file from sandbox: {path} -> {local_path}")
return local_path, True
@@ -355,11 +352,11 @@ class SendMessageToUserTool(FunctionTool[AstrAgentContext]):
MessageChain(chain=components),
)
# if file_from_sandbox:
# try:
# os.remove(local_path)
# except Exception as e:
# logger.error(f"Error removing temp file {local_path}: {e}")
if file_from_sandbox:
try:
os.remove(local_path)
except Exception as e:
logger.error(f"Error removing temp file {local_path}: {e}")
return f"Message sent to session {target_session}"
-2
View File
@@ -11,7 +11,6 @@ from astrbot.core.db.po import (
CommandConflict,
ConversationV2,
Persona,
PersonaFolder,
PlatformMessageHistory,
PlatformSession,
PlatformStat,
@@ -40,7 +39,6 @@ MAIN_DB_MODELS: dict[str, type[SQLModel]] = {
"platform_stats": PlatformStat,
"conversations": ConversationV2,
"personas": Persona,
"persona_folders": PersonaFolder,
"preferences": Preference,
"platform_message_history": PlatformMessageHistory,
"platform_sessions": PlatformSession,
+6 -9
View File
@@ -1,5 +1,4 @@
import os
import uuid
from dataclasses import dataclass, field
from astrbot.api import FunctionTool, logger
@@ -168,9 +167,7 @@ class FileDownloadTool(FunctionTool):
try:
name = os.path.basename(remote_path)
local_path = os.path.join(
get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}"
)
local_path = os.path.join(get_astrbot_temp_path(), name)
# Download file from sandbox
await sb.download_file(remote_path, local_path)
@@ -186,12 +183,12 @@ class FileDownloadTool(FunctionTool):
logger.error(f"Error sending file message: {e}")
# remove
# try:
# os.remove(local_path)
# except Exception as e:
# logger.error(f"Error removing temp file {local_path}: {e}")
try:
os.remove(local_path)
except Exception as e:
logger.error(f"Error removing temp file {local_path}: {e}")
return f"File downloaded successfully to {local_path} and sent to user."
return f"File downloaded successfully to {local_path} and sent to user. The file has been removed from local storage."
return f"File downloaded successfully to {local_path}"
except Exception as e:
+7 -130
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.17.2"
VERSION = "4.14.8"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -15,7 +15,6 @@ WEBHOOK_SUPPORTED_PLATFORMS = [
"wecom_ai_bot",
"slack",
"lark",
"line",
]
# 默认配置
@@ -68,7 +67,6 @@ DEFAULT_CONFIG = {
"provider_settings": {
"enable": True,
"default_provider_id": "",
"fallback_chat_models": [],
"default_image_caption_provider_id": "",
"image_caption_prompt": "Please describe the image using Chinese.",
"provider_pool": ["*"], # "*" 表示使用所有可用的提供者
@@ -101,13 +99,6 @@ DEFAULT_CONFIG = {
"streaming_response": False,
"show_tool_use_status": False,
"sanitize_context_by_modalities": False,
"max_quoted_fallback_images": 20,
"quoted_message_parser": {
"max_component_chain_depth": 4,
"max_forward_node_depth": 6,
"max_forward_fetch": 32,
"warn_on_action_failure": False,
},
"agent_runner_type": "local",
"dify_agent_runner_provider_id": "",
"coze_agent_runner_provider_id": "",
@@ -138,9 +129,8 @@ DEFAULT_CONFIG = {
},
# SubAgent orchestrator mode:
# - main_enable = False: disabled; main LLM mounts tools normally (persona selection).
# - main_enable = True: enabled; main LLM keeps its own tools and includes handoff
# tools (transfer_to_*). remove_main_duplicate_tools can remove tools that are
# duplicated on subagents from the main LLM toolset.
# - main_enable = True: enabled; main LLM will include handoff tools and can optionally
# remove tools that are duplicated on subagents via remove_main_duplicate_tools.
"subagent_orchestrator": {
"main_enable": False,
"remove_main_duplicate_tools": False,
@@ -196,12 +186,6 @@ DEFAULT_CONFIG = {
"host": "0.0.0.0",
"port": 6185,
"disable_access_log": True,
"ssl": {
"enable": False,
"cert_file": "",
"key_file": "",
"ca_certs": "",
},
},
"platform": [],
"platform_specific": {
@@ -218,7 +202,6 @@ DEFAULT_CONFIG = {
"log_file_enable": False,
"log_file_path": "logs/astrbot.log",
"log_file_max_mb": 20,
"temp_dir_max_size": 1024,
"trace_enable": False,
"trace_log_enable": False,
"trace_log_path": "logs/astrbot.trace.log",
@@ -336,11 +319,9 @@ CONFIG_METADATA_2 = {
"id": "wecom_ai_bot",
"type": "wecom_ai_bot",
"enable": True,
"wecomaibot_init_respond_text": "",
"wecomaibot_init_respond_text": "💭 思考中...",
"wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"msg_push_webhook_url": "",
"only_use_webhook_url_to_send": False,
"token": "",
"encoding_aes_key": "",
"unified_webhook_mode": True,
@@ -423,7 +404,6 @@ CONFIG_METADATA_2 = {
"slack_webhook_port": 6197,
"slack_webhook_path": "/astrbot-slack-webhook/callback",
},
# LINE's config is located in line_adapter.py
"Satori": {
"id": "satori",
"type": "satori",
@@ -707,23 +687,13 @@ CONFIG_METADATA_2 = {
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"type": "string",
"hint": "当机器人收到消息时,首先回复的文本内容。留空则不设置",
"hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值",
},
"wecomaibot_friend_message_welcome_text": {
"description": "企业微信智能机器人私聊欢迎语",
"type": "string",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
},
"msg_push_webhook_url": {
"description": "企业微信消息推送 Webhook URL",
"type": "string",
"hint": "用于 send_by_session 主动消息推送。格式示例: https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx",
},
"only_use_webhook_url_to_send": {
"description": "仅使用 Webhook 发送消息",
"type": "bool",
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
},
"lark_bot_name": {
"description": "飞书机器人的名字",
"type": "string",
@@ -2214,10 +2184,6 @@ CONFIG_METADATA_2 = {
"default_provider_id": {
"type": "string",
},
"fallback_chat_models": {
"type": "list",
"items": {"type": "string"},
},
"wake_prefix": {
"type": "string",
},
@@ -2412,23 +2378,9 @@ CONFIG_METADATA_2 = {
"type": "string",
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
},
"dashboard.ssl.enable": {"type": "bool"},
"dashboard.ssl.cert_file": {
"type": "string",
"condition": {"dashboard.ssl.enable": True},
},
"dashboard.ssl.key_file": {
"type": "string",
"condition": {"dashboard.ssl.enable": True},
},
"dashboard.ssl.ca_certs": {
"type": "string",
"condition": {"dashboard.ssl.enable": True},
},
"log_file_enable": {"type": "bool"},
"log_file_path": {"type": "string", "condition": {"log_file_enable": True}},
"log_file_max_mb": {"type": "int", "condition": {"log_file_enable": True}},
"temp_dir_max_size": {"type": "int"},
"trace_log_enable": {"type": "bool"},
"trace_log_path": {
"type": "string",
@@ -2528,22 +2480,15 @@ CONFIG_METADATA_3 = {
},
"ai": {
"description": "模型",
"hint": "当使用非内置 Agent 执行器时,默认对话模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"hint": "当使用非内置 Agent 执行器时,默认聊天模型和默认图片转述模型可能会无效,但某些插件会依赖此配置项来调用 AI 能力。",
"type": "object",
"items": {
"provider_settings.default_provider_id": {
"description": "默认对话模型",
"description": "默认聊天模型",
"type": "string",
"_special": "select_provider",
"hint": "留空时使用第一个模型",
},
"provider_settings.fallback_chat_models": {
"description": "回退对话模型列表",
"type": "list",
"items": {"type": "string"},
"_special": "select_providers",
"hint": "主聊天模型请求失败时,按顺序切换到这些模型。",
},
"provider_settings.default_image_caption_provider_id": {
"description": "默认图片转述模型",
"type": "string",
@@ -2948,46 +2893,6 @@ CONFIG_METADATA_3 = {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_quoted_fallback_images": {
"description": "引用图片回退解析上限",
"type": "int",
"hint": "引用/转发消息回退解析图片时的最大注入数量,超出会截断。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_component_chain_depth": {
"description": "引用解析组件链深度",
"type": "int",
"hint": "解析 Reply 组件链时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_node_depth": {
"description": "引用解析转发节点深度",
"type": "int",
"hint": "解析合并转发节点时允许的最大递归深度。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.max_forward_fetch": {
"description": "引用解析转发拉取上限",
"type": "int",
"hint": "递归拉取 get_forward_msg 的最大次数。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.quoted_message_parser.warn_on_action_failure": {
"description": "引用解析 action 失败告警",
"type": "bool",
"hint": "开启后,get_msg/get_forward_msg 全部尝试失败时输出 warning 日志。",
"condition": {
"provider_settings.agent_runner_type": "local",
},
},
"provider_settings.max_agent_step": {
"description": "工具调用轮数上限",
"type": "int",
@@ -3439,29 +3344,6 @@ CONFIG_METADATA_3_SYSTEM = {
"hint": "控制台输出日志的级别。",
"options": ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
},
"dashboard.ssl.enable": {
"description": "启用 WebUI HTTPS",
"type": "bool",
"hint": "启用后,WebUI 将直接使用 HTTPS 提供服务。",
},
"dashboard.ssl.cert_file": {
"description": "SSL 证书文件路径",
"type": "string",
"hint": "证书文件路径(PEM)。支持绝对路径和相对路径(相对于当前工作目录)。",
"condition": {"dashboard.ssl.enable": True},
},
"dashboard.ssl.key_file": {
"description": "SSL 私钥文件路径",
"type": "string",
"hint": "私钥文件路径(PEM)。支持绝对路径和相对路径(相对于当前工作目录)。",
"condition": {"dashboard.ssl.enable": True},
},
"dashboard.ssl.ca_certs": {
"description": "SSL CA 证书文件路径",
"type": "string",
"hint": "可选。用于指定 CA 证书文件路径。",
"condition": {"dashboard.ssl.enable": True},
},
"log_file_enable": {
"description": "启用文件日志",
"type": "bool",
@@ -3477,11 +3359,6 @@ CONFIG_METADATA_3_SYSTEM = {
"type": "int",
"hint": "超过大小后自动轮转,默认 20MB。",
},
"temp_dir_max_size": {
"description": "临时目录大小上限 (MB)",
"type": "int",
"hint": "用于限制 data/temp 目录总大小,单位为 MB。系统每 10 分钟检查一次,超限时按文件修改时间从旧到新删除,释放约 30% 当前体积。",
},
"trace_log_enable": {
"description": "启用 Trace 文件日志",
"type": "bool",
-19
View File
@@ -37,7 +37,6 @@ from astrbot.core.umop_config_router import UmopConfigRouter
from astrbot.core.updator import AstrBotUpdator
from astrbot.core.utils.llm_metadata import update_llm_metadata
from astrbot.core.utils.migra_helper import migra
from astrbot.core.utils.temp_dir_cleaner import TempDirCleaner
from . import astrbot_config, html_renderer
from .event_bus import EventBus
@@ -58,7 +57,6 @@ class AstrBotCoreLifecycle:
self.subagent_orchestrator: SubAgentOrchestrator | None = None
self.cron_manager: CronJobManager | None = None
self.temp_dir_cleaner: TempDirCleaner | None = None
# 设置代理
proxy_config = self.astrbot_config.get("http_proxy", "")
@@ -127,12 +125,6 @@ class AstrBotCoreLifecycle:
ucr=self.umop_config_router,
sp=sp,
)
self.temp_dir_cleaner = TempDirCleaner(
max_size_getter=lambda: self.astrbot_config_mgr.default_conf.get(
TempDirCleaner.CONFIG_KEY,
TempDirCleaner.DEFAULT_MAX_SIZE,
),
)
# apply migration
try:
@@ -246,12 +238,6 @@ class AstrBotCoreLifecycle:
self.cron_manager.start(self.star_context),
name="cron_manager",
)
temp_dir_cleaner_task = None
if self.temp_dir_cleaner:
temp_dir_cleaner_task = asyncio.create_task(
self.temp_dir_cleaner.run(),
name="temp_dir_cleaner",
)
# 把插件中注册的所有协程函数注册到事件总线中并执行
extra_tasks = []
@@ -261,8 +247,6 @@ class AstrBotCoreLifecycle:
tasks_ = [event_bus_task, *(extra_tasks if extra_tasks else [])]
if cron_task:
tasks_.append(cron_task)
if temp_dir_cleaner_task:
tasks_.append(temp_dir_cleaner_task)
for task in tasks_:
self.curr_tasks.append(
asyncio.create_task(self._task_wrapper(task), name=task.get_name()),
@@ -314,9 +298,6 @@ class AstrBotCoreLifecycle:
async def stop(self) -> None:
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器."""
if self.temp_dir_cleaner:
await self.temp_dir_cleaner.stop()
# 请求停止所有正在运行的异步任务
for task in self.curr_tasks:
task.cancel()
+295 -252
View File
@@ -1,4 +1,24 @@
"""日志系统,统一将标准 logging 输出转发到 loguru。"""
"""日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
const:
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
class:
LogBroker: 日志代理类, 用于缓存和分发日志消息
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
LogManager: 日志管理器, 用于创建和配置日志记录器
function:
is_plugin_path: 检查文件路径是否来自插件目录
get_short_level_name: 将日志级别名称转换为四个字母的缩写
工作流程:
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
"""
import asyncio
import logging
@@ -7,59 +27,54 @@ import sys
import time
from asyncio import Queue
from collections import deque
from typing import TYPE_CHECKING
from logging.handlers import RotatingFileHandler
from loguru import logger as _raw_loguru_logger
import colorlog
from astrbot.core.config.default import VERSION
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
# 日志缓存大小
CACHED_SIZE = 500
if TYPE_CHECKING:
from loguru import Record
# 日志颜色配置
log_color_config = {
"DEBUG": "green",
"INFO": "bold_cyan",
"WARNING": "bold_yellow",
"ERROR": "red",
"CRITICAL": "bold_red",
"RESET": "reset",
"asctime": "green",
}
class _RecordEnricherFilter(logging.Filter):
"""为 logging.LogRecord 注入 AstrBot 日志字段。"""
def is_plugin_path(pathname):
"""检查文件路径是否来自插件目录
def filter(self, record: logging.LogRecord) -> bool:
record.plugin_tag = "[Plug]" if _is_plugin_path(record.pathname) else "[Core]"
record.short_levelname = _get_short_level_name(record.levelname)
record.astrbot_version_tag = (
f" [v{VERSION}]" if record.levelno >= logging.WARNING else ""
)
record.source_file = _build_source_file(record.pathname)
record.source_line = record.lineno
record.is_trace = record.name == "astrbot.trace"
return True
Args:
pathname (str): 文件路径
Returns:
bool: 如果路径来自插件目录则返回 True否则返回 False
class _QueueAnsiColorFilter(logging.Filter):
"""Attach ANSI color prefix for WebUI console rendering."""
_LEVEL_COLOR = {
"DEBUG": "\u001b[1;34m",
"INFO": "\u001b[1;36m",
"WARNING": "\u001b[1;33m",
"ERROR": "\u001b[31m",
"CRITICAL": "\u001b[1;31m",
}
def filter(self, record: logging.LogRecord) -> bool:
record.ansi_prefix = self._LEVEL_COLOR.get(record.levelname, "\u001b[0m")
record.ansi_reset = "\u001b[0m"
return True
def _is_plugin_path(pathname: str | None) -> bool:
"""
if not pathname:
return False
norm_path = os.path.normpath(pathname)
return ("data/plugins" in norm_path) or ("astrbot/builtin_stars/" in norm_path)
def _get_short_level_name(level_name: str) -> str:
def get_short_level_name(level_name):
"""将日志级别名称转换为四个字母的缩写
Args:
level_name (str): 日志级别名称, "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
Returns:
str: 四个字母的日志级别缩写
"""
level_map = {
"DEBUG": "DBUG",
"INFO": "INFO",
@@ -70,75 +85,44 @@ def _get_short_level_name(level_name: str) -> str:
return level_map.get(level_name, level_name[:4].upper())
def _build_source_file(pathname: str | None) -> str:
if not pathname:
return "unknown"
dirname = os.path.dirname(pathname)
return (
os.path.basename(dirname) + "." + os.path.basename(pathname).replace(".py", "")
)
def _patch_record(record: "Record") -> None:
extra = record["extra"]
extra.setdefault("plugin_tag", "[Core]")
extra.setdefault("short_levelname", _get_short_level_name(record["level"].name))
level_no = record["level"].no
extra.setdefault("astrbot_version_tag", f" [v{VERSION}]" if level_no >= 30 else "")
extra.setdefault("source_file", _build_source_file(record["file"].path))
extra.setdefault("source_line", record["line"])
extra.setdefault("is_trace", False)
_loguru = _raw_loguru_logger.patch(_patch_record)
class _LoguruInterceptHandler(logging.Handler):
"""将 logging 记录转发到 loguru。"""
def emit(self, record: logging.LogRecord) -> None:
try:
level: str | int = _loguru.level(record.levelname).name
except ValueError:
level = record.levelno
payload = {
"plugin_tag": getattr(record, "plugin_tag", "[Core]"),
"short_levelname": getattr(
record,
"short_levelname",
_get_short_level_name(record.levelname),
),
"astrbot_version_tag": getattr(record, "astrbot_version_tag", ""),
"source_file": getattr(
record, "source_file", _build_source_file(record.pathname)
),
"source_line": getattr(record, "source_line", record.lineno),
"is_trace": getattr(record, "is_trace", record.name == "astrbot.trace"),
}
_loguru.bind(**payload).opt(exception=record.exc_info).log(
level,
record.getMessage(),
)
class LogBroker:
"""日志代理类用于缓存和分发日志消息"""
"""日志代理类, 用于缓存和分发日志消息
发布-订阅模式
"""
def __init__(self) -> None:
self.log_cache = deque(maxlen=CACHED_SIZE)
self.subscribers: list[Queue] = []
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: list[Queue] = [] # 订阅者列表
def register(self) -> Queue:
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
Returns:
Queue: 订阅者的队列, 可用于接收日志消息
"""
q = Queue(maxsize=CACHED_SIZE + 10)
self.subscribers.append(q)
return q
def unregister(self, q: Queue) -> None:
"""取消订阅
Args:
q (Queue): 需要取消订阅的队列
"""
self.subscribers.remove(q)
def publish(self, log_entry: dict) -> None:
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
Args:
log_entry (dict): 日志消息, 包含日志级别和日志内容.
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
"""
self.log_cache.append(log_entry)
for q in self.subscribers:
try:
@@ -148,13 +132,23 @@ class LogBroker:
class LogQueueHandler(logging.Handler):
"""日志处理器用于将日志消息发送到 LogBroker"""
"""日志处理器, 用于将日志消息发送到 LogBroker
继承自 logging.Handler
"""
def __init__(self, log_broker: LogBroker) -> None:
super().__init__()
self.log_broker = log_broker
def emit(self, record: logging.LogRecord) -> None:
def emit(self, record) -> None:
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
这个方法会在每次日志记录时被调用
Args:
record (logging.LogRecord): 日志记录对象, 包含日志信息
"""
log_entry = self.format(record)
self.log_broker.publish(
{
@@ -166,16 +160,117 @@ class LogQueueHandler(logging.Handler):
class LogManager:
_LOGGER_HANDLER_FLAG = "_astrbot_loguru_handler"
_ENRICH_FILTER_FLAG = "_astrbot_enrich_filter"
"""日志管理器, 用于创建和配置日志记录器
_configured = False
_console_sink_id: int | None = None
_file_sink_id: int | None = None
_trace_sink_id: int | None = None
_NOISY_LOGGER_LEVELS: dict[str, int] = {
"aiosqlite": logging.WARNING,
}
提供了获取默认日志记录器logger和设置队列处理器的方法
"""
_FILE_HANDLER_FLAG = "_astrbot_file_handler"
_TRACE_FILE_HANDLER_FLAG = "_astrbot_trace_file_handler"
@classmethod
def GetLogger(cls, log_name: str = "default"):
"""获取指定名称的日志记录器logger
Args:
log_name (str): 日志记录器的名称, 默认为 "default"
Returns:
logging.Logger: 返回配置好的日志记录器
"""
logger = logging.getLogger(log_name)
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
if logger.hasHandlers():
return logger
# 如果logger没有处理器
console_handler = logging.StreamHandler(
sys.stdout,
) # 创建一个StreamHandler用于控制台输出
console_handler.setLevel(
logging.DEBUG,
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
console_formatter = colorlog.ColoredFormatter(
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s]%(astrbot_version_tag)s [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
datefmt="%H:%M:%S",
log_colors=log_color_config,
)
class PluginFilter(logging.Filter):
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
def filter(self, record) -> bool:
record.plugin_tag = (
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
)
return True
class FileNameFilter(logging.Filter):
"""文件名过滤器类, 用于修改日志记录的文件名格式
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式
"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record) -> bool:
dirname = os.path.dirname(record.pathname)
record.filename = (
os.path.basename(dirname)
+ "."
+ os.path.basename(record.pathname).replace(".py", "")
)
return True
class LevelNameFilter(logging.Filter):
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
# 添加短日志级别名称
def filter(self, record) -> bool:
record.short_levelname = get_short_level_name(record.levelname)
return True
class AstrBotVersionTagFilter(logging.Filter):
"""在 WARNING 及以上级别日志后追加当前 AstrBot 版本号。"""
def filter(self, record) -> bool:
if record.levelno >= logging.WARNING:
record.astrbot_version_tag = f" [v{VERSION}]"
else:
record.astrbot_version_tag = ""
return True
console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
logger.addFilter(PluginFilter()) # 添加插件过滤器
logger.addFilter(FileNameFilter()) # 添加文件名过滤器
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
logger.addFilter(AstrBotVersionTagFilter()) # 追加版本号(WARNING 及以上)
logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
logger.addHandler(console_handler) # 添加处理器到logger
return logger
@classmethod
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None:
"""设置队列处理器, 用于将日志消息发送到 LogBroker
Args:
logger (logging.Logger): 日志记录器
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
"""
handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG)
if logger.handlers:
handler.setFormatter(logger.handlers[0].formatter)
else:
# 为队列处理器设置相同格式的formatter
handler.setFormatter(
logging.Formatter(
"[%(asctime)s] [%(short_levelname)s] %(plugin_tag)s[%(filename)s:%(lineno)d]: %(message)s",
),
)
logger.addHandler(handler)
@classmethod
def _default_log_path(cls) -> str:
@@ -190,147 +285,79 @@ class LogManager:
return os.path.join(get_astrbot_data_path(), configured_path)
@classmethod
def _setup_loguru(cls) -> None:
if cls._configured:
return
_loguru.remove()
cls._console_sink_id = _loguru.add(
sys.stdout,
level="DEBUG",
colorize=True,
filter=lambda record: not record["extra"].get("is_trace", False),
format=(
"<green>[{time:HH:mm:ss.SSS}]</green> {extra[plugin_tag]} "
"<level>[{extra[short_levelname]}]</level>{extra[astrbot_version_tag]} "
"[{extra[source_file]}:{extra[source_line]}]: <level>{message}</level>"
),
)
cls._configured = True
@classmethod
def _setup_root_bridge(cls) -> None:
root_logger = logging.getLogger()
has_handler = any(
getattr(handler, cls._LOGGER_HANDLER_FLAG, False)
for handler in root_logger.handlers
)
if not has_handler:
handler = _LoguruInterceptHandler()
setattr(handler, cls._LOGGER_HANDLER_FLAG, True)
root_logger.addHandler(handler)
root_logger.setLevel(logging.DEBUG)
for name, level in cls._NOISY_LOGGER_LEVELS.items():
logging.getLogger(name).setLevel(level)
@classmethod
def _ensure_logger_enricher_filter(cls, logger: logging.Logger) -> None:
has_filter = any(
getattr(existing_filter, cls._ENRICH_FILTER_FLAG, False)
for existing_filter in logger.filters
)
if not has_filter:
enrich_filter = _RecordEnricherFilter()
setattr(enrich_filter, cls._ENRICH_FILTER_FLAG, True)
logger.addFilter(enrich_filter)
@classmethod
def _ensure_logger_intercept_handler(cls, logger: logging.Logger) -> None:
has_handler = any(
getattr(handler, cls._LOGGER_HANDLER_FLAG, False)
def _get_file_handlers(cls, logger: logging.Logger) -> list[logging.Handler]:
return [
handler
for handler in logger.handlers
)
if not has_handler:
handler = _LoguruInterceptHandler()
setattr(handler, cls._LOGGER_HANDLER_FLAG, True)
logger.addHandler(handler)
if getattr(handler, cls._FILE_HANDLER_FLAG, False)
]
@classmethod
def GetLogger(cls, log_name: str = "default") -> logging.Logger:
cls._setup_loguru()
cls._setup_root_bridge()
logger = logging.getLogger(log_name)
cls._ensure_logger_enricher_filter(logger)
cls._ensure_logger_intercept_handler(logger)
logger.setLevel(logging.DEBUG)
logger.propagate = False
return logger
def _get_trace_file_handlers(cls, logger: logging.Logger) -> list[logging.Handler]:
return [
handler
for handler in logger.handlers
if getattr(handler, cls._TRACE_FILE_HANDLER_FLAG, False)
]
@classmethod
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None:
cls._ensure_logger_enricher_filter(logger)
for handler in logger.handlers:
if isinstance(handler, LogQueueHandler):
return
handler = LogQueueHandler(log_broker)
handler.setLevel(logging.DEBUG)
handler.addFilter(_QueueAnsiColorFilter())
handler.setFormatter(
logging.Formatter(
"%(ansi_prefix)s[%(asctime)s.%(msecs)03d] %(plugin_tag)s [%(short_levelname)s]%(astrbot_version_tag)s "
"[%(source_file)s:%(source_line)d]: %(message)s%(ansi_reset)s",
datefmt="%Y-%m-%d %H:%M:%S",
),
)
logger.addHandler(handler)
def _remove_file_handlers(cls, logger: logging.Logger) -> None:
for handler in cls._get_file_handlers(logger):
logger.removeHandler(handler)
try:
handler.close()
except Exception:
pass
@classmethod
def _remove_sink(cls, sink_id: int | None) -> None:
if sink_id is None:
return
try:
_loguru.remove(sink_id)
except ValueError:
pass
def _remove_trace_file_handlers(cls, logger: logging.Logger) -> None:
for handler in cls._get_trace_file_handlers(logger):
logger.removeHandler(handler)
try:
handler.close()
except Exception:
pass
@classmethod
def _add_file_sink(
def _add_file_handler(
cls,
*,
logger: logging.Logger,
file_path: str,
level: int,
max_mb: int | None,
backup_count: int,
trace: bool,
) -> int:
max_mb: int | None = None,
backup_count: int = 3,
trace: bool = False,
) -> None:
os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True)
rotation = f"{max_mb} MB" if max_mb and max_mb > 0 else None
retention = (
backup_count if rotation and backup_count and backup_count > 0 else None
)
if trace:
return _loguru.add(
max_bytes = 0
if max_mb and max_mb > 0:
max_bytes = max_mb * 1024 * 1024
if max_bytes > 0:
file_handler = RotatingFileHandler(
file_path,
level="INFO",
format="[{time:YYYY-MM-DD HH:mm:ss.SSS}] {message}",
maxBytes=max_bytes,
backupCount=backup_count,
encoding="utf-8",
rotation=rotation,
retention=retention,
enqueue=True,
filter=lambda record: record["extra"].get("is_trace", False),
)
logging_level_name = logging.getLevelName(level)
if isinstance(logging_level_name, int):
logging_level_name = "INFO"
return _loguru.add(
file_path,
level=logging_level_name,
format=(
"[{time:YYYY-MM-DD HH:mm:ss.SSS}] {extra[plugin_tag]} "
"[{extra[short_levelname]}]{extra[astrbot_version_tag]} "
"[{extra[source_file]}:{extra[source_line]}]: {message}"
),
encoding="utf-8",
rotation=rotation,
retention=retention,
enqueue=True,
filter=lambda record: not record["extra"].get("is_trace", False),
else:
file_handler = logging.FileHandler(file_path, encoding="utf-8")
file_handler.setLevel(logger.level)
if trace:
formatter = logging.Formatter(
"[%(asctime)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
else:
formatter = logging.Formatter(
"[%(asctime)s] %(plugin_tag)s [%(short_levelname)s]%(astrbot_version_tag)s [%(filename)s:%(lineno)d]: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler.setFormatter(formatter)
setattr(
file_handler,
cls._TRACE_FILE_HANDLER_FLAG if trace else cls._FILE_HANDLER_FLAG,
True,
)
logger.addHandler(file_handler)
@classmethod
def configure_logger(
@@ -339,6 +366,13 @@ class LogManager:
config: dict | None,
override_level: str | None = None,
) -> None:
"""根据配置设置日志级别和文件日志。
Args:
logger: 需要配置的 logger
config: 配置字典
override_level: 若提供将覆盖配置中的日志级别
"""
if not config:
return
@@ -349,6 +383,7 @@ class LogManager:
except Exception:
logger.setLevel(logging.INFO)
# 兼容旧版嵌套配置
if "log_file" in config:
file_conf = config.get("log_file") or {}
enable_file = bool(file_conf.get("enable", False))
@@ -359,25 +394,27 @@ class LogManager:
file_path = config.get("log_file_path")
max_mb = config.get("log_file_max_mb")
cls._remove_sink(cls._file_sink_id)
cls._file_sink_id = None
file_path = cls._resolve_log_path(file_path)
existing = cls._get_file_handlers(logger)
if not enable_file:
cls._remove_file_handlers(logger)
return
try:
cls._file_sink_id = cls._add_file_sink(
file_path=cls._resolve_log_path(file_path),
level=logger.level,
max_mb=max_mb,
backup_count=3,
trace=False,
)
except Exception as e:
logger.error(f"Failed to add file sink: {e}")
# 如果已有文件处理器且路径一致,则仅同步级别
if existing:
handler = existing[0]
base = getattr(handler, "baseFilename", "")
if base and os.path.abspath(base) == os.path.abspath(file_path):
handler.setLevel(logger.level)
return
cls._remove_file_handlers(logger)
cls._add_file_handler(logger, file_path, max_mb=max_mb)
@classmethod
def configure_trace_logger(cls, config: dict | None) -> None:
"""为 trace 事件配置独立的文件日志,不向控制台输出。"""
if not config:
return
@@ -392,22 +429,28 @@ class LogManager:
path = path or legacy.get("trace_path")
max_mb = max_mb or legacy.get("trace_max_mb")
if not enable:
trace_logger = logging.getLogger("astrbot.trace")
cls._remove_trace_file_handlers(trace_logger)
return
file_path = cls._resolve_log_path(path or "logs/astrbot.trace.log")
trace_logger = logging.getLogger("astrbot.trace")
cls._ensure_logger_enricher_filter(trace_logger)
cls._ensure_logger_intercept_handler(trace_logger)
trace_logger.setLevel(logging.INFO)
trace_logger.propagate = False
cls._remove_sink(cls._trace_sink_id)
cls._trace_sink_id = None
existing = cls._get_trace_file_handlers(trace_logger)
if existing:
handler = existing[0]
base = getattr(handler, "baseFilename", "")
if base and os.path.abspath(base) == os.path.abspath(file_path):
handler.setLevel(trace_logger.level)
return
cls._remove_trace_file_handlers(trace_logger)
if not enable:
return
cls._trace_sink_id = cls._add_file_sink(
file_path=cls._resolve_log_path(path or "logs/astrbot.trace.log"),
level=logging.INFO,
cls._add_file_handler(
trace_logger,
file_path,
max_mb=max_mb,
backup_count=3,
trace=True,
)
+11 -13
View File
@@ -31,7 +31,7 @@ from enum import Enum
from pydantic.v1 import BaseModel
from astrbot.core import astrbot_config, file_token_service, logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
@@ -156,9 +156,8 @@ class Record(BaseMessageComponent):
if self.file.startswith("base64://"):
bs64_data = self.file.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
file_path = os.path.join(
get_astrbot_temp_path(), f"recordseg_{uuid.uuid4()}.jpg"
)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
with open(file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(file_path)
@@ -246,9 +245,8 @@ class Video(BaseMessageComponent):
if url and url.startswith("file:///"):
return url[8:]
if url and url.startswith("http"):
video_file_path = os.path.join(
get_astrbot_temp_path(), f"videoseg_{uuid.uuid4().hex}"
)
download_dir = os.path.join(get_astrbot_data_path(), "temp")
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
await download_file(url, video_file_path)
if os.path.exists(video_file_path):
return os.path.abspath(video_file_path)
@@ -447,9 +445,8 @@ class Image(BaseMessageComponent):
if url.startswith("base64://"):
bs64_data = url.removeprefix("base64://")
image_bytes = base64.b64decode(bs64_data)
image_file_path = os.path.join(
get_astrbot_temp_path(), f"imgseg_{uuid.uuid4()}.jpg"
)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
with open(image_file_path, "wb") as f:
f.write(image_bytes)
return os.path.abspath(image_file_path)
@@ -728,12 +725,13 @@ class File(BaseMessageComponent):
"""下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
download_dir = get_astrbot_temp_path()
download_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(download_dir, exist_ok=True)
if self.name:
name, ext = os.path.splitext(self.name)
filename = f"fileseg_{name}_{uuid.uuid4().hex[:8]}{ext}"
filename = f"{name}_{uuid.uuid4().hex[:8]}{ext}"
else:
filename = f"fileseg_{uuid.uuid4().hex}"
filename = f"{uuid.uuid4().hex}"
file_path = os.path.join(download_dir, filename)
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
@@ -123,7 +123,6 @@ class InternalAgentSubStage(Stage):
provider_settings=settings,
subagent_orchestrator=conf.get("subagent_orchestrator", {}),
timezone=self.ctx.plugin_manager.context.get_config().get("timezone"),
max_quoted_fallback_images=settings.get("max_quoted_fallback_images", 20),
)
async def process(
@@ -150,7 +149,6 @@ class InternalAgentSubStage(Stage):
logger.debug("ready to request llm provider")
await event.send_typing()
await call_event_hook(event, EventType.OnWaitingLLMRequestEvent)
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
@@ -192,8 +190,6 @@ class InternalAgentSubStage(Stage):
)
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
if reset_coro:
reset_coro.close()
return
# apply reset
+10 -11
View File
@@ -61,17 +61,16 @@ class RespondStage(Stage):
self.log_base = float(
ctx.astrbot_config["platform_settings"]["segmented_reply"]["log_base"],
)
self.interval = [1.5, 3.5]
if self.enable_seg:
interval_str: str = ctx.astrbot_config["platform_settings"][
"segmented_reply"
]["interval"]
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
self.interval = [float(t) for t in interval_str_ls]
except BaseException as e:
logger.error(f"解析分段回复间隔时间失败。{e}")
logger.info(f"分段回复间隔时间:{self.interval}")
interval_str: str = ctx.astrbot_config["platform_settings"]["segmented_reply"][
"interval"
]
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
self.interval = [float(t) for t in interval_str_ls]
except BaseException as e:
logger.error(f"解析分段回复的间隔时间失败。{e}")
self.interval = [1.5, 3.5]
logger.info(f"分段回复间隔时间{self.interval}")
async def _word_cnt(self, text: str) -> int:
"""分段回复 统计字数"""
@@ -244,12 +244,6 @@ class AstrMessageEvent(abc.ABC):
)
self._has_send_oper = True
async def send_typing(self) -> None:
"""发送输入中状态。
默认实现为空由具体平台按需重写
"""
async def _pre_send(self) -> None:
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
+20 -71
View File
@@ -1,7 +1,6 @@
import asyncio
import traceback
from asyncio import Queue
from dataclasses import dataclass
from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -13,19 +12,12 @@ from .register import platform_cls_map
from .sources.webchat.webchat_adapter import WebChatAdapter
@dataclass
class PlatformTasks:
run: asyncio.Task
wrapper: asyncio.Task
class PlatformManager:
def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None:
self.platform_insts: list[Platform] = []
"""加载的 Platform 的实例"""
self._inst_map: dict[str, dict] = {}
self._platform_tasks: dict[str, PlatformTasks] = {}
self.astrbot_config = config
self.platforms_config = config["platform"]
@@ -46,44 +38,6 @@ class PlatformManager:
sanitized = platform_id.replace(":", "_").replace("!", "_")
return sanitized, sanitized != platform_id
def _start_platform_task(self, task_name: str, inst: Platform) -> None:
run_task = asyncio.create_task(inst.run(), name=task_name)
wrapper_task = asyncio.create_task(
self._task_wrapper(run_task, platform=inst),
name=f"{task_name}_wrapper",
)
self._platform_tasks[inst.client_self_id] = PlatformTasks(
run=run_task,
wrapper=wrapper_task,
)
async def _stop_platform_task(self, client_id: str) -> None:
tasks = self._platform_tasks.pop(client_id, None)
if not tasks:
return
for task in (tasks.run, tasks.wrapper):
if not task.done():
task.cancel()
await asyncio.gather(tasks.run, tasks.wrapper, return_exceptions=True)
async def _terminate_inst_and_tasks(self, inst: Platform) -> None:
client_id = inst.client_self_id
try:
if getattr(inst, "terminate", None):
try:
await inst.terminate()
except asyncio.CancelledError:
raise
except Exception as e:
logger.error(
"终止平台适配器失败: client_id=%s, error=%s",
client_id,
e,
)
logger.error(traceback.format_exc())
finally:
await self._stop_platform_task(client_id)
async def initialize(self) -> None:
"""初始化所有平台适配器"""
for platform in self.platforms_config:
@@ -97,7 +51,12 @@ class PlatformManager:
# 网页聊天
webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
self.platform_insts.append(webchat_inst)
self._start_platform_task("webchat", webchat_inst)
asyncio.create_task(
self._task_wrapper(
asyncio.create_task(webchat_inst.run(), name="webchat"),
platform=webchat_inst,
),
)
async def load_platform(self, platform_config: dict) -> None:
"""实例化一个平台"""
@@ -176,10 +135,6 @@ class PlatformManager:
from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401
)
case "line":
from .sources.line.line_adapter import (
LinePlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e:
logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
@@ -199,9 +154,15 @@ class PlatformManager:
"client_id": inst.client_self_id,
}
self.platform_insts.append(inst)
self._start_platform_task(
f"platform_{platform_config['type']}_{platform_config['id']}",
inst,
asyncio.create_task(
self._task_wrapper(
asyncio.create_task(
inst.run(),
name=f"platform_{platform_config['type']}_{platform_config['id']}",
),
platform=inst,
),
)
handlers = star_handlers_registry.get_handlers_by_event_type(
EventType.OnPlatformLoadedEvent,
@@ -269,25 +230,13 @@ class PlatformManager:
except Exception:
logger.warning(f"可能未完全移除 {platform_id} 平台适配器")
await self._terminate_inst_and_tasks(inst)
if getattr(inst, "terminate", None):
await inst.terminate()
async def terminate(self) -> None:
terminated_client_ids: set[str] = set()
for platform_id in list(self._inst_map.keys()):
info = self._inst_map.get(platform_id)
if info:
terminated_client_ids.add(info["client_id"])
await self.terminate_platform(platform_id)
for inst in list(self.platform_insts):
client_id = inst.client_self_id
if client_id in terminated_client_ids:
continue
await self._terminate_inst_and_tasks(inst)
self.platform_insts.clear()
self._inst_map.clear()
self._platform_tasks.clear()
for inst in self.platform_insts:
if getattr(inst, "terminate", None):
await inst.terminate()
def get_insts(self):
return self.platform_insts
@@ -24,14 +24,3 @@ class PlatformMetadata:
module_path: str | None = None
"""注册该适配器的模块路径,用于插件热重载时清理"""
i18n_resources: dict[str, dict] | None = None
"""国际化资源数据,如 {"zh-CN": {...}, "en-US": {...}}
参考 https://github.com/AstrBotDevs/AstrBot/pull/5045
"""
config_metadata: dict | None = None
"""配置项元数据,用于 WebUI 生成表单。对应 config_metadata.json 的内容
参考 https://github.com/AstrBotDevs/AstrBot/pull/5045
"""
-5
View File
@@ -15,14 +15,11 @@ def register_platform_adapter(
adapter_display_name: str | None = None,
logo_path: str | None = None,
support_streaming_message: bool = True,
i18n_resources: dict[str, dict] | None = None,
config_metadata: dict | None = None,
):
"""用于注册平台适配器的带参装饰器。
default_config_tmpl 指定了平台适配器的默认配置模板用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类
logo_path 指定了平台适配器的 logo 文件路径是相对于插件目录的路径
config_metadata 指定了配置项的元数据用于 WebUI 生成表单如果不指定WebUI 将会把配置项渲染为原始的键值对编辑框
"""
def decorator(cls):
@@ -52,8 +49,6 @@ def register_platform_adapter(
logo_path=logo_path,
support_streaming_message=support_streaming_message,
module_path=module_path,
i18n_resources=i18n_resources,
config_metadata=config_metadata,
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
@@ -1,9 +1,8 @@
import asyncio
import json
import os
import threading
import uuid
from pathlib import Path
from typing import Literal, NoReturn, cast
from typing import NoReturn, cast
import aiohttp
import dingtalk_stream
@@ -11,7 +10,7 @@ from dingtalk_stream import AckMessage
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, Image, Plain, Record, Video
from astrbot.api.message_components import At, Image, Plain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
@@ -19,16 +18,9 @@ from astrbot.api.platform import (
Platform,
PlatformMetadata,
)
from astrbot.core import sp
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.media_utils import (
convert_audio_format,
convert_video_format,
extract_video_cover,
get_media_duration,
)
from ...register import register_platform_adapter
from .dingtalk_event import DingtalkMessageEvent
@@ -83,6 +75,8 @@ class DingtalkPlatformAdapter(Platform):
)
self.client_ = client # 用于 websockets 的 client
self._shutdown_event: threading.Event | None = None
self.card_template_id = platform_config.get("card_template_id")
self.card_instance_id_dict = {}
def _id_to_sid(self, dingtalk_id: str | None) -> str:
if not dingtalk_id:
@@ -97,44 +91,7 @@ class DingtalkPlatformAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
) -> None:
robot_code = self.client_id
if session.message_type == MessageType.GROUP_MESSAGE:
open_conversation_id = session.session_id
await self.send_message_chain_to_group(
open_conversation_id=open_conversation_id,
robot_code=robot_code,
message_chain=message_chain,
)
else:
staff_id = await self._get_sender_staff_id(session)
if not staff_id:
logger.warning(
"钉钉私聊会话缺少 staff_id 映射,回退使用 session_id 作为 userId 发送",
)
staff_id = session.session_id
await self.send_message_chain_to_user(
staff_id=staff_id,
robot_code=robot_code,
message_chain=message_chain,
)
await super().send_by_session(session, message_chain)
async def send_with_session(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
await self.send_by_session(session, message_chain)
async def send_with_sesison(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
# backward typo compatibility
await self.send_by_session(session, message_chain)
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
@@ -142,9 +99,67 @@ class DingtalkPlatformAdapter(Platform):
description="钉钉机器人官方 API 适配器",
id=cast(str, self.config.get("id")),
support_streaming_message=True,
support_proactive_message=True,
support_proactive_message=False,
)
async def create_message_card(
self, message_id: str, incoming_message: dingtalk_stream.ChatbotMessage
) -> bool | None:
if not self.card_template_id:
return False
card_instance = dingtalk_stream.AICardReplier(self.client_, incoming_message)
card_data = {"content": ""} # Initial content empty
try:
card_instance_id = await card_instance.async_create_and_deliver_card(
self.card_template_id,
card_data,
)
self.card_instance_id_dict[message_id] = (card_instance, card_instance_id)
return True
except Exception as e:
logger.error(f"创建钉钉卡片失败: {e}")
return False
async def send_card_message(
self, message_id: str, content: str, is_final: bool
) -> None:
if message_id not in self.card_instance_id_dict:
return
card_instance, card_instance_id = self.card_instance_id_dict[message_id]
content_key = "content"
try:
# 钉钉卡片流式更新
await card_instance.async_streaming(
card_instance_id,
content_key=content_key,
content_value=content,
append=False,
finished=is_final,
failed=False,
)
except Exception as e:
logger.error(f"发送钉钉卡片消息失败: {e}")
# Try to report failure
try:
await card_instance.async_streaming(
card_instance_id,
content_key=content_key,
content_value=content, # Keep existing content
append=False,
finished=True,
failed=True,
)
except Exception:
pass
if is_final:
self.card_instance_id_dict.pop(message_id, None)
async def convert_msg(
self,
message: dingtalk_stream.ChatbotMessage,
@@ -202,35 +217,8 @@ class DingtalkPlatformAdapter(Platform):
case "audio":
pass
await self._remember_sender_binding(message, abm)
return abm # 别忘了返回转换后的消息对象
async def _remember_sender_binding(
self,
message: dingtalk_stream.ChatbotMessage,
abm: AstrBotMessage,
) -> None:
try:
if abm.type == MessageType.FRIEND_MESSAGE:
sender_id = abm.sender.user_id
sender_staff_id = cast(str, message.sender_staff_id or "")
if sender_staff_id:
umo = str(
MessageSesion(
platform_name=self.meta().id,
message_type=abm.type,
session_id=sender_id,
)
)
await sp.put_async(
"global",
umo,
"dingtalk_staffid",
sender_staff_id,
)
except Exception as e:
logger.warning(f"保存钉钉会话映射失败: {e}")
async def download_ding_file(
self,
download_code: str,
@@ -253,9 +241,8 @@ class DingtalkPlatformAdapter(Platform):
"downloadCode": download_code,
"robotCode": robot_code,
}
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
f_path = temp_dir / f"dingtalk_{uuid.uuid4()}.{ext}"
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
async with (
aiohttp.ClientSession() as session,
session.post(
@@ -271,21 +258,14 @@ class DingtalkPlatformAdapter(Platform):
return ""
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, str(f_path))
return str(f_path)
await download_file(download_url, f_path)
return f_path
async def get_access_token(self) -> str:
try:
access_token = await asyncio.get_event_loop().run_in_executor(
None,
self.client_.get_access_token,
)
if access_token:
return access_token
except Exception as e:
logger.warning(f"通过 dingtalk_stream 获取 access_token 失败: {e}")
payload = {"appKey": self.client_id, "appSecret": self.client_secret}
payload = {
"appKey": self.client_id,
"appSecret": self.client_secret,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/oauth2/accessToken",
@@ -296,328 +276,7 @@ class DingtalkPlatformAdapter(Platform):
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
)
return ""
data = await resp.json()
return cast(str, data.get("data", {}).get("accessToken", ""))
async def _get_sender_staff_id(self, session: MessageSesion) -> str:
try:
staff_id = await sp.get_async(
"global",
str(session),
"dingtalk_staffid",
"",
)
return cast(str, staff_id or "")
except Exception as e:
logger.warning(f"读取钉钉 staff_id 映射失败: {e}")
return ""
async def _send_group_message(
self,
open_conversation_id: str,
robot_code: str,
msg_key: str,
msg_param: dict,
) -> None:
access_token = await self.get_access_token()
if not access_token:
logger.error("钉钉群消息发送失败: access_token 为空")
return
payload = {
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
"openConversationId": open_conversation_id,
"robotCode": robot_code,
}
headers = {
"Content-Type": "application/json",
"x-acs-dingtalk-access-token": access_token,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/robot/groupMessages/send",
headers=headers,
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"钉钉群消息发送失败: {resp.status}, {await resp.text()}",
)
async def _send_private_message(
self,
staff_id: str,
robot_code: str,
msg_key: str,
msg_param: dict,
) -> None:
access_token = await self.get_access_token()
if not access_token:
logger.error("钉钉私聊消息发送失败: access_token 为空")
return
payload = {
"robotCode": robot_code,
"userIds": [staff_id],
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
headers = {
"Content-Type": "application/json",
"x-acs-dingtalk-access-token": access_token,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend",
headers=headers,
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"钉钉私聊消息发送失败: {resp.status}, {await resp.text()}",
)
def _safe_remove_file(self, file_path: str | None) -> None:
if not file_path:
return
try:
p = Path(file_path)
if p.exists() and p.is_file():
p.unlink()
except Exception as e:
logger.warning(f"清理临时文件失败: {file_path}, {e}")
async def _prepare_voice_for_dingtalk(self, input_path: str) -> tuple[str, bool]:
"""优先转换为 OGG(Opus),不可用时回退 AMR。"""
lower_path = input_path.lower()
if lower_path.endswith((".amr", ".ogg")):
return input_path, False
try:
converted = await convert_audio_format(input_path, "ogg")
return converted, converted != input_path
except Exception as e:
logger.warning(f"钉钉语音转 OGG 失败,回退 AMR: {e}")
converted = await convert_audio_format(input_path, "amr")
return converted, converted != input_path
async def upload_media(self, file_path: str, media_type: str) -> str:
media_file_path = Path(file_path)
access_token = await self.get_access_token()
if not access_token:
logger.error("钉钉媒体上传失败: access_token 为空")
return ""
form = aiohttp.FormData()
form.add_field(
"media",
media_file_path.read_bytes(),
filename=media_file_path.name,
content_type="application/octet-stream",
)
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://oapi.dingtalk.com/media/upload?access_token={access_token}&type={media_type}",
data=form,
) as resp:
if resp.status != 200:
logger.error(
f"钉钉媒体上传失败: {resp.status}, {await resp.text()}"
)
return ""
data = await resp.json()
if data.get("errcode") != 0:
logger.error(f"钉钉媒体上传失败: {data}")
return ""
return cast(str, data.get("media_id", ""))
async def upload_image(self, image: Image) -> str:
image_file_path = await image.convert_to_file_path()
return await self.upload_media(image_file_path, "image")
async def _send_message_chain(
self,
target_type: Literal["group", "user"],
target_id: str,
robot_code: str,
message_chain: MessageChain,
at_str: str = "",
) -> None:
async def send_message(msg_key: str, msg_param: dict) -> None:
if target_type == "group":
await self._send_group_message(
open_conversation_id=target_id,
robot_code=robot_code,
msg_key=msg_key,
msg_param=msg_param,
)
else:
await self._send_private_message(
staff_id=target_id,
robot_code=robot_code,
msg_key=msg_key,
msg_param=msg_param,
)
for segment in message_chain.chain:
if isinstance(segment, Plain):
text = segment.text.strip()
if not text and not at_str:
continue
await send_message(
msg_key="sampleMarkdown",
msg_param={
"title": "AstrBot",
"text": f"{at_str} {text}".strip(),
},
)
elif isinstance(segment, Image):
photo_url = segment.file or segment.url or ""
if photo_url.startswith(("http://", "https://")):
pass
else:
photo_url = await self.upload_image(segment)
if not photo_url:
continue
await send_message(
msg_key="sampleImageMsg",
msg_param={"photoURL": photo_url},
)
elif isinstance(segment, Record):
converted_audio = None
try:
audio_path = await segment.convert_to_file_path()
(
audio_path,
converted_audio,
) = await self._prepare_voice_for_dingtalk(audio_path)
media_id = await self.upload_media(audio_path, "voice")
if not media_id:
continue
duration_ms = await get_media_duration(audio_path)
await send_message(
msg_key="sampleAudio",
msg_param={
"mediaId": media_id,
"duration": str(duration_ms or 1000),
},
)
except Exception as e:
logger.warning(f"钉钉语音发送失败: {e}")
continue
finally:
if converted_audio:
self._safe_remove_file(audio_path)
elif isinstance(segment, Video):
converted_video = False
cover_path = None
try:
source_video_path = await segment.convert_to_file_path()
video_path = source_video_path
if not video_path.lower().endswith(".mp4"):
video_path = await convert_video_format(video_path, "mp4")
converted_video = video_path != source_video_path
cover_path = await extract_video_cover(video_path)
video_media_id = await self.upload_media(video_path, "file")
pic_media_id = await self.upload_media(cover_path, "image")
if not video_media_id or not pic_media_id:
continue
duration_ms = await get_media_duration(video_path)
duration_sec = max(1, int((duration_ms or 1000) / 1000))
await send_message(
msg_key="sampleVideo",
msg_param={
"duration": str(duration_sec),
"videoMediaId": video_media_id,
"videoType": "mp4",
"picMediaId": pic_media_id,
},
)
except Exception as e:
logger.warning(f"钉钉视频发送失败: {e}")
continue
finally:
self._safe_remove_file(cover_path)
if converted_video:
self._safe_remove_file(video_path)
async def send_message_chain_to_group(
self,
open_conversation_id: str,
robot_code: str,
message_chain: MessageChain,
at_str: str = "",
) -> None:
await self._send_message_chain(
target_type="group",
target_id=open_conversation_id,
robot_code=robot_code,
message_chain=message_chain,
at_str=at_str,
)
async def send_message_chain_to_user(
self,
staff_id: str,
robot_code: str,
message_chain: MessageChain,
at_str: str = "",
) -> None:
await self._send_message_chain(
target_type="user",
target_id=staff_id,
robot_code=robot_code,
message_chain=message_chain,
at_str=at_str,
)
async def send_message_chain_with_incoming(
self,
incoming_message: dingtalk_stream.ChatbotMessage,
message_chain: MessageChain,
) -> None:
robot_code = self.client_id
# at_list: list[str] = []
sender_id = cast(str, incoming_message.sender_id or "")
sender_staff_id = cast(str, incoming_message.sender_staff_id or "")
normalized_sender_id = self._id_to_sid(sender_id)
# 现在用的发消息接口不支持 at
# for segment in message_chain.chain:
# if isinstance(segment, At):
# if (
# str(segment.qq) in {sender_id, normalized_sender_id}
# and sender_staff_id
# ):
# at_list.append(f"@{sender_staff_id}")
# else:
# at_list.append(f"@{segment.qq}")
# at_str = " ".join(at_list)
if incoming_message.conversation_type == "2":
await self.send_message_chain_to_group(
open_conversation_id=cast(str, incoming_message.conversation_id),
robot_code=robot_code,
message_chain=message_chain,
# at_str=at_str,
)
else:
session = MessageSesion(
platform_name=self.meta().id,
message_type=MessageType.FRIEND_MESSAGE,
session_id=normalized_sender_id,
)
staff_id = sender_staff_id or await self._get_sender_staff_id(session)
if not staff_id:
logger.error("钉钉私聊回复失败: 缺少 sender_staff_id")
return
await self.send_message_chain_to_user(
staff_id=staff_id,
robot_code=robot_code,
message_chain=message_chain,
# at_str=at_str,
)
return (await resp.json())["data"]["accessToken"]
async def handle_msg(self, abm: AstrBotMessage) -> None:
event = DingtalkMessageEvent(
@@ -1,5 +1,9 @@
from typing import Any
import asyncio
from typing import Any, cast
import dingtalk_stream
import astrbot.api.message_components as Comp
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -11,33 +15,128 @@ class DingtalkMessageEvent(AstrMessageEvent):
message_obj,
platform_meta,
session_id,
client: Any = None,
client: dingtalk_stream.ChatbotHandler,
adapter: "Any" = None,
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.adapter = adapter
async def send_with_client(
self,
client: dingtalk_stream.ChatbotHandler,
message: MessageChain,
) -> None:
icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message)
ats = []
# fixes: #4218
# 钉钉 at 机器人需要使用 sender_staff_id 而不是 sender_id
for i in message.chain:
if isinstance(i, Comp.At):
print(i.qq, icm.sender_id, icm.sender_staff_id)
if str(i.qq) in str(icm.sender_id or ""):
# 适配器会将开头的 $:LWCP_v1:$ 去掉,因此我们用 in 判断
ats.append(f"@{icm.sender_staff_id}")
else:
ats.append(f"@{i.qq}")
at_str = " ".join(ats)
for segment in message.chain:
if isinstance(segment, Comp.Plain):
segment.text = segment.text.strip()
await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
segment.text,
f"{at_str} {segment.text}".strip(),
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
try:
if not segment.file:
logger.warning("钉钉图片 segment 缺少 file 字段,跳过")
continue
if segment.file.startswith(("http://", "https://")):
image_url = segment.file
else:
image_url = await segment.register_to_file_service()
markdown_str = f"![image]({image_url})\n\n"
ret = await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
"😄",
markdown_str,
cast(
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
),
)
logger.debug(f"send image: {ret}")
except Exception as e:
logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送")
continue
async def send(self, message: MessageChain) -> None:
if not self.adapter:
logger.error("钉钉消息发送失败: 缺少 adapter")
return
await self.adapter.send_message_chain_with_incoming(
incoming_message=self.message_obj.raw_message,
message_chain=message,
)
await self.send_with_client(self.client, message)
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
# 钉钉统一回退为缓冲发送:最终发送仍使用新的 HTTP 消息接口。
buffer = None
async for chain in generator:
if not self.adapter or not self.adapter.card_template_id:
logger.warning(
f"DingTalk streaming is enabled, but 'card_template_id' is not configured for platform '{self.platform_meta.id}'. Falling back to text streaming."
)
# Fallback to default behavior (buffer and send)
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
# Create card
msg_id = self.message_obj.message_id
incoming_msg = self.message_obj.raw_message
created = await self.adapter.create_message_card(msg_id, incoming_msg)
if not created:
# Fallback to default behavior (buffer and send)
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
full_content = ""
seq = 0
try:
async for chain in generator:
for segment in chain.chain:
if isinstance(segment, Comp.Plain):
full_content += segment.text
seq += 1
if seq % 2 == 0: # Update every 2 chunks to be more responsive than 8
await self.adapter.send_card_message(
msg_id, full_content, is_final=False
)
await self.adapter.send_card_message(msg_id, full_content, is_final=True)
except Exception as e:
logger.error(f"DingTalk streaming error: {e}")
# Try to ensure final state is sent or cleaned up?
await self.adapter.send_card_message(msg_id, full_content, is_final=True)
@@ -3,13 +3,10 @@ import base64
import json
import re
import time
from pathlib import Path
from typing import Any, cast
from uuid import uuid4
import lark_oapi as lark
from lark_oapi.api.im.v1 import (
GetMessageRequest,
GetMessageResourceRequest,
)
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
@@ -25,7 +22,6 @@ from astrbot.api.platform import (
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
@@ -95,347 +91,6 @@ class LarkPlatformAdapter(Platform):
self.event_id_timestamps: dict[str, float] = {}
async def _download_message_resource(
self,
*,
message_id: str,
file_key: str,
resource_type: str,
) -> bytes | None:
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化")
return None
request = (
GetMessageResourceRequest.builder()
.message_id(message_id)
.file_key(file_key)
.type(resource_type)
.build()
)
response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success():
logger.error(
f"[Lark] 下载消息资源失败 type={resource_type}, key={file_key}, "
f"code={response.code}, msg={response.msg}",
)
return None
if response.file is None:
logger.error(f"[Lark] 消息资源响应中不包含文件流: {file_key}")
return None
return response.file.read()
@staticmethod
def _build_message_str_from_components(
components: list[Comp.BaseMessageComponent],
) -> str:
parts: list[str] = []
for comp in components:
if isinstance(comp, Comp.Plain):
text = comp.text.strip()
if text:
parts.append(text)
elif isinstance(comp, Comp.At):
name = str(comp.name or comp.qq or "").strip()
if name:
parts.append(f"@{name}")
elif isinstance(comp, Comp.Image):
parts.append("[image]")
elif isinstance(comp, Comp.File):
parts.append(str(comp.name or "[file]"))
elif isinstance(comp, Comp.Record):
parts.append("[audio]")
elif isinstance(comp, Comp.Video):
parts.append("[video]")
return " ".join(parts).strip()
@staticmethod
def _parse_post_content(content: dict[str, Any]) -> list[dict[str, Any]]:
result: list[dict[str, Any]] = []
for item in content.get("content", []):
if isinstance(item, list):
for comp in item:
if isinstance(comp, dict):
result.append(comp)
elif isinstance(item, dict):
result.append(item)
return result
@staticmethod
def _build_at_map(mentions: list[Any] | None) -> dict[str, Comp.At]:
at_map: dict[str, Comp.At] = {}
if not mentions:
return at_map
for mention in mentions:
key = getattr(mention, "key", None)
if not key:
continue
mention_id = getattr(mention, "id", None)
open_id = ""
if mention_id is not None:
if hasattr(mention_id, "open_id"):
open_id = getattr(mention_id, "open_id", "") or ""
else:
open_id = str(mention_id)
mention_name = str(getattr(mention, "name", "") or "")
at_map[key] = Comp.At(qq=open_id, name=mention_name)
return at_map
async def _parse_message_components(
self,
*,
message_id: str | None,
message_type: str,
content: dict[str, Any],
at_map: dict[str, Comp.At],
) -> list[Comp.BaseMessageComponent]:
components: list[Comp.BaseMessageComponent] = []
if message_type == "text":
message_str_raw = str(content.get("text", ""))
at_pattern = r"(@_user_\d+)"
parts = re.split(at_pattern, message_str_raw)
for part in parts:
segment = part.strip()
if not segment:
continue
if segment in at_map:
components.append(at_map[segment])
else:
components.append(Comp.Plain(segment))
return components
if message_type in ("post", "image"):
if message_type == "image":
comp_list = [
{
"tag": "img",
"image_key": content.get("image_key"),
},
]
else:
comp_list = self._parse_post_content(content)
for comp in comp_list:
tag = comp.get("tag")
if tag == "at":
user_key = str(comp.get("user_id", ""))
if user_key in at_map:
components.append(at_map[user_key])
elif tag == "text":
text = str(comp.get("text", "")).strip()
if text:
components.append(Comp.Plain(text))
elif tag == "a":
text = str(comp.get("text", "")).strip()
href = str(comp.get("href", "")).strip()
if text and href:
components.append(Comp.Plain(f"{text}({href})"))
elif text:
components.append(Comp.Plain(text))
elif tag == "img":
image_key = str(comp.get("image_key", "")).strip()
if not image_key:
continue
if not message_id:
logger.error("[Lark] 图片消息缺少 message_id")
continue
image_bytes = await self._download_message_resource(
message_id=message_id,
file_key=image_key,
resource_type="image",
)
if image_bytes is None:
continue
image_base64 = base64.b64encode(image_bytes).decode()
components.append(Comp.Image.fromBase64(image_base64))
elif tag == "media":
file_key = str(comp.get("file_key", "")).strip()
file_name = (
str(comp.get("file_name", "")).strip() or "lark_media.mp4"
)
if not file_key:
continue
if not message_id:
logger.error("[Lark] 富文本视频消息缺少 message_id")
continue
file_path = await self._download_file_resource_to_temp(
message_id=message_id,
file_key=file_key,
message_type="post_media",
file_name=file_name,
default_suffix=".mp4",
)
if file_path:
components.append(Comp.Video(file=file_path, path=file_path))
return components
if message_type == "file":
file_key = str(content.get("file_key", "")).strip()
file_name = str(content.get("file_name", "")).strip() or "lark_file"
if not message_id:
logger.error("[Lark] 文件消息缺少 message_id")
return components
if not file_key:
logger.error("[Lark] 文件消息缺少 file_key")
return components
file_path = await self._download_file_resource_to_temp(
message_id=message_id,
file_key=file_key,
message_type="file",
file_name=file_name,
)
if file_path:
components.append(Comp.File(name=file_name, file=file_path))
return components
if message_type == "audio":
file_key = str(content.get("file_key", "")).strip()
if not message_id:
logger.error("[Lark] 音频消息缺少 message_id")
return components
if not file_key:
logger.error("[Lark] 音频消息缺少 file_key")
return components
file_path = await self._download_file_resource_to_temp(
message_id=message_id,
file_key=file_key,
message_type="audio",
default_suffix=".opus",
)
if file_path:
components.append(Comp.Record(file=file_path, url=file_path))
return components
if message_type == "media":
file_key = str(content.get("file_key", "")).strip()
file_name = str(content.get("file_name", "")).strip() or "lark_media.mp4"
if not message_id:
logger.error("[Lark] 视频消息缺少 message_id")
return components
if not file_key:
logger.error("[Lark] 视频消息缺少 file_key")
return components
file_path = await self._download_file_resource_to_temp(
message_id=message_id,
file_key=file_key,
message_type="media",
file_name=file_name,
default_suffix=".mp4",
)
if file_path:
components.append(Comp.Video(file=file_path, path=file_path))
return components
return components
async def _build_reply_from_parent_id(
self,
parent_message_id: str,
) -> Comp.Reply | None:
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化")
return None
request = GetMessageRequest.builder().message_id(parent_message_id).build()
response = await self.lark_api.im.v1.message.aget(request)
if not response.success():
logger.error(
f"[Lark] 获取引用消息失败 id={parent_message_id}, "
f"code={response.code}, msg={response.msg}",
)
return None
if response.data is None or not response.data.items:
logger.error(
f"[Lark] 引用消息响应为空 id={parent_message_id}",
)
return None
parent_message = response.data.items[0]
quoted_message_id = parent_message.message_id or parent_message_id
quoted_sender_id = (
parent_message.sender.id
if parent_message.sender and parent_message.sender.id
else "unknown"
)
quoted_time_raw = parent_message.create_time or 0
quoted_time = (
quoted_time_raw // 1000
if isinstance(quoted_time_raw, int) and quoted_time_raw > 10**11
else quoted_time_raw
)
quoted_content = (
parent_message.body.content if parent_message.body else ""
) or ""
quoted_type = parent_message.msg_type or ""
quoted_content_json: dict[str, Any] = {}
if quoted_content:
try:
parsed = json.loads(quoted_content)
if isinstance(parsed, dict):
quoted_content_json = parsed
except json.JSONDecodeError:
logger.warning(
f"[Lark] 解析引用消息内容失败 id={quoted_message_id}",
)
quoted_at_map = self._build_at_map(parent_message.mentions)
quoted_chain = await self._parse_message_components(
message_id=quoted_message_id,
message_type=quoted_type,
content=quoted_content_json,
at_map=quoted_at_map,
)
quoted_text = self._build_message_str_from_components(quoted_chain)
sender_nickname = (
quoted_sender_id[:8] if quoted_sender_id != "unknown" else "unknown"
)
return Comp.Reply(
id=quoted_message_id,
chain=quoted_chain,
sender_id=quoted_sender_id,
sender_nickname=sender_nickname,
time=quoted_time,
message_str=quoted_text,
text=quoted_text,
)
async def _download_file_resource_to_temp(
self,
*,
message_id: str,
file_key: str,
message_type: str,
file_name: str = "",
default_suffix: str = ".bin",
) -> str | None:
file_bytes = await self._download_message_resource(
message_id=message_id,
file_key=file_key,
resource_type="file",
)
if file_bytes is None:
return None
suffix = Path(file_name).suffix if file_name else default_suffix
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
temp_path = (
temp_dir / f"lark_{message_type}_{file_name}_{uuid4().hex[:4]}{suffix}"
)
temp_path.write_bytes(file_bytes)
return str(temp_path.resolve())
def _clean_expired_events(self) -> None:
"""清理超过 30 分钟的事件记录"""
current_time = time.time()
@@ -521,11 +176,6 @@ class LarkPlatformAdapter(Platform):
abm.message_str = ""
at_list = {}
if message.parent_id:
reply_seg = await self._build_reply_from_parent_id(message.parent_id)
if reply_seg:
abm.message.append(reply_seg)
if message.mentions:
for m in message.mentions:
if m.id is None:
@@ -548,19 +198,80 @@ class LarkPlatformAdapter(Platform):
logger.error(f"[Lark] 解析消息内容失败: {message.content}")
return
if not isinstance(content_json_b, dict):
logger.error(f"[Lark] 消息内容不是 JSON Object: {message.content}")
return
if message.message_type == "text":
message_str_raw = content_json_b.get("text", "") # 带有 @ 的消息
at_pattern = r"(@_user_\d+)" # 可以根据需求修改正则
# at_users = re.findall(at_pattern, message_str_raw)
# 拆分文本,去掉AT符号部分
parts = re.split(at_pattern, message_str_raw)
for i in range(len(parts)):
s = parts[i].strip()
if not s:
continue
if s in at_list:
abm.message.append(at_list[s])
else:
abm.message.append(Comp.Plain(parts[i].strip()))
elif message.message_type == "post":
_ls = []
logger.debug(f"[Lark] 解析消息内容: {content_json_b}")
parsed_components = await self._parse_message_components(
message_id=message.message_id,
message_type=message.message_type or "unknown",
content=content_json_b,
at_map=at_list,
)
abm.message.extend(parsed_components)
abm.message_str = self._build_message_str_from_components(parsed_components)
content_ls = content_json_b.get("content", [])
for comp in content_ls:
if isinstance(comp, list):
_ls.extend(comp)
elif isinstance(comp, dict):
_ls.append(comp)
content_json_b = _ls
elif message.message_type == "image":
content_json_b = [
{
"tag": "img",
"image_key": content_json_b.get("image_key"),
"style": [],
},
]
if message.message_type in ("post", "image"):
for comp in content_json_b:
if comp.get("tag") == "at":
user_id = comp.get("user_id")
if user_id in at_list:
abm.message.append(at_list[user_id])
elif comp.get("tag") == "text" and comp.get("text", "").strip():
abm.message.append(Comp.Plain(comp["text"].strip()))
elif comp.get("tag") == "img":
image_key = comp.get("image_key")
if not image_key:
continue
request = (
GetMessageResourceRequest.builder()
.message_id(cast(str, message.message_id))
.file_key(image_key)
.type("image")
.build()
)
if self.lark_api.im is None:
logger.error("[Lark] API Client im 模块未初始化")
continue
response = await self.lark_api.im.v1.message_resource.aget(request)
if not response.success():
logger.error(f"无法下载飞书图片: {image_key}")
continue
if response.file is None:
logger.error(f"飞书图片响应中不包含文件流: {image_key}")
continue
image_bytes = response.file.read()
image_base64 = base64.b64encode(image_bytes).decode()
abm.message.append(Comp.Image.fromBase64(image_base64))
for comp in abm.message:
if isinstance(comp, Comp.Plain):
abm.message_str += comp.text
if message.message_id is None:
logger.error("[Lark] 消息缺少 message_id")
@@ -585,6 +296,7 @@ class LarkPlatformAdapter(Platform):
else:
abm.session_id = abm.sender.user_id
logger.debug(abm)
await self.handle_msg(abm)
async def handle_msg(self, abm: AstrBotMessage) -> None:
@@ -21,7 +21,7 @@ from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, File, Plain, Record, Video
from astrbot.api.message_components import Image as AstrBotImage
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.media_utils import (
convert_audio_to_opus,
@@ -202,11 +202,8 @@ class LarkMessageEvent(AstrMessageEvent):
base64_str = comp.file.removeprefix("base64://")
image_data = base64.b64decode(base64_str)
# save as temp file
temp_dir = get_astrbot_temp_path()
file_path = os.path.join(
temp_dir,
f"lark_image_{uuid.uuid4().hex[:8]}.jpg",
)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg")
with open(file_path, "wb") as f:
f.write(BytesIO(image_data).getvalue())
else:
@@ -1,474 +0,0 @@
import asyncio
import mimetypes
import time
import uuid
from pathlib import Path
from typing import Any, cast
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, File, Image, Plain, Record, Video
from astrbot.api.platform import (
AstrBotMessage,
Group,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .line_api import LineAPIClient
from .line_event import LineMessageEvent
LINE_CONFIG_METADATA = {
"channel_access_token": {
"description": "LINE Channel Access Token",
"type": "string",
"hint": "LINE Messaging API 的 channel access token。",
},
"channel_secret": {
"description": "LINE Channel Secret",
"type": "string",
"hint": "用于校验 LINE Webhook 签名。",
},
}
LINE_I18N_RESOURCES = {
"zh-CN": {
"channel_access_token": {
"description": "LINE Channel Access Token",
"hint": "LINE Messaging API 的 channel access token。",
},
"channel_secret": {
"description": "LINE Channel Secret",
"hint": "用于校验 LINE Webhook 签名。",
},
},
"en-US": {
"channel_access_token": {
"description": "LINE Channel Access Token",
"hint": "Channel access token for LINE Messaging API.",
},
"channel_secret": {
"description": "LINE Channel Secret",
"hint": "Used to verify LINE webhook signatures.",
},
},
}
@register_platform_adapter(
"line",
"LINE Messaging API 适配器",
support_streaming_message=False,
default_config_tmpl={
"id": "line",
"type": "line",
"enable": False,
"channel_access_token": "",
"channel_secret": "",
"unified_webhook_mode": True,
"webhook_uuid": "",
},
config_metadata=LINE_CONFIG_METADATA,
i18n_resources=LINE_I18N_RESOURCES,
)
class LinePlatformAdapter(Platform):
def __init__(
self,
platform_config: dict,
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(platform_config, event_queue)
self.config["unified_webhook_mode"] = True
self.destination = "unknown"
self.settings = platform_settings
self._event_id_timestamps: dict[str, float] = {}
self.shutdown_event = asyncio.Event()
channel_access_token = str(platform_config.get("channel_access_token", ""))
channel_secret = str(platform_config.get("channel_secret", ""))
if not channel_access_token or not channel_secret:
raise ValueError(
"LINE 适配器需要 channel_access_token 和 channel_secret。",
)
self.line_api = LineAPIClient(
channel_access_token=channel_access_token,
channel_secret=channel_secret,
)
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
messages = await LineMessageEvent.build_line_messages(message_chain)
if messages:
await self.line_api.push_message(session.session_id, messages)
await super().send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="line",
description="LINE Messaging API 适配器",
id=cast(str, self.config.get("id", "line")),
support_streaming_message=False,
)
async def run(self) -> None:
webhook_uuid = self.config.get("webhook_uuid")
if webhook_uuid:
log_webhook_info(f"{self.meta().id}(LINE)", webhook_uuid)
else:
logger.warning("[LINE] webhook_uuid 为空,统一 Webhook 可能无法接收消息。")
await self.shutdown_event.wait()
async def terminate(self) -> None:
self.shutdown_event.set()
await self.line_api.close()
async def webhook_callback(self, request: Any) -> Any:
raw_body = await request.get_data()
signature = request.headers.get("x-line-signature")
if not self.line_api.verify_signature(raw_body, signature):
logger.warning("[LINE] invalid webhook signature")
return "invalid signature", 400
try:
payload = await request.get_json(force=True, silent=False)
except Exception as e:
logger.warning("[LINE] invalid webhook body: %s", e)
return "bad request", 400
if not isinstance(payload, dict):
return "bad request", 400
await self.handle_webhook_event(payload)
return "ok", 200
async def handle_webhook_event(self, payload: dict[str, Any]) -> None:
destination = str(payload.get("destination", "")).strip()
if destination:
self.destination = destination
events = payload.get("events")
if not isinstance(events, list):
return
for event in events:
if not isinstance(event, dict):
continue
event_id = str(event.get("webhookEventId", ""))
if event_id and self._is_duplicate_event(event_id):
logger.debug("[LINE] duplicate event skipped: %s", event_id)
continue
abm = await self.convert_message(event)
if abm is None:
continue
await self.handle_msg(abm)
async def convert_message(self, event: dict[str, Any]) -> AstrBotMessage | None:
if str(event.get("type", "")) != "message":
return None
if str(event.get("mode", "active")) == "standby":
return None
source = event.get("source", {})
if not isinstance(source, dict):
return None
message = event.get("message", {})
if not isinstance(message, dict):
return None
source_type = str(source.get("type", ""))
user_id = str(source.get("userId", "")).strip()
group_id = str(source.get("groupId", "")).strip()
room_id = str(source.get("roomId", "")).strip()
abm = AstrBotMessage()
abm.self_id = self.destination or self.meta().id
abm.message = []
abm.raw_message = event
abm.message_id = str(
message.get("id")
or event.get("webhookEventId")
or event.get("deliveryContext", {}).get("deliveryId", "")
or uuid.uuid4().hex
)
event_timestamp = event.get("timestamp")
if isinstance(event_timestamp, int):
abm.timestamp = (
event_timestamp // 1000
if event_timestamp > 1_000_000_000_000
else event_timestamp
)
else:
abm.timestamp = int(time.time())
if source_type in {"group", "room"}:
abm.type = MessageType.GROUP_MESSAGE
container_id = group_id or room_id
abm.group = Group(group_id=container_id, group_name=container_id)
abm.session_id = container_id
sender_id = user_id or container_id
elif source_type == "user":
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = user_id
sender_id = user_id
else:
abm.type = MessageType.OTHER_MESSAGE
abm.session_id = user_id or group_id or room_id or "unknown"
sender_id = abm.session_id
abm.sender = MessageMember(user_id=sender_id, nickname=sender_id[:8])
components = await self._parse_line_message_components(message)
if not components:
return None
abm.message = components
abm.message_str = self._build_message_str(components)
return abm
async def _parse_line_message_components(
self,
message: dict[str, Any],
) -> list:
msg_type = str(message.get("type", ""))
message_id = str(message.get("id", "")).strip()
if msg_type == "text":
text = str(message.get("text", ""))
mention = message.get("mention")
if isinstance(mention, dict):
return self._parse_text_with_mentions(text, mention)
return [Plain(text=text)] if text else []
if msg_type == "image":
image_component = await self._build_image_component(message_id, message)
return [image_component] if image_component else [Plain(text="[image]")]
if msg_type == "video":
video_component = await self._build_video_component(message_id, message)
return [video_component] if video_component else [Plain(text="[video]")]
if msg_type == "audio":
audio_component = await self._build_audio_component(message_id, message)
return [audio_component] if audio_component else [Plain(text="[audio]")]
if msg_type == "file":
file_component = await self._build_file_component(message_id, message)
return [file_component] if file_component else [Plain(text="[file]")]
if msg_type == "sticker":
return [Plain(text="[sticker]")]
return [Plain(text=f"[{msg_type}]")]
def _parse_text_with_mentions(self, text: str, mention_obj: dict[str, Any]) -> list:
mentions = mention_obj.get("mentionees", [])
if not isinstance(mentions, list) or not mentions:
return [Plain(text=text)] if text else []
normalized = []
for item in mentions:
if not isinstance(item, dict):
continue
start = item.get("index")
length = item.get("length")
if not isinstance(start, int) or not isinstance(length, int):
continue
normalized.append((start, length, item))
normalized.sort(key=lambda x: x[0])
ret = []
cursor = 0
for start, length, item in normalized:
if start > cursor:
part = text[cursor:start]
if part:
ret.append(Plain(text=part))
label = text[start : start + length] or "@user"
mention_type = str(item.get("type", ""))
if mention_type == "user":
target_id = str(item.get("userId", "")).strip()
ret.append(At(qq=target_id, name=label.lstrip("@")))
else:
ret.append(Plain(text=label))
cursor = max(cursor, start + length)
if cursor < len(text):
tail = text[cursor:]
if tail:
ret.append(Plain(text=tail))
return ret
async def _build_image_component(
self,
message_id: str,
message: dict[str, Any],
) -> Image | None:
external_url = self._get_external_content_url(message)
if external_url:
return Image.fromURL(external_url)
content = await self.line_api.get_message_content(message_id)
if not content:
return None
content_bytes, _, _ = content
return Image.fromBytes(content_bytes)
async def _build_video_component(
self,
message_id: str,
message: dict[str, Any],
) -> Video | None:
external_url = self._get_external_content_url(message)
if external_url:
return Video.fromURL(external_url)
content = await self.line_api.get_message_content(message_id)
if not content:
return None
content_bytes, content_type, _ = content
suffix = self._guess_suffix(content_type, ".mp4")
file_path = self._store_temp_content("video", message_id, content_bytes, suffix)
return Video(file=file_path, path=file_path)
async def _build_audio_component(
self,
message_id: str,
message: dict[str, Any],
) -> Record | None:
external_url = self._get_external_content_url(message)
if external_url:
return Record.fromURL(external_url)
content = await self.line_api.get_message_content(message_id)
if not content:
return None
content_bytes, content_type, _ = content
suffix = self._guess_suffix(content_type, ".m4a")
file_path = self._store_temp_content("audio", message_id, content_bytes, suffix)
return Record(file=file_path, url=file_path)
async def _build_file_component(
self,
message_id: str,
message: dict[str, Any],
) -> File | None:
content = await self.line_api.get_message_content(message_id)
if not content:
return None
content_bytes, content_type, filename = content
default_name = str(message.get("fileName", "")).strip() or f"{message_id}.bin"
suffix = Path(default_name).suffix or self._guess_suffix(content_type, ".bin")
final_name = filename or default_name
file_path = self._store_temp_content(
"file",
message_id,
content_bytes,
suffix,
original_name=final_name,
)
return File(name=final_name, file=file_path, url=file_path)
@staticmethod
def _get_external_content_url(message: dict[str, Any]) -> str:
provider = message.get("contentProvider")
if not isinstance(provider, dict):
return ""
if str(provider.get("type", "")) != "external":
return ""
return str(provider.get("originalContentUrl", "")).strip()
@staticmethod
def _guess_suffix(content_type: str | None, fallback: str) -> str:
if not content_type:
return fallback
base_type = content_type.split(";", 1)[0].strip().lower()
guessed = mimetypes.guess_extension(base_type)
if guessed:
return guessed
return fallback
@staticmethod
def _store_temp_content(
content_type: str,
message_id: str,
content: bytes,
suffix: str,
original_name: str = "",
) -> str:
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
name_prefix = f"line_{content_type}"
if original_name:
safe_stem = Path(original_name).stem.strip()
safe_stem = "".join(
ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in safe_stem
)
safe_stem = safe_stem.strip("._")
if safe_stem:
name_prefix = safe_stem[:64]
file_path = temp_dir / f"{name_prefix}_{message_id}_{uuid.uuid4().hex[:6]}"
file_path = file_path.with_suffix(suffix)
file_path.write_bytes(content)
return str(file_path.resolve())
@staticmethod
def _build_message_str(components: list) -> str:
parts: list[str] = []
for comp in components:
if isinstance(comp, Plain):
parts.append(comp.text)
elif isinstance(comp, At):
parts.append(f"@{comp.name or comp.qq}")
elif isinstance(comp, Image):
parts.append("[image]")
elif isinstance(comp, Video):
parts.append("[video]")
elif isinstance(comp, Record):
parts.append("[audio]")
elif isinstance(comp, File):
parts.append(str(comp.name or "[file]"))
else:
parts.append(f"[{comp.type}]")
return " ".join(i for i in parts if i).strip()
def _clean_expired_events(self) -> None:
current = time.time()
expired = [
event_id
for event_id, ts in self._event_id_timestamps.items()
if current - ts > 1800
]
for event_id in expired:
del self._event_id_timestamps[event_id]
def _is_duplicate_event(self, event_id: str) -> bool:
self._clean_expired_events()
if event_id in self._event_id_timestamps:
return True
self._event_id_timestamps[event_id] = time.time()
return False
async def handle_msg(self, abm: AstrBotMessage) -> None:
event = LineMessageEvent(
message_str=abm.message_str,
message_obj=abm,
platform_meta=self.meta(),
session_id=abm.session_id,
line_api=self.line_api,
)
self._event_queue.put_nowait(event)
@@ -1,203 +0,0 @@
import asyncio
import base64
import hmac
import json
from hashlib import sha256
from typing import Any
from urllib.parse import unquote
import aiohttp
from astrbot.api import logger
class LineAPIClient:
def __init__(
self,
*,
channel_access_token: str,
channel_secret: str,
timeout_seconds: int = 30,
) -> None:
self.channel_access_token = channel_access_token.strip()
self.channel_secret = channel_secret.strip()
self.timeout = aiohttp.ClientTimeout(total=timeout_seconds)
self._session: aiohttp.ClientSession | None = None
async def _get_session(self) -> aiohttp.ClientSession:
if self._session is None or self._session.closed:
self._session = aiohttp.ClientSession(timeout=self.timeout)
return self._session
async def close(self) -> None:
if self._session and not self._session.closed:
await self._session.close()
def verify_signature(self, raw_body: bytes, signature: str | None) -> bool:
if not signature:
return False
digest = hmac.new(
self.channel_secret.encode("utf-8"),
raw_body,
sha256,
).digest()
expected = base64.b64encode(digest).decode("utf-8")
return hmac.compare_digest(expected, signature.strip())
@property
def _auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self.channel_access_token}"}
async def reply_message(
self,
reply_token: str,
messages: list[dict[str, Any]],
*,
notification_disabled: bool = False,
) -> bool:
payload = {
"replyToken": reply_token,
"messages": messages[:5],
"notificationDisabled": notification_disabled,
}
return await self._post_json(
"https://api.line.me/v2/bot/message/reply",
payload=payload,
op_name="reply",
)
async def push_message(
self,
to: str,
messages: list[dict[str, Any]],
*,
notification_disabled: bool = False,
) -> bool:
payload = {
"to": to,
"messages": messages[:5],
"notificationDisabled": notification_disabled,
}
return await self._post_json(
"https://api.line.me/v2/bot/message/push",
payload=payload,
op_name="push",
)
async def _post_json(
self,
url: str,
*,
payload: dict[str, Any],
op_name: str,
) -> bool:
session = await self._get_session()
headers = {
**self._auth_headers,
"Content-Type": "application/json",
}
try:
async with session.post(url, json=payload, headers=headers) as resp:
if resp.status < 400:
return True
body = await resp.text()
logger.error(
"[LINE] %s message failed: status=%s body=%s",
op_name,
resp.status,
body,
)
return False
except Exception as e:
logger.error("[LINE] %s message request failed: %s", op_name, e)
return False
async def get_message_content(
self,
message_id: str,
) -> tuple[bytes, str | None, str | None] | None:
session = await self._get_session()
url = f"https://api-data.line.me/v2/bot/message/{message_id}/content"
headers = self._auth_headers
async with session.get(url, headers=headers) as resp:
if resp.status == 202:
if not await self._wait_for_transcoding(message_id):
return None
async with session.get(url, headers=headers) as retry_resp:
if retry_resp.status != 200:
body = await retry_resp.text()
logger.warning(
"[LINE] get content retry failed: message_id=%s status=%s body=%s",
message_id,
retry_resp.status,
body,
)
return None
return await self._read_content_response(retry_resp)
if resp.status != 200:
body = await resp.text()
logger.warning(
"[LINE] get content failed: message_id=%s status=%s body=%s",
message_id,
resp.status,
body,
)
return None
return await self._read_content_response(resp)
async def _read_content_response(
self,
resp: aiohttp.ClientResponse,
) -> tuple[bytes, str | None, str | None]:
content = await resp.read()
content_type = resp.headers.get("Content-Type")
disposition = resp.headers.get("Content-Disposition")
filename = self._extract_filename_from_disposition(disposition)
return content, content_type, filename
def _extract_filename_from_disposition(self, disposition: str | None) -> str | None:
if not disposition:
return None
for part in disposition.split(";"):
token = part.strip()
if token.startswith("filename*="):
val = token.split("=", 1)[1].strip().strip('"')
if val.lower().startswith("utf-8''"):
val = val[7:]
return unquote(val)
if token.startswith("filename="):
return token.split("=", 1)[1].strip().strip('"')
return None
async def _wait_for_transcoding(
self,
message_id: str,
*,
max_attempts: int = 10,
interval_seconds: float = 1.0,
) -> bool:
session = await self._get_session()
url = (
f"https://api-data.line.me/v2/bot/message/{message_id}/content/transcoding"
)
headers = self._auth_headers
for _ in range(max_attempts):
try:
async with session.get(url, headers=headers) as resp:
if resp.status != 200:
await asyncio.sleep(interval_seconds)
continue
body = await resp.text()
data = json.loads(body)
status = str(data.get("status", "")).lower()
if status == "succeeded":
return True
if status == "failed":
return False
except Exception:
pass
await asyncio.sleep(interval_seconds)
return False
@@ -1,285 +0,0 @@
import asyncio
import os
import re
import uuid
from collections.abc import AsyncGenerator
from pathlib import Path
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import (
At,
BaseMessageComponent,
File,
Image,
Plain,
Record,
Video,
)
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import get_media_duration
from .line_api import LineAPIClient
class LineMessageEvent(AstrMessageEvent):
def __init__(
self,
message_str,
message_obj,
platform_meta,
session_id,
line_api: LineAPIClient,
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.line_api = line_api
@staticmethod
async def _component_to_message_object(
segment: BaseMessageComponent,
) -> dict | None:
if isinstance(segment, Plain):
text = segment.text.strip()
if not text:
return None
return {"type": "text", "text": text[:5000]}
if isinstance(segment, At):
name = str(segment.name or segment.qq or "").strip()
if not name:
return None
return {"type": "text", "text": f"@{name}"[:5000]}
if isinstance(segment, Image):
image_url = await LineMessageEvent._resolve_image_url(segment)
if not image_url:
return None
return {
"type": "image",
"originalContentUrl": image_url,
"previewImageUrl": image_url,
}
if isinstance(segment, Record):
audio_url = await LineMessageEvent._resolve_record_url(segment)
if not audio_url:
return None
duration = await LineMessageEvent._resolve_record_duration(segment)
return {
"type": "audio",
"originalContentUrl": audio_url,
"duration": duration,
}
if isinstance(segment, Video):
video_url = await LineMessageEvent._resolve_video_url(segment)
if not video_url:
return None
preview_url = await LineMessageEvent._resolve_video_preview_url(segment)
if not preview_url:
return None
return {
"type": "video",
"originalContentUrl": video_url,
"previewImageUrl": preview_url,
}
if isinstance(segment, File):
file_url = await LineMessageEvent._resolve_file_url(segment)
if not file_url:
return None
file_name = str(segment.name or "").strip() or "file.bin"
file_size = await LineMessageEvent._resolve_file_size(segment)
if file_size <= 0:
return None
return {
"type": "file",
"fileName": file_name,
"fileSize": file_size,
"originalContentUrl": file_url,
}
return None
@staticmethod
async def _resolve_image_url(segment: Image) -> str:
candidate = (segment.url or segment.file or "").strip()
if candidate.startswith("http://") or candidate.startswith("https://"):
return candidate
try:
return await segment.register_to_file_service()
except Exception as e:
logger.debug("[LINE] resolve image url failed: %s", e)
return ""
@staticmethod
async def _resolve_record_url(segment: Record) -> str:
candidate = (segment.url or segment.file or "").strip()
if candidate.startswith("http://") or candidate.startswith("https://"):
return candidate
try:
return await segment.register_to_file_service()
except Exception as e:
logger.debug("[LINE] resolve record url failed: %s", e)
return ""
@staticmethod
async def _resolve_record_duration(segment: Record) -> int:
try:
file_path = await segment.convert_to_file_path()
duration_ms = await get_media_duration(file_path)
if isinstance(duration_ms, int) and duration_ms > 0:
return duration_ms
except Exception as e:
logger.debug("[LINE] resolve record duration failed: %s", e)
return 1000
@staticmethod
async def _resolve_video_url(segment: Video) -> str:
candidate = (segment.file or "").strip()
if candidate.startswith("http://") or candidate.startswith("https://"):
return candidate
try:
return await segment.register_to_file_service()
except Exception as e:
logger.debug("[LINE] resolve video url failed: %s", e)
return ""
@staticmethod
async def _resolve_video_preview_url(segment: Video) -> str:
cover_candidate = (segment.cover or "").strip()
if cover_candidate.startswith("http://") or cover_candidate.startswith(
"https://"
):
return cover_candidate
if cover_candidate:
try:
cover_seg = Image(file=cover_candidate)
return await cover_seg.register_to_file_service()
except Exception as e:
logger.debug("[LINE] resolve video cover failed: %s", e)
try:
video_path = await segment.convert_to_file_path()
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
thumb_path = temp_dir / f"line_video_preview_{uuid.uuid4().hex}.jpg"
process = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y",
"-ss",
"00:00:01",
"-i",
video_path,
"-frames:v",
"1",
str(thumb_path),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await process.communicate()
if process.returncode != 0 or not thumb_path.exists():
return ""
cover_seg = Image.fromFileSystem(str(thumb_path))
return await cover_seg.register_to_file_service()
except Exception as e:
logger.debug("[LINE] generate video preview failed: %s", e)
return ""
@staticmethod
async def _resolve_file_url(segment: File) -> str:
if segment.url and segment.url.startswith(("http://", "https://")):
return segment.url
try:
return await segment.register_to_file_service()
except Exception as e:
logger.debug("[LINE] resolve file url failed: %s", e)
return ""
@staticmethod
async def _resolve_file_size(segment: File) -> int:
try:
file_path = await segment.get_file(allow_return_url=False)
if file_path and os.path.exists(file_path):
return int(os.path.getsize(file_path))
except Exception as e:
logger.debug("[LINE] resolve file size failed: %s", e)
return 0
@classmethod
async def build_line_messages(cls, message_chain: MessageChain) -> list[dict]:
messages: list[dict] = []
for segment in message_chain.chain:
obj = await cls._component_to_message_object(segment)
if obj:
messages.append(obj)
if not messages:
return []
if len(messages) > 5:
logger.warning(
"[LINE] message count exceeds 5, extra segments will be dropped."
)
messages = messages[:5]
return messages
async def send(self, message: MessageChain) -> None:
messages = await self.build_line_messages(message)
if not messages:
return
raw = self.message_obj.raw_message
reply_token = ""
if isinstance(raw, dict):
reply_token = str(raw.get("replyToken") or "")
sent = False
if reply_token:
sent = await self.line_api.reply_message(reply_token, messages)
if not sent:
target_id = self.get_group_id() or self.get_sender_id()
if target_id:
await self.line_api.push_message(target_id, messages)
await super().send(message)
async def send_streaming(
self,
generator: AsyncGenerator,
use_fallback: bool = False,
):
if not use_fallback:
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
buffer = ""
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
async for chain in generator:
if isinstance(chain, MessageChain):
for comp in chain.chain:
if isinstance(comp, Plain):
buffer += comp.text
if any(p in buffer for p in "。?!~…"):
buffer = await self.process_buffer(buffer, pattern)
else:
await self.send(MessageChain(chain=[comp]))
await asyncio.sleep(1.5)
if buffer.strip():
await self.send(MessageChain([Plain(buffer)]))
return await super().send_streaming(generator, use_fallback)
@@ -21,7 +21,7 @@ try:
except Exception:
magic = None
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .misskey_event import MisskeyPlatformEvent
from .misskey_utils import (
@@ -498,7 +498,7 @@ class MisskeyPlatformAdapter(Platform):
finally:
# 清理临时文件
if local_path and isinstance(local_path, str):
data_temp = get_astrbot_temp_path()
data_temp = os.path.join(get_astrbot_data_path(), "temp")
if local_path.startswith(data_temp) and os.path.exists(
local_path,
):
@@ -19,7 +19,7 @@ from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Image, Plain, Record
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_image_by_url, file_to_base64
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
@@ -350,10 +350,10 @@ class QQOfficialMessageEvent(AstrMessageEvent):
elif isinstance(i, Record):
if i.file:
record_wav_path = await i.convert_to_file_path() # wav 路径
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_tecent_silk_path = os.path.join(
temp_dir,
f"qqofficial_{uuid.uuid4()}.silk",
f"{uuid.uuid4()}.silk",
)
try:
duration = await wav_to_tencent_silk(
@@ -8,11 +8,13 @@ from typing import cast
import botpy
import botpy.message
import botpy.types
import botpy.types.message
from botpy import Client
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, File, Image, Plain
from astrbot.api.message_components import At, Image, Plain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
@@ -141,41 +143,6 @@ class QQOfficialPlatformAdapter(Platform):
support_proactive_message=False,
)
@staticmethod
def _normalize_attachment_url(url: str | None) -> str:
if not url:
return ""
if url.startswith("http://") or url.startswith("https://"):
return url
return f"https://{url}"
@staticmethod
def _append_attachments(
msg: list[BaseMessageComponent],
attachments: list | None,
) -> None:
if not attachments:
return
for attachment in attachments:
content_type = cast(str, getattr(attachment, "content_type", "") or "")
url = QQOfficialPlatformAdapter._normalize_attachment_url(
cast(str | None, getattr(attachment, "url", None))
)
if not url:
continue
if content_type.startswith("image"):
msg.append(Image.fromURL(url))
else:
filename = cast(
str,
getattr(attachment, "filename", None)
or getattr(attachment, "name", None)
or "attachment",
)
msg.append(File(name=filename, file=url, url=url))
@staticmethod
def _parse_from_qqofficial(
message: botpy.message.Message
@@ -205,7 +172,14 @@ class QQOfficialPlatformAdapter(Platform):
abm.self_id = "unknown_selfid"
msg.append(At(qq="qq_official"))
msg.append(Plain(abm.message_str))
QQOfficialPlatformAdapter._append_attachments(msg, message.attachments)
if message.attachments:
for i in message.attachments:
if i.content_type.startswith("image"):
url = i.url
if not url.startswith("http"):
url = "https://" + url
img = Image.fromURL(url)
msg.append(img)
abm.message = msg
elif isinstance(message, botpy.message.Message) or isinstance(
@@ -222,7 +196,14 @@ class QQOfficialPlatformAdapter(Platform):
"",
).strip()
QQOfficialPlatformAdapter._append_attachments(msg, message.attachments)
if message.attachments:
for i in message.attachments:
if i.content_type.startswith("image"):
url = i.url
if not url.startswith("http"):
url = "https://" + url
img = Image.fromURL(url)
msg.append(img)
abm.message = msg
abm.message_str = plain_content
abm.sender = MessageMember(
@@ -1,11 +1,11 @@
import asyncio
import logging
import random
from types import SimpleNamespace
from typing import Any, cast
import botpy
import botpy.message
import botpy.types
import botpy.types.message
from botpy import Client
from astrbot import logger
@@ -15,7 +15,6 @@ from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from ..qqofficial.qqofficial_message_event import QQOfficialMessageEvent
from ..qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter
from .qo_webhook_event import QQOfficialWebhookMessageEvent
from .qo_webhook_server import QQOfficialWebhook
@@ -40,7 +39,6 @@ class botClient(Client):
)
abm.group_id = cast(str, message.group_openid)
abm.session_id = abm.group_id
self.platform.remember_session_scene(abm.session_id, "group")
self._commit(abm)
# 收到频道消息
@@ -51,7 +49,6 @@ class botClient(Client):
)
abm.group_id = message.channel_id
abm.session_id = abm.group_id
self.platform.remember_session_scene(abm.session_id, "channel")
self._commit(abm)
# 收到私聊消息
@@ -63,7 +60,6 @@ class botClient(Client):
MessageType.FRIEND_MESSAGE,
)
abm.session_id = abm.sender.user_id
self.platform.remember_session_scene(abm.session_id, "friend")
self._commit(abm)
# 收到 C2C 消息
@@ -73,11 +69,9 @@ class botClient(Client):
MessageType.FRIEND_MESSAGE,
)
abm.session_id = abm.sender.user_id
self.platform.remember_session_scene(abm.session_id, "friend")
self._commit(abm)
def _commit(self, abm: AstrBotMessage) -> None:
self.platform.remember_session_message_id(abm.session_id, abm.message_id)
self.platform.commit_event(
QQOfficialWebhookMessageEvent(
abm.message_str,
@@ -115,129 +109,20 @@ class QQOfficialWebhookPlatformAdapter(Platform):
)
self.client.set_platform(self)
self.webhook_helper = None
self._session_last_message_id: dict[str, str] = {}
self._session_scene: dict[str, str] = {}
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
(
plain_text,
image_base64,
image_path,
record_file_path,
) = await QQOfficialMessageEvent._parse_to_qqofficial(message_chain)
if not plain_text and not image_path:
return
msg_id = self._session_last_message_id.get(session.session_id)
if not msg_id:
logger.warning(
"[QQOfficialWebhook] No cached msg_id for session: %s, skip send_by_session",
session.session_id,
)
return
payload: dict[str, Any] = {"content": plain_text, "msg_id": msg_id}
ret: Any = None
send_helper = SimpleNamespace(bot=self.client)
if session.message_type == MessageType.GROUP_MESSAGE:
scene = self._session_scene.get(session.session_id)
if scene == "group":
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
1,
group_openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
send_helper, # type: ignore
record_file_path,
3,
group_openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
ret = await self.client.api.post_group_message(
group_openid=session.session_id,
**payload,
)
else:
if image_path:
payload["file_image"] = image_path
ret = await self.client.api.post_message(
channel_id=session.session_id,
**payload,
)
elif session.message_type == MessageType.FRIEND_MESSAGE:
payload["msg_seq"] = random.randint(1, 10000)
if image_base64:
media = await QQOfficialMessageEvent.upload_group_and_c2c_image(
send_helper, # type: ignore
image_base64,
1,
openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
if record_file_path:
media = await QQOfficialMessageEvent.upload_group_and_c2c_record(
send_helper, # type: ignore
record_file_path,
3,
openid=session.session_id,
)
payload["media"] = media
payload["msg_type"] = 7
ret = await QQOfficialMessageEvent.post_c2c_message(
send_helper, # type: ignore
openid=session.session_id,
**payload,
)
else:
logger.warning(
"[QQOfficialWebhook] Unsupported message type for send_by_session: %s",
session.message_type,
)
return
sent_message_id = self._extract_message_id(ret)
if sent_message_id:
self.remember_session_message_id(session.session_id, sent_message_id)
await super().send_by_session(session, message_chain)
def remember_session_message_id(self, session_id: str, message_id: str) -> None:
if not session_id or not message_id:
return
self._session_last_message_id[session_id] = message_id
def remember_session_scene(self, session_id: str, scene: str) -> None:
if not session_id or not scene:
return
self._session_scene[session_id] = scene
def _extract_message_id(self, ret: Any) -> str | None:
if isinstance(ret, dict):
message_id = ret.get("id")
return str(message_id) if message_id else None
message_id = getattr(ret, "id", None)
if message_id:
return str(message_id)
return None
raise NotImplementedError("QQ 机器人官方 API 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="qq_official_webhook",
description="QQ 机器人官方 API 适配器",
id=cast(str, self.config.get("id")),
support_proactive_message=True,
support_proactive_message=False,
)
async def run(self) -> None:
@@ -5,7 +5,6 @@ from typing import Any, cast
import telegramify_markdown
from telegram import ReactionTypeCustomEmoji, ReactionTypeEmoji
from telegram.constants import ChatAction
from telegram.ext import ExtBot
from astrbot import logger
@@ -32,14 +31,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
"word": re.compile(r"\s"),
}
# 消息类型到 chat action 的映射,用于优先级判断
ACTION_BY_TYPE: dict[type, str] = {
Record: ChatAction.UPLOAD_VOICE,
File: ChatAction.UPLOAD_DOCUMENT,
Image: ChatAction.UPLOAD_PHOTO,
Plain: ChatAction.TYPING,
}
def __init__(
self,
message_str: str,
@@ -76,71 +67,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
return chunks
@classmethod
async def _send_chat_action(
cls,
client: ExtBot,
chat_id: str,
action: ChatAction | str,
message_thread_id: str | None = None,
) -> None:
"""发送聊天状态动作"""
try:
payload: dict[str, Any] = {"chat_id": chat_id, "action": action}
if message_thread_id:
payload["message_thread_id"] = message_thread_id
await client.send_chat_action(**payload)
except Exception as e:
logger.warning(f"[Telegram] 发送 chat action 失败: {e}")
@classmethod
def _get_chat_action_for_chain(cls, chain: list[Any]) -> ChatAction | str:
"""根据消息链中的组件类型确定合适的 chat action(按优先级)"""
for seg_type, action in cls.ACTION_BY_TYPE.items():
if any(isinstance(seg, seg_type) for seg in chain):
return action
return ChatAction.TYPING
@classmethod
async def _send_media_with_action(
cls,
client: ExtBot,
upload_action: ChatAction | str,
send_coro,
*,
user_name: str,
message_thread_id: str | None = None,
**payload: Any,
) -> None:
"""发送媒体时显示 upload action,发送完成后恢复 typing"""
await cls._send_chat_action(client, user_name, upload_action, message_thread_id)
await send_coro(**payload)
await cls._send_chat_action(
client, user_name, ChatAction.TYPING, message_thread_id
)
async def _ensure_typing(
self,
user_name: str,
message_thread_id: str | None = None,
) -> None:
"""确保显示 typing 状态"""
await self._send_chat_action(
self.client, user_name, ChatAction.TYPING, message_thread_id
)
async def send_typing(self) -> None:
message_thread_id = None
if self.get_message_type() == MessageType.GROUP_MESSAGE:
user_name = self.message_obj.group_id
else:
user_name = self.get_sender_id()
if "#" in user_name:
user_name, message_thread_id = user_name.split("#")
await self._ensure_typing(user_name, message_thread_id)
@classmethod
async def send_with_client(
cls,
@@ -165,11 +91,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
if "#" in user_name:
# it's a supergroup chat with message_thread_id
user_name, message_thread_id = user_name.split("#")
# 根据消息链确定合适的 chat action 并发送
action = cls._get_chat_action_for_chain(message.chain)
await cls._send_chat_action(client, user_name, action, message_thread_id)
for i in message.chain:
payload = {
"chat_id": user_name,
@@ -274,12 +195,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
message_id = None
last_edit_time = 0 # 上次编辑消息的时间
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
last_chat_action_time = 0 # 上次发送 chat action 的时间
chat_action_interval = 0.5 # chat action 的节流间隔 (秒)
# 发送初始 typing 状态
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = asyncio.get_event_loop().time()
async for chain in generator:
if isinstance(chain, MessageChain):
@@ -304,25 +219,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
delta += i.text
elif isinstance(i, Image):
image_path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_PHOTO,
self.client.send_photo,
user_name=user_name,
message_thread_id=message_thread_id,
photo=image_path,
**cast(Any, payload),
await self.client.send_photo(
photo=image_path, **cast(Any, payload)
)
continue
elif isinstance(i, File):
path = await i.get_file()
name = i.name or os.path.basename(path)
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_DOCUMENT,
self.client.send_document,
user_name=user_name,
message_thread_id=message_thread_id,
await self.client.send_document(
document=path,
filename=name,
**cast(Any, payload),
@@ -330,15 +235,7 @@ class TelegramPlatformEvent(AstrMessageEvent):
continue
elif isinstance(i, Record):
path = await i.convert_to_file_path()
await self._send_media_with_action(
self.client,
ChatAction.UPLOAD_VOICE,
self.client.send_voice,
user_name=user_name,
message_thread_id=message_thread_id,
voice=path,
**cast(Any, payload),
)
await self.client.send_voice(voice=path, **cast(Any, payload))
continue
else:
logger.warning(f"不支持的消息类型: {type(i)}")
@@ -351,11 +248,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
if time_since_last_edit >= throttle_interval:
# 发送 typing 状态(带节流)
current_time = asyncio.get_event_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
# 编辑消息
try:
await self.client.edit_message_text(
@@ -371,11 +263,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
) # 更新上次编辑的时间
else:
# delta 长度一般不会大于 4096,因此这里直接发送
# 发送 typing 状态(带节流)
current_time = asyncio.get_event_loop().time()
if current_time - last_chat_action_time >= chat_action_interval:
await self._ensure_typing(user_name, message_thread_id)
last_chat_action_time = current_time
try:
msg = await self.client.send_message(
text=delta, **cast(Any, payload)
@@ -26,23 +26,14 @@ from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
class QueueListener:
def __init__(
self,
webchat_queue_mgr: WebChatQueueMgr,
callback: Callable,
stop_event: asyncio.Event,
) -> None:
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
self.webchat_queue_mgr = webchat_queue_mgr
self.callback = callback
self.stop_event = stop_event
async def run(self) -> None:
"""Register callback and keep adapter task alive."""
self.webchat_queue_mgr.set_listener(self.callback)
try:
await self.stop_event.wait()
finally:
await self.webchat_queue_mgr.clear_listener()
await asyncio.Event().wait()
@register_platform_adapter("webchat", "webchat")
@@ -65,8 +56,6 @@ class WebChatAdapter(Platform):
id="webchat",
support_proactive_message=False,
)
self._shutdown_event = asyncio.Event()
self._webchat_queue_mgr = webchat_queue_mgr
async def send_by_session(
self,
@@ -195,7 +184,7 @@ class WebChatAdapter(Platform):
abm = await self.convert_message(data)
await self.handle_msg(abm)
bot = QueueListener(self._webchat_queue_mgr, callback, self._shutdown_event)
bot = QueueListener(webchat_queue_mgr, callback)
return bot.run()
def meta(self) -> PlatformMetadata:
@@ -220,4 +209,5 @@ class WebChatAdapter(Platform):
self.commit_event(message_event)
async def terminate(self) -> None:
self._shutdown_event.set()
# Do nothing
pass
@@ -87,19 +87,6 @@ class WebChatQueueMgr:
for conversation_id in list(self.queues.keys()):
self._start_listener_if_needed(conversation_id)
async def clear_listener(self) -> None:
self._listener_callback = None
for close_event in list(self._queue_close_events.values()):
close_event.set()
self._queue_close_events.clear()
listener_tasks = list(self._listener_tasks.values())
for task in listener_tasks:
task.cancel()
if listener_tasks:
await asyncio.gather(*listener_tasks, return_exceptions=True)
self._listener_tasks.clear()
def _start_listener_if_needed(self, conversation_id: str):
if self._listener_callback is None:
return
@@ -25,8 +25,7 @@ from astrbot.api.platform import (
)
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.webhook_utils import log_webhook_info
from .wecom_event import WecomPlatformEvent
@@ -166,7 +165,6 @@ class WecomPlatformAdapter(Platform):
self.api_base_url += "/"
self.server = WecomServer(self._event_queue, self.config)
self.agent_id: str | None = None
self.client = WeChatClient(
self.config["corpid"].strip(),
@@ -217,36 +215,6 @@ class WecomPlatformAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
) -> None:
# 企业微信客服不支持主动发送
if hasattr(self.client, "kf_message"):
logger.warning("企业微信客服模式不支持 send_by_session 主动发送。")
await super().send_by_session(session, message_chain)
return
if not self.agent_id:
logger.warning(
f"send_by_session 失败:无法为会话 {session.session_id} 推断 agent_id。",
)
await super().send_by_session(session, message_chain)
return
message_obj = AstrBotMessage()
message_obj.self_id = self.agent_id
message_obj.session_id = session.session_id
message_obj.type = session.message_type
message_obj.sender = MessageMember(session.session_id, session.session_id)
message_obj.message = []
message_obj.message_str = ""
message_obj.message_id = uuid.uuid4().hex
message_obj.raw_message = {"_proactive_send": True}
event = WecomPlatformEvent(
message_str=message_obj.message_str,
message_obj=message_obj,
platform_meta=self.meta(),
session_id=message_obj.session_id,
client=self.client,
)
await event.send(message_chain)
await super().send_by_session(session, message_chain)
@override
@@ -344,14 +312,17 @@ class WecomPlatformAdapter(Platform):
self.client.media.download,
msg.media_id,
)
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr")
with open(path, "wb") as f:
f.write(resp.content)
try:
from pydub import AudioSegment
path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav")
path_wav = await convert_audio_to_wav(path, path_wav)
audio = AudioSegment.from_file(path)
audio.export(path_wav, format="wav")
except Exception as e:
logger.error(f"转换音频失败: {e}。如果没有安装 ffmpeg 请先安装。")
path_wav = path
@@ -373,7 +344,6 @@ class WecomPlatformAdapter(Platform):
logger.warning(f"暂未实现的事件: {msg.type}")
return
self.agent_id = abm.self_id
logger.info(f"abm: {abm}")
await self.handle_msg(abm)
@@ -400,8 +370,7 @@ class WecomPlatformAdapter(Platform):
self.client.media.download,
media_id,
)
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.jpg")
path = f"data/temp/wechat_kf_{media_id}.jpg"
with open(path, "wb") as f:
f.write(resp.content)
abm.message = [Image(file=path, url=path)]
@@ -413,14 +382,17 @@ class WecomPlatformAdapter(Platform):
media_id,
)
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"weixinkefu_{media_id}.amr")
with open(path, "wb") as f:
f.write(resp.content)
try:
from pydub import AudioSegment
path_wav = os.path.join(temp_dir, f"weixinkefu_{media_id}.wav")
path_wav = await convert_audio_to_wav(path, path_wav)
audio = AudioSegment.from_file(path)
audio.export(path_wav, format="wav")
except Exception as e:
logger.error(f"转换音频失败: {e}。如果没有安装 ffmpeg 请先安装。")
path_wav = path
@@ -1,16 +1,24 @@
import asyncio
import os
import uuid
from wechatpy.enterprise import WeChatClient
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import File, Image, Plain, Record, Video
from astrbot.api.message_components import Image, Plain, Record
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.utils.media_utils import convert_audio_to_amr
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .wecom_kf_message import WeChatKFMessage
try:
import pydub
except Exception:
logger.warning(
"检测到 pydub 库未安装,企业微信将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
)
class WecomPlatformEvent(AstrMessageEvent):
def __init__(
@@ -117,66 +125,25 @@ class WecomPlatformEvent(AstrMessageEvent):
)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
record_path_amr = await convert_audio_to_amr(record_path)
# 转成amr
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
pydub.AudioSegment.from_wav(record_path).export(
record_path_amr,
format="amr",
)
try:
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"微信客服上传语音失败: {e}")
await self.send(
MessageChain().message(
f"微信客服上传语音失败: {e}"
),
)
return
logger.info(f"微信客服上传语音返回: {response}")
kf_message_api.send_voice(
user_id,
self.get_self_id(),
response["media_id"],
)
finally:
if record_path_amr != record_path and os.path.exists(
record_path_amr,
):
try:
os.remove(record_path_amr)
except OSError as e:
logger.warning(f"删除临时音频文件失败: {e}")
elif isinstance(comp, File):
file_path = await comp.get_file()
with open(file_path, "rb") as f:
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("file", f)
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"微信客服上传文件失败: {e}")
logger.error(f"微信客服上传语音失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传文件失败: {e}"),
MessageChain().message(f"微信客服上传语音失败: {e}"),
)
return
logger.debug(f"微信客服上传文件返回: {response}")
kf_message_api.send_file(
user_id,
self.get_self_id(),
response["media_id"],
)
elif isinstance(comp, Video):
video_path = await comp.convert_to_file_path()
with open(video_path, "rb") as f:
try:
response = self.client.media.upload("video", f)
except Exception as e:
logger.error(f"微信客服上传视频失败: {e}")
await self.send(
MessageChain().message(f"微信客服上传视频失败: {e}"),
)
return
logger.debug(f"微信客服上传视频返回: {response}")
kf_message_api.send_video(
logger.info(f"微信客服上传语音返回: {response}")
kf_message_api.send_voice(
user_id,
self.get_self_id(),
response["media_id"],
@@ -216,66 +183,25 @@ class WecomPlatformEvent(AstrMessageEvent):
)
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
record_path_amr = await convert_audio_to_amr(record_path)
# 转成amr
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr")
pydub.AudioSegment.from_wav(record_path).export(
record_path_amr,
format="amr",
)
try:
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"企业微信上传语音失败: {e}")
await self.send(
MessageChain().message(
f"企业微信上传语音失败: {e}"
),
)
return
logger.info(f"企业微信上传语音返回: {response}")
self.client.message.send_voice(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
finally:
if record_path_amr != record_path and os.path.exists(
record_path_amr,
):
try:
os.remove(record_path_amr)
except OSError as e:
logger.warning(f"删除临时音频文件失败: {e}")
elif isinstance(comp, File):
file_path = await comp.get_file()
with open(file_path, "rb") as f:
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("file", f)
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"企业微信上传文件失败: {e}")
logger.error(f"企业微信上传语音失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传文件失败: {e}"),
MessageChain().message(f"企业微信上传语音失败: {e}"),
)
return
logger.debug(f"企业微信上传文件返回: {response}")
self.client.message.send_file(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
)
elif isinstance(comp, Video):
video_path = await comp.convert_to_file_path()
with open(video_path, "rb") as f:
try:
response = self.client.media.upload("video", f)
except Exception as e:
logger.error(f"企业微信上传视频失败: {e}")
await self.send(
MessageChain().message(f"企业微信上传视频失败: {e}"),
)
return
logger.debug(f"企业微信上传视频返回: {response}")
self.client.message.send_video(
logger.info(f"企业微信上传语音返回: {response}")
self.client.message.send_voice(
message_obj.self_id,
message_obj.session_id,
response["media_id"],
@@ -39,7 +39,6 @@ from .wecomai_utils import (
generate_random_string,
process_encrypted_image,
)
from .wecomai_webhook import WecomAIBotWebhookClient, WecomAIBotWebhookError
class WecomAIQueueListener:
@@ -85,24 +84,20 @@ class WecomAIBotAdapter(Platform):
self.bot_name = self.config.get("wecom_ai_bot_name", "")
self.initial_respond_text = self.config.get(
"wecomaibot_init_respond_text",
"",
"💭 思考中...",
)
self.friend_message_welcome_text = self.config.get(
"wecomaibot_friend_message_welcome_text",
"",
)
self.unified_webhook_mode = self.config.get("unified_webhook_mode", False)
self.msg_push_webhook_url = self.config.get("msg_push_webhook_url", "").strip()
self.only_use_webhook_url_to_send = bool(
self.config.get("only_use_webhook_url_to_send", False),
)
# 平台元数据
self.metadata = PlatformMetadata(
name="wecom_ai_bot",
description="企业微信智能机器人适配器,支持 HTTP 回调接收消息",
id=self.config.get("id", "wecom_ai_bot"),
support_proactive_message=bool(self.msg_push_webhook_url),
support_proactive_message=False,
)
# 初始化 API 客户端
@@ -127,16 +122,6 @@ class WecomAIBotAdapter(Platform):
self.queue_mgr,
self._handle_queued_message,
)
self._stream_plain_cache: dict[str, str] = {}
self.webhook_client: WecomAIBotWebhookClient | None = None
if self.msg_push_webhook_url:
try:
self.webhook_client = WecomAIBotWebhookClient(
self.msg_push_webhook_url,
)
except WecomAIBotWebhookError as e:
logger.error("企业微信消息推送 webhook 配置无效: %s", e)
async def _handle_queued_message(self, data: dict) -> None:
"""处理队列中的消息,类似webchat的callback"""
@@ -179,19 +164,16 @@ class WecomAIBotAdapter(Platform):
)
self.queue_mgr.set_pending_response(stream_id, callback_params)
if self.only_use_webhook_url_to_send and self.webhook_client:
return None
if self.initial_respond_text:
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id,
self.initial_respond_text,
False,
)
return await self.api_client.encrypt_message(
resp,
callback_params["nonce"],
callback_params["timestamp"],
)
resp = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id,
self.initial_respond_text,
False,
)
return await self.api_client.encrypt_message(
resp,
callback_params["nonce"],
callback_params["timestamp"],
)
except Exception as e:
logger.error("处理消息时发生异常: %s", e)
return None
@@ -199,7 +181,6 @@ class WecomAIBotAdapter(Platform):
# wechat server is requesting for updates of a stream
stream_id = message_data["stream"]["id"]
if not self.queue_mgr.has_back_queue(stream_id):
self._stream_plain_cache.pop(stream_id, None)
if self.queue_mgr.is_stream_finished(stream_id):
logger.debug(
f"Stream already finished, returning end message: {stream_id}"
@@ -227,48 +208,24 @@ class WecomAIBotAdapter(Platform):
return None
# aggregate all delta chains in the back queue
cached_plain_content = self._stream_plain_cache.get(stream_id, "")
latest_plain_content = cached_plain_content
latest_plain_content = ""
image_base64 = []
finish = False
while not queue.empty():
msg = await queue.get()
if msg["type"] == "plain":
plain_data = msg.get("data") or ""
if msg.get("streaming", False):
# streaming plain payload is already cumulative
cached_plain_content = plain_data
else:
# segmented non-stream send() pushes plain chunks, needs append
cached_plain_content += plain_data
latest_plain_content = cached_plain_content
latest_plain_content = msg["data"] or ""
elif msg["type"] == "image":
image_base64.append(msg["image_data"])
elif msg["type"] == "break":
continue
elif msg["type"] in {"end", "complete"}:
# stream end
finish = True
self.queue_mgr.remove_queues(stream_id, mark_finished=True)
self._stream_plain_cache.pop(stream_id, None)
break
logger.debug(
f"Aggregated content: {latest_plain_content}, image: {len(image_base64)}, finish: {finish}",
)
if not finish:
self._stream_plain_cache[stream_id] = cached_plain_content
if finish and not latest_plain_content and not image_base64:
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
stream_id,
"",
True,
)
return await self.api_client.encrypt_message(
end_message,
callback_params["nonce"],
callback_params["timestamp"],
)
if latest_plain_content or image_base64:
msg_items = []
if finish and image_base64:
@@ -436,23 +393,9 @@ class WecomAIBotAdapter(Platform):
session: MessageSesion,
message_chain: MessageChain,
) -> None:
"""通过消息推送 webhook 发送消息"""
if not self.webhook_client:
logger.warning(
"主动消息发送失败: 未配置企业微信消息推送 Webhook URL,请前往配置添加。session_id=%s",
session.session_id,
)
await super().send_by_session(session, message_chain)
return
try:
await self.webhook_client.send_message_chain(message_chain)
except Exception as e:
logger.error(
"企业微信消息推送失败(session=%s): %s",
session.session_id,
e,
)
"""通过会话发送消息"""
# 企业微信智能机器人主要通过回调响应,这里记录日志
logger.info("会话发送消息: %s -> %s", session.session_id, message_chain)
await super().send_by_session(session, message_chain)
def run(self) -> Awaitable[Any]:
@@ -505,8 +448,6 @@ class WecomAIBotAdapter(Platform):
session_id=message.session_id,
api_client=self.api_client,
queue_mgr=self.queue_mgr,
webhook_client=self.webhook_client,
only_use_webhook_url_to_send=self.only_use_webhook_url_to_send,
)
self.commit_event(message_event)
@@ -2,11 +2,13 @@
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import At, Image, Plain
from astrbot.api.message_components import (
Image,
Plain,
)
from .wecomai_api import WecomAIBotAPIClient
from .wecomai_queue_mgr import WecomAIQueueMgr
from .wecomai_webhook import WecomAIBotWebhookClient
class WecomAIBotMessageEvent(AstrMessageEvent):
@@ -20,8 +22,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
session_id: str,
api_client: WecomAIBotAPIClient,
queue_mgr: WecomAIQueueMgr,
webhook_client: WecomAIBotWebhookClient | None = None,
only_use_webhook_url_to_send: bool = False,
) -> None:
"""初始化消息事件
@@ -36,19 +36,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.api_client = api_client
self.queue_mgr = queue_mgr
self.webhook_client = webhook_client
self.only_use_webhook_url_to_send = only_use_webhook_url_to_send
async def _mark_stream_complete(self, stream_id: str) -> None:
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
await back_queue.put(
{
"type": "complete",
"data": "",
"streaming": False,
"session_id": stream_id,
},
)
@staticmethod
async def _send(
@@ -56,7 +43,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
stream_id: str,
queue_mgr: WecomAIQueueMgr,
streaming: bool = False,
suppress_unsupported_log: bool = False,
):
back_queue = queue_mgr.get_or_create_back_queue(stream_id)
@@ -72,17 +58,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
data = ""
for comp in message_chain.chain:
if isinstance(comp, At):
data = f"@{comp.name} "
await back_queue.put(
{
"type": "plain",
"data": data,
"streaming": streaming,
"session_id": stream_id,
},
)
elif isinstance(comp, Plain):
if isinstance(comp, Plain):
data = comp.text
await back_queue.put(
{
@@ -110,10 +86,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
except Exception as e:
logger.error("处理图片消息失败: %s", e)
else:
if not suppress_unsupported_log:
logger.warning(
f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过"
)
logger.warning(f"[WecomAI] 不支持的消息组件类型: {type(comp)}, 跳过")
return data
@@ -124,24 +97,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"wecom_ai_bot platform event raw_message should be a dict"
)
stream_id = raw.get("stream_id", self.session_id)
if self.only_use_webhook_url_to_send and self.webhook_client and message:
await self.webhook_client.send_message_chain(message)
await self._mark_stream_complete(stream_id)
await super().send(MessageChain([]))
return
if self.webhook_client and message:
await self.webhook_client.send_message_chain(
message,
unsupported_only=True,
)
await WecomAIBotMessageEvent._send(
message,
stream_id,
self.queue_mgr,
suppress_unsupported_log=self.webhook_client is not None,
)
await WecomAIBotMessageEvent._send(message, stream_id, self.queue_mgr)
await super().send(MessageChain([]))
async def send_streaming(self, generator, use_fallback=False) -> None:
@@ -154,23 +110,9 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
stream_id = raw.get("stream_id", self.session_id)
back_queue = self.queue_mgr.get_or_create_back_queue(stream_id)
if self.only_use_webhook_url_to_send and self.webhook_client:
merged_chain = MessageChain([])
async for chain in generator:
merged_chain.chain.extend(chain.chain)
merged_chain.squash_plain()
await self.webhook_client.send_message_chain(merged_chain)
await self._mark_stream_complete(stream_id)
await super().send_streaming(generator, use_fallback)
return
# 企业微信智能机器人不支持增量发送,因此我们需要在这里将增量内容累积起来,积累发送
increment_plain = ""
async for chain in generator:
if self.webhook_client:
await self.webhook_client.send_message_chain(
chain, unsupported_only=True
)
# 累积增量内容,并改写 Plain 段
chain.squash_plain()
for comp in chain.chain:
@@ -186,7 +128,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"type": "break", # break means a segment end
"data": final_data,
"streaming": True,
"session_id": stream_id,
"session_id": self.session_id,
},
)
final_data = ""
@@ -197,7 +139,6 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
stream_id=stream_id,
queue_mgr=self.queue_mgr,
streaming=True,
suppress_unsupported_log=self.webhook_client is not None,
)
await back_queue.put(
@@ -205,7 +146,7 @@ class WecomAIBotMessageEvent(AstrMessageEvent):
"type": "complete", # complete means we return the final result
"data": final_data,
"streaming": True,
"session_id": stream_id,
"session_id": self.session_id,
},
)
await super().send_streaming(generator, use_fallback)
@@ -1,225 +0,0 @@
"""企业微信智能机器人 webhook 推送客户端。"""
from __future__ import annotations
import base64
import hashlib
import mimetypes
from pathlib import Path
from typing import Any, Literal
from urllib.parse import parse_qs, urlencode, urlparse
import aiohttp
from astrbot.api import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, File, Image, Plain, Record, Video
from astrbot.core.utils.media_utils import convert_audio_format
class WecomAIBotWebhookError(RuntimeError):
"""企业微信 webhook 推送异常。"""
class WecomAIBotWebhookClient:
"""企业微信智能机器人 webhook 消息推送客户端。"""
def __init__(self, webhook_url: str, timeout_seconds: int = 15) -> None:
self.webhook_url = webhook_url.strip()
self.timeout_seconds = timeout_seconds
if not self.webhook_url:
raise WecomAIBotWebhookError("消息推送 webhook URL 不能为空")
self._webhook_key = self._extract_webhook_key()
def _extract_webhook_key(self) -> str:
parsed = urlparse(self.webhook_url)
key = parse_qs(parsed.query).get("key", [""])[0].strip()
if not key:
raise WecomAIBotWebhookError("消息推送 webhook URL 缺少 key 参数")
return key
def _build_upload_url(self, media_type: Literal["file", "voice"]) -> str:
query = urlencode({"key": self._webhook_key, "type": media_type})
return f"https://qyapi.weixin.qq.com/cgi-bin/webhook/upload_media?{query}"
@staticmethod
def _split_markdown_v2_content(content: str, max_bytes: int = 4096) -> list[str]:
if not content:
return []
chunks: list[str] = []
buffer: list[str] = []
current_size = 0
for char in content:
char_size = len(char.encode("utf-8"))
if current_size + char_size > max_bytes and buffer:
chunks.append("".join(buffer))
buffer = [char]
current_size = char_size
else:
buffer.append(char)
current_size += char_size
if buffer:
chunks.append("".join(buffer))
return chunks
async def send_payload(self, payload: dict[str, Any]) -> None:
timeout = aiohttp.ClientTimeout(total=self.timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(self.webhook_url, json=payload) as response:
text = await response.text()
if response.status != 200:
raise WecomAIBotWebhookError(
f"Webhook 请求失败: HTTP {response.status}, {text}"
)
result = await response.json(content_type=None)
if result.get("errcode") != 0:
raise WecomAIBotWebhookError(
f"Webhook 返回错误: {result.get('errcode')} {result.get('errmsg')}"
)
logger.debug("企业微信消息推送成功: %s", payload.get("msgtype", "unknown"))
async def send_markdown_v2(self, content: str) -> None:
for chunk in self._split_markdown_v2_content(content):
await self.send_payload(
{
"msgtype": "markdown_v2",
"markdown_v2": {"content": chunk},
}
)
async def send_image_base64(self, image_base64: str) -> None:
image_bytes = base64.b64decode(image_base64)
md5 = hashlib.md5(image_bytes).hexdigest()
await self.send_payload(
{
"msgtype": "image",
"image": {
"base64": image_base64,
"md5": md5,
},
}
)
async def upload_media(
self, file_path: Path, media_type: Literal["file", "voice"]
) -> str:
if not file_path.exists() or not file_path.is_file():
raise WecomAIBotWebhookError(f"文件不存在: {file_path}")
content_type = (
mimetypes.guess_type(str(file_path))[0] or "application/octet-stream"
)
form = aiohttp.FormData()
form.add_field(
"media",
file_path.read_bytes(),
filename=file_path.name,
content_type=content_type,
)
timeout = aiohttp.ClientTimeout(total=self.timeout_seconds)
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(
self._build_upload_url(media_type),
data=form,
) as response:
text = await response.text()
if response.status != 200:
raise WecomAIBotWebhookError(
f"上传媒体失败: HTTP {response.status}, {text}"
)
result = await response.json(content_type=None)
if result.get("errcode") != 0:
raise WecomAIBotWebhookError(
f"上传媒体失败: {result.get('errcode')} {result.get('errmsg')}"
)
media_id = result.get("media_id", "")
if not media_id:
raise WecomAIBotWebhookError("上传媒体失败: 返回缺少 media_id")
return str(media_id)
async def send_file(self, file_path: Path) -> None:
media_id = await self.upload_media(file_path, "file")
await self.send_payload(
{
"msgtype": "file",
"file": {"media_id": media_id},
}
)
async def send_voice(self, file_path: Path) -> None:
media_id = await self.upload_media(file_path, "voice")
await self.send_payload(
{
"msgtype": "voice",
"voice": {"media_id": media_id},
}
)
@staticmethod
def is_stream_supported_component(component: Any) -> bool:
return isinstance(component, Plain | Image | At)
async def send_message_chain(
self,
message_chain: MessageChain,
unsupported_only: bool = False,
) -> None:
async def flush_markdown_buffer(parts: list[str]) -> None:
content = "".join(parts).strip()
parts.clear()
if content:
await self.send_markdown_v2(content)
markdown_buffer: list[str] = []
for component in message_chain.chain:
if unsupported_only and self.is_stream_supported_component(component):
continue
if isinstance(component, Plain):
markdown_buffer.append(component.text)
elif isinstance(component, At):
mention_name = component.name or str(component.qq)
markdown_buffer.append(f" @{mention_name} ")
elif isinstance(component, Image):
await flush_markdown_buffer(markdown_buffer)
image_base64 = await component.convert_to_base64()
await self.send_image_base64(image_base64)
elif isinstance(component, File):
await flush_markdown_buffer(markdown_buffer)
file_path = await component.get_file()
if not file_path:
logger.warning("文件消息缺少有效文件路径,已跳过: %s", component)
continue
await self.send_file(Path(file_path))
elif isinstance(component, Video):
await flush_markdown_buffer(markdown_buffer)
video_path = await component.convert_to_file_path()
await self.send_file(Path(video_path))
elif isinstance(component, Record):
await flush_markdown_buffer(markdown_buffer)
source_voice_path = Path(await component.convert_to_file_path())
target_voice_path = source_voice_path
converted = False
if source_voice_path.suffix.lower() != ".amr":
target_voice_path = Path(
await convert_audio_format(str(source_voice_path), "amr"),
)
converted = target_voice_path != source_voice_path
try:
await self.send_voice(target_voice_path)
finally:
if converted and target_voice_path.exists():
try:
target_voice_path.unlink()
except Exception as e:
logger.warning(
"清理临时语音文件失败 %s: %s", target_voice_path, e
)
else:
logger.warning(
"企业微信消息推送暂不支持组件类型 %s,已跳过",
type(component).__name__,
)
await flush_markdown_buffer(markdown_buffer)
@@ -1,5 +1,4 @@
import asyncio
import os
import sys
import uuid
from collections.abc import Awaitable, Callable
@@ -25,8 +24,6 @@ from astrbot.api.platform import (
)
from astrbot.core import logger
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.media_utils import convert_audio_to_wav
from astrbot.core.utils.webhook_utils import log_webhook_info
from .weixin_offacc_event import WeixinOfficialAccountPlatformEvent
@@ -292,20 +289,19 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
self.client.media.download,
msg.media_id,
)
temp_dir = get_astrbot_temp_path()
path = os.path.join(temp_dir, f"weixin_offacc_{msg.media_id}.amr")
path = f"data/temp/wecom_{msg.media_id}.amr"
with open(path, "wb") as f:
f.write(resp.content)
try:
path_wav = os.path.join(
temp_dir,
f"weixin_offacc_{msg.media_id}.wav",
)
path_wav = await convert_audio_to_wav(path, path_wav)
from pydub import AudioSegment
path_wav = f"data/temp/wecom_{msg.media_id}.wav"
audio = AudioSegment.from_file(path)
audio.export(path_wav, format="wav")
except Exception as e:
logger.error(
f"转换音频失败: {e}。如果没有安装 ffmpeg 请先安装。",
f"转换音频失败: {e}。如果没有安装 pydub 和 ffmpeg 请先安装。",
)
path_wav = path
return
@@ -1,5 +1,5 @@
import asyncio
import os
import uuid
from typing import cast
from wechatpy import WeChatClient
@@ -9,7 +9,13 @@ from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Image, Plain, Record
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.core.utils.media_utils import convert_audio_to_amr
try:
import pydub
except Exception:
logger.warning(
"检测到 pydub 库未安装,微信公众平台将无法语音收发。如需使用语音,请前往管理面板 -> 平台日志 -> 安装 Pip 库安装 pydub。",
)
class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
@@ -131,46 +137,38 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
elif isinstance(comp, Record):
record_path = await comp.convert_to_file_path()
record_path_amr = await convert_audio_to_amr(record_path)
# 转成amr
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
pydub.AudioSegment.from_wav(record_path).export(
record_path_amr,
format="amr",
)
try:
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"微信公众平台上传语音失败: {e}")
await self.send(
MessageChain().message(
f"微信公众平台上传语音失败: {e}"
),
)
return
logger.info(f"微信公众平台上传语音返回: {response}")
with open(record_path_amr, "rb") as f:
try:
response = self.client.media.upload("voice", f)
except Exception as e:
logger.error(f"微信公众平台上传语音失败: {e}")
await self.send(
MessageChain().message(f"微信公众平台上传语音失败: {e}"),
)
return
logger.info(f"微信公众平台上传语音返回: {response}")
if active_send_mode:
self.client.message.send_voice(
message_obj.sender.user_id,
response["media_id"],
)
else:
reply = VoiceReply(
media_id=response["media_id"],
message=cast(dict, self.message_obj.raw_message)[
"message"
],
)
xml = reply.render()
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
finally:
if record_path_amr != record_path and os.path.exists(
record_path_amr
):
try:
os.remove(record_path_amr)
except OSError as e:
logger.warning(f"删除临时音频文件失败: {e}")
if active_send_mode:
self.client.message.send_voice(
message_obj.sender.user_id,
response["media_id"],
)
else:
reply = VoiceReply(
media_id=response["media_id"],
message=cast(dict, self.message_obj.raw_message)["message"],
)
xml = reply.render()
future = cast(dict, self.message_obj.raw_message)["future"]
assert isinstance(future, asyncio.Future)
future.set_result(xml)
else:
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}")
@@ -22,7 +22,6 @@ from astrbot.core.utils.network_utils import (
)
from ..register import register_provider_adapter
from .default import with_model_request_retry
@register_provider_adapter(
@@ -205,7 +204,6 @@ class ProviderAnthropic(Provider):
if usage.output_tokens is not None:
token_usage.output = usage.output_tokens
@with_model_request_retry()
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
if tool_list := tools.get_func_desc_anthropic_style():
@@ -267,10 +265,6 @@ class ProviderAnthropic(Provider):
return llm_response
@with_model_request_retry()
async def _create_message_stream(self, payloads: dict, extra_body: dict):
return self.client.messages.stream(**payloads, extra_body=extra_body)
async def _query_stream(
self,
payloads: dict,
@@ -299,8 +293,9 @@ class ProviderAnthropic(Provider):
"type": "enabled",
}
stream_ctx = await self._create_message_stream(payloads, extra_body)
async with stream_ctx as stream:
async with self.client.messages.stream(
**payloads, extra_body=extra_body
) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream:
if event.type == "message_start":
@@ -12,13 +12,12 @@ from httpx import AsyncClient, Timeout
from astrbot import logger
from astrbot.core.config.default import VERSION
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
from ..register import register_provider_adapter
TEMP_DIR = Path(get_astrbot_temp_path()) / "azure_tts"
TEMP_DIR = Path("data/temp/azure_tts")
TEMP_DIR.mkdir(parents=True, exist_ok=True)
@@ -15,7 +15,7 @@ except (
): # pragma: no cover - older dashscope versions without Qwen TTS support
MultiModalConversation = None
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -45,7 +45,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
if not model:
raise RuntimeError("Dashscope TTS model is not configured.")
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
if self._is_qwen_tts_model(model):
-38
View File
@@ -1,38 +0,0 @@
from tenacity import (
AsyncRetrying,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
MODEL_REQUEST_RETRY_ATTEMPTS = 5
MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS = 15
MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS = 1
MODEL_REQUEST_RETRY_WAIT_MULTIPLIER = 1
def with_model_request_retry():
return retry(
retry=retry_if_exception_type(Exception),
stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS),
wait=wait_exponential(
multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER,
min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS,
max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS,
),
reraise=True,
)
def get_model_request_async_retrying() -> AsyncRetrying:
return AsyncRetrying(
retry=retry_if_exception_type(Exception),
stop=stop_after_attempt(MODEL_REQUEST_RETRY_ATTEMPTS),
wait=wait_exponential(
multiplier=MODEL_REQUEST_RETRY_WAIT_MULTIPLIER,
min=MODEL_REQUEST_RETRY_WAIT_MIN_SECONDS,
max=MODEL_REQUEST_RETRY_WAIT_MAX_SECONDS,
),
reraise=True,
)
@@ -6,7 +6,7 @@ import uuid
import edge_tts
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -46,7 +46,7 @@ class ProviderEdgeTTS(TTSProvider):
self.set_model("edge_tts")
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
@@ -8,7 +8,7 @@ from httpx import AsyncClient
from pydantic import BaseModel, conint
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -142,7 +142,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
)
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
self.headers["content-type"] = "application/msgpack"
request = await self._generate_request(text)
+24 -16
View File
@@ -21,7 +21,6 @@ from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
from ..register import register_provider_adapter
from .default import get_model_request_async_retrying, with_model_request_retry
class SuppressNonTextPartsWarning(logging.Filter):
@@ -514,7 +513,6 @@ class ProviderGoogleGenAI(Provider):
llm_response.reasoning_signature = base64.b64encode(ts).decode("utf-8")
return MessageChain(chain=chain)
@with_model_request_retry()
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
"""非流式请求 Gemini API"""
system_instruction = next(
@@ -603,17 +601,6 @@ class ProviderGoogleGenAI(Provider):
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
async for attempt in get_model_request_async_retrying():
with attempt:
async for response in self._query_stream_once(payloads, tools):
yield response
return
async def _query_stream_once(
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式请求 Gemini API"""
system_instruction = next(
@@ -772,7 +759,18 @@ class ProviderGoogleGenAI(Provider):
payloads = {"messages": context_query, "model": model}
return await self._query(payloads, func_tool)
retry = 10
keys = self.api_keys.copy()
for _ in range(retry):
try:
return await self._query(payloads, func_tool)
except APIError as e:
if await self._handle_api_error(e, keys):
continue
break
raise Exception("请求失败。")
async def text_chat_stream(
self,
@@ -816,8 +814,18 @@ class ProviderGoogleGenAI(Provider):
payloads = {"messages": context_query, "model": model}
async for response in self._query_stream(payloads, func_tool):
yield response
retry = 10
keys = self.api_keys.copy()
for _ in range(retry):
try:
async for response in self._query_stream(payloads, func_tool):
yield response
break
except APIError as e:
if await self._handle_api_error(e, keys):
continue
break
async def get_models(self):
try:
@@ -6,7 +6,7 @@ from google import genai
from google.genai import types
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -49,7 +49,7 @@ class ProviderGeminiTTSAPI(TTSProvider):
self.voice_name: str = provider_config.get("gemini_tts_voice_name", "Leda")
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"gemini_tts_{uuid.uuid4()}.wav")
prompt = f"{self.prefix}: {text}" if self.prefix else text
response = await self.client.models.generate_content(
+3 -3
View File
@@ -6,7 +6,7 @@ from astrbot.core import logger
from astrbot.core.provider.entities import ProviderType
from astrbot.core.provider.provider import TTSProvider
from astrbot.core.provider.register import register_provider_adapter
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
try:
import genie_tts as genie # type: ignore
@@ -54,7 +54,7 @@ class GenieTTSProvider(TTSProvider):
return True
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename)
@@ -94,7 +94,7 @@ class GenieTTSProvider(TTSProvider):
break
try:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
filename = f"genie_tts_{uuid.uuid4()}.wav"
path = os.path.join(temp_dir, filename)
@@ -5,7 +5,7 @@ import uuid
import aiohttp
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -121,7 +121,7 @@ class ProviderGSVTTS(TTSProvider):
params = self.build_synthesis_params(text)
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"gsv_tts_{uuid.uuid4().hex}.wav")
@@ -4,7 +4,7 @@ import uuid
import aiohttp
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -29,7 +29,7 @@ class ProviderGSVITTS(TTSProvider):
self.emotion = provider_config.get("emotion")
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav")
params = {"text": text}
@@ -6,7 +6,7 @@ from collections.abc import AsyncIterator
import aiohttp
from astrbot.api import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -145,7 +145,7 @@ class ProviderMiniMaxTTSAPI(TTSProvider):
return b"".join(chunks)
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
path = os.path.join(temp_dir, f"minimax_tts_api_{uuid.uuid4()}.mp3")
+90 -261
View File
@@ -5,7 +5,6 @@ import json
import random
import re
from collections.abc import AsyncGenerator
from typing import Any
import httpx
from openai import AsyncAzureOpenAI, AsyncOpenAI
@@ -28,10 +27,8 @@ from astrbot.core.utils.network_utils import (
is_connection_error,
log_connection_failure,
)
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from ..register import register_provider_adapter
from .default import get_model_request_async_retrying, with_model_request_retry
@register_provider_adapter(
@@ -39,128 +36,6 @@ from .default import get_model_request_async_retrying, with_model_request_retry
"OpenAI API Chat Completion 提供商适配器",
)
class ProviderOpenAIOfficial(Provider):
_ERROR_TEXT_CANDIDATE_MAX_CHARS = 4096
@classmethod
def _truncate_error_text_candidate(cls, text: str) -> str:
if len(text) <= cls._ERROR_TEXT_CANDIDATE_MAX_CHARS:
return text
return text[: cls._ERROR_TEXT_CANDIDATE_MAX_CHARS]
@staticmethod
def _safe_json_dump(value: Any) -> str | None:
try:
return json.dumps(value, ensure_ascii=False, default=str)
except Exception:
return None
def _get_image_moderation_error_patterns(self) -> list[str]:
"""Return configured moderation patterns (case-insensitive substring match, not regex)."""
configured = self.provider_config.get("image_moderation_error_patterns", [])
patterns: list[str] = []
if isinstance(configured, str):
configured = [configured]
if isinstance(configured, list):
for pattern in configured:
if not isinstance(pattern, str):
continue
pattern = pattern.strip()
if pattern:
patterns.append(pattern)
return patterns
@staticmethod
def _extract_error_text_candidates(error: Exception) -> list[str]:
candidates: list[str] = []
def _append_candidate(candidate: Any):
if candidate is None:
return
text = str(candidate).strip()
if not text:
return
candidates.append(
ProviderOpenAIOfficial._truncate_error_text_candidate(text)
)
_append_candidate(str(error))
body = getattr(error, "body", None)
if isinstance(body, dict):
err_obj = body.get("error")
body_text = ProviderOpenAIOfficial._safe_json_dump(
{"error": err_obj} if isinstance(err_obj, dict) else body
)
_append_candidate(body_text)
if isinstance(err_obj, dict):
for field in ("message", "type", "code", "param"):
value = err_obj.get(field)
if value is not None:
_append_candidate(value)
elif isinstance(body, str):
_append_candidate(body)
response = getattr(error, "response", None)
if response is not None:
response_text = getattr(response, "text", None)
if isinstance(response_text, str):
_append_candidate(response_text)
return normalize_and_dedupe_strings(candidates)
def _is_content_moderated_upload_error(self, error: Exception) -> bool:
patterns = [
pattern.lower() for pattern in self._get_image_moderation_error_patterns()
]
if not patterns:
return False
candidates = [
candidate.lower()
for candidate in self._extract_error_text_candidates(error)
]
for pattern in patterns:
if any(pattern in candidate for candidate in candidates):
return True
return False
@staticmethod
def _context_contains_image(contexts: list[dict]) -> bool:
for context in contexts:
content = context.get("content")
if not isinstance(content, list):
continue
for item in content:
if isinstance(item, dict) and item.get("type") == "image_url":
return True
return False
async def _fallback_to_text_only_and_retry(
self,
payloads: dict,
context_query: list,
chosen_key: str,
available_api_keys: list[str],
func_tool: ToolSet | None,
reason: str,
*,
image_fallback_used: bool = False,
) -> tuple:
logger.warning(
"检测到图片请求失败(%s),已移除图片并重试(保留文本内容)。",
reason,
)
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
return (
False,
chosen_key,
available_api_keys,
payloads,
new_contexts,
func_tool,
image_fallback_used,
)
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
"""创建带代理的 HTTP 客户端"""
proxy = provider_config.get("proxy", "")
@@ -222,7 +97,6 @@ class ProviderOpenAIOfficial(Provider):
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
@with_model_request_retry()
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
@@ -248,6 +122,8 @@ class ProviderOpenAIOfficial(Provider):
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
model = payloads.get("model", "").lower()
completion = await self.client.chat.completions.create(
**payloads,
stream=False,
@@ -269,17 +145,6 @@ class ProviderOpenAIOfficial(Provider):
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
async for attempt in get_model_request_async_retrying():
with attempt:
async for response in self._query_stream_once(payloads, tools):
yield response
return
async def _query_stream_once(
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API,逐步返回结果"""
if tools:
@@ -334,8 +199,7 @@ class ProviderOpenAIOfficial(Provider):
llm_response.reasoning_content = reasoning
_y = True
if delta.content:
# Don't strip streaming chunks to preserve spaces between words
completion_text = self._normalize_content(delta.content, strip=False)
completion_text = delta.content
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)],
)
@@ -383,86 +247,6 @@ class ProviderOpenAIOfficial(Provider):
output=completion_tokens,
)
@staticmethod
def _normalize_content(raw_content: Any, strip: bool = True) -> str:
"""Normalize content from various formats to plain string.
Some LLM providers return content as list[dict] format
like [{'type': 'text', 'text': '...'}] instead of
plain string. This method handles both formats.
Args:
raw_content: The raw content from LLM response, can be str, list, or other.
strip: Whether to strip whitespace from the result. Set to False for
streaming chunks to preserve spaces between words.
Returns:
Normalized plain text string.
"""
if isinstance(raw_content, list):
# Check if this looks like OpenAI content-part format
# Only process if at least one item has {'type': 'text', 'text': ...} structure
has_content_part = any(
isinstance(part, dict) and part.get("type") == "text"
for part in raw_content
)
if has_content_part:
text_parts = []
for part in raw_content:
if isinstance(part, dict) and part.get("type") == "text":
text_val = part.get("text", "")
# Coerce to str in case text is null or non-string
text_parts.append(str(text_val) if text_val is not None else "")
return "".join(text_parts)
# Not content-part format, return string representation
return str(raw_content)
if isinstance(raw_content, str):
content = raw_content.strip() if strip else raw_content
# Check if the string is a JSON-encoded list (e.g., "[{'type': 'text', ...}]")
# This can happen when streaming concatenates content that was originally list format
# Only check if it looks like a complete JSON array (requires strip for check)
check_content = raw_content.strip()
if (
check_content.startswith("[")
and check_content.endswith("]")
and len(check_content) < 8192
):
try:
# First try standard JSON parsing
parsed = json.loads(check_content)
except json.JSONDecodeError:
# If that fails, try parsing as Python literal (handles single quotes)
# This is safer than blind replace("'", '"') which corrupts apostrophes
try:
import ast
parsed = ast.literal_eval(check_content)
except (ValueError, SyntaxError):
parsed = None
if isinstance(parsed, list):
# Only convert if it matches OpenAI content-part schema
# i.e., at least one item has {'type': 'text', 'text': ...}
has_content_part = any(
isinstance(part, dict) and part.get("type") == "text"
for part in parsed
)
if has_content_part:
text_parts = []
for part in parsed:
if isinstance(part, dict) and part.get("type") == "text":
text_val = part.get("text", "")
# Coerce to str in case text is null or non-string
text_parts.append(
str(text_val) if text_val is not None else ""
)
if text_parts:
return "".join(text_parts)
return content
return str(raw_content)
async def _parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse:
@@ -475,7 +259,8 @@ class ProviderOpenAIOfficial(Provider):
# parse the text completion
if choice.message.content is not None:
completion_text = self._normalize_content(choice.message.content)
# text completion
completion_text = str(choice.message.content).strip()
# specially, some providers may set <think> tags around reasoning content in the completion text,
# we use regex to remove them, and store then in reasoning_content field
reasoning_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
@@ -485,8 +270,6 @@ class ProviderOpenAIOfficial(Provider):
[match.strip() for match in matches],
)
completion_text = reasoning_pattern.sub("", completion_text).strip()
# Also clean up orphan </think> tags that may leak from some models
completion_text = re.sub(r"</think>\s*$", "", completion_text).strip()
llm_response.result_chain = MessageChain().message(completion_text)
# parse the reasoning content if any
@@ -620,7 +403,6 @@ class ProviderOpenAIOfficial(Provider):
available_api_keys: list[str],
retry_cnt: int,
max_retries: int,
image_fallback_used: bool = False,
) -> tuple:
"""处理API错误并尝试恢复"""
if "429" in str(e):
@@ -640,7 +422,6 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
image_fallback_used,
)
raise e
if "maximum context length" in str(e):
@@ -656,34 +437,20 @@ class ProviderOpenAIOfficial(Provider):
payloads,
context_query,
func_tool,
image_fallback_used,
)
if "The model is not a VLM" in str(e): # siliconcloud
if image_fallback_used or not self._context_contains_image(context_query):
raise e
# 尝试删除所有 image
return await self._fallback_to_text_only_and_retry(
payloads,
context_query,
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
return (
False,
chosen_key,
available_api_keys,
func_tool,
"model_not_vlm",
image_fallback_used=True,
)
if self._is_content_moderated_upload_error(e):
if image_fallback_used or not self._context_contains_image(context_query):
raise e
return await self._fallback_to_text_only_and_retry(
payloads,
context_query,
chosen_key,
available_api_keys,
func_tool,
"image_content_moderated",
image_fallback_used=True,
)
if (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
@@ -694,15 +461,7 @@ class ProviderOpenAIOfficial(Provider):
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。",
)
payloads.pop("tools", None)
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
None,
image_fallback_used,
)
return False, chosen_key, available_api_keys, payloads, context_query, None
# logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
@@ -727,7 +486,7 @@ class ProviderOpenAIOfficial(Provider):
extra_user_content_parts=None,
**kwargs,
) -> LLMResponse:
payloads, _ = await self._prepare_chat_payload(
payloads, context_query = await self._prepare_chat_payload(
prompt,
image_urls,
contexts,
@@ -739,9 +498,44 @@ class ProviderOpenAIOfficial(Provider):
)
llm_response = None
if self.api_keys:
self.client.api_key = random.choice(self.api_keys)
llm_response = await self._query(payloads, func_tool)
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
last_exception = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
last_exception = e
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1 or llm_response is None:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
if last_exception is None:
raise Exception("未知错误")
raise last_exception
return llm_response
async def text_chat_stream(
@@ -757,7 +551,7 @@ class ProviderOpenAIOfficial(Provider):
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
payloads, _ = await self._prepare_chat_payload(
payloads, context_query = await self._prepare_chat_payload(
prompt,
image_urls,
contexts,
@@ -767,10 +561,45 @@ class ProviderOpenAIOfficial(Provider):
**kwargs,
)
if self.api_keys:
self.client.api_key = random.choice(self.api_keys)
async for response in self._query_stream(payloads, func_tool):
yield response
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
last_exception = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
async for response in self._query_stream(payloads, func_tool):
yield response
break
except Exception as e:
last_exception = e
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
if last_exception is None:
raise Exception("未知错误")
raise last_exception
async def _remove_image_from_context(self, contexts: list):
"""从上下文中删除所有带有 image 的记录"""
@@ -5,7 +5,7 @@ import httpx
from openai import NOT_GIVEN, AsyncOpenAI
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -46,7 +46,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
self.set_model(provider_config.get("model", ""))
async def get_audio(self, text: str) -> str:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav")
async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name,
@@ -8,7 +8,6 @@ import uuid
import aiohttp
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..entities import ProviderType
from ..provider import TTSProvider
@@ -93,12 +92,9 @@ class ProviderVolcengineTTS(TTSProvider):
if "data" in resp_data:
audio_data = base64.b64decode(resp_data["data"])
temp_dir = get_astrbot_temp_path()
os.makedirs(temp_dir, exist_ok=True)
file_path = os.path.join(
temp_dir,
f"volcengine_tts_{uuid.uuid4()}.mp3",
)
os.makedirs("data/temp", exist_ok=True)
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
loop = asyncio.get_running_loop()
await loop.run_in_executor(
@@ -4,7 +4,7 @@ import uuid
from openai import NOT_GIVEN, AsyncOpenAI
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import (
convert_to_pcm_wav,
@@ -65,11 +65,9 @@ class ProviderOpenAIWhisperAPI(STTProvider):
if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True
temp_dir = get_astrbot_temp_path()
path = os.path.join(
temp_dir,
f"whisper_api_{uuid.uuid4().hex[:8]}.input",
)
name = str(uuid.uuid4())
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, name)
await download_file(audio_url, path)
audio_url = path
@@ -81,11 +79,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
# 判断是否需要转换
if file_format in ["silk", "amr"]:
temp_dir = get_astrbot_temp_path()
output_path = os.path.join(
temp_dir,
f"whisper_api_{uuid.uuid4().hex[:8]}.wav",
)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
if file_format == "silk":
logger.info(
@@ -6,7 +6,7 @@ from typing import cast
import whisper
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
@@ -58,11 +58,9 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
if "multimedia.nt.qq.com.cn" in audio_url:
is_tencent = True
temp_dir = get_astrbot_temp_path()
path = os.path.join(
temp_dir,
f"whisper_selfhost_{uuid.uuid4().hex[:8]}.input",
)
name = str(uuid.uuid4())
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
path = os.path.join(temp_dir, name)
await download_file(audio_url, path)
audio_url = path
@@ -73,11 +71,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
temp_dir = get_astrbot_temp_path()
output_path = os.path.join(
temp_dir,
f"whisper_selfhost_{uuid.uuid4().hex[:8]}.wav",
)
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
await tencent_silk_to_wav(audio_url, output_path)
audio_url = output_path
@@ -7,7 +7,7 @@ from xinference_client.client.restful.async_restful_client import (
)
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.tencent_record_helper import (
convert_to_pcm_wav,
tencent_silk_to_wav,
@@ -130,17 +130,11 @@ class ProviderXinferenceSTT(STTProvider):
logger.info(
f"Audio requires conversion ({conversion_type}), using temporary files..."
)
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
input_path = os.path.join(
temp_dir,
f"xinference_stt_{uuid.uuid4().hex[:8]}.input",
)
output_path = os.path.join(
temp_dir,
f"xinference_stt_{uuid.uuid4().hex[:8]}.wav",
)
input_path = os.path.join(temp_dir, str(uuid.uuid4()))
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
temp_files.extend([input_path, output_path])
with open(input_path, "wb") as f:
+1
View File
@@ -93,6 +93,7 @@ class SkillManager:
self.skills_root = skills_root or get_astrbot_skills_path()
self.config_path = os.path.join(get_astrbot_data_path(), SKILLS_CONFIG_FILENAME)
os.makedirs(self.skills_root, exist_ok=True)
os.makedirs(get_astrbot_temp_path(), exist_ok=True)
def _load_config(self) -> dict:
if not os.path.exists(self.config_path):
@@ -20,7 +20,6 @@ class PlatformAdapterType(enum.Flag):
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
SATORI = enum.auto()
MISSKEY = enum.auto()
LINE = enum.auto()
ALL = (
AIOCQHTTP
| QQOFFICIAL
@@ -35,7 +34,6 @@ class PlatformAdapterType(enum.Flag):
| WEIXIN_OFFICIAL_ACCOUNT
| SATORI
| MISSKEY
| LINE
)
@@ -53,7 +51,6 @@ ADAPTER_NAME_2_TYPE = {
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
"satori": PlatformAdapterType.SATORI,
"misskey": PlatformAdapterType.MISSKEY,
"line": PlatformAdapterType.LINE,
}
+107 -247
View File
@@ -62,9 +62,6 @@ class PluginManager:
self._pm_lock = asyncio.Lock()
"""StarManager操作互斥锁"""
self.failed_plugin_dict = {}
"""加载失败插件的信息,用于后续可能的热重载"""
self.failed_plugin_info = ""
if os.getenv("ASTRBOT_RELOAD", "0") == "1":
asyncio.create_task(self._watch_plugins_changes())
@@ -194,38 +191,6 @@ class PluginManager:
await pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {e!s}")
return True
async def _import_plugin_with_dependency_recovery(
self,
path: str,
module_str: str,
root_dir_name: str,
requirements_path: str,
) -> ModuleType:
try:
return __import__(path, fromlist=[module_str])
except (ModuleNotFoundError, ImportError) as import_exc:
if os.path.exists(requirements_path):
try:
logger.info(
f"插件 {root_dir_name} 导入失败,尝试从已安装依赖恢复: {import_exc!s}"
)
pip_installer.prefer_installed_dependencies(
requirements_path=requirements_path
)
module = __import__(path, fromlist=[module_str])
logger.info(
f"插件 {root_dir_name} 已从 site-packages 恢复依赖,跳过重新安装。"
)
return module
except Exception as recover_exc:
logger.info(
f"插件 {root_dir_name} 已安装依赖恢复失败,将重新安装依赖: {recover_exc!s}"
)
await self._check_plugin_dept_update(target_plugin=root_dir_name)
return __import__(path, fromlist=[module_str])
@staticmethod
def _load_plugin_metadata(plugin_path: str, plugin_obj=None) -> StarMetadata | None:
@@ -330,28 +295,6 @@ class PluginManager:
except KeyError:
logger.warning(f"模块 {module_name} 未载入")
async def reload_failed_plugin(self, dir_name):
"""
重新加载未注册加载失败的插件
Args:
dir_name (str): 要重载的特定插件名称
Returns:
tuple: 返回 load() 方法的结果包含 (success, error_message)
- success (bool): 重载是否成功
- error_message (str|None): 错误信息成功时为 None
"""
async with self._pm_lock:
if dir_name in self.failed_plugin_dict:
success, error = await self.load(specified_dir_name=dir_name)
if success:
self.failed_plugin_dict.pop(dir_name, None)
if not self.failed_plugin_dict:
self.failed_plugin_info = ""
return success, None
else:
return False, error
return False, "插件不存在于失败列表中"
async def reload(self, specified_plugin_name=None):
"""重新加载插件
@@ -442,12 +385,6 @@ class PluginManager:
"reserved",
False,
) # 是否是保留插件。目前在 astrbot/builtin_stars 目录下的都是保留插件。保留插件不可以卸载。
plugin_dir_path = (
os.path.join(self.plugin_store_path, root_dir_name)
if not reserved
else os.path.join(self.reserved_plugin_path, root_dir_name)
)
requirements_path = os.path.join(plugin_dir_path, "requirements.txt")
path = "data.plugins." if not reserved else "astrbot.builtin_stars."
path += root_dir_name + "." + module_str
@@ -462,12 +399,11 @@ class PluginManager:
# 尝试导入模块
try:
module = await self._import_plugin_with_dependency_recovery(
path=path,
module_str=module_str,
root_dir_name=root_dir_name,
requirements_path=requirements_path,
)
module = __import__(path, fromlist=[module_str])
except (ModuleNotFoundError, ImportError):
# 尝试安装依赖
await self._check_plugin_dept_update(target_plugin=root_dir_name)
module = __import__(path, fromlist=[module_str])
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"插件 {root_dir_name} 导入失败。原因:{e!s}")
@@ -475,6 +411,11 @@ class PluginManager:
# 检查 _conf_schema.json
plugin_config = None
plugin_dir_path = (
os.path.join(self.plugin_store_path, root_dir_name)
if not reserved
else os.path.join(self.reserved_plugin_path, root_dir_name)
)
plugin_schema_path = os.path.join(
plugin_dir_path,
self.conf_schema_fname,
@@ -688,11 +629,6 @@ class PluginManager:
logger.error(f"| {line}")
logger.error("----------------------------------")
fail_rec += f"加载 {root_dir_name} 插件时出现问题,原因 {e!s}\n"
self.failed_plugin_dict[root_dir_name] = {
"error": str(e),
"traceback": errors,
}
# 记录注册失败的插件名称,以便后续重载插件
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
@@ -708,49 +644,6 @@ class PluginManager:
self.failed_plugin_info = fail_rec
return False, fail_rec
async def _cleanup_failed_plugin_install(
self,
dir_name: str,
plugin_path: str,
) -> None:
plugin = None
for star in self.context.get_all_stars():
if star.root_dir_name == dir_name:
plugin = star
break
if plugin and plugin.name and plugin.module_path:
try:
await self._terminate_plugin(plugin)
except Exception:
logger.warning(traceback.format_exc())
try:
await self._unbind_plugin(plugin.name, plugin.module_path)
except Exception:
logger.warning(traceback.format_exc())
if os.path.exists(plugin_path):
try:
remove_dir(plugin_path)
logger.warning(f"已清理安装失败的插件目录: {plugin_path}")
except Exception as e:
logger.warning(
f"清理安装失败插件目录失败: {plugin_path},原因: {e!s}",
)
plugin_config_path = os.path.join(
self.plugin_config_path,
f"{dir_name}_config.json",
)
if os.path.exists(plugin_config_path):
try:
os.remove(plugin_config_path)
logger.warning(f"已清理安装失败插件配置: {plugin_config_path}")
except Exception as e:
logger.warning(
f"清理安装失败插件配置失败: {plugin_config_path},原因: {e!s}",
)
async def install_plugin(self, repo_url: str, proxy=""):
"""从仓库 URL 安装插件
@@ -776,62 +669,44 @@ class PluginManager:
)
async with self._pm_lock:
plugin_path = ""
dir_name = ""
cleanup_required = False
try:
plugin_path = await self.updator.install(repo_url, proxy)
cleanup_required = True
plugin_path = await self.updator.install(repo_url, proxy)
# reload the plugin
dir_name = os.path.basename(plugin_path)
await self.load(specified_dir_name=dir_name)
# reload the plugin
dir_name = os.path.basename(plugin_path)
success, error_message = await self.load(specified_dir_name=dir_name)
if not success:
raise Exception(
error_message
or f"安装插件 {dir_name} 失败,请检查插件依赖或兼容性。"
# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
if not plugin:
# Try to find by other name if directory name doesn't match plugin name
for star in self.context.get_all_stars():
if star.root_dir_name == dir_name:
plugin = star
break
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(plugin_path, "README.md")
if not os.path.exists(readme_path):
readme_path = os.path.join(plugin_path, "readme.md")
if os.path.exists(readme_path):
try:
with open(readme_path, encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(
f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}",
)
# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
if not plugin:
# Try to find by other name if directory name doesn't match plugin name
for star in self.context.get_all_stars():
if star.root_dir_name == dir_name:
plugin = star
break
plugin_info = None
if plugin:
plugin_info = {
"repo": plugin.repo,
"readme": readme_content,
"name": plugin.name,
}
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(plugin_path, "README.md")
if not os.path.exists(readme_path):
readme_path = os.path.join(plugin_path, "readme.md")
if os.path.exists(readme_path):
try:
with open(readme_path, encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(
f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}",
)
plugin_info = None
if plugin:
plugin_info = {
"repo": plugin.repo,
"readme": readme_content,
"name": plugin.name,
}
return plugin_info
except Exception:
if cleanup_required and dir_name and plugin_path:
await self._cleanup_failed_plugin_install(
dir_name=dir_name,
plugin_path=plugin_path,
)
raise
return plugin_info
async def uninstall_plugin(
self,
@@ -1093,7 +968,6 @@ class PluginManager:
dir_name = os.path.basename(zip_file_path).replace(".zip", "")
dir_name = dir_name.removesuffix("-master").removesuffix("-main").lower()
desti_dir = os.path.join(self.plugin_store_path, dir_name)
cleanup_required = False
# 第一步:检查是否已安装同目录名的插件,先终止旧插件
existing_plugin = None
@@ -1113,88 +987,74 @@ class PluginManager:
existing_plugin.name, existing_plugin.module_path
)
self.updator.unzip_file(zip_file_path, desti_dir)
# 第二步:解压后,读取新插件的 metadata.yaml,检查是否存在同名但不同目录的插件
try:
self.updator.unzip_file(zip_file_path, desti_dir)
cleanup_required = True
# 第二步:解压后,读取新插件的 metadata.yaml,检查是否存在同名但不同目录的插件
try:
new_metadata = self._load_plugin_metadata(desti_dir)
if new_metadata and new_metadata.name:
for star in self.context.get_all_stars():
if (
star.name == new_metadata.name
and star.root_dir_name != dir_name
):
logger.warning(
f"检测到同名插件 {star.name} 存在于不同目录 {star.root_dir_name},正在终止..."
)
try:
await self._terminate_plugin(star)
except Exception:
logger.warning(traceback.format_exc())
if star.name and star.module_path:
await self._unbind_plugin(star.name, star.module_path)
break # 只处理第一个匹配的
except Exception as e:
logger.debug(f"读取新插件 metadata.yaml 失败,跳过同名检查: {e!s}")
# remove the zip
try:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {e!s}")
# await self.reload()
success, error_message = await self.load(specified_dir_name=dir_name)
if not success:
raise Exception(
error_message
or f"安装插件 {dir_name} 失败,请检查插件依赖或兼容性。"
)
# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
if not plugin:
# Try to find by other name if directory name doesn't match plugin name
new_metadata = self._load_plugin_metadata(desti_dir)
if new_metadata and new_metadata.name:
for star in self.context.get_all_stars():
if star.root_dir_name == dir_name:
plugin = star
break
if (
star.name == new_metadata.name
and star.root_dir_name != dir_name
):
logger.warning(
f"检测到同名插件 {star.name} 存在于不同目录 {star.root_dir_name},正在终止..."
)
try:
await self._terminate_plugin(star)
except Exception:
logger.warning(traceback.format_exc())
if star.name and star.module_path:
await self._unbind_plugin(star.name, star.module_path)
break # 只处理第一个匹配的
except Exception as e:
logger.debug(f"读取新插件 metadata.yaml 失败,跳过同名检查: {e!s}")
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(desti_dir, "README.md")
if not os.path.exists(readme_path):
readme_path = os.path.join(desti_dir, "readme.md")
# remove the zip
try:
os.remove(zip_file_path)
except BaseException as e:
logger.warning(f"删除插件压缩包失败: {e!s}")
# await self.reload()
await self.load(specified_dir_name=dir_name)
if os.path.exists(readme_path):
try:
with open(readme_path, encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}")
# Get the plugin metadata to return repo info
plugin = self.context.get_registered_star(dir_name)
if not plugin:
# Try to find by other name if directory name doesn't match plugin name
for star in self.context.get_all_stars():
if star.root_dir_name == dir_name:
plugin = star
break
plugin_info = None
if plugin:
plugin_info = {
"repo": plugin.repo,
"readme": readme_content,
"name": plugin.name,
}
# Extract README.md content if exists
readme_content = None
readme_path = os.path.join(desti_dir, "README.md")
if not os.path.exists(readme_path):
readme_path = os.path.join(desti_dir, "readme.md")
if plugin.repo:
asyncio.create_task(
Metric.upload(
et="install_star_f", # install star
repo=plugin.repo,
),
)
if os.path.exists(readme_path):
try:
with open(readme_path, encoding="utf-8") as f:
readme_content = f.read()
except Exception as e:
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {e!s}")
return plugin_info
except Exception:
if cleanup_required:
await self._cleanup_failed_plugin_install(
dir_name=dir_name,
plugin_path=desti_dir,
plugin_info = None
if plugin:
plugin_info = {
"repo": plugin.repo,
"readme": readme_content,
"name": plugin.name,
}
if plugin.repo:
asyncio.create_task(
Metric.upload(
et="install_star_f", # install star
repo=plugin.repo,
),
)
raise
return plugin_info
+19 -72
View File
@@ -44,73 +44,6 @@ class AstrBotUpdator(RepoZipUpdator):
except psutil.NoSuchProcess:
pass
@staticmethod
def _is_option_arg(arg: str) -> bool:
return arg.startswith("-")
@classmethod
def _collect_flag_values(cls, argv: list[str], flag: str) -> str | None:
try:
idx = argv.index(flag)
except ValueError:
return None
if idx + 1 >= len(argv):
return None
value_parts: list[str] = []
for arg in argv[idx + 1 :]:
if cls._is_option_arg(arg):
break
if arg:
value_parts.append(arg)
if not value_parts:
return None
return " ".join(value_parts).strip() or None
@classmethod
def _resolve_webui_dir_arg(cls, argv: list[str]) -> str | None:
return cls._collect_flag_values(argv, "--webui-dir")
def _build_frozen_reboot_args(self) -> list[str]:
argv = list(sys.argv[1:])
webui_dir = self._resolve_webui_dir_arg(argv)
if not webui_dir:
webui_dir = os.environ.get("ASTRBOT_WEBUI_DIR")
if webui_dir:
return ["--webui-dir", webui_dir]
return []
@staticmethod
def _reset_pyinstaller_environment() -> None:
if not getattr(sys, "frozen", False):
return
os.environ["PYINSTALLER_RESET_ENVIRONMENT"] = "1"
for key in list(os.environ.keys()):
if key.startswith("_PYI_"):
os.environ.pop(key, None)
def _build_reboot_argv(self, executable: str) -> list[str]:
if os.environ.get("ASTRBOT_CLI") == "1":
args = sys.argv[1:]
return [executable, "-m", "astrbot.cli.__main__", *args]
if getattr(sys, "frozen", False):
args = self._build_frozen_reboot_args()
return [executable, *args]
return [executable, *sys.argv]
@staticmethod
def _exec_reboot(executable: str, argv: list[str]) -> None:
if os.name == "nt" and getattr(sys, "frozen", False):
quoted_executable = f'"{executable}"' if " " in executable else executable
quoted_args = [f'"{arg}"' if " " in arg else arg for arg in argv[1:]]
os.execl(executable, quoted_executable, *quoted_args)
return
os.execv(executable, argv)
def _reboot(self, delay: int = 3) -> None:
"""重启当前程序
在指定的延迟后终止所有子进程并重新启动程序
@@ -118,14 +51,28 @@ class AstrBotUpdator(RepoZipUpdator):
"""
time.sleep(delay)
self.terminate_child_processes()
executable = sys.executable
if os.name == "nt":
py = f'"{sys.executable}"'
else:
py = sys.executable
try:
self._reset_pyinstaller_environment()
reboot_argv = self._build_reboot_argv(executable)
self._exec_reboot(executable, reboot_argv)
# 仅 CLI 模式走 `python -m astrbot.cli.__main__`
# 打包后的后端可执行文件需要直接 exec 自身。
if os.environ.get("ASTRBOT_CLI") == "1":
if os.name == "nt":
args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
else:
args = sys.argv[1:]
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
else:
if getattr(sys, "frozen", False):
# Frozen executable should not receive argv[0] as a positional argument.
os.execl(sys.executable, py, *sys.argv[1:])
else:
os.execl(sys.executable, py, *sys.argv)
except Exception as e:
logger.error(f"重启失败({executable}, {e}),请尝试手动重启。")
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
raise e
async def check_update(
-4
View File
@@ -15,8 +15,6 @@ Skills 目录路径:固定为数据目录下的 skills 目录
import os
from astrbot.core.utils.runtime_env import is_packaged_electron_runtime
def get_astrbot_path() -> str:
"""获取Astrbot项目路径"""
@@ -29,8 +27,6 @@ def get_astrbot_root() -> str:
"""获取Astrbot根目录路径"""
if path := os.environ.get("ASTRBOT_ROOT"):
return os.path.realpath(path)
if is_packaged_electron_runtime():
return os.path.realpath(os.path.join(os.path.expanduser("~"), ".astrbot"))
return os.path.realpath(os.getcwd())
-33
View File
@@ -1,33 +0,0 @@
import logging
import ssl
import threading
import aiohttp
from astrbot.utils.http_ssl_common import (
build_ssl_context_with_certifi as _build_ssl_context,
)
logger = logging.getLogger("astrbot")
_SHARED_TLS_CONTEXT: ssl.SSLContext | None = None
_SHARED_TLS_CONTEXT_LOCK = threading.Lock()
def build_ssl_context_with_certifi() -> ssl.SSLContext:
"""Build an SSL context from system trust store and add certifi CAs."""
global _SHARED_TLS_CONTEXT
if _SHARED_TLS_CONTEXT is not None:
return _SHARED_TLS_CONTEXT
with _SHARED_TLS_CONTEXT_LOCK:
if _SHARED_TLS_CONTEXT is not None:
return _SHARED_TLS_CONTEXT
_SHARED_TLS_CONTEXT = _build_ssl_context(log_obj=logger)
return _SHARED_TLS_CONTEXT
def build_tls_connector() -> aiohttp.TCPConnector:
return aiohttp.TCPConnector(ssl=build_ssl_context_with_certifi())
+14 -3
View File
@@ -14,7 +14,7 @@ import certifi
import psutil
from PIL import Image
from .astrbot_path import get_astrbot_data_path, get_astrbot_temp_path
from .astrbot_path import get_astrbot_data_path
logger = logging.getLogger("astrbot")
@@ -50,10 +50,21 @@ def port_checker(port: int, host: str = "localhost") -> bool:
def save_temp_img(img: Image.Image | bytes) -> str:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
# 获得文件创建时间,清除超过 12 小时的
try:
for f in os.listdir(temp_dir):
path = os.path.join(temp_dir, f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600 * 12:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}")
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
p = os.path.join(temp_dir, f"io_temp_img_{timestamp}.jpg")
p = os.path.join(temp_dir, f"{timestamp}.jpg")
if isinstance(img, Image.Image):
img.save(p)
+1 -4
View File
@@ -3,7 +3,6 @@ from typing import Literal, TypedDict
import aiohttp
from astrbot.core import logger
from astrbot.core.utils.http_ssl import build_tls_connector
class LLMModalities(TypedDict):
@@ -33,9 +32,7 @@ LLM_METADATAS: dict[str, LLMMetadata] = {}
async def update_llm_metadata() -> None:
url = "https://models.dev/api.json"
try:
async with aiohttp.ClientSession(
trust_env=True, connector=build_tls_connector()
) as session:
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
data = await response.json()
global LLM_METADATAS
+5 -116
View File
@@ -7,10 +7,9 @@ import asyncio
import os
import subprocess
import uuid
from pathlib import Path
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
async def get_media_duration(file_path: str) -> int | None:
@@ -77,9 +76,9 @@ async def convert_audio_to_opus(audio_path: str, output_path: str | None = None)
# 生成输出文件路径
if output_path is None:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
output_path = os.path.join(temp_dir, f"media_audio_{uuid.uuid4().hex}.opus")
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.opus")
try:
# 使用ffmpeg转换为opus格式
@@ -156,12 +155,9 @@ async def convert_video_format(
# 生成输出文件路径
if output_path is None:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
output_path = os.path.join(
temp_dir,
f"media_video_{uuid.uuid4().hex}.{output_format}",
)
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{output_format}")
try:
# 使用ffmpeg转换视频格式
@@ -209,110 +205,3 @@ async def convert_video_format(
except Exception as e:
logger.error(f"[Media Utils] 转换视频格式时出错: {e}")
raise
async def convert_audio_format(
audio_path: str,
output_format: str = "amr",
output_path: str | None = None,
) -> str:
"""使用ffmpeg将音频转换为指定格式。
Args:
audio_path: 原始音频文件路径
output_format: 目标格式例如 amr / ogg
output_path: 输出文件路径如果为None则自动生成
Returns:
转换后的音频文件路径
"""
if audio_path.lower().endswith(f".{output_format}"):
return audio_path
if output_path is None:
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
output_path = str(temp_dir / f"media_audio_{uuid.uuid4().hex}.{output_format}")
args = ["ffmpeg", "-y", "-i", audio_path]
if output_format == "amr":
args.extend(["-ac", "1", "-ar", "8000", "-ab", "12.2k"])
elif output_format == "ogg":
args.extend(["-acodec", "libopus", "-ac", "1", "-ar", "16000"])
args.append(output_path)
try:
process = await asyncio.create_subprocess_exec(
*args,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
_, stderr = await process.communicate()
if process.returncode != 0:
if output_path and os.path.exists(output_path):
try:
os.remove(output_path)
except OSError as e:
logger.warning(f"[Media Utils] 清理失败的音频输出文件时出错: {e}")
error_msg = stderr.decode() if stderr else "未知错误"
raise Exception(f"ffmpeg conversion failed: {error_msg}")
logger.debug(f"[Media Utils] 音频转换成功: {audio_path} -> {output_path}")
return output_path
except FileNotFoundError:
raise Exception("ffmpeg not found")
async def convert_audio_to_amr(audio_path: str, output_path: str | None = None) -> str:
"""将音频转换为amr格式。"""
return await convert_audio_format(
audio_path=audio_path,
output_format="amr",
output_path=output_path,
)
async def convert_audio_to_wav(audio_path: str, output_path: str | None = None) -> str:
"""将音频转换为wav格式。"""
return await convert_audio_format(
audio_path=audio_path,
output_format="wav",
output_path=output_path,
)
async def extract_video_cover(
video_path: str,
output_path: str | None = None,
) -> str:
"""从视频中提取封面图(JPG)。"""
if output_path is None:
temp_dir = Path(get_astrbot_temp_path())
temp_dir.mkdir(parents=True, exist_ok=True)
output_path = str(temp_dir / f"media_cover_{uuid.uuid4().hex}.jpg")
try:
process = await asyncio.create_subprocess_exec(
"ffmpeg",
"-y",
"-i",
video_path,
"-ss",
"00:00:00",
"-frames:v",
"1",
output_path,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
_, stderr = await process.communicate()
if process.returncode != 0:
if output_path and os.path.exists(output_path):
try:
os.remove(output_path)
except OSError as e:
logger.warning(f"[Media Utils] 清理失败的视频封面文件时出错: {e}")
error_msg = stderr.decode() if stderr else "未知错误"
raise Exception(f"ffmpeg extract cover failed: {error_msg}")
return output_path
except FileNotFoundError:
raise Exception("ffmpeg not found")
+3 -1
View File
@@ -77,7 +77,9 @@ def log_connection_failure(
f"代理地址: {effective_proxy},错误: {error}"
)
else:
logger.error(f"[{provider_label}] 网络连接失败 ({error_type})。错误: {error}")
logger.error(
f"[{provider_label}] 网络连接失败 ({error_type}),未配置代理。错误: {error}"
)
def create_proxy_client(
+48 -526
View File
@@ -1,43 +1,44 @@
import asyncio
import contextlib
import importlib
import importlib.metadata as importlib_metadata
import importlib.util
import io
import locale
import logging
import os
import re
import sys
import threading
from collections import deque
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
from astrbot.core.utils.runtime_env import is_packaged_electron_runtime
logger = logging.getLogger("astrbot")
_DISTLIB_FINDER_PATCH_ATTEMPTED = False
_SITE_PACKAGES_IMPORT_LOCK = threading.RLock()
def _robust_decode(line: bytes) -> str:
"""解码字节流,兼容不同平台的编码"""
try:
return line.decode("utf-8").strip()
except UnicodeDecodeError:
pass
try:
return line.decode(locale.getpreferredencoding(False)).strip()
except UnicodeDecodeError:
pass
if sys.platform.startswith("win"):
try:
return line.decode("gbk").strip()
except UnicodeDecodeError:
pass
return line.decode("utf-8", errors="replace").strip()
def _canonicalize_distribution_name(name: str) -> str:
return re.sub(r"[-_.]+", "-", name).strip("-").lower()
def _is_frozen_runtime() -> bool:
return bool(getattr(sys, "frozen", False))
def _get_pip_main():
try:
from pip._internal.cli.main import main as pip_main
except ImportError:
try:
from pip import main as pip_main
except ImportError as exc:
raise ImportError(
"pip module is unavailable "
f"(sys.executable={sys.executable}, "
f"frozen={getattr(sys, 'frozen', False)}, "
f"ASTRBOT_ELECTRON_CLIENT={os.environ.get('ASTRBOT_ELECTRON_CLIENT')})"
) from exc
from pip import main as pip_main
return pip_main
@@ -59,477 +60,6 @@ def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> No
handler.close()
def _prepend_sys_path(path: str) -> None:
normalized_target = os.path.realpath(path)
sys.path[:] = [
item for item in sys.path if os.path.realpath(item) != normalized_target
]
sys.path.insert(0, normalized_target)
def _module_exists_in_site_packages(module_name: str, site_packages_path: str) -> bool:
base_path = os.path.join(site_packages_path, *module_name.split("."))
package_init = os.path.join(base_path, "__init__.py")
module_file = f"{base_path}.py"
return os.path.isfile(package_init) or os.path.isfile(module_file)
def _is_module_loaded_from_site_packages(
module_name: str,
site_packages_path: str,
) -> bool:
module = sys.modules.get(module_name)
if module is None:
try:
module = importlib.import_module(module_name)
except Exception:
return False
module_file = getattr(module, "__file__", None)
if not module_file:
return False
module_path = os.path.realpath(module_file)
site_packages_real = os.path.realpath(site_packages_path)
try:
return (
os.path.commonpath([module_path, site_packages_real]) == site_packages_real
)
except ValueError:
return False
def _extract_requirement_name(raw_requirement: str) -> str | None:
line = raw_requirement.split("#", 1)[0].strip()
if not line:
return None
if line.startswith(("-r", "--requirement", "-c", "--constraint")):
return None
if line.startswith("-"):
return None
egg_match = re.search(r"#egg=([A-Za-z0-9_.-]+)", raw_requirement)
if egg_match:
return _canonicalize_distribution_name(egg_match.group(1))
candidate = re.split(r"[<>=!~;\s\[]", line, maxsplit=1)[0].strip()
if not candidate:
return None
return _canonicalize_distribution_name(candidate)
def _extract_requirement_names(requirements_path: str) -> set[str]:
names: set[str] = set()
try:
with open(requirements_path, encoding="utf-8") as requirements_file:
for line in requirements_file:
requirement_name = _extract_requirement_name(line)
if requirement_name:
names.add(requirement_name)
except Exception as exc:
logger.warning("读取依赖文件失败,跳过冲突检测: %s", exc)
return names
def _extract_top_level_modules(
distribution: importlib_metadata.Distribution,
) -> set[str]:
try:
text = distribution.read_text("top_level.txt") or ""
except Exception:
return set()
modules: set[str] = set()
for line in text.splitlines():
candidate = line.strip()
if not candidate or candidate.startswith("#"):
continue
modules.add(candidate)
return modules
def _collect_candidate_modules(
requirement_names: set[str],
site_packages_path: str,
) -> set[str]:
by_name: dict[str, list[importlib_metadata.Distribution]] = {}
try:
for distribution in importlib_metadata.distributions(path=[site_packages_path]):
distribution_name = distribution.metadata.get("Name")
if not distribution_name:
continue
canonical_name = _canonicalize_distribution_name(distribution_name)
by_name.setdefault(canonical_name, []).append(distribution)
except Exception as exc:
logger.warning("读取 site-packages 元数据失败,使用回退模块名: %s", exc)
expanded_requirement_names: set[str] = set()
pending = deque(requirement_names)
while pending:
requirement_name = pending.popleft()
if requirement_name in expanded_requirement_names:
continue
expanded_requirement_names.add(requirement_name)
for distribution in by_name.get(requirement_name, []):
for dependency_line in distribution.requires or []:
dependency_name = _extract_requirement_name(dependency_line)
if not dependency_name:
continue
if dependency_name in expanded_requirement_names:
continue
pending.append(dependency_name)
candidates: set[str] = set()
for requirement_name in expanded_requirement_names:
matched_distributions = by_name.get(requirement_name, [])
modules_for_requirement: set[str] = set()
for distribution in matched_distributions:
modules_for_requirement.update(_extract_top_level_modules(distribution))
if modules_for_requirement:
candidates.update(modules_for_requirement)
continue
fallback_module_name = requirement_name.replace("-", "_")
if fallback_module_name:
candidates.add(fallback_module_name)
return candidates
def _ensure_preferred_modules(
module_names: set[str],
site_packages_path: str,
) -> None:
unresolved_prefer_reasons = _prefer_modules_from_site_packages(
module_names, site_packages_path
)
unresolved_modules: list[str] = []
for module_name in sorted(module_names):
if not _module_exists_in_site_packages(module_name, site_packages_path):
continue
if _is_module_loaded_from_site_packages(module_name, site_packages_path):
continue
failure_reason = unresolved_prefer_reasons.get(module_name)
if failure_reason:
unresolved_modules.append(f"{module_name} -> {failure_reason}")
continue
loaded_module = sys.modules.get(module_name)
loaded_from = getattr(loaded_module, "__file__", "unknown")
unresolved_modules.append(f"{module_name} -> {loaded_from}")
if unresolved_modules:
conflict_message = (
"检测到插件依赖与当前运行时发生冲突,无法安全加载该插件。"
f"冲突模块: {', '.join(unresolved_modules)}"
)
raise RuntimeError(conflict_message)
def _prefer_module_from_site_packages(
module_name: str, site_packages_path: str
) -> bool:
with _SITE_PACKAGES_IMPORT_LOCK:
base_path = os.path.join(site_packages_path, *module_name.split("."))
package_init = os.path.join(base_path, "__init__.py")
module_file = f"{base_path}.py"
module_location = None
submodule_search_locations = None
if os.path.isfile(package_init):
module_location = package_init
submodule_search_locations = [os.path.dirname(package_init)]
elif os.path.isfile(module_file):
module_location = module_file
else:
return False
spec = importlib.util.spec_from_file_location(
module_name,
module_location,
submodule_search_locations=submodule_search_locations,
)
if spec is None or spec.loader is None:
return False
matched_keys = [
key
for key in list(sys.modules.keys())
if key == module_name or key.startswith(f"{module_name}.")
]
original_modules = {key: sys.modules[key] for key in matched_keys}
try:
for key in matched_keys:
sys.modules.pop(key, None)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
if "." in module_name:
parent_name, child_name = module_name.rsplit(".", 1)
parent_module = sys.modules.get(parent_name)
if parent_module is not None:
setattr(parent_module, child_name, module)
logger.info(
"Loaded %s from plugin site-packages: %s",
module_name,
module_location,
)
return True
except Exception:
failed_keys = [
key
for key in list(sys.modules.keys())
if key == module_name or key.startswith(f"{module_name}.")
]
for key in failed_keys:
sys.modules.pop(key, None)
sys.modules.update(original_modules)
raise
def _extract_conflicting_module_name(exc: Exception) -> str | None:
if isinstance(exc, ModuleNotFoundError):
missing_name = getattr(exc, "name", None)
if missing_name:
return missing_name.split(".", 1)[0]
message = str(exc)
from_match = re.search(r"from '([A-Za-z0-9_.]+)'", message)
if from_match:
return from_match.group(1).split(".", 1)[0]
no_module_match = re.search(r"No module named '([A-Za-z0-9_.]+)'", message)
if no_module_match:
return no_module_match.group(1).split(".", 1)[0]
return None
def _prefer_module_with_dependency_recovery(
module_name: str,
site_packages_path: str,
max_attempts: int = 3,
) -> bool:
recovered_dependencies: set[str] = set()
for _ in range(max_attempts):
try:
return _prefer_module_from_site_packages(module_name, site_packages_path)
except Exception as exc:
dependency_name = _extract_conflicting_module_name(exc)
if (
not dependency_name
or dependency_name == module_name
or dependency_name in recovered_dependencies
):
raise
recovered_dependencies.add(dependency_name)
recovered = _prefer_module_from_site_packages(
dependency_name,
site_packages_path,
)
if not recovered:
raise
logger.info(
"Recovered dependency %s while preferring %s from plugin site-packages.",
dependency_name,
module_name,
)
return False
def _prefer_modules_from_site_packages(
module_names: set[str],
site_packages_path: str,
) -> dict[str, str]:
pending_modules = sorted(module_names)
unresolved_reasons: dict[str, str] = {}
max_rounds = max(2, min(6, len(pending_modules) + 1))
for _ in range(max_rounds):
if not pending_modules:
break
next_round_pending: list[str] = []
round_progress = False
for module_name in pending_modules:
try:
loaded = _prefer_module_with_dependency_recovery(
module_name,
site_packages_path,
)
except Exception as exc:
unresolved_reasons[module_name] = str(exc)
next_round_pending.append(module_name)
continue
unresolved_reasons.pop(module_name, None)
if loaded:
round_progress = True
else:
logger.debug(
"Module %s not found in plugin site-packages: %s",
module_name,
site_packages_path,
)
if not next_round_pending:
pending_modules = []
break
if not round_progress and len(next_round_pending) == len(pending_modules):
pending_modules = next_round_pending
break
pending_modules = next_round_pending
final_unresolved = {
module_name: unresolved_reasons.get(module_name, "unknown import error")
for module_name in pending_modules
}
for module_name, reason in final_unresolved.items():
logger.warning(
"Failed to prefer module %s from plugin site-packages: %s",
module_name,
reason,
)
return final_unresolved
def _ensure_plugin_dependencies_preferred(
target_site_packages: str,
requested_requirements: set[str],
) -> None:
if not requested_requirements:
return
candidate_modules = _collect_candidate_modules(
requested_requirements,
target_site_packages,
)
if not candidate_modules:
return
_ensure_preferred_modules(candidate_modules, target_site_packages)
def _get_loader_for_package(package: object) -> object | None:
loader = getattr(package, "__loader__", None)
if loader is not None:
return loader
spec = getattr(package, "__spec__", None)
if spec is None:
return None
return getattr(spec, "loader", None)
def _try_register_distlib_finder(
distlib_resources: object,
finder_registry: dict[type, object],
register_finder,
resource_finder: object,
loader: object,
package_name: str,
) -> bool:
loader_type = type(loader)
if loader_type in finder_registry:
return False
try:
register_finder(loader, resource_finder)
except Exception as exc:
logger.warning(
"Failed to patch pip distlib finder for loader %s (%s): %s",
loader_type.__name__,
package_name,
exc,
)
return False
updated_registry = getattr(distlib_resources, "_finder_registry", finder_registry)
if isinstance(updated_registry, dict) and loader_type not in updated_registry:
logger.warning(
"Distlib finder patch did not take effect for loader %s (%s).",
loader_type.__name__,
package_name,
)
return False
logger.info(
"Patched pip distlib finder for frozen loader: %s (%s)",
loader_type.__name__,
package_name,
)
return True
def _patch_distlib_finder_for_frozen_runtime() -> None:
global _DISTLIB_FINDER_PATCH_ATTEMPTED
if not getattr(sys, "frozen", False):
return
if _DISTLIB_FINDER_PATCH_ATTEMPTED:
return
_DISTLIB_FINDER_PATCH_ATTEMPTED = True
try:
from pip._vendor.distlib import resources as distlib_resources
except Exception:
return
finder_registry = getattr(distlib_resources, "_finder_registry", None)
register_finder = getattr(distlib_resources, "register_finder", None)
resource_finder = getattr(distlib_resources, "ResourceFinder", None)
if not isinstance(finder_registry, dict):
logger.warning(
"Skip patching distlib finder because _finder_registry is unavailable."
)
return
if not callable(register_finder) or resource_finder is None:
logger.warning(
"Skip patching distlib finder because register API is unavailable."
)
return
for package_name in ("pip._vendor.distlib", "pip._vendor"):
try:
package = importlib.import_module(package_name)
except Exception:
continue
loader = _get_loader_for_package(package)
if loader is None:
continue
if _try_register_distlib_finder(
distlib_resources,
finder_registry,
register_finder,
resource_finder,
loader,
package_name,
):
finder_registry = getattr(
distlib_resources, "_finder_registry", finder_registry
)
class PipInstaller:
def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None) -> None:
self.pip_install_arg = pip_install_arg
@@ -542,68 +72,60 @@ class PipInstaller:
mirror: str | None = None,
) -> None:
args = ["install"]
requested_requirements: set[str] = set()
if package_name:
args.append(package_name)
requirement_name = _extract_requirement_name(package_name)
if requirement_name:
requested_requirements.add(requirement_name)
elif requirements_path:
args.extend(["-r", requirements_path])
requested_requirements = _extract_requirement_names(requirements_path)
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
target_site_packages = None
if is_packaged_electron_runtime():
if _is_frozen_runtime():
target_site_packages = get_astrbot_site_packages_path()
os.makedirs(target_site_packages, exist_ok=True)
_prepend_sys_path(target_site_packages)
args.extend(["--target", target_site_packages])
args.extend(["--upgrade", "--force-reinstall"])
if self.pip_install_arg:
args.extend(self.pip_install_arg.split())
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
result_code = await self._run_pip_in_process(args)
result_code = None
if _is_frozen_runtime():
result_code = await self._run_pip_in_process(args)
else:
try:
result_code = await self._run_pip_subprocess(args)
except FileNotFoundError:
result_code = await self._run_pip_in_process(args)
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")
if target_site_packages:
_prepend_sys_path(target_site_packages)
_ensure_plugin_dependencies_preferred(
target_site_packages,
requested_requirements,
)
if target_site_packages and target_site_packages not in sys.path:
sys.path.insert(0, target_site_packages)
importlib.invalidate_caches()
def prefer_installed_dependencies(self, requirements_path: str) -> None:
"""优先使用已安装在插件 site-packages 中的依赖,不执行安装。"""
if not is_packaged_electron_runtime():
return
target_site_packages = get_astrbot_site_packages_path()
if not os.path.isdir(target_site_packages):
return
requested_requirements = _extract_requirement_names(requirements_path)
if not requested_requirements:
return
_prepend_sys_path(target_site_packages)
_ensure_plugin_dependencies_preferred(
target_site_packages,
requested_requirements,
async def _run_pip_subprocess(self, args: list[str]) -> int:
process = await asyncio.create_subprocess_exec(
sys.executable,
"-m",
"pip",
*args,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
importlib.invalidate_caches()
assert process.stdout is not None
async for line in process.stdout:
logger.info(_robust_decode(line))
await process.wait()
return process.returncode
async def _run_pip_in_process(self, args: list[str]) -> int:
pip_main = _get_pip_main()
_patch_distlib_finder_for_frozen_runtime()
original_handlers = list(logging.getLogger().handlers)
result_code, output = await asyncio.to_thread(
_run_pip_main_with_output, pip_main, args
@@ -1,8 +0,0 @@
from __future__ import annotations
from .extractor import extract_quoted_message_images, extract_quoted_message_text
__all__ = [
"extract_quoted_message_text",
"extract_quoted_message_images",
]
@@ -1,505 +0,0 @@
from __future__ import annotations
import json
import re
from typing import Any, TypedDict
from astrbot.core.message.components import (
At,
AtAll,
File,
Forward,
Image,
Node,
Nodes,
Plain,
Reply,
Video,
)
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from .image_refs import looks_like_image_file_name, normalize_file_like_url
from .settings import SETTINGS, QuotedMessageParserSettings
_FORWARD_PLACEHOLDER_PATTERN = re.compile(
r"^(?:[\(\[]?[^\]:\)]*[\)\]]?\s*:\s*)?\[(?:forward message|转发消息|合并转发)\]$",
flags=re.IGNORECASE,
)
class ParsedOneBotPayload(TypedDict):
text: str | None
forward_ids: list[str]
image_refs: list[str]
def _build_parsed_payload(
text: str | None,
forward_ids: list[str] | None = None,
image_refs: list[str] | None = None,
) -> ParsedOneBotPayload:
return {
"text": text,
"forward_ids": forward_ids or [],
"image_refs": image_refs or [],
}
def _join_text_parts(parts: list[str]) -> str | None:
text = "".join(parts).strip()
return text or None
def _find_first_reply_component(event: AstrMessageEvent) -> Reply | None:
for comp in event.message_obj.message:
if isinstance(comp, Reply):
return comp
return None
def _is_forward_placeholder_only_text(text: str | None) -> bool:
if not isinstance(text, str):
return False
lines = [line.strip() for line in text.splitlines() if line.strip()]
if not lines:
return False
return all(_FORWARD_PLACEHOLDER_PATTERN.match(line) for line in lines)
def _extract_image_refs_from_component_chain(
chain: list[Any] | None,
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> list[str]:
if not isinstance(chain, list) or depth > settings.max_component_chain_depth:
return []
image_refs: list[str] = []
for seg in chain:
if isinstance(seg, Image):
for candidate in (seg.url, seg.file, seg.path):
if isinstance(candidate, str) and candidate.strip():
image_refs.append(candidate.strip())
break
elif isinstance(seg, Reply):
image_refs.extend(
_extract_image_refs_from_reply_component(
seg,
depth=depth + 1,
settings=settings,
)
)
elif isinstance(seg, Node):
image_refs.extend(
_extract_image_refs_from_component_chain(
seg.content,
depth=depth + 1,
settings=settings,
)
)
elif isinstance(seg, Nodes):
for node in seg.nodes:
image_refs.extend(
_extract_image_refs_from_component_chain(
node.content,
depth=depth + 1,
settings=settings,
)
)
return normalize_and_dedupe_strings(image_refs)
def _extract_text_from_component_chain(
chain: list[Any] | None,
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> str | None:
if not isinstance(chain, list) or depth > settings.max_component_chain_depth:
return None
parts: list[str] = []
for seg in chain:
if isinstance(seg, Plain):
if seg.text:
parts.append(seg.text)
elif isinstance(seg, At):
if seg.name:
parts.append(f"@{seg.name}")
elif seg.qq:
parts.append(f"@{seg.qq}")
elif isinstance(seg, AtAll):
parts.append("@all")
elif isinstance(seg, Image):
parts.append("[Image]")
elif isinstance(seg, Video):
parts.append("[Video]")
elif isinstance(seg, File):
file_name = seg.name or "file"
parts.append(f"[File:{file_name}]")
elif isinstance(seg, Forward):
parts.append("[Forward Message]")
elif isinstance(seg, Reply):
nested = _extract_text_from_reply_component(
seg,
depth=depth + 1,
settings=settings,
)
if nested:
parts.append(nested)
elif isinstance(seg, Node):
node_sender = seg.name or seg.uin or "Unknown User"
node_text = _extract_text_from_component_chain(
seg.content,
depth=depth + 1,
settings=settings,
)
if node_text:
parts.append(f"{node_sender}: {node_text}")
elif isinstance(seg, Nodes):
for node in seg.nodes:
node_sender = node.name or node.uin or "Unknown User"
node_text = _extract_text_from_component_chain(
node.content,
depth=depth + 1,
settings=settings,
)
if node_text:
parts.append(f"{node_sender}: {node_text}")
return _join_text_parts(parts)
def _extract_image_refs_from_reply_component(
reply: Reply,
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> list[str]:
for attr in ("chain", "message", "origin", "content"):
payload = getattr(reply, attr, None)
image_refs = _extract_image_refs_from_component_chain(
payload,
depth=depth,
settings=settings,
)
if image_refs:
return image_refs
return []
def _extract_text_from_reply_component(
reply: Reply,
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> str | None:
for attr in ("chain", "message", "origin", "content"):
payload = getattr(reply, attr, None)
text = _extract_text_from_component_chain(
payload,
depth=depth,
settings=settings,
)
if text:
return text
if reply.message_str and reply.message_str.strip():
return reply.message_str.strip()
return None
def _unwrap_onebot_data(payload: Any) -> dict[str, Any]:
if not isinstance(payload, dict):
return {}
data = payload.get("data")
if isinstance(data, dict):
return data
return payload
def _extract_text_from_multimsg_json(raw_json: str) -> str | None:
try:
parsed = json.loads(raw_json)
except Exception:
return None
if not isinstance(parsed, dict):
return None
if parsed.get("app") != "com.tencent.multimsg":
return None
config = parsed.get("config")
if not isinstance(config, dict):
return None
if config.get("forward") != 1:
return None
meta = parsed.get("meta")
if not isinstance(meta, dict):
return None
detail = meta.get("detail")
if not isinstance(detail, dict):
return None
news_items = detail.get("news")
if not isinstance(news_items, list):
return None
texts: list[str] = []
for item in news_items:
if not isinstance(item, dict):
continue
text_content = item.get("text")
if not isinstance(text_content, str):
continue
cleaned = text_content.strip().replace("[图片]", "").strip()
if cleaned:
texts.append(cleaned)
return "\n".join(texts).strip() or None
def _parse_onebot_segments(
segments: list[Any],
*,
settings: QuotedMessageParserSettings = SETTINGS,
) -> ParsedOneBotPayload:
text_parts: list[str] = []
forward_ids: list[str] = []
image_refs: list[str] = []
for seg in segments:
if not isinstance(seg, dict):
continue
seg_type = seg.get("type")
seg_data = seg.get("data", {}) if isinstance(seg.get("data"), dict) else {}
if seg_type in ("text", "plain"):
text = seg_data.get("text")
if isinstance(text, str) and text:
text_parts.append(text)
elif seg_type == "image":
text_parts.append("[Image]")
candidate = seg_data.get("url") or seg_data.get("file")
if isinstance(candidate, str) and candidate.strip():
image_refs.append(candidate.strip())
elif seg_type == "video":
text_parts.append("[Video]")
elif seg_type == "file":
file_name = (
seg_data.get("name")
or seg_data.get("file_name")
or seg_data.get("file")
or "file"
)
text_parts.append(f"[File:{file_name}]")
candidate_url = seg_data.get("url")
if (
isinstance(candidate_url, str)
and candidate_url.strip()
and looks_like_image_file_name(normalize_file_like_url(candidate_url))
):
image_refs.append(candidate_url.strip())
candidate_file = seg_data.get("file")
if (
isinstance(candidate_file, str)
and candidate_file.strip()
and looks_like_image_file_name(
normalize_file_like_url(
seg_data.get("name")
or seg_data.get("file_name")
or candidate_file
)
)
):
image_refs.append(candidate_file.strip())
elif seg_type in ("forward", "forward_msg", "nodes"):
fid = seg_data.get("id") or seg_data.get("message_id")
if isinstance(fid, (str, int)) and str(fid):
forward_ids.append(str(fid))
else:
nested_nodes = seg_data.get("content")
nested_text, nested_forward_ids, nested_images = (
_extract_text_forward_ids_and_images_from_forward_nodes(
nested_nodes if isinstance(nested_nodes, list) else [],
depth=1,
settings=settings,
)
)
if nested_text:
text_parts.append(nested_text)
if nested_forward_ids:
forward_ids.extend(nested_forward_ids)
if nested_images:
image_refs.extend(nested_images)
elif seg_type == "json":
raw_json = seg_data.get("data")
if isinstance(raw_json, str) and raw_json.strip():
raw_json = raw_json.replace("&#44;", ",")
multimsg_text = _extract_text_from_multimsg_json(raw_json)
if multimsg_text:
text_parts.append(multimsg_text)
return _build_parsed_payload(
_join_text_parts(text_parts),
forward_ids,
normalize_and_dedupe_strings(image_refs),
)
def _extract_text_forward_ids_and_images_from_forward_nodes(
nodes: list[Any],
*,
depth: int = 0,
settings: QuotedMessageParserSettings = SETTINGS,
) -> tuple[str | None, list[str], list[str]]:
if not isinstance(nodes, list) or depth > settings.max_forward_node_depth:
return None, [], []
texts: list[str] = []
forward_ids: list[str] = []
image_refs: list[str] = []
indent = " " * depth
for node in nodes:
if not isinstance(node, dict):
continue
sender = node.get("sender") if isinstance(node.get("sender"), dict) else {}
sender_name = (
sender.get("nickname")
or sender.get("card")
or sender.get("user_id")
or "Unknown User"
)
raw_content = node.get("message") or node.get("content") or []
chain: list[Any] = []
if isinstance(raw_content, list):
chain = raw_content
elif isinstance(raw_content, str):
raw_content = raw_content.strip()
if raw_content:
try:
parsed = json.loads(raw_content)
except Exception:
parsed = None
if isinstance(parsed, list):
chain = parsed
else:
chain = [{"type": "text", "data": {"text": raw_content}}]
parsed_segments = _parse_onebot_segments(chain, settings=settings)
node_text = parsed_segments["text"]
node_forward_ids = parsed_segments["forward_ids"]
node_images = parsed_segments["image_refs"]
if node_text:
texts.append(f"{indent}{sender_name}: {node_text}")
if node_forward_ids:
forward_ids.extend(node_forward_ids)
if node_images:
image_refs.extend(node_images)
return (
"\n".join(texts).strip() or None,
normalize_and_dedupe_strings(forward_ids),
normalize_and_dedupe_strings(image_refs),
)
def _parse_onebot_get_msg_payload(
payload: dict[str, Any],
*,
settings: QuotedMessageParserSettings = SETTINGS,
) -> ParsedOneBotPayload:
data = _unwrap_onebot_data(payload)
segments = data.get("message") or data.get("messages")
if isinstance(segments, list):
return _parse_onebot_segments(segments, settings=settings)
text: str | None = None
if isinstance(segments, str) and segments.strip():
text = segments.strip()
else:
raw = data.get("raw_message")
if isinstance(raw, str) and raw.strip():
text = raw.strip()
return _build_parsed_payload(text)
def _parse_onebot_get_forward_payload(
payload: dict[str, Any],
*,
settings: QuotedMessageParserSettings = SETTINGS,
) -> ParsedOneBotPayload:
data = _unwrap_onebot_data(payload)
nodes = (
data.get("messages")
or data.get("message")
or data.get("nodes")
or data.get("nodeList")
)
if not isinstance(nodes, list):
return _build_parsed_payload(None)
text, forward_ids, image_refs = (
_extract_text_forward_ids_and_images_from_forward_nodes(
nodes,
settings=settings,
)
)
return _build_parsed_payload(text, forward_ids, image_refs)
class ReplyChainParser:
def __init__(self, settings: QuotedMessageParserSettings = SETTINGS):
self._settings = settings
@staticmethod
def find_first_reply_component(event: AstrMessageEvent) -> Reply | None:
return _find_first_reply_component(event)
@staticmethod
def is_forward_placeholder_only_text(text: str | None) -> bool:
return _is_forward_placeholder_only_text(text)
def extract_text_from_reply_component(
self,
reply: Reply,
*,
depth: int = 0,
) -> str | None:
return _extract_text_from_reply_component(
reply,
depth=depth,
settings=self._settings,
)
def extract_image_refs_from_reply_component(
self,
reply: Reply,
*,
depth: int = 0,
) -> list[str]:
return _extract_image_refs_from_reply_component(
reply,
depth=depth,
settings=self._settings,
)
class OneBotPayloadParser:
def __init__(self, settings: QuotedMessageParserSettings = SETTINGS):
self._settings = settings
def parse_get_msg_payload(self, payload: dict[str, Any]) -> ParsedOneBotPayload:
return _parse_onebot_get_msg_payload(payload, settings=self._settings)
def parse_get_forward_payload(
self,
payload: dict[str, Any],
) -> ParsedOneBotPayload:
return _parse_onebot_get_forward_payload(payload, settings=self._settings)
@@ -1,211 +0,0 @@
from __future__ import annotations
from dataclasses import dataclass
from astrbot import logger
from astrbot.core.message.components import Reply
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from .chain_parser import OneBotPayloadParser, ReplyChainParser
from .image_resolver import ImageResolver
from .onebot_client import OneBotClient
from .settings import SETTINGS, QuotedMessageParserSettings
async def _collect_text_and_images_from_forward_ids(
onebot_client: OneBotClient,
payload_parser: OneBotPayloadParser,
forward_ids: list[str],
*,
max_fetch: int,
) -> tuple[list[str], list[str]]:
texts: list[str] = []
image_refs: list[str] = []
pending: list[str] = []
seen: set[str] = set()
for fid in forward_ids:
if not isinstance(fid, str):
continue
cleaned = fid.strip()
if cleaned:
pending.append(cleaned)
fetch_count = 0
while pending and fetch_count < max_fetch:
current_id = pending.pop(0)
if current_id in seen:
continue
seen.add(current_id)
fetch_count += 1
forward_payload = await onebot_client.get_forward_msg(current_id)
if not forward_payload:
continue
parsed = payload_parser.parse_get_forward_payload(forward_payload)
if parsed["text"]:
texts.append(parsed["text"])
if parsed["image_refs"]:
image_refs.extend(parsed["image_refs"])
for nested_id in parsed["forward_ids"]:
if nested_id not in seen:
pending.append(nested_id)
if pending:
logger.warning(
"quoted_message_parser: stop fetching nested forward messages after %d hops",
max_fetch,
)
return texts, normalize_and_dedupe_strings(image_refs)
@dataclass(slots=True)
class QuotedMessageContent:
embedded_text: str | None
embedded_image_refs: list[str]
reply_id: str
direct_text: str | None
direct_image_refs: list[str]
forward_texts: list[str]
forward_image_refs: list[str]
class QuotedMessageExtractor:
def __init__(
self,
event: AstrMessageEvent,
settings: QuotedMessageParserSettings = SETTINGS,
):
self._event = event
self._settings = settings
self._reply_parser = ReplyChainParser(settings=settings)
self._payload_parser = OneBotPayloadParser(settings=settings)
self._client = OneBotClient(event, settings=settings)
self._image_resolver = ImageResolver(event, self._client)
async def _fetch_quoted_content(
self,
reply_component: Reply | None = None,
*,
fetch_remote: bool,
) -> QuotedMessageContent | None:
reply = reply_component or self._reply_parser.find_first_reply_component(
self._event
)
if not reply:
return None
embedded_text = self._reply_parser.extract_text_from_reply_component(reply)
embedded_image_refs = list(
self._reply_parser.extract_image_refs_from_reply_component(reply)
)
reply_id = getattr(reply, "id", None)
reply_id_str = str(reply_id).strip() if reply_id is not None else ""
if not fetch_remote or not reply_id_str:
return QuotedMessageContent(
embedded_text=embedded_text,
embedded_image_refs=embedded_image_refs,
reply_id=reply_id_str,
direct_text=None,
direct_image_refs=[],
forward_texts=[],
forward_image_refs=[],
)
msg_payload = await self._client.get_msg(reply_id_str)
if not msg_payload:
return QuotedMessageContent(
embedded_text=embedded_text,
embedded_image_refs=embedded_image_refs,
reply_id=reply_id_str,
direct_text=None,
direct_image_refs=[],
forward_texts=[],
forward_image_refs=[],
)
parsed = self._payload_parser.parse_get_msg_payload(msg_payload)
forward_texts, forward_images = await _collect_text_and_images_from_forward_ids(
self._client,
self._payload_parser,
parsed["forward_ids"],
max_fetch=self._settings.max_forward_fetch,
)
return QuotedMessageContent(
embedded_text=embedded_text,
embedded_image_refs=embedded_image_refs,
reply_id=reply_id_str,
direct_text=parsed["text"],
direct_image_refs=list(parsed["image_refs"]),
forward_texts=forward_texts,
forward_image_refs=forward_images,
)
async def text(self, reply_component: Reply | None = None) -> str | None:
embedded_content = await self._fetch_quoted_content(
reply_component,
fetch_remote=False,
)
if not embedded_content:
return None
if (
embedded_content.embedded_text
and not self._reply_parser.is_forward_placeholder_only_text(
embedded_content.embedded_text
)
):
return embedded_content.embedded_text
if not embedded_content.reply_id:
return embedded_content.embedded_text
fetched_content = await self._fetch_quoted_content(
reply_component,
fetch_remote=True,
)
if not fetched_content:
return embedded_content.embedded_text
text_parts: list[str] = []
if fetched_content.direct_text:
text_parts.append(fetched_content.direct_text)
text_parts.extend(fetched_content.forward_texts)
return "\n".join(text_parts).strip() or embedded_content.embedded_text
async def images(self, reply_component: Reply | None = None) -> list[str]:
content = await self._fetch_quoted_content(reply_component, fetch_remote=True)
if not content:
return []
image_refs: list[str] = []
image_refs.extend(content.embedded_image_refs)
image_refs.extend(content.direct_image_refs)
image_refs.extend(content.forward_image_refs)
return await self._image_resolver.resolve_for_llm(image_refs)
async def extract_quoted_message_text(
event: AstrMessageEvent,
reply_component: Reply | None = None,
settings: QuotedMessageParserSettings | None = None,
) -> str | None:
return await QuotedMessageExtractor(event, settings=settings or SETTINGS).text(
reply_component
)
async def extract_quoted_message_images(
event: AstrMessageEvent,
reply_component: Reply | None = None,
settings: QuotedMessageParserSettings | None = None,
) -> list[str]:
return await QuotedMessageExtractor(event, settings=settings or SETTINGS).images(
reply_component
)
@@ -1,94 +0,0 @@
from __future__ import annotations
import os
from urllib.parse import urlsplit
IMAGE_EXTENSIONS = {
".jpg",
".jpeg",
".png",
".webp",
".bmp",
".tif",
".tiff",
".gif",
}
def normalize_file_like_url(path: str | None) -> str | None:
if path is None:
return None
if not isinstance(path, str):
return None
if "?" not in path and "#" not in path:
return path
try:
split = urlsplit(path)
except Exception:
return path
return split.path or path
def looks_like_image_file_name(name: str) -> bool:
normalized_name = normalize_file_like_url(name)
if not isinstance(normalized_name, str) or not normalized_name.strip():
return False
_, ext = os.path.splitext(normalized_name.strip().lower())
return ext in IMAGE_EXTENSIONS
def convert_data_image_to_base64_ref(image_ref: str) -> str | None:
if not isinstance(image_ref, str):
return None
value = image_ref.strip()
if not value:
return None
lower_value = value.lower()
if not lower_value.startswith("data:image/"):
return None
comma_index = value.find(",")
if comma_index <= 0:
return None
header = value[:comma_index].lower()
payload = value[comma_index + 1 :].strip()
if ";base64" not in header or not payload:
return None
return f"base64://{payload}"
def get_existing_local_path(value: str) -> str | None:
lower_value = value.lower()
if lower_value.startswith("file://"):
file_path = value[7:]
if file_path.startswith("/") and len(file_path) > 3 and file_path[2] == ":":
file_path = file_path[1:]
if file_path and os.path.exists(file_path):
return os.path.abspath(file_path)
return None
if os.path.exists(value):
return os.path.abspath(value)
return None
def normalize_image_ref(image_ref: str) -> str | None:
if not isinstance(image_ref, str):
return None
value = image_ref.strip()
if not value:
return None
lower_value = value.lower()
if lower_value.startswith(("http://", "https://")):
return value
if lower_value.startswith("base64://"):
return value
data_image_ref = convert_data_image_to_base64_ref(value)
if data_image_ref:
return data_image_ref
local_path = get_existing_local_path(value)
if local_path and looks_like_image_file_name(local_path):
return local_path
return None
@@ -1,130 +0,0 @@
from __future__ import annotations
import os
from typing import Any
from astrbot import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.utils.string_utils import normalize_and_dedupe_strings
from .image_refs import IMAGE_EXTENSIONS, get_existing_local_path, normalize_image_ref
from .onebot_client import OneBotClient
def _build_image_id_candidates(image_ref: str) -> list[str]:
candidates: list[str] = [image_ref]
base_name, ext = os.path.splitext(image_ref)
if ext and base_name and base_name not in candidates:
if ext.lower() in IMAGE_EXTENSIONS:
candidates.append(base_name)
return candidates
def _build_image_resolve_actions(
event: AstrMessageEvent,
image_ref: str,
) -> list[tuple[str, dict[str, Any]]]:
actions: list[tuple[str, dict[str, Any]]] = []
candidates = _build_image_id_candidates(image_ref)
for candidate in candidates:
actions.extend(
[
("get_image", {"file": candidate}),
("get_image", {"file_id": candidate}),
("get_image", {"id": candidate}),
("get_image", {"image": candidate}),
("get_file", {"file_id": candidate}),
("get_file", {"file": candidate}),
]
)
try:
group_id = event.get_group_id()
except Exception:
group_id = None
group_id_value = group_id
if isinstance(group_id, str) and group_id.isdigit():
group_id_value = int(group_id)
if group_id_value:
for candidate in candidates:
actions.append(
(
"get_group_file_url",
{"group_id": group_id_value, "file_id": candidate},
)
)
for candidate in candidates:
actions.append(("get_private_file_url", {"file_id": candidate}))
return actions
class ImageResolver:
def __init__(
self,
event: AstrMessageEvent,
onebot_client: OneBotClient | None = None,
):
self._event = event
self._client = onebot_client or OneBotClient(event)
async def resolve_for_llm(self, image_refs: list[str]) -> list[str]:
resolved: list[str] = []
unresolved: list[str] = []
for image_ref in normalize_and_dedupe_strings(image_refs):
normalized = normalize_image_ref(image_ref)
if normalized:
resolved.append(normalized)
elif get_existing_local_path(image_ref):
# Drop non-image local paths instead of treating them as remote IDs.
logger.debug(
"quoted_message_parser: skip non-image local path ref=%s",
image_ref[:128],
)
else:
unresolved.append(image_ref)
for image_ref in unresolved:
resolved_ref = await self._resolve_one(image_ref)
if resolved_ref:
resolved.append(resolved_ref)
return normalize_and_dedupe_strings(resolved)
async def _resolve_one(self, image_ref: str) -> str | None:
resolved = normalize_image_ref(image_ref)
if resolved:
return resolved
actions = _build_image_resolve_actions(self._event, image_ref)
for action, params in actions:
data = await self._client.call(
action,
params,
warn_on_all_failed=False,
unwrap_data=True,
)
if not isinstance(data, dict):
continue
url = data.get("url")
if isinstance(url, str):
normalized = normalize_image_ref(url)
if normalized:
return normalized
file_value = data.get("file")
if isinstance(file_value, str):
normalized = normalize_image_ref(file_value)
if normalized:
return normalized
logger.warning(
"quoted_message_parser: failed to resolve quoted image ref=%s after %d actions",
image_ref[:128],
len(actions),
)
return None
@@ -1,119 +0,0 @@
from __future__ import annotations
from typing import Any
from astrbot import logger
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .settings import SETTINGS, QuotedMessageParserSettings
def _unwrap_action_response(ret: dict[str, Any] | None) -> dict[str, Any]:
if not isinstance(ret, dict):
return {}
data = ret.get("data")
if isinstance(data, dict):
return data
return ret
class OneBotClient:
def __init__(
self,
event: AstrMessageEvent,
settings: QuotedMessageParserSettings = SETTINGS,
):
self._call_action = self._resolve_call_action(event)
self._settings = settings
@staticmethod
def _resolve_call_action(event: AstrMessageEvent):
bot = getattr(event, "bot", None)
api = getattr(bot, "api", None)
call_action = getattr(api, "call_action", None)
if not callable(call_action):
return None
return call_action
async def _call_action_try_params(
self,
action: str,
params_list: list[dict[str, Any]],
*,
warn_on_all_failed: bool | None = None,
) -> dict[str, Any] | None:
if self._call_action is None:
return None
if warn_on_all_failed is None:
warn_on_all_failed = self._settings.warn_on_action_failure
last_error: Exception | None = None
last_params: dict[str, Any] | None = None
for params in params_list:
try:
result = await self._call_action(action, **params)
if isinstance(result, dict):
return result
except Exception as exc:
last_error = exc
last_params = params
logger.debug(
"quoted_message_parser: action %s failed with params %s: %s",
action,
{k: str(v)[:64] for k, v in params.items()},
exc,
)
if warn_on_all_failed and last_error is not None:
logger.warning(
"quoted_message_parser: all attempts failed for action %s, "
"last_params=%s, error=%s",
action,
(
{k: str(v)[:64] for k, v in last_params.items()}
if isinstance(last_params, dict)
else None
),
last_error,
)
return None
async def call(
self,
action: str,
params: dict[str, Any],
*,
warn_on_all_failed: bool = False,
unwrap_data: bool = True,
) -> dict[str, Any] | None:
ret = await self._call_action_try_params(
action,
[params],
warn_on_all_failed=warn_on_all_failed,
)
if not unwrap_data:
return ret
return _unwrap_action_response(ret)
async def _call_action_compat(
self,
action: str,
message_id: str | int,
) -> dict[str, Any] | None:
message_id_str = str(message_id).strip()
if not message_id_str:
return None
params_list: list[dict[str, Any]] = [
{"message_id": message_id_str},
{"id": message_id_str},
]
if message_id_str.isdigit():
int_id = int(message_id_str)
params_list.extend([{"message_id": int_id}, {"id": int_id}])
return await self._call_action_try_params(action, params_list)
async def get_msg(self, message_id: str | int) -> dict[str, Any] | None:
return await self._call_action_compat("get_msg", message_id)
async def get_forward_msg(self, forward_id: str | int) -> dict[str, Any] | None:
return await self._call_action_compat("get_forward_msg", forward_id)
@@ -1,85 +0,0 @@
from __future__ import annotations
from collections.abc import Mapping
from dataclasses import dataclass
from typing import Any
_DEFAULT_MAX_COMPONENT_CHAIN_DEPTH = 4
_DEFAULT_MAX_FORWARD_NODE_DEPTH = 6
_DEFAULT_MAX_FORWARD_FETCH = 32
def _read_int_mapping(
mapping: Mapping[str, Any],
key: str,
default: int,
) -> int:
raw = mapping.get(key)
if raw is None:
return default
try:
value = int(raw)
except (TypeError, ValueError):
return default
if value <= 0:
return default
return value
def _read_bool_mapping(
mapping: Mapping[str, Any],
key: str,
default: bool,
) -> bool:
raw = mapping.get(key)
if raw is None:
return default
if isinstance(raw, bool):
return raw
if isinstance(raw, str):
lowered = raw.strip().lower()
if lowered in {"1", "true", "yes", "on"}:
return True
if lowered in {"0", "false", "no", "off"}:
return False
return default
@dataclass(frozen=True)
class QuotedMessageParserSettings:
max_component_chain_depth: int = _DEFAULT_MAX_COMPONENT_CHAIN_DEPTH
max_forward_node_depth: int = _DEFAULT_MAX_FORWARD_NODE_DEPTH
max_forward_fetch: int = _DEFAULT_MAX_FORWARD_FETCH
warn_on_action_failure: bool = False
def with_overrides(
self,
overrides: Mapping[str, Any] | None = None,
) -> QuotedMessageParserSettings:
if not overrides:
return self
return QuotedMessageParserSettings(
max_component_chain_depth=_read_int_mapping(
overrides,
"max_component_chain_depth",
self.max_component_chain_depth,
),
max_forward_node_depth=_read_int_mapping(
overrides,
"max_forward_node_depth",
self.max_forward_node_depth,
),
max_forward_fetch=_read_int_mapping(
overrides,
"max_forward_fetch",
self.max_forward_fetch,
),
warn_on_action_failure=_read_bool_mapping(
overrides,
"warn_on_action_failure",
self.warn_on_action_failure,
),
)
SETTINGS = QuotedMessageParserSettings()
@@ -1,11 +0,0 @@
from __future__ import annotations
from astrbot.core.utils.quoted_message.extractor import (
extract_quoted_message_images,
extract_quoted_message_text,
)
__all__ = [
"extract_quoted_message_text",
"extract_quoted_message_images",
]
-10
View File
@@ -1,10 +0,0 @@
import os
import sys
def is_frozen_runtime() -> bool:
return bool(getattr(sys, "frozen", False))
def is_packaged_electron_runtime() -> bool:
return is_frozen_runtime() and os.environ.get("ASTRBOT_ELECTRON_CLIENT") == "1"
-21
View File
@@ -1,21 +0,0 @@
from __future__ import annotations
from collections.abc import Iterable
from typing import Any
def normalize_and_dedupe_strings(items: Iterable[Any] | None) -> list[str]:
if items is None:
return []
normalized: list[str] = []
seen: set[str] = set()
for item in items:
if not isinstance(item, str):
continue
cleaned = item.strip()
if not cleaned or cleaned in seen:
continue
seen.add(cleaned)
normalized.append(cleaned)
return normalized
+6 -6
View File
@@ -1,11 +1,12 @@
import asyncio
import logging
import random
import ssl
import aiohttp
import certifi
from astrbot.core.config import VERSION
from astrbot.core.utils.http_ssl import build_tls_connector
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.utils.t2i.template_manager import TemplateManager
@@ -38,10 +39,7 @@ class NetworkRenderStrategy(RenderStrategy):
async def get_official_endpoints(self) -> None:
"""获取官方的 t2i 端点列表。"""
try:
async with aiohttp.ClientSession(
trust_env=True,
connector=build_tls_connector(),
) as session:
async with aiohttp.ClientSession() as session:
async with session.get(
"https://api.soulter.top/astrbot/t2i-endpoints",
) as resp:
@@ -90,10 +88,12 @@ class NetworkRenderStrategy(RenderStrategy):
for endpoint in endpoints:
try:
if return_url:
ssl_context = ssl.create_default_context(cafile=certifi.where())
connector = aiohttp.TCPConnector(ssl=ssl_context)
async with (
aiohttp.ClientSession(
trust_env=True,
connector=build_tls_connector(),
connector=connector,
) as session,
session.post(
f"{endpoint}/generate",
-150
View File
@@ -1,150 +0,0 @@
import asyncio
from collections.abc import Callable
from dataclasses import dataclass
from pathlib import Path
from astrbot import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
def parse_size_to_bytes(value: str | int | float | None) -> int:
"""Parse size in MB to bytes."""
if value is None:
return 0
try:
size_mb = float(str(value).strip())
except (TypeError, ValueError):
return 0
if size_mb <= 0:
return 0
return int(size_mb * 1024**2)
@dataclass
class TempFileInfo:
path: Path
size: int
mtime: float
class TempDirCleaner:
CONFIG_KEY = "temp_dir_max_size"
DEFAULT_MAX_SIZE = 1024
CHECK_INTERVAL_SECONDS = 10 * 60
CLEANUP_RATIO = 0.30
def __init__(
self,
max_size_getter: Callable[[], str | int | float | None],
temp_dir: Path | None = None,
) -> None:
self._max_size_getter = max_size_getter
self._temp_dir = temp_dir or Path(get_astrbot_temp_path())
self._stop_event = asyncio.Event()
def _limit_bytes(self) -> int:
configured = self._max_size_getter()
parsed = parse_size_to_bytes(configured)
if parsed <= 0:
fallback = parse_size_to_bytes(self.DEFAULT_MAX_SIZE)
logger.warning(
f"Invalid {self.CONFIG_KEY}={configured!r}, fallback to {self.DEFAULT_MAX_SIZE}MB.",
)
return fallback
return parsed
def _scan_temp_files(self) -> tuple[int, list[TempFileInfo]]:
if not self._temp_dir.exists():
return 0, []
total_size = 0
files: list[TempFileInfo] = []
for path in self._temp_dir.rglob("*"):
if not path.is_file():
continue
try:
stat = path.stat()
except OSError as e:
logger.debug(f"Skip temp file {path} due to stat error: {e}")
continue
total_size += stat.st_size
files.append(
TempFileInfo(path=path, size=stat.st_size, mtime=stat.st_mtime)
)
return total_size, files
def _cleanup_empty_dirs(self) -> None:
if not self._temp_dir.exists():
return
for path in sorted(
self._temp_dir.rglob("*"), key=lambda p: len(p.parts), reverse=True
):
if not path.is_dir():
continue
try:
path.rmdir()
except OSError:
continue
def cleanup_once(self) -> None:
limit = self._limit_bytes()
if limit <= 0:
return
total_size, files = self._scan_temp_files()
if total_size <= limit:
return
target_release = max(int(total_size * self.CLEANUP_RATIO), 1)
released = 0
removed_files = 0
for file_info in sorted(files, key=lambda item: item.mtime):
try:
file_info.path.unlink()
except OSError as e:
logger.warning(f"Failed to delete temp file {file_info.path}: {e}")
continue
released += file_info.size
removed_files += 1
if released >= target_release:
break
self._cleanup_empty_dirs()
logger.warning(
f"Temp dir exceeded limit ({total_size} > {limit}). "
f"Removed {removed_files} files, released {released} bytes "
f"(target {target_release} bytes).",
)
async def run(self) -> None:
logger.info(
f"TempDirCleaner started. interval={self.CHECK_INTERVAL_SECONDS}s "
f"cleanup_ratio={self.CLEANUP_RATIO}",
)
while not self._stop_event.is_set():
try:
# File-system traversal and deletion are blocking operations.
# Run cleanup in a worker thread to avoid blocking the event loop.
await asyncio.to_thread(self.cleanup_once)
except Exception as e:
logger.error(f"TempDirCleaner run failed: {e}", exc_info=True)
try:
await asyncio.wait_for(
self._stop_event.wait(),
timeout=self.CHECK_INTERVAL_SECONDS,
)
except asyncio.TimeoutError:
continue
logger.info("TempDirCleaner stopped.")
async def stop(self) -> None:
self._stop_event.set()
+2 -4
View File
@@ -7,7 +7,7 @@ import wave
from io import BytesIO
from astrbot.core import logger
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
@@ -117,13 +117,12 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
except ImportError as e:
raise Exception("未安装 pilk: pip install pilk") from e
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
# 是否需要转换为 WAV
ext = os.path.splitext(audio_path)[1].lower()
temp_wav = tempfile.NamedTemporaryFile(
prefix="tencent_record_",
suffix=".wav",
delete=False,
dir=temp_dir,
@@ -141,7 +140,6 @@ async def audio_to_tencent_silk_base64(audio_path: str) -> tuple[str, float]:
rate = wav_file.getframerate()
silk_path = tempfile.NamedTemporaryFile(
prefix="tencent_record_",
suffix=".silk",
delete=False,
dir=temp_dir,
+1 -17
View File
@@ -1,4 +1,3 @@
import os
import uuid
from astrbot.core import astrbot_config, logger
@@ -21,20 +20,6 @@ def _get_dashboard_port() -> int:
return 6185
def _is_dashboard_ssl_enabled() -> bool:
env_ssl = os.environ.get("DASHBOARD_SSL_ENABLE") or os.environ.get(
"ASTRBOT_DASHBOARD_SSL_ENABLE"
)
if env_ssl is not None:
return env_ssl.strip().lower() in {"1", "true", "yes", "on"}
try:
return bool(astrbot_config.get("dashboard", {}).get("ssl", {}).get("enable"))
except Exception as e:
logger.error(f"获取 dashboard SSL 配置失败: {e!s}")
return False
def log_webhook_info(platform_name: str, webhook_uuid: str) -> None:
"""打印美观的 webhook 信息日志
@@ -53,13 +38,12 @@ def log_webhook_info(platform_name: str, webhook_uuid: str) -> None:
callback_base = callback_base.rstrip("/")
webhook_url = f"{callback_base}/api/platform/webhook/{webhook_uuid}"
scheme = "https" if _is_dashboard_ssl_enabled() else "http"
display_log = (
"\n====================\n"
f"🔗 机器人平台 {platform_name} 已启用统一 Webhook 模式\n"
f"📍 Webhook 回调地址: \n"
f"{scheme}://<your-ip>:{_get_dashboard_port()}/api/platform/webhook/{webhook_uuid}\n"
f"http://<your-ip>:{_get_dashboard_port()}/api/platform/webhook/{webhook_uuid}\n"
f"{webhook_url}\n"
"====================\n"
)
+2 -42
View File
@@ -1290,30 +1290,6 @@ class ConfigRoute(Route):
f"Unexpected error registering logo for platform {platform.name}: {e}",
)
def _inject_platform_metadata_with_i18n(
self, platform, metadata, platform_i18n_translations: dict
):
"""将配置元数据注入到 metadata 中并处理国际化键转换。"""
metadata["platform_group"]["metadata"]["platform"].setdefault("items", {})
platform_items_to_inject = copy.deepcopy(platform.config_metadata)
if platform.i18n_resources:
i18n_prefix = f"platform_group.platform.{platform.name}"
for lang, lang_data in platform.i18n_resources.items():
platform_i18n_translations.setdefault(lang, {}).setdefault(
"platform_group", {}
).setdefault("platform", {})[platform.name] = lang_data
for field_key, field_value in platform_items_to_inject.items():
for key in ("description", "hint", "labels"):
if key in field_value:
field_value[key] = f"{i18n_prefix}.{field_key}.{key}"
metadata["platform_group"]["metadata"]["platform"]["items"].update(
platform_items_to_inject
)
async def _get_astrbot_config(self):
config = self.config
metadata = copy.deepcopy(CONFIG_METADATA_2)
@@ -1335,23 +1311,11 @@ class ConfigRoute(Route):
"config_template"
]
# 收集平台的 i18n 翻译数据
platform_i18n_translations = {}
# 收集需要注册logo的平台
logo_registration_tasks = []
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = copy.deepcopy(
platform.default_config_tmpl
)
# 注入配置元数据(在 convert_to_i18n_keys 之后,使用国际化键)
if platform.config_metadata:
self._inject_platform_metadata_with_i18n(
platform, metadata, platform_i18n_translations
)
platform_default_tmpl[platform.name] = platform.default_config_tmpl
# 收集logo注册任务
if platform.logo_path:
logo_registration_tasks.append(
@@ -1370,11 +1334,7 @@ class ConfigRoute(Route):
if provider.default_config_tmpl:
provider_default_tmpl[provider.type] = provider.default_config_tmpl
return {
"metadata": metadata,
"config": config,
"platform_i18n_translations": platform_i18n_translations,
}
return {"metadata": metadata, "config": config}
async def _get_plugin_config(self, plugin_name: str):
ret: dict = {"metadata": None, "config": None}
+1 -5
View File
@@ -12,7 +12,6 @@ from quart import request
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from ..utils import generate_tsne_visualization
from .route import Response, Route, RouteContext
@@ -704,10 +703,7 @@ class KnowledgeBaseRoute(Route):
file_name = file.filename
# 保存到临时文件
temp_file_path = os.path.join(
get_astrbot_temp_path(),
f"kb_upload_{uuid.uuid4()}_{file_name}",
)
temp_file_path = f"data/temp/{uuid.uuid4()}_{file_name}"
await file.save(temp_file_path)
try:
+2 -2
View File
@@ -12,7 +12,7 @@ from quart import websocket
from astrbot import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .route import Route, RouteContext
@@ -60,7 +60,7 @@ class LiveChatSession:
# 组装 WAV 文件
try:
temp_dir = get_astrbot_temp_path()
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
os.makedirs(temp_dir, exist_ok=True)
audio_path = os.path.join(temp_dir, f"live_audio_{uuid.uuid4()}.wav")
+1 -38
View File
@@ -20,7 +20,6 @@ from astrbot.core.star.filter.permission import PermissionTypeFilter
from astrbot.core.star.filter.regex import RegexFilter
from astrbot.core.star.star_handler import EventType, star_handlers_registry
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
from .route import Response, Route, RouteContext
@@ -54,13 +53,11 @@ class PluginRoute(Route):
"/plugin/market_list": ("GET", self.get_online_plugins),
"/plugin/off": ("POST", self.off_plugin),
"/plugin/on": ("POST", self.on_plugin),
"/plugin/reload-failed": ("POST", self.reload_failed_plugins),
"/plugin/reload": ("POST", self.reload_plugins),
"/plugin/readme": ("GET", self.get_plugin_readme),
"/plugin/changelog": ("GET", self.get_plugin_changelog),
"/plugin/source/get": ("GET", self.get_custom_source),
"/plugin/source/save": ("POST", self.save_custom_source),
"/plugin/source/get-failed-plugins": ("GET", self.get_failed_plugins),
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
@@ -77,33 +74,6 @@ class PluginRoute(Route):
self._logo_cache = {}
async def reload_failed_plugins(self):
if DEMO_MODE:
return (
Response()
.error("You are not permitted to do this operation in demo mode")
.__dict__
)
try:
data = await request.get_json()
dir_name = data.get("dir_name") # 这里拿的是目录名,不是插件名
if not dir_name:
return Response().error("缺少插件目录名").__dict__
# 调用 star_manager.py 中的函数
# 注意:传入的是目录名
success, err = await self.plugin_manager.reload_failed_plugin(dir_name)
if success:
return Response().ok(None, f"插件 {dir_name} 重载成功。").__dict__
else:
return Response().error(f"重载失败: {err}").__dict__
except Exception as e:
logger.error(f"/api/plugin/reload-failed: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def reload_plugins(self):
if DEMO_MODE:
return (
@@ -363,10 +333,6 @@ class PluginRoute(Route):
.__dict__
)
async def get_failed_plugins(self):
"""专门获取加载失败的插件列表(字典格式)"""
return Response().ok(self.plugin_manager.failed_plugin_dict).__dict__
async def get_plugin_handlers_info(self, handler_full_names: list[str]):
"""解析插件行为"""
handlers = []
@@ -465,10 +431,7 @@ class PluginRoute(Route):
file = await request.files
file = file["file"]
logger.info(f"正在安装用户上传的插件 {file.filename}")
file_path = os.path.join(
get_astrbot_temp_path(),
f"plugin_upload_{file.filename}",
)
file_path = f"data/temp/{file.filename}"
await file.save(file_path)
plugin_info = await self.plugin_manager.install_plugin_from_file(file_path)
# self.core_lifecycle.restart()
-39
View File
@@ -4,7 +4,6 @@ import threading
import time
import traceback
from functools import cmp_to_key
from pathlib import Path
import aiohttp
import psutil
@@ -38,7 +37,6 @@ class StatRoute(Route):
"/stat/test-ghproxy-connection": ("POST", self.test_ghproxy_connection),
"/stat/changelog": ("GET", self.get_changelog),
"/stat/changelog/list": ("GET", self.list_changelog_versions),
"/stat/first-notice": ("GET", self.get_first_notice),
}
self.db_helper = db_helper
self.register_routes()
@@ -281,40 +279,3 @@ class StatRoute(Route):
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Error: {e!s}").__dict__
async def get_first_notice(self):
"""读取项目根目录 FIRST_NOTICE.md 内容。"""
try:
locale = (request.args.get("locale") or "").strip()
if not re.match(r"^[A-Za-z0-9_-]*$", locale):
locale = ""
base_path = Path(get_astrbot_path())
candidates: list[Path] = []
if locale:
candidates.append(base_path / f"FIRST_NOTICE.{locale}.md")
if locale.lower().startswith("zh"):
candidates.append(base_path / "FIRST_NOTICE.md")
candidates.append(base_path / "FIRST_NOTICE.zh-CN.md")
elif locale.lower().startswith("en"):
candidates.append(base_path / "FIRST_NOTICE.en-US.md")
candidates.extend(
[
base_path / "FIRST_NOTICE.md",
base_path / "FIRST_NOTICE.en-US.md",
],
)
for notice_path in candidates:
if not notice_path.is_file():
continue
content = notice_path.read_text(encoding="utf-8")
if content.strip():
return Response().ok({"content": content}).__dict__
return Response().ok({"content": None}).__dict__
except Exception as e:
logger.error(traceback.format_exc())
return Response().error(f"Error: {e!s}").__dict__
+13 -68
View File
@@ -2,7 +2,6 @@ import asyncio
import logging
import os
import socket
from pathlib import Path
from typing import Protocol, cast
import jwt
@@ -37,12 +36,6 @@ class _AddrWithPort(Protocol):
APP: Quart
def _parse_env_bool(value: str | None, default: bool) -> bool:
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
class AstrBotDashboard:
def __init__(
self,
@@ -208,33 +201,19 @@ class AstrBotDashboard:
def run(self):
ip_addr = []
dashboard_config = self.core_lifecycle.astrbot_config.get("dashboard", {})
port = (
os.environ.get("DASHBOARD_PORT")
or os.environ.get("ASTRBOT_DASHBOARD_PORT")
or dashboard_config.get("port", 6185)
)
host = (
os.environ.get("DASHBOARD_HOST")
or os.environ.get("ASTRBOT_DASHBOARD_HOST")
or dashboard_config.get("host", "0.0.0.0")
)
enable = dashboard_config.get("enable", True)
ssl_config = dashboard_config.get("ssl", {})
if not isinstance(ssl_config, dict):
ssl_config = {}
ssl_enable = _parse_env_bool(
os.environ.get("DASHBOARD_SSL_ENABLE")
or os.environ.get("ASTRBOT_DASHBOARD_SSL_ENABLE"),
bool(ssl_config.get("enable", False)),
)
scheme = "https" if ssl_enable else "http"
if p := os.environ.get("DASHBOARD_PORT"):
port = p
else:
port = self.core_lifecycle.astrbot_config["dashboard"].get("port", 6185)
host = self.core_lifecycle.astrbot_config["dashboard"].get("host", "0.0.0.0")
enable = self.core_lifecycle.astrbot_config["dashboard"].get("enable", True)
if not enable:
logger.info("WebUI 已被禁用")
return None
logger.info(f"正在启动 WebUI, 监听地址: {scheme}://{host}:{port}")
logger.info(f"正在启动 WebUI, 监听地址: http://{host}:{port}")
if host == "0.0.0.0":
logger.info(
"提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host",
@@ -262,9 +241,9 @@ class AstrBotDashboard:
raise Exception(f"端口 {port} 已被占用")
parts = [f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n"]
parts.append(f" ➜ 本地: {scheme}://localhost:{port}\n")
parts.append(f" ➜ 本地: http://localhost:{port}\n")
for ip in ip_addr:
parts.append(f" ➜ 网络: {scheme}://{ip}:{port}\n")
parts.append(f" ➜ 网络: http://{ip}:{port}\n")
parts.append(" ➜ 默认用户名和密码: astrbot\n ✨✨✨\n")
display = "".join(parts)
@@ -278,45 +257,11 @@ class AstrBotDashboard:
# 配置 Hypercorn
config = HyperConfig()
config.bind = [f"{host}:{port}"]
if ssl_enable:
cert_file = (
os.environ.get("DASHBOARD_SSL_CERT")
or os.environ.get("ASTRBOT_DASHBOARD_SSL_CERT")
or ssl_config.get("cert_file", "")
)
key_file = (
os.environ.get("DASHBOARD_SSL_KEY")
or os.environ.get("ASTRBOT_DASHBOARD_SSL_KEY")
or ssl_config.get("key_file", "")
)
ca_certs = (
os.environ.get("DASHBOARD_SSL_CA_CERTS")
or os.environ.get("ASTRBOT_DASHBOARD_SSL_CA_CERTS")
or ssl_config.get("ca_certs", "")
)
cert_path = Path(cert_file).expanduser()
key_path = Path(key_file).expanduser()
if not cert_file or not key_file:
raise ValueError(
"dashboard.ssl.enable 为 true 时,必须配置 cert_file 和 key_file。",
)
if not cert_path.is_file():
raise ValueError(f"SSL 证书文件不存在: {cert_path}")
if not key_path.is_file():
raise ValueError(f"SSL 私钥文件不存在: {key_path}")
config.certfile = str(cert_path.resolve())
config.keyfile = str(key_path.resolve())
if ca_certs:
ca_path = Path(ca_certs).expanduser()
if not ca_path.is_file():
raise ValueError(f"SSL CA 证书文件不存在: {ca_path}")
config.ca_certs = str(ca_path.resolve())
# 根据配置决定是否禁用访问日志
disable_access_log = dashboard_config.get("disable_access_log", True)
disable_access_log = self.core_lifecycle.astrbot_config.get(
"dashboard", {}
).get("disable_access_log", True)
if disable_access_log:
config.accesslog = None
else:

Some files were not shown because too many files have changed in this diff Show More