Add: Whitelist AUTH
This commit is contained in:
+106
-44
@@ -15,10 +15,11 @@ use sqlx::Row;
|
||||
use sqlx::query;
|
||||
use sqlx::sqlite::SqlitePool;
|
||||
use std::env;
|
||||
use tokio::sync::Mutex;
|
||||
use std::time::Duration;
|
||||
use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::{collections::{HashMap, HashSet}, sync::Arc};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::{
|
||||
@@ -154,7 +155,7 @@ static SERVER_INFO: Lazy<ServerInfo> = Lazy::new(|| {
|
||||
software: "https://git.moe.gift/laoxong/nostr-relay",
|
||||
supported_nips: vec![1, 2, 5, 65],
|
||||
version: env!("CARGO_PKG_VERSION"),
|
||||
auth_required: false,
|
||||
auth_required: true,
|
||||
}
|
||||
});
|
||||
|
||||
@@ -175,14 +176,18 @@ async fn main() -> Result<()> {
|
||||
|
||||
let clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>> =
|
||||
Arc::new(RwLock::new(HashMap::new()));
|
||||
let accounts = get_trust_accounts(&pool).await;
|
||||
let trust_accounts: HashSet<_> = accounts.into_iter().collect();
|
||||
let trust_accounts = Arc::new(RwLock::new(trust_accounts));
|
||||
|
||||
while let Ok((stream, client_addr)) = listener.accept().await {
|
||||
let event_tx = event_tx.clone();
|
||||
let event_rx = event_tx.subscribe();
|
||||
let pool = pool.clone();
|
||||
let clients = clients.clone();
|
||||
|
||||
let trust_accounts = trust_accounts.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients).await {
|
||||
if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients, trust_accounts).await {
|
||||
error!("Error handling connection: {}", e);
|
||||
}
|
||||
});
|
||||
@@ -196,6 +201,7 @@ async fn handle_connection_multiplex(
|
||||
rx: broadcast::Receiver<String>,
|
||||
pool: Arc<SqlitePool>,
|
||||
clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>>,
|
||||
trust_accounts: Arc<RwLock<HashSet<String>>>
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// 分配足够大的缓冲区,一次读完整个头
|
||||
let mut buf = vec![0u8; 4096];
|
||||
@@ -225,7 +231,7 @@ async fn handle_connection_multiplex(
|
||||
.map(|&v| v.eq_ignore_ascii_case("websocket"))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
handle_ws_connection(stream, tx, rx, pool, clients).await;
|
||||
handle_ws_connection(stream, tx, rx, pool, clients, trust_accounts).await;
|
||||
} else {
|
||||
if let Some(accept) = header_map.get("accept") {
|
||||
if accept.contains("application/json") || accept.contains("application/json") {
|
||||
@@ -427,6 +433,7 @@ async fn handle_ws_connection(
|
||||
mut event_rx: broadcast::Receiver<String>,
|
||||
pool: Arc<SqlitePool>,
|
||||
clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>>,
|
||||
trust_accounts: Arc<RwLock<HashSet<String>>>,
|
||||
) {
|
||||
let ws_stream = match accept_async(stream).await {
|
||||
Ok(stream) => stream,
|
||||
@@ -503,7 +510,8 @@ async fn handle_ws_connection(
|
||||
debug!("Received text: {}", text);
|
||||
let pool = pool.clone();
|
||||
let client_conn_for_msg = client_conn.clone();
|
||||
match handle_message(&text, &pool, &client_tx, &client_conn_for_msg, &event_tx).await {
|
||||
let trust_accounts = trust_accounts.clone();
|
||||
match handle_message(&text, &pool, &client_tx, &client_conn_for_msg, &event_tx, trust_accounts).await {
|
||||
Ok(_) => {
|
||||
debug!("Message handled successfully");
|
||||
}
|
||||
@@ -531,6 +539,7 @@ async fn handle_message(
|
||||
to_client_msg_tx: &mpsc::Sender<String>,
|
||||
client_conn: &Arc<RwLock<ClientConnection>>,
|
||||
event_tx: &broadcast::Sender<String>,
|
||||
trust_accounts: Arc<RwLock<HashSet<String>>>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let v = serde_json::from_str(text).map_err(|e| anyhow!("Invalid JSON: {}", e))?;
|
||||
|
||||
@@ -548,19 +557,6 @@ async fn handle_message(
|
||||
.ok_or_else(|| anyhow!("First element must be a string"))?
|
||||
{
|
||||
"REQ" => {
|
||||
if SERVER_INFO.auth_required {
|
||||
let conn = client_conn.read().await;
|
||||
if !conn.authenticated {
|
||||
RelayMessage::send_closed(
|
||||
&conn.id.to_string(),
|
||||
"auth-required:Authentication required".to_string(),
|
||||
to_client_msg_tx,
|
||||
)
|
||||
.await;
|
||||
debug!("Sending message to Client {}: not authenticated", conn.id);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
if arr.len() < 3 {
|
||||
RelayMessage::send_notice(
|
||||
"Not enough array elements".to_string(),
|
||||
@@ -623,19 +619,6 @@ async fn handle_message(
|
||||
}
|
||||
|
||||
"EVENT" => {
|
||||
if !SERVER_INFO.auth_required {
|
||||
let conn = client_conn.read().await;
|
||||
if !conn.authenticated {
|
||||
RelayMessage::send_closed(
|
||||
&conn.id.to_string(),
|
||||
"auth-required:Authentication required".to_string(),
|
||||
to_client_msg_tx,
|
||||
)
|
||||
.await;
|
||||
debug!("Sending message to Client {}: not authenticated", conn.id);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
let event: NostrEvent = serde_json::from_value(arr[1].clone())
|
||||
.map_err(|e| anyhow!("Event parse error: {}", e))?;
|
||||
debug!("EVENT received: {:?}", event);
|
||||
@@ -644,7 +627,30 @@ async fn handle_message(
|
||||
let to_client_msg_tx = to_client_msg_tx.clone();
|
||||
let event_tx = event_tx.clone();
|
||||
let event = event.clone();
|
||||
|
||||
let conn = client_conn.read().await;
|
||||
if SERVER_INFO.auth_required {
|
||||
let trust_accounts = trust_accounts.read().await;
|
||||
if !trust_accounts.contains(&event.pubkey) {
|
||||
RelayMessage::send_closed(
|
||||
&conn.id.to_string(),
|
||||
"restricted: you are not in the trust list".to_string(),
|
||||
&to_client_msg_tx,
|
||||
)
|
||||
.await;
|
||||
debug!("Sending message to Client {}:restricted: client not in the trust list", conn.id);
|
||||
return Ok(());
|
||||
}
|
||||
if !conn.authenticated {
|
||||
RelayMessage::send_closed(
|
||||
&conn.id.to_string(),
|
||||
"auth-required:Authentication required".to_string(),
|
||||
&to_client_msg_tx,
|
||||
)
|
||||
.await;
|
||||
debug!("Sending message to Client {}: not authenticated", conn.id);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
tokio::spawn(async move {
|
||||
match event.save(&pool).await {
|
||||
Ok(_) => {
|
||||
@@ -653,6 +659,15 @@ async fn handle_message(
|
||||
if let Err(e) = event_tx.send(serde_json::to_string(&event).unwrap()) {
|
||||
error!("Failed to broadcast event: {}", e);
|
||||
}
|
||||
//Update trust list
|
||||
if event.kind == 3 && event.pubkey == SERVER_INFO.pubkey{
|
||||
let mut ts = trust_accounts.write().await;
|
||||
let new_trust_accounts = extract_p_tags_from_vec(&event.tags);
|
||||
let mut new_trust_accounts: HashSet<String> = new_trust_accounts.into_iter().collect();
|
||||
new_trust_accounts.insert(SERVER_INFO.pubkey.to_string());
|
||||
*ts = new_trust_accounts;
|
||||
debug!("Trust list updated: {}", ts.len());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx).await;
|
||||
@@ -673,17 +688,6 @@ async fn handle_message(
|
||||
}
|
||||
|
||||
"CLOSE" => {
|
||||
if SERVER_INFO.auth_required {
|
||||
if !client_conn.read().await.authenticated {
|
||||
RelayMessage::send_closed(
|
||||
"AUTH",
|
||||
"auth-required:Authentication required".to_string(),
|
||||
to_client_msg_tx,
|
||||
)
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
let sub_id = arr
|
||||
.get(1)
|
||||
.and_then(Value::as_str)
|
||||
@@ -1143,3 +1147,61 @@ impl RelayMessage {
|
||||
}
|
||||
|
||||
|
||||
async fn get_trust_accounts(pool: &SqlitePool) -> Vec<String> {
|
||||
let pubkey = SERVER_INFO.pubkey;
|
||||
let sql = "SELECT tags FROM events WHERE kind = 3 AND pubkey = ? ORDER BY created_at DESC LIMIT 1";
|
||||
|
||||
let row = match sqlx::query(sql)
|
||||
.bind(pubkey)
|
||||
.fetch_optional(pool)
|
||||
.await
|
||||
{
|
||||
Ok(row) => row,
|
||||
Err(e) => {
|
||||
error!("Failed to execute query: {}", e);
|
||||
return Vec::new();
|
||||
}
|
||||
};
|
||||
match row {
|
||||
Some(row) => {
|
||||
let tags_json: String = row.get(0);
|
||||
extract_p_tags_from_json(&tags_json).unwrap_or_else(|_| Vec::new())
|
||||
}
|
||||
None => Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_p_tags_from_json(tags_json: &str) -> Result<Vec<String>, serde_json::Error> {
|
||||
let parsed: Value = serde_json::from_str(tags_json)?;
|
||||
|
||||
let mut trust_accounts:Vec<String> = parsed
|
||||
.as_array()
|
||||
.unwrap_or(&vec![])
|
||||
.iter()
|
||||
.filter_map(|item| {
|
||||
item.as_array().and_then(|arr| {
|
||||
// 检查是否是 "p" 标签且至少有两个元素
|
||||
if arr.len() >= 2 && arr[0].as_str() == Some("p") {
|
||||
arr[1].as_str().map(|s| s.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
debug!("Trust accounts: {}", trust_accounts.len());
|
||||
trust_accounts.push(SERVER_INFO.pubkey.clone().to_string());
|
||||
Ok(trust_accounts)
|
||||
}
|
||||
|
||||
fn extract_p_tags_from_vec(tags: &[Vec<String>]) -> Vec<String> {
|
||||
tags.iter()
|
||||
.filter_map(|tag| {
|
||||
if tag.len() >= 2 && tag[0] == "p" {
|
||||
Some(tag[1].clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
Reference in New Issue
Block a user