diff --git a/src/main.rs b/src/main.rs index d9c310a..87cd65c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,10 +15,11 @@ use sqlx::Row; use sqlx::query; use sqlx::sqlite::SqlitePool; use std::env; +use tokio::sync::Mutex; use std::time::Duration; use std::time::SystemTime; use std::time::UNIX_EPOCH; -use std::{collections::HashMap, sync::Arc}; +use std::{collections::{HashMap, HashSet}, sync::Arc}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::sync::{mpsc, RwLock}; use tokio::{ @@ -154,7 +155,7 @@ static SERVER_INFO: Lazy = Lazy::new(|| { software: "https://git.moe.gift/laoxong/nostr-relay", supported_nips: vec![1, 2, 5, 65], version: env!("CARGO_PKG_VERSION"), - auth_required: false, + auth_required: true, } }); @@ -175,14 +176,18 @@ async fn main() -> Result<()> { let clients: Arc>>>> = Arc::new(RwLock::new(HashMap::new())); + let accounts = get_trust_accounts(&pool).await; + let trust_accounts: HashSet<_> = accounts.into_iter().collect(); + let trust_accounts = Arc::new(RwLock::new(trust_accounts)); + while let Ok((stream, client_addr)) = listener.accept().await { let event_tx = event_tx.clone(); let event_rx = event_tx.subscribe(); let pool = pool.clone(); let clients = clients.clone(); - + let trust_accounts = trust_accounts.clone(); tokio::spawn(async move { - if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients).await { + if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients, trust_accounts).await { error!("Error handling connection: {}", e); } }); @@ -196,6 +201,7 @@ async fn handle_connection_multiplex( rx: broadcast::Receiver, pool: Arc, clients: Arc>>>>, + trust_accounts: Arc>> ) -> Result<(), anyhow::Error> { // 分配足够大的缓冲区,一次读完整个头 let mut buf = vec![0u8; 4096]; @@ -225,7 +231,7 @@ async fn handle_connection_multiplex( .map(|&v| v.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - handle_ws_connection(stream, tx, rx, pool, clients).await; + handle_ws_connection(stream, tx, rx, pool, clients, trust_accounts).await; } else { if let Some(accept) = header_map.get("accept") { if accept.contains("application/json") || accept.contains("application/json") { @@ -427,6 +433,7 @@ async fn handle_ws_connection( mut event_rx: broadcast::Receiver, pool: Arc, clients: Arc>>>>, + trust_accounts: Arc>>, ) { let ws_stream = match accept_async(stream).await { Ok(stream) => stream, @@ -503,7 +510,8 @@ async fn handle_ws_connection( debug!("Received text: {}", text); let pool = pool.clone(); let client_conn_for_msg = client_conn.clone(); - match handle_message(&text, &pool, &client_tx, &client_conn_for_msg, &event_tx).await { + let trust_accounts = trust_accounts.clone(); + match handle_message(&text, &pool, &client_tx, &client_conn_for_msg, &event_tx, trust_accounts).await { Ok(_) => { debug!("Message handled successfully"); } @@ -531,6 +539,7 @@ async fn handle_message( to_client_msg_tx: &mpsc::Sender, client_conn: &Arc>, event_tx: &broadcast::Sender, + trust_accounts: Arc>>, ) -> Result<(), anyhow::Error> { let v = serde_json::from_str(text).map_err(|e| anyhow!("Invalid JSON: {}", e))?; @@ -548,19 +557,6 @@ async fn handle_message( .ok_or_else(|| anyhow!("First element must be a string"))? { "REQ" => { - if SERVER_INFO.auth_required { - let conn = client_conn.read().await; - if !conn.authenticated { - RelayMessage::send_closed( - &conn.id.to_string(), - "auth-required:Authentication required".to_string(), - to_client_msg_tx, - ) - .await; - debug!("Sending message to Client {}: not authenticated", conn.id); - return Ok(()); - } - } if arr.len() < 3 { RelayMessage::send_notice( "Not enough array elements".to_string(), @@ -623,19 +619,6 @@ async fn handle_message( } "EVENT" => { - if !SERVER_INFO.auth_required { - let conn = client_conn.read().await; - if !conn.authenticated { - RelayMessage::send_closed( - &conn.id.to_string(), - "auth-required:Authentication required".to_string(), - to_client_msg_tx, - ) - .await; - debug!("Sending message to Client {}: not authenticated", conn.id); - return Ok(()); - } - } let event: NostrEvent = serde_json::from_value(arr[1].clone()) .map_err(|e| anyhow!("Event parse error: {}", e))?; debug!("EVENT received: {:?}", event); @@ -644,7 +627,30 @@ async fn handle_message( let to_client_msg_tx = to_client_msg_tx.clone(); let event_tx = event_tx.clone(); let event = event.clone(); - + let conn = client_conn.read().await; + if SERVER_INFO.auth_required { + let trust_accounts = trust_accounts.read().await; + if !trust_accounts.contains(&event.pubkey) { + RelayMessage::send_closed( + &conn.id.to_string(), + "restricted: you are not in the trust list".to_string(), + &to_client_msg_tx, + ) + .await; + debug!("Sending message to Client {}:restricted: client not in the trust list", conn.id); + return Ok(()); + } + if !conn.authenticated { + RelayMessage::send_closed( + &conn.id.to_string(), + "auth-required:Authentication required".to_string(), + &to_client_msg_tx, + ) + .await; + debug!("Sending message to Client {}: not authenticated", conn.id); + return Ok(()); + } + } tokio::spawn(async move { match event.save(&pool).await { Ok(_) => { @@ -653,6 +659,15 @@ async fn handle_message( if let Err(e) = event_tx.send(serde_json::to_string(&event).unwrap()) { error!("Failed to broadcast event: {}", e); } + //Update trust list + if event.kind == 3 && event.pubkey == SERVER_INFO.pubkey{ + let mut ts = trust_accounts.write().await; + let new_trust_accounts = extract_p_tags_from_vec(&event.tags); + let mut new_trust_accounts: HashSet = new_trust_accounts.into_iter().collect(); + new_trust_accounts.insert(SERVER_INFO.pubkey.to_string()); + *ts = new_trust_accounts; + debug!("Trust list updated: {}", ts.len()); + } } Err(e) => { RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx).await; @@ -673,17 +688,6 @@ async fn handle_message( } "CLOSE" => { - if SERVER_INFO.auth_required { - if !client_conn.read().await.authenticated { - RelayMessage::send_closed( - "AUTH", - "auth-required:Authentication required".to_string(), - to_client_msg_tx, - ) - .await; - return Ok(()); - } - } let sub_id = arr .get(1) .and_then(Value::as_str) @@ -1143,3 +1147,61 @@ impl RelayMessage { } +async fn get_trust_accounts(pool: &SqlitePool) -> Vec { + let pubkey = SERVER_INFO.pubkey; + let sql = "SELECT tags FROM events WHERE kind = 3 AND pubkey = ? ORDER BY created_at DESC LIMIT 1"; + + let row = match sqlx::query(sql) + .bind(pubkey) + .fetch_optional(pool) + .await + { + Ok(row) => row, + Err(e) => { + error!("Failed to execute query: {}", e); + return Vec::new(); + } + }; + match row { + Some(row) => { + let tags_json: String = row.get(0); + extract_p_tags_from_json(&tags_json).unwrap_or_else(|_| Vec::new()) + } + None => Vec::new(), + } +} + +fn extract_p_tags_from_json(tags_json: &str) -> Result, serde_json::Error> { + let parsed: Value = serde_json::from_str(tags_json)?; + + let mut trust_accounts:Vec = parsed + .as_array() + .unwrap_or(&vec![]) + .iter() + .filter_map(|item| { + item.as_array().and_then(|arr| { + // 检查是否是 "p" 标签且至少有两个元素 + if arr.len() >= 2 && arr[0].as_str() == Some("p") { + arr[1].as_str().map(|s| s.to_string()) + } else { + None + } + }) + }) + .collect(); + debug!("Trust accounts: {}", trust_accounts.len()); + trust_accounts.push(SERVER_INFO.pubkey.clone().to_string()); + Ok(trust_accounts) +} + +fn extract_p_tags_from_vec(tags: &[Vec]) -> Vec { + tags.iter() + .filter_map(|tag| { + if tag.len() >= 2 && tag[0] == "p" { + Some(tag[1].clone()) + } else { + None + } + }) + .collect() +} \ No newline at end of file