mirror of
https://github.com/0x5eal/luau-unzip.git
synced 2025-04-04 06:30:53 +01:00
fix: deflate decompression failing for files with high compression ratio
This commit is contained in:
parent
6f4083f10f
commit
06b1f1a640
2 changed files with 305 additions and 300 deletions
488
lib/inflate.luau
488
lib/inflate.luau
|
@ -3,19 +3,19 @@ local Tree = {}
|
|||
|
||||
export type Tree = typeof(setmetatable({} :: TreeInner, { __index = Tree }))
|
||||
type TreeInner = {
|
||||
table: { number }, -- Length of 16, stores code length counts
|
||||
trans: { number }, -- Length of 288, stores code to symbol translations
|
||||
table: { number }, -- Length of 16, stores code length counts
|
||||
trans: { number }, -- Length of 288, stores code to symbol translations
|
||||
}
|
||||
|
||||
--- Creates a new Tree instance with initialized tables
|
||||
function Tree.new(): Tree
|
||||
return setmetatable(
|
||||
{
|
||||
table = table.create(16, 0),
|
||||
trans = table.create(288, 0),
|
||||
} :: TreeInner,
|
||||
{ __index = Tree }
|
||||
)
|
||||
return setmetatable(
|
||||
{
|
||||
table = table.create(16, 0),
|
||||
trans = table.create(288, 0),
|
||||
} :: TreeInner,
|
||||
{ __index = Tree }
|
||||
)
|
||||
end
|
||||
|
||||
-- Data class for managing compression state and buffers
|
||||
|
@ -37,19 +37,19 @@ export type DataInner = {
|
|||
|
||||
--- Creates a new Data instance with initialized compression state
|
||||
function Data.new(source: buffer, dest: buffer): Data
|
||||
return setmetatable(
|
||||
{
|
||||
source = source,
|
||||
sourceIndex = 0,
|
||||
tag = 0,
|
||||
bitcount = 0,
|
||||
dest = dest,
|
||||
destLen = 0,
|
||||
ltree = Tree.new(),
|
||||
dtree = Tree.new(),
|
||||
} :: DataInner,
|
||||
{ __index = Data }
|
||||
)
|
||||
return setmetatable(
|
||||
{
|
||||
source = source,
|
||||
sourceIndex = 0,
|
||||
tag = 0,
|
||||
bitcount = 0,
|
||||
dest = dest,
|
||||
destLen = 0,
|
||||
ltree = Tree.new(),
|
||||
dtree = Tree.new(),
|
||||
} :: DataInner,
|
||||
{ __index = Data }
|
||||
)
|
||||
end
|
||||
|
||||
-- Static Huffman trees used for fixed block types
|
||||
|
@ -76,56 +76,56 @@ local lengths = table.create(288 + 32, 0)
|
|||
|
||||
--- Builds the extra bits and base tables for length and distance codes
|
||||
local function buildBitsBase(bits: { number }, base: { number }, delta: number, first: number)
|
||||
local sum = first
|
||||
local sum = first
|
||||
|
||||
-- Initialize the bits table with appropriate bit lengths
|
||||
for i = 0, delta - 1 do
|
||||
bits[i] = 0
|
||||
end
|
||||
for i = 0, 29 - delta do
|
||||
bits[i + delta] = math.floor(i / delta)
|
||||
end
|
||||
-- Initialize the bits table with appropriate bit lengths
|
||||
for i = 0, delta - 1 do
|
||||
bits[i] = 0
|
||||
end
|
||||
for i = 0, 29 - delta do
|
||||
bits[i + delta] = math.floor(i / delta)
|
||||
end
|
||||
|
||||
-- Build the base value table using bit lengths
|
||||
for i = 0, 29 do
|
||||
base[i] = sum
|
||||
sum += bit32.lshift(1, bits[i])
|
||||
end
|
||||
-- Build the base value table using bit lengths
|
||||
for i = 0, 29 do
|
||||
base[i] = sum
|
||||
sum += bit32.lshift(1, bits[i])
|
||||
end
|
||||
end
|
||||
|
||||
--- Constructs the fixed Huffman trees used in DEFLATE format
|
||||
local function buildFixedTrees(lengthTree: Tree, distTree: Tree)
|
||||
-- Build the fixed length tree according to DEFLATE specification
|
||||
for i = 0, 6 do
|
||||
lengthTree.table[i] = 0
|
||||
end
|
||||
lengthTree.table[7] = 24
|
||||
lengthTree.table[8] = 152
|
||||
lengthTree.table[9] = 112
|
||||
-- Build the fixed length tree according to DEFLATE specification
|
||||
for i = 0, 6 do
|
||||
lengthTree.table[i] = 0
|
||||
end
|
||||
lengthTree.table[7] = 24
|
||||
lengthTree.table[8] = 152
|
||||
lengthTree.table[9] = 112
|
||||
|
||||
-- Populate the translation table for length codes
|
||||
for i = 0, 23 do
|
||||
lengthTree.trans[i] = 256 + i
|
||||
end
|
||||
for i = 0, 143 do
|
||||
lengthTree.trans[24 + i] = i
|
||||
end
|
||||
for i = 0, 7 do
|
||||
lengthTree.trans[24 + 144 + i] = 280 + i
|
||||
end
|
||||
for i = 0, 111 do
|
||||
lengthTree.trans[24 + 144 + 8 + i] = 144 + i
|
||||
end
|
||||
-- Populate the translation table for length codes
|
||||
for i = 0, 23 do
|
||||
lengthTree.trans[i] = 256 + i
|
||||
end
|
||||
for i = 0, 143 do
|
||||
lengthTree.trans[24 + i] = i
|
||||
end
|
||||
for i = 0, 7 do
|
||||
lengthTree.trans[24 + 144 + i] = 280 + i
|
||||
end
|
||||
for i = 0, 111 do
|
||||
lengthTree.trans[24 + 144 + 8 + i] = 144 + i
|
||||
end
|
||||
|
||||
-- Build the fixed distance tree (simpler than length tree)
|
||||
for i = 0, 4 do
|
||||
distTree.table[i] = 0
|
||||
end
|
||||
distTree.table[5] = 32
|
||||
-- Build the fixed distance tree (simpler than length tree)
|
||||
for i = 0, 4 do
|
||||
distTree.table[i] = 0
|
||||
end
|
||||
distTree.table[5] = 32
|
||||
|
||||
for i = 0, 31 do
|
||||
distTree.trans[i] = i
|
||||
end
|
||||
for i = 0, 31 do
|
||||
distTree.trans[i] = i
|
||||
end
|
||||
end
|
||||
|
||||
--- Temporary array for building trees
|
||||
|
@ -133,247 +133,249 @@ local offs = table.create(16, 0)
|
|||
|
||||
--- Builds a Huffman tree from a list of code lengths
|
||||
local function buildTree(t: Tree, lengths: { number }, off: number, num: number)
|
||||
-- Initialize the code length count table
|
||||
for i = 0, 15 do
|
||||
t.table[i] = 0
|
||||
end
|
||||
-- Initialize the code length count table
|
||||
for i = 0, 15 do
|
||||
t.table[i] = 0
|
||||
end
|
||||
|
||||
-- Count the frequency of each code length
|
||||
for i = 0, num - 1 do
|
||||
t.table[lengths[off + i]] += 1
|
||||
end
|
||||
-- Count the frequency of each code length
|
||||
for i = 0, num - 1 do
|
||||
t.table[lengths[off + i]] += 1
|
||||
end
|
||||
|
||||
t.table[0] = 0
|
||||
t.table[0] = 0
|
||||
|
||||
-- Calculate offsets for distribution sort
|
||||
local sum = 0
|
||||
for i = 0, 15 do
|
||||
offs[i] = sum
|
||||
sum += t.table[i]
|
||||
end
|
||||
-- Calculate offsets for distribution sort
|
||||
local sum = 0
|
||||
for i = 0, 15 do
|
||||
offs[i] = sum
|
||||
sum += t.table[i]
|
||||
end
|
||||
|
||||
-- Create the translation table mapping codes to symbols
|
||||
for i = 0, num - 1 do
|
||||
local len = lengths[off + i]
|
||||
if len > 0 then
|
||||
t.trans[offs[len]] = i
|
||||
offs[len] += 1
|
||||
end
|
||||
end
|
||||
-- Create the translation table mapping codes to symbols
|
||||
for i = 0, num - 1 do
|
||||
local len = lengths[off + i]
|
||||
if len > 0 then
|
||||
t.trans[offs[len]] = i
|
||||
offs[len] += 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Reads a single bit from the input stream
|
||||
local function getBit(d: Data): number
|
||||
if d.bitcount <= 0 then
|
||||
d.tag = buffer.readu8(d.source, d.sourceIndex)
|
||||
d.sourceIndex += 1
|
||||
d.bitcount = 8
|
||||
end
|
||||
if d.bitcount <= 0 then
|
||||
d.tag = buffer.readu8(d.source, d.sourceIndex)
|
||||
d.sourceIndex += 1
|
||||
d.bitcount = 8
|
||||
end
|
||||
|
||||
local bit = bit32.band(d.tag, 1)
|
||||
d.tag = bit32.rshift(d.tag, 1)
|
||||
d.bitcount -= 1
|
||||
local bit = bit32.band(d.tag, 1)
|
||||
d.tag = bit32.rshift(d.tag, 1)
|
||||
d.bitcount -= 1
|
||||
|
||||
return bit
|
||||
return bit
|
||||
end
|
||||
|
||||
--- Reads multiple bits from the input stream with a base value
|
||||
local function readBits(d: Data, num: number?, base: number): number
|
||||
if not num then
|
||||
return base
|
||||
end
|
||||
if not num then
|
||||
return base
|
||||
end
|
||||
|
||||
-- Ensure we have enough bits in the buffer
|
||||
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
|
||||
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
|
||||
d.sourceIndex += 1
|
||||
d.bitcount += 8
|
||||
end
|
||||
-- Ensure we have enough bits in the buffer
|
||||
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
|
||||
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
|
||||
d.sourceIndex += 1
|
||||
d.bitcount += 8
|
||||
end
|
||||
|
||||
local val = bit32.band(d.tag, bit32.rshift(0xffff, 16 - num))
|
||||
d.tag = bit32.rshift(d.tag, num)
|
||||
d.bitcount -= num
|
||||
local val = bit32.band(d.tag, bit32.rshift(0xffff, 16 - num))
|
||||
d.tag = bit32.rshift(d.tag, num)
|
||||
d.bitcount -= num
|
||||
|
||||
return val + base
|
||||
return val + base
|
||||
end
|
||||
|
||||
--- Decodes a symbol using a Huffman tree
|
||||
local function decodeSymbol(d: Data, t: Tree): number
|
||||
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
|
||||
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
|
||||
d.sourceIndex += 1
|
||||
d.bitcount += 8
|
||||
end
|
||||
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
|
||||
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
|
||||
d.sourceIndex += 1
|
||||
d.bitcount += 8
|
||||
end
|
||||
|
||||
local sum, cur, len = 0, 0, 0
|
||||
local tag = d.tag
|
||||
local sum, cur, len = 0, 0, 0
|
||||
local tag = d.tag
|
||||
|
||||
-- Traverse the Huffman tree to find the symbol
|
||||
repeat
|
||||
cur = 2 * cur + bit32.band(tag, 1)
|
||||
tag = bit32.rshift(tag, 1)
|
||||
len += 1
|
||||
sum += t.table[len]
|
||||
cur -= t.table[len]
|
||||
until cur < 0
|
||||
-- Traverse the Huffman tree to find the symbol
|
||||
repeat
|
||||
cur = 2 * cur + bit32.band(tag, 1)
|
||||
tag = bit32.rshift(tag, 1)
|
||||
len += 1
|
||||
sum += t.table[len]
|
||||
cur -= t.table[len]
|
||||
until cur < 0
|
||||
|
||||
d.tag = tag
|
||||
d.bitcount -= len
|
||||
d.tag = tag
|
||||
d.bitcount -= len
|
||||
|
||||
return t.trans[sum + cur]
|
||||
return t.trans[sum + cur]
|
||||
end
|
||||
|
||||
--- Decodes the dynamic Huffman trees for a block
|
||||
local function decodeTrees(d: Data, lengthTree: Tree, distTree: Tree)
|
||||
local hlit = readBits(d, 5, 257) -- Number of literal/length codes
|
||||
local hdist = readBits(d, 5, 1) -- Number of distance codes
|
||||
local hclen = readBits(d, 4, 4) -- Number of code length codes
|
||||
local hlit = readBits(d, 5, 257) -- Number of literal/length codes
|
||||
local hdist = readBits(d, 5, 1) -- Number of distance codes
|
||||
local hclen = readBits(d, 4, 4) -- Number of code length codes
|
||||
|
||||
-- Initialize code lengths array
|
||||
for i = 0, 18 do
|
||||
lengths[i] = 0
|
||||
end
|
||||
-- Initialize code lengths array
|
||||
for i = 0, 18 do
|
||||
lengths[i] = 0
|
||||
end
|
||||
|
||||
-- Read code lengths for the code length alphabet
|
||||
for i = 0, hclen - 1 do
|
||||
lengths[clcIndex[i + 1]] = readBits(d, 3, 0)
|
||||
end
|
||||
-- Read code lengths for the code length alphabet
|
||||
for i = 0, hclen - 1 do
|
||||
lengths[clcIndex[i + 1]] = readBits(d, 3, 0)
|
||||
end
|
||||
|
||||
-- Build the code lengths tree
|
||||
buildTree(codeTree, lengths, 0, 19)
|
||||
-- Build the code lengths tree
|
||||
buildTree(codeTree, lengths, 0, 19)
|
||||
|
||||
-- Decode length/distance tree code lengths
|
||||
local num = 0
|
||||
while num < hlit + hdist do
|
||||
local sym = decodeSymbol(d, codeTree)
|
||||
-- Decode length/distance tree code lengths
|
||||
local num = 0
|
||||
while num < hlit + hdist do
|
||||
local sym = decodeSymbol(d, codeTree)
|
||||
|
||||
if sym == 16 then
|
||||
-- Copy previous code length 3-6 times
|
||||
local prev = lengths[num - 1]
|
||||
for _ = 1, readBits(d, 2, 3) do
|
||||
lengths[num] = prev
|
||||
num += 1
|
||||
end
|
||||
elseif sym == 17 then
|
||||
-- Repeat zero 3-10 times
|
||||
for _ = 1, readBits(d, 3, 3) do
|
||||
lengths[num] = 0
|
||||
num += 1
|
||||
end
|
||||
elseif sym == 18 then
|
||||
-- Repeat zero 11-138 times
|
||||
for _ = 1, readBits(d, 7, 11) do
|
||||
lengths[num] = 0
|
||||
num += 1
|
||||
end
|
||||
else
|
||||
-- Regular code length 0-15
|
||||
lengths[num] = sym
|
||||
num += 1
|
||||
end
|
||||
end
|
||||
if sym == 16 then
|
||||
-- Copy previous code length 3-6 times
|
||||
local prev = lengths[num - 1]
|
||||
for _ = 1, readBits(d, 2, 3) do
|
||||
lengths[num] = prev
|
||||
num += 1
|
||||
end
|
||||
elseif sym == 17 then
|
||||
-- Repeat zero 3-10 times
|
||||
for _ = 1, readBits(d, 3, 3) do
|
||||
lengths[num] = 0
|
||||
num += 1
|
||||
end
|
||||
elseif sym == 18 then
|
||||
-- Repeat zero 11-138 times
|
||||
for _ = 1, readBits(d, 7, 11) do
|
||||
lengths[num] = 0
|
||||
num += 1
|
||||
end
|
||||
else
|
||||
-- Regular code length 0-15
|
||||
lengths[num] = sym
|
||||
num += 1
|
||||
end
|
||||
end
|
||||
|
||||
-- Build the literal/length and distance trees
|
||||
buildTree(lengthTree, lengths, 0, hlit)
|
||||
buildTree(distTree, lengths, hlit, hdist)
|
||||
-- Build the literal/length and distance trees
|
||||
buildTree(lengthTree, lengths, 0, hlit)
|
||||
buildTree(distTree, lengths, hlit, hdist)
|
||||
end
|
||||
|
||||
--- Inflates a block of data using Huffman trees
|
||||
local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree)
|
||||
while true do
|
||||
local sym = decodeSymbol(d, lengthTree)
|
||||
while true do
|
||||
local sym = decodeSymbol(d, lengthTree)
|
||||
|
||||
if sym == 256 then
|
||||
-- End of block
|
||||
return
|
||||
end
|
||||
if sym == 256 then
|
||||
-- End of block
|
||||
return
|
||||
end
|
||||
|
||||
if sym < 256 then
|
||||
-- Literal byte
|
||||
buffer.writeu8(d.dest, d.destLen, sym)
|
||||
d.destLen += 1
|
||||
else
|
||||
-- Length/distance pair for copying
|
||||
sym -= 257
|
||||
if sym < 256 then
|
||||
-- Literal byte
|
||||
buffer.writeu8(d.dest, d.destLen, sym)
|
||||
d.destLen += 1
|
||||
else
|
||||
-- Length/distance pair for copying
|
||||
sym -= 257
|
||||
|
||||
local length = readBits(d, lengthBits[sym], lengthBase[sym])
|
||||
local dist = decodeSymbol(d, distTree)
|
||||
local length = readBits(d, lengthBits[sym], lengthBase[sym])
|
||||
local dist = decodeSymbol(d, distTree)
|
||||
|
||||
local offs = d.destLen - readBits(d, distBits[dist], distBase[dist])
|
||||
local offs = d.destLen - readBits(d, distBits[dist], distBase[dist])
|
||||
|
||||
-- Copy bytes from back reference
|
||||
for i = offs, offs + length - 1 do
|
||||
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i))
|
||||
d.destLen += 1
|
||||
end
|
||||
end
|
||||
end
|
||||
-- Copy bytes from back reference
|
||||
for i = offs, offs + length - 1 do
|
||||
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i))
|
||||
d.destLen += 1
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
--- Processes an uncompressed block
|
||||
local function inflateUncompressedBlock(d: Data)
|
||||
-- Align to byte boundary
|
||||
while d.bitcount > 8 do
|
||||
d.sourceIndex -= 1
|
||||
d.bitcount -= 8
|
||||
end
|
||||
-- Align to byte boundary
|
||||
while d.bitcount > 8 do
|
||||
d.sourceIndex -= 1
|
||||
d.bitcount -= 8
|
||||
end
|
||||
|
||||
-- Read block length and its complement
|
||||
local length = buffer.readu8(d.source, d.sourceIndex + 1)
|
||||
length = 256 * length + buffer.readu8(d.source, d.sourceIndex)
|
||||
-- Read block length and its complement
|
||||
local length = buffer.readu8(d.source, d.sourceIndex + 1)
|
||||
length = 256 * length + buffer.readu8(d.source, d.sourceIndex)
|
||||
|
||||
local invlength = buffer.readu8(d.source, d.sourceIndex + 3)
|
||||
invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2)
|
||||
local invlength = buffer.readu8(d.source, d.sourceIndex + 3)
|
||||
invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2)
|
||||
|
||||
-- Verify block length using ones complement
|
||||
if length ~= bit32.bxor(invlength, 0xffff) then
|
||||
error("Invalid block length")
|
||||
end
|
||||
-- Verify block length using ones complement
|
||||
if length ~= bit32.bxor(invlength, 0xffff) then
|
||||
error("Invalid block length")
|
||||
end
|
||||
|
||||
d.sourceIndex += 4
|
||||
d.sourceIndex += 4
|
||||
|
||||
-- Copy uncompressed data to output
|
||||
for _ = 1, length do
|
||||
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex))
|
||||
d.destLen += 1
|
||||
d.sourceIndex += 1
|
||||
end
|
||||
-- Copy uncompressed data to output
|
||||
for _ = 1, length do
|
||||
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex))
|
||||
d.destLen += 1
|
||||
d.sourceIndex += 1
|
||||
end
|
||||
|
||||
d.bitcount = 0
|
||||
d.bitcount = 0
|
||||
end
|
||||
|
||||
--- Main decompression function that processes DEFLATE compressed data
|
||||
local function uncompress(source: buffer): buffer
|
||||
-- FIXME: This is a temporary solution to avoid a buffer overflow
|
||||
-- We likely want some type of reflection with the zip metadata to
|
||||
-- have a definitive buffer size
|
||||
local dest = buffer.create(buffer.len(source) * 7)
|
||||
local d = Data.new(source, dest)
|
||||
local function uncompress(source: buffer, uncompressedSize: number?): buffer
|
||||
local dest = buffer.create(
|
||||
-- If the uncompressed size is known, we use that, otherwise we use a default
|
||||
-- size that is a 7 times more than the compressed size; this factor works
|
||||
-- well for most cases other than those with a very high compression ratio
|
||||
uncompressedSize or buffer.len(source) * 7
|
||||
)
|
||||
local d = Data.new(source, dest)
|
||||
|
||||
repeat
|
||||
local bfinal = getBit(d) -- Last block flag
|
||||
local btype = readBits(d, 2, 0) -- Block type (0=uncompressed, 1=fixed, 2=dynamic)
|
||||
repeat
|
||||
local bfinal = getBit(d) -- Last block flag
|
||||
local btype = readBits(d, 2, 0) -- Block type (0=uncompressed, 1=fixed, 2=dynamic)
|
||||
|
||||
if btype == 0 then
|
||||
inflateUncompressedBlock(d)
|
||||
elseif btype == 1 then
|
||||
inflateBlockData(d, staticLengthTree, staticDistTree)
|
||||
elseif btype == 2 then
|
||||
decodeTrees(d, d.ltree, d.dtree)
|
||||
inflateBlockData(d, d.ltree, d.dtree)
|
||||
else
|
||||
error("Invalid block type")
|
||||
end
|
||||
until bfinal == 1
|
||||
if btype == 0 then
|
||||
inflateUncompressedBlock(d)
|
||||
elseif btype == 1 then
|
||||
inflateBlockData(d, staticLengthTree, staticDistTree)
|
||||
elseif btype == 2 then
|
||||
decodeTrees(d, d.ltree, d.dtree)
|
||||
inflateBlockData(d, d.ltree, d.dtree)
|
||||
else
|
||||
error("Invalid block type")
|
||||
end
|
||||
until bfinal == 1
|
||||
|
||||
-- Trim output buffer to actual size if needed
|
||||
if d.destLen < buffer.len(dest) then
|
||||
local result = buffer.create(d.destLen)
|
||||
buffer.copy(result, 0, dest, 0, d.destLen)
|
||||
return result
|
||||
end
|
||||
-- Trim output buffer to actual size if needed
|
||||
if d.destLen < buffer.len(dest) then
|
||||
local result = buffer.create(d.destLen)
|
||||
buffer.copy(result, 0, dest, 0, d.destLen)
|
||||
return result
|
||||
end
|
||||
|
||||
return dest
|
||||
return dest
|
||||
end
|
||||
|
||||
-- Initialize static trees and lookup tables for DEFLATE format
|
||||
|
|
117
lib/init.luau
117
lib/init.luau
|
@ -32,20 +32,23 @@ local function validateCrc(decompressed: buffer, validation: CrcValidationOption
|
|||
end
|
||||
end
|
||||
|
||||
local DECOMPRESSION_ROUTINES: { [number]: (buffer, validation: CrcValidationOptions) -> buffer } = table.freeze({
|
||||
-- `STORE` decompression method - No compression
|
||||
[0x00] = function(buf, validation)
|
||||
validateCrc(buf, validation)
|
||||
return buf
|
||||
end,
|
||||
local DECOMPRESSION_ROUTINES: { [number]: (buffer, number, CrcValidationOptions) -> buffer } =
|
||||
table.freeze({
|
||||
-- `STORE` decompression method - No compression
|
||||
[0x00] = function(buf, _, validation)
|
||||
validateCrc(buf, validation)
|
||||
return buf
|
||||
end,
|
||||
|
||||
-- `DEFLATE` decompression method - Compressed raw deflate chunks
|
||||
[0x08] = function(buf, validation)
|
||||
local decompressed = inflate(buf)
|
||||
validateCrc(decompressed, validation)
|
||||
return decompressed
|
||||
end,
|
||||
})
|
||||
-- `DEFLATE` decompression method - Compressed raw deflate chunks
|
||||
[0x08] = function(buf, uncompressedSize, validation)
|
||||
-- FIXME: Why is uncompressedSize not getting inferred correctly although it
|
||||
-- is typed?
|
||||
local decompressed = inflate(buf, uncompressedSize :: any)
|
||||
validateCrc(decompressed, validation)
|
||||
return decompressed
|
||||
end,
|
||||
})
|
||||
|
||||
-- TODO: ERROR HANDLING!
|
||||
|
||||
|
@ -175,62 +178,62 @@ function ZipReader.parseCentralDirectory(self: ZipReader): ()
|
|||
end
|
||||
|
||||
function ZipReader.buildDirectoryTree(self: ZipReader): ()
|
||||
-- Sort entries to process directories first; I could either handle
|
||||
-- Sort entries to process directories first; I could either handle
|
||||
-- directories and files in separate passes over the entries, or sort
|
||||
-- the entries so I handled the directories first -- I decided to do
|
||||
-- the latter
|
||||
table.sort(self.entries, function(a, b)
|
||||
if a.isDirectory ~= b.isDirectory then
|
||||
return a.isDirectory
|
||||
end
|
||||
return a.name < b.name
|
||||
end)
|
||||
table.sort(self.entries, function(a, b)
|
||||
if a.isDirectory ~= b.isDirectory then
|
||||
return a.isDirectory
|
||||
end
|
||||
return a.name < b.name
|
||||
end)
|
||||
|
||||
for _, entry in self.entries do
|
||||
local parts = {}
|
||||
-- Split entry path into individual components
|
||||
-- e.g. "folder/subfolder/file.txt" -> {"folder", "subfolder", "file.txt"}
|
||||
for part in string.gmatch(entry.name, "([^/]+)/?") do
|
||||
table.insert(parts, part)
|
||||
end
|
||||
for _, entry in self.entries do
|
||||
local parts = {}
|
||||
-- Split entry path into individual components
|
||||
-- e.g. "folder/subfolder/file.txt" -> {"folder", "subfolder", "file.txt"}
|
||||
for part in string.gmatch(entry.name, "([^/]+)/?") do
|
||||
table.insert(parts, part)
|
||||
end
|
||||
|
||||
-- Start from root directory
|
||||
local current = self.root
|
||||
local path = ""
|
||||
-- Start from root directory
|
||||
local current = self.root
|
||||
local path = ""
|
||||
|
||||
-- Process each path component
|
||||
for i, part in parts do
|
||||
path ..= part
|
||||
-- Process each path component
|
||||
for i, part in parts do
|
||||
path ..= part
|
||||
|
||||
if i < #parts or entry.isDirectory then
|
||||
-- Create missing directory entries for intermediate paths
|
||||
if not self.directories[path] then
|
||||
if entry.isDirectory and i == #parts then
|
||||
-- Existing directory entry, reuse it
|
||||
self.directories[path] = entry
|
||||
if i < #parts or entry.isDirectory then
|
||||
-- Create missing directory entries for intermediate paths
|
||||
if not self.directories[path] then
|
||||
if entry.isDirectory and i == #parts then
|
||||
-- Existing directory entry, reuse it
|
||||
self.directories[path] = entry
|
||||
else
|
||||
-- Create new directory entry for intermediate paths or undefined
|
||||
-- parent directories in the ZIP
|
||||
local dir = ZipEntry.new(path .. "/", 0, 0, entry.timestamp, 0)
|
||||
dir.isDirectory = true
|
||||
dir.parent = current
|
||||
self.directories[path] = dir
|
||||
end
|
||||
local dir = ZipEntry.new(path .. "/", 0, 0, entry.timestamp, 0)
|
||||
dir.isDirectory = true
|
||||
dir.parent = current
|
||||
self.directories[path] = dir
|
||||
end
|
||||
|
||||
-- Track directory in both lookup table and parent's children
|
||||
table.insert(current.children, self.directories[path])
|
||||
end
|
||||
-- Track directory in both lookup table and parent's children
|
||||
table.insert(current.children, self.directories[path])
|
||||
end
|
||||
|
||||
-- Move deeper into the tree
|
||||
current = self.directories[path]
|
||||
continue
|
||||
end
|
||||
-- Move deeper into the tree
|
||||
current = self.directories[path]
|
||||
continue
|
||||
end
|
||||
|
||||
-- Link file entry to its parent directory
|
||||
entry.parent = current
|
||||
table.insert(current.children, entry)
|
||||
end
|
||||
end
|
||||
-- Link file entry to its parent directory
|
||||
entry.parent = current
|
||||
table.insert(current.children, entry)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function ZipReader.findEntry(self: ZipReader, path: string): ZipEntry
|
||||
|
@ -364,7 +367,7 @@ function ZipReader.extract(self: ZipReader, entry: ZipEntry, options: Extraction
|
|||
error(`Unsupported compression, ID: {compressionMethod}`)
|
||||
end
|
||||
|
||||
content = decompress(content, {
|
||||
content = decompress(content, uncompressedSize, {
|
||||
expected = crcChecksum,
|
||||
skip = optionsOrDefault.skipCrcValidation,
|
||||
})
|
||||
|
|
Loading…
Add table
Reference in a new issue