mirror of
https://github.com/CompeyDev/ruck.git
synced 2025-01-23 09:48:03 +00:00
Refactor connection
This commit is contained in:
parent
e79510cafe
commit
3c7c5fe29d
6 changed files with 108 additions and 184 deletions
136
src/client.rs
136
src/client.rs
|
@ -1,10 +1,8 @@
|
|||
use crate::conf::BUFFER_SIZE;
|
||||
use crate::crypto::Crypt;
|
||||
use crate::connection::Connection;
|
||||
use crate::file::{to_size_string, FileHandle, FileInfo};
|
||||
use crate::handshake::Handshake;
|
||||
use crate::message::{
|
||||
EncryptedMessage, FileNegotiationPayload, FileTransferPayload, Message, MessageStream,
|
||||
};
|
||||
use crate::message::{FileNegotiationPayload, FileTransferPayload, Message};
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use bytes::{Bytes, BytesMut};
|
||||
|
@ -30,13 +28,12 @@ pub async fn send(file_paths: &Vec<PathBuf>, password: &String) -> Result<()> {
|
|||
// Complete handshake, returning key used for encryption
|
||||
let (socket, key) = handshake.negotiate(socket, s1).await?;
|
||||
|
||||
let mut stream = Message::to_stream(socket);
|
||||
let crypt = Crypt::new(&key);
|
||||
let mut connection = Connection::new(socket, key);
|
||||
// Complete file negotiation
|
||||
let handles = negotiate_files_up(handles, &mut stream, &crypt).await?;
|
||||
let handles = negotiate_files_up(&mut connection, handles).await?;
|
||||
|
||||
// Upload negotiated files
|
||||
upload_encrypted_files(&mut stream, handles, &crypt).await?;
|
||||
upload_encrypted_files(&mut connection, handles).await?;
|
||||
println!("Done uploading.");
|
||||
|
||||
// Exit
|
||||
|
@ -47,11 +44,10 @@ pub async fn receive(password: &String) -> Result<()> {
|
|||
let socket = TcpStream::connect("127.0.0.1:8080").await?;
|
||||
let (handshake, s1) = Handshake::from_password(password);
|
||||
let (socket, key) = handshake.negotiate(socket, s1).await?;
|
||||
let mut stream = Message::to_stream(socket);
|
||||
let crypt = Crypt::new(&key);
|
||||
let files = negotiate_files_down(&mut stream, &crypt).await?;
|
||||
let mut connection = Connection::new(socket, key);
|
||||
let files = negotiate_files_down(&mut connection).await?;
|
||||
|
||||
download_files(files, &mut stream, &crypt).await?;
|
||||
download_files(&mut connection, files).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
|
@ -64,28 +60,15 @@ pub async fn get_file_handles(file_paths: &Vec<PathBuf>) -> Result<Vec<FileHandl
|
|||
}
|
||||
|
||||
pub async fn negotiate_files_up(
|
||||
conn: &mut Connection,
|
||||
file_handles: Vec<FileHandle>,
|
||||
stream: &mut MessageStream,
|
||||
crypt: &Crypt,
|
||||
) -> Result<Vec<FileHandle>> {
|
||||
let files = file_handles.iter().map(|fh| fh.to_file_info()).collect();
|
||||
let msg = EncryptedMessage::FileNegotiationMessage(FileNegotiationPayload { files });
|
||||
let server_msg = msg.to_encrypted_message(crypt)?;
|
||||
println!("server_msg encrypted: {:?}", server_msg);
|
||||
stream.send(server_msg).await?;
|
||||
let reply_payload = match stream.next().await {
|
||||
Some(Ok(msg)) => match msg {
|
||||
Message::EncryptedMessage(response) => response,
|
||||
},
|
||||
_ => {
|
||||
return Err(anyhow!("No response to negotiation message"));
|
||||
}
|
||||
};
|
||||
let plaintext_reply = EncryptedMessage::from_encrypted_message(crypt, &reply_payload)?;
|
||||
let requested_paths: Vec<PathBuf> = match plaintext_reply {
|
||||
EncryptedMessage::FileNegotiationMessage(fnm) => {
|
||||
fnm.files.into_iter().map(|f| f.path).collect()
|
||||
}
|
||||
let msg = Message::FileNegotiationMessage(FileNegotiationPayload { files });
|
||||
conn.send_msg(msg).await?;
|
||||
let reply = conn.await_msg().await?;
|
||||
let requested_paths: Vec<PathBuf> = match reply {
|
||||
Message::FileNegotiationMessage(fnm) => fnm.files.into_iter().map(|f| f.path).collect(),
|
||||
_ => return Err(anyhow!("Expecting file negotiation message back")),
|
||||
};
|
||||
Ok(file_handles
|
||||
|
@ -94,21 +77,10 @@ pub async fn negotiate_files_up(
|
|||
.collect())
|
||||
}
|
||||
|
||||
pub async fn negotiate_files_down(
|
||||
stream: &mut MessageStream,
|
||||
crypt: &Crypt,
|
||||
) -> Result<Vec<FileInfo>> {
|
||||
let file_offer = match stream.next().await {
|
||||
Some(Ok(msg)) => match msg {
|
||||
Message::EncryptedMessage(response) => response,
|
||||
},
|
||||
_ => {
|
||||
return Err(anyhow!("No response to negotiation message"));
|
||||
}
|
||||
};
|
||||
let plaintext_offer = EncryptedMessage::from_encrypted_message(crypt, &file_offer)?;
|
||||
let requested_infos: Vec<FileInfo> = match plaintext_offer {
|
||||
EncryptedMessage::FileNegotiationMessage(fnm) => fnm.files,
|
||||
pub async fn negotiate_files_down(conn: &mut Connection) -> Result<Vec<FileInfo>> {
|
||||
let offer = conn.await_msg().await?;
|
||||
let requested_infos: Vec<FileInfo> = match offer {
|
||||
Message::FileNegotiationMessage(fnm) => fnm.files,
|
||||
_ => return Err(anyhow!("Expecting file negotiation message back")),
|
||||
};
|
||||
let mut stdin = FramedRead::new(io::stdin(), LinesCodec::new());
|
||||
|
@ -123,47 +95,22 @@ pub async fn negotiate_files_down(
|
|||
_ => {}
|
||||
}
|
||||
}
|
||||
let msg = EncryptedMessage::FileNegotiationMessage(FileNegotiationPayload {
|
||||
let msg = Message::FileNegotiationMessage(FileNegotiationPayload {
|
||||
files: files.clone(),
|
||||
});
|
||||
let server_msg = msg.to_encrypted_message(crypt)?;
|
||||
stream.send(server_msg).await?;
|
||||
conn.send_msg(msg).await?;
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
pub async fn upload_encrypted_files(
|
||||
stream: &mut MessageStream,
|
||||
handles: Vec<FileHandle>,
|
||||
cipher: &Crypt,
|
||||
) -> Result<()> {
|
||||
let (tx, mut rx) = mpsc::unbounded_channel::<EncryptedMessage>();
|
||||
pub async fn upload_encrypted_files(conn: &mut Connection, handles: Vec<FileHandle>) -> Result<()> {
|
||||
for mut handle in handles {
|
||||
let txc = tx.clone();
|
||||
tokio::spawn(async move {
|
||||
let _ = enqueue_file_chunks(&mut handle, txc).await;
|
||||
});
|
||||
}
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
Some(msg) = rx.recv() => {
|
||||
// println!("message received to client.rx {:?}", msg);
|
||||
let x = msg.to_encrypted_message(cipher)?;
|
||||
stream.send(x).await?
|
||||
}
|
||||
else => {
|
||||
println!("breaking");
|
||||
break
|
||||
},
|
||||
}
|
||||
enqueue_file_chunks(conn, &mut handle).await?;
|
||||
}
|
||||
println!("Files uploaded.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn enqueue_file_chunks(
|
||||
fh: &mut FileHandle,
|
||||
tx: mpsc::UnboundedSender<EncryptedMessage>,
|
||||
) -> Result<()> {
|
||||
pub async fn enqueue_file_chunks(conn: &mut Connection, fh: &mut FileHandle) -> Result<()> {
|
||||
let mut chunk_num = 0;
|
||||
let mut bytes_read = 1;
|
||||
while bytes_read != 0 {
|
||||
|
@ -173,12 +120,12 @@ pub async fn enqueue_file_chunks(
|
|||
if bytes_read != 0 {
|
||||
let chunk = buf.freeze();
|
||||
let file_info = fh.to_file_info();
|
||||
let ftp = EncryptedMessage::FileTransferMessage(FileTransferPayload {
|
||||
let ftp = Message::FileTransferMessage(FileTransferPayload {
|
||||
chunk,
|
||||
chunk_num,
|
||||
file_info,
|
||||
});
|
||||
tx.send(ftp)?;
|
||||
conn.send_msg(ftp).await?;
|
||||
chunk_num += 1;
|
||||
}
|
||||
}
|
||||
|
@ -186,11 +133,7 @@ pub async fn enqueue_file_chunks(
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn download_files(
|
||||
file_infos: Vec<FileInfo>,
|
||||
stream: &mut MessageStream,
|
||||
cipher: &Crypt,
|
||||
) -> Result<()> {
|
||||
pub async fn download_files(conn: &mut Connection, file_infos: Vec<FileInfo>) -> Result<()> {
|
||||
// for each file_info
|
||||
let mut info_handles: HashMap<PathBuf, mpsc::UnboundedSender<(u64, Bytes)>> = HashMap::new();
|
||||
for fi in file_infos {
|
||||
|
@ -201,29 +144,24 @@ pub async fn download_files(
|
|||
}
|
||||
loop {
|
||||
tokio::select! {
|
||||
result = stream.next() => match result {
|
||||
Some(Ok(Message::EncryptedMessage(payload))) => {
|
||||
let ec = EncryptedMessage::from_encrypted_message(cipher, &payload)?;
|
||||
// println!("encrypted message received! {:?}", ec);
|
||||
match ec {
|
||||
EncryptedMessage::FileTransferMessage(payload) => {
|
||||
// println!("matched file transfer message");
|
||||
result = conn.await_msg() => match result {
|
||||
Ok(msg) => {
|
||||
match msg {
|
||||
Message::FileTransferMessage(payload) => {
|
||||
if let Some(tx) = info_handles.get(&payload.file_info.path) {
|
||||
// println!("matched on filetype, sending to tx");
|
||||
tx.send((payload.chunk_num, payload.chunk))?
|
||||
};
|
||||
}
|
||||
},
|
||||
_ => {println!("wrong msg")}
|
||||
_ => {
|
||||
println!("Wrong message type");
|
||||
return Err(anyhow!("wrong message type"));
|
||||
}
|
||||
}
|
||||
Some(Err(e)) => {
|
||||
println!("Error {:?}", e);
|
||||
}
|
||||
None => break,
|
||||
},
|
||||
Err(e) => return Err(anyhow!(e.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn download_file(
|
||||
|
|
|
@ -1,4 +1,38 @@
|
|||
use crate::crypto::Crypt;
|
||||
use crate::message::{Message, MessageStream};
|
||||
use anyhow::{anyhow, Result};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
pub struct Connection {
|
||||
ms: MessageStream;
|
||||
crypt: Crypt
|
||||
ms: MessageStream,
|
||||
crypt: Crypt,
|
||||
}
|
||||
|
||||
impl Connection {
|
||||
pub fn new(socket: TcpStream, key: Vec<u8>) -> Self {
|
||||
let ms = Message::to_stream(socket);
|
||||
let crypt = Crypt::new(&key);
|
||||
Connection { ms, crypt }
|
||||
}
|
||||
pub async fn send_msg(&mut self, msg: Message) -> Result<()> {
|
||||
let msg = msg.serialize()?;
|
||||
let bytes = self.crypt.encrypt(msg)?;
|
||||
match self.ms.send(bytes).await {
|
||||
Ok(_) => Ok(()),
|
||||
Err(e) => Err(anyhow!(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn await_msg(&mut self) -> Result<Message> {
|
||||
match self.ms.next().await {
|
||||
Some(Ok(msg)) => {
|
||||
let decrypted_bytes = self.crypt.decrypt(msg.freeze())?;
|
||||
Message::deserialize(decrypted_bytes)
|
||||
}
|
||||
_ => {
|
||||
return Err(anyhow!("No response to negotiation message"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,9 +1,8 @@
|
|||
use crate::conf::NONCE_SIZE_IN_BYTES;
|
||||
use crate::message::EncryptedPayload;
|
||||
use aes_gcm::aead::{Aead, NewAead};
|
||||
use aes_gcm::{Aes256Gcm, Key, Nonce}; // Or `Aes128Gcm`
|
||||
use anyhow::{anyhow, Result};
|
||||
use bytes::Bytes;
|
||||
use bytes::{Bytes, BytesMut};
|
||||
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
|
@ -19,25 +18,28 @@ impl Crypt {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn encrypt(&self, body: &Vec<u8>) -> Result<EncryptedPayload> {
|
||||
pub fn encrypt(&self, plaintext: Bytes) -> Result<Bytes> {
|
||||
let mut arr = [0u8; NONCE_SIZE_IN_BYTES];
|
||||
thread_rng().try_fill(&mut arr[..])?;
|
||||
let nonce = Nonce::from_slice(&arr);
|
||||
let plaintext = body.as_ref();
|
||||
match self.cipher.encrypt(nonce, plaintext) {
|
||||
Ok(body) => Ok(EncryptedPayload {
|
||||
nonce: arr.to_vec(),
|
||||
body,
|
||||
}),
|
||||
Err(_) => Err(anyhow!("Encryption error")),
|
||||
match self.cipher.encrypt(nonce, plaintext.as_ref()) {
|
||||
Ok(body) => {
|
||||
let mut buffer = BytesMut::with_capacity(NONCE_SIZE_IN_BYTES + body.len());
|
||||
buffer.extend_from_slice(nonce);
|
||||
buffer.extend_from_slice(&body);
|
||||
Ok(buffer.freeze())
|
||||
}
|
||||
Err(e) => Err(anyhow!(e.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decrypt(&self, payload: &EncryptedPayload) -> Result<Bytes> {
|
||||
let nonce = Nonce::from_slice(payload.nonce.as_ref());
|
||||
match self.cipher.decrypt(nonce, payload.body.as_ref()) {
|
||||
pub fn decrypt(&self, body: Bytes) -> Result<Bytes> {
|
||||
let mut body = body;
|
||||
let nonce_bytes = body.split_to(NONCE_SIZE_IN_BYTES);
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
match self.cipher.decrypt(nonce, body.as_ref()) {
|
||||
Ok(payload) => Ok(Bytes::from(payload)),
|
||||
Err(_) => Err(anyhow!("Decryption error")),
|
||||
Err(e) => Err(anyhow!(e.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
mod cli;
|
||||
mod client;
|
||||
mod conf;
|
||||
mod connection;
|
||||
mod crypto;
|
||||
mod file;
|
||||
mod handshake;
|
||||
|
|
|
@ -1,28 +1,13 @@
|
|||
use crate::crypto::Crypt;
|
||||
use crate::file::FileInfo;
|
||||
|
||||
use anyhow::{anyhow, Result};
|
||||
use bytes::Bytes;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio_serde::{formats::SymmetricalBincode, SymmetricallyFramed};
|
||||
use tokio_util::codec::{Framed, LengthDelimitedCodec};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Message {
|
||||
EncryptedMessage(EncryptedPayload),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct EncryptedPayload {
|
||||
pub nonce: Vec<u8>,
|
||||
pub body: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum EncryptedMessage {
|
||||
FileNegotiationMessage(FileNegotiationPayload),
|
||||
FileTransferMessage(FileTransferPayload),
|
||||
}
|
||||
|
@ -39,62 +24,25 @@ pub struct FileTransferPayload {
|
|||
pub chunk: Bytes,
|
||||
}
|
||||
|
||||
impl EncryptedMessage {
|
||||
pub fn from_encrypted_message(crypt: &Crypt, payload: &EncryptedPayload) -> Result<Self> {
|
||||
let raw = crypt.decrypt(payload)?;
|
||||
let res = match bincode::deserialize(raw.as_ref()) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
println!("deserialize error {:?}", e);
|
||||
return Err(anyhow!("deser error"));
|
||||
impl Message {
|
||||
pub fn serialize(&self) -> Result<Bytes> {
|
||||
match bincode::serialize(&self) {
|
||||
Ok(vec) => Ok(Bytes::from(vec)),
|
||||
Err(e) => Err(anyhow!(e.to_string())),
|
||||
}
|
||||
};
|
||||
Ok(res)
|
||||
}
|
||||
pub fn to_encrypted_message(&self, crypt: &Crypt) -> Result<Message> {
|
||||
let raw = match bincode::serialize(&self) {
|
||||
Ok(result) => result,
|
||||
Err(e) => {
|
||||
println!("serialize error {:?}", e);
|
||||
return Err(anyhow!("serialize error"));
|
||||
pub fn deserialize(bytes: Bytes) -> Result<Self> {
|
||||
match bincode::deserialize(bytes.as_ref()) {
|
||||
Ok(msg) => Ok(msg),
|
||||
Err(e) => Err(anyhow!(e.to_string())),
|
||||
}
|
||||
};
|
||||
let payload = crypt.encrypt(&raw)?;
|
||||
Ok(Message::EncryptedMessage(payload))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum RuckError {
|
||||
NotHandshake,
|
||||
SenderNotConnected,
|
||||
SenderAlreadyConnected,
|
||||
PairDisconnected,
|
||||
}
|
||||
|
||||
impl fmt::Display for RuckError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "RuckError is here!")
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for RuckError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
Some(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn to_stream(stream: TcpStream) -> MessageStream {
|
||||
tokio_serde::SymmetricallyFramed::new(
|
||||
Framed::new(stream, LengthDelimitedCodec::new()),
|
||||
tokio_serde::formats::SymmetricalBincode::<Message>::default(),
|
||||
)
|
||||
Framed::new(stream, LengthDelimitedCodec::new())
|
||||
}
|
||||
}
|
||||
|
||||
pub type MessageStream = SymmetricallyFramed<
|
||||
Framed<TcpStream, LengthDelimitedCodec>,
|
||||
Message,
|
||||
SymmetricalBincode<Message>,
|
||||
>;
|
||||
pub type MessageStream = Framed<TcpStream, LengthDelimitedCodec>;
|
||||
|
|
|
@ -57,7 +57,9 @@ impl Client {
|
|||
// Sender - needs to wait for the incoming msg to look up peer_tx
|
||||
None => {
|
||||
tokio::select! {
|
||||
// Client reads handshake message sent over channel
|
||||
Some(msg) = client.rx.recv() => {
|
||||
// Writes parnter handshake message over wire
|
||||
client.socket.write_all(&msg[..]).await?
|
||||
}
|
||||
}
|
||||
|
@ -126,9 +128,8 @@ pub async fn handle_connection(
|
|||
},
|
||||
Ok(n) => {
|
||||
let b = BytesMut::from(&client_buffer[0..n]).freeze();
|
||||
// println!("reading more = {:?}", b);
|
||||
client.peer_tx.send(b)?;
|
||||
client_buffer.clear();
|
||||
client.peer_tx.send(b)?
|
||||
},
|
||||
Err(e) => {
|
||||
println!("Error {:?}", e);
|
||||
|
|
Loading…
Reference in a new issue