Compare commits
264 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3c8ec2f42e | |||
| 7e193f7f52 | |||
| 7069b02929 | |||
| 66995db927 | |||
| c36054ca1b | |||
| 3e07fbf3dc | |||
| bf3fbe3e96 | |||
| 0a93d22bc8 | |||
| f5b3d94d16 | |||
| 4d1a6994aa | |||
| 05c686782c | |||
| 85609ea742 | |||
| 20dabc0615 | |||
| 356dd9bc2b | |||
| cd5d7534c4 | |||
| b4f12fc933 | |||
| cbea387ce0 | |||
| 345b155374 | |||
| 29d216950e | |||
| 321b04772c | |||
| 5b924aee98 | |||
| 46d44e3405 | |||
| 4d5332fe25 | |||
| 18bd4c54f4 | |||
| 31c7768ca0 | |||
| 6ec643e9d1 | |||
| 2b39f6f61c | |||
| bf3ca13961 | |||
| 82026370ec | |||
| 6d49bf5346 | |||
| 67431d87fb | |||
| fdf55221e6 | |||
| 07f277dd3b | |||
| cf8f0603ca | |||
| 5592408ab8 | |||
| a01617b45c | |||
| 7abb4087b3 | |||
| dff15cf27a | |||
| aa858137e5 | |||
| 45cb143202 | |||
| 7a9c6ab8c4 | |||
| e2c26c292d | |||
| be7c3fd00e | |||
| 7e5461a2cf | |||
| 6ee9010645 | |||
| a23d5be056 | |||
| 97a6a1fdc2 | |||
| c8f567347b | |||
| 74c1e7f69e | |||
| 15a5fc0cae | |||
| f07c54d47c | |||
| 70446be108 | |||
| d6d21fca56 | |||
| 8d7273924f | |||
| ea64afbaa7 | |||
| 45da9837ec | |||
| 8c19b7d163 | |||
| ab227a08d0 | |||
| 40d6e77964 | |||
| 9326e3f1b0 | |||
| 0e1eb3daf6 | |||
| 05daac12ed | |||
| c5b24b4764 | |||
| cc16548e5f | |||
| 291d65bb3e | |||
| bd3ad03da6 | |||
| 5fa6788357 | |||
| c5c5a98ac4 | |||
| a1151143cf | |||
| f5024984f7 | |||
| f4880fd90d | |||
| 0ae61d5865 | |||
| d3bd775a79 | |||
| da546cfe7f | |||
| a211933e83 | |||
| 1d40b5a821 | |||
| 33836daeb7 | |||
| d921b0f6bd | |||
| 0607b95df6 | |||
| 0de6d0e046 | |||
| 98427345cf | |||
| 9fedaa9f77 | |||
| bf4c2ecd33 | |||
| f8c18cc1e0 | |||
| 458b900412 | |||
| 192c776e0b | |||
| 5cdec18863 | |||
| 15f856f951 | |||
| 01d52cef74 | |||
| 95563c8659 | |||
| 31d8c40eca | |||
| 56001ed272 | |||
| d916fda04c | |||
| cfae655068 | |||
| 5596565ec4 | |||
| afa1aa5d93 | |||
| e98c3d8393 | |||
| 6687b816f0 | |||
| ea8035e854 | |||
| 54b0171d49 | |||
| 676d4277b9 | |||
| a4b1da3ca2 | |||
| 9e9c16e770 | |||
| dc87006fed | |||
| b9b260f26a | |||
| 33fd6a5016 | |||
| 97cbccc2ba | |||
| 1ee4685d5d | |||
| aba18232b1 | |||
| 0a02441b75 | |||
| 1be5b4c7ff | |||
| a0ce0cf18a | |||
| 7c54e5d093 | |||
| b825e51dab | |||
| 589855c393 | |||
| 4c546f2f53 | |||
| 3753fce912 | |||
| 4c02857ec5 | |||
| 33f87ff7d7 | |||
| 784dcf2a9a | |||
| 43ee943acb | |||
| a769fd7d13 | |||
| 2c4fd00b16 | |||
| 264771fe98 | |||
| ecd92dafef | |||
| c8b6e4bea3 | |||
| 3756cb766e | |||
| 068d9ca60b | |||
| 93f632d8b8 | |||
| bb44ce7e74 | |||
| 6986c8d8f7 | |||
| fe95506db4 | |||
| 310ed76b18 | |||
| 98830d147f | |||
| 19c9177d7b | |||
| f41c5f97f6 | |||
| 648c125697 | |||
| 0dc2b89897 | |||
| 83745f83a5 | |||
| 2f91fe4535 | |||
| 739f09059e | |||
| c86f9f0f5f | |||
| 9470ca6bc5 | |||
| 2a92c4d5de | |||
| bb6e892657 | |||
| c9079b9299 | |||
| b6963c1bf9 | |||
| 9c29df47bb | |||
| fc146d3d00 | |||
| 1bf5a21678 | |||
| 011542dc2b | |||
| 489784104e | |||
| 3860634fd2 | |||
| 709c324e18 | |||
| b75d24d92c | |||
| ed80e9424c | |||
| 2fe1f2060a | |||
| c6df820164 | |||
| d6239822db | |||
| bced9ffff9 | |||
| d7d1c1544a | |||
| 7c1e8ce48c | |||
| e3b0ca8ef6 | |||
| 9e266eb6d5 | |||
| 7231403e16 | |||
| 344a486fd7 | |||
| 4fd831875d | |||
| 0988d067ea | |||
| 44dbe475af | |||
| bd24cf3ea4 | |||
| b493a808fe | |||
| 54035d108d | |||
| c5e8bc7e20 | |||
| 3bbb4779a3 | |||
| 1b3963ebea | |||
| 3b6dd7e15a | |||
| 757d2a3947 | |||
| 61b71143f2 | |||
| 1b343a36c9 | |||
| 8e94937060 | |||
| e8ffebc006 | |||
| 2ca95eaa9f | |||
| 0dc5b4cdfc | |||
| cc6cd96d8e | |||
| 4244d37625 | |||
| 0b766095d4 | |||
| a4f212a18f | |||
| caafb73190 | |||
| 09482799c9 | |||
| 37f93d1760 | |||
| 725f2e5204 | |||
| 967198fae0 | |||
| 43d57f6dcb | |||
| 6afa4db577 | |||
| 3b8c3fb29a | |||
| 921c3b0627 | |||
| c0fadb45ab | |||
| a1481fb179 | |||
| 987cd972d3 | |||
| bdf25976a3 | |||
| 87c3aff4ce | |||
| 99350a957a | |||
| 319068dc7e | |||
| cd18806c39 | |||
| 95b08b2023 | |||
| 0e70f76c86 | |||
| 4d414a2994 | |||
| 3d22772d4e | |||
| 0b381e2570 | |||
| f2cc4311c5 | |||
| e349671fdf | |||
| 01c02d5efa | |||
| b62b1f3870 | |||
| 8844830859 | |||
| 0c51ee4b64 | |||
| 11920d5e31 | |||
| 848ea1eb63 | |||
| a216519486 | |||
| b04606c38e | |||
| 38072beea7 | |||
| b843f1fa03 | |||
| 560d40e571 | |||
| 5f0b8161b7 | |||
| 062d482917 | |||
| 39693a27e3 | |||
| 7cd1eeac30 | |||
| bafa473c8e | |||
| 750cf46b2e | |||
| 68885a4bbc | |||
| bcc99a8904 | |||
| 59fbd98db3 | |||
| b70ed425f1 | |||
| 45ef5811c8 | |||
| 3b137ac762 | |||
| 1ddb0caf73 | |||
| ae4c6fe2dd | |||
| b03fe438d0 | |||
| db257af58e | |||
| 735368c71b | |||
| 9e04e3679b | |||
| 43b8414727 | |||
| 5a00187147 | |||
| cb525c7c84 | |||
| d88420dd03 | |||
| b9a983f8e0 | |||
| 42431ea7db | |||
| f9459e4abb | |||
| 72f917d611 | |||
| 9fd1d19e93 | |||
| 41bd76e091 | |||
| cfd3f4b199 | |||
| b3866559e1 | |||
| 8ed3d5f3db | |||
| f0c8f39b6d | |||
| 431db8fc9b | |||
| ba252c5356 | |||
| a2812c39c0 | |||
| 0490758820 | |||
| 7f56824b42 | |||
| 627da3a2bc | |||
| 9b36a5c8a6 | |||
| c1cf2be533 | |||
| e6b69042de | |||
| 109650faf3 |
+4
-1
@@ -17,4 +17,7 @@ ENV/
|
||||
.conda/
|
||||
README*.md
|
||||
dashboard/
|
||||
data/
|
||||
data/
|
||||
changelogs/
|
||||
tests/
|
||||
.ruff_cache/
|
||||
@@ -0,0 +1,15 @@
|
||||
# These are supported funding model platforms
|
||||
|
||||
github: # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
|
||||
patreon: # Replace with a single Patreon username
|
||||
open_collective: astrbot
|
||||
ko_fi: # Replace with a single Ko-fi username
|
||||
tidelift: # Replace with a single Tidelift platform-name/package-name e.g., npm/babel
|
||||
community_bridge: # Replace with a single Community Bridge project-name e.g., cloud-foundry
|
||||
liberapay: # Replace with a single Liberapay username
|
||||
issuehunt: # Replace with a single IssueHunt username
|
||||
lfx_crowdfunding: # Replace with a single LFX Crowdfunding project-name e.g., cloud-foundry
|
||||
polar: # Replace with a single Polar username
|
||||
buy_me_a_coffee: # Replace with a single Buy Me a Coffee username
|
||||
thanks_dev: # Replace with a single thanks.dev username
|
||||
custom: ['https://afdian.com/a/astrbot_team']
|
||||
@@ -1,5 +1,5 @@
|
||||
<!-- 如果有的话,指定这个 PR 要解决的 ISSUE -->
|
||||
修复了 #XYZ
|
||||
解决了 #XYZ
|
||||
|
||||
### Motivation
|
||||
|
||||
@@ -8,3 +8,12 @@
|
||||
### Modifications
|
||||
|
||||
<!--简单解释你的改动-->
|
||||
|
||||
### Check
|
||||
|
||||
<!--如果分支被合并,您的代码将服务于数万名用户!在提交前,请核查一下几点内容-->
|
||||
|
||||
- [ ] 😊 我的 Commit Message 符合良好的[规范](https://www.conventionalcommits.org/en/v1.0.0/#summary)
|
||||
- [ ] 👀 我的更改经过良好的测试
|
||||
- [ ] 🤓 我确保没有引入新依赖库,或者引入了新依赖库的同时将其添加到了 `requirements.txt` 和 `pyproject.toml` 文件相应位置。
|
||||
- [ ] 😮 我的更改没有引入恶意代码
|
||||
|
||||
@@ -7,7 +7,7 @@ on:
|
||||
name: Auto Release
|
||||
|
||||
jobs:
|
||||
build:
|
||||
build-and-publish-to-github-release:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
@@ -28,8 +28,35 @@ jobs:
|
||||
run: |
|
||||
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
|
||||
|
||||
- name: Create Release
|
||||
- name: Create GitHub Release
|
||||
uses: ncipollo/release-action@v1
|
||||
with:
|
||||
bodyFile: ${{ env.changelog }}
|
||||
artifacts: "dashboard/dist.zip"
|
||||
artifacts: "dashboard/dist.zip"
|
||||
|
||||
build-and-publish-to-pypi:
|
||||
# 构建并发布到 PyPI
|
||||
runs-on: ubuntu-latest
|
||||
needs: build-and-publish-to-github-release
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v4
|
||||
with:
|
||||
python-version: '3.10'
|
||||
|
||||
- name: Install uv
|
||||
run: |
|
||||
python -m pip install uv
|
||||
|
||||
- name: Build package
|
||||
run: |
|
||||
uv build
|
||||
|
||||
- name: Publish to PyPI
|
||||
env:
|
||||
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
|
||||
run: |
|
||||
uv publish
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
3.10
|
||||
@@ -4,6 +4,8 @@ WORKDIR /AstrBot
|
||||
COPY . /AstrBot/
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
nodejs \
|
||||
npm \
|
||||
gcc \
|
||||
build-essential \
|
||||
python3-dev \
|
||||
@@ -28,3 +30,6 @@ EXPOSE 6185
|
||||
EXPOSE 6186
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,12 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
[](https://github.com/Soulter/AstrBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10+-blue.svg?style=for-the-badge&color=76bad9" alt="python">
|
||||
<a href="https://hub.docker.com/r/soulter/astrbot"><img alt="Docker pull" src="https://img.shields.io/docker/pulls/soulter/astrbot.svg?style=for-the-badge&color=76bad9"/></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="Static Badge" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://qm.qq.com/cgi-bin/qm/qr?k=wtbaNx7EioxeaqS9z7RQWVXPIxg2zYr7&jump_from=webapi&authKey=vlqnv/AV2DbJEvGIcxdlNSpfxVy+8vVqijgreRdnVKOaydpc+YSw4MctmEbr0k5"><img alt="QQ_community" src="https://img.shields.io/badge/QQ群-775869627-purple?style=for-the-badge&color=76bad9"></a>
|
||||
<a href="https://t.me/+hAsD2Ebl5as3NmY1"><img alt="Telegram_community" src="https://img.shields.io/badge/Telegram-AstrBot-purple?style=for-the-badge&color=76bad9"></a>
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||

|
||||

|
||||

|
||||
|
||||
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_en.md">English</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/blob/master/README_ja.md">日本語</a> |
|
||||
@@ -149,6 +152,8 @@ pre-commit install
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
<details><summary>👉 点击展开多张 Demo 截图 👈</summary>
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
@@ -170,6 +175,9 @@ _✨ WebUI ✨_
|
||||
|
||||
</div>
|
||||
|
||||
</details>
|
||||
|
||||
|
||||
## ❤️ Special Thanks
|
||||
|
||||
特别感谢所有 Contributors 和插件开发者对 AstrBot 的贡献 ❤️
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from astrbot.core.provider import Provider, STTProvider, Personality
|
||||
from astrbot.core.provider.entites import (
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
ProviderType,
|
||||
ProviderMetaData,
|
||||
|
||||
@@ -0,0 +1,238 @@
|
||||
import asyncio
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import click
|
||||
from pathlib import Path
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
/ \ / | || _ \ | _ \ / __ \ | |
|
||||
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# utils
|
||||
def _get_astrbot_root(path: str | None) -> Path:
|
||||
"""获取astrbot根目录"""
|
||||
match path:
|
||||
case None:
|
||||
match ASTRBOT_ROOT := os.getenv("ASTRBOT_ROOT"):
|
||||
case None:
|
||||
astrbot_root = Path.cwd() / "data"
|
||||
case _:
|
||||
astrbot_root = Path(ASTRBOT_ROOT).resolve()
|
||||
case str():
|
||||
astrbot_root = Path(path).resolve()
|
||||
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
if not dot_astrbot.exists():
|
||||
if click.confirm(
|
||||
f"运行前必须先执行初始化!请检查当前目录是否正确,回车以继续: {astrbot_root}",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
dot_astrbot.touch()
|
||||
astrbot_root.mkdir(parents=True, exist_ok=True)
|
||||
click.echo(f"Created {dot_astrbot}")
|
||||
|
||||
return astrbot_root
|
||||
|
||||
|
||||
# 通过类型来验证先后,必须先获取 Path 对象才能对该目录进行检查
|
||||
def _check_astrbot_root(astrbot_root: Path) -> None:
|
||||
"""验证"""
|
||||
dot_astrbot = astrbot_root / ".astrbot"
|
||||
if not astrbot_root.exists():
|
||||
click.echo(f"AstrBot root directory does not exist: {astrbot_root}")
|
||||
click.echo("Please run 'astrbot init' to create the directory.")
|
||||
sys.exit(1)
|
||||
else:
|
||||
click.echo(f"AstrBot root directory exists: {astrbot_root}")
|
||||
if not dot_astrbot.exists():
|
||||
click.echo(
|
||||
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
|
||||
)
|
||||
if click.confirm(
|
||||
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
dot_astrbot.touch()
|
||||
click.echo(f"Created {dot_astrbot}")
|
||||
else:
|
||||
click.echo(f"Welcome back! AstrBot root directory: {astrbot_root}")
|
||||
|
||||
|
||||
async def _check_dashboard(astrbot_root: Path) -> None:
|
||||
"""检查是否安装了dashboard"""
|
||||
try:
|
||||
from ..core.utils.io import get_dashboard_version, download_dashboard
|
||||
except ImportError:
|
||||
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
|
||||
|
||||
try:
|
||||
# 添加 create=True 参数以确保在初始化时不会抛出异常
|
||||
dashboard_version = await get_dashboard_version()
|
||||
match dashboard_version:
|
||||
case None:
|
||||
click.echo("未安装管理面板")
|
||||
if click.confirm(
|
||||
"是否安装管理面板?",
|
||||
default=True,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("正在安装管理面板...")
|
||||
# 确保使用 create=True 参数
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板安装完成")
|
||||
|
||||
case str():
|
||||
if dashboard_version == f"v{VERSION}":
|
||||
click.echo("无需更新")
|
||||
else:
|
||||
try:
|
||||
version = dashboard_version.split("v")[1]
|
||||
click.echo(f"管理面板版本: {version}")
|
||||
# 确保使用 create=True 参数
|
||||
await download_dashboard(
|
||||
path="data/dashboard.zip", extract_path=str(astrbot_root)
|
||||
)
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
except FileNotFoundError:
|
||||
click.echo("初始化管理面板目录...")
|
||||
# 初始化模式下,下载到指定位置
|
||||
try:
|
||||
await download_dashboard(
|
||||
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
|
||||
)
|
||||
click.echo("管理面板初始化完成")
|
||||
except Exception as e:
|
||||
click.echo(f"下载管理面板失败: {e}")
|
||||
return
|
||||
|
||||
|
||||
@click.group(name="astrbot")
|
||||
def cli() -> None:
|
||||
"""The AstrBot CLI"""
|
||||
click.echo(logo_tmpl)
|
||||
click.echo("Welcome to AstrBot CLI!")
|
||||
click.echo(f"AstrBot version: {VERSION}")
|
||||
|
||||
|
||||
# region init
|
||||
@cli.command()
|
||||
@click.option("--path", "-p", help="AstrBot 数据目录")
|
||||
@click.option("--force", "-f", is_flag=True, help="强制初始化")
|
||||
def init(path: str | None, force: bool) -> None:
|
||||
"""Initialize AstrBot"""
|
||||
click.echo("Initializing AstrBot...")
|
||||
astrbot_root = _get_astrbot_root(path)
|
||||
if force:
|
||||
if click.confirm(
|
||||
"强制初始化会删除当前目录下的所有文件,是否继续?",
|
||||
default=False,
|
||||
abort=True,
|
||||
):
|
||||
click.echo("正在删除当前目录下的所有文件...")
|
||||
shutil.rmtree(astrbot_root, ignore_errors=True)
|
||||
|
||||
_check_astrbot_root(astrbot_root)
|
||||
|
||||
click.echo(f"AstrBot root directory: {astrbot_root}")
|
||||
|
||||
if not astrbot_root.exists():
|
||||
# 创建目录
|
||||
astrbot_root.mkdir(parents=True, exist_ok=True)
|
||||
click.echo(f"Created directory: {astrbot_root}")
|
||||
else:
|
||||
click.echo(f"Directory already exists: {astrbot_root}")
|
||||
|
||||
config_path: Path = astrbot_root / "config"
|
||||
plugins_path: Path = astrbot_root / "plugins"
|
||||
temp_path: Path = astrbot_root / "temp"
|
||||
config_path.mkdir(parents=True, exist_ok=True)
|
||||
plugins_path.mkdir(parents=True, exist_ok=True)
|
||||
temp_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
click.echo(f"Created directories: {config_path}, {plugins_path}, {temp_path}")
|
||||
|
||||
# 检查是否安装了dashboard
|
||||
asyncio.run(_check_dashboard(astrbot_root))
|
||||
|
||||
|
||||
# region run
|
||||
@cli.command()
|
||||
@click.option("--path", "-p", help="AstrBot 数据目录")
|
||||
def run(path: str | None = None) -> None:
|
||||
"""Run AstrBot"""
|
||||
# 解析为绝对路径
|
||||
try:
|
||||
from ..core.log import LogBroker
|
||||
from ..core import db_helper
|
||||
from ..core.initial_loader import InitialLoader
|
||||
except ImportError:
|
||||
from astrbot.core.log import LogBroker
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
|
||||
astrbot_root = _get_astrbot_root(path)
|
||||
|
||||
_check_astrbot_root(astrbot_root)
|
||||
|
||||
asyncio.run(_check_dashboard(astrbot_root))
|
||||
|
||||
log_broker = LogBroker()
|
||||
db = db_helper
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
try:
|
||||
asyncio.run(core_lifecycle.start())
|
||||
except KeyboardInterrupt:
|
||||
click.echo("接收到退出信号,正在关闭 AstrBot...")
|
||||
except Exception as e:
|
||||
click.echo(f"运行时出现错误: {e}")
|
||||
|
||||
|
||||
# region Basic
|
||||
@cli.command(name="version")
|
||||
def version() -> None:
|
||||
"""Show the version of AstrBot"""
|
||||
click.echo(f"AstrBot version: {VERSION}")
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("command_name", required=False, type=str)
|
||||
def help(command_name: str | None) -> None:
|
||||
"""Show help information for commands
|
||||
|
||||
If COMMAND_NAME is provided, show detailed help for that command.
|
||||
Otherwise, show general help information.
|
||||
"""
|
||||
ctx = click.get_current_context()
|
||||
if command_name:
|
||||
# 查找指定命令
|
||||
command = cli.get_command(ctx, command_name)
|
||||
if command:
|
||||
# 显示特定命令的帮助信息
|
||||
click.echo(command.get_help(ctx))
|
||||
else:
|
||||
click.echo(f"Unknown command: {command_name}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
# 显示通用帮助信息
|
||||
click.echo(cli.get_help(ctx))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
@@ -23,7 +23,10 @@ db_helper = SQLiteDatabase(DB_PATH)
|
||||
sp = (
|
||||
SharedPreferences()
|
||||
) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中
|
||||
pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", ""))
|
||||
pip_installer = PipInstaller(
|
||||
astrbot_config.get("pip_install_arg", ""),
|
||||
astrbot_config.get("pypi_index_url", None),
|
||||
)
|
||||
web_chat_queue = asyncio.Queue(maxsize=32)
|
||||
web_chat_back_queue = asyncio.Queue(maxsize=32)
|
||||
WEBUI_SK = "Advanced_System_for_Text_Response_and_Bot_Operations_Tool"
|
||||
|
||||
+100
-14
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.5.2"
|
||||
VERSION = "3.5.7"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -38,6 +38,7 @@ DEFAULT_CONFIG = {
|
||||
"no_permission_reply": True,
|
||||
"empty_mention_waiting": True,
|
||||
"friend_message_needs_wake_prefix": False,
|
||||
"ignore_bot_self_message": False,
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
@@ -50,6 +51,9 @@ DEFAULT_CONFIG = {
|
||||
"default_personality": "default",
|
||||
"prompt_prefix": "",
|
||||
"max_context_length": -1,
|
||||
"dequeue_context_length": 1,
|
||||
"streaming_response": False,
|
||||
"streaming_segmented": False,
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
@@ -58,6 +62,7 @@ DEFAULT_CONFIG = {
|
||||
"provider_tts_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
"dual_output": False,
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
"group_icl_enable": False,
|
||||
@@ -95,7 +100,7 @@ DEFAULT_CONFIG = {
|
||||
"wake_prefix": ["/"],
|
||||
"log_level": "INFO",
|
||||
"pip_install_arg": "",
|
||||
"plugin_repo_mirror": "",
|
||||
"pypi_index_url": "https://mirrors.aliyun.com/pypi/simple/",
|
||||
"knowledge_db": {},
|
||||
"persona": [],
|
||||
"timezone": "",
|
||||
@@ -135,6 +140,7 @@ CONFIG_METADATA_2 = {
|
||||
"enable": False,
|
||||
"ws_reverse_host": "0.0.0.0",
|
||||
"ws_reverse_port": 6199,
|
||||
"ws_reverse_token": "",
|
||||
},
|
||||
"gewechat(微信)": {
|
||||
"id": "gwchat",
|
||||
@@ -153,6 +159,7 @@ CONFIG_METADATA_2 = {
|
||||
"secret": "",
|
||||
"token": "",
|
||||
"encoding_aes_key": "",
|
||||
"kf_name": "",
|
||||
"api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/",
|
||||
"callback_server_host": "0.0.0.0",
|
||||
"port": 6195,
|
||||
@@ -181,14 +188,37 @@ CONFIG_METADATA_2 = {
|
||||
"start_message": "Hello, I'm AstrBot!",
|
||||
"telegram_api_base_url": "https://api.telegram.org/bot",
|
||||
"telegram_file_base_url": "https://api.telegram.org/file/bot",
|
||||
"telegram_command_register": True,
|
||||
"telegram_command_auto_refresh": True,
|
||||
"telegram_command_register_interval": 300,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"kf_name": {
|
||||
"description": "微信客服账号名",
|
||||
"type": "string",
|
||||
"hint": "可选。微信客服账号名(不是 ID)。可在 https://kf.weixin.qq.com/kf/frame#/accounts 获取"
|
||||
},
|
||||
"telegram_token": {
|
||||
"description": "Bot Token",
|
||||
"type": "string",
|
||||
"hint": "如果你的网络环境为中国大陆,请在 `其他配置` 处设置代理或更改 api_base。",
|
||||
},
|
||||
"telegram_command_register": {
|
||||
"description": "Telegram 命令注册",
|
||||
"type": "bool",
|
||||
"hint": "启用后,AstrBot 将会自动注册 Telegram 命令。",
|
||||
},
|
||||
"telegram_command_auto_refresh": {
|
||||
"description": "Telegram 命令自动刷新",
|
||||
"type": "bool",
|
||||
"hint": "启用后,AstrBot 将会在运行时自动刷新 Telegram 命令。(单独设置此项无效)",
|
||||
},
|
||||
"telegram_command_register_interval": {
|
||||
"description": "Telegram 命令自动刷新间隔",
|
||||
"type": "int",
|
||||
"hint": "Telegram 命令自动刷新间隔,单位为秒。",
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
"type": "string",
|
||||
@@ -213,7 +243,7 @@ CONFIG_METADATA_2 = {
|
||||
"secret": {
|
||||
"description": "secret",
|
||||
"type": "string",
|
||||
"hint": "必填项。QQ 官方机器人平台的 secret。如何获取请参考文档。",
|
||||
"hint": "必填项。",
|
||||
},
|
||||
"enable_group_c2c": {
|
||||
"description": "启用消息列表单聊",
|
||||
@@ -235,6 +265,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket 端口。",
|
||||
},
|
||||
"ws_reverse_token": {
|
||||
"description": "反向 Websocket Token",
|
||||
"type": "string",
|
||||
"hint": "aiocqhttp 适配器的反向 Websocket Token。未设置则不启用 Token 验证。",
|
||||
},
|
||||
"lark_bot_name": {
|
||||
"description": "飞书机器人的名字",
|
||||
"type": "string",
|
||||
@@ -247,6 +282,9 @@ CONFIG_METADATA_2 = {
|
||||
"description": "平台设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"plugin_enable": {
|
||||
"invisible": True, # 隐藏插件启用配置
|
||||
},
|
||||
"unique_session": {
|
||||
"description": "会话隔离",
|
||||
"type": "bool",
|
||||
@@ -282,6 +320,11 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,私聊消息需要唤醒前缀才会被处理,同群聊一样。",
|
||||
},
|
||||
"ignore_bot_self_message": {
|
||||
"description": "是否忽略机器人自身的消息",
|
||||
"type": "bool",
|
||||
"hint": "某些平台如 gewechat 会将自身账号在其他 APP 端发送的消息也当做消息事件下发导致给自己发消息时唤醒机器人",
|
||||
},
|
||||
"segmented_reply": {
|
||||
"description": "分段回复",
|
||||
"type": "object",
|
||||
@@ -523,12 +566,17 @@ CONFIG_METADATA_2 = {
|
||||
"model": "gemini-2.0-flash-exp",
|
||||
},
|
||||
"gm_resp_image_modal": False,
|
||||
"gm_native_search": False,
|
||||
"gm_native_coderunner": False,
|
||||
"gm_safety_settings": {
|
||||
"harassment": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"hate_speech": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"sexually_explicit": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
"dangerous_content": "BLOCK_MEDIUM_AND_ABOVE",
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"budget": 0,
|
||||
},
|
||||
},
|
||||
"DeepSeek": {
|
||||
"id": "deepseek_default",
|
||||
@@ -699,6 +747,18 @@ CONFIG_METADATA_2 = {
|
||||
"type": "bool",
|
||||
"hint": "启用后,将支持返回图片内容。需要模型支持,否则会报错。具体支持模型请查看 Google Gemini 官方网站。温馨提示,如果您需要生成图片,请关闭 `启用群员识别` 配置获得更好的效果。",
|
||||
},
|
||||
"gm_native_search": {
|
||||
"description": "启用原生搜索功能",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效,免费次数限制请查阅官方文档",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_native_coderunner": {
|
||||
"description": "启用原生代码执行器",
|
||||
"type": "bool",
|
||||
"hint": "启用后所有函数工具将全部失效",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"gm_safety_settings": {
|
||||
"description": "安全过滤器",
|
||||
"type": "object",
|
||||
@@ -750,6 +810,17 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
},
|
||||
},
|
||||
"gm_thinking_config": {
|
||||
"description": "Gemini思考设置",
|
||||
"type": "object",
|
||||
"items": {
|
||||
"budget": {
|
||||
"description": "思考预算",
|
||||
"type": "int",
|
||||
"hint": "模型应该生成的思考Token的数量,设为0关闭思考。除gemini-2.5-flash外的模型会静默忽略此参数。",
|
||||
},
|
||||
},
|
||||
},
|
||||
"rag_options": {
|
||||
"description": "RAG 选项",
|
||||
"type": "object",
|
||||
@@ -923,8 +994,8 @@ CONFIG_METADATA_2 = {
|
||||
"dify_api_type": {
|
||||
"description": "Dify 应用类型",
|
||||
"type": "string",
|
||||
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, agent, workflow 三种应用类型",
|
||||
"options": ["chat", "agent", "workflow"],
|
||||
"hint": "Dify API 类型。根据 Dify 官网,目前支持 chat, chatflow, agent, workflow 三种应用类型。",
|
||||
"options": ["chat", "chatflow", "agent", "workflow"],
|
||||
},
|
||||
"dify_workflow_output_key": {
|
||||
"description": "Dify Workflow 输出变量名",
|
||||
@@ -993,6 +1064,21 @@ CONFIG_METADATA_2 = {
|
||||
"type": "int",
|
||||
"hint": "超出这个数量时将丢弃最旧的部分,用户和AI的一轮聊天记为 1 条。-1 表示不限制,默认为不限制。",
|
||||
},
|
||||
"dequeue_context_length": {
|
||||
"description": "丢弃对话数量(条)",
|
||||
"type": "int",
|
||||
"hint": "超出 最多携带对话数量(条) 时,丢弃多少条记录,用户和AI的一轮聊天记为 1 条。适宜的配置,可以提高超长上下文对话 deepseek 命中缓存效果,理想情况下计费将降低到1/3以下",
|
||||
},
|
||||
"streaming_response": {
|
||||
"description": "启用流式回复",
|
||||
"type": "bool",
|
||||
"hint": "启用后,将会流式输出 LLM 的响应。目前仅支持 OpenAI API提供商 以及 Telegram、QQ Official 私聊 两个平台",
|
||||
},
|
||||
"streaming_segmented": {
|
||||
"description": "不支持流式回复的平台分段输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,若平台不支持流式回复,会分段输出。目前仅支持 aiocqhttp 和 gewechat 两个平台,不支持或无需使用流式分段输出的平台会静默忽略此选项",
|
||||
},
|
||||
},
|
||||
},
|
||||
"persona": {
|
||||
@@ -1067,6 +1153,12 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
|
||||
},
|
||||
"dual_output": {
|
||||
"description": "启用语音和文字双输出",
|
||||
"type": "bool",
|
||||
"hint": "启用后,Bot 将同时输出语音和文字消息。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_ltm_settings": {
|
||||
@@ -1201,16 +1293,10 @@ CONFIG_METADATA_2 = {
|
||||
"type": "string",
|
||||
"hint": "安装插件依赖时,会使用 Python 的 pip 工具。这里可以填写额外的参数,如 `--break-system-package` 等。",
|
||||
},
|
||||
"plugin_repo_mirror": {
|
||||
"description": "插件仓库镜像",
|
||||
"pypi_index_url": {
|
||||
"description": "PyPI 软件仓库地址",
|
||||
"type": "string",
|
||||
"hint": "已废弃,请使用管理面板->设置页的代理地址选择",
|
||||
"obvious_hint": True,
|
||||
"options": [
|
||||
"default",
|
||||
"https://ghp.ci/",
|
||||
"https://github-mirror.us.kg/",
|
||||
],
|
||||
"hint": "安装 Python 依赖时请求的 PyPI 软件仓库地址。默认为 https://mirrors.aliyun.com/pypi/simple/",
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -175,7 +175,15 @@ class ConversationManager:
|
||||
if record["role"] == "user":
|
||||
temp_contexts.append(f"User: {record['content']}")
|
||||
elif record["role"] == "assistant":
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
if "content" in record and record["content"]:
|
||||
temp_contexts.append(f"Assistant: {record['content']}")
|
||||
elif "tool_calls" in record:
|
||||
tool_calls_str = json.dumps(
|
||||
record["tool_calls"], ensure_ascii=False
|
||||
)
|
||||
temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
|
||||
else:
|
||||
temp_contexts.append("Assistant: [未知的内容]")
|
||||
contexts.insert(0, temp_contexts)
|
||||
temp_contexts = []
|
||||
|
||||
|
||||
@@ -106,7 +106,7 @@ class AstrBotCoreLifecycle:
|
||||
await self.pipeline_scheduler.initialize()
|
||||
|
||||
# 初始化更新器
|
||||
self.astrbot_updator = AstrBotUpdator(self.astrbot_config["plugin_repo_mirror"])
|
||||
self.astrbot_updator = AstrBotUpdator()
|
||||
|
||||
# 初始化事件总线
|
||||
self.event_bus = EventBus(self.event_queue, self.pipeline_scheduler)
|
||||
|
||||
+11
-6
@@ -25,6 +25,7 @@ import logging
|
||||
import colorlog
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from collections import deque
|
||||
from asyncio import Queue
|
||||
from typing import List
|
||||
@@ -141,11 +142,13 @@ class LogQueueHandler(logging.Handler):
|
||||
record (logging.LogRecord): 日志记录对象, 包含日志信息
|
||||
"""
|
||||
log_entry = self.format(record)
|
||||
self.log_broker.publish({
|
||||
"level": record.levelname,
|
||||
"time": record.asctime,
|
||||
"data": log_entry,
|
||||
})
|
||||
self.log_broker.publish(
|
||||
{
|
||||
"level": record.levelname,
|
||||
"time": record.asctime,
|
||||
"data": log_entry,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class LogManager:
|
||||
@@ -169,7 +172,9 @@ class LogManager:
|
||||
if logger.hasHandlers():
|
||||
return logger
|
||||
# 如果logger没有处理器
|
||||
console_handler = logging.StreamHandler() # 创建一个StreamHandler用于控制台输出
|
||||
console_handler = logging.StreamHandler(
|
||||
sys.stdout
|
||||
) # 创建一个StreamHandler用于控制台输出
|
||||
console_handler.setLevel(
|
||||
logging.DEBUG
|
||||
) # 将日志级别设置为DEBUG(最低级别, 显示所有日志), *如果插件没有设置级别, 默认为DEBUG
|
||||
|
||||
@@ -26,10 +26,12 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import asyncio
|
||||
import typing as T
|
||||
from enum import Enum
|
||||
from pydantic.v1 import BaseModel
|
||||
from astrbot.core.utils.io import download_image_by_url, file_to_base64
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.io import download_image_by_url, file_to_base64, download_file
|
||||
|
||||
|
||||
class ComponentType(Enum):
|
||||
@@ -193,6 +195,7 @@ class Record(BaseMessageComponent):
|
||||
bs64_data = file_to_base64(self.file)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {self.file}")
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
|
||||
@@ -397,6 +400,7 @@ class Image(BaseMessageComponent):
|
||||
bs64_data = file_to_base64(url)
|
||||
else:
|
||||
raise Exception(f"not a valid file: {url}")
|
||||
bs64_data = bs64_data.removeprefix("base64://")
|
||||
return bs64_data
|
||||
|
||||
|
||||
@@ -405,17 +409,15 @@ class Reply(BaseMessageComponent):
|
||||
id: T.Union[str, int]
|
||||
"""所引用的消息 ID"""
|
||||
chain: T.Optional[T.List["BaseMessageComponent"]] = []
|
||||
"""引用的消息段列表"""
|
||||
"""被引用的消息段列表"""
|
||||
sender_id: T.Optional[int] | T.Optional[str] = 0
|
||||
"""引用的消息发送者 ID"""
|
||||
"""被引用的消息对应的发送者的 ID"""
|
||||
sender_nickname: T.Optional[str] = ""
|
||||
"""引用的消息发送者昵称"""
|
||||
"""被引用的消息对应的发送者的昵称"""
|
||||
time: T.Optional[int] = 0
|
||||
"""引用的消息发送时间"""
|
||||
"""被引用的消息发送时间"""
|
||||
message_str: T.Optional[str] = ""
|
||||
"""解析后的纯文本消息字符串"""
|
||||
sender_str: T.Optional[str] = ""
|
||||
"""被引用的消息纯文本"""
|
||||
"""被引用的消息解析后的纯文本消息字符串"""
|
||||
|
||||
text: T.Optional[str] = ""
|
||||
"""deprecated"""
|
||||
@@ -552,15 +554,91 @@ class Unknown(BaseMessageComponent):
|
||||
|
||||
class File(BaseMessageComponent):
|
||||
"""
|
||||
目前此消息段只适配了 Napcat。
|
||||
文件消息段
|
||||
"""
|
||||
|
||||
type: ComponentType = "File"
|
||||
name: T.Optional[str] = "" # 名字
|
||||
file: T.Optional[str] = "" # url(本地路径)
|
||||
_file: T.Optional[str] = "" # 本地路径
|
||||
url: T.Optional[str] = "" # url
|
||||
_downloaded: bool = False # 是否已经下载
|
||||
|
||||
def __init__(self, name: str, file: str):
|
||||
super().__init__(name=name, file=file)
|
||||
def __init__(self, name: str = "", file: str = "", url: str = ""):
|
||||
super().__init__(name=name, _file=file, url=url)
|
||||
|
||||
@property
|
||||
def file(self) -> str:
|
||||
"""
|
||||
获取文件路径,如果文件不存在但有URL,则同步下载文件
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
"""
|
||||
if self._file and os.path.exists(self._file):
|
||||
return self._file
|
||||
|
||||
if self.url and not self._downloaded:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning(
|
||||
"不可以在异步上下文中同步等待下载! 请使用 await get_file() 代替"
|
||||
)
|
||||
return ""
|
||||
else:
|
||||
# 等待下载完成
|
||||
loop.run_until_complete(self._download_file())
|
||||
|
||||
if self._file and os.path.exists(self._file):
|
||||
return self._file
|
||||
except Exception as e:
|
||||
logger.error(f"文件下载失败: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
@file.setter
|
||||
def file(self, value: str):
|
||||
"""
|
||||
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
|
||||
|
||||
Args:
|
||||
value (str): 文件路径或URL
|
||||
"""
|
||||
if value.startswith("http://") or value.startswith("https://"):
|
||||
self.url = value
|
||||
else:
|
||||
self._file = value
|
||||
|
||||
async def get_file(self) -> str:
|
||||
"""
|
||||
异步获取文件
|
||||
To 插件开发者: 请注意在使用后清理下载的文件, 以免占用过多空间
|
||||
|
||||
Returns:
|
||||
str: 文件路径
|
||||
"""
|
||||
if self._file and os.path.exists(self._file):
|
||||
return self._file
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
return self._file
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
if self._downloaded:
|
||||
return
|
||||
|
||||
os.makedirs("data/download", exist_ok=True)
|
||||
filename = self.name or f"{uuid.uuid4().hex}"
|
||||
file_path = f"data/download/{filename}"
|
||||
|
||||
await download_file(self.url, file_path)
|
||||
|
||||
self._file = file_path
|
||||
self._downloaded = True
|
||||
|
||||
|
||||
class WechatEmoji(BaseMessageComponent):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Union, AsyncGenerator
|
||||
from dataclasses import dataclass, field
|
||||
from astrbot.core.message.components import (
|
||||
BaseMessageComponent,
|
||||
@@ -111,6 +111,30 @@ class MessageChain:
|
||||
"""获取纯文本消息。这个方法将获取 chain 中所有 Plain 组件的文本并拼接成一条消息。空格分隔。"""
|
||||
return " ".join([comp.text for comp in self.chain if isinstance(comp, Plain)])
|
||||
|
||||
def squash_plain(self):
|
||||
"""将消息链中的所有 Plain 消息段聚合到第一个 Plain 消息段中。"""
|
||||
if not self.chain:
|
||||
return
|
||||
|
||||
new_chain = []
|
||||
first_plain = None
|
||||
plain_texts = []
|
||||
|
||||
for comp in self.chain:
|
||||
if isinstance(comp, Plain):
|
||||
if first_plain is None:
|
||||
first_plain = comp
|
||||
new_chain.append(comp)
|
||||
plain_texts.append(comp.text)
|
||||
else:
|
||||
new_chain.append(comp)
|
||||
|
||||
if first_plain is not None:
|
||||
first_plain.text = "".join(plain_texts)
|
||||
|
||||
self.chain = new_chain
|
||||
return self
|
||||
|
||||
|
||||
class EventResultType(enum.Enum):
|
||||
"""用于描述事件处理的结果类型。
|
||||
@@ -131,6 +155,10 @@ class ResultContentType(enum.Enum):
|
||||
"""调用 LLM 产生的结果"""
|
||||
GENERAL_RESULT = enum.auto()
|
||||
"""普通的消息结果"""
|
||||
STREAMING_RESULT = enum.auto()
|
||||
"""调用 LLM 产生的流式结果"""
|
||||
STREAMING_FINISH= enum.auto()
|
||||
"""流式输出完成"""
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -152,6 +180,9 @@ class MessageEventResult(MessageChain):
|
||||
default_factory=lambda: ResultContentType.GENERAL_RESULT
|
||||
)
|
||||
|
||||
async_stream: Optional[AsyncGenerator] = None
|
||||
"""异步流"""
|
||||
|
||||
def stop_event(self) -> "MessageEventResult":
|
||||
"""终止事件传播。"""
|
||||
self.result_type = EventResultType.STOP
|
||||
@@ -168,6 +199,11 @@ class MessageEventResult(MessageChain):
|
||||
"""
|
||||
return self.result_type == EventResultType.STOP
|
||||
|
||||
def set_async_stream(self, stream: AsyncGenerator) -> "MessageEventResult":
|
||||
"""设置异步流。"""
|
||||
self.async_stream = stream
|
||||
return self
|
||||
|
||||
def set_result_content_type(self, typ: ResultContentType) -> "MessageEventResult":
|
||||
"""设置事件处理的结果类型。
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from .waking_check.stage import WakingCheckStage
|
||||
from .whitelist_check.stage import WhitelistCheckStage
|
||||
from .rate_limit_check.stage import RateLimitStage
|
||||
from .content_safety_check.stage import ContentSafetyCheckStage
|
||||
from .platform_compatibility.stage import PlatformCompatibilityStage
|
||||
from .preprocess_stage.stage import PreProcessStage
|
||||
from .process_stage.stage import ProcessStage
|
||||
from .result_decorate.stage import ResultDecorateStage
|
||||
@@ -18,6 +19,7 @@ STAGES_ORDER = [
|
||||
"WhitelistCheckStage", # 检查是否在群聊/私聊白名单
|
||||
"RateLimitStage", # 检查会话是否超过频率限制
|
||||
"ContentSafetyCheckStage", # 检查内容安全
|
||||
"PlatformCompatibilityStage", # 检查所有处理器的平台兼容性
|
||||
"PreProcessStage", # 预处理
|
||||
"ProcessStage", # 交由 Stars 处理(a.k.a 插件),或者 LLM 调用
|
||||
"ResultDecorateStage", # 处理结果,比如添加回复前缀、t2i、转换为语音 等
|
||||
@@ -29,6 +31,7 @@ __all__ = [
|
||||
"WhitelistCheckStage",
|
||||
"RateLimitStage",
|
||||
"ContentSafetyCheckStage",
|
||||
"PlatformCompatibilityStage",
|
||||
"PreProcessStage",
|
||||
"ProcessStage",
|
||||
"ResultDecorateStage",
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_stage
|
||||
class PlatformCompatibilityStage(Stage):
|
||||
"""检查所有处理器的平台兼容性。
|
||||
|
||||
这个阶段会检查所有处理器是否在当前平台启用,如果未启用则设置platform_compatible属性为False。
|
||||
"""
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
"""初始化平台兼容性检查阶段
|
||||
|
||||
Args:
|
||||
ctx (PipelineContext): 消息管道上下文对象, 包括配置和插件管理器
|
||||
"""
|
||||
self.ctx = ctx
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
# 获取当前平台ID
|
||||
platform_id = event.get_platform_id()
|
||||
|
||||
# 获取已激活的处理器
|
||||
activated_handlers = event.get_extra("activated_handlers")
|
||||
if activated_handlers is None:
|
||||
activated_handlers = []
|
||||
|
||||
# 标记不兼容的处理器
|
||||
for handler in activated_handlers:
|
||||
if not isinstance(handler, StarHandlerMetadata):
|
||||
continue
|
||||
# 检查处理器是否在当前平台启用
|
||||
enabled = handler.is_enabled_for_platform(platform_id)
|
||||
if not enabled:
|
||||
if handler.handler_module_path in star_map:
|
||||
plugin_name = star_map[handler.handler_module_path].name
|
||||
logger.debug(
|
||||
f"[PlatformCompatibilityStage] 插件 {plugin_name} 在平台 {platform_id} 未启用,标记处理器 {handler.handler_name} 为平台不兼容"
|
||||
)
|
||||
# 设置处理器为平台不兼容状态
|
||||
# TODO: 更好的标记方式
|
||||
handler.platform_compatible = False
|
||||
else:
|
||||
# 确保处理器为平台兼容状态
|
||||
handler.platform_compatible = True
|
||||
|
||||
# 更新已激活的处理器列表
|
||||
event.set_extra("activated_handlers", activated_handlers)
|
||||
@@ -12,11 +12,12 @@ from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import (
|
||||
MessageEventResult,
|
||||
ResultContentType,
|
||||
MessageChain,
|
||||
)
|
||||
from astrbot.core.message.components import Image
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from astrbot.core.provider.entites import (
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
LLMResponse,
|
||||
ToolCallMessageSegment,
|
||||
@@ -25,6 +26,13 @@ from astrbot.core.provider.entites import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
EmbeddedResource,
|
||||
TextResourceContents,
|
||||
BlobResourceContents,
|
||||
)
|
||||
|
||||
|
||||
class LLMRequestSubStage(Stage):
|
||||
@@ -37,6 +45,13 @@ class LLMRequestSubStage(Stage):
|
||||
self.max_context_length = ctx.astrbot_config["provider_settings"][
|
||||
"max_context_length"
|
||||
] # int
|
||||
self.dequeue_context_length = min(
|
||||
max(1, ctx.astrbot_config["provider_settings"]["dequeue_context_length"]),
|
||||
self.max_context_length - 1,
|
||||
) # int
|
||||
self.streaming_response = ctx.astrbot_config["provider_settings"][
|
||||
"streaming_response"
|
||||
] # bool
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
@@ -63,7 +78,11 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
|
||||
if req.conversation:
|
||||
req.contexts = json.loads(req.conversation.history)
|
||||
all_contexts = json.loads(req.conversation.history)
|
||||
req.contexts = self._process_tool_message_pairs(
|
||||
all_contexts, remove_tags=True
|
||||
)
|
||||
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if self.provider_wake_prefix:
|
||||
@@ -104,8 +123,10 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
# 执行请求 LLM 前事件钩子。
|
||||
# 装饰 system_prompt 等功能
|
||||
# 获取当前平台ID
|
||||
platform_id = event.get_platform_id()
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMRequestEvent
|
||||
EventType.OnLLMRequestEvent, platform_id=platform_id
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
@@ -131,76 +152,152 @@ class LLMRequestSubStage(Stage):
|
||||
and len(req.contexts) // 2 > self.max_context_length
|
||||
):
|
||||
logger.debug("上下文长度超过限制,将截断。")
|
||||
req.contexts = req.contexts[-self.max_context_length * 2 :]
|
||||
req.contexts = req.contexts[
|
||||
-(self.max_context_length - self.dequeue_context_length + 1) * 2 :
|
||||
]
|
||||
# 找到第一个role 为 user 的索引,确保上下文格式正确
|
||||
index = next(
|
||||
(
|
||||
i
|
||||
for i, item in enumerate(req.contexts)
|
||||
if item.get("role") == "user"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if index is not None and index > 0:
|
||||
req.contexts = req.contexts[index:]
|
||||
|
||||
# session_id
|
||||
if not req.session_id:
|
||||
req.session_id = event.unified_msg_origin
|
||||
|
||||
try:
|
||||
need_loop = True
|
||||
while need_loop:
|
||||
need_loop = False
|
||||
logger.debug(f"提供商请求 Payload: {req}")
|
||||
llm_response = await provider.text_chat(**req.__dict__) # 请求 LLM
|
||||
async def requesting(req: ProviderRequest):
|
||||
try:
|
||||
need_loop = True
|
||||
while need_loop:
|
||||
need_loop = False
|
||||
logger.debug(f"提供商请求 Payload: {req}")
|
||||
|
||||
# 执行 LLM 响应后的事件钩子。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMResponseEvent
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
final_llm_response = None
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
|
||||
async for result in self._handle_llm_response(event, req, llm_response):
|
||||
if isinstance(result, ProviderRequest):
|
||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||
req = result
|
||||
need_loop = True
|
||||
if self.streaming_response:
|
||||
stream = provider.text_chat_stream(**req.__dict__)
|
||||
async for llm_response in stream:
|
||||
if llm_response.is_chunk:
|
||||
if llm_response.result_chain:
|
||||
yield llm_response.result_chain # MessageChain
|
||||
else:
|
||||
yield MessageChain().message(
|
||||
llm_response.completion_text
|
||||
)
|
||||
else:
|
||||
final_llm_response = llm_response
|
||||
else:
|
||||
yield
|
||||
final_llm_response = await provider.text_chat(
|
||||
**req.__dict__
|
||||
) # 请求 LLM
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=provider.get_model(),
|
||||
provider_type=provider.meta().type,
|
||||
if not final_llm_response:
|
||||
raise Exception("LLM response is None.")
|
||||
|
||||
# 执行 LLM 响应后的事件钩子。
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnLLMResponseEvent
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_llm_response) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
await handler.handler(event, final_llm_response)
|
||||
except BaseException:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
if event.is_stopped():
|
||||
logger.info(
|
||||
f"{star_map[handler.handler_module_path].name} - {handler.handler_name} 终止了事件传播。"
|
||||
)
|
||||
return
|
||||
|
||||
if self.streaming_response:
|
||||
# 流式输出的处理
|
||||
async for result in self._handle_llm_stream_response(
|
||||
event, req, final_llm_response
|
||||
):
|
||||
if isinstance(result, ProviderRequest):
|
||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||
req = result
|
||||
need_loop = True
|
||||
else:
|
||||
yield
|
||||
else:
|
||||
# 非流式输出的处理
|
||||
async for result in self._handle_llm_response(
|
||||
event, req, final_llm_response
|
||||
):
|
||||
if isinstance(result, ProviderRequest):
|
||||
# 有函数工具调用并且返回了结果,我们需要再次请求 LLM
|
||||
req = result
|
||||
need_loop = True
|
||||
else:
|
||||
yield
|
||||
|
||||
asyncio.create_task(
|
||||
Metric.upload(
|
||||
llm_tick=1,
|
||||
model_name=provider.get_model(),
|
||||
provider_type=provider.meta().type,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, llm_response)
|
||||
# 保存到历史记录
|
||||
await self._save_to_history(event, req, final_llm_response)
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
)
|
||||
)
|
||||
|
||||
if not self.streaming_response:
|
||||
event.set_extra("tool_call_result", None)
|
||||
async for _ in requesting(req):
|
||||
yield
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"
|
||||
)
|
||||
MessageEventResult()
|
||||
.set_result_content_type(ResultContentType.STREAMING_RESULT)
|
||||
.set_async_stream(requesting(req))
|
||||
)
|
||||
return
|
||||
# 这里使用yield来暂停当前阶段,等待流式输出完成后继续处理
|
||||
yield
|
||||
|
||||
if event.get_extra("tool_call_result"):
|
||||
event.set_result(event.get_extra("tool_call_result"))
|
||||
event.set_extra("tool_call_result", None)
|
||||
yield
|
||||
|
||||
# 暂时直接发出去
|
||||
if img_b64 := event.get_extra("tool_call_img_respond"):
|
||||
await event.send(MessageChain(chain=[Image.fromBase64(img_b64)]))
|
||||
event.set_extra("tool_call_img_respond", None)
|
||||
yield
|
||||
|
||||
async def _handle_llm_response(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
||||
) -> AsyncGenerator[None, None]:
|
||||
"""处理 LLM 响应。
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||
"""处理非流式 LLM 响应。
|
||||
|
||||
Returns:
|
||||
bool: 是否需要继续调用 LLM
|
||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||
|
||||
Yields:
|
||||
Iterator[bool]: 将 event 交付给下一个 stage
|
||||
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||
"""
|
||||
if llm_response.role == "assistant":
|
||||
# text completion
|
||||
@@ -223,30 +320,83 @@ class LLMRequestSubStage(Stage):
|
||||
)
|
||||
)
|
||||
elif llm_response.role == "tool":
|
||||
# function calling
|
||||
tool_call_result: list[ToolCallMessageSegment] = []
|
||||
logger.info(
|
||||
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
||||
# 处理函数工具调用
|
||||
async for result in self._handle_function_tools(event, req, llm_response):
|
||||
yield result
|
||||
|
||||
async def _handle_llm_stream_response(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||
"""处理流式 LLM 响应。
|
||||
|
||||
专门用于处理流式输出完成后的响应,与非流式响应处理分离。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||
|
||||
Yields:
|
||||
Iterator[Union[None, ProviderRequest]]: 将 event 交付给下一个 stage 或者返回 ProviderRequest 表示需要再次调用 LLM
|
||||
"""
|
||||
if llm_response.role == "assistant":
|
||||
# text completion
|
||||
if llm_response.result_chain:
|
||||
event.set_result(
|
||||
MessageEventResult(
|
||||
chain=llm_response.result_chain.chain
|
||||
).set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||
)
|
||||
else:
|
||||
event.set_result(
|
||||
MessageEventResult()
|
||||
.message(llm_response.completion_text)
|
||||
.set_result_content_type(ResultContentType.STREAMING_FINISH)
|
||||
)
|
||||
elif llm_response.role == "err":
|
||||
event.set_result(
|
||||
MessageEventResult().message(
|
||||
f"AstrBot 请求失败。\n错误信息: {llm_response.completion_text}"
|
||||
)
|
||||
)
|
||||
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||
llm_response.tools_call_name,
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
try:
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||
)
|
||||
client = req.func_tool.mcp_client_dict[
|
||||
func_tool.mcp_server_name
|
||||
]
|
||||
res = await client.session.call_tool(
|
||||
func_tool.name, func_tool_args
|
||||
)
|
||||
if res:
|
||||
# TODO content的类型可能包括list[TextContent | ImageContent | EmbeddedResource],这里只处理了TextContent。
|
||||
elif llm_response.role == "tool":
|
||||
# 处理函数工具调用
|
||||
async for result in self._handle_function_tools(event, req, llm_response):
|
||||
yield result
|
||||
|
||||
async def _handle_function_tools(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
req: ProviderRequest,
|
||||
llm_response: LLMResponse,
|
||||
) -> AsyncGenerator[Union[None, ProviderRequest], None]:
|
||||
"""处理函数工具调用。
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[Union[None, ProviderRequest], None]: 如果返回 ProviderRequest,表示需要再次调用 LLM
|
||||
"""
|
||||
# function calling
|
||||
tool_call_result: list[ToolCallMessageSegment] = []
|
||||
logger.info(
|
||||
f"触发 {len(llm_response.tools_call_name)} 个函数调用: {llm_response.tools_call_name}"
|
||||
)
|
||||
for func_tool_name, func_tool_args, func_tool_id in zip(
|
||||
llm_response.tools_call_name,
|
||||
llm_response.tools_call_args,
|
||||
llm_response.tools_call_ids,
|
||||
):
|
||||
try:
|
||||
func_tool = req.func_tool.get_func(func_tool_name)
|
||||
if func_tool.origin == "mcp":
|
||||
logger.info(
|
||||
f"从 MCP 服务 {func_tool.mcp_server_name} 调用工具函数:{func_tool.name},参数:{func_tool_args}"
|
||||
)
|
||||
client = req.func_tool.mcp_client_dict[func_tool.mcp_server_name]
|
||||
res = await client.session.call_tool(func_tool.name, func_tool_args)
|
||||
if res:
|
||||
# TODO 仅对ImageContent | EmbeddedResource进行了简单的Fallback
|
||||
if isinstance(res.content[0], TextContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
@@ -254,52 +404,115 @@ class LLMRequestSubStage(Stage):
|
||||
content=res.content[0].text,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||
)
|
||||
# 尝试调用工具函数
|
||||
wrapper = self._call_handler(
|
||||
self.ctx, event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None: # 有 return 返回
|
||||
elif isinstance(res.content[0], ImageContent):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
elif isinstance(res.content[0], EmbeddedResource):
|
||||
resource = res.content[0].resource
|
||||
if isinstance(resource, TextResourceContents):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resp,
|
||||
content=resource.text,
|
||||
)
|
||||
)
|
||||
elif (
|
||||
isinstance(resource, BlobResourceContents)
|
||||
and resource.mimeType
|
||||
and resource.mimeType.startswith("image/")
|
||||
):
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回了图片(已直接发送给用户)",
|
||||
)
|
||||
)
|
||||
event.set_extra(
|
||||
"tool_call_img_respond",
|
||||
res.content[0].data,
|
||||
)
|
||||
else:
|
||||
yield # 有生成器返回
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except BaseException as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: {str(e)}",
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content="返回的数据类型不受支持",
|
||||
)
|
||||
)
|
||||
else:
|
||||
# 获取处理器,过滤掉平台不兼容的处理器
|
||||
platform_id = event.get_platform_id()
|
||||
star_md = star_map.get(func_tool.handler_module_path)
|
||||
if (
|
||||
star_md
|
||||
and platform_id in star_md.supported_platforms
|
||||
and not star_md.supported_platforms[platform_id]
|
||||
):
|
||||
logger.debug(
|
||||
f"处理器 {func_tool_name}({star_md.name}) 在当前平台不兼容或者被禁用,跳过执行"
|
||||
)
|
||||
# 直接跳过,不添加任何消息到tool_call_result
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"调用工具函数:{func_tool_name},参数:{func_tool_args}"
|
||||
)
|
||||
if tool_call_result:
|
||||
# 函数调用结果
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
assistant_msg_seg = AssistantMessageSegment(
|
||||
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
|
||||
)
|
||||
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
|
||||
req.tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=assistant_msg_seg,
|
||||
tool_calls_result=tool_call_result,
|
||||
)
|
||||
yield req # 再次执行 LLM 请求
|
||||
else:
|
||||
if llm_response.completion_text:
|
||||
event.set_result(
|
||||
MessageEventResult().message(llm_response.completion_text)
|
||||
# 尝试调用工具函数
|
||||
wrapper = self._call_handler(
|
||||
self.ctx, event, func_tool.handler, **func_tool_args
|
||||
)
|
||||
async for resp in wrapper:
|
||||
if resp is not None: # 有 return 返回
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=resp,
|
||||
)
|
||||
)
|
||||
else:
|
||||
res = event.get_result()
|
||||
if res and res.chain:
|
||||
event.set_extra("tool_call_result", res)
|
||||
yield # 有生成器返回
|
||||
event.clear_result() # 清除上一个 handler 的结果
|
||||
except BaseException as e:
|
||||
logger.warning(traceback.format_exc())
|
||||
tool_call_result.append(
|
||||
ToolCallMessageSegment(
|
||||
role="tool",
|
||||
tool_call_id=func_tool_id,
|
||||
content=f"error: {str(e)}",
|
||||
)
|
||||
)
|
||||
if tool_call_result:
|
||||
# 函数调用结果
|
||||
req.func_tool = None # 暂时不支持递归工具调用
|
||||
assistant_msg_seg = AssistantMessageSegment(
|
||||
role="assistant", tool_calls=llm_response.to_openai_tool_calls()
|
||||
)
|
||||
# 在多轮 Tool 调用的情况下,这里始终保持最新的 Tool 调用结果,减少上下文长度。
|
||||
req.tool_calls_result = ToolCallsResult(
|
||||
tool_calls_info=assistant_msg_seg,
|
||||
tool_calls_result=tool_call_result,
|
||||
)
|
||||
yield req # 再次执行 LLM 请求
|
||||
else:
|
||||
if llm_response.completion_text:
|
||||
event.set_result(
|
||||
MessageEventResult().message(llm_response.completion_text)
|
||||
)
|
||||
|
||||
async def _save_to_history(
|
||||
self, event: AstrMessageEvent, req: ProviderRequest, llm_response: LLMResponse
|
||||
@@ -309,12 +522,22 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
if llm_response.role == "assistant":
|
||||
# 文本回复
|
||||
contexts = req.contexts
|
||||
contexts = req.contexts.copy()
|
||||
contexts.append(await req.assemble_context())
|
||||
|
||||
# tool calls result
|
||||
# 记录并标记函数调用结果
|
||||
if req.tool_calls_result:
|
||||
contexts.extend(req.tool_calls_result.to_openai_messages())
|
||||
tool_calls_messages = req.tool_calls_result.to_openai_messages()
|
||||
|
||||
# 添加标记
|
||||
for message in tool_calls_messages:
|
||||
message["_tool_call_history"] = True
|
||||
|
||||
processed_tool_messages = self._process_tool_message_pairs(
|
||||
tool_calls_messages, remove_tags=False
|
||||
)
|
||||
|
||||
contexts.extend(processed_tool_messages)
|
||||
|
||||
contexts.append(
|
||||
{"role": "assistant", "content": llm_response.completion_text}
|
||||
@@ -325,3 +548,59 @@ class LLMRequestSubStage(Stage):
|
||||
await self.conv_manager.update_conversation(
|
||||
event.unified_msg_origin, req.conversation.cid, history=contexts_to_save
|
||||
)
|
||||
|
||||
def _process_tool_message_pairs(self, messages, remove_tags=True):
|
||||
"""处理工具调用消息,确保assistant和tool消息成对出现
|
||||
|
||||
Args:
|
||||
messages (list): 消息列表
|
||||
remove_tags (bool): 是否移除_tool_call_history标记
|
||||
|
||||
Returns:
|
||||
list: 处理后的消息列表,保证了assistant和对应tool消息的成对出现
|
||||
"""
|
||||
result = []
|
||||
i = 0
|
||||
|
||||
while i < len(messages):
|
||||
current_msg = messages[i]
|
||||
|
||||
# 普通消息直接添加
|
||||
if "_tool_call_history" not in current_msg:
|
||||
result.append(current_msg.copy() if remove_tags else current_msg)
|
||||
i += 1
|
||||
continue
|
||||
|
||||
# 工具调用消息成对处理
|
||||
if current_msg.get("role") == "assistant" and "tool_calls" in current_msg:
|
||||
assistant_msg = current_msg.copy()
|
||||
|
||||
if remove_tags and "_tool_call_history" in assistant_msg:
|
||||
del assistant_msg["_tool_call_history"]
|
||||
|
||||
related_tools = []
|
||||
j = i + 1
|
||||
while (
|
||||
j < len(messages)
|
||||
and messages[j].get("role") == "tool"
|
||||
and "_tool_call_history" in messages[j]
|
||||
):
|
||||
tool_msg = messages[j].copy()
|
||||
|
||||
if remove_tags:
|
||||
del tool_msg["_tool_call_history"]
|
||||
|
||||
related_tools.append(tool_msg)
|
||||
j += 1
|
||||
|
||||
# 成对的时候添加到结果
|
||||
if related_tools:
|
||||
result.append(assistant_msg)
|
||||
result.extend(related_tools)
|
||||
|
||||
i = j # 跳过已处理
|
||||
else:
|
||||
# 单独的tool消息
|
||||
i += 1
|
||||
|
||||
return result
|
||||
|
||||
@@ -31,7 +31,18 @@ class StarRequestSubStage(Stage):
|
||||
)
|
||||
if not handlers_parsed_params:
|
||||
handlers_parsed_params = {}
|
||||
|
||||
for handler in activated_handlers:
|
||||
# 检查处理器是否在当前平台兼容
|
||||
if (
|
||||
hasattr(handler, "platform_compatible")
|
||||
and handler.platform_compatible is False
|
||||
):
|
||||
logger.debug(
|
||||
f"处理器 {handler.handler_name} 在当前平台不兼容,跳过执行"
|
||||
)
|
||||
continue
|
||||
|
||||
params = handlers_parsed_params.get(handler.handler_full_name, {})
|
||||
try:
|
||||
if handler.handler_module_path not in star_map:
|
||||
|
||||
@@ -5,7 +5,7 @@ from .method.llm_request import LLMRequestSubStage
|
||||
from .method.star_request import StarRequestSubStage
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.star.star_handler import StarHandlerMetadata
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
|
||||
@@ -7,18 +7,21 @@ from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage, Stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.message.message_event_result import MessageChain, ResultContentType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.message_event_result import BaseMessageComponent
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.utils.path_util import path_Mapping
|
||||
|
||||
|
||||
@register_stage
|
||||
class RespondStage(Stage):
|
||||
# 组件类型到其非空判断函数的映射
|
||||
_component_validators = {
|
||||
Comp.Plain: lambda comp: bool(comp.text and comp.text.strip()), # 纯文本消息需要strip
|
||||
Comp.Plain: lambda comp: bool(
|
||||
comp.text and comp.text.strip()
|
||||
), # 纯文本消息需要strip
|
||||
Comp.Face: lambda comp: comp.id is not None, # QQ表情
|
||||
Comp.Record: lambda comp: bool(comp.file), # 语音
|
||||
Comp.Video: lambda comp: bool(comp.file), # 视频
|
||||
@@ -31,13 +34,17 @@ class RespondStage(Stage):
|
||||
Comp.Share: lambda comp: bool(comp.url) and bool(comp.title), # 分享
|
||||
Comp.Contact: lambda comp: True, # 联系人(未完成)
|
||||
Comp.Location: lambda comp: bool(comp.lat and comp.lon), # 位置
|
||||
Comp.Music: lambda comp: bool(comp._type) and bool(comp.url) and bool(comp.audio), # 音乐
|
||||
Comp.Music: lambda comp: bool(comp._type)
|
||||
and bool(comp.url)
|
||||
and bool(comp.audio), # 音乐
|
||||
Comp.Image: lambda comp: bool(comp.file), # 图片
|
||||
Comp.Reply: lambda comp: bool(comp.id) and comp.sender_id is not None, # 回复
|
||||
Comp.RedBag: lambda comp: bool(comp.title), # 红包
|
||||
Comp.Poke: lambda comp: comp.id != 0 and comp.qq != 0, # 戳一戳
|
||||
Comp.Forward: lambda comp: bool(comp.id and comp.id.strip()), # 转发
|
||||
Comp.Node: lambda comp: bool(comp.name) and comp.uin != 0 and bool(comp.content), # 一个转发节点
|
||||
Comp.Node: lambda comp: bool(comp.name)
|
||||
and comp.uin != 0
|
||||
and bool(comp.content), # 一个转发节点
|
||||
Comp.Nodes: lambda comp: bool(comp.nodes), # 多个转发节点
|
||||
Comp.Xml: lambda comp: bool(comp.data and comp.data.strip()), # XML
|
||||
Comp.Json: lambda comp: bool(comp.data), # JSON
|
||||
@@ -50,6 +57,8 @@ class RespondStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
self.config = ctx.astrbot_config
|
||||
self.platform_settings: dict = self.config.get("platform_settings", {})
|
||||
|
||||
self.reply_with_mention = ctx.astrbot_config["platform_settings"][
|
||||
"reply_with_mention"
|
||||
@@ -132,8 +141,28 @@ class RespondStage(Stage):
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
if result.result_content_type == ResultContentType.STREAMING_FINISH:
|
||||
return
|
||||
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
# 流式结果直接交付平台适配器处理
|
||||
use_fallback = self.config.get("provider_settings", {}).get(
|
||||
"streaming_segmented", False
|
||||
)
|
||||
logger.info(f"应用流式输出({event.get_platform_name()})")
|
||||
await event._pre_send()
|
||||
await event.send_streaming(result.async_stream, use_fallback)
|
||||
await event._post_send()
|
||||
return
|
||||
elif len(result.chain) > 0:
|
||||
# 检查路径映射
|
||||
if mappings := self.platform_settings.get("path_mapping", []):
|
||||
for idx, component in enumerate(result.chain):
|
||||
if isinstance(component, Comp.File) and component.file:
|
||||
# 支持 File 消息段的路径映射。
|
||||
component.file = path_Mapping(mappings, component.file)
|
||||
event.get_result().chain[idx] = component
|
||||
|
||||
if len(result.chain) > 0:
|
||||
await event._pre_send()
|
||||
|
||||
# 检查消息链是否为空
|
||||
@@ -176,6 +205,7 @@ class RespondStage(Stage):
|
||||
try:
|
||||
await event.send(result)
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"发送消息失败: {e} chain: {result.chain}")
|
||||
await event._post_send()
|
||||
logger.info(
|
||||
@@ -183,7 +213,7 @@ class RespondStage(Stage):
|
||||
)
|
||||
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnAfterMessageSentEvent
|
||||
EventType.OnAfterMessageSentEvent, platform_id=event.get_platform_id()
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import Union, AsyncGenerator
|
||||
from ..stage import Stage, register_stage, registered_stages
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import ResultContentType
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply, Record, File, Node
|
||||
@@ -72,11 +73,17 @@ class ResultDecorateStage(Stage):
|
||||
if result is None or not result.chain:
|
||||
return
|
||||
|
||||
if result.result_content_type == ResultContentType.STREAMING_RESULT:
|
||||
return
|
||||
|
||||
is_stream = result.result_content_type == ResultContentType.STREAMING_FINISH
|
||||
|
||||
# 回复时检查内容安全
|
||||
if (
|
||||
self.content_safe_check_reply
|
||||
and self.content_safe_check_stage
|
||||
and result.is_llm_result()
|
||||
and not is_stream # 流式输出不检查内容安全
|
||||
):
|
||||
text = ""
|
||||
for comp in result.chain:
|
||||
@@ -89,13 +96,17 @@ class ResultDecorateStage(Stage):
|
||||
|
||||
# 发送消息前事件钩子
|
||||
handlers = star_handlers_registry.get_handlers_by_event_type(
|
||||
EventType.OnDecoratingResultEvent
|
||||
EventType.OnDecoratingResultEvent, platform_id=event.get_platform_id()
|
||||
)
|
||||
for handler in handlers:
|
||||
try:
|
||||
logger.debug(
|
||||
f"hook(on_decorating_result) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}"
|
||||
)
|
||||
if is_stream:
|
||||
logger.warning(
|
||||
"启用流式输出时,依赖发送消息前事件钩子的插件可能无法正常工作"
|
||||
)
|
||||
await handler.handler(event)
|
||||
if event.get_result() is None or not event.get_result().chain:
|
||||
logger.debug(
|
||||
@@ -110,6 +121,11 @@ class ResultDecorateStage(Stage):
|
||||
)
|
||||
return
|
||||
|
||||
# 流式输出不执行下面的逻辑
|
||||
if is_stream:
|
||||
logger.info("流式输出已启用,跳过结果装饰阶段")
|
||||
return
|
||||
|
||||
# 需要再获取一次。插件可能直接对 chain 进行了替换。
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
@@ -135,9 +151,9 @@ class ResultDecorateStage(Stage):
|
||||
# 不分段回复
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
split_response = []
|
||||
for line in comp.text.split("\n"):
|
||||
split_response.extend(re.findall(self.regex, line))
|
||||
split_response = re.findall(
|
||||
self.regex, comp.text, re.DOTALL | re.MULTILINE
|
||||
)
|
||||
if not split_response:
|
||||
new_chain.append(comp)
|
||||
continue
|
||||
@@ -168,6 +184,8 @@ class ResultDecorateStage(Stage):
|
||||
new_chain.append(
|
||||
Record(file=audio_path, url=audio_path)
|
||||
)
|
||||
if(self.ctx.astrbot_config["provider_tts_settings"]["dual_output"]):
|
||||
new_chain.append(comp)
|
||||
else:
|
||||
logger.error(
|
||||
f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot import logger
|
||||
from typing import Union, AsyncGenerator
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
@@ -34,10 +35,21 @@ class WakingCheckStage(Stage):
|
||||
self.friend_message_needs_wake_prefix = self.ctx.astrbot_config[
|
||||
"platform_settings"
|
||||
].get("friend_message_needs_wake_prefix", False)
|
||||
# 是否忽略机器人自己发送的消息
|
||||
self.ignore_bot_self_message = self.ctx.astrbot_config["platform_settings"].get(
|
||||
"ignore_bot_self_message", False
|
||||
)
|
||||
|
||||
async def process(
|
||||
self, event: AstrMessageEvent
|
||||
) -> Union[None, AsyncGenerator[None, None]]:
|
||||
if (
|
||||
self.ignore_bot_self_message
|
||||
and event.get_self_id() == event.get_sender_id()
|
||||
):
|
||||
# 忽略机器人自己发送的消息
|
||||
event.stop_event()
|
||||
return
|
||||
# 设置 sender 身份
|
||||
event.message_str = event.message_str.strip()
|
||||
for admin_id in self.ctx.astrbot_config["admins_id"]:
|
||||
@@ -93,6 +105,7 @@ class WakingCheckStage(Stage):
|
||||
# filter 需满足 AND 逻辑关系
|
||||
passed = True
|
||||
permission_not_pass = False
|
||||
permission_filter_raise_error = False
|
||||
if len(handler.event_filters) == 0:
|
||||
continue
|
||||
|
||||
@@ -101,6 +114,7 @@ class WakingCheckStage(Stage):
|
||||
if isinstance(filter, PermissionTypeFilter):
|
||||
if not filter.filter(event, self.ctx.astrbot_config):
|
||||
permission_not_pass = True
|
||||
permission_filter_raise_error = filter.raise_error
|
||||
else:
|
||||
if not filter.filter(event, self.ctx.astrbot_config):
|
||||
passed = False
|
||||
@@ -117,6 +131,9 @@ class WakingCheckStage(Stage):
|
||||
break
|
||||
if passed:
|
||||
if permission_not_pass:
|
||||
if not permission_filter_raise_error:
|
||||
# 跳过
|
||||
continue
|
||||
if self.no_permission_reply:
|
||||
await event.send(
|
||||
MessageChain().message(
|
||||
@@ -124,6 +141,9 @@ class WakingCheckStage(Stage):
|
||||
)
|
||||
)
|
||||
await event._post_send()
|
||||
logger.info(
|
||||
f"触发 {star_map[handler.handler_module_path].name} 时, 用户(ID={event.get_sender_id()}) 权限不足。"
|
||||
)
|
||||
event.stop_event()
|
||||
return
|
||||
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import abc
|
||||
import asyncio
|
||||
import re
|
||||
import hashlib
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union, Optional
|
||||
from typing import List, Union, Optional, AsyncGenerator
|
||||
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.message.components import (
|
||||
@@ -16,7 +19,7 @@ from astrbot.core.message.components import (
|
||||
)
|
||||
from astrbot.core.message.message_event_result import MessageEventResult, MessageChain
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core.provider.entites import ProviderRequest
|
||||
from astrbot.core.provider.entities import ProviderRequest
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
from .astrbot_message import AstrBotMessage, Group
|
||||
from .platform_metadata import PlatformMetadata
|
||||
@@ -81,6 +84,9 @@ class AstrMessageEvent(abc.ABC):
|
||||
def get_platform_name(self):
|
||||
return self.platform_meta.name
|
||||
|
||||
def get_platform_id(self):
|
||||
return self.platform_meta.id
|
||||
|
||||
def get_message_str(self) -> str:
|
||||
"""
|
||||
获取消息字符串。
|
||||
@@ -202,6 +208,32 @@ class AstrMessageEvent(abc.ABC):
|
||||
"""
|
||||
return self.role == "admin"
|
||||
|
||||
async def process_buffer(self, buffer: str, pattern: re.Pattern) -> str:
|
||||
"""
|
||||
将消息缓冲区中的文本按指定正则表达式分割后发送至消息平台,作为不支持流式输出平台的Fallback。
|
||||
"""
|
||||
while True:
|
||||
match = re.search(pattern, buffer)
|
||||
if not match:
|
||||
break
|
||||
matched_text = match.group()
|
||||
await self.send(MessageChain([Plain(matched_text)]))
|
||||
buffer = buffer[match.end() :]
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
return buffer
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator[MessageChain, None], use_fallback: bool = False
|
||||
):
|
||||
"""发送流式消息到消息平台,使用异步生成器。
|
||||
目前仅支持: telegram,qq official 私聊。
|
||||
Fallback仅支持 aiocqhttp, gewechat。
|
||||
"""
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
async def _pre_send(self):
|
||||
"""调度器会在执行 send() 前调用该方法"""
|
||||
|
||||
@@ -372,8 +404,13 @@ class AstrMessageEvent(abc.ABC):
|
||||
Args:
|
||||
message (MessageChain): 消息链,具体使用方式请参考文档。
|
||||
"""
|
||||
# Leverage BLAKE2 hash function to generate a non-reversible hash of the sender ID for privacy.
|
||||
hash_obj = hashlib.blake2b(self.get_sender_id().encode("utf-8"), digest_size=16)
|
||||
sid = str(uuid.UUID(bytes=hash_obj.digest()))
|
||||
asyncio.create_task(
|
||||
Metric.upload(msg_event_tick=1, adapter_name=self.platform_meta.name)
|
||||
Metric.upload(
|
||||
msg_event_tick=1, adapter_name=self.platform_meta.name, sid=sid
|
||||
)
|
||||
)
|
||||
self._has_send_oper = True
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ class PlatformMetadata:
|
||||
"""平台的名称"""
|
||||
description: str
|
||||
"""平台的描述"""
|
||||
id: str = None
|
||||
"""平台的唯一标识符,用于配置中识别特定平台"""
|
||||
|
||||
default_config_tmpl: dict = None
|
||||
"""平台的默认配置模板"""
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import typing
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
from astrbot.api.message_components import Plain, Image, Record, At, Node, Nodes
|
||||
import re
|
||||
from typing import AsyncGenerator, Dict, List
|
||||
from aiocqhttp import CQHttp
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import At, Image, Node, Nodes, Plain, Record
|
||||
from astrbot.api.platform import Group, MessageMember
|
||||
|
||||
|
||||
class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
@@ -29,7 +30,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
# convert to base64
|
||||
bs64 = await segment.convert_to_base64()
|
||||
d["data"] = {
|
||||
"file": bs64,
|
||||
"file": f"base64://{bs64}",
|
||||
}
|
||||
elif isinstance(segment, At):
|
||||
d["data"] = {
|
||||
@@ -82,6 +83,40 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def get_group(self, group_id=None, **kwargs):
|
||||
if isinstance(group_id, str) and group_id.isdigit():
|
||||
group_id = int(group_id)
|
||||
@@ -95,7 +130,7 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
members: typing.List[typing.Dict] = await self.bot.call_action(
|
||||
members: List[Dict] = await self.bot.call_action(
|
||||
"get_group_member_list",
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import os
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
import itertools
|
||||
from typing import Awaitable, Any
|
||||
from aiocqhttp import CQHttp, Event
|
||||
from astrbot.api.platform import (
|
||||
@@ -20,7 +20,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from aiocqhttp.exceptions import ActionFailed
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
|
||||
@register_platform_adapter(
|
||||
@@ -39,12 +38,18 @@ class AiocqhttpAdapter(Platform):
|
||||
self.port = platform_config["ws_reverse_port"]
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
"aiocqhttp",
|
||||
"适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
name="aiocqhttp",
|
||||
description="适用于 OneBot 标准的消息平台适配器,支持反向 WebSockets。",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
self.bot = CQHttp(
|
||||
use_ws_reverse=True, import_name="aiocqhttp", api_timeout_sec=180
|
||||
use_ws_reverse=True,
|
||||
import_name="aiocqhttp",
|
||||
api_timeout_sec=180,
|
||||
access_token=platform_config.get(
|
||||
"ws_reverse_token"
|
||||
), # 以防旧版本配置不存在
|
||||
)
|
||||
|
||||
@self.bot.on_request()
|
||||
@@ -109,7 +114,7 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 请求类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if "group_id" in event and event["group_id"]:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
@@ -118,6 +123,12 @@ class AiocqhttpAdapter(Platform):
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if self.unique_session and abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = str(abm.sender.user_id) + "_" + str(event.group_id)
|
||||
else:
|
||||
abm.session_id = (
|
||||
str(event.group_id)
|
||||
if abm.type == MessageType.GROUP_MESSAGE
|
||||
else abm.sender.user_id
|
||||
)
|
||||
abm.message_str = ""
|
||||
abm.message = []
|
||||
abm.timestamp = int(time.time())
|
||||
@@ -129,7 +140,7 @@ class AiocqhttpAdapter(Platform):
|
||||
"""OneBot V11 通知类事件"""
|
||||
abm = AstrBotMessage()
|
||||
abm.self_id = str(event.self_id)
|
||||
abm.sender = MessageMember(user_id=event.user_id, nickname=event.user_id)
|
||||
abm.sender = MessageMember(user_id=str(event.user_id), nickname=event.user_id)
|
||||
abm.type = MessageType.OTHER_MESSAGE
|
||||
if "group_id" in event and event["group_id"]:
|
||||
abm.group_id = str(event.group_id)
|
||||
@@ -154,7 +165,9 @@ class AiocqhttpAdapter(Platform):
|
||||
|
||||
if "sub_type" in event:
|
||||
if event["sub_type"] == "poke" and "target_id" in event:
|
||||
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
|
||||
abm.message.append(
|
||||
Poke(qq=str(event["target_id"]), type="poke")
|
||||
) # noqa: F405
|
||||
|
||||
return abm
|
||||
|
||||
@@ -201,82 +214,83 @@ class AiocqhttpAdapter(Platform):
|
||||
return
|
||||
|
||||
# 按消息段类型类型适配
|
||||
for m in event.message:
|
||||
t = m["type"]
|
||||
for t, m_group in itertools.groupby(event.message, key=lambda x: x["type"]):
|
||||
a = None
|
||||
if t == "text":
|
||||
message_str += m["data"]["text"].strip()
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
# 合并相邻文本段
|
||||
message_str = "".join(m["data"]["text"] for m in m_group).strip()
|
||||
a = ComponentTypes[t](text=message_str) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
elif t == "file":
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
for m in m_group:
|
||||
if m["data"].get("url") and m["data"].get("url").startswith("http"):
|
||||
# Lagrange
|
||||
logger.info("guessing lagrange")
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
abm.message.append(File(name=file_name, url=m["data"]["url"]))
|
||||
else:
|
||||
try:
|
||||
# Napcat
|
||||
ret = None
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
ret = await self.bot.call_action(
|
||||
action="get_group_file_url",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
group_id=event.group_id,
|
||||
)
|
||||
elif abm.type == MessageType.FRIEND_MESSAGE:
|
||||
ret = await self.bot.call_action(
|
||||
action="get_private_file_url",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
)
|
||||
if ret and "url" in ret:
|
||||
file_url = ret["url"] # https
|
||||
a = File(name="", url=file_url)
|
||||
abm.message.append(a)
|
||||
else:
|
||||
logger.error(f"获取文件失败: {ret}")
|
||||
|
||||
file_name = m["data"].get("file_name", "file")
|
||||
path = os.path.join("data/temp", file_name)
|
||||
await download_file(m["data"]["url"], path)
|
||||
|
||||
m["data"] = {"file": path, "name": file_name}
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
else:
|
||||
try:
|
||||
# Napcat, LLBot
|
||||
ret = await self.bot.call_action(
|
||||
action="get_file",
|
||||
file_id=event.message[0]["data"]["file_id"],
|
||||
)
|
||||
if not ret.get("file", None):
|
||||
raise ValueError(f"无法解析文件响应: {ret}")
|
||||
if not os.path.exists(ret["file"]):
|
||||
raise FileNotFoundError(
|
||||
f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot"
|
||||
)
|
||||
|
||||
m["data"] = {"file": ret["file"], "name": ret["file_name"]}
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except ActionFailed as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
except BaseException as e:
|
||||
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
|
||||
|
||||
elif t == "reply":
|
||||
if not get_reply:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
else:
|
||||
try:
|
||||
reply_event_data = await self.bot.call_action(
|
||||
action="get_msg",
|
||||
message_id=int(m["data"]["id"]),
|
||||
)
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
id=abm_reply.message_id,
|
||||
chain=abm_reply.message,
|
||||
sender_id=abm_reply.sender.user_id,
|
||||
sender_nickname=abm_reply.sender.nickname,
|
||||
time=abm_reply.timestamp,
|
||||
message_str=abm_reply.message_str,
|
||||
text=abm_reply.message_str, # for compatibility
|
||||
qq=abm_reply.sender.user_id, # for compatibility
|
||||
)
|
||||
|
||||
abm.message.append(reply_seg)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取引用消息失败: {e}。")
|
||||
for m in m_group:
|
||||
if not get_reply:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
else:
|
||||
try:
|
||||
reply_event_data = await self.bot.call_action(
|
||||
action="get_msg",
|
||||
message_id=int(m["data"]["id"]),
|
||||
)
|
||||
abm_reply = await self._convert_handle_message_event(
|
||||
Event.from_payload(reply_event_data), get_reply=False
|
||||
)
|
||||
|
||||
reply_seg = Reply(
|
||||
id=abm_reply.message_id,
|
||||
chain=abm_reply.message,
|
||||
sender_id=abm_reply.sender.user_id,
|
||||
sender_nickname=abm_reply.sender.nickname,
|
||||
time=abm_reply.timestamp,
|
||||
message_str=abm_reply.message_str,
|
||||
text=abm_reply.message_str, # for compatibility
|
||||
qq=abm_reply.sender.user_id, # for compatibility
|
||||
)
|
||||
|
||||
abm.message.append(reply_seg)
|
||||
except BaseException as e:
|
||||
logger.error(f"获取引用消息失败: {e}。")
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
else:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
for m in m_group:
|
||||
a = ComponentTypes[t](**m["data"]) # noqa: F405
|
||||
abm.message.append(a)
|
||||
|
||||
abm.timestamp = int(time.time())
|
||||
abm.message_str = message_str
|
||||
|
||||
@@ -73,8 +73,9 @@ class DingtalkPlatformAdapter(Platform):
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"dingtalk",
|
||||
"钉钉机器人官方 API 适配器",
|
||||
name="dingtalk",
|
||||
description="钉钉机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def convert_msg(
|
||||
|
||||
@@ -24,7 +24,11 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
if isinstance(segment, Comp.Plain):
|
||||
segment.text = segment.text.strip()
|
||||
await asyncio.get_event_loop().run_in_executor(
|
||||
None, client.reply_markdown, "AstrBot", segment.text, self.message_obj.raw_message
|
||||
None,
|
||||
client.reply_markdown,
|
||||
"AstrBot",
|
||||
segment.text,
|
||||
self.message_obj.raw_message,
|
||||
)
|
||||
elif isinstance(segment, Comp.Image):
|
||||
markdown_str = ""
|
||||
@@ -56,3 +60,16 @@ class DingtalkMessageEvent(AstrMessageEvent):
|
||||
async def send(self, message: MessageChain):
|
||||
await self.send_with_client(self.client, message)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -3,6 +3,7 @@ import base64
|
||||
import datetime
|
||||
import os
|
||||
import re
|
||||
import uuid
|
||||
import threading
|
||||
|
||||
import aiohttp
|
||||
@@ -63,7 +64,7 @@ class SimpleGewechatClient:
|
||||
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
|
||||
)
|
||||
self.server.add_url_rule(
|
||||
"/astrbot-gewechat/file/<file_id>",
|
||||
"/astrbot-gewechat/file/<file_token>",
|
||||
view_func=self._handle_file,
|
||||
methods=["GET"],
|
||||
)
|
||||
@@ -81,6 +82,11 @@ class SimpleGewechatClient:
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
|
||||
self.staged_files = {}
|
||||
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
|
||||
async def get_token_id(self):
|
||||
"""获取 Gewechat Token。"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -143,18 +149,25 @@ class SimpleGewechatClient:
|
||||
content = d["Content"]["string"] # 消息内容
|
||||
|
||||
at_me = False
|
||||
at_wxids = []
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(":\n")
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
if "\u2005" in content:
|
||||
# at
|
||||
# content = content.split('\u2005')[1]
|
||||
content = re.sub(r"@[^\u2005]*\u2005", "", content)
|
||||
at_wxids = re.findall(
|
||||
r"<atuserlist><!\[CDATA\[.*?(?:,|\b)([^,]+?)(?=,|\]\]></atuserlist>)",
|
||||
msg_source,
|
||||
)
|
||||
|
||||
abm.group_id = from_user_name
|
||||
# at
|
||||
msg_source = d["MsgSource"]
|
||||
|
||||
if (
|
||||
f"<atuserlist><![CDATA[,{abm.self_id}]]>" in msg_source
|
||||
or f"<atuserlist><![CDATA[{abm.self_id}]]>" in msg_source
|
||||
@@ -167,13 +180,12 @@ class SimpleGewechatClient:
|
||||
user_id = from_user_name
|
||||
|
||||
# 检查消息是否由自己发送,若是则忽略
|
||||
if user_id == abm.self_id:
|
||||
logger.info("忽略自己发送的消息")
|
||||
return None
|
||||
# 已经有可配置项专门配置是否需要响应自己的消息,因此这里注释掉。
|
||||
# if user_id == abm.self_id:
|
||||
# logger.info("忽略自己发送的消息")
|
||||
# return None
|
||||
|
||||
abm.message = []
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
# 解析用户真实名字
|
||||
user_real_name = "unknown"
|
||||
@@ -197,7 +209,19 @@ class SimpleGewechatClient:
|
||||
else:
|
||||
user_real_name = self.userrealnames[abm.group_id][user_id]
|
||||
else:
|
||||
user_real_name = d.get("PushContent", "unknown : ").split(" : ")[0]
|
||||
try:
|
||||
info = (await self.get_user_or_group_info(user_id))["data"][0]
|
||||
user_real_name = info["nickName"]
|
||||
except Exception as e:
|
||||
logger.debug(f"获取用户 {user_id} 昵称失败: {e}")
|
||||
user_real_name = user_id
|
||||
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id, name=self.nickname))
|
||||
for wxid in at_wxids:
|
||||
# 群聊里 At 其他人的列表
|
||||
_username = self.userrealnames.get(abm.group_id, {}).get(wxid, wxid)
|
||||
abm.message.append(At(qq=wxid, name=_username))
|
||||
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.raw_message = d
|
||||
@@ -248,9 +272,12 @@ class SimpleGewechatClient:
|
||||
logger.info("消息类型(48):地理位置")
|
||||
case 49: # 公众号/文件/小程序/引用/转账/红包/视频号/群聊邀请
|
||||
data_parser = GeweDataParser(content, abm.group_id == "")
|
||||
abm_data = data_parser.parse_mutil_49()
|
||||
if abm_data:
|
||||
abm.message.append(abm_data)
|
||||
segments = data_parser.parse_mutil_49()
|
||||
if segments:
|
||||
abm.message.extend(segments)
|
||||
for seg in segments:
|
||||
if isinstance(seg, Plain):
|
||||
abm.message_str += seg.text
|
||||
case 51: # 帐号消息同步?
|
||||
logger.info("消息类型(51):帐号消息同步?")
|
||||
case 10000: # 被踢出群聊/更换群主/修改群名称
|
||||
@@ -289,9 +316,33 @@ class SimpleGewechatClient:
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def _handle_file(self, file_id):
|
||||
file_path = f"data/temp/{file_id}"
|
||||
return await quart.send_file(file_path)
|
||||
async def _register_file(self, file_path: str) -> str:
|
||||
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
|
||||
|
||||
Args:
|
||||
file_path (str): 文件路径。
|
||||
Returns:
|
||||
str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
|
||||
"""
|
||||
async with self.lock:
|
||||
if not os.path.exists(file_path):
|
||||
raise Exception(f"文件不存在: {file_path}")
|
||||
|
||||
file_token = str(uuid.uuid4())
|
||||
self.staged_files[file_token] = file_path
|
||||
return file_token
|
||||
|
||||
async def _handle_file(self, file_token):
|
||||
async with self.lock:
|
||||
if file_token not in self.staged_files:
|
||||
logger.warning(f"请求的文件 {file_token} 不存在。")
|
||||
return quart.abort(404)
|
||||
if not os.path.exists(self.staged_files[file_token]):
|
||||
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
|
||||
return quart.abort(404)
|
||||
file_path = self.staged_files[file_token]
|
||||
self.staged_files.pop(file_token, None)
|
||||
return await quart.send_file(file_path)
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
@@ -441,17 +492,18 @@ class SimpleGewechatClient:
|
||||
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
|
||||
)
|
||||
else:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
if "status" in json_blob["data"]:
|
||||
status = json_blob["data"]["status"]
|
||||
nickname = json_blob["data"].get("nickName", "")
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
import asyncio
|
||||
import re
|
||||
import wave
|
||||
import uuid
|
||||
import traceback
|
||||
import os
|
||||
|
||||
from typing import AsyncGenerator
|
||||
from astrbot.core.utils.io import save_temp_img, download_file
|
||||
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
|
||||
from astrbot.api import logger
|
||||
@@ -80,15 +83,9 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
# 检查 record_path 是否在 data/temp 目录中
|
||||
temp_directory = os.path.abspath("data/temp")
|
||||
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
|
||||
with open(img_path, "rb") as f:
|
||||
img_path = save_temp_img(f.read())
|
||||
|
||||
file_id = os.path.basename(img_path)
|
||||
img_url = f"{client.file_server_url}/{file_id}"
|
||||
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
|
||||
token = await client._register_file(img_path)
|
||||
img_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback img url: {img_url}")
|
||||
await client.post_image(to_wxid, img_url)
|
||||
elif isinstance(comp, Video):
|
||||
@@ -107,20 +104,29 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
|
||||
video_url = comp.file
|
||||
# 根据 url 下载视频
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
video_path = f"data/temp/{video_filename}"
|
||||
await download_file(video_url, video_path)
|
||||
if video_url.startswith("http"):
|
||||
video_filename = f"{uuid.uuid4()}.mp4"
|
||||
video_path = f"data/temp/{video_filename}"
|
||||
await download_file(video_url, video_path)
|
||||
else:
|
||||
video_path = video_url
|
||||
|
||||
video_token = await client._register_file(video_path)
|
||||
video_callback_url = f"{client.file_server_url}/{video_token}"
|
||||
|
||||
# 获取视频第一帧
|
||||
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
|
||||
thumb_path = f"data/temp/gewechat_video_thumb_{uuid.uuid4()}.jpg"
|
||||
|
||||
video_path = video_path.replace(" ", "\\ ")
|
||||
try:
|
||||
ff = FFmpeg()
|
||||
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
|
||||
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
|
||||
ff.options(command)
|
||||
thumb_file_id = os.path.basename(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
|
||||
thumb_token = await client._register_file(thumb_path)
|
||||
thumb_url = f"{client.file_server_url}/{thumb_token}"
|
||||
except Exception as e:
|
||||
logger.error(f"获取视频第一帧失败: {e}")
|
||||
|
||||
# 获取视频时长
|
||||
try:
|
||||
from pyffmpeg import FFprobe
|
||||
@@ -135,15 +141,12 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
logger.error(f"获取时长失败: {e}")
|
||||
video_duration = 10
|
||||
|
||||
file_id = os.path.basename(video_path)
|
||||
video_url = f"{client.file_server_url}/{file_id}"
|
||||
# 发送视频
|
||||
await client.post_video(
|
||||
to_wxid, video_url, thumb_url, video_duration
|
||||
to_wxid, video_callback_url, thumb_url, video_duration
|
||||
)
|
||||
|
||||
# 删除临时视频和缩略图文件
|
||||
if os.path.exists(video_path):
|
||||
os.remove(video_path)
|
||||
# 删除临时缩略图文件
|
||||
if os.path.exists(thumb_path):
|
||||
os.remove(thumb_path)
|
||||
elif isinstance(comp, Record):
|
||||
@@ -160,8 +163,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
logger.info("Silk 语音文件格式转换至: " + record_path)
|
||||
if duration == 0:
|
||||
duration = get_wav_duration(record_path)
|
||||
file_id = os.path.basename(silk_path)
|
||||
record_url = f"{client.file_server_url}/{file_id}"
|
||||
token = await client._register_file(silk_path)
|
||||
record_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback record url: {record_url}")
|
||||
await client.post_voice(to_wxid, record_url, duration * 1000)
|
||||
elif isinstance(comp, File):
|
||||
@@ -174,10 +177,10 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
file_path = file_path
|
||||
|
||||
file_id = os.path.basename(file_path)
|
||||
file_url = f"{client.file_server_url}/{file_id}"
|
||||
token = await client._register_file(file_path)
|
||||
file_url = f"{client.file_server_url}/{token}"
|
||||
logger.debug(f"gewe callback file url: {file_url}")
|
||||
await client.post_file(to_wxid, file_url, file_id)
|
||||
await client.post_file(to_wxid, file_url, file_name)
|
||||
elif isinstance(comp, Emoji):
|
||||
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
|
||||
elif isinstance(comp, At):
|
||||
@@ -216,3 +219,37 @@ class GewechatPlatformEvent(AstrMessageEvent):
|
||||
group_owner=data.get("chatRoomOwner"),
|
||||
members=members,
|
||||
)
|
||||
|
||||
async def send_streaming(
|
||||
self, generator: AsyncGenerator, use_fallback: bool = False
|
||||
):
|
||||
if not use_fallback:
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
buffer = ""
|
||||
pattern = re.compile(r"[^。?!~…]+[。?!~…]+")
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
for comp in chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
buffer += comp.text
|
||||
if any(p in buffer for p in "。?!~…"):
|
||||
buffer = await self.process_buffer(buffer, pattern)
|
||||
else:
|
||||
await self.send(MessageChain(chain=[comp]))
|
||||
await asyncio.sleep(1.5) # 限速
|
||||
|
||||
if buffer.strip():
|
||||
await self.send(MessageChain([Plain(buffer)]))
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -60,13 +60,17 @@ class GewechatPlatformAdapter(Platform):
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"gewechat",
|
||||
"基于 gewechat 的 Wechat 适配器",
|
||||
name="gewechat",
|
||||
description="基于 gewechat 的 Wechat 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def terminate(self):
|
||||
self.client.shutdown_event.set()
|
||||
await self.client.server.shutdown()
|
||||
try:
|
||||
await self.client.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("Gewechat 适配器已被优雅地关闭。")
|
||||
|
||||
async def logout(self):
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
from defusedxml import ElementTree as eT
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.message_components import WechatEmoji as Emoji, Reply, Plain
|
||||
from astrbot.api.message_components import (
|
||||
WechatEmoji as Emoji,
|
||||
Reply,
|
||||
Plain,
|
||||
BaseMessageComponent,
|
||||
)
|
||||
|
||||
|
||||
class GeweDataParser:
|
||||
@@ -11,7 +16,7 @@ class GeweDataParser:
|
||||
def _format_to_xml(self):
|
||||
return eT.fromstring(self.data)
|
||||
|
||||
def parse_mutil_49(self):
|
||||
def parse_mutil_49(self) -> list[BaseMessageComponent] | None:
|
||||
appmsg_type = self._format_to_xml().find(".//appmsg/type")
|
||||
if appmsg_type is None:
|
||||
return
|
||||
@@ -34,13 +39,18 @@ class GeweDataParser:
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_emoji failed, {e}")
|
||||
|
||||
def parse_reply(self) -> Reply | None:
|
||||
def parse_reply(self) -> list[Reply, Plain] | None:
|
||||
"""解析引用消息
|
||||
|
||||
Returns:
|
||||
list[Reply, Plain]: 一个包含两个元素的列表。Reply 消息对象和引用者说的文本内容。微信平台下引用消息时只能发送文本消息。
|
||||
"""
|
||||
try:
|
||||
replied_id = -1
|
||||
replied_uid = 0
|
||||
replied_nickname = ""
|
||||
replied_content = ""
|
||||
content = ""
|
||||
replied_content = "" # 被引用者说的内容
|
||||
content = "" # 引用者说的内容
|
||||
|
||||
root = self._format_to_xml()
|
||||
refermsg = root.find(".//refermsg")
|
||||
@@ -57,22 +67,44 @@ class GeweDataParser:
|
||||
if displayname is not None:
|
||||
replied_nickname = displayname.text
|
||||
if refermsg_content is not None:
|
||||
replied_content = refermsg_content.text
|
||||
# 处理引用嵌套,包括嵌套公众号消息
|
||||
if refermsg_content.text.startswith(
|
||||
"<msg>"
|
||||
) or refermsg_content.text.startswith("<?xml"):
|
||||
try:
|
||||
logger.debug("gewechat: Reference message is nested")
|
||||
refer_root = eT.fromstring(refermsg_content.text)
|
||||
img = refer_root.find("img")
|
||||
if img is not None:
|
||||
replied_content = "[图片]"
|
||||
else:
|
||||
app_msg = refer_root.find("appmsg")
|
||||
refermsg_content_title = app_msg.find("title")
|
||||
logger.debug(
|
||||
f"gewechat: Reference message nesting: {refermsg_content_title.text}"
|
||||
)
|
||||
replied_content = refermsg_content_title.text
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: nested failed, {e}")
|
||||
# 处理异常情况
|
||||
replied_content = refermsg_content.text
|
||||
else:
|
||||
replied_content = refermsg_content.text
|
||||
|
||||
# 提取引用者说的内容
|
||||
title = root.find(".//appmsg/title")
|
||||
if title is not None:
|
||||
content = title.text
|
||||
|
||||
r = Reply(
|
||||
reply_seg = Reply(
|
||||
id=replied_id,
|
||||
chain=[Plain(content)],
|
||||
chain=[Plain(replied_content)],
|
||||
sender_id=replied_uid,
|
||||
sender_nickname=replied_nickname,
|
||||
sender_str=replied_content,
|
||||
message_str=content,
|
||||
message_str=replied_content,
|
||||
)
|
||||
return r
|
||||
plain_seg = Plain(content)
|
||||
return [reply_seg, plain_seg]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"gewechat: parse_reply failed, {e}")
|
||||
|
||||
@@ -2,6 +2,7 @@ import base64
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
import astrbot.api.message_components as Comp
|
||||
|
||||
from astrbot.api.platform import (
|
||||
@@ -66,12 +67,47 @@ class LarkPlatformAdapter(Platform):
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
):
|
||||
raise NotImplementedError("Lark 适配器不支持 send_by_session")
|
||||
res = await LarkMessageEvent._convert_to_lark(message_chain, self.lark_api)
|
||||
wrapped = {
|
||||
"zh_cn": {
|
||||
"title": "",
|
||||
"content": res,
|
||||
}
|
||||
}
|
||||
|
||||
if session.message_type == MessageType.GROUP_MESSAGE:
|
||||
id_type = "chat_id"
|
||||
if "%" in session.session_id:
|
||||
session.session_id = session.session_id.split("%")[1]
|
||||
else:
|
||||
id_type = "open_id"
|
||||
|
||||
request = (
|
||||
CreateMessageRequest.builder()
|
||||
.receive_id_type(id_type)
|
||||
.request_body(
|
||||
CreateMessageRequestBody.builder()
|
||||
.receive_id(session.session_id)
|
||||
.content(json.dumps(wrapped))
|
||||
.msg_type("post")
|
||||
.uuid(str(uuid.uuid4()))
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
response = await self.lark_api.im.v1.message.acreate(request)
|
||||
|
||||
if not response.success():
|
||||
logger.error(f"发送飞书消息失败({response.code}): {response.msg}")
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"lark",
|
||||
"飞书机器人官方 API 适配器",
|
||||
name="lark",
|
||||
description="飞书机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def convert_msg(self, event: lark.im.v1.P2ImMessageReceiveV1):
|
||||
@@ -165,7 +201,10 @@ class LarkPlatformAdapter(Platform):
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
if abm.type == MessageType.GROUP_MESSAGE:
|
||||
abm.session_id = f"{abm.sender.user_id}%{abm.group_id}" # 也保留群组id
|
||||
else:
|
||||
abm.session_id = abm.sender.user_id
|
||||
|
||||
logger.debug(abm)
|
||||
await self.handle_msg(abm)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import json
|
||||
import uuid
|
||||
import base64
|
||||
import lark_oapi as lark
|
||||
from io import BytesIO
|
||||
from typing import List
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.message_components import Plain, Image as AstrBotImage, At
|
||||
@@ -27,22 +29,32 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
_stage.append({"tag": "at", "user_id": comp.qq, "style": []})
|
||||
elif isinstance(comp, AstrBotImage):
|
||||
file_path = ""
|
||||
image_file = None
|
||||
|
||||
if comp.file and comp.file.startswith("file:///"):
|
||||
file_path = comp.file.replace("file:///", "")
|
||||
elif comp.file and comp.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(comp.file)
|
||||
file_path = image_file_path
|
||||
elif comp.file and comp.file.startswith("base64://"):
|
||||
pass
|
||||
base64_str = comp.file.removeprefix("base64://")
|
||||
image_data = base64.b64decode(base64_str)
|
||||
# save as temp file
|
||||
file_path = f"data/temp/{uuid.uuid4()}_test.jpg"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(BytesIO(image_data).getvalue())
|
||||
else:
|
||||
file_path = comp.file
|
||||
|
||||
if image_file is None:
|
||||
image_file = open(file_path, "rb")
|
||||
|
||||
request = (
|
||||
CreateImageRequest.builder()
|
||||
.request_body(
|
||||
CreateImageRequestBody.builder()
|
||||
.image_type("message")
|
||||
.image(open(file_path, "rb"))
|
||||
.image(image_file)
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
@@ -51,7 +63,7 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
if not response.success():
|
||||
logger.error(f"无法上传飞书图片({response.code}): {response.msg}")
|
||||
image_key = response.data.image_key
|
||||
print(image_key)
|
||||
logger.debug(image_key)
|
||||
ret.append(_stage)
|
||||
ret.append([{"tag": "img", "image_key": image_key}])
|
||||
_stage.clear()
|
||||
@@ -91,3 +103,16 @@ class LarkMessageEvent(AstrMessageEvent):
|
||||
logger.error(f"回复飞书消息失败({response.code}): {response.msg}")
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -2,6 +2,7 @@ import botpy
|
||||
import botpy.message
|
||||
import botpy.types
|
||||
import botpy.types.message
|
||||
import asyncio
|
||||
from astrbot.core.utils.io import file_to_base64, download_image_by_url
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
@@ -9,6 +10,8 @@ from astrbot.api.message_components import Plain, Image
|
||||
from botpy import Client
|
||||
from botpy.http import Route
|
||||
from astrbot.api import logger
|
||||
from botpy.types import message
|
||||
import random
|
||||
|
||||
|
||||
class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
@@ -30,8 +33,45 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
else:
|
||||
self.send_buffer.chain.extend(message.chain)
|
||||
|
||||
async def _post_send(self):
|
||||
"""QQ 官方 API 仅支持回复一次"""
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
"""流式输出仅支持消息列表私聊"""
|
||||
stream_payload = {"state": 1, "id": None, "index": 0, "reset": False}
|
||||
last_edit_time = 0 # 上次编辑消息的时间
|
||||
throttle_interval = 1 # 编辑消息的间隔时间 (秒)
|
||||
try:
|
||||
async for chain in generator:
|
||||
source = self.message_obj.raw_message
|
||||
if not self.send_buffer:
|
||||
self.send_buffer = chain
|
||||
else:
|
||||
self.send_buffer.chain.extend(chain.chain)
|
||||
|
||||
if isinstance(source, botpy.message.C2CMessage):
|
||||
# 真流式传输
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
if time_since_last_edit >= throttle_interval:
|
||||
ret = await self._post_send(stream=stream_payload)
|
||||
stream_payload["index"] += 1
|
||||
stream_payload["id"] = ret["id"]
|
||||
last_edit_time = asyncio.get_event_loop().time()
|
||||
|
||||
if isinstance(source, botpy.message.C2CMessage):
|
||||
# 结束流式对话,并且传输 buffer 中剩余的消息
|
||||
stream_payload["state"] = 10
|
||||
ret = await self._post_send(stream=stream_payload)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送流式消息时出错: {e}", exc_info=True)
|
||||
self.send_buffer = None
|
||||
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
async def _post_send(self, stream: dict = None):
|
||||
if not self.send_buffer:
|
||||
return
|
||||
|
||||
source = self.message_obj.raw_message
|
||||
assert isinstance(
|
||||
source,
|
||||
@@ -57,6 +97,9 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
"msg_id": self.message_obj.message_id,
|
||||
}
|
||||
|
||||
if not isinstance(source, (botpy.message.Message, botpy.message.DirectMessage)):
|
||||
payload["msg_seq"] = random.randint(1, 10000)
|
||||
|
||||
match type(source):
|
||||
case botpy.message.GroupMessage:
|
||||
if image_base64:
|
||||
@@ -65,7 +108,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
await self.bot.api.post_group_message(
|
||||
ret = await self.bot.api.post_group_message(
|
||||
group_openid=source.group_openid, **payload
|
||||
)
|
||||
case botpy.message.C2CMessage:
|
||||
@@ -75,22 +118,34 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
payload["media"] = media
|
||||
payload["msg_type"] = 7
|
||||
await self.bot.api.post_c2c_message(
|
||||
openid=source.author.user_openid, **payload
|
||||
)
|
||||
if stream:
|
||||
ret = await self.post_c2c_message(
|
||||
openid=source.author.user_openid,
|
||||
**payload,
|
||||
stream=stream,
|
||||
)
|
||||
else:
|
||||
ret = await self.post_c2c_message(
|
||||
openid=source.author.user_openid, **payload
|
||||
)
|
||||
logger.debug(f"Message sent to C2C: {ret}")
|
||||
case botpy.message.Message:
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
await self.bot.api.post_message(channel_id=source.channel_id, **payload)
|
||||
ret = await self.bot.api.post_message(
|
||||
channel_id=source.channel_id, **payload
|
||||
)
|
||||
case botpy.message.DirectMessage:
|
||||
if image_path:
|
||||
payload["file_image"] = image_path
|
||||
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
ret = await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
|
||||
|
||||
await super().send(self.send_buffer)
|
||||
|
||||
self.send_buffer = None
|
||||
|
||||
return ret
|
||||
|
||||
async def upload_group_and_c2c_image(
|
||||
self, image_base64: str, file_type: int, **kwargs
|
||||
) -> botpy.types.message.Media:
|
||||
@@ -112,6 +167,27 @@ class QQOfficialMessageEvent(AstrMessageEvent):
|
||||
)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
async def post_c2c_message(
|
||||
self,
|
||||
openid: str,
|
||||
msg_type: int = 0,
|
||||
content: str = None,
|
||||
embed: message.Embed = None,
|
||||
ark: message.Ark = None,
|
||||
message_reference: message.Reference = None,
|
||||
media: message.Media = None,
|
||||
msg_id: str = None,
|
||||
msg_seq: str = 1,
|
||||
event_id: str = None,
|
||||
markdown: message.MarkdownPayload = None,
|
||||
keyboard: message.Keyboard = None,
|
||||
stream: dict = None,
|
||||
) -> message.Message:
|
||||
payload = locals()
|
||||
payload.pop("self", None)
|
||||
route = Route("POST", "/v2/users/{openid}/messages", openid=openid)
|
||||
return await self.bot.api._http.request(route, json=payload)
|
||||
|
||||
@staticmethod
|
||||
async def _parse_to_qqofficial(message: MessageChain):
|
||||
plain_text = ""
|
||||
|
||||
@@ -126,8 +126,9 @@ class QQOfficialPlatformAdapter(Platform):
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"qq_official",
|
||||
"QQ 机器人官方 API 适配器",
|
||||
name="qq_official",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -99,8 +99,9 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"qq_official_webhook",
|
||||
"QQ 机器人官方 API 适配器",
|
||||
name="qq_official_webhook",
|
||||
description="QQ 机器人官方 API 适配器",
|
||||
id=self.config.get("id"),
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
@@ -116,5 +117,8 @@ class QQOfficialWebhookPlatformAdapter(Platform):
|
||||
async def terminate(self):
|
||||
self.webhook_helper.shutdown_event.set()
|
||||
await self.client.close()
|
||||
await self.webhook_helper.server.shutdown()
|
||||
try:
|
||||
await self.webhook_helper.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("QQ 机器人官方 API 适配器已经被优雅地关闭")
|
||||
|
||||
@@ -1,26 +1,32 @@
|
||||
import asyncio
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
import astrbot.api.message_components as Comp
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from telegram import BotCommand, Update
|
||||
from telegram.constants import ChatType
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, ExtBot, filters
|
||||
from telegram.ext import MessageHandler as TelegramMessageHandler
|
||||
|
||||
import astrbot.api.message_components as Comp
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
AstrBotMessage,
|
||||
MessageMember,
|
||||
PlatformMetadata,
|
||||
MessageType,
|
||||
Platform,
|
||||
PlatformMetadata,
|
||||
register_platform_adapter,
|
||||
)
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.api.platform import register_platform_adapter
|
||||
from astrbot.core.star.filter.command import CommandFilter
|
||||
from astrbot.core.star.filter.command_group import CommandGroupFilter
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.star_handler import star_handlers_registry
|
||||
|
||||
from telegram import Update
|
||||
from telegram.ext import ApplicationBuilder, ContextTypes, filters
|
||||
from telegram.constants import ChatType
|
||||
from telegram.ext import MessageHandler as TelegramMessageHandler
|
||||
from .tg_event import TelegramPlatformEvent
|
||||
from astrbot.api import logger
|
||||
from telegram.ext import ExtBot
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
@@ -52,6 +58,14 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
self.base_url = base_url
|
||||
|
||||
self.enable_command_register = self.config.get(
|
||||
"telegram_command_register", True
|
||||
)
|
||||
self.enable_command_refresh = self.config.get(
|
||||
"telegram_command_auto_refresh", True
|
||||
)
|
||||
self.last_command_hash = None
|
||||
|
||||
self.application = (
|
||||
ApplicationBuilder()
|
||||
.token(self.config["telegram_token"])
|
||||
@@ -67,6 +81,8 @@ class TelegramPlatformAdapter(Platform):
|
||||
self.client = self.application.bot
|
||||
logger.debug(f"Telegram base url: {self.client.base_url}")
|
||||
|
||||
self.scheduler = AsyncIOScheduler()
|
||||
|
||||
@override
|
||||
async def send_by_session(
|
||||
self, session: MessageSesion, message_chain: MessageChain
|
||||
@@ -80,18 +96,104 @@ class TelegramPlatformAdapter(Platform):
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"telegram",
|
||||
"telegram 适配器",
|
||||
name="telegram", description="telegram 适配器", id=self.config.get("id")
|
||||
)
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
await self.application.initialize()
|
||||
await self.application.start()
|
||||
|
||||
if self.enable_command_register:
|
||||
await self.register_commands()
|
||||
|
||||
if self.enable_command_refresh and self.enable_command_register:
|
||||
self.scheduler.add_job(
|
||||
self.register_commands,
|
||||
"interval",
|
||||
seconds=self.config.get("telegram_command_register_interval", 300),
|
||||
id="telegram_command_register",
|
||||
misfire_grace_time=60,
|
||||
)
|
||||
self.scheduler.start()
|
||||
|
||||
queue = self.application.updater.start_polling()
|
||||
logger.info("Telegram Platform Adapter is running.")
|
||||
await queue
|
||||
|
||||
async def register_commands(self):
|
||||
"""收集所有注册的指令并注册到 Telegram"""
|
||||
try:
|
||||
commands = self.collect_commands()
|
||||
|
||||
if commands:
|
||||
current_hash = hash(
|
||||
tuple((cmd.command, cmd.description) for cmd in commands)
|
||||
)
|
||||
if current_hash == self.last_command_hash:
|
||||
return
|
||||
self.last_command_hash = current_hash
|
||||
await self.client.delete_my_commands()
|
||||
await self.client.set_my_commands(commands)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"向 Telegram 注册指令时发生错误: {e!s}")
|
||||
|
||||
def collect_commands(self) -> list[BotCommand]:
|
||||
"""从注册的处理器中收集所有指令"""
|
||||
command_dict = {}
|
||||
skip_commands = {"start"}
|
||||
|
||||
for handler_md in star_handlers_registry._handlers:
|
||||
handler_metadata = handler_md[1]
|
||||
if not star_map[handler_metadata.handler_module_path].activated:
|
||||
continue
|
||||
for event_filter in handler_metadata.event_filters:
|
||||
cmd_info = self._extract_command_info(
|
||||
event_filter, handler_metadata, skip_commands
|
||||
)
|
||||
if cmd_info:
|
||||
cmd_name, description = cmd_info
|
||||
command_dict.setdefault(cmd_name, description)
|
||||
|
||||
commands_a = sorted(command_dict.keys())
|
||||
return [BotCommand(cmd, command_dict[cmd]) for cmd in commands_a]
|
||||
|
||||
@staticmethod
|
||||
def _extract_command_info(
|
||||
event_filter, handler_metadata, skip_commands: set
|
||||
) -> tuple[str, str] | None:
|
||||
"""从事件过滤器中提取指令信息"""
|
||||
cmd_name = None
|
||||
is_group = False
|
||||
if isinstance(event_filter, CommandFilter) and event_filter.command_name:
|
||||
if (
|
||||
event_filter.parent_command_names
|
||||
and event_filter.parent_command_names != [""]
|
||||
):
|
||||
return None
|
||||
cmd_name = event_filter.command_name
|
||||
elif isinstance(event_filter, CommandGroupFilter):
|
||||
if event_filter.parent_group:
|
||||
return None
|
||||
cmd_name = event_filter.group_name
|
||||
is_group = True
|
||||
|
||||
if not cmd_name or cmd_name in skip_commands:
|
||||
return None
|
||||
|
||||
if not re.match(r"^[a-z0-9_]+$", cmd_name) or len(cmd_name) > 32:
|
||||
logger.debug(f"跳过无法注册的命令: {cmd_name}")
|
||||
return None
|
||||
|
||||
# Build description.
|
||||
description = handler_metadata.desc or (
|
||||
f"指令组: {cmd_name} (包含多个子指令)" if is_group else f"指令: {cmd_name}"
|
||||
)
|
||||
if len(description) > 30:
|
||||
description = description[:30] + "..."
|
||||
return cmd_name, description
|
||||
|
||||
async def start(self, update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
await context.bot.send_message(
|
||||
chat_id=update.effective_chat.id, text=self.config["start_message"]
|
||||
@@ -163,6 +265,16 @@ class TelegramPlatformAdapter(Platform):
|
||||
# 处理文本消息
|
||||
plain_text = update.message.text
|
||||
|
||||
# 群聊场景命令特殊处理
|
||||
if plain_text.startswith("/"):
|
||||
command_parts = plain_text.split(" ", 1)
|
||||
if "@" in command_parts[0]:
|
||||
command, bot_name = command_parts[0].split("@")
|
||||
if bot_name == self.client.username:
|
||||
plain_text = command + (
|
||||
f" {command_parts[1]}" if len(command_parts) > 1 else ""
|
||||
)
|
||||
|
||||
if update.message.entities:
|
||||
for entity in update.message.entities:
|
||||
if entity.type == "mention":
|
||||
@@ -242,8 +354,14 @@ class TelegramPlatformAdapter(Platform):
|
||||
|
||||
async def terminate(self):
|
||||
try:
|
||||
if self.scheduler.running:
|
||||
self.scheduler.shutdown()
|
||||
|
||||
await self.application.stop()
|
||||
|
||||
if self.enable_command_register:
|
||||
await self.client.delete_my_commands()
|
||||
|
||||
# 保险起见先判断是否存在updater对象
|
||||
if self.application.updater is not None:
|
||||
await self.application.updater.stop()
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
import asyncio
|
||||
import telegramify_markdown
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, Reply, At, File, Record
|
||||
from astrbot.api.message_components import (
|
||||
Plain,
|
||||
Image,
|
||||
Reply,
|
||||
At,
|
||||
File,
|
||||
Record,
|
||||
)
|
||||
from telegram.ext import ExtBot
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot import logger
|
||||
@@ -82,3 +90,107 @@ class TelegramPlatformEvent(AstrMessageEvent):
|
||||
else:
|
||||
await self.send_with_client(self.client, message, self.get_sender_id())
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
message_thread_id = None
|
||||
|
||||
if self.get_message_type() == MessageType.GROUP_MESSAGE:
|
||||
user_name = self.message_obj.group_id
|
||||
else:
|
||||
user_name = self.get_sender_id()
|
||||
|
||||
if "#" in user_name:
|
||||
# it's a supergroup chat with message_thread_id
|
||||
user_name, message_thread_id = user_name.split("#")
|
||||
payload = {
|
||||
"chat_id": user_name,
|
||||
}
|
||||
if message_thread_id:
|
||||
payload["reply_to_message_id"] = message_thread_id
|
||||
|
||||
delta = ""
|
||||
current_content = ""
|
||||
message_id = None
|
||||
last_edit_time = 0 # 上次编辑消息的时间
|
||||
throttle_interval = 0.6 # 编辑消息的间隔时间 (秒)
|
||||
|
||||
async for chain in generator:
|
||||
if isinstance(chain, MessageChain):
|
||||
# 处理消息链中的每个组件
|
||||
for i in chain.chain:
|
||||
if isinstance(i, Plain):
|
||||
delta += i.text
|
||||
elif isinstance(i, Image):
|
||||
image_path = await i.convert_to_file_path()
|
||||
await self.client.send_photo(photo=image_path, **payload)
|
||||
continue
|
||||
elif isinstance(i, File):
|
||||
if i.file.startswith("https://"):
|
||||
path = "data/temp/" + i.name
|
||||
await download_file(i.file, path)
|
||||
i.file = path
|
||||
|
||||
await self.client.send_document(
|
||||
document=i.file, filename=i.name, **payload
|
||||
)
|
||||
continue
|
||||
elif isinstance(i, Record):
|
||||
path = await i.convert_to_file_path()
|
||||
await self.client.send_voice(voice=path, **payload)
|
||||
continue
|
||||
else:
|
||||
logger.warning(f"不支持的消息类型: {type(i)}")
|
||||
continue
|
||||
|
||||
# Plain
|
||||
if not message_id:
|
||||
try:
|
||||
msg = await self.client.send_message(text=delta, **payload)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"发送消息失败(streaming): {e!s}")
|
||||
message_id = msg.message_id
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 记录初始消息发送时间
|
||||
else:
|
||||
current_time = asyncio.get_event_loop().time()
|
||||
time_since_last_edit = current_time - last_edit_time
|
||||
|
||||
# 如果距离上次编辑的时间 >= 设定的间隔,等待一段时间
|
||||
if time_since_last_edit >= throttle_interval:
|
||||
# 编辑消息
|
||||
try:
|
||||
await self.client.edit_message_text(
|
||||
text=delta,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id,
|
||||
)
|
||||
current_content = delta
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||
last_edit_time = (
|
||||
asyncio.get_event_loop().time()
|
||||
) # 更新上次编辑的时间
|
||||
|
||||
try:
|
||||
if delta and current_content != delta:
|
||||
try:
|
||||
markdown_text = telegramify_markdown.markdownify(
|
||||
delta, max_line_length=None, normalize_whitespace=False
|
||||
)
|
||||
await self.client.edit_message_text(
|
||||
text=markdown_text,
|
||||
chat_id=payload["chat_id"],
|
||||
message_id=message_id,
|
||||
parse_mode="MarkdownV2",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Markdown转换失败,使用普通文本: {e!s}")
|
||||
await self.client.edit_message_text(
|
||||
text=delta, chat_id=payload["chat_id"], message_id=message_id
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"编辑消息失败(streaming): {e!s}")
|
||||
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -43,8 +43,7 @@ class WebChatAdapter(Platform):
|
||||
self.imgs_dir = "data/webchat/imgs"
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
"webchat",
|
||||
"webchat",
|
||||
name="webchat", description="webchat", id=self.config.get("id")
|
||||
)
|
||||
|
||||
async def send_by_session(
|
||||
|
||||
@@ -16,16 +16,26 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
os.makedirs(imgs_dir, exist_ok=True)
|
||||
|
||||
@staticmethod
|
||||
async def _send(message: MessageChain, session_id: str):
|
||||
async def _send(message: MessageChain, session_id: str, streaming: bool = False):
|
||||
if not message:
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
return
|
||||
await web_chat_back_queue.put(
|
||||
{"type": "end", "data": "", "streaming": False}
|
||||
)
|
||||
return ""
|
||||
|
||||
cid = session_id.split("!")[-1]
|
||||
|
||||
data = ""
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
web_chat_back_queue.put_nowait((comp.text, cid))
|
||||
data = comp.text
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "plain",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
# save image to local
|
||||
filename = str(uuid.uuid4()) + ".jpg"
|
||||
@@ -46,7 +56,15 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
|
||||
data = f"[IMAGE]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "image",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
}
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
# save record to local
|
||||
filename = str(uuid.uuid4()) + ".wav"
|
||||
@@ -62,11 +80,45 @@ class WebChatMessageEvent(AstrMessageEvent):
|
||||
with open(path, "wb") as f:
|
||||
with open(comp.file, "rb") as f2:
|
||||
f.write(f2.read())
|
||||
web_chat_back_queue.put_nowait((f"[RECORD]{filename}", cid))
|
||||
data = f"[RECORD]{filename}"
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "record",
|
||||
"cid": cid,
|
||||
"data": data,
|
||||
"streaming": streaming,
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.debug(f"webchat 忽略: {comp.type}")
|
||||
web_chat_back_queue.put_nowait(None)
|
||||
|
||||
return data
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
await WebChatMessageEvent._send(message, session_id=self.session_id)
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": "",
|
||||
"streaming": False,
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
final_data = ""
|
||||
async for chain in generator:
|
||||
final_data += await WebChatMessageEvent._send(
|
||||
chain, session_id=self.session_id, streaming=True
|
||||
)
|
||||
|
||||
await web_chat_back_queue.put(
|
||||
{
|
||||
"type": "end",
|
||||
"data": final_data,
|
||||
"streaming": True,
|
||||
"cid": self.session_id.split("!")[-1],
|
||||
}
|
||||
)
|
||||
await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -2,6 +2,7 @@ import sys
|
||||
import uuid
|
||||
import asyncio
|
||||
import quart
|
||||
import aiohttp
|
||||
|
||||
from astrbot.api.platform import (
|
||||
Platform,
|
||||
@@ -20,10 +21,14 @@ from requests import Response
|
||||
from wechatpy.enterprise.crypto import WeChatCrypto
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage
|
||||
from wechatpy.messages import BaseMessage
|
||||
from wechatpy.exceptions import InvalidSignatureException
|
||||
from wechatpy.enterprise import parse_message
|
||||
from .wecom_event import WecomPlatformEvent
|
||||
|
||||
from .wecom_kf import WeChatKF
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
@@ -131,9 +136,40 @@ class WecomPlatformAdapter(Platform):
|
||||
self.config["corpid"].strip(),
|
||||
self.config["secret"].strip(),
|
||||
)
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
|
||||
async def callback(msg):
|
||||
# 微信客服
|
||||
self.kf_name = self.config.get("kf_name", None)
|
||||
if self.kf_name:
|
||||
# inject
|
||||
self.wechat_kf_api = WeChatKF(client=self.client)
|
||||
self.wechat_kf_message_api = WeChatKFMessage(self.client)
|
||||
self.client.kf = self.wechat_kf_api
|
||||
self.client.kf_message = self.wechat_kf_message_api
|
||||
|
||||
self.client.API_BASE_URL = self.api_base_url
|
||||
|
||||
async def callback(msg: BaseMessage):
|
||||
if msg.type == "unknown" and msg._data["Event"] == "kf_msg_or_event":
|
||||
|
||||
def get_latest_msg_item() -> dict | None:
|
||||
token = msg._data["Token"]
|
||||
kfid = msg._data["OpenKfId"]
|
||||
has_more = 1
|
||||
ret = {}
|
||||
while has_more:
|
||||
ret = self.wechat_kf_api.sync_msg(token, kfid)
|
||||
has_more = ret["has_more"]
|
||||
msg_list = ret.get("msg_list", [])
|
||||
if msg_list:
|
||||
return msg_list[-1]
|
||||
return None
|
||||
|
||||
msg_new = await asyncio.get_event_loop().run_in_executor(
|
||||
None, get_latest_msg_item
|
||||
)
|
||||
if msg_new:
|
||||
await self.convert_wechat_kf_message(msg_new)
|
||||
return
|
||||
await self.convert_message(msg)
|
||||
|
||||
self.server.callback = callback
|
||||
@@ -153,9 +189,39 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
@override
|
||||
async def run(self):
|
||||
loop = asyncio.get_event_loop()
|
||||
if self.kf_name:
|
||||
try:
|
||||
acc_list = (
|
||||
await loop.run_in_executor(
|
||||
None, self.wechat_kf_api.get_account_list
|
||||
)
|
||||
).get("account_list", [])
|
||||
logger.debug(f"获取到微信客服列表: {str(acc_list)}")
|
||||
for acc in acc_list:
|
||||
name = acc.get("name", None)
|
||||
if name != self.kf_name:
|
||||
continue
|
||||
open_kfid = acc.get("open_kfid", None)
|
||||
if not open_kfid:
|
||||
logger.error("获取微信客服失败,open_kfid 为空。")
|
||||
logger.debug(f"Found open_kfid: {str(open_kfid)}")
|
||||
kf_url = (
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
self.wechat_kf_api.add_contact_way,
|
||||
open_kfid,
|
||||
"astrbot_placeholder",
|
||||
)
|
||||
).get("url", "")
|
||||
logger.info(
|
||||
f"请打开以下链接,在微信扫码以获取客服微信: https://api.cl2wm.cn/api/qrcode/code?text={kf_url}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
await self.server.start_polling()
|
||||
|
||||
async def convert_message(self, msg):
|
||||
async def convert_message(self, msg: BaseMessage) -> AstrBotMessage | None:
|
||||
abm = AstrBotMessage()
|
||||
if msg.type == "text":
|
||||
assert isinstance(msg, TextMessage)
|
||||
@@ -218,10 +284,42 @@ class WecomPlatformAdapter(Platform):
|
||||
abm.timestamp = msg.time
|
||||
abm.session_id = abm.sender.user_id
|
||||
abm.raw_message = msg
|
||||
else:
|
||||
logger.warning(f"暂未实现的事件: {msg.type}")
|
||||
return
|
||||
|
||||
logger.info(f"abm: {abm}")
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def convert_wechat_kf_message(self, msg: dict) -> AstrBotMessage | None:
|
||||
msgtype = msg.get("msgtype", None)
|
||||
external_userid = msg.get("external_userid", None)
|
||||
abm = AstrBotMessage()
|
||||
abm.raw_message = msg
|
||||
abm.raw_message["_wechat_kf_flag"] = None # 方便处理
|
||||
abm.self_id = msg["open_kfid"]
|
||||
abm.sender = MessageMember(external_userid, external_userid)
|
||||
abm.session_id = external_userid
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
if msgtype == "text":
|
||||
text = msg.get("text", {}).get("content", "").strip()
|
||||
abm.message = [Plain(text=text)]
|
||||
abm.message_str = text
|
||||
elif msgtype == "image":
|
||||
media_id = msg.get("image", {}).get("media_id", "")
|
||||
resp: Response = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.client.media.download, media_id
|
||||
)
|
||||
path = f"data/temp/wechat_kf_{media_id}.jpg"
|
||||
with open(path, "wb") as f:
|
||||
f.write(resp.content)
|
||||
abm.message = [Image(file=path, url=path)]
|
||||
abm.message_str = "[图片]"
|
||||
else:
|
||||
logger.warning(f"未实现的微信客服消息事件: {msg}")
|
||||
return
|
||||
await self.handle_msg(abm)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = WecomPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
@@ -237,5 +335,8 @@ class WecomPlatformAdapter(Platform):
|
||||
|
||||
async def terminate(self):
|
||||
self.server.shutdown_event.set()
|
||||
await self.server.server.shutdown()
|
||||
try:
|
||||
await self.server.server.shutdown()
|
||||
except Exception as _:
|
||||
pass
|
||||
logger.info("企业微信 适配器已被优雅地关闭")
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import uuid
|
||||
import asyncio
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image, Record
|
||||
from wechatpy.enterprise import WeChatClient
|
||||
from .wecom_kf_message import WeChatKFMessage
|
||||
|
||||
from astrbot.api import logger
|
||||
|
||||
@@ -33,54 +35,157 @@ class WecomPlatformEvent(AstrMessageEvent):
|
||||
):
|
||||
pass
|
||||
|
||||
async def split_plain(self, plain: str) -> list[str]:
|
||||
"""将长文本分割成多个小文本, 每个小文本长度不超过 2048 字符
|
||||
|
||||
Args:
|
||||
plain (str): 要分割的长文本
|
||||
Returns:
|
||||
list[str]: 分割后的文本列表
|
||||
"""
|
||||
if len(plain) <= 2048:
|
||||
return [plain]
|
||||
else:
|
||||
result = []
|
||||
start = 0
|
||||
while start < len(plain):
|
||||
# 剩下的字符串长度<2048时结束
|
||||
if start + 2048 >= len(plain):
|
||||
result.append(plain[start:])
|
||||
break
|
||||
|
||||
# 向前搜索分割标点符号
|
||||
end = min(start + 2048, len(plain))
|
||||
cut_position = end
|
||||
for i in range(end, start, -1):
|
||||
if i < len(plain) and plain[i - 1] in [
|
||||
"。",
|
||||
"!",
|
||||
"?",
|
||||
".",
|
||||
"!",
|
||||
"?",
|
||||
"\n",
|
||||
";",
|
||||
";",
|
||||
]:
|
||||
cut_position = i
|
||||
break
|
||||
|
||||
# 没找到合适的位置分割, 直接切分
|
||||
if cut_position == end and end < len(plain):
|
||||
cut_position = end
|
||||
|
||||
result.append(plain[start:cut_position])
|
||||
start = cut_position
|
||||
|
||||
return result
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
message_obj = self.message_obj
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
self.client.message.send_text(
|
||||
message_obj.self_id, message_obj.session_id, comp.text
|
||||
)
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
is_wechat_kf = hasattr(self.client, "kf_message")
|
||||
if is_wechat_kf:
|
||||
# 微信客服
|
||||
kf_message_api = getattr(self.client, "kf_message", None)
|
||||
if not kf_message_api:
|
||||
logger.warning("未找到微信客服发送消息方法。")
|
||||
return
|
||||
assert isinstance(kf_message_api, WeChatKFMessage)
|
||||
user_id = self.get_sender_id()
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
kf_message_api.send_text(user_id, self.get_self_id(), chunk)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传图片失败: {e}")
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"微信客服上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"微信客服上传图片失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.debug(f"微信客服上传图片返回: {response}")
|
||||
kf_message_api.send_image(
|
||||
user_id,
|
||||
self.get_self_id(),
|
||||
response["media_id"],
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
else:
|
||||
# 企业微信应用
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
# Split long text messages if needed
|
||||
plain_chunks = await self.split_plain(comp.text)
|
||||
for chunk in plain_chunks:
|
||||
self.client.message.send_text(
|
||||
message_obj.self_id, message_obj.session_id, chunk
|
||||
)
|
||||
await asyncio.sleep(0.5) # Avoid sending too fast
|
||||
elif isinstance(comp, Image):
|
||||
img_path = await comp.convert_to_file_path()
|
||||
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传语音失败: {e}")
|
||||
with open(img_path, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("image", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传图片失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传图片失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.debug(f"企业微信上传图片返回: {response}")
|
||||
self.client.message.send_image(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
elif isinstance(comp, Record):
|
||||
record_path = await comp.convert_to_file_path()
|
||||
# 转成amr
|
||||
record_path_amr = f"data/temp/{uuid.uuid4()}.amr"
|
||||
pydub.AudioSegment.from_wav(record_path).export(
|
||||
record_path_amr, format="amr"
|
||||
)
|
||||
|
||||
with open(record_path_amr, "rb") as f:
|
||||
try:
|
||||
response = self.client.media.upload("voice", f)
|
||||
except Exception as e:
|
||||
logger.error(f"企业微信上传语音失败: {e}")
|
||||
await self.send(
|
||||
MessageChain().message(f"企业微信上传语音失败: {e}")
|
||||
)
|
||||
return
|
||||
logger.info(f"企业微信上传语音返回: {response}")
|
||||
self.client.message.send_voice(
|
||||
message_obj.self_id,
|
||||
message_obj.session_id,
|
||||
response["media_id"],
|
||||
)
|
||||
else:
|
||||
logger.warning(f"还没实现这个消息类型的发送逻辑: {comp.type}。")
|
||||
|
||||
await super().send(message)
|
||||
|
||||
async def send_streaming(self, generator, use_fallback: bool = False):
|
||||
buffer = None
|
||||
async for chain in generator:
|
||||
if not buffer:
|
||||
buffer = chain
|
||||
else:
|
||||
buffer.chain.extend(chain.chain)
|
||||
if not buffer:
|
||||
return
|
||||
buffer.squash_plain()
|
||||
await self.send(buffer)
|
||||
return await super().send_streaming(generator, use_fallback)
|
||||
|
||||
@@ -0,0 +1,278 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014-2020 messense
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from wechatpy.client.api.base import BaseWeChatAPI
|
||||
|
||||
|
||||
class WeChatKF(BaseWeChatAPI):
|
||||
"""
|
||||
微信客服接口
|
||||
|
||||
https://work.weixin.qq.com/api/doc/90000/90135/94670
|
||||
"""
|
||||
|
||||
def sync_msg(self, token, open_kfid, cursor="", limit=1000):
|
||||
"""
|
||||
微信客户发送的消息、接待人员在企业微信回复的消息、发送消息接口发送失败事件(如被用户拒收)
|
||||
、客户点击菜单消息的回复消息,可以通过该接口获取具体的消息内容和事件。不支持读取通过发送消息接口发送的消息。
|
||||
支持的消息类型:文本、图片、语音、视频、文件、位置、链接、名片、小程序、事件。
|
||||
|
||||
|
||||
:param token: 回调事件返回的token字段,10分钟内有效;可不填,如果不填接口有严格的频率限制。不多于128字节
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param cursor: 上一次调用时返回的next_cursor,第一次拉取可以不填。不多于64字节
|
||||
:param limit: 期望请求的数据量,默认值和最大值都为1000。
|
||||
注意:可能会出现返回条数少于limit的情况,需结合返回的has_more字段判断是否继续请求。
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"token": token, "cursor": cursor, "limit": limit, "open_kfid": open_kfid}
|
||||
return self._post("kf/sync_msg", data=data)
|
||||
|
||||
def get_service_state(self, open_kfid, external_userid):
|
||||
"""
|
||||
获取会话状态
|
||||
|
||||
ID 状态 说明
|
||||
0 未处理 新会话接入。可选择:1.直接用API自动回复消息。2.放进待接入池等待接待人员接待。3.指定接待人员进行接待
|
||||
1 由智能助手接待 可使用API回复消息。可选择转入待接入池或者指定接待人员处理。
|
||||
2 待接入池排队中 在待接入池中排队等待接待人员接入。可选择转为指定人员接待
|
||||
3 由人工接待 人工接待中。可选择结束会话
|
||||
4 已结束 会话已经结束。不允许变更会话状态,等待用户重新发起咨询
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"external_userid": external_userid,
|
||||
}
|
||||
return self._post("kf/service_state/get", data=data)
|
||||
|
||||
def trans_service_state(self, open_kfid, external_userid, service_state, servicer_userid=""):
|
||||
"""
|
||||
变更会话状态
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:param service_state: 当前的会话状态,状态定义参考概述中的表格
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"external_userid": external_userid,
|
||||
"service_state": service_state,
|
||||
}
|
||||
if servicer_userid:
|
||||
data["servicer_userid"] = servicer_userid
|
||||
return self._post("kf/service_state/trans", data=data)
|
||||
|
||||
def get_servicer_list(self, open_kfid):
|
||||
"""
|
||||
获取接待人员列表
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
}
|
||||
return self._get("kf/servicer/list", params=data)
|
||||
|
||||
def add_servicer(self, open_kfid, userid_list):
|
||||
"""
|
||||
添加接待人员
|
||||
添加指定客服帐号的接待人员。
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param userid_list: 接待人员userid列表
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
if not isinstance(userid_list, list):
|
||||
userid_list = [userid_list]
|
||||
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"userid_list": userid_list,
|
||||
}
|
||||
return self._post("kf/servicer/add", data=data)
|
||||
|
||||
def del_servicer(self, open_kfid, userid_list):
|
||||
"""
|
||||
删除接待人员
|
||||
从客服帐号删除接待人员
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param userid_list: 接待人员userid列表
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
if not isinstance(userid_list, list):
|
||||
userid_list = [userid_list]
|
||||
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"userid_list": userid_list,
|
||||
}
|
||||
return self._post("kf/servicer/del", data=data)
|
||||
|
||||
def batchget_customer(self, external_userid_list):
|
||||
"""
|
||||
客户基本信息获取
|
||||
|
||||
:param external_userid_list: external_userid列表
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
if not isinstance(external_userid_list, list):
|
||||
external_userid_list = [external_userid_list]
|
||||
|
||||
data = {
|
||||
"external_userid_list": external_userid_list,
|
||||
}
|
||||
return self._post("kf/customer/batchget", data=data)
|
||||
|
||||
def get_account_list(self):
|
||||
"""
|
||||
获取客服帐号列表
|
||||
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
return self._get("kf/account/list")
|
||||
|
||||
def add_contact_way(self, open_kfid, scene):
|
||||
"""
|
||||
获取客服帐号链接
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param scene: 场景值,字符串类型,由开发者自定义。不多于32字节;字符串取值范围(正则表达式):[0-9a-zA-Z_-]*
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"open_kfid": open_kfid, "scene": scene}
|
||||
return self._post("kf/add_contact_way", data=data)
|
||||
|
||||
def get_upgrade_service_config(self):
|
||||
"""
|
||||
获取配置的专员与客户群
|
||||
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
return self._get("kf/customer/get_upgrade_service_config")
|
||||
|
||||
def upgrade_service(self, open_kfid, external_userid, service_type, member=None, groupchat=None):
|
||||
"""
|
||||
为客户升级为专员或客户群服务
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:param service_type: 表示是升级到专员服务还是客户群服务。1:专员服务。2:客户群服务
|
||||
:param member: 推荐的服务专员,type等于1时有效
|
||||
:param groupchat: 推荐的客户群,type等于2时有效
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"external_userid": external_userid,
|
||||
"type": service_type,
|
||||
}
|
||||
if service_type == 1:
|
||||
data["member"] = member
|
||||
else:
|
||||
data["groupchat"] = groupchat
|
||||
return self._post("kf/customer/upgrade_service", data=data)
|
||||
|
||||
def cancel_upgrade_service(self, open_kfid, external_userid):
|
||||
"""
|
||||
为客户取消推荐
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param external_userid: 微信客户的external_userid
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
|
||||
data = {"open_kfid": open_kfid, "external_userid": external_userid}
|
||||
return self._post("kf/customer/cancel_upgrade_service", data=data)
|
||||
|
||||
def send_msg_on_event(self, code, msgtype, msg_content, msgid=None):
|
||||
"""
|
||||
当特定的事件回调消息包含code字段,可以此code为凭证,调用该接口给用户发送相应事件场景下的消息,如客服欢迎语。
|
||||
支持发送消息类型:文本、菜单消息。
|
||||
|
||||
:param code: 事件响应消息对应的code。通过事件回调下发,仅可使用一次。
|
||||
:param msgtype: 消息类型。对不同的msgtype,有相应的结构描述,详见消息类型
|
||||
:param msg_content: 目前支持文本与菜单消息,具体查看文档
|
||||
:param msgid: 消息ID。如果请求参数指定了msgid,则原样返回,否则系统自动生成并返回。不多于32字节;
|
||||
字符串取值范围(正则表达式):[0-9a-zA-Z_-]*
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
|
||||
data = {"code": code, "msgtype": msgtype}
|
||||
if msgid:
|
||||
data["msgid"] = msgid
|
||||
data.update(msg_content)
|
||||
return self._post("kf/send_msg_on_event", data=data)
|
||||
|
||||
def get_corp_statistic(self, start_time, end_time, open_kfid=None):
|
||||
"""
|
||||
获取「客户数据统计」企业汇总数据
|
||||
|
||||
:param start_time: 开始时间
|
||||
:param end_time: 结束时间
|
||||
:param open_kfid: 客服帐号ID
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"open_kfid": open_kfid, "start_time": start_time, "end_time": end_time}
|
||||
return self._post("kf/get_corp_statistic", data=data)
|
||||
|
||||
def get_servicer_statistic(self, start_time, end_time, open_kfid=None, servicer_userid=None):
|
||||
"""
|
||||
获取「客户数据统计」接待人员明细数据
|
||||
|
||||
:param start_time: 开始时间
|
||||
:param end_time: 结束时间
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param servicer_userid: 接待人员
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {
|
||||
"open_kfid": open_kfid,
|
||||
"servicer_userid": servicer_userid,
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
}
|
||||
return self._post("kf/get_servicer_statistic", data=data)
|
||||
|
||||
def account_update(self, open_kfid, name, media_id):
|
||||
"""
|
||||
修改客服账号
|
||||
|
||||
:param open_kfid: 客服帐号ID
|
||||
:param name: 客服名称
|
||||
:param media_id: 客服头像临时素材
|
||||
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
data = {"open_kfid": open_kfid, "name": name, "media_id": media_id}
|
||||
return self._post("kf/account/update", data=data)
|
||||
@@ -0,0 +1,159 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2014-2020 messense
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
from optionaldict import optionaldict
|
||||
|
||||
from wechatpy.client.api.base import BaseWeChatAPI
|
||||
|
||||
class WeChatKFMessage(BaseWeChatAPI):
|
||||
"""
|
||||
发送微信客服消息
|
||||
|
||||
https://work.weixin.qq.com/api/doc/90000/90135/94677
|
||||
|
||||
支持:
|
||||
* 文本消息
|
||||
* 图片消息
|
||||
* 语音消息
|
||||
* 视频消息
|
||||
* 文件消息
|
||||
* 图文链接
|
||||
* 小程序
|
||||
* 菜单消息
|
||||
* 地理位置
|
||||
"""
|
||||
|
||||
def send(self, user_id, open_kfid, msgid="", msg=None):
|
||||
"""
|
||||
当微信客户处于“新接入待处理”或“由智能助手接待”状态下,可调用该接口给用户发送消息。
|
||||
注意仅当微信客户在主动发送消息给客服后的48小时内,企业可发送消息给客户,最多可发送5条消息;若用户继续发送消息,企业可再次下发消息。
|
||||
支持发送消息类型:文本、图片、语音、视频、文件、图文、小程序、菜单消息、地理位置。
|
||||
|
||||
:param user_id: 指定接收消息的客户UserID
|
||||
:param open_kfid: 指定发送消息的客服帐号ID
|
||||
:param msgid: 指定消息ID
|
||||
:param tag_ids: 标签ID列表。
|
||||
:param msg: 发送消息的 dict 对象
|
||||
:type msg: dict | None
|
||||
:return: 接口调用结果
|
||||
"""
|
||||
msg = msg or {}
|
||||
data = {
|
||||
"touser": user_id,
|
||||
"open_kfid": open_kfid,
|
||||
}
|
||||
if msgid:
|
||||
data["msgid"] = msgid
|
||||
data.update(msg)
|
||||
return self._post("kf/send_msg", data=data)
|
||||
|
||||
def send_text(self, user_id, open_kfid, content, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "text", "text": {"content": content}},
|
||||
)
|
||||
|
||||
def send_image(self, user_id, open_kfid, media_id, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "image", "image": {"media_id": media_id}},
|
||||
)
|
||||
|
||||
def send_voice(self, user_id, open_kfid, media_id, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "voice", "voice": {"media_id": media_id}},
|
||||
)
|
||||
|
||||
def send_video(self, user_id, open_kfid, media_id, msgid=""):
|
||||
video_data = optionaldict()
|
||||
video_data["media_id"] = media_id
|
||||
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "video", "video": dict(video_data)},
|
||||
)
|
||||
|
||||
def send_file(self, user_id, open_kfid, media_id, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "file", "file": {"media_id": media_id}},
|
||||
)
|
||||
|
||||
def send_articles_link(self, user_id, open_kfid, article, msgid=""):
|
||||
articles_data = {
|
||||
"title": article["title"],
|
||||
"desc": article["desc"],
|
||||
"url": article["url"],
|
||||
"thumb_media_id": article["thumb_media_id"],
|
||||
}
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={"msgtype": "news", "link": {"link": articles_data}},
|
||||
)
|
||||
|
||||
def send_msgmenu(self, user_id, open_kfid, head_content, menu_list, tail_content, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "msgmenu",
|
||||
"msgmenu": {"head_content": head_content, "list": menu_list, "tail_content": tail_content},
|
||||
},
|
||||
)
|
||||
|
||||
def send_location(self, user_id, open_kfid, name, address, latitude, longitude, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "location",
|
||||
"msgmenu": {"name": name, "address": address, "latitude": latitude, "longitude": longitude},
|
||||
},
|
||||
)
|
||||
|
||||
def send_miniprogram(self, user_id, open_kfid, appid, title, thumb_media_id, pagepath, msgid=""):
|
||||
return self.send(
|
||||
user_id,
|
||||
open_kfid,
|
||||
msgid,
|
||||
msg={
|
||||
"msgtype": "miniprogram",
|
||||
"msgmenu": {"appid": appid, "title": title, "thumb_media_id": thumb_media_id, "pagepath": pagepath},
|
||||
},
|
||||
)
|
||||
@@ -1,5 +1,5 @@
|
||||
from .provider import Provider, Personality, STTProvider
|
||||
|
||||
from .entites import ProviderMetaData
|
||||
from .entities import ProviderMetaData
|
||||
|
||||
__all__ = ["Provider", "Personality", "ProviderMetaData", "STTProvider"]
|
||||
|
||||
@@ -1,269 +1,19 @@
|
||||
import enum
|
||||
import base64
|
||||
import json
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot import logger
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
from astrbot.core.provider.entities import (
|
||||
ProviderRequest,
|
||||
ProviderType,
|
||||
ProviderMetaData,
|
||||
ToolCallsResult,
|
||||
AssistantMessageSegment,
|
||||
ToolCallMessageSegment,
|
||||
LLMResponse,
|
||||
)
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
import astrbot.core.message.components as Comp
|
||||
|
||||
|
||||
class ProviderType(enum.Enum):
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData:
|
||||
type: str
|
||||
"""提供商适配器名称,如 openai, ollama"""
|
||||
desc: str = ""
|
||||
"""提供商适配器描述."""
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
cls_type: Type = None
|
||||
|
||||
default_config_tmpl: dict = None
|
||||
"""平台的默认配置模板"""
|
||||
provider_display_name: str = None
|
||||
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallMessageSegment:
|
||||
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||
|
||||
tool_call_id: str
|
||||
content: str
|
||||
role: str = "tool"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tool_call_id": self.tool_call_id,
|
||||
"content": self.content,
|
||||
"role": self.role,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantMessageSegment:
|
||||
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||
|
||||
content: str = None
|
||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
|
||||
role: str = "assistant"
|
||||
|
||||
def to_dict(self):
|
||||
ret = {
|
||||
"role": self.role,
|
||||
}
|
||||
if self.content:
|
||||
ret["content"] = self.content
|
||||
elif self.tool_calls:
|
||||
ret["tool_calls"] = self.tool_calls
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallsResult:
|
||||
"""工具调用结果"""
|
||||
|
||||
tool_calls_info: AssistantMessageSegment
|
||||
"""函数调用的信息"""
|
||||
tool_calls_result: List[ToolCallMessageSegment]
|
||||
"""函数调用的结果"""
|
||||
|
||||
def to_openai_messages(self) -> List[Dict]:
|
||||
ret = [
|
||||
self.tool_calls_info.to_dict(),
|
||||
*[item.to_dict() for item in self.tool_calls_result],
|
||||
]
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest:
|
||||
prompt: str
|
||||
"""提示词"""
|
||||
session_id: str = ""
|
||||
"""会话 ID"""
|
||||
image_urls: List[str] = None
|
||||
"""图片 URL 列表"""
|
||||
func_tool: FuncCall = None
|
||||
"""可用的函数工具"""
|
||||
contexts: List = None
|
||||
"""上下文。格式与 openai 的上下文格式一致:
|
||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||
"""
|
||||
system_prompt: str = ""
|
||||
"""系统提示词"""
|
||||
conversation: Conversation = None
|
||||
|
||||
tool_calls_result: ToolCallsResult = None
|
||||
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def _print_friendly_context(self):
|
||||
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
||||
if not self.contexts:
|
||||
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
|
||||
|
||||
result_parts = []
|
||||
|
||||
for ctx in self.contexts:
|
||||
role = ctx.get("role", "unknown")
|
||||
content = ctx.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
result_parts.append(f"{role}: {content}")
|
||||
elif isinstance(content, list):
|
||||
msg_parts = []
|
||||
image_count = 0
|
||||
|
||||
for item in content:
|
||||
item_type = item.get("type", "")
|
||||
|
||||
if item_type == "text":
|
||||
msg_parts.append(item.get("text", ""))
|
||||
elif item_type == "image_url":
|
||||
image_count += 1
|
||||
|
||||
if image_count > 0:
|
||||
if msg_parts:
|
||||
msg_parts.append(f"[+{image_count} images]")
|
||||
else:
|
||||
msg_parts.append(f"[{image_count} images]")
|
||||
|
||||
result_parts.append(f"{role}: {''.join(msg_parts)}")
|
||||
|
||||
return result_parts
|
||||
|
||||
async def assemble_context(self) -> Dict:
|
||||
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": self.prompt}],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self._encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self._encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self._encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}}
|
||||
)
|
||||
return user_content
|
||||
else:
|
||||
return {"role": "user", "content": self.prompt}
|
||||
|
||||
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
"""角色, assistant, tool, err"""
|
||||
result_chain: MessageChain = None
|
||||
"""返回的消息链"""
|
||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||
"""工具调用参数"""
|
||||
tools_call_name: List[str] = field(default_factory=list)
|
||||
"""工具调用名称"""
|
||||
tools_call_ids: List[str] = field(default_factory=list)
|
||||
"""工具调用 ID"""
|
||||
|
||||
raw_completion: ChatCompletion = None
|
||||
_new_record: Dict[str, any] = None
|
||||
|
||||
_completion_text: str = ""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
completion_text: str = "",
|
||||
result_chain: MessageChain = None,
|
||||
tools_call_args: List[Dict[str, any]] = [],
|
||||
tools_call_name: List[str] = [],
|
||||
tools_call_ids: List[str] = [],
|
||||
raw_completion: ChatCompletion = None,
|
||||
_new_record: Dict[str, any] = None,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
|
||||
Args:
|
||||
role (str): 角色, assistant, tool, err
|
||||
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
|
||||
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
|
||||
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
|
||||
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
|
||||
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
||||
"""
|
||||
self.role = role
|
||||
self.completion_text = completion_text
|
||||
self.result_chain = result_chain
|
||||
self.tools_call_args = tools_call_args
|
||||
self.tools_call_name = tools_call_name
|
||||
self.tools_call_ids = tools_call_ids
|
||||
self.raw_completion = raw_completion
|
||||
self._new_record = _new_record
|
||||
|
||||
@property
|
||||
def completion_text(self):
|
||||
if self.result_chain:
|
||||
return self.result_chain.get_plain_text()
|
||||
return self._completion_text
|
||||
|
||||
@completion_text.setter
|
||||
def completion_text(self, value):
|
||||
if self.result_chain:
|
||||
self.result_chain.chain = [
|
||||
comp
|
||||
for comp in self.result_chain.chain
|
||||
if not isinstance(comp, Comp.Plain)
|
||||
] # 清空 Plain 组件
|
||||
self.result_chain.chain.insert(0, Comp.Plain(value))
|
||||
else:
|
||||
self._completion_text = value
|
||||
|
||||
def to_openai_tool_calls(self) -> List[Dict]:
|
||||
"""将工具调用信息转换为 OpenAI 格式"""
|
||||
ret = []
|
||||
for idx, tool_call_arg in enumerate(self.tools_call_args):
|
||||
ret.append(
|
||||
{
|
||||
"id": self.tools_call_ids[idx],
|
||||
"function": {
|
||||
"name": self.tools_call_name[idx],
|
||||
"arguments": json.dumps(tool_call_arg),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
)
|
||||
return ret
|
||||
__all__ = [
|
||||
"ProviderRequest",
|
||||
"ProviderType",
|
||||
"ProviderMetaData",
|
||||
"ToolCallsResult",
|
||||
"AssistantMessageSegment",
|
||||
"ToolCallMessageSegment",
|
||||
"LLMResponse",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,281 @@
|
||||
import enum
|
||||
import base64
|
||||
import json
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot import logger
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Type
|
||||
from .func_tool_manager import FuncCall
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_message_tool_call import (
|
||||
ChatCompletionMessageToolCall,
|
||||
)
|
||||
from astrbot.core.db.po import Conversation
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
import astrbot.core.message.components as Comp
|
||||
|
||||
|
||||
class ProviderType(enum.Enum):
|
||||
CHAT_COMPLETION = "chat_completion"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderMetaData:
|
||||
type: str
|
||||
"""提供商适配器名称,如 openai, ollama"""
|
||||
desc: str = ""
|
||||
"""提供商适配器描述."""
|
||||
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
||||
cls_type: Type = None
|
||||
|
||||
default_config_tmpl: dict = None
|
||||
"""平台的默认配置模板"""
|
||||
provider_display_name: str = None
|
||||
"""显示在 WebUI 配置页中的提供商名称,如空则是 type"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallMessageSegment:
|
||||
"""OpenAI 格式的上下文中 role 为 tool 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||
|
||||
tool_call_id: str
|
||||
content: str
|
||||
role: str = "tool"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"tool_call_id": self.tool_call_id,
|
||||
"content": self.content,
|
||||
"role": self.role,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssistantMessageSegment:
|
||||
"""OpenAI 格式的上下文中 role 为 assistant 的消息段。参考: https://platform.openai.com/docs/guides/function-calling"""
|
||||
|
||||
content: str = None
|
||||
tool_calls: List[ChatCompletionMessageToolCall | Dict] = None
|
||||
role: str = "assistant"
|
||||
|
||||
def to_dict(self):
|
||||
ret = {
|
||||
"role": self.role,
|
||||
}
|
||||
if self.content:
|
||||
ret["content"] = self.content
|
||||
elif self.tool_calls:
|
||||
ret["tool_calls"] = self.tool_calls
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolCallsResult:
|
||||
"""工具调用结果"""
|
||||
|
||||
tool_calls_info: AssistantMessageSegment
|
||||
"""函数调用的信息"""
|
||||
tool_calls_result: List[ToolCallMessageSegment]
|
||||
"""函数调用的结果"""
|
||||
|
||||
def to_openai_messages(self) -> List[Dict]:
|
||||
ret = [
|
||||
self.tool_calls_info.to_dict(),
|
||||
*[item.to_dict() for item in self.tool_calls_result],
|
||||
]
|
||||
return ret
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderRequest:
|
||||
prompt: str
|
||||
"""提示词"""
|
||||
session_id: str = ""
|
||||
"""会话 ID"""
|
||||
image_urls: List[str] = None
|
||||
"""图片 URL 列表"""
|
||||
func_tool: FuncCall = None
|
||||
"""可用的函数工具"""
|
||||
contexts: List = None
|
||||
"""上下文。格式与 openai 的上下文格式一致:
|
||||
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
||||
"""
|
||||
system_prompt: str = ""
|
||||
"""系统提示词"""
|
||||
conversation: Conversation = None
|
||||
|
||||
tool_calls_result: ToolCallsResult = None
|
||||
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
||||
|
||||
def __repr__(self):
|
||||
return f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, image_urls={self.image_urls}, func_tool={self.func_tool}, contexts={self._print_friendly_context()}, system_prompt={self.system_prompt.strip()}, tool_calls_result={self.tool_calls_result})"
|
||||
|
||||
def __str__(self):
|
||||
return self.__repr__()
|
||||
|
||||
def _print_friendly_context(self):
|
||||
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
||||
if not self.contexts:
|
||||
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
|
||||
|
||||
result_parts = []
|
||||
|
||||
for ctx in self.contexts:
|
||||
role = ctx.get("role", "unknown")
|
||||
content = ctx.get("content", "")
|
||||
|
||||
if isinstance(content, str):
|
||||
result_parts.append(f"{role}: {content}")
|
||||
elif isinstance(content, list):
|
||||
msg_parts = []
|
||||
image_count = 0
|
||||
|
||||
for item in content:
|
||||
item_type = item.get("type", "")
|
||||
|
||||
if item_type == "text":
|
||||
msg_parts.append(item.get("text", ""))
|
||||
elif item_type == "image_url":
|
||||
image_count += 1
|
||||
|
||||
if image_count > 0:
|
||||
if msg_parts:
|
||||
msg_parts.append(f"[+{image_count} images]")
|
||||
else:
|
||||
msg_parts.append(f"[{image_count} images]")
|
||||
|
||||
result_parts.append(f"{role}: {''.join(msg_parts)}")
|
||||
|
||||
return result_parts
|
||||
|
||||
async def assemble_context(self) -> Dict:
|
||||
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
||||
if self.image_urls:
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": self.prompt if self.prompt else "[图片]"}],
|
||||
}
|
||||
for image_url in self.image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
image_data = await self._encode_image_bs64(image_path)
|
||||
elif image_url.startswith("file:///"):
|
||||
image_path = image_url.replace("file:///", "")
|
||||
image_data = await self._encode_image_bs64(image_path)
|
||||
else:
|
||||
image_data = await self._encode_image_bs64(image_url)
|
||||
if not image_data:
|
||||
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
||||
continue
|
||||
user_content["content"].append(
|
||||
{"type": "image_url", "image_url": {"url": image_data}}
|
||||
)
|
||||
return user_content
|
||||
else:
|
||||
return {"role": "user", "content": self.prompt}
|
||||
|
||||
async def _encode_image_bs64(self, image_url: str) -> str:
|
||||
"""将图片转换为 base64"""
|
||||
if image_url.startswith("base64://"):
|
||||
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
||||
with open(image_url, "rb") as f:
|
||||
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
||||
return "data:image/jpeg;base64," + image_bs64
|
||||
return ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMResponse:
|
||||
role: str
|
||||
"""角色, assistant, tool, err"""
|
||||
result_chain: MessageChain = None
|
||||
"""返回的消息链"""
|
||||
tools_call_args: List[Dict[str, any]] = field(default_factory=list)
|
||||
"""工具调用参数"""
|
||||
tools_call_name: List[str] = field(default_factory=list)
|
||||
"""工具调用名称"""
|
||||
tools_call_ids: List[str] = field(default_factory=list)
|
||||
"""工具调用 ID"""
|
||||
|
||||
raw_completion: ChatCompletion = None
|
||||
_new_record: Dict[str, any] = None
|
||||
|
||||
_completion_text: str = ""
|
||||
|
||||
is_chunk: bool = False
|
||||
"""是否是流式输出的单个 Chunk"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
completion_text: str = "",
|
||||
result_chain: MessageChain = None,
|
||||
tools_call_args: List[Dict[str, any]] = None,
|
||||
tools_call_name: List[str] = None,
|
||||
tools_call_ids: List[str] = None,
|
||||
raw_completion: ChatCompletion = None,
|
||||
_new_record: Dict[str, any] = None,
|
||||
is_chunk: bool = False,
|
||||
):
|
||||
"""初始化 LLMResponse
|
||||
|
||||
Args:
|
||||
role (str): 角色, assistant, tool, err
|
||||
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
|
||||
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
|
||||
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
|
||||
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
|
||||
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
||||
"""
|
||||
if tools_call_args is None:
|
||||
tools_call_args = []
|
||||
if tools_call_name is None:
|
||||
tools_call_name = []
|
||||
if tools_call_ids is None:
|
||||
tools_call_ids = []
|
||||
|
||||
self.role = role
|
||||
self.completion_text = completion_text
|
||||
self.result_chain = result_chain
|
||||
self.tools_call_args = tools_call_args
|
||||
self.tools_call_name = tools_call_name
|
||||
self.tools_call_ids = tools_call_ids
|
||||
self.raw_completion = raw_completion
|
||||
self._new_record = _new_record
|
||||
self.is_chunk = is_chunk
|
||||
|
||||
@property
|
||||
def completion_text(self):
|
||||
if self.result_chain:
|
||||
return self.result_chain.get_plain_text()
|
||||
return self._completion_text
|
||||
|
||||
@completion_text.setter
|
||||
def completion_text(self, value):
|
||||
if self.result_chain:
|
||||
self.result_chain.chain = [
|
||||
comp
|
||||
for comp in self.result_chain.chain
|
||||
if not isinstance(comp, Comp.Plain)
|
||||
] # 清空 Plain 组件
|
||||
self.result_chain.chain.insert(0, Comp.Plain(value))
|
||||
else:
|
||||
self._completion_text = value
|
||||
|
||||
def to_openai_tool_calls(self) -> List[Dict]:
|
||||
"""将工具调用信息转换为 OpenAI 格式"""
|
||||
ret = []
|
||||
for idx, tool_call_arg in enumerate(self.tools_call_args):
|
||||
ret.append(
|
||||
{
|
||||
"id": self.tools_call_ids[idx],
|
||||
"function": {
|
||||
"name": self.tools_call_name[idx],
|
||||
"arguments": json.dumps(tool_call_arg),
|
||||
},
|
||||
"type": "function",
|
||||
}
|
||||
)
|
||||
return ret
|
||||
@@ -3,16 +3,18 @@ import json
|
||||
import textwrap
|
||||
import os
|
||||
import asyncio
|
||||
import copy
|
||||
import logging
|
||||
|
||||
from typing import Dict, List, Awaitable, Literal, Any
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from contextlib import AsyncExitStack
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.log_pipe import LogPipe
|
||||
|
||||
try:
|
||||
import mcp
|
||||
from mcp.client.sse import sse_client
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
logger.warning("警告: 缺少依赖库 'mcp',将无法使用 MCP 服务。")
|
||||
|
||||
@@ -87,26 +89,58 @@ class MCPClient:
|
||||
self.name = None
|
||||
self.active: bool = True
|
||||
self.tools: List[mcp.Tool] = []
|
||||
self.server_errlogs: List[str] = []
|
||||
|
||||
async def connect_to_server(self, mcp_server_config: dict):
|
||||
"""Connect to an MCP server
|
||||
async def connect_to_server(self, mcp_server_config: dict, name: str):
|
||||
"""连接到 MCP 服务器
|
||||
|
||||
如果 `url` 参数存在,则使用 SSE 的方式连接到 MCP 服务。
|
||||
|
||||
Args:
|
||||
mcp_server_config (dict): Configuration for the MCP server. See https://modelcontextprotocol.io/quickstart/server
|
||||
"""
|
||||
cfg = mcp_server_config.copy()
|
||||
cfg.pop("active", None)
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
if "mcpServers" in cfg and len(cfg["mcpServers"]) > 0:
|
||||
key_0 = list(cfg["mcpServers"].keys())[0]
|
||||
cfg = cfg["mcpServers"][key_0]
|
||||
cfg.pop("active", None) # Remove active flag from config
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(server_params)
|
||||
)
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(self.stdio, self.write)
|
||||
)
|
||||
if "url" in cfg:
|
||||
# SSE transport method
|
||||
self._streams_context = sse_client(url=cfg["url"])
|
||||
streams = await self._streams_context.__aenter__()
|
||||
|
||||
# Create a new client session
|
||||
# self.session = await self._session_context.__aenter__()
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*streams)
|
||||
)
|
||||
|
||||
else:
|
||||
server_params = mcp.StdioServerParameters(
|
||||
**cfg,
|
||||
)
|
||||
|
||||
def callback(msg: str):
|
||||
# 处理 MCP 服务的错误日志
|
||||
self.server_errlogs.append(msg)
|
||||
|
||||
stdio_transport = await self.exit_stack.enter_async_context(
|
||||
mcp.stdio_client(
|
||||
server_params,
|
||||
errlog=LogPipe(
|
||||
level=logging.ERROR,
|
||||
logger=logger,
|
||||
identifier=f"MCPServer-{name}",
|
||||
callback=callback,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new client session
|
||||
self.session = await self.exit_stack.enter_async_context(
|
||||
mcp.ClientSession(*stdio_transport)
|
||||
)
|
||||
|
||||
await self.session.initialize()
|
||||
|
||||
@@ -260,6 +294,13 @@ class FuncCall:
|
||||
if data["name"] in self.mcp_client_event:
|
||||
self.mcp_client_event[data["name"]].set()
|
||||
self.mcp_client_event.pop(data["name"], None)
|
||||
self.func_list = [
|
||||
f
|
||||
for f in self.func_list
|
||||
if not (
|
||||
f.origin == "mcp" and f.mcp_server_name == data["name"]
|
||||
)
|
||||
]
|
||||
else:
|
||||
for name in self.mcp_client_dict.keys():
|
||||
# await self._terminate_mcp_client(name)
|
||||
@@ -267,6 +308,7 @@ class FuncCall:
|
||||
if name in self.mcp_client_event:
|
||||
self.mcp_client_event[name].set()
|
||||
self.mcp_client_event.pop(name, None)
|
||||
self.func_list = [f for f in self.func_list if f.origin != "mcp"]
|
||||
|
||||
async def _init_mcp_client_task_wrapper(
|
||||
self, name: str, cfg: dict, event: asyncio.Event
|
||||
@@ -278,6 +320,9 @@ class FuncCall:
|
||||
logger.info(f"收到 MCP 客户端 {name} 终止信号")
|
||||
await self._terminate_mcp_client(name)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
|
||||
async def _init_mcp_client(self, name: str, config: dict) -> None:
|
||||
@@ -289,10 +334,10 @@ class FuncCall:
|
||||
|
||||
mcp_client = MCPClient()
|
||||
mcp_client.name = name
|
||||
await mcp_client.connect_to_server(config)
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
await mcp_client.connect_to_server(config, name)
|
||||
tools_res = await mcp_client.list_tools_and_save()
|
||||
tool_names = [tool.name for tool in tools_res.tools]
|
||||
self.mcp_client_dict[name] = mcp_client
|
||||
|
||||
# 移除该MCP服务之前的工具(如有)
|
||||
self.func_list = [
|
||||
@@ -314,13 +359,16 @@ class FuncCall:
|
||||
self.func_list.append(func_tool)
|
||||
|
||||
logger.info(f"已连接 MCP 服务 {name}, Tools: {tool_names}")
|
||||
return True
|
||||
return
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
logger.error(traceback.format_exc())
|
||||
logger.error(f"初始化 MCP 客户端 {name} 失败: {e}")
|
||||
# 发生错误时确保客户端被清理
|
||||
if name in self.mcp_client_dict:
|
||||
await self._terminate_mcp_client(name)
|
||||
return False
|
||||
return
|
||||
|
||||
async def _terminate_mcp_client(self, name: str) -> None:
|
||||
"""关闭并清理MCP客户端"""
|
||||
@@ -339,7 +387,7 @@ class FuncCall:
|
||||
]
|
||||
logger.info(f"已关闭 MCP 服务 {name}")
|
||||
|
||||
def get_func_desc_openai_style(self) -> list:
|
||||
def get_func_desc_openai_style(self, omit_empty_parameter_field=False) -> list:
|
||||
"""
|
||||
获得 OpenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
@@ -348,16 +396,19 @@ class FuncCall:
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
_l.append(
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f.name,
|
||||
"parameters": f.parameters,
|
||||
"description": f.description,
|
||||
},
|
||||
}
|
||||
)
|
||||
func_ = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": f.name,
|
||||
# "parameters": f.parameters,
|
||||
"description": f.description,
|
||||
},
|
||||
}
|
||||
func_["function"]["parameters"] = f.parameters
|
||||
if not f.parameters.get("properties") and omit_empty_parameter_field:
|
||||
# 如果 properties 为空,并且 omit_empty_parameter_field 为 True,则删除 parameters 字段
|
||||
del func_["function"]["parameters"]
|
||||
_l.append(func_)
|
||||
return _l
|
||||
|
||||
def get_func_desc_anthropic_style(self) -> list:
|
||||
@@ -383,28 +434,86 @@ class FuncCall:
|
||||
tools.append(tool)
|
||||
return tools
|
||||
|
||||
def get_func_desc_google_genai_style(self) -> Dict:
|
||||
def get_func_desc_google_genai_style(self) -> dict:
|
||||
"""
|
||||
获得 Google GenAI API 风格的**已经激活**的工具描述
|
||||
"""
|
||||
|
||||
# Gemini API 支持的数据类型和格式
|
||||
supported_types = {
|
||||
"string",
|
||||
"number",
|
||||
"integer",
|
||||
"boolean",
|
||||
"array",
|
||||
"object",
|
||||
"null",
|
||||
}
|
||||
supported_formats = {
|
||||
"string": {"enum", "date-time"},
|
||||
"integer": {"int32", "int64"},
|
||||
"number": {"float", "double"},
|
||||
}
|
||||
|
||||
def convert_schema(schema: dict) -> dict:
|
||||
"""转换 schema 为 Gemini API 格式"""
|
||||
|
||||
# 如果 schema 包含 anyOf,则只返回 anyOf 字段
|
||||
if "anyOf" in schema:
|
||||
return {"anyOf": [convert_schema(s) for s in schema["anyOf"]]}
|
||||
|
||||
result = {}
|
||||
|
||||
if "type" in schema and schema["type"] in supported_types:
|
||||
result["type"] = schema["type"]
|
||||
if "format" in schema and schema["format"] in supported_formats.get(
|
||||
result["type"], set()
|
||||
):
|
||||
result["format"] = schema["format"]
|
||||
else:
|
||||
# 暂时指定默认为null
|
||||
result["type"] = "null"
|
||||
|
||||
support_fields = {
|
||||
"title",
|
||||
"description",
|
||||
"enum",
|
||||
"minimum",
|
||||
"maximum",
|
||||
"maxItems",
|
||||
"minItems",
|
||||
"nullable",
|
||||
"required",
|
||||
}
|
||||
result.update({k: schema[k] for k in support_fields if k in schema})
|
||||
|
||||
if "properties" in schema:
|
||||
properties = {}
|
||||
for key, value in schema["properties"].items():
|
||||
prop_value = convert_schema(value)
|
||||
if "default" in prop_value:
|
||||
del prop_value["default"]
|
||||
properties[key] = prop_value
|
||||
|
||||
if properties: # 只在有非空属性时添加
|
||||
result["properties"] = properties
|
||||
|
||||
if "items" in schema:
|
||||
result["items"] = convert_schema(schema["items"])
|
||||
|
||||
return result
|
||||
|
||||
tools = [
|
||||
{
|
||||
"name": f.name,
|
||||
"description": f.description,
|
||||
**({"parameters": convert_schema(f.parameters)}),
|
||||
}
|
||||
for f in self.func_list
|
||||
if f.active
|
||||
]
|
||||
|
||||
declarations = {}
|
||||
tools = []
|
||||
for f in self.func_list:
|
||||
if not f.active:
|
||||
continue
|
||||
|
||||
func_declaration = {"name": f.name, "description": f.description}
|
||||
|
||||
# 检查并添加非空的properties参数
|
||||
params = f.parameters if isinstance(f.parameters, dict) else {}
|
||||
params = copy.deepcopy(params)
|
||||
if params.get("properties", {}):
|
||||
properties = params["properties"]
|
||||
for key, value in properties.items():
|
||||
if "default" in value:
|
||||
del value["default"]
|
||||
params["properties"] = properties
|
||||
func_declaration["parameters"] = params
|
||||
|
||||
tools.append(func_declaration)
|
||||
|
||||
if tools:
|
||||
declarations["function_declarations"] = tools
|
||||
return declarations
|
||||
|
||||
@@ -2,7 +2,7 @@ import traceback
|
||||
import asyncio
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from .provider import Provider, STTProvider, TTSProvider, Personality
|
||||
from .entites import ProviderType
|
||||
from .entities import ProviderType
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from .register import provider_cls_map, llm_tools
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import abc
|
||||
from typing import List
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from typing import TypedDict
|
||||
from typing import TypedDict, AsyncGenerator
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@@ -108,7 +108,35 @@ class Provider(AbstractProvider):
|
||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
...
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = None,
|
||||
func_tool: FuncCall = None,
|
||||
contexts: List = None,
|
||||
system_prompt: str = None,
|
||||
tool_calls_result: ToolCallsResult = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
|
||||
|
||||
Args:
|
||||
prompt: 提示词
|
||||
session_id: 会话 ID(此属性已经被废弃)
|
||||
image_urls: 图片 URL 列表
|
||||
tools: Function-calling 工具
|
||||
contexts: 上下文
|
||||
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
|
||||
kwargs: 其他参数
|
||||
|
||||
Notes:
|
||||
- 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。
|
||||
- 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。
|
||||
"""
|
||||
...
|
||||
|
||||
async def pop_record(self, context: List):
|
||||
"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import List, Dict
|
||||
from .entites import ProviderMetaData, ProviderType
|
||||
from .entities import ProviderMetaData, ProviderType
|
||||
from astrbot.core import logger
|
||||
from .func_tool_manager import FuncCall
|
||||
|
||||
|
||||
@@ -10,7 +10,8 @@ from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse, ToolCallsResult
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
@@ -72,7 +73,8 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
if content.type == "text":
|
||||
# text completion
|
||||
completion_text = str(content.text).strip()
|
||||
llm_response.completion_text = completion_text
|
||||
# llm_response.completion_text = completion_text
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
|
||||
# Anthropic每次只返回一个函数调用
|
||||
if completion.stop_reason == "tool_use":
|
||||
@@ -145,7 +147,7 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
messages=context_query, **model_config
|
||||
)
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.completion_text = response.content[0].text
|
||||
llm_response.result_chain = MessageChain().message(response.content[0].text)
|
||||
llm_response.raw_completion = response
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
@@ -160,6 +162,33 @@ class ProviderAnthropic(ProviderOpenAIOfficial):
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""组装上下文,支持文本和图片"""
|
||||
if not image_urls:
|
||||
|
||||
@@ -3,10 +3,11 @@ import asyncio
|
||||
import functools
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entites import LLMResponse
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
from astrbot.core import logger, sp
|
||||
from dashscope import Application
|
||||
@@ -132,7 +133,9 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
)
|
||||
return LLMResponse(
|
||||
role="err",
|
||||
completion_text=f"阿里云百炼请求失败: message={response.message} code={response.status_code}",
|
||||
result_chain=MessageChain().message(
|
||||
f"阿里云百炼请求失败: message={response.message} code={response.status_code}"
|
||||
),
|
||||
)
|
||||
|
||||
output_text = response.output.get("text", "")
|
||||
@@ -141,11 +144,45 @@ class ProviderDashscope(ProviderOpenAIOfficial):
|
||||
if self.output_reference and response.output.get("doc_references", None):
|
||||
ref_str = ""
|
||||
for ref in response.output.get("doc_references", []):
|
||||
ref_title = ref.get("title", "") if ref.get("title") else ref.get("doc_name", "")
|
||||
ref_title = (
|
||||
ref.get("title", "")
|
||||
if ref.get("title")
|
||||
else ref.get("doc_name", "")
|
||||
)
|
||||
ref_str += f"{ref['index_id']}. {ref_title}\n"
|
||||
output_text += f"\n\n回答来源:\n{ref_str}"
|
||||
|
||||
return LLMResponse(role="assistant", completion_text=output_text)
|
||||
llm_response = LLMResponse("assistant")
|
||||
llm_response.result_chain = MessageChain().message(output_text)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def forget(self, session_id):
|
||||
return True
|
||||
|
||||
@@ -3,7 +3,7 @@ import uuid
|
||||
import asyncio
|
||||
from dashscope.audio.tts_v2 import *
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
@@ -20,17 +20,16 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
self.chosen_api_key: str = provider_config.get("api_key", "")
|
||||
self.voice: str = provider_config.get("dashscope_tts_voice", "loongstella")
|
||||
self.set_model(provider_config.get("model", None))
|
||||
self.timeout_ms = float(provider_config.get("timeout", 20))*1000
|
||||
|
||||
self.timeout_ms = float(provider_config.get("timeout", 20)) * 1000
|
||||
dashscope.api_key = self.chosen_api_key
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
|
||||
self.synthesizer = SpeechSynthesizer(
|
||||
model=self.get_model(),
|
||||
voice=self.voice,
|
||||
format=AudioFormat.WAV_24000HZ_MONO_16BIT,
|
||||
)
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav"
|
||||
audio = await asyncio.get_event_loop().run_in_executor(
|
||||
None, self.synthesizer.call, text, self.timeout_ms
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ import astrbot.core.message.components as Comp
|
||||
|
||||
from typing import List
|
||||
from .. import Provider, Personality
|
||||
from ..entites import LLMResponse
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
@@ -102,7 +102,7 @@ class ProviderDify(Provider):
|
||||
|
||||
try:
|
||||
match self.api_type:
|
||||
case "chat" | "agent":
|
||||
case "chat" | "agent" | "chatflow":
|
||||
if not prompt:
|
||||
prompt = "请描述这张图片。"
|
||||
|
||||
@@ -189,6 +189,33 @@ class ProviderDify(Provider):
|
||||
|
||||
return LLMResponse(role="assistant", result_chain=chain)
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def parse_dify_result(self, chunk: dict | str) -> MessageChain:
|
||||
if isinstance(chunk, str):
|
||||
# Chat
|
||||
|
||||
@@ -4,7 +4,7 @@ import edge_tts
|
||||
import subprocess
|
||||
import asyncio
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from pydantic import BaseModel, conint
|
||||
from httpx import AsyncClient
|
||||
from typing import Annotated, Literal
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
|
||||
@@ -1,88 +1,55 @@
|
||||
import base64
|
||||
import aiohttp
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
from typing import Dict, List, Optional
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.errors import APIError
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.api.provider import Personality, Provider
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
|
||||
|
||||
class SimpleGoogleGenAIClient:
|
||||
def __init__(self, api_key: str, api_base: str, timeout: int = 120) -> None:
|
||||
self.api_key = api_key
|
||||
if api_base.endswith("/"):
|
||||
self.api_base = api_base[:-1]
|
||||
else:
|
||||
self.api_base = api_base
|
||||
self.client = aiohttp.ClientSession(trust_env=True)
|
||||
self.timeout = timeout
|
||||
class SuppressNonTextPartsWarning(logging.Filter):
|
||||
"""过滤 Gemini SDK 中的非文本部分警告"""
|
||||
|
||||
async def models_list(self) -> List[str]:
|
||||
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
|
||||
async with self.client.get(request_url, timeout=self.timeout) as resp:
|
||||
response = await resp.json()
|
||||
def filter(self, record):
|
||||
return "there are non-text parts in the response" not in record.getMessage()
|
||||
|
||||
models = []
|
||||
for model in response["models"]:
|
||||
if "generateContent" in model["supportedGenerationMethods"]:
|
||||
models.append(model["name"].replace("models/", ""))
|
||||
return models
|
||||
|
||||
async def generate_content(
|
||||
self,
|
||||
contents: List[dict],
|
||||
model: str = "gemini-1.5-flash",
|
||||
system_instruction: str = "",
|
||||
tools: dict = None,
|
||||
modalities: List[str] = ["Text"],
|
||||
safety_settings: List[dict] = [],
|
||||
):
|
||||
payload = {}
|
||||
if system_instruction:
|
||||
payload["system_instruction"] = {"parts": {"text": system_instruction}}
|
||||
if tools:
|
||||
payload["tools"] = [tools]
|
||||
payload["contents"] = contents
|
||||
payload["generationConfig"] = {
|
||||
"responseModalities": modalities,
|
||||
}
|
||||
payload["safetySettings"] = [
|
||||
{"category": s["category"], "threshold": s["threshold"]}
|
||||
for s in safety_settings
|
||||
]
|
||||
logger.debug(f"payload: {payload}")
|
||||
request_url = (
|
||||
f"{self.api_base}/v1beta/models/{model}:generateContent?key={self.api_key}"
|
||||
)
|
||||
async with self.client.post(
|
||||
request_url, json=payload, timeout=self.timeout
|
||||
) as resp:
|
||||
if "application/json" in resp.headers.get("Content-Type"):
|
||||
try:
|
||||
response = await resp.json()
|
||||
except Exception as e:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise e
|
||||
return response
|
||||
else:
|
||||
text = await resp.text()
|
||||
logger.error(f"Gemini 返回了非 json 数据: {text}")
|
||||
raise Exception("Gemini 返回了非 json 数据: ")
|
||||
logging.getLogger("google_genai.types").addFilter(SuppressNonTextPartsWarning())
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
"googlegenai_chat_completion", "Google Gemini Chat Completion 提供商适配器"
|
||||
)
|
||||
class ProviderGoogleGenAI(Provider):
|
||||
CATEGORY_MAPPING = {
|
||||
"harassment": types.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
"hate_speech": types.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
"sexually_explicit": types.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
"dangerous_content": types.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
}
|
||||
|
||||
THRESHOLD_MAPPING = {
|
||||
"BLOCK_NONE": types.HarmBlockThreshold.BLOCK_NONE,
|
||||
"BLOCK_ONLY_HIGH": types.HarmBlockThreshold.BLOCK_ONLY_HIGH,
|
||||
"BLOCK_MEDIUM_AND_ABOVE": types.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
|
||||
"BLOCK_LOW_AND_ABOVE": types.HarmBlockThreshold.BLOCK_LOW_AND_ABOVE,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
provider_config: dict,
|
||||
@@ -98,183 +65,401 @@ class ProviderGoogleGenAI(Provider):
|
||||
db_helper,
|
||||
default_persona,
|
||||
)
|
||||
self.chosen_api_key = None
|
||||
self.api_keys: List = provider_config.get("key", [])
|
||||
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout = provider_config.get("timeout", 180)
|
||||
if isinstance(self.timeout, str):
|
||||
self.timeout = int(self.timeout)
|
||||
self.client = SimpleGoogleGenAIClient(
|
||||
api_key=self.chosen_api_key,
|
||||
api_base=provider_config.get("api_base", None),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
self.chosen_api_key: str = self.api_keys[0] if len(self.api_keys) > 0 else None
|
||||
self.timeout: int = int(provider_config.get("timeout", 180))
|
||||
|
||||
self.api_base: Optional[str] = provider_config.get("api_base", None)
|
||||
if self.api_base and self.api_base.endswith("/"):
|
||||
self.api_base = self.api_base[:-1]
|
||||
|
||||
self._init_client()
|
||||
self.set_model(provider_config["model_config"]["model"])
|
||||
self._init_safety_settings()
|
||||
|
||||
safety_mapping = {
|
||||
"harassment": "HARM_CATEGORY_HARASSMENT",
|
||||
"hate_speech": "HARM_CATEGORY_HATE_SPEECH",
|
||||
"sexually_explicit": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||
"dangerous_content": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||
}
|
||||
def _init_client(self) -> None:
|
||||
"""初始化Gemini客户端"""
|
||||
self.client = genai.Client(
|
||||
api_key=self.chosen_api_key,
|
||||
http_options=types.HttpOptions(
|
||||
base_url=self.api_base,
|
||||
timeout=self.timeout * 1000, # 毫秒
|
||||
),
|
||||
).aio
|
||||
|
||||
self.safety_settings = []
|
||||
def _init_safety_settings(self) -> None:
|
||||
"""初始化安全设置"""
|
||||
user_safety_config = self.provider_config.get("gm_safety_settings", {})
|
||||
for config_key, harm_category in safety_mapping.items():
|
||||
if threshold := user_safety_config.get(config_key):
|
||||
self.safety_settings.append(
|
||||
{"category": harm_category, "threshold": threshold}
|
||||
)
|
||||
self.safety_settings = [
|
||||
types.SafetySetting(
|
||||
category=harm_category, threshold=self.THRESHOLD_MAPPING[threshold_str]
|
||||
)
|
||||
for config_key, harm_category in self.CATEGORY_MAPPING.items()
|
||||
if (threshold_str := user_safety_config.get(config_key))
|
||||
and threshold_str in self.THRESHOLD_MAPPING
|
||||
]
|
||||
|
||||
async def get_models(self):
|
||||
return await self.client.models_list()
|
||||
async def _handle_api_error(self, e: APIError, keys: List[str]) -> bool:
|
||||
"""处理API错误,返回是否需要重试"""
|
||||
if e.code == 429 or "API key not valid" in e.message:
|
||||
keys.remove(self.chosen_api_key)
|
||||
if len(keys) > 0:
|
||||
self.set_key(random.choice(keys))
|
||||
logger.info(
|
||||
f"检测到 Key 异常({e.message}),正在尝试更换 API Key 重试... 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
return True
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({e.message}),且已没有可用的 Key。 当前 Key: {self.chosen_api_key[:12]}..."
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _prepare_query_config(
|
||||
self,
|
||||
payloads: dict,
|
||||
tools: Optional[FuncCall] = None,
|
||||
system_instruction: Optional[str] = None,
|
||||
modalities: Optional[List[str]] = None,
|
||||
temperature: float = 0.7,
|
||||
) -> types.GenerateContentConfig:
|
||||
"""准备查询配置"""
|
||||
if not modalities:
|
||||
modalities = ["Text"]
|
||||
|
||||
# 流式输出不支持图片模态
|
||||
if (
|
||||
self.provider_settings.get("streaming_response", False)
|
||||
and "Image" in modalities
|
||||
):
|
||||
logger.warning("流式输出不支持图片模态,已自动降级为文本模态")
|
||||
modalities = ["Text"]
|
||||
|
||||
tool_list = None
|
||||
native_coderunner = self.provider_config.get("gm_native_coderunner", False)
|
||||
native_search = self.provider_config.get("gm_native_search", False)
|
||||
|
||||
if native_coderunner:
|
||||
tool_list = [types.Tool(code_execution=types.ToolCodeExecution())]
|
||||
if native_search:
|
||||
logger.warning("已启用代码执行工具,搜索工具将被忽略")
|
||||
if tools:
|
||||
logger.warning("已启用代码执行工具,函数工具将被忽略")
|
||||
elif native_search:
|
||||
tool_list = [types.Tool(google_search=types.GoogleSearch())]
|
||||
if tools:
|
||||
logger.warning("已启用搜索工具,函数工具将被忽略")
|
||||
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
|
||||
tool_list = [
|
||||
types.Tool(function_declarations=func_desc["function_declarations"])
|
||||
]
|
||||
return types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
temperature=temperature,
|
||||
max_output_tokens=payloads.get("max_tokens")
|
||||
or payloads.get("maxOutputTokens"),
|
||||
top_p=payloads.get("top_p") or payloads.get("topP"),
|
||||
top_k=payloads.get("top_k") or payloads.get("topK"),
|
||||
frequency_penalty=payloads.get("frequency_penalty")
|
||||
or payloads.get("frequencyPenalty"),
|
||||
presence_penalty=payloads.get("presence_penalty")
|
||||
or payloads.get("presencePenalty"),
|
||||
stop_sequences=payloads.get("stop") or payloads.get("stopSequences"),
|
||||
response_logprobs=payloads.get("response_logprobs")
|
||||
or payloads.get("responseLogprobs"),
|
||||
logprobs=payloads.get("logprobs"),
|
||||
seed=payloads.get("seed"),
|
||||
response_modalities=modalities,
|
||||
tools=tool_list,
|
||||
safety_settings=self.safety_settings if self.safety_settings else None,
|
||||
thinking_config=types.ThinkingConfig(
|
||||
thinking_budget=min(
|
||||
int(
|
||||
self.provider_config.get("gm_thinking_config", {}).get(
|
||||
"budget", 0
|
||||
)
|
||||
),
|
||||
24576,
|
||||
),
|
||||
)
|
||||
if "gemini-2.5-flash" in self.get_model()
|
||||
else None,
|
||||
automatic_function_calling=types.AutomaticFunctionCallingConfig(
|
||||
disable=True
|
||||
),
|
||||
)
|
||||
|
||||
def _prepare_conversation(self, payloads: Dict) -> List[types.Content]:
|
||||
"""准备 Gemini SDK 的 Content 列表"""
|
||||
|
||||
def create_text_part(text: str) -> types.Part:
|
||||
content_a = text if text else " "
|
||||
if not text:
|
||||
logger.warning("文本内容为空,已添加空格占位")
|
||||
return types.Part.from_text(text=content_a)
|
||||
|
||||
def process_image_url(image_url_dict: dict) -> types.Part:
|
||||
url = image_url_dict["url"]
|
||||
mime_type = url.split(":")[1].split(";")[0]
|
||||
image_bytes = base64.b64decode(url.split(",", 1)[1])
|
||||
return types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
|
||||
|
||||
def append_or_extend(
|
||||
contents: list[types.Content],
|
||||
part: list[types.Part],
|
||||
content_cls: type[types.Content],
|
||||
) -> None:
|
||||
if contents and isinstance(contents[-1], content_cls):
|
||||
contents[-1].parts.extend(part)
|
||||
else:
|
||||
contents.append(content_cls(parts=part))
|
||||
|
||||
gemini_contents: List[types.Content] = []
|
||||
native_tool_enabled = any(
|
||||
[
|
||||
self.provider_config.get("gm_native_coderunner", False),
|
||||
self.provider_config.get("gm_native_search", False),
|
||||
]
|
||||
)
|
||||
for message in payloads["messages"]:
|
||||
role, content = message["role"], message.get("content")
|
||||
|
||||
if role == "user":
|
||||
if isinstance(content, list):
|
||||
parts = [
|
||||
types.Part.from_text(text=item["text"] or " ")
|
||||
if item["type"] == "text"
|
||||
else process_image_url(item["image_url"])
|
||||
for item in content
|
||||
]
|
||||
else:
|
||||
parts = [create_text_part(content)]
|
||||
append_or_extend(gemini_contents, parts, types.UserContent)
|
||||
|
||||
elif role == "assistant":
|
||||
if content:
|
||||
parts = [types.Part.from_text(text=content)]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
elif not native_tool_enabled and "tool_calls" in message:
|
||||
parts = [
|
||||
types.Part.from_function_call(
|
||||
name=tool["function"]["name"],
|
||||
args=json.loads(tool["function"]["arguments"]),
|
||||
)
|
||||
for tool in message["tool_calls"]
|
||||
]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
else:
|
||||
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
|
||||
if native_tool_enabled and "tool_calls" in message:
|
||||
logger.warning(
|
||||
"检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文"
|
||||
)
|
||||
parts = [types.Part.from_text(text=" ")]
|
||||
append_or_extend(gemini_contents, parts, types.ModelContent)
|
||||
|
||||
elif role == "tool" and not native_tool_enabled:
|
||||
parts = [
|
||||
types.Part.from_function_response(
|
||||
name=message["tool_call_id"],
|
||||
response={
|
||||
"name": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
)
|
||||
]
|
||||
append_or_extend(gemini_contents, parts, types.UserContent)
|
||||
|
||||
if gemini_contents and isinstance(gemini_contents[0], types.ModelContent):
|
||||
gemini_contents.pop()
|
||||
|
||||
return gemini_contents
|
||||
|
||||
@staticmethod
|
||||
def _process_content_parts(
|
||||
result: types.GenerateContentResponse, llm_response: LLMResponse
|
||||
) -> MessageChain:
|
||||
"""处理内容部分并构建消息链"""
|
||||
finish_reason = result.candidates[0].finish_reason
|
||||
result_parts: Optional[types.Part] = result.candidates[0].content.parts
|
||||
|
||||
if finish_reason == types.FinishReason.SAFETY:
|
||||
raise Exception("模型生成内容未通过用户定义的内容安全检查")
|
||||
|
||||
if finish_reason in {
|
||||
types.FinishReason.PROHIBITED_CONTENT,
|
||||
types.FinishReason.SPII,
|
||||
types.FinishReason.BLOCKLIST,
|
||||
}:
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
|
||||
# 防止旧版本SDK不存在IMAGE_SAFETY
|
||||
if hasattr(types.FinishReason, "IMAGE_SAFETY"):
|
||||
if finish_reason == types.FinishReason.IMAGE_SAFETY:
|
||||
raise Exception("模型生成内容违反Gemini平台政策")
|
||||
|
||||
if not result_parts:
|
||||
logger.debug(result.candidates)
|
||||
raise Exception("API 返回的内容为空。")
|
||||
|
||||
chain = []
|
||||
part: types.Part
|
||||
|
||||
# 暂时这样Fallback
|
||||
if all(
|
||||
part.inline_data and part.inline_data.mime_type.startswith("image/")
|
||||
for part in result_parts
|
||||
):
|
||||
chain.append(Comp.Plain("这是图片"))
|
||||
for part in result_parts:
|
||||
if part.text:
|
||||
chain.append(Comp.Plain(part.text))
|
||||
elif part.function_call:
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_name.append(part.function_call.name)
|
||||
llm_response.tools_call_args.append(part.function_call.args)
|
||||
# gemini 返回的 function_call.id 可能为 None
|
||||
llm_response.tools_call_ids.append(
|
||||
part.function_call.id or part.function_call.name
|
||||
)
|
||||
elif part.inline_data and part.inline_data.mime_type.startswith("image/"):
|
||||
chain.append(Comp.Image.fromBytes(part.inline_data.data))
|
||||
return MessageChain(chain=chain)
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
tool = None
|
||||
if tools:
|
||||
tool = tools.get_func_desc_google_genai_style()
|
||||
if not tool:
|
||||
tool = None
|
||||
"""非流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
modalities = ["Text"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalities.append("Image")
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
temperature = payloads.get("temperature", 0.7)
|
||||
|
||||
result: Optional[types.GenerateContentResponse] = None
|
||||
while True:
|
||||
try:
|
||||
config = await self._prepare_query_config(
|
||||
payloads, tools, system_instruction, modalities, temperature
|
||||
)
|
||||
result = await self.client.models.generate_content(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
config=config,
|
||||
)
|
||||
|
||||
if result.candidates[0].finish_reason == types.FinishReason.RECITATION:
|
||||
if temperature > 2:
|
||||
raise Exception("温度参数已超过最大值2,仍然发生recitation")
|
||||
temperature += 0.2
|
||||
logger.warning(
|
||||
f"发生了recitation,正在提高温度至{temperature:.1f}重试..."
|
||||
)
|
||||
continue
|
||||
|
||||
system_instruction = ""
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "system":
|
||||
system_instruction = message["content"]
|
||||
break
|
||||
|
||||
google_genai_conversation = []
|
||||
for message in payloads["messages"]:
|
||||
if message["role"] == "user":
|
||||
if isinstance(message["content"], str):
|
||||
if not message["content"]:
|
||||
message["content"] = ""
|
||||
|
||||
google_genai_conversation.append(
|
||||
{"role": "user", "parts": [{"text": message["content"]}]}
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
elif isinstance(message["content"], list):
|
||||
# images
|
||||
parts = []
|
||||
for part in message["content"]:
|
||||
if part["type"] == "text":
|
||||
if not part["text"]:
|
||||
part["text"] = ""
|
||||
parts.append({"text": part["text"]})
|
||||
elif part["type"] == "image_url":
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "image/jpeg",
|
||||
"data": part["image_url"]["url"].replace(
|
||||
"data:image/jpeg;base64,", ""
|
||||
), # base64
|
||||
}
|
||||
}
|
||||
)
|
||||
google_genai_conversation.append({"role": "user", "parts": parts})
|
||||
|
||||
elif message["role"] == "assistant":
|
||||
if "content" in message:
|
||||
if not message["content"]:
|
||||
message["content"] = ""
|
||||
google_genai_conversation.append(
|
||||
{"role": "model", "parts": [{"text": message["content"]}]}
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
elif (
|
||||
"Multi-modal output is not supported" in e.message
|
||||
or "Model does not support the requested response modalities"
|
||||
in e.message
|
||||
or "only supports text output" in e.message
|
||||
):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态"
|
||||
)
|
||||
elif "tool_calls" in message:
|
||||
# tool calls in the last turn
|
||||
parts = []
|
||||
for tool_call in message["tool_calls"]:
|
||||
parts.append(
|
||||
{
|
||||
"functionCall": {
|
||||
"name": tool_call["function"]["name"],
|
||||
"args": json.loads(
|
||||
tool_call["function"]["arguments"]
|
||||
),
|
||||
}
|
||||
}
|
||||
)
|
||||
google_genai_conversation.append({"role": "model", "parts": parts})
|
||||
elif message["role"] == "tool":
|
||||
parts = []
|
||||
parts.append(
|
||||
{
|
||||
"functionResponse": {
|
||||
"name": message["tool_call_id"],
|
||||
"response": {
|
||||
"name": message["tool_call_id"],
|
||||
"content": message["content"],
|
||||
},
|
||||
}
|
||||
}
|
||||
)
|
||||
google_genai_conversation.append({"role": "user", "parts": parts})
|
||||
modalities = ["Text"]
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
|
||||
logger.debug(f"google_genai_conversation: {google_genai_conversation}")
|
||||
|
||||
modalites = ["Text"]
|
||||
if self.provider_config.get("gm_resp_image_modal", False):
|
||||
modalites.append("Image")
|
||||
|
||||
loop = True
|
||||
while loop:
|
||||
loop = False
|
||||
result = await self.client.generate_content(
|
||||
contents=google_genai_conversation,
|
||||
model=self.get_model(),
|
||||
system_instruction=system_instruction,
|
||||
tools=tool,
|
||||
modalities=modalites,
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
logger.debug(f"result: {result}")
|
||||
|
||||
# Developer instruction is not enabled for models/gemini-2.0-flash-exp
|
||||
if "Developer instruction is not enabled" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt, 已自动去除, 将会影响人格设置。"
|
||||
)
|
||||
system_instruction = ""
|
||||
loop = True
|
||||
|
||||
elif "Function calling is not enabled" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持函数调用,已自动去除,不影响使用。"
|
||||
)
|
||||
tool = None
|
||||
loop = True
|
||||
|
||||
elif "Multi-modal output is not supported" in str(result):
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持多模态输出,降级为文本模态重新请求。"
|
||||
)
|
||||
modalites = ["Text"]
|
||||
loop = True
|
||||
|
||||
elif "candidates" not in result:
|
||||
raise Exception("Gemini 返回异常结果: " + str(result))
|
||||
|
||||
candidates = result["candidates"][0]["content"]["parts"]
|
||||
llm_response = LLMResponse("assistant")
|
||||
chain = []
|
||||
for candidate in candidates:
|
||||
if "text" in candidate:
|
||||
chain.append(Comp.Plain(candidate["text"]))
|
||||
elif "functionCall" in candidate:
|
||||
llm_response.role = "tool"
|
||||
llm_response.tools_call_args.append(candidate["functionCall"]["args"])
|
||||
llm_response.tools_call_name.append(candidate["functionCall"]["name"])
|
||||
llm_response.tools_call_ids.append(
|
||||
candidate["functionCall"]["name"]
|
||||
) # 没有 tool id
|
||||
elif "inlineData" in candidate:
|
||||
mime_type: str = candidate["inlineData"]["mimeType"]
|
||||
if mime_type.startswith("image/"):
|
||||
chain.append(Comp.Image.fromBase64(candidate["inlineData"]["data"]))
|
||||
|
||||
llm_response.result_chain = MessageChain(chain=chain)
|
||||
llm_response.result_chain = self._process_content_parts(result, llm_response)
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式请求 Gemini API"""
|
||||
system_instruction = next(
|
||||
(msg["content"] for msg in payloads["messages"] if msg["role"] == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
conversation = self._prepare_conversation(payloads)
|
||||
|
||||
result = None
|
||||
while True:
|
||||
try:
|
||||
config = await self._prepare_query_config(
|
||||
payloads, tools, system_instruction
|
||||
)
|
||||
result = await self.client.models.generate_content_stream(
|
||||
model=self.get_model(),
|
||||
contents=conversation,
|
||||
config=config,
|
||||
)
|
||||
break
|
||||
except APIError as e:
|
||||
if "Developer instruction is not enabled" in e.message:
|
||||
logger.warning(
|
||||
f"{self.get_model()} 不支持 system prompt,已自动去除(影响人格设置)"
|
||||
)
|
||||
system_instruction = None
|
||||
elif "Function calling is not enabled" in e.message:
|
||||
logger.warning(f"{self.get_model()} 不支持函数调用,已自动去除")
|
||||
tools = None
|
||||
else:
|
||||
raise
|
||||
continue
|
||||
|
||||
async for chunk in result:
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
if chunk.candidates[0].content.parts and any(
|
||||
part.function_call for part in chunk.candidates[0].content.parts
|
||||
):
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
|
||||
if chunk.text:
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(chunk.text)])
|
||||
yield llm_response
|
||||
|
||||
if chunk.candidates[0].finish_reason:
|
||||
llm_response = LLMResponse("assistant", is_chunk=False)
|
||||
if not chunk.candidates[0].content.parts:
|
||||
llm_response.result_chain = MessageChain(chain=[Comp.Plain(" ")])
|
||||
else:
|
||||
llm_response.result_chain = self._process_content_parts(
|
||||
chunk, llm_response
|
||||
)
|
||||
yield llm_response
|
||||
break
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -287,7 +472,6 @@ class ProviderGoogleGenAI(Provider):
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = []
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
@@ -304,55 +488,90 @@ class ProviderGoogleGenAI(Provider):
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
llm_response = None
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(keys)
|
||||
|
||||
for i in range(retry):
|
||||
for _ in range(retry):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
return await self._query(payloads, func_tool)
|
||||
except APIError as e:
|
||||
if await self._handle_api_error(e, keys):
|
||||
continue
|
||||
break
|
||||
except Exception as e:
|
||||
if "429" in str(e) or "API key not valid" in str(e):
|
||||
keys.remove(chosen_key)
|
||||
if len(keys) > 0:
|
||||
chosen_key = random.choice(keys)
|
||||
logger.info(
|
||||
f"检测到 Key 异常({str(e)}),正在尝试更换 API Key 重试... 当前 Key: {chosen_key[:12]}..."
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
logger.error(
|
||||
f"检测到 Key 异常({str(e)}),且已没有可用的 Key。 当前 Key: {chosen_key[:12]}..."
|
||||
)
|
||||
raise Exception("达到了 Gemini 速率限制, 请稍后再试...")
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误(gemini_source)。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
context_query.insert(0, {"role": "system", "content": system_prompt})
|
||||
|
||||
for part in context_query:
|
||||
if "_no_save" in part:
|
||||
del part["_no_save"]
|
||||
|
||||
# tool calls result
|
||||
if tool_calls_result:
|
||||
context_query.extend(tool_calls_result.to_openai_messages())
|
||||
|
||||
model_config = self.provider_config.get("model_config", {})
|
||||
model_config["model"] = self.get_model()
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
retry = 10
|
||||
keys = self.api_keys.copy()
|
||||
|
||||
for _ in range(retry):
|
||||
try:
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
break
|
||||
except APIError as e:
|
||||
if await self._handle_api_error(e, keys):
|
||||
continue
|
||||
break
|
||||
|
||||
async def get_models(self):
|
||||
try:
|
||||
models = await self.client.models.list()
|
||||
return [
|
||||
m.name.replace("models/", "")
|
||||
for m in models
|
||||
if "generateContent" in m.supported_actions
|
||||
]
|
||||
except APIError as e:
|
||||
raise Exception(f"获取模型列表失败: {e.message}")
|
||||
|
||||
def get_current_key(self) -> str:
|
||||
return self.client.api_key
|
||||
return self.chosen_api_key
|
||||
|
||||
def get_keys(self) -> List[str]:
|
||||
return self.api_keys
|
||||
|
||||
def set_key(self, key):
|
||||
self.client.api_key = key
|
||||
self.chosen_api_key = key
|
||||
self._init_client()
|
||||
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None):
|
||||
"""
|
||||
组装上下文。
|
||||
"""
|
||||
if image_urls:
|
||||
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
@@ -384,5 +603,4 @@ class ProviderGoogleGenAI(Provider):
|
||||
return ""
|
||||
|
||||
async def terminate(self):
|
||||
await self.client.client.close()
|
||||
logger.info("Google GenAI 适配器已终止。")
|
||||
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
import aiohttp
|
||||
import urllib.parse
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import os
|
||||
from llmtuner.chat import ChatModel
|
||||
from typing import List
|
||||
from .. import Provider
|
||||
from ..entites import LLMResponse
|
||||
from ..entities import LLMResponse
|
||||
from ..func_tool_manager import FuncCall
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from ..register import register_provider_adapter
|
||||
@@ -95,6 +95,33 @@ class LLMTunerModelLoader(Provider):
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt,
|
||||
session_id=None,
|
||||
image_urls=...,
|
||||
func_tool=None,
|
||||
contexts=...,
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
):
|
||||
# raise NotImplementedError("This method is not implemented yet.")
|
||||
# 调用 text_chat 模拟流式
|
||||
llm_response = await self.text_chat(
|
||||
prompt=prompt,
|
||||
session_id=session_id,
|
||||
image_urls=image_urls,
|
||||
func_tool=func_tool,
|
||||
contexts=contexts,
|
||||
system_prompt=system_prompt,
|
||||
tool_calls_result=tool_calls_result,
|
||||
)
|
||||
llm_response.is_chunk = True
|
||||
yield llm_response
|
||||
llm_response.is_chunk = False
|
||||
yield llm_response
|
||||
|
||||
async def get_current_key(self):
|
||||
return "none"
|
||||
|
||||
|
||||
@@ -4,19 +4,24 @@ import os
|
||||
import inspect
|
||||
import random
|
||||
import asyncio
|
||||
import astrbot.core.message.components as Comp
|
||||
|
||||
from openai import AsyncOpenAI, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
# from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai._exceptions import NotFoundError, UnprocessableEntityError
|
||||
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.api.provider import Provider, Personality
|
||||
from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from typing import List, AsyncGenerator
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
|
||||
|
||||
@register_provider_adapter(
|
||||
@@ -82,7 +87,11 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
async def _query(self, payloads: dict, tools: FuncCall) -> LLMResponse:
|
||||
if tools:
|
||||
tool_list = tools.get_func_desc_openai_style()
|
||||
model = payloads.get("model", "").lower()
|
||||
omit_empty_param_field = "gemini" in model
|
||||
tool_list = tools.get_func_desc_openai_style(
|
||||
omit_empty_parameter_field=omit_empty_param_field
|
||||
)
|
||||
if tool_list:
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
@@ -107,16 +116,76 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
logger.debug(f"completion: {completion}")
|
||||
|
||||
llm_response = await self.parse_openai_completion(completion, tools)
|
||||
|
||||
return llm_response
|
||||
|
||||
async def _query_stream(
|
||||
self, payloads: dict, tools: FuncCall
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式查询API,逐步返回结果"""
|
||||
if tools:
|
||||
model = payloads.get("model", "").lower()
|
||||
omit_empty_param_field = "gemini" in model
|
||||
tool_list = tools.get_func_desc_openai_style(
|
||||
omit_empty_parameter_field=omit_empty_param_field
|
||||
)
|
||||
if tool_list:
|
||||
payloads["tools"] = tool_list
|
||||
|
||||
# 不在默认参数中的参数放在 extra_body 中
|
||||
extra_body = {}
|
||||
to_del = []
|
||||
for key in payloads.keys():
|
||||
if key not in self.default_params:
|
||||
extra_body[key] = payloads[key]
|
||||
to_del.append(key)
|
||||
for key in to_del:
|
||||
del payloads[key]
|
||||
|
||||
stream = await self.client.chat.completions.create(
|
||||
**payloads, stream=True, extra_body=extra_body
|
||||
)
|
||||
|
||||
llm_response = LLMResponse("assistant", is_chunk=True)
|
||||
|
||||
state = ChatCompletionStreamState()
|
||||
|
||||
async for chunk in stream:
|
||||
try:
|
||||
state.handle_chunk(chunk)
|
||||
except Exception as e:
|
||||
logger.warning("Saving chunk state error: " + str(e))
|
||||
if len(chunk.choices) == 0:
|
||||
continue
|
||||
delta = chunk.choices[0].delta
|
||||
# 处理文本内容
|
||||
if delta.content:
|
||||
completion_text = delta.content
|
||||
llm_response.result_chain = MessageChain(
|
||||
chain=[Comp.Plain(completion_text)]
|
||||
)
|
||||
yield llm_response
|
||||
|
||||
final_completion = state.get_final_completion()
|
||||
llm_response = await self.parse_openai_completion(final_completion, tools)
|
||||
|
||||
yield llm_response
|
||||
|
||||
async def parse_openai_completion(
|
||||
self, completion: ChatCompletion, tools: FuncCall
|
||||
):
|
||||
"""解析 OpenAI 的 ChatCompletion 响应"""
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if len(completion.choices) == 0:
|
||||
raise Exception("API 返回的 completion 为空。")
|
||||
choice = completion.choices[0]
|
||||
|
||||
llm_response = LLMResponse("assistant")
|
||||
|
||||
if choice.message.content:
|
||||
# text completion
|
||||
completion_text = str(choice.message.content).strip()
|
||||
llm_response.completion_text = completion_text
|
||||
llm_response.result_chain = MessageChain().message(completion_text)
|
||||
|
||||
if choice.message.tool_calls:
|
||||
# tools call (function calling)
|
||||
@@ -148,7 +217,7 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
return llm_response
|
||||
|
||||
async def text_chat(
|
||||
async def _prepare_chat_payload(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
@@ -158,7 +227,8 @@ class ProviderOpenAIOfficial(Provider):
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
) -> tuple:
|
||||
"""准备聊天所需的有效载荷和上下文"""
|
||||
new_record = await self.assemble_context(prompt, image_urls)
|
||||
context_query = [*contexts, new_record]
|
||||
if system_prompt:
|
||||
@@ -177,13 +247,122 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
payloads = {"messages": context_query, **model_config}
|
||||
|
||||
llm_response = None
|
||||
return payloads, context_query, func_tool
|
||||
|
||||
async def _handle_api_error(
|
||||
self,
|
||||
e: Exception,
|
||||
payloads: dict,
|
||||
context_query: list,
|
||||
func_tool: FuncCall,
|
||||
chosen_key: str,
|
||||
available_api_keys: List[str],
|
||||
retry_cnt: int,
|
||||
max_retries: int,
|
||||
) -> tuple:
|
||||
"""处理API错误并尝试恢复"""
|
||||
if "429" in str(e):
|
||||
logger.warning(
|
||||
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
|
||||
)
|
||||
# 最后一次不等待
|
||||
if retry_cnt < max_retries - 1:
|
||||
await asyncio.sleep(1)
|
||||
available_api_keys.remove(chosen_key)
|
||||
if len(available_api_keys) > 0:
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
return (
|
||||
False,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
elif "maximum context length" in str(e):
|
||||
logger.warning(
|
||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||
)
|
||||
await self.pop_record(context_query)
|
||||
payloads["messages"] = context_query
|
||||
return (
|
||||
False,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
)
|
||||
elif "The model is not a VLM" in str(e): # siliconcloud
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
return (
|
||||
False,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
)
|
||||
elif (
|
||||
"Function calling is not enabled" in str(e)
|
||||
or ("tool" in str(e).lower() and "support" in str(e).lower())
|
||||
or ("function" in str(e).lower() and "support" in str(e).lower())
|
||||
):
|
||||
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
||||
logger.info(
|
||||
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
||||
)
|
||||
if "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
return False, chosen_key, available_api_keys, payloads, context_query, None
|
||||
else:
|
||||
logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
|
||||
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
|
||||
|
||||
if "Connection error." in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
if proxy:
|
||||
logger.error(
|
||||
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
async def text_chat(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> LLMResponse:
|
||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
session_id,
|
||||
image_urls,
|
||||
func_tool,
|
||||
contexts,
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
llm_response = None
|
||||
max_retries = 10
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
|
||||
e = None
|
||||
last_exception = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
@@ -197,64 +376,103 @@ class ProviderOpenAIOfficial(Provider):
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
if "429" in str(e):
|
||||
logger.warning(
|
||||
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}"
|
||||
)
|
||||
# 最后一次不等待
|
||||
if retry_cnt < max_retries - 1:
|
||||
await asyncio.sleep(1)
|
||||
available_api_keys.remove(chosen_key)
|
||||
if len(available_api_keys) > 0:
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
elif "maximum context length" in str(e):
|
||||
logger.warning(
|
||||
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}"
|
||||
)
|
||||
await self.pop_record(context_query)
|
||||
elif "The model is not a VLM" in str(e): # siliconcloud
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
elif (
|
||||
"Function calling is not enabled" in str(e)
|
||||
or ("tool" in str(e).lower() and "support" in str(e).lower())
|
||||
or ("function" in str(e).lower() and "support" in str(e).lower())
|
||||
):
|
||||
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
|
||||
logger.info(
|
||||
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。"
|
||||
)
|
||||
if "tools" in payloads:
|
||||
del payloads["tools"]
|
||||
func_tool = None
|
||||
else:
|
||||
logger.error(
|
||||
f"发生了错误。Provider 配置如下: {self.provider_config}"
|
||||
)
|
||||
|
||||
if "tool" in str(e).lower() and "support" in str(e).lower():
|
||||
logger.error(
|
||||
"疑似该模型不支持函数调用工具调用。请输入 /tool off_all"
|
||||
)
|
||||
|
||||
if "Connection error." in str(e):
|
||||
proxy = os.environ.get("http_proxy", None)
|
||||
if proxy:
|
||||
logger.error(
|
||||
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}"
|
||||
)
|
||||
|
||||
raise e
|
||||
last_exception = e
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
) = await self._handle_api_error(
|
||||
e,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
retry_cnt,
|
||||
max_retries,
|
||||
)
|
||||
if success:
|
||||
break
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
raise e
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
raise last_exception
|
||||
return llm_response
|
||||
|
||||
async def text_chat_stream(
|
||||
self,
|
||||
prompt: str,
|
||||
session_id: str = None,
|
||||
image_urls: List[str] = [],
|
||||
func_tool: FuncCall = None,
|
||||
contexts=[],
|
||||
system_prompt=None,
|
||||
tool_calls_result=None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[LLMResponse, None]:
|
||||
"""流式对话,与服务商交互并逐步返回结果"""
|
||||
payloads, context_query, func_tool = await self._prepare_chat_payload(
|
||||
prompt,
|
||||
session_id,
|
||||
image_urls,
|
||||
func_tool,
|
||||
contexts,
|
||||
system_prompt,
|
||||
tool_calls_result,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
max_retries = 10
|
||||
available_api_keys = self.api_keys.copy()
|
||||
chosen_key = random.choice(available_api_keys)
|
||||
|
||||
last_exception = None
|
||||
retry_cnt = 0
|
||||
for retry_cnt in range(max_retries):
|
||||
try:
|
||||
self.client.api_key = chosen_key
|
||||
async for response in self._query_stream(payloads, func_tool):
|
||||
yield response
|
||||
break
|
||||
except UnprocessableEntityError as e:
|
||||
logger.warning(f"不可处理的实体错误:{e},尝试删除图片。")
|
||||
# 尝试删除所有 image
|
||||
new_contexts = await self._remove_image_from_context(context_query)
|
||||
payloads["messages"] = new_contexts
|
||||
context_query = new_contexts
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
(
|
||||
success,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
) = await self._handle_api_error(
|
||||
e,
|
||||
payloads,
|
||||
context_query,
|
||||
func_tool,
|
||||
chosen_key,
|
||||
available_api_keys,
|
||||
retry_cnt,
|
||||
max_retries,
|
||||
)
|
||||
if success:
|
||||
break
|
||||
|
||||
if retry_cnt == max_retries - 1:
|
||||
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
|
||||
if last_exception is None:
|
||||
raise Exception("未知错误")
|
||||
raise last_exception
|
||||
|
||||
async def _remove_image_from_context(self, contexts: List):
|
||||
"""
|
||||
从上下文中删除所有带有 image 的记录
|
||||
@@ -293,7 +511,10 @@ class ProviderOpenAIOfficial(Provider):
|
||||
async def assemble_context(self, text: str, image_urls: List[str] = None) -> dict:
|
||||
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
|
||||
if image_urls:
|
||||
user_content = {"role": "user", "content": [{"type": "text", "text": text}]}
|
||||
user_content = {
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": text if text else "[图片]"}],
|
||||
}
|
||||
for image_url in image_urls:
|
||||
if image_url.startswith("http"):
|
||||
image_path = await download_image_by_url(image_url)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import uuid
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import TTSProvider
|
||||
from ..entites import ProviderType
|
||||
from ..entities import ProviderType
|
||||
from ..register import register_provider_adapter
|
||||
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ import re
|
||||
from funasr_onnx import SenseVoiceSmall
|
||||
from funasr_onnx.utils.postprocess_utils import rich_transcription_postprocess
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from ..entities import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
import os
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from ..entities import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -3,7 +3,7 @@ import os
|
||||
import asyncio
|
||||
import whisper
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from ..entities import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
|
||||
@@ -3,7 +3,7 @@ from astrbot import logger
|
||||
from astrbot.core.provider.func_tool_manager import FuncCall
|
||||
from typing import List
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core.provider.entites import LLMResponse
|
||||
from astrbot.core.provider.entities import LLMResponse
|
||||
from .openai_source import ProviderOpenAIOfficial
|
||||
|
||||
|
||||
|
||||
Regular → Executable
Regular → Executable
@@ -47,5 +47,29 @@ class StarMetadata:
|
||||
star_handler_full_names: List[str] = field(default_factory=list)
|
||||
"""注册的 Handler 的全名列表"""
|
||||
|
||||
supported_platforms: Dict[str, bool] = field(default_factory=dict)
|
||||
"""插件支持的平台ID字典,key为平台ID,value为是否支持"""
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"StarMetadata({self.name}, {self.desc}, {self.version}, {self.repo})"
|
||||
|
||||
def update_platform_compatibility(self, plugin_enable_config: dict) -> None:
|
||||
"""更新插件支持的平台列表
|
||||
|
||||
Args:
|
||||
plugin_enable_config: 平台插件启用配置,即platform_settings.plugin_enable配置项
|
||||
"""
|
||||
if not plugin_enable_config:
|
||||
return
|
||||
|
||||
# 清空之前的配置
|
||||
self.supported_platforms.clear()
|
||||
|
||||
# 遍历所有平台配置
|
||||
for platform_id, plugins in plugin_enable_config.items():
|
||||
# 检查该插件在当前平台的配置
|
||||
if self.name in plugins:
|
||||
self.supported_platforms[platform_id] = plugins[self.name]
|
||||
else:
|
||||
# 如果没有明确配置,默认为启用
|
||||
self.supported_platforms[platform_id] = True
|
||||
|
||||
@@ -30,21 +30,36 @@ class StarHandlerRegistry(Generic[T]):
|
||||
print(handler.handler_full_name)
|
||||
|
||||
def get_handlers_by_event_type(
|
||||
self, event_type: EventType, only_activated=True
|
||||
self, event_type: EventType, only_activated=True, platform_id=None
|
||||
) -> List[StarHandlerMetadata]:
|
||||
"""通过事件类型获取 Handler"""
|
||||
handlers = [
|
||||
handler
|
||||
for _, handler in self._handlers
|
||||
if handler.event_type == event_type
|
||||
and (
|
||||
not only_activated
|
||||
or (
|
||||
star_map[handler.handler_module_path]
|
||||
and star_map[handler.handler_module_path].activated
|
||||
)
|
||||
)
|
||||
]
|
||||
"""通过事件类型获取 Handler
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
only_activated: 是否只返回已激活的插件的处理器
|
||||
platform_id: 平台ID,如果提供此参数,将过滤掉在此平台不兼容的处理器
|
||||
|
||||
Returns:
|
||||
List[StarHandlerMetadata]: 处理器列表
|
||||
"""
|
||||
handlers = []
|
||||
for _, handler in self._handlers:
|
||||
if handler.event_type != event_type:
|
||||
continue
|
||||
|
||||
# 只激活的插件处理器
|
||||
if only_activated:
|
||||
plugin = star_map.get(handler.handler_module_path)
|
||||
if not (plugin and plugin.activated):
|
||||
continue
|
||||
|
||||
# 平台兼容性过滤
|
||||
if platform_id and event_type != EventType.OnAstrBotLoadedEvent:
|
||||
if not handler.is_enabled_for_platform(platform_id):
|
||||
continue
|
||||
|
||||
handlers.append(handler)
|
||||
|
||||
return handlers
|
||||
|
||||
def get_handler_by_full_name(self, full_name: str) -> StarHandlerMetadata:
|
||||
@@ -139,3 +154,32 @@ class StarHandlerMetadata:
|
||||
return self.extras_configs.get("priority", 0) < other.extras_configs.get(
|
||||
"priority", 0
|
||||
)
|
||||
|
||||
def is_enabled_for_platform(self, platform_id: str) -> bool:
|
||||
"""检查插件是否在指定平台启用
|
||||
|
||||
Args:
|
||||
platform_id: 平台ID,这是从event.get_platform_id()获取的,用于唯一标识平台实例
|
||||
|
||||
Returns:
|
||||
bool: 是否启用,True表示启用,False表示禁用
|
||||
"""
|
||||
plugin = star_map.get(self.handler_module_path)
|
||||
|
||||
# 如果插件元数据不存在,默认允许执行
|
||||
if not plugin or not plugin.name:
|
||||
return True
|
||||
|
||||
# 先检查插件是否被激活
|
||||
if not plugin.activated:
|
||||
return False
|
||||
|
||||
# 直接使用StarMetadata中缓存的supported_platforms判断平台兼容性
|
||||
if (
|
||||
hasattr(plugin, "supported_platforms")
|
||||
and platform_id in plugin.supported_platforms
|
||||
):
|
||||
return plugin.supported_platforms[platform_id]
|
||||
|
||||
# 如果没有缓存数据,默认允许执行
|
||||
return True
|
||||
|
||||
@@ -28,7 +28,7 @@ from .filter.permission import PermissionTypeFilter, PermissionType
|
||||
|
||||
class PluginManager:
|
||||
def __init__(self, context: Context, config: AstrBotConfig):
|
||||
self.updator = PluginUpdator(config["plugin_repo_mirror"])
|
||||
self.updator = PluginUpdator()
|
||||
|
||||
self.context = context
|
||||
self.context._star_manager = self
|
||||
@@ -166,8 +166,71 @@ class PluginManager:
|
||||
|
||||
return metadata
|
||||
|
||||
def _get_plugin_related_modules(
|
||||
self, plugin_root_dir: str, is_reserved: bool = False
|
||||
) -> list[str]:
|
||||
"""获取与指定插件相关的所有已加载模块名
|
||||
|
||||
根据插件根目录名和是否为保留插件,从 sys.modules 中筛选出相关的模块名
|
||||
|
||||
Args:
|
||||
plugin_root_dir: 插件根目录名
|
||||
is_reserved: 是否是保留插件,影响模块路径前缀
|
||||
|
||||
Returns:
|
||||
list[str]: 与该插件相关的模块名列表
|
||||
"""
|
||||
prefix = "packages." if is_reserved else "data.plugins."
|
||||
return [
|
||||
key
|
||||
for key in list(sys.modules.keys())
|
||||
if key.startswith(f"{prefix}{plugin_root_dir}")
|
||||
]
|
||||
|
||||
def _purge_modules(
|
||||
self,
|
||||
module_patterns: list[str] = None,
|
||||
root_dir_name: str = None,
|
||||
is_reserved: bool = False,
|
||||
):
|
||||
"""从 sys.modules 中移除指定的模块
|
||||
|
||||
可以基于模块名模式或插件目录名移除模块,用于清理插件相关的模块缓存
|
||||
|
||||
Args:
|
||||
module_patterns: 要移除的模块名模式列表(例如 ["data.plugins", "packages"])
|
||||
root_dir_name: 插件根目录名,用于移除与该插件相关的所有模块
|
||||
is_reserved: 插件是否为保留插件(影响模块路径前缀)
|
||||
"""
|
||||
if module_patterns:
|
||||
for pattern in module_patterns:
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith(pattern):
|
||||
del sys.modules[key]
|
||||
logger.debug(f"删除模块 {key}")
|
||||
|
||||
if root_dir_name:
|
||||
for module_name in self._get_plugin_related_modules(
|
||||
root_dir_name, is_reserved
|
||||
):
|
||||
try:
|
||||
del sys.modules[module_name]
|
||||
logger.debug(f"删除模块 {module_name}")
|
||||
except KeyError:
|
||||
logger.warning(f"模块 {module_name} 未载入")
|
||||
|
||||
async def reload(self, specified_plugin_name=None):
|
||||
"""扫描并加载所有的插件 当 specified_module_path 指定时,重载指定插件"""
|
||||
"""重新加载插件
|
||||
|
||||
Args:
|
||||
specified_plugin_name (str, optional): 要重载的特定插件名称。
|
||||
如果为 None,则重载所有插件。
|
||||
|
||||
Returns:
|
||||
tuple: 返回 load() 方法的结果,包含 (success, error_message)
|
||||
- success (bool): 重载是否成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
specified_module_path = None
|
||||
if specified_plugin_name:
|
||||
for smd in star_registry:
|
||||
@@ -192,9 +255,6 @@ class PluginManager:
|
||||
star_handlers_registry.clear()
|
||||
star_map.clear()
|
||||
star_registry.clear()
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith("data.plugins") or key.startswith("packages"):
|
||||
del sys.modules[key]
|
||||
else:
|
||||
# 只重载指定插件
|
||||
smd = star_map.get(specified_module_path)
|
||||
@@ -209,11 +269,44 @@ class PluginManager:
|
||||
|
||||
await self._unbind_plugin(smd.name, specified_module_path)
|
||||
|
||||
return await self.load(specified_module_path)
|
||||
result = await self.load(specified_module_path)
|
||||
|
||||
# 更新所有插件的平台兼容性
|
||||
await self.update_all_platform_compatibility()
|
||||
|
||||
return result
|
||||
|
||||
async def update_all_platform_compatibility(self):
|
||||
"""更新所有插件的平台兼容性设置"""
|
||||
# 获取最新的平台插件启用配置
|
||||
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||
"plugin_enable", {}
|
||||
)
|
||||
logger.debug(
|
||||
f"更新所有插件的平台兼容性设置,平台数量: {len(plugin_enable_config)}"
|
||||
)
|
||||
|
||||
# 遍历所有插件,更新平台兼容性
|
||||
for plugin in self.context.get_all_stars():
|
||||
plugin.update_platform_compatibility(plugin_enable_config)
|
||||
logger.debug(
|
||||
f"插件 {plugin.name} 支持的平台: {list(plugin.supported_platforms.keys())}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def load(self, specified_module_path=None, specified_dir_name=None):
|
||||
"""载入插件。
|
||||
当 specified_module_path 或者 specified_dir_name 不为 None 时,只载入指定的插件。
|
||||
|
||||
Args:
|
||||
specified_module_path (str, optional): 指定要加载的插件模块路径。例如: "data.plugins.my_plugin.main"
|
||||
specified_dir_name (str, optional): 指定要加载的插件目录名。例如: "my_plugin"
|
||||
|
||||
Returns:
|
||||
tuple: (success, error_message)
|
||||
- success (bool): 是否全部加载成功
|
||||
- error_message (str|None): 错误信息,成功时为 None
|
||||
"""
|
||||
inactivated_plugins: list = sp.get("inactivated_plugins", [])
|
||||
inactivated_llm_tools: list = sp.get("inactivated_llm_tools", [])
|
||||
@@ -320,6 +413,12 @@ class PluginManager:
|
||||
metadata.root_dir_name = root_dir_name
|
||||
metadata.reserved = reserved
|
||||
|
||||
# 更新插件的平台兼容性
|
||||
plugin_enable_config = self.config.get("platform_settings", {}).get(
|
||||
"plugin_enable", {}
|
||||
)
|
||||
metadata.update_platform_compatibility(plugin_enable_config)
|
||||
|
||||
# 绑定 handler
|
||||
related_handlers = (
|
||||
star_handlers_registry.get_handlers_by_module_name(
|
||||
@@ -447,6 +546,20 @@ class PluginManager:
|
||||
return False, fail_rec
|
||||
|
||||
async def install_plugin(self, repo_url: str, proxy=""):
|
||||
"""从仓库 URL 安装插件
|
||||
|
||||
从指定的仓库 URL 下载并安装插件,然后加载该插件到系统中
|
||||
|
||||
Args:
|
||||
repo_url (str): 要安装的插件仓库 URL
|
||||
proxy (str, optional): 用于下载的代理服务器。默认为空字符串。
|
||||
|
||||
Returns:
|
||||
dict | None: 安装成功时返回包含插件信息的字典:
|
||||
- repo: 插件的仓库 URL
|
||||
- readme: README.md 文件的内容(如果存在)
|
||||
如果找不到插件元数据则返回 None。
|
||||
"""
|
||||
plugin_path = await self.updator.install(repo_url, proxy)
|
||||
# reload the plugin
|
||||
dir_name = os.path.basename(plugin_path)
|
||||
@@ -481,6 +594,14 @@ class PluginManager:
|
||||
return plugin_info
|
||||
|
||||
async def uninstall_plugin(self, plugin_name: str):
|
||||
"""卸载指定的插件。
|
||||
|
||||
Args:
|
||||
plugin_name (str): 要卸载的插件名称
|
||||
|
||||
Raises:
|
||||
Exception: 当插件不存在、是保留插件时,或删除插件文件夹失败时抛出异常
|
||||
"""
|
||||
plugin = self.context.get_registered_star(plugin_name)
|
||||
if not plugin:
|
||||
raise Exception("插件不存在。")
|
||||
@@ -509,9 +630,17 @@ class PluginManager:
|
||||
)
|
||||
|
||||
async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str):
|
||||
"""解绑并移除一个插件。
|
||||
|
||||
Args:
|
||||
plugin_name: 要解绑的插件名称
|
||||
plugin_module_path: 插件的完整模块路径
|
||||
"""
|
||||
plugin = None
|
||||
del star_map[plugin_module_path]
|
||||
for i, p in enumerate(star_registry):
|
||||
if p.name == plugin_name:
|
||||
plugin = p
|
||||
del star_registry[i]
|
||||
break
|
||||
for handler in star_handlers_registry.get_handlers_by_module_name(
|
||||
@@ -521,21 +650,17 @@ class PluginManager:
|
||||
f"移除了插件 {plugin_name} 的处理函数 {handler.handler_name} ({len(star_handlers_registry)})"
|
||||
)
|
||||
star_handlers_registry.remove(handler)
|
||||
keys_to_delete = [
|
||||
k
|
||||
for k, v in star_handlers_registry.star_handlers_map.items()
|
||||
if k.startswith(plugin_module_path)
|
||||
]
|
||||
for k in keys_to_delete:
|
||||
try:
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
try:
|
||||
del sys.modules[plugin_module_path]
|
||||
except KeyError:
|
||||
logger.warning(f"模块 {plugin_module_path} 未载入")
|
||||
for k in [
|
||||
k
|
||||
for k in star_handlers_registry.star_handlers_map
|
||||
if k.startswith(plugin_module_path)
|
||||
]:
|
||||
del star_handlers_registry.star_handlers_map[k]
|
||||
|
||||
self._purge_modules(
|
||||
root_dir_name=plugin.root_dir_name, is_reserved=plugin.reserved
|
||||
)
|
||||
|
||||
async def update_plugin(self, plugin_name: str, proxy=""):
|
||||
"""升级一个插件"""
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import inspect
|
||||
from typing import Union, Awaitable, List, Optional, ClassVar
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.api.platform import MessageMember, AstrBotMessage
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.star.context import Context
|
||||
from astrbot.core.star.star import star_map
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class StarTools:
|
||||
@@ -142,3 +145,48 @@ class StarTools:
|
||||
name (str): 工具名称
|
||||
"""
|
||||
cls._context.unregister_llm_tool(name)
|
||||
|
||||
@classmethod
|
||||
def get_data_dir(cls, plugin_name: Optional[str] = None) -> Path:
|
||||
"""
|
||||
返回插件数据目录的绝对路径。
|
||||
|
||||
此方法会在 data/plugin_data 目录下为插件创建一个专属的数据目录。如果未提供插件名称,
|
||||
会自动从调用栈中获取插件信息。
|
||||
|
||||
Args:
|
||||
plugin_name: 可选的插件名称。如果为None,将自动检测调用者的插件名称。
|
||||
|
||||
Returns:
|
||||
Path (Path): 插件数据目录的绝对路径,位于 data/plugin_data/{plugin_name}。
|
||||
|
||||
Raises:
|
||||
RuntimeError: 当出现以下情况时抛出:
|
||||
- 无法获取调用者模块信息
|
||||
- 无法获取模块的元数据信息
|
||||
- 创建目录失败(权限不足或其他IO错误)
|
||||
"""
|
||||
if not plugin_name:
|
||||
frame = inspect.currentframe().f_back
|
||||
module = inspect.getmodule(frame)
|
||||
|
||||
if not module:
|
||||
raise RuntimeError("无法获取调用者模块信息")
|
||||
|
||||
metadata = star_map.get(module.__name__, None)
|
||||
|
||||
if not metadata:
|
||||
raise RuntimeError(f"无法获取模块 {module.__name__} 的元数据信息")
|
||||
|
||||
plugin_name = metadata.name
|
||||
|
||||
data_dir = Path("data/plugin_data") / plugin_name
|
||||
|
||||
try:
|
||||
data_dir.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as e:
|
||||
if isinstance(e, PermissionError):
|
||||
raise RuntimeError(f"无法创建目录 {data_dir}:权限不足") from e
|
||||
raise RuntimeError(f"无法创建目录 {data_dir}:{e!s}") from e
|
||||
|
||||
return data_dir.resolve()
|
||||
|
||||
@@ -209,20 +209,20 @@ async def get_dashboard_version():
|
||||
return None
|
||||
|
||||
|
||||
async def download_dashboard():
|
||||
async def download_dashboard(path: str = "data/dashboard.zip", extract_path: str = "data"):
|
||||
"""下载管理面板文件"""
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
await download_file(
|
||||
dashboard_release_url, "data/dashboard.zip", show_progress=True
|
||||
dashboard_release_url, path, show_progress=True
|
||||
)
|
||||
except BaseException as _:
|
||||
dashboard_release_url = (
|
||||
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
)
|
||||
await download_file(
|
||||
dashboard_release_url, "data/dashboard.zip", show_progress=True
|
||||
dashboard_release_url, path, show_progress=True
|
||||
)
|
||||
print("解压管理面板文件中...")
|
||||
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
|
||||
z.extractall("data")
|
||||
with zipfile.ZipFile(path, "r") as z:
|
||||
z.extractall(extract_path)
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import threading
|
||||
import os
|
||||
from logging import Logger
|
||||
|
||||
|
||||
class LogPipe(threading.Thread):
|
||||
def __init__(
|
||||
self,
|
||||
level,
|
||||
logger: Logger,
|
||||
identifier=None,
|
||||
callback=None,
|
||||
):
|
||||
threading.Thread.__init__(self)
|
||||
self.daemon = True
|
||||
self.level = level
|
||||
self.fd_read, self.fd_write = os.pipe()
|
||||
self.identifier = identifier
|
||||
self.logger = logger
|
||||
self.callback = callback
|
||||
self.reader = os.fdopen(self.fd_read)
|
||||
self.start()
|
||||
|
||||
def fileno(self):
|
||||
return self.fd_write
|
||||
|
||||
def run(self):
|
||||
for line in iter(self.reader.readline, ""):
|
||||
if self.callback:
|
||||
self.callback(line.strip())
|
||||
self.logger.log(self.level, f"[{self.identifier}] {line.strip()}")
|
||||
|
||||
self.reader.close()
|
||||
|
||||
def close(self):
|
||||
os.close(self.fd_write)
|
||||
@@ -1,10 +1,42 @@
|
||||
import aiohttp
|
||||
import sys
|
||||
import os
|
||||
import socket
|
||||
import uuid
|
||||
from astrbot.core.config import VERSION
|
||||
from astrbot.core import db_helper, logger
|
||||
|
||||
|
||||
class Metric:
|
||||
_iid_cache = None
|
||||
|
||||
@staticmethod
|
||||
def get_installation_id():
|
||||
"""获取或创建一个唯一的安装ID"""
|
||||
if Metric._iid_cache is not None:
|
||||
return Metric._iid_cache
|
||||
|
||||
config_dir = os.path.join(os.path.expanduser("~"), ".astrbot")
|
||||
id_file = os.path.join(config_dir, ".installation_id")
|
||||
|
||||
if os.path.exists(id_file):
|
||||
try:
|
||||
with open(id_file, "r") as f:
|
||||
Metric._iid_cache = f.read().strip()
|
||||
return Metric._iid_cache
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
os.makedirs(config_dir, exist_ok=True)
|
||||
installation_id = str(uuid.uuid4())
|
||||
with open(id_file, "w") as f:
|
||||
f.write(installation_id)
|
||||
Metric._iid_cache = installation_id
|
||||
return installation_id
|
||||
except Exception:
|
||||
Metric._iid_cache = "null"
|
||||
return "null"
|
||||
|
||||
@staticmethod
|
||||
async def upload(**kwargs):
|
||||
"""
|
||||
@@ -16,6 +48,14 @@ class Metric:
|
||||
kwargs["v"] = VERSION
|
||||
kwargs["os"] = sys.platform
|
||||
payload = {"metrics_data": kwargs}
|
||||
try:
|
||||
kwargs["hn"] = socket.gethostname()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
kwargs["iid"] = Metric.get_installation_id()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
if "adapter_name" in kwargs:
|
||||
db_helper.insert_platform_metrics({kwargs["adapter_name"]: 1})
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
from astrbot.core import logger
|
||||
|
||||
def path_Mapping(mappings, srcPath: str)->str:
|
||||
"""路径映射处理函数。尝试支援 Windows 和 Linux 的路径映射。
|
||||
Args:
|
||||
mappings: 映射规则列表
|
||||
srcPath: 原路径
|
||||
Returns:
|
||||
str: 处理后的路径
|
||||
"""
|
||||
for mapping in mappings:
|
||||
rule = mapping.split(":")
|
||||
if len(rule) == 2:
|
||||
from_, to_ = mapping.split(":")
|
||||
elif len(rule) > 4 or len(rule) == 1:
|
||||
# 切割后大于4个项目,或者只有1个项目,那肯定是错误的,只能是2,3,4个项目
|
||||
logger.warning(f"路径映射规则错误: {mapping}")
|
||||
continue
|
||||
else:
|
||||
# rule.len == 3 or 4
|
||||
if(os.path.exists(rule[0]+":"+rule[1])):
|
||||
# 前面两个项目合并路径存在,说明是本地Window路径。后面一个或两个项目组成的路径本地大概率无法解析,直接拼接
|
||||
from_ = rule[0] + ":" + rule[1]
|
||||
if len(rule) == 3:
|
||||
to_ = rule[2]
|
||||
else:
|
||||
to_ = rule[2] + ":" + rule[3]
|
||||
else:
|
||||
# 前面两个项目合并路径不存在,说明第一个项目是本地Linux路径,后面一个或两个项目直接拼接。
|
||||
from_ = rule[0]
|
||||
if len(rule) == 3:
|
||||
to_ = rule[1] + ":" + rule[2]
|
||||
else:
|
||||
# 这种情况下存在四个项目,说明规则也是错误的
|
||||
logger.warning(f"路径映射规则错误: {mapping}")
|
||||
continue
|
||||
|
||||
from_ = from_.removesuffix("/")
|
||||
from_ = from_.removesuffix("\\")
|
||||
to_ = to_.removesuffix("/")
|
||||
to_ = to_.removesuffix("\\")
|
||||
# logger.debug(f"\t路径映射-规则(处理): {from_} -> {to_}")
|
||||
|
||||
url = srcPath.removeprefix("file://")
|
||||
if url.startswith(from_):
|
||||
srcPath = url.replace(from_, to_, 1)
|
||||
if ":" in srcPath:
|
||||
# Windows路径处理
|
||||
srcPath = srcPath.replace("/", "\\")
|
||||
else:
|
||||
has_replaced_processed = False
|
||||
if srcPath.startswith("."):
|
||||
# 相对路径处理。如果是相对路径,可能是Linux路径,也可能是Windows路径
|
||||
sign = srcPath[1]
|
||||
# 处理两个点的情况
|
||||
if sign == ".":
|
||||
sign = srcPath[2]
|
||||
if sign == "/":
|
||||
srcPath = srcPath.replace("\\", "/")
|
||||
has_replaced_processed = True
|
||||
elif sign == "\\":
|
||||
srcPath = srcPath.replace("/", "\\")
|
||||
has_replaced_processed = True
|
||||
if has_replaced_processed == False:
|
||||
# 如果不是相对路径或不能处理,默认按照Linux路径处理
|
||||
srcPath = srcPath.replace("\\", "/")
|
||||
logger.info(f"路径映射: {url} -> {srcPath}")
|
||||
return srcPath
|
||||
return srcPath
|
||||
@@ -5,8 +5,9 @@ logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
class PipInstaller:
|
||||
def __init__(self, pip_install_arg: str):
|
||||
def __init__(self, pip_install_arg: str, pypi_index_url: str = None):
|
||||
self.pip_install_arg = pip_install_arg
|
||||
self.pypi_index_url = pypi_index_url
|
||||
|
||||
def install(
|
||||
self,
|
||||
@@ -20,10 +21,9 @@ class PipInstaller:
|
||||
elif requirements_path:
|
||||
args.extend(["-r", requirements_path])
|
||||
|
||||
if not mirror:
|
||||
mirror = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
index_url = mirror or self.pypi_index_url or "https://pypi.org/simple"
|
||||
|
||||
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", mirror])
|
||||
args.extend(["--trusted-host", "mirrors.aliyun.com", "-i", index_url])
|
||||
|
||||
if self.pip_install_arg:
|
||||
args.extend(self.pip_install_arg.split())
|
||||
|
||||
@@ -97,8 +97,8 @@ class SessionFilter:
|
||||
|
||||
class DefaultSessionFilter(SessionFilter):
|
||||
def filter(self, event: AstrMessageEvent) -> str:
|
||||
"""默认实现,返回发送者的 ID 作为会话标识符"""
|
||||
return event.get_sender_id()
|
||||
"""默认实现,返回统一消息来源字符串作为会话标识符"""
|
||||
return event.unified_msg_origin
|
||||
|
||||
|
||||
class SessionWaiter:
|
||||
|
||||
@@ -9,13 +9,16 @@ class SharedPreferences:
|
||||
|
||||
def _load_preferences(self):
|
||||
if os.path.exists(self.path):
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
try:
|
||||
with open(self.path, "r") as f:
|
||||
return json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
os.remove(self.path)
|
||||
return {}
|
||||
|
||||
def _save_preferences(self):
|
||||
with open(self.path, "w") as f:
|
||||
json.dump(self._data, f, indent=4)
|
||||
json.dump(self._data, f, indent=4, ensure_ascii=False)
|
||||
f.flush()
|
||||
|
||||
def get(self, key, default=None):
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import re
|
||||
|
||||
|
||||
class VersionComparator:
|
||||
@staticmethod
|
||||
def compare_version(v1: str, v2: str) -> int:
|
||||
"""根据 Semver 语义版本规范来比较版本号的大小。支持不仅局限于 3 个数字的版本号,并处理预发布标签。
|
||||
|
||||
参考: https://semver.org/lang/zh-CN/
|
||||
|
||||
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
|
||||
"""
|
||||
v1 = v1.lower().replace("v", "")
|
||||
v2 = v2.lower().replace("v", "")
|
||||
|
||||
def split_version(version):
|
||||
match = re.match(
|
||||
r"^([0-9]+(?:\.[0-9]+)*)(?:-([0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(.+))?$",
|
||||
version,
|
||||
)
|
||||
if not match:
|
||||
return [], None
|
||||
major_minor_patch = match.group(1).split(".")
|
||||
prerelease = match.group(2)
|
||||
# buildmetadata = match.group(3) # 构建元数据在比较时忽略
|
||||
parts = [int(x) for x in major_minor_patch]
|
||||
prerelease = VersionComparator._split_prerelease(prerelease)
|
||||
return parts, prerelease
|
||||
|
||||
v1_parts, v1_prerelease = split_version(v1)
|
||||
v2_parts, v2_prerelease = split_version(v2)
|
||||
|
||||
# 比较数字部分
|
||||
length = max(len(v1_parts), len(v2_parts))
|
||||
v1_parts.extend([0] * (length - len(v1_parts)))
|
||||
v2_parts.extend([0] * (length - len(v2_parts)))
|
||||
|
||||
for i in range(length):
|
||||
if v1_parts[i] > v2_parts[i]:
|
||||
return 1
|
||||
elif v1_parts[i] < v2_parts[i]:
|
||||
return -1
|
||||
|
||||
# 比较预发布标签
|
||||
if v1_prerelease is None and v2_prerelease is not None:
|
||||
return 1 # 没有预发布标签的版本高于有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is None:
|
||||
return -1 # 有预发布标签的版本低于没有预发布标签的版本
|
||||
elif v1_prerelease is not None and v2_prerelease is not None:
|
||||
len_pre = max(len(v1_prerelease), len(v2_prerelease))
|
||||
for i in range(len_pre):
|
||||
p1 = v1_prerelease[i] if i < len(v1_prerelease) else None
|
||||
p2 = v2_prerelease[i] if i < len(v2_prerelease) else None
|
||||
|
||||
if p1 is None and p2 is not None:
|
||||
return -1
|
||||
elif p1 is not None and p2 is None:
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, str):
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, int):
|
||||
return 1
|
||||
elif isinstance(p1, int) and isinstance(p2, int):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
elif isinstance(p1, str) and isinstance(p2, str):
|
||||
if p1 > p2:
|
||||
return 1
|
||||
elif p1 < p2:
|
||||
return -1
|
||||
return 0 # 预发布标签完全相同
|
||||
|
||||
return 0 # 数字部分和预发布标签都相同
|
||||
|
||||
@staticmethod
|
||||
def _split_prerelease(prerelease):
|
||||
if not prerelease:
|
||||
return None
|
||||
parts = prerelease.split(".")
|
||||
result = []
|
||||
for part in parts:
|
||||
if part.isdigit():
|
||||
result.append(int(part))
|
||||
else:
|
||||
result.append(part)
|
||||
return result
|
||||
@@ -8,6 +8,7 @@ import certifi
|
||||
|
||||
from astrbot.core.utils.io import on_error, download_file
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.version_comparator import VersionComparator
|
||||
|
||||
|
||||
class ReleaseInfo:
|
||||
@@ -102,23 +103,10 @@ class RepoZipUpdator:
|
||||
raise NotImplementedError()
|
||||
|
||||
def compare_version(self, v1: str, v2: str) -> int:
|
||||
"""
|
||||
比较两个版本号的大小。
|
||||
返回 1 表示 v1 > v2,返回 -1 表示 v1 < v2,返回 0 表示 v1 = v2。
|
||||
"""
|
||||
v1 = v1.replace("v", "")
|
||||
v2 = v2.replace("v", "")
|
||||
v1 = v1.split(".")
|
||||
v2 = v2.split(".")
|
||||
"""Semver 版本比较"""
|
||||
return VersionComparator.compare_version(v1, v2)
|
||||
|
||||
for i in range(3):
|
||||
if int(v1[i]) > int(v2[i]):
|
||||
return 1
|
||||
elif int(v1[i]) < int(v2[i]):
|
||||
return -1
|
||||
return 0
|
||||
|
||||
async def check_update(self, url: str, current_version: str) -> ReleaseInfo:
|
||||
async def check_update(self, url: str, current_version: str) -> ReleaseInfo | None:
|
||||
update_data = await self.fetch_release_info(url)
|
||||
tag_name = update_data[0]["tag_name"]
|
||||
|
||||
|
||||
@@ -161,42 +161,53 @@ class ChatRoute(Route):
|
||||
username = g.get("username", "guest")
|
||||
|
||||
if username in self.curr_chat_sse:
|
||||
return "[ERROR]\n"
|
||||
return Response().error("Already connected").__dict__
|
||||
|
||||
self.curr_chat_sse[username] = None
|
||||
|
||||
heartbeat = json.dumps({"type": "heartbeat", "data": "ping"})
|
||||
|
||||
async def stream():
|
||||
try:
|
||||
yield "[HB]\n"
|
||||
yield f"data: {heartbeat}\n\n" # 心跳包
|
||||
while True:
|
||||
try:
|
||||
result = await asyncio.wait_for(
|
||||
web_chat_back_queue.get(), timeout=10
|
||||
) # 设置超时时间为5秒
|
||||
except asyncio.TimeoutError:
|
||||
yield "[HB]\n" # 心跳包
|
||||
yield f"data: {heartbeat}\n\n" # 心跳包
|
||||
continue
|
||||
|
||||
if not result:
|
||||
continue
|
||||
result_text, cid = result
|
||||
|
||||
result_text = result["data"]
|
||||
type = result.get("type")
|
||||
cid = result.get("cid")
|
||||
streaming = result.get("streaming", False)
|
||||
if cid != self.curr_user_cid.get(username):
|
||||
# 丢弃
|
||||
continue
|
||||
yield result_text + "\n"
|
||||
yield f"data: {json.dumps(result, ensure_ascii=False)}\n\n"
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
conversation = self.db.get_conversation_by_user_id(username, cid)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
history.append({"type": "bot", "message": result_text})
|
||||
self.db.update_conversation(
|
||||
username, cid, history=json.dumps(history)
|
||||
)
|
||||
if streaming and type != "end":
|
||||
continue
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
if result_text:
|
||||
conversation = self.db.get_conversation_by_user_id(
|
||||
username, cid
|
||||
)
|
||||
try:
|
||||
history = json.loads(conversation.history)
|
||||
except BaseException as e:
|
||||
print(e)
|
||||
history = []
|
||||
history.append({"type": "bot", "message": result_text})
|
||||
self.db.update_conversation(
|
||||
username, cid, history=json.dumps(history)
|
||||
)
|
||||
except BaseException as _:
|
||||
logger.debug(f"用户 {username} 断开聊天长连接。")
|
||||
self.curr_chat_sse.pop(username)
|
||||
|
||||
@@ -60,11 +60,13 @@ def validate_config(
|
||||
data[key] = False
|
||||
continue
|
||||
meta = metadata[key]
|
||||
if "type" not in meta:
|
||||
logger.debug(f"配置项 {path}{key} 没有类型定义, 跳过校验")
|
||||
continue
|
||||
# null 转换
|
||||
if value is None:
|
||||
data[key] = DEFAULT_VALUE_MAP[meta["type"]]
|
||||
continue
|
||||
# 递归验证
|
||||
if meta["type"] == "list" and not isinstance(value, list):
|
||||
errors.append(
|
||||
f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}"
|
||||
@@ -179,7 +181,7 @@ class ConfigRoute(Route):
|
||||
await self._save_astrbot_configs(post_configs)
|
||||
return Response().ok(None, "保存成功~ 机器人正在重载配置。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def post_plugin_configs(self):
|
||||
|
||||
@@ -20,7 +20,7 @@ class LogRoute(Route):
|
||||
message = await queue.get()
|
||||
payload = {
|
||||
"type": "log",
|
||||
**message # see astrbot/core/log.py
|
||||
**message, # see astrbot/core/log.py
|
||||
}
|
||||
yield f"data: {json.dumps(payload, ensure_ascii=False)}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import traceback
|
||||
import aiohttp
|
||||
import os
|
||||
|
||||
import ssl
|
||||
import certifi
|
||||
@@ -36,6 +37,9 @@ class PluginRoute(Route):
|
||||
"/plugin/off": ("POST", self.off_plugin),
|
||||
"/plugin/on": ("POST", self.on_plugin),
|
||||
"/plugin/reload": ("POST", self.reload_plugins),
|
||||
"/plugin/readme": ("GET", self.get_plugin_readme),
|
||||
"/plugin/platform_enable/get": ("GET", self.get_plugin_platform_enable),
|
||||
"/plugin/platform_enable/set": ("POST", self.set_plugin_platform_enable),
|
||||
}
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.plugin_manager = plugin_manager
|
||||
@@ -141,7 +145,9 @@ class PluginRoute(Route):
|
||||
if handler.event_type == EventType.AdapterMessageEvent:
|
||||
# 处理平台适配器消息事件
|
||||
has_admin = False
|
||||
for filter in (
|
||||
for (
|
||||
filter
|
||||
) in (
|
||||
handler.event_filters
|
||||
): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
|
||||
if isinstance(filter, CommandFilter):
|
||||
@@ -317,3 +323,135 @@ class PluginRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/on: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def get_plugin_readme(self):
|
||||
plugin_name = request.args.get("name")
|
||||
logger.debug(f"正在获取插件 {plugin_name} 的README文件内容")
|
||||
|
||||
if not plugin_name:
|
||||
logger.warning("插件名称为空")
|
||||
return Response().error("插件名称不能为空").__dict__
|
||||
|
||||
plugin_obj = None
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
if plugin.name == plugin_name:
|
||||
plugin_obj = plugin
|
||||
break
|
||||
|
||||
if not plugin_obj:
|
||||
logger.warning(f"插件 {plugin_name} 不存在")
|
||||
return Response().error(f"插件 {plugin_name} 不存在").__dict__
|
||||
|
||||
plugin_dir = os.path.join(
|
||||
self.plugin_manager.plugin_store_path, plugin_obj.root_dir_name
|
||||
)
|
||||
|
||||
if not os.path.isdir(plugin_dir):
|
||||
logger.warning(f"无法找到插件目录: {plugin_dir}")
|
||||
return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__
|
||||
|
||||
readme_path = os.path.join(plugin_dir, "README.md")
|
||||
|
||||
if not os.path.isfile(readme_path):
|
||||
logger.warning(f"插件 {plugin_name} 没有README文件")
|
||||
return Response().error(f"插件 {plugin_name} 没有README文件").__dict__
|
||||
|
||||
try:
|
||||
with open(readme_path, "r", encoding="utf-8") as f:
|
||||
readme_content = f.read()
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok({"content": readme_content}, "成功获取README内容")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/readme: {traceback.format_exc()}")
|
||||
return Response().error(f"读取README文件失败: {str(e)}").__dict__
|
||||
|
||||
async def get_plugin_platform_enable(self):
|
||||
"""获取插件在各平台的可用性配置"""
|
||||
try:
|
||||
platform_enable = self.core_lifecycle.astrbot_config.get(
|
||||
"platform_settings", {}
|
||||
).get("plugin_enable", {})
|
||||
|
||||
# 获取所有可用平台
|
||||
platforms = []
|
||||
|
||||
for platform in self.core_lifecycle.astrbot_config.get("platform", []):
|
||||
platform_type = platform.get("type", "")
|
||||
platform_id = platform.get("id", "")
|
||||
|
||||
platforms.append(
|
||||
{
|
||||
"name": platform_id, # 使用type作为name,这是系统内部使用的平台名称
|
||||
"id": platform_id, # 保留id字段以便前端可以显示
|
||||
"type": platform_type,
|
||||
"display_name": f"{platform_type}({platform_id})",
|
||||
}
|
||||
)
|
||||
|
||||
adjusted_platform_enable = {}
|
||||
for platform_id, plugins in platform_enable.items():
|
||||
adjusted_platform_enable[platform_id] = plugins
|
||||
|
||||
# 获取所有插件,包括系统内部插件
|
||||
plugins = []
|
||||
for plugin in self.plugin_manager.context.get_all_stars():
|
||||
plugins.append(
|
||||
{
|
||||
"name": plugin.name,
|
||||
"desc": plugin.desc,
|
||||
"reserved": plugin.reserved, # 添加reserved标志
|
||||
}
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"获取插件平台配置: 原始配置={platform_enable}, 调整后={adjusted_platform_enable}"
|
||||
)
|
||||
|
||||
return (
|
||||
Response()
|
||||
.ok(
|
||||
{
|
||||
"platforms": platforms,
|
||||
"plugins": plugins,
|
||||
"platform_enable": adjusted_platform_enable,
|
||||
}
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/platform_enable/get: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
async def set_plugin_platform_enable(self):
|
||||
"""设置插件在各平台的可用性配置"""
|
||||
if DEMO_MODE:
|
||||
return (
|
||||
Response()
|
||||
.error("You are not permitted to do this operation in demo mode")
|
||||
.__dict__
|
||||
)
|
||||
|
||||
try:
|
||||
data = await request.json
|
||||
platform_enable = data.get("platform_enable", {})
|
||||
|
||||
# 更新配置
|
||||
config = self.core_lifecycle.astrbot_config
|
||||
platform_settings = config.get("platform_settings", {})
|
||||
platform_settings["plugin_enable"] = platform_enable
|
||||
config["platform_settings"] = platform_settings
|
||||
config.save_config()
|
||||
|
||||
# 更新插件的平台兼容性缓存
|
||||
await self.plugin_manager.update_all_platform_compatibility()
|
||||
|
||||
logger.info(f"插件平台可用性配置已更新: {platform_enable}")
|
||||
|
||||
return Response().ok(None, "插件平台可用性配置已更新").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/plugin/platform_enable/set: {traceback.format_exc()}")
|
||||
return Response().error(str(e)).__dict__
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
import aiohttp
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
@@ -20,6 +21,7 @@ class ToolsRoute(Route):
|
||||
"/tools/mcp/add": ("POST", self.add_mcp_server),
|
||||
"/tools/mcp/update": ("POST", self.update_mcp_server),
|
||||
"/tools/mcp/delete": ("POST", self.delete_mcp_server),
|
||||
"/tools/mcp/market": ("GET", self.get_mcp_markets),
|
||||
}
|
||||
self.register_routes()
|
||||
self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools
|
||||
@@ -78,6 +80,7 @@ class ToolsRoute(Route):
|
||||
) in self.tool_mgr.mcp_client_dict.items():
|
||||
if name_key == name:
|
||||
server_info["tools"] = [tool.name for tool in mcp_client.tools]
|
||||
server_info["errlogs"] = mcp_client.server_errlogs
|
||||
break
|
||||
else:
|
||||
server_info["tools"] = []
|
||||
@@ -105,8 +108,14 @@ class ToolsRoute(Route):
|
||||
|
||||
# 复制所有配置字段
|
||||
for key, value in server_data.items():
|
||||
if key not in ["name", "active", "tools"]: # 排除特殊字段
|
||||
server_config[key] = value
|
||||
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
|
||||
if key == "mcpServers":
|
||||
key_0 = list(server_data["mcpServers"].keys())[
|
||||
0
|
||||
] # 不考虑为空的情况
|
||||
server_config = server_data["mcpServers"][key_0]
|
||||
else:
|
||||
server_config[key] = value
|
||||
has_valid_config = True
|
||||
|
||||
if not has_valid_config:
|
||||
@@ -121,7 +130,7 @@ class ToolsRoute(Route):
|
||||
|
||||
if self.save_mcp_config(config):
|
||||
# 动态初始化新MCP客户端
|
||||
self.tool_mgr.mcp_service_queue.put_nowait(
|
||||
await self.tool_mgr.mcp_service_queue.put(
|
||||
{
|
||||
"type": "init",
|
||||
"name": name,
|
||||
@@ -162,8 +171,14 @@ class ToolsRoute(Route):
|
||||
|
||||
# 复制所有配置字段
|
||||
for key, value in server_data.items():
|
||||
if key not in ["name", "active", "tools"]: # 排除特殊字段
|
||||
server_config[key] = value
|
||||
if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段
|
||||
if key == "mcpServers":
|
||||
key_0 = list(server_data["mcpServers"].keys())[
|
||||
0
|
||||
] # 不考虑为空的情况
|
||||
server_config = server_data["mcpServers"][key_0]
|
||||
else:
|
||||
server_config[key] = value
|
||||
only_update_active = False
|
||||
|
||||
# 如果只更新活动状态,保留原始配置
|
||||
@@ -194,7 +209,7 @@ class ToolsRoute(Route):
|
||||
)
|
||||
else:
|
||||
# 客户端不存在,初始化
|
||||
self.tool_mgr.mcp_service_queue.put_nowait(
|
||||
await self.tool_mgr.mcp_service_queue.put(
|
||||
{
|
||||
"type": "init",
|
||||
"name": name,
|
||||
@@ -250,3 +265,26 @@ class ToolsRoute(Route):
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(f"删除 MCP 服务器失败: {str(e)}").__dict__
|
||||
|
||||
async def get_mcp_markets(self):
|
||||
page = request.args.get("page", 1, type=int)
|
||||
page_size = request.args.get("page_size", 10, type=int)
|
||||
BASE_URL = "https://api.soulter.top/astrbot/mcpservers?page={}&page_size={}".format(
|
||||
page,
|
||||
page_size,
|
||||
)
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(f"{BASE_URL}") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
return Response().ok(data["data"]).__dict__
|
||||
else:
|
||||
return (
|
||||
Response()
|
||||
.error(f"获取市场数据失败: HTTP {response.status}")
|
||||
.__dict__
|
||||
)
|
||||
except Exception as _:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error("获取市场数据失败").__dict__
|
||||
@@ -136,10 +136,11 @@ class UpdateRoute(Route):
|
||||
|
||||
data = await request.json
|
||||
package = data.get("package", "")
|
||||
mirror = data.get("mirror", None)
|
||||
if not package:
|
||||
return Response().error("缺少参数 package 或不合法。").__dict__
|
||||
try:
|
||||
pip_installer.install(package)
|
||||
pip_installer.install(package, mirror=mirror)
|
||||
return Response().ok(None, "安装成功。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_pip: {traceback.format_exc()}")
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
# What's Changed
|
||||
|
||||
> 📢 在升级前,请完整阅读本次更新日志。
|
||||
> 此版本为针对 `v3.5.3` 的紧急修复版本
|
||||
|
||||
## ✨ 新增的功能
|
||||
|
||||
1. Telegram、Webchat、QQ官方机器人平台(私聊)支持流式输出(实验性)。@Soulter @Raven95676 @anka-afk
|
||||
2. 支持针对不同消息平台开启/关闭插件 @zhx8702 @Raven95676 @Soulter
|
||||
3. 插件市场支持显示 Star 个数、插件管理支持插件帮助对话框 @kterna
|
||||
4. 飞书平台支持主动消息发送 @Soulter
|
||||
5. Telegram 平台适配显示指令列表,支持自动补全 @Raven95676
|
||||
6. 新增配置项允许配置当超出最多携带对话数量时,一次性丢弃多少条旧消息 @Rail1bc
|
||||
7. StarTool 新增获取插件数据目录接口 @Raven95676
|
||||
|
||||
## 🎈 功能性优化
|
||||
|
||||
1. 优化 /his 指令对函数调用的显示 @anka-afk
|
||||
2. QQ 官方机器人支持对同一条消息多次回复 @kuangfeng
|
||||
|
||||
## 🐛 修复的 Bug
|
||||
|
||||
1. ‼️ 修复使用 gemini 时,函数数工具调用会重复调用已经在过去会话中调用过的工具 @Soulter
|
||||
2. 修复使用 Gemini 模型时出现 <empty_content> 的问题 @anka-afk
|
||||
4. 修复使用 OneAPI + Gemini(openai) 传递空参数函数工具时可能报错的问题 @Soulter
|
||||
5. 修复 permission 过滤算子的 raise_error 参数失效的问题 @Soulter
|
||||
6. 修复函数调用时可能出现 `messages with role 'tool' must be a response to a preceeding message with 'tool_calls'` 报错的问题 @anka-afk
|
||||
7. 修复 dify 下删除对话的报错问题 @Soulter
|
||||
8. 修复人格预设对话多次插入上下文的问题 @Rail1bc
|
||||
9. 修复了 event.get_sender_id() 返回值与函数注释不一致的问题 @zsbai
|
||||
|
||||
|
||||
## 🧩 新增的插件
|
||||
|
||||
待补充
|
||||
@@ -0,0 +1,41 @@
|
||||
# What's Changed
|
||||
|
||||
> 📢 在升级前,请完整阅读本次更新日志。
|
||||
> 此版本为针对 `v3.5.3` 的紧急修复版本
|
||||
> 修复以下 BUG:
|
||||
> 1. 智谱 GLM 在函数工具有空参数时报错的问题。
|
||||
|
||||
---
|
||||
|
||||
v3.5.3
|
||||
|
||||
## ✨ 新增的功能
|
||||
|
||||
1. Telegram、Webchat、QQ官方机器人平台(私聊)支持流式输出(实验性)。@Soulter @Raven95676 @anka-afk
|
||||
2. 支持针对不同消息平台开启/关闭插件 @zhx8702 @Raven95676 @Soulter
|
||||
3. 插件市场支持显示 Star 个数、插件管理支持插件帮助对话框 @kterna
|
||||
4. 飞书平台支持主动消息发送 @Soulter
|
||||
5. Telegram 平台适配显示指令列表,支持自动补全 @Raven95676
|
||||
6. 新增配置项允许配置当超出最多携带对话数量时,一次性丢弃多少条旧消息 @Rail1bc
|
||||
7. StarTool 新增获取插件数据目录接口 @Raven95676
|
||||
|
||||
## 🎈 功能性优化
|
||||
|
||||
1. 优化 /his 指令对函数调用的显示 @anka-afk
|
||||
2. QQ 官方机器人支持对同一条消息多次回复 @kuangfeng
|
||||
|
||||
## 🐛 修复的 Bug
|
||||
|
||||
1. ‼️ 修复使用 gemini 时,函数数工具调用会重复调用已经在过去会话中调用过的工具 @Soulter
|
||||
2. 修复使用 Gemini 模型时出现 <empty_content> 的问题 @anka-afk
|
||||
4. 修复使用 OneAPI + Gemini(openai) 传递空参数函数工具时可能报错的问题 @Soulter
|
||||
5. 修复 permission 过滤算子的 raise_error 参数失效的问题 @Soulter
|
||||
6. 修复函数调用时可能出现 `messages with role 'tool' must be a response to a preceeding message with 'tool_calls'` 报错的问题 @anka-afk
|
||||
7. 修复 dify 下删除对话的报错问题 @Soulter
|
||||
8. 修复人格预设对话多次插入上下文的问题 @Rail1bc
|
||||
9. 修复了 event.get_sender_id() 返回值与函数注释不一致的问题 @zsbai
|
||||
|
||||
|
||||
## 🧩 新增的插件
|
||||
|
||||
待补充
|
||||
@@ -0,0 +1,34 @@
|
||||
# What's Changed
|
||||
|
||||
> 📢 在升级前,请完整阅读本次更新日志。
|
||||
|
||||
## ✨ 新增的功能
|
||||
|
||||
1. Telegram、Webchat、QQ官方机器人平台(私聊)支持流式输出(实验性)。@Soulter @Raven95676 @anka-afk
|
||||
2. 支持针对不同消息平台开启/关闭插件 @zhx8702 @Raven95676 @Soulter
|
||||
3. 插件市场支持显示 Star 个数、插件管理支持插件帮助对话框 @kterna
|
||||
4. 飞书平台支持主动消息发送 @Soulter
|
||||
5. Telegram 平台适配显示指令列表,支持自动补全 @Raven95676
|
||||
6. 新增配置项允许配置当超出最多携带对话数量时,一次性丢弃多少条旧消息 @Rail1bc
|
||||
7. StarTool 新增获取插件数据目录接口 @Raven95676
|
||||
|
||||
## 🎈 功能性优化
|
||||
|
||||
1. 优化 /his 指令对函数调用的显示 @anka-afk
|
||||
2. QQ 官方机器人支持对同一条消息多次回复 @kuangfeng
|
||||
|
||||
## 🐛 修复的 Bug
|
||||
|
||||
1. ‼️ 修复使用 gemini 时,函数数工具调用会重复调用已经在过去会话中调用过的工具 @Soulter
|
||||
2. 修复使用 Gemini 模型时出现 <empty_content> 的问题 @anka-afk
|
||||
4. 修复使用 OneAPI + Gemini(openai) 传递空参数函数工具时可能报错的问题 @Soulter
|
||||
5. 修复 permission 过滤算子的 raise_error 参数失效的问题 @Soulter
|
||||
6. 修复函数调用时可能出现 `messages with role 'tool' must be a response to a preceeding message with 'tool_calls'` 报错的问题 @anka-afk
|
||||
7. 修复 dify 下删除对话的报错问题 @Soulter
|
||||
8. 修复人格预设对话多次插入上下文的问题 @Rail1bc
|
||||
9. 修复了 event.get_sender_id() 返回值与函数注释不一致的问题 @zsbai
|
||||
|
||||
|
||||
## 🧩 新增的插件
|
||||
|
||||
待补充
|
||||
@@ -0,0 +1,156 @@
|
||||
# What's Changed
|
||||
|
||||
> 📢 在升级前,请完整阅读本次更新日志。
|
||||
|
||||
## ✨ 新增的功能
|
||||
|
||||
1. 上线 MCP 市场(beta) @Soulter
|
||||
2. MCP 服务器支持通过 SSE 连接 @Soulter
|
||||
3. 支持自定义 PyPI 软件仓库地址 @Soulter
|
||||
4. 支持开关是否忽略自身发送的消息 @Soulter
|
||||
5. Docker 镜像自带 node 环境以适应 MCP 需要 @Soulter
|
||||
6. 添加对 Gemini 原生搜索功能的支持 @Raven95676
|
||||
7. 企业微信添加长文本分割功能以支持发送超过 2048 字符的消息 @Soulter @anka-afk
|
||||
8. TTS 支持同时输出原始文本 @YOO-koishi
|
||||
|
||||
## 🎈 功能性优化
|
||||
|
||||
1. shared_preferences加载失败时自动删除无效文件 @Raven95676
|
||||
2. 适配 MCP 配置文件带 mcpServers 的情况(Cursor) @Soulter
|
||||
3. 采用 google-genai SDK 重构 Gemini 适配器 @Raven95676 @Soulter
|
||||
4. 优化已安装的插件页,支持以列表展示 @Soulter
|
||||
5. 分段回复优化 @huirh @Raven95676
|
||||
6. 优化 MCP 服务器的日志回显 @Soulter
|
||||
7. 为不支持流式输出的平台提供 Fallback @Raven95676
|
||||
8. 替换为采用 Semver 语义化版本来比较版本号 @Soulter
|
||||
9. 文件发送时支持路径映射 @Jackxwb
|
||||
|
||||
## 🐛 修复的 Bug
|
||||
|
||||
1. 修复关闭/删除 MCP 服务器后 Tools 没有清除的问题 @Soulter
|
||||
2. 修复超出最大对话数时每次清除的消息比实际上期望的多 1 条 的问题 @Raila23
|
||||
3. 修复调用函数工具可能导致 400 Bad Request 的问题 @Raila23
|
||||
4. 修复飞书适配器无法发送 Base64 图片的问题 @KimigaiiWuyi @Soulter
|
||||
5. 修复上下文带图的情况下,对话数据库页无法查看对话详情的问题 @Soulter
|
||||
6. Telegram 适配器注册指令功能优化 @Raven95676
|
||||
7. 修复阿里云百炼 TTS 只能发送一次语音,第二次就会报错 @Soulter
|
||||
|
||||
## 🧩 新增的插件
|
||||
|
||||
> Automatically generated by program.
|
||||
|
||||
- [Plugin] 60秒国内新闻 by @bbpn-cn in #970
|
||||
- [Plugin] astrbot_plugin_memelite by @Zhalslar in #977
|
||||
- [Plugin] 赛博打胶 by @tenno1174 in #980
|
||||
- [Plugin] astrbot_plugin_PockAttack by @LouieKH359 in #981
|
||||
- [Plugin] astrbot_plugin_saris_economic by @chengcheng0325 in #984
|
||||
- [Plugin] astrbot_plugin_saris_db by @chengcheng0325 in #985
|
||||
- [Plugin] astrbot_plugin_today_in_history by @Zhalslar in #987
|
||||
- [Plugin] astrbot_plugin_history_day by @Zhalslar in #989
|
||||
- [Plugin] astrbot_plugin_nachoneko by @Rinyin in #991
|
||||
- [Plugin] jmcomicsget by @Ayachi2225 in #993
|
||||
- [Plugin] astrbot_plugin_idiom by @zhx8702 in #994
|
||||
- [Plugin] anime_gacha by @xco2 in #997
|
||||
- [Plugin] jmcomic_downloader by @QiChenSn in #1007
|
||||
- [Plugin] gewe_chatsummary by @NiceAir in #1013
|
||||
- [Plugin] 群CCB by @tenno1174 in #1016
|
||||
- [Plugin] 自动生成图表(思维导图、流程图等) by @kterna in #1018
|
||||
- [Plugin] astrbot_plugin_jm_sender by @EnderPPT in #1019
|
||||
- [Plugin] 插件名random image by @IGCrystal in #1021
|
||||
- [Plugin] astrbot_plugin_saris_fish by @chengcheng0325 in #1022
|
||||
- [Plugin] astrbot_plugin_membercontrast by @laopanmemz in #1027
|
||||
- [Plugin] jm_search by @Ryonnoski0 in #1028
|
||||
- [Plugin] astrbot_plugin_gomoku by @zhx8702 in #1029
|
||||
- [Plugin] astrbot_plugin_CounterStrikle by @Last-emo-boy in #1036
|
||||
- [Plugin] encrypt-and-decrypt by @Soffd in #1037
|
||||
- [Plugin] emoji合成 by @ttq7 in #1041
|
||||
- [Plugin] astrbot_plugin_file_reader by @xiewoc in #1043
|
||||
- [Plugin] vv_pic by @LonelySky7490 in #1048
|
||||
- [Plugin] bot代戳 by @791819 in #1049
|
||||
- [Plugin] astrbot_plugin_weather_wttr_in by @xiewoc in #1051
|
||||
- [Plugin] astrbot_plugin_kahunabot by @AraragiEro in #1053
|
||||
- [Plugin] astrbot_plugin_answerbook by @litsum in #1058
|
||||
- [Plugin] astrbot_plugin_ewords by @IGCrystal in #1059
|
||||
- [Plugin] minecraft投影管理器 by @kterna in #1064
|
||||
- [Plugin] minecraft投影管理器 by @kterna in #1063
|
||||
- [Plugin] astrbot_plugin_encipherer by @Soffd in #1066
|
||||
- [Plugin] 定时任务提醒插件 by @advent259141 in #1068
|
||||
- [Plugin] 定时任务提醒插件 by @advent259141 in #1067
|
||||
- [Plugin] bot_plugin_doro_today by @Futureppo in #1071
|
||||
- [Plugin] JmCli by @gaxiic in #1076
|
||||
- [Plugin] 用户自定义识别nickname by @MR-pofeng in #1078
|
||||
- [Plugin] astrbot_plugin_timtip by @IGCrystal in #1082
|
||||
- [Plugin] y @caomeiguodong in #1083
|
||||
- [Plugin] astrbot_plugin_search_pic by @lyjlyjlyjly in #1084
|
||||
- [Plugin] ot_plugin_quote_collocter by @litsum in #1089
|
||||
- [Plugin] astrbot_plugin_ending by @clfpwp in #1090
|
||||
- [Plugin] astrbot_plugin_GPT_SoVITS by @Zhalslar in #1091
|
||||
- [Plugin] astrbot_plugin_Merge_WeMSG by @zj591227045 in #1092
|
||||
- [Plugin] 随机维什戴尔游戏日语语音的AstrBot插件 by @zhewang448 in #1094
|
||||
- [Plugin] astrbot_plugin_QQProfile by @Zhalslar in #1095
|
||||
- [Plugin] astrbot_plugin_mcping by @Zhalslar in #1097
|
||||
- [Plugin] astrbot_plugin_lorebook_lite by @Raven95676 in #1098
|
||||
- [Plugin] astrbot_plugin_grok_filter by @Cheng-MaoMao in #1099
|
||||
- [Plugin] astrbot_plugin_quarksave by @lm379 in #1100
|
||||
- [Plugin] daily_limit by @left666 in #1102
|
||||
- [Plugin] 赛博塔罗牌 by @XziXmn in #1103
|
||||
- [Plugin] 把QQ里面不可保存的表情转化为可以保存的插件astrbot_plugins_ConvetPicture by @orchidsziyou in #1115
|
||||
- [Plugin] astrbot_plugin_get_weather_cmd by @whzcc in #1117
|
||||
- [Plugin] httpposter by @Wayzinx in #1118
|
||||
- [Plugin] astrbot_plugin_get_weather_msg by @whzcc in #1119
|
||||
- [Plugin] astrbot_plugin_cs2-box by @bvzrays in #1124
|
||||
- [Plugin] astrbot_plugin_liars_bar by @xunxiing in #1125
|
||||
- [Plugin] astrbot_plugin_cs2-box by @bvzrays in #1129
|
||||
- [Plugin] astrbot_plugin_no_dragon_lord by @anka-afk in #1130
|
||||
- [Plugin] astrbot_plugin_liars_bar by @xunxiing in #1134
|
||||
- [Plugin] astrbot_plugin_QQAdmin by @Zhalslar in #1137
|
||||
- [Plugin] astrbot_plugin_SessionFaker by @advent259141 in #1138
|
||||
- [Plugin] astrbot_plugin_browser by @Zhalslar in #1140
|
||||
- [Plugin] astrbot_plugin_aishit by @advent259141 in #1141
|
||||
- [Plugin] Arch Linux 软件包搜索插件 by @xmengnet in #1142
|
||||
- [Plugin] astrbot_plugin_composting_bucket by @Rail1bc in #1147
|
||||
- [Plugin] 任务管理task-management by @zengweis in #1149
|
||||
- [Plugin] astrbot_plugin_appreview by @qiqi55488 in #1151
|
||||
- [Plugin] doro互动故事 by @ttq7 in #1153
|
||||
- [Plugin] 斗牛牛 by @LaoZhuJackson in #1155
|
||||
- [Plugin] astrbot_plugin_reread by @Zhalslar in #1162
|
||||
- [Plugin] astrbot_plugin_media302-save by @Qoo-330ml in #1163
|
||||
- [Plugin] astrbot_plugin_ehentai_bot by @drdon1234 in #1168
|
||||
- [Plugin] 追番助手(AGE) by @xiamuceer-j in #1181
|
||||
- [Plugin] astrbot_plugin_zanwo by @Futureppo in #1183
|
||||
- [Plugin] astrbot_plugin_jrrp by @exusiaiwei in #1189
|
||||
- [Plugin] astrbot_plugin_group_chatsummary by @glidersxu in #1193
|
||||
- [Plugin] astrbot_plugin_showmejm by @drdon1234 in #1202
|
||||
- [Plugin] astrbot_plugin_gscore_adapter by @KimigaiiWuyi in #1206
|
||||
- [Plugin] astrbot_portainer_plugin by @RC-CHN in #1209
|
||||
- [Plugin] astrbot_plugin_goldprice by @waterfeet in #1210
|
||||
- [Plugin] astrbot_plugin_xyzw by @XuYingJie-cmd in #1213
|
||||
- [Plugin] astrbot_plugin_alist by @yukikazechan in #1217
|
||||
- [Plugin] astrbot_plugin_showme_xjj by @drdon1234 in #1219
|
||||
- [Plugin] astrbot_plugin_60s_news by @flyinsz in #1220
|
||||
- [Plugin] astrbot_plugin_mp by @EWEDLCM in #1229
|
||||
- [Plugin] 浅草寺抽签插件-PRO by @xiamuceer-j in #1230
|
||||
- [Plugin] astrbot_plugin_gallery by @Zhalslar in #1238
|
||||
- [Plugin] astrbot_plugin_openweaponscase by @luooka in #1250
|
||||
- [Plugin] 多功能插件 by @ttq7 in #1254
|
||||
- [Plugin] astrbot_plugin_douyin_bot by @drdon1234 in #1255
|
||||
- [Plugin] astrbot_plugin_password by @Zhalslar in #1262
|
||||
- [Plugin] FavorSystem by @wuyan1003 in #1264
|
||||
- [Plugin] astrbot_plugin_ExchangeRateQuery by @MoonShadow1976 in #1271
|
||||
- [Plugin] astrbot_plugin_pexels by @xiamuceer-j in #1278
|
||||
- [Plugin] astrbot_plugin_hello-bye by @tinkerbellqwq in #1279
|
||||
- [Plugin] astrbot_plugin_gotify by @BetaCatX in #1283
|
||||
- [Plugin] astrbot_plugin_ccb_plus by @Koikokokokoro in #1287
|
||||
- [Plugin] astrbot-gold-plugin by @RC-CHN in #1299
|
||||
- [Plugin] 幻影坦克 by @bigshabei in #1305
|
||||
- [Plugin] astrbot_plugin_repeat_after_me by @0d00-Ciallo-0721 in #1306
|
||||
- [Plugin] astrbot_plugin_video by @guowenye in #1307
|
||||
- [Plugin] astrbot_plugin_group_sum_ai by @Ayu-u in #1308
|
||||
- [Plugin] 五子棋 by @bigshabei in #1309
|
||||
- [Plugin] 舔狗日记 by @bigshabei in #1310
|
||||
- [Plugin] astrbot_plugin_status-pro by @tinkerbellqwq in #1312
|
||||
- [Plugin] 监听/转发 by @Cedar2352 in #1322
|
||||
- [Plugin] astrbot_plugin_fuck by @vmoranv in #1338
|
||||
- [Plugin] 食物推荐插件 by @Wayzinx in #1331
|
||||
- [Plugin] astrbot_plugin_llmgo by @advent259141 in #1332
|
||||
- [Plugin] astrbot_plugin_a2s by @ZvZPvz in #1337
|
||||
@@ -0,0 +1,6 @@
|
||||
# What's Changed
|
||||
|
||||
## 🐛 修复的 Bug
|
||||
|
||||
1. 修复 Gemini 下可能无法正常使用 Tools 的问题 @Raven95676
|
||||
2. 修复 WebUI MCP 页面的一些问题 @Soulter
|
||||
@@ -0,0 +1,13 @@
|
||||
# What's Changed
|
||||
|
||||
> 🙁 Gewechat 已经停止维护,我们将更换更稳定的个人微信接入方式。如有问题请提交 issue。
|
||||
> 🧐 预告:接下来三个版本之内将会逐步上线 Live2D 桌宠、长期记忆(实验性)的功能。
|
||||
|
||||
1. Gewechat 相关 bug 修复(即使已经不可用 :( ) @BigFace123 @XiGuang @Soulter
|
||||
2. 支持 CLI 命令行 @LIghtJUNction
|
||||
3. 修复 QQ 下带有网址的指令可能无法识别的问题 @kkjzio
|
||||
4. `reset` 指令优化 @anka-afk
|
||||
5. Gemini 请求优化,支持 Gemini 思考信息设置 @Raven95676
|
||||
6. 支持处理 MCP 服务器返回的图片等多模态信息 @Raven95676
|
||||
7. 插件市场支持基于 Star 和 更新时间排序 @Soulter
|
||||
8. 优化 QQ 下自动下载文件导致磁盘被占满的问题 @Soulter @anka-afk
|
||||
@@ -0,0 +1,5 @@
|
||||
# What's Changed
|
||||
|
||||
> Gewechat 已经停止维护,此版本提供了 `微信客服` 的接入方式,可以在直接微信内聊天。这是微信官方推出的接入方式,因此没有风控风险。详见 [AstrBot 接入企业微信](https://astrbot.app/deploy/platform/wecom.html)。此接入方式处于测试阶段,有问题请及时在 GitHub 上提交 Issue。
|
||||
|
||||
1. 支持接入微信客服。
|
||||
@@ -24,13 +24,10 @@ const emit = defineEmits([
|
||||
'install',
|
||||
'uninstall',
|
||||
'toggle-activation',
|
||||
'view-handlers'
|
||||
'view-handlers',
|
||||
'view-readme'
|
||||
]);
|
||||
|
||||
const open = (link: string | undefined) => {
|
||||
window.open(link, '_blank');
|
||||
};
|
||||
|
||||
const reveal = ref(false);
|
||||
|
||||
// 操作函数
|
||||
@@ -70,6 +67,10 @@ const toggleActivation = () => {
|
||||
const viewHandlers = () => {
|
||||
emit('view-handlers', props.extension);
|
||||
};
|
||||
|
||||
const viewReadme = () => {
|
||||
emit('view-readme', props.extension);
|
||||
};
|
||||
</script>
|
||||
|
||||
<template>
|
||||
@@ -80,7 +81,7 @@ const viewHandlers = () => {
|
||||
<div class="flex-grow-1">
|
||||
<div>{{ extension.author }} /</div>
|
||||
|
||||
<p class="text-h3 font-weight-black" :class="{ 'text-h4': $vuetify.display.xs }">
|
||||
<p class="text-h4 font-weight-black" :class="{ 'text-h4': $vuetify.display.xs }">
|
||||
{{ extension.name }}
|
||||
<v-tooltip location="top" v-if="extension?.has_update && !marketMode">
|
||||
<template v-slot:activator="{ props: tooltipProps }">
|
||||
@@ -128,7 +129,7 @@ const viewHandlers = () => {
|
||||
</v-card-text>
|
||||
|
||||
<v-card-actions style="padding: 0px; margin-top: auto;">
|
||||
<v-btn color="teal-accent-4" text="帮助" variant="text" @click="open(extension.repo)"></v-btn>
|
||||
<v-btn color="teal-accent-4" text="查看文档" variant="text" @click="viewReadme"></v-btn>
|
||||
<v-btn v-if="!marketMode" color="teal-accent-4" text="操作" variant="text" @click="reveal = true"></v-btn>
|
||||
<v-btn v-if="marketMode && !extension?.installed" color="teal-accent-4" text="安装" variant="text"
|
||||
@click="emit('install', extension)"></v-btn>
|
||||
|
||||
@@ -114,19 +114,6 @@ export default {
|
||||
justify-content: space-between;
|
||||
}
|
||||
|
||||
.item-status-indicator {
|
||||
position: absolute;
|
||||
top: 0;
|
||||
left: 0;
|
||||
width: 100%;
|
||||
height: 4px;
|
||||
background-color: #e0e0e0;
|
||||
}
|
||||
|
||||
.item-status-indicator.active {
|
||||
background-color: #4CAF50;
|
||||
}
|
||||
|
||||
.hover-elevation:hover {
|
||||
box-shadow: 0 6px 12px rgba(0, 0, 0, 0.1);
|
||||
transform: translateY(-2px);
|
||||
|
||||
@@ -0,0 +1,302 @@
|
||||
<script setup>
|
||||
import { ref, watch, onMounted } from 'vue';
|
||||
import axios from 'axios';
|
||||
import { marked } from 'marked';
|
||||
import hljs from 'highlight.js';
|
||||
import 'highlight.js/styles/github.css';
|
||||
|
||||
const props = defineProps({
|
||||
show: {
|
||||
type: Boolean,
|
||||
default: false
|
||||
},
|
||||
pluginName: {
|
||||
type: String,
|
||||
default: ''
|
||||
},
|
||||
repoUrl: {
|
||||
type: String,
|
||||
default: null
|
||||
}
|
||||
});
|
||||
|
||||
const emit = defineEmits(['update:show']);
|
||||
|
||||
const content = ref(null);
|
||||
const error = ref(null);
|
||||
const loading = ref(false);
|
||||
|
||||
// 监听show的变化,当显示对话框时加载内容
|
||||
watch(() => props.show, (newVal) => {
|
||||
if (newVal && props.pluginName) {
|
||||
fetchReadme();
|
||||
}
|
||||
});
|
||||
|
||||
// 监听pluginName的变化
|
||||
watch(() => props.pluginName, (newVal) => {
|
||||
if (props.show && newVal) {
|
||||
fetchReadme();
|
||||
}
|
||||
});
|
||||
|
||||
// 获取README内容
|
||||
async function fetchReadme() {
|
||||
if (!props.pluginName) return;
|
||||
|
||||
loading.value = true;
|
||||
content.value = null;
|
||||
error.value = null;
|
||||
|
||||
try {
|
||||
// 从本地文件获取README
|
||||
const res = await axios.get(`/api/plugin/readme?name=${props.pluginName}`);
|
||||
if (res.data.status === 'ok') {
|
||||
content.value = res.data.data.content;
|
||||
} else {
|
||||
error.value = res.data.message || '获取README失败';
|
||||
}
|
||||
} catch (err) {
|
||||
error.value = err.message || '获取README时发生错误';
|
||||
} finally {
|
||||
loading.value = false;
|
||||
}
|
||||
}
|
||||
|
||||
// 打开GitHub中的仓库
|
||||
function openRepoInNewTab() {
|
||||
if (props.repoUrl) {
|
||||
window.open(props.repoUrl, '_blank');
|
||||
}
|
||||
}
|
||||
|
||||
// 渲染Markdown内容
|
||||
function renderMarkdown(content) {
|
||||
if (!content) return '';
|
||||
|
||||
// 配置marked使用highlight.js进行语法高亮
|
||||
marked.setOptions({
|
||||
highlight: function(code, lang) {
|
||||
if (lang && hljs.getLanguage(lang)) {
|
||||
try {
|
||||
return hljs.highlight(code, { language: lang }).value;
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}
|
||||
return hljs.highlightAuto(code).value;
|
||||
},
|
||||
gfm: true, // GitHub Flavored Markdown
|
||||
breaks: true, // Convert \n to <br>
|
||||
headerIds: true, // Add id attributes to headers
|
||||
mangle: false // Don't mangle email addresses
|
||||
});
|
||||
|
||||
return marked(content);
|
||||
}
|
||||
|
||||
// 刷新README内容
|
||||
function refreshReadme() {
|
||||
fetchReadme();
|
||||
}
|
||||
</script>
|
||||
|
||||
<template>
|
||||
<v-dialog v-model="_show" width="800" persistent>
|
||||
<v-card>
|
||||
<v-card-title class="d-flex justify-space-between align-center">
|
||||
<span class="text-h5">插件说明文档</span>
|
||||
<v-btn icon @click="$emit('update:show', false)">
|
||||
<v-icon>mdi-close</v-icon>
|
||||
</v-btn>
|
||||
</v-card-title>
|
||||
<v-divider></v-divider>
|
||||
<v-card-text style="height: 70vh; overflow-y: auto;">
|
||||
<div class="d-flex justify-space-between mb-4">
|
||||
<v-btn
|
||||
v-if="repoUrl"
|
||||
color="primary"
|
||||
prepend-icon="mdi-github"
|
||||
@click="openRepoInNewTab()"
|
||||
>
|
||||
在GitHub中查看仓库
|
||||
</v-btn>
|
||||
<v-btn
|
||||
color="secondary"
|
||||
prepend-icon="mdi-refresh"
|
||||
@click="refreshReadme()"
|
||||
>
|
||||
刷新文档
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<!-- 加载中 -->
|
||||
<div v-if="loading" class="d-flex flex-column align-center justify-center" style="height: 100%;">
|
||||
<v-progress-circular indeterminate color="primary" size="64" class="mb-4"></v-progress-circular>
|
||||
<p class="text-body-1 text-center">正在加载README文档...</p>
|
||||
</div>
|
||||
|
||||
<!-- 内容显示 -->
|
||||
<div v-else-if="content" class="markdown-body" v-html="renderMarkdown(content)"></div>
|
||||
|
||||
<!-- 错误提示 -->
|
||||
<div v-else-if="error" class="d-flex flex-column align-center justify-center" style="height: 100%;">
|
||||
<v-icon size="64" color="error" class="mb-4">mdi-alert-circle-outline</v-icon>
|
||||
<p class="text-body-1 text-center mb-4">{{ error }}</p>
|
||||
</div>
|
||||
|
||||
<!-- 无内容提示 -->
|
||||
<div v-else class="d-flex flex-column align-center justify-center" style="height: 100%;">
|
||||
<v-icon size="64" color="warning" class="mb-4">mdi-file-question-outline</v-icon>
|
||||
<p class="text-body-1 text-center mb-4">该插件未提供文档链接或GitHub仓库地址。<br>请查看插件市场或联系插件作者获取更多信息。</p>
|
||||
</div>
|
||||
</v-card-text>
|
||||
<v-divider></v-divider>
|
||||
<v-card-actions>
|
||||
<v-spacer></v-spacer>
|
||||
<v-btn color="primary" variant="tonal" @click="$emit('update:show', false)">
|
||||
关闭
|
||||
</v-btn>
|
||||
</v-card-actions>
|
||||
</v-card>
|
||||
</v-dialog>
|
||||
</template>
|
||||
|
||||
<style>
|
||||
.markdown-body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Helvetica, Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
padding: 8px 0;
|
||||
color: #24292e;
|
||||
}
|
||||
|
||||
.markdown-body h1,
|
||||
.markdown-body h2,
|
||||
.markdown-body h3,
|
||||
.markdown-body h4,
|
||||
.markdown-body h5,
|
||||
.markdown-body h6 {
|
||||
margin-top: 24px;
|
||||
margin-bottom: 16px;
|
||||
font-weight: 600;
|
||||
line-height: 1.25;
|
||||
}
|
||||
|
||||
.markdown-body h1 {
|
||||
font-size: 2em;
|
||||
border-bottom: 1px solid #eaecef;
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
|
||||
.markdown-body h2 {
|
||||
font-size: 1.5em;
|
||||
border-bottom: 1px solid #eaecef;
|
||||
padding-bottom: 0.3em;
|
||||
}
|
||||
|
||||
.markdown-body p {
|
||||
margin-top: 0;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body code {
|
||||
padding: 0.2em 0.4em;
|
||||
margin: 0;
|
||||
background-color: rgba(27, 31, 35, 0.05);
|
||||
border-radius: 3px;
|
||||
font-family: "SFMono-Regular", Consolas, "Liberation Mono", Menlo, monospace;
|
||||
font-size: 85%;
|
||||
}
|
||||
|
||||
.markdown-body pre {
|
||||
padding: 16px;
|
||||
overflow: auto;
|
||||
font-size: 85%;
|
||||
line-height: 1.45;
|
||||
background-color: #f6f8fa;
|
||||
border-radius: 3px;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body pre code {
|
||||
background-color: transparent;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.markdown-body ul,
|
||||
.markdown-body ol {
|
||||
padding-left: 2em;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body img {
|
||||
max-width: 100%;
|
||||
margin: 8px 0;
|
||||
box-sizing: border-box;
|
||||
background-color: #fff;
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.markdown-body blockquote {
|
||||
padding: 0 1em;
|
||||
color: #6a737d;
|
||||
border-left: 0.25em solid #dfe2e5;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body a {
|
||||
color: #0366d6;
|
||||
text-decoration: none;
|
||||
}
|
||||
|
||||
.markdown-body a:hover {
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.markdown-body table {
|
||||
border-spacing: 0;
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
overflow: auto;
|
||||
margin-bottom: 16px;
|
||||
}
|
||||
|
||||
.markdown-body table th,
|
||||
.markdown-body table td {
|
||||
padding: 6px 13px;
|
||||
border: 1px solid #dfe2e5;
|
||||
}
|
||||
|
||||
.markdown-body table tr {
|
||||
background-color: #fff;
|
||||
border-top: 1px solid #c6cbd1;
|
||||
}
|
||||
|
||||
.markdown-body table tr:nth-child(2n) {
|
||||
background-color: #f6f8fa;
|
||||
}
|
||||
|
||||
.markdown-body hr {
|
||||
height: 0.25em;
|
||||
padding: 0;
|
||||
margin: 24px 0;
|
||||
background-color: #e1e4e8;
|
||||
border: 0;
|
||||
}
|
||||
</style>
|
||||
|
||||
<script>
|
||||
export default {
|
||||
name: 'ReadmeDialog',
|
||||
computed: {
|
||||
_show: {
|
||||
get() {
|
||||
return this.show;
|
||||
},
|
||||
set(value) {
|
||||
this.$emit('update:show', value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</script>
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user