diff --git a/lib/inflate.luau b/lib/inflate.luau index 8d6345e..1c49ee0 100644 --- a/lib/inflate.luau +++ b/lib/inflate.luau @@ -3,19 +3,19 @@ local Tree = {} export type Tree = typeof(setmetatable({} :: TreeInner, { __index = Tree })) type TreeInner = { - table: { number }, -- Length of 16, stores code length counts - trans: { number }, -- Length of 288, stores code to symbol translations + table: { number }, -- Length of 16, stores code length counts + trans: { number }, -- Length of 288, stores code to symbol translations } --- Creates a new Tree instance with initialized tables function Tree.new(): Tree - return setmetatable( - { - table = table.create(16, 0), - trans = table.create(288, 0), - } :: TreeInner, - { __index = Tree } - ) + return setmetatable( + { + table = table.create(16, 0), + trans = table.create(288, 0), + } :: TreeInner, + { __index = Tree } + ) end -- Data class for managing compression state and buffers @@ -37,19 +37,19 @@ export type DataInner = { --- Creates a new Data instance with initialized compression state function Data.new(source: buffer, dest: buffer): Data - return setmetatable( - { - source = source, - sourceIndex = 0, - tag = 0, - bitcount = 0, - dest = dest, - destLen = 0, - ltree = Tree.new(), - dtree = Tree.new(), - } :: DataInner, - { __index = Data } - ) + return setmetatable( + { + source = source, + sourceIndex = 0, + tag = 0, + bitcount = 0, + dest = dest, + destLen = 0, + ltree = Tree.new(), + dtree = Tree.new(), + } :: DataInner, + { __index = Data } + ) end -- Static Huffman trees used for fixed block types @@ -76,56 +76,56 @@ local lengths = table.create(288 + 32, 0) --- Builds the extra bits and base tables for length and distance codes local function buildBitsBase(bits: { number }, base: { number }, delta: number, first: number) - local sum = first + local sum = first - -- Initialize the bits table with appropriate bit lengths - for i = 0, delta - 1 do - bits[i] = 0 - end - for i = 0, 29 - delta do - bits[i + delta] = math.floor(i / delta) - end + -- Initialize the bits table with appropriate bit lengths + for i = 0, delta - 1 do + bits[i] = 0 + end + for i = 0, 29 - delta do + bits[i + delta] = math.floor(i / delta) + end - -- Build the base value table using bit lengths - for i = 0, 29 do - base[i] = sum - sum += bit32.lshift(1, bits[i]) - end + -- Build the base value table using bit lengths + for i = 0, 29 do + base[i] = sum + sum += bit32.lshift(1, bits[i]) + end end --- Constructs the fixed Huffman trees used in DEFLATE format local function buildFixedTrees(lengthTree: Tree, distTree: Tree) - -- Build the fixed length tree according to DEFLATE specification - for i = 0, 6 do - lengthTree.table[i] = 0 - end - lengthTree.table[7] = 24 - lengthTree.table[8] = 152 - lengthTree.table[9] = 112 + -- Build the fixed length tree according to DEFLATE specification + for i = 0, 6 do + lengthTree.table[i] = 0 + end + lengthTree.table[7] = 24 + lengthTree.table[8] = 152 + lengthTree.table[9] = 112 - -- Populate the translation table for length codes - for i = 0, 23 do - lengthTree.trans[i] = 256 + i - end - for i = 0, 143 do - lengthTree.trans[24 + i] = i - end - for i = 0, 7 do - lengthTree.trans[24 + 144 + i] = 280 + i - end - for i = 0, 111 do - lengthTree.trans[24 + 144 + 8 + i] = 144 + i - end + -- Populate the translation table for length codes + for i = 0, 23 do + lengthTree.trans[i] = 256 + i + end + for i = 0, 143 do + lengthTree.trans[24 + i] = i + end + for i = 0, 7 do + lengthTree.trans[24 + 144 + i] = 280 + i + end + for i = 0, 111 do + lengthTree.trans[24 + 144 + 8 + i] = 144 + i + end - -- Build the fixed distance tree (simpler than length tree) - for i = 0, 4 do - distTree.table[i] = 0 - end - distTree.table[5] = 32 + -- Build the fixed distance tree (simpler than length tree) + for i = 0, 4 do + distTree.table[i] = 0 + end + distTree.table[5] = 32 - for i = 0, 31 do - distTree.trans[i] = i - end + for i = 0, 31 do + distTree.trans[i] = i + end end --- Temporary array for building trees @@ -133,247 +133,249 @@ local offs = table.create(16, 0) --- Builds a Huffman tree from a list of code lengths local function buildTree(t: Tree, lengths: { number }, off: number, num: number) - -- Initialize the code length count table - for i = 0, 15 do - t.table[i] = 0 - end + -- Initialize the code length count table + for i = 0, 15 do + t.table[i] = 0 + end - -- Count the frequency of each code length - for i = 0, num - 1 do - t.table[lengths[off + i]] += 1 - end + -- Count the frequency of each code length + for i = 0, num - 1 do + t.table[lengths[off + i]] += 1 + end - t.table[0] = 0 + t.table[0] = 0 - -- Calculate offsets for distribution sort - local sum = 0 - for i = 0, 15 do - offs[i] = sum - sum += t.table[i] - end + -- Calculate offsets for distribution sort + local sum = 0 + for i = 0, 15 do + offs[i] = sum + sum += t.table[i] + end - -- Create the translation table mapping codes to symbols - for i = 0, num - 1 do - local len = lengths[off + i] - if len > 0 then - t.trans[offs[len]] = i - offs[len] += 1 - end - end + -- Create the translation table mapping codes to symbols + for i = 0, num - 1 do + local len = lengths[off + i] + if len > 0 then + t.trans[offs[len]] = i + offs[len] += 1 + end + end end --- Reads a single bit from the input stream local function getBit(d: Data): number - if d.bitcount <= 0 then - d.tag = buffer.readu8(d.source, d.sourceIndex) - d.sourceIndex += 1 - d.bitcount = 8 - end + if d.bitcount <= 0 then + d.tag = buffer.readu8(d.source, d.sourceIndex) + d.sourceIndex += 1 + d.bitcount = 8 + end - local bit = bit32.band(d.tag, 1) - d.tag = bit32.rshift(d.tag, 1) - d.bitcount -= 1 + local bit = bit32.band(d.tag, 1) + d.tag = bit32.rshift(d.tag, 1) + d.bitcount -= 1 - return bit + return bit end --- Reads multiple bits from the input stream with a base value local function readBits(d: Data, num: number?, base: number): number - if not num then - return base - end + if not num then + return base + end - -- Ensure we have enough bits in the buffer - while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do - d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) - d.sourceIndex += 1 - d.bitcount += 8 - end + -- Ensure we have enough bits in the buffer + while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do + d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) + d.sourceIndex += 1 + d.bitcount += 8 + end - local val = bit32.band(d.tag, bit32.rshift(0xffff, 16 - num)) - d.tag = bit32.rshift(d.tag, num) - d.bitcount -= num + local val = bit32.band(d.tag, bit32.rshift(0xffff, 16 - num)) + d.tag = bit32.rshift(d.tag, num) + d.bitcount -= num - return val + base + return val + base end --- Decodes a symbol using a Huffman tree local function decodeSymbol(d: Data, t: Tree): number - while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do - d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) - d.sourceIndex += 1 - d.bitcount += 8 - end + while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do + d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) + d.sourceIndex += 1 + d.bitcount += 8 + end - local sum, cur, len = 0, 0, 0 - local tag = d.tag + local sum, cur, len = 0, 0, 0 + local tag = d.tag - -- Traverse the Huffman tree to find the symbol - repeat - cur = 2 * cur + bit32.band(tag, 1) - tag = bit32.rshift(tag, 1) - len += 1 - sum += t.table[len] - cur -= t.table[len] - until cur < 0 + -- Traverse the Huffman tree to find the symbol + repeat + cur = 2 * cur + bit32.band(tag, 1) + tag = bit32.rshift(tag, 1) + len += 1 + sum += t.table[len] + cur -= t.table[len] + until cur < 0 - d.tag = tag - d.bitcount -= len + d.tag = tag + d.bitcount -= len - return t.trans[sum + cur] + return t.trans[sum + cur] end --- Decodes the dynamic Huffman trees for a block local function decodeTrees(d: Data, lengthTree: Tree, distTree: Tree) - local hlit = readBits(d, 5, 257) -- Number of literal/length codes - local hdist = readBits(d, 5, 1) -- Number of distance codes - local hclen = readBits(d, 4, 4) -- Number of code length codes + local hlit = readBits(d, 5, 257) -- Number of literal/length codes + local hdist = readBits(d, 5, 1) -- Number of distance codes + local hclen = readBits(d, 4, 4) -- Number of code length codes - -- Initialize code lengths array - for i = 0, 18 do - lengths[i] = 0 - end + -- Initialize code lengths array + for i = 0, 18 do + lengths[i] = 0 + end - -- Read code lengths for the code length alphabet - for i = 0, hclen - 1 do - lengths[clcIndex[i + 1]] = readBits(d, 3, 0) - end + -- Read code lengths for the code length alphabet + for i = 0, hclen - 1 do + lengths[clcIndex[i + 1]] = readBits(d, 3, 0) + end - -- Build the code lengths tree - buildTree(codeTree, lengths, 0, 19) + -- Build the code lengths tree + buildTree(codeTree, lengths, 0, 19) - -- Decode length/distance tree code lengths - local num = 0 - while num < hlit + hdist do - local sym = decodeSymbol(d, codeTree) + -- Decode length/distance tree code lengths + local num = 0 + while num < hlit + hdist do + local sym = decodeSymbol(d, codeTree) - if sym == 16 then - -- Copy previous code length 3-6 times - local prev = lengths[num - 1] - for _ = 1, readBits(d, 2, 3) do - lengths[num] = prev - num += 1 - end - elseif sym == 17 then - -- Repeat zero 3-10 times - for _ = 1, readBits(d, 3, 3) do - lengths[num] = 0 - num += 1 - end - elseif sym == 18 then - -- Repeat zero 11-138 times - for _ = 1, readBits(d, 7, 11) do - lengths[num] = 0 - num += 1 - end - else - -- Regular code length 0-15 - lengths[num] = sym - num += 1 - end - end + if sym == 16 then + -- Copy previous code length 3-6 times + local prev = lengths[num - 1] + for _ = 1, readBits(d, 2, 3) do + lengths[num] = prev + num += 1 + end + elseif sym == 17 then + -- Repeat zero 3-10 times + for _ = 1, readBits(d, 3, 3) do + lengths[num] = 0 + num += 1 + end + elseif sym == 18 then + -- Repeat zero 11-138 times + for _ = 1, readBits(d, 7, 11) do + lengths[num] = 0 + num += 1 + end + else + -- Regular code length 0-15 + lengths[num] = sym + num += 1 + end + end - -- Build the literal/length and distance trees - buildTree(lengthTree, lengths, 0, hlit) - buildTree(distTree, lengths, hlit, hdist) + -- Build the literal/length and distance trees + buildTree(lengthTree, lengths, 0, hlit) + buildTree(distTree, lengths, hlit, hdist) end --- Inflates a block of data using Huffman trees local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree) - while true do - local sym = decodeSymbol(d, lengthTree) + while true do + local sym = decodeSymbol(d, lengthTree) - if sym == 256 then - -- End of block - return - end + if sym == 256 then + -- End of block + return + end - if sym < 256 then - -- Literal byte - buffer.writeu8(d.dest, d.destLen, sym) - d.destLen += 1 - else - -- Length/distance pair for copying - sym -= 257 + if sym < 256 then + -- Literal byte + buffer.writeu8(d.dest, d.destLen, sym) + d.destLen += 1 + else + -- Length/distance pair for copying + sym -= 257 - local length = readBits(d, lengthBits[sym], lengthBase[sym]) - local dist = decodeSymbol(d, distTree) + local length = readBits(d, lengthBits[sym], lengthBase[sym]) + local dist = decodeSymbol(d, distTree) - local offs = d.destLen - readBits(d, distBits[dist], distBase[dist]) + local offs = d.destLen - readBits(d, distBits[dist], distBase[dist]) - -- Copy bytes from back reference - for i = offs, offs + length - 1 do - buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i)) - d.destLen += 1 - end - end - end + -- Copy bytes from back reference + for i = offs, offs + length - 1 do + buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i)) + d.destLen += 1 + end + end + end end --- Processes an uncompressed block local function inflateUncompressedBlock(d: Data) - -- Align to byte boundary - while d.bitcount > 8 do - d.sourceIndex -= 1 - d.bitcount -= 8 - end + -- Align to byte boundary + while d.bitcount > 8 do + d.sourceIndex -= 1 + d.bitcount -= 8 + end - -- Read block length and its complement - local length = buffer.readu8(d.source, d.sourceIndex + 1) - length = 256 * length + buffer.readu8(d.source, d.sourceIndex) + -- Read block length and its complement + local length = buffer.readu8(d.source, d.sourceIndex + 1) + length = 256 * length + buffer.readu8(d.source, d.sourceIndex) - local invlength = buffer.readu8(d.source, d.sourceIndex + 3) - invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2) + local invlength = buffer.readu8(d.source, d.sourceIndex + 3) + invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2) - -- Verify block length using ones complement - if length ~= bit32.bxor(invlength, 0xffff) then - error("Invalid block length") - end + -- Verify block length using ones complement + if length ~= bit32.bxor(invlength, 0xffff) then + error("Invalid block length") + end - d.sourceIndex += 4 + d.sourceIndex += 4 - -- Copy uncompressed data to output - for _ = 1, length do - buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex)) - d.destLen += 1 - d.sourceIndex += 1 - end + -- Copy uncompressed data to output + for _ = 1, length do + buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex)) + d.destLen += 1 + d.sourceIndex += 1 + end - d.bitcount = 0 + d.bitcount = 0 end --- Main decompression function that processes DEFLATE compressed data -local function uncompress(source: buffer): buffer - -- FIXME: This is a temporary solution to avoid a buffer overflow - -- We likely want some type of reflection with the zip metadata to - -- have a definitive buffer size - local dest = buffer.create(buffer.len(source) * 7) - local d = Data.new(source, dest) +local function uncompress(source: buffer, uncompressedSize: number?): buffer + local dest = buffer.create( + -- If the uncompressed size is known, we use that, otherwise we use a default + -- size that is a 7 times more than the compressed size; this factor works + -- well for most cases other than those with a very high compression ratio + uncompressedSize or buffer.len(source) * 7 + ) + local d = Data.new(source, dest) - repeat - local bfinal = getBit(d) -- Last block flag - local btype = readBits(d, 2, 0) -- Block type (0=uncompressed, 1=fixed, 2=dynamic) + repeat + local bfinal = getBit(d) -- Last block flag + local btype = readBits(d, 2, 0) -- Block type (0=uncompressed, 1=fixed, 2=dynamic) - if btype == 0 then - inflateUncompressedBlock(d) - elseif btype == 1 then - inflateBlockData(d, staticLengthTree, staticDistTree) - elseif btype == 2 then - decodeTrees(d, d.ltree, d.dtree) - inflateBlockData(d, d.ltree, d.dtree) - else - error("Invalid block type") - end - until bfinal == 1 + if btype == 0 then + inflateUncompressedBlock(d) + elseif btype == 1 then + inflateBlockData(d, staticLengthTree, staticDistTree) + elseif btype == 2 then + decodeTrees(d, d.ltree, d.dtree) + inflateBlockData(d, d.ltree, d.dtree) + else + error("Invalid block type") + end + until bfinal == 1 - -- Trim output buffer to actual size if needed - if d.destLen < buffer.len(dest) then - local result = buffer.create(d.destLen) - buffer.copy(result, 0, dest, 0, d.destLen) - return result - end + -- Trim output buffer to actual size if needed + if d.destLen < buffer.len(dest) then + local result = buffer.create(d.destLen) + buffer.copy(result, 0, dest, 0, d.destLen) + return result + end - return dest + return dest end -- Initialize static trees and lookup tables for DEFLATE format diff --git a/lib/init.luau b/lib/init.luau index e0e1cae..c70ebfc 100644 --- a/lib/init.luau +++ b/lib/init.luau @@ -32,20 +32,23 @@ local function validateCrc(decompressed: buffer, validation: CrcValidationOption end end -local DECOMPRESSION_ROUTINES: { [number]: (buffer, validation: CrcValidationOptions) -> buffer } = table.freeze({ - -- `STORE` decompression method - No compression - [0x00] = function(buf, validation) - validateCrc(buf, validation) - return buf - end, +local DECOMPRESSION_ROUTINES: { [number]: (buffer, number, CrcValidationOptions) -> buffer } = + table.freeze({ + -- `STORE` decompression method - No compression + [0x00] = function(buf, _, validation) + validateCrc(buf, validation) + return buf + end, - -- `DEFLATE` decompression method - Compressed raw deflate chunks - [0x08] = function(buf, validation) - local decompressed = inflate(buf) - validateCrc(decompressed, validation) - return decompressed - end, -}) + -- `DEFLATE` decompression method - Compressed raw deflate chunks + [0x08] = function(buf, uncompressedSize, validation) + -- FIXME: Why is uncompressedSize not getting inferred correctly although it + -- is typed? + local decompressed = inflate(buf, uncompressedSize :: any) + validateCrc(decompressed, validation) + return decompressed + end, + }) -- TODO: ERROR HANDLING! @@ -175,62 +178,62 @@ function ZipReader.parseCentralDirectory(self: ZipReader): () end function ZipReader.buildDirectoryTree(self: ZipReader): () - -- Sort entries to process directories first; I could either handle + -- Sort entries to process directories first; I could either handle -- directories and files in separate passes over the entries, or sort -- the entries so I handled the directories first -- I decided to do -- the latter - table.sort(self.entries, function(a, b) - if a.isDirectory ~= b.isDirectory then - return a.isDirectory - end - return a.name < b.name - end) + table.sort(self.entries, function(a, b) + if a.isDirectory ~= b.isDirectory then + return a.isDirectory + end + return a.name < b.name + end) - for _, entry in self.entries do - local parts = {} - -- Split entry path into individual components - -- e.g. "folder/subfolder/file.txt" -> {"folder", "subfolder", "file.txt"} - for part in string.gmatch(entry.name, "([^/]+)/?") do - table.insert(parts, part) - end + for _, entry in self.entries do + local parts = {} + -- Split entry path into individual components + -- e.g. "folder/subfolder/file.txt" -> {"folder", "subfolder", "file.txt"} + for part in string.gmatch(entry.name, "([^/]+)/?") do + table.insert(parts, part) + end - -- Start from root directory - local current = self.root - local path = "" + -- Start from root directory + local current = self.root + local path = "" - -- Process each path component - for i, part in parts do - path ..= part + -- Process each path component + for i, part in parts do + path ..= part - if i < #parts or entry.isDirectory then - -- Create missing directory entries for intermediate paths - if not self.directories[path] then - if entry.isDirectory and i == #parts then - -- Existing directory entry, reuse it - self.directories[path] = entry + if i < #parts or entry.isDirectory then + -- Create missing directory entries for intermediate paths + if not self.directories[path] then + if entry.isDirectory and i == #parts then + -- Existing directory entry, reuse it + self.directories[path] = entry else -- Create new directory entry for intermediate paths or undefined -- parent directories in the ZIP - local dir = ZipEntry.new(path .. "/", 0, 0, entry.timestamp, 0) - dir.isDirectory = true - dir.parent = current - self.directories[path] = dir - end + local dir = ZipEntry.new(path .. "/", 0, 0, entry.timestamp, 0) + dir.isDirectory = true + dir.parent = current + self.directories[path] = dir + end - -- Track directory in both lookup table and parent's children - table.insert(current.children, self.directories[path]) - end + -- Track directory in both lookup table and parent's children + table.insert(current.children, self.directories[path]) + end - -- Move deeper into the tree - current = self.directories[path] - continue - end + -- Move deeper into the tree + current = self.directories[path] + continue + end - -- Link file entry to its parent directory - entry.parent = current - table.insert(current.children, entry) - end - end + -- Link file entry to its parent directory + entry.parent = current + table.insert(current.children, entry) + end + end end function ZipReader.findEntry(self: ZipReader, path: string): ZipEntry @@ -364,7 +367,7 @@ function ZipReader.extract(self: ZipReader, entry: ZipEntry, options: Extraction error(`Unsupported compression, ID: {compressionMethod}`) end - content = decompress(content, { + content = decompress(content, uncompressedSize, { expected = crcChecksum, skip = optionsOrDefault.skipCrcValidation, })