Files
AstrBot/astrbot/dashboard/server.py
T

520 lines
19 KiB
Python

import asyncio
import hashlib
import logging
import os
import platform
import socket
from collections.abc import Callable
from datetime import datetime
from ipaddress import IPv4Address, IPv6Address, ip_address
from pathlib import Path
from typing import Protocol
import jwt
import psutil
import werkzeug.exceptions
from flask.json.provider import DefaultJSONProvider
from hypercorn.asyncio import serve
from hypercorn.config import Config as HyperConfig
from quart import Quart, g, jsonify, request
from quart.logging import default_handler
from quart_cors import cors
from astrbot.core import logger
from astrbot.core.config.default import VERSION
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.db import BaseDatabase
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
from astrbot.core.utils.datetime_utils import to_utc_isoformat
from astrbot.core.utils.io import get_local_ip_addresses
from .routes import *
from .routes.api_key import ALL_OPEN_API_SCOPES
# Static assets shipped inside the wheel (built during `hatch build`).
_BUNDLED_DIST = Path(__file__).parent / "dist"
class _AddrWithPort(Protocol):
port: int
APP: Quart
def _parse_env_bool(value: str | None, default: bool) -> bool:
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
class AstrBotJSONProvider(DefaultJSONProvider):
def default(self, obj):
if isinstance(obj, datetime):
return to_utc_isoformat(obj)
return super().default(obj)
class AstrBotDashboard:
"""AstrBot Web Dashboard"""
ALLOWED_ENDPOINT_PREFIXES = (
"/api/auth/login",
"/api/file",
"/api/platform/webhook",
"/api/stat/start-time",
"/api/backup/download",
)
def __init__(
self,
core_lifecycle: AstrBotCoreLifecycle,
db: BaseDatabase,
shutdown_event: asyncio.Event,
webui_dir: str | None = None,
) -> None:
self.core_lifecycle = core_lifecycle
self.config = core_lifecycle.astrbot_config
self.db = db
self.shutdown_event = shutdown_event
self.enable_webui = self._check_webui_enabled()
self._init_paths(webui_dir)
self._init_app()
self.context = RouteContext(self.config, self.app)
self._init_routes(db)
self._init_plugin_route_index()
self._init_jwt_secret()
def _check_webui_enabled(self) -> bool:
cfg = self.config.get("dashboard", {})
_env = os.environ.get("DASHBOARD_ENABLE")
if _env is not None:
return _env.lower() in ("true", "1", "yes")
return cfg.get("enable", True)
def _init_paths(self, webui_dir: str | None):
# Path priority:
# 1. Explicit webui_dir argument
# 2. data/dist/ (user-installed / manually updated dashboard)
# 3. astrbot/dashboard/dist/ (bundled with the wheel)
if webui_dir and os.path.exists(webui_dir):
self.data_path = os.path.abspath(webui_dir)
else:
user_dist = os.path.join(get_astrbot_data_path(), "dist")
if os.path.exists(user_dist):
self.data_path = os.path.abspath(user_dist)
elif _BUNDLED_DIST.exists():
self.data_path = str(_BUNDLED_DIST.absolute())
logger.info("Using bundled dashboard dist: %s", self.data_path)
else:
self.data_path = os.path.abspath(user_dist)
if self.enable_webui and not (Path(self.data_path) / "index.html").exists():
raise RuntimeError(
f"Dashboard static assets not found: index.html is missing in {self.data_path}. "
"Please run the WebUI build step."
)
def _init_app(self):
"""初始化 Quart 应用"""
global APP
self.app = Quart(
"AstrBotDashboard",
static_folder=self.data_path,
static_url_path="/",
)
APP = self.app
self.app.json_provider_class = DefaultJSONProvider
self.app.config["MAX_CONTENT_LENGTH"] = 128 * 1024 * 1024 # 128MB
self.app.json = AstrBotJSONProvider(self.app)
self.app.json.sort_keys = False
# 配置 CORS
self.app = cors(
self.app,
allow_origin="*",
allow_headers=["Authorization", "Content-Type", "X-API-Key"],
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
)
@self.app.route("/")
async def index():
if not self.enable_webui:
return "WebUI is disabled."
try:
return await self.app.send_static_file("index.html")
except werkzeug.exceptions.NotFound:
logger.error(f"Dashboard index.html not found in {self.data_path}")
return "Dashboard files not found.", 404
@self.app.errorhandler(404)
async def not_found(e):
if not self.enable_webui:
return "WebUI is disabled."
if request.path.startswith("/api/"):
return jsonify(Response().error("Not Found").to_json()), 404
try:
return await self.app.send_static_file("index.html")
except werkzeug.exceptions.NotFound:
return "Dashboard files not found.", 404
@self.app.before_serving
async def startup():
pass
@self.app.after_serving
async def shutdown():
pass
self.app.before_request(self.auth_middleware)
logging.getLogger(self.app.name).removeHandler(default_handler)
def _init_routes(self, db: BaseDatabase):
UpdateRoute(
self.context, self.core_lifecycle.astrbot_updator, self.core_lifecycle
)
StatRoute(self.context, db, self.core_lifecycle)
PluginRoute(
self.context, self.core_lifecycle, self.core_lifecycle.plugin_manager
)
self.command_route = CommandRoute(self.context)
self.cr = ConfigRoute(self.context, self.core_lifecycle)
self.lr = LogRoute(self.context, self.core_lifecycle.log_broker)
self.sfr = StaticFileRoute(self.context)
self.ar = AuthRoute(self.context)
self.api_key_route = ApiKeyRoute(self.context, db)
self.chat_route = ChatRoute(self.context, db, self.core_lifecycle)
self.open_api_route = OpenApiRoute(
self.context,
db,
self.core_lifecycle,
self.chat_route,
)
self.chatui_project_route = ChatUIProjectRoute(self.context, db)
self.tools_root = ToolsRoute(self.context, self.core_lifecycle)
self.subagent_route = SubAgentRoute(self.context, self.core_lifecycle)
self.skills_route = SkillsRoute(self.context, self.core_lifecycle)
self.conversation_route = ConversationRoute(
self.context, db, self.core_lifecycle
)
self.file_route = FileRoute(self.context)
self.session_management_route = SessionManagementRoute(
self.context,
db,
self.core_lifecycle,
)
self.persona_route = PersonaRoute(self.context, db, self.core_lifecycle)
self.cron_route = CronRoute(self.context, self.core_lifecycle)
self.t2i_route = T2iRoute(self.context, self.core_lifecycle)
self.kb_route = KnowledgeBaseRoute(self.context, self.core_lifecycle)
self.platform_route = PlatformRoute(self.context, self.core_lifecycle)
self.backup_route = BackupRoute(self.context, db, self.core_lifecycle)
self.live_chat_route = LiveChatRoute(self.context, db, self.core_lifecycle)
self.app.add_url_rule(
"/api/plug/<path:subpath>",
view_func=self.srv_plug_route,
methods=["GET", "POST"],
)
def _init_plugin_route_index(self):
"""将插件路由索引,避免 O(n) 查找"""
self._plugin_route_map: dict[tuple[str, str], Callable] = {}
for (
route,
handler,
methods,
_,
) in self.core_lifecycle.star_context.registered_web_apis:
for method in methods:
self._plugin_route_map[(route, method)] = handler
def _init_jwt_secret(self):
dashboard_cfg = self.config.setdefault("dashboard", {})
if not dashboard_cfg.get("jwt_secret"):
dashboard_cfg["jwt_secret"] = os.urandom(32).hex()
self.config.save_config()
logger.info("Initialized random JWT secret for dashboard.")
self._jwt_secret = dashboard_cfg["jwt_secret"]
async def auth_middleware(self):
# 放行CORS预检请求
if request.method == "OPTIONS":
return None
if not request.path.startswith("/api"):
return None
if request.path.startswith("/api/v1"):
raw_key = self._extract_raw_api_key()
if not raw_key:
r = jsonify(Response().error("Missing API key").__dict__)
r.status_code = 401
return r
key_hash = hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
b"astrbot_api_key",
100_000,
).hex()
api_key = await self.db.get_active_api_key_by_hash(key_hash)
if not api_key:
r = jsonify(Response().error("Invalid API key").__dict__)
r.status_code = 401
return r
if isinstance(api_key.scopes, list):
scopes = api_key.scopes
else:
scopes = list(ALL_OPEN_API_SCOPES)
required_scope = self._get_required_open_api_scope(request.path)
if required_scope and "*" not in scopes and required_scope not in scopes:
r = jsonify(Response().error("Insufficient API key scope").__dict__)
r.status_code = 403
return r
g.api_key_id = api_key.key_id
g.api_key_scopes = scopes
g.username = f"api_key:{api_key.key_id}"
await self.db.touch_api_key(api_key.key_id)
return None
if any(request.path.startswith(p) for p in self.ALLOWED_ENDPOINT_PREFIXES):
return None
token = request.headers.get("Authorization")
if not token:
return self._unauthorized("未授权")
try:
payload = jwt.decode(
token.removeprefix("Bearer "),
self._jwt_secret,
algorithms=["HS256"],
options={"require": ["username"]},
)
g.username = payload["username"]
except jwt.ExpiredSignatureError:
return self._unauthorized("Token 过期")
except jwt.PyJWTError:
return self._unauthorized("Token 无效")
@staticmethod
def _unauthorized(msg: str):
r = jsonify(Response().error(msg).to_json())
r.status_code = 401
return r
async def srv_plug_route(self, subpath: str, *args, **kwargs):
handler = self._plugin_route_map.get((f"/{subpath}", request.method))
if not handler:
return jsonify(Response().error("未找到该路由").to_json())
try:
return await handler(*args, **kwargs)
except Exception:
logger.exception("插件 Web API 执行异常")
return jsonify(Response().error("插件 Web API 执行异常").to_json())
@staticmethod
def _extract_raw_api_key() -> str | None:
if key := request.args.get("api_key"):
return key.strip()
if key := request.args.get("key"):
return key.strip()
if key := request.headers.get("X-API-Key"):
return key.strip()
auth_header = request.headers.get("Authorization", "").strip()
if auth_header.startswith("Bearer "):
return auth_header.removeprefix("Bearer ").strip()
if auth_header.startswith("ApiKey "):
return auth_header.removeprefix("ApiKey ").strip()
return None
@staticmethod
def _get_required_open_api_scope(path: str) -> str | None:
scope_map = {
"/api/v1/chat": "chat",
"/api/v1/chat/ws": "chat",
"/api/v1/chat/sessions": "chat",
"/api/v1/configs": "config",
"/api/v1/file": "file",
"/api/v1/im/message": "im",
"/api/v1/im/bots": "im",
}
return scope_map.get(path)
def check_port_in_use(self, host: str, port: int) -> bool:
"""跨平台检测端口是否被占用"""
family = socket.AF_INET6 if ":" in host else socket.AF_INET
try:
with socket.socket(family, socket.SOCK_STREAM) as s:
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
s.bind((host, port))
return False
except OSError:
return True
def get_process_using_port(self, port: int) -> str:
"""获取占用端口的进程信息"""
try:
for proc in psutil.process_iter(["pid", "name"]):
try:
connections = proc.net_connections()
for conn in connections:
if conn.laddr.port == port:
return f"PID: {proc.info['pid']}, Name: {proc.info['name']}"
except (
psutil.NoSuchProcess,
psutil.AccessDenied,
psutil.ZombieProcess,
):
pass
except Exception as e:
return f"获取进程信息失败: {e!s}"
return "未知进程"
async def run(self) -> None:
"""Run dashboard server (blocking)"""
if not self.enable_webui:
logger.warning(
"WebUI 已禁用 (dashboard.enable=false or DASHBOARD_ENABLE=false)"
)
dashboard_config = self.config.get("dashboard", {})
host = os.environ.get("DASHBOARD_HOST") or dashboard_config.get(
"host", "0.0.0.0"
)
port = int(
os.environ.get("DASHBOARD_PORT") or dashboard_config.get("port", 6185)
)
ssl_config = dashboard_config.get("ssl", {})
ssl_enable = _parse_env_bool(
os.environ.get("DASHBOARD_SSL_ENABLE"),
ssl_config.get("enable", False),
)
scheme = "https" if ssl_enable else "http"
display_host = f"[{host}]" if ":" in host else host
if self.enable_webui:
logger.info(
"正在启动 WebUI + API, 监听地址: %s://%s:%s",
scheme,
display_host,
port,
)
else:
logger.info(
"正在启动 API Server (WebUI 已分离), 监听地址: %s://%s:%s",
scheme,
display_host,
port,
)
check_hosts = {host}
if host not in ("127.0.0.1", "localhost", "::1"):
check_hosts.add("127.0.0.1")
for check_host in check_hosts:
if self.check_port_in_use(check_host, port):
info = self.get_process_using_port(port)
raise RuntimeError(f"端口 {port} 已被占用\n{info}")
if self.enable_webui:
self._print_access_urls(host, port, scheme)
# 配置 Hypercorn
config = HyperConfig()
binds: list[str] = [self._build_bind(host, port)]
if host == "::" and platform.system() in ("Windows", "Darwin"):
binds.append(self._build_bind("0.0.0.0", port))
config.bind = binds
if ssl_enable:
cert_file = (
os.environ.get("DASHBOARD_SSL_CERT")
or os.environ.get("ASTRBOT_DASHBOARD_SSL_CERT")
or ssl_config.get("cert_file", "")
)
key_file = (
os.environ.get("DASHBOARD_SSL_KEY")
or os.environ.get("ASTRBOT_DASHBOARD_SSL_KEY")
or ssl_config.get("key_file", "")
)
ca_certs = (
os.environ.get("DASHBOARD_SSL_CA_CERTS")
or os.environ.get("ASTRBOT_DASHBOARD_SSL_CA_CERTS")
or ssl_config.get("ca_certs", "")
)
cert_path = Path(cert_file).expanduser()
key_path = Path(key_file).expanduser()
if not cert_file or not key_file:
raise ValueError(
"dashboard.ssl.enable 为 true 时,必须配置 cert_file 和 key_file。",
)
if not cert_path.is_file():
raise ValueError(f"SSL 证书文件不存在: {cert_path}")
if not key_path.is_file():
raise ValueError(f"SSL 私钥文件不存在: {key_path}")
config.certfile = str(cert_path.resolve())
config.keyfile = str(key_path.resolve())
if ca_certs:
ca_path = Path(ca_certs).expanduser()
if not ca_path.is_file():
raise ValueError(f"SSL CA 证书文件不存在: {ca_path}")
config.ca_certs = str(ca_path.resolve())
# 根据配置决定是否禁用访问日志
disable_access_log = dashboard_config.get("disable_access_log", True)
if disable_access_log:
config.accesslog = None
else:
config.accesslog = "-"
config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s"
await serve(self.app, config, shutdown_trigger=self.shutdown_trigger)
@staticmethod
def _build_bind(host: str, port: int) -> str:
try:
ip: IPv4Address | IPv6Address = ip_address(host)
return f"[{ip}]:{port}" if ip.version == 6 else f"{ip}:{port}"
except ValueError:
return f"{host}:{port}"
def _print_access_urls(self, host: str, port: int, scheme: str = "http") -> None:
local_ips: list[IPv4Address | IPv6Address] = get_local_ip_addresses()
parts = [f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动\n\n"]
parts.append(f" ➜ 本地: {scheme}://localhost:{port}\n")
if host in ("::", "0.0.0.0"):
for ip in local_ips:
if ip.is_loopback:
continue
if ip.version == 6:
display_url = f"{scheme}://[{ip}]:{port}"
else:
display_url = f"{scheme}://{ip}:{port}"
parts.append(f" ➜ 网络: {display_url}\n")
parts.append(" ➜ 默认用户名和密码: astrbot\n ✨✨✨\n")
if not local_ips:
parts.append(
"可在 data/cmd_config.json 中配置 dashboard.host 以便远程访问。\n"
)
logger.info("".join(parts))
async def shutdown_trigger(self):
await self.shutdown_event.wait()