Add new options for global injection and codegen to luau.load

This commit is contained in:
Filip Tibell 2024-06-05 19:02:48 +02:00
parent 3cf2be51bc
commit d3b9a4b9e8
No known key found for this signature in database
6 changed files with 188 additions and 36 deletions

View file

@ -44,26 +44,41 @@ fn load_source<'lua>(
(source, options): (LuaString<'lua>, LuauLoadOptions),
) -> LuaResult<LuaFunction<'lua>> {
let mut chunk = lua.load(source.as_bytes()).set_name(options.debug_name);
let env_changed = options.environment.is_some();
if let Some(environment) = options.environment {
let environment_with_globals = lua.create_table()?;
if let Some(custom_environment) = options.environment {
let environment = lua.create_table()?;
if let Some(meta) = environment.get_metatable() {
environment_with_globals.set_metatable(Some(meta));
// Inject all globals into the environment
if options.inject_globals {
for pair in lua.globals().pairs() {
let (key, value): (LuaValue, LuaValue) = pair?;
environment.set(key, value)?;
}
if let Some(global_metatable) = lua.globals().get_metatable() {
environment.set_metatable(Some(global_metatable));
}
} else if let Some(custom_metatable) = custom_environment.get_metatable() {
// Since we don't need to set the global metatable,
// we can just set a custom metatable if it exists
environment.set_metatable(Some(custom_metatable));
}
for pair in lua.globals().pairs() {
// Inject the custom environment
for pair in custom_environment.pairs() {
let (key, value): (LuaValue, LuaValue) = pair?;
environment_with_globals.set(key, value)?;
environment.set(key, value)?;
}
for pair in environment.pairs() {
let (key, value): (LuaValue, LuaValue) = pair?;
environment_with_globals.set(key, value)?;
}
chunk = chunk.set_environment(environment_with_globals);
chunk = chunk.set_environment(environment);
}
chunk.into_function()
// Enable JIT if codegen is enabled and the environment hasn't
// changed, otherwise disable JIT since it'll fall back anyways
lua.enable_jit(options.codegen_enabled && !env_changed);
let function = chunk.into_function()?;
lua.enable_jit(true);
Ok(function)
}

View file

@ -79,13 +79,11 @@ impl<'lua> FromLua<'lua> for LuauCompileOptions {
}
}
/**
Options for loading Lua source code.
*/
#[derive(Debug, Clone)]
pub struct LuauLoadOptions<'lua> {
pub(crate) debug_name: String,
pub(crate) environment: Option<LuaTable<'lua>>,
pub(crate) inject_globals: bool,
pub(crate) codegen_enabled: bool,
}
impl Default for LuauLoadOptions<'_> {
@ -93,6 +91,8 @@ impl Default for LuauLoadOptions<'_> {
Self {
debug_name: DEFAULT_DEBUG_NAME.to_string(),
environment: None,
inject_globals: true,
codegen_enabled: false,
}
}
}
@ -112,11 +112,21 @@ impl<'lua> FromLua<'lua> for LuauLoadOptions<'lua> {
options.environment = Some(environment);
}
if let Some(inject_globals) = t.get("injectGlobals")? {
options.inject_globals = inject_globals;
}
if let Some(codegen_enabled) = t.get("codegenEnabled")? {
options.codegen_enabled = codegen_enabled;
}
options
}
LuaValue::String(s) => Self {
debug_name: s.to_string_lossy().to_string(),
environment: None,
inject_globals: true,
codegen_enabled: false,
},
_ => {
return Err(LuaError::FromLuaConversionError {

View file

@ -113,6 +113,7 @@ create_tests! {
luau_compile: "luau/compile",
luau_load: "luau/load",
luau_options: "luau/options",
luau_safeenv: "luau/safeenv",
}
#[cfg(feature = "std-net")]

View file

@ -26,11 +26,11 @@ assert(
"expected source block name for 'luau.load' to return a custom debug name"
)
local success = pcall(function()
local loadSuccess = pcall(function()
luau.load(luau.compile(RETURN_LUAU_CODE_BLOCK))
end)
assert(success, "expected `luau.load` to be able to process the result of `luau.compile`")
assert(loadSuccess, "expected `luau.load` to be able to process the result of `luau.compile`")
local CUSTOM_SOURCE_WITH_FOO_FN = "return foo()"
@ -48,34 +48,92 @@ local fooFn = luau.load(CUSTOM_SOURCE_WITH_FOO_FN, {
local fooFnRet = fooFn()
assert(fooFnRet == fooValue, "expected `luau.load` with custom environment to return proper values")
local CUSTOM_SOURCE_WITH_PRINT_FN = "return print()"
-- NOTE: Same as what we did above, new userdata to guarantee unique-ness
local overriddenValue = newproxy(false)
local overriddenFn = luau.load(CUSTOM_SOURCE_WITH_PRINT_FN, {
local fooValue2 = newproxy(false)
local fooFn2 = luau.load(CUSTOM_SOURCE_WITH_FOO_FN, {
environment = {
print = function()
return overriddenValue
foo = function()
return fooValue2
end,
},
enableGlobals = false,
})
local overriddenFnRet = overriddenFn()
local fooFn2Ret = fooFn2()
assert(
overriddenFnRet == overriddenValue,
fooFn2Ret == fooValue2,
"expected `luau.load` with custom environment and no default globals to still return proper values"
)
local CUSTOM_SOURCE_WITH_PRINT_FN = "return print()"
-- NOTE: Testing overriding the print function
local overriddenPrintValue1 = newproxy(false)
local overriddenPrintFn1 = luau.load(CUSTOM_SOURCE_WITH_PRINT_FN, {
environment = {
print = function()
return overriddenPrintValue1
end,
},
enableGlobals = true,
})
local overriddenPrintFnRet1 = overriddenPrintFn1()
assert(
overriddenPrintFnRet1 == overriddenPrintValue1,
"expected `luau.load` with overridden environment to return proper values"
)
local CUSTOM_SOURCE_WITH_DEFAULT_FN = "return string.lower(...)"
local overriddenFn2 = luau.load(CUSTOM_SOURCE_WITH_DEFAULT_FN, {
local overriddenPrintValue2 = newproxy(false)
local overriddenPrintFn2 = luau.load(CUSTOM_SOURCE_WITH_PRINT_FN, {
environment = {
hello = "world",
print = function()
return overriddenPrintValue2
end,
},
enableGlobals = false,
})
local overriddenFn2Ret = overriddenFn2("LOWERCASE")
local overriddenPrintFnRet2 = overriddenPrintFn2()
assert(
overriddenFn2Ret == "lowercase",
"expected `luau.load` with overridden environment to contain default globals"
overriddenPrintFnRet2 == overriddenPrintValue2,
"expected `luau.load` with overridden environment and disabled default globals to return proper values"
)
-- NOTE: Testing whether injectGlobals works
local CUSTOM_SOURCE_WITH_DEFAULT_FN = "return string.lower(...)"
local lowerFn1 = luau.load(CUSTOM_SOURCE_WITH_DEFAULT_FN, {
environment = {},
injectGlobals = false,
})
local lowerFn1Success = pcall(lowerFn1, "LOWERCASE")
assert(
not lowerFn1Success,
"expected `luau.load` with injectGlobals = false and empty custom environment to not contain default globals"
)
local lowerFn2 = luau.load(CUSTOM_SOURCE_WITH_DEFAULT_FN, {
environment = { string = string },
injectGlobals = false,
})
local lowerFn2Success, lowerFn2Result = pcall(lowerFn2, "LOWERCASE")
assert(
lowerFn2Success and lowerFn2Result == "lowercase",
"expected `luau.load` with injectGlobals = false and valid custom environment to return proper values"
)
local lowerFn3 = luau.load(CUSTOM_SOURCE_WITH_DEFAULT_FN, {
environment = {},
injectGlobals = true,
})
local lowerFn3Success, lowerFn3Result = pcall(lowerFn3, "LOWERCASE")
assert(
lowerFn3Success and lowerFn3Result == "lowercase",
"expected `luau.load` with injectGlobals = true and empty custom environment to return proper values"
)

64
tests/luau/safeenv.luau Normal file
View file

@ -0,0 +1,64 @@
local luau = require("@lune/luau")
local TEST_SCRIPT = [[
local start = os.clock()
local x
for i = 1, 1e6 do
x = math.sqrt(i)
end
local finish = os.clock()
return finish - start
]]
local TEST_BYTECODE = luau.compile(TEST_SCRIPT, {
optimizationLevel = 2,
coverageLevel = 0,
debugLevel = 0,
})
-- Load the bytecode with different configurations
local safeCodegenFunction = luau.load(TEST_BYTECODE, {
debugName = "safeCodegenFunction",
codegenEnabled = true,
})
local unsafeCodegenFunction = luau.load(TEST_BYTECODE, {
debugName = "unsafeCodegenFunction",
codegenEnabled = true,
environment = {},
injectGlobals = true,
})
local safeFunction = luau.load(TEST_BYTECODE, {
debugName = "safeFunction",
codegenEnabled = false,
})
local unsafeFunction = luau.load(TEST_BYTECODE, {
debugName = "unsafeFunction",
codegenEnabled = false,
environment = {},
injectGlobals = true,
})
-- Run the functions to get the timings
local safeCodegenTime = safeCodegenFunction()
local unsafeCodegenTime = unsafeCodegenFunction()
local safeTime = safeFunction()
local unsafeTime = unsafeFunction()
-- Assert that safeCodegenTime is always twice as fast as both unsafe functions
local safeCodegenUpperBound = safeCodegenTime * 2
assert(
unsafeCodegenTime > safeCodegenUpperBound and unsafeTime > safeCodegenUpperBound,
"expected luau.load with codegenEnabled = true and no custom environment to use codegen"
)
-- Assert that safeTime is always atleast twice as fast as both unsafe functions
local safeUpperBound = safeTime * 2
assert(
unsafeCodegenTime > safeUpperBound and unsafeTime > safeUpperBound,
"expected luau.load with codegenEnabled = false and no custom environment to have safeenv enabled"
)
-- Normally we'd also want to check whether codegen is actually being enabled by
-- comparing timings of safe_codegen_fn and safe_fn but since we don't have a way of
-- checking whether the current device even supports codegen, we can't safely test this.

View file

@ -27,11 +27,15 @@ export type CompileOptions = {
This is a dictionary that may contain one or more of the following values:
* `debugName` - The debug name of the closure. Defaults to `luau.load(...)`.
* `environment` - Environment values to set and/or override. Includes default globals unless overwritten.
* `environment` - A custom environment to load the chunk in. Setting a custom environment will deoptimize the chunk and forcefully disable codegen. Defaults to the global environment.
* `injectGlobals` - Whether or not to inject globals in the custom environment. Has no effect if no custom environment is provided. Defaults to `true`.
* `codegenEnabled` - Whether or not to enable codegen. Defaults to `true`.
]=]
export type LoadOptions = {
debugName: string?,
environment: { [string]: any }?,
injectGlobals: boolean?,
codegenEnabled: boolean?,
}
--[=[