Compare commits

...

70 Commits

Author SHA1 Message Date
Soulter 94618e8feb feat: 添加 aiodocker 依赖 2025-01-11 22:02:15 +08:00
Soulter 55de7d4494 🎉 Bump to v3.4.5 2025-01-11 21:40:48 +08:00
Soulter 7ed639f741 🎉 bump to v3.4.5 2025-01-11 21:06:06 +08:00
Soulter 41f2870c29 Merge pull request #236 from Soulter/feat-stt
支持 Speech To Text,并适配腾讯修改过的 Silk 语音格式
2025-01-11 21:00:04 +08:00
Soulter ba198490fa feat: 支持自部署 Whisper 模型 2025-01-11 20:31:21 +08:00
Soulter 0f9ab082ab perf: 优化webchat,没有结果返回时的反馈 2025-01-11 19:45:42 +08:00
Soulter 97b58965f2 feat: webchat可显示Provider状态 2025-01-11 19:31:56 +08:00
Soulter f2566c68e3 feat: 按 K 语音 2025-01-11 19:07:26 +08:00
Soulter a456bf5449 fix: 初始化reminder时的一些问题 2025-01-11 18:55:18 +08:00
Soulter a09998f910 feat: webchat 支持语音输入 2025-01-11 18:54:40 +08:00
Soulter be662b913c feat: 支持 Whisper STT,并适配 Tencent 语音格式 2025-01-11 17:19:28 +08:00
Soulter e7ddc8448d perf: 代码执行器在成功执行后清空文件buffer 2025-01-11 11:31:56 +08:00
Soulter 29374f8d8a fix: 修复 /dashbord_update 指令 2025-01-11 00:25:02 +08:00
Soulter 359b971103 Merge pull request #235 from Soulter/feat-webchat
WebChat 支持
2025-01-11 00:17:18 +08:00
Soulter fbdb1ae208 chore: bump to v3.4.4 2025-01-11 00:14:08 +08:00
Soulter 22c13c1eff perf: webchat支持传图 2025-01-11 00:06:19 +08:00
Soulter 5fc63aeaf1 perf: ui 2025-01-10 22:45:14 +08:00
Soulter d4f32673ab fix: 修复持久化问题 2025-01-10 22:08:43 +08:00
Soulter 480dffb51b feat: 初步实现 webchat 页面 2025-01-10 21:48:15 +08:00
Soulter 966df00124 feat: 支持从管理面板(控制台页)手动安装 pip 库 2025-01-10 15:35:57 +08:00
Soulter 3e2b4bc727 feat: 支持动态设置会话变量以适用 Dify 输入变量 2025-01-10 12:32:20 +08:00
Soulter 5929a8d42b Update README.md 2025-01-09 23:11:11 +08:00
Soulter f8ab40eb39 chore: 上传管理面板package.json 2025-01-09 22:25:46 +08:00
Soulter 55e9233b93 docs: v3.4.3 changelog 2025-01-09 22:19:11 +08:00
Soulter b7277b51fd feat: 管理面板支持显示不在metadata中的配置 2025-01-09 22:03:53 +08:00
Soulter 1fa9111b2b perf: 进一步防止llm递归调用 2025-01-09 22:03:22 +08:00
Soulter 90a9e496d9 feat: 适配器类插件支持设置默认配置模板 2025-01-09 19:45:18 +08:00
Soulter 2a7dce1eb0 chore: clean code 2025-01-09 16:34:39 +08:00
Soulter 0c0841cc03 fix: websearch 在 cmd_config 中失效的问题 2025-01-09 16:33:58 +08:00
Soulter 4c9fe016bf fix: test_pipeline 2025-01-09 16:00:43 +08:00
Soulter acc90f140c chore: bump dashboard_release_url 2025-01-09 15:50:24 +08:00
Soulter 68a7bc3930 Merge pull request #232 from Soulter/feat-python-interpreter
初步实现代码执行器
2025-01-09 15:43:40 +08:00
Soulter 12ea64be0e fix: dashboard command bug 2025-01-09 15:42:04 +08:00
Soulter 7f30a673f7 fix: 修复 qq_official 无法发图 2025-01-09 15:20:54 +08:00
Soulter 897e100c32 Merge pull request #234 from Soulter/233-gemini-native-support
支持通过 Google GenAI 访问 Gemini 模型
2025-01-09 14:23:44 +08:00
Soulter 0d4ad5cb31 fix: 修复 APScheduler 任务错过后不执行的问题 2025-01-09 14:23:07 +08:00
Soulter b124bd0d0e feat: 支持通过 Google GenAI 访问 Gemini 模型 2025-01-09 14:05:48 +08:00
Soulter 6bc2f84602 Update README.md
qingcloud 在新网的账户余额不足导致原域名无法续费
2025-01-09 10:35:02 +08:00
Soulter d787a28c40 feat: 支持使用 /dashboard update 更新管理面板 2025-01-09 00:59:28 +08:00
Soulter 6b078a5731 cd: build dashboard files automatically 2025-01-09 00:57:48 +08:00
Soulter 17dddbfe21 chore: 禁用插件 2025-01-08 23:34:54 +08:00
Soulter 3ff3c9e144 perf: 检测到docker不可用时自动禁用本插件 2025-01-08 23:32:49 +08:00
Soulter f5a37d82cc Merge branch 'master' into feat-python-interpreter 2025-01-08 23:13:52 +08:00
Soulter d3d428dc9d fix: 管理面板支持禁用/启用插件 2025-01-08 23:04:03 +08:00
Soulter 8dc8c5b5dc feat: 支持对插件禁用/启用 2025-01-08 22:28:20 +08:00
Soulter e6b06f914b perf: provider 偏好项记忆 2025-01-08 20:46:34 +08:00
Soulter 4dc502a8b6 fix: 修复事件监听器会让wakestage失效的问题 2025-01-08 20:24:01 +08:00
Soulter b1d1a13d5f perf: 支持图片输入 2025-01-08 19:56:03 +08:00
Soulter 75cc4cac5a perf: 代码执行器添加部分控制指令,添加更多可用库 2025-01-08 13:26:16 +08:00
Soulter 1b7e4fbbdc perf: 退出时关闭 aiohttp client session 2025-01-08 12:43:34 +08:00
Soulter 9789e2f6c1 perf: 代码执行器请求llm不持久化历史记录 2025-01-08 02:12:35 +08:00
Soulter b8fb0bee24 feat: 初步实现代码执行器 #210 2025-01-08 02:10:27 +08:00
Soulter 419f77e245 Update README.md 2025-01-07 20:56:25 +08:00
Soulter 59b1c3473b Merge pull request #230 from Soulter/feat-dify
接入 Dify
2025-01-07 20:14:33 +08:00
Soulter 6db58ca375 perf: 优化在prompt为空的情况下不请求provider 2025-01-07 20:01:47 +08:00
Soulter 4832b342b0 Merge branch 'master' into feat-dify 2025-01-07 19:59:54 +08:00
Soulter 6cec542402 feat: 初步接入 Dify 2025-01-07 19:56:18 +08:00
Soulter 9644791783 feat: kdb 2024-12-30 18:06:09 +08:00
Soulter 5031c307d1 update: readme 2024-12-26 23:39:29 +08:00
Soulter aa49539e3e chore: fix test 2024-12-26 23:33:40 +08:00
Soulter 7b4118493b chore: fix test 2024-12-26 23:15:10 +08:00
Soulter d1cc9ba4ce chore: update test workflow 2024-12-26 23:09:11 +08:00
Soulter e0e92139d7 fix: test workflow 2024-12-26 23:07:50 +08:00
Soulter 62039392bb chore: fix test workflow 2024-12-26 23:06:30 +08:00
Soulter b72c69892e test: dashboard test 2024-12-26 22:59:17 +08:00
Soulter e6205e9aad ci: update workflow 2024-12-25 17:18:29 +08:00
Soulter b8a6fb1720 chore: update tests 2024-12-25 12:50:29 +08:00
Soulter 7c06d82f27 perf: plugin manager 重复 reload 释放资源 2024-12-25 12:50:29 +08:00
Soulter d92cb0f500 perf: 当没有provider时直接返回 2024-12-25 12:50:29 +08:00
Soulter 7fa72f2fe9 perf: adapt glm-4v-flash 2024-12-24 14:08:20 +08:00
95 changed files with 14175 additions and 231 deletions
+12 -1
View File
@@ -2,6 +2,7 @@ on:
push:
tags:
- 'v*'
workflow_dispatch:
name: Auto Release
@@ -14,6 +15,15 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v4
- name: Dashboard Build
run: |
cd dashboard
npm install
npm run build
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
echo ${{ github.ref_name }} > dist/assets/version
zip -r dist.zip dist
- name: Fetch Changelog
run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
@@ -21,4 +31,5 @@ jobs:
- name: Create Release
uses: ncipollo/release-action@v1
with:
bodyFile: ${{ env.changelog }}
bodyFile: ${{ env.changelog }}
artifacts: "dashboard/dist.zip"
+15 -9
View File
@@ -1,7 +1,14 @@
name: Run tests and upload coverage
on:
push
push:
branches:
- master
paths-ignore:
- 'README.md'
- 'changelogs/**'
- 'dashboard/**'
workflow_dispatch:
jobs:
test:
@@ -21,17 +28,16 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
mkdir data
mkdir data/plugins
mkdir data/config
mkdir temp
- name: Run tests
run: |
export LLM_MODEL=${{ secrets.LLM_MODEL }}
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v
mkdir data
mkdir data/plugins
mkdir data/config
mkdir data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
+2 -1
View File
@@ -20,4 +20,5 @@ chroma
node_modules/
.DS_Store
package-lock.json
package.json
package.json
venv/*
+61 -10
View File
@@ -14,9 +14,10 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
</a>
<a href="https://astrbot.soulter.top/">查看文档</a>
<a href="https://astrbot.lwl.lol/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
@@ -24,7 +25,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
## ✨ 多消息平台部署
1. QQ 群、QQ 频道、微信、Telegram。
1. QQ 群、QQ 频道、微信个人号、Telegram。
2. 支持文本转图片,Markdown 渲染。
## ✨ 多 LLM 配置
@@ -33,7 +34,8 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
2. 支持 OneAPI 等分发平台。
3. 支持 LLMTuner 载入微调模型。
4. 支持 Ollama 载入自部署模型。
4. 支持网页搜索(Web Search)。
4. 支持网页搜索(Web Search、自然语言待办提醒
5. 支持 Whisper 语音转文字
## ✨ 管理面板
@@ -42,15 +44,23 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
3. 简单的信息统计
4. 插件管理
<!-- ## ✨ ATRI [Beta 测试]
## ✨ 支持 Dify
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流![接入 Dify - AstrBot 文档](https://astrbot.lwl.lol/others/dify.html)
## ✨ 代码执行器(Beta)
基于 Docker 的沙箱化代码执行器(Beta 测试中)
> [!NOTE]
> 文件输入/输出目前仅支持 Napcat(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
</div>
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
## ✨ 云部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
@@ -71,3 +81,44 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
## ✨ Demo
<div align='center'>
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ 自然语言待办事项 ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
_✨ 管理面板 ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内置 Web Chat,在线与机器人交互 ✨_
</div>
<!-- ## ✨ ATRI [Beta 测试]
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
+2
View File
@@ -2,6 +2,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
__all__ = [
@@ -10,4 +11,5 @@ __all__ = [
"personalities",
"html_renderer",
"llm_tool",
"sp"
]
+2 -2
View File
@@ -1,2 +1,2 @@
from astrbot.core.provider import Provider, Personality, ProviderMetaData
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData
+14 -1
View File
@@ -1,12 +1,25 @@
import os
import asyncio
from .log import LogManager, LogBroker
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
os.makedirs("data", exist_ok=True)
astrbot_config = AstrBotConfig()
html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot')
if os.environ.get('TESTING', ""):
logger.setLevel('DEBUG')
db_helper = SQLiteDatabase(DB_PATH)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
sp = SharedPreferences() # 简单的偏好设置存储
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
+91 -7
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.2"
VERSION = "3.4.5"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -33,6 +33,10 @@ DEFAULT_CONFIG = {
"default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
"prompt_prefix": "",
},
"provider_stt_settings": {
"enable": False,
"provider_id": "",
},
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
@@ -50,7 +54,8 @@ DEFAULT_CONFIG = {
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": ""
"plugin_repo_mirror": "",
"knowledge_db": {},
}
@@ -230,10 +235,10 @@ CONFIG_METADATA_2 = {
},
},
"provider_group": {
"name": "大语言模型",
"name": "服务提供商",
"metadata": {
"provider": {
"description": "大语言模型配置",
"description": "服务提供商配置",
"type": "list",
"config_template": {
"openai": {
@@ -256,7 +261,7 @@ CONFIG_METADATA_2 = {
"model": "llama3.1-8b",
},
},
"gemini": {
"gemini(OpenAI兼容)": {
"id": "gemini_default",
"type": "openai_chat_completion",
"enable": True,
@@ -266,6 +271,16 @@ CONFIG_METADATA_2 = {
"model": "gemini-1.5-flash",
},
},
"gemini(googlegenai原生)": {
"id": "gemini_default",
"type": "googlegenai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/",
"model_config": {
"model": "gemini-1.5-flash",
},
},
"deepseek": {
"id": "deepseek_default",
"type": "openai_chat_completion",
@@ -278,7 +293,7 @@ CONFIG_METADATA_2 = {
},
"zhipu": {
"id": "zhipu_default",
"type": "openai_chat_completion",
"type": "zhipu_chat_completion",
"enable": True,
"key": [],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
@@ -295,9 +310,39 @@ CONFIG_METADATA_2 = {
"llmtuner_template": "",
"finetuning_type": "lora",
"quantization_bit": 4,
},
"dify": {
"id": "dify_app_default",
"type": "dify",
"enable": True,
"dify_api_type": "chat",
"dify_api_key": "",
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
},
"whisper(API)": {
"id": "whisper",
"type": "openai_whisper_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "whisper-1",
},
"whisper(本地加载)": {
"whisper_hint": "(不用修改我)",
"enable": False,
"id": "whisper",
"type": "openai_whisper_selfhost",
"model": "tiny",
}
},
"items": {
"whisper_hint": {
"description": "本地部署 Whisper 模型须知",
"type": "string",
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cudaCPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"obvious_hint": True
},
"id": {
"description": "ID",
"type": "string",
@@ -366,6 +411,27 @@ CONFIG_METADATA_2 = {
"top_p": {"description": "Top P值", "type": "float"},
},
},
"dify_api_key": {
"description": "API Key",
"type": "string",
"hint": "Dify API Key。此项必填。",
},
"dify_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Dify API Base URL。默认为 https://api.dify.ai/v1",
},
"dify_api_type": {
"description": "Dify 应用类型",
"type": "string",
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
"options": ["chat", "agent", "workflow"],
},
"dify_workflow_output_key": {
"description": "Dify Workflow 输出变量名",
"type": "string",
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
}
},
},
"provider_settings": {
@@ -375,7 +441,8 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用大语言模型聊天",
"type": "bool",
"hint": "是否启用大语言模型聊天。默认启用",
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
"obvious_hint": True
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
@@ -409,6 +476,23 @@ CONFIG_METADATA_2 = {
},
},
},
"provider_stt_settings": {
"description": "语音转文本(STT)",
"type": "object",
"items": {
"enable": {
"description": "启用语音转文本(STT)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
"obvious_hint": True
},
"provider_id": {
"description": "提供商 ID,不填则默认第一个STT提供商",
"type": "string",
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
},
},
"misc_config_group": {
+20 -6
View File
@@ -3,6 +3,7 @@ import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -16,11 +17,12 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = AstrBotConfig()
self.astrbot_config = astrbot_config
self.db = db
if self.astrbot_config['http_proxy']:
@@ -29,7 +31,10 @@ class AstrBotCoreLifecycle:
async def initialize(self):
logger.info("AstrBot v"+ VERSION)
logger.setLevel(self.astrbot_config['log_level'])
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
else:
logger.setLevel(self.astrbot_config['log_level'])
self.event_queue = Queue()
self.event_queue.closed = False
@@ -37,12 +42,19 @@ class AstrBotCoreLifecycle:
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
self.star_context.platform_manager = self.platform_manager
self.star_context.provider_manager = self.provider_manager
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.knowledge_db_manager
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
self.plugin_manager.reload()
await self.plugin_manager.reload()
'''扫描、注册插件、实例化插件类'''
await self.provider_manager.initialize()
@@ -81,6 +93,8 @@ class AstrBotCoreLifecycle:
self.event_queue.closed = True
for task in self.curr_tasks:
task.cancel()
await self.provider_manager.terminate()
for task in self.curr_tasks:
try:
+25 -1
View File
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
@dataclass
class BaseDatabase(abc.ABC):
@@ -76,4 +76,28 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
'''通过 url 或 path 获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
'''通过 user_id 和 cid 获取 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def webchat_new_conversation(self, user_id: str, cid: str):
'''新建 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
raise NotImplementedError
@abc.abstractmethod
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
'''更新 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_webchat_conversation(self, user_id: str, cid: str):
'''删除 WebChatConversation'''
raise NotImplementedError
+12 -1
View File
@@ -51,4 +51,15 @@ class ATRIVision():
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
timestamp: int = -1
@dataclass
class WebChatConversation():
user_id: str
cid: str
history: str = ""
created_at: int = 0
updated_at: int = 0
+65 -1
View File
@@ -5,7 +5,8 @@ from astrbot.core.db.po import (
Platform,
Stats,
LLMHistory,
ATRIVision
ATRIVision,
WebChatConversation
)
from . import BaseDatabase
from typing import Tuple
@@ -199,6 +200,69 @@ class SQLiteDatabase(BaseDatabase):
c.close()
return Stats(platform, [], [])
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
res = c.fetchone()
c.close()
return WebChatConversation(*res)
def webchat_new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
'''
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
''', (user_id, cid, history, updated_at, created_at)
)
def get_webchat_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
return conversations
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
''', (history, user_id, cid)
)
def delete_webchat_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision):
+8
View File
@@ -35,4 +35,12 @@ CREATE TABLE IF NOT EXISTS atri_vision(
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER
);
+15 -2
View File
@@ -54,6 +54,7 @@ class ComponentType(Enum):
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
File = "File"
class BaseMessageComponent(BaseModel):
@@ -122,7 +123,7 @@ class Record(BaseMessageComponent):
proxy: T.Optional[bool] = True
timeout: T.Optional[int] = 0
# 额外
path: T.Optional[str]
path: T.Optional[str] # 用这个
def __init__(self, file: T.Optional[str], **_):
for k in _.keys():
@@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent):
def toString(self):
return ""
class File(BaseMessageComponent):
'''
目前此消息段只适配了 Napcat。
'''
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url(本地路径)
def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)
ComponentTypes = {
"plain": Plain,
@@ -441,5 +453,6 @@ ComponentTypes = {
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown
"unknown": Unknown,
'file': File,
}
+3
View File
@@ -3,6 +3,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage
@@ -12,6 +13,7 @@ STAGES_ORDER = [
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitCheckStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage" # 发送消息
@@ -21,6 +23,7 @@ __all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
@@ -0,0 +1,55 @@
import traceback
import asyncio
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.message.components import Plain, Record
@register_stage
class PreProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''在处理事件之前的预处理'''
if self.stt_settings.get('enable', False):
# STT 处理
# TODO: 独立
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
if stt_provider:
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.path:
path = component.path
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break
@@ -0,0 +1,60 @@
'''
Dify 调用 Stage
'''
import traceback
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest
class DifyRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if provider.meta().type != "dify":
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
return
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
req.session_id = event.session_id
event.set_extra("provider_request", req)
if not req.prompt:
return
try:
logger.debug(f"Dify 请求 Payload: {req.__dict__}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
yield # rick roll
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 Dify 失败:" + str(e)))
return
@@ -1,3 +1,6 @@
'''
本地 Agent 模式的 LLM 调用 Stage
'''
import traceback
from typing import Union, AsyncGenerator
from ...context import PipelineContext
@@ -15,10 +18,13 @@ class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if provider is None:
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
@@ -38,6 +44,9 @@ class LLMRequestSubStage(Stage):
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if not req.prompt:
return
# 执行请求 LLM 前事件。
# 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
@@ -48,7 +57,9 @@ class LLMRequestSubStage(Stage):
logger.error(traceback.format_exc())
try:
logger.debug(f"请求 LLM{req.__dict__}")
logger.debug(f"提供商请求 Payload: {req.__dict__}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
@@ -82,7 +93,7 @@ class LLMRequestSubStage(Stage):
for tool_name, tool_result in function_calling_result.items():
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
req.prompt += extra_prompt
async for _ in self.process(event):
async for _ in self.process(event, _nested=True):
yield
except BaseException as e:
@@ -1,3 +1,6 @@
'''
本地 Agent 模式的 AstrBot 插件调用 Stage
'''
from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
@@ -24,7 +27,7 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_str not in star_map:
if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
@@ -36,7 +39,7 @@ class StarRequestSubStage(Stage):
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
+20 -7
View File
@@ -3,6 +3,7 @@ from ..stage import Stage, register_stage
from ..context import PipelineContext
from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from .method.dify_request import DifyRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entites import ProviderRequest
@@ -20,11 +21,15 @@ class ProcessStage(Stage):
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
self.dify_request_sub_stage = DifyRequestSubStage()
await self.dify_request_sub_stage.initialize(ctx)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''处理事件
'''
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
# 有插件 Handler 被激活
if activated_handlers:
async for resp in self.star_request_sub_stage.process(event):
# 生成器返回值处理
@@ -36,10 +41,18 @@ class ProcessStage(Stage):
yield
else:
yield
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
if not event._has_send_oper:
'''当没有发送操作'''
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
async for _ in self.llm_request_sub_stage.process(event):
yield
# 调用提供商相关请求
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
return
if not event._has_send_oper and event.is_at_or_wake_command:
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
provider = self.ctx.plugin_manager.context.get_using_provider()
match provider.meta().type:
case "dify":
async for _ in self.dify_request_sub_stage.process(event):
yield
case _:
async for _ in self.llm_request_sub_stage.process(event):
yield
+3 -1
View File
@@ -22,4 +22,6 @@ class RespondStage(Stage):
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
for handler in handlers:
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
await handler.handler(event)
await handler.handler(event)
event.clear_result()
+4
View File
@@ -41,4 +41,8 @@ class PipelineScheduler():
async def execute(self, event: AstrMessageEvent):
'''执行 pipeline'''
await self._process_stages(event)
if not event._has_send_oper and event.get_platform_name() == "webchat":
await event.send(None)
logger.debug("pipeline 执行完毕。")
-1
View File
@@ -44,7 +44,6 @@ class Stage(abc.ABC):
try:
ready_to_call = handler(event, **params)
except TypeError as e:
print(e)
# 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
@@ -47,6 +47,7 @@ class WakingCheckStage(Stage):
# 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒
break
is_wake = True
event.is_at_or_wake_command = True
event.is_wake = True
event.message_str = event.message_str[len(wake_prefix) :].strip()
break
@@ -60,11 +61,13 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
wake_prefix = ""
event.is_at_or_wake_command = True
break
# 检查是否是私聊
if event.is_private_chat():
is_wake = True
event.is_wake = True
event.is_at_or_wake_command = True
wake_prefix = ""
# 检查插件的 handler filter
@@ -20,6 +20,10 @@ class WhitelistCheckStage(Stage):
if not self.enable_whitelist_check:
return
if event.get_platform_name() == 'webchat':
# WebChat 豁免
return
# 检查是否在白名单
if self.wl_ignore_admin_on_group:
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
+2 -1
View File
@@ -35,7 +35,8 @@ class AstrMessageEvent(abc.ABC):
self.platform_meta = platform_meta
self.session_id = session_id
self.role = "member"
self.is_wake = False
self.is_wake = False # 是否通过 WakingStage
self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
+4 -1
View File
@@ -4,7 +4,7 @@ from typing import List
from asyncio import Queue
from .register import platform_cls_map
from astrbot.core import logger
from .sources.webchat.webchat_adapter import WebChatAdapter
class PlatformManager():
def __init__(self, config: AstrBotConfig, event_queue: Queue):
@@ -25,6 +25,7 @@ class PlatformManager():
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "vchat":
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
async def initialize(self):
for platform in self.platforms_config:
@@ -37,6 +38,8 @@ class PlatformManager():
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst)
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
def get_insts(self):
return self.platform_insts
+3 -1
View File
@@ -2,4 +2,6 @@ from dataclasses import dataclass
@dataclass
class PlatformMetadata():
name: str # 平台的名称
description: str # 平台的描述
description: str # 平台的描述
default_config_tmpl: dict = None # 平台的默认配置模板
+13 -2
View File
@@ -7,15 +7,26 @@ platform_registry: List[PlatformMetadata] = []
platform_cls_map: Dict[str, Type] = {}
'''维护了平台适配器名称和适配器类的映射'''
def register_platform_adapter(adapter_name: str, desc: str):
'''用于注册平台适配器的带参装饰器'''
def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None):
'''用于注册平台适配器的带参装饰器
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
'''
def decorator(cls):
if adapter_name in platform_cls_map:
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
# 添加必备选项
if default_config_tmpl:
if 'type' not in default_config_tmpl:
default_config_tmpl['type'] = adapter_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False
pm = PlatformMetadata(
name=adapter_name,
description=desc,
default_config_tmpl=default_config_tmpl
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
@@ -1,3 +1,4 @@
import os
import time
import asyncio
import logging
@@ -5,12 +6,13 @@ from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from .aiocqhttp_message_event import *
from astrbot.api.message_components import *
from .aiocqhttp_message_event import * # noqa: F403
from astrbot.api.message_components import * # noqa: F403
from astrbot.api import logger
from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
@@ -42,7 +44,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)
def convert_message(self, event: Event) -> AstrBotMessage:
async def convert_message(self, event: Event) -> AstrBotMessage:
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.tag = "aiocqhttp"
@@ -78,7 +80,25 @@ class AiocqhttpAdapter(Platform):
a = None
if t == 'text':
message_str += m['data']['text'].strip()
a = ComponentTypes[t](**m['data'])
elif t == 'file':
try:
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
m['data'] = {
"file": ret['file'],
"name": ret['file_name']
}
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
@@ -91,13 +111,13 @@ class AiocqhttpAdapter(Platform):
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@@ -31,11 +31,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
case botpy.message.C2CMessage:
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
case botpy.message.Message:
if image_path:
@@ -73,9 +75,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text += i.text
elif isinstance(i, Image) and not image_base64:
if i.file and i.file.startswith("file:///"):
image_base64 = file_to_base64(i.file[8:])
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
image_file_path = i.file[8:]
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
return plain_text, image_base64, image_file_path
@@ -2,6 +2,7 @@ import sys
import time
import uuid
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
@@ -62,7 +63,7 @@ class VChatPlatformAdapter(Platform):
self.start_time = int(time.time())
return self._run()
async def _run(self):
await self.client.init()
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
@@ -0,0 +1,110 @@
import time
import asyncio
import uuid
import os
from typing import Awaitable, Any
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image, Record # noqa: F403
from astrbot.api import logger
from astrbot.core import web_chat_queue, web_chat_back_queue
from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
class QueueListener:
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
self.queue = queue
self.callback = callback
async def run(self):
while True:
data = await self.queue.get()
await self.callback(data)
@register_platform_adapter("webchat", "webchat")
class WebChatAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings['unique_session']
self.imgs_dir = "data/webchat/imgs"
self.metadata = PlatformMetadata(
"webchat",
"webchat",
)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
plain = ""
for comp in message_chain.chain:
if isinstance(comp, Plain):
plain += comp.text
web_chat_back_queue.put_nowait(plain)
await super().send_by_session(session, message_chain)
async def convert_message(self, data: tuple) -> AstrBotMessage:
username, cid, payload = data
abm = AstrBotMessage()
abm.self_id = "webchat"
abm.tag = "webchat"
abm.sender = MessageMember(username, username)
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = f"webchat!{username}!{cid}"
abm.message_id = str(uuid.uuid4())
abm.message = []
if payload['message']:
abm.message.append(Plain(payload['message']))
if payload['image_url']:
if isinstance(payload['image_url'], list):
for img in payload['image_url']:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
else:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
if payload['audio_url']:
if isinstance(payload['audio_url'], list):
for audio in payload['audio_url']:
path = os.path.join(self.imgs_dir, audio)
abm.message.append(Record(file=path, path=path))
else:
path = os.path.join(self.imgs_dir, payload['audio_url'])
abm.message.append(Record(file=path, path=path))
logger.debug(f"WebChatAdapter: {abm.message}")
message_str = payload['message']
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = data
return abm
def run(self) -> Awaitable[Any]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
bot = QueueListener(web_chat_queue, callback)
return bot.run()
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
message_event = WebChatMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id
)
self.commit_event(message_event)
@@ -0,0 +1,35 @@
import os
import uuid
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.core import web_chat_back_queue
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.imgs_dir = "data/webchat/imgs"
os.makedirs(self.imgs_dir, exist_ok=True)
async def send(self, message: MessageChain):
if not message:
web_chat_back_queue.put_nowait(None)
return
for comp in message.chain:
if isinstance(comp, Plain):
web_chat_back_queue.put_nowait(comp.text)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
web_chat_back_queue.put_nowait(None)
await super().send(message)
+2 -1
View File
@@ -1,4 +1,4 @@
from .provider import Provider, Personality
from .provider import Provider, Personality, STTProvider
from .entites import ProviderMetaData
@@ -6,4 +6,5 @@ __all__ = [
"Provider",
"Personality",
"ProviderMetaData",
"STTProvider"
]
+17 -8
View File
@@ -1,13 +1,22 @@
from dataclasses import dataclass
from typing import List, Dict
import enum
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
@dataclass
class ProviderMetaData():
type: str # 提供商适配器名称,如 openai, ollama
desc: str = "" # 提供商适配器描述.
type: str
'''提供商适配器名称,如 openai, ollama'''
desc: str = ""
'''提供商适配器描述.'''
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
@dataclass
class ProviderRequest():
@@ -32,9 +41,9 @@ class ProviderRequest():
class LLMResponse:
role: str
'''角色'''
completion_text: str = None
completion_text: str = ""
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = None
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
'''工具调用参数'''
tools_call_name: List[str] = None
tools_call_name: List[str] = field(default_factory=list)
'''工具调用名称'''
@@ -14,6 +14,7 @@ class FuncTool:
parameters: Dict
description: str
handler: Awaitable
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
active: bool = True
'''是否激活'''
@@ -100,10 +101,29 @@ class FuncCall:
}
)
return _l
def get_func_desc_google_genai_style(self) -> Dict:
declarations = {}
tools = []
for f in self.func_list:
if not f.active:
continue
tools.append(
{
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
)
declarations["function_declarations"] = tools
return declarations
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"name": f["name"],
@@ -169,3 +189,10 @@ class FuncCall:
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
def __str__(self):
return str(self.func_list)
def __repr__(self):
return str(self.func_list)
+89 -17
View File
@@ -1,23 +1,36 @@
import traceback
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider
from .provider import Provider, STTProvider
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
from collections import defaultdict
from .register import provider_cls_map, llm_tools
from astrbot.core import logger
from astrbot.core import logger, sp
class ProviderManager():
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
self.providers_config: List = config['provider']
self.provider_settings: dict = config['provider_settings']
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
self.curr_stt_provider_inst: STTProvider = None
'''当前使用的 Speech To Text Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
self.curr_kdb_name = ""
kdb_cfg = config.get("knowledge_db", {})
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
@@ -26,32 +39,91 @@ class ProviderManager():
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}")
self.loaded_ids[provider_cfg['id']] = True
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
try:
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
case "dify":
from .sources.dify_source import ProviderDify # noqa: F401
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
case "openai_whisper_api":
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
except Exception as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
continue
async def initialize(self):
for provider_config in self.providers_config:
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的 大模型提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
cls_type = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
provider_metadata = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
self.provider_insts.append(inst)
# 按任务实例化提供商
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.stt_provider_insts.append(inst)
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if selected_provider_id == provider_config['id'] and provider_enabled:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 失败:{e}")
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
if len(self.provider_insts) > 0:
if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
self.curr_provider_inst = self.provider_insts[0]
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if self.provider_stt_settings.get("enable"):
if not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
def get_insts(self):
return self.provider_insts
return self.provider_insts
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
await provider_inst.terminate()
+27
View File
@@ -125,6 +125,33 @@ class Provider(abc.ABC):
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class STTProvider():
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_text(self, audio_url: str) -> str:
'''获取音频的文本'''
raise NotImplementedError()
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获取当前使用的模型'''
return self.provider_config.get("model", "")
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
+11 -5
View File
@@ -1,16 +1,20 @@
from typing import List, Dict, Type
from .entites import ProviderMetaData
from .entites import ProviderMetaData, ProviderType
from astrbot.core import logger
from .func_tool_manager import FuncCall
provider_registry: List[ProviderMetaData] = []
'''维护了通过装饰器注册的 Provider'''
provider_cls_map: Dict[str, Type] = {}
'''维护了 Provider 类型名称和 Provider 的映射'''
provider_cls_map: Dict[str, ProviderMetaData] = {}
'''维护了 Provider 类型名称和 ProviderMetadata 的映射'''
llm_tools = FuncCall()
def register_provider_adapter(provider_type_name: str, desc: str):
def register_provider_adapter(
provider_type_name: str,
desc: str,
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
if provider_type_name in provider_cls_map:
@@ -19,9 +23,11 @@ def register_provider_adapter(provider_type_name: str, desc: str):
pm = ProviderMetaData(
type=provider_type_name,
desc=desc,
provider_type=provider_type,
cls_type=cls
)
provider_registry.append(pm)
provider_cls_map[provider_type_name] = cls
provider_cls_map[provider_type_name] = pm
logger.debug(f"Provider {provider_type_name} 已注册")
return cls
@@ -0,0 +1,137 @@
from typing import List
from .. import Provider
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import logger, sp
@register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
) -> None:
super().__init__(
provider_config, provider_settings, persistant_history, db_helper
)
self.api_key = provider_config.get("dify_api_key", "")
if not self.api_key:
raise Exception("Dify API Key 不能为空。")
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_client = DifyAPIClient(self.api_key, api_base)
self.api_type = provider_config.get("dify_api_type", "")
if not self.api_type:
raise Exception("Dify API 类型不能为空。")
self.model_name = "dify"
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
self.conversation_ids = {}
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
result = ""
conversation_id = self.conversation_ids.get(session_id, "")
files_payload = []
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
file_response = await self.api_client.file_upload(image_path, user=session_id)
if 'id' not in file_response:
logger.warning(f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。")
continue
files_payload.append({
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response['id'],
})
else:
# TODO: 处理更多情况
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
logger.debug(files_payload)
# 获得会话变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={
**session_var
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk['event'] == "message" or \
chunk['event'] == "agent_message":
result += chunk['answer']
if not conversation_id:
self.conversation_ids[session_id] = chunk['conversation_id']
conversation_id = chunk['conversation_id']
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
"astrbot_text_query": prompt,
"astrbot_session_id": session_id,
**session_var
},
user=session_id,
files=files_payload
):
match chunk['event']:
case "workflow_started":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
case "node_finished":
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
case "workflow_finished":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
if chunk['data']['error']:
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
if self.workflow_output_key not in chunk['data']['outputs']:
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
result = chunk['data']['outputs'][self.workflow_output_key]
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
return LLMResponse(role="assistant", completion_text=result)
async def forget(self, session_id):
self.conversation_ids.pop(session_id, None)
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("Dify 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 Dify 的历史消息记录。")
async def terminate(self):
await self.api_client.close()
@@ -0,0 +1,287 @@
import traceback
import base64
import json
import aiohttp
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
class SimpleGoogleGenAIClient():
def __init__(self, api_key: str, api_base: str):
self.api_key = api_key
if api_base.endswith("/"):
self.api_base = api_base[:-1]
else:
self.api_base = api_base
self.client = aiohttp.ClientSession()
async def models_list(self) -> List[str]:
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
async with self.client.get(request_url, timeout=10) as resp:
response = await resp.json()
models = []
for model in response["models"]:
if 'generateContent' in model["supportedGenerationMethods"]:
models.append(model["name"].replace("models/", ""))
return models
async def generate_content(
self,
contents: List[dict],
model: str="gemini-1.5-flash",
system_instruction: str="",
tools: dict=None
):
payload = {}
if system_instruction:
payload["system_instruction"] = {
"parts": {"text": system_instruction}
}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
logger.debug(f"payload: {payload}")
request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
async with self.client.post(request_url, json=payload, timeout=10) as resp:
response = await resp.json()
return response
@register_provider_adapter("googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器")
class ProviderGoogleGenAI(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.client = SimpleGoogleGenAIClient(
api_key=self.chosen_api_key,
api_base=provider_config.get("api_base", None)
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
return await self.client.models_list()
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
tool = None
if tools:
tool = tools.get_func_desc_google_genai_style()
system_instruction = ""
for message in payloads["messages"]:
if message["role"] == "system":
system_instruction = message["content"]
break
google_genai_conversation = []
for message in payloads["messages"]:
if message["role"] == "user":
if isinstance(message["content"], str):
google_genai_conversation.append({
"role": "user",
"parts": [{"text": message["content"]}]
})
elif isinstance(message["content"], list):
# images
parts = []
for part in message["content"]:
if part["type"] == "text":
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
parts.append({"inline_data": {
"mime_type": "image/jpeg",
"data": part["image_url"]["url"].replace("data:image/jpeg;base64,", "") # base64
}})
google_genai_conversation.append({
"role": "user",
"parts": parts
})
elif message["role"] == "assistant":
google_genai_conversation.append({
"role": "model",
"parts": [{"text": message["content"]}]
})
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
result = await self.client.generate_content(
contents=google_genai_conversation,
model=self.get_model(),
system_instruction=system_instruction,
tools=tool
)
logger.debug(f"result: {result}")
candidates = result["candidates"][0]['content']['parts']
llm_response = LLMResponse("assistant")
for candidate in candidates:
if 'text' in candidate:
llm_response.completion_text += candidate['text']
elif 'functionCall' in candidate:
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate['functionCall']['args'])
llm_response.tools_call_name.append(candidate['functionCall']['name'])
return llm_response
async def text_chat(
self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
}
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
self.session_memory[session_id] = [*contexts, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
def get_current_key(self) -> str:
return self.client.api_key
def get_keys(self) -> List[str]:
return self.api_keys
def set_key(self, key):
self.client.api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''
组装上下文。
'''
if image_urls:
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
else:
return {"role": "user","content": text}
async def encode_image_bs64(self, image_url: str) -> str:
'''
将图片转换为 base64
'''
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
return "data:image/jpeg;base64," + image_bs64
return ''
@@ -162,7 +162,12 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
@@ -180,8 +185,6 @@ class ProviderOpenAIOfficial(Provider):
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
return llm_response
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
@@ -206,6 +209,8 @@ class ProviderOpenAIOfficial(Provider):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
else:
if image_url.startswith("file:///"):
image_url = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
@@ -0,0 +1,95 @@
import uuid
import os
import io
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import STTProvider
from ..entites import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT)
class ProviderOpenAIWhisperAPI(STTProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + '.mp3'
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data/temp', filename))
return output_path
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
import wave
with wave.open(output_path, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(24000)
wav.writeframes(input_io.read())
return output_path
async def _convert_silk(self, path: str) -> str:
import pysilk
filename = str(uuid.uuid4()) + '.wav'
output_path = os.path.join('data/temp', filename)
with open(path, "rb") as f:
input_data = f.read()
if input_data.startswith(b'\x02'):
# tencent 我爱你
input_data = input_data[1:]
input_io = io.BytesIO(input_data)
output_io = io.BytesIO()
pysilk.decode(input_io, output_io, 24000)
output_io.seek(0)
await self._pcm_to_wav(output_io, output_path)
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
else:
return False
async def get_text(self, audio_url: str) -> str:
'''only supports mp3, mp4, mpeg, m4a, wav, webm'''
if audio_url.startswith("http"):
name = str(uuid.uuid4())
path = os.path.join("data/temp", name)
audio_url = await download_file(audio_url, path)
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
audio_url = await self._convert_silk(audio_url)
result = await self.client.audio.transcriptions.create(
model=self.model_name,
file=open(audio_url, "rb"),
)
return result.text
@@ -0,0 +1,92 @@
import uuid
import os
import io
import asyncio
import whisper
from ..provider import STTProvider
from ..entites import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
class ProviderOpenAIWhisperSelfHost(STTProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("model", None))
self.model = None
async def initialize(self):
loop = asyncio.get_event_loop()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name)
logger.info("Whisper 模型加载完成。")
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + '.mp3'
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data/temp', filename))
return output_path
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
import wave
with wave.open(output_path, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(24000)
wav.writeframes(input_io.read())
return output_path
async def _convert_silk(self, path: str) -> str:
import pysilk
filename = str(uuid.uuid4()) + '.wav'
output_path = os.path.join('data/temp', filename)
with open(path, "rb") as f:
input_data = f.read()
if input_data.startswith(b'\x02'):
# tencent 我爱你
input_data = input_data[1:]
input_io = io.BytesIO(input_data)
output_io = io.BytesIO()
pysilk.decode(input_io, output_io, 24000)
output_io.seek(0)
await self._pcm_to_wav(output_io, output_path)
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
else:
return False
async def get_text(self, audio_url: str) -> str:
loop = asyncio.get_event_loop()
if audio_url.startswith("http"):
name = str(uuid.uuid4())
path = os.path.join("data/temp", name)
audio_url = await download_file(audio_url, path)
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
audio_url = await self._convert_silk(audio_url)
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
return result['text']
@@ -0,0 +1,73 @@
import traceback
from astrbot.core.db import BaseDatabase
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
class ProviderZhipu(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
) -> None:
super().__init__(provider_config, provider_settings, db_helper, persistant_history)
async def text_chat(
self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
# glm-4v-flash 只支持一张图片
model: str = model_cfgs.get("model", "")
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
logger.debug(context_query)
new_context_query_ = []
for i in range(0, len(context_query) - 1, 2):
if isinstance(context_query[i].get("content", ""), list):
continue
new_context_query_.append(context_query[i])
new_context_query_.append(context_query[i+1])
new_context_query_.append(context_query[-1]) # 保留最后一条记录
context_query = new_context_query_
logger.debug(context_query)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**model_cfgs
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
@@ -0,0 +1,25 @@
from typing import List
from openai import AsyncOpenAI
class SimpleOpenAIEmbedding():
def __init__(
self,
model,
api_key,
api_base=None,
) -> None:
self.client = AsyncOpenAI(
api_key=api_key,
base_url=api_base
)
self.model = model
async def get_embedding(self, text) -> List[float]:
'''
获取文本的嵌入
'''
embedding = await self.client.embeddings.create(
input=text,
model=self.model
)
return embedding.data[0].embedding
+92
View File
@@ -0,0 +1,92 @@
import os
from typing import List, Dict
from astrbot.core import logger
from .store import Store
from astrbot.core.config import AstrBotConfig
class KnowledgeDBManager():
def __init__(self, astrbot_config: AstrBotConfig) -> None:
self.db_path = "data/knowledge_db/"
self.config = astrbot_config.get("knowledge_db", {})
self.astrbot_config = astrbot_config
if not os.path.exists(self.db_path):
os.makedirs(self.db_path)
self.store_insts: Dict[str, Store] = {}
for name, cfg in self.config.items():
if cfg["strategy"] == "embedding":
logger.info(f"加载 Chroma Vector Store{name}")
try:
from .store.chroma_db import ChromaVectorStore
except ImportError as ie:
logger.error(f"{ie} 可能未安装 chromadb 库。")
continue
self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"])
else:
logger.error(f"不支持的策略:{cfg['strategy']}")
async def list_knowledge_db(self) -> List[str]:
return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))]
async def create_knowledge_db(self, name: str, config: Dict):
'''
config 格式:
```
{
"strategy": "embedding", # 目前只支持 embedding
"chunk_method": {
"strategy": "fixed",
"chunk_size": 100,
"overlap_size": 10
},
"embedding_config": {
"strategy": "openai",
"base_url": "",
"model": "",
"api_key": ""
}
}
```
'''
if name in self.config:
raise ValueError(f"知识库已存在:{name}")
self.config[name] = config
self.astrbot_config["knowledge_db"] = self.config
self.astrbot_config.save_config()
async def insert_record(self, name: str, text: str):
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
ret = []
match self.config[name]["chunk_method"]['strategy']:
case "fixed":
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
case _:
pass
for chunk in ret:
await self.store_insts[name].save(chunk)
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
inst = self.store_insts[name]
return await inst.query(query, top_n)
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start += chunk_size - chunk_overlap
return chunks
+8
View File
@@ -0,0 +1,8 @@
from typing import List
class Store():
async def save(self, text: str):
pass
async def query(self, query: str, top_n: int = 3) -> List[str]:
pass
+39
View File
@@ -0,0 +1,39 @@
import chromadb
import uuid
from typing import List, Dict
from astrbot.api import logger
from ..embedding.openai_source import SimpleOpenAIEmbedding
from . import Store
class ChromaVectorStore(Store):
def __init__(self, name: str, embedding_cfg: Dict) -> None:
self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db')
self.collection = self.chroma_client.get_or_create_collection(name=name)
self.embedding = None
if embedding_cfg["strategy"] == "openai":
self.embedding = SimpleOpenAIEmbedding(
model=embedding_cfg["model"],
api_key=embedding_cfg["api_key"],
api_base=embedding_cfg.get("base_url", None)
)
async def save(self, text: str, metadata: Dict = None):
logger.debug(f"Saving text: {text}")
embedding = await self.embedding.get_embedding(text)
self.collection.upsert(
documents=text,
metadatas=metadata,
ids=str(uuid.uuid4()),
embeddings=embedding
)
async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]:
embedding = await self.embedding.get_embedding(query)
results = self.collection.query(
query_embeddings=embedding,
n_results=top_n,
where=metadata_filter
)
return results['documents'][0]
+32 -11
View File
@@ -1,6 +1,7 @@
from asyncio import Queue
from typing import List, TypedDict, Union
from astrbot.core import sp
from astrbot.core.provider.provider import Provider
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -14,10 +15,7 @@ from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
class StarCommand(TypedDict):
full_command_name: str
command_name: str
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class Context:
'''
@@ -38,11 +36,22 @@ class Context:
# back compatibility
_register_tasks: List[Awaitable] = []
_star_manager = None
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
def __init__(self,
event_queue: Queue,
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
knowledge_db_manager: KnowledgeDBManager = None
):
self._event_queue = event_queue
self._config = config
self._db = db
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
for star in star_registry:
@@ -73,7 +82,7 @@ class Context:
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_str=func_obj.__module__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
@@ -94,6 +103,12 @@ class Context:
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
@@ -105,6 +120,12 @@ class Context:
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
@@ -125,7 +146,7 @@ class Context:
event_type=EventType.AdapterMessageEvent,
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
handler_module_str=awaitable.__module__,
handler_module_path=awaitable.__module__,
handler=awaitable,
event_filters=[],
desc=desc
@@ -143,13 +164,13 @@ class Context:
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider。
注册一个 LLM Provider(Chat_Completion 类型)
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''
通过 ID 获取 LLM Provider。
通过 ID 获取 LLM Provider(Chat_Completion 类型)
'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
@@ -158,13 +179,13 @@ class Context:
def get_all_providers(self) -> List[Provider]:
'''
获取所有 LLM Provider。
获取所有 LLM Provider(Chat_Completion 类型)
'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的 LLM Provider。
获取当前使用的 LLM Provider(Chat_Completion 类型)
通过 /provider 指令切换。
'''
+3
View File
@@ -51,6 +51,9 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:
return False
# if len(self.handler_params) == 0 and len(ls) > 1:
# # 一定程度避免 LLM 聊天时误判为指令
# return False
# params_str = message_str[len(self.command_name):].strip()
ls = ls[1:]
# 去除空字符串
+2 -2
View File
@@ -28,7 +28,7 @@ def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add =
event_type=event_type,
handler_full_name=handler_full_name,
handler_name=handler.__name__,
handler_module_str=handler.__module__,
handler_module_path=handler.__module__,
handler=handler,
event_filters=[]
)
@@ -185,7 +185,7 @@ def register_llm_tool(name: str = None):
"description": arg.description
})
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, docstring.short_description, md.handler)
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return awaitable
+3
View File
@@ -32,6 +32,9 @@ class StarMetadata:
'''Star 的根目录名'''
reserved: bool = False
'''是否是 AstrBot 的保留 Star'''
activated: bool = True
'''是否被激活'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
+17 -8
View File
@@ -1,11 +1,12 @@
from __future__ import annotations
import enum
from dataclasses import dataclass
from typing import Awaitable, List, Dict
from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
from .star import star_map
class StarHandlerRegistry(List):
T = TypeVar('T', bound='StarHandlerMetadata')
class StarHandlerRegistry(Generic[T], List[T]):
'''用于存储所有的 Star Handler'''
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
@@ -16,9 +17,18 @@ class StarHandlerRegistry(List):
super().append(handler)
self.star_handlers_map[handler.handler_full_name] = handler
def get_handlers_by_event_type(self, event_type: EventType) -> List[StarHandlerMetadata]:
def get_handlers_by_event_type(self, event_type: EventType, only_activated = True) -> List[StarHandlerMetadata]:
'''通过事件类型获取 Handler'''
return [handler for handler in self if handler.event_type == event_type]
if only_activated:
return [
handler
for handler in self
if handler.event_type == event_type and
star_map[handler.handler_module_path] and
star_map[handler.handler_module_path].activated
]
else:
return [handler for handler in self if handler.event_type == event_type]
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
'''通过 Handler 的全名获取 Handler'''
@@ -26,8 +36,7 @@ class StarHandlerRegistry(List):
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
'''通过模块名获取 Handler'''
return [handler for handler in self if handler.handler_module_str == module_name]
return [handler for handler in self if handler.handler_module_path == module_name]
star_handlers_registry = StarHandlerRegistry()
@@ -55,7 +64,7 @@ class StarHandlerMetadata():
handler_name: str
'''Handler 的名字,也就是方法名'''
handler_module_str: str
handler_module_path: str
'''Handler 所在的模块路径。'''
handler: Awaitable
+102 -28
View File
@@ -1,6 +1,7 @@
import inspect
import functools
import os
import sys
import traceback
import yaml
import logging
@@ -8,15 +9,14 @@ from types import ModuleType
from typing import List
from pip import main as pip_main
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
from .updator import PluginUpdator
from astrbot.core.utils.io import remove_dir
from .star import star_registry, star_map
from astrbot.core.provider.register import llm_tools
from .star_handler import star_handlers_registry
from astrbot.core.provider.register import llm_tools
class PluginManager:
def __init__(
@@ -27,6 +27,7 @@ class PluginManager:
self.updator = PluginUpdator(config['plugin_repo_mirror'])
self.context = context
self.context._star_manager = self # 就这样吧,不想改了
self.config = config
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
@@ -91,21 +92,12 @@ class PluginManager:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在检查插件 {p} 的依赖: {pth}")
logger.info(f"正在安装插件 {p} 所需的依赖: {pth}")
try:
self._update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
def _update_plugin_dept(self, path):
'''更新插件的依赖'''
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
if self.config.pip_install_arg:
args.extend(self.config.pip_install_arg)
result_code = pip_main(args)
if result_code != 0:
raise Exception(str(result_code))
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
'''v3.4.0 以前的方式载入插件元数据
@@ -136,15 +128,29 @@ class PluginManager:
return metadata
def reload(self):
async def reload(self):
'''扫描并加载所有的 Star'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
smd.star_cls.__del__()
star_handlers_registry.clear()
star_handlers_registry.star_handlers_map.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
if key.startswith("data.plugins") or key.startswith("packages"):
del sys.modules[key]
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return False, "未找到任何插件模块"
fail_rec = ""
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 导入 Star 模块,并尝试实例化 Star 类
for plugin_module in plugin_modules:
try:
@@ -171,21 +177,24 @@ class PluginManager:
if path in star_map:
# 通过装饰器的方式注册插件
star_metadata = star_map[path]
star_metadata.star_cls = star_metadata.star_cls_type(context=self.context)
star_metadata.module = module
star_metadata.root_dir_name = root_dir_name
star_metadata.reserved = reserved
metadata = star_map[path]
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
related_handlers = star_handlers_registry.get_handlers_by_module_name(star_metadata.module_path)
related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path)
for handler in related_handlers:
logger.debug(f"bind handler {handler.handler_name} to {star_metadata.name}")
logger.debug(f"bind handler {handler.handler_name} to {metadata.name}")
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
handler.handler = functools.partial(handler.handler, star_metadata.star_cls)
handler.handler = functools.partial(handler.handler, metadata.star_cls)
# llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == star_metadata.module_path:
func_tool.handler = functools.partial(func_tool.handler, star_metadata.star_cls)
if func_tool.handler.__module__ == metadata.module_path:
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(func_tool.handler, metadata.star_cls)
if func_tool.name in inactivated_llm_tools:
func_tool.active = False
else:
# v3.4.0 以前的方式注册插件
@@ -209,6 +218,13 @@ class PluginManager:
star_map[path] = metadata
star_registry.append(metadata)
logger.debug(f"插件 {root_dir_name} 载入成功。")
if metadata.module_path in inactivated_plugins:
metadata.activated = False
# 执行 initialize 函数
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
except BaseException as e:
traceback.print_exc()
@@ -225,10 +241,11 @@ class PluginManager:
async def install_plugin(self, repo_url: str):
plugin_path = await self.updator.install(repo_url)
self._check_plugin_dept_update()
# reload the plugin
await self.reload()
return plugin_path
def uninstall_plugin(self, plugin_name: str):
async def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
@@ -237,10 +254,26 @@ class PluginManager:
root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
del star_map[plugin.module_path]
# 从 star_registry 和 star_map 中删除
await self._unbind_plugin(plugin_name, plugin.module_path)
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
del star_map[plugin_module_path]
for i, p in enumerate(star_registry):
if p.name == plugin_name:
del star_registry[i]
break
for handler in star_handlers_registry.get_handlers_by_module_name(plugin_module_path):
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
star_handlers_registry.remove(handler)
keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin_module_path)]
for k in keys_to_delete:
v = star_handlers_registry.star_handlers_map[k]
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
del star_handlers_registry.star_handlers_map[k]
async def update_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
@@ -250,6 +283,46 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin)
await self.reload()
async def turn_off_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
inactivated_plugins: list = sp.get("inactivated_plugins", [])
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = False
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:
inactivated_plugins.remove(plugin.module_path)
sp.put("inactivated_plugins", inactivated_plugins)
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
inactivated_llm_tools.remove(func_tool.name)
func_tool.active = True
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = True
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
@@ -262,3 +335,4 @@ class PluginManager:
logger.warning(f"删除插件压缩包失败: {str(e)}")
self._check_plugin_dept_update()
+1 -2
View File
@@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator):
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
@@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator):
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
try:
logger.info(f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}")
logger.info(f"删除临时文件: {zip_path}{os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except BaseException:
+90
View File
@@ -0,0 +1,90 @@
import json
from astrbot.core import logger
from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
self.api_key = api_key
self.api_base = api_base
self.session = ClientSession()
self.headers = {
"Authorization": f"Bearer {self.api_key}",
}
async def chat_messages(
self,
inputs: Dict,
query: str,
user: str,
response_mode: str = "streaming",
conversation_id: str = "",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
) -> AsyncGenerator[Dict[str, Any], None]:
url = f"{self.api_base}/chat-messages"
payload = locals()
payload.pop("self")
payload.pop("timeout")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def workflow_run(
self,
inputs: Dict,
user: str,
response_mode: str = "streaming",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
):
url = f"{self.api_base}/workflows/run"
payload = locals()
payload.pop("self")
payload.pop("timeout")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def file_upload(
self,
file_path: str,
user: str,
) -> Dict[str, Any]:
url = f"{self.api_base}/files/upload"
payload = {
"user": user,
"file": open(file_path, "rb"),
}
async with self.session.post(
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
+24 -4
View File
@@ -5,6 +5,7 @@ import socket
import time
import aiohttp
import base64
import zipfile
from PIL import Image
@@ -64,7 +65,7 @@ def save_temp_img(img: Image) -> str:
f.write(img)
return p
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str:
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str:
'''
下载图片, 返回 path
'''
@@ -72,10 +73,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
async with aiohttp.ClientSession() as session:
if post:
async with session.post(url, json=post_data) as resp:
return save_temp_img(await resp.read())
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
return save_temp_img(await resp.read())
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client_exceptions.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
@@ -96,7 +107,9 @@ async def download_file(url: str, path: str):
'''
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
async with session.get(url, timeout=20) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
with open(path, 'wb') as f:
while True:
chunk = await resp.content.read(8192)
@@ -123,3 +136,10 @@ def get_local_ip_addresses():
finally:
s.close()
return ip
async def download_dashboard():
'''下载管理面板文件'''
dashboard_release_url = "https://astrbot-registry.lwl.lol/download/astrbot-dashboard/latest/dist.zip"
await download_file(dashboard_release_url, "data/dashboard.zip")
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")
@@ -22,6 +22,9 @@ class ParameterValidationMixin:
result[param_name] = int(params[i])
else:
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
+33
View File
@@ -0,0 +1,33 @@
import logging
from pip import main as pip_main
class PipInstaller():
def __init__(self, pip_install_arg: str):
self.pip_install_arg = pip_install_arg
def install(self, package_name: str = None, requirements_path: str = None, mirror: str = None):
args = ['install']
if package_name:
args.append(package_name)
elif requirements_path:
args.extend(['-r', requirements_path])
if not mirror:
mirror = 'https://mirrors.aliyun.com/pypi/simple/'
args.extend(['--trusted-host', 'mirrors.aliyun.com', '-i', mirror])
if self.pip_install_arg:
args.extend(self.pip_install_arg.split())
print(f"Pip 包管理器: {' '.join(args)}")
result_code = pip_main(args)
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")
+33
View File
@@ -0,0 +1,33 @@
import json
import os
class SharedPreferences:
def __init__(self, path="data/shared_preferences.json"):
self.path = path
self._data = self._load_preferences()
def _load_preferences(self):
if os.path.exists(self.path):
with open(self.path, "r") as f:
return json.load(f)
return {}
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4)
def get(self, key, default=None):
return self._data.get(key, default)
def put(self, key, value):
self._data[key] = value
self._save_preferences()
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
self._data.clear()
self._save_preferences()
+1 -1
View File
@@ -111,7 +111,7 @@ class RepoZipUpdator():
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,从默认分支下载。")
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
+11 -2
View File
@@ -1,4 +1,5 @@
import asyncio
import traceback
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .server import AstrBotDashboard
@@ -13,8 +14,16 @@ class AstrBotDashBoardLifecycle:
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
await core_lifecycle.initialize()
core_task = core_lifecycle.start()
core_task = []
try:
await core_lifecycle.initialize()
core_task = core_lifecycle.start()
except Exception as e:
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
task = asyncio.gather(core_task, self.dashboard_server.run())
+3 -1
View File
@@ -5,6 +5,7 @@ from .update import UpdateRoute
from .stat import StatRoute
from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
__all__ = [
@@ -14,6 +15,7 @@ __all__ = [
"UpdateRoute",
"StatRoute",
"LogRoute",
"StaticFileRoute"
"StaticFileRoute",
"ChatRoute",
]
+197
View File
@@ -0,0 +1,197 @@
import uuid
import json
import os
from .route import Route, Response, RouteContext
from astrbot.core import web_chat_queue, web_chat_back_queue
from quart import request, Response as QuartResponse, g
from astrbot.core.db import BaseDatabase
import asyncio
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class ChatRoute(Route):
def __init__(self, context: RouteContext, db: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
self.routes = {
'/chat/send': ('POST', self.chat),
'/chat/new_conversation': ('GET', self.new_conversation),
'/chat/conversations': ('GET', self.get_conversations),
'/chat/get_conversation': ('GET', self.get_conversation),
'/chat/delete_conversation': ('GET', self.delete_conversation),
'/chat/get_file': ('GET', self.get_file),
'/chat/post_image': ('POST', self.post_image),
'/chat/post_file': ('POST', self.post_file),
'/chat/status': ('GET', self.status),
}
self.db = db
self.core_lifecycle = core_lifecycle
self.register_routes()
self.imgs_dir = "data/webchat/imgs"
self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp']
async def status(self):
has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None
has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None
return Response().ok(data={
'llm_enabled': has_llm_enabled,
'stt_enabled': has_stt_enabled
}).__dict__
async def get_file(self):
filename = request.args.get('filename')
if not filename:
return Response().error("Missing key: filename").__dict__
try:
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
if filename.endswith(".wav"):
return QuartResponse(f.read(), mimetype="audio/wav")
elif filename.split('.')[-1] in self.supported_imgs:
return QuartResponse(f.read(), mimetype="image/jpeg")
else:
return QuartResponse(f.read())
except FileNotFoundError:
return Response().error("File not found").__dict__
async def post_image(self):
post_data = await request.files
if 'file' not in post_data:
return Response().error("Missing key: file").__dict__
file = post_data['file']
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
await file.save(path)
return Response().ok(data={
'filename': filename
}).__dict__
async def post_file(self):
post_data = await request.files
if 'file' not in post_data:
return Response().error("Missing key: file").__dict__
file = post_data['file']
filename = f"{str(uuid.uuid4())}"
print(file)
# 通过文件格式判断文件类型
if file.content_type.startswith('audio'):
filename += ".wav"
path = os.path.join(self.imgs_dir, filename)
await file.save(path)
return Response().ok(data={
'filename': filename
}).__dict__
async def chat(self):
username = g.get('username', 'guest')
post_data = await request.json
if 'message' not in post_data and 'image_url' not in post_data:
return Response().error("Missing key: message or image_url").__dict__
if 'conversation_id' not in post_data:
return Response().error("Missing key: conversation_id").__dict__
message = post_data['message']
conversation_id = post_data['conversation_id']
image_url = post_data.get('image_url')
audio_url = post_data.get('audio_url')
if not message and not image_url and not audio_url:
return Response().error("Message and image_url and audio_url are empty").__dict__
if not conversation_id:
return Response().error("conversation_id is empty").__dict__
await web_chat_queue.put((username, conversation_id, {
'message': message,
'image_url': image_url, # list
'audio_url': audio_url
}))
async def stream():
ret = []
while True:
try:
result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=30) # 设置超时时间为5秒
except asyncio.TimeoutError:
yield '[Error] 30 秒内没有返回数据,已放弃。\n'
return
if result is None:
break
ret.append(result)
yield result + '\n'
await asyncio.sleep(0.5)
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
new_his = {
'type': 'user',
'message': message
}
if image_url:
new_his['image_url'] = image_url
if audio_url:
new_his['audio_url'] = audio_url
history.append(new_his)
for r in ret:
history.append({
'type': 'bot',
'message': r
})
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
return QuartResponse(
stream(),
mimetype="text/event-stream",
headers={
"Content-Type": "text/event-stream",
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*" # 如果是跨域请求
}
)
async def delete_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
self.db.delete_webchat_conversation(username, conversation_id)
return Response().ok().__dict__
async def new_conversation(self):
username = g.get('username', 'guest')
conversation_id = str(uuid.uuid4())
self.db.webchat_new_conversation(username, conversation_id)
return Response().ok(data={
'conversation_id': conversation_id
}).__dict__
async def get_conversations(self):
username = g.get('username', 'guest')
conversations = self.db.get_webchat_conversations(username)
return Response().ok(data=conversations).__dict__
async def get_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
return Response().ok(data=conversation).__dict__
+7
View File
@@ -7,6 +7,7 @@ from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
@@ -121,6 +122,12 @@ class ConfigRoute(Route):
async def _get_astrbot_config(self):
config = self.config
platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template']
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = platform.default_config_tmpl
return {
"metadata": CONFIG_METADATA_2,
"config": config
+30 -8
View File
@@ -16,7 +16,9 @@ class PluginRoute(Route):
'/plugin/install-upload': ('POST', self.install_plugin_upload),
'/plugin/update': ('POST', self.update_plugin),
'/plugin/uninstall': ('POST', self.uninstall_plugin),
'/plugin/market_list': ('GET', self.get_online_plugins)
'/plugin/market_list': ('GET', self.get_online_plugins),
'/plugin/off': ('POST', self.off_plugin),
'/plugin/on': ('POST', self.on_plugin)
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
@@ -42,7 +44,8 @@ class PluginRoute(Route):
"author": plugin.author,
"desc": plugin.desc,
"version": plugin.version,
"reserved": plugin.reserved
"reserved": plugin.reserved,
"activated": plugin.activated
}
_plugin_resp.append(_t)
return Response().ok(_plugin_resp).__dict__
@@ -53,7 +56,6 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url)
self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -69,7 +71,6 @@ class PluginRoute(Route):
await file.save(file_path)
self.plugin_manager.install_plugin_from_file(file_path)
logger.info(f"安装插件 {file.filename} 成功")
self.core_lifecycle.restart()
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
@@ -80,7 +81,7 @@ class PluginRoute(Route):
plugin_name = post_data["name"]
try:
logger.info(f"正在卸载插件 {plugin_name}")
self.plugin_manager.uninstall_plugin(plugin_name)
await self.plugin_manager.uninstall_plugin(plugin_name)
logger.info(f"卸载插件 {plugin_name} 成功")
return Response().ok(None, "卸载成功").__dict__
except Exception as e:
@@ -93,9 +94,30 @@ class PluginRoute(Route):
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name)
self.core_lifecycle.restart()
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def off_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_off_plugin(plugin_name)
logger.info(f"停用插件 {plugin_name}")
return Response().ok(None, "停用成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/off: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def on_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_on_plugin(plugin_name)
logger.info(f"启用插件 {plugin_name}")
return Response().ok(None, "启用成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/on: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
+20 -3
View File
@@ -3,7 +3,7 @@ import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core import logger, pip_installer
class UpdateRoute(Route):
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
@@ -11,6 +11,7 @@ class UpdateRoute(Route):
self.routes = {
'/update/check': ('GET', self.check_update),
'/update/do': ('POST', self.update_project),
'/update/pip-install': ('POST', self.install_pip_package)
}
self.astrbot_updator = astrbot_updator
self.register_routes()
@@ -32,6 +33,7 @@ class UpdateRoute(Route):
async def update_project(self):
data = await request.json
version = data.get('version', '')
reboot = data.get('reboot', True)
if version == "" or version == "latest":
latest = True
version = ''
@@ -39,8 +41,23 @@ class UpdateRoute(Route):
latest = False
try:
await self.astrbot_updator.update(latest=latest, version=version)
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
if reboot:
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
else:
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
except Exception as e:
logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
async def install_pip_package(self):
data = await request.json
package = data.get('package', '')
if not package:
return Response().error("缺少参数 package 或不合法。").__dict__
try:
pip_installer.install(package)
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(f"/api/update_pip: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
+6 -3
View File
@@ -2,7 +2,7 @@ import logging
import jwt
import asyncio
import os
from quart import Quart, request, jsonify
from quart import Quart, request, jsonify, g
from quart.logging import default_handler
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .routes import *
@@ -18,7 +18,6 @@ class AstrBotDashboard():
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.data_path = os.path.abspath(os.path.join(DATAPATH, "dist"))
logger.info(f"Dashboard data path: {self.data_path}")
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
self.app.json.sort_keys = False
self.app.before_request(self.auth_middleware)
@@ -32,12 +31,15 @@ class AstrBotDashboard():
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context)
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
async def auth_middleware(self):
if not request.path.startswith("/api"):
return
if request.path == "/api/auth/login":
return
if request.path == "/api/chat/get_file":
return
# claim jwt
token = request.headers.get("Authorization")
if not token:
@@ -47,7 +49,8 @@ class AstrBotDashboard():
if token.startswith("Bearer "):
token = token[7:]
try:
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
payload = jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
g.username = payload["username"]
except jwt.ExpiredSignatureError:
r = jsonify(Response().error("Token 过期").__dict__)
r.status_code = 401
+14
View File
@@ -0,0 +1,14 @@
# What's Changed
1. 修复了 reminder 插件可能不会触发回调的问题。
2. 修复了 telegram 插件不可用的问题。
3. 修复了 qq_official 无法发图的问题。
4. 修复事件监听器会让 WakeStage 失效的问题。
5. 修复 websearch 在 cmd_config 中失效的问题。
3. 支持通过 Google GenAI 访问 Gemini 模型,而不需要使用 Gemini 对 OpenAI 的兼容 API。详见文档。
4. 支持对插件禁用/启用。/plugin off/on <plugin_name>
5. 支持基于 Docker 的沙箱化代码执行器。(Beta 测试)详见文档。
6. 支持接入 Dify LLMOps 平台。详见文档。
7. 适配器类插件支持设置默认配置模板。
8. 优化了部分指令的持久化记忆。如 /tool 的禁用、/provider 的选择都将持久化保存,每次启动时不需要重新设置。
9. 优化了 glm-4v-flash 模型。其只支持一张图。
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
1. 支持通过 /set <k> <v> 设置持久化的会话变量, 方便 Dify App 输入变量
2. 管理面板支持 Web Chat
3. 管理面板支持手动安装 Pip 库, 在 `控制台` 页中可找到
+9
View File
@@ -0,0 +1,9 @@
# What's Changed
- 支持接入 STT(语音转文字)Provider
- 内置支持 OpenAI Whisper API/本地运行模型。[看这里](https://astrbot.lwl.lol/use/whisper.html)
- WebChat 支持语音输入
- WebChat 支持显示当前 Provider 状态
- 优化了 WebChat 在没有消息返回时的处理方式
- 修复了 reminder 在初始化历史待办时没有正常传入 session_id 的问题
- 代码执行器在成功回复后清空文件 buffer。
+1 -3
View File
@@ -1,4 +1,2 @@
node_modules/
.DS_Store
package-lock.json
package.json
.DS_Store
+9998
View File
File diff suppressed because it is too large Load Diff
+58
View File
@@ -0,0 +1,58 @@
{
"name": "astrbot-dashboard",
"version": "1.0.0",
"private": true,
"author": "CodedThemes",
"scripts": {
"dev": "vite --host",
"build": "vue-tsc --noEmit && vite build",
"build-stage": "vue-tsc --noEmit && vite build --base=/vue/free/stage/",
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
"preview": "vite preview --port 5050",
"typecheck": "vue-tsc --noEmit",
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
},
"dependencies": {
"@guolao/vue-monaco-editor": "^1.5.4",
"@tiptap/starter-kit": "2.1.7",
"@tiptap/vue-3": "2.1.7",
"apexcharts": "3.42.0",
"axios": "^1.6.2",
"axios-mock-adapter": "^1.22.0",
"chance": "1.1.11",
"date-fns": "2.30.0",
"js-md5": "^0.8.3",
"lodash": "4.17.21",
"marked": "^15.0.6",
"pinia": "2.1.6",
"remixicon": "3.5.0",
"vee-validate": "4.11.3",
"vite-plugin-vuetify": "1.0.2",
"vue": "3.3.4",
"vue-router": "4.2.4",
"vue3-apexcharts": "1.4.4",
"vue3-print-nb": "0.1.4",
"vuetify": "3.3.14",
"yup": "1.2.0"
},
"devDependencies": {
"@mdi/font": "7.2.96",
"@rushstack/eslint-patch": "1.3.3",
"@types/chance": "1.1.3",
"@types/node": "20.5.7",
"@vitejs/plugin-vue": "4.3.3",
"@vue/eslint-config-prettier": "8.0.0",
"@vue/eslint-config-typescript": "11.0.3",
"@vue/tsconfig": "0.4.0",
"eslint": "8.48.0",
"eslint-plugin-vue": "9.17.0",
"prettier": "3.0.2",
"sass": "1.66.1",
"sass-loader": "13.3.2",
"typescript": "5.1.6",
"vite": "4.4.9",
"vue-cli-plugin-vuetify": "2.5.8",
"vue-tsc": "1.8.8",
"vuetify-loader": "^2.0.0-alpha.9"
}
}
@@ -10,7 +10,7 @@
</v-alert>
<div style="display: flex; align-items: center; justify-content: center; gap: 16px">
<div style="width: 100%;">
<div style="width: 100%;" v-if="metadata[metadataKey].items[key]">
<v-select v-if="metadata[metadataKey].items[key]?.options && !metadata[metadataKey].items[key]?.invisible" v-model="iterable[key]"
variant="outlined" :items="metadata[metadataKey].items[key]?.options"
:label="metadata[metadataKey].items[key]?.description + '(' + key + ')'" dense :disabled="metadata[metadataKey].items[key]?.readonly"></v-select>
@@ -46,6 +46,11 @@
</div>
</div>
<div style="width: 100%;" v-else>
<!-- metadata 中没有 key -->
<v-text-field v-model="iterable[key]" :label="key" variant="outlined" dense></v-text-field>
</div>
<div
v-if="!metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint && metadata[metadataKey].items[key]?.type !== 'object' && !metadata[metadataKey].items[key]?.invisible">
<v-btn icon size="x-small" style="margin-bottom: 22px;">
@@ -30,6 +30,11 @@ const sidebarItem: menu[] = [
icon: 'mdi-puzzle',
to: '/extension'
},
{
title: '聊天',
icon: 'mdi-chat',
to: '/chat'
},
{
title: '控制台',
icon: 'mdi-console',
+5
View File
@@ -36,6 +36,11 @@ const MainRoutes = {
name: 'Project ATRI',
path: '/project-atri',
component: () => import('@/views/ATRIProject.vue')
},
{
name: 'Chat',
path: '/chat',
component: () => import('@/views/ChatPage.vue')
}
]
};
+504
View File
@@ -0,0 +1,504 @@
<script setup>
import axios from 'axios';
import { ref } from 'vue';
import { marked } from 'marked';
marked.setOptions({
breaks: true
});
</script>
<template>
<v-card style="margin-bottom: 16px; width: 100%; background-color: #fff; height: 100%;">
<v-card-text style="width: 100%; height: calc(100vh - 120px);">
<div style="height: 100%; display: flex; gap: 16px;">
<div style="max-width: 200px;">
<!-- conversation -->
<v-btn variant="tonal" rounded="xl" style="margin-bottom: 16px; min-width: 200px;" @click="newC"
:disabled="!currCid">+ 创建对话</v-btn>
<v-card class="mx-auto" min-width="200">
<v-list dense nav v-if="conversations.length > 0" style="max-height: 500px; overflow-y: auto;"
@update:selected="getConversationMessages">
<v-list-item v-for="(item, i) in conversations" :key="item.cid" :value="item.cid"
color="primary" rounded="xl">
<v-list-item-title>新对话</v-list-item-title>
<v-list-item-subtitle>{{ formatDate(item.updated_at) }}</v-list-item-subtitle>
</v-list-item>
</v-list>
</v-card>
<div>
<v-chip class="mt-4" color="primary" :append-icon="status?.llm_enabled ? 'mdi-check' : 'mdi-close'">
LLM
</v-chip>
<v-chip class="mt-4 ml-2" color="success" :append-icon="status?.stt_enabled ? 'mdi-check' : 'mdi-close'">
语音转文本
</v-chip>
</div>
<v-btn variant="tonal" rounded="xl"
style="position: fixed; bottom: 48px; margin-bottom: 16px; min-width: 200px;" v-if="currCid"
@click="deleteConversation(currCid)" color="error">删除此对话</v-btn>
</div>
<div style="height: 100%; width: 100%;">
<div style="height: calc(100% - 120px); overflow-y: auto; padding: 16px; " ref="messageContainer">
<div class="fade-in" v-if="messages.length == 0"
style="height: 100%; display: flex; justify-content: center; align-items: center; flex-direction: column;">
<div>
<span style="font-size: 28px;">Hello, I'm</span>
<span style="font-weight: 1000; font-size: 28px; margin-left: 8px;">AstrBot ⭐</span>
</div>
<div style="margin-top: 8px; color: #aaa;">
<span>输入</span>
<span
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">/help</span>
<span>获取帮助 😊</span>
</div>
<div style="margin-top: 8px; color: #aaa;">
<span>按</span>
<span
style="background-color: #eee; padding-left: 4px; padding-right: 4px; margin: 2px; border-radius: 4px;">K</span>
<span>开始语音 🎤</span>
</div>
</div>
<div v-else style="max-height: 100%; padding: 16px; max-width: 700px; margin: 0 auto;">
<div class="fade-in" v-for="(msg, index) in messages" :key="index"
style="margin-bottom: 16px;">
<div v-if="msg.type == 'user'" style="display: flex; justify-content: flex-end;">
<div
style="padding: 12px; border-radius: 8px; background-color: rgba(94, 53, 177, 0.15)">
<span>{{ msg.message }}</span>
<div style="display: flex; gap: 8px; margin-top: 8px;"
v-if="msg.image_url && msg.image_url.length > 0">
<div v-for="(img, index) in msg.image_url" :key="index"
style="position: relative; display: inline-block;">
<img :src="img"
style="width: 100px; height: 100px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
</div>
</div>
<!-- audio -->
<div>
<audio controls v-if="msg.audio_url && msg.audio_url.length > 0">
<source :src="msg.audio_url" type="audio/wav">
Your browser does not support the audio element.
</audio>
</div>
</div>
</div>
<div v-else style="display: flex; justify-content: flex-start; gap: 16px;">
<span style="font-size: 32px;">✨</span>
<div v-html="marked(msg.message)" class="mc" style="font-family: inherit;"></div>
</div>
</div>
</div>
</div>
<div class="fade-in" style="bottom: 16px; width: 100%; padding: 8px; ">
<div
style="width: 100%; justify-content: center; align-items: center; display: flex; flex-direction: column; margin-top: 8px;">
<v-text-field id="input-field" variant="outlined" v-model="prompt" :label="inputFieldLabel"
placeholder="Start typing..." loading clear-icon="mdi-close-circle" clearable
@click:clear="clearMessage" style="width: 100%; max-width: 850px;">
<template v-slot:loader>
<v-progress-linear :active="loadingChat" height="6"
indeterminate></v-progress-linear>
</template>
<template v-slot:append>
<v-tooltip text="发送">
<template v-slot:activator="{ props }">
<v-icon v-bind="props" @click="sendMessage" size="35"
icon="mdi-arrow-up-circle" />
</template>
</v-tooltip>
<v-tooltip text="语音输入">
<template v-slot:activator="{ props }">
<v-icon :color="isRecording ? 'error' : ''" v-bind="props"
@click="isRecording ? stopRecording() : startRecording()" size="35"
icon="mdi-record-circle" />
</template>
</v-tooltip>
</template>
</v-text-field>
<div style="display: flex; gap: 8px; margin-top: -8px;">
<div v-for="(img, index) in stagedImagesUrl" :key="index"
style="position: relative; display: inline-block;">
<img :src="img"
style="width: 50px; height: 50px; border-radius: 8px; box-shadow: 0 0 5px rgba(0, 0, 0, 0.1);" />
<v-icon @click="removeImage(index)" size="20" color="red"
style="position: absolute; top: 0; right: 0; cursor: pointer;">mdi-close-circle</v-icon>
</div>
<div style="display: inline-block; width: 50px; height: 50px;">
<div v-if="stagedAudioUrl"
style="position: relative; padding: 6px; border-radius: 8px; background-color: rgba(94, 53, 177, 0.15); display: inline-block;">
新录音
<v-icon @click="removeAudio" size="20" color="red"
style="position: absolute; top: 0; right: 0; cursor: pointer;">mdi-close-circle</v-icon>
</div>
</div>
</div>
</div>
</div>
</div>
</div>
</v-card-text>
</v-card>
</template>
<script>
export default {
name: 'ChatPage',
components: {
},
data() {
return {
prompt: '',
messages: [],
conversations: [],
currCid: '',
stagedImagesUrl: [],
loadingChat: false,
inputFieldLabel: '聊天吧!',
isRecording: false,
audioChunks: [],
stagedAudioUrl: "",
mediaRecorder: null,
status: {},
statusText: ''
}
},
mounted() {
this.checkStatus();
this.getConversations();
let inputField = document.getElementById('input-field');
inputField.addEventListener('paste', this.handlePaste);
inputField.addEventListener('keydown', function (e) {
if (e.keyCode == 13 && !e.shiftKey) {
e.preventDefault();
this.sendMessage();
}
}.bind(this));
document.addEventListener('keydown', function (e) {
if (e.keyCode == 75) {
this.isRecording ? this.stopRecording() : this.startRecording();
}
}.bind(this));
},
methods: {
removeAudio() {
this.stagedAudioUrl = null;
},
checkStatus() {
axios.get('/api/chat/status').then(response => {
console.log(response.data);
this.status = response.data.data;
}).catch(err => {
console.error(err);
});
},
async startRecording() {
const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
this.mediaRecorder = new MediaRecorder(stream);
this.mediaRecorder.ondataavailable = (event) => {
this.audioChunks.push(event.data);
};
this.mediaRecorder.start();
this.isRecording = true;
this.inputFieldLabel = "录音中,请说话...";
},
async stopRecording() {
this.isRecording = false;
this.inputFieldLabel = "聊天吧!";
this.mediaRecorder.stop();
this.mediaRecorder.onstop = async () => {
const audioBlob = new Blob(this.audioChunks, { type: 'audio/wav' });
this.audioChunks = [];
this.mediaRecorder.stream.getTracks().forEach(track => track.stop());
const formData = new FormData();
formData.append('file', audioBlob);
try {
const response = await axios.post('/api/chat/post_file', formData, {
headers: {
'Content-Type': 'multipart/form-data',
'Authorization': 'Bearer ' + localStorage.getItem('token')
}
});
const audio = response.data.data.filename;
console.log('Audio uploaded:', audio);
this.stagedAudioUrl = `/api/chat/get_file?filename=${audio}`;
} catch (err) {
console.error('Error uploading audio:', err);
}
};
},
async handlePaste(event) {
console.log('Pasting image...');
const items = event.clipboardData.items;
for (let i = 0; i < items.length; i++) {
if (items[i].type.indexOf('image') !== -1) {
const file = items[i].getAsFile();
const formData = new FormData();
formData.append('file', file);
try {
const response = await axios.post('/api/chat/post_image', formData, {
headers: {
'Content-Type': 'multipart/form-data',
'Authorization': 'Bearer ' + localStorage.getItem('token')
}
});
const img = response.data.data.filename;
this.stagedImagesUrl.push(`/api/chat/get_file?filename=${img}`);
} catch (err) {
console.error('Error uploading image:', err);
}
}
}
},
removeImage(index) {
this.stagedImagesUrl.splice(index, 1);
},
clearMessage() {
this.prompt = '';
},
getConversations() {
axios.get('/api/chat/conversations').then(response => {
this.conversations = response.data.data;
}).catch(err => {
console.error(err);
});
},
getConversationMessages(cid) {
if (!cid[0])
return;
axios.get('/api/chat/get_conversation?conversation_id=' + cid[0]).then(response => {
this.currCid = cid[0];
let message = JSON.parse(response.data.data.history);
for (let i = 0; i < message.length; i++) {
if (message[i].message.startsWith('[IMAGE]')) {
let img = message[i].message.replace('[IMAGE]', '');
message[i].message = `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
if (message[i].image_url && message[i].image_url.length > 0) {
for (let j = 0; j < message[i].image_url.length; j++) {
message[i].image_url[j] = `/api/chat/get_file?filename=${message[i].image_url[j]}`;
}
}
if (message[i].audio_url) {
message[i].audio_url = `/api/chat/get_file?filename=${message[i].audio_url}`;
}
}
this.messages = message;
}).catch(err => {
console.error(err);
});
},
async newConversation() {
await axios.get('/api/chat/new_conversation').then(response => {
this.currCid = response.data.data.conversation_id;
this.getConversations();
}).catch(err => {
console.error(err);
});
},
newC() {
this.currCid = '';
this.messages = [];
},
formatDate(timestamp) {
const date = new Date(timestamp * 1000); // 假设时间戳是以秒为单位
const options = {
year: 'numeric',
month: '2-digit',
day: '2-digit',
hour: '2-digit',
minute: '2-digit',
second: '2-digit',
hour12: false
};
return date.toLocaleString('zh-CN', options).replace(/\//g, '-').replace(/, /g, ' ');
},
deleteConversation(cid) {
axios.get('/api/chat/delete_conversation?conversation_id=' + cid).then(response => {
this.getConversations();
this.currCid = '';
this.messages = [];
}).catch(err => {
console.error(err);
});
},
async sendMessage() {
if (this.currCid == '') {
await this.newConversation();
}
this.messages.push({
type: 'user',
message: this.prompt,
image_url: this.stagedImagesUrl,
audio_url: this.stagedAudioUrl
});
this.scrollToBottom();
// images
let image_filenames = [];
for (let i = 0; i < this.stagedImagesUrl.length; i++) {
let img = this.stagedImagesUrl[i].replace('/api/chat/get_file?filename=', '');
image_filenames.push(img);
}
// audio
let audio_filenames = [];
if (this.stagedAudioUrl) {
let audio = this.stagedAudioUrl.replace('/api/chat/get_file?filename=', '');
audio_filenames.push(audio);
}
this.loadingChat = true;
fetch('/api/chat/send', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + localStorage.getItem('token')
},
body: JSON.stringify({
message: this.prompt,
conversation_id: this.currCid,
image_url: image_filenames,
audio_url: audio_filenames
}) // 发送请求体
})
.then(response => {
this.prompt = '';
this.stagedImagesUrl = [];
this.stagedAudioUrl = "";
this.loadingChat = false;
const reader = response.body.getReader(); // 获取流的 Reader
const decoder = new TextDecoder();
const readStream = async () => {
const { done, value } = await reader.read(); // 读取流中的数据
if (done) {
console.log("Stream finished.");
return;
}
const chunk = decoder.decode(value, { stream: true });
// bot_resp.message.value += chunk;
console.log("!!!!", chunk);
if (chunk.startsWith('[IMAGE]')) {
let img = chunk.replace('[IMAGE]', '');
let bot_resp = {
type: 'bot',
message: `<img src="/api/chat/get_file?filename=${img}" style="max-width: 80%; border-radius: 8px; box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);"/>`
}
this.messages.push(bot_resp);
} else {
let bot_resp = {
type: 'bot',
message: chunk
}
this.messages.push(bot_resp);
}
this.scrollToBottom();
readStream(); // 递归读取流
};
readStream();
})
.catch(err => {
console.error(err);
});
},
scrollToBottom() {
this.$nextTick(() => {
const container = this.$refs.messageContainer;
container.scrollTop = container.scrollHeight;
});
}
}
}
</script>
<style>
@keyframes fadeIn {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
.mc h1,
.mc h2,
.mc h3,
.mc h4,
.mc h5,
.mc h6 {
margin-bottom: 10px;
}
.mc li {
margin-left: 16px;
}
.mc p {
margin-top: 10px;
margin-bottom: 10px;
}
</style>
+61 -1
View File
@@ -1,5 +1,7 @@
<script setup>
import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
import axios from 'axios';
</script>
<template>
@@ -7,8 +9,34 @@ import ConsoleDisplayer from '@/components/shared/ConsoleDisplayer.vue';
<div
style="background-color: white; padding: 8px; padding-left: 16px; border-radius: 8px; margin-bottom: 16px; display: flex; flex-direction: row; align-items: center; justify-content: space-between;">
<h4>控制台</h4>
<v-dialog v-model="pipDialog" width="400">
<template v-slot:activator="{ props }">
<v-btn variant="plain" v-bind="props">安装 pip </v-btn>
</template>
<v-card>
<v-card-title>
<span class="text-h5">安装 Pip </span>
</v-card-title>
<v-card-text>
<v-text-field v-model="pipInstallPayload.package" label="*库名,如 llmtuner" variant="outlined"></v-text-field>
<v-text-field v-model="pipInstallPayload.mirror" label="镜像站链接(可选)" variant="outlined"></v-text-field>
<small>如果不填镜像站链接默认使用阿里云镜像https://mirrors.aliyun.com/pypi/simple/</small>
<div>
<small>{{ status }}</small>
</div>
</v-card-text>
<v-card-actions>
<v-spacer></v-spacer>
<v-btn color="blue-darken-1" variant="text" @click="pipInstall" :loading="loading">
安装
</v-btn>
</v-card-actions>
</v-card>
</v-dialog>
</div>
<ConsoleDisplayer style="height: calc(100vh - 160px); "/>
<ConsoleDisplayer style="height: calc(100vh - 160px); " />
</div>
</template>
<script>
@@ -17,6 +45,36 @@ export default {
components: {
ConsoleDisplayer
},
data() {
return {
pipDialog: false,
pipInstallPayload: {
package: '',
mirror: ''
},
loading: false,
status: ''
}
},
methods: {
pipInstall() {
this.loading = true;
axios.post('/api/update/pip-install', this.pipInstallPayload)
.then(res => {
this.status = res.data.message;
setTimeout(() => {
this.status = '';
this.pipDialog = false;
}, 2000);
})
.catch(err => {
this.status = err.response.data.message;
}).finally(() => {
this.loading = false;
});
}
}
}
</script>
@@ -26,10 +84,12 @@ export default {
from {
opacity: 0;
}
to {
opacity: 1;
}
}
.fade-in {
animation: fadeIn 0.2s ease-in-out;
}
+36 -4
View File
@@ -9,8 +9,8 @@ import axios from 'axios';
<template>
<v-row>
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以前往 配置->其他配置->插件仓库镜像 修改安装镜像源。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal">
<v-alert style="margin: 16px" text="1. 如果因为网络问题安装失败,可以自行前往仓库下载压缩包,然后从本地上传。2. 如需插件帮助请点击 `仓库` 查看 README"
title="💡提示" type="info" variant="tonal">
</v-alert>
<v-col cols="12" md="12">
<div style="background-color: white; width: 100%; padding: 16px; border-radius: 10px;">
@@ -29,7 +29,9 @@ import axios from 'axios';
<v-btn variant="plain" @click="updateExtension(extension.name)">更新</v-btn>
<v-btn variant="plain" @click="uninstallExtension(extension.name)">卸载</v-btn>
</div>
<span v-else>保留插件</span>
<!-- <span v-else>保留插件</span> -->
<v-btn variant="plain" v-if="extension.activated" @click="pluginOff(extension)">禁用</v-btn>
<v-btn variant="plain" v-else @click="pluginOn(extension)">启用</v-btn>
</div>
</ExtensionCard>
</v-col>
@@ -78,7 +80,7 @@ import axios from 'axios';
</v-card>
</v-dialog>
<v-dialog v-model="dialog" persistent width="700">
<v-dialog v-model="dialog" width="700">
<template v-slot:activator="{ props }">
<v-btn v-bind="props" icon="mdi-plus" size="x-large" style="position: fixed; right: 52px; bottom: 52px;"
color="darkprimary">
@@ -329,6 +331,36 @@ export default {
this.toast(err, "error");
});
},
pluginOn(extension) {
axios.post('/api/plugin/on',
{
name: extension.name
}).then((res) => {
if (res.data.status === "error") {
this.toast(res.data.message, "error");
return;
}
this.toast(res.data.message, "success");
this.getExtensions();
}).catch((err) => {
this.toast(err, "error");
});
},
pluginOff(extension) {
axios.post('/api/plugin/off',
{
name: extension.name
}).then((res) => {
if (res.data.status === "error") {
this.toast(res.data.message, "error");
return;
}
this.toast(res.data.message, "success");
this.getExtensions();
}).catch((err) => {
this.toast(err, "error");
});
},
openExtensionConfig(extension_name) {
this.curr_namespace = extension_name;
this.configDialog = true;
+14 -20
View File
@@ -1,13 +1,12 @@
import os
import asyncio
import sys
import mimetypes
import aiohttp
import zipfile
from astrbot.dashboard import AstrBotDashBoardLifecycle
from astrbot.core import db_helper
from astrbot.core import logger, LogManager, LogBroker
from astrbot.core.config.default import VERSION
from astrbot.core.utils.io import download_dashboard
# add parent path to sys.path
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
@@ -39,25 +38,20 @@ def check_env():
async def check_dashboard_files():
'''下载管理面板文件'''
if os.path.exists("data/dist"):
return
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
logger.info("开始下载管理面板文件...")
async with aiohttp.ClientSession() as session:
async with session.get(dashboard_release_url) as resp:
if resp.status != 200:
logger.error(f"下载管理面板文件失败: {resp.status}")
with open("data/dashboard.zip", "wb") as f:
f.write(await resp.read())
logger.info("管理面板文件下载完成。")
ok = True
if not ok:
logger.critical("下载管理面板文件失败")
if os.path.exists("data/dist/assets/version"):
with open("data/dist/assets/version", "r") as f:
if f.read() != VERSION:
logger.warning("检测到管理面板有更新。可以使用 /dashboard update 命令更新。")
return
# unzip
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")
logger.info("开始下载管理面板文件...")
try:
await download_dashboard()
except Exception as e:
logger.critical(f"下载管理面板文件失败: {e}")
return
logger.info("管理面板下载完成。")
if __name__ == "__main__":
+108 -20
View File
@@ -3,8 +3,9 @@ import datetime
import astrbot.api.star as star
import astrbot.api.event.filter as filter
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import personalities
from astrbot.api import personalities, sp
from astrbot.api.provider import Personality, ProviderRequest
from astrbot.core.utils.io import download_dashboard
from typing import Union
@@ -16,6 +17,8 @@ class Main(star.Star):
self.prompt_prefix = cfg['provider_settings']['prompt_prefix']
self.identifier = cfg['provider_settings']['identifier']
self.enable_datetime = cfg['provider_settings']["datetime_system_prompt"]
self.kdb_enabled = False
async def _query_astrbot_notice(self):
try:
@@ -42,6 +45,7 @@ class Main(star.Star):
/deop <admin_id>: 取消管理员
/wl <sid>: 添加会话白名单
/dwl <sid>: 删除会话白名单
/dashboard update: 更新管理面板
[大模型]
/provider: 查看、切换大模型提供商
@@ -52,6 +56,10 @@ class Main(star.Star):
/persona: 情境人格设置
/tool ls: 查看、激活、停用当前注册的函数工具
[其他]
/set <变量名> <值>: 为当前会话定义一个变量。适用于 Dify 工作流输入。
/unset <变量名>: 删除当前会话的变量。
提示:如果要查看插件指令,请输入 /plugin 查看具体信息。
{notice}"""
@@ -87,24 +95,41 @@ class Main(star.Star):
event.set_result(MessageEventResult().message(f"停用工具 {tool_name} 失败,未找到此工具。"))
@filter.command("plugin")
async def plugin(self, event: AstrMessageEvent, oper: str = None):
if oper is None:
async def plugin(self, event: AstrMessageEvent, oper1: str = None, oper2: str = None):
if oper1 is None:
plugin_list_info = "已加载的插件:\n"
for plugin in self.context.get_all_stars():
plugin_list_info += f"- `{plugin.name}` By {plugin.author}: {plugin.desc}\n"
if plugin_list_info.strip() == "":
plugin_list_info = "没有加载任何插件。"
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。"
plugin_list_info += "\n使用 /plugin <插件名> 查看插件帮助。\n使用 /plugin on/off <插件名> 启用或者禁用插件。"
event.set_result(MessageEventResult().message(f"{plugin_list_info}").use_t2i(False))
else:
plugin = self.context.get_registered_star(oper)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件。"))
if oper1 == "off":
# 禁用插件
if oper2 is None:
event.set_result(MessageEventResult().message("/plugin off <插件名> 禁用插件。"))
return
await self.context._star_manager.turn_off_plugin(oper2)
event.set_result(MessageEventResult().message(f"插件 {oper2} 已禁用。"))
elif oper1 == "on":
# 启用插件
if oper2 is None:
event.set_result(MessageEventResult().message("/plugin on <插件名> 启用插件。"))
return
await self.context._star_manager.turn_on_plugin(oper2)
event.set_result(MessageEventResult().message(f"插件 {oper2} 已启用。"))
else:
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
ret = f"插件 {oper} 帮助信息:\n" + help_msg
event.set_result(MessageEventResult().message(ret).use_t2i(False))
# 获取插件帮助
plugin = self.context.get_registered_star(oper1)
if plugin is None:
event.set_result(MessageEventResult().message("未找到此插件。"))
else:
help_msg = plugin.star_cls.__doc__ if plugin.star_cls.__doc__ else "该插件未提供帮助信息"
ret = f"插件 {oper1} 帮助信息:\n" + help_msg
event.set_result(MessageEventResult().message(ret).use_t2i(False))
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
@@ -167,8 +192,9 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
if idx is None:
ret = "## 当前载入的 LLM 提供商\n"
for idx, llm in enumerate(self.context.get_all_providers()):
ret += f"{idx + 1}. {llm.meta().id} ({llm.meta().model})"
if self.provider == llm:
id_ = llm.meta().id
ret += f"{idx + 1}. {id_} ({llm.meta().model})"
if self.context.get_using_provider().meta().id == id_:
ret += " (当前使用)"
ret += "\n"
@@ -178,9 +204,12 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
if idx > len(self.context.get_all_providers()) or idx < 1:
event.set_result(MessageEventResult().message("无效的序号。"))
self.context.provider_manager.curr_provider_inst = self.context.get_all_providers()[idx - 1]
provider = self.context.get_all_providers()[idx - 1]
id_ = provider.meta().id
self.context.provider_manager.curr_provider_inst = provider
sp.put("curr_provider", id_)
event.set_result(MessageEventResult().message(f"成功切换到 {self.context.provider_manager.curr_provider_inst.meta().id}"))
event.set_result(MessageEventResult().message(f"成功切换到 {id_}"))
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
@@ -289,7 +318,7 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
- 重置 LLM 会话(保留人格): /reset p
【当前人格】: {str(self.context.get_using_provider().curr_personality['prompt'])}
"""))
""").use_t2i(False))
elif l[1] == "list":
msg = "人格列表:\n"
for key in personalities.keys():
@@ -318,6 +347,13 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
name="自定义人格", prompt=ps)
message.set_result(
MessageEventResult().message(f"人格已设置。 \n人格信息: {ps}"))
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard_update")
async def update_dashboard(self, event: AstrMessageEvent):
yield event.plain_result("正在尝试更新管理面板...")
await download_dashboard()
yield event.plain_result("管理面板更新完成。")
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
@@ -333,8 +369,60 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
req.system_prompt += f"\nCurrent datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
if provider.curr_personality['prompt']:
req.system_prompt += f"\n{provider.curr_personality['prompt']}"
@filter.event_message_type(filter.EventMessageType.OTHER_MESSAGE)
async def other_message(self, event: AstrMessageEvent):
print("triggered")
event.stop_event()
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
session_id = event.get_session_id()
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
session_var[key] = value
session_vars[session_id] = session_var
sp.put("session_variables", session_vars)
yield event.plain_result(f"会话 {session_id} 变量 {key} 存储成功。")
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str):
session_id = event.get_session_id()
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
if key not in session_var:
yield event.plain_result("没有那个变量名。")
else:
del session_var[key]
sp.put("session_variables", session_vars)
yield event.plain_result(f"会话 {session_id} 变量 {key} 移除成功。")
@filter.command_group("kdb")
def kdb(self):
pass
@kdb.command("on")
async def on_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = True
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if not curr_kdb_name:
yield event.plain_result("未载入任何知识库")
else:
yield event.plain_result(f"知识库已打开。当前载入的知识库: {curr_kdb_name}")
@kdb.command("off")
async def off_kdb(self, event: AstrMessageEvent):
self.kdb_enabled = False
yield event.plain_result("知识库已关闭")
@filter.on_llm_request()
async def on_llm_response(self, event: AstrMessageEvent, req: ProviderRequest):
curr_kdb_name = self.context.provider_manager.curr_kdb_name
if self.kdb_enabled and curr_kdb_name:
mgr = self.context.knowledge_db_manager
results = await mgr.retrive_records(curr_kdb_name, req.prompt)
if results:
req.system_prompt += "\nHere are documents that related to user's query: \n"
for result in results:
req.system_prompt += f"- {result}\n"
+396
View File
@@ -0,0 +1,396 @@
import os
import json
import shutil
import aiohttp
import uuid
import asyncio
import re
import astrbot.api.star as star
import aiodocker
from collections import defaultdict
from astrbot.api.event import AstrMessageEvent, MessageEventResult
from astrbot.api import llm_tool, logger
from astrbot.api.event import filter
from astrbot.api.provider import ProviderRequest
from astrbot.api.message_components import Image, File
PROMPT = """
## Task
You need to generate python codes to solve user's problem: {prompt}
{extra_input}
## Limit
1. Available libraries:
- standard libs
- `Pillow`
- `requests`
- `numpy`
- `matplotlib`
- `scipy`
- `scikit-learn`
- `beautifulsoup4`
- `pandas`
- `opencv-python`
- `python-docx`
- `python-pptx`
- `pymupdf` (Do not use fpdf, reportlab, etc.)
- `mplfonts`
You can only use these libraries and the libraries that they depend on.
2. Do not generate malicious code.
3. Use given `shared.api` package to output the result.
It has 3 functions: `send_text(text: str)`, `send_image(image_path: str)`, `send_file(file_path: str)`.
For Image and file, you must save it to `output` folder.
4. You must only output the code, do not output the result of the code and any other information.
5. The output language is same as user's input language.
6. Please first provide relevant knowledge about user's problem appropriately.
## Example
1. User's problem: `please solve the fabonacci sequence problem.`
Output:
```python
from shared.api import send_text, send_image, send_file
def fabonacci(n):
if n <= 1:
return n
else:
return fabonacci(n-1) + fabonacci(n-2)
result = fabonacci(10)
send_text("The fabonacci sequence is a series of numbers in which each number is the sum of the two preceding ones, starting from 0 and 1.")
send_text("Let's calculate the fabonacci sequence of 10: " + result) # send_text is a function to send pure text to user
```
2. User's problem: `please draw a sin(x) function.`
Output:
```python
from shared.api import send_text, send_image, send_file
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(0, 2*np.pi, 100)
y = np.sin(x)
plt.plot(x, y)
plt.savefig("output/sin_x.png")
send_text("The sin(x) is a periodic function with a period of 2π, and the value range is [-1, 1]. The following is the image of sin(x).")
send_image("output/sin_x.png") # send_image is a function to send image to user
send_text("If you need more information, please let me know :)")
```
{extra_prompt}
"""
DEFAULT_CONFIG = {
"sandbox": {
"image": "soulter/astrbot-code-interpreter-sandbox",
"docker_mirror": "", # cjie.eu.org
}
}
PATH = "data/config/python_interpreter.json"
@star.register(name="astrbot-python-interpreter", desc="Python 代码执行器", author="Soulter", version="0.0.1")
class Main(star.Star):
'''基于 Docker 沙箱的 Python 代码执行器'''
def __init__(self, context: star.Context) -> None:
self.context = context
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
self.workplace_path = os.path.join(self.curr_dir, "workplace")
self.shared_path = os.path.join(self.curr_dir, "shared")
os.makedirs(self.workplace_path, exist_ok=True)
self.user_file_msg_buffer = defaultdict(list)
'''存放用户上传的文件'''
# 加载配置
if not os.path.exists(PATH):
self.config = DEFAULT_CONFIG
self._save_config()
else:
with open(PATH, "r") as f:
self.config = json.load(f)
async def initialize(self):
ok = await self.is_docker_available()
if not ok:
logger.warning("Docker 不可用,代码解释器将无法使用,astrbot-python-interpreter 将自动禁用。")
await self.context._star_manager.turn_off_plugin("astrbot-python-interpreter")
async def file_upload(self, file_path: str):
'''
上传图像文件到 S3
'''
ext = os.path.splitext(file_path)[1]
S3_URL = "https://s3.neko.soulter.top/astrbot-s3"
with open(file_path, "rb") as f:
file = f.read()
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
async with session.put(s3_file_url, data=file) as resp:
if resp.status != 200:
raise Exception(f"Failed to upload image: {resp.status}")
return s3_file_url
async def is_docker_available(self) -> bool:
'''Check if docker is available'''
try:
docker = aiodocker.Docker()
await docker.version()
return True
except aiodocker.exceptions.DockerError as e:
logger.error(f"检查 Docker 可用性时出现问题: {e}")
return False
async def get_image_name(self) -> str:
'''Get the image name'''
if self.config["sandbox"]["docker_mirror"]:
return f"{self.config['sandbox']['docker_mirror']}/{self.config['sandbox']['image']}"
return self.config["sandbox"]["image"]
async def _save_config(self):
with open(PATH, "w") as f:
json.dump(self.config, f)
async def gen_magic_code(self) -> str:
return uuid.uuid4().hex[:8]
async def download_image(self, image_url: str, workplace_path: str, filename: str) -> str:
'''Download image from url to workplace_path'''
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
if resp.status != 200:
return ""
image_path = os.path.join(workplace_path, f"{filename}.jpg")
with open(image_path, 'wb') as f:
f.write(await resp.read())
return f"{filename}.jpg"
async def tidy_code(self, code: str) -> str:
'''Tidy the code'''
pattern = r"```(?:py|python)?\n(.*?)\n```"
match = re.search(pattern, code, re.DOTALL)
if match is None:
raise ValueError("The code is not in the code block.")
return match.group(1)
@filter.event_message_type(filter.EventMessageType.ALL)
async def on_message(self, event: AstrMessageEvent):
'''处理消息'''
for comp in event.message_obj.message:
if isinstance(comp, File):
self.user_file_msg_buffer[event.get_session_id()].append(comp.file)
logger.debug(f"User uploaded file: {comp.file}")
break # 一个消息中,文件只能有一个,这里直接 break 减少计算量。
@filter.on_llm_request()
async def on_llm_req(self, event: AstrMessageEvent, request: ProviderRequest):
if event.get_session_id() in self.user_file_msg_buffer:
files = self.user_file_msg_buffer[event.get_session_id()]
request.prompt += f"\nUser provided files: {files}"
@filter.command_group("pi")
def pi(self):
pass
@pi.command("mirror")
async def pi_mirror(self, event: AstrMessageEvent, url: str = ""):
'''Docker 镜像地址'''
if not url:
yield event.plain_result(f"""当前 Docker 镜像地址: {self.config['sandbox']['docker_mirror']}
使用 `pi mirror <url>` 来设置 Docker 镜像地址。
您所设置的 Docker 镜像地址将会自动加在 Docker 镜像名前。如: `soulter/astrbot-code-interpreter-sandbox` -> `cjie.eu.org/soulter/astrbot-code-interpreter-sandbox`。
""")
else:
self.config["sandbox"]["docker_mirror"] = url
await self._save_config()
yield event.plain_result("设置 Docker 镜像地址成功。")
@pi.command("repull")
async def pi_repull(self, event: AstrMessageEvent):
'''重新拉取沙箱镜像'''
docker = aiodocker.Docker()
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
await docker.images.delete(image_name, force=True)
except aiodocker.exceptions.DockerError:
pass
await docker.images.pull(image_name)
yield event.plain_result("重新拉取沙箱镜像成功。")
@llm_tool("python_interpreter")
async def python_interpreter(self, event: AstrMessageEvent):
'''Use this tool only if user really want to solve a complex problem and the problem can be solved very well by Python code.
For example, user can use this tool to solve math problems, edit image, docx, pptx, pdf, etc.
'''
if not await self.is_docker_available():
yield event.plain_result("Docker 在当前机器不可用,无法沙箱化执行代码。")
plain_text = event.message_str
# 创建必要的工作目录和幻术码
magic_code = await self.gen_magic_code()
workplace_path = os.path.join(self.workplace_path, magic_code)
output_path = os.path.join(workplace_path, "output")
os.makedirs(workplace_path, exist_ok=True)
os.makedirs(output_path, exist_ok=True)
# 图片
images = []
idx = 1
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
if image_url.startswith("http"):
image_path = await self.download_image(image_url, workplace_path, f"img_{idx}")
if image_path:
images.append(image_path)
idx += 1
# 文件
files = []
for file_path in self.user_file_msg_buffer[event.get_session_id()]:
# cp
file_name = os.path.basename(file_path)
shutil.copy(file_path, os.path.join(workplace_path, file_name))
files.append(file_name)
logger.debug(f"user query: {plain_text}, images: {images}, files: {files}")
# 整理额外输入
extra_inputs = ""
if images:
extra_inputs += f"User provided images: {images}\n"
if files:
extra_inputs += f"User provided files: {files}\n"
obs = ""
n = 5
for i in range(n):
if i > 0:
logger.info(f"Try {i+1}/{n}")
PROMPT_ = PROMPT.format(
prompt=plain_text,
extra_input=extra_inputs,
extra_prompt=obs,
)
provider = self.context.get_using_provider()
llm_response = await provider.text_chat(prompt=PROMPT_, session_id=f"{event.session_id}_{magic_code}_{str(i)}")
logger.debug("code interpreter llm gened code:" + llm_response.completion_text)
# 整理代码并保存
code_clean = await self.tidy_code(llm_response.completion_text)
with open(os.path.join(workplace_path, "exec.py"), "w") as f:
f.write(code_clean)
# 启动容器
docker = aiodocker.Docker()
# 检查有没有image
image_name = await self.get_image_name()
try:
await docker.images.get(image_name)
except aiodocker.exceptions.DockerError:
# 拉取镜像
logger.info(f"未找到沙箱镜像,正在尝试拉取 {image_name}...")
await docker.images.pull(image_name)
yield event.plain_result(f"使用沙箱执行代码中,请稍等...(尝试次数: {i+1}/{n})")
container = await docker.containers.run({
"Image": image_name,
"Cmd": ["python", "exec.py"],
"Memory": 512 * 1024 * 1024,
"NanoCPUs": 1000000000,
"HostConfig": {
"Binds": [
f"{self.shared_path}:/astrbot_sandbox/shared:ro",
f"{output_path}:/astrbot_sandbox/output:rw",
f"{workplace_path}:/astrbot_sandbox:rw",
]
},
"Env": [
f"MAGIC_CODE={magic_code}"
],
"AutoRemove": True
})
logger.debug(f"Container {container.id} created.")
logs = await self.run_container(container)
logger.debug(f"Container {container.id} finished.")
logger.debug(f"Container {container.id} logs: {logs}")
# 发送结果
pattern = r"\[ASTRBOT_(TEXT|IMAGE|FILE)_OUTPUT#\w+\]: (.*)"
ok = False
traceback = ""
for idx, log in enumerate(logs):
match = re.match(pattern, log)
if match:
ok = True
if match.group(1) == "TEXT":
yield event.plain_result(match.group(2))
elif match.group(1) == "IMAGE":
image_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending image: {image_path}")
yield event.image_result(image_path)
elif match.group(1) == "FILE":
file_path = os.path.join(workplace_path, match.group(2))
logger.debug(f"Sending file: {file_path}")
file_s3_url = await self.file_upload(file_path)
logger.info(f"文件上传到 AstrBot 云节点: {file_s3_url}")
file_name = os.path.basename(file_path)
chain = [File(name=file_name, file=file_s3_url)]
yield event.set_result(MessageEventResult(chain=chain))
elif "Traceback (most recent call last)" in log \
or "[Error]: " in log:
traceback = "\n".join(logs[idx:])
if not ok:
if traceback:
obs = f"## Observation \n When execute the code: ```python\n{code_clean}\n```\n\n Error occured:\n\n{traceback}\n Need to improve/fix the code."
else:
logger.warning(f"未从沙箱输出中捕获到合法的输出。沙箱输出日志: {logs}")
break
else:
# 成功了
self.user_file_msg_buffer.pop(event.get_session_id())
return
yield event.plain_result("经过多次尝试后,未从沙箱输出中捕获到合法的输出,请更换问法或者查看日志。")
@pi.command("cleanfile")
async def pi_cleanfile(self, event: AstrMessageEvent):
'''清理用户上传的文件'''
for file in self.user_file_msg_buffer[event.get_session_id()]:
try:
os.remove(file)
except BaseException as e:
logger.error(f"删除文件 {file} 失败: {e}")
self.user_file_msg_buffer.pop(event.get_session_id())
yield event.plain_result(f"用户 {event.get_session_id()} 上传的文件已清理。")
async def run_container(self, container: aiodocker.docker.DockerContainer, timeout: int = 20) -> list[str]:
'''Run the container and get the output'''
try:
await container.wait(timeout=timeout)
logs = await container.log(stdout=True, stderr=True)
return logs
except asyncio.TimeoutError:
logger.warning(f"Container {container.id} timeout.")
await container.kill()
return [f"[Error]: Container has been killed due to timeout ({timeout}s)."]
finally:
await container.delete()
@@ -0,0 +1 @@
aiodocker
+18
View File
@@ -0,0 +1,18 @@
import os
def _get_magic_code():
'''防止注入攻击'''
return os.getenv("MAGIC_CODE")
def send_text(text: str):
print(f"[ASTRBOT_TEXT_OUTPUT#{_get_magic_code()}]: {text}")
def send_image(image_path: str):
if not os.path.exists(image_path):
raise Exception(f"Image file not found: {image_path}")
print(f"[ASTRBOT_IMAGE_OUTPUT#{_get_magic_code()}]: {image_path}")
def send_file(file_path: str):
if not os.path.exists(file_path):
raise Exception(f"File not found: {file_path}")
print(f"[ASTRBOT_FILE_OUTPUT#{_get_magic_code()}]: {file_path}")
+27 -4
View File
@@ -31,9 +31,21 @@ class Main(star.Star):
if "datetime" in reminder:
if self.check_is_outdated(reminder):
continue
self.scheduler.add_job(self._reminder_callback, 'date', args=[reminder["text"], reminder], run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"))
self.scheduler.add_job(
self._reminder_callback,
trigger='date',
args=[group, reminder],
run_date=datetime.datetime.strptime(reminder["datetime"], "%Y-%m-%d %H:%M"),
misfire_grace_time=60
)
elif "cron" in reminder:
self.scheduler.add_job(self._reminder_callback, 'cron', args=[reminder["text"], reminder], **self._parse_cron_expr(reminder["cron"]))
self.scheduler.add_job(
self._reminder_callback,
trigger='cron',
args=[group, reminder],
misfire_grace_time=60,
**self._parse_cron_expr(reminder["cron"])
)
def check_is_outdated(self, reminder: dict):
'''Check if the reminder is outdated.'''
@@ -75,14 +87,25 @@ class Main(star.Star):
if cron_expression:
d = { "text": text, "cron": cron_expression, "cron_h": human_readable_cron }
self.reminder_data[event.unified_msg_origin].append(d)
self.scheduler.add_job(self._reminder_callback, 'cron', **self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d])
self.scheduler.add_job(
self._reminder_callback,
'cron',
misfire_grace_time=60,
**self._parse_cron_expr(cron_expression), args=[event.unified_msg_origin, d]
)
if human_readable_cron:
reminder_time = f"{human_readable_cron}(Cron: {cron_expression})"
else:
d = { "text": text, "datetime": datetime_str }
self.reminder_data[event.unified_msg_origin].append(d)
datetime_scheduled = datetime.datetime.strptime(datetime_str, "%Y-%m-%d %H:%M")
self.scheduler.add_job(self._reminder_callback, 'date', args=[event.unified_msg_origin, d], run_date=datetime_scheduled)
self.scheduler.add_job(
self._reminder_callback,
'date',
args=[event.unified_msg_origin, d],
run_date=datetime_scheduled,
misfire_grace_time=60
)
reminder_time = datetime_str
await self._save_data()
yield event.plain_result("成功设置待办事项。\n内容: " + text + "\n时间: " + reminder_time + "\n\n使用 /reminder ls 查看所有待办事项。")
+9
View File
@@ -22,6 +22,15 @@ class Main(star.Star):
self.sogo_search = Sogo()
self.google = Google()
async def initialize(self):
websearch = self.context.get_config()['provider_settings']['web_search']
if websearch:
self.context.activate_llm_tool("web_search")
self.context.activate_llm_tool("fetch_url")
else:
self.context.deactivate_llm_tool("web_search")
self.context.deactivate_llm_tool("fetch_url")
async def _tidy_text(self, text: str) -> str:
'''清理文本,去除空格、换行符等'''
return text.strip().replace("\n", " ").replace("\r", " ").replace(" ", " ")
+3 -1
View File
@@ -15,4 +15,6 @@ colorlog
aiocqhttp
pyjwt
apscheduler
docstring_parser
docstring_parser
aiodocker
silk-python
+33
View File
@@ -0,0 +1,33 @@
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor, pipeline
from datasets import load_dataset
device = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id = "openai/whisper-large-v3"
model = AutoModelForSpeechSeq2Seq.from_pretrained(
model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True
)
model.to(device)
processor = AutoProcessor.from_pretrained(model_id)
pipe = pipeline(
"automatic-speech-recognition",
model=model,
tokenizer=processor.tokenizer,
feature_extractor=processor.feature_extractor,
chunk_length_s=30,
batch_size=16, # batch size for inference - set based on your device
torch_dtype=torch_dtype,
device=device,
)
dataset = load_dataset("distil-whisper/librispeech_long", "clean", split="validation")
sample = dataset[0]["audio"]
result = pipe(sample)
print(result["text"])
+148
View File
@@ -0,0 +1,148 @@
import pytest
import os
from quart import Quart
from astrbot.dashboard.server import AstrBotDashboard
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core import LogBroker
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
@pytest.fixture(scope="module")
def core_lifecycle_td():
db = SQLiteDatabase("data/data_v3.db")
log_broker = LogBroker()
core_lifecycle_td = AstrBotCoreLifecycle(log_broker, db)
return core_lifecycle_td
@pytest.fixture(scope="module")
def app(core_lifecycle_td):
db = SQLiteDatabase("data/data_v3.db")
server = AstrBotDashboard(core_lifecycle_td, db)
return server.app
@pytest.fixture(scope="module")
def header():
return {}
@pytest.mark.asyncio
async def test_init_core_lifecycle_td(core_lifecycle_td):
await core_lifecycle_td.initialize()
assert core_lifecycle_td is not None
@pytest.mark.asyncio
async def test_auth_login(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle, header: dict):
test_client = app.test_client()
response = await test_client.post('/api/auth/login', json={
"username": "wrong",
"password": "password"
})
data = await response.get_json()
assert data['status'] == 'error'
response = await test_client.post('/api/auth/login', json={
"username": core_lifecycle_td.astrbot_config['dashboard']['username'],
"password": core_lifecycle_td.astrbot_config['dashboard']['password']
})
data = await response.get_json()
assert data['status'] == 'ok' and 'token' in data['data']
header['Authorization'] = f"Bearer {data['data']['token']}"
@pytest.mark.asyncio
async def test_get_stat(app: Quart, header: dict):
test_client = app.test_client()
response = await test_client.get('/api/stat/get')
assert response.status_code == 401
response = await test_client.get('/api/stat/get', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok' and 'platform' in data['data']
@pytest.mark.asyncio
async def test_plugins(app: Quart, header: dict):
test_client = app.test_client()
# 已经安装的插件
response = await test_client.get('/api/plugin/get', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件市场
response = await test_client.get('/api/plugin/market_list', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件安装
response = await test_client.post('/api/plugin/install', json={
"url": "https://github.com/Soulter/astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# 插件更新
response = await test_client.post('/api/plugin/update', json={
"name": "astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
# 插件卸载
response = await test_client.post('/api/plugin/uninstall', json={
"name": "astrbot_plugin_essential"
}, headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
@pytest.mark.asyncio
async def test_check_update(app: Quart, header: dict):
test_client = app.test_client()
response = await test_client.get('/api/update/check', headers=header)
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'success'
@pytest.mark.asyncio
async def test_do_update(app: Quart, header: dict, core_lifecycle_td: AstrBotCoreLifecycle):
global VERSION
test_client = app.test_client()
os.makedirs("data/astrbot_release", exist_ok=True)
core_lifecycle_td.astrbot_updator.MAIN_PATH = "data/astrbot_release"
VERSION = "114.514.1919810"
response = await test_client.post('/api/update/do', headers=header, json={
"version": "latest"
})
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'error' # 已经是最新版本
response = await test_client.post('/api/update/do', headers=header, json={
"version": "v3.4.0",
"reboot": False
})
assert response.status_code == 200
data = await response.get_json()
assert data['status'] == 'ok'
assert os.path.exists("data/astrbot_release/astrbot")
+48
View File
@@ -0,0 +1,48 @@
import os
import sys
import pytest
from unittest import mock
from main import check_env, check_dashboard_files
class _version_info():
def __init__(self, major, minor):
self.major = major
self.minor = minor
def test_check_env(monkeypatch):
version_info_correct = _version_info(3, 10)
version_info_wrong = _version_info(3, 9)
monkeypatch.setattr(sys, 'version_info', version_info_correct)
with mock.patch('os.makedirs') as mock_makedirs:
check_env()
mock_makedirs.assert_any_call("data/config", exist_ok=True)
mock_makedirs.assert_any_call("data/plugins", exist_ok=True)
mock_makedirs.assert_any_call("data/temp", exist_ok=True)
monkeypatch.setattr(sys, 'version_info', version_info_wrong)
with pytest.raises(SystemExit):
check_env()
@pytest.mark.asyncio
async def test_check_dashboard_files(monkeypatch):
monkeypatch.setattr(os.path, 'exists', lambda x: False)
async def mock_get(*args, **kwargs):
class MockResponse:
status = 200
async def read(self):
return b'content'
return MockResponse()
with mock.patch('aiohttp.ClientSession.get', new=mock_get):
with mock.patch('builtins.open', mock.mock_open()) as mock_file:
with mock.patch('zipfile.ZipFile.extractall') as mock_extractall:
async def mock_aenter(_):
await check_dashboard_files()
mock_file.assert_called_once_with("data/dashboard.zip", "wb")
mock_extractall.assert_called_once()
async def mock_aexit(obj, exc_type, exc, tb):
return
mock_extractall.__aenter__ = mock_aenter
mock_extractall.__aexit__ = mock_aexit
+226
View File
@@ -0,0 +1,226 @@
import pytest
import logging
import os
import asyncio
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember, MessageType
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
from astrbot.core.message.components import Plain, At
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.platform.manager import PlatformManager
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.star.context import Context
from asyncio import Queue
SESSION_ID_IN_WHITELIST = "test_sid_wl"
SESSION_ID_NOT_IN_WHITELIST = "test_sid"
TEST_LLM_PROVIDER = {
"id": "zhipu_default",
"type": "openai_chat_completion",
"enable": True,
"key": [os.getenv("ZHIPU_API_KEY")],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
"model_config": {
"model": "glm-4-flash",
},
}
TEST_COMMANDS = [
["help", "已注册的 AstrBot 内置指令"],
["tool ls", "函数工具"],
["tool on websearch", "激活工具"],
["tool off websearch", "停用工具"],
["plugin", "已加载的插件"],
["t2i", "文本转图片模式"],
["sid", "此 ID 可用于设置会话白名单。"],
["op test_op", "授权成功。"],
["deop test_op", "取消授权成功。"],
["wl test_platform:FriendMessage:test_sid_wl2", "添加白名单成功。"],
["dwl test_platform:FriendMessage:test_sid_wl2", "删除白名单成功。"],
["provider", "当前载入的 LLM 提供商"],
["reset", "重置成功"],
# ["model", "查看、切换提供商模型列表"],
["history", "历史记录:"],
["key", "当前 Key"],
["persona", "[Persona]"]
]
class FakeAstrMessageEvent(AstrMessageEvent):
def __init__(self, abm: AstrBotMessage = None):
meta = PlatformMetadata("test_platform", "test")
super().__init__(
message_str=abm.message_str,
message_obj=abm,
platform_meta=meta,
session_id=abm.session_id
)
async def send(self, message: MessageChain):
await super().send(message)
@staticmethod
def create_fake_event(
message_str: str,
session_id: str = "test_sid",
is_at: bool = False,
is_group: bool = False,
sender_id: str = "123456"
):
abm = AstrBotMessage()
abm.message_str = message_str
abm.group_id = "test"
abm.message = [Plain(message_str)]
if is_at:
abm.message.append(At(qq="bot"))
abm.self_id = "bot"
abm.sender = MessageMember(sender_id, "mika")
abm.timestamp = 1234567890
abm.message_id = "test"
abm.session_id = session_id
if is_group:
abm.type = MessageType.GROUP_MESSAGE
else:
abm.type = MessageType.FRIEND_MESSAGE
return FakeAstrMessageEvent(abm)
@pytest.fixture(scope="module")
def event_queue():
return Queue()
@pytest.fixture(scope="module")
def config():
cfg = AstrBotConfig()
cfg['platform_settings']['id_whitelist'] = ["test_platform:FriendMessage:test_sid_wl", "test_platform:GroupMessage:test_sid_wl"]
cfg['admins_id'] = ["123456"]
cfg['content_safety']['internal_keywords']['extra_keywords'] = ["^TEST_NEGATIVE"]
cfg['provider'] = [TEST_LLM_PROVIDER]
return cfg
@pytest.fixture(scope="module")
def db():
return SQLiteDatabase("data/data_v3.db")
@pytest.fixture(scope="module")
def platform_manager(event_queue, config):
return PlatformManager(config, event_queue)
@pytest.fixture(scope="module")
def provider_manager(config, db):
return ProviderManager(config, db)
@pytest.fixture(scope="module")
def star_context(event_queue, config, db, platform_manager, provider_manager):
star_context = Context(event_queue, config, db, provider_manager, platform_manager)
return star_context
@pytest.fixture(scope="module")
def plugin_manager(star_context, config):
plugin_manager = PluginManager(star_context, config)
# await plugin_manager.reload()
asyncio.run(plugin_manager.reload())
return plugin_manager
@pytest.fixture(scope="module")
def pipeline_context(config, plugin_manager):
return PipelineContext(config, plugin_manager)
@pytest.fixture(scope="module")
def pipeline_scheduler(pipeline_context):
return PipelineScheduler(pipeline_context)
@pytest.mark.asyncio
async def test_platform_initialization(platform_manager: PlatformManager):
await platform_manager.initialize()
@pytest.mark.asyncio
async def test_provider_initialization(provider_manager: ProviderManager):
await provider_manager.initialize()
@pytest.mark.asyncio
async def test_pipeline_scheduler_initialization(pipeline_scheduler: PipelineScheduler):
await pipeline_scheduler.initialize()
@pytest.mark.asyncio
async def test_pipeline_wakeup(pipeline_scheduler: PipelineScheduler, caplog):
'''测试唤醒'''
# 群聊无 @ 无指令
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("执行阶段 WhitelistCheckStage" not in message for message in caplog.messages)
# 群聊有 @ 无指令
mock_event = FakeAstrMessageEvent.create_fake_event("test", is_group=True, is_at=True)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("执行阶段 WhitelistCheckStage" in message for message in caplog.messages)
# 群聊有指令
mock_event = FakeAstrMessageEvent.create_fake_event("/help", is_group=True, session_id=SESSION_ID_IN_WHITELIST)
await pipeline_scheduler.execute(mock_event)
assert mock_event._has_send_oper is True
@pytest.mark.asyncio
async def test_pipeline_wl(pipeline_scheduler: PipelineScheduler, config: AstrBotConfig, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("test", SESSION_ID_IN_WHITELIST, sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("不在会话白名单中,已终止事件传播。" not in message for message in caplog.messages), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("test", sender_id="123")
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("不在会话白名单中,已终止事件传播。" in message for message in caplog.messages), "日志中未找到预期的消息"
@pytest.mark.asyncio
async def test_pipeline_content_safety(pipeline_scheduler: PipelineScheduler, caplog):
# 测试默认屏蔽词
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("色情", session_id=SESSION_ID_IN_WHITELIST) # 测试需要。
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
# 测试额外屏蔽词
mock_event = FakeAstrMessageEvent.create_fake_event("TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" in message for message in caplog.messages), "日志中未找到预期的消息"
mock_event = FakeAstrMessageEvent.create_fake_event("_TEST_NEGATIVE", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.INFO):
await pipeline_scheduler.execute(mock_event)
assert any("内容安全检查不通过" not in message for message in caplog.messages)
# TODO: 测试 百度AI 的内容安全检查
@pytest.mark.asyncio
async def test_pipeline_llm(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("just reply me `OK`", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert mock_event.get_result() is not None
assert mock_event.get_result().result_content_type == ResultContentType.LLM_RESULT
@pytest.mark.asyncio
async def test_pipeline_websearch(pipeline_scheduler: PipelineScheduler, caplog):
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event("help me search the latest OpenAI news", session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
assert any("请求 LLM" in message for message in caplog.messages)
assert any("web_searcher - search_from_search_engine" in message for message in caplog.messages)
@pytest.mark.asyncio
async def test_commands(pipeline_scheduler: PipelineScheduler, caplog):
for command in TEST_COMMANDS:
caplog.clear()
mock_event = FakeAstrMessageEvent.create_fake_event(command[0], session_id=SESSION_ID_IN_WHITELIST)
with caplog.at_level(logging.DEBUG):
await pipeline_scheduler.execute(mock_event)
# assert any("执行阶段 ProcessStage" in message for message in caplog.messages)
assert any(command[1] in message for message in caplog.messages)
+80
View File
@@ -0,0 +1,80 @@
import pytest
import os
from astrbot.core.star.star_manager import PluginManager
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star import star_registry
from astrbot.core.star.context import Context
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.db.sqlite import SQLiteDatabase
from asyncio import Queue
event_queue = Queue()
config = AstrBotConfig()
db = SQLiteDatabase("data/data_v3.db")
star_context = Context(event_queue, config, db)
@pytest.fixture
def plugin_manager_pm():
return PluginManager(star_context, config)
def test_plugin_manager_initialization(plugin_manager_pm: PluginManager):
assert plugin_manager_pm is not None
assert plugin_manager_pm.context is not None
assert plugin_manager_pm.config is not None
@pytest.mark.asyncio
async def test_plugin_manager_reload(plugin_manager_pm: PluginManager):
success, err_message = await plugin_manager_pm.reload()
assert success is True
assert err_message is None
assert len(star_handlers_registry) > 0 # package
@pytest.mark.asyncio
async def test_plugin_crud(plugin_manager_pm: PluginManager):
'''测试插件安装和重载'''
os.makedirs("data/plugins", exist_ok=True)
test_repo = "https://github.com/Soulter/astrbot_plugin_essential"
plugin_path = await plugin_manager_pm.install_plugin(test_repo)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert plugin_path is not None
assert os.path.exists(plugin_path)
assert exists is True, "插件 astrbot_plugin_essential 未成功载入"
# shutil.rmtree(plugin_path)
# install plugin which is not exists
with pytest.raises(Exception):
plugin_path = await plugin_manager_pm.install_plugin(test_repo + "haha")
# update
await plugin_manager_pm.update_plugin("astrbot_plugin_essential")
with pytest.raises(Exception):
await plugin_manager_pm.update_plugin("astrbot_plugin_essentialhaha")
# uninstall
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essential")
assert not os.path.exists(plugin_path)
exists = False
for md in star_registry:
if md.name == "astrbot_plugin_essential":
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
exists = False
for md in star_handlers_registry:
if "astrbot_plugin_essential" in md.handler_module_path:
exists = True
break
assert exists is False, "插件 astrbot_plugin_essential 未成功卸载"
with pytest.raises(Exception):
await plugin_manager_pm.uninstall_plugin("astrbot_plugin_essentialhaha")
# TODO: file installation