diff --git a/src/client.rs b/src/client.rs index 69bdcac..61e6fc2 100644 --- a/src/client.rs +++ b/src/client.rs @@ -100,12 +100,7 @@ pub async fn create_files(desired_files: Vec) -> Result) -> Result<()> { - let mut socket = self.ms.into_inner().into_std()?; - tokio::task::spawn_blocking(move || { - for mut handle in handles { - let mut buffer = BytesMut::with_capacity(BUFFER_SIZE); - let mut start = handle.start; - loop { - let end = - FileHandle::to_message(handle.id, &mut handle.file, &mut buffer, start)?; - let mut compressor = GzEncoder::new(Vec::new(), Compression::fast()); - compressor.write_all(&buffer[..])?; - let compressed_bytes = compressor.finish()?; - let encrypted_bytes = self.crypt.encrypt(Bytes::from(compressed_bytes))?; - start = end; - socket.write(&encrypted_bytes[..])?; - if end == 0 { - break; - } + pub async fn upload_file(&mut self, handle: StdFileHandle) -> Result<()> { + let mut buffer = [0; BUFFER_SIZE]; + let reader = BufReader::new(handle.file); + let mut gz = GzEncoder::new(reader, Compression::fast()); + loop { + match gz.read(&mut buffer) { + Ok(0) => { + break; } + Ok(n) => { + let message = Message::FileTransfer(FileTransferPayload { + chunk: BytesMut::from(&buffer[..n]).freeze(), + chunk_header: ChunkHeader { + id: handle.id, + start: 0, + }, + }); + self.send_msg(message).await?; + } + Err(e) => return Err(anyhow!(e.to_string())), } - Ok(()) - }) - .await? + } + Ok(()) + } + + pub async fn upload_files(mut self, handles: Vec) -> Result<()> { + for handle in handles { + self.upload_file(handle).await?; + } + Ok(()) } pub async fn download_files(mut self, handles: Vec) -> Result<()> { - let mut socket = self.ms.into_inner().into_std()?; - tokio::task::spawn_blocking(move || { - for mut handle in handles { - let mut buffer = BytesMut::with_capacity(BUFFER_SIZE); - let mut start = handle.start; - loop { - // read bytes - match socket.read(&mut buffer) { - Ok(0) => { - break; - } - Ok(n) => { - let decrypted_bytes = - self.crypt.decrypt(Bytes::from(&mut buffer[0..n]))?; - let mut writer = Vec::new(); - let mut decompressor = GzDecoder::new(writer); - decompressor.write_all(&decrypted_bytes[..])?; - decompressor.try_finish()?; - writer = decompressor.finish()?; - let chunk_header: ChunkHeader = - bincode::deserialize(&writer[..CHUNK_HEADER_SIZE])?; - handle.file.write_all(&writer) - } - Err(e) => return Err(anyhow!(e.to_string())), - }; + for handle in handles { + self.download_file(handle).await?; + } + Ok(()) + } + + pub async fn download_file(&mut self, handle: StdFileHandle) -> Result<()> { + let mut decoder = GzDecoder::new(handle.file); + loop { + let msg = self.await_msg().await?; + match msg { + Message::FileTransfer(payload) => { + if payload.chunk_header.id != handle.id { + return Err(anyhow!("Wrong file")); + } + if payload.chunk.len() == 0 { + break; + } + decoder.write_all(&payload.chunk[..])? } + _ => return Err(anyhow!("Expecting file transfer message")), } - Ok(()) - }) - .await? + } + decoder.finish()?; + println!("Done downloading file."); + Ok(()) } } diff --git a/src/file.rs b/src/file.rs index 4967cd4..c63e094 100644 --- a/src/file.rs +++ b/src/file.rs @@ -5,8 +5,7 @@ use serde::{Deserialize, Serialize}; use std::fs::Metadata; use std::path::PathBuf; -use bytes::BytesMut; -use std::io::{Read, Seek, SeekFrom}; +use std::io::{Seek, SeekFrom}; use tokio::fs::File; @@ -29,6 +28,18 @@ pub struct StdFileHandle { pub start: u64, } +impl StdFileHandle { + pub async fn new(id: u8, file: File, start: u64) -> Result { + let mut std_file = file.into_std().await; + std_file.seek(SeekFrom::Start(start))?; + Ok(StdFileHandle { + id: id, + file: std_file, + start: start, + }) + } +} + pub struct FileHandle { pub id: u8, pub file: File, @@ -69,13 +80,7 @@ impl FileHandle { } async fn to_std(self, chunk_header: &ChunkHeader) -> Result { - let mut std_file = self.file.into_std().await; - std_file.seek(SeekFrom::Start(chunk_header.start))?; - Ok(StdFileHandle { - id: self.id, - file: std_file, - start: chunk_header.start, - }) + StdFileHandle::new(self.id, self.file, chunk_header.start).await } pub fn to_file_offer(&self) -> FileOffer { @@ -94,26 +99,6 @@ impl FileHandle { let handles = try_join_all(tasks).await?; Ok(handles) } - - pub fn to_message( - id: u8, - file: &mut std::fs::File, - buffer: &mut BytesMut, - start: u64, - ) -> Result { - // reads the next chunk of the file - // packs it into the buffer, with the header taking up the first X bytes - let chunk_header = ChunkHeader { id, start }; - let chunk_bytes = bincode::serialize(&chunk_header)?; - println!( - "chunk_bytes = {:?}, len = {:?}", - chunk_bytes.clone(), - chunk_bytes.len() - ); - buffer.extend_from_slice(&chunk_bytes[..]); - let n = file.read(buffer)? as u64; - Ok(n) - } } const SUFFIX: [&'static str; 9] = ["B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"];