mirror of
https://github.com/CompeyDev/ruck.git
synced 2025-01-09 03:59:10 +00:00
Optimize server
This commit is contained in:
parent
5d2a216693
commit
a21ecafb0d
3 changed files with 62 additions and 73 deletions
|
@ -1,6 +1,6 @@
|
||||||
pub const ID_SIZE: usize = 32; // Blake256 of password
|
pub const ID_SIZE: usize = 32; // Blake256 of password
|
||||||
pub const HANDSHAKE_MSG_SIZE: usize = 33; // generated by Spake2
|
pub const HANDSHAKE_MSG_SIZE: usize = 33; // generated by Spake2
|
||||||
pub const PER_CLIENT_BUFFER: usize = 1024 * 64; // buffer size allocated by server for each client
|
// pub const PER_CLIENT_BUFFER: usize = 1024 * 64; // buffer size allocated by server for each client
|
||||||
pub const BUFFER_SIZE: usize = 1024 * 64; // chunk size for files sent over wire
|
pub const BUFFER_SIZE: usize = 1024 * 64; // chunk size for files sent over wire
|
||||||
pub const NONCE_SIZE: usize = 96 / 8; // used for every encrypted message
|
pub const NONCE_SIZE: usize = 96 / 8; // used for every encrypted message
|
||||||
pub const PASSWORD_LEN: usize = 12; // min size of password
|
pub const PASSWORD_LEN: usize = 12; // min size of password
|
||||||
|
|
|
@ -18,9 +18,7 @@ impl Handshake {
|
||||||
let id = Handshake::pass_to_bytes(&pw);
|
let id = Handshake::pass_to_bytes(&pw);
|
||||||
let (s1, outbound_msg) =
|
let (s1, outbound_msg) =
|
||||||
Spake2::<Ed25519Group>::start_symmetric(&Password::new(&password), &Identity::new(&id));
|
Spake2::<Ed25519Group>::start_symmetric(&Password::new(&password), &Identity::new(&id));
|
||||||
let mut buffer = BytesMut::with_capacity(HANDSHAKE_MSG_SIZE);
|
let outbound_msg = Bytes::from(outbound_msg);
|
||||||
buffer.extend_from_slice(&outbound_msg[..HANDSHAKE_MSG_SIZE]);
|
|
||||||
let outbound_msg = buffer.freeze();
|
|
||||||
let handshake = Handshake { id, outbound_msg };
|
let handshake = Handshake { id, outbound_msg };
|
||||||
Ok((handshake, s1))
|
Ok((handshake, s1))
|
||||||
}
|
}
|
||||||
|
@ -53,6 +51,7 @@ impl Handshake {
|
||||||
socket.write_all(&bytes).await?;
|
socket.write_all(&bytes).await?;
|
||||||
let mut buffer = [0; HANDSHAKE_MSG_SIZE];
|
let mut buffer = [0; HANDSHAKE_MSG_SIZE];
|
||||||
let n = socket.read_exact(&mut buffer).await?;
|
let n = socket.read_exact(&mut buffer).await?;
|
||||||
|
// println!("reading response");
|
||||||
let response = BytesMut::from(&buffer[..n]).freeze();
|
let response = BytesMut::from(&buffer[..n]).freeze();
|
||||||
// println!("client - handshake msg, {:?}", response);
|
// println!("client - handshake msg, {:?}", response);
|
||||||
let key = match s1.finish(&response[..]) {
|
let key = match s1.finish(&response[..]) {
|
||||||
|
|
122
src/server.rs
122
src/server.rs
|
@ -1,31 +1,28 @@
|
||||||
use crate::conf::PER_CLIENT_BUFFER;
|
|
||||||
use crate::handshake::Handshake;
|
use crate::handshake::Handshake;
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::Result;
|
||||||
use bytes::{Bytes, BytesMut};
|
use bytes::Bytes;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
use tokio::io::{copy, AsyncWriteExt};
|
||||||
|
use tokio::net::tcp::{OwnedReadHalf, OwnedWriteHalf};
|
||||||
use tokio::net::{TcpListener, TcpStream};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
use tokio::sync::{mpsc, Mutex};
|
use tokio::sync::{broadcast, Mutex};
|
||||||
|
use tokio::time::{sleep, Duration};
|
||||||
type Tx = mpsc::UnboundedSender<Bytes>;
|
|
||||||
type Rx = mpsc::UnboundedReceiver<Bytes>;
|
|
||||||
|
|
||||||
pub struct Shared {
|
pub struct Shared {
|
||||||
handshake_cache: HashMap<Bytes, Tx>,
|
handshake_cache: HashMap<Bytes, OwnedWriteHalf>,
|
||||||
}
|
}
|
||||||
type State = Arc<Mutex<Shared>>;
|
type State = Arc<Mutex<Shared>>;
|
||||||
|
type IdChannelSender = broadcast::Sender<bytes::Bytes>;
|
||||||
|
type IdChannelReceiver = broadcast::Receiver<bytes::Bytes>;
|
||||||
|
|
||||||
struct Client {
|
struct Client {
|
||||||
socket: TcpStream,
|
read_socket: OwnedReadHalf,
|
||||||
rx: Rx,
|
peer_write_socket: Option<OwnedWriteHalf>,
|
||||||
peer_tx: Option<Tx>,
|
|
||||||
}
|
}
|
||||||
struct StapledClient {
|
struct StapledClient {
|
||||||
socket: TcpStream,
|
read_socket: OwnedReadHalf,
|
||||||
rx: Rx,
|
peer_write_socket: OwnedWriteHalf,
|
||||||
peer_tx: Tx,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Shared {
|
impl Shared {
|
||||||
|
@ -38,42 +35,54 @@ impl Shared {
|
||||||
|
|
||||||
impl Client {
|
impl Client {
|
||||||
async fn new(id: Bytes, state: State, socket: TcpStream) -> Result<Client> {
|
async fn new(id: Bytes, state: State, socket: TcpStream) -> Result<Client> {
|
||||||
let (tx, rx) = mpsc::unbounded_channel();
|
let (read_socket, write_socket) = socket.into_split();
|
||||||
let mut shared = state.lock().await;
|
let mut shared = state.lock().await;
|
||||||
let client = Client {
|
let client = Client {
|
||||||
socket,
|
read_socket,
|
||||||
rx,
|
peer_write_socket: shared.handshake_cache.remove(&id),
|
||||||
peer_tx: shared.handshake_cache.remove(&id),
|
|
||||||
};
|
};
|
||||||
shared.handshake_cache.insert(id, tx);
|
shared.handshake_cache.insert(id, write_socket);
|
||||||
Ok(client)
|
Ok(client)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn upgrade(client: Client, state: State, handshake: Handshake) -> Result<StapledClient> {
|
async fn await_peer(state: State, id: &Bytes, id_channel: IdChannelReceiver) -> OwnedWriteHalf {
|
||||||
let mut client = client;
|
let mut id_channel = id_channel;
|
||||||
let peer_tx = match client.peer_tx {
|
loop {
|
||||||
// Receiver - already stapled at creation
|
|
||||||
Some(peer_tx) => peer_tx,
|
|
||||||
// Sender - needs to wait for the incoming msg to look up peer_tx
|
|
||||||
None => {
|
|
||||||
tokio::select! {
|
tokio::select! {
|
||||||
// Client reads handshake message sent over channel
|
res = id_channel.recv() => match res {
|
||||||
Some(msg) = client.rx.recv() => {
|
Ok(bytes) if bytes == id => {
|
||||||
// Writes parnter handshake message over wire
|
match state.lock().await.handshake_cache.remove(id) {
|
||||||
client.socket.write_all(&msg[..]).await?
|
Some(tx_write_half) => {return tx_write_half},
|
||||||
|
_ => continue
|
||||||
|
}
|
||||||
|
},
|
||||||
|
_ => continue
|
||||||
|
},
|
||||||
|
else => {
|
||||||
|
sleep(Duration::from_millis(500)).await;
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
match state.lock().await.handshake_cache.remove(&handshake.id) {
|
|
||||||
Some(peer_tx) => peer_tx,
|
|
||||||
None => return Err(anyhow!("Connection not stapled")),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn upgrade(
|
||||||
|
client: Client,
|
||||||
|
state: State,
|
||||||
|
handshake: Handshake,
|
||||||
|
id_channel: IdChannelSender,
|
||||||
|
) -> Result<StapledClient> {
|
||||||
|
let mut peer_write_socket = match client.peer_write_socket {
|
||||||
|
// Receiver - already stapled at creation
|
||||||
|
Some(peer_write_socket) => peer_write_socket,
|
||||||
|
// Sender - needs to wait for the incoming msg to look up peer_tx
|
||||||
|
None => Client::await_peer(state, &handshake.id, id_channel.subscribe()).await,
|
||||||
};
|
};
|
||||||
peer_tx.send(handshake.outbound_msg)?;
|
println!("past await peer");
|
||||||
|
peer_write_socket.write_all(&handshake.outbound_msg).await?;
|
||||||
Ok(StapledClient {
|
Ok(StapledClient {
|
||||||
socket: client.socket,
|
read_socket: client.read_socket,
|
||||||
rx: client.rx,
|
peer_write_socket,
|
||||||
peer_tx,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -82,12 +91,14 @@ pub async fn serve() -> Result<()> {
|
||||||
let addr = "127.0.0.1:8080".to_string();
|
let addr = "127.0.0.1:8080".to_string();
|
||||||
let listener = TcpListener::bind(&addr).await?;
|
let listener = TcpListener::bind(&addr).await?;
|
||||||
let state = Arc::new(Mutex::new(Shared::new()));
|
let state = Arc::new(Mutex::new(Shared::new()));
|
||||||
|
let (tx, _rx) = broadcast::channel::<Bytes>(100);
|
||||||
println!("Listening on: {}", addr);
|
println!("Listening on: {}", addr);
|
||||||
loop {
|
loop {
|
||||||
let (stream, address) = listener.accept().await?;
|
let (stream, _address) = listener.accept().await?;
|
||||||
let state = Arc::clone(&state);
|
let state = Arc::clone(&state);
|
||||||
|
let tx = tx.clone();
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match handle_connection(state, stream, address).await {
|
match handle_connection(state, stream, tx).await {
|
||||||
Ok(_) => println!("Connection complete!"),
|
Ok(_) => println!("Connection complete!"),
|
||||||
Err(err) => println!("Error handling connection! {:?}", err),
|
Err(err) => println!("Error handling connection! {:?}", err),
|
||||||
}
|
}
|
||||||
|
@ -98,14 +109,15 @@ pub async fn serve() -> Result<()> {
|
||||||
pub async fn handle_connection(
|
pub async fn handle_connection(
|
||||||
state: Arc<Mutex<Shared>>,
|
state: Arc<Mutex<Shared>>,
|
||||||
socket: TcpStream,
|
socket: TcpStream,
|
||||||
_addr: SocketAddr,
|
id_channel: IdChannelSender,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
socket.readable().await?;
|
socket.readable().await?;
|
||||||
let (handshake, socket) = Handshake::from_socket(socket).await?;
|
let (handshake, socket) = Handshake::from_socket(socket).await?;
|
||||||
let id = handshake.id.clone();
|
let id = handshake.id.clone();
|
||||||
let client = Client::new(id.clone(), state.clone(), socket).await?;
|
let client = Client::new(id.clone(), state.clone(), socket).await?;
|
||||||
|
id_channel.send(id.clone())?;
|
||||||
println!("Client created");
|
println!("Client created");
|
||||||
let mut client = match Client::upgrade(client, state.clone(), handshake).await {
|
let mut client = match Client::upgrade(client, state.clone(), handshake, id_channel).await {
|
||||||
Ok(client) => client,
|
Ok(client) => client,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
// Clear handshake cache if staple is unsuccessful
|
// Clear handshake cache if staple is unsuccessful
|
||||||
|
@ -115,28 +127,6 @@ pub async fn handle_connection(
|
||||||
};
|
};
|
||||||
println!("Client upgraded");
|
println!("Client upgraded");
|
||||||
// The handshake cache should be empty for {id} at this point.
|
// The handshake cache should be empty for {id} at this point.
|
||||||
let mut client_buffer = BytesMut::with_capacity(PER_CLIENT_BUFFER);
|
copy(&mut client.read_socket, &mut client.peer_write_socket).await?;
|
||||||
loop {
|
|
||||||
tokio::select! {
|
|
||||||
Some(msg) = client.rx.recv() => {
|
|
||||||
// println!("piping bytes= {:?}", msg);
|
|
||||||
client.socket.write_all(&msg[..]).await?
|
|
||||||
}
|
|
||||||
result = client.socket.read_buf(&mut client_buffer) => match result {
|
|
||||||
Ok(0) => {
|
|
||||||
break;
|
|
||||||
},
|
|
||||||
Ok(n) => {
|
|
||||||
let b = BytesMut::from(&client_buffer[0..n]).freeze();
|
|
||||||
client.peer_tx.send(b)?;
|
|
||||||
client_buffer.clear();
|
|
||||||
},
|
|
||||||
Err(e) => {
|
|
||||||
println!("Error {:?}", e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
println!("done with client");
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in a new issue