fix: deflate decompression failing for files with high compression ratio

This commit is contained in:
Erica Marigold 2025-01-02 06:07:53 +00:00
parent 6f4083f10f
commit 06b1f1a640
Signed by: DevComp
GPG key ID: 429EF1C337871656
2 changed files with 305 additions and 300 deletions

View file

@ -3,19 +3,19 @@ local Tree = {}
export type Tree = typeof(setmetatable({} :: TreeInner, { __index = Tree })) export type Tree = typeof(setmetatable({} :: TreeInner, { __index = Tree }))
type TreeInner = { type TreeInner = {
table: { number }, -- Length of 16, stores code length counts table: { number }, -- Length of 16, stores code length counts
trans: { number }, -- Length of 288, stores code to symbol translations trans: { number }, -- Length of 288, stores code to symbol translations
} }
--- Creates a new Tree instance with initialized tables --- Creates a new Tree instance with initialized tables
function Tree.new(): Tree function Tree.new(): Tree
return setmetatable( return setmetatable(
{ {
table = table.create(16, 0), table = table.create(16, 0),
trans = table.create(288, 0), trans = table.create(288, 0),
} :: TreeInner, } :: TreeInner,
{ __index = Tree } { __index = Tree }
) )
end end
-- Data class for managing compression state and buffers -- Data class for managing compression state and buffers
@ -37,19 +37,19 @@ export type DataInner = {
--- Creates a new Data instance with initialized compression state --- Creates a new Data instance with initialized compression state
function Data.new(source: buffer, dest: buffer): Data function Data.new(source: buffer, dest: buffer): Data
return setmetatable( return setmetatable(
{ {
source = source, source = source,
sourceIndex = 0, sourceIndex = 0,
tag = 0, tag = 0,
bitcount = 0, bitcount = 0,
dest = dest, dest = dest,
destLen = 0, destLen = 0,
ltree = Tree.new(), ltree = Tree.new(),
dtree = Tree.new(), dtree = Tree.new(),
} :: DataInner, } :: DataInner,
{ __index = Data } { __index = Data }
) )
end end
-- Static Huffman trees used for fixed block types -- Static Huffman trees used for fixed block types
@ -76,56 +76,56 @@ local lengths = table.create(288 + 32, 0)
--- Builds the extra bits and base tables for length and distance codes --- Builds the extra bits and base tables for length and distance codes
local function buildBitsBase(bits: { number }, base: { number }, delta: number, first: number) local function buildBitsBase(bits: { number }, base: { number }, delta: number, first: number)
local sum = first local sum = first
-- Initialize the bits table with appropriate bit lengths -- Initialize the bits table with appropriate bit lengths
for i = 0, delta - 1 do for i = 0, delta - 1 do
bits[i] = 0 bits[i] = 0
end end
for i = 0, 29 - delta do for i = 0, 29 - delta do
bits[i + delta] = math.floor(i / delta) bits[i + delta] = math.floor(i / delta)
end end
-- Build the base value table using bit lengths -- Build the base value table using bit lengths
for i = 0, 29 do for i = 0, 29 do
base[i] = sum base[i] = sum
sum += bit32.lshift(1, bits[i]) sum += bit32.lshift(1, bits[i])
end end
end end
--- Constructs the fixed Huffman trees used in DEFLATE format --- Constructs the fixed Huffman trees used in DEFLATE format
local function buildFixedTrees(lengthTree: Tree, distTree: Tree) local function buildFixedTrees(lengthTree: Tree, distTree: Tree)
-- Build the fixed length tree according to DEFLATE specification -- Build the fixed length tree according to DEFLATE specification
for i = 0, 6 do for i = 0, 6 do
lengthTree.table[i] = 0 lengthTree.table[i] = 0
end end
lengthTree.table[7] = 24 lengthTree.table[7] = 24
lengthTree.table[8] = 152 lengthTree.table[8] = 152
lengthTree.table[9] = 112 lengthTree.table[9] = 112
-- Populate the translation table for length codes -- Populate the translation table for length codes
for i = 0, 23 do for i = 0, 23 do
lengthTree.trans[i] = 256 + i lengthTree.trans[i] = 256 + i
end end
for i = 0, 143 do for i = 0, 143 do
lengthTree.trans[24 + i] = i lengthTree.trans[24 + i] = i
end end
for i = 0, 7 do for i = 0, 7 do
lengthTree.trans[24 + 144 + i] = 280 + i lengthTree.trans[24 + 144 + i] = 280 + i
end end
for i = 0, 111 do for i = 0, 111 do
lengthTree.trans[24 + 144 + 8 + i] = 144 + i lengthTree.trans[24 + 144 + 8 + i] = 144 + i
end end
-- Build the fixed distance tree (simpler than length tree) -- Build the fixed distance tree (simpler than length tree)
for i = 0, 4 do for i = 0, 4 do
distTree.table[i] = 0 distTree.table[i] = 0
end end
distTree.table[5] = 32 distTree.table[5] = 32
for i = 0, 31 do for i = 0, 31 do
distTree.trans[i] = i distTree.trans[i] = i
end end
end end
--- Temporary array for building trees --- Temporary array for building trees
@ -133,247 +133,249 @@ local offs = table.create(16, 0)
--- Builds a Huffman tree from a list of code lengths --- Builds a Huffman tree from a list of code lengths
local function buildTree(t: Tree, lengths: { number }, off: number, num: number) local function buildTree(t: Tree, lengths: { number }, off: number, num: number)
-- Initialize the code length count table -- Initialize the code length count table
for i = 0, 15 do for i = 0, 15 do
t.table[i] = 0 t.table[i] = 0
end end
-- Count the frequency of each code length -- Count the frequency of each code length
for i = 0, num - 1 do for i = 0, num - 1 do
t.table[lengths[off + i]] += 1 t.table[lengths[off + i]] += 1
end end
t.table[0] = 0 t.table[0] = 0
-- Calculate offsets for distribution sort -- Calculate offsets for distribution sort
local sum = 0 local sum = 0
for i = 0, 15 do for i = 0, 15 do
offs[i] = sum offs[i] = sum
sum += t.table[i] sum += t.table[i]
end end
-- Create the translation table mapping codes to symbols -- Create the translation table mapping codes to symbols
for i = 0, num - 1 do for i = 0, num - 1 do
local len = lengths[off + i] local len = lengths[off + i]
if len > 0 then if len > 0 then
t.trans[offs[len]] = i t.trans[offs[len]] = i
offs[len] += 1 offs[len] += 1
end end
end end
end end
--- Reads a single bit from the input stream --- Reads a single bit from the input stream
local function getBit(d: Data): number local function getBit(d: Data): number
if d.bitcount <= 0 then if d.bitcount <= 0 then
d.tag = buffer.readu8(d.source, d.sourceIndex) d.tag = buffer.readu8(d.source, d.sourceIndex)
d.sourceIndex += 1 d.sourceIndex += 1
d.bitcount = 8 d.bitcount = 8
end end
local bit = bit32.band(d.tag, 1) local bit = bit32.band(d.tag, 1)
d.tag = bit32.rshift(d.tag, 1) d.tag = bit32.rshift(d.tag, 1)
d.bitcount -= 1 d.bitcount -= 1
return bit return bit
end end
--- Reads multiple bits from the input stream with a base value --- Reads multiple bits from the input stream with a base value
local function readBits(d: Data, num: number?, base: number): number local function readBits(d: Data, num: number?, base: number): number
if not num then if not num then
return base return base
end end
-- Ensure we have enough bits in the buffer -- Ensure we have enough bits in the buffer
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
d.sourceIndex += 1 d.sourceIndex += 1
d.bitcount += 8 d.bitcount += 8
end end
local val = bit32.band(d.tag, bit32.rshift(0xffff, 16 - num)) local val = bit32.band(d.tag, bit32.rshift(0xffff, 16 - num))
d.tag = bit32.rshift(d.tag, num) d.tag = bit32.rshift(d.tag, num)
d.bitcount -= num d.bitcount -= num
return val + base return val + base
end end
--- Decodes a symbol using a Huffman tree --- Decodes a symbol using a Huffman tree
local function decodeSymbol(d: Data, t: Tree): number local function decodeSymbol(d: Data, t: Tree): number
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
d.sourceIndex += 1 d.sourceIndex += 1
d.bitcount += 8 d.bitcount += 8
end end
local sum, cur, len = 0, 0, 0 local sum, cur, len = 0, 0, 0
local tag = d.tag local tag = d.tag
-- Traverse the Huffman tree to find the symbol -- Traverse the Huffman tree to find the symbol
repeat repeat
cur = 2 * cur + bit32.band(tag, 1) cur = 2 * cur + bit32.band(tag, 1)
tag = bit32.rshift(tag, 1) tag = bit32.rshift(tag, 1)
len += 1 len += 1
sum += t.table[len] sum += t.table[len]
cur -= t.table[len] cur -= t.table[len]
until cur < 0 until cur < 0
d.tag = tag d.tag = tag
d.bitcount -= len d.bitcount -= len
return t.trans[sum + cur] return t.trans[sum + cur]
end end
--- Decodes the dynamic Huffman trees for a block --- Decodes the dynamic Huffman trees for a block
local function decodeTrees(d: Data, lengthTree: Tree, distTree: Tree) local function decodeTrees(d: Data, lengthTree: Tree, distTree: Tree)
local hlit = readBits(d, 5, 257) -- Number of literal/length codes local hlit = readBits(d, 5, 257) -- Number of literal/length codes
local hdist = readBits(d, 5, 1) -- Number of distance codes local hdist = readBits(d, 5, 1) -- Number of distance codes
local hclen = readBits(d, 4, 4) -- Number of code length codes local hclen = readBits(d, 4, 4) -- Number of code length codes
-- Initialize code lengths array -- Initialize code lengths array
for i = 0, 18 do for i = 0, 18 do
lengths[i] = 0 lengths[i] = 0
end end
-- Read code lengths for the code length alphabet -- Read code lengths for the code length alphabet
for i = 0, hclen - 1 do for i = 0, hclen - 1 do
lengths[clcIndex[i + 1]] = readBits(d, 3, 0) lengths[clcIndex[i + 1]] = readBits(d, 3, 0)
end end
-- Build the code lengths tree -- Build the code lengths tree
buildTree(codeTree, lengths, 0, 19) buildTree(codeTree, lengths, 0, 19)
-- Decode length/distance tree code lengths -- Decode length/distance tree code lengths
local num = 0 local num = 0
while num < hlit + hdist do while num < hlit + hdist do
local sym = decodeSymbol(d, codeTree) local sym = decodeSymbol(d, codeTree)
if sym == 16 then if sym == 16 then
-- Copy previous code length 3-6 times -- Copy previous code length 3-6 times
local prev = lengths[num - 1] local prev = lengths[num - 1]
for _ = 1, readBits(d, 2, 3) do for _ = 1, readBits(d, 2, 3) do
lengths[num] = prev lengths[num] = prev
num += 1 num += 1
end end
elseif sym == 17 then elseif sym == 17 then
-- Repeat zero 3-10 times -- Repeat zero 3-10 times
for _ = 1, readBits(d, 3, 3) do for _ = 1, readBits(d, 3, 3) do
lengths[num] = 0 lengths[num] = 0
num += 1 num += 1
end end
elseif sym == 18 then elseif sym == 18 then
-- Repeat zero 11-138 times -- Repeat zero 11-138 times
for _ = 1, readBits(d, 7, 11) do for _ = 1, readBits(d, 7, 11) do
lengths[num] = 0 lengths[num] = 0
num += 1 num += 1
end end
else else
-- Regular code length 0-15 -- Regular code length 0-15
lengths[num] = sym lengths[num] = sym
num += 1 num += 1
end end
end end
-- Build the literal/length and distance trees -- Build the literal/length and distance trees
buildTree(lengthTree, lengths, 0, hlit) buildTree(lengthTree, lengths, 0, hlit)
buildTree(distTree, lengths, hlit, hdist) buildTree(distTree, lengths, hlit, hdist)
end end
--- Inflates a block of data using Huffman trees --- Inflates a block of data using Huffman trees
local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree) local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree)
while true do while true do
local sym = decodeSymbol(d, lengthTree) local sym = decodeSymbol(d, lengthTree)
if sym == 256 then if sym == 256 then
-- End of block -- End of block
return return
end end
if sym < 256 then if sym < 256 then
-- Literal byte -- Literal byte
buffer.writeu8(d.dest, d.destLen, sym) buffer.writeu8(d.dest, d.destLen, sym)
d.destLen += 1 d.destLen += 1
else else
-- Length/distance pair for copying -- Length/distance pair for copying
sym -= 257 sym -= 257
local length = readBits(d, lengthBits[sym], lengthBase[sym]) local length = readBits(d, lengthBits[sym], lengthBase[sym])
local dist = decodeSymbol(d, distTree) local dist = decodeSymbol(d, distTree)
local offs = d.destLen - readBits(d, distBits[dist], distBase[dist]) local offs = d.destLen - readBits(d, distBits[dist], distBase[dist])
-- Copy bytes from back reference -- Copy bytes from back reference
for i = offs, offs + length - 1 do for i = offs, offs + length - 1 do
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i)) buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i))
d.destLen += 1 d.destLen += 1
end end
end end
end end
end end
--- Processes an uncompressed block --- Processes an uncompressed block
local function inflateUncompressedBlock(d: Data) local function inflateUncompressedBlock(d: Data)
-- Align to byte boundary -- Align to byte boundary
while d.bitcount > 8 do while d.bitcount > 8 do
d.sourceIndex -= 1 d.sourceIndex -= 1
d.bitcount -= 8 d.bitcount -= 8
end end
-- Read block length and its complement -- Read block length and its complement
local length = buffer.readu8(d.source, d.sourceIndex + 1) local length = buffer.readu8(d.source, d.sourceIndex + 1)
length = 256 * length + buffer.readu8(d.source, d.sourceIndex) length = 256 * length + buffer.readu8(d.source, d.sourceIndex)
local invlength = buffer.readu8(d.source, d.sourceIndex + 3) local invlength = buffer.readu8(d.source, d.sourceIndex + 3)
invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2) invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2)
-- Verify block length using ones complement -- Verify block length using ones complement
if length ~= bit32.bxor(invlength, 0xffff) then if length ~= bit32.bxor(invlength, 0xffff) then
error("Invalid block length") error("Invalid block length")
end end
d.sourceIndex += 4 d.sourceIndex += 4
-- Copy uncompressed data to output -- Copy uncompressed data to output
for _ = 1, length do for _ = 1, length do
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex)) buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex))
d.destLen += 1 d.destLen += 1
d.sourceIndex += 1 d.sourceIndex += 1
end end
d.bitcount = 0 d.bitcount = 0
end end
--- Main decompression function that processes DEFLATE compressed data --- Main decompression function that processes DEFLATE compressed data
local function uncompress(source: buffer): buffer local function uncompress(source: buffer, uncompressedSize: number?): buffer
-- FIXME: This is a temporary solution to avoid a buffer overflow local dest = buffer.create(
-- We likely want some type of reflection with the zip metadata to -- If the uncompressed size is known, we use that, otherwise we use a default
-- have a definitive buffer size -- size that is a 7 times more than the compressed size; this factor works
local dest = buffer.create(buffer.len(source) * 7) -- well for most cases other than those with a very high compression ratio
local d = Data.new(source, dest) uncompressedSize or buffer.len(source) * 7
)
local d = Data.new(source, dest)
repeat repeat
local bfinal = getBit(d) -- Last block flag local bfinal = getBit(d) -- Last block flag
local btype = readBits(d, 2, 0) -- Block type (0=uncompressed, 1=fixed, 2=dynamic) local btype = readBits(d, 2, 0) -- Block type (0=uncompressed, 1=fixed, 2=dynamic)
if btype == 0 then if btype == 0 then
inflateUncompressedBlock(d) inflateUncompressedBlock(d)
elseif btype == 1 then elseif btype == 1 then
inflateBlockData(d, staticLengthTree, staticDistTree) inflateBlockData(d, staticLengthTree, staticDistTree)
elseif btype == 2 then elseif btype == 2 then
decodeTrees(d, d.ltree, d.dtree) decodeTrees(d, d.ltree, d.dtree)
inflateBlockData(d, d.ltree, d.dtree) inflateBlockData(d, d.ltree, d.dtree)
else else
error("Invalid block type") error("Invalid block type")
end end
until bfinal == 1 until bfinal == 1
-- Trim output buffer to actual size if needed -- Trim output buffer to actual size if needed
if d.destLen < buffer.len(dest) then if d.destLen < buffer.len(dest) then
local result = buffer.create(d.destLen) local result = buffer.create(d.destLen)
buffer.copy(result, 0, dest, 0, d.destLen) buffer.copy(result, 0, dest, 0, d.destLen)
return result return result
end end
return dest return dest
end end
-- Initialize static trees and lookup tables for DEFLATE format -- Initialize static trees and lookup tables for DEFLATE format

View file

@ -32,20 +32,23 @@ local function validateCrc(decompressed: buffer, validation: CrcValidationOption
end end
end end
local DECOMPRESSION_ROUTINES: { [number]: (buffer, validation: CrcValidationOptions) -> buffer } = table.freeze({ local DECOMPRESSION_ROUTINES: { [number]: (buffer, number, CrcValidationOptions) -> buffer } =
-- `STORE` decompression method - No compression table.freeze({
[0x00] = function(buf, validation) -- `STORE` decompression method - No compression
validateCrc(buf, validation) [0x00] = function(buf, _, validation)
return buf validateCrc(buf, validation)
end, return buf
end,
-- `DEFLATE` decompression method - Compressed raw deflate chunks -- `DEFLATE` decompression method - Compressed raw deflate chunks
[0x08] = function(buf, validation) [0x08] = function(buf, uncompressedSize, validation)
local decompressed = inflate(buf) -- FIXME: Why is uncompressedSize not getting inferred correctly although it
validateCrc(decompressed, validation) -- is typed?
return decompressed local decompressed = inflate(buf, uncompressedSize :: any)
end, validateCrc(decompressed, validation)
}) return decompressed
end,
})
-- TODO: ERROR HANDLING! -- TODO: ERROR HANDLING!
@ -175,62 +178,62 @@ function ZipReader.parseCentralDirectory(self: ZipReader): ()
end end
function ZipReader.buildDirectoryTree(self: ZipReader): () function ZipReader.buildDirectoryTree(self: ZipReader): ()
-- Sort entries to process directories first; I could either handle -- Sort entries to process directories first; I could either handle
-- directories and files in separate passes over the entries, or sort -- directories and files in separate passes over the entries, or sort
-- the entries so I handled the directories first -- I decided to do -- the entries so I handled the directories first -- I decided to do
-- the latter -- the latter
table.sort(self.entries, function(a, b) table.sort(self.entries, function(a, b)
if a.isDirectory ~= b.isDirectory then if a.isDirectory ~= b.isDirectory then
return a.isDirectory return a.isDirectory
end end
return a.name < b.name return a.name < b.name
end) end)
for _, entry in self.entries do for _, entry in self.entries do
local parts = {} local parts = {}
-- Split entry path into individual components -- Split entry path into individual components
-- e.g. "folder/subfolder/file.txt" -> {"folder", "subfolder", "file.txt"} -- e.g. "folder/subfolder/file.txt" -> {"folder", "subfolder", "file.txt"}
for part in string.gmatch(entry.name, "([^/]+)/?") do for part in string.gmatch(entry.name, "([^/]+)/?") do
table.insert(parts, part) table.insert(parts, part)
end end
-- Start from root directory -- Start from root directory
local current = self.root local current = self.root
local path = "" local path = ""
-- Process each path component -- Process each path component
for i, part in parts do for i, part in parts do
path ..= part path ..= part
if i < #parts or entry.isDirectory then if i < #parts or entry.isDirectory then
-- Create missing directory entries for intermediate paths -- Create missing directory entries for intermediate paths
if not self.directories[path] then if not self.directories[path] then
if entry.isDirectory and i == #parts then if entry.isDirectory and i == #parts then
-- Existing directory entry, reuse it -- Existing directory entry, reuse it
self.directories[path] = entry self.directories[path] = entry
else else
-- Create new directory entry for intermediate paths or undefined -- Create new directory entry for intermediate paths or undefined
-- parent directories in the ZIP -- parent directories in the ZIP
local dir = ZipEntry.new(path .. "/", 0, 0, entry.timestamp, 0) local dir = ZipEntry.new(path .. "/", 0, 0, entry.timestamp, 0)
dir.isDirectory = true dir.isDirectory = true
dir.parent = current dir.parent = current
self.directories[path] = dir self.directories[path] = dir
end end
-- Track directory in both lookup table and parent's children -- Track directory in both lookup table and parent's children
table.insert(current.children, self.directories[path]) table.insert(current.children, self.directories[path])
end end
-- Move deeper into the tree -- Move deeper into the tree
current = self.directories[path] current = self.directories[path]
continue continue
end end
-- Link file entry to its parent directory -- Link file entry to its parent directory
entry.parent = current entry.parent = current
table.insert(current.children, entry) table.insert(current.children, entry)
end end
end end
end end
function ZipReader.findEntry(self: ZipReader, path: string): ZipEntry function ZipReader.findEntry(self: ZipReader, path: string): ZipEntry
@ -364,7 +367,7 @@ function ZipReader.extract(self: ZipReader, entry: ZipEntry, options: Extraction
error(`Unsupported compression, ID: {compressionMethod}`) error(`Unsupported compression, ID: {compressionMethod}`)
end end
content = decompress(content, { content = decompress(content, uncompressedSize, {
expected = crcChecksum, expected = crcChecksum,
skip = optionsOrDefault.skipCrcValidation, skip = optionsOrDefault.skipCrcValidation,
}) })