diff --git a/Cargo.lock b/Cargo.lock index a06ddb1..8798cbc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -75,6 +75,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "blake2" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b94ba84325db59637ffc528bbe8c7f86c02c57cff5c0e2b9b00f9a851f42f309" +dependencies = [ + "digest 0.10.2", +] + [[package]] name = "block-buffer" version = "0.10.2" @@ -672,6 +681,7 @@ dependencies = [ "aes-gcm", "anyhow", "bincode", + "blake2", "bytes", "clap", "futures", diff --git a/Cargo.toml b/Cargo.toml index ba06217..0358121 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" aes-gcm = "0.9.4" anyhow = "1.0" +blake2 = "0.10.2" bytes = { version = "1", features = ["serde"] } bincode = "1.3.3" clap = { version = "3.0.14", features = ["derive"] } diff --git a/src/client.rs b/src/client.rs index 84d679e..1eb07b9 100644 --- a/src/client.rs +++ b/src/client.rs @@ -2,18 +2,28 @@ use crate::crypto::handshake; use crate::message::{Message, MessageStream}; use anyhow::Result; +use blake2::{Blake2s256, Digest}; use bytes::{BufMut, Bytes, BytesMut}; use futures::prelude::*; use std::path::PathBuf; use tokio::net::TcpStream; +fn pass_to_bytes(password: &String) -> Bytes { + let mut hasher = Blake2s256::new(); + hasher.update(password.as_bytes()); + let res = hasher.finalize(); + BytesMut::from(&res[..]).freeze() +} + pub async fn send(file_paths: &Vec, password: &String) -> Result<()> { let socket = TcpStream::connect("127.0.0.1:8080").await?; let mut stream = Message::to_stream(socket); + let (stream, key) = handshake( &mut stream, + true, Bytes::from(password.to_string()), - Bytes::from("id123"), + pass_to_bytes(password), ) .await?; @@ -39,8 +49,9 @@ pub async fn receive(password: &String) -> Result<()> { let mut stream = Message::to_stream(socket); let (stream, key) = handshake( &mut stream, + false, Bytes::from(password.to_string()), - Bytes::from("id123"), + pass_to_bytes(password), ) .await?; return Ok(()); diff --git a/src/crypto.rs b/src/crypto.rs index d6aec92..e98f20c 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -1,4 +1,4 @@ -use crate::message::{HandshakeMessage, Message, MessageStream}; +use crate::message::{HandshakePayload, Message, MessageStream}; use anyhow::{anyhow, Result}; use bytes::Bytes; @@ -7,17 +7,20 @@ use spake2::{Ed25519Group, Identity, Password, Spake2}; pub async fn handshake( stream: &mut MessageStream, + up: bool, password: Bytes, id: Bytes, ) -> Result<(&mut MessageStream, Bytes)> { let (s1, outbound_msg) = Spake2::::start_symmetric(&Password::new(password), &Identity::new(&id)); - stream - .send(Message::HandshakeMessage(HandshakeMessage { - id, - msg: Bytes::from(outbound_msg), - })) - .await?; + println!("client - sending handshake msg"); + let handshake_msg = Message::HandshakeMessage(HandshakePayload { + up, + id, + msg: Bytes::from(outbound_msg), + }); + println!("client - handshake msg, {:?}", handshake_msg); + stream.send(handshake_msg).await?; let first_message = match stream.next().await { Some(Ok(msg)) => match msg { Message::HandshakeMessage(response) => response.msg, @@ -27,6 +30,7 @@ pub async fn handshake( return Err(anyhow!("No response to handshake message")); } }; + println!("client - handshake msg responded to"); let key = match s1.finish(&first_message[..]) { Ok(key_bytes) => key_bytes, Err(e) => return Err(anyhow!(e.to_string())), diff --git a/src/message.rs b/src/message.rs index fcca1b6..5e42ea7 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,20 +1,44 @@ use bytes::Bytes; use serde::{Deserialize, Serialize}; +use std::error::Error; +use std::fmt; use tokio::net::TcpStream; use tokio_serde::{formats::SymmetricalBincode, SymmetricallyFramed}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Message { - HandshakeMessage(HandshakeMessage), + HandshakeMessage(HandshakePayload), + ErrorMessage(RuckError), } #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct HandshakeMessage { +pub struct HandshakePayload { + pub up: bool, pub id: Bytes, pub msg: Bytes, } +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] +pub enum RuckError { + NotHandshake, + SenderNotConnected, + SenderAlreadyConnected, + PairDisconnected, +} + +impl fmt::Display for RuckError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "RuckError is here!") + } +} + +impl Error for RuckError { + fn source(&self) -> Option<&(dyn Error + 'static)> { + Some(self) + } +} + impl Message { pub fn to_stream(stream: TcpStream) -> MessageStream { tokio_serde::SymmetricallyFramed::new( diff --git a/src/server.rs b/src/server.rs index f123e60..0ee4f4c 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,66 +1,125 @@ -use crate::message::{Message, MessageStream}; +use crate::message::{HandshakePayload, Message, MessageStream, RuckError}; +use anyhow::{anyhow, Result}; +use bytes::Bytes; use futures::prelude::*; use std::collections::HashMap; -use std::io; use std::net::SocketAddr; use std::sync::Arc; use tokio::net::{TcpListener, TcpStream}; -use tokio::sync::{mpsc, Mutex}; +use tokio::sync::{mpsc, oneshot, Mutex}; type Tx = mpsc::UnboundedSender; type Rx = mpsc::UnboundedReceiver; pub struct Shared { - rooms: HashMap, + handshakes: HashMap, + senders: HashMap, + receivers: HashMap, } type State = Arc>; -struct RoomInfo { - sender_tx: Tx, -} - -struct Client { - is_sender: bool, - messages: MessageStream, +struct Client<'a> { + up: bool, + id: Bytes, + messages: &'a mut MessageStream, rx: Rx, } impl Shared { fn new() -> Self { Shared { - rooms: HashMap::new(), + handshakes: HashMap::new(), + senders: HashMap::new(), + receivers: HashMap::new(), } } - - // async fn broadcast(&mut self, sender: SocketAddr, message: Message) { - // for peer in self.peers.iter_mut() { - // if *peer.0 != sender { - // let _ = peer.1.send(message.clone()); - // } - // } - // } + async fn relay<'a>(&self, client: &Client<'a>, message: Message) -> Result<()> { + println!("in relay - got client={:?}, msg {:?}", client.id, message); + match client.up { + true => match self.receivers.get(&client.id) { + Some(tx) => { + tx.send(message)?; + } + None => { + return Err(anyhow!(RuckError::PairDisconnected)); + } + }, + false => match self.senders.get(&client.id) { + Some(tx) => { + tx.send(message)?; + } + None => { + return Err(anyhow!(RuckError::PairDisconnected)); + } + }, + } + Ok(()) + } } -impl Client { - async fn new(is_sender: bool, state: State, messages: MessageStream) -> io::Result { +impl<'a> Client<'a> { + async fn new( + up: bool, + id: Bytes, + state: State, + messages: &'a mut MessageStream, + ) -> Result> { let (tx, rx) = mpsc::unbounded_channel(); - let room_info = RoomInfo { sender_tx: tx }; - state - .lock() - .await - .rooms - .insert("abc".to_string(), room_info); - + println!("server - creating client up={:?}, id={:?}", up, id); + let shared = &mut state.lock().await; + match shared.senders.get(&id) { + Some(_) if up => { + messages + .send(Message::ErrorMessage(RuckError::SenderAlreadyConnected)) + .await?; + } + Some(_) => { + println!("server - adding client to receivers"); + shared.receivers.insert(id.clone(), tx); + } + None if up => { + println!("server - adding client to senders"); + shared.senders.insert(id.clone(), tx); + } + None => { + messages + .send(Message::ErrorMessage(RuckError::SenderNotConnected)) + .await?; + } + } Ok(Client { - is_sender, + up, + id, messages, rx, }) } + async fn complete_handshake(&mut self, state: State, msg: Message) -> Result<()> { + match self.up { + true => { + let (tx, rx) = mpsc::unbounded_channel(); + tx.send(msg)?; + let shared = &mut state.lock().await; + shared.handshakes.insert(self.id.clone(), rx); + } + false => { + let shared = &mut state.lock().await; + if let Some(tx) = shared.senders.get(&self.id) { + tx.send(msg)?; + } + if let Some(mut rx) = shared.handshakes.remove(&self.id) { + if let Some(msg) = rx.recv().await { + self.messages.send(msg).await?; + } + } + } + } + Ok(()) + } } -pub async fn serve() -> Result<(), Box> { +pub async fn serve() -> Result<()> { let addr = "127.0.0.1:8080".to_string(); let listener = TcpListener::bind(&addr).await?; let state = Arc::new(Mutex::new(Shared::new())); @@ -70,8 +129,8 @@ pub async fn serve() -> Result<(), Box> { let state = Arc::clone(&state); tokio::spawn(async move { match handle_connection(state, stream, address).await { - Ok(_) => println!("ok"), - Err(_) => println!("err"), + Ok(_) => println!("Connection complete!"), + Err(err) => println!("Error handling connection! {:?}", err), } }); } @@ -81,19 +140,36 @@ pub async fn handle_connection( state: Arc>, socket: TcpStream, addr: SocketAddr, -) -> Result<(), Box> { +) -> Result<()> { let mut stream = Message::to_stream(socket); - let first_message = match stream.next().await { - Some(Ok(msg)) => { - println!("first msg: {:?}", msg); - msg + println!("server - new conn from {:?}", addr); + let handshake_payload = match stream.next().await { + Some(Ok(Message::HandshakeMessage(payload))) => payload, + Some(Ok(_)) => { + stream + .send(Message::ErrorMessage(RuckError::NotHandshake)) + .await?; + return Ok(()); } _ => { - println!("no first message"); + println!("No first message"); return Ok(()); } }; - let mut client = Client::new(true, state.clone(), stream).await?; + // How do I get this handshake message to the peer + // If it's the sender, the recipient hasn't arrived yet + // If it's the recipient, the sender was created before + println!("server - received msg from {:?}", addr); + let mut client = Client::new( + handshake_payload.up, + handshake_payload.id.clone(), + state.clone(), + &mut stream, + ) + .await?; + client + .complete_handshake(state.clone(), Message::HandshakeMessage(handshake_payload)) + .await?; // add client to state here loop { tokio::select! { @@ -104,6 +180,8 @@ pub async fn handle_connection( result = client.messages.next() => match result { Some(Ok(msg)) => { println!("GOT: {:?}", msg); + let state = state.lock().await; + state.relay(&client, msg).await?; } Some(Err(e)) => { println!("Error {:?}", e);