Fix: Remove Disconnected Client

This commit is contained in:
2025-08-03 16:11:18 +08:00
parent 008f4b2ede
commit 45f978af9e
+262 -183
View File
@@ -4,6 +4,7 @@ use futures_util::{SinkExt, StreamExt};
use hex;
use httparse::{EMPTY_HEADER, Request};
use log::{Level, debug, error, info, log_enabled};
use once_cell::sync::Lazy;
use secp256k1::{Message as SecpMessage, PublicKey, Secp256k1, ecdsa::Signature};
use serde::de;
use serde::{Deserialize, Serialize};
@@ -15,20 +16,22 @@ use sqlx::Row;
use sqlx::query;
use sqlx::sqlite::SqlitePool;
use std::env;
use tokio::sync::Mutex;
use std::time::Duration;
use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use std::{collections::{HashMap, HashSet}, sync::Arc};
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::{mpsc, RwLock};
use tokio::sync::Mutex;
use tokio::sync::{RwLock, mpsc};
use tokio::{
net::{TcpListener, TcpStream},
sync::broadcast,
};
use tokio_tungstenite::{WebSocketStream, accept_async, tungstenite::protocol::Message};
use uuid::Uuid;
use once_cell::sync::Lazy;
const DEFAULT_BIND_ADDR: &str = "0.0.0.0:8080";
const DEFAULT_DB_PATH: &str = "nostr.db";
@@ -135,37 +138,35 @@ struct ClientConnection {
pubkey: Option<String>,
}
static SERVER_INFO: Lazy<ServerInfo> = Lazy::new(|| {
ServerInfo {
contact: "https://www.moec.top/",
description: "Powered by laoXong.",
limitation: Limitation {
max_event_tags: 5000,
max_event_time_newer_than_now: 900,
max_event_time_older_than_now: 315576000,
max_filters: 100,
max_limit: 500,
max_message_length: 524288,
max_subid_length: 100,
max_subscriptions: 20,
min_prefix: 10,
},
name: "A rust nostr relay by laoXong",
pubkey: "63abd4f817e39cca4e6abb6e6cf3e133bb718cf8ec28b38c1645e84d7a6190c6",
software: "https://git.moe.gift/laoxong/nostr-relay",
supported_nips: vec![1, 2, 5, 65],
version: env!("CARGO_PKG_VERSION"),
auth_required: true,
}
static SERVER_INFO: Lazy<ServerInfo> = Lazy::new(|| ServerInfo {
contact: "https://www.moec.top/",
description: "Powered by laoXong.",
limitation: Limitation {
max_event_tags: 5000,
max_event_time_newer_than_now: 900,
max_event_time_older_than_now: 315576000,
max_filters: 100,
max_limit: 500,
max_message_length: 524288,
max_subid_length: 100,
max_subscriptions: 20,
min_prefix: 10,
},
name: "A rust nostr relay by laoXong",
pubkey: "63abd4f817e39cca4e6abb6e6cf3e133bb718cf8ec28b38c1645e84d7a6190c6",
software: "https://git.moe.gift/laoxong/nostr-relay",
supported_nips: vec![1, 2, 5, 42, 65],
version: env!("CARGO_PKG_VERSION"),
auth_required: env::var("AUTH_REQUIRED").unwrap_or_else(|_| "False".to_string()) == "True",
});
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();
let db_path = env::var("DB_PATH").unwrap_or_else(|_| "nostr.db".to_string());
let db_url = format!("sqlite://{}", db_path);
let pool = SqlitePool::connect(&db_url).await?;
info!("Database pool connected successfully");
init_database(&pool).await?;
let pool = Arc::new(pool);
info!("Connected to SQLite");
@@ -173,13 +174,13 @@ async fn main() -> Result<()> {
let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
info!("Listening on: {}", addr);
let (event_tx, _) = broadcast::channel::<String>(100);
let clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>> =
let clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>> =
Arc::new(RwLock::new(HashMap::new()));
let accounts = get_trust_accounts(&pool).await;
let trust_accounts: HashSet<_> = accounts.into_iter().collect();
let trust_accounts = Arc::new(RwLock::new(trust_accounts));
while let Ok((stream, client_addr)) = listener.accept().await {
let event_tx = event_tx.clone();
let event_rx = event_tx.subscribe();
@@ -187,7 +188,16 @@ async fn main() -> Result<()> {
let clients = clients.clone();
let trust_accounts = trust_accounts.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients, trust_accounts).await {
if let Err(e) = handle_connection_multiplex(
stream,
event_tx,
event_rx,
pool,
clients,
trust_accounts,
)
.await
{
error!("Error handling connection: {}", e);
}
});
@@ -201,7 +211,7 @@ async fn handle_connection_multiplex(
rx: broadcast::Receiver<String>,
pool: Arc<SqlitePool>,
clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>>,
trust_accounts: Arc<RwLock<HashSet<String>>>
trust_accounts: Arc<RwLock<HashSet<String>>>,
) -> Result<(), anyhow::Error> {
// 分配足够大的缓冲区,一次读完整个头
let mut buf = vec![0u8; 4096];
@@ -253,7 +263,6 @@ async fn handle_http_info(mut stream: TcpStream) -> Result<(), anyhow::Error> {
let mut buffer = vec![0; 1024];
stream.read(&mut buffer).await?;
let json = serde_json::to_string(&*SERVER_INFO).expect("Failed to serialize server info");
let response = format!(
"HTTP/1.1 200 OK\r\n\
@@ -448,7 +457,10 @@ async fn handle_ws_connection(
let client_id = Uuid::new_v4();
let client_conn = Arc::new(RwLock::new(ClientConnection {
id: client_id.clone(),
connected_at: SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
connected_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs(),
subscriptions: HashMap::new(),
sender: client_tx.clone(),
authenticated: false,
@@ -459,7 +471,7 @@ async fn handle_ws_connection(
let mut clients_map = clients.write().await;
clients_map.insert(client_id, client_conn.clone());
}
info!("Client {} connected", client_id);
let auth_msg = format!("[\"AUTH\", \"{}\"]", client_id);
if let Err(e) = ws_sender.send(Message::Text(auth_msg.into())).await {
@@ -477,12 +489,12 @@ async fn handle_ws_connection(
for (sub_id, subscription) in &conn.subscriptions {
if subscription.filters.iter().any(|filter| filter.matches(&event)) {
let msg = format!(
"[\"EVENT\", \"{}\", {}]",
sub_id,
"[\"EVENT\", \"{}\", {}]",
sub_id,
serde_json::to_string(&event).unwrap()
);
if ws_sender.send(Message::Text(msg.into())).await.is_err() {
return;
break;
}
}
}
@@ -511,21 +523,40 @@ async fn handle_ws_connection(
let pool = pool.clone();
let client_conn_for_msg = client_conn.clone();
let trust_accounts = trust_accounts.clone();
match handle_message(&text, &pool, &client_tx, &client_conn_for_msg, &event_tx, trust_accounts).await {
match handle_message(
&text,
&pool,
&client_tx,
&client_conn_for_msg,
&event_tx,
trust_accounts,
)
.await
{
Ok(_) => {
debug!("Message handled successfully");
}
Err(e) => {
error!("Error handling message: {}", e);
client_tx
.send(format!("Error: {}", e))
.await
.unwrap();
client_tx.send(format!("Error: {}", e)).await.unwrap();
}
}
}
Ok(Message::Close(_)) => {
info!("Client {} disconnected", client_id);
{
let mut clients_map = clients.write().await;
clients_map.remove(&client_id);
}
send_task.abort();
break;
}
Err(_e) => {
{
let mut clients_map = clients.write().await;
clients_map.remove(&client_id);
}
send_task.abort();
break;
}
_ => {}
@@ -576,10 +607,11 @@ async fn handle_message(
let conn = client_conn.read().await;
if conn.subscriptions.len() >= MAX_SUBSCRIPTIONS {
RelayMessage::send_closed(
&sub_id,
format!("Maximum subscriptions ({}) exceeded", MAX_SUBSCRIPTIONS),
to_client_msg_tx
).await;
&sub_id,
format!("Maximum subscriptions ({}) exceeded", MAX_SUBSCRIPTIONS),
to_client_msg_tx,
)
.await;
return Ok(());
}
}
@@ -587,24 +619,29 @@ async fn handle_message(
for i in 2..arr.len() {
if filters.len() >= MAX_FILTERS_PER_REQ {
RelayMessage::send_closed(
&sub_id,
format!("Maximum filters ({}) exceeded", MAX_FILTERS_PER_REQ),
to_client_msg_tx
).await;
&sub_id,
format!("Maximum filters ({}) exceeded", MAX_FILTERS_PER_REQ),
to_client_msg_tx,
)
.await;
return Ok(());
}
let filter: Filter = serde_json::from_value(arr[i].clone())
.map_err(|e| anyhow!("Filter parse error: {}", e))?;
filters.push(filter);
}
debug!("REQ subscription: {}, filters: {:?}", sub_id, filters);
{
let mut conn = client_conn.write().await;
conn.subscriptions.insert(sub_id.clone(), Subscription { filters: filters.clone() });
conn.subscriptions.insert(
sub_id.clone(),
Subscription {
filters: filters.clone(),
},
);
}
// 查询历史事件
for filter in &filters {
let events = filter.select(pool).await?;
@@ -612,7 +649,7 @@ async fn handle_message(
RelayMessage::send_event(&event, &sub_id, to_client_msg_tx).await;
}
}
// 发送 EOSE
RelayMessage::send_eose(&sub_id, to_client_msg_tx).await;
Ok(())
@@ -637,7 +674,10 @@ async fn handle_message(
&to_client_msg_tx,
)
.await;
debug!("Sending message to Client {}:restricted: client not in the trust list", conn.id);
debug!(
"Sending message to Client {}:restricted: client not in the trust list",
conn.id
);
return Ok(());
}
if !conn.authenticated {
@@ -654,23 +694,31 @@ async fn handle_message(
tokio::spawn(async move {
match event.save(&pool).await {
Ok(_) => {
RelayMessage::send_ok(&event, true, "Saved".to_string(), &to_client_msg_tx).await;
RelayMessage::send_ok(
&event,
true,
"Saved".to_string(),
&to_client_msg_tx,
)
.await;
// 广播事件给所有客户端
if let Err(e) = event_tx.send(serde_json::to_string(&event).unwrap()) {
error!("Failed to broadcast event: {}", e);
}
//Update trust list
if event.kind == 3 && event.pubkey == SERVER_INFO.pubkey{
if event.kind == 3 && event.pubkey == SERVER_INFO.pubkey {
let mut ts = trust_accounts.write().await;
let new_trust_accounts = extract_p_tags_from_vec(&event.tags);
let mut new_trust_accounts: HashSet<String> = new_trust_accounts.into_iter().collect();
let mut new_trust_accounts: HashSet<String> =
new_trust_accounts.into_iter().collect();
new_trust_accounts.insert(SERVER_INFO.pubkey.to_string());
*ts = new_trust_accounts;
debug!("Trust list updated: {}", ts.len());
}
}
Err(e) => {
RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx).await;
RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx)
.await;
error!("Failed to save event: {}", e);
}
}
@@ -698,7 +746,7 @@ async fn handle_message(
let mut conn = client_conn.write().await;
conn.subscriptions.remove(&sub_id);
}
debug!("CLOSE subscription: {}", sub_id);
Ok(())
@@ -713,11 +761,13 @@ async fn handle_message(
.await;
return Ok(());
}
let auth_event: NostrEvent = serde_json::from_value(arr[1].clone())
.map_err(|e| anyhow!("AUTH event parse error: {}", e))?;
auth_event.handle_auth_event(client_conn, to_client_msg_tx).await?;
auth_event
.handle_auth_event(client_conn, to_client_msg_tx)
.await?;
Ok(())
}
other => {
@@ -772,27 +822,66 @@ impl NostrEvent {
}
fn verify(&self) -> bool {
// 1. 验证ID
let mut hasher = Sha256::new();
hasher.update(&self.serialize());
let result = hasher.finalize();
if hex::encode(result) != self.id {
return false;
}
// 2. 验证时间戳
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if self.created_at > now + 900 {
return false;
}
// 验证标签数量
// 3. 验证标签数量
if self.tags.len() > MAX_EVENT_TAGS as usize {
return false;
}
return true;
// 4. 解析公钥
let pubkey_bytes: Vec<u8> = match hex::decode(&self.pubkey) {
Ok(bytes) => bytes,
Err(_) => return false,
};
let public_key: PublicKey = match PublicKey::from_slice(&pubkey_bytes) {
Ok(key) => key,
Err(_) => return false,
};
// 5. 解析签名
let sig_bytes: Vec<u8> = match hex::decode(&self.sig) {
Ok(bytes) => bytes,
Err(_) => return false,
};
let signature: Signature = match Signature::from_compact(&sig_bytes) {
Ok(sig) => sig,
Err(_) => return false,
};
let event_hash: [u8; 32] = match hex::decode(&self.id) {
Ok(bytes) => {
let mut hash = [0u8; 32];
hash.copy_from_slice(&bytes);
hash
}
Err(_) => return false,
};
// 6. 创建消息对象
let message: SecpMessage = secp256k1::Message::from_digest(event_hash);
// 7. 验证签名
let secp = Secp256k1::verification_only();
match secp.verify_ecdsa(message, &signature, &public_key) {
Ok(_) => true,
Err(_) => false,
}
}
async fn save(&self, pool: &SqlitePool) -> Result<(), sqlx::Error> {
@@ -837,107 +926,102 @@ impl NostrEvent {
Ok(())
}
async fn handle_auth_event(
&self,
client_conn: &Arc<RwLock<ClientConnection>>,
to_client_msg_tx: &mpsc::Sender<String>,
) -> Result<(), anyhow::Error> {
// 验证认证事件
if self.kind != 22242 {
RelayMessage::send_ok(
&self,
false,
"AUTH event must be kind 22242".to_string(),
to_client_msg_tx,
)
.await;
return Ok(());
}
// 验证事件签名
if !self.verify() {
RelayMessage::send_notice(
"Invalid AUTH event signature".to_string(),
to_client_msg_tx,
)
.await;
return Ok(());
}
// 检查挑战
let mut relay_url = None;
let mut challenge = None;
for tag in &self.tags {
if tag.len() >= 2 {
match tag[0].as_str() {
"relay" => relay_url = Some(&tag[1]),
"challenge" => challenge = Some(&tag[1]),
_ => {}
}
}
}
let challenge = match challenge {
Some(c) => c,
None => {
RelayMessage::send_notice(
"AUTH event missing challenge tag".to_string(),
&self,
client_conn: &Arc<RwLock<ClientConnection>>,
to_client_msg_tx: &mpsc::Sender<String>,
) -> Result<(), anyhow::Error> {
// 验证认证事件
if self.kind != 22242 {
RelayMessage::send_ok(
&self,
false,
"AUTH event must be kind 22242".to_string(),
to_client_msg_tx,
)
.await;
return Ok(());
}
};
// 验证挑战是否匹配
let is_valid_challenge = {
let conn = client_conn.read().await;
if conn.id.to_string() == *challenge {
match SystemTime::now().duration_since(UNIX_EPOCH + Duration::from_secs(conn.connected_at)) {
Ok(elapsed) => elapsed <= Duration::from_secs(15 * 60),
Err(_) => false,
}
} else {
false
// 验证事件签名
if !self.verify() {
RelayMessage::send_notice("Invalid AUTH event signature".to_string(), to_client_msg_tx)
.await;
return Ok(());
}
};
if !is_valid_challenge {
RelayMessage::send_notice(
"Invalid or expired challenge".to_string(),
to_client_msg_tx,
)
.await;
return Ok(());
}
// 验证时间戳
let now = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs();
if self.created_at < now - 600 || self.created_at > now + 60 {
RelayMessage::send_notice(
"AUTH event timestamp out of acceptable range".to_string(),
to_client_msg_tx,
)
.await;
return Ok(());
}
// 认证成功,更新客户端状态
{
let mut conn = client_conn.write().await;
conn.authenticated = true;
conn.pubkey = Some(self.pubkey.clone());
}
info!("Client authenticated with pubkey: {}", self.pubkey);
RelayMessage::send_notice(
"Authentication successful".to_string(),
to_client_msg_tx,
)
.await;
Ok(())
// 检查挑战
let mut relay_url = None;
let mut challenge = None;
for tag in &self.tags {
if tag.len() >= 2 {
match tag[0].as_str() {
"relay" => relay_url = Some(&tag[1]),
"challenge" => challenge = Some(&tag[1]),
_ => {}
}
}
}
let challenge = match challenge {
Some(c) => c,
None => {
RelayMessage::send_notice(
"AUTH event missing challenge tag".to_string(),
to_client_msg_tx,
)
.await;
return Ok(());
}
};
// 验证挑战是否匹配
let is_valid_challenge = {
let conn = client_conn.read().await;
if conn.id.to_string() == *challenge {
match SystemTime::now()
.duration_since(UNIX_EPOCH + Duration::from_secs(conn.connected_at))
{
Ok(elapsed) => elapsed <= Duration::from_secs(15 * 60),
Err(_) => false,
}
} else {
false
}
};
if !is_valid_challenge {
RelayMessage::send_notice("Invalid or expired challenge".to_string(), to_client_msg_tx)
.await;
return Ok(());
}
// 验证时间戳
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if self.created_at < now - 600 || self.created_at > now + 60 {
RelayMessage::send_notice(
"AUTH event timestamp out of acceptable range".to_string(),
to_client_msg_tx,
)
.await;
return Ok(());
}
// 认证成功,更新客户端状态
{
let mut conn = client_conn.write().await;
conn.authenticated = true;
conn.pubkey = Some(self.pubkey.clone());
}
info!("Client authenticated with pubkey: {}", self.pubkey);
RelayMessage::send_notice("Authentication successful".to_string(), to_client_msg_tx).await;
Ok(())
}
}
@@ -998,7 +1082,7 @@ impl Filter {
sql.push(" AND EXISTS (");
sql.push("SELECT 1 FROM json_each(tags) AS tag_array ");
sql.push("WHERE json_extract(tag_array.value, '$[0]') = ");
sql.push_bind(&tag_name[1..]); // 去掉 # 前缀
sql.push_bind(&tag_name[1..]); // 去掉 # 前缀
sql.push(" AND json_extract(tag_array.value, '$[1]') IN (");
let mut separated = sql.separated(",");
for value in tag_values {
@@ -1043,40 +1127,40 @@ impl Filter {
return false;
}
}
// 作者匹配
if let Some(authors) = &self.authors {
if !authors.is_empty() && !authors.contains(&event.pubkey) {
return false;
}
}
// 类型匹配
if let Some(kinds) = &self.kinds {
if !kinds.is_empty() && !kinds.contains(&event.kind) {
return false;
}
}
// 时间匹配
if let Some(since) = self.since {
if event.created_at < since {
return false;
}
}
if let Some(until) = self.until {
if event.created_at > until {
return false;
}
}
// 标签匹配
for (tag_name, tag_values) in &self.tag_filters {
if tag_name.starts_with('#') && !tag_values.is_empty() {
let tag_letter = &tag_name[1..];
let mut found = false;
for tag in &event.tags {
if tag.get(0).map(|s| s == tag_letter).unwrap_or(false) && tag.len() > 1 {
if tag_values.contains(&tag[1]) {
@@ -1085,14 +1169,14 @@ impl Filter {
}
}
}
if !found {
return false;
}
}
}
return true
return true;
}
}
@@ -1143,19 +1227,14 @@ impl RelayMessage {
error!("Failed to send message: {}", e);
}
}
}
async fn get_trust_accounts(pool: &SqlitePool) -> Vec<String> {
let pubkey = SERVER_INFO.pubkey;
let sql = "SELECT tags FROM events WHERE kind = 3 AND pubkey = ? ORDER BY created_at DESC LIMIT 1";
let row = match sqlx::query(sql)
.bind(pubkey)
.fetch_optional(pool)
.await
{
let sql =
"SELECT tags FROM events WHERE kind = 3 AND pubkey = ? ORDER BY created_at DESC LIMIT 1";
let row = match sqlx::query(sql).bind(pubkey).fetch_optional(pool).await {
Ok(row) => row,
Err(e) => {
error!("Failed to execute query: {}", e);
@@ -1173,8 +1252,8 @@ async fn get_trust_accounts(pool: &SqlitePool) -> Vec<String> {
fn extract_p_tags_from_json(tags_json: &str) -> Result<Vec<String>, serde_json::Error> {
let parsed: Value = serde_json::from_str(tags_json)?;
let mut trust_accounts:Vec<String> = parsed
let mut trust_accounts: Vec<String> = parsed
.as_array()
.unwrap_or(&vec![])
.iter()
@@ -1204,4 +1283,4 @@ fn extract_p_tags_from_vec(tags: &[Vec<String>]) -> Vec<String> {
}
})
.collect()
}
}