diff --git a/Cargo.toml b/Cargo.toml index 35131ece..683ad7ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,11 +33,13 @@ indexmap = "2" hmac = { version = "0.12.1", optional = true, features = ["reset"] } num_enum = "0.7.2" pbkdf2 = { version = "0.12.2", optional = true } +rand = { version = "0.8.5", optional = true } sha1 = { version = "0.10.6", optional = true } thiserror = "1.0.48" time = { workspace = true, optional = true, features = [ "std", ] } +zeroize = { version = "1.6.0", optional = true, features = ["zeroize_derive"] } zstd = { version = "0.13.1", optional = true, default-features = false } zopfli = { version = "0.8.0", optional = true } deflate64 = { version = "0.1.8", optional = true } @@ -58,7 +60,7 @@ anyhow = "1" clap = { version = "=4.4.18", features = ["derive"] } [features] -aes-crypto = ["aes", "constant_time_eq", "hmac", "pbkdf2", "sha1"] +aes-crypto = ["aes", "constant_time_eq", "hmac", "pbkdf2", "sha1", "rand", "zeroize"] chrono = ["chrono/default"] _deflate-any = [] deflate = ["flate2/rust_backend", "_deflate-any"] diff --git a/src/aes.rs b/src/aes.rs index b28df02e..772d8fb7 100644 --- a/src/aes.rs +++ b/src/aes.rs @@ -9,8 +9,10 @@ use crate::aes_ctr::AesCipher; use crate::types::AesMode; use constant_time_eq::constant_time_eq; use hmac::{Hmac, Mac}; +use rand::RngCore; use sha1::Sha1; -use std::io::{self, Error, ErrorKind, Read}; +use std::io::{self, Error, ErrorKind, Read, Write}; +use zeroize::{Zeroize, Zeroizing}; /// The length of the password verifcation value in bytes const PWD_VERIFY_LENGTH: usize = 2; @@ -204,3 +206,196 @@ impl AesReaderValid { self.reader } } + +pub struct AesWriter { + writer: W, + cipher: Cipher, + hmac: Hmac, + buffer: Zeroizing>, + encrypted_file_header: Option>, +} + +impl AesWriter { + pub fn new(writer: W, aes_mode: AesMode, password: &[u8]) -> io::Result { + let salt_length = aes_mode.salt_length(); + let key_length = aes_mode.key_length(); + + let mut encrypted_file_header = Vec::with_capacity(salt_length + 2); + + let mut salt = vec![0; salt_length]; + rand::thread_rng().fill_bytes(&mut salt); + encrypted_file_header.write_all(&salt)?; + + // Derive a key from the password and salt. The length depends on the aes key length + let derived_key_len = 2 * key_length + PWD_VERIFY_LENGTH; + let mut derived_key: Zeroizing> = Zeroizing::new(vec![0; derived_key_len]); + + // Use PBKDF2 with HMAC-Sha1 to derive the key. + pbkdf2::pbkdf2::>(password, &salt, ITERATION_COUNT, &mut derived_key) + .map_err(|e| Error::new(ErrorKind::InvalidInput, e))?; + let encryption_key = &derived_key[0..key_length]; + let hmac_key = &derived_key[key_length..key_length * 2]; + + let pwd_verify = derived_key[derived_key_len - 2..].to_vec(); + encrypted_file_header.write_all(&pwd_verify)?; + + let cipher = Cipher::from_mode(aes_mode, encryption_key); + let hmac = Hmac::::new_from_slice(hmac_key).unwrap(); + + Ok(Self { + writer, + cipher, + hmac, + buffer: Default::default(), + encrypted_file_header: Some(encrypted_file_header), + }) + } + + pub fn finish(mut self) -> io::Result { + self.write_encrypted_file_header()?; + + // Zip uses HMAC-Sha1-80, which only uses the first half of the hash + // see https://www.winzip.com/win/en/aes_info.html#auth-faq + let computed_auth_code = &self.hmac.finalize_reset().into_bytes()[0..AUTH_CODE_LENGTH]; + self.writer.write_all(computed_auth_code)?; + + Ok(self.writer) + } + + /// The AES encryption specification requires some metadata being written at the start of the + /// file data section, but this can only be done once the extra data writing has been finished + /// so we can't do it when the writer is constructed. + fn write_encrypted_file_header(&mut self) -> io::Result<()> { + if let Some(header) = self.encrypted_file_header.take() { + self.writer.write_all(&header)?; + } + + Ok(()) + } +} + +impl Write for AesWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + self.write_encrypted_file_header()?; + + // Fill the internal buffer and encrypt it in-place. + self.buffer.extend_from_slice(buf); + self.cipher.crypt_in_place(&mut self.buffer[..]); + + // Update the hmac with the encrypted data. + self.hmac.update(&self.buffer[..]); + + // Write the encrypted buffer to the inner writer. We need to use `write_all` here as if + // we only write parts of the data we can't easily reverse the keystream in the cipher + // implementation. + self.writer.write_all(&self.buffer[..])?; + + // Zeroize the backing memory before clearing the buffer to prevent cleartext data from + // being left in memory. + self.buffer.zeroize(); + self.buffer.clear(); + + Ok(buf.len()) + } + + fn flush(&mut self) -> io::Result<()> { + self.writer.flush() + } +} + +#[cfg(test)] +mod tests { + use std::io::{self, Read, Write}; + + use crate::{ + aes::{AesReader, AesWriter}, + types::AesMode, + }; + + /// Checks whether `AesReader` can successfully decrypt what `AesWriter` produces. + fn roundtrip(aes_mode: AesMode, password: &[u8], plaintext: &[u8]) -> io::Result { + let mut buf = io::Cursor::new(vec![]); + let mut read_buffer = vec![]; + + { + let mut writer = AesWriter::new(&mut buf, aes_mode, &password)?; + writer.write_all(plaintext)?; + writer.finish()?; + } + + // Reset cursor position to the beginning. + buf.set_position(0); + + { + let compressed_length = buf.get_ref().len() as u64; + let mut reader = + match AesReader::new(&mut buf, aes_mode, compressed_length).validate(&password)? { + None => { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + "Invalid authentication code", + )) + } + Some(r) => r, + }; + reader.read_to_end(&mut read_buffer)?; + } + + return Ok(plaintext == read_buffer); + } + + #[test] + fn crypt_aes_256_0_byte() { + let plaintext = &[]; + let password = b"some super secret password"; + assert!(roundtrip(AesMode::Aes256, password, plaintext).expect("could encrypt and decrypt")); + } + + #[test] + fn crypt_aes_128_5_byte() { + let plaintext = b"asdf\n"; + let password = b"some super secret password"; + + assert!(roundtrip(AesMode::Aes128, password, plaintext).expect("could encrypt and decrypt")); + } + + #[test] + fn crypt_aes_192_5_byte() { + let plaintext = b"asdf\n"; + let password = b"some super secret password"; + + assert!(roundtrip(AesMode::Aes192, password, plaintext).expect("could encrypt and decrypt")); + } + + #[test] + fn crypt_aes_256_5_byte() { + let plaintext = b"asdf\n"; + let password = b"some super secret password"; + + assert!(roundtrip(AesMode::Aes256, password, plaintext).expect("could encrypt and decrypt")); + } + + #[test] + fn crypt_aes_128_40_byte() { + let plaintext = b"Lorem ipsum dolor sit amet, consectetur\n"; + let password = b"some super secret password"; + + assert!(roundtrip(AesMode::Aes128, password, plaintext).expect("could encrypt and decrypt")); + } + + #[test] + fn crypt_aes_192_40_byte() { + let plaintext = b"Lorem ipsum dolor sit amet, consectetur\n"; + let password = b"some super secret password"; + + assert!(roundtrip(AesMode::Aes192, password, plaintext).expect("could encrypt and decrypt")); + } + + #[test] + fn crypt_aes_256_40_byte() { + let plaintext = b"Lorem ipsum dolor sit amet, consectetur\n"; + let password = b"some super secret password"; + + assert!(roundtrip(AesMode::Aes256, password, plaintext).expect("could encrypt and decrypt")); + } +}