Fix bugs where calling start_file with incorrect parameters would close the ZipWriter

This commit is contained in:
Chris Hennick 2023-05-01 10:11:07 -07:00
parent df489189b1
commit 43a9db8886
No known key found for this signature in database
GPG key ID: 25653935CC8B6C74
2 changed files with 85 additions and 96 deletions

View file

@ -23,7 +23,6 @@ pbkdf2 = {version = "0.12.1", optional = true }
sha1 = {version = "0.10.5", optional = true } sha1 = {version = "0.10.5", optional = true }
time = { version = "0.3.20", optional = true, default-features = false, features = ["std"] } time = { version = "0.3.20", optional = true, default-features = false, features = ["std"] }
zstd = { version = "0.12.3", optional = true } zstd = { version = "0.12.3", optional = true }
replace_with = "0.1.7"
[target.'cfg(any(all(target_arch = "arm", target_pointer_width = "32"), target_arch = "mips", target_arch = "powerpc"))'.dependencies] [target.'cfg(any(all(target_arch = "arm", target_pointer_width = "32"), target_arch = "mips", target_arch = "powerpc"))'.dependencies]
crossbeam-utils = "0.8.15" crossbeam-utils = "0.8.15"

View file

@ -24,7 +24,6 @@ use flate2::write::DeflateEncoder;
#[cfg(feature = "bzip2")] #[cfg(feature = "bzip2")]
use bzip2::write::BzEncoder; use bzip2::write::BzEncoder;
use replace_with::{replace_with, replace_with_and_return};
#[cfg(feature = "time")] #[cfg(feature = "time")]
use time::OffsetDateTime; use time::OffsetDateTime;
@ -295,7 +294,7 @@ impl<A: Read + Write + Seek> ZipWriter<A> {
let _ = readwriter.seek(SeekFrom::Start(directory_start)); // seek directory_start to overwrite it let _ = readwriter.seek(SeekFrom::Start(directory_start)); // seek directory_start to overwrite it
Ok(ZipWriter { Ok(ZipWriter {
inner: GenericZipWriter::Storer(readwriter), inner: Storer(readwriter),
files, files,
files_by_name, files_by_name,
stats: Default::default(), stats: Default::default(),
@ -358,7 +357,7 @@ impl<W: Write + Seek> ZipWriter<W> {
/// Before writing to this object, the [`ZipWriter::start_file`] function should be called. /// Before writing to this object, the [`ZipWriter::start_file`] function should be called.
pub fn new(inner: W) -> ZipWriter<W> { pub fn new(inner: W) -> ZipWriter<W> {
ZipWriter { ZipWriter {
inner: GenericZipWriter::Storer(inner), inner: Storer(inner),
files: Vec::new(), files: Vec::new(),
files_by_name: HashMap::new(), files_by_name: HashMap::new(),
stats: Default::default(), stats: Default::default(),
@ -824,7 +823,7 @@ impl<W: Write + Seek> ZipWriter<W> {
/// Note that the zipfile will also be finished on drop. /// Note that the zipfile will also be finished on drop.
pub fn finish(&mut self) -> ZipResult<W> { pub fn finish(&mut self) -> ZipResult<W> {
self.finalize()?; self.finalize()?;
let inner = mem::replace(&mut self.inner, GenericZipWriter::Closed); let inner = mem::replace(&mut self.inner, Closed);
Ok(inner.unwrap()) Ok(inner.unwrap())
} }
@ -961,7 +960,7 @@ impl<W: Write + Seek> GenericZipWriter<W> {
} }
match self { match self {
GenericZipWriter::Closed => { Closed => {
return Err(io::Error::new( return Err(io::Error::new(
io::ErrorKind::BrokenPipe, io::ErrorKind::BrokenPipe,
"ZipWriter was already closed", "ZipWriter was already closed",
@ -970,27 +969,8 @@ impl<W: Write + Seek> GenericZipWriter<W> {
} }
_ => {} _ => {}
} }
let bare: &mut W = match self {
Storer(w) => &mut unsafe { mem::transmute_copy::<W, W>(w) },
#[cfg(any(
feature = "deflate",
feature = "deflate-miniz",
feature = "deflate-zlib"
))]
GenericZipWriter::Deflater(w) => &mut w.finish()?,
#[cfg(feature = "bzip2")]
GenericZipWriter::Bzip2(w) => &mut w.finish()?,
#[cfg(feature = "zstd")]
GenericZipWriter::Zstd(w) => &mut w.finish()?,
Closed => {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"ZipWriter was already closed",
)
.into())
}
};
let make_new_self: Box<dyn FnOnce(W) -> GenericZipWriter<W>> = {
#[allow(deprecated)] #[allow(deprecated)]
match compression { match compression {
CompressionMethod::Stored => { CompressionMethod::Stored => {
@ -1000,8 +980,7 @@ impl<W: Write + Seek> GenericZipWriter<W> {
)); ));
} }
let _ = mem::replace(self, Storer(unsafe { mem::transmute_copy::<W, W>(bare) })); Box::new(|bare| Storer(bare))
Ok(())
} }
#[cfg(any( #[cfg(any(
feature = "deflate", feature = "deflate",
@ -1009,37 +988,28 @@ impl<W: Write + Seek> GenericZipWriter<W> {
feature = "deflate-zlib" feature = "deflate-zlib"
))] ))]
CompressionMethod::Deflated => { CompressionMethod::Deflated => {
let _ = mem::replace(self, GenericZipWriter::Deflater(DeflateEncoder::new( let level = clamp_opt(
unsafe { mem::transmute_copy::<W, W>(bare) }, compression_level.unwrap_or(flate2::Compression::default().level() as i32),
flate2::Compression::new( deflate_compression_level_range()
clamp_opt( ).ok_or(ZipError::UnsupportedArchive(
compression_level
.unwrap_or(flate2::Compression::default().level() as i32),
deflate_compression_level_range(),
)
.ok_or(ZipError::UnsupportedArchive(
"Unsupported compression level", "Unsupported compression level",
))? as u32, ))? as u32;
), Box::new(move |bare|
))); GenericZipWriter::Deflater(DeflateEncoder::new(bare, flate2::Compression::new(level)))
Ok(()) )
}, }
#[cfg(feature = "bzip2")] #[cfg(feature = "bzip2")]
CompressionMethod::Bzip2 => { CompressionMethod::Bzip2 => {
let _ = mem::replace(self, GenericZipWriter::Bzip2(BzEncoder::new( let level = clamp_opt(
unsafe { mem::transmute_copy::<W, W>(bare) },
bzip2::Compression::new(
clamp_opt(
compression_level compression_level
.unwrap_or(bzip2::Compression::default().level() as i32), .unwrap_or(bzip2::Compression::default().level() as i32),
bzip2_compression_level_range(), bzip2_compression_level_range(),
) ).ok_or(ZipError::UnsupportedArchive(
.ok_or(ZipError::UnsupportedArchive(
"Unsupported compression level", "Unsupported compression level",
))? as u32, ))? as u32;
), Box::new(move |bare|
))); GenericZipWriter::Bzip2(BzEncoder::new(bare, bzip2::Compression::new(level)))
Ok(()) )
}, },
CompressionMethod::AES => { CompressionMethod::AES => {
return Err(ZipError::UnsupportedArchive( return Err(ZipError::UnsupportedArchive(
@ -1048,25 +1018,45 @@ impl<W: Write + Seek> GenericZipWriter<W> {
} }
#[cfg(feature = "zstd")] #[cfg(feature = "zstd")]
CompressionMethod::Zstd => { CompressionMethod::Zstd => {
let _ = mem::replace(self, GenericZipWriter::Zstd( let level = clamp_opt(
ZstdEncoder::new(
unsafe { mem::transmute_copy::<W, W>(bare) },
clamp_opt(
compression_level.unwrap_or(zstd::DEFAULT_COMPRESSION_LEVEL), compression_level.unwrap_or(zstd::DEFAULT_COMPRESSION_LEVEL),
zstd::compression_level_range(), zstd::compression_level_range(),
) )
.ok_or(ZipError::UnsupportedArchive( .ok_or(ZipError::UnsupportedArchive(
"Unsupported compression level", "Unsupported compression level",
))?, ))?;
) Box::new(move |bare|
.unwrap(), GenericZipWriter::Zstd(ZstdEncoder::new(bare, level).unwrap()
)); ))
Ok(())
}, },
CompressionMethod::Unsupported(..) => { CompressionMethod::Unsupported(..) => {
return Err(ZipError::UnsupportedArchive("Unsupported compression")); return Err(ZipError::UnsupportedArchive("Unsupported compression"))
} }
} }
};
let bare = match mem::replace(self, Closed) {
Storer(w) => w,
#[cfg(any(
feature = "deflate",
feature = "deflate-miniz",
feature = "deflate-zlib"
))]
GenericZipWriter::Deflater(w) => w.finish()?,
#[cfg(feature = "bzip2")]
GenericZipWriter::Bzip2(w) => w.finish()?,
#[cfg(feature = "zstd")]
GenericZipWriter::Zstd(w) => w.finish()?,
Closed => {
return Err(io::Error::new(
io::ErrorKind::BrokenPipe,
"ZipWriter was already closed",
)
.into())
}
};
*self = (make_new_self)(bare);
Ok(())
} }
fn ref_mut(&mut self) -> Option<&mut dyn Write> { fn ref_mut(&mut self) -> Option<&mut dyn Write> {
@ -1099,7 +1089,7 @@ impl<W: Write + Seek> GenericZipWriter<W> {
fn current_compression(&self) -> Option<CompressionMethod> { fn current_compression(&self) -> Option<CompressionMethod> {
match *self { match *self {
GenericZipWriter::Storer(..) => Some(CompressionMethod::Stored), Storer(..) => Some(CompressionMethod::Stored),
#[cfg(any( #[cfg(any(
feature = "deflate", feature = "deflate",
feature = "deflate-miniz", feature = "deflate-miniz",
@ -1110,13 +1100,13 @@ impl<W: Write + Seek> GenericZipWriter<W> {
GenericZipWriter::Bzip2(..) => Some(CompressionMethod::Bzip2), GenericZipWriter::Bzip2(..) => Some(CompressionMethod::Bzip2),
#[cfg(feature = "zstd")] #[cfg(feature = "zstd")]
GenericZipWriter::Zstd(..) => Some(CompressionMethod::Zstd), GenericZipWriter::Zstd(..) => Some(CompressionMethod::Zstd),
GenericZipWriter::Closed => None, Closed => None,
} }
} }
fn unwrap(self) -> W { fn unwrap(self) -> W {
match self { match self {
GenericZipWriter::Storer(w) => w, Storer(w) => w,
_ => panic!("Should have switched to stored beforehand"), _ => panic!("Should have switched to stored beforehand"),
} }
} }
@ -1135,7 +1125,7 @@ fn deflate_compression_level_range() -> std::ops::RangeInclusive<i32> {
#[cfg(feature = "bzip2")] #[cfg(feature = "bzip2")]
fn bzip2_compression_level_range() -> std::ops::RangeInclusive<i32> { fn bzip2_compression_level_range() -> std::ops::RangeInclusive<i32> {
let min = bzip2::Compression::none().level() as i32; let min = bzip2::Compression::fast().level() as i32;
let max = bzip2::Compression::best().level() as i32; let max = bzip2::Compression::best().level() as i32;
min..=max min..=max
} }