diff --git a/src/client.rs b/src/client.rs index 35355ee..f6db18c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -12,9 +12,11 @@ use futures::future::try_join_all; use futures::prelude::*; use futures::stream::FuturesUnordered; use futures::StreamExt; +use std::collections::HashMap; +use std::ffi::OsStr; use std::path::PathBuf; -use std::pin::Pin; -use tokio::io::{self, AsyncReadExt}; +use tokio::fs::File; +use tokio::io::{self, AsyncReadExt, AsyncWriteExt}; use tokio::net::TcpStream; use tokio::sync::mpsc; use tokio_util::codec::{FramedRead, LinesCodec}; @@ -46,6 +48,7 @@ pub async fn send(file_paths: &Vec, password: &String) -> Result<()> { let handles = negotiate_files_up(handles, stream, &cipher).await?; // Upload negotiated files + upload_encrypted_files(stream, handles, &cipher).await?; // Exit Ok(()) @@ -62,6 +65,8 @@ pub async fn receive(password: &String) -> Result<()> { .await?; let files = negotiate_files_down(stream, &cipher).await?; + + download_files(files, stream, &cipher).await?; return Ok(()); } @@ -104,7 +109,10 @@ pub async fn negotiate_files_up( .collect()) } -pub async fn negotiate_files_down(stream: &mut MessageStream, cipher: &Aes256Gcm) -> Result<()> { +pub async fn negotiate_files_down( + stream: &mut MessageStream, + cipher: &Aes256Gcm, +) -> Result<(Vec)> { let file_offer = match stream.next().await { Some(Ok(msg)) => match msg { Message::EncryptedMessage(response) => response, @@ -121,20 +129,22 @@ pub async fn negotiate_files_down(stream: &mut MessageStream, cipher: &Aes256Gcm }; let mut stdin = FramedRead::new(io::stdin(), LinesCodec::new()); let mut files = vec![]; - for path in requested_infos.into_iter() { - let mut reply = prompt_user_input(&mut stdin, &path).await; + for file_info in requested_infos.into_iter() { + let mut reply = prompt_user_input(&mut stdin, &file_info).await; while reply.is_none() { - reply = prompt_user_input(&mut stdin, &path).await; + reply = prompt_user_input(&mut stdin, &file_info).await; } match reply { - Some(true) => files.push(path), + Some(true) => files.push(file_info), _ => {} } } - let msg = EncryptedMessage::FileNegotiationMessage(FileNegotiationPayload { files }); + let msg = EncryptedMessage::FileNegotiationMessage(FileNegotiationPayload { + files: files.clone(), + }); let server_msg = msg.to_encrypted_message(cipher)?; stream.send(server_msg).await?; - Ok(()) + Ok(files) } pub async fn upload_encrypted_files( @@ -143,7 +153,6 @@ pub async fn upload_encrypted_files( cipher: &Aes256Gcm, ) -> Result<()> { let (tx, mut rx) = mpsc::unbounded_channel::(); - //turn foo into something more concrete for mut handle in handles { let txc = tx.clone(); tokio::spawn(async move { @@ -168,25 +177,82 @@ pub async fn enqueue_file_chunks( fh: &mut FileHandle, tx: mpsc::UnboundedSender, ) -> Result<()> { - // let mut buf = BytesMut::with_capacity(BUFFER_SIZE); + let mut chunk_num = 0; + let mut buf: BytesMut; + while { + buf = BytesMut::with_capacity(BUFFER_SIZE); + let n = fh.file.read_exact(&mut buf[..]).await?; + n == 0 + } { + let chunk = buf.freeze(); + let file_info = fh.to_file_info(); + let ftp = EncryptedMessage::FileTransferMessage(FileTransferPayload { + chunk, + chunk_num, + file_info, + }); + tx.send(ftp)?; + chunk_num += 1; + } - // // The `read` method is defined by this trait. - // let mut chunk_num = 0; - // while { - // let n = fh.file.read(&mut buf[..]).await?; - // n == 0 - // } { - // let chunk = buf.freeze(); - // let file_info = fh.to_file_info(); - // let ftp = EncryptedMessage::FileTransferMessage(FileTransferPayload { - // chunk, - // chunk_num, - // file_info, - // }); - // tx.send(ftp); - // chunk_num += 1; - // } + Ok(()) +} +pub async fn download_files( + file_infos: Vec, + stream: &mut MessageStream, + cipher: &Aes256Gcm, +) -> Result<()> { + // for each file_info + let info_handles: HashMap<_, _> = file_infos + .into_iter() + .map(|fi| { + let (tx, rx) = mpsc::unbounded_channel::<(u64, Bytes)>(); + let path = fi.path.clone(); + tokio::spawn(async move { download_file(fi, rx).await }); + (path, tx) + }) + .collect(); + loop { + tokio::select! { + result = stream.next() => match result { + Some(Ok(Message::EncryptedMessage(payload))) => { + let ec = EncryptedMessage::from_encrypted_message(cipher, &payload)?; + match ec { + EncryptedMessage::FileTransferMessage(payload) => { + if let Some(tx) = info_handles.get(&payload.file_info.path) { + tx.send((payload.chunk_num, payload.chunk))? + }; + }, + _ => {println!("wrong msg")} + } + } + Some(Ok(_)) => { + println!("wrong msg"); + } + Some(Err(e)) => { + println!("Error {:?}", e); + } + None => break, + } + } + } + Ok(()) +} + +pub async fn download_file( + file_info: FileInfo, + rx: mpsc::UnboundedReceiver<(u64, Bytes)>, +) -> Result<()> { + let mut rx = rx; + let filename = match file_info.path.file_name() { + Some(f) => f, + None => OsStr::new("random.txt"), + }; + let mut file = File::open(filename).await?; + while let Some((chunk_num, chunk)) = rx.recv().await { + file.write_all(&chunk).await?; + } Ok(()) } @@ -196,7 +262,7 @@ pub async fn prompt_user_input( ) -> Option { let prompt_name = file_info.path.file_name().unwrap(); println!( - "Do you want to download {:?}? It's {:?}. (Y/n)", + "Accept {:?}? ({:?}). (Y/n)", prompt_name, to_size_string(file_info.size) ); diff --git a/src/server.rs b/src/server.rs index 8da91be..c52bc9f 100644 --- a/src/server.rs +++ b/src/server.rs @@ -124,7 +124,6 @@ pub async fn handle_connection( println!("server - received msg from {:?}", addr); let client = Client::new(handshake_payload.id.clone(), state.clone(), stream).await?; let mut client = Client::upgrade(client, state.clone(), handshake_payload).await?; - // add client to state here loop { tokio::select! { Some(msg) = client.rx.recv() => { @@ -142,6 +141,5 @@ pub async fn handle_connection( } } } - // client is disconnected, let's remove them from the state Ok(()) }