Add: REQ
This commit is contained in:
+196
-45
@@ -18,7 +18,7 @@ use std::time::SystemTime;
|
||||
use std::time::UNIX_EPOCH;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{mpsc, RwLock};
|
||||
use tokio::{
|
||||
net::{TcpListener, TcpStream},
|
||||
sync::broadcast,
|
||||
@@ -36,7 +36,7 @@ const BROADCAST_CHANNEL_SIZE: usize = 100;
|
||||
const CLIENT_CHANNEL_SIZE: usize = 32;
|
||||
const MAX_SUBSCRIPTIONS: usize = 20;
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct NostrEvent {
|
||||
pub id: String, // 32 bytes hex string of sha256
|
||||
pub pubkey: String, // 32 bytes hex pubkey
|
||||
@@ -79,7 +79,7 @@ pub enum ClientMessage {
|
||||
CLOSE { sub_id: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Debug, Serialize, Deserialize, Clone)]
|
||||
pub struct Filter {
|
||||
pub ids: Option<Vec<String>>, // event ids 列表
|
||||
pub authors: Option<Vec<String>>, // pubkeys 列表,小写字符串
|
||||
@@ -116,6 +116,16 @@ struct ServerInfo {
|
||||
version: &'static str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Subscription {
|
||||
filters: Vec<Filter>,
|
||||
}
|
||||
|
||||
struct ClientConnection {
|
||||
subscriptions: HashMap<String, Subscription>,
|
||||
sender: mpsc::Sender<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
env_logger::init();
|
||||
@@ -129,14 +139,17 @@ async fn main() -> Result<()> {
|
||||
let listener = TcpListener::bind(&addr).await.expect("Failed to bind");
|
||||
info!("Listening on: {}", addr);
|
||||
let (event_tx, _) = broadcast::channel::<String>(100);
|
||||
|
||||
|
||||
let clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>> =
|
||||
Arc::new(RwLock::new(HashMap::new()));
|
||||
while let Ok((stream, client_addr)) = listener.accept().await {
|
||||
let event_tx = event_tx.clone();
|
||||
let event_rx = event_tx.subscribe();
|
||||
let pool = pool.clone();
|
||||
let clients = clients.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool).await {
|
||||
if let Err(e) = handle_connection_multiplex(stream, event_tx, event_rx, pool, clients).await {
|
||||
error!("Error handling connection: {}", e);
|
||||
}
|
||||
});
|
||||
@@ -149,6 +162,7 @@ async fn handle_connection_multiplex(
|
||||
tx: broadcast::Sender<String>,
|
||||
rx: broadcast::Receiver<String>,
|
||||
pool: Arc<SqlitePool>,
|
||||
clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
// 分配足够大的缓冲区,一次读完整个头
|
||||
let mut buf = vec![0u8; 4096];
|
||||
@@ -178,7 +192,7 @@ async fn handle_connection_multiplex(
|
||||
.map(|&v| v.eq_ignore_ascii_case("websocket"))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
handle_ws_connection(stream, tx, rx, pool).await;
|
||||
handle_ws_connection(stream, tx, rx, pool, clients).await;
|
||||
} else {
|
||||
if let Some(accept) = header_map.get("accept") {
|
||||
if accept.contains("application/json") || accept.contains("application/json") {
|
||||
@@ -398,6 +412,7 @@ async fn handle_ws_connection(
|
||||
event_tx: broadcast::Sender<String>,
|
||||
mut event_rx: broadcast::Receiver<String>,
|
||||
pool: Arc<SqlitePool>,
|
||||
clients: Arc<RwLock<HashMap<Uuid, Arc<RwLock<ClientConnection>>>>>,
|
||||
) {
|
||||
let ws_stream = match accept_async(stream).await {
|
||||
Ok(stream) => stream,
|
||||
@@ -407,18 +422,39 @@ async fn handle_ws_connection(
|
||||
}
|
||||
};
|
||||
let (mut ws_sender, mut ws_reciver) = ws_stream.split();
|
||||
let (client_tx, mut client_rx) = mpsc::channel::<String>(32);
|
||||
let (client_tx, mut client_rx) = mpsc::channel::<String>(CLIENT_CHANNEL_SIZE);
|
||||
|
||||
let client_id = Uuid::new_v4();
|
||||
let client_conn = Arc::new(RwLock::new(ClientConnection {
|
||||
subscriptions: HashMap::new(),
|
||||
sender: client_tx.clone(),
|
||||
}));
|
||||
// 注册客户端
|
||||
{
|
||||
let mut clients_map = clients.write().await;
|
||||
clients_map.insert(client_id, client_conn.clone());
|
||||
}
|
||||
|
||||
info!("Client {} connected", client_id);
|
||||
let client_conn_for_send = client_conn.clone();
|
||||
let send_task = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
//Subscribed events
|
||||
Ok(msg) = event_rx.recv() => {
|
||||
if ws_sender
|
||||
.send(Message::Text(msg.into()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
Ok(event_msg) = event_rx.recv() => {
|
||||
let event: NostrEvent = serde_json::from_str(&event_msg).unwrap();
|
||||
let conn = client_conn_for_send.read().await;
|
||||
for (sub_id, subscription) in &conn.subscriptions {
|
||||
if subscription.filters.iter().any(|filter| filter.matches(&event)) {
|
||||
let msg = format!(
|
||||
"[\"EVENT\", \"{}\", {}]",
|
||||
sub_id,
|
||||
serde_json::to_string(&event).unwrap()
|
||||
);
|
||||
if ws_sender.send(Message::Text(msg.into())).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
//Send message to client
|
||||
@@ -443,7 +479,8 @@ async fn handle_ws_connection(
|
||||
Ok(Message::Text(text)) => {
|
||||
debug!("Received text: {}", text);
|
||||
let pool = pool.clone();
|
||||
match handle_message(&text, &pool, &client_tx).await {
|
||||
let client_conn_for_msg = client_conn.clone();
|
||||
match handle_message(&text, &pool, &client_tx, &client_id, &client_conn_for_msg, &event_tx).await {
|
||||
Ok(_) => {
|
||||
debug!("Message handled successfully");
|
||||
}
|
||||
@@ -457,7 +494,7 @@ async fn handle_ws_connection(
|
||||
}
|
||||
}
|
||||
Ok(Message::Close(_)) => {
|
||||
info!("Client disconnected");
|
||||
info!("Client {} disconnected", client_id);
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
@@ -469,6 +506,9 @@ async fn handle_message(
|
||||
text: &str,
|
||||
pool: &SqlitePool,
|
||||
to_client_msg_tx: &mpsc::Sender<String>,
|
||||
client_id: &Uuid,
|
||||
client_conn: &Arc<RwLock<ClientConnection>>,
|
||||
event_tx: &broadcast::Sender<String>,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let v = serde_json::from_str(text).map_err(|e| anyhow!("Invalid JSON: {}", e))?;
|
||||
|
||||
@@ -487,7 +527,7 @@ async fn handle_message(
|
||||
{
|
||||
"REQ" => {
|
||||
if arr.len() < 3 {
|
||||
RelayMessage::send_message(
|
||||
RelayMessage::send_notice(
|
||||
"Not enough array elements".to_string(),
|
||||
to_client_msg_tx,
|
||||
)
|
||||
@@ -499,13 +539,50 @@ async fn handle_message(
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| anyhow!("REQ missing sub_id"))?
|
||||
.to_string();
|
||||
let filters: Filter = serde_json::from_value(arr[2].clone())
|
||||
.map_err(|e| anyhow!("Filter parse error: {}", e))?;
|
||||
debug!("REQ subscription: {}, filters: {:?}", sub_id, filters);
|
||||
let r = filters.select(pool).await?;
|
||||
for event in r {
|
||||
RelayMessage::send_event(&event, &sub_id, to_client_msg_tx).await;
|
||||
|
||||
// Check if the client has exceeded the maximum number of subscriptions
|
||||
{
|
||||
let conn = client_conn.read().await;
|
||||
if conn.subscriptions.len() >= MAX_SUBSCRIPTIONS {
|
||||
RelayMessage::send_closed(
|
||||
&sub_id,
|
||||
format!("Maximum subscriptions ({}) exceeded", MAX_SUBSCRIPTIONS),
|
||||
to_client_msg_tx
|
||||
).await;
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
let mut filters = Vec::new();
|
||||
for i in 2..arr.len() {
|
||||
if filters.len() >= MAX_FILTERS_PER_REQ {
|
||||
RelayMessage::send_closed(
|
||||
&sub_id,
|
||||
format!("Maximum filters ({}) exceeded", MAX_FILTERS_PER_REQ),
|
||||
to_client_msg_tx
|
||||
).await;
|
||||
return Ok(());
|
||||
}
|
||||
let filter: Filter = serde_json::from_value(arr[i].clone())
|
||||
.map_err(|e| anyhow!("Filter parse error: {}", e))?;
|
||||
filters.push(filter);
|
||||
}
|
||||
|
||||
|
||||
debug!("REQ subscription: {}, filters: {:?}", sub_id, filters);
|
||||
{
|
||||
let mut conn = client_conn.write().await;
|
||||
conn.subscriptions.insert(sub_id.clone(), Subscription { filters: filters.clone() });
|
||||
}
|
||||
|
||||
// 查询历史事件
|
||||
for filter in &filters {
|
||||
let events = filter.select(pool).await?;
|
||||
for event in events {
|
||||
RelayMessage::send_event(&event, &sub_id, to_client_msg_tx).await;
|
||||
}
|
||||
}
|
||||
|
||||
// 发送 EOSE
|
||||
RelayMessage::send_eose(&sub_id, to_client_msg_tx).await;
|
||||
Ok(())
|
||||
}
|
||||
@@ -517,21 +594,29 @@ async fn handle_message(
|
||||
if event.verify() {
|
||||
let pool = pool.clone();
|
||||
let to_client_msg_tx = to_client_msg_tx.clone();
|
||||
let event_tx = event_tx.clone();
|
||||
let event = event.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = event.save(&pool).await {
|
||||
RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx)
|
||||
.await;
|
||||
error!("Failed to save event: {}", e);
|
||||
} else {
|
||||
RelayMessage::send_ok(&event, true, "Saved".to_string(), &to_client_msg_tx)
|
||||
.await;
|
||||
match event.save(&pool).await {
|
||||
Ok(_) => {
|
||||
RelayMessage::send_ok(&event, true, "Saved".to_string(), &to_client_msg_tx).await;
|
||||
// 广播事件给所有客户端
|
||||
if let Err(e) = event_tx.send(serde_json::to_string(&event).unwrap()) {
|
||||
error!("Failed to broadcast event: {}", e);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
RelayMessage::send_ok(&event, false, e.to_string(), &to_client_msg_tx).await;
|
||||
error!("Failed to save event: {}", e);
|
||||
}
|
||||
}
|
||||
});
|
||||
} else {
|
||||
RelayMessage::send_ok(
|
||||
&event,
|
||||
false,
|
||||
"Invalid signature".to_string(),
|
||||
"Failed to save event".to_string(),
|
||||
to_client_msg_tx,
|
||||
)
|
||||
.await;
|
||||
@@ -545,6 +630,12 @@ async fn handle_message(
|
||||
.and_then(Value::as_str)
|
||||
.ok_or_else(|| anyhow!("CLOSE missing sub_id"))?
|
||||
.to_string();
|
||||
|
||||
{
|
||||
let mut conn = client_conn.write().await;
|
||||
conn.subscriptions.remove(&sub_id);
|
||||
}
|
||||
|
||||
debug!("CLOSE subscription: {}", sub_id);
|
||||
|
||||
Ok(())
|
||||
@@ -717,10 +808,21 @@ impl Filter {
|
||||
sql.push(" AND created_at <= ").push_bind(until as i64);
|
||||
}
|
||||
|
||||
// TODO: 实现标签过滤
|
||||
// for (tag_name, tag_values) in &self.tag_filters {
|
||||
// // 实现 #e, #p 等标签过滤
|
||||
// }
|
||||
for (tag_name, tag_values) in &self.tag_filters {
|
||||
if tag_name.starts_with('#') && !tag_values.is_empty() {
|
||||
// 使用 JSON 查询来匹配标签
|
||||
sql.push(" AND EXISTS (");
|
||||
sql.push("SELECT 1 FROM json_each(tags) AS tag_array ");
|
||||
sql.push("WHERE json_extract(tag_array.value, '$[0]') = ");
|
||||
sql.push_bind(&tag_name[1..]); // 去掉 # 前缀
|
||||
sql.push(" AND json_extract(tag_array.value, '$[1]') IN (");
|
||||
let mut separated = sql.separated(",");
|
||||
for value in tag_values {
|
||||
separated.push_bind(value);
|
||||
}
|
||||
separated.push_unseparated("))");
|
||||
}
|
||||
}
|
||||
|
||||
sql.push(" ORDER BY created_at DESC");
|
||||
|
||||
@@ -750,23 +852,63 @@ impl Filter {
|
||||
Ok(events)
|
||||
}
|
||||
|
||||
fn matchs(&self, event: &NostrEvent) -> bool {
|
||||
fn matches(&self, event: &NostrEvent) -> bool {
|
||||
// ID 匹配
|
||||
if let Some(ids) = &self.ids {
|
||||
if ids.contains(&event.id) {
|
||||
return true;
|
||||
if !ids.is_empty() && !ids.contains(&event.id) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// 作者匹配
|
||||
if let Some(authors) = &self.authors {
|
||||
if authors.contains(&event.pubkey) {
|
||||
return true;
|
||||
if !authors.is_empty() && !authors.contains(&event.pubkey) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// 类型匹配
|
||||
if let Some(kinds) = &self.kinds {
|
||||
if kinds.contains(&event.kind) {
|
||||
return true;
|
||||
if !kinds.is_empty() && !kinds.contains(&event.kind) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
||||
// 时间匹配
|
||||
if let Some(since) = self.since {
|
||||
if event.created_at < since {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(until) = self.until {
|
||||
if event.created_at > until {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// 标签匹配
|
||||
for (tag_name, tag_values) in &self.tag_filters {
|
||||
if tag_name.starts_with('#') && !tag_values.is_empty() {
|
||||
let tag_letter = &tag_name[1..];
|
||||
let mut found = false;
|
||||
|
||||
for tag in &event.tags {
|
||||
if tag.get(0).map(|s| s == tag_letter).unwrap_or(false) && tag.len() > 1 {
|
||||
if tag_values.contains(&tag[1]) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -804,9 +946,18 @@ impl RelayMessage {
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_message(message: String, to_client_msg_tx: &mpsc::Sender<String>) {
|
||||
if let Err(e) = to_client_msg_tx.send(message).await {
|
||||
async fn send_closed(sub_id: &str, message: String, to_client_msg_tx: &mpsc::Sender<String>) {
|
||||
let msg = format!("[\"CLOSED\", \"{}\", \"{}\"]", sub_id, message);
|
||||
if let Err(e) = to_client_msg_tx.send(msg).await {
|
||||
error!("Failed to send message: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_notice(message: String, to_client_msg_tx: &mpsc::Sender<String>) {
|
||||
let msg = "[\"NOTICE\", \"".to_string() + &message + "\"]";
|
||||
if let Err(e) = to_client_msg_tx.send(msg).await {
|
||||
error!("Failed to send message: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user