diff --git a/src/conf.rs b/src/conf.rs index 383ca92..6b354d5 100644 --- a/src/conf.rs +++ b/src/conf.rs @@ -1,6 +1,6 @@ pub const ID_SIZE: usize = 32; // Blake256 of password pub const HANDSHAKE_MSG_SIZE: usize = 33; // generated by Spake2 -pub const PER_CLIENT_BUFFER: usize = 1024 * 64; // buffer size allocated by server for each client + // pub const PER_CLIENT_BUFFER: usize = 1024 * 64; // buffer size allocated by server for each client pub const BUFFER_SIZE: usize = 1024 * 64; // chunk size for files sent over wire pub const NONCE_SIZE: usize = 96 / 8; // used for every encrypted message pub const PASSWORD_LEN: usize = 12; // min size of password diff --git a/src/handshake.rs b/src/handshake.rs index fb3ab20..31a8c3a 100644 --- a/src/handshake.rs +++ b/src/handshake.rs @@ -18,9 +18,7 @@ impl Handshake { let id = Handshake::pass_to_bytes(&pw); let (s1, outbound_msg) = Spake2::::start_symmetric(&Password::new(&password), &Identity::new(&id)); - let mut buffer = BytesMut::with_capacity(HANDSHAKE_MSG_SIZE); - buffer.extend_from_slice(&outbound_msg[..HANDSHAKE_MSG_SIZE]); - let outbound_msg = buffer.freeze(); + let outbound_msg = Bytes::from(outbound_msg); let handshake = Handshake { id, outbound_msg }; Ok((handshake, s1)) } @@ -53,6 +51,7 @@ impl Handshake { socket.write_all(&bytes).await?; let mut buffer = [0; HANDSHAKE_MSG_SIZE]; let n = socket.read_exact(&mut buffer).await?; + // println!("reading response"); let response = BytesMut::from(&buffer[..n]).freeze(); // println!("client - handshake msg, {:?}", response); let key = match s1.finish(&response[..]) { diff --git a/src/server.rs b/src/server.rs index f4cf567..d3f8cf4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,31 +1,28 @@ -use crate::conf::PER_CLIENT_BUFFER; use crate::handshake::Handshake; -use anyhow::{anyhow, Result}; -use bytes::{Bytes, BytesMut}; +use anyhow::Result; +use bytes::Bytes; use std::collections::HashMap; -use std::net::SocketAddr; use std::sync::Arc; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::io::{copy, AsyncWriteExt}; +use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf}; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::{mpsc, Mutex}; - -type Tx = mpsc::UnboundedSender; -type Rx = mpsc::UnboundedReceiver; +use tokio::sync::{broadcast, Mutex}; +use tokio::time::{sleep, Duration}; pub struct Shared { - handshake_cache: HashMap, + handshake_cache: HashMap, } type State = Arc>; +type IdChannelSender = broadcast::Sender; +type IdChannelReceiver = broadcast::Receiver; struct Client { - socket: TcpStream, - rx: Rx, - peer_tx: Option, + read_socket: OwnedReadHalf, + peer_write_socket: Option, } struct StapledClient { - socket: TcpStream, - rx: Rx, - peer_tx: Tx, + read_socket: OwnedReadHalf, + peer_write_socket: OwnedWriteHalf, } impl Shared { @@ -38,42 +35,54 @@ impl Shared { impl Client { async fn new(id: Bytes, state: State, socket: TcpStream) -> Result { - let (tx, rx) = mpsc::unbounded_channel(); + let (read_socket, write_socket) = socket.into_split(); let mut shared = state.lock().await; let client = Client { - socket, - rx, - peer_tx: shared.handshake_cache.remove(&id), + read_socket, + peer_write_socket: shared.handshake_cache.remove(&id), }; - shared.handshake_cache.insert(id, tx); + shared.handshake_cache.insert(id, write_socket); Ok(client) } - async fn upgrade(client: Client, state: State, handshake: Handshake) -> Result { - let mut client = client; - let peer_tx = match client.peer_tx { - // Receiver - already stapled at creation - Some(peer_tx) => peer_tx, - // Sender - needs to wait for the incoming msg to look up peer_tx - None => { - tokio::select! { - // Client reads handshake message sent over channel - Some(msg) = client.rx.recv() => { - // Writes parnter handshake message over wire - client.socket.write_all(&msg[..]).await? - } - } - match state.lock().await.handshake_cache.remove(&handshake.id) { - Some(peer_tx) => peer_tx, - None => return Err(anyhow!("Connection not stapled")), + async fn await_peer(state: State, id: &Bytes, id_channel: IdChannelReceiver) -> OwnedWriteHalf { + let mut id_channel = id_channel; + loop { + tokio::select! { + res = id_channel.recv() => match res { + Ok(bytes) if bytes == id => { + match state.lock().await.handshake_cache.remove(id) { + Some(tx_write_half) => {return tx_write_half}, + _ => continue + } + }, + _ => continue + }, + else => { + sleep(Duration::from_millis(500)).await; + continue } } + } + } + + async fn upgrade( + client: Client, + state: State, + handshake: Handshake, + id_channel: IdChannelSender, + ) -> Result { + let mut peer_write_socket = match client.peer_write_socket { + // Receiver - already stapled at creation + Some(peer_write_socket) => peer_write_socket, + // Sender - needs to wait for the incoming msg to look up peer_tx + None => Client::await_peer(state, &handshake.id, id_channel.subscribe()).await, }; - peer_tx.send(handshake.outbound_msg)?; + println!("past await peer"); + peer_write_socket.write_all(&handshake.outbound_msg).await?; Ok(StapledClient { - socket: client.socket, - rx: client.rx, - peer_tx, + read_socket: client.read_socket, + peer_write_socket, }) } } @@ -82,12 +91,14 @@ pub async fn serve() -> Result<()> { let addr = "127.0.0.1:8080".to_string(); let listener = TcpListener::bind(&addr).await?; let state = Arc::new(Mutex::new(Shared::new())); + let (tx, _rx) = broadcast::channel::(100); println!("Listening on: {}", addr); loop { - let (stream, address) = listener.accept().await?; + let (stream, _address) = listener.accept().await?; let state = Arc::clone(&state); + let tx = tx.clone(); tokio::spawn(async move { - match handle_connection(state, stream, address).await { + match handle_connection(state, stream, tx).await { Ok(_) => println!("Connection complete!"), Err(err) => println!("Error handling connection! {:?}", err), } @@ -98,14 +109,15 @@ pub async fn serve() -> Result<()> { pub async fn handle_connection( state: Arc>, socket: TcpStream, - _addr: SocketAddr, + id_channel: IdChannelSender, ) -> Result<()> { socket.readable().await?; let (handshake, socket) = Handshake::from_socket(socket).await?; let id = handshake.id.clone(); let client = Client::new(id.clone(), state.clone(), socket).await?; + id_channel.send(id.clone())?; println!("Client created"); - let mut client = match Client::upgrade(client, state.clone(), handshake).await { + let mut client = match Client::upgrade(client, state.clone(), handshake, id_channel).await { Ok(client) => client, Err(err) => { // Clear handshake cache if staple is unsuccessful @@ -115,28 +127,6 @@ pub async fn handle_connection( }; println!("Client upgraded"); // The handshake cache should be empty for {id} at this point. - let mut client_buffer = BytesMut::with_capacity(PER_CLIENT_BUFFER); - loop { - tokio::select! { - Some(msg) = client.rx.recv() => { - // println!("piping bytes= {:?}", msg); - client.socket.write_all(&msg[..]).await? - } - result = client.socket.read_buf(&mut client_buffer) => match result { - Ok(0) => { - break; - }, - Ok(n) => { - let b = BytesMut::from(&client_buffer[0..n]).freeze(); - client.peer_tx.send(b)?; - client_buffer.clear(); - }, - Err(e) => { - println!("Error {:?}", e); - } - } - } - } - println!("done with client"); + copy(&mut client.read_socket, &mut client.peer_write_socket).await?; Ok(()) }