Optimize server

This commit is contained in:
Donald Knuth 2022-09-15 10:05:20 -04:00
parent 5d2a216693
commit a21ecafb0d
3 changed files with 62 additions and 73 deletions

View file

@ -1,6 +1,6 @@
pub const ID_SIZE: usize = 32; // Blake256 of password pub const ID_SIZE: usize = 32; // Blake256 of password
pub const HANDSHAKE_MSG_SIZE: usize = 33; // generated by Spake2 pub const HANDSHAKE_MSG_SIZE: usize = 33; // generated by Spake2
pub const PER_CLIENT_BUFFER: usize = 1024 * 64; // buffer size allocated by server for each client // pub const PER_CLIENT_BUFFER: usize = 1024 * 64; // buffer size allocated by server for each client
pub const BUFFER_SIZE: usize = 1024 * 64; // chunk size for files sent over wire pub const BUFFER_SIZE: usize = 1024 * 64; // chunk size for files sent over wire
pub const NONCE_SIZE: usize = 96 / 8; // used for every encrypted message pub const NONCE_SIZE: usize = 96 / 8; // used for every encrypted message
pub const PASSWORD_LEN: usize = 12; // min size of password pub const PASSWORD_LEN: usize = 12; // min size of password

View file

@ -18,9 +18,7 @@ impl Handshake {
let id = Handshake::pass_to_bytes(&pw); let id = Handshake::pass_to_bytes(&pw);
let (s1, outbound_msg) = let (s1, outbound_msg) =
Spake2::<Ed25519Group>::start_symmetric(&Password::new(&password), &Identity::new(&id)); Spake2::<Ed25519Group>::start_symmetric(&Password::new(&password), &Identity::new(&id));
let mut buffer = BytesMut::with_capacity(HANDSHAKE_MSG_SIZE); let outbound_msg = Bytes::from(outbound_msg);
buffer.extend_from_slice(&outbound_msg[..HANDSHAKE_MSG_SIZE]);
let outbound_msg = buffer.freeze();
let handshake = Handshake { id, outbound_msg }; let handshake = Handshake { id, outbound_msg };
Ok((handshake, s1)) Ok((handshake, s1))
} }
@ -53,6 +51,7 @@ impl Handshake {
socket.write_all(&bytes).await?; socket.write_all(&bytes).await?;
let mut buffer = [0; HANDSHAKE_MSG_SIZE]; let mut buffer = [0; HANDSHAKE_MSG_SIZE];
let n = socket.read_exact(&mut buffer).await?; let n = socket.read_exact(&mut buffer).await?;
// println!("reading response");
let response = BytesMut::from(&buffer[..n]).freeze(); let response = BytesMut::from(&buffer[..n]).freeze();
// println!("client - handshake msg, {:?}", response); // println!("client - handshake msg, {:?}", response);
let key = match s1.finish(&response[..]) { let key = match s1.finish(&response[..]) {

View file

@ -1,31 +1,28 @@
use crate::conf::PER_CLIENT_BUFFER;
use crate::handshake::Handshake; use crate::handshake::Handshake;
use anyhow::{anyhow, Result}; use anyhow::Result;
use bytes::{Bytes, BytesMut}; use bytes::Bytes;
use std::collections::HashMap; use std::collections::HashMap;
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::io::{copy, AsyncWriteExt};
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
use tokio::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream};
use tokio::sync::{mpsc, Mutex}; use tokio::sync::{broadcast, Mutex};
use tokio::time::{sleep, Duration};
type Tx = mpsc::UnboundedSender<Bytes>;
type Rx = mpsc::UnboundedReceiver<Bytes>;
pub struct Shared { pub struct Shared {
handshake_cache: HashMap<Bytes, Tx>, handshake_cache: HashMap<Bytes, OwnedWriteHalf>,
} }
type State = Arc<Mutex<Shared>>; type State = Arc<Mutex<Shared>>;
type IdChannelSender = broadcast::Sender<bytes::Bytes>;
type IdChannelReceiver = broadcast::Receiver<bytes::Bytes>;
struct Client { struct Client {
socket: TcpStream, read_socket: OwnedReadHalf,
rx: Rx, peer_write_socket: Option<OwnedWriteHalf>,
peer_tx: Option<Tx>,
} }
struct StapledClient { struct StapledClient {
socket: TcpStream, read_socket: OwnedReadHalf,
rx: Rx, peer_write_socket: OwnedWriteHalf,
peer_tx: Tx,
} }
impl Shared { impl Shared {
@ -38,42 +35,54 @@ impl Shared {
impl Client { impl Client {
async fn new(id: Bytes, state: State, socket: TcpStream) -> Result<Client> { async fn new(id: Bytes, state: State, socket: TcpStream) -> Result<Client> {
let (tx, rx) = mpsc::unbounded_channel(); let (read_socket, write_socket) = socket.into_split();
let mut shared = state.lock().await; let mut shared = state.lock().await;
let client = Client { let client = Client {
socket, read_socket,
rx, peer_write_socket: shared.handshake_cache.remove(&id),
peer_tx: shared.handshake_cache.remove(&id),
}; };
shared.handshake_cache.insert(id, tx); shared.handshake_cache.insert(id, write_socket);
Ok(client) Ok(client)
} }
async fn upgrade(client: Client, state: State, handshake: Handshake) -> Result<StapledClient> { async fn await_peer(state: State, id: &Bytes, id_channel: IdChannelReceiver) -> OwnedWriteHalf {
let mut client = client; let mut id_channel = id_channel;
let peer_tx = match client.peer_tx { loop {
// Receiver - already stapled at creation
Some(peer_tx) => peer_tx,
// Sender - needs to wait for the incoming msg to look up peer_tx
None => {
tokio::select! { tokio::select! {
// Client reads handshake message sent over channel res = id_channel.recv() => match res {
Some(msg) = client.rx.recv() => { Ok(bytes) if bytes == id => {
// Writes parnter handshake message over wire match state.lock().await.handshake_cache.remove(id) {
client.socket.write_all(&msg[..]).await? Some(tx_write_half) => {return tx_write_half},
_ => continue
}
},
_ => continue
},
else => {
sleep(Duration::from_millis(500)).await;
continue
} }
} }
match state.lock().await.handshake_cache.remove(&handshake.id) {
Some(peer_tx) => peer_tx,
None => return Err(anyhow!("Connection not stapled")),
} }
} }
async fn upgrade(
client: Client,
state: State,
handshake: Handshake,
id_channel: IdChannelSender,
) -> Result<StapledClient> {
let mut peer_write_socket = match client.peer_write_socket {
// Receiver - already stapled at creation
Some(peer_write_socket) => peer_write_socket,
// Sender - needs to wait for the incoming msg to look up peer_tx
None => Client::await_peer(state, &handshake.id, id_channel.subscribe()).await,
}; };
peer_tx.send(handshake.outbound_msg)?; println!("past await peer");
peer_write_socket.write_all(&handshake.outbound_msg).await?;
Ok(StapledClient { Ok(StapledClient {
socket: client.socket, read_socket: client.read_socket,
rx: client.rx, peer_write_socket,
peer_tx,
}) })
} }
} }
@ -82,12 +91,14 @@ pub async fn serve() -> Result<()> {
let addr = "127.0.0.1:8080".to_string(); let addr = "127.0.0.1:8080".to_string();
let listener = TcpListener::bind(&addr).await?; let listener = TcpListener::bind(&addr).await?;
let state = Arc::new(Mutex::new(Shared::new())); let state = Arc::new(Mutex::new(Shared::new()));
let (tx, _rx) = broadcast::channel::<Bytes>(100);
println!("Listening on: {}", addr); println!("Listening on: {}", addr);
loop { loop {
let (stream, address) = listener.accept().await?; let (stream, _address) = listener.accept().await?;
let state = Arc::clone(&state); let state = Arc::clone(&state);
let tx = tx.clone();
tokio::spawn(async move { tokio::spawn(async move {
match handle_connection(state, stream, address).await { match handle_connection(state, stream, tx).await {
Ok(_) => println!("Connection complete!"), Ok(_) => println!("Connection complete!"),
Err(err) => println!("Error handling connection! {:?}", err), Err(err) => println!("Error handling connection! {:?}", err),
} }
@ -98,14 +109,15 @@ pub async fn serve() -> Result<()> {
pub async fn handle_connection( pub async fn handle_connection(
state: Arc<Mutex<Shared>>, state: Arc<Mutex<Shared>>,
socket: TcpStream, socket: TcpStream,
_addr: SocketAddr, id_channel: IdChannelSender,
) -> Result<()> { ) -> Result<()> {
socket.readable().await?; socket.readable().await?;
let (handshake, socket) = Handshake::from_socket(socket).await?; let (handshake, socket) = Handshake::from_socket(socket).await?;
let id = handshake.id.clone(); let id = handshake.id.clone();
let client = Client::new(id.clone(), state.clone(), socket).await?; let client = Client::new(id.clone(), state.clone(), socket).await?;
id_channel.send(id.clone())?;
println!("Client created"); println!("Client created");
let mut client = match Client::upgrade(client, state.clone(), handshake).await { let mut client = match Client::upgrade(client, state.clone(), handshake, id_channel).await {
Ok(client) => client, Ok(client) => client,
Err(err) => { Err(err) => {
// Clear handshake cache if staple is unsuccessful // Clear handshake cache if staple is unsuccessful
@ -115,28 +127,6 @@ pub async fn handle_connection(
}; };
println!("Client upgraded"); println!("Client upgraded");
// The handshake cache should be empty for {id} at this point. // The handshake cache should be empty for {id} at this point.
let mut client_buffer = BytesMut::with_capacity(PER_CLIENT_BUFFER); copy(&mut client.read_socket, &mut client.peer_write_socket).await?;
loop {
tokio::select! {
Some(msg) = client.rx.recv() => {
// println!("piping bytes= {:?}", msg);
client.socket.write_all(&msg[..]).await?
}
result = client.socket.read_buf(&mut client_buffer) => match result {
Ok(0) => {
break;
},
Ok(n) => {
let b = BytesMut::from(&client_buffer[0..n]).freeze();
client.peer_tx.send(b)?;
client_buffer.clear();
},
Err(e) => {
println!("Error {:?}", e);
}
}
}
}
println!("done with client");
Ok(()) Ok(())
} }