diff --git a/src/client.rs b/src/client.rs index 74988b0..35355ee 100644 --- a/src/client.rs +++ b/src/client.rs @@ -37,7 +37,6 @@ pub async fn send(file_paths: &Vec, password: &String) -> Result<()> { // Complete handshake, returning cipher used for encryption let (stream, cipher) = handshake( &mut stream, - true, Bytes::from(password.to_string()), pass_to_bytes(password), ) @@ -57,7 +56,6 @@ pub async fn receive(password: &String) -> Result<()> { let mut stream = Message::to_stream(socket); let (stream, cipher) = handshake( &mut stream, - false, Bytes::from(password.to_string()), pass_to_bytes(password), ) diff --git a/src/crypto.rs b/src/crypto.rs index 789175a..7074718 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -10,7 +10,6 @@ use spake2::{Ed25519Group, Identity, Password, Spake2}; pub async fn handshake( stream: &mut MessageStream, - up: bool, password: Bytes, id: Bytes, ) -> Result<(&mut MessageStream, Aes256Gcm)> { @@ -18,7 +17,6 @@ pub async fn handshake( Spake2::::start_symmetric(&Password::new(password), &Identity::new(&id)); println!("client - sending handshake msg"); let handshake_msg = Message::HandshakeMessage(HandshakePayload { - up, id, msg: Bytes::from(outbound_msg), }); diff --git a/src/message.rs b/src/message.rs index 5819978..10c440f 100644 --- a/src/message.rs +++ b/src/message.rs @@ -21,7 +21,6 @@ pub enum Message { #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub struct HandshakePayload { - pub up: bool, pub id: Bytes, pub msg: Bytes, } diff --git a/src/server.rs b/src/server.rs index cf155b0..8da91be 100644 --- a/src/server.rs +++ b/src/server.rs @@ -13,109 +13,74 @@ type Tx = mpsc::UnboundedSender; type Rx = mpsc::UnboundedReceiver; pub struct Shared { - handshakes: HashMap, - senders: HashMap, - receivers: HashMap, + handshake_cache: HashMap, } type State = Arc>; -struct Client<'a> { - up: bool, - id: Bytes, - messages: &'a mut MessageStream, +struct Client { + messages: MessageStream, rx: Rx, + peer_tx: Option, +} +struct StapledClient { + messages: MessageStream, + rx: Rx, + peer_tx: Tx, } impl Shared { fn new() -> Self { Shared { - handshakes: HashMap::new(), - senders: HashMap::new(), - receivers: HashMap::new(), + handshake_cache: HashMap::new(), } } - async fn relay<'a>(&self, client: &Client<'a>, message: Message) -> Result<()> { - println!("in relay - got client={:?}, msg {:?}", client.id, message); - match client.up { - true => match self.receivers.get(&client.id) { - Some(tx) => { - tx.send(message)?; - } - None => { - return Err(anyhow!(RuckError::PairDisconnected)); - } - }, - false => match self.senders.get(&client.id) { - Some(tx) => { - tx.send(message)?; - } - None => { - return Err(anyhow!(RuckError::PairDisconnected)); - } - }, - } - Ok(()) - } } -impl<'a> Client<'a> { - async fn new( - up: bool, - id: Bytes, - state: State, - messages: &'a mut MessageStream, - ) -> Result> { +impl Client { + async fn new(id: Bytes, state: State, messages: MessageStream) -> Result { let (tx, rx) = mpsc::unbounded_channel(); - println!("server - creating client up={:?}, id={:?}", up, id); - let shared = &mut state.lock().await; - match shared.senders.get(&id) { - Some(_) if up => { - messages - .send(Message::ErrorMessage(RuckError::SenderAlreadyConnected)) - .await?; - } - Some(_) => { - println!("server - adding client to receivers"); - shared.receivers.insert(id.clone(), tx); - } - None if up => { - println!("server - adding client to senders"); - shared.senders.insert(id.clone(), tx); - } - None => { - messages - .send(Message::ErrorMessage(RuckError::SenderNotConnected)) - .await?; - } - } - Ok(Client { - up, - id, + let mut shared = state.lock().await; + let client = Client { messages, rx, - }) + peer_tx: shared.handshake_cache.remove(&id), + }; + shared.handshake_cache.insert(id, tx); + Ok(client) } - async fn complete_handshake(&mut self, state: State, msg: Message) -> Result<()> { - match self.up { - true => { - let (tx, rx) = mpsc::unbounded_channel(); - tx.send(msg)?; - state.lock().await.handshakes.insert(self.id.clone(), rx); - } - false => { - let shared = &mut state.lock().await; - if let Some(tx) = shared.senders.get(&self.id) { - tx.send(msg)?; - } - if let Some(mut rx) = shared.handshakes.remove(&self.id) { - drop(shared); - if let Some(msg) = rx.recv().await { - self.messages.send(msg).await?; - } + + async fn upgrade( + client: Client, + state: State, + handshake_payload: HandshakePayload, + ) -> Result { + let mut client = client; + let peer_tx = match client.peer_tx { + // Receiver - already stapled at creation + Some(peer_tx) => peer_tx, + // Sender - needs to wait for the incoming msg to look up peer_tx + None => { + match client.rx.recv().await { + Some(msg) => client.messages.send(msg).await?, + None => return Err(anyhow!("Connection not stapled")), + }; + match state + .lock() + .await + .handshake_cache + .remove(&handshake_payload.id) + { + Some(peer_tx) => peer_tx, + None => return Err(anyhow!("Connection not stapled")), } } - } - Ok(()) + }; + peer_tx.send(Message::HandshakeMessage(handshake_payload))?; + Ok(StapledClient { + messages: client.messages, + rx: client.rx, + peer_tx, + }) } } @@ -157,16 +122,8 @@ pub async fn handle_connection( } }; println!("server - received msg from {:?}", addr); - let mut client = Client::new( - handshake_payload.up, - handshake_payload.id.clone(), - state.clone(), - &mut stream, - ) - .await?; - client - .complete_handshake(state.clone(), Message::HandshakeMessage(handshake_payload)) - .await?; + let client = Client::new(handshake_payload.id.clone(), state.clone(), stream).await?; + let mut client = Client::upgrade(client, state.clone(), handshake_payload).await?; // add client to state here loop { tokio::select! { @@ -176,9 +133,7 @@ pub async fn handle_connection( } result = client.messages.next() => match result { Some(Ok(msg)) => { - println!("GOT: {:?}", msg); - let state = state.lock().await; - state.relay(&client, msg).await?; + client.peer_tx.send(msg)? } Some(Err(e)) => { println!("Error {:?}", e);