feat: support following symlinks

* Added new extraction option: `followSymlinks`
* Added method for querying whether an entry is a symlink using
  `ZipEntry:isSymlink`
* Include tests for symlinks
This commit is contained in:
Erica Marigold 2025-01-07 18:44:46 +00:00
parent 89ee51874b
commit d329a3f273
Signed by: DevComp
GPG key ID: 429EF1C337871656
2 changed files with 444 additions and 329 deletions

View file

@ -3,59 +3,59 @@ local crc32 = require("./crc")
-- Little endian constant signatures used in the ZIP file format -- Little endian constant signatures used in the ZIP file format
local SIGNATURES = table.freeze({ local SIGNATURES = table.freeze({
-- Marks the beginning of each file in the ZIP -- Marks the beginning of each file in the ZIP
LOCAL_FILE = 0x04034b50, LOCAL_FILE = 0x04034b50,
-- Marks the start of an data descriptor -- Marks the start of an data descriptor
DATA_DESCRIPTOR = 0x08074b50, DATA_DESCRIPTOR = 0x08074b50,
-- Marks entries in the central directory -- Marks entries in the central directory
CENTRAL_DIR = 0x02014b50, CENTRAL_DIR = 0x02014b50,
-- Marks the end of the central directory -- Marks the end of the central directory
END_OF_CENTRAL_DIR = 0x06054b50, END_OF_CENTRAL_DIR = 0x06054b50,
}) })
type CrcValidationOptions = { type CrcValidationOptions = {
skip: boolean, skip: boolean,
expected: number, expected: number,
} }
local function validateCrc(decompressed: buffer, validation: CrcValidationOptions) local function validateCrc(decompressed: buffer, validation: CrcValidationOptions)
-- Unless skipping validation is requested, we verify the checksum -- Unless skipping validation is requested, we verify the checksum
if not validation.skip then if not validation.skip then
local computed = crc32(decompressed) local computed = crc32(decompressed)
assert( assert(
validation.expected == computed, validation.expected == computed,
`Validation failed; CRC checksum does not match: {string.format("%x", computed)} ~= {string.format( `Validation failed; CRC checksum does not match: {string.format("%x", computed)} ~= {string.format(
"%x", "%x",
computed computed
)} (expected ~= got)` )} (expected ~= got)`
) )
end end
end end
export type CompressionMethod = "STORE" | "DEFLATE" export type CompressionMethod = "STORE" | "DEFLATE"
local DECOMPRESSION_ROUTINES: { [number]: { name: CompressionMethod, decompress: (buffer, number, CrcValidationOptions) -> buffer } } = local DECOMPRESSION_ROUTINES: { [number]: { name: CompressionMethod, decompress: (buffer, number, CrcValidationOptions) -> buffer } } =
table.freeze({ table.freeze({
-- `STORE` decompression method - No compression -- `STORE` decompression method - No compression
[0x00] = { [0x00] = {
name = "STORE" :: CompressionMethod, name = "STORE" :: CompressionMethod,
decompress = function(buf, _, validation) decompress = function(buf, _, validation)
validateCrc(buf, validation) validateCrc(buf, validation)
return buf return buf
end, end,
}, },
-- `DEFLATE` decompression method - Compressed raw deflate chunks -- `DEFLATE` decompression method - Compressed raw deflate chunks
[0x08] = { [0x08] = {
name = "DEFLATE" :: CompressionMethod, name = "DEFLATE" :: CompressionMethod,
decompress = function(buf, uncompressedSize, validation) decompress = function(buf, uncompressedSize, validation)
-- FIXME: Why is uncompressedSize not getting inferred correctly although it -- FIXME: Why is uncompressedSize not getting inferred correctly although it
-- is typed? -- is typed?
local decompressed = inflate(buf, uncompressedSize :: any) local decompressed = inflate(buf, uncompressedSize :: any)
validateCrc(decompressed, validation) validateCrc(decompressed, validation)
return decompressed return decompressed
end, end,
}, },
}) })
-- TODO: ERROR HANDLING! -- TODO: ERROR HANDLING!
@ -70,44 +70,59 @@ type ZipEntryInner = {
method: CompressionMethod, -- Method used to compress the file method: CompressionMethod, -- Method used to compress the file
crc: number, -- CRC32 checksum of uncompressed data crc: number, -- CRC32 checksum of uncompressed data
isDirectory: boolean, -- Whether the entry is a directory or not isDirectory: boolean, -- Whether the entry is a directory or not
isAscii: boolean, -- Whether the entry is plain ASCII text or binary
attributes: number, -- File attributes
parent: ZipEntry?, -- The parent of the current entry, nil for root parent: ZipEntry?, -- The parent of the current entry, nil for root
children: { ZipEntry }, -- The children of the entry children: { ZipEntry }, -- The children of the entry
} }
function ZipEntry.new( type ZipEntryProperties = {
name: string, size: number,
size: number, attributes: number,
offset: number, timestamp: number,
timestamp: number, method: CompressionMethod?,
method: CompressionMethod?, crc: number,
crc: number }
): ZipEntry local EMPTY_PROPERTIES: ZipEntryProperties = table.freeze({
return setmetatable( size = 0,
{ attributes = 0,
name = name, timestamp = 0,
size = size, method = nil,
offset = offset, crc = 0,
timestamp = timestamp, })
method = method,
crc = crc, function ZipEntry.new(offset: number, name: string, properties: ZipEntryProperties): ZipEntry
isDirectory = string.sub(name, -1) == "/", return setmetatable(
parent = nil, {
children = {}, name = name,
} :: ZipEntryInner, size = properties.size,
{ __index = ZipEntry } offset = offset,
) timestamp = properties.timestamp,
method = properties.method,
crc = properties.crc,
isDirectory = string.sub(name, -1) == "/",
attributes = properties.attributes,
parent = nil,
children = {},
} :: ZipEntryInner,
{ __index = ZipEntry }
)
end
function ZipEntry.isSymlink(self: ZipEntry): boolean
return bit32.band(self.attributes, 0xA0000000) == 0xA0000000
end end
function ZipEntry.getPath(self: ZipEntry): string function ZipEntry.getPath(self: ZipEntry): string
local path = self.name local path = self.name
local current = self.parent local current = self.parent
while current and current.name ~= "/" do while current and current.name ~= "/" do
path = current.name .. path path = current.name .. path
current = current.parent current = current.parent
end end
return path return path
end end
local ZipReader = {} local ZipReader = {}
@ -122,98 +137,110 @@ type ZipReaderInner = {
} }
function ZipReader.new(data): ZipReader function ZipReader.new(data): ZipReader
local root = ZipEntry.new("/", 0, 0, 0, nil, 0) local root = ZipEntry.new(0, "/", EMPTY_PROPERTIES)
root.isDirectory = true root.isDirectory = true
local this = setmetatable( local this = setmetatable(
{ {
data = data, data = data,
entries = {}, entries = {},
directories = {}, directories = {},
root = root, root = root,
} :: ZipReaderInner, } :: ZipReaderInner,
{ __index = ZipReader } { __index = ZipReader }
) )
this:parseCentralDirectory() this:parseCentralDirectory()
this:buildDirectoryTree() this:buildDirectoryTree()
return this return this
end end
function ZipReader.parseCentralDirectory(self: ZipReader): () function ZipReader.parseCentralDirectory(self: ZipReader): ()
-- ZIP files are read from the end, starting with the End of Central Directory record -- ZIP files are read from the end, starting with the End of Central Directory record
-- The EoCD is at least 22 bytes and contains pointers to the rest of the ZIP structure -- The EoCD is at least 22 bytes and contains pointers to the rest of the ZIP structure
local bufSize = buffer.len(self.data) local bufSize = buffer.len(self.data)
-- Start from the minimum possible position of EoCD (22 bytes from end) -- Start from the minimum possible position of EoCD (22 bytes from end)
local minPos = math.max(0, bufSize - (22 + 65535) --[[ max comment size: 64 KiB ]]) local minPos = math.max(0, bufSize - (22 + 65535) --[[ max comment size: 64 KiB ]])
local pos = bufSize - 22 local pos = bufSize - 22
-- Search backwards for the EoCD signature -- Search backwards for the EoCD signature
while pos >= minPos do while pos >= minPos do
if buffer.readu32(self.data, pos) == SIGNATURES.END_OF_CENTRAL_DIR then if buffer.readu32(self.data, pos) == SIGNATURES.END_OF_CENTRAL_DIR then
break break
end end
pos -= 1 pos -= 1
end end
-- Verify we found the signature -- Verify we found the signature
if pos < minPos then if pos < minPos then
error("Could not find End of Central Directory signature") error("Could not find End of Central Directory signature")
end end
-- End of Central Directory format: -- End of Central Directory format:
-- Offset Bytes Description -- Offset Bytes Description
-- 0 4 End of central directory signature -- 0 4 End of central directory signature
-- 4 2 Number of this disk -- 4 2 Number of this disk
-- 6 2 Disk where central directory starts -- 6 2 Disk where central directory starts
-- 8 2 Number of central directory records on this disk -- 8 2 Number of central directory records on this disk
-- 10 2 Total number of central directory records -- 10 2 Total number of central directory records
-- 12 4 Size of central directory (bytes) -- 12 4 Size of central directory (bytes)
-- 16 4 Offset of start of central directory -- 16 4 Offset of start of central directory
-- 20 2 Comment length (n) -- 20 2 Comment length (n)
-- 22 n Comment -- 22 n Comment
local cdOffset = buffer.readu32(self.data, pos + 16) local cdOffset = buffer.readu32(self.data, pos + 16)
local cdEntries = buffer.readu16(self.data, pos + 10) local cdEntries = buffer.readu16(self.data, pos + 10)
local cdCommentLength = buffer.readu16(self.data, pos + 20) local cdCommentLength = buffer.readu16(self.data, pos + 20)
self.comment = buffer.readstring(self.data, pos + 22, cdCommentLength) self.comment = buffer.readstring(self.data, pos + 22, cdCommentLength)
-- Process each entry in the Central Directory -- Process each entry in the Central Directory
pos = cdOffset pos = cdOffset
for i = 1, cdEntries do for i = 1, cdEntries do
-- Central Directory Entry format: -- Central Directory Entry format:
-- Offset Bytes Description -- Offset Bytes Description
-- 0 4 Central directory entry signature -- 0 4 Central directory entry signature
-- 8 2 General purpose bitflags -- 8 2 General purpose bitflags
-- 10 2 Compression method (8 = DEFLATE) -- 10 2 Compression method (8 = DEFLATE)
-- 12 4 Last mod time/date -- 12 4 Last mod time/date
-- 16 4 CRC-32 -- 16 4 CRC-32
-- 24 4 Uncompressed size -- 24 4 Uncompressed size
-- 28 2 File name length (n) -- 28 2 File name length (n)
-- 30 2 Extra field length (m) -- 30 2 Extra field length (m)
-- 32 2 Comment length (k) -- 32 2 Comment length (k)
-- 42 4 Local header offset -- 36 2 Internal file attributes
-- 46 n File name -- 38 4 External file attributes
-- 46+n m Extra field -- 42 4 Local header offset
-- 46+n+m k Comment -- 46 n File name
-- 46+n m Extra field
-- 46+n+m k Comment
local _bitflags = buffer.readu16(self.data, pos + 8) local _bitflags = buffer.readu16(self.data, pos + 8)
local timestamp = buffer.readu32(self.data, pos + 12) local timestamp = buffer.readu32(self.data, pos + 12)
local compressionMethod = buffer.readu16(self.data, pos + 10) local compressionMethod = buffer.readu16(self.data, pos + 10)
local crc = buffer.readu32(self.data, pos + 16) local crc = buffer.readu32(self.data, pos + 16)
local size = buffer.readu32(self.data, pos + 24) local size = buffer.readu32(self.data, pos + 24)
local nameLength = buffer.readu16(self.data, pos + 28) local nameLength = buffer.readu16(self.data, pos + 28)
local extraLength = buffer.readu16(self.data, pos + 30) local extraLength = buffer.readu16(self.data, pos + 30)
local commentLength = buffer.readu16(self.data, pos + 32) local commentLength = buffer.readu16(self.data, pos + 32)
local offset = buffer.readu32(self.data, pos + 42) local internalAttrs = buffer.readu16(self.data, pos + 36)
local name = buffer.readstring(self.data, pos + 46, nameLength) local externalAttrs = buffer.readu32(self.data, pos + 38)
local offset = buffer.readu32(self.data, pos + 42)
local name = buffer.readstring(self.data, pos + 46, nameLength)
local entry = ZipEntry.new(name, size, offset, timestamp, DECOMPRESSION_ROUTINES[compressionMethod].name, crc) table.insert(
table.insert(self.entries, entry) self.entries,
ZipEntry.new(offset, name, {
size = size,
crc = crc,
compressionMethod = DECOMPRESSION_ROUTINES[compressionMethod].name,
timestamp = timestamp,
attributes = externalAttrs,
isAscii = bit32.band(internalAttrs, 0x0001) ~= 0,
})
)
pos = pos + 46 + nameLength + extraLength + commentLength pos = pos + 46 + nameLength + extraLength + commentLength
end end
end end
function ZipReader.buildDirectoryTree(self: ZipReader): () function ZipReader.buildDirectoryTree(self: ZipReader): ()
@ -253,10 +280,16 @@ function ZipReader.buildDirectoryTree(self: ZipReader): ()
else else
-- Create new directory entry for intermediate paths or undefined -- Create new directory entry for intermediate paths or undefined
-- parent directories in the ZIP -- parent directories in the ZIP
local dir = ZipEntry.new(path .. "/", 0, 0, entry.timestamp, nil, 0) local dir = ZipEntry.new(0, path .. "/", {
dir.isDirectory = true size = 0,
dir.parent = current crc = 0,
self.directories[path] = dir compressionMethod = "STORED",
timestamp = entry.timestamp,
attributes = entry.attributes,
})
dir.isDirectory = true
dir.parent = current
self.directories[path] = dir
end end
-- Track directory in both lookup table and parent's children -- Track directory in both lookup table and parent's children
@ -276,225 +309,284 @@ function ZipReader.buildDirectoryTree(self: ZipReader): ()
end end
function ZipReader.findEntry(self: ZipReader, path: string): ZipEntry? function ZipReader.findEntry(self: ZipReader, path: string): ZipEntry?
if path == "/" then if path == "/" then
-- If the root directory's entry was requested we do not -- If the root directory's entry was requested we do not
-- need to do any additional work -- need to do any additional work
return self.root return self.root
end end
-- Normalize path by removing leading and trailing slashes -- Normalize path by removing leading and trailing slashes
-- This ensures consistent lookup regardless of input format -- This ensures consistent lookup regardless of input format
-- e.g., "/folder/file.txt/" -> "folder/file.txt" -- e.g., "/folder/file.txt/" -> "folder/file.txt"
path = string.gsub(path, "^/", ""):gsub("/$", "") path = string.gsub(path, "^/", ""):gsub("/$", "")
-- First check regular files and explicit directories -- First check regular files and explicit directories
for _, entry in self.entries do for _, entry in self.entries do
-- Compare normalized paths -- Compare normalized paths
if string.gsub(entry.name, "/$", "") == path then if string.gsub(entry.name, "/$", "") == path then
return entry return entry
end end
end end
-- If not found, check virtual directory entries -- If not found, check virtual directory entries
-- These are directories that were created implicitly -- These are directories that were created implicitly
return self.directories[path] return self.directories[path]
end end
type ExtractionOptions = { type ExtractionOptions = {
decompress: boolean?, followSymlinks: boolean?,
isString: boolean?, decompress: boolean?,
skipCrcValidation: boolean?, isString: boolean?,
skipSizeValidation: boolean?, skipCrcValidation: boolean?,
skipSizeValidation: boolean?,
} }
function ZipReader.extract(self: ZipReader, entry: ZipEntry, options: ExtractionOptions?): buffer | string function ZipReader.extract(self: ZipReader, entry: ZipEntry, options: ExtractionOptions?): buffer | string
-- Local File Header format: -- Local File Header format:
-- Offset Bytes Description -- Offset Bytes Description
-- 0 4 Local file header signature -- 0 4 Local file header signature
-- 6 2 General purpose bitflags -- 6 2 General purpose bitflags
-- 8 2 Compression method (8 = DEFLATE) -- 8 2 Compression method (8 = DEFLATE)
-- 14 4 CRC32 checksum -- 14 4 CRC32 checksum
-- 18 4 Compressed size -- 18 4 Compressed size
-- 22 4 Uncompressed size -- 22 4 Uncompressed size
-- 26 2 File name length (n) -- 26 2 File name length (n)
-- 28 2 Extra field length (m) -- 28 2 Extra field length (m)
-- 30 n File name -- 30 n File name
-- 30+n m Extra field -- 30+n m Extra field
-- 30+n+m - File data -- 30+n+m - File data
if entry.isDirectory then if entry.isDirectory then
error("Cannot extract directory") error("Cannot extract directory")
end end
local defaultOptions: ExtractionOptions = { local defaultOptions: ExtractionOptions = {
decompress = true, followSymlinks = false,
isString = false, decompress = true,
skipValidation = false, isString = entry.isAscii,
} skipValidation = false,
}
-- TODO: Use a `Partial` type function for this in the future! -- TODO: Use a `Partial` type function for this in the future!
local optionsOrDefault: { local optionsOrDefault: {
decompress: boolean, followSymlinks: boolean,
isString: boolean, decompress: boolean,
skipCrcValidation: boolean, isString: boolean,
skipSizeValidation: boolean, skipCrcValidation: boolean,
} = if options skipSizeValidation: boolean,
then setmetatable(options, { __index = defaultOptions }) :: any } = if options
else defaultOptions then setmetatable(options, { __index = defaultOptions }) :: any
else defaultOptions
local pos = entry.offset local pos = entry.offset
if buffer.readu32(self.data, pos) ~= SIGNATURES.LOCAL_FILE then if buffer.readu32(self.data, pos) ~= SIGNATURES.LOCAL_FILE then
error("Invalid local file header") error("Invalid local file header")
end end
local bitflags = buffer.readu16(self.data, pos + 6) local bitflags = buffer.readu16(self.data, pos + 6)
local crcChecksum = buffer.readu32(self.data, pos + 14) local crcChecksum = buffer.readu32(self.data, pos + 14)
local compressedSize = buffer.readu32(self.data, pos + 18) local compressedSize = buffer.readu32(self.data, pos + 18)
local uncompressedSize = buffer.readu32(self.data, pos + 22) local uncompressedSize = buffer.readu32(self.data, pos + 22)
local nameLength = buffer.readu16(self.data, pos + 26) local nameLength = buffer.readu16(self.data, pos + 26)
local extraLength = buffer.readu16(self.data, pos + 28) local extraLength = buffer.readu16(self.data, pos + 28)
pos = pos + 30 + nameLength + extraLength pos = pos + 30 + nameLength + extraLength
if bit32.band(bitflags, 0x08) ~= 0 then if bit32.band(bitflags, 0x08) ~= 0 then
-- The bit at offset 3 was set, meaning we did not have the file sizes -- The bit at offset 3 was set, meaning we did not have the file sizes
-- and CRC checksum at the time of the creation of the ZIP. Instead, they -- and CRC checksum at the time of the creation of the ZIP. Instead, they
-- were appended after the compressed data chunks in a data descriptor -- were appended after the compressed data chunks in a data descriptor
-- Data Descriptor format: -- Data Descriptor format:
-- Offset Bytes Description -- Offset Bytes Description
-- 0 0 or 4 0x08074b50 (optional signature) -- 0 0 or 4 0x08074b50 (optional signature)
-- 0 or 4 4 CRC32 checksum -- 0 or 4 4 CRC32 checksum
-- 4 or 8 4 Compressed size -- 4 or 8 4 Compressed size
-- 8 or 12 4 Uncompressed size -- 8 or 12 4 Uncompressed size
-- Start at the compressed data -- Start at the compressed data
local descriptorPos = pos local descriptorPos = pos
while true do while true do
-- Try reading a u32 starting from current offset -- Try reading a u32 starting from current offset
local leading = buffer.readu32(self.data, descriptorPos) local leading = buffer.readu32(self.data, descriptorPos)
if leading == SIGNATURES.DATA_DESCRIPTOR then if leading == SIGNATURES.DATA_DESCRIPTOR then
-- If we find a data descriptor signature, that must mean -- If we find a data descriptor signature, that must mean
-- the current offset points is the start of the descriptor -- the current offset points is the start of the descriptor
break break
end end
if leading == entry.crc then if leading == entry.crc then
-- If we find our file's CRC checksum, that means the data -- If we find our file's CRC checksum, that means the data
-- descriptor signature was omitted, so our chunk starts 4 -- descriptor signature was omitted, so our chunk starts 4
-- bytes before -- bytes before
descriptorPos -= 4 descriptorPos -= 4
break break
end end
-- Skip to the next byte -- Skip to the next byte
descriptorPos += 1 descriptorPos += 1
end end
crcChecksum = buffer.readu32(self.data, descriptorPos + 4) crcChecksum = buffer.readu32(self.data, descriptorPos + 4)
compressedSize = buffer.readu32(self.data, descriptorPos + 8) compressedSize = buffer.readu32(self.data, descriptorPos + 8)
uncompressedSize = buffer.readu32(self.data, descriptorPos + 12) uncompressedSize = buffer.readu32(self.data, descriptorPos + 12)
end end
local content = buffer.create(compressedSize) local content = buffer.create(compressedSize)
buffer.copy(content, 0, self.data, pos, compressedSize) buffer.copy(content, 0, self.data, pos, compressedSize)
if optionsOrDefault.decompress then if optionsOrDefault.decompress then
local compressionMethod = buffer.readu16(self.data, entry.offset + 8) local compressionMethod = buffer.readu16(self.data, entry.offset + 8)
local algo = DECOMPRESSION_ROUTINES[compressionMethod] local algo = DECOMPRESSION_ROUTINES[compressionMethod]
if algo == nil then if algo == nil then
error(`Unsupported compression, ID: {compressionMethod}`) error(`Unsupported compression, ID: {compressionMethod}`)
end end
content = algo.decompress(content, uncompressedSize, { if optionsOrDefault.followSymlinks then
expected = crcChecksum, local linkPath = buffer.tostring(algo.decompress(content, 0, {
skip = optionsOrDefault.skipCrcValidation, expected = 0x00000000,
}) skip = true,
}))
-- Unless skipping validation is requested, we make sure the uncompressed size matches --- Canonicalize a path by removing redundant components
assert( local function canonicalize(path: string): string
optionsOrDefault.skipSizeValidation or uncompressedSize == buffer.len(content), -- NOTE: It is fine for us to use `/` here because ZIP file names
"Validation failed; uncompressed size does not match" -- always use `/` as the path separator
) local components = string.split(path, "/")
end local result = {}
for _, component in components do
if component == "." then
-- Skip current directory
continue
end
return if optionsOrDefault.isString then buffer.tostring(content) else content if component == ".." then
-- Traverse one upwards
table.remove(result, #result)
continue
end
-- Otherwise, add the component to the result
table.insert(result, component)
end
return table.concat(result, "/")
end
-- Check if the path was a relative path
if
not (
string.match(linkPath, "^/")
or string.match(linkPath, "^[a-zA-Z]:[\\/]")
or string.match(linkPath, "^//")
)
then
if string.sub(linkPath, -1) ~= "/" then
linkPath ..= "/"
end
linkPath = canonicalize(`{(entry.parent or self.root).name}{linkPath}`)
end
optionsOrDefault.followSymlinks = false
optionsOrDefault.isString = false
optionsOrDefault.skipCrcValidation = true
optionsOrDefault.skipSizeValidation = true
content = self:extract(
self:findEntry(linkPath) or error("Symlink path not found"),
optionsOrDefault
) :: buffer
end
content = algo.decompress(content, uncompressedSize, {
expected = crcChecksum,
skip = optionsOrDefault.skipCrcValidation,
})
-- Unless skipping validation is requested, we make sure the uncompressed size matches
assert(
optionsOrDefault.skipSizeValidation or uncompressedSize == buffer.len(content),
"Validation failed; uncompressed size does not match"
)
end
return if optionsOrDefault.isString then buffer.tostring(content) else content
end end
function ZipReader.extractDirectory( function ZipReader.extractDirectory(
self: ZipReader, self: ZipReader,
path: string, path: string,
options: ExtractionOptions options: ExtractionOptions
): { [string]: buffer } | { [string]: string } ): { [string]: buffer } | { [string]: string }
local files: { [string]: buffer } | { [string]: string } = {} local files: { [string]: buffer } | { [string]: string } = {}
-- Normalize path by removing leading slash for consistent prefix matching -- Normalize path by removing leading slash for consistent prefix matching
path = string.gsub(path, "^/", "") path = string.gsub(path, "^/", "")
-- Iterate through all entries to find files within target directory -- Iterate through all entries to find files within target directory
for _, entry in self.entries do for _, entry in self.entries do
-- Check if entry is a file (not directory) and its path starts with target directory -- Check if entry is a file (not directory) and its path starts with target directory
if not entry.isDirectory and string.sub(entry.name, 1, #path) == path then if not entry.isDirectory and string.sub(entry.name, 1, #path) == path then
-- Store extracted content mapped to full path -- Store extracted content mapped to full path
files[entry.name] = self:extract(entry, options) files[entry.name] = self:extract(entry, options)
end end
end end
-- Return a map of file to contents -- Return a map of file to contents
return files return files
end end
function ZipReader.listDirectory(self: ZipReader, path: string): { ZipEntry } function ZipReader.listDirectory(self: ZipReader, path: string): { ZipEntry }
-- Locate the entry with the path -- Locate the entry with the path
local entry = self:findEntry(path) local entry = self:findEntry(path)
if not entry or not entry.isDirectory then if not entry or not entry.isDirectory then
-- If an entry was not found, we error -- If an entry was not found, we error
error("Not a directory") error("Not a directory")
end end
-- Return the children of our discovered entry -- Return the children of our discovered entry
return entry.children return entry.children
end end
function ZipReader.walk(self: ZipReader, callback: (entry: ZipEntry, depth: number) -> ()): () function ZipReader.walk(self: ZipReader, callback: (entry: ZipEntry, depth: number) -> ()): ()
-- Wrapper function which recursively calls callback for every child -- Wrapper function which recursively calls callback for every child
-- in an entry -- in an entry
local function walkEntry(entry: ZipEntry, depth: number) local function walkEntry(entry: ZipEntry, depth: number)
callback(entry, depth) callback(entry, depth)
for _, child in entry.children do for _, child in entry.children do
-- ooo spooky recursion... blame this if shit go wrong -- ooo spooky recursion... blame this if shit go wrong
walkEntry(child, depth + 1) walkEntry(child, depth + 1)
end end
end end
walkEntry(self.root, 0) walkEntry(self.root, 0)
end end
export type ZipStatistics = { fileCount: number, dirCount: number, totalSize: number } export type ZipStatistics = { fileCount: number, dirCount: number, totalSize: number }
function ZipReader.getStats(self: ZipReader): ZipStatistics function ZipReader.getStats(self: ZipReader): ZipStatistics
local stats: ZipStatistics = { local stats: ZipStatistics = {
fileCount = 0, fileCount = 0,
dirCount = 0, dirCount = 0,
totalSize = 0, totalSize = 0,
} }
-- Iterate through the entries, updating stats -- Iterate through the entries, updating stats
for _, entry in self.entries do for _, entry in self.entries do
if entry.isDirectory then if entry.isDirectory then
stats.dirCount += 1 stats.dirCount += 1
continue continue
end end
stats.fileCount += 1 stats.fileCount += 1
stats.totalSize += entry.size stats.totalSize += entry.size
end end
return stats return stats
end end
return { return {
-- Creates a `ZipReader` from a `buffer` of ZIP data. -- Creates a `ZipReader` from a `buffer` of ZIP data.
load = function(data: buffer) load = function(data: buffer)
return ZipReader.new(data) return ZipReader.new(data)
end, end,
} }

View file

@ -1,4 +1,6 @@
local fs = require("@lune/fs") local fs = require("@lune/fs")
local process = require("@lune/process")
local serde = require("@lune/serde")
local frktest = require("../lune_packages/frktest") local frktest = require("../lune_packages/frktest")
local check = frktest.assert.check local check = frktest.assert.check
@ -8,10 +10,31 @@ local ZipReader = require("../lib")
return function(test: typeof(frktest.test)) return function(test: typeof(frktest.test))
test.suite("Edge case tests", function() test.suite("Edge case tests", function()
test.case("Handles misaligned comment properly", function() test.case("Handles misaligned comment properly", function()
local data = fs.readFile("tests/data/misaligned_comment.zip") local data = fs.readFile("tests/data/misaligned_comment.zip")
local zip = ZipReader.load(buffer.fromstring(data)) local zip = ZipReader.load(buffer.fromstring(data))
check.equal(zip.comment, "short.") check.equal(zip.comment, "short.")
end)
end) end)
test.case("Follows symlinks correctly", function()
-- TODO: More test files with symlinks
local data = fs.readFile("tests/data/pandoc_soft_links.zip")
local zip = ZipReader.load(buffer.fromstring(data))
local entry = assert(zip:findEntry("/pandoc-3.2-arm64/bin/pandoc-lua"))
assert(entry:isSymlink(), "Entry type must be a symlink")
local targetPath = zip:extract(entry, { isString = true }) :: string
check.equal(targetPath, "pandoc")
local bin = zip:extract(entry, { isString = false, followSymlinks = true }) :: buffer
local expectedBin = process.spawn("unzip", { "-p", "tests/data/pandoc_soft_links.zip", "pandoc-3.2-arm64/bin/pandoc" })
check.is_true(expectedBin.ok)
-- Compare hashes instead of the entire binary to improve speed and not print out
-- the entire binary data in case there's a mismatch
check.equal(serde.hash("blake3", bin), serde.hash("blake3", expectedBin.stdout))
end)
end)
end end