Refactor connection

This commit is contained in:
Donald Knuth 2022-08-30 09:53:16 -04:00
parent e79510cafe
commit 3c7c5fe29d
6 changed files with 108 additions and 184 deletions

View file

@ -1,10 +1,8 @@
use crate::conf::BUFFER_SIZE;
use crate::crypto::Crypt;
use crate::connection::Connection;
use crate::file::{to_size_string, FileHandle, FileInfo};
use crate::handshake::Handshake;
use crate::message::{
EncryptedMessage, FileNegotiationPayload, FileTransferPayload, Message, MessageStream,
};
use crate::message::{FileNegotiationPayload, FileTransferPayload, Message};
use anyhow::{anyhow, Result};
use bytes::{Bytes, BytesMut};
@ -30,13 +28,12 @@ pub async fn send(file_paths: &Vec<PathBuf>, password: &String) -> Result<()> {
// Complete handshake, returning key used for encryption
let (socket, key) = handshake.negotiate(socket, s1).await?;
let mut stream = Message::to_stream(socket);
let crypt = Crypt::new(&key);
let mut connection = Connection::new(socket, key);
// Complete file negotiation
let handles = negotiate_files_up(handles, &mut stream, &crypt).await?;
let handles = negotiate_files_up(&mut connection, handles).await?;
// Upload negotiated files
upload_encrypted_files(&mut stream, handles, &crypt).await?;
upload_encrypted_files(&mut connection, handles).await?;
println!("Done uploading.");
// Exit
@ -47,11 +44,10 @@ pub async fn receive(password: &String) -> Result<()> {
let socket = TcpStream::connect("127.0.0.1:8080").await?;
let (handshake, s1) = Handshake::from_password(password);
let (socket, key) = handshake.negotiate(socket, s1).await?;
let mut stream = Message::to_stream(socket);
let crypt = Crypt::new(&key);
let files = negotiate_files_down(&mut stream, &crypt).await?;
let mut connection = Connection::new(socket, key);
let files = negotiate_files_down(&mut connection).await?;
download_files(files, &mut stream, &crypt).await?;
download_files(&mut connection, files).await?;
return Ok(());
}
@ -64,28 +60,15 @@ pub async fn get_file_handles(file_paths: &Vec<PathBuf>) -> Result<Vec<FileHandl
}
pub async fn negotiate_files_up(
conn: &mut Connection,
file_handles: Vec<FileHandle>,
stream: &mut MessageStream,
crypt: &Crypt,
) -> Result<Vec<FileHandle>> {
let files = file_handles.iter().map(|fh| fh.to_file_info()).collect();
let msg = EncryptedMessage::FileNegotiationMessage(FileNegotiationPayload { files });
let server_msg = msg.to_encrypted_message(crypt)?;
println!("server_msg encrypted: {:?}", server_msg);
stream.send(server_msg).await?;
let reply_payload = match stream.next().await {
Some(Ok(msg)) => match msg {
Message::EncryptedMessage(response) => response,
},
_ => {
return Err(anyhow!("No response to negotiation message"));
}
};
let plaintext_reply = EncryptedMessage::from_encrypted_message(crypt, &reply_payload)?;
let requested_paths: Vec<PathBuf> = match plaintext_reply {
EncryptedMessage::FileNegotiationMessage(fnm) => {
fnm.files.into_iter().map(|f| f.path).collect()
}
let msg = Message::FileNegotiationMessage(FileNegotiationPayload { files });
conn.send_msg(msg).await?;
let reply = conn.await_msg().await?;
let requested_paths: Vec<PathBuf> = match reply {
Message::FileNegotiationMessage(fnm) => fnm.files.into_iter().map(|f| f.path).collect(),
_ => return Err(anyhow!("Expecting file negotiation message back")),
};
Ok(file_handles
@ -94,21 +77,10 @@ pub async fn negotiate_files_up(
.collect())
}
pub async fn negotiate_files_down(
stream: &mut MessageStream,
crypt: &Crypt,
) -> Result<Vec<FileInfo>> {
let file_offer = match stream.next().await {
Some(Ok(msg)) => match msg {
Message::EncryptedMessage(response) => response,
},
_ => {
return Err(anyhow!("No response to negotiation message"));
}
};
let plaintext_offer = EncryptedMessage::from_encrypted_message(crypt, &file_offer)?;
let requested_infos: Vec<FileInfo> = match plaintext_offer {
EncryptedMessage::FileNegotiationMessage(fnm) => fnm.files,
pub async fn negotiate_files_down(conn: &mut Connection) -> Result<Vec<FileInfo>> {
let offer = conn.await_msg().await?;
let requested_infos: Vec<FileInfo> = match offer {
Message::FileNegotiationMessage(fnm) => fnm.files,
_ => return Err(anyhow!("Expecting file negotiation message back")),
};
let mut stdin = FramedRead::new(io::stdin(), LinesCodec::new());
@ -123,47 +95,22 @@ pub async fn negotiate_files_down(
_ => {}
}
}
let msg = EncryptedMessage::FileNegotiationMessage(FileNegotiationPayload {
let msg = Message::FileNegotiationMessage(FileNegotiationPayload {
files: files.clone(),
});
let server_msg = msg.to_encrypted_message(crypt)?;
stream.send(server_msg).await?;
conn.send_msg(msg).await?;
Ok(files)
}
pub async fn upload_encrypted_files(
stream: &mut MessageStream,
handles: Vec<FileHandle>,
cipher: &Crypt,
) -> Result<()> {
let (tx, mut rx) = mpsc::unbounded_channel::<EncryptedMessage>();
pub async fn upload_encrypted_files(conn: &mut Connection, handles: Vec<FileHandle>) -> Result<()> {
for mut handle in handles {
let txc = tx.clone();
tokio::spawn(async move {
let _ = enqueue_file_chunks(&mut handle, txc).await;
});
}
loop {
tokio::select! {
Some(msg) = rx.recv() => {
// println!("message received to client.rx {:?}", msg);
let x = msg.to_encrypted_message(cipher)?;
stream.send(x).await?
}
else => {
println!("breaking");
break
},
}
enqueue_file_chunks(conn, &mut handle).await?;
}
println!("Files uploaded.");
Ok(())
}
pub async fn enqueue_file_chunks(
fh: &mut FileHandle,
tx: mpsc::UnboundedSender<EncryptedMessage>,
) -> Result<()> {
pub async fn enqueue_file_chunks(conn: &mut Connection, fh: &mut FileHandle) -> Result<()> {
let mut chunk_num = 0;
let mut bytes_read = 1;
while bytes_read != 0 {
@ -173,12 +120,12 @@ pub async fn enqueue_file_chunks(
if bytes_read != 0 {
let chunk = buf.freeze();
let file_info = fh.to_file_info();
let ftp = EncryptedMessage::FileTransferMessage(FileTransferPayload {
let ftp = Message::FileTransferMessage(FileTransferPayload {
chunk,
chunk_num,
file_info,
});
tx.send(ftp)?;
conn.send_msg(ftp).await?;
chunk_num += 1;
}
}
@ -186,11 +133,7 @@ pub async fn enqueue_file_chunks(
Ok(())
}
pub async fn download_files(
file_infos: Vec<FileInfo>,
stream: &mut MessageStream,
cipher: &Crypt,
) -> Result<()> {
pub async fn download_files(conn: &mut Connection, file_infos: Vec<FileInfo>) -> Result<()> {
// for each file_info
let mut info_handles: HashMap<PathBuf, mpsc::UnboundedSender<(u64, Bytes)>> = HashMap::new();
for fi in file_infos {
@ -201,29 +144,24 @@ pub async fn download_files(
}
loop {
tokio::select! {
result = stream.next() => match result {
Some(Ok(Message::EncryptedMessage(payload))) => {
let ec = EncryptedMessage::from_encrypted_message(cipher, &payload)?;
// println!("encrypted message received! {:?}", ec);
match ec {
EncryptedMessage::FileTransferMessage(payload) => {
// println!("matched file transfer message");
result = conn.await_msg() => match result {
Ok(msg) => {
match msg {
Message::FileTransferMessage(payload) => {
if let Some(tx) = info_handles.get(&payload.file_info.path) {
// println!("matched on filetype, sending to tx");
tx.send((payload.chunk_num, payload.chunk))?
};
}
},
_ => {println!("wrong msg")}
_ => {
println!("Wrong message type");
return Err(anyhow!("wrong message type"));
}
}
}
Some(Err(e)) => {
println!("Error {:?}", e);
}
None => break,
},
Err(e) => return Err(anyhow!(e.to_string())),
}
}
}
Ok(())
}
pub async fn download_file(

View file

@ -1,4 +1,38 @@
use crate::crypto::Crypt;
use crate::message::{Message, MessageStream};
use anyhow::{anyhow, Result};
use futures::{SinkExt, StreamExt};
use tokio::net::TcpStream;
pub struct Connection {
ms: MessageStream;
crypt: Crypt
}
ms: MessageStream,
crypt: Crypt,
}
impl Connection {
pub fn new(socket: TcpStream, key: Vec<u8>) -> Self {
let ms = Message::to_stream(socket);
let crypt = Crypt::new(&key);
Connection { ms, crypt }
}
pub async fn send_msg(&mut self, msg: Message) -> Result<()> {
let msg = msg.serialize()?;
let bytes = self.crypt.encrypt(msg)?;
match self.ms.send(bytes).await {
Ok(_) => Ok(()),
Err(e) => Err(anyhow!(e.to_string())),
}
}
pub async fn await_msg(&mut self) -> Result<Message> {
match self.ms.next().await {
Some(Ok(msg)) => {
let decrypted_bytes = self.crypt.decrypt(msg.freeze())?;
Message::deserialize(decrypted_bytes)
}
_ => {
return Err(anyhow!("No response to negotiation message"));
}
}
}
}

View file

@ -1,9 +1,8 @@
use crate::conf::NONCE_SIZE_IN_BYTES;
use crate::message::EncryptedPayload;
use aes_gcm::aead::{Aead, NewAead};
use aes_gcm::{Aes256Gcm, Key, Nonce}; // Or `Aes128Gcm`
use anyhow::{anyhow, Result};
use bytes::Bytes;
use bytes::{Bytes, BytesMut};
use rand::{thread_rng, Rng};
@ -19,25 +18,28 @@ impl Crypt {
}
}
pub fn encrypt(&self, body: &Vec<u8>) -> Result<EncryptedPayload> {
pub fn encrypt(&self, plaintext: Bytes) -> Result<Bytes> {
let mut arr = [0u8; NONCE_SIZE_IN_BYTES];
thread_rng().try_fill(&mut arr[..])?;
let nonce = Nonce::from_slice(&arr);
let plaintext = body.as_ref();
match self.cipher.encrypt(nonce, plaintext) {
Ok(body) => Ok(EncryptedPayload {
nonce: arr.to_vec(),
body,
}),
Err(_) => Err(anyhow!("Encryption error")),
match self.cipher.encrypt(nonce, plaintext.as_ref()) {
Ok(body) => {
let mut buffer = BytesMut::with_capacity(NONCE_SIZE_IN_BYTES + body.len());
buffer.extend_from_slice(nonce);
buffer.extend_from_slice(&body);
Ok(buffer.freeze())
}
Err(e) => Err(anyhow!(e.to_string())),
}
}
pub fn decrypt(&self, payload: &EncryptedPayload) -> Result<Bytes> {
let nonce = Nonce::from_slice(payload.nonce.as_ref());
match self.cipher.decrypt(nonce, payload.body.as_ref()) {
pub fn decrypt(&self, body: Bytes) -> Result<Bytes> {
let mut body = body;
let nonce_bytes = body.split_to(NONCE_SIZE_IN_BYTES);
let nonce = Nonce::from_slice(&nonce_bytes);
match self.cipher.decrypt(nonce, body.as_ref()) {
Ok(payload) => Ok(Bytes::from(payload)),
Err(_) => Err(anyhow!("Decryption error")),
Err(e) => Err(anyhow!(e.to_string())),
}
}
}

View file

@ -1,6 +1,7 @@
mod cli;
mod client;
mod conf;
mod connection;
mod crypto;
mod file;
mod handshake;

View file

@ -1,28 +1,13 @@
use crate::crypto::Crypt;
use crate::file::FileInfo;
use anyhow::{anyhow, Result};
use bytes::Bytes;
use serde::{Deserialize, Serialize};
use std::error::Error;
use std::fmt;
use tokio::net::TcpStream;
use tokio_serde::{formats::SymmetricalBincode, SymmetricallyFramed};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum Message {
EncryptedMessage(EncryptedPayload),
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct EncryptedPayload {
pub nonce: Vec<u8>,
pub body: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum EncryptedMessage {
FileNegotiationMessage(FileNegotiationPayload),
FileTransferMessage(FileTransferPayload),
}
@ -39,62 +24,25 @@ pub struct FileTransferPayload {
pub chunk: Bytes,
}
impl EncryptedMessage {
pub fn from_encrypted_message(crypt: &Crypt, payload: &EncryptedPayload) -> Result<Self> {
let raw = crypt.decrypt(payload)?;
let res = match bincode::deserialize(raw.as_ref()) {
Ok(result) => result,
Err(e) => {
println!("deserialize error {:?}", e);
return Err(anyhow!("deser error"));
}
};
Ok(res)
impl Message {
pub fn serialize(&self) -> Result<Bytes> {
match bincode::serialize(&self) {
Ok(vec) => Ok(Bytes::from(vec)),
Err(e) => Err(anyhow!(e.to_string())),
}
}
pub fn to_encrypted_message(&self, crypt: &Crypt) -> Result<Message> {
let raw = match bincode::serialize(&self) {
Ok(result) => result,
Err(e) => {
println!("serialize error {:?}", e);
return Err(anyhow!("serialize error"));
}
};
let payload = crypt.encrypt(&raw)?;
Ok(Message::EncryptedMessage(payload))
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum RuckError {
NotHandshake,
SenderNotConnected,
SenderAlreadyConnected,
PairDisconnected,
}
impl fmt::Display for RuckError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "RuckError is here!")
}
}
impl Error for RuckError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
Some(self)
pub fn deserialize(bytes: Bytes) -> Result<Self> {
match bincode::deserialize(bytes.as_ref()) {
Ok(msg) => Ok(msg),
Err(e) => Err(anyhow!(e.to_string())),
}
}
}
impl Message {
pub fn to_stream(stream: TcpStream) -> MessageStream {
tokio_serde::SymmetricallyFramed::new(
Framed::new(stream, LengthDelimitedCodec::new()),
tokio_serde::formats::SymmetricalBincode::<Message>::default(),
)
Framed::new(stream, LengthDelimitedCodec::new())
}
}
pub type MessageStream = SymmetricallyFramed<
Framed<TcpStream, LengthDelimitedCodec>,
Message,
SymmetricalBincode<Message>,
>;
pub type MessageStream = Framed<TcpStream, LengthDelimitedCodec>;

View file

@ -57,7 +57,9 @@ impl Client {
// Sender - needs to wait for the incoming msg to look up peer_tx
None => {
tokio::select! {
// Client reads handshake message sent over channel
Some(msg) = client.rx.recv() => {
// Writes parnter handshake message over wire
client.socket.write_all(&msg[..]).await?
}
}
@ -126,9 +128,8 @@ pub async fn handle_connection(
},
Ok(n) => {
let b = BytesMut::from(&client_buffer[0..n]).freeze();
// println!("reading more = {:?}", b);
client.peer_tx.send(b)?;
client_buffer.clear();
client.peer_tx.send(b)?
},
Err(e) => {
println!("Error {:?}", e);