lune/packages/lib/src/globals/require.rs

96 lines
4.4 KiB
Rust

use std::{
env::{self, current_dir},
path::PathBuf,
sync::Arc,
};
use mlua::prelude::*;
use os_str_bytes::{OsStrBytes, RawOsStr};
use crate::utils::table::TableBuilder;
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
// Preserve original require behavior if we have a special env var set
if env::var_os("LUAU_PWD_REQUIRE").is_some() {
// Return an empty table since there are no globals to overwrite
return TableBuilder::new(lua)?.build_readonly();
}
/*
Store the current working directory so that we can use it later
and remove it from require paths in error messages, showing
absolute paths is bad ux and we should try to avoid it
Throughout this function we also take extra care to not perform any lossy
conversion and use os strings instead of Rust's utf-8 checked strings,
just in case someone out there uses luau with non-utf8 string requires
*/
let pwd = lua.create_string(&current_dir()?.to_raw_bytes())?;
lua.set_named_registry_value("pwd", pwd)?;
/*
Create a new function that fetches the file name from the current thread,
sets the luau module lookup path to be the exact script we are looking
for, and then runs the original require function with the wanted path
*/
let new_require = lua.create_function(|lua, require_path: LuaString| {
let require_pwd: LuaString = lua.named_registry_value("pwd")?;
let require_fn: LuaFunction = lua.named_registry_value("require")?;
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
let require_source: LuaString = require_info.call((2, "s"))?;
/*
Combine the require caller source with the wanted path
string to get a final path relative to pwd - it is definitely
relative to pwd because Lune will only load files relative to pwd
*/
let raw_pwd_str = RawOsStr::assert_from_raw_bytes(require_pwd.as_bytes());
let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes());
let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes());
let mut path_relative_to_pwd = PathBuf::from(
&raw_source
.trim_start_matches("[string \"")
.trim_end_matches("\"]")
.to_os_str(),
)
.parent()
.unwrap()
.join(raw_path.to_os_str());
// Try to normalize and resolve relative path segments such as './' and '../'
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("luau").canonicalize() {
path_relative_to_pwd = canonicalized.with_extension("");
}
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("lua").canonicalize() {
path_relative_to_pwd = canonicalized.with_extension("");
}
if let Ok(stripped) = path_relative_to_pwd.strip_prefix(&raw_pwd_str.to_os_str()) {
path_relative_to_pwd = stripped.to_path_buf();
}
// Create a lossless lua string from the pathbuf and finally call require
let raw_path_str = RawOsStr::new(path_relative_to_pwd.as_os_str());
let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes());
// If the require call errors then we should also replace
// the path in the error message to improve user experience
let result: LuaResult<_> = require_fn.call::<_, LuaValue>(lua_path_str);
match result {
Err(LuaError::CallbackError { traceback, cause }) => {
let before = format!(
"runtime error: cannot find '{}'",
path_relative_to_pwd.to_str().unwrap()
);
let after = format!(
"Invalid require path '{}' ({})",
require_path.to_str().unwrap(),
path_relative_to_pwd.to_str().unwrap()
);
let cause = Arc::new(LuaError::RuntimeError(
cause.to_string().replace(&before, &after),
));
Err(LuaError::CallbackError { traceback, cause })
}
Err(e) => Err(e),
Ok(result) => Ok(result),
}
})?;
// Override the original require global with our monkey-patched one
TableBuilder::new(lua)?
.with_value("require", new_require)?
.build_readonly()
}