use crate::message::{HandshakePayload, Message, MessageStream, RuckError}; use anyhow::{anyhow, Result}; use bytes::Bytes; use futures::prelude::*; use std::collections::HashMap; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; use tokio::sync::{mpsc, Mutex}; type Tx = mpsc::UnboundedSender; type Rx = mpsc::UnboundedReceiver; pub struct Shared { handshake_cache: HashMap, } type State = Arc>; struct Client { messages: MessageStream, rx: Rx, peer_tx: Option, } struct StapledClient { messages: MessageStream, rx: Rx, peer_tx: Tx, } impl Shared { fn new() -> Self { Shared { handshake_cache: HashMap::new(), } } } impl Client { async fn new(id: Bytes, state: State, messages: MessageStream) -> Result { let (tx, rx) = mpsc::unbounded_channel(); let mut shared = state.lock().await; let client = Client { messages, rx, peer_tx: shared.handshake_cache.remove(&id), }; shared.handshake_cache.insert(id, tx); Ok(client) } async fn upgrade( client: Client, state: State, handshake_payload: HandshakePayload, ) -> Result { let mut client = client; let peer_tx = match client.peer_tx { // Receiver - already stapled at creation Some(peer_tx) => peer_tx, // Sender - needs to wait for the incoming msg to look up peer_tx None => { tokio::select! { Some(msg) = client.rx.recv() => { client.messages.send(msg).await? } result = client.messages.next() => match result { Some(_) => return Err(anyhow!("Client sending more messages before handshake complete")), None => return Err(anyhow!("Connection interrupted")), } } match state .lock() .await .handshake_cache .remove(&handshake_payload.id) { Some(peer_tx) => peer_tx, None => return Err(anyhow!("Connection not stapled")), } } }; peer_tx.send(Message::HandshakeMessage(handshake_payload))?; Ok(StapledClient { messages: client.messages, rx: client.rx, peer_tx, }) } } pub async fn serve() -> Result<()> { let addr = "127.0.0.1:8080".to_string(); let listener = TcpListener::bind(&addr).await?; let state = Arc::new(Mutex::new(Shared::new())); println!("Listening on: {}", addr); loop { let (stream, address) = listener.accept().await?; let state = Arc::clone(&state); tokio::spawn(async move { match handle_connection(state, stream, address).await { Ok(_) => println!("Connection complete!"), Err(err) => println!("Error handling connection! {:?}", err), } }); } } pub async fn handle_connection( state: Arc>, socket: TcpStream, addr: SocketAddr, ) -> Result<()> { let mut stream = Message::to_stream(socket); println!("server - new conn from {:?}", addr); let handshake_payload = match stream.next().await { Some(Ok(Message::HandshakeMessage(payload))) => payload, Some(Ok(_)) => { stream .send(Message::ErrorMessage(RuckError::NotHandshake)) .await?; return Ok(()); } _ => { println!("No first message"); return Ok(()); } }; let id = handshake_payload.id.clone(); let client = Client::new(id.clone(), state.clone(), stream).await?; let mut client = match Client::upgrade(client, state.clone(), handshake_payload).await { Ok(client) => client, Err(err) => { // Clear handshake cache if staple is unsuccessful state.lock().await.handshake_cache.remove(&id); return Err(err); } }; // The handshake cache should be empty for {id} at this point. loop { tokio::select! { Some(msg) = client.rx.recv() => { client.messages.send(msg).await? } result = client.messages.next() => match result { Some(Ok(msg)) => { client.peer_tx.send(msg)? } Some(Err(e)) => { println!("Error {:?}", e); } None => break, } } } Ok(()) }