feat(chat): add websocket API key extraction and scope validation
This commit is contained in:
@@ -1,6 +1,9 @@
|
|||||||
|
import asyncio
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from quart import g, request
|
from quart import g, request, websocket
|
||||||
|
|
||||||
from astrbot.core import logger
|
from astrbot.core import logger
|
||||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||||
@@ -8,8 +11,12 @@ from astrbot.core.db import BaseDatabase
|
|||||||
from astrbot.core.platform.message_session import MessageSesion
|
from astrbot.core.platform.message_session import MessageSesion
|
||||||
from astrbot.core.platform.sources.webchat.message_parts_helper import (
|
from astrbot.core.platform.sources.webchat.message_parts_helper import (
|
||||||
build_message_chain_from_payload,
|
build_message_chain_from_payload,
|
||||||
|
strip_message_parts_path_fields,
|
||||||
|
webchat_message_parts_have_content,
|
||||||
)
|
)
|
||||||
|
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||||
|
|
||||||
|
from .api_key import ALL_OPEN_API_SCOPES
|
||||||
from .chat import ChatRoute
|
from .chat import ChatRoute
|
||||||
from .route import Response, Route, RouteContext
|
from .route import Response, Route, RouteContext
|
||||||
|
|
||||||
@@ -37,6 +44,7 @@ class OpenApiRoute(Route):
|
|||||||
"/v1/im/bots": ("GET", self.get_bots),
|
"/v1/im/bots": ("GET", self.get_bots),
|
||||||
}
|
}
|
||||||
self.register_routes()
|
self.register_routes()
|
||||||
|
self.app.websocket("/api/v1/chat/ws")(self.chat_ws)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _resolve_open_username(
|
def _resolve_open_username(
|
||||||
@@ -181,6 +189,348 @@ class OpenApiRoute(Route):
|
|||||||
finally:
|
finally:
|
||||||
g.username = original_username
|
g.username = original_username
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _extract_ws_api_key() -> str | None:
|
||||||
|
if key := websocket.args.get("api_key"):
|
||||||
|
return key.strip()
|
||||||
|
if key := websocket.args.get("key"):
|
||||||
|
return key.strip()
|
||||||
|
if key := websocket.headers.get("X-API-Key"):
|
||||||
|
return key.strip()
|
||||||
|
|
||||||
|
auth_header = websocket.headers.get("Authorization", "").strip()
|
||||||
|
if auth_header.startswith("Bearer "):
|
||||||
|
return auth_header.removeprefix("Bearer ").strip()
|
||||||
|
if auth_header.startswith("ApiKey "):
|
||||||
|
return auth_header.removeprefix("ApiKey ").strip()
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _authenticate_chat_ws_api_key(self) -> tuple[bool, str | None]:
|
||||||
|
raw_key = self._extract_ws_api_key()
|
||||||
|
if not raw_key:
|
||||||
|
return False, "Missing API key"
|
||||||
|
|
||||||
|
key_hash = hashlib.pbkdf2_hmac(
|
||||||
|
"sha256",
|
||||||
|
raw_key.encode("utf-8"),
|
||||||
|
b"astrbot_api_key",
|
||||||
|
100_000,
|
||||||
|
).hex()
|
||||||
|
api_key = await self.db.get_active_api_key_by_hash(key_hash)
|
||||||
|
if not api_key:
|
||||||
|
return False, "Invalid API key"
|
||||||
|
|
||||||
|
if isinstance(api_key.scopes, list):
|
||||||
|
scopes = api_key.scopes
|
||||||
|
else:
|
||||||
|
scopes = list(ALL_OPEN_API_SCOPES)
|
||||||
|
|
||||||
|
if "*" not in scopes and "chat" not in scopes:
|
||||||
|
return False, "Insufficient API key scope"
|
||||||
|
|
||||||
|
await self.db.touch_api_key(api_key.key_id)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
async def _send_chat_ws_error(self, message: str, code: str) -> None:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "error",
|
||||||
|
"code": code,
|
||||||
|
"data": message,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _update_session_config_route(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
username: str,
|
||||||
|
session_id: str,
|
||||||
|
config_id: str | None,
|
||||||
|
) -> str | None:
|
||||||
|
if not config_id:
|
||||||
|
return None
|
||||||
|
|
||||||
|
umo = f"webchat:FriendMessage:webchat!{username}!{session_id}"
|
||||||
|
try:
|
||||||
|
if config_id == "default":
|
||||||
|
await self.core_lifecycle.umop_config_router.delete_route(umo)
|
||||||
|
else:
|
||||||
|
await self.core_lifecycle.umop_config_router.update_route(
|
||||||
|
umo, config_id
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(
|
||||||
|
"Failed to update chat config route for %s with %s: %s",
|
||||||
|
umo,
|
||||||
|
config_id,
|
||||||
|
e,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
return f"Failed to update chat config route: {e}"
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _handle_chat_ws_send(self, post_data: dict) -> None:
|
||||||
|
effective_username, username_err = self._resolve_open_username(
|
||||||
|
post_data.get("username")
|
||||||
|
)
|
||||||
|
if username_err or not effective_username:
|
||||||
|
await self._send_chat_ws_error(
|
||||||
|
username_err or "Invalid username", "BAD_USER"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
message = post_data.get("message")
|
||||||
|
if message is None:
|
||||||
|
await self._send_chat_ws_error("Missing key: message", "INVALID_MESSAGE")
|
||||||
|
return
|
||||||
|
|
||||||
|
raw_session_id = post_data.get("session_id", post_data.get("conversation_id"))
|
||||||
|
session_id = str(raw_session_id).strip() if raw_session_id is not None else ""
|
||||||
|
if not session_id:
|
||||||
|
session_id = str(uuid4())
|
||||||
|
|
||||||
|
ensure_session_err = await self._ensure_chat_session(
|
||||||
|
effective_username,
|
||||||
|
session_id,
|
||||||
|
)
|
||||||
|
if ensure_session_err:
|
||||||
|
await self._send_chat_ws_error(ensure_session_err, "SESSION_ERROR")
|
||||||
|
return
|
||||||
|
|
||||||
|
config_id, resolve_err = self._resolve_chat_config_id(post_data)
|
||||||
|
if resolve_err:
|
||||||
|
await self._send_chat_ws_error(resolve_err, "CONFIG_ERROR")
|
||||||
|
return
|
||||||
|
|
||||||
|
config_err = await self._update_session_config_route(
|
||||||
|
username=effective_username,
|
||||||
|
session_id=session_id,
|
||||||
|
config_id=config_id,
|
||||||
|
)
|
||||||
|
if config_err:
|
||||||
|
await self._send_chat_ws_error(config_err, "CONFIG_ERROR")
|
||||||
|
return
|
||||||
|
|
||||||
|
message_parts = await self.chat_route._build_user_message_parts(message)
|
||||||
|
if not webchat_message_parts_have_content(message_parts):
|
||||||
|
await self._send_chat_ws_error(
|
||||||
|
"Message content is empty (reply only is not allowed)",
|
||||||
|
"INVALID_MESSAGE",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
message_id = str(post_data.get("message_id") or uuid4())
|
||||||
|
selected_provider = post_data.get("selected_provider")
|
||||||
|
selected_model = post_data.get("selected_model")
|
||||||
|
enable_streaming = post_data.get("enable_streaming", True)
|
||||||
|
|
||||||
|
back_queue = webchat_queue_mgr.get_or_create_back_queue(message_id, session_id)
|
||||||
|
try:
|
||||||
|
chat_queue = webchat_queue_mgr.get_or_create_queue(session_id)
|
||||||
|
await chat_queue.put(
|
||||||
|
(
|
||||||
|
effective_username,
|
||||||
|
session_id,
|
||||||
|
{
|
||||||
|
"message": message_parts,
|
||||||
|
"selected_provider": selected_provider,
|
||||||
|
"selected_model": selected_model,
|
||||||
|
"enable_streaming": enable_streaming,
|
||||||
|
"message_id": message_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
message_parts_for_storage = strip_message_parts_path_fields(message_parts)
|
||||||
|
await self.chat_route.platform_history_mgr.insert(
|
||||||
|
platform_id="webchat",
|
||||||
|
user_id=session_id,
|
||||||
|
content={"type": "user", "message": message_parts_for_storage},
|
||||||
|
sender_id=effective_username,
|
||||||
|
sender_name=effective_username,
|
||||||
|
)
|
||||||
|
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "session_id",
|
||||||
|
"data": None,
|
||||||
|
"session_id": session_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
accumulated_parts = []
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_reasoning = ""
|
||||||
|
tool_calls = {}
|
||||||
|
agent_stats = {}
|
||||||
|
refs = {}
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(back_queue.get(), timeout=1)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not result:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "message_id" in result and result["message_id"] != message_id:
|
||||||
|
logger.warning("openapi ws stream message_id mismatch")
|
||||||
|
continue
|
||||||
|
|
||||||
|
result_text = result.get("data", "")
|
||||||
|
msg_type = result.get("type")
|
||||||
|
streaming = result.get("streaming", False)
|
||||||
|
chain_type = result.get("chain_type")
|
||||||
|
|
||||||
|
if chain_type == "agent_stats":
|
||||||
|
try:
|
||||||
|
stats_info = {
|
||||||
|
"type": "agent_stats",
|
||||||
|
"data": json.loads(result_text),
|
||||||
|
}
|
||||||
|
await websocket.send_json(stats_info)
|
||||||
|
agent_stats = stats_info["data"]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
continue
|
||||||
|
|
||||||
|
await websocket.send_json(result)
|
||||||
|
|
||||||
|
if msg_type == "plain":
|
||||||
|
if chain_type == "tool_call":
|
||||||
|
tool_call = json.loads(result_text)
|
||||||
|
tool_calls[tool_call.get("id")] = tool_call
|
||||||
|
if accumulated_text:
|
||||||
|
accumulated_parts.append(
|
||||||
|
{"type": "plain", "text": accumulated_text}
|
||||||
|
)
|
||||||
|
accumulated_text = ""
|
||||||
|
elif chain_type == "tool_call_result":
|
||||||
|
tcr = json.loads(result_text)
|
||||||
|
tc_id = tcr.get("id")
|
||||||
|
if tc_id in tool_calls:
|
||||||
|
tool_calls[tc_id]["result"] = tcr.get("result")
|
||||||
|
tool_calls[tc_id]["finished_ts"] = tcr.get("ts")
|
||||||
|
accumulated_parts.append(
|
||||||
|
{"type": "tool_call", "tool_calls": [tool_calls[tc_id]]}
|
||||||
|
)
|
||||||
|
tool_calls.pop(tc_id, None)
|
||||||
|
elif chain_type == "reasoning":
|
||||||
|
accumulated_reasoning += result_text
|
||||||
|
elif streaming:
|
||||||
|
accumulated_text += result_text
|
||||||
|
else:
|
||||||
|
accumulated_text = result_text
|
||||||
|
elif msg_type == "image":
|
||||||
|
filename = str(result_text).replace("[IMAGE]", "")
|
||||||
|
part = await self.chat_route._create_attachment_from_file(
|
||||||
|
filename, "image"
|
||||||
|
)
|
||||||
|
if part:
|
||||||
|
accumulated_parts.append(part)
|
||||||
|
elif msg_type == "record":
|
||||||
|
filename = str(result_text).replace("[RECORD]", "")
|
||||||
|
part = await self.chat_route._create_attachment_from_file(
|
||||||
|
filename, "record"
|
||||||
|
)
|
||||||
|
if part:
|
||||||
|
accumulated_parts.append(part)
|
||||||
|
elif msg_type == "file":
|
||||||
|
filename = str(result_text).replace("[FILE]", "")
|
||||||
|
part = await self.chat_route._create_attachment_from_file(
|
||||||
|
filename, "file"
|
||||||
|
)
|
||||||
|
if part:
|
||||||
|
accumulated_parts.append(part)
|
||||||
|
elif msg_type == "video":
|
||||||
|
filename = str(result_text).replace("[VIDEO]", "")
|
||||||
|
part = await self.chat_route._create_attachment_from_file(
|
||||||
|
filename, "video"
|
||||||
|
)
|
||||||
|
if part:
|
||||||
|
accumulated_parts.append(part)
|
||||||
|
|
||||||
|
if msg_type == "end":
|
||||||
|
break
|
||||||
|
if (streaming and msg_type == "complete") or not streaming:
|
||||||
|
if chain_type in ("tool_call", "tool_call_result"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
refs = self.chat_route._extract_web_search_refs(
|
||||||
|
accumulated_text,
|
||||||
|
accumulated_parts,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(
|
||||||
|
f"Open API WS failed to extract web search refs: {e}",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
saved_record = await self.chat_route._save_bot_message(
|
||||||
|
session_id,
|
||||||
|
accumulated_text,
|
||||||
|
accumulated_parts,
|
||||||
|
accumulated_reasoning,
|
||||||
|
agent_stats,
|
||||||
|
refs,
|
||||||
|
)
|
||||||
|
if saved_record:
|
||||||
|
await websocket.send_json(
|
||||||
|
{
|
||||||
|
"type": "message_saved",
|
||||||
|
"data": {
|
||||||
|
"id": saved_record.id,
|
||||||
|
"created_at": saved_record.created_at.astimezone().isoformat(),
|
||||||
|
},
|
||||||
|
"session_id": session_id,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
accumulated_parts = []
|
||||||
|
accumulated_text = ""
|
||||||
|
accumulated_reasoning = ""
|
||||||
|
agent_stats = {}
|
||||||
|
refs = {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Open API WS chat failed: {e}", exc_info=True)
|
||||||
|
await self._send_chat_ws_error(
|
||||||
|
f"Failed to process message: {e}", "PROCESSING_ERROR"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
webchat_queue_mgr.remove_back_queue(message_id)
|
||||||
|
|
||||||
|
async def chat_ws(self) -> None:
|
||||||
|
authed, auth_err = await self._authenticate_chat_ws_api_key()
|
||||||
|
if not authed:
|
||||||
|
await self._send_chat_ws_error(auth_err or "Unauthorized", "UNAUTHORIZED")
|
||||||
|
await websocket.close(1008, auth_err or "Unauthorized")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
message = await websocket.receive_json()
|
||||||
|
if not isinstance(message, dict):
|
||||||
|
await self._send_chat_ws_error(
|
||||||
|
"message must be an object",
|
||||||
|
"INVALID_MESSAGE",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
msg_type = message.get("t", "send")
|
||||||
|
if msg_type == "ping":
|
||||||
|
await websocket.send_json({"type": "pong"})
|
||||||
|
continue
|
||||||
|
if msg_type != "send":
|
||||||
|
await self._send_chat_ws_error(
|
||||||
|
f"Unsupported message type: {msg_type}",
|
||||||
|
"INVALID_MESSAGE",
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
await self._handle_chat_ws_send(message)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug("Open API WS connection closed: %s", e)
|
||||||
|
|
||||||
async def upload_file(self):
|
async def upload_file(self):
|
||||||
return await self.chat_route.post_file()
|
return await self.chat_route.post_file()
|
||||||
|
|
||||||
|
|||||||
@@ -204,6 +204,10 @@ class AstrBotDashboard:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_raw_api_key() -> str | None:
|
def _extract_raw_api_key() -> str | None:
|
||||||
|
if key := request.args.get("api_key"):
|
||||||
|
return key.strip()
|
||||||
|
if key := request.args.get("key"):
|
||||||
|
return key.strip()
|
||||||
if key := request.headers.get("X-API-Key"):
|
if key := request.headers.get("X-API-Key"):
|
||||||
return key.strip()
|
return key.strip()
|
||||||
auth_header = request.headers.get("Authorization", "").strip()
|
auth_header = request.headers.get("Authorization", "").strip()
|
||||||
@@ -217,6 +221,7 @@ class AstrBotDashboard:
|
|||||||
def _get_required_open_api_scope(path: str) -> str | None:
|
def _get_required_open_api_scope(path: str) -> str | None:
|
||||||
scope_map = {
|
scope_map = {
|
||||||
"/api/v1/chat": "chat",
|
"/api/v1/chat": "chat",
|
||||||
|
"/api/v1/chat/ws": "chat",
|
||||||
"/api/v1/chat/sessions": "chat",
|
"/api/v1/chat/sessions": "chat",
|
||||||
"/api/v1/configs": "config",
|
"/api/v1/configs": "config",
|
||||||
"/api/v1/file": "file",
|
"/api/v1/file": "file",
|
||||||
|
|||||||
Reference in New Issue
Block a user