From 58ce0463945800fe874bc119c1f4e654efa6fe9e Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Fri, 17 Feb 2023 15:03:13 +0100 Subject: [PATCH] Implement proper async require & error handling --- packages/lib/src/globals/require.rs | 163 +++++++++-------- packages/lib/src/globals/top_level.rs | 17 +- packages/lib/src/lib.rs | 3 +- packages/lib/src/lua/create.rs | 83 ++++++--- packages/lib/src/lua/ext.rs | 17 +- packages/lib/src/lua/task/scheduler.rs | 14 +- packages/lib/src/tests.rs | 13 +- packages/lib/src/utils/formatting.rs | 166 ++++++++++++++++-- tests/globals/pcall.luau | 39 ++++ .../{ => globals}/require/modules/module.luau | 0 .../{ => globals}/require/tests/children.luau | 0 tests/globals/require/tests/foo.lua | 0 .../{ => globals}/require/tests/invalid.luau | 4 +- tests/{ => globals}/require/tests/module.luau | 0 .../require/tests/modules/module.luau | 0 .../require/tests/modules/modules/module.luau | 0 .../require/tests/modules/nested.luau | 0 tests/{ => globals}/require/tests/nested.luau | 0 .../{ => globals}/require/tests/parents.luau | 0 .../{ => globals}/require/tests/siblings.luau | 0 tests/globals/type.luau | 11 ++ tests/globals/typeof.luau | 14 ++ 22 files changed, 412 insertions(+), 132 deletions(-) create mode 100644 tests/globals/pcall.luau rename tests/{ => globals}/require/modules/module.luau (100%) rename tests/{ => globals}/require/tests/children.luau (100%) create mode 100644 tests/globals/require/tests/foo.lua rename tests/{ => globals}/require/tests/invalid.luau (84%) rename tests/{ => globals}/require/tests/module.luau (100%) rename tests/{ => globals}/require/tests/modules/module.luau (100%) rename tests/{ => globals}/require/tests/modules/modules/module.luau (100%) rename tests/{ => globals}/require/tests/modules/nested.luau (100%) rename tests/{ => globals}/require/tests/nested.luau (100%) rename tests/{ => globals}/require/tests/parents.luau (100%) rename tests/{ => globals}/require/tests/siblings.luau (100%) create mode 100644 tests/globals/type.luau create mode 100644 tests/globals/typeof.luau diff --git a/packages/lib/src/globals/require.rs b/packages/lib/src/globals/require.rs index 6f39ef2..b974739 100644 --- a/packages/lib/src/globals/require.rs +++ b/packages/lib/src/globals/require.rs @@ -1,96 +1,107 @@ use std::{ env::{self, current_dir}, + io, path::PathBuf, - sync::Arc, }; use mlua::prelude::*; -use os_str_bytes::{OsStrBytes, RawOsStr}; +use tokio::fs; use crate::utils::table::TableBuilder; pub fn create(lua: &'static Lua) -> LuaResult { - // Preserve original require behavior if we have a special env var set + // Preserve original require behavior if we have a special env var set, + // returning an empty table since there are no globals to overwrite if env::var_os("LUAU_PWD_REQUIRE").is_some() { - // Return an empty table since there are no globals to overwrite return TableBuilder::new(lua)?.build_readonly(); } + // Store the current pwd, and make helper functions for path conversions + let require_pwd = current_dir()?.to_string_lossy().to_string(); + let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; + let require_error: LuaFunction = lua.named_registry_value("error")?; + let require_get_abs_rel_paths = lua + .create_function( + |_, (require_pwd, require_source, require_path): (String, String, String)| { + let mut path_relative_to_pwd = PathBuf::from( + &require_source + .trim_start_matches("[string \"") + .trim_end_matches("\"]"), + ) + .parent() + .unwrap() + .join(require_path); + // Try to normalize and resolve relative path segments such as './' and '../' + if let Ok(canonicalized) = + path_relative_to_pwd.with_extension("luau").canonicalize() + { + path_relative_to_pwd = canonicalized; + } + if let Ok(canonicalized) = path_relative_to_pwd.with_extension("lua").canonicalize() + { + path_relative_to_pwd = canonicalized; + } + let absolute = path_relative_to_pwd.to_string_lossy().to_string(); + let relative = absolute.trim_start_matches(&require_pwd).to_string(); + Ok((absolute, relative)) + }, + )? + .bind(require_pwd)?; /* - Store the current working directory so that we can use it later - and remove it from require paths in error messages, showing - absolute paths is bad ux and we should try to avoid it + We need to get the source file where require was + called to be able to do path-relative requires, + so we make a small wrapper to do that here, this + will then call our actual async require function - Throughout this function we also take extra care to not perform any lossy - conversion and use os strings instead of Rust's utf-8 checked strings, - just in case someone out there uses luau with non-utf8 string requires + This must be done in lua because due to how our + scheduler works mlua can not preserve debug info */ - let pwd = lua.create_string(¤t_dir()?.to_raw_bytes())?; - lua.set_named_registry_value("pwd", pwd)?; - /* - Create a new function that fetches the file name from the current thread, - sets the luau module lookup path to be the exact script we are looking - for, and then runs the original require function with the wanted path - */ - let new_require = lua.create_function(|lua, require_path: LuaString| { - let require_pwd: LuaString = lua.named_registry_value("pwd")?; - let require_fn: LuaFunction = lua.named_registry_value("require")?; - let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; - let require_source: LuaString = require_info.call((2, "s"))?; - /* - Combine the require caller source with the wanted path - string to get a final path relative to pwd - it is definitely - relative to pwd because Lune will only load files relative to pwd - */ - let raw_pwd_str = RawOsStr::assert_from_raw_bytes(require_pwd.as_bytes()); - let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes()); - let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes()); - let mut path_relative_to_pwd = PathBuf::from( - &raw_source - .trim_start_matches("[string \"") - .trim_end_matches("\"]") - .to_os_str(), + let require_env = TableBuilder::new(lua)? + .with_value("loaded", lua.create_table()?)? + .with_value("cache", lua.create_table()?)? + .with_value("info", require_info)? + .with_value("error", require_error)? + .with_value("paths", require_get_abs_rel_paths)? + .with_async_function("load", load_file)? + .build_readonly()?; + let require_fn_lua = lua + .load( + r#" + local source = info(2, "s") + local absolute, relative = paths(source, ...) + if loaded[absolute] ~= true then + local first, second = load(absolute, relative) + if first == nil or second ~= nil then + error("Module did not return exactly one value") + end + loaded[absolute] = true + cache[absolute] = first + return first + else + return cache[absolute] + end + "#, ) - .parent() - .unwrap() - .join(raw_path.to_os_str()); - // Try to normalize and resolve relative path segments such as './' and '../' - if let Ok(canonicalized) = path_relative_to_pwd.with_extension("luau").canonicalize() { - path_relative_to_pwd = canonicalized.with_extension(""); - } - if let Ok(canonicalized) = path_relative_to_pwd.with_extension("lua").canonicalize() { - path_relative_to_pwd = canonicalized.with_extension(""); - } - if let Ok(stripped) = path_relative_to_pwd.strip_prefix(&raw_pwd_str.to_os_str()) { - path_relative_to_pwd = stripped.to_path_buf(); - } - // Create a lossless lua string from the pathbuf and finally call require - let raw_path_str = RawOsStr::new(path_relative_to_pwd.as_os_str()); - let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes()); - // If the require call errors then we should also replace - // the path in the error message to improve user experience - let result: LuaResult<_> = require_fn.call::<_, LuaValue>(lua_path_str); - match result { - Err(LuaError::CallbackError { traceback, cause }) => { - let before = format!( - "runtime error: cannot find '{}'", - path_relative_to_pwd.to_str().unwrap() - ); - let after = format!( - "Invalid require path '{}' ({})", - require_path.to_str().unwrap(), - path_relative_to_pwd.to_str().unwrap() - ); - let cause = Arc::new(LuaError::RuntimeError( - cause.to_string().replace(&before, &after), - )); - Err(LuaError::CallbackError { traceback, cause }) - } - Err(e) => Err(e), - Ok(result) => Ok(result), - } - })?; - // Override the original require global with our monkey-patched one + .set_name("require")? + .set_environment(require_env)? + .into_function()?; TableBuilder::new(lua)? - .with_value("require", new_require)? + .with_value("require", require_fn_lua)? .build_readonly() } + +async fn load_file( + lua: &Lua, + (path_absolute, path_relative): (String, String), +) -> LuaResult { + // Try to read the wanted file, note that we use bytes instead of reading + // to a string since lua scripts are not necessarily valid utf-8 strings + match fs::read(&path_absolute).await { + Ok(contents) => lua.load(&contents).set_name(path_relative)?.eval(), + Err(e) => match e.kind() { + io::ErrorKind::NotFound => Err(LuaError::RuntimeError(format!( + "No lua module exists at the path '{path_relative}'" + ))), + _ => Err(LuaError::external(e)), + }, + } +} diff --git a/packages/lib/src/globals/top_level.rs b/packages/lib/src/globals/top_level.rs index 16b3081..9eda825 100644 --- a/packages/lib/src/globals/top_level.rs +++ b/packages/lib/src/globals/top_level.rs @@ -37,16 +37,21 @@ pub fn create(lua: &'static Lua) -> LuaResult { })? .with_function("error", |lua, (arg, level): (LuaValue, Option)| { let error: LuaFunction = lua.named_registry_value("error")?; - let multi = arg.to_lua_multi(lua)?; + let trace: LuaFunction = lua.named_registry_value("dbg.trace")?; error.call(( - format!( - "{}\n{}", - format_label("error"), - pretty_format_multi_value(&multi)? - ), + LuaError::CallbackError { + traceback: format!("override traceback:{}", trace.call::<_, String>(())?), + cause: LuaError::external(format!( + "{}\n{}", + format_label("error"), + pretty_format_multi_value(&arg.to_lua_multi(lua)?)? + )) + .into(), + }, level, ))?; Ok(()) })? + // TODO: Add an override for tostring that formats errors in a nicer way .build_readonly() } diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index 8432347..bae32a8 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -118,9 +118,8 @@ impl Lune { let mut got_error = false; loop { let result = sched.resume_queue().await; - // println!("{result}"); if let Some(err) = result.get_lua_error() { - eprintln!("{}", pretty_format_luau_error(&err)); + eprintln!("{}", pretty_format_luau_error(&err, true)); got_error = true; } if result.is_done() { diff --git a/packages/lib/src/lua/create.rs b/packages/lib/src/lua/create.rs index 5379b24..bb1a18a 100644 --- a/packages/lib/src/lua/create.rs +++ b/packages/lib/src/lua/create.rs @@ -3,29 +3,38 @@ use mlua::prelude::*; /* - Level 0 is the call to info - Level 1 is the load call in create() below where we load this into a function - - Level 2 is the call to the scheduler, probably, but we can't know for sure so we start at 2 + - Level 2 is the call to the trace, which we also want to skip, so start at 3 + + Also note that we must match the mlua traceback format here so that we + can pattern match and beautify it properly later on when outputting it */ const TRACE_IMPL_LUA: &str = r#" local lines = {} -for level = 2, 2^8 do +for level = 3, 16 do + local parts = {} local source, line, name = info(level, "sln") if source then - if line then - if name and #name > 0 then - push(lines, format(" Script '%s', Line %d - function %s", source, line, name)) - else - push(lines, format(" Script '%s', Line %d", source, line)) - end - elseif name and #name > 0 then - push(lines, format(" Script '%s' - function %s", source, name)) - else - push(lines, format(" Script '%s'", source)) - end - elseif name then - push(lines, format("[Lune] - function %s", source, name)) + push(parts, source) else break end + if line == -1 then + line = nil + end + if name and #name <= 0 then + name = nil + end + if line then + push(parts, format("%d", line)) + end + if name and #parts > 1 then + push(parts, format(" in function '%s'", name)) + elseif name then + push(parts, format("in function '%s'", name)) + end + if #parts > 0 then + push(lines, concat(parts, ":")) + end end if #lines > 0 then return concat(lines, "\n") @@ -49,12 +58,20 @@ end * `"type"` -> `type` * `"typeof"` -> `typeof` --- + * `"pcall"` -> `pcall` + * `"xpcall"` -> `xpcall` + --- + * `"tostring"` -> `tostring` + * `"tonumber"` -> `tonumber` + --- * `"co.thread"` -> `coroutine.running` * `"co.yield"` -> `coroutine.yield` * `"co.close"` -> `coroutine.close` --- * `"dbg.info"` -> `debug.info` * `"dbg.trace"` -> `debug.traceback` + * `"dbg.iserr"` -> `` + * `"dbg.makeerr"` -> `` --- */ pub fn create() -> LuaResult<&'static Lua> { @@ -72,23 +89,43 @@ pub fn create() -> LuaResult<&'static Lua> { lua.set_named_registry_value("error", globals.get::<_, LuaFunction>("error")?)?; lua.set_named_registry_value("type", globals.get::<_, LuaFunction>("type")?)?; lua.set_named_registry_value("typeof", globals.get::<_, LuaFunction>("typeof")?)?; + lua.set_named_registry_value("xpcall", globals.get::<_, LuaFunction>("xpcall")?)?; + lua.set_named_registry_value("pcall", globals.get::<_, LuaFunction>("pcall")?)?; + lua.set_named_registry_value("tostring", globals.get::<_, LuaFunction>("tostring")?)?; + lua.set_named_registry_value("tonumber", globals.get::<_, LuaFunction>("tonumber")?)?; lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?; lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?; lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?; lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; + lua.set_named_registry_value("tab.pack", table.get::<_, LuaFunction>("pack")?)?; + lua.set_named_registry_value("tab.unpack", table.get::<_, LuaFunction>("unpack")?)?; + // Create a function that can be called from lua to check if a value is a mlua error, + // this will be used in async environments for proper error handling and throwing, as + // well as a function that can be called to make a callback error with a traceback from lua + let dbg_is_err_fn = + lua.create_function(move |_, value: LuaValue| Ok(matches!(value, LuaValue::Error(_))))?; + + let dbg_make_err_fn = lua.create_function(|_, (cause, traceback): (LuaError, String)| { + Ok(LuaError::CallbackError { + traceback, + cause: cause.into(), + }) + })?; // Create a trace function that can be called to obtain a full stack trace from // lua, this is not possible to do from rust when using our manual scheduler - let trace_env = lua.create_table_with_capacity(0, 1)?; - trace_env.set("info", debug.get::<_, LuaFunction>("info")?)?; - trace_env.set("push", table.get::<_, LuaFunction>("insert")?)?; - trace_env.set("concat", table.get::<_, LuaFunction>("concat")?)?; - trace_env.set("format", string.get::<_, LuaFunction>("format")?)?; - let trace_fn = lua + let dbg_trace_env = lua.create_table_with_capacity(0, 1)?; + dbg_trace_env.set("info", debug.get::<_, LuaFunction>("info")?)?; + dbg_trace_env.set("push", table.get::<_, LuaFunction>("insert")?)?; + dbg_trace_env.set("concat", table.get::<_, LuaFunction>("concat")?)?; + dbg_trace_env.set("format", string.get::<_, LuaFunction>("format")?)?; + let dbg_trace_fn = lua .load(TRACE_IMPL_LUA) .set_name("=dbg.trace")? - .set_environment(trace_env)? + .set_environment(dbg_trace_env)? .into_function()?; - lua.set_named_registry_value("dbg.trace", trace_fn)?; + lua.set_named_registry_value("dbg.trace", dbg_trace_fn)?; + lua.set_named_registry_value("dbg.iserr", dbg_is_err_fn)?; + lua.set_named_registry_value("dbg.makeerr", dbg_make_err_fn)?; // All done Ok(lua) } diff --git a/packages/lib/src/lua/ext.rs b/packages/lib/src/lua/ext.rs index 940e14e..a933ce0 100644 --- a/packages/lib/src/lua/ext.rs +++ b/packages/lib/src/lua/ext.rs @@ -26,9 +26,19 @@ impl LuaAsyncExt for &'static Lua { F: 'static + Fn(&'static Lua, A) -> FR, FR: 'static + Future>, { + let async_env_make_err: LuaFunction = self.named_registry_value("dbg.makeerr")?; + let async_env_is_err: LuaFunction = self.named_registry_value("dbg.iserr")?; + let async_env_trace: LuaFunction = self.named_registry_value("dbg.trace")?; + let async_env_error: LuaFunction = self.named_registry_value("error")?; + let async_env_unpack: LuaFunction = self.named_registry_value("tab.unpack")?; let async_env_thread: LuaFunction = self.named_registry_value("co.thread")?; let async_env_yield: LuaFunction = self.named_registry_value("co.yield")?; let async_env = TableBuilder::new(self)? + .with_value("makeError", async_env_make_err)? + .with_value("isError", async_env_is_err)? + .with_value("trace", async_env_trace)? + .with_value("error", async_env_error)? + .with_value("unpack", async_env_unpack)? .with_value("thread", async_env_thread)? .with_value("yield", async_env_yield)? .with_function( @@ -50,7 +60,12 @@ impl LuaAsyncExt for &'static Lua { .load( " resumeAsync(thread(), ...) - return yield() + local results = { yield() } + if isError(results[1]) then + error(makeError(results[1], trace())) + else + return unpack(results) + end ", ) .set_name("asyncWrapper")? diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs index a38a3b0..b7de186 100644 --- a/packages/lib/src/lua/task/scheduler.rs +++ b/packages/lib/src/lua/task/scheduler.rs @@ -225,8 +225,18 @@ impl<'fut> TaskScheduler<'fut> { self.guid_running.set(Some(reference.id())); let rets = match args_opt_res { Some(args_res) => match args_res { - Err(err) => Err(err), // FIXME: We need to throw this error in lua to let pcall & friends handle it properly - Ok(args) => thread.resume::<_, LuaMultiValue>(args), + /* + HACK: Resuming with an error here only works because the Rust + functions that we register and that may return lua errors are + also error-aware and wrapped in a special wrapper that checks + if the returned value is a lua error userdata, then throws it + + Also note that this only happens for our custom async functions + that may pass errors as arguments when resuming tasks, other + native mlua functions will handle this and dont need wrapping + */ + Err(err) => thread.resume(err), + Ok(args) => thread.resume(args), }, None => thread.resume(()), }; diff --git a/packages/lib/src/tests.rs b/packages/lib/src/tests.rs index aa04bbb..5665e93 100644 --- a/packages/lib/src/tests.rs +++ b/packages/lib/src/tests.rs @@ -55,11 +55,14 @@ create_tests! { process_env: "process/env", process_exit: "process/exit", process_spawn: "process/spawn", - require_children: "require/tests/children", - require_invalid: "require/tests/invalid", - require_nested: "require/tests/nested", - require_parents: "require/tests/parents", - require_siblings: "require/tests/siblings", + require_children: "globals/require/tests/children", + require_invalid: "globals/require/tests/invalid", + require_nested: "globals/require/tests/nested", + require_parents: "globals/require/tests/parents", + require_siblings: "globals/require/tests/siblings", + global_pcall: "globals/pcall", + global_type: "globals/type", + global_typeof: "globals/typeof", stdio_format: "stdio/format", stdio_color: "stdio/color", stdio_style: "stdio/style", diff --git a/packages/lib/src/utils/formatting.rs b/packages/lib/src/utils/formatting.rs index e4ab5f7..cddf01c 100644 --- a/packages/lib/src/utils/formatting.rs +++ b/packages/lib/src/utils/formatting.rs @@ -1,6 +1,6 @@ use std::fmt::Write; -use console::{style, Style}; +use console::{colors_enabled, set_colors_enabled, style, Style}; use lazy_static::lazy_static; use mlua::prelude::*; @@ -178,7 +178,7 @@ pub fn pretty_format_value( } } LuaValue::LightUserData(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to(""))?, - _ => write!(buffer, "{}", STYLE_DIM.apply_to("?"))?, + LuaValue::Error(e) => write!(buffer, "{}", pretty_format_luau_error(e, false),)?, } Ok(()) } @@ -200,7 +200,13 @@ pub fn pretty_format_multi_value(multi: &LuaMultiValue) -> LuaResult { Ok(buffer) } -pub fn pretty_format_luau_error(e: &LuaError) -> String { +pub fn pretty_format_luau_error(e: &LuaError, colorized: bool) -> String { + let previous_colors_enabled = if !colorized { + set_colors_enabled(false); + Some(colors_enabled()) + } else { + None + }; let stack_begin = format!("[{}]", COLOR_BLUE.apply_to("Stack Begin")); let stack_end = format!("[{}]", COLOR_BLUE.apply_to("Stack End")); let err_string = match e { @@ -218,23 +224,33 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String { let mut found_stack_begin = false; for (index, line) in err_lines.clone().iter().enumerate().rev() { if *line == "stack traceback:" { - err_lines[index] = stack_begin; + err_lines[index] = stack_begin.clone(); found_stack_begin = true; break; } } // Add "Stack End" to the very end of the stack trace for symmetry if found_stack_begin { - err_lines.push(stack_end); + err_lines.push(stack_end.clone()); } err_lines.join("\n") } LuaError::CallbackError { traceback, cause } => { // Find the best traceback (most lines) and the root error message - let mut best_trace = traceback; + // The traceback may also start with "override traceback:" which + // means it was passed from somewhere that wants a custom trace, + // so we should then respect that and get the best override instead + let mut best_trace: &str = traceback; let mut root_cause = cause.as_ref(); + let mut trace_override = false; while let LuaError::CallbackError { cause, traceback } = root_cause { - if traceback.lines().count() > best_trace.len() { + let is_override = traceback.starts_with("override traceback:"); + if is_override { + if !trace_override || traceback.lines().count() > best_trace.len() { + best_trace = traceback.strip_prefix("override traceback:").unwrap(); + trace_override = true; + } + } else if !trace_override && traceback.lines().count() > best_trace.len() { best_trace = traceback; } root_cause = cause; @@ -242,15 +258,19 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String { // If we got a runtime error with an embedded traceback, we should // use that instead since it generally contains more information if matches!(root_cause, LuaError::RuntimeError(e) if e.contains("stack traceback:")) { - pretty_format_luau_error(root_cause) + pretty_format_luau_error(root_cause, colorized) } else { // Otherwise we format whatever root error we got using // the same error formatting as for above runtime errors format!( "{}\n{}\n{}\n{}", - pretty_format_luau_error(root_cause), + pretty_format_luau_error(root_cause, colorized), stack_begin, - best_trace.strip_prefix("stack traceback:\n").unwrap(), + if best_trace.starts_with("stack traceback:") { + best_trace.strip_prefix("stack traceback:\n").unwrap() + } else { + best_trace + }, stack_end ) } @@ -269,11 +289,13 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String { } e => format!("{e}"), }; - let mut err_lines = err_string.lines().collect::>(); + // Re-enable colors if they were previously enabled + if let Some(true) = previous_colors_enabled { + set_colors_enabled(true) + } // Remove the script path from the error message // itself, it can be found in the stack trace - // FIXME: This no longer works now that we use - // an exact name when our lune script is loaded + let mut err_lines = err_string.lines().collect::>(); if let Some(first_line) = err_lines.first() { if first_line.starts_with("[string \"") { if let Some(closing_bracket) = first_line.find("]:") { @@ -287,6 +309,120 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String { } } } - // Merge all lines back together into one string - err_lines.join("\n") + // Find where the stack trace stars and ends + let stack_begin_idx = + err_lines.iter().enumerate().find_map( + |(i, line)| { + if *line == stack_begin { + Some(i) + } else { + None + } + }, + ); + let stack_end_idx = + err_lines.iter().enumerate().find_map( + |(i, line)| { + if *line == stack_end { + Some(i) + } else { + None + } + }, + ); + // If we have a stack trace, we should transform the formatting from the + // default mlua formatting into something more friendly, similar to Roblox + if let (Some(idx_start), Some(idx_end)) = (stack_begin_idx, stack_end_idx) { + let stack_lines = err_lines + .iter() + .enumerate() + // Filter out stack lines + .filter_map(|(idx, line)| { + if idx > idx_start && idx < idx_end { + Some(*line) + } else { + None + } + }) + // Transform from mlua format into friendly format, while also + // ensuring that leading whitespace / indentation is consistent + .map(transform_stack_line) + .collect::>(); + fix_error_nitpicks(format!( + "{}\n{}\n{}\n{}", + err_lines + .iter() + .take(idx_start) + .copied() + .collect::>() + .join("\n"), + stack_begin, + stack_lines.join("\n"), + stack_end, + )) + } else { + fix_error_nitpicks(err_string) + } +} + +fn transform_stack_line(line: &str) -> String { + match (line.find('['), line.find(']')) { + (Some(idx_start), Some(idx_end)) => { + let name = line[idx_start..idx_end + 1] + .trim_start_matches('[') + .trim_start_matches("string ") + .trim_start_matches('"') + .trim_end_matches(']') + .trim_end_matches('"'); + let after_name = &line[idx_end + 1..]; + let line_num = match after_name.find(':') { + Some(lineno_start) => match after_name[lineno_start + 1..].find(':') { + Some(lineno_end) => &after_name[lineno_start + 1..lineno_end + 1], + None => match after_name.contains("in function") { + false => &after_name[lineno_start + 1..], + true => "", + }, + }, + None => "", + }; + let func_name = match after_name.find("in function ") { + Some(func_start) => after_name[func_start + 12..] + .trim() + .trim_end_matches('\'') + .trim_start_matches('\'') + .trim_start_matches("_G."), + None => "", + }; + let mut result = String::new(); + write!( + result, + " Script '{}'", + match name { + "C" => "[C]", + name => name, + }, + ) + .unwrap(); + if !line_num.is_empty() { + write!(result, ", Line {line_num}").unwrap(); + } + if !func_name.is_empty() { + write!(result, " - function {func_name}").unwrap(); + } + result + } + (_, _) => line.to_string(), + } +} + +fn fix_error_nitpicks(full_message: String) -> String { + full_message + // Hacky fix for our custom require appearing as a normal script + .replace("'require', Line 5", "'[C]' - function require") + .replace("'require', Line 7", "'[C]' - function require") + // Fix error calls in custom script chunks coming through + .replace( + "'[C]' - function error\n Script '[C]' - function require", + "'[C]' - function require", + ) } diff --git a/tests/globals/pcall.luau b/tests/globals/pcall.luau new file mode 100644 index 0000000..7658899 --- /dev/null +++ b/tests/globals/pcall.luau @@ -0,0 +1,39 @@ +local function test(f, ...) + local success, message = pcall(f, ...) + assert(not success, "Function did not throw an error") + assert(type(message) == "userdata", "Pcall did not return a proper error") +end + +-- These are not async but should be pcallable + +test(error, "Test error", 2) + +-- Net request is async and will throw a DNS error here for the weird address + +test(net.request, "https://wxyz.google.com") + +-- Net serve is async and will throw an OS error when trying to serve twice on the same port + +local handle = net.serve(8080, function() + return "" +end) + +task.delay(0, function() + handle.stop() +end) + +test(net.serve, 8080, function() end) + +local function e() + task.spawn(function() + task.defer(function() + task.delay(0, function() + error({ + Hello = "World", + }) + end) + end) + end) +end + +task.defer(e) diff --git a/tests/require/modules/module.luau b/tests/globals/require/modules/module.luau similarity index 100% rename from tests/require/modules/module.luau rename to tests/globals/require/modules/module.luau diff --git a/tests/require/tests/children.luau b/tests/globals/require/tests/children.luau similarity index 100% rename from tests/require/tests/children.luau rename to tests/globals/require/tests/children.luau diff --git a/tests/globals/require/tests/foo.lua b/tests/globals/require/tests/foo.lua new file mode 100644 index 0000000..e69de29 diff --git a/tests/require/tests/invalid.luau b/tests/globals/require/tests/invalid.luau similarity index 84% rename from tests/require/tests/invalid.luau rename to tests/globals/require/tests/invalid.luau index 5c5c693..2dbaaab 100644 --- a/tests/require/tests/invalid.luau +++ b/tests/globals/require/tests/invalid.luau @@ -5,8 +5,8 @@ local function test(path: string) if success then error(string.format("Invalid require at path '%s' succeeded", path)) else - message = tostring(message) - if string.find(message, string.format("'%s'", path)) == nil then + print(message) + if string.find(message, string.format("%s'", path)) == nil then error( string.format( "Invalid require did not mention path '%s' in its error message!\nMessage: %s", diff --git a/tests/require/tests/module.luau b/tests/globals/require/tests/module.luau similarity index 100% rename from tests/require/tests/module.luau rename to tests/globals/require/tests/module.luau diff --git a/tests/require/tests/modules/module.luau b/tests/globals/require/tests/modules/module.luau similarity index 100% rename from tests/require/tests/modules/module.luau rename to tests/globals/require/tests/modules/module.luau diff --git a/tests/require/tests/modules/modules/module.luau b/tests/globals/require/tests/modules/modules/module.luau similarity index 100% rename from tests/require/tests/modules/modules/module.luau rename to tests/globals/require/tests/modules/modules/module.luau diff --git a/tests/require/tests/modules/nested.luau b/tests/globals/require/tests/modules/nested.luau similarity index 100% rename from tests/require/tests/modules/nested.luau rename to tests/globals/require/tests/modules/nested.luau diff --git a/tests/require/tests/nested.luau b/tests/globals/require/tests/nested.luau similarity index 100% rename from tests/require/tests/nested.luau rename to tests/globals/require/tests/nested.luau diff --git a/tests/require/tests/parents.luau b/tests/globals/require/tests/parents.luau similarity index 100% rename from tests/require/tests/parents.luau rename to tests/globals/require/tests/parents.luau diff --git a/tests/require/tests/siblings.luau b/tests/globals/require/tests/siblings.luau similarity index 100% rename from tests/require/tests/siblings.luau rename to tests/globals/require/tests/siblings.luau diff --git a/tests/globals/type.luau b/tests/globals/type.luau new file mode 100644 index 0000000..273b53c --- /dev/null +++ b/tests/globals/type.luau @@ -0,0 +1,11 @@ +local function f() end + +local thread1 = coroutine.create(f) +local thread2 = task.spawn(f) +local thread3 = task.defer(f) +local thread4 = task.delay(0, f) + +assert(type(thread1) == "thread", "Calling type() did not return 'thread' after coroutine.create") +assert(type(thread2) == "thread", "Calling type() did not return 'thread' after task.spawn") +assert(type(thread3) == "thread", "Calling type() did not return 'thread' after task.defer") +assert(type(thread4) == "thread", "Calling type() did not return 'thread' after delay") diff --git a/tests/globals/typeof.luau b/tests/globals/typeof.luau new file mode 100644 index 0000000..ac42c28 --- /dev/null +++ b/tests/globals/typeof.luau @@ -0,0 +1,14 @@ +local function f() end + +local thread1 = coroutine.create(f) +local thread2 = task.spawn(f) +local thread3 = task.defer(f) +local thread4 = task.delay(0, f) + +assert( + typeof(thread1) == "thread", + "Calling typeof() did not return 'thread' after coroutine.create" +) +assert(typeof(thread2) == "thread", "Calling typeof() did not return 'thread' after task.spawn") +assert(typeof(thread3) == "thread", "Calling typeof() did not return 'thread' after task.defer") +assert(typeof(thread4) == "thread", "Calling typeof() did not return 'thread' after delay")