Compare commits

...

22 Commits

Author SHA1 Message Date
Soulter 16d49d568b fix: add reminder for v4.14.8 users regarding manual redeployment due to a bug 2026-02-10 23:20:49 +08:00
Soulter 776e17062c chore: bump version to 4.15.0 (#5003) 2026-02-10 23:17:23 +08:00
エイカク 8fa8c14b0b fix: 修复app内重启异常,修复app内点击重启不能立刻提示重启,以及在后端就绪时及时刷新界面的问题 (#5013)
* fix: patch pip distlib finder for frozen electron runtime

* fix: use certifi CA bundle for runtime SSL requests

* fix: configure certifi CA before core imports

* fix: improve mac font fallback for dashboard text

* fix: harden frozen pip patch and unify TLS connector

* refactor: centralize dashboard CJK font fallback stacks

* perf: reuse TLS context and avoid repeated frozen pip patch

* refactor: bootstrap TLS setup before core imports

* fix: use async confirm dialog for provider deletions

* fix: replace native confirm dialogs in dashboard

- Add shared confirm helper in dashboard/src/utils/confirmDialog.ts for async dialog usage with safe fallback.

- Migrate provider, chat, config, session, platform, persona, MCP, backup, and knowledge-base delete/close confirmations to use the shared helper.

- Remove scattered inline confirm handling to keep behavior consistent and avoid native blocking dialog focus/caret issues in Electron.

* fix: capture runtime bootstrap logs after logger init

- Add bootstrap record buffer in runtime_bootstrap for early TLS patch logs before logger is ready.

- Flush buffered bootstrap logs to astrbot logger at process startup in main.py.

- Include concrete exception details for TLS bootstrap failures to improve diagnosis.

* fix: harden runtime bootstrap and unify confirm handling

- Simplify bootstrap log buffering and add a public initialize hook for non-main startup paths.

- Guard aiohttp TLS patching with feature/type checks and keep graceful fallback when internals are unavailable.

- Standardize dashboard confirmation flow via shared confirm helpers across composition and options API components.

* refactor: simplify runtime tls bootstrap and tighten confirm typing

* refactor: align ssl helper namespace and confirm usage

* fix: avoid frozen restart crash from multiprocessing import

* fix: include missing frozen dependencies for windows backend

* fix: use execv for stable backend reboot args

* Revert "fix: use execv for stable backend reboot args"

This reverts commit 9cc27becff.

* Revert "fix: include missing frozen dependencies for windows backend"

This reverts commit 52554bea1f.

* Revert "fix: avoid frozen restart crash from multiprocessing import"

This reverts commit 10548645b0.

* fix: reset pyinstaller onefile env before reboot

* fix: unify electron restart path and tray-exit backend cleanup

* fix: stabilize desktop restart detection and frozen reboot args

* fix: make dashboard restart wait detection robust

* fix: revert dashboard restart waiting interaction tweaks

* fix: pass auth token for desktop graceful restart

* fix: avoid false failure during graceful restart wait

* fix: start restart waiting before electron restart call

* fix: harden restart waiting and reboot arg parsing

* fix: parse start_time as numeric timestamp

* fix: preserve windows frozen reboot argv quoting

* fix: align restart waiting with electron restart timing

* fix: tighten graceful restart and unmanaged kill safety
2026-02-10 22:21:04 +09:00
エイカク 64de474139 fix: 修复 Windows 打包版后端重启失败问题 (#5009)
* fix: patch pip distlib finder for frozen electron runtime

* fix: use certifi CA bundle for runtime SSL requests

* fix: configure certifi CA before core imports

* fix: improve mac font fallback for dashboard text

* fix: harden frozen pip patch and unify TLS connector

* refactor: centralize dashboard CJK font fallback stacks

* perf: reuse TLS context and avoid repeated frozen pip patch

* refactor: bootstrap TLS setup before core imports

* fix: use async confirm dialog for provider deletions

* fix: replace native confirm dialogs in dashboard

- Add shared confirm helper in dashboard/src/utils/confirmDialog.ts for async dialog usage with safe fallback.

- Migrate provider, chat, config, session, platform, persona, MCP, backup, and knowledge-base delete/close confirmations to use the shared helper.

- Remove scattered inline confirm handling to keep behavior consistent and avoid native blocking dialog focus/caret issues in Electron.

* fix: capture runtime bootstrap logs after logger init

- Add bootstrap record buffer in runtime_bootstrap for early TLS patch logs before logger is ready.

- Flush buffered bootstrap logs to astrbot logger at process startup in main.py.

- Include concrete exception details for TLS bootstrap failures to improve diagnosis.

* fix: harden runtime bootstrap and unify confirm handling

- Simplify bootstrap log buffering and add a public initialize hook for non-main startup paths.

- Guard aiohttp TLS patching with feature/type checks and keep graceful fallback when internals are unavailable.

- Standardize dashboard confirmation flow via shared confirm helpers across composition and options API components.

* refactor: simplify runtime tls bootstrap and tighten confirm typing

* refactor: align ssl helper namespace and confirm usage

* fix: avoid frozen restart crash from multiprocessing import

* fix: include missing frozen dependencies for windows backend

* fix: use execv for stable backend reboot args

* Revert "fix: use execv for stable backend reboot args"

This reverts commit 9cc27becff.

* Revert "fix: include missing frozen dependencies for windows backend"

This reverts commit 52554bea1f.

* Revert "fix: avoid frozen restart crash from multiprocessing import"

This reverts commit 10548645b0.

* fix: reset pyinstaller onefile env before reboot

* fix: unify electron restart path and tray-exit backend cleanup

* fix: stabilize desktop restart detection and frozen reboot args

* fix: make dashboard restart wait detection robust

* fix: revert dashboard restart waiting interaction tweaks

* fix: pass auth token for desktop graceful restart

* fix: avoid false failure during graceful restart wait

* fix: start restart waiting before electron restart call

* fix: harden restart waiting and reboot arg parsing

* fix: parse start_time as numeric timestamp
2026-02-10 21:33:06 +09:00
エイカク d35771f97d fix: stabilize packaged runtime pip/ssl behavior and mac font fallback (#5007)
* fix: patch pip distlib finder for frozen electron runtime

* fix: use certifi CA bundle for runtime SSL requests

* fix: configure certifi CA before core imports

* fix: improve mac font fallback for dashboard text

* fix: harden frozen pip patch and unify TLS connector

* refactor: centralize dashboard CJK font fallback stacks

* perf: reuse TLS context and avoid repeated frozen pip patch

* refactor: bootstrap TLS setup before core imports

* fix: use async confirm dialog for provider deletions

* fix: replace native confirm dialogs in dashboard

- Add shared confirm helper in dashboard/src/utils/confirmDialog.ts for async dialog usage with safe fallback.

- Migrate provider, chat, config, session, platform, persona, MCP, backup, and knowledge-base delete/close confirmations to use the shared helper.

- Remove scattered inline confirm handling to keep behavior consistent and avoid native blocking dialog focus/caret issues in Electron.

* fix: capture runtime bootstrap logs after logger init

- Add bootstrap record buffer in runtime_bootstrap for early TLS patch logs before logger is ready.

- Flush buffered bootstrap logs to astrbot logger at process startup in main.py.

- Include concrete exception details for TLS bootstrap failures to improve diagnosis.

* fix: harden runtime bootstrap and unify confirm handling

- Simplify bootstrap log buffering and add a public initialize hook for non-main startup paths.

- Guard aiohttp TLS patching with feature/type checks and keep graceful fallback when internals are unavailable.

- Standardize dashboard confirmation flow via shared confirm helpers across composition and options API components.

* refactor: simplify runtime tls bootstrap and tighten confirm typing

* refactor: align ssl helper namespace and confirm usage
2026-02-10 16:42:43 +09:00
dependabot[bot] 7a4d20d329 chore(deps): bump the github-actions group with 2 updates (#5006)
Bumps the github-actions group with 2 updates: [astral-sh/setup-uv](https://github.com/astral-sh/setup-uv) and [actions/download-artifact](https://github.com/actions/download-artifact).


Updates `astral-sh/setup-uv` from 6 to 7
- [Release notes](https://github.com/astral-sh/setup-uv/releases)
- [Commits](https://github.com/astral-sh/setup-uv/compare/v6...v7)

Updates `actions/download-artifact` from 6 to 7
- [Release notes](https://github.com/actions/download-artifact/releases)
- [Commits](https://github.com/actions/download-artifact/compare/v6...v7)

---
updated-dependencies:
- dependency-name: astral-sh/setup-uv
  dependency-version: '7'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
- dependency-name: actions/download-artifact
  dependency-version: '7'
  dependency-type: direct:production
  update-type: version-update:semver-major
  dependency-group: github-actions
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-02-10 11:10:26 +08:00
Li-shi-ling aab095347f fix: 'HandoffTool' object has no attribute 'agent' (#5005)
* fix: 移动agent的位置到super().__init__之后

* add: 添加一行注释
2026-02-10 11:01:49 +08:00
エイカク 1addd5b2ab perf: 稳定源码与 Electron 打包环境下的 pip 安装行为,并修复非 Electron 环境下点击 WebUI 更新按钮时出现跳转对话框的问题 (#4996)
* fix: handle pip install execution in frozen runtime

* fix: harden pip subprocess fallback handling

* fix: scope global data root to packaged electron runtime

* refactor: inline frozen runtime check for electron guard

* fix: prefer current interpreter for source pip installs

* fix: avoid resolving venv python symlink for pip

* refactor: share runtime environment detection utilities

* fix: improve error message when pip module is unavailable

* fix: raise ImportError when pip module is unavailable

* fix: preserve ImportError semantics for missing pip

* fix: 修复非electron app环境更新时仍然显示electron更新对话框的问题

---------

Co-authored-by: Soulter <905617992@qq.com>
2026-02-09 23:12:18 +08:00
Soulter da4bb6549c feat: enhance persona tool management and update UI localization for subagent orchestration (#4990)
* feat: enhance persona tool management and update UI localization for subagent orchestration

* fix: remove debug logging for final ProviderRequest in build_main_agent function
2026-02-09 22:38:05 +08:00
Soulter 7193454d50 feat: enhance WecomAIBotAdapter and WecomAIBotMessageEvent for improved streaming message handling (#5000)
fixes: #3965
2026-02-09 22:30:24 +08:00
Soulter d204b92877 feat: 企业微信智能机器人支持主动消息推送以及发送视频、文件等消息类型支持 (#4999) 2026-02-09 22:16:44 +08:00
Soulter 04faf26140 feat: 企业微信应用 支持主动消息推送,并优化企微应用、微信公众号、微信客服音频相关的处理 (#4998) 2026-02-09 22:15:11 +08:00
鸦羽 67b81c279b fix: collect certifi data in desktop backend build (#4995) 2026-02-09 19:40:32 +09:00
エイカク 2afb08d8b2 fix: handle pip install execution in frozen runtime (#4985)
* fix: handle pip install execution in frozen runtime

* fix: harden pip subprocess fallback handling
2026-02-09 15:19:01 +08:00
Soulter 06b2c7cb16 feat: enhance Dingtalk adapter with active push message and image, video, audio message type (#4986) 2026-02-09 15:17:55 +08:00
Copilot 9c12803ddd feat: add delete button to persona management dialog (#4978)
* Initial plan

* feat: add delete button to persona management dialog

- Added delete button to PersonaForm dialog (only visible when editing)
- Implemented deletePersona method with confirmation dialog
- Connected delete event to PersonaManager for proper handling
- Button positioned on left side of dialog actions for clear separation
- Uses existing i18n translations for delete button and messages

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

* fix: use finally block to ensure saving state is reset

- Moved `this.saving = false` to finally block in deletePersona
- Ensures UI doesn't stay in saving state after errors
- Follows best practices for state management

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2026-02-09 11:59:28 +08:00
Soulter ce65491d55 chore: update pydantic dependency version (#4980) 2026-02-09 11:59:05 +08:00
Soulter b67adcf481 ci: change ghcr namespace 2026-02-09 11:51:56 +08:00
Soulter 1707d55c02 fix: prepare OpenSSL via vcpkg for Windows ARM64 2026-02-09 11:04:31 +08:00
Dt8333 7dd95d8a59 chore: auto ann fix by ruff (#4903)
* chore: auto fix by ruff

* refactor: 统一修正返回类型注解为 None/bool 以匹配实现

* refactor: 将 _get_next_page 改为异步并移除多余的请求错误抛出

* refactor: 将 get_client 的返回类型改为 object

* style: 为 LarkMessageEvent 的相关方法添加返回类型注解 None

---------

Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com>
2026-02-09 00:22:24 +08:00
Soulter e1b71540c7 chore: bump version to 4.14.8 and bump faiss-cpu version up to date 2026-02-09 00:19:12 +08:00
Soulter 85e1764857 feat: refactor release workflow and add special update handling for electron app (#4969) 2026-02-08 23:56:30 +08:00
238 changed files with 3640 additions and 1624 deletions
-92
View File
@@ -1,92 +0,0 @@
on:
push:
tags:
- 'v*'
workflow_dispatch:
name: Auto Release
jobs:
build-and-publish-to-github-release:
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Dashboard Build
run: |
cd dashboard
npm install
npm run build
echo "COMMIT_SHA=$(git rev-parse HEAD)" >> $GITHUB_ENV
echo ${{ github.ref_name }} > dist/assets/version
zip -r dist.zip dist
- name: Upload to Cloudflare R2
env:
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
R2_BUCKET_NAME: "astrbot"
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
VERSION_TAG: ${{ github.ref_name }}
run: |
echo "Installing rclone..."
curl https://rclone.org/install.sh | sudo bash
echo "Configuring rclone remote..."
mkdir -p ~/.config/rclone
cat <<EOF > ~/.config/rclone/rclone.conf
[r2]
type = s3
provider = Cloudflare
access_key_id = $R2_ACCESS_KEY_ID
secret_access_key = $R2_SECRET_ACCESS_KEY
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
EOF
echo "Uploading dist.zip to R2 bucket: $R2_BUCKET_NAME/$R2_OBJECT_NAME"
mv dashboard/dist.zip dashboard/$R2_OBJECT_NAME
rclone copy dashboard/$R2_OBJECT_NAME r2:$R2_BUCKET_NAME --progress
mv dashboard/$R2_OBJECT_NAME dashboard/astrbot-webui-${VERSION_TAG}.zip
rclone copy dashboard/astrbot-webui-${VERSION_TAG}.zip r2:$R2_BUCKET_NAME --progress
mv dashboard/astrbot-webui-${VERSION_TAG}.zip dashboard/dist.zip
- name: Fetch Changelog
run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
- name: Create GitHub Release
uses: ncipollo/release-action@v1
with:
bodyFile: ${{ env.changelog }}
artifacts: "dashboard/dist.zip"
build-and-publish-to-pypi:
# 构建并发布到 PyPI
runs-on: ubuntu-latest
needs: build-and-publish-to-github-release
steps:
- name: Checkout repository
uses: actions/checkout@v6
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: '3.10'
- name: Install uv
run: |
python -m pip install uv
- name: Build package
run: |
uv build
- name: Publish to PyPI
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
uv publish
-227
View File
@@ -1,227 +0,0 @@
name: Desktop Release
on:
push:
tags:
- "v*"
workflow_dispatch:
inputs:
ref:
description: "Git ref to build (branch/tag/SHA)"
required: false
default: "master"
tag:
description: "Release tag to upload assets to (for example: v4.14.6)"
required: false
permissions:
contents: write
jobs:
build-desktop:
name: Build ${{ matrix.name }}
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- name: linux-x64
runner: ubuntu-24.04
os: linux
arch: amd64
- name: linux-arm64
runner: ubuntu-24.04-arm
os: linux
arch: arm64
- name: windows-x64
runner: windows-2022
os: win
arch: amd64
- name: windows-arm64
runner: windows-11-arm
os: win
arch: arm64
- name: macos-x64
runner: macos-15-intel
os: mac
arch: amd64
- name: macos-arm64
runner: macos-15
os: mac
arch: arm64
env:
CSC_IDENTITY_AUTO_DISCOVERY: "false"
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ inputs.ref || github.ref }}
- name: Setup uv
uses: astral-sh/setup-uv@v6
- name: Setup Python
uses: actions/setup-python@v6
with:
python-version: "3.11"
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10.28.2
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 20
cache: "pnpm"
cache-dependency-path: |
dashboard/pnpm-lock.yaml
desktop/pnpm-lock.yaml
- name: Install dependencies
run: |
uv sync
pnpm --dir dashboard install --frozen-lockfile
pnpm --dir desktop install --frozen-lockfile
- name: Build desktop package
run: |
pnpm --dir dashboard run build
pnpm --dir desktop run build:webui
pnpm --dir desktop run build:backend
pnpm --dir desktop run sync:version
pnpm --dir desktop exec electron-builder --publish never
- name: Resolve artifact tag
id: tag
shell: bash
run: |
if [ "${{ github.event_name }}" = "push" ]; then
tag="${GITHUB_REF_NAME}"
elif [ -n "${{ inputs.tag }}" ]; then
tag="${{ inputs.tag }}"
else
tag="$(git describe --tags --abbrev=0)"
fi
if [ -z "$tag" ]; then
echo "Failed to resolve artifact tag." >&2
exit 1
fi
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Normalize artifact names
shell: bash
env:
NAME_PREFIX: AstrBot-${{ steps.tag.outputs.tag }}-${{ matrix.arch }}-${{ matrix.os }}
run: |
shopt -s nullglob
out_dir="desktop/dist/release"
mkdir -p "$out_dir"
files=(
desktop/dist/*.AppImage
desktop/dist/*.dmg
desktop/dist/*.zip
desktop/dist/*.exe
)
if [ ${#files[@]} -eq 0 ]; then
echo "No desktop artifacts found to rename." >&2
exit 1
fi
for src in "${files[@]}"; do
file="$(basename "$src")"
case "$file" in
*.AppImage)
dest="$out_dir/${NAME_PREFIX}.AppImage"
;;
*.dmg)
dest="$out_dir/${NAME_PREFIX}.dmg"
;;
*.exe)
dest="$out_dir/${NAME_PREFIX}.exe"
;;
*.zip)
dest="$out_dir/${NAME_PREFIX}.zip"
;;
*)
continue
;;
esac
cp "$src" "$dest"
done
ls -la "$out_dir"
- name: Upload desktop artifacts
uses: actions/upload-artifact@v6
with:
name: AstrBot-${{ steps.tag.outputs.tag }}-${{ matrix.arch }}-${{ matrix.os }}
if-no-files-found: error
path: desktop/dist/release/*
publish-release:
name: Publish Release Assets
runs-on: ubuntu-24.04
needs: build-desktop
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ inputs.ref || github.ref }}
- name: Resolve release tag
id: tag
shell: bash
run: |
if [ "${{ github.event_name }}" = "push" ]; then
tag="${GITHUB_REF_NAME}"
elif [ -n "${{ inputs.tag }}" ]; then
tag="${{ inputs.tag }}"
else
tag="$(git describe --tags --abbrev=0)"
fi
if [ -z "$tag" ]; then
echo "Failed to resolve release tag." >&2
exit 1
fi
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Download built artifacts
uses: actions/download-artifact@v6
with:
pattern: AstrBot-${{ steps.tag.outputs.tag }}-*
path: release-assets
merge-multiple: true
- name: Ensure release exists
env:
GH_TOKEN: ${{ github.token }}
shell: bash
run: |
tag="${{ steps.tag.outputs.tag }}"
if ! gh release view "$tag" >/dev/null 2>&1; then
gh release create "$tag" --title "$tag" --notes ""
fi
- name: Remove stale desktop assets from release
env:
GH_TOKEN: ${{ github.token }}
shell: bash
run: |
tag="${{ steps.tag.outputs.tag }}"
while IFS= read -r asset; do
case "$asset" in
*.AppImage|*.dmg|*.zip|*.exe|*.blockmap)
gh release delete-asset "$tag" "$asset" -y || true
;;
esac
done < <(gh release view "$tag" --json assets --jq '.assets[].name')
- name: Upload assets to release
env:
GH_TOKEN: ${{ github.token }}
shell: bash
run: |
tag="${{ steps.tag.outputs.tag }}"
gh release upload "$tag" release-assets/* --clobber
+2 -2
View File
@@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
env:
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
GHCR_OWNER: soulter
GHCR_OWNER: astrbotdevs
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
steps:
@@ -113,7 +113,7 @@ jobs:
runs-on: ubuntu-latest
env:
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
GHCR_OWNER: soulter
GHCR_OWNER: astrbotdevs
HAS_GHCR_TOKEN: ${{ secrets.GHCR_GITHUB_TOKEN != '' }}
steps:
+377
View File
@@ -0,0 +1,377 @@
name: Release
on:
push:
tags:
- "v*"
workflow_dispatch:
inputs:
ref:
description: "Git ref to build (branch/tag/SHA)"
required: false
default: "master"
tag:
description: "Release tag to publish assets to (for example: v4.14.6)"
required: false
permissions:
contents: write
jobs:
build-dashboard:
name: Build Dashboard
runs-on: ubuntu-24.04
env:
R2_ACCOUNT_ID: ${{ secrets.R2_ACCOUNT_ID }}
R2_ACCESS_KEY_ID: ${{ secrets.R2_ACCESS_KEY_ID }}
R2_SECRET_ACCESS_KEY: ${{ secrets.R2_SECRET_ACCESS_KEY }}
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ inputs.ref || github.ref }}
- name: Resolve tag
id: tag
shell: bash
run: |
if [ "${{ github.event_name }}" = "push" ]; then
tag="${GITHUB_REF_NAME}"
elif [ -n "${{ inputs.tag }}" ]; then
tag="${{ inputs.tag }}"
else
tag="$(git describe --tags --abbrev=0)"
fi
if [ -z "$tag" ]; then
echo "Failed to resolve tag." >&2
exit 1
fi
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10.28.2
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 20
cache: "pnpm"
cache-dependency-path: dashboard/pnpm-lock.yaml
- name: Build dashboard dist
shell: bash
run: |
pnpm --dir dashboard install --frozen-lockfile
pnpm --dir dashboard run build
echo "${{ steps.tag.outputs.tag }}" > dashboard/dist/assets/version
cd dashboard
zip -r "AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip" dist
- name: Upload dashboard artifact
uses: actions/upload-artifact@v6
with:
name: Dashboard-${{ steps.tag.outputs.tag }}
if-no-files-found: error
path: dashboard/AstrBot-${{ steps.tag.outputs.tag }}-dashboard.zip
- name: Upload dashboard package to Cloudflare R2
if: ${{ env.R2_ACCOUNT_ID != '' && env.R2_ACCESS_KEY_ID != '' && env.R2_SECRET_ACCESS_KEY != '' }}
env:
R2_BUCKET_NAME: "astrbot"
R2_OBJECT_NAME: "astrbot-webui-latest.zip"
VERSION_TAG: ${{ steps.tag.outputs.tag }}
shell: bash
run: |
curl https://rclone.org/install.sh | sudo bash
mkdir -p ~/.config/rclone
cat <<EOF > ~/.config/rclone/rclone.conf
[r2]
type = s3
provider = Cloudflare
access_key_id = $R2_ACCESS_KEY_ID
secret_access_key = $R2_SECRET_ACCESS_KEY
endpoint = https://${R2_ACCOUNT_ID}.r2.cloudflarestorage.com
EOF
cp "dashboard/AstrBot-${VERSION_TAG}-dashboard.zip" "dashboard/${R2_OBJECT_NAME}"
rclone copy "dashboard/${R2_OBJECT_NAME}" "r2:${R2_BUCKET_NAME}" --progress
cp "dashboard/AstrBot-${VERSION_TAG}-dashboard.zip" "dashboard/astrbot-webui-${VERSION_TAG}.zip"
rclone copy "dashboard/astrbot-webui-${VERSION_TAG}.zip" "r2:${R2_BUCKET_NAME}" --progress
build-desktop:
name: Build ${{ matrix.name }}
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
include:
- name: linux-x64
runner: ubuntu-24.04
os: linux
arch: amd64
- name: linux-arm64
runner: ubuntu-24.04-arm
os: linux
arch: arm64
- name: windows-x64
runner: windows-2022
os: win
arch: amd64
- name: windows-arm64
runner: windows-11-arm
os: win
arch: arm64
- name: macos-x64
runner: macos-15-intel
os: mac
arch: amd64
- name: macos-arm64
runner: macos-15
os: mac
arch: arm64
env:
CSC_IDENTITY_AUTO_DISCOVERY: "false"
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ inputs.ref || github.ref }}
- name: Resolve tag
id: tag
shell: bash
run: |
if [ "${{ github.event_name }}" = "push" ]; then
tag="${GITHUB_REF_NAME}"
elif [ -n "${{ inputs.tag }}" ]; then
tag="${{ inputs.tag }}"
else
tag="$(git describe --tags --abbrev=0)"
fi
if [ -z "$tag" ]; then
echo "Failed to resolve tag." >&2
exit 1
fi
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Setup uv
uses: astral-sh/setup-uv@v7
- name: Setup Python
uses: actions/setup-python@v6
with:
python-version: "3.12"
- name: Setup pnpm
uses: pnpm/action-setup@v4
with:
version: 10.28.2
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 20
cache: "pnpm"
cache-dependency-path: |
dashboard/pnpm-lock.yaml
desktop/pnpm-lock.yaml
- name: Prepare OpenSSL for Windows ARM64
if: ${{ matrix.os == 'win' && matrix.arch == 'arm64' }}
shell: pwsh
run: |
git clone https://github.com/microsoft/vcpkg.git C:\vcpkg
& C:\vcpkg\bootstrap-vcpkg.bat -disableMetrics
& C:\vcpkg\vcpkg.exe install openssl:arm64-windows
"VCPKG_ROOT=C:\vcpkg" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
"VCPKGRS_TRIPLET=arm64-windows" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
"OPENSSL_DIR=C:\vcpkg\installed\arm64-windows" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
"OPENSSL_ROOT_DIR=C:\vcpkg\installed\arm64-windows" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
"OPENSSL_LIB_DIR=C:\vcpkg\installed\arm64-windows\lib" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
"OPENSSL_INCLUDE_DIR=C:\vcpkg\installed\arm64-windows\include" | Out-File -FilePath $env:GITHUB_ENV -Encoding utf8 -Append
- name: Install dependencies
shell: bash
run: |
uv sync
pnpm --dir dashboard install --frozen-lockfile
pnpm --dir desktop install --frozen-lockfile
- name: Build desktop package
shell: bash
run: |
pnpm --dir dashboard run build
pnpm --dir desktop run build:webui
pnpm --dir desktop run build:backend
pnpm --dir desktop run sync:version
pnpm --dir desktop exec electron-builder --publish never
- name: Normalize artifact names
shell: bash
env:
NAME_PREFIX: AstrBot-${{ steps.tag.outputs.tag }}-${{ matrix.arch }}-${{ matrix.os }}
run: |
shopt -s nullglob
out_dir="desktop/dist/release"
mkdir -p "$out_dir"
files=(
desktop/dist/*.AppImage
desktop/dist/*.dmg
desktop/dist/*.zip
desktop/dist/*.exe
)
if [ ${#files[@]} -eq 0 ]; then
echo "No desktop artifacts found to rename." >&2
exit 1
fi
for src in "${files[@]}"; do
file="$(basename "$src")"
case "$file" in
*.AppImage)
dest="$out_dir/${NAME_PREFIX}.AppImage"
;;
*.dmg)
dest="$out_dir/${NAME_PREFIX}.dmg"
;;
*.exe)
dest="$out_dir/${NAME_PREFIX}.exe"
;;
*.zip)
dest="$out_dir/${NAME_PREFIX}.zip"
;;
*)
continue
;;
esac
cp "$src" "$dest"
done
ls -la "$out_dir"
- name: Upload desktop artifacts
uses: actions/upload-artifact@v6
with:
name: AstrBot-${{ steps.tag.outputs.tag }}-${{ matrix.arch }}-${{ matrix.os }}
if-no-files-found: error
path: desktop/dist/release/*
publish-release:
name: Publish GitHub Release
runs-on: ubuntu-24.04
needs:
- build-dashboard
- build-desktop
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ inputs.ref || github.ref }}
- name: Resolve tag
id: tag
shell: bash
run: |
if [ "${{ github.event_name }}" = "push" ]; then
tag="${GITHUB_REF_NAME}"
elif [ -n "${{ inputs.tag }}" ]; then
tag="${{ inputs.tag }}"
else
tag="$(git describe --tags --abbrev=0)"
fi
if [ -z "$tag" ]; then
echo "Failed to resolve tag." >&2
exit 1
fi
echo "tag=$tag" >> "$GITHUB_OUTPUT"
- name: Download dashboard artifact
uses: actions/download-artifact@v7
with:
name: Dashboard-${{ steps.tag.outputs.tag }}
path: release-assets
- name: Download desktop artifacts
uses: actions/download-artifact@v7
with:
pattern: AstrBot-${{ steps.tag.outputs.tag }}-*
path: release-assets
merge-multiple: true
- name: Resolve release notes
id: notes
shell: bash
run: |
note_file="changelogs/${{ steps.tag.outputs.tag }}.md"
if [ ! -f "$note_file" ]; then
note_file="$(mktemp)"
echo "Release ${{ steps.tag.outputs.tag }}" > "$note_file"
fi
echo "file=$note_file" >> "$GITHUB_OUTPUT"
- name: Ensure release exists
env:
GH_TOKEN: ${{ github.token }}
shell: bash
run: |
tag="${{ steps.tag.outputs.tag }}"
if ! gh release view "$tag" >/dev/null 2>&1; then
gh release create "$tag" --title "$tag" --notes-file "${{ steps.notes.outputs.file }}"
fi
- name: Remove stale assets from release
env:
GH_TOKEN: ${{ github.token }}
shell: bash
run: |
tag="${{ steps.tag.outputs.tag }}"
while IFS= read -r asset; do
case "$asset" in
*.AppImage|*.dmg|*.zip|*.exe|*.blockmap)
gh release delete-asset "$tag" "$asset" -y || true
;;
esac
done < <(gh release view "$tag" --json assets --jq '.assets[].name')
- name: Upload assets to release
env:
GH_TOKEN: ${{ github.token }}
shell: bash
run: |
tag="${{ steps.tag.outputs.tag }}"
gh release upload "$tag" release-assets/* --clobber
publish-pypi:
name: Publish PyPI
runs-on: ubuntu-24.04
needs: publish-release
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ inputs.ref || github.ref }}
- name: Set up Python
uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install uv
shell: bash
run: python -m pip install uv
- name: Build package
shell: bash
run: uv build
- name: Publish to PyPI
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
shell: bash
run: uv publish
@@ -17,7 +17,7 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
class LongTermMemory:
def __init__(self, acm: AstrBotConfigManager, context: star.Context):
def __init__(self, acm: AstrBotConfigManager, context: star.Context) -> None:
self.acm = acm
self.context = context
self.session_chats = defaultdict(list)
@@ -111,7 +111,7 @@ class LongTermMemory:
return False
async def handle_message(self, event: AstrMessageEvent):
async def handle_message(self, event: AstrMessageEvent) -> None:
"""仅支持群聊"""
if event.get_message_type() == MessageType.GROUP_MESSAGE:
datetime_str = datetime.datetime.now().strftime("%H:%M:%S")
@@ -148,7 +148,7 @@ class LongTermMemory:
if len(self.session_chats[event.unified_msg_origin]) > cfg["max_cnt"]:
self.session_chats[event.unified_msg_origin].pop(0)
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest):
async def on_req_llm(self, event: AstrMessageEvent, req: ProviderRequest) -> None:
"""当触发 LLM 请求前,调用此方法修改 req"""
if event.unified_msg_origin not in self.session_chats:
return
@@ -171,7 +171,9 @@ class LongTermMemory:
)
req.system_prompt += chats_str
async def after_req_llm(self, event: AstrMessageEvent, llm_resp: LLMResponse):
async def after_req_llm(
self, event: AstrMessageEvent, llm_resp: LLMResponse
) -> None:
if event.unified_msg_origin not in self.session_chats:
return
+7 -3
View File
@@ -85,7 +85,9 @@ class Main(star.Star):
logger.error(f"主动回复失败: {e}")
@filter.on_llm_request()
async def decorate_llm_req(self, event: AstrMessageEvent, req: ProviderRequest):
async def decorate_llm_req(
self, event: AstrMessageEvent, req: ProviderRequest
) -> None:
"""在请求 LLM 前注入人格信息、Identifier、时间、回复内容等 System Prompt"""
if self.ltm and self.ltm_enabled(event):
try:
@@ -94,7 +96,9 @@ class Main(star.Star):
logger.error(f"ltm: {e}")
@filter.on_llm_response()
async def record_llm_resp_to_ltm(self, event: AstrMessageEvent, resp: LLMResponse):
async def record_llm_resp_to_ltm(
self, event: AstrMessageEvent, resp: LLMResponse
) -> None:
"""在 LLM 响应后记录对话"""
if self.ltm and self.ltm_enabled(event):
try:
@@ -103,7 +107,7 @@ class Main(star.Star):
logger.error(f"ltm: {e}")
@filter.after_message_sent()
async def after_message_sent(self, event: AstrMessageEvent):
async def after_message_sent(self, event: AstrMessageEvent) -> None:
"""消息发送后处理"""
if self.ltm and self.ltm_enabled(event):
try:
@@ -5,10 +5,10 @@ from astrbot.core.utils.io import download_dashboard
class AdminCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def op(self, event: AstrMessageEvent, admin_id: str = ""):
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员。op <admin_id>"""
if not admin_id:
event.set_result(
@@ -21,7 +21,7 @@ class AdminCommands:
self.context.get_config().save_config()
event.set_result(MessageEventResult().message("授权成功。"))
async def deop(self, event: AstrMessageEvent, admin_id: str = ""):
async def deop(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""取消授权管理员。deop <admin_id>"""
if not admin_id:
event.set_result(
@@ -39,7 +39,7 @@ class AdminCommands:
MessageEventResult().message("此用户 ID 不在管理员名单内。"),
)
async def wl(self, event: AstrMessageEvent, sid: str = ""):
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单。wl <sid>"""
if not sid:
event.set_result(
@@ -53,7 +53,7 @@ class AdminCommands:
cfg.save_config()
event.set_result(MessageEventResult().message("添加白名单成功。"))
async def dwl(self, event: AstrMessageEvent, sid: str = ""):
async def dwl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""删除白名单。dwl <sid>"""
if not sid:
event.set_result(
@@ -70,7 +70,7 @@ class AdminCommands:
except ValueError:
event.set_result(MessageEventResult().message("此 SID 不在白名单内。"))
async def update_dashboard(self, event: AstrMessageEvent):
async def update_dashboard(self, event: AstrMessageEvent) -> None:
"""更新管理面板"""
await event.send(MessageChain().message("正在尝试更新管理面板..."))
await download_dashboard(version=f"v{VERSION}", latest=False)
@@ -11,10 +11,10 @@ from .utils.rst_scene import RstScene
class AlterCmdCommands(CommandParserMixin):
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def update_reset_permission(self, scene_key: str, perm_type: str):
async def update_reset_permission(self, scene_key: str, perm_type: str) -> None:
"""更新reset命令在特定场景下的权限设置"""
from astrbot.api import sp
@@ -26,7 +26,7 @@ class AlterCmdCommands(CommandParserMixin):
alter_cmd_cfg["astrbot"] = plugin_cfg
await sp.global_put("alter_cmd", alter_cmd_cfg)
async def alter_cmd(self, event: AstrMessageEvent):
async def alter_cmd(self, event: AstrMessageEvent) -> None:
token = self.parse_commands(event.message_str)
if token.len < 3:
await event.send(
@@ -16,7 +16,7 @@ THIRD_PARTY_AGENT_RUNNER_STR = ", ".join(THIRD_PARTY_AGENT_RUNNER_KEY.keys())
class ConversationCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def _get_current_persona_id(self, session_id):
@@ -33,7 +33,7 @@ class ConversationCommands:
return None
return conv.persona_id
async def reset(self, message: AstrMessageEvent):
async def reset(self, message: AstrMessageEvent) -> None:
"""重置 LLM 会话"""
umo = message.unified_msg_origin
cfg = self.context.get_config(umo=message.unified_msg_origin)
@@ -98,7 +98,7 @@ class ConversationCommands:
message.set_result(MessageEventResult().message(ret))
async def his(self, message: AstrMessageEvent, page: int = 1):
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
if not self.context.get_using_provider(message.unified_msg_origin):
message.set_result(
@@ -141,7 +141,7 @@ class ConversationCommands:
message.set_result(MessageEventResult().message(ret).use_t2i(False))
async def convs(self, message: AstrMessageEvent, page: int = 1):
async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话列表"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
@@ -216,7 +216,7 @@ class ConversationCommands:
message.set_result(MessageEventResult().message(ret).use_t2i(False))
return
async def new_conv(self, message: AstrMessageEvent):
async def new_conv(self, message: AstrMessageEvent) -> None:
"""创建新对话"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
agent_runner_type = cfg["provider_settings"]["agent_runner_type"]
@@ -242,7 +242,7 @@ class ConversationCommands:
MessageEventResult().message(f"切换到新对话: 新对话({cid[:4]})。"),
)
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = ""):
async def groupnew_conv(self, message: AstrMessageEvent, sid: str = "") -> None:
"""创建新群聊对话"""
if sid:
session = str(
@@ -273,7 +273,7 @@ class ConversationCommands:
self,
message: AstrMessageEvent,
index: int | None = None,
):
) -> None:
"""通过 /ls 前面的序号切换对话"""
if not isinstance(index, int):
message.set_result(
@@ -308,7 +308,7 @@ class ConversationCommands:
),
)
async def rename_conv(self, message: AstrMessageEvent, new_name: str = ""):
async def rename_conv(self, message: AstrMessageEvent, new_name: str = "") -> None:
"""重命名对话"""
if not new_name:
message.set_result(MessageEventResult().message("请输入新的对话名称。"))
@@ -319,7 +319,7 @@ class ConversationCommands:
)
message.set_result(MessageEventResult().message("重命名对话成功。"))
async def del_conv(self, message: AstrMessageEvent):
async def del_conv(self, message: AstrMessageEvent) -> None:
"""删除当前对话"""
cfg = self.context.get_config(umo=message.unified_msg_origin)
is_unique_session = cfg["platform_settings"]["unique_session"]
@@ -8,7 +8,7 @@ from astrbot.core.utils.io import get_dashboard_version
class HelpCommand:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def _query_astrbot_notice(self):
@@ -34,7 +34,7 @@ class HelpCommand:
lines: list[str] = []
hidden_commands = {"set", "unset", "websearch"}
def walk(items: list[dict], indent: int = 0):
def walk(items: list[dict], indent: int = 0) -> None:
for item in items:
if not item.get("reserved") or not item.get("enabled"):
continue
@@ -62,7 +62,7 @@ class HelpCommand:
walk(commands)
return lines
async def help(self, event: AstrMessageEvent):
async def help(self, event: AstrMessageEvent) -> None:
"""查看帮助"""
notice = ""
try:
@@ -3,10 +3,10 @@ from astrbot.api.event import AstrMessageEvent, MessageChain
class LLMCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def llm(self, event: AstrMessageEvent):
async def llm(self, event: AstrMessageEvent) -> None:
"""开启/关闭 LLM"""
cfg = self.context.get_config(umo=event.unified_msg_origin)
enable = cfg["provider_settings"].get("enable", True)
@@ -9,7 +9,7 @@ if TYPE_CHECKING:
class PersonaCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
def _build_tree_output(
@@ -50,7 +50,7 @@ class PersonaCommands:
return lines
async def persona(self, message: AstrMessageEvent):
async def persona(self, message: AstrMessageEvent) -> None:
l = message.message_str.split(" ") # noqa: E741
umo = message.unified_msg_origin
@@ -8,10 +8,10 @@ from astrbot.core.star.star_manager import PluginManager
class PluginCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def plugin_ls(self, event: AstrMessageEvent):
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表。"""
parts = ["已加载的插件:\n"]
for plugin in self.context.get_all_stars():
@@ -30,7 +30,7 @@ class PluginCommands:
MessageEventResult().message(f"{plugin_list_info}").use_t2i(False),
)
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""禁用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法禁用插件。"))
@@ -43,7 +43,7 @@ class PluginCommands:
await self.context._star_manager.turn_off_plugin(plugin_name) # type: ignore
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已禁用。"))
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""启用插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法启用插件。"))
@@ -56,7 +56,7 @@ class PluginCommands:
await self.context._star_manager.turn_on_plugin(plugin_name) # type: ignore
event.set_result(MessageEventResult().message(f"插件 {plugin_name} 已启用。"))
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
"""安装插件"""
if DEMO_MODE:
event.set_result(MessageEventResult().message("演示模式下无法安装插件。"))
@@ -77,7 +77,7 @@ class PluginCommands:
event.set_result(MessageEventResult().message(f"安装插件失败: {e}"))
return
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""获取插件帮助"""
if not plugin_name:
event.set_result(
@@ -8,7 +8,7 @@ from astrbot.core.provider.entities import ProviderType
class ProviderCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
def _log_reachability_failure(
@@ -17,7 +17,7 @@ class ProviderCommands:
provider_capability_type: ProviderType | None,
err_code: str,
err_reason: str,
):
) -> None:
"""记录不可达原因到日志。"""
meta = provider.meta()
logger.warning(
@@ -49,7 +49,7 @@ class ProviderCommands:
event: AstrMessageEvent,
idx: str | int | None = None,
idx2: int | None = None,
):
) -> None:
"""查看或者切换 LLM Provider"""
umo = event.unified_msg_origin
cfg = self.context.get_config(umo).get("provider_settings", {})
@@ -228,7 +228,7 @@ class ProviderCommands:
self,
message: AstrMessageEvent,
idx_or_name: int | str | None = None,
):
) -> None:
"""查看或者切换模型"""
prov = self.context.get_using_provider(message.unified_msg_origin)
if not prov:
@@ -293,7 +293,7 @@ class ProviderCommands:
MessageEventResult().message(f"切换模型到 {prov.get_model()}"),
)
async def key(self, message: AstrMessageEvent, index: int | None = None):
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
prov = self.context.get_using_provider(message.unified_msg_origin)
if not prov:
message.set_result(
@@ -3,10 +3,10 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
class SetUnsetCommands:
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
"""设置会话变量"""
uid = event.unified_msg_origin
session_var = await sp.session_get(uid, "session_variables", {})
@@ -19,7 +19,7 @@ class SetUnsetCommands:
),
)
async def unset_variable(self, event: AstrMessageEvent, key: str):
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
"""移除会话变量"""
uid = event.unified_msg_origin
session_var = await sp.session_get(uid, "session_variables", {})
@@ -7,10 +7,10 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
class SIDCommand:
"""会话ID命令类"""
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def sid(self, event: AstrMessageEvent):
async def sid(self, event: AstrMessageEvent) -> None:
"""获取消息来源信息"""
sid = event.unified_msg_origin
user_id = str(event.get_sender_id())
@@ -7,10 +7,10 @@ from astrbot.api.event import AstrMessageEvent, MessageEventResult
class T2ICommand:
"""文本转图片命令类"""
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def t2i(self, event: AstrMessageEvent):
async def t2i(self, event: AstrMessageEvent) -> None:
"""开关文本转图片"""
config = self.context.get_config(umo=event.unified_msg_origin)
if config["t2i"]:
@@ -8,10 +8,10 @@ from astrbot.core.star.session_llm_manager import SessionServiceManager
class TTSCommand:
"""文本转语音命令类"""
def __init__(self, context: star.Context):
def __init__(self, context: star.Context) -> None:
self.context = context
async def tts(self, event: AstrMessageEvent):
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
umo = event.unified_msg_origin
ses_tts = await SessionServiceManager.is_tts_enabled_for_session(umo)
+33 -31
View File
@@ -35,84 +35,84 @@ class Main(star.Star):
self.sid_c = SIDCommand(self.context)
@filter.command("help")
async def help(self, event: AstrMessageEvent):
async def help(self, event: AstrMessageEvent) -> None:
"""查看帮助"""
await self.help_c.help(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("llm")
async def llm(self, event: AstrMessageEvent):
async def llm(self, event: AstrMessageEvent) -> None:
"""开启/关闭 LLM"""
await self.llm_c.llm(event)
@filter.command_group("plugin")
def plugin(self):
def plugin(self) -> None:
"""插件管理"""
@plugin.command("ls")
async def plugin_ls(self, event: AstrMessageEvent):
async def plugin_ls(self, event: AstrMessageEvent) -> None:
"""获取已经安装的插件列表。"""
await self.plugin_c.plugin_ls(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("off")
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_off(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""禁用插件"""
await self.plugin_c.plugin_off(event, plugin_name)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("on")
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_on(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""启用插件"""
await self.plugin_c.plugin_on(event, plugin_name)
@filter.permission_type(filter.PermissionType.ADMIN)
@plugin.command("get")
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = ""):
async def plugin_get(self, event: AstrMessageEvent, plugin_repo: str = "") -> None:
"""安装插件"""
await self.plugin_c.plugin_get(event, plugin_repo)
@plugin.command("help")
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = ""):
async def plugin_help(self, event: AstrMessageEvent, plugin_name: str = "") -> None:
"""获取插件帮助"""
await self.plugin_c.plugin_help(event, plugin_name)
@filter.command("t2i")
async def t2i(self, event: AstrMessageEvent):
async def t2i(self, event: AstrMessageEvent) -> None:
"""开关文本转图片"""
await self.t2i_c.t2i(event)
@filter.command("tts")
async def tts(self, event: AstrMessageEvent):
async def tts(self, event: AstrMessageEvent) -> None:
"""开关文本转语音(会话级别)"""
await self.tts_c.tts(event)
@filter.command("sid")
async def sid(self, event: AstrMessageEvent):
async def sid(self, event: AstrMessageEvent) -> None:
"""获取会话 ID 和 管理员 ID"""
await self.sid_c.sid(event)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("op")
async def op(self, event: AstrMessageEvent, admin_id: str = ""):
async def op(self, event: AstrMessageEvent, admin_id: str = "") -> None:
"""授权管理员。op <admin_id>"""
await self.admin_c.op(event, admin_id)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("deop")
async def deop(self, event: AstrMessageEvent, admin_id: str):
async def deop(self, event: AstrMessageEvent, admin_id: str) -> None:
"""取消授权管理员。deop <admin_id>"""
await self.admin_c.deop(event, admin_id)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("wl")
async def wl(self, event: AstrMessageEvent, sid: str = ""):
async def wl(self, event: AstrMessageEvent, sid: str = "") -> None:
"""添加白名单。wl <sid>"""
await self.admin_c.wl(event, sid)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dwl")
async def dwl(self, event: AstrMessageEvent, sid: str):
async def dwl(self, event: AstrMessageEvent, sid: str) -> None:
"""删除白名单。dwl <sid>"""
await self.admin_c.dwl(event, sid)
@@ -123,12 +123,12 @@ class Main(star.Star):
event: AstrMessageEvent,
idx: str | int | None = None,
idx2: int | None = None,
):
) -> None:
"""查看或者切换 LLM Provider"""
await self.provider_c.provider(event, idx, idx2)
@filter.command("reset")
async def reset(self, message: AstrMessageEvent):
async def reset(self, message: AstrMessageEvent) -> None:
"""重置 LLM 会话"""
await self.conversation_c.reset(message)
@@ -138,74 +138,76 @@ class Main(star.Star):
self,
message: AstrMessageEvent,
idx_or_name: int | str | None = None,
):
) -> None:
"""查看或者切换模型"""
await self.provider_c.model_ls(message, idx_or_name)
@filter.command("history")
async def his(self, message: AstrMessageEvent, page: int = 1):
async def his(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话记录"""
await self.conversation_c.his(message, page)
@filter.command("ls")
async def convs(self, message: AstrMessageEvent, page: int = 1):
async def convs(self, message: AstrMessageEvent, page: int = 1) -> None:
"""查看对话列表"""
await self.conversation_c.convs(message, page)
@filter.command("new")
async def new_conv(self, message: AstrMessageEvent):
async def new_conv(self, message: AstrMessageEvent) -> None:
"""创建新对话"""
await self.conversation_c.new_conv(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("groupnew")
async def groupnew_conv(self, message: AstrMessageEvent, sid: str):
async def groupnew_conv(self, message: AstrMessageEvent, sid: str) -> None:
"""创建新群聊对话"""
await self.conversation_c.groupnew_conv(message, sid)
@filter.command("switch")
async def switch_conv(self, message: AstrMessageEvent, index: int | None = None):
async def switch_conv(
self, message: AstrMessageEvent, index: int | None = None
) -> None:
"""通过 /ls 前面的序号切换对话"""
await self.conversation_c.switch_conv(message, index)
@filter.command("rename")
async def rename_conv(self, message: AstrMessageEvent, new_name: str):
async def rename_conv(self, message: AstrMessageEvent, new_name: str) -> None:
"""重命名对话"""
await self.conversation_c.rename_conv(message, new_name)
@filter.command("del")
async def del_conv(self, message: AstrMessageEvent):
async def del_conv(self, message: AstrMessageEvent) -> None:
"""删除当前对话"""
await self.conversation_c.del_conv(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("key")
async def key(self, message: AstrMessageEvent, index: int | None = None):
async def key(self, message: AstrMessageEvent, index: int | None = None) -> None:
"""查看或者切换 Key"""
await self.provider_c.key(message, index)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("persona")
async def persona(self, message: AstrMessageEvent):
async def persona(self, message: AstrMessageEvent) -> None:
"""查看或者切换 Persona"""
await self.persona_c.persona(message)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("dashboard_update")
async def update_dashboard(self, event: AstrMessageEvent):
async def update_dashboard(self, event: AstrMessageEvent) -> None:
"""更新管理面板"""
await self.admin_c.update_dashboard(event)
@filter.command("set")
async def set_variable(self, event: AstrMessageEvent, key: str, value: str):
async def set_variable(self, event: AstrMessageEvent, key: str, value: str) -> None:
await self.setunset_c.set_variable(event, key, value)
@filter.command("unset")
async def unset_variable(self, event: AstrMessageEvent, key: str):
async def unset_variable(self, event: AstrMessageEvent, key: str) -> None:
await self.setunset_c.unset_variable(event, key)
@filter.permission_type(filter.PermissionType.ADMIN)
@filter.command("alter_cmd", alias={"alter"})
async def alter_cmd(self, event: AstrMessageEvent):
async def alter_cmd(self, event: AstrMessageEvent) -> None:
"""修改命令权限"""
await self.alter_cmd_c.alter_cmd(event)
@@ -17,11 +17,11 @@ from astrbot.core.utils.session_waiter import (
class Main(Star):
"""会话控制"""
def __init__(self, context: Context):
def __init__(self, context: Context) -> None:
super().__init__(context)
@filter.event_message_type(filter.EventMessageType.ALL, priority=maxsize)
async def handle_session_control_agent(self, event: AstrMessageEvent):
async def handle_session_control_agent(self, event: AstrMessageEvent) -> None:
"""会话控制代理"""
for session_filter in FILTERS:
session_id = session_filter.filter(event)
@@ -90,7 +90,7 @@ class Main(Star):
async def empty_mention_waiter(
controller: SessionController,
event: AstrMessageEvent,
):
) -> None:
event.message_obj.message.insert(
0,
Comp.At(qq=event.get_self_id(), name=event.get_self_id()),
@@ -49,7 +49,7 @@ class SearchEngine:
def _set_selector(self, selector: str) -> str:
raise NotImplementedError
def _get_next_page(self, query: str):
async def _get_next_page(self, query: str) -> str:
raise NotImplementedError
async def _get_html(self, url: str, data: dict | None = None) -> str:
+3 -3
View File
@@ -199,7 +199,7 @@ class Main(star.Star):
return results
@filter.command("websearch")
async def websearch(self, event: AstrMessageEvent, oper: str | None = None):
async def websearch(self, event: AstrMessageEvent, oper: str | None = None) -> None:
"""网页搜索指令(已废弃)"""
event.set_result(
MessageEventResult().message(
@@ -246,7 +246,7 @@ class Main(star.Star):
return ret
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None):
async def ensure_baidu_ai_search_mcp(self, umo: str | None = None) -> None:
if self.baidu_initialized:
return
cfg = self.context.get_config(umo=umo)
@@ -553,7 +553,7 @@ class Main(star.Star):
self,
event: AstrMessageEvent,
req: ProviderRequest,
):
) -> None:
"""Get the session conversation for the given event."""
cfg = self.context.get_config(umo=event.unified_msg_origin)
prov_settings = cfg.get("provider_settings", {})
+1 -1
View File
@@ -1 +1 @@
__version__ = "4.14.7"
__version__ = "4.15.0"
+3 -3
View File
@@ -127,7 +127,7 @@ def _get_nested_item(obj: dict[str, Any], path: str) -> Any:
@click.group(name="conf")
def conf():
def conf() -> None:
"""配置管理命令
支持的配置项:
@@ -149,7 +149,7 @@ def conf():
@conf.command(name="set")
@click.argument("key")
@click.argument("value")
def set_config(key: str, value: str):
def set_config(key: str, value: str) -> None:
"""设置配置项的值"""
if key not in CONFIG_VALIDATORS:
raise click.ClickException(f"不支持的配置项: {key}")
@@ -178,7 +178,7 @@ def set_config(key: str, value: str):
@conf.command(name="get")
@click.argument("key", required=False)
def get_config(key: str | None = None):
def get_config(key: str | None = None) -> None:
"""获取配置项的值,不提供key则显示所有可配置项"""
config = _load_config()
+8 -8
View File
@@ -15,7 +15,7 @@ from ..utils import (
@click.group()
def plug():
def plug() -> None:
"""插件管理"""
@@ -28,7 +28,7 @@ def _get_data_path() -> Path:
return (base / "data").resolve()
def display_plugins(plugins, title=None, color=None):
def display_plugins(plugins, title=None, color=None) -> None:
if title:
click.echo(click.style(title, fg=color, bold=True))
@@ -45,7 +45,7 @@ def display_plugins(plugins, title=None, color=None):
@plug.command()
@click.argument("name")
def new(name: str):
def new(name: str) -> None:
"""创建新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins" / name
@@ -100,7 +100,7 @@ def new(name: str):
@plug.command()
@click.option("--all", "-a", is_flag=True, help="列出未安装的插件")
def list(all: bool):
def list(all: bool) -> None:
"""列出插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
@@ -141,7 +141,7 @@ def list(all: bool):
@plug.command()
@click.argument("name")
@click.option("--proxy", help="代理服务器地址")
def install(name: str, proxy: str | None):
def install(name: str, proxy: str | None) -> None:
"""安装插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
@@ -164,7 +164,7 @@ def install(name: str, proxy: str | None):
@plug.command()
@click.argument("name")
def remove(name: str):
def remove(name: str) -> None:
"""卸载插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
@@ -187,7 +187,7 @@ def remove(name: str):
@plug.command()
@click.argument("name", required=False)
@click.option("--proxy", help="Github代理地址")
def update(name: str, proxy: str | None):
def update(name: str, proxy: str | None) -> None:
"""更新插件"""
base_path = _get_data_path()
plug_path = base_path / "plugins"
@@ -225,7 +225,7 @@ def update(name: str, proxy: str | None):
@plug.command()
@click.argument("query")
def search(query: str):
def search(query: str) -> None:
"""搜索插件"""
base_path = _get_data_path()
plugins = build_plug_list(base_path / "plugins")
+1 -1
View File
@@ -10,7 +10,7 @@ from filelock import FileLock, Timeout
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
async def run_astrbot(astrbot_root: Path):
async def run_astrbot(astrbot_root: Path) -> None:
"""运行 AstrBot"""
from astrbot.core import LogBroker, LogManager, db_helper, logger
from astrbot.core.initial_loader import InitialLoader
+1 -1
View File
@@ -19,7 +19,7 @@ class PluginStatus(str, Enum):
NOT_PUBLISHED = "未发布"
def get_git_repo(url: str, target_path: Path, proxy: str | None = None):
def get_git_repo(url: str, target_path: Path, proxy: str | None = None) -> None:
"""从 Git 仓库下载代码并解压到指定路径"""
temp_dir = Path(tempfile.mkdtemp())
try:
+4 -2
View File
@@ -57,7 +57,9 @@ class TruncateByTurnsCompressor:
Truncates the message list by removing older turns.
"""
def __init__(self, truncate_turns: int = 1, compression_threshold: float = 0.82):
def __init__(
self, truncate_turns: int = 1, compression_threshold: float = 0.82
) -> None:
"""Initialize the truncate by turns compressor.
Args:
@@ -152,7 +154,7 @@ class LLMSummaryCompressor:
keep_recent: int = 4,
instruction_text: str | None = None,
compression_threshold: float = 0.82,
):
) -> None:
"""Initialize the LLM summary compressor.
Args:
+1 -1
View File
@@ -13,7 +13,7 @@ class ContextManager:
def __init__(
self,
config: ContextConfig,
):
) -> None:
"""Initialize the context manager.
There are two strategies to handle context limit reached:
+3 -2
View File
@@ -14,8 +14,7 @@ class HandoffTool(FunctionTool, Generic[TContext]):
parameters: dict | None = None,
tool_description: str | None = None,
**kwargs,
):
self.agent = agent
) -> None:
# Avoid passing duplicate `description` to the FunctionTool dataclass.
# Some call sites (e.g. SubAgentOrchestrator) pass `description` via kwargs
@@ -34,6 +33,8 @@ class HandoffTool(FunctionTool, Generic[TContext]):
# Optional provider override for this subagent. When set, the handoff
# execution will use this chat provider id instead of the global/default.
self.provider_id: str | None = None
# Note: Must assign after super().__init__() to prevent parent class from overriding this attribute
self.agent = agent
def default_parameters(self) -> dict:
return {
+4 -4
View File
@@ -9,22 +9,22 @@ from .run_context import ContextWrapper, TContext
class BaseAgentRunHooks(Generic[TContext]):
async def on_agent_begin(self, run_context: ContextWrapper[TContext]): ...
async def on_agent_begin(self, run_context: ContextWrapper[TContext]) -> None: ...
async def on_tool_start(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
): ...
) -> None: ...
async def on_tool_end(
self,
run_context: ContextWrapper[TContext],
tool: FunctionTool,
tool_args: dict | None,
tool_result: mcp.types.CallToolResult | None,
): ...
) -> None: ...
async def on_agent_done(
self,
run_context: ContextWrapper[TContext],
llm_response: LLMResponse,
): ...
) -> None: ...
+6 -6
View File
@@ -108,7 +108,7 @@ async def _quick_test_mcp_connection(config: dict) -> tuple[bool, str]:
class MCPClient:
def __init__(self):
def __init__(self) -> None:
# Initialize session and client objects
self.session: mcp.ClientSession | None = None
self.exit_stack = AsyncExitStack()
@@ -126,7 +126,7 @@ class MCPClient:
self._reconnect_lock = asyncio.Lock() # Lock for thread-safe reconnection
self._reconnecting: bool = False # For logging and debugging
async def connect_to_server(self, mcp_server_config: dict, name: str):
async def connect_to_server(self, mcp_server_config: dict, name: str) -> None:
"""Connect to MCP server
If `url` parameter exists:
@@ -144,7 +144,7 @@ class MCPClient:
cfg = _prepare_config(mcp_server_config.copy())
def logging_callback(msg: str):
def logging_callback(msg: str) -> None:
# Handle MCP service error logs
print(f"MCP Server {name} Error: {msg}")
self.server_errlogs.append(msg)
@@ -214,7 +214,7 @@ class MCPClient:
**cfg,
)
def callback(msg: str):
def callback(msg: str) -> None:
# Handle MCP service error logs
self.server_errlogs.append(msg)
@@ -343,7 +343,7 @@ class MCPClient:
return await _call_with_retry()
async def cleanup(self):
async def cleanup(self) -> None:
"""Clean up resources including old exit stacks from reconnections"""
# Close current exit stack
try:
@@ -365,7 +365,7 @@ class MCPTool(FunctionTool, Generic[TContext]):
def __init__(
self, mcp_tool: mcp.Tool, mcp_client: MCPClient, mcp_server_name: str, **kwargs
):
) -> None:
super().__init__(
name=mcp_tool.name,
description=mcp_tool.description or "",
@@ -10,7 +10,7 @@ from astrbot.core import logger
class CozeAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn"):
def __init__(self, api_key: str, api_base: str = "https://api.coze.cn") -> None:
self.api_key = api_key
self.api_base = api_base
self.session = None
@@ -277,7 +277,7 @@ class CozeAPIClient:
logger.error(f"获取Coze消息列表失败: {e!s}")
raise Exception(f"获取Coze消息列表失败: {e!s}")
async def close(self):
async def close(self) -> None:
"""关闭会话"""
if self.session:
await self.session.close()
@@ -288,7 +288,7 @@ if __name__ == "__main__":
import asyncio
import os
async def test_coze_api_client():
async def test_coze_api_client() -> None:
api_key = os.getenv("COZE_API_KEY", "")
bot_id = os.getenv("COZE_BOT_ID", "")
client = CozeAPIClient(api_key=api_key)
@@ -67,7 +67,7 @@ class DashscopeAgentRunner(BaseAgentRunner[TContext]):
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
def has_rag_options(self):
def has_rag_options(self) -> bool:
"""判断是否有 RAG 选项
Returns:
@@ -31,7 +31,7 @@ async def _stream_sse(resp: ClientResponse) -> AsyncGenerator[dict, None]:
class DifyAPIClient:
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1"):
def __init__(self, api_key: str, api_base: str = "https://api.dify.ai/v1") -> None:
self.api_key = api_key
self.api_base = api_base
self.session = ClientSession(trust_env=True)
@@ -155,7 +155,7 @@ class DifyAPIClient:
raise Exception(f"Dify 文件上传失败:{resp.status}. {text}")
return await resp.json() # {"id": "xxx", ...}
async def close(self):
async def close(self) -> None:
await self.session.close()
async def get_chat_convs(self, user: str, limit: int = 20):
+10 -10
View File
@@ -64,7 +64,7 @@ class FunctionTool(ToolSchema, Generic[TContext]):
with a task identifier while the real work continues asynchronously.
"""
def __repr__(self):
def __repr__(self) -> str:
return f"FuncTool(name={self.name}, parameters={self.parameters}, description={self.description})"
async def call(self, context: ContextWrapper[TContext], **kwargs) -> ToolExecResult:
@@ -88,7 +88,7 @@ class ToolSet:
"""Check if the tool set is empty."""
return len(self.tools) == 0
def add_tool(self, tool: FunctionTool):
def add_tool(self, tool: FunctionTool) -> None:
"""Add a tool to the set."""
# 检查是否已存在同名工具
for i, existing_tool in enumerate(self.tools):
@@ -97,7 +97,7 @@ class ToolSet:
return
self.tools.append(tool)
def remove_tool(self, name: str):
def remove_tool(self, name: str) -> None:
"""Remove a tool by its name."""
self.tools = [tool for tool in self.tools if tool.name != name]
@@ -156,7 +156,7 @@ class ToolSet:
func_args: list,
desc: str,
handler: Callable[..., Awaitable[Any]],
):
) -> None:
"""Add a function tool to the set."""
params = {
"type": "object", # hard-coded here
@@ -176,7 +176,7 @@ class ToolSet:
self.add_tool(_func)
@deprecated(reason="Use remove_tool() instead", version="4.0.0")
def remove_func(self, name: str):
def remove_func(self, name: str) -> None:
"""Remove a function tool by its name."""
self.remove_tool(name)
@@ -325,22 +325,22 @@ class ToolSet:
"""获取所有工具的名称列表"""
return [tool.name for tool in self.tools]
def merge(self, other: "ToolSet"):
def merge(self, other: "ToolSet") -> None:
"""Merge another ToolSet into this one."""
for tool in other.tools:
self.add_tool(tool)
def __len__(self):
def __len__(self) -> int:
return len(self.tools)
def __bool__(self):
def __bool__(self) -> bool:
return len(self.tools) > 0
def __iter__(self):
return iter(self.tools)
def __repr__(self):
def __repr__(self) -> str:
return f"ToolSet(tools={self.tools})"
def __str__(self):
def __str__(self) -> str:
return f"ToolSet(tools={self.tools})"
+3 -3
View File
@@ -12,7 +12,7 @@ from astrbot.core.star.star_handler import EventType
class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
async def on_agent_done(self, run_context, llm_response):
async def on_agent_done(self, run_context, llm_response) -> None:
# 执行事件钩子
if llm_response and llm_response.reasoning_content:
# we will use this in result_decorate stage to inject reasoning content to chain
@@ -31,7 +31,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
run_context: ContextWrapper[AstrAgentContext],
tool: FunctionTool[Any],
tool_args: dict | None,
):
) -> None:
await call_event_hook(
run_context.context.event,
EventType.OnUsingLLMToolEvent,
@@ -45,7 +45,7 @@ class MainAgentHooks(BaseAgentRunHooks[AstrAgentContext]):
tool: FunctionTool[Any],
tool_args: dict | None,
tool_result: CallToolResult | None,
):
) -> None:
run_context.context.event.clear_result()
await call_event_hook(
run_context.context.event,
+3 -3
View File
@@ -295,7 +295,7 @@ async def _run_agent_feeder(
max_step: int,
show_tool_use: bool,
show_reasoning: bool,
):
) -> None:
"""运行 Agent 并将文本输出分句放入队列"""
buffer = ""
try:
@@ -352,7 +352,7 @@ async def _safe_tts_stream_wrapper(
tts_provider: TTSProvider,
text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
):
) -> None:
"""包装原生流式 TTS 确保异常处理和队列关闭"""
try:
await tts_provider.get_audio_stream(text_queue, audio_queue)
@@ -366,7 +366,7 @@ async def _simulated_stream_tts(
tts_provider: TTSProvider,
text_queue: asyncio.Queue[str | None],
audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]",
):
) -> None:
"""模拟流式 TTS 分句生成音频"""
try:
while True:
+2 -2
View File
@@ -57,7 +57,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
elif tool.is_background_task:
task_id = uuid.uuid4().hex
async def _run_in_background():
async def _run_in_background() -> None:
try:
await cls._execute_background(
tool=tool,
@@ -153,7 +153,7 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]):
run_context: ContextWrapper[AstrAgentContext],
task_id: str,
**tool_args,
):
) -> None:
from astrbot.core.astr_main_agent import (
MainAgentBuildConfig,
_get_session_conv,
+27 -30
View File
@@ -326,6 +326,24 @@ async def _ensure_persona_and_skills(
)
tmgr = plugin_context.get_llm_tool_manager()
# inject toolset in the persona
if (persona and persona.get("tools") is None) or not persona:
persona_toolset = tmgr.get_full_tool_set()
for tool in list(persona_toolset):
if not tool.active:
persona_toolset.remove_tool(tool.name)
else:
persona_toolset = ToolSet()
if persona["tools"]:
for tool_name in persona["tools"]:
tool = tmgr.get_func(tool_name)
if tool and tool.active:
persona_toolset.add_tool(tool)
if not req.func_tool:
req.func_tool = persona_toolset
else:
req.func_tool.merge(persona_toolset)
# sub agents integration
orch_cfg = plugin_context.get_config().get("subagent_orchestrator", {})
so = plugin_context.subagent_orchestrator
@@ -371,22 +389,19 @@ async def _ensure_persona_and_skills(
assigned_tools.add(name)
if req.func_tool is None:
toolset = ToolSet()
else:
toolset = req.func_tool
req.func_tool = ToolSet()
# add subagent handoff tools
for tool in so.handoffs:
toolset.add_tool(tool)
req.func_tool.add_tool(tool)
# check duplicates
if remove_dup:
names = toolset.names()
handoff_names = {tool.name for tool in so.handoffs}
for tool_name in assigned_tools:
if tool_name in names:
toolset.remove_tool(tool_name)
req.func_tool = toolset
if tool_name in handoff_names:
continue
req.func_tool.remove_tool(tool_name)
router_prompt = (
plugin_context.get_config()
@@ -395,32 +410,14 @@ async def _ensure_persona_and_skills(
).strip()
if router_prompt:
req.system_prompt += f"\n{router_prompt}\n"
return
# inject toolset in the persona
if (persona and persona.get("tools") is None) or not persona:
toolset = tmgr.get_full_tool_set()
for tool in list(toolset):
if not tool.active:
toolset.remove_tool(tool.name)
else:
toolset = ToolSet()
if persona["tools"]:
for tool_name in persona["tools"]:
tool = tmgr.get_func(tool_name)
if tool and tool.active:
toolset.add_tool(tool)
if not req.func_tool:
req.func_tool = toolset
else:
req.func_tool.merge(toolset)
try:
event.trace.record(
"sel_persona", persona_id=persona_id, persona_toolset=toolset.names()
"sel_persona",
persona_id=persona_id,
persona_toolset=persona_toolset.names(),
)
except Exception:
pass
logger.debug("Tool set for persona %s: %s", persona_id, toolset.names())
async def _request_img_caption(
+2 -2
View File
@@ -36,7 +36,7 @@ class AstrBotConfigManager:
default_config: AstrBotConfig,
ucr: UmopConfigRouter,
sp: SharedPreferences,
):
) -> None:
self.sp = sp
self.ucr = ucr
self.confs: dict[str, AstrBotConfig] = {}
@@ -56,7 +56,7 @@ class AstrBotConfigManager:
)
return self.abconf_data
def _load_all_configs(self):
def _load_all_configs(self) -> None:
"""Load all configurations from the shared preferences."""
abconf_data = self._get_abconf_data()
self.abconf_data = abconf_data
+1 -1
View File
@@ -59,7 +59,7 @@ class AstrBotExporter:
main_db: BaseDatabase,
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
):
) -> None:
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
+2 -2
View File
@@ -110,7 +110,7 @@ class ImportPreCheckResult:
class ImportResult:
"""导入结果"""
def __init__(self):
def __init__(self) -> None:
self.success = True
self.imported_tables: dict[str, int] = {}
self.imported_files: dict[str, int] = {}
@@ -161,7 +161,7 @@ class AstrBotImporter:
kb_manager: "KnowledgeBaseManager | None" = None,
config_path: str = CMD_CONFIG_FILE_PATH,
kb_root_dir: str = KB_PATH,
):
) -> None:
self.main_db = main_db
self.kb_manager = kb_manager
self.config_path = config_path
+1 -1
View File
@@ -22,7 +22,7 @@ class ComputerBooter:
"""
...
async def download_file(self, remote_path: str, local_path: str):
async def download_file(self, remote_path: str, local_path: str) -> None:
"""Download file from the computer."""
...
+1 -1
View File
@@ -225,7 +225,7 @@ class LocalBooter(ComputerBooter):
"LocalBooter does not support upload_file operation. Use shell instead."
)
async def download_file(self, remote_path: str, local_path: str):
async def download_file(self, remote_path: str, local_path: str) -> None:
raise NotImplementedError(
"LocalBooter does not support download_file operation. Use shell instead."
)
+1 -1
View File
@@ -100,7 +100,7 @@ class FileUploadTool(FunctionTool):
self,
context: ContextWrapper[AstrAgentContext],
local_path: str,
):
) -> str | None:
sb = await get_booter(
context.context.context,
context.context.event.unified_msg_origin,
+5 -5
View File
@@ -33,7 +33,7 @@ class AstrBotConfig(dict):
config_path: str = ASTRBOT_CONFIG_PATH,
default_config: dict = DEFAULT_CONFIG,
schema: dict | None = None,
):
) -> None:
super().__init__()
# 调用父类的 __setattr__ 方法,防止保存配置时将此属性写入配置文件
@@ -66,7 +66,7 @@ class AstrBotConfig(dict):
"""将 Schema 转换成 Config"""
conf = {}
def _parse_schema(schema: dict, conf: dict):
def _parse_schema(schema: dict, conf: dict) -> None:
for k, v in schema.items():
if v["type"] not in DEFAULT_VALUE_MAP:
raise TypeError(
@@ -148,7 +148,7 @@ class AstrBotConfig(dict):
return has_new
def save_config(self, replace_config: dict | None = None):
def save_config(self, replace_config: dict | None = None) -> None:
"""将配置写入文件
如果传入 replace_config则将配置替换为 replace_config
@@ -164,14 +164,14 @@ class AstrBotConfig(dict):
except KeyError:
return None
def __delattr__(self, key):
def __delattr__(self, key) -> None:
try:
del self[key]
self.save_config()
except KeyError:
raise AttributeError(f"没有找到 Key: '{key}'")
def __setattr__(self, key, value):
def __setattr__(self, key, value) -> None:
self[key] = value
def check_exist(self) -> bool:
+18 -5
View File
@@ -5,7 +5,7 @@ from typing import Any, TypedDict
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
VERSION = "4.14.7"
VERSION = "4.15.0"
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
WEBHOOK_SUPPORTED_PLATFORMS = [
@@ -129,8 +129,9 @@ DEFAULT_CONFIG = {
},
# SubAgent orchestrator mode:
# - main_enable = False: disabled; main LLM mounts tools normally (persona selection).
# - main_enable = True: enabled; main LLM will include handoff tools and can optionally
# remove tools that are duplicated on subagents via remove_main_duplicate_tools.
# - main_enable = True: enabled; main LLM keeps its own tools and includes handoff
# tools (transfer_to_*). remove_main_duplicate_tools can remove tools that are
# duplicated on subagents from the main LLM toolset.
"subagent_orchestrator": {
"main_enable": False,
"remove_main_duplicate_tools": False,
@@ -319,9 +320,11 @@ CONFIG_METADATA_2 = {
"id": "wecom_ai_bot",
"type": "wecom_ai_bot",
"enable": True,
"wecomaibot_init_respond_text": "💭 思考中...",
"wecomaibot_init_respond_text": "",
"wecomaibot_friend_message_welcome_text": "",
"wecom_ai_bot_name": "",
"msg_push_webhook_url": "",
"only_use_webhook_url_to_send": False,
"token": "",
"encoding_aes_key": "",
"unified_webhook_mode": True,
@@ -687,13 +690,23 @@ CONFIG_METADATA_2 = {
"wecomaibot_init_respond_text": {
"description": "企业微信智能机器人初始响应文本",
"type": "string",
"hint": "当机器人收到消息时,首先回复的文本内容。留空则使用默认值",
"hint": "当机器人收到消息时,首先回复的文本内容。留空则不设置",
},
"wecomaibot_friend_message_welcome_text": {
"description": "企业微信智能机器人私聊欢迎语",
"type": "string",
"hint": "当用户当天进入智能机器人单聊会话,回复欢迎语,留空则不回复。",
},
"msg_push_webhook_url": {
"description": "企业微信消息推送 Webhook URL",
"type": "string",
"hint": "用于 send_by_session 主动消息推送。格式示例: https://qyapi.weixin.qq.com/cgi-bin/webhook/send?key=xxx",
},
"only_use_webhook_url_to_send": {
"description": "仅使用 Webhook 发送消息",
"type": "bool",
"hint": "启用后,企业微信智能机器人的所有回复都改为通过消息推送 Webhook 发送。消息推送 Webhook 支持更多的消息类型(如图片、文件等)。",
},
"lark_bot_name": {
"description": "飞书机器人的名字",
"type": "string",
+6 -4
View File
@@ -16,7 +16,7 @@ from astrbot.core.db.po import Conversation, ConversationV2
class ConversationManager:
"""负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""
def __init__(self, db_helper: BaseDatabase):
def __init__(self, db_helper: BaseDatabase) -> None:
self.session_conversations: dict[str, str] = {}
self.db = db_helper
self.save_interval = 60 # 每 60 秒保存一次
@@ -106,7 +106,9 @@ class ConversationManager:
await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
return conv.conversation_id
async def switch_conversation(self, unified_msg_origin: str, conversation_id: str):
async def switch_conversation(
self, unified_msg_origin: str, conversation_id: str
) -> None:
"""切换会话的对话
Args:
@@ -121,7 +123,7 @@ class ConversationManager:
self,
unified_msg_origin: str,
conversation_id: str | None = None,
):
) -> None:
"""删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话
Args:
@@ -138,7 +140,7 @@ class ConversationManager:
self.session_conversations.pop(unified_msg_origin, None)
await sp.session_remove(unified_msg_origin, "sel_conv_id")
async def delete_conversations_by_user_id(self, unified_msg_origin: str):
async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None:
"""删除会话的所有对话
Args:
+3 -3
View File
@@ -24,7 +24,7 @@ class CronMessageEvent(AstrMessageEvent):
sender_name: str = "Scheduler",
extras: dict[str, Any] | None = None,
message_type: MessageType = MessageType.FRIEND_MESSAGE,
):
) -> None:
platform_meta = PlatformMetadata(
name="cron",
description="CronJob",
@@ -53,13 +53,13 @@ class CronMessageEvent(AstrMessageEvent):
if extras:
self._extras.update(extras)
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
if message is None:
return
await self.context_obj.send_message(self.session, message)
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
async def send_streaming(self, generator, use_fallback: bool = False) -> None:
async for chain in generator:
await self.send(chain)
+10 -10
View File
@@ -25,14 +25,14 @@ if TYPE_CHECKING:
class CronJobManager:
"""Central scheduler for BasicCronJob and ActiveAgentCronJob."""
def __init__(self, db: BaseDatabase):
def __init__(self, db: BaseDatabase) -> None:
self.db = db
self.scheduler = AsyncIOScheduler()
self._basic_handlers: dict[str, Callable[..., Any]] = {}
self._lock = asyncio.Lock()
self._started = False
async def start(self, ctx: "Context"):
async def start(self, ctx: "Context") -> None:
self.ctx: Context = ctx # star context
async with self._lock:
if self._started:
@@ -41,14 +41,14 @@ class CronJobManager:
self._started = True
await self.sync_from_db()
async def shutdown(self):
async def shutdown(self) -> None:
async with self._lock:
if not self._started:
return
self.scheduler.shutdown(wait=False)
self._started = False
async def sync_from_db(self):
async def sync_from_db(self) -> None:
jobs = await self.db.list_cron_jobs()
for job in jobs:
if not job.enabled or not job.persistent:
@@ -136,11 +136,11 @@ class CronJobManager:
async def list_jobs(self, job_type: str | None = None) -> list[CronJob]:
return await self.db.list_cron_jobs(job_type)
def _remove_scheduled(self, job_id: str):
def _remove_scheduled(self, job_id: str) -> None:
if self.scheduler.get_job(job_id):
self.scheduler.remove_job(job_id)
def _schedule_job(self, job: CronJob):
def _schedule_job(self, job: CronJob) -> None:
if not self._started:
self.scheduler.start()
self._started = True
@@ -188,7 +188,7 @@ class CronJobManager:
aps_job = self.scheduler.get_job(job_id)
return aps_job.next_run_time if aps_job else None
async def _run_job(self, job_id: str):
async def _run_job(self, job_id: str) -> None:
job = await self.db.get_cron_job(job_id)
if not job or not job.enabled:
return
@@ -222,7 +222,7 @@ class CronJobManager:
# one-shot: remove after execution regardless of success
await self.delete_job(job_id)
async def _run_basic_job(self, job: CronJob):
async def _run_basic_job(self, job: CronJob) -> None:
handler = self._basic_handlers.get(job.job_id)
if not handler:
raise RuntimeError(f"Basic cron job handler not found for {job.job_id}")
@@ -231,7 +231,7 @@ class CronJobManager:
if asyncio.iscoroutine(result):
await result
async def _run_active_agent_job(self, job: CronJob, start_time: datetime):
async def _run_active_agent_job(self, job: CronJob, start_time: datetime) -> None:
payload = job.payload or {}
session_str = payload.get("session")
if not session_str:
@@ -266,7 +266,7 @@ class CronJobManager:
message: str,
session_str: str,
extras: dict,
):
) -> None:
"""Woke the main agent to handle the cron job message."""
from astrbot.core.astr_main_agent import (
MainAgentBuildConfig,
+1 -1
View File
@@ -43,7 +43,7 @@ class BaseDatabase(abc.ABC):
expire_on_commit=False,
)
async def initialize(self):
async def initialize(self) -> None:
"""初始化数据库连接"""
@asynccontextmanager
+5 -5
View File
@@ -43,7 +43,7 @@ def get_platform_type(
async def migration_conversation_table(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
):
) -> None:
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
)
@@ -101,7 +101,7 @@ async def migration_conversation_table(
async def migration_platform_table(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
):
) -> None:
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
)
@@ -180,7 +180,7 @@ async def migration_platform_table(
async def migration_webchat_data(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
):
) -> None:
"""迁移 WebChat 的历史记录到新的 PlatformMessageHistory 表中"""
db_helper_v3 = SQLiteV3DatabaseV3(
db_path=DB_PATH.replace("data_v4.db", "data_v3.db"),
@@ -236,7 +236,7 @@ async def migration_webchat_data(
async def migration_persona_data(
db_helper: BaseDatabase,
astrbot_config: AstrBotConfig,
):
) -> None:
"""迁移 Persona 数据到新的表中。
旧的 Persona 数据存储在 preference 新的 Persona 数据存储在 persona 表中
"""
@@ -279,7 +279,7 @@ async def migration_persona_data(
async def migration_preferences(
db_helper: BaseDatabase,
platform_id_map: dict[str, dict[str, str]],
):
) -> None:
# 1. global scope migration
keys = [
"inactivated_llm_tools",
+1 -1
View File
@@ -3,7 +3,7 @@ from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
from astrbot.core.umop_config_router import UmopConfigRouter
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter):
async def migrate_45_to_46(acm: AstrBotConfigManager, ucr: UmopConfigRouter) -> None:
abconf_data = acm.abconf_data
if not isinstance(abconf_data, dict):
@@ -12,7 +12,7 @@ from astrbot.api import logger, sp
from astrbot.core.db import BaseDatabase
async def migrate_token_usage(db_helper: BaseDatabase):
async def migrate_token_usage(db_helper: BaseDatabase) -> None:
"""Add token_usage column to conversations table.
This migration adds a new column to track token consumption in conversations.
@@ -17,7 +17,7 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import ConversationV2, PlatformMessageHistory, PlatformSession
async def migrate_webchat_session(db_helper: BaseDatabase):
async def migrate_webchat_session(db_helper: BaseDatabase) -> None:
"""Create PlatformSession records from platform_message_history.
This migration extracts all unique user_ids from platform_message_history
@@ -8,7 +8,7 @@ _VT = TypeVar("_VT")
class SharedPreferences:
def __init__(self, path=None):
def __init__(self, path=None) -> None:
if path is None:
path = os.path.join(get_astrbot_data_path(), "shared_preferences.json")
self.path = path
@@ -23,7 +23,7 @@ class SharedPreferences:
os.remove(self.path)
return {}
def _save_preferences(self):
def _save_preferences(self) -> None:
with open(self.path, "w") as f:
json.dump(self._data, f, indent=4, ensure_ascii=False)
f.flush()
@@ -31,16 +31,16 @@ class SharedPreferences:
def get(self, key, default: _VT = None) -> _VT:
return self._data.get(key, default)
def put(self, key, value):
def put(self, key, value) -> None:
self._data[key] = value
self._save_preferences()
def remove(self, key):
def remove(self, key) -> None:
if key in self._data:
del self._data[key]
self._save_preferences()
def clear(self):
def clear(self) -> None:
self._data.clear()
self._save_preferences()
+10 -8
View File
@@ -127,7 +127,7 @@ class SQLiteDatabase:
conn.text_factory = str
return conn
def _exec_sql(self, sql: str, params: tuple | None = None):
def _exec_sql(self, sql: str, params: tuple | None = None) -> None:
conn = self.conn
try:
c = self.conn.cursor()
@@ -144,7 +144,7 @@ class SQLiteDatabase:
conn.commit()
def insert_platform_metrics(self, metrics: dict):
def insert_platform_metrics(self, metrics: dict) -> None:
for k, v in metrics.items():
self._exec_sql(
"""
@@ -153,7 +153,7 @@ class SQLiteDatabase:
(k, v, int(time.time())),
)
def insert_llm_metrics(self, metrics: dict):
def insert_llm_metrics(self, metrics: dict) -> None:
for k, v in metrics.items():
self._exec_sql(
"""
@@ -249,7 +249,7 @@ class SQLiteDatabase:
return Conversation(*res)
def new_conversation(self, user_id: str, cid: str):
def new_conversation(self, user_id: str, cid: str) -> None:
history = "[]"
updated_at = int(time.time())
created_at = updated_at
@@ -287,7 +287,7 @@ class SQLiteDatabase:
)
return conversations
def update_conversation(self, user_id: str, cid: str, history: str):
def update_conversation(self, user_id: str, cid: str, history: str) -> None:
"""更新对话,并且同时更新时间"""
updated_at = int(time.time())
self._exec_sql(
@@ -297,7 +297,7 @@ class SQLiteDatabase:
(history, updated_at, user_id, cid),
)
def update_conversation_title(self, user_id: str, cid: str, title: str):
def update_conversation_title(self, user_id: str, cid: str, title: str) -> None:
self._exec_sql(
"""
UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ?
@@ -305,7 +305,9 @@ class SQLiteDatabase:
(title, user_id, cid),
)
def update_conversation_persona_id(self, user_id: str, cid: str, persona_id: str):
def update_conversation_persona_id(
self, user_id: str, cid: str, persona_id: str
) -> None:
self._exec_sql(
"""
UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ?
@@ -313,7 +315,7 @@ class SQLiteDatabase:
(persona_id, user_id, cid),
)
def delete_conversation(self, user_id: str, cid: str):
def delete_conversation(self, user_id: str, cid: str) -> None:
self._exec_sql(
"""
DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ?
+8 -8
View File
@@ -305,7 +305,7 @@ class SQLiteDatabase(BaseDatabase):
await session.execute(query)
return await self.get_conversation_by_id(cid)
async def delete_conversation(self, cid):
async def delete_conversation(self, cid) -> None:
async with self.get_db() as session:
session: AsyncSession
async with session.begin():
@@ -461,7 +461,7 @@ class SQLiteDatabase(BaseDatabase):
platform_id,
user_id,
offset_sec=86400,
):
) -> None:
"""Delete platform message history records newer than the specified offset."""
async with self.get_db() as session:
session: AsyncSession
@@ -645,7 +645,7 @@ class SQLiteDatabase(BaseDatabase):
await session.execute(query)
return await self.get_persona_by_id(persona_id)
async def delete_persona(self, persona_id):
async def delete_persona(self, persona_id) -> None:
"""Delete a persona by its ID."""
async with self.get_db() as session:
session: AsyncSession
@@ -903,7 +903,7 @@ class SQLiteDatabase(BaseDatabase):
result = await session.execute(query)
return result.scalars().all()
async def remove_preference(self, scope, scope_id, key):
async def remove_preference(self, scope, scope_id, key) -> None:
"""Remove a preference by scope ID and key."""
async with self.get_db() as session:
session: AsyncSession
@@ -917,7 +917,7 @@ class SQLiteDatabase(BaseDatabase):
)
await session.commit()
async def clear_preferences(self, scope, scope_id):
async def clear_preferences(self, scope, scope_id) -> None:
"""Clear all preferences for a specific scope ID."""
async with self.get_db() as session:
session: AsyncSession
@@ -1195,7 +1195,7 @@ class SQLiteDatabase(BaseDatabase):
result = None
def runner():
def runner() -> None:
nonlocal result
result = asyncio.run(_inner())
@@ -1218,7 +1218,7 @@ class SQLiteDatabase(BaseDatabase):
result = None
def runner():
def runner() -> None:
nonlocal result
result = asyncio.run(_inner())
@@ -1253,7 +1253,7 @@ class SQLiteDatabase(BaseDatabase):
result = None
def runner():
def runner() -> None:
nonlocal result
result = asyncio.run(_inner())
+1 -1
View File
@@ -9,7 +9,7 @@ class Result:
class BaseVecDB:
async def initialize(self):
async def initialize(self) -> None:
"""初始化向量数据库"""
@abc.abstractmethod
@@ -33,7 +33,7 @@ class Document(BaseDocModel, table=True):
class DocumentStorage:
def __init__(self, db_path: str):
def __init__(self, db_path: str) -> None:
self.db_path = db_path
self.DATABASE_URL = f"sqlite+aiosqlite:///{db_path}"
self.engine: AsyncEngine | None = None
@@ -43,7 +43,7 @@ class DocumentStorage:
"sqlite_init.sql",
)
async def initialize(self):
async def initialize(self) -> None:
"""Initialize the SQLite database and create the documents table if it doesn't exist."""
await self.connect()
async with self.engine.begin() as conn: # type: ignore
@@ -80,7 +80,7 @@ class DocumentStorage:
await conn.commit()
async def connect(self):
async def connect(self) -> None:
"""Connect to the SQLite database."""
if self.engine is None:
self.engine = create_async_engine(
@@ -211,7 +211,7 @@ class DocumentStorage:
await session.flush() # Flush to get all IDs
return [doc.id for doc in documents] # type: ignore
async def delete_document_by_doc_id(self, doc_id: str):
async def delete_document_by_doc_id(self, doc_id: str) -> None:
"""Delete a document by its doc_id.
Args:
@@ -249,7 +249,7 @@ class DocumentStorage:
return self._document_to_dict(document)
return None
async def update_document_by_doc_id(self, doc_id: str, new_text: str):
async def update_document_by_doc_id(self, doc_id: str, new_text: str) -> None:
"""Update a document by its doc_id.
Args:
@@ -269,7 +269,7 @@ class DocumentStorage:
document.updated_at = datetime.now()
session.add(document)
async def delete_documents(self, metadata_filters: dict):
async def delete_documents(self, metadata_filters: dict) -> None:
"""Delete documents by their metadata filters.
Args:
@@ -384,7 +384,7 @@ class DocumentStorage:
"updated_at": row[5],
}
async def close(self):
async def close(self) -> None:
"""Close the connection to the SQLite database."""
if self.engine:
await self.engine.dispose()
@@ -10,7 +10,7 @@ import numpy as np
class EmbeddingStorage:
def __init__(self, dimension: int, path: str | None = None):
def __init__(self, dimension: int, path: str | None = None) -> None:
self.dimension = dimension
self.path = path
self.index = None
@@ -20,7 +20,7 @@ class EmbeddingStorage:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)
async def insert(self, vector: np.ndarray, id: int):
async def insert(self, vector: np.ndarray, id: int) -> None:
"""插入向量
Args:
@@ -38,7 +38,7 @@ class EmbeddingStorage:
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
await self.save_index()
async def insert_batch(self, vectors: np.ndarray, ids: list[int]):
async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
"""批量插入向量
Args:
@@ -71,7 +71,7 @@ class EmbeddingStorage:
distances, indices = self.index.search(vector, k)
return distances, indices
async def delete(self, ids: list[int]):
async def delete(self, ids: list[int]) -> None:
"""删除向量
Args:
@@ -83,7 +83,7 @@ class EmbeddingStorage:
self.index.remove_ids(id_array)
await self.save_index()
async def save_index(self):
async def save_index(self) -> None:
"""保存索引
Args:
+5 -5
View File
@@ -20,7 +20,7 @@ class FaissVecDB(BaseVecDB):
index_store_path: str,
embedding_provider: EmbeddingProvider,
rerank_provider: RerankProvider | None = None,
):
) -> None:
self.doc_store_path = doc_store_path
self.index_store_path = index_store_path
self.embedding_provider = embedding_provider
@@ -32,7 +32,7 @@ class FaissVecDB(BaseVecDB):
self.embedding_provider = embedding_provider
self.rerank_provider = rerank_provider
async def initialize(self):
async def initialize(self) -> None:
await self.document_storage.initialize()
async def insert(
@@ -165,7 +165,7 @@ class FaissVecDB(BaseVecDB):
return top_k_results
async def delete(self, doc_id: str):
async def delete(self, doc_id: str) -> None:
"""删除一条文档块(chunk"""
# 获得对应的 int id
result = await self.document_storage.get_document_by_doc_id(doc_id)
@@ -177,7 +177,7 @@ class FaissVecDB(BaseVecDB):
await self.document_storage.delete_document_by_doc_id(doc_id)
await self.embedding_storage.delete([int_id])
async def close(self):
async def close(self) -> None:
await self.document_storage.close()
async def count_documents(self, metadata_filter: dict | None = None) -> int:
@@ -192,7 +192,7 @@ class FaissVecDB(BaseVecDB):
)
return count
async def delete_documents(self, metadata_filters: dict):
async def delete_documents(self, metadata_filters: dict) -> None:
"""根据元数据过滤器删除文档"""
docs = await self.document_storage.get_documents(
metadata_filters=metadata_filters,
+3 -3
View File
@@ -28,13 +28,13 @@ class EventBus:
event_queue: Queue,
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
astrbot_config_mgr: AstrBotConfigManager,
):
) -> None:
self.event_queue = event_queue # 事件队列
# abconf uuid -> scheduler
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
self.astrbot_config_mgr = astrbot_config_mgr
async def dispatch(self):
async def dispatch(self) -> None:
while True:
event: AstrMessageEvent = await self.event_queue.get()
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
@@ -47,7 +47,7 @@ class EventBus:
continue
asyncio.create_task(scheduler.execute(event))
def _print_event(self, event: AstrMessageEvent, conf_name: str):
def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None:
"""用于记录事件信息
Args:
+2 -2
View File
@@ -9,12 +9,12 @@ from urllib.parse import unquote, urlparse
class FileTokenService:
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
def __init__(self, default_timeout: float = 300):
def __init__(self, default_timeout: float = 300) -> None:
self.lock = asyncio.Lock()
self.staged_files = {} # token: (file_path, expire_time)
self.default_timeout = default_timeout
async def _cleanup_expired_tokens(self):
async def _cleanup_expired_tokens(self) -> None:
"""清理过期的令牌"""
now = time.time()
expired_tokens = [
+2 -2
View File
@@ -17,13 +17,13 @@ from astrbot.dashboard.server import AstrBotDashboard
class InitialLoader:
"""AstrBot 启动器,负责初始化和启动核心组件和仪表板服务器。"""
def __init__(self, db: BaseDatabase, log_broker: LogBroker):
def __init__(self, db: BaseDatabase, log_broker: LogBroker) -> None:
self.db = db
self.logger = logger
self.log_broker = log_broker
self.webui_dir: str | None = None
async def start(self):
async def start(self) -> None:
core_lifecycle = AstrBotCoreLifecycle(self.log_broker, self.db)
try:
@@ -12,7 +12,7 @@ class FixedSizeChunker(BaseChunker):
按照固定的字符数分块,并支持块之间的重叠
"""
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50):
def __init__(self, chunk_size: int = 512, chunk_overlap: int = 50) -> None:
"""初始化分块器
Args:
@@ -11,7 +11,7 @@ class RecursiveCharacterChunker(BaseChunker):
length_function: Callable[[str], int] = len,
is_separator_regex: bool = False,
separators: list[str] | None = None,
):
) -> None:
"""初始化递归字符文本分割器
Args:
+1 -1
View File
@@ -253,7 +253,7 @@ class KBSQLiteDatabase:
"knowledge_base": row[1],
}
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB):
async def delete_document_by_id(self, doc_id: str, vec_db: FaissVecDB) -> None:
"""删除单个文档及其相关数据"""
# 在知识库表中删除
async with self.get_db() as session, session.begin():
+9 -9
View File
@@ -31,7 +31,7 @@ from .prompts import TEXT_REPAIR_SYSTEM_PROMPT
class RateLimiter:
"""一个简单的速率限制器"""
def __init__(self, max_rpm: int):
def __init__(self, max_rpm: int) -> None:
self.max_per_minute = max_rpm
self.interval = 60.0 / max_rpm if max_rpm > 0 else 0
self.last_call_time = 0
@@ -116,7 +116,7 @@ class KBHelper:
provider_manager: ProviderManager,
kb_root_dir: str,
chunker: BaseChunker,
):
) -> None:
self.kb_db = kb_db
self.kb = kb
self.prov_mgr = provider_manager
@@ -130,7 +130,7 @@ class KBHelper:
self.kb_medias_dir.mkdir(parents=True, exist_ok=True)
self.kb_files_dir.mkdir(parents=True, exist_ok=True)
async def initialize(self):
async def initialize(self) -> None:
await self._ensure_vec_db()
async def get_ep(self) -> EmbeddingProvider:
@@ -174,7 +174,7 @@ class KBHelper:
self.vec_db = vec_db
return vec_db
async def delete_vec_db(self):
async def delete_vec_db(self) -> None:
"""删除知识库的向量数据库和所有相关文件"""
import shutil
@@ -182,7 +182,7 @@ class KBHelper:
if self.kb_dir.exists():
shutil.rmtree(self.kb_dir)
async def terminate(self):
async def terminate(self) -> None:
if self.vec_db:
await self.vec_db.close()
@@ -293,7 +293,7 @@ class KBHelper:
await progress_callback("chunking", 100, 100)
# 阶段3: 生成向量(带进度回调)
async def embedding_progress_callback(current, total):
async def embedding_progress_callback(current, total) -> None:
if progress_callback:
await progress_callback("embedding", current, total)
@@ -360,7 +360,7 @@ class KBHelper:
doc = await self.kb_db.get_document_by_id(doc_id)
return doc
async def delete_document(self, doc_id: str):
async def delete_document(self, doc_id: str) -> None:
"""删除单个文档及其相关数据"""
await self.kb_db.delete_document_by_id(
doc_id=doc_id,
@@ -372,7 +372,7 @@ class KBHelper:
)
await self.refresh_kb()
async def delete_chunk(self, chunk_id: str, doc_id: str):
async def delete_chunk(self, chunk_id: str, doc_id: str) -> None:
"""删除单个文本块及其相关数据"""
vec_db: FaissVecDB = self.vec_db # type: ignore
await vec_db.delete(chunk_id)
@@ -383,7 +383,7 @@ class KBHelper:
await self.refresh_kb()
await self.refresh_document(doc_id)
async def refresh_kb(self):
async def refresh_kb(self) -> None:
if self.kb:
kb = await self.kb_db.get_kb_by_id(self.kb.kb_id)
if kb:
+5 -5
View File
@@ -26,14 +26,14 @@ class KnowledgeBaseManager:
def __init__(
self,
provider_manager: ProviderManager,
):
) -> None:
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
self.provider_manager = provider_manager
self._session_deleted_callback_registered = False
self.kb_insts: dict[str, KBHelper] = {}
async def initialize(self):
async def initialize(self) -> None:
"""初始化知识库模块"""
try:
logger.info("正在初始化知识库模块...")
@@ -58,13 +58,13 @@ class KnowledgeBaseManager:
logger.error(f"知识库模块初始化失败: {e}")
logger.error(traceback.format_exc())
async def _init_kb_database(self):
async def _init_kb_database(self) -> None:
self.kb_db = KBSQLiteDatabase(DB_PATH.as_posix())
await self.kb_db.initialize()
await self.kb_db.migrate_to_v1()
logger.info(f"KnowledgeBase database initialized: {DB_PATH}")
async def load_kbs(self):
async def load_kbs(self) -> None:
"""加载所有知识库实例"""
kb_records = await self.kb_db.list_kbs()
for record in kb_records:
@@ -275,7 +275,7 @@ class KnowledgeBaseManager:
return "\n".join(lines)
async def terminate(self):
async def terminate(self) -> None:
"""终止所有知识库实例,关闭数据库连接"""
for kb_id, kb_helper in self.kb_insts.items():
try:
@@ -6,7 +6,7 @@ import aiohttp
class URLExtractor:
"""URL 内容提取器,封装了 Tavily API 调用和密钥管理"""
def __init__(self, tavily_keys: list[str]):
def __init__(self, tavily_keys: list[str]) -> None:
"""
初始化 URL 提取器
@@ -44,7 +44,7 @@ class RetrievalManager:
sparse_retriever: SparseRetriever,
rank_fusion: RankFusion,
kb_db: KBSQLiteDatabase,
):
) -> None:
"""初始化检索管理器
Args:
@@ -31,7 +31,7 @@ class RankFusion:
- 使用 Reciprocal Rank Fusion (RRF) 算法
"""
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60):
def __init__(self, kb_db: KBSQLiteDatabase, k: int = 60) -> None:
"""初始化结果融合器
Args:
@@ -34,7 +34,7 @@ class SparseRetriever:
- 使用 BM25 算法计算相关度
"""
def __init__(self, kb_db: KBSQLiteDatabase):
def __init__(self, kb_db: KBSQLiteDatabase) -> None:
"""初始化稀疏检索器
Args:
+15 -15
View File
@@ -91,7 +91,7 @@ class LogBroker:
发布-订阅模式
"""
def __init__(self):
def __init__(self) -> None:
self.log_cache = deque(maxlen=CACHED_SIZE) # 环形缓冲区, 保存最近的日志
self.subscribers: list[Queue] = [] # 订阅者列表
@@ -106,7 +106,7 @@ class LogBroker:
self.subscribers.append(q)
return q
def unregister(self, q: Queue):
def unregister(self, q: Queue) -> None:
"""取消订阅
Args:
@@ -115,7 +115,7 @@ class LogBroker:
"""
self.subscribers.remove(q)
def publish(self, log_entry: dict):
def publish(self, log_entry: dict) -> None:
"""发布新日志到所有订阅者, 使用非阻塞方式投递, 避免一个订阅者阻塞整个系统
Args:
@@ -137,11 +137,11 @@ class LogQueueHandler(logging.Handler):
继承自 logging.Handler
"""
def __init__(self, log_broker: LogBroker):
def __init__(self, log_broker: LogBroker) -> None:
super().__init__()
self.log_broker = log_broker
def emit(self, record):
def emit(self, record) -> None:
"""日志处理的入口方法, 接受一个日志记录, 转换为字符串后由 LogBroker 发布
这个方法会在每次日志记录时被调用
@@ -201,7 +201,7 @@ class LogManager:
class PluginFilter(logging.Filter):
"""插件过滤器类, 用于标记日志来源是插件还是核心组件"""
def filter(self, record):
def filter(self, record) -> bool:
record.plugin_tag = (
"[Plug]" if is_plugin_path(record.pathname) else "[Core]"
)
@@ -213,7 +213,7 @@ class LogManager:
"""
# 获取这个文件和父文件夹的名字:<folder>.<file> 并且去除 .py
def filter(self, record):
def filter(self, record) -> bool:
dirname = os.path.dirname(record.pathname)
record.filename = (
os.path.basename(dirname)
@@ -226,14 +226,14 @@ class LogManager:
"""短日志级别名称过滤器类, 用于将日志级别名称转换为四个字母的缩写"""
# 添加短日志级别名称
def filter(self, record):
def filter(self, record) -> bool:
record.short_levelname = get_short_level_name(record.levelname)
return True
class AstrBotVersionTagFilter(logging.Filter):
"""在 WARNING 及以上级别日志后追加当前 AstrBot 版本号。"""
def filter(self, record):
def filter(self, record) -> bool:
if record.levelno >= logging.WARNING:
record.astrbot_version_tag = f" [v{VERSION}]"
else:
@@ -251,7 +251,7 @@ class LogManager:
return logger
@classmethod
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker):
def set_queue_handler(cls, logger: logging.Logger, log_broker: LogBroker) -> None:
"""设置队列处理器, 用于将日志消息发送到 LogBroker
Args:
@@ -301,7 +301,7 @@ class LogManager:
]
@classmethod
def _remove_file_handlers(cls, logger: logging.Logger):
def _remove_file_handlers(cls, logger: logging.Logger) -> None:
for handler in cls._get_file_handlers(logger):
logger.removeHandler(handler)
try:
@@ -310,7 +310,7 @@ class LogManager:
pass
@classmethod
def _remove_trace_file_handlers(cls, logger: logging.Logger):
def _remove_trace_file_handlers(cls, logger: logging.Logger) -> None:
for handler in cls._get_trace_file_handlers(logger):
logger.removeHandler(handler)
try:
@@ -326,7 +326,7 @@ class LogManager:
max_mb: int | None = None,
backup_count: int = 3,
trace: bool = False,
):
) -> None:
os.makedirs(os.path.dirname(file_path) or ".", exist_ok=True)
max_bytes = 0
if max_mb and max_mb > 0:
@@ -365,7 +365,7 @@ class LogManager:
logger: logging.Logger,
config: dict | None,
override_level: str | None = None,
):
) -> None:
"""根据配置设置日志级别和文件日志。
Args:
@@ -413,7 +413,7 @@ class LogManager:
cls._add_file_handler(logger, file_path, max_mb=max_mb)
@classmethod
def configure_trace_logger(cls, config: dict | None):
def configure_trace_logger(cls, config: dict | None) -> None:
"""为 trace 事件配置独立的文件日志,不向控制台输出。"""
if not config:
return
+27 -27
View File
@@ -66,7 +66,7 @@ class ComponentType(str, Enum):
class BaseMessageComponent(BaseModel):
type: ComponentType
def __init__(self, **kwargs):
def __init__(self, **kwargs) -> None:
super().__init__(**kwargs)
def toDict(self):
@@ -89,7 +89,7 @@ class Plain(BaseMessageComponent):
text: str
convert: bool | None = True
def __init__(self, text: str, convert: bool = True, **_):
def __init__(self, text: str, convert: bool = True, **_) -> None:
super().__init__(text=text, convert=convert, **_)
def toDict(self):
@@ -103,7 +103,7 @@ class Face(BaseMessageComponent):
type = ComponentType.Face
id: int
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -118,7 +118,7 @@ class Record(BaseMessageComponent):
# 额外
path: str | None
def __init__(self, file: str | None, **_):
def __init__(self, file: str | None, **_) -> None:
for k in _:
if k == "url":
pass
@@ -221,7 +221,7 @@ class Video(BaseMessageComponent):
# 额外
path: str | None = ""
def __init__(self, file: str, **_):
def __init__(self, file: str, **_) -> None:
super().__init__(file=file, **_)
@staticmethod
@@ -255,7 +255,7 @@ class Video(BaseMessageComponent):
return os.path.abspath(url)
raise Exception(f"not a valid file: {url}")
async def register_to_file_service(self):
async def register_to_file_service(self) -> str:
"""将视频注册到文件服务。
Returns:
@@ -303,7 +303,7 @@ class At(BaseMessageComponent):
qq: int | str # 此处str为all时代表所有人
name: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
def toDict(self):
@@ -316,28 +316,28 @@ class At(BaseMessageComponent):
class AtAll(At):
qq: str = "all"
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
class RPS(BaseMessageComponent): # TODO
type = ComponentType.RPS
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
class Dice(BaseMessageComponent): # TODO
type = ComponentType.Dice
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
class Shake(BaseMessageComponent): # TODO
type = ComponentType.Shake
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -348,7 +348,7 @@ class Share(BaseMessageComponent):
content: str | None = ""
image: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -357,7 +357,7 @@ class Contact(BaseMessageComponent): # TODO
_type: str # type 字段冲突
id: int | None = 0
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -368,7 +368,7 @@ class Location(BaseMessageComponent): # TODO
title: str | None = ""
content: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -382,7 +382,7 @@ class Music(BaseMessageComponent):
content: str | None = ""
image: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
# for k in _.keys():
# if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]:
# logger.warn(f"Protocol: {k}={_[k]} doesn't match values")
@@ -402,7 +402,7 @@ class Image(BaseMessageComponent):
path: str | None = ""
file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识
def __init__(self, file: str | None, **_):
def __init__(self, file: str | None, **_) -> None:
super().__init__(file=file, **_)
@staticmethod
@@ -525,7 +525,7 @@ class Reply(BaseMessageComponent):
seq: int | None = 0
"""deprecated"""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -534,7 +534,7 @@ class Poke(BaseMessageComponent):
id: int | None = 0
qq: int | None = 0
def __init__(self, type: str, **_):
def __init__(self, type: str, **_) -> None:
type = f"Poke:{type}"
super().__init__(type=type, **_)
@@ -543,7 +543,7 @@ class Forward(BaseMessageComponent):
type = ComponentType.Forward
id: str
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
@@ -558,7 +558,7 @@ class Node(BaseMessageComponent):
seq: str | list | None = "" # 忽略
time: int | None = 0 # 忽略
def __init__(self, content: list[BaseMessageComponent], **_):
def __init__(self, content: list[BaseMessageComponent], **_) -> None:
if isinstance(content, Node):
# back
content = [content]
@@ -605,7 +605,7 @@ class Nodes(BaseMessageComponent):
type = ComponentType.Nodes
nodes: list[Node]
def __init__(self, nodes: list[Node], **_):
def __init__(self, nodes: list[Node], **_) -> None:
super().__init__(nodes=nodes, **_)
def toDict(self):
@@ -631,7 +631,7 @@ class Json(BaseMessageComponent):
type = ComponentType.Json
data: dict
def __init__(self, data: str | dict, **_):
def __init__(self, data: str | dict, **_) -> None:
if isinstance(data, str):
data = json.loads(data)
super().__init__(data=data, **_)
@@ -650,7 +650,7 @@ class File(BaseMessageComponent):
file_: str | None = "" # 本地路径
url: str | None = "" # url
def __init__(self, name: str, file: str = "", url: str = ""):
def __init__(self, name: str, file: str = "", url: str = "") -> None:
"""文件消息段。"""
super().__init__(name=name, file_=file, url=url)
@@ -686,7 +686,7 @@ class File(BaseMessageComponent):
return ""
@file.setter
def file(self, value: str):
def file(self, value: str) -> None:
"""向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
Args:
@@ -721,7 +721,7 @@ class File(BaseMessageComponent):
return ""
async def _download_file(self):
async def _download_file(self) -> None:
"""下载文件"""
if not self.url:
raise ValueError("Download failed: No URL provided in File component.")
@@ -736,7 +736,7 @@ class File(BaseMessageComponent):
await download_file(self.url, file_path)
self.file_ = os.path.abspath(file_path)
async def register_to_file_service(self):
async def register_to_file_service(self) -> str:
"""将文件注册到文件服务。
Returns:
@@ -786,7 +786,7 @@ class WechatEmoji(BaseMessageComponent):
md5_len: int | None = 0
cdnurl: str | None = ""
def __init__(self, **_):
def __init__(self, **_) -> None:
super().__init__(**_)
+3 -3
View File
@@ -17,7 +17,7 @@ DEFAULT_PERSONALITY = Personality(
class PersonaManager:
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager):
def __init__(self, db_helper: BaseDatabase, acm: AstrBotConfigManager) -> None:
self.db = db_helper
self.acm = acm
default_ps = acm.default_conf.get("provider_settings", {})
@@ -29,7 +29,7 @@ class PersonaManager:
self.selected_default_persona_v3: Personality | None = None
self.persona_v3_config: list[dict] = []
async def initialize(self):
async def initialize(self) -> None:
self.personas = await self.get_all_personas()
self.get_v3_persona_data()
logger.info(f"已加载 {len(self.personas)} 个人格。")
@@ -58,7 +58,7 @@ class PersonaManager:
except Exception:
return DEFAULT_PERSONALITY
async def delete_persona(self, persona_id: str):
async def delete_persona(self, persona_id: str) -> None:
"""删除指定 persona"""
if not await self.db.get_persona_by_id(persona_id):
raise ValueError(f"Persona with ID {persona_id} does not exist.")
@@ -16,7 +16,7 @@ class ContentSafetyCheckStage(Stage):
当前只会检查文本的
"""
async def initialize(self, ctx: PipelineContext):
async def initialize(self, ctx: PipelineContext) -> None:
config = ctx.astrbot_config["content_safety"]
self.strategy_selector = StrategySelector(config)
@@ -336,7 +336,7 @@ class InternalAgentSubStage(Stage):
llm_response: LLMResponse | None,
all_messages: list[Message],
runner_stats: AgentStats | None,
):
) -> None:
if (
not req
or not req.conversation
@@ -19,7 +19,7 @@ class RateLimitStage(Stage):
如果触发限流 stall 流水线直到下一个时间窗口来临时自动唤醒
"""
def __init__(self):
def __init__(self) -> None:
# 存储每个会话的请求时间队列
self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
# 为每个会话设置一个锁,避免并发冲突
+2 -2
View File
@@ -35,7 +35,7 @@ class RespondStage(Stage):
Comp.WechatEmoji: lambda comp: comp.md5 is not None, # 微信表情
}
async def initialize(self, ctx: PipelineContext):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.config = ctx.astrbot_config
self.platform_settings: dict = self.config.get("platform_settings", {})
@@ -91,7 +91,7 @@ class RespondStage(Stage):
# random
return random.uniform(self.interval[0], self.interval[1])
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]):
async def _is_empty_message_chain(self, chain: list[BaseMessageComponent]) -> bool:
"""检查消息链是否为空
Args:
@@ -20,7 +20,7 @@ from ..stage import Stage, register_stage, registered_stages
@register_stage
class ResultDecorateStage(Stage):
async def initialize(self, ctx: PipelineContext):
async def initialize(self, ctx: PipelineContext) -> None:
self.ctx = ctx
self.reply_prefix = ctx.astrbot_config["platform_settings"]["reply_prefix"]
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
+4 -4
View File
@@ -15,21 +15,21 @@ from .stage import registered_stages
class PipelineScheduler:
"""管道调度器,负责调度各个阶段的执行"""
def __init__(self, context: PipelineContext):
def __init__(self, context: PipelineContext) -> None:
registered_stages.sort(
key=lambda x: STAGES_ORDER.index(x.__name__),
) # 按照顺序排序
self.ctx = context # 上下文对象
self.stages = [] # 存储阶段实例
async def initialize(self):
async def initialize(self) -> None:
"""初始化管道调度器时, 初始化所有阶段"""
for stage_cls in registered_stages:
stage_instance = stage_cls() # 创建实例
await stage_instance.initialize(self.ctx)
self.stages.append(stage_instance)
async def _process_stages(self, event: AstrMessageEvent, from_stage=0):
async def _process_stages(self, event: AstrMessageEvent, from_stage=0) -> None:
"""依次执行各个阶段
Args:
@@ -72,7 +72,7 @@ class PipelineScheduler:
logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。")
break
async def execute(self, event: AstrMessageEvent):
async def execute(self, event: AstrMessageEvent) -> None:
"""执行 pipeline
Args:
+15 -15
View File
@@ -38,7 +38,7 @@ class AstrMessageEvent(abc.ABC):
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
):
) -> None:
self.message_str = message_str
"""纯文本的消息"""
self.message_obj = message_obj
@@ -91,7 +91,7 @@ class AstrMessageEvent(abc.ABC):
return str(self.session)
@unified_msg_origin.setter
def unified_msg_origin(self, value: str):
def unified_msg_origin(self, value: str) -> None:
"""设置统一的消息来源字符串。格式为 platform_name:message_type:session_id"""
self.new_session = MessageSession.from_str(value)
self.session = self.new_session
@@ -102,7 +102,7 @@ class AstrMessageEvent(abc.ABC):
return self.session.session_id
@session_id.setter
def session_id(self, value: str):
def session_id(self, value: str) -> None:
"""设置用户的会话 ID。可以直接使用下面的 unified_msg_origin"""
self.session.session_id = value
@@ -191,7 +191,7 @@ class AstrMessageEvent(abc.ABC):
return self.message_obj.sender.nickname
return ""
def set_extra(self, key, value):
def set_extra(self, key, value) -> None:
"""设置额外的信息。"""
self._extras[key] = value
@@ -201,7 +201,7 @@ class AstrMessageEvent(abc.ABC):
return self._extras
return self._extras.get(key, default)
def clear_extra(self):
def clear_extra(self) -> None:
"""清除额外的信息。"""
logger.info(f"清除 {self.get_platform_name()} 的额外信息: {self._extras}")
self._extras.clear()
@@ -234,7 +234,7 @@ class AstrMessageEvent(abc.ABC):
self,
generator: AsyncGenerator[MessageChain, None],
use_fallback: bool = False,
):
) -> None:
"""发送流式消息到消息平台,使用异步生成器。
目前仅支持: telegramqq official 私聊
Fallback仅支持 aiocqhttp
@@ -244,13 +244,13 @@ class AstrMessageEvent(abc.ABC):
)
self._has_send_oper = True
async def _pre_send(self):
async def _pre_send(self) -> None:
"""调度器会在执行 send() 前调用该方法 deprecated in v3.5.18"""
async def _post_send(self):
async def _post_send(self) -> None:
"""调度器会在执行 send() 后调用该方法 deprecated in v3.5.18"""
def set_result(self, result: MessageEventResult | str):
def set_result(self, result: MessageEventResult | str) -> None:
"""设置消息事件的结果。
Note:
@@ -279,14 +279,14 @@ class AstrMessageEvent(abc.ABC):
result.chain = []
self._result = result
def stop_event(self):
def stop_event(self) -> None:
"""终止事件传播。"""
if self._result is None:
self.set_result(MessageEventResult().stop_event())
else:
self._result.stop_event()
def continue_event(self):
def continue_event(self) -> None:
"""继续事件传播。"""
if self._result is None:
self.set_result(MessageEventResult().continue_event())
@@ -299,7 +299,7 @@ class AstrMessageEvent(abc.ABC):
return False # 默认是继续传播
return self._result.is_stopped()
def should_call_llm(self, call_llm: bool):
def should_call_llm(self, call_llm: bool) -> None:
"""是否在此消息事件中禁止默认的 LLM 请求。
只会阻止 AstrBot 默认的 LLM 请求链路不会阻止插件中的 LLM 请求
@@ -310,7 +310,7 @@ class AstrMessageEvent(abc.ABC):
"""获取消息事件的结果。"""
return self._result
def clear_result(self):
def clear_result(self) -> None:
"""清除消息事件的结果。"""
self._result = None
@@ -404,7 +404,7 @@ class AstrMessageEvent(abc.ABC):
"""平台适配器"""
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
"""发送消息到消息平台。
Args:
@@ -423,7 +423,7 @@ class AstrMessageEvent(abc.ABC):
)
self._has_send_oper = True
async def react(self, emoji: str):
async def react(self, emoji: str) -> None:
"""对消息添加表情回应。
默认实现为发送一条包含该表情的消息
+3 -3
View File
@@ -11,7 +11,7 @@ class MessageMember:
user_id: str # 发送者id
nickname: str | None = None
def __str__(self):
def __str__(self) -> str:
# 使用 f-string 来构建返回的字符串表示形式
return (
f"User ID: {self.user_id},"
@@ -34,7 +34,7 @@ class Group:
members: list[MessageMember] | None = None
"""所有群成员"""
def __str__(self):
def __str__(self) -> str:
# 使用 f-string 来构建返回的字符串表示形式
return (
f"Group ID: {self.group_id}\n"
@@ -78,7 +78,7 @@ class AstrBotMessage:
return ""
@group_id.setter
def group_id(self, value: str | None):
def group_id(self, value: str | None) -> None:
"""设置 group_id"""
if value:
if self.group:
+9 -7
View File
@@ -13,7 +13,7 @@ from .sources.webchat.webchat_adapter import WebChatAdapter
class PlatformManager:
def __init__(self, config: AstrBotConfig, event_queue: Queue):
def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None:
self.platform_insts: list[Platform] = []
"""加载的 Platform 的实例"""
@@ -38,7 +38,7 @@ class PlatformManager:
sanitized = platform_id.replace(":", "_").replace("!", "_")
return sanitized, sanitized != platform_id
async def initialize(self):
async def initialize(self) -> None:
"""初始化所有平台适配器"""
for platform in self.platforms_config:
try:
@@ -58,7 +58,7 @@ class PlatformManager:
),
)
async def load_platform(self, platform_config: dict):
async def load_platform(self, platform_config: dict) -> None:
"""实例化一个平台"""
# 动态导入
try:
@@ -176,7 +176,9 @@ class PlatformManager:
except Exception:
logger.error(traceback.format_exc())
async def _task_wrapper(self, task: asyncio.Task, platform: Platform | None = None):
async def _task_wrapper(
self, task: asyncio.Task, platform: Platform | None = None
) -> None:
# 设置平台状态为运行中
if platform:
platform.status = PlatformStatus.RUNNING
@@ -198,7 +200,7 @@ class PlatformManager:
if platform:
platform.record_error(error_msg, tb_str)
async def reload(self, platform_config: dict):
async def reload(self, platform_config: dict) -> None:
await self.terminate_platform(platform_config["id"])
if platform_config["enable"]:
await self.load_platform(platform_config)
@@ -209,7 +211,7 @@ class PlatformManager:
if key not in config_ids:
await self.terminate_platform(key)
async def terminate_platform(self, platform_id: str):
async def terminate_platform(self, platform_id: str) -> None:
if platform_id in self._inst_map:
logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")
@@ -231,7 +233,7 @@ class PlatformManager:
if getattr(inst, "terminate", None):
await inst.terminate()
async def terminate(self):
async def terminate(self) -> None:
for inst in self.platform_insts:
if getattr(inst, "terminate", None):
await inst.terminate()
+1 -1
View File
@@ -15,7 +15,7 @@ class MessageSession:
session_id: str
platform_id: str = field(init=False)
def __str__(self):
def __str__(self) -> str:
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
def __post_init__(self):
+7 -7
View File
@@ -34,7 +34,7 @@ class PlatformError:
class Platform(abc.ABC):
def __init__(self, config: dict, event_queue: Queue):
def __init__(self, config: dict, event_queue: Queue) -> None:
super().__init__()
# 平台配置
self.config = config
@@ -53,7 +53,7 @@ class Platform(abc.ABC):
return self._status
@status.setter
def status(self, value: PlatformStatus):
def status(self, value: PlatformStatus) -> None:
"""设置平台运行状态"""
self._status = value
if value == PlatformStatus.RUNNING and self._started_at is None:
@@ -69,12 +69,12 @@ class Platform(abc.ABC):
"""获取最近的错误"""
return self._errors[-1] if self._errors else None
def record_error(self, message: str, traceback_str: str | None = None):
def record_error(self, message: str, traceback_str: str | None = None) -> None:
"""记录一个错误"""
self._errors.append(PlatformError(message=message, traceback=traceback_str))
self._status = PlatformStatus.ERROR
def clear_errors(self):
def clear_errors(self) -> None:
"""清除错误记录"""
self._errors.clear()
if self._status == PlatformStatus.ERROR:
@@ -121,7 +121,7 @@ class Platform(abc.ABC):
"""得到一个平台的运行实例,需要返回一个协程对象。"""
raise NotImplementedError
async def terminate(self):
async def terminate(self) -> None:
"""终止一个平台的运行实例。"""
@abc.abstractmethod
@@ -140,11 +140,11 @@ class Platform(abc.ABC):
"""
await Metric.upload(msg_event_tick=1, adapter_name=self.meta().name)
def commit_event(self, event: AstrMessageEvent):
def commit_event(self, event: AstrMessageEvent) -> None:
"""提交一个事件到事件队列。"""
self._event_queue.put_nowait(event)
def get_client(self):
def get_client(self) -> object:
"""获取平台的客户端对象。"""
async def webhook_callback(self, request: Any) -> Any:
@@ -26,7 +26,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
platform_meta,
session_id,
bot: CQHttp,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@@ -72,7 +72,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
is_group: bool,
session_id: str | None,
messages: list[dict],
):
) -> None:
# session_id 必须是纯数字字符串
session_id_int = (
int(session_id) if session_id and session_id.isdigit() else None
@@ -97,7 +97,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
event: Event | None = None,
is_group: bool = False,
session_id: str | None = None,
):
) -> None:
"""发送消息至 QQ 协议端(aiocqhttp)。
Args:
@@ -143,7 +143,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
await cls._dispatch_send(bot, event, is_group, session_id, messages)
await asyncio.sleep(0.5)
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
"""发送消息"""
event = getattr(self.message_obj, "raw_message", None)
@@ -61,7 +61,7 @@ class AiocqhttpAdapter(Platform):
)
@self.bot.on_request()
async def request(event: Event):
async def request(event: Event) -> None:
try:
abm = await self.convert_message(event)
if not abm:
@@ -72,7 +72,7 @@ class AiocqhttpAdapter(Platform):
return
@self.bot.on_notice()
async def notice(event: Event):
async def notice(event: Event) -> None:
try:
abm = await self.convert_message(event)
if abm:
@@ -82,7 +82,7 @@ class AiocqhttpAdapter(Platform):
return
@self.bot.on_message("group")
async def group(event: Event):
async def group(event: Event) -> None:
try:
abm = await self.convert_message(event)
if abm:
@@ -92,7 +92,7 @@ class AiocqhttpAdapter(Platform):
return
@self.bot.on_message("private")
async def private(event: Event):
async def private(event: Event) -> None:
try:
abm = await self.convert_message(event)
if abm:
@@ -102,14 +102,14 @@ class AiocqhttpAdapter(Platform):
return
@self.bot.on_websocket_connection
def on_websocket_connection(_):
def on_websocket_connection(_) -> None:
logger.info("aiocqhttp(OneBot v11) 适配器已连接。")
async def send_by_session(
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
is_group = session.message_type == MessageType.GROUP_MESSAGE
if is_group:
session_id = session.session_id.split("_")[-1]
@@ -435,17 +435,17 @@ class AiocqhttpAdapter(Platform):
self.shutdown_event = asyncio.Event()
return coro
async def terminate(self):
async def terminate(self) -> None:
self.shutdown_event.set()
async def shutdown_trigger_placeholder(self):
async def shutdown_trigger_placeholder(self) -> None:
await self.shutdown_event.wait()
logger.info("aiocqhttp 适配器已被关闭")
def meta(self) -> PlatformMetadata:
return self.metadata
async def handle_msg(self, message: AstrBotMessage):
async def handle_msg(self, message: AstrBotMessage) -> None:
message_event = AiocqhttpMessageEvent(
message_str=message.message_str,
message_obj=message,
@@ -1,8 +1,9 @@
import asyncio
import os
import json
import threading
import uuid
from typing import cast
from pathlib import Path
from typing import Literal, NoReturn, cast
import aiohttp
import dingtalk_stream
@@ -10,7 +11,7 @@ from dingtalk_stream import AckMessage
from astrbot import logger
from astrbot.api.event import MessageChain
from astrbot.api.message_components import At, Image, Plain
from astrbot.api.message_components import At, Image, Plain, Record, Video
from astrbot.api.platform import (
AstrBotMessage,
MessageMember,
@@ -18,9 +19,16 @@ from astrbot.api.platform import (
Platform,
PlatformMetadata,
)
from astrbot.core import sp
from astrbot.core.platform.astr_message_event import MessageSesion
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.io import download_file
from astrbot.core.utils.media_utils import (
convert_audio_format,
convert_video_format,
extract_video_cover,
get_media_duration,
)
from ...register import register_platform_adapter
from .dingtalk_event import DingtalkMessageEvent
@@ -75,8 +83,6 @@ class DingtalkPlatformAdapter(Platform):
)
self.client_ = client # 用于 websockets 的 client
self._shutdown_event: threading.Event | None = None
self.card_template_id = platform_config.get("card_template_id")
self.card_instance_id_dict = {}
def _id_to_sid(self, dingtalk_id: str | None) -> str:
if not dingtalk_id:
@@ -90,8 +96,45 @@ class DingtalkPlatformAdapter(Platform):
self,
session: MessageSesion,
message_chain: MessageChain,
):
raise NotImplementedError("钉钉机器人适配器不支持 send_by_session")
) -> None:
robot_code = self.client_id
if session.message_type == MessageType.GROUP_MESSAGE:
open_conversation_id = session.session_id
await self.send_message_chain_to_group(
open_conversation_id=open_conversation_id,
robot_code=robot_code,
message_chain=message_chain,
)
else:
staff_id = await self._get_sender_staff_id(session)
if not staff_id:
logger.warning(
"钉钉私聊会话缺少 staff_id 映射,回退使用 session_id 作为 userId 发送",
)
staff_id = session.session_id
await self.send_message_chain_to_user(
staff_id=staff_id,
robot_code=robot_code,
message_chain=message_chain,
)
await super().send_by_session(session, message_chain)
async def send_with_session(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
await self.send_by_session(session, message_chain)
async def send_with_sesison(
self,
session: MessageSesion,
message_chain: MessageChain,
) -> None:
# backward typo compatibility
await self.send_by_session(session, message_chain)
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
@@ -99,65 +142,9 @@ class DingtalkPlatformAdapter(Platform):
description="钉钉机器人官方 API 适配器",
id=cast(str, self.config.get("id")),
support_streaming_message=True,
support_proactive_message=False,
support_proactive_message=True,
)
async def create_message_card(
self, message_id: str, incoming_message: dingtalk_stream.ChatbotMessage
):
if not self.card_template_id:
return False
card_instance = dingtalk_stream.AICardReplier(self.client_, incoming_message)
card_data = {"content": ""} # Initial content empty
try:
card_instance_id = await card_instance.async_create_and_deliver_card(
self.card_template_id,
card_data,
)
self.card_instance_id_dict[message_id] = (card_instance, card_instance_id)
return True
except Exception as e:
logger.error(f"创建钉钉卡片失败: {e}")
return False
async def send_card_message(self, message_id: str, content: str, is_final: bool):
if message_id not in self.card_instance_id_dict:
return
card_instance, card_instance_id = self.card_instance_id_dict[message_id]
content_key = "content"
try:
# 钉钉卡片流式更新
await card_instance.async_streaming(
card_instance_id,
content_key=content_key,
content_value=content,
append=False,
finished=is_final,
failed=False,
)
except Exception as e:
logger.error(f"发送钉钉卡片消息失败: {e}")
# Try to report failure
try:
await card_instance.async_streaming(
card_instance_id,
content_key=content_key,
content_value=content, # Keep existing content
append=False,
finished=True,
failed=True,
)
except Exception:
pass
if is_final:
self.card_instance_id_dict.pop(message_id, None)
async def convert_msg(
self,
message: dingtalk_stream.ChatbotMessage,
@@ -215,8 +202,35 @@ class DingtalkPlatformAdapter(Platform):
case "audio":
pass
await self._remember_sender_binding(message, abm)
return abm # 别忘了返回转换后的消息对象
async def _remember_sender_binding(
self,
message: dingtalk_stream.ChatbotMessage,
abm: AstrBotMessage,
) -> None:
try:
if abm.type == MessageType.FRIEND_MESSAGE:
sender_id = abm.sender.user_id
sender_staff_id = cast(str, message.sender_staff_id or "")
if sender_staff_id:
umo = str(
MessageSesion(
platform_name=self.meta().id,
message_type=abm.type,
session_id=sender_id,
)
)
await sp.put_async(
"global",
umo,
"dingtalk_staffid",
sender_staff_id,
)
except Exception as e:
logger.warning(f"保存钉钉会话映射失败: {e}")
async def download_ding_file(
self,
download_code: str,
@@ -239,8 +253,9 @@ class DingtalkPlatformAdapter(Platform):
"downloadCode": download_code,
"robotCode": robot_code,
}
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}")
temp_dir = Path(get_astrbot_data_path()) / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)
f_path = temp_dir / f"dingtalk_file_{uuid.uuid4()}.{ext}"
async with (
aiohttp.ClientSession() as session,
session.post(
@@ -256,14 +271,21 @@ class DingtalkPlatformAdapter(Platform):
return ""
resp_data = await resp.json()
download_url = resp_data["data"]["downloadUrl"]
await download_file(download_url, f_path)
return f_path
await download_file(download_url, str(f_path))
return str(f_path)
async def get_access_token(self) -> str:
payload = {
"appKey": self.client_id,
"appSecret": self.client_secret,
}
try:
access_token = await asyncio.get_event_loop().run_in_executor(
None,
self.client_.get_access_token,
)
if access_token:
return access_token
except Exception as e:
logger.warning(f"通过 dingtalk_stream 获取 access_token 失败: {e}")
payload = {"appKey": self.client_id, "appSecret": self.client_secret}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/oauth2/accessToken",
@@ -274,9 +296,330 @@ class DingtalkPlatformAdapter(Platform):
f"获取钉钉机器人 access_token 失败: {resp.status}, {await resp.text()}",
)
return ""
return (await resp.json())["data"]["accessToken"]
data = await resp.json()
return cast(str, data.get("data", {}).get("accessToken", ""))
async def handle_msg(self, abm: AstrBotMessage):
async def _get_sender_staff_id(self, session: MessageSesion) -> str:
try:
staff_id = await sp.get_async(
"global",
str(session),
"dingtalk_staffid",
"",
)
return cast(str, staff_id or "")
except Exception as e:
logger.warning(f"读取钉钉 staff_id 映射失败: {e}")
return ""
async def _send_group_message(
self,
open_conversation_id: str,
robot_code: str,
msg_key: str,
msg_param: dict,
) -> None:
access_token = await self.get_access_token()
if not access_token:
logger.error("钉钉群消息发送失败: access_token 为空")
return
payload = {
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
"openConversationId": open_conversation_id,
"robotCode": robot_code,
}
headers = {
"Content-Type": "application/json",
"x-acs-dingtalk-access-token": access_token,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/robot/groupMessages/send",
headers=headers,
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"钉钉群消息发送失败: {resp.status}, {await resp.text()}",
)
async def _send_private_message(
self,
staff_id: str,
robot_code: str,
msg_key: str,
msg_param: dict,
) -> None:
access_token = await self.get_access_token()
if not access_token:
logger.error("钉钉私聊消息发送失败: access_token 为空")
return
payload = {
"robotCode": robot_code,
"userIds": [staff_id],
"msgKey": msg_key,
"msgParam": json.dumps(msg_param, ensure_ascii=False),
}
headers = {
"Content-Type": "application/json",
"x-acs-dingtalk-access-token": access_token,
}
async with aiohttp.ClientSession() as session:
async with session.post(
"https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend",
headers=headers,
json=payload,
) as resp:
if resp.status != 200:
logger.error(
f"钉钉私聊消息发送失败: {resp.status}, {await resp.text()}",
)
def _safe_remove_file(self, file_path: str | None) -> None:
if not file_path:
return
try:
p = Path(file_path)
if p.exists() and p.is_file():
p.unlink()
except Exception as e:
logger.warning(f"清理临时文件失败: {file_path}, {e}")
async def _prepare_voice_for_dingtalk(self, input_path: str) -> tuple[str, bool]:
"""优先转换为 OGG(Opus),不可用时回退 AMR。"""
lower_path = input_path.lower()
if lower_path.endswith((".amr", ".ogg")):
return input_path, False
try:
converted = await convert_audio_format(input_path, "ogg")
return converted, converted != input_path
except Exception as e:
logger.warning(f"钉钉语音转 OGG 失败,回退 AMR: {e}")
converted = await convert_audio_format(input_path, "amr")
return converted, converted != input_path
async def upload_media(self, file_path: str, media_type: str) -> str:
media_file_path = Path(file_path)
access_token = await self.get_access_token()
if not access_token:
logger.error("钉钉媒体上传失败: access_token 为空")
return ""
form = aiohttp.FormData()
form.add_field(
"media",
media_file_path.read_bytes(),
filename=media_file_path.name,
content_type="application/octet-stream",
)
async with aiohttp.ClientSession() as session:
async with session.post(
f"https://oapi.dingtalk.com/media/upload?access_token={access_token}&type={media_type}",
data=form,
) as resp:
if resp.status != 200:
logger.error(
f"钉钉媒体上传失败: {resp.status}, {await resp.text()}"
)
return ""
data = await resp.json()
if data.get("errcode") != 0:
logger.error(f"钉钉媒体上传失败: {data}")
return ""
return cast(str, data.get("media_id", ""))
async def upload_image(self, image: Image) -> str:
image_file_path = await image.convert_to_file_path()
return await self.upload_media(image_file_path, "image")
async def _send_message_chain(
self,
target_type: Literal["group", "user"],
target_id: str,
robot_code: str,
message_chain: MessageChain,
at_str: str = "",
) -> None:
async def send_message(msg_key: str, msg_param: dict) -> None:
if target_type == "group":
await self._send_group_message(
open_conversation_id=target_id,
robot_code=robot_code,
msg_key=msg_key,
msg_param=msg_param,
)
else:
await self._send_private_message(
staff_id=target_id,
robot_code=robot_code,
msg_key=msg_key,
msg_param=msg_param,
)
for segment in message_chain.chain:
if isinstance(segment, Plain):
text = segment.text.strip()
if not text and not at_str:
continue
await send_message(
msg_key="sampleMarkdown",
msg_param={
"title": "AstrBot",
"text": f"{at_str} {text}".strip(),
},
)
elif isinstance(segment, Image):
photo_url = segment.file or segment.url or ""
if photo_url.startswith(("http://", "https://")):
pass
else:
photo_url = await self.upload_image(segment)
if not photo_url:
continue
await send_message(
msg_key="sampleImageMsg",
msg_param={"photoURL": photo_url},
)
elif isinstance(segment, Record):
converted_audio = None
try:
audio_path = await segment.convert_to_file_path()
(
audio_path,
converted_audio,
) = await self._prepare_voice_for_dingtalk(audio_path)
media_id = await self.upload_media(audio_path, "voice")
if not media_id:
continue
duration_ms = await get_media_duration(audio_path)
await send_message(
msg_key="sampleAudio",
msg_param={
"mediaId": media_id,
"duration": str(duration_ms or 1000),
},
)
except Exception as e:
logger.warning(f"钉钉语音发送失败: {e}")
continue
finally:
if converted_audio:
self._safe_remove_file(audio_path)
elif isinstance(segment, Video):
converted_video = False
cover_path = None
try:
source_video_path = await segment.convert_to_file_path()
video_path = source_video_path
if not video_path.lower().endswith(".mp4"):
video_path = await convert_video_format(video_path, "mp4")
converted_video = video_path != source_video_path
cover_path = await extract_video_cover(video_path)
video_media_id = await self.upload_media(video_path, "file")
pic_media_id = await self.upload_media(cover_path, "image")
if not video_media_id or not pic_media_id:
continue
duration_ms = await get_media_duration(video_path)
duration_sec = max(1, int((duration_ms or 1000) / 1000))
await send_message(
msg_key="sampleVideo",
msg_param={
"duration": str(duration_sec),
"videoMediaId": video_media_id,
"videoType": "mp4",
"picMediaId": pic_media_id,
},
)
except Exception as e:
logger.warning(f"钉钉视频发送失败: {e}")
continue
finally:
self._safe_remove_file(cover_path)
if converted_video:
self._safe_remove_file(video_path)
async def send_message_chain_to_group(
self,
open_conversation_id: str,
robot_code: str,
message_chain: MessageChain,
at_str: str = "",
) -> None:
await self._send_message_chain(
target_type="group",
target_id=open_conversation_id,
robot_code=robot_code,
message_chain=message_chain,
at_str=at_str,
)
async def send_message_chain_to_user(
self,
staff_id: str,
robot_code: str,
message_chain: MessageChain,
at_str: str = "",
) -> None:
await self._send_message_chain(
target_type="user",
target_id=staff_id,
robot_code=robot_code,
message_chain=message_chain,
at_str=at_str,
)
async def send_message_chain_with_incoming(
self,
incoming_message: dingtalk_stream.ChatbotMessage,
message_chain: MessageChain,
) -> None:
robot_code = self.client_id
# at_list: list[str] = []
sender_id = cast(str, incoming_message.sender_id or "")
sender_staff_id = cast(str, incoming_message.sender_staff_id or "")
normalized_sender_id = self._id_to_sid(sender_id)
# 现在用的发消息接口不支持 at
# for segment in message_chain.chain:
# if isinstance(segment, At):
# if (
# str(segment.qq) in {sender_id, normalized_sender_id}
# and sender_staff_id
# ):
# at_list.append(f"@{sender_staff_id}")
# else:
# at_list.append(f"@{segment.qq}")
# at_str = " ".join(at_list)
if incoming_message.conversation_type == "2":
await self.send_message_chain_to_group(
open_conversation_id=cast(str, incoming_message.conversation_id),
robot_code=robot_code,
message_chain=message_chain,
# at_str=at_str,
)
else:
session = MessageSesion(
platform_name=self.meta().id,
message_type=MessageType.FRIEND_MESSAGE,
session_id=normalized_sender_id,
)
staff_id = sender_staff_id or await self._get_sender_staff_id(session)
if not staff_id:
logger.error("钉钉私聊回复失败: 缺少 sender_staff_id")
return
await self.send_message_chain_to_user(
staff_id=staff_id,
robot_code=robot_code,
message_chain=message_chain,
# at_str=at_str,
)
async def handle_msg(self, abm: AstrBotMessage) -> None:
event = DingtalkMessageEvent(
message_str=abm.message_str,
message_obj=abm,
@@ -288,10 +631,10 @@ class DingtalkPlatformAdapter(Platform):
self._event_queue.put_nowait(event)
async def run(self):
async def run(self) -> None:
# await self.client_.start()
# 钉钉的 SDK 并没有实现真正的异步,start() 里面有堵塞方法。
def start_client(loop: asyncio.AbstractEventLoop):
def start_client(loop: asyncio.AbstractEventLoop) -> None:
try:
self._shutdown_event = threading.Event()
task = loop.create_task(self.client_.start())
@@ -307,8 +650,8 @@ class DingtalkPlatformAdapter(Platform):
loop = asyncio.get_event_loop()
await loop.run_in_executor(None, start_client, loop)
async def terminate(self):
def monkey_patch_close():
async def terminate(self) -> None:
def monkey_patch_close() -> NoReturn:
raise KeyboardInterrupt("Graceful shutdown")
if self.client_.websocket is not None:
@@ -1,9 +1,5 @@
import asyncio
from typing import Any, cast
from typing import Any
import dingtalk_stream
import astrbot.api.message_components as Comp
from astrbot import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
@@ -15,128 +11,33 @@ class DingtalkMessageEvent(AstrMessageEvent):
message_obj,
platform_meta,
session_id,
client: dingtalk_stream.ChatbotHandler,
client: Any = None,
adapter: "Any" = None,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.adapter = adapter
async def send_with_client(
self,
client: dingtalk_stream.ChatbotHandler,
message: MessageChain,
):
icm = cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message)
ats = []
# fixes: #4218
# 钉钉 at 机器人需要使用 sender_staff_id 而不是 sender_id
for i in message.chain:
if isinstance(i, Comp.At):
print(i.qq, icm.sender_id, icm.sender_staff_id)
if str(i.qq) in str(icm.sender_id or ""):
# 适配器会将开头的 $:LWCP_v1:$ 去掉,因此我们用 in 判断
ats.append(f"@{icm.sender_staff_id}")
else:
ats.append(f"@{i.qq}")
at_str = " ".join(ats)
for segment in message.chain:
if isinstance(segment, Comp.Plain):
segment.text = segment.text.strip()
await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
segment.text,
f"{at_str} {segment.text}".strip(),
cast(dingtalk_stream.ChatbotMessage, self.message_obj.raw_message),
)
elif isinstance(segment, Comp.Image):
markdown_str = ""
try:
if not segment.file:
logger.warning("钉钉图片 segment 缺少 file 字段,跳过")
continue
if segment.file.startswith(("http://", "https://")):
image_url = segment.file
else:
image_url = await segment.register_to_file_service()
markdown_str = f"![image]({image_url})\n\n"
ret = await asyncio.get_event_loop().run_in_executor(
None,
client.reply_markdown,
"😄",
markdown_str,
cast(
dingtalk_stream.ChatbotMessage, self.message_obj.raw_message
),
)
logger.debug(f"send image: {ret}")
except Exception as e:
logger.warning(f"钉钉图片处理失败: {e}, 跳过图片发送")
continue
async def send(self, message: MessageChain):
await self.send_with_client(self.client, message)
async def send(self, message: MessageChain) -> None:
if not self.adapter:
logger.error("钉钉消息发送失败: 缺少 adapter")
return
await self.adapter.send_message_chain_with_incoming(
incoming_message=self.message_obj.raw_message,
message_chain=message,
)
await super().send(message)
async def send_streaming(self, generator, use_fallback: bool = False):
if not self.adapter or not self.adapter.card_template_id:
logger.warning(
f"DingTalk streaming is enabled, but 'card_template_id' is not configured for platform '{self.platform_meta.id}'. Falling back to text streaming."
)
# Fallback to default behavior (buffer and send)
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
# 钉钉统一回退为缓冲发送:最终发送仍使用新的 HTTP 消息接口。
buffer = None
async for chain in generator:
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
# Create card
msg_id = self.message_obj.message_id
incoming_msg = self.message_obj.raw_message
created = await self.adapter.create_message_card(msg_id, incoming_msg)
if not created:
# Fallback to default behavior (buffer and send)
buffer = None
async for chain in generator:
if not buffer:
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
full_content = ""
seq = 0
try:
async for chain in generator:
for segment in chain.chain:
if isinstance(segment, Comp.Plain):
full_content += segment.text
seq += 1
if seq % 2 == 0: # Update every 2 chunks to be more responsive than 8
await self.adapter.send_card_message(
msg_id, full_content, is_final=False
)
await self.adapter.send_card_message(msg_id, full_content, is_final=True)
except Exception as e:
logger.error(f"DingTalk streaming error: {e}")
# Try to ensure final state is sent or cleaned up?
await self.adapter.send_card_message(msg_id, full_content, is_final=True)
buffer = chain
else:
buffer.chain.extend(chain.chain)
if not buffer:
return None
buffer.squash_plain()
await self.send(buffer)
return await super().send_streaming(generator, use_fallback)
@@ -15,7 +15,7 @@ else:
class DiscordBotClient(discord.Bot):
"""Discord客户端封装"""
def __init__(self, token: str, proxy: str | None = None):
def __init__(self, token: str, proxy: str | None = None) -> None:
self.token = token
self.proxy = proxy
@@ -32,7 +32,7 @@ class DiscordBotClient(discord.Bot):
self.on_ready_once_callback: Callable[[], Awaitable[None]] | None = None
self._ready_once_fired = False
async def on_ready(self):
async def on_ready(self) -> None:
"""当机器人成功连接并准备就绪时触发"""
if self.user is None:
logger.error("[Discord] 客户端未正确加载用户信息 (self.user is None)")
@@ -93,7 +93,7 @@ class DiscordBotClient(discord.Bot):
"type": "interaction",
}
async def on_message(self, message: discord.Message):
async def on_message(self, message: discord.Message) -> None:
"""当接收到消息时触发"""
if message.author.bot:
return
@@ -130,12 +130,12 @@ class DiscordBotClient(discord.Bot):
return str(interaction_data)
async def start_polling(self):
async def start_polling(self) -> None:
"""开始轮询消息,这是个阻塞方法"""
await self.start(self.token)
@override
async def close(self):
async def close(self) -> None:
"""关闭客户端"""
if not self.is_closed():
await super().close()
@@ -19,7 +19,7 @@ class DiscordEmbed(BaseMessageComponent):
image: str | None = None,
footer: str | None = None,
fields: list[dict] | None = None,
):
) -> None:
self.title = title
self.description = description
self.color = color
@@ -71,7 +71,7 @@ class DiscordButton(BaseMessageComponent):
emoji: str | None = None,
url: str | None = None,
disabled: bool = False,
):
) -> None:
self.label = label
self.custom_id = custom_id
self.style = style
@@ -85,7 +85,7 @@ class DiscordReference(BaseMessageComponent):
type: str = "discord_reference"
def __init__(self, message_id: str, channel_id: str):
def __init__(self, message_id: str, channel_id: str) -> None:
self.message_id = message_id
self.channel_id = channel_id
@@ -99,7 +99,7 @@ class DiscordView(BaseMessageComponent):
self,
components: list[BaseMessageComponent] | None = None,
timeout: float | None = None,
):
) -> None:
self.components = components or []
self.timeout = timeout
@@ -60,7 +60,7 @@ class DiscordPlatformAdapter(Platform):
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
"""通过会话发送消息"""
if self.client.user is None:
logger.error(
@@ -122,11 +122,11 @@ class DiscordPlatformAdapter(Platform):
)
@override
async def run(self):
async def run(self) -> None:
"""主要运行逻辑"""
# 初始化回调函数
async def on_received(message_data):
async def on_received(message_data) -> None:
logger.debug(f"[Discord] 收到消息: {message_data}")
if self.client_self_id is None:
self.client_self_id = message_data.get("bot_id")
@@ -143,7 +143,7 @@ class DiscordPlatformAdapter(Platform):
self.client = DiscordBotClient(token, proxy)
self.client.on_message_received = on_received
async def callback():
async def callback() -> None:
if self.enable_command_register:
await self._collect_and_register_commands()
if self.activity_name:
@@ -251,7 +251,7 @@ class DiscordPlatformAdapter(Platform):
# 由于 on_interaction 已被禁用,我们只处理普通消息
return self._convert_message_to_abm(data)
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None):
async def handle_msg(self, message: AstrBotMessage, followup_webhook=None) -> None:
"""处理消息"""
message_event = DiscordPlatformEvent(
message_str=message.message_str,
@@ -323,7 +323,7 @@ class DiscordPlatformAdapter(Platform):
self.commit_event(message_event)
@override
async def terminate(self):
async def terminate(self) -> None:
"""终止适配器"""
logger.info("[Discord] 正在终止适配器... (step 1: cancel polling task)")
self.shutdown_event.set()
@@ -358,11 +358,11 @@ class DiscordPlatformAdapter(Platform):
logger.warning(f"[Discord] 客户端关闭异常: {e}")
logger.info("[Discord] 适配器已终止。")
def register_handler(self, handler_info):
def register_handler(self, handler_info) -> None:
"""注册处理器信息"""
self.registered_handlers.append(handler_info)
async def _collect_and_register_commands(self):
async def _collect_and_register_commands(self) -> None:
"""收集所有指令并注册到Discord"""
logger.info("[Discord] 开始收集并注册斜杠指令...")
registered_commands = []
@@ -420,7 +420,7 @@ class DiscordPlatformAdapter(Platform):
async def dynamic_callback(
ctx: discord.ApplicationContext, params: str | None = None
):
) -> None:
# 将平台特定的前缀'/'剥离,以适配通用的CommandFilter
logger.debug(f"[Discord] 回调函数触发: {cmd_name}")
logger.debug(f"[Discord] 回调函数参数: {ctx}")
@@ -28,7 +28,7 @@ from .components import DiscordEmbed, DiscordView
class DiscordViewComponent(BaseMessageComponent):
type: str = "discord_view"
def __init__(self, view: discord.ui.View):
def __init__(self, view: discord.ui.View) -> None:
self.view = view
@@ -41,12 +41,12 @@ class DiscordPlatformEvent(AstrMessageEvent):
session_id: str,
client: DiscordBotClient,
interaction_followup_webhook: discord.Webhook | None = None,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
self.interaction_followup_webhook = interaction_followup_webhook
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
"""发送消息到Discord平台"""
# 解析消息链为 Discord 所需的对象
try:
@@ -267,7 +267,7 @@ class DiscordPlatformEvent(AstrMessageEvent):
content = content[:2000]
return content, files, view, embeds, reference_message_id
async def react(self, emoji: str):
async def react(self, emoji: str) -> None:
"""对原消息添加反应"""
try:
if hasattr(self.message_obj, "raw_message") and hasattr(
@@ -53,10 +53,10 @@ class LarkPlatformAdapter(Platform):
logger.warning("未设置飞书机器人名称,@ 机器人可能得不到回复。")
# 初始化 WebSocket 长连接相关配置
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1):
async def on_msg_event_recv(event: lark.im.v1.P2ImMessageReceiveV1) -> None:
await self.convert_msg(event)
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1):
def do_v2_msg_event(event: lark.im.v1.P2ImMessageReceiveV1) -> None:
asyncio.create_task(on_msg_event_recv(event))
self.event_handler = (
@@ -91,7 +91,7 @@ class LarkPlatformAdapter(Platform):
self.event_id_timestamps: dict[str, float] = {}
def _clean_expired_events(self):
def _clean_expired_events(self) -> None:
"""清理超过 30 分钟的事件记录"""
current_time = time.time()
expired_keys = [
@@ -121,7 +121,7 @@ class LarkPlatformAdapter(Platform):
self,
session: MessageSesion,
message_chain: MessageChain,
):
) -> None:
if session.message_type == MessageType.GROUP_MESSAGE:
id_type = "chat_id"
receive_id = session.session_id
@@ -149,7 +149,7 @@ class LarkPlatformAdapter(Platform):
support_streaming_message=False,
)
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1) -> None:
if event.event is None:
logger.debug("[Lark] 收到空事件(event.event is None)")
return
@@ -299,7 +299,7 @@ class LarkPlatformAdapter(Platform):
logger.debug(abm)
await self.handle_msg(abm)
async def handle_msg(self, abm: AstrBotMessage):
async def handle_msg(self, abm: AstrBotMessage) -> None:
event = LarkMessageEvent(
message_str=abm.message_str,
message_obj=abm,
@@ -310,7 +310,7 @@ class LarkPlatformAdapter(Platform):
self._event_queue.put_nowait(event)
async def handle_webhook_event(self, event_data: dict):
async def handle_webhook_event(self, event_data: dict) -> None:
"""处理 Webhook 事件
Args:
@@ -332,7 +332,7 @@ class LarkPlatformAdapter(Platform):
except Exception as e:
logger.error(f"[Lark Webhook] 处理事件失败: {e}", exc_info=True)
async def run(self):
async def run(self) -> None:
if self.connection_mode == "webhook":
# Webhook 模式
if self.webhook_server is None:
@@ -355,7 +355,7 @@ class LarkPlatformAdapter(Platform):
return await self.webhook_server.handle_callback(request)
async def terminate(self):
async def terminate(self) -> None:
if self.connection_mode == "socket":
await self.client._disconnect()
logger.info("飞书(Lark) 适配器已关闭")
@@ -38,7 +38,7 @@ class LarkMessageEvent(AstrMessageEvent):
platform_meta,
session_id,
bot: lark.Client,
):
) -> None:
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
@@ -274,7 +274,7 @@ class LarkMessageEvent(AstrMessageEvent):
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
) -> None:
"""通用的消息链发送方法
Args:
@@ -342,7 +342,7 @@ class LarkMessageEvent(AstrMessageEvent):
media_comp, lark_client, reply_message_id, receive_id, receive_id_type
)
async def send(self, message: MessageChain):
async def send(self, message: MessageChain) -> None:
"""发送消息链到飞书,然后交给父类做框架级发送/记录"""
await LarkMessageEvent.send_message_chain(
message,
@@ -358,7 +358,7 @@ class LarkMessageEvent(AstrMessageEvent):
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
) -> None:
"""发送文件消息
Args:
@@ -392,7 +392,7 @@ class LarkMessageEvent(AstrMessageEvent):
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
) -> None:
"""发送音频消息
Args:
@@ -465,7 +465,7 @@ class LarkMessageEvent(AstrMessageEvent):
reply_message_id: str | None = None,
receive_id: str | None = None,
receive_id_type: str | None = None,
):
) -> None:
"""发送视频消息
Args:
@@ -531,7 +531,7 @@ class LarkMessageEvent(AstrMessageEvent):
receive_id_type=receive_id_type,
)
async def react(self, emoji: str):
async def react(self, emoji: str) -> None:
if self.bot.im is None:
logger.error("[Lark] API Client im 模块未初始化,无法发送表情")
return
+3 -3
View File
@@ -21,7 +21,7 @@ from astrbot.api import logger
class AESCipher:
"""AES 加密/解密工具类"""
def __init__(self, key: str):
def __init__(self, key: str) -> None:
self.bs = AES.block_size
self.key = hashlib.sha256(self.str_to_bytes(key)).digest()
@@ -52,7 +52,7 @@ class LarkWebhookServer:
仅支持统一 Webhook 模式
"""
def __init__(self, config: dict, event_queue: asyncio.Queue):
def __init__(self, config: dict, event_queue: asyncio.Queue) -> None:
"""初始化 Webhook 服务器
Args:
@@ -197,7 +197,7 @@ class LarkWebhookServer:
return {}
def set_callback(self, callback: Callable[[dict], Awaitable[None]]):
def set_callback(self, callback: Callable[[dict], Awaitable[None]]) -> None:
"""设置事件回调函数
Args:

Some files were not shown because too many files have changed in this diff Show More