Compare commits
56 Commits
perf/trace
...
v4.14.7
| Author | SHA1 | Date | |
|---|---|---|---|
| 0553f84d6c | |||
| 3fd89808ee | |||
| 96753821b7 | |||
| eca3ede7b0 | |||
| a7e580407c | |||
| 8bd1565696 | |||
| 03e0949067 | |||
| dbe8e33c4b | |||
| 952023db30 | |||
| 4e0b5063c6 | |||
| 30d1d55e3c | |||
| 1e9026d44c | |||
| e48950d260 | |||
| 5e5207da95 | |||
| def8b730b7 | |||
| 22a109c2ae | |||
| 6416707e35 | |||
| 4658998b85 | |||
| d233fb8b1e | |||
| fc2a67188f | |||
| d69592aaa8 | |||
| f3397f6f08 | |||
| be92e4f395 | |||
| 912e40e7f0 | |||
| 2876c43387 | |||
| 464882f206 | |||
| 6736fb85c2 | |||
| 1f75255950 | |||
| a954e75547 | |||
| d2b9997620 | |||
| 36432c4361 | |||
| 36f0d1f0f9 | |||
| f65b268bb2 | |||
| fe06dfcca3 | |||
| bc9043bc3f | |||
| 430694aae9 | |||
| c643e3c093 | |||
| ff46eef3b2 | |||
| a0c364aa81 | |||
| 0e0f923a49 | |||
| f2d637b935 | |||
| 96e61a4a92 | |||
| e42c1b6da8 | |||
| 387bba093e | |||
| 123cf9cb11 | |||
| 93277ffac9 | |||
| c091053ea8 | |||
| 8b9f2f1e70 | |||
| 25ca7bd71e | |||
| 093b37e04b | |||
| a12e27f9ab | |||
| ae6e0db053 | |||
| cd6bef4d78 | |||
| de1304dc6a | |||
| f835f63542 | |||
| 5deb045e47 |
@@ -0,0 +1,227 @@
|
||||
name: Desktop Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
ref:
|
||||
description: "Git ref to build (branch/tag/SHA)"
|
||||
required: false
|
||||
default: "master"
|
||||
tag:
|
||||
description: "Release tag to upload assets to (for example: v4.14.6)"
|
||||
required: false
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
|
||||
jobs:
|
||||
build-desktop:
|
||||
name: Build ${{ matrix.name }}
|
||||
runs-on: ${{ matrix.runner }}
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- name: linux-x64
|
||||
runner: ubuntu-24.04
|
||||
os: linux
|
||||
arch: amd64
|
||||
- name: linux-arm64
|
||||
runner: ubuntu-24.04-arm
|
||||
os: linux
|
||||
arch: arm64
|
||||
- name: windows-x64
|
||||
runner: windows-2022
|
||||
os: win
|
||||
arch: amd64
|
||||
- name: windows-arm64
|
||||
runner: windows-11-arm
|
||||
os: win
|
||||
arch: arm64
|
||||
- name: macos-x64
|
||||
runner: macos-15-intel
|
||||
os: mac
|
||||
arch: amd64
|
||||
- name: macos-arm64
|
||||
runner: macos-15
|
||||
os: mac
|
||||
arch: arm64
|
||||
env:
|
||||
CSC_IDENTITY_AUTO_DISCOVERY: "false"
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ inputs.ref || github.ref }}
|
||||
|
||||
- name: Setup uv
|
||||
uses: astral-sh/setup-uv@v6
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v6
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Setup pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
version: 10.28.2
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 20
|
||||
cache: "pnpm"
|
||||
cache-dependency-path: |
|
||||
dashboard/pnpm-lock.yaml
|
||||
desktop/pnpm-lock.yaml
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
uv sync
|
||||
pnpm --dir dashboard install --frozen-lockfile
|
||||
pnpm --dir desktop install --frozen-lockfile
|
||||
|
||||
- name: Build desktop package
|
||||
run: |
|
||||
pnpm --dir dashboard run build
|
||||
pnpm --dir desktop run build:webui
|
||||
pnpm --dir desktop run build:backend
|
||||
pnpm --dir desktop run sync:version
|
||||
pnpm --dir desktop exec electron-builder --publish never
|
||||
|
||||
- name: Resolve artifact tag
|
||||
id: tag
|
||||
shell: bash
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
tag="${GITHUB_REF_NAME}"
|
||||
elif [ -n "${{ inputs.tag }}" ]; then
|
||||
tag="${{ inputs.tag }}"
|
||||
else
|
||||
tag="$(git describe --tags --abbrev=0)"
|
||||
fi
|
||||
if [ -z "$tag" ]; then
|
||||
echo "Failed to resolve artifact tag." >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "tag=$tag" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Normalize artifact names
|
||||
shell: bash
|
||||
env:
|
||||
NAME_PREFIX: AstrBot-${{ steps.tag.outputs.tag }}-${{ matrix.arch }}-${{ matrix.os }}
|
||||
run: |
|
||||
shopt -s nullglob
|
||||
out_dir="desktop/dist/release"
|
||||
mkdir -p "$out_dir"
|
||||
files=(
|
||||
desktop/dist/*.AppImage
|
||||
desktop/dist/*.dmg
|
||||
desktop/dist/*.zip
|
||||
desktop/dist/*.exe
|
||||
)
|
||||
if [ ${#files[@]} -eq 0 ]; then
|
||||
echo "No desktop artifacts found to rename." >&2
|
||||
exit 1
|
||||
fi
|
||||
for src in "${files[@]}"; do
|
||||
file="$(basename "$src")"
|
||||
case "$file" in
|
||||
*.AppImage)
|
||||
dest="$out_dir/${NAME_PREFIX}.AppImage"
|
||||
;;
|
||||
*.dmg)
|
||||
dest="$out_dir/${NAME_PREFIX}.dmg"
|
||||
;;
|
||||
*.exe)
|
||||
dest="$out_dir/${NAME_PREFIX}.exe"
|
||||
;;
|
||||
*.zip)
|
||||
dest="$out_dir/${NAME_PREFIX}.zip"
|
||||
;;
|
||||
*)
|
||||
continue
|
||||
;;
|
||||
esac
|
||||
cp "$src" "$dest"
|
||||
done
|
||||
ls -la "$out_dir"
|
||||
|
||||
- name: Upload desktop artifacts
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: AstrBot-${{ steps.tag.outputs.tag }}-${{ matrix.arch }}-${{ matrix.os }}
|
||||
if-no-files-found: error
|
||||
path: desktop/dist/release/*
|
||||
|
||||
publish-release:
|
||||
name: Publish Release Assets
|
||||
runs-on: ubuntu-24.04
|
||||
needs: build-desktop
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
ref: ${{ inputs.ref || github.ref }}
|
||||
|
||||
- name: Resolve release tag
|
||||
id: tag
|
||||
shell: bash
|
||||
run: |
|
||||
if [ "${{ github.event_name }}" = "push" ]; then
|
||||
tag="${GITHUB_REF_NAME}"
|
||||
elif [ -n "${{ inputs.tag }}" ]; then
|
||||
tag="${{ inputs.tag }}"
|
||||
else
|
||||
tag="$(git describe --tags --abbrev=0)"
|
||||
fi
|
||||
if [ -z "$tag" ]; then
|
||||
echo "Failed to resolve release tag." >&2
|
||||
exit 1
|
||||
fi
|
||||
echo "tag=$tag" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Download built artifacts
|
||||
uses: actions/download-artifact@v6
|
||||
with:
|
||||
pattern: AstrBot-${{ steps.tag.outputs.tag }}-*
|
||||
path: release-assets
|
||||
merge-multiple: true
|
||||
|
||||
- name: Ensure release exists
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
shell: bash
|
||||
run: |
|
||||
tag="${{ steps.tag.outputs.tag }}"
|
||||
if ! gh release view "$tag" >/dev/null 2>&1; then
|
||||
gh release create "$tag" --title "$tag" --notes ""
|
||||
fi
|
||||
|
||||
- name: Remove stale desktop assets from release
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
shell: bash
|
||||
run: |
|
||||
tag="${{ steps.tag.outputs.tag }}"
|
||||
while IFS= read -r asset; do
|
||||
case "$asset" in
|
||||
*.AppImage|*.dmg|*.zip|*.exe|*.blockmap)
|
||||
gh release delete-asset "$tag" "$asset" -y || true
|
||||
;;
|
||||
esac
|
||||
done < <(gh release view "$tag" --json assets --jq '.assets[].name')
|
||||
|
||||
- name: Upload assets to release
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
shell: bash
|
||||
run: |
|
||||
tag="${{ steps.tag.outputs.tag }}"
|
||||
gh release upload "$tag" release-assets/* --clobber
|
||||
+9
-2
@@ -32,8 +32,15 @@ tests/astrbot_plugin_openai
|
||||
# Dashboard
|
||||
dashboard/node_modules/
|
||||
dashboard/dist/
|
||||
.pnpm-store/
|
||||
desktop/node_modules/
|
||||
desktop/dist/
|
||||
desktop/out/
|
||||
desktop/resources/backend/astrbot-backend*
|
||||
desktop/resources/backend/*.exe
|
||||
desktop/resources/webui/*
|
||||
desktop/resources/.pyinstaller/
|
||||
package-lock.json
|
||||
package.json
|
||||
yarn.lock
|
||||
|
||||
# Operating System
|
||||
@@ -53,4 +60,4 @@ IFLOW.md
|
||||
|
||||
# genie_tts data
|
||||
CharacterModels/
|
||||
GenieData/
|
||||
GenieData/
|
||||
|
||||
+1
-1
@@ -1 +1 @@
|
||||
3.10
|
||||
3.12
|
||||
@@ -26,6 +26,7 @@ Runs on `http://localhost:3000` by default.
|
||||
3. After finishing, use `ruff format .` and `ruff check .` to format and check the code.
|
||||
4. When committing, ensure to use conventional commits messages, such as `feat: add new agent for data analysis` or `fix: resolve bug in provider manager`.
|
||||
5. Use English for all new comments.
|
||||
6. For path handling, use `pathlib.Path` instead of string paths, and use `astrbot.core.utils.path_utils` to get the AstrBot data and temp directory.
|
||||
|
||||
## PR instructions
|
||||
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
我需要让 Agent 能够在未来提醒自己去做某些事情,这样 Agent 能够主动地去完成一些任务,而不是等用户主动来下达命令。
|
||||
|
||||
你需要实现一个 CronJob 系统,允许 Agent 创建未来任务,并且在未来的某个时间点自动触发这些任务的执行.
|
||||
|
||||
CronJob 系统分为 BasicCronJob 和 ActiveAgentCronJob 两种类型。前者只是简单的提供一个定时任务功能(给插件用),而后者则允许 Agent 主动地去完成一些任务。BasicCronJob 不必多说,就是定时执行某个函数。对于 ActiveAgentCronJob,Agent 应该可以主动管理(比如通过Tool来管理)这些 CronJobs,当添加的时候,Agent 可以给 CronJob 捎一段文字,以说明未来的自己需要做什么事情。比如说,Agent 在听到用户 “每天早上都给我整理一份今日早报” 之后,应该可以创建 Cron Job,并且自己写脚本来完成这个任务,并且注册 cron job。Agent 给未来的自己捎去的信息应该只是呈现为一段文字,这样可以保持设计简约。当触发后, CronJobManager 会调用 MainAgent 的一轮循环,MainAgent 通过上下文知道这是一个定时任务触发的循环,从而执行相应的操作。
|
||||
|
||||
此外,我还有一个需求,后台长任务。需要给当前的 FunctionTool 类增加一个属性,is_background_task: bool = False,插件可以通过这个属性来声明这是一个异步任务。这是为了解决一些 Tool 需要长时间运行的问题,比如 Deep Search tool 需要长时间搜索网页内容、Sub Agent 需要长时间运行来完成一个复杂任务。
|
||||
|
||||
基于上面的讨论,我觉得,应该:
|
||||
|
||||
1. 需要给当前的 FunctionTool 类增加一个属性is_background_task: bool = False,tool runner 在执行这个 tool 的时候,如果发现是后台任务,就不等待结果返回,而是直接返回一个任务 ID (已经创建成功提示)的结果,tool runner 在后台继续执行这个任务。当任务完成之后,任务的结果回传给 MainAgent(其实就是再执行一次 main agent loop,但是上下文应该是最新的),并且 MainAgent 此时应该有 send_message_to_user 的工具,通过这个工具可以选择是否主动通知用户任务完成的结果。
|
||||
2. 增加一个 CronJobManager 类,负责管理所有的定时任务。Agent 可以通过调用这个类的方法来创建、删除、修改定时任务。通过 cron expression 来定义触发条件。
|
||||
3. CronJobManager 除了管理普通的定时任务(比如插件可能有一些自己的定时任务),还有一种特殊的任务类型,就是上面提到的主动型 Agent 任务。用户提需求,MainAgent 选择性地调用 CronJobManager 的方法来创建这些任务,并且在任务触发时,CronJobManager 的回调就是执行 MainAgent 的一轮循环(需要加 send_message_to_user tool),MainAgent 通过上下文知道这是一个定时任务触发的循环,从而执行相应的操作。
|
||||
4. WebUI 需要增加 Cron Job 管理界面,用户可以在界面上查看、创建、修改、删除定时任务。对于主动型 Agent 任务,用户可以看到任务的描述、触发条件等信息。
|
||||
5. 除此之外,现在的代码中已经有了 subagent 的管理。WebUI 可以创建 SubAgent,但是还没写完。除了结合上面我说的之外,你还需要将 SubAgent 与 Persona 结合起来——因为 Persona 是一个包含了 tool、skills、name、description 的完整体,所以 SubAgent 应该直接继承 Persona 的定义,而不是单独定义 SubAgent。SubAgent 本质上就是一个有特定角色和能力的 Persona!多么美妙的设计啊!
|
||||
6. 为了实现大一统,is_background_task = True 的时候,后台任务也挂到 CronJobManager 上去管理,只不过这个是立即触发的任务,不需要等到未来某个时间点才触发罢了。
|
||||
|
||||
我希望设计尽可能简单,但是强大。
|
||||
+2
-2
@@ -1,4 +1,4 @@
|
||||
FROM python:3.11-slim
|
||||
FROM python:3.12-slim
|
||||
WORKDIR /AstrBot
|
||||
|
||||
COPY . /AstrBot/
|
||||
@@ -23,7 +23,7 @@ RUN apt-get update && apt-get install -y curl gnupg \
|
||||
&& apt-get install -y nodejs
|
||||
|
||||
RUN python -m pip install uv \
|
||||
&& echo "3.11" > .python-version
|
||||
&& echo "3.12" > .python-version
|
||||
RUN uv pip install -r requirements.txt --no-cache-dir --system
|
||||
RUN uv pip install socksio uv pilk --no-cache-dir --system
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@
|
||||
<a href="https://github.com/AstrBotDevs/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可接入主流即时通讯软件,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建生产可用的 AI 应用。
|
||||
AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、Telegram、企业微信、飞书、钉钉、Slack、等数十款主流即时通讯软件上部署,此外还内置类似 OpenWebUI 的轻量化 ChatUI,为个人、开发者和团队打造可靠、可扩展的对话式智能基础设施。无论是个人 AI 伙伴、智能客服、自动化助手,还是企业知识库,AstrBot 都能在你的即时通讯软件平台的工作流中快速构建 AI 应用。
|
||||
|
||||

|
||||
|
||||
@@ -50,6 +50,23 @@ AstrBot 是一个开源的一站式 Agent 聊天机器人平台,可接入主
|
||||
7. 🌈 Web ChatUI 支持,ChatUI 内置代理沙盒、网页搜索等。
|
||||
8. 🌐 国际化(i18n)支持。
|
||||
|
||||
<br>
|
||||
|
||||
<table align="center">
|
||||
<tr align="center">
|
||||
<th>💙 角色扮演 & 情感陪伴</th>
|
||||
<th>✨ 主动式 Agent</th>
|
||||
<th>🚀 通用 Agentic 能力</th>
|
||||
<th>🧩 900+ 社区插件</th>
|
||||
</tr>
|
||||
<tr>
|
||||
<td align="center"><p align="center"><img width="984" height="1746" alt="99b587c5d35eea09d84f33e6cf6cfd4f" src="https://github.com/user-attachments/assets/89196061-3290-458d-b51f-afa178049f84" /></p></td>
|
||||
<td align="center"><p align="center"><img width="976" height="1612" alt="c449acd838c41d0915cc08a3824025b1" src="https://github.com/user-attachments/assets/f75368b4-e022-41dc-a9e0-131c3e73e32e" /></p></td>
|
||||
<td align="center"><p align="center"><img width="974" height="1732" alt="image" src="https://github.com/user-attachments/assets/e22a3968-87d7-4708-a7cd-e7f198c7c32e" /></p></td>
|
||||
<td align="center"><p align="center"><img width="976" height="1734" alt="image" src="https://github.com/user-attachments/assets/0952b395-6b4a-432a-8a50-c294b7f89750" /></p></td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## 快速开始
|
||||
|
||||
#### Docker 部署(推荐 🥳)
|
||||
@@ -115,6 +132,10 @@ uv run main.py
|
||||
|
||||
或者请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
#### 桌面端 Electron 打包
|
||||
|
||||
桌面端(Electron 打包,`pnpm` 工作流)构建流程请参阅:[`desktop/README.md`](desktop/README.md)。
|
||||
|
||||
## 支持的消息平台
|
||||
|
||||
**官方维护**
|
||||
@@ -247,8 +268,8 @@ pre-commit install
|
||||
|
||||
<div align="center">
|
||||
|
||||
_陪伴与能力从来不应该是对立面。我们希望创造的是一个既能理解情绪、给予陪伴,也能可靠完成工作的机器人。_
|
||||
|
||||
_私は、高性能ですから!_
|
||||
|
||||
<img src="https://files.astrbot.app/watashiwa-koseino-desukara.gif" width="100"/>
|
||||
</div
|
||||
|
||||
|
||||
@@ -117,6 +117,10 @@ uv run main.py
|
||||
|
||||
Or refer to the official documentation: [Deploy AstrBot from Source](https://astrbot.app/deploy/astrbot/cli.html).
|
||||
|
||||
#### Desktop Electron Build
|
||||
|
||||
For desktop build steps (Electron packaging, `pnpm` workflow), see [`desktop/README.md`](desktop/README.md).
|
||||
|
||||
## Supported Messaging Platforms
|
||||
|
||||
**Officially Maintained**
|
||||
|
||||
@@ -77,7 +77,6 @@ class Main(star.Star):
|
||||
|
||||
yield event.request_llm(
|
||||
prompt=prompt,
|
||||
func_tool_manager=self.context.get_llm_tool_manager(),
|
||||
session_id=event.session_id,
|
||||
conversation=conv,
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ class Main(Star):
|
||||
if p_settings.get("empty_mention_waiting_need_reply", True):
|
||||
try:
|
||||
# 尝试使用 LLM 生成更生动的回复
|
||||
func_tools_mgr = self.context.get_llm_tool_manager()
|
||||
# func_tools_mgr = self.context.get_llm_tool_manager()
|
||||
|
||||
# 获取用户当前的对话信息
|
||||
curr_cid = await self.context.conversation_manager.get_curr_conversation_id(
|
||||
@@ -76,7 +76,6 @@ class Main(Star):
|
||||
"你友好地询问用户想要聊些什么或者需要什么帮助,回复要符合人设,不要太过机械化。"
|
||||
"请注意,你仅需要输出要回复用户的内容,不要输出其他任何东西"
|
||||
),
|
||||
func_tool_manager=func_tools_mgr,
|
||||
session_id=curr_cid,
|
||||
contexts=[],
|
||||
system_prompt="",
|
||||
|
||||
@@ -23,6 +23,7 @@ class Main(star.Star):
|
||||
"fetch_url",
|
||||
"web_search_tavily",
|
||||
"tavily_extract_web_page",
|
||||
"web_search_bocha",
|
||||
]
|
||||
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
@@ -30,6 +31,9 @@ class Main(star.Star):
|
||||
self.tavily_key_index = 0
|
||||
self.tavily_key_lock = asyncio.Lock()
|
||||
|
||||
self.bocha_key_index = 0
|
||||
self.bocha_key_lock = asyncio.Lock()
|
||||
|
||||
# 将 str 类型的 key 迁移至 list[str],并保存
|
||||
cfg = self.context.get_config()
|
||||
provider_settings = cfg.get("provider_settings")
|
||||
@@ -45,6 +49,14 @@ class Main(star.Star):
|
||||
provider_settings["websearch_tavily_key"] = []
|
||||
cfg.save_config()
|
||||
|
||||
bocha_key = provider_settings.get("websearch_bocha_key")
|
||||
if isinstance(bocha_key, str):
|
||||
if bocha_key:
|
||||
provider_settings["websearch_bocha_key"] = [bocha_key]
|
||||
else:
|
||||
provider_settings["websearch_bocha_key"] = []
|
||||
cfg.save_config()
|
||||
|
||||
self.bing_search = Bing()
|
||||
self.sogo_search = Sogo()
|
||||
self.baidu_initialized = False
|
||||
@@ -341,7 +353,7 @@ class Main(star.Star):
|
||||
}
|
||||
)
|
||||
if result.favicon:
|
||||
sp.temorary_cache["_ws_favicon"][result.url] = result.favicon
|
||||
sp.temporary_cache["_ws_favicon"][result.url] = result.favicon
|
||||
# ret = "\n".join(ret_ls)
|
||||
ret = json.dumps({"results": ret_ls}, ensure_ascii=False)
|
||||
return ret
|
||||
@@ -382,6 +394,160 @@ class Main(star.Star):
|
||||
return "Error: Tavily web searcher does not return any results."
|
||||
return ret
|
||||
|
||||
async def _get_bocha_key(self, cfg: AstrBotConfig) -> str:
|
||||
"""并发安全的从列表中获取并轮换BoCha API密钥。"""
|
||||
bocha_keys = cfg.get("provider_settings", {}).get("websearch_bocha_key", [])
|
||||
if not bocha_keys:
|
||||
raise ValueError("错误:BoCha API密钥未在AstrBot中配置。")
|
||||
|
||||
async with self.bocha_key_lock:
|
||||
key = bocha_keys[self.bocha_key_index]
|
||||
self.bocha_key_index = (self.bocha_key_index + 1) % len(bocha_keys)
|
||||
return key
|
||||
|
||||
async def _web_search_bocha(
|
||||
self,
|
||||
cfg: AstrBotConfig,
|
||||
payload: dict,
|
||||
) -> list[SearchResult]:
|
||||
"""使用 BoCha 搜索引擎进行搜索"""
|
||||
bocha_key = await self._get_bocha_key(cfg)
|
||||
url = "https://api.bochaai.com/v1/web-search"
|
||||
header = {
|
||||
"Authorization": f"Bearer {bocha_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(
|
||||
url,
|
||||
json=payload,
|
||||
headers=header,
|
||||
) as response:
|
||||
if response.status != 200:
|
||||
reason = await response.text()
|
||||
raise Exception(
|
||||
f"BoCha web search failed: {reason}, status: {response.status}",
|
||||
)
|
||||
data = await response.json()
|
||||
data = data["data"]["webPages"]["value"]
|
||||
results = []
|
||||
for item in data:
|
||||
result = SearchResult(
|
||||
title=item.get("name"),
|
||||
url=item.get("url"),
|
||||
snippet=item.get("snippet"),
|
||||
favicon=item.get("siteIcon"),
|
||||
)
|
||||
results.append(result)
|
||||
return results
|
||||
|
||||
@llm_tool("web_search_bocha")
|
||||
async def search_from_bocha(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
query: str,
|
||||
freshness: str = "noLimit",
|
||||
summary: bool = False,
|
||||
include: str = "",
|
||||
exclude: str = "",
|
||||
count: int = 10,
|
||||
) -> str:
|
||||
"""
|
||||
A web search tool based on Bocha Search API, used to retrieve web pages
|
||||
related to the user's query.
|
||||
|
||||
Args:
|
||||
query (string): Required. User's search query.
|
||||
|
||||
freshness (string): Optional. Specifies the time range of the search.
|
||||
Supported values:
|
||||
- "noLimit": No time limit (default, recommended).
|
||||
- "oneDay": Within one day.
|
||||
- "oneWeek": Within one week.
|
||||
- "oneMonth": Within one month.
|
||||
- "oneYear": Within one year.
|
||||
- "YYYY-MM-DD..YYYY-MM-DD": Search within a specific date range.
|
||||
Example: "2025-01-01..2025-04-06".
|
||||
- "YYYY-MM-DD": Search on a specific date.
|
||||
Example: "2025-04-06".
|
||||
It is recommended to use "noLimit", as the search algorithm will
|
||||
automatically optimize time relevance. Manually restricting the
|
||||
time range may result in no search results.
|
||||
|
||||
summary (boolean): Optional. Whether to include a text summary
|
||||
for each search result.
|
||||
- True: Include summary.
|
||||
- False: Do not include summary (default).
|
||||
|
||||
include (string): Optional. Specifies the domains to include in
|
||||
the search. Multiple domains can be separated by "|" or ",".
|
||||
A maximum of 100 domains is allowed.
|
||||
Examples:
|
||||
- "qq.com"
|
||||
- "qq.com|m.163.com"
|
||||
|
||||
exclude (string): Optional. Specifies the domains to exclude from
|
||||
the search. Multiple domains can be separated by "|" or ",".
|
||||
A maximum of 100 domains is allowed.
|
||||
Examples:
|
||||
- "qq.com"
|
||||
- "qq.com|m.163.com"
|
||||
|
||||
count (number): Optional. Number of search results to return.
|
||||
- Range: 1–50
|
||||
- Default: 10
|
||||
The actual number of returned results may be less than the
|
||||
specified count.
|
||||
"""
|
||||
logger.info(f"web_searcher - search_from_bocha: {query}")
|
||||
cfg = self.context.get_config(umo=event.unified_msg_origin)
|
||||
# websearch_link = cfg["provider_settings"].get("web_search_link", False)
|
||||
if not cfg.get("provider_settings", {}).get("websearch_bocha_key", []):
|
||||
raise ValueError("Error: BoCha API key is not configured in AstrBot.")
|
||||
|
||||
# build payload
|
||||
payload = {
|
||||
"query": query,
|
||||
"count": count,
|
||||
}
|
||||
|
||||
# freshness:时间范围
|
||||
if freshness:
|
||||
payload["freshness"] = freshness
|
||||
|
||||
# 是否返回摘要
|
||||
payload["summary"] = summary
|
||||
|
||||
# include:限制搜索域
|
||||
if include:
|
||||
payload["include"] = include
|
||||
|
||||
# exclude:排除搜索域
|
||||
if exclude:
|
||||
payload["exclude"] = exclude
|
||||
|
||||
results = await self._web_search_bocha(cfg, payload)
|
||||
if not results:
|
||||
return "Error: BoCha web searcher does not return any results."
|
||||
|
||||
ret_ls = []
|
||||
ref_uuid = str(uuid.uuid4())[:4]
|
||||
for idx, result in enumerate(results, 1):
|
||||
index = f"{ref_uuid}.{idx}"
|
||||
ret_ls.append(
|
||||
{
|
||||
"title": f"{result.title}",
|
||||
"url": f"{result.url}",
|
||||
"snippet": f"{result.snippet}",
|
||||
"index": index,
|
||||
}
|
||||
)
|
||||
if result.favicon:
|
||||
sp.temporary_cache["_ws_favicon"][result.url] = result.favicon
|
||||
# ret = "\n".join(ret_ls)
|
||||
ret = json.dumps({"results": ret_ls}, ensure_ascii=False)
|
||||
return ret
|
||||
|
||||
@filter.on_llm_request(priority=-10000)
|
||||
async def edit_web_search_tools(
|
||||
self,
|
||||
@@ -419,6 +585,7 @@ class Main(star.Star):
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
tool_set.remove_tool("tavily_extract_web_page")
|
||||
tool_set.remove_tool("AIsearch")
|
||||
tool_set.remove_tool("web_search_bocha")
|
||||
elif provider == "tavily":
|
||||
web_search_tavily = func_tool_mgr.get_func("web_search_tavily")
|
||||
tavily_extract_web_page = func_tool_mgr.get_func("tavily_extract_web_page")
|
||||
@@ -429,6 +596,7 @@ class Main(star.Star):
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
tool_set.remove_tool("AIsearch")
|
||||
tool_set.remove_tool("web_search_bocha")
|
||||
elif provider == "baidu_ai_search":
|
||||
try:
|
||||
await self.ensure_baidu_ai_search_mcp(event.unified_msg_origin)
|
||||
@@ -440,5 +608,15 @@ class Main(star.Star):
|
||||
tool_set.remove_tool("fetch_url")
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
tool_set.remove_tool("tavily_extract_web_page")
|
||||
tool_set.remove_tool("web_search_bocha")
|
||||
except Exception as e:
|
||||
logger.error(f"Cannot Initialize Baidu AI Search MCP Server: {e}")
|
||||
elif provider == "bocha":
|
||||
web_search_bocha = func_tool_mgr.get_func("web_search_bocha")
|
||||
if web_search_bocha:
|
||||
tool_set.add_tool(web_search_bocha)
|
||||
tool_set.remove_tool("web_search")
|
||||
tool_set.remove_tool("fetch_url")
|
||||
tool_set.remove_tool("AIsearch")
|
||||
tool_set.remove_tool("web_search_tavily")
|
||||
tool_set.remove_tool("tavily_extract_web_page")
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "4.13.2"
|
||||
__version__ = "4.14.7"
|
||||
|
||||
@@ -3,7 +3,13 @@
|
||||
|
||||
from typing import Any, ClassVar, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, model_serializer, model_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
GetCoreSchemaHandler,
|
||||
PrivateAttr,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic_core import core_schema
|
||||
|
||||
|
||||
@@ -178,6 +184,8 @@ class Message(BaseModel):
|
||||
tool_call_id: str | None = None
|
||||
"""The ID of the tool call."""
|
||||
|
||||
_no_save: bool = PrivateAttr(default=False)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_content_required(self):
|
||||
# assistant + tool_calls is not None: allow content to be None
|
||||
|
||||
@@ -3,6 +3,7 @@ import sys
|
||||
import time
|
||||
import traceback
|
||||
import typing as T
|
||||
from dataclasses import dataclass
|
||||
|
||||
from mcp.types import (
|
||||
BlobResourceContents,
|
||||
@@ -14,8 +15,9 @@ from mcp.types import (
|
||||
)
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.message import TextPart, ThinkPart
|
||||
from astrbot.core.agent.message import ImageURLPart, TextPart, ThinkPart
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.agent.tool_image_cache import tool_image_cache
|
||||
from astrbot.core.message.components import Json
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageChain,
|
||||
@@ -44,6 +46,28 @@ else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _HandleFunctionToolsResult:
|
||||
kind: T.Literal["message_chain", "tool_call_result_blocks", "cached_image"]
|
||||
message_chain: MessageChain | None = None
|
||||
tool_call_result_blocks: list[ToolCallMessageSegment] | None = None
|
||||
cached_image: T.Any = None
|
||||
|
||||
@classmethod
|
||||
def from_message_chain(cls, chain: MessageChain) -> "_HandleFunctionToolsResult":
|
||||
return cls(kind="message_chain", message_chain=chain)
|
||||
|
||||
@classmethod
|
||||
def from_tool_call_result_blocks(
|
||||
cls, blocks: list[ToolCallMessageSegment]
|
||||
) -> "_HandleFunctionToolsResult":
|
||||
return cls(kind="tool_call_result_blocks", tool_call_result_blocks=blocks)
|
||||
|
||||
@classmethod
|
||||
def from_cached_image(cls, image: T.Any) -> "_HandleFunctionToolsResult":
|
||||
return cls(kind="cached_image", cached_image=image)
|
||||
|
||||
|
||||
class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
@override
|
||||
async def reset(
|
||||
@@ -125,7 +149,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
messages = []
|
||||
# append existing messages in the run context
|
||||
for msg in request.contexts:
|
||||
messages.append(Message.model_validate(msg))
|
||||
m = Message.model_validate(msg)
|
||||
if isinstance(msg, dict) and msg.get("_no_save"):
|
||||
m._no_save = True
|
||||
messages.append(m)
|
||||
if request.prompt is not None:
|
||||
m = await request.assemble_context()
|
||||
messages.append(Message.model_validate(m))
|
||||
@@ -213,6 +240,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
if not llm_response.is_chunk and llm_response.usage:
|
||||
# only count the token usage of the final response for computation purpose
|
||||
self.stats.token_usage += llm_response.usage
|
||||
if self.req.conversation:
|
||||
self.req.conversation.token_usage = llm_response.usage.total
|
||||
break # got final response
|
||||
|
||||
if not llm_resp_result:
|
||||
@@ -252,6 +281,10 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
parts.append(TextPart(text=llm_resp.completion_text))
|
||||
if len(parts) == 0:
|
||||
logger.warning(
|
||||
"LLM returned empty assistant message with no tool calls."
|
||||
)
|
||||
self.run_context.messages.append(Message(role="assistant", content=parts))
|
||||
|
||||
# call the on_agent_done hook
|
||||
@@ -280,20 +313,27 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_resp, _ = await self._resolve_tool_exec(llm_resp)
|
||||
|
||||
tool_call_result_blocks = []
|
||||
cached_images = [] # Collect cached images for LLM visibility
|
||||
async for result in self._handle_function_tools(self.req, llm_resp):
|
||||
if isinstance(result, list):
|
||||
tool_call_result_blocks = result
|
||||
elif isinstance(result, MessageChain):
|
||||
if result.type is None:
|
||||
if result.kind == "tool_call_result_blocks":
|
||||
if result.tool_call_result_blocks is not None:
|
||||
tool_call_result_blocks = result.tool_call_result_blocks
|
||||
elif result.kind == "cached_image":
|
||||
if result.cached_image is not None:
|
||||
# Collect cached image info
|
||||
cached_images.append(result.cached_image)
|
||||
elif result.kind == "message_chain":
|
||||
chain = result.message_chain
|
||||
if chain is None or chain.type is None:
|
||||
# should not happen
|
||||
continue
|
||||
if result.type == "tool_direct_result":
|
||||
if chain.type == "tool_direct_result":
|
||||
ar_type = "tool_call_result"
|
||||
else:
|
||||
ar_type = result.type
|
||||
ar_type = chain.type
|
||||
yield AgentResponse(
|
||||
type=ar_type,
|
||||
data=AgentResponseData(chain=result),
|
||||
data=AgentResponseData(chain=chain),
|
||||
)
|
||||
|
||||
# 将结果添加到上下文中
|
||||
@@ -307,6 +347,8 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
)
|
||||
if llm_resp.completion_text:
|
||||
parts.append(TextPart(text=llm_resp.completion_text))
|
||||
if len(parts) == 0:
|
||||
parts = None
|
||||
tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=AssistantMessageSegment(
|
||||
tool_calls=llm_resp.to_openai_to_calls_model(),
|
||||
@@ -319,6 +361,41 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
tool_calls_result.to_openai_messages_model()
|
||||
)
|
||||
|
||||
# If there are cached images and the model supports image input,
|
||||
# append a user message with images so LLM can see them
|
||||
if cached_images:
|
||||
modalities = self.provider.provider_config.get("modalities", [])
|
||||
supports_image = "image" in modalities
|
||||
if supports_image:
|
||||
# Build user message with images for LLM to review
|
||||
image_parts = []
|
||||
for cached_img in cached_images:
|
||||
img_data = tool_image_cache.get_image_base64_by_path(
|
||||
cached_img.file_path, cached_img.mime_type
|
||||
)
|
||||
if img_data:
|
||||
base64_data, mime_type = img_data
|
||||
image_parts.append(
|
||||
TextPart(
|
||||
text=f"[Image from tool '{cached_img.tool_name}', path='{cached_img.file_path}']"
|
||||
)
|
||||
)
|
||||
image_parts.append(
|
||||
ImageURLPart(
|
||||
image_url=ImageURLPart.ImageURL(
|
||||
url=f"data:{mime_type};base64,{base64_data}",
|
||||
id=cached_img.file_path,
|
||||
)
|
||||
)
|
||||
)
|
||||
if image_parts:
|
||||
self.run_context.messages.append(
|
||||
Message(role="user", content=image_parts)
|
||||
)
|
||||
logger.debug(
|
||||
f"Appended {len(cached_images)} cached image(s) to context for LLM review"
|
||||
)
|
||||
|
||||
self.req.append_tool_calls_result(tool_calls_result)
|
||||
|
||||
async def step_until_done(
|
||||
@@ -354,7 +431,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
self,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> T.AsyncGenerator[MessageChain | list[ToolCallMessageSegment], None]:
|
||||
) -> T.AsyncGenerator[_HandleFunctionToolsResult, None]:
|
||||
"""处理函数工具调用。"""
|
||||
tool_call_result_blocks: list[ToolCallMessageSegment] = []
|
||||
logger.info(f"Agent 使用工具: {llm_response.tools_call_name}")
|
||||
@@ -365,18 +442,20 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
yield MessageChain(
|
||||
type="tool_call",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"name": func_tool_name,
|
||||
"args": func_tool_args,
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
],
|
||||
yield _HandleFunctionToolsResult.from_message_chain(
|
||||
MessageChain(
|
||||
type="tool_call",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"name": func_tool_name,
|
||||
"args": func_tool_args,
|
||||
"ts": time.time(),
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
try:
|
||||
if not req.func_tool:
|
||||
@@ -462,15 +541,28 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
),
|
||||
)
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
# Cache the image instead of sending directly
|
||||
cached_img = tool_image_cache.save_image(
|
||||
base64_data=res.content[0].data,
|
||||
tool_call_id=func_tool_id,
|
||||
tool_name=func_tool_name,
|
||||
index=0,
|
||||
mime_type=res.content[0].mimeType or "image/png",
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="The tool has successfully returned an image and sent directly to the user. You can describe it in your next response.",
|
||||
content=(
|
||||
f"Image returned and cached at path='{cached_img.file_path}'. "
|
||||
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
|
||||
f"with type='image' and path='{cached_img.file_path}'."
|
||||
),
|
||||
),
|
||||
)
|
||||
yield MessageChain(type="tool_direct_result").base64_image(
|
||||
res.content[0].data,
|
||||
# Yield image info for LLM visibility (will be handled in step())
|
||||
yield _HandleFunctionToolsResult.from_cached_image(
|
||||
cached_img
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
@@ -487,16 +579,29 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
# Cache the image instead of sending directly
|
||||
cached_img = tool_image_cache.save_image(
|
||||
base64_data=resource.blob,
|
||||
tool_call_id=func_tool_id,
|
||||
tool_name=func_tool_name,
|
||||
index=0,
|
||||
mime_type=resource.mimeType,
|
||||
)
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="The tool has successfully returned an image and sent directly to the user. You can describe it in your next response.",
|
||||
content=(
|
||||
f"Image returned and cached at path='{cached_img.file_path}'. "
|
||||
f"Review the image below. Use send_message_to_user to send it to the user if satisfied, "
|
||||
f"with type='image' and path='{cached_img.file_path}'."
|
||||
),
|
||||
),
|
||||
)
|
||||
yield MessageChain(
|
||||
type="tool_direct_result",
|
||||
).base64_image(resource.blob)
|
||||
# Yield image info for LLM visibility
|
||||
yield _HandleFunctionToolsResult.from_cached_image(
|
||||
cached_img
|
||||
)
|
||||
else:
|
||||
tool_call_result_blocks.append(
|
||||
ToolCallMessageSegment(
|
||||
@@ -557,23 +662,27 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]):
|
||||
# yield the last tool call result
|
||||
if tool_call_result_blocks:
|
||||
last_tcr_content = str(tool_call_result_blocks[-1].content)
|
||||
yield MessageChain(
|
||||
type="tool_call_result",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"ts": time.time(),
|
||||
"result": last_tcr_content,
|
||||
}
|
||||
)
|
||||
],
|
||||
yield _HandleFunctionToolsResult.from_message_chain(
|
||||
MessageChain(
|
||||
type="tool_call_result",
|
||||
chain=[
|
||||
Json(
|
||||
data={
|
||||
"id": func_tool_id,
|
||||
"ts": time.time(),
|
||||
"result": last_tcr_content,
|
||||
}
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
logger.info(f"Tool `{func_tool_name}` Result: {last_tcr_content}")
|
||||
|
||||
# 处理函数调用响应
|
||||
if tool_call_result_blocks:
|
||||
yield tool_call_result_blocks
|
||||
yield _HandleFunctionToolsResult.from_tool_call_result_blocks(
|
||||
tool_call_result_blocks
|
||||
)
|
||||
|
||||
def _build_tool_requery_context(
|
||||
self, tool_names: list[str]
|
||||
|
||||
@@ -246,8 +246,18 @@ class ToolSet:
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
# Avoid side effects by not modifying the original schema
|
||||
origin_type = schema.get("type")
|
||||
target_type = origin_type
|
||||
|
||||
# Compatibility fix: Gemini API expects 'type' to be a string (enum),
|
||||
# but standard JSON Schema (MCP) allows lists (e.g. ["string", "null"]).
|
||||
# We fallback to the first non-null type.
|
||||
if isinstance(origin_type, list):
|
||||
target_type = next((t for t in origin_type if t != "null"), "string")
|
||||
|
||||
if target_type in supported_types:
|
||||
result["type"] = target_type
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"],
|
||||
set(),
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
"""Tool image cache module for storing and retrieving images returned by tools.
|
||||
|
||||
This module allows LLM to review images before deciding whether to send them to users.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import ClassVar
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_temp_path
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedImage:
|
||||
"""Represents a cached image from a tool call."""
|
||||
|
||||
tool_call_id: str
|
||||
"""The tool call ID that produced this image."""
|
||||
tool_name: str
|
||||
"""The name of the tool that produced this image."""
|
||||
file_path: str
|
||||
"""The file path where the image is stored."""
|
||||
mime_type: str
|
||||
"""The MIME type of the image."""
|
||||
created_at: float = field(default_factory=time.time)
|
||||
"""Timestamp when the image was cached."""
|
||||
|
||||
|
||||
class ToolImageCache:
|
||||
"""Manages cached images from tool calls.
|
||||
|
||||
Images are stored in data/temp/tool_images/ and can be retrieved by file path.
|
||||
"""
|
||||
|
||||
_instance: ClassVar["ToolImageCache | None"] = None
|
||||
CACHE_DIR_NAME: ClassVar[str] = "tool_images"
|
||||
# Cache expiry time in seconds (1 hour)
|
||||
CACHE_EXPIRY: ClassVar[int] = 3600
|
||||
|
||||
def __new__(cls) -> "ToolImageCache":
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._initialized = False
|
||||
return cls._instance
|
||||
|
||||
def __init__(self) -> None:
|
||||
if self._initialized:
|
||||
return
|
||||
self._initialized = True
|
||||
self._cache_dir = os.path.join(get_astrbot_temp_path(), self.CACHE_DIR_NAME)
|
||||
os.makedirs(self._cache_dir, exist_ok=True)
|
||||
logger.debug(f"ToolImageCache initialized, cache dir: {self._cache_dir}")
|
||||
|
||||
def _get_file_extension(self, mime_type: str) -> str:
|
||||
"""Get file extension from MIME type."""
|
||||
mime_to_ext = {
|
||||
"image/png": ".png",
|
||||
"image/jpeg": ".jpg",
|
||||
"image/jpg": ".jpg",
|
||||
"image/gif": ".gif",
|
||||
"image/webp": ".webp",
|
||||
"image/bmp": ".bmp",
|
||||
"image/svg+xml": ".svg",
|
||||
}
|
||||
return mime_to_ext.get(mime_type.lower(), ".png")
|
||||
|
||||
def save_image(
|
||||
self,
|
||||
base64_data: str,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
index: int = 0,
|
||||
mime_type: str = "image/png",
|
||||
) -> CachedImage:
|
||||
"""Save an image to cache and return the cached image info.
|
||||
|
||||
Args:
|
||||
base64_data: Base64 encoded image data.
|
||||
tool_call_id: The tool call ID that produced this image.
|
||||
tool_name: The name of the tool that produced this image.
|
||||
index: The index of the image (for multiple images from same tool call).
|
||||
mime_type: The MIME type of the image.
|
||||
|
||||
Returns:
|
||||
CachedImage object with file path.
|
||||
"""
|
||||
ext = self._get_file_extension(mime_type)
|
||||
file_name = f"{tool_call_id}_{index}{ext}"
|
||||
file_path = os.path.join(self._cache_dir, file_name)
|
||||
|
||||
# Decode and save the image
|
||||
try:
|
||||
image_bytes = base64.b64decode(base64_data)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
logger.debug(f"Saved tool image to: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to save tool image: {e}")
|
||||
raise
|
||||
|
||||
return CachedImage(
|
||||
tool_call_id=tool_call_id,
|
||||
tool_name=tool_name,
|
||||
file_path=file_path,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
def get_image_base64_by_path(
|
||||
self, file_path: str, mime_type: str = "image/png"
|
||||
) -> tuple[str, str] | None:
|
||||
"""Read an image file and return its base64 encoded data.
|
||||
|
||||
Args:
|
||||
file_path: The file path of the cached image.
|
||||
mime_type: The MIME type of the image.
|
||||
|
||||
Returns:
|
||||
Tuple of (base64_data, mime_type) if found, None otherwise.
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
image_bytes = f.read()
|
||||
base64_data = base64.b64encode(image_bytes).decode("utf-8")
|
||||
return base64_data, mime_type
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to read cached image {file_path}: {e}")
|
||||
return None
|
||||
|
||||
def cleanup_expired(self) -> int:
|
||||
"""Clean up expired cached images.
|
||||
|
||||
Returns:
|
||||
Number of images cleaned up.
|
||||
"""
|
||||
now = time.time()
|
||||
cleaned = 0
|
||||
|
||||
try:
|
||||
for file_name in os.listdir(self._cache_dir):
|
||||
file_path = os.path.join(self._cache_dir, file_name)
|
||||
if os.path.isfile(file_path):
|
||||
file_age = now - os.path.getmtime(file_path)
|
||||
if file_age > self.CACHE_EXPIRY:
|
||||
os.remove(file_path)
|
||||
cleaned += 1
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during cache cleanup: {e}")
|
||||
|
||||
if cleaned:
|
||||
logger.info(f"Cleaned up {cleaned} expired cached images")
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
tool_image_cache = ToolImageCache()
|
||||
@@ -59,7 +59,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
|
||||
platform_name = run_context.context.event.get_platform_name()
|
||||
if (
|
||||
platform_name == "webchat"
|
||||
and tool.name == "web_search_tavily"
|
||||
and tool.name in ["web_search_tavily", "web_search_bocha"]
|
||||
and len(run_context.messages) > 0
|
||||
and tool_result
|
||||
and len(tool_result.content)
|
||||
|
||||
@@ -54,6 +54,14 @@ async def run_agent(
|
||||
return
|
||||
if resp.type == "tool_call_result":
|
||||
msg_chain = resp.data["chain"]
|
||||
|
||||
astr_event.trace.record(
|
||||
"agent_tool_result",
|
||||
tool_result=msg_chain.get_plain_text(
|
||||
with_other_comps_mark=True
|
||||
),
|
||||
)
|
||||
|
||||
if msg_chain.type == "tool_direct_result":
|
||||
# tool_direct_result 用于标记 llm tool 需要直接发送给用户的内容
|
||||
await astr_event.send(msg_chain)
|
||||
@@ -67,12 +75,22 @@ async def run_agent(
|
||||
# 用来标记流式响应需要分节
|
||||
yield MessageChain(chain=[], type="break")
|
||||
|
||||
tool_info = None
|
||||
|
||||
if resp.data["chain"].chain:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
tool_info = json_comp.data
|
||||
astr_event.trace.record(
|
||||
"agent_tool_call",
|
||||
tool_name=tool_info if tool_info else "unknown",
|
||||
)
|
||||
|
||||
if astr_event.get_platform_name() == "webchat":
|
||||
await astr_event.send(resp.data["chain"])
|
||||
elif show_tool_use:
|
||||
json_comp = resp.data["chain"].chain[0]
|
||||
if isinstance(json_comp, Json):
|
||||
m = f"🔨 调用工具: {json_comp.data.get('name')}"
|
||||
if tool_info:
|
||||
m = f"🔨 调用工具: {tool_info.get('name', 'unknown')}"
|
||||
else:
|
||||
m = "🔨 调用工具..."
|
||||
chain = MessageChain(type="tool_call").message(m)
|
||||
|
||||
@@ -7,11 +7,13 @@ import datetime
|
||||
import json
|
||||
import os
|
||||
import zoneinfo
|
||||
from collections.abc import Coroutine
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.api import sp
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.agent.handoff import HandoffTool
|
||||
from astrbot.core.agent.mcp_client import MCPTool
|
||||
from astrbot.core.agent.message import TextPart
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext
|
||||
@@ -19,7 +21,6 @@ from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS
|
||||
from astrbot.core.astr_agent_run_util import AgentRunner
|
||||
from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor
|
||||
from astrbot.core.astr_main_agent_resources import (
|
||||
CHATUI_EXTRA_PROMPT,
|
||||
CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT,
|
||||
EXECUTE_SHELL_TOOL,
|
||||
FILE_DOWNLOAD_TOOL,
|
||||
@@ -99,6 +100,8 @@ class MainAgentBuildConfig:
|
||||
"""This will inject healthy and safe system prompt into the main agent,
|
||||
to prevent LLM output harmful information"""
|
||||
safety_mode_strategy: str = "system_prompt"
|
||||
computer_use_runtime: str = "local"
|
||||
"""The runtime for agent computer use: none, local, or sandbox."""
|
||||
sandbox_cfg: dict = field(default_factory=dict)
|
||||
add_cron_tools: bool = True
|
||||
"""This will add cron job management tools to the main agent for proactive cron job execution."""
|
||||
@@ -112,6 +115,7 @@ class MainAgentBuildResult:
|
||||
agent_runner: AgentRunner
|
||||
provider_request: ProviderRequest
|
||||
provider: Provider
|
||||
reset_coro: Coroutine | None = None
|
||||
|
||||
|
||||
def _select_provider(
|
||||
@@ -259,6 +263,8 @@ async def _ensure_persona_and_skills(
|
||||
return
|
||||
|
||||
# get persona ID
|
||||
|
||||
# 1. from session service config - highest priority
|
||||
persona_id = (
|
||||
await sp.get_async(
|
||||
scope="umo",
|
||||
@@ -269,14 +275,15 @@ async def _ensure_persona_and_skills(
|
||||
).get("persona_id")
|
||||
|
||||
if not persona_id:
|
||||
persona_id = req.conversation.persona_id or cfg.get("default_personality")
|
||||
if persona_id is None or persona_id != "[%None]":
|
||||
default_persona = plugin_context.persona_manager.selected_default_persona_v3
|
||||
if default_persona:
|
||||
persona_id = default_persona["name"]
|
||||
if event.get_platform_name() == "webchat":
|
||||
persona_id = "_chatui_default_"
|
||||
req.system_prompt += CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT
|
||||
# 2. from conversation setting - second priority
|
||||
persona_id = req.conversation.persona_id
|
||||
|
||||
if persona_id == "[%None]":
|
||||
# explicitly set to no persona
|
||||
pass
|
||||
elif persona_id is None:
|
||||
# 3. from config default persona setting - last priority
|
||||
persona_id = cfg.get("default_personality")
|
||||
|
||||
persona = next(
|
||||
builtins.filter(
|
||||
@@ -291,23 +298,18 @@ async def _ensure_persona_and_skills(
|
||||
req.system_prompt += f"\n# Persona Instructions\n\n{prompt}\n"
|
||||
if begin_dialogs := copy.deepcopy(persona.get("_begin_dialogs_processed")):
|
||||
req.contexts[:0] = begin_dialogs
|
||||
else:
|
||||
# special handling for webchat persona
|
||||
if event.get_platform_name() == "webchat" and persona_id != "[%None]":
|
||||
persona_id = "_chatui_default_"
|
||||
req.system_prompt += CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT
|
||||
|
||||
# Inject skills prompt
|
||||
skills_cfg = cfg.get("skills", {})
|
||||
sandbox_cfg = cfg.get("sandbox", {})
|
||||
runtime = cfg.get("computer_use_runtime", "local")
|
||||
skill_manager = SkillManager()
|
||||
runtime = skills_cfg.get("runtime", "local")
|
||||
skills = skill_manager.list_skills(active_only=True, runtime=runtime)
|
||||
|
||||
if runtime == "sandbox" and not sandbox_cfg.get("enable", False):
|
||||
logger.warning(
|
||||
"Skills runtime is set to sandbox, but sandbox mode is disabled, will skip skills prompt injection.",
|
||||
)
|
||||
req.system_prompt += (
|
||||
"\n[Background: User added some skills, and skills runtime is set to sandbox, "
|
||||
"but sandbox mode is disabled. So skills will be unavailable.]\n"
|
||||
)
|
||||
elif skills:
|
||||
if skills:
|
||||
if persona and persona.get("skills") is not None:
|
||||
if not persona["skills"]:
|
||||
skills = []
|
||||
@@ -316,12 +318,12 @@ async def _ensure_persona_and_skills(
|
||||
skills = [skill for skill in skills if skill.name in allowed]
|
||||
if skills:
|
||||
req.system_prompt += f"\n{build_skills_prompt(skills)}\n"
|
||||
|
||||
runtime = skills_cfg.get("runtime", "local")
|
||||
sandbox_enabled = sandbox_cfg.get("enable", False)
|
||||
if runtime == "local" and not sandbox_enabled:
|
||||
_apply_local_env_tools(req)
|
||||
|
||||
if runtime == "none":
|
||||
req.system_prompt += (
|
||||
"User has not enabled the Computer Use feature. "
|
||||
"You cannot use shell or Python to perform skills. "
|
||||
"If you need to use these capabilities, ask the user to enable Computer Use in the AstrBot WebUI -> Config."
|
||||
)
|
||||
tmgr = plugin_context.get_llm_tool_manager()
|
||||
|
||||
# sub agents integration
|
||||
@@ -708,9 +710,18 @@ def _sanitize_context_by_modalities(
|
||||
|
||||
|
||||
def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None:
|
||||
"""根据事件中的插件设置,过滤请求中的工具列表。
|
||||
|
||||
注意:没有 handler_module_path 的工具(如 MCP 工具)会被保留,
|
||||
因为它们不属于任何插件,不应被插件过滤逻辑影响。
|
||||
"""
|
||||
if event.plugins_name is not None and req.func_tool:
|
||||
new_tool_set = ToolSet()
|
||||
for tool in req.func_tool.tools:
|
||||
if isinstance(tool, MCPTool):
|
||||
# 保留 MCP 工具
|
||||
new_tool_set.add_tool(tool)
|
||||
continue
|
||||
mp = tool.handler_module_path
|
||||
if not mp:
|
||||
continue
|
||||
@@ -828,8 +839,12 @@ async def build_main_agent(
|
||||
config: MainAgentBuildConfig,
|
||||
provider: Provider | None = None,
|
||||
req: ProviderRequest | None = None,
|
||||
apply_reset: bool = True,
|
||||
) -> MainAgentBuildResult | None:
|
||||
"""构建主对话代理(Main Agent),并且自动 reset。"""
|
||||
"""构建主对话代理(Main Agent),并且自动 reset。
|
||||
|
||||
If apply_reset is False, will not call reset on the agent runner.
|
||||
"""
|
||||
provider = provider or _select_provider(event, plugin_context)
|
||||
if provider is None:
|
||||
logger.info("未找到任何对话模型(提供商),跳过 LLM 请求处理。")
|
||||
@@ -905,8 +920,10 @@ async def build_main_agent(
|
||||
if config.llm_safety_mode:
|
||||
_apply_llm_safety_mode(config, req)
|
||||
|
||||
if config.sandbox_cfg.get("enable", False):
|
||||
if config.computer_use_runtime == "sandbox":
|
||||
_apply_sandbox_tools(config, req, req.session_id)
|
||||
elif config.computer_use_runtime == "local":
|
||||
_apply_local_env_tools(req)
|
||||
|
||||
agent_runner = AgentRunner()
|
||||
astr_agent_ctx = AstrAgentContext(
|
||||
@@ -931,7 +948,6 @@ async def build_main_agent(
|
||||
|
||||
if event.get_platform_name() == "webchat":
|
||||
asyncio.create_task(_handle_webchat(event, req, provider))
|
||||
req.system_prompt += f"\n{CHATUI_EXTRA_PROMPT}\n"
|
||||
|
||||
if req.func_tool and req.func_tool.tools:
|
||||
tool_prompt = (
|
||||
@@ -945,7 +961,7 @@ async def build_main_agent(
|
||||
if action_type == "live":
|
||||
req.system_prompt += f"\n{LIVE_MODE_SYSTEM_PROMPT}\n"
|
||||
|
||||
await agent_runner.reset(
|
||||
reset_coro = agent_runner.reset(
|
||||
provider=provider,
|
||||
request=req,
|
||||
run_context=AgentContextWrapper(
|
||||
@@ -963,8 +979,12 @@ async def build_main_agent(
|
||||
tool_schema_mode=config.tool_schema_mode,
|
||||
)
|
||||
|
||||
if apply_reset:
|
||||
await reset_coro
|
||||
|
||||
return MainAgentBuildResult(
|
||||
agent_runner=agent_runner,
|
||||
provider_request=req,
|
||||
provider=provider,
|
||||
reset_coro=reset_coro if not apply_reset else None,
|
||||
)
|
||||
|
||||
@@ -78,9 +78,6 @@ CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT = (
|
||||
"You listen more than you speak, respect uncertainty, avoid forcing quick conclusions or grand narratives, "
|
||||
"and prefer clear, restrained language over unnecessary emotional embellishment. At your core, you value "
|
||||
"empathy, clarity, autonomy, and meaning, favoring steady, sustainable progress over judgment or dramatic leaps."
|
||||
)
|
||||
|
||||
CHATUI_EXTRA_PROMPT = (
|
||||
'When you answered, you need to add a follow up question / summarization but do not add "Follow up" words. '
|
||||
"Such as, user asked you to generate codes, you can add: Do you need me to run these codes for you?"
|
||||
)
|
||||
|
||||
@@ -35,12 +35,21 @@ async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None:
|
||||
os.remove(zip_path)
|
||||
shutil.make_archive(zip_base, "zip", skills_root)
|
||||
remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip"
|
||||
logger.info("Uploading skills bundle to sandbox...")
|
||||
await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}")
|
||||
upload_result = await booter.upload_file(zip_path, str(remote_zip))
|
||||
if not upload_result.get("success", False):
|
||||
raise RuntimeError("Failed to upload skills bundle to sandbox.")
|
||||
# Use -n flag to never overwrite existing files, fallback to Python if unzip unavailable
|
||||
await booter.shell.exec(
|
||||
f"unzip -o {remote_zip} -d {SANDBOX_SKILLS_ROOT} && rm -f {remote_zip}"
|
||||
f"unzip -n {remote_zip} -d {SANDBOX_SKILLS_ROOT} || "
|
||||
f"python3 -c \"import zipfile, os, pathlib; z=zipfile.ZipFile('{remote_zip}'); "
|
||||
f"[z.extract(m, '{SANDBOX_SKILLS_ROOT}') for m in z.namelist() "
|
||||
f"if not os.path.exists(os.path.join('{SANDBOX_SKILLS_ROOT}', m))]\" || "
|
||||
f"python -c \"import zipfile, os, pathlib; z=zipfile.ZipFile('{remote_zip}'); "
|
||||
f"[z.extract(m, '{SANDBOX_SKILLS_ROOT}') for m in z.namelist() "
|
||||
f"if not os.path.exists(os.path.join('{SANDBOX_SKILLS_ROOT}', m))]\"; "
|
||||
f"rm -f {remote_zip}"
|
||||
)
|
||||
finally:
|
||||
if os.path.exists(zip_path):
|
||||
|
||||
+126
-96
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
VERSION = "4.13.2"
|
||||
VERSION = "4.14.7"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
WEBHOOK_SUPPORTED_PLATFORMS = [
|
||||
@@ -74,6 +74,7 @@ DEFAULT_CONFIG = {
|
||||
"web_search": False,
|
||||
"websearch_provider": "default",
|
||||
"websearch_tavily_key": [],
|
||||
"websearch_bocha_key": [],
|
||||
"websearch_baidu_app_builder_key": "",
|
||||
"web_search_link": False,
|
||||
"display_reasoning_text": False,
|
||||
@@ -117,15 +118,14 @@ DEFAULT_CONFIG = {
|
||||
"proactive_capability": {
|
||||
"add_cron_tools": True,
|
||||
},
|
||||
"computer_use_runtime": "local",
|
||||
"sandbox": {
|
||||
"enable": False,
|
||||
"booter": "shipyard",
|
||||
"shipyard_endpoint": "",
|
||||
"shipyard_access_token": "",
|
||||
"shipyard_ttl": 3600,
|
||||
"shipyard_max_sessions": 10,
|
||||
},
|
||||
"skills": {"runtime": "sandbox"},
|
||||
},
|
||||
# SubAgent orchestrator mode:
|
||||
# - main_enable = False: disabled; main LLM mounts tools normally (persona selection).
|
||||
@@ -177,7 +177,7 @@ DEFAULT_CONFIG = {
|
||||
"t2i_use_file_service": False,
|
||||
"t2i_active_template": "base",
|
||||
"http_proxy": "",
|
||||
"no_proxy": ["localhost", "127.0.0.1", "::1"],
|
||||
"no_proxy": ["localhost", "127.0.0.1", "::1", "10.*", "192.168.*"],
|
||||
"dashboard": {
|
||||
"enable": True,
|
||||
"username": "astrbot",
|
||||
@@ -202,6 +202,7 @@ DEFAULT_CONFIG = {
|
||||
"log_file_enable": False,
|
||||
"log_file_path": "logs/astrbot.log",
|
||||
"log_file_max_mb": 20,
|
||||
"trace_enable": False,
|
||||
"trace_log_enable": False,
|
||||
"trace_log_path": "logs/astrbot.trace.log",
|
||||
"trace_log_max_mb": 20,
|
||||
@@ -912,6 +913,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.openai.com/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Google Gemini": {
|
||||
@@ -934,6 +936,7 @@ CONFIG_METADATA_2 = {
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
"gm_thinking_config": {"budget": 0, "level": "HIGH"},
|
||||
"proxy": "",
|
||||
},
|
||||
"Anthropic": {
|
||||
"id": "anthropic",
|
||||
@@ -944,6 +947,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.anthropic.com/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"anth_thinking_config": {"budget": 0},
|
||||
},
|
||||
"Moonshot": {
|
||||
@@ -955,6 +959,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.moonshot.cn/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"xAI": {
|
||||
@@ -966,6 +971,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.x.ai/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
"xai_native_search": False,
|
||||
},
|
||||
@@ -978,6 +984,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.deepseek.com/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Zhipu": {
|
||||
@@ -989,6 +996,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://open.bigmodel.cn/api/paas/v4/",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Azure OpenAI": {
|
||||
@@ -1001,6 +1009,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Ollama": {
|
||||
@@ -1011,6 +1020,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://127.0.0.1:11434/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"LM Studio": {
|
||||
@@ -1021,6 +1031,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": True,
|
||||
"key": ["lmstudio"],
|
||||
"api_base": "http://127.0.0.1:1234/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Gemini_OpenAI_API": {
|
||||
@@ -1032,6 +1043,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Groq": {
|
||||
@@ -1043,6 +1055,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.groq.com/openai/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"302.AI": {
|
||||
@@ -1054,6 +1067,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.302.ai/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"SiliconFlow": {
|
||||
@@ -1065,6 +1079,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api.siliconflow.cn/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"PPIO": {
|
||||
@@ -1076,6 +1091,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.ppinfra.com/v3/openai",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"TokenPony": {
|
||||
@@ -1087,6 +1103,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.tokenpony.cn/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Compshare": {
|
||||
@@ -1098,6 +1115,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.modelverse.cn/v1",
|
||||
"timeout": 120,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"ModelScope": {
|
||||
@@ -1109,6 +1127,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"timeout": 120,
|
||||
"api_base": "https://api-inference.modelscope.cn/v1",
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
},
|
||||
"Dify": {
|
||||
@@ -1124,6 +1143,7 @@ CONFIG_METADATA_2 = {
|
||||
"dify_query_input_key": "astrbot_text_query",
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
"proxy": "",
|
||||
},
|
||||
"Coze": {
|
||||
"id": "coze",
|
||||
@@ -1135,6 +1155,7 @@ CONFIG_METADATA_2 = {
|
||||
"bot_id": "",
|
||||
"coze_api_base": "https://api.coze.cn",
|
||||
"timeout": 60,
|
||||
"proxy": "",
|
||||
# "auto_save_history": True,
|
||||
},
|
||||
"阿里云百炼应用": {
|
||||
@@ -1153,6 +1174,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"variables": {},
|
||||
"timeout": 60,
|
||||
"proxy": "",
|
||||
},
|
||||
"FastGPT": {
|
||||
"id": "fastgpt",
|
||||
@@ -1163,6 +1185,7 @@ CONFIG_METADATA_2 = {
|
||||
"key": [],
|
||||
"api_base": "https://api.fastgpt.in/api/v1",
|
||||
"timeout": 60,
|
||||
"proxy": "",
|
||||
"custom_headers": {},
|
||||
"custom_extra_body": {},
|
||||
},
|
||||
@@ -1175,6 +1198,7 @@ CONFIG_METADATA_2 = {
|
||||
"api_key": "",
|
||||
"api_base": "",
|
||||
"model": "whisper-1",
|
||||
"proxy": "",
|
||||
},
|
||||
"Whisper(Local)": {
|
||||
"provider": "openai",
|
||||
@@ -1204,6 +1228,7 @@ CONFIG_METADATA_2 = {
|
||||
"model": "tts-1",
|
||||
"openai-tts-voice": "alloy",
|
||||
"timeout": "20",
|
||||
"proxy": "",
|
||||
},
|
||||
"Genie TTS": {
|
||||
"id": "genie_tts",
|
||||
@@ -1284,6 +1309,7 @@ CONFIG_METADATA_2 = {
|
||||
"fishaudio-tts-character": "可莉",
|
||||
"fishaudio-tts-reference-id": "",
|
||||
"timeout": "20",
|
||||
"proxy": "",
|
||||
},
|
||||
"阿里云百炼 TTS(API)": {
|
||||
"hint": "API Key 从 https://bailian.console.aliyun.com/?tab=model#/api-key 获取。模型和音色的选择文档请参考: 阿里云百炼语音合成音色名称。具体可参考 https://help.aliyun.com/zh/model-studio/speech-synthesis-and-speech-recognition",
|
||||
@@ -1310,6 +1336,7 @@ CONFIG_METADATA_2 = {
|
||||
"azure_tts_volume": "100",
|
||||
"azure_tts_subscription_key": "",
|
||||
"azure_tts_region": "eastus",
|
||||
"proxy": "",
|
||||
},
|
||||
"MiniMax TTS(API)": {
|
||||
"id": "minimax_tts",
|
||||
@@ -1332,6 +1359,7 @@ CONFIG_METADATA_2 = {
|
||||
"minimax-voice-latex": False,
|
||||
"minimax-voice-english-normalization": False,
|
||||
"timeout": 20,
|
||||
"proxy": "",
|
||||
},
|
||||
"火山引擎_TTS(API)": {
|
||||
"id": "volcengine_tts",
|
||||
@@ -1346,6 +1374,7 @@ CONFIG_METADATA_2 = {
|
||||
"volcengine_speed_ratio": 1.0,
|
||||
"api_base": "https://openspeech.bytedance.com/api/v1/tts",
|
||||
"timeout": 20,
|
||||
"proxy": "",
|
||||
},
|
||||
"Gemini TTS": {
|
||||
"id": "gemini_tts",
|
||||
@@ -1359,6 +1388,7 @@ CONFIG_METADATA_2 = {
|
||||
"gemini_tts_model": "gemini-2.5-flash-preview-tts",
|
||||
"gemini_tts_prefix": "",
|
||||
"gemini_tts_voice_name": "Leda",
|
||||
"proxy": "",
|
||||
},
|
||||
"OpenAI Embedding": {
|
||||
"id": "openai_embedding",
|
||||
@@ -1371,6 +1401,7 @@ CONFIG_METADATA_2 = {
|
||||
"embedding_model": "",
|
||||
"embedding_dimensions": 1024,
|
||||
"timeout": 20,
|
||||
"proxy": "",
|
||||
},
|
||||
"Gemini Embedding": {
|
||||
"id": "gemini_embedding",
|
||||
@@ -1383,6 +1414,7 @@ CONFIG_METADATA_2 = {
|
||||
"embedding_model": "gemini-embedding-exp-03-07",
|
||||
"embedding_dimensions": 768,
|
||||
"timeout": 20,
|
||||
"proxy": "",
|
||||
},
|
||||
"vLLM Rerank": {
|
||||
"id": "vllm_rerank",
|
||||
@@ -2079,6 +2111,11 @@ CONFIG_METADATA_2 = {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
},
|
||||
"proxy": {
|
||||
"description": "代理地址",
|
||||
"type": "string",
|
||||
"hint": "HTTP/HTTPS 代理地址,格式如 http://127.0.0.1:7890。仅对该提供商的 API 请求生效,不影响 Docker 内网通信。",
|
||||
},
|
||||
"model": {
|
||||
"description": "模型 ID",
|
||||
"type": "string",
|
||||
@@ -2224,17 +2261,6 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"skills": {
|
||||
"type": "object",
|
||||
"items": {
|
||||
"enable": {
|
||||
"type": "bool",
|
||||
},
|
||||
"runtime": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"proactive_capability": {
|
||||
"type": "object",
|
||||
"items": {
|
||||
@@ -2515,6 +2541,7 @@ CONFIG_METADATA_3 = {
|
||||
},
|
||||
"persona": {
|
||||
"description": "人格",
|
||||
"hint": "",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.default_personality": {
|
||||
@@ -2530,6 +2557,7 @@ CONFIG_METADATA_3 = {
|
||||
},
|
||||
"knowledgebase": {
|
||||
"description": "知识库",
|
||||
"hint": "",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"kb_names": {
|
||||
@@ -2562,6 +2590,7 @@ CONFIG_METADATA_3 = {
|
||||
},
|
||||
"websearch": {
|
||||
"description": "网页搜索",
|
||||
"hint": "",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.web_search": {
|
||||
@@ -2571,7 +2600,10 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.websearch_provider": {
|
||||
"description": "网页搜索提供商",
|
||||
"type": "string",
|
||||
"options": ["default", "tavily", "baidu_ai_search"],
|
||||
"options": ["default", "tavily", "baidu_ai_search", "bocha"],
|
||||
"condition": {
|
||||
"provider_settings.web_search": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.websearch_tavily_key": {
|
||||
"description": "Tavily API Key",
|
||||
@@ -2580,6 +2612,17 @@ CONFIG_METADATA_3 = {
|
||||
"hint": "可添加多个 Key 进行轮询。",
|
||||
"condition": {
|
||||
"provider_settings.websearch_provider": "tavily",
|
||||
"provider_settings.web_search": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.websearch_bocha_key": {
|
||||
"description": "BoCha API Key",
|
||||
"type": "list",
|
||||
"items": {"type": "string"},
|
||||
"hint": "可添加多个 Key 进行轮询。",
|
||||
"condition": {
|
||||
"provider_settings.websearch_provider": "bocha",
|
||||
"provider_settings.web_search": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.websearch_baidu_app_builder_key": {
|
||||
@@ -2593,6 +2636,73 @@ CONFIG_METADATA_3 = {
|
||||
"provider_settings.web_search_link": {
|
||||
"description": "显示来源引用",
|
||||
"type": "bool",
|
||||
"condition": {
|
||||
"provider_settings.web_search": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"agent_computer_use": {
|
||||
"description": "Agent Computer Use",
|
||||
"hint": "",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.computer_use_runtime": {
|
||||
"description": "Computer Use Runtime",
|
||||
"type": "string",
|
||||
"options": ["none", "local", "sandbox"],
|
||||
"labels": ["无", "本地", "沙箱"],
|
||||
"hint": "选择 Computer Use 运行环境。",
|
||||
},
|
||||
"provider_settings.sandbox.booter": {
|
||||
"description": "沙箱环境驱动器",
|
||||
"type": "string",
|
||||
"options": ["shipyard"],
|
||||
"labels": ["Shipyard"],
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_endpoint": {
|
||||
"description": "Shipyard API Endpoint",
|
||||
"type": "string",
|
||||
"hint": "Shipyard 服务的 API 访问地址。",
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard",
|
||||
},
|
||||
"_special": "check_shipyard_connection",
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_access_token": {
|
||||
"description": "Shipyard Access Token",
|
||||
"type": "string",
|
||||
"hint": "用于访问 Shipyard 服务的访问令牌。",
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_ttl": {
|
||||
"description": "Shipyard Session TTL",
|
||||
"type": "int",
|
||||
"hint": "Shipyard 会话的生存时间(秒)。",
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_max_sessions": {
|
||||
"description": "Shipyard Max Sessions",
|
||||
"type": "int",
|
||||
"hint": "Shipyard 最大会话数量。",
|
||||
"condition": {
|
||||
"provider_settings.computer_use_runtime": "sandbox",
|
||||
"provider_settings.sandbox.booter": "shipyard",
|
||||
},
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
@@ -2630,86 +2740,6 @@ CONFIG_METADATA_3 = {
|
||||
# "provider_settings.enable": True,
|
||||
# },
|
||||
# },
|
||||
"sandbox": {
|
||||
"description": "Agent 沙箱环境",
|
||||
"hint": "",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"provider_settings.sandbox.enable": {
|
||||
"description": "启用沙箱环境",
|
||||
"type": "bool",
|
||||
"hint": "启用后,Agent 可以使用沙箱环境中的工具和资源,如 Python 代码执行、Shell 等。",
|
||||
},
|
||||
"provider_settings.sandbox.booter": {
|
||||
"description": "沙箱环境驱动器",
|
||||
"type": "string",
|
||||
"options": ["shipyard"],
|
||||
"labels": ["Shipyard"],
|
||||
"condition": {
|
||||
"provider_settings.sandbox.enable": True,
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_endpoint": {
|
||||
"description": "Shipyard API Endpoint",
|
||||
"type": "string",
|
||||
"hint": "Shipyard 服务的 API 访问地址。",
|
||||
"condition": {
|
||||
"provider_settings.sandbox.enable": True,
|
||||
"provider_settings.sandbox.booter": "shipyard",
|
||||
},
|
||||
"_special": "check_shipyard_connection",
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_access_token": {
|
||||
"description": "Shipyard Access Token",
|
||||
"type": "string",
|
||||
"hint": "用于访问 Shipyard 服务的访问令牌。",
|
||||
"condition": {
|
||||
"provider_settings.sandbox.enable": True,
|
||||
"provider_settings.sandbox.booter": "shipyard",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_ttl": {
|
||||
"description": "Shipyard Session TTL",
|
||||
"type": "int",
|
||||
"hint": "Shipyard 会话的生存时间(秒)。",
|
||||
"condition": {
|
||||
"provider_settings.sandbox.enable": True,
|
||||
"provider_settings.sandbox.booter": "shipyard",
|
||||
},
|
||||
},
|
||||
"provider_settings.sandbox.shipyard_max_sessions": {
|
||||
"description": "Shipyard Max Sessions",
|
||||
"type": "int",
|
||||
"hint": "Shipyard 最大会话数量。",
|
||||
"condition": {
|
||||
"provider_settings.sandbox.enable": True,
|
||||
"provider_settings.sandbox.booter": "shipyard",
|
||||
},
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"skills": {
|
||||
"description": "Skills",
|
||||
"type": "object",
|
||||
"hint": "",
|
||||
"items": {
|
||||
"provider_settings.skills.runtime": {
|
||||
"description": "Skill Runtime",
|
||||
"type": "string",
|
||||
"options": ["local", "sandbox"],
|
||||
"labels": ["本地", "沙箱"],
|
||||
"hint": "选择 Skills 运行环境。使用沙箱时需先启用沙箱环境。",
|
||||
},
|
||||
},
|
||||
"condition": {
|
||||
"provider_settings.agent_runner_type": "local",
|
||||
"provider_settings.enable": True,
|
||||
},
|
||||
},
|
||||
"proactive_capability": {
|
||||
"description": "主动型 Agent",
|
||||
"hint": "https://docs.astrbot.app/use/proactive-agent.html",
|
||||
|
||||
@@ -42,6 +42,55 @@ class ConfigMetadataI18n:
|
||||
"""
|
||||
result = {}
|
||||
|
||||
def convert_items(
|
||||
group: str, section: str, items: dict[str, Any], prefix: str = ""
|
||||
) -> dict[str, Any]:
|
||||
items_result: dict[str, Any] = {}
|
||||
|
||||
for field_key, field_data in items.items():
|
||||
if not isinstance(field_data, dict):
|
||||
items_result[field_key] = field_data
|
||||
continue
|
||||
|
||||
field_name = field_key
|
||||
field_path = f"{prefix}.{field_name}" if prefix else field_name
|
||||
|
||||
field_result = {
|
||||
key: value
|
||||
for key, value in field_data.items()
|
||||
if key not in {"description", "hint", "labels", "name"}
|
||||
}
|
||||
|
||||
if "description" in field_data:
|
||||
field_result["description"] = (
|
||||
f"{group}.{section}.{field_path}.description"
|
||||
)
|
||||
if "hint" in field_data:
|
||||
field_result["hint"] = f"{group}.{section}.{field_path}.hint"
|
||||
if "labels" in field_data:
|
||||
field_result["labels"] = f"{group}.{section}.{field_path}.labels"
|
||||
if "name" in field_data:
|
||||
field_result["name"] = f"{group}.{section}.{field_path}.name"
|
||||
|
||||
if "items" in field_data and isinstance(field_data["items"], dict):
|
||||
field_result["items"] = convert_items(
|
||||
group, section, field_data["items"], field_path
|
||||
)
|
||||
|
||||
if "template_schema" in field_data and isinstance(
|
||||
field_data["template_schema"], dict
|
||||
):
|
||||
field_result["template_schema"] = convert_items(
|
||||
group,
|
||||
section,
|
||||
field_data["template_schema"],
|
||||
f"{field_path}.template_schema",
|
||||
)
|
||||
|
||||
items_result[field_key] = field_result
|
||||
|
||||
return items_result
|
||||
|
||||
for group_key, group_data in metadata.items():
|
||||
group_result = {
|
||||
"name": f"{group_key}.name",
|
||||
@@ -50,59 +99,19 @@ class ConfigMetadataI18n:
|
||||
|
||||
for section_key, section_data in group_data.get("metadata", {}).items():
|
||||
section_result = {
|
||||
"description": f"{group_key}.{section_key}.description",
|
||||
"type": section_data.get("type"),
|
||||
key: value
|
||||
for key, value in section_data.items()
|
||||
if key not in {"description", "hint", "labels", "name"}
|
||||
}
|
||||
section_result["description"] = f"{group_key}.{section_key}.description"
|
||||
|
||||
# 复制其他属性
|
||||
for key in ["items", "condition", "_special", "invisible"]:
|
||||
if key in section_data:
|
||||
section_result[key] = section_data[key]
|
||||
|
||||
# 处理 hint
|
||||
if "hint" in section_data:
|
||||
section_result["hint"] = f"{group_key}.{section_key}.hint"
|
||||
|
||||
# 处理 items 中的字段
|
||||
if "items" in section_data and isinstance(section_data["items"], dict):
|
||||
items_result = {}
|
||||
for field_key, field_data in section_data["items"].items():
|
||||
# 处理嵌套的点号字段名(如 provider_settings.enable)
|
||||
field_name = field_key
|
||||
|
||||
field_result = {}
|
||||
|
||||
# 复制基本属性
|
||||
for attr in [
|
||||
"type",
|
||||
"condition",
|
||||
"_special",
|
||||
"invisible",
|
||||
"options",
|
||||
"slider",
|
||||
]:
|
||||
if attr in field_data:
|
||||
field_result[attr] = field_data[attr]
|
||||
|
||||
# 转换文本属性为国际化键
|
||||
if "description" in field_data:
|
||||
field_result["description"] = (
|
||||
f"{group_key}.{section_key}.{field_name}.description"
|
||||
)
|
||||
|
||||
if "hint" in field_data:
|
||||
field_result["hint"] = (
|
||||
f"{group_key}.{section_key}.{field_name}.hint"
|
||||
)
|
||||
|
||||
if "labels" in field_data:
|
||||
field_result["labels"] = (
|
||||
f"{group_key}.{section_key}.{field_name}.labels"
|
||||
)
|
||||
|
||||
items_result[field_key] = field_result
|
||||
|
||||
section_result["items"] = items_result
|
||||
section_result["items"] = convert_items(
|
||||
group_key, section_key, section_data["items"]
|
||||
)
|
||||
|
||||
group_result["metadata"][section_key] = section_result
|
||||
|
||||
|
||||
@@ -310,6 +310,7 @@ class CronJobManager:
|
||||
config = MainAgentBuildConfig(
|
||||
tool_call_timeout=3600,
|
||||
llm_safety_mode=False,
|
||||
streaming_response=False,
|
||||
)
|
||||
req = ProviderRequest()
|
||||
conv = await _get_session_conv(event=cron_event, plugin_context=self.ctx)
|
||||
|
||||
@@ -54,7 +54,6 @@ class EventBus:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
|
||||
"""
|
||||
event.trace.record("event_dispatch", config_name=conf_name)
|
||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||
if event.get_sender_name():
|
||||
logger.info(
|
||||
|
||||
@@ -9,6 +9,7 @@ from astrbot.core.message.components import (
|
||||
AtAll,
|
||||
BaseMessageComponent,
|
||||
Image,
|
||||
Json,
|
||||
Plain,
|
||||
)
|
||||
|
||||
@@ -117,9 +118,26 @@ class MessageChain:
|
||||
self.use_t2i_ = use_t2i
|
||||
return self
|
||||
|
||||
def get_plain_text(self) -> str:
|
||||
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
|
||||
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||
def get_plain_text(self, with_other_comps_mark: bool = False) -> str:
|
||||
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。
|
||||
|
||||
Args:
|
||||
with_other_comps_mark (bool): 是否在纯文本中标记其他组件的位置
|
||||
"""
|
||||
if not with_other_comps_mark:
|
||||
return " ".join(
|
||||
[comp.text for comp in self.chain if isinstance(comp, Plain)]
|
||||
)
|
||||
else:
|
||||
texts = []
|
||||
for comp in self.chain:
|
||||
if isinstance(comp, Plain):
|
||||
texts.append(comp.text)
|
||||
elif isinstance(comp, Json):
|
||||
texts.append(f"{comp.data}")
|
||||
else:
|
||||
texts.append(f"[{comp.__class__.__name__}]")
|
||||
return " ".join(texts)
|
||||
|
||||
def squash_plain(self):
|
||||
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
||||
|
||||
@@ -313,7 +313,7 @@ class PersonaManager:
|
||||
{
|
||||
"role": "user" if user_turn else "assistant",
|
||||
"content": dialog,
|
||||
"_no_save": None, # 不持久化到 db
|
||||
"_no_save": True, # 不持久化到 db
|
||||
},
|
||||
)
|
||||
user_turn = not user_turn
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""使用此功能应该先 pip install baidu-aip"""
|
||||
|
||||
from typing import Any, cast
|
||||
|
||||
from aip import AipContentCensor
|
||||
|
||||
from . import ContentSafetyStrategy
|
||||
@@ -23,7 +25,8 @@ class BaiduAipStrategy(ContentSafetyStrategy):
|
||||
count = len(res["data"])
|
||||
parts = [f"百度审核服务发现 {count} 处违规:\n"]
|
||||
for i in res["data"]:
|
||||
parts.append(f"{i['msg']};\n")
|
||||
# 百度 AIP 返回结构是动态 dict;类型检查时 i 可能被推断为序列,转成 dict 后用 get 取字段
|
||||
parts.append(f"{cast(dict[str, Any], i).get('msg', '')};\n")
|
||||
parts.append("\n判断结果:" + res["conclusion"])
|
||||
info = "".join(parts)
|
||||
return False, info
|
||||
|
||||
@@ -92,6 +92,7 @@ class InternalAgentSubStage(Stage):
|
||||
"safety_mode_strategy", "system_prompt"
|
||||
)
|
||||
|
||||
self.computer_use_runtime = settings.get("computer_use_runtime")
|
||||
self.sandbox_cfg = settings.get("sandbox", {})
|
||||
|
||||
# Proactive capability configuration
|
||||
@@ -116,6 +117,7 @@ class InternalAgentSubStage(Stage):
|
||||
dequeue_context_length=self.dequeue_context_length,
|
||||
llm_safety_mode=self.llm_safety_mode,
|
||||
safety_mode_strategy=self.safety_mode_strategy,
|
||||
computer_use_runtime=self.computer_use_runtime,
|
||||
sandbox_cfg=self.sandbox_cfg,
|
||||
add_cron_tools=self.add_cron_tools,
|
||||
provider_settings=settings,
|
||||
@@ -162,6 +164,7 @@ class InternalAgentSubStage(Stage):
|
||||
event=event,
|
||||
plugin_context=self.ctx.plugin_manager.context,
|
||||
config=build_cfg,
|
||||
apply_reset=False,
|
||||
)
|
||||
|
||||
if build_result is None:
|
||||
@@ -170,6 +173,7 @@ class InternalAgentSubStage(Stage):
|
||||
agent_runner = build_result.agent_runner
|
||||
req = build_result.provider_request
|
||||
provider = build_result.provider
|
||||
reset_coro = build_result.reset_coro
|
||||
|
||||
api_base = provider.provider_config.get("api_base", "")
|
||||
for host in decoded_blocked:
|
||||
@@ -188,6 +192,10 @@ class InternalAgentSubStage(Stage):
|
||||
if await call_event_hook(event, EventType.OnLLMRequestEvent, req):
|
||||
return
|
||||
|
||||
# apply reset
|
||||
if reset_coro:
|
||||
await reset_coro
|
||||
|
||||
action_type = event.get_extra("action_type")
|
||||
|
||||
event.trace.record(
|
||||
@@ -347,15 +355,14 @@ class InternalAgentSubStage(Stage):
|
||||
if message.role == "system" and not skipped_initial_system:
|
||||
skipped_initial_system = True
|
||||
continue
|
||||
if message.role in ["assistant", "user"] and getattr(
|
||||
message, "_no_save", None
|
||||
):
|
||||
if message.role in ["assistant", "user"] and message._no_save:
|
||||
continue
|
||||
message_to_save.append(message.model_dump())
|
||||
|
||||
token_usage = None
|
||||
if runner_stats:
|
||||
token_usage = runner_stats.token_usage.total
|
||||
# token_usage = runner_stats.token_usage.total
|
||||
token_usage = llm_response.usage.total if llm_response.usage else None
|
||||
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin,
|
||||
|
||||
@@ -85,6 +85,4 @@ class PipelineScheduler:
|
||||
if isinstance(event, WebChatMessageEvent | WecomAIBotMessageEvent):
|
||||
await event.send(None)
|
||||
|
||||
event.trace.record("event_end")
|
||||
|
||||
logger.debug("pipeline 执行完毕。")
|
||||
|
||||
@@ -8,6 +8,7 @@ from time import time
|
||||
from typing import Any
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.message.components import (
|
||||
At,
|
||||
@@ -73,9 +74,6 @@ class AstrMessageEvent(abc.ABC):
|
||||
self.span = self.trace
|
||||
"""事件级 TraceSpan(别名: span)"""
|
||||
|
||||
self.trace.record("umo", umo=self.unified_msg_origin)
|
||||
self.trace.record("event_created", created_at=self.created_at)
|
||||
|
||||
self._has_send_oper = False
|
||||
"""在此次事件中是否有过至少一次发送消息的操作"""
|
||||
self.call_llm = False
|
||||
@@ -358,6 +356,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
self,
|
||||
prompt: str,
|
||||
func_tool_manager=None,
|
||||
tool_set: ToolSet | None = None,
|
||||
session_id: str = "",
|
||||
image_urls: list[str] | None = None,
|
||||
contexts: list | None = None,
|
||||
@@ -380,7 +379,7 @@ class AstrMessageEvent(abc.ABC):
|
||||
|
||||
contexts: 当指定 contexts 时,将会使用 contexts 作为上下文。如果同时传入了 conversation,将会忽略 conversation。
|
||||
|
||||
func_tool_manager: 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。
|
||||
func_tool_manager: [Deprecated] 函数工具管理器,用于调用函数工具。用 self.context.get_llm_tool_manager() 获取。已过时,请使用 tool_set 参数代替。
|
||||
|
||||
conversation: 可选。如果指定,将在指定的对话中进行 LLM 请求。对话的人格会被用于 LLM 请求,并且结果将会被记录到对话中。
|
||||
|
||||
@@ -396,7 +395,8 @@ class AstrMessageEvent(abc.ABC):
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool_manager,
|
||||
# func_tool=func_tool_manager,
|
||||
func_tool=tool_set,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
conversation=conversation,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
|
||||
@@ -13,7 +13,7 @@ class MessageSession:
|
||||
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
|
||||
message_type: MessageType
|
||||
session_id: str
|
||||
platform_id: str | None = None
|
||||
platform_id: str = field(init=False)
|
||||
|
||||
def __str__(self):
|
||||
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
|
||||
|
||||
@@ -21,3 +21,6 @@ class PlatformMetadata:
|
||||
"""平台是否支持真实流式传输"""
|
||||
support_proactive_message: bool = True
|
||||
"""平台是否支持主动消息推送(非用户触发)"""
|
||||
|
||||
module_path: str | None = None
|
||||
"""注册该适配器的模块路径,用于插件热重载时清理"""
|
||||
|
||||
@@ -37,6 +37,9 @@ def register_platform_adapter(
|
||||
if "id" not in default_config_tmpl:
|
||||
default_config_tmpl["id"] = adapter_name
|
||||
|
||||
# Get the module path of the class being decorated
|
||||
module_path = cls.__module__
|
||||
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
description=desc,
|
||||
@@ -45,6 +48,7 @@ def register_platform_adapter(
|
||||
adapter_display_name=adapter_display_name,
|
||||
logo_path=logo_path,
|
||||
support_streaming_message=support_streaming_message,
|
||||
module_path=module_path,
|
||||
)
|
||||
platform_registry.append(pm)
|
||||
platform_cls_map[adapter_name] = cls
|
||||
@@ -52,3 +56,31 @@ def register_platform_adapter(
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def unregister_platform_adapters_by_module(module_path_prefix: str) -> list[str]:
|
||||
"""根据模块路径前缀注销平台适配器。
|
||||
|
||||
在插件热重载时调用,用于清理该插件注册的所有平台适配器。
|
||||
|
||||
Args:
|
||||
module_path_prefix: 模块路径前缀,如 "data.plugins.my_plugin"
|
||||
|
||||
Returns:
|
||||
被注销的平台适配器名称列表
|
||||
"""
|
||||
unregistered = []
|
||||
to_remove = []
|
||||
|
||||
for pm in platform_registry:
|
||||
if pm.module_path and pm.module_path.startswith(module_path_prefix):
|
||||
to_remove.append(pm)
|
||||
unregistered.append(pm.name)
|
||||
|
||||
for pm in to_remove:
|
||||
platform_registry.remove(pm)
|
||||
if pm.name in platform_cls_map:
|
||||
del platform_cls_map[pm.name]
|
||||
logger.debug(f"平台适配器 {pm.name} 已注销 (来自模块 {pm.module_path})")
|
||||
|
||||
return unregistered
|
||||
|
||||
@@ -444,9 +444,20 @@ class DiscordPlatformAdapter(Platform):
|
||||
logger.warning(f"[Discord] 指令 '{cmd_name}' defer 失败: {e}")
|
||||
|
||||
# 2. 构建 AstrBotMessage
|
||||
channel = ctx.channel
|
||||
abm = AstrBotMessage()
|
||||
abm.type = self._get_message_type(ctx.channel, ctx.guild_id)
|
||||
abm.group_id = self._get_channel_id(ctx.channel)
|
||||
if channel is not None:
|
||||
abm.type = self._get_message_type(channel, ctx.guild_id)
|
||||
abm.group_id = self._get_channel_id(channel)
|
||||
else:
|
||||
# 防守式兜底:channel 取不到时,仍能根据 guild_id/channel_id 推断会话信息
|
||||
abm.type = (
|
||||
MessageType.GROUP_MESSAGE
|
||||
if ctx.guild_id is not None
|
||||
else MessageType.FRIEND_MESSAGE
|
||||
)
|
||||
abm.group_id = str(ctx.channel_id)
|
||||
|
||||
abm.message_str = message_str_for_filter
|
||||
abm.sender = MessageMember(
|
||||
user_id=str(ctx.author.id),
|
||||
|
||||
@@ -3,13 +3,10 @@ import base64
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, cast
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
GetMessageResourceRequest,
|
||||
)
|
||||
from lark_oapi.api.im.v1.processor import P2ImMessageReceiveV1Processor
|
||||
@@ -125,44 +122,23 @@ class LarkPlatformAdapter(Platform):
|
||||
session: MessageSesion,
|
||||
message_chain: MessageChain,
|
||||
):
|
||||
if self.lark_api.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送消息")
|
||||
return
|
||||
|
||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
},
|
||||
}
|
||||
|
||||
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||
id_type = "chat_id"
|
||||
if "%" in session.session_id:
|
||||
session.session_id = session.session_id.split("%")[1]
|
||||
receive_id = session.session_id
|
||||
if "%" in receive_id:
|
||||
receive_id = receive_id.split("%")[1]
|
||||
else:
|
||||
id_type = "open_id"
|
||||
receive_id = session.session_id
|
||||
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(session.session_id)
|
||||
.content(json.dumps(wrapped))
|
||||
.msg_type("post")
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build(),
|
||||
)
|
||||
.build()
|
||||
# 复用 LarkMessageEvent 中的通用发送逻辑
|
||||
await LarkMessageEvent.send_message_chain(
|
||||
message_chain,
|
||||
self.lark_api,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=id_type,
|
||||
)
|
||||
|
||||
response = await self.lark_api.im.v1.message.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
|
||||
@@ -6,6 +6,8 @@ from io import BytesIO
|
||||
|
||||
import lark_oapi as lark
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateFileRequest,
|
||||
CreateFileRequestBody,
|
||||
CreateImageRequest,
|
||||
CreateImageRequestBody,
|
||||
CreateMessageReactionRequest,
|
||||
@@ -17,10 +19,15 @@ from lark_oapi.api.im.v1 import (
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import At, Plain
|
||||
from astrbot.api.message_components import At, File, Plain, Record, Video
|
||||
from astrbot.api.message_components import Image as AstrBotImage
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.media_utils import (
|
||||
convert_audio_to_opus,
|
||||
convert_video_format,
|
||||
get_media_duration,
|
||||
)
|
||||
|
||||
|
||||
class LarkMessageEvent(AstrMessageEvent):
|
||||
@@ -35,6 +42,144 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.bot = bot
|
||||
|
||||
@staticmethod
|
||||
async def _send_im_message(
|
||||
lark_client: lark.Client,
|
||||
*,
|
||||
content: str,
|
||||
msg_type: str,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
) -> bool:
|
||||
"""发送飞书 IM 消息的通用辅助函数
|
||||
|
||||
Args:
|
||||
lark_client: 飞书客户端
|
||||
content: 消息内容(JSON字符串)
|
||||
msg_type: 消息类型(post/file/audio/media等)
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
|
||||
Returns:
|
||||
是否发送成功
|
||||
"""
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
return False
|
||||
|
||||
if reply_message_id:
|
||||
request = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(reply_message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(content)
|
||||
.msg_type(msg_type)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.reply_in_thread(False)
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.message.areply(request)
|
||||
else:
|
||||
from lark_oapi.api.im.v1 import (
|
||||
CreateMessageRequest,
|
||||
CreateMessageRequestBody,
|
||||
)
|
||||
|
||||
if receive_id_type is None or receive_id is None:
|
||||
logger.error(
|
||||
"[Lark] 主动发送消息时,receive_id 和 receive_id_type 不能为空",
|
||||
)
|
||||
return False
|
||||
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(receive_id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(receive_id)
|
||||
.content(content)
|
||||
.msg_type(msg_type)
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.message.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"[Lark] 发送飞书消息失败({response.code}): {response.msg}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _upload_lark_file(
|
||||
lark_client: lark.Client,
|
||||
*,
|
||||
path: str,
|
||||
file_type: str,
|
||||
duration: int | None = None,
|
||||
) -> str | None:
|
||||
"""上传文件到飞书的通用辅助函数
|
||||
|
||||
Args:
|
||||
lark_client: 飞书客户端
|
||||
path: 文件路径
|
||||
file_type: 文件类型(stream/opus/mp4等)
|
||||
duration: 媒体时长(毫秒),可选
|
||||
|
||||
Returns:
|
||||
成功返回file_key,失败返回None
|
||||
"""
|
||||
if not path or not os.path.exists(path):
|
||||
logger.error(f"[Lark] 文件不存在: {path}")
|
||||
return None
|
||||
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法上传文件")
|
||||
return None
|
||||
|
||||
try:
|
||||
with open(path, "rb") as file_obj:
|
||||
body_builder = (
|
||||
CreateFileRequestBody.builder()
|
||||
.file_type(file_type)
|
||||
.file_name(os.path.basename(path))
|
||||
.file(file_obj)
|
||||
)
|
||||
if duration is not None:
|
||||
body_builder.duration(duration)
|
||||
|
||||
request = (
|
||||
CreateFileRequest.builder()
|
||||
.request_body(body_builder.build())
|
||||
.build()
|
||||
)
|
||||
response = await lark_client.im.v1.file.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(
|
||||
f"[Lark] 无法上传文件({response.code}): {response.msg}"
|
||||
)
|
||||
return None
|
||||
|
||||
if response.data is None:
|
||||
logger.error("[Lark] 上传文件成功但未返回数据(data is None)")
|
||||
return None
|
||||
|
||||
file_key = response.data.file_key
|
||||
logger.debug(f"[Lark] 文件上传成功: {file_key}")
|
||||
return file_key
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法打开或上传文件: {e}")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _convert_to_lark(message: MessageChain, lark_client: lark.Client) -> list:
|
||||
ret = []
|
||||
@@ -103,6 +248,18 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
ret.append(_stage)
|
||||
ret.append([{"tag": "img", "image_key": image_key}])
|
||||
_stage.clear()
|
||||
elif isinstance(comp, File):
|
||||
# 文件将通过 _send_file_message 方法单独发送,这里跳过
|
||||
logger.debug("[Lark] 检测到文件组件,将单独发送")
|
||||
continue
|
||||
elif isinstance(comp, Record):
|
||||
# 音频将通过 _send_audio_message 方法单独发送,这里跳过
|
||||
logger.debug("[Lark] 检测到音频组件,将单独发送")
|
||||
continue
|
||||
elif isinstance(comp, Video):
|
||||
# 视频将通过 _send_media_message 方法单独发送,这里跳过
|
||||
logger.debug("[Lark] 检测到视频组件,将单独发送")
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"飞书 暂时不支持消息段: {comp.type}")
|
||||
|
||||
@@ -110,40 +267,270 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
ret.append(_stage)
|
||||
return ret
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
res = await LarkMessageEvent._convert_to_lark(message, self.bot)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
},
|
||||
}
|
||||
@staticmethod
|
||||
async def send_message_chain(
|
||||
message_chain: MessageChain,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""通用的消息链发送方法
|
||||
|
||||
request = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(self.message_obj.message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder()
|
||||
.content(json.dumps(wrapped))
|
||||
.msg_type("post")
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.reply_in_thread(False)
|
||||
.build(),
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法回复消息")
|
||||
Args:
|
||||
message_chain: 要发送的消息链
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型,如 'open_id', 'chat_id'(用于主动发送)
|
||||
"""
|
||||
if lark_client.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化")
|
||||
return
|
||||
|
||||
response = await self.bot.im.v1.message.areply(request)
|
||||
# 分离文件、音频、视频组件和其他组件
|
||||
file_components: list[File] = []
|
||||
audio_components: list[Record] = []
|
||||
media_components: list[Video] = []
|
||||
other_components = []
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, File):
|
||||
file_components.append(comp)
|
||||
elif isinstance(comp, Record):
|
||||
audio_components.append(comp)
|
||||
elif isinstance(comp, Video):
|
||||
media_components.append(comp)
|
||||
else:
|
||||
other_components.append(comp)
|
||||
|
||||
# 先发送非文件内容(如果有)
|
||||
if other_components:
|
||||
temp_chain = MessageChain()
|
||||
temp_chain.chain = other_components
|
||||
res = await LarkMessageEvent._convert_to_lark(temp_chain, lark_client)
|
||||
|
||||
if res: # 只在有内容时发送
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
},
|
||||
}
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=json.dumps(wrapped),
|
||||
msg_type="post",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
# 发送附件
|
||||
for file_comp in file_components:
|
||||
await LarkMessageEvent._send_file_message(
|
||||
file_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
for audio_comp in audio_components:
|
||||
await LarkMessageEvent._send_audio_message(
|
||||
audio_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
for media_comp in media_components:
|
||||
await LarkMessageEvent._send_media_message(
|
||||
media_comp, lark_client, reply_message_id, receive_id, receive_id_type
|
||||
)
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
"""发送消息链到飞书,然后交给父类做框架级发送/记录"""
|
||||
await LarkMessageEvent.send_message_chain(
|
||||
message,
|
||||
self.bot,
|
||||
reply_message_id=self.message_obj.message_id,
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
@staticmethod
|
||||
async def _send_file_message(
|
||||
file_comp: File,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""发送文件消息
|
||||
|
||||
Args:
|
||||
file_comp: 文件组件
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
"""
|
||||
file_path = file_comp.file or ""
|
||||
file_key = await LarkMessageEvent._upload_lark_file(
|
||||
lark_client, path=file_path, file_type="stream"
|
||||
)
|
||||
if not file_key:
|
||||
return
|
||||
|
||||
content = json.dumps({"file_key": file_key})
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=content,
|
||||
msg_type="file",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_audio_message(
|
||||
audio_comp: Record,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""发送音频消息
|
||||
|
||||
Args:
|
||||
audio_comp: 音频组件
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
"""
|
||||
# 获取音频文件路径
|
||||
try:
|
||||
original_audio_path = await audio_comp.convert_to_file_path()
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法获取音频文件路径: {e}")
|
||||
return
|
||||
|
||||
if not original_audio_path or not os.path.exists(original_audio_path):
|
||||
logger.error(f"[Lark] 音频文件不存在: {original_audio_path}")
|
||||
return
|
||||
|
||||
# 转换为opus格式
|
||||
converted_audio_path = None
|
||||
try:
|
||||
audio_path = await convert_audio_to_opus(original_audio_path)
|
||||
# 如果转换后路径与原路径不同,说明生成了新文件
|
||||
if audio_path != original_audio_path:
|
||||
converted_audio_path = audio_path
|
||||
else:
|
||||
audio_path = original_audio_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 音频格式转换失败,将尝试直接上传: {e}")
|
||||
# 如果转换失败,继续尝试直接上传原始文件
|
||||
audio_path = original_audio_path
|
||||
|
||||
# 获取音频时长
|
||||
duration = await get_media_duration(audio_path)
|
||||
|
||||
# 上传音频文件
|
||||
file_key = await LarkMessageEvent._upload_lark_file(
|
||||
lark_client,
|
||||
path=audio_path,
|
||||
file_type="opus",
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
# 清理转换后的临时音频文件
|
||||
if converted_audio_path and os.path.exists(converted_audio_path):
|
||||
try:
|
||||
os.remove(converted_audio_path)
|
||||
logger.debug(f"[Lark] 已删除转换后的音频文件: {converted_audio_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Lark] 删除转换后的音频文件失败: {e}")
|
||||
|
||||
if not file_key:
|
||||
return
|
||||
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=json.dumps({"file_key": file_key}),
|
||||
msg_type="audio",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _send_media_message(
|
||||
media_comp: Video,
|
||||
lark_client: lark.Client,
|
||||
reply_message_id: str | None = None,
|
||||
receive_id: str | None = None,
|
||||
receive_id_type: str | None = None,
|
||||
):
|
||||
"""发送视频消息
|
||||
|
||||
Args:
|
||||
media_comp: 视频组件
|
||||
lark_client: 飞书客户端
|
||||
reply_message_id: 回复的消息ID(用于回复消息)
|
||||
receive_id: 接收者ID(用于主动发送)
|
||||
receive_id_type: 接收者ID类型(用于主动发送)
|
||||
"""
|
||||
# 获取视频文件路径
|
||||
try:
|
||||
original_video_path = await media_comp.convert_to_file_path()
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 无法获取视频文件路径: {e}")
|
||||
return
|
||||
|
||||
if not original_video_path or not os.path.exists(original_video_path):
|
||||
logger.error(f"[Lark] 视频文件不存在: {original_video_path}")
|
||||
return
|
||||
|
||||
# 转换为mp4格式
|
||||
converted_video_path = None
|
||||
try:
|
||||
video_path = await convert_video_format(original_video_path, "mp4")
|
||||
# 如果转换后路径与原路径不同,说明生成了新文件
|
||||
if video_path != original_video_path:
|
||||
converted_video_path = video_path
|
||||
else:
|
||||
video_path = original_video_path
|
||||
except Exception as e:
|
||||
logger.error(f"[Lark] 视频格式转换失败,将尝试直接上传: {e}")
|
||||
# 如果转换失败,继续尝试直接上传原始文件
|
||||
video_path = original_video_path
|
||||
|
||||
# 获取视频时长
|
||||
duration = await get_media_duration(video_path)
|
||||
|
||||
# 上传视频文件
|
||||
file_key = await LarkMessageEvent._upload_lark_file(
|
||||
lark_client,
|
||||
path=video_path,
|
||||
file_type="mp4",
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
# 清理转换后的临时视频文件
|
||||
if converted_video_path and os.path.exists(converted_video_path):
|
||||
try:
|
||||
os.remove(converted_video_path)
|
||||
logger.debug(f"[Lark] 已删除转换后的视频文件: {converted_video_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[Lark] 删除转换后的视频文件失败: {e}")
|
||||
|
||||
if not file_key:
|
||||
return
|
||||
|
||||
await LarkMessageEvent._send_im_message(
|
||||
lark_client,
|
||||
content=json.dumps({"file_key": file_key}),
|
||||
msg_type="media",
|
||||
reply_message_id=reply_message_id,
|
||||
receive_id=receive_id,
|
||||
receive_id_type=receive_id_type,
|
||||
)
|
||||
|
||||
async def react(self, emoji: str):
|
||||
if self.bot.im is None:
|
||||
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
|
||||
|
||||
@@ -89,6 +89,16 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
|
||||
# Media group handling
|
||||
# Cache structure: {media_group_id: {"created_at": datetime, "items": [(update, context), ...]}}
|
||||
self.media_group_cache: dict[str, dict] = {}
|
||||
self.media_group_timeout = self.config.get(
|
||||
"telegram_media_group_timeout", 2.5
|
||||
) # seconds - debounce delay between messages
|
||||
self.media_group_max_wait = self.config.get(
|
||||
"telegram_media_group_max_wait", 10.0
|
||||
) # max seconds - hard cap to prevent indefinite delay
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self,
|
||||
@@ -225,6 +235,13 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
async def message_handler(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
logger.debug(f"Telegram message: {update.message}")
|
||||
|
||||
# Handle media group messages
|
||||
if update.message and update.message.media_group_id:
|
||||
await self.handle_media_group_message(update, context)
|
||||
return
|
||||
|
||||
# Handle regular messages
|
||||
abm = await self.convert_message(update, context)
|
||||
if abm:
|
||||
await self.handle_msg(abm)
|
||||
@@ -399,6 +416,113 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
return message
|
||||
|
||||
async def handle_media_group_message(
|
||||
self, update: Update, context: ContextTypes.DEFAULT_TYPE
|
||||
):
|
||||
"""Handle messages that are part of a media group (album).
|
||||
|
||||
Caches incoming messages and schedules delayed processing to collect all
|
||||
media items before sending to the pipeline. Uses debounce mechanism with
|
||||
a hard cap (max_wait) to prevent indefinite delay.
|
||||
"""
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
if not update.message:
|
||||
return
|
||||
|
||||
media_group_id = update.message.media_group_id
|
||||
if not media_group_id:
|
||||
return
|
||||
|
||||
# Initialize cache for this media group if needed
|
||||
if media_group_id not in self.media_group_cache:
|
||||
self.media_group_cache[media_group_id] = {
|
||||
"created_at": datetime.now(),
|
||||
"items": [],
|
||||
}
|
||||
logger.debug(f"Create media group cache: {media_group_id}")
|
||||
|
||||
# Add this message to the cache
|
||||
entry = self.media_group_cache[media_group_id]
|
||||
entry["items"].append((update, context))
|
||||
logger.debug(
|
||||
f"Add message to media group {media_group_id}, "
|
||||
f"currently has {len(entry['items'])} items.",
|
||||
)
|
||||
|
||||
# Calculate delay: if already waited too long, process immediately;
|
||||
# otherwise use normal debounce timeout
|
||||
elapsed = (datetime.now() - entry["created_at"]).total_seconds()
|
||||
if elapsed >= self.media_group_max_wait:
|
||||
delay = 0
|
||||
logger.debug(
|
||||
f"Media group {media_group_id} has reached max wait time "
|
||||
f"({elapsed:.1f}s >= {self.media_group_max_wait}s), processing immediately.",
|
||||
)
|
||||
else:
|
||||
delay = self.media_group_timeout
|
||||
logger.debug(
|
||||
f"Scheduled media group {media_group_id} to be processed in {delay} seconds "
|
||||
f"(already waited {elapsed:.1f}s)"
|
||||
)
|
||||
|
||||
# Schedule/reschedule processing (replace_existing=True handles debounce)
|
||||
job_id = f"media_group_{media_group_id}"
|
||||
self.scheduler.add_job(
|
||||
self.process_media_group,
|
||||
"date",
|
||||
run_date=datetime.now() + timedelta(seconds=delay),
|
||||
args=[media_group_id],
|
||||
id=job_id,
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
async def process_media_group(self, media_group_id: str):
|
||||
"""Process a complete media group by merging all collected messages.
|
||||
|
||||
Args:
|
||||
media_group_id: The unique identifier for this media group
|
||||
"""
|
||||
if media_group_id not in self.media_group_cache:
|
||||
logger.warning(f"Media group {media_group_id} not found in cache")
|
||||
return
|
||||
|
||||
entry = self.media_group_cache.pop(media_group_id)
|
||||
updates_and_contexts = entry["items"]
|
||||
if not updates_and_contexts:
|
||||
logger.warning(f"Media group {media_group_id} is empty")
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Processing media group {media_group_id}, total {len(updates_and_contexts)} items"
|
||||
)
|
||||
|
||||
# Use the first update to create the base message (with reply, caption, etc.)
|
||||
first_update, first_context = updates_and_contexts[0]
|
||||
abm = await self.convert_message(first_update, first_context)
|
||||
|
||||
if not abm:
|
||||
logger.warning(
|
||||
f"Failed to convert the first message of media group {media_group_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Add additional media from remaining updates by reusing convert_message
|
||||
for update, context in updates_and_contexts[1:]:
|
||||
# Convert the message but skip reply chains (get_reply=False)
|
||||
extra = await self.convert_message(update, context, get_reply=False)
|
||||
if not extra:
|
||||
continue
|
||||
|
||||
# Merge only the message components (keep base session/meta from first)
|
||||
abm.message.extend(extra.message)
|
||||
logger.debug(
|
||||
f"Added {len(extra.message)} components to media group {media_group_id}"
|
||||
)
|
||||
|
||||
# Process the merged message
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = TelegramPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
@@ -426,6 +550,6 @@ class TelegramPlatformAdapter(Platform):
|
||||
if self.application.updater is not None:
|
||||
await self.application.updater.stop()
|
||||
|
||||
logger.info("Telegram 适配器已被关闭")
|
||||
logger.info("Telegram adapter has been closed.")
|
||||
except Exception as e:
|
||||
logger.error(f"Telegram 适配器关闭时出错: {e}")
|
||||
logger.error(f"Error occurred while closing Telegram adapter: {e}")
|
||||
|
||||
@@ -29,43 +29,11 @@ class QueueListener:
|
||||
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
|
||||
self.webchat_queue_mgr = webchat_queue_mgr
|
||||
self.callback = callback
|
||||
self.running_tasks = set()
|
||||
|
||||
async def listen_to_queue(self, conversation_id: str):
|
||||
"""Listen to a specific conversation queue"""
|
||||
queue = self.webchat_queue_mgr.get_or_create_queue(conversation_id)
|
||||
while True:
|
||||
try:
|
||||
data = await queue.get()
|
||||
await self.callback(data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing message from conversation {conversation_id}: {e}",
|
||||
)
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
"""Monitor for new conversation queues and start listeners"""
|
||||
monitored_conversations = set()
|
||||
|
||||
while True:
|
||||
# Check for new conversations
|
||||
current_conversations = set(self.webchat_queue_mgr.queues.keys())
|
||||
new_conversations = current_conversations - monitored_conversations
|
||||
|
||||
# Start listeners for new conversations
|
||||
for conversation_id in new_conversations:
|
||||
task = asyncio.create_task(self.listen_to_queue(conversation_id))
|
||||
self.running_tasks.add(task)
|
||||
task.add_done_callback(self.running_tasks.discard)
|
||||
monitored_conversations.add(conversation_id)
|
||||
logger.debug(f"Started listener for conversation: {conversation_id}")
|
||||
|
||||
# Clean up monitored conversations that no longer exist
|
||||
removed_conversations = monitored_conversations - current_conversations
|
||||
monitored_conversations -= removed_conversations
|
||||
|
||||
await asyncio.sleep(1) # Check for new conversations every second
|
||||
"""Register callback and keep adapter task alive."""
|
||||
self.webchat_queue_mgr.set_listener(self.callback)
|
||||
await asyncio.Event().wait()
|
||||
|
||||
|
||||
@register_platform_adapter("webchat", "webchat")
|
||||
|
||||
@@ -26,8 +26,12 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
session_id: str,
|
||||
streaming: bool = False,
|
||||
) -> str | None:
|
||||
cid = session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
request_id = str(message_id)
|
||||
conversation_id = session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id,
|
||||
conversation_id,
|
||||
)
|
||||
if not message:
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
@@ -124,9 +128,13 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
reasoning_content = ""
|
||||
cid = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
message_id = self.message_obj.message_id
|
||||
request_id = str(message_id)
|
||||
conversation_id = self.session_id.split("!")[-1]
|
||||
web_chat_back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
request_id,
|
||||
conversation_id,
|
||||
)
|
||||
async for chain in generator:
|
||||
# 处理音频流(Live Mode)
|
||||
if chain.type == "audio_chunk":
|
||||
|
||||
@@ -1,35 +1,147 @@
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
class WebChatQueueMgr:
|
||||
def __init__(self) -> None:
|
||||
self.queues = {}
|
||||
def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None:
|
||||
self.queues: dict[str, asyncio.Queue] = {}
|
||||
"""Conversation ID to asyncio.Queue mapping"""
|
||||
self.back_queues = {}
|
||||
"""Conversation ID to asyncio.Queue mapping for responses"""
|
||||
self.back_queues: dict[str, asyncio.Queue] = {}
|
||||
"""Request ID to asyncio.Queue mapping for responses"""
|
||||
self._conversation_back_requests: dict[str, set[str]] = {}
|
||||
self._request_conversation: dict[str, str] = {}
|
||||
self._queue_close_events: dict[str, asyncio.Event] = {}
|
||||
self._listener_tasks: dict[str, asyncio.Task] = {}
|
||||
self._listener_callback: Callable[[tuple], Awaitable[None]] | None = None
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.back_queue_maxsize = back_queue_maxsize
|
||||
|
||||
def get_or_create_queue(self, conversation_id: str) -> asyncio.Queue:
|
||||
"""Get or create a queue for the given conversation ID"""
|
||||
if conversation_id not in self.queues:
|
||||
self.queues[conversation_id] = asyncio.Queue()
|
||||
self.queues[conversation_id] = asyncio.Queue(maxsize=self.queue_maxsize)
|
||||
self._queue_close_events[conversation_id] = asyncio.Event()
|
||||
self._start_listener_if_needed(conversation_id)
|
||||
return self.queues[conversation_id]
|
||||
|
||||
def get_or_create_back_queue(self, conversation_id: str) -> asyncio.Queue:
|
||||
"""Get or create a back queue for the given conversation ID"""
|
||||
if conversation_id not in self.back_queues:
|
||||
self.back_queues[conversation_id] = asyncio.Queue()
|
||||
return self.back_queues[conversation_id]
|
||||
def get_or_create_back_queue(
|
||||
self,
|
||||
request_id: str,
|
||||
conversation_id: str | None = None,
|
||||
) -> asyncio.Queue:
|
||||
"""Get or create a back queue for the given request ID"""
|
||||
if request_id not in self.back_queues:
|
||||
self.back_queues[request_id] = asyncio.Queue(
|
||||
maxsize=self.back_queue_maxsize
|
||||
)
|
||||
if conversation_id:
|
||||
self._request_conversation[request_id] = conversation_id
|
||||
if conversation_id not in self._conversation_back_requests:
|
||||
self._conversation_back_requests[conversation_id] = set()
|
||||
self._conversation_back_requests[conversation_id].add(request_id)
|
||||
return self.back_queues[request_id]
|
||||
|
||||
def remove_back_queue(self, request_id: str):
|
||||
"""Remove back queue for the given request ID"""
|
||||
self.back_queues.pop(request_id, None)
|
||||
conversation_id = self._request_conversation.pop(request_id, None)
|
||||
if conversation_id:
|
||||
request_ids = self._conversation_back_requests.get(conversation_id)
|
||||
if request_ids is not None:
|
||||
request_ids.discard(request_id)
|
||||
if not request_ids:
|
||||
self._conversation_back_requests.pop(conversation_id, None)
|
||||
|
||||
def remove_queues(self, conversation_id: str):
|
||||
"""Remove queues for the given conversation ID"""
|
||||
if conversation_id in self.queues:
|
||||
del self.queues[conversation_id]
|
||||
if conversation_id in self.back_queues:
|
||||
del self.back_queues[conversation_id]
|
||||
for request_id in list(
|
||||
self._conversation_back_requests.get(conversation_id, set())
|
||||
):
|
||||
self.remove_back_queue(request_id)
|
||||
self._conversation_back_requests.pop(conversation_id, None)
|
||||
self.remove_queue(conversation_id)
|
||||
|
||||
def remove_queue(self, conversation_id: str):
|
||||
"""Remove input queue and listener for the given conversation ID"""
|
||||
self.queues.pop(conversation_id, None)
|
||||
|
||||
close_event = self._queue_close_events.pop(conversation_id, None)
|
||||
if close_event is not None:
|
||||
close_event.set()
|
||||
|
||||
task = self._listener_tasks.pop(conversation_id, None)
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
|
||||
def has_queue(self, conversation_id: str) -> bool:
|
||||
"""Check if a queue exists for the given conversation ID"""
|
||||
return conversation_id in self.queues
|
||||
|
||||
def set_listener(
|
||||
self,
|
||||
callback: Callable[[tuple], Awaitable[None]],
|
||||
):
|
||||
self._listener_callback = callback
|
||||
for conversation_id in list(self.queues.keys()):
|
||||
self._start_listener_if_needed(conversation_id)
|
||||
|
||||
def _start_listener_if_needed(self, conversation_id: str):
|
||||
if self._listener_callback is None:
|
||||
return
|
||||
if conversation_id in self._listener_tasks:
|
||||
task = self._listener_tasks[conversation_id]
|
||||
if not task.done():
|
||||
return
|
||||
queue = self.queues.get(conversation_id)
|
||||
close_event = self._queue_close_events.get(conversation_id)
|
||||
if queue is None or close_event is None:
|
||||
return
|
||||
task = asyncio.create_task(
|
||||
self._listen_to_queue(conversation_id, queue, close_event),
|
||||
name=f"webchat_listener_{conversation_id}",
|
||||
)
|
||||
self._listener_tasks[conversation_id] = task
|
||||
task.add_done_callback(
|
||||
lambda _: self._listener_tasks.pop(conversation_id, None)
|
||||
)
|
||||
logger.debug(f"Started listener for conversation: {conversation_id}")
|
||||
|
||||
async def _listen_to_queue(
|
||||
self,
|
||||
conversation_id: str,
|
||||
queue: asyncio.Queue,
|
||||
close_event: asyncio.Event,
|
||||
):
|
||||
while True:
|
||||
get_task = asyncio.create_task(queue.get())
|
||||
close_task = asyncio.create_task(close_event.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
{get_task, close_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if close_task in done:
|
||||
break
|
||||
data = get_task.result()
|
||||
if self._listener_callback is None:
|
||||
continue
|
||||
try:
|
||||
await self._listener_callback(data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error processing message from conversation {conversation_id}: {e}"
|
||||
)
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
if not get_task.done():
|
||||
get_task.cancel()
|
||||
if not close_task.done():
|
||||
close_task.cancel()
|
||||
|
||||
|
||||
webchat_queue_mgr = WebChatQueueMgr()
|
||||
|
||||
@@ -51,44 +51,13 @@ class WecomAIQueueListener:
|
||||
) -> None:
|
||||
self.queue_mgr = queue_mgr
|
||||
self.callback = callback
|
||||
self.running_tasks = set()
|
||||
|
||||
async def listen_to_queue(self, session_id: str):
|
||||
"""监听特定会话的队列"""
|
||||
queue = self.queue_mgr.get_or_create_queue(session_id)
|
||||
while True:
|
||||
try:
|
||||
data = await queue.get()
|
||||
await self.callback(data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
|
||||
break
|
||||
|
||||
async def run(self):
|
||||
"""监控新会话队列并启动监听器"""
|
||||
monitored_sessions = set()
|
||||
|
||||
"""注册监听回调并定期清理过期响应。"""
|
||||
self.queue_mgr.set_listener(self.callback)
|
||||
while True:
|
||||
# 检查新会话
|
||||
current_sessions = set(self.queue_mgr.queues.keys())
|
||||
new_sessions = current_sessions - monitored_sessions
|
||||
|
||||
# 为新会话启动监听器
|
||||
for session_id in new_sessions:
|
||||
task = asyncio.create_task(self.listen_to_queue(session_id))
|
||||
self.running_tasks.add(task)
|
||||
task.add_done_callback(self.running_tasks.discard)
|
||||
monitored_sessions.add(session_id)
|
||||
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
|
||||
|
||||
# 清理已不存在的会话
|
||||
removed_sessions = monitored_sessions - current_sessions
|
||||
monitored_sessions -= removed_sessions
|
||||
|
||||
# 清理过期的待处理响应
|
||||
self.queue_mgr.cleanup_expired_responses()
|
||||
|
||||
await asyncio.sleep(1) # 每秒检查一次新会话
|
||||
await asyncio.sleep(1)
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -212,7 +181,12 @@ class WecomAIBotAdapter(Platform):
|
||||
# wechat server is requesting for updates of a stream
|
||||
stream_id = message_data["stream"]["id"]
|
||||
if not self.queue_mgr.has_back_queue(stream_id):
|
||||
logger.error(f"Cannot find back queue for stream_id: {stream_id}")
|
||||
if self.queue_mgr.is_stream_finished(stream_id):
|
||||
logger.debug(
|
||||
f"Stream already finished, returning end message: {stream_id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Cannot find back queue for stream_id: {stream_id}")
|
||||
|
||||
# 返回结束标志,告诉微信服务器流已结束
|
||||
end_message = WecomAIBotStreamMessageBuilder.make_text_stream(
|
||||
@@ -243,10 +217,10 @@ class WecomAIBotAdapter(Platform):
|
||||
latest_plain_content = msg["data"] or ""
|
||||
elif msg["type"] == "image":
|
||||
image_base64.append(msg["image_data"])
|
||||
elif msg["type"] == "end":
|
||||
elif msg["type"] in {"end", "complete"}:
|
||||
# stream end
|
||||
finish = True
|
||||
self.queue_mgr.remove_queues(stream_id)
|
||||
self.queue_mgr.remove_queues(stream_id, mark_finished=True)
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from astrbot.api import logger
|
||||
@@ -12,7 +13,7 @@ from astrbot.api import logger
|
||||
class WecomAIQueueMgr:
|
||||
"""企业微信智能机器人队列管理器"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, queue_maxsize: int = 128, back_queue_maxsize: int = 512) -> None:
|
||||
self.queues: dict[str, asyncio.Queue] = {}
|
||||
"""StreamID 到输入队列的映射 - 用于接收用户消息"""
|
||||
|
||||
@@ -21,6 +22,13 @@ class WecomAIQueueMgr:
|
||||
|
||||
self.pending_responses: dict[str, dict[str, Any]] = {}
|
||||
"""待处理的响应缓存,用于流式响应"""
|
||||
self.completed_streams: dict[str, float] = {}
|
||||
"""已结束的 stream 缓存,用于兼容平台后续重复轮询"""
|
||||
self._queue_close_events: dict[str, asyncio.Event] = {}
|
||||
self._listener_tasks: dict[str, asyncio.Task] = {}
|
||||
self._listener_callback: Callable[[dict], Awaitable[None]] | None = None
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.back_queue_maxsize = back_queue_maxsize
|
||||
|
||||
def get_or_create_queue(self, session_id: str) -> asyncio.Queue:
|
||||
"""获取或创建指定会话的输入队列
|
||||
@@ -33,7 +41,9 @@ class WecomAIQueueMgr:
|
||||
|
||||
"""
|
||||
if session_id not in self.queues:
|
||||
self.queues[session_id] = asyncio.Queue()
|
||||
self.queues[session_id] = asyncio.Queue(maxsize=self.queue_maxsize)
|
||||
self._queue_close_events[session_id] = asyncio.Event()
|
||||
self._start_listener_if_needed(session_id)
|
||||
logger.debug(f"[WecomAI] 创建输入队列: {session_id}")
|
||||
return self.queues[session_id]
|
||||
|
||||
@@ -48,20 +58,21 @@ class WecomAIQueueMgr:
|
||||
|
||||
"""
|
||||
if session_id not in self.back_queues:
|
||||
self.back_queues[session_id] = asyncio.Queue()
|
||||
self.back_queues[session_id] = asyncio.Queue(
|
||||
maxsize=self.back_queue_maxsize
|
||||
)
|
||||
logger.debug(f"[WecomAI] 创建输出队列: {session_id}")
|
||||
return self.back_queues[session_id]
|
||||
|
||||
def remove_queues(self, session_id: str):
|
||||
def remove_queues(self, session_id: str, mark_finished: bool = False):
|
||||
"""移除指定会话的所有队列
|
||||
|
||||
Args:
|
||||
session_id: 会话ID
|
||||
mark_finished: 是否标记为已正常结束
|
||||
|
||||
"""
|
||||
if session_id in self.queues:
|
||||
del self.queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
|
||||
self.remove_queue(session_id)
|
||||
|
||||
if session_id in self.back_queues:
|
||||
del self.back_queues[session_id]
|
||||
@@ -70,6 +81,23 @@ class WecomAIQueueMgr:
|
||||
if session_id in self.pending_responses:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 移除待处理响应: {session_id}")
|
||||
if mark_finished:
|
||||
self.completed_streams[session_id] = asyncio.get_event_loop().time()
|
||||
logger.debug(f"[WecomAI] 标记流已结束: {session_id}")
|
||||
|
||||
def remove_queue(self, session_id: str):
|
||||
"""仅移除输入队列和对应监听任务"""
|
||||
if session_id in self.queues:
|
||||
del self.queues[session_id]
|
||||
logger.debug(f"[WecomAI] 移除输入队列: {session_id}")
|
||||
|
||||
close_event = self._queue_close_events.pop(session_id, None)
|
||||
if close_event is not None:
|
||||
close_event.set()
|
||||
|
||||
task = self._listener_tasks.pop(session_id, None)
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
|
||||
def has_queue(self, session_id: str) -> bool:
|
||||
"""检查是否存在指定会话的队列
|
||||
@@ -121,6 +149,20 @@ class WecomAIQueueMgr:
|
||||
"""
|
||||
return self.pending_responses.get(session_id)
|
||||
|
||||
def is_stream_finished(
|
||||
self,
|
||||
session_id: str,
|
||||
max_age_seconds: int = 60,
|
||||
) -> bool:
|
||||
"""判断 stream 是否在短期内已结束"""
|
||||
finished_at = self.completed_streams.get(session_id)
|
||||
if finished_at is None:
|
||||
return False
|
||||
if asyncio.get_event_loop().time() - finished_at > max_age_seconds:
|
||||
self.completed_streams.pop(session_id, None)
|
||||
return False
|
||||
return True
|
||||
|
||||
def cleanup_expired_responses(self, max_age_seconds: int = 300):
|
||||
"""清理过期的待处理响应
|
||||
|
||||
@@ -136,8 +178,75 @@ class WecomAIQueueMgr:
|
||||
expired_sessions.append(session_id)
|
||||
|
||||
for session_id in expired_sessions:
|
||||
del self.pending_responses[session_id]
|
||||
logger.debug(f"[WecomAI] 清理过期响应: {session_id}")
|
||||
self.remove_queues(session_id)
|
||||
logger.debug(f"[WecomAI] 清理过期响应及队列: {session_id}")
|
||||
expired_finished = [
|
||||
session_id
|
||||
for session_id, finished_at in self.completed_streams.items()
|
||||
if current_time - finished_at > 60
|
||||
]
|
||||
for session_id in expired_finished:
|
||||
self.completed_streams.pop(session_id, None)
|
||||
|
||||
def set_listener(
|
||||
self,
|
||||
callback: Callable[[dict], Awaitable[None]],
|
||||
):
|
||||
self._listener_callback = callback
|
||||
for session_id in list(self.queues.keys()):
|
||||
self._start_listener_if_needed(session_id)
|
||||
|
||||
def _start_listener_if_needed(self, session_id: str):
|
||||
if self._listener_callback is None:
|
||||
return
|
||||
if session_id in self._listener_tasks:
|
||||
task = self._listener_tasks[session_id]
|
||||
if not task.done():
|
||||
return
|
||||
queue = self.queues.get(session_id)
|
||||
close_event = self._queue_close_events.get(session_id)
|
||||
if queue is None or close_event is None:
|
||||
return
|
||||
task = asyncio.create_task(
|
||||
self._listen_to_queue(session_id, queue, close_event),
|
||||
name=f"wecomai_listener_{session_id}",
|
||||
)
|
||||
self._listener_tasks[session_id] = task
|
||||
task.add_done_callback(lambda _: self._listener_tasks.pop(session_id, None))
|
||||
logger.debug(f"[WecomAI] 为会话启动监听器: {session_id}")
|
||||
|
||||
async def _listen_to_queue(
|
||||
self,
|
||||
session_id: str,
|
||||
queue: asyncio.Queue,
|
||||
close_event: asyncio.Event,
|
||||
):
|
||||
while True:
|
||||
get_task = asyncio.create_task(queue.get())
|
||||
close_task = asyncio.create_task(close_event.wait())
|
||||
try:
|
||||
done, pending = await asyncio.wait(
|
||||
{get_task, close_task},
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if close_task in done:
|
||||
break
|
||||
data = get_task.result()
|
||||
if self._listener_callback is None:
|
||||
continue
|
||||
try:
|
||||
await self._listener_callback(data)
|
||||
except Exception as e:
|
||||
logger.error(f"处理会话 {session_id} 消息时发生错误: {e}")
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
finally:
|
||||
if not get_task.done():
|
||||
get_task.cancel()
|
||||
if not close_task.done():
|
||||
close_task.cancel()
|
||||
|
||||
def get_stats(self) -> dict[str, int]:
|
||||
"""获取队列统计信息
|
||||
|
||||
@@ -3,6 +3,7 @@ import json
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import anthropic
|
||||
import httpx
|
||||
from anthropic import AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.message_delta_usage import MessageDeltaUsage
|
||||
@@ -14,6 +15,11 @@ from astrbot.core.agent.message import ContentPart, ImageURLPart, TextPart
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.network_utils import (
|
||||
create_proxy_client,
|
||||
is_connection_error,
|
||||
log_connection_failure,
|
||||
)
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
@@ -45,12 +51,18 @@ class ProviderAnthropic(Provider):
|
||||
api_key=self.chosen_api_key,
|
||||
timeout=self.timeout,
|
||||
base_url=self.base_url,
|
||||
http_client=self._create_http_client(provider_config),
|
||||
)
|
||||
|
||||
self.thinking_config = provider_config.get("anth_thinking_config", {})
|
||||
|
||||
self.set_model(provider_config.get("model", "unknown"))
|
||||
|
||||
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
|
||||
"""创建带代理的 HTTP 客户端"""
|
||||
proxy = provider_config.get("proxy", "")
|
||||
return create_proxy_client("Anthropic", proxy)
|
||||
|
||||
def _prepare_payload(self, messages: list[dict]):
|
||||
"""准备 Anthropic API 的请求 payload
|
||||
|
||||
@@ -207,9 +219,19 @@ class ProviderAnthropic(Provider):
|
||||
"type": "enabled",
|
||||
}
|
||||
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
try:
|
||||
completion = await self.client.messages.create(
|
||||
**payloads, stream=False, extra_body=extra_body
|
||||
)
|
||||
except httpx.RequestError as e:
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
log_connection_failure("Anthropic", e, proxy)
|
||||
raise
|
||||
except Exception as e:
|
||||
if is_connection_error(e):
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
log_connection_failure("Anthropic", e, proxy)
|
||||
raise
|
||||
|
||||
assert isinstance(completion, Message)
|
||||
logger.debug(f"completion: {completion}")
|
||||
@@ -622,3 +644,7 @@ class ProviderAnthropic(Provider):
|
||||
|
||||
def set_key(self, key: str):
|
||||
self.chosen_api_key = key
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -10,6 +10,7 @@ from xml.sax.saxutils import escape
|
||||
|
||||
from httpx import AsyncClient, Timeout
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
from ..entities import ProviderType
|
||||
@@ -29,6 +30,9 @@ class OTTSProvider:
|
||||
self.last_sync_time = 0
|
||||
self.timeout = Timeout(10.0)
|
||||
self.retry_count = 3
|
||||
self.proxy = config.get("proxy", "")
|
||||
if self.proxy:
|
||||
logger.info(f"[Azure TTS] 使用代理: {self.proxy}")
|
||||
self._client: AsyncClient | None = None
|
||||
|
||||
@property
|
||||
@@ -40,7 +44,9 @@ class OTTSProvider:
|
||||
return self._client
|
||||
|
||||
async def __aenter__(self):
|
||||
self._client = AsyncClient(timeout=self.timeout)
|
||||
self._client = AsyncClient(
|
||||
timeout=self.timeout, proxy=self.proxy if self.proxy else None
|
||||
)
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
@@ -125,6 +131,9 @@ class AzureNativeProvider(TTSProvider):
|
||||
"rate": provider_config.get("azure_tts_rate", "1"),
|
||||
"volume": provider_config.get("azure_tts_volume", "100"),
|
||||
}
|
||||
self.proxy = provider_config.get("proxy", "")
|
||||
if self.proxy:
|
||||
logger.info(f"[Azure TTS Native] 使用代理: {self.proxy}")
|
||||
|
||||
@property
|
||||
def client(self) -> AsyncClient:
|
||||
@@ -141,6 +150,7 @@ class AzureNativeProvider(TTSProvider):
|
||||
"Content-Type": "application/ssml+xml",
|
||||
"X-Microsoft-OutputFormat": "riff-48khz-16bit-mono-pcm",
|
||||
},
|
||||
proxy=self.proxy if self.proxy else None,
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import ormsgpack
|
||||
from httpx import AsyncClient
|
||||
from pydantic import BaseModel, conint
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
@@ -60,10 +61,13 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
self.timeout: int = int(provider_config.get("timeout", 20))
|
||||
except ValueError:
|
||||
self.timeout = 20
|
||||
self.proxy: str = provider_config.get("proxy", "")
|
||||
if self.proxy:
|
||||
logger.info(f"[FishAudio TTS] 使用代理: {self.proxy}")
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.chosen_api_key}",
|
||||
}
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.set_model(provider_config.get("model", ""))
|
||||
|
||||
async def _get_reference_id_by_character(self, character: str) -> str | None:
|
||||
"""获取角色的reference_id
|
||||
@@ -79,7 +83,10 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
|
||||
"""
|
||||
sort_options = ["score", "task_count", "created_at"]
|
||||
async with AsyncClient(base_url=self.api_base.replace("/v1", "")) as client:
|
||||
async with AsyncClient(
|
||||
base_url=self.api_base.replace("/v1", ""),
|
||||
proxy=self.proxy if self.proxy else None,
|
||||
) as client:
|
||||
for sort_by in sort_options:
|
||||
params = {"title": character, "sort_by": sort_by}
|
||||
response = await client.get(
|
||||
@@ -139,7 +146,11 @@ class ProviderFishAudioTTSAPI(TTSProvider):
|
||||
path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav")
|
||||
self.headers["content-type"] = "application/msgpack"
|
||||
request = await self._generate_request(text)
|
||||
async with AsyncClient(base_url=self.api_base, timeout=self.timeout).stream(
|
||||
async with AsyncClient(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout,
|
||||
proxy=self.proxy if self.proxy else None,
|
||||
).stream(
|
||||
"POST",
|
||||
"/tts",
|
||||
headers=self.headers,
|
||||
|
||||
@@ -4,6 +4,8 @@ from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
@@ -28,6 +30,10 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
if api_base:
|
||||
api_base = api_base.removesuffix("/")
|
||||
http_options.base_url = api_base
|
||||
proxy = provider_config.get("proxy", "")
|
||||
if proxy:
|
||||
http_options.async_client_args = {"proxy": proxy}
|
||||
logger.info(f"[Gemini Embedding] 使用代理: {proxy}")
|
||||
|
||||
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||
|
||||
@@ -69,3 +75,7 @@ class GeminiEmbeddingProvider(EmbeddingProvider):
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return int(self.provider_config.get("embedding_dimensions", 768))
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
@@ -18,6 +18,7 @@ from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage
|
||||
from astrbot.core.provider.func_tool_manager import ToolSet
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.network_utils import is_connection_error, log_connection_failure
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
@@ -74,12 +75,17 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
def _init_client(self) -> None:
|
||||
"""初始化Gemini客户端"""
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
http_options = types.HttpOptions(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
)
|
||||
if proxy:
|
||||
http_options.async_client_args = {"proxy": proxy}
|
||||
logger.info(f"[Gemini] 使用代理: {proxy}")
|
||||
self.client = genai.Client(
|
||||
api_key=self.chosen_api_key,
|
||||
http_options=types.HttpOptions(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
),
|
||||
http_options=http_options,
|
||||
).aio
|
||||
|
||||
def _init_safety_settings(self) -> None:
|
||||
@@ -113,9 +119,12 @@ class ProviderGoogleGenAI(Provider):
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}...",
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
# logger.error(
|
||||
# f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}",
|
||||
# )
|
||||
|
||||
# 连接错误处理
|
||||
if is_connection_error(e):
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
log_connection_failure("Gemini", e, proxy)
|
||||
|
||||
raise e
|
||||
|
||||
async def _prepare_query_config(
|
||||
@@ -920,4 +929,5 @@ class ProviderGoogleGenAI(Provider):
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
|
||||
async def terminate(self):
|
||||
logger.info("Google GenAI 适配器已终止。")
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
@@ -5,6 +5,7 @@ import wave
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
@@ -32,6 +33,10 @@ class ProviderGeminiTTSAPI(TTSProvider):
|
||||
if api_base:
|
||||
api_base = api_base.removesuffix("/")
|
||||
http_options.base_url = api_base
|
||||
proxy = provider_config.get("proxy", "")
|
||||
if proxy:
|
||||
http_options.async_client_args = {"proxy": proxy}
|
||||
logger.info(f"[Gemini TTS] 使用代理: {proxy}")
|
||||
|
||||
self.client = genai.Client(api_key=api_key, http_options=http_options).aio
|
||||
self.model: str = provider_config.get(
|
||||
@@ -79,3 +84,7 @@ class ProviderGeminiTTSAPI(TTSProvider):
|
||||
wf.writeframes(response.candidates[0].content.parts[0].inline_data.data)
|
||||
|
||||
return path
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.aclose()
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import httpx
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import EmbeddingProvider
|
||||
from ..register import register_provider_adapter
|
||||
@@ -15,6 +18,11 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.provider_config = provider_config
|
||||
self.provider_settings = provider_settings
|
||||
proxy = provider_config.get("proxy", "")
|
||||
http_client = None
|
||||
if proxy:
|
||||
logger.info(f"[OpenAI Embedding] 使用代理: {proxy}")
|
||||
http_client = httpx.AsyncClient(proxy=proxy)
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=provider_config.get("embedding_api_key"),
|
||||
base_url=provider_config.get(
|
||||
@@ -22,6 +30,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
"https://api.openai.com/v1",
|
||||
),
|
||||
timeout=int(provider_config.get("timeout", 20)),
|
||||
http_client=http_client,
|
||||
)
|
||||
self.model = provider_config.get("embedding_model", "text-embedding-3-small")
|
||||
|
||||
@@ -38,3 +47,7 @@ class OpenAIEmbeddingProvider(EmbeddingProvider):
|
||||
def get_dim(self) -> int:
|
||||
"""获取向量的维度"""
|
||||
return int(self.provider_config.get("embedding_dimensions", 1024))
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -2,11 +2,11 @@ import asyncio
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AsyncOpenAI
|
||||
from openai._exceptions import NotFoundError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
@@ -22,6 +22,11 @@ from astrbot.core.agent.tool import ToolSet
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, TokenUsage, ToolCallsResult
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.utils.network_utils import (
|
||||
create_proxy_client,
|
||||
is_connection_error,
|
||||
log_connection_failure,
|
||||
)
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
@@ -31,6 +36,11 @@ from ..register import register_provider_adapter
|
||||
"OpenAI API Chat Completion 提供商适配器",
|
||||
)
|
||||
class ProviderOpenAIOfficial(Provider):
|
||||
def _create_http_client(self, provider_config: dict) -> httpx.AsyncClient | None:
|
||||
"""创建带代理的 HTTP 客户端"""
|
||||
proxy = provider_config.get("proxy", "")
|
||||
return create_proxy_client("OpenAI", proxy)
|
||||
|
||||
def __init__(self, provider_config, provider_settings) -> None:
|
||||
super().__init__(provider_config, provider_settings)
|
||||
self.chosen_api_key = None
|
||||
@@ -55,6 +65,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
default_headers=self.custom_headers,
|
||||
base_url=provider_config.get("api_base", ""),
|
||||
timeout=self.timeout,
|
||||
http_client=self._create_http_client(provider_config),
|
||||
)
|
||||
else:
|
||||
# Using OpenAI Official API
|
||||
@@ -63,6 +74,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
base_url=provider_config.get("api_base", None),
|
||||
default_headers=self.custom_headers,
|
||||
timeout=self.timeout,
|
||||
http_client=self._create_http_client(provider_config),
|
||||
)
|
||||
|
||||
self.default_params = inspect.signature(
|
||||
@@ -455,12 +467,9 @@ class ProviderOpenAIOfficial(Provider):
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
|
||||
if "Connection error." in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
if proxy:
|
||||
logger.error(
|
||||
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}",
|
||||
)
|
||||
if is_connection_error(e):
|
||||
proxy = self.provider_config.get("proxy", "")
|
||||
log_connection_failure("OpenAI", e, proxy)
|
||||
|
||||
raise e
|
||||
|
||||
@@ -697,3 +706,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import httpx
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
@@ -29,10 +31,16 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
if isinstance(timeout, str):
|
||||
timeout = int(timeout)
|
||||
|
||||
proxy = provider_config.get("proxy", "")
|
||||
http_client = None
|
||||
if proxy:
|
||||
logger.info(f"[OpenAI TTS] 使用代理: {proxy}")
|
||||
http_client = httpx.AsyncClient(proxy=proxy)
|
||||
self.client = AsyncOpenAI(
|
||||
api_key=self.chosen_api_key,
|
||||
base_url=provider_config.get("api_base"),
|
||||
timeout=timeout,
|
||||
http_client=http_client,
|
||||
)
|
||||
|
||||
self.set_model(provider_config.get("model", ""))
|
||||
@@ -50,3 +58,7 @@ class ProviderOpenAITTSAPI(TTSProvider):
|
||||
async for chunk in response.iter_bytes(chunk_size=1024):
|
||||
f.write(chunk)
|
||||
return path
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -107,3 +107,7 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to remove temp file {audio_url}: {e}")
|
||||
return result.text
|
||||
|
||||
async def terminate(self):
|
||||
if self.client:
|
||||
await self.client.close()
|
||||
|
||||
@@ -4,6 +4,7 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from astrbot.api import sp
|
||||
from astrbot.core import db_helper, logger
|
||||
from astrbot.core.db.po import CommandConfig
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
@@ -139,6 +140,51 @@ async def rename_command(
|
||||
return descriptor
|
||||
|
||||
|
||||
async def update_command_permission(
|
||||
handler_full_name: str,
|
||||
permission_type: str,
|
||||
) -> CommandDescriptor:
|
||||
descriptor = _build_descriptor_by_full_name(handler_full_name)
|
||||
if not descriptor:
|
||||
raise ValueError("指定的处理函数不存在或不是指令。")
|
||||
|
||||
if permission_type not in ["admin", "member"]:
|
||||
raise ValueError("权限类型必须为 admin 或 member。")
|
||||
|
||||
handler = descriptor.handler
|
||||
found_plugin = star_map.get(handler.handler_module_path)
|
||||
if not found_plugin:
|
||||
raise ValueError("未找到指令所属插件")
|
||||
|
||||
# 1. Update Persistent Config (alter_cmd)
|
||||
alter_cmd_cfg = await sp.global_get("alter_cmd", {})
|
||||
plugin_ = alter_cmd_cfg.get(found_plugin.name, {})
|
||||
cfg = plugin_.get(handler.handler_name, {})
|
||||
cfg["permission"] = permission_type
|
||||
plugin_[handler.handler_name] = cfg
|
||||
alter_cmd_cfg[found_plugin.name] = plugin_
|
||||
|
||||
await sp.global_put("alter_cmd", alter_cmd_cfg)
|
||||
|
||||
# 2. Update Runtime Filter
|
||||
found_permission_filter = False
|
||||
target_perm_type = (
|
||||
PermissionType.ADMIN if permission_type == "admin" else PermissionType.MEMBER
|
||||
)
|
||||
|
||||
for filter_ in handler.event_filters:
|
||||
if isinstance(filter_, PermissionTypeFilter):
|
||||
filter_.permission_type = target_perm_type
|
||||
found_permission_filter = True
|
||||
break
|
||||
|
||||
if not found_permission_filter:
|
||||
handler.event_filters.insert(0, PermissionTypeFilter(target_perm_type))
|
||||
|
||||
# Re-build descriptor to reflect changes
|
||||
return _build_descriptor(handler) or descriptor
|
||||
|
||||
|
||||
async def list_commands() -> list[dict[str, Any]]:
|
||||
descriptors = _collect_descriptors(include_sub_commands=True)
|
||||
config_records = await db_helper.get_command_configs()
|
||||
|
||||
@@ -37,9 +37,9 @@ class CustomFilter(HandlerFilter, metaclass=CustomFilterMeta):
|
||||
class CustomFilterOr(CustomFilter):
|
||||
def __init__(self, filter1: CustomFilter, filter2: CustomFilter):
|
||||
super().__init__()
|
||||
if not isinstance(filter1, CustomFilter | CustomFilterAnd | CustomFilterOr):
|
||||
if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)):
|
||||
raise ValueError(
|
||||
"CustomFilter lass can only operate with other CustomFilter.",
|
||||
"CustomFilter class can only operate with other CustomFilter.",
|
||||
)
|
||||
self.filter1 = filter1
|
||||
self.filter2 = filter2
|
||||
@@ -51,7 +51,7 @@ class CustomFilterOr(CustomFilter):
|
||||
class CustomFilterAnd(CustomFilter):
|
||||
def __init__(self, filter1: CustomFilter, filter2: CustomFilter):
|
||||
super().__init__()
|
||||
if not isinstance(filter1, CustomFilter | CustomFilterAnd | CustomFilterOr):
|
||||
if not isinstance(filter1, (CustomFilter, CustomFilterAnd, CustomFilterOr)):
|
||||
raise ValueError(
|
||||
"CustomFilter lass can only operate with other CustomFilter.",
|
||||
)
|
||||
|
||||
@@ -150,7 +150,7 @@ def register_custom_filter(custom_type_filter, *args, **kwargs):
|
||||
if args:
|
||||
raise_error = args[0]
|
||||
|
||||
if not isinstance(custom_filter, CustomFilterAnd | CustomFilterOr):
|
||||
if not isinstance(custom_filter, (CustomFilterAnd, CustomFilterOr)):
|
||||
custom_filter = custom_filter(raise_error)
|
||||
|
||||
def decorator(awaitable):
|
||||
|
||||
@@ -15,6 +15,7 @@ import yaml
|
||||
from astrbot.core import logger, pip_installer, sp
|
||||
from astrbot.core.agent.handoff import FunctionTool, HandoffTool
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core.platform.register import unregister_platform_adapters_by_module
|
||||
from astrbot.core.provider.register import llm_tools
|
||||
from astrbot.core.utils.astrbot_path import (
|
||||
get_astrbot_config_path,
|
||||
@@ -842,6 +843,18 @@ class PluginManager:
|
||||
for func_tool in to_remove:
|
||||
llm_tools.func_list.remove(func_tool)
|
||||
|
||||
# Unregister platform adapters registered by this plugin
|
||||
# module_path is like "data.plugins.my_plugin.main", extract prefix like "data.plugins.my_plugin"
|
||||
module_prefix = ".".join(plugin_module_path.split(".")[:-1])
|
||||
if module_prefix:
|
||||
unregistered_adapters = unregister_platform_adapters_by_module(
|
||||
module_prefix
|
||||
)
|
||||
for adapter_name in unregistered_adapters:
|
||||
logger.info(
|
||||
f"移除了插件 {plugin_name} 的平台适配器 {adapter_name}",
|
||||
)
|
||||
|
||||
if plugin is None:
|
||||
return
|
||||
|
||||
|
||||
@@ -57,14 +57,20 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
py = sys.executable
|
||||
|
||||
try:
|
||||
if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli
|
||||
# 仅 CLI 模式走 `python -m astrbot.cli.__main__`,
|
||||
# 打包后的后端可执行文件需要直接 exec 自身。
|
||||
if os.environ.get("ASTRBOT_CLI") == "1":
|
||||
if os.name == "nt":
|
||||
args = [f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:]]
|
||||
else:
|
||||
args = sys.argv[1:]
|
||||
os.execl(sys.executable, py, "-m", "astrbot.cli.__main__", *args)
|
||||
else:
|
||||
os.execl(sys.executable, py, *sys.argv)
|
||||
if getattr(sys, "frozen", False):
|
||||
# Frozen executable should not receive argv[0] as a positional argument.
|
||||
os.execl(sys.executable, py, *sys.argv[1:])
|
||||
else:
|
||||
os.execl(sys.executable, py, *sys.argv)
|
||||
except Exception as e:
|
||||
logger.error(f"重启失败({py}, {e}),请尝试手动重启。")
|
||||
raise e
|
||||
|
||||
@@ -10,6 +10,7 @@ T2I 模板目录路径:固定为数据目录下的 t2i_templates 目录
|
||||
WebChat 数据目录路径:固定为数据目录下的 webchat 目录
|
||||
临时文件目录路径:固定为数据目录下的 temp 目录
|
||||
Skills 目录路径:固定为数据目录下的 skills 目录
|
||||
第三方依赖目录路径:固定为数据目录下的 site-packages 目录
|
||||
"""
|
||||
|
||||
import os
|
||||
@@ -69,6 +70,11 @@ def get_astrbot_skills_path() -> str:
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "skills"))
|
||||
|
||||
|
||||
def get_astrbot_site_packages_path() -> str:
|
||||
"""获取Astrbot第三方依赖目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "site-packages"))
|
||||
|
||||
|
||||
def get_astrbot_knowledge_base_path() -> str:
|
||||
"""获取Astrbot知识库根目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "knowledge_base"))
|
||||
|
||||
@@ -0,0 +1,207 @@
|
||||
"""媒体文件处理工具
|
||||
|
||||
提供音视频格式转换、时长获取等功能。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import subprocess
|
||||
import uuid
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
async def get_media_duration(file_path: str) -> int | None:
|
||||
"""使用ffprobe获取媒体文件时长
|
||||
|
||||
Args:
|
||||
file_path: 媒体文件路径
|
||||
|
||||
Returns:
|
||||
时长(毫秒),如果获取失败返回None
|
||||
"""
|
||||
try:
|
||||
# 使用ffprobe获取时长
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"error",
|
||||
"-show_entries",
|
||||
"format=duration",
|
||||
"-of",
|
||||
"default=noprint_wrappers=1:nokey=1",
|
||||
file_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode == 0 and stdout:
|
||||
duration_seconds = float(stdout.decode().strip())
|
||||
duration_ms = int(duration_seconds * 1000)
|
||||
logger.debug(f"[Media Utils] 获取媒体时长: {duration_ms}ms")
|
||||
return duration_ms
|
||||
else:
|
||||
logger.warning(f"[Media Utils] 无法获取媒体文件时长: {file_path}")
|
||||
return None
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.warning(
|
||||
"[Media Utils] ffprobe未安装或不在PATH中,无法获取媒体时长。请安装ffmpeg: https://ffmpeg.org/"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[Media Utils] 获取媒体时长时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def convert_audio_to_opus(audio_path: str, output_path: str | None = None) -> str:
|
||||
"""使用ffmpeg将音频转换为opus格式
|
||||
|
||||
Args:
|
||||
audio_path: 原始音频文件路径
|
||||
output_path: 输出文件路径,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
转换后的opus文件路径
|
||||
|
||||
Raises:
|
||||
Exception: 转换失败时抛出异常
|
||||
"""
|
||||
# 如果已经是opus格式,直接返回
|
||||
if audio_path.lower().endswith(".opus"):
|
||||
return audio_path
|
||||
|
||||
# 生成输出文件路径
|
||||
if output_path is None:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.opus")
|
||||
|
||||
try:
|
||||
# 使用ffmpeg转换为opus格式
|
||||
# -y: 覆盖输出文件
|
||||
# -i: 输入文件
|
||||
# -acodec libopus: 使用opus编码器
|
||||
# -ac 1: 单声道
|
||||
# -ar 16000: 采样率16kHz
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
audio_path,
|
||||
"-acodec",
|
||||
"libopus",
|
||||
"-ac",
|
||||
"1",
|
||||
"-ar",
|
||||
"16000",
|
||||
output_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
# 清理可能已生成但无效的临时文件
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.debug(
|
||||
f"[Media Utils] 已清理失败的opus输出文件: {output_path}"
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(f"[Media Utils] 清理失败的opus输出文件时出错: {e}")
|
||||
|
||||
error_msg = stderr.decode() if stderr else "未知错误"
|
||||
logger.error(f"[Media Utils] ffmpeg转换音频失败: {error_msg}")
|
||||
raise Exception(f"ffmpeg conversion failed: {error_msg}")
|
||||
|
||||
logger.debug(f"[Media Utils] 音频转换成功: {audio_path} -> {output_path}")
|
||||
return output_path
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"[Media Utils] ffmpeg未安装或不在PATH中,无法转换音频格式。请安装ffmpeg: https://ffmpeg.org/"
|
||||
)
|
||||
raise Exception("ffmpeg not found")
|
||||
except Exception as e:
|
||||
logger.error(f"[Media Utils] 转换音频格式时出错: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def convert_video_format(
|
||||
video_path: str, output_format: str = "mp4", output_path: str | None = None
|
||||
) -> str:
|
||||
"""使用ffmpeg转换视频格式
|
||||
|
||||
Args:
|
||||
video_path: 原始视频文件路径
|
||||
output_format: 目标格式,默认mp4
|
||||
output_path: 输出文件路径,如果为None则自动生成
|
||||
|
||||
Returns:
|
||||
转换后的视频文件路径
|
||||
|
||||
Raises:
|
||||
Exception: 转换失败时抛出异常
|
||||
"""
|
||||
# 如果已经是目标格式,直接返回
|
||||
if video_path.lower().endswith(f".{output_format}"):
|
||||
return video_path
|
||||
|
||||
# 生成输出文件路径
|
||||
if output_path is None:
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
output_path = os.path.join(temp_dir, f"{uuid.uuid4()}.{output_format}")
|
||||
|
||||
try:
|
||||
# 使用ffmpeg转换视频格式
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
"ffmpeg",
|
||||
"-y",
|
||||
"-i",
|
||||
video_path,
|
||||
"-c:v",
|
||||
"libx264",
|
||||
"-c:a",
|
||||
"aac",
|
||||
output_path,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
|
||||
stdout, stderr = await process.communicate()
|
||||
|
||||
if process.returncode != 0:
|
||||
# 清理可能已生成但无效的临时文件
|
||||
if output_path and os.path.exists(output_path):
|
||||
try:
|
||||
os.remove(output_path)
|
||||
logger.debug(
|
||||
f"[Media Utils] 已清理失败的{output_format}输出文件: {output_path}"
|
||||
)
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"[Media Utils] 清理失败的{output_format}输出文件时出错: {e}"
|
||||
)
|
||||
|
||||
error_msg = stderr.decode() if stderr else "未知错误"
|
||||
logger.error(f"[Media Utils] ffmpeg转换视频失败: {error_msg}")
|
||||
raise Exception(f"ffmpeg conversion failed: {error_msg}")
|
||||
|
||||
logger.debug(f"[Media Utils] 视频转换成功: {video_path} -> {output_path}")
|
||||
return output_path
|
||||
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"[Media Utils] ffmpeg未安装或不在PATH中,无法转换视频格式。请安装ffmpeg: https://ffmpeg.org/"
|
||||
)
|
||||
raise Exception("ffmpeg not found")
|
||||
except Exception as e:
|
||||
logger.error(f"[Media Utils] 转换视频格式时出错: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,104 @@
|
||||
"""Network error handling utilities for providers."""
|
||||
|
||||
import httpx
|
||||
|
||||
from astrbot import logger
|
||||
|
||||
|
||||
def is_connection_error(exc: BaseException) -> bool:
|
||||
"""Check if an exception is a connection/network related error.
|
||||
|
||||
Uses explicit exception type checking instead of brittle string matching.
|
||||
Handles httpx network errors, timeouts, and common Python network exceptions.
|
||||
|
||||
Args:
|
||||
exc: The exception to check
|
||||
|
||||
Returns:
|
||||
True if the exception is a connection/network error
|
||||
"""
|
||||
# Check for httpx network errors
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadTimeout,
|
||||
httpx.WriteTimeout,
|
||||
httpx.PoolTimeout,
|
||||
httpx.NetworkError,
|
||||
httpx.ProxyError,
|
||||
httpx.RequestError,
|
||||
),
|
||||
):
|
||||
return True
|
||||
|
||||
# Check for common Python network errors
|
||||
if isinstance(exc, (TimeoutError, OSError, ConnectionError)):
|
||||
return True
|
||||
|
||||
# Check the __cause__ chain for wrapped connection errors
|
||||
cause = getattr(exc, "__cause__", None)
|
||||
if cause is not None and cause is not exc:
|
||||
return is_connection_error(cause)
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def log_connection_failure(
|
||||
provider_label: str,
|
||||
error: Exception,
|
||||
proxy: str | None = None,
|
||||
) -> None:
|
||||
"""Log a connection failure with proxy information.
|
||||
|
||||
If proxy is not provided, will fallback to check os.environ for
|
||||
http_proxy/https_proxy environment variables.
|
||||
|
||||
Args:
|
||||
provider_label: The provider name for log prefix (e.g., "OpenAI", "Gemini")
|
||||
error: The exception that occurred
|
||||
proxy: The proxy address if configured, or None/empty string
|
||||
"""
|
||||
import os
|
||||
|
||||
error_type = type(error).__name__
|
||||
|
||||
# Fallback to environment proxy if not configured
|
||||
effective_proxy = proxy
|
||||
if not effective_proxy:
|
||||
effective_proxy = os.environ.get(
|
||||
"http_proxy", os.environ.get("https_proxy", "")
|
||||
)
|
||||
|
||||
if effective_proxy:
|
||||
logger.error(
|
||||
f"[{provider_label}] 网络/代理连接失败 ({error_type})。"
|
||||
f"代理地址: {effective_proxy},错误: {error}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"[{provider_label}] 网络连接失败 ({error_type}),未配置代理。错误: {error}"
|
||||
)
|
||||
|
||||
|
||||
def create_proxy_client(
|
||||
provider_label: str,
|
||||
proxy: str | None = None,
|
||||
) -> httpx.AsyncClient | None:
|
||||
"""Create an httpx AsyncClient with proxy configuration if provided.
|
||||
|
||||
Note: The caller is responsible for closing the client when done.
|
||||
Consider using the client as a context manager or calling aclose() explicitly.
|
||||
|
||||
Args:
|
||||
provider_label: The provider name for log prefix (e.g., "OpenAI", "Gemini")
|
||||
proxy: The proxy address (e.g., "http://127.0.0.1:7890"), or None/empty
|
||||
|
||||
Returns:
|
||||
An httpx.AsyncClient configured with the proxy, or None if no proxy
|
||||
"""
|
||||
if proxy:
|
||||
logger.info(f"[{provider_label}] 使用代理: {proxy}")
|
||||
return httpx.AsyncClient(proxy=proxy)
|
||||
return None
|
||||
@@ -1,8 +1,14 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import importlib
|
||||
import io
|
||||
import locale
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_site_packages_path
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
@@ -24,6 +30,36 @@ def _robust_decode(line: bytes) -> str:
|
||||
return line.decode("utf-8", errors="replace").strip()
|
||||
|
||||
|
||||
def _is_frozen_runtime() -> bool:
|
||||
return bool(getattr(sys, "frozen", False))
|
||||
|
||||
|
||||
def _get_pip_main():
|
||||
try:
|
||||
from pip._internal.cli.main import main as pip_main
|
||||
except ImportError:
|
||||
from pip import main as pip_main
|
||||
return pip_main
|
||||
|
||||
|
||||
def _run_pip_main_with_output(pip_main, args: list[str]) -> tuple[int, str]:
|
||||
stream = io.StringIO()
|
||||
with contextlib.redirect_stdout(stream), contextlib.redirect_stderr(stream):
|
||||
result_code = pip_main(args)
|
||||
return result_code, stream.getvalue()
|
||||
|
||||
|
||||
def _cleanup_added_root_handlers(original_handlers: list[logging.Handler]) -> None:
|
||||
root_logger = logging.getLogger()
|
||||
original_handler_ids = {id(handler) for handler in original_handlers}
|
||||
|
||||
for handler in list(root_logger.handlers):
|
||||
if id(handler) not in original_handler_ids:
|
||||
root_logger.removeHandler(handler)
|
||||
with contextlib.suppress(Exception):
|
||||
handler.close()
|
||||
|
||||
|
||||
class PipInstaller:
|
||||
def __init__(self, pip_install_arg: str, pypi_index_url: str | None = None):
|
||||
self.pip_install_arg = pip_install_arg
|
||||
@@ -45,37 +81,59 @@ class PipInstaller:
|
||||
|
||||
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
|
||||
|
||||
target_site_packages = None
|
||||
if _is_frozen_runtime():
|
||||
target_site_packages = get_astrbot_site_packages_path()
|
||||
os.makedirs(target_site_packages, exist_ok=True)
|
||||
args.extend(["--target", target_site_packages])
|
||||
|
||||
if self.pip_install_arg:
|
||||
args.extend(self.pip_install_arg.split())
|
||||
|
||||
logger.info(f"Pip 包管理器: pip {' '.join(args)}")
|
||||
try:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pip",
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
result_code = None
|
||||
if _is_frozen_runtime():
|
||||
result_code = await self._run_pip_in_process(args)
|
||||
else:
|
||||
try:
|
||||
result_code = await self._run_pip_subprocess(args)
|
||||
except FileNotFoundError:
|
||||
result_code = await self._run_pip_in_process(args)
|
||||
|
||||
assert process.stdout is not None
|
||||
async for line in process.stdout:
|
||||
logger.info(_robust_decode(line))
|
||||
if result_code != 0:
|
||||
raise Exception(f"安装失败,错误码:{result_code}")
|
||||
|
||||
await process.wait()
|
||||
if target_site_packages and target_site_packages not in sys.path:
|
||||
sys.path.insert(0, target_site_packages)
|
||||
importlib.invalidate_caches()
|
||||
|
||||
if process.returncode != 0:
|
||||
raise Exception(f"安装失败,错误码:{process.returncode}")
|
||||
except FileNotFoundError:
|
||||
# 没有 pip
|
||||
from pip import main as pip_main
|
||||
async def _run_pip_subprocess(self, args: list[str]) -> int:
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
sys.executable,
|
||||
"-m",
|
||||
"pip",
|
||||
*args,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.STDOUT,
|
||||
)
|
||||
|
||||
result_code = await asyncio.to_thread(pip_main, args)
|
||||
assert process.stdout is not None
|
||||
async for line in process.stdout:
|
||||
logger.info(_robust_decode(line))
|
||||
|
||||
# 清除 pip.main 导致的多余的 logging handlers
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
await process.wait()
|
||||
return process.returncode
|
||||
|
||||
if result_code != 0:
|
||||
raise Exception(f"安装失败,错误码:{result_code}")
|
||||
async def _run_pip_in_process(self, args: list[str]) -> int:
|
||||
pip_main = _get_pip_main()
|
||||
original_handlers = list(logging.getLogger().handlers)
|
||||
result_code, output = await asyncio.to_thread(
|
||||
_run_pip_main_with_output, pip_main, args
|
||||
)
|
||||
for line in output.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
logger.info(line)
|
||||
|
||||
_cleanup_added_root_handlers(original_handlers)
|
||||
return result_code
|
||||
|
||||
@@ -23,7 +23,7 @@ class SharedPreferences:
|
||||
)
|
||||
self.path = json_storage_path
|
||||
self.db_helper = db_helper
|
||||
self.temorary_cache: dict[str, dict[str, Any]] = defaultdict(dict)
|
||||
self.temporary_cache: dict[str, dict[str, Any]] = defaultdict(dict)
|
||||
"""automatically clear per 24 hours. Might be helpful in some cases XD"""
|
||||
|
||||
self._sync_loop = asyncio.new_event_loop()
|
||||
@@ -37,7 +37,7 @@ class SharedPreferences:
|
||||
self._scheduler.start()
|
||||
|
||||
def _clear_temporary_cache(self):
|
||||
self.temorary_cache.clear()
|
||||
self.temporary_cache.clear()
|
||||
|
||||
async def get_async(
|
||||
self,
|
||||
|
||||
@@ -50,6 +50,10 @@ class TraceSpan:
|
||||
self.started_at = time.time()
|
||||
|
||||
def record(self, action: str, **fields: Any) -> None:
|
||||
# Check if trace recording is enabled
|
||||
if not astrbot_config.get("trace_enable", True):
|
||||
return
|
||||
|
||||
payload = {
|
||||
"type": "trace",
|
||||
"level": "TRACE",
|
||||
|
||||
@@ -238,6 +238,7 @@ class ChatRoute(Route):
|
||||
Returns:
|
||||
包含 used 列表的字典,记录被引用的搜索结果
|
||||
"""
|
||||
supported = ["web_search_tavily", "web_search_bocha"]
|
||||
# 从 accumulated_parts 中找到所有 web_search_tavily 的工具调用结果
|
||||
web_search_results = {}
|
||||
tool_call_parts = [
|
||||
@@ -248,7 +249,7 @@ class ChatRoute(Route):
|
||||
|
||||
for part in tool_call_parts:
|
||||
for tool_call in part["tool_calls"]:
|
||||
if tool_call.get("name") != "web_search_tavily" or not tool_call.get(
|
||||
if tool_call.get("name") not in supported or not tool_call.get(
|
||||
"result"
|
||||
):
|
||||
continue
|
||||
@@ -278,7 +279,7 @@ class ChatRoute(Route):
|
||||
if ref_index not in web_search_results:
|
||||
continue
|
||||
payload = {"index": ref_index, **web_search_results[ref_index]}
|
||||
if favicon := sp.temorary_cache.get("_ws_favicon", {}).get(payload["url"]):
|
||||
if favicon := sp.temporary_cache.get("_ws_favicon", {}).get(payload["url"]):
|
||||
payload["favicon"] = favicon
|
||||
used_refs.append(payload)
|
||||
|
||||
@@ -353,12 +354,15 @@ class ChatRoute(Route):
|
||||
return Response().error("session_id is empty").__dict__
|
||||
|
||||
webchat_conv_id = session_id
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(webchat_conv_id)
|
||||
|
||||
# 构建用户消息段(包含 path 用于传递给 adapter)
|
||||
message_parts = await self._build_user_message_parts(message)
|
||||
|
||||
message_id = str(uuid.uuid4())
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(
|
||||
message_id,
|
||||
webchat_conv_id,
|
||||
)
|
||||
|
||||
async def stream():
|
||||
client_disconnected = False
|
||||
@@ -531,6 +535,8 @@ class ChatRoute(Route):
|
||||
refs = {}
|
||||
except BaseException as e:
|
||||
logger.exception(f"WebChat stream unexpected error: {e}", exc_info=True)
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
# 将消息放入会话特定的队列
|
||||
chat_queue = webchat_queue_mgr.get_or_create_queue(webchat_conv_id)
|
||||
|
||||
@@ -10,6 +10,9 @@ from astrbot.core.star.command_management import (
|
||||
from astrbot.core.star.command_management import (
|
||||
toggle_command as toggle_command_service,
|
||||
)
|
||||
from astrbot.core.star.command_management import (
|
||||
update_command_permission as update_command_permission_service,
|
||||
)
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -22,6 +25,7 @@ class CommandRoute(Route):
|
||||
"/commands/conflicts": ("GET", self.get_conflicts),
|
||||
"/commands/toggle": ("POST", self.toggle_command),
|
||||
"/commands/rename": ("POST", self.rename_command),
|
||||
"/commands/permission": ("POST", self.update_permission),
|
||||
}
|
||||
self.register_routes()
|
||||
|
||||
@@ -74,6 +78,24 @@ class CommandRoute(Route):
|
||||
payload = await _get_command_payload(handler_full_name)
|
||||
return Response().ok(payload).__dict__
|
||||
|
||||
async def update_permission(self):
|
||||
data = await request.get_json()
|
||||
handler_full_name = data.get("handler_full_name")
|
||||
permission = data.get("permission")
|
||||
|
||||
if not handler_full_name or not permission:
|
||||
return (
|
||||
Response().error("handler_full_name 与 permission 均为必填。").__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
await update_command_permission_service(handler_full_name, permission)
|
||||
except ValueError as exc:
|
||||
return Response().error(str(exc)).__dict__
|
||||
|
||||
payload = await _get_command_payload(handler_full_name)
|
||||
return Response().ok(payload).__dict__
|
||||
|
||||
|
||||
async def _get_command_payload(handler_full_name: str):
|
||||
commands = await list_commands()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
@@ -407,8 +408,19 @@ class ConfigRoute(Route):
|
||||
return Response().ok(message="更新 provider source 成功").__dict__
|
||||
|
||||
async def get_provider_template(self):
|
||||
provider_metadata = ConfigMetadataI18n.convert_to_i18n_keys(
|
||||
{
|
||||
"provider_group": {
|
||||
"metadata": {
|
||||
"provider": CONFIG_METADATA_2["provider_group"]["metadata"][
|
||||
"provider"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
config_schema = {
|
||||
"provider": CONFIG_METADATA_2["provider_group"]["metadata"]["provider"]
|
||||
"provider": provider_metadata["provider_group"]["metadata"]["provider"]
|
||||
}
|
||||
data = {
|
||||
"config_schema": config_schema,
|
||||
@@ -1278,11 +1290,24 @@ class ConfigRoute(Route):
|
||||
|
||||
async def _get_astrbot_config(self):
|
||||
config = self.config
|
||||
metadata = copy.deepcopy(CONFIG_METADATA_2)
|
||||
platform_i18n = ConfigMetadataI18n.convert_to_i18n_keys(
|
||||
{
|
||||
"platform_group": {
|
||||
"metadata": {
|
||||
"platform": metadata["platform_group"]["metadata"]["platform"]
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
metadata["platform_group"]["metadata"]["platform"] = platform_i18n[
|
||||
"platform_group"
|
||||
]["metadata"]["platform"]
|
||||
|
||||
# 平台适配器的默认配置模板注入
|
||||
platform_default_tmpl = CONFIG_METADATA_2["platform_group"]["metadata"][
|
||||
"platform"
|
||||
]["config_template"]
|
||||
platform_default_tmpl = metadata["platform_group"]["metadata"]["platform"][
|
||||
"config_template"
|
||||
]
|
||||
|
||||
# 收集需要注册logo的平台
|
||||
logo_registration_tasks = []
|
||||
@@ -1300,14 +1325,14 @@ class ConfigRoute(Route):
|
||||
await asyncio.gather(*logo_registration_tasks, return_exceptions=True)
|
||||
|
||||
# 服务提供商的默认配置模板注入
|
||||
provider_default_tmpl = CONFIG_METADATA_2["provider_group"]["metadata"][
|
||||
"provider"
|
||||
]["config_template"]
|
||||
provider_default_tmpl = metadata["provider_group"]["metadata"]["provider"][
|
||||
"config_template"
|
||||
]
|
||||
for provider in provider_registry:
|
||||
if provider.default_config_tmpl:
|
||||
provider_default_tmpl[provider.type] = provider.default_config_tmpl
|
||||
|
||||
return {"metadata": CONFIG_METADATA_2, "config": config}
|
||||
return {"metadata": metadata, "config": config}
|
||||
|
||||
async def _get_plugin_config(self, plugin_name: str):
|
||||
ret: dict = {"metadata": None, "config": None}
|
||||
|
||||
@@ -23,7 +23,7 @@ class CronRoute(Route):
|
||||
]
|
||||
self.register_routes()
|
||||
|
||||
def _serialize_job(self, job):
|
||||
def _serialize_job(self, job) -> dict:
|
||||
data = job.model_dump() if hasattr(job, "model_dump") else job.__dict__
|
||||
for k in ["created_at", "updated_at", "last_run_at", "next_run_time"]:
|
||||
if isinstance(data.get(k), datetime):
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
from quart import request
|
||||
@@ -75,7 +76,7 @@ class KnowledgeBaseRoute(Route):
|
||||
}
|
||||
|
||||
def _set_task_result(
|
||||
self, task_id: str, status: str, result: any = None, error: str | None = None
|
||||
self, task_id: str, status: str, result: Any = None, error: str | None = None
|
||||
) -> None:
|
||||
self.upload_tasks[task_id] = {
|
||||
"status": status,
|
||||
|
||||
@@ -256,143 +256,148 @@ class LiveChatRoute(Route):
|
||||
await queue.put((session.username, cid, payload))
|
||||
|
||||
# 3. 等待响应并流式发送 TTS 音频
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(cid)
|
||||
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, cid)
|
||||
|
||||
bot_text = ""
|
||||
audio_playing = False
|
||||
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
# 用户打断,停止处理
|
||||
logger.info("[Live Chat] 检测到用户打断")
|
||||
await websocket.send_json({"t": "stop_play"})
|
||||
# 保存消息并标记为被打断
|
||||
await self._save_interrupted_message(session, user_text, bot_text)
|
||||
# 清空队列中未处理的消息
|
||||
while not back_queue.empty():
|
||||
try:
|
||||
while True:
|
||||
if session.should_interrupt:
|
||||
# 用户打断,停止处理
|
||||
logger.info("[Live Chat] 检测到用户打断")
|
||||
await websocket.send_json({"t": "stop_play"})
|
||||
# 保存消息并标记为被打断
|
||||
await self._save_interrupted_message(
|
||||
session, user_text, bot_text
|
||||
)
|
||||
# 清空队列中未处理的消息
|
||||
while not back_queue.empty():
|
||||
try:
|
||||
back_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
break
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
|
||||
result_message_id = result.get("message_id")
|
||||
if result_message_id != message_id:
|
||||
logger.warning(
|
||||
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
|
||||
)
|
||||
continue
|
||||
|
||||
result_type = result.get("type")
|
||||
result_chain_type = result.get("chain_type")
|
||||
data = result.get("data", "")
|
||||
|
||||
if result_chain_type == "agent_stats":
|
||||
try:
|
||||
back_queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
break
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"llm_ttft": stats.get("time_to_first_token", 0),
|
||||
"llm_total_time": stats.get("end_time", 0)
|
||||
- stats.get("start_time", 0),
|
||||
},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
result = await asyncio.wait_for(back_queue.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if result_chain_type == "tts_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": stats,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
if result_type == "plain":
|
||||
# 普通文本消息
|
||||
bot_text += data
|
||||
|
||||
result_message_id = result.get("message_id")
|
||||
if result_message_id != message_id:
|
||||
logger.warning(
|
||||
f"[Live Chat] 消息 ID 不匹配: {result_message_id} != {message_id}"
|
||||
)
|
||||
continue
|
||||
elif result_type == "audio_chunk":
|
||||
# 流式音频数据
|
||||
if not audio_playing:
|
||||
audio_playing = True
|
||||
logger.debug("[Live Chat] 开始播放音频流")
|
||||
|
||||
result_type = result.get("type")
|
||||
result_chain_type = result.get("chain_type")
|
||||
data = result.get("data", "")
|
||||
# Calculate latency from wav assembly finish to first audio chunk
|
||||
speak_to_first_frame_latency = (
|
||||
time.time() - wav_assembly_finish_time
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"speak_to_first_frame": speak_to_first_frame_latency
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
if result_chain_type == "agent_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
text = result.get("text")
|
||||
if text:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_text_chunk",
|
||||
"data": {"text": text},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送音频数据给前端
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "response",
|
||||
"data": data, # base64 编码的音频数据
|
||||
}
|
||||
)
|
||||
|
||||
elif result_type in ["complete", "end"]:
|
||||
# 处理完成
|
||||
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
||||
|
||||
# 如果没有音频流,发送 bot 消息文本
|
||||
if not audio_playing:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_msg",
|
||||
"data": {
|
||||
"text": bot_text,
|
||||
"ts": int(time.time() * 1000),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送结束标记
|
||||
await websocket.send_json({"t": "end"})
|
||||
|
||||
# 发送总耗时
|
||||
wav_to_tts_duration = time.time() - wav_assembly_finish_time
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"llm_ttft": stats.get("time_to_first_token", 0),
|
||||
"llm_total_time": stats.get("end_time", 0)
|
||||
- stats.get("start_time", 0),
|
||||
},
|
||||
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 AgentStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if result_chain_type == "tts_stats":
|
||||
try:
|
||||
stats = json.loads(data)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": stats,
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 解析 TTSStats 失败: {e}")
|
||||
continue
|
||||
|
||||
if result_type == "plain":
|
||||
# 普通文本消息
|
||||
bot_text += data
|
||||
|
||||
elif result_type == "audio_chunk":
|
||||
# 流式音频数据
|
||||
if not audio_playing:
|
||||
audio_playing = True
|
||||
logger.debug("[Live Chat] 开始播放音频流")
|
||||
|
||||
# Calculate latency from wav assembly finish to first audio chunk
|
||||
speak_to_first_frame_latency = (
|
||||
time.time() - wav_assembly_finish_time
|
||||
)
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {
|
||||
"speak_to_first_frame": speak_to_first_frame_latency
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
text = result.get("text")
|
||||
if text:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_text_chunk",
|
||||
"data": {"text": text},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送音频数据给前端
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "response",
|
||||
"data": data, # base64 编码的音频数据
|
||||
}
|
||||
)
|
||||
|
||||
elif result_type in ["complete", "end"]:
|
||||
# 处理完成
|
||||
logger.info(f"[Live Chat] Bot 回复完成: {bot_text}")
|
||||
|
||||
# 如果没有音频流,发送 bot 消息文本
|
||||
if not audio_playing:
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "bot_msg",
|
||||
"data": {
|
||||
"text": bot_text,
|
||||
"ts": int(time.time() * 1000),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# 发送结束标记
|
||||
await websocket.send_json({"t": "end"})
|
||||
|
||||
# 发送总耗时
|
||||
wav_to_tts_duration = time.time() - wav_assembly_finish_time
|
||||
await websocket.send_json(
|
||||
{
|
||||
"t": "metrics",
|
||||
"data": {"wav_to_tts_total_time": wav_to_tts_duration},
|
||||
}
|
||||
)
|
||||
break
|
||||
break
|
||||
finally:
|
||||
webchat_queue_mgr.remove_back_queue(message_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Live Chat] 处理音频失败: {e}", exc_info=True)
|
||||
|
||||
@@ -31,6 +31,16 @@ class LogRoute(Route):
|
||||
view_func=self.log_history,
|
||||
methods=["GET"],
|
||||
)
|
||||
self.app.add_url_rule(
|
||||
"/api/trace/settings",
|
||||
view_func=self.get_trace_settings,
|
||||
methods=["GET"],
|
||||
)
|
||||
self.app.add_url_rule(
|
||||
"/api/trace/settings",
|
||||
view_func=self.update_trace_settings,
|
||||
methods=["POST"],
|
||||
)
|
||||
|
||||
async def _replay_cached_logs(
|
||||
self, last_event_id: str
|
||||
@@ -106,3 +116,29 @@ class LogRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(f"获取日志历史失败: {e}")
|
||||
return Response().error(f"获取日志历史失败: {e}").__dict__
|
||||
|
||||
async def get_trace_settings(self):
|
||||
"""获取 Trace 设置"""
|
||||
try:
|
||||
trace_enable = self.config.get("trace_enable", True)
|
||||
return Response().ok(data={"trace_enable": trace_enable}).__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"获取 Trace 设置失败: {e}")
|
||||
return Response().error(f"获取 Trace 设置失败: {e}").__dict__
|
||||
|
||||
async def update_trace_settings(self):
|
||||
"""更新 Trace 设置"""
|
||||
try:
|
||||
data = await request.json
|
||||
if data is None:
|
||||
return Response().error("请求数据为空").__dict__
|
||||
|
||||
trace_enable = data.get("trace_enable")
|
||||
if trace_enable is not None:
|
||||
self.config["trace_enable"] = bool(trace_enable)
|
||||
self.config.save_config()
|
||||
|
||||
return Response().ok(message="Trace 设置已更新").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"更新 Trace 设置失败: {e}")
|
||||
return Response().error(f"更新 Trace 设置失败: {e}").__dict__
|
||||
|
||||
@@ -315,6 +315,17 @@ class PluginRoute(Route):
|
||||
"display_name": plugin.display_name,
|
||||
"logo": f"/api/file/{logo_url}" if logo_url else None,
|
||||
}
|
||||
# 检查是否为全空的幽灵插件
|
||||
if not any(
|
||||
[
|
||||
plugin.name,
|
||||
plugin.author,
|
||||
plugin.desc,
|
||||
plugin.version,
|
||||
plugin.display_name,
|
||||
]
|
||||
):
|
||||
continue
|
||||
_plugin_resp.append(_t)
|
||||
return (
|
||||
Response()
|
||||
|
||||
@@ -24,14 +24,22 @@ class SkillsRoute(Route):
|
||||
|
||||
async def get_skills(self):
|
||||
try:
|
||||
cfg = self.core_lifecycle.astrbot_config.get("provider_settings", {}).get(
|
||||
"skills", {}
|
||||
provider_settings = self.core_lifecycle.astrbot_config.get(
|
||||
"provider_settings", {}
|
||||
)
|
||||
runtime = cfg.get("runtime", "local")
|
||||
runtime = provider_settings.get("computer_use_runtime", "local")
|
||||
skills = SkillManager().list_skills(
|
||||
active_only=False, runtime=runtime, show_sandbox_path=False
|
||||
)
|
||||
return Response().ok([skill.__dict__ for skill in skills]).__dict__
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"skills": [skill.__dict__ for skill in skills],
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
@@ -2,14 +2,13 @@ import asyncio
|
||||
import logging
|
||||
import os
|
||||
import socket
|
||||
from typing import cast
|
||||
from typing import Protocol, cast
|
||||
|
||||
import jwt
|
||||
import psutil
|
||||
from flask.json.provider import DefaultJSONProvider
|
||||
from hypercorn.asyncio import serve
|
||||
from hypercorn.config import Config as HyperConfig
|
||||
from psutil._common import addr as psutil_addr
|
||||
from quart import Quart, g, jsonify, request
|
||||
from quart.logging import default_handler
|
||||
|
||||
@@ -29,6 +28,11 @@ from .routes.session_management import SessionManagementRoute
|
||||
from .routes.subagent import SubAgentRoute
|
||||
from .routes.t2i import T2iRoute
|
||||
|
||||
|
||||
class _AddrWithPort(Protocol):
|
||||
port: int
|
||||
|
||||
|
||||
APP: Quart
|
||||
|
||||
|
||||
@@ -168,7 +172,7 @@ class AstrBotDashboard:
|
||||
"""获取占用端口的进程详细信息"""
|
||||
try:
|
||||
for conn in psutil.net_connections(kind="inet"):
|
||||
if cast(psutil_addr, conn.laddr).port == port:
|
||||
if cast(_AddrWithPort, conn.laddr).port == port:
|
||||
try:
|
||||
process = psutil.Process(conn.pid)
|
||||
# 获取详细信息
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
## What's Changed - BIG AND BEAUTIFUL VERSION
|
||||
|
||||
> 如果在之前版本使用了 Skill,这次更新之后**需要重新配置** Skill Runtime 相关选项。
|
||||
|
||||
### 新增
|
||||
- 🔥 新增未来任务系统(Future Tasks)。给 AstrBot 布置的未来任务,让 AstrBot 能够在某一时刻自动唤醒,帮你完成任务。详见 [主动任务](https://docs.astrbot.app/use/proactive-agent.html) 。(实验性) ([#4697](https://github.com/AstrBotDevs/AstrBot/issues/4831))
|
||||
- 🔥 新增子代理(SubAgent)编排器。(实验性)([#4697](https://github.com/AstrBotDevs/AstrBot/issues/4831))
|
||||
- 🔥 AstrBot 目前可以直接通过调用 tool 将图片 / 文件推送给用户,大大提高交互效果。
|
||||
- 新增 Computer Use 运行时配置,以融合 Skill 和 Sandbox 配置 ([#4831](https://github.com/AstrBotDevs/AstrBot/issues/4831))
|
||||
- 新增主题自定义功能,可设置主色与辅色
|
||||
- 支持在配置页下人格对话框的编辑人格 ([#4826](https://github.com/AstrBotDevs/AstrBot/issues/4826))
|
||||
- 支持开关 “追踪” 功能;支持在系统配置中设置是否将日志写入 log 文件 ([#4822](https://github.com/AstrBotDevs/AstrBot/issues/4822))
|
||||
|
||||
### 修复
|
||||
- ‼️ 修复 ChatUI 图片、思考等显示异常问题。
|
||||
- ‼️ 修复 Skill 上传到 Sandbox 后未自动解压导致 Agent 无法读取的问题。
|
||||
- ‼️ 修复配置特定插件集时 MCP 工具被过滤的问题 ([#4825](https://github.com/AstrBotDevs/AstrBot/issues/4825))
|
||||
- ‼️ 移除 ChatUI 自带的让 LLM 最后提出问题的 prompt ([#4824](https://github.com/AstrBotDevs/AstrBot/issues/4824))
|
||||
- ‼️ 修复 WebUI 在上传 Skill 失败后仍显示成功消息的 bug ([#4768](https://github.com/AstrBotDevs/AstrBot/issues/4768))
|
||||
- 修复 MCP 服务器无法重命名的问题 ([#4766](https://github.com/AstrBotDevs/AstrBot/issues/4766))
|
||||
- 修复插件的 tool 无法在 WebUI 管理行为中看到来源的问题 ([#4776](https://github.com/AstrBotDevs/AstrBot/issues/4776))
|
||||
- ‼️ 修复 skill-like 的 tool 模式下,调用 tool 失败的问题 ([#4775](https://github.com/AstrBotDevs/AstrBot/issues/4775))
|
||||
|
||||
### 优化
|
||||
|
||||
- WebUI 整体 UI 效果优化
|
||||
- 部分 Dialog 标题样式统一
|
||||
|
||||
## What's Changed (EN)
|
||||
|
||||
### New Features
|
||||
- Introduce CronJob system with one-time tasks and enhanced dashboard management
|
||||
- Add theme customization with primary & secondary color options
|
||||
- Add computer-use runtime config for skills sandbox execution ([#4831](https://github.com/AstrBotDevs/AstrBot/issues/4831))
|
||||
- Add edit button to persona selector dialog ([#4826](https://github.com/AstrBotDevs/AstrBot/issues/4826))
|
||||
- Add trace logging toggle and configuration UI ([#4822](https://github.com/AstrBotDevs/AstrBot/issues/4822))
|
||||
- Add proactive-messaging capability with cron-tool trigger
|
||||
- Implement SubAgent orchestrator with configurable tool-management policies
|
||||
- Support resolving sandbox file paths and auto-download when necessary
|
||||
- Add embedded image & audio styles in MessagePartsRenderer
|
||||
- Introduce i18n foundation
|
||||
- Persist agent-interaction history
|
||||
- Add user notifications for file-download success/removal
|
||||
|
||||
### Bug Fixes
|
||||
- Improve ghost-plugin detection accuracy
|
||||
- Add error handling to prevent ghost-plugin crashes
|
||||
- Prevent skills bundle from overwriting existing files
|
||||
- Fix skills bundle unzip failure inside sandbox
|
||||
- Fix MCP tools being filtered when specific plugin set configured ([#4825](https://github.com/AstrBotDevs/AstrBot/issues/4825))
|
||||
- Merge ChatUI persona pop-up into default persona ([#4824](https://github.com/AstrBotDevs/AstrBot/issues/4824))
|
||||
- Fix reasoning block style
|
||||
- Add missing comma in truncate_and_compress hint
|
||||
- Fix frontend still showing success message ([#4768](https://github.com/AstrBotDevs/AstrBot/issues/4768))
|
||||
- Fix unable to rename MCP server ([#4766](https://github.com/AstrBotDevs/AstrBot/issues/4766))
|
||||
- Remove leftover sandbox runtime handling in skill upload ([#4798](https://github.com/AstrBotDevs/AstrBot/issues/4798))
|
||||
- Fix handler module path construction ([#4776](https://github.com/AstrBotDevs/AstrBot/issues/4776))
|
||||
- Fix skill-like tool invocation error ([#4775](https://github.com/AstrBotDevs/AstrBot/issues/4775))
|
||||
|
||||
### Improvements
|
||||
- Runtime hints & refined UI in skills management
|
||||
- Performance and UX improvements on cron-job page
|
||||
- General WebUI performance boost
|
||||
- Group tools by plugin in dropdown
|
||||
- Consistent dialog titles with padding and text styles
|
||||
- Code formatting unified (ruff format)
|
||||
- Bump version to 4.13.2
|
||||
|
||||
### Others
|
||||
- Remove obsolete reminder code
|
||||
- Extract main-agent module for better architecture
|
||||
- Merge AstrBot_skill branch changes
|
||||
@@ -0,0 +1,7 @@
|
||||
## What's Changed - BIG AND BEAUTIFUL VERSION
|
||||
|
||||
hotfix of v4.14.0
|
||||
|
||||
fixes:
|
||||
|
||||
- 由 `event.request_llm()` 过时导致的群聊上下文感知-主动回复功能可能不可用的问题
|
||||
@@ -0,0 +1,23 @@
|
||||
## What's Changed
|
||||
|
||||
### 新增
|
||||
- 控制台页面新增调试提示和本地化文件 ([#4852](https://github.com/AstrBotDevs/AstrBot/pull/4852))
|
||||
|
||||
### 修复
|
||||
- 修复插件热重载时平台适配器未清理导致注册冲突的问题 ([#4859](https://github.com/AstrBotDevs/AstrBot/pull/4859))
|
||||
|
||||
### 其他
|
||||
- 更新 ruff 版本至 0.15.0
|
||||
- 新增 robots.txt ([#4847](https://github.com/AstrBotDevs/AstrBot/pull/4847))
|
||||
|
||||
## What's Changed (EN)
|
||||
|
||||
### New Features
|
||||
- Add debug hint to console page and localization files ([#4852](https://github.com/AstrBotDevs/AstrBot/pull/4852))
|
||||
|
||||
### Bug Fixes
|
||||
- Fix platform adapter not being cleaned up during plugin hot reload, causing registration conflicts ([#4859](https://github.com/AstrBotDevs/AstrBot/pull/4859))
|
||||
|
||||
### Others
|
||||
- Update ruff version to 0.15.0
|
||||
- Add robots.txt ([#4847](https://github.com/AstrBotDevs/AstrBot/pull/4847))
|
||||
@@ -0,0 +1,4 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
- 修复 `on_llm_request` 钩子可能无法应用效果的问题
|
||||
@@ -0,0 +1,4 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
- 修复 token 统计错误的问题,修复在多轮 tool call 情况下或者其他极端情况下可能造成 tool 无限调用的问题。
|
||||
@@ -0,0 +1,11 @@
|
||||
## What's Changed
|
||||
|
||||
### Fix
|
||||
- fix: `fix: messages[x] assistant content must contain at least one part` after tool calling ([#4928](https://github.com/AstrBotDevs/AstrBot/issues/4928)) after tool calls.
|
||||
- fix: TypeError when MCP schema type is a list ([#4867](https://github.com/AstrBotDevs/AstrBot/issues/4867))
|
||||
- fix: Fixed an issue that caused scheduled task execution failures with specific providers 修复特定提供商导致的定时任务执行失败的问题 ([#4872](https://github.com/AstrBotDevs/AstrBot/issues/4872))
|
||||
|
||||
|
||||
### Feature
|
||||
- feat: add bocha web search tool ([#4902](https://github.com/AstrBotDevs/AstrBot/issues/4902))
|
||||
- feat: systemd support ([#4880](https://github.com/AstrBotDevs/AstrBot/issues/4880))
|
||||
@@ -0,0 +1,10 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
- 修复一些原因导致 Tavily WebSearch、Bocha WebSearch 无法使用的问题
|
||||
|
||||
### xinzeng
|
||||
- 飞书支持 Bot 发送文件、图片和视频消息类型。
|
||||
|
||||
### 优化
|
||||
- 优化 WebChat 和 企业微信 AI 会话队列生命周期管理,减少内存泄漏,提高性能。
|
||||
@@ -0,0 +1,31 @@
|
||||
## What's Changed
|
||||
|
||||
### 修复
|
||||
- 人格预设对话可能会重复添加到上下文 ([#4961](https://github.com/AstrBotDevs/AstrBot/issues/4961))
|
||||
|
||||
### 新增
|
||||
- 增加提供商级别的代理支持 ([#4949](https://github.com/AstrBotDevs/AstrBot/issues/4949))
|
||||
- WebUI 管理行为增加插件指令权限管理功能 ([#4887](https://github.com/AstrBotDevs/AstrBot/issues/4887))
|
||||
- 允许 LLM 预览工具返回的图片并自主决定是否发送 ([#4895](https://github.com/AstrBotDevs/AstrBot/issues/4895))
|
||||
- Telegram 平台添加媒体组(相册)支持 ([#4893](https://github.com/AstrBotDevs/AstrBot/issues/4893))
|
||||
- 增加欢迎功能,支持本地化内容和新手引导步骤
|
||||
- 支持 Electron 桌面应用部署 ([#4952](https://github.com/AstrBotDevs/AstrBot/issues/4952))
|
||||
|
||||
### 注意
|
||||
- 更新 AstrBot Python 版本要求至 3.12 ([#4963](https://github.com/AstrBotDevs/AstrBot/issues/4963))
|
||||
|
||||
## What's Changed
|
||||
|
||||
### Fixes
|
||||
- Fixed issue where persona preset conversations could be duplicated in context ([#4961](https://github.com/AstrBotDevs/AstrBot/issues/4961))
|
||||
|
||||
### Features
|
||||
- Added provider-level proxy support ([#4949](https://github.com/AstrBotDevs/AstrBot/issues/4949))
|
||||
- Added plugin command permission management to WebUI management behavior ([#4887](https://github.com/AstrBotDevs/AstrBot/issues/4887))
|
||||
- Allowed LLMs to preview images returned by tools and autonomously decide whether to send them ([#4895](https://github.com/AstrBotDevs/AstrBot/issues/4895))
|
||||
- Added media group (album) support for Telegram platform ([#4893](https://github.com/AstrBotDevs/AstrBot/issues/4893))
|
||||
- Added welcome feature with support for localized content and onboarding steps
|
||||
- Supported Electron desktop application deployment ([#4952](https://github.com/AstrBotDevs/AstrBot/issues/4952))
|
||||
|
||||
### Notice
|
||||
- Updated AstrBot Python version requirement to 3.12 ([#4963](https://github.com/AstrBotDevs/AstrBot/issues/4963))
|
||||
@@ -6,6 +6,7 @@
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<meta name="keywords" content="AstrBot Soulter" />
|
||||
<meta name="description" content="AstrBot Dashboard" />
|
||||
<meta name="robots" content="noindex, nofollow" />
|
||||
<link
|
||||
rel="stylesheet"
|
||||
href="https://fonts.googleapis.com/css2?family=Outfit&family=Poppins:wght@400;500;600;700&family=Roboto:wght@400;500;700&display=swap"
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
"markdown-it": "^14.1.0",
|
||||
"markstream-vue": "^0.0.6",
|
||||
"mermaid": "^11.12.2",
|
||||
"monaco-editor": "^0.55.1",
|
||||
"monaco-editor": "^0.52.2",
|
||||
"pinia": "2.1.6",
|
||||
"pinyin-pro": "^3.26.0",
|
||||
"remixicon": "3.5.0",
|
||||
|
||||
Generated
+5491
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,2 @@
|
||||
User-agent: *
|
||||
Disallow: /
|
||||
@@ -3,8 +3,7 @@
|
||||
<v-container fluid class="pa-0" elevation="0">
|
||||
<v-row class="d-flex justify-space-between align-center px-4 py-3 pb-8">
|
||||
<div>
|
||||
<v-btn color="success" prepend-icon="mdi-upload" class="me-2" variant="tonal"
|
||||
@click="uploadDialog = true">
|
||||
<v-btn color="success" prepend-icon="mdi-upload" class="me-2" variant="tonal" @click="uploadDialog = true">
|
||||
{{ tm('skills.upload') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" prepend-icon="mdi-refresh" variant="tonal" @click="fetchSkills">
|
||||
@@ -13,6 +12,10 @@
|
||||
</div>
|
||||
</v-row>
|
||||
|
||||
<div class="px-2 pb-2">
|
||||
<small style="color: grey;">{{ tm('skills.runtimeHint') }}</small>
|
||||
</div>
|
||||
|
||||
<v-progress-linear v-if="loading" indeterminate color="primary"></v-progress-linear>
|
||||
|
||||
<div v-else-if="skills.length === 0" class="text-center pa-8">
|
||||
@@ -40,13 +43,13 @@
|
||||
</v-row>
|
||||
</v-container>
|
||||
|
||||
<v-dialog v-model="uploadDialog" max-width="520px" persistent>
|
||||
<v-dialog v-model="uploadDialog" max-width="520px">
|
||||
<v-card>
|
||||
<v-card-title class="text-h3 pa-4 pb-0 pl-6">{{ tm('skills.uploadDialogTitle') }}</v-card-title>
|
||||
<v-card-text>
|
||||
<small class="text-grey">{{ tm('skills.uploadHint') }}</small>
|
||||
<v-file-input v-model="uploadFile" accept=".zip" :label="tm('skills.selectFile')" prepend-icon="mdi-folder-zip-outline"
|
||||
variant="outlined" class="mt-4" :multiple="false" />
|
||||
<v-file-input v-model="uploadFile" accept=".zip" :label="tm('skills.selectFile')"
|
||||
prepend-icon="mdi-folder-zip-outline" variant="outlined" class="mt-4" :multiple="false" />
|
||||
</v-card-text>
|
||||
<v-card-actions class="d-flex justify-end">
|
||||
<v-btn variant="text" @click="uploadDialog = false">{{ tm('skills.cancel') }}</v-btn>
|
||||
@@ -110,7 +113,12 @@ export default {
|
||||
loading.value = true;
|
||||
try {
|
||||
const res = await axios.get("/api/skills");
|
||||
skills.value = res.data.data || [];
|
||||
const payload = res.data?.data || [];
|
||||
if (Array.isArray(payload)) {
|
||||
skills.value = payload;
|
||||
} else {
|
||||
skills.value = payload.skills || [];
|
||||
}
|
||||
} catch (err) {
|
||||
showMessage(tm("skills.loadFailed"), "error");
|
||||
} finally {
|
||||
|
||||
@@ -18,6 +18,7 @@ const emit = defineEmits<{
|
||||
(e: 'toggle-command', cmd: CommandItem): void;
|
||||
(e: 'rename', cmd: CommandItem): void;
|
||||
(e: 'view-details', cmd: CommandItem): void;
|
||||
(e: 'update-permission', cmd: CommandItem, permission: 'admin' | 'member'): void;
|
||||
}>();
|
||||
|
||||
// 表格表头
|
||||
@@ -146,9 +147,36 @@ const getRowProps = ({ item }: { item: CommandItem }) => {
|
||||
</template>
|
||||
|
||||
<template v-slot:item.permission="{ item }">
|
||||
<v-chip :color="getPermissionColor(item.permission)" size="small" class="font-weight-medium">
|
||||
{{ getPermissionLabel(item.permission) }}
|
||||
</v-chip>
|
||||
<v-menu location="bottom">
|
||||
<template v-slot:activator="{ props }">
|
||||
<v-chip
|
||||
v-bind="props"
|
||||
:color="getPermissionColor(item.permission)"
|
||||
size="small"
|
||||
class="font-weight-medium cursor-pointer"
|
||||
link
|
||||
>
|
||||
{{ getPermissionLabel(item.permission) }}
|
||||
<v-icon end size="14">mdi-chevron-down</v-icon>
|
||||
</v-chip>
|
||||
</template>
|
||||
<v-list density="compact">
|
||||
<v-list-item
|
||||
:value="'member'"
|
||||
@click="$emit('update-permission', item, 'member')"
|
||||
:active="item.permission !== 'admin'"
|
||||
>
|
||||
<v-list-item-title>{{ tm('permission.everyone') }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
<v-list-item
|
||||
:value="'admin'"
|
||||
@click="$emit('update-permission', item, 'admin')"
|
||||
:active="item.permission === 'admin'"
|
||||
>
|
||||
<v-list-item-title>{{ tm('permission.admin') }}</v-list-item-title>
|
||||
</v-list-item>
|
||||
</v-list>
|
||||
</v-menu>
|
||||
</template>
|
||||
|
||||
<template v-slot:item.enabled="{ item }">
|
||||
@@ -253,5 +281,9 @@ code.sub-command-code {
|
||||
.v-data-table .sub-command-row:hover {
|
||||
background-color: rgba(var(--v-theme-info), 0.08) !important;
|
||||
}
|
||||
|
||||
.cursor-pointer {
|
||||
cursor: pointer;
|
||||
}
|
||||
</style>
|
||||
|
||||
|
||||
@@ -28,8 +28,8 @@ export function useCommandActions(
|
||||
* 切换指令启用/禁用状态
|
||||
*/
|
||||
const toggleCommand = async (
|
||||
cmd: CommandItem,
|
||||
successMessage: string,
|
||||
cmd: CommandItem,
|
||||
successMessage: string,
|
||||
errorMessage: string
|
||||
) => {
|
||||
try {
|
||||
@@ -131,7 +131,7 @@ export function useCommandActions(
|
||||
* 获取状态显示信息
|
||||
*/
|
||||
const getStatusInfo = (
|
||||
cmd: CommandItem,
|
||||
cmd: CommandItem,
|
||||
translations: { conflict: string; enabled: string; disabled: string }
|
||||
): StatusInfo => {
|
||||
if (cmd.has_conflict) {
|
||||
@@ -160,13 +160,39 @@ export function useCommandActions(
|
||||
return classes.length > 0 ? { class: classes.join(' ') } : {};
|
||||
};
|
||||
|
||||
/**
|
||||
* 更新指令权限
|
||||
*/
|
||||
const updatePermission = async (
|
||||
cmd: CommandItem,
|
||||
permission: 'admin' | 'member',
|
||||
successMessage: string,
|
||||
errorMessage: string
|
||||
) => {
|
||||
try {
|
||||
const res = await axios.post('/api/commands/permission', {
|
||||
handler_full_name: cmd.handler_full_name,
|
||||
permission: permission
|
||||
});
|
||||
if (res.data.status === 'ok') {
|
||||
toast(successMessage, 'success');
|
||||
await fetchCommands();
|
||||
} else {
|
||||
toast(res.data.message || errorMessage, 'error');
|
||||
}
|
||||
} catch (err: any) {
|
||||
toast(err?.message || errorMessage, 'error');
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
// 状态
|
||||
renameDialog,
|
||||
detailsDialog,
|
||||
|
||||
|
||||
// 方法
|
||||
toggleCommand,
|
||||
updatePermission,
|
||||
openRenameDialog,
|
||||
confirmRename,
|
||||
openDetailsDialog,
|
||||
|
||||
@@ -76,6 +76,7 @@ const {
|
||||
renameDialog,
|
||||
detailsDialog,
|
||||
toggleCommand,
|
||||
updatePermission,
|
||||
openRenameDialog,
|
||||
confirmRename,
|
||||
openDetailsDialog
|
||||
@@ -95,6 +96,10 @@ const handleToggleCommand = async (cmd: CommandItem) => {
|
||||
await toggleCommand(cmd, tm('messages.toggleSuccess'), tm('messages.toggleFailed'));
|
||||
};
|
||||
|
||||
const handleUpdatePermission = async (cmd: CommandItem, permission: 'admin' | 'member') => {
|
||||
await updatePermission(cmd, permission, tm('messages.updateSuccess'), tm('messages.updateFailed'));
|
||||
};
|
||||
|
||||
const handleToggleTool = async (tool: ToolItem) => {
|
||||
const previous = tool.active;
|
||||
tool.active = !tool.active;
|
||||
@@ -240,6 +245,7 @@ watch(viewMode, async (mode) => {
|
||||
@toggle-command="handleToggleCommand"
|
||||
@rename="openRenameDialog"
|
||||
@view-details="openDetailsDialog"
|
||||
@update-permission="handleUpdatePermission"
|
||||
/>
|
||||
</div>
|
||||
|
||||
|
||||
@@ -119,8 +119,17 @@
|
||||
</v-list-item-subtitle>
|
||||
|
||||
<template v-slot:append>
|
||||
<v-icon v-if="selectedItemId === getItemId(item)"
|
||||
color="primary" size="22">mdi-check-circle</v-icon>
|
||||
<div class="d-flex align-center ga-1">
|
||||
<v-btn v-if="showEditButton && !isDefaultItem(item)"
|
||||
icon="mdi-pencil"
|
||||
size="small"
|
||||
variant="text"
|
||||
@click.stop="handleEditItem(item)"
|
||||
:title="labels.editButton || 'Edit'"
|
||||
/>
|
||||
<v-icon v-if="selectedItemId === getItemId(item)"
|
||||
color="primary" size="22">mdi-check-circle</v-icon>
|
||||
</div>
|
||||
</template>
|
||||
</v-list-item>
|
||||
</template>
|
||||
@@ -197,6 +206,11 @@ export default defineComponent({
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
// 是否显示编辑按钮
|
||||
showEditButton: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
// 默认项(如 "默认人格")
|
||||
defaultItem: {
|
||||
type: Object as PropType<SelectableItem | null>,
|
||||
@@ -221,7 +235,7 @@ export default defineComponent({
|
||||
default: null
|
||||
}
|
||||
},
|
||||
emits: ['update:modelValue', 'navigate', 'create'],
|
||||
emits: ['update:modelValue', 'navigate', 'create', 'edit'],
|
||||
data() {
|
||||
return {
|
||||
dialog: false,
|
||||
@@ -370,6 +384,17 @@ export default defineComponent({
|
||||
cancelSelection() {
|
||||
this.selectedItemId = this.modelValue || '';
|
||||
this.dialog = false;
|
||||
},
|
||||
|
||||
isDefaultItem(item: SelectableItem): boolean {
|
||||
if (this.defaultItem === null) {
|
||||
return false;
|
||||
}
|
||||
return this.getItemId(item) === this.getItemId(this.defaultItem);
|
||||
},
|
||||
|
||||
handleEditItem(item: SelectableItem) {
|
||||
this.$emit('edit', item);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
@@ -241,6 +241,7 @@ export interface FolderItemSelectorLabels {
|
||||
|
||||
// 按钮
|
||||
createButton?: string;
|
||||
editButton?: string;
|
||||
confirmButton?: string;
|
||||
cancelButton?: string;
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
<template>
|
||||
<v-dialog v-model="showDialog" max-width="800px" height="90%" @after-enter="prepareData">
|
||||
<v-dialog v-model="showDialog" max-width="800px" max-height="90%" @after-enter="prepareData">
|
||||
<v-card
|
||||
:title="updatingMode ? `${tm('dialog.edit')} ${updatingPlatformConfig.id} ${tm('dialog.adapter')}` : tm('dialog.addPlatform')">
|
||||
<v-card-text ref="dialogScrollContainer" class="pa-4 ml-2" style="overflow-y: auto;">
|
||||
@@ -9,14 +9,14 @@
|
||||
</div>
|
||||
<div style="flex: 1;">
|
||||
<h3>
|
||||
选择消息平台类别
|
||||
{{ tm('createDialog.step1Title') }}
|
||||
</h3>
|
||||
<small style="color: grey;">想把机器人接入到哪里?如 QQ、企业微信、飞书、Discord、Telegram 等。</small>
|
||||
<small style="color: grey;">{{ tm('createDialog.step1Hint') }}</small>
|
||||
<div>
|
||||
|
||||
<div v-if="!updatingMode">
|
||||
<v-select v-model="selectedPlatformType" :items="Object.keys(platformTemplates)" item-title="name"
|
||||
item-value="name" label="消息平台类别" variant="outlined" rounded="md" dense hide-details class="mt-6"
|
||||
item-value="name" :label="tm('createDialog.platformTypeLabel')" variant="outlined" rounded="md" dense hide-details class="mt-6"
|
||||
style="max-width: 30%; min-width: 300px;">
|
||||
|
||||
<template v-slot:item="{ props: itemProps, item }">
|
||||
@@ -41,7 +41,7 @@
|
||||
</div>
|
||||
</div>
|
||||
<div v-else>
|
||||
<v-text-field label="消息平台类别" variant="outlined" rounded="md" dense hide-details class="mt-6"
|
||||
<v-text-field :label="tm('createDialog.platformTypeLabel')" variant="outlined" rounded="md" dense hide-details class="mt-6"
|
||||
style="max-width: 30%; min-width: 300px;" v-model="updatingPlatformConfig.type"
|
||||
disabled></v-text-field>
|
||||
<div class="mt-3">
|
||||
@@ -65,13 +65,13 @@
|
||||
<div>
|
||||
<div class="d-flex align-center">
|
||||
<h3>
|
||||
配置文件
|
||||
{{ tm('createDialog.configFileTitle') }}
|
||||
</h3>
|
||||
<v-chip size="x-small" color="primary" variant="tonal" rounded="sm" class="ml-2"
|
||||
v-if="!updatingMode">可选</v-chip>
|
||||
v-if="!updatingMode">{{ tm('createDialog.optional') }}</v-chip>
|
||||
</div>
|
||||
<small style="color: grey;">想如何配置机器人?配置文件包含了聊天模型、人格、知识库、插件范围等丰富的机器人配置项。</small>
|
||||
<small style="color: grey;" v-if="!updatingMode">默认使用默认配置文件 “default”。您也可以稍后配置。</small>
|
||||
<small style="color: grey;">{{ tm('createDialog.configHint') }}</small>
|
||||
<small style="color: grey;" v-if="!updatingMode">{{ tm('createDialog.configDefaultHint') }}</small>
|
||||
</div>
|
||||
<div>
|
||||
<v-btn variant="plain" icon @click="toggleConfigSection" class="mt-2">
|
||||
@@ -86,12 +86,12 @@
|
||||
<v-radio-group class="mt-2" v-model="aBConfigRadioVal" hide-details="true">
|
||||
<v-radio value="0">
|
||||
<template v-slot:label>
|
||||
<span>使用现有配置文件</span>
|
||||
<span>{{ tm('createDialog.useExistingConfig') }}</span>
|
||||
</template>
|
||||
</v-radio>
|
||||
<div class="d-flex align-center ml-10 my-2" v-if="aBConfigRadioVal === '0'">
|
||||
<v-select v-model="selectedAbConfId" :items="configInfoList" item-title="name"
|
||||
item-value="id" label="选择配置文件" variant="outlined" rounded="md" dense hide-details
|
||||
item-value="id" :label="tm('createDialog.selectConfigLabel')" variant="outlined" rounded="md" dense hide-details
|
||||
style="max-width: 30%; min-width: 200px;">
|
||||
</v-select>
|
||||
<v-btn icon variant="text" density="comfortable" class="ml-2"
|
||||
@@ -99,10 +99,10 @@
|
||||
<v-icon>mdi-arrow-top-right-thick</v-icon>
|
||||
</v-btn>
|
||||
</div>
|
||||
<v-radio value="1" label="创建新配置文件">
|
||||
<v-radio value="1" :label="tm('createDialog.createNewConfig')">
|
||||
</v-radio>
|
||||
<div class="d-flex align-center" v-if="aBConfigRadioVal === '1'">
|
||||
<v-text-field v-model="selectedAbConfId" label="新配置文件名称" variant="outlined" rounded="md" dense
|
||||
<v-text-field v-model="selectedAbConfId" :label="tm('createDialog.newConfigNameLabel')" variant="outlined" rounded="md" dense
|
||||
hide-details style="max-width: 30%; min-width: 200px;" class="ml-10 my-2">
|
||||
</v-text-field>
|
||||
</div>
|
||||
@@ -131,12 +131,12 @@
|
||||
<v-progress-circular indeterminate color="primary"></v-progress-circular>
|
||||
</div>
|
||||
<div v-else-if="newConfigData && newConfigMetadata" class="config-preview-container">
|
||||
<h4 class="mb-3">使用新的配置文件</h4>
|
||||
<h4 class="mb-3">{{ tm('createDialog.newConfigTitle') }}</h4>
|
||||
<AstrBotCoreConfigWrapper :metadata="newConfigMetadata" :config_data="newConfigData" />
|
||||
</div>
|
||||
<div v-else class="text-center py-4 text-grey">
|
||||
<v-icon>mdi-information-outline</v-icon>
|
||||
<p class="mt-2">无法加载默认配置模板</p>
|
||||
<p class="mt-2">{{ tm('createDialog.newConfigLoadFailed') }}</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -147,18 +147,18 @@
|
||||
<div>
|
||||
<v-btn v-if="isEditingRoutes" color="primary" variant="tonal" @click="addNewRoute" size="small">
|
||||
<v-icon start>mdi-plus</v-icon>
|
||||
添加路由规则
|
||||
{{ tm('createDialog.addRouteRule') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
<v-btn :color="isEditingRoutes ? 'grey' : 'primary'" variant="tonal" size="small"
|
||||
@click="toggleEditMode">
|
||||
<v-icon start>{{ isEditingRoutes ? 'mdi-eye' : 'mdi-pencil' }}</v-icon>
|
||||
{{ isEditingRoutes ? '查看' : '编辑' }}
|
||||
{{ isEditingRoutes ? tm('createDialog.viewMode') : tm('createDialog.editMode') }}
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<v-data-table :headers="routeTableHeaders" :items="platformRoutes" item-value="umop"
|
||||
no-data-text="该平台暂无路由规则,将使用默认配置文件" hide-default-footer :items-per-page="-1" class="mt-2"
|
||||
:no-data-text="tm('createDialog.noRouteRules')" hide-default-footer :items-per-page="-1" class="mt-2"
|
||||
variant="outlined">
|
||||
|
||||
<template v-slot:item.source="{ item }">
|
||||
@@ -170,9 +170,9 @@
|
||||
<small v-else>{{ getMessageTypeLabel(item.messageType) }}</small>
|
||||
<small class="mx-1">:</small>
|
||||
<v-text-field v-if="isEditingRoutes" v-model="item.sessionId" variant="outlined" density="compact"
|
||||
hide-details placeholder="会话ID或*">
|
||||
hide-details :placeholder="tm('createDialog.sessionIdPlaceholder')">
|
||||
</v-text-field>
|
||||
<small v-else>{{ item.sessionId === '*' ? '全部会话' : item.sessionId }}</small>
|
||||
<small v-else>{{ item.sessionId === '*' ? tm('createDialog.allSessions') : item.sessionId }}</small>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
@@ -191,7 +191,7 @@
|
||||
</v-btn>
|
||||
</div>
|
||||
<small v-if="configInfoList.findIndex(c => c.id === item.configId) === -1" style="color: red;"
|
||||
class="ml-2">配置文件不存在</small>
|
||||
class="ml-2">{{ tm('createDialog.configMissing') }}</small>
|
||||
</template>
|
||||
|
||||
<template v-slot:item.actions="{ item, index }">
|
||||
@@ -211,8 +211,7 @@
|
||||
</template>
|
||||
|
||||
</v-data-table>
|
||||
<small class="ml-2 mt-2 d-block" style="color: grey">*消息下发时,根据会话来源按顺序从上到下匹配首个符合条件的配置文件。使用 * 表示匹配所有。使用 /sid 指令获取会话
|
||||
ID。全部不匹配时将使用默认配置文件。</small>
|
||||
<small class="ml-2 mt-2 d-block" style="color: grey">{{ tm('createDialog.routeHint') }}</small>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -266,10 +265,10 @@
|
||||
<v-card-actions class="px-4 pb-4">
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="error" @click="handleOneBotEmptyTokenWarningDismiss(true)">
|
||||
无视警告并继续创建
|
||||
{{ tm('createDialog.warningContinue') }}
|
||||
</v-btn>
|
||||
<v-btn color="primary" @click="handleOneBotEmptyTokenWarningDismiss(false)">
|
||||
重新修改
|
||||
{{ tm('createDialog.warningEditAgain') }}
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
@@ -286,9 +285,9 @@
|
||||
<v-card class="config-drawer-card" elevation="12">
|
||||
<div class="config-drawer-header">
|
||||
<div>
|
||||
<span class="text-h6">配置文件管理</span>
|
||||
<span class="text-h6">{{ tm('createDialog.configDrawerTitle') }}</span>
|
||||
<div v-if="configDrawerTargetId" class="text-caption text-grey">
|
||||
ID: {{ configDrawerTargetId }}
|
||||
{{ tm('createDialog.configDrawerIdLabel') }}: {{ configDrawerTargetId }}
|
||||
</div>
|
||||
</div>
|
||||
<v-btn icon variant="text" @click="closeConfigDrawer">
|
||||
@@ -359,23 +358,9 @@ export default {
|
||||
|
||||
// 平台配置文件表格(已弃用,改用 platformRoutes)
|
||||
platformConfigs: [],
|
||||
configTableHeaders: [
|
||||
{ title: '与此实例关联的配置文件 ID', key: 'name', sortable: false },
|
||||
{ title: '在此实例下的应用范围', key: 'scope', sortable: false },
|
||||
],
|
||||
|
||||
// 平台路由表
|
||||
platformRoutes: [],
|
||||
routeTableHeaders: [
|
||||
{ title: '消息会话来源(消息类型:会话 ID)', key: 'source', sortable: false, width: '60%' },
|
||||
{ title: '使用配置文件', key: 'configId', sortable: false, width: '20%' },
|
||||
{ title: '操作', key: 'actions', sortable: false, align: 'center', width: '20%' },
|
||||
],
|
||||
messageTypeOptions: [
|
||||
{ label: '全部消息', value: '*' },
|
||||
{ label: '群组消息(GroupMessage)', value: 'GroupMessage' },
|
||||
{ label: '私聊消息(FriendMessage)', value: 'FriendMessage' },
|
||||
],
|
||||
isEditingRoutes: false, // 编辑模式开关
|
||||
|
||||
// ID冲突确认对话框
|
||||
@@ -437,6 +422,26 @@ export default {
|
||||
}
|
||||
|
||||
return false;
|
||||
},
|
||||
configTableHeaders() {
|
||||
return [
|
||||
{ title: this.tm('createDialog.configTableHeaders.configId'), key: 'name', sortable: false },
|
||||
{ title: this.tm('createDialog.configTableHeaders.scope'), key: 'scope', sortable: false },
|
||||
];
|
||||
},
|
||||
routeTableHeaders() {
|
||||
return [
|
||||
{ title: this.tm('createDialog.routeTableHeaders.source'), key: 'source', sortable: false, width: '60%' },
|
||||
{ title: this.tm('createDialog.routeTableHeaders.config'), key: 'configId', sortable: false, width: '20%' },
|
||||
{ title: this.tm('createDialog.routeTableHeaders.actions'), key: 'actions', sortable: false, align: 'center', width: '20%' },
|
||||
];
|
||||
},
|
||||
messageTypeOptions() {
|
||||
return [
|
||||
{ label: this.tm('createDialog.messageTypeOptions.all'), value: '*' },
|
||||
{ label: this.tm('createDialog.messageTypeOptions.group'), value: 'GroupMessage' },
|
||||
{ label: this.tm('createDialog.messageTypeOptions.friend'), value: 'FriendMessage' },
|
||||
];
|
||||
}
|
||||
},
|
||||
watch: {
|
||||
@@ -603,7 +608,7 @@ export default {
|
||||
const targetId = configId || 'default';
|
||||
|
||||
if (configId && this.configInfoList.findIndex(c => c.id === configId) === -1) {
|
||||
this.showError('目标配置文件不存在,已打开配置页面以便检查。');
|
||||
this.showError(this.tm('messages.configNotFoundOpenConfig'));
|
||||
}
|
||||
|
||||
this.configDrawerTargetId = targetId;
|
||||
@@ -637,7 +642,7 @@ export default {
|
||||
const id = this.originalUpdatingPlatformId || this.updatingPlatformConfig.id;
|
||||
if (!id) {
|
||||
this.loading = false;
|
||||
this.showError('更新失败,缺少平台 ID。');
|
||||
this.showError(this.tm('messages.updateMissingPlatformId'));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -655,7 +660,7 @@ export default {
|
||||
})
|
||||
|
||||
if (resp.data.status === 'error') {
|
||||
throw new Error(resp.data.message || '平台更新失败');
|
||||
throw new Error(resp.data.message || this.tm('messages.platformUpdateFailed'));
|
||||
}
|
||||
|
||||
// 同时更新路由表
|
||||
@@ -665,7 +670,7 @@ export default {
|
||||
this.showDialog = false;
|
||||
this.resetForm();
|
||||
this.$emit('refresh-config');
|
||||
this.showSuccess('更新成功');
|
||||
this.showSuccess(this.tm('messages.updateSuccess'));
|
||||
} catch (err) {
|
||||
this.loading = false;
|
||||
this.showError(err.response?.data?.message || err.message);
|
||||
@@ -710,7 +715,7 @@ export default {
|
||||
this.showDialog = false;
|
||||
this.resetForm();
|
||||
this.$emit('refresh-config');
|
||||
this.showSuccess(res.data.message || '平台添加成功,配置文件已更新');
|
||||
this.showSuccess(res.data.message || this.tm('messages.addSuccessWithConfig'));
|
||||
} catch (err) {
|
||||
this.loading = false;
|
||||
this.showError(err.response?.data?.message || err.message);
|
||||
@@ -738,7 +743,7 @@ export default {
|
||||
}
|
||||
|
||||
if (!configId) {
|
||||
throw new Error('无法获取配置文件ID');
|
||||
throw new Error(this.tm('messages.configIdMissing'));
|
||||
}
|
||||
|
||||
// 第二步:统一更新路由表
|
||||
@@ -755,7 +760,8 @@ export default {
|
||||
console.log(`成功更新路由表: ${umop} -> ${configId}`);
|
||||
} catch (err) {
|
||||
console.error('更新路由表失败:', err);
|
||||
throw new Error(`更新路由表失败: ${err.response?.data?.message || err.message}`);
|
||||
const errorMessage = err.response?.data?.message || err.message;
|
||||
throw new Error(this.tm('messages.routingUpdateFailed', { message: errorMessage }));
|
||||
}
|
||||
},
|
||||
|
||||
@@ -778,7 +784,8 @@ export default {
|
||||
return newConfigId;
|
||||
} catch (err) {
|
||||
console.error('创建新配置文件失败:', err);
|
||||
throw new Error(`创建新配置文件失败: ${err.response?.data?.message || err.message}`);
|
||||
const errorMessage = err.response?.data?.message || err.message;
|
||||
throw new Error(this.tm('messages.createConfigFailed', { message: errorMessage }));
|
||||
}
|
||||
},
|
||||
|
||||
@@ -922,7 +929,7 @@ export default {
|
||||
const newPlatformId = this.updatingPlatformConfig?.id || originalPlatformId;
|
||||
|
||||
if (!originalPlatformId && !newPlatformId) {
|
||||
throw new Error('无法获取平台 ID');
|
||||
throw new Error(this.tm('messages.platformIdMissing'));
|
||||
}
|
||||
|
||||
try {
|
||||
@@ -958,7 +965,8 @@ export default {
|
||||
});
|
||||
} catch (err) {
|
||||
console.error('保存路由表失败:', err);
|
||||
throw new Error(`保存路由表失败: ${err.response?.data?.message || err.message}`);
|
||||
const errorMessage = err.response?.data?.message || err.message;
|
||||
throw new Error(this.tm('messages.routingSaveFailed', { message: errorMessage }));
|
||||
}
|
||||
},
|
||||
|
||||
@@ -987,10 +995,10 @@ export default {
|
||||
// 获取消息类型标签
|
||||
getMessageTypeLabel(messageType) {
|
||||
const typeMap = {
|
||||
'*': '全部消息',
|
||||
'': '全部消息',
|
||||
'GroupMessage': '群组消息',
|
||||
'FriendMessage': '私聊消息'
|
||||
'*': this.tm('createDialog.messageTypeLabels.all'),
|
||||
'': this.tm('createDialog.messageTypeLabels.all'),
|
||||
'GroupMessage': this.tm('createDialog.messageTypeLabels.group'),
|
||||
'FriendMessage': this.tm('createDialog.messageTypeLabels.friend')
|
||||
};
|
||||
return typeMap[messageType] || messageType;
|
||||
},
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
rounded="xl"
|
||||
size="small"
|
||||
>
|
||||
新增
|
||||
{{ tm('providerSources.add') }}
|
||||
</v-btn>
|
||||
</template>
|
||||
<v-list density="compact">
|
||||
|
||||
@@ -3,7 +3,7 @@ import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref, computed } from 'vue'
|
||||
import ConfigItemRenderer from './ConfigItemRenderer.vue'
|
||||
import TemplateListEditor from './TemplateListEditor.vue'
|
||||
import { useI18n } from '@/i18n/composables'
|
||||
import { useI18n, useModuleI18n } from '@/i18n/composables'
|
||||
import axios from 'axios'
|
||||
import { useToast } from '@/utils/toast'
|
||||
|
||||
@@ -35,6 +35,12 @@ const props = defineProps({
|
||||
})
|
||||
|
||||
const { t } = useI18n()
|
||||
const { tm, getRaw } = useModuleI18n('features/config-metadata')
|
||||
|
||||
const translateIfKey = (value) => {
|
||||
if (!value || typeof value !== 'string') return value
|
||||
return getRaw(value) ? tm(value) : value
|
||||
}
|
||||
|
||||
const filteredIterable = computed(() => {
|
||||
if (!props.iterable) return {}
|
||||
@@ -134,11 +140,11 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<template>
|
||||
<div class="config-section" v-if="iterable && metadata[metadataKey]?.type === 'object'">
|
||||
<v-list-item-title class="config-title">
|
||||
{{ metadata[metadataKey]?.description }} <span class="metadata-key">({{ metadataKey }})</span>
|
||||
{{ translateIfKey(metadata[metadataKey]?.description) }} <span class="metadata-key">({{ metadataKey }})</span>
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle class="config-hint">
|
||||
<span v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint" class="important-hint">‼️</span>
|
||||
{{ metadata[metadataKey]?.hint }}
|
||||
{{ translateIfKey(metadata[metadataKey]?.hint) }}
|
||||
</v-list-item-subtitle>
|
||||
</div>
|
||||
|
||||
@@ -180,14 +186,14 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<div class="config-section mb-2">
|
||||
<v-list-item-title class="config-title">
|
||||
<span v-if="metadata[metadataKey].items[key]?.description">
|
||||
{{ metadata[metadataKey].items[key]?.description }}
|
||||
{{ translateIfKey(metadata[metadataKey].items[key]?.description) }}
|
||||
<span class="property-key">({{ key }})</span>
|
||||
</span>
|
||||
<span v-else>{{ key }}</span>
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle class="config-hint">
|
||||
<span v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint" class="important-hint">‼️</span>
|
||||
{{ metadata[metadataKey].items[key]?.hint }}
|
||||
{{ translateIfKey(metadata[metadataKey].items[key]?.hint) }}
|
||||
</v-list-item-subtitle>
|
||||
</div>
|
||||
<TemplateListEditor
|
||||
@@ -205,7 +211,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<v-list-item density="compact">
|
||||
<v-list-item-title class="property-name">
|
||||
<span v-if="metadata[metadataKey].items[key]?.description">
|
||||
{{ metadata[metadataKey].items[key]?.description }}
|
||||
{{ translateIfKey(metadata[metadataKey].items[key]?.description) }}
|
||||
<span class="property-key">({{ key }})</span>
|
||||
</span>
|
||||
<span v-else>{{ key }}</span>
|
||||
@@ -214,7 +220,7 @@ function hasVisibleItemsAfter(items, currentIndex) {
|
||||
<v-list-item-subtitle class="property-hint">
|
||||
<span v-if="metadata[metadataKey].items[key]?.obvious_hint && metadata[metadataKey].items[key]?.hint"
|
||||
class="important-hint">‼️</span>
|
||||
{{ metadata[metadataKey].items[key]?.hint }}
|
||||
{{ translateIfKey(metadata[metadataKey].items[key]?.hint) }}
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-col>
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
<script setup>
|
||||
import MarkdownIt from 'markdown-it'
|
||||
import { VueMonacoEditor } from '@guolao/vue-monaco-editor'
|
||||
import { ref, computed } from 'vue'
|
||||
import ConfigItemRenderer from './ConfigItemRenderer.vue'
|
||||
@@ -24,12 +25,23 @@ const props = defineProps({
|
||||
const { t } = useI18n()
|
||||
const { tm, getRaw } = useModuleI18n('features/config-metadata')
|
||||
|
||||
const hintMarkdown = new MarkdownIt({
|
||||
linkify: true,
|
||||
breaks: true
|
||||
})
|
||||
|
||||
// 翻译器函数 - 如果是国际化键则翻译,否则原样返回
|
||||
const translateIfKey = (value) => {
|
||||
if (!value || typeof value !== 'string') return value
|
||||
return tm(value)
|
||||
}
|
||||
|
||||
const renderHint = (value) => {
|
||||
const text = translateIfKey(value)
|
||||
if (!text) return ''
|
||||
return hintMarkdown.renderInline(text)
|
||||
}
|
||||
|
||||
// 处理labels翻译 - labels可以是数组或国际化键
|
||||
const getTranslatedLabels = (itemMeta) => {
|
||||
if (!itemMeta?.labels) return null
|
||||
@@ -185,7 +197,7 @@ function getSpecialSubtype(value) {
|
||||
</v-list-item-title>
|
||||
<v-list-item-subtitle class="config-hint">
|
||||
<span v-if="metadata[metadataKey]?.obvious_hint && metadata[metadataKey]?.hint" class="important-hint">‼️</span>
|
||||
{{ translateIfKey(metadata[metadataKey]?.hint) }}
|
||||
<span v-html="renderHint(metadata[metadataKey]?.hint)"></span>
|
||||
</v-list-item-subtitle>
|
||||
</v-card-text>
|
||||
|
||||
@@ -205,7 +217,7 @@ function getSpecialSubtype(value) {
|
||||
|
||||
<v-list-item-subtitle class="property-hint">
|
||||
<span v-if="itemMeta?.obvious_hint && itemMeta?.hint" class="important-hint">‼️</span>
|
||||
{{ translateIfKey(itemMeta?.hint) }}
|
||||
<span v-html="renderHint(itemMeta?.hint)"></span>
|
||||
</v-list-item-subtitle>
|
||||
</v-list-item>
|
||||
</v-col>
|
||||
@@ -293,6 +305,12 @@ function getSpecialSubtype(value) {
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.config-hint :deep(a),
|
||||
.property-hint :deep(a) {
|
||||
color: var(--v-theme-primary);
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.metadata-key,
|
||||
.property-key {
|
||||
font-size: 0.85em;
|
||||
|
||||
@@ -530,8 +530,13 @@ export default {
|
||||
try {
|
||||
const response = await axios.get('/api/skills');
|
||||
if (response.data.status === 'ok') {
|
||||
const skills = response.data.data || [];
|
||||
this.availableSkills = skills.filter(skill => skill.active !== false);
|
||||
const payload = response.data.data || [];
|
||||
if (Array.isArray(payload)) {
|
||||
this.availableSkills = payload.filter(skill => skill.active !== false);
|
||||
} else {
|
||||
const skills = payload.skills || [];
|
||||
this.availableSkills = skills.filter(skill => skill.active !== false);
|
||||
}
|
||||
} else {
|
||||
this.$emit('error', response.data.message || 'Failed to load skills');
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
:items-loading="itemsLoading"
|
||||
:labels="labels"
|
||||
:show-create-button="true"
|
||||
:show-edit-button="true"
|
||||
:default-item="defaultPersona"
|
||||
item-id-field="persona_id"
|
||||
item-name-field="persona_id"
|
||||
@@ -15,15 +16,16 @@
|
||||
:display-value-formatter="formatDisplayValue"
|
||||
@navigate="handleNavigate"
|
||||
@create="openCreatePersona"
|
||||
@edit="openEditPersona"
|
||||
/>
|
||||
|
||||
<!-- 创建人格对话框 -->
|
||||
<!-- 创建/编辑人格对话框 -->
|
||||
<PersonaForm
|
||||
v-model="showCreateDialog"
|
||||
:editing-persona="undefined"
|
||||
v-model="showPersonaDialog"
|
||||
:editing-persona="editingPersona ?? undefined"
|
||||
:current-folder-id="currentFolderId ?? undefined"
|
||||
:current-folder-name="currentFolderName ?? undefined"
|
||||
@saved="handlePersonaCreated"
|
||||
@saved="handlePersonaSaved"
|
||||
@error="handleError" />
|
||||
</template>
|
||||
|
||||
@@ -62,7 +64,8 @@ const folderTree = ref<FolderTreeNode[]>([])
|
||||
const currentPersonas = ref<Persona[]>([])
|
||||
const treeLoading = ref(false)
|
||||
const itemsLoading = ref(false)
|
||||
const showCreateDialog = ref(false)
|
||||
const showPersonaDialog = ref(false)
|
||||
const editingPersona = ref<Persona | null>(null)
|
||||
const currentFolderId = ref<string | null>(null)
|
||||
|
||||
// 默认人格
|
||||
@@ -104,6 +107,7 @@ const labels = computed(() => ({
|
||||
defaultItem: tm('personaSelector.defaultPersona'),
|
||||
noDescription: tm('personaSelector.noDescription'),
|
||||
createButton: tm('personaSelector.createPersona'),
|
||||
editButton: tm('personaSelector.editPersona') || 'Edit',
|
||||
confirmButton: t('core.common.confirm'),
|
||||
cancelButton: t('core.common.cancel'),
|
||||
rootFolder: tm('personaSelector.rootFolder') || '全部人格',
|
||||
@@ -171,13 +175,21 @@ async function handleNavigate(folderId: string | null) {
|
||||
|
||||
// 打开创建人格对话框
|
||||
function openCreatePersona() {
|
||||
showCreateDialog.value = true
|
||||
editingPersona.value = null
|
||||
showPersonaDialog.value = true
|
||||
}
|
||||
|
||||
// 人格创建成功
|
||||
async function handlePersonaCreated(message: string) {
|
||||
console.log('人格创建成功:', message)
|
||||
showCreateDialog.value = false
|
||||
// 打开编辑人格对话框
|
||||
function openEditPersona(persona: Persona) {
|
||||
editingPersona.value = persona
|
||||
showPersonaDialog.value = true
|
||||
}
|
||||
|
||||
// 人格保存成功(创建或编辑)
|
||||
async function handlePersonaSaved(message: string) {
|
||||
console.log('人格保存成功:', message)
|
||||
showPersonaDialog.value = false
|
||||
editingPersona.value = null
|
||||
// 刷新当前文件夹的人格列表
|
||||
await loadPersonasInFolder(currentFolderId.value)
|
||||
}
|
||||
|
||||
@@ -33,9 +33,15 @@ export default {
|
||||
methods: {
|
||||
async check() {
|
||||
this.newStartTime = -1
|
||||
this.startTime = useCommonStore().getStartTime()
|
||||
this.cnt = 0
|
||||
this.visible = true
|
||||
this.status = ""
|
||||
const commonStore = useCommonStore()
|
||||
try {
|
||||
this.startTime = await commonStore.fetchStartTime()
|
||||
} catch (_error) {
|
||||
this.startTime = commonStore.getStartTime()
|
||||
}
|
||||
console.log('start wfr')
|
||||
setTimeout(() => {
|
||||
this.timeoutInternal()
|
||||
@@ -50,7 +56,7 @@ export default {
|
||||
this.timeoutInternal()
|
||||
}, 1000)
|
||||
} else {
|
||||
if (this.cnt == 10) {
|
||||
if (this.cnt >= 60) {
|
||||
this.status = this.t('core.common.restart.maxRetriesReached')
|
||||
}
|
||||
this.cnt = 0
|
||||
@@ -60,18 +66,22 @@ export default {
|
||||
}
|
||||
},
|
||||
async checkStartTime() {
|
||||
let res = await axios.get('/api/stat/start-time', { timeout: 3000 })
|
||||
let newStartTime = res.data.data.start_time
|
||||
console.log('wfr: checkStartTime', this.newStartTime, this.startTime)
|
||||
if (this.newStartTime !== this.startTime) {
|
||||
this.newStartTime = newStartTime
|
||||
console.log('wfr: restarted')
|
||||
this.visible = false
|
||||
// reload
|
||||
window.location.reload()
|
||||
try {
|
||||
let res = await axios.get('/api/stat/start-time', { timeout: 3000 })
|
||||
let newStartTime = res.data.data.start_time
|
||||
console.log('wfr: checkStartTime', newStartTime, this.startTime)
|
||||
if (this.startTime !== -1 && newStartTime !== this.startTime) {
|
||||
this.newStartTime = newStartTime
|
||||
console.log('wfr: restarted')
|
||||
this.visible = false
|
||||
// reload
|
||||
window.location.reload()
|
||||
}
|
||||
} catch (_error) {
|
||||
// backend may be unavailable during restart window
|
||||
}
|
||||
return this.newStartTime
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
</script>
|
||||
|
||||
@@ -59,14 +59,14 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
|
||||
let suppressSourceWatch = false
|
||||
|
||||
const providerTypes = [
|
||||
const providerTypes = computed(() => [
|
||||
{ value: 'chat_completion', label: tm('providers.tabs.chatCompletion'), icon: 'mdi-message-text' },
|
||||
{ value: 'agent_runner', label: tm('providers.tabs.agentRunner'), icon: 'mdi-robot' },
|
||||
{ value: 'speech_to_text', label: tm('providers.tabs.speechToText'), icon: 'mdi-microphone-message' },
|
||||
{ value: 'text_to_speech', label: tm('providers.tabs.textToSpeech'), icon: 'mdi-volume-high' },
|
||||
{ value: 'embedding', label: tm('providers.tabs.embedding'), icon: 'mdi-code-json' },
|
||||
{ value: 'rerank', label: tm('providers.tabs.rerank'), icon: 'mdi-compare-vertical' }
|
||||
]
|
||||
])
|
||||
|
||||
// ===== Computed =====
|
||||
const availableSourceTypes = computed(() => {
|
||||
@@ -233,6 +233,11 @@ export function useProviderSources(options: UseProviderSourcesOptions) {
|
||||
customSchema.provider.items.key.hint = tm('providerSources.hints.key')
|
||||
customSchema.provider.items.api_base.hint = tm('providerSources.hints.apiBase')
|
||||
}
|
||||
// 为 proxy 字段添加描述和提示
|
||||
if (customSchema.provider?.items?.proxy) {
|
||||
customSchema.provider.items.proxy.description = tm('providerSources.labels.proxy')
|
||||
customSchema.provider.items.proxy.hint = tm('providerSources.hints.proxy')
|
||||
}
|
||||
|
||||
return customSchema
|
||||
})
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user