diff --git a/src/client.rs b/src/client.rs index bf15038..fa1ce0a 100644 --- a/src/client.rs +++ b/src/client.rs @@ -1,10 +1,8 @@ use crate::conf::BUFFER_SIZE; -use crate::crypto::Crypt; +use crate::connection::Connection; use crate::file::{to_size_string, FileHandle, FileInfo}; use crate::handshake::Handshake; -use crate::message::{ - EncryptedMessage, FileNegotiationPayload, FileTransferPayload, Message, MessageStream, -}; +use crate::message::{FileNegotiationPayload, FileTransferPayload, Message}; use anyhow::{anyhow, Result}; use bytes::{Bytes, BytesMut}; @@ -30,13 +28,12 @@ pub async fn send(file_paths: &Vec, password: &String) -> Result<()> { // Complete handshake, returning key used for encryption let (socket, key) = handshake.negotiate(socket, s1).await?; - let mut stream = Message::to_stream(socket); - let crypt = Crypt::new(&key); + let mut connection = Connection::new(socket, key); // Complete file negotiation - let handles = negotiate_files_up(handles, &mut stream, &crypt).await?; + let handles = negotiate_files_up(&mut connection, handles).await?; // Upload negotiated files - upload_encrypted_files(&mut stream, handles, &crypt).await?; + upload_encrypted_files(&mut connection, handles).await?; println!("Done uploading."); // Exit @@ -47,11 +44,10 @@ pub async fn receive(password: &String) -> Result<()> { let socket = TcpStream::connect("127.0.0.1:8080").await?; let (handshake, s1) = Handshake::from_password(password); let (socket, key) = handshake.negotiate(socket, s1).await?; - let mut stream = Message::to_stream(socket); - let crypt = Crypt::new(&key); - let files = negotiate_files_down(&mut stream, &crypt).await?; + let mut connection = Connection::new(socket, key); + let files = negotiate_files_down(&mut connection).await?; - download_files(files, &mut stream, &crypt).await?; + download_files(&mut connection, files).await?; return Ok(()); } @@ -64,28 +60,15 @@ pub async fn get_file_handles(file_paths: &Vec) -> Result, - stream: &mut MessageStream, - crypt: &Crypt, ) -> Result> { let files = file_handles.iter().map(|fh| fh.to_file_info()).collect(); - let msg = EncryptedMessage::FileNegotiationMessage(FileNegotiationPayload { files }); - let server_msg = msg.to_encrypted_message(crypt)?; - println!("server_msg encrypted: {:?}", server_msg); - stream.send(server_msg).await?; - let reply_payload = match stream.next().await { - Some(Ok(msg)) => match msg { - Message::EncryptedMessage(response) => response, - }, - _ => { - return Err(anyhow!("No response to negotiation message")); - } - }; - let plaintext_reply = EncryptedMessage::from_encrypted_message(crypt, &reply_payload)?; - let requested_paths: Vec = match plaintext_reply { - EncryptedMessage::FileNegotiationMessage(fnm) => { - fnm.files.into_iter().map(|f| f.path).collect() - } + let msg = Message::FileNegotiationMessage(FileNegotiationPayload { files }); + conn.send_msg(msg).await?; + let reply = conn.await_msg().await?; + let requested_paths: Vec = match reply { + Message::FileNegotiationMessage(fnm) => fnm.files.into_iter().map(|f| f.path).collect(), _ => return Err(anyhow!("Expecting file negotiation message back")), }; Ok(file_handles @@ -94,21 +77,10 @@ pub async fn negotiate_files_up( .collect()) } -pub async fn negotiate_files_down( - stream: &mut MessageStream, - crypt: &Crypt, -) -> Result> { - let file_offer = match stream.next().await { - Some(Ok(msg)) => match msg { - Message::EncryptedMessage(response) => response, - }, - _ => { - return Err(anyhow!("No response to negotiation message")); - } - }; - let plaintext_offer = EncryptedMessage::from_encrypted_message(crypt, &file_offer)?; - let requested_infos: Vec = match plaintext_offer { - EncryptedMessage::FileNegotiationMessage(fnm) => fnm.files, +pub async fn negotiate_files_down(conn: &mut Connection) -> Result> { + let offer = conn.await_msg().await?; + let requested_infos: Vec = match offer { + Message::FileNegotiationMessage(fnm) => fnm.files, _ => return Err(anyhow!("Expecting file negotiation message back")), }; let mut stdin = FramedRead::new(io::stdin(), LinesCodec::new()); @@ -123,47 +95,22 @@ pub async fn negotiate_files_down( _ => {} } } - let msg = EncryptedMessage::FileNegotiationMessage(FileNegotiationPayload { + let msg = Message::FileNegotiationMessage(FileNegotiationPayload { files: files.clone(), }); - let server_msg = msg.to_encrypted_message(crypt)?; - stream.send(server_msg).await?; + conn.send_msg(msg).await?; Ok(files) } -pub async fn upload_encrypted_files( - stream: &mut MessageStream, - handles: Vec, - cipher: &Crypt, -) -> Result<()> { - let (tx, mut rx) = mpsc::unbounded_channel::(); +pub async fn upload_encrypted_files(conn: &mut Connection, handles: Vec) -> Result<()> { for mut handle in handles { - let txc = tx.clone(); - tokio::spawn(async move { - let _ = enqueue_file_chunks(&mut handle, txc).await; - }); - } - - loop { - tokio::select! { - Some(msg) = rx.recv() => { - // println!("message received to client.rx {:?}", msg); - let x = msg.to_encrypted_message(cipher)?; - stream.send(x).await? - } - else => { - println!("breaking"); - break - }, - } + enqueue_file_chunks(conn, &mut handle).await?; } + println!("Files uploaded."); Ok(()) } -pub async fn enqueue_file_chunks( - fh: &mut FileHandle, - tx: mpsc::UnboundedSender, -) -> Result<()> { +pub async fn enqueue_file_chunks(conn: &mut Connection, fh: &mut FileHandle) -> Result<()> { let mut chunk_num = 0; let mut bytes_read = 1; while bytes_read != 0 { @@ -173,12 +120,12 @@ pub async fn enqueue_file_chunks( if bytes_read != 0 { let chunk = buf.freeze(); let file_info = fh.to_file_info(); - let ftp = EncryptedMessage::FileTransferMessage(FileTransferPayload { + let ftp = Message::FileTransferMessage(FileTransferPayload { chunk, chunk_num, file_info, }); - tx.send(ftp)?; + conn.send_msg(ftp).await?; chunk_num += 1; } } @@ -186,11 +133,7 @@ pub async fn enqueue_file_chunks( Ok(()) } -pub async fn download_files( - file_infos: Vec, - stream: &mut MessageStream, - cipher: &Crypt, -) -> Result<()> { +pub async fn download_files(conn: &mut Connection, file_infos: Vec) -> Result<()> { // for each file_info let mut info_handles: HashMap> = HashMap::new(); for fi in file_infos { @@ -201,29 +144,24 @@ pub async fn download_files( } loop { tokio::select! { - result = stream.next() => match result { - Some(Ok(Message::EncryptedMessage(payload))) => { - let ec = EncryptedMessage::from_encrypted_message(cipher, &payload)?; - // println!("encrypted message received! {:?}", ec); - match ec { - EncryptedMessage::FileTransferMessage(payload) => { - // println!("matched file transfer message"); + result = conn.await_msg() => match result { + Ok(msg) => { + match msg { + Message::FileTransferMessage(payload) => { if let Some(tx) = info_handles.get(&payload.file_info.path) { - // println!("matched on filetype, sending to tx"); tx.send((payload.chunk_num, payload.chunk))? - }; + } }, - _ => {println!("wrong msg")} + _ => { + println!("Wrong message type"); + return Err(anyhow!("wrong message type")); + } } - } - Some(Err(e)) => { - println!("Error {:?}", e); - } - None => break, + }, + Err(e) => return Err(anyhow!(e.to_string())), } } } - Ok(()) } pub async fn download_file( diff --git a/src/connection.rs b/src/connection.rs index 413ad0b..3bca2a2 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -1,4 +1,38 @@ +use crate::crypto::Crypt; +use crate::message::{Message, MessageStream}; +use anyhow::{anyhow, Result}; +use futures::{SinkExt, StreamExt}; +use tokio::net::TcpStream; + pub struct Connection { - ms: MessageStream; - crypt: Crypt -} \ No newline at end of file + ms: MessageStream, + crypt: Crypt, +} + +impl Connection { + pub fn new(socket: TcpStream, key: Vec) -> Self { + let ms = Message::to_stream(socket); + let crypt = Crypt::new(&key); + Connection { ms, crypt } + } + pub async fn send_msg(&mut self, msg: Message) -> Result<()> { + let msg = msg.serialize()?; + let bytes = self.crypt.encrypt(msg)?; + match self.ms.send(bytes).await { + Ok(_) => Ok(()), + Err(e) => Err(anyhow!(e.to_string())), + } + } + + pub async fn await_msg(&mut self) -> Result { + match self.ms.next().await { + Some(Ok(msg)) => { + let decrypted_bytes = self.crypt.decrypt(msg.freeze())?; + Message::deserialize(decrypted_bytes) + } + _ => { + return Err(anyhow!("No response to negotiation message")); + } + } + } +} diff --git a/src/crypto.rs b/src/crypto.rs index 6fb97d6..fc78e80 100644 --- a/src/crypto.rs +++ b/src/crypto.rs @@ -1,9 +1,8 @@ use crate::conf::NONCE_SIZE_IN_BYTES; -use crate::message::EncryptedPayload; use aes_gcm::aead::{Aead, NewAead}; use aes_gcm::{Aes256Gcm, Key, Nonce}; // Or `Aes128Gcm` use anyhow::{anyhow, Result}; -use bytes::Bytes; +use bytes::{Bytes, BytesMut}; use rand::{thread_rng, Rng}; @@ -19,25 +18,28 @@ impl Crypt { } } - pub fn encrypt(&self, body: &Vec) -> Result { + pub fn encrypt(&self, plaintext: Bytes) -> Result { let mut arr = [0u8; NONCE_SIZE_IN_BYTES]; thread_rng().try_fill(&mut arr[..])?; let nonce = Nonce::from_slice(&arr); - let plaintext = body.as_ref(); - match self.cipher.encrypt(nonce, plaintext) { - Ok(body) => Ok(EncryptedPayload { - nonce: arr.to_vec(), - body, - }), - Err(_) => Err(anyhow!("Encryption error")), + match self.cipher.encrypt(nonce, plaintext.as_ref()) { + Ok(body) => { + let mut buffer = BytesMut::with_capacity(NONCE_SIZE_IN_BYTES + body.len()); + buffer.extend_from_slice(nonce); + buffer.extend_from_slice(&body); + Ok(buffer.freeze()) + } + Err(e) => Err(anyhow!(e.to_string())), } } - pub fn decrypt(&self, payload: &EncryptedPayload) -> Result { - let nonce = Nonce::from_slice(payload.nonce.as_ref()); - match self.cipher.decrypt(nonce, payload.body.as_ref()) { + pub fn decrypt(&self, body: Bytes) -> Result { + let mut body = body; + let nonce_bytes = body.split_to(NONCE_SIZE_IN_BYTES); + let nonce = Nonce::from_slice(&nonce_bytes); + match self.cipher.decrypt(nonce, body.as_ref()) { Ok(payload) => Ok(Bytes::from(payload)), - Err(_) => Err(anyhow!("Decryption error")), + Err(e) => Err(anyhow!(e.to_string())), } } } diff --git a/src/main.rs b/src/main.rs index c33a731..e481fd1 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ mod cli; mod client; mod conf; +mod connection; mod crypto; mod file; mod handshake; diff --git a/src/message.rs b/src/message.rs index b815a1d..5abf618 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,28 +1,13 @@ -use crate::crypto::Crypt; use crate::file::FileInfo; use anyhow::{anyhow, Result}; use bytes::Bytes; use serde::{Deserialize, Serialize}; -use std::error::Error; -use std::fmt; use tokio::net::TcpStream; -use tokio_serde::{formats::SymmetricalBincode, SymmetricallyFramed}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] pub enum Message { - EncryptedMessage(EncryptedPayload), -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct EncryptedPayload { - pub nonce: Vec, - pub body: Vec, -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum EncryptedMessage { FileNegotiationMessage(FileNegotiationPayload), FileTransferMessage(FileTransferPayload), } @@ -39,62 +24,25 @@ pub struct FileTransferPayload { pub chunk: Bytes, } -impl EncryptedMessage { - pub fn from_encrypted_message(crypt: &Crypt, payload: &EncryptedPayload) -> Result { - let raw = crypt.decrypt(payload)?; - let res = match bincode::deserialize(raw.as_ref()) { - Ok(result) => result, - Err(e) => { - println!("deserialize error {:?}", e); - return Err(anyhow!("deser error")); - } - }; - Ok(res) +impl Message { + pub fn serialize(&self) -> Result { + match bincode::serialize(&self) { + Ok(vec) => Ok(Bytes::from(vec)), + Err(e) => Err(anyhow!(e.to_string())), + } } - pub fn to_encrypted_message(&self, crypt: &Crypt) -> Result { - let raw = match bincode::serialize(&self) { - Ok(result) => result, - Err(e) => { - println!("serialize error {:?}", e); - return Err(anyhow!("serialize error")); - } - }; - let payload = crypt.encrypt(&raw)?; - Ok(Message::EncryptedMessage(payload)) - } -} - -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub enum RuckError { - NotHandshake, - SenderNotConnected, - SenderAlreadyConnected, - PairDisconnected, -} - -impl fmt::Display for RuckError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "RuckError is here!") - } -} - -impl Error for RuckError { - fn source(&self) -> Option<&(dyn Error + 'static)> { - Some(self) + pub fn deserialize(bytes: Bytes) -> Result { + match bincode::deserialize(bytes.as_ref()) { + Ok(msg) => Ok(msg), + Err(e) => Err(anyhow!(e.to_string())), + } } } impl Message { pub fn to_stream(stream: TcpStream) -> MessageStream { - tokio_serde::SymmetricallyFramed::new( - Framed::new(stream, LengthDelimitedCodec::new()), - tokio_serde::formats::SymmetricalBincode::::default(), - ) + Framed::new(stream, LengthDelimitedCodec::new()) } } -pub type MessageStream = SymmetricallyFramed< - Framed, - Message, - SymmetricalBincode, ->; +pub type MessageStream = Framed; diff --git a/src/server.rs b/src/server.rs index 3509c47..f4cf567 100644 --- a/src/server.rs +++ b/src/server.rs @@ -57,7 +57,9 @@ impl Client { // Sender - needs to wait for the incoming msg to look up peer_tx None => { tokio::select! { + // Client reads handshake message sent over channel Some(msg) = client.rx.recv() => { + // Writes parnter handshake message over wire client.socket.write_all(&msg[..]).await? } } @@ -126,9 +128,8 @@ pub async fn handle_connection( }, Ok(n) => { let b = BytesMut::from(&client_buffer[0..n]).freeze(); - // println!("reading more = {:?}", b); + client.peer_tx.send(b)?; client_buffer.clear(); - client.peer_tx.send(b)? }, Err(e) => { println!("Error {:?}", e);