Compare commits
124 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 062af1ac08 | |||
| 79d38f9597 | |||
| 4d186baa35 | |||
| e54eaab842 | |||
| 43b6297b5d | |||
| c20f4f5adf | |||
| dc1f222cd2 | |||
| c2b687212c | |||
| 849913276d | |||
| 23579c1e4a | |||
| e031161fd4 | |||
| 4800ee6c0a | |||
| d3a7fef9b0 | |||
| 40822fe77a | |||
| 837b670213 | |||
| 57ce69f3fb | |||
| be022c4894 | |||
| 8a366964bb | |||
| ee86b68470 | |||
| 60352307aa | |||
| 3ebd2f746f | |||
| 1c1a65b637 | |||
| 010e60d029 | |||
| 7a25568861 | |||
| 5f4f913661 | |||
| ccd0e34a53 | |||
| 72f1ffccd3 | |||
| ea7a52945f | |||
| 89d4d1351a | |||
| b757c91d93 | |||
| 27203d7a4d | |||
| 9ad4e18ac5 | |||
| fcdc8f3ce7 | |||
| 78b994b84a | |||
| 58bfc677e2 | |||
| 7d17285a0c | |||
| e9eb00a0d4 | |||
| 48d07af574 | |||
| 2fc62efd88 | |||
| be516d75bd | |||
| 951d5fde85 | |||
| 1389abc052 | |||
| 19ad67a77f | |||
| 641f308344 | |||
| 9f097fa4d5 | |||
| 5ad362c52b | |||
| 614f238a61 | |||
| dec91950bc | |||
| 6cef9c23f0 | |||
| 3f568bf136 | |||
| 5484b421ce | |||
| 02f21e07d3 | |||
| fff1f23a83 | |||
| a056ec0d38 | |||
| 2eb9e5dde3 | |||
| 627d2a4701 | |||
| 76895fe86d | |||
| 64c3c85780 | |||
| 7288348857 | |||
| 62e73299b1 | |||
| fe76c41ed8 | |||
| 1a92edf8be | |||
| b63b606a4e | |||
| 8e2ef3d22b | |||
| c6c4a32283 | |||
| b70b3b158e | |||
| 3d59ab8108 | |||
| b6c3089510 | |||
| bd92aac280 | |||
| 5299e802e9 | |||
| 8e5a57d7dd | |||
| beaa324fb6 | |||
| 79e64fe206 | |||
| 93f525e3fe | |||
| aacb803c64 | |||
| 8a0665b222 | |||
| 20e41a7f73 | |||
| 93a1699a35 | |||
| c33c07e4af | |||
| c7484d0cc9 | |||
| fb85a7bb35 | |||
| 42ff9a4d34 | |||
| 005e9eae7c | |||
| 3e325debcc | |||
| a221de9a2b | |||
| 32b0cc1865 | |||
| bbf85f8a12 | |||
| 67a0172b28 | |||
| fb19d4d45b | |||
| a156b1af14 | |||
| a604b4943c | |||
| 3f0b6435d9 | |||
| e0f029e2cb | |||
| 89d3fd5fab | |||
| a38b00be6b | |||
| 0e8d52b591 | |||
| 298c77740d | |||
| c681aae8ee | |||
| faef98b089 | |||
| 84a3e0a30b | |||
| 69bd553ce0 | |||
| fd0c0f8975 | |||
| 860ceb06b4 | |||
| ecf501bf72 | |||
| 81a2ed1e25 | |||
| 76ab28338a | |||
| 9a56c9630f | |||
| 53b9497c18 | |||
| 750b16b6ee | |||
| 0ee3e0779a | |||
| 333c2d9299 | |||
| ad37ff5048 | |||
| 33f86f3bde | |||
| 8acb969a49 | |||
| b74b5933b8 | |||
| 681c556b7e | |||
| 1746684e52 | |||
| 0b93d06555 | |||
| 8a8b8c7c27 | |||
| 6b6577006d | |||
| 5c14ebb049 | |||
| 9717a736b1 | |||
| 9e7fe773bd | |||
| 5c4326c302 |
+10
-2
@@ -9,12 +9,20 @@ RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3-dev \
|
||||
libffi-dev \
|
||||
libssl-dev \
|
||||
ca-certificates \
|
||||
bash \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN python -m pip install -r requirements.txt --no-cache-dir
|
||||
RUN python -m pip install uv
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pyffmpeg pilk --no-cache-dir --system
|
||||
|
||||
RUN python -m pip install socksio wechatpy cryptography --no-cache-dir
|
||||
# 释出 ffmpeg
|
||||
RUN python -c "from pyffmpeg import FFmpeg; ff = FFmpeg();"
|
||||
|
||||
# add /root/.pyffmpeg/bin/ffmpeg to PATH, inorder to use ffmpeg
|
||||
RUN echo 'export PATH=$PATH:/root/.pyffmpeg/bin' >> ~/.bashrc
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
FROM python:3.10-slim
|
||||
|
||||
WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
libffi-dev \
|
||||
libssl-dev \
|
||||
curl \
|
||||
unzip \
|
||||
ca-certificates \
|
||||
bash \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Installation of Node.js
|
||||
ENV NVM_DIR="/root/.nvm"
|
||||
RUN curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.2/install.sh | bash && \
|
||||
. "$NVM_DIR/nvm.sh" && \
|
||||
nvm install 22 && \
|
||||
nvm use 22
|
||||
RUN /bin/bash -c ". \"$NVM_DIR/nvm.sh\" && node -v && npm -v"
|
||||
|
||||
RUN python -m pip install uv
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pyffmpeg --no-cache-dir --system
|
||||
|
||||
EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD ["python", "main.py"]
|
||||
@@ -1,6 +1,6 @@
|
||||
<p align="center">
|
||||
|
||||

|
||||
|
||||

|
||||
|
||||
</p>
|
||||
|
||||
@@ -15,7 +15,7 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||

|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
@@ -30,20 +30,26 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
<!-- [](https://codecov.io/gh/Soulter/AstrBot)
|
||||
-->
|
||||
|
||||
## ✨ 近期更新
|
||||
|
||||
1. AstrBot 现已支持接入 [MCP](https://modelcontextprotocol.io/) 服务器!
|
||||
|
||||
## ✨ 主要功能
|
||||
|
||||
> [!NOTE]
|
||||
> 🪧 我们正基于前沿科研成果,设计并实现适用于角色扮演和情感陪伴的长短期记忆模型及情绪控制模型,旨在提升对话的真实性与情感表达能力。敬请期待 `v3.6.0` 版本!
|
||||
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **支持 MCP**。AstrBot 现已支持接入 MCP 服务器。
|
||||
3. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
4. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
5. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
6. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
7. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat)、飞书、Telegram。后续将支持钉钉、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://dify.ai/),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||
|
||||
> [!TIP]
|
||||
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
> WebUI 在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。未配置 LLM,无法在聊天页使用大模型。(不要再修改 demo 的登录密码了 😭)
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。
|
||||
|
||||
## ✨ 使用方式
|
||||
|
||||
@@ -67,7 +73,15 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
|
||||
#### 手动部署
|
||||
|
||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
推荐使用 `uv`。
|
||||
|
||||
```bash
|
||||
git clone https://github.com/AstrBotDevs/AstrBot && cd AstrBot
|
||||
pip install uv
|
||||
uv run main.py
|
||||
```
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
#### Replit 部署
|
||||
|
||||
@@ -93,7 +107,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
|
||||
|
||||
| 名称 | 支持性 | 类型 | 备注 |
|
||||
| -------- | ------- | ------- | ------- |
|
||||
| OpenAI API | ✔ | 文本生成 | 同时也支持 DeepSeek、Google Gemini、GLM(智谱)、Moonshot(月之暗面)、阿里云百炼、硅基流动、xAI 等所有兼容 OpenAI API 的服务 |
|
||||
| OpenAI API | ✔ | 文本生成 | 也支持 DeepSeek、Google Gemini、GLM、Kimi、硅基流动、xAI 等兼容 OpenAI API 的服务 |
|
||||
| Claude API | ✔ | 文本生成 | |
|
||||
| Google Gemini API | ✔ | 文本生成 | |
|
||||
| Dify | ✔ | LLMOps | |
|
||||
@@ -135,38 +149,36 @@ pre-commit install
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
> [!NOTE]
|
||||
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
|
||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试中)✨_
|
||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试)✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
|
||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
|
||||
|
||||
_✨ 自然语言待办事项 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
|
||||
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
|
||||
|
||||
_✨ 插件系统——部分插件展示 ✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
|
||||
<img src="https://github.com/user-attachments/assets/0cdbf564-2f59-4da5-b524-ce0e7ef3d978" width=600>
|
||||
|
||||
_✨ 管理面板 ✨_
|
||||
|
||||

|
||||
|
||||
_✨ 内置 Web Chat,在线与机器人交互 ✨_
|
||||
_✨ WebUI ✨_
|
||||
|
||||
</div>
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=AstrBotDevs/AstrBot" />
|
||||
</a>
|
||||
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
|
||||
+1
-1
@@ -28,7 +28,7 @@ AstrBot is a loosely coupled, asynchronous chatbot and development framework tha
|
||||
|
||||
1. **LLM Conversations** - Supports various LLMs including OpenAI API, Google Gemini, Llama, Deepseek, ChatGLM, etc. Enables local model deployment via Ollama/LLMTuner. Features multi-turn dialogues, personality contexts, multimodal capabilities (image understanding), and speech-to-text (Whisper).
|
||||
2. **Multi-platform Integration** - Supports QQ (OneBot), QQ Channels, WeChat (Gewechat), Feishu, and Telegram. Planned support for DingTalk, Discord, WhatsApp, and Xiaomi Smart Speakers. Includes rate limiting, whitelisting, keyword filtering, and Baidu content moderation.
|
||||
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://astrbot.app/others/dify.html) for easy access to Dify assistants/knowledge bases/workflows.
|
||||
3. **Agent Capabilities** - Native support for code execution, natural language TODO lists, web search. Integrates with [Dify Platform](https://dify.ai/) for easy access to Dify assistants/knowledge bases/workflows.
|
||||
4. **Plugin System** - Optimized plugin mechanism with minimal development effort. Supports multiple installed plugins.
|
||||
5. **Web Dashboard** - Visual configuration management, plugin controls, logging, and WebChat interface for direct LLM interaction.
|
||||
6. **High Stability & Modularity** - Event bus and pipeline architecture ensures high modularization and loose coupling.
|
||||
|
||||
+1
-1
@@ -28,7 +28,7 @@ AstrBot は、疎結合、非同期、複数のメッセージプラットフォ
|
||||
|
||||
1. **大規模言語モデルの対話**。OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM など、さまざまな大規模言語モデルをサポートし、Ollama、LLMTuner を介してローカルにデプロイされた大規模モデルをサポートします。多輪対話、人格シナリオ、多モーダル機能を備え、画像理解、音声からテキストへの変換(Whisper)をサポートします。
|
||||
2. **複数のメッセージプラットフォームの接続**。QQ(OneBot)、QQ チャンネル、WeChat(Gewechat)、Feishu、Telegram への接続をサポートします。今後、DingTalk、Discord、WhatsApp、Xiaoai 音響をサポートする予定です。レート制限、ホワイトリスト、キーワードフィルタリング、Baidu コンテンツ監査をサポートします。
|
||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://astrbot.app/others/dify.html)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||
3. **エージェント**。一部のエージェント機能をネイティブにサポートし、コードエグゼキューター、自然言語タスク、ウェブ検索などを提供します。[Dify プラットフォーム](https://dify.ai/)と連携し、Dify スマートアシスタント、ナレッジベース、Dify ワークフローを簡単に接続できます。
|
||||
4. **プラグインの拡張**。深く最適化されたプラグインメカニズムを備え、[プラグインの開発](https://astrbot.app/dev/plugin.html)をサポートし、機能を拡張できます。複数のプラグインのインストールをサポートします。
|
||||
5. **ビジュアル管理パネル**。設定の視覚的な変更、プラグイン管理、ログの表示などをサポートし、設定の難易度を低減します。WebChat を統合し、パネル上で大規模モデルと対話できます。
|
||||
6. **高い安定性と高いモジュール性**。イベントバスとパイプラインに基づくアーキテクチャ設計により、高度にモジュール化され、低結合です。
|
||||
|
||||
@@ -2,11 +2,7 @@ from astrbot.core.star.register import (
|
||||
register_star as register, # 注册插件(Star)
|
||||
)
|
||||
|
||||
from astrbot.core.star import Context, Star
|
||||
from astrbot.core.star import Context, Star, StarTools
|
||||
from astrbot.core.star.config import *
|
||||
|
||||
__all__ = [
|
||||
"register",
|
||||
"Context",
|
||||
"Star",
|
||||
]
|
||||
__all__ = ["register", "Context", "Star", "StarTools"]
|
||||
|
||||
@@ -8,6 +8,7 @@ from astrbot.core.db.sqlite import SQLiteDatabase
|
||||
from astrbot.core.config.default import DB_PATH
|
||||
from astrbot.core.config import AstrBotConfig
|
||||
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs("data", exist_ok=True)
|
||||
|
||||
astrbot_config = AstrBotConfig()
|
||||
@@ -19,8 +20,11 @@ if os.environ.get("TESTING", ""):
|
||||
logger.setLevel("DEBUG")
|
||||
|
||||
db_helper = SQLiteDatabase(DB_PATH)
|
||||
sp = SharedPreferences() # 简单的偏好设置存储
|
||||
sp = (
|
||||
SharedPreferences()
|
||||
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
|
||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.5.0"
|
||||
VERSION = "3.5.2"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -98,6 +98,7 @@ DEFAULT_CONFIG = {
|
||||
"plugin_repo_mirror": "",
|
||||
"knowledge_db": {},
|
||||
"persona": [],
|
||||
"timezone": "",
|
||||
}
|
||||
|
||||
|
||||
@@ -519,7 +520,14 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": "https://generativelanguage.googleapis.com/",
|
||||
"timeout": 120,
|
||||
"model_config": {
|
||||
"model": "gemini-1.5-flash",
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
},
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_safety_settings": {
|
||||
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
},
|
||||
"DeepSeek": {
|
||||
@@ -670,12 +678,82 @@ CONFIG_METADATA_2 = {
|
||||
"fishaudio-tts-character": "可莉",
|
||||
"timeout": "20",
|
||||
},
|
||||
"阿里云百炼_TTS(API)": {
|
||||
"id": "dashscope_tts",
|
||||
"type": "dashscope_tts",
|
||||
"enable": False,
|
||||
"api_key": "",
|
||||
"model": "cosyvoice-v1",
|
||||
"dashscope_tts_voice": "loongstella",
|
||||
"timeout": "20",
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"dashscope_tts_voice": {
|
||||
"description": "语音合成模型",
|
||||
"type": "string",
|
||||
"hint": "阿里云百炼语音合成模型名称。具体可参考 https://help.aliyun.com/zh/model-studio/developer-reference/cosyvoice-python-api 等内容",
|
||||
},
|
||||
"gm_resp_image_modal": {
|
||||
"description": "启用图片模态",
|
||||
"type": "bool",
|
||||
"hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。",
|
||||
},
|
||||
"gm_safety_settings": {
|
||||
"description": "安全过滤器",
|
||||
"type": "object",
|
||||
"hint": "设置模型输入的内容安全过滤级别。过滤级别分类为NONE(不屏蔽)、HIGH(高风险时屏蔽)、MEDIUM_AND_ABOVE(中等风险及以上屏蔽)、LOW_AND_ABOVE(低风险及以上时屏蔽),具体参见Gemini API文档。",
|
||||
"items": {
|
||||
"harassment": {
|
||||
"description": "骚扰内容",
|
||||
"type": "string",
|
||||
"hint": "负面或有害评论",
|
||||
"options": [
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
],
|
||||
},
|
||||
"hate_speech": {
|
||||
"description": "仇恨言论",
|
||||
"type": "string",
|
||||
"hint": "粗鲁、无礼或亵渎性质内容",
|
||||
"options": [
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
],
|
||||
},
|
||||
"sexually_explicit": {
|
||||
"description": "露骨色情内容",
|
||||
"type": "string",
|
||||
"hint": "包含性行为或其他淫秽内容的引用",
|
||||
"options": [
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
],
|
||||
},
|
||||
"dangerous_content": {
|
||||
"description": "危险内容",
|
||||
"type": "string",
|
||||
"hint": "宣扬、助长或鼓励有害行为的信息",
|
||||
"options": [
|
||||
"BLOCK_NONE",
|
||||
"BLOCK_ONLY_HIGH",
|
||||
"BLOCK_MEDIUM_AND_ABOVE",
|
||||
"BLOCK_LOW_AND_ABOVE",
|
||||
],
|
||||
},
|
||||
},
|
||||
},
|
||||
"rag_options": {
|
||||
"description": "RAG 选项",
|
||||
"type": "object",
|
||||
"hint": "检索知识库设置, 非必填。仅 Agent 应用类型支持(智能体应用, 包括 RAG 应用)",
|
||||
"hint": "检索知识库设置, 非必填。仅 Agent 应用类型支持(智能体应用, 包括 RAG 应用)。阿里云百炼应用开启此功能后将无法多轮对话。",
|
||||
"items": {
|
||||
"pipeline_ids": {
|
||||
"description": "知识库 ID 列表",
|
||||
@@ -1095,6 +1173,12 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "启用后,会以添加环境变量的方式设置代理。格式为 `http://ip:port`",
|
||||
},
|
||||
"timezone": {
|
||||
"description": "时区",
|
||||
"type": "string",
|
||||
"obvious_hint": True,
|
||||
"hint": "时区设置。请填写 IANA 时区名称, 如 Asia/Shanghai, 为空时使用系统默认时区。所有时区请查看: https://data.iana.org/time-zones/tzdb-2021a/zone1970.tab",
|
||||
},
|
||||
"log_level": {
|
||||
"description": "控制台日志级别",
|
||||
"type": "string",
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
"""
|
||||
AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库
|
||||
|
||||
在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
|
||||
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
|
||||
"""
|
||||
|
||||
import uuid
|
||||
import json
|
||||
import asyncio
|
||||
@@ -11,24 +18,34 @@ class ConversationManager:
|
||||
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
|
||||
|
||||
def __init__(self, db_helper: BaseDatabase):
|
||||
# session_conversations 字典记录会话ID-对话ID 映射关系
|
||||
self.session_conversations: Dict[str, str] = sp.get("session_conversation", {})
|
||||
self.db = db_helper
|
||||
self.save_interval = 60 # 每 60 秒保存一次
|
||||
self._start_periodic_save()
|
||||
|
||||
def _start_periodic_save(self):
|
||||
"""启动定时保存任务"""
|
||||
asyncio.create_task(self._periodic_save())
|
||||
|
||||
async def _periodic_save(self):
|
||||
"""定时保存会话对话映射关系到存储中"""
|
||||
while True:
|
||||
await asyncio.sleep(self.save_interval)
|
||||
self._save_to_storage()
|
||||
|
||||
def _save_to_storage(self):
|
||||
"""保存会话对话映射关系到存储中"""
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def new_conversation(self, unified_msg_origin: str) -> str:
|
||||
"""新建对话,并将当前会话的对话转移到新对话"""
|
||||
"""新建对话,并将当前会话的对话转移到新对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
conversation_id = str(uuid.uuid4())
|
||||
self.db.new_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
@@ -36,14 +53,24 @@ class ConversationManager:
|
||||
return conversation_id
|
||||
|
||||
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
|
||||
"""切换会话的对话"""
|
||||
"""切换会话的对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
self.session_conversations[unified_msg_origin] = conversation_id
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def delete_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str = None
|
||||
):
|
||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话"""
|
||||
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.delete_conversation(user_id=unified_msg_origin, cid=conversation_id)
|
||||
@@ -51,23 +78,48 @@ class ConversationManager:
|
||||
sp.put("session_conversation", self.session_conversations)
|
||||
|
||||
async def get_curr_conversation_id(self, unified_msg_origin: str) -> str:
|
||||
"""获取会话当前的对话 ID"""
|
||||
"""获取会话当前的对话 ID
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
Returns:
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
"""
|
||||
return self.session_conversations.get(unified_msg_origin, None)
|
||||
|
||||
async def get_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str
|
||||
) -> Conversation:
|
||||
"""获取会话的对话"""
|
||||
"""获取会话的对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
Returns:
|
||||
conversation (Conversation): 对话对象
|
||||
"""
|
||||
return self.db.get_conversation_by_user_id(unified_msg_origin, conversation_id)
|
||||
|
||||
async def get_conversations(self, unified_msg_origin: str) -> List[Conversation]:
|
||||
"""获取会话的所有对话"""
|
||||
"""获取会话的所有对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
Returns:
|
||||
conversations (List[Conversation]): 对话对象列表
|
||||
"""
|
||||
return self.db.get_conversations(unified_msg_origin)
|
||||
|
||||
async def update_conversation(
|
||||
self, unified_msg_origin: str, conversation_id: str, history: List[Dict]
|
||||
):
|
||||
"""更新会话的对话"""
|
||||
"""更新会话的对话
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
|
||||
"""
|
||||
if conversation_id:
|
||||
self.db.update_conversation(
|
||||
user_id=unified_msg_origin,
|
||||
@@ -76,7 +128,12 @@ class ConversationManager:
|
||||
)
|
||||
|
||||
async def update_conversation_title(self, unified_msg_origin: str, title: str):
|
||||
"""更新会话的对话标题"""
|
||||
"""更新会话的对话标题
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
title (str): 对话标题
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_title(
|
||||
@@ -86,7 +143,12 @@ class ConversationManager:
|
||||
async def update_conversation_persona_id(
|
||||
self, unified_msg_origin: str, persona_id: str
|
||||
):
|
||||
"""更新会话的对话 Persona ID"""
|
||||
"""更新会话的对话 Persona ID
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
persona_id (str): 对话 Persona ID
|
||||
"""
|
||||
conversation_id = self.session_conversations.get(unified_msg_origin)
|
||||
if conversation_id:
|
||||
self.db.update_conversation_persona_id(
|
||||
@@ -96,6 +158,14 @@ class ConversationManager:
|
||||
async def get_human_readable_context(
|
||||
self, unified_msg_origin, conversation_id, page=1, page_size=10
|
||||
):
|
||||
"""获取人类可读的上下文
|
||||
|
||||
Args:
|
||||
unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
|
||||
conversation_id (str): 对话 ID, 是 uuid 格式的字符串
|
||||
page (int): 页码
|
||||
page_size (int): 每页大小
|
||||
"""
|
||||
conversation = await self.get_conversation(unified_msg_origin, conversation_id)
|
||||
history = json.loads(conversation.history)
|
||||
|
||||
|
||||
@@ -1,3 +1,14 @@
|
||||
"""
|
||||
Astrbot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
|
||||
工作流程:
|
||||
1. 初始化所有组件
|
||||
2. 启动事件总线和任务, 所有任务都在这里运行
|
||||
3. 执行启动完成事件钩子
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import asyncio
|
||||
import time
|
||||
@@ -24,31 +35,51 @@ from astrbot.core.star.star_handler import star_map
|
||||
|
||||
|
||||
class AstrBotCoreLifecycle:
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
self.log_broker = log_broker
|
||||
self.astrbot_config = astrbot_config
|
||||
self.db = db
|
||||
"""
|
||||
AstrBot 核心生命周期管理类, 负责管理 AstrBot 的启动、停止、重启等操作。
|
||||
该类负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、
|
||||
EventBus 等。
|
||||
该类还负责加载和执行插件, 以及处理事件总线的分发。
|
||||
"""
|
||||
|
||||
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
|
||||
self.log_broker = log_broker # 初始化日志代理
|
||||
self.astrbot_config = astrbot_config # 初始化配置
|
||||
self.db = db # 初始化数据库
|
||||
|
||||
# 根据环境变量设置代理
|
||||
os.environ["https_proxy"] = self.astrbot_config["http_proxy"]
|
||||
os.environ["http_proxy"] = self.astrbot_config["http_proxy"]
|
||||
os.environ["no_proxy"] = "localhost"
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
初始化 AstrBot 核心生命周期管理类, 负责初始化各个组件, 包括 ProviderManager、PlatformManager、KnowledgeDBManager、ConversationManager、PluginManager、PipelineScheduler、EventBus、AstrBotUpdator等。
|
||||
"""
|
||||
|
||||
# 初始化日志代理
|
||||
logger.info("AstrBot v" + VERSION)
|
||||
if os.environ.get("TESTING", ""):
|
||||
logger.setLevel("DEBUG")
|
||||
logger.setLevel("DEBUG") # 测试模式下设置日志级别为 DEBUG
|
||||
else:
|
||||
logger.setLevel(self.astrbot_config["log_level"])
|
||||
logger.setLevel(self.astrbot_config["log_level"]) # 设置日志级别
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
|
||||
# 初始化供应商管理器
|
||||
self.provider_manager = ProviderManager(self.astrbot_config, self.db)
|
||||
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
|
||||
# 初始化知识库管理器
|
||||
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
|
||||
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
self.astrbot_config,
|
||||
@@ -58,35 +89,50 @@ class AstrBotCoreLifecycle:
|
||||
self.conversation_manager,
|
||||
self.knowledge_db_manager,
|
||||
)
|
||||
|
||||
# 初始化插件管理器
|
||||
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
|
||||
|
||||
# 扫描、注册插件、实例化插件类
|
||||
await self.plugin_manager.reload()
|
||||
"""扫描、注册插件、实例化插件类"""
|
||||
|
||||
# 根据配置实例化各个 Provider
|
||||
await self.provider_manager.initialize()
|
||||
"""根据配置实例化各个 Provider"""
|
||||
|
||||
# 初始化消息事件流水线调度器
|
||||
self.pipeline_scheduler = PipelineScheduler(
|
||||
PipelineContext(self.astrbot_config, self.plugin_manager)
|
||||
)
|
||||
await self.pipeline_scheduler.initialize()
|
||||
"""初始化消息事件流水线调度器"""
|
||||
|
||||
# 初始化更新器
|
||||
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
|
||||
|
||||
# 初始化事件总线
|
||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||
|
||||
# 记录启动时间
|
||||
self.start_time = int(time.time())
|
||||
|
||||
# 初始化当前任务列表
|
||||
self.curr_tasks: List[asyncio.Task] = []
|
||||
|
||||
# 根据配置实例化各个平台适配器
|
||||
await self.platform_manager.initialize()
|
||||
"""根据配置实例化各个平台适配器"""
|
||||
|
||||
# 初始化关闭控制面板的事件
|
||||
self.dashboard_shutdown_event = asyncio.Event()
|
||||
|
||||
def _load(self):
|
||||
"""加载事件总线和任务并初始化"""
|
||||
|
||||
# 创建一个异步任务来执行事件总线的 dispatch() 方法
|
||||
# dispatch是一个无限循环的协程, 从事件队列中获取事件并处理
|
||||
event_bus_task = asyncio.create_task(
|
||||
self.event_bus.dispatch(), name="event_bus"
|
||||
)
|
||||
|
||||
# 把插件中注册的所有协程函数注册到事件总线中并执行
|
||||
extra_tasks = []
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||
@@ -100,17 +146,24 @@ class AstrBotCoreLifecycle:
|
||||
self.start_time = int(time.time())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task):
|
||||
"""异步任务包装器, 用于处理异步任务执行中出现的各种异常
|
||||
|
||||
Args:
|
||||
task (asyncio.Task): 要执行的异步任务
|
||||
"""
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
pass # 任务被取消, 静默处理
|
||||
except Exception as e:
|
||||
# 获取完整的异常堆栈信息, 按行分割并记录到日志中
|
||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||
for line in traceback.format_exc().split("\n"):
|
||||
logger.error(f"| {line}")
|
||||
logger.error("-------")
|
||||
|
||||
async def start(self):
|
||||
"""启动 AstrBot 核心生命周期管理类, 用load加载事件总线和任务并初始化, 执行启动完成事件钩子"""
|
||||
self._load()
|
||||
logger.info("AstrBot 启动完成。")
|
||||
|
||||
@@ -127,16 +180,29 @@ class AstrBotCoreLifecycle:
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 同时运行curr_tasks中的所有任务
|
||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||
|
||||
async def stop(self):
|
||||
"""停止 AstrBot 核心生命周期管理类, 取消所有当前任务并终止各个管理器"""
|
||||
# 请求停止所有正在运行的异步任务
|
||||
for task in self.curr_tasks:
|
||||
task.cancel()
|
||||
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
try:
|
||||
await self.plugin_manager._terminate_plugin(plugin)
|
||||
except Exception as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
logger.warning(
|
||||
f"插件 {plugin.name} 未被正常终止 {e!s}, 可能会导致资源泄露等问题。"
|
||||
)
|
||||
|
||||
await self.provider_manager.terminate()
|
||||
await self.platform_manager.terminate()
|
||||
self.dashboard_shutdown_event.set()
|
||||
|
||||
# 再次遍历curr_tasks等待每个任务真正结束
|
||||
for task in self.curr_tasks:
|
||||
try:
|
||||
await task
|
||||
@@ -146,6 +212,7 @@ class AstrBotCoreLifecycle:
|
||||
logger.error(f"任务 {task.get_name()} 发生错误: {e}")
|
||||
|
||||
async def restart(self):
|
||||
"""重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例"""
|
||||
await self.provider_manager.terminate()
|
||||
await self.platform_manager.terminate()
|
||||
self.dashboard_shutdown_event.set()
|
||||
@@ -154,6 +221,7 @@ class AstrBotCoreLifecycle:
|
||||
).start()
|
||||
|
||||
def load_platform(self) -> List[asyncio.Task]:
|
||||
"""加载平台实例并返回所有平台实例的异步任务列表"""
|
||||
tasks = []
|
||||
platform_insts = self.platform_manager.get_insts()
|
||||
for platform_inst in platform_insts:
|
||||
|
||||
@@ -0,0 +1,112 @@
|
||||
import json
|
||||
import aiosqlite
|
||||
import os
|
||||
from typing import Any
|
||||
from .plugin_storage import PluginStorage
|
||||
|
||||
DBPATH = "data/plugin_data/sqlite/plugin_data.db"
|
||||
|
||||
|
||||
class SQLitePluginStorage(PluginStorage):
|
||||
"""插件数据的 SQLite 存储实现类。
|
||||
|
||||
该类提供异步方式将插件数据存储到 SQLite 数据库中,支持数据的增删改查操作。
|
||||
所有数据以 (plugin, key) 作为复合主键进行索引。
|
||||
"""
|
||||
|
||||
_instance = None # Standalone instance of the class
|
||||
_db_conn = None
|
||||
db_path = None
|
||||
|
||||
def __new__(cls):
|
||||
"""
|
||||
创建或获取 SQLitePluginStorage 的单例实例。
|
||||
如果实例已存在,则返回现有实例;否则创建一个新实例。
|
||||
数据在 `data/plugin_data/sqlite/plugin_data.db` 下。
|
||||
"""
|
||||
os.makedirs(os.path.dirname(DBPATH), exist_ok=True)
|
||||
if cls._instance is None:
|
||||
cls._instance = super(SQLitePluginStorage, cls).__new__(cls)
|
||||
cls._instance.db_path = DBPATH
|
||||
return cls._instance
|
||||
|
||||
async def _init_db(self):
|
||||
"""初始化数据库连接(只执行一次)"""
|
||||
if SQLitePluginStorage._db_conn is None:
|
||||
SQLitePluginStorage._db_conn = await aiosqlite.connect(self.db_path)
|
||||
await self._setup_db()
|
||||
|
||||
async def _setup_db(self):
|
||||
"""
|
||||
异步初始化数据库。
|
||||
|
||||
创建插件数据表,如果表不存在则创建,表结构包含 plugin、key 和 value 字段,
|
||||
其中 plugin 和 key 组合作为主键。
|
||||
"""
|
||||
await self._db_conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS plugin_data (
|
||||
plugin TEXT,
|
||||
key TEXT,
|
||||
value TEXT,
|
||||
PRIMARY KEY (plugin, key)
|
||||
)
|
||||
""")
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def set(self, plugin: str, key: str, value: Any):
|
||||
"""
|
||||
异步存储数据。
|
||||
|
||||
将指定插件的键值对存入数据库,如果键已存在则更新值。
|
||||
值会被序列化为 JSON 字符串后存储。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
value: 要存储的数据值(任意类型,将被 JSON 序列化)
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"INSERT INTO plugin_data (plugin, key, value) VALUES (?, ?, ?) "
|
||||
"ON CONFLICT(plugin, key) DO UPDATE SET value = excluded.value",
|
||||
(plugin, key, json.dumps(value)),
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
|
||||
async def get(self, plugin: str, key: str) -> Any:
|
||||
"""
|
||||
异步获取数据。
|
||||
|
||||
从数据库中获取指定插件和键名对应的值,
|
||||
返回的值会从 JSON 字符串反序列化为原始数据类型。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 数据键名
|
||||
|
||||
Returns:
|
||||
Any: 存储的数据值,如果未找到则返回 None
|
||||
"""
|
||||
await self._init_db()
|
||||
async with self._db_conn.execute(
|
||||
"SELECT value FROM plugin_data WHERE plugin = ? AND key = ?",
|
||||
(plugin, key),
|
||||
) as cursor:
|
||||
row = await cursor.fetchone()
|
||||
return json.loads(row[0]) if row else None
|
||||
|
||||
async def delete(self, plugin: str, key: str):
|
||||
"""
|
||||
异步删除数据。
|
||||
|
||||
从数据库中删除指定插件和键名对应的数据项。
|
||||
|
||||
Args:
|
||||
plugin: 插件标识符
|
||||
key: 要删除的数据键名
|
||||
"""
|
||||
await self._init_db()
|
||||
await self._db_conn.execute(
|
||||
"DELETE FROM plugin_data WHERE plugin = ? AND key = ?", (plugin, key)
|
||||
)
|
||||
await self._db_conn.commit()
|
||||
@@ -6,6 +6,8 @@ from typing import List
|
||||
|
||||
@dataclass
|
||||
class Platform:
|
||||
"""平台使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
@@ -13,6 +15,8 @@ class Platform:
|
||||
|
||||
@dataclass
|
||||
class Provider:
|
||||
"""供应商使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
@@ -20,6 +24,8 @@ class Provider:
|
||||
|
||||
@dataclass
|
||||
class Plugin:
|
||||
"""插件使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
@@ -27,6 +33,8 @@ class Plugin:
|
||||
|
||||
@dataclass
|
||||
class Command:
|
||||
"""命令使用统计数据"""
|
||||
|
||||
name: str
|
||||
count: int
|
||||
timestamp: int
|
||||
|
||||
@@ -1,3 +1,16 @@
|
||||
"""
|
||||
事件总线, 用于处理事件的分发和处理
|
||||
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
|
||||
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||
|
||||
class:
|
||||
EventBus: 事件总线, 用于处理事件的分发和处理
|
||||
|
||||
工作流程:
|
||||
1. 维护一个异步队列, 来接受各种消息事件
|
||||
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from asyncio import Queue
|
||||
from astrbot.core.pipeline.scheduler import PipelineScheduler
|
||||
@@ -6,21 +19,38 @@ from .platform import AstrMessageEvent
|
||||
|
||||
|
||||
class EventBus:
|
||||
"""事件总线: 用于处理事件的分发和处理
|
||||
|
||||
维护一个异步队列, 来接受各种消息事件
|
||||
"""
|
||||
|
||||
def __init__(self, event_queue: Queue, pipeline_scheduler: PipelineScheduler):
|
||||
self.event_queue = event_queue
|
||||
self.pipeline_scheduler = pipeline_scheduler
|
||||
self.event_queue = event_queue # 事件队列
|
||||
self.pipeline_scheduler = pipeline_scheduler # 管道调度器
|
||||
|
||||
async def dispatch(self):
|
||||
"""无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑"""
|
||||
while True:
|
||||
event: AstrMessageEvent = await self.event_queue.get()
|
||||
self._print_event(event)
|
||||
asyncio.create_task(self.pipeline_scheduler.execute(event))
|
||||
event: AstrMessageEvent = (
|
||||
await self.event_queue.get()
|
||||
) # 从事件队列中获取新的事件
|
||||
self._print_event(event) # 打印日志
|
||||
asyncio.create_task(
|
||||
self.pipeline_scheduler.execute(event)
|
||||
) # 创建新的异步任务来执行管道调度器的处理逻辑
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent):
|
||||
"""用于记录事件信息
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
"""
|
||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||
if event.get_sender_name():
|
||||
logger.info(
|
||||
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
|
||||
)
|
||||
# 没有发送者名称: [平台名] 发送者ID: 消息概要
|
||||
else:
|
||||
logger.info(
|
||||
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
"""
|
||||
AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。
|
||||
|
||||
工作流程:
|
||||
1. 初始化核心生命周期, 传递数据库和日志代理实例到核心生命周期
|
||||
2. 运行核心生命周期任务和仪表板服务器
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import traceback
|
||||
from astrbot.core import logger
|
||||
@@ -8,6 +16,8 @@ from astrbot.dashboard.server import AstrBotDashboard
|
||||
|
||||
|
||||
class InitialLoader:
|
||||
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
|
||||
|
||||
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
|
||||
self.db = db
|
||||
self.logger = logger
|
||||
@@ -27,10 +37,12 @@ class InitialLoader:
|
||||
self.dashboard_server = AstrBotDashboard(
|
||||
core_lifecycle, self.db, core_lifecycle.dashboard_shutdown_event
|
||||
)
|
||||
task = asyncio.gather(core_task, self.dashboard_server.run())
|
||||
task = asyncio.gather(
|
||||
core_task, self.dashboard_server.run()
|
||||
) # 启动核心任务和仪表板服务器
|
||||
|
||||
try:
|
||||
await task
|
||||
await task # 整个AstrBot在这里运行
|
||||
except asyncio.CancelledError:
|
||||
logger.info("🌈 正在关闭 AstrBot...")
|
||||
await core_lifecycle.stop()
|
||||
|
||||
+117
-18
@@ -1,3 +1,26 @@
|
||||
"""
|
||||
日志系统, 用于支持核心组件和插件的日志记录, 提供了日志订阅功能
|
||||
|
||||
const:
|
||||
CACHED_SIZE: 日志缓存大小, 用于限制缓存的日志数量
|
||||
log_color_config: 日志颜色配置, 定义了不同日志级别的颜色
|
||||
|
||||
class:
|
||||
LogBroker: 日志代理类, 用于缓存和分发日志消息
|
||||
LogQueueHandler: 日志处理器, 用于将日志消息发送到 LogBroker
|
||||
LogManager: 日志管理器, 用于创建和配置日志记录器
|
||||
|
||||
function:
|
||||
is_plugin_path: 检查文件路径是否来自插件目录
|
||||
get_short_level_name: 将日志级别名称转换为四个字母的缩写
|
||||
|
||||
工作流程:
|
||||
1. 通过 LogManager.GetLogger() 获取日志器, 配置了控制台输出和多个格式化过滤器
|
||||
2. 通过 set_queue_handler() 设置日志处理器, 将日志消息发送到 LogBroker
|
||||
3. logBroker 维护一个订阅者列表, 负责将日志分发给所有订阅者
|
||||
4. 订阅者可以使用 register() 方法注册到 LogBroker, 订阅日志流
|
||||
"""
|
||||
|
||||
import logging
|
||||
import colorlog
|
||||
import asyncio
|
||||
@@ -6,7 +29,9 @@ from collections import deque
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
|
||||
# 日志缓存大小
|
||||
CACHED_SIZE = 200
|
||||
# 日志颜色配置
|
||||
log_color_config = {
|
||||
"DEBUG": "green",
|
||||
"INFO": "bold_cyan",
|
||||
@@ -19,8 +44,13 @@ log_color_config = {
|
||||
|
||||
|
||||
def is_plugin_path(pathname):
|
||||
"""
|
||||
检查文件路径是否来自插件目录
|
||||
"""检查文件路径是否来自插件目录
|
||||
|
||||
Args:
|
||||
pathname (str): 文件路径
|
||||
|
||||
Returns:
|
||||
bool: 如果路径来自插件目录,则返回 True,否则返回 False
|
||||
"""
|
||||
if not pathname:
|
||||
return False
|
||||
@@ -30,8 +60,13 @@ def is_plugin_path(pathname):
|
||||
|
||||
|
||||
def get_short_level_name(level_name):
|
||||
"""
|
||||
将日志级别名称转换为四个字母的缩写
|
||||
"""将日志级别名称转换为四个字母的缩写
|
||||
|
||||
Args:
|
||||
level_name (str): 日志级别名称, 如 "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"
|
||||
|
||||
Returns:
|
||||
str: 四个字母的日志级别缩写
|
||||
"""
|
||||
level_map = {
|
||||
"DEBUG": "DBUG",
|
||||
@@ -44,12 +79,21 @@ def get_short_level_name(level_name):
|
||||
|
||||
|
||||
class LogBroker:
|
||||
"""日志代理类, 用于缓存和分发日志消息
|
||||
|
||||
发布-订阅模式
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.log_cache = deque(maxlen=CACHED_SIZE)
|
||||
self.subscribers: List[Queue] = []
|
||||
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
|
||||
self.subscribers: List[Queue] = [] # 订阅者列表
|
||||
|
||||
def register(self) -> Queue:
|
||||
"""给每个订阅者返回一个带有日志缓存的队列"""
|
||||
"""注册新的订阅者, 并给每个订阅者返回一个带有日志缓存的队列
|
||||
|
||||
Returns:
|
||||
Queue: 订阅者的队列, 可用于接收日志消息
|
||||
"""
|
||||
q = Queue(maxsize=CACHED_SIZE + 10)
|
||||
for log in self.log_cache:
|
||||
q.put_nowait(log)
|
||||
@@ -57,11 +101,20 @@ class LogBroker:
|
||||
return q
|
||||
|
||||
def unregister(self, q: Queue):
|
||||
"""取消订阅"""
|
||||
"""取消订阅
|
||||
|
||||
Args:
|
||||
q (Queue): 需要取消订阅的队列
|
||||
"""
|
||||
self.subscribers.remove(q)
|
||||
|
||||
def publish(self, log_entry: str):
|
||||
"""发布消息"""
|
||||
def publish(self, log_entry: dict):
|
||||
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
|
||||
|
||||
Args:
|
||||
log_entry (dict): 日志消息, 包含日志级别和日志内容.
|
||||
example: {"level": "INFO", "data": "This is a log message.", "time": "2023-10-01 12:00:00"}
|
||||
"""
|
||||
self.log_cache.append(log_entry)
|
||||
for q in self.subscribers:
|
||||
try:
|
||||
@@ -71,24 +124,57 @@ class LogBroker:
|
||||
|
||||
|
||||
class LogQueueHandler(logging.Handler):
|
||||
"""日志处理器, 用于将日志消息发送到 LogBroker
|
||||
|
||||
继承自 logging.Handler
|
||||
"""
|
||||
|
||||
def __init__(self, log_broker: LogBroker):
|
||||
super().__init__()
|
||||
self.log_broker = log_broker
|
||||
|
||||
def emit(self, record):
|
||||
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
|
||||
这个方法会在每次日志记录时被调用
|
||||
|
||||
Args:
|
||||
record (logging.LogRecord): 日志记录对象, 包含日志信息
|
||||
"""
|
||||
log_entry = self.format(record)
|
||||
self.log_broker.publish(log_entry)
|
||||
self.log_broker.publish({
|
||||
"level": record.levelname,
|
||||
"time": record.asctime,
|
||||
"data": log_entry,
|
||||
})
|
||||
|
||||
|
||||
class LogManager:
|
||||
"""日志管理器, 用于创建和配置日志记录器
|
||||
|
||||
提供了获取默认日志记录器logger和设置队列处理器的方法
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def GetLogger(cls, log_name: str = "default"):
|
||||
"""获取指定名称的日志记录器logger
|
||||
|
||||
Args:
|
||||
log_name (str): 日志记录器的名称, 默认为 "default"
|
||||
|
||||
Returns:
|
||||
logging.Logger: 返回配置好的日志记录器
|
||||
"""
|
||||
logger = logging.getLogger(log_name)
|
||||
# 检查该logger或父级logger是否已经有处理器, 如果已经有处理器, 直接返回该logger, 避免重复配置
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.DEBUG)
|
||||
# 如果logger没有处理器
|
||||
console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出
|
||||
console_handler.setLevel(
|
||||
logging.DEBUG
|
||||
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
||||
|
||||
# 创建彩色日志格式化器, 输出日志格式为: [时间] [插件标签] [日志级别] [文件名:行号]: 日志消息
|
||||
console_formatter = colorlog.ColoredFormatter(
|
||||
fmt="%(log_color)s [%(asctime)s] %(plugin_tag)s [%(short_levelname)-4s] [%(filename)s:%(lineno)d]: %(message)s %(reset)s",
|
||||
datefmt="%H:%M:%S",
|
||||
@@ -96,6 +182,8 @@ class LogManager:
|
||||
)
|
||||
|
||||
class PluginFilter(logging.Filter):
|
||||
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
|
||||
|
||||
def filter(self, record):
|
||||
record.plugin_tag = (
|
||||
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
|
||||
@@ -103,6 +191,9 @@ class LogManager:
|
||||
return True
|
||||
|
||||
class FileNameFilter(logging.Filter):
|
||||
"""文件名过滤器类, 用于修改日志记录的文件名格式
|
||||
例如: 将文件路径 /path/to/file.py 转换为 file.<file> 格式"""
|
||||
|
||||
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
|
||||
def filter(self, record):
|
||||
dirname = os.path.dirname(record.pathname)
|
||||
@@ -114,22 +205,30 @@ class LogManager:
|
||||
return True
|
||||
|
||||
class LevelNameFilter(logging.Filter):
|
||||
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
|
||||
|
||||
# 添加短日志级别名称
|
||||
def filter(self, record):
|
||||
record.short_levelname = get_short_level_name(record.levelname)
|
||||
return True
|
||||
|
||||
console_handler.setFormatter(console_formatter)
|
||||
logger.addFilter(PluginFilter())
|
||||
logger.addFilter(FileNameFilter())
|
||||
console_handler.setFormatter(console_formatter) # 设置处理器的格式化器
|
||||
logger.addFilter(PluginFilter()) # 添加插件过滤器
|
||||
logger.addFilter(FileNameFilter()) # 添加文件名过滤器
|
||||
logger.addFilter(LevelNameFilter()) # 添加级别名称过滤器
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.addHandler(console_handler)
|
||||
logger.setLevel(logging.DEBUG) # 设置日志级别为DEBUG
|
||||
logger.addHandler(console_handler) # 添加处理器到logger
|
||||
|
||||
return logger
|
||||
|
||||
@classmethod
|
||||
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
|
||||
"""设置队列处理器, 用于将日志消息发送到 LogBroker
|
||||
|
||||
Args:
|
||||
logger (logging.Logger): 日志记录器
|
||||
log_broker (LogBroker): 日志代理类, 用于缓存和分发日志消息
|
||||
"""
|
||||
handler = LogQueueHandler(log_broker)
|
||||
handler.setLevel(logging.DEBUG)
|
||||
if logger.handlers:
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import enum
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
from dataclasses import dataclass, field
|
||||
from astrbot.core.message.components import BaseMessageComponent, Plain, Image
|
||||
from astrbot.core.message.components import (
|
||||
BaseMessageComponent,
|
||||
Plain,
|
||||
Image,
|
||||
At,
|
||||
AtAll,
|
||||
)
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
||||
@@ -31,6 +37,30 @@ class MessageChain:
|
||||
self.chain.append(Plain(message))
|
||||
return self
|
||||
|
||||
def at(self, name: str, qq: Union[str, int]):
|
||||
"""添加一条 At 消息到消息链 `chain` 中。
|
||||
|
||||
Example:
|
||||
|
||||
CommandResult().at("张三", "12345678910")
|
||||
# 输出 @张三
|
||||
|
||||
"""
|
||||
self.chain.append(At(name=name, qq=qq))
|
||||
return self
|
||||
|
||||
def at_all(self):
|
||||
"""添加一条 AtAll 消息到消息链 `chain` 中。
|
||||
|
||||
Example:
|
||||
|
||||
CommandResult().at_all()
|
||||
# 输出 @所有人
|
||||
|
||||
"""
|
||||
self.chain.append(AtAll())
|
||||
return self
|
||||
|
||||
@deprecated("请使用 message 方法代替。")
|
||||
def error(self, message: str):
|
||||
"""添加一条错误消息到消息链 `chain` 中
|
||||
@@ -152,4 +182,5 @@ class MessageEventResult(MessageChain):
|
||||
return self.result_content_type == ResultContentType.LLM_RESULT
|
||||
|
||||
|
||||
# 为了兼容旧版代码,保留 CommandResult 的别名
|
||||
CommandResult = MessageEventResult
|
||||
|
||||
@@ -12,6 +12,7 @@ from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
from .respond.stage import RespondStage
|
||||
|
||||
# 管道阶段顺序
|
||||
STAGES_ORDER = [
|
||||
"WakingCheckStage", # 检查是否需要唤醒
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
|
||||
@@ -5,5 +5,7 @@ from astrbot.core.star import PluginManager
|
||||
|
||||
@dataclass
|
||||
class PipelineContext:
|
||||
astrbot_config: AstrBotConfig
|
||||
plugin_manager: PluginManager
|
||||
"""上下文对象,包含管道执行所需的上下文信息"""
|
||||
|
||||
astrbot_config: AstrBotConfig # AstrBot 配置对象
|
||||
plugin_manager: PluginManager # 插件管理器对象
|
||||
|
||||
@@ -80,7 +80,6 @@ class LLMRequestSubStage(Stage):
|
||||
conversation_id = await self.conv_manager.get_curr_conversation_id(
|
||||
event.unified_msg_origin
|
||||
)
|
||||
req.session_id = event.unified_msg_origin
|
||||
if not conversation_id:
|
||||
conversation_id = await self.conv_manager.new_conversation(
|
||||
event.unified_msg_origin
|
||||
@@ -134,6 +133,10 @@ class LLMRequestSubStage(Stage):
|
||||
logger.debug("上下文长度超过限制,将截断。")
|
||||
req.contexts = req.contexts[-self.max_context_length * 2 :]
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
try:
|
||||
need_loop = True
|
||||
while need_loop:
|
||||
|
||||
@@ -2,6 +2,7 @@ import random
|
||||
import asyncio
|
||||
import math
|
||||
import traceback
|
||||
import astrbot.core.message.components as Comp
|
||||
from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
@@ -11,11 +12,42 @@ from astrbot.core import logger
|
||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.message.components import Plain, Reply, At
|
||||
|
||||
|
||||
@register_stage
|
||||
class RespondStage(Stage):
|
||||
# 组件类型到其非空判断函数的映射
|
||||
_component_validators = {
|
||||
Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), # 纯文本消息需要strip
|
||||
Comp.Face: lambda comp: comp.id is not None, # QQ表情
|
||||
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||
Comp.Video: lambda comp: bool(comp.file), # 视频
|
||||
Comp.At: lambda comp: bool(comp.qq) or bool(comp.name), # @
|
||||
Comp.AtAll: lambda comp: True, # @所有人
|
||||
Comp.RPS: lambda comp: True, # 不知道是啥(未完成)
|
||||
Comp.Dice: lambda comp: True, # 骰子(未完成)
|
||||
Comp.Shake: lambda comp: True, # 摇一摇(未完成)
|
||||
Comp.Anonymous: lambda comp: True, # 匿名(未完成)
|
||||
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
|
||||
Comp.Contact: lambda comp: True, # 联系人(未完成)
|
||||
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
|
||||
Comp.Music: lambda comp: bool(comp._type) and bool(comp.url) and bool(comp.audio), # 音乐
|
||||
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||
Comp.RedBag: lambda comp: bool(comp.title), # 红包
|
||||
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
|
||||
Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点
|
||||
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
|
||||
Comp.Json: lambda comp: bool(comp.data), # JSON
|
||||
Comp.CardImage: lambda comp: bool(comp.file), # 卡片图片
|
||||
Comp.TTS: lambda comp: bool(comp.text and comp.text.strip()), # 语音合成
|
||||
Comp.Unknown: lambda comp: bool(comp.text and comp.text.strip()), # 未知消息
|
||||
Comp.File: lambda comp: bool(comp.file), # 文件
|
||||
Comp.WechatEmoji: lambda comp: bool(comp.md5), # 微信表情
|
||||
}
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
|
||||
@@ -62,7 +94,7 @@ class RespondStage(Stage):
|
||||
async def _calc_comp_interval(self, comp: BaseMessageComponent) -> float:
|
||||
"""分段回复 计算间隔时间"""
|
||||
if self.interval_method == "log":
|
||||
if isinstance(comp, Plain):
|
||||
if isinstance(comp, Comp.Plain):
|
||||
wc = await self._word_cnt(comp.text)
|
||||
i = math.log(wc + 1, self.log_base)
|
||||
return random.uniform(i, i + 0.5)
|
||||
@@ -72,6 +104,28 @@ class RespondStage(Stage):
|
||||
# random
|
||||
return random.uniform(self.interval[0], self.interval[1])
|
||||
|
||||
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
|
||||
"""检查消息链是否为空
|
||||
|
||||
Args:
|
||||
chain (list[BaseMessageComponent]): 包含消息对象的列表
|
||||
"""
|
||||
if not chain:
|
||||
return True
|
||||
|
||||
for comp in chain:
|
||||
comp_type = type(comp)
|
||||
|
||||
# 检查组件类型是否在字典中
|
||||
if comp_type in self._component_validators:
|
||||
if self._component_validators[comp_type](comp):
|
||||
return False
|
||||
else:
|
||||
logger.info(f"空内容检查: 无法识别的组件类型: {comp_type.__name__}")
|
||||
|
||||
# 如果所有组件都为空
|
||||
return True
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
@@ -82,6 +136,16 @@ class RespondStage(Stage):
|
||||
if len(result.chain) > 0:
|
||||
await event._pre_send()
|
||||
|
||||
# 检查消息链是否为空
|
||||
try:
|
||||
if await self._is_empty_message_chain(result.chain):
|
||||
logger.info("消息为空,跳过发送阶段")
|
||||
event.clear_result()
|
||||
event.stop_event()
|
||||
return
|
||||
except Exception as e:
|
||||
logger.warning(f"空内容检查异常: {e}")
|
||||
|
||||
if self.enable_seg and (
|
||||
(self.only_llm_result and result.is_llm_result())
|
||||
or not self.only_llm_result
|
||||
@@ -89,13 +153,13 @@ class RespondStage(Stage):
|
||||
decorated_comps = []
|
||||
if self.reply_with_mention:
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, At):
|
||||
if isinstance(comp, Comp.At):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
if self.reply_with_quote:
|
||||
for comp in result.chain:
|
||||
if isinstance(comp, Reply):
|
||||
if isinstance(comp, Comp.Reply):
|
||||
decorated_comps.append(comp)
|
||||
result.chain.remove(comp)
|
||||
break
|
||||
|
||||
@@ -7,49 +7,72 @@ from astrbot.core import logger
|
||||
|
||||
|
||||
class PipelineScheduler:
|
||||
"""管道调度器,负责调度各个阶段的执行"""
|
||||
|
||||
def __init__(self, context: PipelineContext):
|
||||
registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__.__name__))
|
||||
self.ctx = context
|
||||
registered_stages.sort(
|
||||
key=lambda x: STAGES_ORDER.index(x.__class__.__name__)
|
||||
) # 按照顺序排序
|
||||
self.ctx = context # 上下文对象
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化管道调度器时, 初始化所有阶段"""
|
||||
for stage in registered_stages:
|
||||
# logger.debug(f"初始化阶段 {stage.__class__ .__name__}")
|
||||
|
||||
await stage.initialize(self.ctx)
|
||||
|
||||
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
|
||||
"""依次执行各个阶段
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
from_stage (int): 从第几个阶段开始执行, 默认从0开始
|
||||
"""
|
||||
for i in range(from_stage, len(registered_stages)):
|
||||
stage = registered_stages[i]
|
||||
stage = registered_stages[i] # 获取当前要执行的阶段
|
||||
# logger.debug(f"执行阶段 {stage.__class__ .__name__}")
|
||||
coro = stage.process(event)
|
||||
if isinstance(coro, AsyncGenerator):
|
||||
async for _ in coro:
|
||||
coroutine = stage.process(
|
||||
event
|
||||
) # 调用阶段的process方法, 返回协程或者异步生成器
|
||||
|
||||
if isinstance(coroutine, AsyncGenerator):
|
||||
# 如果返回的是异步生成器, 实现洋葱模型的核心
|
||||
async for _ in coroutine:
|
||||
# 此处是前置处理完成后的暂停点(yield), 下面开始执行后续阶段
|
||||
if event.is_stopped():
|
||||
logger.debug(
|
||||
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
||||
)
|
||||
break
|
||||
|
||||
# 递归调用, 处理所有后续阶段
|
||||
await self._process_stages(event, i + 1)
|
||||
|
||||
# 此处是后续所有阶段处理完毕后返回的点, 执行后置处理
|
||||
if event.is_stopped():
|
||||
logger.debug(
|
||||
f"阶段 {stage.__class__.__name__} 已终止事件传播。"
|
||||
)
|
||||
break
|
||||
else:
|
||||
await coro
|
||||
# 如果返回的是普通协程(不含yield的async函数), 则不进入下一层(基线条件)
|
||||
# 简单地等待它执行完成, 然后继续执行下一个阶段
|
||||
await coroutine
|
||||
|
||||
if event.is_stopped():
|
||||
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
||||
break
|
||||
|
||||
if event.is_stopped():
|
||||
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
|
||||
break
|
||||
|
||||
async def execute(self, event: AstrMessageEvent):
|
||||
"""执行 pipeline"""
|
||||
"""执行 pipeline
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
"""
|
||||
await self._process_stages(event)
|
||||
|
||||
# 如果没有发送操作, 则发送一个空消息, 以便于后续的处理
|
||||
if not event._has_send_oper and event.get_platform_name() == "webchat":
|
||||
await event.send(None)
|
||||
|
||||
|
||||
@@ -8,8 +8,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from .context import PipelineContext
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
|
||||
|
||||
registered_stages: List[Stage] = []
|
||||
"""维护了所有已注册的 Stage 实现类"""
|
||||
registered_stages: List[Stage] = [] # 维护了所有已注册的 Stage 实现类
|
||||
|
||||
|
||||
def register_stage(cls):
|
||||
@@ -23,14 +22,24 @@ class Stage(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
"""初始化阶段"""
|
||||
"""初始化阶段
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
"""处理事件"""
|
||||
"""处理事件
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象,包含事件的相关信息
|
||||
Returns:
|
||||
Union[None, AsyncGenerator[None, None]]: 处理结果,可能是 None 或者异步生成器, 如果为 None 则表示不需要继续处理, 如果为异步生成器则表示需要继续处理(进入下一个阶段)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def _call_handler(
|
||||
@@ -41,9 +50,23 @@ class Stage(abc.ABC):
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""调用 Handler。"""
|
||||
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
|
||||
ready_to_call = None
|
||||
"""执行事件处理函数并处理其返回结果
|
||||
|
||||
该方法负责调用处理函数并处理不同类型的返回值。它支持两种类型的处理函数:
|
||||
1. 异步生成器: 实现洋葱模型,每次yield都会将控制权交回上层
|
||||
2. 协程: 执行一次并处理返回值
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象
|
||||
event (AstrMessageEvent): 待处理的事件对象
|
||||
handler (Awaitable): 事件处理函数
|
||||
*args: 传递给handler的位置参数
|
||||
**kwargs: 传递给handler的关键字参数
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[None, None]: 异步生成器,用于在管道中传递控制流
|
||||
"""
|
||||
ready_to_call = None # 一个协程或者异步生成器(async def)
|
||||
|
||||
trace_ = None
|
||||
|
||||
@@ -52,29 +75,36 @@ class Stage(abc.ABC):
|
||||
except TypeError as _:
|
||||
# 向下兼容
|
||||
trace_ = traceback.format_exc()
|
||||
# 以前的handler会额外传入一个参数, 但是context对象实际上在插件实例中有一份
|
||||
ready_to_call = handler(event, ctx.plugin_manager.context, *args, **kwargs)
|
||||
|
||||
if isinstance(ready_to_call, AsyncGenerator):
|
||||
_has_yielded = False
|
||||
# 如果是一个异步生成器, 进入洋葱模型
|
||||
_has_yielded = False # 是否返回过值
|
||||
try:
|
||||
async for ret in ready_to_call:
|
||||
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
# 这里逐步执行异步生成器, 对于每个yield返回的ret, 执行下面的代码
|
||||
# 返回值只能是 MessageEventResult 或者 None(无返回值)
|
||||
_has_yielded = True
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
# 如果返回值是 MessageEventResult, 设置结果并继续
|
||||
event.set_result(ret)
|
||||
yield
|
||||
yield # 传递控制权给上一层的process函数
|
||||
else:
|
||||
yield ret
|
||||
# 如果返回值是 None, 则不设置结果并继续
|
||||
# 继续执行后续阶段
|
||||
yield ret # 传递控制权给上一层的process函数
|
||||
if not _has_yielded:
|
||||
# 如果这个异步生成器没有执行到yield分支
|
||||
yield
|
||||
except Exception as e:
|
||||
logger.error(f"Previous Error: {trace_}")
|
||||
raise e
|
||||
elif inspect.iscoroutine(ready_to_call):
|
||||
# 如果只是一个 coroutine
|
||||
# 如果只是一个协程, 直接执行
|
||||
ret = await ready_to_call
|
||||
if isinstance(ret, (MessageEventResult, CommandResult)):
|
||||
event.set_result(ret)
|
||||
yield
|
||||
yield # 传递控制权给上一层的process函数
|
||||
else:
|
||||
yield ret
|
||||
yield ret # 传递控制权给上一层的process函数
|
||||
|
||||
@@ -21,6 +21,11 @@ class WakingCheckStage(Stage):
|
||||
"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
"""初始化唤醒检查阶段
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||
"""
|
||||
self.ctx = ctx
|
||||
self.no_permission_reply = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"no_permission_reply", True
|
||||
|
||||
@@ -15,6 +15,9 @@ class WhitelistCheckStage(Stage):
|
||||
"enable_id_white_list"
|
||||
]
|
||||
self.whitelist = ctx.astrbot_config["platform_settings"]["id_whitelist"]
|
||||
self.whitelist = [
|
||||
str(i).strip() for i in self.whitelist if str(i).strip() != ""
|
||||
]
|
||||
self.wl_ignore_admin_on_group = ctx.astrbot_config["platform_settings"][
|
||||
"wl_ignore_admin_on_group"
|
||||
]
|
||||
@@ -53,7 +56,7 @@ class WhitelistCheckStage(Stage):
|
||||
return
|
||||
if (
|
||||
event.unified_msg_origin not in self.whitelist
|
||||
and event.get_group_id() not in self.whitelist
|
||||
and str(event.get_group_id()).strip() not in self.whitelist
|
||||
):
|
||||
if self.wl_log:
|
||||
logger.info(
|
||||
|
||||
@@ -22,6 +22,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
if isinstance(segment, Plain):
|
||||
d["type"] = "text"
|
||||
d["data"]["text"] = segment.text.strip()
|
||||
# 如果是空文本或者只带换行符的文本,不发送
|
||||
if not d["data"]["text"]:
|
||||
continue
|
||||
elif isinstance(segment, (Image, Record)):
|
||||
# convert to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
@@ -38,6 +41,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
async def send(self, message: MessageChain):
|
||||
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
|
||||
|
||||
if not ret:
|
||||
return
|
||||
|
||||
send_one_by_one = False
|
||||
for seg in message.chain:
|
||||
if isinstance(seg, (Node, Nodes)):
|
||||
|
||||
@@ -24,7 +24,7 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
if isinstance(segment, Comp.Plain):
|
||||
segment.text = segment.text.strip()
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, client.reply_text, segment.text, self.message_obj.raw_message
|
||||
None, client.reply_markdown, "AstrBot", segment.text, self.message_obj.raw_message
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
markdown_str = ""
|
||||
|
||||
@@ -735,3 +735,20 @@ class SimpleGewechatClient:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取群信息结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
async def get_contacts_list(self):
|
||||
"""
|
||||
获取通讯录列表
|
||||
见 https://apifox.com/apidoc/shared/69ba62ca-cb7d-437e-85e4-6f3d3df271b1/api-196794504
|
||||
"""
|
||||
payload = {"appId": self.appid}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/contacts/fetchContactsList",
|
||||
headers=self.headers,
|
||||
json=payload,
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.debug(f"获取通讯录列表结果: {json_blob}")
|
||||
return json_blob
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import telegramify_markdown
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
|
||||
from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class TelegramPlatformEvent(AstrMessageEvent):
|
||||
@@ -49,7 +51,17 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
if at_user_id and not at_flag:
|
||||
i.text = f"@{at_user_id} " + i.text
|
||||
at_flag = True
|
||||
await client.send_message(text=i.text, **payload)
|
||||
text = i.text
|
||||
try:
|
||||
text = telegramify_markdown.markdownify(
|
||||
i.text, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"MarkdownV2 conversion failed: {e}. Using plain text instead."
|
||||
)
|
||||
return
|
||||
await client.send_message(text=text, parse_mode="MarkdownV2", **payload)
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await client.send_photo(photo=image_path, **payload)
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
import base64
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core import web_chat_back_queue
|
||||
|
||||
@@ -47,6 +47,22 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
elif isinstance(comp, Record):
|
||||
# save record to local
|
||||
filename = str(uuid.uuid4()) + ".wav"
|
||||
path = os.path.join(imgs_dir, filename)
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
ph = comp.file[8:]
|
||||
with open(path, "wb") as f:
|
||||
with open(ph, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
await download_image_by_url(comp.file, path=path)
|
||||
else:
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid))
|
||||
else:
|
||||
logger.debug(f"webchat 忽略: {comp.type}")
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
|
||||
@@ -198,6 +198,10 @@ class ProviderManager:
|
||||
from .sources.fishaudio_tts_api_source import (
|
||||
ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI,
|
||||
)
|
||||
case "dashscope_tts":
|
||||
from .sources.dashscope_tts import (
|
||||
ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI,
|
||||
)
|
||||
except (ImportError, ModuleNotFoundError) as e:
|
||||
logger.critical(
|
||||
f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。"
|
||||
@@ -306,10 +310,42 @@ class ProviderManager:
|
||||
|
||||
if len(self.provider_insts) == 0:
|
||||
self.curr_provider_inst = None
|
||||
elif (
|
||||
self.curr_provider_inst is None
|
||||
and len(self.provider_insts) > 0
|
||||
and self.provider_enabled
|
||||
):
|
||||
self.curr_provider_inst = self.provider_insts[0]
|
||||
self.selected_provider_id = self.curr_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。"
|
||||
)
|
||||
|
||||
if len(self.stt_provider_insts) == 0:
|
||||
self.curr_stt_provider_inst = None
|
||||
elif (
|
||||
self.curr_stt_provider_inst is None
|
||||
and len(self.stt_provider_insts) > 0
|
||||
and self.stt_enabled
|
||||
):
|
||||
self.curr_stt_provider_inst = self.stt_provider_insts[0]
|
||||
self.selected_stt_provider_id = self.curr_stt_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。"
|
||||
)
|
||||
|
||||
if len(self.tts_provider_insts) == 0:
|
||||
self.curr_tts_provider_inst = None
|
||||
elif (
|
||||
self.curr_tts_provider_inst is None
|
||||
and len(self.tts_provider_insts) > 0
|
||||
and self.tts_enabled
|
||||
):
|
||||
self.curr_tts_provider_inst = self.tts_provider_insts[0]
|
||||
self.selected_tts_provider_id = self.curr_tts_provider_inst.meta().id
|
||||
logger.info(
|
||||
f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。"
|
||||
)
|
||||
|
||||
def get_insts(self):
|
||||
return self.provider_insts
|
||||
|
||||
@@ -51,10 +51,14 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
self.timeout = int(self.timeout)
|
||||
|
||||
def has_rag_options(self):
|
||||
if (
|
||||
self.rag_options
|
||||
and self.rag_options.get("pipeline_ids", None)
|
||||
and self.rag_options.get("file_ids", None)
|
||||
"""判断是否有 RAG 选项
|
||||
|
||||
Returns:
|
||||
bool: 是否有 RAG 选项
|
||||
"""
|
||||
if self.rag_options and (
|
||||
len(self.rag_options.get("pipeline_ids", [])) > 0
|
||||
or len(self.rag_options.get("file_ids", [])) > 0
|
||||
):
|
||||
return True
|
||||
return False
|
||||
@@ -78,7 +82,7 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
|
||||
if (
|
||||
self.dashscope_app_type in ["agent", "dialog-workflow"]
|
||||
and self.has_rag_options()
|
||||
and not self.has_rag_options()
|
||||
):
|
||||
# 支持多轮对话的
|
||||
new_record = {"role": "user", "content": prompt}
|
||||
@@ -92,12 +96,15 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
# 调用阿里云百炼 API
|
||||
payload = {
|
||||
"app_id": self.app_id,
|
||||
"api_key": self.api_key,
|
||||
"messages": context_query,
|
||||
"biz_params": payload_vars or None,
|
||||
}
|
||||
partial = functools.partial(
|
||||
Application.call,
|
||||
app_id=self.app_id,
|
||||
api_key=self.api_key,
|
||||
messages=context_query,
|
||||
biz_params=payload_vars or None,
|
||||
**payload,
|
||||
)
|
||||
response = await asyncio.get_event_loop().run_in_executor(None, partial)
|
||||
else:
|
||||
@@ -134,7 +141,8 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
if self.output_reference and response.output.get("doc_references", None):
|
||||
ref_str = ""
|
||||
for ref in response.output.get("doc_references", []):
|
||||
ref_str += f"{ref['index_id']}. {ref['title']}\n"
|
||||
ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
|
||||
ref_str += f"{ref['index_id']}. {ref_title}\n"
|
||||
output_text += f"\n\n回答来源:\n{ref_str}"
|
||||
|
||||
return LLMResponse(role="assistant", completion_text=output_text)
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
import dashscope
|
||||
import uuid
|
||||
import asyncio
|
||||
from dashscope.audio.tts_v2 import *
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"dashscope_tts", "Dashscope TTS API", provider_type=ProviderType.TEXT_TO_SPEECH
|
||||
)
|
||||
class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
provider_settings: dict,
|
||||
) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.timeout_ms = float(provider_config.get("timeout", 20))*1000
|
||||
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
self.synthesizer = SpeechSynthesizer(
|
||||
model=self.get_model(),
|
||||
voice=self.voice,
|
||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
|
||||
audio = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.synthesizer.call, text, self.timeout_ms
|
||||
)
|
||||
with open(path, "wb") as f:
|
||||
f.write(audio)
|
||||
return path
|
||||
@@ -35,6 +35,8 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
self.pitch = provider_config.get("pitch", None)
|
||||
self.timeout = provider_config.get("timeout", 30)
|
||||
|
||||
self.proxy = os.getenv("https_proxy", None)
|
||||
|
||||
self.set_model("edge_tts")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
@@ -42,7 +44,7 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3"
|
||||
wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav"
|
||||
|
||||
# 构建Edge TTS参数
|
||||
# 构建 Edge TTS 参数
|
||||
kwargs = {"text": text, "voice": self.voice}
|
||||
if self.rate:
|
||||
kwargs["rate"] = self.rate
|
||||
@@ -52,35 +54,45 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
kwargs["pitch"] = self.pitch
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(**kwargs)
|
||||
communicate = edge_tts.Communicate(proxy=self.proxy, **kwargs)
|
||||
await communicate.save(mp3_path)
|
||||
|
||||
# 使用ffmpeg将MP3转换为标准WAV格式
|
||||
_ = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y", # 覆盖输出文件
|
||||
"-i",
|
||||
mp3_path, # 输入文件
|
||||
"-acodec",
|
||||
"pcm_s16le", # 16位PCM编码
|
||||
"-ar",
|
||||
"24000", # 采样率24kHz (适合微信语音)
|
||||
"-ac",
|
||||
"1", # 单声道
|
||||
"-af",
|
||||
"apad=pad_dur=2", # 确保输出时长准确
|
||||
"-fflags",
|
||||
"+genpts", # 强制生成时间戳
|
||||
"-hide_banner", # 隐藏版本信息
|
||||
wav_path, # 输出文件
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
# 等待进程完成并获取输出
|
||||
stdout, stderr = await _.communicate()
|
||||
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
|
||||
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
|
||||
logger.info(f"[EdgeTTS] 返回值(0代表成功): {_.returncode}")
|
||||
try:
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
ff = FFmpeg()
|
||||
ff.convert(input=mp3_path, output=wav_path)
|
||||
except Exception as e:
|
||||
logger.debug(f"pyffmpeg 转换失败: {e}, 尝试使用 ffmpeg 命令行进行转换")
|
||||
# use ffmpeg command line
|
||||
|
||||
# 使用ffmpeg将MP3转换为标准WAV格式
|
||||
p = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y", # 覆盖输出文件
|
||||
"-i",
|
||||
mp3_path, # 输入文件
|
||||
"-acodec",
|
||||
"pcm_s16le", # 16位PCM编码
|
||||
"-ar",
|
||||
"24000", # 采样率24kHz (适合微信语音)
|
||||
"-ac",
|
||||
"1", # 单声道
|
||||
"-af",
|
||||
"apad=pad_dur=2", # 确保输出时长准确
|
||||
"-fflags",
|
||||
"+genpts", # 强制生成时间戳
|
||||
"-hide_banner", # 隐藏版本信息
|
||||
wav_path, # 输出文件
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
# 等待进程完成并获取输出
|
||||
stdout, stderr = await p.communicate()
|
||||
logger.info(f"[EdgeTTS] FFmpeg 标准输出: {stdout.decode().strip()}")
|
||||
logger.debug(f"FFmpeg错误输出: {stderr.decode().strip()}")
|
||||
logger.info(f"[EdgeTTS] 返回值(0代表成功): {p.returncode}")
|
||||
|
||||
os.remove(mp3_path)
|
||||
if os.path.exists(wav_path) and os.path.getsize(wav_path) > 0:
|
||||
return wav_path
|
||||
|
||||
@@ -2,6 +2,9 @@ import base64
|
||||
import aiohttp
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
@@ -39,6 +42,8 @@ class SimpleGoogleGenAIClient:
|
||||
model: str = "gemini-1.5-flash",
|
||||
system_instruction: str = "",
|
||||
tools: dict = None,
|
||||
modalities: List[str] = ["Text"],
|
||||
safety_settings: List[dict] = [],
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
@@ -46,6 +51,13 @@ class SimpleGoogleGenAIClient:
|
||||
if tools:
|
||||
payload["tools"] = [tools]
|
||||
payload["contents"] = contents
|
||||
payload["generationConfig"] = {
|
||||
"responseModalities": modalities,
|
||||
}
|
||||
payload["safetySettings"] = [
|
||||
{"category": s["category"], "threshold": s["threshold"]}
|
||||
for s in safety_settings
|
||||
]
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = (
|
||||
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
||||
@@ -99,6 +111,21 @@ class ProviderGoogleGenAI(Provider):
|
||||
)
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
|
||||
safety_mapping = {
|
||||
"harassment": "HARM_CATEGORY_HARASSMENT",
|
||||
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
}
|
||||
|
||||
self.safety_settings = []
|
||||
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
||||
for config_key, harm_category in safety_mapping.items():
|
||||
if threshold := user_safety_config.get(config_key):
|
||||
self.safety_settings.append(
|
||||
{"category": harm_category, "threshold": threshold}
|
||||
)
|
||||
|
||||
async def get_models(self):
|
||||
return await self.client.models_list()
|
||||
|
||||
@@ -120,7 +147,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str):
|
||||
if not message["content"]:
|
||||
message["content"] = "<empty_content>"
|
||||
message["content"] = ""
|
||||
|
||||
google_genai_conversation.append(
|
||||
{"role": "user", "parts": [{"text": message["content"]}]}
|
||||
@@ -131,7 +158,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
for part in message["content"]:
|
||||
if part["type"] == "text":
|
||||
if not part["text"]:
|
||||
part["text"] = "<empty_content>"
|
||||
part["text"] = ""
|
||||
parts.append({"text": part["text"]})
|
||||
elif part["type"] == "image_url":
|
||||
parts.append(
|
||||
@@ -149,7 +176,7 @@ class ProviderGoogleGenAI(Provider):
|
||||
elif message["role"] == "assistant":
|
||||
if "content" in message:
|
||||
if not message["content"]:
|
||||
message["content"] = "<empty_content>"
|
||||
message["content"] = ""
|
||||
google_genai_conversation.append(
|
||||
{"role": "model", "parts": [{"text": message["content"]}]}
|
||||
)
|
||||
@@ -185,22 +212,54 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
|
||||
|
||||
result = await self.client.generate_content(
|
||||
contents=google_genai_conversation,
|
||||
model=self.get_model(),
|
||||
system_instruction=system_instruction,
|
||||
tools=tool,
|
||||
)
|
||||
logger.debug(f"result: {result}")
|
||||
modalites = ["Text"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalites.append("Image")
|
||||
|
||||
if "candidates" not in result:
|
||||
raise Exception("Gemini 返回异常结果: " + str(result))
|
||||
loop = True
|
||||
while loop:
|
||||
loop = False
|
||||
result = await self.client.generate_content(
|
||||
contents=google_genai_conversation,
|
||||
model=self.get_model(),
|
||||
system_instruction=system_instruction,
|
||||
tools=tool,
|
||||
modalities=modalites,
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
# Developer instruction is not enabled for models/gemini-2.0-flash-exp
|
||||
if "Developer instruction is not enabled" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。"
|
||||
)
|
||||
system_instruction = ""
|
||||
loop = True
|
||||
|
||||
elif "Function calling is not enabled" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。"
|
||||
)
|
||||
tool = None
|
||||
loop = True
|
||||
|
||||
elif "Multi-modal output is not supported" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。"
|
||||
)
|
||||
modalites = ["Text"]
|
||||
loop = True
|
||||
|
||||
elif "candidates" not in result:
|
||||
raise Exception("Gemini 返回异常结果: " + str(result))
|
||||
|
||||
candidates = result["candidates"][0]["content"]["parts"]
|
||||
llm_response = LLMResponse("assistant")
|
||||
chain = []
|
||||
for candidate in candidates:
|
||||
if "text" in candidate:
|
||||
llm_response.completion_text += candidate["text"]
|
||||
chain.append(Comp.Plain(candidate["text"]))
|
||||
elif "functionCall" in candidate:
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args.append(candidate["functionCall"]["args"])
|
||||
@@ -208,8 +267,12 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response.tools_call_ids.append(
|
||||
candidate["functionCall"]["name"]
|
||||
) # 没有 tool id
|
||||
elif "inlineData" in candidate:
|
||||
mime_type: str = candidate["inlineData"]["mimeType"]
|
||||
if mime_type.startswith("image/"):
|
||||
chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"]))
|
||||
|
||||
llm_response.completion_text = llm_response.completion_text.strip()
|
||||
llm_response.result_chain = MessageChain(chain=chain)
|
||||
return llm_response
|
||||
|
||||
async def text_chat(
|
||||
@@ -253,46 +316,20 @@ class ProviderGoogleGenAI(Provider):
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(
|
||||
f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||
)
|
||||
try:
|
||||
await self.pop_record(context_query)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
if retry_cnt == 0:
|
||||
llm_response = LLMResponse(
|
||||
"err", "err: 请尝试 /reset 重置会话"
|
||||
)
|
||||
elif "Function calling is not enabled" in str(e):
|
||||
logger.info(
|
||||
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
||||
)
|
||||
if "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
llm_response = await self._query(payloads, None)
|
||||
break
|
||||
elif "429" in str(e) or "API key not valid" in str(e):
|
||||
if "429" in str(e) or "API key not valid" in str(e):
|
||||
keys.remove(chosen_key)
|
||||
if len(keys) > 0:
|
||||
chosen_key = random.choice(keys)
|
||||
logger.info(
|
||||
f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..."
|
||||
)
|
||||
raise Exception("API 资源已耗尽,且没有可用的 Key 重试...")
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||
|
||||
@@ -2,6 +2,8 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import inspect
|
||||
import random
|
||||
import asyncio
|
||||
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
@@ -176,77 +178,81 @@ class ProviderOpenAIOfficial(Provider):
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except UnprocessableEntityError as e:
|
||||
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
# 重试 10 次
|
||||
retry_cnt = 20
|
||||
while retry_cnt > 0:
|
||||
logger.warning(
|
||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||
)
|
||||
try:
|
||||
await self.pop_record(context_query)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
if retry_cnt == 0:
|
||||
llm_response = LLMResponse(
|
||||
"err", "err: 请尝试 /reset 清除会话记录。"
|
||||
)
|
||||
elif "The model is not a VLM" in str(e): # siliconcloud
|
||||
|
||||
max_retries = 10
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
|
||||
e = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except UnprocessableEntityError as e:
|
||||
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
|
||||
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
||||
elif (
|
||||
"does not support Function Calling" in str(e)
|
||||
or "does not support tools" in str(e)
|
||||
or "Function call is not supported" in str(e)
|
||||
or "Function calling is not enabled" in str(e)
|
||||
or "Tool calling is not supported" in str(e)
|
||||
or "No endpoints found that support tool use" in str(e)
|
||||
or "model does not support function calling" in str(e)
|
||||
or ("tool" in str(e) and "support" in str(e).lower())
|
||||
or ("function" in str(e) and "support" in str(e).lower())
|
||||
):
|
||||
logger.info(
|
||||
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
||||
)
|
||||
if "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
llm_response = await self._query(payloads, None)
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
if "429" in str(e):
|
||||
logger.warning(
|
||||
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
|
||||
)
|
||||
# 最后一次不等待
|
||||
if retry_cnt < max_retries - 1:
|
||||
await asyncio.sleep(1)
|
||||
available_api_keys.remove(chosen_key)
|
||||
if len(available_api_keys) > 0:
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
elif "maximum context length" in str(e):
|
||||
logger.warning(
|
||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||
)
|
||||
await self.pop_record(context_query)
|
||||
elif "The model is not a VLM" in str(e): # siliconcloud
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
elif (
|
||||
"Function calling is not enabled" in str(e)
|
||||
or ("tool" in str(e).lower() and "support" in str(e).lower())
|
||||
or ("function" in str(e).lower() and "support" in str(e).lower())
|
||||
):
|
||||
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
||||
logger.info(
|
||||
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
||||
)
|
||||
if "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
func_tool = None
|
||||
else:
|
||||
logger.error(
|
||||
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
|
||||
f"发生了错误。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
|
||||
if "Connection error." in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
if proxy:
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error(
|
||||
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
|
||||
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
|
||||
)
|
||||
|
||||
raise e
|
||||
if "Connection error." in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
if proxy:
|
||||
logger.error(
|
||||
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
raise e
|
||||
return llm_response
|
||||
|
||||
async def _remove_image_from_context(self, contexts: List):
|
||||
|
||||
@@ -48,14 +48,6 @@ class ProviderSenseVoiceSTTSelfHost(STTProvider):
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
return os.path.join("data", "temp", f"{timestamp}")
|
||||
|
||||
async def _convert_audio(self, path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
filename = await self.get_timestamped_path() + ".mp3"
|
||||
ff = FFmpeg()
|
||||
output_path = ff.convert(path, os.path.join('data","temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
|
||||
@@ -31,14 +31,6 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
|
||||
self.set_model(provider_config.get("model", None))
|
||||
|
||||
async def _convert_audio(self, path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
filename = str(uuid.uuid4()) + ".mp3"
|
||||
ff = FFmpeg()
|
||||
output_path = ff.convert(path, os.path.join("data/temp", filename))
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
|
||||
@@ -33,14 +33,6 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
)
|
||||
logger.info("Whisper 模型加载完成。")
|
||||
|
||||
async def _convert_audio(self, path: str) -> str:
|
||||
from pyffmpeg import FFmpeg
|
||||
|
||||
filename = str(uuid.uuid4()) + ".mp3"
|
||||
ff = FFmpeg()
|
||||
output_path = ff.convert(path, os.path.join("data/temp", filename))
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
|
||||
@@ -4,12 +4,14 @@ from .context import Context
|
||||
from astrbot.core.provider import Provider
|
||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_tools import StarTools
|
||||
|
||||
|
||||
class Star(CommandParserMixin):
|
||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||
|
||||
def __init__(self, context: Context):
|
||||
StarTools.initialize(context)
|
||||
self.context = context
|
||||
|
||||
async def text_to_image(self, text: str, return_url=True) -> str:
|
||||
@@ -27,4 +29,4 @@ class Star(CommandParserMixin):
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider"]
|
||||
__all__ = ["Star", "StarMetadata", "PluginManager", "Context", "Provider", "StarTools"]
|
||||
|
||||
@@ -451,7 +451,34 @@ class PluginManager:
|
||||
# reload the plugin
|
||||
dir_name = os.path.basename(plugin_path)
|
||||
await self.load(specified_dir_name=dir_name)
|
||||
return plugin_path
|
||||
|
||||
# Get the plugin metadata to return repo info
|
||||
plugin = self.context.get_registered_star(dir_name)
|
||||
if not plugin:
|
||||
# Try to find by other name if directory name doesn't match plugin name
|
||||
for star in self.context.get_all_stars():
|
||||
if star.root_dir_name == dir_name:
|
||||
plugin = star
|
||||
break
|
||||
|
||||
# Extract README.md content if exists
|
||||
readme_content = None
|
||||
readme_path = os.path.join(plugin_path, "README.md")
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(plugin_path, "readme.md")
|
||||
|
||||
if os.path.exists(readme_path):
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {"repo": plugin.repo, "readme": readme_content}
|
||||
|
||||
return plugin_info
|
||||
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
@@ -558,7 +585,7 @@ class PluginManager:
|
||||
|
||||
async def _terminate_plugin(self, star_metadata: StarMetadata):
|
||||
"""终止插件,调用插件的 terminate() 和 __del__() 方法"""
|
||||
logging.info(f"正在终止插件 {star_metadata.name} ...")
|
||||
logger.info(f"正在终止插件 {star_metadata.name} ...")
|
||||
|
||||
if not star_metadata.activated:
|
||||
# 说明之前已经被禁用了
|
||||
@@ -569,7 +596,7 @@ class PluginManager:
|
||||
asyncio.get_event_loop().run_in_executor(
|
||||
None, star_metadata.star_cls.__del__
|
||||
)
|
||||
else:
|
||||
elif hasattr(star_metadata.star_cls, "terminate"):
|
||||
await star_metadata.star_cls.terminate()
|
||||
|
||||
async def turn_on_plugin(self, plugin_name: str):
|
||||
@@ -607,3 +634,31 @@ class PluginManager:
|
||||
logger.warning(f"删除插件压缩包失败: {str(e)}")
|
||||
# await self.reload()
|
||||
await self.load(specified_dir_name=dir_name)
|
||||
|
||||
# Get the plugin metadata to return repo info
|
||||
plugin = self.context.get_registered_star(dir_name)
|
||||
if not plugin:
|
||||
# Try to find by other name if directory name doesn't match plugin name
|
||||
for star in self.context.get_all_stars():
|
||||
if star.root_dir_name == dir_name:
|
||||
plugin = star
|
||||
break
|
||||
|
||||
# Extract README.md content if exists
|
||||
readme_content = None
|
||||
readme_path = os.path.join(desti_dir, "README.md")
|
||||
if not os.path.exists(readme_path):
|
||||
readme_path = os.path.join(desti_dir, "readme.md")
|
||||
|
||||
if os.path.exists(readme_path):
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
except Exception as e:
|
||||
logger.warning(f"读取插件 {dir_name} 的 README.md 文件失败: {str(e)}")
|
||||
|
||||
plugin_info = None
|
||||
if plugin:
|
||||
plugin_info = {"repo": plugin.repo, "readme": readme_content}
|
||||
|
||||
return plugin_info
|
||||
|
||||
@@ -0,0 +1,144 @@
|
||||
from typing import Union, Awaitable, List, Optional, ClassVar
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.api.platform import MessageMember, AstrBotMessage
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.star.context import Context
|
||||
|
||||
|
||||
class StarTools:
|
||||
"""
|
||||
提供给插件使用的便捷工具函数集合
|
||||
这些方法封装了一些常用操作,使插件开发更加简单便捷!
|
||||
"""
|
||||
|
||||
_context: ClassVar[Optional[Context]] = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, context: Context) -> None:
|
||||
"""
|
||||
初始化StarTools,设置context引用
|
||||
|
||||
Args:
|
||||
context: 暴露给插件的上下文
|
||||
"""
|
||||
cls._context = context
|
||||
|
||||
@classmethod
|
||||
async def send_message(
|
||||
cls, session: Union[str, MessageSesion], message_chain: MessageChain
|
||||
) -> bool:
|
||||
"""
|
||||
根据session(unified_msg_origin)主动发送消息
|
||||
|
||||
Args:
|
||||
session: 消息会话。通过event.session或者event.unified_msg_origin获取
|
||||
message_chain: 消息链
|
||||
|
||||
Returns:
|
||||
bool: 是否找到匹配的平台
|
||||
|
||||
Raises:
|
||||
ValueError: 当session为字符串且解析失败时抛出
|
||||
|
||||
Note:
|
||||
qq_official(QQ官方API平台)不支持此方法
|
||||
"""
|
||||
return await cls._context.send_message(session, message_chain)
|
||||
|
||||
@classmethod
|
||||
async def create_message(
|
||||
cls,
|
||||
type: str,
|
||||
self_id: str,
|
||||
session_id: str,
|
||||
message_id: str,
|
||||
sender: MessageMember,
|
||||
message: List[BaseMessageComponent],
|
||||
message_str: str,
|
||||
raw_message: object,
|
||||
group_id: str = "",
|
||||
):
|
||||
"""
|
||||
创建一个AstrBot消息对象
|
||||
|
||||
Args:
|
||||
type (str): 消息类型
|
||||
self_id (str): 机器人自身ID
|
||||
session_id (str): 会话ID(通常为用户ID)(QQ号, 群号等)
|
||||
message_id (str): 消息ID
|
||||
sender (MessageMember): 发送者信息
|
||||
message (List[BaseMessageComponent]): 消息组件列表
|
||||
message_str (str): 消息字符串
|
||||
raw_message (object): 原始消息对象
|
||||
group_id (str, optional): 群组ID, 如果为私聊则为空. Defaults to "".
|
||||
|
||||
Returns:
|
||||
AstrBotMessage: 创建的消息对象
|
||||
"""
|
||||
abm = AstrBotMessage()
|
||||
abm.type = type
|
||||
abm.self_id = self_id
|
||||
abm.session_id = session_id
|
||||
abm.message_id = message_id
|
||||
abm.sender = sender
|
||||
abm.message = message
|
||||
abm.message_str = message_str
|
||||
abm.raw_message = raw_message
|
||||
abm.group_id = group_id
|
||||
return abm
|
||||
|
||||
# todo: 添加构造事件的方法
|
||||
# async def create_event(
|
||||
# self, platform: str, umo: str, sender_id: str, session_id: str
|
||||
# ):
|
||||
# platform = self._context.get_platform(platform)
|
||||
|
||||
# todo: 添加找到对应平台并提交对应事件的方法
|
||||
|
||||
@classmethod
|
||||
def activate_llm_tool(cls, name: str) -> bool:
|
||||
"""
|
||||
激活一个已经注册的函数调用工具
|
||||
注册的工具默认是激活状态
|
||||
|
||||
Args:
|
||||
name (str): 工具名称
|
||||
"""
|
||||
return cls._context.activate_llm_tool(name)
|
||||
|
||||
@classmethod
|
||||
def deactivate_llm_tool(cls, name: str) -> bool:
|
||||
"""
|
||||
停用一个已经注册的函数调用工具
|
||||
|
||||
Args:
|
||||
name (str): 工具名称
|
||||
"""
|
||||
return cls._context.deactivate_llm_tool(name)
|
||||
|
||||
@classmethod
|
||||
def register_llm_tool(
|
||||
cls, name: str, func_args: list, desc: str, func_obj: Awaitable
|
||||
) -> None:
|
||||
"""
|
||||
为函数调用(function-calling/tools-use)添加工具
|
||||
|
||||
Args:
|
||||
name (str): 工具名称
|
||||
func_args (list): 函数参数列表
|
||||
desc (str): 工具描述
|
||||
func_obj (Awaitable): 函数对象,必须是异步函数
|
||||
"""
|
||||
cls._context.register_llm_tool(name, func_args, desc, func_obj)
|
||||
|
||||
@classmethod
|
||||
def unregister_llm_tool(cls, name: str) -> None:
|
||||
"""
|
||||
删除一个函数调用工具
|
||||
如果再要启用,需要重新注册
|
||||
|
||||
Args:
|
||||
name (str): 工具名称
|
||||
"""
|
||||
cls._context.unregister_llm_tool(name)
|
||||
@@ -41,7 +41,7 @@ class PluginUpdator(RepoZipUpdator):
|
||||
plugin_path = os.path.join(self.plugin_store_path, plugin.root_dir_name)
|
||||
|
||||
logger.info(f"正在更新插件,路径: {plugin_path},仓库地址: {repo_url}")
|
||||
await self.download_from_repo_url(plugin_path, repo_url)
|
||||
await self.download_from_repo_url(plugin_path, repo_url, proxy=proxy)
|
||||
|
||||
try:
|
||||
remove_dir(plugin_path)
|
||||
|
||||
@@ -9,6 +9,11 @@ from astrbot.core.utils.io import download_file
|
||||
|
||||
|
||||
class AstrBotUpdator(RepoZipUpdator):
|
||||
"""AstrBot 更新器,继承自 RepoZipUpdator 类
|
||||
该类用于处理 AstrBot 的更新操作
|
||||
功能包括检查更新、下载更新文件、解压缩更新文件等
|
||||
"""
|
||||
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.MAIN_PATH = os.path.abspath(
|
||||
@@ -17,6 +22,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
|
||||
|
||||
def terminate_child_processes(self):
|
||||
"""终止当前进程的所有子进程
|
||||
使用 psutil 库获取当前进程的所有子进程,并尝试终止它们
|
||||
"""
|
||||
try:
|
||||
parent = psutil.Process(os.getpid())
|
||||
children = parent.children(recursive=True)
|
||||
@@ -35,6 +43,9 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
pass
|
||||
|
||||
def _reboot(self, delay: int = 3):
|
||||
"""重启当前程序
|
||||
在指定的延迟后,终止所有子进程并重新启动程序
|
||||
"""
|
||||
py = sys.executable
|
||||
time.sleep(delay)
|
||||
self.terminate_child_processes()
|
||||
@@ -46,6 +57,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
raise e
|
||||
|
||||
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
"""检查更新"""
|
||||
return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION)
|
||||
|
||||
async def get_releases(self) -> list:
|
||||
|
||||
@@ -103,7 +103,7 @@ async def download_image_by_url(
|
||||
with open(path, "wb") as f:
|
||||
f.write(await resp.read())
|
||||
return path
|
||||
except aiohttp.client.ClientConnectorSSLError:
|
||||
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers("DEFAULT")
|
||||
@@ -152,7 +152,7 @@ async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s",
|
||||
end="",
|
||||
)
|
||||
except aiohttp.client.ClientConnectorSSLError:
|
||||
except (aiohttp.ClientConnectorSSLError, aiohttp.ClientConnectorCertificateError):
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers("DEFAULT")
|
||||
|
||||
@@ -16,6 +16,7 @@ class SharedPreferences:
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default=None):
|
||||
return self._data.get(key, default)
|
||||
|
||||
@@ -2,7 +2,7 @@ import jwt
|
||||
import datetime
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core import WEBUI_SK
|
||||
from astrbot.core import WEBUI_SK, DEMO_MODE
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
@@ -40,6 +40,13 @@ class AuthRoute(Route):
|
||||
return Response().error("用户名或密码错误").__dict__
|
||||
|
||||
async def edit_account(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
password = self.config["dashboard"]["password"]
|
||||
post_data = await request.json
|
||||
|
||||
|
||||
@@ -12,8 +12,11 @@ from astrbot.core import logger
|
||||
|
||||
|
||||
def try_cast(value: str, type_: str):
|
||||
if type_ == "int" and value.isdigit():
|
||||
return int(value)
|
||||
if type_ == "int":
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
elif (
|
||||
type_ == "float"
|
||||
and isinstance(value, str)
|
||||
@@ -22,6 +25,11 @@ def try_cast(value: str, type_: str):
|
||||
return float(value)
|
||||
elif type_ == "float" and isinstance(value, int):
|
||||
return float(value)
|
||||
elif type_ == "float":
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def validate_config(
|
||||
@@ -34,13 +42,21 @@ def validate_config(
|
||||
if key not in metadata:
|
||||
# 无 schema 的配置项,执行类型猜测
|
||||
if isinstance(value, str):
|
||||
if value.isdigit():
|
||||
try:
|
||||
data[key] = int(value)
|
||||
elif value.replace(".", "", 1).isdigit():
|
||||
continue
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
try:
|
||||
data[key] = float(value)
|
||||
elif value == "true":
|
||||
continue
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if value.lower() == "true":
|
||||
data[key] = True
|
||||
elif value == "false":
|
||||
elif value.lower() == "false":
|
||||
data[key] = False
|
||||
continue
|
||||
meta = metadata[key]
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
from quart import websocket
|
||||
import json
|
||||
from quart import make_response
|
||||
from astrbot.core import logger, LogBroker
|
||||
from .route import Route, RouteContext
|
||||
|
||||
@@ -8,21 +9,36 @@ class LogRoute(Route):
|
||||
def __init__(self, context: RouteContext, log_broker: LogBroker) -> None:
|
||||
super().__init__(context)
|
||||
self.log_broker = log_broker
|
||||
self.app.add_url_rule(
|
||||
"/api/live-log", view_func=self.log, methods=["GET"], websocket=True
|
||||
)
|
||||
self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"])
|
||||
|
||||
async def log(self):
|
||||
queue = None
|
||||
try:
|
||||
queue = self.log_broker.register()
|
||||
while True:
|
||||
message = await queue.get()
|
||||
await websocket.send(message)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except BaseException as e:
|
||||
logger.error(f"WebSocket 连接错误: {e}")
|
||||
finally:
|
||||
if queue:
|
||||
self.log_broker.unregister(queue)
|
||||
async def stream():
|
||||
queue = None
|
||||
try:
|
||||
queue = self.log_broker.register()
|
||||
while True:
|
||||
message = await queue.get()
|
||||
payload = {
|
||||
"type": "log",
|
||||
**message # see astrbot/core/log.py
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except BaseException as e:
|
||||
logger.error(f"Log SSE 连接错误: {e}")
|
||||
finally:
|
||||
if queue:
|
||||
self.log_broker.unregister(queue)
|
||||
|
||||
response = await make_response(
|
||||
stream(),
|
||||
{
|
||||
"Content-Type": "text/event-stream",
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"Transfer-Encoding": "chunked",
|
||||
},
|
||||
)
|
||||
response.timeout = None
|
||||
return response
|
||||
|
||||
@@ -15,6 +15,7 @@ from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.filter.permission import PermissionTypeFilter
|
||||
from astrbot.core.star.filter.regex import RegexFilter
|
||||
from astrbot.core.star.star_handler import EventType
|
||||
from astrbot.core import DEMO_MODE
|
||||
|
||||
|
||||
class PluginRoute(Route):
|
||||
@@ -50,6 +51,13 @@ class PluginRoute(Route):
|
||||
}
|
||||
|
||||
async def reload_plugins(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
data = await request.json
|
||||
plugin_name = data.get("name", None)
|
||||
try:
|
||||
@@ -187,6 +195,13 @@ class PluginRoute(Route):
|
||||
return handlers
|
||||
|
||||
async def install_plugin(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
repo_url = post_data["url"]
|
||||
|
||||
@@ -196,30 +211,44 @@ class PluginRoute(Route):
|
||||
|
||||
try:
|
||||
logger.info(f"正在安装插件 {repo_url}")
|
||||
await self.plugin_manager.install_plugin(repo_url, proxy)
|
||||
plugin_info = await self.plugin_manager.install_plugin(repo_url, proxy)
|
||||
# self.core_lifecycle.restart()
|
||||
logger.info(f"安装插件 {repo_url} 成功。")
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
return Response().ok(plugin_info, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def install_plugin_upload(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
file = await request.files
|
||||
file = file["file"]
|
||||
logger.info(f"正在安装用户上传的插件 {file.filename}")
|
||||
file_path = f"data/temp/{file.filename}"
|
||||
await file.save(file_path)
|
||||
await self.plugin_manager.install_plugin_from_file(file_path)
|
||||
plugin_info = await self.plugin_manager.install_plugin_from_file(file_path)
|
||||
# self.core_lifecycle.restart()
|
||||
logger.info(f"安装插件 {file.filename} 成功")
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
return Response().ok(plugin_info, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def uninstall_plugin(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
@@ -232,6 +261,13 @@ class PluginRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def update_plugin(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
proxy: str = post_data.get("proxy", None)
|
||||
@@ -247,6 +283,13 @@ class PluginRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def off_plugin(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
@@ -258,6 +301,13 @@ class PluginRoute(Route):
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def on_plugin(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
post_data = await request.json
|
||||
plugin_name = post_data["name"]
|
||||
try:
|
||||
|
||||
@@ -8,6 +8,7 @@ from quart import request
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core import DEMO_MODE
|
||||
|
||||
|
||||
class StatRoute(Route):
|
||||
@@ -29,6 +30,13 @@ class StatRoute(Route):
|
||||
self.core_lifecycle = core_lifecycle
|
||||
|
||||
async def restart_core(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
await self.core_lifecycle.restart()
|
||||
return Response().ok().__dict__
|
||||
|
||||
|
||||
@@ -20,6 +20,8 @@ class StaticFileRoute(Route):
|
||||
"/providers",
|
||||
"/about",
|
||||
"/extension-marketplace",
|
||||
"/conversation",
|
||||
"/tool-use",
|
||||
]
|
||||
for i in index_:
|
||||
self.app.add_url_rule(i, view_func=self.index)
|
||||
|
||||
@@ -6,6 +6,7 @@ from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger, pip_installer
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core import DEMO_MODE
|
||||
|
||||
|
||||
class UpdateRoute(Route):
|
||||
@@ -126,6 +127,13 @@ class UpdateRoute(Route):
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def install_pip_package(self):
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
data = await request.json
|
||||
package = data.get("package", "")
|
||||
if not package:
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
# What's Changed
|
||||
|
||||
> 📢 在升级前,请完整阅读本次更新日志。
|
||||
|
||||
## ✨ 新增的功能
|
||||
|
||||
1. 适配 `gemini-2.0-flash-exp-image-generation` 对图片模态的输入 [#1017](https://github.com/Soulter/AstrBot/issues/1017)
|
||||
2. 在 MessageChain 类中添加 at 和 at_all 方法,用于快速添加 At 消息 @left666
|
||||
3. Gewechat Client 增加获取通讯录列表接口
|
||||
4. 支持 /llm 指令快捷启停 LLM 功能 [#296](https://github.com/Soulter/AstrBot/issues/296)
|
||||
|
||||
## 🎈 功能性优化
|
||||
|
||||
1. Edge TTS 支持使用代理
|
||||
2. 在 Lifecycle 新增插件资源清理逻辑 @Raven95676
|
||||
3. Docker 镜像提供内置 FFmpeg [#979](https://github.com/Soulter/AstrBot/issues/979)
|
||||
4. 优化无对话情况下设置人格的反馈 @Raven95676
|
||||
5. 若禁用提供商,自动切换到另一个可用的提供商 @Raven95676
|
||||
6. openai_source 同步支持随机请求均衡,同时优化 LLM 请求逻辑的异常处理
|
||||
7. 保存 shared_preferences 时强制刷新文件缓冲区
|
||||
8. 优化空 At 回复 @advent259141
|
||||
|
||||
## 🐛 修复的 Bug
|
||||
|
||||
1. 插件更新时没有正确应用加速地址
|
||||
2. newgroup 指令名显示错误
|
||||
|
||||
## 🧩 新增的插件
|
||||
|
||||
待补充
|
||||
@@ -0,0 +1,31 @@
|
||||
# What's Changed
|
||||
|
||||
> 📢 在升级前,请完整阅读本次更新日志。
|
||||
|
||||
## ✨ 新增的功能
|
||||
|
||||
1. 安装完插件后自动弹出插件仓库 README 对话框 @zhx8702
|
||||
4. 支持阿里云百炼 TTS@Soulter
|
||||
5. 支持 Telegram MarkdownV2 渲染 @Soulter
|
||||
6. 支持 钉钉 Markdown 渲染 @Soulter
|
||||
6. 增加对 Gemini 系列模型的输入安全设置参数支持 @AliveGh0st
|
||||
7. 支持手动设置时区以应对容器、国外用户的时区问题 @anka-afk @Raven95676 @Soulter
|
||||
8. 插件市场显示帮助按钮 @Soulter
|
||||
|
||||
## 🎈 功能性优化
|
||||
|
||||
1. WebUI 的日志通信使用 SSE 替代 Websockets @Soulter
|
||||
2. 在发送消息之前统一检查消息内容是否为空, 不允许发送空消息, 以解决该消息内容不支持查看以及 Gemini 返回 `<empty content>` 问题 @anka-afk
|
||||
3. 更新 Dify 平台链接为官方域名 by @Captain-Slacker-OwO
|
||||
4. 人格 prompt 输入框支持调节高度 @Soulter
|
||||
|
||||
## 🐛 修复的 Bug
|
||||
|
||||
1. 将最多携带对话数量修改回 `-1` 时出现报错 #1074 @anka-afk
|
||||
2. 修复无法识别到函数调用异常的问题 by @Soulter
|
||||
3. 修复 aiocqhttp 适配器下空白 plain 导致的 `the object is not a proper segment chain` 报错问题 @Soulter
|
||||
4. 修复阿里百炼应用无法多轮会话的问题 @Soulter
|
||||
|
||||
## 🧩 新增的插件
|
||||
|
||||
待补充
|
||||
@@ -21,9 +21,10 @@
|
||||
"axios-mock-adapter": "^1.22.0",
|
||||
"chance": "1.1.11",
|
||||
"date-fns": "2.30.0",
|
||||
"highlight.js": "^11.11.1",
|
||||
"js-md5": "^0.8.3",
|
||||
"lodash": "4.17.21",
|
||||
"marked": "^15.0.6",
|
||||
"marked": "^15.0.7",
|
||||
"pinia": "2.1.6",
|
||||
"remixicon": "3.5.0",
|
||||
"vee-validate": "4.11.3",
|
||||
|
||||
@@ -94,7 +94,6 @@
|
||||
v-else-if="metadata[metadataKey].items[key]?.type === 'text' && !metadata[metadataKey].items[key]?.invisible"
|
||||
v-model="iterable[key]"
|
||||
variant="outlined"
|
||||
auto-grow
|
||||
rows="3"
|
||||
class="config-field"
|
||||
hide-details
|
||||
|
||||
@@ -3,9 +3,20 @@ import { useCommonStore } from '@/stores/common';
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<div id="term"
|
||||
style="background-color: #1e1e1e; padding: 16px; border-radius: 8px; overflow-y:auto">
|
||||
<div>
|
||||
<!-- 添加筛选级别控件 -->
|
||||
<div class="filter-controls mb-2">
|
||||
<v-chip-group v-model="selectedLevels" column multiple>
|
||||
<v-chip v-for="level in logLevels" :key="level" :color="getLevelColor(level)" filter
|
||||
:text-color="level === 'DEBUG' || level === 'INFO' ? 'black' : 'white'">
|
||||
{{ level }}
|
||||
</v-chip>
|
||||
</v-chip-group>
|
||||
</div>
|
||||
|
||||
<div id="term" style="background-color: #1e1e1e; padding: 16px; border-radius: 8px; overflow-y:auto; height: 100%">
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script>
|
||||
@@ -25,7 +36,16 @@ export default {
|
||||
'default': 'color: #FFFFFF;'
|
||||
},
|
||||
logCache: useCommonStore().getLogCache(),
|
||||
historyNum_: -1
|
||||
historyNum_: -1,
|
||||
logLevels: ['DEBUG', 'INFO', 'WARNING', 'ERROR', 'CRITICAL'],
|
||||
selectedLevels: [0, 1, 2, 3, 4], // 默认选中所有级别
|
||||
levelColors: {
|
||||
'DEBUG': 'grey',
|
||||
'INFO': 'blue-lighten-3',
|
||||
'WARNING': 'amber',
|
||||
'ERROR': 'red',
|
||||
'CRITICAL': 'purple'
|
||||
}
|
||||
}
|
||||
},
|
||||
props: {
|
||||
@@ -37,27 +57,82 @@ export default {
|
||||
watch: {
|
||||
logCache: {
|
||||
handler(val) {
|
||||
this.printLog(val[this.logCache.length - 1])
|
||||
const lastLog = val[this.logCache.length - 1];
|
||||
if (lastLog && this.isLevelSelected(lastLog.level)) {
|
||||
this.printLog(lastLog.data);
|
||||
}
|
||||
},
|
||||
deep: true
|
||||
},
|
||||
selectedLevels: {
|
||||
handler() {
|
||||
this.refreshDisplay();
|
||||
},
|
||||
deep: true
|
||||
}
|
||||
},
|
||||
mounted() {
|
||||
this.historyNum_ = parseInt(this.historyNum)
|
||||
let i = 0
|
||||
for (let log of this.logCache) {
|
||||
if (this.historyNum_ != -1 && i >= this.logCache.length - this.historyNum_) {
|
||||
this.printLog(log)
|
||||
++i
|
||||
} else if (this.historyNum_ == -1) {
|
||||
this.printLog(log)
|
||||
}
|
||||
if (this.logCache.length === 0) {
|
||||
this.delayInit()
|
||||
} else {
|
||||
this.init()
|
||||
}
|
||||
},
|
||||
methods: {
|
||||
getLevelColor(level) {
|
||||
return this.levelColors[level] || 'grey';
|
||||
},
|
||||
|
||||
isLevelSelected(level) {
|
||||
for (let i = 0; i < this.selectedLevels.length; ++i) {
|
||||
let level_ = this.logLevels[this.selectedLevels[i]]
|
||||
if (level_ === level) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
},
|
||||
|
||||
refreshDisplay() {
|
||||
// 清空现有的显示
|
||||
const termElement = document.getElementById('term');
|
||||
if (termElement) {
|
||||
termElement.innerHTML = '';
|
||||
}
|
||||
|
||||
// 重新显示符合筛选条件的日志
|
||||
this.init();
|
||||
},
|
||||
|
||||
delayInit() {
|
||||
if (this.logCache.length === 0) {
|
||||
setTimeout(() => {
|
||||
this.delayInit()
|
||||
}, 500)
|
||||
} else {
|
||||
this.init()
|
||||
}
|
||||
},
|
||||
|
||||
init() {
|
||||
this.historyNum_ = parseInt(this.historyNum)
|
||||
let i = 0
|
||||
for (let log of this.logCache) {
|
||||
if (this.isLevelSelected(log.level)) { // 只显示选中级别的日志
|
||||
if (this.historyNum_ != -1 && i >= this.logCache.length - this.historyNum_) {
|
||||
this.printLog(log.data)
|
||||
++i
|
||||
} else if (this.historyNum_ == -1) {
|
||||
this.printLog(log.data)
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
toggleAutoScroll() {
|
||||
this.autoScroll = !this.autoScroll;
|
||||
},
|
||||
|
||||
printLog(log) {
|
||||
// append 一个 span 标签到 term,block 的方式
|
||||
let ele = document.getElementById('term')
|
||||
@@ -70,14 +145,38 @@ export default {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
span.style = style + 'display: block; font-size: 12px; font-family: Consolas, monospace; white-space: pre-wrap;'
|
||||
span.classList.add('fade-in')
|
||||
span.innerText = log
|
||||
span.innerText = `${log}`;
|
||||
ele.appendChild(span)
|
||||
if (this.autoScroll) {
|
||||
if (this.autoScroll ) {
|
||||
ele.scrollTop = ele.scrollHeight
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
</script>
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.filter-controls {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 8px;
|
||||
margin-bottom: 8px;
|
||||
}
|
||||
|
||||
.fade-in {
|
||||
animation: fadeIn 0.3s;
|
||||
}
|
||||
|
||||
@keyframes fadeIn {
|
||||
from {
|
||||
opacity: 0;
|
||||
}
|
||||
|
||||
to {
|
||||
opacity: 1;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
@@ -3,12 +3,36 @@
|
||||
<v-list dense style="background-color: transparent;max-height: 300px; overflow-y: auto;">
|
||||
<v-list-item v-for="(item, index) in items" :key="index">
|
||||
<v-list-item-content style="display: flex; justify-content: space-between;">
|
||||
<v-list-item-title>
|
||||
<v-list-item-title v-if="editIndex !== index">
|
||||
<v-chip size="small" label color="primary">{{ item }}</v-chip>
|
||||
</v-list-item-title>
|
||||
<v-btn @click="removeItem(index)" variant="plain">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
<v-text-field
|
||||
v-else
|
||||
v-model="editItem"
|
||||
dense
|
||||
hide-details
|
||||
variant="outlined"
|
||||
density="compact"
|
||||
@keyup.enter="saveEdit"
|
||||
@keyup.esc="cancelEdit"
|
||||
autofocus
|
||||
></v-text-field>
|
||||
<div v-if="editIndex !== index">
|
||||
<v-btn @click="startEdit(index, item)" variant="plain" class="edit-btn" icon size="small">
|
||||
<v-icon>mdi-pencil</v-icon>
|
||||
</v-btn>
|
||||
<v-btn @click="removeItem(index)" variant="plain" icon size="small">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
<div v-else>
|
||||
<v-btn @click="saveEdit" variant="plain" color="success" icon size="small">
|
||||
<v-icon>mdi-check</v-icon>
|
||||
</v-btn>
|
||||
<v-btn @click="cancelEdit" variant="plain" color="error" icon size="small">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-list-item-content>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
@@ -41,6 +65,8 @@ export default {
|
||||
return {
|
||||
newItem: '',
|
||||
items: this.value,
|
||||
editIndex: -1,
|
||||
editItem: '',
|
||||
};
|
||||
},
|
||||
watch: {
|
||||
@@ -58,6 +84,20 @@ export default {
|
||||
removeItem(index) {
|
||||
this.items.splice(index, 1);
|
||||
},
|
||||
startEdit(index, item) {
|
||||
this.editIndex = index;
|
||||
this.editItem = item;
|
||||
},
|
||||
saveEdit() {
|
||||
if (this.editItem.trim() !== '') {
|
||||
this.items[this.editIndex] = this.editItem.trim();
|
||||
this.cancelEdit();
|
||||
}
|
||||
},
|
||||
cancelEdit() {
|
||||
this.editIndex = -1;
|
||||
this.editItem = '';
|
||||
},
|
||||
},
|
||||
};
|
||||
</script>
|
||||
@@ -82,4 +122,8 @@ export default {
|
||||
.v-btn {
|
||||
margin-left: 8px;
|
||||
}
|
||||
|
||||
.edit-btn {
|
||||
margin-right: -8px;
|
||||
}
|
||||
</style>
|
||||
@@ -184,7 +184,7 @@ function updateDashboard() {
|
||||
checkUpdate();
|
||||
|
||||
const commonStore = useCommonStore();
|
||||
commonStore.createWebSocket();
|
||||
commonStore.createEventSource(); // log
|
||||
commonStore.getStartTime();
|
||||
|
||||
|
||||
|
||||
@@ -5,8 +5,10 @@ export const useCommonStore = defineStore({
|
||||
id: 'common',
|
||||
state: () => ({
|
||||
// @ts-ignore
|
||||
websocket: null,
|
||||
eventSource: null,
|
||||
log_cache: [],
|
||||
sse_connected: false,
|
||||
|
||||
log_cache_max_len: 1000,
|
||||
startTime: -1,
|
||||
|
||||
@@ -21,25 +23,92 @@ export const useCommonStore = defineStore({
|
||||
"dingtalk": "https://astrbot.app/deploy/platform/dingtalk.html",
|
||||
},
|
||||
|
||||
pluginMarketData: []
|
||||
|
||||
pluginMarketData: [],
|
||||
}),
|
||||
actions: {
|
||||
createWebSocket() {
|
||||
if (this.websocket) {
|
||||
createEventSource() {
|
||||
if (this.eventSource) {
|
||||
return
|
||||
}
|
||||
let protocol = window.location.protocol === 'https:' ? 'wss' : 'ws'
|
||||
let route = '/api/live-log'
|
||||
let port = window.location.port
|
||||
let url = `${protocol}://${window.location.hostname}:${port}${route}`
|
||||
console.log('websocket url:', url)
|
||||
this.websocket = new WebSocket(url)
|
||||
this.websocket.onmessage = (evt) => {
|
||||
this.log_cache.push(evt.data)
|
||||
if (this.log_cache.length > this.log_cache_max_len) {
|
||||
this.log_cache.shift()
|
||||
const controller = new AbortController();
|
||||
const { signal } = controller;
|
||||
const headers = {
|
||||
'Content-Type': 'multipart/form-data',
|
||||
'Authorization': 'Bearer ' + localStorage.getItem('token')
|
||||
};
|
||||
fetch('/api/live-log', {
|
||||
method: 'GET',
|
||||
headers,
|
||||
signal,
|
||||
cache: 'no-cache',
|
||||
}).then(response => {
|
||||
if (!response.ok) {
|
||||
throw new Error(`SSE connection failed: ${response.status}`);
|
||||
}
|
||||
console.log('SSE stream opened');
|
||||
this.sse_connected = true;
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
const processStream = ({ done, value }) => {
|
||||
if (done) {
|
||||
console.log('SSE stream closed');
|
||||
setTimeout(() => {
|
||||
this.eventSource = null;
|
||||
this.createEventSource();
|
||||
}, 2000);
|
||||
return;
|
||||
}
|
||||
|
||||
const text = decoder.decode(value);
|
||||
const lines = text.split('\n\n');
|
||||
lines.forEach(line => {
|
||||
if (line.startsWith('data:')) {
|
||||
const data = line.substring(5).trim();
|
||||
// {"type":"log","data":"[2021-08-01 00:00:00] INFO: Hello, world!"}
|
||||
let data_json = {}
|
||||
try {
|
||||
data_json = JSON.parse(data);
|
||||
} catch (e) {
|
||||
console.error('Invalid JSON:', data);
|
||||
data_json = {
|
||||
type: 'log',
|
||||
data: data,
|
||||
level: 'INFO',
|
||||
time: new Date().toISOString(),
|
||||
}
|
||||
}
|
||||
if (data_json.type === 'log') {
|
||||
// let log = data_json.data
|
||||
this.log_cache.push(data_json);
|
||||
if (this.log_cache.length > this.log_cache_max_len) {
|
||||
this.log_cache.shift();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
return reader.read().then(processStream);
|
||||
};
|
||||
|
||||
reader.read().then(processStream);
|
||||
}).catch(error => {
|
||||
console.error('SSE error:', error);
|
||||
// Attempt to reconnect after a delay
|
||||
this.log_cache.push('SSE Connection failed, retrying in 5 seconds...');
|
||||
setTimeout(() => {
|
||||
this.eventSource = null;
|
||||
this.createEventSource();
|
||||
}, 1000);
|
||||
});
|
||||
|
||||
// Store controller to allow closing the connection
|
||||
this.eventSource = controller;
|
||||
},
|
||||
closeEventSourcet() {
|
||||
if (this.eventSource) {
|
||||
this.eventSource.abort();
|
||||
this.eventSource = null;
|
||||
}
|
||||
},
|
||||
getLogCache() {
|
||||
@@ -50,7 +119,7 @@ export const useCommonStore = defineStore({
|
||||
return this.startTime
|
||||
}
|
||||
axios.get('/api/stat/start-time').then((res) => {
|
||||
this.startTime = res.data.data.start_time
|
||||
this.startTime = res.data.data.start_time
|
||||
})
|
||||
},
|
||||
getTutorialLink(platform) {
|
||||
|
||||
@@ -289,6 +289,16 @@ export default {
|
||||
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else if (chunk.startsWith('[RECORD]')) {
|
||||
let audio = chunk.replace('[RECORD]', '');
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
message: `<audio controls class="audio-player">
|
||||
<source src="/api/chat/get_file?filename=${audio}" type="audio/wav">
|
||||
您的浏览器不支持音频播放。
|
||||
</audio>`
|
||||
}
|
||||
this.messages.push(bot_resp);
|
||||
} else {
|
||||
let bot_resp = {
|
||||
type: 'bot',
|
||||
@@ -407,6 +417,13 @@ export default {
|
||||
let img = message[i].message.replace('[IMAGE]', '');
|
||||
message[i].message = `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
|
||||
}
|
||||
if (message[i].message.startsWith('[RECORD]')) {
|
||||
let audio = message[i].message.replace('[RECORD]', '');
|
||||
message[i].message = `<audio controls class="audio-player">
|
||||
<source src="/api/chat/get_file?filename=${audio}" type="audio/wav">
|
||||
您的浏览器不支持音频播放。
|
||||
</audio>`
|
||||
}
|
||||
if (message[i].image_url && message[i].image_url.length > 0) {
|
||||
for (let j = 0; j < message[i].image_url.length; j++) {
|
||||
message[i].image_url[j] = `/api/chat/get_file?filename=${message[i].image_url[j]}`;
|
||||
@@ -846,7 +863,6 @@ export default {
|
||||
}
|
||||
|
||||
.audio-player {
|
||||
width: 100%;
|
||||
height: 36px;
|
||||
border-radius: 18px;
|
||||
}
|
||||
|
||||
@@ -44,7 +44,7 @@ import axios from 'axios';
|
||||
</v-dialog>
|
||||
</div>
|
||||
</div>
|
||||
<ConsoleDisplayer ref="consoleDisplayer" style="height: calc(100vh - 160px); " />
|
||||
<ConsoleDisplayer ref="consoleDisplayer" style="height: calc(100vh - 220px); " />
|
||||
</div>
|
||||
</template>
|
||||
<script>
|
||||
|
||||
@@ -5,6 +5,9 @@ import AstrBotConfig from '@/components/shared/AstrBotConfig.vue';
|
||||
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
|
||||
import axios from 'axios';
|
||||
import { useCommonStore } from '@/stores/common';
|
||||
import { marked } from 'marked';
|
||||
import hljs from 'highlight.js';
|
||||
import 'highlight.js/styles/github.css';
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -60,14 +63,13 @@ import { useCommonStore } from '@/stores/common';
|
||||
</v-row>
|
||||
</div>
|
||||
|
||||
|
||||
|
||||
|
||||
<div v-if="isListView" class="mt-4">
|
||||
<h2>📦 全部插件</h2>
|
||||
<v-col cols="12" md="12" style="padding: 0px;">
|
||||
<v-data-table :headers="pluginMarketHeaders" :items="pluginMarketData" item-key="name"
|
||||
:loading="loading_" v-model:search="marketSearch"
|
||||
:filter-keys="filterKeys">
|
||||
:loading="loading_" v-model:search="marketSearch" :filter-keys="filterKeys">
|
||||
<template v-slot:item.name="{ item }">
|
||||
<div class="d-flex align-center">
|
||||
<img v-if="item.logo" :src="item.logo"
|
||||
@@ -83,20 +85,22 @@ import { useCommonStore } from '@/stores/common';
|
||||
</template>
|
||||
<template v-slot:item.author="{ item }">
|
||||
<span v-if="item?.social_link"><a :href="item?.social_link">{{ item.author
|
||||
}}</a></span>
|
||||
}}</a></span>
|
||||
<span v-else>{{ item.author }}</span>
|
||||
</template>
|
||||
<template v-slot:item.tags="{ item }">
|
||||
<span v-if="item.tags.length === 0">无</span>
|
||||
<v-chip v-for="tag in item.tags" :key="tag" color="primary" size="small">{{ tag
|
||||
}}</v-chip>
|
||||
}}</v-chip>
|
||||
</template>
|
||||
<template v-slot:item.actions="{ item }">
|
||||
<v-btn v-if="!item.installed" class="text-none mr-2" size="small" text="Read"
|
||||
<v-btn v-if="!item.installed" class="text-none mr-2" size="small"
|
||||
variant="flat" border
|
||||
@click="extension_url = item.repo; newExtension()">安装</v-btn>
|
||||
<v-btn v-else class="text-none mr-2" size="small" text="Read" variant="flat" border
|
||||
<v-btn v-else class="text-none mr-2" size="small" variant="flat" border
|
||||
disabled>已安装</v-btn>
|
||||
<v-btn class="text-none mr-2" size="small" variant="flat" border
|
||||
@click="open(item.repo)">查看帮助</v-btn>
|
||||
</template>
|
||||
</v-data-table>
|
||||
</v-col>
|
||||
@@ -175,6 +179,42 @@ import { useCommonStore } from '@/stores/common';
|
||||
</v-snackbar>
|
||||
|
||||
<WaitingForRestart ref="wfr"></WaitingForRestart>
|
||||
|
||||
<!-- README Dialog -->
|
||||
<v-dialog v-model="readmeDialog.show" width="800" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="d-flex justify-space-between align-center">
|
||||
<span class="text-h5">插件说明文档</span>
|
||||
<v-btn icon @click="readmeDialog.show = false">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
<v-divider></v-divider>
|
||||
<v-card-text style="height: 70vh; overflow-y: auto;">
|
||||
<v-btn color="primary" prepend-icon="mdi-open-in-new" @click="openReadmeInNewTab()" class="mt-4">
|
||||
在GitHub中查看文档
|
||||
</v-btn>
|
||||
<div v-if="readmeDialog.content" class="markdown-body" v-html="renderMarkdown(readmeDialog.content)">
|
||||
</div>
|
||||
<div v-else-if="readmeDialog.error" class="d-flex flex-column align-center justify-center"
|
||||
style="height: 100%;">
|
||||
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle-outline</v-icon>
|
||||
<p class="text-body-1 text-center mb-4">{{ readmeDialog.error }}</p>
|
||||
</div>
|
||||
<div v-else class="d-flex flex-column align-center justify-center" style="height: 100%;">
|
||||
<v-icon size="64" color="warning" class="mb-4">mdi-file-question-outline</v-icon>
|
||||
<p class="text-body-1 text-center mb-4">该插件未提供文档链接或GitHub仓库地址。<br>请查看插件市场或联系插件作者获取更多信息。</p>
|
||||
</div>
|
||||
</v-card-text>
|
||||
<v-divider></v-divider>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="primary" variant="tonal" @click="readmeDialog.show = false">
|
||||
关闭
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
|
||||
@@ -209,6 +249,12 @@ export default {
|
||||
statusCode: 0, // 0: loading, 1: success, 2: error,
|
||||
result: ""
|
||||
},
|
||||
readmeDialog: {
|
||||
show: false,
|
||||
url: null,
|
||||
content: null,
|
||||
error: null
|
||||
},
|
||||
|
||||
announcement: "",
|
||||
isListView: true,
|
||||
@@ -234,8 +280,8 @@ export default {
|
||||
const search = this.marketSearch.toLowerCase();
|
||||
return this.pluginMarketData.filter(plugin =>
|
||||
this.filterKeys.some(key =>
|
||||
plugin[key]?.toLowerCase().includes(search)
|
||||
));
|
||||
plugin[key]?.toLowerCase().includes(search)
|
||||
));
|
||||
},
|
||||
pinnedPlugins() {
|
||||
return this.pluginMarketData.filter(plugin => plugin?.pinned);
|
||||
@@ -262,6 +308,12 @@ export default {
|
||||
});
|
||||
},
|
||||
methods: {
|
||||
open(link) {
|
||||
if (link) {
|
||||
window.open(link, '_blank');
|
||||
}
|
||||
},
|
||||
|
||||
jumpToPluginMarket() {
|
||||
window.open('https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json', '_blank');
|
||||
},
|
||||
@@ -327,7 +379,50 @@ export default {
|
||||
});
|
||||
},
|
||||
|
||||
newExtension() {
|
||||
async getReadmeUrl(repoUrl) {
|
||||
// 去掉 repoUrl 末尾的斜杠
|
||||
repoUrl = repoUrl.replace(/\/+$/, '');
|
||||
|
||||
const match = repoUrl.match(/github\.com\/([^/]+)\/([^/]+)/);
|
||||
if (!match) {
|
||||
throw new Error("无效的 GitHub 仓库地址");
|
||||
}
|
||||
|
||||
const owner = match[1];
|
||||
const repo = match[2];
|
||||
|
||||
const apiUrl = `https://api.github.com/repos/${owner}/${repo}`;
|
||||
|
||||
try {
|
||||
const res = await fetch(apiUrl);
|
||||
const data = await res.json();
|
||||
|
||||
const branch = data?.default_branch || 'master';
|
||||
return `${repoUrl}/blob/${branch}/README.md`;
|
||||
} catch (error) {
|
||||
console.error("获取默认分支失败,使用 master 作为默认:", error);
|
||||
return `${repoUrl}/blob/master/README.md`;
|
||||
}
|
||||
},
|
||||
|
||||
async showReadmeDialog(res) {
|
||||
this.readmeDialog.content = null;
|
||||
this.readmeDialog.error = null;
|
||||
if (res?.data?.data?.repo) {
|
||||
this.readmeDialog.url = await this.getReadmeUrl(res.data.data.repo);
|
||||
if (res.data.data.readme) {
|
||||
this.readmeDialog.content = res.data.data.readme;
|
||||
} else {
|
||||
this.readmeDialog.error = "插件未提供README文档";
|
||||
}
|
||||
} else {
|
||||
this.readmeDialog.url = null;
|
||||
this.readmeDialog.error = "插件没有仓库信息或README文档";
|
||||
}
|
||||
this.readmeDialog.show = true;
|
||||
},
|
||||
|
||||
async newExtension() {
|
||||
if (this.extension_url === "" && this.upload_file === null) {
|
||||
this.toast("请填写插件链接或上传插件文件", "error");
|
||||
return;
|
||||
@@ -347,7 +442,7 @@ export default {
|
||||
headers: {
|
||||
'Content-Type': 'multipart/form-data'
|
||||
}
|
||||
}).then((res) => {
|
||||
}).then(async (res) => {
|
||||
this.loading_ = false;
|
||||
if (res.data.status === "error") {
|
||||
this.onLoadingDialogResult(2, res.data.message, -1);
|
||||
@@ -358,7 +453,8 @@ export default {
|
||||
this.onLoadingDialogResult(1, res.data.message);
|
||||
this.dialog = false;
|
||||
this.getExtensions();
|
||||
// this.$refs.wfr.check();
|
||||
|
||||
await this.showReadmeDialog(res);
|
||||
}).catch((err) => {
|
||||
this.loading_ = false;
|
||||
this.onLoadingDialogResult(2, err, -1);
|
||||
@@ -370,7 +466,7 @@ export default {
|
||||
{
|
||||
url: this.extension_url,
|
||||
proxy: localStorage.getItem('selectedGitHubProxy') || ""
|
||||
}).then((res) => {
|
||||
}).then(async (res) => {
|
||||
this.loading_ = false;
|
||||
this.toast(res.data.message, res.data.status === "ok" ? "success" : "error");
|
||||
if (res.data.status === "error") {
|
||||
@@ -382,7 +478,7 @@ export default {
|
||||
this.onLoadingDialogResult(1, res.data.message);
|
||||
this.dialog = false;
|
||||
this.getExtensions();
|
||||
// this.$refs.wfr.check();
|
||||
await this.showReadmeDialog(res);
|
||||
}).catch((err) => {
|
||||
this.loading_ = false;
|
||||
this.toast("安装插件失败: " + err, "error");
|
||||
@@ -412,8 +508,157 @@ export default {
|
||||
}
|
||||
}
|
||||
this.pluginMarketData = notInstalled.concat(installed);
|
||||
}
|
||||
},
|
||||
openReadmeInNewTab() {
|
||||
if (this.readmeDialog.url) {
|
||||
window.open(this.readmeDialog.url, '_blank');
|
||||
}
|
||||
},
|
||||
renderMarkdown(content) {
|
||||
if (!content) return '';
|
||||
// Configure marked with highlight.js for syntax highlighting
|
||||
marked.setOptions({
|
||||
highlight: function (code, lang) {
|
||||
if (lang && hljs.getLanguage(lang)) {
|
||||
try {
|
||||
return hljs.highlight(code, { language: lang }).value;
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
return hljs.highlightAuto(code).value;
|
||||
},
|
||||
gfm: true, // GitHub Flavored Markdown
|
||||
breaks: true, // Convert \n to <br>
|
||||
headerIds: true, // Add id attributes to headers
|
||||
mangle: false // Don't mangle email addresses
|
||||
});
|
||||
return marked(content);
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
</script>
|
||||
</script>
|
||||
|
||||
<style>
|
||||
.markdown-body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
padding: 8px 0;
|
||||
color: #24292e;
|
||||
}
|
||||
|
||||
.markdown-body h1,
|
||||
.markdown-body h2,
|
||||
.markdown-body h3,
|
||||
.markdown-body h4,
|
||||
.markdown-body h5,
|
||||
.markdown-body h6 {
|
||||
margin-top: 24px;
|
||||
margin-bottom: 16px;
|
||||
font-weight: 600;
|
||||
line-height: 1.25;
|
||||
}
|
||||
|
||||
.markdown-body h1 {
|
||||
font-size: 2em;
|
||||
border-bottom: 1px solid #eaecef;
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
|
||||
.markdown-body h2 {
|
||||
font-size: 1.5em;
|
||||
border-bottom: 1px solid #eaecef;
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
|
||||
.markdown-body p {
|
||||
margin-top: 0;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body code {
|
||||
padding: 0.2em 0.4em;
|
||||
margin: 0;
|
||||
background-color: rgba(27, 31, 35, 0.05);
|
||||
border-radius: 3px;
|
||||
font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
|
||||
font-size: 85%;
|
||||
}
|
||||
|
||||
.markdown-body pre {
|
||||
padding: 16px;
|
||||
overflow: auto;
|
||||
font-size: 85%;
|
||||
line-height: 1.45;
|
||||
background-color: #f6f8fa;
|
||||
border-radius: 3px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body pre code {
|
||||
background-color: transparent;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.markdown-body ul,
|
||||
.markdown-body ol {
|
||||
padding-left: 2em;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body img {
|
||||
max-width: 100%;
|
||||
margin: 8px 0;
|
||||
box-sizing: border-box;
|
||||
background-color: #fff;
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-body blockquote {
|
||||
padding: 0 1em;
|
||||
color: #6a737d;
|
||||
border-left: 0.25em solid #dfe2e5;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body a {
|
||||
color: #0366d6;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.markdown-body a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.markdown-body table {
|
||||
border-spacing: 0;
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
overflow: auto;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body table th,
|
||||
.markdown-body table td {
|
||||
padding: 6px 13px;
|
||||
border: 1px solid #dfe2e5;
|
||||
}
|
||||
|
||||
.markdown-body table tr {
|
||||
background-color: #fff;
|
||||
border-top: 1px solid #c6cbd1;
|
||||
}
|
||||
|
||||
.markdown-body table tr:nth-child(2n) {
|
||||
background-color: #f6f8fa;
|
||||
}
|
||||
|
||||
.markdown-body hr {
|
||||
height: 0.25em;
|
||||
padding: 0;
|
||||
margin: 24px 0;
|
||||
background-color: #e1e4e8;
|
||||
border: 0;
|
||||
}
|
||||
</style>
|
||||
@@ -79,5 +79,5 @@ if __name__ == "__main__":
|
||||
# print logo
|
||||
logger.info(logo_tmpl)
|
||||
|
||||
dashboard_lifecycle = InitialLoader(db, log_broker)
|
||||
asyncio.run(dashboard_lifecycle.start())
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
asyncio.run(core_lifecycle.start())
|
||||
|
||||
+53
-11
@@ -2,6 +2,8 @@ import aiohttp
|
||||
import datetime
|
||||
import builtins
|
||||
import traceback
|
||||
import re
|
||||
import zoneinfo
|
||||
import astrbot.api.star as star
|
||||
import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
@@ -21,7 +23,6 @@ from astrbot.core.config.default import VERSION
|
||||
from .long_term_memory import LongTermMemory
|
||||
from astrbot.core import logger
|
||||
from astrbot.api.message_components import Plain, Image, Reply
|
||||
|
||||
from typing import Union
|
||||
|
||||
|
||||
@@ -38,7 +39,12 @@ class Main(star.Star):
|
||||
self.prompt_prefix = cfg["provider_settings"]["prompt_prefix"]
|
||||
self.identifier = cfg["provider_settings"]["identifier"]
|
||||
self.enable_datetime = cfg["provider_settings"]["datetime_system_prompt"]
|
||||
|
||||
self.timezone = cfg.get("timezone")
|
||||
if not self.timezone:
|
||||
# 系统默认时区
|
||||
self.timezone = None
|
||||
else:
|
||||
logger.info(f"Timezone set to: {self.timezone}")
|
||||
self.ltm = None
|
||||
if (
|
||||
self.context.get_config()["provider_ltm_settings"]["group_icl_enable"]
|
||||
@@ -87,6 +93,7 @@ class Main(star.Star):
|
||||
/alter_cmd: 设置指令权限(op)
|
||||
|
||||
[大模型]
|
||||
/llm: 开启/关闭 LLM
|
||||
/provider: 大模型提供商
|
||||
/model: 模型列表
|
||||
/ls: 对话列表
|
||||
@@ -95,7 +102,7 @@ class Main(star.Star):
|
||||
/switch 序号: 切换对话
|
||||
/rename 新名字: 重命名当前对话
|
||||
/del: 删除当前会话对话(op)
|
||||
/reset: 重置 LLM 会话(op)
|
||||
/reset: 重置 LLM 会话
|
||||
/history: 当前对话的对话记录
|
||||
/persona: 人格情景(op)
|
||||
/tool ls: 函数工具
|
||||
@@ -105,6 +112,20 @@ class Main(star.Star):
|
||||
|
||||
event.set_result(MessageEventResult().message(msg).use_t2i(False))
|
||||
|
||||
@filter.command("llm")
|
||||
async def llm(self, event: AstrMessageEvent):
|
||||
"""开启/关闭 LLM"""
|
||||
cfg = self.context.get_config()
|
||||
enable = cfg["provider_settings"]["enable"]
|
||||
if enable:
|
||||
cfg["provider_settings"]["enable"] = False
|
||||
status = "关闭"
|
||||
else:
|
||||
cfg["provider_settings"]["enable"] = True
|
||||
status = "开启"
|
||||
cfg.save_config()
|
||||
yield event.plain_result(f"{status} LLM 聊天功能。")
|
||||
|
||||
@filter.command_group("tool")
|
||||
def tool(self):
|
||||
pass
|
||||
@@ -520,15 +541,18 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
MessageEventResult().message("未找到任何 LLM 提供商。请先配置。")
|
||||
)
|
||||
return
|
||||
# 定义正则表达式匹配 API 密钥
|
||||
api_key_pattern = re.compile(r"key=[^&'\" ]+")
|
||||
|
||||
if idx_or_name is None:
|
||||
models = []
|
||||
try:
|
||||
models = await self.context.get_using_provider().get_models()
|
||||
except BaseException as e:
|
||||
err_msg = api_key_pattern.sub("key=***", str(e))
|
||||
message.set_result(
|
||||
MessageEventResult()
|
||||
.message("获取模型列表失败: " + str(e))
|
||||
.message("获取模型列表失败: " + err_msg)
|
||||
.use_t2i(False)
|
||||
)
|
||||
return
|
||||
@@ -754,7 +778,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
)
|
||||
else:
|
||||
message.set_result(
|
||||
MessageEventResult().message("请输入群聊 ID。/newgroup 群聊ID。")
|
||||
MessageEventResult().message("请输入群聊 ID。/groupnew 群聊ID。")
|
||||
)
|
||||
|
||||
@filter.command("switch")
|
||||
@@ -950,7 +974,8 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
if len(l) == 1:
|
||||
message.set_result(
|
||||
MessageEventResult()
|
||||
.message(f"""[Persona]
|
||||
.message(
|
||||
f"""[Persona]
|
||||
|
||||
- 人格情景列表: `/persona list`
|
||||
- 设置人格情景: `/persona 人格`
|
||||
@@ -961,7 +986,8 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
当前对话 {curr_cid_title} 的人格情景: {curr_persona_name}
|
||||
|
||||
配置人格情景请前往管理面板-配置页
|
||||
""")
|
||||
"""
|
||||
)
|
||||
.use_t2i(False)
|
||||
)
|
||||
elif l[1] == "list":
|
||||
@@ -999,6 +1025,13 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
message.set_result(MessageEventResult().message("取消人格成功。"))
|
||||
else:
|
||||
ps = "".join(l[1:]).strip()
|
||||
if not cid:
|
||||
message.set_result(
|
||||
MessageEventResult().message(
|
||||
"当前没有对话,请先开始对话或使用 /new 创建一个对话。"
|
||||
)
|
||||
)
|
||||
return
|
||||
if persona := next(
|
||||
builtins.filter(
|
||||
lambda persona: persona["name"] == ps,
|
||||
@@ -1164,11 +1197,20 @@ UID: {user_id} 此 ID 可用于设置管理员。
|
||||
user_info = f"\n[User ID: {user_id}, Nickname: {user_nickname}]\n"
|
||||
req.prompt = user_info + req.prompt
|
||||
|
||||
# 启用附加时间戳
|
||||
if self.enable_datetime:
|
||||
# Including timezone
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
current_time = None
|
||||
if self.timezone:
|
||||
# 启用时区
|
||||
try:
|
||||
now = datetime.datetime.now(zoneinfo.ZoneInfo(self.timezone))
|
||||
current_time = now.strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
except Exception as e:
|
||||
logger.error(f"时区设置错误: {e}, 使用本地时区")
|
||||
if not current_time:
|
||||
current_time = (
|
||||
datetime.datetime.now().astimezone().strftime("%Y-%m-%d %H:%M (%Z)")
|
||||
)
|
||||
req.system_prompt += f"\nCurrent datetime: {current_time}\n"
|
||||
|
||||
if req.conversation:
|
||||
|
||||
@@ -2,6 +2,7 @@ import os
|
||||
import json
|
||||
import datetime
|
||||
import uuid
|
||||
import zoneinfo
|
||||
import astrbot.api.star as star
|
||||
from astrbot.api.event import filter
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
@@ -17,7 +18,15 @@ class Main(star.Star):
|
||||
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
self.scheduler = AsyncIOScheduler(timezone="Asia/Shanghai")
|
||||
self.timezone = self.context.get_config().get("timezone")
|
||||
if not self.timezone:
|
||||
self.timezone = None
|
||||
try:
|
||||
self.timezone = zoneinfo.ZoneInfo(self.timezone) if self.timezone else None
|
||||
except Exception as e:
|
||||
logger.error(f"时区设置错误: {e}, 使用本地时区")
|
||||
self.timezone = None
|
||||
self.scheduler = AsyncIOScheduler(timezone=self.timezone)
|
||||
|
||||
# set and load config
|
||||
if not os.path.exists("data/astrbot-reminder.json"):
|
||||
@@ -65,10 +74,10 @@ class Main(star.Star):
|
||||
def check_is_outdated(self, reminder: dict):
|
||||
"""Check if the reminder is outdated."""
|
||||
if "datetime" in reminder:
|
||||
return (
|
||||
datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M")
|
||||
< datetime.datetime.now()
|
||||
)
|
||||
reminder_time = datetime.datetime.strptime(
|
||||
reminder["datetime"], "%Y-%m-%d %H:%M"
|
||||
).replace(tzinfo=self.timezone)
|
||||
return reminder_time < datetime.datetime.now(self.timezone)
|
||||
return False
|
||||
|
||||
async def _save_data(self):
|
||||
@@ -171,12 +180,15 @@ class Main(star.Star):
|
||||
reminders = self.reminder_data.get(unified_msg_origin, [])
|
||||
if not reminders:
|
||||
return []
|
||||
now = datetime.datetime.now()
|
||||
now = datetime.datetime.now(self.timezone)
|
||||
upcoming_reminders = [
|
||||
reminder
|
||||
for reminder in reminders
|
||||
if "datetime" not in reminder
|
||||
or datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M") >= now
|
||||
or datetime.datetime.strptime(
|
||||
reminder["datetime"], "%Y-%m-%d %H:%M"
|
||||
).replace(tzinfo=self.timezone)
|
||||
>= now
|
||||
]
|
||||
return upcoming_reminders
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import astrbot.api.message_components as Comp
|
||||
import copy
|
||||
import json
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, filter
|
||||
from astrbot.api.star import Context, Star, register
|
||||
@@ -64,17 +63,11 @@ class Waiter(Star):
|
||||
event.unified_msg_origin
|
||||
)
|
||||
conversation = None
|
||||
context = []
|
||||
|
||||
if curr_cid:
|
||||
conversation = await self.context.conversation_manager.get_conversation(
|
||||
event.unified_msg_origin, curr_cid
|
||||
)
|
||||
context = (
|
||||
json.loads(conversation.history)
|
||||
if conversation.history
|
||||
else []
|
||||
)
|
||||
else:
|
||||
# 创建新对话
|
||||
curr_cid = await self.context.conversation_manager.new_conversation(
|
||||
@@ -83,10 +76,10 @@ class Waiter(Star):
|
||||
|
||||
# 使用 LLM 生成回复
|
||||
yield event.request_llm(
|
||||
prompt="用户只是@我或唤醒我,请友好地询问用户想要聊些什么或者需要什么帮助,回复要符合人设,不要太过机械化。",
|
||||
prompt="注意,你正在社交媒体上中与用户进行聊天,用户只是通过@来唤醒你,但并未在这条消息中输入内容,他可能会在接下来一条发送他想发送的内容。请你友好地询问用户想要聊些什么或者需要什么帮助,回复要符合人设,不要太过机械化。注意,你仅需要输出要回复用户的内容,不要输出其他任何东西",
|
||||
func_tool_manager=func_tools_mgr,
|
||||
session_id=curr_cid,
|
||||
contexts=context,
|
||||
contexts=[],
|
||||
system_prompt="",
|
||||
conversation=conversation,
|
||||
)
|
||||
@@ -113,16 +106,7 @@ class Waiter(Star):
|
||||
try:
|
||||
await empty_mention_waiter(event)
|
||||
except TimeoutError as _:
|
||||
try:
|
||||
# 超时时也尝试使用 LLM 生成回复
|
||||
yield event.request_llm(
|
||||
prompt="用户在提问后超时未回复,请生成一个温馨友好的提醒,告诉用户如果需要帮助可以再次提问,回答要符合人设。",
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
system_prompt="",
|
||||
)
|
||||
except Exception:
|
||||
# LLM 回复失败,使用原始预设回复
|
||||
yield event.plain_result("如果需要帮助,请再次 @ 我哦~")
|
||||
pass
|
||||
except Exception as e:
|
||||
yield event.plain_result("发生错误,请联系管理员: " + str(e))
|
||||
finally:
|
||||
|
||||
@@ -35,6 +35,7 @@ dependencies = [
|
||||
"quart>=0.20.0",
|
||||
"readability-lxml>=0.8.1",
|
||||
"silk-python>=0.2.6",
|
||||
"telegramify-markdown>=0.5.0",
|
||||
"wechatpy>=1.8.18",
|
||||
]
|
||||
|
||||
|
||||
+2
-1
@@ -28,4 +28,5 @@ dingtalk-stream
|
||||
defusedxml
|
||||
mcp
|
||||
certifi
|
||||
pip
|
||||
pip
|
||||
telegramify-markdown
|
||||
@@ -225,6 +225,7 @@ dependencies = [
|
||||
{ name = "quart" },
|
||||
{ name = "readability-lxml" },
|
||||
{ name = "silk-python" },
|
||||
{ name = "telegramify-markdown" },
|
||||
{ name = "wechatpy" },
|
||||
]
|
||||
|
||||
@@ -260,6 +261,7 @@ requires-dist = [
|
||||
{ name = "quart", specifier = ">=0.20.0" },
|
||||
{ name = "readability-lxml", specifier = ">=0.8.1" },
|
||||
{ name = "silk-python", specifier = ">=0.2.6" },
|
||||
{ name = "telegramify-markdown", specifier = ">=0.5.0" },
|
||||
{ name = "wechatpy", specifier = ">=1.8.18" },
|
||||
]
|
||||
|
||||
@@ -1059,6 +1061,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/d1/3ff566ecf322077d861f1a68a1ff025cad337417bd66ad22a7c6f7dfcfaf/mcp-1.5.0-py3-none-any.whl", hash = "sha256:51c3f35ce93cb702f7513c12406bbea9665ef75a08db909200b07da9db641527", size = 73734 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mistletoe"
|
||||
version = "1.4.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/11/96/ea46a376a7c4cd56955ecdfff0ea68de43996a4e6d1aee4599729453bd11/mistletoe-1.4.0.tar.gz", hash = "sha256:1630f906e5e4bbe66fdeb4d29d277e2ea515d642bb18a9b49b136361a9818c9d", size = 107203 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/0f/b5e545f0c7962be90366af3418989b12cf441d9da1e5d89d88f2f3e5cf8f/mistletoe-1.4.0-py3-none-any.whl", hash = "sha256:44a477803861de1237ba22e375c6b617690a31d2902b47279d1f8f7ed498a794", size = 51304 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "multidict"
|
||||
version = "6.2.0"
|
||||
@@ -1792,6 +1803,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/d1/b1/74babcc824a57904e919f3af16d86c08b524c0691504baf038ef2d7f655c/taskgroup-0.2.2-py2.py3-none-any.whl", hash = "sha256:e2c53121609f4ae97303e9ea1524304b4de6faf9eb2c9280c7f87976479a52fb", size = 14237 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "telegramify-markdown"
|
||||
version = "0.5.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "mistletoe" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f6/b7/a9b56f856f87e1b4d743932353f71811844a413561be180a22d667ef6f5a/telegramify_markdown-0.5.0.tar.gz", hash = "sha256:70e6eff7e341e6e9c8818fa1ec53a4e25e4f5e3ef50856d7772760fc6b7a4066", size = 36017 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/df/87/4a39dde5b4aea91b874cbbe057edea19334e62651ce0f1f74f5a1f721439/telegramify_markdown-0.5.0-py3-none-any.whl", hash = "sha256:6f66b7029c0eba268fed5f9daf9216f56c588c6202dd591ff572f7df0d318f2f", size = 32389 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.2.1"
|
||||
|
||||
Reference in New Issue
Block a user