Compare commits
79 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8f4a31cf8c | |||
| 23549f13d6 | |||
| 869d11f9a6 | |||
| 02e73b82ee | |||
| f85f87f545 | |||
| 1fff5713f3 | |||
| 8453ec36f0 | |||
| d5b3ce8424 | |||
| 80cbbfa5ca | |||
| 9177bb660f | |||
| a3df39a01a | |||
| 25dce05cbb | |||
| 1542ea3e03 | |||
| 6084abbcfe | |||
| ed19b63914 | |||
| 4efeb85296 | |||
| fc76665615 | |||
| 3a044bb71a | |||
| cddd606562 | |||
| 7a5bc51c11 | |||
| 9f939b4b6f | |||
| 80a86f5b1b | |||
| a0ce1855ab | |||
| a4b43b884a | |||
| 824c0f6667 | |||
| a030fe8491 | |||
| 3a9429e8ef | |||
| c4eb1ab748 | |||
| 29ed19d600 | |||
| 0cc65513a5 | |||
| debc048659 | |||
| 92f5c918dd | |||
| 9519f1e8e2 | |||
| a8f874bf05 | |||
| 9d9917e45b | |||
| 91ee0a870d | |||
| 6cbbffc5a9 | |||
| 8f26fd34d1 | |||
| fda655f6d7 | |||
| a663d6509b | |||
| 9ec8839efa | |||
| a7a0350eb2 | |||
| 39a7a0d960 | |||
| 7740e1e131 | |||
| 9dce1ed47e | |||
| e84a00d3a5 | |||
| 88a944cb57 | |||
| 20c32e72cc | |||
| 4788c20816 | |||
| e83fc570a4 | |||
| e841b6af88 | |||
| ea6f209557 | |||
| 9bfa726107 | |||
| d24902c66d | |||
| 72aea2d3f3 | |||
| dc9612d564 | |||
| 1770556d56 | |||
| 888fb84aee | |||
| d597fd056d | |||
| dea0ab3974 | |||
| da6facd7d7 | |||
| bb8ab5f173 | |||
| ac8a541059 | |||
| 0e66771f0e | |||
| d3a295a801 | |||
| f2df771771 | |||
| 7b72cd87a5 | |||
| 9431efc6d1 | |||
| 7c3f5431ba | |||
| d98cf16a4c | |||
| 2c3c3ae546 | |||
| 905eef48e3 | |||
| b31b520c7c | |||
| 17aee086a3 | |||
| c1756e5767 | |||
| 2920279c64 | |||
| 1f0f985b01 | |||
| 0762c81633 | |||
| 28ef301ccc |
@@ -1,19 +1,46 @@
|
|||||||
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
<!-- 如果有的话,请指定此 PR 旨在解决的 ISSUE 编号。 -->
|
||||||
解决了 #XYZ
|
<!-- If applicable, please specify the ISSUE number this PR aims to resolve. -->
|
||||||
|
|
||||||
### Motivation
|
fixes #XYZ
|
||||||
|
|
||||||
<!--解释为什么要改动-->
|
---
|
||||||
|
|
||||||
### Modifications
|
### Motivation / 动机
|
||||||
|
|
||||||
<!--简单解释你的改动-->
|
<!--请描述此项更改的动机:它解决了什么问题?(例如:修复了 XX 错误,添加了 YY 功能)-->
|
||||||
|
<!--Please describe the motivation for this change: What problem does it solve? (e.g., Fixes XX bug, adds YY feature)-->
|
||||||
|
|
||||||
### Check
|
### Modifications / 改动点
|
||||||
|
|
||||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
|
<!--请总结你的改动:哪些核心文件被修改了?实现了什么功能?-->
|
||||||
|
<!--Please summarize your changes: What core files were modified? What functionality was implemented?-->
|
||||||
|
|
||||||
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
### Verification Steps / 验证步骤
|
||||||
- [ ] 👀 我的更改经过良好的测试
|
|
||||||
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。
|
<!--请为审查者 (Reviewer) 提供清晰、可复现的验证步骤(例如:1. 导航到... 2. 点击...)。-->
|
||||||
- [ ] 😮 我的更改没有引入恶意代码
|
<!--Please provide clear and reproducible verification steps for the Reviewer (e.g., 1. Navigate to... 2. Click...).-->
|
||||||
|
|
||||||
|
### Screenshots or Test Results / 运行截图或测试结果
|
||||||
|
|
||||||
|
<!--请粘贴截图、GIF 或测试日志,作为执行“验证步骤”的证据,证明此改动有效。-->
|
||||||
|
<!--Please paste screenshots, GIFs, or test logs here as evidence of executing the "Verification Steps" to prove this change is effective.-->
|
||||||
|
|
||||||
|
### Compatibility & Breaking Changes / 兼容性与破坏性变更
|
||||||
|
|
||||||
|
<!--请说明此变更的兼容性:哪些是破坏性变更?哪些地方做了向后兼容处理?是否提供了数据迁移方法?-->
|
||||||
|
<!--Please explain the compatibility of this change: What are the breaking changes? What backward-compatible measures were taken? Are data migration paths provided?-->
|
||||||
|
|
||||||
|
- [ ] 这是一个破坏性变更 (Breaking Change)。/ This is a breaking change.
|
||||||
|
- [ ] 这不是一个破坏性变更。/ This is NOT a breaking change.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Checklist / 检查清单
|
||||||
|
|
||||||
|
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容。-->
|
||||||
|
<!--If merged, your code will serve tens of thousands of users! Please double-check the following items before submitting.-->
|
||||||
|
|
||||||
|
- [ ] 😊 如果 PR 中有新加入的功能,已经通过 Issue / 邮件等方式和作者讨论过。/ If there are new features added in the PR, I have discussed it with the authors through issues/emails, etc.
|
||||||
|
- [ ] 👀 我的更改经过了良好的测试,**并已在上方提供了“验证步骤”和“运行截图”**。/ My changes have been well-tested, **and "Verification Steps" and "Screenshots" have been provided above**.
|
||||||
|
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。/ I have ensured that no new dependencies are introduced, OR if new dependencies are introduced, they have been added to the appropriate locations in `requirements.txt` and `pyproject.toml`.
|
||||||
|
- [ ] 😮 我的更改没有引入恶意代码。/ My changes do not introduce malicious code.
|
||||||
|
|||||||
@@ -0,0 +1,36 @@
|
|||||||
|
# Set to true to add reviewers to pull requests
|
||||||
|
addReviewers: true
|
||||||
|
|
||||||
|
# Set to true to add assignees to pull requests
|
||||||
|
addAssignees: false
|
||||||
|
|
||||||
|
# A list of reviewers to be added to pull requests (GitHub user name)
|
||||||
|
reviewers:
|
||||||
|
- Soulter
|
||||||
|
- Raven95676
|
||||||
|
- Larch-C
|
||||||
|
- anka-afk
|
||||||
|
- advent259141
|
||||||
|
# - zouyonghe
|
||||||
|
|
||||||
|
# A number of reviewers added to the pull request
|
||||||
|
# Set 0 to add all the reviewers (default: 0)
|
||||||
|
numberOfReviewers: 2
|
||||||
|
|
||||||
|
# A list of assignees, overrides reviewers if set
|
||||||
|
# assignees:
|
||||||
|
# - assigneeA
|
||||||
|
|
||||||
|
# A number of assignees to add to the pull request
|
||||||
|
# Set to 0 to add all of the assignees.
|
||||||
|
# Uses numberOfReviewers if unset.
|
||||||
|
# numberOfAssignees: 2
|
||||||
|
|
||||||
|
# A list of keywords to be skipped the process that add reviewers if pull requests include it
|
||||||
|
skipKeywords:
|
||||||
|
- wip
|
||||||
|
- draft
|
||||||
|
|
||||||
|
# A list of users to be skipped by both the add reviewers and add assignees processes
|
||||||
|
# skipUsers:
|
||||||
|
# - dependabot[bot]
|
||||||
@@ -73,7 +73,7 @@ jobs:
|
|||||||
uses: actions/checkout@v5
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,34 @@
|
|||||||
|
name: Code Format Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: [ master ]
|
||||||
|
push:
|
||||||
|
branches: [ master ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
format-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v6
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install UV
|
||||||
|
run: pip install uv
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: uv sync
|
||||||
|
|
||||||
|
- name: Check code formatting with ruff
|
||||||
|
run: |
|
||||||
|
uv run ruff format --check .
|
||||||
|
|
||||||
|
- name: Check code style with ruff
|
||||||
|
run: |
|
||||||
|
uv run ruff check .
|
||||||
@@ -22,7 +22,7 @@ jobs:
|
|||||||
fetch-depth: 0
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v6
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -37,6 +37,7 @@ jobs:
|
|||||||
!dist/**/*.md
|
!dist/**/*.md
|
||||||
|
|
||||||
- name: Create GitHub Release
|
- name: Create GitHub Release
|
||||||
|
if: github.event_name == 'push'
|
||||||
uses: ncipollo/release-action@v1
|
uses: ncipollo/release-action@v1
|
||||||
with:
|
with:
|
||||||
tag: release-${{ github.sha }}
|
tag: release-${{ github.sha }}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ jobs:
|
|||||||
pull-requests: write
|
pull-requests: write
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/stale@v9
|
- uses: actions/stale@v10
|
||||||
with:
|
with:
|
||||||
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
stale-issue-message: 'Stale issue message'
|
stale-issue-message: 'Stale issue message'
|
||||||
|
|||||||
@@ -6,8 +6,6 @@
|
|||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
|
||||||
|
|
||||||
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/12875" target="_blank"><img src="https://trendshift.io/api/badge/repositories/12875" alt="Soulter%2FAstrBot | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
|
|
||||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||||
@@ -16,18 +14,18 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
|||||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||||

|
|
||||||

|

|
||||||
|
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||||
<a href="https://astrbot.app/">查看文档</a> |
|
<a href="https://astrbot.app/">文档</a> |
|
||||||
|
<a href="https://blog.astrbot.app/">Blog</a> |
|
||||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
|
AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架。
|
||||||
|
|
||||||
## ✨ 主要功能
|
## 主要功能
|
||||||
|
|
||||||
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
1. **大模型对话**。支持接入多种大模型服务。支持多模态、工具调用、MCP、原生知识库、人设等功能。
|
||||||
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
2. **多消息平台支持**。支持接入 QQ、企业微信、微信公众号、飞书、Telegram、钉钉、Discord、KOOK 等平台。支持速率限制、白名单、百度内容审核。
|
||||||
@@ -35,7 +33,7 @@ AstrBot 是一个开源的一站式 Agentic 聊天机器人平台及开发框架
|
|||||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,社区插件生态丰富。
|
||||||
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
5. **WebUI**。可视化配置和管理机器人,功能齐全。
|
||||||
|
|
||||||
## ✨ 使用方式
|
## 部署方式
|
||||||
|
|
||||||
#### Docker 部署
|
#### Docker 部署
|
||||||
|
|
||||||
@@ -79,9 +77,7 @@ AstrBot 已由雨云官方上架至云应用平台,可一键部署。
|
|||||||
|
|
||||||
#### 手动部署
|
#### 手动部署
|
||||||
|
|
||||||
> 推荐使用 `uv`。
|
首先安装 uv:
|
||||||
|
|
||||||
首先,安装 uv:
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install uv
|
pip install uv
|
||||||
@@ -96,6 +92,25 @@ uv run main.py
|
|||||||
|
|
||||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||||
|
|
||||||
|
## 🌍 社区
|
||||||
|
|
||||||
|
### QQ 群组
|
||||||
|
|
||||||
|
- 1 群:322154837
|
||||||
|
- 3 群:630166526
|
||||||
|
- 5 群:822130018
|
||||||
|
- 6 群:753075035
|
||||||
|
- 开发者群:975206796
|
||||||
|
- 开发者群(备份):295657329
|
||||||
|
|
||||||
|
### Telegram 群组
|
||||||
|
|
||||||
|
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
|
|
||||||
|
### Discord 群组
|
||||||
|
|
||||||
|
<a href="https://discord.gg/hAVk6tgV36"><img alt="Discord_community" src="https://img.shields.io/badge/Discord-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||||
|
|
||||||
## ⚡ 消息平台支持情况
|
## ⚡ 消息平台支持情况
|
||||||
|
|
||||||
| 平台 | 支持性 |
|
| 平台 | 支持性 |
|
||||||
@@ -112,22 +127,20 @@ uv run main.py
|
|||||||
| Discord | ✔ |
|
| Discord | ✔ |
|
||||||
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
| [KOOK](https://github.com/wuyan1003/astrbot_plugin_kook_adapter) | ✔ |
|
||||||
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
| [VoceChat](https://github.com/HikariFroya/astrbot_plugin_vocechat) | ✔ |
|
||||||
| 微信对话开放平台 | 🚧 |
|
| Satori | ✔ |
|
||||||
| WhatsApp | 🚧 |
|
| Misskey | ✔ |
|
||||||
| 小爱音响 | 🚧 |
|
|
||||||
|
|
||||||
## ⚡ 提供商支持情况
|
## ⚡ 提供商支持情况
|
||||||
|
|
||||||
| 名称 | 支持性 | 类型 | 备注 |
|
| 名称 | 支持性 | 类型 | 备注 |
|
||||||
| -------- | ------- | ------- | ------- |
|
| -------- | ------- | ------- | ------- |
|
||||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Gemini、Kimi、xAI 等兼容 OpenAI API 的服务 |
|
| OpenAI | ✔ | 文本生成 | 支持任何兼容 OpenAI API 的服务 |
|
||||||
| Claude API | ✔ | 文本生成 | |
|
| Anthropic | ✔ | 文本生成 | |
|
||||||
| Google Gemini API | ✔ | 文本生成 | |
|
| Google Gemini | ✔ | 文本生成 | |
|
||||||
| Dify | ✔ | LLMOps | |
|
| Dify | ✔ | LLMOps | |
|
||||||
| 阿里云百炼应用 | ✔ | LLMOps | |
|
| 阿里云百炼应用 | ✔ | LLMOps | |
|
||||||
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| Ollama | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
| LM Studio | ✔ | 模型加载器 | 本地部署 DeepSeek、Llama 等开源语言模型 |
|
||||||
| LLMTuner | ✔ | 模型加载器 | 本地加载 lora 等微调模型 |
|
|
||||||
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
| [优云智算](https://www.compshare.cn/?ytag=GPU_YY-gh_astrbot&referral_code=FV7DcGowN4hB5UuXKgpE74) | ✔ | 模型 API 及算力服务平台 | |
|
||||||
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
| [302.AI](https://share.302.ai/rr1M3l) | ✔ | 模型 API 服务平台 | |
|
||||||
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
| 硅基流动 | ✔ | 模型 API 服务平台 | |
|
||||||
@@ -143,7 +156,6 @@ uv run main.py
|
|||||||
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
| 阿里云百炼 TTS | ✔ | 文本转语音 | |
|
||||||
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
| Azure TTS | ✔ | 文本转语音 | Microsoft Azure TTS |
|
||||||
|
|
||||||
|
|
||||||
## ❤️ 贡献
|
## ❤️ 贡献
|
||||||
|
|
||||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||||
@@ -162,39 +174,6 @@ pip install pre-commit
|
|||||||
pre-commit install
|
pre-commit install
|
||||||
```
|
```
|
||||||
|
|
||||||
## 🌟 支持
|
|
||||||
|
|
||||||
- Star 这个项目!
|
|
||||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
|
||||||
|
|
||||||
## ✨ Demo
|
|
||||||
|
|
||||||
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
|
|
||||||
|
|
||||||
<div align='center'>
|
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
|
||||||
|
|
||||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
|
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
|
||||||
|
|
||||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
|
||||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
|
||||||
|
|
||||||
_✨ 插件系统——部分插件展示 ✨_
|
|
||||||
|
|
||||||
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
|
|
||||||
|
|
||||||
_✨ WebUI ✨_
|
|
||||||
|
|
||||||
</div>
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
## ❤️ Special Thanks
|
## ❤️ Special Thanks
|
||||||
|
|
||||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||||
@@ -203,10 +182,18 @@ _✨ WebUI ✨_
|
|||||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||||
</a>
|
</a>
|
||||||
|
|
||||||
此外,本项目的诞生离不开以下开源项目:
|
此外,本项目的诞生离不开以下开源项目的帮助:
|
||||||
|
|
||||||
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
- [NapNeko/NapCatQQ](https://github.com/NapNeko/NapCatQQ) - 伟大的猫猫框架
|
||||||
- [wechatpy/wechatpy](https://github.com/wechatpy/wechatpy)
|
|
||||||
|
另外,一些同类型其他的活跃开源 Bot 项目:
|
||||||
|
|
||||||
|
- [nonebot/nonebot2](https://github.com/nonebot/nonebot2) - 扩展性极强的 Bot 框架
|
||||||
|
- [koishijs/koishi](https://github.com/koishijs/koishi) - 扩展性极强的 Bot 框架
|
||||||
|
- [MaiM-with-u/MaiBot](https://github.com/MaiM-with-u/MaiBot) - 注重拟人功能的 ChatBot
|
||||||
|
- [langbot-app/LangBot](https://github.com/langbot-app/LangBot) - 功能丰富的 Bot 平台
|
||||||
|
- [LroMiose/nekro-agent](https://github.com/KroMiose/nekro-agent) - 注重 Agent 的 ChatBot
|
||||||
|
- [zhenxun-org/zhenxun_bot](https://github.com/zhenxun-org/zhenxun_bot) - 功能完善的 ChatBot
|
||||||
|
|
||||||
## ⭐ Star History
|
## ⭐ Star History
|
||||||
|
|
||||||
@@ -219,8 +206,6 @@ _✨ WebUI ✨_
|
|||||||
|
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||

|
</details>
|
||||||
|
|
||||||
|
|
||||||
_私は、高性能ですから!_
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from astrbot.core.star.register import (
|
|||||||
register_permission_type as permission_type,
|
register_permission_type as permission_type,
|
||||||
register_custom_filter as custom_filter,
|
register_custom_filter as custom_filter,
|
||||||
register_on_astrbot_loaded as on_astrbot_loaded,
|
register_on_astrbot_loaded as on_astrbot_loaded,
|
||||||
|
register_on_platform_loaded as on_platform_loaded,
|
||||||
register_on_llm_request as on_llm_request,
|
register_on_llm_request as on_llm_request,
|
||||||
register_on_llm_response as on_llm_response,
|
register_on_llm_response as on_llm_response,
|
||||||
register_llm_tool as llm_tool,
|
register_llm_tool as llm_tool,
|
||||||
@@ -41,6 +42,7 @@ __all__ = [
|
|||||||
"custom_filter",
|
"custom_filter",
|
||||||
"PermissionType",
|
"PermissionType",
|
||||||
"on_astrbot_loaded",
|
"on_astrbot_loaded",
|
||||||
|
"on_platform_loaded",
|
||||||
"on_llm_request",
|
"on_llm_request",
|
||||||
"llm_tool",
|
"llm_tool",
|
||||||
"on_decorating_result",
|
"on_decorating_result",
|
||||||
|
|||||||
+22
-18
@@ -124,15 +124,17 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
if metadata and all(
|
if metadata and all(
|
||||||
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
k in metadata for k in ["name", "desc", "version", "author", "repo"]
|
||||||
):
|
):
|
||||||
result.append({
|
result.append(
|
||||||
"name": str(metadata.get("name", "")),
|
{
|
||||||
"desc": str(metadata.get("desc", "")),
|
"name": str(metadata.get("name", "")),
|
||||||
"version": str(metadata.get("version", "")),
|
"desc": str(metadata.get("desc", "")),
|
||||||
"author": str(metadata.get("author", "")),
|
"version": str(metadata.get("version", "")),
|
||||||
"repo": str(metadata.get("repo", "")),
|
"author": str(metadata.get("author", "")),
|
||||||
"status": PluginStatus.INSTALLED,
|
"repo": str(metadata.get("repo", "")),
|
||||||
"local_path": str(plugin_dir),
|
"status": PluginStatus.INSTALLED,
|
||||||
})
|
"local_path": str(plugin_dir),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# 获取在线插件列表
|
# 获取在线插件列表
|
||||||
online_plugins = []
|
online_plugins = []
|
||||||
@@ -142,15 +144,17 @@ def build_plug_list(plugins_dir: Path) -> list:
|
|||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
for plugin_id, plugin_info in data.items():
|
for plugin_id, plugin_info in data.items():
|
||||||
online_plugins.append({
|
online_plugins.append(
|
||||||
"name": str(plugin_id),
|
{
|
||||||
"desc": str(plugin_info.get("desc", "")),
|
"name": str(plugin_id),
|
||||||
"version": str(plugin_info.get("version", "")),
|
"desc": str(plugin_info.get("desc", "")),
|
||||||
"author": str(plugin_info.get("author", "")),
|
"version": str(plugin_info.get("version", "")),
|
||||||
"repo": str(plugin_info.get("repo", "")),
|
"author": str(plugin_info.get("author", "")),
|
||||||
"status": PluginStatus.NOT_INSTALLED,
|
"repo": str(plugin_info.get("repo", "")),
|
||||||
"local_path": None,
|
"status": PluginStatus.NOT_INSTALLED,
|
||||||
})
|
"local_path": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
click.echo(f"获取在线插件列表失败: {e}", err=True)
|
||||||
|
|
||||||
|
|||||||
@@ -9,5 +9,5 @@ from .hooks import BaseAgentRunHooks
|
|||||||
class Agent(Generic[TContext]):
|
class Agent(Generic[TContext]):
|
||||||
name: str
|
name: str
|
||||||
instructions: str | None = None
|
instructions: str | None = None
|
||||||
tools: list[str, FunctionTool] | None = None
|
tools: list[str | FunctionTool] | None = None
|
||||||
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
run_hooks: BaseAgentRunHooks[TContext] | None = None
|
||||||
|
|||||||
@@ -92,7 +92,7 @@ class MCPClient:
|
|||||||
self.session: Optional[mcp.ClientSession] = None
|
self.session: Optional[mcp.ClientSession] = None
|
||||||
self.exit_stack = AsyncExitStack()
|
self.exit_stack = AsyncExitStack()
|
||||||
|
|
||||||
self.name = None
|
self.name: str | None = None
|
||||||
self.active: bool = True
|
self.active: bool = True
|
||||||
self.tools: list[mcp.Tool] = []
|
self.tools: list[mcp.Tool] = []
|
||||||
self.server_errlogs: list[str] = []
|
self.server_errlogs: list[str] = []
|
||||||
@@ -198,6 +198,8 @@ class MCPClient:
|
|||||||
|
|
||||||
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
async def list_tools_and_save(self) -> mcp.ListToolsResult:
|
||||||
"""List all tools from the server and save them to self.tools"""
|
"""List all tools from the server and save them to self.tools"""
|
||||||
|
if not self.session:
|
||||||
|
raise Exception("MCP Client is not initialized")
|
||||||
response = await self.session.list_tools()
|
response = await self.session.list_tools()
|
||||||
self.tools = response.tools
|
self.tools = response.tools
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|||||||
import typing as T
|
import typing as T
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
|
||||||
|
|
||||||
class AgentResponseData(T.TypedDict):
|
class AgentResponseData(T.TypedDict):
|
||||||
chain: MessageChain
|
chain: MessageChain
|
||||||
|
|
||||||
|
|||||||
@@ -14,4 +14,5 @@ class ContextWrapper(Generic[TContext]):
|
|||||||
context: TContext
|
context: TContext
|
||||||
event: AstrMessageEvent
|
event: AstrMessageEvent
|
||||||
|
|
||||||
|
|
||||||
NoContext = ContextWrapper[None]
|
NoContext = ContextWrapper[None]
|
||||||
|
|||||||
@@ -258,7 +258,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
yield MessageChain(
|
yield MessageChain(
|
||||||
type="tool_direct_result"
|
type="tool_direct_result"
|
||||||
).base64_image(res.content[0].data)
|
).base64_image(resource.blob)
|
||||||
else:
|
else:
|
||||||
tool_call_result_blocks.append(
|
tool_call_result_blocks.append(
|
||||||
ToolCallMessageSegment(
|
ToolCallMessageSegment(
|
||||||
@@ -269,17 +269,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
yield MessageChain().message("返回的数据类型不受支持。")
|
yield MessageChain().message("返回的数据类型不受支持。")
|
||||||
|
|
||||||
try:
|
|
||||||
await self.agent_hooks.on_tool_end(
|
|
||||||
self.run_context,
|
|
||||||
func_tool_name,
|
|
||||||
func_tool_args,
|
|
||||||
resp,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
elif resp is None:
|
elif resp is None:
|
||||||
# Tool 直接请求发送消息给用户
|
# Tool 直接请求发送消息给用户
|
||||||
# 这里我们将直接结束 Agent Loop。
|
# 这里我们将直接结束 Agent Loop。
|
||||||
@@ -289,27 +278,17 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
yield MessageChain(
|
yield MessageChain(
|
||||||
chain=res.chain, type="tool_direct_result"
|
chain=res.chain, type="tool_direct_result"
|
||||||
)
|
)
|
||||||
try:
|
|
||||||
await self.agent_hooks.on_tool_end(
|
|
||||||
self.run_context, func_tool_name, func_tool_args, None
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(
|
|
||||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
f"Tool 返回了不支持的类型: {type(resp)},将忽略。"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.agent_hooks.on_tool_end(
|
await self.agent_hooks.on_tool_end(
|
||||||
self.run_context, func_tool_name, func_tool_args, None
|
self.run_context, func_tool, func_tool_args, None
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(f"Error in on_tool_end hook: {e}", exc_info=True)
|
||||||
f"Error in on_tool_end hook: {e}", exc_info=True
|
|
||||||
)
|
|
||||||
|
|
||||||
self.run_context.event.clear_result()
|
self.run_context.event.clear_result()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
+28
-17
@@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
from typing import Awaitable, Literal, Any, Optional
|
from typing import Awaitable, Callable, Literal, Any, Optional
|
||||||
from .mcp_client import MCPClient
|
from .mcp_client import MCPClient
|
||||||
|
|
||||||
|
|
||||||
@@ -8,10 +8,10 @@ from .mcp_client import MCPClient
|
|||||||
class FunctionTool:
|
class FunctionTool:
|
||||||
"""A class representing a function tool that can be used in function calling."""
|
"""A class representing a function tool that can be used in function calling."""
|
||||||
|
|
||||||
name: str | None = None
|
name: str
|
||||||
parameters: dict | None = None
|
parameters: dict | None = None
|
||||||
description: str | None = None
|
description: str | None = None
|
||||||
handler: Awaitable | None = None
|
handler: Callable[..., Awaitable[Any]] | None = None
|
||||||
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
"""处理函数, 当 origin 为 mcp 时,这个为空"""
|
||||||
handler_module_path: str | None = None
|
handler_module_path: str | None = None
|
||||||
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
"""处理函数的模块路径,当 origin 为 mcp 时,这个为空
|
||||||
@@ -51,7 +51,7 @@ class ToolSet:
|
|||||||
This class provides methods to add, remove, and retrieve tools, as well as
|
This class provides methods to add, remove, and retrieve tools, as well as
|
||||||
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI)."""
|
||||||
|
|
||||||
def __init__(self, tools: list[FunctionTool] = None):
|
def __init__(self, tools: list[FunctionTool] | None = None):
|
||||||
self.tools: list[FunctionTool] = tools or []
|
self.tools: list[FunctionTool] = tools or []
|
||||||
|
|
||||||
def empty(self) -> bool:
|
def empty(self) -> bool:
|
||||||
@@ -79,7 +79,13 @@ class ToolSet:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
@deprecated(reason="Use add_tool() instead", version="4.0.0")
|
||||||
def add_func(self, name: str, func_args: list, desc: str, handler: Awaitable):
|
def add_func(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
func_args: list,
|
||||||
|
desc: str,
|
||||||
|
handler: Callable[..., Awaitable[Any]],
|
||||||
|
):
|
||||||
"""Add a function tool to the set."""
|
"""Add a function tool to the set."""
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
@@ -104,7 +110,7 @@ class ToolSet:
|
|||||||
self.remove_tool(name)
|
self.remove_tool(name)
|
||||||
|
|
||||||
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
@deprecated(reason="Use get_tool() instead", version="4.0.0")
|
||||||
def get_func(self, name: str) -> list[FunctionTool]:
|
def get_func(self, name: str) -> FunctionTool | None:
|
||||||
"""Get all function tools."""
|
"""Get all function tools."""
|
||||||
return self.get_tool(name)
|
return self.get_tool(name)
|
||||||
|
|
||||||
@@ -125,7 +131,11 @@ class ToolSet:
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
if tool.parameters.get("properties") or not omit_empty_parameter_field:
|
if (
|
||||||
|
tool.parameters
|
||||||
|
and tool.parameters.get("properties")
|
||||||
|
or not omit_empty_parameter_field
|
||||||
|
):
|
||||||
func_def["function"]["parameters"] = tool.parameters
|
func_def["function"]["parameters"] = tool.parameters
|
||||||
|
|
||||||
result.append(func_def)
|
result.append(func_def)
|
||||||
@@ -135,14 +145,14 @@ class ToolSet:
|
|||||||
"""Convert tools to Anthropic API format."""
|
"""Convert tools to Anthropic API format."""
|
||||||
result = []
|
result = []
|
||||||
for tool in self.tools:
|
for tool in self.tools:
|
||||||
|
input_schema = {"type": "object"}
|
||||||
|
if tool.parameters:
|
||||||
|
input_schema["properties"] = tool.parameters.get("properties", {})
|
||||||
|
input_schema["required"] = tool.parameters.get("required", [])
|
||||||
tool_def = {
|
tool_def = {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"input_schema": {
|
"input_schema": input_schema,
|
||||||
"type": "object",
|
|
||||||
"properties": tool.parameters.get("properties", {}),
|
|
||||||
"required": tool.parameters.get("required", []),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
result.append(tool_def)
|
result.append(tool_def)
|
||||||
return result
|
return result
|
||||||
@@ -210,14 +220,15 @@ class ToolSet:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
tools = [
|
tools = []
|
||||||
{
|
for tool in self.tools:
|
||||||
|
d = {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": convert_schema(tool.parameters),
|
|
||||||
}
|
}
|
||||||
for tool in self.tools
|
if tool.parameters:
|
||||||
]
|
d["parameters"] = convert_schema(tool.parameters)
|
||||||
|
tools.append(d)
|
||||||
|
|
||||||
declarations = {}
|
declarations = {}
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
+142
-13
@@ -6,7 +6,7 @@ import os
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.0.0-beta.5"
|
VERSION = "4.1.7"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
# 默认配置
|
# 默认配置
|
||||||
@@ -56,10 +56,11 @@ DEFAULT_CONFIG = {
|
|||||||
"wake_prefix": "",
|
"wake_prefix": "",
|
||||||
"web_search": False,
|
"web_search": False,
|
||||||
"websearch_provider": "default",
|
"websearch_provider": "default",
|
||||||
"websearch_tavily_key": "",
|
"websearch_tavily_key": [],
|
||||||
"web_search_link": False,
|
"web_search_link": False,
|
||||||
"display_reasoning_text": False,
|
"display_reasoning_text": False,
|
||||||
"identifier": False,
|
"identifier": False,
|
||||||
|
"group_name_display": False,
|
||||||
"datetime_system_prompt": True,
|
"datetime_system_prompt": True,
|
||||||
"default_personality": "default",
|
"default_personality": "default",
|
||||||
"persona_pool": ["*"],
|
"persona_pool": ["*"],
|
||||||
@@ -103,6 +104,7 @@ DEFAULT_CONFIG = {
|
|||||||
"t2i_strategy": "remote",
|
"t2i_strategy": "remote",
|
||||||
"t2i_endpoint": "",
|
"t2i_endpoint": "",
|
||||||
"t2i_use_file_service": False,
|
"t2i_use_file_service": False,
|
||||||
|
"t2i_active_template": "base",
|
||||||
"http_proxy": "",
|
"http_proxy": "",
|
||||||
"no_proxy": ["localhost", "127.0.0.1", "::1"],
|
"no_proxy": ["localhost", "127.0.0.1", "::1"],
|
||||||
"dashboard": {
|
"dashboard": {
|
||||||
@@ -234,6 +236,16 @@ CONFIG_METADATA_2 = {
|
|||||||
"discord_guild_id_for_debug": "",
|
"discord_guild_id_for_debug": "",
|
||||||
"discord_activity_name": "",
|
"discord_activity_name": "",
|
||||||
},
|
},
|
||||||
|
"Misskey": {
|
||||||
|
"id": "misskey",
|
||||||
|
"type": "misskey",
|
||||||
|
"enable": False,
|
||||||
|
"misskey_instance_url": "https://misskey.example",
|
||||||
|
"misskey_token": "",
|
||||||
|
"misskey_default_visibility": "public",
|
||||||
|
"misskey_local_only": False,
|
||||||
|
"misskey_enable_chat": True,
|
||||||
|
},
|
||||||
"Slack": {
|
"Slack": {
|
||||||
"id": "slack",
|
"id": "slack",
|
||||||
"type": "slack",
|
"type": "slack",
|
||||||
@@ -246,8 +258,49 @@ CONFIG_METADATA_2 = {
|
|||||||
"slack_webhook_port": 6197,
|
"slack_webhook_port": 6197,
|
||||||
"slack_webhook_path": "/astrbot-slack-webhook/callback",
|
"slack_webhook_path": "/astrbot-slack-webhook/callback",
|
||||||
},
|
},
|
||||||
|
"Satori": {
|
||||||
|
"id": "satori",
|
||||||
|
"type": "satori",
|
||||||
|
"enable": False,
|
||||||
|
"satori_api_base_url": "http://localhost:5140/satori/v1",
|
||||||
|
"satori_endpoint": "ws://localhost:5140/satori/v1/events",
|
||||||
|
"satori_token": "",
|
||||||
|
"satori_auto_reconnect": True,
|
||||||
|
"satori_heartbeat_interval": 10,
|
||||||
|
"satori_reconnect_delay": 5,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"items": {
|
"items": {
|
||||||
|
"satori_api_base_url": {
|
||||||
|
"description": "Satori API 终结点",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "Satori API 的基础地址。",
|
||||||
|
},
|
||||||
|
"satori_endpoint": {
|
||||||
|
"description": "Satori WebSocket 终结点",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "Satori 事件的 WebSocket 端点。",
|
||||||
|
},
|
||||||
|
"satori_token": {
|
||||||
|
"description": "Satori 令牌",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "用于 Satori API 身份验证的令牌。",
|
||||||
|
},
|
||||||
|
"satori_auto_reconnect": {
|
||||||
|
"description": "启用自动重连",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "断开连接时是否自动重新连接 WebSocket。",
|
||||||
|
},
|
||||||
|
"satori_heartbeat_interval": {
|
||||||
|
"description": "Satori 心跳间隔",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "发送心跳消息的间隔(秒)。",
|
||||||
|
},
|
||||||
|
"satori_reconnect_delay": {
|
||||||
|
"description": "Satori 重连延迟",
|
||||||
|
"type": "int",
|
||||||
|
"hint": "尝试重新连接前的延迟时间(秒)。",
|
||||||
|
},
|
||||||
"slack_connection_mode": {
|
"slack_connection_mode": {
|
||||||
"description": "Slack Connection Mode",
|
"description": "Slack Connection Mode",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
@@ -294,6 +347,32 @@ CONFIG_METADATA_2 = {
|
|||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
||||||
},
|
},
|
||||||
|
"misskey_instance_url": {
|
||||||
|
"description": "Misskey 实例 URL",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "例如 https://misskey.example,填写 Bot 账号所在的 Misskey 实例地址",
|
||||||
|
},
|
||||||
|
"misskey_token": {
|
||||||
|
"description": "Misskey Access Token",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "连接服务设置生成的 API 鉴权访问令牌(Access token)",
|
||||||
|
},
|
||||||
|
"misskey_default_visibility": {
|
||||||
|
"description": "默认帖子可见性",
|
||||||
|
"type": "string",
|
||||||
|
"options": ["public", "home", "followers"],
|
||||||
|
"hint": "机器人发帖时的默认可见性设置。public:公开,home:主页时间线,followers:仅关注者。",
|
||||||
|
},
|
||||||
|
"misskey_local_only": {
|
||||||
|
"description": "仅限本站(不参与联合)",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,机器人发出的帖子将仅在本实例可见,不会联合到其他实例",
|
||||||
|
},
|
||||||
|
"misskey_enable_chat": {
|
||||||
|
"description": "启用聊天消息响应",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,机器人将会监听和响应私信聊天消息",
|
||||||
|
},
|
||||||
"telegram_command_register": {
|
"telegram_command_register": {
|
||||||
"description": "Telegram 命令注册",
|
"description": "Telegram 命令注册",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -557,6 +636,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://api.openai.com/v1",
|
"api_base": "https://api.openai.com/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
"hint": "也兼容所有与 OpenAI API 兼容的服务。",
|
||||||
},
|
},
|
||||||
@@ -571,6 +651,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "",
|
"api_base": "",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
"model_config": {"model": "gpt-4o-mini", "temperature": 0.4},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"xAI": {
|
"xAI": {
|
||||||
@@ -583,6 +664,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://api.x.ai/v1",
|
"api_base": "https://api.x.ai/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
"model_config": {"model": "grok-2-latest", "temperature": 0.4},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"Anthropic": {
|
"Anthropic": {
|
||||||
@@ -612,6 +694,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||||
"api_base": "http://localhost:11434/v1",
|
"api_base": "http://localhost:11434/v1",
|
||||||
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
"model_config": {"model": "llama3.1-8b", "temperature": 0.4},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"LM Studio": {
|
"LM Studio": {
|
||||||
@@ -625,6 +708,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "llama-3.1-8b",
|
"model": "llama-3.1-8b",
|
||||||
},
|
},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"Gemini(OpenAI兼容)": {
|
"Gemini(OpenAI兼容)": {
|
||||||
@@ -640,6 +724,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "gemini-1.5-flash",
|
"model": "gemini-1.5-flash",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"Gemini": {
|
"Gemini": {
|
||||||
@@ -680,6 +765,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://api.deepseek.com/v1",
|
"api_base": "https://api.deepseek.com/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
"model_config": {"model": "deepseek-chat", "temperature": 0.4},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"302.AI": {
|
"302.AI": {
|
||||||
@@ -692,6 +778,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://api.302.ai/v1",
|
"api_base": "https://api.302.ai/v1",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
"model_config": {"model": "gpt-4.1-mini", "temperature": 0.4},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"硅基流动": {
|
"硅基流动": {
|
||||||
@@ -707,6 +794,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "deepseek-ai/DeepSeek-V3",
|
"model": "deepseek-ai/DeepSeek-V3",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"PPIO派欧云": {
|
"PPIO派欧云": {
|
||||||
@@ -722,6 +810,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"model": "deepseek/deepseek-r1",
|
"model": "deepseek/deepseek-r1",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
|
"custom_extra_body": {},
|
||||||
},
|
},
|
||||||
"优云智算": {
|
"优云智算": {
|
||||||
"id": "compshare",
|
"id": "compshare",
|
||||||
@@ -735,6 +824,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "moonshotai/Kimi-K2-Instruct",
|
"model": "moonshotai/Kimi-K2-Instruct",
|
||||||
},
|
},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"Kimi": {
|
"Kimi": {
|
||||||
@@ -747,6 +837,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"api_base": "https://api.moonshot.cn/v1",
|
"api_base": "https://api.moonshot.cn/v1",
|
||||||
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
"model_config": {"model": "moonshot-v1-8k", "temperature": 0.4},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"智谱 AI": {
|
"智谱 AI": {
|
||||||
@@ -805,6 +896,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||||
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
"model_config": {"model": "Qwen/Qwen3-32B", "temperature": 0.4},
|
||||||
|
"custom_extra_body": {},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"FastGPT": {
|
"FastGPT": {
|
||||||
@@ -816,6 +908,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"key": [],
|
"key": [],
|
||||||
"api_base": "https://api.fastgpt.in/api/v1",
|
"api_base": "https://api.fastgpt.in/api/v1",
|
||||||
"timeout": 60,
|
"timeout": 60,
|
||||||
|
"custom_extra_body": {},
|
||||||
},
|
},
|
||||||
"Whisper(API)": {
|
"Whisper(API)": {
|
||||||
"id": "whisper",
|
"id": "whisper",
|
||||||
@@ -1060,6 +1153,12 @@ CONFIG_METADATA_2 = {
|
|||||||
"render_type": "checkbox",
|
"render_type": "checkbox",
|
||||||
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
|
"hint": "模型支持的模态。如所填写的模型不支持图像,请取消勾选图像。",
|
||||||
},
|
},
|
||||||
|
"custom_extra_body": {
|
||||||
|
"description": "自定义请求体参数",
|
||||||
|
"type": "dict",
|
||||||
|
"items": {},
|
||||||
|
"hint": "此处添加的键值对将被合并到发送给 API 的 extra_body 中。值可以是字符串、数字或布尔值。",
|
||||||
|
},
|
||||||
"provider": {
|
"provider": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"invisible": True,
|
"invisible": True,
|
||||||
@@ -1662,6 +1761,9 @@ CONFIG_METADATA_2 = {
|
|||||||
"identifier": {
|
"identifier": {
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
|
"group_name_display": {
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
"datetime_system_prompt": {
|
"datetime_system_prompt": {
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
@@ -1841,17 +1943,31 @@ CONFIG_METADATA_3 = {
|
|||||||
"_special": "select_provider",
|
"_special": "select_provider",
|
||||||
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
"hint": "留空代表不使用。可用于不支持视觉模态的聊天模型。",
|
||||||
},
|
},
|
||||||
|
"provider_stt_settings.enable": {
|
||||||
|
"description": "默认启用语音转文本",
|
||||||
|
"type": "bool",
|
||||||
|
},
|
||||||
"provider_stt_settings.provider_id": {
|
"provider_stt_settings.provider_id": {
|
||||||
"description": "语音转文本模型",
|
"description": "语音转文本模型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "留空代表不使用。",
|
"hint": "留空代表不使用。",
|
||||||
"_special": "select_provider_stt",
|
"_special": "select_provider_stt",
|
||||||
|
"condition": {
|
||||||
|
"provider_stt_settings.enable": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"provider_tts_settings.enable": {
|
||||||
|
"description": "默认启用文本转语音",
|
||||||
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"provider_tts_settings.provider_id": {
|
"provider_tts_settings.provider_id": {
|
||||||
"description": "文本转语音模型",
|
"description": "文本转语音模型",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"hint": "留空代表不使用。",
|
"hint": "留空代表不使用。",
|
||||||
"_special": "select_provider_tts",
|
"_special": "select_provider_tts",
|
||||||
|
"condition": {
|
||||||
|
"provider_tts_settings.enable": True,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
"provider_settings.image_caption_prompt": {
|
"provider_settings.image_caption_prompt": {
|
||||||
"description": "图片转述提示词",
|
"description": "图片转述提示词",
|
||||||
@@ -1896,7 +2012,9 @@ CONFIG_METADATA_3 = {
|
|||||||
},
|
},
|
||||||
"provider_settings.websearch_tavily_key": {
|
"provider_settings.websearch_tavily_key": {
|
||||||
"description": "Tavily API Key",
|
"description": "Tavily API Key",
|
||||||
"type": "string",
|
"type": "list",
|
||||||
|
"items": {"type": "string"},
|
||||||
|
"hint": "可添加多个 Key 进行轮询。",
|
||||||
"condition": {
|
"condition": {
|
||||||
"provider_settings.websearch_provider": "tavily",
|
"provider_settings.websearch_provider": "tavily",
|
||||||
},
|
},
|
||||||
@@ -1919,6 +2037,11 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "用户识别",
|
"description": "用户识别",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
|
"provider_settings.group_name_display": {
|
||||||
|
"description": "显示群名称",
|
||||||
|
"type": "bool",
|
||||||
|
"hint": "启用后,在支持的平台(aiocqhttp)上会在 prompt 中包含群名称信息。",
|
||||||
|
},
|
||||||
"provider_settings.datetime_system_prompt": {
|
"provider_settings.datetime_system_prompt": {
|
||||||
"description": "现实世界时间感知",
|
"description": "现实世界时间感知",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
@@ -2066,41 +2189,41 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "内容安全",
|
"description": "内容安全",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
"platform_settings.content_safety.also_use_in_response": {
|
"content_safety.also_use_in_response": {
|
||||||
"description": "同时检查模型的响应内容",
|
"description": "同时检查模型的响应内容",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"platform_settings.content_safety.baidu_aip.enable": {
|
"content_safety.baidu_aip.enable": {
|
||||||
"description": "使用百度内容安全审核",
|
"description": "使用百度内容安全审核",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
"hint": "您需要手动安装 baidu-aip 库。",
|
"hint": "您需要手动安装 baidu-aip 库。",
|
||||||
},
|
},
|
||||||
"platform_settings.content_safety.baidu_aip.app_id": {
|
"content_safety.baidu_aip.app_id": {
|
||||||
"description": "App ID",
|
"description": "App ID",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"condition": {
|
"condition": {
|
||||||
"platform_settings.content_safety.baidu_aip.enable": True,
|
"content_safety.baidu_aip.enable": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"platform_settings.content_safety.baidu_aip.api_key": {
|
"content_safety.baidu_aip.api_key": {
|
||||||
"description": "API Key",
|
"description": "API Key",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"condition": {
|
"condition": {
|
||||||
"platform_settings.content_safety.baidu_aip.enable": True,
|
"content_safety.baidu_aip.enable": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"platform_settings.content_safety.baidu_aip.secret_key": {
|
"content_safety.baidu_aip.secret_key": {
|
||||||
"description": "Secret Key",
|
"description": "Secret Key",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"condition": {
|
"condition": {
|
||||||
"platform_settings.content_safety.baidu_aip.enable": True,
|
"content_safety.baidu_aip.enable": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"platform_settings.content_safety.internal_keywords.enable": {
|
"content_safety.internal_keywords.enable": {
|
||||||
"description": "关键词检查",
|
"description": "关键词检查",
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"platform_settings.content_safety.internal_keywords.extra_keywords": {
|
"content_safety.internal_keywords.extra_keywords": {
|
||||||
"description": "额外关键词",
|
"description": "额外关键词",
|
||||||
"type": "list",
|
"type": "list",
|
||||||
"items": {"type": "string"},
|
"items": {"type": "string"},
|
||||||
@@ -2293,6 +2416,12 @@ CONFIG_METADATA_3_SYSTEM = {
|
|||||||
},
|
},
|
||||||
"_special": "t2i_template",
|
"_special": "t2i_template",
|
||||||
},
|
},
|
||||||
|
"t2i_active_template": {
|
||||||
|
"description": "当前应用的文转图渲染模板",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "此处的值由文转图模板管理页面进行维护。",
|
||||||
|
"invisible": True,
|
||||||
|
},
|
||||||
"log_level": {
|
"log_level": {
|
||||||
"description": "控制台日志级别",
|
"description": "控制台日志级别",
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ async def do_migration_v4(
|
|||||||
await migration_webchat_data(db_helper, platform_id_map)
|
await migration_webchat_data(db_helper, platform_id_map)
|
||||||
|
|
||||||
# 执行偏好设置迁移
|
# 执行偏好设置迁移
|
||||||
await migration_preferences(db_helper,platform_id_map)
|
await migration_preferences(db_helper, platform_id_map)
|
||||||
|
|
||||||
# 执行平台统计表迁移
|
# 执行平台统计表迁移
|
||||||
await migration_platform_table(db_helper, platform_id_map)
|
await migration_platform_table(db_helper, platform_id_map)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|||||||
|
|
||||||
_VT = TypeVar("_VT")
|
_VT = TypeVar("_VT")
|
||||||
|
|
||||||
|
|
||||||
class SharedPreferences:
|
class SharedPreferences:
|
||||||
def __init__(self, path=None):
|
def __init__(self, path=None):
|
||||||
if path is None:
|
if path is None:
|
||||||
@@ -42,4 +43,5 @@ class SharedPreferences:
|
|||||||
self._data.clear()
|
self._data.clear()
|
||||||
self._save_preferences()
|
self._save_preferences()
|
||||||
|
|
||||||
|
|
||||||
sp = SharedPreferences()
|
sp = SharedPreferences()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from astrbot.core.db.po import Platform, Stats
|
|||||||
from typing import Tuple, List, Dict, Any
|
from typing import Tuple, List, Dict, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Conversation:
|
class Conversation:
|
||||||
"""LLM 对话存储
|
"""LLM 对话存储
|
||||||
@@ -76,7 +77,7 @@ PRAGMA encoding = 'UTF-8';
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class SQLiteDatabase():
|
class SQLiteDatabase:
|
||||||
def __init__(self, db_path: str) -> None:
|
def __init__(self, db_path: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.db_path = db_path
|
self.db_path = db_path
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from astrbot.core.db.po import (
|
|||||||
from sqlalchemy import select, update, delete, text
|
from sqlalchemy import select, update, delete, text
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.sql import func
|
from sqlalchemy.sql import func
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||||
|
|
||||||
@@ -153,8 +154,22 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
ConversationV2.platform_id.in_(platform_ids)
|
ConversationV2.platform_id.in_(platform_ids)
|
||||||
)
|
)
|
||||||
if search_query:
|
if search_query:
|
||||||
|
search_query = search_query.encode("unicode_escape").decode("utf-8")
|
||||||
base_query = base_query.where(
|
base_query = base_query.where(
|
||||||
ConversationV2.title.ilike(f"%{search_query}%")
|
or_(
|
||||||
|
ConversationV2.title.ilike(f"%{search_query}%"),
|
||||||
|
ConversationV2.content.ilike(f"%{search_query}%"),
|
||||||
|
ConversationV2.user_id.ilike(f"%{search_query}%"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if "message_types" in kwargs and len(kwargs["message_types"]) > 0:
|
||||||
|
for msg_type in kwargs["message_types"]:
|
||||||
|
base_query = base_query.where(
|
||||||
|
ConversationV2.user_id.ilike(f"%:{msg_type}:%")
|
||||||
|
)
|
||||||
|
if "platforms" in kwargs and len(kwargs["platforms"]) > 0:
|
||||||
|
base_query = base_query.where(
|
||||||
|
ConversationV2.platform_id.in_(kwargs["platforms"])
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get total count matching the filters
|
# Get total count matching the filters
|
||||||
|
|||||||
@@ -113,7 +113,8 @@ class FaissVecDB(BaseVecDB):
|
|||||||
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
reranked_results, key=lambda x: x.relevance_score, reverse=True
|
||||||
)
|
)
|
||||||
top_k_results = [
|
top_k_results = [
|
||||||
top_k_results[reranked_result.index] for reranked_result in reranked_results
|
top_k_results[reranked_result.index]
|
||||||
|
for reranked_result in reranked_results
|
||||||
]
|
]
|
||||||
|
|
||||||
return top_k_results
|
return top_k_results
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ class InitialLoader:
|
|||||||
self.db = db
|
self.db = db
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
self.log_broker = log_broker
|
self.log_broker = log_broker
|
||||||
|
self.webui_dir: str | None = None
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
|
||||||
@@ -35,8 +36,10 @@ class InitialLoader:
|
|||||||
|
|
||||||
core_task = core_lifecycle.start()
|
core_task = core_lifecycle.start()
|
||||||
|
|
||||||
|
webui_dir = self.webui_dir
|
||||||
|
|
||||||
self.dashboard_server = AstrBotDashboard(
|
self.dashboard_server = AstrBotDashboard(
|
||||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event, webui_dir
|
||||||
)
|
)
|
||||||
task = asyncio.gather(
|
task = asyncio.gather(
|
||||||
core_task, self.dashboard_server.run()
|
core_task, self.dashboard_server.run()
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
|||||||
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||||
|
|
||||||
|
|
||||||
class ComponentType(Enum):
|
class ComponentType(str, Enum):
|
||||||
Plain = "Plain" # 纯文本消息
|
Plain = "Plain" # 纯文本消息
|
||||||
Face = "Face" # QQ表情
|
Face = "Face" # QQ表情
|
||||||
Record = "Record" # 语音
|
Record = "Record" # 语音
|
||||||
@@ -108,7 +108,7 @@ class BaseMessageComponent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class Plain(BaseMessageComponent):
|
class Plain(BaseMessageComponent):
|
||||||
type: ComponentType = "Plain"
|
type = ComponentType.Plain
|
||||||
text: str
|
text: str
|
||||||
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
|
convert: T.Optional[bool] = True # 若为 False 则直接发送未转换 CQ 码的消息
|
||||||
|
|
||||||
@@ -128,8 +128,9 @@ class Plain(BaseMessageComponent):
|
|||||||
async def to_dict(self):
|
async def to_dict(self):
|
||||||
return {"type": "text", "data": {"text": self.text}}
|
return {"type": "text", "data": {"text": self.text}}
|
||||||
|
|
||||||
|
|
||||||
class Face(BaseMessageComponent):
|
class Face(BaseMessageComponent):
|
||||||
type: ComponentType = "Face"
|
type = ComponentType.Face
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -137,7 +138,7 @@ class Face(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Record(BaseMessageComponent):
|
class Record(BaseMessageComponent):
|
||||||
type: ComponentType = "Record"
|
type = ComponentType.Record
|
||||||
file: T.Optional[str] = ""
|
file: T.Optional[str] = ""
|
||||||
magic: T.Optional[bool] = False
|
magic: T.Optional[bool] = False
|
||||||
url: T.Optional[str] = ""
|
url: T.Optional[str] = ""
|
||||||
@@ -164,19 +165,24 @@ class Record(BaseMessageComponent):
|
|||||||
return Record(file=url, **_)
|
return Record(file=url, **_)
|
||||||
raise Exception("not a valid url")
|
raise Exception("not a valid url")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fromBase64(bs64_data: str, **_):
|
||||||
|
return Record(file=f"base64://{bs64_data}", **_)
|
||||||
|
|
||||||
async def convert_to_file_path(self) -> str:
|
async def convert_to_file_path(self) -> str:
|
||||||
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
"""将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: 语音的本地路径,以绝对路径表示。
|
str: 语音的本地路径,以绝对路径表示。
|
||||||
"""
|
"""
|
||||||
if self.file and self.file.startswith("file:///"):
|
if not self.file:
|
||||||
file_path = self.file[8:]
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
return file_path
|
if self.file.startswith("file:///"):
|
||||||
elif self.file and self.file.startswith("http"):
|
return self.file[8:]
|
||||||
|
elif self.file.startswith("http"):
|
||||||
file_path = await download_image_by_url(self.file)
|
file_path = await download_image_by_url(self.file)
|
||||||
return os.path.abspath(file_path)
|
return os.path.abspath(file_path)
|
||||||
elif self.file and self.file.startswith("base64://"):
|
elif self.file.startswith("base64://"):
|
||||||
bs64_data = self.file.removeprefix("base64://")
|
bs64_data = self.file.removeprefix("base64://")
|
||||||
image_bytes = base64.b64decode(bs64_data)
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
@@ -185,8 +191,7 @@ class Record(BaseMessageComponent):
|
|||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
return os.path.abspath(file_path)
|
return os.path.abspath(file_path)
|
||||||
elif os.path.exists(self.file):
|
elif os.path.exists(self.file):
|
||||||
file_path = self.file
|
return os.path.abspath(self.file)
|
||||||
return os.path.abspath(file_path)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"not a valid file: {self.file}")
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
|
|
||||||
@@ -197,12 +202,14 @@ class Record(BaseMessageComponent):
|
|||||||
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
"""
|
"""
|
||||||
# convert to base64
|
# convert to base64
|
||||||
if self.file and self.file.startswith("file:///"):
|
if not self.file:
|
||||||
|
raise Exception(f"not a valid file: {self.file}")
|
||||||
|
if self.file.startswith("file:///"):
|
||||||
bs64_data = file_to_base64(self.file[8:])
|
bs64_data = file_to_base64(self.file[8:])
|
||||||
elif self.file and self.file.startswith("http"):
|
elif self.file.startswith("http"):
|
||||||
file_path = await download_image_by_url(self.file)
|
file_path = await download_image_by_url(self.file)
|
||||||
bs64_data = file_to_base64(file_path)
|
bs64_data = file_to_base64(file_path)
|
||||||
elif self.file and self.file.startswith("base64://"):
|
elif self.file.startswith("base64://"):
|
||||||
bs64_data = self.file
|
bs64_data = self.file
|
||||||
elif os.path.exists(self.file):
|
elif os.path.exists(self.file):
|
||||||
bs64_data = file_to_base64(self.file)
|
bs64_data = file_to_base64(self.file)
|
||||||
@@ -236,7 +243,7 @@ class Record(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Video(BaseMessageComponent):
|
class Video(BaseMessageComponent):
|
||||||
type: ComponentType = "Video"
|
type = ComponentType.Video
|
||||||
file: str
|
file: str
|
||||||
cover: T.Optional[str] = ""
|
cover: T.Optional[str] = ""
|
||||||
c: T.Optional[int] = 2
|
c: T.Optional[int] = 2
|
||||||
@@ -322,7 +329,7 @@ class Video(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class At(BaseMessageComponent):
|
class At(BaseMessageComponent):
|
||||||
type: ComponentType = "At"
|
type = ComponentType.At
|
||||||
qq: T.Union[int, str] # 此处str为all时代表所有人
|
qq: T.Union[int, str] # 此处str为all时代表所有人
|
||||||
name: T.Optional[str] = ""
|
name: T.Optional[str] = ""
|
||||||
|
|
||||||
@@ -344,28 +351,28 @@ class AtAll(At):
|
|||||||
|
|
||||||
|
|
||||||
class RPS(BaseMessageComponent): # TODO
|
class RPS(BaseMessageComponent): # TODO
|
||||||
type: ComponentType = "RPS"
|
type = ComponentType.RPS
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
class Dice(BaseMessageComponent): # TODO
|
class Dice(BaseMessageComponent): # TODO
|
||||||
type: ComponentType = "Dice"
|
type = ComponentType.Dice
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
class Shake(BaseMessageComponent): # TODO
|
class Shake(BaseMessageComponent): # TODO
|
||||||
type: ComponentType = "Shake"
|
type = ComponentType.Shake
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
super().__init__(**_)
|
super().__init__(**_)
|
||||||
|
|
||||||
|
|
||||||
class Anonymous(BaseMessageComponent): # TODO
|
class Anonymous(BaseMessageComponent): # TODO
|
||||||
type: ComponentType = "Anonymous"
|
type = ComponentType.Anonymous
|
||||||
ignore: T.Optional[bool] = False
|
ignore: T.Optional[bool] = False
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -373,7 +380,7 @@ class Anonymous(BaseMessageComponent): # TODO
|
|||||||
|
|
||||||
|
|
||||||
class Share(BaseMessageComponent):
|
class Share(BaseMessageComponent):
|
||||||
type: ComponentType = "Share"
|
type = ComponentType.Share
|
||||||
url: str
|
url: str
|
||||||
title: str
|
title: str
|
||||||
content: T.Optional[str] = ""
|
content: T.Optional[str] = ""
|
||||||
@@ -384,7 +391,7 @@ class Share(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Contact(BaseMessageComponent): # TODO
|
class Contact(BaseMessageComponent): # TODO
|
||||||
type: ComponentType = "Contact"
|
type = ComponentType.Contact
|
||||||
_type: str # type 字段冲突
|
_type: str # type 字段冲突
|
||||||
id: T.Optional[int] = 0
|
id: T.Optional[int] = 0
|
||||||
|
|
||||||
@@ -393,7 +400,7 @@ class Contact(BaseMessageComponent): # TODO
|
|||||||
|
|
||||||
|
|
||||||
class Location(BaseMessageComponent): # TODO
|
class Location(BaseMessageComponent): # TODO
|
||||||
type: ComponentType = "Location"
|
type = ComponentType.Location
|
||||||
lat: float
|
lat: float
|
||||||
lon: float
|
lon: float
|
||||||
title: T.Optional[str] = ""
|
title: T.Optional[str] = ""
|
||||||
@@ -404,7 +411,7 @@ class Location(BaseMessageComponent): # TODO
|
|||||||
|
|
||||||
|
|
||||||
class Music(BaseMessageComponent):
|
class Music(BaseMessageComponent):
|
||||||
type: ComponentType = "Music"
|
type = ComponentType.Music
|
||||||
_type: str
|
_type: str
|
||||||
id: T.Optional[int] = 0
|
id: T.Optional[int] = 0
|
||||||
url: T.Optional[str] = ""
|
url: T.Optional[str] = ""
|
||||||
@@ -421,7 +428,7 @@ class Music(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Image(BaseMessageComponent):
|
class Image(BaseMessageComponent):
|
||||||
type: ComponentType = "Image"
|
type = ComponentType.Image
|
||||||
file: T.Optional[str] = ""
|
file: T.Optional[str] = ""
|
||||||
_type: T.Optional[str] = ""
|
_type: T.Optional[str] = ""
|
||||||
subType: T.Optional[int] = 0
|
subType: T.Optional[int] = 0
|
||||||
@@ -464,14 +471,15 @@ class Image(BaseMessageComponent):
|
|||||||
Returns:
|
Returns:
|
||||||
str: 图片的本地路径,以绝对路径表示。
|
str: 图片的本地路径,以绝对路径表示。
|
||||||
"""
|
"""
|
||||||
url = self.url if self.url else self.file
|
url = self.url or self.file
|
||||||
if url and url.startswith("file:///"):
|
if not url:
|
||||||
image_file_path = url[8:]
|
raise ValueError("No valid file or URL provided")
|
||||||
return image_file_path
|
if url.startswith("file:///"):
|
||||||
elif url and url.startswith("http"):
|
return url[8:]
|
||||||
|
elif url.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(url)
|
image_file_path = await download_image_by_url(url)
|
||||||
return os.path.abspath(image_file_path)
|
return os.path.abspath(image_file_path)
|
||||||
elif url and url.startswith("base64://"):
|
elif url.startswith("base64://"):
|
||||||
bs64_data = url.removeprefix("base64://")
|
bs64_data = url.removeprefix("base64://")
|
||||||
image_bytes = base64.b64decode(bs64_data)
|
image_bytes = base64.b64decode(bs64_data)
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
@@ -480,8 +488,7 @@ class Image(BaseMessageComponent):
|
|||||||
f.write(image_bytes)
|
f.write(image_bytes)
|
||||||
return os.path.abspath(image_file_path)
|
return os.path.abspath(image_file_path)
|
||||||
elif os.path.exists(url):
|
elif os.path.exists(url):
|
||||||
image_file_path = url
|
return os.path.abspath(url)
|
||||||
return os.path.abspath(image_file_path)
|
|
||||||
else:
|
else:
|
||||||
raise Exception(f"not a valid file: {url}")
|
raise Exception(f"not a valid file: {url}")
|
||||||
|
|
||||||
@@ -492,13 +499,15 @@ class Image(BaseMessageComponent):
|
|||||||
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。
|
||||||
"""
|
"""
|
||||||
# convert to base64
|
# convert to base64
|
||||||
url = self.url if self.url else self.file
|
url = self.url or self.file
|
||||||
if url and url.startswith("file:///"):
|
if not url:
|
||||||
|
raise ValueError("No valid file or URL provided")
|
||||||
|
if url.startswith("file:///"):
|
||||||
bs64_data = file_to_base64(url[8:])
|
bs64_data = file_to_base64(url[8:])
|
||||||
elif url and url.startswith("http"):
|
elif url.startswith("http"):
|
||||||
image_file_path = await download_image_by_url(url)
|
image_file_path = await download_image_by_url(url)
|
||||||
bs64_data = file_to_base64(image_file_path)
|
bs64_data = file_to_base64(image_file_path)
|
||||||
elif url and url.startswith("base64://"):
|
elif url.startswith("base64://"):
|
||||||
bs64_data = url
|
bs64_data = url
|
||||||
elif os.path.exists(url):
|
elif os.path.exists(url):
|
||||||
bs64_data = file_to_base64(url)
|
bs64_data = file_to_base64(url)
|
||||||
@@ -532,7 +541,7 @@ class Image(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Reply(BaseMessageComponent):
|
class Reply(BaseMessageComponent):
|
||||||
type: ComponentType = "Reply"
|
type = ComponentType.Reply
|
||||||
id: T.Union[str, int]
|
id: T.Union[str, int]
|
||||||
"""所引用的消息 ID"""
|
"""所引用的消息 ID"""
|
||||||
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||||
@@ -558,7 +567,7 @@ class Reply(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class RedBag(BaseMessageComponent):
|
class RedBag(BaseMessageComponent):
|
||||||
type: ComponentType = "RedBag"
|
type = ComponentType.RedBag
|
||||||
title: str
|
title: str
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -566,7 +575,7 @@ class RedBag(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Poke(BaseMessageComponent):
|
class Poke(BaseMessageComponent):
|
||||||
type: str = ""
|
type: str = ComponentType.Poke
|
||||||
id: T.Optional[int] = 0
|
id: T.Optional[int] = 0
|
||||||
qq: T.Optional[int] = 0
|
qq: T.Optional[int] = 0
|
||||||
|
|
||||||
@@ -576,7 +585,7 @@ class Poke(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Forward(BaseMessageComponent):
|
class Forward(BaseMessageComponent):
|
||||||
type: ComponentType = "Forward"
|
type = ComponentType.Forward
|
||||||
id: str
|
id: str
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -586,7 +595,7 @@ class Forward(BaseMessageComponent):
|
|||||||
class Node(BaseMessageComponent):
|
class Node(BaseMessageComponent):
|
||||||
"""群合并转发消息"""
|
"""群合并转发消息"""
|
||||||
|
|
||||||
type: ComponentType = "Node"
|
type = ComponentType.Node
|
||||||
id: T.Optional[int] = 0 # 忽略
|
id: T.Optional[int] = 0 # 忽略
|
||||||
name: T.Optional[str] = "" # qq昵称
|
name: T.Optional[str] = "" # qq昵称
|
||||||
uin: T.Optional[str] = "0" # qq号
|
uin: T.Optional[str] = "0" # qq号
|
||||||
@@ -638,7 +647,7 @@ class Node(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Nodes(BaseMessageComponent):
|
class Nodes(BaseMessageComponent):
|
||||||
type: ComponentType = "Nodes"
|
type = ComponentType.Nodes
|
||||||
nodes: T.List[Node]
|
nodes: T.List[Node]
|
||||||
|
|
||||||
def __init__(self, nodes: T.List[Node], **_):
|
def __init__(self, nodes: T.List[Node], **_):
|
||||||
@@ -664,7 +673,7 @@ class Nodes(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Xml(BaseMessageComponent):
|
class Xml(BaseMessageComponent):
|
||||||
type: ComponentType = "Xml"
|
type = ComponentType.Xml
|
||||||
data: str
|
data: str
|
||||||
resid: T.Optional[int] = 0
|
resid: T.Optional[int] = 0
|
||||||
|
|
||||||
@@ -673,7 +682,7 @@ class Xml(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Json(BaseMessageComponent):
|
class Json(BaseMessageComponent):
|
||||||
type: ComponentType = "Json"
|
type = ComponentType.Json
|
||||||
data: T.Union[str, dict]
|
data: T.Union[str, dict]
|
||||||
resid: T.Optional[int] = 0
|
resid: T.Optional[int] = 0
|
||||||
|
|
||||||
@@ -684,7 +693,7 @@ class Json(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class CardImage(BaseMessageComponent):
|
class CardImage(BaseMessageComponent):
|
||||||
type: ComponentType = "CardImage"
|
type = ComponentType.CardImage
|
||||||
file: str
|
file: str
|
||||||
cache: T.Optional[bool] = True
|
cache: T.Optional[bool] = True
|
||||||
minwidth: T.Optional[int] = 400
|
minwidth: T.Optional[int] = 400
|
||||||
@@ -703,7 +712,7 @@ class CardImage(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class TTS(BaseMessageComponent):
|
class TTS(BaseMessageComponent):
|
||||||
type: ComponentType = "TTS"
|
type = ComponentType.TTS
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
def __init__(self, **_):
|
def __init__(self, **_):
|
||||||
@@ -711,7 +720,7 @@ class TTS(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class Unknown(BaseMessageComponent):
|
class Unknown(BaseMessageComponent):
|
||||||
type: ComponentType = "Unknown"
|
type = ComponentType.Unknown
|
||||||
text: str
|
text: str
|
||||||
|
|
||||||
def toString(self):
|
def toString(self):
|
||||||
@@ -723,7 +732,7 @@ class File(BaseMessageComponent):
|
|||||||
文件消息段
|
文件消息段
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: ComponentType = "File"
|
type = ComponentType.File
|
||||||
name: T.Optional[str] = "" # 名字
|
name: T.Optional[str] = "" # 名字
|
||||||
file_: T.Optional[str] = "" # 本地路径
|
file_: T.Optional[str] = "" # 本地路径
|
||||||
url: T.Optional[str] = "" # url
|
url: T.Optional[str] = "" # url
|
||||||
@@ -853,7 +862,7 @@ class File(BaseMessageComponent):
|
|||||||
|
|
||||||
|
|
||||||
class WechatEmoji(BaseMessageComponent):
|
class WechatEmoji(BaseMessageComponent):
|
||||||
type: ComponentType = "WechatEmoji"
|
type = ComponentType.WechatEmoji
|
||||||
md5: T.Optional[str] = ""
|
md5: T.Optional[str] = ""
|
||||||
md5_len: T.Optional[int] = 0
|
md5_len: T.Optional[int] = 0
|
||||||
cdnurl: T.Optional[str] = ""
|
cdnurl: T.Optional[str] = ""
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class ContentSafetyCheckStage(Stage):
|
|||||||
self.strategy_selector = StrategySelector(config)
|
self.strategy_selector = StrategySelector(config)
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent, check_text: str = None
|
self, event: AstrMessageEvent, check_text: str | None = None
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
"""检查内容安全"""
|
"""检查内容安全"""
|
||||||
text = check_text if check_text else event.get_message_str()
|
text = check_text if check_text else event.get_message_str()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ class BaiduAipStrategy(ContentSafetyStrategy):
|
|||||||
self.secret_key = sk
|
self.secret_key = sk
|
||||||
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
|
self.client = AipContentCensor(self.app_id, self.api_key, self.secret_key)
|
||||||
|
|
||||||
def check(self, content: str):
|
def check(self, content: str) -> tuple[bool, str]:
|
||||||
res = self.client.textCensorUserDefined(content)
|
res = self.client.textCensorUserDefined(content)
|
||||||
if "conclusionType" not in res:
|
if "conclusionType" not in res:
|
||||||
return False, ""
|
return False, ""
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class KeywordsStrategy(ContentSafetyStrategy):
|
|||||||
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
# json.loads(base64.b64decode(f.read()).decode("utf-8"))["keywords"]
|
||||||
# )
|
# )
|
||||||
|
|
||||||
def check(self, content: str) -> bool:
|
def check(self, content: str) -> tuple[bool, str]:
|
||||||
for keyword in self.keywords:
|
for keyword in self.keywords:
|
||||||
if re.search(keyword, content):
|
if re.search(keyword, content):
|
||||||
return False, "内容安全检查不通过,匹配到敏感词。"
|
return False, "内容安全检查不通过,匹配到敏感词。"
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|||||||
|
|
||||||
async def call_handler(
|
async def call_handler(
|
||||||
event: AstrMessageEvent,
|
event: AstrMessageEvent,
|
||||||
handler: T.Awaitable,
|
handler: T.Callable[..., T.Awaitable[T.Any]],
|
||||||
*args,
|
*args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> T.AsyncGenerator[T.Any, None]:
|
) -> T.AsyncGenerator[T.Any, None]:
|
||||||
@@ -36,6 +36,9 @@ async def call_handler(
|
|||||||
except TypeError:
|
except TypeError:
|
||||||
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
logger.error("处理函数参数不匹配,请检查 handler 的定义。", exc_info=True)
|
||||||
|
|
||||||
|
if not ready_to_call:
|
||||||
|
return
|
||||||
|
|
||||||
if inspect.isasyncgen(ready_to_call):
|
if inspect.isasyncgen(ready_to_call):
|
||||||
_has_yielded = False
|
_has_yielded = False
|
||||||
try:
|
try:
|
||||||
@@ -77,7 +80,7 @@ async def call_event_hook(
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: 如果事件被终止,返回 True
|
bool: 如果事件被终止,返回 True
|
||||||
# """
|
#"""
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
hook_type, plugins_name=event.plugins_name
|
hook_type, plugins_name=event.plugins_name
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from typing import AsyncGenerator, Union
|
from typing import AsyncGenerator, Union
|
||||||
|
from astrbot.core.conversation_mgr import Conversation
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.components import Image
|
from astrbot.core.message.components import Image
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
@@ -133,6 +134,15 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
|
|
||||||
if agent_runner.done():
|
if agent_runner.done():
|
||||||
llm_response = agent_runner.get_final_llm_resp()
|
llm_response = agent_runner.get_final_llm_resp()
|
||||||
|
|
||||||
|
if not llm_response:
|
||||||
|
text_content = mcp.types.TextContent(
|
||||||
|
type="text",
|
||||||
|
text=f"error when deligate task to {tool.agent.name}",
|
||||||
|
)
|
||||||
|
yield mcp.types.CallToolResult(content=[text_content])
|
||||||
|
return
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
f"Agent {tool.agent.name} 任务完成, response: {llm_response.completion_text}"
|
||||||
)
|
)
|
||||||
@@ -148,7 +158,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
)
|
)
|
||||||
yield mcp.types.CallToolResult(content=[text_content])
|
yield mcp.types.CallToolResult(content=[text_content])
|
||||||
else:
|
else:
|
||||||
yield mcp.types.TextContent(
|
text_content = mcp.types.TextContent(
|
||||||
type="text",
|
type="text",
|
||||||
text=f"error when deligate task to {tool.agent.name}",
|
text=f"error when deligate task to {tool.agent.name}",
|
||||||
)
|
)
|
||||||
@@ -200,7 +210,11 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
|
|||||||
):
|
):
|
||||||
if not tool.mcp_client:
|
if not tool.mcp_client:
|
||||||
raise ValueError("MCP client is not available for MCP function tools.")
|
raise ValueError("MCP client is not available for MCP function tools.")
|
||||||
res = await tool.mcp_client.session.call_tool(
|
|
||||||
|
session = tool.mcp_client.session
|
||||||
|
if not session:
|
||||||
|
raise ValueError("MCP session is not available for MCP function tools.")
|
||||||
|
res = await session.call_tool(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
arguments=tool_args,
|
arguments=tool_args,
|
||||||
)
|
)
|
||||||
@@ -271,11 +285,11 @@ async def run_agent(
|
|||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
astr_event.set_result(
|
err_msg = f"\n\nAstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
||||||
MessageEventResult().message(
|
if agent_runner.streaming:
|
||||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}\n\n请在控制台查看和分享错误详情。\n"
|
yield MessageChain().message(err_msg)
|
||||||
)
|
else:
|
||||||
)
|
astr_event.set_result(MessageEventResult().message(err_msg))
|
||||||
return
|
return
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
Metric.upload(
|
Metric.upload(
|
||||||
@@ -325,7 +339,7 @@ class LLMRequestSubStage(Stage):
|
|||||||
|
|
||||||
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
return _ctx.get_using_provider(umo=event.unified_msg_origin)
|
||||||
|
|
||||||
async def _get_session_conv(self, event: AstrMessageEvent):
|
async def _get_session_conv(self, event: AstrMessageEvent) -> Conversation:
|
||||||
umo = event.unified_msg_origin
|
umo = event.unified_msg_origin
|
||||||
conv_mgr = self.conv_manager
|
conv_mgr = self.conv_manager
|
||||||
|
|
||||||
@@ -337,6 +351,8 @@ class LLMRequestSubStage(Stage):
|
|||||||
if not conversation:
|
if not conversation:
|
||||||
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
cid = await conv_mgr.new_conversation(umo, event.get_platform_id())
|
||||||
conversation = await conv_mgr.get_conversation(umo, cid)
|
conversation = await conv_mgr.get_conversation(umo, cid)
|
||||||
|
if not conversation:
|
||||||
|
raise RuntimeError("无法创建新的对话。")
|
||||||
return conversation
|
return conversation
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
@@ -444,7 +460,10 @@ class LLMRequestSubStage(Stage):
|
|||||||
if event.plugins_name is not None and req.func_tool:
|
if event.plugins_name is not None and req.func_tool:
|
||||||
new_tool_set = ToolSet()
|
new_tool_set = ToolSet()
|
||||||
for tool in req.func_tool.tools:
|
for tool in req.func_tool.tools:
|
||||||
plugin = star_map.get(tool.handler_module_path)
|
mp = tool.handler_module_path
|
||||||
|
if not mp:
|
||||||
|
continue
|
||||||
|
plugin = star_map.get(mp)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
continue
|
continue
|
||||||
if plugin.name in event.plugins_name or plugin.reserved:
|
if plugin.name in event.plugins_name or plugin.reserved:
|
||||||
|
|||||||
@@ -34,12 +34,14 @@ class StarRequestSubStage(Stage):
|
|||||||
|
|
||||||
for handler in activated_handlers:
|
for handler in activated_handlers:
|
||||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||||
try:
|
md = star_map.get(handler.handler_module_path)
|
||||||
if handler.handler_module_path not in star_map:
|
if not md:
|
||||||
continue
|
logger.warning(
|
||||||
logger.debug(
|
f"Cannot find plugin for given handler module path: {handler.handler_module_path}"
|
||||||
f"plugin -> {star_map.get(handler.handler_module_path).name} - {handler.handler_name}"
|
|
||||||
)
|
)
|
||||||
|
continue
|
||||||
|
logger.debug(f"plugin -> {md.name} - {handler.handler_name}")
|
||||||
|
try:
|
||||||
wrapper = call_handler(event, handler.handler, **params)
|
wrapper = call_handler(event, handler.handler, **params)
|
||||||
async for ret in wrapper:
|
async for ret in wrapper:
|
||||||
yield ret
|
yield ret
|
||||||
@@ -49,7 +51,7 @@ class StarRequestSubStage(Stage):
|
|||||||
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
|
||||||
|
|
||||||
if event.is_at_or_wake_command:
|
if event.is_at_or_wake_command:
|
||||||
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
ret = f":(\n\n在调用插件 {md.name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
|
||||||
event.set_result(MessageEventResult().message(ret))
|
event.set_result(MessageEventResult().message(ret))
|
||||||
yield
|
yield
|
||||||
event.clear_result()
|
event.clear_result()
|
||||||
|
|||||||
@@ -1,17 +1,15 @@
|
|||||||
import random
|
import random
|
||||||
import asyncio
|
import asyncio
|
||||||
import math
|
import math
|
||||||
import traceback
|
|
||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
from typing import Union, AsyncGenerator
|
from typing import Union, AsyncGenerator
|
||||||
from ..stage import register_stage, Stage
|
from ..stage import register_stage, Stage
|
||||||
from ..context import PipelineContext
|
from ..context import PipelineContext, call_event_hook
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
from astrbot.core.message.components import BaseMessageComponent, ComponentType
|
||||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
from astrbot.core.star.star_handler import EventType
|
||||||
from astrbot.core.star.star import star_map
|
|
||||||
from astrbot.core.utils.path_util import path_Mapping
|
from astrbot.core.utils.path_util import path_Mapping
|
||||||
from astrbot.core.utils.session_lock import session_lock_manager
|
from astrbot.core.utils.session_lock import session_lock_manager
|
||||||
|
|
||||||
@@ -114,6 +112,43 @@ class RespondStage(Stage):
|
|||||||
# 如果所有组件都为空
|
# 如果所有组件都为空
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def is_seg_reply_required(self, event: AstrMessageEvent) -> bool:
|
||||||
|
"""检查是否需要分段回复"""
|
||||||
|
if not self.enable_seg:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.only_llm_result and not event.get_result().is_llm_result():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if event.get_platform_name() in [
|
||||||
|
"qq_official",
|
||||||
|
"weixin_official_account",
|
||||||
|
"dingtalk",
|
||||||
|
]:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _extract_comp(
|
||||||
|
self,
|
||||||
|
raw_chain: list[BaseMessageComponent],
|
||||||
|
extract_types: set[ComponentType],
|
||||||
|
modify_raw_chain: bool = True,
|
||||||
|
):
|
||||||
|
extracted = []
|
||||||
|
if modify_raw_chain:
|
||||||
|
remaining = []
|
||||||
|
for comp in raw_chain:
|
||||||
|
if comp.type in extract_types:
|
||||||
|
extracted.append(comp)
|
||||||
|
else:
|
||||||
|
remaining.append(comp)
|
||||||
|
raw_chain[:] = remaining
|
||||||
|
else:
|
||||||
|
extracted = [comp for comp in raw_chain if comp.type in extract_types]
|
||||||
|
|
||||||
|
return extracted
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self, event: AstrMessageEvent
|
self, event: AstrMessageEvent
|
||||||
) -> Union[None, AsyncGenerator[None, None]]:
|
) -> Union[None, AsyncGenerator[None, None]]:
|
||||||
@@ -123,7 +158,14 @@ class RespondStage(Stage):
|
|||||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Prepare to send - {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
||||||
|
)
|
||||||
|
|
||||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||||
|
if result.async_stream is None:
|
||||||
|
logger.warning("async_stream 为空,跳过发送。")
|
||||||
|
return
|
||||||
# 流式结果直接交付平台适配器处理
|
# 流式结果直接交付平台适配器处理
|
||||||
use_fallback = self.config.get("provider_settings", {}).get(
|
use_fallback = self.config.get("provider_settings", {}).get(
|
||||||
"streaming_segmented", False
|
"streaming_segmented", False
|
||||||
@@ -148,87 +190,71 @@ class RespondStage(Stage):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"空内容检查异常: {e}")
|
logger.warning(f"空内容检查异常: {e}")
|
||||||
|
|
||||||
record_comps = [c for c in result.chain if isinstance(c, Comp.Record)]
|
# 发送消息链
|
||||||
non_record_comps = [
|
# Record 需要强制单独发送
|
||||||
c for c in result.chain if not isinstance(c, Comp.Record)
|
need_separately = {ComponentType.Record}
|
||||||
]
|
if self.is_seg_reply_required(event):
|
||||||
|
header_comps = self._extract_comp(
|
||||||
if (
|
result.chain,
|
||||||
self.enable_seg
|
{ComponentType.Reply, ComponentType.At},
|
||||||
and (
|
modify_raw_chain=True,
|
||||||
(self.only_llm_result and result.is_llm_result())
|
|
||||||
or not self.only_llm_result
|
|
||||||
)
|
)
|
||||||
and event.get_platform_name()
|
if not result.chain or len(result.chain) == 0:
|
||||||
not in ["qq_official", "weixin_official_account", "dingtalk"]
|
# may fix #2670
|
||||||
):
|
logger.warning(
|
||||||
decorated_comps = []
|
f"实际消息链为空, 跳过发送阶段。header_chain: {header_comps}, actual_chain: {result.chain}"
|
||||||
if self.reply_with_mention:
|
)
|
||||||
for comp in result.chain:
|
return
|
||||||
if isinstance(comp, Comp.At):
|
|
||||||
decorated_comps.append(comp)
|
|
||||||
result.chain.remove(comp)
|
|
||||||
break
|
|
||||||
if self.reply_with_quote:
|
|
||||||
for comp in result.chain:
|
|
||||||
if isinstance(comp, Comp.Reply):
|
|
||||||
decorated_comps.append(comp)
|
|
||||||
result.chain.remove(comp)
|
|
||||||
break
|
|
||||||
|
|
||||||
# leverage lock to guarentee the order of message sending among different events
|
|
||||||
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
async with session_lock_manager.acquire_lock(event.unified_msg_origin):
|
||||||
for rcomp in record_comps:
|
for comp in result.chain:
|
||||||
i = await self._calc_comp_interval(rcomp)
|
|
||||||
await asyncio.sleep(i)
|
|
||||||
try:
|
|
||||||
await event.send(MessageChain([rcomp]))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
|
||||||
break
|
|
||||||
# 分段回复
|
|
||||||
for comp in non_record_comps:
|
|
||||||
i = await self._calc_comp_interval(comp)
|
i = await self._calc_comp_interval(comp)
|
||||||
await asyncio.sleep(i)
|
await asyncio.sleep(i)
|
||||||
try:
|
try:
|
||||||
await event.send(MessageChain([*decorated_comps, comp]))
|
if comp.type in need_separately:
|
||||||
decorated_comps = [] # 清空已发送的装饰组件
|
await event.send(MessageChain([comp]))
|
||||||
|
else:
|
||||||
|
await event.send(MessageChain([*header_comps, comp]))
|
||||||
|
header_comps.clear()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
logger.error(
|
||||||
break
|
f"发送消息链失败: chain = {MessageChain([comp])}, error = {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for rcomp in record_comps:
|
if all(
|
||||||
|
comp.type in {ComponentType.Reply, ComponentType.At}
|
||||||
|
for comp in result.chain
|
||||||
|
):
|
||||||
|
# may fix #2670
|
||||||
|
logger.warning(
|
||||||
|
f"消息链全为 Reply 和 At 消息段, 跳过发送阶段。chain: {result.chain}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
sep_comps = self._extract_comp(
|
||||||
|
result.chain,
|
||||||
|
need_separately,
|
||||||
|
modify_raw_chain=True,
|
||||||
|
)
|
||||||
|
for comp in sep_comps:
|
||||||
|
chain = MessageChain([comp])
|
||||||
try:
|
try:
|
||||||
await event.send(MessageChain([rcomp]))
|
await event.send(chain)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
logger.error(
|
||||||
|
f"发送消息链失败: chain = {chain}, error = {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
chain = MessageChain(result.chain)
|
||||||
|
if result.chain and len(result.chain) > 0:
|
||||||
|
try:
|
||||||
|
await event.send(chain)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"发送消息链失败: chain = {chain}, error = {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
if await call_event_hook(event, EventType.OnAfterMessageSentEvent):
|
||||||
await event.send(MessageChain(non_record_comps))
|
return
|
||||||
except Exception as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
|
||||||
EventType.OnAfterMessageSentEvent, plugins_name=event.plugins_name
|
|
||||||
)
|
|
||||||
for handler in handlers:
|
|
||||||
try:
|
|
||||||
logger.debug(
|
|
||||||
f"hook(on_after_message_sent) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
|
||||||
)
|
|
||||||
await handler.handler(event)
|
|
||||||
except BaseException:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
if event.is_stopped():
|
|
||||||
logger.info(
|
|
||||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
event.clear_result()
|
event.clear_result()
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class ResultDecorateStage(Stage):
|
|||||||
self.t2i_word_threshold = 150
|
self.t2i_word_threshold = 150
|
||||||
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
|
self.t2i_strategy = ctx.astrbot_config["t2i_strategy"]
|
||||||
self.t2i_use_network = self.t2i_strategy == "remote"
|
self.t2i_use_network = self.t2i_strategy == "remote"
|
||||||
|
self.t2i_active_template = ctx.astrbot_config["t2i_active_template"]
|
||||||
|
|
||||||
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
self.forward_threshold = ctx.astrbot_config["platform_settings"][
|
||||||
"forward_threshold"
|
"forward_threshold"
|
||||||
@@ -247,7 +248,10 @@ class ResultDecorateStage(Stage):
|
|||||||
render_start = time.time()
|
render_start = time.time()
|
||||||
try:
|
try:
|
||||||
url = await html_renderer.render_t2i(
|
url = await html_renderer.render_t2i(
|
||||||
plain_str, return_url=True, use_network=self.t2i_use_network
|
plain_str,
|
||||||
|
return_url=True,
|
||||||
|
use_network=self.t2i_use_network,
|
||||||
|
template_name=self.t2i_active_template,
|
||||||
)
|
)
|
||||||
except BaseException:
|
except BaseException:
|
||||||
logger.error("文本转图片失败,使用文本发送。")
|
logger.error("文本转图片失败,使用文本发送。")
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ from astrbot.core.provider.entities import ProviderRequest
|
|||||||
from astrbot.core.utils.metrics import Metric
|
from astrbot.core.utils.metrics import Metric
|
||||||
from .astrbot_message import AstrBotMessage, Group
|
from .astrbot_message import AstrBotMessage, Group
|
||||||
from .platform_metadata import PlatformMetadata
|
from .platform_metadata import PlatformMetadata
|
||||||
from .message_session import MessageSession, MessageSesion # noqa
|
from .message_session import MessageSession, MessageSesion # noqa
|
||||||
|
|
||||||
|
|
||||||
class AstrMessageEvent(abc.ABC):
|
class AstrMessageEvent(abc.ABC):
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ class AstrBotMessage:
|
|||||||
self_id: str # 机器人的识别id
|
self_id: str # 机器人的识别id
|
||||||
session_id: str # 会话id。取决于 unique_session 的设置。
|
session_id: str # 会话id。取决于 unique_session 的设置。
|
||||||
message_id: str # 消息id
|
message_id: str # 消息id
|
||||||
group_id: str = "" # 群组id,如果为私聊,则为空
|
group: Group # 群组
|
||||||
sender: MessageMember # 发送者
|
sender: MessageMember # 发送者
|
||||||
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
message: List[BaseMessageComponent] # 消息链使用 Nakuru 的消息链格式
|
||||||
message_str: str # 最直观的纯文本消息字符串
|
message_str: str # 最直观的纯文本消息字符串
|
||||||
@@ -64,6 +64,28 @@ class AstrBotMessage:
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.timestamp = int(time.time())
|
self.timestamp = int(time.time())
|
||||||
|
self.group = None
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
return str(self.__dict__)
|
return str(self.__dict__)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def group_id(self) -> str:
|
||||||
|
"""
|
||||||
|
向后兼容的 group_id 属性
|
||||||
|
群组id,如果为私聊,则为空
|
||||||
|
"""
|
||||||
|
if self.group:
|
||||||
|
return self.group.group_id
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@group_id.setter
|
||||||
|
def group_id(self, value: str):
|
||||||
|
"""设置 group_id"""
|
||||||
|
if value:
|
||||||
|
if self.group:
|
||||||
|
self.group.group_id = value
|
||||||
|
else:
|
||||||
|
self.group = Group(group_id=value)
|
||||||
|
else:
|
||||||
|
self.group = None
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from typing import List
|
|||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from .register import platform_cls_map
|
from .register import platform_cls_map
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.star.star_handler import star_handlers_registry, star_map, EventType
|
||||||
from .sources.webchat.webchat_adapter import WebChatAdapter
|
from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||||
|
|
||||||
|
|
||||||
@@ -66,25 +67,39 @@ class PlatformManager:
|
|||||||
WeChatPadProAdapter, # noqa: F401
|
WeChatPadProAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
case "lark":
|
case "lark":
|
||||||
from .sources.lark.lark_adapter import LarkPlatformAdapter # noqa: F401
|
from .sources.lark.lark_adapter import (
|
||||||
|
LarkPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
case "dingtalk":
|
case "dingtalk":
|
||||||
from .sources.dingtalk.dingtalk_adapter import (
|
from .sources.dingtalk.dingtalk_adapter import (
|
||||||
DingtalkPlatformAdapter, # noqa: F401
|
DingtalkPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
case "telegram":
|
case "telegram":
|
||||||
from .sources.telegram.tg_adapter import TelegramPlatformAdapter # noqa: F401
|
from .sources.telegram.tg_adapter import (
|
||||||
|
TelegramPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
case "wecom":
|
case "wecom":
|
||||||
from .sources.wecom.wecom_adapter import WecomPlatformAdapter # noqa: F401
|
from .sources.wecom.wecom_adapter import (
|
||||||
|
WecomPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
case "weixin_official_account":
|
case "weixin_official_account":
|
||||||
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
from .sources.weixin_official_account.weixin_offacc_adapter import (
|
||||||
WeixinOfficialAccountPlatformAdapter, # noqa
|
WeixinOfficialAccountPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
case "discord":
|
case "discord":
|
||||||
from .sources.discord.discord_platform_adapter import (
|
from .sources.discord.discord_platform_adapter import (
|
||||||
DiscordPlatformAdapter, # noqa: F401
|
DiscordPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
|
case "misskey":
|
||||||
|
from .sources.misskey.misskey_adapter import (
|
||||||
|
MisskeyPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
case "slack":
|
case "slack":
|
||||||
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
from .sources.slack.slack_adapter import SlackAdapter # noqa: F401
|
||||||
|
case "satori":
|
||||||
|
from .sources.satori.satori_adapter import (
|
||||||
|
SatoriPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->控制台->安装Pip库 中安装依赖库。"
|
||||||
@@ -113,6 +128,17 @@ class PlatformManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||||
|
EventType.OnPlatformLoadedEvent
|
||||||
|
)
|
||||||
|
for handler in handlers:
|
||||||
|
try:
|
||||||
|
logger.info(
|
||||||
|
f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||||
|
)
|
||||||
|
await handler.handler()
|
||||||
|
except Exception:
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def _task_wrapper(self, task: asyncio.Task):
|
async def _task_wrapper(self, task: asyncio.Task):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -182,11 +182,13 @@ class AiocqhttpAdapter(Platform):
|
|||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.self_id = str(event.self_id)
|
abm.self_id = str(event.self_id)
|
||||||
abm.sender = MessageMember(
|
abm.sender = MessageMember(
|
||||||
str(event.sender["user_id"]), event.sender["nickname"]
|
str(event.sender["user_id"]),
|
||||||
|
event.sender.get("card") or event.sender.get("nickname", "N/A"),
|
||||||
)
|
)
|
||||||
if event["message_type"] == "group":
|
if event["message_type"] == "group":
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
abm.group_id = str(event.group_id)
|
abm.group_id = str(event.group_id)
|
||||||
|
abm.group.group_name = event.get("group_name", "N/A")
|
||||||
elif event["message_type"] == "private":
|
elif event["message_type"] == "private":
|
||||||
abm.type = MessageType.FRIEND_MESSAGE
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||||
@@ -308,13 +310,22 @@ class AiocqhttpAdapter(Platform):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
at_info = await self.bot.call_action(
|
at_info = await self.bot.call_action(
|
||||||
action="get_stranger_info",
|
action="get_group_member_info",
|
||||||
|
group_id=event.group_id,
|
||||||
user_id=int(m["data"]["qq"]),
|
user_id=int(m["data"]["qq"]),
|
||||||
|
no_cache=False,
|
||||||
)
|
)
|
||||||
if at_info:
|
if at_info:
|
||||||
nickname = at_info.get("nick", "") or at_info.get(
|
nickname = at_info.get("card", "")
|
||||||
"nickname", ""
|
if nickname == "":
|
||||||
)
|
at_info = await self.bot.call_action(
|
||||||
|
action="get_stranger_info",
|
||||||
|
user_id=int(m["data"]["qq"]),
|
||||||
|
no_cache=False,
|
||||||
|
)
|
||||||
|
nickname = at_info.get("nick", "") or at_info.get(
|
||||||
|
"nickname", ""
|
||||||
|
)
|
||||||
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
is_at_self = str(m["data"]["qq"]) in {abm.self_id, "all"}
|
||||||
|
|
||||||
abm.message.append(
|
abm.message.append(
|
||||||
|
|||||||
@@ -54,9 +54,9 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
|||||||
logger.debug(f"send image: {ret}")
|
logger.debug(f"send image: {ret}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"钉钉图片处理失败: {e}")
|
logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送")
|
||||||
logger.warning(f"跳过图片发送: {image_path}")
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def send(self, message: MessageChain):
|
async def send(self, message: MessageChain):
|
||||||
await self.send_with_client(self.client, message)
|
await self.send_with_client(self.client, message)
|
||||||
await super().send(message)
|
await super().send(message)
|
||||||
|
|||||||
@@ -41,7 +41,8 @@ class DiscordBotClient(discord.Bot):
|
|||||||
await self.on_ready_once_callback()
|
await self.on_ready_once_callback()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True)
|
f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
def _create_message_data(self, message: discord.Message) -> dict:
|
def _create_message_data(self, message: discord.Message) -> dict:
|
||||||
"""从 discord.Message 创建数据字典"""
|
"""从 discord.Message 创建数据字典"""
|
||||||
@@ -90,7 +91,6 @@ class DiscordBotClient(discord.Bot):
|
|||||||
message_data = self._create_message_data(message)
|
message_data = self._create_message_data(message)
|
||||||
await self.on_message_received(message_data)
|
await self.on_message_received(message_data)
|
||||||
|
|
||||||
|
|
||||||
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
def _extract_interaction_content(self, interaction: discord.Interaction) -> str:
|
||||||
"""从交互中提取内容"""
|
"""从交互中提取内容"""
|
||||||
interaction_type = interaction.type
|
interaction_type = interaction.type
|
||||||
|
|||||||
@@ -79,9 +79,12 @@ class DiscordButton(BaseMessageComponent):
|
|||||||
self.url = url
|
self.url = url
|
||||||
self.disabled = disabled
|
self.disabled = disabled
|
||||||
|
|
||||||
|
|
||||||
class DiscordReference(BaseMessageComponent):
|
class DiscordReference(BaseMessageComponent):
|
||||||
"""Discord引用组件"""
|
"""Discord引用组件"""
|
||||||
|
|
||||||
type: str = "discord_reference"
|
type: str = "discord_reference"
|
||||||
|
|
||||||
def __init__(self, message_id: str, channel_id: str):
|
def __init__(self, message_id: str, channel_id: str):
|
||||||
self.message_id = message_id
|
self.message_id = message_id
|
||||||
self.channel_id = channel_id
|
self.channel_id = channel_id
|
||||||
@@ -98,7 +101,6 @@ class DiscordView(BaseMessageComponent):
|
|||||||
self.components = components or []
|
self.components = components or []
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
|
|
||||||
|
|
||||||
def to_discord_view(self) -> discord.ui.View:
|
def to_discord_view(self) -> discord.ui.View:
|
||||||
"""转换为Discord View对象"""
|
"""转换为Discord View对象"""
|
||||||
view = discord.ui.View(timeout=self.timeout)
|
view = discord.ui.View(timeout=self.timeout)
|
||||||
|
|||||||
@@ -53,7 +53,13 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
|
|
||||||
# 解析消息链为 Discord 所需的对象
|
# 解析消息链为 Discord 所需的对象
|
||||||
try:
|
try:
|
||||||
content, files, view, embeds, reference_message_id = await self._parse_to_discord(message)
|
(
|
||||||
|
content,
|
||||||
|
files,
|
||||||
|
view,
|
||||||
|
embeds,
|
||||||
|
reference_message_id,
|
||||||
|
) = await self._parse_to_discord(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True)
|
||||||
return
|
return
|
||||||
@@ -206,8 +212,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
|
|||||||
if await asyncio.to_thread(path.exists):
|
if await asyncio.to_thread(path.exists):
|
||||||
file_bytes = await asyncio.to_thread(path.read_bytes)
|
file_bytes = await asyncio.to_thread(path.read_bytes)
|
||||||
files.append(
|
files.append(
|
||||||
discord.File(BytesIO(file_bytes),
|
discord.File(BytesIO(file_bytes), filename=i.name)
|
||||||
filename=i.name)
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -0,0 +1,391 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Dict, Any, Optional, Awaitable
|
||||||
|
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
AstrBotMessage,
|
||||||
|
Platform,
|
||||||
|
PlatformMetadata,
|
||||||
|
register_platform_adapter,
|
||||||
|
)
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSession
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
|
||||||
|
from .misskey_api import MisskeyAPI
|
||||||
|
from .misskey_event import MisskeyPlatformEvent
|
||||||
|
from .misskey_utils import (
|
||||||
|
serialize_message_chain,
|
||||||
|
resolve_message_visibility,
|
||||||
|
is_valid_user_session_id,
|
||||||
|
is_valid_room_session_id,
|
||||||
|
add_at_mention_if_needed,
|
||||||
|
process_files,
|
||||||
|
extract_sender_info,
|
||||||
|
create_base_message,
|
||||||
|
process_at_mention,
|
||||||
|
cache_user_info,
|
||||||
|
cache_room_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter("misskey", "Misskey 平台适配器")
|
||||||
|
class MisskeyPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
self.config = platform_config or {}
|
||||||
|
self.settings = platform_settings or {}
|
||||||
|
self.instance_url = self.config.get("misskey_instance_url", "")
|
||||||
|
self.access_token = self.config.get("misskey_token", "")
|
||||||
|
self.max_message_length = self.config.get("max_message_length", 3000)
|
||||||
|
self.default_visibility = self.config.get(
|
||||||
|
"misskey_default_visibility", "public"
|
||||||
|
)
|
||||||
|
self.local_only = self.config.get("misskey_local_only", False)
|
||||||
|
self.enable_chat = self.config.get("misskey_enable_chat", True)
|
||||||
|
|
||||||
|
self.unique_session = platform_settings["unique_session"]
|
||||||
|
|
||||||
|
self.api: Optional[MisskeyAPI] = None
|
||||||
|
self._running = False
|
||||||
|
self.client_self_id = ""
|
||||||
|
self._bot_username = ""
|
||||||
|
self._user_cache = {}
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
default_config = {
|
||||||
|
"misskey_instance_url": "",
|
||||||
|
"misskey_token": "",
|
||||||
|
"max_message_length": 3000,
|
||||||
|
"misskey_default_visibility": "public",
|
||||||
|
"misskey_local_only": False,
|
||||||
|
"misskey_enable_chat": True,
|
||||||
|
}
|
||||||
|
default_config.update(self.config)
|
||||||
|
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="misskey",
|
||||||
|
description="Misskey 平台适配器",
|
||||||
|
id=self.config.get("id", "misskey"),
|
||||||
|
default_config_tmpl=default_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
if not self.instance_url or not self.access_token:
|
||||||
|
logger.error("[Misskey] 配置不完整,无法启动")
|
||||||
|
return
|
||||||
|
|
||||||
|
self.api = MisskeyAPI(self.instance_url, self.access_token)
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
try:
|
||||||
|
user_info = await self.api.get_current_user()
|
||||||
|
self.client_self_id = str(user_info.get("id", ""))
|
||||||
|
self._bot_username = user_info.get("username", "")
|
||||||
|
logger.info(
|
||||||
|
f"[Misskey] 已连接用户: {self._bot_username} (ID: {self.client_self_id})"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey] 获取用户信息失败: {e}")
|
||||||
|
self._running = False
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._start_websocket_connection()
|
||||||
|
|
||||||
|
async def _start_websocket_connection(self):
|
||||||
|
backoff_delay = 1.0
|
||||||
|
max_backoff = 300.0
|
||||||
|
backoff_multiplier = 1.5
|
||||||
|
connection_attempts = 0
|
||||||
|
|
||||||
|
while self._running:
|
||||||
|
try:
|
||||||
|
connection_attempts += 1
|
||||||
|
if not self.api:
|
||||||
|
logger.error("[Misskey] API 客户端未初始化")
|
||||||
|
break
|
||||||
|
|
||||||
|
streaming = self.api.get_streaming_client()
|
||||||
|
streaming.add_message_handler("notification", self._handle_notification)
|
||||||
|
if self.enable_chat:
|
||||||
|
streaming.add_message_handler(
|
||||||
|
"newChatMessage", self._handle_chat_message
|
||||||
|
)
|
||||||
|
streaming.add_message_handler("_debug", self._debug_handler)
|
||||||
|
|
||||||
|
if await streaming.connect():
|
||||||
|
logger.info(
|
||||||
|
f"[Misskey] WebSocket 已连接 (尝试 #{connection_attempts})"
|
||||||
|
)
|
||||||
|
connection_attempts = 0 # 重置计数器
|
||||||
|
await streaming.subscribe_channel("main")
|
||||||
|
if self.enable_chat:
|
||||||
|
await streaming.subscribe_channel("messaging")
|
||||||
|
await streaming.subscribe_channel("messagingIndex")
|
||||||
|
logger.info("[Misskey] 聊天频道已订阅")
|
||||||
|
|
||||||
|
backoff_delay = 1.0 # 重置延迟
|
||||||
|
await streaming.listen()
|
||||||
|
else:
|
||||||
|
logger.error(
|
||||||
|
f"[Misskey] WebSocket 连接失败 (尝试 #{connection_attempts})"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
f"[Misskey] WebSocket 异常 (尝试 #{connection_attempts}): {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._running:
|
||||||
|
logger.info(
|
||||||
|
f"[Misskey] {backoff_delay:.1f}秒后重连 (下次尝试 #{connection_attempts + 1})"
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff_delay)
|
||||||
|
backoff_delay = min(backoff_delay * backoff_multiplier, max_backoff)
|
||||||
|
|
||||||
|
async def _handle_notification(self, data: Dict[str, Any]):
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey] 收到通知事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
notification_type = data.get("type")
|
||||||
|
if notification_type in ["mention", "reply", "quote"]:
|
||||||
|
note = data.get("note")
|
||||||
|
if note and self._is_bot_mentioned(note):
|
||||||
|
logger.info(
|
||||||
|
f"[Misskey] 处理贴文提及: {note.get('text', '')[:50]}..."
|
||||||
|
)
|
||||||
|
message = await self.convert_message(note)
|
||||||
|
event = MisskeyPlatformEvent(
|
||||||
|
message_str=message.message_str,
|
||||||
|
message_obj=message,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=message.session_id,
|
||||||
|
client=self.api,
|
||||||
|
)
|
||||||
|
self.commit_event(event)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey] 处理通知失败: {e}")
|
||||||
|
|
||||||
|
async def _handle_chat_message(self, data: Dict[str, Any]):
|
||||||
|
try:
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey] 收到聊天事件数据:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
sender_id = str(
|
||||||
|
data.get("fromUserId", "") or data.get("fromUser", {}).get("id", "")
|
||||||
|
)
|
||||||
|
if sender_id == self.client_self_id:
|
||||||
|
return
|
||||||
|
|
||||||
|
room_id = data.get("toRoomId")
|
||||||
|
if room_id:
|
||||||
|
raw_text = data.get("text", "")
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey] 检查群聊消息: '{raw_text}', 机器人用户名: '{self._bot_username}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
message = await self.convert_room_message(data)
|
||||||
|
logger.info(f"[Misskey] 处理群聊消息: {message.message_str[:50]}...")
|
||||||
|
else:
|
||||||
|
message = await self.convert_chat_message(data)
|
||||||
|
logger.info(f"[Misskey] 处理私聊消息: {message.message_str[:50]}...")
|
||||||
|
|
||||||
|
event = MisskeyPlatformEvent(
|
||||||
|
message_str=message.message_str,
|
||||||
|
message_obj=message,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=message.session_id,
|
||||||
|
client=self.api,
|
||||||
|
)
|
||||||
|
self.commit_event(event)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey] 处理聊天消息失败: {e}")
|
||||||
|
|
||||||
|
async def _debug_handler(self, data: Dict[str, Any]):
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey] 收到未处理事件:\n{json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _is_bot_mentioned(self, note: Dict[str, Any]) -> bool:
|
||||||
|
text = note.get("text", "")
|
||||||
|
if not text:
|
||||||
|
return False
|
||||||
|
|
||||||
|
mentions = note.get("mentions", [])
|
||||||
|
if self._bot_username and f"@{self._bot_username}" in text:
|
||||||
|
return True
|
||||||
|
if self.client_self_id in [str(uid) for uid in mentions]:
|
||||||
|
return True
|
||||||
|
|
||||||
|
reply = note.get("reply")
|
||||||
|
if reply and isinstance(reply, dict):
|
||||||
|
reply_user_id = str(reply.get("user", {}).get("id", ""))
|
||||||
|
if reply_user_id == self.client_self_id:
|
||||||
|
return bool(self._bot_username and f"@{self._bot_username}" in text)
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSession, message_chain: MessageChain
|
||||||
|
) -> Awaitable[Any]:
|
||||||
|
if not self.api:
|
||||||
|
logger.error("[Misskey] API 客户端未初始化")
|
||||||
|
return await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
try:
|
||||||
|
session_id = session.session_id
|
||||||
|
text, has_at_user = serialize_message_chain(message_chain.chain)
|
||||||
|
|
||||||
|
if not has_at_user and session_id:
|
||||||
|
user_info = self._user_cache.get(session_id)
|
||||||
|
text = add_at_mention_if_needed(text, user_info, has_at_user)
|
||||||
|
|
||||||
|
if not text or not text.strip():
|
||||||
|
logger.warning("[Misskey] 消息内容为空,跳过发送")
|
||||||
|
return await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
if len(text) > self.max_message_length:
|
||||||
|
text = text[: self.max_message_length] + "..."
|
||||||
|
|
||||||
|
if session_id and is_valid_user_session_id(session_id):
|
||||||
|
from .misskey_utils import extract_user_id_from_session_id
|
||||||
|
|
||||||
|
user_id = extract_user_id_from_session_id(session_id)
|
||||||
|
await self.api.send_message(user_id, text)
|
||||||
|
elif session_id and is_valid_room_session_id(session_id):
|
||||||
|
from .misskey_utils import extract_room_id_from_session_id
|
||||||
|
|
||||||
|
room_id = extract_room_id_from_session_id(session_id)
|
||||||
|
await self.api.send_room_message(room_id, text)
|
||||||
|
else:
|
||||||
|
visibility, visible_user_ids = resolve_message_visibility(
|
||||||
|
user_id=session_id,
|
||||||
|
user_cache=self._user_cache,
|
||||||
|
self_id=self.client_self_id,
|
||||||
|
default_visibility=self.default_visibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
await self.api.create_note(
|
||||||
|
text,
|
||||||
|
visibility=visibility,
|
||||||
|
visible_user_ids=visible_user_ids,
|
||||||
|
local_only=self.local_only,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey] 发送消息失败: {e}")
|
||||||
|
|
||||||
|
return await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
async def convert_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||||
|
"""将 Misskey 贴文数据转换为 AstrBotMessage 对象"""
|
||||||
|
sender_info = extract_sender_info(raw_data, is_chat=False)
|
||||||
|
message = create_base_message(
|
||||||
|
raw_data,
|
||||||
|
sender_info,
|
||||||
|
self.client_self_id,
|
||||||
|
is_chat=False,
|
||||||
|
unique_session=self.unique_session,
|
||||||
|
)
|
||||||
|
cache_user_info(
|
||||||
|
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
||||||
|
)
|
||||||
|
|
||||||
|
message_parts = []
|
||||||
|
raw_text = raw_data.get("text", "")
|
||||||
|
|
||||||
|
if raw_text:
|
||||||
|
text_parts, processed_text = process_at_mention(
|
||||||
|
message, raw_text, self._bot_username, self.client_self_id
|
||||||
|
)
|
||||||
|
message_parts.extend(text_parts)
|
||||||
|
|
||||||
|
files = raw_data.get("files", [])
|
||||||
|
file_parts = process_files(message, files)
|
||||||
|
message_parts.extend(file_parts)
|
||||||
|
|
||||||
|
message.message_str = (
|
||||||
|
" ".join(part for part in message_parts if part.strip())
|
||||||
|
if message_parts
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def convert_chat_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||||
|
"""将 Misskey 聊天消息数据转换为 AstrBotMessage 对象"""
|
||||||
|
sender_info = extract_sender_info(raw_data, is_chat=True)
|
||||||
|
message = create_base_message(
|
||||||
|
raw_data,
|
||||||
|
sender_info,
|
||||||
|
self.client_self_id,
|
||||||
|
is_chat=True,
|
||||||
|
unique_session=self.unique_session,
|
||||||
|
)
|
||||||
|
cache_user_info(
|
||||||
|
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=True
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_text = raw_data.get("text", "")
|
||||||
|
if raw_text:
|
||||||
|
message.message.append(Comp.Plain(raw_text))
|
||||||
|
|
||||||
|
files = raw_data.get("files", [])
|
||||||
|
process_files(message, files, include_text_parts=False)
|
||||||
|
|
||||||
|
message.message_str = raw_text if raw_text else ""
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def convert_room_message(self, raw_data: Dict[str, Any]) -> AstrBotMessage:
|
||||||
|
"""将 Misskey 群聊消息数据转换为 AstrBotMessage 对象"""
|
||||||
|
sender_info = extract_sender_info(raw_data, is_chat=True)
|
||||||
|
room_id = raw_data.get("toRoomId", "")
|
||||||
|
message = create_base_message(
|
||||||
|
raw_data,
|
||||||
|
sender_info,
|
||||||
|
self.client_self_id,
|
||||||
|
is_chat=False,
|
||||||
|
room_id=room_id,
|
||||||
|
unique_session=self.unique_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
cache_user_info(
|
||||||
|
self._user_cache, sender_info, raw_data, self.client_self_id, is_chat=False
|
||||||
|
)
|
||||||
|
cache_room_info(self._user_cache, raw_data, self.client_self_id)
|
||||||
|
|
||||||
|
raw_text = raw_data.get("text", "")
|
||||||
|
message_parts = []
|
||||||
|
|
||||||
|
if raw_text:
|
||||||
|
if self._bot_username and f"@{self._bot_username}" in raw_text:
|
||||||
|
text_parts, processed_text = process_at_mention(
|
||||||
|
message, raw_text, self._bot_username, self.client_self_id
|
||||||
|
)
|
||||||
|
message_parts.extend(text_parts)
|
||||||
|
else:
|
||||||
|
message.message.append(Comp.Plain(raw_text))
|
||||||
|
message_parts.append(raw_text)
|
||||||
|
|
||||||
|
files = raw_data.get("files", [])
|
||||||
|
file_parts = process_files(message, files)
|
||||||
|
message_parts.extend(file_parts)
|
||||||
|
|
||||||
|
message.message_str = (
|
||||||
|
" ".join(part for part in message_parts if part.strip())
|
||||||
|
if message_parts
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
return message
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self._running = False
|
||||||
|
if self.api:
|
||||||
|
await self.api.close()
|
||||||
|
|
||||||
|
def get_client(self) -> Any:
|
||||||
|
return self.api
|
||||||
@@ -0,0 +1,404 @@
|
|||||||
|
import json
|
||||||
|
from typing import Any, Optional, Dict, List, Callable, Awaitable
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
try:
|
||||||
|
import aiohttp
|
||||||
|
import websockets
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"aiohttp and websockets are required for Misskey API. Please install them with: pip install aiohttp websockets"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
from astrbot.api import logger
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
API_MAX_RETRIES = 3
|
||||||
|
HTTP_OK = 200
|
||||||
|
|
||||||
|
|
||||||
|
class APIError(Exception):
|
||||||
|
"""Misskey API 基础异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class APIConnectionError(APIError):
|
||||||
|
"""网络连接异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class APIRateLimitError(APIError):
|
||||||
|
"""API 频率限制异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(APIError):
|
||||||
|
"""认证失败异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class WebSocketError(APIError):
|
||||||
|
"""WebSocket 连接异常"""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class StreamingClient:
|
||||||
|
def __init__(self, instance_url: str, access_token: str):
|
||||||
|
self.instance_url = instance_url.rstrip("/")
|
||||||
|
self.access_token = access_token
|
||||||
|
self.websocket: Optional[Any] = None
|
||||||
|
self.is_connected = False
|
||||||
|
self.message_handlers: Dict[str, Callable] = {}
|
||||||
|
self.channels: Dict[str, str] = {}
|
||||||
|
self._running = False
|
||||||
|
self._last_pong = None
|
||||||
|
|
||||||
|
async def connect(self) -> bool:
|
||||||
|
try:
|
||||||
|
ws_url = self.instance_url.replace("https://", "wss://").replace(
|
||||||
|
"http://", "ws://"
|
||||||
|
)
|
||||||
|
ws_url += f"/streaming?i={self.access_token}"
|
||||||
|
|
||||||
|
self.websocket = await websockets.connect(
|
||||||
|
ws_url, ping_interval=30, ping_timeout=10
|
||||||
|
)
|
||||||
|
self.is_connected = True
|
||||||
|
self._running = True
|
||||||
|
|
||||||
|
logger.info("[Misskey WebSocket] 已连接")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey WebSocket] 连接失败: {e}")
|
||||||
|
self.is_connected = False
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def disconnect(self):
|
||||||
|
self._running = False
|
||||||
|
if self.websocket:
|
||||||
|
await self.websocket.close()
|
||||||
|
self.websocket = None
|
||||||
|
self.is_connected = False
|
||||||
|
logger.info("[Misskey WebSocket] 连接已断开")
|
||||||
|
|
||||||
|
async def subscribe_channel(
|
||||||
|
self, channel_type: str, params: Optional[Dict] = None
|
||||||
|
) -> str:
|
||||||
|
if not self.is_connected or not self.websocket:
|
||||||
|
raise WebSocketError("WebSocket 未连接")
|
||||||
|
|
||||||
|
channel_id = str(uuid.uuid4())
|
||||||
|
message = {
|
||||||
|
"type": "connect",
|
||||||
|
"body": {"channel": channel_type, "id": channel_id, "params": params or {}},
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
self.channels[channel_id] = channel_type
|
||||||
|
return channel_id
|
||||||
|
|
||||||
|
async def unsubscribe_channel(self, channel_id: str):
|
||||||
|
if (
|
||||||
|
not self.is_connected
|
||||||
|
or not self.websocket
|
||||||
|
or channel_id not in self.channels
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
message = {"type": "disconnect", "body": {"id": channel_id}}
|
||||||
|
|
||||||
|
await self.websocket.send(json.dumps(message))
|
||||||
|
del self.channels[channel_id]
|
||||||
|
|
||||||
|
def add_message_handler(
|
||||||
|
self, event_type: str, handler: Callable[[Dict], Awaitable[None]]
|
||||||
|
):
|
||||||
|
self.message_handlers[event_type] = handler
|
||||||
|
|
||||||
|
async def listen(self):
|
||||||
|
if not self.is_connected or not self.websocket:
|
||||||
|
raise WebSocketError("WebSocket 未连接")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in self.websocket:
|
||||||
|
if not self._running:
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
await self._handle_message(data)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.warning(f"[Misskey WebSocket] 无法解析消息: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey WebSocket] 处理消息失败: {e}")
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosedError as e:
|
||||||
|
logger.warning(f"[Misskey WebSocket] 连接意外关闭: {e}")
|
||||||
|
self.is_connected = False
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
logger.warning(
|
||||||
|
f"[Misskey WebSocket] 连接已关闭 (代码: {e.code}, 原因: {e.reason})"
|
||||||
|
)
|
||||||
|
self.is_connected = False
|
||||||
|
except websockets.exceptions.InvalidHandshake as e:
|
||||||
|
logger.error(f"[Misskey WebSocket] 握手失败: {e}")
|
||||||
|
self.is_connected = False
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[Misskey WebSocket] 监听消息失败: {e}")
|
||||||
|
self.is_connected = False
|
||||||
|
|
||||||
|
async def _handle_message(self, data: Dict[str, Any]):
|
||||||
|
message_type = data.get("type")
|
||||||
|
body = data.get("body", {})
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey WebSocket] 收到消息类型: {message_type}\n数据: {json.dumps(data, indent=2, ensure_ascii=False)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if message_type == "channel":
|
||||||
|
channel_id = body.get("id")
|
||||||
|
event_type = body.get("type")
|
||||||
|
event_body = body.get("body", {})
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey WebSocket] 频道消息: {channel_id}, 事件类型: {event_type}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if channel_id in self.channels:
|
||||||
|
channel_type = self.channels[channel_id]
|
||||||
|
handler_key = f"{channel_type}:{event_type}"
|
||||||
|
|
||||||
|
if handler_key in self.message_handlers:
|
||||||
|
logger.debug(f"[Misskey WebSocket] 使用处理器: {handler_key}")
|
||||||
|
await self.message_handlers[handler_key](event_body)
|
||||||
|
elif event_type in self.message_handlers:
|
||||||
|
logger.debug(f"[Misskey WebSocket] 使用事件处理器: {event_type}")
|
||||||
|
await self.message_handlers[event_type](event_body)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
f"[Misskey WebSocket] 未找到处理器: {handler_key} 或 {event_type}"
|
||||||
|
)
|
||||||
|
if "_debug" in self.message_handlers:
|
||||||
|
await self.message_handlers["_debug"](
|
||||||
|
{
|
||||||
|
"type": event_type,
|
||||||
|
"body": event_body,
|
||||||
|
"channel": channel_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
elif message_type in self.message_handlers:
|
||||||
|
logger.debug(f"[Misskey WebSocket] 直接消息处理器: {message_type}")
|
||||||
|
await self.message_handlers[message_type](body)
|
||||||
|
else:
|
||||||
|
logger.debug(f"[Misskey WebSocket] 未处理的消息类型: {message_type}")
|
||||||
|
if "_debug" in self.message_handlers:
|
||||||
|
await self.message_handlers["_debug"](data)
|
||||||
|
|
||||||
|
|
||||||
|
def retry_async(max_retries: int = 3, retryable_exceptions: tuple = ()):
|
||||||
|
def decorator(func):
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
last_exc = None
|
||||||
|
for _ in range(max_retries):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except retryable_exceptions as e:
|
||||||
|
last_exc = e
|
||||||
|
continue
|
||||||
|
if last_exc:
|
||||||
|
raise last_exc
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
class MisskeyAPI:
|
||||||
|
def __init__(self, instance_url: str, access_token: str):
|
||||||
|
self.instance_url = instance_url.rstrip("/")
|
||||||
|
self.access_token = access_token
|
||||||
|
self._session: Optional[aiohttp.ClientSession] = None
|
||||||
|
self.streaming: Optional[StreamingClient] = None
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
|
await self.close()
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
if self.streaming:
|
||||||
|
await self.streaming.disconnect()
|
||||||
|
self.streaming = None
|
||||||
|
if self._session:
|
||||||
|
await self._session.close()
|
||||||
|
self._session = None
|
||||||
|
logger.debug("[Misskey API] 客户端已关闭")
|
||||||
|
|
||||||
|
def get_streaming_client(self) -> StreamingClient:
|
||||||
|
if not self.streaming:
|
||||||
|
self.streaming = StreamingClient(self.instance_url, self.access_token)
|
||||||
|
return self.streaming
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session(self) -> aiohttp.ClientSession:
|
||||||
|
if self._session is None or self._session.closed:
|
||||||
|
headers = {"Authorization": f"Bearer {self.access_token}"}
|
||||||
|
self._session = aiohttp.ClientSession(headers=headers)
|
||||||
|
return self._session
|
||||||
|
|
||||||
|
def _handle_response_status(self, status: int, endpoint: str):
|
||||||
|
"""处理 HTTP 响应状态码"""
|
||||||
|
if status == 400:
|
||||||
|
logger.error(f"API 请求错误: {endpoint} (状态码: {status})")
|
||||||
|
raise APIError(f"Bad request for {endpoint}")
|
||||||
|
elif status in (401, 403):
|
||||||
|
logger.error(f"API 认证失败: {endpoint} (状态码: {status})")
|
||||||
|
raise AuthenticationError(f"Authentication failed for {endpoint}")
|
||||||
|
elif status == 429:
|
||||||
|
logger.warning(f"API 频率限制: {endpoint} (状态码: {status})")
|
||||||
|
raise APIRateLimitError(f"Rate limit exceeded for {endpoint}")
|
||||||
|
else:
|
||||||
|
logger.error(f"API 请求失败: {endpoint} (状态码: {status})")
|
||||||
|
raise APIConnectionError(f"HTTP {status} for {endpoint}")
|
||||||
|
|
||||||
|
async def _process_response(
|
||||||
|
self, response: aiohttp.ClientResponse, endpoint: str
|
||||||
|
) -> Any:
|
||||||
|
"""处理 API 响应"""
|
||||||
|
if response.status == HTTP_OK:
|
||||||
|
try:
|
||||||
|
result = await response.json()
|
||||||
|
if endpoint == "i/notifications":
|
||||||
|
notifications_data = (
|
||||||
|
result
|
||||||
|
if isinstance(result, list)
|
||||||
|
else result.get("notifications", [])
|
||||||
|
if isinstance(result, dict)
|
||||||
|
else []
|
||||||
|
)
|
||||||
|
if notifications_data:
|
||||||
|
logger.debug(f"获取到 {len(notifications_data)} 条新通知")
|
||||||
|
else:
|
||||||
|
logger.debug(f"API 请求成功: {endpoint}")
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"响应不是有效的 JSON 格式: {e}")
|
||||||
|
raise APIConnectionError("Invalid JSON response") from e
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
error_text = await response.text()
|
||||||
|
logger.error(
|
||||||
|
f"API 请求失败: {endpoint} - 状态码: {response.status}, 响应: {error_text}"
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error(f"API 请求失败: {endpoint} - 状态码: {response.status}")
|
||||||
|
|
||||||
|
self._handle_response_status(response.status, endpoint)
|
||||||
|
raise APIConnectionError(f"Request failed for {endpoint}")
|
||||||
|
|
||||||
|
@retry_async(
|
||||||
|
max_retries=API_MAX_RETRIES,
|
||||||
|
retryable_exceptions=(APIConnectionError, APIRateLimitError),
|
||||||
|
)
|
||||||
|
async def _make_request(
|
||||||
|
self, endpoint: str, data: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Any:
|
||||||
|
url = f"{self.instance_url}/api/{endpoint}"
|
||||||
|
payload = {"i": self.access_token}
|
||||||
|
if data:
|
||||||
|
payload.update(data)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.session.post(url, json=payload) as response:
|
||||||
|
return await self._process_response(response, endpoint)
|
||||||
|
except aiohttp.ClientError as e:
|
||||||
|
logger.error(f"HTTP 请求错误: {e}")
|
||||||
|
raise APIConnectionError(f"HTTP request failed: {e}") from e
|
||||||
|
|
||||||
|
async def create_note(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
visibility: str = "public",
|
||||||
|
reply_id: Optional[str] = None,
|
||||||
|
visible_user_ids: Optional[List[str]] = None,
|
||||||
|
local_only: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""创建新贴文"""
|
||||||
|
data: Dict[str, Any] = {
|
||||||
|
"text": text,
|
||||||
|
"visibility": visibility,
|
||||||
|
"localOnly": local_only,
|
||||||
|
}
|
||||||
|
if reply_id:
|
||||||
|
data["replyId"] = reply_id
|
||||||
|
if visible_user_ids and visibility == "specified":
|
||||||
|
data["visibleUserIds"] = visible_user_ids
|
||||||
|
|
||||||
|
result = await self._make_request("notes/create", data)
|
||||||
|
note_id = result.get("createdNote", {}).get("id", "unknown")
|
||||||
|
logger.debug(f"发帖成功,note_id: {note_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_current_user(self) -> Dict[str, Any]:
|
||||||
|
"""获取当前用户信息"""
|
||||||
|
return await self._make_request("i", {})
|
||||||
|
|
||||||
|
async def send_message(self, user_id: str, text: str) -> Dict[str, Any]:
|
||||||
|
"""发送聊天消息"""
|
||||||
|
result = await self._make_request(
|
||||||
|
"chat/messages/create-to-user", {"toUserId": user_id, "text": text}
|
||||||
|
)
|
||||||
|
message_id = result.get("id", "unknown")
|
||||||
|
logger.debug(f"聊天发送成功,message_id: {message_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def send_room_message(self, room_id: str, text: str) -> Dict[str, Any]:
|
||||||
|
"""发送房间消息"""
|
||||||
|
result = await self._make_request(
|
||||||
|
"chat/messages/create-to-room", {"toRoomId": room_id, "text": text}
|
||||||
|
)
|
||||||
|
message_id = result.get("id", "unknown")
|
||||||
|
logger.debug(f"房间消息发送成功,message_id: {message_id}")
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def get_messages(
|
||||||
|
self, user_id: str, limit: int = 10, since_id: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""获取聊天消息历史"""
|
||||||
|
data: Dict[str, Any] = {"userId": user_id, "limit": limit}
|
||||||
|
if since_id:
|
||||||
|
data["sinceId"] = since_id
|
||||||
|
|
||||||
|
result = await self._make_request("chat/messages/user-timeline", data)
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
logger.warning(f"获取聊天消息响应格式异常: {type(result)}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def get_mentions(
|
||||||
|
self, limit: int = 10, since_id: Optional[str] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""获取提及通知"""
|
||||||
|
data: Dict[str, Any] = {"limit": limit}
|
||||||
|
if since_id:
|
||||||
|
data["sinceId"] = since_id
|
||||||
|
data["includeTypes"] = ["mention", "reply", "quote"]
|
||||||
|
|
||||||
|
result = await self._make_request("i/notifications", data)
|
||||||
|
if isinstance(result, list):
|
||||||
|
return result
|
||||||
|
elif isinstance(result, dict) and "notifications" in result:
|
||||||
|
return result["notifications"]
|
||||||
|
else:
|
||||||
|
logger.warning(f"获取提及通知响应格式异常: {type(result)}")
|
||||||
|
return []
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
import asyncio
|
||||||
|
import re
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.platform import PlatformMetadata, AstrBotMessage
|
||||||
|
from astrbot.api.message_components import Plain
|
||||||
|
|
||||||
|
from .misskey_utils import (
|
||||||
|
serialize_message_chain,
|
||||||
|
resolve_visibility_from_raw_message,
|
||||||
|
is_valid_user_session_id,
|
||||||
|
is_valid_room_session_id,
|
||||||
|
add_at_mention_if_needed,
|
||||||
|
extract_user_id_from_session_id,
|
||||||
|
extract_room_id_from_session_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MisskeyPlatformEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
client,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
def _is_system_command(self, message_str: str) -> bool:
|
||||||
|
"""检测是否为系统指令"""
|
||||||
|
if not message_str or not message_str.strip():
|
||||||
|
return False
|
||||||
|
|
||||||
|
system_prefixes = ["/", "!", "#", ".", "^"]
|
||||||
|
message_trimmed = message_str.strip()
|
||||||
|
|
||||||
|
return any(message_trimmed.startswith(prefix) for prefix in system_prefixes)
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
content, has_at = serialize_message_chain(message.chain)
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
logger.debug("[MisskeyEvent] 内容为空,跳过发送")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
original_message_id = getattr(self.message_obj, "message_id", None)
|
||||||
|
raw_message = getattr(self.message_obj, "raw_message", {})
|
||||||
|
|
||||||
|
if raw_message and not has_at:
|
||||||
|
user_data = raw_message.get("user", {})
|
||||||
|
user_info = {
|
||||||
|
"username": user_data.get("username", ""),
|
||||||
|
"nickname": user_data.get("name", user_data.get("username", "")),
|
||||||
|
}
|
||||||
|
content = add_at_mention_if_needed(content, user_info, has_at)
|
||||||
|
|
||||||
|
# 根据会话类型选择发送方式
|
||||||
|
if hasattr(self.client, "send_message") and is_valid_user_session_id(
|
||||||
|
self.session_id
|
||||||
|
):
|
||||||
|
user_id = extract_user_id_from_session_id(self.session_id)
|
||||||
|
await self.client.send_message(user_id, content)
|
||||||
|
elif hasattr(self.client, "send_room_message") and is_valid_room_session_id(
|
||||||
|
self.session_id
|
||||||
|
):
|
||||||
|
room_id = extract_room_id_from_session_id(self.session_id)
|
||||||
|
await self.client.send_room_message(room_id, content)
|
||||||
|
elif original_message_id and hasattr(self.client, "create_note"):
|
||||||
|
visibility, visible_user_ids = resolve_visibility_from_raw_message(
|
||||||
|
raw_message
|
||||||
|
)
|
||||||
|
await self.client.create_note(
|
||||||
|
content,
|
||||||
|
reply_id=original_message_id,
|
||||||
|
visibility=visibility,
|
||||||
|
visible_user_ids=visible_user_ids,
|
||||||
|
)
|
||||||
|
elif hasattr(self.client, "create_note"):
|
||||||
|
logger.debug("[MisskeyEvent] 创建新帖子")
|
||||||
|
await self.client.create_note(content)
|
||||||
|
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[MisskeyEvent] 发送失败: {e}")
|
||||||
|
|
||||||
|
async def send_streaming(
|
||||||
|
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||||
|
):
|
||||||
|
if not use_fallback:
|
||||||
|
buffer = None
|
||||||
|
async for chain in generator:
|
||||||
|
if not buffer:
|
||||||
|
buffer = chain
|
||||||
|
else:
|
||||||
|
buffer.chain.extend(chain.chain)
|
||||||
|
if not buffer:
|
||||||
|
return
|
||||||
|
buffer.squash_plain()
|
||||||
|
await self.send(buffer)
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
|
|
||||||
|
buffer = ""
|
||||||
|
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
for comp in chain.chain:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
buffer += comp.text
|
||||||
|
if any(p in buffer for p in "。?!~…"):
|
||||||
|
buffer = await self.process_buffer(buffer, pattern)
|
||||||
|
else:
|
||||||
|
await self.send(MessageChain(chain=[comp]))
|
||||||
|
await asyncio.sleep(1.5) # 限速
|
||||||
|
|
||||||
|
if buffer.strip():
|
||||||
|
await self.send(MessageChain([Plain(buffer)]))
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
@@ -0,0 +1,327 @@
|
|||||||
|
"""Misskey 平台适配器通用工具函数"""
|
||||||
|
|
||||||
|
from typing import Dict, Any, List, Tuple, Optional, Union
|
||||||
|
import astrbot.api.message_components as Comp
|
||||||
|
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||||
|
|
||||||
|
|
||||||
|
def serialize_message_chain(chain: List[Any]) -> Tuple[str, bool]:
|
||||||
|
"""将消息链序列化为文本字符串"""
|
||||||
|
text_parts = []
|
||||||
|
has_at = False
|
||||||
|
|
||||||
|
def process_component(component):
|
||||||
|
nonlocal has_at
|
||||||
|
if isinstance(component, Comp.Plain):
|
||||||
|
return component.text
|
||||||
|
elif isinstance(component, Comp.File):
|
||||||
|
file_name = getattr(component, "name", "文件")
|
||||||
|
return f"[文件: {file_name}]"
|
||||||
|
elif isinstance(component, Comp.At):
|
||||||
|
has_at = True
|
||||||
|
return f"@{component.qq}"
|
||||||
|
elif hasattr(component, "text"):
|
||||||
|
text = getattr(component, "text", "")
|
||||||
|
if "@" in text:
|
||||||
|
has_at = True
|
||||||
|
return text
|
||||||
|
else:
|
||||||
|
return str(component)
|
||||||
|
|
||||||
|
for component in chain:
|
||||||
|
if isinstance(component, Comp.Node) and component.content:
|
||||||
|
for node_comp in component.content:
|
||||||
|
result = process_component(node_comp)
|
||||||
|
if result:
|
||||||
|
text_parts.append(result)
|
||||||
|
else:
|
||||||
|
result = process_component(component)
|
||||||
|
if result:
|
||||||
|
text_parts.append(result)
|
||||||
|
|
||||||
|
return "".join(text_parts), has_at
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_message_visibility(
|
||||||
|
user_id: Optional[str],
|
||||||
|
user_cache: Dict[str, Any],
|
||||||
|
self_id: Optional[str],
|
||||||
|
default_visibility: str = "public",
|
||||||
|
) -> Tuple[str, Optional[List[str]]]:
|
||||||
|
"""解析 Misskey 消息的可见性设置"""
|
||||||
|
visibility = default_visibility
|
||||||
|
visible_user_ids = None
|
||||||
|
|
||||||
|
if user_id and user_cache:
|
||||||
|
user_info = user_cache.get(user_id)
|
||||||
|
if user_info:
|
||||||
|
original_visibility = user_info.get("visibility", default_visibility)
|
||||||
|
if original_visibility == "specified":
|
||||||
|
visibility = "specified"
|
||||||
|
original_visible_users = user_info.get("visible_user_ids", [])
|
||||||
|
users_to_include = [user_id]
|
||||||
|
if self_id:
|
||||||
|
users_to_include.append(self_id)
|
||||||
|
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||||
|
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||||
|
else:
|
||||||
|
visibility = original_visibility
|
||||||
|
|
||||||
|
return visibility, visible_user_ids
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_visibility_from_raw_message(
|
||||||
|
raw_message: Dict[str, Any], self_id: Optional[str] = None
|
||||||
|
) -> Tuple[str, Optional[List[str]]]:
|
||||||
|
"""从原始消息数据中解析可见性设置"""
|
||||||
|
visibility = "public"
|
||||||
|
visible_user_ids = None
|
||||||
|
|
||||||
|
if not raw_message:
|
||||||
|
return visibility, visible_user_ids
|
||||||
|
|
||||||
|
original_visibility = raw_message.get("visibility", "public")
|
||||||
|
if original_visibility == "specified":
|
||||||
|
visibility = "specified"
|
||||||
|
original_visible_users = raw_message.get("visibleUserIds", [])
|
||||||
|
sender_id = raw_message.get("userId", "")
|
||||||
|
|
||||||
|
users_to_include = []
|
||||||
|
if sender_id:
|
||||||
|
users_to_include.append(sender_id)
|
||||||
|
if self_id:
|
||||||
|
users_to_include.append(self_id)
|
||||||
|
|
||||||
|
visible_user_ids = list(set(original_visible_users + users_to_include))
|
||||||
|
visible_user_ids = [uid for uid in visible_user_ids if uid]
|
||||||
|
else:
|
||||||
|
visibility = original_visibility
|
||||||
|
|
||||||
|
return visibility, visible_user_ids
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_user_session_id(session_id: Union[str, Any]) -> bool:
|
||||||
|
"""检查 session_id 是否是有效的聊天用户 session_id (仅限chat%前缀)"""
|
||||||
|
if not isinstance(session_id, str) or "%" not in session_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
parts = session_id.split("%")
|
||||||
|
return (
|
||||||
|
len(parts) == 2
|
||||||
|
and parts[0] == "chat"
|
||||||
|
and bool(parts[1])
|
||||||
|
and parts[1] != "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_room_session_id(session_id: Union[str, Any]) -> bool:
|
||||||
|
"""检查 session_id 是否是有效的房间 session_id (仅限room%前缀)"""
|
||||||
|
if not isinstance(session_id, str) or "%" not in session_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
parts = session_id.split("%")
|
||||||
|
return (
|
||||||
|
len(parts) == 2
|
||||||
|
and parts[0] == "room"
|
||||||
|
and bool(parts[1])
|
||||||
|
and parts[1] != "unknown"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_user_id_from_session_id(session_id: str) -> str:
|
||||||
|
"""从 session_id 中提取用户 ID"""
|
||||||
|
if "%" in session_id:
|
||||||
|
parts = session_id.split("%")
|
||||||
|
if len(parts) >= 2:
|
||||||
|
return parts[1]
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
def extract_room_id_from_session_id(session_id: str) -> str:
|
||||||
|
"""从 session_id 中提取房间 ID"""
|
||||||
|
if "%" in session_id:
|
||||||
|
parts = session_id.split("%")
|
||||||
|
if len(parts) >= 2 and parts[0] == "room":
|
||||||
|
return parts[1]
|
||||||
|
return session_id
|
||||||
|
|
||||||
|
|
||||||
|
def add_at_mention_if_needed(
|
||||||
|
text: str, user_info: Optional[Dict[str, Any]], has_at: bool = False
|
||||||
|
) -> str:
|
||||||
|
"""如果需要且没有@用户,则添加@用户"""
|
||||||
|
if has_at or not user_info:
|
||||||
|
return text
|
||||||
|
|
||||||
|
username = user_info.get("username")
|
||||||
|
nickname = user_info.get("nickname")
|
||||||
|
|
||||||
|
if username:
|
||||||
|
mention = f"@{username}"
|
||||||
|
if not text.startswith(mention):
|
||||||
|
text = f"{mention}\n{text}".strip()
|
||||||
|
elif nickname:
|
||||||
|
mention = f"@{nickname}"
|
||||||
|
if not text.startswith(mention):
|
||||||
|
text = f"{mention}\n{text}".strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def create_file_component(file_info: Dict[str, Any]) -> Tuple[Any, str]:
|
||||||
|
"""创建文件组件和描述文本"""
|
||||||
|
file_url = file_info.get("url", "")
|
||||||
|
file_name = file_info.get("name", "未知文件")
|
||||||
|
file_type = file_info.get("type", "")
|
||||||
|
|
||||||
|
if file_type.startswith("image/"):
|
||||||
|
return Comp.Image(url=file_url, file=file_name), f"图片[{file_name}]"
|
||||||
|
elif file_type.startswith("audio/"):
|
||||||
|
return Comp.Record(url=file_url, file=file_name), f"音频[{file_name}]"
|
||||||
|
elif file_type.startswith("video/"):
|
||||||
|
return Comp.Video(url=file_url, file=file_name), f"视频[{file_name}]"
|
||||||
|
else:
|
||||||
|
return Comp.File(name=file_name, url=file_url), f"文件[{file_name}]"
|
||||||
|
|
||||||
|
|
||||||
|
def process_files(
|
||||||
|
message: AstrBotMessage, files: list, include_text_parts: bool = True
|
||||||
|
) -> list:
|
||||||
|
"""处理文件列表,添加到消息组件中并返回文本描述"""
|
||||||
|
file_parts = []
|
||||||
|
for file_info in files:
|
||||||
|
component, part_text = create_file_component(file_info)
|
||||||
|
message.message.append(component)
|
||||||
|
if include_text_parts:
|
||||||
|
file_parts.append(part_text)
|
||||||
|
return file_parts
|
||||||
|
|
||||||
|
|
||||||
|
def extract_sender_info(
|
||||||
|
raw_data: Dict[str, Any], is_chat: bool = False
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""提取发送者信息"""
|
||||||
|
if is_chat:
|
||||||
|
sender = raw_data.get("fromUser", {})
|
||||||
|
sender_id = str(sender.get("id", "") or raw_data.get("fromUserId", ""))
|
||||||
|
else:
|
||||||
|
sender = raw_data.get("user", {})
|
||||||
|
sender_id = str(sender.get("id", ""))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"sender": sender,
|
||||||
|
"sender_id": sender_id,
|
||||||
|
"nickname": sender.get("name", sender.get("username", "")),
|
||||||
|
"username": sender.get("username", ""),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def create_base_message(
|
||||||
|
raw_data: Dict[str, Any],
|
||||||
|
sender_info: Dict[str, Any],
|
||||||
|
client_self_id: str,
|
||||||
|
is_chat: bool = False,
|
||||||
|
room_id: Optional[str] = None,
|
||||||
|
unique_session: bool = False,
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
"""创建基础消息对象"""
|
||||||
|
message = AstrBotMessage()
|
||||||
|
message.raw_message = raw_data
|
||||||
|
message.message = []
|
||||||
|
|
||||||
|
message.sender = MessageMember(
|
||||||
|
user_id=sender_info["sender_id"],
|
||||||
|
nickname=sender_info["nickname"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if room_id:
|
||||||
|
session_prefix = "room"
|
||||||
|
session_id = f"{session_prefix}%{room_id}"
|
||||||
|
if unique_session:
|
||||||
|
session_id += f"_{sender_info['sender_id']}"
|
||||||
|
message.type = MessageType.GROUP_MESSAGE
|
||||||
|
message.group_id = room_id
|
||||||
|
elif is_chat:
|
||||||
|
session_prefix = "chat"
|
||||||
|
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
||||||
|
message.type = MessageType.FRIEND_MESSAGE
|
||||||
|
else:
|
||||||
|
session_prefix = "note"
|
||||||
|
session_id = f"{session_prefix}%{sender_info['sender_id']}"
|
||||||
|
message.type = MessageType.FRIEND_MESSAGE
|
||||||
|
|
||||||
|
message.session_id = (
|
||||||
|
session_id if sender_info["sender_id"] else f"{session_prefix}%unknown"
|
||||||
|
)
|
||||||
|
message.message_id = str(raw_data.get("id", ""))
|
||||||
|
message.self_id = client_self_id
|
||||||
|
|
||||||
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def process_at_mention(
|
||||||
|
message: AstrBotMessage, raw_text: str, bot_username: str, client_self_id: str
|
||||||
|
) -> Tuple[List[str], str]:
|
||||||
|
"""处理@提及逻辑,返回消息部分列表和处理后的文本"""
|
||||||
|
message_parts = []
|
||||||
|
|
||||||
|
if not raw_text:
|
||||||
|
return message_parts, ""
|
||||||
|
|
||||||
|
if bot_username and raw_text.startswith(f"@{bot_username}"):
|
||||||
|
at_mention = f"@{bot_username}"
|
||||||
|
message.message.append(Comp.At(qq=client_self_id))
|
||||||
|
remaining_text = raw_text[len(at_mention) :].strip()
|
||||||
|
if remaining_text:
|
||||||
|
message.message.append(Comp.Plain(remaining_text))
|
||||||
|
message_parts.append(remaining_text)
|
||||||
|
return message_parts, remaining_text
|
||||||
|
else:
|
||||||
|
message.message.append(Comp.Plain(raw_text))
|
||||||
|
message_parts.append(raw_text)
|
||||||
|
return message_parts, raw_text
|
||||||
|
|
||||||
|
|
||||||
|
def cache_user_info(
|
||||||
|
user_cache: Dict[str, Any],
|
||||||
|
sender_info: Dict[str, Any],
|
||||||
|
raw_data: Dict[str, Any],
|
||||||
|
client_self_id: str,
|
||||||
|
is_chat: bool = False,
|
||||||
|
):
|
||||||
|
"""缓存用户信息"""
|
||||||
|
if is_chat:
|
||||||
|
user_cache_data = {
|
||||||
|
"username": sender_info["username"],
|
||||||
|
"nickname": sender_info["nickname"],
|
||||||
|
"visibility": "specified",
|
||||||
|
"visible_user_ids": [client_self_id, sender_info["sender_id"]],
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
user_cache_data = {
|
||||||
|
"username": sender_info["username"],
|
||||||
|
"nickname": sender_info["nickname"],
|
||||||
|
"visibility": raw_data.get("visibility", "public"),
|
||||||
|
"visible_user_ids": raw_data.get("visibleUserIds", []),
|
||||||
|
}
|
||||||
|
|
||||||
|
user_cache[sender_info["sender_id"]] = user_cache_data
|
||||||
|
|
||||||
|
|
||||||
|
def cache_room_info(
|
||||||
|
user_cache: Dict[str, Any], raw_data: Dict[str, Any], client_self_id: str
|
||||||
|
):
|
||||||
|
"""缓存房间信息"""
|
||||||
|
room_data = raw_data.get("toRoom")
|
||||||
|
room_id = raw_data.get("toRoomId")
|
||||||
|
|
||||||
|
if room_data and room_id:
|
||||||
|
room_cache_key = f"room:{room_id}"
|
||||||
|
user_cache[room_cache_key] = {
|
||||||
|
"room_id": room_id,
|
||||||
|
"room_name": room_data.get("name", ""),
|
||||||
|
"room_description": room_data.get("description", ""),
|
||||||
|
"owner_id": room_data.get("ownerId", ""),
|
||||||
|
"visibility": "specified",
|
||||||
|
"visible_user_ids": [client_self_id],
|
||||||
|
}
|
||||||
@@ -94,10 +94,15 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
plain_text,
|
plain_text,
|
||||||
image_base64,
|
image_base64,
|
||||||
image_path,
|
image_path,
|
||||||
record_file_path
|
record_file_path,
|
||||||
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
) = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
|
||||||
|
|
||||||
if not plain_text and not image_base64 and not image_path and not record_file_path:
|
if (
|
||||||
|
not plain_text
|
||||||
|
and not image_base64
|
||||||
|
and not image_path
|
||||||
|
and not record_file_path
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
@@ -118,7 +123,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
if record_file_path: # group record msg
|
if record_file_path: # group record msg
|
||||||
media = await self.upload_group_and_c2c_record(
|
media = await self.upload_group_and_c2c_record(
|
||||||
record_file_path, 3, group_openid=source.group_openid
|
record_file_path, 3, group_openid=source.group_openid
|
||||||
)
|
)
|
||||||
@@ -134,9 +139,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
if record_file_path: # c2c record
|
if record_file_path: # c2c record
|
||||||
media = await self.upload_group_and_c2c_record(
|
media = await self.upload_group_and_c2c_record(
|
||||||
record_file_path, 3, openid = source.author.user_openid
|
record_file_path, 3, openid=source.author.user_openid
|
||||||
)
|
)
|
||||||
payload["media"] = media
|
payload["media"] = media
|
||||||
payload["msg_type"] = 7
|
payload["msg_type"] = 7
|
||||||
@@ -190,28 +195,21 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
return await self.bot.api._http.request(route, json=payload)
|
return await self.bot.api._http.request(route, json=payload)
|
||||||
|
|
||||||
async def upload_group_and_c2c_record(
|
async def upload_group_and_c2c_record(
|
||||||
self,
|
self, file_source: str, file_type: int, srv_send_msg: bool = False, **kwargs
|
||||||
file_source: str,
|
|
||||||
file_type: int,
|
|
||||||
srv_send_msg: bool = False,
|
|
||||||
**kwargs
|
|
||||||
) -> Optional[Media]:
|
) -> Optional[Media]:
|
||||||
"""
|
"""
|
||||||
上传媒体文件
|
上传媒体文件
|
||||||
"""
|
"""
|
||||||
# 构建基础payload
|
# 构建基础payload
|
||||||
payload = {
|
payload = {"file_type": file_type, "srv_send_msg": srv_send_msg}
|
||||||
"file_type": file_type,
|
|
||||||
"srv_send_msg": srv_send_msg
|
|
||||||
}
|
|
||||||
|
|
||||||
# 处理文件数据
|
# 处理文件数据
|
||||||
if os.path.exists(file_source):
|
if os.path.exists(file_source):
|
||||||
# 读取本地文件
|
# 读取本地文件
|
||||||
async with aiofiles.open(file_source, 'rb') as f:
|
async with aiofiles.open(file_source, "rb") as f:
|
||||||
file_content = await f.read()
|
file_content = await f.read()
|
||||||
# use base64 encode
|
# use base64 encode
|
||||||
payload["file_data"] = base64.b64encode(file_content).decode('utf-8')
|
payload["file_data"] = base64.b64encode(file_content).decode("utf-8")
|
||||||
else:
|
else:
|
||||||
# 使用URL
|
# 使用URL
|
||||||
payload["url"] = file_source
|
payload["url"] = file_source
|
||||||
@@ -221,8 +219,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
payload["openid"] = kwargs["openid"]
|
payload["openid"] = kwargs["openid"]
|
||||||
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
route = Route("POST", "/v2/users/{openid}/files", openid=kwargs["openid"])
|
||||||
elif "group_openid" in kwargs:
|
elif "group_openid" in kwargs:
|
||||||
payload["group_openid"] =kwargs["group_openid"]
|
payload["group_openid"] = kwargs["group_openid"]
|
||||||
route = Route("POST", "/v2/groups/{group_openid}/files", group_openid=kwargs["group_openid"])
|
route = Route(
|
||||||
|
"POST",
|
||||||
|
"/v2/groups/{group_openid}/files",
|
||||||
|
group_openid=kwargs["group_openid"],
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -235,7 +237,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
file_uuid=result.get("file_uuid"),
|
file_uuid=result.get("file_uuid"),
|
||||||
file_info=result.get("file_info"),
|
file_info=result.get("file_info"),
|
||||||
ttl=result.get("ttl", 0),
|
ttl=result.get("ttl", 0),
|
||||||
file_id=result.get("id", "")
|
file_id=result.get("id", ""),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"上传请求错误: {e}")
|
logger.error(f"上传请求错误: {e}")
|
||||||
@@ -286,11 +288,15 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
|||||||
image_base64 = image_base64.removeprefix("base64://")
|
image_base64 = image_base64.removeprefix("base64://")
|
||||||
elif isinstance(i, Record):
|
elif isinstance(i, Record):
|
||||||
if i.file:
|
if i.file:
|
||||||
record_wav_path = await i.convert_to_file_path() # wav 路径
|
record_wav_path = await i.convert_to_file_path() # wav 路径
|
||||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||||
record_tecent_silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk")
|
record_tecent_silk_path = os.path.join(
|
||||||
|
temp_dir, f"{uuid.uuid4()}.silk"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
duration = await wav_to_tencent_silk(record_wav_path, record_tecent_silk_path)
|
duration = await wav_to_tencent_silk(
|
||||||
|
record_wav_path, record_tecent_silk_path
|
||||||
|
)
|
||||||
if duration > 0:
|
if duration > 0:
|
||||||
record_file_path = record_tecent_silk_path
|
record_file_path = record_tecent_silk_path
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -0,0 +1,748 @@
|
|||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import websockets
|
||||||
|
from websockets.asyncio.client import connect
|
||||||
|
from typing import Optional
|
||||||
|
from aiohttp import ClientSession, ClientTimeout
|
||||||
|
from websockets.asyncio.client import ClientConnection
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
Platform,
|
||||||
|
PlatformMetadata,
|
||||||
|
register_platform_adapter,
|
||||||
|
)
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSession
|
||||||
|
from astrbot.api.message_components import (
|
||||||
|
Plain,
|
||||||
|
Image,
|
||||||
|
At,
|
||||||
|
File,
|
||||||
|
Record,
|
||||||
|
Reply,
|
||||||
|
)
|
||||||
|
from xml.etree import ElementTree as ET
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter(
|
||||||
|
"satori",
|
||||||
|
"Satori 协议适配器",
|
||||||
|
)
|
||||||
|
class SatoriPlatformAdapter(Platform):
|
||||||
|
def __init__(
|
||||||
|
self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue
|
||||||
|
) -> None:
|
||||||
|
super().__init__(event_queue)
|
||||||
|
self.config = platform_config
|
||||||
|
self.settings = platform_settings
|
||||||
|
|
||||||
|
self.api_base_url = self.config.get(
|
||||||
|
"satori_api_base_url", "http://localhost:5140/satori/v1"
|
||||||
|
)
|
||||||
|
self.token = self.config.get("satori_token", "")
|
||||||
|
self.endpoint = self.config.get(
|
||||||
|
"satori_endpoint", "ws://localhost:5140/satori/v1/events"
|
||||||
|
)
|
||||||
|
self.auto_reconnect = self.config.get("satori_auto_reconnect", True)
|
||||||
|
self.heartbeat_interval = self.config.get("satori_heartbeat_interval", 10)
|
||||||
|
self.reconnect_delay = self.config.get("satori_reconnect_delay", 5)
|
||||||
|
|
||||||
|
self.metadata = PlatformMetadata(
|
||||||
|
name="satori",
|
||||||
|
description="Satori 通用协议适配器",
|
||||||
|
id=self.config["id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.ws: Optional[ClientConnection] = None
|
||||||
|
self.session: Optional[ClientSession] = None
|
||||||
|
self.sequence = 0
|
||||||
|
self.logins = []
|
||||||
|
self.running = False
|
||||||
|
self.heartbeat_task: Optional[asyncio.Task] = None
|
||||||
|
self.ready_received = False
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self, session: MessageSession, message_chain: MessageChain
|
||||||
|
):
|
||||||
|
from .satori_event import SatoriPlatformEvent
|
||||||
|
|
||||||
|
await SatoriPlatformEvent.send_with_adapter(
|
||||||
|
self, message_chain, session.session_id
|
||||||
|
)
|
||||||
|
await super().send_by_session(session, message_chain)
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return self.metadata
|
||||||
|
|
||||||
|
def _is_websocket_closed(self, ws) -> bool:
|
||||||
|
"""检查WebSocket连接是否已关闭"""
|
||||||
|
if not ws:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
if hasattr(ws, "closed"):
|
||||||
|
return ws.closed
|
||||||
|
elif hasattr(ws, "close_code"):
|
||||||
|
return ws.close_code is not None
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except AttributeError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
self.running = True
|
||||||
|
self.session = ClientSession(timeout=ClientTimeout(total=30))
|
||||||
|
|
||||||
|
retry_count = 0
|
||||||
|
max_retries = 10
|
||||||
|
|
||||||
|
while self.running:
|
||||||
|
try:
|
||||||
|
await self.connect_websocket()
|
||||||
|
retry_count = 0
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
logger.warning(f"Satori WebSocket 连接关闭: {e}")
|
||||||
|
retry_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori WebSocket 连接失败: {e}")
|
||||||
|
retry_count += 1
|
||||||
|
|
||||||
|
if not self.running:
|
||||||
|
break
|
||||||
|
|
||||||
|
if retry_count >= max_retries:
|
||||||
|
logger.error(f"达到最大重试次数 ({max_retries}),停止重试")
|
||||||
|
break
|
||||||
|
|
||||||
|
if not self.auto_reconnect:
|
||||||
|
break
|
||||||
|
|
||||||
|
delay = min(self.reconnect_delay * (2 ** (retry_count - 1)), 60)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
if self.session:
|
||||||
|
await self.session.close()
|
||||||
|
|
||||||
|
async def connect_websocket(self):
|
||||||
|
logger.info(f"Satori 适配器正在连接到 WebSocket: {self.endpoint}")
|
||||||
|
logger.info(f"Satori 适配器 HTTP API 地址: {self.api_base_url}")
|
||||||
|
|
||||||
|
if not self.endpoint.startswith(("ws://", "wss://")):
|
||||||
|
logger.error(f"无效的WebSocket URL: {self.endpoint}")
|
||||||
|
raise ValueError(f"WebSocket URL必须以ws://或wss://开头: {self.endpoint}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
websocket = await connect(self.endpoint, additional_headers={})
|
||||||
|
self.ws = websocket
|
||||||
|
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
await self.send_identify()
|
||||||
|
|
||||||
|
self.heartbeat_task = asyncio.create_task(self.heartbeat_loop())
|
||||||
|
|
||||||
|
async for message in websocket:
|
||||||
|
try:
|
||||||
|
await self.handle_message(message) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori 处理消息异常: {e}")
|
||||||
|
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
logger.warning(f"Satori WebSocket 连接关闭: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori WebSocket 连接异常: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if self.heartbeat_task:
|
||||||
|
self.heartbeat_task.cancel()
|
||||||
|
try:
|
||||||
|
await self.heartbeat_task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
if self.ws:
|
||||||
|
try:
|
||||||
|
await self.ws.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori WebSocket 关闭异常: {e}")
|
||||||
|
|
||||||
|
async def send_identify(self):
|
||||||
|
if not self.ws:
|
||||||
|
raise Exception("WebSocket连接未建立")
|
||||||
|
|
||||||
|
if self._is_websocket_closed(self.ws):
|
||||||
|
raise Exception("WebSocket连接已关闭")
|
||||||
|
|
||||||
|
identify_payload = {
|
||||||
|
"op": 3, # IDENTIFY
|
||||||
|
"body": {
|
||||||
|
"token": str(self.token) if self.token else "", # 字符串
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# 只有在有序列号时才添加sn字段
|
||||||
|
if self.sequence > 0:
|
||||||
|
identify_payload["body"]["sn"] = self.sequence
|
||||||
|
|
||||||
|
try:
|
||||||
|
message_str = json.dumps(identify_payload, ensure_ascii=False)
|
||||||
|
await self.ws.send(message_str)
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
logger.error(f"发送 IDENTIFY 信令时连接关闭: {e}")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"发送 IDENTIFY 信令失败: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def heartbeat_loop(self):
|
||||||
|
try:
|
||||||
|
while self.running and self.ws:
|
||||||
|
await asyncio.sleep(self.heartbeat_interval)
|
||||||
|
|
||||||
|
if self.ws and not self._is_websocket_closed(self.ws):
|
||||||
|
try:
|
||||||
|
ping_payload = {
|
||||||
|
"op": 1, # PING
|
||||||
|
"body": {},
|
||||||
|
}
|
||||||
|
await self.ws.send(json.dumps(ping_payload, ensure_ascii=False))
|
||||||
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
|
logger.error(f"Satori WebSocket 连接关闭: {e}")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori WebSocket 发送心跳失败: {e}")
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"心跳任务异常: {e}")
|
||||||
|
|
||||||
|
async def handle_message(self, message: str):
|
||||||
|
try:
|
||||||
|
data = json.loads(message)
|
||||||
|
op = data.get("op")
|
||||||
|
body = data.get("body", {})
|
||||||
|
|
||||||
|
if op == 4: # READY
|
||||||
|
self.logins = body.get("logins", [])
|
||||||
|
self.ready_received = True
|
||||||
|
|
||||||
|
# 输出连接成功的bot信息
|
||||||
|
if self.logins:
|
||||||
|
for i, login in enumerate(self.logins):
|
||||||
|
platform = login.get("platform", "")
|
||||||
|
user = login.get("user", {})
|
||||||
|
user_id = user.get("id", "")
|
||||||
|
user_name = user.get("name", "")
|
||||||
|
logger.info(
|
||||||
|
f"Satori 连接成功 - Bot {i + 1}: platform={platform}, user_id={user_id}, user_name={user_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if "sn" in body:
|
||||||
|
self.sequence = body["sn"]
|
||||||
|
|
||||||
|
elif op == 2: # PONG
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif op == 0: # EVENT
|
||||||
|
await self.handle_event(body)
|
||||||
|
if "sn" in body:
|
||||||
|
self.sequence = body["sn"]
|
||||||
|
|
||||||
|
elif op == 5: # META
|
||||||
|
if "sn" in body:
|
||||||
|
self.sequence = body["sn"]
|
||||||
|
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"解析 WebSocket 消息失败: {e}, 消息内容: {message}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理 WebSocket 消息异常: {e}")
|
||||||
|
|
||||||
|
async def handle_event(self, event_data: dict):
|
||||||
|
try:
|
||||||
|
event_type = event_data.get("type")
|
||||||
|
sn = event_data.get("sn")
|
||||||
|
if sn:
|
||||||
|
self.sequence = sn
|
||||||
|
|
||||||
|
if event_type == "message-created":
|
||||||
|
message = event_data.get("message", {})
|
||||||
|
user = event_data.get("user", {})
|
||||||
|
channel = event_data.get("channel", {})
|
||||||
|
guild = event_data.get("guild")
|
||||||
|
login = event_data.get("login", {})
|
||||||
|
timestamp = event_data.get("timestamp")
|
||||||
|
|
||||||
|
if user.get("id") == login.get("user", {}).get("id"):
|
||||||
|
return
|
||||||
|
|
||||||
|
abm = await self.convert_satori_message(
|
||||||
|
message, user, channel, guild, login, timestamp
|
||||||
|
)
|
||||||
|
if abm:
|
||||||
|
await self.handle_msg(abm)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理事件失败: {e}")
|
||||||
|
|
||||||
|
async def convert_satori_message(
|
||||||
|
self,
|
||||||
|
message: dict,
|
||||||
|
user: dict,
|
||||||
|
channel: dict,
|
||||||
|
guild: Optional[dict],
|
||||||
|
login: dict,
|
||||||
|
timestamp: Optional[int] = None,
|
||||||
|
) -> Optional[AstrBotMessage]:
|
||||||
|
try:
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.message_id = message.get("id", "")
|
||||||
|
abm.raw_message = {
|
||||||
|
"message": message,
|
||||||
|
"user": user,
|
||||||
|
"channel": channel,
|
||||||
|
"guild": guild,
|
||||||
|
"login": login,
|
||||||
|
}
|
||||||
|
|
||||||
|
if guild and guild.get("id"):
|
||||||
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
|
abm.group_id = guild.get("id", "")
|
||||||
|
abm.session_id = channel.get("id", "")
|
||||||
|
else:
|
||||||
|
abm.type = MessageType.FRIEND_MESSAGE
|
||||||
|
abm.session_id = channel.get("id", "")
|
||||||
|
|
||||||
|
abm.sender = MessageMember(
|
||||||
|
user_id=user.get("id", ""),
|
||||||
|
nickname=user.get("nick", user.get("name", "")),
|
||||||
|
)
|
||||||
|
|
||||||
|
abm.self_id = login.get("user", {}).get("id", "")
|
||||||
|
|
||||||
|
# 消息链
|
||||||
|
abm.message = []
|
||||||
|
|
||||||
|
content = message.get("content", "")
|
||||||
|
|
||||||
|
quote = message.get("quote")
|
||||||
|
content_for_parsing = content # 副本
|
||||||
|
|
||||||
|
# 提取<quote>标签
|
||||||
|
if "<quote" in content:
|
||||||
|
try:
|
||||||
|
quote_info = await self._extract_quote_element(content)
|
||||||
|
if quote_info:
|
||||||
|
quote = quote_info["quote"]
|
||||||
|
content_for_parsing = quote_info["content_without_quote"]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析<quote>标签时发生错误: {e}, 错误内容: {content}")
|
||||||
|
|
||||||
|
if quote:
|
||||||
|
# 引用消息
|
||||||
|
quote_abm = await self._convert_quote_message(quote)
|
||||||
|
if quote_abm:
|
||||||
|
sender_id = quote_abm.sender.user_id
|
||||||
|
if isinstance(sender_id, str) and sender_id.isdigit():
|
||||||
|
sender_id = int(sender_id)
|
||||||
|
elif not isinstance(sender_id, int):
|
||||||
|
sender_id = 0 # 默认值
|
||||||
|
|
||||||
|
reply_component = Reply(
|
||||||
|
id=quote_abm.message_id,
|
||||||
|
chain=quote_abm.message,
|
||||||
|
sender_id=quote_abm.sender.user_id,
|
||||||
|
sender_nickname=quote_abm.sender.nickname,
|
||||||
|
time=quote_abm.timestamp,
|
||||||
|
message_str=quote_abm.message_str,
|
||||||
|
text=quote_abm.message_str,
|
||||||
|
qq=sender_id,
|
||||||
|
)
|
||||||
|
abm.message.append(reply_component)
|
||||||
|
|
||||||
|
# 解析消息内容
|
||||||
|
content_elements = await self.parse_satori_elements(content_for_parsing)
|
||||||
|
abm.message.extend(content_elements)
|
||||||
|
|
||||||
|
abm.message_str = ""
|
||||||
|
for comp in content_elements:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
abm.message_str += comp.text
|
||||||
|
|
||||||
|
# 优先使用Satori事件中的时间戳
|
||||||
|
if timestamp is not None:
|
||||||
|
abm.timestamp = timestamp
|
||||||
|
else:
|
||||||
|
abm.timestamp = int(time.time())
|
||||||
|
|
||||||
|
return abm
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"转换 Satori 消息失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _extract_namespace_prefixes(self, content: str) -> set:
|
||||||
|
"""提取XML内容中的命名空间前缀"""
|
||||||
|
prefixes = set()
|
||||||
|
|
||||||
|
# 查找所有标签
|
||||||
|
i = 0
|
||||||
|
while i < len(content):
|
||||||
|
# 查找开始标签
|
||||||
|
if content[i] == "<" and i + 1 < len(content) and content[i + 1] != "/":
|
||||||
|
# 找到标签结束位置
|
||||||
|
tag_end = content.find(">", i)
|
||||||
|
if tag_end != -1:
|
||||||
|
# 提取标签内容
|
||||||
|
tag_content = content[i + 1 : tag_end]
|
||||||
|
# 检查是否有命名空间前缀
|
||||||
|
if ":" in tag_content and "xmlns:" not in tag_content:
|
||||||
|
# 分割标签名
|
||||||
|
parts = tag_content.split()
|
||||||
|
if parts:
|
||||||
|
tag_name = parts[0]
|
||||||
|
if ":" in tag_name:
|
||||||
|
prefix = tag_name.split(":")[0]
|
||||||
|
# 确保是有效的命名空间前缀
|
||||||
|
if (
|
||||||
|
prefix.isalnum()
|
||||||
|
or prefix.replace("_", "").isalnum()
|
||||||
|
):
|
||||||
|
prefixes.add(prefix)
|
||||||
|
i = tag_end + 1
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
# 查找结束标签
|
||||||
|
elif content[i] == "<" and i + 1 < len(content) and content[i + 1] == "/":
|
||||||
|
# 找到标签结束位置
|
||||||
|
tag_end = content.find(">", i)
|
||||||
|
if tag_end != -1:
|
||||||
|
# 提取标签内容
|
||||||
|
tag_content = content[i + 2 : tag_end]
|
||||||
|
# 检查是否有命名空间前缀
|
||||||
|
if ":" in tag_content:
|
||||||
|
prefix = tag_content.split(":")[0]
|
||||||
|
# 确保是有效的命名空间前缀
|
||||||
|
if prefix.isalnum() or prefix.replace("_", "").isalnum():
|
||||||
|
prefixes.add(prefix)
|
||||||
|
i = tag_end + 1
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
else:
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
return prefixes
|
||||||
|
|
||||||
|
async def _extract_quote_element(self, content: str) -> Optional[dict]:
|
||||||
|
"""提取<quote>标签信息"""
|
||||||
|
try:
|
||||||
|
# 处理命名空间前缀问题
|
||||||
|
processed_content = content
|
||||||
|
if ":" in content and not content.startswith("<root"):
|
||||||
|
prefixes = self._extract_namespace_prefixes(content)
|
||||||
|
|
||||||
|
# 构建命名空间声明
|
||||||
|
ns_declarations = " ".join(
|
||||||
|
[
|
||||||
|
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
||||||
|
for prefix in prefixes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 包装内容
|
||||||
|
processed_content = f"<root {ns_declarations}>{content}</root>"
|
||||||
|
elif not content.startswith("<root"):
|
||||||
|
processed_content = f"<root>{content}</root>"
|
||||||
|
else:
|
||||||
|
processed_content = content
|
||||||
|
|
||||||
|
root = ET.fromstring(processed_content)
|
||||||
|
|
||||||
|
# 查找<quote>标签
|
||||||
|
quote_element = None
|
||||||
|
for elem in root.iter():
|
||||||
|
tag_name = elem.tag
|
||||||
|
if "}" in tag_name:
|
||||||
|
tag_name = tag_name.split("}")[1]
|
||||||
|
if tag_name.lower() == "quote":
|
||||||
|
quote_element = elem
|
||||||
|
break
|
||||||
|
|
||||||
|
if quote_element is not None:
|
||||||
|
# 提取quote标签的属性
|
||||||
|
quote_id = quote_element.get("id", "")
|
||||||
|
|
||||||
|
# 提取<quote>标签内部的内容
|
||||||
|
inner_content = ""
|
||||||
|
if quote_element.text:
|
||||||
|
inner_content += quote_element.text
|
||||||
|
for child in quote_element:
|
||||||
|
inner_content += ET.tostring(
|
||||||
|
child, encoding="unicode", method="xml"
|
||||||
|
)
|
||||||
|
if child.tail:
|
||||||
|
inner_content += child.tail
|
||||||
|
|
||||||
|
# 构造移除了<quote>标签的内容
|
||||||
|
content_without_quote = content.replace(
|
||||||
|
ET.tostring(quote_element, encoding="unicode", method="xml"), ""
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"quote": {"id": quote_id, "content": inner_content},
|
||||||
|
"content_without_quote": content_without_quote,
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"提取<quote>标签时发生错误: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _convert_quote_message(self, quote: dict) -> Optional[AstrBotMessage]:
|
||||||
|
"""转换引用消息"""
|
||||||
|
try:
|
||||||
|
quote_abm = AstrBotMessage()
|
||||||
|
quote_abm.message_id = quote.get("id", "")
|
||||||
|
|
||||||
|
# 解析引用消息的发送者
|
||||||
|
quote_author = quote.get("author", {})
|
||||||
|
if quote_author:
|
||||||
|
quote_abm.sender = MessageMember(
|
||||||
|
user_id=quote_author.get("id", ""),
|
||||||
|
nickname=quote_author.get("nick", quote_author.get("name", "")),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果没有作者信息,使用默认值
|
||||||
|
quote_abm.sender = MessageMember(
|
||||||
|
user_id=quote.get("user_id", ""),
|
||||||
|
nickname="内容",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 解析引用消息内容
|
||||||
|
quote_content = quote.get("content", "")
|
||||||
|
quote_abm.message = await self.parse_satori_elements(quote_content)
|
||||||
|
|
||||||
|
quote_abm.message_str = ""
|
||||||
|
for comp in quote_abm.message:
|
||||||
|
if isinstance(comp, Plain):
|
||||||
|
quote_abm.message_str += comp.text
|
||||||
|
|
||||||
|
quote_abm.timestamp = int(quote.get("timestamp", time.time()))
|
||||||
|
|
||||||
|
# 如果没有任何内容,使用默认文本
|
||||||
|
if not quote_abm.message_str.strip():
|
||||||
|
quote_abm.message_str = "[引用消息]"
|
||||||
|
|
||||||
|
return quote_abm
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"转换引用消息失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def parse_satori_elements(self, content: str) -> list:
|
||||||
|
"""解析 Satori 消息元素"""
|
||||||
|
elements = []
|
||||||
|
|
||||||
|
if not content:
|
||||||
|
return elements
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 处理命名空间前缀问题
|
||||||
|
processed_content = content
|
||||||
|
if ":" in content and not content.startswith("<root"):
|
||||||
|
prefixes = self._extract_namespace_prefixes(content)
|
||||||
|
|
||||||
|
# 构建命名空间声明
|
||||||
|
ns_declarations = " ".join(
|
||||||
|
[
|
||||||
|
f'xmlns:{prefix}="http://temp.uri/{prefix}"'
|
||||||
|
for prefix in prefixes
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# 包装内容
|
||||||
|
processed_content = f"<root {ns_declarations}>{content}</root>"
|
||||||
|
elif not content.startswith("<root"):
|
||||||
|
processed_content = f"<root>{content}</root>"
|
||||||
|
else:
|
||||||
|
processed_content = content
|
||||||
|
|
||||||
|
root = ET.fromstring(processed_content)
|
||||||
|
await self._parse_xml_node(root, elements)
|
||||||
|
except ET.ParseError as e:
|
||||||
|
logger.error(f"解析 Satori 元素时发生解析错误: {e}, 错误内容: {content}")
|
||||||
|
# 如果解析失败,将整个内容当作纯文本
|
||||||
|
if content.strip():
|
||||||
|
elements.append(Plain(text=content))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析 Satori 元素时发生未知错误: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# 如果没有解析到任何元素,将整个内容当作纯文本
|
||||||
|
if not elements and content.strip():
|
||||||
|
elements.append(Plain(text=content))
|
||||||
|
|
||||||
|
return elements
|
||||||
|
|
||||||
|
async def _parse_xml_node(self, node: ET.Element, elements: list) -> None:
|
||||||
|
"""递归解析 XML 节点"""
|
||||||
|
if node.text and node.text.strip():
|
||||||
|
elements.append(Plain(text=node.text))
|
||||||
|
|
||||||
|
for child in node:
|
||||||
|
# 获取标签名,去除命名空间前缀
|
||||||
|
tag_name = child.tag
|
||||||
|
if "}" in tag_name:
|
||||||
|
tag_name = tag_name.split("}")[1]
|
||||||
|
tag_name = tag_name.lower()
|
||||||
|
|
||||||
|
attrs = child.attrib
|
||||||
|
|
||||||
|
if tag_name == "at":
|
||||||
|
user_id = attrs.get("id") or attrs.get("name", "")
|
||||||
|
elements.append(At(qq=user_id, name=user_id))
|
||||||
|
|
||||||
|
elif tag_name in ("img", "image"):
|
||||||
|
src = attrs.get("src", "")
|
||||||
|
if not src:
|
||||||
|
continue
|
||||||
|
elements.append(Image(file=src))
|
||||||
|
|
||||||
|
elif tag_name == "file":
|
||||||
|
src = attrs.get("src", "")
|
||||||
|
name = attrs.get("name", "文件")
|
||||||
|
if src:
|
||||||
|
elements.append(File(name=name, file=src))
|
||||||
|
|
||||||
|
elif tag_name in ("audio", "record"):
|
||||||
|
src = attrs.get("src", "")
|
||||||
|
if not src:
|
||||||
|
continue
|
||||||
|
elements.append(Record(file=src))
|
||||||
|
|
||||||
|
elif tag_name == "quote":
|
||||||
|
# quote标签已经被特殊处理
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif tag_name == "face":
|
||||||
|
face_id = attrs.get("id", "")
|
||||||
|
face_name = attrs.get("name", "")
|
||||||
|
face_type = attrs.get("type", "")
|
||||||
|
|
||||||
|
if face_name:
|
||||||
|
elements.append(Plain(text=f"[表情:{face_name}]"))
|
||||||
|
elif face_id and face_type:
|
||||||
|
elements.append(Plain(text=f"[表情ID:{face_id},类型:{face_type}]"))
|
||||||
|
elif face_id:
|
||||||
|
elements.append(Plain(text=f"[表情ID:{face_id}]"))
|
||||||
|
else:
|
||||||
|
elements.append(Plain(text="[表情]"))
|
||||||
|
|
||||||
|
elif tag_name == "ark":
|
||||||
|
# 作为纯文本添加到消息链中
|
||||||
|
data = attrs.get("data", "")
|
||||||
|
if data:
|
||||||
|
import html
|
||||||
|
|
||||||
|
decoded_data = html.unescape(data)
|
||||||
|
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
||||||
|
else:
|
||||||
|
elements.append(Plain(text="[ARK卡片]"))
|
||||||
|
|
||||||
|
elif tag_name == "json":
|
||||||
|
# JSON标签 视为ARK卡片消息
|
||||||
|
data = attrs.get("data", "")
|
||||||
|
if data:
|
||||||
|
import html
|
||||||
|
|
||||||
|
decoded_data = html.unescape(data)
|
||||||
|
elements.append(Plain(text=f"[ARK卡片数据: {decoded_data}]"))
|
||||||
|
else:
|
||||||
|
elements.append(Plain(text="[JSON卡片]"))
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 未知标签,递归处理其内容
|
||||||
|
if child.text and child.text.strip():
|
||||||
|
elements.append(Plain(text=child.text))
|
||||||
|
await self._parse_xml_node(child, elements)
|
||||||
|
|
||||||
|
# 处理标签后的文本
|
||||||
|
if child.tail and child.tail.strip():
|
||||||
|
elements.append(Plain(text=child.tail))
|
||||||
|
|
||||||
|
async def handle_msg(self, message: AstrBotMessage):
|
||||||
|
from .satori_event import SatoriPlatformEvent
|
||||||
|
|
||||||
|
message_event = SatoriPlatformEvent(
|
||||||
|
message_str=message.message_str,
|
||||||
|
message_obj=message,
|
||||||
|
platform_meta=self.meta(),
|
||||||
|
session_id=message.session_id,
|
||||||
|
adapter=self,
|
||||||
|
)
|
||||||
|
self.commit_event(message_event)
|
||||||
|
|
||||||
|
async def send_http_request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
path: str,
|
||||||
|
data: dict | None = None,
|
||||||
|
platform: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
if not self.session:
|
||||||
|
raise Exception("HTTP session 未初始化")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.token:
|
||||||
|
headers["Authorization"] = f"Bearer {self.token}"
|
||||||
|
|
||||||
|
if platform and user_id:
|
||||||
|
headers["satori-platform"] = platform
|
||||||
|
headers["satori-user-id"] = user_id
|
||||||
|
elif self.logins:
|
||||||
|
current_login = self.logins[0]
|
||||||
|
headers["satori-platform"] = current_login.get("platform", "")
|
||||||
|
user = current_login.get("user", {})
|
||||||
|
headers["satori-user-id"] = user.get("id", "") if user else ""
|
||||||
|
|
||||||
|
if not path.startswith("/"):
|
||||||
|
path = "/" + path
|
||||||
|
|
||||||
|
# 使用新的API地址配置
|
||||||
|
url = f"{self.api_base_url.rstrip('/')}{path}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.session.request(
|
||||||
|
method, url, json=data, headers=headers
|
||||||
|
) as response:
|
||||||
|
if response.status == 200:
|
||||||
|
result = await response.json()
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori HTTP 请求异常: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
self.running = False
|
||||||
|
|
||||||
|
if self.heartbeat_task:
|
||||||
|
self.heartbeat_task.cancel()
|
||||||
|
|
||||||
|
if self.ws:
|
||||||
|
try:
|
||||||
|
await self.ws.close()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori WebSocket 关闭异常: {e}")
|
||||||
|
|
||||||
|
if self.session:
|
||||||
|
await self.session.close()
|
||||||
@@ -0,0 +1,230 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
from astrbot.api import logger
|
||||||
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
|
from astrbot.api.message_components import Plain, Image, At, File, Record
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .satori_adapter import SatoriPlatformAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class SatoriPlatformEvent(AstrMessageEvent):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
adapter: "SatoriPlatformAdapter",
|
||||||
|
):
|
||||||
|
# 更新平台元数据
|
||||||
|
if adapter and hasattr(adapter, "logins") and adapter.logins:
|
||||||
|
current_login = adapter.logins[0]
|
||||||
|
platform_name = current_login.get("platform", "satori")
|
||||||
|
user = current_login.get("user", {})
|
||||||
|
user_id = user.get("id", "") if user else ""
|
||||||
|
if not platform_meta.id and user_id:
|
||||||
|
platform_meta.id = f"{platform_name}({user_id})"
|
||||||
|
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.adapter = adapter
|
||||||
|
self.platform = None
|
||||||
|
self.user_id = None
|
||||||
|
if (
|
||||||
|
hasattr(message_obj, "raw_message")
|
||||||
|
and message_obj.raw_message
|
||||||
|
and isinstance(message_obj.raw_message, dict)
|
||||||
|
):
|
||||||
|
login = message_obj.raw_message.get("login", {})
|
||||||
|
self.platform = login.get("platform")
|
||||||
|
user = login.get("user", {})
|
||||||
|
self.user_id = user.get("id") if user else None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def send_with_adapter(
|
||||||
|
cls, adapter: "SatoriPlatformAdapter", message: MessageChain, session_id: str
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
content_parts = []
|
||||||
|
|
||||||
|
for component in message.chain:
|
||||||
|
if isinstance(component, Plain):
|
||||||
|
text = (
|
||||||
|
component.text.replace("&", "&")
|
||||||
|
.replace("<", "<")
|
||||||
|
.replace(">", ">")
|
||||||
|
)
|
||||||
|
content_parts.append(text)
|
||||||
|
|
||||||
|
elif isinstance(component, At):
|
||||||
|
if component.qq:
|
||||||
|
content_parts.append(f'<at id="{component.qq}"/>')
|
||||||
|
elif component.name:
|
||||||
|
content_parts.append(f'<at name="{component.name}"/>')
|
||||||
|
|
||||||
|
elif isinstance(component, Image):
|
||||||
|
try:
|
||||||
|
image_base64 = await component.convert_to_base64()
|
||||||
|
if image_base64:
|
||||||
|
content_parts.append(
|
||||||
|
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图片转换为base64失败: {e}")
|
||||||
|
|
||||||
|
elif isinstance(component, File):
|
||||||
|
content_parts.append(
|
||||||
|
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(component, Record):
|
||||||
|
try:
|
||||||
|
record_base64 = await component.convert_to_base64()
|
||||||
|
if record_base64:
|
||||||
|
content_parts.append(
|
||||||
|
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"语音转换为base64失败: {e}")
|
||||||
|
|
||||||
|
content = "".join(content_parts)
|
||||||
|
channel_id = session_id
|
||||||
|
data = {"channel_id": channel_id, "content": content}
|
||||||
|
|
||||||
|
platform = None
|
||||||
|
user_id = None
|
||||||
|
|
||||||
|
if hasattr(adapter, "logins") and adapter.logins:
|
||||||
|
current_login = adapter.logins[0]
|
||||||
|
platform = current_login.get("platform", "")
|
||||||
|
user = current_login.get("user", {})
|
||||||
|
user_id = user.get("id", "") if user else ""
|
||||||
|
|
||||||
|
result = await adapter.send_http_request(
|
||||||
|
"POST", "/message.create", data, platform, user_id
|
||||||
|
)
|
||||||
|
if result:
|
||||||
|
return result
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori 消息发送异常: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def send(self, message: MessageChain):
|
||||||
|
platform = getattr(self, "platform", None)
|
||||||
|
user_id = getattr(self, "user_id", None)
|
||||||
|
|
||||||
|
if not platform or not user_id:
|
||||||
|
if hasattr(self.adapter, "logins") and self.adapter.logins:
|
||||||
|
current_login = self.adapter.logins[0]
|
||||||
|
platform = current_login.get("platform", "")
|
||||||
|
user = current_login.get("user", {})
|
||||||
|
user_id = user.get("id", "") if user else ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
content_parts = []
|
||||||
|
|
||||||
|
for component in message.chain:
|
||||||
|
if isinstance(component, Plain):
|
||||||
|
text = (
|
||||||
|
component.text.replace("&", "&")
|
||||||
|
.replace("<", "<")
|
||||||
|
.replace(">", ">")
|
||||||
|
)
|
||||||
|
content_parts.append(text)
|
||||||
|
|
||||||
|
elif isinstance(component, At):
|
||||||
|
if component.qq:
|
||||||
|
content_parts.append(f'<at id="{component.qq}"/>')
|
||||||
|
elif component.name:
|
||||||
|
content_parts.append(f'<at name="{component.name}"/>')
|
||||||
|
|
||||||
|
elif isinstance(component, Image):
|
||||||
|
try:
|
||||||
|
image_base64 = await component.convert_to_base64()
|
||||||
|
if image_base64:
|
||||||
|
content_parts.append(
|
||||||
|
f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图片转换为base64失败: {e}")
|
||||||
|
|
||||||
|
elif isinstance(component, File):
|
||||||
|
content_parts.append(
|
||||||
|
f'<file src="{component.file}" name="{component.name or "文件"}"/>'
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(component, Record):
|
||||||
|
try:
|
||||||
|
record_base64 = await component.convert_to_base64()
|
||||||
|
if record_base64:
|
||||||
|
content_parts.append(
|
||||||
|
f'<audio src="data:audio/wav;base64,{record_base64}"/>'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"语音转换为base64失败: {e}")
|
||||||
|
|
||||||
|
content = "".join(content_parts)
|
||||||
|
channel_id = self.session_id
|
||||||
|
data = {"channel_id": channel_id, "content": content}
|
||||||
|
|
||||||
|
result = await self.adapter.send_http_request(
|
||||||
|
"POST", "/message.create", data, platform, user_id
|
||||||
|
)
|
||||||
|
if not result:
|
||||||
|
logger.error("Satori 消息发送失败")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori 消息发送异常: {e}")
|
||||||
|
|
||||||
|
await super().send(message)
|
||||||
|
|
||||||
|
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||||
|
try:
|
||||||
|
content_parts = []
|
||||||
|
|
||||||
|
async for chain in generator:
|
||||||
|
if isinstance(chain, MessageChain):
|
||||||
|
if chain.type == "break":
|
||||||
|
if content_parts:
|
||||||
|
content = "".join(content_parts)
|
||||||
|
temp_chain = MessageChain([Plain(text=content)])
|
||||||
|
await self.send(temp_chain)
|
||||||
|
content_parts = []
|
||||||
|
continue
|
||||||
|
|
||||||
|
for component in chain.chain:
|
||||||
|
if isinstance(component, Plain):
|
||||||
|
content_parts.append(component.text)
|
||||||
|
elif isinstance(component, Image):
|
||||||
|
if content_parts:
|
||||||
|
content = "".join(content_parts)
|
||||||
|
temp_chain = MessageChain([Plain(text=content)])
|
||||||
|
await self.send(temp_chain)
|
||||||
|
content_parts = []
|
||||||
|
try:
|
||||||
|
image_base64 = await component.convert_to_base64()
|
||||||
|
if image_base64:
|
||||||
|
img_chain = MessageChain(
|
||||||
|
[
|
||||||
|
Plain(
|
||||||
|
text=f'<img src="data:image/jpeg;base64,{image_base64}"/>'
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await self.send(img_chain)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"图片转换为base64失败: {e}")
|
||||||
|
else:
|
||||||
|
content_parts.append(str(component))
|
||||||
|
|
||||||
|
if content_parts:
|
||||||
|
content = "".join(content_parts)
|
||||||
|
temp_chain = MessageChain([Plain(text=content)])
|
||||||
|
await self.send(temp_chain)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Satori 流式消息发送异常: {e}")
|
||||||
|
|
||||||
|
return await super().send_streaming(generator, use_fallback)
|
||||||
@@ -308,7 +308,9 @@ class SlackAdapter(Platform):
|
|||||||
base64_content = base64.b64encode(content).decode("utf-8")
|
base64_content = base64.b64encode(content).decode("utf-8")
|
||||||
return base64_content
|
return base64_content
|
||||||
else:
|
else:
|
||||||
logger.error(f"Failed to download slack file: {resp.status} {await resp.text()}")
|
logger.error(
|
||||||
|
f"Failed to download slack file: {resp.status} {await resp.text()}"
|
||||||
|
)
|
||||||
raise Exception(f"下载文件失败: {resp.status}")
|
raise Exception(f"下载文件失败: {resp.status}")
|
||||||
|
|
||||||
async def run(self) -> Awaitable[Any]:
|
async def run(self) -> Awaitable[Any]:
|
||||||
|
|||||||
@@ -75,7 +75,13 @@ class SlackMessageEvent(AstrMessageEvent):
|
|||||||
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
"text": {"type": "mrkdwn", "text": "文件上传失败"},
|
||||||
}
|
}
|
||||||
file_url = response["files"][0]["permalink"]
|
file_url = response["files"][0]["permalink"]
|
||||||
return {"type": "section", "text": {"type": "mrkdwn", "text": f"文件: <{file_url}|{segment.name or '文件'}>"}}
|
return {
|
||||||
|
"type": "section",
|
||||||
|
"text": {
|
||||||
|
"type": "mrkdwn",
|
||||||
|
"text": f"文件: <{file_url}|{segment.name or '文件'}>",
|
||||||
|
},
|
||||||
|
}
|
||||||
else:
|
else:
|
||||||
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
return {"type": "section", "text": {"type": "mrkdwn", "text": str(segment)}}
|
||||||
|
|
||||||
|
|||||||
@@ -183,7 +183,6 @@ class TelegramPlatformAdapter(Platform):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
|
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
|
||||||
logger.debug(f"跳过无法注册的命令: {cmd_name}")
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Build description.
|
# Build description.
|
||||||
|
|||||||
@@ -66,7 +66,9 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def send_with_client(cls, client: ExtBot, message: MessageChain, user_name: str):
|
async def send_with_client(
|
||||||
|
cls, client: ExtBot, message: MessageChain, user_name: str
|
||||||
|
):
|
||||||
image_path = None
|
image_path = None
|
||||||
|
|
||||||
has_reply = False
|
has_reply = False
|
||||||
@@ -216,7 +218,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
try:
|
try:
|
||||||
msg = await self.client.send_message(text=delta, **payload)
|
msg = await self.client.send_message(text=delta, **payload)
|
||||||
current_content = delta
|
current_content = delta
|
||||||
delta = ""
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||||
message_id = msg.message_id
|
message_id = msg.message_id
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
class WebChatQueueMgr:
|
class WebChatQueueMgr:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.queues = {}
|
self.queues = {}
|
||||||
@@ -30,4 +31,5 @@ class WebChatQueueMgr:
|
|||||||
"""Check if a queue exists for the given conversation ID"""
|
"""Check if a queue exists for the given conversation ID"""
|
||||||
return conversation_id in self.queues
|
return conversation_id in self.queues
|
||||||
|
|
||||||
|
|
||||||
webchat_queue_mgr = WebChatQueueMgr()
|
webchat_queue_mgr = WebChatQueueMgr()
|
||||||
|
|||||||
@@ -213,10 +213,10 @@ class WeChatPadProAdapter(Platform):
|
|||||||
def _extract_auth_key(self, data):
|
def _extract_auth_key(self, data):
|
||||||
"""Helper method to extract auth_key from response data."""
|
"""Helper method to extract auth_key from response data."""
|
||||||
if isinstance(data, dict):
|
if isinstance(data, dict):
|
||||||
auth_keys = data.get("authKeys") # 新接口
|
auth_keys = data.get("authKeys") # 新接口
|
||||||
if isinstance(auth_keys, list) and auth_keys:
|
if isinstance(auth_keys, list) and auth_keys:
|
||||||
return auth_keys[0]
|
return auth_keys[0]
|
||||||
elif isinstance(data, list) and data: # 旧接口
|
elif isinstance(data, list) and data: # 旧接口
|
||||||
return data[0]
|
return data[0]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -234,7 +234,9 @@ class WeChatPadProAdapter(Platform):
|
|||||||
try:
|
try:
|
||||||
async with session.post(url, params=params, json=payload) as response:
|
async with session.post(url, params=params, json=payload) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
logger.error(f"生成授权码失败: {response.status}, {await response.text()}")
|
logger.error(
|
||||||
|
f"生成授权码失败: {response.status}, {await response.text()}"
|
||||||
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
response_data = await response.json()
|
response_data = await response.json()
|
||||||
@@ -245,7 +247,9 @@ class WeChatPadProAdapter(Platform):
|
|||||||
if self.auth_key:
|
if self.auth_key:
|
||||||
logger.info("成功获取授权码")
|
logger.info("成功获取授权码")
|
||||||
else:
|
else:
|
||||||
logger.error(f"生成授权码成功但未找到授权码: {response_data}")
|
logger.error(
|
||||||
|
f"生成授权码成功但未找到授权码: {response_data}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.error(f"生成授权码失败: {response_data}")
|
logger.error(f"生成授权码失败: {response_data}")
|
||||||
except aiohttp.ClientConnectorError as e:
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
|||||||
@@ -48,7 +48,12 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||||
:return: 接口调用结果
|
:return: 接口调用结果
|
||||||
"""
|
"""
|
||||||
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
data = {
|
||||||
|
"token": token,
|
||||||
|
"cursor": cursor,
|
||||||
|
"limit": limit,
|
||||||
|
"open_kfid": open_kfid,
|
||||||
|
}
|
||||||
return self._post("kf/sync_msg", data=data)
|
return self._post("kf/sync_msg", data=data)
|
||||||
|
|
||||||
def get_service_state(self, open_kfid, external_userid):
|
def get_service_state(self, open_kfid, external_userid):
|
||||||
@@ -72,7 +77,9 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
}
|
}
|
||||||
return self._post("kf/service_state/get", data=data)
|
return self._post("kf/service_state/get", data=data)
|
||||||
|
|
||||||
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
def trans_service_state(
|
||||||
|
self, open_kfid, external_userid, service_state, servicer_userid=""
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
变更会话状态
|
变更会话状态
|
||||||
|
|
||||||
@@ -180,7 +187,9 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
"""
|
"""
|
||||||
return self._get("kf/customer/get_upgrade_service_config")
|
return self._get("kf/customer/get_upgrade_service_config")
|
||||||
|
|
||||||
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
def upgrade_service(
|
||||||
|
self, open_kfid, external_userid, service_type, member=None, groupchat=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
为客户升级为专员或客户群服务
|
为客户升级为专员或客户群服务
|
||||||
|
|
||||||
@@ -246,7 +255,9 @@ class WeChatKF(BaseWeChatAPI):
|
|||||||
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||||
return self._post("kf/get_corp_statistic", data=data)
|
return self._post("kf/get_corp_statistic", data=data)
|
||||||
|
|
||||||
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
def get_servicer_statistic(
|
||||||
|
self, start_time, end_time, open_kfid=None, servicer_userid=None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
获取「客户数据统计」接待人员明细数据
|
获取「客户数据统计」接待人员明细数据
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from optionaldict import optionaldict
|
|||||||
|
|
||||||
from wechatpy.client.api.base import BaseWeChatAPI
|
from wechatpy.client.api.base import BaseWeChatAPI
|
||||||
|
|
||||||
|
|
||||||
class WeChatKFMessage(BaseWeChatAPI):
|
class WeChatKFMessage(BaseWeChatAPI):
|
||||||
"""
|
"""
|
||||||
发送微信客服消息
|
发送微信客服消息
|
||||||
@@ -125,35 +126,55 @@ class WeChatKFMessage(BaseWeChatAPI):
|
|||||||
msg={"msgtype": "news", "link": {"link": articles_data}},
|
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
def send_msgmenu(
|
||||||
|
self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""
|
||||||
|
):
|
||||||
return self.send(
|
return self.send(
|
||||||
user_id,
|
user_id,
|
||||||
open_kfid,
|
open_kfid,
|
||||||
msgid,
|
msgid,
|
||||||
msg={
|
msg={
|
||||||
"msgtype": "msgmenu",
|
"msgtype": "msgmenu",
|
||||||
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
"msgmenu": {
|
||||||
|
"head_content": head_content,
|
||||||
|
"list": menu_list,
|
||||||
|
"tail_content": tail_content,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
def send_location(
|
||||||
|
self, user_id, open_kfid, name, address, latitude, longitude, msgid=""
|
||||||
|
):
|
||||||
return self.send(
|
return self.send(
|
||||||
user_id,
|
user_id,
|
||||||
open_kfid,
|
open_kfid,
|
||||||
msgid,
|
msgid,
|
||||||
msg={
|
msg={
|
||||||
"msgtype": "location",
|
"msgtype": "location",
|
||||||
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
"msgmenu": {
|
||||||
|
"name": name,
|
||||||
|
"address": address,
|
||||||
|
"latitude": latitude,
|
||||||
|
"longitude": longitude,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
def send_miniprogram(
|
||||||
|
self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""
|
||||||
|
):
|
||||||
return self.send(
|
return self.send(
|
||||||
user_id,
|
user_id,
|
||||||
open_kfid,
|
open_kfid,
|
||||||
msgid,
|
msgid,
|
||||||
msg={
|
msg={
|
||||||
"msgtype": "miniprogram",
|
"msgtype": "miniprogram",
|
||||||
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
"msgmenu": {
|
||||||
|
"appid": appid,
|
||||||
|
"title": title,
|
||||||
|
"thumb_media_id": thumb_media_id,
|
||||||
|
"pagepath": pagepath,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -160,7 +160,9 @@ class WeixinOfficialAccountPlatformAdapter(Platform):
|
|||||||
self.wexin_event_workers[msg.id] = future
|
self.wexin_event_workers[msg.id] = future
|
||||||
await self.convert_message(msg, future)
|
await self.convert_message(msg, future)
|
||||||
# I love shield so much!
|
# I love shield so much!
|
||||||
result = await asyncio.wait_for(asyncio.shield(future), 60) # wait for 60s
|
result = await asyncio.wait_for(
|
||||||
|
asyncio.shield(future), 60
|
||||||
|
) # wait for 60s
|
||||||
logger.debug(f"Got future result: {result}")
|
logger.debug(f"Got future result: {result}")
|
||||||
self.wexin_event_workers.pop(msg.id, None)
|
self.wexin_event_workers.pop(msg.id, None)
|
||||||
return result # xml. see weixin_offacc_event.py
|
return result # xml. see weixin_offacc_event.py
|
||||||
|
|||||||
@@ -150,7 +150,6 @@ class WeixinOfficialAccountPlatformEvent(AstrMessageEvent):
|
|||||||
return
|
return
|
||||||
logger.info(f"微信公众平台上传语音返回: {response}")
|
logger.info(f"微信公众平台上传语音返回: {response}")
|
||||||
|
|
||||||
|
|
||||||
if active_send_mode:
|
if active_send_mode:
|
||||||
self.client.message.send_voice(
|
self.client.message.send_voice(
|
||||||
message_obj.sender.user_id,
|
message_obj.sender.user_id,
|
||||||
|
|||||||
@@ -65,13 +65,16 @@ class AssistantMessageSegment:
|
|||||||
role: str = "assistant"
|
role: str = "assistant"
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
ret = {
|
ret: dict[str, str | list[dict]] = {
|
||||||
"role": self.role,
|
"role": self.role,
|
||||||
}
|
}
|
||||||
if self.content:
|
if self.content:
|
||||||
ret["content"] = self.content
|
ret["content"] = self.content
|
||||||
if self.tool_calls:
|
if self.tool_calls:
|
||||||
ret["tool_calls"] = self.tool_calls
|
tool_calls_dict = [
|
||||||
|
tc if isinstance(tc, dict) else tc.to_dict() for tc in self.tool_calls
|
||||||
|
]
|
||||||
|
ret["tool_calls"] = tool_calls_dict
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +120,14 @@ class ProviderRequest:
|
|||||||
"""模型名称,为 None 时使用提供商的默认模型"""
|
"""模型名称,为 None 时使用提供商的默认模型"""
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
return (
|
||||||
|
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
|
||||||
|
f"image_count={len(self.image_urls or [])}, "
|
||||||
|
f"func_tool={self.func_tool}, "
|
||||||
|
f"contexts={self._print_friendly_context()}, "
|
||||||
|
f"system_prompt={self.system_prompt}, "
|
||||||
|
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
|
||||||
|
)
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.__repr__()
|
return self.__repr__()
|
||||||
@@ -297,6 +307,7 @@ class LLMResponse:
|
|||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RerankResult:
|
class RerankResult:
|
||||||
index: int
|
index: int
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
import asyncio
|
import asyncio
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from typing import Dict, List, Awaitable
|
from typing import Dict, List, Awaitable, Callable, Any
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core import sp
|
from astrbot.core import sp
|
||||||
|
|
||||||
@@ -109,7 +109,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
func_args: list,
|
func_args: list,
|
||||||
desc: str,
|
desc: str,
|
||||||
handler: Awaitable,
|
handler: Callable[..., Awaitable[Any]],
|
||||||
) -> FuncTool:
|
) -> FuncTool:
|
||||||
params = {
|
params = {
|
||||||
"type": "object", # hard-coded here
|
"type": "object", # hard-coded here
|
||||||
@@ -132,7 +132,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
func_args: list,
|
func_args: list,
|
||||||
desc: str,
|
desc: str,
|
||||||
handler: Awaitable,
|
handler: Callable[..., Awaitable[Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""添加函数调用工具
|
"""添加函数调用工具
|
||||||
|
|
||||||
@@ -220,7 +220,7 @@ class FunctionToolManager:
|
|||||||
name: str,
|
name: str,
|
||||||
cfg: dict,
|
cfg: dict,
|
||||||
event: asyncio.Event,
|
event: asyncio.Event,
|
||||||
ready_future: asyncio.Future = None,
|
ready_future: asyncio.Future | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -7,7 +7,13 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
|||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
|
|
||||||
from .entities import ProviderType
|
from .entities import ProviderType
|
||||||
from .provider import Provider, STTProvider, TTSProvider, EmbeddingProvider
|
from .provider import (
|
||||||
|
Provider,
|
||||||
|
STTProvider,
|
||||||
|
TTSProvider,
|
||||||
|
EmbeddingProvider,
|
||||||
|
RerankProvider,
|
||||||
|
)
|
||||||
from .register import llm_tools, provider_cls_map
|
from .register import llm_tools, provider_cls_map
|
||||||
from ..persona_mgr import PersonaManager
|
from ..persona_mgr import PersonaManager
|
||||||
|
|
||||||
@@ -38,7 +44,12 @@ class ProviderManager:
|
|||||||
"""加载的 Text To Speech Provider 的实例"""
|
"""加载的 Text To Speech Provider 的实例"""
|
||||||
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
self.embedding_provider_insts: List[EmbeddingProvider] = []
|
||||||
"""加载的 Embedding Provider 的实例"""
|
"""加载的 Embedding Provider 的实例"""
|
||||||
self.inst_map: dict[str, Provider] = {}
|
self.rerank_provider_insts: List[RerankProvider] = []
|
||||||
|
"""加载的 Rerank Provider 的实例"""
|
||||||
|
self.inst_map: dict[
|
||||||
|
str,
|
||||||
|
Provider | STTProvider | TTSProvider | EmbeddingProvider | RerankProvider,
|
||||||
|
] = {}
|
||||||
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
"""Provider 实例映射. key: provider_id, value: Provider 实例"""
|
||||||
self.llm_tools = llm_tools
|
self.llm_tools = llm_tools
|
||||||
|
|
||||||
@@ -87,19 +98,31 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
# 不启用提供商会话隔离模式的情况
|
# 不启用提供商会话隔离模式的情况
|
||||||
self.curr_provider_inst = self.inst_map[provider_id]
|
|
||||||
if provider_type == ProviderType.TEXT_TO_SPEECH:
|
prov = self.inst_map[provider_id]
|
||||||
|
if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance(
|
||||||
|
prov, TTSProvider
|
||||||
|
):
|
||||||
|
self.curr_tts_provider_inst = prov
|
||||||
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider_tts", provider_id, scope="global", scope_id="global")
|
||||||
elif provider_type == ProviderType.SPEECH_TO_TEXT:
|
elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance(
|
||||||
|
prov, STTProvider
|
||||||
|
):
|
||||||
|
self.curr_stt_provider_inst = prov
|
||||||
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider_stt", provider_id, scope="global", scope_id="global")
|
||||||
elif provider_type == ProviderType.CHAT_COMPLETION:
|
elif provider_type == ProviderType.CHAT_COMPLETION and isinstance(
|
||||||
|
prov, Provider
|
||||||
|
):
|
||||||
|
self.curr_provider_inst = prov
|
||||||
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
sp.put("curr_provider", provider_id, scope="global", scope_id="global")
|
||||||
|
|
||||||
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
async def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
||||||
"""根据提供商 ID 获取提供商实例"""
|
"""根据提供商 ID 获取提供商实例"""
|
||||||
return self.inst_map.get(provider_id)
|
return self.inst_map.get(provider_id)
|
||||||
|
|
||||||
def get_using_provider(self, provider_type: ProviderType, umo=None):
|
def get_using_provider(
|
||||||
|
self, provider_type: ProviderType, umo=None
|
||||||
|
) -> Provider | STTProvider | TTSProvider | None:
|
||||||
"""获取正在使用的提供商实例。
|
"""获取正在使用的提供商实例。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -303,12 +326,14 @@ class ProviderManager:
|
|||||||
provider_metadata = provider_cls_map[provider_config["type"]]
|
provider_metadata = provider_cls_map[provider_config["type"]]
|
||||||
try:
|
try:
|
||||||
# 按任务实例化提供商
|
# 按任务实例化提供商
|
||||||
|
cls_type = provider_metadata.cls_type
|
||||||
|
if not cls_type:
|
||||||
|
logger.error(f"无法找到 {provider_metadata.type} 的类")
|
||||||
|
return
|
||||||
|
|
||||||
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
|
||||||
# STT 任务
|
# STT 任务
|
||||||
inst = provider_metadata.cls_type(
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
provider_config, self.provider_settings
|
|
||||||
)
|
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
@@ -327,9 +352,7 @@ class ProviderManager:
|
|||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
|
||||||
# TTS 任务
|
# TTS 任务
|
||||||
inst = provider_metadata.cls_type(
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
provider_config, self.provider_settings
|
|
||||||
)
|
|
||||||
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
@@ -345,7 +368,7 @@ class ProviderManager:
|
|||||||
|
|
||||||
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
|
||||||
# 文本生成任务
|
# 文本生成任务
|
||||||
inst = provider_metadata.cls_type(
|
inst = cls_type(
|
||||||
provider_config,
|
provider_config,
|
||||||
self.provider_settings,
|
self.provider_settings,
|
||||||
self.selected_default_persona,
|
self.selected_default_persona,
|
||||||
@@ -366,13 +389,16 @@ class ProviderManager:
|
|||||||
if not self.curr_provider_inst:
|
if not self.curr_provider_inst:
|
||||||
self.curr_provider_inst = inst
|
self.curr_provider_inst = inst
|
||||||
|
|
||||||
elif provider_metadata.provider_type in [ProviderType.EMBEDDING, ProviderType.RERANK]:
|
elif provider_metadata.provider_type == ProviderType.EMBEDDING:
|
||||||
inst = provider_metadata.cls_type(
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
provider_config, self.provider_settings
|
|
||||||
)
|
|
||||||
if getattr(inst, "initialize", None):
|
if getattr(inst, "initialize", None):
|
||||||
await inst.initialize()
|
await inst.initialize()
|
||||||
self.embedding_provider_insts.append(inst)
|
self.embedding_provider_insts.append(inst)
|
||||||
|
elif provider_metadata.provider_type == ProviderType.RERANK:
|
||||||
|
inst = cls_type(provider_config, self.provider_settings)
|
||||||
|
if getattr(inst, "initialize", None):
|
||||||
|
await inst.initialize()
|
||||||
|
self.rerank_provider_insts.append(inst)
|
||||||
|
|
||||||
self.inst_map[provider_config["id"]] = inst
|
self.inst_map[provider_config["id"]] = inst
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -388,6 +414,7 @@ class ProviderManager:
|
|||||||
|
|
||||||
# 和配置文件保持同步
|
# 和配置文件保持同步
|
||||||
config_ids = [provider["id"] for provider in self.providers_config]
|
config_ids = [provider["id"] for provider in self.providers_config]
|
||||||
|
logger.debug(f"providers in user's config: {config_ids}")
|
||||||
for key in list(self.inst_map.keys()):
|
for key in list(self.inst_map.keys()):
|
||||||
if key not in config_ids:
|
if key not in config_ids:
|
||||||
await self.terminate_provider(key)
|
await self.terminate_provider(key)
|
||||||
@@ -426,11 +453,17 @@ class ProviderManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.inst_map[provider_id] in self.provider_insts:
|
if self.inst_map[provider_id] in self.provider_insts:
|
||||||
self.provider_insts.remove(self.inst_map[provider_id])
|
prov_inst = self.inst_map[provider_id]
|
||||||
|
if isinstance(prov_inst, Provider):
|
||||||
|
self.provider_insts.remove(prov_inst)
|
||||||
if self.inst_map[provider_id] in self.stt_provider_insts:
|
if self.inst_map[provider_id] in self.stt_provider_insts:
|
||||||
self.stt_provider_insts.remove(self.inst_map[provider_id])
|
prov_inst = self.inst_map[provider_id]
|
||||||
|
if isinstance(prov_inst, STTProvider):
|
||||||
|
self.stt_provider_insts.remove(prov_inst)
|
||||||
if self.inst_map[provider_id] in self.tts_provider_insts:
|
if self.inst_map[provider_id] in self.tts_provider_insts:
|
||||||
self.tts_provider_insts.remove(self.inst_map[provider_id])
|
prov_inst = self.inst_map[provider_id]
|
||||||
|
if isinstance(prov_inst, TTSProvider):
|
||||||
|
self.tts_provider_insts.remove(prov_inst)
|
||||||
|
|
||||||
if self.inst_map[provider_id] == self.curr_provider_inst:
|
if self.inst_map[provider_id] == self.curr_provider_inst:
|
||||||
self.curr_provider_inst = None
|
self.curr_provider_inst = None
|
||||||
|
|||||||
@@ -98,7 +98,7 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
|||||||
|
|
||||||
# FishAudio的reference_id通常是32位十六进制字符串
|
# FishAudio的reference_id通常是32位十六进制字符串
|
||||||
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
|
# 例如: 626bb6d3f3364c9cbc3aa6a67300a664
|
||||||
pattern = r'^[a-fA-F0-9]{32}$'
|
pattern = r"^[a-fA-F0-9]{32}$"
|
||||||
return bool(re.match(pattern, reference_id.strip()))
|
return bool(re.match(pattern, reference_id.strip()))
|
||||||
|
|
||||||
async def _generate_request(self, text: str) -> dict:
|
async def _generate_request(self, text: str) -> dict:
|
||||||
|
|||||||
@@ -99,12 +99,15 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
for key in to_del:
|
for key in to_del:
|
||||||
del payloads[key]
|
del payloads[key]
|
||||||
|
|
||||||
model = payloads.get("model", "")
|
# 读取并合并 custom_extra_body 配置
|
||||||
# 针对 qwen3 非 thinking 模型的特殊处理:非流式调用必须设置 enable_thinking=false
|
custom_extra_body = self.provider_config.get("custom_extra_body", {})
|
||||||
if "qwen3" in model.lower() and "thinking" not in model.lower():
|
if isinstance(custom_extra_body, dict):
|
||||||
extra_body["enable_thinking"] = False
|
extra_body.update(custom_extra_body)
|
||||||
|
|
||||||
|
model = payloads.get("model", "").lower()
|
||||||
|
|
||||||
# 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
|
# 针对 deepseek 模型的特殊处理:deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
|
||||||
elif model == "deepseek-reasoner" and "tools" in payloads:
|
if model == "deepseek-reasoner" and "tools" in payloads:
|
||||||
del payloads["tools"]
|
del payloads["tools"]
|
||||||
|
|
||||||
completion = await self.client.chat.completions.create(
|
completion = await self.client.chat.completions.create(
|
||||||
@@ -137,6 +140,12 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
|
|
||||||
# 不在默认参数中的参数放在 extra_body 中
|
# 不在默认参数中的参数放在 extra_body 中
|
||||||
extra_body = {}
|
extra_body = {}
|
||||||
|
|
||||||
|
# 读取并合并 custom_extra_body 配置
|
||||||
|
custom_extra_body = self.provider_config.get("custom_extra_body", {})
|
||||||
|
if isinstance(custom_extra_body, dict):
|
||||||
|
extra_body.update(custom_extra_body)
|
||||||
|
|
||||||
to_del = []
|
to_del = []
|
||||||
for key in payloads.keys():
|
for key in payloads.keys():
|
||||||
if key not in self.default_params:
|
if key not in self.default_params:
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
|
from astrbot import logger
|
||||||
from ..provider import RerankProvider
|
from ..provider import RerankProvider
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
from ..entities import ProviderType, RerankResult
|
from ..entities import ProviderType, RerankResult
|
||||||
@@ -44,6 +45,11 @@ class VLLMRerankProvider(RerankProvider):
|
|||||||
response_data = await response.json()
|
response_data = await response.json()
|
||||||
results = response_data.get("results", [])
|
results = response_data.get("results", [])
|
||||||
|
|
||||||
|
if not results:
|
||||||
|
logger.warning(
|
||||||
|
f"Rerank API 返回了空的列表数据。原始响应: {response_data}"
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
RerankResult(
|
RerankResult(
|
||||||
index=result["index"],
|
index=result["index"],
|
||||||
|
|||||||
@@ -27,14 +27,16 @@ class Star(CommandParserMixin):
|
|||||||
star_map[cls.__module__].star_cls_type = cls
|
star_map[cls.__module__].star_cls_type = cls
|
||||||
star_map[cls.__module__].module_path = cls.__module__
|
star_map[cls.__module__].module_path = cls.__module__
|
||||||
|
|
||||||
@staticmethod
|
async def text_to_image(self, text: str, return_url=True) -> str:
|
||||||
async def text_to_image(text: str, return_url=True) -> str:
|
|
||||||
"""将文本转换为图片"""
|
"""将文本转换为图片"""
|
||||||
return await html_renderer.render_t2i(text, return_url=return_url)
|
return await html_renderer.render_t2i(
|
||||||
|
text,
|
||||||
|
return_url=return_url,
|
||||||
|
template_name=self.context._config.get("t2i_active_template"),
|
||||||
|
)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def html_render(
|
async def html_render(
|
||||||
tmpl: str, data: dict, return_url=True, options: dict | None = None
|
self, tmpl: str, data: dict, return_url=True, options: dict | None = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""渲染 HTML"""
|
"""渲染 HTML"""
|
||||||
return await html_renderer.render_custom_template(
|
return await html_renderer.render_custom_template(
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from astrbot.core.provider.provider import (
|
|||||||
TTSProvider,
|
TTSProvider,
|
||||||
STTProvider,
|
STTProvider,
|
||||||
EmbeddingProvider,
|
EmbeddingProvider,
|
||||||
|
RerankProvider,
|
||||||
)
|
)
|
||||||
from astrbot.core.provider.entities import ProviderType
|
from astrbot.core.provider.entities import ProviderType
|
||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
@@ -23,7 +24,7 @@ from .star import star_registry, StarMetadata, star_map
|
|||||||
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
|
||||||
from .filter.command import CommandFilter
|
from .filter.command import CommandFilter
|
||||||
from .filter.regex import RegexFilter
|
from .filter.regex import RegexFilter
|
||||||
from typing import Awaitable
|
from typing import Awaitable, Any, Callable
|
||||||
from astrbot.core.conversation_mgr import ConversationManager
|
from astrbot.core.conversation_mgr import ConversationManager
|
||||||
from astrbot.core.star.filter.platform_adapter_type import (
|
from astrbot.core.star.filter.platform_adapter_type import (
|
||||||
PlatformAdapterType,
|
PlatformAdapterType,
|
||||||
@@ -103,9 +104,14 @@ class Context:
|
|||||||
"""
|
"""
|
||||||
self.provider_manager.provider_insts.append(provider)
|
self.provider_manager.provider_insts.append(provider)
|
||||||
|
|
||||||
def get_provider_by_id(self, provider_id: str) -> Provider | None:
|
def get_provider_by_id(
|
||||||
"""通过 ID 获取对应的 LLM Provider(Chat_Completion 类型)。"""
|
self, provider_id: str
|
||||||
return self.provider_manager.inst_map.get(provider_id)
|
) -> (
|
||||||
|
Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None
|
||||||
|
):
|
||||||
|
"""通过 ID 获取对应的 LLM Provider。"""
|
||||||
|
prov = self.provider_manager.inst_map.get(provider_id)
|
||||||
|
return prov
|
||||||
|
|
||||||
def get_all_providers(self) -> List[Provider]:
|
def get_all_providers(self) -> List[Provider]:
|
||||||
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
"""获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。"""
|
||||||
@@ -130,34 +136,43 @@ class Context:
|
|||||||
Args:
|
Args:
|
||||||
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。
|
||||||
"""
|
"""
|
||||||
return self.provider_manager.get_using_provider(
|
prov = self.provider_manager.get_using_provider(
|
||||||
provider_type=ProviderType.CHAT_COMPLETION,
|
provider_type=ProviderType.CHAT_COMPLETION,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
|
if prov and not isinstance(prov, Provider):
|
||||||
|
raise ValueError("返回的 Provider 不是 Provider 类型")
|
||||||
|
return prov
|
||||||
|
|
||||||
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider:
|
def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None:
|
||||||
"""
|
"""
|
||||||
获取当前使用的用于 TTS 任务的 Provider。
|
获取当前使用的用于 TTS 任务的 Provider。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||||
"""
|
"""
|
||||||
return self.provider_manager.get_using_provider(
|
prov = self.provider_manager.get_using_provider(
|
||||||
provider_type=ProviderType.TEXT_TO_SPEECH,
|
provider_type=ProviderType.TEXT_TO_SPEECH,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
|
if prov and not isinstance(prov, TTSProvider):
|
||||||
|
raise ValueError("返回的 Provider 不是 TTSProvider 类型")
|
||||||
|
return prov
|
||||||
|
|
||||||
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider:
|
def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None:
|
||||||
"""
|
"""
|
||||||
获取当前使用的用于 STT 任务的 Provider。
|
获取当前使用的用于 STT 任务的 Provider。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。
|
||||||
"""
|
"""
|
||||||
return self.provider_manager.get_using_provider(
|
prov = self.provider_manager.get_using_provider(
|
||||||
provider_type=ProviderType.SPEECH_TO_TEXT,
|
provider_type=ProviderType.SPEECH_TO_TEXT,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
|
if prov and not isinstance(prov, STTProvider):
|
||||||
|
raise ValueError("返回的 Provider 不是 STTProvider 类型")
|
||||||
|
return prov
|
||||||
|
|
||||||
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
def get_config(self, umo: str | None = None) -> AstrBotConfig:
|
||||||
"""获取 AstrBot 的配置。"""
|
"""获取 AstrBot 的配置。"""
|
||||||
@@ -245,7 +260,11 @@ class Context:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def register_llm_tool(
|
def register_llm_tool(
|
||||||
self, name: str, func_args: list, desc: str, func_obj: Awaitable
|
self,
|
||||||
|
name: str,
|
||||||
|
func_args: list,
|
||||||
|
desc: str,
|
||||||
|
func_obj: Callable[..., Awaitable[Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为函数调用(function-calling / tools-use)添加工具。
|
为函数调用(function-calling / tools-use)添加工具。
|
||||||
@@ -267,9 +286,7 @@ class Context:
|
|||||||
desc=desc,
|
desc=desc,
|
||||||
)
|
)
|
||||||
star_handlers_registry.append(md)
|
star_handlers_registry.append(md)
|
||||||
self.provider_manager.llm_tools.add_func(
|
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj)
|
||||||
name, func_args, desc, func_obj, func_obj
|
|
||||||
)
|
|
||||||
|
|
||||||
def unregister_llm_tool(self, name: str) -> None:
|
def unregister_llm_tool(self, name: str) -> None:
|
||||||
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
|
"""删除一个函数调用工具。如果再要启用,需要重新注册。"""
|
||||||
@@ -281,7 +298,7 @@ class Context:
|
|||||||
command_name: str,
|
command_name: str,
|
||||||
desc: str,
|
desc: str,
|
||||||
priority: int,
|
priority: int,
|
||||||
awaitable: Awaitable,
|
awaitable: Callable[..., Awaitable[Any]],
|
||||||
use_regex=False,
|
use_regex=False,
|
||||||
ignore_prefix=False,
|
ignore_prefix=False,
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
group_name: str,
|
group_name: str,
|
||||||
alias: set = None,
|
alias: set | None = None,
|
||||||
parent_group: CommandGroupFilter = None,
|
parent_group: CommandGroupFilter | None = None,
|
||||||
):
|
):
|
||||||
self.group_name = group_name
|
self.group_name = group_name
|
||||||
self.alias = alias if alias else set()
|
self.alias = alias if alias else set()
|
||||||
@@ -54,8 +54,8 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
self,
|
self,
|
||||||
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
|
sub_command_filters: List[Union[CommandFilter, CommandGroupFilter]],
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
event: AstrMessageEvent = None,
|
event: AstrMessageEvent | None = None,
|
||||||
cfg: AstrBotConfig = None,
|
cfg: AstrBotConfig | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
result = ""
|
result = ""
|
||||||
for sub_filter in sub_command_filters:
|
for sub_filter in sub_command_filters:
|
||||||
@@ -113,8 +113,7 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
+ self.print_cmd_tree(self.sub_command_filters, event=event, cfg=cfg)
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n"
|
f"参数不足。{self.group_name} 指令组下有如下指令,请参考:\n" + tree
|
||||||
+ tree
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# complete_command_names = [name + " " for name in complete_command_names]
|
# complete_command_names = [name + " " for name in complete_command_names]
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ import enum
|
|||||||
from . import HandlerFilter
|
from . import HandlerFilter
|
||||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||||
from astrbot.core.config import AstrBotConfig
|
from astrbot.core.config import AstrBotConfig
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
|
|
||||||
class PlatformAdapterType(enum.Flag):
|
class PlatformAdapterType(enum.Flag):
|
||||||
@@ -18,6 +17,8 @@ class PlatformAdapterType(enum.Flag):
|
|||||||
KOOK = enum.auto()
|
KOOK = enum.auto()
|
||||||
VOCECHAT = enum.auto()
|
VOCECHAT = enum.auto()
|
||||||
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
WEIXIN_OFFICIAL_ACCOUNT = enum.auto()
|
||||||
|
SATORI = enum.auto()
|
||||||
|
MISSKEY = enum.auto()
|
||||||
ALL = (
|
ALL = (
|
||||||
AIOCQHTTP
|
AIOCQHTTP
|
||||||
| QQOFFICIAL
|
| QQOFFICIAL
|
||||||
@@ -31,6 +32,8 @@ class PlatformAdapterType(enum.Flag):
|
|||||||
| KOOK
|
| KOOK
|
||||||
| VOCECHAT
|
| VOCECHAT
|
||||||
| WEIXIN_OFFICIAL_ACCOUNT
|
| WEIXIN_OFFICIAL_ACCOUNT
|
||||||
|
| SATORI
|
||||||
|
| MISSKEY
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -47,15 +50,20 @@ ADAPTER_NAME_2_TYPE = {
|
|||||||
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
|
"wechatpadpro": PlatformAdapterType.WECHATPADPRO,
|
||||||
"vocechat": PlatformAdapterType.VOCECHAT,
|
"vocechat": PlatformAdapterType.VOCECHAT,
|
||||||
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
"weixin_official_account": PlatformAdapterType.WEIXIN_OFFICIAL_ACCOUNT,
|
||||||
|
"satori": PlatformAdapterType.SATORI,
|
||||||
|
"misskey": PlatformAdapterType.MISSKEY,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class PlatformAdapterTypeFilter(HandlerFilter):
|
class PlatformAdapterTypeFilter(HandlerFilter):
|
||||||
def __init__(self, platform_adapter_type_or_str: Union[PlatformAdapterType, str]):
|
def __init__(self, platform_adapter_type_or_str: PlatformAdapterType | str):
|
||||||
self.type_or_str = platform_adapter_type_or_str
|
if isinstance(platform_adapter_type_or_str, str):
|
||||||
|
self.platform_type = ADAPTER_NAME_2_TYPE.get(platform_adapter_type_or_str)
|
||||||
|
else:
|
||||||
|
self.platform_type = platform_adapter_type_or_str
|
||||||
|
|
||||||
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
def filter(self, event: AstrMessageEvent, cfg: AstrBotConfig) -> bool:
|
||||||
adapter_name = event.get_platform_name()
|
adapter_name = event.get_platform_name()
|
||||||
if adapter_name in ADAPTER_NAME_2_TYPE:
|
if adapter_name in ADAPTER_NAME_2_TYPE and self.platform_type is not None:
|
||||||
return ADAPTER_NAME_2_TYPE[adapter_name] & self.type_or_str
|
return bool(ADAPTER_NAME_2_TYPE[adapter_name] & self.platform_type)
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from .star_handler import (
|
|||||||
register_permission_type,
|
register_permission_type,
|
||||||
register_custom_filter,
|
register_custom_filter,
|
||||||
register_on_astrbot_loaded,
|
register_on_astrbot_loaded,
|
||||||
|
register_on_platform_loaded,
|
||||||
register_on_llm_request,
|
register_on_llm_request,
|
||||||
register_on_llm_response,
|
register_on_llm_response,
|
||||||
register_llm_tool,
|
register_llm_tool,
|
||||||
@@ -26,6 +27,7 @@ __all__ = [
|
|||||||
"register_permission_type",
|
"register_permission_type",
|
||||||
"register_custom_filter",
|
"register_custom_filter",
|
||||||
"register_on_astrbot_loaded",
|
"register_on_astrbot_loaded",
|
||||||
|
"register_on_platform_loaded",
|
||||||
"register_on_llm_request",
|
"register_on_llm_request",
|
||||||
"register_on_llm_response",
|
"register_on_llm_response",
|
||||||
"register_llm_tool",
|
"register_llm_tool",
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ from astrbot.core.star import StarMetadata, star_map
|
|||||||
_warned_register_star = False
|
_warned_register_star = False
|
||||||
|
|
||||||
|
|
||||||
def register_star(name: str, author: str, desc: str, version: str, repo: str = None):
|
def register_star(
|
||||||
|
name: str, author: str, desc: str, version: str, repo: str | None = None
|
||||||
|
):
|
||||||
"""注册一个插件(Star)。
|
"""注册一个插件(Star)。
|
||||||
|
|
||||||
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
[DEPRECATED] 该装饰器已废弃,将在未来版本中移除。
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from ..filter.platform_adapter_type import (
|
|||||||
from ..filter.permission import PermissionTypeFilter, PermissionType
|
from ..filter.permission import PermissionTypeFilter, PermissionType
|
||||||
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr
|
||||||
from ..filter.regex import RegexFilter
|
from ..filter.regex import RegexFilter
|
||||||
from typing import Awaitable
|
from typing import Awaitable, Any, Callable
|
||||||
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
|
||||||
from astrbot.core.provider.register import llm_tools
|
from astrbot.core.provider.register import llm_tools
|
||||||
from astrbot.core.agent.agent import Agent
|
from astrbot.core.agent.agent import Agent
|
||||||
@@ -20,15 +20,19 @@ from astrbot.core.agent.tool import FunctionTool
|
|||||||
from astrbot.core.agent.handoff import HandoffTool
|
from astrbot.core.agent.handoff import HandoffTool
|
||||||
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
from astrbot.core.agent.hooks import BaseAgentRunHooks
|
||||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
|
from astrbot.core import logger
|
||||||
|
|
||||||
|
|
||||||
def get_handler_full_name(awaitable: Awaitable) -> str:
|
def get_handler_full_name(awaitable: Callable[..., Awaitable[Any]]) -> str:
|
||||||
"""获取 Handler 的全名"""
|
"""获取 Handler 的全名"""
|
||||||
return f"{awaitable.__module__}_{awaitable.__name__}"
|
return f"{awaitable.__module__}_{awaitable.__name__}"
|
||||||
|
|
||||||
|
|
||||||
def get_handler_or_create(
|
def get_handler_or_create(
|
||||||
handler: Awaitable, event_type: EventType, dont_add=False, **kwargs
|
handler: Callable[..., Awaitable[Any]],
|
||||||
|
event_type: EventType,
|
||||||
|
dont_add=False,
|
||||||
|
**kwargs,
|
||||||
) -> StarHandlerMetadata:
|
) -> StarHandlerMetadata:
|
||||||
"""获取 Handler 或者创建一个新的 Handler"""
|
"""获取 Handler 或者创建一个新的 Handler"""
|
||||||
handler_full_name = get_handler_full_name(handler)
|
handler_full_name = get_handler_full_name(handler)
|
||||||
@@ -59,22 +63,35 @@ def get_handler_or_create(
|
|||||||
|
|
||||||
|
|
||||||
def register_command(
|
def register_command(
|
||||||
command_name: str = None, sub_command: str = None, alias: set = None, **kwargs
|
command_name: str | None = None,
|
||||||
|
sub_command: str | None = None,
|
||||||
|
alias: set | None = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""注册一个 Command."""
|
"""注册一个 Command."""
|
||||||
new_command = None
|
new_command = None
|
||||||
add_to_event_filters = False
|
add_to_event_filters = False
|
||||||
if isinstance(command_name, RegisteringCommandable):
|
if isinstance(command_name, RegisteringCommandable):
|
||||||
# 子指令
|
# 子指令
|
||||||
parent_command_names = command_name.parent_group.get_complete_command_names()
|
if sub_command is not None:
|
||||||
new_command = CommandFilter(
|
parent_command_names = (
|
||||||
sub_command, alias, None, parent_command_names=parent_command_names
|
command_name.parent_group.get_complete_command_names()
|
||||||
)
|
)
|
||||||
command_name.parent_group.add_sub_command_filter(new_command)
|
new_command = CommandFilter(
|
||||||
|
sub_command, alias, None, parent_command_names=parent_command_names
|
||||||
|
)
|
||||||
|
command_name.parent_group.add_sub_command_filter(new_command)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"注册指令{command_name} 的子指令时未提供 sub_command 参数。"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# 裸指令
|
# 裸指令
|
||||||
new_command = CommandFilter(command_name, alias, None)
|
if command_name is None:
|
||||||
add_to_event_filters = True
|
logger.warning("注册裸指令时未提供 command_name 参数。")
|
||||||
|
else:
|
||||||
|
new_command = CommandFilter(command_name, alias, None)
|
||||||
|
add_to_event_filters = True
|
||||||
|
|
||||||
def decorator(awaitable):
|
def decorator(awaitable):
|
||||||
if not add_to_event_filters:
|
if not add_to_event_filters:
|
||||||
@@ -84,8 +101,9 @@ def register_command(
|
|||||||
handler_md = get_handler_or_create(
|
handler_md = get_handler_or_create(
|
||||||
awaitable, EventType.AdapterMessageEvent, **kwargs
|
awaitable, EventType.AdapterMessageEvent, **kwargs
|
||||||
)
|
)
|
||||||
new_command.init_handler_md(handler_md)
|
if new_command:
|
||||||
handler_md.event_filters.append(new_command)
|
new_command.init_handler_md(handler_md)
|
||||||
|
handler_md.event_filters.append(new_command)
|
||||||
return awaitable
|
return awaitable
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
@@ -163,26 +181,38 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
|||||||
|
|
||||||
|
|
||||||
def register_command_group(
|
def register_command_group(
|
||||||
command_group_name: str = None, sub_command: str = None, alias: set = None, **kwargs
|
command_group_name: str | None = None,
|
||||||
|
sub_command: str | None = None,
|
||||||
|
alias: set | None = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""注册一个 CommandGroup"""
|
"""注册一个 CommandGroup"""
|
||||||
new_group = None
|
new_group = None
|
||||||
if isinstance(command_group_name, RegisteringCommandable):
|
if isinstance(command_group_name, RegisteringCommandable):
|
||||||
# 子指令组
|
# 子指令组
|
||||||
new_group = CommandGroupFilter(
|
if sub_command is None:
|
||||||
sub_command, alias, parent_group=command_group_name.parent_group
|
logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定")
|
||||||
)
|
else:
|
||||||
command_group_name.parent_group.add_sub_command_filter(new_group)
|
new_group = CommandGroupFilter(
|
||||||
|
sub_command, alias, parent_group=command_group_name.parent_group
|
||||||
|
)
|
||||||
|
command_group_name.parent_group.add_sub_command_filter(new_group)
|
||||||
else:
|
else:
|
||||||
# 根指令组
|
# 根指令组
|
||||||
new_group = CommandGroupFilter(command_group_name, alias)
|
if command_group_name is None:
|
||||||
|
logger.warning("根指令组的名称未指定")
|
||||||
|
else:
|
||||||
|
new_group = CommandGroupFilter(command_group_name, alias)
|
||||||
|
|
||||||
def decorator(obj):
|
def decorator(obj):
|
||||||
# 根指令组
|
# 根指令组
|
||||||
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent, **kwargs)
|
if new_group:
|
||||||
handler_md.event_filters.append(new_group)
|
handler_md = get_handler_or_create(
|
||||||
|
obj, EventType.AdapterMessageEvent, **kwargs
|
||||||
|
)
|
||||||
|
handler_md.event_filters.append(new_group)
|
||||||
|
|
||||||
return RegisteringCommandable(new_group)
|
return RegisteringCommandable(new_group)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
@@ -267,6 +297,18 @@ def register_on_astrbot_loaded(**kwargs):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def register_on_platform_loaded(**kwargs):
|
||||||
|
"""
|
||||||
|
当平台加载完成时
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(awaitable):
|
||||||
|
_ = get_handler_or_create(awaitable, EventType.OnPlatformLoadedEvent, **kwargs)
|
||||||
|
return awaitable
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def register_on_llm_request(**kwargs):
|
def register_on_llm_request(**kwargs):
|
||||||
"""当有 LLM 请求时的事件
|
"""当有 LLM 请求时的事件
|
||||||
|
|
||||||
@@ -311,7 +353,7 @@ def register_on_llm_response(**kwargs):
|
|||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def register_llm_tool(name: str = None, **kwargs):
|
def register_llm_tool(name: str | None = None, **kwargs):
|
||||||
"""为函数调用(function-calling / tools-use)添加工具。
|
"""为函数调用(function-calling / tools-use)添加工具。
|
||||||
|
|
||||||
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
|
||||||
@@ -349,9 +391,10 @@ def register_llm_tool(name: str = None, **kwargs):
|
|||||||
if kwargs.get("registering_agent"):
|
if kwargs.get("registering_agent"):
|
||||||
registering_agent = kwargs["registering_agent"]
|
registering_agent = kwargs["registering_agent"]
|
||||||
|
|
||||||
def decorator(awaitable: Awaitable):
|
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||||
llm_tool_name = name_ if name_ else awaitable.__name__
|
llm_tool_name = name_ if name_ else awaitable.__name__
|
||||||
docstring = docstring_parser.parse(awaitable.__doc__)
|
func_doc = awaitable.__doc__ or ""
|
||||||
|
docstring = docstring_parser.parse(func_doc)
|
||||||
args = []
|
args = []
|
||||||
for arg in docstring.params:
|
for arg in docstring.params:
|
||||||
if arg.type_name not in SUPPORTED_TYPES:
|
if arg.type_name not in SUPPORTED_TYPES:
|
||||||
@@ -367,18 +410,18 @@ def register_llm_tool(name: str = None, **kwargs):
|
|||||||
)
|
)
|
||||||
# print(llm_tool_name, registering_agent)
|
# print(llm_tool_name, registering_agent)
|
||||||
if not registering_agent:
|
if not registering_agent:
|
||||||
|
doc_desc = docstring.description.strip() if docstring.description else ""
|
||||||
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
|
||||||
llm_tools.add_func(
|
llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler)
|
||||||
llm_tool_name, args, docstring.description.strip(), md.handler
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
assert isinstance(registering_agent, RegisteringAgent)
|
assert isinstance(registering_agent, RegisteringAgent)
|
||||||
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
# print(f"Registering tool {llm_tool_name} for agent", registering_agent._agent.name)
|
||||||
if registering_agent._agent.tools is None:
|
if registering_agent._agent.tools is None:
|
||||||
registering_agent._agent.tools = []
|
registering_agent._agent.tools = []
|
||||||
registering_agent._agent.tools.append(llm_tools.spec_to_func(
|
|
||||||
llm_tool_name, args, docstring.description.strip(), awaitable
|
desc = docstring.description.strip() if docstring.description else ""
|
||||||
))
|
tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable)
|
||||||
|
registering_agent._agent.tools.append(tool)
|
||||||
|
|
||||||
return awaitable
|
return awaitable
|
||||||
|
|
||||||
@@ -399,8 +442,8 @@ class RegisteringAgent:
|
|||||||
def register_agent(
|
def register_agent(
|
||||||
name: str,
|
name: str,
|
||||||
instruction: str,
|
instruction: str,
|
||||||
tools: list[str | FunctionTool] = None,
|
tools: list[str | FunctionTool] | None = None,
|
||||||
run_hooks: BaseAgentRunHooks[AstrAgentContext] = None,
|
run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None,
|
||||||
):
|
):
|
||||||
"""注册一个 Agent
|
"""注册一个 Agent
|
||||||
|
|
||||||
@@ -412,7 +455,7 @@ def register_agent(
|
|||||||
"""
|
"""
|
||||||
tools_ = tools or []
|
tools_ = tools or []
|
||||||
|
|
||||||
def decorator(awaitable: Awaitable):
|
def decorator(awaitable: Callable[..., Awaitable[Any]]):
|
||||||
AstrAgent = Agent[AstrAgentContext]
|
AstrAgent = Agent[AstrAgentContext]
|
||||||
agent = AstrAgent(
|
agent = AstrAgent(
|
||||||
name=name,
|
name=name,
|
||||||
@@ -421,7 +464,7 @@ def register_agent(
|
|||||||
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](),
|
||||||
)
|
)
|
||||||
handoff_tool = HandoffTool(agent=agent)
|
handoff_tool = HandoffTool(agent=agent)
|
||||||
handoff_tool.handler=awaitable
|
handoff_tool.handler = awaitable
|
||||||
llm_tools.func_list.append(handoff_tool)
|
llm_tools.func_list.append(handoff_tool)
|
||||||
return RegisteringAgent(agent)
|
return RegisteringAgent(agent)
|
||||||
|
|
||||||
|
|||||||
@@ -84,7 +84,10 @@ class SessionPluginManager:
|
|||||||
session_config["disabled_plugins"] = disabled_plugins
|
session_config["disabled_plugins"] = disabled_plugins
|
||||||
session_plugin_config[session_id] = session_config
|
session_plugin_config[session_id] = session_config
|
||||||
sp.put(
|
sp.put(
|
||||||
"session_plugin_config", session_plugin_config, scope="umo", scope_id=session_id
|
"session_plugin_config",
|
||||||
|
session_plugin_config,
|
||||||
|
scope="umo",
|
||||||
|
scope_id=session_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -137,6 +140,9 @@ class SessionPluginManager:
|
|||||||
filtered_handlers.append(handler)
|
filtered_handlers.append(handler)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if plugin.name is None:
|
||||||
|
continue
|
||||||
|
|
||||||
# 检查插件是否在当前会话中启用
|
# 检查插件是否在当前会话中启用
|
||||||
if SessionPluginManager.is_plugin_enabled_for_session(
|
if SessionPluginManager.is_plugin_enabled_for_session(
|
||||||
session_id, plugin.name
|
session_id, plugin.name
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
import enum
|
import enum
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Awaitable, List, Dict, TypeVar, Generic
|
from typing import Callable, Awaitable, Any, List, Dict, TypeVar, Generic
|
||||||
from .filter import HandlerFilter
|
from .filter import HandlerFilter
|
||||||
from .star import star_map
|
from .star import star_map
|
||||||
|
|
||||||
@@ -34,26 +34,33 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
) -> List[StarHandlerMetadata]:
|
) -> List[StarHandlerMetadata]:
|
||||||
handlers = []
|
handlers = []
|
||||||
for handler in self._handlers:
|
for handler in self._handlers:
|
||||||
|
# 过滤事件类型
|
||||||
if handler.event_type != event_type:
|
if handler.event_type != event_type:
|
||||||
continue
|
continue
|
||||||
|
# 过滤启用状态
|
||||||
if only_activated:
|
if only_activated:
|
||||||
plugin = star_map.get(handler.handler_module_path)
|
plugin = star_map.get(handler.handler_module_path)
|
||||||
if not (plugin and plugin.activated):
|
if not (plugin and plugin.activated):
|
||||||
continue
|
continue
|
||||||
|
# 过滤插件白名单
|
||||||
if plugins_name is not None and plugins_name != ["*"]:
|
if plugins_name is not None and plugins_name != ["*"]:
|
||||||
plugin = star_map.get(handler.handler_module_path)
|
plugin = star_map.get(handler.handler_module_path)
|
||||||
if not plugin:
|
if not plugin:
|
||||||
continue
|
continue
|
||||||
if (
|
if (
|
||||||
plugin.name not in plugins_name
|
plugin.name not in plugins_name
|
||||||
and event_type != EventType.OnAstrBotLoadedEvent
|
and event_type
|
||||||
|
not in (
|
||||||
|
EventType.OnAstrBotLoadedEvent,
|
||||||
|
EventType.OnPlatformLoadedEvent,
|
||||||
|
)
|
||||||
and not plugin.reserved
|
and not plugin.reserved
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
handlers.append(handler)
|
handlers.append(handler)
|
||||||
return handlers
|
return handlers
|
||||||
|
|
||||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata | None:
|
||||||
return self.star_handlers_map.get(full_name, None)
|
return self.star_handlers_map.get(full_name, None)
|
||||||
|
|
||||||
def get_handlers_by_module_name(
|
def get_handlers_by_module_name(
|
||||||
@@ -80,7 +87,7 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
return len(self._handlers)
|
return len(self._handlers)
|
||||||
|
|
||||||
|
|
||||||
star_handlers_registry = StarHandlerRegistry()
|
star_handlers_registry = StarHandlerRegistry() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class EventType(enum.Enum):
|
class EventType(enum.Enum):
|
||||||
@@ -90,6 +97,7 @@ class EventType(enum.Enum):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成
|
OnAstrBotLoadedEvent = enum.auto() # AstrBot 加载完成
|
||||||
|
OnPlatformLoadedEvent = enum.auto() # 平台加载完成
|
||||||
|
|
||||||
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
|
||||||
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
|
||||||
@@ -115,7 +123,7 @@ class StarHandlerMetadata:
|
|||||||
handler_module_path: str
|
handler_module_path: str
|
||||||
"""Handler 所在的模块路径。"""
|
"""Handler 所在的模块路径。"""
|
||||||
|
|
||||||
handler: Awaitable
|
handler: Callable[..., Awaitable[Any]]
|
||||||
"""Handler 的函数对象,应当是一个异步函数"""
|
"""Handler 的函数对象,应当是一个异步函数"""
|
||||||
|
|
||||||
event_filters: List[HandlerFilter]
|
event_filters: List[HandlerFilter]
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class PluginManager:
|
|||||||
self.updator = PluginUpdator()
|
self.updator = PluginUpdator()
|
||||||
|
|
||||||
self.context = context
|
self.context = context
|
||||||
self.context._star_manager = self
|
self.context._star_manager = self # type: ignore
|
||||||
|
|
||||||
self.config = config
|
self.config = config
|
||||||
self.plugin_store_path = get_astrbot_plugin_path()
|
self.plugin_store_path = get_astrbot_plugin_path()
|
||||||
@@ -478,9 +478,10 @@ class PluginManager:
|
|||||||
if isinstance(func_tool, HandoffTool):
|
if isinstance(func_tool, HandoffTool):
|
||||||
need_apply = []
|
need_apply = []
|
||||||
sub_tools = func_tool.agent.tools
|
sub_tools = func_tool.agent.tools
|
||||||
for sub_tool in sub_tools:
|
if sub_tools:
|
||||||
if isinstance(sub_tool, FunctionTool):
|
for sub_tool in sub_tools:
|
||||||
need_apply.append(sub_tool)
|
if isinstance(sub_tool, FunctionTool):
|
||||||
|
need_apply.append(sub_tool)
|
||||||
else:
|
else:
|
||||||
need_apply = [func_tool]
|
need_apply = [func_tool]
|
||||||
|
|
||||||
@@ -686,6 +687,9 @@ class PluginManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 从 star_registry 和 star_map 中删除
|
# 从 star_registry 和 star_map 中删除
|
||||||
|
if plugin.module_path is None or root_dir_name is None:
|
||||||
|
raise Exception(f"插件 {plugin_name} 数据不完整,无法卸载。")
|
||||||
|
|
||||||
await self._unbind_plugin(plugin_name, plugin.module_path)
|
await self._unbind_plugin(plugin_name, plugin.module_path)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -791,15 +795,17 @@ class PluginManager:
|
|||||||
if star_metadata.star_cls is None:
|
if star_metadata.star_cls is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if '__del__' in star_metadata.star_cls_type.__dict__:
|
if "__del__" in star_metadata.star_cls_type.__dict__:
|
||||||
asyncio.get_event_loop().run_in_executor(
|
asyncio.get_event_loop().run_in_executor(
|
||||||
None, star_metadata.star_cls.__del__
|
None, star_metadata.star_cls.__del__
|
||||||
)
|
)
|
||||||
elif 'terminate' in star_metadata.star_cls_type.__dict__:
|
elif "terminate" in star_metadata.star_cls_type.__dict__:
|
||||||
await star_metadata.star_cls.terminate()
|
await star_metadata.star_cls.terminate()
|
||||||
|
|
||||||
async def turn_on_plugin(self, plugin_name: str):
|
async def turn_on_plugin(self, plugin_name: str):
|
||||||
plugin = self.context.get_registered_star(plugin_name)
|
plugin = self.context.get_registered_star(plugin_name)
|
||||||
|
if plugin is None:
|
||||||
|
raise Exception(f"插件 {plugin_name} 不存在。")
|
||||||
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
inactivated_plugins: list = await sp.global_get("inactivated_plugins", [])
|
||||||
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
inactivated_llm_tools: list = await sp.global_get("inactivated_llm_tools", [])
|
||||||
if plugin.module_path in inactivated_plugins:
|
if plugin.module_path in inactivated_plugins:
|
||||||
|
|||||||
+135
-22
@@ -1,14 +1,41 @@
|
|||||||
|
"""
|
||||||
|
插件开发工具集
|
||||||
|
封装了许多常用的操作,方便插件开发者使用
|
||||||
|
|
||||||
|
说明:
|
||||||
|
|
||||||
|
主动发送消息: send_message(session, message_chain)
|
||||||
|
根据 session (unified_msg_origin) 主动发送消息, 前提是需要提前获得或构造 session
|
||||||
|
|
||||||
|
根据id直接主动发送消息: send_message_by_id(type, id, message_chain, platform="aiocqhttp")
|
||||||
|
根据 id (例如 qq 号, 群号等) 直接, 主动地发送消息
|
||||||
|
|
||||||
|
以上两种方式需要构造消息链, 也就是消息组件的列表
|
||||||
|
|
||||||
|
构造事件:
|
||||||
|
|
||||||
|
首先需要构造一个 AstrBotMessage 对象, 使用 create_message 方法
|
||||||
|
然后使用 create_event 方法提交事件到指定平台
|
||||||
|
"""
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, Awaitable, List, Optional, ClassVar
|
from typing import Union, Awaitable, Callable, Any, List, Optional, ClassVar
|
||||||
from astrbot.core.message.components import BaseMessageComponent
|
from astrbot.core.message.components import BaseMessageComponent
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.api.platform import MessageMember, AstrBotMessage
|
from astrbot.api.platform import MessageMember, AstrBotMessage, MessageType
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
from astrbot.core.star.context import Context
|
from astrbot.core.star.context import Context
|
||||||
from astrbot.core.star.star import star_map
|
from astrbot.core.star.star import star_map
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_message_event import (
|
||||||
|
AiocqhttpMessageEvent,
|
||||||
|
)
|
||||||
|
from astrbot.core.platform.sources.aiocqhttp.aiocqhttp_platform_adapter import (
|
||||||
|
AiocqhttpAdapter,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StarTools:
|
class StarTools:
|
||||||
@@ -49,42 +76,82 @@ class StarTools:
|
|||||||
Note:
|
Note:
|
||||||
qq_official(QQ官方API平台)不支持此方法
|
qq_official(QQ官方API平台)不支持此方法
|
||||||
"""
|
"""
|
||||||
|
if cls._context is None:
|
||||||
|
raise ValueError("StarTools not initialized")
|
||||||
return await cls._context.send_message(session, message_chain)
|
return await cls._context.send_message(session, message_chain)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def send_message_by_id(
|
||||||
|
cls,
|
||||||
|
type: str,
|
||||||
|
id: str,
|
||||||
|
message_chain: MessageChain,
|
||||||
|
platform: str = "aiocqhttp",
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
根据 id(例如qq号, 群号等) 直接, 主动地发送消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
type (str): 消息类型, 可选: PrivateMessage, GroupMessage
|
||||||
|
id (str): 目标ID, 例如QQ号, 群号等
|
||||||
|
message_chain (MessageChain): 消息链
|
||||||
|
platform (str): 可选的平台名称,默认平台(aiocqhttp), 目前只支持 aiocqhttp
|
||||||
|
"""
|
||||||
|
if cls._context is None:
|
||||||
|
raise ValueError("StarTools not initialized")
|
||||||
|
platforms = cls._context.platform_manager.get_insts()
|
||||||
|
if platform == "aiocqhttp":
|
||||||
|
adapter = next(
|
||||||
|
(p for p in platforms if isinstance(p, AiocqhttpAdapter)), None
|
||||||
|
)
|
||||||
|
if adapter is None:
|
||||||
|
raise ValueError("未找到适配器: AiocqhttpAdapter")
|
||||||
|
await AiocqhttpMessageEvent.send_message(
|
||||||
|
bot=adapter.bot,
|
||||||
|
message_chain=message_chain,
|
||||||
|
is_group=(type == "GroupMessage"),
|
||||||
|
session_id=id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的平台: {platform}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create_message(
|
async def create_message(
|
||||||
cls,
|
cls,
|
||||||
type: str,
|
type: str,
|
||||||
self_id: str,
|
self_id: str,
|
||||||
session_id: str,
|
session_id: str,
|
||||||
message_id: str,
|
|
||||||
sender: MessageMember,
|
sender: MessageMember,
|
||||||
message: List[BaseMessageComponent],
|
message: List[BaseMessageComponent],
|
||||||
message_str: str,
|
message_str: str,
|
||||||
raw_message: object,
|
message_id: str = "",
|
||||||
|
raw_message: object = None,
|
||||||
group_id: str = "",
|
group_id: str = "",
|
||||||
):
|
) -> AstrBotMessage:
|
||||||
"""
|
"""
|
||||||
创建一个AstrBot消息对象
|
创建一个AstrBot消息对象
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
type (str): 消息类型
|
type (str): 消息类型, 例如 "GroupMessage" "FriendMessage" "OtherMessage"
|
||||||
self_id (str): 机器人自身ID
|
self_id (str): 机器人自身ID
|
||||||
session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等)
|
session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等)
|
||||||
message_id (str): 消息ID
|
sender (MessageMember): 发送者信息, 例如 MessageMember(user_id="123456", nickname="昵称")
|
||||||
sender (MessageMember): 发送者信息
|
message (List[BaseMessageComponent]): 消息组件列表, 也就是消息链, 这个不会发给 llm, 但是会经过其他处理
|
||||||
message (List[BaseMessageComponent]): 消息组件列表
|
message_str (str): 消息字符串, 也就是纯文本消息, 也就是发送给 llm 的消息, 与消息链一致
|
||||||
message_str (str): 消息字符串
|
|
||||||
raw_message (object): 原始消息对象
|
message_id (str): 消息ID, 构造消息时可以随意填写也可不填
|
||||||
|
raw_message (object): 原始消息对象, 可以随意填写也可不填
|
||||||
group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "".
|
group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
AstrBotMessage: 创建的消息对象
|
AstrBotMessage: 创建的消息对象
|
||||||
"""
|
"""
|
||||||
abm = AstrBotMessage()
|
abm = AstrBotMessage()
|
||||||
abm.type = type
|
abm.type = MessageType(type)
|
||||||
abm.self_id = self_id
|
abm.self_id = self_id
|
||||||
abm.session_id = session_id
|
abm.session_id = session_id
|
||||||
|
if message_id == "":
|
||||||
|
message_id = uuid.uuid4().hex
|
||||||
abm.message_id = message_id
|
abm.message_id = message_id
|
||||||
abm.sender = sender
|
abm.sender = sender
|
||||||
abm.message = message
|
abm.message = message
|
||||||
@@ -93,13 +160,39 @@ class StarTools:
|
|||||||
abm.group_id = group_id
|
abm.group_id = group_id
|
||||||
return abm
|
return abm
|
||||||
|
|
||||||
# todo: 添加构造事件的方法
|
@classmethod
|
||||||
# async def create_event(
|
async def create_event(
|
||||||
# self, platform: str, umo: str, sender_id: str, session_id: str
|
cls, abm: AstrBotMessage, platform: str = "aiocqhttp", is_wake: bool = True
|
||||||
# ):
|
) -> None:
|
||||||
# platform = self._context.get_platform(platform)
|
"""
|
||||||
|
创建并提交事件到指定平台
|
||||||
|
当有需要创建一个事件, 触发某些处理流程时, 使用该方法
|
||||||
|
|
||||||
# todo: 添加找到对应平台并提交对应事件的方法
|
Args:
|
||||||
|
abm (AstrBotMessage): 要提交的消息对象, 请先使用 create_message 创建
|
||||||
|
platform (str): 可选的平台名称,默认平台(aiocqhttp), 目前只支持 aiocqhttp
|
||||||
|
is_wake (bool): 是否标记为唤醒事件, 默认为 True, 只有唤醒事件才会被 llm 响应
|
||||||
|
"""
|
||||||
|
if cls._context is None:
|
||||||
|
raise ValueError("StarTools not initialized")
|
||||||
|
platforms = cls._context.platform_manager.get_insts()
|
||||||
|
if platform == "aiocqhttp":
|
||||||
|
adapter = next(
|
||||||
|
(p for p in platforms if isinstance(p, AiocqhttpAdapter)), None
|
||||||
|
)
|
||||||
|
if adapter is None:
|
||||||
|
raise ValueError("未找到适配器: AiocqhttpAdapter")
|
||||||
|
event = AiocqhttpMessageEvent(
|
||||||
|
message_str=abm.message_str,
|
||||||
|
message_obj=abm,
|
||||||
|
platform_meta=adapter.metadata,
|
||||||
|
session_id=abm.session_id,
|
||||||
|
bot=adapter.bot,
|
||||||
|
)
|
||||||
|
event.is_wake = is_wake
|
||||||
|
adapter.commit_event(event)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的平台: {platform}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def activate_llm_tool(cls, name: str) -> bool:
|
def activate_llm_tool(cls, name: str) -> bool:
|
||||||
@@ -110,6 +203,8 @@ class StarTools:
|
|||||||
Args:
|
Args:
|
||||||
name (str): 工具名称
|
name (str): 工具名称
|
||||||
"""
|
"""
|
||||||
|
if cls._context is None:
|
||||||
|
raise ValueError("StarTools not initialized")
|
||||||
return cls._context.activate_llm_tool(name)
|
return cls._context.activate_llm_tool(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -120,11 +215,17 @@ class StarTools:
|
|||||||
Args:
|
Args:
|
||||||
name (str): 工具名称
|
name (str): 工具名称
|
||||||
"""
|
"""
|
||||||
|
if cls._context is None:
|
||||||
|
raise ValueError("StarTools not initialized")
|
||||||
return cls._context.deactivate_llm_tool(name)
|
return cls._context.deactivate_llm_tool(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def register_llm_tool(
|
def register_llm_tool(
|
||||||
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
|
cls,
|
||||||
|
name: str,
|
||||||
|
func_args: list,
|
||||||
|
desc: str,
|
||||||
|
func_obj: Callable[..., Awaitable[Any]],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
为函数调用(function-calling/tools-use)添加工具
|
为函数调用(function-calling/tools-use)添加工具
|
||||||
@@ -135,6 +236,8 @@ class StarTools:
|
|||||||
desc (str): 工具描述
|
desc (str): 工具描述
|
||||||
func_obj (Awaitable): 函数对象,必须是异步函数
|
func_obj (Awaitable): 函数对象,必须是异步函数
|
||||||
"""
|
"""
|
||||||
|
if cls._context is None:
|
||||||
|
raise ValueError("StarTools not initialized")
|
||||||
cls._context.register_llm_tool(name, func_args, desc, func_obj)
|
cls._context.register_llm_tool(name, func_args, desc, func_obj)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -146,6 +249,8 @@ class StarTools:
|
|||||||
Args:
|
Args:
|
||||||
name (str): 工具名称
|
name (str): 工具名称
|
||||||
"""
|
"""
|
||||||
|
if cls._context is None:
|
||||||
|
raise ValueError("StarTools not initialized")
|
||||||
cls._context.unregister_llm_tool(name)
|
cls._context.unregister_llm_tool(name)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -169,8 +274,11 @@ class StarTools:
|
|||||||
- 创建目录失败(权限不足或其他IO错误)
|
- 创建目录失败(权限不足或其他IO错误)
|
||||||
"""
|
"""
|
||||||
if not plugin_name:
|
if not plugin_name:
|
||||||
frame = inspect.currentframe().f_back
|
frame = inspect.currentframe()
|
||||||
module = inspect.getmodule(frame)
|
module = None
|
||||||
|
if frame:
|
||||||
|
frame = frame.f_back
|
||||||
|
module = inspect.getmodule(frame)
|
||||||
|
|
||||||
if not module:
|
if not module:
|
||||||
raise RuntimeError("无法获取调用者模块信息")
|
raise RuntimeError("无法获取调用者模块信息")
|
||||||
@@ -182,7 +290,12 @@ class StarTools:
|
|||||||
|
|
||||||
plugin_name = metadata.name
|
plugin_name = metadata.name
|
||||||
|
|
||||||
data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name))
|
if not plugin_name:
|
||||||
|
raise ValueError("无法获取插件名称")
|
||||||
|
|
||||||
|
data_dir = Path(
|
||||||
|
os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data_dir.mkdir(parents=True, exist_ok=True)
|
data_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|||||||
@@ -32,6 +32,9 @@ class PluginUpdator(RepoZipUpdator):
|
|||||||
if not repo_url:
|
if not repo_url:
|
||||||
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
raise Exception(f"插件 {plugin.name} 没有指定仓库地址。")
|
||||||
|
|
||||||
|
if not plugin.root_dir_name:
|
||||||
|
raise Exception(f"插件 {plugin.name} 的根目录名未指定。")
|
||||||
|
|
||||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||||
|
|
||||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||||
|
|||||||
@@ -227,9 +227,11 @@ async def download_dashboard(
|
|||||||
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
path = os.path.join(get_astrbot_data_path(), "dashboard.zip")
|
||||||
|
|
||||||
if latest or len(str(version)) != 40:
|
if latest or len(str(version)) != 40:
|
||||||
logger.info(f"准备下载 {version} 发行版本的 AstrBot WebUI 文件")
|
|
||||||
ver_name = "latest" if latest else version
|
ver_name = "latest" if latest else version
|
||||||
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
|
dashboard_release_url = f"https://astrbot-registry.soulter.top/download/astrbot-dashboard/{ver_name}/dist.zip"
|
||||||
|
logger.info(
|
||||||
|
f"准备下载指定发行版本的 AstrBot WebUI 文件: {dashboard_release_url}"
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
await download_file(dashboard_release_url, path, show_progress=True)
|
await download_file(dashboard_release_url, path, show_progress=True)
|
||||||
except BaseException as _:
|
except BaseException as _:
|
||||||
@@ -241,24 +243,10 @@ async def download_dashboard(
|
|||||||
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
|
dashboard_release_url = f"{proxy}/{dashboard_release_url}"
|
||||||
await download_file(dashboard_release_url, path, show_progress=True)
|
await download_file(dashboard_release_url, path, show_progress=True)
|
||||||
else:
|
else:
|
||||||
logger.info(f"准备下载指定版本的 AstrBot WebUI: {version}")
|
url = f"https://github.com/AstrBotDevs/astrbot-release-harbour/releases/download/release-{version}/dist.zip"
|
||||||
|
logger.info(f"准备下载指定版本的 AstrBot WebUI: {url}")
|
||||||
url = (
|
|
||||||
"https://api.github.com/repos/AstrBotDevs/astrbot-release-harbour/releases"
|
|
||||||
)
|
|
||||||
if proxy:
|
if proxy:
|
||||||
url = f"{proxy}/{url}"
|
url = f"{proxy}/{url}"
|
||||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
await download_file(url, path, show_progress=True)
|
||||||
async with session.get(url) as resp:
|
|
||||||
if resp.status == 200:
|
|
||||||
releases = await resp.json()
|
|
||||||
for release in releases:
|
|
||||||
if version in release["tag_name"]:
|
|
||||||
download_url = release["assets"][0]["browser_download_url"]
|
|
||||||
await download_file(download_url, path, show_progress=True)
|
|
||||||
else:
|
|
||||||
logger.warning(f"未找到指定的版本的 Dashboard 构建文件: {version}")
|
|
||||||
return
|
|
||||||
|
|
||||||
with zipfile.ZipFile(path, "r") as z:
|
with zipfile.ZipFile(path, "r") as z:
|
||||||
z.extractall(extract_path)
|
z.extractall(extract_path)
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
|
||||||
import ssl
|
import ssl
|
||||||
import certifi
|
import certifi
|
||||||
import logging
|
import logging
|
||||||
@@ -8,10 +7,9 @@ import random
|
|||||||
from . import RenderStrategy
|
from . import RenderStrategy
|
||||||
from astrbot.core.config import VERSION
|
from astrbot.core.config import VERSION
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.t2i.template_manager import TemplateManager
|
||||||
|
|
||||||
ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
ASTRBOT_T2I_DEFAULT_ENDPOINT = "https://t2i.soulter.top/text2img"
|
||||||
CUSTOM_T2I_TEMPLATE_PATH = os.path.join(get_astrbot_data_path(), "t2i_template.html")
|
|
||||||
|
|
||||||
logger = logging.getLogger("astrbot")
|
logger = logging.getLogger("astrbot")
|
||||||
|
|
||||||
@@ -23,26 +21,17 @@ class NetworkRenderStrategy(RenderStrategy):
|
|||||||
self.BASE_RENDER_URL = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
self.BASE_RENDER_URL = ASTRBOT_T2I_DEFAULT_ENDPOINT
|
||||||
else:
|
else:
|
||||||
self.BASE_RENDER_URL = self._clean_url(base_url)
|
self.BASE_RENDER_URL = self._clean_url(base_url)
|
||||||
self.TEMPLATE_PATH = os.path.join(os.path.dirname(__file__), "template", "base.html")
|
|
||||||
with open(self.TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
|
||||||
self.DEFAULT_TEMPLATE = f.read()
|
|
||||||
|
|
||||||
self.endpoints = [self.BASE_RENDER_URL]
|
self.endpoints = [self.BASE_RENDER_URL]
|
||||||
|
self.template_manager = TemplateManager()
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT:
|
if self.BASE_RENDER_URL == ASTRBOT_T2I_DEFAULT_ENDPOINT:
|
||||||
asyncio.create_task(self.get_official_endpoints())
|
asyncio.create_task(self.get_official_endpoints())
|
||||||
|
|
||||||
async def get_template(self) -> str:
|
async def get_template(self, name: str = "base") -> str:
|
||||||
"""获取文转图 HTML 模板
|
"""通过名称获取文转图 HTML 模板"""
|
||||||
|
return self.template_manager.get_template(name)
|
||||||
Returns:
|
|
||||||
str: 文转图 HTML 模板字符串
|
|
||||||
"""
|
|
||||||
if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH):
|
|
||||||
with open(CUSTOM_T2I_TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
|
||||||
return f.read()
|
|
||||||
return self.DEFAULT_TEMPLATE
|
|
||||||
|
|
||||||
async def get_official_endpoints(self):
|
async def get_official_endpoints(self):
|
||||||
"""获取官方的 t2i 端点列表。"""
|
"""获取官方的 t2i 端点列表。"""
|
||||||
@@ -124,11 +113,15 @@ class NetworkRenderStrategy(RenderStrategy):
|
|||||||
logger.error(f"All endpoints failed: {last_exception}")
|
logger.error(f"All endpoints failed: {last_exception}")
|
||||||
raise RuntimeError(f"All endpoints failed: {last_exception}")
|
raise RuntimeError(f"All endpoints failed: {last_exception}")
|
||||||
|
|
||||||
async def render(self, text: str, return_url: bool = False) -> str:
|
async def render(
|
||||||
|
self, text: str, return_url: bool = False, template_name: str | None = "base"
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
返回图像的文件路径
|
返回图像的文件路径
|
||||||
"""
|
"""
|
||||||
tmpl_str = await self.get_template()
|
if not template_name:
|
||||||
|
template_name = "base"
|
||||||
|
tmpl_str = await self.get_template(name=template_name)
|
||||||
text = text.replace("`", "\\`")
|
text = text.replace("`", "\\`")
|
||||||
return await self.render_custom_template(
|
return await self.render_custom_template(
|
||||||
tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url
|
tmpl_str, {"text": text, "version": f"v{VERSION}"}, return_url
|
||||||
|
|||||||
@@ -34,12 +34,18 @@ class HtmlRenderer:
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def render_t2i(
|
async def render_t2i(
|
||||||
self, text: str, use_network: bool = True, return_url: bool = False
|
self,
|
||||||
|
text: str,
|
||||||
|
use_network: bool = True,
|
||||||
|
return_url: bool = False,
|
||||||
|
template_name: str | None = None,
|
||||||
):
|
):
|
||||||
"""使用默认文转图模板。"""
|
"""使用默认文转图模板。"""
|
||||||
if use_network:
|
if use_network:
|
||||||
try:
|
try:
|
||||||
return await self.network_strategy.render(text, return_url=return_url)
|
return await self.network_strategy.render(
|
||||||
|
text, return_url=return_url, template_name=template_name
|
||||||
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Failed to render image via AstrBot API: {e}. Falling back to local rendering."
|
f"Failed to render image via AstrBot API: {e}. Falling back to local rendering."
|
||||||
|
|||||||
@@ -0,0 +1,184 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="utf-8"/>
|
||||||
|
<title>Astrbot PowerShell {{ version }} </title>
|
||||||
|
<link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/katex.min.css" integrity="sha384-wcIxkf4k558AjM3Yz3BBFQUbk/zgIYC2R0QpeeYb+TwlBVMrlgLqwRjRtGZiK7ww" crossorigin="anonymous">
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/highlight.js@11.9.0/lib/common.min.js"></script>
|
||||||
|
<script>hljs.highlightAll();</script>
|
||||||
|
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/katex.min.js" integrity="sha384-hIoBPJpTUs74ddyc4bFZSM1TVlQDA60VBbJS0oA934VSz82sBx1X7kSx2ATBDIyd" crossorigin="anonymous"></script>
|
||||||
|
<script defer src="https://cdn.jsdelivr.net/npm/katex@0.16.10/dist/contrib/auto-render.min.js" integrity="sha384-43gviWU0YVjaDtb/GhzOouOXtZMP/7XUzwPTstBeZFe/+rCMvRwr4yROQP43s0Xk" crossorigin="anonymous"
|
||||||
|
onload="renderMathInElement(document.getElementById('content'),{delimiters: [{left: '$$', right: '$$', display: true},{left: '$', right: '$', display: false}]});"></script>
|
||||||
|
<style>
|
||||||
|
:root {
|
||||||
|
--bg-color: #010409;
|
||||||
|
--text-color: #e6edf3;
|
||||||
|
--title-bar-color: #161b22;
|
||||||
|
--title-text-color: #e6edf3;
|
||||||
|
--font-family: 'Consolas', 'Microsoft YaHei Mono', 'Dengxian Mono', 'Courier New', monospace;
|
||||||
|
--glow-color: rgba(200, 220, 255, 0.7);
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes scanline {
|
||||||
|
0% {
|
||||||
|
background-position: 0 0;
|
||||||
|
}
|
||||||
|
100% {
|
||||||
|
background-position: 0 100%;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
background-color: var(--bg-color);
|
||||||
|
color: var(--text-color);
|
||||||
|
font-family: var(--font-family);
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
line-height: 1.6;
|
||||||
|
font-size: 18px;
|
||||||
|
/* The CRT glow effect from the image */
|
||||||
|
text-shadow: 0 0 15px var(--glow-color), 0 0 7px rgba(255, 255, 255, 1);
|
||||||
|
position: relative;
|
||||||
|
overflow: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
|
body::after {
|
||||||
|
content: " ";
|
||||||
|
display: block;
|
||||||
|
position: absolute;
|
||||||
|
top: 0;
|
||||||
|
left: 0;
|
||||||
|
right: 0;
|
||||||
|
bottom: 0;
|
||||||
|
background: linear-gradient(to bottom, transparent 50%, rgba(0, 0, 0, 0.3) 50%);
|
||||||
|
background-size: 100% 4px;
|
||||||
|
z-index: 2;
|
||||||
|
pointer-events: none;
|
||||||
|
animation: scanline 8s linear infinite;
|
||||||
|
}
|
||||||
|
|
||||||
|
.header {
|
||||||
|
background-color: var(--title-bar-color);
|
||||||
|
padding: 12px 18px;
|
||||||
|
color: var(--title-text-color);
|
||||||
|
font-size: 16px;
|
||||||
|
border-bottom: 1px solid #30363d;
|
||||||
|
text-shadow: none; /* No glow for title bar */
|
||||||
|
}
|
||||||
|
|
||||||
|
.header .title {
|
||||||
|
font-weight: bold;
|
||||||
|
font-size: 28px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.header .version {
|
||||||
|
opacity: 0.8;
|
||||||
|
margin-left: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
main {
|
||||||
|
padding: 1rem 1.5rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
#content {
|
||||||
|
/* min-width and max-width removed as per request */
|
||||||
|
}
|
||||||
|
|
||||||
|
/* --- Markdown Styles adjusted for terminal look --- */
|
||||||
|
h1, h2, h3, h4, h5, h6 {
|
||||||
|
line-height: 1.4;
|
||||||
|
margin-top: 20px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
padding-bottom: 5px;
|
||||||
|
border-bottom: 1px solid #30363d;
|
||||||
|
color: var(--text-color);
|
||||||
|
}
|
||||||
|
h1 { font-size: 2rem; }
|
||||||
|
h2 { font-size: 1.7rem; }
|
||||||
|
h3 { font-size: 1.4rem; }
|
||||||
|
|
||||||
|
p {
|
||||||
|
margin-top: 1rem;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
}
|
||||||
|
|
||||||
|
strong {
|
||||||
|
color: var(--text-color);
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
|
||||||
|
img {
|
||||||
|
max-width: 100%;
|
||||||
|
border: 1px solid #30363d;
|
||||||
|
display: block;
|
||||||
|
margin: 1rem auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
hr {
|
||||||
|
border: 0;
|
||||||
|
border-top: 1px dashed #30363d;
|
||||||
|
margin: 2rem 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
code {
|
||||||
|
font-family: var(--font-family);
|
||||||
|
padding: 0.2em 0.4em;
|
||||||
|
margin: 0;
|
||||||
|
font-size: 90%;
|
||||||
|
background-color: #161b22;
|
||||||
|
border-radius: 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
pre {
|
||||||
|
font-family: var(--font-family);
|
||||||
|
border-radius: 4px;
|
||||||
|
background: #0d1117;
|
||||||
|
padding: 1rem;
|
||||||
|
overflow-x: auto;
|
||||||
|
border: 1px solid #30363d;
|
||||||
|
}
|
||||||
|
|
||||||
|
pre > code {
|
||||||
|
padding: 0;
|
||||||
|
margin: 0;
|
||||||
|
font-size: 100%;
|
||||||
|
background-color: transparent;
|
||||||
|
border-radius: 0;
|
||||||
|
text-shadow: none; /* Disable glow inside code blocks for clarity */
|
||||||
|
}
|
||||||
|
|
||||||
|
a {
|
||||||
|
color: #58a6ff;
|
||||||
|
text-decoration: underline;
|
||||||
|
}
|
||||||
|
a:hover {
|
||||||
|
text-decoration: underline;
|
||||||
|
}
|
||||||
|
|
||||||
|
blockquote {
|
||||||
|
border-left: 4px solid #30363d;
|
||||||
|
padding: 0.5rem 1rem;
|
||||||
|
margin: 1.5rem 0;
|
||||||
|
color: #8b949e;
|
||||||
|
background-color: #161b22;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
|
||||||
|
<div class="header">
|
||||||
|
<span class="title">> Astrbot PowerShell</span>
|
||||||
|
<span class="version">{{ version }}</span>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<main>
|
||||||
|
<div id="content"></div>
|
||||||
|
</main>
|
||||||
|
|
||||||
|
<script src="https://cdn.jsdelivr.net/npm/marked/marked.min.js"></script>
|
||||||
|
<script>
|
||||||
|
document.getElementById('content').innerHTML = marked.parse(`{{ text | safe }}`);
|
||||||
|
</script>
|
||||||
|
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
# astrbot/core/utils/t2i/template_manager.py
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateManager:
|
||||||
|
"""
|
||||||
|
负责管理 t2i HTML 模板的 CRUD 和重置操作。
|
||||||
|
采用“用户覆盖内置”策略:用户模板存储在 data 目录中,并优先于内置模板加载。
|
||||||
|
所有创建、更新、删除操作仅影响用户目录,以确保更新框架时用户数据安全。
|
||||||
|
"""
|
||||||
|
|
||||||
|
CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"]
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.builtin_template_dir = os.path.join(
|
||||||
|
get_astrbot_path(), "astrbot", "core", "utils", "t2i", "template"
|
||||||
|
)
|
||||||
|
self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates")
|
||||||
|
|
||||||
|
os.makedirs(self.user_template_dir, exist_ok=True)
|
||||||
|
self._initialize_user_templates()
|
||||||
|
|
||||||
|
def _copy_core_templates(self, overwrite: bool = False):
|
||||||
|
"""从内置目录复制核心模板到用户目录。"""
|
||||||
|
for filename in self.CORE_TEMPLATES:
|
||||||
|
src = os.path.join(self.builtin_template_dir, filename)
|
||||||
|
dst = os.path.join(self.user_template_dir, filename)
|
||||||
|
if os.path.exists(src) and (overwrite or not os.path.exists(dst)):
|
||||||
|
shutil.copyfile(src, dst)
|
||||||
|
|
||||||
|
def _initialize_user_templates(self):
|
||||||
|
"""如果用户目录下缺少核心模板,则进行复制。"""
|
||||||
|
self._copy_core_templates(overwrite=False)
|
||||||
|
|
||||||
|
def _get_user_template_path(self, name: str) -> str:
|
||||||
|
"""获取用户模板的完整路径,防止路径遍历漏洞。"""
|
||||||
|
if ".." in name or "/" in name or "\\" in name:
|
||||||
|
raise ValueError("模板名称包含非法字符。")
|
||||||
|
return os.path.join(self.user_template_dir, f"{name}.html")
|
||||||
|
|
||||||
|
def _read_file(self, path: str) -> str:
|
||||||
|
"""读取文件内容。"""
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def list_templates(self) -> list[dict]:
|
||||||
|
"""
|
||||||
|
列出所有可用模板。
|
||||||
|
该列表是内置模板和用户模板的合并视图,用户模板将覆盖同名的内置模板。
|
||||||
|
"""
|
||||||
|
dirs_to_scan = [self.builtin_template_dir, self.user_template_dir]
|
||||||
|
all_names = {
|
||||||
|
os.path.splitext(f)[0]
|
||||||
|
for d in dirs_to_scan
|
||||||
|
for f in os.listdir(d)
|
||||||
|
if f.endswith(".html")
|
||||||
|
}
|
||||||
|
return [
|
||||||
|
{"name": name, "is_default": name == "base"} for name in sorted(all_names)
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_template(self, name: str) -> str:
|
||||||
|
"""
|
||||||
|
获取指定模板的内容。
|
||||||
|
优先从用户目录加载,如果不存在则回退到内置目录。
|
||||||
|
"""
|
||||||
|
user_path = self._get_user_template_path(name)
|
||||||
|
if os.path.exists(user_path):
|
||||||
|
return self._read_file(user_path)
|
||||||
|
|
||||||
|
builtin_path = os.path.join(self.builtin_template_dir, f"{name}.html")
|
||||||
|
if os.path.exists(builtin_path):
|
||||||
|
return self._read_file(builtin_path)
|
||||||
|
|
||||||
|
raise FileNotFoundError("模板不存在。")
|
||||||
|
|
||||||
|
def create_template(self, name: str, content: str):
|
||||||
|
"""在用户目录中创建一个新的模板文件。"""
|
||||||
|
path = self._get_user_template_path(name)
|
||||||
|
if os.path.exists(path):
|
||||||
|
raise FileExistsError("同名模板已存在。")
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
def update_template(self, name: str, content: str):
|
||||||
|
"""
|
||||||
|
更新一个模板。此操作始终写入用户目录。
|
||||||
|
如果更新的是一个内置模板,此操作实际上会在用户目录中创建一个修改后的副本,
|
||||||
|
从而实现对内置模板的“覆盖”。
|
||||||
|
"""
|
||||||
|
path = self._get_user_template_path(name)
|
||||||
|
with open(path, "w", encoding="utf-8") as f:
|
||||||
|
f.write(content)
|
||||||
|
|
||||||
|
def delete_template(self, name: str):
|
||||||
|
"""
|
||||||
|
仅删除用户目录中的模板文件。
|
||||||
|
如果删除的是一个覆盖了内置模板的用户模板,这将有效地“恢复”到内置版本。
|
||||||
|
"""
|
||||||
|
path = self._get_user_template_path(name)
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise FileNotFoundError("用户模板不存在,无法删除。")
|
||||||
|
os.remove(path)
|
||||||
|
|
||||||
|
def reset_default_template(self):
|
||||||
|
"""
|
||||||
|
将核心模板从内置目录强制重置到用户目录。
|
||||||
|
"""
|
||||||
|
self._copy_core_templates(overwrite=True)
|
||||||
@@ -157,7 +157,11 @@ class ChatRoute(Route):
|
|||||||
|
|
||||||
if type == "end":
|
if type == "end":
|
||||||
break
|
break
|
||||||
elif (streaming and type == "complete") or not streaming:
|
elif (
|
||||||
|
(streaming and type == "complete")
|
||||||
|
or not streaming
|
||||||
|
or type == "break"
|
||||||
|
):
|
||||||
# append bot message
|
# append bot message
|
||||||
new_his = {"type": "bot", "message": result_text}
|
new_his = {"type": "bot", "message": result_text}
|
||||||
await self.platform_history_mgr.insert(
|
await self.platform_history_mgr.insert(
|
||||||
@@ -197,6 +201,7 @@ class ChatRoute(Route):
|
|||||||
"Connection": "keep-alive",
|
"Connection": "keep-alive",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
response.timeout = None # fix SSE auto disconnect issue
|
||||||
return response
|
return response
|
||||||
|
|
||||||
async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
|
async def _get_webchat_conv_id_from_conv_id(self, conversation_id: str) -> str:
|
||||||
|
|||||||
@@ -16,11 +16,10 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
|||||||
from astrbot.core.platform.register import platform_registry
|
from astrbot.core.platform.register import platform_registry
|
||||||
from astrbot.core.provider.register import provider_registry
|
from astrbot.core.provider.register import provider_registry
|
||||||
from astrbot.core.star.star import star_registry
|
from astrbot.core.star.star import star_registry
|
||||||
from astrbot.core import logger, html_renderer
|
from astrbot.core import logger
|
||||||
from astrbot.core.provider import Provider
|
from astrbot.core.provider import Provider
|
||||||
from astrbot.core.provider.provider import RerankProvider
|
from astrbot.core.provider.provider import RerankProvider
|
||||||
import asyncio
|
import asyncio
|
||||||
from astrbot.core.utils.t2i.network_strategy import CUSTOM_T2I_TEMPLATE_PATH
|
|
||||||
|
|
||||||
|
|
||||||
def try_cast(value: str, type_: str):
|
def try_cast(value: str, type_: str):
|
||||||
@@ -52,24 +51,6 @@ def validate_config(
|
|||||||
def validate(data: dict, metadata: dict = schema, path=""):
|
def validate(data: dict, metadata: dict = schema, path=""):
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if key not in metadata:
|
if key not in metadata:
|
||||||
# 无 schema 的配置项,执行类型猜测
|
|
||||||
if isinstance(value, str):
|
|
||||||
try:
|
|
||||||
data[key] = int(value)
|
|
||||||
continue
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
data[key] = float(value)
|
|
||||||
continue
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if value.lower() == "true":
|
|
||||||
data[key] = True
|
|
||||||
elif value.lower() == "false":
|
|
||||||
data[key] = False
|
|
||||||
continue
|
continue
|
||||||
meta = metadata[key]
|
meta = metadata[key]
|
||||||
if "type" not in meta:
|
if "type" not in meta:
|
||||||
@@ -128,12 +109,12 @@ def validate_config(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_core:
|
if is_core:
|
||||||
for key, group in schema.items():
|
meta_all = {
|
||||||
group_meta = group.get("metadata")
|
**schema["platform_group"]["metadata"],
|
||||||
if not group_meta:
|
**schema["provider_group"]["metadata"],
|
||||||
continue
|
**schema["misc_config_group"]["metadata"],
|
||||||
# logger.info(f"验证配置: 组 {key} ...")
|
}
|
||||||
validate(data, group_meta, path=f"{key}.")
|
validate(data, meta_all)
|
||||||
else:
|
else:
|
||||||
validate(data, schema)
|
validate(data, schema)
|
||||||
|
|
||||||
@@ -143,6 +124,7 @@ def validate_config(
|
|||||||
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
|
def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False):
|
||||||
"""验证并保存配置"""
|
"""验证并保存配置"""
|
||||||
errors = None
|
errors = None
|
||||||
|
logger.info(f"Saving config, is_core={is_core}")
|
||||||
try:
|
try:
|
||||||
if is_core:
|
if is_core:
|
||||||
errors, post_config = validate_config(
|
errors, post_config = validate_config(
|
||||||
@@ -156,6 +138,7 @@ def save_config(post_config: dict, config: AstrBotConfig, is_core: bool = False)
|
|||||||
raise ValueError(f"验证配置时出现异常: {e}")
|
raise ValueError(f"验证配置时出现异常: {e}")
|
||||||
if errors:
|
if errors:
|
||||||
raise ValueError(f"格式校验未通过: {errors}")
|
raise ValueError(f"格式校验未通过: {errors}")
|
||||||
|
|
||||||
config.save_config(post_config)
|
config.save_config(post_config)
|
||||||
|
|
||||||
|
|
||||||
@@ -186,56 +169,9 @@ class ConfigRoute(Route):
|
|||||||
"/config/provider/check_one": ("GET", self.check_one_provider_status),
|
"/config/provider/check_one": ("GET", self.check_one_provider_status),
|
||||||
"/config/provider/list": ("GET", self.get_provider_config_list),
|
"/config/provider/list": ("GET", self.get_provider_config_list),
|
||||||
"/config/provider/model_list": ("GET", self.get_provider_model_list),
|
"/config/provider/model_list": ("GET", self.get_provider_model_list),
|
||||||
"/config/astrbot/t2i-template/get": ("GET", self.get_t2i_template),
|
|
||||||
"/config/astrbot/t2i-template/save": ("POST", self.post_t2i_template),
|
|
||||||
"/config/astrbot/t2i-template/delete": ("DELETE", self.delete_t2i_template),
|
|
||||||
}
|
}
|
||||||
self.register_routes()
|
self.register_routes()
|
||||||
|
|
||||||
async def get_t2i_template(self):
|
|
||||||
"""获取 T2I 模板"""
|
|
||||||
try:
|
|
||||||
template = await html_renderer.network_strategy.get_template()
|
|
||||||
has_custom_template = os.path.exists(CUSTOM_T2I_TEMPLATE_PATH)
|
|
||||||
return (
|
|
||||||
Response()
|
|
||||||
.ok({"template": template, "has_custom_template": has_custom_template})
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return Response().error(f"获取模板失败: {str(e)}").__dict__
|
|
||||||
|
|
||||||
async def post_t2i_template(self):
|
|
||||||
"""保存 T2I 模板"""
|
|
||||||
try:
|
|
||||||
post_data = await request.json
|
|
||||||
if not post_data or "template" not in post_data:
|
|
||||||
return Response().error("缺少模板内容").__dict__
|
|
||||||
|
|
||||||
template_content = post_data["template"]
|
|
||||||
|
|
||||||
# 保存自定义模板到文件
|
|
||||||
with open(CUSTOM_T2I_TEMPLATE_PATH, "w", encoding="utf-8") as f:
|
|
||||||
f.write(template_content)
|
|
||||||
|
|
||||||
return Response().ok(message="模板保存成功").__dict__
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return Response().error(f"保存模板失败: {str(e)}").__dict__
|
|
||||||
|
|
||||||
async def delete_t2i_template(self):
|
|
||||||
"""删除自定义 T2I 模板,恢复默认模板"""
|
|
||||||
try:
|
|
||||||
if os.path.exists(CUSTOM_T2I_TEMPLATE_PATH):
|
|
||||||
os.remove(CUSTOM_T2I_TEMPLATE_PATH)
|
|
||||||
return Response().ok(message="已恢复默认模板").__dict__
|
|
||||||
else:
|
|
||||||
return Response().ok(message="未找到自定义模板文件").__dict__
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return Response().error(f"删除模板失败: {str(e)}").__dict__
|
|
||||||
|
|
||||||
async def get_abconf_list(self):
|
async def get_abconf_list(self):
|
||||||
"""获取所有 AstrBot 配置文件的列表"""
|
"""获取所有 AstrBot 配置文件的列表"""
|
||||||
abconf_list = self.acm.get_conf_list()
|
abconf_list = self.acm.get_conf_list()
|
||||||
@@ -766,6 +702,13 @@ class ConfigRoute(Route):
|
|||||||
if conf_id not in self.acm.confs:
|
if conf_id not in self.acm.confs:
|
||||||
raise ValueError(f"配置文件 {conf_id} 不存在")
|
raise ValueError(f"配置文件 {conf_id} 不存在")
|
||||||
astrbot_config = self.acm.confs[conf_id]
|
astrbot_config = self.acm.confs[conf_id]
|
||||||
|
|
||||||
|
# 保留服务端的 t2i_active_template 值
|
||||||
|
if "t2i_active_template" in astrbot_config:
|
||||||
|
post_configs["t2i_active_template"] = astrbot_config[
|
||||||
|
"t2i_active_template"
|
||||||
|
]
|
||||||
|
|
||||||
save_config(post_configs, astrbot_config, is_core=True)
|
save_config(post_configs, astrbot_config, is_core=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|||||||
@@ -169,15 +169,65 @@ class ConversationRoute(Route):
|
|||||||
"""删除对话"""
|
"""删除对话"""
|
||||||
try:
|
try:
|
||||||
data = await request.get_json()
|
data = await request.get_json()
|
||||||
user_id = data.get("user_id")
|
|
||||||
cid = data.get("cid")
|
|
||||||
|
|
||||||
if not user_id or not cid:
|
# 检查是否是批量删除
|
||||||
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
if "conversations" in data:
|
||||||
await self.core_lifecycle.conversation_manager.delete_conversation(
|
# 批量删除
|
||||||
unified_msg_origin=user_id, conversation_id=cid
|
conversations = data.get("conversations", [])
|
||||||
)
|
if not conversations:
|
||||||
return Response().ok({"message": "对话删除成功"}).__dict__
|
return (
|
||||||
|
Response().error("批量删除时conversations参数不能为空").__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
deleted_count = 0
|
||||||
|
failed_items = []
|
||||||
|
|
||||||
|
for conv in conversations:
|
||||||
|
user_id = conv.get("user_id")
|
||||||
|
cid = conv.get("cid")
|
||||||
|
|
||||||
|
if not user_id or not cid:
|
||||||
|
failed_items.append(
|
||||||
|
f"user_id:{user_id}, cid:{cid} - 缺少必要参数"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||||
|
unified_msg_origin=user_id, conversation_id=cid
|
||||||
|
)
|
||||||
|
deleted_count += 1
|
||||||
|
except Exception as e:
|
||||||
|
failed_items.append(f"user_id:{user_id}, cid:{cid} - {str(e)}")
|
||||||
|
|
||||||
|
message = f"成功删除 {deleted_count} 个对话"
|
||||||
|
if failed_items:
|
||||||
|
message += f",失败 {len(failed_items)} 个"
|
||||||
|
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"message": message,
|
||||||
|
"deleted_count": deleted_count,
|
||||||
|
"failed_count": len(failed_items),
|
||||||
|
"failed_items": failed_items,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 单个删除
|
||||||
|
user_id = data.get("user_id")
|
||||||
|
cid = data.get("cid")
|
||||||
|
|
||||||
|
if not user_id or not cid:
|
||||||
|
return Response().error("缺少必要参数: user_id 和 cid").__dict__
|
||||||
|
|
||||||
|
await self.core_lifecycle.conversation_manager.delete_conversation(
|
||||||
|
unified_msg_origin=user_id, conversation_id=cid
|
||||||
|
)
|
||||||
|
return Response().ok({"message": "对话删除成功"}).__dict__
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}")
|
logger.error(f"删除对话失败: {str(e)}\n{traceback.format_exc()}")
|
||||||
|
|||||||
@@ -10,7 +10,9 @@ class LogRoute(Route):
|
|||||||
super().__init__(context)
|
super().__init__(context)
|
||||||
self.log_broker = log_broker
|
self.log_broker = log_broker
|
||||||
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
|
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
|
||||||
self.app.add_url_rule("/api/log-history", view_func=self.log_history, methods=["GET"])
|
self.app.add_url_rule(
|
||||||
|
"/api/log-history", view_func=self.log_history, methods=["GET"]
|
||||||
|
)
|
||||||
|
|
||||||
async def log(self):
|
async def log(self):
|
||||||
async def stream():
|
async def stream():
|
||||||
@@ -48,9 +50,15 @@ class LogRoute(Route):
|
|||||||
"""获取日志历史"""
|
"""获取日志历史"""
|
||||||
try:
|
try:
|
||||||
logs = list(self.log_broker.log_cache)
|
logs = list(self.log_broker.log_cache)
|
||||||
return Response().ok(data={
|
return (
|
||||||
"logs": logs,
|
Response()
|
||||||
}).__dict__
|
.ok(
|
||||||
|
data={
|
||||||
|
"logs": logs,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.error(f"获取日志历史失败: {e}")
|
logger.error(f"获取日志历史失败: {e}")
|
||||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||||
|
|||||||
@@ -15,8 +15,24 @@ class Route:
|
|||||||
self.config = context.config
|
self.config = context.config
|
||||||
|
|
||||||
def register_routes(self):
|
def register_routes(self):
|
||||||
for route, (method, func) in self.routes.items():
|
def _add_rule(path, method, func):
|
||||||
self.app.add_url_rule(f"/api{route}", view_func=func, methods=[method])
|
# 统一添加 /api 前缀
|
||||||
|
full_path = f"/api{path}"
|
||||||
|
self.app.add_url_rule(full_path, view_func=func, methods=[method])
|
||||||
|
|
||||||
|
# 兼容字典和列表两种格式
|
||||||
|
routes_to_register = (
|
||||||
|
self.routes.items() if isinstance(self.routes, dict) else self.routes
|
||||||
|
)
|
||||||
|
|
||||||
|
for route, definition in routes_to_register:
|
||||||
|
# 兼容一个路由多个方法
|
||||||
|
if isinstance(definition, list):
|
||||||
|
for method, func in definition:
|
||||||
|
_add_rule(route, method, func)
|
||||||
|
else:
|
||||||
|
method, func = definition
|
||||||
|
_add_rule(route, method, func)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -0,0 +1,230 @@
|
|||||||
|
# astrbot/dashboard/routes/t2i.py
|
||||||
|
|
||||||
|
from dataclasses import asdict
|
||||||
|
from quart import jsonify, request
|
||||||
|
|
||||||
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
|
from astrbot.core.utils.t2i.template_manager import TemplateManager
|
||||||
|
from .route import Response, Route, RouteContext
|
||||||
|
|
||||||
|
|
||||||
|
class T2iRoute(Route):
|
||||||
|
def __init__(self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle):
|
||||||
|
super().__init__(context)
|
||||||
|
self.core_lifecycle = core_lifecycle
|
||||||
|
self.config = core_lifecycle.astrbot_config
|
||||||
|
self.manager = TemplateManager()
|
||||||
|
# 使用列表保证路由注册顺序,避免 /<name> 路由优先匹配 /reset_default
|
||||||
|
self.routes = [
|
||||||
|
("/t2i/templates", ("GET", self.list_templates)),
|
||||||
|
("/t2i/templates/active", ("GET", self.get_active_template)),
|
||||||
|
("/t2i/templates/create", ("POST", self.create_template)),
|
||||||
|
("/t2i/templates/reset_default", ("POST", self.reset_default_template)),
|
||||||
|
("/t2i/templates/set_active", ("POST", self.set_active_template)),
|
||||||
|
# 动态路由应该在静态路由之后注册
|
||||||
|
(
|
||||||
|
"/t2i/templates/<name>",
|
||||||
|
[
|
||||||
|
("GET", self.get_template),
|
||||||
|
("PUT", self.update_template),
|
||||||
|
("DELETE", self.delete_template),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
||||||
|
self.register_routes()
|
||||||
|
|
||||||
|
async def list_templates(self):
|
||||||
|
"""获取所有T2I模板列表"""
|
||||||
|
try:
|
||||||
|
templates = self.manager.list_templates()
|
||||||
|
return jsonify(asdict(Response().ok(data=templates)))
|
||||||
|
except Exception as e:
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def get_active_template(self):
|
||||||
|
"""获取当前激活的T2I模板"""
|
||||||
|
try:
|
||||||
|
active_template = self.config.get("t2i_active_template", "base")
|
||||||
|
return jsonify(
|
||||||
|
asdict(Response().ok(data={"active_template": active_template}))
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in get_active_template", exc_info=True)
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def get_template(self, name: str):
|
||||||
|
"""获取指定名称的T2I模板内容"""
|
||||||
|
try:
|
||||||
|
content = self.manager.get_template(name)
|
||||||
|
return jsonify(
|
||||||
|
asdict(Response().ok(data={"name": name, "content": content}))
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
response = jsonify(asdict(Response().error("Template not found")))
|
||||||
|
response.status_code = 404
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def create_template(self):
|
||||||
|
"""创建一个新的T2I模板"""
|
||||||
|
try:
|
||||||
|
data = await request.json
|
||||||
|
name = data.get("name")
|
||||||
|
content = data.get("content")
|
||||||
|
if not name or not content:
|
||||||
|
response = jsonify(
|
||||||
|
asdict(Response().error("Name and content are required."))
|
||||||
|
)
|
||||||
|
response.status_code = 400
|
||||||
|
return response
|
||||||
|
name = name.strip()
|
||||||
|
|
||||||
|
self.manager.create_template(name, content)
|
||||||
|
response = jsonify(
|
||||||
|
asdict(
|
||||||
|
Response().ok(
|
||||||
|
data={"name": name}, message="Template created successfully."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
response.status_code = 201
|
||||||
|
return response
|
||||||
|
except FileExistsError:
|
||||||
|
response = jsonify(
|
||||||
|
asdict(Response().error("Template with this name already exists."))
|
||||||
|
)
|
||||||
|
response.status_code = 409
|
||||||
|
return response
|
||||||
|
except ValueError as e:
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 400
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def update_template(self, name: str):
|
||||||
|
"""更新一个已存在的T2I模板"""
|
||||||
|
try:
|
||||||
|
name = name.strip()
|
||||||
|
data = await request.json
|
||||||
|
content = data.get("content")
|
||||||
|
if content is None:
|
||||||
|
response = jsonify(asdict(Response().error("Content is required.")))
|
||||||
|
response.status_code = 400
|
||||||
|
return response
|
||||||
|
|
||||||
|
self.manager.update_template(name, content)
|
||||||
|
|
||||||
|
# 检查更新的是否为当前激活的模板,如果是,则热重载
|
||||||
|
active_template = self.config.get("t2i_active_template", "base")
|
||||||
|
if name == active_template:
|
||||||
|
await self.core_lifecycle.reload_pipeline_scheduler("default")
|
||||||
|
message = f"模板 '{name}' 已更新并重新加载。"
|
||||||
|
else:
|
||||||
|
message = f"模板 '{name}' 已更新。"
|
||||||
|
|
||||||
|
return jsonify(asdict(Response().ok(data={"name": name}, message=message)))
|
||||||
|
except ValueError as e:
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 400
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def delete_template(self, name: str):
|
||||||
|
"""删除一个T2I模板"""
|
||||||
|
try:
|
||||||
|
name = name.strip()
|
||||||
|
self.manager.delete_template(name)
|
||||||
|
return jsonify(
|
||||||
|
asdict(Response().ok(message="Template deleted successfully."))
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
response = jsonify(asdict(Response().error("Template not found.")))
|
||||||
|
response.status_code = 404
|
||||||
|
return response
|
||||||
|
except ValueError as e:
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 400
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def set_active_template(self):
|
||||||
|
"""设置当前活动的T2I模板"""
|
||||||
|
try:
|
||||||
|
data = await request.json
|
||||||
|
name = data.get("name")
|
||||||
|
if not name:
|
||||||
|
response = jsonify(asdict(Response().error("模板名称(name)不能为空。")))
|
||||||
|
response.status_code = 400
|
||||||
|
return response
|
||||||
|
|
||||||
|
# 验证模板文件是否存在
|
||||||
|
self.manager.get_template(name)
|
||||||
|
|
||||||
|
# 更新配置
|
||||||
|
config = self.config
|
||||||
|
config["t2i_active_template"] = name
|
||||||
|
config.save_config(config)
|
||||||
|
|
||||||
|
# 热重载以应用更改
|
||||||
|
await self.core_lifecycle.reload_pipeline_scheduler("default")
|
||||||
|
|
||||||
|
return jsonify(asdict(Response().ok(message=f"模板 '{name}' 已成功应用。")))
|
||||||
|
|
||||||
|
except FileNotFoundError:
|
||||||
|
response = jsonify(
|
||||||
|
asdict(Response().error(f"模板 '{name}' 不存在,无法应用。"))
|
||||||
|
)
|
||||||
|
response.status_code = 404
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in set_active_template", exc_info=True)
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
|
|
||||||
|
async def reset_default_template(self):
|
||||||
|
"""重置默认的'base'模板"""
|
||||||
|
try:
|
||||||
|
self.manager.reset_default_template()
|
||||||
|
|
||||||
|
# 更新配置,将激活模板也重置为'base'
|
||||||
|
config = self.config
|
||||||
|
config["t2i_active_template"] = "base"
|
||||||
|
config.save_config(config)
|
||||||
|
|
||||||
|
# 热重载以应用更改
|
||||||
|
await self.core_lifecycle.reload_pipeline_scheduler("default")
|
||||||
|
|
||||||
|
return jsonify(
|
||||||
|
asdict(
|
||||||
|
Response().ok(
|
||||||
|
message="Default template has been reset and activated."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
except FileNotFoundError as e:
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 404
|
||||||
|
return response
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("Error in reset_default_template", exc_info=True)
|
||||||
|
response = jsonify(asdict(Response().error(str(e))))
|
||||||
|
response.status_code = 500
|
||||||
|
return response
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
from quart import request
|
from quart import request
|
||||||
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from astrbot.core.utils.io import get_local_ip_addresses
|
|||||||
from .routes import *
|
from .routes import *
|
||||||
from .routes.route import Response, RouteContext
|
from .routes.route import Response, RouteContext
|
||||||
from .routes.session_management import SessionManagementRoute
|
from .routes.session_management import SessionManagementRoute
|
||||||
|
from .routes.t2i import T2iRoute
|
||||||
|
|
||||||
APP: Quart = None
|
APP: Quart = None
|
||||||
|
|
||||||
@@ -28,10 +29,19 @@ class AstrBotDashboard:
|
|||||||
core_lifecycle: AstrBotCoreLifecycle,
|
core_lifecycle: AstrBotCoreLifecycle,
|
||||||
db: BaseDatabase,
|
db: BaseDatabase,
|
||||||
shutdown_event: asyncio.Event,
|
shutdown_event: asyncio.Event,
|
||||||
|
webui_dir: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.core_lifecycle = core_lifecycle
|
self.core_lifecycle = core_lifecycle
|
||||||
self.config = core_lifecycle.astrbot_config
|
self.config = core_lifecycle.astrbot_config
|
||||||
self.data_path = os.path.abspath(os.path.join(get_astrbot_data_path(), "dist"))
|
|
||||||
|
# 参数指定webui目录
|
||||||
|
if webui_dir and os.path.exists(webui_dir):
|
||||||
|
self.data_path = os.path.abspath(webui_dir)
|
||||||
|
else:
|
||||||
|
self.data_path = os.path.abspath(
|
||||||
|
os.path.join(get_astrbot_data_path(), "dist")
|
||||||
|
)
|
||||||
|
|
||||||
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
||||||
APP = self.app # noqa
|
APP = self.app # noqa
|
||||||
self.app.config["MAX_CONTENT_LENGTH"] = (
|
self.app.config["MAX_CONTENT_LENGTH"] = (
|
||||||
@@ -60,9 +70,8 @@ class AstrBotDashboard:
|
|||||||
self.session_management_route = SessionManagementRoute(
|
self.session_management_route = SessionManagementRoute(
|
||||||
self.context, db, core_lifecycle
|
self.context, db, core_lifecycle
|
||||||
)
|
)
|
||||||
self.persona_route = PersonaRoute(
|
self.persona_route = PersonaRoute(self.context, db, core_lifecycle)
|
||||||
self.context, db, core_lifecycle
|
self.t2i_route = T2iRoute(self.context, core_lifecycle)
|
||||||
)
|
|
||||||
|
|
||||||
self.app.add_url_rule(
|
self.app.add_url_rule(
|
||||||
"/api/plug/<path:subpath>",
|
"/api/plug/<path:subpath>",
|
||||||
|
|||||||
@@ -0,0 +1,24 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
> 新版本介绍和用法请看 AstrBot 官方 Blog [v4.0.0 的新变化](https://blog.astrbot.app/posts/what-is-changed-in-4.0.0/)。
|
||||||
|
|
||||||
|
* Refactor: using sqlmodel(sqlchemy+pydantic) as ORM framework and switch to async-based sqlite operation by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2294
|
||||||
|
* Fix: 当多个相同消息平台实例部署时上下文可能混乱(共享) by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2298
|
||||||
|
* Improve: 引入全新的人格管理模式 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2305
|
||||||
|
* Feature: Add support to sync MCP servers from ModelScope by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2313
|
||||||
|
* Feature: 移除 MCP 市场相关逻辑 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2314
|
||||||
|
* Refactor: 重构配置文件管理,以支持更灵活的、会话粒度的(基于 umo part)配置文件隔离 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2328
|
||||||
|
* Feature: 增加图片转述提供商配置、支持用户自定义模型模态能力 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2422
|
||||||
|
* Feature: 优化 WebSearch 的爬取网页速度并且支持使用 Tavily 作为搜索引擎 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2427
|
||||||
|
* Feature: 添加url转知识库功能 by @RC-CHN in https://github.com/AstrBotDevs/AstrBot/pull/2280
|
||||||
|
* Feature: 添加条件显示逻辑以优化插件配置项的可见性管理 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2433
|
||||||
|
* Feature: 支持在 WebUI 配置文件页中配置默认知识库 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2437
|
||||||
|
* Feature: 重构 Function Tool 管理并初步引入 Multi Agent 及 Agent Handsoff 机制 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2454
|
||||||
|
* feat: 添加数据迁移助手以及相关迁移方法 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2477
|
||||||
|
* Refactor: 重构 SharedPreference 类并采用数据库存储替换 json 存储 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2482
|
||||||
|
* Feature: 支持配置重排序模型(vLLM API 格式)用于 score 任务 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2496
|
||||||
|
* Feature: 支持在配置文件配置可用的插件组 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2505
|
||||||
|
* Feature: llm_tool 装饰器返回值支持 mcp 库的 tool 返回值类型 (mcp.type.CallToolResult) by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2507
|
||||||
|
* Feature: 多 t2i 服务的随机负载均衡 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2529
|
||||||
|
* Improve: 扩大配置文件生效范围的自定义程度到会话粒度 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2532
|
||||||
|
* Feature: 支持可视化自定义 T2I 模版 by @Soulter in https://github.com/AstrBotDevs/AstrBot/pull/2581
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
> 如果已经使用自定义文转图模板,此次升级之后将会被覆盖,请提前备份。路径在 `astrbot/core/utils/t2i/template` 目录下。
|
||||||
|
|
||||||
|
0. ‼️‼️‼️ 修复 LLM 仍会调用已禁用的工具的问题 ([#2729](https://github.com/Soulter/AstrBot/issues/2729))
|
||||||
|
1. ‼️ 修复 WebChat 下,Agent 长时任务时,SSE 连接自动断开的问题
|
||||||
|
2. ‼️ 修复自定义文转图模板更新版本后会被覆盖的问题 ([#2677](https://github.com/Soulter/AstrBot/issues/2677))
|
||||||
|
3. 修复 Satori 适配器教程链接 ([#2668](https://github.com/Soulter/AstrBot/issues/2668))
|
||||||
|
4. 修复插件页表格视图中,点击状态字段表头排序不起作用的问题 ([#2714](https://github.com/Soulter/AstrBot/issues/2714))
|
||||||
|
5. 修复工具调用时的 content 内容在重新加载后没有显示在 webchat 的问题 ([#2727](https://github.com/Soulter/AstrBot/issues/2727))
|
||||||
|
6. 允许添加多个 tavily API Key 进行轮询 ([#2725](https://github.com/Soulter/AstrBot/issues/2725))
|
||||||
|
7. 添加 --webui-dir 启动参数以支持指定 WebUI 构建文件目录 ([#2680](https://github.com/Soulter/AstrBot/issues/2680))
|
||||||
|
8. 兼容指令名和第一个参数之间没有空格的情况 ([#2650](https://github.com/Soulter/AstrBot/issues/2650))
|
||||||
|
9. 支持在 WebUI 自定义 OpenAI API extra_body 参数 ([#2719](https://github.com/Soulter/AstrBot/issues/2719))
|
||||||
|
10. 增加 on_platform_loaded 钩子以在消息平台适配器实例化完成后触发 ([#2651](https://github.com/Soulter/AstrBot/issues/2651))
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
修复了 v4.1.0 `model referenced before assignment` 的错误。
|
||||||
|
|
||||||
|
> 如果已经使用自定义文转图模板,此次升级之后将会被覆盖,请提前备份。路径在 `astrbot/core/utils/t2i/template` 目录下。
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
0. ‼️‼️‼️ fix: 修复 4.1.1 版本下,指令调用异常的问题
|
||||||
|
1. ‼️‼️ fix: 修复多配置文件配置的不同人格无法生效的问题 ([#2739](https://github.com/AstrBotDevs/AstrBot/issues/2739))
|
||||||
|
2. ‼️‼️ fix: 修复人格所选择的工具无法应用的问题 ([#2739](https://github.com/AstrBotDevs/AstrBot/issues/2739))
|
||||||
|
3. ‼️‼️ fix: 修复平台配置下的「内容安全」组无法生效 ([#2751](https://github.com/AstrBotDevs/AstrBot/issues/2751))
|
||||||
|
4. perf: 检查服务提供商可用性时跳过未启用的提供商,解决部分 `provider with id xxx not found` 的问题
|
||||||
|
|
||||||
|
fixes: [#2724](https://github.com/AstrBotDevs/AstrBot/issues/2724)
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
0. ‼️ fix: 修复 4.0.0 版本之后,配置默认 TTS 或者 STT 模型之后仍无法生效的问题 ([#2758](https://github.com/Soulter/AstrBot/issues/2758))
|
||||||
|
1. ‼️ fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 ([#2757](https://github.com/Soulter/AstrBot/issues/2757))
|
||||||
|
2. feat: 支持在 WebUI 复制提供商配置以简化操作 ([#2767](https://github.com/Soulter/AstrBot/issues/2767))
|
||||||
|
3. fix: handle image value correctly for mcp BlobResourceContents ([#2753](https://github.com/Soulter/AstrBot/issues/2753))
|
||||||
|
4. feat: 增加 QQ 群名称识别到 system prompt, 并提供相应的配置 ([#2770](https://github.com/Soulter/AstrBot/issues/2770))
|
||||||
|
5. fix: parameter type/default handling in CommandFilter
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
0. ‼️ fix: 修复 4.0.0 版本之后,配置默认 TTS 或者 STT 模型之后仍无法生效的问题 ([#2758](https://github.com/Soulter/AstrBot/issues/2758))
|
||||||
|
1. ‼️ fix: 修复分段回复时,引用消息单独发送导致第一条消息内容为空的问题 ([#2757](https://github.com/Soulter/AstrBot/issues/2757))
|
||||||
|
2. feat: 支持在 WebUI 复制提供商配置以简化操作 ([#2767](https://github.com/Soulter/AstrBot/issues/2767))
|
||||||
|
3. fix: handle image value correctly for mcp BlobResourceContents ([#2753](https://github.com/Soulter/AstrBot/issues/2753))
|
||||||
|
4. feat: 增加 QQ 群名称识别到 system prompt, 并提供相应的配置 ([#2770](https://github.com/Soulter/AstrBot/issues/2770))
|
||||||
|
5. fix: 修复 4.1.3 的异常问题
|
||||||
|
|
||||||
|
**总之上个版本有很严重的 bug 赶快更新!**
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
0. feat: 新增 Misskey 平台适配器 ([#2774](https://github.com/AstrBotDevs/AstrBot/issues/2774))
|
||||||
|
1. fix: 修复aiocqhttp适配器at会获取群昵称而消息不会获取的逻辑不一致 ([#2769](https://github.com/AstrBotDevs/AstrBot/issues/2769))
|
||||||
|
2. fix: 修复「对话管理」页面的关键词搜索功能失效的问题并优化一些 UI 样式 ([#2837](https://github.com/AstrBotDevs/AstrBot/issues/2837))
|
||||||
|
3. fix: 识别「引用消息」的图片时优先使用默认图片转述提供商 ([#2836](https://github.com/AstrBotDevs/AstrBot/issues/2836))
|
||||||
|
5. fix: 修复 Telegram 下流式传输时,第一次输出的内容会被覆盖掉的问题
|
||||||
|
6. perf: 优化统计页内存占用和消息数据趋势的样式 ([#2826](https://github.com/AstrBotDevs/AstrBot/issues/2826))
|
||||||
|
7. perf: 优化 「插件页」、「对话管理页」、「会话管理页」的样式
|
||||||
|
8. fix: on_tool_end hook unavailable
|
||||||
|
9. feat: add audioop-lts dependencies ([#2809](https://github.com/AstrBotDevs/AstrBot/issues/2809))
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
1. fix: 修复在某些情况下,出现 「返回的 Provider 不是 Provider 类型的错误」
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
# What's Changed
|
||||||
|
|
||||||
|
1. perf: 优化 WebChat 等组件的 UI 风格
|
||||||
|
2. fix: 修复 4.1.6 版本可能无法点击更新按钮的问题
|
||||||
|
3. fix: 修复更新开发版的时候,可能无法同时更新 WebUI 的问题
|
||||||
|
4. feat: 支持在「对话数据」页批量删除对话
|
||||||
|
5. fix: 修复部分错误地显示「格式校验未通过」的问题
|
||||||
|
6. perf: WebChat 支持手动填写模型名称
|
||||||
Binary file not shown.
|
After Width: | Height: | Size: 26 KiB |
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user