Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a2fe0ec5a1 | |||
| 6957ec713d | |||
| d97c8b5b2b | |||
| d07a1ad5c9 | |||
| d8e6dfbd6b |
@@ -1,79 +0,0 @@
|
|||||||
name: Build Desktop App
|
|
||||||
|
|
||||||
on:
|
|
||||||
push:
|
|
||||||
tags:
|
|
||||||
- 'v*'
|
|
||||||
workflow_dispatch:
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
build:
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
platform: [macos-latest, ubuntu-latest, windows-latest]
|
|
||||||
|
|
||||||
runs-on: ${{ matrix.platform }}
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: '3.10'
|
|
||||||
|
|
||||||
- name: Setup Node.js
|
|
||||||
uses: actions/setup-node@v4
|
|
||||||
with:
|
|
||||||
node-version: 20
|
|
||||||
|
|
||||||
- name: Install Rust
|
|
||||||
uses: dtolnay/rust-toolchain@stable
|
|
||||||
|
|
||||||
- name: Install dependencies (Ubuntu)
|
|
||||||
if: matrix.platform == 'ubuntu-latest'
|
|
||||||
run: |
|
|
||||||
sudo apt-get update
|
|
||||||
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev libappindicator3-dev librsvg2-dev patchelf
|
|
||||||
|
|
||||||
- name: Install Python dependencies
|
|
||||||
run: |
|
|
||||||
pip install uv
|
|
||||||
uv sync
|
|
||||||
|
|
||||||
- name: Build Python backend with Nuitka
|
|
||||||
run: |
|
|
||||||
pip install nuitka
|
|
||||||
python build_nuitka.py
|
|
||||||
|
|
||||||
- name: Install Node dependencies
|
|
||||||
working-directory: ./dashboard
|
|
||||||
run: npm install
|
|
||||||
|
|
||||||
- name: Build Tauri app
|
|
||||||
working-directory: ./dashboard
|
|
||||||
run: npm run tauri:build
|
|
||||||
|
|
||||||
- name: Upload artifacts (macOS)
|
|
||||||
if: matrix.platform == 'macos-latest'
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: astrbot-macos
|
|
||||||
path: dashboard/src-tauri/target/release/bundle/dmg/*.dmg
|
|
||||||
|
|
||||||
- name: Upload artifacts (Windows)
|
|
||||||
if: matrix.platform == 'windows-latest'
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: astrbot-windows
|
|
||||||
path: dashboard/src-tauri/target/release/bundle/msi/*.msi
|
|
||||||
|
|
||||||
- name: Upload artifacts (Linux)
|
|
||||||
if: matrix.platform == 'ubuntu-latest'
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: astrbot-linux
|
|
||||||
path: |
|
|
||||||
dashboard/src-tauri/target/release/bundle/deb/*.deb
|
|
||||||
dashboard/src-tauri/target/release/bundle/appimage/*.AppImage
|
|
||||||
@@ -36,7 +36,7 @@ jobs:
|
|||||||
zip -r dist.zip dist
|
zip -r dist.zip dist
|
||||||
|
|
||||||
- name: Archive production artifacts
|
- name: Archive production artifacts
|
||||||
uses: actions/upload-artifact@v6
|
uses: actions/upload-artifact@v5
|
||||||
with:
|
with:
|
||||||
name: dist-without-markdown
|
name: dist-without-markdown
|
||||||
path: |
|
path: |
|
||||||
|
|||||||
@@ -32,7 +32,6 @@ tests/astrbot_plugin_openai
|
|||||||
# Dashboard
|
# Dashboard
|
||||||
dashboard/node_modules/
|
dashboard/node_modules/
|
||||||
dashboard/dist/
|
dashboard/dist/
|
||||||
dashboard/src-tauri/target
|
|
||||||
package-lock.json
|
package-lock.json
|
||||||
package.json
|
package.json
|
||||||
yarn.lock
|
yarn.lock
|
||||||
@@ -49,6 +48,5 @@ astrbot.lock
|
|||||||
chroma
|
chroma
|
||||||
venv/*
|
venv/*
|
||||||
pytest.ini
|
pytest.ini
|
||||||
build/
|
|
||||||
AGENTS.md
|
AGENTS.md
|
||||||
IFLOW.md
|
IFLOW.md
|
||||||
|
|||||||
@@ -1,287 +0,0 @@
|
|||||||
# AstrBot 桌面应用构建指南
|
|
||||||
|
|
||||||
本指南介绍如何使用 Nuitka 将 Python 后端打包并集成到 Tauri 桌面应用中。
|
|
||||||
|
|
||||||
## 前置要求
|
|
||||||
|
|
||||||
### 系统要求
|
|
||||||
- Python 3.10+
|
|
||||||
- Node.js 20+
|
|
||||||
- Rust (通过 rustup 安装)
|
|
||||||
- UV 包管理器
|
|
||||||
|
|
||||||
### macOS 额外要求
|
|
||||||
- Xcode Command Line Tools: `xcode-select --install`
|
|
||||||
|
|
||||||
### Linux 额外要求
|
|
||||||
```bash
|
|
||||||
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev \
|
|
||||||
libappindicator3-dev librsvg2-dev patchelf
|
|
||||||
```
|
|
||||||
|
|
||||||
### Windows 额外要求
|
|
||||||
- Visual Studio 2019+ with C++ build tools
|
|
||||||
- Windows 10 SDK
|
|
||||||
|
|
||||||
## 构建步骤
|
|
||||||
|
|
||||||
### 1. 安装 Python 依赖
|
|
||||||
```bash
|
|
||||||
pip install uv
|
|
||||||
uv sync
|
|
||||||
```
|
|
||||||
|
|
||||||
### 2. 安装 Nuitka
|
|
||||||
```bash
|
|
||||||
pip install nuitka
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. 构建 Python 后端
|
|
||||||
```bash
|
|
||||||
python build_nuitka.py
|
|
||||||
```
|
|
||||||
|
|
||||||
这会使用 Nuitka 将 `main.py` 编译为独立可执行文件,输出到 `build/nuitka/` 目录。
|
|
||||||
|
|
||||||
**注意**: Nuitka 编译过程可能需要 10-30 分钟,取决于您的系统性能。
|
|
||||||
|
|
||||||
### 4. 安装前端依赖
|
|
||||||
```bash
|
|
||||||
cd dashboard
|
|
||||||
npm install
|
|
||||||
```
|
|
||||||
|
|
||||||
### 5. 构建 Tauri 应用
|
|
||||||
```bash
|
|
||||||
npm run tauri:build
|
|
||||||
```
|
|
||||||
|
|
||||||
构建脚本会自动:
|
|
||||||
1. 运行 `build_nuitka.py` 编译 Python 后端
|
|
||||||
2. 将编译好的可执行文件复制到 `src-tauri/resources/` 目录
|
|
||||||
3. 构建 Tauri 应用并打包所有资源
|
|
||||||
|
|
||||||
### 6. 查找构建产物
|
|
||||||
|
|
||||||
构建完成后,您可以在以下位置找到安装包:
|
|
||||||
|
|
||||||
- **macOS**: `dashboard/src-tauri/target/release/bundle/dmg/AstrBot_*.dmg`
|
|
||||||
- **Windows**: `dashboard/src-tauri/target/release/bundle/msi/AstrBot_*.msi`
|
|
||||||
- **Linux**:
|
|
||||||
- `dashboard/src-tauri/target/release/bundle/deb/astrbot_*.deb`
|
|
||||||
- `dashboard/src-tauri/target/release/bundle/appimage/astrbot_*.AppImage`
|
|
||||||
|
|
||||||
## 开发模式
|
|
||||||
|
|
||||||
在开发时,您可能不想每次都完整编译 Python 后端。
|
|
||||||
|
|
||||||
### 仅开发 Tauri + Vue
|
|
||||||
```bash
|
|
||||||
cd dashboard
|
|
||||||
npm run tauri:dev
|
|
||||||
```
|
|
||||||
|
|
||||||
这会启动开发服务器,但不会自动启动 Python 后端。您需要手动运行:
|
|
||||||
```bash
|
|
||||||
uv run main.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### 测试完整集成
|
|
||||||
如果您想测试 Tauri 自动启动 Python 后端的功能:
|
|
||||||
|
|
||||||
1. 先编译一次 Python 后端:
|
|
||||||
```bash
|
|
||||||
python build_nuitka.py
|
|
||||||
```
|
|
||||||
|
|
||||||
2. 手动复制到资源目录:
|
|
||||||
```bash
|
|
||||||
# macOS
|
|
||||||
cp -r build/nuitka/main.app dashboard/src-tauri/resources/astrbot-backend.app
|
|
||||||
|
|
||||||
# Windows
|
|
||||||
copy build\nuitka\main.exe dashboard\src-tauri\resources\astrbot-backend.exe
|
|
||||||
|
|
||||||
# Linux
|
|
||||||
cp build/nuitka/main.bin dashboard/src-tauri/resources/astrbot-backend
|
|
||||||
```
|
|
||||||
|
|
||||||
3. 运行开发模式:
|
|
||||||
```bash
|
|
||||||
cd dashboard
|
|
||||||
npm run tauri:dev
|
|
||||||
```
|
|
||||||
|
|
||||||
## Nuitka 构建选项说明
|
|
||||||
|
|
||||||
`build_nuitka.py` 脚本使用以下关键选项:
|
|
||||||
|
|
||||||
- `--standalone`: 创建包含所有依赖的独立目录
|
|
||||||
- `--onefile`: 将所有内容打包到单个可执行文件
|
|
||||||
- `--follow-imports`: 自动跟踪所有 Python 导入
|
|
||||||
- `--include-package`: 明确包含特定包
|
|
||||||
- `--include-data-dir`: 包含数据目录(插件、配置等)
|
|
||||||
|
|
||||||
### 自定义构建
|
|
||||||
|
|
||||||
如果您需要修改构建选项,编辑 `build_nuitka.py`:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# 添加更多要包含的包
|
|
||||||
include_packages = [
|
|
||||||
"astrbot",
|
|
||||||
"your_custom_package",
|
|
||||||
# ...
|
|
||||||
]
|
|
||||||
|
|
||||||
# 添加更多数据目录
|
|
||||||
data_includes = [
|
|
||||||
"data/config",
|
|
||||||
"your_custom_data",
|
|
||||||
# ...
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
## 常见问题
|
|
||||||
|
|
||||||
### 1. Nuitka 编译失败
|
|
||||||
**问题**: 编译时出现 "module not found" 错误
|
|
||||||
|
|
||||||
**解决方案**: 在 `build_nuitka.py` 中添加缺失的包到 `include_packages` 列表
|
|
||||||
|
|
||||||
### 2. 运行时找不到资源文件
|
|
||||||
**问题**: 应用启动后提示找不到配置文件或插件
|
|
||||||
|
|
||||||
**解决方案**: 确保在 `build_nuitka.py` 中使用 `--include-data-dir` 包含了所有必要的数据目录
|
|
||||||
|
|
||||||
### 3. macOS 安全警告
|
|
||||||
**问题**: macOS 提示"应用来自未知开发者"
|
|
||||||
|
|
||||||
**解决方案**:
|
|
||||||
```bash
|
|
||||||
# 临时解除限制
|
|
||||||
sudo spctl --master-disable
|
|
||||||
|
|
||||||
# 或者为特定应用授权
|
|
||||||
xattr -cr /Applications/AstrBot.app
|
|
||||||
```
|
|
||||||
|
|
||||||
对于生产发布,您需要:
|
|
||||||
1. 注册 Apple Developer 账号
|
|
||||||
2. 对应用进行代码签名
|
|
||||||
3. 提交公证 (Notarization)
|
|
||||||
|
|
||||||
### 4. Windows Defender 报毒
|
|
||||||
**问题**: Windows Defender 或其他杀毒软件报毒
|
|
||||||
|
|
||||||
**解决方案**:
|
|
||||||
- 这是 Nuitka 打包程序的常见问题
|
|
||||||
- 可以使用 `--windows-company-name` 和 `--windows-product-name` 添加元数据
|
|
||||||
- 对于生产发布,需要购买代码签名证书
|
|
||||||
|
|
||||||
### 5. Linux 依赖问题
|
|
||||||
**问题**: 在某些 Linux 发行版上缺少共享库
|
|
||||||
|
|
||||||
**解决方案**: 使用 AppImage 格式,它包含所有依赖:
|
|
||||||
```bash
|
|
||||||
# 构建时会自动生成 AppImage
|
|
||||||
npm run tauri:build
|
|
||||||
```
|
|
||||||
|
|
||||||
## 优化构建大小
|
|
||||||
|
|
||||||
默认的 `--onefile` 模式会生成较大的可执行文件。如果需要减小体积:
|
|
||||||
|
|
||||||
1. 移除不需要的包
|
|
||||||
2. 使用 `--standalone` 而不是 `--onefile`
|
|
||||||
3. 排除不必要的数据文件
|
|
||||||
|
|
||||||
修改 `build_nuitka.py`:
|
|
||||||
```python
|
|
||||||
# 移除 --onefile,使用 --standalone
|
|
||||||
nuitka_cmd = [
|
|
||||||
sys.executable,
|
|
||||||
"-m", "nuitka",
|
|
||||||
"--standalone", # 只使用 standalone
|
|
||||||
# "--onefile", # 注释掉 onefile
|
|
||||||
# ...
|
|
||||||
]
|
|
||||||
```
|
|
||||||
|
|
||||||
## CI/CD 集成
|
|
||||||
|
|
||||||
项目已配置 GitHub Actions 工作流 (`.github/workflows/build-app.yml`),可以自动为所有平台构建应用。
|
|
||||||
|
|
||||||
推送标签时自动触发:
|
|
||||||
```bash
|
|
||||||
git tag v4.5.7
|
|
||||||
git push origin v4.5.7
|
|
||||||
```
|
|
||||||
|
|
||||||
或手动触发:
|
|
||||||
在 GitHub Actions 页面选择 "Build Desktop App" 工作流并点击 "Run workflow"
|
|
||||||
|
|
||||||
## 发布清单
|
|
||||||
|
|
||||||
在发布新版本前:
|
|
||||||
|
|
||||||
- [ ] 更新版本号
|
|
||||||
- `pyproject.toml` - Python 项目版本
|
|
||||||
- `dashboard/package.json` - Node 项目版本
|
|
||||||
- `dashboard/src-tauri/Cargo.toml` - Rust 项目版本
|
|
||||||
- `dashboard/src-tauri/tauri.conf.json` - Tauri 配置版本
|
|
||||||
|
|
||||||
- [ ] 运行代码检查
|
|
||||||
```bash
|
|
||||||
uv run ruff check .
|
|
||||||
uv run ruff format .
|
|
||||||
```
|
|
||||||
|
|
||||||
- [ ] 本地测试构建
|
|
||||||
```bash
|
|
||||||
python build_nuitka.py
|
|
||||||
cd dashboard && npm run tauri:build
|
|
||||||
```
|
|
||||||
|
|
||||||
- [ ] 测试安装包
|
|
||||||
- 安装生成的安装包
|
|
||||||
- 验证应用启动
|
|
||||||
- 验证 Python 后端自动启动
|
|
||||||
- 测试核心功能
|
|
||||||
|
|
||||||
- [ ] 创建发布标签
|
|
||||||
```bash
|
|
||||||
git tag -a v4.5.7 -m "Release v4.5.7"
|
|
||||||
git push origin v4.5.7
|
|
||||||
```
|
|
||||||
|
|
||||||
## 技术架构
|
|
||||||
|
|
||||||
```
|
|
||||||
┌─────────────────────────────────────┐
|
|
||||||
│ Tauri Desktop App │
|
|
||||||
│ (Rust + WebView) │
|
|
||||||
│ │
|
|
||||||
│ ┌─────────────────────────────┐ │
|
|
||||||
│ │ Vue.js Dashboard │ │
|
|
||||||
│ │ (Frontend UI) │ │
|
|
||||||
│ └─────────────────────────────┘ │
|
|
||||||
│ │
|
|
||||||
│ ┌─────────────────────────────┐ │
|
|
||||||
│ │ Python Backend │ │
|
|
||||||
│ │ (Nuitka Compiled) │ │
|
|
||||||
│ │ - AstrBot Core │ │
|
|
||||||
│ │ - Plugins │ │
|
|
||||||
│ │ - API Server │ │
|
|
||||||
│ └─────────────────────────────┘ │
|
|
||||||
│ │
|
|
||||||
│ HTTP/WebSocket │
|
|
||||||
│ localhost:6185 │
|
|
||||||
└─────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
## 参考资源
|
|
||||||
|
|
||||||
- [Nuitka 文档](https://nuitka.net/doc/user-manual.html)
|
|
||||||
- [Tauri 文档](https://tauri.app/v1/guides/)
|
|
||||||
- [AstrBot 文档](https://astrbot.fun)
|
|
||||||
@@ -33,20 +33,6 @@
|
|||||||
- 请使用英文描述您的 PR。
|
- 请使用英文描述您的 PR。
|
||||||
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`。
|
||||||
|
|
||||||
#### 代码规范
|
|
||||||
|
|
||||||
##### Core
|
|
||||||
|
|
||||||
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ruff format .
|
|
||||||
ruff check .
|
|
||||||
```
|
|
||||||
|
|
||||||
如果您使用 VSCode,可以安装 `Ruff` 插件。
|
|
||||||
|
|
||||||
|
|
||||||
## Contributing Guide
|
## Contributing Guide
|
||||||
|
|
||||||
First off, thanks for taking the time to contribute! ❤️
|
First off, thanks for taking the time to contribute! ❤️
|
||||||
@@ -77,14 +63,3 @@ We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features.
|
|||||||
#### PR Description
|
#### PR Description
|
||||||
- Please use English to describe your PR.
|
- Please use English to describe your PR.
|
||||||
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
|
||||||
|
|
||||||
#### Code Style
|
|
||||||
|
|
||||||
##### Core
|
|
||||||
|
|
||||||
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
ruff format .
|
|
||||||
ruff check .
|
|
||||||
```
|
|
||||||
|
|||||||
@@ -243,10 +243,4 @@ pre-commit install
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<div align="center">
|
|
||||||
|
|
||||||
_私は、高性能ですから!_
|
_私は、高性能ですから!_
|
||||||
|
|
||||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
|
||||||
</div
|
|
||||||
|
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = "4.9.2"
|
__version__ = "4.8.0"
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
|
|
||||||
from typing import Any, ClassVar, Literal, cast
|
from typing import Any, ClassVar, Literal, cast
|
||||||
|
|
||||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
|
from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
|
||||||
from pydantic_core import core_schema
|
from pydantic_core import core_schema
|
||||||
|
|
||||||
|
|
||||||
@@ -122,12 +122,10 @@ class ToolCall(BaseModel):
|
|||||||
extra_content: dict[str, Any] | None = None
|
extra_content: dict[str, Any] | None = None
|
||||||
"""Extra metadata for the tool call."""
|
"""Extra metadata for the tool call."""
|
||||||
|
|
||||||
@model_serializer(mode="wrap")
|
def model_dump(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
def serialize(self, handler):
|
|
||||||
data = handler(self)
|
|
||||||
if self.extra_content is None:
|
if self.extra_content is None:
|
||||||
data.pop("extra_content", None)
|
kwargs.setdefault("exclude", set()).add("extra_content")
|
||||||
return data
|
return super().model_dump(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
class ToolCallPart(BaseModel):
|
class ToolCallPart(BaseModel):
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import typing as T
|
import typing as T
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.entities import TokenUsage
|
|
||||||
|
|
||||||
|
|
||||||
class AgentResponseData(T.TypedDict):
|
class AgentResponseData(T.TypedDict):
|
||||||
@@ -13,23 +12,3 @@ class AgentResponseData(T.TypedDict):
|
|||||||
class AgentResponse:
|
class AgentResponse:
|
||||||
type: str
|
type: str
|
||||||
data: AgentResponseData
|
data: AgentResponseData
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AgentStats:
|
|
||||||
token_usage: TokenUsage = field(default_factory=TokenUsage)
|
|
||||||
start_time: float = 0.0
|
|
||||||
end_time: float = 0.0
|
|
||||||
time_to_first_token: float = 0.0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def duration(self) -> float:
|
|
||||||
return self.end_time - self.start_time
|
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
return {
|
|
||||||
"token_usage": self.token_usage.__dict__,
|
|
||||||
"start_time": self.start_time,
|
|
||||||
"end_time": self.end_time,
|
|
||||||
"time_to_first_token": self.time_to_first_token,
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from .message import Message
|
|||||||
TContext = TypeVar("TContext", default=Any)
|
TContext = TypeVar("TContext", default=Any)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(config={"arbitrary_types_allowed": True})
|
||||||
class ContextWrapper(Generic[TContext]):
|
class ContextWrapper(Generic[TContext]):
|
||||||
"""A context for running an agent, which can be used to pass additional data or state."""
|
"""A context for running an agent, which can be used to pass additional data or state."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
import traceback
|
import traceback
|
||||||
import typing as T
|
import typing as T
|
||||||
|
|
||||||
@@ -13,7 +12,6 @@ from mcp.types import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.core.message.components import Json
|
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageChain,
|
MessageChain,
|
||||||
)
|
)
|
||||||
@@ -26,7 +24,7 @@ from astrbot.core.provider.provider import Provider
|
|||||||
|
|
||||||
from ..hooks import BaseAgentRunHooks
|
from ..hooks import BaseAgentRunHooks
|
||||||
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
|
||||||
from ..response import AgentResponseData, AgentStats
|
from ..response import AgentResponseData
|
||||||
from ..run_context import ContextWrapper, TContext
|
from ..run_context import ContextWrapper, TContext
|
||||||
from ..tool_executor import BaseFunctionToolExecutor
|
from ..tool_executor import BaseFunctionToolExecutor
|
||||||
from .base import AgentResponse, AgentState, BaseAgentRunner
|
from .base import AgentResponse, AgentState, BaseAgentRunner
|
||||||
@@ -71,9 +69,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
self.run_context.messages = messages
|
self.run_context.messages = messages
|
||||||
|
|
||||||
self.stats = AgentStats()
|
|
||||||
self.stats.start_time = time.time()
|
|
||||||
|
|
||||||
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
|
||||||
"""Yields chunks *and* a final LLMResponse."""
|
"""Yields chunks *and* a final LLMResponse."""
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
@@ -103,10 +98,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
|
|
||||||
async for llm_response in self._iter_llm_responses():
|
async for llm_response in self._iter_llm_responses():
|
||||||
if llm_response.is_chunk:
|
if llm_response.is_chunk:
|
||||||
# update ttft
|
|
||||||
if self.stats.time_to_first_token == 0:
|
|
||||||
self.stats.time_to_first_token = time.time() - self.stats.start_time
|
|
||||||
|
|
||||||
if llm_response.result_chain:
|
if llm_response.result_chain:
|
||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type="streaming_delta",
|
type="streaming_delta",
|
||||||
@@ -130,10 +121,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
llm_resp_result = llm_response
|
llm_resp_result = llm_response
|
||||||
|
|
||||||
if not llm_response.is_chunk and llm_response.usage:
|
|
||||||
# only count the token usage of the final response for computation purpose
|
|
||||||
self.stats.token_usage += llm_response.usage
|
|
||||||
break # got final response
|
break # got final response
|
||||||
|
|
||||||
if not llm_resp_result:
|
if not llm_resp_result:
|
||||||
@@ -145,7 +132,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
if llm_resp.role == "err":
|
if llm_resp.role == "err":
|
||||||
# 如果 LLM 响应错误,转换到错误状态
|
# 如果 LLM 响应错误,转换到错误状态
|
||||||
self.final_llm_resp = llm_resp
|
self.final_llm_resp = llm_resp
|
||||||
self.stats.end_time = time.time()
|
|
||||||
self._transition_state(AgentState.ERROR)
|
self._transition_state(AgentState.ERROR)
|
||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type="err",
|
type="err",
|
||||||
@@ -160,7 +146,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
# 如果没有工具调用,转换到完成状态
|
# 如果没有工具调用,转换到完成状态
|
||||||
self.final_llm_resp = llm_resp
|
self.final_llm_resp = llm_resp
|
||||||
self._transition_state(AgentState.DONE)
|
self._transition_state(AgentState.DONE)
|
||||||
self.stats.end_time = time.time()
|
|
||||||
# record the final assistant message
|
# record the final assistant message
|
||||||
self.run_context.messages.append(
|
self.run_context.messages.append(
|
||||||
Message(
|
Message(
|
||||||
@@ -190,19 +175,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
# 如果有工具调用,还需处理工具调用
|
# 如果有工具调用,还需处理工具调用
|
||||||
if llm_resp.tools_call_name:
|
if llm_resp.tools_call_name:
|
||||||
tool_call_result_blocks = []
|
tool_call_result_blocks = []
|
||||||
|
for tool_call_name in llm_resp.tools_call_name:
|
||||||
|
yield AgentResponse(
|
||||||
|
type="tool_call",
|
||||||
|
data=AgentResponseData(
|
||||||
|
chain=MessageChain(type="tool_call").message(
|
||||||
|
f"🔨 调用工具: {tool_call_name}"
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
tool_call_result_blocks = result
|
tool_call_result_blocks = result
|
||||||
elif isinstance(result, MessageChain):
|
elif isinstance(result, MessageChain):
|
||||||
if result.type is None:
|
result.type = "tool_call_result"
|
||||||
# should not happen
|
|
||||||
continue
|
|
||||||
if result.type == "tool_direct_result":
|
|
||||||
ar_type = "tool_call_result"
|
|
||||||
else:
|
|
||||||
ar_type = result.type
|
|
||||||
yield AgentResponse(
|
yield AgentResponse(
|
||||||
type=ar_type,
|
type="tool_call_result",
|
||||||
data=AgentResponseData(chain=result),
|
data=AgentResponseData(chain=result),
|
||||||
)
|
)
|
||||||
# 将结果添加到上下文中
|
# 将结果添加到上下文中
|
||||||
@@ -245,19 +233,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
llm_response.tools_call_args,
|
llm_response.tools_call_args,
|
||||||
llm_response.tools_call_ids,
|
llm_response.tools_call_ids,
|
||||||
):
|
):
|
||||||
yield MessageChain(
|
|
||||||
type="tool_call",
|
|
||||||
chain=[
|
|
||||||
Json(
|
|
||||||
data={
|
|
||||||
"id": func_tool_id,
|
|
||||||
"name": func_tool_name,
|
|
||||||
"args": func_tool_args,
|
|
||||||
"ts": time.time(),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
if not req.func_tool:
|
if not req.func_tool:
|
||||||
return
|
return
|
||||||
@@ -331,6 +306,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
content=res.content[0].text,
|
content=res.content[0].text,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
yield MessageChain().message(res.content[0].text)
|
||||||
elif isinstance(res.content[0], ImageContent):
|
elif isinstance(res.content[0], ImageContent):
|
||||||
tool_call_result_blocks.append(
|
tool_call_result_blocks.append(
|
||||||
ToolCallMessageSegment(
|
ToolCallMessageSegment(
|
||||||
@@ -352,6 +328,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
content=resource.text,
|
content=resource.text,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
yield MessageChain().message(resource.text)
|
||||||
elif (
|
elif (
|
||||||
isinstance(resource, BlobResourceContents)
|
isinstance(resource, BlobResourceContents)
|
||||||
and resource.mimeType
|
and resource.mimeType
|
||||||
@@ -375,22 +352,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
content="返回的数据类型不受支持",
|
content="返回的数据类型不受支持",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
yield MessageChain().message("返回的数据类型不受支持。")
|
||||||
# yield the last tool call result
|
|
||||||
if tool_call_result_blocks:
|
|
||||||
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
|
||||||
yield MessageChain(
|
|
||||||
type="tool_call_result",
|
|
||||||
chain=[
|
|
||||||
Json(
|
|
||||||
data={
|
|
||||||
"id": func_tool_id,
|
|
||||||
"ts": time.time(),
|
|
||||||
"result": last_tcr_content,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
elif resp is None:
|
elif resp is None:
|
||||||
# Tool 直接请求发送消息给用户
|
# Tool 直接请求发送消息给用户
|
||||||
@@ -400,7 +362,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
|||||||
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
|
||||||
)
|
)
|
||||||
self._transition_state(AgentState.DONE)
|
self._transition_state(AgentState.DONE)
|
||||||
self.stats.end_time = time.time()
|
|
||||||
else:
|
else:
|
||||||
# 不应该出现其他类型
|
# 不应该出现其他类型
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|||||||
@@ -6,10 +6,8 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
|||||||
from astrbot.core.star.context import Context
|
from astrbot.core.star.context import Context
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(config={"arbitrary_types_allowed": True})
|
||||||
class AstrAgentContext:
|
class AstrAgentContext:
|
||||||
__pydantic_config__ = {"arbitrary_types_allowed": True}
|
|
||||||
|
|
||||||
context: Context
|
context: Context
|
||||||
"""The star context instance"""
|
"""The star context instance"""
|
||||||
event: AstrMessageEvent
|
event: AstrMessageEvent
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ from collections.abc import AsyncGenerator
|
|||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
|
||||||
from astrbot.core.astr_agent_context import AstrAgentContext
|
from astrbot.core.astr_agent_context import AstrAgentContext
|
||||||
from astrbot.core.message.components import Json
|
|
||||||
from astrbot.core.message.message_event_result import (
|
from astrbot.core.message.message_event_result import (
|
||||||
MessageChain,
|
MessageChain,
|
||||||
MessageEventResult,
|
MessageEventResult,
|
||||||
@@ -34,27 +33,16 @@ async def run_agent(
|
|||||||
msg_chain = resp.data["chain"]
|
msg_chain = resp.data["chain"]
|
||||||
if msg_chain.type == "tool_direct_result":
|
if msg_chain.type == "tool_direct_result":
|
||||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||||
await astr_event.send(msg_chain)
|
await astr_event.send(resp.data["chain"])
|
||||||
continue
|
continue
|
||||||
if astr_event.get_platform_id() == "webchat":
|
|
||||||
await astr_event.send(msg_chain)
|
|
||||||
# 对于其他情况,暂时先不处理
|
# 对于其他情况,暂时先不处理
|
||||||
continue
|
continue
|
||||||
elif resp.type == "tool_call":
|
elif resp.type == "tool_call":
|
||||||
if agent_runner.streaming:
|
if agent_runner.streaming:
|
||||||
# 用来标记流式响应需要分节
|
# 用来标记流式响应需要分节
|
||||||
yield MessageChain(chain=[], type="break")
|
yield MessageChain(chain=[], type="break")
|
||||||
|
if show_tool_use:
|
||||||
if astr_event.get_platform_name() == "webchat":
|
|
||||||
await astr_event.send(resp.data["chain"])
|
await astr_event.send(resp.data["chain"])
|
||||||
elif show_tool_use:
|
|
||||||
json_comp = resp.data["chain"].chain[0]
|
|
||||||
if isinstance(json_comp, Json):
|
|
||||||
m = f"🔨 调用工具: {json_comp.data.get('name')}"
|
|
||||||
else:
|
|
||||||
m = "🔨 调用工具..."
|
|
||||||
chain = MessageChain(type="tool_call").message(m)
|
|
||||||
await astr_event.send(chain)
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if stream_to_general and resp.type == "streaming_delta":
|
if stream_to_general and resp.type == "streaming_delta":
|
||||||
@@ -81,15 +69,6 @@ async def run_agent(
|
|||||||
continue
|
continue
|
||||||
yield resp.data["chain"] # MessageChain
|
yield resp.data["chain"] # MessageChain
|
||||||
if agent_runner.done():
|
if agent_runner.done():
|
||||||
# send agent stats to webchat
|
|
||||||
if astr_event.get_platform_name() == "webchat":
|
|
||||||
await astr_event.send(
|
|
||||||
MessageChain(
|
|
||||||
type="agent_stats",
|
|
||||||
chain=[Json(data=agent_runner.stats.to_dict())],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import os
|
|||||||
|
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
VERSION = "4.9.2"
|
VERSION = "4.8.0"
|
||||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||||
|
|
||||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||||
@@ -108,7 +108,6 @@ DEFAULT_CONFIG = {
|
|||||||
"provider_id": "",
|
"provider_id": "",
|
||||||
"dual_output": False,
|
"dual_output": False,
|
||||||
"use_file_service": False,
|
"use_file_service": False,
|
||||||
"trigger_probability": 1.0,
|
|
||||||
},
|
},
|
||||||
"provider_ltm_settings": {
|
"provider_ltm_settings": {
|
||||||
"group_icl_enable": False,
|
"group_icl_enable": False,
|
||||||
@@ -209,7 +208,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"callback_server_host": "0.0.0.0",
|
"callback_server_host": "0.0.0.0",
|
||||||
"port": 6196,
|
"port": 6196,
|
||||||
},
|
},
|
||||||
"OneBot v11": {
|
"QQ 个人号(OneBot v11)": {
|
||||||
"id": "default",
|
"id": "default",
|
||||||
"type": "aiocqhttp",
|
"type": "aiocqhttp",
|
||||||
"enable": False,
|
"enable": False,
|
||||||
@@ -946,7 +945,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gemini-3-flash-preview",
|
"model": "gemini-1.5-flash",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
"custom_headers": {},
|
"custom_headers": {},
|
||||||
@@ -963,7 +962,7 @@ CONFIG_METADATA_2 = {
|
|||||||
"api_base": "https://generativelanguage.googleapis.com/",
|
"api_base": "https://generativelanguage.googleapis.com/",
|
||||||
"timeout": 120,
|
"timeout": 120,
|
||||||
"model_config": {
|
"model_config": {
|
||||||
"model": "gemini-3-flash-preview",
|
"model": "gemini-2.0-flash-exp",
|
||||||
"temperature": 0.4,
|
"temperature": 0.4,
|
||||||
},
|
},
|
||||||
"gm_resp_image_modal": False,
|
"gm_resp_image_modal": False,
|
||||||
@@ -976,7 +975,9 @@ CONFIG_METADATA_2 = {
|
|||||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||||
},
|
},
|
||||||
"gm_thinking_config": {"budget": 0, "level": "HIGH"},
|
"gm_thinking_config": {
|
||||||
|
"budget": 0,
|
||||||
|
},
|
||||||
"modalities": ["text", "image", "tool_use"],
|
"modalities": ["text", "image", "tool_use"],
|
||||||
},
|
},
|
||||||
"DeepSeek": {
|
"DeepSeek": {
|
||||||
@@ -1817,24 +1818,13 @@ CONFIG_METADATA_2 = {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
"gm_thinking_config": {
|
"gm_thinking_config": {
|
||||||
"description": "Thinking Config",
|
"description": "Gemini思考设置",
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"items": {
|
"items": {
|
||||||
"budget": {
|
"budget": {
|
||||||
"description": "Thinking Budget",
|
"description": "思考预算",
|
||||||
"type": "int",
|
"type": "int",
|
||||||
"hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget",
|
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
|
||||||
},
|
|
||||||
"level": {
|
|
||||||
"description": "Thinking Level",
|
|
||||||
"type": "string",
|
|
||||||
"hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
|
|
||||||
"options": [
|
|
||||||
"MINIMAL",
|
|
||||||
"LOW",
|
|
||||||
"MEDIUM",
|
|
||||||
"HIGH",
|
|
||||||
],
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -2219,9 +2209,6 @@ CONFIG_METADATA_2 = {
|
|||||||
"use_file_service": {
|
"use_file_service": {
|
||||||
"type": "bool",
|
"type": "bool",
|
||||||
},
|
},
|
||||||
"trigger_probability": {
|
|
||||||
"type": "float",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"provider_ltm_settings": {
|
"provider_ltm_settings": {
|
||||||
@@ -2432,14 +2419,6 @@ CONFIG_METADATA_3 = {
|
|||||||
"provider_tts_settings.enable": True,
|
"provider_tts_settings.enable": True,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"provider_tts_settings.trigger_probability": {
|
|
||||||
"description": "TTS 触发概率",
|
|
||||||
"type": "float",
|
|
||||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
|
||||||
"condition": {
|
|
||||||
"provider_tts_settings.enable": True,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"provider_settings.image_caption_prompt": {
|
"provider_settings.image_caption_prompt": {
|
||||||
"description": "图片转述提示词",
|
"description": "图片转述提示词",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
@@ -3007,7 +2986,6 @@ CONFIG_METADATA_3 = {
|
|||||||
"description": "回复概率",
|
"description": "回复概率",
|
||||||
"type": "float",
|
"type": "float",
|
||||||
"hint": "0.0-1.0 之间的数值",
|
"hint": "0.0-1.0 之间的数值",
|
||||||
"slider": {"min": 0, "max": 1, "step": 0.05},
|
|
||||||
"condition": {
|
"condition": {
|
||||||
"provider_ltm_settings.active_reply.enable": True,
|
"provider_ltm_settings.active_reply.enable": True,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -79,7 +79,6 @@ class ConfigMetadataI18n:
|
|||||||
"_special",
|
"_special",
|
||||||
"invisible",
|
"invisible",
|
||||||
"options",
|
"options",
|
||||||
"slider",
|
|
||||||
]:
|
]:
|
||||||
if attr in field_data:
|
if attr in field_data:
|
||||||
field_result[attr] = field_data[attr]
|
field_result[attr] = field_data[attr]
|
||||||
|
|||||||
@@ -9,8 +9,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
|
|||||||
|
|
||||||
from astrbot.core.db.po import (
|
from astrbot.core.db.po import (
|
||||||
Attachment,
|
Attachment,
|
||||||
CommandConfig,
|
|
||||||
CommandConflict,
|
|
||||||
ConversationV2,
|
ConversationV2,
|
||||||
Persona,
|
Persona,
|
||||||
PlatformMessageHistory,
|
PlatformMessageHistory,
|
||||||
@@ -316,76 +314,6 @@ class BaseDatabase(abc.ABC):
|
|||||||
"""Clear all preferences for a specific scope ID."""
|
"""Clear all preferences for a specific scope ID."""
|
||||||
...
|
...
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def get_command_configs(self) -> list[CommandConfig]:
|
|
||||||
"""Get all stored command configurations."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
|
|
||||||
"""Fetch a single command configuration by handler."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def upsert_command_config(
|
|
||||||
self,
|
|
||||||
handler_full_name: str,
|
|
||||||
plugin_name: str,
|
|
||||||
module_path: str,
|
|
||||||
original_command: str,
|
|
||||||
*,
|
|
||||||
resolved_command: str | None = None,
|
|
||||||
enabled: bool | None = None,
|
|
||||||
keep_original_alias: bool | None = None,
|
|
||||||
conflict_key: str | None = None,
|
|
||||||
resolution_strategy: str | None = None,
|
|
||||||
note: str | None = None,
|
|
||||||
extra_data: dict | None = None,
|
|
||||||
auto_managed: bool | None = None,
|
|
||||||
) -> CommandConfig:
|
|
||||||
"""Create or update a command configuration."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
|
||||||
"""Delete a single command configuration."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
|
||||||
"""Bulk delete command configurations."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def list_command_conflicts(
|
|
||||||
self,
|
|
||||||
status: str | None = None,
|
|
||||||
) -> list[CommandConflict]:
|
|
||||||
"""List recorded command conflict entries."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def upsert_command_conflict(
|
|
||||||
self,
|
|
||||||
conflict_key: str,
|
|
||||||
handler_full_name: str,
|
|
||||||
plugin_name: str,
|
|
||||||
*,
|
|
||||||
status: str | None = None,
|
|
||||||
resolution: str | None = None,
|
|
||||||
resolved_command: str | None = None,
|
|
||||||
note: str | None = None,
|
|
||||||
extra_data: dict | None = None,
|
|
||||||
auto_generated: bool | None = None,
|
|
||||||
) -> CommandConflict:
|
|
||||||
"""Create or update a conflict record."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
|
||||||
"""Delete conflict records."""
|
|
||||||
...
|
|
||||||
|
|
||||||
# @abc.abstractmethod
|
# @abc.abstractmethod
|
||||||
# async def insert_llm_message(
|
# async def insert_llm_message(
|
||||||
# self,
|
# self,
|
||||||
|
|||||||
@@ -234,65 +234,6 @@ class Attachment(SQLModel, table=True):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class CommandConfig(SQLModel, table=True):
|
|
||||||
"""Per-command configuration overrides for dashboard management."""
|
|
||||||
|
|
||||||
__tablename__ = "command_configs" # type: ignore
|
|
||||||
|
|
||||||
handler_full_name: str = Field(
|
|
||||||
primary_key=True,
|
|
||||||
max_length=512,
|
|
||||||
)
|
|
||||||
plugin_name: str = Field(nullable=False, max_length=255)
|
|
||||||
module_path: str = Field(nullable=False, max_length=255)
|
|
||||||
original_command: str = Field(nullable=False, max_length=255)
|
|
||||||
resolved_command: str | None = Field(default=None, max_length=255)
|
|
||||||
enabled: bool = Field(default=True, nullable=False)
|
|
||||||
keep_original_alias: bool = Field(default=False, nullable=False)
|
|
||||||
conflict_key: str | None = Field(default=None, max_length=255)
|
|
||||||
resolution_strategy: str | None = Field(default=None, max_length=64)
|
|
||||||
note: str | None = Field(default=None, sa_type=Text)
|
|
||||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
|
||||||
auto_managed: bool = Field(default=False, nullable=False)
|
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
||||||
updated_at: datetime = Field(
|
|
||||||
default_factory=lambda: datetime.now(timezone.utc),
|
|
||||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CommandConflict(SQLModel, table=True):
|
|
||||||
"""Conflict tracking for duplicated command names."""
|
|
||||||
|
|
||||||
__tablename__ = "command_conflicts" # type: ignore
|
|
||||||
|
|
||||||
id: int | None = Field(
|
|
||||||
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
|
|
||||||
)
|
|
||||||
conflict_key: str = Field(nullable=False, max_length=255)
|
|
||||||
handler_full_name: str = Field(nullable=False, max_length=512)
|
|
||||||
plugin_name: str = Field(nullable=False, max_length=255)
|
|
||||||
status: str = Field(default="pending", max_length=32)
|
|
||||||
resolution: str | None = Field(default=None, max_length=64)
|
|
||||||
resolved_command: str | None = Field(default=None, max_length=255)
|
|
||||||
note: str | None = Field(default=None, sa_type=Text)
|
|
||||||
extra_data: dict | None = Field(default=None, sa_type=JSON)
|
|
||||||
auto_generated: bool = Field(default=False, nullable=False)
|
|
||||||
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
|
|
||||||
updated_at: datetime = Field(
|
|
||||||
default_factory=lambda: datetime.now(timezone.utc),
|
|
||||||
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
UniqueConstraint(
|
|
||||||
"conflict_key",
|
|
||||||
"handler_full_name",
|
|
||||||
name="uix_conflict_handler",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Conversation:
|
class Conversation:
|
||||||
"""LLM 对话类
|
"""LLM 对话类
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import threading
|
import threading
|
||||||
import typing as T
|
import typing as T
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
from sqlalchemy import CursorResult
|
from sqlalchemy import CursorResult
|
||||||
@@ -11,8 +10,6 @@ from sqlmodel import col, delete, desc, func, or_, select, text, update
|
|||||||
from astrbot.core.db import BaseDatabase
|
from astrbot.core.db import BaseDatabase
|
||||||
from astrbot.core.db.po import (
|
from astrbot.core.db.po import (
|
||||||
Attachment,
|
Attachment,
|
||||||
CommandConfig,
|
|
||||||
CommandConflict,
|
|
||||||
ConversationV2,
|
ConversationV2,
|
||||||
Persona,
|
Persona,
|
||||||
PlatformMessageHistory,
|
PlatformMessageHistory,
|
||||||
@@ -29,7 +26,6 @@ from astrbot.core.db.po import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
NOT_GIVEN = T.TypeVar("NOT_GIVEN")
|
||||||
TxResult = T.TypeVar("TxResult")
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteDatabase(BaseDatabase):
|
class SQLiteDatabase(BaseDatabase):
|
||||||
@@ -674,242 +670,6 @@ class SQLiteDatabase(BaseDatabase):
|
|||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# ====
|
|
||||||
# Command Configuration & Conflict Tracking
|
|
||||||
# ====
|
|
||||||
|
|
||||||
async def _run_in_tx(
|
|
||||||
self,
|
|
||||||
fn: Callable[[AsyncSession], Awaitable[TxResult]],
|
|
||||||
) -> TxResult:
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
async with session.begin():
|
|
||||||
return await fn(session)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _apply_updates(model, **updates) -> None:
|
|
||||||
for field, value in updates.items():
|
|
||||||
if value is not None:
|
|
||||||
setattr(model, field, value)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _new_command_config(
|
|
||||||
handler_full_name: str,
|
|
||||||
plugin_name: str,
|
|
||||||
module_path: str,
|
|
||||||
original_command: str,
|
|
||||||
*,
|
|
||||||
resolved_command: str | None = None,
|
|
||||||
enabled: bool | None = None,
|
|
||||||
keep_original_alias: bool | None = None,
|
|
||||||
conflict_key: str | None = None,
|
|
||||||
resolution_strategy: str | None = None,
|
|
||||||
note: str | None = None,
|
|
||||||
extra_data: dict | None = None,
|
|
||||||
auto_managed: bool | None = None,
|
|
||||||
) -> CommandConfig:
|
|
||||||
return CommandConfig(
|
|
||||||
handler_full_name=handler_full_name,
|
|
||||||
plugin_name=plugin_name,
|
|
||||||
module_path=module_path,
|
|
||||||
original_command=original_command,
|
|
||||||
resolved_command=resolved_command,
|
|
||||||
enabled=True if enabled is None else enabled,
|
|
||||||
keep_original_alias=False
|
|
||||||
if keep_original_alias is None
|
|
||||||
else keep_original_alias,
|
|
||||||
conflict_key=conflict_key or original_command,
|
|
||||||
resolution_strategy=resolution_strategy,
|
|
||||||
note=note,
|
|
||||||
extra_data=extra_data,
|
|
||||||
auto_managed=bool(auto_managed),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _new_command_conflict(
|
|
||||||
conflict_key: str,
|
|
||||||
handler_full_name: str,
|
|
||||||
plugin_name: str,
|
|
||||||
*,
|
|
||||||
status: str | None = None,
|
|
||||||
resolution: str | None = None,
|
|
||||||
resolved_command: str | None = None,
|
|
||||||
note: str | None = None,
|
|
||||||
extra_data: dict | None = None,
|
|
||||||
auto_generated: bool | None = None,
|
|
||||||
) -> CommandConflict:
|
|
||||||
return CommandConflict(
|
|
||||||
conflict_key=conflict_key,
|
|
||||||
handler_full_name=handler_full_name,
|
|
||||||
plugin_name=plugin_name,
|
|
||||||
status=status or "pending",
|
|
||||||
resolution=resolution,
|
|
||||||
resolved_command=resolved_command,
|
|
||||||
note=note,
|
|
||||||
extra_data=extra_data,
|
|
||||||
auto_generated=bool(auto_generated),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_command_configs(self) -> list[CommandConfig]:
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
result = await session.execute(select(CommandConfig))
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
async def get_command_config(
|
|
||||||
self,
|
|
||||||
handler_full_name: str,
|
|
||||||
) -> CommandConfig | None:
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
return await session.get(CommandConfig, handler_full_name)
|
|
||||||
|
|
||||||
async def upsert_command_config(
|
|
||||||
self,
|
|
||||||
handler_full_name: str,
|
|
||||||
plugin_name: str,
|
|
||||||
module_path: str,
|
|
||||||
original_command: str,
|
|
||||||
*,
|
|
||||||
resolved_command: str | None = None,
|
|
||||||
enabled: bool | None = None,
|
|
||||||
keep_original_alias: bool | None = None,
|
|
||||||
conflict_key: str | None = None,
|
|
||||||
resolution_strategy: str | None = None,
|
|
||||||
note: str | None = None,
|
|
||||||
extra_data: dict | None = None,
|
|
||||||
auto_managed: bool | None = None,
|
|
||||||
) -> CommandConfig:
|
|
||||||
async def _op(session: AsyncSession) -> CommandConfig:
|
|
||||||
config = await session.get(CommandConfig, handler_full_name)
|
|
||||||
if not config:
|
|
||||||
config = self._new_command_config(
|
|
||||||
handler_full_name,
|
|
||||||
plugin_name,
|
|
||||||
module_path,
|
|
||||||
original_command,
|
|
||||||
resolved_command=resolved_command,
|
|
||||||
enabled=enabled,
|
|
||||||
keep_original_alias=keep_original_alias,
|
|
||||||
conflict_key=conflict_key,
|
|
||||||
resolution_strategy=resolution_strategy,
|
|
||||||
note=note,
|
|
||||||
extra_data=extra_data,
|
|
||||||
auto_managed=auto_managed,
|
|
||||||
)
|
|
||||||
session.add(config)
|
|
||||||
else:
|
|
||||||
self._apply_updates(
|
|
||||||
config,
|
|
||||||
plugin_name=plugin_name,
|
|
||||||
module_path=module_path,
|
|
||||||
original_command=original_command,
|
|
||||||
resolved_command=resolved_command,
|
|
||||||
enabled=enabled,
|
|
||||||
keep_original_alias=keep_original_alias,
|
|
||||||
conflict_key=conflict_key,
|
|
||||||
resolution_strategy=resolution_strategy,
|
|
||||||
note=note,
|
|
||||||
extra_data=extra_data,
|
|
||||||
auto_managed=auto_managed,
|
|
||||||
)
|
|
||||||
await session.flush()
|
|
||||||
await session.refresh(config)
|
|
||||||
return config
|
|
||||||
|
|
||||||
return await self._run_in_tx(_op)
|
|
||||||
|
|
||||||
async def delete_command_config(self, handler_full_name: str) -> None:
|
|
||||||
await self.delete_command_configs([handler_full_name])
|
|
||||||
|
|
||||||
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
|
|
||||||
if not handler_full_names:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _op(session: AsyncSession) -> None:
|
|
||||||
await session.execute(
|
|
||||||
delete(CommandConfig).where(
|
|
||||||
col(CommandConfig.handler_full_name).in_(handler_full_names),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._run_in_tx(_op)
|
|
||||||
|
|
||||||
async def list_command_conflicts(
|
|
||||||
self,
|
|
||||||
status: str | None = None,
|
|
||||||
) -> list[CommandConflict]:
|
|
||||||
async with self.get_db() as session:
|
|
||||||
session: AsyncSession
|
|
||||||
query = select(CommandConflict)
|
|
||||||
if status:
|
|
||||||
query = query.where(CommandConflict.status == status)
|
|
||||||
result = await session.execute(query)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
async def upsert_command_conflict(
|
|
||||||
self,
|
|
||||||
conflict_key: str,
|
|
||||||
handler_full_name: str,
|
|
||||||
plugin_name: str,
|
|
||||||
*,
|
|
||||||
status: str | None = None,
|
|
||||||
resolution: str | None = None,
|
|
||||||
resolved_command: str | None = None,
|
|
||||||
note: str | None = None,
|
|
||||||
extra_data: dict | None = None,
|
|
||||||
auto_generated: bool | None = None,
|
|
||||||
) -> CommandConflict:
|
|
||||||
async def _op(session: AsyncSession) -> CommandConflict:
|
|
||||||
result = await session.execute(
|
|
||||||
select(CommandConflict).where(
|
|
||||||
CommandConflict.conflict_key == conflict_key,
|
|
||||||
CommandConflict.handler_full_name == handler_full_name,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if not record:
|
|
||||||
record = self._new_command_conflict(
|
|
||||||
conflict_key,
|
|
||||||
handler_full_name,
|
|
||||||
plugin_name,
|
|
||||||
status=status,
|
|
||||||
resolution=resolution,
|
|
||||||
resolved_command=resolved_command,
|
|
||||||
note=note,
|
|
||||||
extra_data=extra_data,
|
|
||||||
auto_generated=auto_generated,
|
|
||||||
)
|
|
||||||
session.add(record)
|
|
||||||
else:
|
|
||||||
self._apply_updates(
|
|
||||||
record,
|
|
||||||
plugin_name=plugin_name,
|
|
||||||
status=status,
|
|
||||||
resolution=resolution,
|
|
||||||
resolved_command=resolved_command,
|
|
||||||
note=note,
|
|
||||||
extra_data=extra_data,
|
|
||||||
auto_generated=auto_generated,
|
|
||||||
)
|
|
||||||
await session.flush()
|
|
||||||
await session.refresh(record)
|
|
||||||
return record
|
|
||||||
|
|
||||||
return await self._run_in_tx(_op)
|
|
||||||
|
|
||||||
async def delete_command_conflicts(self, ids: list[int]) -> None:
|
|
||||||
if not ids:
|
|
||||||
return
|
|
||||||
|
|
||||||
async def _op(session: AsyncSession) -> None:
|
|
||||||
await session.execute(
|
|
||||||
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._run_in_tx(_op)
|
|
||||||
|
|
||||||
# ====
|
# ====
|
||||||
# Deprecated Methods
|
# Deprecated Methods
|
||||||
# ====
|
# ====
|
||||||
|
|||||||
@@ -24,7 +24,6 @@ import asyncio
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
|
||||||
from asyncio import Queue
|
from asyncio import Queue
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
|
||||||
@@ -149,7 +148,7 @@ class LogQueueHandler(logging.Handler):
|
|||||||
self.log_broker.publish(
|
self.log_broker.publish(
|
||||||
{
|
{
|
||||||
"level": record.levelname,
|
"level": record.levelname,
|
||||||
"time": time.time(),
|
"time": record.asctime,
|
||||||
"data": log_entry,
|
"data": log_entry,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -629,11 +629,12 @@ class Nodes(BaseMessageComponent):
|
|||||||
|
|
||||||
class Json(BaseMessageComponent):
|
class Json(BaseMessageComponent):
|
||||||
type = ComponentType.Json
|
type = ComponentType.Json
|
||||||
data: dict
|
data: str | dict
|
||||||
|
resid: int | None = 0
|
||||||
|
|
||||||
def __init__(self, data: str | dict, **_):
|
def __init__(self, data, **_):
|
||||||
if isinstance(data, str):
|
if isinstance(data, dict):
|
||||||
data = json.loads(data)
|
data = json.dumps(data)
|
||||||
super().__init__(data=data, **_)
|
super().__init__(data=data, **_)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -119,7 +119,7 @@ class RespondStage(Stage):
|
|||||||
|
|
||||||
if (result := event.get_result()) is None:
|
if (result := event.get_result()) is None:
|
||||||
return False
|
return False
|
||||||
if self.only_llm_result and not result.is_llm_result():
|
if self.only_llm_result and result.is_llm_result():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if event.get_platform_name() in [
|
if event.get_platform_name() in [
|
||||||
@@ -158,11 +158,7 @@ class RespondStage(Stage):
|
|||||||
result = event.get_result()
|
result = event.get_result()
|
||||||
if result is None:
|
if result is None:
|
||||||
return
|
return
|
||||||
if event.get_extra("_streaming_finished", False):
|
|
||||||
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
|
|
||||||
return
|
|
||||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||||
event.set_extra("_streaming_finished", True)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import random
|
|
||||||
import re
|
import re
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@@ -43,18 +42,6 @@ class ResultDecorateStage(Stage):
|
|||||||
"forward_threshold"
|
"forward_threshold"
|
||||||
]
|
]
|
||||||
|
|
||||||
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
|
|
||||||
"trigger_probability",
|
|
||||||
1,
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
self.tts_trigger_probability = max(
|
|
||||||
0.0,
|
|
||||||
min(float(trigger_probability), 1.0),
|
|
||||||
)
|
|
||||||
except (TypeError, ValueError):
|
|
||||||
self.tts_trigger_probability = 1.0
|
|
||||||
|
|
||||||
# 分段回复
|
# 分段回复
|
||||||
self.words_count_threshold = int(
|
self.words_count_threshold = int(
|
||||||
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
ctx.astrbot_config["platform_settings"]["segmented_reply"][
|
||||||
@@ -259,14 +246,7 @@ class ResultDecorateStage(Stage):
|
|||||||
and result.is_llm_result()
|
and result.is_llm_result()
|
||||||
and SessionServiceManager.should_process_tts_request(event)
|
and SessionServiceManager.should_process_tts_request(event)
|
||||||
):
|
):
|
||||||
should_tts = self.tts_trigger_probability >= 1.0 or (
|
if not tts_provider:
|
||||||
self.tts_trigger_probability > 0.0
|
|
||||||
and random.random() <= self.tts_trigger_probability
|
|
||||||
)
|
|
||||||
|
|
||||||
if not should_tts:
|
|
||||||
logger.debug("跳过 TTS:触发概率未命中。")
|
|
||||||
elif not tts_provider:
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -112,6 +112,10 @@ class PlatformManager:
|
|||||||
from .sources.satori.satori_adapter import (
|
from .sources.satori.satori_adapter import (
|
||||||
SatoriPlatformAdapter, # noqa: F401
|
SatoriPlatformAdapter, # noqa: F401
|
||||||
)
|
)
|
||||||
|
case "github_webhook":
|
||||||
|
from .sources.github_webhook.github_webhook_adapter import (
|
||||||
|
GitHubWebhookPlatformAdapter, # noqa: F401
|
||||||
|
)
|
||||||
except (ImportError, ModuleNotFoundError) as e:
|
except (ImportError, ModuleNotFoundError) as e:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
|
||||||
|
|||||||
@@ -0,0 +1,315 @@
|
|||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
|
from astrbot import logger
|
||||||
|
from astrbot.api.event import MessageChain
|
||||||
|
from astrbot.api.message_components import Plain
|
||||||
|
from astrbot.api.platform import (
|
||||||
|
AstrBotMessage,
|
||||||
|
MessageMember,
|
||||||
|
MessageType,
|
||||||
|
Platform,
|
||||||
|
PlatformMetadata,
|
||||||
|
)
|
||||||
|
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||||
|
from astrbot.core.platform.platform import PlatformStatus
|
||||||
|
from astrbot.core.utils.webhook_utils import log_webhook_info
|
||||||
|
|
||||||
|
from ...register import register_platform_adapter
|
||||||
|
from .github_webhook_event import GitHubWebhookMessageEvent
|
||||||
|
|
||||||
|
|
||||||
|
@register_platform_adapter(
|
||||||
|
"github_webhook",
|
||||||
|
"GitHub Webhook 适配器",
|
||||||
|
support_streaming_message=False,
|
||||||
|
)
|
||||||
|
class GitHubWebhookPlatformAdapter(Platform):
|
||||||
|
"""GitHub Webhook 平台适配器
|
||||||
|
|
||||||
|
支持的事件:
|
||||||
|
- issues (created)
|
||||||
|
- issue_comment (created)
|
||||||
|
- pull_request (opened)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
platform_config: dict,
|
||||||
|
platform_settings: dict,
|
||||||
|
event_queue: asyncio.Queue,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(platform_config, event_queue)
|
||||||
|
|
||||||
|
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", True)
|
||||||
|
self.webhook_secret = platform_config.get("webhook_secret", "")
|
||||||
|
self.shutdown_event = asyncio.Event()
|
||||||
|
|
||||||
|
async def send_by_session(
|
||||||
|
self,
|
||||||
|
session: MessageSesion,
|
||||||
|
message_chain: MessageChain,
|
||||||
|
):
|
||||||
|
"""GitHub Webhook 是单向接收,不支持主动发送消息"""
|
||||||
|
logger.warning("GitHub Webhook 适配器不支持 send_by_session")
|
||||||
|
|
||||||
|
def meta(self) -> PlatformMetadata:
|
||||||
|
return PlatformMetadata(
|
||||||
|
name="github_webhook",
|
||||||
|
description="GitHub Webhook 适配器",
|
||||||
|
id=cast(str, self.config.get("id")),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""运行适配器"""
|
||||||
|
self.status = PlatformStatus.RUNNING
|
||||||
|
|
||||||
|
# 如果启用统一 webhook 模式
|
||||||
|
webhook_uuid = self.config.get("webhook_uuid")
|
||||||
|
if self.unified_webhook_mode and webhook_uuid:
|
||||||
|
log_webhook_info(f"{self.meta().id}(GitHub Webhook)", webhook_uuid)
|
||||||
|
# 保持运行状态,等待 shutdown
|
||||||
|
await self.shutdown_event.wait()
|
||||||
|
else:
|
||||||
|
logger.warning("GitHub Webhook 适配器需要启用统一 webhook 模式")
|
||||||
|
await self.shutdown_event.wait()
|
||||||
|
|
||||||
|
async def webhook_callback(self, request: Any) -> Any:
|
||||||
|
"""统一 Webhook 回调入口
|
||||||
|
|
||||||
|
处理 GitHub webhook 事件
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Quart 请求对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
响应数据
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 获取事件类型
|
||||||
|
event_type = request.headers.get("X-GitHub-Event", "")
|
||||||
|
|
||||||
|
# 获取请求数据
|
||||||
|
payload = await request.json
|
||||||
|
|
||||||
|
# 验证 webhook 签名(如果配置了 secret)
|
||||||
|
if self.webhook_secret:
|
||||||
|
if not await self._verify_signature(request, payload):
|
||||||
|
logger.warning("GitHub webhook 签名验证失败")
|
||||||
|
return {"error": "Invalid signature"}, 401
|
||||||
|
|
||||||
|
logger.debug(f"收到 GitHub Webhook 事件: {event_type}")
|
||||||
|
|
||||||
|
# 处理不同类型的事件
|
||||||
|
if event_type == "issues":
|
||||||
|
await self._handle_issue_event(payload)
|
||||||
|
elif event_type == "issue_comment":
|
||||||
|
await self._handle_issue_comment_event(payload)
|
||||||
|
elif event_type == "pull_request":
|
||||||
|
await self._handle_pull_request_event(payload)
|
||||||
|
elif event_type == "ping":
|
||||||
|
# GitHub webhook 验证事件
|
||||||
|
return {"message": "pong"}
|
||||||
|
else:
|
||||||
|
logger.debug(f"忽略不支持的 GitHub 事件类型: {event_type}")
|
||||||
|
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"处理 GitHub webhook 回调时发生错误: {e}", exc_info=True)
|
||||||
|
return {"error": str(e)}, 500
|
||||||
|
|
||||||
|
async def _verify_signature(self, request: Any, payload: dict) -> bool:
|
||||||
|
"""验证 GitHub webhook 签名
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: Quart 请求对象
|
||||||
|
payload: 请求负载数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
签名是否有效
|
||||||
|
"""
|
||||||
|
signature_header = request.headers.get("X-Hub-Signature-256", "")
|
||||||
|
if not signature_header:
|
||||||
|
# 如果没有签名头,检查是否有旧版本的签名
|
||||||
|
signature_header = request.headers.get("X-Hub-Signature", "")
|
||||||
|
if not signature_header:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 获取原始请求体
|
||||||
|
body = await request.get_data()
|
||||||
|
|
||||||
|
# 计算 HMAC
|
||||||
|
if signature_header.startswith("sha256="):
|
||||||
|
expected_signature = hmac.new(
|
||||||
|
self.webhook_secret.encode("utf-8"),
|
||||||
|
body,
|
||||||
|
hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
received_signature = signature_header.replace("sha256=", "")
|
||||||
|
elif signature_header.startswith("sha1="):
|
||||||
|
expected_signature = hmac.new(
|
||||||
|
self.webhook_secret.encode("utf-8"),
|
||||||
|
body,
|
||||||
|
hashlib.sha1,
|
||||||
|
).hexdigest()
|
||||||
|
received_signature = signature_header.replace("sha1=", "")
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 使用 hmac.compare_digest 防止时序攻击
|
||||||
|
return hmac.compare_digest(expected_signature, received_signature)
|
||||||
|
|
||||||
|
async def _handle_issue_event(self, payload: dict):
|
||||||
|
"""处理 issue 事件"""
|
||||||
|
action = payload.get("action", "")
|
||||||
|
|
||||||
|
# 只处理创建事件
|
||||||
|
if action != "created" and action != "opened":
|
||||||
|
return
|
||||||
|
|
||||||
|
issue = payload.get("issue", {})
|
||||||
|
repo = payload.get("repository", {})
|
||||||
|
sender = payload.get("sender", {})
|
||||||
|
|
||||||
|
# 构造消息文本
|
||||||
|
message_text = (
|
||||||
|
f"📝 新 Issue 创建\n"
|
||||||
|
f"仓库: {repo.get('full_name', 'unknown')}\n"
|
||||||
|
f"标题: {issue.get('title', 'No title')}\n"
|
||||||
|
f"作者: {sender.get('login', 'unknown')}\n"
|
||||||
|
f"链接: {issue.get('html_url', '')}\n"
|
||||||
|
f"内容:\n{issue.get('body', 'No description')[:200]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 AstrBotMessage
|
||||||
|
abm = self._create_message(
|
||||||
|
message_text,
|
||||||
|
sender.get("login", "unknown"),
|
||||||
|
sender.get("login", "unknown"),
|
||||||
|
repo.get("full_name", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提交事件
|
||||||
|
self.commit_event(
|
||||||
|
GitHubWebhookMessageEvent(
|
||||||
|
message_text,
|
||||||
|
abm,
|
||||||
|
self.meta(),
|
||||||
|
repo.get("full_name", "unknown"),
|
||||||
|
"issues",
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_issue_comment_event(self, payload: dict):
|
||||||
|
"""处理 issue 评论事件"""
|
||||||
|
action = payload.get("action", "")
|
||||||
|
|
||||||
|
# 只处理创建事件
|
||||||
|
if action != "created":
|
||||||
|
return
|
||||||
|
|
||||||
|
issue = payload.get("issue", {})
|
||||||
|
comment = payload.get("comment", {})
|
||||||
|
repo = payload.get("repository", {})
|
||||||
|
sender = payload.get("sender", {})
|
||||||
|
|
||||||
|
# 构造消息文本
|
||||||
|
message_text = (
|
||||||
|
f"💬 新 Issue 评论\n"
|
||||||
|
f"仓库: {repo.get('full_name', 'unknown')}\n"
|
||||||
|
f"Issue: {issue.get('title', 'No title')}\n"
|
||||||
|
f"评论者: {sender.get('login', 'unknown')}\n"
|
||||||
|
f"链接: {comment.get('html_url', '')}\n"
|
||||||
|
f"内容:\n{comment.get('body', 'No comment')[:200]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 AstrBotMessage
|
||||||
|
abm = self._create_message(
|
||||||
|
message_text,
|
||||||
|
sender.get("login", "unknown"),
|
||||||
|
sender.get("login", "unknown"),
|
||||||
|
repo.get("full_name", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提交事件
|
||||||
|
self.commit_event(
|
||||||
|
GitHubWebhookMessageEvent(
|
||||||
|
message_text,
|
||||||
|
abm,
|
||||||
|
self.meta(),
|
||||||
|
repo.get("full_name", "unknown"),
|
||||||
|
"issue_comment",
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_pull_request_event(self, payload: dict):
|
||||||
|
"""处理 pull request 事件"""
|
||||||
|
action = payload.get("action", "")
|
||||||
|
|
||||||
|
# 只处理打开事件
|
||||||
|
if action != "opened":
|
||||||
|
return
|
||||||
|
|
||||||
|
pr = payload.get("pull_request", {})
|
||||||
|
repo = payload.get("repository", {})
|
||||||
|
sender = payload.get("sender", {})
|
||||||
|
|
||||||
|
# 构造消息文本
|
||||||
|
message_text = (
|
||||||
|
f"🔀 新 Pull Request\n"
|
||||||
|
f"仓库: {repo.get('full_name', 'unknown')}\n"
|
||||||
|
f"标题: {pr.get('title', 'No title')}\n"
|
||||||
|
f"作者: {sender.get('login', 'unknown')}\n"
|
||||||
|
f"链接: {pr.get('html_url', '')}\n"
|
||||||
|
f"内容:\n{pr.get('body', 'No description')[:200]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建 AstrBotMessage
|
||||||
|
abm = self._create_message(
|
||||||
|
message_text,
|
||||||
|
sender.get("login", "unknown"),
|
||||||
|
sender.get("login", "unknown"),
|
||||||
|
repo.get("full_name", "unknown"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# 提交事件
|
||||||
|
self.commit_event(
|
||||||
|
GitHubWebhookMessageEvent(
|
||||||
|
message_text,
|
||||||
|
abm,
|
||||||
|
self.meta(),
|
||||||
|
repo.get("full_name", "unknown"),
|
||||||
|
"pull_request",
|
||||||
|
payload,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_message(
|
||||||
|
self,
|
||||||
|
message_text: str,
|
||||||
|
user_id: str,
|
||||||
|
nickname: str,
|
||||||
|
session_id: str,
|
||||||
|
) -> AstrBotMessage:
|
||||||
|
"""创建 AstrBotMessage 对象"""
|
||||||
|
abm = AstrBotMessage()
|
||||||
|
abm.type = MessageType.GROUP_MESSAGE
|
||||||
|
abm.self_id = self.client_self_id
|
||||||
|
abm.session_id = session_id
|
||||||
|
abm.message_id = ""
|
||||||
|
abm.sender = MessageMember(user_id=user_id, nickname=nickname)
|
||||||
|
abm.message = [Plain(message_text)]
|
||||||
|
abm.message_str = message_text
|
||||||
|
abm.raw_message = message_text
|
||||||
|
|
||||||
|
return abm
|
||||||
|
|
||||||
|
async def terminate(self):
|
||||||
|
"""终止适配器运行"""
|
||||||
|
self.shutdown_event.set()
|
||||||
|
logger.info("GitHub Webhook 适配器已经被优雅地关闭")
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||||
|
|
||||||
|
from ...astr_message_event import AstrMessageEvent
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubWebhookMessageEvent(AstrMessageEvent):
|
||||||
|
"""GitHub Webhook 消息事件"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
message_str: str,
|
||||||
|
message_obj: AstrBotMessage,
|
||||||
|
platform_meta: PlatformMetadata,
|
||||||
|
session_id: str,
|
||||||
|
event_type: str,
|
||||||
|
event_data: dict,
|
||||||
|
):
|
||||||
|
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||||
|
self.event_type = event_type
|
||||||
|
"""GitHub 事件类型: issues, issue_comment, pull_request"""
|
||||||
|
self.event_data = event_data
|
||||||
|
"""原始事件数据"""
|
||||||
@@ -81,12 +81,7 @@ class LarkPlatformAdapter(Platform):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.lark_api = (
|
self.lark_api = (
|
||||||
lark.Client.builder()
|
lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
|
||||||
.app_id(self.appid)
|
|
||||||
.app_secret(self.appsecret)
|
|
||||||
.log_level(lark.LogLevel.ERROR)
|
|
||||||
.domain(self.domain)
|
|
||||||
.build()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self.webhook_server = None
|
self.webhook_server = None
|
||||||
|
|||||||
@@ -200,15 +200,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
|||||||
if isinstance(chain, MessageChain):
|
if isinstance(chain, MessageChain):
|
||||||
if chain.type == "break":
|
if chain.type == "break":
|
||||||
# 分割符
|
# 分割符
|
||||||
if message_id:
|
|
||||||
try:
|
|
||||||
await self.client.edit_message_text(
|
|
||||||
text=delta,
|
|
||||||
chat_id=payload["chat_id"],
|
|
||||||
message_id=message_id,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
|
|
||||||
message_id = None # 重置消息 ID
|
message_id = None # 重置消息 ID
|
||||||
delta = "" # 重置 delta
|
delta = "" # 重置 delta
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -1,12 +1,11 @@
|
|||||||
import base64
|
import base64
|
||||||
import json
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from astrbot.api import logger
|
from astrbot.api import logger
|
||||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||||
from astrbot.api.message_components import File, Image, Json, Plain, Record
|
from astrbot.api.message_components import File, Image, Plain, Record
|
||||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||||
|
|
||||||
from .webchat_queue_mgr import webchat_queue_mgr
|
from .webchat_queue_mgr import webchat_queue_mgr
|
||||||
@@ -42,20 +41,12 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{
|
{
|
||||||
"type": "plain",
|
"type": "plain",
|
||||||
|
"cid": cid,
|
||||||
"data": data,
|
"data": data,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
"chain_type": message.type,
|
"chain_type": message.type,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
elif isinstance(comp, Json):
|
|
||||||
await web_chat_back_queue.put(
|
|
||||||
{
|
|
||||||
"type": "plain",
|
|
||||||
"data": json.dumps(comp.data, ensure_ascii=False),
|
|
||||||
"streaming": streaming,
|
|
||||||
"chain_type": message.type,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
elif isinstance(comp, Image):
|
elif isinstance(comp, Image):
|
||||||
# save image to local
|
# save image to local
|
||||||
filename = f"{str(uuid.uuid4())}.jpg"
|
filename = f"{str(uuid.uuid4())}.jpg"
|
||||||
@@ -67,6 +58,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{
|
{
|
||||||
"type": "image",
|
"type": "image",
|
||||||
|
"cid": cid,
|
||||||
"data": data,
|
"data": data,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
},
|
},
|
||||||
@@ -82,6 +74,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{
|
{
|
||||||
"type": "record",
|
"type": "record",
|
||||||
|
"cid": cid,
|
||||||
"data": data,
|
"data": data,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
},
|
},
|
||||||
@@ -98,6 +91,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
{
|
{
|
||||||
"type": "file",
|
"type": "file",
|
||||||
|
"cid": cid,
|
||||||
"data": data,
|
"data": data,
|
||||||
"streaming": streaming,
|
"streaming": streaming,
|
||||||
},
|
},
|
||||||
@@ -117,17 +111,18 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
cid = self.session_id.split("!")[-1]
|
cid = self.session_id.split("!")[-1]
|
||||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||||
async for chain in generator:
|
async for chain in generator:
|
||||||
# if chain.type == "break" and final_data:
|
if chain.type == "break" and final_data:
|
||||||
# # 分割符
|
# 分割符
|
||||||
# await web_chat_back_queue.put(
|
await web_chat_back_queue.put(
|
||||||
# {
|
{
|
||||||
# "type": "break", # break means a segment end
|
"type": "break", # break means a segment end
|
||||||
# "data": final_data,
|
"data": final_data,
|
||||||
# "streaming": True,
|
"streaming": True,
|
||||||
# },
|
"cid": cid,
|
||||||
# )
|
},
|
||||||
# final_data = ""
|
)
|
||||||
# continue
|
final_data = ""
|
||||||
|
continue
|
||||||
|
|
||||||
r = await WebChatMessageEvent._send(
|
r = await WebChatMessageEvent._send(
|
||||||
chain,
|
chain,
|
||||||
@@ -147,6 +142,7 @@ class WebChatMessageEvent(AstrMessageEvent):
|
|||||||
"data": final_data,
|
"data": final_data,
|
||||||
"reasoning": reasoning_content,
|
"reasoning": reasoning_content,
|
||||||
"streaming": True,
|
"streaming": True,
|
||||||
|
"cid": cid,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
await super().send_streaming(generator, use_fallback)
|
await super().send_streaming(generator, use_fallback)
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import enum
|
import enum
|
||||||
import json
|
import json
|
||||||
@@ -201,38 +199,6 @@ class ProviderRequest:
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TokenUsage:
|
|
||||||
input_other: int = 0
|
|
||||||
"""The number of input tokens, excluding cached tokens."""
|
|
||||||
input_cached: int = 0
|
|
||||||
"""The number of input cached tokens."""
|
|
||||||
output: int = 0
|
|
||||||
"""The number of output tokens."""
|
|
||||||
|
|
||||||
@property
|
|
||||||
def total(self) -> int:
|
|
||||||
return self.input_other + self.input_cached + self.output
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input(self) -> int:
|
|
||||||
return self.input_other + self.input_cached
|
|
||||||
|
|
||||||
def __add__(self, other: TokenUsage) -> TokenUsage:
|
|
||||||
return TokenUsage(
|
|
||||||
input_other=self.input_other + other.input_other,
|
|
||||||
input_cached=self.input_cached + other.input_cached,
|
|
||||||
output=self.output + other.output,
|
|
||||||
)
|
|
||||||
|
|
||||||
def __sub__(self, other: TokenUsage) -> TokenUsage:
|
|
||||||
return TokenUsage(
|
|
||||||
input_other=self.input_other - other.input_other,
|
|
||||||
input_cached=self.input_cached - other.input_cached,
|
|
||||||
output=self.output - other.output,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LLMResponse:
|
class LLMResponse:
|
||||||
role: str
|
role: str
|
||||||
@@ -261,11 +227,6 @@ class LLMResponse:
|
|||||||
is_chunk: bool = False
|
is_chunk: bool = False
|
||||||
"""Indicates if the response is a chunked response."""
|
"""Indicates if the response is a chunked response."""
|
||||||
|
|
||||||
id: str | None = None
|
|
||||||
"""The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
|
|
||||||
usage: TokenUsage | None = None
|
|
||||||
"""The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
role: str,
|
role: str,
|
||||||
@@ -280,8 +241,6 @@ class LLMResponse:
|
|||||||
| AnthropicMessage
|
| AnthropicMessage
|
||||||
| None = None,
|
| None = None,
|
||||||
is_chunk: bool = False,
|
is_chunk: bool = False,
|
||||||
id: str | None = None,
|
|
||||||
usage: TokenUsage | None = None,
|
|
||||||
):
|
):
|
||||||
"""初始化 LLMResponse
|
"""初始化 LLMResponse
|
||||||
|
|
||||||
|
|||||||
@@ -6,12 +6,10 @@ from mimetypes import guess_type
|
|||||||
import anthropic
|
import anthropic
|
||||||
from anthropic import AsyncAnthropic
|
from anthropic import AsyncAnthropic
|
||||||
from anthropic.types import Message
|
from anthropic.types import Message
|
||||||
from anthropic.types.message_delta_usage import MessageDeltaUsage
|
|
||||||
from anthropic.types.usage import Usage
|
|
||||||
|
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.provider import Provider
|
from astrbot.api.provider import Provider
|
||||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
|
||||||
@@ -109,22 +107,6 @@ class ProviderAnthropic(Provider):
|
|||||||
|
|
||||||
return system_prompt, new_messages
|
return system_prompt, new_messages
|
||||||
|
|
||||||
def _extract_usage(self, usage: Usage) -> TokenUsage:
|
|
||||||
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
|
|
||||||
return TokenUsage(
|
|
||||||
input_other=usage.input_tokens or 0,
|
|
||||||
input_cached=usage.cache_read_input_tokens or 0,
|
|
||||||
output=usage.output_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
|
|
||||||
if usage.input_tokens is not None:
|
|
||||||
token_usage.input_other = usage.input_tokens
|
|
||||||
if usage.cache_read_input_tokens is not None:
|
|
||||||
token_usage.input_cached = usage.cache_read_input_tokens
|
|
||||||
if usage.output_tokens is not None:
|
|
||||||
token_usage.output = usage.output_tokens
|
|
||||||
|
|
||||||
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
|
||||||
if tools:
|
if tools:
|
||||||
if tool_list := tools.get_func_desc_anthropic_style():
|
if tool_list := tools.get_func_desc_anthropic_style():
|
||||||
@@ -149,10 +131,6 @@ class ProviderAnthropic(Provider):
|
|||||||
llm_response.tools_call_args.append(content_block.input)
|
llm_response.tools_call_args.append(content_block.input)
|
||||||
llm_response.tools_call_name.append(content_block.name)
|
llm_response.tools_call_name.append(content_block.name)
|
||||||
llm_response.tools_call_ids.append(content_block.id)
|
llm_response.tools_call_ids.append(content_block.id)
|
||||||
|
|
||||||
llm_response.id = completion.id
|
|
||||||
llm_response.usage = self._extract_usage(completion.usage)
|
|
||||||
|
|
||||||
# TODO(Soulter): 处理 end_turn 情况
|
# TODO(Soulter): 处理 end_turn 情况
|
||||||
if not llm_response.completion_text and not llm_response.tools_call_args:
|
if not llm_response.completion_text and not llm_response.tools_call_args:
|
||||||
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}。")
|
||||||
@@ -174,16 +152,9 @@ class ProviderAnthropic(Provider):
|
|||||||
final_text = ""
|
final_text = ""
|
||||||
final_tool_calls = []
|
final_tool_calls = []
|
||||||
|
|
||||||
id = None
|
|
||||||
usage = TokenUsage()
|
|
||||||
|
|
||||||
async with self.client.messages.stream(**payloads) as stream:
|
async with self.client.messages.stream(**payloads) as stream:
|
||||||
assert isinstance(stream, anthropic.AsyncMessageStream)
|
assert isinstance(stream, anthropic.AsyncMessageStream)
|
||||||
async for event in stream:
|
async for event in stream:
|
||||||
if event.type == "message_start":
|
|
||||||
# the usage contains input token usage
|
|
||||||
id = event.message.id
|
|
||||||
usage = self._extract_usage(event.message.usage)
|
|
||||||
if event.type == "content_block_start":
|
if event.type == "content_block_start":
|
||||||
if event.content_block.type == "text":
|
if event.content_block.type == "text":
|
||||||
# 文本块开始
|
# 文本块开始
|
||||||
@@ -191,8 +162,6 @@ class ProviderAnthropic(Provider):
|
|||||||
role="assistant",
|
role="assistant",
|
||||||
completion_text="",
|
completion_text="",
|
||||||
is_chunk=True,
|
is_chunk=True,
|
||||||
usage=usage,
|
|
||||||
id=id,
|
|
||||||
)
|
)
|
||||||
elif event.content_block.type == "tool_use":
|
elif event.content_block.type == "tool_use":
|
||||||
# 工具使用块开始,初始化缓冲区
|
# 工具使用块开始,初始化缓冲区
|
||||||
@@ -210,8 +179,6 @@ class ProviderAnthropic(Provider):
|
|||||||
role="assistant",
|
role="assistant",
|
||||||
completion_text=event.delta.text,
|
completion_text=event.delta.text,
|
||||||
is_chunk=True,
|
is_chunk=True,
|
||||||
usage=usage,
|
|
||||||
id=id,
|
|
||||||
)
|
)
|
||||||
elif event.delta.type == "input_json_delta":
|
elif event.delta.type == "input_json_delta":
|
||||||
# 工具调用参数增量
|
# 工具调用参数增量
|
||||||
@@ -248,8 +215,6 @@ class ProviderAnthropic(Provider):
|
|||||||
tools_call_name=[tool_info["name"]],
|
tools_call_name=[tool_info["name"]],
|
||||||
tools_call_ids=[tool_info["id"]],
|
tools_call_ids=[tool_info["id"]],
|
||||||
is_chunk=True,
|
is_chunk=True,
|
||||||
usage=usage,
|
|
||||||
id=id,
|
|
||||||
)
|
)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
# JSON 解析失败,跳过这个工具调用
|
# JSON 解析失败,跳过这个工具调用
|
||||||
@@ -258,17 +223,11 @@ class ProviderAnthropic(Provider):
|
|||||||
# 清理缓冲区
|
# 清理缓冲区
|
||||||
del tool_use_buffer[event.index]
|
del tool_use_buffer[event.index]
|
||||||
|
|
||||||
elif event.type == "message_delta":
|
|
||||||
if event.usage:
|
|
||||||
self._update_usage(usage, event.usage)
|
|
||||||
|
|
||||||
# 返回最终的完整结果
|
# 返回最终的完整结果
|
||||||
final_response = LLMResponse(
|
final_response = LLMResponse(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
completion_text=final_text,
|
completion_text=final_text,
|
||||||
is_chunk=False,
|
is_chunk=False,
|
||||||
usage=usage,
|
|
||||||
id=id,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if final_tool_calls:
|
if final_tool_calls:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import astrbot.core.message.components as Comp
|
|||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
from astrbot.api.provider import Provider
|
from astrbot.api.provider import Provider
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
from astrbot.core.provider.entities import LLMResponse
|
||||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
|
||||||
@@ -138,7 +138,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
modalities = ["TEXT"]
|
modalities = ["TEXT"]
|
||||||
|
|
||||||
tool_list: list[types.Tool] | None = []
|
tool_list: list[types.Tool] | None = []
|
||||||
model_name = payloads.get("model", self.get_model())
|
model_name = self.get_model()
|
||||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||||
native_search = self.provider_config.get("gm_native_search", False)
|
native_search = self.provider_config.get("gm_native_search", False)
|
||||||
url_context = self.provider_config.get("gm_url_context", False)
|
url_context = self.provider_config.get("gm_url_context", False)
|
||||||
@@ -197,37 +197,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
types.Tool(function_declarations=func_desc["function_declarations"]),
|
types.Tool(function_declarations=func_desc["function_declarations"]),
|
||||||
]
|
]
|
||||||
|
|
||||||
# oper thinking config
|
|
||||||
thinking_config = None
|
|
||||||
if model_name.startswith("gemini-2.5"):
|
|
||||||
# The thinkingBudget parameter, introduced with the Gemini 2.5 series
|
|
||||||
thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
|
|
||||||
"budget", 0
|
|
||||||
)
|
|
||||||
if thinking_budget is not None:
|
|
||||||
thinking_config = types.ThinkingConfig(
|
|
||||||
thinking_budget=thinking_budget,
|
|
||||||
)
|
|
||||||
elif model_name.startswith("gemini-3"):
|
|
||||||
# The thinkingLevel parameter, recommended for Gemini 3 models and onwards
|
|
||||||
# Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
|
|
||||||
thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
|
|
||||||
"level", "HIGH"
|
|
||||||
)
|
|
||||||
if thinking_level and isinstance(thinking_level, str):
|
|
||||||
thinking_level = thinking_level.upper()
|
|
||||||
if thinking_level not in ["MINIMAL", "LOW", "MEDIUM", "HIGH"]:
|
|
||||||
logger.warning(
|
|
||||||
f"Invalid thinking level: {thinking_level}, using HIGH"
|
|
||||||
)
|
|
||||||
thinking_level = "HIGH"
|
|
||||||
level = types.ThinkingLevel(thinking_level)
|
|
||||||
thinking_config = types.ThinkingConfig()
|
|
||||||
if not hasattr(types.ThinkingConfig, "thinking_level"):
|
|
||||||
setattr(types.ThinkingConfig, "thinking_level", level)
|
|
||||||
else:
|
|
||||||
thinking_config.thinking_level = level
|
|
||||||
|
|
||||||
return types.GenerateContentConfig(
|
return types.GenerateContentConfig(
|
||||||
system_instruction=system_instruction,
|
system_instruction=system_instruction,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
@@ -247,7 +216,22 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
response_modalities=modalities,
|
response_modalities=modalities,
|
||||||
tools=cast(types.ToolListUnion | None, tool_list),
|
tools=cast(types.ToolListUnion | None, tool_list),
|
||||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||||
thinking_config=thinking_config,
|
thinking_config=(
|
||||||
|
types.ThinkingConfig(
|
||||||
|
thinking_budget=min(
|
||||||
|
int(
|
||||||
|
self.provider_config.get("gm_thinking_config", {}).get(
|
||||||
|
"budget",
|
||||||
|
0,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
24576,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if "gemini-2.5-flash" in self.get_model()
|
||||||
|
and hasattr(types.ThinkingConfig, "thinking_budget")
|
||||||
|
else None
|
||||||
|
),
|
||||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||||
disable=True,
|
disable=True,
|
||||||
),
|
),
|
||||||
@@ -363,16 +347,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
]
|
]
|
||||||
return "".join(thought_buf).strip()
|
return "".join(thought_buf).strip()
|
||||||
|
|
||||||
def _extract_usage(
|
|
||||||
self, usage_metadata: types.GenerateContentResponseUsageMetadata
|
|
||||||
) -> TokenUsage:
|
|
||||||
"""Extract usage from candidate"""
|
|
||||||
return TokenUsage(
|
|
||||||
input_other=usage_metadata.prompt_token_count or 0,
|
|
||||||
input_cached=usage_metadata.cached_content_token_count or 0,
|
|
||||||
output=usage_metadata.candidates_token_count or 0,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _process_content_parts(
|
def _process_content_parts(
|
||||||
self,
|
self,
|
||||||
candidate: types.Candidate,
|
candidate: types.Candidate,
|
||||||
@@ -457,8 +431,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = payloads.get("model", self.get_model())
|
|
||||||
|
|
||||||
modalities = ["TEXT"]
|
modalities = ["TEXT"]
|
||||||
if self.provider_config.get("gm_resp_image_modal", False):
|
if self.provider_config.get("gm_resp_image_modal", False):
|
||||||
modalities.append("IMAGE")
|
modalities.append("IMAGE")
|
||||||
@@ -477,7 +449,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
temperature,
|
temperature,
|
||||||
)
|
)
|
||||||
result = await self.client.models.generate_content(
|
result = await self.client.models.generate_content(
|
||||||
model=model,
|
model=self.get_model(),
|
||||||
contents=cast(types.ContentListUnion, conversation),
|
contents=cast(types.ContentListUnion, conversation),
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
@@ -503,11 +475,11 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
e.message = ""
|
e.message = ""
|
||||||
if "Developer instruction is not enabled" in e.message:
|
if "Developer instruction is not enabled" in e.message:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||||
)
|
)
|
||||||
system_instruction = None
|
system_instruction = None
|
||||||
elif "Function calling is not enabled" in e.message:
|
elif "Function calling is not enabled" in e.message:
|
||||||
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||||
tools = None
|
tools = None
|
||||||
elif (
|
elif (
|
||||||
"Multi-modal output is not supported" in e.message
|
"Multi-modal output is not supported" in e.message
|
||||||
@@ -516,7 +488,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
or "only supports text output" in e.message
|
or "only supports text output" in e.message
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{model} 不支持多模态输出,降级为文本模态",
|
f"{self.get_model()} 不支持多模态输出,降级为文本模态",
|
||||||
)
|
)
|
||||||
modalities = ["TEXT"]
|
modalities = ["TEXT"]
|
||||||
else:
|
else:
|
||||||
@@ -529,9 +501,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
result.candidates[0],
|
result.candidates[0],
|
||||||
llm_response,
|
llm_response,
|
||||||
)
|
)
|
||||||
llm_response.id = result.response_id
|
|
||||||
if result.usage_metadata:
|
|
||||||
llm_response.usage = self._extract_usage(result.usage_metadata)
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
async def _query_stream(
|
async def _query_stream(
|
||||||
@@ -544,7 +513,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
model = payloads.get("model", self.get_model())
|
|
||||||
conversation = self._prepare_conversation(payloads)
|
conversation = self._prepare_conversation(payloads)
|
||||||
|
|
||||||
result = None
|
result = None
|
||||||
@@ -556,7 +525,7 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
system_instruction,
|
system_instruction,
|
||||||
)
|
)
|
||||||
result = await self.client.models.generate_content_stream(
|
result = await self.client.models.generate_content_stream(
|
||||||
model=model,
|
model=self.get_model(),
|
||||||
contents=cast(types.ContentListUnion, conversation),
|
contents=cast(types.ContentListUnion, conversation),
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
@@ -566,11 +535,11 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
e.message = ""
|
e.message = ""
|
||||||
if "Developer instruction is not enabled" in e.message:
|
if "Developer instruction is not enabled" in e.message:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"{model} 不支持 system prompt,已自动去除(影响人格设置)",
|
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
|
||||||
)
|
)
|
||||||
system_instruction = None
|
system_instruction = None
|
||||||
elif "Function calling is not enabled" in e.message:
|
elif "Function calling is not enabled" in e.message:
|
||||||
logger.warning(f"{model} 不支持函数调用,已自动去除")
|
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||||
tools = None
|
tools = None
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
@@ -600,9 +569,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
chunk.candidates[0],
|
chunk.candidates[0],
|
||||||
llm_response,
|
llm_response,
|
||||||
)
|
)
|
||||||
llm_response.id = chunk.response_id
|
|
||||||
if chunk.usage_metadata:
|
|
||||||
llm_response.usage = self._extract_usage(chunk.usage_metadata)
|
|
||||||
yield llm_response
|
yield llm_response
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -630,9 +596,6 @@ class ProviderGoogleGenAI(Provider):
|
|||||||
chunk.candidates[0],
|
chunk.candidates[0],
|
||||||
final_response,
|
final_response,
|
||||||
)
|
)
|
||||||
final_response.id = chunk.response_id
|
|
||||||
if chunk.usage_metadata:
|
|
||||||
final_response.usage = self._extract_usage(chunk.usage_metadata)
|
|
||||||
break
|
break
|
||||||
|
|
||||||
# Yield final complete response with accumulated text
|
# Yield final complete response with accumulated text
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from openai._exceptions import NotFoundError
|
|||||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||||
from openai.types.chat.chat_completion import ChatCompletion
|
from openai.types.chat.chat_completion import ChatCompletion
|
||||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from openai.types.completion_usage import CompletionUsage
|
|
||||||
|
|
||||||
import astrbot.core.message.components as Comp
|
import astrbot.core.message.components as Comp
|
||||||
from astrbot import logger
|
from astrbot import logger
|
||||||
@@ -20,7 +19,7 @@ from astrbot.api.provider import Provider
|
|||||||
from astrbot.core.agent.message import Message
|
from astrbot.core.agent.message import Message
|
||||||
from astrbot.core.agent.tool import ToolSet
|
from astrbot.core.agent.tool import ToolSet
|
||||||
from astrbot.core.message.message_event_result import MessageChain
|
from astrbot.core.message.message_event_result import MessageChain
|
||||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||||
from astrbot.core.utils.io import download_image_by_url
|
from astrbot.core.utils.io import download_image_by_url
|
||||||
|
|
||||||
from ..register import register_provider_adapter
|
from ..register import register_provider_adapter
|
||||||
@@ -209,7 +208,6 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
# handle the content delta
|
# handle the content delta
|
||||||
reasoning = self._extract_reasoning_content(chunk)
|
reasoning = self._extract_reasoning_content(chunk)
|
||||||
_y = False
|
_y = False
|
||||||
llm_response.id = chunk.id
|
|
||||||
if reasoning:
|
if reasoning:
|
||||||
llm_response.reasoning_content = reasoning
|
llm_response.reasoning_content = reasoning
|
||||||
_y = True
|
_y = True
|
||||||
@@ -219,8 +217,6 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
chain=[Comp.Plain(completion_text)],
|
chain=[Comp.Plain(completion_text)],
|
||||||
)
|
)
|
||||||
_y = True
|
_y = True
|
||||||
if chunk.usage:
|
|
||||||
llm_response.usage = self._extract_usage(chunk.usage)
|
|
||||||
if _y:
|
if _y:
|
||||||
yield llm_response
|
yield llm_response
|
||||||
|
|
||||||
@@ -249,15 +245,6 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
reasoning_text = str(reasoning_attr)
|
reasoning_text = str(reasoning_attr)
|
||||||
return reasoning_text
|
return reasoning_text
|
||||||
|
|
||||||
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
|
|
||||||
ptd = usage.prompt_tokens_details
|
|
||||||
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
|
|
||||||
return TokenUsage(
|
|
||||||
input_other=usage.prompt_tokens - cached,
|
|
||||||
input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
|
|
||||||
output=usage.completion_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _parse_openai_completion(
|
async def _parse_openai_completion(
|
||||||
self, completion: ChatCompletion, tools: ToolSet | None
|
self, completion: ChatCompletion, tools: ToolSet | None
|
||||||
) -> LLMResponse:
|
) -> LLMResponse:
|
||||||
@@ -334,10 +321,6 @@ class ProviderOpenAIOfficial(Provider):
|
|||||||
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
raise Exception(f"API 返回的 completion 无法解析:{completion}。")
|
||||||
|
|
||||||
llm_response.raw_completion = completion
|
llm_response.raw_completion = completion
|
||||||
llm_response.id = completion.id
|
|
||||||
|
|
||||||
if completion.usage:
|
|
||||||
llm_response.usage = self._extract_usage(completion.usage)
|
|
||||||
|
|
||||||
return llm_response
|
return llm_response
|
||||||
|
|
||||||
|
|||||||
@@ -2,19 +2,15 @@ from astrbot.core import html_renderer
|
|||||||
from astrbot.core.provider import Provider
|
from astrbot.core.provider import Provider
|
||||||
from astrbot.core.star.star_tools import StarTools
|
from astrbot.core.star.star_tools import StarTools
|
||||||
from astrbot.core.utils.command_parser import CommandParserMixin
|
from astrbot.core.utils.command_parser import CommandParserMixin
|
||||||
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
|
|
||||||
|
|
||||||
from .context import Context
|
from .context import Context
|
||||||
from .star import StarMetadata, star_map, star_registry
|
from .star import StarMetadata, star_map, star_registry
|
||||||
from .star_manager import PluginManager
|
from .star_manager import PluginManager
|
||||||
|
|
||||||
|
|
||||||
class Star(CommandParserMixin, PluginKVStoreMixin):
|
class Star(CommandParserMixin):
|
||||||
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
"""所有插件(Star)的父类,所有插件都应该继承于这个类"""
|
||||||
|
|
||||||
author: str
|
|
||||||
name: str
|
|
||||||
|
|
||||||
def __init__(self, context: Context, config: dict | None = None):
|
def __init__(self, context: Context, config: dict | None = None):
|
||||||
StarTools.initialize(context)
|
StarTools.initialize(context)
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|||||||
@@ -1,449 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from astrbot.core import db_helper
|
|
||||||
from astrbot.core.db.po import CommandConfig
|
|
||||||
from astrbot.core.star.filter.command import CommandFilter
|
|
||||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
|
||||||
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
|
|
||||||
from astrbot.core.star.star import star_map
|
|
||||||
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CommandDescriptor:
|
|
||||||
handler: StarHandlerMetadata = field(repr=False)
|
|
||||||
filter_ref: CommandFilter | CommandGroupFilter | None = field(
|
|
||||||
default=None,
|
|
||||||
repr=False,
|
|
||||||
)
|
|
||||||
handler_full_name: str = ""
|
|
||||||
handler_name: str = ""
|
|
||||||
plugin_name: str = ""
|
|
||||||
plugin_display_name: str | None = None
|
|
||||||
module_path: str = ""
|
|
||||||
description: str = ""
|
|
||||||
command_type: str = "command" # "command" | "group" | "sub_command"
|
|
||||||
raw_command_name: str | None = None
|
|
||||||
current_fragment: str | None = None
|
|
||||||
parent_signature: str = ""
|
|
||||||
parent_group_handler: str = ""
|
|
||||||
original_command: str | None = None
|
|
||||||
effective_command: str | None = None
|
|
||||||
aliases: list[str] = field(default_factory=list)
|
|
||||||
permission: str = "everyone"
|
|
||||||
enabled: bool = True
|
|
||||||
is_group: bool = False
|
|
||||||
is_sub_command: bool = False
|
|
||||||
reserved: bool = False
|
|
||||||
config: CommandConfig | None = None
|
|
||||||
has_conflict: bool = False
|
|
||||||
sub_commands: list[CommandDescriptor] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
async def sync_command_configs() -> None:
|
|
||||||
"""同步指令配置,清理过期配置。"""
|
|
||||||
descriptors = _collect_descriptors(include_sub_commands=False)
|
|
||||||
config_records = await db_helper.get_command_configs()
|
|
||||||
config_map = _bind_configs_to_descriptors(descriptors, config_records)
|
|
||||||
live_handlers = {desc.handler_full_name for desc in descriptors}
|
|
||||||
|
|
||||||
stale_configs = [key for key in config_map if key not in live_handlers]
|
|
||||||
if stale_configs:
|
|
||||||
await db_helper.delete_command_configs(stale_configs)
|
|
||||||
|
|
||||||
|
|
||||||
async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescriptor:
|
|
||||||
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
|
||||||
if not descriptor:
|
|
||||||
raise ValueError("指定的处理函数不存在或不是指令。")
|
|
||||||
|
|
||||||
existing_cfg = await db_helper.get_command_config(handler_full_name)
|
|
||||||
config = await db_helper.upsert_command_config(
|
|
||||||
handler_full_name=handler_full_name,
|
|
||||||
plugin_name=descriptor.plugin_name or "",
|
|
||||||
module_path=descriptor.module_path,
|
|
||||||
original_command=descriptor.original_command or descriptor.handler_name,
|
|
||||||
resolved_command=(
|
|
||||||
existing_cfg.resolved_command
|
|
||||||
if existing_cfg
|
|
||||||
else descriptor.current_fragment
|
|
||||||
),
|
|
||||||
enabled=enabled,
|
|
||||||
keep_original_alias=False,
|
|
||||||
conflict_key=existing_cfg.conflict_key
|
|
||||||
if existing_cfg and existing_cfg.conflict_key
|
|
||||||
else descriptor.original_command,
|
|
||||||
resolution_strategy=existing_cfg.resolution_strategy if existing_cfg else None,
|
|
||||||
note=existing_cfg.note if existing_cfg else None,
|
|
||||||
extra_data=existing_cfg.extra_data if existing_cfg else None,
|
|
||||||
auto_managed=False,
|
|
||||||
)
|
|
||||||
_bind_descriptor_with_config(descriptor, config)
|
|
||||||
await sync_command_configs()
|
|
||||||
return descriptor
|
|
||||||
|
|
||||||
|
|
||||||
async def rename_command(
|
|
||||||
handler_full_name: str,
|
|
||||||
new_fragment: str,
|
|
||||||
) -> CommandDescriptor:
|
|
||||||
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
|
||||||
if not descriptor:
|
|
||||||
raise ValueError("指定的处理函数不存在或不是指令。")
|
|
||||||
|
|
||||||
new_fragment = new_fragment.strip()
|
|
||||||
if not new_fragment:
|
|
||||||
raise ValueError("指令名不能为空。")
|
|
||||||
|
|
||||||
candidate_full = _compose_command(descriptor.parent_signature, new_fragment)
|
|
||||||
if _is_command_in_use(handler_full_name, candidate_full):
|
|
||||||
raise ValueError("新的指令名已被其他指令占用,请换一个名称。")
|
|
||||||
|
|
||||||
config = await db_helper.upsert_command_config(
|
|
||||||
handler_full_name=handler_full_name,
|
|
||||||
plugin_name=descriptor.plugin_name or "",
|
|
||||||
module_path=descriptor.module_path,
|
|
||||||
original_command=descriptor.original_command or descriptor.handler_name,
|
|
||||||
resolved_command=new_fragment,
|
|
||||||
enabled=True if descriptor.enabled else False,
|
|
||||||
keep_original_alias=False,
|
|
||||||
conflict_key=descriptor.original_command,
|
|
||||||
resolution_strategy="manual_rename",
|
|
||||||
note=None,
|
|
||||||
extra_data=None,
|
|
||||||
auto_managed=False,
|
|
||||||
)
|
|
||||||
_bind_descriptor_with_config(descriptor, config)
|
|
||||||
|
|
||||||
await sync_command_configs()
|
|
||||||
return descriptor
|
|
||||||
|
|
||||||
|
|
||||||
async def list_commands() -> list[dict[str, Any]]:
|
|
||||||
descriptors = _collect_descriptors(include_sub_commands=True)
|
|
||||||
config_records = await db_helper.get_command_configs()
|
|
||||||
_bind_configs_to_descriptors(descriptors, config_records)
|
|
||||||
|
|
||||||
conflict_groups = _group_conflicts(descriptors)
|
|
||||||
conflict_handler_names: set[str] = {
|
|
||||||
d.handler_full_name for group in conflict_groups.values() for d in group
|
|
||||||
}
|
|
||||||
|
|
||||||
# 分类,设置冲突标志,将子指令挂载到父指令组
|
|
||||||
group_map: dict[str, CommandDescriptor] = {}
|
|
||||||
sub_commands: list[CommandDescriptor] = []
|
|
||||||
root_commands: list[CommandDescriptor] = []
|
|
||||||
|
|
||||||
for desc in descriptors:
|
|
||||||
desc.has_conflict = desc.handler_full_name in conflict_handler_names
|
|
||||||
if desc.is_group:
|
|
||||||
group_map[desc.handler_full_name] = desc
|
|
||||||
elif desc.is_sub_command:
|
|
||||||
sub_commands.append(desc)
|
|
||||||
else:
|
|
||||||
root_commands.append(desc)
|
|
||||||
|
|
||||||
for sub in sub_commands:
|
|
||||||
if sub.parent_group_handler and sub.parent_group_handler in group_map:
|
|
||||||
group_map[sub.parent_group_handler].sub_commands.append(sub)
|
|
||||||
else:
|
|
||||||
root_commands.append(sub)
|
|
||||||
|
|
||||||
# 指令组 + 普通指令,按 effective_command 字母排序
|
|
||||||
all_commands = list(group_map.values()) + root_commands
|
|
||||||
all_commands.sort(key=lambda d: (d.effective_command or "").lower())
|
|
||||||
|
|
||||||
result = [_descriptor_to_dict(desc) for desc in all_commands]
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def list_command_conflicts() -> list[dict[str, Any]]:
|
|
||||||
"""列出所有冲突的指令组。"""
|
|
||||||
descriptors = _collect_descriptors(include_sub_commands=False)
|
|
||||||
config_records = await db_helper.get_command_configs()
|
|
||||||
_bind_configs_to_descriptors(descriptors, config_records)
|
|
||||||
|
|
||||||
conflict_groups = _group_conflicts(descriptors)
|
|
||||||
details = [
|
|
||||||
{
|
|
||||||
"conflict_key": key,
|
|
||||||
"handlers": [
|
|
||||||
{
|
|
||||||
"handler_full_name": item.handler_full_name,
|
|
||||||
"plugin": item.plugin_name,
|
|
||||||
"current_name": item.effective_command,
|
|
||||||
}
|
|
||||||
for item in group
|
|
||||||
],
|
|
||||||
}
|
|
||||||
for key, group in conflict_groups.items()
|
|
||||||
]
|
|
||||||
return details
|
|
||||||
|
|
||||||
|
|
||||||
# Internal helpers ----------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]:
|
|
||||||
"""收集指令,按需包含子指令。"""
|
|
||||||
descriptors: list[CommandDescriptor] = []
|
|
||||||
for handler in star_handlers_registry:
|
|
||||||
desc = _build_descriptor(handler)
|
|
||||||
if not desc:
|
|
||||||
continue
|
|
||||||
if not include_sub_commands and desc.is_sub_command:
|
|
||||||
continue
|
|
||||||
descriptors.append(desc)
|
|
||||||
return descriptors
|
|
||||||
|
|
||||||
|
|
||||||
def _build_descriptor(handler: StarHandlerMetadata) -> CommandDescriptor | None:
|
|
||||||
filter_ref = _locate_primary_filter(handler)
|
|
||||||
if filter_ref is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
plugin_meta = star_map.get(handler.handler_module_path)
|
|
||||||
plugin_name = (
|
|
||||||
plugin_meta.name if plugin_meta else None
|
|
||||||
) or handler.handler_module_path
|
|
||||||
plugin_display = plugin_meta.display_name if plugin_meta else None
|
|
||||||
|
|
||||||
is_sub_command = bool(handler.extras_configs.get("sub_command"))
|
|
||||||
parent_group_handler = ""
|
|
||||||
|
|
||||||
if isinstance(filter_ref, CommandFilter):
|
|
||||||
raw_fragment = getattr(
|
|
||||||
filter_ref, "_original_command_name", filter_ref.command_name
|
|
||||||
)
|
|
||||||
current_fragment = filter_ref.command_name
|
|
||||||
parent_signature = (filter_ref.parent_command_names or [""])[0].strip()
|
|
||||||
# 如果是子指令,尝试找到父指令组的 handler_full_name
|
|
||||||
if is_sub_command and parent_signature:
|
|
||||||
parent_group_handler = _find_parent_group_handler(
|
|
||||||
handler.handler_module_path, parent_signature
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raw_fragment = getattr(
|
|
||||||
filter_ref, "_original_group_name", filter_ref.group_name
|
|
||||||
)
|
|
||||||
current_fragment = filter_ref.group_name
|
|
||||||
parent_signature = _resolve_group_parent_signature(filter_ref)
|
|
||||||
|
|
||||||
original_command = _compose_command(parent_signature, raw_fragment)
|
|
||||||
effective_command = _compose_command(parent_signature, current_fragment)
|
|
||||||
|
|
||||||
# 确定 command_type
|
|
||||||
if isinstance(filter_ref, CommandGroupFilter):
|
|
||||||
command_type = "group"
|
|
||||||
elif is_sub_command:
|
|
||||||
command_type = "sub_command"
|
|
||||||
else:
|
|
||||||
command_type = "command"
|
|
||||||
|
|
||||||
descriptor = CommandDescriptor(
|
|
||||||
handler=handler,
|
|
||||||
filter_ref=filter_ref,
|
|
||||||
handler_full_name=handler.handler_full_name,
|
|
||||||
handler_name=handler.handler_name,
|
|
||||||
plugin_name=plugin_name,
|
|
||||||
plugin_display_name=plugin_display,
|
|
||||||
module_path=handler.handler_module_path,
|
|
||||||
description=handler.desc or "",
|
|
||||||
command_type=command_type,
|
|
||||||
raw_command_name=raw_fragment,
|
|
||||||
current_fragment=current_fragment,
|
|
||||||
parent_signature=parent_signature,
|
|
||||||
parent_group_handler=parent_group_handler,
|
|
||||||
original_command=original_command,
|
|
||||||
effective_command=effective_command,
|
|
||||||
aliases=sorted(getattr(filter_ref, "alias", set())),
|
|
||||||
permission=_determine_permission(handler),
|
|
||||||
enabled=handler.enabled,
|
|
||||||
is_group=isinstance(filter_ref, CommandGroupFilter),
|
|
||||||
is_sub_command=is_sub_command,
|
|
||||||
reserved=plugin_meta.reserved if plugin_meta else False,
|
|
||||||
)
|
|
||||||
return descriptor
|
|
||||||
|
|
||||||
|
|
||||||
def _build_descriptor_by_full_name(full_name: str) -> CommandDescriptor | None:
|
|
||||||
handler = star_handlers_registry.get_handler_by_full_name(full_name)
|
|
||||||
if not handler:
|
|
||||||
return None
|
|
||||||
return _build_descriptor(handler)
|
|
||||||
|
|
||||||
|
|
||||||
def _locate_primary_filter(
|
|
||||||
handler: StarHandlerMetadata,
|
|
||||||
) -> CommandFilter | CommandGroupFilter | None:
|
|
||||||
for filter_ref in handler.event_filters:
|
|
||||||
if isinstance(filter_ref, (CommandFilter, CommandGroupFilter)):
|
|
||||||
return filter_ref
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _determine_permission(handler: StarHandlerMetadata) -> str:
|
|
||||||
for filter_ref in handler.event_filters:
|
|
||||||
if isinstance(filter_ref, PermissionTypeFilter):
|
|
||||||
return (
|
|
||||||
"admin"
|
|
||||||
if filter_ref.permission_type == PermissionType.ADMIN
|
|
||||||
else "member"
|
|
||||||
)
|
|
||||||
return "everyone"
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_group_parent_signature(group_filter: CommandGroupFilter) -> str:
|
|
||||||
signatures: list[str] = []
|
|
||||||
parent = group_filter.parent_group
|
|
||||||
while parent:
|
|
||||||
signatures.append(getattr(parent, "_original_group_name", parent.group_name))
|
|
||||||
parent = parent.parent_group
|
|
||||||
return " ".join(reversed(signatures)).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _find_parent_group_handler(module_path: str, parent_signature: str) -> str:
|
|
||||||
"""根据模块路径和父级签名,找到对应的指令组 handler_full_name。"""
|
|
||||||
parent_sig_normalized = parent_signature.strip()
|
|
||||||
for handler in star_handlers_registry:
|
|
||||||
if handler.handler_module_path != module_path:
|
|
||||||
continue
|
|
||||||
filter_ref = _locate_primary_filter(handler)
|
|
||||||
if not isinstance(filter_ref, CommandGroupFilter):
|
|
||||||
continue
|
|
||||||
# 检查该指令组的完整指令名是否匹配 parent_signature
|
|
||||||
group_names = filter_ref.get_complete_command_names()
|
|
||||||
if parent_sig_normalized in group_names:
|
|
||||||
return handler.handler_full_name
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def _compose_command(parent_signature: str, fragment: str | None) -> str:
|
|
||||||
fragment = (fragment or "").strip()
|
|
||||||
parent_signature = parent_signature.strip()
|
|
||||||
if not parent_signature:
|
|
||||||
return fragment
|
|
||||||
if not fragment:
|
|
||||||
return parent_signature
|
|
||||||
return f"{parent_signature} {fragment}"
|
|
||||||
|
|
||||||
|
|
||||||
def _bind_descriptor_with_config(
|
|
||||||
descriptor: CommandDescriptor,
|
|
||||||
config: CommandConfig,
|
|
||||||
) -> None:
|
|
||||||
_apply_config_to_descriptor(descriptor, config)
|
|
||||||
_apply_config_to_runtime(descriptor, config)
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_config_to_descriptor(
|
|
||||||
descriptor: CommandDescriptor,
|
|
||||||
config: CommandConfig,
|
|
||||||
) -> None:
|
|
||||||
descriptor.config = config
|
|
||||||
descriptor.enabled = config.enabled
|
|
||||||
|
|
||||||
if config.original_command:
|
|
||||||
descriptor.original_command = config.original_command
|
|
||||||
|
|
||||||
new_fragment = config.resolved_command or descriptor.current_fragment
|
|
||||||
descriptor.current_fragment = new_fragment
|
|
||||||
descriptor.effective_command = _compose_command(
|
|
||||||
descriptor.parent_signature,
|
|
||||||
new_fragment,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_config_to_runtime(
|
|
||||||
descriptor: CommandDescriptor,
|
|
||||||
config: CommandConfig,
|
|
||||||
) -> None:
|
|
||||||
descriptor.handler.enabled = config.enabled
|
|
||||||
if descriptor.filter_ref and descriptor.current_fragment:
|
|
||||||
_set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
|
|
||||||
|
|
||||||
|
|
||||||
def _bind_configs_to_descriptors(
|
|
||||||
descriptors: list[CommandDescriptor],
|
|
||||||
config_records: list[CommandConfig],
|
|
||||||
) -> dict[str, CommandConfig]:
|
|
||||||
config_map = {cfg.handler_full_name: cfg for cfg in config_records}
|
|
||||||
for desc in descriptors:
|
|
||||||
if cfg := config_map.get(desc.handler_full_name):
|
|
||||||
_bind_descriptor_with_config(desc, cfg)
|
|
||||||
return config_map
|
|
||||||
|
|
||||||
|
|
||||||
def _group_conflicts(
|
|
||||||
descriptors: list[CommandDescriptor],
|
|
||||||
) -> dict[str, list[CommandDescriptor]]:
|
|
||||||
conflicts: dict[str, list[CommandDescriptor]] = defaultdict(list)
|
|
||||||
for desc in descriptors:
|
|
||||||
if desc.effective_command and desc.enabled:
|
|
||||||
conflicts[desc.effective_command].append(desc)
|
|
||||||
return {k: v for k, v in conflicts.items() if len(v) > 1}
|
|
||||||
|
|
||||||
|
|
||||||
def _set_filter_fragment(
|
|
||||||
filter_ref: CommandFilter | CommandGroupFilter,
|
|
||||||
fragment: str,
|
|
||||||
) -> None:
|
|
||||||
attr = (
|
|
||||||
"group_name" if isinstance(filter_ref, CommandGroupFilter) else "command_name"
|
|
||||||
)
|
|
||||||
current_value = getattr(filter_ref, attr)
|
|
||||||
if fragment == current_value:
|
|
||||||
return
|
|
||||||
setattr(filter_ref, attr, fragment)
|
|
||||||
if hasattr(filter_ref, "_cmpl_cmd_names"):
|
|
||||||
filter_ref._cmpl_cmd_names = None
|
|
||||||
|
|
||||||
|
|
||||||
def _is_command_in_use(
|
|
||||||
target_handler_full_name: str,
|
|
||||||
candidate_full_command: str,
|
|
||||||
) -> bool:
|
|
||||||
candidate = candidate_full_command.strip()
|
|
||||||
for handler in star_handlers_registry:
|
|
||||||
if handler.handler_full_name == target_handler_full_name:
|
|
||||||
continue
|
|
||||||
filter_ref = _locate_primary_filter(handler)
|
|
||||||
if not filter_ref:
|
|
||||||
continue
|
|
||||||
names = {name.strip() for name in filter_ref.get_complete_command_names()}
|
|
||||||
if candidate in names:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _descriptor_to_dict(desc: CommandDescriptor) -> dict[str, Any]:
|
|
||||||
result = {
|
|
||||||
"handler_full_name": desc.handler_full_name,
|
|
||||||
"handler_name": desc.handler_name,
|
|
||||||
"plugin": desc.plugin_name,
|
|
||||||
"plugin_display_name": desc.plugin_display_name,
|
|
||||||
"module_path": desc.module_path,
|
|
||||||
"description": desc.description,
|
|
||||||
"type": desc.command_type,
|
|
||||||
"parent_signature": desc.parent_signature,
|
|
||||||
"parent_group_handler": desc.parent_group_handler,
|
|
||||||
"original_command": desc.original_command,
|
|
||||||
"current_fragment": desc.current_fragment,
|
|
||||||
"effective_command": desc.effective_command,
|
|
||||||
"aliases": desc.aliases,
|
|
||||||
"permission": desc.permission,
|
|
||||||
"enabled": desc.enabled,
|
|
||||||
"is_group": desc.is_group,
|
|
||||||
"has_conflict": desc.has_conflict,
|
|
||||||
"reserved": desc.reserved,
|
|
||||||
}
|
|
||||||
# 如果是指令组,包含子指令列表
|
|
||||||
if desc.is_group and desc.sub_commands:
|
|
||||||
result["sub_commands"] = [_descriptor_to_dict(sub) for sub in desc.sub_commands]
|
|
||||||
else:
|
|
||||||
result["sub_commands"] = []
|
|
||||||
return result
|
|
||||||
@@ -296,10 +296,6 @@ class Context:
|
|||||||
provider_type=ProviderType.CHAT_COMPLETION,
|
provider_type=ProviderType.CHAT_COMPLETION,
|
||||||
umo=umo,
|
umo=umo,
|
||||||
)
|
)
|
||||||
if prov is None:
|
|
||||||
raise ProviderNotFoundError(
|
|
||||||
"provider not found, please choose provider first"
|
|
||||||
)
|
|
||||||
if not isinstance(prov, Provider):
|
if not isinstance(prov, Provider):
|
||||||
raise ValueError("返回的 Provider 不是 Provider 类型")
|
raise ValueError("返回的 Provider 不是 Provider 类型")
|
||||||
return prov
|
return prov
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ class CommandFilter(HandlerFilter):
|
|||||||
):
|
):
|
||||||
self.command_name = command_name
|
self.command_name = command_name
|
||||||
self.alias = alias if alias else set()
|
self.alias = alias if alias else set()
|
||||||
self._original_command_name = command_name
|
|
||||||
self.parent_command_names = (
|
self.parent_command_names = (
|
||||||
parent_command_names if parent_command_names is not None else [""]
|
parent_command_names if parent_command_names is not None else [""]
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ class CommandGroupFilter(HandlerFilter):
|
|||||||
):
|
):
|
||||||
self.group_name = group_name
|
self.group_name = group_name
|
||||||
self.alias = alias if alias else set()
|
self.alias = alias if alias else set()
|
||||||
self._original_group_name = group_name
|
|
||||||
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
|
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
|
||||||
self.custom_filter_list: list[CustomFilter] = []
|
self.custom_filter_list: list[CustomFilter] = []
|
||||||
self.parent_group = parent_group
|
self.parent_group = parent_group
|
||||||
|
|||||||
@@ -118,8 +118,6 @@ class StarHandlerRegistry(Generic[T]):
|
|||||||
# 过滤事件类型
|
# 过滤事件类型
|
||||||
if handler.event_type != event_type:
|
if handler.event_type != event_type:
|
||||||
continue
|
continue
|
||||||
if not handler.enabled:
|
|
||||||
continue
|
|
||||||
# 过滤启用状态
|
# 过滤启用状态
|
||||||
if only_activated:
|
if only_activated:
|
||||||
plugin = star_map.get(handler.handler_module_path)
|
plugin = star_map.get(handler.handler_module_path)
|
||||||
@@ -222,8 +220,6 @@ class StarHandlerMetadata(Generic[H]):
|
|||||||
extras_configs: dict = field(default_factory=dict)
|
extras_configs: dict = field(default_factory=dict)
|
||||||
"""插件注册的一些其他的信息, 如 priority 等"""
|
"""插件注册的一些其他的信息, 如 priority 等"""
|
||||||
|
|
||||||
enabled: bool = True
|
|
||||||
|
|
||||||
def __lt__(self, other: StarHandlerMetadata):
|
def __lt__(self, other: StarHandlerMetadata):
|
||||||
"""定义小于运算符以支持优先队列"""
|
"""定义小于运算符以支持优先队列"""
|
||||||
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from astrbot.core.utils.astrbot_path import (
|
|||||||
from astrbot.core.utils.io import remove_dir
|
from astrbot.core.utils.io import remove_dir
|
||||||
|
|
||||||
from . import StarMetadata
|
from . import StarMetadata
|
||||||
from .command_management import sync_command_configs
|
|
||||||
from .context import Context
|
from .context import Context
|
||||||
from .filter.permission import PermissionType, PermissionTypeFilter
|
from .filter.permission import PermissionType, PermissionTypeFilter
|
||||||
from .star import star_map, star_registry
|
from .star import star_map, star_registry
|
||||||
@@ -468,18 +467,6 @@ class PluginManager:
|
|||||||
metadata.star_cls = metadata.star_cls_type(
|
metadata.star_cls = metadata.star_cls_type(
|
||||||
context=self.context,
|
context=self.context,
|
||||||
)
|
)
|
||||||
|
|
||||||
p_name = (metadata.name or "unknown").lower().replace("/", "_")
|
|
||||||
p_author = (
|
|
||||||
(metadata.author or "unknown").lower().replace("/", "_")
|
|
||||||
)
|
|
||||||
setattr(metadata.star_cls, "name", p_name)
|
|
||||||
setattr(metadata.star_cls, "author", p_author)
|
|
||||||
setattr(
|
|
||||||
metadata.star_cls,
|
|
||||||
"plugin_id",
|
|
||||||
f"{p_author}/{p_name}",
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(f"插件 {metadata.name} 已被禁用。")
|
logger.info(f"插件 {metadata.name} 已被禁用。")
|
||||||
|
|
||||||
@@ -631,7 +618,6 @@ class PluginManager:
|
|||||||
# 清除 pip.main 导致的多余的 logging handlers
|
# 清除 pip.main 导致的多余的 logging handlers
|
||||||
for handler in logging.root.handlers[:]:
|
for handler in logging.root.handlers[:]:
|
||||||
logging.root.removeHandler(handler)
|
logging.root.removeHandler(handler)
|
||||||
await sync_command_configs()
|
|
||||||
|
|
||||||
if not fail_rec:
|
if not fail_rec:
|
||||||
return True, None
|
return True, None
|
||||||
|
|||||||
@@ -1,28 +0,0 @@
|
|||||||
from typing import TypeVar
|
|
||||||
|
|
||||||
from astrbot.core import sp
|
|
||||||
|
|
||||||
SUPPORTED_VALUE_TYPES = int | float | str | bytes | bool | dict | list | None
|
|
||||||
_VT = TypeVar("_VT")
|
|
||||||
|
|
||||||
|
|
||||||
class PluginKVStoreMixin:
|
|
||||||
"""为插件提供键值存储功能的 Mixin 类"""
|
|
||||||
|
|
||||||
plugin_id: str
|
|
||||||
|
|
||||||
async def put_kv_data(
|
|
||||||
self,
|
|
||||||
key: str,
|
|
||||||
value: SUPPORTED_VALUE_TYPES,
|
|
||||||
) -> None:
|
|
||||||
"""为指定插件存储一个键值对"""
|
|
||||||
await sp.put_async("plugin", self.plugin_id, key, value)
|
|
||||||
|
|
||||||
async def get_kv_data(self, key: str, default: _VT) -> _VT | None:
|
|
||||||
"""获取指定插件存储的键值对"""
|
|
||||||
return await sp.get_async("plugin", self.plugin_id, key, default)
|
|
||||||
|
|
||||||
async def delete_kv_data(self, key: str) -> None:
|
|
||||||
"""删除指定插件存储的键值对"""
|
|
||||||
await sp.remove_async("plugin", self.plugin_id, key)
|
|
||||||
@@ -1,6 +1,5 @@
|
|||||||
from .auth import AuthRoute
|
from .auth import AuthRoute
|
||||||
from .chat import ChatRoute
|
from .chat import ChatRoute
|
||||||
from .command import CommandRoute
|
|
||||||
from .config import ConfigRoute
|
from .config import ConfigRoute
|
||||||
from .conversation import ConversationRoute
|
from .conversation import ConversationRoute
|
||||||
from .file import FileRoute
|
from .file import FileRoute
|
||||||
@@ -18,7 +17,6 @@ from .update import UpdateRoute
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AuthRoute",
|
"AuthRoute",
|
||||||
"ChatRoute",
|
"ChatRoute",
|
||||||
"CommandRoute",
|
|
||||||
"ConfigRoute",
|
"ConfigRoute",
|
||||||
"ConversationRoute",
|
"ConversationRoute",
|
||||||
"FileRoute",
|
"FileRoute",
|
||||||
|
|||||||
@@ -227,19 +227,16 @@ class ChatRoute(Route):
|
|||||||
text: str,
|
text: str,
|
||||||
media_parts: list,
|
media_parts: list,
|
||||||
reasoning: str,
|
reasoning: str,
|
||||||
agent_stats: dict,
|
|
||||||
):
|
):
|
||||||
"""保存 bot 消息到历史记录,返回保存的记录"""
|
"""保存 bot 消息到历史记录,返回保存的记录"""
|
||||||
bot_message_parts = []
|
bot_message_parts = []
|
||||||
bot_message_parts.extend(media_parts)
|
|
||||||
if text:
|
if text:
|
||||||
bot_message_parts.append({"type": "plain", "text": text})
|
bot_message_parts.append({"type": "plain", "text": text})
|
||||||
|
bot_message_parts.extend(media_parts)
|
||||||
|
|
||||||
new_his = {"type": "bot", "message": bot_message_parts}
|
new_his = {"type": "bot", "message": bot_message_parts}
|
||||||
if reasoning:
|
if reasoning:
|
||||||
new_his["reasoning"] = reasoning
|
new_his["reasoning"] = reasoning
|
||||||
if agent_stats:
|
|
||||||
new_his["agent_stats"] = agent_stats
|
|
||||||
|
|
||||||
record = await self.platform_history_mgr.insert(
|
record = await self.platform_history_mgr.insert(
|
||||||
platform_id="webchat",
|
platform_id="webchat",
|
||||||
@@ -297,8 +294,7 @@ class ChatRoute(Route):
|
|||||||
accumulated_parts = []
|
accumulated_parts = []
|
||||||
accumulated_text = ""
|
accumulated_text = ""
|
||||||
accumulated_reasoning = ""
|
accumulated_reasoning = ""
|
||||||
tool_calls = {}
|
|
||||||
agent_stats = {}
|
|
||||||
try:
|
try:
|
||||||
async with track_conversation(self.running_convs, webchat_conv_id):
|
async with track_conversation(self.running_convs, webchat_conv_id):
|
||||||
while True:
|
while True:
|
||||||
@@ -318,16 +314,6 @@ class ChatRoute(Route):
|
|||||||
result_text = result["data"]
|
result_text = result["data"]
|
||||||
msg_type = result.get("type")
|
msg_type = result.get("type")
|
||||||
streaming = result.get("streaming", False)
|
streaming = result.get("streaming", False)
|
||||||
chain_type = result.get("chain_type")
|
|
||||||
|
|
||||||
if chain_type == "agent_stats":
|
|
||||||
stats_info = {
|
|
||||||
"type": "agent_stats",
|
|
||||||
"data": json.loads(result_text),
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(stats_info, ensure_ascii=False)}\n\n"
|
|
||||||
agent_stats = stats_info["data"]
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 发送 SSE 数据
|
# 发送 SSE 数据
|
||||||
try:
|
try:
|
||||||
@@ -349,35 +335,11 @@ class ChatRoute(Route):
|
|||||||
|
|
||||||
# 累积消息部分
|
# 累积消息部分
|
||||||
if msg_type == "plain":
|
if msg_type == "plain":
|
||||||
chain_type = result.get("chain_type")
|
chain_type = result.get("chain_type", "normal")
|
||||||
if chain_type == "tool_call":
|
if chain_type == "reasoning":
|
||||||
tool_call = json.loads(result_text)
|
|
||||||
tool_calls[tool_call.get("id")] = tool_call
|
|
||||||
if accumulated_text:
|
|
||||||
# 如果累积了文本,则先保存文本
|
|
||||||
accumulated_parts.append(
|
|
||||||
{"type": "plain", "text": accumulated_text}
|
|
||||||
)
|
|
||||||
accumulated_text = ""
|
|
||||||
elif chain_type == "tool_call_result":
|
|
||||||
tcr = json.loads(result_text)
|
|
||||||
tc_id = tcr.get("id")
|
|
||||||
if tc_id in tool_calls:
|
|
||||||
tool_calls[tc_id]["result"] = tcr.get("result")
|
|
||||||
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
|
||||||
accumulated_parts.append(
|
|
||||||
{
|
|
||||||
"type": "tool_call",
|
|
||||||
"tool_calls": [tool_calls[tc_id]],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tool_calls.pop(tc_id, None)
|
|
||||||
elif chain_type == "reasoning":
|
|
||||||
accumulated_reasoning += result_text
|
accumulated_reasoning += result_text
|
||||||
elif streaming:
|
|
||||||
accumulated_text += result_text
|
|
||||||
else:
|
else:
|
||||||
accumulated_text = result_text
|
accumulated_text += result_text
|
||||||
elif msg_type == "image":
|
elif msg_type == "image":
|
||||||
filename = result_text.replace("[IMAGE]", "")
|
filename = result_text.replace("[IMAGE]", "")
|
||||||
part = await self._create_attachment_from_file(
|
part = await self._create_attachment_from_file(
|
||||||
@@ -405,20 +367,15 @@ class ChatRoute(Route):
|
|||||||
if msg_type == "end":
|
if msg_type == "end":
|
||||||
break
|
break
|
||||||
elif (
|
elif (
|
||||||
(streaming and msg_type == "complete") or not streaming
|
(streaming and msg_type == "complete")
|
||||||
# or msg_type == "break"
|
or not streaming
|
||||||
|
or msg_type == "break"
|
||||||
):
|
):
|
||||||
if (
|
|
||||||
chain_type == "tool_call"
|
|
||||||
or chain_type == "tool_call_result"
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
saved_record = await self._save_bot_message(
|
saved_record = await self._save_bot_message(
|
||||||
webchat_conv_id,
|
webchat_conv_id,
|
||||||
accumulated_text,
|
accumulated_text,
|
||||||
accumulated_parts,
|
accumulated_parts,
|
||||||
accumulated_reasoning,
|
accumulated_reasoning,
|
||||||
agent_stats,
|
|
||||||
)
|
)
|
||||||
# 发送保存的消息信息给前端
|
# 发送保存的消息信息给前端
|
||||||
if saved_record and not client_disconnected:
|
if saved_record and not client_disconnected:
|
||||||
@@ -433,11 +390,11 @@ class ChatRoute(Route):
|
|||||||
yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n"
|
yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n"
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
accumulated_parts = []
|
# 重置累积变量 (对于 break 后的下一段消息)
|
||||||
accumulated_text = ""
|
if msg_type == "break":
|
||||||
accumulated_reasoning = ""
|
accumulated_parts = []
|
||||||
tool_calls = {}
|
accumulated_text = ""
|
||||||
agent_stats = {}
|
accumulated_reasoning = ""
|
||||||
except BaseException as e:
|
except BaseException as e:
|
||||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|||||||
@@ -1,82 +0,0 @@
|
|||||||
from quart import request
|
|
||||||
|
|
||||||
from astrbot.core.star.command_management import (
|
|
||||||
list_command_conflicts,
|
|
||||||
list_commands,
|
|
||||||
)
|
|
||||||
from astrbot.core.star.command_management import (
|
|
||||||
rename_command as rename_command_service,
|
|
||||||
)
|
|
||||||
from astrbot.core.star.command_management import (
|
|
||||||
toggle_command as toggle_command_service,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .route import Response, Route, RouteContext
|
|
||||||
|
|
||||||
|
|
||||||
class CommandRoute(Route):
|
|
||||||
def __init__(self, context: RouteContext) -> None:
|
|
||||||
super().__init__(context)
|
|
||||||
self.routes = {
|
|
||||||
"/commands": ("GET", self.get_commands),
|
|
||||||
"/commands/conflicts": ("GET", self.get_conflicts),
|
|
||||||
"/commands/toggle": ("POST", self.toggle_command),
|
|
||||||
"/commands/rename": ("POST", self.rename_command),
|
|
||||||
}
|
|
||||||
self.register_routes()
|
|
||||||
|
|
||||||
async def get_commands(self):
|
|
||||||
commands = await list_commands()
|
|
||||||
summary = {
|
|
||||||
"total": len(commands),
|
|
||||||
"disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
|
|
||||||
"conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
|
|
||||||
}
|
|
||||||
return Response().ok({"items": commands, "summary": summary}).__dict__
|
|
||||||
|
|
||||||
async def get_conflicts(self):
|
|
||||||
conflicts = await list_command_conflicts()
|
|
||||||
return Response().ok(conflicts).__dict__
|
|
||||||
|
|
||||||
async def toggle_command(self):
|
|
||||||
data = await request.get_json()
|
|
||||||
handler_full_name = data.get("handler_full_name")
|
|
||||||
enabled = data.get("enabled")
|
|
||||||
|
|
||||||
if handler_full_name is None or enabled is None:
|
|
||||||
return Response().error("handler_full_name 与 enabled 均为必填。").__dict__
|
|
||||||
|
|
||||||
if isinstance(enabled, str):
|
|
||||||
enabled = enabled.lower() in ("1", "true", "yes", "on")
|
|
||||||
|
|
||||||
try:
|
|
||||||
await toggle_command_service(handler_full_name, bool(enabled))
|
|
||||||
except ValueError as exc:
|
|
||||||
return Response().error(str(exc)).__dict__
|
|
||||||
|
|
||||||
payload = await _get_command_payload(handler_full_name)
|
|
||||||
return Response().ok(payload).__dict__
|
|
||||||
|
|
||||||
async def rename_command(self):
|
|
||||||
data = await request.get_json()
|
|
||||||
handler_full_name = data.get("handler_full_name")
|
|
||||||
new_name = data.get("new_name")
|
|
||||||
|
|
||||||
if not handler_full_name or not new_name:
|
|
||||||
return Response().error("handler_full_name 与 new_name 均为必填。").__dict__
|
|
||||||
|
|
||||||
try:
|
|
||||||
await rename_command_service(handler_full_name, new_name)
|
|
||||||
except ValueError as exc:
|
|
||||||
return Response().error(str(exc)).__dict__
|
|
||||||
|
|
||||||
payload = await _get_command_payload(handler_full_name)
|
|
||||||
return Response().ok(payload).__dict__
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_command_payload(handler_full_name: str):
|
|
||||||
commands = await list_commands()
|
|
||||||
for cmd in commands:
|
|
||||||
if cmd["handler_full_name"] == handler_full_name:
|
|
||||||
return cmd
|
|
||||||
return {}
|
|
||||||
@@ -1,9 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime
|
|
||||||
from io import BytesIO
|
|
||||||
|
|
||||||
from quart import request, send_file
|
from quart import request
|
||||||
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
@@ -32,7 +30,6 @@ class ConversationRoute(Route):
|
|||||||
"POST",
|
"POST",
|
||||||
self.update_history,
|
self.update_history,
|
||||||
),
|
),
|
||||||
"/conversation/export": ("POST", self.export_conversations),
|
|
||||||
}
|
}
|
||||||
self.db_helper = db_helper
|
self.db_helper = db_helper
|
||||||
self.conv_mgr = core_lifecycle.conversation_manager
|
self.conv_mgr = core_lifecycle.conversation_manager
|
||||||
@@ -286,90 +283,3 @@ class ConversationRoute(Route):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
|
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
|
||||||
return Response().error(f"更新对话历史失败: {e!s}").__dict__
|
return Response().error(f"更新对话历史失败: {e!s}").__dict__
|
||||||
|
|
||||||
async def export_conversations(self):
|
|
||||||
"""批量导出对话为 JSONL 格式"""
|
|
||||||
try:
|
|
||||||
data = await request.get_json()
|
|
||||||
conversations_to_export = data.get("conversations", [])
|
|
||||||
|
|
||||||
if not conversations_to_export:
|
|
||||||
return Response().error("导出列表不能为空").__dict__
|
|
||||||
|
|
||||||
# 收集所有对话的内容
|
|
||||||
jsonl_lines = []
|
|
||||||
exported_count = 0
|
|
||||||
failed_items = []
|
|
||||||
|
|
||||||
for conv_info in conversations_to_export:
|
|
||||||
user_id = conv_info.get("user_id")
|
|
||||||
cid = conv_info.get("cid")
|
|
||||||
|
|
||||||
if not user_id or not cid:
|
|
||||||
failed_items.append(
|
|
||||||
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
try:
|
|
||||||
conversation = await self.conv_mgr.get_conversation(
|
|
||||||
unified_msg_origin=user_id,
|
|
||||||
conversation_id=cid,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not conversation:
|
|
||||||
failed_items.append(
|
|
||||||
f"user_id:{user_id}, cid:{cid} - 对话不存在"
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
|
|
||||||
content = json.loads(conversation.history)
|
|
||||||
|
|
||||||
# 创建导出记录
|
|
||||||
export_record = {
|
|
||||||
"cid": cid,
|
|
||||||
"user_id": user_id,
|
|
||||||
"platform_id": conversation.platform_id,
|
|
||||||
"title": conversation.title,
|
|
||||||
"persona_id": conversation.persona_id,
|
|
||||||
"created_at": conversation.created_at,
|
|
||||||
"updated_at": conversation.updated_at,
|
|
||||||
"content": content,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 将记录转换为 JSON 字符串并添加到 JSONL
|
|
||||||
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
|
|
||||||
exported_count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
|
|
||||||
logger.error(
|
|
||||||
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if exported_count == 0:
|
|
||||||
return Response().error("没有成功导出任何对话").__dict__
|
|
||||||
|
|
||||||
# 创建 JSONL 内容
|
|
||||||
jsonl_content = "\n".join(jsonl_lines)
|
|
||||||
|
|
||||||
# 创建一个内存文件对象
|
|
||||||
file_obj = BytesIO(jsonl_content.encode("utf-8"))
|
|
||||||
file_obj.seek(0)
|
|
||||||
|
|
||||||
# 生成文件名
|
|
||||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
|
|
||||||
|
|
||||||
# 返回文件流
|
|
||||||
return await send_file(
|
|
||||||
file_obj,
|
|
||||||
mimetype="application/jsonl",
|
|
||||||
as_attachment=True,
|
|
||||||
attachment_filename=filename,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
|
|
||||||
return Response().error(f"批量导出对话失败: {e!s}").__dict__
|
|
||||||
|
|||||||
@@ -48,7 +48,6 @@ class KnowledgeBaseRoute(Route):
|
|||||||
# 文档管理
|
# 文档管理
|
||||||
"/kb/document/list": ("GET", self.list_documents),
|
"/kb/document/list": ("GET", self.list_documents),
|
||||||
"/kb/document/upload": ("POST", self.upload_document),
|
"/kb/document/upload": ("POST", self.upload_document),
|
||||||
"/kb/document/import": ("POST", self.import_documents),
|
|
||||||
"/kb/document/upload/url": ("POST", self.upload_document_from_url),
|
"/kb/document/upload/url": ("POST", self.upload_document_from_url),
|
||||||
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
|
"/kb/document/upload/progress": ("GET", self.get_upload_progress),
|
||||||
"/kb/document/get": ("GET", self.get_document),
|
"/kb/document/get": ("GET", self.get_document),
|
||||||
@@ -67,65 +66,6 @@ class KnowledgeBaseRoute(Route):
|
|||||||
def _get_kb_manager(self):
|
def _get_kb_manager(self):
|
||||||
return self.core_lifecycle.kb_manager
|
return self.core_lifecycle.kb_manager
|
||||||
|
|
||||||
def _init_task(self, task_id: str, status: str = "pending") -> None:
|
|
||||||
self.upload_tasks[task_id] = {
|
|
||||||
"status": status,
|
|
||||||
"result": None,
|
|
||||||
"error": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
def _set_task_result(
|
|
||||||
self, task_id: str, status: str, result: any = None, error: str | None = None
|
|
||||||
) -> None:
|
|
||||||
self.upload_tasks[task_id] = {
|
|
||||||
"status": status,
|
|
||||||
"result": result,
|
|
||||||
"error": error,
|
|
||||||
}
|
|
||||||
if task_id in self.upload_progress:
|
|
||||||
self.upload_progress[task_id]["status"] = status
|
|
||||||
|
|
||||||
def _update_progress(
|
|
||||||
self,
|
|
||||||
task_id: str,
|
|
||||||
*,
|
|
||||||
status: str | None = None,
|
|
||||||
file_index: int | None = None,
|
|
||||||
file_name: str | None = None,
|
|
||||||
stage: str | None = None,
|
|
||||||
current: int | None = None,
|
|
||||||
total: int | None = None,
|
|
||||||
) -> None:
|
|
||||||
if task_id not in self.upload_progress:
|
|
||||||
return
|
|
||||||
p = self.upload_progress[task_id]
|
|
||||||
if status is not None:
|
|
||||||
p["status"] = status
|
|
||||||
if file_index is not None:
|
|
||||||
p["file_index"] = file_index
|
|
||||||
if file_name is not None:
|
|
||||||
p["file_name"] = file_name
|
|
||||||
if stage is not None:
|
|
||||||
p["stage"] = stage
|
|
||||||
if current is not None:
|
|
||||||
p["current"] = current
|
|
||||||
if total is not None:
|
|
||||||
p["total"] = total
|
|
||||||
|
|
||||||
def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
|
|
||||||
async def _callback(stage: str, current: int, total: int):
|
|
||||||
self._update_progress(
|
|
||||||
task_id,
|
|
||||||
status="processing",
|
|
||||||
file_index=file_idx,
|
|
||||||
file_name=file_name,
|
|
||||||
stage=stage,
|
|
||||||
current=current,
|
|
||||||
total=total,
|
|
||||||
)
|
|
||||||
|
|
||||||
return _callback
|
|
||||||
|
|
||||||
async def _background_upload_task(
|
async def _background_upload_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
@@ -140,7 +80,11 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"""后台上传任务"""
|
"""后台上传任务"""
|
||||||
try:
|
try:
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self._init_task(task_id, status="processing")
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": "processing",
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
self.upload_progress[task_id] = {
|
self.upload_progress[task_id] = {
|
||||||
"status": "processing",
|
"status": "processing",
|
||||||
"file_index": 0,
|
"file_index": 0,
|
||||||
@@ -156,20 +100,30 @@ class KnowledgeBaseRoute(Route):
|
|||||||
for file_idx, file_info in enumerate(files_to_upload):
|
for file_idx, file_info in enumerate(files_to_upload):
|
||||||
try:
|
try:
|
||||||
# 更新整体进度
|
# 更新整体进度
|
||||||
self._update_progress(
|
self.upload_progress[task_id].update(
|
||||||
task_id,
|
{
|
||||||
status="processing",
|
"status": "processing",
|
||||||
file_index=file_idx,
|
"file_index": file_idx,
|
||||||
file_name=file_info["file_name"],
|
"file_name": file_info["file_name"],
|
||||||
stage="parsing",
|
"stage": "parsing",
|
||||||
current=0,
|
"current": 0,
|
||||||
total=100,
|
"total": 100,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建进度回调函数
|
# 创建进度回调函数
|
||||||
progress_callback = self._make_progress_callback(
|
async def progress_callback(stage, current, total):
|
||||||
task_id, file_idx, file_info["file_name"]
|
if task_id in self.upload_progress:
|
||||||
)
|
self.upload_progress[task_id].update(
|
||||||
|
{
|
||||||
|
"status": "processing",
|
||||||
|
"file_index": file_idx,
|
||||||
|
"file_name": file_info["file_name"],
|
||||||
|
"stage": stage,
|
||||||
|
"current": current,
|
||||||
|
"total": total,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
doc = await kb_helper.upload_document(
|
doc = await kb_helper.upload_document(
|
||||||
file_name=file_info["file_name"],
|
file_name=file_info["file_name"],
|
||||||
@@ -200,99 +154,23 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"failed_count": len(failed_docs),
|
"failed_count": len(failed_docs),
|
||||||
}
|
}
|
||||||
|
|
||||||
self._set_task_result(task_id, "completed", result=result)
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": "completed",
|
||||||
|
"result": result,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
self.upload_progress[task_id]["status"] = "completed"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"后台上传任务 {task_id} 失败: {e}")
|
logger.error(f"后台上传任务 {task_id} 失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
self._set_task_result(task_id, "failed", error=str(e))
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": "failed",
|
||||||
async def _background_import_task(
|
"result": None,
|
||||||
self,
|
"error": str(e),
|
||||||
task_id: str,
|
|
||||||
kb_helper,
|
|
||||||
documents: list,
|
|
||||||
batch_size: int,
|
|
||||||
tasks_limit: int,
|
|
||||||
max_retries: int,
|
|
||||||
):
|
|
||||||
"""后台导入预切片文档任务"""
|
|
||||||
try:
|
|
||||||
# 初始化任务状态
|
|
||||||
self._init_task(task_id, status="processing")
|
|
||||||
self.upload_progress[task_id] = {
|
|
||||||
"status": "processing",
|
|
||||||
"file_index": 0,
|
|
||||||
"file_total": len(documents),
|
|
||||||
"stage": "waiting",
|
|
||||||
"current": 0,
|
|
||||||
"total": 100,
|
|
||||||
}
|
}
|
||||||
|
if task_id in self.upload_progress:
|
||||||
uploaded_docs = []
|
self.upload_progress[task_id]["status"] = "failed"
|
||||||
failed_docs = []
|
|
||||||
|
|
||||||
for file_idx, doc_info in enumerate(documents):
|
|
||||||
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
|
|
||||||
chunks = doc_info.get("chunks", [])
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 更新整体进度
|
|
||||||
self._update_progress(
|
|
||||||
task_id,
|
|
||||||
status="processing",
|
|
||||||
file_index=file_idx,
|
|
||||||
file_name=file_name,
|
|
||||||
stage="importing",
|
|
||||||
current=0,
|
|
||||||
total=100,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 创建进度回调函数
|
|
||||||
progress_callback = self._make_progress_callback(
|
|
||||||
task_id, file_idx, file_name
|
|
||||||
)
|
|
||||||
|
|
||||||
# 调用 upload_document,传入 pre_chunked_text
|
|
||||||
doc = await kb_helper.upload_document(
|
|
||||||
file_name=file_name,
|
|
||||||
file_content=None, # 预切片模式下不需要原始内容
|
|
||||||
file_type=doc_info.get("file_type")
|
|
||||||
or (
|
|
||||||
file_name.rsplit(".", 1)[-1].lower()
|
|
||||||
if "." in file_name
|
|
||||||
else "txt"
|
|
||||||
),
|
|
||||||
batch_size=batch_size,
|
|
||||||
tasks_limit=tasks_limit,
|
|
||||||
max_retries=max_retries,
|
|
||||||
progress_callback=progress_callback,
|
|
||||||
pre_chunked_text=chunks,
|
|
||||||
)
|
|
||||||
|
|
||||||
uploaded_docs.append(doc.model_dump())
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"导入文档 {file_name} 失败: {e}")
|
|
||||||
failed_docs.append(
|
|
||||||
{"file_name": file_name, "error": str(e)},
|
|
||||||
)
|
|
||||||
|
|
||||||
# 更新任务完成状态
|
|
||||||
result = {
|
|
||||||
"task_id": task_id,
|
|
||||||
"uploaded": uploaded_docs,
|
|
||||||
"failed": failed_docs,
|
|
||||||
"total": len(documents),
|
|
||||||
"success_count": len(uploaded_docs),
|
|
||||||
"failed_count": len(failed_docs),
|
|
||||||
}
|
|
||||||
|
|
||||||
self._set_task_result(task_id, "completed", result=result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"后台导入任务 {task_id} 失败: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
self._set_task_result(task_id, "failed", error=str(e))
|
|
||||||
|
|
||||||
async def list_kbs(self):
|
async def list_kbs(self):
|
||||||
"""获取知识库列表
|
"""获取知识库列表
|
||||||
@@ -736,7 +614,11 @@ class KnowledgeBaseRoute(Route):
|
|||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self._init_task(task_id, status="pending")
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": "pending",
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
# 启动后台任务
|
# 启动后台任务
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@@ -771,93 +653,6 @@ class KnowledgeBaseRoute(Route):
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return Response().error(f"上传文档失败: {e!s}").__dict__
|
return Response().error(f"上传文档失败: {e!s}").__dict__
|
||||||
|
|
||||||
def _validate_import_request(self, data: dict):
|
|
||||||
kb_id = data.get("kb_id")
|
|
||||||
if not kb_id:
|
|
||||||
raise ValueError("缺少参数 kb_id")
|
|
||||||
|
|
||||||
documents = data.get("documents")
|
|
||||||
if not documents or not isinstance(documents, list):
|
|
||||||
raise ValueError("缺少参数 documents 或格式错误")
|
|
||||||
|
|
||||||
for doc in documents:
|
|
||||||
if "file_name" not in doc or "chunks" not in doc:
|
|
||||||
raise ValueError("文档格式错误,必须包含 file_name 和 chunks")
|
|
||||||
if not isinstance(doc["chunks"], list):
|
|
||||||
raise ValueError("chunks 必须是列表")
|
|
||||||
if not all(
|
|
||||||
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
|
|
||||||
):
|
|
||||||
raise ValueError("chunks 必须是非空字符串列表")
|
|
||||||
|
|
||||||
batch_size = data.get("batch_size", 32)
|
|
||||||
tasks_limit = data.get("tasks_limit", 3)
|
|
||||||
max_retries = data.get("max_retries", 3)
|
|
||||||
return kb_id, documents, batch_size, tasks_limit, max_retries
|
|
||||||
|
|
||||||
async def import_documents(self):
|
|
||||||
"""导入预切片文档
|
|
||||||
|
|
||||||
Body:
|
|
||||||
- kb_id: 知识库 ID (必填)
|
|
||||||
- documents: 文档列表 (必填)
|
|
||||||
- file_name: 文件名 (必填)
|
|
||||||
- chunks: 切片列表 (必填, list[str])
|
|
||||||
- file_type: 文件类型 (可选, 默认从文件名推断或为 txt)
|
|
||||||
- batch_size: 批处理大小 (可选, 默认32)
|
|
||||||
- tasks_limit: 并发任务限制 (可选, 默认3)
|
|
||||||
- max_retries: 最大重试次数 (可选, 默认3)
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
kb_manager = self._get_kb_manager()
|
|
||||||
data = await request.json
|
|
||||||
|
|
||||||
kb_id, documents, batch_size, tasks_limit, max_retries = (
|
|
||||||
self._validate_import_request(data)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取知识库
|
|
||||||
kb_helper = await kb_manager.get_kb(kb_id)
|
|
||||||
if not kb_helper:
|
|
||||||
return Response().error("知识库不存在").__dict__
|
|
||||||
|
|
||||||
# 生成任务ID
|
|
||||||
task_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
# 初始化任务状态
|
|
||||||
self._init_task(task_id, status="pending")
|
|
||||||
|
|
||||||
# 启动后台任务
|
|
||||||
asyncio.create_task(
|
|
||||||
self._background_import_task(
|
|
||||||
task_id=task_id,
|
|
||||||
kb_helper=kb_helper,
|
|
||||||
documents=documents,
|
|
||||||
batch_size=batch_size,
|
|
||||||
tasks_limit=tasks_limit,
|
|
||||||
max_retries=max_retries,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
Response()
|
|
||||||
.ok(
|
|
||||||
{
|
|
||||||
"task_id": task_id,
|
|
||||||
"doc_count": len(documents),
|
|
||||||
"message": "import task created, processing in background",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
.__dict__
|
|
||||||
)
|
|
||||||
|
|
||||||
except ValueError as e:
|
|
||||||
return Response().error(str(e)).__dict__
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"导入文档失败: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return Response().error(f"导入文档失败: {e!s}").__dict__
|
|
||||||
|
|
||||||
async def get_upload_progress(self):
|
async def get_upload_progress(self):
|
||||||
"""获取上传进度和结果
|
"""获取上传进度和结果
|
||||||
|
|
||||||
@@ -1165,7 +960,11 @@ class KnowledgeBaseRoute(Route):
|
|||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
|
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self._init_task(task_id, status="pending")
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": "pending",
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
|
||||||
# 启动后台任务
|
# 启动后台任务
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
@@ -1218,7 +1017,11 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"""后台上传URL任务"""
|
"""后台上传URL任务"""
|
||||||
try:
|
try:
|
||||||
# 初始化任务状态
|
# 初始化任务状态
|
||||||
self._init_task(task_id, status="processing")
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": "processing",
|
||||||
|
"result": None,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
self.upload_progress[task_id] = {
|
self.upload_progress[task_id] = {
|
||||||
"status": "processing",
|
"status": "processing",
|
||||||
"file_index": 0,
|
"file_index": 0,
|
||||||
@@ -1230,7 +1033,18 @@ class KnowledgeBaseRoute(Route):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# 创建进度回调函数
|
# 创建进度回调函数
|
||||||
progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}")
|
async def progress_callback(stage, current, total):
|
||||||
|
if task_id in self.upload_progress:
|
||||||
|
self.upload_progress[task_id].update(
|
||||||
|
{
|
||||||
|
"status": "processing",
|
||||||
|
"file_index": 0,
|
||||||
|
"file_name": f"URL: {url}",
|
||||||
|
"stage": stage,
|
||||||
|
"current": current,
|
||||||
|
"total": total,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# 上传文档
|
# 上传文档
|
||||||
doc = await kb_helper.upload_from_url(
|
doc = await kb_helper.upload_from_url(
|
||||||
@@ -1255,9 +1069,20 @@ class KnowledgeBaseRoute(Route):
|
|||||||
"failed_count": 0,
|
"failed_count": 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
self._set_task_result(task_id, "completed", result=result)
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": "completed",
|
||||||
|
"result": result,
|
||||||
|
"error": None,
|
||||||
|
}
|
||||||
|
self.upload_progress[task_id]["status"] = "completed"
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
|
logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
self._set_task_result(task_id, "failed", error=str(e))
|
self.upload_tasks[task_id] = {
|
||||||
|
"status": "failed",
|
||||||
|
"result": None,
|
||||||
|
"error": str(e),
|
||||||
|
}
|
||||||
|
if task_id in self.upload_progress:
|
||||||
|
self.upload_progress[task_id]["status"] = "failed"
|
||||||
|
|||||||
@@ -124,11 +124,7 @@ class PluginRoute(Route):
|
|||||||
session.get(url) as response,
|
session.get(url) as response,
|
||||||
):
|
):
|
||||||
if response.status == 200:
|
if response.status == 200:
|
||||||
try:
|
remote_data = await response.json()
|
||||||
remote_data = await response.json()
|
|
||||||
except aiohttp.ContentTypeError:
|
|
||||||
remote_text = await response.text()
|
|
||||||
remote_data = json.loads(remote_text)
|
|
||||||
|
|
||||||
# 检查远程数据是否为空
|
# 检查远程数据是否为空
|
||||||
if not remote_data or (
|
if not remote_data or (
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import traceback
|
|||||||
from quart import request
|
from quart import request
|
||||||
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.agent.mcp_client import MCPTool
|
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
from astrbot.core.star import star_map
|
from astrbot.core.star import star_map
|
||||||
|
|
||||||
@@ -297,30 +296,15 @@ class ToolsRoute(Route):
|
|||||||
"""获取所有注册的工具列表"""
|
"""获取所有注册的工具列表"""
|
||||||
try:
|
try:
|
||||||
tools = self.tool_mgr.func_list
|
tools = self.tool_mgr.func_list
|
||||||
tools_dict = []
|
tools_dict = [
|
||||||
for tool in tools:
|
{
|
||||||
if isinstance(tool, MCPTool):
|
|
||||||
origin = "mcp"
|
|
||||||
origin_name = tool.mcp_server_name
|
|
||||||
elif tool.handler_module_path and star_map.get(
|
|
||||||
tool.handler_module_path
|
|
||||||
):
|
|
||||||
star = star_map[tool.handler_module_path]
|
|
||||||
origin = "plugin"
|
|
||||||
origin_name = star.name
|
|
||||||
else:
|
|
||||||
origin = "unknown"
|
|
||||||
origin_name = "unknown"
|
|
||||||
|
|
||||||
tool_info = {
|
|
||||||
"name": tool.name,
|
"name": tool.name,
|
||||||
"description": tool.description,
|
"description": tool.description,
|
||||||
"parameters": tool.parameters,
|
"parameters": tool.parameters,
|
||||||
"active": tool.active,
|
"active": tool.active,
|
||||||
"origin": origin,
|
|
||||||
"origin_name": origin_name,
|
|
||||||
}
|
}
|
||||||
tools_dict.append(tool_info)
|
for tool in tools
|
||||||
|
]
|
||||||
return Response().ok(data=tools_dict).__dict__
|
return Response().ok(data=tools_dict).__dict__
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|||||||
@@ -67,7 +67,6 @@ class AstrBotDashboard:
|
|||||||
core_lifecycle,
|
core_lifecycle,
|
||||||
core_lifecycle.plugin_manager,
|
core_lifecycle.plugin_manager,
|
||||||
)
|
)
|
||||||
self.command_route = CommandRoute(self.context)
|
|
||||||
self.cr = ConfigRoute(self.context, core_lifecycle)
|
self.cr = ConfigRoute(self.context, core_lifecycle)
|
||||||
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
|
self.lr = LogRoute(self.context, core_lifecycle.log_broker)
|
||||||
self.sfr = StaticFileRoute(self.context)
|
self.sfr = StaticFileRoute(self.context)
|
||||||
|
|||||||
@@ -1,134 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Use Nuitka to build the AstrBot project into standalone executables
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def get_platform_info():
|
|
||||||
"""fetch the current platform information"""
|
|
||||||
system = platform.system()
|
|
||||||
machine = platform.machine()
|
|
||||||
return system, machine
|
|
||||||
|
|
||||||
|
|
||||||
def build_with_nuitka():
|
|
||||||
"""use Nuitka to build the project"""
|
|
||||||
system, machine = get_platform_info()
|
|
||||||
|
|
||||||
print(f"🚀 Starting build for {system} ({machine}) platform...")
|
|
||||||
|
|
||||||
# Output directory
|
|
||||||
output_dir = Path("build/nuitka")
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Base Nuitka command
|
|
||||||
nuitka_cmd = [
|
|
||||||
sys.executable,
|
|
||||||
"-m",
|
|
||||||
"nuitka",
|
|
||||||
"--standalone", # Create standalone directory
|
|
||||||
"--onefile", # Single file mode
|
|
||||||
"--follow-imports", # Follow all imports
|
|
||||||
"--enable-plugin=multiprocessing", # Enable multiprocessing support
|
|
||||||
"--output-dir=build/nuitka", # Output directory
|
|
||||||
"--quiet", # Reduce output verbosity
|
|
||||||
"--assume-yes-for-downloads", # Automatically download dependencies
|
|
||||||
"--jobs=4", # Use multiple CPU cores
|
|
||||||
]
|
|
||||||
|
|
||||||
# include specific packages
|
|
||||||
include_packages = [
|
|
||||||
"astrbot",
|
|
||||||
]
|
|
||||||
|
|
||||||
for pkg in include_packages:
|
|
||||||
nuitka_cmd.extend([f"--include-package={pkg}"])
|
|
||||||
|
|
||||||
# include data directories
|
|
||||||
# data_includes = [
|
|
||||||
# "data/config",
|
|
||||||
# "data/plugins",
|
|
||||||
# "data/temp",
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# for data_dir in data_includes:
|
|
||||||
# if os.path.exists(data_dir):
|
|
||||||
# nuitka_cmd.extend([f"--include-data-dir={data_dir}={data_dir}"])
|
|
||||||
|
|
||||||
# include packages directory (built-in plugins)
|
|
||||||
# if os.path.exists("packages"):
|
|
||||||
# nuitka_cmd.extend(["--include-data-dir=packages=packages"])
|
|
||||||
|
|
||||||
# Platform specific settings
|
|
||||||
if system == "Darwin": # macOS
|
|
||||||
nuitka_cmd.extend(
|
|
||||||
[
|
|
||||||
"--macos-create-app-bundle", # Create .app bundle
|
|
||||||
"--macos-app-name=AstrBot",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# macOS icon (if exists)
|
|
||||||
icon_path = "dashboard/src-tauri/icons/icon.icns"
|
|
||||||
if os.path.exists(icon_path):
|
|
||||||
nuitka_cmd.extend([f"--macos-app-icon={icon_path}"])
|
|
||||||
elif system == "Windows":
|
|
||||||
nuitka_cmd.extend(
|
|
||||||
[
|
|
||||||
"--windows-console-mode=disable", # 无控制台窗口
|
|
||||||
]
|
|
||||||
)
|
|
||||||
# Windows icon (if exists)
|
|
||||||
icon_path = "dashboard/src-tauri/icons/icon.ico"
|
|
||||||
if os.path.exists(icon_path):
|
|
||||||
nuitka_cmd.extend([f"--windows-icon-from-ico={icon_path}"])
|
|
||||||
|
|
||||||
# Main file to compile
|
|
||||||
nuitka_cmd.append("main.py")
|
|
||||||
|
|
||||||
print(f"📦 Executing command: {' '.join(nuitka_cmd)}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
subprocess.run(nuitka_cmd, check=True)
|
|
||||||
print("✅ Nuitka build successful!")
|
|
||||||
|
|
||||||
# Find the generated executable
|
|
||||||
if system == "Darwin":
|
|
||||||
built_file = list(output_dir.glob("*.app"))
|
|
||||||
if built_file:
|
|
||||||
print(f"Generated macOS app: {built_file[0]}")
|
|
||||||
elif system == "Windows":
|
|
||||||
built_file = list(output_dir.glob("*.exe"))
|
|
||||||
if built_file:
|
|
||||||
print(f"Generated Windows executable: {built_file[0]}")
|
|
||||||
else: # Linux
|
|
||||||
built_file = list(output_dir.glob("main.bin"))
|
|
||||||
if built_file:
|
|
||||||
print(f"Generated Linux executable: {built_file[0]}")
|
|
||||||
|
|
||||||
return True
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print(f"❌ Nuitka build failed: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("=" * 60)
|
|
||||||
print("AstrBot Nuitka Builder")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# 构建
|
|
||||||
if build_with_nuitka():
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🎉 Build Complete!")
|
|
||||||
print("=" * 60)
|
|
||||||
else:
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("❌ Build Failed")
|
|
||||||
print("=" * 60)
|
|
||||||
sys.exit(1)
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
"""
|
|
||||||
Use PyInstaller to build the AstrBot project into standalone executables
|
|
||||||
"""
|
|
||||||
|
|
||||||
import platform
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
|
|
||||||
def get_platform_info():
|
|
||||||
"""fetch the current platform information"""
|
|
||||||
system = platform.system()
|
|
||||||
machine = platform.machine()
|
|
||||||
return system, machine
|
|
||||||
|
|
||||||
|
|
||||||
def build_with_pyinstaller():
|
|
||||||
"""use PyInstaller to build the project"""
|
|
||||||
system, machine = get_platform_info()
|
|
||||||
|
|
||||||
print(f"🚀 Starting build for {system} ({machine}) platform...")
|
|
||||||
|
|
||||||
# Output directory
|
|
||||||
output_dir = Path("build/pyinstaller")
|
|
||||||
output_dir.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Base PyInstaller command
|
|
||||||
pyinstaller_cmd = [
|
|
||||||
sys.executable,
|
|
||||||
"-m",
|
|
||||||
"PyInstaller",
|
|
||||||
"--clean", # Clean cache before build
|
|
||||||
"--noconfirm", # Replace output directory without asking
|
|
||||||
"--onefile", # Single file mode
|
|
||||||
"--distpath=build/pyinstaller/dist", # Distribution directory
|
|
||||||
"--workpath=build/pyinstaller/build", # Work directory
|
|
||||||
"--specpath=build/pyinstaller", # Spec file directory
|
|
||||||
"--name=AstrBot", # Output executable name
|
|
||||||
]
|
|
||||||
# Platform specific settings
|
|
||||||
# if system == "Darwin": # macOS
|
|
||||||
# # macOS icon (if exists)
|
|
||||||
# icon_path = "dashboard/src-tauri/icons/icon.icns"
|
|
||||||
# if os.path.exists(icon_path):
|
|
||||||
# pyinstaller_cmd.extend([f"--icon={icon_path}"])
|
|
||||||
# # Create .app bundle
|
|
||||||
# pyinstaller_cmd.extend(["--windowed"])
|
|
||||||
# elif system == "Windows":
|
|
||||||
# # Windows icon (if exists)
|
|
||||||
# icon_path = "dashboard/src-tauri/icons/icon.ico"
|
|
||||||
# if os.path.exists(icon_path):
|
|
||||||
# pyinstaller_cmd.extend([f"--icon={icon_path}"])
|
|
||||||
# # No console window
|
|
||||||
# pyinstaller_cmd.extend(["--windowed"])
|
|
||||||
# else: # Linux
|
|
||||||
# pyinstaller_cmd.extend(["--console"])
|
|
||||||
|
|
||||||
# Main file to compile
|
|
||||||
pyinstaller_cmd.append("main.py")
|
|
||||||
|
|
||||||
print(f"📦 Executing command: {' '.join(pyinstaller_cmd)}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
subprocess.run(pyinstaller_cmd, check=True)
|
|
||||||
print("✅ PyInstaller build successful!")
|
|
||||||
|
|
||||||
# Find the generated executable
|
|
||||||
dist_dir = output_dir / "dist"
|
|
||||||
if system == "Darwin":
|
|
||||||
built_file = list(dist_dir.glob("AstrBot.app"))
|
|
||||||
if not built_file:
|
|
||||||
built_file = list(dist_dir.glob("AstrBot"))
|
|
||||||
if built_file:
|
|
||||||
print(f"📱 Generated macOS app: {built_file[0]}")
|
|
||||||
elif system == "Windows":
|
|
||||||
built_file = list(dist_dir.glob("AstrBot.exe"))
|
|
||||||
if built_file:
|
|
||||||
print(f"💻 Generated Windows executable: {built_file[0]}")
|
|
||||||
else: # Linux
|
|
||||||
built_file = list(dist_dir.glob("AstrBot"))
|
|
||||||
if built_file:
|
|
||||||
print(f"🐧 Generated Linux executable: {built_file[0]}")
|
|
||||||
|
|
||||||
print(f"\n📁 Output directory: {dist_dir.absolute()}")
|
|
||||||
return True
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print(f"❌ PyInstaller build failed: {e}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ Unexpected error: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def install_pyinstaller():
|
|
||||||
"""Install PyInstaller if not already installed"""
|
|
||||||
try:
|
|
||||||
import PyInstaller
|
|
||||||
|
|
||||||
print(f"✅ PyInstaller already installed (version {PyInstaller.__version__})")
|
|
||||||
return True
|
|
||||||
except ImportError:
|
|
||||||
print("📥 PyInstaller not found, installing...")
|
|
||||||
try:
|
|
||||||
subprocess.run(
|
|
||||||
[sys.executable, "-m", "pip", "install", "pyinstaller"], check=True
|
|
||||||
)
|
|
||||||
print("✅ PyInstaller installed successfully!")
|
|
||||||
return True
|
|
||||||
except subprocess.CalledProcessError as e:
|
|
||||||
print(f"❌ Failed to install PyInstaller: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
print("=" * 60)
|
|
||||||
print("AstrBot PyInstaller Builder")
|
|
||||||
print("=" * 60)
|
|
||||||
|
|
||||||
# Check and install PyInstaller
|
|
||||||
if not install_pyinstaller():
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
# Build
|
|
||||||
if build_with_pyinstaller():
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("🎉 Build Complete!")
|
|
||||||
print("=" * 60)
|
|
||||||
else:
|
|
||||||
print("\n" + "=" * 60)
|
|
||||||
print("❌ Build Failed")
|
|
||||||
print("=" * 60)
|
|
||||||
sys.exit(1)
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
## What's Changed
|
|
||||||
|
|
||||||
### 新增
|
|
||||||
|
|
||||||
- 支持自定义插件源。
|
|
||||||
- 支持飞书(Lark)的 Webhook 模式(将事件推送至开发者服务器)。
|
|
||||||
- 支持 “禁用自带指令” 快捷配置项,启用后将禁用所有 AstrBot 自带指令。入口: WebUI -> 配置文件 -> 平台配置。
|
|
||||||
|
|
||||||
### 优化
|
|
||||||
|
|
||||||
- 从 WebUI 移除了开发版本渠道。
|
|
||||||
- 当试图测试"Agent Runner"时,提示前往配置文件页测试。
|
|
||||||
- WebUI 列表项支持批量粘贴、回车创建项目。
|
|
||||||
|
|
||||||
### 修复
|
|
||||||
|
|
||||||
- Gemini API 部分调用失败的问题。
|
|
||||||
- WebUI 插件安装加载 Dialog 关闭按钮在手机端下显示异常的问题。
|
|
||||||
- 部分情况下,WebUI 日志显示不全的问题。
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
## What's Changed
|
|
||||||
|
|
||||||
-
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
## What's Changed
|
|
||||||
|
|
||||||
### 修复
|
|
||||||
|
|
||||||
- 企业自部署飞书(自定义 domain)可以接收消息但无法发送消息的问题。
|
|
||||||
- 安装插件 Dialog 的深色样式问题。
|
|
||||||
|
|
||||||
### 优化
|
|
||||||
|
|
||||||
- 避免某些插件在流式响应结束后重d复发送消息的问题。
|
|
||||||
|
|
||||||
### 新增
|
|
||||||
|
|
||||||
- 支持在对话管理批量导出对话轨迹数据为 `jsonl` 格式文件。入口:WebUI -> 对话管理 -> 批量选中 -> 导出。
|
|
||||||
- 支持对 TTS(文本转语音)设置概率触发。
|
|
||||||
- (插件开发)支持在 schema 中对 float 和 int 类型设置 `slider` 滑块控件。例如 `slider: {min: 0, max: 1, step: 0.1}`。
|
|
||||||
- (插件开发)支持 key-value 存储功能。例如使用 `await self.put_kv_data("key", value)`, `await self.get_kv_data("key", default_value)` 和 `await self.delete_kv_data("key")`。
|
|
||||||
@@ -1,225 +0,0 @@
|
|||||||
# AstrBot Dashboard - Tauri 桌面应用
|
|
||||||
|
|
||||||
本项目现已支持通过 Tauri 构建为桌面应用,同时保持与 Web 版本的兼容性。
|
|
||||||
|
|
||||||
## 环境要求
|
|
||||||
|
|
||||||
### 系统依赖
|
|
||||||
|
|
||||||
**macOS:**
|
|
||||||
```bash
|
|
||||||
# 安装 Xcode Command Line Tools
|
|
||||||
xcode-select --install
|
|
||||||
```
|
|
||||||
|
|
||||||
**Windows:**
|
|
||||||
- 安装 [Microsoft Visual Studio C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/)
|
|
||||||
- 安装 [WebView2](https://developer.microsoft.com/en-us/microsoft-edge/webview2/)
|
|
||||||
|
|
||||||
**Linux (Ubuntu/Debian):**
|
|
||||||
```bash
|
|
||||||
sudo apt update
|
|
||||||
sudo apt install libwebkit2gtk-4.0-dev \
|
|
||||||
build-essential \
|
|
||||||
curl \
|
|
||||||
wget \
|
|
||||||
file \
|
|
||||||
libssl-dev \
|
|
||||||
libgtk-3-dev \
|
|
||||||
libayatana-appindicator3-dev \
|
|
||||||
librsvg2-dev
|
|
||||||
```
|
|
||||||
|
|
||||||
### Rust 环境
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# 安装 Rust
|
|
||||||
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
|
|
||||||
|
|
||||||
# 验证安装
|
|
||||||
rustc --version
|
|
||||||
cargo --version
|
|
||||||
```
|
|
||||||
|
|
||||||
## 安装依赖
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd dashboard
|
|
||||||
npm install
|
|
||||||
```
|
|
||||||
|
|
||||||
## 开发模式
|
|
||||||
|
|
||||||
### Web 端开发(不变)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm run dev
|
|
||||||
```
|
|
||||||
|
|
||||||
访问 http://localhost:3000
|
|
||||||
|
|
||||||
### 桌面端开发
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm run tauri:dev
|
|
||||||
```
|
|
||||||
|
|
||||||
这会同时启动:
|
|
||||||
1. Vite 开发服务器(端口 3000)
|
|
||||||
2. Tauri 桌面应用窗口
|
|
||||||
|
|
||||||
热重载功能正常工作,修改代码后会自动刷新。
|
|
||||||
|
|
||||||
## 构建
|
|
||||||
|
|
||||||
### Web 端构建(不变)
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm run build
|
|
||||||
```
|
|
||||||
|
|
||||||
输出目录:`dist/`
|
|
||||||
|
|
||||||
### 桌面端构建
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm run tauri:build
|
|
||||||
```
|
|
||||||
|
|
||||||
构建产物位置:
|
|
||||||
- **macOS**: `src-tauri/target/release/bundle/dmg/`
|
|
||||||
- **Windows**: `src-tauri/target/release/bundle/msi/`
|
|
||||||
- **Linux**: `src-tauri/target/release/bundle/deb/` 或 `appimage/`
|
|
||||||
|
|
||||||
## 图标设置
|
|
||||||
|
|
||||||
### 自动生成图标
|
|
||||||
|
|
||||||
准备一个至少 512x512 像素的 PNG 图标,然后运行:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm run tauri icon path/to/your/icon.png
|
|
||||||
```
|
|
||||||
|
|
||||||
### 手动设置图标
|
|
||||||
|
|
||||||
将以下图标放入 `src-tauri/icons/` 目录:
|
|
||||||
- `32x32.png`
|
|
||||||
- `128x128.png`
|
|
||||||
- `128x128@2x.png`
|
|
||||||
- `icon.icns` (macOS)
|
|
||||||
- `icon.ico` (Windows)
|
|
||||||
|
|
||||||
## 代码兼容性
|
|
||||||
|
|
||||||
项目已配置为同时支持 Web 和桌面端,使用相同的代码库。
|
|
||||||
|
|
||||||
### 环境检测工具
|
|
||||||
|
|
||||||
在 `src/utils/tauri.ts` 中提供了环境检测工具:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
import { isTauri, isWeb, PlatformAPI } from '@/utils/tauri';
|
|
||||||
|
|
||||||
// 检测运行环境
|
|
||||||
if (isTauri()) {
|
|
||||||
console.log('运行在桌面应用中');
|
|
||||||
} else {
|
|
||||||
console.log('运行在浏览器中');
|
|
||||||
}
|
|
||||||
|
|
||||||
// 获取正确的 API 端点
|
|
||||||
const baseURL = PlatformAPI.getBaseURL();
|
|
||||||
```
|
|
||||||
|
|
||||||
### API 调用注意事项
|
|
||||||
|
|
||||||
- **Web 端**: 使用 Vite 代理,API 路径为 `/api/*`
|
|
||||||
- **桌面端**: 直接连接到 `http://127.0.0.1:6185`
|
|
||||||
|
|
||||||
已在 `PlatformAPI.getBaseURL()` 中处理,使用 axios 时:
|
|
||||||
|
|
||||||
```typescript
|
|
||||||
import axios from 'axios';
|
|
||||||
import { PlatformAPI } from '@/utils/tauri';
|
|
||||||
|
|
||||||
const api = axios.create({
|
|
||||||
baseURL: PlatformAPI.getBaseURL()
|
|
||||||
});
|
|
||||||
```
|
|
||||||
|
|
||||||
## 配置说明
|
|
||||||
|
|
||||||
### tauri.conf.json
|
|
||||||
|
|
||||||
主要配置项:
|
|
||||||
- `build.devPath`: 开发服务器地址(http://localhost:3000)
|
|
||||||
- `build.distDir`: 构建输出目录(../dist)
|
|
||||||
- `tauri.allowlist`: API 权限配置
|
|
||||||
- `tauri.windows`: 窗口配置(大小、标题等)
|
|
||||||
|
|
||||||
### 安全性
|
|
||||||
|
|
||||||
默认配置已启用必要的权限:
|
|
||||||
- 文件系统访问(限定在 APPDATA 目录)
|
|
||||||
- HTTP 请求(限定到本地后端)
|
|
||||||
- 窗口控制
|
|
||||||
- 对话框(打开/保存文件)
|
|
||||||
|
|
||||||
可在 `tauri.conf.json` 的 `allowlist` 部分调整权限。
|
|
||||||
|
|
||||||
## 后端连接
|
|
||||||
|
|
||||||
桌面应用需要后端服务运行在 `http://127.0.0.1:6185`。
|
|
||||||
|
|
||||||
### 启动流程
|
|
||||||
|
|
||||||
1. 启动 AstrBot 后端:
|
|
||||||
```bash
|
|
||||||
cd /path/to/AstrBot
|
|
||||||
uv run main.py
|
|
||||||
```
|
|
||||||
|
|
||||||
2. 启动桌面应用:
|
|
||||||
```bash
|
|
||||||
cd dashboard
|
|
||||||
npm run tauri:dev
|
|
||||||
```
|
|
||||||
|
|
||||||
或直接运行打包后的应用(后端需要已启动)。
|
|
||||||
|
|
||||||
## 常见问题
|
|
||||||
|
|
||||||
### Q: 桌面应用无法连接到后端?
|
|
||||||
|
|
||||||
确保:
|
|
||||||
1. AstrBot 后端正在运行(`uv run main.py`)
|
|
||||||
2. 后端监听在 `127.0.0.1:6185`
|
|
||||||
3. 防火墙未阻止连接
|
|
||||||
|
|
||||||
### Q: 图标未显示?
|
|
||||||
|
|
||||||
检查 `src-tauri/icons/` 目录中是否有所需的图标文件,或使用 `npm run tauri icon` 命令生成。
|
|
||||||
|
|
||||||
### Q: 构建失败?
|
|
||||||
|
|
||||||
- 确保已安装 Rust 和系统依赖
|
|
||||||
- 运行 `cargo clean` 清理缓存后重试
|
|
||||||
- 检查 Rust 版本(需要 1.60+)
|
|
||||||
|
|
||||||
### Q: Web 端功能是否受影响?
|
|
||||||
|
|
||||||
不受影响。`npm run dev` 和 `npm run build` 的行为完全不变。
|
|
||||||
|
|
||||||
## 开发建议
|
|
||||||
|
|
||||||
1. **优先使用 Web 端开发**: 更快的热重载,更好的调试体验
|
|
||||||
2. **定期测试桌面端**: 确保跨平台兼容性
|
|
||||||
3. **使用环境检测**: 针对不同平台提供最佳体验
|
|
||||||
4. **注意 API 差异**: Web 和桌面端的某些 API 可能有差异
|
|
||||||
|
|
||||||
## 更多资源
|
|
||||||
|
|
||||||
- [Tauri 官方文档](https://tauri.app/)
|
|
||||||
- [Tauri API 参考](https://tauri.app/v1/api/js/)
|
|
||||||
- [Tauri Discord 社区](https://discord.com/invite/tauri)
|
|
||||||
@@ -10,14 +10,10 @@
|
|||||||
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
|
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
|
||||||
"preview": "vite preview --port 5050",
|
"preview": "vite preview --port 5050",
|
||||||
"typecheck": "vue-tsc --noEmit",
|
"typecheck": "vue-tsc --noEmit",
|
||||||
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore",
|
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
|
||||||
"tauri": "tauri",
|
|
||||||
"tauri:dev": "tauri dev",
|
|
||||||
"tauri:build": "tauri build"
|
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@guolao/vue-monaco-editor": "^1.5.4",
|
"@guolao/vue-monaco-editor": "^1.5.4",
|
||||||
"@tauri-apps/api": "^2.9.0",
|
|
||||||
"@tiptap/starter-kit": "2.1.7",
|
"@tiptap/starter-kit": "2.1.7",
|
||||||
"@tiptap/vue-3": "2.1.7",
|
"@tiptap/vue-3": "2.1.7",
|
||||||
"apexcharts": "3.42.0",
|
"apexcharts": "3.42.0",
|
||||||
@@ -47,7 +43,6 @@
|
|||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@mdi/font": "7.2.96",
|
"@mdi/font": "7.2.96",
|
||||||
"@rushstack/eslint-patch": "1.3.3",
|
"@rushstack/eslint-patch": "1.3.3",
|
||||||
"@tauri-apps/cli": "^2.9.4",
|
|
||||||
"@types/chance": "1.1.3",
|
"@types/chance": "1.1.3",
|
||||||
"@types/markdown-it": "^14.1.2",
|
"@types/markdown-it": "^14.1.2",
|
||||||
"@types/node": "^20.5.7",
|
"@types/node": "^20.5.7",
|
||||||
|
|||||||
@@ -1,3 +0,0 @@
|
|||||||
# Tauri specific
|
|
||||||
src-tauri/target/
|
|
||||||
src-tauri/WixTools/
|
|
||||||
@@ -1,27 +0,0 @@
|
|||||||
[package]
|
|
||||||
name = "astrbot-dashboard"
|
|
||||||
version = "4.5.6"
|
|
||||||
description = "AstrBot"
|
|
||||||
authors = ["AstrBot Team"]
|
|
||||||
license = "AGPL-3.0"
|
|
||||||
repository = "https://github.com/AstrBotDevs/AstrBot"
|
|
||||||
default-run = "astrbot-dashboard"
|
|
||||||
edition = "2021"
|
|
||||||
rust-version = "1.91.0"
|
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
|
||||||
|
|
||||||
[build-dependencies]
|
|
||||||
tauri-build = { version = "2", features = [] }
|
|
||||||
|
|
||||||
[dependencies]
|
|
||||||
serde_json = "1.0"
|
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
|
||||||
tauri = { version = "2.9.2", features = ["macos-private-api", "protocol-asset"] }
|
|
||||||
tauri-plugin-opener = "2"
|
|
||||||
|
|
||||||
[features]
|
|
||||||
# this feature is used for production builds or when `devPath` points to the filesystem and the built-in dev server is disabled.
|
|
||||||
# If you use cargo directly instead of tauri's cli you can use this feature flag to switch between tauri's `dev` and `build` modes.
|
|
||||||
# DO NOT REMOVE!!
|
|
||||||
custom-protocol = [ "tauri/custom-protocol" ]
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
fn main() {
|
|
||||||
tauri_build::build()
|
|
||||||
}
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
{}
|
|
||||||
|
Before Width: | Height: | Size: 7.3 KiB |
|
Before Width: | Height: | Size: 18 KiB |
|
Before Width: | Height: | Size: 1.3 KiB |
|
Before Width: | Height: | Size: 3.2 KiB |
|
Before Width: | Height: | Size: 5.9 KiB |
|
Before Width: | Height: | Size: 8.2 KiB |
|
Before Width: | Height: | Size: 8.8 KiB |
|
Before Width: | Height: | Size: 20 KiB |
|
Before Width: | Height: | Size: 1.2 KiB |
|
Before Width: | Height: | Size: 23 KiB |
|
Before Width: | Height: | Size: 2.0 KiB |
|
Before Width: | Height: | Size: 3.5 KiB |
|
Before Width: | Height: | Size: 4.8 KiB |
|
Before Width: | Height: | Size: 2.3 KiB |
@@ -1,5 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="utf-8"?>
|
|
||||||
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
|
|
||||||
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
|
|
||||||
<background android:drawable="@color/ic_launcher_background"/>
|
|
||||||
</adaptive-icon>
|
|
||||||
|
Before Width: | Height: | Size: 2.2 KiB |
|
Before Width: | Height: | Size: 9.8 KiB |
|
Before Width: | Height: | Size: 2.0 KiB |
|
Before Width: | Height: | Size: 2.1 KiB |
|
Before Width: | Height: | Size: 6.0 KiB |
|
Before Width: | Height: | Size: 1.8 KiB |
|
Before Width: | Height: | Size: 4.9 KiB |
|
Before Width: | Height: | Size: 14 KiB |
|
Before Width: | Height: | Size: 4.2 KiB |
|
Before Width: | Height: | Size: 7.9 KiB |
|
Before Width: | Height: | Size: 24 KiB |
|
Before Width: | Height: | Size: 6.8 KiB |
|
Before Width: | Height: | Size: 11 KiB |
|
Before Width: | Height: | Size: 37 KiB |
|
Before Width: | Height: | Size: 9.6 KiB |
@@ -1,4 +0,0 @@
|
|||||||
<?xml version="1.0" encoding="utf-8"?>
|
|
||||||
<resources>
|
|
||||||
<color name="ic_launcher_background">#fff</color>
|
|
||||||
</resources>
|
|
||||||
|
Before Width: | Height: | Size: 27 KiB |
|
Before Width: | Height: | Size: 47 KiB |
|
Before Width: | Height: | Size: 602 B |
|
Before Width: | Height: | Size: 1.4 KiB |