diff --git a/lib/inflate.luau b/lib/inflate.luau index ae77452..d1f5d04 100644 --- a/lib/inflate.luau +++ b/lib/inflate.luau @@ -1,11 +1,13 @@ +-- Tree class for storing Huffman trees used in DEFLATE decompression local Tree = {} export type Tree = typeof(setmetatable({} :: TreeInner, { __index = Tree })) type TreeInner = { - table: { number }, -- len: 16 - trans: { number }, -- len: 288 (🏳️‍⚧️❓) + table: { number }, -- Length of 16, stores code length counts + trans: { number }, -- Length of 288, stores code to symbol translations } +--- Creates a new Tree instance with initialized tables function Tree.new(): Tree return setmetatable( { @@ -16,21 +18,24 @@ function Tree.new(): Tree ) end +-- Data class for managing compression state and buffers local Data = {} export type Data = typeof(setmetatable({} :: DataInner, { __index = Data })) +-- stylua: ignore export type DataInner = { - source: buffer, - sourceIndex: number, - tag: number, - bitcount: number, + source: buffer, -- Input buffer containing compressed data + sourceIndex: number, -- Current position in source buffer + tag: number, -- Bit buffer for reading compressed data + bitcount: number, -- Number of valid bits in tag - dest: buffer, - destLen: number, + dest: buffer, -- Output buffer for decompressed data + destLen: number, -- Current length of decompressed data - ltree: Tree, - dtree: Tree, + ltree: Tree, -- Length/literal tree for current block + dtree: Tree, -- Distance tree for current block } +--- Creates a new Data instance with initialized compression state function Data.new(source: buffer, dest: buffer): Data return setmetatable( { @@ -47,30 +52,33 @@ function Data.new(source: buffer, dest: buffer): Data ) end --- Static structures +-- Static Huffman trees used for fixed block types local staticLengthTree = Tree.new() local staticDistTree = Tree.new() --- Extra bits and base tables +-- Tables for storing extra bits and base values for length/distance codes local lengthBits = table.create(30, 0) local lengthBase = table.create(30, 0) local distBits = table.create(30, 0) local distBase = table.create(30, 0) --- Special ordering of code length codes +-- Special ordering of code length codes used in dynamic Huffman trees +-- stylua: ignore local clcIndex = { 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 } +-- Tree used for decoding code lengths in dynamic blocks local codeTree = Tree.new() local lengths = table.create(288 + 32, 0) +--- Builds the extra bits and base tables for length and distance codes local function buildBitsBase(bits: { number }, base: { number }, delta: number, first: number) local sum = first - -- build bits table + -- Initialize the bits table with appropriate bit lengths for i = 0, delta - 1 do bits[i] = 0 end @@ -78,15 +86,16 @@ local function buildBitsBase(bits: { number }, base: { number }, delta: number, bits[i + delta] = math.floor(i / delta) end - -- build base table + -- Build the base value table using bit lengths for i = 0, 29 do base[i] = sum sum += bit32.lshift(1, bits[i]) end end +--- Constructs the fixed Huffman trees used in DEFLATE format local function buildFixedTrees(lengthTree: Tree, distTree: Tree) - -- build fixed length tree + -- Build the fixed length tree according to DEFLATE specification for i = 0, 6 do lengthTree.table[i] = 0 end @@ -94,6 +103,7 @@ local function buildFixedTrees(lengthTree: Tree, distTree: Tree) lengthTree.table[8] = 152 lengthTree.table[9] = 112 + -- Populate the translation table for length codes for i = 0, 23 do lengthTree.trans[i] = 256 + i end @@ -107,7 +117,7 @@ local function buildFixedTrees(lengthTree: Tree, distTree: Tree) lengthTree.trans[24 + 144 + 8 + i] = 144 + i end - -- build fixed distance tree + -- Build the fixed distance tree (simpler than length tree) for i = 0, 4 do distTree.table[i] = 0 end @@ -118,29 +128,31 @@ local function buildFixedTrees(lengthTree: Tree, distTree: Tree) end end +--- Temporary array for building trees local offs = table.create(16, 0) +--- Builds a Huffman tree from a list of code lengths local function buildTree(t: Tree, lengths: { number }, off: number, num: number) - -- clear code length count table + -- Initialize the code length count table for i = 0, 15 do t.table[i] = 0 end - -- scan symbol lengths, and sum code length counts + -- Count the frequency of each code length for i = 0, num - 1 do t.table[lengths[off + i]] += 1 end t.table[0] = 0 - -- compute offset table for distribution sort + -- Calculate offsets for distribution sort local sum = 0 for i = 0, 15 do offs[i] = sum sum += t.table[i] end - -- create code->symbol translation table + -- Create the translation table mapping codes to symbols for i = 0, num - 1 do local len = lengths[off + i] if len > 0 then @@ -150,6 +162,7 @@ local function buildTree(t: Tree, lengths: { number }, off: number, num: number) end end +--- Reads a single bit from the input stream local function getBit(d: Data): number if d.bitcount <= 0 then d.tag = buffer.readu8(d.source, d.sourceIndex) @@ -164,11 +177,13 @@ local function getBit(d: Data): number return bit end +--- Reads multiple bits from the input stream with a base value local function readBits(d: Data, num: number?, base: number): number if not num then return base end + -- Ensure we have enough bits in the buffer while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) d.sourceIndex += 1 @@ -182,6 +197,7 @@ local function readBits(d: Data, num: number?, base: number): number return val + base end +--- Decodes a symbol using a Huffman tree local function decodeSymbol(d: Data, t: Tree): number while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) @@ -192,6 +208,7 @@ local function decodeSymbol(d: Data, t: Tree): number local sum, cur, len = 0, 0, 0 local tag = d.tag + -- Traverse the Huffman tree to find the symbol repeat cur = 2 * cur + bit32.band(tag, 1) tag = bit32.rshift(tag, 1) @@ -206,63 +223,77 @@ local function decodeSymbol(d: Data, t: Tree): number return t.trans[sum + cur] end +--- Decodes the dynamic Huffman trees for a block local function decodeTrees(d: Data, lengthTree: Tree, distTree: Tree) - local hlit = readBits(d, 5, 257) - local hdist = readBits(d, 5, 1) - local hclen = readBits(d, 4, 4) + local hlit = readBits(d, 5, 257) -- Number of literal/length codes + local hdist = readBits(d, 5, 1) -- Number of distance codes + local hclen = readBits(d, 4, 4) -- Number of code length codes + -- Initialize code lengths array for i = 0, 18 do lengths[i] = 0 end + -- Read code lengths for the code length alphabet for i = 0, hclen - 1 do lengths[clcIndex[i + 1]] = readBits(d, 3, 0) end + -- Build the code lengths tree buildTree(codeTree, lengths, 0, 19) + -- Decode length/distance tree code lengths local num = 0 while num < hlit + hdist do local sym = decodeSymbol(d, codeTree) if sym == 16 then + -- Copy previous code length 3-6 times local prev = lengths[num - 1] for _ = 1, readBits(d, 2, 3) do lengths[num] = prev num += 1 end elseif sym == 17 then + -- Repeat zero 3-10 times for _ = 1, readBits(d, 3, 3) do lengths[num] = 0 num += 1 end elseif sym == 18 then + -- Repeat zero 11-138 times for _ = 1, readBits(d, 7, 11) do lengths[num] = 0 num += 1 end else + -- Regular code length 0-15 lengths[num] = sym num += 1 end end + -- Build the literal/length and distance trees buildTree(lengthTree, lengths, 0, hlit) buildTree(distTree, lengths, hlit, hdist) end +--- Inflates a block of data using Huffman trees local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree) while true do local sym = decodeSymbol(d, lengthTree) if sym == 256 then + -- End of block return end if sym < 256 then + -- Literal byte buffer.writeu8(d.dest, d.destLen, sym) d.destLen += 1 else + -- Length/distance pair for copying sym -= 257 local length = readBits(d, lengthBits[sym], lengthBase[sym]) @@ -270,6 +301,7 @@ local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree) local offs = d.destLen - readBits(d, distBits[dist], distBase[dist]) + -- Copy bytes from back reference for i = offs, offs + length - 1 do buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i)) d.destLen += 1 @@ -278,24 +310,29 @@ local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree) end end +--- Processes an uncompressed block local function inflateUncompressedBlock(d: Data) + -- Align to byte boundary while d.bitcount > 8 do d.sourceIndex -= 1 d.bitcount -= 8 end + -- Read block length and its complement local length = buffer.readu8(d.source, d.sourceIndex + 1) length = 256 * length + buffer.readu8(d.source, d.sourceIndex) local invlength = buffer.readu8(d.source, d.sourceIndex + 3) invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2) + -- Verify block length using ones complement if length ~= bit32.bxor(invlength, 0xffff) then error("Invalid block length") end d.sourceIndex += 4 + -- Copy uncompressed data to output for _ = 1, length do buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex)) d.destLen += 1 @@ -305,13 +342,14 @@ local function inflateUncompressedBlock(d: Data) d.bitcount = 0 end +--- Main decompression function that processes DEFLATE compressed data local function uncompress(source: buffer): buffer local dest = buffer.create(buffer.len(source) * 4) local d = Data.new(source, dest) repeat - local bfinal = getBit(d) - local btype = readBits(d, 2, 0) + local bfinal = getBit(d) -- Last block flag + local btype = readBits(d, 2, 0) -- Block type (0=uncompressed, 1=fixed, 2=dynamic) if btype == 0 then inflateUncompressedBlock(d) @@ -325,6 +363,7 @@ local function uncompress(source: buffer): buffer end until bfinal == 1 + -- Trim output buffer to actual size if needed if d.destLen < buffer.len(dest) then local result = buffer.create(d.destLen) buffer.copy(result, 0, dest, 0, d.destLen) @@ -334,7 +373,7 @@ local function uncompress(source: buffer): buffer return dest end --- Initialize static trees and tables +-- Initialize static trees and lookup tables for DEFLATE format buildFixedTrees(staticLengthTree, staticDistTree) buildBitsBase(lengthBits, lengthBase, 4, 3) buildBitsBase(distBits, distBase, 2, 1)