diff --git a/packages/cli/src/cli.rs b/packages/cli/src/cli.rs index 897e1b1..623dc7c 100644 --- a/packages/cli/src/cli.rs +++ b/packages/cli/src/cli.rs @@ -159,7 +159,7 @@ impl Cli { // Display the file path relative to cwd with no extensions in stack traces let file_display_name = file_path.with_extension("").display().to_string(); // Create a new lune object with all globals & run the script - let lune = Lune::new().with_all_globals_and_args(self.script_args); + let lune = Lune::new().with_args(self.script_args); let result = lune.run(&file_display_name, &file_contents).await; Ok(match result { Err(e) => { diff --git a/packages/lib/src/globals/require.rs b/packages/lib/src/globals/require.rs index b974739..938b4f3 100644 --- a/packages/lib/src/globals/require.rs +++ b/packages/lib/src/globals/require.rs @@ -1,11 +1,10 @@ use std::{ env::{self, current_dir}, - io, + fs, path::PathBuf, }; use mlua::prelude::*; -use tokio::fs; use crate::utils::table::TableBuilder; @@ -15,37 +14,63 @@ pub fn create(lua: &'static Lua) -> LuaResult { if env::var_os("LUAU_PWD_REQUIRE").is_some() { return TableBuilder::new(lua)?.build_readonly(); } - // Store the current pwd, and make helper functions for path conversions - let require_pwd = current_dir()?.to_string_lossy().to_string(); + // Store the current pwd, and make the functions for path conversions & loading a file + let mut require_pwd = current_dir()?.to_string_lossy().to_string(); + if !require_pwd.ends_with('/') { + require_pwd = format!("{require_pwd}/") + } let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; let require_error: LuaFunction = lua.named_registry_value("error")?; let require_get_abs_rel_paths = lua .create_function( |_, (require_pwd, require_source, require_path): (String, String, String)| { - let mut path_relative_to_pwd = PathBuf::from( + let path_relative_to_pwd = PathBuf::from( &require_source .trim_start_matches("[string \"") .trim_end_matches("\"]"), ) .parent() .unwrap() - .join(require_path); + .join(&require_path); // Try to normalize and resolve relative path segments such as './' and '../' - if let Ok(canonicalized) = - path_relative_to_pwd.with_extension("luau").canonicalize() - { - path_relative_to_pwd = canonicalized; - } - if let Ok(canonicalized) = path_relative_to_pwd.with_extension("lua").canonicalize() - { - path_relative_to_pwd = canonicalized; - } - let absolute = path_relative_to_pwd.to_string_lossy().to_string(); + let file_path = match ( + path_relative_to_pwd.with_extension("luau").canonicalize(), + path_relative_to_pwd.with_extension("lua").canonicalize(), + ) { + (Ok(luau), _) => luau, + (_, Ok(lua)) => lua, + _ => { + return Err(LuaError::RuntimeError(format!( + "File does not exist at path '{require_path}'" + ))) + } + }; + let absolute = file_path.to_string_lossy().to_string(); let relative = absolute.trim_start_matches(&require_pwd).to_string(); Ok((absolute, relative)) }, )? .bind(require_pwd)?; + // Note that file loading must be blocking to guarantee the require cache works, if it + // were async then one lua script may require a module during the file reading process + let require_get_loaded_file = lua.create_function( + |lua: &Lua, (path_absolute, path_relative): (String, String)| { + // Use a name without extensions for loading the chunk, the + // above code assumes the require path is without extensions + let path_relative_no_extension = path_relative + .trim_end_matches(".lua") + .trim_end_matches(".luau"); + // Try to read the wanted file, note that we use bytes instead of reading + // to a string since lua scripts are not necessarily valid utf-8 strings + match fs::read(path_absolute) { + Ok(contents) => lua + .load(&contents) + .set_name(path_relative_no_extension)? + .eval::(), + Err(e) => Err(LuaError::external(e)), + } + }, + )?; /* We need to get the source file where require was called to be able to do path-relative requires, @@ -61,12 +86,15 @@ pub fn create(lua: &'static Lua) -> LuaResult { .with_value("info", require_info)? .with_value("error", require_error)? .with_value("paths", require_get_abs_rel_paths)? - .with_async_function("load", load_file)? + .with_value("load", require_get_loaded_file)? .build_readonly()?; let require_fn_lua = lua .load( r#" - local source = info(2, "s") + local source = info(1, "s") + if source == '[string "require"]' then + source = info(2, "s") + end local absolute, relative = paths(source, ...) if loaded[absolute] ~= true then local first, second = load(absolute, relative) @@ -88,20 +116,3 @@ pub fn create(lua: &'static Lua) -> LuaResult { .with_value("require", require_fn_lua)? .build_readonly() } - -async fn load_file( - lua: &Lua, - (path_absolute, path_relative): (String, String), -) -> LuaResult { - // Try to read the wanted file, note that we use bytes instead of reading - // to a string since lua scripts are not necessarily valid utf-8 strings - match fs::read(&path_absolute).await { - Ok(contents) => lua.load(&contents).set_name(path_relative)?.eval(), - Err(e) => match e.kind() { - io::ErrorKind::NotFound => Err(LuaError::RuntimeError(format!( - "No lua module exists at the path '{path_relative}'" - ))), - _ => Err(LuaError::external(e)), - }, - } -} diff --git a/packages/lib/src/globals/task.rs b/packages/lib/src/globals/task.rs index f2f3bff..fafa398 100644 --- a/packages/lib/src/globals/task.rs +++ b/packages/lib/src/globals/task.rs @@ -1,10 +1,13 @@ -use std::time::Duration; - use mlua::prelude::*; -use tokio::time::{sleep, Instant}; use crate::{ - lua::task::{TaskKind, TaskReference, TaskScheduler, TaskSchedulerScheduleExt}, + lua::{ + async_ext::LuaAsyncExt, + task::{ + LuaThreadOrFunction, LuaThreadOrTaskReference, TaskKind, TaskReference, TaskScheduler, + TaskSchedulerScheduleExt, + }, + }, utils::table::TableBuilder, }; @@ -22,7 +25,6 @@ pub fn create(lua: &'static Lua) -> LuaResult> { we need to yield right away to allow the spawned task to run until first yield */ - let task_spawn_env_thread: LuaFunction = lua.named_registry_value("co.thread")?; let task_spawn_env_yield: LuaFunction = lua.named_registry_value("co.yield")?; let task_spawn = lua .load( @@ -33,10 +35,10 @@ pub fn create(lua: &'static Lua) -> LuaResult> { return task ", ) - .set_name("=task.spawn")? + .set_name("task.spawn")? .set_environment( TableBuilder::new(lua)? - .with_value("thread", task_spawn_env_thread)? + .with_function("thread", |lua, _: ()| Ok(lua.current_thread()))? .with_value("yield", task_spawn_env_yield)? .with_function( "scheduleNext", @@ -63,83 +65,14 @@ pub fn create(lua: &'static Lua) -> LuaResult> { coroutine.set("wrap", lua.create_function(coroutine_wrap)?)?; // All good, return the task scheduler lib TableBuilder::new(lua)? + .with_value("wait", lua.create_waiter_function()?)? .with_value("spawn", task_spawn)? .with_function("cancel", task_cancel)? .with_function("defer", task_defer)? .with_function("delay", task_delay)? - .with_async_function("wait", task_wait)? .build_readonly() } -/* - Proxy enum to deal with both threads & functions -*/ - -enum LuaThreadOrFunction<'lua> { - Thread(LuaThread<'lua>), - Function(LuaFunction<'lua>), -} - -impl<'lua> LuaThreadOrFunction<'lua> { - fn into_thread(self, lua: &'lua Lua) -> LuaResult> { - match self { - Self::Thread(t) => Ok(t), - Self::Function(f) => lua.create_thread(f), - } - } -} - -impl<'lua> FromLua<'lua> for LuaThreadOrFunction<'lua> { - fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { - match value { - LuaValue::Thread(t) => Ok(Self::Thread(t)), - LuaValue::Function(f) => Ok(Self::Function(f)), - value => Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: "LuaThreadOrFunction", - message: Some(format!( - "Expected thread or function, got '{}'", - value.type_name() - )), - }), - } - } -} - -/* - Proxy enum to deal with both threads & task scheduler task references -*/ - -enum LuaThreadOrTaskReference<'lua> { - Thread(LuaThread<'lua>), - TaskReference(TaskReference), -} - -impl<'lua> FromLua<'lua> for LuaThreadOrTaskReference<'lua> { - fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { - let tname = value.type_name(); - match value { - LuaValue::Thread(t) => Ok(Self::Thread(t)), - LuaValue::UserData(u) => { - if let Ok(task) = TaskReference::from_lua(LuaValue::UserData(u), lua) { - Ok(Self::TaskReference(task)) - } else { - Err(LuaError::FromLuaConversionError { - from: tname, - to: "thread", - message: Some(format!("Expected thread, got '{tname}'")), - }) - } - } - _ => Err(LuaError::FromLuaConversionError { - from: tname, - to: "thread", - message: Some(format!("Expected thread, got '{tname}'")), - }), - } - } -} - /* Basic task functions */ @@ -166,12 +99,6 @@ fn task_delay( sched.schedule_blocking_after_seconds(secs, tof.into_thread(lua)?, args) } -async fn task_wait(_: &Lua, secs: Option) -> LuaResult { - let start = Instant::now(); - sleep(Duration::from_secs_f64(secs.unwrap_or_default())).await; - Ok(start.elapsed().as_secs_f64()) -} - /* Type getter overrides for compat with task scheduler */ @@ -207,7 +134,7 @@ fn coroutine_resume<'lua>( match value { LuaThreadOrTaskReference::Thread(t) => { let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); - let task = sched.create_task(TaskKind::Instant, t, None, None)?; + let task = sched.create_task(TaskKind::Instant, t, None, true)?; sched.resume_task(task, None) } LuaThreadOrTaskReference::TaskReference(t) => lua @@ -222,7 +149,7 @@ fn coroutine_wrap<'lua>(lua: &'lua Lua, func: LuaFunction) -> LuaResult() diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index bae32a8..e25bc17 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -1,4 +1,4 @@ -use std::{collections::HashSet, process::ExitCode}; +use std::process::ExitCode; use lua::task::{TaskScheduler, TaskSchedulerResumeExt, TaskSchedulerScheduleExt}; use mlua::prelude::*; @@ -18,8 +18,7 @@ pub use lua::create_lune_lua; #[derive(Clone, Debug, Default)] pub struct Lune { - includes: HashSet, - excludes: HashSet, + args: Vec, } impl Lune { @@ -31,42 +30,13 @@ impl Lune { } /** - Include a global in the lua environment created for running a Lune script. + Arguments to give in `process.args` for a Lune script. */ - pub fn with_global(mut self, global: LuneGlobal) -> Self { - self.includes.insert(global); - self - } - - /** - Include all globals in the lua environment created for running a Lune script. - */ - pub fn with_all_globals(mut self) -> Self { - for global in LuneGlobal::all::(&[]) { - self.includes.insert(global); - } - self - } - - /** - Include all globals in the lua environment created for running a - Lune script, as well as supplying args for [`LuneGlobal::Process`]. - */ - pub fn with_all_globals_and_args(mut self, args: Vec) -> Self { - for global in LuneGlobal::all(&args) { - self.includes.insert(global); - } - self - } - - /** - Exclude a global from the lua environment created for running a Lune script. - - This should be preferred over manually iterating and filtering - which Lune globals to add to the global environment. - */ - pub fn without_global(mut self, global: LuneGlobal) -> Self { - self.excludes.insert(global); + pub fn with_args(mut self, args: V) -> Self + where + V: Into>, + { + self.args = args.into(); self } @@ -76,12 +46,11 @@ impl Lune { This will create a new sandboxed Luau environment with the configured globals and arguments, running inside of a [`tokio::task::LocalSet`]. - Some Lune globals such as [`LuneGlobal::Process`] and [`LuneGlobal::Net`] - may spawn separate tokio tasks on other threads, but the Luau environment - itself is guaranteed to run on a single thread in the local set. + Some Lune globals may spawn separate tokio tasks on other threads, but the Luau + environment itself is guaranteed to run on a single thread in the local set. - Note that this will create a static Lua instance and task scheduler which both - will live for the remainer of the program, and that this leaks memory using + Note that this will create a static Lua instance and task scheduler that will + both live for the remainer of the program, and that this leaks memory using [`Box::leak`] that will then get deallocated when the program exits. */ pub async fn run( @@ -106,10 +75,8 @@ impl Lune { sched.schedule_blocking(main_thread, main_thread_args)?; // Create our wanted lune globals, some of these need // the task scheduler be available during construction - for global in self.includes.clone() { - if !self.excludes.contains(&global) { - global.inject(lua)?; - } + for global in LuneGlobal::all(&self.args) { + global.inject(lua)?; } // Keep running the scheduler until there are either no tasks // left to run, or until a task requests to exit the process diff --git a/packages/lib/src/lua/ext.rs b/packages/lib/src/lua/async_ext.rs similarity index 67% rename from packages/lib/src/lua/ext.rs rename to packages/lib/src/lua/async_ext.rs index a933ce0..e90e5d9 100644 --- a/packages/lib/src/lua/ext.rs +++ b/packages/lib/src/lua/async_ext.rs @@ -4,6 +4,8 @@ use mlua::prelude::*; use crate::{lua::task::TaskScheduler, utils::table::TableBuilder}; +use super::task::TaskSchedulerAsyncExt; + #[async_trait(?Send)] pub trait LuaAsyncExt { fn create_async_function<'lua, A, R, F, FR>(self, func: F) -> LuaResult> @@ -12,6 +14,8 @@ pub trait LuaAsyncExt { R: ToLuaMulti<'static>, F: 'static + Fn(&'static Lua, A) -> FR, FR: 'static + Future>; + + fn create_waiter_function<'lua>(self) -> LuaResult>; } impl LuaAsyncExt for &'static Lua { @@ -31,7 +35,6 @@ impl LuaAsyncExt for &'static Lua { let async_env_trace: LuaFunction = self.named_registry_value("dbg.trace")?; let async_env_error: LuaFunction = self.named_registry_value("error")?; let async_env_unpack: LuaFunction = self.named_registry_value("tab.unpack")?; - let async_env_thread: LuaFunction = self.named_registry_value("co.thread")?; let async_env_yield: LuaFunction = self.named_registry_value("co.yield")?; let async_env = TableBuilder::new(self)? .with_value("makeError", async_env_make_err)? @@ -39,8 +42,8 @@ impl LuaAsyncExt for &'static Lua { .with_value("trace", async_env_trace)? .with_value("error", async_env_error)? .with_value("unpack", async_env_unpack)? - .with_value("thread", async_env_thread)? .with_value("yield", async_env_yield)? + .with_function("thread", |lua, _: ()| Ok(lua.current_thread()))? .with_function( "resumeAsync", move |lua: &Lua, (thread, args): (LuaThread, A)| { @@ -48,7 +51,7 @@ impl LuaAsyncExt for &'static Lua { let sched = lua .app_data_ref::<&TaskScheduler>() .expect("Missing task scheduler as a lua app data"); - sched.queue_async_task(thread, None, None, async { + sched.queue_async_task(thread, None, async { let rets = fut.await?; let mult = rets.to_lua_multi(lua)?; Ok(Some(mult)) @@ -68,7 +71,36 @@ impl LuaAsyncExt for &'static Lua { end ", ) - .set_name("asyncWrapper")? + .set_name("async")? + .set_environment(async_env)? + .into_function()?; + Ok(async_func) + } + + /** + Creates a special async function that waits the + desired amount of time, inheriting the guid of the + current thread / task for proper cancellation. + */ + fn create_waiter_function<'lua>(self) -> LuaResult> { + let async_env_yield: LuaFunction = self.named_registry_value("co.yield")?; + let async_env = TableBuilder::new(self)? + .with_value("yield", async_env_yield)? + .with_function("resumeAfter", move |lua: &Lua, duration: Option| { + let sched = lua + .app_data_ref::<&TaskScheduler>() + .expect("Missing task scheduler as a lua app data"); + sched.schedule_wait(lua.current_thread(), duration) + })? + .build_readonly()?; + let async_func = self + .load( + " + resumeAfter(...) + return yield() + ", + ) + .set_name("wait")? .set_environment(async_env)? .into_function()?; Ok(async_func) diff --git a/packages/lib/src/lua/create.rs b/packages/lib/src/lua/create.rs index bb1a18a..8ba500a 100644 --- a/packages/lib/src/lua/create.rs +++ b/packages/lib/src/lua/create.rs @@ -64,7 +64,6 @@ end * `"tostring"` -> `tostring` * `"tonumber"` -> `tonumber` --- - * `"co.thread"` -> `coroutine.running` * `"co.yield"` -> `coroutine.yield` * `"co.close"` -> `coroutine.close` --- @@ -93,7 +92,6 @@ pub fn create() -> LuaResult<&'static Lua> { lua.set_named_registry_value("pcall", globals.get::<_, LuaFunction>("pcall")?)?; lua.set_named_registry_value("tostring", globals.get::<_, LuaFunction>("tostring")?)?; lua.set_named_registry_value("tonumber", globals.get::<_, LuaFunction>("tonumber")?)?; - lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?; lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?; lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?; lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; diff --git a/packages/lib/src/lua/mod.rs b/packages/lib/src/lua/mod.rs index 752b458..0323230 100644 --- a/packages/lib/src/lua/mod.rs +++ b/packages/lib/src/lua/mod.rs @@ -1,6 +1,6 @@ mod create; -pub mod ext; +pub mod async_ext; pub mod net; pub mod stdio; pub mod task; diff --git a/packages/lib/src/lua/task/ext/async_ext.rs b/packages/lib/src/lua/task/ext/async_ext.rs index deaf59b..2fbf8a2 100644 --- a/packages/lib/src/lua/task/ext/async_ext.rs +++ b/packages/lib/src/lua/task/ext/async_ext.rs @@ -1,7 +1,12 @@ +use std::time::Duration; + use async_trait::async_trait; use futures_util::Future; use mlua::prelude::*; +use tokio::time::{sleep, Instant}; + +use crate::lua::task::TaskKind; use super::super::{ async_handle::TaskSchedulerAsyncHandle, message::TaskSchedulerMessage, @@ -30,6 +35,12 @@ pub trait TaskSchedulerAsyncExt<'fut> { R: ToLuaMulti<'static>, F: 'static + Fn(&'static Lua) -> FR, FR: 'static + Future>; + + fn schedule_wait( + &'fut self, + reference: LuaThread<'_>, + duration: Option, + ) -> LuaResult; } /* @@ -82,7 +93,7 @@ impl<'fut> TaskSchedulerAsyncExt<'fut> for TaskScheduler<'fut> { F: 'static + Fn(&'static Lua) -> FR, FR: 'static + Future>, { - self.queue_async_task(thread, None, None, async move { + self.queue_async_task(thread, None, async move { match func(self.lua).await { Ok(res) => match res.to_lua_multi(self.lua) { Ok(multi) => Ok(Some(multi)), @@ -92,4 +103,30 @@ impl<'fut> TaskSchedulerAsyncExt<'fut> for TaskScheduler<'fut> { } }) } + + /** + Schedules a task reference to be resumed after a certain amount of time. + + The given task will be resumed with the elapsed time as its one and only argument. + */ + fn schedule_wait( + &'fut self, + thread: LuaThread<'_>, + duration: Option, + ) -> LuaResult { + let reference = self.create_task(TaskKind::Future, thread, None, true)?; + // Insert the future + let futs = self + .futures + .try_lock() + .expect("Tried to add future to queue during futures resumption"); + futs.push(Box::pin(async move { + let before = Instant::now(); + sleep(Duration::from_secs_f64(duration.unwrap_or_default())).await; + let elapsed_secs = before.elapsed().as_secs_f64(); + let args = elapsed_secs.to_lua_multi(self.lua).unwrap(); + (Some(reference), Ok(Some(args))) + })); + Ok(reference) + } } diff --git a/packages/lib/src/lua/task/ext/resume_ext.rs b/packages/lib/src/lua/task/ext/resume_ext.rs index a036a67..91066bf 100644 --- a/packages/lib/src/lua/task/ext/resume_ext.rs +++ b/packages/lib/src/lua/task/ext/resume_ext.rs @@ -124,12 +124,15 @@ async fn resume_next_async_task(scheduler: &TaskScheduler<'_>) -> TaskSchedulerS .await .expect("Tried to resume next queued future but none are queued") }; - // Promote this future task to a blocking task and resume it - // right away, also taking care to not borrow mutably twice - // by dropping this guard before trying to resume it - let mut queue_guard = scheduler.tasks_queue_blocking.borrow_mut(); - queue_guard.push_front(task); - drop(queue_guard); + // The future might not return a reference that it wants to resume + if let Some(task) = task { + // Promote this future task to a blocking task and resume it + // right away, also taking care to not borrow mutably twice + // by dropping this guard before trying to resume it + let mut queue_guard = scheduler.tasks_queue_blocking.borrow_mut(); + queue_guard.push_front(task); + drop(queue_guard); + } resume_next_blocking_task(scheduler, result.transpose()) } diff --git a/packages/lib/src/lua/task/ext/schedule_ext.rs b/packages/lib/src/lua/task/ext/schedule_ext.rs index 57f7da9..244a344 100644 --- a/packages/lib/src/lua/task/ext/schedule_ext.rs +++ b/packages/lib/src/lua/task/ext/schedule_ext.rs @@ -52,7 +52,7 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> { thread: LuaThread<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - self.queue_blocking_task(TaskKind::Instant, thread, Some(thread_args), None) + self.queue_blocking_task(TaskKind::Instant, thread, Some(thread_args)) } /** @@ -67,7 +67,7 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> { thread: LuaThread<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - self.queue_blocking_task(TaskKind::Deferred, thread, Some(thread_args), None) + self.queue_blocking_task(TaskKind::Deferred, thread, Some(thread_args)) } /** @@ -83,7 +83,7 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> { thread: LuaThread<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - self.queue_async_task(thread, Some(thread_args), None, async move { + self.queue_async_task(thread, Some(thread_args), async move { sleep(Duration::from_secs_f64(after_secs)).await; Ok(None) }) diff --git a/packages/lib/src/lua/task/mod.rs b/packages/lib/src/lua/task/mod.rs index 666eb23..dae3ac5 100644 --- a/packages/lib/src/lua/task/mod.rs +++ b/packages/lib/src/lua/task/mod.rs @@ -1,10 +1,12 @@ mod async_handle; mod ext; mod message; +mod proxy; mod result; mod scheduler; mod task_kind; mod task_reference; pub use ext::*; +pub use proxy::*; pub use scheduler::*; diff --git a/packages/lib/src/lua/task/proxy.rs b/packages/lib/src/lua/task/proxy.rs new file mode 100644 index 0000000..4fd2762 --- /dev/null +++ b/packages/lib/src/lua/task/proxy.rs @@ -0,0 +1,116 @@ +use mlua::prelude::*; + +use super::TaskReference; + +/* + Proxy enum to deal with both threads & functions +*/ + +#[derive(Debug, Clone)] +pub enum LuaThreadOrFunction<'lua> { + Thread(LuaThread<'lua>), + Function(LuaFunction<'lua>), +} + +impl<'lua> LuaThreadOrFunction<'lua> { + pub fn into_thread(self, lua: &'lua Lua) -> LuaResult> { + match self { + Self::Thread(t) => Ok(t), + Self::Function(f) => lua.create_thread(f), + } + } +} + +impl<'lua> From> for LuaThreadOrFunction<'lua> { + fn from(value: LuaThread<'lua>) -> Self { + Self::Thread(value) + } +} + +impl<'lua> From> for LuaThreadOrFunction<'lua> { + fn from(value: LuaFunction<'lua>) -> Self { + Self::Function(value) + } +} + +impl<'lua> FromLua<'lua> for LuaThreadOrFunction<'lua> { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + match value { + LuaValue::Thread(t) => Ok(Self::Thread(t)), + LuaValue::Function(f) => Ok(Self::Function(f)), + value => Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "LuaThreadOrFunction", + message: Some(format!( + "Expected thread or function, got '{}'", + value.type_name() + )), + }), + } + } +} + +impl<'lua> ToLua<'lua> for LuaThreadOrFunction<'lua> { + fn to_lua(self, _: &'lua Lua) -> LuaResult> { + match self { + Self::Thread(t) => Ok(LuaValue::Thread(t)), + Self::Function(f) => Ok(LuaValue::Function(f)), + } + } +} + +/* + Proxy enum to deal with both threads & task scheduler task references +*/ + +#[derive(Debug, Clone)] +pub enum LuaThreadOrTaskReference<'lua> { + Thread(LuaThread<'lua>), + TaskReference(TaskReference), +} + +impl<'lua> From> for LuaThreadOrTaskReference<'lua> { + fn from(value: LuaThread<'lua>) -> Self { + Self::Thread(value) + } +} + +impl<'lua> From for LuaThreadOrTaskReference<'lua> { + fn from(value: TaskReference) -> Self { + Self::TaskReference(value) + } +} + +impl<'lua> FromLua<'lua> for LuaThreadOrTaskReference<'lua> { + fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { + let tname = value.type_name(); + match value { + LuaValue::Thread(t) => Ok(Self::Thread(t)), + LuaValue::UserData(u) => { + if let Ok(task) = TaskReference::from_lua(LuaValue::UserData(u), lua) { + Ok(Self::TaskReference(task)) + } else { + Err(LuaError::FromLuaConversionError { + from: tname, + to: "thread", + message: Some(format!("Expected thread, got '{tname}'")), + }) + } + } + _ => Err(LuaError::FromLuaConversionError { + from: tname, + to: "thread", + message: Some(format!("Expected thread, got '{tname}'")), + }), + } + } +} + +impl<'lua> ToLua<'lua> for LuaThreadOrTaskReference<'lua> { + fn to_lua(self, lua: &'lua Lua) -> LuaResult> { + match self { + Self::TaskReference(t) => t.to_lua(lua), + Self::Thread(t) => Ok(LuaValue::Thread(t)), + } + } +} diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs index b7de186..a7d2463 100644 --- a/packages/lib/src/lua/task/scheduler.rs +++ b/packages/lib/src/lua/task/scheduler.rs @@ -14,7 +14,7 @@ use super::message::TaskSchedulerMessage; pub use super::{task_kind::TaskKind, task_reference::TaskReference}; type TaskFutureRets<'fut> = LuaResult>>; -type TaskFuture<'fut> = LocalBoxFuture<'fut, (TaskReference, TaskFutureRets<'fut>)>; +type TaskFuture<'fut> = LocalBoxFuture<'fut, (Option, TaskFutureRets<'fut>)>; /// A struct representing a task contained in the task scheduler #[derive(Debug)] @@ -40,10 +40,10 @@ pub struct TaskScheduler<'fut> { // Internal state & flags pub(super) lua: &'static Lua, pub(super) guid: Cell, - pub(super) guid_running: Cell>, pub(super) exit_code: Cell>, // Blocking tasks pub(super) tasks: RefCell>, + pub(super) tasks_current: Cell>, pub(super) tasks_queue_blocking: RefCell>, // Future tasks & objects for waking pub(super) futures: AsyncMutex>>, @@ -61,9 +61,9 @@ impl<'fut> TaskScheduler<'fut> { Ok(Self { lua, guid: Cell::new(0), - guid_running: Cell::new(None), exit_code: Cell::new(None), tasks: RefCell::new(HashMap::new()), + tasks_current: Cell::new(None), tasks_queue_blocking: RefCell::new(VecDeque::new()), futures: AsyncMutex::new(FuturesUnordered::new()), futures_tx: tx, @@ -109,6 +109,14 @@ impl<'fut> TaskScheduler<'fut> { self.tasks.borrow().contains_key(&reference) } + /** + Returns the currently running task, if any. + */ + #[allow(dead_code)] + pub fn current_task(&self) -> Option { + self.tasks_current.get() + } + /** Creates a new task, storing a new Lua thread for it, as well as the arguments to give the @@ -123,7 +131,7 @@ impl<'fut> TaskScheduler<'fut> { kind: TaskKind, thread: LuaThread<'_>, thread_args: Option>, - guid_to_reuse: Option, + inherit_current_guid: bool, ) -> LuaResult { // Store the thread and its arguments in the registry // NOTE: We must convert to a vec since multis @@ -137,19 +145,22 @@ impl<'fut> TaskScheduler<'fut> { args: task_args_key, }; // Create the task ref to use - let task_ref = if let Some(reusable_guid) = guid_to_reuse { - TaskReference::new(kind, reusable_guid) + let guid = if inherit_current_guid { + self.current_task() + .expect("No current guid to inherit") + .id() } else { let guid = self.guid.get(); self.guid.set(guid + 1); - TaskReference::new(kind, guid) + guid }; + let reference = TaskReference::new(kind, guid); // Add the task to the scheduler { let mut tasks = self.tasks.borrow_mut(); - tasks.insert(task_ref, task); + tasks.insert(reference, task); } - Ok(task_ref) + Ok(reference) } /** @@ -181,8 +192,13 @@ impl<'fut> TaskScheduler<'fut> { .filter(|task_ref| task_ref.id() == reference.id()) .copied() .collect(); - for task_ref in tasks_to_remove { - if let Some(task) = tasks.remove(&task_ref) { + for task_ref in &tasks_to_remove { + if let Some(task) = tasks.remove(task_ref) { + // NOTE: We need to close the thread here to + // make 100% sure that nothing can resume it + let close: LuaFunction = self.lua.named_registry_value("co.close")?; + let thread: LuaThread = self.lua.registry_value(&task.thread)?; + close.call(thread)?; self.lua.remove_registry_value(task.thread)?; self.lua.remove_registry_value(task.args)?; found = true; @@ -204,13 +220,16 @@ impl<'fut> TaskScheduler<'fut> { reference: TaskReference, override_args: Option>>, ) -> LuaResult> { + // Fetch and check if the task was removed, if it got + // removed it means it was intentionally cancelled let task = { let mut tasks = self.tasks.borrow_mut(); match tasks.remove(&reference) { Some(task) => task, - None => return Ok(LuaMultiValue::new()), // Task was removed + None => return Ok(LuaMultiValue::new()), } }; + // Fetch and remove the thread to resume + its arguments let thread: LuaThread = self.lua.registry_value(&task.thread)?; let args_opt_res = override_args.or_else(|| { Ok(self @@ -222,7 +241,9 @@ impl<'fut> TaskScheduler<'fut> { }); self.lua.remove_registry_value(task.thread)?; self.lua.remove_registry_value(task.args)?; - self.guid_running.set(Some(reference.id())); + // We got everything we need and our references + // were cleaned up properly, resume the thread + self.tasks_current.set(Some(reference)); let rets = match args_opt_res { Some(args_res) => match args_res { /* @@ -235,12 +256,12 @@ impl<'fut> TaskScheduler<'fut> { that may pass errors as arguments when resuming tasks, other native mlua functions will handle this and dont need wrapping */ - Err(err) => thread.resume(err), + Err(e) => thread.resume(e), Ok(args) => thread.resume(args), }, None => thread.resume(()), }; - self.guid_running.set(None); + self.tasks_current.set(None); rets } @@ -265,12 +286,11 @@ impl<'fut> TaskScheduler<'fut> { kind: TaskKind, thread: LuaThread<'_>, thread_args: Option>, - guid_to_reuse: Option, ) -> LuaResult { if kind == TaskKind::Future { panic!("Tried to schedule future using normal task schedule method") } - let task_ref = self.create_task(kind, thread, thread_args, guid_to_reuse)?; + let task_ref = self.create_task(kind, thread, thread_args, false)?; // Add the task to the front of the queue, unless it // should be deferred, in that case add it to the back let mut queue = self.tasks_queue_blocking.borrow_mut(); @@ -304,17 +324,16 @@ impl<'fut> TaskScheduler<'fut> { &self, thread: LuaThread<'_>, thread_args: Option>, - guid_to_reuse: Option, fut: impl Future> + 'fut, ) -> LuaResult { - let task_ref = self.create_task(TaskKind::Future, thread, thread_args, guid_to_reuse)?; + let task_ref = self.create_task(TaskKind::Future, thread, thread_args, false)?; let futs = self .futures .try_lock() - .expect("Failed to get lock on futures"); + .expect("Tried to add future to queue during futures resumption"); futs.push(Box::pin(async move { let result = fut.await; - (task_ref, result) + (Some(task_ref), result) })); Ok(task_ref) } diff --git a/packages/lib/src/lua/task/task_reference.rs b/packages/lib/src/lua/task/task_reference.rs index bfc6fbf..8df11c3 100644 --- a/packages/lib/src/lua/task/task_reference.rs +++ b/packages/lib/src/lua/task/task_reference.rs @@ -24,7 +24,11 @@ impl TaskReference { impl fmt::Display for TaskReference { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "TaskReference({} - {})", self.kind, self.guid) + if self.guid == 0 { + write!(f, "TaskReference(MAIN)") + } else { + write!(f, "TaskReference({} - {})", self.kind, self.guid) + } } } diff --git a/packages/lib/src/tests.rs b/packages/lib/src/tests.rs index 5665e93..9e46c31 100644 --- a/packages/lib/src/tests.rs +++ b/packages/lib/src/tests.rs @@ -26,12 +26,12 @@ macro_rules! create_tests { // The rest of the test logic can continue as normal let full_name = format!("tests/{}.luau", $value); let script = read_to_string(&full_name).await?; - let lune = Lune::new().with_all_globals_and_args( + let lune = Lune::new().with_args( ARGS .clone() .iter() .map(ToString::to_string) - .collect() + .collect::>() ); let script_name = full_name.strip_suffix(".luau").unwrap(); let exit_code = lune.run(&script_name, &script).await?; diff --git a/packages/lib/src/utils/formatting.rs b/packages/lib/src/utils/formatting.rs index cddf01c..e15395e 100644 --- a/packages/lib/src/utils/formatting.rs +++ b/packages/lib/src/utils/formatting.rs @@ -240,18 +240,21 @@ pub fn pretty_format_luau_error(e: &LuaError, colorized: bool) -> String { // The traceback may also start with "override traceback:" which // means it was passed from somewhere that wants a custom trace, // so we should then respect that and get the best override instead - let mut best_trace: &str = traceback; + let mut full_trace = traceback.to_string(); let mut root_cause = cause.as_ref(); let mut trace_override = false; while let LuaError::CallbackError { cause, traceback } = root_cause { let is_override = traceback.starts_with("override traceback:"); if is_override { - if !trace_override || traceback.lines().count() > best_trace.len() { - best_trace = traceback.strip_prefix("override traceback:").unwrap(); + if !trace_override || traceback.lines().count() > full_trace.len() { + full_trace = traceback + .strip_prefix("override traceback:") + .unwrap() + .to_string(); trace_override = true; } - } else if !trace_override && traceback.lines().count() > best_trace.len() { - best_trace = traceback; + } else if !trace_override { + full_trace = format!("{traceback}\n{full_trace}"); } root_cause = cause; } @@ -266,10 +269,10 @@ pub fn pretty_format_luau_error(e: &LuaError, colorized: bool) -> String { "{}\n{}\n{}\n{}", pretty_format_luau_error(root_cause, colorized), stack_begin, - if best_trace.starts_with("stack traceback:") { - best_trace.strip_prefix("stack traceback:\n").unwrap() + if full_trace.starts_with("stack traceback:") { + full_trace.strip_prefix("stack traceback:\n").unwrap() } else { - best_trace + &full_trace }, stack_end ) @@ -378,7 +381,8 @@ fn transform_stack_line(line: &str) -> String { let line_num = match after_name.find(':') { Some(lineno_start) => match after_name[lineno_start + 1..].find(':') { Some(lineno_end) => &after_name[lineno_start + 1..lineno_end + 1], - None => match after_name.contains("in function") { + None => match after_name.contains("in function") || after_name.contains("in ?") + { false => &after_name[lineno_start + 1..], true => "", }, @@ -418,11 +422,18 @@ fn transform_stack_line(line: &str) -> String { fn fix_error_nitpicks(full_message: String) -> String { full_message // Hacky fix for our custom require appearing as a normal script + // TODO: It's probably better to pull in the regex crate here .. .replace("'require', Line 5", "'[C]' - function require") .replace("'require', Line 7", "'[C]' - function require") + .replace("'require', Line 8", "'[C]' - function require") // Fix error calls in custom script chunks coming through .replace( "'[C]' - function error\n Script '[C]' - function require", "'[C]' - function require", ) + // Fix strange double require + .replace( + "'[C]' - function require - function require", + "'[C]' - function require", + ) } diff --git a/packages/lib/src/utils/table.rs b/packages/lib/src/utils/table.rs index e11ddce..002c841 100644 --- a/packages/lib/src/utils/table.rs +++ b/packages/lib/src/utils/table.rs @@ -2,7 +2,7 @@ use std::future::Future; use mlua::prelude::*; -use crate::lua::ext::LuaAsyncExt; +use crate::lua::async_ext::LuaAsyncExt; pub struct TableBuilder { lua: &'static Lua, diff --git a/tests/globals/pcall.luau b/tests/globals/pcall.luau index 7658899..ceb1c69 100644 --- a/tests/globals/pcall.luau +++ b/tests/globals/pcall.luau @@ -1,3 +1,6 @@ +local PORT = 9090 -- NOTE: This must be different from +-- net tests to let them run in parallel with this file + local function test(f, ...) local success, message = pcall(f, ...) assert(not success, "Function did not throw an error") @@ -14,7 +17,7 @@ test(net.request, "https://wxyz.google.com") -- Net serve is async and will throw an OS error when trying to serve twice on the same port -local handle = net.serve(8080, function() +local handle = net.serve(PORT, function() return "" end) @@ -22,18 +25,4 @@ task.delay(0, function() handle.stop() end) -test(net.serve, 8080, function() end) - -local function e() - task.spawn(function() - task.defer(function() - task.delay(0, function() - error({ - Hello = "World", - }) - end) - end) - end) -end - -task.defer(e) +test(net.serve, PORT, function() end) diff --git a/tests/globals/require/tests/children.luau b/tests/globals/require/tests/children.luau index 79dcf94..c3ed971 100644 --- a/tests/globals/require/tests/children.luau +++ b/tests/globals/require/tests/children.luau @@ -4,6 +4,8 @@ assert(type(module) == "table", "Required module did not return a table") assert(module.Foo == "Bar", "Required module did not contain correct values") assert(module.Hello == "World", "Required module did not contain correct values") -require("modules/module") +module = require("modules/module") +assert(module.Foo == "Bar", "Required module did not contain correct values") +assert(module.Hello == "World", "Required module did not contain correct values") return true diff --git a/tests/globals/require/tests/foo.lua b/tests/globals/require/tests/foo.lua deleted file mode 100644 index e69de29..0000000 diff --git a/tests/globals/require/tests/invalid.luau b/tests/globals/require/tests/invalid.luau index 2dbaaab..52a3574 100644 --- a/tests/globals/require/tests/invalid.luau +++ b/tests/globals/require/tests/invalid.luau @@ -5,7 +5,7 @@ local function test(path: string) if success then error(string.format("Invalid require at path '%s' succeeded", path)) else - print(message) + message = tostring(message) if string.find(message, string.format("%s'", path)) == nil then error( string.format( diff --git a/tests/globals/require/tests/nested.luau b/tests/globals/require/tests/nested.luau index 68d2d10..e0c80ee 100644 --- a/tests/globals/require/tests/nested.luau +++ b/tests/globals/require/tests/nested.luau @@ -4,4 +4,6 @@ assert(type(module) == "table", "Required module did not return a table") assert(module.Foo == "Bar", "Required module did not contain correct values") assert(module.Hello == "World", "Required module did not contain correct values") -require("modules/nested") +module = require("modules/nested") +assert(module.Foo == "Bar", "Required module did not contain correct values") +assert(module.Hello == "World", "Required module did not contain correct values") diff --git a/tests/net/request/util.luau b/tests/net/request/util.luau index 562b650..ae020a2 100644 --- a/tests/net/request/util.luau +++ b/tests/net/request/util.luau @@ -15,7 +15,7 @@ function util.fail(method, url, message) method = method, url = url, }) - if not response.ok then + if response.ok then error(string.format("%s passed!\nResponse: %s", message, stdio.format(response))) end end