Compare commits

..

5 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] a2fe0ec5a1 Add webhook signature verification for security
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:27:51 +00:00
copilot-swe-agent[bot] 6957ec713d Clean up unused imports in tests
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:24:18 +00:00
copilot-swe-agent[bot] d97c8b5b2b Add tests for GitHub webhook platform adapter
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:23:22 +00:00
copilot-swe-agent[bot] d07a1ad5c9 Add GitHub webhook platform adapter with event handlers
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2025-12-12 14:20:33 +00:00
copilot-swe-agent[bot] d8e6dfbd6b Initial plan 2025-12-12 14:14:49 +00:00
178 changed files with 1695 additions and 20142 deletions
-79
View File
@@ -1,79 +0,0 @@
name: Build Desktop App
on:
push:
tags:
- 'v*'
workflow_dispatch:
jobs:
build:
strategy:
fail-fast: false
matrix:
platform: [macos-latest, ubuntu-latest, windows-latest]
runs-on: ${{ matrix.platform }}
steps:
- uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Setup Node.js
uses: actions/setup-node@v4
with:
node-version: 20
- name: Install Rust
uses: dtolnay/rust-toolchain@stable
- name: Install dependencies (Ubuntu)
if: matrix.platform == 'ubuntu-latest'
run: |
sudo apt-get update
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev libappindicator3-dev librsvg2-dev patchelf
- name: Install Python dependencies
run: |
pip install uv
uv sync
- name: Build Python backend with Nuitka
run: |
pip install nuitka
python build_nuitka.py
- name: Install Node dependencies
working-directory: ./dashboard
run: npm install
- name: Build Tauri app
working-directory: ./dashboard
run: npm run tauri:build
- name: Upload artifacts (macOS)
if: matrix.platform == 'macos-latest'
uses: actions/upload-artifact@v4
with:
name: astrbot-macos
path: dashboard/src-tauri/target/release/bundle/dmg/*.dmg
- name: Upload artifacts (Windows)
if: matrix.platform == 'windows-latest'
uses: actions/upload-artifact@v4
with:
name: astrbot-windows
path: dashboard/src-tauri/target/release/bundle/msi/*.msi
- name: Upload artifacts (Linux)
if: matrix.platform == 'ubuntu-latest'
uses: actions/upload-artifact@v4
with:
name: astrbot-linux
path: |
dashboard/src-tauri/target/release/bundle/deb/*.deb
dashboard/src-tauri/target/release/bundle/appimage/*.AppImage
+1 -1
View File
@@ -36,7 +36,7 @@ jobs:
zip -r dist.zip dist zip -r dist.zip dist
- name: Archive production artifacts - name: Archive production artifacts
uses: actions/upload-artifact@v6 uses: actions/upload-artifact@v5
with: with:
name: dist-without-markdown name: dist-without-markdown
path: | path: |
-2
View File
@@ -32,7 +32,6 @@ tests/astrbot_plugin_openai
# Dashboard # Dashboard
dashboard/node_modules/ dashboard/node_modules/
dashboard/dist/ dashboard/dist/
dashboard/src-tauri/target
package-lock.json package-lock.json
package.json package.json
yarn.lock yarn.lock
@@ -49,6 +48,5 @@ astrbot.lock
chroma chroma
venv/* venv/*
pytest.ini pytest.ini
build/
AGENTS.md AGENTS.md
IFLOW.md IFLOW.md
-287
View File
@@ -1,287 +0,0 @@
# AstrBot 桌面应用构建指南
本指南介绍如何使用 Nuitka 将 Python 后端打包并集成到 Tauri 桌面应用中。
## 前置要求
### 系统要求
- Python 3.10+
- Node.js 20+
- Rust (通过 rustup 安装)
- UV 包管理器
### macOS 额外要求
- Xcode Command Line Tools: `xcode-select --install`
### Linux 额外要求
```bash
sudo apt-get install -y libgtk-3-dev libwebkit2gtk-4.0-dev \
libappindicator3-dev librsvg2-dev patchelf
```
### Windows 额外要求
- Visual Studio 2019+ with C++ build tools
- Windows 10 SDK
## 构建步骤
### 1. 安装 Python 依赖
```bash
pip install uv
uv sync
```
### 2. 安装 Nuitka
```bash
pip install nuitka
```
### 3. 构建 Python 后端
```bash
python build_nuitka.py
```
这会使用 Nuitka 将 `main.py` 编译为独立可执行文件,输出到 `build/nuitka/` 目录。
**注意**: Nuitka 编译过程可能需要 10-30 分钟,取决于您的系统性能。
### 4. 安装前端依赖
```bash
cd dashboard
npm install
```
### 5. 构建 Tauri 应用
```bash
npm run tauri:build
```
构建脚本会自动:
1. 运行 `build_nuitka.py` 编译 Python 后端
2. 将编译好的可执行文件复制到 `src-tauri/resources/` 目录
3. 构建 Tauri 应用并打包所有资源
### 6. 查找构建产物
构建完成后,您可以在以下位置找到安装包:
- **macOS**: `dashboard/src-tauri/target/release/bundle/dmg/AstrBot_*.dmg`
- **Windows**: `dashboard/src-tauri/target/release/bundle/msi/AstrBot_*.msi`
- **Linux**:
- `dashboard/src-tauri/target/release/bundle/deb/astrbot_*.deb`
- `dashboard/src-tauri/target/release/bundle/appimage/astrbot_*.AppImage`
## 开发模式
在开发时,您可能不想每次都完整编译 Python 后端。
### 仅开发 Tauri + Vue
```bash
cd dashboard
npm run tauri:dev
```
这会启动开发服务器,但不会自动启动 Python 后端。您需要手动运行:
```bash
uv run main.py
```
### 测试完整集成
如果您想测试 Tauri 自动启动 Python 后端的功能:
1. 先编译一次 Python 后端:
```bash
python build_nuitka.py
```
2. 手动复制到资源目录:
```bash
# macOS
cp -r build/nuitka/main.app dashboard/src-tauri/resources/astrbot-backend.app
# Windows
copy build\nuitka\main.exe dashboard\src-tauri\resources\astrbot-backend.exe
# Linux
cp build/nuitka/main.bin dashboard/src-tauri/resources/astrbot-backend
```
3. 运行开发模式:
```bash
cd dashboard
npm run tauri:dev
```
## Nuitka 构建选项说明
`build_nuitka.py` 脚本使用以下关键选项:
- `--standalone`: 创建包含所有依赖的独立目录
- `--onefile`: 将所有内容打包到单个可执行文件
- `--follow-imports`: 自动跟踪所有 Python 导入
- `--include-package`: 明确包含特定包
- `--include-data-dir`: 包含数据目录(插件、配置等)
### 自定义构建
如果您需要修改构建选项,编辑 `build_nuitka.py`:
```python
# 添加更多要包含的包
include_packages = [
"astrbot",
"your_custom_package",
# ...
]
# 添加更多数据目录
data_includes = [
"data/config",
"your_custom_data",
# ...
]
```
## 常见问题
### 1. Nuitka 编译失败
**问题**: 编译时出现 "module not found" 错误
**解决方案**: 在 `build_nuitka.py` 中添加缺失的包到 `include_packages` 列表
### 2. 运行时找不到资源文件
**问题**: 应用启动后提示找不到配置文件或插件
**解决方案**: 确保在 `build_nuitka.py` 中使用 `--include-data-dir` 包含了所有必要的数据目录
### 3. macOS 安全警告
**问题**: macOS 提示"应用来自未知开发者"
**解决方案**:
```bash
# 临时解除限制
sudo spctl --master-disable
# 或者为特定应用授权
xattr -cr /Applications/AstrBot.app
```
对于生产发布,您需要:
1. 注册 Apple Developer 账号
2. 对应用进行代码签名
3. 提交公证 (Notarization)
### 4. Windows Defender 报毒
**问题**: Windows Defender 或其他杀毒软件报毒
**解决方案**:
- 这是 Nuitka 打包程序的常见问题
- 可以使用 `--windows-company-name``--windows-product-name` 添加元数据
- 对于生产发布,需要购买代码签名证书
### 5. Linux 依赖问题
**问题**: 在某些 Linux 发行版上缺少共享库
**解决方案**: 使用 AppImage 格式,它包含所有依赖:
```bash
# 构建时会自动生成 AppImage
npm run tauri:build
```
## 优化构建大小
默认的 `--onefile` 模式会生成较大的可执行文件。如果需要减小体积:
1. 移除不需要的包
2. 使用 `--standalone` 而不是 `--onefile`
3. 排除不必要的数据文件
修改 `build_nuitka.py`:
```python
# 移除 --onefile,使用 --standalone
nuitka_cmd = [
sys.executable,
"-m", "nuitka",
"--standalone", # 只使用 standalone
# "--onefile", # 注释掉 onefile
# ...
]
```
## CI/CD 集成
项目已配置 GitHub Actions 工作流 (`.github/workflows/build-app.yml`),可以自动为所有平台构建应用。
推送标签时自动触发:
```bash
git tag v4.5.7
git push origin v4.5.7
```
或手动触发:
在 GitHub Actions 页面选择 "Build Desktop App" 工作流并点击 "Run workflow"
## 发布清单
在发布新版本前:
- [ ] 更新版本号
- `pyproject.toml` - Python 项目版本
- `dashboard/package.json` - Node 项目版本
- `dashboard/src-tauri/Cargo.toml` - Rust 项目版本
- `dashboard/src-tauri/tauri.conf.json` - Tauri 配置版本
- [ ] 运行代码检查
```bash
uv run ruff check .
uv run ruff format .
```
- [ ] 本地测试构建
```bash
python build_nuitka.py
cd dashboard && npm run tauri:build
```
- [ ] 测试安装包
- 安装生成的安装包
- 验证应用启动
- 验证 Python 后端自动启动
- 测试核心功能
- [ ] 创建发布标签
```bash
git tag -a v4.5.7 -m "Release v4.5.7"
git push origin v4.5.7
```
## 技术架构
```
┌─────────────────────────────────────┐
│ Tauri Desktop App │
│ (Rust + WebView) │
│ │
│ ┌─────────────────────────────┐ │
│ │ Vue.js Dashboard │ │
│ │ (Frontend UI) │ │
│ └─────────────────────────────┘ │
│ │
│ ┌─────────────────────────────┐ │
│ │ Python Backend │ │
│ │ (Nuitka Compiled) │ │
│ │ - AstrBot Core │ │
│ │ - Plugins │ │
│ │ - API Server │ │
│ └─────────────────────────────┘ │
│ │
│ HTTP/WebSocket │
│ localhost:6185 │
└─────────────────────────────────────┘
```
## 参考资源
- [Nuitka 文档](https://nuitka.net/doc/user-manual.html)
- [Tauri 文档](https://tauri.app/v1/guides/)
- [AstrBot 文档](https://astrbot.fun)
-25
View File
@@ -33,20 +33,6 @@
- 请使用英文描述您的 PR。 - 请使用英文描述您的 PR。
- 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo` - 标题请使用 `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` 等语义化前缀,并简要描述更改内容。如:`fix: correct login page typo`
#### 代码规范
##### Core
我们使用 Ruff 作为代码格式化和静态分析工具。在提交代码之前,请运行以下命令以确保代码符合规范:
```bash
ruff format .
ruff check .
```
如果您使用 VSCode,可以安装 `Ruff` 插件。
## Contributing Guide ## Contributing Guide
First off, thanks for taking the time to contribute! ❤️ First off, thanks for taking the time to contribute! ❤️
@@ -77,14 +63,3 @@ We use the `fix/` prefix for bug fixes and the `feat/` prefix for new features.
#### PR Description #### PR Description
- Please use English to describe your PR. - Please use English to describe your PR.
- Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`. - Use semantic prefixes like `fix: `, `feat: `, `docs: `, `style: `, `refactor: `, `test: `, `chore: ` in the title, followed by a brief description of the changes, e.g., `fix: correct login page typo`.
#### Code Style
##### Core
We use Ruff as our code formatter and static analysis tool. Before submitting your code, please run the following commands to ensure your code adheres to the style guidelines:
```bash
ruff format .
ruff check .
```
-6
View File
@@ -243,10 +243,4 @@ pre-commit install
</details> </details>
<div align="center">
_私は、高性能ですから!_ _私は、高性能ですから!_
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
</div
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.9.2" __version__ = "4.8.0"
+4 -6
View File
@@ -3,7 +3,7 @@
from typing import Any, ClassVar, Literal, cast from typing import Any, ClassVar, Literal, cast
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator from pydantic import BaseModel, GetCoreSchemaHandler, model_validator
from pydantic_core import core_schema from pydantic_core import core_schema
@@ -122,12 +122,10 @@ class ToolCall(BaseModel):
extra_content: dict[str, Any] | None = None extra_content: dict[str, Any] | None = None
"""Extra metadata for the tool call.""" """Extra metadata for the tool call."""
@model_serializer(mode="wrap") def model_dump(self, **kwargs: Any) -> dict[str, Any]:
def serialize(self, handler):
data = handler(self)
if self.extra_content is None: if self.extra_content is None:
data.pop("extra_content", None) kwargs.setdefault("exclude", set()).add("extra_content")
return data return super().model_dump(**kwargs)
class ToolCallPart(BaseModel): class ToolCallPart(BaseModel):
+1 -22
View File
@@ -1,8 +1,7 @@
import typing as T import typing as T
from dataclasses import dataclass, field from dataclasses import dataclass
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import TokenUsage
class AgentResponseData(T.TypedDict): class AgentResponseData(T.TypedDict):
@@ -13,23 +12,3 @@ class AgentResponseData(T.TypedDict):
class AgentResponse: class AgentResponse:
type: str type: str
data: AgentResponseData data: AgentResponseData
@dataclass
class AgentStats:
token_usage: TokenUsage = field(default_factory=TokenUsage)
start_time: float = 0.0
end_time: float = 0.0
time_to_first_token: float = 0.0
@property
def duration(self) -> float:
return self.end_time - self.start_time
def to_dict(self) -> dict:
return {
"token_usage": self.token_usage.__dict__,
"start_time": self.start_time,
"end_time": self.end_time,
"time_to_first_token": self.time_to_first_token,
}
+1 -1
View File
@@ -9,7 +9,7 @@ from .message import Message
TContext = TypeVar("TContext", default=Any) TContext = TypeVar("TContext", default=Any)
@dataclass @dataclass(config={"arbitrary_types_allowed": True})
class ContextWrapper(Generic[TContext]): class ContextWrapper(Generic[TContext]):
"""A context for running an agent, which can be used to pass additional data or state.""" """A context for running an agent, which can be used to pass additional data or state."""
@@ -1,5 +1,4 @@
import sys import sys
import time
import traceback import traceback
import typing as T import typing as T
@@ -13,7 +12,6 @@ from mcp.types import (
) )
from astrbot import logger from astrbot import logger
from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
) )
@@ -26,7 +24,7 @@ from astrbot.core.provider.provider import Provider
from ..hooks import BaseAgentRunHooks from ..hooks import BaseAgentRunHooks
from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment from ..message import AssistantMessageSegment, Message, ToolCallMessageSegment
from ..response import AgentResponseData, AgentStats from ..response import AgentResponseData
from ..run_context import ContextWrapper, TContext from ..run_context import ContextWrapper, TContext
from ..tool_executor import BaseFunctionToolExecutor from ..tool_executor import BaseFunctionToolExecutor
from .base import AgentResponse, AgentState, BaseAgentRunner from .base import AgentResponse, AgentState, BaseAgentRunner
@@ -71,9 +69,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
) )
self.run_context.messages = messages self.run_context.messages = messages
self.stats = AgentStats()
self.stats.start_time = time.time()
async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]: async def _iter_llm_responses(self) -> T.AsyncGenerator[LLMResponse, None]:
"""Yields chunks *and* a final LLMResponse.""" """Yields chunks *and* a final LLMResponse."""
if self.streaming: if self.streaming:
@@ -103,10 +98,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
async for llm_response in self._iter_llm_responses(): async for llm_response in self._iter_llm_responses():
if llm_response.is_chunk: if llm_response.is_chunk:
# update ttft
if self.stats.time_to_first_token == 0:
self.stats.time_to_first_token = time.time() - self.stats.start_time
if llm_response.result_chain: if llm_response.result_chain:
yield AgentResponse( yield AgentResponse(
type="streaming_delta", type="streaming_delta",
@@ -130,10 +121,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
) )
continue continue
llm_resp_result = llm_response llm_resp_result = llm_response
if not llm_response.is_chunk and llm_response.usage:
# only count the token usage of the final response for computation purpose
self.stats.token_usage += llm_response.usage
break # got final response break # got final response
if not llm_resp_result: if not llm_resp_result:
@@ -145,7 +132,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
if llm_resp.role == "err": if llm_resp.role == "err":
# 如果 LLM 响应错误,转换到错误状态 # 如果 LLM 响应错误,转换到错误状态
self.final_llm_resp = llm_resp self.final_llm_resp = llm_resp
self.stats.end_time = time.time()
self._transition_state(AgentState.ERROR) self._transition_state(AgentState.ERROR)
yield AgentResponse( yield AgentResponse(
type="err", type="err",
@@ -160,7 +146,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果没有工具调用,转换到完成状态 # 如果没有工具调用,转换到完成状态
self.final_llm_resp = llm_resp self.final_llm_resp = llm_resp
self._transition_state(AgentState.DONE) self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
# record the final assistant message # record the final assistant message
self.run_context.messages.append( self.run_context.messages.append(
Message( Message(
@@ -190,19 +175,22 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
# 如果有工具调用,还需处理工具调用 # 如果有工具调用,还需处理工具调用
if llm_resp.tools_call_name: if llm_resp.tools_call_name:
tool_call_result_blocks = [] tool_call_result_blocks = []
for tool_call_name in llm_resp.tools_call_name:
yield AgentResponse(
type="tool_call",
data=AgentResponseData(
chain=MessageChain(type="tool_call").message(
f"🔨 调用工具: {tool_call_name}"
),
),
)
async for result in self._handle_function_tools(self.req, llm_resp): async for result in self._handle_function_tools(self.req, llm_resp):
if isinstance(result, list): if isinstance(result, list):
tool_call_result_blocks = result tool_call_result_blocks = result
elif isinstance(result, MessageChain): elif isinstance(result, MessageChain):
if result.type is None: result.type = "tool_call_result"
# should not happen
continue
if result.type == "tool_direct_result":
ar_type = "tool_call_result"
else:
ar_type = result.type
yield AgentResponse( yield AgentResponse(
type=ar_type, type="tool_call_result",
data=AgentResponseData(chain=result), data=AgentResponseData(chain=result),
) )
# 将结果添加到上下文中 # 将结果添加到上下文中
@@ -245,19 +233,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
llm_response.tools_call_args, llm_response.tools_call_args,
llm_response.tools_call_ids, llm_response.tools_call_ids,
): ):
yield MessageChain(
type="tool_call",
chain=[
Json(
data={
"id": func_tool_id,
"name": func_tool_name,
"args": func_tool_args,
"ts": time.time(),
}
)
],
)
try: try:
if not req.func_tool: if not req.func_tool:
return return
@@ -331,6 +306,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content=res.content[0].text, content=res.content[0].text,
), ),
) )
yield MessageChain().message(res.content[0].text)
elif isinstance(res.content[0], ImageContent): elif isinstance(res.content[0], ImageContent):
tool_call_result_blocks.append( tool_call_result_blocks.append(
ToolCallMessageSegment( ToolCallMessageSegment(
@@ -352,6 +328,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content=resource.text, content=resource.text,
), ),
) )
yield MessageChain().message(resource.text)
elif ( elif (
isinstance(resource, BlobResourceContents) isinstance(resource, BlobResourceContents)
and resource.mimeType and resource.mimeType
@@ -375,22 +352,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
content="返回的数据类型不受支持", content="返回的数据类型不受支持",
), ),
) )
yield MessageChain().message("返回的数据类型不受支持。")
# yield the last tool call result
if tool_call_result_blocks:
last_tcr_content = str(tool_call_result_blocks[-1].content)
yield MessageChain(
type="tool_call_result",
chain=[
Json(
data={
"id": func_tool_id,
"ts": time.time(),
"result": last_tcr_content,
}
)
],
)
elif resp is None: elif resp is None:
# Tool 直接请求发送消息给用户 # Tool 直接请求发送消息给用户
@@ -400,7 +362,6 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。" f"{func_tool_name} 没有没有返回值或者将结果直接发送给用户,此工具调用不会被记录到历史中。"
) )
self._transition_state(AgentState.DONE) self._transition_state(AgentState.DONE)
self.stats.end_time = time.time()
else: else:
# 不应该出现其他类型 # 不应该出现其他类型
logger.warning( logger.warning(
+1 -3
View File
@@ -6,10 +6,8 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.star.context import Context from astrbot.core.star.context import Context
@dataclass @dataclass(config={"arbitrary_types_allowed": True})
class AstrAgentContext: class AstrAgentContext:
__pydantic_config__ = {"arbitrary_types_allowed": True}
context: Context context: Context
"""The star context instance""" """The star context instance"""
event: AstrMessageEvent event: AstrMessageEvent
+2 -23
View File
@@ -4,7 +4,6 @@ from collections.abc import AsyncGenerator
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner
from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_agent_context import AstrAgentContext
from astrbot.core.message.components import Json
from astrbot.core.message.message_event_result import ( from astrbot.core.message.message_event_result import (
MessageChain, MessageChain,
MessageEventResult, MessageEventResult,
@@ -34,27 +33,16 @@ async def run_agent(
msg_chain = resp.data["chain"] msg_chain = resp.data["chain"]
if msg_chain.type == "tool_direct_result": if msg_chain.type == "tool_direct_result":
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容 # tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
await astr_event.send(msg_chain) await astr_event.send(resp.data["chain"])
continue continue
if astr_event.get_platform_id() == "webchat":
await astr_event.send(msg_chain)
# 对于其他情况,暂时先不处理 # 对于其他情况,暂时先不处理
continue continue
elif resp.type == "tool_call": elif resp.type == "tool_call":
if agent_runner.streaming: if agent_runner.streaming:
# 用来标记流式响应需要分节 # 用来标记流式响应需要分节
yield MessageChain(chain=[], type="break") yield MessageChain(chain=[], type="break")
if show_tool_use:
if astr_event.get_platform_name() == "webchat":
await astr_event.send(resp.data["chain"]) await astr_event.send(resp.data["chain"])
elif show_tool_use:
json_comp = resp.data["chain"].chain[0]
if isinstance(json_comp, Json):
m = f"🔨 调用工具: {json_comp.data.get('name')}"
else:
m = "🔨 调用工具..."
chain = MessageChain(type="tool_call").message(m)
await astr_event.send(chain)
continue continue
if stream_to_general and resp.type == "streaming_delta": if stream_to_general and resp.type == "streaming_delta":
@@ -81,15 +69,6 @@ async def run_agent(
continue continue
yield resp.data["chain"] # MessageChain yield resp.data["chain"] # MessageChain
if agent_runner.done(): if agent_runner.done():
# send agent stats to webchat
if astr_event.get_platform_name() == "webchat":
await astr_event.send(
MessageChain(
type="agent_stats",
chain=[Json(data=agent_runner.stats.to_dict())],
)
)
break break
except Exception as e: except Exception as e:
+10 -32
View File
@@ -4,7 +4,7 @@ import os
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.9.2" VERSION = "4.8.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db") DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [ WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -108,7 +108,6 @@ DEFAULT_CONFIG = {
"provider_id": "", "provider_id": "",
"dual_output": False, "dual_output": False,
"use_file_service": False, "use_file_service": False,
"trigger_probability": 1.0,
}, },
"provider_ltm_settings": { "provider_ltm_settings": {
"group_icl_enable": False, "group_icl_enable": False,
@@ -209,7 +208,7 @@ CONFIG_METADATA_2 = {
"callback_server_host": "0.0.0.0", "callback_server_host": "0.0.0.0",
"port": 6196, "port": 6196,
}, },
"OneBot v11": { "QQ 个人号(OneBot v11)": {
"id": "default", "id": "default",
"type": "aiocqhttp", "type": "aiocqhttp",
"enable": False, "enable": False,
@@ -946,7 +945,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/", "api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {
"model": "gemini-3-flash-preview", "model": "gemini-1.5-flash",
"temperature": 0.4, "temperature": 0.4,
}, },
"custom_headers": {}, "custom_headers": {},
@@ -963,7 +962,7 @@ CONFIG_METADATA_2 = {
"api_base": "https://generativelanguage.googleapis.com/", "api_base": "https://generativelanguage.googleapis.com/",
"timeout": 120, "timeout": 120,
"model_config": { "model_config": {
"model": "gemini-3-flash-preview", "model": "gemini-2.0-flash-exp",
"temperature": 0.4, "temperature": 0.4,
}, },
"gm_resp_image_modal": False, "gm_resp_image_modal": False,
@@ -976,7 +975,9 @@ CONFIG_METADATA_2 = {
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE", "sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE", "dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
}, },
"gm_thinking_config": {"budget": 0, "level": "HIGH"}, "gm_thinking_config": {
"budget": 0,
},
"modalities": ["text", "image", "tool_use"], "modalities": ["text", "image", "tool_use"],
}, },
"DeepSeek": { "DeepSeek": {
@@ -1817,24 +1818,13 @@ CONFIG_METADATA_2 = {
}, },
}, },
"gm_thinking_config": { "gm_thinking_config": {
"description": "Thinking Config", "description": "Gemini思考设置",
"type": "object", "type": "object",
"items": { "items": {
"budget": { "budget": {
"description": "Thinking Budget", "description": "思考预算",
"type": "int", "type": "int",
"hint": "Guides the model on the specific number of thinking tokens to use for reasoning. See: https://ai.google.dev/gemini-api/docs/thinking#set-budget", "hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
},
"level": {
"description": "Thinking Level",
"type": "string",
"hint": "Recommended for Gemini 3 models and onwards, lets you control reasoning behavior.See: https://ai.google.dev/gemini-api/docs/thinking#thinking-levels",
"options": [
"MINIMAL",
"LOW",
"MEDIUM",
"HIGH",
],
}, },
}, },
}, },
@@ -2219,9 +2209,6 @@ CONFIG_METADATA_2 = {
"use_file_service": { "use_file_service": {
"type": "bool", "type": "bool",
}, },
"trigger_probability": {
"type": "float",
},
}, },
}, },
"provider_ltm_settings": { "provider_ltm_settings": {
@@ -2432,14 +2419,6 @@ CONFIG_METADATA_3 = {
"provider_tts_settings.enable": True, "provider_tts_settings.enable": True,
}, },
}, },
"provider_tts_settings.trigger_probability": {
"description": "TTS 触发概率",
"type": "float",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": {
"provider_tts_settings.enable": True,
},
},
"provider_settings.image_caption_prompt": { "provider_settings.image_caption_prompt": {
"description": "图片转述提示词", "description": "图片转述提示词",
"type": "text", "type": "text",
@@ -3007,7 +2986,6 @@ CONFIG_METADATA_3 = {
"description": "回复概率", "description": "回复概率",
"type": "float", "type": "float",
"hint": "0.0-1.0 之间的数值", "hint": "0.0-1.0 之间的数值",
"slider": {"min": 0, "max": 1, "step": 0.05},
"condition": { "condition": {
"provider_ltm_settings.active_reply.enable": True, "provider_ltm_settings.active_reply.enable": True,
}, },
-1
View File
@@ -79,7 +79,6 @@ class ConfigMetadataI18n:
"_special", "_special",
"invisible", "invisible",
"options", "options",
"slider",
]: ]:
if attr in field_data: if attr in field_data:
field_result[attr] = field_data[attr] field_result[attr] = field_data[attr]
-72
View File
@@ -9,8 +9,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
from astrbot.core.db.po import ( from astrbot.core.db.po import (
Attachment, Attachment,
CommandConfig,
CommandConflict,
ConversationV2, ConversationV2,
Persona, Persona,
PlatformMessageHistory, PlatformMessageHistory,
@@ -316,76 +314,6 @@ class BaseDatabase(abc.ABC):
"""Clear all preferences for a specific scope ID.""" """Clear all preferences for a specific scope ID."""
... ...
@abc.abstractmethod
async def get_command_configs(self) -> list[CommandConfig]:
"""Get all stored command configurations."""
...
@abc.abstractmethod
async def get_command_config(self, handler_full_name: str) -> CommandConfig | None:
"""Fetch a single command configuration by handler."""
...
@abc.abstractmethod
async def upsert_command_config(
self,
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
"""Create or update a command configuration."""
...
@abc.abstractmethod
async def delete_command_config(self, handler_full_name: str) -> None:
"""Delete a single command configuration."""
...
@abc.abstractmethod
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
"""Bulk delete command configurations."""
...
@abc.abstractmethod
async def list_command_conflicts(
self,
status: str | None = None,
) -> list[CommandConflict]:
"""List recorded command conflict entries."""
...
@abc.abstractmethod
async def upsert_command_conflict(
self,
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
"""Create or update a conflict record."""
...
@abc.abstractmethod
async def delete_command_conflicts(self, ids: list[int]) -> None:
"""Delete conflict records."""
...
# @abc.abstractmethod # @abc.abstractmethod
# async def insert_llm_message( # async def insert_llm_message(
# self, # self,
-59
View File
@@ -234,65 +234,6 @@ class Attachment(SQLModel, table=True):
) )
class CommandConfig(SQLModel, table=True):
"""Per-command configuration overrides for dashboard management."""
__tablename__ = "command_configs" # type: ignore
handler_full_name: str = Field(
primary_key=True,
max_length=512,
)
plugin_name: str = Field(nullable=False, max_length=255)
module_path: str = Field(nullable=False, max_length=255)
original_command: str = Field(nullable=False, max_length=255)
resolved_command: str | None = Field(default=None, max_length=255)
enabled: bool = Field(default=True, nullable=False)
keep_original_alias: bool = Field(default=False, nullable=False)
conflict_key: str | None = Field(default=None, max_length=255)
resolution_strategy: str | None = Field(default=None, max_length=64)
note: str | None = Field(default=None, sa_type=Text)
extra_data: dict | None = Field(default=None, sa_type=JSON)
auto_managed: bool = Field(default=False, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
class CommandConflict(SQLModel, table=True):
"""Conflict tracking for duplicated command names."""
__tablename__ = "command_conflicts" # type: ignore
id: int | None = Field(
default=None, primary_key=True, sa_column_kwargs={"autoincrement": True}
)
conflict_key: str = Field(nullable=False, max_length=255)
handler_full_name: str = Field(nullable=False, max_length=512)
plugin_name: str = Field(nullable=False, max_length=255)
status: str = Field(default="pending", max_length=32)
resolution: str | None = Field(default=None, max_length=64)
resolved_command: str | None = Field(default=None, max_length=255)
note: str | None = Field(default=None, sa_type=Text)
extra_data: dict | None = Field(default=None, sa_type=JSON)
auto_generated: bool = Field(default=False, nullable=False)
created_at: datetime = Field(default_factory=lambda: datetime.now(timezone.utc))
updated_at: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
sa_column_kwargs={"onupdate": datetime.now(timezone.utc)},
)
__table_args__ = (
UniqueConstraint(
"conflict_key",
"handler_full_name",
name="uix_conflict_handler",
),
)
@dataclass @dataclass
class Conversation: class Conversation:
"""LLM 对话类 """LLM 对话类
-240
View File
@@ -1,7 +1,6 @@
import asyncio import asyncio
import threading import threading
import typing as T import typing as T
from collections.abc import Awaitable, Callable
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from sqlalchemy import CursorResult from sqlalchemy import CursorResult
@@ -11,8 +10,6 @@ from sqlmodel import col, delete, desc, func, or_, select, text, update
from astrbot.core.db import BaseDatabase from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import ( from astrbot.core.db.po import (
Attachment, Attachment,
CommandConfig,
CommandConflict,
ConversationV2, ConversationV2,
Persona, Persona,
PlatformMessageHistory, PlatformMessageHistory,
@@ -29,7 +26,6 @@ from astrbot.core.db.po import (
) )
NOT_GIVEN = T.TypeVar("NOT_GIVEN") NOT_GIVEN = T.TypeVar("NOT_GIVEN")
TxResult = T.TypeVar("TxResult")
class SQLiteDatabase(BaseDatabase): class SQLiteDatabase(BaseDatabase):
@@ -674,242 +670,6 @@ class SQLiteDatabase(BaseDatabase):
) )
await session.commit() await session.commit()
# ====
# Command Configuration & Conflict Tracking
# ====
async def _run_in_tx(
self,
fn: Callable[[AsyncSession], Awaitable[TxResult]],
) -> TxResult:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
return await fn(session)
@staticmethod
def _apply_updates(model, **updates) -> None:
for field, value in updates.items():
if value is not None:
setattr(model, field, value)
@staticmethod
def _new_command_config(
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
return CommandConfig(
handler_full_name=handler_full_name,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=True if enabled is None else enabled,
keep_original_alias=False
if keep_original_alias is None
else keep_original_alias,
conflict_key=conflict_key or original_command,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=bool(auto_managed),
)
@staticmethod
def _new_command_conflict(
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
return CommandConflict(
conflict_key=conflict_key,
handler_full_name=handler_full_name,
plugin_name=plugin_name,
status=status or "pending",
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=bool(auto_generated),
)
async def get_command_configs(self) -> list[CommandConfig]:
async with self.get_db() as session:
session: AsyncSession
result = await session.execute(select(CommandConfig))
return list(result.scalars().all())
async def get_command_config(
self,
handler_full_name: str,
) -> CommandConfig | None:
async with self.get_db() as session:
session: AsyncSession
return await session.get(CommandConfig, handler_full_name)
async def upsert_command_config(
self,
handler_full_name: str,
plugin_name: str,
module_path: str,
original_command: str,
*,
resolved_command: str | None = None,
enabled: bool | None = None,
keep_original_alias: bool | None = None,
conflict_key: str | None = None,
resolution_strategy: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_managed: bool | None = None,
) -> CommandConfig:
async def _op(session: AsyncSession) -> CommandConfig:
config = await session.get(CommandConfig, handler_full_name)
if not config:
config = self._new_command_config(
handler_full_name,
plugin_name,
module_path,
original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
session.add(config)
else:
self._apply_updates(
config,
plugin_name=plugin_name,
module_path=module_path,
original_command=original_command,
resolved_command=resolved_command,
enabled=enabled,
keep_original_alias=keep_original_alias,
conflict_key=conflict_key,
resolution_strategy=resolution_strategy,
note=note,
extra_data=extra_data,
auto_managed=auto_managed,
)
await session.flush()
await session.refresh(config)
return config
return await self._run_in_tx(_op)
async def delete_command_config(self, handler_full_name: str) -> None:
await self.delete_command_configs([handler_full_name])
async def delete_command_configs(self, handler_full_names: list[str]) -> None:
if not handler_full_names:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConfig).where(
col(CommandConfig.handler_full_name).in_(handler_full_names),
),
)
await self._run_in_tx(_op)
async def list_command_conflicts(
self,
status: str | None = None,
) -> list[CommandConflict]:
async with self.get_db() as session:
session: AsyncSession
query = select(CommandConflict)
if status:
query = query.where(CommandConflict.status == status)
result = await session.execute(query)
return list(result.scalars().all())
async def upsert_command_conflict(
self,
conflict_key: str,
handler_full_name: str,
plugin_name: str,
*,
status: str | None = None,
resolution: str | None = None,
resolved_command: str | None = None,
note: str | None = None,
extra_data: dict | None = None,
auto_generated: bool | None = None,
) -> CommandConflict:
async def _op(session: AsyncSession) -> CommandConflict:
result = await session.execute(
select(CommandConflict).where(
CommandConflict.conflict_key == conflict_key,
CommandConflict.handler_full_name == handler_full_name,
),
)
record = result.scalar_one_or_none()
if not record:
record = self._new_command_conflict(
conflict_key,
handler_full_name,
plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
session.add(record)
else:
self._apply_updates(
record,
plugin_name=plugin_name,
status=status,
resolution=resolution,
resolved_command=resolved_command,
note=note,
extra_data=extra_data,
auto_generated=auto_generated,
)
await session.flush()
await session.refresh(record)
return record
return await self._run_in_tx(_op)
async def delete_command_conflicts(self, ids: list[int]) -> None:
if not ids:
return
async def _op(session: AsyncSession) -> None:
await session.execute(
delete(CommandConflict).where(col(CommandConflict.id).in_(ids)),
)
await self._run_in_tx(_op)
# ==== # ====
# Deprecated Methods # Deprecated Methods
# ==== # ====
+1 -2
View File
@@ -24,7 +24,6 @@ import asyncio
import logging import logging
import os import os
import sys import sys
import time
from asyncio import Queue from asyncio import Queue
from collections import deque from collections import deque
@@ -149,7 +148,7 @@ class LogQueueHandler(logging.Handler):
self.log_broker.publish( self.log_broker.publish(
{ {
"level": record.levelname, "level": record.levelname,
"time": time.time(), "time": record.asctime,
"data": log_entry, "data": log_entry,
}, },
) )
+5 -4
View File
@@ -629,11 +629,12 @@ class Nodes(BaseMessageComponent):
class Json(BaseMessageComponent): class Json(BaseMessageComponent):
type = ComponentType.Json type = ComponentType.Json
data: dict data: str | dict
resid: int | None = 0
def __init__(self, data: str | dict, **_): def __init__(self, data, **_):
if isinstance(data, str): if isinstance(data, dict):
data = json.loads(data) data = json.dumps(data)
super().__init__(data=data, **_) super().__init__(data=data, **_)
+1 -5
View File
@@ -119,7 +119,7 @@ class RespondStage(Stage):
if (result := event.get_result()) is None: if (result := event.get_result()) is None:
return False return False
if self.only_llm_result and not result.is_llm_result(): if self.only_llm_result and result.is_llm_result():
return False return False
if event.get_platform_name() in [ if event.get_platform_name() in [
@@ -158,11 +158,7 @@ class RespondStage(Stage):
result = event.get_result() result = event.get_result()
if result is None: if result is None:
return return
if event.get_extra("_streaming_finished", False):
# prevent some plugin make result content type to LLM_RESULT after streaming finished, lead to send again
return
if result.result_content_type == ResultContentType.STREAMING_FINISH: if result.result_content_type == ResultContentType.STREAMING_FINISH:
event.set_extra("_streaming_finished", True)
return return
logger.info( logger.info(
+1 -21
View File
@@ -1,4 +1,3 @@
import random
import re import re
import time import time
import traceback import traceback
@@ -43,18 +42,6 @@ class ResultDecorateStage(Stage):
"forward_threshold" "forward_threshold"
] ]
trigger_probability = ctx.astrbot_config["provider_tts_settings"].get(
"trigger_probability",
1,
)
try:
self.tts_trigger_probability = max(
0.0,
min(float(trigger_probability), 1.0),
)
except (TypeError, ValueError):
self.tts_trigger_probability = 1.0
# 分段回复 # 分段回复
self.words_count_threshold = int( self.words_count_threshold = int(
ctx.astrbot_config["platform_settings"]["segmented_reply"][ ctx.astrbot_config["platform_settings"]["segmented_reply"][
@@ -259,14 +246,7 @@ class ResultDecorateStage(Stage):
and result.is_llm_result() and result.is_llm_result()
and SessionServiceManager.should_process_tts_request(event) and SessionServiceManager.should_process_tts_request(event)
): ):
should_tts = self.tts_trigger_probability >= 1.0 or ( if not tts_provider:
self.tts_trigger_probability > 0.0
and random.random() <= self.tts_trigger_probability
)
if not should_tts:
logger.debug("跳过 TTS:触发概率未命中。")
elif not tts_provider:
logger.warning( logger.warning(
f"会话 {event.unified_msg_origin} 未配置文本转语音模型。", f"会话 {event.unified_msg_origin} 未配置文本转语音模型。",
) )
+4
View File
@@ -112,6 +112,10 @@ class PlatformManager:
from .sources.satori.satori_adapter import ( from .sources.satori.satori_adapter import (
SatoriPlatformAdapter, # noqa: F401 SatoriPlatformAdapter, # noqa: F401
) )
case "github_webhook":
from .sources.github_webhook.github_webhook_adapter import (
GitHubWebhookPlatformAdapter, # noqa: F401
)
except (ImportError, ModuleNotFoundError) as e: except (ImportError, ModuleNotFoundError) as e:
logger.error( logger.error(
f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。", f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
@@ -0,0 +1,315 @@
import asyncio
import hashlib
import hmac
from typing import Any, cast
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import Plain
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
MessageType,
Platform,
PlatformMetadata,
)
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.platform.platform import PlatformStatus
from astrbot.core.utils.webhook_utils import log_webhook_info
from ...register import register_platform_adapter
from .github_webhook_event import GitHubWebhookMessageEvent
@register_platform_adapter(
"github_webhook",
"GitHub Webhook 适配器",
support_streaming_message=False,
)
class GitHubWebhookPlatformAdapter(Platform):
"""GitHub Webhook 平台适配器
支持的事件:
- issues (created)
- issue_comment (created)
- pull_request (opened)
"""
def __init__(
self,
platform_config: dict,
platform_settings: dict,
event_queue: asyncio.Queue,
) -> None:
super().__init__(platform_config, event_queue)
self.unified_webhook_mode = platform_config.get("unified_webhook_mode", True)
self.webhook_secret = platform_config.get("webhook_secret", "")
self.shutdown_event = asyncio.Event()
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
"""GitHub Webhook 是单向接收,不支持主动发送消息"""
logger.warning("GitHub Webhook 适配器不支持 send_by_session")
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
name="github_webhook",
description="GitHub Webhook 适配器",
id=cast(str, self.config.get("id")),
)
async def run(self):
"""运行适配器"""
self.status = PlatformStatus.RUNNING
# 如果启用统一 webhook 模式
webhook_uuid = self.config.get("webhook_uuid")
if self.unified_webhook_mode and webhook_uuid:
log_webhook_info(f"{self.meta().id}(GitHub Webhook)", webhook_uuid)
# 保持运行状态,等待 shutdown
await self.shutdown_event.wait()
else:
logger.warning("GitHub Webhook 适配器需要启用统一 webhook 模式")
await self.shutdown_event.wait()
async def webhook_callback(self, request: Any) -> Any:
"""统一 Webhook 回调入口
处理 GitHub webhook 事件
Args:
request: Quart 请求对象
Returns:
响应数据
"""
try:
# 获取事件类型
event_type = request.headers.get("X-GitHub-Event", "")
# 获取请求数据
payload = await request.json
# 验证 webhook 签名(如果配置了 secret
if self.webhook_secret:
if not await self._verify_signature(request, payload):
logger.warning("GitHub webhook 签名验证失败")
return {"error": "Invalid signature"}, 401
logger.debug(f"收到 GitHub Webhook 事件: {event_type}")
# 处理不同类型的事件
if event_type == "issues":
await self._handle_issue_event(payload)
elif event_type == "issue_comment":
await self._handle_issue_comment_event(payload)
elif event_type == "pull_request":
await self._handle_pull_request_event(payload)
elif event_type == "ping":
# GitHub webhook 验证事件
return {"message": "pong"}
else:
logger.debug(f"忽略不支持的 GitHub 事件类型: {event_type}")
return {"status": "ok"}
except Exception as e:
logger.error(f"处理 GitHub webhook 回调时发生错误: {e}", exc_info=True)
return {"error": str(e)}, 500
async def _verify_signature(self, request: Any, payload: dict) -> bool:
"""验证 GitHub webhook 签名
Args:
request: Quart 请求对象
payload: 请求负载数据
Returns:
签名是否有效
"""
signature_header = request.headers.get("X-Hub-Signature-256", "")
if not signature_header:
# 如果没有签名头,检查是否有旧版本的签名
signature_header = request.headers.get("X-Hub-Signature", "")
if not signature_header:
return False
# 获取原始请求体
body = await request.get_data()
# 计算 HMAC
if signature_header.startswith("sha256="):
expected_signature = hmac.new(
self.webhook_secret.encode("utf-8"),
body,
hashlib.sha256,
).hexdigest()
received_signature = signature_header.replace("sha256=", "")
elif signature_header.startswith("sha1="):
expected_signature = hmac.new(
self.webhook_secret.encode("utf-8"),
body,
hashlib.sha1,
).hexdigest()
received_signature = signature_header.replace("sha1=", "")
else:
return False
# 使用 hmac.compare_digest 防止时序攻击
return hmac.compare_digest(expected_signature, received_signature)
async def _handle_issue_event(self, payload: dict):
"""处理 issue 事件"""
action = payload.get("action", "")
# 只处理创建事件
if action != "created" and action != "opened":
return
issue = payload.get("issue", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"📝 新 Issue 创建\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"标题: {issue.get('title', 'No title')}\n"
f"作者: {sender.get('login', 'unknown')}\n"
f"链接: {issue.get('html_url', '')}\n"
f"内容:\n{issue.get('body', 'No description')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"issues",
payload,
)
)
async def _handle_issue_comment_event(self, payload: dict):
"""处理 issue 评论事件"""
action = payload.get("action", "")
# 只处理创建事件
if action != "created":
return
issue = payload.get("issue", {})
comment = payload.get("comment", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"💬 新 Issue 评论\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"Issue: {issue.get('title', 'No title')}\n"
f"评论者: {sender.get('login', 'unknown')}\n"
f"链接: {comment.get('html_url', '')}\n"
f"内容:\n{comment.get('body', 'No comment')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"issue_comment",
payload,
)
)
async def _handle_pull_request_event(self, payload: dict):
"""处理 pull request 事件"""
action = payload.get("action", "")
# 只处理打开事件
if action != "opened":
return
pr = payload.get("pull_request", {})
repo = payload.get("repository", {})
sender = payload.get("sender", {})
# 构造消息文本
message_text = (
f"🔀 新 Pull Request\n"
f"仓库: {repo.get('full_name', 'unknown')}\n"
f"标题: {pr.get('title', 'No title')}\n"
f"作者: {sender.get('login', 'unknown')}\n"
f"链接: {pr.get('html_url', '')}\n"
f"内容:\n{pr.get('body', 'No description')[:200]}"
)
# 创建 AstrBotMessage
abm = self._create_message(
message_text,
sender.get("login", "unknown"),
sender.get("login", "unknown"),
repo.get("full_name", "unknown"),
)
# 提交事件
self.commit_event(
GitHubWebhookMessageEvent(
message_text,
abm,
self.meta(),
repo.get("full_name", "unknown"),
"pull_request",
payload,
)
)
def _create_message(
self,
message_text: str,
user_id: str,
nickname: str,
session_id: str,
) -> AstrBotMessage:
"""创建 AstrBotMessage 对象"""
abm = AstrBotMessage()
abm.type = MessageType.GROUP_MESSAGE
abm.self_id = self.client_self_id
abm.session_id = session_id
abm.message_id = ""
abm.sender = MessageMember(user_id=user_id, nickname=nickname)
abm.message = [Plain(message_text)]
abm.message_str = message_text
abm.raw_message = message_text
return abm
async def terminate(self):
"""终止适配器运行"""
self.shutdown_event.set()
logger.info("GitHub Webhook 适配器已经被优雅地关闭")
@@ -0,0 +1,22 @@
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from ...astr_message_event import AstrMessageEvent
class GitHubWebhookMessageEvent(AstrMessageEvent):
"""GitHub Webhook 消息事件"""
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
event_type: str,
event_data: dict,
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.event_type = event_type
"""GitHub 事件类型: issues, issue_comment, pull_request"""
self.event_data = event_data
"""原始事件数据"""
@@ -81,12 +81,7 @@ class LarkPlatformAdapter(Platform):
) )
self.lark_api = ( self.lark_api = (
lark.Client.builder() lark.Client.builder().app_id(self.appid).app_secret(self.appsecret).build()
.app_id(self.appid)
.app_secret(self.appsecret)
.log_level(lark.LogLevel.ERROR)
.domain(self.domain)
.build()
) )
self.webhook_server = None self.webhook_server = None
@@ -200,15 +200,6 @@ class TelegramPlatformEvent(AstrMessageEvent):
if isinstance(chain, MessageChain): if isinstance(chain, MessageChain):
if chain.type == "break": if chain.type == "break":
# 分割符 # 分割符
if message_id:
try:
await self.client.edit_message_text(
text=delta,
chat_id=payload["chat_id"],
message_id=message_id,
)
except Exception as e:
logger.warning(f"编辑消息失败(streaming-break): {e!s}")
message_id = None # 重置消息 ID message_id = None # 重置消息 ID
delta = "" # 重置 delta delta = "" # 重置 delta
continue continue
@@ -1,12 +1,11 @@
import base64 import base64
import json
import os import os
import shutil import shutil
import uuid import uuid
from astrbot.api import logger from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import File, Image, Json, Plain, Record from astrbot.api.message_components import File, Image, Plain, Record
from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from .webchat_queue_mgr import webchat_queue_mgr from .webchat_queue_mgr import webchat_queue_mgr
@@ -42,20 +41,12 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
"type": "plain", "type": "plain",
"cid": cid,
"data": data, "data": data,
"streaming": streaming, "streaming": streaming,
"chain_type": message.type, "chain_type": message.type,
}, },
) )
elif isinstance(comp, Json):
await web_chat_back_queue.put(
{
"type": "plain",
"data": json.dumps(comp.data, ensure_ascii=False),
"streaming": streaming,
"chain_type": message.type,
},
)
elif isinstance(comp, Image): elif isinstance(comp, Image):
# save image to local # save image to local
filename = f"{str(uuid.uuid4())}.jpg" filename = f"{str(uuid.uuid4())}.jpg"
@@ -67,6 +58,7 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
"type": "image", "type": "image",
"cid": cid,
"data": data, "data": data,
"streaming": streaming, "streaming": streaming,
}, },
@@ -82,6 +74,7 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
"type": "record", "type": "record",
"cid": cid,
"data": data, "data": data,
"streaming": streaming, "streaming": streaming,
}, },
@@ -98,6 +91,7 @@ class WebChatMessageEvent(AstrMessageEvent):
await web_chat_back_queue.put( await web_chat_back_queue.put(
{ {
"type": "file", "type": "file",
"cid": cid,
"data": data, "data": data,
"streaming": streaming, "streaming": streaming,
}, },
@@ -117,17 +111,18 @@ class WebChatMessageEvent(AstrMessageEvent):
cid = self.session_id.split("!")[-1] cid = self.session_id.split("!")[-1]
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid) web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
async for chain in generator: async for chain in generator:
# if chain.type == "break" and final_data: if chain.type == "break" and final_data:
# # 分割符 # 分割符
# await web_chat_back_queue.put( await web_chat_back_queue.put(
# { {
# "type": "break", # break means a segment end "type": "break", # break means a segment end
# "data": final_data, "data": final_data,
# "streaming": True, "streaming": True,
# }, "cid": cid,
# ) },
# final_data = "" )
# continue final_data = ""
continue
r = await WebChatMessageEvent._send( r = await WebChatMessageEvent._send(
chain, chain,
@@ -147,6 +142,7 @@ class WebChatMessageEvent(AstrMessageEvent):
"data": final_data, "data": final_data,
"reasoning": reasoning_content, "reasoning": reasoning_content,
"streaming": True, "streaming": True,
"cid": cid,
}, },
) )
await super().send_streaming(generator, use_fallback) await super().send_streaming(generator, use_fallback)
-41
View File
@@ -1,5 +1,3 @@
from __future__ import annotations
import base64 import base64
import enum import enum
import json import json
@@ -201,38 +199,6 @@ class ProviderRequest:
return "" return ""
@dataclass
class TokenUsage:
input_other: int = 0
"""The number of input tokens, excluding cached tokens."""
input_cached: int = 0
"""The number of input cached tokens."""
output: int = 0
"""The number of output tokens."""
@property
def total(self) -> int:
return self.input_other + self.input_cached + self.output
@property
def input(self) -> int:
return self.input_other + self.input_cached
def __add__(self, other: TokenUsage) -> TokenUsage:
return TokenUsage(
input_other=self.input_other + other.input_other,
input_cached=self.input_cached + other.input_cached,
output=self.output + other.output,
)
def __sub__(self, other: TokenUsage) -> TokenUsage:
return TokenUsage(
input_other=self.input_other - other.input_other,
input_cached=self.input_cached - other.input_cached,
output=self.output - other.output,
)
@dataclass @dataclass
class LLMResponse: class LLMResponse:
role: str role: str
@@ -261,11 +227,6 @@ class LLMResponse:
is_chunk: bool = False is_chunk: bool = False
"""Indicates if the response is a chunked response.""" """Indicates if the response is a chunked response."""
id: str | None = None
"""The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
usage: TokenUsage | None = None
"""The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
def __init__( def __init__(
self, self,
role: str, role: str,
@@ -280,8 +241,6 @@ class LLMResponse:
| AnthropicMessage | AnthropicMessage
| None = None, | None = None,
is_chunk: bool = False, is_chunk: bool = False,
id: str | None = None,
usage: TokenUsage | None = None,
): ):
"""初始化 LLMResponse """初始化 LLMResponse
@@ -6,12 +6,10 @@ from mimetypes import guess_type
import anthropic import anthropic
from anthropic import AsyncAnthropic from anthropic import AsyncAnthropic
from anthropic.types import Message from anthropic.types import Message
from anthropic.types.message_delta_usage import MessageDeltaUsage
from anthropic.types.usage import Usage
from astrbot import logger from astrbot import logger
from astrbot.api.provider import Provider from astrbot.api.provider import Provider
from astrbot.core.provider.entities import LLMResponse, TokenUsage from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
@@ -109,22 +107,6 @@ class ProviderAnthropic(Provider):
return system_prompt, new_messages return system_prompt, new_messages
def _extract_usage(self, usage: Usage) -> TokenUsage:
# https://docs.claude.com/en/docs/build-with-claude/prompt-caching#tracking-cache-performance
return TokenUsage(
input_other=usage.input_tokens or 0,
input_cached=usage.cache_read_input_tokens or 0,
output=usage.output_tokens,
)
def _update_usage(self, token_usage: TokenUsage, usage: MessageDeltaUsage) -> None:
if usage.input_tokens is not None:
token_usage.input_other = usage.input_tokens
if usage.cache_read_input_tokens is not None:
token_usage.input_cached = usage.cache_read_input_tokens
if usage.output_tokens is not None:
token_usage.output = usage.output_tokens
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools: if tools:
if tool_list := tools.get_func_desc_anthropic_style(): if tool_list := tools.get_func_desc_anthropic_style():
@@ -149,10 +131,6 @@ class ProviderAnthropic(Provider):
llm_response.tools_call_args.append(content_block.input) llm_response.tools_call_args.append(content_block.input)
llm_response.tools_call_name.append(content_block.name) llm_response.tools_call_name.append(content_block.name)
llm_response.tools_call_ids.append(content_block.id) llm_response.tools_call_ids.append(content_block.id)
llm_response.id = completion.id
llm_response.usage = self._extract_usage(completion.usage)
# TODO(Soulter): 处理 end_turn 情况 # TODO(Soulter): 处理 end_turn 情况
if not llm_response.completion_text and not llm_response.tools_call_args: if not llm_response.completion_text and not llm_response.tools_call_args:
raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}") raise Exception(f"Anthropic API 返回的 completion 无法解析:{completion}")
@@ -174,16 +152,9 @@ class ProviderAnthropic(Provider):
final_text = "" final_text = ""
final_tool_calls = [] final_tool_calls = []
id = None
usage = TokenUsage()
async with self.client.messages.stream(**payloads) as stream: async with self.client.messages.stream(**payloads) as stream:
assert isinstance(stream, anthropic.AsyncMessageStream) assert isinstance(stream, anthropic.AsyncMessageStream)
async for event in stream: async for event in stream:
if event.type == "message_start":
# the usage contains input token usage
id = event.message.id
usage = self._extract_usage(event.message.usage)
if event.type == "content_block_start": if event.type == "content_block_start":
if event.content_block.type == "text": if event.content_block.type == "text":
# 文本块开始 # 文本块开始
@@ -191,8 +162,6 @@ class ProviderAnthropic(Provider):
role="assistant", role="assistant",
completion_text="", completion_text="",
is_chunk=True, is_chunk=True,
usage=usage,
id=id,
) )
elif event.content_block.type == "tool_use": elif event.content_block.type == "tool_use":
# 工具使用块开始,初始化缓冲区 # 工具使用块开始,初始化缓冲区
@@ -210,8 +179,6 @@ class ProviderAnthropic(Provider):
role="assistant", role="assistant",
completion_text=event.delta.text, completion_text=event.delta.text,
is_chunk=True, is_chunk=True,
usage=usage,
id=id,
) )
elif event.delta.type == "input_json_delta": elif event.delta.type == "input_json_delta":
# 工具调用参数增量 # 工具调用参数增量
@@ -248,8 +215,6 @@ class ProviderAnthropic(Provider):
tools_call_name=[tool_info["name"]], tools_call_name=[tool_info["name"]],
tools_call_ids=[tool_info["id"]], tools_call_ids=[tool_info["id"]],
is_chunk=True, is_chunk=True,
usage=usage,
id=id,
) )
except json.JSONDecodeError: except json.JSONDecodeError:
# JSON 解析失败,跳过这个工具调用 # JSON 解析失败,跳过这个工具调用
@@ -258,17 +223,11 @@ class ProviderAnthropic(Provider):
# 清理缓冲区 # 清理缓冲区
del tool_use_buffer[event.index] del tool_use_buffer[event.index]
elif event.type == "message_delta":
if event.usage:
self._update_usage(usage, event.usage)
# 返回最终的完整结果 # 返回最终的完整结果
final_response = LLMResponse( final_response = LLMResponse(
role="assistant", role="assistant",
completion_text=final_text, completion_text=final_text,
is_chunk=False, is_chunk=False,
usage=usage,
id=id,
) )
if final_tool_calls: if final_tool_calls:
+26 -63
View File
@@ -14,7 +14,7 @@ import astrbot.core.message.components as Comp
from astrbot import logger from astrbot import logger
from astrbot.api.provider import Provider from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.func_tool_manager import ToolSet from astrbot.core.provider.func_tool_manager import ToolSet
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
@@ -138,7 +138,7 @@ class ProviderGoogleGenAI(Provider):
modalities = ["TEXT"] modalities = ["TEXT"]
tool_list: list[types.Tool] | None = [] tool_list: list[types.Tool] | None = []
model_name = payloads.get("model", self.get_model()) model_name = self.get_model()
native_coderunner = self.provider_config.get("gm_native_coderunner", False) native_coderunner = self.provider_config.get("gm_native_coderunner", False)
native_search = self.provider_config.get("gm_native_search", False) native_search = self.provider_config.get("gm_native_search", False)
url_context = self.provider_config.get("gm_url_context", False) url_context = self.provider_config.get("gm_url_context", False)
@@ -197,37 +197,6 @@ class ProviderGoogleGenAI(Provider):
types.Tool(function_declarations=func_desc["function_declarations"]), types.Tool(function_declarations=func_desc["function_declarations"]),
] ]
# oper thinking config
thinking_config = None
if model_name.startswith("gemini-2.5"):
# The thinkingBudget parameter, introduced with the Gemini 2.5 series
thinking_budget = self.provider_config.get("gm_thinking_config", {}).get(
"budget", 0
)
if thinking_budget is not None:
thinking_config = types.ThinkingConfig(
thinking_budget=thinking_budget,
)
elif model_name.startswith("gemini-3"):
# The thinkingLevel parameter, recommended for Gemini 3 models and onwards
# Gemini 2.5 series models don't support thinkingLevel; use thinkingBudget instead.
thinking_level = self.provider_config.get("gm_thinking_config", {}).get(
"level", "HIGH"
)
if thinking_level and isinstance(thinking_level, str):
thinking_level = thinking_level.upper()
if thinking_level not in ["MINIMAL", "LOW", "MEDIUM", "HIGH"]:
logger.warning(
f"Invalid thinking level: {thinking_level}, using HIGH"
)
thinking_level = "HIGH"
level = types.ThinkingLevel(thinking_level)
thinking_config = types.ThinkingConfig()
if not hasattr(types.ThinkingConfig, "thinking_level"):
setattr(types.ThinkingConfig, "thinking_level", level)
else:
thinking_config.thinking_level = level
return types.GenerateContentConfig( return types.GenerateContentConfig(
system_instruction=system_instruction, system_instruction=system_instruction,
temperature=temperature, temperature=temperature,
@@ -247,7 +216,22 @@ class ProviderGoogleGenAI(Provider):
response_modalities=modalities, response_modalities=modalities,
tools=cast(types.ToolListUnion | None, tool_list), tools=cast(types.ToolListUnion | None, tool_list),
safety_settings=self.safety_settings if self.safety_settings else None, safety_settings=self.safety_settings if self.safety_settings else None,
thinking_config=thinking_config, thinking_config=(
types.ThinkingConfig(
thinking_budget=min(
int(
self.provider_config.get("gm_thinking_config", {}).get(
"budget",
0,
),
),
24576,
),
)
if "gemini-2.5-flash" in self.get_model()
and hasattr(types.ThinkingConfig, "thinking_budget")
else None
),
automatic_function_calling=types.AutomaticFunctionCallingConfig( automatic_function_calling=types.AutomaticFunctionCallingConfig(
disable=True, disable=True,
), ),
@@ -363,16 +347,6 @@ class ProviderGoogleGenAI(Provider):
] ]
return "".join(thought_buf).strip() return "".join(thought_buf).strip()
def _extract_usage(
self, usage_metadata: types.GenerateContentResponseUsageMetadata
) -> TokenUsage:
"""Extract usage from candidate"""
return TokenUsage(
input_other=usage_metadata.prompt_token_count or 0,
input_cached=usage_metadata.cached_content_token_count or 0,
output=usage_metadata.candidates_token_count or 0,
)
def _process_content_parts( def _process_content_parts(
self, self,
candidate: types.Candidate, candidate: types.Candidate,
@@ -457,8 +431,6 @@ class ProviderGoogleGenAI(Provider):
None, None,
) )
model = payloads.get("model", self.get_model())
modalities = ["TEXT"] modalities = ["TEXT"]
if self.provider_config.get("gm_resp_image_modal", False): if self.provider_config.get("gm_resp_image_modal", False):
modalities.append("IMAGE") modalities.append("IMAGE")
@@ -477,7 +449,7 @@ class ProviderGoogleGenAI(Provider):
temperature, temperature,
) )
result = await self.client.models.generate_content( result = await self.client.models.generate_content(
model=model, model=self.get_model(),
contents=cast(types.ContentListUnion, conversation), contents=cast(types.ContentListUnion, conversation),
config=config, config=config,
) )
@@ -503,11 +475,11 @@ class ProviderGoogleGenAI(Provider):
e.message = "" e.message = ""
if "Developer instruction is not enabled" in e.message: if "Developer instruction is not enabled" in e.message:
logger.warning( logger.warning(
f"{model} 不支持 system prompt,已自动去除(影响人格设置)", f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
) )
system_instruction = None system_instruction = None
elif "Function calling is not enabled" in e.message: elif "Function calling is not enabled" in e.message:
logger.warning(f"{model} 不支持函数调用,已自动去除") logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tools = None tools = None
elif ( elif (
"Multi-modal output is not supported" in e.message "Multi-modal output is not supported" in e.message
@@ -516,7 +488,7 @@ class ProviderGoogleGenAI(Provider):
or "only supports text output" in e.message or "only supports text output" in e.message
): ):
logger.warning( logger.warning(
f"{model} 不支持多模态输出,降级为文本模态", f"{self.get_model()} 不支持多模态输出,降级为文本模态",
) )
modalities = ["TEXT"] modalities = ["TEXT"]
else: else:
@@ -529,9 +501,6 @@ class ProviderGoogleGenAI(Provider):
result.candidates[0], result.candidates[0],
llm_response, llm_response,
) )
llm_response.id = result.response_id
if result.usage_metadata:
llm_response.usage = self._extract_usage(result.usage_metadata)
return llm_response return llm_response
async def _query_stream( async def _query_stream(
@@ -544,7 +513,7 @@ class ProviderGoogleGenAI(Provider):
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"), (msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
None, None,
) )
model = payloads.get("model", self.get_model())
conversation = self._prepare_conversation(payloads) conversation = self._prepare_conversation(payloads)
result = None result = None
@@ -556,7 +525,7 @@ class ProviderGoogleGenAI(Provider):
system_instruction, system_instruction,
) )
result = await self.client.models.generate_content_stream( result = await self.client.models.generate_content_stream(
model=model, model=self.get_model(),
contents=cast(types.ContentListUnion, conversation), contents=cast(types.ContentListUnion, conversation),
config=config, config=config,
) )
@@ -566,11 +535,11 @@ class ProviderGoogleGenAI(Provider):
e.message = "" e.message = ""
if "Developer instruction is not enabled" in e.message: if "Developer instruction is not enabled" in e.message:
logger.warning( logger.warning(
f"{model} 不支持 system prompt,已自动去除(影响人格设置)", f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)",
) )
system_instruction = None system_instruction = None
elif "Function calling is not enabled" in e.message: elif "Function calling is not enabled" in e.message:
logger.warning(f"{model} 不支持函数调用,已自动去除") logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
tools = None tools = None
else: else:
raise raise
@@ -600,9 +569,6 @@ class ProviderGoogleGenAI(Provider):
chunk.candidates[0], chunk.candidates[0],
llm_response, llm_response,
) )
llm_response.id = chunk.response_id
if chunk.usage_metadata:
llm_response.usage = self._extract_usage(chunk.usage_metadata)
yield llm_response yield llm_response
return return
@@ -630,9 +596,6 @@ class ProviderGoogleGenAI(Provider):
chunk.candidates[0], chunk.candidates[0],
final_response, final_response,
) )
final_response.id = chunk.response_id
if chunk.usage_metadata:
final_response.usage = self._extract_usage(chunk.usage_metadata)
break break
# Yield final complete response with accumulated text # Yield final complete response with accumulated text
+1 -18
View File
@@ -12,7 +12,6 @@ from openai._exceptions import NotFoundError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.completion_usage import CompletionUsage
import astrbot.core.message.components as Comp import astrbot.core.message.components as Comp
from astrbot import logger from astrbot import logger
@@ -20,7 +19,7 @@ from astrbot.api.provider import Provider
from astrbot.core.agent.message import Message from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter from ..register import register_provider_adapter
@@ -209,7 +208,6 @@ class ProviderOpenAIOfficial(Provider):
# handle the content delta # handle the content delta
reasoning = self._extract_reasoning_content(chunk) reasoning = self._extract_reasoning_content(chunk)
_y = False _y = False
llm_response.id = chunk.id
if reasoning: if reasoning:
llm_response.reasoning_content = reasoning llm_response.reasoning_content = reasoning
_y = True _y = True
@@ -219,8 +217,6 @@ class ProviderOpenAIOfficial(Provider):
chain=[Comp.Plain(completion_text)], chain=[Comp.Plain(completion_text)],
) )
_y = True _y = True
if chunk.usage:
llm_response.usage = self._extract_usage(chunk.usage)
if _y: if _y:
yield llm_response yield llm_response
@@ -249,15 +245,6 @@ class ProviderOpenAIOfficial(Provider):
reasoning_text = str(reasoning_attr) reasoning_text = str(reasoning_attr)
return reasoning_text return reasoning_text
def _extract_usage(self, usage: CompletionUsage) -> TokenUsage:
ptd = usage.prompt_tokens_details
cached = ptd.cached_tokens if ptd and ptd.cached_tokens else 0
return TokenUsage(
input_other=usage.prompt_tokens - cached,
input_cached=ptd.cached_tokens if ptd and ptd.cached_tokens else 0,
output=usage.completion_tokens,
)
async def _parse_openai_completion( async def _parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse: ) -> LLMResponse:
@@ -334,10 +321,6 @@ class ProviderOpenAIOfficial(Provider):
raise Exception(f"API 返回的 completion 无法解析:{completion}") raise Exception(f"API 返回的 completion 无法解析:{completion}")
llm_response.raw_completion = completion llm_response.raw_completion = completion
llm_response.id = completion.id
if completion.usage:
llm_response.usage = self._extract_usage(completion.usage)
return llm_response return llm_response
+1 -5
View File
@@ -2,19 +2,15 @@ from astrbot.core import html_renderer
from astrbot.core.provider import Provider from astrbot.core.provider import Provider
from astrbot.core.star.star_tools import StarTools from astrbot.core.star.star_tools import StarTools
from astrbot.core.utils.command_parser import CommandParserMixin from astrbot.core.utils.command_parser import CommandParserMixin
from astrbot.core.utils.plugin_kv_store import PluginKVStoreMixin
from .context import Context from .context import Context
from .star import StarMetadata, star_map, star_registry from .star import StarMetadata, star_map, star_registry
from .star_manager import PluginManager from .star_manager import PluginManager
class Star(CommandParserMixin, PluginKVStoreMixin): class Star(CommandParserMixin):
"""所有插件(Star)的父类,所有插件都应该继承于这个类""" """所有插件(Star)的父类,所有插件都应该继承于这个类"""
author: str
name: str
def __init__(self, context: Context, config: dict | None = None): def __init__(self, context: Context, config: dict | None = None):
StarTools.initialize(context) StarTools.initialize(context)
self.context = context self.context = context
-449
View File
@@ -1,449 +0,0 @@
from __future__ import annotations
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from astrbot.core import db_helper
from astrbot.core.db.po import CommandConfig
from astrbot.core.star.filter.command import CommandFilter
from astrbot.core.star.filter.command_group import CommandGroupFilter
from astrbot.core.star.filter.permission import PermissionType, PermissionTypeFilter
from astrbot.core.star.star import star_map
from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry
@dataclass
class CommandDescriptor:
handler: StarHandlerMetadata = field(repr=False)
filter_ref: CommandFilter | CommandGroupFilter | None = field(
default=None,
repr=False,
)
handler_full_name: str = ""
handler_name: str = ""
plugin_name: str = ""
plugin_display_name: str | None = None
module_path: str = ""
description: str = ""
command_type: str = "command" # "command" | "group" | "sub_command"
raw_command_name: str | None = None
current_fragment: str | None = None
parent_signature: str = ""
parent_group_handler: str = ""
original_command: str | None = None
effective_command: str | None = None
aliases: list[str] = field(default_factory=list)
permission: str = "everyone"
enabled: bool = True
is_group: bool = False
is_sub_command: bool = False
reserved: bool = False
config: CommandConfig | None = None
has_conflict: bool = False
sub_commands: list[CommandDescriptor] = field(default_factory=list)
async def sync_command_configs() -> None:
"""同步指令配置,清理过期配置。"""
descriptors = _collect_descriptors(include_sub_commands=False)
config_records = await db_helper.get_command_configs()
config_map = _bind_configs_to_descriptors(descriptors, config_records)
live_handlers = {desc.handler_full_name for desc in descriptors}
stale_configs = [key for key in config_map if key not in live_handlers]
if stale_configs:
await db_helper.delete_command_configs(stale_configs)
async def toggle_command(handler_full_name: str, enabled: bool) -> CommandDescriptor:
descriptor = _build_descriptor_by_full_name(handler_full_name)
if not descriptor:
raise ValueError("指定的处理函数不存在或不是指令。")
existing_cfg = await db_helper.get_command_config(handler_full_name)
config = await db_helper.upsert_command_config(
handler_full_name=handler_full_name,
plugin_name=descriptor.plugin_name or "",
module_path=descriptor.module_path,
original_command=descriptor.original_command or descriptor.handler_name,
resolved_command=(
existing_cfg.resolved_command
if existing_cfg
else descriptor.current_fragment
),
enabled=enabled,
keep_original_alias=False,
conflict_key=existing_cfg.conflict_key
if existing_cfg and existing_cfg.conflict_key
else descriptor.original_command,
resolution_strategy=existing_cfg.resolution_strategy if existing_cfg else None,
note=existing_cfg.note if existing_cfg else None,
extra_data=existing_cfg.extra_data if existing_cfg else None,
auto_managed=False,
)
_bind_descriptor_with_config(descriptor, config)
await sync_command_configs()
return descriptor
async def rename_command(
handler_full_name: str,
new_fragment: str,
) -> CommandDescriptor:
descriptor = _build_descriptor_by_full_name(handler_full_name)
if not descriptor:
raise ValueError("指定的处理函数不存在或不是指令。")
new_fragment = new_fragment.strip()
if not new_fragment:
raise ValueError("指令名不能为空。")
candidate_full = _compose_command(descriptor.parent_signature, new_fragment)
if _is_command_in_use(handler_full_name, candidate_full):
raise ValueError("新的指令名已被其他指令占用,请换一个名称。")
config = await db_helper.upsert_command_config(
handler_full_name=handler_full_name,
plugin_name=descriptor.plugin_name or "",
module_path=descriptor.module_path,
original_command=descriptor.original_command or descriptor.handler_name,
resolved_command=new_fragment,
enabled=True if descriptor.enabled else False,
keep_original_alias=False,
conflict_key=descriptor.original_command,
resolution_strategy="manual_rename",
note=None,
extra_data=None,
auto_managed=False,
)
_bind_descriptor_with_config(descriptor, config)
await sync_command_configs()
return descriptor
async def list_commands() -> list[dict[str, Any]]:
descriptors = _collect_descriptors(include_sub_commands=True)
config_records = await db_helper.get_command_configs()
_bind_configs_to_descriptors(descriptors, config_records)
conflict_groups = _group_conflicts(descriptors)
conflict_handler_names: set[str] = {
d.handler_full_name for group in conflict_groups.values() for d in group
}
# 分类,设置冲突标志,将子指令挂载到父指令组
group_map: dict[str, CommandDescriptor] = {}
sub_commands: list[CommandDescriptor] = []
root_commands: list[CommandDescriptor] = []
for desc in descriptors:
desc.has_conflict = desc.handler_full_name in conflict_handler_names
if desc.is_group:
group_map[desc.handler_full_name] = desc
elif desc.is_sub_command:
sub_commands.append(desc)
else:
root_commands.append(desc)
for sub in sub_commands:
if sub.parent_group_handler and sub.parent_group_handler in group_map:
group_map[sub.parent_group_handler].sub_commands.append(sub)
else:
root_commands.append(sub)
# 指令组 + 普通指令,按 effective_command 字母排序
all_commands = list(group_map.values()) + root_commands
all_commands.sort(key=lambda d: (d.effective_command or "").lower())
result = [_descriptor_to_dict(desc) for desc in all_commands]
return result
async def list_command_conflicts() -> list[dict[str, Any]]:
"""列出所有冲突的指令组。"""
descriptors = _collect_descriptors(include_sub_commands=False)
config_records = await db_helper.get_command_configs()
_bind_configs_to_descriptors(descriptors, config_records)
conflict_groups = _group_conflicts(descriptors)
details = [
{
"conflict_key": key,
"handlers": [
{
"handler_full_name": item.handler_full_name,
"plugin": item.plugin_name,
"current_name": item.effective_command,
}
for item in group
],
}
for key, group in conflict_groups.items()
]
return details
# Internal helpers ----------------------------------------------------------
def _collect_descriptors(include_sub_commands: bool) -> list[CommandDescriptor]:
"""收集指令,按需包含子指令。"""
descriptors: list[CommandDescriptor] = []
for handler in star_handlers_registry:
desc = _build_descriptor(handler)
if not desc:
continue
if not include_sub_commands and desc.is_sub_command:
continue
descriptors.append(desc)
return descriptors
def _build_descriptor(handler: StarHandlerMetadata) -> CommandDescriptor | None:
filter_ref = _locate_primary_filter(handler)
if filter_ref is None:
return None
plugin_meta = star_map.get(handler.handler_module_path)
plugin_name = (
plugin_meta.name if plugin_meta else None
) or handler.handler_module_path
plugin_display = plugin_meta.display_name if plugin_meta else None
is_sub_command = bool(handler.extras_configs.get("sub_command"))
parent_group_handler = ""
if isinstance(filter_ref, CommandFilter):
raw_fragment = getattr(
filter_ref, "_original_command_name", filter_ref.command_name
)
current_fragment = filter_ref.command_name
parent_signature = (filter_ref.parent_command_names or [""])[0].strip()
# 如果是子指令,尝试找到父指令组的 handler_full_name
if is_sub_command and parent_signature:
parent_group_handler = _find_parent_group_handler(
handler.handler_module_path, parent_signature
)
else:
raw_fragment = getattr(
filter_ref, "_original_group_name", filter_ref.group_name
)
current_fragment = filter_ref.group_name
parent_signature = _resolve_group_parent_signature(filter_ref)
original_command = _compose_command(parent_signature, raw_fragment)
effective_command = _compose_command(parent_signature, current_fragment)
# 确定 command_type
if isinstance(filter_ref, CommandGroupFilter):
command_type = "group"
elif is_sub_command:
command_type = "sub_command"
else:
command_type = "command"
descriptor = CommandDescriptor(
handler=handler,
filter_ref=filter_ref,
handler_full_name=handler.handler_full_name,
handler_name=handler.handler_name,
plugin_name=plugin_name,
plugin_display_name=plugin_display,
module_path=handler.handler_module_path,
description=handler.desc or "",
command_type=command_type,
raw_command_name=raw_fragment,
current_fragment=current_fragment,
parent_signature=parent_signature,
parent_group_handler=parent_group_handler,
original_command=original_command,
effective_command=effective_command,
aliases=sorted(getattr(filter_ref, "alias", set())),
permission=_determine_permission(handler),
enabled=handler.enabled,
is_group=isinstance(filter_ref, CommandGroupFilter),
is_sub_command=is_sub_command,
reserved=plugin_meta.reserved if plugin_meta else False,
)
return descriptor
def _build_descriptor_by_full_name(full_name: str) -> CommandDescriptor | None:
handler = star_handlers_registry.get_handler_by_full_name(full_name)
if not handler:
return None
return _build_descriptor(handler)
def _locate_primary_filter(
handler: StarHandlerMetadata,
) -> CommandFilter | CommandGroupFilter | None:
for filter_ref in handler.event_filters:
if isinstance(filter_ref, (CommandFilter, CommandGroupFilter)):
return filter_ref
return None
def _determine_permission(handler: StarHandlerMetadata) -> str:
for filter_ref in handler.event_filters:
if isinstance(filter_ref, PermissionTypeFilter):
return (
"admin"
if filter_ref.permission_type == PermissionType.ADMIN
else "member"
)
return "everyone"
def _resolve_group_parent_signature(group_filter: CommandGroupFilter) -> str:
signatures: list[str] = []
parent = group_filter.parent_group
while parent:
signatures.append(getattr(parent, "_original_group_name", parent.group_name))
parent = parent.parent_group
return " ".join(reversed(signatures)).strip()
def _find_parent_group_handler(module_path: str, parent_signature: str) -> str:
"""根据模块路径和父级签名,找到对应的指令组 handler_full_name。"""
parent_sig_normalized = parent_signature.strip()
for handler in star_handlers_registry:
if handler.handler_module_path != module_path:
continue
filter_ref = _locate_primary_filter(handler)
if not isinstance(filter_ref, CommandGroupFilter):
continue
# 检查该指令组的完整指令名是否匹配 parent_signature
group_names = filter_ref.get_complete_command_names()
if parent_sig_normalized in group_names:
return handler.handler_full_name
return ""
def _compose_command(parent_signature: str, fragment: str | None) -> str:
fragment = (fragment or "").strip()
parent_signature = parent_signature.strip()
if not parent_signature:
return fragment
if not fragment:
return parent_signature
return f"{parent_signature} {fragment}"
def _bind_descriptor_with_config(
descriptor: CommandDescriptor,
config: CommandConfig,
) -> None:
_apply_config_to_descriptor(descriptor, config)
_apply_config_to_runtime(descriptor, config)
def _apply_config_to_descriptor(
descriptor: CommandDescriptor,
config: CommandConfig,
) -> None:
descriptor.config = config
descriptor.enabled = config.enabled
if config.original_command:
descriptor.original_command = config.original_command
new_fragment = config.resolved_command or descriptor.current_fragment
descriptor.current_fragment = new_fragment
descriptor.effective_command = _compose_command(
descriptor.parent_signature,
new_fragment,
)
def _apply_config_to_runtime(
descriptor: CommandDescriptor,
config: CommandConfig,
) -> None:
descriptor.handler.enabled = config.enabled
if descriptor.filter_ref and descriptor.current_fragment:
_set_filter_fragment(descriptor.filter_ref, descriptor.current_fragment)
def _bind_configs_to_descriptors(
descriptors: list[CommandDescriptor],
config_records: list[CommandConfig],
) -> dict[str, CommandConfig]:
config_map = {cfg.handler_full_name: cfg for cfg in config_records}
for desc in descriptors:
if cfg := config_map.get(desc.handler_full_name):
_bind_descriptor_with_config(desc, cfg)
return config_map
def _group_conflicts(
descriptors: list[CommandDescriptor],
) -> dict[str, list[CommandDescriptor]]:
conflicts: dict[str, list[CommandDescriptor]] = defaultdict(list)
for desc in descriptors:
if desc.effective_command and desc.enabled:
conflicts[desc.effective_command].append(desc)
return {k: v for k, v in conflicts.items() if len(v) > 1}
def _set_filter_fragment(
filter_ref: CommandFilter | CommandGroupFilter,
fragment: str,
) -> None:
attr = (
"group_name" if isinstance(filter_ref, CommandGroupFilter) else "command_name"
)
current_value = getattr(filter_ref, attr)
if fragment == current_value:
return
setattr(filter_ref, attr, fragment)
if hasattr(filter_ref, "_cmpl_cmd_names"):
filter_ref._cmpl_cmd_names = None
def _is_command_in_use(
target_handler_full_name: str,
candidate_full_command: str,
) -> bool:
candidate = candidate_full_command.strip()
for handler in star_handlers_registry:
if handler.handler_full_name == target_handler_full_name:
continue
filter_ref = _locate_primary_filter(handler)
if not filter_ref:
continue
names = {name.strip() for name in filter_ref.get_complete_command_names()}
if candidate in names:
return True
return False
def _descriptor_to_dict(desc: CommandDescriptor) -> dict[str, Any]:
result = {
"handler_full_name": desc.handler_full_name,
"handler_name": desc.handler_name,
"plugin": desc.plugin_name,
"plugin_display_name": desc.plugin_display_name,
"module_path": desc.module_path,
"description": desc.description,
"type": desc.command_type,
"parent_signature": desc.parent_signature,
"parent_group_handler": desc.parent_group_handler,
"original_command": desc.original_command,
"current_fragment": desc.current_fragment,
"effective_command": desc.effective_command,
"aliases": desc.aliases,
"permission": desc.permission,
"enabled": desc.enabled,
"is_group": desc.is_group,
"has_conflict": desc.has_conflict,
"reserved": desc.reserved,
}
# 如果是指令组,包含子指令列表
if desc.is_group and desc.sub_commands:
result["sub_commands"] = [_descriptor_to_dict(sub) for sub in desc.sub_commands]
else:
result["sub_commands"] = []
return result
-4
View File
@@ -296,10 +296,6 @@ class Context:
provider_type=ProviderType.CHAT_COMPLETION, provider_type=ProviderType.CHAT_COMPLETION,
umo=umo, umo=umo,
) )
if prov is None:
raise ProviderNotFoundError(
"provider not found, please choose provider first"
)
if not isinstance(prov, Provider): if not isinstance(prov, Provider):
raise ValueError("返回的 Provider 不是 Provider 类型") raise ValueError("返回的 Provider 不是 Provider 类型")
return prov return prov
-1
View File
@@ -40,7 +40,6 @@ class CommandFilter(HandlerFilter):
): ):
self.command_name = command_name self.command_name = command_name
self.alias = alias if alias else set() self.alias = alias if alias else set()
self._original_command_name = command_name
self.parent_command_names = ( self.parent_command_names = (
parent_command_names if parent_command_names is not None else [""] parent_command_names if parent_command_names is not None else [""]
) )
@@ -18,7 +18,6 @@ class CommandGroupFilter(HandlerFilter):
): ):
self.group_name = group_name self.group_name = group_name
self.alias = alias if alias else set() self.alias = alias if alias else set()
self._original_group_name = group_name
self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = [] self.sub_command_filters: list[CommandFilter | CommandGroupFilter] = []
self.custom_filter_list: list[CustomFilter] = [] self.custom_filter_list: list[CustomFilter] = []
self.parent_group = parent_group self.parent_group = parent_group
-4
View File
@@ -118,8 +118,6 @@ class StarHandlerRegistry(Generic[T]):
# 过滤事件类型 # 过滤事件类型
if handler.event_type != event_type: if handler.event_type != event_type:
continue continue
if not handler.enabled:
continue
# 过滤启用状态 # 过滤启用状态
if only_activated: if only_activated:
plugin = star_map.get(handler.handler_module_path) plugin = star_map.get(handler.handler_module_path)
@@ -222,8 +220,6 @@ class StarHandlerMetadata(Generic[H]):
extras_configs: dict = field(default_factory=dict) extras_configs: dict = field(default_factory=dict)
"""插件注册的一些其他的信息, 如 priority 等""" """插件注册的一些其他的信息, 如 priority 等"""
enabled: bool = True
def __lt__(self, other: StarHandlerMetadata): def __lt__(self, other: StarHandlerMetadata):
"""定义小于运算符以支持优先队列""" """定义小于运算符以支持优先队列"""
return self.extras_configs.get("priority", 0) < other.extras_configs.get( return self.extras_configs.get("priority", 0) < other.extras_configs.get(
-14
View File
@@ -23,7 +23,6 @@ from astrbot.core.utils.astrbot_path import (
from astrbot.core.utils.io import remove_dir from astrbot.core.utils.io import remove_dir
from . import StarMetadata from . import StarMetadata
from .command_management import sync_command_configs
from .context import Context from .context import Context
from .filter.permission import PermissionType, PermissionTypeFilter from .filter.permission import PermissionType, PermissionTypeFilter
from .star import star_map, star_registry from .star import star_map, star_registry
@@ -468,18 +467,6 @@ class PluginManager:
metadata.star_cls = metadata.star_cls_type( metadata.star_cls = metadata.star_cls_type(
context=self.context, context=self.context,
) )
p_name = (metadata.name or "unknown").lower().replace("/", "_")
p_author = (
(metadata.author or "unknown").lower().replace("/", "_")
)
setattr(metadata.star_cls, "name", p_name)
setattr(metadata.star_cls, "author", p_author)
setattr(
metadata.star_cls,
"plugin_id",
f"{p_author}/{p_name}",
)
else: else:
logger.info(f"插件 {metadata.name} 已被禁用。") logger.info(f"插件 {metadata.name} 已被禁用。")
@@ -631,7 +618,6 @@ class PluginManager:
# 清除 pip.main 导致的多余的 logging handlers # 清除 pip.main 导致的多余的 logging handlers
for handler in logging.root.handlers[:]: for handler in logging.root.handlers[:]:
logging.root.removeHandler(handler) logging.root.removeHandler(handler)
await sync_command_configs()
if not fail_rec: if not fail_rec:
return True, None return True, None
-28
View File
@@ -1,28 +0,0 @@
from typing import TypeVar
from astrbot.core import sp
SUPPORTED_VALUE_TYPES = int | float | str | bytes | bool | dict | list | None
_VT = TypeVar("_VT")
class PluginKVStoreMixin:
"""为插件提供键值存储功能的 Mixin 类"""
plugin_id: str
async def put_kv_data(
self,
key: str,
value: SUPPORTED_VALUE_TYPES,
) -> None:
"""为指定插件存储一个键值对"""
await sp.put_async("plugin", self.plugin_id, key, value)
async def get_kv_data(self, key: str, default: _VT) -> _VT | None:
"""获取指定插件存储的键值对"""
return await sp.get_async("plugin", self.plugin_id, key, default)
async def delete_kv_data(self, key: str) -> None:
"""删除指定插件存储的键值对"""
await sp.remove_async("plugin", self.plugin_id, key)
-2
View File
@@ -1,6 +1,5 @@
from .auth import AuthRoute from .auth import AuthRoute
from .chat import ChatRoute from .chat import ChatRoute
from .command import CommandRoute
from .config import ConfigRoute from .config import ConfigRoute
from .conversation import ConversationRoute from .conversation import ConversationRoute
from .file import FileRoute from .file import FileRoute
@@ -18,7 +17,6 @@ from .update import UpdateRoute
__all__ = [ __all__ = [
"AuthRoute", "AuthRoute",
"ChatRoute", "ChatRoute",
"CommandRoute",
"ConfigRoute", "ConfigRoute",
"ConversationRoute", "ConversationRoute",
"FileRoute", "FileRoute",
+13 -56
View File
@@ -227,19 +227,16 @@ class ChatRoute(Route):
text: str, text: str,
media_parts: list, media_parts: list,
reasoning: str, reasoning: str,
agent_stats: dict,
): ):
"""保存 bot 消息到历史记录,返回保存的记录""" """保存 bot 消息到历史记录,返回保存的记录"""
bot_message_parts = [] bot_message_parts = []
bot_message_parts.extend(media_parts)
if text: if text:
bot_message_parts.append({"type": "plain", "text": text}) bot_message_parts.append({"type": "plain", "text": text})
bot_message_parts.extend(media_parts)
new_his = {"type": "bot", "message": bot_message_parts} new_his = {"type": "bot", "message": bot_message_parts}
if reasoning: if reasoning:
new_his["reasoning"] = reasoning new_his["reasoning"] = reasoning
if agent_stats:
new_his["agent_stats"] = agent_stats
record = await self.platform_history_mgr.insert( record = await self.platform_history_mgr.insert(
platform_id="webchat", platform_id="webchat",
@@ -297,8 +294,7 @@ class ChatRoute(Route):
accumulated_parts = [] accumulated_parts = []
accumulated_text = "" accumulated_text = ""
accumulated_reasoning = "" accumulated_reasoning = ""
tool_calls = {}
agent_stats = {}
try: try:
async with track_conversation(self.running_convs, webchat_conv_id): async with track_conversation(self.running_convs, webchat_conv_id):
while True: while True:
@@ -318,16 +314,6 @@ class ChatRoute(Route):
result_text = result["data"] result_text = result["data"]
msg_type = result.get("type") msg_type = result.get("type")
streaming = result.get("streaming", False) streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
stats_info = {
"type": "agent_stats",
"data": json.loads(result_text),
}
yield f"data: {json.dumps(stats_info, ensure_ascii=False)}\n\n"
agent_stats = stats_info["data"]
continue
# 发送 SSE 数据 # 发送 SSE 数据
try: try:
@@ -349,35 +335,11 @@ class ChatRoute(Route):
# 累积消息部分 # 累积消息部分
if msg_type == "plain": if msg_type == "plain":
chain_type = result.get("chain_type") chain_type = result.get("chain_type", "normal")
if chain_type == "tool_call": if chain_type == "reasoning":
tool_call = json.loads(result_text)
tool_calls[tool_call.get("id")] = tool_call
if accumulated_text:
# 如果累积了文本,则先保存文本
accumulated_parts.append(
{"type": "plain", "text": accumulated_text}
)
accumulated_text = ""
elif chain_type == "tool_call_result":
tcr = json.loads(result_text)
tc_id = tcr.get("id")
if tc_id in tool_calls:
tool_calls[tc_id]["result"] = tcr.get("result")
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
accumulated_parts.append(
{
"type": "tool_call",
"tool_calls": [tool_calls[tc_id]],
}
)
tool_calls.pop(tc_id, None)
elif chain_type == "reasoning":
accumulated_reasoning += result_text accumulated_reasoning += result_text
elif streaming:
accumulated_text += result_text
else: else:
accumulated_text = result_text accumulated_text += result_text
elif msg_type == "image": elif msg_type == "image":
filename = result_text.replace("[IMAGE]", "") filename = result_text.replace("[IMAGE]", "")
part = await self._create_attachment_from_file( part = await self._create_attachment_from_file(
@@ -405,20 +367,15 @@ class ChatRoute(Route):
if msg_type == "end": if msg_type == "end":
break break
elif ( elif (
(streaming and msg_type == "complete") or not streaming (streaming and msg_type == "complete")
# or msg_type == "break" or not streaming
or msg_type == "break"
): ):
if (
chain_type == "tool_call"
or chain_type == "tool_call_result"
):
continue
saved_record = await self._save_bot_message( saved_record = await self._save_bot_message(
webchat_conv_id, webchat_conv_id,
accumulated_text, accumulated_text,
accumulated_parts, accumulated_parts,
accumulated_reasoning, accumulated_reasoning,
agent_stats,
) )
# 发送保存的消息信息给前端 # 发送保存的消息信息给前端
if saved_record and not client_disconnected: if saved_record and not client_disconnected:
@@ -433,11 +390,11 @@ class ChatRoute(Route):
yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n" yield f"data: {json.dumps(saved_info, ensure_ascii=False)}\n\n"
except Exception: except Exception:
pass pass
accumulated_parts = [] # 重置累积变量 (对于 break 后的下一段消息)
accumulated_text = "" if msg_type == "break":
accumulated_reasoning = "" accumulated_parts = []
tool_calls = {} accumulated_text = ""
agent_stats = {} accumulated_reasoning = ""
except BaseException as e: except BaseException as e:
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True) logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
-82
View File
@@ -1,82 +0,0 @@
from quart import request
from astrbot.core.star.command_management import (
list_command_conflicts,
list_commands,
)
from astrbot.core.star.command_management import (
rename_command as rename_command_service,
)
from astrbot.core.star.command_management import (
toggle_command as toggle_command_service,
)
from .route import Response, Route, RouteContext
class CommandRoute(Route):
def __init__(self, context: RouteContext) -> None:
super().__init__(context)
self.routes = {
"/commands": ("GET", self.get_commands),
"/commands/conflicts": ("GET", self.get_conflicts),
"/commands/toggle": ("POST", self.toggle_command),
"/commands/rename": ("POST", self.rename_command),
}
self.register_routes()
async def get_commands(self):
commands = await list_commands()
summary = {
"total": len(commands),
"disabled": len([cmd for cmd in commands if not cmd["enabled"]]),
"conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]),
}
return Response().ok({"items": commands, "summary": summary}).__dict__
async def get_conflicts(self):
conflicts = await list_command_conflicts()
return Response().ok(conflicts).__dict__
async def toggle_command(self):
data = await request.get_json()
handler_full_name = data.get("handler_full_name")
enabled = data.get("enabled")
if handler_full_name is None or enabled is None:
return Response().error("handler_full_name 与 enabled 均为必填。").__dict__
if isinstance(enabled, str):
enabled = enabled.lower() in ("1", "true", "yes", "on")
try:
await toggle_command_service(handler_full_name, bool(enabled))
except ValueError as exc:
return Response().error(str(exc)).__dict__
payload = await _get_command_payload(handler_full_name)
return Response().ok(payload).__dict__
async def rename_command(self):
data = await request.get_json()
handler_full_name = data.get("handler_full_name")
new_name = data.get("new_name")
if not handler_full_name or not new_name:
return Response().error("handler_full_name 与 new_name 均为必填。").__dict__
try:
await rename_command_service(handler_full_name, new_name)
except ValueError as exc:
return Response().error(str(exc)).__dict__
payload = await _get_command_payload(handler_full_name)
return Response().ok(payload).__dict__
async def _get_command_payload(handler_full_name: str):
commands = await list_commands()
for cmd in commands:
if cmd["handler_full_name"] == handler_full_name:
return cmd
return {}
+1 -91
View File
@@ -1,9 +1,7 @@
import json import json
import traceback import traceback
from datetime import datetime
from io import BytesIO
from quart import request, send_file from quart import request
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -32,7 +30,6 @@ class ConversationRoute(Route):
"POST", "POST",
self.update_history, self.update_history,
), ),
"/conversation/export": ("POST", self.export_conversations),
} }
self.db_helper = db_helper self.db_helper = db_helper
self.conv_mgr = core_lifecycle.conversation_manager self.conv_mgr = core_lifecycle.conversation_manager
@@ -286,90 +283,3 @@ class ConversationRoute(Route):
except Exception as e: except Exception as e:
logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}") logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"更新对话历史失败: {e!s}").__dict__ return Response().error(f"更新对话历史失败: {e!s}").__dict__
async def export_conversations(self):
"""批量导出对话为 JSONL 格式"""
try:
data = await request.get_json()
conversations_to_export = data.get("conversations", [])
if not conversations_to_export:
return Response().error("导出列表不能为空").__dict__
# 收集所有对话的内容
jsonl_lines = []
exported_count = 0
failed_items = []
for conv_info in conversations_to_export:
user_id = conv_info.get("user_id")
cid = conv_info.get("cid")
if not user_id or not cid:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 缺少必要参数",
)
continue
try:
conversation = await self.conv_mgr.get_conversation(
unified_msg_origin=user_id,
conversation_id=cid,
)
if not conversation:
failed_items.append(
f"user_id:{user_id}, cid:{cid} - 对话不存在"
)
continue
# 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1)
content = json.loads(conversation.history)
# 创建导出记录
export_record = {
"cid": cid,
"user_id": user_id,
"platform_id": conversation.platform_id,
"title": conversation.title,
"persona_id": conversation.persona_id,
"created_at": conversation.created_at,
"updated_at": conversation.updated_at,
"content": content,
}
# 将记录转换为 JSON 字符串并添加到 JSONL
jsonl_lines.append(json.dumps(export_record, ensure_ascii=False))
exported_count += 1
except Exception as e:
failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}")
logger.error(
f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}"
)
if exported_count == 0:
return Response().error("没有成功导出任何对话").__dict__
# 创建 JSONL 内容
jsonl_content = "\n".join(jsonl_lines)
# 创建一个内存文件对象
file_obj = BytesIO(jsonl_content.encode("utf-8"))
file_obj.seek(0)
# 生成文件名
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"astrbot_conversations_export_{timestamp}.jsonl"
# 返回文件流
return await send_file(
file_obj,
mimetype="application/jsonl",
as_attachment=True,
attachment_filename=filename,
)
except Exception as e:
logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}")
return Response().error(f"批量导出对话失败: {e!s}").__dict__
+78 -253
View File
@@ -48,7 +48,6 @@ class KnowledgeBaseRoute(Route):
# 文档管理 # 文档管理
"/kb/document/list": ("GET", self.list_documents), "/kb/document/list": ("GET", self.list_documents),
"/kb/document/upload": ("POST", self.upload_document), "/kb/document/upload": ("POST", self.upload_document),
"/kb/document/import": ("POST", self.import_documents),
"/kb/document/upload/url": ("POST", self.upload_document_from_url), "/kb/document/upload/url": ("POST", self.upload_document_from_url),
"/kb/document/upload/progress": ("GET", self.get_upload_progress), "/kb/document/upload/progress": ("GET", self.get_upload_progress),
"/kb/document/get": ("GET", self.get_document), "/kb/document/get": ("GET", self.get_document),
@@ -67,65 +66,6 @@ class KnowledgeBaseRoute(Route):
def _get_kb_manager(self): def _get_kb_manager(self):
return self.core_lifecycle.kb_manager return self.core_lifecycle.kb_manager
def _init_task(self, task_id: str, status: str = "pending") -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": None,
"error": None,
}
def _set_task_result(
self, task_id: str, status: str, result: any = None, error: str | None = None
) -> None:
self.upload_tasks[task_id] = {
"status": status,
"result": result,
"error": error,
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = status
def _update_progress(
self,
task_id: str,
*,
status: str | None = None,
file_index: int | None = None,
file_name: str | None = None,
stage: str | None = None,
current: int | None = None,
total: int | None = None,
) -> None:
if task_id not in self.upload_progress:
return
p = self.upload_progress[task_id]
if status is not None:
p["status"] = status
if file_index is not None:
p["file_index"] = file_index
if file_name is not None:
p["file_name"] = file_name
if stage is not None:
p["stage"] = stage
if current is not None:
p["current"] = current
if total is not None:
p["total"] = total
def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str):
async def _callback(stage: str, current: int, total: int):
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage=stage,
current=current,
total=total,
)
return _callback
async def _background_upload_task( async def _background_upload_task(
self, self,
task_id: str, task_id: str,
@@ -140,7 +80,11 @@ class KnowledgeBaseRoute(Route):
"""后台上传任务""" """后台上传任务"""
try: try:
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="processing") self.upload_tasks[task_id] = {
"status": "processing",
"result": None,
"error": None,
}
self.upload_progress[task_id] = { self.upload_progress[task_id] = {
"status": "processing", "status": "processing",
"file_index": 0, "file_index": 0,
@@ -156,20 +100,30 @@ class KnowledgeBaseRoute(Route):
for file_idx, file_info in enumerate(files_to_upload): for file_idx, file_info in enumerate(files_to_upload):
try: try:
# 更新整体进度 # 更新整体进度
self._update_progress( self.upload_progress[task_id].update(
task_id, {
status="processing", "status": "processing",
file_index=file_idx, "file_index": file_idx,
file_name=file_info["file_name"], "file_name": file_info["file_name"],
stage="parsing", "stage": "parsing",
current=0, "current": 0,
total=100, "total": 100,
},
) )
# 创建进度回调函数 # 创建进度回调函数
progress_callback = self._make_progress_callback( async def progress_callback(stage, current, total):
task_id, file_idx, file_info["file_name"] if task_id in self.upload_progress:
) self.upload_progress[task_id].update(
{
"status": "processing",
"file_index": file_idx,
"file_name": file_info["file_name"],
"stage": stage,
"current": current,
"total": total,
},
)
doc = await kb_helper.upload_document( doc = await kb_helper.upload_document(
file_name=file_info["file_name"], file_name=file_info["file_name"],
@@ -200,99 +154,23 @@ class KnowledgeBaseRoute(Route):
"failed_count": len(failed_docs), "failed_count": len(failed_docs),
} }
self._set_task_result(task_id, "completed", result=result) self.upload_tasks[task_id] = {
"status": "completed",
"result": result,
"error": None,
}
self.upload_progress[task_id]["status"] = "completed"
except Exception as e: except Exception as e:
logger.error(f"后台上传任务 {task_id} 失败: {e}") logger.error(f"后台上传任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e)) self.upload_tasks[task_id] = {
"status": "failed",
async def _background_import_task( "result": None,
self, "error": str(e),
task_id: str,
kb_helper,
documents: list,
batch_size: int,
tasks_limit: int,
max_retries: int,
):
"""后台导入预切片文档任务"""
try:
# 初始化任务状态
self._init_task(task_id, status="processing")
self.upload_progress[task_id] = {
"status": "processing",
"file_index": 0,
"file_total": len(documents),
"stage": "waiting",
"current": 0,
"total": 100,
} }
if task_id in self.upload_progress:
uploaded_docs = [] self.upload_progress[task_id]["status"] = "failed"
failed_docs = []
for file_idx, doc_info in enumerate(documents):
file_name = doc_info.get("file_name", f"imported_doc_{file_idx}")
chunks = doc_info.get("chunks", [])
try:
# 更新整体进度
self._update_progress(
task_id,
status="processing",
file_index=file_idx,
file_name=file_name,
stage="importing",
current=0,
total=100,
)
# 创建进度回调函数
progress_callback = self._make_progress_callback(
task_id, file_idx, file_name
)
# 调用 upload_document,传入 pre_chunked_text
doc = await kb_helper.upload_document(
file_name=file_name,
file_content=None, # 预切片模式下不需要原始内容
file_type=doc_info.get("file_type")
or (
file_name.rsplit(".", 1)[-1].lower()
if "." in file_name
else "txt"
),
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
progress_callback=progress_callback,
pre_chunked_text=chunks,
)
uploaded_docs.append(doc.model_dump())
except Exception as e:
logger.error(f"导入文档 {file_name} 失败: {e}")
failed_docs.append(
{"file_name": file_name, "error": str(e)},
)
# 更新任务完成状态
result = {
"task_id": task_id,
"uploaded": uploaded_docs,
"failed": failed_docs,
"total": len(documents),
"success_count": len(uploaded_docs),
"failed_count": len(failed_docs),
}
self._set_task_result(task_id, "completed", result=result)
except Exception as e:
logger.error(f"后台导入任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e))
async def list_kbs(self): async def list_kbs(self):
"""获取知识库列表 """获取知识库列表
@@ -736,7 +614,11 @@ class KnowledgeBaseRoute(Route):
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="pending") self.upload_tasks[task_id] = {
"status": "pending",
"result": None,
"error": None,
}
# 启动后台任务 # 启动后台任务
asyncio.create_task( asyncio.create_task(
@@ -771,93 +653,6 @@ class KnowledgeBaseRoute(Route):
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return Response().error(f"上传文档失败: {e!s}").__dict__ return Response().error(f"上传文档失败: {e!s}").__dict__
def _validate_import_request(self, data: dict):
kb_id = data.get("kb_id")
if not kb_id:
raise ValueError("缺少参数 kb_id")
documents = data.get("documents")
if not documents or not isinstance(documents, list):
raise ValueError("缺少参数 documents 或格式错误")
for doc in documents:
if "file_name" not in doc or "chunks" not in doc:
raise ValueError("文档格式错误,必须包含 file_name 和 chunks")
if not isinstance(doc["chunks"], list):
raise ValueError("chunks 必须是列表")
if not all(
isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"]
):
raise ValueError("chunks 必须是非空字符串列表")
batch_size = data.get("batch_size", 32)
tasks_limit = data.get("tasks_limit", 3)
max_retries = data.get("max_retries", 3)
return kb_id, documents, batch_size, tasks_limit, max_retries
async def import_documents(self):
"""导入预切片文档
Body:
- kb_id: 知识库 ID (必填)
- documents: 文档列表 (必填)
- file_name: 文件名 (必填)
- chunks: 切片列表 (必填, list[str])
- file_type: 文件类型 (可选, 默认从文件名推断或为 txt)
- batch_size: 批处理大小 (可选, 默认32)
- tasks_limit: 并发任务限制 (可选, 默认3)
- max_retries: 最大重试次数 (可选, 默认3)
"""
try:
kb_manager = self._get_kb_manager()
data = await request.json
kb_id, documents, batch_size, tasks_limit, max_retries = (
self._validate_import_request(data)
)
# 获取知识库
kb_helper = await kb_manager.get_kb(kb_id)
if not kb_helper:
return Response().error("知识库不存在").__dict__
# 生成任务ID
task_id = str(uuid.uuid4())
# 初始化任务状态
self._init_task(task_id, status="pending")
# 启动后台任务
asyncio.create_task(
self._background_import_task(
task_id=task_id,
kb_helper=kb_helper,
documents=documents,
batch_size=batch_size,
tasks_limit=tasks_limit,
max_retries=max_retries,
),
)
return (
Response()
.ok(
{
"task_id": task_id,
"doc_count": len(documents),
"message": "import task created, processing in background",
},
)
.__dict__
)
except ValueError as e:
return Response().error(str(e)).__dict__
except Exception as e:
logger.error(f"导入文档失败: {e}")
logger.error(traceback.format_exc())
return Response().error(f"导入文档失败: {e!s}").__dict__
async def get_upload_progress(self): async def get_upload_progress(self):
"""获取上传进度和结果 """获取上传进度和结果
@@ -1165,7 +960,11 @@ class KnowledgeBaseRoute(Route):
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="pending") self.upload_tasks[task_id] = {
"status": "pending",
"result": None,
"error": None,
}
# 启动后台任务 # 启动后台任务
asyncio.create_task( asyncio.create_task(
@@ -1218,7 +1017,11 @@ class KnowledgeBaseRoute(Route):
"""后台上传URL任务""" """后台上传URL任务"""
try: try:
# 初始化任务状态 # 初始化任务状态
self._init_task(task_id, status="processing") self.upload_tasks[task_id] = {
"status": "processing",
"result": None,
"error": None,
}
self.upload_progress[task_id] = { self.upload_progress[task_id] = {
"status": "processing", "status": "processing",
"file_index": 0, "file_index": 0,
@@ -1230,7 +1033,18 @@ class KnowledgeBaseRoute(Route):
} }
# 创建进度回调函数 # 创建进度回调函数
progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") async def progress_callback(stage, current, total):
if task_id in self.upload_progress:
self.upload_progress[task_id].update(
{
"status": "processing",
"file_index": 0,
"file_name": f"URL: {url}",
"stage": stage,
"current": current,
"total": total,
},
)
# 上传文档 # 上传文档
doc = await kb_helper.upload_from_url( doc = await kb_helper.upload_from_url(
@@ -1255,9 +1069,20 @@ class KnowledgeBaseRoute(Route):
"failed_count": 0, "failed_count": 0,
} }
self._set_task_result(task_id, "completed", result=result) self.upload_tasks[task_id] = {
"status": "completed",
"result": result,
"error": None,
}
self.upload_progress[task_id]["status"] = "completed"
except Exception as e: except Exception as e:
logger.error(f"后台上传URL任务 {task_id} 失败: {e}") logger.error(f"后台上传URL任务 {task_id} 失败: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
self._set_task_result(task_id, "failed", error=str(e)) self.upload_tasks[task_id] = {
"status": "failed",
"result": None,
"error": str(e),
}
if task_id in self.upload_progress:
self.upload_progress[task_id]["status"] = "failed"
+1 -5
View File
@@ -124,11 +124,7 @@ class PluginRoute(Route):
session.get(url) as response, session.get(url) as response,
): ):
if response.status == 200: if response.status == 200:
try: remote_data = await response.json()
remote_data = await response.json()
except aiohttp.ContentTypeError:
remote_text = await response.text()
remote_data = json.loads(remote_text)
# 检查远程数据是否为空 # 检查远程数据是否为空
if not remote_data or ( if not remote_data or (
+4 -20
View File
@@ -3,7 +3,6 @@ import traceback
from quart import request from quart import request
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.star import star_map from astrbot.core.star import star_map
@@ -297,30 +296,15 @@ class ToolsRoute(Route):
"""获取所有注册的工具列表""" """获取所有注册的工具列表"""
try: try:
tools = self.tool_mgr.func_list tools = self.tool_mgr.func_list
tools_dict = [] tools_dict = [
for tool in tools: {
if isinstance(tool, MCPTool):
origin = "mcp"
origin_name = tool.mcp_server_name
elif tool.handler_module_path and star_map.get(
tool.handler_module_path
):
star = star_map[tool.handler_module_path]
origin = "plugin"
origin_name = star.name
else:
origin = "unknown"
origin_name = "unknown"
tool_info = {
"name": tool.name, "name": tool.name,
"description": tool.description, "description": tool.description,
"parameters": tool.parameters, "parameters": tool.parameters,
"active": tool.active, "active": tool.active,
"origin": origin,
"origin_name": origin_name,
} }
tools_dict.append(tool_info) for tool in tools
]
return Response().ok(data=tools_dict).__dict__ return Response().ok(data=tools_dict).__dict__
except Exception as e: except Exception as e:
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
-1
View File
@@ -67,7 +67,6 @@ class AstrBotDashboard:
core_lifecycle, core_lifecycle,
core_lifecycle.plugin_manager, core_lifecycle.plugin_manager,
) )
self.command_route = CommandRoute(self.context)
self.cr = ConfigRoute(self.context, core_lifecycle) self.cr = ConfigRoute(self.context, core_lifecycle)
self.lr = LogRoute(self.context, core_lifecycle.log_broker) self.lr = LogRoute(self.context, core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context) self.sfr = StaticFileRoute(self.context)
-134
View File
@@ -1,134 +0,0 @@
#!/usr/bin/env python3
"""
Use Nuitka to build the AstrBot project into standalone executables
"""
import os
import platform
import subprocess
import sys
from pathlib import Path
def get_platform_info():
"""fetch the current platform information"""
system = platform.system()
machine = platform.machine()
return system, machine
def build_with_nuitka():
"""use Nuitka to build the project"""
system, machine = get_platform_info()
print(f"🚀 Starting build for {system} ({machine}) platform...")
# Output directory
output_dir = Path("build/nuitka")
output_dir.mkdir(parents=True, exist_ok=True)
# Base Nuitka command
nuitka_cmd = [
sys.executable,
"-m",
"nuitka",
"--standalone", # Create standalone directory
"--onefile", # Single file mode
"--follow-imports", # Follow all imports
"--enable-plugin=multiprocessing", # Enable multiprocessing support
"--output-dir=build/nuitka", # Output directory
"--quiet", # Reduce output verbosity
"--assume-yes-for-downloads", # Automatically download dependencies
"--jobs=4", # Use multiple CPU cores
]
# include specific packages
include_packages = [
"astrbot",
]
for pkg in include_packages:
nuitka_cmd.extend([f"--include-package={pkg}"])
# include data directories
# data_includes = [
# "data/config",
# "data/plugins",
# "data/temp",
# ]
# for data_dir in data_includes:
# if os.path.exists(data_dir):
# nuitka_cmd.extend([f"--include-data-dir={data_dir}={data_dir}"])
# include packages directory (built-in plugins)
# if os.path.exists("packages"):
# nuitka_cmd.extend(["--include-data-dir=packages=packages"])
# Platform specific settings
if system == "Darwin": # macOS
nuitka_cmd.extend(
[
"--macos-create-app-bundle", # Create .app bundle
"--macos-app-name=AstrBot",
]
)
# macOS icon (if exists)
icon_path = "dashboard/src-tauri/icons/icon.icns"
if os.path.exists(icon_path):
nuitka_cmd.extend([f"--macos-app-icon={icon_path}"])
elif system == "Windows":
nuitka_cmd.extend(
[
"--windows-console-mode=disable", # 无控制台窗口
]
)
# Windows icon (if exists)
icon_path = "dashboard/src-tauri/icons/icon.ico"
if os.path.exists(icon_path):
nuitka_cmd.extend([f"--windows-icon-from-ico={icon_path}"])
# Main file to compile
nuitka_cmd.append("main.py")
print(f"📦 Executing command: {' '.join(nuitka_cmd)}")
try:
subprocess.run(nuitka_cmd, check=True)
print("✅ Nuitka build successful!")
# Find the generated executable
if system == "Darwin":
built_file = list(output_dir.glob("*.app"))
if built_file:
print(f"Generated macOS app: {built_file[0]}")
elif system == "Windows":
built_file = list(output_dir.glob("*.exe"))
if built_file:
print(f"Generated Windows executable: {built_file[0]}")
else: # Linux
built_file = list(output_dir.glob("main.bin"))
if built_file:
print(f"Generated Linux executable: {built_file[0]}")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Nuitka build failed: {e}")
return False
if __name__ == "__main__":
print("=" * 60)
print("AstrBot Nuitka Builder")
print("=" * 60)
# 构建
if build_with_nuitka():
print("\n" + "=" * 60)
print("🎉 Build Complete!")
print("=" * 60)
else:
print("\n" + "=" * 60)
print("❌ Build Failed")
print("=" * 60)
sys.exit(1)
-134
View File
@@ -1,134 +0,0 @@
#!/usr/bin/env python3
"""
Use PyInstaller to build the AstrBot project into standalone executables
"""
import platform
import subprocess
import sys
from pathlib import Path
def get_platform_info():
"""fetch the current platform information"""
system = platform.system()
machine = platform.machine()
return system, machine
def build_with_pyinstaller():
"""use PyInstaller to build the project"""
system, machine = get_platform_info()
print(f"🚀 Starting build for {system} ({machine}) platform...")
# Output directory
output_dir = Path("build/pyinstaller")
output_dir.mkdir(parents=True, exist_ok=True)
# Base PyInstaller command
pyinstaller_cmd = [
sys.executable,
"-m",
"PyInstaller",
"--clean", # Clean cache before build
"--noconfirm", # Replace output directory without asking
"--onefile", # Single file mode
"--distpath=build/pyinstaller/dist", # Distribution directory
"--workpath=build/pyinstaller/build", # Work directory
"--specpath=build/pyinstaller", # Spec file directory
"--name=AstrBot", # Output executable name
]
# Platform specific settings
# if system == "Darwin": # macOS
# # macOS icon (if exists)
# icon_path = "dashboard/src-tauri/icons/icon.icns"
# if os.path.exists(icon_path):
# pyinstaller_cmd.extend([f"--icon={icon_path}"])
# # Create .app bundle
# pyinstaller_cmd.extend(["--windowed"])
# elif system == "Windows":
# # Windows icon (if exists)
# icon_path = "dashboard/src-tauri/icons/icon.ico"
# if os.path.exists(icon_path):
# pyinstaller_cmd.extend([f"--icon={icon_path}"])
# # No console window
# pyinstaller_cmd.extend(["--windowed"])
# else: # Linux
# pyinstaller_cmd.extend(["--console"])
# Main file to compile
pyinstaller_cmd.append("main.py")
print(f"📦 Executing command: {' '.join(pyinstaller_cmd)}")
try:
subprocess.run(pyinstaller_cmd, check=True)
print("✅ PyInstaller build successful!")
# Find the generated executable
dist_dir = output_dir / "dist"
if system == "Darwin":
built_file = list(dist_dir.glob("AstrBot.app"))
if not built_file:
built_file = list(dist_dir.glob("AstrBot"))
if built_file:
print(f"📱 Generated macOS app: {built_file[0]}")
elif system == "Windows":
built_file = list(dist_dir.glob("AstrBot.exe"))
if built_file:
print(f"💻 Generated Windows executable: {built_file[0]}")
else: # Linux
built_file = list(dist_dir.glob("AstrBot"))
if built_file:
print(f"🐧 Generated Linux executable: {built_file[0]}")
print(f"\n📁 Output directory: {dist_dir.absolute()}")
return True
except subprocess.CalledProcessError as e:
print(f"❌ PyInstaller build failed: {e}")
return False
except Exception as e:
print(f"❌ Unexpected error: {e}")
return False
def install_pyinstaller():
"""Install PyInstaller if not already installed"""
try:
import PyInstaller
print(f"✅ PyInstaller already installed (version {PyInstaller.__version__})")
return True
except ImportError:
print("📥 PyInstaller not found, installing...")
try:
subprocess.run(
[sys.executable, "-m", "pip", "install", "pyinstaller"], check=True
)
print("✅ PyInstaller installed successfully!")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Failed to install PyInstaller: {e}")
return False
if __name__ == "__main__":
print("=" * 60)
print("AstrBot PyInstaller Builder")
print("=" * 60)
# Check and install PyInstaller
if not install_pyinstaller():
sys.exit(1)
# Build
if build_with_pyinstaller():
print("\n" + "=" * 60)
print("🎉 Build Complete!")
print("=" * 60)
else:
print("\n" + "=" * 60)
print("❌ Build Failed")
print("=" * 60)
sys.exit(1)
-19
View File
@@ -1,19 +0,0 @@
## What's Changed
### 新增
- 支持自定义插件源。
- 支持飞书(Lark)的 Webhook 模式(将事件推送至开发者服务器)。
- 支持 “禁用自带指令” 快捷配置项,启用后将禁用所有 AstrBot 自带指令。入口: WebUI -> 配置文件 -> 平台配置。
### 优化
- 从 WebUI 移除了开发版本渠道。
- 当试图测试"Agent Runner"时,提示前往配置文件页测试。
- WebUI 列表项支持批量粘贴、回车创建项目。
### 修复
- Gemini API 部分调用失败的问题。
- WebUI 插件安装加载 Dialog 关闭按钮在手机端下显示异常的问题。
- 部分情况下,WebUI 日志显示不全的问题。
-3
View File
@@ -1,3 +0,0 @@
## What's Changed
-
-17
View File
@@ -1,17 +0,0 @@
## What's Changed
### 修复
- 企业自部署飞书(自定义 domain)可以接收消息但无法发送消息的问题。
- 安装插件 Dialog 的深色样式问题。
### 优化
- 避免某些插件在流式响应结束后重d复发送消息的问题。
### 新增
- 支持在对话管理批量导出对话轨迹数据为 `jsonl` 格式文件。入口:WebUI -> 对话管理 -> 批量选中 -> 导出。
- 支持对 TTS(文本转语音)设置概率触发。
- (插件开发)支持在 schema 中对 float 和 int 类型设置 `slider` 滑块控件。例如 `slider: {min: 0, max: 1, step: 0.1}`
- (插件开发)支持 key-value 存储功能。例如使用 `await self.put_kv_data("key", value)`, `await self.get_kv_data("key", default_value)``await self.delete_kv_data("key")`
-225
View File
@@ -1,225 +0,0 @@
# AstrBot Dashboard - Tauri 桌面应用
本项目现已支持通过 Tauri 构建为桌面应用,同时保持与 Web 版本的兼容性。
## 环境要求
### 系统依赖
**macOS:**
```bash
# 安装 Xcode Command Line Tools
xcode-select --install
```
**Windows:**
- 安装 [Microsoft Visual Studio C++ Build Tools](https://visualstudio.microsoft.com/visual-cpp-build-tools/)
- 安装 [WebView2](https://developer.microsoft.com/en-us/microsoft-edge/webview2/)
**Linux (Ubuntu/Debian):**
```bash
sudo apt update
sudo apt install libwebkit2gtk-4.0-dev \
build-essential \
curl \
wget \
file \
libssl-dev \
libgtk-3-dev \
libayatana-appindicator3-dev \
librsvg2-dev
```
### Rust 环境
```bash
# 安装 Rust
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
# 验证安装
rustc --version
cargo --version
```
## 安装依赖
```bash
cd dashboard
npm install
```
## 开发模式
### Web 端开发(不变)
```bash
npm run dev
```
访问 http://localhost:3000
### 桌面端开发
```bash
npm run tauri:dev
```
这会同时启动:
1. Vite 开发服务器(端口 3000)
2. Tauri 桌面应用窗口
热重载功能正常工作,修改代码后会自动刷新。
## 构建
### Web 端构建(不变)
```bash
npm run build
```
输出目录:`dist/`
### 桌面端构建
```bash
npm run tauri:build
```
构建产物位置:
- **macOS**: `src-tauri/target/release/bundle/dmg/`
- **Windows**: `src-tauri/target/release/bundle/msi/`
- **Linux**: `src-tauri/target/release/bundle/deb/``appimage/`
## 图标设置
### 自动生成图标
准备一个至少 512x512 像素的 PNG 图标,然后运行:
```bash
npm run tauri icon path/to/your/icon.png
```
### 手动设置图标
将以下图标放入 `src-tauri/icons/` 目录:
- `32x32.png`
- `128x128.png`
- `128x128@2x.png`
- `icon.icns` (macOS)
- `icon.ico` (Windows)
## 代码兼容性
项目已配置为同时支持 Web 和桌面端,使用相同的代码库。
### 环境检测工具
`src/utils/tauri.ts` 中提供了环境检测工具:
```typescript
import { isTauri, isWeb, PlatformAPI } from '@/utils/tauri';
// 检测运行环境
if (isTauri()) {
console.log('运行在桌面应用中');
} else {
console.log('运行在浏览器中');
}
// 获取正确的 API 端点
const baseURL = PlatformAPI.getBaseURL();
```
### API 调用注意事项
- **Web 端**: 使用 Vite 代理,API 路径为 `/api/*`
- **桌面端**: 直接连接到 `http://127.0.0.1:6185`
已在 `PlatformAPI.getBaseURL()` 中处理,使用 axios 时:
```typescript
import axios from 'axios';
import { PlatformAPI } from '@/utils/tauri';
const api = axios.create({
baseURL: PlatformAPI.getBaseURL()
});
```
## 配置说明
### tauri.conf.json
主要配置项:
- `build.devPath`: 开发服务器地址(http://localhost:3000
- `build.distDir`: 构建输出目录(../dist
- `tauri.allowlist`: API 权限配置
- `tauri.windows`: 窗口配置(大小、标题等)
### 安全性
默认配置已启用必要的权限:
- 文件系统访问(限定在 APPDATA 目录)
- HTTP 请求(限定到本地后端)
- 窗口控制
- 对话框(打开/保存文件)
可在 `tauri.conf.json``allowlist` 部分调整权限。
## 后端连接
桌面应用需要后端服务运行在 `http://127.0.0.1:6185`
### 启动流程
1. 启动 AstrBot 后端:
```bash
cd /path/to/AstrBot
uv run main.py
```
2. 启动桌面应用:
```bash
cd dashboard
npm run tauri:dev
```
或直接运行打包后的应用(后端需要已启动)。
## 常见问题
### Q: 桌面应用无法连接到后端?
确保:
1. AstrBot 后端正在运行(`uv run main.py`
2. 后端监听在 `127.0.0.1:6185`
3. 防火墙未阻止连接
### Q: 图标未显示?
检查 `src-tauri/icons/` 目录中是否有所需的图标文件,或使用 `npm run tauri icon` 命令生成。
### Q: 构建失败?
- 确保已安装 Rust 和系统依赖
- 运行 `cargo clean` 清理缓存后重试
- 检查 Rust 版本(需要 1.60+
### Q: Web 端功能是否受影响?
不受影响。`npm run dev``npm run build` 的行为完全不变。
## 开发建议
1. **优先使用 Web 端开发**: 更快的热重载,更好的调试体验
2. **定期测试桌面端**: 确保跨平台兼容性
3. **使用环境检测**: 针对不同平台提供最佳体验
4. **注意 API 差异**: Web 和桌面端的某些 API 可能有差异
## 更多资源
- [Tauri 官方文档](https://tauri.app/)
- [Tauri API 参考](https://tauri.app/v1/api/js/)
- [Tauri Discord 社区](https://discord.com/invite/tauri)
+1 -6
View File
@@ -10,14 +10,10 @@
"build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/", "build-prod": "vue-tsc --noEmit && vite build --base=/vue/free/",
"preview": "vite preview --port 5050", "preview": "vite preview --port 5050",
"typecheck": "vue-tsc --noEmit", "typecheck": "vue-tsc --noEmit",
"lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore", "lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore"
"tauri": "tauri",
"tauri:dev": "tauri dev",
"tauri:build": "tauri build"
}, },
"dependencies": { "dependencies": {
"@guolao/vue-monaco-editor": "^1.5.4", "@guolao/vue-monaco-editor": "^1.5.4",
"@tauri-apps/api": "^2.9.0",
"@tiptap/starter-kit": "2.1.7", "@tiptap/starter-kit": "2.1.7",
"@tiptap/vue-3": "2.1.7", "@tiptap/vue-3": "2.1.7",
"apexcharts": "3.42.0", "apexcharts": "3.42.0",
@@ -47,7 +43,6 @@
"devDependencies": { "devDependencies": {
"@mdi/font": "7.2.96", "@mdi/font": "7.2.96",
"@rushstack/eslint-patch": "1.3.3", "@rushstack/eslint-patch": "1.3.3",
"@tauri-apps/cli": "^2.9.4",
"@types/chance": "1.1.3", "@types/chance": "1.1.3",
"@types/markdown-it": "^14.1.2", "@types/markdown-it": "^14.1.2",
"@types/node": "^20.5.7", "@types/node": "^20.5.7",
-4509
View File
File diff suppressed because it is too large Load Diff
-3
View File
@@ -1,3 +0,0 @@
# Tauri specific
src-tauri/target/
src-tauri/WixTools/
-4692
View File
File diff suppressed because it is too large Load Diff
-27
View File
@@ -1,27 +0,0 @@
[package]
name = "astrbot-dashboard"
version = "4.5.6"
description = "AstrBot"
authors = ["AstrBot Team"]
license = "AGPL-3.0"
repository = "https://github.com/AstrBotDevs/AstrBot"
default-run = "astrbot-dashboard"
edition = "2021"
rust-version = "1.91.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[build-dependencies]
tauri-build = { version = "2", features = [] }
[dependencies]
serde_json = "1.0"
serde = { version = "1.0", features = ["derive"] }
tauri = { version = "2.9.2", features = ["macos-private-api", "protocol-asset"] }
tauri-plugin-opener = "2"
[features]
# this feature is used for production builds or when `devPath` points to the filesystem and the built-in dev server is disabled.
# If you use cargo directly instead of tauri's cli you can use this feature flag to switch between tauri's `dev` and `build` modes.
# DO NOT REMOVE!!
custom-protocol = [ "tauri/custom-protocol" ]
-3
View File
@@ -1,3 +0,0 @@
fn main() {
tauri_build::build()
}
File diff suppressed because one or more lines are too long
@@ -1 +0,0 @@
{}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 18 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.3 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 5.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 8.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 23 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 3.5 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 KiB

@@ -1,5 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<adaptive-icon xmlns:android="http://schemas.android.com/apk/res/android">
<foreground android:drawable="@mipmap/ic_launcher_foreground"/>
<background android:drawable="@color/ic_launcher_background"/>
</adaptive-icon>
Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.1 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.0 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 4.2 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 7.9 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 24 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 6.8 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 11 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 37 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 9.6 KiB

@@ -1,4 +0,0 @@
<?xml version="1.0" encoding="utf-8"?>
<resources>
<color name="ic_launcher_background">#fff</color>
</resources>
Binary file not shown.
Binary file not shown.

Before

Width:  |  Height:  |  Size: 27 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 47 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 602 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 KiB

Some files were not shown because too many files have changed in this diff Show More