mirror of
https://github.com/0x5eal/luau-unzip.git
synced 2025-04-10 17:20:53 +01:00
refactor: add more comments to inflate implementation
This commit is contained in:
parent
67139503df
commit
b1818de2f2
1 changed files with 66 additions and 27 deletions
|
@ -1,11 +1,13 @@
|
||||||
|
-- Tree class for storing Huffman trees used in DEFLATE decompression
|
||||||
local Tree = {}
|
local Tree = {}
|
||||||
|
|
||||||
export type Tree = typeof(setmetatable({} :: TreeInner, { __index = Tree }))
|
export type Tree = typeof(setmetatable({} :: TreeInner, { __index = Tree }))
|
||||||
type TreeInner = {
|
type TreeInner = {
|
||||||
table: { number }, -- len: 16
|
table: { number }, -- Length of 16, stores code length counts
|
||||||
trans: { number }, -- len: 288 (🏳️⚧️❓)
|
trans: { number }, -- Length of 288, stores code to symbol translations
|
||||||
}
|
}
|
||||||
|
|
||||||
|
--- Creates a new Tree instance with initialized tables
|
||||||
function Tree.new(): Tree
|
function Tree.new(): Tree
|
||||||
return setmetatable(
|
return setmetatable(
|
||||||
{
|
{
|
||||||
|
@ -16,21 +18,24 @@ function Tree.new(): Tree
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Data class for managing compression state and buffers
|
||||||
local Data = {}
|
local Data = {}
|
||||||
export type Data = typeof(setmetatable({} :: DataInner, { __index = Data }))
|
export type Data = typeof(setmetatable({} :: DataInner, { __index = Data }))
|
||||||
|
-- stylua: ignore
|
||||||
export type DataInner = {
|
export type DataInner = {
|
||||||
source: buffer,
|
source: buffer, -- Input buffer containing compressed data
|
||||||
sourceIndex: number,
|
sourceIndex: number, -- Current position in source buffer
|
||||||
tag: number,
|
tag: number, -- Bit buffer for reading compressed data
|
||||||
bitcount: number,
|
bitcount: number, -- Number of valid bits in tag
|
||||||
|
|
||||||
dest: buffer,
|
dest: buffer, -- Output buffer for decompressed data
|
||||||
destLen: number,
|
destLen: number, -- Current length of decompressed data
|
||||||
|
|
||||||
ltree: Tree,
|
ltree: Tree, -- Length/literal tree for current block
|
||||||
dtree: Tree,
|
dtree: Tree, -- Distance tree for current block
|
||||||
}
|
}
|
||||||
|
|
||||||
|
--- Creates a new Data instance with initialized compression state
|
||||||
function Data.new(source: buffer, dest: buffer): Data
|
function Data.new(source: buffer, dest: buffer): Data
|
||||||
return setmetatable(
|
return setmetatable(
|
||||||
{
|
{
|
||||||
|
@ -47,30 +52,33 @@ function Data.new(source: buffer, dest: buffer): Data
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Static structures
|
-- Static Huffman trees used for fixed block types
|
||||||
local staticLengthTree = Tree.new()
|
local staticLengthTree = Tree.new()
|
||||||
local staticDistTree = Tree.new()
|
local staticDistTree = Tree.new()
|
||||||
|
|
||||||
-- Extra bits and base tables
|
-- Tables for storing extra bits and base values for length/distance codes
|
||||||
local lengthBits = table.create(30, 0)
|
local lengthBits = table.create(30, 0)
|
||||||
local lengthBase = table.create(30, 0)
|
local lengthBase = table.create(30, 0)
|
||||||
local distBits = table.create(30, 0)
|
local distBits = table.create(30, 0)
|
||||||
local distBase = table.create(30, 0)
|
local distBase = table.create(30, 0)
|
||||||
|
|
||||||
-- Special ordering of code length codes
|
-- Special ordering of code length codes used in dynamic Huffman trees
|
||||||
|
-- stylua: ignore
|
||||||
local clcIndex = {
|
local clcIndex = {
|
||||||
16, 17, 18, 0, 8, 7, 9, 6,
|
16, 17, 18, 0, 8, 7, 9, 6,
|
||||||
10, 5, 11, 4, 12, 3, 13, 2,
|
10, 5, 11, 4, 12, 3, 13, 2,
|
||||||
14, 1, 15
|
14, 1, 15
|
||||||
}
|
}
|
||||||
|
|
||||||
|
-- Tree used for decoding code lengths in dynamic blocks
|
||||||
local codeTree = Tree.new()
|
local codeTree = Tree.new()
|
||||||
local lengths = table.create(288 + 32, 0)
|
local lengths = table.create(288 + 32, 0)
|
||||||
|
|
||||||
|
--- Builds the extra bits and base tables for length and distance codes
|
||||||
local function buildBitsBase(bits: { number }, base: { number }, delta: number, first: number)
|
local function buildBitsBase(bits: { number }, base: { number }, delta: number, first: number)
|
||||||
local sum = first
|
local sum = first
|
||||||
|
|
||||||
-- build bits table
|
-- Initialize the bits table with appropriate bit lengths
|
||||||
for i = 0, delta - 1 do
|
for i = 0, delta - 1 do
|
||||||
bits[i] = 0
|
bits[i] = 0
|
||||||
end
|
end
|
||||||
|
@ -78,15 +86,16 @@ local function buildBitsBase(bits: { number }, base: { number }, delta: number,
|
||||||
bits[i + delta] = math.floor(i / delta)
|
bits[i + delta] = math.floor(i / delta)
|
||||||
end
|
end
|
||||||
|
|
||||||
-- build base table
|
-- Build the base value table using bit lengths
|
||||||
for i = 0, 29 do
|
for i = 0, 29 do
|
||||||
base[i] = sum
|
base[i] = sum
|
||||||
sum += bit32.lshift(1, bits[i])
|
sum += bit32.lshift(1, bits[i])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Constructs the fixed Huffman trees used in DEFLATE format
|
||||||
local function buildFixedTrees(lengthTree: Tree, distTree: Tree)
|
local function buildFixedTrees(lengthTree: Tree, distTree: Tree)
|
||||||
-- build fixed length tree
|
-- Build the fixed length tree according to DEFLATE specification
|
||||||
for i = 0, 6 do
|
for i = 0, 6 do
|
||||||
lengthTree.table[i] = 0
|
lengthTree.table[i] = 0
|
||||||
end
|
end
|
||||||
|
@ -94,6 +103,7 @@ local function buildFixedTrees(lengthTree: Tree, distTree: Tree)
|
||||||
lengthTree.table[8] = 152
|
lengthTree.table[8] = 152
|
||||||
lengthTree.table[9] = 112
|
lengthTree.table[9] = 112
|
||||||
|
|
||||||
|
-- Populate the translation table for length codes
|
||||||
for i = 0, 23 do
|
for i = 0, 23 do
|
||||||
lengthTree.trans[i] = 256 + i
|
lengthTree.trans[i] = 256 + i
|
||||||
end
|
end
|
||||||
|
@ -107,7 +117,7 @@ local function buildFixedTrees(lengthTree: Tree, distTree: Tree)
|
||||||
lengthTree.trans[24 + 144 + 8 + i] = 144 + i
|
lengthTree.trans[24 + 144 + 8 + i] = 144 + i
|
||||||
end
|
end
|
||||||
|
|
||||||
-- build fixed distance tree
|
-- Build the fixed distance tree (simpler than length tree)
|
||||||
for i = 0, 4 do
|
for i = 0, 4 do
|
||||||
distTree.table[i] = 0
|
distTree.table[i] = 0
|
||||||
end
|
end
|
||||||
|
@ -118,29 +128,31 @@ local function buildFixedTrees(lengthTree: Tree, distTree: Tree)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Temporary array for building trees
|
||||||
local offs = table.create(16, 0)
|
local offs = table.create(16, 0)
|
||||||
|
|
||||||
|
--- Builds a Huffman tree from a list of code lengths
|
||||||
local function buildTree(t: Tree, lengths: { number }, off: number, num: number)
|
local function buildTree(t: Tree, lengths: { number }, off: number, num: number)
|
||||||
-- clear code length count table
|
-- Initialize the code length count table
|
||||||
for i = 0, 15 do
|
for i = 0, 15 do
|
||||||
t.table[i] = 0
|
t.table[i] = 0
|
||||||
end
|
end
|
||||||
|
|
||||||
-- scan symbol lengths, and sum code length counts
|
-- Count the frequency of each code length
|
||||||
for i = 0, num - 1 do
|
for i = 0, num - 1 do
|
||||||
t.table[lengths[off + i]] += 1
|
t.table[lengths[off + i]] += 1
|
||||||
end
|
end
|
||||||
|
|
||||||
t.table[0] = 0
|
t.table[0] = 0
|
||||||
|
|
||||||
-- compute offset table for distribution sort
|
-- Calculate offsets for distribution sort
|
||||||
local sum = 0
|
local sum = 0
|
||||||
for i = 0, 15 do
|
for i = 0, 15 do
|
||||||
offs[i] = sum
|
offs[i] = sum
|
||||||
sum += t.table[i]
|
sum += t.table[i]
|
||||||
end
|
end
|
||||||
|
|
||||||
-- create code->symbol translation table
|
-- Create the translation table mapping codes to symbols
|
||||||
for i = 0, num - 1 do
|
for i = 0, num - 1 do
|
||||||
local len = lengths[off + i]
|
local len = lengths[off + i]
|
||||||
if len > 0 then
|
if len > 0 then
|
||||||
|
@ -150,6 +162,7 @@ local function buildTree(t: Tree, lengths: { number }, off: number, num: number)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Reads a single bit from the input stream
|
||||||
local function getBit(d: Data): number
|
local function getBit(d: Data): number
|
||||||
if d.bitcount <= 0 then
|
if d.bitcount <= 0 then
|
||||||
d.tag = buffer.readu8(d.source, d.sourceIndex)
|
d.tag = buffer.readu8(d.source, d.sourceIndex)
|
||||||
|
@ -164,11 +177,13 @@ local function getBit(d: Data): number
|
||||||
return bit
|
return bit
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Reads multiple bits from the input stream with a base value
|
||||||
local function readBits(d: Data, num: number?, base: number): number
|
local function readBits(d: Data, num: number?, base: number): number
|
||||||
if not num then
|
if not num then
|
||||||
return base
|
return base
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Ensure we have enough bits in the buffer
|
||||||
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
|
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
|
||||||
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
|
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
|
||||||
d.sourceIndex += 1
|
d.sourceIndex += 1
|
||||||
|
@ -182,6 +197,7 @@ local function readBits(d: Data, num: number?, base: number): number
|
||||||
return val + base
|
return val + base
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Decodes a symbol using a Huffman tree
|
||||||
local function decodeSymbol(d: Data, t: Tree): number
|
local function decodeSymbol(d: Data, t: Tree): number
|
||||||
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
|
while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do
|
||||||
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
|
d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount))
|
||||||
|
@ -192,6 +208,7 @@ local function decodeSymbol(d: Data, t: Tree): number
|
||||||
local sum, cur, len = 0, 0, 0
|
local sum, cur, len = 0, 0, 0
|
||||||
local tag = d.tag
|
local tag = d.tag
|
||||||
|
|
||||||
|
-- Traverse the Huffman tree to find the symbol
|
||||||
repeat
|
repeat
|
||||||
cur = 2 * cur + bit32.band(tag, 1)
|
cur = 2 * cur + bit32.band(tag, 1)
|
||||||
tag = bit32.rshift(tag, 1)
|
tag = bit32.rshift(tag, 1)
|
||||||
|
@ -206,63 +223,77 @@ local function decodeSymbol(d: Data, t: Tree): number
|
||||||
return t.trans[sum + cur]
|
return t.trans[sum + cur]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Decodes the dynamic Huffman trees for a block
|
||||||
local function decodeTrees(d: Data, lengthTree: Tree, distTree: Tree)
|
local function decodeTrees(d: Data, lengthTree: Tree, distTree: Tree)
|
||||||
local hlit = readBits(d, 5, 257)
|
local hlit = readBits(d, 5, 257) -- Number of literal/length codes
|
||||||
local hdist = readBits(d, 5, 1)
|
local hdist = readBits(d, 5, 1) -- Number of distance codes
|
||||||
local hclen = readBits(d, 4, 4)
|
local hclen = readBits(d, 4, 4) -- Number of code length codes
|
||||||
|
|
||||||
|
-- Initialize code lengths array
|
||||||
for i = 0, 18 do
|
for i = 0, 18 do
|
||||||
lengths[i] = 0
|
lengths[i] = 0
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Read code lengths for the code length alphabet
|
||||||
for i = 0, hclen - 1 do
|
for i = 0, hclen - 1 do
|
||||||
lengths[clcIndex[i + 1]] = readBits(d, 3, 0)
|
lengths[clcIndex[i + 1]] = readBits(d, 3, 0)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Build the code lengths tree
|
||||||
buildTree(codeTree, lengths, 0, 19)
|
buildTree(codeTree, lengths, 0, 19)
|
||||||
|
|
||||||
|
-- Decode length/distance tree code lengths
|
||||||
local num = 0
|
local num = 0
|
||||||
while num < hlit + hdist do
|
while num < hlit + hdist do
|
||||||
local sym = decodeSymbol(d, codeTree)
|
local sym = decodeSymbol(d, codeTree)
|
||||||
|
|
||||||
if sym == 16 then
|
if sym == 16 then
|
||||||
|
-- Copy previous code length 3-6 times
|
||||||
local prev = lengths[num - 1]
|
local prev = lengths[num - 1]
|
||||||
for _ = 1, readBits(d, 2, 3) do
|
for _ = 1, readBits(d, 2, 3) do
|
||||||
lengths[num] = prev
|
lengths[num] = prev
|
||||||
num += 1
|
num += 1
|
||||||
end
|
end
|
||||||
elseif sym == 17 then
|
elseif sym == 17 then
|
||||||
|
-- Repeat zero 3-10 times
|
||||||
for _ = 1, readBits(d, 3, 3) do
|
for _ = 1, readBits(d, 3, 3) do
|
||||||
lengths[num] = 0
|
lengths[num] = 0
|
||||||
num += 1
|
num += 1
|
||||||
end
|
end
|
||||||
elseif sym == 18 then
|
elseif sym == 18 then
|
||||||
|
-- Repeat zero 11-138 times
|
||||||
for _ = 1, readBits(d, 7, 11) do
|
for _ = 1, readBits(d, 7, 11) do
|
||||||
lengths[num] = 0
|
lengths[num] = 0
|
||||||
num += 1
|
num += 1
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
|
-- Regular code length 0-15
|
||||||
lengths[num] = sym
|
lengths[num] = sym
|
||||||
num += 1
|
num += 1
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Build the literal/length and distance trees
|
||||||
buildTree(lengthTree, lengths, 0, hlit)
|
buildTree(lengthTree, lengths, 0, hlit)
|
||||||
buildTree(distTree, lengths, hlit, hdist)
|
buildTree(distTree, lengths, hlit, hdist)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Inflates a block of data using Huffman trees
|
||||||
local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree)
|
local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree)
|
||||||
while true do
|
while true do
|
||||||
local sym = decodeSymbol(d, lengthTree)
|
local sym = decodeSymbol(d, lengthTree)
|
||||||
|
|
||||||
if sym == 256 then
|
if sym == 256 then
|
||||||
|
-- End of block
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
|
|
||||||
if sym < 256 then
|
if sym < 256 then
|
||||||
|
-- Literal byte
|
||||||
buffer.writeu8(d.dest, d.destLen, sym)
|
buffer.writeu8(d.dest, d.destLen, sym)
|
||||||
d.destLen += 1
|
d.destLen += 1
|
||||||
else
|
else
|
||||||
|
-- Length/distance pair for copying
|
||||||
sym -= 257
|
sym -= 257
|
||||||
|
|
||||||
local length = readBits(d, lengthBits[sym], lengthBase[sym])
|
local length = readBits(d, lengthBits[sym], lengthBase[sym])
|
||||||
|
@ -270,6 +301,7 @@ local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree)
|
||||||
|
|
||||||
local offs = d.destLen - readBits(d, distBits[dist], distBase[dist])
|
local offs = d.destLen - readBits(d, distBits[dist], distBase[dist])
|
||||||
|
|
||||||
|
-- Copy bytes from back reference
|
||||||
for i = offs, offs + length - 1 do
|
for i = offs, offs + length - 1 do
|
||||||
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i))
|
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i))
|
||||||
d.destLen += 1
|
d.destLen += 1
|
||||||
|
@ -278,24 +310,29 @@ local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Processes an uncompressed block
|
||||||
local function inflateUncompressedBlock(d: Data)
|
local function inflateUncompressedBlock(d: Data)
|
||||||
|
-- Align to byte boundary
|
||||||
while d.bitcount > 8 do
|
while d.bitcount > 8 do
|
||||||
d.sourceIndex -= 1
|
d.sourceIndex -= 1
|
||||||
d.bitcount -= 8
|
d.bitcount -= 8
|
||||||
end
|
end
|
||||||
|
|
||||||
|
-- Read block length and its complement
|
||||||
local length = buffer.readu8(d.source, d.sourceIndex + 1)
|
local length = buffer.readu8(d.source, d.sourceIndex + 1)
|
||||||
length = 256 * length + buffer.readu8(d.source, d.sourceIndex)
|
length = 256 * length + buffer.readu8(d.source, d.sourceIndex)
|
||||||
|
|
||||||
local invlength = buffer.readu8(d.source, d.sourceIndex + 3)
|
local invlength = buffer.readu8(d.source, d.sourceIndex + 3)
|
||||||
invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2)
|
invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2)
|
||||||
|
|
||||||
|
-- Verify block length using ones complement
|
||||||
if length ~= bit32.bxor(invlength, 0xffff) then
|
if length ~= bit32.bxor(invlength, 0xffff) then
|
||||||
error("Invalid block length")
|
error("Invalid block length")
|
||||||
end
|
end
|
||||||
|
|
||||||
d.sourceIndex += 4
|
d.sourceIndex += 4
|
||||||
|
|
||||||
|
-- Copy uncompressed data to output
|
||||||
for _ = 1, length do
|
for _ = 1, length do
|
||||||
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex))
|
buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex))
|
||||||
d.destLen += 1
|
d.destLen += 1
|
||||||
|
@ -305,13 +342,14 @@ local function inflateUncompressedBlock(d: Data)
|
||||||
d.bitcount = 0
|
d.bitcount = 0
|
||||||
end
|
end
|
||||||
|
|
||||||
|
--- Main decompression function that processes DEFLATE compressed data
|
||||||
local function uncompress(source: buffer): buffer
|
local function uncompress(source: buffer): buffer
|
||||||
local dest = buffer.create(buffer.len(source) * 4)
|
local dest = buffer.create(buffer.len(source) * 4)
|
||||||
local d = Data.new(source, dest)
|
local d = Data.new(source, dest)
|
||||||
|
|
||||||
repeat
|
repeat
|
||||||
local bfinal = getBit(d)
|
local bfinal = getBit(d) -- Last block flag
|
||||||
local btype = readBits(d, 2, 0)
|
local btype = readBits(d, 2, 0) -- Block type (0=uncompressed, 1=fixed, 2=dynamic)
|
||||||
|
|
||||||
if btype == 0 then
|
if btype == 0 then
|
||||||
inflateUncompressedBlock(d)
|
inflateUncompressedBlock(d)
|
||||||
|
@ -325,6 +363,7 @@ local function uncompress(source: buffer): buffer
|
||||||
end
|
end
|
||||||
until bfinal == 1
|
until bfinal == 1
|
||||||
|
|
||||||
|
-- Trim output buffer to actual size if needed
|
||||||
if d.destLen < buffer.len(dest) then
|
if d.destLen < buffer.len(dest) then
|
||||||
local result = buffer.create(d.destLen)
|
local result = buffer.create(d.destLen)
|
||||||
buffer.copy(result, 0, dest, 0, d.destLen)
|
buffer.copy(result, 0, dest, 0, d.destLen)
|
||||||
|
@ -334,7 +373,7 @@ local function uncompress(source: buffer): buffer
|
||||||
return dest
|
return dest
|
||||||
end
|
end
|
||||||
|
|
||||||
-- Initialize static trees and tables
|
-- Initialize static trees and lookup tables for DEFLATE format
|
||||||
buildFixedTrees(staticLengthTree, staticDistTree)
|
buildFixedTrees(staticLengthTree, staticDistTree)
|
||||||
buildBitsBase(lengthBits, lengthBase, 4, 3)
|
buildBitsBase(lengthBits, lengthBase, 4, 3)
|
||||||
buildBitsBase(distBits, distBase, 2, 1)
|
buildBitsBase(distBits, distBase, 2, 1)
|
||||||
|
|
Loading…
Add table
Reference in a new issue