diff --git a/src/server.rs b/src/server.rs index 0930ead..759e154 100644 --- a/src/server.rs +++ b/src/server.rs @@ -60,10 +60,15 @@ impl Client { Some(peer_tx) => peer_tx, // Sender - needs to wait for the incoming msg to look up peer_tx None => { - match client.rx.recv().await { - Some(msg) => client.messages.send(msg).await?, - None => return Err(anyhow!("Connection not stapled")), - }; + tokio::select! { + Some(msg) = client.rx.recv() => { + client.messages.send(msg).await? + } + result = client.messages.next() => match result { + Some(_) => return Err(anyhow!("Client sending more messages before handshake complete")), + None => return Err(anyhow!("Connection interrupted")), + } + } match state .lock() .await @@ -121,13 +126,20 @@ pub async fn handle_connection( return Ok(()); } }; - // println!("server - received msg from {:?}", addr); - let client = Client::new(handshake_payload.id.clone(), state.clone(), stream).await?; - let mut client = Client::upgrade(client, state.clone(), handshake_payload).await?; + let id = handshake_payload.id.clone(); + let client = Client::new(id.clone(), state.clone(), stream).await?; + let mut client = match Client::upgrade(client, state.clone(), handshake_payload).await { + Ok(client) => client, + Err(err) => { + // Clear handshake cache if staple is unsuccessful + state.lock().await.handshake_cache.remove(&id); + return Err(err); + } + }; + // The handshake cache should be empty for {id} at this point. loop { tokio::select! { Some(msg) = client.rx.recv() => { - // println!("message received to client.rx {:?}", msg); client.messages.send(msg).await? } result = client.messages.next() => match result {