From b01d5c9b1f5ba4f6c7b533db770dc4b76d10d847 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20du=20Garreau?= Date: Thu, 11 Jul 2024 17:27:12 +0200 Subject: [PATCH] Split reader and decompressor --- src/compression.rs | 88 +++++++++++++++++- src/read.rs | 224 ++++++++------------------------------------- src/read/lzma.rs | 7 +- src/read/xz.rs | 18 ++-- 4 files changed, 135 insertions(+), 202 deletions(-) diff --git a/src/compression.rs b/src/compression.rs index 0dd21017..83a7669b 100644 --- a/src/compression.rs +++ b/src/compression.rs @@ -1,6 +1,6 @@ //! Possible ZIP compression methods. -use std::fmt; +use std::{fmt, io}; #[allow(deprecated)] /// Identifies the storage format used to compress a file within a ZIP archive. @@ -189,6 +189,92 @@ pub const SUPPORTED_COMPRESSION_METHODS: &[CompressionMethod] = &[ CompressionMethod::Zstd, ]; +pub(crate) enum Decompressor { + Stored(R), + #[cfg(feature = "_deflate-any")] + Deflated(flate2::bufread::DeflateDecoder), + #[cfg(feature = "deflate64")] + Deflate64(deflate64::Deflate64Decoder), + #[cfg(feature = "bzip2")] + Bzip2(bzip2::bufread::BzDecoder), + #[cfg(feature = "zstd")] + Zstd(zstd::Decoder<'static, R>), + #[cfg(feature = "lzma")] + Lzma(Box>), + #[cfg(feature = "xz")] + Xz(crate::read::xz::XzDecoder), +} + +impl io::Read for Decompressor { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + match self { + Decompressor::Stored(r) => r.read(buf), + #[cfg(feature = "_deflate-any")] + Decompressor::Deflated(r) => r.read(buf), + #[cfg(feature = "deflate64")] + Decompressor::Deflate64(r) => r.read(buf), + #[cfg(feature = "bzip2")] + Decompressor::Bzip2(r) => r.read(buf), + #[cfg(feature = "zstd")] + Decompressor::Zstd(r) => r.read(buf), + #[cfg(feature = "lzma")] + Decompressor::Lzma(r) => r.read(buf), + #[cfg(feature = "xz")] + Decompressor::Xz(r) => r.read(buf), + } + } +} + +impl Decompressor { + pub fn new(reader: R, compression_method: CompressionMethod) -> crate::result::ZipResult { + Ok(match compression_method { + CompressionMethod::Stored => Decompressor::Stored(reader), + #[cfg(feature = "_deflate-any")] + CompressionMethod::Deflated => { + Decompressor::Deflated(flate2::bufread::DeflateDecoder::new(reader)) + } + #[cfg(feature = "deflate64")] + CompressionMethod::Deflate64 => { + Decompressor::Deflate64(deflate64::Deflate64Decoder::with_buffer(reader)) + } + #[cfg(feature = "bzip2")] + CompressionMethod::Bzip2 => Decompressor::Bzip2(bzip2::bufread::BzDecoder::new(reader)), + #[cfg(feature = "zstd")] + CompressionMethod::Zstd => Decompressor::Zstd(zstd::Decoder::with_buffer(reader)?), + #[cfg(feature = "lzma")] + CompressionMethod::Lzma => { + Decompressor::Lzma(Box::new(crate::read::lzma::LzmaDecoder::new(reader))) + } + #[cfg(feature = "xz")] + CompressionMethod::Xz => Decompressor::Xz(crate::read::xz::XzDecoder::new(reader)), + _ => { + return Err(crate::result::ZipError::UnsupportedArchive( + "Compression method not supported", + )) + } + }) + } + + /// Consumes this decoder, returning the underlying reader. + pub fn into_inner(self) -> R { + match self { + Decompressor::Stored(r) => r, + #[cfg(feature = "_deflate-any")] + Decompressor::Deflated(r) => r.into_inner(), + #[cfg(feature = "deflate64")] + Decompressor::Deflate64(r) => r.into_inner(), + #[cfg(feature = "bzip2")] + Decompressor::Bzip2(r) => r.into_inner(), + #[cfg(feature = "zstd")] + Decompressor::Zstd(r) => r.finish(), + #[cfg(feature = "lzma")] + Decompressor::Lzma(r) => r.into_inner(), + #[cfg(feature = "xz")] + Decompressor::Xz(r) => r.into_inner(), + } + } +} + #[cfg(test)] mod test { use super::{CompressionMethod, SUPPORTED_COMPRESSION_METHODS}; diff --git a/src/read.rs b/src/read.rs index 830a5851..ceaa4a11 100644 --- a/src/read.rs +++ b/src/read.rs @@ -2,7 +2,7 @@ #[cfg(feature = "aes-crypto")] use crate::aes::{AesReader, AesReaderValid}; -use crate::compression::CompressionMethod; +use crate::compression::{CompressionMethod, Decompressor}; use crate::cp437::FromCp437; use crate::crc32::Crc32Reader; use crate::extra_fields::{ExtendedTimestamp, ExtraField}; @@ -26,18 +26,6 @@ use std::path::{Path, PathBuf}; use std::rc::Rc; use std::sync::{Arc, OnceLock}; -#[cfg(feature = "deflate-flate2")] -use flate2::read::DeflateDecoder; - -#[cfg(feature = "deflate64")] -use deflate64::Deflate64Decoder; - -#[cfg(feature = "bzip2")] -use bzip2::read::BzDecoder; - -#[cfg(feature = "zstd")] -use zstd::stream::read::Decoder as ZstdDecoder; - mod config; pub use config::*; @@ -123,11 +111,7 @@ pub(crate) mod zip_archive { #[cfg(feature = "aes-crypto")] use crate::aes::PWD_VERIFY_LENGTH; use crate::extra_fields::UnicodeExtraField; -#[cfg(feature = "lzma")] -use crate::read::lzma::LzmaDecoder; -#[cfg(feature = "xz")] -use crate::read::xz::XzDecoder; -use crate::result::ZipError::{InvalidArchive, InvalidPassword, UnsupportedArchive}; +use crate::result::ZipError::{InvalidArchive, InvalidPassword}; use crate::spec::is_dir; use crate::types::ffi::S_IFLNK; use crate::unstable::{path_to_string, LittleEndianReadExt}; @@ -199,134 +183,63 @@ impl<'a> CryptoReader<'a> { } } +#[cold] +fn invalid_state() -> io::Error { + io::Error::new( + io::ErrorKind::Other, + "ZipFileReader was in an invalid state", + ) +} + pub(crate) enum ZipFileReader<'a> { NoReader, Raw(io::Take<&'a mut dyn Read>), - Stored(Crc32Reader>), - #[cfg(feature = "_deflate-any")] - Deflated(Crc32Reader>>), - #[cfg(feature = "deflate64")] - Deflate64(Crc32Reader>>>), - #[cfg(feature = "bzip2")] - Bzip2(Crc32Reader>>), - #[cfg(feature = "zstd")] - Zstd(Crc32Reader>>>), - #[cfg(feature = "lzma")] - Lzma(Crc32Reader>>>), - #[cfg(feature = "xz")] - Xz(Crc32Reader>>), + Compressed(Box>>>>), } impl<'a> Read for ZipFileReader<'a> { fn read(&mut self, buf: &mut [u8]) -> io::Result { match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::NoReader => Err(invalid_state()), ZipFileReader::Raw(r) => r.read(buf), - ZipFileReader::Stored(r) => r.read(buf), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.read(buf), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.read(buf), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.read(buf), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.read(buf), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => r.read(buf), - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.read(buf), + ZipFileReader::Compressed(r) => r.read(buf), } } fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::NoReader => Err(invalid_state()), ZipFileReader::Raw(r) => r.read_exact(buf), - ZipFileReader::Stored(r) => r.read_exact(buf), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.read_exact(buf), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.read_exact(buf), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.read_exact(buf), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.read_exact(buf), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => r.read_exact(buf), - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.read_exact(buf), + ZipFileReader::Compressed(r) => r.read_exact(buf), } } fn read_to_end(&mut self, buf: &mut Vec) -> io::Result { match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::NoReader => Err(invalid_state()), ZipFileReader::Raw(r) => r.read_to_end(buf), - ZipFileReader::Stored(r) => r.read_to_end(buf), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.read_to_end(buf), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.read_to_end(buf), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.read_to_end(buf), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.read_to_end(buf), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => r.read_to_end(buf), - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.read_to_end(buf), + ZipFileReader::Compressed(r) => r.read_to_end(buf), } } fn read_to_string(&mut self, buf: &mut String) -> io::Result { match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), + ZipFileReader::NoReader => Err(invalid_state()), ZipFileReader::Raw(r) => r.read_to_string(buf), - ZipFileReader::Stored(r) => r.read_to_string(buf), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.read_to_string(buf), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.read_to_string(buf), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.read_to_string(buf), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.read_to_string(buf), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => r.read_to_string(buf), - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.read_to_string(buf), + ZipFileReader::Compressed(r) => r.read_to_string(buf), } } } impl<'a> ZipFileReader<'a> { - /// Consumes this decoder, returning the underlying reader. - pub fn drain(self) { - let mut inner = match self { - ZipFileReader::NoReader => panic!("ZipFileReader was in an invalid state"), - ZipFileReader::Raw(r) => r, - ZipFileReader::Stored(r) => r.into_inner().into_inner(), - #[cfg(feature = "_deflate-any")] - ZipFileReader::Deflated(r) => r.into_inner().into_inner().into_inner(), - #[cfg(feature = "deflate64")] - ZipFileReader::Deflate64(r) => r.into_inner().into_inner().into_inner().into_inner(), - #[cfg(feature = "bzip2")] - ZipFileReader::Bzip2(r) => r.into_inner().into_inner().into_inner(), - #[cfg(feature = "zstd")] - ZipFileReader::Zstd(r) => r.into_inner().finish().into_inner().into_inner(), - #[cfg(feature = "lzma")] - ZipFileReader::Lzma(r) => { - // Lzma reader owns its buffer rather than mutably borrowing it, so we have to drop - // it separately - if let Ok(mut remaining) = r.into_inner().finish() { - let _ = copy(&mut remaining, &mut sink()); - } - return; + fn into_inner(self) -> io::Result> { + match self { + ZipFileReader::NoReader => Err(invalid_state()), + ZipFileReader::Raw(r) => Ok(r), + ZipFileReader::Compressed(r) => { + Ok(r.into_inner().into_inner().into_inner().into_inner()) } - #[cfg(feature = "xz")] - ZipFileReader::Xz(r) => r.into_inner().into_inner().into_inner(), - }; - let _ = copy(&mut inner, &mut sink()); + } } } @@ -434,68 +347,11 @@ pub(crate) fn make_reader( ) -> ZipResult { let ae2_encrypted = reader.is_ae2_encrypted(); - match compression_method { - CompressionMethod::Stored => Ok(ZipFileReader::Stored(Crc32Reader::new( - reader, - crc32, - ae2_encrypted, - ))), - #[cfg(feature = "_deflate-any")] - CompressionMethod::Deflated => { - let deflate_reader = DeflateDecoder::new(reader); - Ok(ZipFileReader::Deflated(Crc32Reader::new( - deflate_reader, - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "deflate64")] - CompressionMethod::Deflate64 => { - let deflate64_reader = Deflate64Decoder::new(reader); - Ok(ZipFileReader::Deflate64(Crc32Reader::new( - deflate64_reader, - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "bzip2")] - CompressionMethod::Bzip2 => { - let bzip2_reader = BzDecoder::new(reader); - Ok(ZipFileReader::Bzip2(Crc32Reader::new( - bzip2_reader, - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "zstd")] - CompressionMethod::Zstd => { - let zstd_reader = ZstdDecoder::new(reader).unwrap(); - Ok(ZipFileReader::Zstd(Crc32Reader::new( - zstd_reader, - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "lzma")] - CompressionMethod::Lzma => { - let reader = LzmaDecoder::new(reader); - Ok(ZipFileReader::Lzma(Crc32Reader::new( - Box::new(reader), - crc32, - ae2_encrypted, - ))) - } - #[cfg(feature = "xz")] - CompressionMethod::Xz => { - let reader = XzDecoder::new(reader); - Ok(ZipFileReader::Xz(Crc32Reader::new( - reader, - crc32, - ae2_encrypted, - ))) - } - _ => Err(UnsupportedArchive("Compression method not supported")), - } + Ok(ZipFileReader::Compressed(Box::new(Crc32Reader::new( + Decompressor::new(io::BufReader::new(reader), compression_method)?, + crc32, + ae2_encrypted, + )))) } #[derive(Debug)] @@ -1722,19 +1578,11 @@ impl<'a> Drop for ZipFile<'a> { // In this case, we want to exhaust the reader so that the next file is accessible. if let Cow::Owned(_) = self.data { // Get the inner `Take` reader so all decryption, decompression and CRC calculation is skipped. - match &mut self.reader { - ZipFileReader::NoReader => { - let innerreader = self.crypto_reader.take(); - let _ = copy( - &mut innerreader.expect("Invalid reader state").into_inner(), - &mut sink(), - ); - } - reader => { - let innerreader = std::mem::replace(reader, ZipFileReader::NoReader); - innerreader.drain(); - } - }; + if let Ok(mut inner) = + std::mem::replace(&mut self.reader, ZipFileReader::NoReader).into_inner() + { + let _ = copy(&mut inner, &mut sink()); + } } } } diff --git a/src/read/lzma.rs b/src/read/lzma.rs index 04b6edd8..0b29b761 100644 --- a/src/read/lzma.rs +++ b/src/read/lzma.rs @@ -1,6 +1,6 @@ use lzma_rs::decompress::{Options, Stream, UnpackedSize}; use std::collections::VecDeque; -use std::io::{copy, Error, Read, Result, Write}; +use std::io::{Read, Result, Write}; const COMPRESSED_BYTES_TO_BUFFER: usize = 4096; @@ -24,9 +24,8 @@ impl LzmaDecoder { } } - pub fn finish(mut self) -> Result> { - copy(&mut self.compressed_reader, &mut self.stream)?; - self.stream.finish().map_err(Error::from) + pub fn into_inner(self) -> R { + self.compressed_reader } } diff --git a/src/read/xz.rs b/src/read/xz.rs index 478ae102..991df62b 100644 --- a/src/read/xz.rs +++ b/src/read/xz.rs @@ -2,12 +2,12 @@ use crc32fast::Hasher; use lzma_rs::decompress::raw::Lzma2Decoder; use std::{ collections::VecDeque, - io::{BufRead, BufReader, Error, Read, Result, Write}, + io::{BufRead, Error, Read, Result, Write}, }; #[derive(Debug)] -pub struct XzDecoder { - compressed_reader: BufReader, +pub struct XzDecoder { + compressed_reader: R, stream_size: usize, buf: VecDeque, check_size: usize, @@ -15,10 +15,10 @@ pub struct XzDecoder { flags: [u8; 2], } -impl XzDecoder { +impl XzDecoder { pub fn new(inner: R) -> Self { XzDecoder { - compressed_reader: BufReader::new(inner), + compressed_reader: inner, stream_size: 0, buf: VecDeque::new(), check_size: 0, @@ -83,7 +83,7 @@ fn error(s: &'static str) -> Result { Err(Error::new(std::io::ErrorKind::InvalidData, s)) } -fn get_multibyte(input: &mut R, hasher: &mut Hasher) -> Result { +fn get_multibyte(input: &mut R, hasher: &mut Hasher) -> Result { let mut result = 0; for i in 0..9 { let mut b = [0u8; 1]; @@ -98,7 +98,7 @@ fn get_multibyte(input: &mut R, hasher: &mut Hasher) -> Result { error("Invalid multi-byte encoding") } -impl Read for XzDecoder { +impl Read for XzDecoder { fn read(&mut self, buf: &mut [u8]) -> Result { if !self.buf.is_empty() { let len = std::cmp::min(buf.len(), self.buf.len()); @@ -263,8 +263,8 @@ impl Read for XzDecoder { } } -impl XzDecoder { +impl XzDecoder { pub fn into_inner(self) -> R { - self.compressed_reader.into_inner() + self.compressed_reader } }