diff --git a/lib/init.luau b/lib/init.luau index f259784..72be767 100644 --- a/lib/init.luau +++ b/lib/init.luau @@ -3,59 +3,59 @@ local crc32 = require("./crc") -- Little endian constant signatures used in the ZIP file format local SIGNATURES = table.freeze({ - -- Marks the beginning of each file in the ZIP - LOCAL_FILE = 0x04034b50, - -- Marks the start of an data descriptor - DATA_DESCRIPTOR = 0x08074b50, - -- Marks entries in the central directory - CENTRAL_DIR = 0x02014b50, - -- Marks the end of the central directory - END_OF_CENTRAL_DIR = 0x06054b50, + -- Marks the beginning of each file in the ZIP + LOCAL_FILE = 0x04034b50, + -- Marks the start of an data descriptor + DATA_DESCRIPTOR = 0x08074b50, + -- Marks entries in the central directory + CENTRAL_DIR = 0x02014b50, + -- Marks the end of the central directory + END_OF_CENTRAL_DIR = 0x06054b50, }) type CrcValidationOptions = { - skip: boolean, - expected: number, + skip: boolean, + expected: number, } local function validateCrc(decompressed: buffer, validation: CrcValidationOptions) - -- Unless skipping validation is requested, we verify the checksum - if not validation.skip then - local computed = crc32(decompressed) - assert( - validation.expected == computed, - `Validation failed; CRC checksum does not match: {string.format("%x", computed)} ~= {string.format( - "%x", - computed - )} (expected ~= got)` - ) - end + -- Unless skipping validation is requested, we verify the checksum + if not validation.skip then + local computed = crc32(decompressed) + assert( + validation.expected == computed, + `Validation failed; CRC checksum does not match: {string.format("%x", computed)} ~= {string.format( + "%x", + computed + )} (expected ~= got)` + ) + end end export type CompressionMethod = "STORE" | "DEFLATE" local DECOMPRESSION_ROUTINES: { [number]: { name: CompressionMethod, decompress: (buffer, number, CrcValidationOptions) -> buffer } } = - table.freeze({ - -- `STORE` decompression method - No compression - [0x00] = { - name = "STORE" :: CompressionMethod, - decompress = function(buf, _, validation) - validateCrc(buf, validation) - return buf - end, - }, + table.freeze({ + -- `STORE` decompression method - No compression + [0x00] = { + name = "STORE" :: CompressionMethod, + decompress = function(buf, _, validation) + validateCrc(buf, validation) + return buf + end, + }, - -- `DEFLATE` decompression method - Compressed raw deflate chunks - [0x08] = { - name = "DEFLATE" :: CompressionMethod, - decompress = function(buf, uncompressedSize, validation) - -- FIXME: Why is uncompressedSize not getting inferred correctly although it - -- is typed? - local decompressed = inflate(buf, uncompressedSize :: any) - validateCrc(decompressed, validation) - return decompressed - end, - }, - }) + -- `DEFLATE` decompression method - Compressed raw deflate chunks + [0x08] = { + name = "DEFLATE" :: CompressionMethod, + decompress = function(buf, uncompressedSize, validation) + -- FIXME: Why is uncompressedSize not getting inferred correctly although it + -- is typed? + local decompressed = inflate(buf, uncompressedSize :: any) + validateCrc(decompressed, validation) + return decompressed + end, + }, + }) -- TODO: ERROR HANDLING! @@ -70,44 +70,59 @@ type ZipEntryInner = { method: CompressionMethod, -- Method used to compress the file crc: number, -- CRC32 checksum of uncompressed data isDirectory: boolean, -- Whether the entry is a directory or not + isAscii: boolean, -- Whether the entry is plain ASCII text or binary + attributes: number, -- File attributes parent: ZipEntry?, -- The parent of the current entry, nil for root children: { ZipEntry }, -- The children of the entry } -function ZipEntry.new( - name: string, - size: number, - offset: number, - timestamp: number, - method: CompressionMethod?, - crc: number -): ZipEntry - return setmetatable( - { - name = name, - size = size, - offset = offset, - timestamp = timestamp, - method = method, - crc = crc, - isDirectory = string.sub(name, -1) == "/", - parent = nil, - children = {}, - } :: ZipEntryInner, - { __index = ZipEntry } - ) +type ZipEntryProperties = { + size: number, + attributes: number, + timestamp: number, + method: CompressionMethod?, + crc: number, +} +local EMPTY_PROPERTIES: ZipEntryProperties = table.freeze({ + size = 0, + attributes = 0, + timestamp = 0, + method = nil, + crc = 0, +}) + +function ZipEntry.new(offset: number, name: string, properties: ZipEntryProperties): ZipEntry + return setmetatable( + { + name = name, + size = properties.size, + offset = offset, + timestamp = properties.timestamp, + method = properties.method, + crc = properties.crc, + isDirectory = string.sub(name, -1) == "/", + attributes = properties.attributes, + parent = nil, + children = {}, + } :: ZipEntryInner, + { __index = ZipEntry } + ) +end + +function ZipEntry.isSymlink(self: ZipEntry): boolean + return bit32.band(self.attributes, 0xA0000000) == 0xA0000000 end function ZipEntry.getPath(self: ZipEntry): string - local path = self.name - local current = self.parent + local path = self.name + local current = self.parent - while current and current.name ~= "/" do - path = current.name .. path - current = current.parent - end + while current and current.name ~= "/" do + path = current.name .. path + current = current.parent + end - return path + return path end local ZipReader = {} @@ -122,98 +137,110 @@ type ZipReaderInner = { } function ZipReader.new(data): ZipReader - local root = ZipEntry.new("/", 0, 0, 0, nil, 0) - root.isDirectory = true + local root = ZipEntry.new(0, "/", EMPTY_PROPERTIES) + root.isDirectory = true - local this = setmetatable( - { - data = data, - entries = {}, - directories = {}, - root = root, - } :: ZipReaderInner, - { __index = ZipReader } - ) + local this = setmetatable( + { + data = data, + entries = {}, + directories = {}, + root = root, + } :: ZipReaderInner, + { __index = ZipReader } + ) - this:parseCentralDirectory() - this:buildDirectoryTree() - return this + this:parseCentralDirectory() + this:buildDirectoryTree() + return this end - function ZipReader.parseCentralDirectory(self: ZipReader): () - -- ZIP files are read from the end, starting with the End of Central Directory record - -- The EoCD is at least 22 bytes and contains pointers to the rest of the ZIP structure - local bufSize = buffer.len(self.data) + -- ZIP files are read from the end, starting with the End of Central Directory record + -- The EoCD is at least 22 bytes and contains pointers to the rest of the ZIP structure + local bufSize = buffer.len(self.data) - -- Start from the minimum possible position of EoCD (22 bytes from end) - local minPos = math.max(0, bufSize - (22 + 65535) --[[ max comment size: 64 KiB ]]) - local pos = bufSize - 22 + -- Start from the minimum possible position of EoCD (22 bytes from end) + local minPos = math.max(0, bufSize - (22 + 65535) --[[ max comment size: 64 KiB ]]) + local pos = bufSize - 22 - -- Search backwards for the EoCD signature - while pos >= minPos do - if buffer.readu32(self.data, pos) == SIGNATURES.END_OF_CENTRAL_DIR then - break - end - pos -= 1 - end + -- Search backwards for the EoCD signature + while pos >= minPos do + if buffer.readu32(self.data, pos) == SIGNATURES.END_OF_CENTRAL_DIR then + break + end + pos -= 1 + end - -- Verify we found the signature - if pos < minPos then - error("Could not find End of Central Directory signature") - end + -- Verify we found the signature + if pos < minPos then + error("Could not find End of Central Directory signature") + end - -- End of Central Directory format: - -- Offset Bytes Description - -- 0 4 End of central directory signature - -- 4 2 Number of this disk - -- 6 2 Disk where central directory starts - -- 8 2 Number of central directory records on this disk - -- 10 2 Total number of central directory records - -- 12 4 Size of central directory (bytes) - -- 16 4 Offset of start of central directory - -- 20 2 Comment length (n) - -- 22 n Comment + -- End of Central Directory format: + -- Offset Bytes Description + -- 0 4 End of central directory signature + -- 4 2 Number of this disk + -- 6 2 Disk where central directory starts + -- 8 2 Number of central directory records on this disk + -- 10 2 Total number of central directory records + -- 12 4 Size of central directory (bytes) + -- 16 4 Offset of start of central directory + -- 20 2 Comment length (n) + -- 22 n Comment - local cdOffset = buffer.readu32(self.data, pos + 16) - local cdEntries = buffer.readu16(self.data, pos + 10) - local cdCommentLength = buffer.readu16(self.data, pos + 20) - self.comment = buffer.readstring(self.data, pos + 22, cdCommentLength) + local cdOffset = buffer.readu32(self.data, pos + 16) + local cdEntries = buffer.readu16(self.data, pos + 10) + local cdCommentLength = buffer.readu16(self.data, pos + 20) + self.comment = buffer.readstring(self.data, pos + 22, cdCommentLength) - -- Process each entry in the Central Directory - pos = cdOffset - for i = 1, cdEntries do - -- Central Directory Entry format: - -- Offset Bytes Description - -- 0 4 Central directory entry signature - -- 8 2 General purpose bitflags - -- 10 2 Compression method (8 = DEFLATE) - -- 12 4 Last mod time/date - -- 16 4 CRC-32 - -- 24 4 Uncompressed size - -- 28 2 File name length (n) - -- 30 2 Extra field length (m) - -- 32 2 Comment length (k) - -- 42 4 Local header offset - -- 46 n File name - -- 46+n m Extra field - -- 46+n+m k Comment + -- Process each entry in the Central Directory + pos = cdOffset + for i = 1, cdEntries do + -- Central Directory Entry format: + -- Offset Bytes Description + -- 0 4 Central directory entry signature + -- 8 2 General purpose bitflags + -- 10 2 Compression method (8 = DEFLATE) + -- 12 4 Last mod time/date + -- 16 4 CRC-32 + -- 24 4 Uncompressed size + -- 28 2 File name length (n) + -- 30 2 Extra field length (m) + -- 32 2 Comment length (k) + -- 36 2 Internal file attributes + -- 38 4 External file attributes + -- 42 4 Local header offset + -- 46 n File name + -- 46+n m Extra field + -- 46+n+m k Comment - local _bitflags = buffer.readu16(self.data, pos + 8) - local timestamp = buffer.readu32(self.data, pos + 12) - local compressionMethod = buffer.readu16(self.data, pos + 10) - local crc = buffer.readu32(self.data, pos + 16) - local size = buffer.readu32(self.data, pos + 24) - local nameLength = buffer.readu16(self.data, pos + 28) - local extraLength = buffer.readu16(self.data, pos + 30) - local commentLength = buffer.readu16(self.data, pos + 32) - local offset = buffer.readu32(self.data, pos + 42) - local name = buffer.readstring(self.data, pos + 46, nameLength) + local _bitflags = buffer.readu16(self.data, pos + 8) + local timestamp = buffer.readu32(self.data, pos + 12) + local compressionMethod = buffer.readu16(self.data, pos + 10) + local crc = buffer.readu32(self.data, pos + 16) + local size = buffer.readu32(self.data, pos + 24) + local nameLength = buffer.readu16(self.data, pos + 28) + local extraLength = buffer.readu16(self.data, pos + 30) + local commentLength = buffer.readu16(self.data, pos + 32) + local internalAttrs = buffer.readu16(self.data, pos + 36) + local externalAttrs = buffer.readu32(self.data, pos + 38) + local offset = buffer.readu32(self.data, pos + 42) + local name = buffer.readstring(self.data, pos + 46, nameLength) - local entry = ZipEntry.new(name, size, offset, timestamp, DECOMPRESSION_ROUTINES[compressionMethod].name, crc) - table.insert(self.entries, entry) + table.insert( + self.entries, + ZipEntry.new(offset, name, { + size = size, + crc = crc, + compressionMethod = DECOMPRESSION_ROUTINES[compressionMethod].name, + timestamp = timestamp, + attributes = externalAttrs, + isAscii = bit32.band(internalAttrs, 0x0001) ~= 0, + }) + ) - pos = pos + 46 + nameLength + extraLength + commentLength - end + pos = pos + 46 + nameLength + extraLength + commentLength + end end function ZipReader.buildDirectoryTree(self: ZipReader): () @@ -253,10 +280,16 @@ function ZipReader.buildDirectoryTree(self: ZipReader): () else -- Create new directory entry for intermediate paths or undefined -- parent directories in the ZIP - local dir = ZipEntry.new(path .. "/", 0, 0, entry.timestamp, nil, 0) - dir.isDirectory = true - dir.parent = current - self.directories[path] = dir + local dir = ZipEntry.new(0, path .. "/", { + size = 0, + crc = 0, + compressionMethod = "STORED", + timestamp = entry.timestamp, + attributes = entry.attributes, + }) + dir.isDirectory = true + dir.parent = current + self.directories[path] = dir end -- Track directory in both lookup table and parent's children @@ -276,225 +309,284 @@ function ZipReader.buildDirectoryTree(self: ZipReader): () end function ZipReader.findEntry(self: ZipReader, path: string): ZipEntry? - if path == "/" then - -- If the root directory's entry was requested we do not - -- need to do any additional work - return self.root - end + if path == "/" then + -- If the root directory's entry was requested we do not + -- need to do any additional work + return self.root + end - -- Normalize path by removing leading and trailing slashes - -- This ensures consistent lookup regardless of input format - -- e.g., "/folder/file.txt/" -> "folder/file.txt" - path = string.gsub(path, "^/", ""):gsub("/$", "") + -- Normalize path by removing leading and trailing slashes + -- This ensures consistent lookup regardless of input format + -- e.g., "/folder/file.txt/" -> "folder/file.txt" + path = string.gsub(path, "^/", ""):gsub("/$", "") - -- First check regular files and explicit directories - for _, entry in self.entries do - -- Compare normalized paths - if string.gsub(entry.name, "/$", "") == path then - return entry - end - end + -- First check regular files and explicit directories + for _, entry in self.entries do + -- Compare normalized paths + if string.gsub(entry.name, "/$", "") == path then + return entry + end + end - -- If not found, check virtual directory entries - -- These are directories that were created implicitly - return self.directories[path] + -- If not found, check virtual directory entries + -- These are directories that were created implicitly + return self.directories[path] end type ExtractionOptions = { - decompress: boolean?, - isString: boolean?, - skipCrcValidation: boolean?, - skipSizeValidation: boolean?, + followSymlinks: boolean?, + decompress: boolean?, + isString: boolean?, + skipCrcValidation: boolean?, + skipSizeValidation: boolean?, } function ZipReader.extract(self: ZipReader, entry: ZipEntry, options: ExtractionOptions?): buffer | string - -- Local File Header format: - -- Offset Bytes Description - -- 0 4 Local file header signature - -- 6 2 General purpose bitflags - -- 8 2 Compression method (8 = DEFLATE) - -- 14 4 CRC32 checksum - -- 18 4 Compressed size - -- 22 4 Uncompressed size - -- 26 2 File name length (n) - -- 28 2 Extra field length (m) - -- 30 n File name - -- 30+n m Extra field - -- 30+n+m - File data + -- Local File Header format: + -- Offset Bytes Description + -- 0 4 Local file header signature + -- 6 2 General purpose bitflags + -- 8 2 Compression method (8 = DEFLATE) + -- 14 4 CRC32 checksum + -- 18 4 Compressed size + -- 22 4 Uncompressed size + -- 26 2 File name length (n) + -- 28 2 Extra field length (m) + -- 30 n File name + -- 30+n m Extra field + -- 30+n+m - File data - if entry.isDirectory then - error("Cannot extract directory") - end + if entry.isDirectory then + error("Cannot extract directory") + end - local defaultOptions: ExtractionOptions = { - decompress = true, - isString = false, - skipValidation = false, - } + local defaultOptions: ExtractionOptions = { + followSymlinks = false, + decompress = true, + isString = entry.isAscii, + skipValidation = false, + } - -- TODO: Use a `Partial` type function for this in the future! - local optionsOrDefault: { - decompress: boolean, - isString: boolean, - skipCrcValidation: boolean, - skipSizeValidation: boolean, - } = if options - then setmetatable(options, { __index = defaultOptions }) :: any - else defaultOptions + -- TODO: Use a `Partial` type function for this in the future! + local optionsOrDefault: { + followSymlinks: boolean, + decompress: boolean, + isString: boolean, + skipCrcValidation: boolean, + skipSizeValidation: boolean, + } = if options + then setmetatable(options, { __index = defaultOptions }) :: any + else defaultOptions - local pos = entry.offset - if buffer.readu32(self.data, pos) ~= SIGNATURES.LOCAL_FILE then - error("Invalid local file header") - end + local pos = entry.offset + if buffer.readu32(self.data, pos) ~= SIGNATURES.LOCAL_FILE then + error("Invalid local file header") + end - local bitflags = buffer.readu16(self.data, pos + 6) - local crcChecksum = buffer.readu32(self.data, pos + 14) - local compressedSize = buffer.readu32(self.data, pos + 18) - local uncompressedSize = buffer.readu32(self.data, pos + 22) - local nameLength = buffer.readu16(self.data, pos + 26) - local extraLength = buffer.readu16(self.data, pos + 28) + local bitflags = buffer.readu16(self.data, pos + 6) + local crcChecksum = buffer.readu32(self.data, pos + 14) + local compressedSize = buffer.readu32(self.data, pos + 18) + local uncompressedSize = buffer.readu32(self.data, pos + 22) + local nameLength = buffer.readu16(self.data, pos + 26) + local extraLength = buffer.readu16(self.data, pos + 28) - pos = pos + 30 + nameLength + extraLength + pos = pos + 30 + nameLength + extraLength - if bit32.band(bitflags, 0x08) ~= 0 then - -- The bit at offset 3 was set, meaning we did not have the file sizes - -- and CRC checksum at the time of the creation of the ZIP. Instead, they - -- were appended after the compressed data chunks in a data descriptor + if bit32.band(bitflags, 0x08) ~= 0 then + -- The bit at offset 3 was set, meaning we did not have the file sizes + -- and CRC checksum at the time of the creation of the ZIP. Instead, they + -- were appended after the compressed data chunks in a data descriptor - -- Data Descriptor format: - -- Offset Bytes Description - -- 0 0 or 4 0x08074b50 (optional signature) - -- 0 or 4 4 CRC32 checksum - -- 4 or 8 4 Compressed size - -- 8 or 12 4 Uncompressed size + -- Data Descriptor format: + -- Offset Bytes Description + -- 0 0 or 4 0x08074b50 (optional signature) + -- 0 or 4 4 CRC32 checksum + -- 4 or 8 4 Compressed size + -- 8 or 12 4 Uncompressed size - -- Start at the compressed data - local descriptorPos = pos - while true do - -- Try reading a u32 starting from current offset - local leading = buffer.readu32(self.data, descriptorPos) + -- Start at the compressed data + local descriptorPos = pos + while true do + -- Try reading a u32 starting from current offset + local leading = buffer.readu32(self.data, descriptorPos) - if leading == SIGNATURES.DATA_DESCRIPTOR then - -- If we find a data descriptor signature, that must mean - -- the current offset points is the start of the descriptor - break - end + if leading == SIGNATURES.DATA_DESCRIPTOR then + -- If we find a data descriptor signature, that must mean + -- the current offset points is the start of the descriptor + break + end - if leading == entry.crc then - -- If we find our file's CRC checksum, that means the data - -- descriptor signature was omitted, so our chunk starts 4 - -- bytes before - descriptorPos -= 4 - break - end + if leading == entry.crc then + -- If we find our file's CRC checksum, that means the data + -- descriptor signature was omitted, so our chunk starts 4 + -- bytes before + descriptorPos -= 4 + break + end - -- Skip to the next byte - descriptorPos += 1 - end + -- Skip to the next byte + descriptorPos += 1 + end - crcChecksum = buffer.readu32(self.data, descriptorPos + 4) - compressedSize = buffer.readu32(self.data, descriptorPos + 8) - uncompressedSize = buffer.readu32(self.data, descriptorPos + 12) - end + crcChecksum = buffer.readu32(self.data, descriptorPos + 4) + compressedSize = buffer.readu32(self.data, descriptorPos + 8) + uncompressedSize = buffer.readu32(self.data, descriptorPos + 12) + end - local content = buffer.create(compressedSize) - buffer.copy(content, 0, self.data, pos, compressedSize) + local content = buffer.create(compressedSize) + buffer.copy(content, 0, self.data, pos, compressedSize) - if optionsOrDefault.decompress then - local compressionMethod = buffer.readu16(self.data, entry.offset + 8) - local algo = DECOMPRESSION_ROUTINES[compressionMethod] - if algo == nil then - error(`Unsupported compression, ID: {compressionMethod}`) - end + if optionsOrDefault.decompress then + local compressionMethod = buffer.readu16(self.data, entry.offset + 8) + local algo = DECOMPRESSION_ROUTINES[compressionMethod] + if algo == nil then + error(`Unsupported compression, ID: {compressionMethod}`) + end - content = algo.decompress(content, uncompressedSize, { - expected = crcChecksum, - skip = optionsOrDefault.skipCrcValidation, - }) + if optionsOrDefault.followSymlinks then + local linkPath = buffer.tostring(algo.decompress(content, 0, { + expected = 0x00000000, + skip = true, + })) - -- Unless skipping validation is requested, we make sure the uncompressed size matches - assert( - optionsOrDefault.skipSizeValidation or uncompressedSize == buffer.len(content), - "Validation failed; uncompressed size does not match" - ) - end + --- Canonicalize a path by removing redundant components + local function canonicalize(path: string): string + -- NOTE: It is fine for us to use `/` here because ZIP file names + -- always use `/` as the path separator + local components = string.split(path, "/") + local result = {} + for _, component in components do + if component == "." then + -- Skip current directory + continue + end - return if optionsOrDefault.isString then buffer.tostring(content) else content + if component == ".." then + -- Traverse one upwards + table.remove(result, #result) + continue + end + + -- Otherwise, add the component to the result + table.insert(result, component) + end + + return table.concat(result, "/") + end + + -- Check if the path was a relative path + if + not ( + string.match(linkPath, "^/") + or string.match(linkPath, "^[a-zA-Z]:[\\/]") + or string.match(linkPath, "^//") + ) + then + if string.sub(linkPath, -1) ~= "/" then + linkPath ..= "/" + end + + linkPath = canonicalize(`{(entry.parent or self.root).name}{linkPath}`) + end + + optionsOrDefault.followSymlinks = false + optionsOrDefault.isString = false + optionsOrDefault.skipCrcValidation = true + optionsOrDefault.skipSizeValidation = true + content = self:extract( + self:findEntry(linkPath) or error("Symlink path not found"), + optionsOrDefault + ) :: buffer + end + + content = algo.decompress(content, uncompressedSize, { + expected = crcChecksum, + skip = optionsOrDefault.skipCrcValidation, + }) + + -- Unless skipping validation is requested, we make sure the uncompressed size matches + assert( + optionsOrDefault.skipSizeValidation or uncompressedSize == buffer.len(content), + "Validation failed; uncompressed size does not match" + ) + end + + return if optionsOrDefault.isString then buffer.tostring(content) else content end function ZipReader.extractDirectory( - self: ZipReader, - path: string, - options: ExtractionOptions + self: ZipReader, + path: string, + options: ExtractionOptions ): { [string]: buffer } | { [string]: string } - local files: { [string]: buffer } | { [string]: string } = {} - -- Normalize path by removing leading slash for consistent prefix matching - path = string.gsub(path, "^/", "") + local files: { [string]: buffer } | { [string]: string } = {} + -- Normalize path by removing leading slash for consistent prefix matching + path = string.gsub(path, "^/", "") - -- Iterate through all entries to find files within target directory - for _, entry in self.entries do - -- Check if entry is a file (not directory) and its path starts with target directory - if not entry.isDirectory and string.sub(entry.name, 1, #path) == path then - -- Store extracted content mapped to full path - files[entry.name] = self:extract(entry, options) - end - end + -- Iterate through all entries to find files within target directory + for _, entry in self.entries do + -- Check if entry is a file (not directory) and its path starts with target directory + if not entry.isDirectory and string.sub(entry.name, 1, #path) == path then + -- Store extracted content mapped to full path + files[entry.name] = self:extract(entry, options) + end + end - -- Return a map of file to contents - return files + -- Return a map of file to contents + return files end function ZipReader.listDirectory(self: ZipReader, path: string): { ZipEntry } - -- Locate the entry with the path - local entry = self:findEntry(path) - if not entry or not entry.isDirectory then - -- If an entry was not found, we error - error("Not a directory") - end + -- Locate the entry with the path + local entry = self:findEntry(path) + if not entry or not entry.isDirectory then + -- If an entry was not found, we error + error("Not a directory") + end - -- Return the children of our discovered entry - return entry.children + -- Return the children of our discovered entry + return entry.children end function ZipReader.walk(self: ZipReader, callback: (entry: ZipEntry, depth: number) -> ()): () - -- Wrapper function which recursively calls callback for every child - -- in an entry - local function walkEntry(entry: ZipEntry, depth: number) - callback(entry, depth) + -- Wrapper function which recursively calls callback for every child + -- in an entry + local function walkEntry(entry: ZipEntry, depth: number) + callback(entry, depth) - for _, child in entry.children do - -- ooo spooky recursion... blame this if shit go wrong - walkEntry(child, depth + 1) - end - end + for _, child in entry.children do + -- ooo spooky recursion... blame this if shit go wrong + walkEntry(child, depth + 1) + end + end - walkEntry(self.root, 0) + walkEntry(self.root, 0) end export type ZipStatistics = { fileCount: number, dirCount: number, totalSize: number } function ZipReader.getStats(self: ZipReader): ZipStatistics - local stats: ZipStatistics = { - fileCount = 0, - dirCount = 0, - totalSize = 0, - } + local stats: ZipStatistics = { + fileCount = 0, + dirCount = 0, + totalSize = 0, + } - -- Iterate through the entries, updating stats - for _, entry in self.entries do - if entry.isDirectory then - stats.dirCount += 1 - continue - end + -- Iterate through the entries, updating stats + for _, entry in self.entries do + if entry.isDirectory then + stats.dirCount += 1 + continue + end - stats.fileCount += 1 - stats.totalSize += entry.size - end + stats.fileCount += 1 + stats.totalSize += entry.size + end - return stats + return stats end return { - -- Creates a `ZipReader` from a `buffer` of ZIP data. - load = function(data: buffer) - return ZipReader.new(data) - end, + -- Creates a `ZipReader` from a `buffer` of ZIP data. + load = function(data: buffer) + return ZipReader.new(data) + end, } diff --git a/tests/edge_cases.luau b/tests/edge_cases.luau index 2a32625..4956902 100644 --- a/tests/edge_cases.luau +++ b/tests/edge_cases.luau @@ -1,4 +1,6 @@ local fs = require("@lune/fs") +local process = require("@lune/process") +local serde = require("@lune/serde") local frktest = require("../lune_packages/frktest") local check = frktest.assert.check @@ -8,10 +10,31 @@ local ZipReader = require("../lib") return function(test: typeof(frktest.test)) test.suite("Edge case tests", function() test.case("Handles misaligned comment properly", function() - local data = fs.readFile("tests/data/misaligned_comment.zip") - local zip = ZipReader.load(buffer.fromstring(data)) + local data = fs.readFile("tests/data/misaligned_comment.zip") + local zip = ZipReader.load(buffer.fromstring(data)) - check.equal(zip.comment, "short.") - end) + check.equal(zip.comment, "short.") end) -end \ No newline at end of file + + test.case("Follows symlinks correctly", function() + -- TODO: More test files with symlinks + + local data = fs.readFile("tests/data/pandoc_soft_links.zip") + local zip = ZipReader.load(buffer.fromstring(data)) + + local entry = assert(zip:findEntry("/pandoc-3.2-arm64/bin/pandoc-lua")) + assert(entry:isSymlink(), "Entry type must be a symlink") + + local targetPath = zip:extract(entry, { isString = true }) :: string + check.equal(targetPath, "pandoc") + + local bin = zip:extract(entry, { isString = false, followSymlinks = true }) :: buffer + local expectedBin = process.spawn("unzip", { "-p", "tests/data/pandoc_soft_links.zip", "pandoc-3.2-arm64/bin/pandoc" }) + check.is_true(expectedBin.ok) + + -- Compare hashes instead of the entire binary to improve speed and not print out + -- the entire binary data in case there's a mismatch + check.equal(serde.hash("blake3", bin), serde.hash("blake3", expectedBin.stdout)) + end) + end) +end