Merge remote-tracking branch 'origin/master'

This commit is contained in:
Raven95676
2025-05-02 10:51:01 +08:00
13 changed files with 1136 additions and 648 deletions
+30 -3
View File
@@ -7,7 +7,7 @@ on:
name: Auto Release
jobs:
build:
build-and-publish-to-github-release:
runs-on: ubuntu-latest
permissions:
contents: write
@@ -28,8 +28,35 @@ jobs:
run: |
echo "changelog=changelogs/${{github.ref_name}}.md" >> "$GITHUB_ENV"
- name: Create Release
- name: Create GitHub Release
uses: ncipollo/release-action@v1
with:
bodyFile: ${{ env.changelog }}
artifacts: "dashboard/dist.zip"
artifacts: "dashboard/dist.zip"
build-and-publish-to-pypi:
# 构建并发布到 PyPI
runs-on: ubuntu-latest
needs: build
steps:
- name: Checkout repository
uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.10'
- name: Install uv
run: |
python -m pip install uv
- name: Build package
run: |
uv build
- name: Publish to PyPI
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_TOKEN }}
run: |
uv publish
+1
View File
@@ -0,0 +1 @@
3.10
+238
View File
@@ -0,0 +1,238 @@
import asyncio
import os
import shutil
import sys
import click
from pathlib import Path
from astrbot.core.config.default import VERSION
logo_tmpl = r"""
___ _______.___________..______ .______ ______ .___________.
/ \ / | || _ \ | _ \ / __ \ | |
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
/ /_\ \ \ \ | | | / | _ < | | | | | |
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
"""
# utils
def _get_astrbot_root(path: str | None) -> Path:
"""获取astrbot根目录"""
match path:
case None:
match ASTRBOT_ROOT := os.getenv("ASTRBOT_ROOT"):
case None:
astrbot_root = Path.cwd() / "data"
case _:
astrbot_root = Path(ASTRBOT_ROOT).resolve()
case str():
astrbot_root = Path(path).resolve()
dot_astrbot = astrbot_root / ".astrbot"
if not dot_astrbot.exists():
if click.confirm(
f"运行前必须先执行初始化!请检查当前目录是否正确,回车以继续: {astrbot_root}",
default=True,
abort=True,
):
dot_astrbot.touch()
astrbot_root.mkdir(parents=True, exist_ok=True)
click.echo(f"Created {dot_astrbot}")
return astrbot_root
# 通过类型来验证先后,必须先获取 Path 对象才能对该目录进行检查
def _check_astrbot_root(astrbot_root: Path) -> None:
"""验证"""
dot_astrbot = astrbot_root / ".astrbot"
if not astrbot_root.exists():
click.echo(f"AstrBot root directory does not exist: {astrbot_root}")
click.echo("Please run 'astrbot init' to create the directory.")
sys.exit(1)
else:
click.echo(f"AstrBot root directory exists: {astrbot_root}")
if not dot_astrbot.exists():
click.echo(
"如果你确认这是 Astrbot root directory, 你需要在当前目录下创建一个 .astrbot 文件标记该目录为 AstrBot 的数据目录。"
)
if click.confirm(
f"请检查当前目录是否正确,确认正确请回车: {astrbot_root}",
default=True,
abort=True,
):
dot_astrbot.touch()
click.echo(f"Created {dot_astrbot}")
else:
click.echo(f"Welcome back! AstrBot root directory: {astrbot_root}")
async def _check_dashboard(astrbot_root: Path) -> None:
"""检查是否安装了dashboard"""
try:
from ..core.utils.io import get_dashboard_version, download_dashboard
except ImportError:
from astrbot.core.utils.io import get_dashboard_version, download_dashboard
try:
# 添加 create=True 参数以确保在初始化时不会抛出异常
dashboard_version = await get_dashboard_version()
match dashboard_version:
case None:
click.echo("未安装管理面板")
if click.confirm(
"是否安装管理面板?",
default=True,
abort=True,
):
click.echo("正在安装管理面板...")
# 确保使用 create=True 参数
await download_dashboard(
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
click.echo("管理面板安装完成")
case str():
if dashboard_version == f"v{VERSION}":
click.echo("无需更新")
else:
try:
version = dashboard_version.split("v")[1]
click.echo(f"管理面板版本: {version}")
# 确保使用 create=True 参数
await download_dashboard(
path="data/dashboard.zip", extract_path=str(astrbot_root)
)
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return
except FileNotFoundError:
click.echo("初始化管理面板目录...")
# 初始化模式下,下载到指定位置
try:
await download_dashboard(
path=str(astrbot_root / "dashboard.zip"), extract_path=str(astrbot_root)
)
click.echo("管理面板初始化完成")
except Exception as e:
click.echo(f"下载管理面板失败: {e}")
return
@click.group(name="astrbot")
def cli() -> None:
"""The AstrBot CLI"""
click.echo(logo_tmpl)
click.echo("Welcome to AstrBot CLI!")
click.echo(f"AstrBot version: {VERSION}")
# region init
@cli.command()
@click.option("--path", "-p", help="AstrBot 数据目录")
@click.option("--force", "-f", is_flag=True, help="强制初始化")
def init(path: str | None, force: bool) -> None:
"""Initialize AstrBot"""
click.echo("Initializing AstrBot...")
astrbot_root = _get_astrbot_root(path)
if force:
if click.confirm(
"强制初始化会删除当前目录下的所有文件,是否继续?",
default=False,
abort=True,
):
click.echo("正在删除当前目录下的所有文件...")
shutil.rmtree(astrbot_root, ignore_errors=True)
_check_astrbot_root(astrbot_root)
click.echo(f"AstrBot root directory: {astrbot_root}")
if not astrbot_root.exists():
# 创建目录
astrbot_root.mkdir(parents=True, exist_ok=True)
click.echo(f"Created directory: {astrbot_root}")
else:
click.echo(f"Directory already exists: {astrbot_root}")
config_path: Path = astrbot_root / "config"
plugins_path: Path = astrbot_root / "plugins"
temp_path: Path = astrbot_root / "temp"
config_path.mkdir(parents=True, exist_ok=True)
plugins_path.mkdir(parents=True, exist_ok=True)
temp_path.mkdir(parents=True, exist_ok=True)
click.echo(f"Created directories: {config_path}, {plugins_path}, {temp_path}")
# 检查是否安装了dashboard
asyncio.run(_check_dashboard(astrbot_root))
# region run
@cli.command()
@click.option("--path", "-p", help="AstrBot 数据目录")
def run(path: str | None = None) -> None:
"""Run AstrBot"""
# 解析为绝对路径
try:
from ..core.log import LogBroker
from ..core import db_helper
from ..core.initial_loader import InitialLoader
except ImportError:
from astrbot.core.log import LogBroker
from astrbot.core import db_helper
from astrbot.core.initial_loader import InitialLoader
astrbot_root = _get_astrbot_root(path)
_check_astrbot_root(astrbot_root)
asyncio.run(_check_dashboard(astrbot_root))
log_broker = LogBroker()
db = db_helper
core_lifecycle = InitialLoader(db, log_broker)
try:
asyncio.run(core_lifecycle.start())
except KeyboardInterrupt:
click.echo("接收到退出信号,正在关闭 AstrBot...")
except Exception as e:
click.echo(f"运行时出现错误: {e}")
# region Basic
@cli.command(name="version")
def version() -> None:
"""Show the version of AstrBot"""
click.echo(f"AstrBot version: {VERSION}")
@cli.command()
@click.argument("command_name", required=False, type=str)
def help(command_name: str | None) -> None:
"""Show help information for commands
If COMMAND_NAME is provided, show detailed help for that command.
Otherwise, show general help information.
"""
ctx = click.get_current_context()
if command_name:
# 查找指定命令
command = cli.get_command(ctx, command_name)
if command:
# 显示特定命令的帮助信息
click.echo(command.get_help(ctx))
else:
click.echo(f"Unknown command: {command_name}")
sys.exit(1)
else:
# 显示通用帮助信息
click.echo(cli.get_help(ctx))
if __name__ == "__main__":
cli()
+83 -5
View File
@@ -26,10 +26,12 @@ import base64
import json
import os
import uuid
import asyncio
import typing as T
from enum import Enum
from pydantic.v1 import BaseModel
from astrbot.core.utils.io import download_image_by_url, file_to_base64
from astrbot.core import logger
from astrbot.core.utils.io import download_image_by_url, file_to_base64, download_file
class ComponentType(Enum):
@@ -552,15 +554,91 @@ class Unknown(BaseMessageComponent):
class File(BaseMessageComponent):
"""
目前此消息段只适配了 Napcat。
文件消息段
"""
type: ComponentType = "File"
name: T.Optional[str] = "" # 名字
file: T.Optional[str] = "" # url本地路径
_file: T.Optional[str] = "" # 本地路径
url: T.Optional[str] = "" # url
_downloaded: bool = False # 是否已经下载
def __init__(self, name: str, file: str):
super().__init__(name=name, file=file)
def __init__(self, name: str = "", file: str = "", url: str = ""):
super().__init__(name=name, _file=file, url=url)
@property
def file(self) -> str:
"""
获取文件路径,如果文件不存在但有URL,则同步下载文件
Returns:
str: 文件路径
"""
if self._file and os.path.exists(self._file):
return self._file
if self.url and not self._downloaded:
try:
loop = asyncio.get_event_loop()
if loop.is_running():
logger.warning(
"不可以在异步上下文中同步等待下载! 请使用 await get_file() 代替"
)
return ""
else:
# 等待下载完成
loop.run_until_complete(self._download_file())
if self._file and os.path.exists(self._file):
return self._file
except Exception as e:
logger.error(f"文件下载失败: {e}")
return ""
@file.setter
def file(self, value: str):
"""
向前兼容, 设置file属性, 传入的参数可能是文件路径或URL
Args:
value (str): 文件路径或URL
"""
if value.startswith("http://") or value.startswith("https://"):
self.url = value
else:
self._file = value
async def get_file(self) -> str:
"""
异步获取文件
To 插件开发者: 请注意在使用后清理下载的文件, 以免占用过多空间
Returns:
str: 文件路径
"""
if self._file and os.path.exists(self._file):
return self._file
if self.url:
await self._download_file()
return self._file
return ""
async def _download_file(self):
"""下载文件"""
if self._downloaded:
return
os.makedirs("data/download", exist_ok=True)
filename = self.name or f"{uuid.uuid4().hex}"
file_path = f"data/download/{filename}"
await download_file(self.url, file_path)
self._file = file_path
self._downloaded = True
class WechatEmoji(BaseMessageComponent):
@@ -1,4 +1,3 @@
import os
import time
import asyncio
import logging
@@ -21,7 +20,6 @@ from .aiocqhttp_message_event import AiocqhttpMessageEvent
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from aiocqhttp.exceptions import ActionFailed
from astrbot.core.utils.io import download_file
@register_platform_adapter(
@@ -167,7 +165,9 @@ class AiocqhttpAdapter(Platform):
if "sub_type" in event:
if event["sub_type"] == "poke" and "target_id" in event:
abm.message.append(Poke(qq=str(event["target_id"]), type="poke")) # noqa: F405
abm.message.append(
Poke(qq=str(event["target_id"]), type="poke")
) # noqa: F405
return abm
@@ -227,32 +227,30 @@ class AiocqhttpAdapter(Platform):
if m["data"].get("url") and m["data"].get("url").startswith("http"):
# Lagrange
logger.info("guessing lagrange")
file_name = m["data"].get("file_name", "file")
path = os.path.join("data/temp", file_name)
await download_file(m["data"]["url"], path)
m["data"] = {"file": path, "name": file_name}
a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
abm.message.append(File(name=file_name, url=m["data"]["url"]))
else:
try:
# Napcat, LLBot
ret = await self.bot.call_action(
action="get_file",
file_id=event.message[0]["data"]["file_id"],
)
if not ret.get("file", None):
raise ValueError(f"无法解析文件响应: {ret}")
if not os.path.exists(ret["file"]):
raise FileNotFoundError(
f"文件不存在或者权限问题: {ret['file']}。如果您使用 Docker 部署了 AstrBot 或者消息协议端(Napcat等),请先映射路径。如果路径在 /root 目录下,请用 sudo 打开 AstrBot"
# Napcat
ret = None
if abm.type == MessageType.GROUP_MESSAGE:
ret = await self.bot.call_action(
action="get_group_file_url",
file_id=event.message[0]["data"]["file_id"],
group_id=event.group_id,
)
elif abm.type == MessageType.FRIEND_MESSAGE:
ret = await self.bot.call_action(
action="get_private_file_url",
file_id=event.message[0]["data"]["file_id"],
)
if ret and "url" in ret:
file_url = ret["url"] # https
a = File(name="", url=file_url)
abm.message.append(a)
else:
logger.error(f"获取文件失败: {ret}")
m["data"] = {"file": ret["file"], "name": ret["file_name"]}
a = ComponentTypes[t](**m["data"]) # noqa: F405
abm.message.append(a)
except ActionFailed as e:
logger.error(f"获取文件失败: {e},此消息段将被忽略。")
except BaseException as e:
@@ -3,6 +3,7 @@ import base64
import datetime
import os
import re
import uuid
import threading
import aiohttp
@@ -63,7 +64,7 @@ class SimpleGewechatClient:
"/astrbot-gewechat/callback", view_func=self._callback, methods=["POST"]
)
self.server.add_url_rule(
"/astrbot-gewechat/file/<file_id>",
"/astrbot-gewechat/file/<file_token>",
view_func=self._handle_file,
methods=["GET"],
)
@@ -81,6 +82,11 @@ class SimpleGewechatClient:
self.shutdown_event = asyncio.Event()
self.staged_files = {}
"""存储了允许外部访问的文件列表。auth_token: file_path。通过 register_file 方法注册。"""
self.lock = asyncio.Lock()
async def get_token_id(self):
"""获取 Gewechat Token。"""
async with aiohttp.ClientSession() as session:
@@ -310,9 +316,33 @@ class SimpleGewechatClient:
return quart.jsonify({"r": "AstrBot ACK"})
async def _handle_file(self, file_id):
file_path = f"data/temp/{file_id}"
return await quart.send_file(file_path)
async def _register_file(self, file_path: str) -> str:
"""向 AstrBot 回调服务器 注册一个允许外部访问的文件。
Args:
file_path (str): 文件路径。
Returns:
str: 返回一个 auth_token,文件路径为 file_path。通过 /astrbot-gewechat/file/auth_token 得到文件。
"""
async with self.lock:
if not os.path.exists(file_path):
raise Exception(f"文件不存在: {file_path}")
file_token = str(uuid.uuid4())
self.staged_files[file_token] = file_path
return file_token
async def _handle_file(self, file_token):
async with self.lock:
if file_token not in self.staged_files:
logger.warning(f"请求的文件 {file_token} 不存在。")
return quart.abort(404)
if not os.path.exists(self.staged_files[file_token]):
logger.warning(f"请求的文件 {self.staged_files[file_token]} 不存在。")
return quart.abort(404)
file_path = self.staged_files[file_token]
self.staged_files.pop(file_token, None)
return await quart.send_file(file_path)
async def _set_callback_url(self):
logger.info("设置回调,请等待...")
@@ -462,17 +492,18 @@ class SimpleGewechatClient:
"此次登录需要安全验证码,请在管理面板聊天页输入 /gewe_code 验证码 来验证,如 /gewe_code 123456"
)
else:
status = json_blob["data"]["status"]
nickname = json_blob["data"].get("nickName", "")
if status == 1:
logger.info(f"等待确认...{nickname}")
elif status == 2:
logger.info(f"绿泡泡平台登录成功: {nickname}")
break
elif status == 0:
logger.info("等待扫码...")
else:
logger.warning(f"未知状态: {status}")
if "status" in json_blob["data"]:
status = json_blob["data"]["status"]
nickname = json_blob["data"].get("nickName", "")
if status == 1:
logger.info(f"等待确认...{nickname}")
elif status == 2:
logger.info(f"绿泡泡平台登录成功: {nickname}")
break
elif status == 0:
logger.info("等待扫码...")
else:
logger.warning(f"未知状态: {status}")
await asyncio.sleep(5)
if appid:
@@ -83,15 +83,9 @@ class GewechatPlatformEvent(AstrMessageEvent):
elif isinstance(comp, Image):
img_path = await comp.convert_to_file_path()
# 检查 record_path 是否在 data/temp 目录中
temp_directory = os.path.abspath("data/temp")
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
file_id = os.path.basename(img_path)
img_url = f"{client.file_server_url}/{file_id}"
# 为了安全,向 AstrBot 回调服务注册可被 gewechat 访问的文件,并获得文件 token
token = await client._register_file(img_path)
img_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback img url: {img_url}")
await client.post_image(to_wxid, img_url)
elif isinstance(comp, Video):
@@ -110,20 +104,29 @@ class GewechatPlatformEvent(AstrMessageEvent):
video_url = comp.file
# 根据 url 下载视频
video_filename = f"{uuid.uuid4()}.mp4"
video_path = f"data/temp/{video_filename}"
await download_file(video_url, video_path)
if video_url.startswith("http"):
video_filename = f"{uuid.uuid4()}.mp4"
video_path = f"data/temp/{video_filename}"
await download_file(video_url, video_path)
else:
video_path = video_url
video_token = await client._register_file(video_path)
video_callback_url = f"{client.file_server_url}/{video_token}"
# 获取视频第一帧
thumb_path = f"data/temp/{uuid.uuid4()}.jpg"
thumb_path = f"data/temp/gewechat_video_thumb_{uuid.uuid4()}.jpg"
video_path = video_path.replace(" ", "\\ ")
try:
ff = FFmpeg()
command = f'-i "{video_path}" -ss 0 -vframes 1 "{thumb_path}"'
command = f"-i {video_path} -ss 0 -vframes 1 {thumb_path}"
ff.options(command)
thumb_file_id = os.path.basename(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_file_id}"
thumb_token = await client._register_file(thumb_path)
thumb_url = f"{client.file_server_url}/{thumb_token}"
except Exception as e:
logger.error(f"获取视频第一帧失败: {e}")
# 获取视频时长
try:
from pyffmpeg import FFprobe
@@ -138,15 +141,12 @@ class GewechatPlatformEvent(AstrMessageEvent):
logger.error(f"获取时长失败: {e}")
video_duration = 10
file_id = os.path.basename(video_path)
video_url = f"{client.file_server_url}/{file_id}"
# 发送视频
await client.post_video(
to_wxid, video_url, thumb_url, video_duration
to_wxid, video_callback_url, thumb_url, video_duration
)
# 删除临时视频和缩略图文件
if os.path.exists(video_path):
os.remove(video_path)
# 删除临时缩略图文件
if os.path.exists(thumb_path):
os.remove(thumb_path)
elif isinstance(comp, Record):
@@ -163,8 +163,8 @@ class GewechatPlatformEvent(AstrMessageEvent):
logger.info("Silk 语音文件格式转换至: " + record_path)
if duration == 0:
duration = get_wav_duration(record_path)
file_id = os.path.basename(silk_path)
record_url = f"{client.file_server_url}/{file_id}"
token = await client._register_file(silk_path)
record_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback record url: {record_url}")
await client.post_voice(to_wxid, record_url, duration * 1000)
elif isinstance(comp, File):
@@ -177,10 +177,10 @@ class GewechatPlatformEvent(AstrMessageEvent):
else:
file_path = file_path
file_id = os.path.basename(file_path)
file_url = f"{client.file_server_url}/{file_id}"
token = await client._register_file(file_path)
file_url = f"{client.file_server_url}/{token}"
logger.debug(f"gewe callback file url: {file_url}")
await client.post_file(to_wxid, file_url, file_id)
await client.post_file(to_wxid, file_url, file_name)
elif isinstance(comp, Emoji):
await client.post_emoji(to_wxid, comp.md5, comp.md5_len, comp.cdnurl)
elif isinstance(comp, At):
@@ -58,8 +58,12 @@ class TelegramPlatformAdapter(Platform):
self.base_url = base_url
self.enable_command_register = self.config.get("telegram_command_register", True)
self.enable_command_refresh = self.config.get("telegram_command_auto_refresh", True)
self.enable_command_register = self.config.get(
"telegram_command_register", True
)
self.enable_command_refresh = self.config.get(
"telegram_command_auto_refresh", True
)
self.last_command_hash = None
self.application = (
@@ -123,7 +127,9 @@ class TelegramPlatformAdapter(Platform):
commands = self.collect_commands()
if commands:
current_hash = hash(tuple((cmd.command, cmd.description) for cmd in commands))
current_hash = hash(
tuple((cmd.command, cmd.description) for cmd in commands)
)
if current_hash == self.last_command_hash:
return
self.last_command_hash = current_hash
+5 -5
View File
@@ -209,20 +209,20 @@ async def get_dashboard_version():
return None
async def download_dashboard():
async def download_dashboard(path: str = "data/dashboard.zip", extract_path: str = "data"):
"""下载管理面板文件"""
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
try:
await download_file(
dashboard_release_url, "data/dashboard.zip", show_progress=True
dashboard_release_url, path, show_progress=True
)
except BaseException as _:
dashboard_release_url = (
"https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
)
await download_file(
dashboard_release_url, "data/dashboard.zip", show_progress=True
dashboard_release_url, path, show_progress=True
)
print("解压管理面板文件中...")
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
z.extractall("data")
with zipfile.ZipFile(path, "r") as z:
z.extractall(extract_path)
+3 -1
View File
@@ -145,7 +145,9 @@ class PluginRoute(Route):
if handler.event_type == EventType.AdapterMessageEvent:
# 处理平台适配器消息事件
has_admin = False
for filter in (
for (
filter
) in (
handler.event_filters
): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高
if isinstance(filter, CommandFilter):
+10 -1
View File
@@ -560,7 +560,16 @@ export default {
// 过滤后的市场服务器
filteredMarketplaceServers() {
return this.marketplaceServers;
if (!this.marketplaceSearch.trim()) {
return this.marketplaceServers;
}
const searchTerm = this.marketplaceSearch.toLowerCase();
return this.marketplaceServers.filter(server =>
server.name.toLowerCase().includes(searchTerm) ||
(server.name_h && server.name_h.toLowerCase().includes(searchTerm)) ||
(server.description && server.description.toLowerCase().includes(searchTerm))
);
},
},
+7
View File
@@ -40,6 +40,13 @@ dependencies = [
"wechatpy>=1.8.18",
]
[project.scripts]
astrbot = "astrbot.cli.__main__:cli"
[build-system]
requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"
[tool.ruff]
exclude = [
"astrbot/core/utils/t2i/local_strategy.py",
Generated
+655 -564
View File
File diff suppressed because it is too large Load Diff