From d3b9a4b9e8ed489fab5af32d2fab83f9bf46db23 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Wed, 5 Jun 2024 19:02:48 +0200 Subject: [PATCH] Add new options for global injection and codegen to luau.load --- crates/lune-std-luau/src/lib.rs | 41 +++++++++---- crates/lune-std-luau/src/options.rs | 18 ++++-- crates/lune/src/tests.rs | 1 + tests/luau/load.luau | 94 +++++++++++++++++++++++------ tests/luau/safeenv.luau | 64 ++++++++++++++++++++ types/luau.luau | 6 +- 6 files changed, 188 insertions(+), 36 deletions(-) create mode 100644 tests/luau/safeenv.luau diff --git a/crates/lune-std-luau/src/lib.rs b/crates/lune-std-luau/src/lib.rs index e41eed5..21eb912 100644 --- a/crates/lune-std-luau/src/lib.rs +++ b/crates/lune-std-luau/src/lib.rs @@ -44,26 +44,41 @@ fn load_source<'lua>( (source, options): (LuaString<'lua>, LuauLoadOptions), ) -> LuaResult> { let mut chunk = lua.load(source.as_bytes()).set_name(options.debug_name); + let env_changed = options.environment.is_some(); - if let Some(environment) = options.environment { - let environment_with_globals = lua.create_table()?; + if let Some(custom_environment) = options.environment { + let environment = lua.create_table()?; - if let Some(meta) = environment.get_metatable() { - environment_with_globals.set_metatable(Some(meta)); + // Inject all globals into the environment + if options.inject_globals { + for pair in lua.globals().pairs() { + let (key, value): (LuaValue, LuaValue) = pair?; + environment.set(key, value)?; + } + + if let Some(global_metatable) = lua.globals().get_metatable() { + environment.set_metatable(Some(global_metatable)); + } + } else if let Some(custom_metatable) = custom_environment.get_metatable() { + // Since we don't need to set the global metatable, + // we can just set a custom metatable if it exists + environment.set_metatable(Some(custom_metatable)); } - for pair in lua.globals().pairs() { + // Inject the custom environment + for pair in custom_environment.pairs() { let (key, value): (LuaValue, LuaValue) = pair?; - environment_with_globals.set(key, value)?; + environment.set(key, value)?; } - for pair in environment.pairs() { - let (key, value): (LuaValue, LuaValue) = pair?; - environment_with_globals.set(key, value)?; - } - - chunk = chunk.set_environment(environment_with_globals); + chunk = chunk.set_environment(environment); } - chunk.into_function() + // Enable JIT if codegen is enabled and the environment hasn't + // changed, otherwise disable JIT since it'll fall back anyways + lua.enable_jit(options.codegen_enabled && !env_changed); + let function = chunk.into_function()?; + lua.enable_jit(true); + + Ok(function) } diff --git a/crates/lune-std-luau/src/options.rs b/crates/lune-std-luau/src/options.rs index a2040ec..81b8ac0 100644 --- a/crates/lune-std-luau/src/options.rs +++ b/crates/lune-std-luau/src/options.rs @@ -79,13 +79,11 @@ impl<'lua> FromLua<'lua> for LuauCompileOptions { } } -/** - Options for loading Lua source code. -*/ -#[derive(Debug, Clone)] pub struct LuauLoadOptions<'lua> { pub(crate) debug_name: String, pub(crate) environment: Option>, + pub(crate) inject_globals: bool, + pub(crate) codegen_enabled: bool, } impl Default for LuauLoadOptions<'_> { @@ -93,6 +91,8 @@ impl Default for LuauLoadOptions<'_> { Self { debug_name: DEFAULT_DEBUG_NAME.to_string(), environment: None, + inject_globals: true, + codegen_enabled: false, } } } @@ -112,11 +112,21 @@ impl<'lua> FromLua<'lua> for LuauLoadOptions<'lua> { options.environment = Some(environment); } + if let Some(inject_globals) = t.get("injectGlobals")? { + options.inject_globals = inject_globals; + } + + if let Some(codegen_enabled) = t.get("codegenEnabled")? { + options.codegen_enabled = codegen_enabled; + } + options } LuaValue::String(s) => Self { debug_name: s.to_string_lossy().to_string(), environment: None, + inject_globals: true, + codegen_enabled: false, }, _ => { return Err(LuaError::FromLuaConversionError { diff --git a/crates/lune/src/tests.rs b/crates/lune/src/tests.rs index 0306b29..2e866dc 100644 --- a/crates/lune/src/tests.rs +++ b/crates/lune/src/tests.rs @@ -113,6 +113,7 @@ create_tests! { luau_compile: "luau/compile", luau_load: "luau/load", luau_options: "luau/options", + luau_safeenv: "luau/safeenv", } #[cfg(feature = "std-net")] diff --git a/tests/luau/load.luau b/tests/luau/load.luau index 7701270..1cff3f8 100644 --- a/tests/luau/load.luau +++ b/tests/luau/load.luau @@ -26,11 +26,11 @@ assert( "expected source block name for 'luau.load' to return a custom debug name" ) -local success = pcall(function() +local loadSuccess = pcall(function() luau.load(luau.compile(RETURN_LUAU_CODE_BLOCK)) end) -assert(success, "expected `luau.load` to be able to process the result of `luau.compile`") +assert(loadSuccess, "expected `luau.load` to be able to process the result of `luau.compile`") local CUSTOM_SOURCE_WITH_FOO_FN = "return foo()" @@ -48,34 +48,92 @@ local fooFn = luau.load(CUSTOM_SOURCE_WITH_FOO_FN, { local fooFnRet = fooFn() assert(fooFnRet == fooValue, "expected `luau.load` with custom environment to return proper values") -local CUSTOM_SOURCE_WITH_PRINT_FN = "return print()" - --- NOTE: Same as what we did above, new userdata to guarantee unique-ness -local overriddenValue = newproxy(false) -local overriddenFn = luau.load(CUSTOM_SOURCE_WITH_PRINT_FN, { +local fooValue2 = newproxy(false) +local fooFn2 = luau.load(CUSTOM_SOURCE_WITH_FOO_FN, { environment = { - print = function() - return overriddenValue + foo = function() + return fooValue2 end, }, + enableGlobals = false, }) -local overriddenFnRet = overriddenFn() +local fooFn2Ret = fooFn2() assert( - overriddenFnRet == overriddenValue, + fooFn2Ret == fooValue2, + "expected `luau.load` with custom environment and no default globals to still return proper values" +) + +local CUSTOM_SOURCE_WITH_PRINT_FN = "return print()" + +-- NOTE: Testing overriding the print function +local overriddenPrintValue1 = newproxy(false) +local overriddenPrintFn1 = luau.load(CUSTOM_SOURCE_WITH_PRINT_FN, { + environment = { + print = function() + return overriddenPrintValue1 + end, + }, + enableGlobals = true, +}) + +local overriddenPrintFnRet1 = overriddenPrintFn1() +assert( + overriddenPrintFnRet1 == overriddenPrintValue1, "expected `luau.load` with overridden environment to return proper values" ) -local CUSTOM_SOURCE_WITH_DEFAULT_FN = "return string.lower(...)" - -local overriddenFn2 = luau.load(CUSTOM_SOURCE_WITH_DEFAULT_FN, { +local overriddenPrintValue2 = newproxy(false) +local overriddenPrintFn2 = luau.load(CUSTOM_SOURCE_WITH_PRINT_FN, { environment = { - hello = "world", + print = function() + return overriddenPrintValue2 + end, }, + enableGlobals = false, }) -local overriddenFn2Ret = overriddenFn2("LOWERCASE") +local overriddenPrintFnRet2 = overriddenPrintFn2() assert( - overriddenFn2Ret == "lowercase", - "expected `luau.load` with overridden environment to contain default globals" + overriddenPrintFnRet2 == overriddenPrintValue2, + "expected `luau.load` with overridden environment and disabled default globals to return proper values" +) + +-- NOTE: Testing whether injectGlobals works +local CUSTOM_SOURCE_WITH_DEFAULT_FN = "return string.lower(...)" + +local lowerFn1 = luau.load(CUSTOM_SOURCE_WITH_DEFAULT_FN, { + environment = {}, + injectGlobals = false, +}) + +local lowerFn1Success = pcall(lowerFn1, "LOWERCASE") + +assert( + not lowerFn1Success, + "expected `luau.load` with injectGlobals = false and empty custom environment to not contain default globals" +) + +local lowerFn2 = luau.load(CUSTOM_SOURCE_WITH_DEFAULT_FN, { + environment = { string = string }, + injectGlobals = false, +}) + +local lowerFn2Success, lowerFn2Result = pcall(lowerFn2, "LOWERCASE") + +assert( + lowerFn2Success and lowerFn2Result == "lowercase", + "expected `luau.load` with injectGlobals = false and valid custom environment to return proper values" +) + +local lowerFn3 = luau.load(CUSTOM_SOURCE_WITH_DEFAULT_FN, { + environment = {}, + injectGlobals = true, +}) + +local lowerFn3Success, lowerFn3Result = pcall(lowerFn3, "LOWERCASE") + +assert( + lowerFn3Success and lowerFn3Result == "lowercase", + "expected `luau.load` with injectGlobals = true and empty custom environment to return proper values" ) diff --git a/tests/luau/safeenv.luau b/tests/luau/safeenv.luau new file mode 100644 index 0000000..f1a9d06 --- /dev/null +++ b/tests/luau/safeenv.luau @@ -0,0 +1,64 @@ +local luau = require("@lune/luau") + +local TEST_SCRIPT = [[ + local start = os.clock() + local x + for i = 1, 1e6 do + x = math.sqrt(i) + end + local finish = os.clock() + + return finish - start +]] + +local TEST_BYTECODE = luau.compile(TEST_SCRIPT, { + optimizationLevel = 2, + coverageLevel = 0, + debugLevel = 0, +}) + +-- Load the bytecode with different configurations +local safeCodegenFunction = luau.load(TEST_BYTECODE, { + debugName = "safeCodegenFunction", + codegenEnabled = true, +}) +local unsafeCodegenFunction = luau.load(TEST_BYTECODE, { + debugName = "unsafeCodegenFunction", + codegenEnabled = true, + environment = {}, + injectGlobals = true, +}) +local safeFunction = luau.load(TEST_BYTECODE, { + debugName = "safeFunction", + codegenEnabled = false, +}) +local unsafeFunction = luau.load(TEST_BYTECODE, { + debugName = "unsafeFunction", + codegenEnabled = false, + environment = {}, + injectGlobals = true, +}) + +-- Run the functions to get the timings +local safeCodegenTime = safeCodegenFunction() +local unsafeCodegenTime = unsafeCodegenFunction() +local safeTime = safeFunction() +local unsafeTime = unsafeFunction() + +-- Assert that safeCodegenTime is always twice as fast as both unsafe functions +local safeCodegenUpperBound = safeCodegenTime * 2 +assert( + unsafeCodegenTime > safeCodegenUpperBound and unsafeTime > safeCodegenUpperBound, + "expected luau.load with codegenEnabled = true and no custom environment to use codegen" +) + +-- Assert that safeTime is always atleast twice as fast as both unsafe functions +local safeUpperBound = safeTime * 2 +assert( + unsafeCodegenTime > safeUpperBound and unsafeTime > safeUpperBound, + "expected luau.load with codegenEnabled = false and no custom environment to have safeenv enabled" +) + +-- Normally we'd also want to check whether codegen is actually being enabled by +-- comparing timings of safe_codegen_fn and safe_fn but since we don't have a way of +-- checking whether the current device even supports codegen, we can't safely test this. diff --git a/types/luau.luau b/types/luau.luau index a810ee5..379c07d 100644 --- a/types/luau.luau +++ b/types/luau.luau @@ -27,11 +27,15 @@ export type CompileOptions = { This is a dictionary that may contain one or more of the following values: * `debugName` - The debug name of the closure. Defaults to `luau.load(...)`. - * `environment` - Environment values to set and/or override. Includes default globals unless overwritten. + * `environment` - A custom environment to load the chunk in. Setting a custom environment will deoptimize the chunk and forcefully disable codegen. Defaults to the global environment. + * `injectGlobals` - Whether or not to inject globals in the custom environment. Has no effect if no custom environment is provided. Defaults to `true`. + * `codegenEnabled` - Whether or not to enable codegen. Defaults to `true`. ]=] export type LoadOptions = { debugName: string?, environment: { [string]: any }?, + injectGlobals: boolean?, + codegenEnabled: boolean?, } --[=[