From 74375ff70841f6527bb926b7d6412d44cbbdc31b Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Thu, 24 Apr 2025 21:02:16 +0200 Subject: [PATCH] Implement self alias for module requires --- .../lune-std/src/globals/require/context.rs | 34 +++++++++++++++++-- crates/lune-std/src/globals/require/mod.rs | 4 ++- crates/lune-std/src/globals/require/path.rs | 3 +- tests/require/tests/init.luau | 4 +++ tests/require/tests/invalid.luau | 2 +- .../tests/modules/self_alias/init.luau | 10 ++++++ .../tests/modules/self_alias/module.luau | 4 +++ 7 files changed, 56 insertions(+), 5 deletions(-) create mode 100644 tests/require/tests/modules/self_alias/init.luau create mode 100644 tests/require/tests/modules/self_alias/module.luau diff --git a/crates/lune-std/src/globals/require/context.rs b/crates/lune-std/src/globals/require/context.rs index 5759315..476583d 100644 --- a/crates/lune-std/src/globals/require/context.rs +++ b/crates/lune-std/src/globals/require/context.rs @@ -55,15 +55,45 @@ impl RequireContext { This will resolve path segments such as `./`, `../`, ..., and if the resolved path is not an absolute path, will create an absolute path by prepending the current working directory. + + If `resolve_as_self` is true, the given path should be a luau + module require path in the format of `@self/foo/bar/...` with the + `@self` prefix being stripped, and only `foo/bar/...` being passed. */ pub fn resolve_paths( source: impl AsRef, path: impl AsRef, + resolve_as_self: bool, ) -> LuaResult<(PathBuf, PathBuf)> { - let path = PathBuf::from(source.as_ref()) + let source = PathBuf::from(source.as_ref()); + let path = PathBuf::from(path.as_ref()); + + let is_init_module = { + let is_init = path + .file_stem() + .and_then(|stem| stem.to_str()) + .is_some_and(|stem| stem.eq_ignore_ascii_case("init")); + let is_luau = is_init + && path + .extension() + .and_then(|ext| ext.to_str()) + .is_some_and(|ext| matches!(ext, "lua" | "luau")); + is_init && is_luau + }; + + let source = if is_init_module && !resolve_as_self { + source + .parent() + .ok_or_else(|| LuaError::runtime("Failed to get parent path of self"))? + .to_path_buf() + } else { + source + }; + + let path = source .parent() .ok_or_else(|| LuaError::runtime("Failed to get parent path of source"))? - .join(path.as_ref()); + .join(path); let abs_path = clean_path_and_make_absolute(&path); let rel_path = clean_path(path); diff --git a/crates/lune-std/src/globals/require/mod.rs b/crates/lune-std/src/globals/require/mod.rs index d6a2a08..487cbda 100644 --- a/crates/lune-std/src/globals/require/mod.rs +++ b/crates/lune-std/src/globals/require/mod.rs @@ -80,13 +80,15 @@ async fn require(lua: Lua, (source, path): (LuaString, LuaString)) -> LuaResult< if let Some(builtin_name) = path.strip_prefix("@lune/").map(str::to_ascii_lowercase) { library::require(lua, &context, &builtin_name) + } else if let Some(self_path) = path.strip_prefix("@self/") { + path::require(lua, &context, &source, self_path, true).await } else if let Some(aliased_path) = path.strip_prefix('@') { let (alias, path) = aliased_path.split_once('/').ok_or(LuaError::runtime( "Require with custom alias must contain '/' delimiter", ))?; alias::require(lua, &context, &source, alias, path).await } else if path.starts_with("./") || path.starts_with("../") { - path::require(lua, &context, &source, &path).await + path::require(lua, &context, &source, &path, false).await } else { Err(LuaError::runtime( "Require path must start with \"./\", \"../\" or \"@\"", diff --git a/crates/lune-std/src/globals/require/path.rs b/crates/lune-std/src/globals/require/path.rs index 937dda4..82765a0 100644 --- a/crates/lune-std/src/globals/require/path.rs +++ b/crates/lune-std/src/globals/require/path.rs @@ -10,8 +10,9 @@ pub(super) async fn require( ctx: &RequireContext, source: &str, path: &str, + resolve_as_self: bool, ) -> LuaResult { - let (abs_path, rel_path) = RequireContext::resolve_paths(source, path)?; + let (abs_path, rel_path) = RequireContext::resolve_paths(source, path, resolve_as_self)?; require_abs_rel(lua, ctx, abs_path, rel_path).await } diff --git a/tests/require/tests/init.luau b/tests/require/tests/init.luau index d4900c3..2cd6a27 100644 --- a/tests/require/tests/init.luau +++ b/tests/require/tests/init.luau @@ -12,4 +12,8 @@ module = require("./modules/modules") assert(module.Foo == "Bar", "Required module did not contain correct values") assert(module.Hello == "World", "Required module did not contain correct values") +module = require("./modules/self_alias") +assert(module.Foo == "Bar", "Required module did not contain correct values") +assert(module.Hello == "World", "Required module did not contain correct values") + return true diff --git a/tests/require/tests/invalid.luau b/tests/require/tests/invalid.luau index 52a3574..8e545e4 100644 --- a/tests/require/tests/invalid.luau +++ b/tests/require/tests/invalid.luau @@ -1,6 +1,6 @@ local function test(path: string) local success, message = pcall(function() - local _ = require(path) :: any + local _ = require("./" .. path) :: any end) if success then error(string.format("Invalid require at path '%s' succeeded", path)) diff --git a/tests/require/tests/modules/self_alias/init.luau b/tests/require/tests/modules/self_alias/init.luau new file mode 100644 index 0000000..b8c85b9 --- /dev/null +++ b/tests/require/tests/modules/self_alias/init.luau @@ -0,0 +1,10 @@ +local outer = require("./module") +local inner = require("@self/module") + +assert(type(outer) == "table", "Outer module is not a table") +assert(type(inner) == "table", "Inner module is not a table") + +assert(outer.Foo == inner.Foo, "Outer and inner modules have different Foo values") +assert(inner.Bar == outer.Bar, "Outer and inner modules have different Bar values") + +return inner diff --git a/tests/require/tests/modules/self_alias/module.luau b/tests/require/tests/modules/self_alias/module.luau new file mode 100644 index 0000000..cb3159b --- /dev/null +++ b/tests/require/tests/modules/self_alias/module.luau @@ -0,0 +1,4 @@ +return { + Foo = "Bar", + Hello = "World", +}