Back to parity

This commit is contained in:
Donald Knuth 2022-09-02 20:29:18 -04:00
parent e0567d8479
commit 6bc011c910
4 changed files with 73 additions and 90 deletions

View file

@ -100,12 +100,7 @@ pub async fn create_files(desired_files: Vec<FileOffer>) -> Result<Vec<StdFileHa
.file_name()
.unwrap_or(OsStr::new("random.txt"));
let file = File::create(filename).await?;
let std_file = file.into_std().await;
let std_file_handle = StdFileHandle {
id: desired_file.id,
file: std_file,
start: 0,
};
let std_file_handle = StdFileHandle::new(desired_file.id, file, 0).await?;
v.push(std_file_handle)
}
return Ok(v);

View file

@ -3,4 +3,3 @@ pub const HANDSHAKE_MSG_SIZE: usize = 33; // generated by Spake2
pub const PER_CLIENT_BUFFER: usize = 1024 * 64; // buffer size allocated by server for each client
pub const BUFFER_SIZE: usize = 1024 * 64; // chunk size for files sent over wire
pub const NONCE_SIZE: usize = 96 / 8; // used for every encrypted message
pub const CHUNK_HEADER_SIZE: usize = 10; // used for every chunk header

View file

@ -1,16 +1,17 @@
use crate::conf::{BUFFER_SIZE, CHUNK_HEADER_SIZE};
use crate::conf::BUFFER_SIZE;
use crate::crypto::Crypt;
use crate::file::{ChunkHeader, FileHandle, StdFileHandle};
use crate::message::{Message, MessageStream};
use crate::file::{ChunkHeader, StdFileHandle};
use crate::message::{FileTransferPayload, Message, MessageStream};
use anyhow::{anyhow, Result};
use futures::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use bytes::{Bytes, BytesMut};
use flate2::write::{GzDecoder, GzEncoder};
use flate2::bufread::GzEncoder;
use flate2::write::GzDecoder;
use flate2::Compression;
use std::io::{Read, Write};
use std::io::{BufReader, Read, Write};
pub struct Connection {
ms: MessageStream,
@ -48,61 +49,64 @@ impl Connection {
}
}
pub async fn upload_files(mut self, handles: Vec<StdFileHandle>) -> Result<()> {
let mut socket = self.ms.into_inner().into_std()?;
tokio::task::spawn_blocking(move || {
for mut handle in handles {
let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
let mut start = handle.start;
loop {
let end =
FileHandle::to_message(handle.id, &mut handle.file, &mut buffer, start)?;
let mut compressor = GzEncoder::new(Vec::new(), Compression::fast());
compressor.write_all(&buffer[..])?;
let compressed_bytes = compressor.finish()?;
let encrypted_bytes = self.crypt.encrypt(Bytes::from(compressed_bytes))?;
start = end;
socket.write(&encrypted_bytes[..])?;
if end == 0 {
break;
}
pub async fn upload_file(&mut self, handle: StdFileHandle) -> Result<()> {
let mut buffer = [0; BUFFER_SIZE];
let reader = BufReader::new(handle.file);
let mut gz = GzEncoder::new(reader, Compression::fast());
loop {
match gz.read(&mut buffer) {
Ok(0) => {
break;
}
Ok(n) => {
let message = Message::FileTransfer(FileTransferPayload {
chunk: BytesMut::from(&buffer[..n]).freeze(),
chunk_header: ChunkHeader {
id: handle.id,
start: 0,
},
});
self.send_msg(message).await?;
}
Err(e) => return Err(anyhow!(e.to_string())),
}
Ok(())
})
.await?
}
Ok(())
}
pub async fn upload_files(mut self, handles: Vec<StdFileHandle>) -> Result<()> {
for handle in handles {
self.upload_file(handle).await?;
}
Ok(())
}
pub async fn download_files(mut self, handles: Vec<StdFileHandle>) -> Result<()> {
let mut socket = self.ms.into_inner().into_std()?;
tokio::task::spawn_blocking(move || {
for mut handle in handles {
let mut buffer = BytesMut::with_capacity(BUFFER_SIZE);
let mut start = handle.start;
loop {
// read bytes
match socket.read(&mut buffer) {
Ok(0) => {
break;
}
Ok(n) => {
let decrypted_bytes =
self.crypt.decrypt(Bytes::from(&mut buffer[0..n]))?;
let mut writer = Vec::new();
let mut decompressor = GzDecoder::new(writer);
decompressor.write_all(&decrypted_bytes[..])?;
decompressor.try_finish()?;
writer = decompressor.finish()?;
let chunk_header: ChunkHeader =
bincode::deserialize(&writer[..CHUNK_HEADER_SIZE])?;
handle.file.write_all(&writer)
}
Err(e) => return Err(anyhow!(e.to_string())),
};
for handle in handles {
self.download_file(handle).await?;
}
Ok(())
}
pub async fn download_file(&mut self, handle: StdFileHandle) -> Result<()> {
let mut decoder = GzDecoder::new(handle.file);
loop {
let msg = self.await_msg().await?;
match msg {
Message::FileTransfer(payload) => {
if payload.chunk_header.id != handle.id {
return Err(anyhow!("Wrong file"));
}
if payload.chunk.len() == 0 {
break;
}
decoder.write_all(&payload.chunk[..])?
}
_ => return Err(anyhow!("Expecting file transfer message")),
}
Ok(())
})
.await?
}
decoder.finish()?;
println!("Done downloading file.");
Ok(())
}
}

View file

@ -5,8 +5,7 @@ use serde::{Deserialize, Serialize};
use std::fs::Metadata;
use std::path::PathBuf;
use bytes::BytesMut;
use std::io::{Read, Seek, SeekFrom};
use std::io::{Seek, SeekFrom};
use tokio::fs::File;
@ -29,6 +28,18 @@ pub struct StdFileHandle {
pub start: u64,
}
impl StdFileHandle {
pub async fn new(id: u8, file: File, start: u64) -> Result<StdFileHandle> {
let mut std_file = file.into_std().await;
std_file.seek(SeekFrom::Start(start))?;
Ok(StdFileHandle {
id: id,
file: std_file,
start: start,
})
}
}
pub struct FileHandle {
pub id: u8,
pub file: File,
@ -69,13 +80,7 @@ impl FileHandle {
}
async fn to_std(self, chunk_header: &ChunkHeader) -> Result<StdFileHandle> {
let mut std_file = self.file.into_std().await;
std_file.seek(SeekFrom::Start(chunk_header.start))?;
Ok(StdFileHandle {
id: self.id,
file: std_file,
start: chunk_header.start,
})
StdFileHandle::new(self.id, self.file, chunk_header.start).await
}
pub fn to_file_offer(&self) -> FileOffer {
@ -94,26 +99,6 @@ impl FileHandle {
let handles = try_join_all(tasks).await?;
Ok(handles)
}
pub fn to_message(
id: u8,
file: &mut std::fs::File,
buffer: &mut BytesMut,
start: u64,
) -> Result<u64> {
// reads the next chunk of the file
// packs it into the buffer, with the header taking up the first X bytes
let chunk_header = ChunkHeader { id, start };
let chunk_bytes = bincode::serialize(&chunk_header)?;
println!(
"chunk_bytes = {:?}, len = {:?}",
chunk_bytes.clone(),
chunk_bytes.len()
);
buffer.extend_from_slice(&chunk_bytes[..]);
let n = file.read(buffer)? as u64;
Ok(n)
}
}
const SUFFIX: [&'static str; 9] = ["B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"];