diff --git a/src/main.rs b/src/main.rs index 46d5ac2..1769bf9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,7 +18,7 @@ use std::time::SystemTime; use std::time::UNIX_EPOCH; use std::{collections::HashMap, sync::Arc}; use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, RwLock}; use tokio::{ net::{TcpListener, TcpStream}, sync::broadcast, @@ -36,7 +36,7 @@ const BROADCAST_CHANNEL_SIZE: usize = 100; const CLIENT_CHANNEL_SIZE: usize = 32; const MAX_SUBSCRIPTIONS: usize = 20; -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct NostrEvent { pub id: String, // 32 bytes hex string of sha256 pub pubkey: String, // 32 bytes hex pubkey @@ -79,7 +79,7 @@ pub enum ClientMessage { CLOSE { sub_id: String }, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub struct Filter { pub ids: Option>, // event ids 列表 pub authors: Option>, // pubkeys 列表,小写字符串 @@ -116,6 +116,16 @@ struct ServerInfo { version: &'static str, } +#[derive(Debug, Clone)] +struct Subscription { + filters: Vec, +} + +struct ClientConnection { + subscriptions: HashMap, + sender: mpsc::Sender, +} + #[tokio::main] async fn main() -> Result<()> { env_logger::init(); @@ -129,14 +139,17 @@ async fn main() -> Result<()> { let listener = TcpListener::bind(&addr).await.expect("Failed to bind"); info!("Listening on: {}", addr); let (event_tx, _) = broadcast::channel::(100); - + + let clients: Arc>>>> = + Arc::new(RwLock::new(HashMap::new())); while let Ok((stream, client_addr)) = listener.accept().await { let event_tx = event_tx.clone(); let event_rx = event_tx.subscribe(); let pool = pool.clone(); + let clients = clients.clone(); tokio::spawn(async move { - if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool).await { + if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients).await { error!("Error handling connection: {}", e); } }); @@ -149,6 +162,7 @@ async fn handle_connection_multiplex( tx: broadcast::Sender, rx: broadcast::Receiver, pool: Arc, + clients: Arc>>>>, ) -> Result<(), anyhow::Error> { // 分配足够大的缓冲区,一次读完整个头 let mut buf = vec![0u8; 4096]; @@ -178,7 +192,7 @@ async fn handle_connection_multiplex( .map(|&v| v.eq_ignore_ascii_case("websocket")) .unwrap_or(false) { - handle_ws_connection(stream, tx, rx, pool).await; + handle_ws_connection(stream, tx, rx, pool, clients).await; } else { if let Some(accept) = header_map.get("accept") { if accept.contains("application/json") || accept.contains("application/json") { @@ -398,6 +412,7 @@ async fn handle_ws_connection( event_tx: broadcast::Sender, mut event_rx: broadcast::Receiver, pool: Arc, + clients: Arc>>>>, ) { let ws_stream = match accept_async(stream).await { Ok(stream) => stream, @@ -407,18 +422,39 @@ async fn handle_ws_connection( } }; let (mut ws_sender, mut ws_reciver) = ws_stream.split(); - let (client_tx, mut client_rx) = mpsc::channel::(32); + let (client_tx, mut client_rx) = mpsc::channel::(CLIENT_CHANNEL_SIZE); + + let client_id = Uuid::new_v4(); + let client_conn = Arc::new(RwLock::new(ClientConnection { + subscriptions: HashMap::new(), + sender: client_tx.clone(), + })); + // 注册客户端 + { + let mut clients_map = clients.write().await; + clients_map.insert(client_id, client_conn.clone()); + } + + info!("Client {} connected", client_id); + let client_conn_for_send = client_conn.clone(); let send_task = tokio::spawn(async move { loop { tokio::select! { //Subscribed events - Ok(msg) = event_rx.recv() => { - if ws_sender - .send(Message::Text(msg.into())) - .await - .is_err() - { - break; + Ok(event_msg) = event_rx.recv() => { + let event: NostrEvent = serde_json::from_str(&event_msg).unwrap(); + let conn = client_conn_for_send.read().await; + for (sub_id, subscription) in &conn.subscriptions { + if subscription.filters.iter().any(|filter| filter.matches(&event)) { + let msg = format!( + "[\"EVENT\", \"{}\", {}]", + sub_id, + serde_json::to_string(&event).unwrap() + ); + if ws_sender.send(Message::Text(msg.into())).await.is_err() { + return; + } + } } } //Send message to client @@ -443,7 +479,8 @@ async fn handle_ws_connection( Ok(Message::Text(text)) => { debug!("Received text: {}", text); let pool = pool.clone(); - match handle_message(&text, &pool, &client_tx).await { + let client_conn_for_msg = client_conn.clone(); + match handle_message(&text, &pool, &client_tx, &client_id, &client_conn_for_msg, &event_tx).await { Ok(_) => { debug!("Message handled successfully"); } @@ -457,7 +494,7 @@ async fn handle_ws_connection( } } Ok(Message::Close(_)) => { - info!("Client disconnected"); + info!("Client {} disconnected", client_id); break; } _ => {} @@ -469,6 +506,9 @@ async fn handle_message( text: &str, pool: &SqlitePool, to_client_msg_tx: &mpsc::Sender, + client_id: &Uuid, + client_conn: &Arc>, + event_tx: &broadcast::Sender, ) -> Result<(), anyhow::Error> { let v = serde_json::from_str(text).map_err(|e| anyhow!("Invalid JSON: {}", e))?; @@ -487,7 +527,7 @@ async fn handle_message( { "REQ" => { if arr.len() < 3 { - RelayMessage::send_message( + RelayMessage::send_notice( "Not enough array elements".to_string(), to_client_msg_tx, ) @@ -499,13 +539,50 @@ async fn handle_message( .and_then(Value::as_str) .ok_or_else(|| anyhow!("REQ missing sub_id"))? .to_string(); - let filters: Filter = serde_json::from_value(arr[2].clone()) - .map_err(|e| anyhow!("Filter parse error: {}", e))?; - debug!("REQ subscription: {}, filters: {:?}", sub_id, filters); - let r = filters.select(pool).await?; - for event in r { - RelayMessage::send_event(&event, &sub_id, to_client_msg_tx).await; + + // Check if the client has exceeded the maximum number of subscriptions + { + let conn = client_conn.read().await; + if conn.subscriptions.len() >= MAX_SUBSCRIPTIONS { + RelayMessage::send_closed( + &sub_id, + format!("Maximum subscriptions ({}) exceeded", MAX_SUBSCRIPTIONS), + to_client_msg_tx + ).await; + return Ok(()); + } } + let mut filters = Vec::new(); + for i in 2..arr.len() { + if filters.len() >= MAX_FILTERS_PER_REQ { + RelayMessage::send_closed( + &sub_id, + format!("Maximum filters ({}) exceeded", MAX_FILTERS_PER_REQ), + to_client_msg_tx + ).await; + return Ok(()); + } + let filter: Filter = serde_json::from_value(arr[i].clone()) + .map_err(|e| anyhow!("Filter parse error: {}", e))?; + filters.push(filter); + } + + + debug!("REQ subscription: {}, filters: {:?}", sub_id, filters); + { + let mut conn = client_conn.write().await; + conn.subscriptions.insert(sub_id.clone(), Subscription { filters: filters.clone() }); + } + + // 查询历史事件 + for filter in &filters { + let events = filter.select(pool).await?; + for event in events { + RelayMessage::send_event(&event, &sub_id, to_client_msg_tx).await; + } + } + + // 发送 EOSE RelayMessage::send_eose(&sub_id, to_client_msg_tx).await; Ok(()) } @@ -517,21 +594,29 @@ async fn handle_message( if event.verify() { let pool = pool.clone(); let to_client_msg_tx = to_client_msg_tx.clone(); + let event_tx = event_tx.clone(); + let event = event.clone(); + tokio::spawn(async move { - if let Err(e) = event.save(&pool).await { - RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx) - .await; - error!("Failed to save event: {}", e); - } else { - RelayMessage::send_ok(&event, true, "Saved".to_string(), &to_client_msg_tx) - .await; + match event.save(&pool).await { + Ok(_) => { + RelayMessage::send_ok(&event, true, "Saved".to_string(), &to_client_msg_tx).await; + // 广播事件给所有客户端 + if let Err(e) = event_tx.send(serde_json::to_string(&event).unwrap()) { + error!("Failed to broadcast event: {}", e); + } + } + Err(e) => { + RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx).await; + error!("Failed to save event: {}", e); + } } }); } else { RelayMessage::send_ok( &event, false, - "Invalid signature".to_string(), + "Failed to save event".to_string(), to_client_msg_tx, ) .await; @@ -545,6 +630,12 @@ async fn handle_message( .and_then(Value::as_str) .ok_or_else(|| anyhow!("CLOSE missing sub_id"))? .to_string(); + + { + let mut conn = client_conn.write().await; + conn.subscriptions.remove(&sub_id); + } + debug!("CLOSE subscription: {}", sub_id); Ok(()) @@ -717,10 +808,21 @@ impl Filter { sql.push(" AND created_at <= ").push_bind(until as i64); } - // TODO: 实现标签过滤 - // for (tag_name, tag_values) in &self.tag_filters { - // // 实现 #e, #p 等标签过滤 - // } + for (tag_name, tag_values) in &self.tag_filters { + if tag_name.starts_with('#') && !tag_values.is_empty() { + // 使用 JSON 查询来匹配标签 + sql.push(" AND EXISTS ("); + sql.push("SELECT 1 FROM json_each(tags) AS tag_array "); + sql.push("WHERE json_extract(tag_array.value, '$[0]') = "); + sql.push_bind(&tag_name[1..]); // 去掉 # 前缀 + sql.push(" AND json_extract(tag_array.value, '$[1]') IN ("); + let mut separated = sql.separated(","); + for value in tag_values { + separated.push_bind(value); + } + separated.push_unseparated("))"); + } + } sql.push(" ORDER BY created_at DESC"); @@ -750,23 +852,63 @@ impl Filter { Ok(events) } - fn matchs(&self, event: &NostrEvent) -> bool { + fn matches(&self, event: &NostrEvent) -> bool { + // ID 匹配 if let Some(ids) = &self.ids { - if ids.contains(&event.id) { - return true; + if !ids.is_empty() && !ids.contains(&event.id) { + return false; } } + + // 作者匹配 if let Some(authors) = &self.authors { - if authors.contains(&event.pubkey) { - return true; + if !authors.is_empty() && !authors.contains(&event.pubkey) { + return false; } } + + // 类型匹配 if let Some(kinds) = &self.kinds { - if kinds.contains(&event.kind) { - return true; + if !kinds.is_empty() && !kinds.contains(&event.kind) { + return false; } } - return false; + + // 时间匹配 + if let Some(since) = self.since { + if event.created_at < since { + return false; + } + } + + if let Some(until) = self.until { + if event.created_at > until { + return false; + } + } + + // 标签匹配 + for (tag_name, tag_values) in &self.tag_filters { + if tag_name.starts_with('#') && !tag_values.is_empty() { + let tag_letter = &tag_name[1..]; + let mut found = false; + + for tag in &event.tags { + if tag.get(0).map(|s| s == tag_letter).unwrap_or(false) && tag.len() > 1 { + if tag_values.contains(&tag[1]) { + found = true; + break; + } + } + } + + if !found { + return false; + } + } + } + + return true } } @@ -804,9 +946,18 @@ impl RelayMessage { } } - async fn send_message(message: String, to_client_msg_tx: &mpsc::Sender) { - if let Err(e) = to_client_msg_tx.send(message).await { + async fn send_closed(sub_id: &str, message: String, to_client_msg_tx: &mpsc::Sender) { + let msg = format!("[\"CLOSED\", \"{}\", \"{}\"]", sub_id, message); + if let Err(e) = to_client_msg_tx.send(msg).await { error!("Failed to send message: {}", e); } } + + async fn send_notice(message: String, to_client_msg_tx: &mpsc::Sender) { + let msg = "[\"NOTICE\", \"".to_string() + &message + "\"]"; + if let Err(e) = to_client_msg_tx.send(msg).await { + error!("Failed to send message: {}", e); + } + } + }