Implement relative path requires, proper exit codes

This commit is contained in:
Filip Tibell 2023-01-24 02:05:54 -05:00
parent bfe852f034
commit 0d0bb3f178
No known key found for this signature in database
21 changed files with 226 additions and 50 deletions

View file

@ -6,7 +6,7 @@ print("Hello, lune! 🌙")
Using a function from another module
]==]
local module = require(".lune/module")
local module = require("./module")
module.sayHello()
--[==[
@ -36,8 +36,7 @@ end)
print("Spawning a delayed task that will run in 5 seconds")
task.delay(5, function()
print()
print("...")
print("\n...")
task.wait(1)
print("Hello again!")
task.wait(1)

View file

@ -5,6 +5,18 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Changed
- `require` now uses paths relative to the file instead of being relative to the current directory, which is consistent with almost all other languages but not original Lua / Luau - this is a breaking change but will allow for proper packaging of third-party modules and more in the future.
- Improved error message when an invalid file path is passed to `require`
- Much improved error formatting and stack traces
### Fixed
- Process termination will now always make sure all lua state is cleaned up before exiting, in all cases
## `0.0.6` - January 23rd, 2023
### Added

View file

@ -1,4 +1,4 @@
use std::fs::read_to_string;
use std::{fs::read_to_string, process::ExitCode};
use anyhow::Result;
use clap::{CommandFactory, Parser};
@ -54,7 +54,7 @@ impl Cli {
}
}
pub async fn run(self) -> Result<()> {
pub async fn run(self) -> Result<ExitCode> {
// Download definition files, if wanted
let download_types_requested = self.download_selene_types || self.download_luau_types;
if download_types_requested {
@ -82,7 +82,7 @@ impl Cli {
// Only downloading types without running a script is completely
// fine, and we should just exit the program normally afterwards
if download_types_requested {
return Ok(());
return Ok(ExitCode::SUCCESS);
}
// HACK: We know that we didn't get any arguments here but since
// script_path is optional clap will not error on its own, to fix
@ -98,10 +98,13 @@ impl Cli {
let file_display_name = file_path.with_extension("").display().to_string();
// Create a new lune object with all globals & run the script
let lune = Lune::new().with_args(self.script_args).with_all_globals();
if let Err(e) = lune.run(&file_display_name, &file_contents).await {
eprintln!("{e}");
std::process::exit(1);
};
Ok(())
let result = lune.run(&file_display_name, &file_contents).await;
Ok(match result {
Err(e) => {
eprintln!("{e}");
ExitCode::from(1)
}
Ok(code) => code,
})
}
}

View file

@ -2,6 +2,8 @@
#![warn(clippy::cargo, clippy::pedantic)]
#![allow(clippy::needless_pass_by_value, clippy::match_bool)]
use std::process::ExitCode;
use anyhow::Result;
use clap::Parser;
@ -10,10 +12,6 @@ mod utils;
use cli::Cli;
fn main() -> Result<()> {
smol::block_on(async {
let cli = Cli::parse();
cli.run().await?;
Ok(())
})
fn main() -> Result<ExitCode> {
smol::block_on(async { Cli::parse().run().await })
}

View file

@ -2,10 +2,12 @@ mod console;
mod fs;
mod net;
mod process;
mod require;
mod task;
pub use console::create as create_console;
pub use fs::create as create_fs;
pub use net::create as create_net;
pub use process::create as create_process;
pub use require::create as create_require;
pub use task::create as create_task;

View file

@ -0,0 +1,72 @@
use std::{env, path::PathBuf, sync::Arc};
use mlua::prelude::*;
use os_str_bytes::RawOsStr;
pub fn create(lua: &Lua) -> LuaResult<()> {
// Preserve original require behavior if we have a special env var set
if env::var_os("LUAU_PWD_REQUIRE").is_some() {
return Ok(());
}
// Fetch the debug info function and store it in the registry
// - we will use it to fetch the current scripts file name
let debug: LuaTable = lua.globals().raw_get("debug")?;
let info: LuaFunction = debug.raw_get("info")?;
lua.set_named_registry_value("require_getinfo", info)?;
// Fetch the original require function and store it in the registry
let require: LuaFunction = lua.globals().raw_get("require")?;
lua.set_named_registry_value("require_original", require)?;
/*
Create a new function that fetches the file name from the current thread,
sets the luau module lookup path to be the exact script we are looking
for, and then runs the original require function with the wanted path
*/
let new_require = lua.create_function(|lua, require_path: LuaString| {
let require_original: LuaFunction = lua.named_registry_value("require_original")?;
let require_getinfo: LuaFunction = lua.named_registry_value("require_getinfo")?;
let require_source: LuaString = require_getinfo.call((2, "s"))?;
/*
Combine the require caller source with the wanted path
string to get a final path relative to pwd - it is definitely
relative to pwd because Lune will only load files relative to pwd
Here we also take extra care to not perform any lossy conversion
and use os strings instead of Rust's utf-8 checked strings, in the
unlikely case someone out there uses luau with non-utf8 string requires
*/
let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes());
let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes());
let path_relative_to_pwd = PathBuf::from(&raw_source.to_os_str())
.parent()
.unwrap()
.join(raw_path.to_os_str());
// Create a lossless lua string from the pathbuf and finally call require
let raw_path_str = RawOsStr::new(path_relative_to_pwd.as_os_str());
let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes());
// If the require call errors then we should also replace
// the path in the error message to improve user experience
let result: LuaResult<_> = require_original.call::<_, LuaValue>(lua_path_str);
match result {
Err(LuaError::CallbackError { traceback, cause }) => {
let before = format!(
"runtime error: cannot find '{}'",
path_relative_to_pwd.to_str().unwrap()
);
let after = format!(
"Invalid require path '{}' ({})",
require_path.to_str().unwrap(),
path_relative_to_pwd.to_str().unwrap()
);
let cause = Arc::new(LuaError::RuntimeError(
cause.to_string().replace(&before, &after),
));
Err(LuaError::CallbackError { traceback, cause })
}
Err(e) => Err(e),
Ok(result) => Ok(result),
}
})?;
// Override the original require global with our monkey-patched one
lua.globals().raw_set("require", new_require)?;
Ok(())
}

View file

@ -1,4 +1,4 @@
use std::{collections::HashSet, sync::Arc};
use std::{collections::HashSet, process::ExitCode, sync::Arc};
use anyhow::{anyhow, bail, Result};
use mlua::prelude::*;
@ -8,7 +8,7 @@ pub mod globals;
pub mod utils;
use crate::{
globals::{create_console, create_fs, create_net, create_process, create_task},
globals::{create_console, create_fs, create_net, create_process, create_require, create_task},
utils::formatting::pretty_format_luau_error,
};
@ -21,6 +21,7 @@ pub enum LuneGlobal {
Fs,
Net,
Process,
Require,
Task,
}
@ -31,6 +32,7 @@ impl LuneGlobal {
Self::Fs,
Self::Net,
Self::Process,
Self::Require,
Self::Task,
]
}
@ -73,7 +75,7 @@ impl Lune {
self
}
pub async fn run(&self, name: &str, chunk: &str) -> Result<u8> {
pub async fn run(&self, name: &str, chunk: &str) -> Result<ExitCode> {
let (s, r) = smol::channel::unbounded::<LuneMessage>();
let lua = Arc::new(mlua::Lua::new());
let exec = Arc::new(LocalExecutor::new());
@ -90,6 +92,7 @@ impl Lune {
LuneGlobal::Fs => create_fs(&lua)?,
LuneGlobal::Net => create_net(&lua)?,
LuneGlobal::Process => create_process(&lua, self.args.clone())?,
LuneGlobal::Require => create_require(&lua)?,
LuneGlobal::Task => create_task(&lua)?,
}
}
@ -100,7 +103,7 @@ impl Lune {
sender.send(LuneMessage::Spawned).await?;
let result = lua
.load(&script_chunk)
.set_name(&script_name)
.set_name(&format!("={}", script_name))
.unwrap()
.call_async::<_, LuaMultiValue>(LuaMultiValue::new())
.await;
@ -155,7 +158,7 @@ impl Lune {
task_count += 1;
}
LuneMessage::LuaError(e) => {
eprintln!("{}", e);
eprintln!("{}", pretty_format_luau_error(&e));
got_error = true;
task_count += 1;
}
@ -171,11 +174,11 @@ impl Lune {
// If we got an error, we will default to exiting
// with code 1, unless a code was manually given
if got_code {
Ok(exit_code)
Ok(ExitCode::from(exit_code))
} else if got_error {
Ok(1)
Ok(ExitCode::FAILURE)
} else {
Ok(0)
Ok(ExitCode::SUCCESS)
}
}
}
@ -183,9 +186,9 @@ impl Lune {
#[cfg(test)]
mod tests {
use crate::Lune;
use anyhow::{bail, Result};
use anyhow::Result;
use smol::fs::read_to_string;
use std::env::current_dir;
use std::process::ExitCode;
const ARGS: &[&str] = &["Foo", "Bar"];
@ -193,12 +196,10 @@ mod tests {
($($name:ident: $value:expr,)*) => {
$(
#[test]
fn $name() -> Result<()> {
fn $name() -> Result<ExitCode> {
smol::block_on(async {
let path = current_dir()
.unwrap()
.join(format!("src/tests/{}.luau", $value));
let script = read_to_string(&path)
let full_name = format!("src/tests/{}.luau", $value);
let script = read_to_string(&full_name)
.await
.unwrap();
let lune = Lune::new()
@ -210,11 +211,8 @@ mod tests {
.collect()
)
.with_all_globals();
let exit_code = lune.run($value, &script).await?;
if exit_code != 0 {
bail!("Test exited with failure code {}", exit_code);
}
Ok(())
let script_name = full_name.strip_suffix(".luau").unwrap();
lune.run(&script_name, &script).await
})
}
)*
@ -227,15 +225,20 @@ mod tests {
console_set_style: "console/set_style",
fs_files: "fs/files",
fs_dirs: "fs/dirs",
process_args: "process/args",
process_env: "process/env",
process_exit: "process/exit",
process_spawn: "process/spawn",
net_request_codes: "net/request/codes",
net_request_methods: "net/request/methods",
net_request_redirect: "net/request/redirect",
net_json_decode: "net/json/decode",
net_json_encode: "net/json/encode",
process_args: "process/args",
process_env: "process/env",
process_exit: "process/exit",
process_spawn: "process/spawn",
require_children: "require/tests/children",
require_invalid: "require/tests/invalid",
require_nested: "require/tests/nested",
require_parents: "require/tests/parents",
require_siblings: "require/tests/siblings",
task_cancel: "task/cancel",
task_defer: "task/defer",
task_delay: "task/delay",

View file

@ -204,23 +204,36 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
.lines()
.map(|s| s.to_string())
.collect::<Vec<String>>();
let mut found_stack_begin = false;
for (index, line) in err_lines.clone().iter().enumerate().rev() {
if *line == "stack traceback:" {
err_lines[index] = stack_begin;
found_stack_begin = true;
break;
}
}
// Add "Stack End" to the very end of the stack trace for symmetry
err_lines.push(stack_end);
if found_stack_begin {
err_lines.push(stack_end);
}
err_lines.join("\n")
}
LuaError::CallbackError { cause, traceback } => {
LuaError::CallbackError { traceback, cause } => {
// Find the best traceback (longest) and the root error message
let mut best_trace = traceback;
let mut root_cause = cause.as_ref();
while let LuaError::CallbackError { cause, traceback } = root_cause {
if traceback.len() > best_trace.len() {
best_trace = traceback;
}
root_cause = cause;
}
// Same error formatting as above
format!(
"{}\n{}{}{}",
pretty_format_luau_error(cause.as_ref()),
"{}\n{}\n{}\n{}",
pretty_format_luau_error(root_cause),
stack_begin,
traceback.strip_prefix("stack traceback:\n").unwrap(),
best_trace.strip_prefix("stack traceback:\n").unwrap(),
stack_end
)
}
@ -247,6 +260,8 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
let mut err_lines = err_string.lines().collect::<Vec<_>>();
// Remove the script path from the error message
// itself, it can be found in the stack trace
// FIXME: This no longer works now that we use
// an exact name when our lune script is loaded
if let Some(first_line) = err_lines.first() {
if first_line.starts_with("[string \"") {
if let Some(closing_bracket) = first_line.find("]:") {
@ -260,7 +275,6 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
}
}
}
// Reformat stack trace lines, ignore lines that just mention C functions
// Merge all lines back together into one string
err_lines.join("\n")
}

View file

@ -1,4 +1,4 @@
local util = require("src/tests/net/request/util")
local util = require("./util")
local pass, fail = util.pass, util.fail
pass("GET", "https://httpbin.org/status/200", "Request status code - 200")

View file

@ -1,4 +1,4 @@
local util = require("src/tests/net/request/util")
local util = require("./util")
local pass = util.pass
-- stylua: ignore start

View file

@ -1,4 +1,4 @@
local util = require("src/tests/net/request/util")
local util = require("./util")
local pass = util.pass
pass("GET", "https://httpbin.org/absolute-redirect/3", "Redirect 3 times")

View file

@ -0,0 +1,4 @@
return {
Foo = "Bar",
Hello = "World",
}

View file

@ -0,0 +1,7 @@
local module = require("./modules/module")
assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
require("modules/module")

View file

@ -0,0 +1,26 @@
local function test(path: string)
local success, message = pcall(function()
local _ = require(path) :: any
end)
if success then
error(string.format("Invalid require at path '%s' succeeded", path))
else
message = tostring(message)
if string.find(message, string.format("'%s'", path)) == nil then
error(
string.format(
"Invalid require did not mention path '%s' in its error message!\nMessage: %s",
path,
tostring(message)
)
)
end
end
end
test("foo")
test("bar")
test("moduuuuule")
test("modules.nested")
test(" modules ")
test("mod" .. string.char(127) .. "ules")

View file

@ -0,0 +1,4 @@
return {
Foo = "Bar",
Hello = "World",
}

View file

@ -0,0 +1,4 @@
return {
Foo = "Bar",
Hello = "World",
}

View file

@ -0,0 +1,4 @@
return {
Foo = "Bar",
Hello = "World",
}

View file

@ -0,0 +1 @@
return require("modules/module")

View file

@ -0,0 +1,7 @@
local module = require("./modules/nested")
assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
require("modules/nested")

View file

@ -0,0 +1,5 @@
local module = require("../modules/module")
assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")

View file

@ -0,0 +1,11 @@
local module = require("./module")
assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
require("./children")
require("./parents")
require("children")
require("parents")