lune-packaging/src/lib/globals/require.rs

73 lines
3.5 KiB
Rust
Raw Normal View History

use std::{env, path::PathBuf, sync::Arc};
use mlua::prelude::*;
use os_str_bytes::RawOsStr;
pub fn create(lua: &Lua) -> LuaResult<()> {
// Preserve original require behavior if we have a special env var set
if env::var_os("LUAU_PWD_REQUIRE").is_some() {
return Ok(());
}
// Fetch the debug info function and store it in the registry
// - we will use it to fetch the current scripts file name
let debug: LuaTable = lua.globals().raw_get("debug")?;
let info: LuaFunction = debug.raw_get("info")?;
lua.set_named_registry_value("require_getinfo", info)?;
// Fetch the original require function and store it in the registry
let require: LuaFunction = lua.globals().raw_get("require")?;
lua.set_named_registry_value("require_original", require)?;
/*
Create a new function that fetches the file name from the current thread,
sets the luau module lookup path to be the exact script we are looking
for, and then runs the original require function with the wanted path
*/
let new_require = lua.create_function(|lua, require_path: LuaString| {
let require_original: LuaFunction = lua.named_registry_value("require_original")?;
let require_getinfo: LuaFunction = lua.named_registry_value("require_getinfo")?;
let require_source: LuaString = require_getinfo.call((2, "s"))?;
/*
Combine the require caller source with the wanted path
string to get a final path relative to pwd - it is definitely
relative to pwd because Lune will only load files relative to pwd
Here we also take extra care to not perform any lossy conversion
and use os strings instead of Rust's utf-8 checked strings, in the
unlikely case someone out there uses luau with non-utf8 string requires
*/
let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes());
let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes());
let path_relative_to_pwd = PathBuf::from(&raw_source.to_os_str())
.parent()
.unwrap()
.join(raw_path.to_os_str());
// Create a lossless lua string from the pathbuf and finally call require
let raw_path_str = RawOsStr::new(path_relative_to_pwd.as_os_str());
let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes());
// If the require call errors then we should also replace
// the path in the error message to improve user experience
let result: LuaResult<_> = require_original.call::<_, LuaValue>(lua_path_str);
match result {
Err(LuaError::CallbackError { traceback, cause }) => {
let before = format!(
"runtime error: cannot find '{}'",
path_relative_to_pwd.to_str().unwrap()
);
let after = format!(
"Invalid require path '{}' ({})",
require_path.to_str().unwrap(),
path_relative_to_pwd.to_str().unwrap()
);
let cause = Arc::new(LuaError::RuntimeError(
cause.to_string().replace(&before, &after),
));
Err(LuaError::CallbackError { traceback, cause })
}
Err(e) => Err(e),
Ok(result) => Ok(result),
}
})?;
// Override the original require global with our monkey-patched one
lua.globals().raw_set("require", new_require)?;
Ok(())
}