Add: Whitelist AUTH

This commit is contained in:
2025-06-30 01:33:13 +08:00
parent c912244bd7
commit 008f4b2ede
+106 -44
View File
@@ -15,10 +15,11 @@ use sqlx::Row;
use sqlx::query;
use sqlx::sqlite::SqlitePool;
use std::env;
use tokio::sync::Mutex;
use std::time::Duration;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use std::{collections::HashMap, sync::Arc};
use std::{collections::{HashMap, HashSet}, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{mpsc, RwLock};
use tokio::{
@@ -154,7 +155,7 @@ static SERVER_INFO: Lazy<ServerInfo> = Lazy::new(|| {
software: "https://git.moe.gift/laoxong/nostr-relay",
supported_nips: vec![1, 2, 5, 65],
version: env!("CARGO_PKG_VERSION"),
auth_required: false,
auth_required: true,
}
});
@@ -175,14 +176,18 @@ async fn main() -> Result<()> {
let clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>> =
Arc::new(RwLock::new(HashMap::new()));
let accounts = get_trust_accounts(&pool).await;
let trust_accounts: HashSet<_> = accounts.into_iter().collect();
let trust_accounts = Arc::new(RwLock::new(trust_accounts));
while let Ok((stream, client_addr)) = listener.accept().await {
let event_tx = event_tx.clone();
let event_rx = event_tx.subscribe();
let pool = pool.clone();
let clients = clients.clone();
let trust_accounts = trust_accounts.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients).await {
if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients, trust_accounts).await {
error!("Error handling connection: {}", e);
}
});
@@ -196,6 +201,7 @@ async fn handle_connection_multiplex(
rx: broadcast::Receiver<String>,
pool: Arc<SqlitePool>,
clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>>,
trust_accounts: Arc<RwLock<HashSet<String>>>
) -> Result<(), anyhow::Error> {
// 分配足够大的缓冲区,一次读完整个头
let mut buf = vec![0u8; 4096];
@@ -225,7 +231,7 @@ async fn handle_connection_multiplex(
.map(|&v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
handle_ws_connection(stream, tx, rx, pool, clients).await;
handle_ws_connection(stream, tx, rx, pool, clients, trust_accounts).await;
} else {
if let Some(accept) = header_map.get("accept") {
if accept.contains("application/json") || accept.contains("application/json") {
@@ -427,6 +433,7 @@ async fn handle_ws_connection(
mut event_rx: broadcast::Receiver<String>,
pool: Arc<SqlitePool>,
clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>>,
trust_accounts: Arc<RwLock<HashSet<String>>>,
) {
let ws_stream = match accept_async(stream).await {
Ok(stream) => stream,
@@ -503,7 +510,8 @@ async fn handle_ws_connection(
debug!("Received text: {}", text);
let pool = pool.clone();
let client_conn_for_msg = client_conn.clone();
match handle_message(&text, &pool, &client_tx, &client_conn_for_msg, &event_tx).await {
let trust_accounts = trust_accounts.clone();
match handle_message(&text, &pool, &client_tx, &client_conn_for_msg, &event_tx, trust_accounts).await {
Ok(_) => {
debug!("Message handled successfully");
}
@@ -531,6 +539,7 @@ async fn handle_message(
to_client_msg_tx: &mpsc::Sender<String>,
client_conn: &Arc<RwLock<ClientConnection>>,
event_tx: &broadcast::Sender<String>,
trust_accounts: Arc<RwLock<HashSet<String>>>,
) -> Result<(), anyhow::Error> {
let v = serde_json::from_str(text).map_err(|e| anyhow!("Invalid JSON: {}", e))?;
@@ -548,19 +557,6 @@ async fn handle_message(
.ok_or_else(|| anyhow!("First element must be a string"))?
{
"REQ" => {
if SERVER_INFO.auth_required {
let conn = client_conn.read().await;
if !conn.authenticated {
RelayMessage::send_closed(
&conn.id.to_string(),
"auth-required:Authentication required".to_string(),
to_client_msg_tx,
)
.await;
debug!("Sending message to Client {}: not authenticated", conn.id);
return Ok(());
}
}
if arr.len() < 3 {
RelayMessage::send_notice(
"Not enough array elements".to_string(),
@@ -623,19 +619,6 @@ async fn handle_message(
}
"EVENT" => {
if !SERVER_INFO.auth_required {
let conn = client_conn.read().await;
if !conn.authenticated {
RelayMessage::send_closed(
&conn.id.to_string(),
"auth-required:Authentication required".to_string(),
to_client_msg_tx,
)
.await;
debug!("Sending message to Client {}: not authenticated", conn.id);
return Ok(());
}
}
let event: NostrEvent = serde_json::from_value(arr[1].clone())
.map_err(|e| anyhow!("Event parse error: {}", e))?;
debug!("EVENT received: {:?}", event);
@@ -644,7 +627,30 @@ async fn handle_message(
let to_client_msg_tx = to_client_msg_tx.clone();
let event_tx = event_tx.clone();
let event = event.clone();
let conn = client_conn.read().await;
if SERVER_INFO.auth_required {
let trust_accounts = trust_accounts.read().await;
if !trust_accounts.contains(&event.pubkey) {
RelayMessage::send_closed(
&conn.id.to_string(),
"restricted: you are not in the trust list".to_string(),
&to_client_msg_tx,
)
.await;
debug!("Sending message to Client {}:restricted: client not in the trust list", conn.id);
return Ok(());
}
if !conn.authenticated {
RelayMessage::send_closed(
&conn.id.to_string(),
"auth-required:Authentication required".to_string(),
&to_client_msg_tx,
)
.await;
debug!("Sending message to Client {}: not authenticated", conn.id);
return Ok(());
}
}
tokio::spawn(async move {
match event.save(&pool).await {
Ok(_) => {
@@ -653,6 +659,15 @@ async fn handle_message(
if let Err(e) = event_tx.send(serde_json::to_string(&event).unwrap()) {
error!("Failed to broadcast event: {}", e);
}
//Update trust list
if event.kind == 3 && event.pubkey == SERVER_INFO.pubkey{
let mut ts = trust_accounts.write().await;
let new_trust_accounts = extract_p_tags_from_vec(&event.tags);
let mut new_trust_accounts: HashSet<String> = new_trust_accounts.into_iter().collect();
new_trust_accounts.insert(SERVER_INFO.pubkey.to_string());
*ts = new_trust_accounts;
debug!("Trust list updated: {}", ts.len());
}
}
Err(e) => {
RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx).await;
@@ -673,17 +688,6 @@ async fn handle_message(
}
"CLOSE" => {
if SERVER_INFO.auth_required {
if !client_conn.read().await.authenticated {
RelayMessage::send_closed(
"AUTH",
"auth-required:Authentication required".to_string(),
to_client_msg_tx,
)
.await;
return Ok(());
}
}
let sub_id = arr
.get(1)
.and_then(Value::as_str)
@@ -1143,3 +1147,61 @@ impl RelayMessage {
}
async fn get_trust_accounts(pool: &SqlitePool) -> Vec<String> {
let pubkey = SERVER_INFO.pubkey;
let sql = "SELECT tags FROM events WHERE kind = 3 AND pubkey = ? ORDER BY created_at DESC LIMIT 1";
let row = match sqlx::query(sql)
.bind(pubkey)
.fetch_optional(pool)
.await
{
Ok(row) => row,
Err(e) => {
error!("Failed to execute query: {}", e);
return Vec::new();
}
};
match row {
Some(row) => {
let tags_json: String = row.get(0);
extract_p_tags_from_json(&tags_json).unwrap_or_else(|_| Vec::new())
}
None => Vec::new(),
}
}
fn extract_p_tags_from_json(tags_json: &str) -> Result<Vec<String>, serde_json::Error> {
let parsed: Value = serde_json::from_str(tags_json)?;
let mut trust_accounts:Vec<String> = parsed
.as_array()
.unwrap_or(&vec![])
.iter()
.filter_map(|item| {
item.as_array().and_then(|arr| {
// 检查是否是 "p" 标签且至少有两个元素
if arr.len() >= 2 && arr[0].as_str() == Some("p") {
arr[1].as_str().map(|s| s.to_string())
} else {
None
}
})
})
.collect();
debug!("Trust accounts: {}", trust_accounts.len());
trust_accounts.push(SERVER_INFO.pubkey.clone().to_string());
Ok(trust_accounts)
}
fn extract_p_tags_from_vec(tags: &[Vec<String>]) -> Vec<String> {
tags.iter()
.filter_map(|tag| {
if tag.len() >= 2 && tag[0] == "p" {
Some(tag[1].clone())
} else {
None
}
})
.collect()
}