diff --git a/src/aes_ctr.rs b/src/aes_ctr.rs index 2d1a4f5e..b544996a 100644 --- a/src/aes_ctr.rs +++ b/src/aes_ctr.rs @@ -1,8 +1,7 @@ use aes::block_cipher::generic_array::GenericArray; use aes::{BlockCipher, NewBlockCipher}; -use arrayvec::{Array, ArrayVec}; -use std::io::Read; -use std::{any, fmt, io}; +use byteorder::WriteBytesExt; +use std::{any, fmt}; /// Internal block size of an AES cipher. const AES_BLOCK_SIZE: usize = 16; @@ -20,14 +19,13 @@ pub struct Aes256; /// An AES cipher kind. pub trait AesKind { /// Key type. - type Key: Array; + type Key: AsRef<[u8]>; /// Cipher used to decrypt. type Cipher; } impl AesKind for Aes256 { type Key = [u8; 32]; - type Cipher = aes::Aes256; } @@ -40,10 +38,15 @@ impl AesKind for Aes256 { /// /// The stream implements the `Read` trait; encryption or decryption is performed by XOR-ing the /// bytes from the key stream with the ciphertext/plaintext. -struct AesCtrZipKeyStream { +pub struct AesCtrZipKeyStream { + /// Current AES counter. counter: u128, + /// AES cipher instance. cipher: C::Cipher, - buffer: ArrayVec, + /// Stores the currently available keystream bytes. + buffer: [u8; AES_BLOCK_SIZE], + /// Number of bytes already used up from `buffer`. + pos: usize, } impl fmt::Debug for AesCtrZipKeyStream @@ -65,13 +68,13 @@ where C: AesKind, C::Cipher: NewBlockCipher, { - #[allow(dead_code)] /// Creates a new zip variant AES-CTR key stream. pub fn new(key: &C::Key) -> AesCtrZipKeyStream { AesCtrZipKeyStream { counter: 1, - cipher: C::Cipher::new_varkey(key.as_slice()).expect("key should have correct size"), - buffer: ArrayVec::new(), + cipher: C::Cipher::new_varkey(key.as_ref()).expect("key should have correct size"), + buffer: [0u8; AES_BLOCK_SIZE], + pos: AES_BLOCK_SIZE, } } } @@ -85,21 +88,26 @@ where #[inline] fn crypt(&mut self, mut target: &mut [u8]) { while target.len() > 0 { - if self.buffer.len() == 0 { + if self.pos == AES_BLOCK_SIZE { // Note: AES block size is always 16 bytes, same as u128. - let mut block = GenericArray::clone_from_slice(&self.counter.to_le_bytes()); - - // TODO: Use trait. - self.cipher.encrypt_block(&mut block); + self.buffer + .as_mut() + .write_u128::(self.counter) + .expect("did not expect u128 le conversion to fail"); + self.cipher + .encrypt_block(GenericArray::from_mut_slice(&mut self.buffer)); self.counter += 1; - self.buffer = block.into_iter().collect(); + self.pos = 0; } - let target_len = target.len().min(self.buffer.len()); + let target_len = target.len().min(AES_BLOCK_SIZE - self.pos); - xor(&mut target[0..target_len], &self.buffer[0..target_len]); - self.buffer.drain(0..target_len); + xor( + &mut target[0..target_len], + &self.buffer[self.pos..(self.pos + target_len)], + ); target = &mut target[target_len..]; + self.pos += target_len; } } } @@ -116,8 +124,7 @@ pub fn xor(dest: &mut [u8], src: &[u8]) { #[cfg(test)] mod tests { - use super::{xor, Aes256, AesCtrZipKeyStream}; - use std::io::Read; + use super::{Aes256, AesCtrZipKeyStream}; #[test] fn crypt_simple_example() { @@ -134,5 +141,10 @@ mod tests { let mut plaintext = ciphertext; key_stream.crypt(&mut plaintext); assert_eq!(&plaintext, expected_plaintext); + + // Round-tripping should yield the ciphertext again. + let mut key_stream = AesCtrZipKeyStream::::new(&key); + key_stream.crypt(&mut plaintext); + assert_eq!(plaintext, ciphertext); } }