diff --git a/.python-version b/.python-version index fdcfcfdfc..e4fba2183 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.12 \ No newline at end of file +3.12 diff --git a/astrbot/cli/commands/cmd_init.py b/astrbot/cli/commands/cmd_init.py index e7e047cca..5d1c8e93d 100644 --- a/astrbot/cli/commands/cmd_init.py +++ b/astrbot/cli/commands/cmd_init.py @@ -30,8 +30,13 @@ async def initialize_astrbot(astrbot_root: Path) -> None: for name, path in paths.items(): path.mkdir(parents=True, exist_ok=True) click.echo(f"{'Created' if not path.exists() else 'Directory exists'}: {path}") - - await check_dashboard(astrbot_root / "data") + if click.confirm( + "是否需要集成式 WebUI?(个人电脑推荐,服务器不推荐)", + default=True, + ): + await check_dashboard(astrbot_root) + else: + click.echo("你可以使用在线面版(v4.14.4+),填写后端地址的方式来控制。") @click.command() diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index de09e5852..98acdcd19 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -15,7 +15,8 @@ async def run_astrbot(astrbot_root: Path) -> None: from astrbot.core import LogBroker, LogManager, db_helper, logger from astrbot.core.initial_loader import InitialLoader - await check_dashboard(astrbot_root / "data") + if os.environ.get("DASHBOARD_ENABLE") == "True": + await check_dashboard(astrbot_root) log_broker = LogBroker() LogManager.set_queue_handler(logger, log_broker) @@ -27,9 +28,16 @@ async def run_astrbot(astrbot_root: Path) -> None: @click.option("--reload", "-r", is_flag=True, help="Auto-reload plugins") +@click.option("--host", "-H", help="AstrBot Dashboard Host", required=False, type=str) @click.option("--port", "-p", help="AstrBot Dashboard port", required=False, type=str) +@click.option( + "--backend-only", + is_flag=True, + default=False, + help="Disable WebUI, run backend only", +) @click.command() -def run(reload: bool, port: str) -> None: +def run(reload: bool, host: str, port: str, backend_only: bool) -> None: """Run AstrBot""" try: os.environ["ASTRBOT_CLI"] = "1" @@ -43,8 +51,11 @@ def run(reload: bool, port: str) -> None: os.environ["ASTRBOT_ROOT"] = str(astrbot_root) sys.path.insert(0, str(astrbot_root)) - if port: + if port is not None: os.environ["DASHBOARD_PORT"] = port + if host is not None: + os.environ["DASHBOARD_HOST"] = host + os.environ["DASHBOARD_ENABLE"] = str(not backend_only) if reload: click.echo("Plugin auto-reload enabled") diff --git a/astrbot/cli/utils/basic.py b/astrbot/cli/utils/basic.py index 16b03218e..b90fd6e11 100644 --- a/astrbot/cli/utils/basic.py +++ b/astrbot/cli/utils/basic.py @@ -47,7 +47,7 @@ async def check_dashboard(astrbot_root: Path) -> None: click.echo("Installing dashboard...") await download_dashboard( path="data/dashboard.zip", - extract_path=str(astrbot_root), + extract_path=str(astrbot_root / "data"), version=f"v{VERSION}", latest=False, ) @@ -62,7 +62,7 @@ async def check_dashboard(astrbot_root: Path) -> None: click.echo(f"Dashboard version: {version}") await download_dashboard( path="data/dashboard.zip", - extract_path=str(astrbot_root), + extract_path=str(astrbot_root / "data"), version=f"v{VERSION}", latest=False, ) @@ -73,8 +73,8 @@ async def check_dashboard(astrbot_root: Path) -> None: click.echo("Initializing dashboard directory...") try: await download_dashboard( - path=str(astrbot_root / "dashboard.zip"), - extract_path=str(astrbot_root), + path=str(astrbot_root / "data" / "dashboard.zip"), + extract_path=str(astrbot_root / "data"), version=f"v{VERSION}", latest=False, ) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index b56592674..f9ad47dbf 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -1,3 +1,4 @@ +import asyncio import base64 import logging import os @@ -7,6 +8,7 @@ import ssl import time import uuid import zipfile +from ipaddress import IPv4Address, IPv6Address, ip_address from pathlib import Path import aiohttp @@ -206,18 +208,53 @@ def file_to_base64(file_path: str) -> str: return "base64://" + base64_str -def get_local_ip_addresses(): +def get_local_ip_addresses() -> list[IPv4Address | IPv6Address]: net_interfaces = psutil.net_if_addrs() - network_ips = [] + network_ips: list[IPv4Address | IPv6Address] = [] - for interface, addrs in net_interfaces.items(): + for _, addrs in net_interfaces.items(): for addr in addrs: - if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET - network_ips.append(addr.address) + if addr.family == socket.AF_INET: + network_ips.append(ip_address(addr.address)) + elif addr.family == socket.AF_INET6: + # 过滤掉 IPv6 的 link-local 地址(fe80:...) + ip = ip_address(addr.address.split("%")[0]) # 处理带 zone index 的情况 + if not ip.is_link_local: + network_ips.append(ip) return network_ips +async def get_public_ip_address() -> list[IPv4Address | IPv6Address]: + urls = [ + "https://api64.ipify.org", + "https://ident.me", + "https://ifconfig.me", + "https://icanhazip.com", + ] + found_ips: dict[int, IPv4Address | IPv6Address] = {} + + async def fetch(session: aiohttp.ClientSession, url: str): + try: + async with session.get(url, timeout=3) as resp: + if resp.status == 200: + raw_ip = (await resp.text()).strip() + ip = ip_address(raw_ip) + if ip.version not in found_ips: + found_ips[ip.version] = ip + except Exception as e: + # Ignore errors from individual services so that a single failing + # endpoint does not prevent discovering the public IP from others. + logger.debug("Failed to fetch public IP from %s: %s", url, e) + + async with aiohttp.ClientSession() as session: + tasks = [fetch(session, url) for url in urls] + await asyncio.gather(*tasks) + + # 返回找到的所有 IP 对象列表 + return list(found_ips.values()) + + async def get_dashboard_version(): # First check user data directory (manually updated / downloaded dashboard). dist_dir = os.path.join(get_astrbot_data_path(), "dist") diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index fbbd0c7a0..7e6e79146 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -9,16 +9,19 @@ from .conversation import ConversationRoute from .cron import CronRoute from .file import FileRoute from .knowledge_base import KnowledgeBaseRoute +from .live_chat import LiveChatRoute from .log import LogRoute from .open_api import OpenApiRoute from .persona import PersonaRoute from .platform import PlatformRoute from .plugin import PluginRoute +from .route import Response, RouteContext from .session_management import SessionManagementRoute from .skills import SkillsRoute from .stat import StatRoute from .static_file import StaticFileRoute from .subagent import SubAgentRoute +from .t2i import T2iRoute from .tools import ToolsRoute from .update import UpdateRoute @@ -46,4 +49,8 @@ __all__ = [ "ToolsRoute", "SkillsRoute", "UpdateRoute", + "T2iRoute", + "LiveChatRoute", + "Response", + "RouteContext", ] diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 53c623443..4fdc37971 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import asdict, dataclass from quart import Quart @@ -57,3 +57,7 @@ class Response: self.data = data self.message = message return self + + def to_json(self): + # Return a plain dict so callers can safely wrap with jsonify() + return asdict(self) diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py index e056b6c5a..15fec95d1 100644 --- a/astrbot/dashboard/routes/static_file.py +++ b/astrbot/dashboard/routes/static_file.py @@ -5,6 +5,9 @@ class StaticFileRoute(Route): def __init__(self, context: RouteContext) -> None: super().__init__(context) + if "index" in self.app.view_functions: + return + index_ = [ "/", "/auth/login", diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index a4742aa67..7fa50fa1b 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -2,10 +2,13 @@ import asyncio import hashlib import logging import os +import platform import socket +from collections.abc import Callable from datetime import datetime +from ipaddress import IPv4Address, IPv6Address, ip_address from pathlib import Path -from typing import Protocol, cast +from typing import Protocol import jwt import psutil @@ -14,6 +17,7 @@ from hypercorn.asyncio import serve from hypercorn.config import Config as HyperConfig from quart import Quart, g, jsonify, request from quart.logging import default_handler +from quart_cors import cors from astrbot.core import logger from astrbot.core.config.default import VERSION @@ -25,13 +29,6 @@ from astrbot.core.utils.io import get_local_ip_addresses from .routes import * from .routes.api_key import ALL_OPEN_API_SCOPES -from .routes.backup import BackupRoute -from .routes.live_chat import LiveChatRoute -from .routes.platform import PlatformRoute -from .routes.route import Response, RouteContext -from .routes.session_management import SessionManagementRoute -from .routes.subagent import SubAgentRoute -from .routes.t2i import T2iRoute # Static assets shipped inside the wheel (built during `hatch build`). _BUNDLED_DIST = Path(__file__).parent / "dist" @@ -58,6 +55,16 @@ class AstrBotJSONProvider(DefaultJSONProvider): class AstrBotDashboard: + """AstrBot Web Dashboard""" + + ALLOWED_ENDPOINT_PREFIXES = ( + "/api/auth/login", + "/api/file", + "/api/platform/webhook", + "/api/stat/start-time", + "/api/backup/download", + ) + def __init__( self, core_lifecycle: AstrBotCoreLifecycle, @@ -68,7 +75,26 @@ class AstrBotDashboard: self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config self.db = db + self.shutdown_event = shutdown_event + self.enable_webui = self._check_webui_enabled() + + self._init_paths(webui_dir) + self._init_app() + self.context = RouteContext(self.config, self.app) + + self._init_routes(db) + self._init_plugin_route_index() + self._init_jwt_secret() + + def _check_webui_enabled(self) -> bool: + cfg = self.config.get("dashboard", {}) + _env = os.environ.get("DASHBOARD_ENABLE") + if _env is not None: + return _env.lower() in ("true", "1", "yes") + return cfg.get("enable", True) + + def _init_paths(self, webui_dir: str | None): # Path priority: # 1. Explicit webui_dir argument # 2. data/dist/ (user-installed / manually updated dashboard) @@ -83,62 +109,96 @@ class AstrBotDashboard: self.data_path = str(_BUNDLED_DIST) logger.info("Using bundled dashboard dist: %s", self.data_path) else: - # Fall back to expected user path (will fail gracefully later) self.data_path = os.path.abspath(user_dist) - self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/") - APP = self.app # noqa - self.app.config["MAX_CONTENT_LENGTH"] = ( - 128 * 1024 * 1024 - ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB + def _init_app(self): + """初始化 Quart 应用""" + global APP + self.app = Quart( + "AstrBotDashboard", + static_folder=self.data_path, + static_url_path="/", + ) + APP = self.app + self.app.json_provider_class = DefaultJSONProvider + self.app.config["MAX_CONTENT_LENGTH"] = 128 * 1024 * 1024 # 128MB self.app.json = AstrBotJSONProvider(self.app) self.app.json.sort_keys = False - self.app.before_request(self.auth_middleware) - # token 用于验证请求 - logging.getLogger(self.app.name).removeHandler(default_handler) - self.context = RouteContext(self.config, self.app) - self.ur = UpdateRoute( - self.context, - core_lifecycle.astrbot_updator, - core_lifecycle, + + # 配置 CORS + self.app = cors( + self.app, + allow_origin="*", + allow_headers=["Authorization", "Content-Type", "X-API-Key"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], ) - self.sr = StatRoute(self.context, db, core_lifecycle) - self.pr = PluginRoute( - self.context, - core_lifecycle, - core_lifecycle.plugin_manager, + + @self.app.route("/") + async def index(): + if not self.enable_webui: + return "WebUI is disabled." + return await self.app.send_static_file("index.html") + + @self.app.errorhandler(404) + async def not_found(e): + if not self.enable_webui: + return "WebUI is disabled." + if request.path.startswith("/api/"): + return jsonify(Response().error("Not Found").to_json()), 404 + return await self.app.send_static_file("index.html") + + @self.app.before_serving + async def startup(): + pass + + @self.app.after_serving + async def shutdown(): + pass + + self.app.before_request(self.auth_middleware) + logging.getLogger(self.app.name).removeHandler(default_handler) + + def _init_routes(self, db: BaseDatabase): + UpdateRoute( + self.context, self.core_lifecycle.astrbot_updator, self.core_lifecycle + ) + StatRoute(self.context, db, self.core_lifecycle) + PluginRoute( + self.context, self.core_lifecycle, self.core_lifecycle.plugin_manager ) self.command_route = CommandRoute(self.context) - self.cr = ConfigRoute(self.context, core_lifecycle) - self.lr = LogRoute(self.context, core_lifecycle.log_broker) + self.cr = ConfigRoute(self.context, self.core_lifecycle) + self.lr = LogRoute(self.context, self.core_lifecycle.log_broker) self.sfr = StaticFileRoute(self.context) self.ar = AuthRoute(self.context) self.api_key_route = ApiKeyRoute(self.context, db) - self.chat_route = ChatRoute(self.context, db, core_lifecycle) + self.chat_route = ChatRoute(self.context, db, self.core_lifecycle) self.open_api_route = OpenApiRoute( self.context, db, - core_lifecycle, + self.core_lifecycle, self.chat_route, ) self.chatui_project_route = ChatUIProjectRoute(self.context, db) - self.tools_root = ToolsRoute(self.context, core_lifecycle) - self.subagent_route = SubAgentRoute(self.context, core_lifecycle) - self.skills_route = SkillsRoute(self.context, core_lifecycle) - self.conversation_route = ConversationRoute(self.context, db, core_lifecycle) + self.tools_root = ToolsRoute(self.context, self.core_lifecycle) + self.subagent_route = SubAgentRoute(self.context, self.core_lifecycle) + self.skills_route = SkillsRoute(self.context, self.core_lifecycle) + self.conversation_route = ConversationRoute( + self.context, db, self.core_lifecycle + ) self.file_route = FileRoute(self.context) self.session_management_route = SessionManagementRoute( self.context, db, - core_lifecycle, + self.core_lifecycle, ) - self.persona_route = PersonaRoute(self.context, db, core_lifecycle) - self.cron_route = CronRoute(self.context, core_lifecycle) - self.t2i_route = T2iRoute(self.context, core_lifecycle) - self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle) - self.platform_route = PlatformRoute(self.context, core_lifecycle) - self.backup_route = BackupRoute(self.context, db, core_lifecycle) - self.live_chat_route = LiveChatRoute(self.context, db, core_lifecycle) + self.persona_route = PersonaRoute(self.context, db, self.core_lifecycle) + self.cron_route = CronRoute(self.context, self.core_lifecycle) + self.t2i_route = T2iRoute(self.context, self.core_lifecycle) + self.kb_route = KnowledgeBaseRoute(self.context, self.core_lifecycle) + self.platform_route = PlatformRoute(self.context, self.core_lifecycle) + self.backup_route = BackupRoute(self.context, db, self.core_lifecycle) + self.live_chat_route = LiveChatRoute(self.context, db, self.core_lifecycle) self.app.add_url_rule( "/api/plug/", @@ -146,20 +206,31 @@ class AstrBotDashboard: methods=["GET", "POST"], ) - self.shutdown_event = shutdown_event + def _init_plugin_route_index(self): + """将插件路由索引,避免 O(n) 查找""" + self._plugin_route_map: dict[tuple[str, str], Callable] = {} - self._init_jwt_secret() + for ( + route, + handler, + methods, + _, + ) in self.core_lifecycle.star_context.registered_web_apis: + for method in methods: + self._plugin_route_map[(route, method)] = handler - async def srv_plug_route(self, subpath, *args, **kwargs): - """插件路由""" - registered_web_apis = self.core_lifecycle.star_context.registered_web_apis - for api in registered_web_apis: - route, view_handler, methods, _ = api - if route == f"/{subpath}" and request.method in methods: - return await view_handler(*args, **kwargs) - return jsonify(Response().error("未找到该路由").__dict__) + def _init_jwt_secret(self): + dashboard_cfg = self.config.setdefault("dashboard", {}) + if not dashboard_cfg.get("jwt_secret"): + dashboard_cfg["jwt_secret"] = os.urandom(32).hex() + self.config.save_config() + logger.info("Initialized random JWT secret for dashboard.") + self._jwt_secret = dashboard_cfg["jwt_secret"] async def auth_middleware(self): + # 放行CORS预检请求 + if request.method == "OPTIONS": + return None if not request.path.startswith("/api"): return None if request.path.startswith("/api/v1"): @@ -196,33 +267,42 @@ class AstrBotDashboard: await self.db.touch_api_key(api_key.key_id) return None - allowed_endpoints = [ - "/api/auth/login", - "/api/file", - "/api/platform/webhook", - "/api/stat/start-time", - "/api/backup/download", # 备份下载使用 URL 参数传递 token - ] - if any(request.path.startswith(prefix) for prefix in allowed_endpoints): + if any(request.path.startswith(p) for p in self.ALLOWED_ENDPOINT_PREFIXES): return None - # 声明 JWT + token = request.headers.get("Authorization") if not token: - r = jsonify(Response().error("未授权").__dict__) - r.status_code = 401 - return r - token = token.removeprefix("Bearer ") + return self._unauthorized("未授权") + try: - payload = jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) + payload = jwt.decode( + token.removeprefix("Bearer "), + self._jwt_secret, + algorithms=["HS256"], + options={"require": ["username"]}, + ) g.username = payload["username"] except jwt.ExpiredSignatureError: - r = jsonify(Response().error("Token 过期").__dict__) - r.status_code = 401 - return r - except jwt.InvalidTokenError: - r = jsonify(Response().error("Token 无效").__dict__) - r.status_code = 401 - return r + return self._unauthorized("Token 过期") + except jwt.PyJWTError: + return self._unauthorized("Token 无效") + + @staticmethod + def _unauthorized(msg: str): + r = jsonify(Response().error(msg).to_json()) + r.status_code = 401 + return r + + async def srv_plug_route(self, subpath: str, *args, **kwargs): + handler = self._plugin_route_map.get((f"/{subpath}", request.method)) + if not handler: + return jsonify(Response().error("未找到该路由").to_json()) + + try: + return await handler(*args, **kwargs) + except Exception: + logger.exception("插件 Web API 执行异常") + return jsonify(Response().error("插件 Web API 执行异常").to_json()) @staticmethod def _extract_raw_api_key() -> str | None: @@ -252,126 +332,92 @@ class AstrBotDashboard: } return scope_map.get(path) - def check_port_in_use(self, port: int) -> bool: + def check_port_in_use(self, host: str, port: int) -> bool: """跨平台检测端口是否被占用""" + family = socket.AF_INET6 if ":" in host else socket.AF_INET try: - # 创建 IPv4 TCP Socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # 设置超时时间 - sock.settimeout(2) - result = sock.connect_ex(("127.0.0.1", port)) - sock.close() - # result 为 0 表示端口被占用 - return result == 0 - except Exception as e: - logger.warning(f"检查端口 {port} 时发生错误: {e!s}") - # 如果出现异常,保守起见认为端口可能被占用 + with socket.socket(family, socket.SOCK_STREAM) as s: + s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + s.bind((host, port)) + return False + except OSError: return True def get_process_using_port(self, port: int) -> str: - """获取占用端口的进程详细信息""" + """获取占用端口的进程信息""" try: - for conn in psutil.net_connections(kind="inet"): - if cast(_AddrWithPort, conn.laddr).port == port: - try: - process = psutil.Process(conn.pid) - # 获取详细信息 - proc_info = [ - f"进程名: {process.name()}", - f"PID: {process.pid}", - f"执行路径: {process.exe()}", - f"工作目录: {process.cwd()}", - f"启动命令: {' '.join(process.cmdline())}", - ] - return "\n ".join(proc_info) - except (psutil.NoSuchProcess, psutil.AccessDenied) as e: - return f"无法获取进程详细信息(可能需要管理员权限): {e!s}" - return "未找到占用进程" + for proc in psutil.process_iter(["pid", "name"]): + try: + connections = proc.net_connections() + for conn in connections: + if conn.laddr.port == port: + return f"PID: {proc.info['pid']}, Name: {proc.info['name']}" + except ( + psutil.NoSuchProcess, + psutil.AccessDenied, + psutil.ZombieProcess, + ): + pass except Exception as e: return f"获取进程信息失败: {e!s}" + return "未知进程" - def _init_jwt_secret(self) -> None: - if not self.config.get("dashboard", {}).get("jwt_secret", None): - # 如果没有设置 JWT 密钥,则生成一个新的密钥 - jwt_secret = os.urandom(32).hex() - self.config["dashboard"]["jwt_secret"] = jwt_secret - self.config.save_config() - logger.info("Initialized random JWT secret for dashboard.") - self._jwt_secret = self.config["dashboard"]["jwt_secret"] + async def run(self) -> None: + """Run dashboard server (blocking)""" + if not self.enable_webui: + logger.warning( + "WebUI 已禁用 (dashboard.enable=false or DASHBOARD_ENABLE=false)" + ) - def run(self): - ip_addr = [] - dashboard_config = self.core_lifecycle.astrbot_config.get("dashboard", {}) - port = ( - os.environ.get("DASHBOARD_PORT") - or os.environ.get("ASTRBOT_DASHBOARD_PORT") - or dashboard_config.get("port", 6185) + dashboard_config = self.config.get("dashboard", {}) + host = os.environ.get("DASHBOARD_HOST") or dashboard_config.get( + "host", "0.0.0.0" ) - host = ( - os.environ.get("DASHBOARD_HOST") - or os.environ.get("ASTRBOT_DASHBOARD_HOST") - or dashboard_config.get("host", "0.0.0.0") + port = int( + os.environ.get("DASHBOARD_PORT") or dashboard_config.get("port", 6185) ) - enable = dashboard_config.get("enable", True) ssl_config = dashboard_config.get("ssl", {}) - if not isinstance(ssl_config, dict): - ssl_config = {} ssl_enable = _parse_env_bool( - os.environ.get("DASHBOARD_SSL_ENABLE") - or os.environ.get("ASTRBOT_DASHBOARD_SSL_ENABLE"), - bool(ssl_config.get("enable", False)), + os.environ.get("DASHBOARD_SSL_ENABLE"), + ssl_config.get("enable", False), ) + scheme = "https" if ssl_enable else "http" + display_host = f"[{host}]" if ":" in host else host - if not enable: - logger.info("WebUI 已被禁用") - return None - - logger.info(f"正在启动 WebUI, 监听地址: {scheme}://{host}:{port}") - if host == "0.0.0.0": + if self.enable_webui: logger.info( - "提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host)", + "正在启动 WebUI + API, 监听地址: %s://%s:%s", + scheme, + display_host, + port, + ) + else: + logger.info( + "正在启动 API Server (WebUI 已分离), 监听地址: %s://%s:%s", + scheme, + display_host, + port, ) - if host not in ["localhost", "127.0.0.1"]: - try: - ip_addr = get_local_ip_addresses() - except Exception as _: - pass - if isinstance(port, str): - port = int(port) + check_hosts = {host} + if host not in ("127.0.0.1", "localhost", "::1"): + check_hosts.add("127.0.0.1") + for check_host in check_hosts: + if self.check_port_in_use(check_host, port): + info = self.get_process_using_port(port) + raise RuntimeError(f"端口 {port} 已被占用\n{info}") - if self.check_port_in_use(port): - process_info = self.get_process_using_port(port) - logger.error( - f"错误:端口 {port} 已被占用\n" - f"占用信息: \n {process_info}\n" - f"请确保:\n" - f"1. 没有其他 AstrBot 实例正在运行\n" - f"2. 端口 {port} 没有被其他程序占用\n" - f"3. 如需使用其他端口,请修改配置文件", - ) - - raise Exception(f"端口 {port} 已被占用") - - parts = [f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n"] - parts.append(f" ➜ 本地: {scheme}://localhost:{port}\n") - for ip in ip_addr: - parts.append(f" ➜ 网络: {scheme}://{ip}:{port}\n") - parts.append(" ➜ 默认用户名和密码: astrbot\n ✨✨✨\n") - display = "".join(parts) - - if not ip_addr: - display += ( - "可在 data/cmd_config.json 中配置 dashboard.host 以便远程访问。\n" - ) - - logger.info(display) + if self.enable_webui: + self._print_access_urls(host, port, scheme) # 配置 Hypercorn config = HyperConfig() - config.bind = [f"{host}:{port}"] + binds: list[str] = [self._build_bind(host, port)] + if host == "::" and platform.system() in ("Windows", "Darwin"): + binds.append(self._build_bind("0.0.0.0", port)) + config.bind = binds + if ssl_enable: cert_file = ( os.environ.get("DASHBOARD_SSL_CERT") @@ -414,12 +460,46 @@ class AstrBotDashboard: if disable_access_log: config.accesslog = None else: - # 启用访问日志,使用简洁格式 config.accesslog = "-" config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s" - return serve(self.app, config, shutdown_trigger=self.shutdown_trigger) + await serve(self.app, config, shutdown_trigger=self.shutdown_trigger) - async def shutdown_trigger(self) -> None: + @staticmethod + def _build_bind(host: str, port: int) -> str: + try: + ip: IPv4Address | IPv6Address = ip_address(host) + return f"[{ip}]:{port}" if ip.version == 6 else f"{ip}:{port}" + except ValueError: + return f"{host}:{port}" + + def _print_access_urls(self, host: str, port: int, scheme: str = "http") -> None: + local_ips: list[IPv4Address | IPv6Address] = get_local_ip_addresses() + + parts = [f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动\n\n"] + + parts.append(f" ➜ 本地: {scheme}://localhost:{port}\n") + + if host in ("::", "0.0.0.0"): + for ip in local_ips: + if ip.is_loopback: + continue + + if ip.version == 6: + display_url = f"{scheme}://[{ip}]:{port}" + else: + display_url = f"{scheme}://{ip}:{port}" + + parts.append(f" ➜ 网络: {display_url}\n") + + parts.append(" ➜ 默认用户名和密码: astrbot\n ✨✨✨\n") + + if not local_ips: + parts.append( + "可在 data/cmd_config.json 中配置 dashboard.host 以便远程访问。\n" + ) + + logger.info("".join(parts)) + + async def shutdown_trigger(self): await self.shutdown_event.wait() - logger.info("AstrBot WebUI 已经被优雅地关闭") diff --git a/dashboard/.gitignore b/dashboard/.gitignore index 6e03962af..f17c69129 100644 --- a/dashboard/.gitignore +++ b/dashboard/.gitignore @@ -1,3 +1,5 @@ node_modules/ .DS_Store -dist/ \ No newline at end of file +dist/ +bun.lock +pmpm-lock.yaml diff --git a/dashboard/env.d.ts b/dashboard/env.d.ts index b4b350830..a90bd47be 100644 --- a/dashboard/env.d.ts +++ b/dashboard/env.d.ts @@ -7,3 +7,9 @@ interface ImportMetaEnv { interface ImportMeta { readonly env: ImportMetaEnv; } + +declare module "*.vue" { + import type { DefineComponent } from "vue"; + const component: DefineComponent<{}, {}, any>; + export default component; +} diff --git a/dashboard/package.json b/dashboard/package.json index 56f1d8731..cfd0bd727 100644 --- a/dashboard/package.json +++ b/dashboard/package.json @@ -64,7 +64,7 @@ "sass": "1.66.1", "sass-loader": "13.3.2", "typescript": "5.1.6", - "vite": "6.4.1", + "vite": "5.4.1", "vue-cli-plugin-vuetify": "2.5.8", "vue-tsc": "1.8.8", "vuetify-loader": "^2.0.0-alpha.9" diff --git a/dashboard/public/config.json b/dashboard/public/config.json new file mode 100644 index 000000000..0d7e84a8a --- /dev/null +++ b/dashboard/public/config.json @@ -0,0 +1,13 @@ +{ + "apiBaseUrl": "", + "presets": [ + { + "name": "Default (Auto)", + "url": "" + }, + { + "name": "Localhost", + "url": "http://localhost:6185" + } + ] +} diff --git a/dashboard/src/components/chat/Chat.vue b/dashboard/src/components/chat/Chat.vue index 51b2dbd20..3bb98f136 100644 --- a/dashboard/src/components/chat/Chat.vue +++ b/dashboard/src/components/chat/Chat.vue @@ -11,6 +11,7 @@ :currSessionId="currSessionId" :selectedProjectId="selectedProjectId" :transportMode="transportMode" + :sendShortcut="sendShortcut" :isDark="isDark" :chatboxMode="chatboxMode" :isMobile="isMobile" @@ -29,6 +30,7 @@ @editProject="showEditProjectDialog" @deleteProject="handleDeleteProject" @updateTransportMode="setTransportMode" + @updateSendShortcut="setSendShortcut" /> @@ -79,6 +81,7 @@ :session-id="currSessionId || null" :current-session="getCurrentSession" :replyTo="replyTo" + :send-shortcut="sendShortcut" @send="handleSendMessage" @stop="handleStopMessage" @toggleStreaming="toggleStreaming" @@ -110,6 +113,7 @@ :session-id="currSessionId || null" :current-session="getCurrentSession" :replyTo="replyTo" + :send-shortcut="sendShortcut" @send="handleSendMessage" @stop="handleStopMessage" @toggleStreaming="toggleStreaming" @@ -140,6 +144,7 @@ :session-id="currSessionId || null" :current-session="getCurrentSession" :replyTo="replyTo" + :send-shortcut="sendShortcut" @send="handleSendMessage" @stop="handleStopMessage" @toggleStreaming="toggleStreaming" @@ -226,6 +231,8 @@ import { useToast } from '@/utils/toast'; interface Props { chatboxMode?: boolean; } +type SendShortcut = 'enter' | 'shift_enter'; +const SEND_SHORTCUT_STORAGE_KEY = 'chat_send_shortcut'; const props = withDefaults(defineProps(), { chatboxMode: false @@ -334,6 +341,12 @@ interface ReplyInfo { const replyTo = ref(null); const isDark = computed(() => useCustomizerStore().uiTheme === 'PurpleThemeDark'); +const sendShortcut = ref('shift_enter'); + +function setSendShortcut(mode: SendShortcut) { + sendShortcut.value = mode; + localStorage.setItem(SEND_SHORTCUT_STORAGE_KEY, mode); +} // 检测是否为手机端 function checkMobile() { @@ -725,6 +738,10 @@ watch(sessions, (newSessions) => { }); onMounted(() => { + const storedShortcut = localStorage.getItem(SEND_SHORTCUT_STORAGE_KEY); + if (storedShortcut === 'enter' || storedShortcut === 'shift_enter') { + sendShortcut.value = storedShortcut; + } checkMobile(); window.addEventListener('resize', checkMobile); getSessions(); diff --git a/dashboard/src/components/chat/ChatInput.vue b/dashboard/src/components/chat/ChatInput.vue index d2d9c4b88..ee48120f7 100644 --- a/dashboard/src/components/chat/ChatInput.vue +++ b/dashboard/src/components/chat/ChatInput.vue @@ -173,6 +173,7 @@ interface Props { currentSession?: Session | null; configId?: string | null; replyTo?: ReplyInfo | null; + sendShortcut?: 'enter' | 'shift_enter'; } const props = withDefaults(defineProps(), { @@ -180,7 +181,8 @@ const props = withDefaults(defineProps(), { currentSession: null, configId: null, stagedFiles: () => [], - replyTo: null + replyTo: null, + sendShortcut: 'shift_enter' }); const emit = defineEmits<{ @@ -253,9 +255,29 @@ watch(localPrompt, () => { }); function handleKeyDown(e: KeyboardEvent) { - // Enter 插入换行(桌面和手机端均如此,发送通过右下角发送按鈕) - // Shift+Enter 发送(Ctrl+Enter / Cmd+Enter 也保留) - if (e.keyCode === 13 && (e.shiftKey || e.ctrlKey || e.metaKey)) { + const isEnter = e.key === 'Enter'; + if (!isEnter) { + // Ctrl+B 录音 + if (e.ctrlKey && e.keyCode === 66) { + e.preventDefault(); + if (ctrlKeyDown.value) return; + + ctrlKeyDown.value = true; + ctrlKeyTimer.value = window.setTimeout(() => { + if (ctrlKeyDown.value && !props.isRecording) { + emit('startRecording'); + } + }, ctrlKeyLongPressThreshold); + } + return; + } + + const isSendHotkey = + e.ctrlKey || + e.metaKey || + (props.sendShortcut === 'enter' ? !e.shiftKey : e.shiftKey); + + if (isSendHotkey) { e.preventDefault(); if (localPrompt.value.trim() === '/astr_live_dev') { emit('openLiveMode'); @@ -267,19 +289,6 @@ function handleKeyDown(e: KeyboardEvent) { } return; } - - // Ctrl+B 录音 - if (e.ctrlKey && e.keyCode === 66) { - e.preventDefault(); - if (ctrlKeyDown.value) return; - - ctrlKeyDown.value = true; - ctrlKeyTimer.value = window.setTimeout(() => { - if (ctrlKeyDown.value && !props.isRecording) { - emit('startRecording'); - } - }, ctrlKeyLongPressThreshold); - } } function handleKeyUp(e: KeyboardEvent) { diff --git a/dashboard/src/components/chat/ConversationSidebar.vue b/dashboard/src/components/chat/ConversationSidebar.vue index a3645694b..8045fd8d1 100644 --- a/dashboard/src/components/chat/ConversationSidebar.vue +++ b/dashboard/src/components/chat/ConversationSidebar.vue @@ -231,6 +231,50 @@ + + + + + + + + {{ opt.label }} + + + + +