This commit is contained in:
2025-06-15 23:24:33 +08:00
parent 1bc3afbff1
commit 6d1b9431d3
+196 -45
View File
@@ -18,7 +18,7 @@ use std::time::SystemTime;
use std::time::UNIX_EPOCH;
use std::{collections::HashMap, sync::Arc};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::sync::mpsc;
use tokio::sync::{mpsc, RwLock};
use tokio::{
net::{TcpListener, TcpStream},
sync::broadcast,
@@ -36,7 +36,7 @@ const BROADCAST_CHANNEL_SIZE: usize = 100;
const CLIENT_CHANNEL_SIZE: usize = 32;
const MAX_SUBSCRIPTIONS: usize = 20;
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct NostrEvent {
pub id: String, // 32 bytes hex string of sha256
pub pubkey: String, // 32 bytes hex pubkey
@@ -79,7 +79,7 @@ pub enum ClientMessage {
CLOSE { sub_id: String },
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Filter {
pub ids: Option<Vec<String>>, // event ids 列表
pub authors: Option<Vec<String>>, // pubkeys 列表,小写字符串
@@ -116,6 +116,16 @@ struct ServerInfo {
version: &'static str,
}
#[derive(Debug, Clone)]
struct Subscription {
filters: Vec<Filter>,
}
struct ClientConnection {
subscriptions: HashMap<String, Subscription>,
sender: mpsc::Sender<String>,
}
#[tokio::main]
async fn main() -> Result<()> {
env_logger::init();
@@ -129,14 +139,17 @@ async fn main() -> Result<()> {
let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
info!("Listening on: {}", addr);
let (event_tx, _) = broadcast::channel::<String>(100);
let clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>> =
Arc::new(RwLock::new(HashMap::new()));
while let Ok((stream, client_addr)) = listener.accept().await {
let event_tx = event_tx.clone();
let event_rx = event_tx.subscribe();
let pool = pool.clone();
let clients = clients.clone();
tokio::spawn(async move {
if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool).await {
if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients).await {
error!("Error handling connection: {}", e);
}
});
@@ -149,6 +162,7 @@ async fn handle_connection_multiplex(
tx: broadcast::Sender<String>,
rx: broadcast::Receiver<String>,
pool: Arc<SqlitePool>,
clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>>,
) -> Result<(), anyhow::Error> {
// 分配足够大的缓冲区,一次读完整个头
let mut buf = vec![0u8; 4096];
@@ -178,7 +192,7 @@ async fn handle_connection_multiplex(
.map(|&v| v.eq_ignore_ascii_case("websocket"))
.unwrap_or(false)
{
handle_ws_connection(stream, tx, rx, pool).await;
handle_ws_connection(stream, tx, rx, pool, clients).await;
} else {
if let Some(accept) = header_map.get("accept") {
if accept.contains("application/json") || accept.contains("application/json") {
@@ -398,6 +412,7 @@ async fn handle_ws_connection(
event_tx: broadcast::Sender<String>,
mut event_rx: broadcast::Receiver<String>,
pool: Arc<SqlitePool>,
clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>>,
) {
let ws_stream = match accept_async(stream).await {
Ok(stream) => stream,
@@ -407,18 +422,39 @@ async fn handle_ws_connection(
}
};
let (mut ws_sender, mut ws_reciver) = ws_stream.split();
let (client_tx, mut client_rx) = mpsc::channel::<String>(32);
let (client_tx, mut client_rx) = mpsc::channel::<String>(CLIENT_CHANNEL_SIZE);
let client_id = Uuid::new_v4();
let client_conn = Arc::new(RwLock::new(ClientConnection {
subscriptions: HashMap::new(),
sender: client_tx.clone(),
}));
// 注册客户端
{
let mut clients_map = clients.write().await;
clients_map.insert(client_id, client_conn.clone());
}
info!("Client {} connected", client_id);
let client_conn_for_send = client_conn.clone();
let send_task = tokio::spawn(async move {
loop {
tokio::select! {
//Subscribed events
Ok(msg) = event_rx.recv() => {
if ws_sender
.send(Message::Text(msg.into()))
.await
.is_err()
{
break;
Ok(event_msg) = event_rx.recv() => {
let event: NostrEvent = serde_json::from_str(&event_msg).unwrap();
let conn = client_conn_for_send.read().await;
for (sub_id, subscription) in &conn.subscriptions {
if subscription.filters.iter().any(|filter| filter.matches(&event)) {
let msg = format!(
"[\"EVENT\", \"{}\", {}]",
sub_id,
serde_json::to_string(&event).unwrap()
);
if ws_sender.send(Message::Text(msg.into())).await.is_err() {
return;
}
}
}
}
//Send message to client
@@ -443,7 +479,8 @@ async fn handle_ws_connection(
Ok(Message::Text(text)) => {
debug!("Received text: {}", text);
let pool = pool.clone();
match handle_message(&text, &pool, &client_tx).await {
let client_conn_for_msg = client_conn.clone();
match handle_message(&text, &pool, &client_tx, &client_id, &client_conn_for_msg, &event_tx).await {
Ok(_) => {
debug!("Message handled successfully");
}
@@ -457,7 +494,7 @@ async fn handle_ws_connection(
}
}
Ok(Message::Close(_)) => {
info!("Client disconnected");
info!("Client {} disconnected", client_id);
break;
}
_ => {}
@@ -469,6 +506,9 @@ async fn handle_message(
text: &str,
pool: &SqlitePool,
to_client_msg_tx: &mpsc::Sender<String>,
client_id: &Uuid,
client_conn: &Arc<RwLock<ClientConnection>>,
event_tx: &broadcast::Sender<String>,
) -> Result<(), anyhow::Error> {
let v = serde_json::from_str(text).map_err(|e| anyhow!("Invalid JSON: {}", e))?;
@@ -487,7 +527,7 @@ async fn handle_message(
{
"REQ" => {
if arr.len() < 3 {
RelayMessage::send_message(
RelayMessage::send_notice(
"Not enough array elements".to_string(),
to_client_msg_tx,
)
@@ -499,13 +539,50 @@ async fn handle_message(
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("REQ missing sub_id"))?
.to_string();
let filters: Filter = serde_json::from_value(arr[2].clone())
.map_err(|e| anyhow!("Filter parse error: {}", e))?;
debug!("REQ subscription: {}, filters: {:?}", sub_id, filters);
let r = filters.select(pool).await?;
for event in r {
RelayMessage::send_event(&event, &sub_id, to_client_msg_tx).await;
// Check if the client has exceeded the maximum number of subscriptions
{
let conn = client_conn.read().await;
if conn.subscriptions.len() >= MAX_SUBSCRIPTIONS {
RelayMessage::send_closed(
&sub_id,
format!("Maximum subscriptions ({}) exceeded", MAX_SUBSCRIPTIONS),
to_client_msg_tx
).await;
return Ok(());
}
}
let mut filters = Vec::new();
for i in 2..arr.len() {
if filters.len() >= MAX_FILTERS_PER_REQ {
RelayMessage::send_closed(
&sub_id,
format!("Maximum filters ({}) exceeded", MAX_FILTERS_PER_REQ),
to_client_msg_tx
).await;
return Ok(());
}
let filter: Filter = serde_json::from_value(arr[i].clone())
.map_err(|e| anyhow!("Filter parse error: {}", e))?;
filters.push(filter);
}
debug!("REQ subscription: {}, filters: {:?}", sub_id, filters);
{
let mut conn = client_conn.write().await;
conn.subscriptions.insert(sub_id.clone(), Subscription { filters: filters.clone() });
}
// 查询历史事件
for filter in &filters {
let events = filter.select(pool).await?;
for event in events {
RelayMessage::send_event(&event, &sub_id, to_client_msg_tx).await;
}
}
// 发送 EOSE
RelayMessage::send_eose(&sub_id, to_client_msg_tx).await;
Ok(())
}
@@ -517,21 +594,29 @@ async fn handle_message(
if event.verify() {
let pool = pool.clone();
let to_client_msg_tx = to_client_msg_tx.clone();
let event_tx = event_tx.clone();
let event = event.clone();
tokio::spawn(async move {
if let Err(e) = event.save(&pool).await {
RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx)
.await;
error!("Failed to save event: {}", e);
} else {
RelayMessage::send_ok(&event, true, "Saved".to_string(), &to_client_msg_tx)
.await;
match event.save(&pool).await {
Ok(_) => {
RelayMessage::send_ok(&event, true, "Saved".to_string(), &to_client_msg_tx).await;
// 广播事件给所有客户端
if let Err(e) = event_tx.send(serde_json::to_string(&event).unwrap()) {
error!("Failed to broadcast event: {}", e);
}
}
Err(e) => {
RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx).await;
error!("Failed to save event: {}", e);
}
}
});
} else {
RelayMessage::send_ok(
&event,
false,
"Invalid signature".to_string(),
"Failed to save event".to_string(),
to_client_msg_tx,
)
.await;
@@ -545,6 +630,12 @@ async fn handle_message(
.and_then(Value::as_str)
.ok_or_else(|| anyhow!("CLOSE missing sub_id"))?
.to_string();
{
let mut conn = client_conn.write().await;
conn.subscriptions.remove(&sub_id);
}
debug!("CLOSE subscription: {}", sub_id);
Ok(())
@@ -717,10 +808,21 @@ impl Filter {
sql.push(" AND created_at <= ").push_bind(until as i64);
}
// TODO: 实现标签过滤
// for (tag_name, tag_values) in &self.tag_filters {
// // 实现 #e, #p 等标签过滤
// }
for (tag_name, tag_values) in &self.tag_filters {
if tag_name.starts_with('#') && !tag_values.is_empty() {
// 使用 JSON 查询来匹配标签
sql.push(" AND EXISTS (");
sql.push("SELECT 1 FROM json_each(tags) AS tag_array ");
sql.push("WHERE json_extract(tag_array.value, '$[0]') = ");
sql.push_bind(&tag_name[1..]); // 去掉 # 前缀
sql.push(" AND json_extract(tag_array.value, '$[1]') IN (");
let mut separated = sql.separated(",");
for value in tag_values {
separated.push_bind(value);
}
separated.push_unseparated("))");
}
}
sql.push(" ORDER BY created_at DESC");
@@ -750,23 +852,63 @@ impl Filter {
Ok(events)
}
fn matchs(&self, event: &NostrEvent) -> bool {
fn matches(&self, event: &NostrEvent) -> bool {
// ID 匹配
if let Some(ids) = &self.ids {
if ids.contains(&event.id) {
return true;
if !ids.is_empty() && !ids.contains(&event.id) {
return false;
}
}
// 作者匹配
if let Some(authors) = &self.authors {
if authors.contains(&event.pubkey) {
return true;
if !authors.is_empty() && !authors.contains(&event.pubkey) {
return false;
}
}
// 类型匹配
if let Some(kinds) = &self.kinds {
if kinds.contains(&event.kind) {
return true;
if !kinds.is_empty() && !kinds.contains(&event.kind) {
return false;
}
}
return false;
// 时间匹配
if let Some(since) = self.since {
if event.created_at < since {
return false;
}
}
if let Some(until) = self.until {
if event.created_at > until {
return false;
}
}
// 标签匹配
for (tag_name, tag_values) in &self.tag_filters {
if tag_name.starts_with('#') && !tag_values.is_empty() {
let tag_letter = &tag_name[1..];
let mut found = false;
for tag in &event.tags {
if tag.get(0).map(|s| s == tag_letter).unwrap_or(false) && tag.len() > 1 {
if tag_values.contains(&tag[1]) {
found = true;
break;
}
}
}
if !found {
return false;
}
}
}
return true
}
}
@@ -804,9 +946,18 @@ impl RelayMessage {
}
}
async fn send_message(message: String, to_client_msg_tx: &mpsc::Sender<String>) {
if let Err(e) = to_client_msg_tx.send(message).await {
async fn send_closed(sub_id: &str, message: String, to_client_msg_tx: &mpsc::Sender<String>) {
let msg = format!("[\"CLOSED\", \"{}\", \"{}\"]", sub_id, message);
if let Err(e) = to_client_msg_tx.send(msg).await {
error!("Failed to send message: {}", e);
}
}
async fn send_notice(message: String, to_client_msg_tx: &mpsc::Sender<String>) {
let msg = "[\"NOTICE\", \"".to_string() + &message + "\"]";
if let Err(e) = to_client_msg_tx.send(msg).await {
error!("Failed to send message: {}", e);
}
}
}