Implement relative path requires, proper exit codes

This commit is contained in:
Filip Tibell 2023-01-24 02:05:54 -05:00
parent bfe852f034
commit 0d0bb3f178
No known key found for this signature in database
21 changed files with 226 additions and 50 deletions

View file

@ -6,7 +6,7 @@ print("Hello, lune! 🌙")
Using a function from another module Using a function from another module
]==] ]==]
local module = require(".lune/module") local module = require("./module")
module.sayHello() module.sayHello()
--[==[ --[==[
@ -36,8 +36,7 @@ end)
print("Spawning a delayed task that will run in 5 seconds") print("Spawning a delayed task that will run in 5 seconds")
task.delay(5, function() task.delay(5, function()
print() print("\n...")
print("...")
task.wait(1) task.wait(1)
print("Hello again!") print("Hello again!")
task.wait(1) task.wait(1)

View file

@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Changed
- `require` now uses paths relative to the file instead of being relative to the current directory, which is consistent with almost all other languages but not original Lua / Luau - this is a breaking change but will allow for proper packaging of third-party modules and more in the future.
- Improved error message when an invalid file path is passed to `require`
- Much improved error formatting and stack traces
### Fixed
- Process termination will now always make sure all lua state is cleaned up before exiting, in all cases
## `0.0.6` - January 23rd, 2023 ## `0.0.6` - January 23rd, 2023
### Added ### Added

View file

@ -1,4 +1,4 @@
use std::fs::read_to_string; use std::{fs::read_to_string, process::ExitCode};
use anyhow::Result; use anyhow::Result;
use clap::{CommandFactory, Parser}; use clap::{CommandFactory, Parser};
@ -54,7 +54,7 @@ impl Cli {
} }
} }
pub async fn run(self) -> Result<()> { pub async fn run(self) -> Result<ExitCode> {
// Download definition files, if wanted // Download definition files, if wanted
let download_types_requested = self.download_selene_types || self.download_luau_types; let download_types_requested = self.download_selene_types || self.download_luau_types;
if download_types_requested { if download_types_requested {
@ -82,7 +82,7 @@ impl Cli {
// Only downloading types without running a script is completely // Only downloading types without running a script is completely
// fine, and we should just exit the program normally afterwards // fine, and we should just exit the program normally afterwards
if download_types_requested { if download_types_requested {
return Ok(()); return Ok(ExitCode::SUCCESS);
} }
// HACK: We know that we didn't get any arguments here but since // HACK: We know that we didn't get any arguments here but since
// script_path is optional clap will not error on its own, to fix // script_path is optional clap will not error on its own, to fix
@ -98,10 +98,13 @@ impl Cli {
let file_display_name = file_path.with_extension("").display().to_string(); let file_display_name = file_path.with_extension("").display().to_string();
// Create a new lune object with all globals & run the script // Create a new lune object with all globals & run the script
let lune = Lune::new().with_args(self.script_args).with_all_globals(); let lune = Lune::new().with_args(self.script_args).with_all_globals();
if let Err(e) = lune.run(&file_display_name, &file_contents).await { let result = lune.run(&file_display_name, &file_contents).await;
Ok(match result {
Err(e) => {
eprintln!("{e}"); eprintln!("{e}");
std::process::exit(1); ExitCode::from(1)
}; }
Ok(()) Ok(code) => code,
})
} }
} }

View file

@ -2,6 +2,8 @@
#![warn(clippy::cargo, clippy::pedantic)] #![warn(clippy::cargo, clippy::pedantic)]
#![allow(clippy::needless_pass_by_value, clippy::match_bool)] #![allow(clippy::needless_pass_by_value, clippy::match_bool)]
use std::process::ExitCode;
use anyhow::Result; use anyhow::Result;
use clap::Parser; use clap::Parser;
@ -10,10 +12,6 @@ mod utils;
use cli::Cli; use cli::Cli;
fn main() -> Result<()> { fn main() -> Result<ExitCode> {
smol::block_on(async { smol::block_on(async { Cli::parse().run().await })
let cli = Cli::parse();
cli.run().await?;
Ok(())
})
} }

View file

@ -2,10 +2,12 @@ mod console;
mod fs; mod fs;
mod net; mod net;
mod process; mod process;
mod require;
mod task; mod task;
pub use console::create as create_console; pub use console::create as create_console;
pub use fs::create as create_fs; pub use fs::create as create_fs;
pub use net::create as create_net; pub use net::create as create_net;
pub use process::create as create_process; pub use process::create as create_process;
pub use require::create as create_require;
pub use task::create as create_task; pub use task::create as create_task;

View file

@ -0,0 +1,72 @@
use std::{env, path::PathBuf, sync::Arc};
use mlua::prelude::*;
use os_str_bytes::RawOsStr;
pub fn create(lua: &Lua) -> LuaResult<()> {
// Preserve original require behavior if we have a special env var set
if env::var_os("LUAU_PWD_REQUIRE").is_some() {
return Ok(());
}
// Fetch the debug info function and store it in the registry
// - we will use it to fetch the current scripts file name
let debug: LuaTable = lua.globals().raw_get("debug")?;
let info: LuaFunction = debug.raw_get("info")?;
lua.set_named_registry_value("require_getinfo", info)?;
// Fetch the original require function and store it in the registry
let require: LuaFunction = lua.globals().raw_get("require")?;
lua.set_named_registry_value("require_original", require)?;
/*
Create a new function that fetches the file name from the current thread,
sets the luau module lookup path to be the exact script we are looking
for, and then runs the original require function with the wanted path
*/
let new_require = lua.create_function(|lua, require_path: LuaString| {
let require_original: LuaFunction = lua.named_registry_value("require_original")?;
let require_getinfo: LuaFunction = lua.named_registry_value("require_getinfo")?;
let require_source: LuaString = require_getinfo.call((2, "s"))?;
/*
Combine the require caller source with the wanted path
string to get a final path relative to pwd - it is definitely
relative to pwd because Lune will only load files relative to pwd
Here we also take extra care to not perform any lossy conversion
and use os strings instead of Rust's utf-8 checked strings, in the
unlikely case someone out there uses luau with non-utf8 string requires
*/
let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes());
let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes());
let path_relative_to_pwd = PathBuf::from(&raw_source.to_os_str())
.parent()
.unwrap()
.join(raw_path.to_os_str());
// Create a lossless lua string from the pathbuf and finally call require
let raw_path_str = RawOsStr::new(path_relative_to_pwd.as_os_str());
let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes());
// If the require call errors then we should also replace
// the path in the error message to improve user experience
let result: LuaResult<_> = require_original.call::<_, LuaValue>(lua_path_str);
match result {
Err(LuaError::CallbackError { traceback, cause }) => {
let before = format!(
"runtime error: cannot find '{}'",
path_relative_to_pwd.to_str().unwrap()
);
let after = format!(
"Invalid require path '{}' ({})",
require_path.to_str().unwrap(),
path_relative_to_pwd.to_str().unwrap()
);
let cause = Arc::new(LuaError::RuntimeError(
cause.to_string().replace(&before, &after),
));
Err(LuaError::CallbackError { traceback, cause })
}
Err(e) => Err(e),
Ok(result) => Ok(result),
}
})?;
// Override the original require global with our monkey-patched one
lua.globals().raw_set("require", new_require)?;
Ok(())
}

View file

@ -1,4 +1,4 @@
use std::{collections::HashSet, sync::Arc}; use std::{collections::HashSet, process::ExitCode, sync::Arc};
use anyhow::{anyhow, bail, Result}; use anyhow::{anyhow, bail, Result};
use mlua::prelude::*; use mlua::prelude::*;
@ -8,7 +8,7 @@ pub mod globals;
pub mod utils; pub mod utils;
use crate::{ use crate::{
globals::{create_console, create_fs, create_net, create_process, create_task}, globals::{create_console, create_fs, create_net, create_process, create_require, create_task},
utils::formatting::pretty_format_luau_error, utils::formatting::pretty_format_luau_error,
}; };
@ -21,6 +21,7 @@ pub enum LuneGlobal {
Fs, Fs,
Net, Net,
Process, Process,
Require,
Task, Task,
} }
@ -31,6 +32,7 @@ impl LuneGlobal {
Self::Fs, Self::Fs,
Self::Net, Self::Net,
Self::Process, Self::Process,
Self::Require,
Self::Task, Self::Task,
] ]
} }
@ -73,7 +75,7 @@ impl Lune {
self self
} }
pub async fn run(&self, name: &str, chunk: &str) -> Result<u8> { pub async fn run(&self, name: &str, chunk: &str) -> Result<ExitCode> {
let (s, r) = smol::channel::unbounded::<LuneMessage>(); let (s, r) = smol::channel::unbounded::<LuneMessage>();
let lua = Arc::new(mlua::Lua::new()); let lua = Arc::new(mlua::Lua::new());
let exec = Arc::new(LocalExecutor::new()); let exec = Arc::new(LocalExecutor::new());
@ -90,6 +92,7 @@ impl Lune {
LuneGlobal::Fs => create_fs(&lua)?, LuneGlobal::Fs => create_fs(&lua)?,
LuneGlobal::Net => create_net(&lua)?, LuneGlobal::Net => create_net(&lua)?,
LuneGlobal::Process => create_process(&lua, self.args.clone())?, LuneGlobal::Process => create_process(&lua, self.args.clone())?,
LuneGlobal::Require => create_require(&lua)?,
LuneGlobal::Task => create_task(&lua)?, LuneGlobal::Task => create_task(&lua)?,
} }
} }
@ -100,7 +103,7 @@ impl Lune {
sender.send(LuneMessage::Spawned).await?; sender.send(LuneMessage::Spawned).await?;
let result = lua let result = lua
.load(&script_chunk) .load(&script_chunk)
.set_name(&script_name) .set_name(&format!("={}", script_name))
.unwrap() .unwrap()
.call_async::<_, LuaMultiValue>(LuaMultiValue::new()) .call_async::<_, LuaMultiValue>(LuaMultiValue::new())
.await; .await;
@ -155,7 +158,7 @@ impl Lune {
task_count += 1; task_count += 1;
} }
LuneMessage::LuaError(e) => { LuneMessage::LuaError(e) => {
eprintln!("{}", e); eprintln!("{}", pretty_format_luau_error(&e));
got_error = true; got_error = true;
task_count += 1; task_count += 1;
} }
@ -171,11 +174,11 @@ impl Lune {
// If we got an error, we will default to exiting // If we got an error, we will default to exiting
// with code 1, unless a code was manually given // with code 1, unless a code was manually given
if got_code { if got_code {
Ok(exit_code) Ok(ExitCode::from(exit_code))
} else if got_error { } else if got_error {
Ok(1) Ok(ExitCode::FAILURE)
} else { } else {
Ok(0) Ok(ExitCode::SUCCESS)
} }
} }
} }
@ -183,9 +186,9 @@ impl Lune {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::Lune; use crate::Lune;
use anyhow::{bail, Result}; use anyhow::Result;
use smol::fs::read_to_string; use smol::fs::read_to_string;
use std::env::current_dir; use std::process::ExitCode;
const ARGS: &[&str] = &["Foo", "Bar"]; const ARGS: &[&str] = &["Foo", "Bar"];
@ -193,12 +196,10 @@ mod tests {
($($name:ident: $value:expr,)*) => { ($($name:ident: $value:expr,)*) => {
$( $(
#[test] #[test]
fn $name() -> Result<()> { fn $name() -> Result<ExitCode> {
smol::block_on(async { smol::block_on(async {
let path = current_dir() let full_name = format!("src/tests/{}.luau", $value);
.unwrap() let script = read_to_string(&full_name)
.join(format!("src/tests/{}.luau", $value));
let script = read_to_string(&path)
.await .await
.unwrap(); .unwrap();
let lune = Lune::new() let lune = Lune::new()
@ -210,11 +211,8 @@ mod tests {
.collect() .collect()
) )
.with_all_globals(); .with_all_globals();
let exit_code = lune.run($value, &script).await?; let script_name = full_name.strip_suffix(".luau").unwrap();
if exit_code != 0 { lune.run(&script_name, &script).await
bail!("Test exited with failure code {}", exit_code);
}
Ok(())
}) })
} }
)* )*
@ -227,15 +225,20 @@ mod tests {
console_set_style: "console/set_style", console_set_style: "console/set_style",
fs_files: "fs/files", fs_files: "fs/files",
fs_dirs: "fs/dirs", fs_dirs: "fs/dirs",
process_args: "process/args",
process_env: "process/env",
process_exit: "process/exit",
process_spawn: "process/spawn",
net_request_codes: "net/request/codes", net_request_codes: "net/request/codes",
net_request_methods: "net/request/methods", net_request_methods: "net/request/methods",
net_request_redirect: "net/request/redirect", net_request_redirect: "net/request/redirect",
net_json_decode: "net/json/decode", net_json_decode: "net/json/decode",
net_json_encode: "net/json/encode", net_json_encode: "net/json/encode",
process_args: "process/args",
process_env: "process/env",
process_exit: "process/exit",
process_spawn: "process/spawn",
require_children: "require/tests/children",
require_invalid: "require/tests/invalid",
require_nested: "require/tests/nested",
require_parents: "require/tests/parents",
require_siblings: "require/tests/siblings",
task_cancel: "task/cancel", task_cancel: "task/cancel",
task_defer: "task/defer", task_defer: "task/defer",
task_delay: "task/delay", task_delay: "task/delay",

View file

@ -204,23 +204,36 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
.lines() .lines()
.map(|s| s.to_string()) .map(|s| s.to_string())
.collect::<Vec<String>>(); .collect::<Vec<String>>();
let mut found_stack_begin = false;
for (index, line) in err_lines.clone().iter().enumerate().rev() { for (index, line) in err_lines.clone().iter().enumerate().rev() {
if *line == "stack traceback:" { if *line == "stack traceback:" {
err_lines[index] = stack_begin; err_lines[index] = stack_begin;
found_stack_begin = true;
break; break;
} }
} }
// Add "Stack End" to the very end of the stack trace for symmetry // Add "Stack End" to the very end of the stack trace for symmetry
if found_stack_begin {
err_lines.push(stack_end); err_lines.push(stack_end);
}
err_lines.join("\n") err_lines.join("\n")
} }
LuaError::CallbackError { cause, traceback } => { LuaError::CallbackError { traceback, cause } => {
// Find the best traceback (longest) and the root error message
let mut best_trace = traceback;
let mut root_cause = cause.as_ref();
while let LuaError::CallbackError { cause, traceback } = root_cause {
if traceback.len() > best_trace.len() {
best_trace = traceback;
}
root_cause = cause;
}
// Same error formatting as above // Same error formatting as above
format!( format!(
"{}\n{}{}{}", "{}\n{}\n{}\n{}",
pretty_format_luau_error(cause.as_ref()), pretty_format_luau_error(root_cause),
stack_begin, stack_begin,
traceback.strip_prefix("stack traceback:\n").unwrap(), best_trace.strip_prefix("stack traceback:\n").unwrap(),
stack_end stack_end
) )
} }
@ -247,6 +260,8 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
let mut err_lines = err_string.lines().collect::<Vec<_>>(); let mut err_lines = err_string.lines().collect::<Vec<_>>();
// Remove the script path from the error message // Remove the script path from the error message
// itself, it can be found in the stack trace // itself, it can be found in the stack trace
// FIXME: This no longer works now that we use
// an exact name when our lune script is loaded
if let Some(first_line) = err_lines.first() { if let Some(first_line) = err_lines.first() {
if first_line.starts_with("[string \"") { if first_line.starts_with("[string \"") {
if let Some(closing_bracket) = first_line.find("]:") { if let Some(closing_bracket) = first_line.find("]:") {
@ -260,7 +275,6 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
} }
} }
} }
// Reformat stack trace lines, ignore lines that just mention C functions
// Merge all lines back together into one string // Merge all lines back together into one string
err_lines.join("\n") err_lines.join("\n")
} }

View file

@ -1,4 +1,4 @@
local util = require("src/tests/net/request/util") local util = require("./util")
local pass, fail = util.pass, util.fail local pass, fail = util.pass, util.fail
pass("GET", "https://httpbin.org/status/200", "Request status code - 200") pass("GET", "https://httpbin.org/status/200", "Request status code - 200")

View file

@ -1,4 +1,4 @@
local util = require("src/tests/net/request/util") local util = require("./util")
local pass = util.pass local pass = util.pass
-- stylua: ignore start -- stylua: ignore start

View file

@ -1,4 +1,4 @@
local util = require("src/tests/net/request/util") local util = require("./util")
local pass = util.pass local pass = util.pass
pass("GET", "https://httpbin.org/absolute-redirect/3", "Redirect 3 times") pass("GET", "https://httpbin.org/absolute-redirect/3", "Redirect 3 times")

View file

@ -0,0 +1,4 @@
return {
Foo = "Bar",
Hello = "World",
}

View file

@ -0,0 +1,7 @@
local module = require("./modules/module")
assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
require("modules/module")

View file

@ -0,0 +1,26 @@
local function test(path: string)
local success, message = pcall(function()
local _ = require(path) :: any
end)
if success then
error(string.format("Invalid require at path '%s' succeeded", path))
else
message = tostring(message)
if string.find(message, string.format("'%s'", path)) == nil then
error(
string.format(
"Invalid require did not mention path '%s' in its error message!\nMessage: %s",
path,
tostring(message)
)
)
end
end
end
test("foo")
test("bar")
test("moduuuuule")
test("modules.nested")
test(" modules ")
test("mod" .. string.char(127) .. "ules")

View file

@ -0,0 +1,4 @@
return {
Foo = "Bar",
Hello = "World",
}

View file

@ -0,0 +1,4 @@
return {
Foo = "Bar",
Hello = "World",
}

View file

@ -0,0 +1,4 @@
return {
Foo = "Bar",
Hello = "World",
}

View file

@ -0,0 +1 @@
return require("modules/module")

View file

@ -0,0 +1,7 @@
local module = require("./modules/nested")
assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
require("modules/nested")

View file

@ -0,0 +1,5 @@
local module = require("../modules/module")
assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")

View file

@ -0,0 +1,11 @@
local module = require("./module")
assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
require("./children")
require("./parents")
require("children")
require("parents")