From 75ee46715af1d06551960d1d72e4cd070a13038e Mon Sep 17 00:00:00 2001 From: LIghtJUNction Date: Fri, 6 Feb 2026 02:46:16 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=AF=E6=8C=81ipv6=E5=B9=B6=E5=AE=8C?= =?UTF-8?q?=E5=96=84astrbot=20run=E5=AD=90=E5=91=BD=E4=BB=A4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 默认host修改为::,同时新增两个环境变量DASHBOARD_HOST,DASHBOARD_ENABLE,和DASHBOARD_PORT对齐 * feat: systemd support (#4880) * fix: pyright lint (#4874) * feat: 将 MessageSession 的 platform_id 改为 init=False,实例化时无需传入 Co-authored-by: aider (openai/gpt-5.2) * refactor: 将 isinstance 检查改为元组、将默认模型值设为空字符串、将类型注解改为 Any 并导入 * refactor: 为 _serialize_job 增加返回类型注解 dict * fix: 使用 cast 获取百度 AIP 的 msg 并对 psutil_addr 引入 type: ignore Co-authored-by: aider (openai/gpt-5.2) * refactor: 引入 _AddrWithPort 协议并替换 conn.laddr 的 cast Co-authored-by: aider (openai/gpt-5.2) * fix: 在构建 AstrBotMessage 时对 ctx.channel 可能为 None 进行兜底处理 Co-authored-by: aider (openai/gpt-5.2) --------- Co-authored-by: aider (openai/gpt-5.2) * fix: TypeError when MCP schema type is a list (#4867) * Fix TypeError when MCP schema type is a list Fixes crash in Gemini native tools with VRChat MCP. * Refactor: avoid modifying schema in place per feedback * Fix formatting and cleanup comments * docs: update watashiwakoseinodesukara Removed duplicate text and added a new image. * 修复/跨平台一致性 * 琐事/类型标注和一些简单错误修正 * 修复/检查端口时候包含ipv6 * 修复/enable变量的赋值逻辑 --------- Co-authored-by: Dt8333 <25431943+Dt8333@users.noreply.github.com> Co-authored-by: aider (openai/gpt-5.2) Co-authored-by: boushi1111 <95118141+boushi1111@users.noreply.github.com> Co-authored-by: Soulter <37870767+Soulter@users.noreply.github.com> --- .python-version | 2 +- README.md | 5 +- astrbot/cli/commands/cmd_run.py | 17 +- astrbot/core/agent/tool.py | 14 +- astrbot/core/config/default.py | 14 +- .../strategies/baidu_aip.py | 5 +- astrbot/core/platform/message_session.py | 4 +- .../aiocqhttp/aiocqhttp_platform_adapter.py | 4 +- .../discord/discord_platform_adapter.py | 15 +- .../qqofficial_webhook/qo_webhook_server.py | 2 +- astrbot/core/platform/sources/slack/client.py | 2 +- .../platform/sources/slack/slack_adapter.py | 2 +- .../platform/sources/wecom/wecom_adapter.py | 2 +- .../sources/wecom_ai_bot/wecomai_adapter.py | 2 +- .../weixin_offacc_adapter.py | 2 +- .../sources/fishaudio_tts_api_source.py | 2 +- astrbot/core/utils/io.py | 45 ++- astrbot/dashboard/routes/cron.py | 2 +- astrbot/dashboard/routes/knowledge_base.py | 3 +- astrbot/dashboard/routes/route.py | 7 +- astrbot/dashboard/server.py | 375 ++++++++++-------- dashboard/vite.config.ts | 46 +-- scripts/astrbot.service | 15 + 23 files changed, 359 insertions(+), 228 deletions(-) create mode 100644 scripts/astrbot.service diff --git a/.python-version b/.python-version index 7c7a975f4..c8cfe3959 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.10 \ No newline at end of file +3.10 diff --git a/README.md b/README.md index eca615c99..f4640ee43 100644 --- a/README.md +++ b/README.md @@ -264,8 +264,9 @@ pre-commit install
+_陪伴与能力从来不应该是对立面。我们希望创造的是一个既能理解情绪、给予陪伴,也能可靠完成工作的机器人。_ + _私は、高性能ですから!_ -陪伴与能力从来不应该是对立面。我们希望创造的是一个既能理解情绪、给予陪伴,也能可靠完成工作的机器人。 - + diff --git a/astrbot/cli/commands/cmd_run.py b/astrbot/cli/commands/cmd_run.py index 9333f1b87..6952ba323 100644 --- a/astrbot/cli/commands/cmd_run.py +++ b/astrbot/cli/commands/cmd_run.py @@ -27,9 +27,17 @@ async def run_astrbot(astrbot_root: Path): @click.option("--reload", "-r", is_flag=True, help="插件自动重载") -@click.option("--port", "-p", help="Astrbot Dashboard端口", required=False, type=str) +@click.option( + "--host", "-H", help="Astrbot Dashboard Host,默认::", required=False, type=str +) +@click.option( + "--port", "-p", help="Astrbot Dashboard端口,默认6185", required=False, type=str +) +@click.option( + "--backend-only", is_flag=True, default=False, help="禁用WEBUI,仅启动后端" +) @click.command() -def run(reload: bool, port: str) -> None: +def run(reload: bool, host: str, port: str, backend_only: bool) -> None: """运行 AstrBot""" try: os.environ["ASTRBOT_CLI"] = "1" @@ -43,8 +51,9 @@ def run(reload: bool, port: str) -> None: os.environ["ASTRBOT_ROOT"] = str(astrbot_root) sys.path.insert(0, str(astrbot_root)) - if port: - os.environ["DASHBOARD_PORT"] = port + os.environ["DASHBOARD_PORT"] = port or "6185" + os.environ["DASHBOARD_HOST"] = host or "::" + os.environ["DASHBOARD_ENABLE"] = str(not backend_only) if reload: click.echo("启用插件自动重载") diff --git a/astrbot/core/agent/tool.py b/astrbot/core/agent/tool.py index 2ffbd40ca..50899ff80 100644 --- a/astrbot/core/agent/tool.py +++ b/astrbot/core/agent/tool.py @@ -246,8 +246,18 @@ class ToolSet: result = {} - if "type" in schema and schema["type"] in supported_types: - result["type"] = schema["type"] + # Avoid side effects by not modifying the original schema + origin_type = schema.get("type") + target_type = origin_type + + # Compatibility fix: Gemini API expects 'type' to be a string (enum), + # but standard JSON Schema (MCP) allows lists (e.g. ["string", "null"]). + # We fallback to the first non-null type. + if isinstance(origin_type, list): + target_type = next((t for t in origin_type if t != "null"), "string") + + if target_type in supported_types: + result["type"] = target_type if "format" in schema and schema["format"] in supported_formats.get( result["type"], set(), diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 10a6fc599..fc0d6b373 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -182,7 +182,7 @@ DEFAULT_CONFIG = { "username": "astrbot", "password": "77b90590a8945a7d36c963981a307dc9", "jwt_secret": "", - "host": "0.0.0.0", + "host": "::", "port": 6185, "disable_access_log": True, }, @@ -273,14 +273,14 @@ CONFIG_METADATA_2 = { "is_sandbox": False, "unified_webhook_mode": True, "webhook_uuid": "", - "callback_server_host": "0.0.0.0", + "callback_server_host": "::", "port": 6196, }, "OneBot v11": { "id": "default", "type": "aiocqhttp", "enable": False, - "ws_reverse_host": "0.0.0.0", + "ws_reverse_host": "::", "ws_reverse_port": 6199, "ws_reverse_token": "", }, @@ -295,7 +295,7 @@ CONFIG_METADATA_2 = { "api_base_url": "https://api.weixin.qq.com/cgi-bin/", "unified_webhook_mode": True, "webhook_uuid": "", - "callback_server_host": "0.0.0.0", + "callback_server_host": "::", "port": 6194, "active_send_mode": False, }, @@ -311,7 +311,7 @@ CONFIG_METADATA_2 = { "api_base_url": "https://qyapi.weixin.qq.com/cgi-bin/", "unified_webhook_mode": True, "webhook_uuid": "", - "callback_server_host": "0.0.0.0", + "callback_server_host": "::", "port": 6195, }, "企业微信智能机器人": { @@ -325,7 +325,7 @@ CONFIG_METADATA_2 = { "encoding_aes_key": "", "unified_webhook_mode": True, "webhook_uuid": "", - "callback_server_host": "0.0.0.0", + "callback_server_host": "::", "port": 6198, }, "飞书(Lark)": { @@ -399,7 +399,7 @@ CONFIG_METADATA_2 = { "slack_connection_mode": "socket", # webhook, socket "unified_webhook_mode": True, "webhook_uuid": "", - "slack_webhook_host": "0.0.0.0", + "slack_webhook_host": "::", "slack_webhook_port": 6197, "slack_webhook_path": "/astrbot-slack-webhook/callback", }, diff --git a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py index bfa82de0e..dd8ca629e 100644 --- a/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py +++ b/astrbot/core/pipeline/content_safety_check/strategies/baidu_aip.py @@ -1,5 +1,7 @@ """使用此功能应该先 pip install baidu-aip""" +from typing import Any, cast + from aip import AipContentCensor from . import ContentSafetyStrategy @@ -23,7 +25,8 @@ class BaiduAipStrategy(ContentSafetyStrategy): count = len(res["data"]) parts = [f"百度审核服务发现 {count} 处违规:\n"] for i in res["data"]: - parts.append(f"{i['msg']};\n") + # 百度 AIP 返回结构是动态 dict;类型检查时 i 可能被推断为序列,转成 dict 后用 get 取字段 + parts.append(f"{cast(dict[str, Any], i).get('msg', '')};\n") parts.append("\n判断结果:" + res["conclusion"]) info = "".join(parts) return False, info diff --git a/astrbot/core/platform/message_session.py b/astrbot/core/platform/message_session.py index 982a844c2..b282b307a 100644 --- a/astrbot/core/platform/message_session.py +++ b/astrbot/core/platform/message_session.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field from astrbot.core.platform.message_type import MessageType @@ -13,7 +13,7 @@ class MessageSession: """平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。""" message_type: MessageType session_id: str - platform_id: str | None = None + platform_id: str = field(init=False) def __str__(self): return f"{self.platform_id}:{self.message_type.value}:{self.session_id}" diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py index d4d8e1d62..8540ff592 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_platform_adapter.py @@ -418,9 +418,9 @@ class AiocqhttpAdapter(Platform): def run(self) -> Awaitable[Any]: if not self.host or not self.port: logger.warning( - "aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://0.0.0.0:6199", + "aiocqhttp: 未配置 ws_reverse_host 或 ws_reverse_port,将使用默认值:http://[::]:6199", ) - self.host = "0.0.0.0" + self.host = "::" self.port = 6199 coro = self.bot.run_task( diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index d81afd1ab..ed9899f6f 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -444,9 +444,20 @@ class DiscordPlatformAdapter(Platform): logger.warning(f"[Discord] 指令 '{cmd_name}' defer 失败: {e}") # 2. 构建 AstrBotMessage + channel = ctx.channel abm = AstrBotMessage() - abm.type = self._get_message_type(ctx.channel, ctx.guild_id) - abm.group_id = self._get_channel_id(ctx.channel) + if channel is not None: + abm.type = self._get_message_type(channel, ctx.guild_id) + abm.group_id = self._get_channel_id(channel) + else: + # 防守式兜底:channel 取不到时,仍能根据 guild_id/channel_id 推断会话信息 + abm.type = ( + MessageType.GROUP_MESSAGE + if ctx.guild_id is not None + else MessageType.FRIEND_MESSAGE + ) + abm.group_id = str(ctx.channel_id) + abm.message_str = message_str_for_filter abm.sender = MessageMember( user_id=str(ctx.author.id), diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 2eda11a6c..50db7fb21 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -19,7 +19,7 @@ class QQOfficialWebhook: self.secret = config["secret"] self.port = config.get("port", 6196) self.is_sandbox = config.get("is_sandbox", False) - self.callback_server_host = config.get("callback_server_host", "0.0.0.0") + self.callback_server_host = config.get("callback_server_host", "::") if isinstance(self.port, str): self.port = int(self.port) diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index fbdc71759..0e43cadee 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -23,7 +23,7 @@ class SlackWebhookClient: self, web_client: AsyncWebClient, signing_secret: str, - host: str = "0.0.0.0", + host: str = "::", port: int = 3000, path: str = "/slack/events", event_handler: Callable | None = None, diff --git a/astrbot/core/platform/sources/slack/slack_adapter.py b/astrbot/core/platform/sources/slack/slack_adapter.py index afd80a8fe..f34242ce6 100644 --- a/astrbot/core/platform/sources/slack/slack_adapter.py +++ b/astrbot/core/platform/sources/slack/slack_adapter.py @@ -47,7 +47,7 @@ class SlackAdapter(Platform): self.signing_secret = platform_config.get("signing_secret") self.connection_mode = platform_config.get("slack_connection_mode", "socket") self.unified_webhook_mode = platform_config.get("unified_webhook_mode", False) - self.webhook_host = platform_config.get("slack_webhook_host", "0.0.0.0") + self.webhook_host = platform_config.get("slack_webhook_host", "::") self.webhook_port = platform_config.get("slack_webhook_port", 3000) self.webhook_path = platform_config.get( "slack_webhook_path", diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index adc24578f..8ae97f746 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -42,7 +42,7 @@ class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) self.port = int(cast(str, config.get("port"))) - self.callback_server_host = config.get("callback_server_host", "0.0.0.0") + self.callback_server_host = config.get("callback_server_host", "::") self.server.add_url_rule( "/callback/command", view_func=self.verify, diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py index 57da5176b..af6f834b1 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_adapter.py @@ -111,7 +111,7 @@ class WecomAIBotAdapter(Platform): self.token = self.config["token"] self.encoding_aes_key = self.config["encoding_aes_key"] self.port = int(self.config["port"]) - self.host = self.config.get("callback_server_host", "0.0.0.0") + self.host = self.config.get("callback_server_host", "::") self.bot_name = self.config.get("wecom_ai_bot_name", "") self.initial_respond_text = self.config.get( "wecomaibot_init_respond_text", diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index a38952127..7aa33a91c 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -38,7 +38,7 @@ class WeixinOfficialAccountServer: def __init__(self, event_queue: asyncio.Queue, config: dict): self.server = quart.Quart(__name__) self.port = int(cast(int | str, config.get("port"))) - self.callback_server_host = config.get("callback_server_host", "0.0.0.0") + self.callback_server_host = config.get("callback_server_host", "::") self.token = config.get("token") self.encoding_aes_key = config.get("encoding_aes_key") self.appid = config.get("appid") diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index e246e00ed..70eabd289 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -63,7 +63,7 @@ class ProviderFishAudioTTSAPI(TTSProvider): self.headers = { "Authorization": f"Bearer {self.chosen_api_key}", } - self.set_model(provider_config.get("model", None)) + self.set_model(provider_config.get("model", "")) async def _get_reference_id_by_character(self, character: str) -> str | None: """获取角色的reference_id diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index fcf5bb3c7..ba487bbc9 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -1,3 +1,4 @@ +import asyncio import base64 import logging import os @@ -7,6 +8,7 @@ import ssl import time import uuid import zipfile +from ipaddress import IPv4Address, IPv6Address, ip_address from pathlib import Path import aiohttp @@ -217,18 +219,51 @@ def file_to_base64(file_path: str) -> str: return "base64://" + base64_str -def get_local_ip_addresses(): +def get_local_ip_addresses() -> list[IPv4Address | IPv6Address]: net_interfaces = psutil.net_if_addrs() - network_ips = [] + network_ips: list[IPv4Address | IPv6Address] = [] - for interface, addrs in net_interfaces.items(): + for _, addrs in net_interfaces.items(): for addr in addrs: - if addr.family == socket.AF_INET: # 使用 socket.AF_INET 代替 psutil.AF_INET - network_ips.append(addr.address) + if addr.family == socket.AF_INET: + network_ips.append(ip_address(addr.address)) + elif addr.family == socket.AF_INET6: + # 过滤掉 IPv6 的 link-local 地址(fe80:...) + # 用这个不如用::1 + ip = ip_address(addr.address.split("%")[0]) # 处理带 zone index 的情况 + network_ips.append(ip) return network_ips +async def get_public_ip_address() -> list[IPv4Address | IPv6Address]: + urls = [ + "https://api64.ipify.org", + "https://ident.me", + "https://ifconfig.me", + "https://icanhazip.com", + ] + found_ips: dict[int, IPv4Address | IPv6Address] = {} + + async def fetch(session: aiohttp.ClientSession, url: str): + try: + async with session.get(url, timeout=3) as resp: + if resp.status == 200: + raw_ip = (await resp.text()).strip() + ip = ip_address(raw_ip) + if ip.version not in found_ips: + found_ips[ip.version] = ip + except Exception: + pass + + async with aiohttp.ClientSession() as session: + tasks = [fetch(session, url) for url in urls] + await asyncio.gather(*tasks) + + # 返回找到的所有 IP 对象列表 + return list(found_ips.values()) + + async def get_dashboard_version(): dist_dir = os.path.join(get_astrbot_data_path(), "dist") if os.path.exists(dist_dir): diff --git a/astrbot/dashboard/routes/cron.py b/astrbot/dashboard/routes/cron.py index 6bef93859..8861fc5cc 100644 --- a/astrbot/dashboard/routes/cron.py +++ b/astrbot/dashboard/routes/cron.py @@ -23,7 +23,7 @@ class CronRoute(Route): ] self.register_routes() - def _serialize_job(self, job): + def _serialize_job(self, job) -> dict: data = job.model_dump() if hasattr(job, "model_dump") else job.__dict__ for k in ["created_at", "updated_at", "last_run_at", "next_run_time"]: if isinstance(data.get(k), datetime): diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py index 537a81f0b..25bc2cf34 100644 --- a/astrbot/dashboard/routes/knowledge_base.py +++ b/astrbot/dashboard/routes/knowledge_base.py @@ -4,6 +4,7 @@ import asyncio import os import traceback import uuid +from typing import Any import aiofiles from quart import request @@ -75,7 +76,7 @@ class KnowledgeBaseRoute(Route): } def _set_task_result( - self, task_id: str, status: str, result: any = None, error: str | None = None + self, task_id: str, status: str, result: Any = None, error: str | None = None ) -> None: self.upload_tasks[task_id] = { "status": status, diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py index 01ab292d4..530a0807c 100644 --- a/astrbot/dashboard/routes/route.py +++ b/astrbot/dashboard/routes/route.py @@ -1,6 +1,6 @@ -from dataclasses import dataclass +from dataclasses import asdict, dataclass -from quart import Quart +from quart import Quart, jsonify from astrbot.core.config.astrbot_config import AstrBotConfig @@ -57,3 +57,6 @@ class Response: self.data = data self.message = message return self + + def to_json(self): + return jsonify(asdict(self)) diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index a34c12aef..5c2831386 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -1,7 +1,11 @@ import asyncio +import ipaddress import logging import os +import platform import socket +from collections.abc import Callable +from ipaddress import IPv4Address, IPv6Address from typing import cast import jwt @@ -9,7 +13,6 @@ import psutil from flask.json.provider import DefaultJSONProvider from hypercorn.asyncio import serve from hypercorn.config import Config as HyperConfig -from psutil._common import addr as psutil_addr from quart import Quart, g, jsonify, request from quart.logging import default_handler from quart_cors import cors @@ -32,8 +35,17 @@ from .routes.t2i import T2iRoute APP: Quart - class AstrBotDashboard: + """AstrBot Web Dashboard""" + + ALLOWED_ENDPOINT_PREFIXES = ( + "/api/auth/login", + "/api/file", + "/api/platform/webhook", + "/api/stat/start-time", + "/api/backup/download", + ) + def __init__( self, core_lifecycle: AstrBotCoreLifecycle, @@ -43,17 +55,35 @@ class AstrBotDashboard: ) -> None: self.core_lifecycle = core_lifecycle self.config = core_lifecycle.astrbot_config + self.shutdown_event = shutdown_event - # 参数指定webui目录 + self._init_paths(webui_dir) + self._init_app() + self.context = RouteContext(self.config, self.app) + + self._init_routes(db) + self._init_plugin_route_index() + self._init_jwt_secret() + + # ------------------------------------------------------------------ + # 初始化阶段 + # ------------------------------------------------------------------ + + def _init_paths(self, webui_dir: str | None): if webui_dir and os.path.exists(webui_dir): self.data_path = os.path.abspath(webui_dir) else: self.data_path = os.path.abspath( - os.path.join(get_astrbot_data_path(), "dist"), + os.path.join(get_astrbot_data_path(), "dist") ) - self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/") - APP = self.app # noqa + def _init_app(self): + self.app = Quart( + "dashboard", + static_folder=self.data_path, + static_url_path="/", + ) + APP = self.app self.app = cors( self.app, allow_origin="*", allow_methods="*", allow_headers="*" ) @@ -61,45 +91,38 @@ class AstrBotDashboard: 128 * 1024 * 1024 ) # 将 Flask 允许的最大上传文件体大小设置为 128 MB cast(DefaultJSONProvider, self.app.json).sort_keys = False + self.app.before_request(self.auth_middleware) - # token 用于验证请求 logging.getLogger(self.app.name).removeHandler(default_handler) - self.context = RouteContext(self.config, self.app) - self.ur = UpdateRoute( - self.context, - core_lifecycle.astrbot_updator, - core_lifecycle, + + def _init_routes(self, db: BaseDatabase): + UpdateRoute( + self.context, self.core_lifecycle.astrbot_updator, self.core_lifecycle ) - self.sr = StatRoute(self.context, db, core_lifecycle) - self.pr = PluginRoute( - self.context, - core_lifecycle, - core_lifecycle.plugin_manager, + StatRoute(self.context, db, self.core_lifecycle) + PluginRoute( + self.context, self.core_lifecycle, self.core_lifecycle.plugin_manager ) - self.command_route = CommandRoute(self.context) - self.cr = ConfigRoute(self.context, core_lifecycle) - self.lr = LogRoute(self.context, core_lifecycle.log_broker) - self.sfr = StaticFileRoute(self.context) - self.ar = AuthRoute(self.context) - self.chat_route = ChatRoute(self.context, db, core_lifecycle) - self.chatui_project_route = ChatUIProjectRoute(self.context, db) - self.tools_root = ToolsRoute(self.context, core_lifecycle) - self.subagent_route = SubAgentRoute(self.context, core_lifecycle) - self.skills_route = SkillsRoute(self.context, core_lifecycle) - self.conversation_route = ConversationRoute(self.context, db, core_lifecycle) - self.file_route = FileRoute(self.context) - self.session_management_route = SessionManagementRoute( - self.context, - db, - core_lifecycle, - ) - self.persona_route = PersonaRoute(self.context, db, core_lifecycle) - self.cron_route = CronRoute(self.context, core_lifecycle) - self.t2i_route = T2iRoute(self.context, core_lifecycle) - self.kb_route = KnowledgeBaseRoute(self.context, core_lifecycle) - self.platform_route = PlatformRoute(self.context, core_lifecycle) - self.backup_route = BackupRoute(self.context, db, core_lifecycle) - self.live_chat_route = LiveChatRoute(self.context, db, core_lifecycle) + CommandRoute(self.context) + ConfigRoute(self.context, self.core_lifecycle) + LogRoute(self.context, self.core_lifecycle.log_broker) + StaticFileRoute(self.context) + AuthRoute(self.context) + ChatRoute(self.context, db, self.core_lifecycle) + ChatUIProjectRoute(self.context, db) + ToolsRoute(self.context, self.core_lifecycle) + SubAgentRoute(self.context, self.core_lifecycle) + SkillsRoute(self.context, self.core_lifecycle) + ConversationRoute(self.context, db, self.core_lifecycle) + FileRoute(self.context) + SessionManagementRoute(self.context, db, self.core_lifecycle) + PersonaRoute(self.context, db, self.core_lifecycle) + CronRoute(self.context, self.core_lifecycle) + T2iRoute(self.context, self.core_lifecycle) + KnowledgeBaseRoute(self.context, self.core_lifecycle) + PlatformRoute(self.context, self.core_lifecycle) + BackupRoute(self.context, db, self.core_lifecycle) + LiveChatRoute(self.context, db, self.core_lifecycle) self.app.add_url_rule( "/api/plug/", @@ -107,18 +130,30 @@ class AstrBotDashboard: methods=["GET", "POST"], ) - self.shutdown_event = shutdown_event + def _init_plugin_route_index(self): + """将插件路由索引,避免 O(n) 查找""" + self._plugin_route_map: dict[tuple[str, str], Callable] = {} - self._init_jwt_secret() + for ( + route, + handler, + methods, + _, + ) in self.core_lifecycle.star_context.registered_web_apis: + for method in methods: + self._plugin_route_map[(route, method)] = handler - async def srv_plug_route(self, subpath, *args, **kwargs): - """插件路由""" - registered_web_apis = self.core_lifecycle.star_context.registered_web_apis - for api in registered_web_apis: - route, view_handler, methods, _ = api - if route == f"/{subpath}" and request.method in methods: - return await view_handler(*args, **kwargs) - return jsonify(Response().error("未找到该路由").__dict__) + def _init_jwt_secret(self): + dashboard_cfg = self.config.setdefault("dashboard", {}) + if not dashboard_cfg.get("jwt_secret"): + dashboard_cfg["jwt_secret"] = os.urandom(32).hex() + self.config.save_config() + logger.info("Initialized random JWT secret for dashboard.") + self._jwt_secret = dashboard_cfg["jwt_secret"] + + # ------------------------------------------------------------------ + # Middleware中间件 + # ------------------------------------------------------------------ async def auth_middleware(self): # 放行CORS预检请求 @@ -126,154 +161,162 @@ class AstrBotDashboard: return None if not request.path.startswith("/api"): return None - allowed_endpoints = [ - "/api/auth/login", - "/api/file", - "/api/platform/webhook", - "/api/stat/start-time", - "/api/backup/download", # 备份下载使用 URL 参数传递 token - ] - if any(request.path.startswith(prefix) for prefix in allowed_endpoints): + + if any(request.path.startswith(p) for p in self.ALLOWED_ENDPOINT_PREFIXES): return None - # 声明 JWT + token = request.headers.get("Authorization") if not token: - r = jsonify(Response().error("未授权").__dict__) - r.status_code = 401 - return r - token = token.removeprefix("Bearer ") + return self._unauthorized("未授权") + try: - payload = jwt.decode(token, self._jwt_secret, algorithms=["HS256"]) + payload = jwt.decode( + token.removeprefix("Bearer "), + self._jwt_secret, + algorithms=["HS256"], + options={"require": ["username"]}, + ) g.username = payload["username"] except jwt.ExpiredSignatureError: - r = jsonify(Response().error("Token 过期").__dict__) - r.status_code = 401 - return r - except jwt.InvalidTokenError: - r = jsonify(Response().error("Token 无效").__dict__) - r.status_code = 401 - return r + return self._unauthorized("Token 过期") + except jwt.PyJWTError: + return self._unauthorized("Token 无效") + + @staticmethod + def _unauthorized(msg: str): + r = jsonify(Response().error(msg).to_json()) + r.status_code = 401 + return r + + # ------------------------------------------------------------------ + # 插件路由 + # ------------------------------------------------------------------ + + async def srv_plug_route(self, subpath: str, *args, **kwargs): + handler = self._plugin_route_map.get((f"/{subpath}", request.method)) + if not handler: + return jsonify(Response().error("未找到该路由").to_json()) - def check_port_in_use(self, port: int) -> bool: - """跨平台检测端口是否被占用""" try: - # 创建 IPv4 TCP Socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - # 设置超时时间 - sock.settimeout(2) - result = sock.connect_ex(("127.0.0.1", port)) - sock.close() - # result 为 0 表示端口被占用 - return result == 0 - except Exception as e: - logger.warning(f"检查端口 {port} 时发生错误: {e!s}") - # 如果出现异常,保守起见认为端口可能被占用 + return await handler(*args, **kwargs) + except Exception: + logger.exception("插件 Web API 执行异常") + return jsonify(Response().error("插件执行失败").to_json()) + + # ------------------------------------------------------------------ + # 网络 / 端口 + # ------------------------------------------------------------------ + + def check_port_in_use(self, host: str, port: int) -> bool: + try: + family = socket.AF_INET6 if ":" in host else socket.AF_INET + with socket.socket(family, socket.SOCK_STREAM) as sock: + sock.settimeout(2) + return sock.connect_ex((host, port)) == 0 + except Exception: return True - def get_process_using_port(self, port: int) -> str: - """获取占用端口的进程详细信息""" + @staticmethod + def get_process_using_port(port: int) -> str: try: - for conn in psutil.net_connections(kind="inet"): - if cast(psutil_addr, conn.laddr).port == port: - try: - process = psutil.Process(conn.pid) - # 获取详细信息 - proc_info = [ - f"进程名: {process.name()}", - f"PID: {process.pid}", - f"执行路径: {process.exe()}", - f"工作目录: {process.cwd()}", - f"启动命令: {' '.join(process.cmdline())}", + for conn in psutil.net_connections(kind="all"): + if conn.laddr and conn.laddr.port == port and conn.pid: + p = psutil.Process(conn.pid) + return "\n ".join( + [ + f"进程名: {p.name()}", + f"PID: {p.pid}", + f"执行路径: {p.exe()}", + f"工作目录: {p.cwd()}", + f"启动命令: {' '.join(p.cmdline())}", ] - return "\n ".join(proc_info) - except (psutil.NoSuchProcess, psutil.AccessDenied) as e: - return f"无法获取进程详细信息(可能需要管理员权限): {e!s}" + ) return "未找到占用进程" except Exception as e: return f"获取进程信息失败: {e!s}" - def _init_jwt_secret(self): - if not self.config.get("dashboard", {}).get("jwt_secret", None): - # 如果没有设置 JWT 密钥,则生成一个新的密钥 - jwt_secret = os.urandom(32).hex() - self.config["dashboard"]["jwt_secret"] = jwt_secret - self.config.save_config() - logger.info("Initialized random JWT secret for dashboard.") - self._jwt_secret = self.config["dashboard"]["jwt_secret"] + # ------------------------------------------------------------------ + # 启动 + # ------------------------------------------------------------------ - def run(self): - ip_addr = [] - if p := os.environ.get("DASHBOARD_PORT"): - port = p - else: - port = self.core_lifecycle.astrbot_config["dashboard"].get("port", 6185) - host = self.core_lifecycle.astrbot_config["dashboard"].get("host", "0.0.0.0") - enable = self.core_lifecycle.astrbot_config["dashboard"].get("enable", True) + def run(self) -> None: + cfg = self.config.get("dashboard", {}) + _port: str = os.environ.get("DASHBOARD_PORT") or cfg.get("port", 6185) + port: int = int(_port) + _host = os.environ.get("DASHBOARD_HOST") or cfg.get("host", "::") + host: str = _host.strip("[]") + _env = os.environ.get("DASHBOARD_ENABLE") + enable = ( + _env.lower() in ("true", "1", "yes") + if _env is not None + else cfg.get("enable", True) + ) if not enable: logger.info("WebUI 已被禁用") return None - logger.info(f"正在启动 WebUI, 监听地址: http://{host}:{port}") + display_host = f"[{host}]" if ":" in host else host + logger.info( + "正在启动 WebUI, 监听地址: http://%s:%s", + display_host, + port, + ) - if host == "0.0.0.0": - logger.info( - "提示: WebUI 将监听所有网络接口,请注意安全。(可在 data/cmd_config.json 中配置 dashboard.host 以修改 host)", - ) + if self.check_port_in_use("127.0.0.1", port): + info = self.get_process_using_port(port) + raise RuntimeError(f"端口 {port} 已被占用\n{info}") - if host not in ["localhost", "127.0.0.1"]: - try: - ip_addr = get_local_ip_addresses() - except Exception as _: - pass - if isinstance(port, str): - port = int(port) + self._print_access_urls(host, port) - if self.check_port_in_use(port): - process_info = self.get_process_using_port(port) - logger.error( - f"错误:端口 {port} 已被占用\n" - f"占用信息: \n {process_info}\n" - f"请确保:\n" - f"1. 没有其他 AstrBot 实例正在运行\n" - f"2. 端口 {port} 没有被其他程序占用\n" - f"3. 如需使用其他端口,请修改配置文件", - ) - - raise Exception(f"端口 {port} 已被占用") - - parts = [f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动,可访问\n\n"] - parts.append(f" ➜ 本地: http://localhost:{port}\n") - for ip in ip_addr: - parts.append(f" ➜ 网络: http://{ip}:{port}\n") - parts.append(" ➜ 默认用户名和密码: astrbot\n ✨✨✨\n") - display = "".join(parts) - - if not ip_addr: - display += ( - "可在 data/cmd_config.json 中配置 dashboard.host 以便远程访问。\n" - ) - - logger.info(display) - - # 配置 Hypercorn config = HyperConfig() - config.bind = [f"{host}:{port}"] + binds: list[str] = [self._build_bind(host, port)] + # 参考:https://github.com/pgjones/hypercorn/issues/85 + if host == "::" and platform.system() in ("Windows", "Darwin"): + binds.append(self._build_bind("0.0.0.0", port)) + config.bind = binds - # 根据配置决定是否禁用访问日志 - disable_access_log = self.core_lifecycle.astrbot_config.get( - "dashboard", {} - ).get("disable_access_log", True) - if disable_access_log: + if cfg.get("disable_access_log", True): config.accesslog = None else: - # 启用访问日志,使用简洁格式 config.accesslog = "-" config.access_log_format = "%(h)s %(r)s %(s)s %(b)s %(D)s" return serve(self.app, config, shutdown_trigger=self.shutdown_trigger) - async def shutdown_trigger(self): + @staticmethod + def _build_bind(host: str, port: int) -> str: + try: + ip: IPv4Address | IPv6Address = ipaddress.ip_address(host) + return f"[{ip}]:{port}" if ip.version == 6 else f"{ip}:{port}" + except ValueError: + return f"{host}:{port}" + + def _print_access_urls(self, host: str, port: int) -> None: + local_ips: list[IPv4Address | IPv6Address] = get_local_ip_addresses() + + parts = [f"\n ✨✨✨\n AstrBot v{VERSION} WebUI 已启动\n\n"] + + parts.append(f" ➜ 本地: http://localhost:{port}\n") + + if host in ("::", "0.0.0.0"): + for ip in local_ips: + if ip.is_loopback: + continue + + # 再次过滤掉 fe80(第一次过滤在get_local_ip_addresses) + if ip.is_link_local: + continue + if ip.version == 6: + display_url = f"http://[{ip}]:{port}" + else: + display_url = f"http://{ip}:{port}" + + parts.append(f" ➜ 网络: {display_url}\n") + + parts.append(" ➜ 默认用户名和密码: astrbot\n ✨✨✨\n") + logger.info("".join(parts)) + + async def shutdown_trigger(self) -> None: await self.shutdown_event.wait() logger.info("AstrBot WebUI 已经被优雅地关闭") diff --git a/dashboard/vite.config.ts b/dashboard/vite.config.ts index 9a31b0b5c..a5168eef7 100644 --- a/dashboard/vite.config.ts +++ b/dashboard/vite.config.ts @@ -1,7 +1,7 @@ -import { fileURLToPath, URL } from 'url'; -import { defineConfig } from 'vite'; -import vue from '@vitejs/plugin-vue'; -import vuetify from 'vite-plugin-vuetify'; +import { fileURLToPath, URL } from "url"; +import { defineConfig } from "vite"; +import vue from "@vitejs/plugin-vue"; +import vuetify from "vite-plugin-vuetify"; // https://vitejs.dev/config/ export default defineConfig({ @@ -9,40 +9,40 @@ export default defineConfig({ vue({ template: { compilerOptions: { - isCustomElement: (tag) => ['v-list-recognize-title'].includes(tag) - } - } + isCustomElement: (tag) => ["v-list-recognize-title"].includes(tag), + }, + }, }), vuetify({ - autoImport: true - }) + autoImport: true, + }), ], resolve: { alias: { - mermaid: 'mermaid/dist/mermaid.js', - '@': fileURLToPath(new URL('./src', import.meta.url)) - } + mermaid: "mermaid/dist/mermaid.js", + "@": fileURLToPath(new URL("./src", import.meta.url)), + }, }, css: { preprocessorOptions: { - scss: {} - } + scss: {}, + }, }, build: { - chunkSizeWarningLimit: 1024 * 1024 // Set the limit to 1 MB + chunkSizeWarningLimit: 1024 * 1024, // Set the limit to 1 MB }, optimizeDeps: { - exclude: ['vuetify'], - entries: ['./src/**/*.vue'] + exclude: ["vuetify"], + entries: ["./src/**/*.vue"], }, server: { - host: '0.0.0.0', + host: "::", port: 3000, proxy: { - '/api': { - target: 'http://127.0.0.1:6185/', + "/api": { + target: "http://127.0.0.1:6185/", changeOrigin: true, - } - } - } + }, + }, + }, }); diff --git a/scripts/astrbot.service b/scripts/astrbot.service new file mode 100644 index 000000000..fdf891be9 --- /dev/null +++ b/scripts/astrbot.service @@ -0,0 +1,15 @@ +[Unit] +Description=AstrBot Service +After=network-online.target +Wants=network-online.target + +[Service] +Type=simple +WorkingDirectory=%h/.local/share/astrbot +ExecStart=/usr/bin/sh -c '/usr/bin/astrbot run || { /usr/bin/astrbot init && /usr/bin/astrbot run; }' +Restart=on-failure +RestartSec=5 +Environment=PYTHONUNBUFFERED=1 + +[Install] +WantedBy=default.target