feat(chat): add websocket API key extraction and scope validation

This commit is contained in:
Soulter
2026-02-25 17:40:55 +08:00
parent f8f7e6d57a
commit cfb0538b32
2 changed files with 356 additions and 1 deletions
+351 -1
View File
@@ -1,6 +1,9 @@
import asyncio
import hashlib
import json
from uuid import uuid4 from uuid import uuid4
from quart import g, request from quart import g, request, websocket
from astrbot.core import logger from astrbot.core import logger
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
@@ -8,8 +11,12 @@ from astrbot.core.db import BaseDatabase
from astrbot.core.platform.message_session import MessageSesion from astrbot.core.platform.message_session import MessageSesion
from astrbot.core.platform.sources.webchat.message_parts_helper import ( from astrbot.core.platform.sources.webchat.message_parts_helper import (
build_message_chain_from_payload, build_message_chain_from_payload,
strip_message_parts_path_fields,
webchat_message_parts_have_content,
) )
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
from .api_key import ALL_OPEN_API_SCOPES
from .chat import ChatRoute from .chat import ChatRoute
from .route import Response, Route, RouteContext from .route import Response, Route, RouteContext
@@ -37,6 +44,7 @@ class OpenApiRoute(Route):
"/v1/im/bots": ("GET", self.get_bots), "/v1/im/bots": ("GET", self.get_bots),
} }
self.register_routes() self.register_routes()
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
@staticmethod @staticmethod
def _resolve_open_username( def _resolve_open_username(
@@ -181,6 +189,348 @@ class OpenApiRoute(Route):
finally: finally:
g.username = original_username g.username = original_username
@staticmethod
def _extract_ws_api_key() -> str | None:
if key := websocket.args.get("api_key"):
return key.strip()
if key := websocket.args.get("key"):
return key.strip()
if key := websocket.headers.get("X-API-Key"):
return key.strip()
auth_header = websocket.headers.get("Authorization", "").strip()
if auth_header.startswith("Bearer "):
return auth_header.removeprefix("Bearer ").strip()
if auth_header.startswith("ApiKey "):
return auth_header.removeprefix("ApiKey ").strip()
return None
async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]:
raw_key = self._extract_ws_api_key()
if not raw_key:
return False, "Missing API key"
key_hash = hashlib.pbkdf2_hmac(
"sha256",
raw_key.encode("utf-8"),
b"astrbot_api_key",
100_000,
).hex()
api_key = await self.db.get_active_api_key_by_hash(key_hash)
if not api_key:
return False, "Invalid API key"
if isinstance(api_key.scopes, list):
scopes = api_key.scopes
else:
scopes = list(ALL_OPEN_API_SCOPES)
if "*" not in scopes and "chat" not in scopes:
return False, "Insufficient API key scope"
await self.db.touch_api_key(api_key.key_id)
return True, None
async def _send_chat_ws_error(self, message: str, code: str) -> None:
await websocket.send_json(
{
"type": "error",
"code": code,
"data": message,
}
)
async def _update_session_config_route(
self,
*,
username: str,
session_id: str,
config_id: str | None,
) -> str | None:
if not config_id:
return None
umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
try:
if config_id == "default":
await self.core_lifecycle.umop_config_router.delete_route(umo)
else:
await self.core_lifecycle.umop_config_router.update_route(
umo, config_id
)
except Exception as e:
logger.error(
"Failed to update chat config route for %s with %s: %s",
umo,
config_id,
e,
exc_info=True,
)
return f"Failed to update chat config route: {e}"
return None
async def _handle_chat_ws_send(self, post_data: dict) -> None:
effective_username, username_err = self._resolve_open_username(
post_data.get("username")
)
if username_err or not effective_username:
await self._send_chat_ws_error(
username_err or "Invalid username", "BAD_USER"
)
return
message = post_data.get("message")
if message is None:
await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE")
return
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
if not session_id:
session_id = str(uuid4())
ensure_session_err = await self._ensure_chat_session(
effective_username,
session_id,
)
if ensure_session_err:
await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR")
return
config_id, resolve_err = self._resolve_chat_config_id(post_data)
if resolve_err:
await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR")
return
config_err = await self._update_session_config_route(
username=effective_username,
session_id=session_id,
config_id=config_id,
)
if config_err:
await self._send_chat_ws_error(config_err, "CONFIG_ERROR")
return
message_parts = await self.chat_route._build_user_message_parts(message)
if not webchat_message_parts_have_content(message_parts):
await self._send_chat_ws_error(
"Message content is empty (reply only is not allowed)",
"INVALID_MESSAGE",
)
return
message_id = str(post_data.get("message_id") or uuid4())
selected_provider = post_data.get("selected_provider")
selected_model = post_data.get("selected_model")
enable_streaming = post_data.get("enable_streaming", True)
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
try:
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
await chat_queue.put(
(
effective_username,
session_id,
{
"message": message_parts,
"selected_provider": selected_provider,
"selected_model": selected_model,
"enable_streaming": enable_streaming,
"message_id": message_id,
},
)
)
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
await self.chat_route.platform_history_mgr.insert(
platform_id="webchat",
user_id=session_id,
content={"type": "user", "message": message_parts_for_storage},
sender_id=effective_username,
sender_name=effective_username,
)
await websocket.send_json(
{
"type": "session_id",
"data": None,
"session_id": session_id,
"message_id": message_id,
}
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
tool_calls = {}
agent_stats = {}
refs = {}
while True:
try:
result = await asyncio.wait_for(back_queue.get(), timeout=1)
except asyncio.TimeoutError:
continue
if not result:
continue
if "message_id" in result and result["message_id"] != message_id:
logger.warning("openapi ws stream message_id mismatch")
continue
result_text = result.get("data", "")
msg_type = result.get("type")
streaming = result.get("streaming", False)
chain_type = result.get("chain_type")
if chain_type == "agent_stats":
try:
stats_info = {
"type": "agent_stats",
"data": json.loads(result_text),
}
await websocket.send_json(stats_info)
agent_stats = stats_info["data"]
except Exception:
pass
continue
await websocket.send_json(result)
if msg_type == "plain":
if chain_type == "tool_call":
tool_call = json.loads(result_text)
tool_calls[tool_call.get("id")] = tool_call
if accumulated_text:
accumulated_parts.append(
{"type": "plain", "text": accumulated_text}
)
accumulated_text = ""
elif chain_type == "tool_call_result":
tcr = json.loads(result_text)
tc_id = tcr.get("id")
if tc_id in tool_calls:
tool_calls[tc_id]["result"] = tcr.get("result")
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
accumulated_parts.append(
{"type": "tool_call", "tool_calls": [tool_calls[tc_id]]}
)
tool_calls.pop(tc_id, None)
elif chain_type == "reasoning":
accumulated_reasoning += result_text
elif streaming:
accumulated_text += result_text
else:
accumulated_text = result_text
elif msg_type == "image":
filename = str(result_text).replace("[IMAGE]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "image"
)
if part:
accumulated_parts.append(part)
elif msg_type == "record":
filename = str(result_text).replace("[RECORD]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "record"
)
if part:
accumulated_parts.append(part)
elif msg_type == "file":
filename = str(result_text).replace("[FILE]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "file"
)
if part:
accumulated_parts.append(part)
elif msg_type == "video":
filename = str(result_text).replace("[VIDEO]", "")
part = await self.chat_route._create_attachment_from_file(
filename, "video"
)
if part:
accumulated_parts.append(part)
if msg_type == "end":
break
if (streaming and msg_type == "complete") or not streaming:
if chain_type in ("tool_call", "tool_call_result"):
continue
try:
refs = self.chat_route._extract_web_search_refs(
accumulated_text,
accumulated_parts,
)
except Exception as e:
logger.exception(
f"Open API WS failed to extract web search refs: {e}",
exc_info=True,
)
saved_record = await self.chat_route._save_bot_message(
session_id,
accumulated_text,
accumulated_parts,
accumulated_reasoning,
agent_stats,
refs,
)
if saved_record:
await websocket.send_json(
{
"type": "message_saved",
"data": {
"id": saved_record.id,
"created_at": saved_record.created_at.astimezone().isoformat(),
},
"session_id": session_id,
}
)
accumulated_parts = []
accumulated_text = ""
accumulated_reasoning = ""
agent_stats = {}
refs = {}
except Exception as e:
logger.exception(f"Open API WS chat failed: {e}", exc_info=True)
await self._send_chat_ws_error(
f"Failed to process message: {e}", "PROCESSING_ERROR"
)
finally:
webchat_queue_mgr.remove_back_queue(message_id)
async def chat_ws(self) -> None:
authed, auth_err = await self._authenticate_chat_ws_api_key()
if not authed:
await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED")
await websocket.close(1008, auth_err or "Unauthorized")
return
try:
while True:
message = await websocket.receive_json()
if not isinstance(message, dict):
await self._send_chat_ws_error(
"message must be an object",
"INVALID_MESSAGE",
)
continue
msg_type = message.get("t", "send")
if msg_type == "ping":
await websocket.send_json({"type": "pong"})
continue
if msg_type != "send":
await self._send_chat_ws_error(
f"Unsupported message type: {msg_type}",
"INVALID_MESSAGE",
)
continue
await self._handle_chat_ws_send(message)
except Exception as e:
logger.debug("Open API WS connection closed: %s", e)
async def upload_file(self): async def upload_file(self):
return await self.chat_route.post_file() return await self.chat_route.post_file()
+5
View File
@@ -204,6 +204,10 @@ class AstrBotDashboard:
@staticmethod @staticmethod
def _extract_raw_api_key() -> str | None: def _extract_raw_api_key() -> str | None:
if key := request.args.get("api_key"):
return key.strip()
if key := request.args.get("key"):
return key.strip()
if key := request.headers.get("X-API-Key"): if key := request.headers.get("X-API-Key"):
return key.strip() return key.strip()
auth_header = request.headers.get("Authorization", "").strip() auth_header = request.headers.get("Authorization", "").strip()
@@ -217,6 +221,7 @@ class AstrBotDashboard:
def _get_required_open_api_scope(path: str) -> str | None: def _get_required_open_api_scope(path: str) -> str | None:
scope_map = { scope_map = {
"/api/v1/chat": "chat", "/api/v1/chat": "chat",
"/api/v1/chat/ws": "chat",
"/api/v1/chat/sessions": "chat", "/api/v1/chat/sessions": "chat",
"/api/v1/configs": "config", "/api/v1/configs": "config",
"/api/v1/file": "file", "/api/v1/file": "file",