From cfb0538b322553ea594201ec7c90f3f72646da3e Mon Sep 17 00:00:00 2001 From: Soulter <905617992@qq.com> Date: Wed, 25 Feb 2026 17:40:55 +0800 Subject: [PATCH] feat(chat): add websocket API key extraction and scope validation --- astrbot/dashboard/routes/open_api.py | 352 ++++++++++++++++++++++++++- astrbot/dashboard/server.py | 5 + 2 files changed, 356 insertions(+), 1 deletion(-) diff --git a/astrbot/dashboard/routes/open_api.py b/astrbot/dashboard/routes/open_api.py index 055de6732..653e22cbf 100644 --- a/astrbot/dashboard/routes/open_api.py +++ b/astrbot/dashboard/routes/open_api.py @@ -1,6 +1,9 @@ +import asyncio +import hashlib +import json from uuid import uuid4 -from quart import g, request +from quart import g, request, websocket from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -8,8 +11,12 @@ from astrbot.core.db import BaseDatabase from astrbot.core.platform.message_session import MessageSesion from astrbot.core.platform.sources.webchat.message_parts_helper import ( build_message_chain_from_payload, + strip_message_parts_path_fields, + webchat_message_parts_have_content, ) +from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr +from .api_key import ALL_OPEN_API_SCOPES from .chat import ChatRoute from .route import Response, Route, RouteContext @@ -37,6 +44,7 @@ class OpenApiRoute(Route): "/v1/im/bots": ("GET", self.get_bots), } self.register_routes() + self.app.websocket("/api/v1/chat/ws")(self.chat_ws) @staticmethod def _resolve_open_username( @@ -181,6 +189,348 @@ class OpenApiRoute(Route): finally: g.username = original_username + @staticmethod + def _extract_ws_api_key() -> str | None: + if key := websocket.args.get("api_key"): + return key.strip() + if key := websocket.args.get("key"): + return key.strip() + if key := websocket.headers.get("X-API-Key"): + return key.strip() + + auth_header = websocket.headers.get("Authorization", "").strip() + if auth_header.startswith("Bearer "): + return auth_header.removeprefix("Bearer ").strip() + if auth_header.startswith("ApiKey "): + return auth_header.removeprefix("ApiKey ").strip() + return None + + async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]: + raw_key = self._extract_ws_api_key() + if not raw_key: + return False, "Missing API key" + + key_hash = hashlib.pbkdf2_hmac( + "sha256", + raw_key.encode("utf-8"), + b"astrbot_api_key", + 100_000, + ).hex() + api_key = await self.db.get_active_api_key_by_hash(key_hash) + if not api_key: + return False, "Invalid API key" + + if isinstance(api_key.scopes, list): + scopes = api_key.scopes + else: + scopes = list(ALL_OPEN_API_SCOPES) + + if "*" not in scopes and "chat" not in scopes: + return False, "Insufficient API key scope" + + await self.db.touch_api_key(api_key.key_id) + return True, None + + async def _send_chat_ws_error(self, message: str, code: str) -> None: + await websocket.send_json( + { + "type": "error", + "code": code, + "data": message, + } + ) + + async def _update_session_config_route( + self, + *, + username: str, + session_id: str, + config_id: str | None, + ) -> str | None: + if not config_id: + return None + + umo = f"webchat:FriendMessage:webchat!{username}!{session_id}" + try: + if config_id == "default": + await self.core_lifecycle.umop_config_router.delete_route(umo) + else: + await self.core_lifecycle.umop_config_router.update_route( + umo, config_id + ) + except Exception as e: + logger.error( + "Failed to update chat config route for %s with %s: %s", + umo, + config_id, + e, + exc_info=True, + ) + return f"Failed to update chat config route: {e}" + return None + + async def _handle_chat_ws_send(self, post_data: dict) -> None: + effective_username, username_err = self._resolve_open_username( + post_data.get("username") + ) + if username_err or not effective_username: + await self._send_chat_ws_error( + username_err or "Invalid username", "BAD_USER" + ) + return + + message = post_data.get("message") + if message is None: + await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE") + return + + raw_session_id = post_data.get("session_id", post_data.get("conversation_id")) + session_id = str(raw_session_id).strip() if raw_session_id is not None else "" + if not session_id: + session_id = str(uuid4()) + + ensure_session_err = await self._ensure_chat_session( + effective_username, + session_id, + ) + if ensure_session_err: + await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR") + return + + config_id, resolve_err = self._resolve_chat_config_id(post_data) + if resolve_err: + await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR") + return + + config_err = await self._update_session_config_route( + username=effective_username, + session_id=session_id, + config_id=config_id, + ) + if config_err: + await self._send_chat_ws_error(config_err, "CONFIG_ERROR") + return + + message_parts = await self.chat_route._build_user_message_parts(message) + if not webchat_message_parts_have_content(message_parts): + await self._send_chat_ws_error( + "Message content is empty (reply only is not allowed)", + "INVALID_MESSAGE", + ) + return + + message_id = str(post_data.get("message_id") or uuid4()) + selected_provider = post_data.get("selected_provider") + selected_model = post_data.get("selected_model") + enable_streaming = post_data.get("enable_streaming", True) + + back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id) + try: + chat_queue = webchat_queue_mgr.get_or_create_queue(session_id) + await chat_queue.put( + ( + effective_username, + session_id, + { + "message": message_parts, + "selected_provider": selected_provider, + "selected_model": selected_model, + "enable_streaming": enable_streaming, + "message_id": message_id, + }, + ) + ) + + message_parts_for_storage = strip_message_parts_path_fields(message_parts) + await self.chat_route.platform_history_mgr.insert( + platform_id="webchat", + user_id=session_id, + content={"type": "user", "message": message_parts_for_storage}, + sender_id=effective_username, + sender_name=effective_username, + ) + + await websocket.send_json( + { + "type": "session_id", + "data": None, + "session_id": session_id, + "message_id": message_id, + } + ) + + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + tool_calls = {} + agent_stats = {} + refs = {} + while True: + try: + result = await asyncio.wait_for(back_queue.get(), timeout=1) + except asyncio.TimeoutError: + continue + + if not result: + continue + + if "message_id" in result and result["message_id"] != message_id: + logger.warning("openapi ws stream message_id mismatch") + continue + + result_text = result.get("data", "") + msg_type = result.get("type") + streaming = result.get("streaming", False) + chain_type = result.get("chain_type") + + if chain_type == "agent_stats": + try: + stats_info = { + "type": "agent_stats", + "data": json.loads(result_text), + } + await websocket.send_json(stats_info) + agent_stats = stats_info["data"] + except Exception: + pass + continue + + await websocket.send_json(result) + + if msg_type == "plain": + if chain_type == "tool_call": + tool_call = json.loads(result_text) + tool_calls[tool_call.get("id")] = tool_call + if accumulated_text: + accumulated_parts.append( + {"type": "plain", "text": accumulated_text} + ) + accumulated_text = "" + elif chain_type == "tool_call_result": + tcr = json.loads(result_text) + tc_id = tcr.get("id") + if tc_id in tool_calls: + tool_calls[tc_id]["result"] = tcr.get("result") + tool_calls[tc_id]["finished_ts"] = tcr.get("ts") + accumulated_parts.append( + {"type": "tool_call", "tool_calls": [tool_calls[tc_id]]} + ) + tool_calls.pop(tc_id, None) + elif chain_type == "reasoning": + accumulated_reasoning += result_text + elif streaming: + accumulated_text += result_text + else: + accumulated_text = result_text + elif msg_type == "image": + filename = str(result_text).replace("[IMAGE]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "image" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "record": + filename = str(result_text).replace("[RECORD]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "record" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "file": + filename = str(result_text).replace("[FILE]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "file" + ) + if part: + accumulated_parts.append(part) + elif msg_type == "video": + filename = str(result_text).replace("[VIDEO]", "") + part = await self.chat_route._create_attachment_from_file( + filename, "video" + ) + if part: + accumulated_parts.append(part) + + if msg_type == "end": + break + if (streaming and msg_type == "complete") or not streaming: + if chain_type in ("tool_call", "tool_call_result"): + continue + try: + refs = self.chat_route._extract_web_search_refs( + accumulated_text, + accumulated_parts, + ) + except Exception as e: + logger.exception( + f"Open API WS failed to extract web search refs: {e}", + exc_info=True, + ) + + saved_record = await self.chat_route._save_bot_message( + session_id, + accumulated_text, + accumulated_parts, + accumulated_reasoning, + agent_stats, + refs, + ) + if saved_record: + await websocket.send_json( + { + "type": "message_saved", + "data": { + "id": saved_record.id, + "created_at": saved_record.created_at.astimezone().isoformat(), + }, + "session_id": session_id, + } + ) + accumulated_parts = [] + accumulated_text = "" + accumulated_reasoning = "" + agent_stats = {} + refs = {} + except Exception as e: + logger.exception(f"Open API WS chat failed: {e}", exc_info=True) + await self._send_chat_ws_error( + f"Failed to process message: {e}", "PROCESSING_ERROR" + ) + finally: + webchat_queue_mgr.remove_back_queue(message_id) + + async def chat_ws(self) -> None: + authed, auth_err = await self._authenticate_chat_ws_api_key() + if not authed: + await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED") + await websocket.close(1008, auth_err or "Unauthorized") + return + + try: + while True: + message = await websocket.receive_json() + if not isinstance(message, dict): + await self._send_chat_ws_error( + "message must be an object", + "INVALID_MESSAGE", + ) + continue + + msg_type = message.get("t", "send") + if msg_type == "ping": + await websocket.send_json({"type": "pong"}) + continue + if msg_type != "send": + await self._send_chat_ws_error( + f"Unsupported message type: {msg_type}", + "INVALID_MESSAGE", + ) + continue + + await self._handle_chat_ws_send(message) + except Exception as e: + logger.debug("Open API WS connection closed: %s", e) + async def upload_file(self): return await self.chat_route.post_file() diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index a9631fc09..a9650cd06 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -204,6 +204,10 @@ class AstrBotDashboard: @staticmethod def _extract_raw_api_key() -> str | None: + if key := request.args.get("api_key"): + return key.strip() + if key := request.args.get("key"): + return key.strip() if key := request.headers.get("X-API-Key"): return key.strip() auth_header = request.headers.get("Authorization", "").strip() @@ -217,6 +221,7 @@ class AstrBotDashboard: def _get_required_open_api_scope(path: str) -> str | None: scope_map = { "/api/v1/chat": "chat", + "/api/v1/chat/ws": "chat", "/api/v1/chat/sessions": "chat", "/api/v1/configs": "config", "/api/v1/file": "file",