diff --git a/.lune/hello_lune.luau b/.lune/hello_lune.luau index c03f8fa..bfe87e3 100644 --- a/.lune/hello_lune.luau +++ b/.lune/hello_lune.luau @@ -6,7 +6,7 @@ print("Hello, lune! 🌙") Using a function from another module ]==] -local module = require(".lune/module") +local module = require("./module") module.sayHello() --[==[ @@ -36,8 +36,7 @@ end) print("Spawning a delayed task that will run in 5 seconds") task.delay(5, function() - print() - print("...") + print("\n...") task.wait(1) print("Hello again!") task.wait(1) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3000f8e..3d56c4f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +### Changed + +- `require` now uses paths relative to the file instead of being relative to the current directory, which is consistent with almost all other languages but not original Lua / Luau - this is a breaking change but will allow for proper packaging of third-party modules and more in the future. +- Improved error message when an invalid file path is passed to `require` +- Much improved error formatting and stack traces + +### Fixed + +- Process termination will now always make sure all lua state is cleaned up before exiting, in all cases + ## `0.0.6` - January 23rd, 2023 ### Added diff --git a/src/cli/cli.rs b/src/cli/cli.rs index 27796a1..02711bc 100644 --- a/src/cli/cli.rs +++ b/src/cli/cli.rs @@ -1,4 +1,4 @@ -use std::fs::read_to_string; +use std::{fs::read_to_string, process::ExitCode}; use anyhow::Result; use clap::{CommandFactory, Parser}; @@ -54,7 +54,7 @@ impl Cli { } } - pub async fn run(self) -> Result<()> { + pub async fn run(self) -> Result { // Download definition files, if wanted let download_types_requested = self.download_selene_types || self.download_luau_types; if download_types_requested { @@ -82,7 +82,7 @@ impl Cli { // Only downloading types without running a script is completely // fine, and we should just exit the program normally afterwards if download_types_requested { - return Ok(()); + return Ok(ExitCode::SUCCESS); } // HACK: We know that we didn't get any arguments here but since // script_path is optional clap will not error on its own, to fix @@ -98,10 +98,13 @@ impl Cli { let file_display_name = file_path.with_extension("").display().to_string(); // Create a new lune object with all globals & run the script let lune = Lune::new().with_args(self.script_args).with_all_globals(); - if let Err(e) = lune.run(&file_display_name, &file_contents).await { - eprintln!("{e}"); - std::process::exit(1); - }; - Ok(()) + let result = lune.run(&file_display_name, &file_contents).await; + Ok(match result { + Err(e) => { + eprintln!("{e}"); + ExitCode::from(1) + } + Ok(code) => code, + }) } } diff --git a/src/cli/main.rs b/src/cli/main.rs index 78e2d45..3416f98 100644 --- a/src/cli/main.rs +++ b/src/cli/main.rs @@ -2,6 +2,8 @@ #![warn(clippy::cargo, clippy::pedantic)] #![allow(clippy::needless_pass_by_value, clippy::match_bool)] +use std::process::ExitCode; + use anyhow::Result; use clap::Parser; @@ -10,10 +12,6 @@ mod utils; use cli::Cli; -fn main() -> Result<()> { - smol::block_on(async { - let cli = Cli::parse(); - cli.run().await?; - Ok(()) - }) +fn main() -> Result { + smol::block_on(async { Cli::parse().run().await }) } diff --git a/src/lib/globals/mod.rs b/src/lib/globals/mod.rs index 28ba9f0..50e8620 100644 --- a/src/lib/globals/mod.rs +++ b/src/lib/globals/mod.rs @@ -2,10 +2,12 @@ mod console; mod fs; mod net; mod process; +mod require; mod task; pub use console::create as create_console; pub use fs::create as create_fs; pub use net::create as create_net; pub use process::create as create_process; +pub use require::create as create_require; pub use task::create as create_task; diff --git a/src/lib/globals/require.rs b/src/lib/globals/require.rs new file mode 100644 index 0000000..5134263 --- /dev/null +++ b/src/lib/globals/require.rs @@ -0,0 +1,72 @@ +use std::{env, path::PathBuf, sync::Arc}; + +use mlua::prelude::*; +use os_str_bytes::RawOsStr; + +pub fn create(lua: &Lua) -> LuaResult<()> { + // Preserve original require behavior if we have a special env var set + if env::var_os("LUAU_PWD_REQUIRE").is_some() { + return Ok(()); + } + // Fetch the debug info function and store it in the registry + // - we will use it to fetch the current scripts file name + let debug: LuaTable = lua.globals().raw_get("debug")?; + let info: LuaFunction = debug.raw_get("info")?; + lua.set_named_registry_value("require_getinfo", info)?; + // Fetch the original require function and store it in the registry + let require: LuaFunction = lua.globals().raw_get("require")?; + lua.set_named_registry_value("require_original", require)?; + /* + Create a new function that fetches the file name from the current thread, + sets the luau module lookup path to be the exact script we are looking + for, and then runs the original require function with the wanted path + */ + let new_require = lua.create_function(|lua, require_path: LuaString| { + let require_original: LuaFunction = lua.named_registry_value("require_original")?; + let require_getinfo: LuaFunction = lua.named_registry_value("require_getinfo")?; + let require_source: LuaString = require_getinfo.call((2, "s"))?; + /* + Combine the require caller source with the wanted path + string to get a final path relative to pwd - it is definitely + relative to pwd because Lune will only load files relative to pwd + + Here we also take extra care to not perform any lossy conversion + and use os strings instead of Rust's utf-8 checked strings, in the + unlikely case someone out there uses luau with non-utf8 string requires + */ + let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes()); + let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes()); + let path_relative_to_pwd = PathBuf::from(&raw_source.to_os_str()) + .parent() + .unwrap() + .join(raw_path.to_os_str()); + // Create a lossless lua string from the pathbuf and finally call require + let raw_path_str = RawOsStr::new(path_relative_to_pwd.as_os_str()); + let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes()); + // If the require call errors then we should also replace + // the path in the error message to improve user experience + let result: LuaResult<_> = require_original.call::<_, LuaValue>(lua_path_str); + match result { + Err(LuaError::CallbackError { traceback, cause }) => { + let before = format!( + "runtime error: cannot find '{}'", + path_relative_to_pwd.to_str().unwrap() + ); + let after = format!( + "Invalid require path '{}' ({})", + require_path.to_str().unwrap(), + path_relative_to_pwd.to_str().unwrap() + ); + let cause = Arc::new(LuaError::RuntimeError( + cause.to_string().replace(&before, &after), + )); + Err(LuaError::CallbackError { traceback, cause }) + } + Err(e) => Err(e), + Ok(result) => Ok(result), + } + })?; + // Override the original require global with our monkey-patched one + lua.globals().raw_set("require", new_require)?; + Ok(()) +} diff --git a/src/lib/lib.rs b/src/lib/lib.rs index 80fb30c..257ba78 100644 --- a/src/lib/lib.rs +++ b/src/lib/lib.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, sync::Arc}; +use std::{collections::HashSet, process::ExitCode, sync::Arc}; use anyhow::{anyhow, bail, Result}; use mlua::prelude::*; @@ -8,7 +8,7 @@ pub mod globals; pub mod utils; use crate::{ - globals::{create_console, create_fs, create_net, create_process, create_task}, + globals::{create_console, create_fs, create_net, create_process, create_require, create_task}, utils::formatting::pretty_format_luau_error, }; @@ -21,6 +21,7 @@ pub enum LuneGlobal { Fs, Net, Process, + Require, Task, } @@ -31,6 +32,7 @@ impl LuneGlobal { Self::Fs, Self::Net, Self::Process, + Self::Require, Self::Task, ] } @@ -73,7 +75,7 @@ impl Lune { self } - pub async fn run(&self, name: &str, chunk: &str) -> Result { + pub async fn run(&self, name: &str, chunk: &str) -> Result { let (s, r) = smol::channel::unbounded::(); let lua = Arc::new(mlua::Lua::new()); let exec = Arc::new(LocalExecutor::new()); @@ -90,6 +92,7 @@ impl Lune { LuneGlobal::Fs => create_fs(&lua)?, LuneGlobal::Net => create_net(&lua)?, LuneGlobal::Process => create_process(&lua, self.args.clone())?, + LuneGlobal::Require => create_require(&lua)?, LuneGlobal::Task => create_task(&lua)?, } } @@ -100,7 +103,7 @@ impl Lune { sender.send(LuneMessage::Spawned).await?; let result = lua .load(&script_chunk) - .set_name(&script_name) + .set_name(&format!("={}", script_name)) .unwrap() .call_async::<_, LuaMultiValue>(LuaMultiValue::new()) .await; @@ -155,7 +158,7 @@ impl Lune { task_count += 1; } LuneMessage::LuaError(e) => { - eprintln!("{}", e); + eprintln!("{}", pretty_format_luau_error(&e)); got_error = true; task_count += 1; } @@ -171,11 +174,11 @@ impl Lune { // If we got an error, we will default to exiting // with code 1, unless a code was manually given if got_code { - Ok(exit_code) + Ok(ExitCode::from(exit_code)) } else if got_error { - Ok(1) + Ok(ExitCode::FAILURE) } else { - Ok(0) + Ok(ExitCode::SUCCESS) } } } @@ -183,9 +186,9 @@ impl Lune { #[cfg(test)] mod tests { use crate::Lune; - use anyhow::{bail, Result}; + use anyhow::Result; use smol::fs::read_to_string; - use std::env::current_dir; + use std::process::ExitCode; const ARGS: &[&str] = &["Foo", "Bar"]; @@ -193,12 +196,10 @@ mod tests { ($($name:ident: $value:expr,)*) => { $( #[test] - fn $name() -> Result<()> { + fn $name() -> Result { smol::block_on(async { - let path = current_dir() - .unwrap() - .join(format!("src/tests/{}.luau", $value)); - let script = read_to_string(&path) + let full_name = format!("src/tests/{}.luau", $value); + let script = read_to_string(&full_name) .await .unwrap(); let lune = Lune::new() @@ -210,11 +211,8 @@ mod tests { .collect() ) .with_all_globals(); - let exit_code = lune.run($value, &script).await?; - if exit_code != 0 { - bail!("Test exited with failure code {}", exit_code); - } - Ok(()) + let script_name = full_name.strip_suffix(".luau").unwrap(); + lune.run(&script_name, &script).await }) } )* @@ -227,15 +225,20 @@ mod tests { console_set_style: "console/set_style", fs_files: "fs/files", fs_dirs: "fs/dirs", - process_args: "process/args", - process_env: "process/env", - process_exit: "process/exit", - process_spawn: "process/spawn", net_request_codes: "net/request/codes", net_request_methods: "net/request/methods", net_request_redirect: "net/request/redirect", net_json_decode: "net/json/decode", net_json_encode: "net/json/encode", + process_args: "process/args", + process_env: "process/env", + process_exit: "process/exit", + process_spawn: "process/spawn", + require_children: "require/tests/children", + require_invalid: "require/tests/invalid", + require_nested: "require/tests/nested", + require_parents: "require/tests/parents", + require_siblings: "require/tests/siblings", task_cancel: "task/cancel", task_defer: "task/defer", task_delay: "task/delay", diff --git a/src/lib/utils/formatting.rs b/src/lib/utils/formatting.rs index ac3dc68..dc68b0f 100644 --- a/src/lib/utils/formatting.rs +++ b/src/lib/utils/formatting.rs @@ -204,23 +204,36 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String { .lines() .map(|s| s.to_string()) .collect::>(); + let mut found_stack_begin = false; for (index, line) in err_lines.clone().iter().enumerate().rev() { if *line == "stack traceback:" { err_lines[index] = stack_begin; + found_stack_begin = true; break; } } // Add "Stack End" to the very end of the stack trace for symmetry - err_lines.push(stack_end); + if found_stack_begin { + err_lines.push(stack_end); + } err_lines.join("\n") } - LuaError::CallbackError { cause, traceback } => { + LuaError::CallbackError { traceback, cause } => { + // Find the best traceback (longest) and the root error message + let mut best_trace = traceback; + let mut root_cause = cause.as_ref(); + while let LuaError::CallbackError { cause, traceback } = root_cause { + if traceback.len() > best_trace.len() { + best_trace = traceback; + } + root_cause = cause; + } // Same error formatting as above format!( - "{}\n{}{}{}", - pretty_format_luau_error(cause.as_ref()), + "{}\n{}\n{}\n{}", + pretty_format_luau_error(root_cause), stack_begin, - traceback.strip_prefix("stack traceback:\n").unwrap(), + best_trace.strip_prefix("stack traceback:\n").unwrap(), stack_end ) } @@ -247,6 +260,8 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String { let mut err_lines = err_string.lines().collect::>(); // Remove the script path from the error message // itself, it can be found in the stack trace + // FIXME: This no longer works now that we use + // an exact name when our lune script is loaded if let Some(first_line) = err_lines.first() { if first_line.starts_with("[string \"") { if let Some(closing_bracket) = first_line.find("]:") { @@ -260,7 +275,6 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String { } } } - // Reformat stack trace lines, ignore lines that just mention C functions // Merge all lines back together into one string err_lines.join("\n") } diff --git a/src/tests/net/request/codes.luau b/src/tests/net/request/codes.luau index 72a0c12..c454608 100644 --- a/src/tests/net/request/codes.luau +++ b/src/tests/net/request/codes.luau @@ -1,4 +1,4 @@ -local util = require("src/tests/net/request/util") +local util = require("./util") local pass, fail = util.pass, util.fail pass("GET", "https://httpbin.org/status/200", "Request status code - 200") diff --git a/src/tests/net/request/methods.luau b/src/tests/net/request/methods.luau index 96dc5c9..beed786 100644 --- a/src/tests/net/request/methods.luau +++ b/src/tests/net/request/methods.luau @@ -1,4 +1,4 @@ -local util = require("src/tests/net/request/util") +local util = require("./util") local pass = util.pass -- stylua: ignore start diff --git a/src/tests/net/request/redirect.luau b/src/tests/net/request/redirect.luau index c3ec663..d8b235f 100644 --- a/src/tests/net/request/redirect.luau +++ b/src/tests/net/request/redirect.luau @@ -1,4 +1,4 @@ -local util = require("src/tests/net/request/util") +local util = require("./util") local pass = util.pass pass("GET", "https://httpbin.org/absolute-redirect/3", "Redirect 3 times") diff --git a/src/tests/require/modules/module.luau b/src/tests/require/modules/module.luau new file mode 100644 index 0000000..cb3159b --- /dev/null +++ b/src/tests/require/modules/module.luau @@ -0,0 +1,4 @@ +return { + Foo = "Bar", + Hello = "World", +} diff --git a/src/tests/require/tests/children.luau b/src/tests/require/tests/children.luau new file mode 100644 index 0000000..30ec035 --- /dev/null +++ b/src/tests/require/tests/children.luau @@ -0,0 +1,7 @@ +local module = require("./modules/module") + +assert(type(module) == "table", "Required module did not return a table") +assert(module.Foo == "Bar", "Required module did not contain correct values") +assert(module.Hello == "World", "Required module did not contain correct values") + +require("modules/module") diff --git a/src/tests/require/tests/invalid.luau b/src/tests/require/tests/invalid.luau new file mode 100644 index 0000000..5c5c693 --- /dev/null +++ b/src/tests/require/tests/invalid.luau @@ -0,0 +1,26 @@ +local function test(path: string) + local success, message = pcall(function() + local _ = require(path) :: any + end) + if success then + error(string.format("Invalid require at path '%s' succeeded", path)) + else + message = tostring(message) + if string.find(message, string.format("'%s'", path)) == nil then + error( + string.format( + "Invalid require did not mention path '%s' in its error message!\nMessage: %s", + path, + tostring(message) + ) + ) + end + end +end + +test("foo") +test("bar") +test("moduuuuule") +test("modules.nested") +test(" modules ") +test("mod" .. string.char(127) .. "ules") diff --git a/src/tests/require/tests/module.luau b/src/tests/require/tests/module.luau new file mode 100644 index 0000000..cb3159b --- /dev/null +++ b/src/tests/require/tests/module.luau @@ -0,0 +1,4 @@ +return { + Foo = "Bar", + Hello = "World", +} diff --git a/src/tests/require/tests/modules/module.luau b/src/tests/require/tests/modules/module.luau new file mode 100644 index 0000000..cb3159b --- /dev/null +++ b/src/tests/require/tests/modules/module.luau @@ -0,0 +1,4 @@ +return { + Foo = "Bar", + Hello = "World", +} diff --git a/src/tests/require/tests/modules/modules/module.luau b/src/tests/require/tests/modules/modules/module.luau new file mode 100644 index 0000000..cb3159b --- /dev/null +++ b/src/tests/require/tests/modules/modules/module.luau @@ -0,0 +1,4 @@ +return { + Foo = "Bar", + Hello = "World", +} diff --git a/src/tests/require/tests/modules/nested.luau b/src/tests/require/tests/modules/nested.luau new file mode 100644 index 0000000..9aa767a --- /dev/null +++ b/src/tests/require/tests/modules/nested.luau @@ -0,0 +1 @@ +return require("modules/module") diff --git a/src/tests/require/tests/nested.luau b/src/tests/require/tests/nested.luau new file mode 100644 index 0000000..68d2d10 --- /dev/null +++ b/src/tests/require/tests/nested.luau @@ -0,0 +1,7 @@ +local module = require("./modules/nested") + +assert(type(module) == "table", "Required module did not return a table") +assert(module.Foo == "Bar", "Required module did not contain correct values") +assert(module.Hello == "World", "Required module did not contain correct values") + +require("modules/nested") diff --git a/src/tests/require/tests/parents.luau b/src/tests/require/tests/parents.luau new file mode 100644 index 0000000..35abaa0 --- /dev/null +++ b/src/tests/require/tests/parents.luau @@ -0,0 +1,5 @@ +local module = require("../modules/module") + +assert(type(module) == "table", "Required module did not return a table") +assert(module.Foo == "Bar", "Required module did not contain correct values") +assert(module.Hello == "World", "Required module did not contain correct values") diff --git a/src/tests/require/tests/siblings.luau b/src/tests/require/tests/siblings.luau new file mode 100644 index 0000000..19cf4c6 --- /dev/null +++ b/src/tests/require/tests/siblings.luau @@ -0,0 +1,11 @@ +local module = require("./module") + +assert(type(module) == "table", "Required module did not return a table") +assert(module.Foo == "Bar", "Required module did not contain correct values") +assert(module.Hello == "World", "Required module did not contain correct values") + +require("./children") +require("./parents") + +require("children") +require("parents")