diff --git a/test.luau b/examples/tour.luau similarity index 58% rename from test.luau rename to examples/tour.luau index 1d9ff1d..6a58cc6 100644 --- a/test.luau +++ b/examples/tour.luau @@ -1,5 +1,5 @@ local fs = require("@lune/fs") -local zip = require("./lib") +local zip = require("../lib") local file = fs.readFile("test.zip") local reader = zip.load(buffer.fromstring(file)) @@ -11,16 +11,14 @@ reader:walk(function(entry, depth) print(prefix .. entry.name .. suffix) end) -print("\nContents of /lib/:") -local assets = reader:listDirectory("/lib/") -for _, entry in ipairs(assets) do - print(entry.name, entry.isDirectory and "DIR" or entry.size) -end - -local configEntry = reader:findEntry("config.json") -if configEntry then - local content = reader:extract(configEntry) - print("\nConfig file size:", buffer.len(content)) +print("\nContents of `/`:") +local assets = reader:listDirectory("/") +for _, entry in assets do + print(entry.name, if entry.isDirectory then "DIR" else entry.size) + if not entry.isDirectory then + local extracted = reader:extract(entry, { isString = true }) + print("Content:", extracted) + end end -- Get archive statistics diff --git a/lib/crc.luau b/lib/crc.luau new file mode 100644 index 0000000..bd9fd79 --- /dev/null +++ b/lib/crc.luau @@ -0,0 +1,30 @@ +local CRC32_TABLE = table.create(256) + +-- Initialize the lookup table and lock it in place +for i = 0, 255 do + local crc = i + for _ = 1, 8 do + if bit32.band(crc, 1) == 1 then + crc = bit32.bxor(bit32.rshift(crc, 1), 0xEDB88320) + else + crc = bit32.rshift(crc, 1) + end + end + CRC32_TABLE[i] = crc +end + +table.freeze(CRC32_TABLE) + +local function crc32(buf: buffer): number + local crc = 0xFFFFFFFF + + for i = 0, buffer.len(buf) - 1 do + local byte = buffer.readu8(buf, i) + local index = bit32.band(bit32.bxor(crc, byte), 0xFF) + crc = bit32.bxor(bit32.rshift(crc, 8), CRC32_TABLE[index]) + end + + return bit32.bxor(crc, 0xFFFFFFFF) +end + +return crc32 diff --git a/lib/inflate.luau b/lib/inflate.luau new file mode 100644 index 0000000..b1c96ee --- /dev/null +++ b/lib/inflate.luau @@ -0,0 +1,345 @@ +local Tree = {} + +export type Tree = typeof(setmetatable({} :: TreeInner, { __index = Tree })) +type TreeInner = { + table: { number }, -- len: 16 (🏳️‍⚧️❓) + trans: { number }, -- len: 288 +} + +function Tree.new(): Tree + return setmetatable( + { + table = table.create(16, 0), + trans = table.create(288, 0), + } :: TreeInner, + { __index = Tree } + ) +end + +local Data = {} +export type Data = typeof(setmetatable({} :: DataInner, { __index = Data })) +export type DataInner = { + source: buffer, + sourceIndex: number, + tag: number, + bitcount: number, + + dest: buffer, + destLen: number, + + ltree: Tree, + dtree: Tree, +} + +function Data.new(source: buffer, dest: buffer): Data + return setmetatable( + { + source = source, + sourceIndex = 0, + tag = 0, + bitcount = 0, + dest = dest, + destLen = 0, + ltree = Tree.new(), + dtree = Tree.new(), + } :: DataInner, + { __index = Data } + ) +end + +-- Static structures +local staticLengthTree = Tree.new() +local staticDistTree = Tree.new() + +-- Extra bits and base tables +local lengthBits = table.create(30, 0) +local lengthBase = table.create(30, 0) +local distBits = table.create(30, 0) +local distBase = table.create(30, 0) + +-- Special ordering of code length codes +-- stylua: ignore +local clcIndex = { + 16, 17, 18, 0, 8, 7, 9, 6, + 10, 5, 11, 4, 12, 3, 13, 2, + 14, 1, 15 +} + +local codeTree = Tree.new() +local lengths = table.create(288 + 32, 0) + +local function buildBitsBase(bits: { number }, base: { number }, delta: number, first: number) + local sum = first + + -- build bits table + for i = 0, delta - 1 do + bits[i] = 0 + end + for i = 0, 29 - delta do + bits[i + delta] = math.floor(i / delta) + end + + -- build base table + for i = 0, 29 do + base[i] = sum + sum = sum + bit32.lshift(1, bits[i]) + end +end + +local function buildFixedTrees(lengthTree: Tree, distTree: Tree) + -- build fixed length tree + for i = 0, 6 do + lengthTree.table[i] = 0 + end + lengthTree.table[7] = 24 + lengthTree.table[8] = 152 + lengthTree.table[9] = 112 + + for i = 0, 23 do + lengthTree.trans[i] = 256 + i + end + for i = 0, 143 do + lengthTree.trans[24 + i] = i + end + for i = 0, 7 do + lengthTree.trans[24 + 144 + i] = 280 + i + end + for i = 0, 111 do + lengthTree.trans[24 + 144 + 8 + i] = 144 + i + end + + -- build fixed distance tree + for i = 0, 4 do + distTree.table[i] = 0 + end + distTree.table[5] = 32 + + for i = 0, 31 do + distTree.trans[i] = i + end +end + +local offs = table.create(16, 0) + +local function buildTree(t: Tree, lengths: { number }, off: number, num: number) + -- clear code length count table + for i = 0, 15 do + t.table[i] = 0 + end + + -- scan symbol lengths, and sum code length counts + for i = 0, num - 1 do + t.table[lengths[off + i]] += 1 + end + + t.table[0] = 0 + + -- compute offset table for distribution sort + local sum = 0 + for i = 0, 15 do + offs[i] = sum + sum = sum + t.table[i] + end + + -- create code->symbol translation table + for i = 0, num - 1 do + local len = lengths[off + i] + if len > 0 then + t.trans[offs[len]] = i + offs[len] += 1 + end + end +end + +local function getBit(d: Data): number + if d.bitcount <= 0 then + d.tag = buffer.readu8(d.source, d.sourceIndex) + d.sourceIndex += 1 + d.bitcount = 8 + end + + local bit = bit32.band(d.tag, 1) + d.tag = bit32.rshift(d.tag, 1) + d.bitcount -= 1 + + return bit +end + +local function readBits(d: Data, num: number?, base: number): number + if not num then + return base + end + + while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do + d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) + d.sourceIndex += 1 + d.bitcount += 8 + end + + local val = bit32.band(d.tag, bit32.rshift(0xffff, 16 - num)) + d.tag = bit32.rshift(d.tag, num) + d.bitcount -= num + + return val + base +end + +local function decodeSymbol(d: Data, t: Tree): number + while d.bitcount < 24 and d.sourceIndex < buffer.len(d.source) do + d.tag = bit32.bor(d.tag, bit32.lshift(buffer.readu8(d.source, d.sourceIndex), d.bitcount)) + d.sourceIndex += 1 + d.bitcount += 8 + end + + local sum, cur, len = 0, 0, 0 + local tag = d.tag + + repeat + cur = 2 * cur + bit32.band(tag, 1) + tag = bit32.rshift(tag, 1) + len += 1 + sum += t.table[len] + cur -= t.table[len] + until cur < 0 + + d.tag = tag + d.bitcount -= len + + return t.trans[sum + cur] +end + +local function decodeTrees(d: Data, lengthTree: Tree, distTree: Tree) + local hlit = readBits(d, 5, 257) + local hdist = readBits(d, 5, 1) + local hclen = readBits(d, 4, 4) + + for i = 0, 18 do + lengths[i] = 0 + end + + for i = 0, hclen - 1 do + lengths[clcIndex[i + 1]] = readBits(d, 3, 0) + end + + buildTree(codeTree, lengths, 0, 19) + + local num = 0 + while num < hlit + hdist do + local sym = decodeSymbol(d, codeTree) + + if sym == 16 then + local prev = lengths[num - 1] + for _ = 1, readBits(d, 2, 3) do + lengths[num] = prev + num += 1 + end + elseif sym == 17 then + for _ = 1, readBits(d, 3, 3) do + lengths[num] = 0 + num += 1 + end + elseif sym == 18 then + for _ = 1, readBits(d, 7, 11) do + lengths[num] = 0 + num += 1 + end + else + lengths[num] = sym + num += 1 + end + end + + buildTree(lengthTree, lengths, 0, hlit) + buildTree(distTree, lengths, hlit, hdist) +end + +local function inflateBlockData(d: Data, lengthTree: Tree, distTree: Tree) + while true do + local sym = decodeSymbol(d, lengthTree) + + if sym == 256 then + return + end + + if sym < 256 then + buffer.writeu8(d.dest, d.destLen, sym) + d.destLen += 1 + else + sym -= 257 + + local length = readBits(d, lengthBits[sym], lengthBase[sym]) + local dist = decodeSymbol(d, distTree) + + local offs = d.destLen - readBits(d, distBits[dist], distBase[dist]) + + for i = offs, offs + length - 1 do + buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.dest, i)) + d.destLen += 1 + end + end + end +end + +local function inflateUncompressedBlock(d: Data) + while d.bitcount > 8 do + d.sourceIndex -= 1 + d.bitcount -= 8 + end + + local length = buffer.readu8(d.source, d.sourceIndex + 1) + length = 256 * length + buffer.readu8(d.source, d.sourceIndex) + + local invlength = buffer.readu8(d.source, d.sourceIndex + 3) + invlength = 256 * invlength + buffer.readu8(d.source, d.sourceIndex + 2) + + if length ~= bit32.bxor(invlength, 0xffff) then + error("Invalid block length") + end + + d.sourceIndex += 4 + + for _ = 1, length do + buffer.writeu8(d.dest, d.destLen, buffer.readu8(d.source, d.sourceIndex)) + d.destLen += 1 + d.sourceIndex += 1 + end + + d.bitcount = 0 +end + +local function uncompress(source: buffer): buffer + local dest = buffer.create(buffer.len(source) * 4) + local d = Data.new(source, dest) + + repeat + local bfinal = getBit(d) + local btype = readBits(d, 2, 0) + + if btype == 0 then + inflateUncompressedBlock(d) + elseif btype == 1 then + inflateBlockData(d, staticLengthTree, staticDistTree) + elseif btype == 2 then + decodeTrees(d, d.ltree, d.dtree) + inflateBlockData(d, d.ltree, d.dtree) + else + error("Invalid block type") + end + until bfinal == 1 + + if d.destLen < buffer.len(dest) then + local result = buffer.create(d.destLen) + buffer.copy(result, 0, dest, 0, d.destLen) + return result + end + + return dest +end + +-- Initialize static trees and tables +buildFixedTrees(staticLengthTree, staticDistTree) +buildBitsBase(lengthBits, lengthBase, 4, 3) +buildBitsBase(distBits, distBase, 2, 1) +lengthBits[28] = 0 +lengthBase[28] = 258 + +return uncompress diff --git a/lib/init.luau b/lib/init.luau index d00898e..6350671 100644 --- a/lib/init.luau +++ b/lib/init.luau @@ -1,3 +1,6 @@ +local inflate = require("./inflate") +local crc32 = require("./crc") + -- Little endian constant signatures used in the ZIP file format local SIGNATURES = table.freeze({ -- Marks the beginning of each file in the ZIP @@ -8,7 +11,38 @@ local SIGNATURES = table.freeze({ END_OF_CENTRAL_DIR = 0x06054b50, }) --- TODO: ERROR HANDLING!! +type CrcValidationOptions = { + skip: boolean, + expected: number, +} + +local function validateCrc(decompressed: buffer, validation: CrcValidationOptions) + -- Unless skipping validation is requested, we verify the checksum + if validation.skip then + local computed = crc32(decompressed) + assert( + validation.expected == computed, + `Validation failed; CRC checksum does not match: {string.format("%x", computed)} ~= {string.format( + "%x", + computed + )} (expected ~= got)` + ) + end +end + +local DECOMPRESSION_ROUTINES: { [number]: (buffer, validation: CrcValidationOptions) -> buffer } = table.freeze({ + [0x00] = function(buf, validation) + validateCrc(buf, validation) + return buf + end, + [0x08] = function(buf, validation) + local decompressed = inflate(buf) + validateCrc(decompressed, validation) + return decompressed + end, +}) + +-- TODO: ERROR HANDLING! local ZipEntry = {} export type ZipEntry = typeof(setmetatable({} :: ZipEntryInner, { __index = ZipEntry })) @@ -140,75 +174,82 @@ end function ZipReader.buildDirectoryTree(self: ZipReader): () for _, entry in self.entries do - local parts = {} - -- Split entry path into individual components - -- e.g. "folder/subfolder/file.txt" -> {"folder", "subfolder", "file.txt"} - for part in string.gmatch(entry.name, "([^/]+)/?") do - table.insert(parts, part) - end + local parts = {} + -- Split entry path into individual components + -- e.g. "folder/subfolder/file.txt" -> {"folder", "subfolder", "file.txt"} + for part in string.gmatch(entry.name, "([^/]+)/?") do + table.insert(parts, part) + end - -- Start from root directory - local current = self.root - local path = "" + -- Start from root directory + local current = self.root + local path = "" - -- Process each path component - for i, part in parts do - path ..= part - if i < #parts then - -- Create missing directory entries for intermediate paths - if not self.directories[path] then - local dir = ZipEntry.new(path, 0, 0, entry.timestamp, 0) - dir.isDirectory = true - dir.parent = current + -- Process each path component + for i, part in parts do + path ..= part + if i < #parts then + -- Create missing directory entries for intermediate paths + if not self.directories[path] then + local dir = ZipEntry.new(path, 0, 0, entry.timestamp, 0) + dir.isDirectory = true + dir.parent = current - -- Track directory in both lookup table and parent's children - self.directories[path] = dir - table.insert(current.children, dir) - end + -- Track directory in both lookup table and parent's children + self.directories[path] = dir + table.insert(current.children, dir) + end - -- Move deeper into the tree - current = self.directories[path] - continue - end + -- Move deeper into the tree + current = self.directories[path] + continue + end - -- Link file entry to its parent directory - entry.parent = current - table.insert(current.children, entry) - end - end + -- Link file entry to its parent directory + entry.parent = current + table.insert(current.children, entry) + end + end end function ZipReader.findEntry(self: ZipReader, path: string): ZipEntry - if path == "/" then - -- If the root directory's entry was requested we do not + if path == "/" then + -- If the root directory's entry was requested we do not -- need to do any additional work return self.root - end + end - -- Normalize path by removing leading and trailing slashes - -- This ensures consistent lookup regardless of input format - -- e.g., "/folder/file.txt/" -> "folder/file.txt" - path = string.gsub(path, "^/", ""):gsub("/$", "") + -- Normalize path by removing leading and trailing slashes + -- This ensures consistent lookup regardless of input format + -- e.g., "/folder/file.txt/" -> "folder/file.txt" + path = string.gsub(path, "^/", ""):gsub("/$", "") - -- First check regular files and explicit directories - for _, entry in self.entries do - -- Compare normalized paths - if string.gsub(entry.name, "/$", "") == path then - return entry - end - end + -- First check regular files and explicit directories + for _, entry in self.entries do + -- Compare normalized paths + if string.gsub(entry.name, "/$", "") == path then + return entry + end + end - -- If not found, check virtual directory entries - -- These are directories that were created implicitly - return self.directories[path] + -- If not found, check virtual directory entries + -- These are directories that were created implicitly + return self.directories[path] end - -function ZipReader.extract(self: ZipReader, entry: ZipEntry): buffer +type ExtractionOptions = { + decompress: boolean?, + isString: boolean?, + skipValidation: boolean?, +} +function ZipReader.extract(self: ZipReader, entry: ZipEntry, options: ExtractionOptions?): buffer | string -- Local File Header format: -- Offset Bytes Description -- 0 4 Local file header signature -- 8 2 Compression method (8 = DEFLATE) + -- 14 4 CRC32 checksume + -- 18 4 Compressed size + -- 22 4 Uncompressed size -- 26 2 File name length (n) -- 28 2 Extra field length (m) -- 30 n File name @@ -219,45 +260,79 @@ function ZipReader.extract(self: ZipReader, entry: ZipEntry): buffer error("Cannot extract directory") end + local defaultOptions: ExtractionOptions = { + decompress = true, + isString = false, + skipValidation = false, + } + + -- TODO: Use a `Partial` type function for this in the future! + local optionsOrDefault: { + decompress: boolean, + isString: boolean, + skipValidation: boolean + } = if options then setmetatable(options, { __index = defaultOptions }) :: any else defaultOptions + local pos = entry.offset if buffer.readu32(self.data, pos) ~= SIGNATURES.LOCAL_FILE then error("Invalid local file header") end + local crcChecksum = buffer.readu32(self.data, pos + 14) + local compressedSize = buffer.readu32(self.data, pos + 18) + local uncompressedSize = buffer.readu32(self.data, pos + 22) local nameLength = buffer.readu16(self.data, pos + 26) local extraLength = buffer.readu16(self.data, pos + 28) pos = pos + 30 + nameLength + extraLength - local content = buffer.create(entry.size) - buffer.copy(content, 0, self.data, pos, entry.size) + local content = buffer.create(compressedSize) + buffer.copy(content, 0, self.data, pos, compressedSize) - -- TODO: decompress data! `buffer.readu16(self.data, entry.offset + 8)` - -- will give the compression method, where method id 8 corresponds to - -- deflate + if optionsOrDefault.decompress then + local compressionMethod = buffer.readu16(self.data, entry.offset + 8) + local decompress = DECOMPRESSION_ROUTINES[compressionMethod] + if decompress == nil then + error(`Unsupported compression, ID: {compressionMethod}`) + end - return content + content = decompress(content, { + expected = crcChecksum, + skip = optionsOrDefault.skipValidation, + }) + + -- Unless skipping validation is requested, we make sure the uncompressed size matches + assert( + optionsOrDefault.skipValidation or uncompressedSize == buffer.len(content), + "Validation failed; uncompressed size does not match" + ) + end + + return if optionsOrDefault.isString then buffer.tostring(content) else content end -function ZipReader.extractDirectory(self: ZipReader, path: string): { [string]: buffer } - local files = {} - -- Normalize path by removing leading slash for consistent prefix matching - path = string.gsub(path, "^/", "") +function ZipReader.extractDirectory( + self: ZipReader, + path: string, + options: ExtractionOptions +): { [string]: buffer } | { [string]: string } + local files: { [string]: buffer } | { [string]: string } = {} + -- Normalize path by removing leading slash for consistent prefix matching + path = string.gsub(path, "^/", "") - -- Iterate through all entries to find files within target directory - for _, entry in self.entries do - -- Check if entry is a file (not directory) and its path starts with target directory - if not entry.isDirectory and string.sub(entry.name, 1, #path) == path then - -- Store extracted content mapped to full path - files[entry.name] = self:extract(entry) - end - end + -- Iterate through all entries to find files within target directory + for _, entry in self.entries do + -- Check if entry is a file (not directory) and its path starts with target directory + if not entry.isDirectory and string.sub(entry.name, 1, #path) == path then + -- Store extracted content mapped to full path + files[entry.name] = self:extract(entry, options) + end + end - -- Return a map of file to contents - return files + -- Return a map of file to contents + return files end - function ZipReader.listDirectory(self: ZipReader, path: string): { ZipEntry } -- Locate the entry with the path local entry = self:findEntry(path)