2023-01-24 20:27:38 +00:00
|
|
|
use std::{
|
|
|
|
env::{self, current_dir},
|
|
|
|
path::PathBuf,
|
|
|
|
sync::Arc,
|
|
|
|
};
|
2023-01-24 07:05:54 +00:00
|
|
|
|
|
|
|
use mlua::prelude::*;
|
2023-01-24 20:27:38 +00:00
|
|
|
use os_str_bytes::{OsStrBytes, RawOsStr};
|
2023-01-24 07:05:54 +00:00
|
|
|
|
2023-02-10 11:14:28 +00:00
|
|
|
use crate::utils::table::TableBuilder;
|
|
|
|
|
2023-02-11 11:39:39 +00:00
|
|
|
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
2023-02-10 11:14:28 +00:00
|
|
|
let require: LuaFunction = lua.globals().raw_get("require")?;
|
2023-01-24 07:05:54 +00:00
|
|
|
// Preserve original require behavior if we have a special env var set
|
|
|
|
if env::var_os("LUAU_PWD_REQUIRE").is_some() {
|
2023-02-10 11:14:28 +00:00
|
|
|
return TableBuilder::new(lua)?
|
|
|
|
.with_value("require", require)?
|
|
|
|
.build_readonly();
|
2023-01-24 07:05:54 +00:00
|
|
|
}
|
2023-01-24 20:27:38 +00:00
|
|
|
/*
|
|
|
|
Store the current working directory so that we can use it later
|
|
|
|
and remove it from require paths in error messages, showing
|
|
|
|
absolute paths is bad ux and we should try to avoid it
|
|
|
|
|
|
|
|
Throughout this function we also take extra care to not perform any lossy
|
|
|
|
conversion and use os strings instead of Rust's utf-8 checked strings,
|
|
|
|
just in case someone out there uses luau with non-utf8 string requires
|
|
|
|
*/
|
|
|
|
let pwd = lua.create_string(¤t_dir()?.to_raw_bytes())?;
|
|
|
|
lua.set_named_registry_value("require_pwd", pwd)?;
|
2023-01-24 07:05:54 +00:00
|
|
|
// Fetch the debug info function and store it in the registry
|
|
|
|
// - we will use it to fetch the current scripts file name
|
|
|
|
let debug: LuaTable = lua.globals().raw_get("debug")?;
|
|
|
|
let info: LuaFunction = debug.raw_get("info")?;
|
|
|
|
lua.set_named_registry_value("require_getinfo", info)?;
|
2023-02-10 11:14:28 +00:00
|
|
|
// Store the original require function in the registry
|
2023-01-24 07:05:54 +00:00
|
|
|
lua.set_named_registry_value("require_original", require)?;
|
|
|
|
/*
|
|
|
|
Create a new function that fetches the file name from the current thread,
|
|
|
|
sets the luau module lookup path to be the exact script we are looking
|
|
|
|
for, and then runs the original require function with the wanted path
|
|
|
|
*/
|
|
|
|
let new_require = lua.create_function(|lua, require_path: LuaString| {
|
2023-01-24 20:27:38 +00:00
|
|
|
let require_pwd: LuaString = lua.named_registry_value("require_pwd")?;
|
2023-01-24 07:05:54 +00:00
|
|
|
let require_original: LuaFunction = lua.named_registry_value("require_original")?;
|
|
|
|
let require_getinfo: LuaFunction = lua.named_registry_value("require_getinfo")?;
|
|
|
|
let require_source: LuaString = require_getinfo.call((2, "s"))?;
|
|
|
|
/*
|
|
|
|
Combine the require caller source with the wanted path
|
|
|
|
string to get a final path relative to pwd - it is definitely
|
|
|
|
relative to pwd because Lune will only load files relative to pwd
|
|
|
|
*/
|
2023-01-24 20:27:38 +00:00
|
|
|
let raw_pwd_str = RawOsStr::assert_from_raw_bytes(require_pwd.as_bytes());
|
2023-01-24 07:05:54 +00:00
|
|
|
let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes());
|
|
|
|
let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes());
|
2023-01-24 20:27:38 +00:00
|
|
|
let mut path_relative_to_pwd = PathBuf::from(&raw_source.to_os_str())
|
2023-01-24 07:05:54 +00:00
|
|
|
.parent()
|
|
|
|
.unwrap()
|
|
|
|
.join(raw_path.to_os_str());
|
2023-01-24 20:27:38 +00:00
|
|
|
// Try to normalize and resolve relative path segments such as './' and '../'
|
|
|
|
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("luau").canonicalize() {
|
|
|
|
path_relative_to_pwd = canonicalized.with_extension("");
|
|
|
|
}
|
|
|
|
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("lua").canonicalize() {
|
|
|
|
path_relative_to_pwd = canonicalized.with_extension("");
|
|
|
|
}
|
|
|
|
if let Ok(stripped) = path_relative_to_pwd.strip_prefix(&raw_pwd_str.to_os_str()) {
|
|
|
|
path_relative_to_pwd = stripped.to_path_buf();
|
|
|
|
}
|
2023-01-24 07:05:54 +00:00
|
|
|
// Create a lossless lua string from the pathbuf and finally call require
|
|
|
|
let raw_path_str = RawOsStr::new(path_relative_to_pwd.as_os_str());
|
|
|
|
let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes());
|
|
|
|
// If the require call errors then we should also replace
|
|
|
|
// the path in the error message to improve user experience
|
|
|
|
let result: LuaResult<_> = require_original.call::<_, LuaValue>(lua_path_str);
|
|
|
|
match result {
|
|
|
|
Err(LuaError::CallbackError { traceback, cause }) => {
|
|
|
|
let before = format!(
|
|
|
|
"runtime error: cannot find '{}'",
|
|
|
|
path_relative_to_pwd.to_str().unwrap()
|
|
|
|
);
|
|
|
|
let after = format!(
|
|
|
|
"Invalid require path '{}' ({})",
|
|
|
|
require_path.to_str().unwrap(),
|
|
|
|
path_relative_to_pwd.to_str().unwrap()
|
|
|
|
);
|
|
|
|
let cause = Arc::new(LuaError::RuntimeError(
|
|
|
|
cause.to_string().replace(&before, &after),
|
|
|
|
));
|
|
|
|
Err(LuaError::CallbackError { traceback, cause })
|
|
|
|
}
|
|
|
|
Err(e) => Err(e),
|
|
|
|
Ok(result) => Ok(result),
|
|
|
|
}
|
|
|
|
})?;
|
|
|
|
// Override the original require global with our monkey-patched one
|
2023-02-10 11:14:28 +00:00
|
|
|
TableBuilder::new(lua)?
|
|
|
|
.with_value("require", new_require)?
|
|
|
|
.build_readonly()
|
2023-01-24 07:05:54 +00:00
|
|
|
}
|