Simplify connection stapling

This commit is contained in:
rictorlome 2022-02-13 22:02:47 -05:00
parent 7befc61ab3
commit e701d69946
4 changed files with 52 additions and 102 deletions

View file

@ -37,7 +37,6 @@ pub async fn send(file_paths: &Vec<PathBuf>, password: &String) -> Result<()> {
// Complete handshake, returning cipher used for encryption // Complete handshake, returning cipher used for encryption
let (stream, cipher) = handshake( let (stream, cipher) = handshake(
&mut stream, &mut stream,
true,
Bytes::from(password.to_string()), Bytes::from(password.to_string()),
pass_to_bytes(password), pass_to_bytes(password),
) )
@ -57,7 +56,6 @@ pub async fn receive(password: &String) -> Result<()> {
let mut stream = Message::to_stream(socket); let mut stream = Message::to_stream(socket);
let (stream, cipher) = handshake( let (stream, cipher) = handshake(
&mut stream, &mut stream,
false,
Bytes::from(password.to_string()), Bytes::from(password.to_string()),
pass_to_bytes(password), pass_to_bytes(password),
) )

View file

@ -10,7 +10,6 @@ use spake2::{Ed25519Group, Identity, Password, Spake2};
pub async fn handshake( pub async fn handshake(
stream: &mut MessageStream, stream: &mut MessageStream,
up: bool,
password: Bytes, password: Bytes,
id: Bytes, id: Bytes,
) -> Result<(&mut MessageStream, Aes256Gcm)> { ) -> Result<(&mut MessageStream, Aes256Gcm)> {
@ -18,7 +17,6 @@ pub async fn handshake(
Spake2::<Ed25519Group>::start_symmetric(&Password::new(password), &Identity::new(&id)); Spake2::<Ed25519Group>::start_symmetric(&Password::new(password), &Identity::new(&id));
println!("client - sending handshake msg"); println!("client - sending handshake msg");
let handshake_msg = Message::HandshakeMessage(HandshakePayload { let handshake_msg = Message::HandshakeMessage(HandshakePayload {
up,
id, id,
msg: Bytes::from(outbound_msg), msg: Bytes::from(outbound_msg),
}); });

View file

@ -21,7 +21,6 @@ pub enum Message {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct HandshakePayload { pub struct HandshakePayload {
pub up: bool,
pub id: Bytes, pub id: Bytes,
pub msg: Bytes, pub msg: Bytes,
} }

View file

@ -13,109 +13,74 @@ type Tx = mpsc::UnboundedSender<Message>;
type Rx = mpsc::UnboundedReceiver<Message>; type Rx = mpsc::UnboundedReceiver<Message>;
pub struct Shared { pub struct Shared {
handshakes: HashMap<Bytes, Rx>, handshake_cache: HashMap<Bytes, Tx>,
senders: HashMap<Bytes, Tx>,
receivers: HashMap<Bytes, Tx>,
} }
type State = Arc<Mutex<Shared>>; type State = Arc<Mutex<Shared>>;
struct Client<'a> { struct Client {
up: bool, messages: MessageStream,
id: Bytes,
messages: &'a mut MessageStream,
rx: Rx, rx: Rx,
peer_tx: Option<Tx>,
}
struct StapledClient {
messages: MessageStream,
rx: Rx,
peer_tx: Tx,
} }
impl Shared { impl Shared {
fn new() -> Self { fn new() -> Self {
Shared { Shared {
handshakes: HashMap::new(), handshake_cache: HashMap::new(),
senders: HashMap::new(),
receivers: HashMap::new(),
} }
} }
async fn relay<'a>(&self, client: &Client<'a>, message: Message) -> Result<()> {
println!("in relay - got client={:?}, msg {:?}", client.id, message);
match client.up {
true => match self.receivers.get(&client.id) {
Some(tx) => {
tx.send(message)?;
}
None => {
return Err(anyhow!(RuckError::PairDisconnected));
}
},
false => match self.senders.get(&client.id) {
Some(tx) => {
tx.send(message)?;
}
None => {
return Err(anyhow!(RuckError::PairDisconnected));
}
},
}
Ok(())
}
} }
impl<'a> Client<'a> { impl Client {
async fn new( async fn new(id: Bytes, state: State, messages: MessageStream) -> Result<Client> {
up: bool,
id: Bytes,
state: State,
messages: &'a mut MessageStream,
) -> Result<Client<'a>> {
let (tx, rx) = mpsc::unbounded_channel(); let (tx, rx) = mpsc::unbounded_channel();
println!("server - creating client up={:?}, id={:?}", up, id); let mut shared = state.lock().await;
let shared = &mut state.lock().await; let client = Client {
match shared.senders.get(&id) {
Some(_) if up => {
messages
.send(Message::ErrorMessage(RuckError::SenderAlreadyConnected))
.await?;
}
Some(_) => {
println!("server - adding client to receivers");
shared.receivers.insert(id.clone(), tx);
}
None if up => {
println!("server - adding client to senders");
shared.senders.insert(id.clone(), tx);
}
None => {
messages
.send(Message::ErrorMessage(RuckError::SenderNotConnected))
.await?;
}
}
Ok(Client {
up,
id,
messages, messages,
rx, rx,
}) peer_tx: shared.handshake_cache.remove(&id),
};
shared.handshake_cache.insert(id, tx);
Ok(client)
} }
async fn complete_handshake(&mut self, state: State, msg: Message) -> Result<()> {
match self.up { async fn upgrade(
true => { client: Client,
let (tx, rx) = mpsc::unbounded_channel(); state: State,
tx.send(msg)?; handshake_payload: HandshakePayload,
state.lock().await.handshakes.insert(self.id.clone(), rx); ) -> Result<StapledClient> {
} let mut client = client;
false => { let peer_tx = match client.peer_tx {
let shared = &mut state.lock().await; // Receiver - already stapled at creation
if let Some(tx) = shared.senders.get(&self.id) { Some(peer_tx) => peer_tx,
tx.send(msg)?; // Sender - needs to wait for the incoming msg to look up peer_tx
} None => {
if let Some(mut rx) = shared.handshakes.remove(&self.id) { match client.rx.recv().await {
drop(shared); Some(msg) => client.messages.send(msg).await?,
if let Some(msg) = rx.recv().await { None => return Err(anyhow!("Connection not stapled")),
self.messages.send(msg).await?; };
} match state
.lock()
.await
.handshake_cache
.remove(&handshake_payload.id)
{
Some(peer_tx) => peer_tx,
None => return Err(anyhow!("Connection not stapled")),
} }
} }
} };
Ok(()) peer_tx.send(Message::HandshakeMessage(handshake_payload))?;
Ok(StapledClient {
messages: client.messages,
rx: client.rx,
peer_tx,
})
} }
} }
@ -157,16 +122,8 @@ pub async fn handle_connection(
} }
}; };
println!("server - received msg from {:?}", addr); println!("server - received msg from {:?}", addr);
let mut client = Client::new( let client = Client::new(handshake_payload.id.clone(), state.clone(), stream).await?;
handshake_payload.up, let mut client = Client::upgrade(client, state.clone(), handshake_payload).await?;
handshake_payload.id.clone(),
state.clone(),
&mut stream,
)
.await?;
client
.complete_handshake(state.clone(), Message::HandshakeMessage(handshake_payload))
.await?;
// add client to state here // add client to state here
loop { loop {
tokio::select! { tokio::select! {
@ -176,9 +133,7 @@ pub async fn handle_connection(
} }
result = client.messages.next() => match result { result = client.messages.next() => match result {
Some(Ok(msg)) => { Some(Ok(msg)) => {
println!("GOT: {:?}", msg); client.peer_tx.send(msg)?
let state = state.lock().await;
state.relay(&client, msg).await?;
} }
Some(Err(e)) => { Some(Err(e)) => {
println!("Error {:?}", e); println!("Error {:?}", e);