Compare commits
30 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4fd26814cb | |||
| 5f531c9be5 | |||
| 94591d965b | |||
| 8a0f865af1 | |||
| 4aced976a8 | |||
| 0299aa6e4c | |||
| fd05b0bf09 | |||
| 58e32b7b70 | |||
| 80b89fd2ea | |||
| 26f863ba81 | |||
| f78a90218e | |||
| a3ecebd2aa | |||
| aaee283367 | |||
| 4a5b7d1976 | |||
| 08244548ab | |||
| b486de6a98 | |||
| e2f928a7e5 | |||
| b8e4068c75 | |||
| 0916177a57 | |||
| 02cd5e396b | |||
| 56673ad78f | |||
| 9a4d05e2b6 | |||
| c3f45449e8 | |||
| 65da469deb | |||
| 16df64c405 | |||
| 6b73b19e54 | |||
| e7e97730af | |||
| 467ca1eb5c | |||
| 46528391c2 | |||
| 1c090299b1 |
@@ -0,0 +1,79 @@
|
|||||||
|
name: Build Desktop App
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
workflow_dispatch:
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build:
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
platform: [macos-latest, ubuntu-latest, windows-latest]
|
||||||
|
|
||||||
|
runs-on: ${{ matrix.platform }}
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Setup Node.js
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 20
|
||||||
|
|
||||||
|
- name: Install Rust
|
||||||
|
uses: dtolnay/rust-toolchain@stable
|
||||||
|
|
||||||
|
- name: Install dependencies (Ubuntu)
|
||||||
|
if: matrix.platform == 'ubuntu-latest'
|
||||||
|
run: |
|
||||||
|
sudo apt-get update
|
||||||
|
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev libappindicator3-dev librsvg2-dev patchelf
|
||||||
|
|
||||||
|
- name: Install Python dependencies
|
||||||
|
run: |
|
||||||
|
pip install uv
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
- name: Build Python backend with Nuitka
|
||||||
|
run: |
|
||||||
|
pip install nuitka
|
||||||
|
python build_nuitka.py
|
||||||
|
|
||||||
|
- name: Install Node dependencies
|
||||||
|
working-directory: ./dashboard
|
||||||
|
run: npm install
|
||||||
|
|
||||||
|
- name: Build Tauri app
|
||||||
|
working-directory: ./dashboard
|
||||||
|
run: npm run tauri:build
|
||||||
|
|
||||||
|
- name: Upload artifacts (macOS)
|
||||||
|
if: matrix.platform == 'macos-latest'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: astrbot-macos
|
||||||
|
path: dashboard/src-tauri/target/release/bundle/dmg/*.dmg
|
||||||
|
|
||||||
|
- name: Upload artifacts (Windows)
|
||||||
|
if: matrix.platform == 'windows-latest'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: astrbot-windows
|
||||||
|
path: dashboard/src-tauri/target/release/bundle/msi/*.msi
|
||||||
|
|
||||||
|
- name: Upload artifacts (Linux)
|
||||||
|
if: matrix.platform == 'ubuntu-latest'
|
||||||
|
uses: actions/upload-artifact@v4
|
||||||
|
with:
|
||||||
|
name: astrbot-linux
|
||||||
|
path: |
|
||||||
|
dashboard/src-tauri/target/release/bundle/deb/*.deb
|
||||||
|
dashboard/src-tauri/target/release/bundle/appimage/*.AppImage
|
||||||
@@ -36,7 +36,7 @@ jobs:
|
|||||||
zip -r dist.zip dist
|
zip -r dist.zip dist
|
||||||
|
|
||||||
- name: Archive production artifacts
|
- name: Archive production artifacts
|
||||||
uses: actions/upload-artifact@v5
|
uses: actions/upload-artifact@v6
|
||||||
with:
|
with:
|
||||||
name: dist-without-markdown
|
name: dist-without-markdown
|
||||||
path: |
|
path: |
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ tests/astrbot_plugin_openai
|
|||||||
# Dashboard
|
# Dashboard
|
||||||
dashboard/node_modules/
|
dashboard/node_modules/
|
||||||
dashboard/dist/
|
dashboard/dist/
|
||||||
|
dashboard/src-tauri/target
|
||||||
package-lock.json
|
package-lock.json
|
||||||
package.json
|
package.json
|
||||||
yarn.lock
|
yarn.lock
|
||||||
@@ -48,5 +49,6 @@ astrbot.lock
|
|||||||
chroma
|
chroma
|
||||||
venv/*
|
venv/*
|
||||||
pytest.ini
|
pytest.ini
|
||||||
|
build/
|
||||||
AGENTS.md
|
AGENTS.md
|
||||||
IFLOW.md
|
IFLOW.md
|
||||||
|
|||||||
@@ -0,0 +1,287 @@
|
|||||||
|
# AstrBot 桌面应用构建指南
|
||||||
|
|
||||||
|
本指南介绍如何使用 Nuitka 将 Python 后端打包并集成到 Tauri 桌面应用中。
|
||||||
|
|
||||||
|
## 前置要求
|
||||||
|
|
||||||
|
### 系统要求
|
||||||
|
- Python 3.10+
|
||||||
|
- Node.js 20+
|
||||||
|
- Rust (通过 rustup 安装)
|
||||||
|
- UV 包管理器
|
||||||
|
|
||||||
|
### macOS 额外要求
|
||||||
|
- Xcode Command Line Tools: `xcode-select --install`
|
||||||
|
|
||||||
|
### Linux 额外要求
|
||||||
|
```bash
|
||||||
|
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev \
|
||||||
|
libappindicator3-dev librsvg2-dev patchelf
|
||||||
|
```
|
||||||
|
|
||||||
|
### Windows 额外要求
|
||||||
|
- Visual Studio 2019+ with C++ build tools
|
||||||
|
- Windows 10 SDK
|
||||||
|
|
||||||
|
## 构建步骤
|
||||||
|
|
||||||
|
### 1. 安装 Python 依赖
|
||||||
|
```bash
|
||||||
|
pip install uv
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 安装 Nuitka
|
||||||
|
```bash
|
||||||
|
pip install nuitka
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 构建 Python 后端
|
||||||
|
```bash
|
||||||
|
python build_nuitka.py
|
||||||
|
```
|
||||||
|
|
||||||
|
这会使用 Nuitka 将 `main.py` 编译为独立可执行文件,输出到 `build/nuitka/` 目录。
|
||||||
|
|
||||||
|
**注意**: Nuitka 编译过程可能需要 10-30 分钟,取决于您的系统性能。
|
||||||
|
|
||||||
|
### 4. 安装前端依赖
|
||||||
|
```bash
|
||||||
|
cd dashboard
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 构建 Tauri 应用
|
||||||
|
```bash
|
||||||
|
npm run tauri:build
|
||||||
|
```
|
||||||
|
|
||||||
|
构建脚本会自动:
|
||||||
|
1. 运行 `build_nuitka.py` 编译 Python 后端
|
||||||
|
2. 将编译好的可执行文件复制到 `src-tauri/resources/` 目录
|
||||||
|
3. 构建 Tauri 应用并打包所有资源
|
||||||
|
|
||||||
|
### 6. 查找构建产物
|
||||||
|
|
||||||
|
构建完成后,您可以在以下位置找到安装包:
|
||||||
|
|
||||||
|
- **macOS**: `dashboard/src-tauri/target/release/bundle/dmg/AstrBot_*.dmg`
|
||||||
|
- **Windows**: `dashboard/src-tauri/target/release/bundle/msi/AstrBot_*.msi`
|
||||||
|
- **Linux**:
|
||||||
|
- `dashboard/src-tauri/target/release/bundle/deb/astrbot_*.deb`
|
||||||
|
- `dashboard/src-tauri/target/release/bundle/appimage/astrbot_*.AppImage`
|
||||||
|
|
||||||
|
## 开发模式
|
||||||
|
|
||||||
|
在开发时,您可能不想每次都完整编译 Python 后端。
|
||||||
|
|
||||||
|
### 仅开发 Tauri + Vue
|
||||||
|
```bash
|
||||||
|
cd dashboard
|
||||||
|
npm run tauri:dev
|
||||||
|
```
|
||||||
|
|
||||||
|
这会启动开发服务器,但不会自动启动 Python 后端。您需要手动运行:
|
||||||
|
```bash
|
||||||
|
uv run main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 测试完整集成
|
||||||
|
如果您想测试 Tauri 自动启动 Python 后端的功能:
|
||||||
|
|
||||||
|
1. 先编译一次 Python 后端:
|
||||||
|
```bash
|
||||||
|
python build_nuitka.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 手动复制到资源目录:
|
||||||
|
```bash
|
||||||
|
# macOS
|
||||||
|
cp -r build/nuitka/main.app dashboard/src-tauri/resources/astrbot-backend.app
|
||||||
|
|
||||||
|
# Windows
|
||||||
|
copy build\nuitka\main.exe dashboard\src-tauri\resources\astrbot-backend.exe
|
||||||
|
|
||||||
|
# Linux
|
||||||
|
cp build/nuitka/main.bin dashboard/src-tauri/resources/astrbot-backend
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 运行开发模式:
|
||||||
|
```bash
|
||||||
|
cd dashboard
|
||||||
|
npm run tauri:dev
|
||||||
|
```
|
||||||
|
|
||||||
|
## Nuitka 构建选项说明
|
||||||
|
|
||||||
|
`build_nuitka.py` 脚本使用以下关键选项:
|
||||||
|
|
||||||
|
- `--standalone`: 创建包含所有依赖的独立目录
|
||||||
|
- `--onefile`: 将所有内容打包到单个可执行文件
|
||||||
|
- `--follow-imports`: 自动跟踪所有 Python 导入
|
||||||
|
- `--include-package`: 明确包含特定包
|
||||||
|
- `--include-data-dir`: 包含数据目录(插件、配置等)
|
||||||
|
|
||||||
|
### 自定义构建
|
||||||
|
|
||||||
|
如果您需要修改构建选项,编辑 `build_nuitka.py`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 添加更多要包含的包
|
||||||
|
include_packages = [
|
||||||
|
"astrbot",
|
||||||
|
"your_custom_package",
|
||||||
|
# ...
|
||||||
|
]
|
||||||
|
|
||||||
|
# 添加更多数据目录
|
||||||
|
data_includes = [
|
||||||
|
"data/config",
|
||||||
|
"your_custom_data",
|
||||||
|
# ...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### 1. Nuitka 编译失败
|
||||||
|
**问题**: 编译时出现 "module not found" 错误
|
||||||
|
|
||||||
|
**解决方案**: 在 `build_nuitka.py` 中添加缺失的包到 `include_packages` 列表
|
||||||
|
|
||||||
|
### 2. 运行时找不到资源文件
|
||||||
|
**问题**: 应用启动后提示找不到配置文件或插件
|
||||||
|
|
||||||
|
**解决方案**: 确保在 `build_nuitka.py` 中使用 `--include-data-dir` 包含了所有必要的数据目录
|
||||||
|
|
||||||
|
### 3. macOS 安全警告
|
||||||
|
**问题**: macOS 提示"应用来自未知开发者"
|
||||||
|
|
||||||
|
**解决方案**:
|
||||||
|
```bash
|
||||||
|
# 临时解除限制
|
||||||
|
sudo spctl --master-disable
|
||||||
|
|
||||||
|
# 或者为特定应用授权
|
||||||
|
xattr -cr /Applications/AstrBot.app
|
||||||
|
```
|
||||||
|
|
||||||
|
对于生产发布,您需要:
|
||||||
|
1. 注册 Apple Developer 账号
|
||||||
|
2. 对应用进行代码签名
|
||||||
|
3. 提交公证 (Notarization)
|
||||||
|
|
||||||
|
### 4. Windows Defender 报毒
|
||||||
|
**问题**: Windows Defender 或其他杀毒软件报毒
|
||||||
|
|
||||||
|
**解决方案**:
|
||||||
|
- 这是 Nuitka 打包程序的常见问题
|
||||||
|
- 可以使用 `--windows-company-name` 和 `--windows-product-name` 添加元数据
|
||||||
|
- 对于生产发布,需要购买代码签名证书
|
||||||
|
|
||||||
|
### 5. Linux 依赖问题
|
||||||
|
**问题**: 在某些 Linux 发行版上缺少共享库
|
||||||
|
|
||||||
|
**解决方案**: 使用 AppImage 格式,它包含所有依赖:
|
||||||
|
```bash
|
||||||
|
# 构建时会自动生成 AppImage
|
||||||
|
npm run tauri:build
|
||||||
|
```
|
||||||
|
|
||||||
|
## 优化构建大小
|
||||||
|
|
||||||
|
默认的 `--onefile` 模式会生成较大的可执行文件。如果需要减小体积:
|
||||||
|
|
||||||
|
1. 移除不需要的包
|
||||||
|
2. 使用 `--standalone` 而不是 `--onefile`
|
||||||
|
3. 排除不必要的数据文件
|
||||||
|
|
||||||
|
修改 `build_nuitka.py`:
|
||||||
|
```python
|
||||||
|
# 移除 --onefile,使用 --standalone
|
||||||
|
nuitka_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m", "nuitka",
|
||||||
|
"--standalone", # 只使用 standalone
|
||||||
|
# "--onefile", # 注释掉 onefile
|
||||||
|
# ...
|
||||||
|
]
|
||||||
|
```
|
||||||
|
|
||||||
|
## CI/CD 集成
|
||||||
|
|
||||||
|
项目已配置 GitHub Actions 工作流 (`.github/workflows/build-app.yml`),可以自动为所有平台构建应用。
|
||||||
|
|
||||||
|
推送标签时自动触发:
|
||||||
|
```bash
|
||||||
|
git tag v4.5.7
|
||||||
|
git push origin v4.5.7
|
||||||
|
```
|
||||||
|
|
||||||
|
或手动触发:
|
||||||
|
在 GitHub Actions 页面选择 "Build Desktop App" 工作流并点击 "Run workflow"
|
||||||
|
|
||||||
|
## 发布清单
|
||||||
|
|
||||||
|
在发布新版本前:
|
||||||
|
|
||||||
|
- [ ] 更新版本号
|
||||||
|
- `pyproject.toml` - Python 项目版本
|
||||||
|
- `dashboard/package.json` - Node 项目版本
|
||||||
|
- `dashboard/src-tauri/Cargo.toml` - Rust 项目版本
|
||||||
|
- `dashboard/src-tauri/tauri.conf.json` - Tauri 配置版本
|
||||||
|
|
||||||
|
- [ ] 运行代码检查
|
||||||
|
```bash
|
||||||
|
uv run ruff check .
|
||||||
|
uv run ruff format .
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] 本地测试构建
|
||||||
|
```bash
|
||||||
|
python build_nuitka.py
|
||||||
|
cd dashboard && npm run tauri:build
|
||||||
|
```
|
||||||
|
|
||||||
|
- [ ] 测试安装包
|
||||||
|
- 安装生成的安装包
|
||||||
|
- 验证应用启动
|
||||||
|
- 验证 Python 后端自动启动
|
||||||
|
- 测试核心功能
|
||||||
|
|
||||||
|
- [ ] 创建发布标签
|
||||||
|
```bash
|
||||||
|
git tag -a v4.5.7 -m "Release v4.5.7"
|
||||||
|
git push origin v4.5.7
|
||||||
|
```
|
||||||
|
|
||||||
|
## 技术架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────┐
|
||||||
|
│ Tauri Desktop App │
|
||||||
|
│ (Rust + WebView) │
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────────────────────┐ │
|
||||||
|
│ │ Vue.js Dashboard │ │
|
||||||
|
│ │ (Frontend UI) │ │
|
||||||
|
│ └─────────────────────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ ┌─────────────────────────────┐ │
|
||||||
|
│ │ Python Backend │ │
|
||||||
|
│ │ (Nuitka Compiled) │ │
|
||||||
|
│ │ - AstrBot Core │ │
|
||||||
|
│ │ - Plugins │ │
|
||||||
|
│ │ - API Server │ │
|
||||||
|
│ └─────────────────────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ HTTP/WebSocket │
|
||||||
|
│ localhost:6185 │
|
||||||
|
└─────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
## 参考资源
|
||||||
|
|
||||||
|
- [Nuitka 文档](https://nuitka.net/doc/user-manual.html)
|
||||||
|
- [Tauri 文档](https://tauri.app/v1/guides/)
|
||||||
|
- [AstrBot 文档](https://astrbot.fun)
|
||||||
@@ -33,6 +33,20 @@
|
|||||||
- 请使用英文描述您的 PR。
|
- 请使用英文描述您的 PR。
|
||||||
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
||||||
|
|
||||||
|
#### 代码规范
|
||||||
|
|
||||||
|
##### Core
|
||||||
|
|
||||||
|
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ruff format .
|
||||||
|
ruff check .
|
||||||
|
```
|
||||||
|
|
||||||
|
如果您使用 VSCode,可以安装 `Ruff` 插件。
|
||||||
|
|
||||||
|
|
||||||
## Contributing Guide
|
## Contributing Guide
|
||||||
|
|
||||||
First off, thanks for taking the time to contribute! ❤️
|
First off, thanks for taking the time to contribute! ❤️
|
||||||
@@ -62,4 +76,15 @@ We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features.
|
|||||||
|
|
||||||
#### PR Description
|
#### PR Description
|
||||||
- Please use English to describe your PR.
|
- Please use English to describe your PR.
|
||||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||||
|
|
||||||
|
#### Code Style
|
||||||
|
|
||||||
|
##### Core
|
||||||
|
|
||||||
|
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
ruff format .
|
||||||
|
ruff check .
|
||||||
|
```
|
||||||
|
|||||||
@@ -243,4 +243,10 @@ pre-commit install
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
_私は、高性能ですから!_
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
|
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||||
|
</div
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = "4.8.0"
|
__version__ = "4.9.2"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from typing import Any, ClassVar, Literal, cast
|
from typing import Any, ClassVar, Literal, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
|
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
|
||||||
@@ -122,10 +122,12 @@ class ToolCall(BaseModel):
|
|||||||
extra_content: dict[str, Any] | None = None
|
extra_content: dict[str, Any] | None = None
|
||||||
"""Extra metadata for the tool call."""
|
"""Extra metadata for the tool call."""
|
||||||
|
|
||||||
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
@model_serializer(mode="wrap")
|
||||||
|
def serialize(self, handler):
|
||||||
|
data = handler(self)
|
||||||
if self.extra_content is None:
|
if self.extra_content is None:
|
||||||
kwargs.setdefault("exclude", set()).add("extra_content")
|
data.pop("extra_content", None)
|
||||||
return super().model_dump(**kwargs)
|
return data
|
||||||
|
|
||||||
|
|
||||||
class ToolCallPart(BaseModel):
|
class ToolCallPart(BaseModel):
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import typing as T
|
import typing as T
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
|
from astrbot.core.provider.entities import TokenUsage
|
||||||
|
|
||||||
|
|
||||||
class AgentResponseData(T.TypedDict):
|
class AgentResponseData(T.TypedDict):
|
||||||
@@ -12,3 +13,23 @@ class AgentResponseData(T.TypedDict):
|
|||||||
class AgentResponse:
|
class AgentResponse:
|
||||||
type: str
|
type: str
|
||||||
data: AgentResponseData
|
data: AgentResponseData
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AgentStats:
|
||||||
|
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
||||||
|
start_time: float = 0.0
|
||||||
|
end_time: float = 0.0
|
||||||
|
time_to_first_token: float = 0.0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def duration(self) -> float:
|
||||||
|
return self.end_time - self.start_time
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"token_usage": self.token_usage.__dict__,
|
||||||
|
"start_time": self.start_time,
|
||||||
|
"end_time": self.end_time,
|
||||||
|
"time_to_first_token": self.time_to_first_token,
|
||||||
|
}
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from .message import Message
|
|||||||
TContext = TypeVar("TContext", default=Any)
|
TContext = TypeVar("TContext", default=Any)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(config={"arbitrary_types_allowed": True})
|
@dataclass
|
||||||
class ContextWrapper(Generic[TContext]):
|
class ContextWrapper(Generic[TContext]):
|
||||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
import typing as T
|
import typing as T
|
||||||
|
|
||||||
@@ -12,6 +13,7 @@ from mcp.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
|
from astrbot.core.message.components import Json
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageChain,
|
MessageChain,
|
||||||
)
|
)
|
||||||
@@ -24,7 +26,7 @@ from astrbot.core.provider.provider import Provider
|
|||||||
|
|
||||||
from ..hooks import BaseAgentRunHooks
|
from ..hooks import BaseAgentRunHooks
|
||||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||||
from ..response import AgentResponseData
|
from ..response import AgentResponseData, AgentStats
|
||||||
from ..run_context import ContextWrapper, TContext
|
from ..run_context import ContextWrapper, TContext
|
||||||
from ..tool_executor import BaseFunctionToolExecutor
|
from ..tool_executor import BaseFunctionToolExecutor
|
||||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||||
@@ -69,6 +71,9 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
self.run_context.messages = messages
|
self.run_context.messages = messages
|
||||||
|
|
||||||
|
self.stats = AgentStats()
|
||||||
|
self.stats.start_time = time.time()
|
||||||
|
|
||||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||||
"""Yields chunks *and* a final LLMResponse."""
|
"""Yields chunks *and* a final LLMResponse."""
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
@@ -98,6 +103,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
async for llm_response in self._iter_llm_responses():
|
async for llm_response in self._iter_llm_responses():
|
||||||
if llm_response.is_chunk:
|
if llm_response.is_chunk:
|
||||||
|
# update ttft
|
||||||
|
if self.stats.time_to_first_token == 0:
|
||||||
|
self.stats.time_to_first_token = time.time() - self.stats.start_time
|
||||||
|
|
||||||
if llm_response.result_chain:
|
if llm_response.result_chain:
|
||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type="streaming_delta",
|
type="streaming_delta",
|
||||||
@@ -121,6 +130,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
llm_resp_result = llm_response
|
llm_resp_result = llm_response
|
||||||
|
|
||||||
|
if not llm_response.is_chunk and llm_response.usage:
|
||||||
|
# only count the token usage of the final response for computation purpose
|
||||||
|
self.stats.token_usage += llm_response.usage
|
||||||
break # got final response
|
break # got final response
|
||||||
|
|
||||||
if not llm_resp_result:
|
if not llm_resp_result:
|
||||||
@@ -132,6 +145,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
if llm_resp.role == "err":
|
if llm_resp.role == "err":
|
||||||
# 如果 LLM 响应错误,转换到错误状态
|
# 如果 LLM 响应错误,转换到错误状态
|
||||||
self.final_llm_resp = llm_resp
|
self.final_llm_resp = llm_resp
|
||||||
|
self.stats.end_time = time.time()
|
||||||
self._transition_state(AgentState.ERROR)
|
self._transition_state(AgentState.ERROR)
|
||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type="err",
|
type="err",
|
||||||
@@ -146,6 +160,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
# 如果没有工具调用,转换到完成状态
|
# 如果没有工具调用,转换到完成状态
|
||||||
self.final_llm_resp = llm_resp
|
self.final_llm_resp = llm_resp
|
||||||
self._transition_state(AgentState.DONE)
|
self._transition_state(AgentState.DONE)
|
||||||
|
self.stats.end_time = time.time()
|
||||||
# record the final assistant message
|
# record the final assistant message
|
||||||
self.run_context.messages.append(
|
self.run_context.messages.append(
|
||||||
Message(
|
Message(
|
||||||
@@ -175,22 +190,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
# 如果有工具调用,还需处理工具调用
|
# 如果有工具调用,还需处理工具调用
|
||||||
if llm_resp.tools_call_name:
|
if llm_resp.tools_call_name:
|
||||||
tool_call_result_blocks = []
|
tool_call_result_blocks = []
|
||||||
for tool_call_name in llm_resp.tools_call_name:
|
|
||||||
yield AgentResponse(
|
|
||||||
type="tool_call",
|
|
||||||
data=AgentResponseData(
|
|
||||||
chain=MessageChain(type="tool_call").message(
|
|
||||||
f"🔨 调用工具: {tool_call_name}"
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
tool_call_result_blocks = result
|
tool_call_result_blocks = result
|
||||||
elif isinstance(result, MessageChain):
|
elif isinstance(result, MessageChain):
|
||||||
result.type = "tool_call_result"
|
if result.type is None:
|
||||||
|
# should not happen
|
||||||
|
continue
|
||||||
|
if result.type == "tool_direct_result":
|
||||||
|
ar_type = "tool_call_result"
|
||||||
|
else:
|
||||||
|
ar_type = result.type
|
||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type="tool_call_result",
|
type=ar_type,
|
||||||
data=AgentResponseData(chain=result),
|
data=AgentResponseData(chain=result),
|
||||||
)
|
)
|
||||||
# 将结果添加到上下文中
|
# 将结果添加到上下文中
|
||||||
@@ -233,6 +245,19 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
llm_response.tools_call_args,
|
llm_response.tools_call_args,
|
||||||
llm_response.tools_call_ids,
|
llm_response.tools_call_ids,
|
||||||
):
|
):
|
||||||
|
yield MessageChain(
|
||||||
|
type="tool_call",
|
||||||
|
chain=[
|
||||||
|
Json(
|
||||||
|
data={
|
||||||
|
"id": func_tool_id,
|
||||||
|
"name": func_tool_name,
|
||||||
|
"args": func_tool_args,
|
||||||
|
"ts": time.time(),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
if not req.func_tool:
|
if not req.func_tool:
|
||||||
return
|
return
|
||||||
@@ -306,7 +331,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
content=res.content[0].text,
|
content=res.content[0].text,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
yield MessageChain().message(res.content[0].text)
|
|
||||||
elif isinstance(res.content[0], ImageContent):
|
elif isinstance(res.content[0], ImageContent):
|
||||||
tool_call_result_blocks.append(
|
tool_call_result_blocks.append(
|
||||||
ToolCallMessageSegment(
|
ToolCallMessageSegment(
|
||||||
@@ -328,7 +352,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
content=resource.text,
|
content=resource.text,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
yield MessageChain().message(resource.text)
|
|
||||||
elif (
|
elif (
|
||||||
isinstance(resource, BlobResourceContents)
|
isinstance(resource, BlobResourceContents)
|
||||||
and resource.mimeType
|
and resource.mimeType
|
||||||
@@ -352,7 +375,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
content="返回的数据类型不受支持",
|
content="返回的数据类型不受支持",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
yield MessageChain().message("返回的数据类型不受支持。")
|
|
||||||
|
# yield the last tool call result
|
||||||
|
if tool_call_result_blocks:
|
||||||
|
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
||||||
|
yield MessageChain(
|
||||||
|
type="tool_call_result",
|
||||||
|
chain=[
|
||||||
|
Json(
|
||||||
|
data={
|
||||||
|
"id": func_tool_id,
|
||||||
|
"ts": time.time(),
|
||||||
|
"result": last_tcr_content,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
elif resp is None:
|
elif resp is None:
|
||||||
# Tool 直接请求发送消息给用户
|
# Tool 直接请求发送消息给用户
|
||||||
@@ -362,6 +400,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||||
)
|
)
|
||||||
self._transition_state(AgentState.DONE)
|
self._transition_state(AgentState.DONE)
|
||||||
|
self.stats.end_time = time.time()
|
||||||
else:
|
else:
|
||||||
# 不应该出现其他类型
|
# 不应该出现其他类型
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|||||||
from astrbot.core.star.context import Context
|
from astrbot.core.star.context import Context
|
||||||
|
|
||||||
|
|
||||||
@dataclass(config={"arbitrary_types_allowed": True})
|
@dataclass
|
||||||
class AstrAgentContext:
|
class AstrAgentContext:
|
||||||
|
__pydantic_config__ = {"arbitrary_types_allowed": True}
|
||||||
|
|
||||||
context: Context
|
context: Context
|
||||||
"""The star context instance"""
|
"""The star context instance"""
|
||||||
event: AstrMessageEvent
|
event: AstrMessageEvent
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
|
from astrbot.core.message.components import Json
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageChain,
|
MessageChain,
|
||||||
MessageEventResult,
|
MessageEventResult,
|
||||||
@@ -33,16 +34,27 @@ async def run_agent(
|
|||||||
msg_chain = resp.data["chain"]
|
msg_chain = resp.data["chain"]
|
||||||
if msg_chain.type == "tool_direct_result":
|
if msg_chain.type == "tool_direct_result":
|
||||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||||
await astr_event.send(resp.data["chain"])
|
await astr_event.send(msg_chain)
|
||||||
continue
|
continue
|
||||||
|
if astr_event.get_platform_id() == "webchat":
|
||||||
|
await astr_event.send(msg_chain)
|
||||||
# 对于其他情况,暂时先不处理
|
# 对于其他情况,暂时先不处理
|
||||||
continue
|
continue
|
||||||
elif resp.type == "tool_call":
|
elif resp.type == "tool_call":
|
||||||
if agent_runner.streaming:
|
if agent_runner.streaming:
|
||||||
# 用来标记流式响应需要分节
|
# 用来标记流式响应需要分节
|
||||||
yield MessageChain(chain=[], type="break")
|
yield MessageChain(chain=[], type="break")
|
||||||
if show_tool_use:
|
|
||||||
|
if astr_event.get_platform_name() == "webchat":
|
||||||
await astr_event.send(resp.data["chain"])
|
await astr_event.send(resp.data["chain"])
|
||||||
|
elif show_tool_use:
|
||||||
|
json_comp = resp.data["chain"].chain[0]
|
||||||
|
if isinstance(json_comp, Json):
|
||||||
|
m = f"🔨 调用工具: {json_comp.data.get('name')}"
|
||||||
|
else:
|
||||||
|
m = "🔨 调用工具..."
|
||||||
|
chain = MessageChain(type="tool_call").message(m)
|
||||||
|
await astr_event.send(chain)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if stream_to_general and resp.type == "streaming_delta":
|
if stream_to_general and resp.type == "streaming_delta":
|
||||||
@@ -69,6 +81,15 @@ async def run_agent(
|
|||||||
continue
|
continue
|
||||||
yield resp.data["chain"] # MessageChain
|
yield resp.data["chain"] # MessageChain
|
||||||
if agent_runner.done():
|
if agent_runner.done():
|
||||||
|
# send agent stats to webchat
|
||||||
|
if astr_event.get_platform_name() == "webchat":
|
||||||
|
await astr_event.send(
|
||||||
|
MessageChain(
|
||||||
|
type="agent_stats",
|
||||||
|
chain=[Json(data=agent_runner.stats.to_dict())],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.8.0"
|
VERSION = "4.9.2"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||||
@@ -108,6 +108,7 @@ DEFAULT_CONFIG = {
|
|||||||
"provider_id": "",
|
"provider_id": "",
|
||||||
"dual_output": False,
|
"dual_output": False,
|
||||||
"use_file_service": False,
|
"use_file_service": False,
|
||||||
|
"trigger_probability": 1.0,
|
||||||
},
|
},
|
||||||
"provider_ltm_settings": {
|
"provider_ltm_settings": {
|
||||||
"group_icl_enable": False,
|
"group_icl_enable": False,
|
||||||
@@ -208,7 +209,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"callback_server_host": "0.0.0.0",
|
"callback_server_host": "0.0.0.0",
|
||||||
"port": 6196,
|
"port": 6196,
|
||||||
},
|
},
|
||||||
"QQ 个人号(OneBot v11)": {
|
"OneBot v11": {
|
||||||
"id": "default",
|
"id": "default",
|
||||||
"type": "aiocqhttp",
|
"type": "aiocqhttp",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -945,7 +946,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gemini-1.5-flash",
|
"model": "gemini-3-flash-preview",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
"custom_headers": {},
|
"custom_headers": {},
|
||||||
@@ -962,7 +963,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://generativelanguage.googleapis.com/",
|
"api_base": "https://generativelanguage.googleapis.com/",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gemini-2.0-flash-exp",
|
"model": "gemini-3-flash-preview",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
"gm_resp_image_modal": False,
|
"gm_resp_image_modal": False,
|
||||||
@@ -975,9 +976,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
},
|
},
|
||||||
"gm_thinking_config": {
|
"gm_thinking_config": {"budget": 0, "level": "HIGH"},
|
||||||
"budget": 0,
|
|
||||||
},
|
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"DeepSeek": {
|
"DeepSeek": {
|
||||||
@@ -1818,13 +1817,24 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"gm_thinking_config": {
|
"gm_thinking_config": {
|
||||||
"description": "Gemini思考设置",
|
"description": "Thinking Config",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
"budget": {
|
"budget": {
|
||||||
"description": "思考预算",
|
"description": "Thinking Budget",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
|
"hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget",
|
||||||
|
},
|
||||||
|
"level": {
|
||||||
|
"description": "Thinking Level",
|
||||||
|
"type": "string",
|
||||||
|
"hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
|
||||||
|
"options": [
|
||||||
|
"MINIMAL",
|
||||||
|
"LOW",
|
||||||
|
"MEDIUM",
|
||||||
|
"HIGH",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -2209,6 +2219,9 @@ CONFIG_METADATA_2 = {
|
|||||||
"use_file_service": {
|
"use_file_service": {
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
|
"trigger_probability": {
|
||||||
|
"type": "float",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"provider_ltm_settings": {
|
"provider_ltm_settings": {
|
||||||
@@ -2419,6 +2432,14 @@ CONFIG_METADATA_3 = {
|
|||||||
"provider_tts_settings.enable": True,
|
"provider_tts_settings.enable": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
"provider_tts_settings.trigger_probability": {
|
||||||
|
"description": "TTS 触发概率",
|
||||||
|
"type": "float",
|
||||||
|
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||||
|
"condition": {
|
||||||
|
"provider_tts_settings.enable": True,
|
||||||
|
},
|
||||||
|
},
|
||||||
"provider_settings.image_caption_prompt": {
|
"provider_settings.image_caption_prompt": {
|
||||||
"description": "图片转述提示词",
|
"description": "图片转述提示词",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
@@ -2986,6 +3007,7 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "回复概率",
|
"description": "回复概率",
|
||||||
"type": "float",
|
"type": "float",
|
||||||
"hint": "0.0-1.0 之间的数值",
|
"hint": "0.0-1.0 之间的数值",
|
||||||
|
"slider": {"min": 0, "max": 1, "step": 0.05},
|
||||||
"condition": {
|
"condition": {
|
||||||
"provider_ltm_settings.active_reply.enable": True,
|
"provider_ltm_settings.active_reply.enable": True,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -79,6 +79,7 @@ class ConfigMetadataI18n:
|
|||||||
"_special",
|
"_special",
|
||||||
"invisible",
|
"invisible",
|
||||||
"options",
|
"options",
|
||||||
|
"slider",
|
||||||
]:
|
]:
|
||||||
if attr in field_data:
|
if attr in field_data:
|
||||||
field_result[attr] = field_data[attr]
|
field_result[attr] = field_data[attr]
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
|||||||
|
|
||||||
from astrbot.core.db.po import (
|
from astrbot.core.db.po import (
|
||||||
Attachment,
|
Attachment,
|
||||||
|
CommandConfig,
|
||||||
|
CommandConflict,
|
||||||
ConversationV2,
|
ConversationV2,
|
||||||
Persona,
|
Persona,
|
||||||
PlatformMessageHistory,
|
PlatformMessageHistory,
|
||||||
@@ -314,6 +316,76 @@ class BaseDatabase(abc.ABC):
|
|||||||
"""Clear all preferences for a specific scope ID."""
|
"""Clear all preferences for a specific scope ID."""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def get_command_configs(self) -> list[CommandConfig]:
|
||||||
|
"""Get all stored command configurations."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
|
||||||
|
"""Fetch a single command configuration by handler."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def upsert_command_config(
|
||||||
|
self,
|
||||||
|
handler_full_name: str,
|
||||||
|
plugin_name: str,
|
||||||
|
module_path: str,
|
||||||
|
original_command: str,
|
||||||
|
*,
|
||||||
|
resolved_command: str | None = None,
|
||||||
|
enabled: bool | None = None,
|
||||||
|
keep_original_alias: bool | None = None,
|
||||||
|
conflict_key: str | None = None,
|
||||||
|
resolution_strategy: str | None = None,
|
||||||
|
note: str | None = None,
|
||||||
|
extra_data: dict | None = None,
|
||||||
|
auto_managed: bool | None = None,
|
||||||
|
) -> CommandConfig:
|
||||||
|
"""Create or update a command configuration."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||||
|
"""Delete a single command configuration."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||||
|
"""Bulk delete command configurations."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def list_command_conflicts(
|
||||||
|
self,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> list[CommandConflict]:
|
||||||
|
"""List recorded command conflict entries."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def upsert_command_conflict(
|
||||||
|
self,
|
||||||
|
conflict_key: str,
|
||||||
|
handler_full_name: str,
|
||||||
|
plugin_name: str,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
resolution: str | None = None,
|
||||||
|
resolved_command: str | None = None,
|
||||||
|
note: str | None = None,
|
||||||
|
extra_data: dict | None = None,
|
||||||
|
auto_generated: bool | None = None,
|
||||||
|
) -> CommandConflict:
|
||||||
|
"""Create or update a conflict record."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||||
|
"""Delete conflict records."""
|
||||||
|
...
|
||||||
|
|
||||||
# @abc.abstractmethod
|
# @abc.abstractmethod
|
||||||
# async def insert_llm_message(
|
# async def insert_llm_message(
|
||||||
# self,
|
# self,
|
||||||
|
|||||||
@@ -234,6 +234,65 @@ class Attachment(SQLModel, table=True):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandConfig(SQLModel, table=True):
|
||||||
|
"""Per-command configuration overrides for dashboard management."""
|
||||||
|
|
||||||
|
__tablename__ = "command_configs" # type: ignore
|
||||||
|
|
||||||
|
handler_full_name: str = Field(
|
||||||
|
primary_key=True,
|
||||||
|
max_length=512,
|
||||||
|
)
|
||||||
|
plugin_name: str = Field(nullable=False, max_length=255)
|
||||||
|
module_path: str = Field(nullable=False, max_length=255)
|
||||||
|
original_command: str = Field(nullable=False, max_length=255)
|
||||||
|
resolved_command: str | None = Field(default=None, max_length=255)
|
||||||
|
enabled: bool = Field(default=True, nullable=False)
|
||||||
|
keep_original_alias: bool = Field(default=False, nullable=False)
|
||||||
|
conflict_key: str | None = Field(default=None, max_length=255)
|
||||||
|
resolution_strategy: str | None = Field(default=None, max_length=64)
|
||||||
|
note: str | None = Field(default=None, sa_type=Text)
|
||||||
|
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||||
|
auto_managed: bool = Field(default=False, nullable=False)
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at: datetime = Field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
|
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CommandConflict(SQLModel, table=True):
|
||||||
|
"""Conflict tracking for duplicated command names."""
|
||||||
|
|
||||||
|
__tablename__ = "command_conflicts" # type: ignore
|
||||||
|
|
||||||
|
id: int | None = Field(
|
||||||
|
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
||||||
|
)
|
||||||
|
conflict_key: str = Field(nullable=False, max_length=255)
|
||||||
|
handler_full_name: str = Field(nullable=False, max_length=512)
|
||||||
|
plugin_name: str = Field(nullable=False, max_length=255)
|
||||||
|
status: str = Field(default="pending", max_length=32)
|
||||||
|
resolution: str | None = Field(default=None, max_length=64)
|
||||||
|
resolved_command: str | None = Field(default=None, max_length=255)
|
||||||
|
note: str | None = Field(default=None, sa_type=Text)
|
||||||
|
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
||||||
|
auto_generated: bool = Field(default=False, nullable=False)
|
||||||
|
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
||||||
|
updated_at: datetime = Field(
|
||||||
|
default_factory=lambda: datetime.now(timezone.utc),
|
||||||
|
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
||||||
|
)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"conflict_key",
|
||||||
|
"handler_full_name",
|
||||||
|
name="uix_conflict_handler",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Conversation:
|
class Conversation:
|
||||||
"""LLM 对话类
|
"""LLM 对话类
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
import typing as T
|
import typing as T
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from sqlalchemy import CursorResult
|
from sqlalchemy import CursorResult
|
||||||
@@ -10,6 +11,8 @@ from sqlmodel import col, delete, desc, func, or_, select, text, update
|
|||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core.db.po import (
|
from astrbot.core.db.po import (
|
||||||
Attachment,
|
Attachment,
|
||||||
|
CommandConfig,
|
||||||
|
CommandConflict,
|
||||||
ConversationV2,
|
ConversationV2,
|
||||||
Persona,
|
Persona,
|
||||||
PlatformMessageHistory,
|
PlatformMessageHistory,
|
||||||
@@ -26,6 +29,7 @@ from astrbot.core.db.po import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||||
|
TxResult = T.TypeVar("TxResult")
|
||||||
|
|
||||||
|
|
||||||
class SQLiteDatabase(BaseDatabase):
|
class SQLiteDatabase(BaseDatabase):
|
||||||
@@ -670,6 +674,242 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
# ====
|
||||||
|
# Command Configuration & Conflict Tracking
|
||||||
|
# ====
|
||||||
|
|
||||||
|
async def _run_in_tx(
|
||||||
|
self,
|
||||||
|
fn: Callable[[AsyncSession], Awaitable[TxResult]],
|
||||||
|
) -> TxResult:
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
async with session.begin():
|
||||||
|
return await fn(session)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _apply_updates(model, **updates) -> None:
|
||||||
|
for field, value in updates.items():
|
||||||
|
if value is not None:
|
||||||
|
setattr(model, field, value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _new_command_config(
|
||||||
|
handler_full_name: str,
|
||||||
|
plugin_name: str,
|
||||||
|
module_path: str,
|
||||||
|
original_command: str,
|
||||||
|
*,
|
||||||
|
resolved_command: str | None = None,
|
||||||
|
enabled: bool | None = None,
|
||||||
|
keep_original_alias: bool | None = None,
|
||||||
|
conflict_key: str | None = None,
|
||||||
|
resolution_strategy: str | None = None,
|
||||||
|
note: str | None = None,
|
||||||
|
extra_data: dict | None = None,
|
||||||
|
auto_managed: bool | None = None,
|
||||||
|
) -> CommandConfig:
|
||||||
|
return CommandConfig(
|
||||||
|
handler_full_name=handler_full_name,
|
||||||
|
plugin_name=plugin_name,
|
||||||
|
module_path=module_path,
|
||||||
|
original_command=original_command,
|
||||||
|
resolved_command=resolved_command,
|
||||||
|
enabled=True if enabled is None else enabled,
|
||||||
|
keep_original_alias=False
|
||||||
|
if keep_original_alias is None
|
||||||
|
else keep_original_alias,
|
||||||
|
conflict_key=conflict_key or original_command,
|
||||||
|
resolution_strategy=resolution_strategy,
|
||||||
|
note=note,
|
||||||
|
extra_data=extra_data,
|
||||||
|
auto_managed=bool(auto_managed),
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _new_command_conflict(
|
||||||
|
conflict_key: str,
|
||||||
|
handler_full_name: str,
|
||||||
|
plugin_name: str,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
resolution: str | None = None,
|
||||||
|
resolved_command: str | None = None,
|
||||||
|
note: str | None = None,
|
||||||
|
extra_data: dict | None = None,
|
||||||
|
auto_generated: bool | None = None,
|
||||||
|
) -> CommandConflict:
|
||||||
|
return CommandConflict(
|
||||||
|
conflict_key=conflict_key,
|
||||||
|
handler_full_name=handler_full_name,
|
||||||
|
plugin_name=plugin_name,
|
||||||
|
status=status or "pending",
|
||||||
|
resolution=resolution,
|
||||||
|
resolved_command=resolved_command,
|
||||||
|
note=note,
|
||||||
|
extra_data=extra_data,
|
||||||
|
auto_generated=bool(auto_generated),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_command_configs(self) -> list[CommandConfig]:
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
result = await session.execute(select(CommandConfig))
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def get_command_config(
|
||||||
|
self,
|
||||||
|
handler_full_name: str,
|
||||||
|
) -> CommandConfig | None:
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
return await session.get(CommandConfig, handler_full_name)
|
||||||
|
|
||||||
|
async def upsert_command_config(
|
||||||
|
self,
|
||||||
|
handler_full_name: str,
|
||||||
|
plugin_name: str,
|
||||||
|
module_path: str,
|
||||||
|
original_command: str,
|
||||||
|
*,
|
||||||
|
resolved_command: str | None = None,
|
||||||
|
enabled: bool | None = None,
|
||||||
|
keep_original_alias: bool | None = None,
|
||||||
|
conflict_key: str | None = None,
|
||||||
|
resolution_strategy: str | None = None,
|
||||||
|
note: str | None = None,
|
||||||
|
extra_data: dict | None = None,
|
||||||
|
auto_managed: bool | None = None,
|
||||||
|
) -> CommandConfig:
|
||||||
|
async def _op(session: AsyncSession) -> CommandConfig:
|
||||||
|
config = await session.get(CommandConfig, handler_full_name)
|
||||||
|
if not config:
|
||||||
|
config = self._new_command_config(
|
||||||
|
handler_full_name,
|
||||||
|
plugin_name,
|
||||||
|
module_path,
|
||||||
|
original_command,
|
||||||
|
resolved_command=resolved_command,
|
||||||
|
enabled=enabled,
|
||||||
|
keep_original_alias=keep_original_alias,
|
||||||
|
conflict_key=conflict_key,
|
||||||
|
resolution_strategy=resolution_strategy,
|
||||||
|
note=note,
|
||||||
|
extra_data=extra_data,
|
||||||
|
auto_managed=auto_managed,
|
||||||
|
)
|
||||||
|
session.add(config)
|
||||||
|
else:
|
||||||
|
self._apply_updates(
|
||||||
|
config,
|
||||||
|
plugin_name=plugin_name,
|
||||||
|
module_path=module_path,
|
||||||
|
original_command=original_command,
|
||||||
|
resolved_command=resolved_command,
|
||||||
|
enabled=enabled,
|
||||||
|
keep_original_alias=keep_original_alias,
|
||||||
|
conflict_key=conflict_key,
|
||||||
|
resolution_strategy=resolution_strategy,
|
||||||
|
note=note,
|
||||||
|
extra_data=extra_data,
|
||||||
|
auto_managed=auto_managed,
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
await session.refresh(config)
|
||||||
|
return config
|
||||||
|
|
||||||
|
return await self._run_in_tx(_op)
|
||||||
|
|
||||||
|
async def delete_command_config(self, handler_full_name: str) -> None:
|
||||||
|
await self.delete_command_configs([handler_full_name])
|
||||||
|
|
||||||
|
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
||||||
|
if not handler_full_names:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _op(session: AsyncSession) -> None:
|
||||||
|
await session.execute(
|
||||||
|
delete(CommandConfig).where(
|
||||||
|
col(CommandConfig.handler_full_name).in_(handler_full_names),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._run_in_tx(_op)
|
||||||
|
|
||||||
|
async def list_command_conflicts(
|
||||||
|
self,
|
||||||
|
status: str | None = None,
|
||||||
|
) -> list[CommandConflict]:
|
||||||
|
async with self.get_db() as session:
|
||||||
|
session: AsyncSession
|
||||||
|
query = select(CommandConflict)
|
||||||
|
if status:
|
||||||
|
query = query.where(CommandConflict.status == status)
|
||||||
|
result = await session.execute(query)
|
||||||
|
return list(result.scalars().all())
|
||||||
|
|
||||||
|
async def upsert_command_conflict(
|
||||||
|
self,
|
||||||
|
conflict_key: str,
|
||||||
|
handler_full_name: str,
|
||||||
|
plugin_name: str,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
resolution: str | None = None,
|
||||||
|
resolved_command: str | None = None,
|
||||||
|
note: str | None = None,
|
||||||
|
extra_data: dict | None = None,
|
||||||
|
auto_generated: bool | None = None,
|
||||||
|
) -> CommandConflict:
|
||||||
|
async def _op(session: AsyncSession) -> CommandConflict:
|
||||||
|
result = await session.execute(
|
||||||
|
select(CommandConflict).where(
|
||||||
|
CommandConflict.conflict_key == conflict_key,
|
||||||
|
CommandConflict.handler_full_name == handler_full_name,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if not record:
|
||||||
|
record = self._new_command_conflict(
|
||||||
|
conflict_key,
|
||||||
|
handler_full_name,
|
||||||
|
plugin_name,
|
||||||
|
status=status,
|
||||||
|
resolution=resolution,
|
||||||
|
resolved_command=resolved_command,
|
||||||
|
note=note,
|
||||||
|
extra_data=extra_data,
|
||||||
|
auto_generated=auto_generated,
|
||||||
|
)
|
||||||
|
session.add(record)
|
||||||
|
else:
|
||||||
|
self._apply_updates(
|
||||||
|
record,
|
||||||
|
plugin_name=plugin_name,
|
||||||
|
status=status,
|
||||||
|
resolution=resolution,
|
||||||
|
resolved_command=resolved_command,
|
||||||
|
note=note,
|
||||||
|
extra_data=extra_data,
|
||||||
|
auto_generated=auto_generated,
|
||||||
|
)
|
||||||
|
await session.flush()
|
||||||
|
await session.refresh(record)
|
||||||
|
return record
|
||||||
|
|
||||||
|
return await self._run_in_tx(_op)
|
||||||
|
|
||||||
|
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
||||||
|
if not ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
async def _op(session: AsyncSession) -> None:
|
||||||
|
await session.execute(
|
||||||
|
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await self._run_in_tx(_op)
|
||||||
|
|
||||||
# ====
|
# ====
|
||||||
# Deprecated Methods
|
# Deprecated Methods
|
||||||
# ====
|
# ====
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
@@ -148,7 +149,7 @@ class LogQueueHandler(logging.Handler):
|
|||||||
self.log_broker.publish(
|
self.log_broker.publish(
|
||||||
{
|
{
|
||||||
"level": record.levelname,
|
"level": record.levelname,
|
||||||
"time": record.asctime,
|
"time": time.time(),
|
||||||
"data": log_entry,
|
"data": log_entry,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -629,12 +629,11 @@ class Nodes(BaseMessageComponent):
|
|||||||
|
|
||||||
class Json(BaseMessageComponent):
|
class Json(BaseMessageComponent):
|
||||||
type = ComponentType.Json
|
type = ComponentType.Json
|
||||||
data: str | dict
|
data: dict
|
||||||
resid: int | None = 0
|
|
||||||
|
|
||||||
def __init__(self, data, **_):
|
def __init__(self, data: str | dict, **_):
|
||||||
if isinstance(data, dict):
|
if isinstance(data, str):
|
||||||
data = json.dumps(data)
|
data = json.loads(data)
|
||||||
super().__init__(data=data, **_)
|
super().__init__(data=data, **_)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class RespondStage(Stage):
|
|||||||
|
|
||||||
if (result := event.get_result()) is None:
|
if (result := event.get_result()) is None:
|
||||||
return False
|
return False
|
||||||
if self.only_llm_result and result.is_llm_result():
|
if self.only_llm_result and not result.is_llm_result():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if event.get_platform_name() in [
|
if event.get_platform_name() in [
|
||||||
@@ -158,7 +158,11 @@ class RespondStage(Stage):
|
|||||||
result = event.get_result()
|
result = event.get_result()
|
||||||
if result is None:
|
if result is None:
|
||||||
return
|
return
|
||||||
|
if event.get_extra("_streaming_finished", False):
|
||||||
|
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
|
||||||
|
return
|
||||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||||
|
event.set_extra("_streaming_finished", True)
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@@ -42,6 +43,18 @@ class ResultDecorateStage(Stage):
|
|||||||
"forward_threshold"
|
"forward_threshold"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
|
||||||
|
"trigger_probability",
|
||||||
|
1,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
self.tts_trigger_probability = max(
|
||||||
|
0.0,
|
||||||
|
min(float(trigger_probability), 1.0),
|
||||||
|
)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
self.tts_trigger_probability = 1.0
|
||||||
|
|
||||||
# 分段回复
|
# 分段回复
|
||||||
self.words_count_threshold = int(
|
self.words_count_threshold = int(
|
||||||
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||||
@@ -246,7 +259,14 @@ class ResultDecorateStage(Stage):
|
|||||||
and result.is_llm_result()
|
and result.is_llm_result()
|
||||||
and SessionServiceManager.should_process_tts_request(event)
|
and SessionServiceManager.should_process_tts_request(event)
|
||||||
):
|
):
|
||||||
if not tts_provider:
|
should_tts = self.tts_trigger_probability >= 1.0 or (
|
||||||
|
self.tts_trigger_probability > 0.0
|
||||||
|
and random.random() <= self.tts_trigger_probability
|
||||||
|
)
|
||||||
|
|
||||||
|
if not should_tts:
|
||||||
|
logger.debug("跳过 TTS:触发概率未命中。")
|
||||||
|
elif not tts_provider:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -112,10 +112,6 @@ class PlatformManager:
|
|||||||
from .sources.satori.satori_adapter import (
|
from .sources.satori.satori_adapter import (
|
||||||
SatoriPlatformAdapter, # noqa: F401
|
SatoriPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
case "github_webhook":
|
|
||||||
from .sources.github_webhook.github_webhook_adapter import (
|
|
||||||
GitHubWebhookPlatformAdapter, # noqa: F401
|
|
||||||
)
|
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||||
|
|||||||
@@ -1,315 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
from typing import Any, cast
|
|
||||||
|
|
||||||
from astrbot import logger
|
|
||||||
from astrbot.api.event import MessageChain
|
|
||||||
from astrbot.api.message_components import Plain
|
|
||||||
from astrbot.api.platform import (
|
|
||||||
AstrBotMessage,
|
|
||||||
MessageMember,
|
|
||||||
MessageType,
|
|
||||||
Platform,
|
|
||||||
PlatformMetadata,
|
|
||||||
)
|
|
||||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
|
||||||
from astrbot.core.platform.platform import PlatformStatus
|
|
||||||
from astrbot.core.utils.webhook_utils import log_webhook_info
|
|
||||||
|
|
||||||
from ...register import register_platform_adapter
|
|
||||||
from .github_webhook_event import GitHubWebhookMessageEvent
|
|
||||||
|
|
||||||
|
|
||||||
@register_platform_adapter(
|
|
||||||
"github_webhook",
|
|
||||||
"GitHub Webhook 适配器",
|
|
||||||
support_streaming_message=False,
|
|
||||||
)
|
|
||||||
class GitHubWebhookPlatformAdapter(Platform):
|
|
||||||
"""GitHub Webhook 平台适配器
|
|
||||||
|
|
||||||
支持的事件:
|
|
||||||
- issues (created)
|
|
||||||
- issue_comment (created)
|
|
||||||
- pull_request (opened)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
platform_config: dict,
|
|
||||||
platform_settings: dict,
|
|
||||||
event_queue: asyncio.Queue,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(platform_config, event_queue)
|
|
||||||
|
|
||||||
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", True)
|
|
||||||
self.webhook_secret = platform_config.get("webhook_secret", "")
|
|
||||||
self.shutdown_event = asyncio.Event()
|
|
||||||
|
|
||||||
async def send_by_session(
|
|
||||||
self,
|
|
||||||
session: MessageSesion,
|
|
||||||
message_chain: MessageChain,
|
|
||||||
):
|
|
||||||
"""GitHub Webhook 是单向接收,不支持主动发送消息"""
|
|
||||||
logger.warning("GitHub Webhook 适配器不支持 send_by_session")
|
|
||||||
|
|
||||||
def meta(self) -> PlatformMetadata:
|
|
||||||
return PlatformMetadata(
|
|
||||||
name="github_webhook",
|
|
||||||
description="GitHub Webhook 适配器",
|
|
||||||
id=cast(str, self.config.get("id")),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""运行适配器"""
|
|
||||||
self.status = PlatformStatus.RUNNING
|
|
||||||
|
|
||||||
# 如果启用统一 webhook 模式
|
|
||||||
webhook_uuid = self.config.get("webhook_uuid")
|
|
||||||
if self.unified_webhook_mode and webhook_uuid:
|
|
||||||
log_webhook_info(f"{self.meta().id}(GitHub Webhook)", webhook_uuid)
|
|
||||||
# 保持运行状态,等待 shutdown
|
|
||||||
await self.shutdown_event.wait()
|
|
||||||
else:
|
|
||||||
logger.warning("GitHub Webhook 适配器需要启用统一 webhook 模式")
|
|
||||||
await self.shutdown_event.wait()
|
|
||||||
|
|
||||||
async def webhook_callback(self, request: Any) -> Any:
|
|
||||||
"""统一 Webhook 回调入口
|
|
||||||
|
|
||||||
处理 GitHub webhook 事件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Quart 请求对象
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
响应数据
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 获取事件类型
|
|
||||||
event_type = request.headers.get("X-GitHub-Event", "")
|
|
||||||
|
|
||||||
# 获取请求数据
|
|
||||||
payload = await request.json
|
|
||||||
|
|
||||||
# 验证 webhook 签名(如果配置了 secret)
|
|
||||||
if self.webhook_secret:
|
|
||||||
if not await self._verify_signature(request, payload):
|
|
||||||
logger.warning("GitHub webhook 签名验证失败")
|
|
||||||
return {"error": "Invalid signature"}, 401
|
|
||||||
|
|
||||||
logger.debug(f"收到 GitHub Webhook 事件: {event_type}")
|
|
||||||
|
|
||||||
# 处理不同类型的事件
|
|
||||||
if event_type == "issues":
|
|
||||||
await self._handle_issue_event(payload)
|
|
||||||
elif event_type == "issue_comment":
|
|
||||||
await self._handle_issue_comment_event(payload)
|
|
||||||
elif event_type == "pull_request":
|
|
||||||
await self._handle_pull_request_event(payload)
|
|
||||||
elif event_type == "ping":
|
|
||||||
# GitHub webhook 验证事件
|
|
||||||
return {"message": "pong"}
|
|
||||||
else:
|
|
||||||
logger.debug(f"忽略不支持的 GitHub 事件类型: {event_type}")
|
|
||||||
|
|
||||||
return {"status": "ok"}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"处理 GitHub webhook 回调时发生错误: {e}", exc_info=True)
|
|
||||||
return {"error": str(e)}, 500
|
|
||||||
|
|
||||||
async def _verify_signature(self, request: Any, payload: dict) -> bool:
|
|
||||||
"""验证 GitHub webhook 签名
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: Quart 请求对象
|
|
||||||
payload: 请求负载数据
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
签名是否有效
|
|
||||||
"""
|
|
||||||
signature_header = request.headers.get("X-Hub-Signature-256", "")
|
|
||||||
if not signature_header:
|
|
||||||
# 如果没有签名头,检查是否有旧版本的签名
|
|
||||||
signature_header = request.headers.get("X-Hub-Signature", "")
|
|
||||||
if not signature_header:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 获取原始请求体
|
|
||||||
body = await request.get_data()
|
|
||||||
|
|
||||||
# 计算 HMAC
|
|
||||||
if signature_header.startswith("sha256="):
|
|
||||||
expected_signature = hmac.new(
|
|
||||||
self.webhook_secret.encode("utf-8"),
|
|
||||||
body,
|
|
||||||
hashlib.sha256,
|
|
||||||
).hexdigest()
|
|
||||||
received_signature = signature_header.replace("sha256=", "")
|
|
||||||
elif signature_header.startswith("sha1="):
|
|
||||||
expected_signature = hmac.new(
|
|
||||||
self.webhook_secret.encode("utf-8"),
|
|
||||||
body,
|
|
||||||
hashlib.sha1,
|
|
||||||
).hexdigest()
|
|
||||||
received_signature = signature_header.replace("sha1=", "")
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 使用 hmac.compare_digest 防止时序攻击
|
|
||||||
return hmac.compare_digest(expected_signature, received_signature)
|
|
||||||
|
|
||||||
async def _handle_issue_event(self, payload: dict):
|
|
||||||
"""处理 issue 事件"""
|
|
||||||
action = payload.get("action", "")
|
|
||||||
|
|
||||||
# 只处理创建事件
|
|
||||||
if action != "created" and action != "opened":
|
|
||||||
return
|
|
||||||
|
|
||||||
issue = payload.get("issue", {})
|
|
||||||
repo = payload.get("repository", {})
|
|
||||||
sender = payload.get("sender", {})
|
|
||||||
|
|
||||||
# 构造消息文本
|
|
||||||
message_text = (
|
|
||||||
f"📝 新 Issue 创建\n"
|
|
||||||
f"仓库: {repo.get('full_name', 'unknown')}\n"
|
|
||||||
f"标题: {issue.get('title', 'No title')}\n"
|
|
||||||
f"作者: {sender.get('login', 'unknown')}\n"
|
|
||||||
f"链接: {issue.get('html_url', '')}\n"
|
|
||||||
f"内容:\n{issue.get('body', 'No description')[:200]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 AstrBotMessage
|
|
||||||
abm = self._create_message(
|
|
||||||
message_text,
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提交事件
|
|
||||||
self.commit_event(
|
|
||||||
GitHubWebhookMessageEvent(
|
|
||||||
message_text,
|
|
||||||
abm,
|
|
||||||
self.meta(),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
"issues",
|
|
||||||
payload,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_issue_comment_event(self, payload: dict):
|
|
||||||
"""处理 issue 评论事件"""
|
|
||||||
action = payload.get("action", "")
|
|
||||||
|
|
||||||
# 只处理创建事件
|
|
||||||
if action != "created":
|
|
||||||
return
|
|
||||||
|
|
||||||
issue = payload.get("issue", {})
|
|
||||||
comment = payload.get("comment", {})
|
|
||||||
repo = payload.get("repository", {})
|
|
||||||
sender = payload.get("sender", {})
|
|
||||||
|
|
||||||
# 构造消息文本
|
|
||||||
message_text = (
|
|
||||||
f"💬 新 Issue 评论\n"
|
|
||||||
f"仓库: {repo.get('full_name', 'unknown')}\n"
|
|
||||||
f"Issue: {issue.get('title', 'No title')}\n"
|
|
||||||
f"评论者: {sender.get('login', 'unknown')}\n"
|
|
||||||
f"链接: {comment.get('html_url', '')}\n"
|
|
||||||
f"内容:\n{comment.get('body', 'No comment')[:200]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 AstrBotMessage
|
|
||||||
abm = self._create_message(
|
|
||||||
message_text,
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提交事件
|
|
||||||
self.commit_event(
|
|
||||||
GitHubWebhookMessageEvent(
|
|
||||||
message_text,
|
|
||||||
abm,
|
|
||||||
self.meta(),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
"issue_comment",
|
|
||||||
payload,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _handle_pull_request_event(self, payload: dict):
|
|
||||||
"""处理 pull request 事件"""
|
|
||||||
action = payload.get("action", "")
|
|
||||||
|
|
||||||
# 只处理打开事件
|
|
||||||
if action != "opened":
|
|
||||||
return
|
|
||||||
|
|
||||||
pr = payload.get("pull_request", {})
|
|
||||||
repo = payload.get("repository", {})
|
|
||||||
sender = payload.get("sender", {})
|
|
||||||
|
|
||||||
# 构造消息文本
|
|
||||||
message_text = (
|
|
||||||
f"🔀 新 Pull Request\n"
|
|
||||||
f"仓库: {repo.get('full_name', 'unknown')}\n"
|
|
||||||
f"标题: {pr.get('title', 'No title')}\n"
|
|
||||||
f"作者: {sender.get('login', 'unknown')}\n"
|
|
||||||
f"链接: {pr.get('html_url', '')}\n"
|
|
||||||
f"内容:\n{pr.get('body', 'No description')[:200]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建 AstrBotMessage
|
|
||||||
abm = self._create_message(
|
|
||||||
message_text,
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
sender.get("login", "unknown"),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# 提交事件
|
|
||||||
self.commit_event(
|
|
||||||
GitHubWebhookMessageEvent(
|
|
||||||
message_text,
|
|
||||||
abm,
|
|
||||||
self.meta(),
|
|
||||||
repo.get("full_name", "unknown"),
|
|
||||||
"pull_request",
|
|
||||||
payload,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _create_message(
|
|
||||||
self,
|
|
||||||
message_text: str,
|
|
||||||
user_id: str,
|
|
||||||
nickname: str,
|
|
||||||
session_id: str,
|
|
||||||
) -> AstrBotMessage:
|
|
||||||
"""创建 AstrBotMessage 对象"""
|
|
||||||
abm = AstrBotMessage()
|
|
||||||
abm.type = MessageType.GROUP_MESSAGE
|
|
||||||
abm.self_id = self.client_self_id
|
|
||||||
abm.session_id = session_id
|
|
||||||
abm.message_id = ""
|
|
||||||
abm.sender = MessageMember(user_id=user_id, nickname=nickname)
|
|
||||||
abm.message = [Plain(message_text)]
|
|
||||||
abm.message_str = message_text
|
|
||||||
abm.raw_message = message_text
|
|
||||||
|
|
||||||
return abm
|
|
||||||
|
|
||||||
async def terminate(self):
|
|
||||||
"""终止适配器运行"""
|
|
||||||
self.shutdown_event.set()
|
|
||||||
logger.info("GitHub Webhook 适配器已经被优雅地关闭")
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
|
||||||
|
|
||||||
from ...astr_message_event import AstrMessageEvent
|
|
||||||
|
|
||||||
|
|
||||||
class GitHubWebhookMessageEvent(AstrMessageEvent):
|
|
||||||
"""GitHub Webhook 消息事件"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message_str: str,
|
|
||||||
message_obj: AstrBotMessage,
|
|
||||||
platform_meta: PlatformMetadata,
|
|
||||||
session_id: str,
|
|
||||||
event_type: str,
|
|
||||||
event_data: dict,
|
|
||||||
):
|
|
||||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
|
||||||
self.event_type = event_type
|
|
||||||
"""GitHub 事件类型: issues, issue_comment, pull_request"""
|
|
||||||
self.event_data = event_data
|
|
||||||
"""原始事件数据"""
|
|
||||||
@@ -81,7 +81,12 @@ class LarkPlatformAdapter(Platform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.lark_api = (
|
self.lark_api = (
|
||||||
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
|
lark.Client.builder()
|
||||||
|
.app_id(self.appid)
|
||||||
|
.app_secret(self.appsecret)
|
||||||
|
.log_level(lark.LogLevel.ERROR)
|
||||||
|
.domain(self.domain)
|
||||||
|
.build()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.webhook_server = None
|
self.webhook_server = None
|
||||||
|
|||||||
@@ -200,6 +200,15 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
if isinstance(chain, MessageChain):
|
if isinstance(chain, MessageChain):
|
||||||
if chain.type == "break":
|
if chain.type == "break":
|
||||||
# 分割符
|
# 分割符
|
||||||
|
if message_id:
|
||||||
|
try:
|
||||||
|
await self.client.edit_message_text(
|
||||||
|
text=delta,
|
||||||
|
chat_id=payload["chat_id"],
|
||||||
|
message_id=message_id,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
|
||||||
message_id = None # 重置消息 ID
|
message_id = None # 重置消息 ID
|
||||||
delta = "" # 重置 delta
|
delta = "" # 重置 delta
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
import base64
|
import base64
|
||||||
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import File, Image, Plain, Record
|
from astrbot.api.message_components import File, Image, Json, Plain, Record
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
from .webchat_queue_mgr import webchat_queue_mgr
|
from .webchat_queue_mgr import webchat_queue_mgr
|
||||||
@@ -41,12 +42,20 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{
|
{
|
||||||
"type": "plain",
|
"type": "plain",
|
||||||
"cid": cid,
|
|
||||||
"data": data,
|
"data": data,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
"chain_type": message.type,
|
"chain_type": message.type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
elif isinstance(comp, Json):
|
||||||
|
await web_chat_back_queue.put(
|
||||||
|
{
|
||||||
|
"type": "plain",
|
||||||
|
"data": json.dumps(comp.data, ensure_ascii=False),
|
||||||
|
"streaming": streaming,
|
||||||
|
"chain_type": message.type,
|
||||||
|
},
|
||||||
|
)
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
# save image to local
|
# save image to local
|
||||||
filename = f"{str(uuid.uuid4())}.jpg"
|
filename = f"{str(uuid.uuid4())}.jpg"
|
||||||
@@ -58,7 +67,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{
|
{
|
||||||
"type": "image",
|
"type": "image",
|
||||||
"cid": cid,
|
|
||||||
"data": data,
|
"data": data,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
},
|
},
|
||||||
@@ -74,7 +82,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{
|
{
|
||||||
"type": "record",
|
"type": "record",
|
||||||
"cid": cid,
|
|
||||||
"data": data,
|
"data": data,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
},
|
},
|
||||||
@@ -91,7 +98,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{
|
{
|
||||||
"type": "file",
|
"type": "file",
|
||||||
"cid": cid,
|
|
||||||
"data": data,
|
"data": data,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
},
|
},
|
||||||
@@ -111,18 +117,17 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
cid = self.session_id.split("!")[-1]
|
cid = self.session_id.split("!")[-1]
|
||||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||||
async for chain in generator:
|
async for chain in generator:
|
||||||
if chain.type == "break" and final_data:
|
# if chain.type == "break" and final_data:
|
||||||
# 分割符
|
# # 分割符
|
||||||
await web_chat_back_queue.put(
|
# await web_chat_back_queue.put(
|
||||||
{
|
# {
|
||||||
"type": "break", # break means a segment end
|
# "type": "break", # break means a segment end
|
||||||
"data": final_data,
|
# "data": final_data,
|
||||||
"streaming": True,
|
# "streaming": True,
|
||||||
"cid": cid,
|
# },
|
||||||
},
|
# )
|
||||||
)
|
# final_data = ""
|
||||||
final_data = ""
|
# continue
|
||||||
continue
|
|
||||||
|
|
||||||
r = await WebChatMessageEvent._send(
|
r = await WebChatMessageEvent._send(
|
||||||
chain,
|
chain,
|
||||||
@@ -142,7 +147,6 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
"data": final_data,
|
"data": final_data,
|
||||||
"reasoning": reasoning_content,
|
"reasoning": reasoning_content,
|
||||||
"streaming": True,
|
"streaming": True,
|
||||||
"cid": cid,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await super().send_streaming(generator, use_fallback)
|
await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
@@ -199,6 +201,38 @@ class ProviderRequest:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TokenUsage:
|
||||||
|
input_other: int = 0
|
||||||
|
"""The number of input tokens, excluding cached tokens."""
|
||||||
|
input_cached: int = 0
|
||||||
|
"""The number of input cached tokens."""
|
||||||
|
output: int = 0
|
||||||
|
"""The number of output tokens."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> int:
|
||||||
|
return self.input_other + self.input_cached + self.output
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input(self) -> int:
|
||||||
|
return self.input_other + self.input_cached
|
||||||
|
|
||||||
|
def __add__(self, other: TokenUsage) -> TokenUsage:
|
||||||
|
return TokenUsage(
|
||||||
|
input_other=self.input_other + other.input_other,
|
||||||
|
input_cached=self.input_cached + other.input_cached,
|
||||||
|
output=self.output + other.output,
|
||||||
|
)
|
||||||
|
|
||||||
|
def __sub__(self, other: TokenUsage) -> TokenUsage:
|
||||||
|
return TokenUsage(
|
||||||
|
input_other=self.input_other - other.input_other,
|
||||||
|
input_cached=self.input_cached - other.input_cached,
|
||||||
|
output=self.output - other.output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMResponse:
|
class LLMResponse:
|
||||||
role: str
|
role: str
|
||||||
@@ -227,6 +261,11 @@ class LLMResponse:
|
|||||||
is_chunk: bool = False
|
is_chunk: bool = False
|
||||||
"""Indicates if the response is a chunked response."""
|
"""Indicates if the response is a chunked response."""
|
||||||
|
|
||||||
|
id: str | None = None
|
||||||
|
"""The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
|
||||||
|
usage: TokenUsage | None = None
|
||||||
|
"""The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
role: str,
|
role: str,
|
||||||
@@ -241,6 +280,8 @@ class LLMResponse:
|
|||||||
| AnthropicMessage
|
| AnthropicMessage
|
||||||
| None = None,
|
| None = None,
|
||||||
is_chunk: bool = False,
|
is_chunk: bool = False,
|
||||||
|
id: str | None = None,
|
||||||
|
usage: TokenUsage | None = None,
|
||||||
):
|
):
|
||||||
"""初始化 LLMResponse
|
"""初始化 LLMResponse
|
||||||
|
|
||||||
|
|||||||
@@ -6,10 +6,12 @@ from mimetypes import guess_type
|
|||||||
import anthropic
|
import anthropic
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
from anthropic.types import Message
|
from anthropic.types import Message
|
||||||
|
from anthropic.types.message_delta_usage import MessageDeltaUsage
|
||||||
|
from anthropic.types.usage import Usage
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.provider import Provider
|
from astrbot.api.provider import Provider
|
||||||
from astrbot.core.provider.entities import LLMResponse
|
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
|
||||||
@@ -107,6 +109,22 @@ class ProviderAnthropic(Provider):
|
|||||||
|
|
||||||
return system_prompt, new_messages
|
return system_prompt, new_messages
|
||||||
|
|
||||||
|
def _extract_usage(self, usage: Usage) -> TokenUsage:
|
||||||
|
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
|
||||||
|
return TokenUsage(
|
||||||
|
input_other=usage.input_tokens or 0,
|
||||||
|
input_cached=usage.cache_read_input_tokens or 0,
|
||||||
|
output=usage.output_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
|
||||||
|
if usage.input_tokens is not None:
|
||||||
|
token_usage.input_other = usage.input_tokens
|
||||||
|
if usage.cache_read_input_tokens is not None:
|
||||||
|
token_usage.input_cached = usage.cache_read_input_tokens
|
||||||
|
if usage.output_tokens is not None:
|
||||||
|
token_usage.output = usage.output_tokens
|
||||||
|
|
||||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||||
if tools:
|
if tools:
|
||||||
if tool_list := tools.get_func_desc_anthropic_style():
|
if tool_list := tools.get_func_desc_anthropic_style():
|
||||||
@@ -131,6 +149,10 @@ class ProviderAnthropic(Provider):
|
|||||||
llm_response.tools_call_args.append(content_block.input)
|
llm_response.tools_call_args.append(content_block.input)
|
||||||
llm_response.tools_call_name.append(content_block.name)
|
llm_response.tools_call_name.append(content_block.name)
|
||||||
llm_response.tools_call_ids.append(content_block.id)
|
llm_response.tools_call_ids.append(content_block.id)
|
||||||
|
|
||||||
|
llm_response.id = completion.id
|
||||||
|
llm_response.usage = self._extract_usage(completion.usage)
|
||||||
|
|
||||||
# TODO(Soulter): 处理 end_turn 情况
|
# TODO(Soulter): 处理 end_turn 情况
|
||||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||||
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
||||||
@@ -152,9 +174,16 @@ class ProviderAnthropic(Provider):
|
|||||||
final_text = ""
|
final_text = ""
|
||||||
final_tool_calls = []
|
final_tool_calls = []
|
||||||
|
|
||||||
|
id = None
|
||||||
|
usage = TokenUsage()
|
||||||
|
|
||||||
async with self.client.messages.stream(**payloads) as stream:
|
async with self.client.messages.stream(**payloads) as stream:
|
||||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
|
if event.type == "message_start":
|
||||||
|
# the usage contains input token usage
|
||||||
|
id = event.message.id
|
||||||
|
usage = self._extract_usage(event.message.usage)
|
||||||
if event.type == "content_block_start":
|
if event.type == "content_block_start":
|
||||||
if event.content_block.type == "text":
|
if event.content_block.type == "text":
|
||||||
# 文本块开始
|
# 文本块开始
|
||||||
@@ -162,6 +191,8 @@ class ProviderAnthropic(Provider):
|
|||||||
role="assistant",
|
role="assistant",
|
||||||
completion_text="",
|
completion_text="",
|
||||||
is_chunk=True,
|
is_chunk=True,
|
||||||
|
usage=usage,
|
||||||
|
id=id,
|
||||||
)
|
)
|
||||||
elif event.content_block.type == "tool_use":
|
elif event.content_block.type == "tool_use":
|
||||||
# 工具使用块开始,初始化缓冲区
|
# 工具使用块开始,初始化缓冲区
|
||||||
@@ -179,6 +210,8 @@ class ProviderAnthropic(Provider):
|
|||||||
role="assistant",
|
role="assistant",
|
||||||
completion_text=event.delta.text,
|
completion_text=event.delta.text,
|
||||||
is_chunk=True,
|
is_chunk=True,
|
||||||
|
usage=usage,
|
||||||
|
id=id,
|
||||||
)
|
)
|
||||||
elif event.delta.type == "input_json_delta":
|
elif event.delta.type == "input_json_delta":
|
||||||
# 工具调用参数增量
|
# 工具调用参数增量
|
||||||
@@ -215,6 +248,8 @@ class ProviderAnthropic(Provider):
|
|||||||
tools_call_name=[tool_info["name"]],
|
tools_call_name=[tool_info["name"]],
|
||||||
tools_call_ids=[tool_info["id"]],
|
tools_call_ids=[tool_info["id"]],
|
||||||
is_chunk=True,
|
is_chunk=True,
|
||||||
|
usage=usage,
|
||||||
|
id=id,
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# JSON 解析失败,跳过这个工具调用
|
# JSON 解析失败,跳过这个工具调用
|
||||||
@@ -223,11 +258,17 @@ class ProviderAnthropic(Provider):
|
|||||||
# 清理缓冲区
|
# 清理缓冲区
|
||||||
del tool_use_buffer[event.index]
|
del tool_use_buffer[event.index]
|
||||||
|
|
||||||
|
elif event.type == "message_delta":
|
||||||
|
if event.usage:
|
||||||
|
self._update_usage(usage, event.usage)
|
||||||
|
|
||||||
# 返回最终的完整结果
|
# 返回最终的完整结果
|
||||||
final_response = LLMResponse(
|
final_response = LLMResponse(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
completion_text=final_text,
|
completion_text=final_text,
|
||||||
is_chunk=False,
|
is_chunk=False,
|
||||||
|
usage=usage,
|
||||||
|
id=id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_tool_calls:
|
if final_tool_calls:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import astrbot.core.message.components as Comp
|
|||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.provider import Provider
|
from astrbot.api.provider import Provider
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.entities import LLMResponse
|
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
|
||||||
@@ -138,7 +138,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
modalities = ["TEXT"]
|
modalities = ["TEXT"]
|
||||||
|
|
||||||
tool_list: list[types.Tool] | None = []
|
tool_list: list[types.Tool] | None = []
|
||||||
model_name = self.get_model()
|
model_name = payloads.get("model", self.get_model())
|
||||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||||
native_search = self.provider_config.get("gm_native_search", False)
|
native_search = self.provider_config.get("gm_native_search", False)
|
||||||
url_context = self.provider_config.get("gm_url_context", False)
|
url_context = self.provider_config.get("gm_url_context", False)
|
||||||
@@ -197,6 +197,37 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
types.Tool(function_declarations=func_desc["function_declarations"]),
|
types.Tool(function_declarations=func_desc["function_declarations"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# oper thinking config
|
||||||
|
thinking_config = None
|
||||||
|
if model_name.startswith("gemini-2.5"):
|
||||||
|
# The thinkingBudget parameter, introduced with the Gemini 2.5 series
|
||||||
|
thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
|
||||||
|
"budget", 0
|
||||||
|
)
|
||||||
|
if thinking_budget is not None:
|
||||||
|
thinking_config = types.ThinkingConfig(
|
||||||
|
thinking_budget=thinking_budget,
|
||||||
|
)
|
||||||
|
elif model_name.startswith("gemini-3"):
|
||||||
|
# The thinkingLevel parameter, recommended for Gemini 3 models and onwards
|
||||||
|
# Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
|
||||||
|
thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
|
||||||
|
"level", "HIGH"
|
||||||
|
)
|
||||||
|
if thinking_level and isinstance(thinking_level, str):
|
||||||
|
thinking_level = thinking_level.upper()
|
||||||
|
if thinking_level not in ["MINIMAL", "LOW", "MEDIUM", "HIGH"]:
|
||||||
|
logger.warning(
|
||||||
|
f"Invalid thinking level: {thinking_level}, using HIGH"
|
||||||
|
)
|
||||||
|
thinking_level = "HIGH"
|
||||||
|
level = types.ThinkingLevel(thinking_level)
|
||||||
|
thinking_config = types.ThinkingConfig()
|
||||||
|
if not hasattr(types.ThinkingConfig, "thinking_level"):
|
||||||
|
setattr(types.ThinkingConfig, "thinking_level", level)
|
||||||
|
else:
|
||||||
|
thinking_config.thinking_level = level
|
||||||
|
|
||||||
return types.GenerateContentConfig(
|
return types.GenerateContentConfig(
|
||||||
system_instruction=system_instruction,
|
system_instruction=system_instruction,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -216,22 +247,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
response_modalities=modalities,
|
response_modalities=modalities,
|
||||||
tools=cast(types.ToolListUnion | None, tool_list),
|
tools=cast(types.ToolListUnion | None, tool_list),
|
||||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||||
thinking_config=(
|
thinking_config=thinking_config,
|
||||||
types.ThinkingConfig(
|
|
||||||
thinking_budget=min(
|
|
||||||
int(
|
|
||||||
self.provider_config.get("gm_thinking_config", {}).get(
|
|
||||||
"budget",
|
|
||||||
0,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
24576,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
if "gemini-2.5-flash" in self.get_model()
|
|
||||||
and hasattr(types.ThinkingConfig, "thinking_budget")
|
|
||||||
else None
|
|
||||||
),
|
|
||||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||||
disable=True,
|
disable=True,
|
||||||
),
|
),
|
||||||
@@ -347,6 +363,16 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
]
|
]
|
||||||
return "".join(thought_buf).strip()
|
return "".join(thought_buf).strip()
|
||||||
|
|
||||||
|
def _extract_usage(
|
||||||
|
self, usage_metadata: types.GenerateContentResponseUsageMetadata
|
||||||
|
) -> TokenUsage:
|
||||||
|
"""Extract usage from candidate"""
|
||||||
|
return TokenUsage(
|
||||||
|
input_other=usage_metadata.prompt_token_count or 0,
|
||||||
|
input_cached=usage_metadata.cached_content_token_count or 0,
|
||||||
|
output=usage_metadata.candidates_token_count or 0,
|
||||||
|
)
|
||||||
|
|
||||||
def _process_content_parts(
|
def _process_content_parts(
|
||||||
self,
|
self,
|
||||||
candidate: types.Candidate,
|
candidate: types.Candidate,
|
||||||
@@ -431,6 +457,8 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
model = payloads.get("model", self.get_model())
|
||||||
|
|
||||||
modalities = ["TEXT"]
|
modalities = ["TEXT"]
|
||||||
if self.provider_config.get("gm_resp_image_modal", False):
|
if self.provider_config.get("gm_resp_image_modal", False):
|
||||||
modalities.append("IMAGE")
|
modalities.append("IMAGE")
|
||||||
@@ -449,7 +477,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
temperature,
|
temperature,
|
||||||
)
|
)
|
||||||
result = await self.client.models.generate_content(
|
result = await self.client.models.generate_content(
|
||||||
model=self.get_model(),
|
model=model,
|
||||||
contents=cast(types.ContentListUnion, conversation),
|
contents=cast(types.ContentListUnion, conversation),
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
@@ -475,11 +503,11 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
e.message = ""
|
e.message = ""
|
||||||
if "Developer instruction is not enabled" in e.message:
|
if "Developer instruction is not enabled" in e.message:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||||
)
|
)
|
||||||
system_instruction = None
|
system_instruction = None
|
||||||
elif "Function calling is not enabled" in e.message:
|
elif "Function calling is not enabled" in e.message:
|
||||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
||||||
tools = None
|
tools = None
|
||||||
elif (
|
elif (
|
||||||
"Multi-modal output is not supported" in e.message
|
"Multi-modal output is not supported" in e.message
|
||||||
@@ -488,7 +516,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
or "only supports text output" in e.message
|
or "only supports text output" in e.message
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
|
f"{model} 不支持多模态输出,降级为文本模态",
|
||||||
)
|
)
|
||||||
modalities = ["TEXT"]
|
modalities = ["TEXT"]
|
||||||
else:
|
else:
|
||||||
@@ -501,6 +529,9 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
result.candidates[0],
|
result.candidates[0],
|
||||||
llm_response,
|
llm_response,
|
||||||
)
|
)
|
||||||
|
llm_response.id = result.response_id
|
||||||
|
if result.usage_metadata:
|
||||||
|
llm_response.usage = self._extract_usage(result.usage_metadata)
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
async def _query_stream(
|
async def _query_stream(
|
||||||
@@ -513,7 +544,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
model = payloads.get("model", self.get_model())
|
||||||
conversation = self._prepare_conversation(payloads)
|
conversation = self._prepare_conversation(payloads)
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
@@ -525,7 +556,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
system_instruction,
|
system_instruction,
|
||||||
)
|
)
|
||||||
result = await self.client.models.generate_content_stream(
|
result = await self.client.models.generate_content_stream(
|
||||||
model=self.get_model(),
|
model=model,
|
||||||
contents=cast(types.ContentListUnion, conversation),
|
contents=cast(types.ContentListUnion, conversation),
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
@@ -535,11 +566,11 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
e.message = ""
|
e.message = ""
|
||||||
if "Developer instruction is not enabled" in e.message:
|
if "Developer instruction is not enabled" in e.message:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||||
)
|
)
|
||||||
system_instruction = None
|
system_instruction = None
|
||||||
elif "Function calling is not enabled" in e.message:
|
elif "Function calling is not enabled" in e.message:
|
||||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
||||||
tools = None
|
tools = None
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
@@ -569,6 +600,9 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
chunk.candidates[0],
|
chunk.candidates[0],
|
||||||
llm_response,
|
llm_response,
|
||||||
)
|
)
|
||||||
|
llm_response.id = chunk.response_id
|
||||||
|
if chunk.usage_metadata:
|
||||||
|
llm_response.usage = self._extract_usage(chunk.usage_metadata)
|
||||||
yield llm_response
|
yield llm_response
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -596,6 +630,9 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
chunk.candidates[0],
|
chunk.candidates[0],
|
||||||
final_response,
|
final_response,
|
||||||
)
|
)
|
||||||
|
final_response.id = chunk.response_id
|
||||||
|
if chunk.usage_metadata:
|
||||||
|
final_response.usage = self._extract_usage(chunk.usage_metadata)
|
||||||
break
|
break
|
||||||
|
|
||||||
# Yield final complete response with accumulated text
|
# Yield final complete response with accumulated text
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from openai._exceptions import NotFoundError
|
|||||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
|
from openai.types.completion_usage import CompletionUsage
|
||||||
|
|
||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
@@ -19,7 +20,7 @@ from astrbot.api.provider import Provider
|
|||||||
from astrbot.core.agent.message import Message
|
from astrbot.core.agent.message import Message
|
||||||
from astrbot.core.agent.tool import ToolSet
|
from astrbot.core.agent.tool import ToolSet
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
@@ -208,6 +209,7 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
# handle the content delta
|
# handle the content delta
|
||||||
reasoning = self._extract_reasoning_content(chunk)
|
reasoning = self._extract_reasoning_content(chunk)
|
||||||
_y = False
|
_y = False
|
||||||
|
llm_response.id = chunk.id
|
||||||
if reasoning:
|
if reasoning:
|
||||||
llm_response.reasoning_content = reasoning
|
llm_response.reasoning_content = reasoning
|
||||||
_y = True
|
_y = True
|
||||||
@@ -217,6 +219,8 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
chain=[Comp.Plain(completion_text)],
|
chain=[Comp.Plain(completion_text)],
|
||||||
)
|
)
|
||||||
_y = True
|
_y = True
|
||||||
|
if chunk.usage:
|
||||||
|
llm_response.usage = self._extract_usage(chunk.usage)
|
||||||
if _y:
|
if _y:
|
||||||
yield llm_response
|
yield llm_response
|
||||||
|
|
||||||
@@ -245,6 +249,15 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
reasoning_text = str(reasoning_attr)
|
reasoning_text = str(reasoning_attr)
|
||||||
return reasoning_text
|
return reasoning_text
|
||||||
|
|
||||||
|
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
|
||||||
|
ptd = usage.prompt_tokens_details
|
||||||
|
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
|
||||||
|
return TokenUsage(
|
||||||
|
input_other=usage.prompt_tokens - cached,
|
||||||
|
input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
|
||||||
|
output=usage.completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
async def _parse_openai_completion(
|
async def _parse_openai_completion(
|
||||||
self, completion: ChatCompletion, tools: ToolSet | None
|
self, completion: ChatCompletion, tools: ToolSet | None
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
@@ -321,6 +334,10 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||||
|
|
||||||
llm_response.raw_completion = completion
|
llm_response.raw_completion = completion
|
||||||
|
llm_response.id = completion.id
|
||||||
|
|
||||||
|
if completion.usage:
|
||||||
|
llm_response.usage = self._extract_usage(completion.usage)
|
||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
|||||||
@@ -2,15 +2,19 @@ from astrbot.core import html_renderer
|
|||||||
from astrbot.core.provider import Provider
|
from astrbot.core.provider import Provider
|
||||||
from astrbot.core.star.star_tools import StarTools
|
from astrbot.core.star.star_tools import StarTools
|
||||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||||
|
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
||||||
|
|
||||||
from .context import Context
|
from .context import Context
|
||||||
from .star import StarMetadata, star_map, star_registry
|
from .star import StarMetadata, star_map, star_registry
|
||||||
from .star_manager import PluginManager
|
from .star_manager import PluginManager
|
||||||
|
|
||||||
|
|
||||||
class Star(CommandParserMixin):
|
class Star(CommandParserMixin, PluginKVStoreMixin):
|
||||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||||
|
|
||||||
|
author: str
|
||||||
|
name: str
|
||||||
|
|
||||||
def __init__(self, context: Context, config: dict | None = None):
|
def __init__(self, context: Context, config: dict | None = None):
|
||||||
StarTools.initialize(context)
|
StarTools.initialize(context)
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|||||||
@@ -0,0 +1,449 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from astrbot.core import db_helper
|
||||||
|
from astrbot.core.db.po import CommandConfig
|
||||||
|
from astrbot.core.star.filter.command import CommandFilter
|
||||||
|
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||||
|
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
|
||||||
|
from astrbot.core.star.star import star_map
|
||||||
|
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandDescriptor:
|
||||||
|
handler: StarHandlerMetadata = field(repr=False)
|
||||||
|
filter_ref: CommandFilter | CommandGroupFilter | None = field(
|
||||||
|
default=None,
|
||||||
|
repr=False,
|
||||||
|
)
|
||||||
|
handler_full_name: str = ""
|
||||||
|
handler_name: str = ""
|
||||||
|
plugin_name: str = ""
|
||||||
|
plugin_display_name: str | None = None
|
||||||
|
module_path: str = ""
|
||||||
|
description: str = ""
|
||||||
|
command_type: str = "command" # "command" | "group" | "sub_command"
|
||||||
|
raw_command_name: str | None = None
|
||||||
|
current_fragment: str | None = None
|
||||||
|
parent_signature: str = ""
|
||||||
|
parent_group_handler: str = ""
|
||||||
|
original_command: str | None = None
|
||||||
|
effective_command: str | None = None
|
||||||
|
aliases: list[str] = field(default_factory=list)
|
||||||
|
permission: str = "everyone"
|
||||||
|
enabled: bool = True
|
||||||
|
is_group: bool = False
|
||||||
|
is_sub_command: bool = False
|
||||||
|
reserved: bool = False
|
||||||
|
config: CommandConfig | None = None
|
||||||
|
has_conflict: bool = False
|
||||||
|
sub_commands: list[CommandDescriptor] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
async def sync_command_configs() -> None:
|
||||||
|
"""同步指令配置,清理过期配置。"""
|
||||||
|
descriptors = _collect_descriptors(include_sub_commands=False)
|
||||||
|
config_records = await db_helper.get_command_configs()
|
||||||
|
config_map = _bind_configs_to_descriptors(descriptors, config_records)
|
||||||
|
live_handlers = {desc.handler_full_name for desc in descriptors}
|
||||||
|
|
||||||
|
stale_configs = [key for key in config_map if key not in live_handlers]
|
||||||
|
if stale_configs:
|
||||||
|
await db_helper.delete_command_configs(stale_configs)
|
||||||
|
|
||||||
|
|
||||||
|
async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescriptor:
|
||||||
|
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
||||||
|
if not descriptor:
|
||||||
|
raise ValueError("指定的处理函数不存在或不是指令。")
|
||||||
|
|
||||||
|
existing_cfg = await db_helper.get_command_config(handler_full_name)
|
||||||
|
config = await db_helper.upsert_command_config(
|
||||||
|
handler_full_name=handler_full_name,
|
||||||
|
plugin_name=descriptor.plugin_name or "",
|
||||||
|
module_path=descriptor.module_path,
|
||||||
|
original_command=descriptor.original_command or descriptor.handler_name,
|
||||||
|
resolved_command=(
|
||||||
|
existing_cfg.resolved_command
|
||||||
|
if existing_cfg
|
||||||
|
else descriptor.current_fragment
|
||||||
|
),
|
||||||
|
enabled=enabled,
|
||||||
|
keep_original_alias=False,
|
||||||
|
conflict_key=existing_cfg.conflict_key
|
||||||
|
if existing_cfg and existing_cfg.conflict_key
|
||||||
|
else descriptor.original_command,
|
||||||
|
resolution_strategy=existing_cfg.resolution_strategy if existing_cfg else None,
|
||||||
|
note=existing_cfg.note if existing_cfg else None,
|
||||||
|
extra_data=existing_cfg.extra_data if existing_cfg else None,
|
||||||
|
auto_managed=False,
|
||||||
|
)
|
||||||
|
_bind_descriptor_with_config(descriptor, config)
|
||||||
|
await sync_command_configs()
|
||||||
|
return descriptor
|
||||||
|
|
||||||
|
|
||||||
|
async def rename_command(
|
||||||
|
handler_full_name: str,
|
||||||
|
new_fragment: str,
|
||||||
|
) -> CommandDescriptor:
|
||||||
|
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
||||||
|
if not descriptor:
|
||||||
|
raise ValueError("指定的处理函数不存在或不是指令。")
|
||||||
|
|
||||||
|
new_fragment = new_fragment.strip()
|
||||||
|
if not new_fragment:
|
||||||
|
raise ValueError("指令名不能为空。")
|
||||||
|
|
||||||
|
candidate_full = _compose_command(descriptor.parent_signature, new_fragment)
|
||||||
|
if _is_command_in_use(handler_full_name, candidate_full):
|
||||||
|
raise ValueError("新的指令名已被其他指令占用,请换一个名称。")
|
||||||
|
|
||||||
|
config = await db_helper.upsert_command_config(
|
||||||
|
handler_full_name=handler_full_name,
|
||||||
|
plugin_name=descriptor.plugin_name or "",
|
||||||
|
module_path=descriptor.module_path,
|
||||||
|
original_command=descriptor.original_command or descriptor.handler_name,
|
||||||
|
resolved_command=new_fragment,
|
||||||
|
enabled=True if descriptor.enabled else False,
|
||||||
|
keep_original_alias=False,
|
||||||
|
conflict_key=descriptor.original_command,
|
||||||
|
resolution_strategy="manual_rename",
|
||||||
|
note=None,
|
||||||
|
extra_data=None,
|
||||||
|
auto_managed=False,
|
||||||
|
)
|
||||||
|
_bind_descriptor_with_config(descriptor, config)
|
||||||
|
|
||||||
|
await sync_command_configs()
|
||||||
|
return descriptor
|
||||||
|
|
||||||
|
|
||||||
|
async def list_commands() -> list[dict[str, Any]]:
|
||||||
|
descriptors = _collect_descriptors(include_sub_commands=True)
|
||||||
|
config_records = await db_helper.get_command_configs()
|
||||||
|
_bind_configs_to_descriptors(descriptors, config_records)
|
||||||
|
|
||||||
|
conflict_groups = _group_conflicts(descriptors)
|
||||||
|
conflict_handler_names: set[str] = {
|
||||||
|
d.handler_full_name for group in conflict_groups.values() for d in group
|
||||||
|
}
|
||||||
|
|
||||||
|
# 分类,设置冲突标志,将子指令挂载到父指令组
|
||||||
|
group_map: dict[str, CommandDescriptor] = {}
|
||||||
|
sub_commands: list[CommandDescriptor] = []
|
||||||
|
root_commands: list[CommandDescriptor] = []
|
||||||
|
|
||||||
|
for desc in descriptors:
|
||||||
|
desc.has_conflict = desc.handler_full_name in conflict_handler_names
|
||||||
|
if desc.is_group:
|
||||||
|
group_map[desc.handler_full_name] = desc
|
||||||
|
elif desc.is_sub_command:
|
||||||
|
sub_commands.append(desc)
|
||||||
|
else:
|
||||||
|
root_commands.append(desc)
|
||||||
|
|
||||||
|
for sub in sub_commands:
|
||||||
|
if sub.parent_group_handler and sub.parent_group_handler in group_map:
|
||||||
|
group_map[sub.parent_group_handler].sub_commands.append(sub)
|
||||||
|
else:
|
||||||
|
root_commands.append(sub)
|
||||||
|
|
||||||
|
# 指令组 + 普通指令,按 effective_command 字母排序
|
||||||
|
all_commands = list(group_map.values()) + root_commands
|
||||||
|
all_commands.sort(key=lambda d: (d.effective_command or "").lower())
|
||||||
|
|
||||||
|
result = [_descriptor_to_dict(desc) for desc in all_commands]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def list_command_conflicts() -> list[dict[str, Any]]:
|
||||||
|
"""列出所有冲突的指令组。"""
|
||||||
|
descriptors = _collect_descriptors(include_sub_commands=False)
|
||||||
|
config_records = await db_helper.get_command_configs()
|
||||||
|
_bind_configs_to_descriptors(descriptors, config_records)
|
||||||
|
|
||||||
|
conflict_groups = _group_conflicts(descriptors)
|
||||||
|
details = [
|
||||||
|
{
|
||||||
|
"conflict_key": key,
|
||||||
|
"handlers": [
|
||||||
|
{
|
||||||
|
"handler_full_name": item.handler_full_name,
|
||||||
|
"plugin": item.plugin_name,
|
||||||
|
"current_name": item.effective_command,
|
||||||
|
}
|
||||||
|
for item in group
|
||||||
|
],
|
||||||
|
}
|
||||||
|
for key, group in conflict_groups.items()
|
||||||
|
]
|
||||||
|
return details
|
||||||
|
|
||||||
|
|
||||||
|
# Internal helpers ----------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]:
|
||||||
|
"""收集指令,按需包含子指令。"""
|
||||||
|
descriptors: list[CommandDescriptor] = []
|
||||||
|
for handler in star_handlers_registry:
|
||||||
|
desc = _build_descriptor(handler)
|
||||||
|
if not desc:
|
||||||
|
continue
|
||||||
|
if not include_sub_commands and desc.is_sub_command:
|
||||||
|
continue
|
||||||
|
descriptors.append(desc)
|
||||||
|
return descriptors
|
||||||
|
|
||||||
|
|
||||||
|
def _build_descriptor(handler: StarHandlerMetadata) -> CommandDescriptor | None:
|
||||||
|
filter_ref = _locate_primary_filter(handler)
|
||||||
|
if filter_ref is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
plugin_meta = star_map.get(handler.handler_module_path)
|
||||||
|
plugin_name = (
|
||||||
|
plugin_meta.name if plugin_meta else None
|
||||||
|
) or handler.handler_module_path
|
||||||
|
plugin_display = plugin_meta.display_name if plugin_meta else None
|
||||||
|
|
||||||
|
is_sub_command = bool(handler.extras_configs.get("sub_command"))
|
||||||
|
parent_group_handler = ""
|
||||||
|
|
||||||
|
if isinstance(filter_ref, CommandFilter):
|
||||||
|
raw_fragment = getattr(
|
||||||
|
filter_ref, "_original_command_name", filter_ref.command_name
|
||||||
|
)
|
||||||
|
current_fragment = filter_ref.command_name
|
||||||
|
parent_signature = (filter_ref.parent_command_names or [""])[0].strip()
|
||||||
|
# 如果是子指令,尝试找到父指令组的 handler_full_name
|
||||||
|
if is_sub_command and parent_signature:
|
||||||
|
parent_group_handler = _find_parent_group_handler(
|
||||||
|
handler.handler_module_path, parent_signature
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_fragment = getattr(
|
||||||
|
filter_ref, "_original_group_name", filter_ref.group_name
|
||||||
|
)
|
||||||
|
current_fragment = filter_ref.group_name
|
||||||
|
parent_signature = _resolve_group_parent_signature(filter_ref)
|
||||||
|
|
||||||
|
original_command = _compose_command(parent_signature, raw_fragment)
|
||||||
|
effective_command = _compose_command(parent_signature, current_fragment)
|
||||||
|
|
||||||
|
# 确定 command_type
|
||||||
|
if isinstance(filter_ref, CommandGroupFilter):
|
||||||
|
command_type = "group"
|
||||||
|
elif is_sub_command:
|
||||||
|
command_type = "sub_command"
|
||||||
|
else:
|
||||||
|
command_type = "command"
|
||||||
|
|
||||||
|
descriptor = CommandDescriptor(
|
||||||
|
handler=handler,
|
||||||
|
filter_ref=filter_ref,
|
||||||
|
handler_full_name=handler.handler_full_name,
|
||||||
|
handler_name=handler.handler_name,
|
||||||
|
plugin_name=plugin_name,
|
||||||
|
plugin_display_name=plugin_display,
|
||||||
|
module_path=handler.handler_module_path,
|
||||||
|
description=handler.desc or "",
|
||||||
|
command_type=command_type,
|
||||||
|
raw_command_name=raw_fragment,
|
||||||
|
current_fragment=current_fragment,
|
||||||
|
parent_signature=parent_signature,
|
||||||
|
parent_group_handler=parent_group_handler,
|
||||||
|
original_command=original_command,
|
||||||
|
effective_command=effective_command,
|
||||||
|
aliases=sorted(getattr(filter_ref, "alias", set())),
|
||||||
|
permission=_determine_permission(handler),
|
||||||
|
enabled=handler.enabled,
|
||||||
|
is_group=isinstance(filter_ref, CommandGroupFilter),
|
||||||
|
is_sub_command=is_sub_command,
|
||||||
|
reserved=plugin_meta.reserved if plugin_meta else False,
|
||||||
|
)
|
||||||
|
return descriptor
|
||||||
|
|
||||||
|
|
||||||
|
def _build_descriptor_by_full_name(full_name: str) -> CommandDescriptor | None:
|
||||||
|
handler = star_handlers_registry.get_handler_by_full_name(full_name)
|
||||||
|
if not handler:
|
||||||
|
return None
|
||||||
|
return _build_descriptor(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def _locate_primary_filter(
|
||||||
|
handler: StarHandlerMetadata,
|
||||||
|
) -> CommandFilter | CommandGroupFilter | None:
|
||||||
|
for filter_ref in handler.event_filters:
|
||||||
|
if isinstance(filter_ref, (CommandFilter, CommandGroupFilter)):
|
||||||
|
return filter_ref
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _determine_permission(handler: StarHandlerMetadata) -> str:
|
||||||
|
for filter_ref in handler.event_filters:
|
||||||
|
if isinstance(filter_ref, PermissionTypeFilter):
|
||||||
|
return (
|
||||||
|
"admin"
|
||||||
|
if filter_ref.permission_type == PermissionType.ADMIN
|
||||||
|
else "member"
|
||||||
|
)
|
||||||
|
return "everyone"
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_group_parent_signature(group_filter: CommandGroupFilter) -> str:
|
||||||
|
signatures: list[str] = []
|
||||||
|
parent = group_filter.parent_group
|
||||||
|
while parent:
|
||||||
|
signatures.append(getattr(parent, "_original_group_name", parent.group_name))
|
||||||
|
parent = parent.parent_group
|
||||||
|
return " ".join(reversed(signatures)).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _find_parent_group_handler(module_path: str, parent_signature: str) -> str:
|
||||||
|
"""根据模块路径和父级签名,找到对应的指令组 handler_full_name。"""
|
||||||
|
parent_sig_normalized = parent_signature.strip()
|
||||||
|
for handler in star_handlers_registry:
|
||||||
|
if handler.handler_module_path != module_path:
|
||||||
|
continue
|
||||||
|
filter_ref = _locate_primary_filter(handler)
|
||||||
|
if not isinstance(filter_ref, CommandGroupFilter):
|
||||||
|
continue
|
||||||
|
# 检查该指令组的完整指令名是否匹配 parent_signature
|
||||||
|
group_names = filter_ref.get_complete_command_names()
|
||||||
|
if parent_sig_normalized in group_names:
|
||||||
|
return handler.handler_full_name
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
def _compose_command(parent_signature: str, fragment: str | None) -> str:
|
||||||
|
fragment = (fragment or "").strip()
|
||||||
|
parent_signature = parent_signature.strip()
|
||||||
|
if not parent_signature:
|
||||||
|
return fragment
|
||||||
|
if not fragment:
|
||||||
|
return parent_signature
|
||||||
|
return f"{parent_signature} {fragment}"
|
||||||
|
|
||||||
|
|
||||||
|
def _bind_descriptor_with_config(
|
||||||
|
descriptor: CommandDescriptor,
|
||||||
|
config: CommandConfig,
|
||||||
|
) -> None:
|
||||||
|
_apply_config_to_descriptor(descriptor, config)
|
||||||
|
_apply_config_to_runtime(descriptor, config)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_config_to_descriptor(
|
||||||
|
descriptor: CommandDescriptor,
|
||||||
|
config: CommandConfig,
|
||||||
|
) -> None:
|
||||||
|
descriptor.config = config
|
||||||
|
descriptor.enabled = config.enabled
|
||||||
|
|
||||||
|
if config.original_command:
|
||||||
|
descriptor.original_command = config.original_command
|
||||||
|
|
||||||
|
new_fragment = config.resolved_command or descriptor.current_fragment
|
||||||
|
descriptor.current_fragment = new_fragment
|
||||||
|
descriptor.effective_command = _compose_command(
|
||||||
|
descriptor.parent_signature,
|
||||||
|
new_fragment,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_config_to_runtime(
|
||||||
|
descriptor: CommandDescriptor,
|
||||||
|
config: CommandConfig,
|
||||||
|
) -> None:
|
||||||
|
descriptor.handler.enabled = config.enabled
|
||||||
|
if descriptor.filter_ref and descriptor.current_fragment:
|
||||||
|
_set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
|
||||||
|
|
||||||
|
|
||||||
|
def _bind_configs_to_descriptors(
|
||||||
|
descriptors: list[CommandDescriptor],
|
||||||
|
config_records: list[CommandConfig],
|
||||||
|
) -> dict[str, CommandConfig]:
|
||||||
|
config_map = {cfg.handler_full_name: cfg for cfg in config_records}
|
||||||
|
for desc in descriptors:
|
||||||
|
if cfg := config_map.get(desc.handler_full_name):
|
||||||
|
_bind_descriptor_with_config(desc, cfg)
|
||||||
|
return config_map
|
||||||
|
|
||||||
|
|
||||||
|
def _group_conflicts(
|
||||||
|
descriptors: list[CommandDescriptor],
|
||||||
|
) -> dict[str, list[CommandDescriptor]]:
|
||||||
|
conflicts: dict[str, list[CommandDescriptor]] = defaultdict(list)
|
||||||
|
for desc in descriptors:
|
||||||
|
if desc.effective_command and desc.enabled:
|
||||||
|
conflicts[desc.effective_command].append(desc)
|
||||||
|
return {k: v for k, v in conflicts.items() if len(v) > 1}
|
||||||
|
|
||||||
|
|
||||||
|
def _set_filter_fragment(
|
||||||
|
filter_ref: CommandFilter | CommandGroupFilter,
|
||||||
|
fragment: str,
|
||||||
|
) -> None:
|
||||||
|
attr = (
|
||||||
|
"group_name" if isinstance(filter_ref, CommandGroupFilter) else "command_name"
|
||||||
|
)
|
||||||
|
current_value = getattr(filter_ref, attr)
|
||||||
|
if fragment == current_value:
|
||||||
|
return
|
||||||
|
setattr(filter_ref, attr, fragment)
|
||||||
|
if hasattr(filter_ref, "_cmpl_cmd_names"):
|
||||||
|
filter_ref._cmpl_cmd_names = None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_command_in_use(
|
||||||
|
target_handler_full_name: str,
|
||||||
|
candidate_full_command: str,
|
||||||
|
) -> bool:
|
||||||
|
candidate = candidate_full_command.strip()
|
||||||
|
for handler in star_handlers_registry:
|
||||||
|
if handler.handler_full_name == target_handler_full_name:
|
||||||
|
continue
|
||||||
|
filter_ref = _locate_primary_filter(handler)
|
||||||
|
if not filter_ref:
|
||||||
|
continue
|
||||||
|
names = {name.strip() for name in filter_ref.get_complete_command_names()}
|
||||||
|
if candidate in names:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _descriptor_to_dict(desc: CommandDescriptor) -> dict[str, Any]:
|
||||||
|
result = {
|
||||||
|
"handler_full_name": desc.handler_full_name,
|
||||||
|
"handler_name": desc.handler_name,
|
||||||
|
"plugin": desc.plugin_name,
|
||||||
|
"plugin_display_name": desc.plugin_display_name,
|
||||||
|
"module_path": desc.module_path,
|
||||||
|
"description": desc.description,
|
||||||
|
"type": desc.command_type,
|
||||||
|
"parent_signature": desc.parent_signature,
|
||||||
|
"parent_group_handler": desc.parent_group_handler,
|
||||||
|
"original_command": desc.original_command,
|
||||||
|
"current_fragment": desc.current_fragment,
|
||||||
|
"effective_command": desc.effective_command,
|
||||||
|
"aliases": desc.aliases,
|
||||||
|
"permission": desc.permission,
|
||||||
|
"enabled": desc.enabled,
|
||||||
|
"is_group": desc.is_group,
|
||||||
|
"has_conflict": desc.has_conflict,
|
||||||
|
"reserved": desc.reserved,
|
||||||
|
}
|
||||||
|
# 如果是指令组,包含子指令列表
|
||||||
|
if desc.is_group and desc.sub_commands:
|
||||||
|
result["sub_commands"] = [_descriptor_to_dict(sub) for sub in desc.sub_commands]
|
||||||
|
else:
|
||||||
|
result["sub_commands"] = []
|
||||||
|
return result
|
||||||
@@ -296,6 +296,10 @@ class Context:
|
|||||||
provider_type=ProviderType.CHAT_COMPLETION,
|
provider_type=ProviderType.CHAT_COMPLETION,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
|
if prov is None:
|
||||||
|
raise ProviderNotFoundError(
|
||||||
|
"provider not found, please choose provider first"
|
||||||
|
)
|
||||||
if not isinstance(prov, Provider):
|
if not isinstance(prov, Provider):
|
||||||
raise ValueError("返回的 Provider 不是 Provider 类型")
|
raise ValueError("返回的 Provider 不是 Provider 类型")
|
||||||
return prov
|
return prov
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class CommandFilter(HandlerFilter):
|
|||||||
):
|
):
|
||||||
self.command_name = command_name
|
self.command_name = command_name
|
||||||
self.alias = alias if alias else set()
|
self.alias = alias if alias else set()
|
||||||
|
self._original_command_name = command_name
|
||||||
self.parent_command_names = (
|
self.parent_command_names = (
|
||||||
parent_command_names if parent_command_names is not None else [""]
|
parent_command_names if parent_command_names is not None else [""]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
):
|
):
|
||||||
self.group_name = group_name
|
self.group_name = group_name
|
||||||
self.alias = alias if alias else set()
|
self.alias = alias if alias else set()
|
||||||
|
self._original_group_name = group_name
|
||||||
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
|
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
|
||||||
self.custom_filter_list: list[CustomFilter] = []
|
self.custom_filter_list: list[CustomFilter] = []
|
||||||
self.parent_group = parent_group
|
self.parent_group = parent_group
|
||||||
|
|||||||
@@ -118,6 +118,8 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
# 过滤事件类型
|
# 过滤事件类型
|
||||||
if handler.event_type != event_type:
|
if handler.event_type != event_type:
|
||||||
continue
|
continue
|
||||||
|
if not handler.enabled:
|
||||||
|
continue
|
||||||
# 过滤启用状态
|
# 过滤启用状态
|
||||||
if only_activated:
|
if only_activated:
|
||||||
plugin = star_map.get(handler.handler_module_path)
|
plugin = star_map.get(handler.handler_module_path)
|
||||||
@@ -220,6 +222,8 @@ class StarHandlerMetadata(Generic[H]):
|
|||||||
extras_configs: dict = field(default_factory=dict)
|
extras_configs: dict = field(default_factory=dict)
|
||||||
"""插件注册的一些其他的信息, 如 priority 等"""
|
"""插件注册的一些其他的信息, 如 priority 等"""
|
||||||
|
|
||||||
|
enabled: bool = True
|
||||||
|
|
||||||
def __lt__(self, other: StarHandlerMetadata):
|
def __lt__(self, other: StarHandlerMetadata):
|
||||||
"""定义小于运算符以支持优先队列"""
|
"""定义小于运算符以支持优先队列"""
|
||||||
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from astrbot.core.utils.astrbot_path import (
|
|||||||
from astrbot.core.utils.io import remove_dir
|
from astrbot.core.utils.io import remove_dir
|
||||||
|
|
||||||
from . import StarMetadata
|
from . import StarMetadata
|
||||||
|
from .command_management import sync_command_configs
|
||||||
from .context import Context
|
from .context import Context
|
||||||
from .filter.permission import PermissionType, PermissionTypeFilter
|
from .filter.permission import PermissionType, PermissionTypeFilter
|
||||||
from .star import star_map, star_registry
|
from .star import star_map, star_registry
|
||||||
@@ -467,6 +468,18 @@ class PluginManager:
|
|||||||
metadata.star_cls = metadata.star_cls_type(
|
metadata.star_cls = metadata.star_cls_type(
|
||||||
context=self.context,
|
context=self.context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
p_name = (metadata.name or "unknown").lower().replace("/", "_")
|
||||||
|
p_author = (
|
||||||
|
(metadata.author or "unknown").lower().replace("/", "_")
|
||||||
|
)
|
||||||
|
setattr(metadata.star_cls, "name", p_name)
|
||||||
|
setattr(metadata.star_cls, "author", p_author)
|
||||||
|
setattr(
|
||||||
|
metadata.star_cls,
|
||||||
|
"plugin_id",
|
||||||
|
f"{p_author}/{p_name}",
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
logger.info(f"插件 {metadata.name} 已被禁用。")
|
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||||
|
|
||||||
@@ -618,6 +631,7 @@ class PluginManager:
|
|||||||
# 清除 pip.main 导致的多余的 logging handlers
|
# 清除 pip.main 导致的多余的 logging handlers
|
||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
logging.root.removeHandler(handler)
|
logging.root.removeHandler(handler)
|
||||||
|
await sync_command_configs()
|
||||||
|
|
||||||
if not fail_rec:
|
if not fail_rec:
|
||||||
return True, None
|
return True, None
|
||||||
|
|||||||
@@ -0,0 +1,28 @@
|
|||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from astrbot.core import sp
|
||||||
|
|
||||||
|
SUPPORTED_VALUE_TYPES = int | float | str | bytes | bool | dict | list | None
|
||||||
|
_VT = TypeVar("_VT")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginKVStoreMixin:
|
||||||
|
"""为插件提供键值存储功能的 Mixin 类"""
|
||||||
|
|
||||||
|
plugin_id: str
|
||||||
|
|
||||||
|
async def put_kv_data(
|
||||||
|
self,
|
||||||
|
key: str,
|
||||||
|
value: SUPPORTED_VALUE_TYPES,
|
||||||
|
) -> None:
|
||||||
|
"""为指定插件存储一个键值对"""
|
||||||
|
await sp.put_async("plugin", self.plugin_id, key, value)
|
||||||
|
|
||||||
|
async def get_kv_data(self, key: str, default: _VT) -> _VT | None:
|
||||||
|
"""获取指定插件存储的键值对"""
|
||||||
|
return await sp.get_async("plugin", self.plugin_id, key, default)
|
||||||
|
|
||||||
|
async def delete_kv_data(self, key: str) -> None:
|
||||||
|
"""删除指定插件存储的键值对"""
|
||||||
|
await sp.remove_async("plugin", self.plugin_id, key)
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
from .auth import AuthRoute
|
from .auth import AuthRoute
|
||||||
from .chat import ChatRoute
|
from .chat import ChatRoute
|
||||||
|
from .command import CommandRoute
|
||||||
from .config import ConfigRoute
|
from .config import ConfigRoute
|
||||||
from .conversation import ConversationRoute
|
from .conversation import ConversationRoute
|
||||||
from .file import FileRoute
|
from .file import FileRoute
|
||||||
@@ -17,6 +18,7 @@ from .update import UpdateRoute
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AuthRoute",
|
"AuthRoute",
|
||||||
"ChatRoute",
|
"ChatRoute",
|
||||||
|
"CommandRoute",
|
||||||
"ConfigRoute",
|
"ConfigRoute",
|
||||||
"ConversationRoute",
|
"ConversationRoute",
|
||||||
"FileRoute",
|
"FileRoute",
|
||||||
|
|||||||
@@ -227,16 +227,19 @@ class ChatRoute(Route):
|
|||||||
text: str,
|
text: str,
|
||||||
media_parts: list,
|
media_parts: list,
|
||||||
reasoning: str,
|
reasoning: str,
|
||||||
|
agent_stats: dict,
|
||||||
):
|
):
|
||||||
"""保存 bot 消息到历史记录,返回保存的记录"""
|
"""保存 bot 消息到历史记录,返回保存的记录"""
|
||||||
bot_message_parts = []
|
bot_message_parts = []
|
||||||
|
bot_message_parts.extend(media_parts)
|
||||||
if text:
|
if text:
|
||||||
bot_message_parts.append({"type": "plain", "text": text})
|
bot_message_parts.append({"type": "plain", "text": text})
|
||||||
bot_message_parts.extend(media_parts)
|
|
||||||
|
|
||||||
new_his = {"type": "bot", "message": bot_message_parts}
|
new_his = {"type": "bot", "message": bot_message_parts}
|
||||||
if reasoning:
|
if reasoning:
|
||||||
new_his["reasoning"] = reasoning
|
new_his["reasoning"] = reasoning
|
||||||
|
if agent_stats:
|
||||||
|
new_his["agent_stats"] = agent_stats
|
||||||
|
|
||||||
record = await self.platform_history_mgr.insert(
|
record = await self.platform_history_mgr.insert(
|
||||||
platform_id="webchat",
|
platform_id="webchat",
|
||||||
@@ -294,7 +297,8 @@ class ChatRoute(Route):
|
|||||||
accumulated_parts = []
|
accumulated_parts = []
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
accumulated_reasoning = ""
|
accumulated_reasoning = ""
|
||||||
|
tool_calls = {}
|
||||||
|
agent_stats = {}
|
||||||
try:
|
try:
|
||||||
async with track_conversation(self.running_convs, webchat_conv_id):
|
async with track_conversation(self.running_convs, webchat_conv_id):
|
||||||
while True:
|
while True:
|
||||||
@@ -314,6 +318,16 @@ class ChatRoute(Route):
|
|||||||
result_text = result["data"]
|
result_text = result["data"]
|
||||||
msg_type = result.get("type")
|
msg_type = result.get("type")
|
||||||
streaming = result.get("streaming", False)
|
streaming = result.get("streaming", False)
|
||||||
|
chain_type = result.get("chain_type")
|
||||||
|
|
||||||
|
if chain_type == "agent_stats":
|
||||||
|
stats_info = {
|
||||||
|
"type": "agent_stats",
|
||||||
|
"data": json.loads(result_text),
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(stats_info, ensure_ascii=False)}\n\n"
|
||||||
|
agent_stats = stats_info["data"]
|
||||||
|
continue
|
||||||
|
|
||||||
# 发送 SSE 数据
|
# 发送 SSE 数据
|
||||||
try:
|
try:
|
||||||
@@ -335,11 +349,35 @@ class ChatRoute(Route):
|
|||||||
|
|
||||||
# 累积消息部分
|
# 累积消息部分
|
||||||
if msg_type == "plain":
|
if msg_type == "plain":
|
||||||
chain_type = result.get("chain_type", "normal")
|
chain_type = result.get("chain_type")
|
||||||
if chain_type == "reasoning":
|
if chain_type == "tool_call":
|
||||||
|
tool_call = json.loads(result_text)
|
||||||
|
tool_calls[tool_call.get("id")] = tool_call
|
||||||
|
if accumulated_text:
|
||||||
|
# 如果累积了文本,则先保存文本
|
||||||
|
accumulated_parts.append(
|
||||||
|
{"type": "plain", "text": accumulated_text}
|
||||||
|
)
|
||||||
|
accumulated_text = ""
|
||||||
|
elif chain_type == "tool_call_result":
|
||||||
|
tcr = json.loads(result_text)
|
||||||
|
tc_id = tcr.get("id")
|
||||||
|
if tc_id in tool_calls:
|
||||||
|
tool_calls[tc_id]["result"] = tcr.get("result")
|
||||||
|
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
||||||
|
accumulated_parts.append(
|
||||||
|
{
|
||||||
|
"type": "tool_call",
|
||||||
|
"tool_calls": [tool_calls[tc_id]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tool_calls.pop(tc_id, None)
|
||||||
|
elif chain_type == "reasoning":
|
||||||
accumulated_reasoning += result_text
|
accumulated_reasoning += result_text
|
||||||
else:
|
elif streaming:
|
||||||
accumulated_text += result_text
|
accumulated_text += result_text
|
||||||
|
else:
|
||||||
|
accumulated_text = result_text
|
||||||
elif msg_type == "image":
|
elif msg_type == "image":
|
||||||
filename = result_text.replace("[IMAGE]", "")
|
filename = result_text.replace("[IMAGE]", "")
|
||||||
part = await self._create_attachment_from_file(
|
part = await self._create_attachment_from_file(
|
||||||
@@ -367,15 +405,20 @@ class ChatRoute(Route):
|
|||||||
if msg_type == "end":
|
if msg_type == "end":
|
||||||
break
|
break
|
||||||
elif (
|
elif (
|
||||||
(streaming and msg_type == "complete")
|
(streaming and msg_type == "complete") or not streaming
|
||||||
or not streaming
|
# or msg_type == "break"
|
||||||
or msg_type == "break"
|
|
||||||
):
|
):
|
||||||
|
if (
|
||||||
|
chain_type == "tool_call"
|
||||||
|
or chain_type == "tool_call_result"
|
||||||
|
):
|
||||||
|
continue
|
||||||
saved_record = await self._save_bot_message(
|
saved_record = await self._save_bot_message(
|
||||||
webchat_conv_id,
|
webchat_conv_id,
|
||||||
accumulated_text,
|
accumulated_text,
|
||||||
accumulated_parts,
|
accumulated_parts,
|
||||||
accumulated_reasoning,
|
accumulated_reasoning,
|
||||||
|
agent_stats,
|
||||||
)
|
)
|
||||||
# 发送保存的消息信息给前端
|
# 发送保存的消息信息给前端
|
||||||
if saved_record and not client_disconnected:
|
if saved_record and not client_disconnected:
|
||||||
@@ -390,11 +433,11 @@ class ChatRoute(Route):
|
|||||||
yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
# 重置累积变量 (对于 break 后的下一段消息)
|
accumulated_parts = []
|
||||||
if msg_type == "break":
|
accumulated_text = ""
|
||||||
accumulated_parts = []
|
accumulated_reasoning = ""
|
||||||
accumulated_text = ""
|
tool_calls = {}
|
||||||
accumulated_reasoning = ""
|
agent_stats = {}
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,82 @@
|
|||||||
|
from quart import request
|
||||||
|
|
||||||
|
from astrbot.core.star.command_management import (
|
||||||
|
list_command_conflicts,
|
||||||
|
list_commands,
|
||||||
|
)
|
||||||
|
from astrbot.core.star.command_management import (
|
||||||
|
rename_command as rename_command_service,
|
||||||
|
)
|
||||||
|
from astrbot.core.star.command_management import (
|
||||||
|
toggle_command as toggle_command_service,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .route import Response, Route, RouteContext
|
||||||
|
|
||||||
|
|
||||||
|
class CommandRoute(Route):
|
||||||
|
def __init__(self, context: RouteContext) -> None:
|
||||||
|
super().__init__(context)
|
||||||
|
self.routes = {
|
||||||
|
"/commands": ("GET", self.get_commands),
|
||||||
|
"/commands/conflicts": ("GET", self.get_conflicts),
|
||||||
|
"/commands/toggle": ("POST", self.toggle_command),
|
||||||
|
"/commands/rename": ("POST", self.rename_command),
|
||||||
|
}
|
||||||
|
self.register_routes()
|
||||||
|
|
||||||
|
async def get_commands(self):
|
||||||
|
commands = await list_commands()
|
||||||
|
summary = {
|
||||||
|
"total": len(commands),
|
||||||
|
"disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
|
||||||
|
"conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
|
||||||
|
}
|
||||||
|
return Response().ok({"items": commands, "summary": summary}).__dict__
|
||||||
|
|
||||||
|
async def get_conflicts(self):
|
||||||
|
conflicts = await list_command_conflicts()
|
||||||
|
return Response().ok(conflicts).__dict__
|
||||||
|
|
||||||
|
async def toggle_command(self):
|
||||||
|
data = await request.get_json()
|
||||||
|
handler_full_name = data.get("handler_full_name")
|
||||||
|
enabled = data.get("enabled")
|
||||||
|
|
||||||
|
if handler_full_name is None or enabled is None:
|
||||||
|
return Response().error("handler_full_name 与 enabled 均为必填。").__dict__
|
||||||
|
|
||||||
|
if isinstance(enabled, str):
|
||||||
|
enabled = enabled.lower() in ("1", "true", "yes", "on")
|
||||||
|
|
||||||
|
try:
|
||||||
|
await toggle_command_service(handler_full_name, bool(enabled))
|
||||||
|
except ValueError as exc:
|
||||||
|
return Response().error(str(exc)).__dict__
|
||||||
|
|
||||||
|
payload = await _get_command_payload(handler_full_name)
|
||||||
|
return Response().ok(payload).__dict__
|
||||||
|
|
||||||
|
async def rename_command(self):
|
||||||
|
data = await request.get_json()
|
||||||
|
handler_full_name = data.get("handler_full_name")
|
||||||
|
new_name = data.get("new_name")
|
||||||
|
|
||||||
|
if not handler_full_name or not new_name:
|
||||||
|
return Response().error("handler_full_name 与 new_name 均为必填。").__dict__
|
||||||
|
|
||||||
|
try:
|
||||||
|
await rename_command_service(handler_full_name, new_name)
|
||||||
|
except ValueError as exc:
|
||||||
|
return Response().error(str(exc)).__dict__
|
||||||
|
|
||||||
|
payload = await _get_command_payload(handler_full_name)
|
||||||
|
return Response().ok(payload).__dict__
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_command_payload(handler_full_name: str):
|
||||||
|
commands = await list_commands()
|
||||||
|
for cmd in commands:
|
||||||
|
if cmd["handler_full_name"] == handler_full_name:
|
||||||
|
return cmd
|
||||||
|
return {}
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
from quart import request
|
from quart import request, send_file
|
||||||
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
@@ -30,6 +32,7 @@ class ConversationRoute(Route):
|
|||||||
"POST",
|
"POST",
|
||||||
self.update_history,
|
self.update_history,
|
||||||
),
|
),
|
||||||
|
"/conversation/export": ("POST", self.export_conversations),
|
||||||
}
|
}
|
||||||
self.db_helper = db_helper
|
self.db_helper = db_helper
|
||||||
self.conv_mgr = core_lifecycle.conversation_manager
|
self.conv_mgr = core_lifecycle.conversation_manager
|
||||||
@@ -283,3 +286,90 @@ class ConversationRoute(Route):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
|
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
|
||||||
return Response().error(f"更新对话历史失败: {e!s}").__dict__
|
return Response().error(f"更新对话历史失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
async def export_conversations(self):
|
||||||
|
"""批量导出对话为 JSONL 格式"""
|
||||||
|
try:
|
||||||
|
data = await request.get_json()
|
||||||
|
conversations_to_export = data.get("conversations", [])
|
||||||
|
|
||||||
|
if not conversations_to_export:
|
||||||
|
return Response().error("导出列表不能为空").__dict__
|
||||||
|
|
||||||
|
# 收集所有对话的内容
|
||||||
|
jsonl_lines = []
|
||||||
|
exported_count = 0
|
||||||
|
failed_items = []
|
||||||
|
|
||||||
|
for conv_info in conversations_to_export:
|
||||||
|
user_id = conv_info.get("user_id")
|
||||||
|
cid = conv_info.get("cid")
|
||||||
|
|
||||||
|
if not user_id or not cid:
|
||||||
|
failed_items.append(
|
||||||
|
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
conversation = await self.conv_mgr.get_conversation(
|
||||||
|
unified_msg_origin=user_id,
|
||||||
|
conversation_id=cid,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
failed_items.append(
|
||||||
|
f"user_id:{user_id}, cid:{cid} - 对话不存在"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
|
||||||
|
content = json.loads(conversation.history)
|
||||||
|
|
||||||
|
# 创建导出记录
|
||||||
|
export_record = {
|
||||||
|
"cid": cid,
|
||||||
|
"user_id": user_id,
|
||||||
|
"platform_id": conversation.platform_id,
|
||||||
|
"title": conversation.title,
|
||||||
|
"persona_id": conversation.persona_id,
|
||||||
|
"created_at": conversation.created_at,
|
||||||
|
"updated_at": conversation.updated_at,
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 将记录转换为 JSON 字符串并添加到 JSONL
|
||||||
|
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
|
||||||
|
exported_count += 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
|
||||||
|
logger.error(
|
||||||
|
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if exported_count == 0:
|
||||||
|
return Response().error("没有成功导出任何对话").__dict__
|
||||||
|
|
||||||
|
# 创建 JSONL 内容
|
||||||
|
jsonl_content = "\n".join(jsonl_lines)
|
||||||
|
|
||||||
|
# 创建一个内存文件对象
|
||||||
|
file_obj = BytesIO(jsonl_content.encode("utf-8"))
|
||||||
|
file_obj.seek(0)
|
||||||
|
|
||||||
|
# 生成文件名
|
||||||
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
|
||||||
|
|
||||||
|
# 返回文件流
|
||||||
|
return await send_file(
|
||||||
|
file_obj,
|
||||||
|
mimetype="application/jsonl",
|
||||||
|
as_attachment=True,
|
||||||
|
attachment_filename=filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
|
||||||
|
return Response().error(f"批量导出对话失败: {e!s}").__dict__
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
# 文档管理
|
# 文档管理
|
||||||
"/kb/document/list": ("GET", self.list_documents),
|
"/kb/document/list": ("GET", self.list_documents),
|
||||||
"/kb/document/upload": ("POST", self.upload_document),
|
"/kb/document/upload": ("POST", self.upload_document),
|
||||||
|
"/kb/document/import": ("POST", self.import_documents),
|
||||||
"/kb/document/upload/url": ("POST", self.upload_document_from_url),
|
"/kb/document/upload/url": ("POST", self.upload_document_from_url),
|
||||||
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
|
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
|
||||||
"/kb/document/get": ("GET", self.get_document),
|
"/kb/document/get": ("GET", self.get_document),
|
||||||
@@ -66,6 +67,65 @@ class KnowledgeBaseRoute(Route):
|
|||||||
def _get_kb_manager(self):
|
def _get_kb_manager(self):
|
||||||
return self.core_lifecycle.kb_manager
|
return self.core_lifecycle.kb_manager
|
||||||
|
|
||||||
|
def _init_task(self, task_id: str, status: str = "pending") -> None:
|
||||||
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": status,
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _set_task_result(
|
||||||
|
self, task_id: str, status: str, result: any = None, error: str | None = None
|
||||||
|
) -> None:
|
||||||
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": status,
|
||||||
|
"result": result,
|
||||||
|
"error": error,
|
||||||
|
}
|
||||||
|
if task_id in self.upload_progress:
|
||||||
|
self.upload_progress[task_id]["status"] = status
|
||||||
|
|
||||||
|
def _update_progress(
|
||||||
|
self,
|
||||||
|
task_id: str,
|
||||||
|
*,
|
||||||
|
status: str | None = None,
|
||||||
|
file_index: int | None = None,
|
||||||
|
file_name: str | None = None,
|
||||||
|
stage: str | None = None,
|
||||||
|
current: int | None = None,
|
||||||
|
total: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
if task_id not in self.upload_progress:
|
||||||
|
return
|
||||||
|
p = self.upload_progress[task_id]
|
||||||
|
if status is not None:
|
||||||
|
p["status"] = status
|
||||||
|
if file_index is not None:
|
||||||
|
p["file_index"] = file_index
|
||||||
|
if file_name is not None:
|
||||||
|
p["file_name"] = file_name
|
||||||
|
if stage is not None:
|
||||||
|
p["stage"] = stage
|
||||||
|
if current is not None:
|
||||||
|
p["current"] = current
|
||||||
|
if total is not None:
|
||||||
|
p["total"] = total
|
||||||
|
|
||||||
|
def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
|
||||||
|
async def _callback(stage: str, current: int, total: int):
|
||||||
|
self._update_progress(
|
||||||
|
task_id,
|
||||||
|
status="processing",
|
||||||
|
file_index=file_idx,
|
||||||
|
file_name=file_name,
|
||||||
|
stage=stage,
|
||||||
|
current=current,
|
||||||
|
total=total,
|
||||||
|
)
|
||||||
|
|
||||||
|
return _callback
|
||||||
|
|
||||||
async def _background_upload_task(
|
async def _background_upload_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
@@ -80,11 +140,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"""后台上传任务"""
|
"""后台上传任务"""
|
||||||
try:
|
try:
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self.upload_tasks[task_id] = {
|
self._init_task(task_id, status="processing")
|
||||||
"status": "processing",
|
|
||||||
"result": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
self.upload_progress[task_id] = {
|
self.upload_progress[task_id] = {
|
||||||
"status": "processing",
|
"status": "processing",
|
||||||
"file_index": 0,
|
"file_index": 0,
|
||||||
@@ -100,30 +156,20 @@ class KnowledgeBaseRoute(Route):
|
|||||||
for file_idx, file_info in enumerate(files_to_upload):
|
for file_idx, file_info in enumerate(files_to_upload):
|
||||||
try:
|
try:
|
||||||
# 更新整体进度
|
# 更新整体进度
|
||||||
self.upload_progress[task_id].update(
|
self._update_progress(
|
||||||
{
|
task_id,
|
||||||
"status": "processing",
|
status="processing",
|
||||||
"file_index": file_idx,
|
file_index=file_idx,
|
||||||
"file_name": file_info["file_name"],
|
file_name=file_info["file_name"],
|
||||||
"stage": "parsing",
|
stage="parsing",
|
||||||
"current": 0,
|
current=0,
|
||||||
"total": 100,
|
total=100,
|
||||||
},
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建进度回调函数
|
# 创建进度回调函数
|
||||||
async def progress_callback(stage, current, total):
|
progress_callback = self._make_progress_callback(
|
||||||
if task_id in self.upload_progress:
|
task_id, file_idx, file_info["file_name"]
|
||||||
self.upload_progress[task_id].update(
|
)
|
||||||
{
|
|
||||||
"status": "processing",
|
|
||||||
"file_index": file_idx,
|
|
||||||
"file_name": file_info["file_name"],
|
|
||||||
"stage": stage,
|
|
||||||
"current": current,
|
|
||||||
"total": total,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
doc = await kb_helper.upload_document(
|
doc = await kb_helper.upload_document(
|
||||||
file_name=file_info["file_name"],
|
file_name=file_info["file_name"],
|
||||||
@@ -154,23 +200,99 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"failed_count": len(failed_docs),
|
"failed_count": len(failed_docs),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.upload_tasks[task_id] = {
|
self._set_task_result(task_id, "completed", result=result)
|
||||||
"status": "completed",
|
|
||||||
"result": result,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
self.upload_progress[task_id]["status"] = "completed"
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"后台上传任务 {task_id} 失败: {e}")
|
logger.error(f"后台上传任务 {task_id} 失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
self.upload_tasks[task_id] = {
|
self._set_task_result(task_id, "failed", error=str(e))
|
||||||
"status": "failed",
|
|
||||||
"result": None,
|
async def _background_import_task(
|
||||||
"error": str(e),
|
self,
|
||||||
|
task_id: str,
|
||||||
|
kb_helper,
|
||||||
|
documents: list,
|
||||||
|
batch_size: int,
|
||||||
|
tasks_limit: int,
|
||||||
|
max_retries: int,
|
||||||
|
):
|
||||||
|
"""后台导入预切片文档任务"""
|
||||||
|
try:
|
||||||
|
# 初始化任务状态
|
||||||
|
self._init_task(task_id, status="processing")
|
||||||
|
self.upload_progress[task_id] = {
|
||||||
|
"status": "processing",
|
||||||
|
"file_index": 0,
|
||||||
|
"file_total": len(documents),
|
||||||
|
"stage": "waiting",
|
||||||
|
"current": 0,
|
||||||
|
"total": 100,
|
||||||
}
|
}
|
||||||
if task_id in self.upload_progress:
|
|
||||||
self.upload_progress[task_id]["status"] = "failed"
|
uploaded_docs = []
|
||||||
|
failed_docs = []
|
||||||
|
|
||||||
|
for file_idx, doc_info in enumerate(documents):
|
||||||
|
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
|
||||||
|
chunks = doc_info.get("chunks", [])
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 更新整体进度
|
||||||
|
self._update_progress(
|
||||||
|
task_id,
|
||||||
|
status="processing",
|
||||||
|
file_index=file_idx,
|
||||||
|
file_name=file_name,
|
||||||
|
stage="importing",
|
||||||
|
current=0,
|
||||||
|
total=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建进度回调函数
|
||||||
|
progress_callback = self._make_progress_callback(
|
||||||
|
task_id, file_idx, file_name
|
||||||
|
)
|
||||||
|
|
||||||
|
# 调用 upload_document,传入 pre_chunked_text
|
||||||
|
doc = await kb_helper.upload_document(
|
||||||
|
file_name=file_name,
|
||||||
|
file_content=None, # 预切片模式下不需要原始内容
|
||||||
|
file_type=doc_info.get("file_type")
|
||||||
|
or (
|
||||||
|
file_name.rsplit(".", 1)[-1].lower()
|
||||||
|
if "." in file_name
|
||||||
|
else "txt"
|
||||||
|
),
|
||||||
|
batch_size=batch_size,
|
||||||
|
tasks_limit=tasks_limit,
|
||||||
|
max_retries=max_retries,
|
||||||
|
progress_callback=progress_callback,
|
||||||
|
pre_chunked_text=chunks,
|
||||||
|
)
|
||||||
|
|
||||||
|
uploaded_docs.append(doc.model_dump())
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"导入文档 {file_name} 失败: {e}")
|
||||||
|
failed_docs.append(
|
||||||
|
{"file_name": file_name, "error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 更新任务完成状态
|
||||||
|
result = {
|
||||||
|
"task_id": task_id,
|
||||||
|
"uploaded": uploaded_docs,
|
||||||
|
"failed": failed_docs,
|
||||||
|
"total": len(documents),
|
||||||
|
"success_count": len(uploaded_docs),
|
||||||
|
"failed_count": len(failed_docs),
|
||||||
|
}
|
||||||
|
|
||||||
|
self._set_task_result(task_id, "completed", result=result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"后台导入任务 {task_id} 失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
self._set_task_result(task_id, "failed", error=str(e))
|
||||||
|
|
||||||
async def list_kbs(self):
|
async def list_kbs(self):
|
||||||
"""获取知识库列表
|
"""获取知识库列表
|
||||||
@@ -614,11 +736,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self.upload_tasks[task_id] = {
|
self._init_task(task_id, status="pending")
|
||||||
"status": "pending",
|
|
||||||
"result": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 启动后台任务
|
# 启动后台任务
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@@ -653,6 +771,93 @@ class KnowledgeBaseRoute(Route):
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"上传文档失败: {e!s}").__dict__
|
return Response().error(f"上传文档失败: {e!s}").__dict__
|
||||||
|
|
||||||
|
def _validate_import_request(self, data: dict):
|
||||||
|
kb_id = data.get("kb_id")
|
||||||
|
if not kb_id:
|
||||||
|
raise ValueError("缺少参数 kb_id")
|
||||||
|
|
||||||
|
documents = data.get("documents")
|
||||||
|
if not documents or not isinstance(documents, list):
|
||||||
|
raise ValueError("缺少参数 documents 或格式错误")
|
||||||
|
|
||||||
|
for doc in documents:
|
||||||
|
if "file_name" not in doc or "chunks" not in doc:
|
||||||
|
raise ValueError("文档格式错误,必须包含 file_name 和 chunks")
|
||||||
|
if not isinstance(doc["chunks"], list):
|
||||||
|
raise ValueError("chunks 必须是列表")
|
||||||
|
if not all(
|
||||||
|
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
|
||||||
|
):
|
||||||
|
raise ValueError("chunks 必须是非空字符串列表")
|
||||||
|
|
||||||
|
batch_size = data.get("batch_size", 32)
|
||||||
|
tasks_limit = data.get("tasks_limit", 3)
|
||||||
|
max_retries = data.get("max_retries", 3)
|
||||||
|
return kb_id, documents, batch_size, tasks_limit, max_retries
|
||||||
|
|
||||||
|
async def import_documents(self):
|
||||||
|
"""导入预切片文档
|
||||||
|
|
||||||
|
Body:
|
||||||
|
- kb_id: 知识库 ID (必填)
|
||||||
|
- documents: 文档列表 (必填)
|
||||||
|
- file_name: 文件名 (必填)
|
||||||
|
- chunks: 切片列表 (必填, list[str])
|
||||||
|
- file_type: 文件类型 (可选, 默认从文件名推断或为 txt)
|
||||||
|
- batch_size: 批处理大小 (可选, 默认32)
|
||||||
|
- tasks_limit: 并发任务限制 (可选, 默认3)
|
||||||
|
- max_retries: 最大重试次数 (可选, 默认3)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
kb_manager = self._get_kb_manager()
|
||||||
|
data = await request.json
|
||||||
|
|
||||||
|
kb_id, documents, batch_size, tasks_limit, max_retries = (
|
||||||
|
self._validate_import_request(data)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取知识库
|
||||||
|
kb_helper = await kb_manager.get_kb(kb_id)
|
||||||
|
if not kb_helper:
|
||||||
|
return Response().error("知识库不存在").__dict__
|
||||||
|
|
||||||
|
# 生成任务ID
|
||||||
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
# 初始化任务状态
|
||||||
|
self._init_task(task_id, status="pending")
|
||||||
|
|
||||||
|
# 启动后台任务
|
||||||
|
asyncio.create_task(
|
||||||
|
self._background_import_task(
|
||||||
|
task_id=task_id,
|
||||||
|
kb_helper=kb_helper,
|
||||||
|
documents=documents,
|
||||||
|
batch_size=batch_size,
|
||||||
|
tasks_limit=tasks_limit,
|
||||||
|
max_retries=max_retries,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return (
|
||||||
|
Response()
|
||||||
|
.ok(
|
||||||
|
{
|
||||||
|
"task_id": task_id,
|
||||||
|
"doc_count": len(documents),
|
||||||
|
"message": "import task created, processing in background",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.__dict__
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError as e:
|
||||||
|
return Response().error(str(e)).__dict__
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"导入文档失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return Response().error(f"导入文档失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def get_upload_progress(self):
|
async def get_upload_progress(self):
|
||||||
"""获取上传进度和结果
|
"""获取上传进度和结果
|
||||||
|
|
||||||
@@ -960,11 +1165,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self.upload_tasks[task_id] = {
|
self._init_task(task_id, status="pending")
|
||||||
"status": "pending",
|
|
||||||
"result": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 启动后台任务
|
# 启动后台任务
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@@ -1017,11 +1218,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"""后台上传URL任务"""
|
"""后台上传URL任务"""
|
||||||
try:
|
try:
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self.upload_tasks[task_id] = {
|
self._init_task(task_id, status="processing")
|
||||||
"status": "processing",
|
|
||||||
"result": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
self.upload_progress[task_id] = {
|
self.upload_progress[task_id] = {
|
||||||
"status": "processing",
|
"status": "processing",
|
||||||
"file_index": 0,
|
"file_index": 0,
|
||||||
@@ -1033,18 +1230,7 @@ class KnowledgeBaseRoute(Route):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 创建进度回调函数
|
# 创建进度回调函数
|
||||||
async def progress_callback(stage, current, total):
|
progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}")
|
||||||
if task_id in self.upload_progress:
|
|
||||||
self.upload_progress[task_id].update(
|
|
||||||
{
|
|
||||||
"status": "processing",
|
|
||||||
"file_index": 0,
|
|
||||||
"file_name": f"URL: {url}",
|
|
||||||
"stage": stage,
|
|
||||||
"current": current,
|
|
||||||
"total": total,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 上传文档
|
# 上传文档
|
||||||
doc = await kb_helper.upload_from_url(
|
doc = await kb_helper.upload_from_url(
|
||||||
@@ -1069,20 +1255,9 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"failed_count": 0,
|
"failed_count": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.upload_tasks[task_id] = {
|
self._set_task_result(task_id, "completed", result=result)
|
||||||
"status": "completed",
|
|
||||||
"result": result,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
self.upload_progress[task_id]["status"] = "completed"
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
|
logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
self.upload_tasks[task_id] = {
|
self._set_task_result(task_id, "failed", error=str(e))
|
||||||
"status": "failed",
|
|
||||||
"result": None,
|
|
||||||
"error": str(e),
|
|
||||||
}
|
|
||||||
if task_id in self.upload_progress:
|
|
||||||
self.upload_progress[task_id]["status"] = "failed"
|
|
||||||
|
|||||||
@@ -124,7 +124,11 @@ class PluginRoute(Route):
|
|||||||
session.get(url) as response,
|
session.get(url) as response,
|
||||||
):
|
):
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
remote_data = await response.json()
|
try:
|
||||||
|
remote_data = await response.json()
|
||||||
|
except aiohttp.ContentTypeError:
|
||||||
|
remote_text = await response.text()
|
||||||
|
remote_data = json.loads(remote_text)
|
||||||
|
|
||||||
# 检查远程数据是否为空
|
# 检查远程数据是否为空
|
||||||
if not remote_data or (
|
if not remote_data or (
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import traceback
|
|||||||
from quart import request
|
from quart import request
|
||||||
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
|
from astrbot.core.agent.mcp_client import MCPTool
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
from astrbot.core.star import star_map
|
from astrbot.core.star import star_map
|
||||||
|
|
||||||
@@ -296,15 +297,30 @@ class ToolsRoute(Route):
|
|||||||
"""获取所有注册的工具列表"""
|
"""获取所有注册的工具列表"""
|
||||||
try:
|
try:
|
||||||
tools = self.tool_mgr.func_list
|
tools = self.tool_mgr.func_list
|
||||||
tools_dict = [
|
tools_dict = []
|
||||||
{
|
for tool in tools:
|
||||||
|
if isinstance(tool, MCPTool):
|
||||||
|
origin = "mcp"
|
||||||
|
origin_name = tool.mcp_server_name
|
||||||
|
elif tool.handler_module_path and star_map.get(
|
||||||
|
tool.handler_module_path
|
||||||
|
):
|
||||||
|
star = star_map[tool.handler_module_path]
|
||||||
|
origin = "plugin"
|
||||||
|
origin_name = star.name
|
||||||
|
else:
|
||||||
|
origin = "unknown"
|
||||||
|
origin_name = "unknown"
|
||||||
|
|
||||||
|
tool_info = {
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": tool.parameters,
|
"parameters": tool.parameters,
|
||||||
"active": tool.active,
|
"active": tool.active,
|
||||||
|
"origin": origin,
|
||||||
|
"origin_name": origin_name,
|
||||||
}
|
}
|
||||||
for tool in tools
|
tools_dict.append(tool_info)
|
||||||
]
|
|
||||||
return Response().ok(data=tools_dict).__dict__
|
return Response().ok(data=tools_dict).__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ class AstrBotDashboard:
|
|||||||
core_lifecycle,
|
core_lifecycle,
|
||||||
core_lifecycle.plugin_manager,
|
core_lifecycle.plugin_manager,
|
||||||
)
|
)
|
||||||
|
self.command_route = CommandRoute(self.context)
|
||||||
self.cr = ConfigRoute(self.context, core_lifecycle)
|
self.cr = ConfigRoute(self.context, core_lifecycle)
|
||||||
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
|
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
|
||||||
self.sfr = StaticFileRoute(self.context)
|
self.sfr = StaticFileRoute(self.context)
|
||||||
|
|||||||
@@ -0,0 +1,134 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Use Nuitka to build the AstrBot project into standalone executables
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_platform_info():
|
||||||
|
"""fetch the current platform information"""
|
||||||
|
system = platform.system()
|
||||||
|
machine = platform.machine()
|
||||||
|
return system, machine
|
||||||
|
|
||||||
|
|
||||||
|
def build_with_nuitka():
|
||||||
|
"""use Nuitka to build the project"""
|
||||||
|
system, machine = get_platform_info()
|
||||||
|
|
||||||
|
print(f"🚀 Starting build for {system} ({machine}) platform...")
|
||||||
|
|
||||||
|
# Output directory
|
||||||
|
output_dir = Path("build/nuitka")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Base Nuitka command
|
||||||
|
nuitka_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"nuitka",
|
||||||
|
"--standalone", # Create standalone directory
|
||||||
|
"--onefile", # Single file mode
|
||||||
|
"--follow-imports", # Follow all imports
|
||||||
|
"--enable-plugin=multiprocessing", # Enable multiprocessing support
|
||||||
|
"--output-dir=build/nuitka", # Output directory
|
||||||
|
"--quiet", # Reduce output verbosity
|
||||||
|
"--assume-yes-for-downloads", # Automatically download dependencies
|
||||||
|
"--jobs=4", # Use multiple CPU cores
|
||||||
|
]
|
||||||
|
|
||||||
|
# include specific packages
|
||||||
|
include_packages = [
|
||||||
|
"astrbot",
|
||||||
|
]
|
||||||
|
|
||||||
|
for pkg in include_packages:
|
||||||
|
nuitka_cmd.extend([f"--include-package={pkg}"])
|
||||||
|
|
||||||
|
# include data directories
|
||||||
|
# data_includes = [
|
||||||
|
# "data/config",
|
||||||
|
# "data/plugins",
|
||||||
|
# "data/temp",
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# for data_dir in data_includes:
|
||||||
|
# if os.path.exists(data_dir):
|
||||||
|
# nuitka_cmd.extend([f"--include-data-dir={data_dir}={data_dir}"])
|
||||||
|
|
||||||
|
# include packages directory (built-in plugins)
|
||||||
|
# if os.path.exists("packages"):
|
||||||
|
# nuitka_cmd.extend(["--include-data-dir=packages=packages"])
|
||||||
|
|
||||||
|
# Platform specific settings
|
||||||
|
if system == "Darwin": # macOS
|
||||||
|
nuitka_cmd.extend(
|
||||||
|
[
|
||||||
|
"--macos-create-app-bundle", # Create .app bundle
|
||||||
|
"--macos-app-name=AstrBot",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# macOS icon (if exists)
|
||||||
|
icon_path = "dashboard/src-tauri/icons/icon.icns"
|
||||||
|
if os.path.exists(icon_path):
|
||||||
|
nuitka_cmd.extend([f"--macos-app-icon={icon_path}"])
|
||||||
|
elif system == "Windows":
|
||||||
|
nuitka_cmd.extend(
|
||||||
|
[
|
||||||
|
"--windows-console-mode=disable", # 无控制台窗口
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Windows icon (if exists)
|
||||||
|
icon_path = "dashboard/src-tauri/icons/icon.ico"
|
||||||
|
if os.path.exists(icon_path):
|
||||||
|
nuitka_cmd.extend([f"--windows-icon-from-ico={icon_path}"])
|
||||||
|
|
||||||
|
# Main file to compile
|
||||||
|
nuitka_cmd.append("main.py")
|
||||||
|
|
||||||
|
print(f"📦 Executing command: {' '.join(nuitka_cmd)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(nuitka_cmd, check=True)
|
||||||
|
print("✅ Nuitka build successful!")
|
||||||
|
|
||||||
|
# Find the generated executable
|
||||||
|
if system == "Darwin":
|
||||||
|
built_file = list(output_dir.glob("*.app"))
|
||||||
|
if built_file:
|
||||||
|
print(f"Generated macOS app: {built_file[0]}")
|
||||||
|
elif system == "Windows":
|
||||||
|
built_file = list(output_dir.glob("*.exe"))
|
||||||
|
if built_file:
|
||||||
|
print(f"Generated Windows executable: {built_file[0]}")
|
||||||
|
else: # Linux
|
||||||
|
built_file = list(output_dir.glob("main.bin"))
|
||||||
|
if built_file:
|
||||||
|
print(f"Generated Linux executable: {built_file[0]}")
|
||||||
|
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"❌ Nuitka build failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=" * 60)
|
||||||
|
print("AstrBot Nuitka Builder")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# 构建
|
||||||
|
if build_with_nuitka():
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("🎉 Build Complete!")
|
||||||
|
print("=" * 60)
|
||||||
|
else:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("❌ Build Failed")
|
||||||
|
print("=" * 60)
|
||||||
|
sys.exit(1)
|
||||||
@@ -0,0 +1,134 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Use PyInstaller to build the AstrBot project into standalone executables
|
||||||
|
"""
|
||||||
|
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def get_platform_info():
|
||||||
|
"""fetch the current platform information"""
|
||||||
|
system = platform.system()
|
||||||
|
machine = platform.machine()
|
||||||
|
return system, machine
|
||||||
|
|
||||||
|
|
||||||
|
def build_with_pyinstaller():
|
||||||
|
"""use PyInstaller to build the project"""
|
||||||
|
system, machine = get_platform_info()
|
||||||
|
|
||||||
|
print(f"🚀 Starting build for {system} ({machine}) platform...")
|
||||||
|
|
||||||
|
# Output directory
|
||||||
|
output_dir = Path("build/pyinstaller")
|
||||||
|
output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Base PyInstaller command
|
||||||
|
pyinstaller_cmd = [
|
||||||
|
sys.executable,
|
||||||
|
"-m",
|
||||||
|
"PyInstaller",
|
||||||
|
"--clean", # Clean cache before build
|
||||||
|
"--noconfirm", # Replace output directory without asking
|
||||||
|
"--onefile", # Single file mode
|
||||||
|
"--distpath=build/pyinstaller/dist", # Distribution directory
|
||||||
|
"--workpath=build/pyinstaller/build", # Work directory
|
||||||
|
"--specpath=build/pyinstaller", # Spec file directory
|
||||||
|
"--name=AstrBot", # Output executable name
|
||||||
|
]
|
||||||
|
# Platform specific settings
|
||||||
|
# if system == "Darwin": # macOS
|
||||||
|
# # macOS icon (if exists)
|
||||||
|
# icon_path = "dashboard/src-tauri/icons/icon.icns"
|
||||||
|
# if os.path.exists(icon_path):
|
||||||
|
# pyinstaller_cmd.extend([f"--icon={icon_path}"])
|
||||||
|
# # Create .app bundle
|
||||||
|
# pyinstaller_cmd.extend(["--windowed"])
|
||||||
|
# elif system == "Windows":
|
||||||
|
# # Windows icon (if exists)
|
||||||
|
# icon_path = "dashboard/src-tauri/icons/icon.ico"
|
||||||
|
# if os.path.exists(icon_path):
|
||||||
|
# pyinstaller_cmd.extend([f"--icon={icon_path}"])
|
||||||
|
# # No console window
|
||||||
|
# pyinstaller_cmd.extend(["--windowed"])
|
||||||
|
# else: # Linux
|
||||||
|
# pyinstaller_cmd.extend(["--console"])
|
||||||
|
|
||||||
|
# Main file to compile
|
||||||
|
pyinstaller_cmd.append("main.py")
|
||||||
|
|
||||||
|
print(f"📦 Executing command: {' '.join(pyinstaller_cmd)}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
subprocess.run(pyinstaller_cmd, check=True)
|
||||||
|
print("✅ PyInstaller build successful!")
|
||||||
|
|
||||||
|
# Find the generated executable
|
||||||
|
dist_dir = output_dir / "dist"
|
||||||
|
if system == "Darwin":
|
||||||
|
built_file = list(dist_dir.glob("AstrBot.app"))
|
||||||
|
if not built_file:
|
||||||
|
built_file = list(dist_dir.glob("AstrBot"))
|
||||||
|
if built_file:
|
||||||
|
print(f"📱 Generated macOS app: {built_file[0]}")
|
||||||
|
elif system == "Windows":
|
||||||
|
built_file = list(dist_dir.glob("AstrBot.exe"))
|
||||||
|
if built_file:
|
||||||
|
print(f"💻 Generated Windows executable: {built_file[0]}")
|
||||||
|
else: # Linux
|
||||||
|
built_file = list(dist_dir.glob("AstrBot"))
|
||||||
|
if built_file:
|
||||||
|
print(f"🐧 Generated Linux executable: {built_file[0]}")
|
||||||
|
|
||||||
|
print(f"\n📁 Output directory: {dist_dir.absolute()}")
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"❌ PyInstaller build failed: {e}")
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ Unexpected error: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def install_pyinstaller():
|
||||||
|
"""Install PyInstaller if not already installed"""
|
||||||
|
try:
|
||||||
|
import PyInstaller
|
||||||
|
|
||||||
|
print(f"✅ PyInstaller already installed (version {PyInstaller.__version__})")
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
print("📥 PyInstaller not found, installing...")
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
[sys.executable, "-m", "pip", "install", "pyinstaller"], check=True
|
||||||
|
)
|
||||||
|
print("✅ PyInstaller installed successfully!")
|
||||||
|
return True
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"❌ Failed to install PyInstaller: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("=" * 60)
|
||||||
|
print("AstrBot PyInstaller Builder")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
# Check and install PyInstaller
|
||||||
|
if not install_pyinstaller():
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Build
|
||||||
|
if build_with_pyinstaller():
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("🎉 Build Complete!")
|
||||||
|
print("=" * 60)
|
||||||
|
else:
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("❌ Build Failed")
|
||||||
|
print("=" * 60)
|
||||||
|
sys.exit(1)
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
## What's Changed
|
||||||
|
|
||||||
|
### 新增
|
||||||
|
|
||||||
|
- 支持自定义插件源。
|
||||||
|
- 支持飞书(Lark)的 Webhook 模式(将事件推送至开发者服务器)。
|
||||||
|
- 支持 “禁用自带指令” 快捷配置项,启用后将禁用所有 AstrBot 自带指令。入口: WebUI -> 配置文件 -> 平台配置。
|
||||||
|
|
||||||
|
### 优化
|
||||||
|
|
||||||
|
- 从 WebUI 移除了开发版本渠道。
|
||||||
|
- 当试图测试"Agent Runner"时,提示前往配置文件页测试。
|
||||||
|
- WebUI 列表项支持批量粘贴、回车创建项目。
|
||||||
|
|
||||||
|
### 修复
|
||||||
|
|
||||||
|
- Gemini API 部分调用失败的问题。
|
||||||
|
- WebUI 插件安装加载 Dialog 关闭按钮在手机端下显示异常的问题。
|
||||||
|
- 部分情况下,WebUI 日志显示不全的问题。
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
## What's Changed
|
||||||
|
|
||||||
|
-
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
## What's Changed
|
||||||
|
|
||||||
|
### 修复
|
||||||
|
|
||||||
|
- 企业自部署飞书(自定义 domain)可以接收消息但无法发送消息的问题。
|
||||||
|
- 安装插件 Dialog 的深色样式问题。
|
||||||
|
|
||||||
|
### 优化
|
||||||
|
|
||||||
|
- 避免某些插件在流式响应结束后重d复发送消息的问题。
|
||||||
|
|
||||||
|
### 新增
|
||||||
|
|
||||||
|
- 支持在对话管理批量导出对话轨迹数据为 `jsonl` 格式文件。入口:WebUI -> 对话管理 -> 批量选中 -> 导出。
|
||||||
|
- 支持对 TTS(文本转语音)设置概率触发。
|
||||||
|
- (插件开发)支持在 schema 中对 float 和 int 类型设置 `slider` 滑块控件。例如 `slider: {min: 0, max: 1, step: 0.1}`。
|
||||||
|
- (插件开发)支持 key-value 存储功能。例如使用 `await self.put_kv_data("key", value)`, `await self.get_kv_data("key", default_value)` 和 `await self.delete_kv_data("key")`。
|
||||||
@@ -0,0 +1,225 @@
|
|||||||
|
# AstrBot Dashboard - Tauri 桌面应用
|
||||||
|
|
||||||
|
本项目现已支持通过 Tauri 构建为桌面应用,同时保持与 Web 版本的兼容性。
|
||||||
|
|
||||||
|
## 环境要求
|
||||||
|
|
||||||
|
### 系统依赖
|
||||||
|
|
||||||
|
**macOS:**
|
||||||
|
```bash
|
||||||
|
# 安装 Xcode Command Line Tools
|
||||||
|
xcode-select --install
|
||||||
|
```
|
||||||
|
|
||||||
|
**Windows:**
|
||||||
|
- 安装 [Microsoft Visual Studio C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/)
|
||||||
|
- 安装 [WebView2](https://developer.microsoft.com/en-us/microsoft-edge/webview2/)
|
||||||
|
|
||||||
|
**Linux (Ubuntu/Debian):**
|
||||||
|
```bash
|
||||||
|
sudo apt update
|
||||||
|
sudo apt install libwebkit2gtk-4.0-dev \
|
||||||
|
build-essential \
|
||||||
|
curl \
|
||||||
|
wget \
|
||||||
|
file \
|
||||||
|
libssl-dev \
|
||||||
|
libgtk-3-dev \
|
||||||
|
libayatana-appindicator3-dev \
|
||||||
|
librsvg2-dev
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rust 环境
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装 Rust
|
||||||
|
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
||||||
|
|
||||||
|
# 验证安装
|
||||||
|
rustc --version
|
||||||
|
cargo --version
|
||||||
|
```
|
||||||
|
|
||||||
|
## 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cd dashboard
|
||||||
|
npm install
|
||||||
|
```
|
||||||
|
|
||||||
|
## 开发模式
|
||||||
|
|
||||||
|
### Web 端开发(不变)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run dev
|
||||||
|
```
|
||||||
|
|
||||||
|
访问 http://localhost:3000
|
||||||
|
|
||||||
|
### 桌面端开发
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run tauri:dev
|
||||||
|
```
|
||||||
|
|
||||||
|
这会同时启动:
|
||||||
|
1. Vite 开发服务器(端口 3000)
|
||||||
|
2. Tauri 桌面应用窗口
|
||||||
|
|
||||||
|
热重载功能正常工作,修改代码后会自动刷新。
|
||||||
|
|
||||||
|
## 构建
|
||||||
|
|
||||||
|
### Web 端构建(不变)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run build
|
||||||
|
```
|
||||||
|
|
||||||
|
输出目录:`dist/`
|
||||||
|
|
||||||
|
### 桌面端构建
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run tauri:build
|
||||||
|
```
|
||||||
|
|
||||||
|
构建产物位置:
|
||||||
|
- **macOS**: `src-tauri/target/release/bundle/dmg/`
|
||||||
|
- **Windows**: `src-tauri/target/release/bundle/msi/`
|
||||||
|
- **Linux**: `src-tauri/target/release/bundle/deb/` 或 `appimage/`
|
||||||
|
|
||||||
|
## 图标设置
|
||||||
|
|
||||||
|
### 自动生成图标
|
||||||
|
|
||||||
|
准备一个至少 512x512 像素的 PNG 图标,然后运行:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm run tauri icon path/to/your/icon.png
|
||||||
|
```
|
||||||
|
|
||||||
|
### 手动设置图标
|
||||||
|
|
||||||
|
将以下图标放入 `src-tauri/icons/` 目录:
|
||||||
|
- `32x32.png`
|
||||||
|
- `128x128.png`
|
||||||
|
- `128x128@2x.png`
|
||||||
|
- `icon.icns` (macOS)
|
||||||
|
- `icon.ico` (Windows)
|
||||||
|
|
||||||
|
## 代码兼容性
|
||||||
|
|
||||||
|
项目已配置为同时支持 Web 和桌面端,使用相同的代码库。
|
||||||
|
|
||||||
|
### 环境检测工具
|
||||||
|
|
||||||
|
在 `src/utils/tauri.ts` 中提供了环境检测工具:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { isTauri, isWeb, PlatformAPI } from '@/utils/tauri';
|
||||||
|
|
||||||
|
// 检测运行环境
|
||||||
|
if (isTauri()) {
|
||||||
|
console.log('运行在桌面应用中');
|
||||||
|
} else {
|
||||||
|
console.log('运行在浏览器中');
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取正确的 API 端点
|
||||||
|
const baseURL = PlatformAPI.getBaseURL();
|
||||||
|
```
|
||||||
|
|
||||||
|
### API 调用注意事项
|
||||||
|
|
||||||
|
- **Web 端**: 使用 Vite 代理,API 路径为 `/api/*`
|
||||||
|
- **桌面端**: 直接连接到 `http://127.0.0.1:6185`
|
||||||
|
|
||||||
|
已在 `PlatformAPI.getBaseURL()` 中处理,使用 axios 时:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import axios from 'axios';
|
||||||
|
import { PlatformAPI } from '@/utils/tauri';
|
||||||
|
|
||||||
|
const api = axios.create({
|
||||||
|
baseURL: PlatformAPI.getBaseURL()
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
## 配置说明
|
||||||
|
|
||||||
|
### tauri.conf.json
|
||||||
|
|
||||||
|
主要配置项:
|
||||||
|
- `build.devPath`: 开发服务器地址(http://localhost:3000)
|
||||||
|
- `build.distDir`: 构建输出目录(../dist)
|
||||||
|
- `tauri.allowlist`: API 权限配置
|
||||||
|
- `tauri.windows`: 窗口配置(大小、标题等)
|
||||||
|
|
||||||
|
### 安全性
|
||||||
|
|
||||||
|
默认配置已启用必要的权限:
|
||||||
|
- 文件系统访问(限定在 APPDATA 目录)
|
||||||
|
- HTTP 请求(限定到本地后端)
|
||||||
|
- 窗口控制
|
||||||
|
- 对话框(打开/保存文件)
|
||||||
|
|
||||||
|
可在 `tauri.conf.json` 的 `allowlist` 部分调整权限。
|
||||||
|
|
||||||
|
## 后端连接
|
||||||
|
|
||||||
|
桌面应用需要后端服务运行在 `http://127.0.0.1:6185`。
|
||||||
|
|
||||||
|
### 启动流程
|
||||||
|
|
||||||
|
1. 启动 AstrBot 后端:
|
||||||
|
```bash
|
||||||
|
cd /path/to/AstrBot
|
||||||
|
uv run main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 启动桌面应用:
|
||||||
|
```bash
|
||||||
|
cd dashboard
|
||||||
|
npm run tauri:dev
|
||||||
|
```
|
||||||
|
|
||||||
|
或直接运行打包后的应用(后端需要已启动)。
|
||||||
|
|
||||||
|
## 常见问题
|
||||||
|
|
||||||
|
### Q: 桌面应用无法连接到后端?
|
||||||
|
|
||||||
|
确保:
|
||||||
|
1. AstrBot 后端正在运行(`uv run main.py`)
|
||||||
|
2. 后端监听在 `127.0.0.1:6185`
|
||||||
|
3. 防火墙未阻止连接
|
||||||
|
|
||||||
|
### Q: 图标未显示?
|
||||||
|
|
||||||
|
检查 `src-tauri/icons/` 目录中是否有所需的图标文件,或使用 `npm run tauri icon` 命令生成。
|
||||||
|
|
||||||
|
### Q: 构建失败?
|
||||||
|
|
||||||
|
- 确保已安装 Rust 和系统依赖
|
||||||
|
- 运行 `cargo clean` 清理缓存后重试
|
||||||
|
- 检查 Rust 版本(需要 1.60+)
|
||||||
|
|
||||||
|
### Q: Web 端功能是否受影响?
|
||||||
|
|
||||||
|
不受影响。`npm run dev` 和 `npm run build` 的行为完全不变。
|
||||||
|
|
||||||
|
## 开发建议
|
||||||
|
|
||||||
|
1. **优先使用 Web 端开发**: 更快的热重载,更好的调试体验
|
||||||
|
2. **定期测试桌面端**: 确保跨平台兼容性
|
||||||
|
3. **使用环境检测**: 针对不同平台提供最佳体验
|
||||||
|
4. **注意 API 差异**: Web 和桌面端的某些 API 可能有差异
|
||||||
|
|
||||||
|
## 更多资源
|
||||||
|
|
||||||
|
- [Tauri 官方文档](https://tauri.app/)
|
||||||
|
- [Tauri API 参考](https://tauri.app/v1/api/js/)
|
||||||
|
- [Tauri Discord 社区](https://discord.com/invite/tauri)
|
||||||
@@ -10,10 +10,14 @@
|
|||||||
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
|
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
|
||||||
"preview": "vite preview --port 5050",
|
"preview": "vite preview --port 5050",
|
||||||
"typecheck": "vue-tsc --noEmit",
|
"typecheck": "vue-tsc --noEmit",
|
||||||
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
|
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore",
|
||||||
|
"tauri": "tauri",
|
||||||
|
"tauri:dev": "tauri dev",
|
||||||
|
"tauri:build": "tauri build"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@guolao/vue-monaco-editor": "^1.5.4",
|
"@guolao/vue-monaco-editor": "^1.5.4",
|
||||||
|
"@tauri-apps/api": "^2.9.0",
|
||||||
"@tiptap/starter-kit": "2.1.7",
|
"@tiptap/starter-kit": "2.1.7",
|
||||||
"@tiptap/vue-3": "2.1.7",
|
"@tiptap/vue-3": "2.1.7",
|
||||||
"apexcharts": "3.42.0",
|
"apexcharts": "3.42.0",
|
||||||
@@ -43,6 +47,7 @@
|
|||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@mdi/font": "7.2.96",
|
"@mdi/font": "7.2.96",
|
||||||
"@rushstack/eslint-patch": "1.3.3",
|
"@rushstack/eslint-patch": "1.3.3",
|
||||||
|
"@tauri-apps/cli": "^2.9.4",
|
||||||
"@types/chance": "1.1.3",
|
"@types/chance": "1.1.3",
|
||||||
"@types/markdown-it": "^14.1.2",
|
"@types/markdown-it": "^14.1.2",
|
||||||
"@types/node": "^20.5.7",
|
"@types/node": "^20.5.7",
|
||||||
|
|||||||
@@ -0,0 +1,3 @@
|
|||||||
|
# Tauri specific
|
||||||
|
src-tauri/target/
|
||||||
|
src-tauri/WixTools/
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
[package]
|
||||||
|
name = "astrbot-dashboard"
|
||||||
|
version = "4.5.6"
|
||||||
|
description = "AstrBot"
|
||||||
|
authors = ["AstrBot Team"]
|
||||||
|
license = "AGPL-3.0"
|
||||||
|
repository = "https://github.com/AstrBotDevs/AstrBot"
|
||||||
|
default-run = "astrbot-dashboard"
|
||||||
|
edition = "2021"
|
||||||
|
rust-version = "1.91.0"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[build-dependencies]
|
||||||
|
tauri-build = { version = "2", features = [] }
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
serde_json = "1.0"
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
tauri = { version = "2.9.2", features = ["macos-private-api", "protocol-asset"] }
|
||||||
|
tauri-plugin-opener = "2"
|
||||||
|
|
||||||
|
[features]
|
||||||
|
# this feature is used for production builds or when `devPath` points to the filesystem and the built-in dev server is disabled.
|
||||||
|
# If you use cargo directly instead of tauri's cli you can use this feature flag to switch between tauri's `dev` and `build` modes.
|
||||||
|
# DO NOT REMOVE!!
|
||||||
|
custom-protocol = [ "tauri/custom-protocol" ]
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
fn main() {
|
||||||
|
tauri_build::build()
|
||||||
|
}
|
||||||
@@ -0,0 +1 @@
|
|||||||
|
{}
|
||||||
|
After Width: | Height: | Size: 7.3 KiB |
|
After Width: | Height: | Size: 18 KiB |
|
After Width: | Height: | Size: 1.3 KiB |
|
After Width: | Height: | Size: 3.2 KiB |
|
After Width: | Height: | Size: 5.9 KiB |
|
After Width: | Height: | Size: 8.2 KiB |
|
After Width: | Height: | Size: 8.8 KiB |
|
After Width: | Height: | Size: 20 KiB |
|
After Width: | Height: | Size: 1.2 KiB |
|
After Width: | Height: | Size: 23 KiB |
|
After Width: | Height: | Size: 2.0 KiB |
|
After Width: | Height: | Size: 3.5 KiB |
|
After Width: | Height: | Size: 4.8 KiB |
|
After Width: | Height: | Size: 2.3 KiB |
@@ -0,0 +1,5 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
||||||
|
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
|
||||||
|
<background android:drawable="@color/ic_launcher_background"/>
|
||||||
|
</adaptive-icon>
|
||||||
|
After Width: | Height: | Size: 2.2 KiB |
|
After Width: | Height: | Size: 9.8 KiB |
|
After Width: | Height: | Size: 2.0 KiB |
|
After Width: | Height: | Size: 2.1 KiB |
|
After Width: | Height: | Size: 6.0 KiB |
|
After Width: | Height: | Size: 1.8 KiB |
|
After Width: | Height: | Size: 4.9 KiB |
|
After Width: | Height: | Size: 14 KiB |
|
After Width: | Height: | Size: 4.2 KiB |
|
After Width: | Height: | Size: 7.9 KiB |
|
After Width: | Height: | Size: 24 KiB |
|
After Width: | Height: | Size: 6.8 KiB |
|
After Width: | Height: | Size: 11 KiB |
|
After Width: | Height: | Size: 37 KiB |
|
After Width: | Height: | Size: 9.6 KiB |
@@ -0,0 +1,4 @@
|
|||||||
|
<?xml version="1.0" encoding="utf-8"?>
|
||||||
|
<resources>
|
||||||
|
<color name="ic_launcher_background">#fff</color>
|
||||||
|
</resources>
|
||||||
|
After Width: | Height: | Size: 27 KiB |
|
After Width: | Height: | Size: 47 KiB |
|
After Width: | Height: | Size: 602 B |
|
After Width: | Height: | Size: 1.4 KiB |