Compare commits

...

85 Commits

Author SHA1 Message Date
Soulter 94618e8feb feat: 添加 aiodocker 依赖 2025-01-11 22:02:15 +08:00
Soulter 55de7d4494 🎉 Bump to v3.4.5 2025-01-11 21:40:48 +08:00
Soulter 7ed639f741 🎉 bump to v3.4.5 2025-01-11 21:06:06 +08:00
Soulter 41f2870c29 Merge pull request #236 from Soulter/feat-stt
支持 Speech To Text,并适配腾讯修改过的 Silk 语音格式
2025-01-11 21:00:04 +08:00
Soulter ba198490fa feat: 支持自部署 Whisper 模型 2025-01-11 20:31:21 +08:00
Soulter 0f9ab082ab perf: 优化webchat,没有结果返回时的反馈 2025-01-11 19:45:42 +08:00
Soulter 97b58965f2 feat: webchat可显示Provider状态 2025-01-11 19:31:56 +08:00
Soulter f2566c68e3 feat: 按 K 语音 2025-01-11 19:07:26 +08:00
Soulter a456bf5449 fix: 初始化reminder时的一些问题 2025-01-11 18:55:18 +08:00
Soulter a09998f910 feat: webchat 支持语音输入 2025-01-11 18:54:40 +08:00
Soulter be662b913c feat: 支持 Whisper STT,并适配 Tencent 语音格式 2025-01-11 17:19:28 +08:00
Soulter e7ddc8448d perf: 代码执行器在成功执行后清空文件buffer 2025-01-11 11:31:56 +08:00
Soulter 29374f8d8a fix: 修复 /dashbord_update 指令 2025-01-11 00:25:02 +08:00
Soulter 359b971103 Merge pull request #235 from Soulter/feat-webchat
WebChat 支持
2025-01-11 00:17:18 +08:00
Soulter fbdb1ae208 chore: bump to v3.4.4 2025-01-11 00:14:08 +08:00
Soulter 22c13c1eff perf: webchat支持传图 2025-01-11 00:06:19 +08:00
Soulter 5fc63aeaf1 perf: ui 2025-01-10 22:45:14 +08:00
Soulter d4f32673ab fix: 修复持久化问题 2025-01-10 22:08:43 +08:00
Soulter 480dffb51b feat: 初步实现 webchat 页面 2025-01-10 21:48:15 +08:00
Soulter 966df00124 feat: 支持从管理面板(控制台页)手动安装 pip 库 2025-01-10 15:35:57 +08:00
Soulter 3e2b4bc727 feat: 支持动态设置会话变量以适用 Dify 输入变量 2025-01-10 12:32:20 +08:00
Soulter 5929a8d42b Update README.md 2025-01-09 23:11:11 +08:00
Soulter f8ab40eb39 chore: 上传管理面板package.json 2025-01-09 22:25:46 +08:00
Soulter 55e9233b93 docs: v3.4.3 changelog 2025-01-09 22:19:11 +08:00
Soulter b7277b51fd feat: 管理面板支持显示不在metadata中的配置 2025-01-09 22:03:53 +08:00
Soulter 1fa9111b2b perf: 进一步防止llm递归调用 2025-01-09 22:03:22 +08:00
Soulter 90a9e496d9 feat: 适配器类插件支持设置默认配置模板 2025-01-09 19:45:18 +08:00
Soulter 2a7dce1eb0 chore: clean code 2025-01-09 16:34:39 +08:00
Soulter 0c0841cc03 fix: websearch 在 cmd_config 中失效的问题 2025-01-09 16:33:58 +08:00
Soulter 4c9fe016bf fix: test_pipeline 2025-01-09 16:00:43 +08:00
Soulter acc90f140c chore: bump dashboard_release_url 2025-01-09 15:50:24 +08:00
Soulter 68a7bc3930 Merge pull request #232 from Soulter/feat-python-interpreter
初步实现代码执行器
2025-01-09 15:43:40 +08:00
Soulter 12ea64be0e fix: dashboard command bug 2025-01-09 15:42:04 +08:00
Soulter 7f30a673f7 fix: 修复 qq_official 无法发图 2025-01-09 15:20:54 +08:00
Soulter 897e100c32 Merge pull request #234 from Soulter/233-gemini-native-support
支持通过 Google GenAI 访问 Gemini 模型
2025-01-09 14:23:44 +08:00
Soulter 0d4ad5cb31 fix: 修复 APScheduler 任务错过后不执行的问题 2025-01-09 14:23:07 +08:00
Soulter b124bd0d0e feat: 支持通过 Google GenAI 访问 Gemini 模型 2025-01-09 14:05:48 +08:00
Soulter 6bc2f84602 Update README.md
qingcloud 在新网的账户余额不足导致原域名无法续费
2025-01-09 10:35:02 +08:00
Soulter d787a28c40 feat: 支持使用 /dashboard update 更新管理面板 2025-01-09 00:59:28 +08:00
Soulter 6b078a5731 cd: build dashboard files automatically 2025-01-09 00:57:48 +08:00
Soulter 17dddbfe21 chore: 禁用插件 2025-01-08 23:34:54 +08:00
Soulter 3ff3c9e144 perf: 检测到docker不可用时自动禁用本插件 2025-01-08 23:32:49 +08:00
Soulter f5a37d82cc Merge branch 'master' into feat-python-interpreter 2025-01-08 23:13:52 +08:00
Soulter d3d428dc9d fix: 管理面板支持禁用/启用插件 2025-01-08 23:04:03 +08:00
Soulter 8dc8c5b5dc feat: 支持对插件禁用/启用 2025-01-08 22:28:20 +08:00
Soulter e6b06f914b perf: provider 偏好项记忆 2025-01-08 20:46:34 +08:00
Soulter 4dc502a8b6 fix: 修复事件监听器会让wakestage失效的问题 2025-01-08 20:24:01 +08:00
Soulter b1d1a13d5f perf: 支持图片输入 2025-01-08 19:56:03 +08:00
Soulter 75cc4cac5a perf: 代码执行器添加部分控制指令,添加更多可用库 2025-01-08 13:26:16 +08:00
Soulter 1b7e4fbbdc perf: 退出时关闭 aiohttp client session 2025-01-08 12:43:34 +08:00
Soulter 9789e2f6c1 perf: 代码执行器请求llm不持久化历史记录 2025-01-08 02:12:35 +08:00
Soulter b8fb0bee24 feat: 初步实现代码执行器 #210 2025-01-08 02:10:27 +08:00
Soulter 419f77e245 Update README.md 2025-01-07 20:56:25 +08:00
Soulter 59b1c3473b Merge pull request #230 from Soulter/feat-dify
接入 Dify
2025-01-07 20:14:33 +08:00
Soulter 6db58ca375 perf: 优化在prompt为空的情况下不请求provider 2025-01-07 20:01:47 +08:00
Soulter 4832b342b0 Merge branch 'master' into feat-dify 2025-01-07 19:59:54 +08:00
Soulter 6cec542402 feat: 初步接入 Dify 2025-01-07 19:56:18 +08:00
Soulter 9644791783 feat: kdb 2024-12-30 18:06:09 +08:00
Soulter 5031c307d1 update: readme 2024-12-26 23:39:29 +08:00
Soulter aa49539e3e chore: fix test 2024-12-26 23:33:40 +08:00
Soulter 7b4118493b chore: fix test 2024-12-26 23:15:10 +08:00
Soulter d1cc9ba4ce chore: update test workflow 2024-12-26 23:09:11 +08:00
Soulter e0e92139d7 fix: test workflow 2024-12-26 23:07:50 +08:00
Soulter 62039392bb chore: fix test workflow 2024-12-26 23:06:30 +08:00
Soulter b72c69892e test: dashboard test 2024-12-26 22:59:17 +08:00
Soulter e6205e9aad ci: update workflow 2024-12-25 17:18:29 +08:00
Soulter b8a6fb1720 chore: update tests 2024-12-25 12:50:29 +08:00
Soulter 7c06d82f27 perf: plugin manager 重复 reload 释放资源 2024-12-25 12:50:29 +08:00
Soulter d92cb0f500 perf: 当没有provider时直接返回 2024-12-25 12:50:29 +08:00
Soulter 7fa72f2fe9 perf: adapt glm-4v-flash 2024-12-24 14:08:20 +08:00
Soulter 21d480a3b5 bugfixes 2024-12-22 05:31:29 +08:00
Soulter 771c045844 feat: 可配置是否启用白名单 2024-12-22 05:18:27 +08:00
Soulter e6ce484c15 perf: 不加载已经outdated的reminder 2024-12-22 05:06:15 +08:00
Soulter 102a92f62d perf: 移动对 prompt 的内置修改的逻辑 2024-12-21 18:39:10 +08:00
Soulter 6c7ac70701 Bump version to v3.4.2 2024-12-21 16:40:04 +08:00
Soulter 9d8372289f fix: fstring format error #226 2024-12-21 16:38:53 +08:00
Soulter 766f6a1ba2 perf: use request_llm 2024-12-21 16:35:16 +08:00
Soulter 193ff24f4c feat: 添加发送消息后的事件钩子 2024-12-20 16:31:36 +08:00
Soulter c675017374 feat: 新增LLM请求事件钩子和装饰消息结果钩子 2024-12-19 21:33:03 +08:00
Soulter 86cb852507 perf: llm-tuner adapter 检查路径 2024-12-18 21:25:04 +08:00
Soulter 73494e0d7d perf: 使用 astrbot-registry 下载面板静态资源 2024-12-18 21:24:39 +08:00
Soulter ec61aa1b6f Merge pull request #224 from Soulter/dashboard
迁移 AstrBot Dashboard 源代码至 AstrBot
2024-12-17 23:45:13 +08:00
Soulter 6df0e78b22 upload: dashboard from Soulter/AstrBot-Dashboard 2024-12-17 23:40:32 +08:00
Soulter 63c604359b fix: update 2024-12-16 22:53:23 +08:00
Soulter 08212588a0 chore: update docker ci/cd workflow 2024-12-16 21:12:02 +08:00
175 changed files with 18097 additions and 604 deletions
+2
View File
@@ -16,3 +16,5 @@ venv*/
ENV/
.conda/
README*.md
dashboard/
data/
+12 -1
View File
@@ -2,6 +2,7 @@ on:
push:
tags:
- 'v*'
workflow_dispatch:
name: Auto Release
@@ -14,6 +15,15 @@ jobs:
- name: Checkout repository
uses: actions/checkout@v4
- name: Dashboard Build
run: |
cd dashboard
npm install
npm run build
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
echo ${{ github.ref_name }} > dist/assets/version
zip -r dist.zip dist
- name: Fetch Changelog
run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
@@ -21,4 +31,5 @@ jobs:
- name: Create Release
uses: ncipollo/release-action@v1
with:
bodyFile: ${{ env.changelog }}
bodyFile: ${{ env.changelog }}
artifacts: "dashboard/dist.zip"
+15 -9
View File
@@ -1,7 +1,14 @@
name: Run tests and upload coverage
on:
push
push:
branches:
- master
paths-ignore:
- 'README.md'
- 'changelogs/**'
- 'dashboard/**'
workflow_dispatch:
jobs:
test:
@@ -21,17 +28,16 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install pytest pytest-cov pytest-asyncio
mkdir data
mkdir data/plugins
mkdir data/config
mkdir temp
- name: Run tests
run: |
export LLM_MODEL=${{ secrets.LLM_MODEL }}
export OPENAI_API_BASE=${{ secrets.OPENAI_API_BASE }}
export OPENAI_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v
mkdir data
mkdir data/plugins
mkdir data/config
mkdir data/temp
export TESTING=true
export ZHIPU_API_KEY=${{ secrets.OPENAI_API_KEY }}
PYTHONPATH=./ pytest --cov=. tests/ -v -o log_cli=true -o log_level=DEBUG
- name: Upload results to Codecov
uses: codecov/codecov-action@v4
+4 -3
View File
@@ -1,8 +1,9 @@
name: Docker Image CI/CD
on:
release:
types: [published]
push:
tags:
- 'v*'
workflow_dispatch:
jobs:
@@ -35,7 +36,7 @@ jobs:
push: true
tags: |
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:latest
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.event.release.tag_name }}
${{ secrets.DOCKER_HUB_USERNAME }}/astrbot:${{ github.ref_name }}
- name: Post build notifications
run: echo "Docker image has been built and pushed successfully"
+6 -1
View File
@@ -16,4 +16,9 @@ addons/plugins
tests/astrbot_plugin_openai
chroma
chroma
node_modules/
.DS_Store
package-lock.json
package.json
venv/*
+61 -10
View File
@@ -14,9 +14,10 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg"/></a>
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
[![wakatime](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e.svg)](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
[![codecov](https://codecov.io/gh/Soulter/AstrBot/graph/badge.svg?token=FF3P5967B8)](https://codecov.io/gh/Soulter/AstrBot)
</a>
<a href="https://astrbot.soulter.top/">查看文档</a>
<a href="https://astrbot.lwl.lol/">查看文档</a>
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
</div>
@@ -24,7 +25,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
## ✨ 多消息平台部署
1. QQ 群、QQ 频道、微信、Telegram。
1. QQ 群、QQ 频道、微信个人号、Telegram。
2. 支持文本转图片,Markdown 渲染。
## ✨ 多 LLM 配置
@@ -33,7 +34,8 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
2. 支持 OneAPI 等分发平台。
3. 支持 LLMTuner 载入微调模型。
4. 支持 Ollama 载入自部署模型。
4. 支持网页搜索(Web Search)。
4. 支持网页搜索(Web Search、自然语言待办提醒
5. 支持 Whisper 语音转文字
## ✨ 管理面板
@@ -42,15 +44,23 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
3. 简单的信息统计
4. 插件管理
<!-- ## ✨ ATRI [Beta 测试]
## ✨ 支持 Dify
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流![接入 Dify - AstrBot 文档](https://astrbot.lwl.lol/others/dify.html)
## ✨ 代码执行器(Beta)
基于 Docker 的沙箱化代码执行器(Beta 测试中)
> [!NOTE]
> 文件输入/输出目前仅支持 Napcat(QQ)
<div align='center'>
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
</div>
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
## ✨ 云部署
[![Run on Repl.it](https://repl.it/badge/github/Soulter/AstrBot)](https://repl.it/github/Soulter/AstrBot)
@@ -71,3 +81,44 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
- Star 这个项目!
- 在[爱发电](https://afdian.com/a/soulter)支持我!
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
## ✨ Demo
<div align='center'>
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
<img src="https://github.com/user-attachments/assets/8ec12797-e70f-460a-959e-48eca39ca2bb" height=100>
_✨ 自然语言待办事项 ✨_
<img src="https://github.com/user-attachments/assets/e137a9e1-340a-4bf2-bb2b-771132780735" height=150>
<img src="https://github.com/user-attachments/assets/480f5e82-cf6a-4955-a869-0d73137aa6e1" height=150>
_✨ 插件系统——部分插件展示 ✨_
<img src="https://github.com/user-attachments/assets/592a8630-14c7-4e06-b496-9c0386e4f36c" width=600>
_✨ 管理面板 ✨_
![webchat](https://drive.soulter.top/f/vlsA/ezgif-5-fb044b2542.gif)
_✨ 内置 Web Chat,在线与机器人交互 ✨_
</div>
<!-- ## ✨ ATRI [Beta 测试]
该功能作为插件载入。插件仓库地址:[astrbot_plugin_atri](https://github.com/Soulter/astrbot_plugin_atri)
1. 基于《ATRI ~ My Dear Moments》主角 ATRI 角色台词作为微调数据集的 `Qwen1.5-7B-Chat Lora` 微调模型。
2. 长期记忆
3. 表情包理解与回复
4. TTS
-->
+3 -1
View File
@@ -2,7 +2,8 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core.provider.register import register_llm_tool as llm_tool
from astrbot.core import sp
from astrbot.core.star.register import register_llm_tool as llm_tool
__all__ = [
"AstrBotConfig",
@@ -10,4 +11,5 @@ __all__ = [
"personalities",
"html_renderer",
"llm_tool",
"sp"
]
+1 -1
View File
@@ -3,7 +3,7 @@ from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot import logger
from astrbot.core.utils.personality import personalities
from astrbot.core import html_renderer
from astrbot.core.provider.register import register_llm_tool as llm_tool
from astrbot.core.star.register import register_llm_tool as llm_tool
# event
from astrbot.core.message.message_event_result import (
+13 -4
View File
@@ -1,9 +1,18 @@
from astrbot.core.message.message_event_result import (
MessageEventResult, MessageChain, CommandResult, EventResultType
)
MessageEventResult,
MessageChain,
CommandResult,
EventResultType,
ResultContentType,
)
from astrbot.core.platform import AstrMessageEvent
__all__ = [
'MessageEventResult', 'MessageChain', 'CommandResult', 'EventResultType', 'AstrMessageEvent'
]
"MessageEventResult",
"MessageChain",
"CommandResult",
"EventResultType",
"AstrMessageEvent",
"ResultContentType",
]
+9 -1
View File
@@ -4,7 +4,11 @@ from astrbot.core.star.register import (
register_event_message_type as event_message_type,
register_regex as regex,
register_platform_adapter_type as platform_adapter_type,
register_permission_type as permission_type
register_permission_type as permission_type,
register_on_llm_request as on_llm_request,
register_llm_tool as llm_tool,
register_on_decorating_result as on_decorating_result,
register_after_message_sent as after_message_sent
)
from astrbot.core.star.filter.event_message_type import EventMessageTypeFilter, EventMessageType
@@ -24,4 +28,8 @@ __all__ = [
'PlatformAdapterType',
'PermissionTypeFilter',
'PermissionType',
'on_llm_request',
'llm_tool',
'on_decorating_result',
'after_message_sent'
]
+2 -1
View File
@@ -1 +1,2 @@
from astrbot.core.provider import Provider, Personality, ProviderMetaData
from astrbot.core.provider import Provider, STTProvider, Personality
from astrbot.core.provider.entites import ProviderRequest, ProviderType, ProviderMetaData
+14 -1
View File
@@ -1,12 +1,25 @@
import os
import asyncio
from .log import LogManager, LogBroker
from astrbot.core.utils.t2i.renderer import HtmlRenderer
from astrbot.core.utils.shared_preferences import SharedPreferences
from astrbot.core.utils.pip_installer import PipInstaller
from astrbot.core.db.sqlite import SQLiteDatabase
from astrbot.core.config.default import DB_PATH
from astrbot.core.config import AstrBotConfig
os.makedirs("data", exist_ok=True)
astrbot_config = AstrBotConfig()
html_renderer = HtmlRenderer()
logger = LogManager.GetLogger(log_name='astrbot')
if os.environ.get('TESTING', ""):
logger.setLevel('DEBUG')
db_helper = SQLiteDatabase(DB_PATH)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
sp = SharedPreferences() # 简单的偏好设置存储
pip_installer = PipInstaller(astrbot_config.get('pip_install_arg', ''))
web_chat_queue = asyncio.Queue(maxsize=32)
web_chat_back_queue = asyncio.Queue(maxsize=32)
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
+96 -7
View File
@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.0"
VERSION = "3.4.5"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -17,6 +17,7 @@ DEFAULT_CONFIG = {
},
"reply_prefix": "",
"forward_threshold": 200,
"enable_id_white_list": True,
"id_whitelist": [],
"id_whitelist_log": True,
"wl_ignore_admin_on_group": True,
@@ -32,6 +33,10 @@ DEFAULT_CONFIG = {
"default_personality": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
"prompt_prefix": "",
},
"provider_stt_settings": {
"enable": False,
"provider_id": "",
},
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
@@ -49,7 +54,8 @@ DEFAULT_CONFIG = {
"log_level": "INFO",
"t2i_endpoint": "",
"pip_install_arg": "",
"plugin_repo_mirror": ""
"plugin_repo_mirror": "",
"knowledge_db": {},
}
@@ -162,6 +168,10 @@ CONFIG_METADATA_2 = {
"type": "int",
"hint": "超过一定字数后,机器人会将消息折叠成 QQ 群聊的 “转发消息”,以防止刷屏。目前仅 QQ 平台适配器适用。",
},
"enable_id_white_list": {
"description": "启用 ID 白名单",
"type": "bool"
},
"id_whitelist": {
"description": "ID 白名单",
"type": "list",
@@ -225,10 +235,10 @@ CONFIG_METADATA_2 = {
},
},
"provider_group": {
"name": "大语言模型",
"name": "服务提供商",
"metadata": {
"provider": {
"description": "大语言模型配置",
"description": "服务提供商配置",
"type": "list",
"config_template": {
"openai": {
@@ -251,7 +261,7 @@ CONFIG_METADATA_2 = {
"model": "llama3.1-8b",
},
},
"gemini": {
"gemini(OpenAI兼容)": {
"id": "gemini_default",
"type": "openai_chat_completion",
"enable": True,
@@ -261,6 +271,16 @@ CONFIG_METADATA_2 = {
"model": "gemini-1.5-flash",
},
},
"gemini(googlegenai原生)": {
"id": "gemini_default",
"type": "googlegenai_chat_completion",
"enable": True,
"key": [],
"api_base": "https://generativelanguage.googleapis.com/",
"model_config": {
"model": "gemini-1.5-flash",
},
},
"deepseek": {
"id": "deepseek_default",
"type": "openai_chat_completion",
@@ -273,7 +293,7 @@ CONFIG_METADATA_2 = {
},
"zhipu": {
"id": "zhipu_default",
"type": "openai_chat_completion",
"type": "zhipu_chat_completion",
"enable": True,
"key": [],
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
@@ -290,9 +310,39 @@ CONFIG_METADATA_2 = {
"llmtuner_template": "",
"finetuning_type": "lora",
"quantization_bit": 4,
},
"dify": {
"id": "dify_app_default",
"type": "dify",
"enable": True,
"dify_api_type": "chat",
"dify_api_key": "",
"dify_api_base": "https://api.dify.ai/v1",
"dify_workflow_output_key": "",
},
"whisper(API)": {
"id": "whisper",
"type": "openai_whisper_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "whisper-1",
},
"whisper(本地加载)": {
"whisper_hint": "(不用修改我)",
"enable": False,
"id": "whisper",
"type": "openai_whisper_selfhost",
"model": "tiny",
}
},
"items": {
"whisper_hint": {
"description": "本地部署 Whisper 模型须知",
"type": "string",
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cudaCPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
"obvious_hint": True
},
"id": {
"description": "ID",
"type": "string",
@@ -361,6 +411,27 @@ CONFIG_METADATA_2 = {
"top_p": {"description": "Top P值", "type": "float"},
},
},
"dify_api_key": {
"description": "API Key",
"type": "string",
"hint": "Dify API Key。此项必填。",
},
"dify_api_base": {
"description": "API Base URL",
"type": "string",
"hint": "Dify API Base URL。默认为 https://api.dify.ai/v1",
},
"dify_api_type": {
"description": "Dify 应用类型",
"type": "string",
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
"options": ["chat", "agent", "workflow"],
},
"dify_workflow_output_key": {
"description": "Dify Workflow 输出变量名",
"type": "string",
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
}
},
},
"provider_settings": {
@@ -370,7 +441,8 @@ CONFIG_METADATA_2 = {
"enable": {
"description": "启用大语言模型聊天",
"type": "bool",
"hint": "是否启用大语言模型聊天。默认启用",
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
"obvious_hint": True
},
"wake_prefix": {
"description": "LLM 聊天额外唤醒前缀",
@@ -404,6 +476,23 @@ CONFIG_METADATA_2 = {
},
},
},
"provider_stt_settings": {
"description": "语音转文本(STT)",
"type": "object",
"items": {
"enable": {
"description": "启用语音转文本(STT)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
"obvious_hint": True
},
"provider_id": {
"description": "提供商 ID,不填则默认第一个STT提供商",
"type": "string",
"hint": "语音转文本提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
},
},
"misc_config_group": {
+20 -6
View File
@@ -3,6 +3,7 @@ import time
import threading
import os
from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -16,11 +17,12 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class AstrBotCoreLifecycle:
def __init__(self, log_broker: LogBroker, db: BaseDatabase):
self.log_broker = log_broker
self.astrbot_config = AstrBotConfig()
self.astrbot_config = astrbot_config
self.db = db
if self.astrbot_config['http_proxy']:
@@ -29,7 +31,10 @@ class AstrBotCoreLifecycle:
async def initialize(self):
logger.info("AstrBot v"+ VERSION)
logger.setLevel(self.astrbot_config['log_level'])
if os.environ.get("TESTING", ""):
logger.setLevel("DEBUG")
else:
logger.setLevel(self.astrbot_config['log_level'])
self.event_queue = Queue()
self.event_queue.closed = False
@@ -37,12 +42,19 @@ class AstrBotCoreLifecycle:
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
self.star_context = Context(self.event_queue, self.astrbot_config, self.db)
self.star_context.platform_manager = self.platform_manager
self.star_context.provider_manager = self.provider_manager
self.knowledge_db_manager = KnowledgeDBManager(self.astrbot_config)
self.star_context = Context(
self.event_queue,
self.astrbot_config,
self.db,
self.provider_manager,
self.platform_manager,
self.knowledge_db_manager
)
self.plugin_manager = PluginManager(self.star_context, self.astrbot_config)
self.plugin_manager.reload()
await self.plugin_manager.reload()
'''扫描、注册插件、实例化插件类'''
await self.provider_manager.initialize()
@@ -81,6 +93,8 @@ class AstrBotCoreLifecycle:
self.event_queue.closed = True
for task in self.curr_tasks:
task.cancel()
await self.provider_manager.terminate()
for task in self.curr_tasks:
try:
+25 -1
View File
@@ -1,7 +1,7 @@
import abc
from dataclasses import dataclass
from typing import List
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision
from astrbot.core.db.po import Stats, LLMHistory, ATRIVision, WebChatConversation
@dataclass
class BaseDatabase(abc.ABC):
@@ -76,4 +76,28 @@ class BaseDatabase(abc.ABC):
@abc.abstractmethod
def get_atri_vision_data_by_path_or_id(self, url_or_path: str, id: str) -> ATRIVision:
'''通过 url 或 path 获取 ATRI 视觉数据'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
'''通过 user_id 和 cid 获取 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def webchat_new_conversation(self, user_id: str, cid: str):
'''新建 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def get_webchat_conversations(self, user_id: str) -> List[WebChatConversation]:
raise NotImplementedError
@abc.abstractmethod
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
'''更新 WebChatConversation'''
raise NotImplementedError
@abc.abstractmethod
def delete_webchat_conversation(self, user_id: str, cid: str):
'''删除 WebChatConversation'''
raise NotImplementedError
+12 -1
View File
@@ -51,4 +51,15 @@ class ATRIVision():
platform_name: str
session_id: str
sender_nickname: str
timestamp: int = -1
timestamp: int = -1
@dataclass
class WebChatConversation():
user_id: str
cid: str
history: str = ""
created_at: int = 0
updated_at: int = 0
+65 -1
View File
@@ -5,7 +5,8 @@ from astrbot.core.db.po import (
Platform,
Stats,
LLMHistory,
ATRIVision
ATRIVision,
WebChatConversation
)
from . import BaseDatabase
from typing import Tuple
@@ -199,6 +200,69 @@ class SQLiteDatabase(BaseDatabase):
c.close()
return Stats(platform, [], [])
def get_webchat_conversation_by_user_id(self, user_id: str, cid: str) -> WebChatConversation:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
res = c.fetchone()
c.close()
return WebChatConversation(*res)
def webchat_new_conversation(self, user_id: str, cid: str):
history = "[]"
updated_at = int(time.time())
created_at = updated_at
self._exec_sql(
'''
INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?)
''', (user_id, cid, history, updated_at, created_at)
)
def get_webchat_conversations(self, user_id: str) -> Tuple:
try:
c = self.conn.cursor()
except sqlite3.ProgrammingError:
c = self._get_conn(self.db_path).cursor()
c.execute(
'''
SELECT cid, created_at, updated_at FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC
''', (user_id,)
)
res = c.fetchall()
c.close()
conversations = []
for row in res:
cid = row[0]
created_at = row[1]
updated_at = row[2]
conversations.append(WebChatConversation("", cid, '[]', created_at, updated_at))
return conversations
def update_webchat_conversation(self, user_id: str, cid: str, history: str):
self._exec_sql(
'''
UPDATE webchat_conversation SET history = ? WHERE user_id = ? AND cid = ?
''', (history, user_id, cid)
)
def delete_webchat_conversation(self, user_id: str, cid: str):
self._exec_sql(
'''
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
''', (user_id, cid)
)
def insert_atri_vision_data(self, vision: ATRIVision):
+8
View File
@@ -35,4 +35,12 @@ CREATE TABLE IF NOT EXISTS atri_vision(
session_id VARCHAR(32),
sender_nickname VARCHAR(32),
timestamp INTEGER
);
CREATE TABLE IF NOT EXISTS webchat_conversation(
user_id TEXT,
cid TEXT,
history TEXT,
created_at INTEGER,
updated_at INTEGER
);
+15 -2
View File
@@ -54,6 +54,7 @@ class ComponentType(Enum):
CardImage = "CardImage"
TTS = "TTS"
Unknown = "Unknown"
File = "File"
class BaseMessageComponent(BaseModel):
@@ -122,7 +123,7 @@ class Record(BaseMessageComponent):
proxy: T.Optional[bool] = True
timeout: T.Optional[int] = 0
# 额外
path: T.Optional[str]
path: T.Optional[str] # 用这个
def __init__(self, file: T.Optional[str], **_):
for k in _.keys():
@@ -415,6 +416,17 @@ class Unknown(BaseMessageComponent):
def toString(self):
return ""
class File(BaseMessageComponent):
'''
目前此消息段只适配了 Napcat。
'''
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url(本地路径)
def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)
ComponentTypes = {
"plain": Plain,
@@ -441,5 +453,6 @@ ComponentTypes = {
"json": Json,
"cardimage": CardImage,
"tts": TTS,
"unknown": Unknown
"unknown": Unknown,
'file': File,
}
+19 -1
View File
@@ -97,7 +97,14 @@ class EventResultType(enum.Enum):
'''
CONTINUE = enum.auto()
STOP = enum.auto()
class ResultContentType(enum.Enum):
'''用于描述事件结果的内容的类型。
'''
LLM_RESULT = enum.auto()
'''调用 LLM 产生的结果'''
GENERAL_RESULT = enum.auto()
'''普通的消息结果'''
@dataclass
class MessageEventResult(MessageChain):
'''MessageEventResult 描述了一整条消息中带有的所有组件以及事件处理的结果。
@@ -112,6 +119,8 @@ class MessageEventResult(MessageChain):
result_type: Optional[EventResultType] = field(default_factory=lambda: EventResultType.CONTINUE)
result_content_type: Optional[ResultContentType] = field(default_factory=lambda: ResultContentType.GENERAL_RESULT)
def stop_event(self) -> 'MessageEventResult':
'''终止事件传播。
'''
@@ -130,5 +139,14 @@ class MessageEventResult(MessageChain):
'''
return self.result_type == EventResultType.STOP
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
result_type (EventResultType): 事件处理的结果类型。
'''
self.result_content_type = typ
return self
CommandResult = MessageEventResult
+3
View File
@@ -3,6 +3,7 @@ from astrbot.core.message.message_event_result import MessageEventResult, EventR
from .waking_check.stage import WakingCheckStage
from .whitelist_check.stage import WhitelistCheckStage
from .content_safety_check.stage import ContentSafetyCheckStage
from .preprocess_stage.stage import PreProcessStage
from .process_stage.stage import ProcessStage
from .result_decorate.stage import ResultDecorateStage
from .respond.stage import RespondStage
@@ -12,6 +13,7 @@ STAGES_ORDER = [
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
"RateLimitCheckStage", # 检查会话是否超过频率限制
"ContentSafetyCheckStage", # 检查内容安全
"PreProcessStage", # 预处理
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
"RespondStage" # 发送消息
@@ -21,6 +23,7 @@ __all__ = [
"WakingCheckStage",
"WhitelistCheckStage",
"ContentSafetyCheckStage",
"PreProcessStage",
"ProcessStage",
"ResultDecorateStage",
"RespondStage",
@@ -1,6 +1,6 @@
from . import ContentSafetyStrategy
from typing import List, Tuple
from astrbot import logger
class StrategySelector:
def __init__(self, config: dict) -> None:
@@ -15,7 +15,8 @@ class StrategySelector:
try:
from .baidu_aip import BaiduAipStrategy
except ImportError:
raise ImportError("使用百度内容审核应该先 pip install baidu-aip")
logger.warning("使用百度内容审核应该先 pip install baidu-aip")
return
self.enabled_strategies.append(
BaiduAipStrategy(
config["baidu_aip"]["app_id"],
@@ -0,0 +1,55 @@
import traceback
import asyncio
from typing import Union, AsyncGenerator
from ..stage import Stage, register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.message.components import Plain, Record
@register_stage
class PreProcessStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.plugin_manager = ctx.plugin_manager
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''在处理事件之前的预处理'''
if self.stt_settings.get('enable', False):
# STT 处理
# TODO: 独立
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
if stt_provider:
message_chain = event.get_messages()
for idx, component in enumerate(message_chain):
if isinstance(component, Record) and component.path:
path = component.path
retry = 5
for i in range(retry):
try:
result = await stt_provider.get_text(audio_url=path)
if result:
logger.info("语音转文本结果: " + result)
message_chain[idx] = Plain(result)
event.message_str += result
event.message_obj.message_str += result
break
except FileNotFoundError as e:
# napcat workaround
logger.warning(e)
logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}")
await asyncio.sleep(0.5)
continue
except BaseException as e:
logger.error(traceback.format_exc())
logger.error(f"语音转文本失败: {e}")
break
@@ -0,0 +1,60 @@
'''
Dify 调用 Stage
'''
import traceback
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest
class DifyRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
if provider.meta().type != "dify":
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
return
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
req.session_id = event.session_id
event.set_extra("provider_request", req)
if not req.prompt:
return
try:
logger.debug(f"Dify 请求 Payload: {req.__dict__}")
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
yield # rick roll
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 Dify 失败:" + str(e)))
return
@@ -1,113 +1,101 @@
'''
本地 Agent 模式的 LLM 调用 Stage
'''
import traceback
import inspect
from typing import Union, AsyncGenerator
from ...context import PipelineContext
from ..stage import Stage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.message.message_event_result import MessageEventResult, ResultContentType
from astrbot.core.message.components import Image
from astrbot.core import logger
from astrbot.core.utils.metrics import Metric
from astrbot.core.star.star import star_map
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core.star.star_handler import star_handlers_registry, EventType
class LLMRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
self.prompt_prefix = ctx.astrbot_config['provider_settings']['prompt_prefix']
self.identifier = ctx.astrbot_config['provider_settings']['identifier']
self.ctx = ctx
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
# Chat 唤醒前缀
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
return
event.message_str = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
if self.prompt_prefix:
event.message_str = self.prompt_prefix + event.message_str
if self.identifier:
user_id = event.message_obj.sender.user_id
user_nickname = event.message_obj.sender.nickname
user_info = f"[User ID: {user_id}, Nickname: {user_nickname}]\n"
event.message_str = user_info + event.message_str
image_urls = []
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
image_urls.append(image_url)
tools = self.ctx.plugin_manager.context.get_llm_tool_manager()
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
req: ProviderRequest = None
provider = self.ctx.plugin_manager.context.get_using_provider()
try:
llm_response = await provider.text_chat(
prompt=event.message_str,
session_id=event.session_id,
image_urls=image_urls,
func_tool=tools
)
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if provider is None:
return
if event.get_extra("provider_request"):
req = event.get_extra("provider_request")
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
else:
req = ProviderRequest(prompt="", image_urls=[])
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
return
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
for comp in event.message_obj.message:
if isinstance(comp, Image):
image_url = comp.url if comp.url else comp.file
req.image_urls.append(image_url)
req.session_id = event.session_id
event.set_extra("provider_request", req)
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if not req.prompt:
return
# 执行请求 LLM 前事件。
# 装饰 system_prompt 等功能
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnLLMRequestEvent)
for handler in handlers:
try:
await handler.handler(event, req)
except BaseException:
logger.error(traceback.format_exc())
try:
logger.debug(f"提供商请求 Payload: {req.__dict__}")
if _nested:
req.func_tool = None # 暂时不支持递归工具调用
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
await Metric.upload(llm_tick=1, model_name=provider.get_model(), provider_type=provider.meta().type)
if llm_response.role == 'assistant':
# text completion
event.set_result(MessageEventResult().message(llm_response.completion_text))
event.set_result(MessageEventResult().message(llm_response.completion_text)
.set_result_content_type(ResultContentType.LLM_RESULT))
elif llm_response.role == 'tool':
# function calling
function_calling_result = {}
for func_tool_name, func_tool_args in zip(llm_response.tools_call_name, llm_response.tools_call_args):
func_tool = tools.get_func(func_tool_name)
func_tool = req.func_tool.get_func(func_tool_name)
logger.info(f"调用工具函数:{func_tool_name},参数:{func_tool_args}")
try:
# 尝试调用工具函数
star_cls_obj = star_map.get(func_tool.module_name).star_cls
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
if hasattr(func_tool.func_obj, '__self__'):
# 猜测没有通过装饰器去注册
try:
ready_to_call = func_tool.func_obj(event, **func_tool_args)
except TypeError:
# 向下兼容
ready_to_call = func_tool.func_obj(event, self.ctx.plugin_manager.context, **func_tool_args)
else:
ready_to_call = func_tool.func_obj(star_cls_obj, event, **func_tool_args)
if isinstance(ready_to_call, AsyncGenerator):
async for mer in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
if mer:
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(mer)
yield
else:
if event.get_result():
yield
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if ret:
# 如果有返回值
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(ret)
# 执行后续步骤来发送消息
if event.is_stopped() and event.get_result():
# 主动停止事件传播,并且有结果
event.continue_event()
yield
event.clear_result()
event.stop_event()
yield
elif not event.is_stopped and not event.get_result():
continue
wrapper = self._call_handler(self.ctx, event, func_tool.handler, **func_tool_args)
async for resp in wrapper:
if resp is not None:
function_calling_result[func_tool_name] = resp
else:
yield
event.clear_result() # 清除上一个 handler 的结果
except BaseException as e:
logger.warning(traceback.format_exc())
function_calling_result[func_tool_name] = "When calling the function, an error occurred: " + str(e)
if function_calling_result:
# 工具返回 LLM 资源。比如 RAG、网页 得到的相关结果等。
# 我们重新执行一遍这个 stage
req.func_tool = None # 暂时不支持递归工具调用
extra_prompt = "\n\nSystem executed some external tools for this task and here are the results:\n"
for tool_name, tool_result in function_calling_result.items():
extra_prompt += f"Tool: {tool_name}\nTool Result: {tool_result}\n"
req.prompt += extra_prompt
async for _ in self.process(event, _nested=True):
yield
except BaseException:
logger.error(traceback.format_exc())
except BaseException as e:
logger.error(traceback.format_exc())
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
@@ -1,13 +1,16 @@
'''
本地 Agent 模式的 AstrBot 插件调用 Stage
'''
from ...context import PipelineContext
from ..stage import Stage
from typing import Dict, Any, List, AsyncGenerator, Union
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core import logger
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.star.star import star_map
import traceback
import inspect
class StarRequestSubStage(Stage):
async def initialize(self, ctx: PipelineContext) -> None:
@@ -24,58 +27,19 @@ class StarRequestSubStage(Stage):
for handler in activated_handlers:
params = handlers_parsed_params.get(handler.handler_full_name, {})
try:
if handler.handler_module_str not in star_map:
if handler.handler_module_path not in star_map:
# 孤立无援的 star handler
continue
star_cls_obj = star_map.get(handler.handler_module_str).star_cls
logger.debug(f"执行 Star Handler {handler.handler_full_name}")
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
if hasattr(handler.handler, '__self__'):
# 猜测没有通过装饰器去注册
try:
ready_to_call = handler.handler(event, **params)
except TypeError:
# 向下兼容
ready_to_call = handler.handler(event, self.ctx.plugin_manager.context, **params)
else:
ready_to_call = handler.handler(star_cls_obj, event, **params)
if isinstance(ready_to_call, AsyncGenerator):
async for mer in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
if mer:
assert isinstance(mer, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(mer)
yield
else:
if event.get_result():
yield
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if ret:
# 如果有返回值
assert isinstance(ret, (MessageEventResult, CommandResult)), "如果有返回值,必须是 MessageEventResult 或 CommandResult 类型。"
event.set_result(ret)
# 执行后续步骤来发送消息
if event.is_stopped() and event.get_result():
# 插件主动停止事件传播,并且有结果
event.continue_event()
yield
event.clear_result()
event.stop_event()
yield
elif not event.is_stopped and not event.get_result():
continue
else:
yield
wrapper = self._call_handler(self.ctx, event, handler.handler, **params)
async for ret in wrapper:
yield ret
event.clear_result() # 清除上一个 handler 的结果
except Exception as e:
logger.error(traceback.format_exc())
logger.error(f"Star {handler.handler_full_name} handle error: {e}")
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_str).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
ret = f":(\n\n在调用插件 {star_map.get(handler.handler_module_path).name} 的处理函数 {handler.handler_name} 时出现异常:{e}"
event.set_result(MessageEventResult().message(ret))
yield
event.clear_result()
+31 -9
View File
@@ -3,8 +3,11 @@ from ..stage import Stage, register_stage
from ..context import PipelineContext
from .method.llm_request import LLMRequestSubStage
from .method.star_request import StarRequestSubStage
from .method.dify_request import DifyRequestSubStage
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.star_handler import StarHandlerMetadata
from astrbot.core.provider.entites import ProviderRequest
from astrbot.core import logger
@register_stage
class ProcessStage(Stage):
@@ -18,19 +21,38 @@ class ProcessStage(Stage):
self.star_request_sub_stage = StarRequestSubStage()
await self.star_request_sub_stage.initialize(ctx)
self.dify_request_sub_stage = DifyRequestSubStage()
await self.dify_request_sub_stage.initialize(ctx)
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
'''处理事件
'''
activated_handlers: List[StarHandlerMetadata] = event.get_extra("activated_handlers")
# 有插件 Handler 被激活
if activated_handlers:
async for _ in self.star_request_sub_stage.process(event):
yield
if self.ctx.astrbot_config['provider_settings'].get('enable', True):
if not event._has_send_oper:
'''当没有发送操作'''
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
async for resp in self.star_request_sub_stage.process(event):
# 生成器返回值处理
if isinstance(resp, ProviderRequest):
# Handler 的 LLM 请求
logger.debug(f"llm request -> {resp.prompt}")
event.set_extra("provider_request", resp)
async for _ in self.llm_request_sub_stage.process(event):
yield
yield
else:
yield
# 调用提供商相关请求
if not self.ctx.astrbot_config['provider_settings'].get('enable', True):
return
if not event._has_send_oper and event.is_at_or_wake_command:
if (event.get_result() and not event.get_result().is_stopped()) or not event.get_result():
provider = self.ctx.plugin_manager.context.get_using_provider()
match provider.meta().type:
case "dify":
async for _ in self.dify_request_sub_stage.process(event):
yield
case _:
async for _ in self.llm_request_sub_stage.process(event):
yield
+11 -4
View File
@@ -1,11 +1,12 @@
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@register_stage
class RespondStage:
class RespondStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
@@ -13,8 +14,14 @@ class RespondStage:
result = event.get_result()
if result is None:
return
if len(result.chain) > 0:
await event.send(result)
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)
for handler in handlers:
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
await handler.handler(event)
event.clear_result()
@@ -6,6 +6,7 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@register_stage
class ResultDecorateStage:
@@ -19,6 +20,11 @@ class ResultDecorateStage:
if result is None:
return
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnDecoratingResultEvent)
for handler in handlers:
# TODO: 如何让这里的 handler 也能使用 LLM 能力。也许需要将 LLMRequestSubStage 提取出来。
await handler.handler(event)
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
+4
View File
@@ -41,4 +41,8 @@ class PipelineScheduler():
async def execute(self, event: AstrMessageEvent):
'''执行 pipeline'''
await self._process_stages(event)
if not event._has_send_oper and event.get_platform_name() == "webchat":
await event.send(None)
logger.debug("pipeline 执行完毕。")
+35 -2
View File
@@ -1,8 +1,10 @@
from __future__ import annotations
import abc
from typing import List, AsyncGenerator, Union
import inspect
from typing import List, AsyncGenerator, Union, Awaitable
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from .context import PipelineContext
from astrbot.core.message.message_event_result import MessageEventResult, CommandResult
registered_stages: List[Stage] = []
'''维护了所有已注册的 Stage 实现类'''
@@ -29,4 +31,35 @@ class Stage(abc.ABC):
'''
raise NotImplementedError
async def _call_handler(
self,
ctx: PipelineContext,
event: AstrMessageEvent,
handler: Awaitable,
**params
) -> AsyncGenerator[None, None]:
'''调用 Handler。'''
# 判断 handler 是否是类方法(通过装饰器注册的没有 __self__ 属性)
ready_to_call = None
try:
ready_to_call = handler(event, **params)
except TypeError as e:
# 向下兼容
ready_to_call = handler(event, ctx.plugin_manager.context, **params)
if isinstance(ready_to_call, AsyncGenerator):
async for ret in ready_to_call:
# 如果处理函数是生成器,返回值只能是 MessageEventResult 或者 None(无返回值)
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
elif inspect.iscoroutine(ready_to_call):
# 如果只是一个 coroutine
ret = await ready_to_call
if isinstance(ret, (MessageEventResult, CommandResult)):
event.set_result(ret)
yield
else:
yield ret
+5 -2
View File
@@ -4,7 +4,7 @@ from typing import Union, AsyncGenerator
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageEventResult
from astrbot.core.message.components import At
from astrbot.core.star.star_handler import star_handlers_registry
from astrbot.core.star.star_handler import star_handlers_registry, EventType
from astrbot.core.star.filter.command_group import CommandGroupFilter
@@ -47,6 +47,7 @@ class WakingCheckStage(Stage):
# 如果是群聊,且第一个消息段是 At 消息,但不是 At 机器人或 At 全体成员,则不唤醒
break
is_wake = True
event.is_at_or_wake_command = True
event.is_wake = True
event.message_str = event.message_str[len(wake_prefix) :].strip()
break
@@ -60,17 +61,19 @@ class WakingCheckStage(Stage):
is_wake = True
event.is_wake = True
wake_prefix = ""
event.is_at_or_wake_command = True
break
# 检查是否是私聊
if event.is_private_chat():
is_wake = True
event.is_wake = True
event.is_at_or_wake_command = True
wake_prefix = ""
# 检查插件的 handler filter
activated_handlers = []
handlers_parsed_params = {} # 注册了指令的 handler
for handler in star_handlers_registry:
for handler in star_handlers_registry.get_handlers_by_event_type(EventType.AdapterMessageEvent):
# filter 需要满足 AND 的逻辑关系
passed = True
child_command_handler_md = None
@@ -10,12 +10,20 @@ class WhitelistCheckStage(Stage):
'''检查是否在群聊/私聊白名单
'''
async def initialize(self, ctx: PipelineContext) -> None:
self.enable_whitelist_check = ctx.astrbot_config['platform_settings']['enable_id_white_list']
self.whitelist = ctx.astrbot_config['platform_settings']['id_whitelist']
self.wl_ignore_admin_on_group = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_group']
self.wl_ignore_admin_on_friend = ctx.astrbot_config['platform_settings']['wl_ignore_admin_on_friend']
self.wl_log = ctx.astrbot_config['platform_settings']['id_whitelist_log']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
if not self.enable_whitelist_check:
return
if event.get_platform_name() == 'webchat':
# WebChat 豁免
return
# 检查是否在白名单
if self.wl_ignore_admin_on_group:
if event.role == 'admin' and event.get_message_type() == MessageType.GROUP_MESSAGE:
+37 -3
View File
@@ -7,6 +7,8 @@ from astrbot.core.platform.message_type import MessageType
from typing import List, Union
from astrbot.core.message.components import Plain, Image, BaseMessageComponent, Face, At, AtAll, Forward
from astrbot.core.utils.metrics import Metric
from astrbot.core.provider.entites import ProviderRequest
@dataclass
class MessageSesion:
@@ -33,7 +35,8 @@ class AstrMessageEvent(abc.ABC):
self.platform_meta = platform_meta
self.session_id = session_id
self.role = "member"
self.is_wake = False
self.is_wake = False # 是否通过 WakingStage
self.is_at_or_wake_command = False # 是否是 At 机器人或者带有唤醒词或者是私聊(事件监听器会让 is_wake 设为 True
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
@@ -236,7 +239,9 @@ class AstrMessageEvent(abc.ABC):
清除消息事件的结果。
'''
self._result = None
'''消息链相关'''
def make_result(self) -> MessageEventResult:
'''
创建一个空的消息事件结果。
@@ -275,4 +280,33 @@ class AstrMessageEvent(abc.ABC):
'''
mer = MessageEventResult()
mer.chain = chain
return mer
return mer
'''LLM 请求相关'''
def request_llm(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
contexts: List = None,
system_prompt: str = ""
) -> ProviderRequest:
'''
创建一个 LLM 请求。
Examples:
```py
yield event.request_llm(prompt="hi")
```
image_urls: 可以是 base64:// 或者 http:// 开头的图片链接,也可以是本地图片路径。
contexts: 当指定 contexts 时,将会**只**使用 contexts 作为上下文。
'''
return ProviderRequest(
prompt = prompt,
session_id = session_id,
image_urls = image_urls,
contexts = contexts,
system_prompt = system_prompt
)
+4 -1
View File
@@ -4,7 +4,7 @@ from typing import List
from asyncio import Queue
from .register import platform_cls_map
from astrbot.core import logger
from .sources.webchat.webchat_adapter import WebChatAdapter
class PlatformManager():
def __init__(self, config: AstrBotConfig, event_queue: Queue):
@@ -25,6 +25,7 @@ class PlatformManager():
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
case "vchat":
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
async def initialize(self):
for platform in self.platforms_config:
@@ -37,6 +38,8 @@ class PlatformManager():
logger.info(f"尝试实例化 {platform['type']}({platform['id']}) 平台适配器 ...")
inst = cls_type(platform, self.settings, self.event_queue)
self.platform_insts.append(inst)
self.platform_insts.append(WebChatAdapter({}, self.settings, self.event_queue))
def get_insts(self):
return self.platform_insts
+3 -1
View File
@@ -2,4 +2,6 @@ from dataclasses import dataclass
@dataclass
class PlatformMetadata():
name: str # 平台的名称
description: str # 平台的描述
description: str # 平台的描述
default_config_tmpl: dict = None # 平台的默认配置模板
+13 -2
View File
@@ -7,15 +7,26 @@ platform_registry: List[PlatformMetadata] = []
platform_cls_map: Dict[str, Type] = {}
'''维护了平台适配器名称和适配器类的映射'''
def register_platform_adapter(adapter_name: str, desc: str):
'''用于注册平台适配器的带参装饰器'''
def register_platform_adapter(adapter_name: str, desc: str, default_config_tmpl: dict = None):
'''用于注册平台适配器的带参装饰器
default_config_tmpl 指定了平台适配器的默认配置模板。用户填写好后将会作为 platform_config 传入你的 Platform 类的实现类。
'''
def decorator(cls):
if adapter_name in platform_cls_map:
raise ValueError(f"平台适配器 {adapter_name} 已经注册过了,可能发生了适配器命名冲突。")
# 添加必备选项
if default_config_tmpl:
if 'type' not in default_config_tmpl:
default_config_tmpl['type'] = adapter_name
if 'enable' not in default_config_tmpl:
default_config_tmpl['enable'] = False
pm = PlatformMetadata(
name=adapter_name,
description=desc,
default_config_tmpl=default_config_tmpl
)
platform_registry.append(pm)
platform_cls_map[adapter_name] = cls
@@ -1,3 +1,4 @@
import os
import time
import asyncio
import logging
@@ -5,12 +6,13 @@ from typing import Awaitable, Any
from aiocqhttp import CQHttp, Event
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from .aiocqhttp_message_event import *
from astrbot.api.message_components import *
from .aiocqhttp_message_event import * # noqa: F403
from astrbot.api.message_components import * # noqa: F403
from astrbot.api import logger
from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
@register_platform_adapter("aiocqhttp", "适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。")
class AiocqhttpAdapter(Platform):
@@ -42,7 +44,7 @@ class AiocqhttpAdapter(Platform):
await self.bot.send_private_msg(user_id=session.session_id, message=ret)
await super().send_by_session(session, message_chain)
def convert_message(self, event: Event) -> AstrBotMessage:
async def convert_message(self, event: Event) -> AstrBotMessage:
abm = AstrBotMessage()
abm.self_id = str(event.self_id)
abm.tag = "aiocqhttp"
@@ -78,7 +80,25 @@ class AiocqhttpAdapter(Platform):
a = None
if t == 'text':
message_str += m['data']['text'].strip()
a = ComponentTypes[t](**m['data'])
elif t == 'file':
try:
# Napcat, LLBot
ret = await self.bot.call_action(action="get_file", file_id=event.message[0]['data']['file_id'])
if not ret.get('file', None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret['file']):
raise FileNotFoundError(f"文件不存在: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),暂时无法获取用户上传的文件。")
m['data'] = {
"file": ret['file'],
"name": ret['file_name']
}
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
a = ComponentTypes[t](**m['data']) # noqa: F405
abm.message.append(a)
abm.timestamp = int(time.time())
abm.message_str = message_str
@@ -91,13 +111,13 @@ class AiocqhttpAdapter(Platform):
self.bot = CQHttp(use_ws_reverse=True, import_name='aiocqhttp', api_timeout_sec=180)
@self.bot.on_message('group')
async def group(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@self.bot.on_message('private')
async def private(event: Event):
abm = self.convert_message(event)
abm = await self.convert_message(event)
if abm:
await self.handle_msg(abm)
@@ -31,11 +31,13 @@ class QQOfficialMessageEvent(AstrMessageEvent):
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, group_openid=source.group_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_group_message(group_openid=source.group_openid, **payload)
case botpy.message.C2CMessage:
if image_base64:
media = await self.upload_group_and_c2c_image(image_base64, 1, openid=source.author.user_openid)
payload['media'] = media
payload['msg_type'] = 7
await self.bot.api.post_c2c_message(openid=source.author.user_openid, **payload)
case botpy.message.Message:
if image_path:
@@ -73,9 +75,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
plain_text += i.text
elif isinstance(i, Image) and not image_base64:
if i.file and i.file.startswith("file:///"):
image_base64 = file_to_base64(i.file[8:])
image_base64 = file_to_base64(i.file[8:]).replace("base64://", "")
image_file_path = i.file[8:]
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
return plain_text, image_base64, image_file_path
@@ -11,7 +11,7 @@ from botpy import Client
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from .qqofficial_message_event import QQOfficialMessageEvent
from ...register import register_platform_adapter
@@ -111,6 +111,7 @@ class QQOfficialPlatformAdapter(Platform):
abm.message_id = message.id
abm.tag = "qq_official"
msg: List[BaseMessageComponent] = []
if isinstance(message, botpy.message.GroupMessage) or isinstance(message, botpy.message.C2CMessage):
if isinstance(message, botpy.message.GroupMessage):
@@ -126,7 +127,7 @@ class QQOfficialPlatformAdapter(Platform):
)
abm.message_str = message.content.strip()
abm.self_id = "unknown_selfid"
msg.append(At(qq="qq_official"))
msg.append(Plain(abm.message_str))
if message.attachments:
for i in message.attachments:
@@ -146,7 +147,7 @@ class QQOfficialPlatformAdapter(Platform):
plain_content = message.content.replace(
"<@!"+str(abm.self_id)+">", "").strip()
msg.append(Plain(plain_content))
if message.attachments:
for i in message.attachments:
if i.content_type.startswith("image"):
@@ -161,11 +162,14 @@ class QQOfficialPlatformAdapter(Platform):
str(message.author.id),
str(message.author.username)
)
msg.append(At(qq="qq_official"))
msg.append(Plain(plain_content))
if isinstance(message, botpy.message.Message):
abm.group_id = message.channel_id
else:
raise ValueError(f"Unknown message type: {message_type}")
abm.self_id = "qq_official"
return abm
def run(self):
@@ -2,6 +2,7 @@ import sys
import time
import uuid
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
@@ -62,7 +63,7 @@ class VChatPlatformAdapter(Platform):
self.start_time = int(time.time())
return self._run()
async def _run(self):
await self.client.init()
await self.client.auto_login(hot_reload=True, enable_cmd_qr=True)
@@ -0,0 +1,110 @@
import time
import asyncio
import uuid
import os
from typing import Awaitable, Any
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain, Image, Record # noqa: F403
from astrbot.api import logger
from astrbot.core import web_chat_queue, web_chat_back_queue
from .webchat_event import WebChatMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
class QueueListener:
def __init__(self, queue: asyncio.Queue, callback: callable) -> None:
self.queue = queue
self.callback = callback
async def run(self):
while True:
data = await self.queue.get()
await self.callback(data)
@register_platform_adapter("webchat", "webchat")
class WebChatAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
self.settings = platform_settings
self.unique_session = platform_settings['unique_session']
self.imgs_dir = "data/webchat/imgs"
self.metadata = PlatformMetadata(
"webchat",
"webchat",
)
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
plain = ""
for comp in message_chain.chain:
if isinstance(comp, Plain):
plain += comp.text
web_chat_back_queue.put_nowait(plain)
await super().send_by_session(session, message_chain)
async def convert_message(self, data: tuple) -> AstrBotMessage:
username, cid, payload = data
abm = AstrBotMessage()
abm.self_id = "webchat"
abm.tag = "webchat"
abm.sender = MessageMember(username, username)
abm.type = MessageType.FRIEND_MESSAGE
abm.session_id = f"webchat!{username}!{cid}"
abm.message_id = str(uuid.uuid4())
abm.message = []
if payload['message']:
abm.message.append(Plain(payload['message']))
if payload['image_url']:
if isinstance(payload['image_url'], list):
for img in payload['image_url']:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, img)))
else:
abm.message.append(Image.fromFileSystem(os.path.join(self.imgs_dir, payload['image_url'])))
if payload['audio_url']:
if isinstance(payload['audio_url'], list):
for audio in payload['audio_url']:
path = os.path.join(self.imgs_dir, audio)
abm.message.append(Record(file=path, path=path))
else:
path = os.path.join(self.imgs_dir, payload['audio_url'])
abm.message.append(Record(file=path, path=path))
logger.debug(f"WebChatAdapter: {abm.message}")
message_str = payload['message']
abm.timestamp = int(time.time())
abm.message_str = message_str
abm.raw_message = data
return abm
def run(self) -> Awaitable[Any]:
async def callback(data: tuple):
abm = await self.convert_message(data)
await self.handle_msg(abm)
bot = QueueListener(web_chat_queue, callback)
return bot.run()
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
message_event = WebChatMessageEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id
)
self.commit_event(message_event)
@@ -0,0 +1,35 @@
import os
import uuid
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image
from astrbot.core.utils.io import file_to_base64, download_image_by_url
from astrbot.core import web_chat_back_queue
class WebChatMessageEvent(AstrMessageEvent):
def __init__(self, message_str, message_obj, platform_meta, session_id):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.imgs_dir = "data/webchat/imgs"
os.makedirs(self.imgs_dir, exist_ok=True)
async def send(self, message: MessageChain):
if not message:
web_chat_back_queue.put_nowait(None)
return
for comp in message.chain:
if isinstance(comp, Plain):
web_chat_back_queue.put_nowait(comp.text)
elif isinstance(comp, Image):
# save image to local
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
if comp.file and comp.file.startswith("file:///"):
ph = comp.file[8:]
with open(path, "wb") as f:
with open(ph, "rb") as f2:
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
web_chat_back_queue.put_nowait(f"[IMAGE]{filename}")
web_chat_back_queue.put_nowait(None)
await super().send(message)
+3 -2
View File
@@ -1,9 +1,10 @@
from .provider import Provider, Personality
from .provider import Provider, Personality, STTProvider
from .provider_metadata import ProviderMetaData
from .entites import ProviderMetaData
__all__ = [
"Provider",
"Personality",
"ProviderMetaData",
"STTProvider"
]
+49
View File
@@ -0,0 +1,49 @@
import enum
from dataclasses import dataclass, field
from typing import List, Dict, Type
from .func_tool_manager import FuncCall
class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
@dataclass
class ProviderMetaData():
type: str
'''提供商适配器名称,如 openai, ollama'''
desc: str = ""
'''提供商适配器描述.'''
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
cls_type: Type = None
@dataclass
class ProviderRequest():
prompt: str
'''提示词'''
session_id: str = ""
'''会话 ID'''
image_urls: List[str] = None
'''图片 URL 列表'''
func_tool: FuncCall = None
'''工具'''
contexts: List = None
'''上下文。格式与 openai 的上下文格式一致:
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
'''
system_prompt: str = ""
'''系统提示词'''
@dataclass
class LLMResponse:
role: str
'''角色'''
completion_text: str = ""
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
'''工具调用参数'''
tools_call_name: List[str] = field(default_factory=list)
'''工具调用名称'''
@@ -1,25 +1,9 @@
import json
import textwrap
from typing import Awaitable, Dict, List
from typing import Dict, List, Awaitable
from dataclasses import dataclass
class FuncCallJsonFormatError(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return self.msg
class FuncNotFoundError(Exception):
def __init__(self, msg):
self.msg = msg
def __str__(self):
return self.msg
@dataclass
class FuncTool:
"""
@@ -29,8 +13,8 @@ class FuncTool:
name: str
parameters: Dict
description: str
func_obj: Awaitable
module_name: str = None
handler: Awaitable
handler_module_path: str = None # 必须要保留这个,handler 在初始化会被 functools.partial 包装,导致 handler 的 __module__ 为 functools
active: bool = True
'''是否激活'''
@@ -56,8 +40,7 @@ class FuncCall:
name: str,
func_args: list,
desc: str,
func_obj: Awaitable,
module_name: str = None,
handler: Awaitable,
) -> None:
"""
为函数调用function-calling / tools-use添加工具
@@ -80,8 +63,7 @@ class FuncCall:
name=name,
parameters=params,
description=desc,
func_obj=func_obj,
module_name=module_name,
handler=handler,
)
self.func_list.append(_func)
@@ -119,10 +101,29 @@ class FuncCall:
}
)
return _l
def get_func_desc_google_genai_style(self) -> Dict:
declarations = {}
tools = []
for f in self.func_list:
if not f.active:
continue
tools.append(
{
"name": f.name,
"parameters": f.parameters,
"description": f.description,
}
)
declarations["function_declarations"] = tools
return declarations
async def func_call(self, question: str, session_id: str, provider) -> tuple:
_l = []
for f in self.func_list:
if not f.active:
continue
_l.append(
{
"name": f["name"],
@@ -179,12 +180,19 @@ class FuncCall:
# 调用函数
tool_callable = None
for func in self.func_list:
if func["name"] == func_name:
tool_callable = func["func_obj"]
if func.name == func_name:
tool_callable = func.star_handler_metadata.handler
break
if not tool_callable:
raise FuncNotFoundError(f"Request function {func_name} not found.")
raise Exception(f"Request function {func_name} not found.")
ret = await tool_callable(**args)
if ret:
tool_call_result.append(str(ret))
return tool_call_result, True
def __str__(self):
return str(self.func_list)
def __repr__(self):
return str(self.func_list)
-13
View File
@@ -1,13 +0,0 @@
from typing import Dict, List
from dataclasses import dataclass
@dataclass
class LLMResponse:
role: str
'''角色'''
completion_text: str = None
'''LLM 返回的文本'''
tools_call_args: List[Dict[str, any]] = None
'''工具调用参数'''
tools_call_name: List[str] = None
'''工具调用名称'''
+93 -16
View File
@@ -1,22 +1,36 @@
import traceback
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider
from .provider import Provider, STTProvider
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
from collections import defaultdict
from .register import provider_cls_map, llm_tools
from astrbot.core import logger
from astrbot.core import logger, sp
class ProviderManager():
def __init__(self, config: AstrBotConfig, db_helper: BaseDatabase):
self.providers_config: List = config['provider']
self.provider_settings: dict = config['provider_settings']
self.provider_stt_settings: dict = config.get('provider_stt_settings', {})
self.provider_insts: List[Provider] = []
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
self.curr_stt_provider_inst: STTProvider = None
'''当前使用的 Speech To Text Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
self.curr_kdb_name = ""
kdb_cfg = config.get("knowledge_db", {})
if kdb_cfg and len(kdb_cfg):
self.curr_kdb_name = list(kdb_cfg.keys())[0]
for provider_cfg in self.providers_config:
if not provider_cfg['enable']:
continue
@@ -25,28 +39,91 @@ class ProviderManager():
raise ValueError(f"Provider ID 重复:{provider_cfg['id']}")
self.loaded_ids[provider_cfg['id']] = True
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
try:
match provider_cfg['type']:
case "openai_chat_completion":
from .sources.openai_source import ProviderOpenAIOfficial # noqa: F401
case "zhipu_chat_completion":
from .sources.zhipu_source import ProviderZhipu # noqa: F401
case "llm_tuner":
logger.info("加载 LLM Tuner 工具 ...")
from .sources.llmtuner_source import LLMTunerModelLoader # noqa: F401
case "dify":
from .sources.dify_source import ProviderDify # noqa: F401
case "googlegenai_chat_completion":
from .sources.gemini_source import ProviderGoogleGenAI # noqa: F401
case "openai_whisper_api":
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
except Exception as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。未知原因")
continue
async def initialize(self):
for provider_config in self.providers_config:
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的 大模型提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
cls_type = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 大模型提供商适配器 ...")
inst = cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
self.provider_insts.append(inst)
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
provider_metadata = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
# 按任务实例化提供商
if provider_metadata.provider_type == ProviderType.SPEECH_TO_TEXT:
# STT 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.stt_provider_insts.append(inst)
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings, self.db_helper, self.provider_settings.get('persistant_history', True))
if getattr(inst, "initialize", None):
await inst.initialize()
self.provider_insts.append(inst)
if selected_provider_id == provider_config['id'] and provider_enabled:
self.curr_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。")
except Exception as e:
traceback.print_exc()
logger.error(f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}")
if len(self.provider_insts) > 0:
if len(self.provider_insts) > 0 and not self.curr_provider_inst and provider_enabled:
self.curr_provider_inst = self.provider_insts[0]
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if self.provider_stt_settings.get("enable"):
if not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
def get_insts(self):
return self.provider_insts
return self.provider_insts
async def terminate(self):
for provider_inst in self.provider_insts:
if hasattr(provider_inst, "terminate"):
await provider_inst.terminate()
+31 -5
View File
@@ -5,8 +5,8 @@ from typing import List
from astrbot.core.db import BaseDatabase
from astrbot.core import logger
from typing import TypedDict
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.llm_response import LLMResponse
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entites import LLMResponse
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
@@ -99,6 +99,7 @@ class Provider(abc.ABC):
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts: List=None,
system_prompt: str=None,
**kwargs) -> LLMResponse:
'''获得 LLM 的文本对话结果。会使用当前的模型进行对话。
@@ -111,13 +112,11 @@ class Provider(abc.ABC):
kwargs: 其他参数
Notes:
- 如果传入了 contexts,将会提前加上上下文。否则使用 session_memory 中的上下文。
- 可以选择性地传入 session_id,如果传入了 session_id,将会使用 session_id 对应的上下文进行对话,
并且也会记录相应的对话上下文,实现多轮对话。如果不传入则不会记录上下文。
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
- 如果传入了 contexts,将会**直接**使用所提供的 contexts 进行对话。
传入此值通常意味着你需要自己维护 context,AstrBot 将不会记录上下文,并且会忽略 prompt、session_id、image_urls、tools。
'''
raise NotImplementedError()
@@ -126,6 +125,33 @@ class Provider(abc.ABC):
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class STTProvider():
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_text(self, audio_url: str) -> str:
'''获取音频的文本'''
raise NotImplementedError()
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获取当前使用的模型'''
return self.provider_config.get("model", "")
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
@@ -1,6 +0,0 @@
from dataclasses import dataclass
@dataclass
class ProviderMetaData():
type: str # 提供商适配器名称,如 openai, ollama
desc: str = "" # 提供商适配器描述.
+13 -48
View File
@@ -1,17 +1,20 @@
import docstring_parser
from typing import List, Dict, Type, Awaitable
from .provider_metadata import ProviderMetaData
from typing import List, Dict, Type
from .entites import ProviderMetaData, ProviderType
from astrbot.core import logger
from .tool import FuncCall, SUPPORTED_TYPES
from .func_tool_manager import FuncCall
provider_registry: List[ProviderMetaData] = []
'''维护了通过装饰器注册的 Provider'''
provider_cls_map: Dict[str, Type] = {}
'''维护了 Provider 类型名称和 Provider 的映射'''
provider_cls_map: Dict[str, ProviderMetaData] = {}
'''维护了 Provider 类型名称和 ProviderMetadata 的映射'''
llm_tools = FuncCall()
def register_provider_adapter(provider_type_name: str, desc: str):
def register_provider_adapter(
provider_type_name: str,
desc: str,
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
):
'''用于注册平台适配器的带参装饰器'''
def decorator(cls):
if provider_type_name in provider_cls_map:
@@ -20,50 +23,12 @@ def register_provider_adapter(provider_type_name: str, desc: str):
pm = ProviderMetaData(
type=provider_type_name,
desc=desc,
provider_type=provider_type,
cls_type=cls
)
provider_registry.append(pm)
provider_cls_map[provider_type_name] = cls
provider_cls_map[provider_type_name] = pm
logger.debug(f"Provider {provider_type_name} 已注册")
return cls
return decorator
def register_llm_tool(name: str = None):
'''为函数调用(function-calling / tools-use)添加工具。
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
```
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
\'\'\'获取天气信息。
Args:
location(string): 地点
\'\'\'
# 处理逻辑
```
可接受的参数类型有:string, number, object, array, boolean。
'''
name_ = name
def decorator(func_obj: Awaitable):
llm_tool_name = name_ if name_ else func_obj.__name__
module_name = func_obj.__module__
docstring = docstring_parser.parse(func_obj.__doc__)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
raise ValueError(f"LLM 函数工具 {func_obj.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
args.append({
"type": arg.type_name,
"name": arg.arg_name,
"description": arg.description
})
llm_tools.add_func(llm_tool_name, args, docstring.short_description, func_obj, module_name)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return func_obj
return decorator
@@ -0,0 +1,137 @@
from typing import List
from .. import Provider
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from ..register import register_provider_adapter
from astrbot.core.utils.dify_api_client import DifyAPIClient
from astrbot.core.utils.io import download_image_by_url
from astrbot.core import logger, sp
@register_provider_adapter("dify", "Dify APP 适配器。")
class ProviderDify(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history=False,
) -> None:
super().__init__(
provider_config, provider_settings, persistant_history, db_helper
)
self.api_key = provider_config.get("dify_api_key", "")
if not self.api_key:
raise Exception("Dify API Key 不能为空。")
api_base = provider_config.get("dify_api_base", "https://api.dify.ai/v1")
self.api_client = DifyAPIClient(self.api_key, api_base)
self.api_type = provider_config.get("dify_api_type", "")
if not self.api_type:
raise Exception("Dify API 类型不能为空。")
self.model_name = "dify"
self.workflow_output_key = provider_config.get("dify_workflow_output_key", "astrbot_wf_output")
self.conversation_ids = {}
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
result = ""
conversation_id = self.conversation_ids.get(session_id, "")
files_payload = []
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
file_response = await self.api_client.file_upload(image_path, user=session_id)
if 'id' not in file_response:
logger.warning(f"上传图片后得到未知的 Dify 响应:{file_response},图片将忽略。")
continue
files_payload.append({
"type": "image",
"transfer_method": "local_file",
"upload_file_id": file_response['id'],
})
else:
# TODO: 处理更多情况
logger.warning(f"未知的图片链接:{image_url},图片将忽略。")
logger.debug(files_payload)
# 获得会话变量
session_vars = sp.get("session_variables", {})
session_var = session_vars.get(session_id, {})
match self.api_type:
case "chat" | "agent":
async for chunk in self.api_client.chat_messages(
inputs={
**session_var
},
query=prompt,
user=session_id,
conversation_id=conversation_id,
files=files_payload
):
logger.debug(f"dify resp chunk: {chunk}")
if chunk['event'] == "message" or \
chunk['event'] == "agent_message":
result += chunk['answer']
if not conversation_id:
self.conversation_ids[session_id] = chunk['conversation_id']
conversation_id = chunk['conversation_id']
case "workflow":
async for chunk in self.api_client.workflow_run(
inputs={
"astrbot_text_query": prompt,
"astrbot_session_id": session_id,
**session_var
},
user=session_id,
files=files_payload
):
match chunk['event']:
case "workflow_started":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})开始运行。")
case "node_finished":
logger.debug(f"Dify 工作流节点(ID: {chunk['data']['node_id']} Title: {chunk['data'].get('title', '')})运行结束。")
case "workflow_finished":
logger.info(f"Dify 工作流(ID: {chunk['workflow_run_id']})运行结束。")
if chunk['data']['error']:
logger.error(f"Dify 工作流出现错误:{chunk['data']['error']}")
raise Exception(f"Dify 工作流出现错误:{chunk['data']['error']}")
if self.workflow_output_key not in chunk['data']['outputs']:
raise Exception(f"Dify 工作流的输出不包含指定的键名:{self.workflow_output_key}")
result = chunk['data']['outputs'][self.workflow_output_key]
case _:
raise Exception(f"未知的 Dify API 类型:{self.api_type}")
return LLMResponse(role="assistant", completion_text=result)
async def forget(self, session_id):
self.conversation_ids.pop(session_id, None)
return True
async def get_current_key(self):
return self.api_key
async def set_key(self, key):
raise Exception("Dify 适配器不支持设置 API Key。")
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
raise Exception("暂不支持获得 Dify 的历史消息记录。")
async def terminate(self):
await self.api_client.close()
@@ -0,0 +1,287 @@
import traceback
import base64
import json
import aiohttp
from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
class SimpleGoogleGenAIClient():
def __init__(self, api_key: str, api_base: str):
self.api_key = api_key
if api_base.endswith("/"):
self.api_base = api_base[:-1]
else:
self.api_base = api_base
self.client = aiohttp.ClientSession()
async def models_list(self) -> List[str]:
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
async with self.client.get(request_url, timeout=10) as resp:
response = await resp.json()
models = []
for model in response["models"]:
if 'generateContent' in model["supportedGenerationMethods"]:
models.append(model["name"].replace("models/", ""))
return models
async def generate_content(
self,
contents: List[dict],
model: str="gemini-1.5-flash",
system_instruction: str="",
tools: dict=None
):
payload = {}
if system_instruction:
payload["system_instruction"] = {
"parts": {"text": system_instruction}
}
if tools:
payload["tools"] = [tools]
payload["contents"] = contents
logger.debug(f"payload: {payload}")
request_url = f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
async with self.client.post(request_url, json=payload, timeout=10) as resp:
response = await resp.json()
return response
@register_provider_adapter("googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器")
class ProviderGoogleGenAI(Provider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.client = SimpleGoogleGenAIClient(
api_key=self.chosen_api_key,
api_base=provider_config.get("api_base", None)
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
async def get_models(self):
return await self.client.models_list()
async def pop_record(self, session_id: str, pop_system_prompt: bool = False):
'''
弹出第一条记录
'''
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
if len(self.session_memory[session_id]) == 0:
return None
for i in range(len(self.session_memory[session_id])):
# 检查是否是 system prompt
if not pop_system_prompt and self.session_memory[session_id][i]['user']['role'] == "system":
# 如果只有一个 system prompt,才不删掉
f = False
for j in range(i+1, len(self.session_memory[session_id])):
if self.session_memory[session_id][j]['user']['role'] == "system":
f = True
break
if not f:
continue
record = self.session_memory[session_id].pop(i)
break
return record
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
tool = None
if tools:
tool = tools.get_func_desc_google_genai_style()
system_instruction = ""
for message in payloads["messages"]:
if message["role"] == "system":
system_instruction = message["content"]
break
google_genai_conversation = []
for message in payloads["messages"]:
if message["role"] == "user":
if isinstance(message["content"], str):
google_genai_conversation.append({
"role": "user",
"parts": [{"text": message["content"]}]
})
elif isinstance(message["content"], list):
# images
parts = []
for part in message["content"]:
if part["type"] == "text":
parts.append({"text": part["text"]})
elif part["type"] == "image_url":
parts.append({"inline_data": {
"mime_type": "image/jpeg",
"data": part["image_url"]["url"].replace("data:image/jpeg;base64,", "") # base64
}})
google_genai_conversation.append({
"role": "user",
"parts": parts
})
elif message["role"] == "assistant":
google_genai_conversation.append({
"role": "model",
"parts": [{"text": message["content"]}]
})
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
result = await self.client.generate_content(
contents=google_genai_conversation,
model=self.get_model(),
system_instruction=system_instruction,
tools=tool
)
logger.debug(f"result: {result}")
candidates = result["candidates"][0]['content']['parts']
llm_response = LLMResponse("assistant")
for candidate in candidates:
if 'text' in candidate:
llm_response.completion_text += candidate['text']
elif 'functionCall' in candidate:
llm_response.role = "tool"
llm_response.tools_call_args.append(candidate['functionCall']['args'])
llm_response.tools_call_name.append(candidate['functionCall']['name'])
return llm_response
async def text_chat(
self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
}
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
self.session_memory[session_id].append(new_record)
# 添加 assistant record
self.session_memory[session_id].append({
"role": "assistant",
"content": llm_response.completion_text
})
else:
self.session_memory[session_id] = [*contexts, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
def get_current_key(self) -> str:
return self.client.api_key
def get_keys(self) -> List[str]:
return self.api_keys
def set_key(self, key):
self.client.api_key = key
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''
组装上下文。
'''
if image_urls:
user_content = {"role": "user","content": [{"type": "text", "text": text}]}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
else:
return {"role": "user","content": text}
async def encode_image_bs64(self, image_url: str) -> str:
'''
将图片转换为 base64
'''
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode('utf-8')
return "data:image/jpeg;base64," + image_bs64
return ''
@@ -3,90 +3,119 @@ import os
from llmtuner.chat import ChatModel
from typing import List
from .. import Provider
from ..entites import LLMResponse
from ..func_tool_manager import FuncCall
from astrbot.core.db import BaseDatabase
from astrbot import logger
from ..register import register_provider_adapter
@register_provider_adapter("llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型")
@register_provider_adapter(
"llm_tuner", "LLMTuner 适配器, 用于装载使用 LlamaFactory 微调后的模型"
)
class LLMTunerModelLoader(Provider):
def __init__(
self,
provider_config: dict,
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
db_helper: BaseDatabase,
persistant_history=True,
) -> None:
super().__init__(provider_config, provider_settings, persistant_history, db_helper)
self.base_model_path = provider_config['base_model_path']
self.adapter_model_path = provider_config['adapter_model_path']
self.model = ChatModel({
"model_name_or_path": self.base_model_path,
"adapter_name_or_path": self.adapter_model_path,
"template": provider_config['llmtuner_template'],
"finetuning_type": provider_config['finetuning_type'],
"quantization_bit": provider_config['quantization_bit'],
})
self.set_model(os.path.basename(self.base_model_path) + "_" + os.path.basename(self.adapter_model_path))
super().__init__(
provider_config, provider_settings, persistant_history, db_helper
)
if not os.path.exists(provider_config["base_model_path"]) or not os.path.exists(
provider_config["adapter_model_path"]
):
raise FileNotFoundError("模型文件路径不存在。")
self.base_model_path = provider_config["base_model_path"]
self.adapter_model_path = provider_config["adapter_model_path"]
self.model = ChatModel(
{
"model_name_or_path": self.base_model_path,
"adapter_name_or_path": self.adapter_model_path,
"template": provider_config["llmtuner_template"],
"finetuning_type": provider_config["finetuning_type"],
"quantization_bit": provider_config["quantization_bit"],
}
)
self.set_model(
os.path.basename(self.base_model_path)
+ "_"
+ os.path.basename(self.adapter_model_path)
)
async def assemble_context(self, text: str, image_urls: List[str] = None):
'''
"""
组装上下文。
'''
"""
return {"role": "user", "content": text}
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str] = None,
tools = None,
contexts: List=None,
**kwargs) -> str:
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: List[str] = None,
func_tool: FuncCall = None,
contexts: List = None,
system_prompt: str = None,
**kwargs,
) -> LLMResponse:
system_prompt = ""
if not contexts:
contexts = [*self.session_memory[session_id], {"role": "user", "content": prompt}]
query_context = [
*self.session_memory[session_id],
{"role": "user", "content": prompt},
]
system_prompt = self.curr_personality["prompt"]
else:
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(contexts):
if context["role"] == "system":
system_idxs.append(idx)
for idx in reversed(system_idxs):
system_prompt += " " + contexts.pop(idx)["content"]
logger.debug(f"请求上下文:{contexts}")
logger.debug(f"请求 System Prompt{system_prompt}")
query_context = [*contexts, {"role": "user", "content": prompt}]
# 提取出系统提示
system_idxs = []
for idx, context in enumerate(query_context):
if context["role"] == "system":
system_idxs.append(idx)
for idx in reversed(system_idxs):
system_prompt += " " + query_context.pop(idx)["content"]
conf = {
"messages": contexts,
"messages": query_context,
"system": system_prompt,
}
if tools:
conf['tools'] = tools
if func_tool:
conf["tools"] = func_tool
responses = await self.model.achat(**conf)
logger.debug(f"返回上下文:{responses}")
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
self.session_memory[session_id].append({"role": "user", "content": prompt})
self.session_memory[session_id].append({"role": "assistant", "content": responses[-1].response_text})
if session_id:
if not contexts:
self.session_memory[session_id].append(
{"role": "user", "content": prompt}
)
self.session_memory[session_id].append(
{"role": "assistant", "content": responses[-1].response_text}
)
else:
self.session_memory[session_id] = [
*contexts,
{"role": "user", "content": prompt},
{"role": "assistant", "content": responses[-1].response_text},
]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.meta().type)
return responses[-1].response_text
async def forget(self, session_id):
logger.info("llmtuner reset")
self.session_memory[session_id] = []
return True
async def get_current_key(self):
return "none"
async def set_key(self, key):
pass
async def get_models(self):
return [self.get_model()]
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
@@ -94,9 +123,9 @@ class LLMTunerModelLoader(Provider):
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
if record["role"] == "user":
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
elif record["role"] == "assistant":
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
@@ -105,9 +134,9 @@ class LLMTunerModelLoader(Provider):
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
paged_contexts = contexts[(page - 1) * page_size : page * page_size]
total_pages = len(contexts) // page_size
if len(contexts) % page_size != 0:
total_pages += 1
return paged_contexts, total_pages
return paged_contexts, total_pages
+41 -33
View File
@@ -1,7 +1,6 @@
import traceback
import base64
import json
import datetime
from openai import AsyncOpenAI, NOT_GIVEN
from openai.types.chat.chat_completion import ChatCompletion
@@ -11,10 +10,10 @@ from astrbot.core.utils.io import download_image_by_url
from astrbot.core.db import BaseDatabase
from astrbot.api.provider import Provider
from astrbot import logger
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.llm_response import LLMResponse
from astrbot.core.provider.entites import LLMResponse
@register_provider_adapter("openai_chat_completion", "OpenAI API Chat Completion 提供商适配器")
class ProviderOpenAIOfficial(Provider):
@@ -29,7 +28,6 @@ class ProviderOpenAIOfficial(Provider):
self.chosen_api_key = None
self.api_keys: List = provider_config.get("key", [])
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.enable_datetime = provider_config.get("datetime_system_prompt", True)
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
@@ -37,16 +35,22 @@ class ProviderOpenAIOfficial(Provider):
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config['model_config']['model'])
async def get_human_readable_context(self, session_id, page, page_size):
if session_id not in self.session_memory:
raise Exception("会话 ID 不存在")
contexts = []
temp_contexts = []
for record in self.session_memory[session_id]:
if record['role'] == "user":
contexts.append(f"User: {record['content']}")
temp_contexts.append(f"User: {record['content']}")
elif record['role'] == "assistant":
contexts.append(f"Assistant: {record['content']}")
temp_contexts.append(f"Assistant: {record['content']}")
contexts.insert(0, temp_contexts)
temp_contexts = []
# 展平 contexts 列表
contexts = [item for sublist in contexts for item in sublist]
# 计算分页
paged_contexts = contexts[(page-1)*page_size:page*page_size]
@@ -127,36 +131,30 @@ class ProviderOpenAIOfficial(Provider):
else:
raise Exception("Internal Error")
async def text_chat(self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
**kwargs
) -> LLMResponse:
async def text_chat(
self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
system_prompt = ""
if self.curr_personality["prompt"]:
system_prompt = self.curr_personality["prompt"]
if self.enable_datetime:
system_prompt += f"Current datetime: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M')}"
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
else:
context_query = contexts
logger.debug(f"请求上下文:{context_query}, {self.get_model()}")
context_query = [*contexts, new_record]
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**self.provider_config.get("model_config", {})
}
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
@@ -164,8 +162,13 @@ class ProviderOpenAIOfficial(Provider):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
if llm_response.role == "assistant":
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
if llm_response.role == "assistant" and session_id:
# 文本回复
if not contexts:
# 添加用户 record
@@ -175,10 +178,13 @@ class ProviderOpenAIOfficial(Provider):
"role": "assistant",
"content": llm_response.completion_text
})
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
else:
self.session_memory[session_id] = [*contexts, new_record, {
"role": "assistant",
"content": llm_response.completion_text
}]
self.db_helper.update_llm_history(session_id, json.dumps(self.session_memory[session_id]), self.provider_config['type'])
return llm_response
async def forget(self, session_id: str) -> bool:
self.session_memory[session_id] = []
return True
@@ -203,6 +209,8 @@ class ProviderOpenAIOfficial(Provider):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
else:
if image_url.startswith("file:///"):
image_url = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_url)
user_content["content"].append({"type": "image_url", "image_url": {"url": image_data}})
return user_content
@@ -0,0 +1,95 @@
import uuid
import os
import io
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import STTProvider
from ..entites import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT)
class ProviderOpenAIWhisperAPI(STTProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + '.mp3'
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data/temp', filename))
return output_path
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
import wave
with wave.open(output_path, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(24000)
wav.writeframes(input_io.read())
return output_path
async def _convert_silk(self, path: str) -> str:
import pysilk
filename = str(uuid.uuid4()) + '.wav'
output_path = os.path.join('data/temp', filename)
with open(path, "rb") as f:
input_data = f.read()
if input_data.startswith(b'\x02'):
# tencent 我爱你
input_data = input_data[1:]
input_io = io.BytesIO(input_data)
output_io = io.BytesIO()
pysilk.decode(input_io, output_io, 24000)
output_io.seek(0)
await self._pcm_to_wav(output_io, output_path)
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
else:
return False
async def get_text(self, audio_url: str) -> str:
'''only supports mp3, mp4, mpeg, m4a, wav, webm'''
if audio_url.startswith("http"):
name = str(uuid.uuid4())
path = os.path.join("data/temp", name)
audio_url = await download_file(audio_url, path)
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
audio_url = await self._convert_silk(audio_url)
result = await self.client.audio.transcriptions.create(
model=self.model_name,
file=open(audio_url, "rb"),
)
return result.text
@@ -0,0 +1,92 @@
import uuid
import os
import io
import asyncio
import whisper
from ..provider import STTProvider
from ..entites import ProviderType
from astrbot.core.utils.io import download_file
from ..register import register_provider_adapter
from astrbot.core import logger
@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
class ProviderOpenAIWhisperSelfHost(STTProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.set_model(provider_config.get("model", None))
self.model = None
async def initialize(self):
loop = asyncio.get_event_loop()
logger.info("下载或者加载 Whisper 模型中,这可能需要一些时间 ...")
self.model = await loop.run_in_executor(None, whisper.load_model, self.model_name)
logger.info("Whisper 模型加载完成。")
async def _convert_audio(self, path: str) -> str:
from pyffmpeg import FFmpeg
filename = str(uuid.uuid4()) + '.mp3'
ff = FFmpeg()
output_path = ff.convert(path, os.path.join('data/temp', filename))
return output_path
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
import wave
with wave.open(output_path, 'wb') as wav:
wav.setnchannels(1)
wav.setsampwidth(2)
wav.setframerate(24000)
wav.writeframes(input_io.read())
return output_path
async def _convert_silk(self, path: str) -> str:
import pysilk
filename = str(uuid.uuid4()) + '.wav'
output_path = os.path.join('data/temp', filename)
with open(path, "rb") as f:
input_data = f.read()
if input_data.startswith(b'\x02'):
# tencent 我爱你
input_data = input_data[1:]
input_io = io.BytesIO(input_data)
output_io = io.BytesIO()
pysilk.decode(input_io, output_io, 24000)
output_io.seek(0)
await self._pcm_to_wav(output_io, output_path)
return output_path
async def _is_silk_file(self, file_path):
silk_header = b"SILK"
with open(file_path, "rb") as f:
file_header = f.read(8)
if silk_header in file_header:
return True
else:
return False
async def get_text(self, audio_url: str) -> str:
loop = asyncio.get_event_loop()
if audio_url.startswith("http"):
name = str(uuid.uuid4())
path = os.path.join("data/temp", name)
audio_url = await download_file(audio_url, path)
if not os.path.exists(audio_url):
raise FileNotFoundError(f"文件不存在: {audio_url}")
if audio_url.endswith(".amr") or audio_url.endswith(".silk"):
is_silk = await self._is_silk_file(audio_url)
if is_silk:
logger.info("Converting silk file to wav ...")
audio_url = await self._convert_silk(audio_url)
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
return result['text']
@@ -0,0 +1,73 @@
import traceback
from astrbot.core.db import BaseDatabase
from astrbot import logger
from astrbot.core.provider.func_tool_manager import FuncCall
from typing import List
from ..register import register_provider_adapter
from astrbot.core.provider.entites import LLMResponse
from .openai_source import ProviderOpenAIOfficial
@register_provider_adapter("zhipu_chat_completion", "智浦 Chat Completion 提供商适配器")
class ProviderZhipu(ProviderOpenAIOfficial):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
db_helper: BaseDatabase,
persistant_history = True
) -> None:
super().__init__(provider_config, provider_settings, db_helper, persistant_history)
async def text_chat(
self,
prompt: str,
session_id: str,
image_urls: List[str]=None,
func_tool: FuncCall=None,
contexts=None,
system_prompt=None,
**kwargs
) -> LLMResponse:
new_record = await self.assemble_context(prompt, image_urls)
context_query = []
if not contexts:
context_query = [*self.session_memory[session_id], new_record]
else:
context_query = [*contexts, new_record]
model_cfgs: dict = self.provider_config.get("model_config", {})
# glm-4v-flash 只支持一张图片
model: str = model_cfgs.get("model", "")
if model.lower() == 'glm-4v-flash' and image_urls and len(context_query) > 1:
logger.debug("glm-4v-flash 只支持一张图片,将只保留最后一张图片")
logger.debug(context_query)
new_context_query_ = []
for i in range(0, len(context_query) - 1, 2):
if isinstance(context_query[i].get("content", ""), list):
continue
new_context_query_.append(context_query[i])
new_context_query_.append(context_query[i+1])
new_context_query_.append(context_query[-1]) # 保留最后一条记录
context_query = new_context_query_
logger.debug(context_query)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
payloads = {
"messages": context_query,
**model_cfgs
}
llm_response = None
try:
llm_response = await self._query(payloads, func_tool)
except Exception as e:
if "maximum context length" in str(e):
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
self.pop_record(session_id)
logger.warning(traceback.format_exc())
await self.save_history(contexts, new_record, session_id, llm_response)
return llm_response
@@ -0,0 +1,25 @@
from typing import List
from openai import AsyncOpenAI
class SimpleOpenAIEmbedding():
def __init__(
self,
model,
api_key,
api_base=None,
) -> None:
self.client = AsyncOpenAI(
api_key=api_key,
base_url=api_base
)
self.model = model
async def get_embedding(self, text) -> List[float]:
'''
获取文本的嵌入
'''
embedding = await self.client.embeddings.create(
input=text,
model=self.model
)
return embedding.data[0].embedding
+92
View File
@@ -0,0 +1,92 @@
import os
from typing import List, Dict
from astrbot.core import logger
from .store import Store
from astrbot.core.config import AstrBotConfig
class KnowledgeDBManager():
def __init__(self, astrbot_config: AstrBotConfig) -> None:
self.db_path = "data/knowledge_db/"
self.config = astrbot_config.get("knowledge_db", {})
self.astrbot_config = astrbot_config
if not os.path.exists(self.db_path):
os.makedirs(self.db_path)
self.store_insts: Dict[str, Store] = {}
for name, cfg in self.config.items():
if cfg["strategy"] == "embedding":
logger.info(f"加载 Chroma Vector Store{name}")
try:
from .store.chroma_db import ChromaVectorStore
except ImportError as ie:
logger.error(f"{ie} 可能未安装 chromadb 库。")
continue
self.store_insts[name] = ChromaVectorStore(name, cfg["embedding_config"])
else:
logger.error(f"不支持的策略:{cfg['strategy']}")
async def list_knowledge_db(self) -> List[str]:
return [f for f in os.listdir(self.db_path) if os.path.isfile(os.path.join(self.db_path, f))]
async def create_knowledge_db(self, name: str, config: Dict):
'''
config 格式:
```
{
"strategy": "embedding", # 目前只支持 embedding
"chunk_method": {
"strategy": "fixed",
"chunk_size": 100,
"overlap_size": 10
},
"embedding_config": {
"strategy": "openai",
"base_url": "",
"model": "",
"api_key": ""
}
}
```
'''
if name in self.config:
raise ValueError(f"知识库已存在:{name}")
self.config[name] = config
self.astrbot_config["knowledge_db"] = self.config
self.astrbot_config.save_config()
async def insert_record(self, name: str, text: str):
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
ret = []
match self.config[name]["chunk_method"]['strategy']:
case "fixed":
chunk_size = self.config[name]["chunk_method"]["chunk_size"]
chunk_overlap = self.config[name]["chunk_method"]["overlap_size"]
ret = self._fixed_chunk(text, chunk_size, chunk_overlap)
case _:
pass
for chunk in ret:
await self.store_insts[name].save(chunk)
async def retrive_records(self, name: str, query: str, top_n: int = 3) -> List[str]:
if name not in self.store_insts:
raise ValueError(f"未找到知识库:{name}")
inst = self.store_insts[name]
return await inst.query(query, top_n)
def _fixed_chunk(self, text: str, chunk_size: int, chunk_overlap: int) -> List[str]:
chunks = []
start = 0
while start < len(text):
end = start + chunk_size
chunks.append(text[start:end])
start += chunk_size - chunk_overlap
return chunks
+8
View File
@@ -0,0 +1,8 @@
from typing import List
class Store():
async def save(self, text: str):
pass
async def query(self, query: str, top_n: int = 3) -> List[str]:
pass
+39
View File
@@ -0,0 +1,39 @@
import chromadb
import uuid
from typing import List, Dict
from astrbot.api import logger
from ..embedding.openai_source import SimpleOpenAIEmbedding
from . import Store
class ChromaVectorStore(Store):
def __init__(self, name: str, embedding_cfg: Dict) -> None:
self.chroma_client = chromadb.PersistentClient(path='data/long_term_memory_chroma.db')
self.collection = self.chroma_client.get_or_create_collection(name=name)
self.embedding = None
if embedding_cfg["strategy"] == "openai":
self.embedding = SimpleOpenAIEmbedding(
model=embedding_cfg["model"],
api_key=embedding_cfg["api_key"],
api_base=embedding_cfg.get("base_url", None)
)
async def save(self, text: str, metadata: Dict = None):
logger.debug(f"Saving text: {text}")
embedding = await self.embedding.get_embedding(text)
self.collection.upsert(
documents=text,
metadatas=metadata,
ids=str(uuid.uuid4()),
embeddings=embedding
)
async def query(self, query: str, top_n=3, metadata_filter: Dict = None) -> List[str]:
embedding = await self.embedding.get_embedding(query)
results = self.collection.query(
query_embeddings=embedding,
n_results=top_n,
where=metadata_filter
)
return results['documents'][0]
+46 -15
View File
@@ -1,23 +1,21 @@
from asyncio import Queue
from typing import List, TypedDict, Union
from astrbot.core.provider import Provider
from astrbot.core import sp
from astrbot.core.provider.provider import Provider
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.provider.tool import FuncCall
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.manager import ProviderManager
from astrbot.core.platform.manager import PlatformManager
from .star import star_registry, StarMetadata
from .star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
from .star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from .filter.command import CommandFilter
from .filter.regex import RegexFilter
from typing import Awaitable
class StarCommand(TypedDict):
full_command_name: str
command_name: str
from astrbot.core.rag.knowledge_db_mgr import KnowledgeDBManager
class Context:
'''
@@ -38,11 +36,22 @@ class Context:
# back compatibility
_register_tasks: List[Awaitable] = []
_star_manager = None
def __init__(self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase):
def __init__(self,
event_queue: Queue,
config: AstrBotConfig,
db: BaseDatabase,
provider_manager: ProviderManager = None,
platform_manager: PlatformManager = None,
knowledge_db_manager: KnowledgeDBManager = None
):
self._event_queue = event_queue
self._config = config
self._db = db
self.provider_manager = provider_manager
self.platform_manager = platform_manager
self.knowledge_db_manager = knowledge_db_manager
def get_registered_star(self, star_name: str) -> StarMetadata:
for star in star_registry:
@@ -69,7 +78,17 @@ class Context:
异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。
'''
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj.__module__)
md = StarHandlerMetadata(
event_type=EventType.OnLLMRequestEvent,
handler_full_name=func_obj.__module__ + "_" + func_obj.__name__,
handler_name=func_obj.__name__,
handler_module_path=func_obj.__module__,
handler=func_obj,
event_filters=[],
desc=desc
)
star_handlers_registry.append(md)
self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj, func_obj)
def unregister_llm_tool(self, name: str) -> None:
'''删除一个函数调用工具。如果再要启用,需要重新注册。'''
@@ -84,6 +103,12 @@ class Context:
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = True
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name in inactivated_llm_tools:
inactivated_llm_tools.remove(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
@@ -95,6 +120,12 @@ class Context:
func_tool = self.provider_manager.llm_tools.get_func(name)
if func_tool is not None:
func_tool.active = False
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if name not in inactivated_llm_tools:
inactivated_llm_tools.append(name)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
return True
return False
@@ -112,9 +143,10 @@ class Context:
'''
md = StarHandlerMetadata(
event_type=EventType.AdapterMessageEvent,
handler_full_name=awaitable.__module__ + "_" + awaitable.__name__,
handler_name=awaitable.__name__,
handler_module_str=awaitable.__module__,
handler_module_path=awaitable.__module__,
handler=awaitable,
event_filters=[],
desc=desc
@@ -129,17 +161,16 @@ class Context:
handler_md=md
))
star_handlers_registry.append(md)
star_handlers_map[md.handler_full_name] = md
def register_provider(self, provider: Provider):
'''
注册一个 LLM Provider。
注册一个 LLM Provider(Chat_Completion 类型)
'''
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(self, provider_id: str) -> Provider:
'''
通过 ID 获取 LLM Provider。
通过 ID 获取 LLM Provider(Chat_Completion 类型)
'''
for provider in self.provider_manager.provider_insts:
if provider.meta().id == provider_id:
@@ -148,13 +179,13 @@ class Context:
def get_all_providers(self) -> List[Provider]:
'''
获取所有 LLM Provider。
获取所有 LLM Provider(Chat_Completion 类型)
'''
return self.provider_manager.provider_insts
def get_using_provider(self) -> Provider:
'''
获取当前使用的 LLM Provider。
获取当前使用的 LLM Provider(Chat_Completion 类型)
通过 /provider 指令切换。
'''
+3
View File
@@ -51,6 +51,9 @@ class CommandFilter(HandlerFilter, ParameterValidationMixin):
ls = re.split(r"\s+", message_str)
if self.command_name != ls[0]:
return False
# if len(self.handler_params) == 0 and len(ls) > 1:
# # 一定程度避免 LLM 聊天时误判为指令
# return False
# params_str = message_str[len(self.command_name):].strip()
ls = ls[1:]
# 去除空字符串
+2 -2
View File
@@ -6,8 +6,8 @@ from astrbot.core.config import AstrBotConfig
class PermissionType(enum.Flag):
'''权限类型。当选择 MEMBER,ADMIN 也可以通过。
'''
ADMIN = "admin"
MEMBER = "member"
ADMIN = enum.auto()
MEMBER = enum.auto()
class PermissionTypeFilter(HandlerFilter):
def __init__(self, permission_type: PermissionType, raise_error: bool = True):
+10 -2
View File
@@ -5,7 +5,11 @@ from .star_handler import (
register_event_message_type,
register_platform_adapter_type,
register_regex,
register_permission_type
register_permission_type,
register_on_llm_request,
register_llm_tool,
register_on_decorating_result,
register_after_message_sent
)
__all__ = [
@@ -15,5 +19,9 @@ __all__ = [
'register_event_message_type',
'register_platform_adapter_type',
'register_regex',
'register_permission_type'
'register_permission_type',
'register_on_llm_request',
'register_llm_tool',
'register_on_decorating_result',
'register_after_message_sent'
]
+101 -23
View File
@@ -1,6 +1,7 @@
from __future__ import annotations
import docstring_parser
from ..star_handler import star_handlers_registry, star_handlers_map, StarHandlerMetadata
from ..star_handler import star_handlers_registry, StarHandlerMetadata, EventType
from ..filter.command import CommandFilter
from ..filter.command_group import CommandGroupFilter
from ..filter.event_message_type import EventMessageTypeFilter, EventMessageType
@@ -8,28 +9,31 @@ from ..filter.platform_adapter_type import PlatformAdapterTypeFilter, PlatformAd
from ..filter.permission import PermissionTypeFilter, PermissionType
from ..filter.regex import RegexFilter
from typing import Awaitable
from astrbot.core.provider.func_tool_manager import SUPPORTED_TYPES
from astrbot.core.provider.register import llm_tools
from astrbot.core import logger
def get_handler_full_name(awatable: Awaitable) -> str:
def get_handler_full_name(awaitable: Awaitable) -> str:
'''获取 Handler 的全名'''
return f"{awatable.__module__}_{awatable.__name__}"
return f"{awaitable.__module__}_{awaitable.__name__}"
def get_handler_or_create(handler: Awaitable, dont_add = False) -> StarHandlerMetadata:
def get_handler_or_create(handler: Awaitable, event_type: EventType, dont_add = False) -> StarHandlerMetadata:
'''获取 Handler 或者创建一个新的 Handler'''
handler_full_name = get_handler_full_name(handler)
if handler_full_name in star_handlers_map:
return star_handlers_map[handler_full_name]
md = star_handlers_registry.get_handler_by_full_name(handler_full_name)
if md:
return md
else:
md = StarHandlerMetadata(
event_type=event_type,
handler_full_name=handler_full_name,
handler_name=handler.__name__,
handler_module_str=handler.__module__,
handler_module_path=handler.__module__,
handler=handler,
event_filters=[]
)
if not dont_add:
star_handlers_registry.append(md)
star_handlers_map[handler_full_name] = md
return md
def register_command(command_name: str = None, *args):
@@ -47,7 +51,7 @@ def register_command(command_name: str = None, *args):
add_to_event_filters = True
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable)
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
new_command.init_handler_md(handler_md)
if add_to_event_filters:
# 裸指令
@@ -74,7 +78,7 @@ def register_command_group(command_group_name: str = None, *args):
def decorator(obj):
if add_to_event_filters:
# 根指令组
handler_md = get_handler_or_create(obj)
handler_md = get_handler_or_create(obj, EventType.AdapterMessageEvent)
handler_md.event_filters.append(new_group)
return RegisteringCommandable(new_group)
@@ -91,28 +95,28 @@ class RegisteringCommandable():
def register_event_message_type(event_message_type: EventMessageType):
'''注册一个 EventMessageType'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(EventMessageTypeFilter(event_message_type))
return awatable
return awaitable
return decorator
def register_platform_adapter_type(platform_adapter_type: PlatformAdapterType):
'''注册一个 PlatformAdapterType'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(PlatformAdapterTypeFilter(platform_adapter_type))
return awatable
return awaitable
return decorator
def register_regex(regex: str):
'''注册一个 Regex'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(RegexFilter(regex))
return awatable
return awaitable
return decorator
@@ -123,9 +127,83 @@ def register_permission_type(permission_type: PermissionType, raise_error: bool
permission_type: PermissionType
raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True
'''
def decorator(awatable):
handler_md = get_handler_or_create(awatable)
def decorator(awaitable):
handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent)
handler_md.event_filters.append(PermissionTypeFilter(permission_type, raise_error))
return awatable
return awaitable
return decorator
def register_on_llm_request():
'''当有 LLM 请求时的事件
Examples:
```py
@on_llm_request()
async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None:
request.system_prompt += "你是一个猫娘..."
```
请务必接收两个参数:event, request
'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent)
return awaitable
return decorator
def register_llm_tool(name: str = None):
'''为函数调用(function-calling / tools-use)添加工具。
请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释)
```
@llm_tool(name="get_weather") # 如果 name 不填,将使用函数名
async def get_weather(event: AstrMessageEvent, location: str) -> MessageEventResult:
\'\'\'获取天气信息。
Args:
location(string): 地点
\'\'\'
# 处理逻辑
```
可接受的参数类型有:string, number, object, array, boolean。
'''
name_ = name
def decorator(awaitable: Awaitable):
llm_tool_name = name_ if name_ else awaitable.__name__
docstring = docstring_parser.parse(awaitable.__doc__)
args = []
for arg in docstring.params:
if arg.type_name not in SUPPORTED_TYPES:
raise ValueError(f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}")
args.append({
"type": arg.type_name,
"name": arg.arg_name,
"description": arg.description
})
md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent)
llm_tools.add_func(llm_tool_name, args, docstring.description, md.handler)
logger.debug(f"LLM 函数工具 {llm_tool_name} 已注册")
return awaitable
return decorator
def register_on_decorating_result():
'''在发送消息前的事件'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnDecoratingResultEvent)
return awaitable
return decorator
def register_after_message_sent():
'''在消息发送后的事件'''
def decorator(awaitable):
_ = get_handler_or_create(awaitable, EventType.OnAfterMessageSentEvent)
return awaitable
return decorator
+3
View File
@@ -32,6 +32,9 @@ class StarMetadata:
'''Star 的根目录名'''
reserved: bool = False
'''是否是 AstrBot 的保留 Star'''
activated: bool = True
'''是否被激活'''
def __str__(self) -> str:
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
+52 -6
View File
@@ -1,31 +1,77 @@
from __future__ import annotations
import enum
from dataclasses import dataclass
from typing import Awaitable, List, Dict
from typing import Awaitable, List, Dict, TypeVar, Generic
from .filter import HandlerFilter
from .star import star_map
star_handlers_registry: List[StarHandlerMetadata] = []
T = TypeVar('T', bound='StarHandlerMetadata')
class StarHandlerRegistry(Generic[T], List[T]):
'''用于存储所有的 Star Handler'''
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
'''用于快速查找。key 是 handler_full_name'''
def append(self, handler: StarHandlerMetadata):
'''添加一个 Handler'''
super().append(handler)
self.star_handlers_map[handler.handler_full_name] = handler
def get_handlers_by_event_type(self, event_type: EventType, only_activated = True) -> List[StarHandlerMetadata]:
'''通过事件类型获取 Handler'''
if only_activated:
return [
handler
for handler in self
if handler.event_type == event_type and
star_map[handler.handler_module_path] and
star_map[handler.handler_module_path].activated
]
else:
return [handler for handler in self if handler.event_type == event_type]
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
'''通过 Handler 的全名获取 Handler'''
return self.star_handlers_map.get(full_name, None)
def get_handlers_by_module_name(self, module_name: str) -> List[StarHandlerMetadata]:
'''通过模块名获取 Handler'''
return [handler for handler in self if handler.handler_module_path == module_name]
star_handlers_registry = StarHandlerRegistry()
star_handlers_map: Dict[str, StarHandlerMetadata] = {}
'''用于快速查找。key 是 handler_full_name'''
class EventType(enum.Enum):
'''表示一个 AstrBot 内部事件的类型。如适配器消息事件、LLM 请求事件、发送消息前的事件等
用于对 Handler 的职能分组。
'''
AdapterMessageEvent = enum.auto() # 收到适配器发来的消息
OnLLMRequestEvent = enum.auto() # 收到 LLM 请求(可以是用户也可以是插件)
OnDecoratingResultEvent = enum.auto() # 发送消息前
OnCallingFuncToolEvent = enum.auto() # 调用函数工具
OnAfterMessageSentEvent = enum.auto() # 发送消息后
@dataclass
class StarHandlerMetadata():
'''描述一个 Star 所注册的某一个 Handler。'''
event_type: EventType
'''Handler 的事件类型'''
handler_full_name: str
'''格式为 f"{handler.__module__}_{handler.__name__}"'''
handler_name: str
'''Handler 的名字,也就是方法名'''
handler_module_str: str
handler_module_path: str
'''Handler 所在的模块路径。'''
handler: Awaitable
'''Handler 的函数对象,应当是一个异步函数'''
event_filters: List[HandlerFilter]
'''一个事件过滤器,用于描述这个 Handler 能够处理、应该处理的事件'''
'''一个适配器消息事件过滤器,用于描述这个 Handler 能够处理、应该处理的适配器消息事件'''
desc: str = ""
'''Handler 的描述信息'''
+109 -22
View File
@@ -1,5 +1,7 @@
import inspect
import functools
import os
import sys
import traceback
import yaml
import logging
@@ -7,14 +9,14 @@ from types import ModuleType
from typing import List
from pip import main as pip_main
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core import logger
from astrbot.core import logger, sp, pip_installer
from .context import Context
from . import StarMetadata
from .updator import PluginUpdator
from astrbot.core.utils.io import remove_dir
from .star import star_registry, star_map
from .star_handler import star_handlers_registry
from astrbot.core.provider.register import llm_tools
class PluginManager:
def __init__(
@@ -25,6 +27,7 @@ class PluginManager:
self.updator = PluginUpdator(config['plugin_repo_mirror'])
self.context = context
self.context._star_manager = self # 就这样吧,不想改了
self.config = config
self.plugin_store_path = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins"))
@@ -89,21 +92,12 @@ class PluginManager:
plugin_path = os.path.join(plugin_dir, p)
if os.path.exists(os.path.join(plugin_path, "requirements.txt")):
pth = os.path.join(plugin_path, "requirements.txt")
logger.info(f"正在检查插件 {p} 的依赖: {pth}")
logger.info(f"正在安装插件 {p} 所需的依赖: {pth}")
try:
self._update_plugin_dept(os.path.join(plugin_path, "requirements.txt"))
pip_installer.install(requirements_path=pth)
except Exception as e:
logger.error(f"更新插件 {p} 的依赖失败。Code: {str(e)}")
def _update_plugin_dept(self, path):
'''更新插件的依赖'''
args = ['install', '-r', path, '--trusted-host', 'mirrors.aliyun.com', '-i', 'https://mirrors.aliyun.com/pypi/simple/']
if self.config.pip_install_arg:
args.extend(self.config.pip_install_arg)
result_code = pip_main(args)
if result_code != 0:
raise Exception(str(result_code))
def _load_plugin_metadata(self, plugin_path: str, plugin_obj = None) -> StarMetadata:
'''v3.4.0 以前的方式载入插件元数据
@@ -134,15 +128,29 @@ class PluginManager:
return metadata
def reload(self):
async def reload(self):
'''扫描并加载所有的 Star'''
for smd in star_registry:
logger.debug(f"尝试终止插件 {smd.name} ...")
if hasattr(smd.star_cls, "__del__"):
smd.star_cls.__del__()
star_handlers_registry.clear()
star_handlers_registry.star_handlers_map.clear()
star_map.clear()
star_registry.clear()
for key in list(sys.modules.keys()):
if key.startswith("data.plugins") or key.startswith("packages"):
del sys.modules[key]
plugin_modules = self._get_plugin_modules()
if plugin_modules is None:
return False, "未找到任何插件模块"
fail_rec = ""
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 导入 Star 模块,并尝试实例化 Star 类
for plugin_module in plugin_modules:
try:
@@ -169,11 +177,25 @@ class PluginManager:
if path in star_map:
# 通过装饰器的方式注册插件
star_metadata = star_map[path]
star_metadata.star_cls = star_metadata.star_cls_type(context=self.context)
star_metadata.module = module
star_metadata.root_dir_name = root_dir_name
star_metadata.reserved = reserved
metadata = star_map[path]
metadata.star_cls = metadata.star_cls_type(context=self.context)
metadata.module = module
metadata.root_dir_name = root_dir_name
metadata.reserved = reserved
related_handlers = star_handlers_registry.get_handlers_by_module_name(metadata.module_path)
for handler in related_handlers:
logger.debug(f"bind handler {handler.handler_name} to {metadata.name}")
# handler.handler.__self__ = star_metadata.star_cls # 绑定 handler 的 self
handler.handler = functools.partial(handler.handler, metadata.star_cls)
# llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler.__module__ == metadata.module_path:
func_tool.handler_module_path = metadata.module_path
func_tool.handler = functools.partial(func_tool.handler, metadata.star_cls)
if func_tool.name in inactivated_llm_tools:
func_tool.active = False
else:
# v3.4.0 以前的方式注册插件
logger.debug(f"插件 {path} 未通过装饰器注册。尝试通过旧版本方式载入。")
@@ -196,6 +218,13 @@ class PluginManager:
star_map[path] = metadata
star_registry.append(metadata)
logger.debug(f"插件 {root_dir_name} 载入成功。")
if metadata.module_path in inactivated_plugins:
metadata.activated = False
# 执行 initialize 函数
if hasattr(metadata.star_cls, "initialize"):
await metadata.star_cls.initialize()
except BaseException as e:
traceback.print_exc()
@@ -212,10 +241,11 @@ class PluginManager:
async def install_plugin(self, repo_url: str):
plugin_path = await self.updator.install(repo_url)
self._check_plugin_dept_update()
# reload the plugin
await self.reload()
return plugin_path
def uninstall_plugin(self, plugin_name: str):
async def uninstall_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
@@ -224,10 +254,26 @@ class PluginManager:
root_dir_name = plugin.root_dir_name
ppath = self.plugin_store_path
del star_map[plugin.module_path]
# 从 star_registry 和 star_map 中删除
await self._unbind_plugin(plugin_name, plugin.module_path)
if not remove_dir(os.path.join(ppath, root_dir_name)):
raise Exception("移除插件成功,但是删除插件文件夹失败。您可以手动删除该文件夹,位于 addons/plugins/ 下。")
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
del star_map[plugin_module_path]
for i, p in enumerate(star_registry):
if p.name == plugin_name:
del star_registry[i]
break
for handler in star_handlers_registry.get_handlers_by_module_name(plugin_module_path):
logger.debug(f"unbind handler {handler.handler_name} from {plugin_name}")
star_handlers_registry.remove(handler)
keys_to_delete = [k for k, v in star_handlers_registry.star_handlers_map.items() if k.startswith(plugin_module_path)]
for k in keys_to_delete:
v = star_handlers_registry.star_handlers_map[k]
logger.debug(f"unbind handler {v.handler_name} from {plugin_name} (map)")
del star_handlers_registry.star_handlers_map[k]
async def update_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
@@ -237,6 +283,46 @@ class PluginManager:
raise Exception("该插件是 AstrBot 保留插件,无法更新。")
await self.updator.update(plugin)
await self.reload()
async def turn_off_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
if not plugin:
raise Exception("插件不存在。")
inactivated_plugins: list = sp.get("inactivated_plugins", [])
if plugin.module_path not in inactivated_plugins:
inactivated_plugins.append(plugin.module_path)
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
# 禁用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
func_tool.active = False
inactivated_llm_tools.append(func_tool.name)
sp.put("inactivated_plugins", inactivated_plugins)
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = False
async def turn_on_plugin(self, plugin_name: str):
plugin = self.context.get_registered_star(plugin_name)
inactivated_plugins: list = sp.get("inactivated_plugins", [])
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
if plugin.module_path in inactivated_plugins:
inactivated_plugins.remove(plugin.module_path)
sp.put("inactivated_plugins", inactivated_plugins)
# 启用插件启用的 llm_tool
for func_tool in llm_tools.func_list:
if func_tool.handler_module_path == plugin.module_path:
inactivated_llm_tools.remove(func_tool.name)
func_tool.active = True
sp.put("inactivated_llm_tools", inactivated_llm_tools)
plugin.activated = True
def install_plugin_from_file(self, zip_file_path: str):
desti_dir = os.path.join(self.plugin_store_path, os.path.basename(zip_file_path))
@@ -249,3 +335,4 @@ class PluginManager:
logger.warning(f"删除插件压缩包失败: {str(e)}")
self._check_plugin_dept_update()
+1 -2
View File
@@ -53,7 +53,6 @@ class PluginUpdator(RepoZipUpdator):
files = os.listdir(os.path.join(target_dir, update_dir))
for f in files:
logger.info(f"移动更新文件/目录: {f}")
if os.path.isdir(os.path.join(target_dir, update_dir, f)):
if os.path.exists(os.path.join(target_dir, f)):
shutil.rmtree(os.path.join(target_dir, f), onerror=on_error)
@@ -63,7 +62,7 @@ class PluginUpdator(RepoZipUpdator):
shutil.move(os.path.join(target_dir, update_dir, f), target_dir)
try:
logger.info(f"删除临时更新文件: {zip_path}{os.path.join(target_dir, update_dir)}")
logger.info(f"删除临时文件: {zip_path}{os.path.join(target_dir, update_dir)}")
shutil.rmtree(os.path.join(target_dir, update_dir), onerror=on_error)
os.remove(zip_path)
except BaseException:
+90
View File
@@ -0,0 +1,90 @@
import json
from astrbot.core import logger
from aiohttp import ClientSession
from typing import Dict, List, Any, AsyncGenerator
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
self.api_key = api_key
self.api_base = api_base
self.session = ClientSession()
self.headers = {
"Authorization": f"Bearer {self.api_key}",
}
async def chat_messages(
self,
inputs: Dict,
query: str,
user: str,
response_mode: str = "streaming",
conversation_id: str = "",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
) -> AsyncGenerator[Dict[str, Any], None]:
url = f"{self.api_base}/chat-messages"
payload = locals()
payload.pop("self")
payload.pop("timeout")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def workflow_run(
self,
inputs: Dict,
user: str,
response_mode: str = "streaming",
files: List[Dict[str, Any]] = [],
timeout: float = 60,
):
url = f"{self.api_base}/workflows/run"
payload = locals()
payload.pop("self")
payload.pop("timeout")
async with self.session.post(
url, json=payload, headers=self.headers, timeout=timeout
) as resp:
while True:
data = await resp.content.read(8192) # 防止数据过大导致高水位报错
if not data:
break
if not data.strip():
continue
elif data.startswith(b"data:"):
try:
json_ = json.loads(data[5:])
yield json_
except BaseException:
pass
async def file_upload(
self,
file_path: str,
user: str,
) -> Dict[str, Any]:
url = f"{self.api_base}/files/upload"
payload = {
"user": user,
"file": open(file_path, "rb"),
}
async with self.session.post(
url, data=payload, headers=self.headers
) as resp:
return await resp.json() # {"id": "xxx", ...}
async def close(self):
await self.session.close()
+24 -4
View File
@@ -5,6 +5,7 @@ import socket
import time
import aiohttp
import base64
import zipfile
from PIL import Image
@@ -64,7 +65,7 @@ def save_temp_img(img: Image) -> str:
f.write(img)
return p
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None) -> str:
async def download_image_by_url(url: str, post: bool = False, post_data: dict = None, path = None) -> str:
'''
下载图片, 返回 path
'''
@@ -72,10 +73,20 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
async with aiohttp.ClientSession() as session:
if post:
async with session.post(url, json=post_data) as resp:
return save_temp_img(await resp.read())
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
else:
async with session.get(url) as resp:
return save_temp_img(await resp.read())
if not path:
return save_temp_img(await resp.read())
else:
with open(path, "wb") as f:
f.write(await resp.read())
return path
except aiohttp.client_exceptions.ClientConnectorSSLError:
# 关闭SSL验证
ssl_context = ssl.create_default_context()
@@ -96,7 +107,9 @@ async def download_file(url: str, path: str):
'''
try:
async with aiohttp.ClientSession() as session:
async with session.get(url) as resp:
async with session.get(url, timeout=20) as resp:
if resp.status != 200:
raise Exception(f"下载文件失败: {resp.status}")
with open(path, 'wb') as f:
while True:
chunk = await resp.content.read(8192)
@@ -123,3 +136,10 @@ def get_local_ip_addresses():
finally:
s.close()
return ip
async def download_dashboard():
'''下载管理面板文件'''
dashboard_release_url = "https://astrbot-registry.lwl.lol/download/astrbot-dashboard/latest/dist.zip"
await download_file(dashboard_release_url, "data/dashboard.zip")
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")
@@ -22,6 +22,9 @@ class ParameterValidationMixin:
result[param_name] = int(params[i])
else:
result[param_name] = params[i]
elif isinstance(param_type_or_default_val, str):
# 如果 param_type_or_default_val 是字符串,直接赋值
result[param_name] = params[i]
else:
result[param_name] = param_type_or_default_val(params[i])
except ValueError:
+33
View File
@@ -0,0 +1,33 @@
import logging
from pip import main as pip_main
class PipInstaller():
def __init__(self, pip_install_arg: str):
self.pip_install_arg = pip_install_arg
def install(self, package_name: str = None, requirements_path: str = None, mirror: str = None):
args = ['install']
if package_name:
args.append(package_name)
elif requirements_path:
args.extend(['-r', requirements_path])
if not mirror:
mirror = 'https://mirrors.aliyun.com/pypi/simple/'
args.extend(['--trusted-host', 'mirrors.aliyun.com', '-i', mirror])
if self.pip_install_arg:
args.extend(self.pip_install_arg.split())
print(f"Pip 包管理器: {' '.join(args)}")
result_code = pip_main(args)
# 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler)
if result_code != 0:
raise Exception(f"安装失败,错误码:{result_code}")
+33
View File
@@ -0,0 +1,33 @@
import json
import os
class SharedPreferences:
def __init__(self, path="data/shared_preferences.json"):
self.path = path
self._data = self._load_preferences()
def _load_preferences(self):
if os.path.exists(self.path):
with open(self.path, "r") as f:
return json.load(f)
return {}
def _save_preferences(self):
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4)
def get(self, key, default=None):
return self._data.get(key, default)
def put(self, key, value):
self._data[key] = value
self._save_preferences()
def remove(self, key):
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
self._data.clear()
self._save_preferences()
+1 -1
View File
@@ -111,7 +111,7 @@ class RepoZipUpdator():
releases = await self.fetch_release_info(url=release_url)
if not releases:
# download from the default branch directly.
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,从默认分支下载。")
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
else:
release_url = releases[0]['zipball_url']
+11 -2
View File
@@ -1,4 +1,5 @@
import asyncio
import traceback
from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .server import AstrBotDashboard
@@ -13,8 +14,16 @@ class AstrBotDashBoardLifecycle:
async def start(self):
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
await core_lifecycle.initialize()
core_task = core_lifecycle.start()
core_task = []
try:
await core_lifecycle.initialize()
core_task = core_lifecycle.start()
except Exception as e:
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
logger.critical(f"初始化 AstrBot 失败:{e} !!!!!!!")
self.dashboard_server = AstrBotDashboard(core_lifecycle, self.db)
task = asyncio.gather(core_task, self.dashboard_server.run())
+3 -1
View File
@@ -5,6 +5,7 @@ from .update import UpdateRoute
from .stat import StatRoute
from .log import LogRoute
from .static_file import StaticFileRoute
from .chat import ChatRoute
__all__ = [
@@ -14,6 +15,7 @@ __all__ = [
"UpdateRoute",
"StatRoute",
"LogRoute",
"StaticFileRoute"
"StaticFileRoute",
"ChatRoute",
]
+197
View File
@@ -0,0 +1,197 @@
import uuid
import json
import os
from .route import Route, Response, RouteContext
from astrbot.core import web_chat_queue, web_chat_back_queue
from quart import request, Response as QuartResponse, g
from astrbot.core.db import BaseDatabase
import asyncio
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
class ChatRoute(Route):
def __init__(self, context: RouteContext, db: BaseDatabase, core_lifecycle: AstrBotCoreLifecycle) -> None:
super().__init__(context)
self.routes = {
'/chat/send': ('POST', self.chat),
'/chat/new_conversation': ('GET', self.new_conversation),
'/chat/conversations': ('GET', self.get_conversations),
'/chat/get_conversation': ('GET', self.get_conversation),
'/chat/delete_conversation': ('GET', self.delete_conversation),
'/chat/get_file': ('GET', self.get_file),
'/chat/post_image': ('POST', self.post_image),
'/chat/post_file': ('POST', self.post_file),
'/chat/status': ('GET', self.status),
}
self.db = db
self.core_lifecycle = core_lifecycle
self.register_routes()
self.imgs_dir = "data/webchat/imgs"
self.supported_imgs = ['jpg', 'jpeg', 'png', 'gif', 'webp']
async def status(self):
has_llm_enabled = self.core_lifecycle.provider_manager.curr_provider_inst is not None
has_stt_enabled = self.core_lifecycle.provider_manager.curr_stt_provider_inst is not None
return Response().ok(data={
'llm_enabled': has_llm_enabled,
'stt_enabled': has_stt_enabled
}).__dict__
async def get_file(self):
filename = request.args.get('filename')
if not filename:
return Response().error("Missing key: filename").__dict__
try:
with open(os.path.join(self.imgs_dir, filename), "rb") as f:
if filename.endswith(".wav"):
return QuartResponse(f.read(), mimetype="audio/wav")
elif filename.split('.')[-1] in self.supported_imgs:
return QuartResponse(f.read(), mimetype="image/jpeg")
else:
return QuartResponse(f.read())
except FileNotFoundError:
return Response().error("File not found").__dict__
async def post_image(self):
post_data = await request.files
if 'file' not in post_data:
return Response().error("Missing key: file").__dict__
file = post_data['file']
filename = str(uuid.uuid4()) + ".jpg"
path = os.path.join(self.imgs_dir, filename)
await file.save(path)
return Response().ok(data={
'filename': filename
}).__dict__
async def post_file(self):
post_data = await request.files
if 'file' not in post_data:
return Response().error("Missing key: file").__dict__
file = post_data['file']
filename = f"{str(uuid.uuid4())}"
print(file)
# 通过文件格式判断文件类型
if file.content_type.startswith('audio'):
filename += ".wav"
path = os.path.join(self.imgs_dir, filename)
await file.save(path)
return Response().ok(data={
'filename': filename
}).__dict__
async def chat(self):
username = g.get('username', 'guest')
post_data = await request.json
if 'message' not in post_data and 'image_url' not in post_data:
return Response().error("Missing key: message or image_url").__dict__
if 'conversation_id' not in post_data:
return Response().error("Missing key: conversation_id").__dict__
message = post_data['message']
conversation_id = post_data['conversation_id']
image_url = post_data.get('image_url')
audio_url = post_data.get('audio_url')
if not message and not image_url and not audio_url:
return Response().error("Message and image_url and audio_url are empty").__dict__
if not conversation_id:
return Response().error("conversation_id is empty").__dict__
await web_chat_queue.put((username, conversation_id, {
'message': message,
'image_url': image_url, # list
'audio_url': audio_url
}))
async def stream():
ret = []
while True:
try:
result = await asyncio.wait_for(web_chat_back_queue.get(), timeout=30) # 设置超时时间为5秒
except asyncio.TimeoutError:
yield '[Error] 30 秒内没有返回数据,已放弃。\n'
return
if result is None:
break
ret.append(result)
yield result + '\n'
await asyncio.sleep(0.5)
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
try:
history = json.loads(conversation.history)
except BaseException as e:
print(e)
history = []
new_his = {
'type': 'user',
'message': message
}
if image_url:
new_his['image_url'] = image_url
if audio_url:
new_his['audio_url'] = audio_url
history.append(new_his)
for r in ret:
history.append({
'type': 'bot',
'message': r
})
self.db.update_webchat_conversation(username, conversation_id, history=json.dumps(history))
return QuartResponse(
stream(),
mimetype="text/event-stream",
headers={
"Content-Type": "text/event-stream",
"Transfer-Encoding": "chunked",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*" # 如果是跨域请求
}
)
async def delete_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
self.db.delete_webchat_conversation(username, conversation_id)
return Response().ok().__dict__
async def new_conversation(self):
username = g.get('username', 'guest')
conversation_id = str(uuid.uuid4())
self.db.webchat_new_conversation(username, conversation_id)
return Response().ok(data={
'conversation_id': conversation_id
}).__dict__
async def get_conversations(self):
username = g.get('username', 'guest')
conversations = self.db.get_webchat_conversations(username)
return Response().ok(data=conversations).__dict__
async def get_conversation(self):
username = g.get('username', 'guest')
conversation_id = request.args.get('conversation_id')
if not conversation_id:
return Response().error("Missing key: conversation_id").__dict__
conversation = self.db.get_webchat_conversation_by_user_id(username, conversation_id)
return Response().ok(data=conversation).__dict__
+7
View File
@@ -7,6 +7,7 @@ from astrbot.core.config.default import CONFIG_METADATA_2, DEFAULT_VALUE_MAP
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.config import update_config
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_registry
def try_cast(value: str, type_: str):
if type_ == "int" and value.isdigit():
@@ -121,6 +122,12 @@ class ConfigRoute(Route):
async def _get_astrbot_config(self):
config = self.config
platform_default_tmpl = CONFIG_METADATA_2['platform_group']['metadata']['platform']['config_template']
for platform in platform_registry:
if platform.default_config_tmpl:
platform_default_tmpl[platform.name] = platform.default_config_tmpl
return {
"metadata": CONFIG_METADATA_2,
"config": config
+30 -8
View File
@@ -16,7 +16,9 @@ class PluginRoute(Route):
'/plugin/install-upload': ('POST', self.install_plugin_upload),
'/plugin/update': ('POST', self.update_plugin),
'/plugin/uninstall': ('POST', self.uninstall_plugin),
'/plugin/market_list': ('GET', self.get_online_plugins)
'/plugin/market_list': ('GET', self.get_online_plugins),
'/plugin/off': ('POST', self.off_plugin),
'/plugin/on': ('POST', self.on_plugin)
}
self.core_lifecycle = core_lifecycle
self.plugin_manager = plugin_manager
@@ -42,7 +44,8 @@ class PluginRoute(Route):
"author": plugin.author,
"desc": plugin.desc,
"version": plugin.version,
"reserved": plugin.reserved
"reserved": plugin.reserved,
"activated": plugin.activated
}
_plugin_resp.append(_t)
return Response().ok(_plugin_resp).__dict__
@@ -53,7 +56,6 @@ class PluginRoute(Route):
try:
logger.info(f"正在安装插件 {repo_url}")
await self.plugin_manager.install_plugin(repo_url)
self.core_lifecycle.restart()
logger.info(f"安装插件 {repo_url} 成功。")
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
@@ -69,7 +71,6 @@ class PluginRoute(Route):
await file.save(file_path)
self.plugin_manager.install_plugin_from_file(file_path)
logger.info(f"安装插件 {file.filename} 成功")
self.core_lifecycle.restart()
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(traceback.format_exc())
@@ -80,7 +81,7 @@ class PluginRoute(Route):
plugin_name = post_data["name"]
try:
logger.info(f"正在卸载插件 {plugin_name}")
self.plugin_manager.uninstall_plugin(plugin_name)
await self.plugin_manager.uninstall_plugin(plugin_name)
logger.info(f"卸载插件 {plugin_name} 成功")
return Response().ok(None, "卸载成功").__dict__
except Exception as e:
@@ -93,9 +94,30 @@ class PluginRoute(Route):
try:
logger.info(f"正在更新插件 {plugin_name}")
await self.plugin_manager.update_plugin(plugin_name)
self.core_lifecycle.restart()
logger.info(f"更新插件 {plugin_name} 成功,2秒后重启")
return Response().ok(None, "更新成功,程序将在 2 秒内重启。").__dict__
logger.info(f"更新插件 {plugin_name} 成功。")
return Response().ok(None, "更新成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/update: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def off_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_off_plugin(plugin_name)
logger.info(f"停用插件 {plugin_name}")
return Response().ok(None, "停用成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/off: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
async def on_plugin(self):
post_data = await request.json
plugin_name = post_data["name"]
try:
await self.plugin_manager.turn_on_plugin(plugin_name)
logger.info(f"启用插件 {plugin_name}")
return Response().ok(None, "启用成功。").__dict__
except Exception as e:
logger.error(f"/api/extensions/on: {traceback.format_exc()}")
return Response().error(str(e)).__dict__
+20 -3
View File
@@ -3,7 +3,7 @@ import traceback
from .route import Route, Response, RouteContext
from quart import request
from astrbot.core.updator import AstrBotUpdator
from astrbot.core import logger
from astrbot.core import logger, pip_installer
class UpdateRoute(Route):
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
@@ -11,6 +11,7 @@ class UpdateRoute(Route):
self.routes = {
'/update/check': ('GET', self.check_update),
'/update/do': ('POST', self.update_project),
'/update/pip-install': ('POST', self.install_pip_package)
}
self.astrbot_updator = astrbot_updator
self.register_routes()
@@ -32,6 +33,7 @@ class UpdateRoute(Route):
async def update_project(self):
data = await request.json
version = data.get('version', '')
reboot = data.get('reboot', True)
if version == "" or version == "latest":
latest = True
version = ''
@@ -39,8 +41,23 @@ class UpdateRoute(Route):
latest = False
try:
await self.astrbot_updator.update(latest=latest, version=version)
threading.Thread(target=self.astrbot_updator._reboot, args=(2, self.context)).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
if reboot:
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
else:
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
except Exception as e:
logger.error(f"/api/update_project: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
async def install_pip_package(self):
data = await request.json
package = data.get('package', '')
if not package:
return Response().error("缺少参数 package 或不合法。").__dict__
try:
pip_installer.install(package)
return Response().ok(None, "安装成功。").__dict__
except Exception as e:
logger.error(f"/api/update_pip: {traceback.format_exc()}")
return Response().error(e.__str__()).__dict__
+6 -3
View File
@@ -2,7 +2,7 @@ import logging
import jwt
import asyncio
import os
from quart import Quart, request, jsonify
from quart import Quart, request, jsonify, g
from quart.logging import default_handler
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from .routes import *
@@ -18,7 +18,6 @@ class AstrBotDashboard():
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.data_path = os.path.abspath(os.path.join(DATAPATH, "dist"))
logger.info(f"Dashboard data path: {self.data_path}")
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
self.app.json.sort_keys = False
self.app.before_request(self.auth_middleware)
@@ -32,12 +31,15 @@ class AstrBotDashboard():
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context)
self.chat_route = ChatRoute(self.context, db, core_lifecycle)
async def auth_middleware(self):
if not request.path.startswith("/api"):
return
if request.path == "/api/auth/login":
return
if request.path == "/api/chat/get_file":
return
# claim jwt
token = request.headers.get("Authorization")
if not token:
@@ -47,7 +49,8 @@ class AstrBotDashboard():
if token.startswith("Bearer "):
token = token[7:]
try:
jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
payload = jwt.decode(token, WEBUI_SK, algorithms=["HS256"])
g.username = payload["username"]
except jwt.ExpiredSignatureError:
r = jsonify(Response().error("Token 过期").__dict__)
r.status_code = 401
+14
View File
@@ -0,0 +1,14 @@
# What's Changed
1. 修复了 reminder 插件可能不会触发回调的问题。
2. 修复了 telegram 插件不可用的问题。
3. 修复了 qq_official 无法发图的问题。
4. 修复事件监听器会让 WakeStage 失效的问题。
5. 修复 websearch 在 cmd_config 中失效的问题。
3. 支持通过 Google GenAI 访问 Gemini 模型,而不需要使用 Gemini 对 OpenAI 的兼容 API。详见文档。
4. 支持对插件禁用/启用。/plugin off/on <plugin_name>
5. 支持基于 Docker 的沙箱化代码执行器。(Beta 测试)详见文档。
6. 支持接入 Dify LLMOps 平台。详见文档。
7. 适配器类插件支持设置默认配置模板。
8. 优化了部分指令的持久化记忆。如 /tool 的禁用、/provider 的选择都将持久化保存,每次启动时不需要重新设置。
9. 优化了 glm-4v-flash 模型。其只支持一张图。
+6
View File
@@ -0,0 +1,6 @@
# What's Changed
1. 支持通过 /set <k> <v> 设置持久化的会话变量, 方便 Dify App 输入变量
2. 管理面板支持 Web Chat
3. 管理面板支持手动安装 Pip 库, 在 `控制台` 页中可找到
+9
View File
@@ -0,0 +1,9 @@
# What's Changed
- 支持接入 STT(语音转文字)Provider
- 内置支持 OpenAI Whisper API/本地运行模型。[看这里](https://astrbot.lwl.lol/use/whisper.html)
- WebChat 支持语音输入
- WebChat 支持显示当前 Provider 状态
- 优化了 WebChat 在没有消息返回时的处理方式
- 修复了 reminder 在初始化历史待办时没有正常传入 session_id 的问题
- 代码执行器在成功回复后清空文件 buffer。
+2
View File
@@ -0,0 +1,2 @@
node_modules/
.DS_Store
+21
View File
@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 CodedThemes
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
+3
View File
@@ -0,0 +1,3 @@
# AstrBot 管理面板
基于 CodedThemes/Berry 模板开发。
+1
View File
@@ -0,0 +1 @@
/// <reference types="vite/client" />
+19
View File
@@ -0,0 +1,19 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<link rel="icon" href="/favicon.svg" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta name="keywords" content="AstrBot Soulter" />
<meta name="description" content="AstrBot Dashboard" />
<link
rel="stylesheet"
href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&family=Poppins:wght@400;500;600;700&family=Roboto:wght@400;500;700&display=swap"
/>
<title>AstrBot - 仪表盘</title>
</head>
<body>
<div id="app"></div>
<script type="module" src="/src/main.ts"></script>
</body>
</html>
+9998
View File
File diff suppressed because it is too large Load Diff
+58
View File
@@ -0,0 +1,58 @@
{
"name": "astrbot-dashboard",
"version": "1.0.0",
"private": true,
"author": "CodedThemes",
"scripts": {
"dev": "vite --host",
"build": "vue-tsc --noEmit && vite build",
"build-stage": "vue-tsc --noEmit && vite build --base=/vue/free/stage/",
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
"preview": "vite preview --port 5050",
"typecheck": "vue-tsc --noEmit",
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
},
"dependencies": {
"@guolao/vue-monaco-editor": "^1.5.4",
"@tiptap/starter-kit": "2.1.7",
"@tiptap/vue-3": "2.1.7",
"apexcharts": "3.42.0",
"axios": "^1.6.2",
"axios-mock-adapter": "^1.22.0",
"chance": "1.1.11",
"date-fns": "2.30.0",
"js-md5": "^0.8.3",
"lodash": "4.17.21",
"marked": "^15.0.6",
"pinia": "2.1.6",
"remixicon": "3.5.0",
"vee-validate": "4.11.3",
"vite-plugin-vuetify": "1.0.2",
"vue": "3.3.4",
"vue-router": "4.2.4",
"vue3-apexcharts": "1.4.4",
"vue3-print-nb": "0.1.4",
"vuetify": "3.3.14",
"yup": "1.2.0"
},
"devDependencies": {
"@mdi/font": "7.2.96",
"@rushstack/eslint-patch": "1.3.3",
"@types/chance": "1.1.3",
"@types/node": "20.5.7",
"@vitejs/plugin-vue": "4.3.3",
"@vue/eslint-config-prettier": "8.0.0",
"@vue/eslint-config-typescript": "11.0.3",
"@vue/tsconfig": "0.4.0",
"eslint": "8.48.0",
"eslint-plugin-vue": "9.17.0",
"prettier": "3.0.2",
"sass": "1.66.1",
"sass-loader": "13.3.2",
"typescript": "5.1.6",
"vite": "4.4.9",
"vue-cli-plugin-vuetify": "2.5.8",
"vue-tsc": "1.8.8",
"vuetify-loader": "^2.0.0-alpha.9"
}
}
+1
View File
@@ -0,0 +1 @@
/* /index.html 200
+1
View File
@@ -0,0 +1 @@
<svg t="1702013028016" class="icon" viewBox="0 0 1024 1024" version="1.1" xmlns="http://www.w3.org/2000/svg" p-id="1541" width="200" height="200"><path d="M0 0m204.8 0l614.4 0q204.8 0 204.8 204.8l0 614.4q0 204.8-204.8 204.8l-614.4 0q-204.8 0-204.8-204.8l0-614.4q0-204.8 204.8-204.8Z" fill="#FFEC9C" p-id="1542"></path><path d="M819.2 0H534.272A756.48 756.48 0 0 0 0 483.584V819.2a204.8 204.8 0 0 0 204.8 204.8h614.4a204.8 204.8 0 0 0 204.8-204.8V204.8a204.8 204.8 0 0 0-204.8-204.8z" fill="#FFE98A" p-id="1543"></path><path d="M819.2 0h-3.84a755.2 755.2 0 0 0-539.392 1024H819.2a204.8 204.8 0 0 0 204.8-204.8V204.8a204.8 204.8 0 0 0-204.8-204.8z" fill="#FFE471" p-id="1544"></path><path d="M497.152 721.152A752.384 752.384 0 0 0 560.384 1024H819.2a204.8 204.8 0 0 0 204.8-204.8V204.8a204.8 204.8 0 0 0-89.088-168.96 755.2 755.2 0 0 0-437.76 685.312z" fill="#FFE161" p-id="1545"></path><path d="M526.08 140.032l98.304 199.168L844.8 371.2a15.616 15.616 0 0 1 8.704 25.6l-159.744 156.16 37.632 219.136a15.616 15.616 0 0 1-22.528 16.384l-196.608-102.4-196.608 102.4a15.616 15.616 0 0 1-22.528-16.384l37.12-219.136-159.232-155.136a15.616 15.616 0 0 1 8.704-25.6l219.904-32 98.304-199.168a15.616 15.616 0 0 1 28.16-1.024z" fill="#FFF5CC" p-id="1546"></path><path d="M665.6 409.6a444.16 444.16 0 0 0 25.6-61.44l-65.536-9.472-99.584-198.656a15.616 15.616 0 0 0-27.904 0l-98.304 199.168L179.2 371.2a15.616 15.616 0 0 0-8.704 25.6l159.744 156.16-15.104 87.04A407.808 407.808 0 0 0 665.6 409.6z" fill="#FFFFFF" p-id="1547"></path></svg>

After

Width:  |  Height:  |  Size: 1.5 KiB

+7
View File
@@ -0,0 +1,7 @@
<template>
<RouterView></RouterView>
</template>
<script setup lang="ts">
import { RouterView } from 'vue-router';
</script>
@@ -0,0 +1,6 @@
<svg width="22" height="22" viewBox="0 0 22 22" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M5.06129 13.2253L4.31871 15.9975L1.60458 16.0549C0.793457 14.5504 0.333374 12.8292 0.333374 11C0.333374 9.23119 0.763541 7.56319 1.52604 6.09448H1.52662L3.94296 6.53748L5.00146 8.93932C4.77992 9.58519 4.65917 10.2785 4.65917 11C4.65925 11.783 4.80108 12.5332 5.06129 13.2253Z" fill="#FBBB00"/>
<path d="M21.4804 9.00732C21.6029 9.65257 21.6668 10.3189 21.6668 11C21.6668 11.7637 21.5865 12.5086 21.4335 13.2271C20.9143 15.6722 19.5575 17.8073 17.678 19.3182L17.6774 19.3177L14.6339 19.1624L14.2031 16.4734C15.4503 15.742 16.425 14.5974 16.9384 13.2271H11.2346V9.00732H17.0216H21.4804Z" fill="#518EF8"/>
<path d="M17.6772 19.3176L17.6777 19.3182C15.8498 20.7875 13.5277 21.6666 11 21.6666C6.93783 21.6666 3.40612 19.3962 1.60449 16.0549L5.0612 13.2253C5.96199 15.6294 8.28112 17.3408 11 17.3408C12.1686 17.3408 13.2634 17.0249 14.2029 16.4734L17.6772 19.3176Z" fill="#28B446"/>
<path d="M17.8085 2.78892L14.353 5.61792C13.3807 5.01017 12.2313 4.65908 11 4.65908C8.21963 4.65908 5.85713 6.44896 5.00146 8.93925L1.52658 6.09442H1.526C3.30125 2.67171 6.8775 0.333252 11 0.333252C13.5881 0.333252 15.9612 1.25517 17.8085 2.78892Z" fill="#F14336"/>
</svg>

After

Width:  |  Height:  |  Size: 1.2 KiB

+18
View File
@@ -0,0 +1,18 @@
<svg width="46" height="55" viewBox="0 0 46 55" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clipPath="url(#clip0)">
<path d="M19.6667 55C8.82205 55 0 46.2504 0 35.4968C0 24.7431 8.82292 15.9935 19.6667 15.9935C30.5105 15.9935 39.3334 24.7431 39.3334 35.4968C39.3334 46.2504 30.5122 55 19.6667 55ZM19.6667 17.8563C9.8587 17.8563 1.87839 25.7686 1.87839 35.4959C1.87839 45.2233 9.8587 53.1355 19.6667 53.1355C29.4747 53.1355 37.4559 45.2215 37.4559 35.4942C37.4559 25.7668 29.4765 17.8563 19.6667 17.8563Z" fill="#2196F3"/>
<path d="M33.9387 36.3618C33.3269 34.1133 27.7188 33.8706 24.3807 34.6949C22.6326 35.1283 20.846 35.6917 19.0034 36.0159C20.3521 37.2026 21.8005 38.3251 23.879 38.6042C29.0361 39.2942 32.2404 37.6898 33.9387 36.3618Z" fill="#2196F3"/>
<path d="M23.8788 38.6042C21.7959 38.3251 20.3519 37.2026 19.0032 36.016C16.9159 34.1792 15.0594 32.189 11.4154 32.9379C5.62198 34.1289 4.85978 40.9247 9.3333 45.2917C11.254 47.2864 13.7197 48.6822 16.4284 49.3079C19.137 49.9336 21.9709 49.7621 24.5828 48.8144C27.1946 47.8667 29.4709 46.1839 31.1327 43.9724C32.7945 41.7608 33.7696 39.1165 33.9385 36.3635C32.2402 37.6898 29.0358 39.2942 23.8788 38.6042Z" fill="#673AB7"/>
<path d="M26.9105 23.8962C26.1876 25.4331 32.6321 27.1381 33.4031 32.2419C33.7746 27.2178 27.8046 21.9962 26.9105 23.8962Z" fill="#2196F3"/>
<path d="M13.3649 30.3107C14.5267 29.8335 15.0784 28.5126 14.5972 27.3604C14.116 26.2083 12.784 25.6611 11.6222 26.1384C10.4604 26.6156 9.90867 27.9365 10.3899 29.0887C10.8712 30.2408 12.2031 30.7879 13.3649 30.3107Z" fill="#673AB7"/>
<path d="M18.5351 24.1103C19.0786 23.5714 19.0786 22.6977 18.5351 22.1587C17.9917 21.6198 17.1106 21.6198 16.5672 22.1587C16.0238 22.6977 16.0238 23.5714 16.5672 24.1103C17.1106 24.6492 17.9917 24.6492 18.5351 24.1103Z" fill="#2196F3"/>
<path d="M23.4513 15.2376C25.4617 9.3485 24.1103 4.64345 19.9786 2.40881C17.1544 2.97831 15.4779 4.334 14.5444 6.20544C20.0843 5.76077 23.5999 9.1994 23.4513 15.2376Z" fill="#2196F3"/>
<path d="M46.0001 10.0923C36.0487 6.55051 29.7685 7.76491 28.7808 15.8349C34.4841 21.6703 40.2286 18.8774 46.0001 10.0923Z" fill="#2196F3"/>
<path d="M38.0851 6.89635C38.5466 4.94082 38.7861 2.6299 38.8219 0C28.5017 2.27885 23.8473 6.6337 27.3584 13.9782C27.5333 14.0198 27.7011 14.0536 27.8698 14.0883C28.6905 8.34132 32.3031 6.2133 38.0851 6.89635Z" fill="#2196F3"/>
</g>
<defs>
<clipPath id="clip0">
<rect width="46" height="55" fill="white"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 2.4 KiB

@@ -0,0 +1,5 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M19 9H9C7.89543 9 7 9.89543 7 11V17C7 18.1046 7.89543 19 9 19H19C20.1046 19 21 18.1046 21 17V11C21 9.89543 20.1046 9 19 9Z" stroke="white" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
<path d="M14 16C15.1046 16 16 15.1046 16 14C16 12.8954 15.1046 12 14 12C12.8954 12 12 12.8954 12 14C12 15.1046 12.8954 16 14 16Z" fill="#90CAF9"/>
<path d="M17 9V7C17 6.46957 16.7893 5.96086 16.4142 5.58579C16.0391 5.21071 15.5304 5 15 5H5C4.46957 5 3.96086 5.21071 3.58579 5.58579C3.21071 5.96086 3 6.46957 3 7V13C3 13.5304 3.21071 14.0391 3.58579 14.4142C3.96086 14.7893 4.46957 15 5 15H7" stroke="white" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 794 B

Some files were not shown because too many files have changed in this diff Show More