Implement proper async require & error handling

This commit is contained in:
Filip Tibell 2023-02-17 15:03:13 +01:00
parent 7f17ab0063
commit 58ce046394
No known key found for this signature in database
22 changed files with 412 additions and 132 deletions

View file

@ -1,96 +1,107 @@
use std::{ use std::{
env::{self, current_dir}, env::{self, current_dir},
io,
path::PathBuf, path::PathBuf,
sync::Arc,
}; };
use mlua::prelude::*; use mlua::prelude::*;
use os_str_bytes::{OsStrBytes, RawOsStr}; use tokio::fs;
use crate::utils::table::TableBuilder; use crate::utils::table::TableBuilder;
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> { pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
// Preserve original require behavior if we have a special env var set // Preserve original require behavior if we have a special env var set,
// returning an empty table since there are no globals to overwrite
if env::var_os("LUAU_PWD_REQUIRE").is_some() { if env::var_os("LUAU_PWD_REQUIRE").is_some() {
// Return an empty table since there are no globals to overwrite
return TableBuilder::new(lua)?.build_readonly(); return TableBuilder::new(lua)?.build_readonly();
} }
/* // Store the current pwd, and make helper functions for path conversions
Store the current working directory so that we can use it later let require_pwd = current_dir()?.to_string_lossy().to_string();
and remove it from require paths in error messages, showing
absolute paths is bad ux and we should try to avoid it
Throughout this function we also take extra care to not perform any lossy
conversion and use os strings instead of Rust's utf-8 checked strings,
just in case someone out there uses luau with non-utf8 string requires
*/
let pwd = lua.create_string(&current_dir()?.to_raw_bytes())?;
lua.set_named_registry_value("pwd", pwd)?;
/*
Create a new function that fetches the file name from the current thread,
sets the luau module lookup path to be the exact script we are looking
for, and then runs the original require function with the wanted path
*/
let new_require = lua.create_function(|lua, require_path: LuaString| {
let require_pwd: LuaString = lua.named_registry_value("pwd")?;
let require_fn: LuaFunction = lua.named_registry_value("require")?;
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
let require_source: LuaString = require_info.call((2, "s"))?; let require_error: LuaFunction = lua.named_registry_value("error")?;
/* let require_get_abs_rel_paths = lua
Combine the require caller source with the wanted path .create_function(
string to get a final path relative to pwd - it is definitely |_, (require_pwd, require_source, require_path): (String, String, String)| {
relative to pwd because Lune will only load files relative to pwd
*/
let raw_pwd_str = RawOsStr::assert_from_raw_bytes(require_pwd.as_bytes());
let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes());
let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes());
let mut path_relative_to_pwd = PathBuf::from( let mut path_relative_to_pwd = PathBuf::from(
&raw_source &require_source
.trim_start_matches("[string \"") .trim_start_matches("[string \"")
.trim_end_matches("\"]") .trim_end_matches("\"]"),
.to_os_str(),
) )
.parent() .parent()
.unwrap() .unwrap()
.join(raw_path.to_os_str()); .join(require_path);
// Try to normalize and resolve relative path segments such as './' and '../' // Try to normalize and resolve relative path segments such as './' and '../'
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("luau").canonicalize() { if let Ok(canonicalized) =
path_relative_to_pwd = canonicalized.with_extension(""); path_relative_to_pwd.with_extension("luau").canonicalize()
{
path_relative_to_pwd = canonicalized;
} }
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("lua").canonicalize() { if let Ok(canonicalized) = path_relative_to_pwd.with_extension("lua").canonicalize()
path_relative_to_pwd = canonicalized.with_extension(""); {
path_relative_to_pwd = canonicalized;
} }
if let Ok(stripped) = path_relative_to_pwd.strip_prefix(&raw_pwd_str.to_os_str()) { let absolute = path_relative_to_pwd.to_string_lossy().to_string();
path_relative_to_pwd = stripped.to_path_buf(); let relative = absolute.trim_start_matches(&require_pwd).to_string();
} Ok((absolute, relative))
// Create a lossless lua string from the pathbuf and finally call require },
let raw_path_str = RawOsStr::new(path_relative_to_pwd.as_os_str()); )?
let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes()); .bind(require_pwd)?;
// If the require call errors then we should also replace /*
// the path in the error message to improve user experience We need to get the source file where require was
let result: LuaResult<_> = require_fn.call::<_, LuaValue>(lua_path_str); called to be able to do path-relative requires,
match result { so we make a small wrapper to do that here, this
Err(LuaError::CallbackError { traceback, cause }) => { will then call our actual async require function
let before = format!(
"runtime error: cannot find '{}'", This must be done in lua because due to how our
path_relative_to_pwd.to_str().unwrap() scheduler works mlua can not preserve debug info
); */
let after = format!( let require_env = TableBuilder::new(lua)?
"Invalid require path '{}' ({})", .with_value("loaded", lua.create_table()?)?
require_path.to_str().unwrap(), .with_value("cache", lua.create_table()?)?
path_relative_to_pwd.to_str().unwrap() .with_value("info", require_info)?
); .with_value("error", require_error)?
let cause = Arc::new(LuaError::RuntimeError( .with_value("paths", require_get_abs_rel_paths)?
cause.to_string().replace(&before, &after), .with_async_function("load", load_file)?
)); .build_readonly()?;
Err(LuaError::CallbackError { traceback, cause }) let require_fn_lua = lua
} .load(
Err(e) => Err(e), r#"
Ok(result) => Ok(result), local source = info(2, "s")
} local absolute, relative = paths(source, ...)
})?; if loaded[absolute] ~= true then
// Override the original require global with our monkey-patched one local first, second = load(absolute, relative)
if first == nil or second ~= nil then
error("Module did not return exactly one value")
end
loaded[absolute] = true
cache[absolute] = first
return first
else
return cache[absolute]
end
"#,
)
.set_name("require")?
.set_environment(require_env)?
.into_function()?;
TableBuilder::new(lua)? TableBuilder::new(lua)?
.with_value("require", new_require)? .with_value("require", require_fn_lua)?
.build_readonly() .build_readonly()
} }
async fn load_file(
lua: &Lua,
(path_absolute, path_relative): (String, String),
) -> LuaResult<LuaValue> {
// Try to read the wanted file, note that we use bytes instead of reading
// to a string since lua scripts are not necessarily valid utf-8 strings
match fs::read(&path_absolute).await {
Ok(contents) => lua.load(&contents).set_name(path_relative)?.eval(),
Err(e) => match e.kind() {
io::ErrorKind::NotFound => Err(LuaError::RuntimeError(format!(
"No lua module exists at the path '{path_relative}'"
))),
_ => Err(LuaError::external(e)),
},
}
}

View file

@ -37,16 +37,21 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
})? })?
.with_function("error", |lua, (arg, level): (LuaValue, Option<u32>)| { .with_function("error", |lua, (arg, level): (LuaValue, Option<u32>)| {
let error: LuaFunction = lua.named_registry_value("error")?; let error: LuaFunction = lua.named_registry_value("error")?;
let multi = arg.to_lua_multi(lua)?; let trace: LuaFunction = lua.named_registry_value("dbg.trace")?;
error.call(( error.call((
format!( LuaError::CallbackError {
traceback: format!("override traceback:{}", trace.call::<_, String>(())?),
cause: LuaError::external(format!(
"{}\n{}", "{}\n{}",
format_label("error"), format_label("error"),
pretty_format_multi_value(&multi)? pretty_format_multi_value(&arg.to_lua_multi(lua)?)?
), ))
.into(),
},
level, level,
))?; ))?;
Ok(()) Ok(())
})? })?
// TODO: Add an override for tostring that formats errors in a nicer way
.build_readonly() .build_readonly()
} }

View file

@ -118,9 +118,8 @@ impl Lune {
let mut got_error = false; let mut got_error = false;
loop { loop {
let result = sched.resume_queue().await; let result = sched.resume_queue().await;
// println!("{result}");
if let Some(err) = result.get_lua_error() { if let Some(err) = result.get_lua_error() {
eprintln!("{}", pretty_format_luau_error(&err)); eprintln!("{}", pretty_format_luau_error(&err, true));
got_error = true; got_error = true;
} }
if result.is_done() { if result.is_done() {

View file

@ -3,29 +3,38 @@ use mlua::prelude::*;
/* /*
- Level 0 is the call to info - Level 0 is the call to info
- Level 1 is the load call in create() below where we load this into a function - Level 1 is the load call in create() below where we load this into a function
- Level 2 is the call to the scheduler, probably, but we can't know for sure so we start at 2 - Level 2 is the call to the trace, which we also want to skip, so start at 3
Also note that we must match the mlua traceback format here so that we
can pattern match and beautify it properly later on when outputting it
*/ */
const TRACE_IMPL_LUA: &str = r#" const TRACE_IMPL_LUA: &str = r#"
local lines = {} local lines = {}
for level = 2, 2^8 do for level = 3, 16 do
local parts = {}
local source, line, name = info(level, "sln") local source, line, name = info(level, "sln")
if source then if source then
if line then push(parts, source)
if name and #name > 0 then
push(lines, format(" Script '%s', Line %d - function %s", source, line, name))
else
push(lines, format(" Script '%s', Line %d", source, line))
end
elseif name and #name > 0 then
push(lines, format(" Script '%s' - function %s", source, name))
else
push(lines, format(" Script '%s'", source))
end
elseif name then
push(lines, format("[Lune] - function %s", source, name))
else else
break break
end end
if line == -1 then
line = nil
end
if name and #name <= 0 then
name = nil
end
if line then
push(parts, format("%d", line))
end
if name and #parts > 1 then
push(parts, format(" in function '%s'", name))
elseif name then
push(parts, format("in function '%s'", name))
end
if #parts > 0 then
push(lines, concat(parts, ":"))
end
end end
if #lines > 0 then if #lines > 0 then
return concat(lines, "\n") return concat(lines, "\n")
@ -49,12 +58,20 @@ end
* `"type"` -> `type` * `"type"` -> `type`
* `"typeof"` -> `typeof` * `"typeof"` -> `typeof`
--- ---
* `"pcall"` -> `pcall`
* `"xpcall"` -> `xpcall`
---
* `"tostring"` -> `tostring`
* `"tonumber"` -> `tonumber`
---
* `"co.thread"` -> `coroutine.running` * `"co.thread"` -> `coroutine.running`
* `"co.yield"` -> `coroutine.yield` * `"co.yield"` -> `coroutine.yield`
* `"co.close"` -> `coroutine.close` * `"co.close"` -> `coroutine.close`
--- ---
* `"dbg.info"` -> `debug.info` * `"dbg.info"` -> `debug.info`
* `"dbg.trace"` -> `debug.traceback` * `"dbg.trace"` -> `debug.traceback`
* `"dbg.iserr"` -> `<custom function>`
* `"dbg.makeerr"` -> `<custom function>`
--- ---
*/ */
pub fn create() -> LuaResult<&'static Lua> { pub fn create() -> LuaResult<&'static Lua> {
@ -72,23 +89,43 @@ pub fn create() -> LuaResult<&'static Lua> {
lua.set_named_registry_value("error", globals.get::<_, LuaFunction>("error")?)?; lua.set_named_registry_value("error", globals.get::<_, LuaFunction>("error")?)?;
lua.set_named_registry_value("type", globals.get::<_, LuaFunction>("type")?)?; lua.set_named_registry_value("type", globals.get::<_, LuaFunction>("type")?)?;
lua.set_named_registry_value("typeof", globals.get::<_, LuaFunction>("typeof")?)?; lua.set_named_registry_value("typeof", globals.get::<_, LuaFunction>("typeof")?)?;
lua.set_named_registry_value("xpcall", globals.get::<_, LuaFunction>("xpcall")?)?;
lua.set_named_registry_value("pcall", globals.get::<_, LuaFunction>("pcall")?)?;
lua.set_named_registry_value("tostring", globals.get::<_, LuaFunction>("tostring")?)?;
lua.set_named_registry_value("tonumber", globals.get::<_, LuaFunction>("tonumber")?)?;
lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?; lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?;
lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?; lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?;
lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?; lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?;
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;
lua.set_named_registry_value("tab.pack", table.get::<_, LuaFunction>("pack")?)?;
lua.set_named_registry_value("tab.unpack", table.get::<_, LuaFunction>("unpack")?)?;
// Create a function that can be called from lua to check if a value is a mlua error,
// this will be used in async environments for proper error handling and throwing, as
// well as a function that can be called to make a callback error with a traceback from lua
let dbg_is_err_fn =
lua.create_function(move |_, value: LuaValue| Ok(matches!(value, LuaValue::Error(_))))?;
let dbg_make_err_fn = lua.create_function(|_, (cause, traceback): (LuaError, String)| {
Ok(LuaError::CallbackError {
traceback,
cause: cause.into(),
})
})?;
// Create a trace function that can be called to obtain a full stack trace from // Create a trace function that can be called to obtain a full stack trace from
// lua, this is not possible to do from rust when using our manual scheduler // lua, this is not possible to do from rust when using our manual scheduler
let trace_env = lua.create_table_with_capacity(0, 1)?; let dbg_trace_env = lua.create_table_with_capacity(0, 1)?;
trace_env.set("info", debug.get::<_, LuaFunction>("info")?)?; dbg_trace_env.set("info", debug.get::<_, LuaFunction>("info")?)?;
trace_env.set("push", table.get::<_, LuaFunction>("insert")?)?; dbg_trace_env.set("push", table.get::<_, LuaFunction>("insert")?)?;
trace_env.set("concat", table.get::<_, LuaFunction>("concat")?)?; dbg_trace_env.set("concat", table.get::<_, LuaFunction>("concat")?)?;
trace_env.set("format", string.get::<_, LuaFunction>("format")?)?; dbg_trace_env.set("format", string.get::<_, LuaFunction>("format")?)?;
let trace_fn = lua let dbg_trace_fn = lua
.load(TRACE_IMPL_LUA) .load(TRACE_IMPL_LUA)
.set_name("=dbg.trace")? .set_name("=dbg.trace")?
.set_environment(trace_env)? .set_environment(dbg_trace_env)?
.into_function()?; .into_function()?;
lua.set_named_registry_value("dbg.trace", trace_fn)?; lua.set_named_registry_value("dbg.trace", dbg_trace_fn)?;
lua.set_named_registry_value("dbg.iserr", dbg_is_err_fn)?;
lua.set_named_registry_value("dbg.makeerr", dbg_make_err_fn)?;
// All done // All done
Ok(lua) Ok(lua)
} }

View file

@ -26,9 +26,19 @@ impl LuaAsyncExt for &'static Lua {
F: 'static + Fn(&'static Lua, A) -> FR, F: 'static + Fn(&'static Lua, A) -> FR,
FR: 'static + Future<Output = LuaResult<R>>, FR: 'static + Future<Output = LuaResult<R>>,
{ {
let async_env_make_err: LuaFunction = self.named_registry_value("dbg.makeerr")?;
let async_env_is_err: LuaFunction = self.named_registry_value("dbg.iserr")?;
let async_env_trace: LuaFunction = self.named_registry_value("dbg.trace")?;
let async_env_error: LuaFunction = self.named_registry_value("error")?;
let async_env_unpack: LuaFunction = self.named_registry_value("tab.unpack")?;
let async_env_thread: LuaFunction = self.named_registry_value("co.thread")?; let async_env_thread: LuaFunction = self.named_registry_value("co.thread")?;
let async_env_yield: LuaFunction = self.named_registry_value("co.yield")?; let async_env_yield: LuaFunction = self.named_registry_value("co.yield")?;
let async_env = TableBuilder::new(self)? let async_env = TableBuilder::new(self)?
.with_value("makeError", async_env_make_err)?
.with_value("isError", async_env_is_err)?
.with_value("trace", async_env_trace)?
.with_value("error", async_env_error)?
.with_value("unpack", async_env_unpack)?
.with_value("thread", async_env_thread)? .with_value("thread", async_env_thread)?
.with_value("yield", async_env_yield)? .with_value("yield", async_env_yield)?
.with_function( .with_function(
@ -50,7 +60,12 @@ impl LuaAsyncExt for &'static Lua {
.load( .load(
" "
resumeAsync(thread(), ...) resumeAsync(thread(), ...)
return yield() local results = { yield() }
if isError(results[1]) then
error(makeError(results[1], trace()))
else
return unpack(results)
end
", ",
) )
.set_name("asyncWrapper")? .set_name("asyncWrapper")?

View file

@ -225,8 +225,18 @@ impl<'fut> TaskScheduler<'fut> {
self.guid_running.set(Some(reference.id())); self.guid_running.set(Some(reference.id()));
let rets = match args_opt_res { let rets = match args_opt_res {
Some(args_res) => match args_res { Some(args_res) => match args_res {
Err(err) => Err(err), // FIXME: We need to throw this error in lua to let pcall & friends handle it properly /*
Ok(args) => thread.resume::<_, LuaMultiValue>(args), HACK: Resuming with an error here only works because the Rust
functions that we register and that may return lua errors are
also error-aware and wrapped in a special wrapper that checks
if the returned value is a lua error userdata, then throws it
Also note that this only happens for our custom async functions
that may pass errors as arguments when resuming tasks, other
native mlua functions will handle this and dont need wrapping
*/
Err(err) => thread.resume(err),
Ok(args) => thread.resume(args),
}, },
None => thread.resume(()), None => thread.resume(()),
}; };

View file

@ -55,11 +55,14 @@ create_tests! {
process_env: "process/env", process_env: "process/env",
process_exit: "process/exit", process_exit: "process/exit",
process_spawn: "process/spawn", process_spawn: "process/spawn",
require_children: "require/tests/children", require_children: "globals/require/tests/children",
require_invalid: "require/tests/invalid", require_invalid: "globals/require/tests/invalid",
require_nested: "require/tests/nested", require_nested: "globals/require/tests/nested",
require_parents: "require/tests/parents", require_parents: "globals/require/tests/parents",
require_siblings: "require/tests/siblings", require_siblings: "globals/require/tests/siblings",
global_pcall: "globals/pcall",
global_type: "globals/type",
global_typeof: "globals/typeof",
stdio_format: "stdio/format", stdio_format: "stdio/format",
stdio_color: "stdio/color", stdio_color: "stdio/color",
stdio_style: "stdio/style", stdio_style: "stdio/style",

View file

@ -1,6 +1,6 @@
use std::fmt::Write; use std::fmt::Write;
use console::{style, Style}; use console::{colors_enabled, set_colors_enabled, style, Style};
use lazy_static::lazy_static; use lazy_static::lazy_static;
use mlua::prelude::*; use mlua::prelude::*;
@ -178,7 +178,7 @@ pub fn pretty_format_value(
} }
} }
LuaValue::LightUserData(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to("<userdata>"))?, LuaValue::LightUserData(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to("<userdata>"))?,
_ => write!(buffer, "{}", STYLE_DIM.apply_to("?"))?, LuaValue::Error(e) => write!(buffer, "{}", pretty_format_luau_error(e, false),)?,
} }
Ok(()) Ok(())
} }
@ -200,7 +200,13 @@ pub fn pretty_format_multi_value(multi: &LuaMultiValue) -> LuaResult<String> {
Ok(buffer) Ok(buffer)
} }
pub fn pretty_format_luau_error(e: &LuaError) -> String { pub fn pretty_format_luau_error(e: &LuaError, colorized: bool) -> String {
let previous_colors_enabled = if !colorized {
set_colors_enabled(false);
Some(colors_enabled())
} else {
None
};
let stack_begin = format!("[{}]", COLOR_BLUE.apply_to("Stack Begin")); let stack_begin = format!("[{}]", COLOR_BLUE.apply_to("Stack Begin"));
let stack_end = format!("[{}]", COLOR_BLUE.apply_to("Stack End")); let stack_end = format!("[{}]", COLOR_BLUE.apply_to("Stack End"));
let err_string = match e { let err_string = match e {
@ -218,23 +224,33 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
let mut found_stack_begin = false; let mut found_stack_begin = false;
for (index, line) in err_lines.clone().iter().enumerate().rev() { for (index, line) in err_lines.clone().iter().enumerate().rev() {
if *line == "stack traceback:" { if *line == "stack traceback:" {
err_lines[index] = stack_begin; err_lines[index] = stack_begin.clone();
found_stack_begin = true; found_stack_begin = true;
break; break;
} }
} }
// Add "Stack End" to the very end of the stack trace for symmetry // Add "Stack End" to the very end of the stack trace for symmetry
if found_stack_begin { if found_stack_begin {
err_lines.push(stack_end); err_lines.push(stack_end.clone());
} }
err_lines.join("\n") err_lines.join("\n")
} }
LuaError::CallbackError { traceback, cause } => { LuaError::CallbackError { traceback, cause } => {
// Find the best traceback (most lines) and the root error message // Find the best traceback (most lines) and the root error message
let mut best_trace = traceback; // The traceback may also start with "override traceback:" which
// means it was passed from somewhere that wants a custom trace,
// so we should then respect that and get the best override instead
let mut best_trace: &str = traceback;
let mut root_cause = cause.as_ref(); let mut root_cause = cause.as_ref();
let mut trace_override = false;
while let LuaError::CallbackError { cause, traceback } = root_cause { while let LuaError::CallbackError { cause, traceback } = root_cause {
if traceback.lines().count() > best_trace.len() { let is_override = traceback.starts_with("override traceback:");
if is_override {
if !trace_override || traceback.lines().count() > best_trace.len() {
best_trace = traceback.strip_prefix("override traceback:").unwrap();
trace_override = true;
}
} else if !trace_override && traceback.lines().count() > best_trace.len() {
best_trace = traceback; best_trace = traceback;
} }
root_cause = cause; root_cause = cause;
@ -242,15 +258,19 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
// If we got a runtime error with an embedded traceback, we should // If we got a runtime error with an embedded traceback, we should
// use that instead since it generally contains more information // use that instead since it generally contains more information
if matches!(root_cause, LuaError::RuntimeError(e) if e.contains("stack traceback:")) { if matches!(root_cause, LuaError::RuntimeError(e) if e.contains("stack traceback:")) {
pretty_format_luau_error(root_cause) pretty_format_luau_error(root_cause, colorized)
} else { } else {
// Otherwise we format whatever root error we got using // Otherwise we format whatever root error we got using
// the same error formatting as for above runtime errors // the same error formatting as for above runtime errors
format!( format!(
"{}\n{}\n{}\n{}", "{}\n{}\n{}\n{}",
pretty_format_luau_error(root_cause), pretty_format_luau_error(root_cause, colorized),
stack_begin, stack_begin,
best_trace.strip_prefix("stack traceback:\n").unwrap(), if best_trace.starts_with("stack traceback:") {
best_trace.strip_prefix("stack traceback:\n").unwrap()
} else {
best_trace
},
stack_end stack_end
) )
} }
@ -269,11 +289,13 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
} }
e => format!("{e}"), e => format!("{e}"),
}; };
let mut err_lines = err_string.lines().collect::<Vec<_>>(); // Re-enable colors if they were previously enabled
if let Some(true) = previous_colors_enabled {
set_colors_enabled(true)
}
// Remove the script path from the error message // Remove the script path from the error message
// itself, it can be found in the stack trace // itself, it can be found in the stack trace
// FIXME: This no longer works now that we use let mut err_lines = err_string.lines().collect::<Vec<_>>();
// an exact name when our lune script is loaded
if let Some(first_line) = err_lines.first() { if let Some(first_line) = err_lines.first() {
if first_line.starts_with("[string \"") { if first_line.starts_with("[string \"") {
if let Some(closing_bracket) = first_line.find("]:") { if let Some(closing_bracket) = first_line.find("]:") {
@ -287,6 +309,120 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
} }
} }
} }
// Merge all lines back together into one string // Find where the stack trace stars and ends
err_lines.join("\n") let stack_begin_idx =
err_lines.iter().enumerate().find_map(
|(i, line)| {
if *line == stack_begin {
Some(i)
} else {
None
}
},
);
let stack_end_idx =
err_lines.iter().enumerate().find_map(
|(i, line)| {
if *line == stack_end {
Some(i)
} else {
None
}
},
);
// If we have a stack trace, we should transform the formatting from the
// default mlua formatting into something more friendly, similar to Roblox
if let (Some(idx_start), Some(idx_end)) = (stack_begin_idx, stack_end_idx) {
let stack_lines = err_lines
.iter()
.enumerate()
// Filter out stack lines
.filter_map(|(idx, line)| {
if idx > idx_start && idx < idx_end {
Some(*line)
} else {
None
}
})
// Transform from mlua format into friendly format, while also
// ensuring that leading whitespace / indentation is consistent
.map(transform_stack_line)
.collect::<Vec<_>>();
fix_error_nitpicks(format!(
"{}\n{}\n{}\n{}",
err_lines
.iter()
.take(idx_start)
.copied()
.collect::<Vec<_>>()
.join("\n"),
stack_begin,
stack_lines.join("\n"),
stack_end,
))
} else {
fix_error_nitpicks(err_string)
}
}
fn transform_stack_line(line: &str) -> String {
match (line.find('['), line.find(']')) {
(Some(idx_start), Some(idx_end)) => {
let name = line[idx_start..idx_end + 1]
.trim_start_matches('[')
.trim_start_matches("string ")
.trim_start_matches('"')
.trim_end_matches(']')
.trim_end_matches('"');
let after_name = &line[idx_end + 1..];
let line_num = match after_name.find(':') {
Some(lineno_start) => match after_name[lineno_start + 1..].find(':') {
Some(lineno_end) => &after_name[lineno_start + 1..lineno_end + 1],
None => match after_name.contains("in function") {
false => &after_name[lineno_start + 1..],
true => "",
},
},
None => "",
};
let func_name = match after_name.find("in function ") {
Some(func_start) => after_name[func_start + 12..]
.trim()
.trim_end_matches('\'')
.trim_start_matches('\'')
.trim_start_matches("_G."),
None => "",
};
let mut result = String::new();
write!(
result,
" Script '{}'",
match name {
"C" => "[C]",
name => name,
},
)
.unwrap();
if !line_num.is_empty() {
write!(result, ", Line {line_num}").unwrap();
}
if !func_name.is_empty() {
write!(result, " - function {func_name}").unwrap();
}
result
}
(_, _) => line.to_string(),
}
}
fn fix_error_nitpicks(full_message: String) -> String {
full_message
// Hacky fix for our custom require appearing as a normal script
.replace("'require', Line 5", "'[C]' - function require")
.replace("'require', Line 7", "'[C]' - function require")
// Fix error calls in custom script chunks coming through
.replace(
"'[C]' - function error\n Script '[C]' - function require",
"'[C]' - function require",
)
} }

39
tests/globals/pcall.luau Normal file
View file

@ -0,0 +1,39 @@
local function test(f, ...)
local success, message = pcall(f, ...)
assert(not success, "Function did not throw an error")
assert(type(message) == "userdata", "Pcall did not return a proper error")
end
-- These are not async but should be pcallable
test(error, "Test error", 2)
-- Net request is async and will throw a DNS error here for the weird address
test(net.request, "https://wxyz.google.com")
-- Net serve is async and will throw an OS error when trying to serve twice on the same port
local handle = net.serve(8080, function()
return ""
end)
task.delay(0, function()
handle.stop()
end)
test(net.serve, 8080, function() end)
local function e()
task.spawn(function()
task.defer(function()
task.delay(0, function()
error({
Hello = "World",
})
end)
end)
end)
end
task.defer(e)

View file

View file

@ -5,8 +5,8 @@ local function test(path: string)
if success then if success then
error(string.format("Invalid require at path '%s' succeeded", path)) error(string.format("Invalid require at path '%s' succeeded", path))
else else
message = tostring(message) print(message)
if string.find(message, string.format("'%s'", path)) == nil then if string.find(message, string.format("%s'", path)) == nil then
error( error(
string.format( string.format(
"Invalid require did not mention path '%s' in its error message!\nMessage: %s", "Invalid require did not mention path '%s' in its error message!\nMessage: %s",

11
tests/globals/type.luau Normal file
View file

@ -0,0 +1,11 @@
local function f() end
local thread1 = coroutine.create(f)
local thread2 = task.spawn(f)
local thread3 = task.defer(f)
local thread4 = task.delay(0, f)
assert(type(thread1) == "thread", "Calling type() did not return 'thread' after coroutine.create")
assert(type(thread2) == "thread", "Calling type() did not return 'thread' after task.spawn")
assert(type(thread3) == "thread", "Calling type() did not return 'thread' after task.defer")
assert(type(thread4) == "thread", "Calling type() did not return 'thread' after delay")

14
tests/globals/typeof.luau Normal file
View file

@ -0,0 +1,14 @@
local function f() end
local thread1 = coroutine.create(f)
local thread2 = task.spawn(f)
local thread3 = task.defer(f)
local thread4 = task.delay(0, f)
assert(
typeof(thread1) == "thread",
"Calling typeof() did not return 'thread' after coroutine.create"
)
assert(typeof(thread2) == "thread", "Calling typeof() did not return 'thread' after task.spawn")
assert(typeof(thread3) == "thread", "Calling typeof() did not return 'thread' after task.defer")
assert(typeof(thread4) == "thread", "Calling typeof() did not return 'thread' after delay")