Wrap files in RefCell

This commit is contained in:
Chris Hennick 2023-04-29 13:47:24 -07:00
parent 59b9279235
commit 8c4ee1f42b
No known key found for this signature in database
GPG key ID: 25653935CC8B6C74

View file

@ -7,6 +7,7 @@ use crate::spec;
use crate::types::{ffi, AtomicU64, DateTime, System, ZipFileData, DEFAULT_VERSION}; use crate::types::{ffi, AtomicU64, DateTime, System, ZipFileData, DEFAULT_VERSION};
use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
use crc32fast::Hasher; use crc32fast::Hasher;
use std::cell::RefCell;
use std::convert::TryInto; use std::convert::TryInto;
use std::default::Default; use std::default::Default;
use std::io; use std::io;
@ -47,6 +48,7 @@ enum GenericZipWriter<W: Write + Seek> {
// Put the struct declaration in a private module to convince rustdoc to display ZipWriter nicely // Put the struct declaration in a private module to convince rustdoc to display ZipWriter nicely
pub(crate) mod zip_writer { pub(crate) mod zip_writer {
use super::*; use super::*;
use std::cell::RefCell;
/// ZIP archive generator /// ZIP archive generator
/// ///
/// Handles the bookkeeping involved in building an archive, and provides an /// Handles the bookkeeping involved in building an archive, and provides an
@ -77,7 +79,7 @@ pub(crate) mod zip_writer {
/// ``` /// ```
pub struct ZipWriter<W: Write + Seek> { pub struct ZipWriter<W: Write + Seek> {
pub(super) inner: GenericZipWriter<W>, pub(super) inner: GenericZipWriter<W>,
pub(super) files: Vec<ZipFileData>, pub(super) files: Vec<RefCell<ZipFileData>>,
pub(super) stats: ZipWriterStats, pub(super) stats: ZipWriterStats,
pub(super) writing_to_file: bool, pub(super) writing_to_file: bool,
pub(super) writing_to_extra_field: bool, pub(super) writing_to_extra_field: bool,
@ -212,13 +214,18 @@ impl<W: Write + Seek> Write for ZipWriter<W> {
match self.inner.ref_mut() { match self.inner.ref_mut() {
Some(ref mut w) => { Some(ref mut w) => {
if self.writing_to_extra_field { if self.writing_to_extra_field {
self.files.last_mut().unwrap().extra_field.write(buf) self.files
.last_mut()
.unwrap()
.borrow_mut()
.extra_field
.write(buf)
} else { } else {
let write_result = w.write(buf); let write_result = w.write(buf);
if let Ok(count) = write_result { if let Ok(count) = write_result {
self.stats.update(&buf[0..count]); self.stats.update(&buf[0..count]);
if self.stats.bytes_written > spec::ZIP64_BYTES_THR if self.stats.bytes_written > spec::ZIP64_BYTES_THR
&& !self.files.last_mut().unwrap().large_file && !self.files.last_mut().unwrap().borrow().large_file
{ {
let _inner = mem::replace(&mut self.inner, GenericZipWriter::Closed); let _inner = mem::replace(&mut self.inner, GenericZipWriter::Closed);
return Err(io::Error::new( return Err(io::Error::new(
@ -276,7 +283,7 @@ impl<A: Read + Write + Seek> ZipWriter<A> {
} }
let files = (0..number_of_files) let files = (0..number_of_files)
.map(|_| central_header_to_zip_file(&mut readwriter, archive_offset)) .map(|_| central_header_to_zip_file(&mut readwriter, archive_offset).map(RefCell::new))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
let _ = readwriter.seek(SeekFrom::Start(directory_start)); // seek directory_start to overwrite it let _ = readwriter.seek(SeekFrom::Start(directory_start)); // seek directory_start to overwrite it
@ -300,7 +307,8 @@ impl<A: Read + Write + Seek> ZipWriter<A> {
pub fn deep_copy_file(&mut self, src_name: &str, dest_name: &str) -> ZipResult<()> { pub fn deep_copy_file(&mut self, src_name: &str, dest_name: &str) -> ZipResult<()> {
self.finish_file()?; self.finish_file()?;
let write_position = self.inner.get_plain().stream_position()?; let write_position = self.inner.get_plain().stream_position()?;
let src_data = self.data_by_name(src_name)?; let src_data_rc = self.data_by_name(src_name)?;
let src_data = src_data_rc.borrow();
let data_start = src_data.data_start.load(); let data_start = src_data.data_start.load();
let compressed_size = src_data.compressed_size; let compressed_size = src_data.compressed_size;
if compressed_size > write_position - data_start { if compressed_size > write_position - data_start {
@ -320,8 +328,9 @@ impl<A: Read + Write + Seek> ZipWriter<A> {
compressed_size, compressed_size,
uncompressed_size, uncompressed_size,
}; };
drop(src_data);
let mut reader = BufReader::new(ZipFileReader::Raw(find_content( let mut reader = BufReader::new(ZipFileReader::Raw(find_content(
&src_data.to_owned(), &src_data_rc.clone().borrow(),
self.inner.get_plain(), self.inner.get_plain(),
)?)); )?));
let mut copy = Vec::with_capacity(compressed_size as usize); let mut copy = Vec::with_capacity(compressed_size as usize);
@ -422,7 +431,7 @@ impl<W: Write + Seek> ZipWriter<W> {
self.stats.bytes_written = 0; self.stats.bytes_written = 0;
self.stats.hasher = Hasher::new(); self.stats.hasher = Hasher::new();
self.files.push(file); self.files.push(RefCell::new(file));
} }
Ok(()) Ok(())
@ -437,9 +446,9 @@ impl<W: Write + Seek> ZipWriter<W> {
let writer = self.inner.get_plain(); let writer = self.inner.get_plain();
if !self.writing_raw { if !self.writing_raw {
let file = match self.files.last_mut() { let mut file = match self.files.last_mut() {
None => return Ok(()), None => return Ok(()),
Some(f) => f, Some(f) => f.borrow_mut(),
}; };
file.crc32 = self.stats.hasher.clone().finalize(); file.crc32 = self.stats.hasher.clone().finalize();
file.uncompressed_size = self.stats.bytes_written; file.uncompressed_size = self.stats.bytes_written;
@ -447,7 +456,7 @@ impl<W: Write + Seek> ZipWriter<W> {
let file_end = writer.stream_position()?; let file_end = writer.stream_position()?;
file.compressed_size = file_end - self.stats.start; file.compressed_size = file_end - self.stats.start;
update_local_file_header(writer, file)?; update_local_file_header(writer, &file)?;
writer.seek(SeekFrom::Start(file_end))?; writer.seek(SeekFrom::Start(file_end))?;
} }
@ -598,7 +607,7 @@ impl<W: Write + Seek> ZipWriter<W> {
self.start_entry(name, options, None)?; self.start_entry(name, options, None)?;
self.writing_to_file = true; self.writing_to_file = true;
self.writing_to_extra_field = true; self.writing_to_extra_field = true;
Ok(self.files.last().unwrap().data_start.load()) Ok(self.files.last().unwrap().borrow().data_start.load())
} }
/// End local and start central extra data. Requires [`ZipWriter::start_file_with_extra_data`]. /// End local and start central extra data. Requires [`ZipWriter::start_file_with_extra_data`].
@ -606,7 +615,12 @@ impl<W: Write + Seek> ZipWriter<W> {
/// Returns the final starting offset of the file data. /// Returns the final starting offset of the file data.
pub fn end_local_start_central_extra_data(&mut self) -> ZipResult<u64> { pub fn end_local_start_central_extra_data(&mut self) -> ZipResult<u64> {
let data_start = self.end_extra_data()?; let data_start = self.end_extra_data()?;
self.files.last_mut().unwrap().extra_field.clear(); self.files
.last_mut()
.unwrap()
.borrow_mut()
.extra_field
.clear();
self.writing_to_extra_field = true; self.writing_to_extra_field = true;
self.writing_to_central_extra_field_only = true; self.writing_to_central_extra_field_only = true;
Ok(data_start) Ok(data_start)
@ -625,9 +639,10 @@ impl<W: Write + Seek> ZipWriter<W> {
} }
let file = self.files.last_mut().unwrap(); let file = self.files.last_mut().unwrap();
validate_extra_data(file)?; validate_extra_data(&file.borrow())?;
let data_start = file.data_start.get_mut(); let mut file = file.borrow_mut();
let mut data_start_result = file.data_start.load();
if !self.writing_to_central_extra_field_only { if !self.writing_to_central_extra_field_only {
let writer = self.inner.get_plain(); let writer = self.inner.get_plain();
@ -636,9 +651,12 @@ impl<W: Write + Seek> ZipWriter<W> {
writer.write_all(&file.extra_field)?; writer.write_all(&file.extra_field)?;
// Update final `data_start`. // Update final `data_start`.
let header_end = *data_start + file.extra_field.len() as u64; let length = file.extra_field.len();
let data_start = file.data_start.get_mut();
let header_end = *data_start + length as u64;
self.stats.start = header_end; self.stats.start = header_end;
*data_start = header_end; *data_start = header_end;
data_start_result = header_end;
// Update extra field length in local file header. // Update extra field length in local file header.
let extra_field_length = let extra_field_length =
@ -653,7 +671,7 @@ impl<W: Write + Seek> ZipWriter<W> {
self.writing_to_extra_field = false; self.writing_to_extra_field = false;
self.writing_to_central_extra_field_only = false; self.writing_to_central_extra_field_only = false;
Ok(*data_start) Ok(data_start_result)
} }
/// Add a new file using the already compressed data from a ZIP file being read and renames it, this /// Add a new file using the already compressed data from a ZIP file being read and renames it, this
@ -835,7 +853,7 @@ impl<W: Write + Seek> ZipWriter<W> {
let central_start = writer.stream_position()?; let central_start = writer.stream_position()?;
for file in self.files.iter() { for file in self.files.iter() {
write_central_directory_header(writer, file)?; write_central_directory_header(writer, &file.borrow())?;
} }
let central_size = writer.stream_position()? - central_start; let central_size = writer.stream_position()? - central_start;
@ -881,9 +899,9 @@ impl<W: Write + Seek> ZipWriter<W> {
Ok(()) Ok(())
} }
fn data_by_name(&self, name: &str) -> ZipResult<&ZipFileData> { fn data_by_name(&self, name: &str) -> ZipResult<&RefCell<ZipFileData>> {
for file in self.files.iter() { for file in self.files.iter() {
if file.file_name == name { if file.borrow().file_name == name {
return Ok(file); return Ok(file);
} }
} }
@ -897,10 +915,11 @@ impl<W: Write + Seek> ZipWriter<W> {
/// some other software (e.g. Minecraft) will refuse to extract a file copied this way. /// some other software (e.g. Minecraft) will refuse to extract a file copied this way.
pub fn shallow_copy_file(&mut self, src_name: &str, dest_name: &str) -> ZipResult<()> { pub fn shallow_copy_file(&mut self, src_name: &str, dest_name: &str) -> ZipResult<()> {
self.finish_file()?; self.finish_file()?;
let src_data = self.data_by_name(src_name)?; let src_data = self.data_by_name(src_name)?.borrow();
let mut dest_data = src_data.to_owned(); let mut dest_data = src_data.to_owned();
drop(src_data);
dest_data.file_name = dest_name.into(); dest_data.file_name = dest_name.into();
self.files.push(dest_data); self.files.push(RefCell::new(dest_data));
Ok(()) Ok(())
} }
} }