From fc5de3c8d5ba5f2b2fe9d3532fcd01bed7b802b1 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Mon, 13 Feb 2023 15:28:18 +0100 Subject: [PATCH] Initial implementation of proper task scheduler, no async yet --- packages/cli/src/main.rs | 2 +- packages/cli/src/tests.rs | 6 +- packages/lib/src/globals/mod.rs | 14 +- packages/lib/src/globals/net.rs | 35 +-- packages/lib/src/globals/process.rs | 57 +++- packages/lib/src/globals/require.rs | 38 ++- packages/lib/src/globals/task.rs | 218 +++++-------- packages/lib/src/globals/top_level.rs | 5 - packages/lib/src/lib.rs | 136 ++++----- packages/lib/src/lua/mod.rs | 1 + packages/lib/src/lua/task/mod.rs | 3 + packages/lib/src/lua/task/scheduler.rs | 406 +++++++++++++++++++++++++ packages/lib/src/tests.rs | 2 +- packages/lib/src/utils/formatting.rs | 41 ++- packages/lib/src/utils/message.rs | 9 - packages/lib/src/utils/mod.rs | 2 - packages/lib/src/utils/process.rs | 18 +- packages/lib/src/utils/task.rs | 76 ----- tests/stdio/write.luau | 2 - tests/task/defer.luau | 6 +- tests/task/delay.luau | 10 +- tests/task/spawn.luau | 4 +- tests/task/wait.luau | 24 +- 23 files changed, 685 insertions(+), 430 deletions(-) create mode 100644 packages/lib/src/lua/task/mod.rs create mode 100644 packages/lib/src/lua/task/scheduler.rs delete mode 100644 packages/lib/src/utils/message.rs delete mode 100644 packages/lib/src/utils/task.rs diff --git a/packages/cli/src/main.rs b/packages/cli/src/main.rs index 833496b..9b57b5e 100644 --- a/packages/cli/src/main.rs +++ b/packages/cli/src/main.rs @@ -22,7 +22,7 @@ mod tests; use cli::Cli; -#[tokio::main] +#[tokio::main(flavor = "multi_thread")] async fn main() -> Result { Cli::parse().run().await } diff --git a/packages/cli/src/tests.rs b/packages/cli/src/tests.rs index 2d470f5..4032254 100644 --- a/packages/cli/src/tests.rs +++ b/packages/cli/src/tests.rs @@ -36,20 +36,20 @@ async fn ensure_file_exists_and_is_not_json(file_name: &str) -> Result<()> { } } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn list() -> Result<()> { Cli::list().run().await?; Ok(()) } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn download_selene_types() -> Result<()> { run_cli(Cli::download_selene_types()).await?; ensure_file_exists_and_is_not_json(LUNE_SELENE_FILE_NAME).await?; Ok(()) } -#[tokio::test] +#[tokio::test(flavor = "multi_thread")] async fn download_luau_types() -> Result<()> { run_cli(Cli::download_luau_types()).await?; ensure_file_exists_and_is_not_json(LUNE_LUAU_FILE_NAME).await?; diff --git a/packages/lib/src/globals/mod.rs b/packages/lib/src/globals/mod.rs index a0b27de..6688de2 100644 --- a/packages/lib/src/globals/mod.rs +++ b/packages/lib/src/globals/mod.rs @@ -2,6 +2,8 @@ use std::fmt::{Display, Formatter, Result as FmtResult}; use mlua::prelude::*; +use crate::lua::task::TaskScheduler; + mod fs; mod net; mod process; @@ -78,14 +80,18 @@ impl LuneGlobal { Note that proxy globals should be handled with special care and that [`LuneGlobal::inject()`] should be preferred over manually creating and manipulating the value(s) of any Lune global. */ - pub fn value(&self, lua: &'static Lua) -> LuaResult { + pub fn value( + &self, + lua: &'static Lua, + scheduler: &'static TaskScheduler, + ) -> LuaResult { match self { LuneGlobal::Fs => fs::create(lua), LuneGlobal::Net => net::create(lua), LuneGlobal::Process { args } => process::create(lua, args.clone()), LuneGlobal::Require => require::create(lua), LuneGlobal::Stdio => stdio::create(lua), - LuneGlobal::Task => task::create(lua), + LuneGlobal::Task => task::create(lua, scheduler), LuneGlobal::TopLevel => top_level::create(lua), } } @@ -98,9 +104,9 @@ impl LuneGlobal { Refer to [`LuneGlobal::is_top_level()`] for more info on proxy globals. */ - pub fn inject(self, lua: &'static Lua) -> LuaResult<()> { + pub fn inject(self, lua: &'static Lua, scheduler: &'static TaskScheduler) -> LuaResult<()> { let globals = lua.globals(); - let table = self.value(lua)?; + let table = self.value(lua, scheduler)?; // NOTE: Top level globals are special, the values // *in* the table they return should be set directly, // instead of setting the table itself as the global diff --git a/packages/lib/src/globals/net.rs b/packages/lib/src/globals/net.rs index 0dfd1ee..0c257c9 100644 --- a/packages/lib/src/globals/net.rs +++ b/packages/lib/src/globals/net.rs @@ -17,10 +17,7 @@ use tokio::{sync::mpsc, task}; use crate::{ lua::net::{NetClient, NetClientBuilder, NetWebSocketClient, NetWebSocketServer, ServeConfig}, - utils::{ - message::LuneMessage, net::get_request_user_agent_header, table::TableBuilder, - task::send_message, - }, + utils::{net::get_request_user_agent_header, table::TableBuilder}, }; pub fn create(lua: &'static Lua) -> LuaResult { @@ -29,7 +26,7 @@ pub fn create(lua: &'static Lua) -> LuaResult { let client = NetClientBuilder::new() .headers(&[("User-Agent", get_request_user_agent_header())])? .build()?; - lua.set_named_registry_value("NetClient", client)?; + lua.set_named_registry_value("net.client", client)?; // Create the global table for net TableBuilder::new(lua)? .with_function("jsonEncode", net_json_encode)? @@ -54,7 +51,7 @@ fn net_json_decode(lua: &'static Lua, json: String) -> LuaResult { } async fn net_request<'a>(lua: &'static Lua, config: LuaValue<'a>) -> LuaResult> { - let client: NetClient = lua.named_registry_value("NetClient")?; + let client: NetClient = lua.named_registry_value("net.client")?; // Extract stuff from config and make sure its all valid let (url, method, headers, body) = match config { LuaValue::String(s) => { @@ -177,20 +174,9 @@ async fn net_serve<'a>( shutdown_rx.recv().await.unwrap(); shutdown_rx.close(); }); - // Make sure we register the thread properly by sending messages - // when the server starts up and when it shuts down or errors - send_message(lua, LuneMessage::Spawned).await?; - task::spawn_local(async move { - let res = server.await.map_err(LuaError::external); - let _ = send_message( - lua, - match res { - Err(e) => LuneMessage::LuaError(e), - Ok(_) => LuneMessage::Finished, - }, - ) - .await; - }); + // TODO: Spawn a new scheduler future with this so we don't block + // and make sure that we register it properly to prevent shutdown + server.await.map_err(LuaError::external)?; // Create a new read-only table that contains methods // for manipulating server behavior and shutting it down let handle_stop = move |_, _: ()| { @@ -313,8 +299,8 @@ impl Service> for NetService { Ok(resp.body(body).unwrap()) } // If the handler returns an error, generate a 5xx response - Err(err) => { - send_message(lua, LuneMessage::LuaError(err.to_lua_err())).await?; + Err(_) => { + // TODO: Send above error to task scheduler so that it can emit properly Ok(Response::builder() .status(500) .body(Body::from("Internal Server Error")) @@ -323,10 +309,11 @@ impl Service> for NetService { // If the handler returns a value that is of an invalid type, // this should also be an error, so generate a 5xx response Ok(value) => { - send_message(lua, LuneMessage::LuaError(LuaError::RuntimeError(format!( + // TODO: Send below error to task scheduler so that it can emit properly + let _ = LuaError::RuntimeError(format!( "Expected net serve handler to return a value of type 'string' or 'table', got '{}'", value.type_name() - )))).await?; + )); Ok(Response::builder() .status(500) .body(Body::from("Internal Server Error")) diff --git a/packages/lib/src/globals/process.rs b/packages/lib/src/globals/process.rs index b6215e5..38708a2 100644 --- a/packages/lib/src/globals/process.rs +++ b/packages/lib/src/globals/process.rs @@ -1,21 +1,35 @@ -use std::{collections::HashMap, env, path::PathBuf, process::Stdio}; +use std::{ + collections::HashMap, + env, + path::PathBuf, + process::{ExitCode, Stdio}, +}; use directories::UserDirs; use mlua::prelude::*; use os_str_bytes::RawOsString; use tokio::process::Command; -use crate::utils::{ - process::{exit_and_yield_forever, pipe_and_inherit_child_process_stdio}, - table::TableBuilder, +use crate::{ + lua::task::TaskScheduler, + utils::{process::pipe_and_inherit_child_process_stdio, table::TableBuilder}, }; +const PROCESS_EXIT_IMPL_LUA: &str = r#" +exit(...) +yield() +"#; + pub fn create(lua: &'static Lua, args_vec: Vec) -> LuaResult { - let cwd = env::current_dir()?.canonicalize()?; - let mut cwd_str = cwd.to_string_lossy().to_string(); - if !cwd_str.ends_with('/') { - cwd_str = format!("{cwd_str}/"); - } + let cwd_str = { + let cwd = env::current_dir()?.canonicalize()?; + let cwd_str = cwd.to_string_lossy().to_string(); + if !cwd_str.ends_with('/') { + format!("{cwd_str}/") + } else { + cwd_str + } + }; // Create readonly args array let args_tab = TableBuilder::new(lua)? .with_sequential_values(args_vec)? @@ -30,12 +44,31 @@ pub fn create(lua: &'static Lua, args_vec: Vec) -> LuaResult { .build_readonly()?, )? .build_readonly()?; + // Create our process exit function, this is a bit involved since + // we have no way to yield from c / rust, we need to load a lua + // chunk that will set the exit code and yield for us instead + let process_exit_env_yield: LuaFunction = lua.named_registry_value("co.yield")?; + let process_exit_env_exit: LuaFunction = lua.create_function(|lua, code: Option| { + let exit_code = code.map_or(ExitCode::SUCCESS, ExitCode::from); + let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap(); + sched.set_exit_code(exit_code); + Ok(()) + })?; + let process_exit = lua + .load(PROCESS_EXIT_IMPL_LUA) + .set_environment( + TableBuilder::new(lua)? + .with_value("yield", process_exit_env_yield)? + .with_value("exit", process_exit_env_exit)? + .build_readonly()?, + )? + .into_function()?; // Create the full process table TableBuilder::new(lua)? .with_value("args", args_tab)? .with_value("cwd", cwd_str)? .with_value("env", env_tab)? - .with_async_function("exit", process_exit)? + .with_value("exit", process_exit)? .with_async_function("spawn", process_spawn)? .build_readonly() } @@ -109,10 +142,6 @@ fn process_env_iter<'lua>( }) } -async fn process_exit(lua: &'static Lua, exit_code: Option) -> LuaResult<()> { - exit_and_yield_forever(lua, exit_code).await -} - async fn process_spawn<'a>( lua: &'static Lua, (mut program, args, options): (String, Option>, Option>), diff --git a/packages/lib/src/globals/require.rs b/packages/lib/src/globals/require.rs index a666043..6f39ef2 100644 --- a/packages/lib/src/globals/require.rs +++ b/packages/lib/src/globals/require.rs @@ -10,12 +10,10 @@ use os_str_bytes::{OsStrBytes, RawOsStr}; use crate::utils::table::TableBuilder; pub fn create(lua: &'static Lua) -> LuaResult { - let require: LuaFunction = lua.globals().raw_get("require")?; // Preserve original require behavior if we have a special env var set if env::var_os("LUAU_PWD_REQUIRE").is_some() { - return TableBuilder::new(lua)? - .with_value("require", require)? - .build_readonly(); + // Return an empty table since there are no globals to overwrite + return TableBuilder::new(lua)?.build_readonly(); } /* Store the current working directory so that we can use it later @@ -27,24 +25,17 @@ pub fn create(lua: &'static Lua) -> LuaResult { just in case someone out there uses luau with non-utf8 string requires */ let pwd = lua.create_string(¤t_dir()?.to_raw_bytes())?; - lua.set_named_registry_value("require_pwd", pwd)?; - // Fetch the debug info function and store it in the registry - // - we will use it to fetch the current scripts file name - let debug: LuaTable = lua.globals().raw_get("debug")?; - let info: LuaFunction = debug.raw_get("info")?; - lua.set_named_registry_value("require_getinfo", info)?; - // Store the original require function in the registry - lua.set_named_registry_value("require_original", require)?; + lua.set_named_registry_value("pwd", pwd)?; /* Create a new function that fetches the file name from the current thread, sets the luau module lookup path to be the exact script we are looking for, and then runs the original require function with the wanted path */ let new_require = lua.create_function(|lua, require_path: LuaString| { - let require_pwd: LuaString = lua.named_registry_value("require_pwd")?; - let require_original: LuaFunction = lua.named_registry_value("require_original")?; - let require_getinfo: LuaFunction = lua.named_registry_value("require_getinfo")?; - let require_source: LuaString = require_getinfo.call((2, "s"))?; + let require_pwd: LuaString = lua.named_registry_value("pwd")?; + let require_fn: LuaFunction = lua.named_registry_value("require")?; + let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; + let require_source: LuaString = require_info.call((2, "s"))?; /* Combine the require caller source with the wanted path string to get a final path relative to pwd - it is definitely @@ -53,10 +44,15 @@ pub fn create(lua: &'static Lua) -> LuaResult { let raw_pwd_str = RawOsStr::assert_from_raw_bytes(require_pwd.as_bytes()); let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes()); let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes()); - let mut path_relative_to_pwd = PathBuf::from(&raw_source.to_os_str()) - .parent() - .unwrap() - .join(raw_path.to_os_str()); + let mut path_relative_to_pwd = PathBuf::from( + &raw_source + .trim_start_matches("[string \"") + .trim_end_matches("\"]") + .to_os_str(), + ) + .parent() + .unwrap() + .join(raw_path.to_os_str()); // Try to normalize and resolve relative path segments such as './' and '../' if let Ok(canonicalized) = path_relative_to_pwd.with_extension("luau").canonicalize() { path_relative_to_pwd = canonicalized.with_extension(""); @@ -72,7 +68,7 @@ pub fn create(lua: &'static Lua) -> LuaResult { let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes()); // If the require call errors then we should also replace // the path in the error message to improve user experience - let result: LuaResult<_> = require_original.call::<_, LuaValue>(lua_path_str); + let result: LuaResult<_> = require_fn.call::<_, LuaValue>(lua_path_str); match result { Err(LuaError::CallbackError { traceback, cause }) => { let before = format!( diff --git a/packages/lib/src/globals/task.rs b/packages/lib/src/globals/task.rs index 0bc728d..379548a 100644 --- a/packages/lib/src/globals/task.rs +++ b/packages/lib/src/globals/task.rs @@ -1,150 +1,84 @@ -use std::time::{Duration, Instant}; - use mlua::prelude::*; -use tokio::time; -use crate::utils::{ - table::TableBuilder, - task::{run_registered_task, TaskRunMode}, +use crate::{ + lua::task::{TaskReference, TaskScheduler}, + utils::table::TableBuilder, }; -const MINIMUM_WAIT_OR_DELAY_DURATION: f32 = 10.0 / 1_000.0; // 10ms +const TASK_WAIT_IMPL_LUA: &str = r#" +resume_after(thread(), ...) +return yield() +"#; -// TODO: We should probably keep track of all threads in a scheduler userdata -// that takes care of scheduling in a better way, and it should keep resuming -// threads until it encounters a delayed / waiting thread, then task:sleep -pub fn create(lua: &'static Lua) -> LuaResult { - // HACK: There is no way to call coroutine.close directly from the mlua - // crate, so we need to fetch the function and store it in the registry - let coroutine: LuaTable = lua.globals().raw_get("coroutine")?; - let close: LuaFunction = coroutine.raw_get("close")?; - lua.set_named_registry_value("coroutine.close", close)?; - // HACK: coroutine.resume has some weird scheduling issues, but our custom - // task.spawn implementation is more or less a replacement for it, so we - // overwrite the original coroutine.resume function with it to fix that - coroutine.raw_set("resume", lua.create_async_function(task_spawn)?)?; - // Rest of the task library is normal, just async functions, no metatable +pub fn create( + lua: &'static Lua, + scheduler: &'static TaskScheduler, +) -> LuaResult> { + lua.set_app_data(scheduler); + // Create task spawning functions that add tasks to the scheduler + let task_spawn = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| { + let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap(); + sched.schedule_instant(tof, args) + })?; + let task_defer = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| { + let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap(); + sched.schedule_deferred(tof, args) + })?; + let task_delay = + lua.create_function(|lua, (secs, tof, args): (f64, LuaValue, LuaMultiValue)| { + let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap(); + sched.schedule_delayed(secs, tof, args) + })?; + // Create our task wait function, this is a bit different since + // we have no way to yield from c / rust, we need to load a + // lua chunk that schedules and yields for us instead + let task_wait_env_thread: LuaFunction = lua.named_registry_value("co.thread")?; + let task_wait_env_yield: LuaFunction = lua.named_registry_value("co.yield")?; + let task_wait = lua + .load(TASK_WAIT_IMPL_LUA) + .set_environment( + TableBuilder::new(lua)? + .with_value("thread", task_wait_env_thread)? + .with_value("yield", task_wait_env_yield)? + .with_function( + "resume_after", + |lua, (thread, secs): (LuaThread, Option)| { + let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap(); + sched.resume_after(secs.unwrap_or(0f64), thread) + }, + )? + .build_readonly()?, + )? + .into_function()?; + // We want the task scheduler to be transparent, + // but it does not return real lua threads, so + // we need to override some globals to fake it + let globals = lua.globals(); + let type_original: LuaFunction = globals.get("type")?; + let type_proxy = lua.create_function(move |_, value: LuaValue| { + if let LuaValue::UserData(u) = &value { + if u.is::() { + return Ok(LuaValue::String(lua.create_string("thread")?)); + } + } + type_original.call(value) + })?; + let typeof_original: LuaFunction = globals.get("typeof")?; + let typeof_proxy = lua.create_function(move |_, value: LuaValue| { + if let LuaValue::UserData(u) = &value { + if u.is::() { + return Ok(LuaValue::String(lua.create_string("thread")?)); + } + } + typeof_original.call(value) + })?; + globals.set("type", type_proxy)?; + globals.set("typeof", typeof_proxy)?; + // All good, return the task scheduler lib TableBuilder::new(lua)? - .with_async_function("cancel", task_cancel)? - .with_async_function("delay", task_delay)? - .with_async_function("defer", task_defer)? - .with_async_function("spawn", task_spawn)? - .with_async_function("wait", task_wait)? + .with_value("spawn", task_spawn)? + .with_value("defer", task_defer)? + .with_value("delay", task_delay)? + .with_value("wait", task_wait)? .build_readonly() } - -fn tof_to_thread<'a>( - lua: &'static Lua, - thread_or_function: LuaValue<'a>, -) -> LuaResult> { - match thread_or_function { - LuaValue::Thread(t) => Ok(t), - LuaValue::Function(f) => Ok(lua.create_thread(f)?), - value => Err(LuaError::RuntimeError(format!( - "Argument must be a thread or function, got {}", - value.type_name() - ))), - } -} - -async fn task_cancel<'a>(lua: &'static Lua, thread: LuaThread<'a>) -> LuaResult<()> { - let close: LuaFunction = lua.named_registry_value("coroutine.close")?; - close.call_async::<_, LuaMultiValue>(thread).await?; - Ok(()) -} - -async fn task_defer<'a>( - lua: &'static Lua, - (tof, args): (LuaValue<'a>, LuaMultiValue<'a>), -) -> LuaResult> { - let task_thread = tof_to_thread(lua, tof)?; - let task_thread_key = lua.create_registry_value(task_thread)?; - let task_args_key = lua.create_registry_value(args.into_vec())?; - let lua_thread_to_return = lua.registry_value(&task_thread_key)?; - run_registered_task(lua, TaskRunMode::Deferred, async move { - let thread: LuaThread = lua.registry_value(&task_thread_key)?; - let argsv: Vec = lua.registry_value(&task_args_key)?; - let args = LuaMultiValue::from_vec(argsv); - if thread.status() == LuaThreadStatus::Resumable { - let _: LuaMultiValue = thread.into_async(args).await?; - } - lua.remove_registry_value(task_thread_key)?; - lua.remove_registry_value(task_args_key)?; - Ok(()) - }) - .await?; - Ok(lua_thread_to_return) -} - -async fn task_delay<'a>( - lua: &'static Lua, - (duration, tof, args): (Option, LuaValue<'a>, LuaMultiValue<'a>), -) -> LuaResult> { - let task_thread = tof_to_thread(lua, tof)?; - let task_thread_key = lua.create_registry_value(task_thread)?; - let task_args_key = lua.create_registry_value(args.into_vec())?; - let lua_thread_to_return = lua.registry_value(&task_thread_key)?; - let start = Instant::now(); - let dur = Duration::from_secs_f32( - duration - .map(|d| d.max(MINIMUM_WAIT_OR_DELAY_DURATION)) - .unwrap_or(MINIMUM_WAIT_OR_DELAY_DURATION), - ); - run_registered_task(lua, TaskRunMode::Instant, async move { - let thread: LuaThread = lua.registry_value(&task_thread_key)?; - // NOTE: We are somewhat busy-waiting here, but we have to do this to make sure - // that delayed+cancelled threads do not prevent the tokio runtime from finishing - while thread.status() == LuaThreadStatus::Resumable && start.elapsed() < dur { - time::sleep(Duration::from_millis(1)).await; - } - if thread.status() == LuaThreadStatus::Resumable { - let argsv: Vec = lua.registry_value(&task_args_key)?; - let args = LuaMultiValue::from_vec(argsv); - let _: LuaMultiValue = thread.into_async(args).await?; - } - lua.remove_registry_value(task_thread_key)?; - lua.remove_registry_value(task_args_key)?; - Ok(()) - }) - .await?; - Ok(lua_thread_to_return) -} - -async fn task_spawn<'a>( - lua: &'static Lua, - (tof, args): (LuaValue<'a>, LuaMultiValue<'a>), -) -> LuaResult> { - let task_thread = tof_to_thread(lua, tof)?; - let task_thread_key = lua.create_registry_value(task_thread)?; - let task_args_key = lua.create_registry_value(args.into_vec())?; - let lua_thread_to_return = lua.registry_value(&task_thread_key)?; - run_registered_task(lua, TaskRunMode::Instant, async move { - let thread: LuaThread = lua.registry_value(&task_thread_key)?; - let argsv: Vec = lua.registry_value(&task_args_key)?; - let args = LuaMultiValue::from_vec(argsv); - if thread.status() == LuaThreadStatus::Resumable { - let _: LuaMultiValue = thread.into_async(args).await?; - } - lua.remove_registry_value(task_thread_key)?; - lua.remove_registry_value(task_args_key)?; - Ok(()) - }) - .await?; - Ok(lua_thread_to_return) -} - -async fn task_wait(lua: &'static Lua, duration: Option) -> LuaResult { - let start = Instant::now(); - run_registered_task(lua, TaskRunMode::Blocking, async move { - time::sleep(Duration::from_secs_f32( - duration - .map(|d| d.max(MINIMUM_WAIT_OR_DELAY_DURATION)) - .unwrap_or(MINIMUM_WAIT_OR_DELAY_DURATION), - )) - .await; - Ok(()) - }) - .await?; - let end = Instant::now(); - Ok((end - start).as_secs_f32()) -} diff --git a/packages/lib/src/globals/top_level.rs b/packages/lib/src/globals/top_level.rs index dc61f9f..16b3081 100644 --- a/packages/lib/src/globals/top_level.rs +++ b/packages/lib/src/globals/top_level.rs @@ -6,15 +6,10 @@ use crate::utils::{ }; pub fn create(lua: &'static Lua) -> LuaResult { - let globals = lua.globals(); // HACK: We need to preserve the default behavior of the // print and error functions, for pcall and such, which // is really tricky to do from scratch so we will just // proxy the default print and error functions here - let print_fn: LuaFunction = globals.raw_get("print")?; - let error_fn: LuaFunction = globals.raw_get("error")?; - lua.set_named_registry_value("print", print_fn)?; - lua.set_named_registry_value("error", error_fn)?; TableBuilder::new(lua)? .with_function("print", |lua, args: LuaMultiValue| { let formatted = pretty_format_multi_value(&args)?; diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index 8317f86..36080e7 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -1,8 +1,8 @@ -use std::{collections::HashSet, process::ExitCode, sync::Arc}; +use std::{collections::HashSet, process::ExitCode}; +use lua::task::TaskScheduler; use mlua::prelude::*; -use tokio::{sync::mpsc, task}; -use utils::task::send_message; +use tokio::task::LocalSet; pub(crate) mod globals; pub(crate) mod lua; @@ -11,7 +11,7 @@ pub(crate) mod utils; #[cfg(test)] mod tests; -use crate::utils::{formatting::pretty_format_luau_error, message::LuneMessage}; +use crate::utils::formatting::pretty_format_luau_error; pub use globals::LuneGlobal; @@ -75,12 +75,12 @@ impl Lune { This will create a new sandboxed Luau environment with the configured globals and arguments, running inside of a [`tokio::task::LocalSet`]. - Some Lune globals such as [`LuneGlobal::Process`] may spawn - separate tokio tasks on other threads, but the Luau environment + Some Lune globals such as [`LuneGlobal::Process`] and [`LuneGlobal::Net`] + may spawn separate tokio tasks on other threads, but the Luau environment itself is guaranteed to run on a single thread in the local set. - Note that this will create a static Lua instance that will live - for the remainer of the program, and that this leaks memory using + Note that this will create a static Lua instance and task scheduler which both + will live for the remainer of the program, and that this leaks memory using [`Box::leak`] that will then get deallocated when the program exits. */ pub async fn run( @@ -88,92 +88,64 @@ impl Lune { script_name: &str, script_contents: &str, ) -> Result { - let task_set = task::LocalSet::new(); - let (sender, mut receiver) = mpsc::channel::(64); + let set = LocalSet::new(); let lua = Lua::new().into_static(); - let snd = Arc::new(sender); - lua.set_app_data(Arc::downgrade(&snd)); + let sched = TaskScheduler::new(lua)?.into_static(); + lua.set_app_data(sched); + // Store original lua global functions in the registry so we can use + // them later without passing them around and dealing with lifetimes + lua.set_named_registry_value("require", lua.globals().get::<_, LuaFunction>("require")?)?; + lua.set_named_registry_value("print", lua.globals().get::<_, LuaFunction>("print")?)?; + lua.set_named_registry_value("error", lua.globals().get::<_, LuaFunction>("error")?)?; + let coroutine: LuaTable = lua.globals().get("coroutine")?; + lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?; + lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?; + let debug: LuaTable = lua.globals().raw_get("debug")?; + lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; // Add in wanted lune globals for global in self.includes.clone() { if !self.excludes.contains(&global) { - global.inject(lua)?; + global.inject(lua, sched)?; } } - // Spawn the main thread from our entrypoint script - let script_name = script_name.to_string(); - let script_chunk = script_contents.to_string(); - send_message(lua, LuneMessage::Spawned).await?; - task_set.spawn_local(async move { - let result = lua - .load(&script_chunk) - .set_name(&format!("={script_name}")) - .unwrap() - .eval_async::() - .await; - send_message( - lua, - match result { - Err(e) => LuneMessage::LuaError(e), - Ok(_) => LuneMessage::Finished, - }, - ) - .await - }); - // Run the executor until there are no tasks left, - // taking care to not exit right away for errors - let (got_code, got_error, exit_code) = task_set + // Schedule the main thread on the task scheduler + sched.schedule_instant( + LuaValue::Function( + lua.load(script_contents) + .set_name(script_name) + .unwrap() + .into_function() + .unwrap(), + ), + LuaValue::Nil.to_lua_multi(lua)?, + )?; + // Keep running the scheduler until there are either no tasks + // left to run, or until some task requests to exit the process + let exit_code = set .run_until(async { - let mut task_count = 0; let mut got_error = false; - let mut got_code = false; - let mut exit_code = 0; - while let Some(message) = receiver.recv().await { - // Make sure our task-count-modifying messages are sent correctly, one - // task spawned must always correspond to one task finished / errored - match &message { - LuneMessage::Exit(_) => {} - LuneMessage::Spawned => {} - message => { - if task_count == 0 { - return Err(format!( - "Got message while task count was 0!\nMessage: {message:#?}" - )); + while let Some(result) = sched.resume_queue().await { + match result { + Err(e) => { + eprintln!("{}", pretty_format_luau_error(&e)); + got_error = true; + } + Ok(status) => { + if let Some(exit_code) = status.exit_code { + return exit_code; + } else if status.num_total == 0 { + return ExitCode::SUCCESS; } } } - // Handle whatever message we got - match message { - LuneMessage::Exit(code) => { - exit_code = code; - got_code = true; - break; - } - LuneMessage::Spawned => task_count += 1, - LuneMessage::Finished => task_count -= 1, - LuneMessage::LuaError(e) => { - eprintln!("{}", pretty_format_luau_error(&e)); - got_error = true; - task_count -= 1; - } - }; - // If there are no tasks left running, it is now - // safe to close the receiver and end execution - if task_count == 0 { - receiver.close(); - } } - Ok((got_code, got_error, exit_code)) + if got_error { + ExitCode::FAILURE + } else { + ExitCode::SUCCESS + } }) - .await - .map_err(LuaError::external)?; - // If we got an error, we will default to exiting - // with code 1, unless a code was manually given - if got_code { - Ok(ExitCode::from(exit_code)) - } else if got_error { - Ok(ExitCode::FAILURE) - } else { - Ok(ExitCode::SUCCESS) - } + .await; + Ok(exit_code) } } diff --git a/packages/lib/src/lua/mod.rs b/packages/lib/src/lua/mod.rs index f9faf2f..3be432e 100644 --- a/packages/lib/src/lua/mod.rs +++ b/packages/lib/src/lua/mod.rs @@ -1 +1,2 @@ pub mod net; +pub mod task; diff --git a/packages/lib/src/lua/task/mod.rs b/packages/lib/src/lua/task/mod.rs new file mode 100644 index 0000000..968d87d --- /dev/null +++ b/packages/lib/src/lua/task/mod.rs @@ -0,0 +1,3 @@ +mod scheduler; + +pub use scheduler::*; diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs new file mode 100644 index 0000000..5230111 --- /dev/null +++ b/packages/lib/src/lua/task/scheduler.rs @@ -0,0 +1,406 @@ +use std::{ + collections::{HashMap, VecDeque}, + fmt, + process::ExitCode, + sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, Mutex, + }, + time::Duration, +}; + +use mlua::prelude::*; + +use tokio::time::{sleep, Instant}; + +type TaskSchedulerQueue = Arc>>; + +/// An enum representing different kinds of tasks +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum TaskKind { + Instant, + Deferred, + Yielded, +} + +impl fmt::Display for TaskKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let name: &'static str = match self { + TaskKind::Instant => "Instant", + TaskKind::Deferred => "Deferred", + TaskKind::Yielded => "Yielded", + }; + write!(f, "{name}") + } +} + +/// A lightweight, clonable struct that represents a +/// task in the scheduler and is accessible from Lua +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct TaskReference { + kind: TaskKind, + guid: usize, + queued_target: Option, +} + +impl TaskReference { + pub const fn new(kind: TaskKind, guid: usize, queued_target: Option) -> Self { + Self { + kind, + guid, + queued_target, + } + } +} + +impl fmt::Display for TaskReference { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "TaskReference({} - {})", self.kind, self.guid) + } +} + +impl LuaUserData for TaskReference {} + +impl From<&Task> for TaskReference { + fn from(value: &Task) -> Self { + Self::new(value.kind, value.guid, value.queued_target) + } +} + +/// A struct representing a task contained in the task scheduler +#[derive(Debug)] +pub struct Task { + kind: TaskKind, + guid: usize, + thread: LuaRegistryKey, + args: LuaRegistryKey, + queued_at: Instant, + queued_target: Option, +} + +/// A struct representing the current status of the task scheduler +#[derive(Debug, Clone, Copy)] +pub struct TaskSchedulerStatus { + pub exit_code: Option, + pub num_instant: usize, + pub num_deferred: usize, + pub num_yielded: usize, + pub num_total: usize, +} + +impl fmt::Display for TaskSchedulerStatus { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "TaskSchedulerStatus(\nInstant: {}\nDeferred: {}\nYielded: {}\nTotal: {})", + self.num_instant, self.num_deferred, self.num_yielded, self.num_total + ) + } +} + +/// A task scheduler that implements task queues +/// with instant, deferred, and delayed tasks +#[derive(Debug)] +pub struct TaskScheduler { + lua: &'static Lua, + guid: AtomicUsize, + running: bool, + tasks: Arc>>, + task_queue_instant: TaskSchedulerQueue, + task_queue_deferred: TaskSchedulerQueue, + task_queue_yielded: TaskSchedulerQueue, + exit_code_set: AtomicBool, + exit_code: Arc>, +} + +impl TaskScheduler { + pub fn new(lua: &'static Lua) -> LuaResult { + Ok(Self { + lua, + guid: AtomicUsize::new(0), + running: false, + tasks: Arc::new(Mutex::new(HashMap::new())), + task_queue_instant: Arc::new(Mutex::new(VecDeque::new())), + task_queue_deferred: Arc::new(Mutex::new(VecDeque::new())), + task_queue_yielded: Arc::new(Mutex::new(VecDeque::new())), + exit_code_set: AtomicBool::new(false), + exit_code: Arc::new(Mutex::new(ExitCode::SUCCESS)), + }) + } + + pub fn into_static(self) -> &'static Self { + Box::leak(Box::new(self)) + } + + pub fn status(&self) -> TaskSchedulerStatus { + let counts = { + ( + self.task_queue_instant.lock().unwrap().len(), + self.task_queue_deferred.lock().unwrap().len(), + self.task_queue_yielded.lock().unwrap().len(), + ) + }; + let num_total = counts.0 + counts.1 + counts.2; + let exit_code = if self.exit_code_set.load(Ordering::Relaxed) { + Some(*self.exit_code.lock().unwrap()) + } else { + None + }; + TaskSchedulerStatus { + exit_code, + num_instant: counts.0, + num_deferred: counts.1, + num_yielded: counts.2, + num_total, + } + } + + pub fn set_exit_code(&self, code: ExitCode) { + self.exit_code_set.store(true, Ordering::Relaxed); + *self.exit_code.lock().unwrap() = code + } + + fn schedule<'a>( + &self, + kind: TaskKind, + tof: LuaValue<'a>, + args: Option>, + delay: Option, + ) -> LuaResult { + // Get or create a thread from the given argument + let task_thread = match tof { + LuaValue::Thread(t) => t, + LuaValue::Function(f) => self.lua.create_thread(f)?, + value => { + return Err(LuaError::RuntimeError(format!( + "Argument must be a thread or function, got {}", + value.type_name() + ))) + } + }; + // Store the thread and its arguments in the registry + let task_args_vec = args.map(|opt| opt.into_vec()); + let task_thread_key = self.lua.create_registry_value(task_thread)?; + let task_args_key = self.lua.create_registry_value(task_args_vec)?; + // Create the full task struct + let guid = self.guid.fetch_add(1, Ordering::Relaxed) + 1; + let queued_at = Instant::now(); + let queued_target = delay.map(|secs| queued_at + Duration::from_secs_f64(secs)); + let task = Task { + kind, + guid, + thread: task_thread_key, + args: task_args_key, + queued_at, + queued_target, + }; + // Create the task ref (before adding the task to the scheduler) + let task_ref = TaskReference::from(&task); + // Add it to the scheduler + { + let mut tasks = self.tasks.lock().unwrap(); + tasks.insert(task_ref, task); + } + match kind { + TaskKind::Instant => { + // If we have a currently running task and we spawned an + // instant task here it should run right after the currently + // running task, so put it at the front of the task queue + let mut queue = self.task_queue_instant.lock().unwrap(); + if self.running { + queue.push_front(task_ref); + } else { + queue.push_back(task_ref); + } + } + TaskKind::Deferred => { + // Deferred tasks should always schedule + // at the very end of the deferred queue + let mut queue = self.task_queue_deferred.lock().unwrap(); + queue.push_back(task_ref); + } + TaskKind::Yielded => { + // Find the first task that is scheduled after this one and insert before it, + // this will ensure that our list of delayed tasks is sorted and we can grab + // the very first one to figure out how long to yield until the next cycle + let mut queue = self.task_queue_yielded.lock().unwrap(); + let idx = queue + .iter() + .enumerate() + .find_map(|(idx, t)| { + if t.queued_target > queued_target { + Some(idx) + } else { + None + } + }) + .unwrap_or(queue.len()); + queue.insert(idx, task_ref); + } + } + Ok(task_ref) + } + + pub fn schedule_instant<'a>( + &self, + tof: LuaValue<'a>, + args: LuaMultiValue<'a>, + ) -> LuaResult { + self.schedule(TaskKind::Instant, tof, Some(args), None) + } + + pub fn schedule_deferred<'a>( + &self, + tof: LuaValue<'a>, + args: LuaMultiValue<'a>, + ) -> LuaResult { + self.schedule(TaskKind::Deferred, tof, Some(args), None) + } + + pub fn schedule_delayed<'a>( + &self, + secs: f64, + tof: LuaValue<'a>, + args: LuaMultiValue<'a>, + ) -> LuaResult { + self.schedule(TaskKind::Yielded, tof, Some(args), Some(secs)) + } + + pub fn resume_after(&self, secs: f64, thread: LuaThread<'_>) -> LuaResult { + self.schedule( + TaskKind::Yielded, + LuaValue::Thread(thread), + None, + Some(secs), + ) + } + + pub fn cancel(&self, reference: TaskReference) -> bool { + let queue_mutex = match reference.kind { + TaskKind::Instant => &self.task_queue_instant, + TaskKind::Deferred => &self.task_queue_deferred, + TaskKind::Yielded => &self.task_queue_yielded, + }; + let mut queue = queue_mutex.lock().unwrap(); + let mut found = false; + queue.retain(|task| { + if task.guid == reference.guid { + found = true; + false + } else { + true + } + }); + found + } + + pub fn resume_task(&self, reference: TaskReference) -> LuaResult<()> { + let task = { + let mut tasks = self.tasks.lock().unwrap(); + match tasks.remove(&reference) { + Some(task) => task, + None => { + return Err(LuaError::RuntimeError(format!( + "Task does not exist in scheduler: {reference}" + ))) + } + } + }; + let thread: LuaThread = self.lua.registry_value(&task.thread)?; + let args: Option> = self.lua.registry_value(&task.args)?; + if let Some(args) = args { + thread.resume::<_, LuaMultiValue>(LuaMultiValue::from_vec(args))?; + } else { + let elapsed = task.queued_at.elapsed().as_secs_f64(); + thread.resume::<_, LuaMultiValue>(elapsed)?; + } + self.lua.remove_registry_value(task.thread)?; + self.lua.remove_registry_value(task.args)?; + Ok(()) + } + + fn get_queue(&self, kind: TaskKind) -> &TaskSchedulerQueue { + match kind { + TaskKind::Instant => &self.task_queue_instant, + TaskKind::Deferred => &self.task_queue_deferred, + TaskKind::Yielded => &self.task_queue_yielded, + } + } + + fn next_queue_task(&self, kind: TaskKind) -> Option { + let task = { + let queue_guard = self.get_queue(kind).lock().unwrap(); + queue_guard.front().copied() + }; + task + } + + fn resume_next_queue_task(&self, kind: TaskKind) -> Option> { + match { + let mut queue_guard = self.get_queue(kind).lock().unwrap(); + queue_guard.pop_front() + } { + None => { + let status = self.status(); + if status.num_total > 0 { + Some(Ok(status)) + } else { + None + } + } + Some(t) => match self.resume_task(t) { + Ok(_) => Some(Ok(self.status())), + Err(e) => Some(Err(e)), + }, + } + } + + pub async fn resume_queue(&self) -> Option> { + let now = Instant::now(); + let status = self.status(); + /* + Resume tasks in the internal queue, in this order: + + 1. Tasks from task.spawn, this includes the main thread + 2. Tasks from task.defer + 3. Tasks from task.delay OR futures, whichever comes first + 4. Tasks from futures + */ + if status.num_instant > 0 { + self.resume_next_queue_task(TaskKind::Instant) + } else if status.num_deferred > 0 { + self.resume_next_queue_task(TaskKind::Deferred) + } else if status.num_yielded > 0 { + // 3. Threads from task.delay or task.wait, futures + let next_yield_target = self + .next_queue_task(TaskKind::Yielded) + .expect("Yielded task missing but status count is > 0") + .queued_target + .expect("Yielded task is missing queued target"); + // Resume this yielding task if its target time has passed + if now >= next_yield_target { + self.resume_next_queue_task(TaskKind::Yielded) + } else { + /* + Await the first future to be ready + + - If it is the sleep fut then we will return and the next + call to resume_queue will then resume that yielded task + + - If it is a future then we resume the corresponding task + that is has stored in the future-specific task queue + */ + sleep(next_yield_target - now).await; + // TODO: Implement this, for now we only await sleep + // since the task scheduler doesn't support futures + Some(Ok(self.status())) + } + } else { + // 4. Just futures + + // TODO: Await the first future to be ready + // and resume the corresponding task for it + None + } + } +} diff --git a/packages/lib/src/tests.rs b/packages/lib/src/tests.rs index 3fef645..da80e12 100644 --- a/packages/lib/src/tests.rs +++ b/packages/lib/src/tests.rs @@ -11,7 +11,7 @@ const ARGS: &[&str] = &["Foo", "Bar"]; macro_rules! create_tests { ($($name:ident: $value:expr,)*) => { $( - #[tokio::test] + #[tokio::test(flavor = "multi_thread")] async fn $name() -> Result { // Disable styling for stdout and stderr since // some tests rely on output not being styled diff --git a/packages/lib/src/utils/formatting.rs b/packages/lib/src/utils/formatting.rs index 6bfdb6d..e4ab5f7 100644 --- a/packages/lib/src/utils/formatting.rs +++ b/packages/lib/src/utils/formatting.rs @@ -4,6 +4,8 @@ use console::{style, Style}; use lazy_static::lazy_static; use mlua::prelude::*; +use crate::lua::task::TaskReference; + const MAX_FORMAT_DEPTH: usize = 4; const INDENT: &str = " "; @@ -165,9 +167,17 @@ pub fn pretty_format_value( )?, LuaValue::Thread(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to(""))?, LuaValue::Function(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to(""))?, - LuaValue::UserData(_) | LuaValue::LightUserData(_) => { - write!(buffer, "{}", COLOR_PURPLE.apply_to(""))? + LuaValue::UserData(u) => { + if u.is::() { + // Task references must be transparent + // to lua and pretend to be normal lua + // threads for compatibility purposes + write!(buffer, "{}", COLOR_PURPLE.apply_to(""))? + } else { + write!(buffer, "{}", COLOR_PURPLE.apply_to(""))? + } } + LuaValue::LightUserData(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to(""))?, _ => write!(buffer, "{}", STYLE_DIM.apply_to("?"))?, } Ok(()) @@ -220,23 +230,30 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String { err_lines.join("\n") } LuaError::CallbackError { traceback, cause } => { - // Find the best traceback (longest) and the root error message + // Find the best traceback (most lines) and the root error message let mut best_trace = traceback; let mut root_cause = cause.as_ref(); while let LuaError::CallbackError { cause, traceback } = root_cause { - if traceback.len() > best_trace.len() { + if traceback.lines().count() > best_trace.len() { best_trace = traceback; } root_cause = cause; } - // Same error formatting as above - format!( - "{}\n{}\n{}\n{}", - pretty_format_luau_error(root_cause), - stack_begin, - best_trace.strip_prefix("stack traceback:\n").unwrap(), - stack_end - ) + // If we got a runtime error with an embedded traceback, we should + // use that instead since it generally contains more information + if matches!(root_cause, LuaError::RuntimeError(e) if e.contains("stack traceback:")) { + pretty_format_luau_error(root_cause) + } else { + // Otherwise we format whatever root error we got using + // the same error formatting as for above runtime errors + format!( + "{}\n{}\n{}\n{}", + pretty_format_luau_error(root_cause), + stack_begin, + best_trace.strip_prefix("stack traceback:\n").unwrap(), + stack_end + ) + } } LuaError::ToLuaConversionError { from, to, message } => { let msg = message diff --git a/packages/lib/src/utils/message.rs b/packages/lib/src/utils/message.rs deleted file mode 100644 index 50f94c9..0000000 --- a/packages/lib/src/utils/message.rs +++ /dev/null @@ -1,9 +0,0 @@ -use mlua::prelude::*; - -#[derive(Debug, Clone)] -pub enum LuneMessage { - Exit(u8), - Spawned, - Finished, - LuaError(LuaError), -} diff --git a/packages/lib/src/utils/mod.rs b/packages/lib/src/utils/mod.rs index dcf93ed..61eebf6 100644 --- a/packages/lib/src/utils/mod.rs +++ b/packages/lib/src/utils/mod.rs @@ -1,7 +1,5 @@ pub mod formatting; pub mod futures; -pub mod message; pub mod net; pub mod process; pub mod table; -pub mod task; diff --git a/packages/lib/src/utils/process.rs b/packages/lib/src/utils/process.rs index 76b7a22..d211c38 100644 --- a/packages/lib/src/utils/process.rs +++ b/packages/lib/src/utils/process.rs @@ -1,11 +1,9 @@ -use std::{process::ExitStatus, time::Duration}; +use std::process::ExitStatus; use mlua::prelude::*; -use tokio::{io, process::Child, task::spawn, time::sleep}; +use tokio::{io, process::Child, task::spawn}; -use crate::utils::{futures::AsyncTeeWriter, message::LuneMessage}; - -use super::task::send_message; +use crate::utils::futures::AsyncTeeWriter; pub async fn pipe_and_inherit_child_process_stdio( mut child: Child, @@ -42,13 +40,3 @@ pub async fn pipe_and_inherit_child_process_stdio( Ok::<_, LuaError>((status, stdout_buffer?, stderr_buffer?)) } - -pub async fn exit_and_yield_forever(lua: &'static Lua, exit_code: Option) -> LuaResult<()> { - // Send an exit signal to the main thread, which - // will try to exit safely and as soon as possible - send_message(lua, LuneMessage::Exit(exit_code.unwrap_or(0))).await?; - // Make sure to block the rest of this thread indefinitely since - // the main thread may not register the exit signal right away - sleep(Duration::MAX).await; - Ok(()) -} diff --git a/packages/lib/src/utils/task.rs b/packages/lib/src/utils/task.rs deleted file mode 100644 index db55a9d..0000000 --- a/packages/lib/src/utils/task.rs +++ /dev/null @@ -1,76 +0,0 @@ -use std::fmt::{self, Debug}; -use std::future::Future; -use std::sync::Weak; - -use mlua::prelude::*; -use tokio::sync::mpsc::Sender; -use tokio::task; - -use crate::utils::message::LuneMessage; - -#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] -pub enum TaskRunMode { - Blocking, - Instant, - Deferred, -} - -impl fmt::Display for TaskRunMode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Blocking => write!(f, "Blocking"), - Self::Instant => write!(f, "Instant"), - Self::Deferred => write!(f, "Deferred"), - } - } -} - -pub async fn send_message(lua: &'static Lua, message: LuneMessage) -> LuaResult<()> { - let sender = lua - .app_data_ref::>>() - .unwrap() - .upgrade() - .unwrap(); - sender.send(message).await.map_err(LuaError::external) -} - -pub async fn run_registered_task( - lua: &'static Lua, - mode: TaskRunMode, - to_run: impl Future> + 'static, -) -> LuaResult<()> { - // Send a message that we have started our task - send_message(lua, LuneMessage::Spawned).await?; - // Run the new task separately from the current one using the executor - let task = task::spawn_local(async move { - // HACK: For deferred tasks we yield a bunch of times to try and ensure - // we run our task at the very end of the async queue, this can fail if - // the user creates a bunch of interleaved deferred and normal tasks - if mode == TaskRunMode::Deferred { - for _ in 0..64 { - task::yield_now().await; - } - } - send_message( - lua, - match to_run.await { - Ok(_) => LuneMessage::Finished, - Err(LuaError::CoroutineInactive) => LuneMessage::Finished, // Task was canceled - Err(e) => LuneMessage::LuaError(e), - }, - ) - .await - }); - // Wait for the task to complete if we want this call to be blocking - // Any lua errors will be sent through the message channel back - // to the main thread which will then handle them properly - if mode == TaskRunMode::Blocking { - task.await - .map_err(LuaError::external)? - .map_err(LuaError::external)?; - } - // Yield once right away to let the above spawned task start working - // instantly, forcing it to run until completion or until it yields - task::yield_now().await; - Ok(()) -} diff --git a/tests/stdio/write.luau b/tests/stdio/write.luau index b77c81e..4736933 100644 --- a/tests/stdio/write.luau +++ b/tests/stdio/write.luau @@ -1,3 +1 @@ stdio.write("Hello, stdout!") - -process.exit(0) diff --git a/tests/task/defer.luau b/tests/task/defer.luau index df95338..01123f8 100644 --- a/tests/task/defer.luau +++ b/tests/task/defer.luau @@ -10,18 +10,18 @@ task.defer(function() flag = true end) assert(not flag, "Defer should not run instantly or block") -task.wait(0.1) +task.wait(0.05) assert(flag, "Defer should run") -- Deferred functions should work with yielding local flag2: boolean = false task.defer(function() - task.wait(0.1) + task.wait(0.05) flag2 = true end) assert(not flag2, "Defer should work with yielding (1)") -task.wait(0.2) +task.wait(0.1) assert(flag2, "Defer should work with yielding (2)") -- Deferred functions should run after other spawned threads diff --git a/tests/task/delay.luau b/tests/task/delay.luau index 7d00546..667eb5f 100644 --- a/tests/task/delay.luau +++ b/tests/task/delay.luau @@ -10,20 +10,20 @@ task.delay(0, function() flag = true end) assert(not flag, "Delay should not run instantly or block") -task.wait(1 / 60) +task.wait(0.05) assert(flag, "Delay should run after the wanted duration") -- Delayed functions should work with yielding local flag2: boolean = false -task.delay(0.2, function() +task.delay(0.05, function() flag2 = true - task.wait(0.4) + task.wait(0.1) flag2 = false end) -task.wait(0.4) +task.wait(0.1) assert(flag, "Delay should work with yielding (1)") -task.wait(0.4) +task.wait(0.1) assert(not flag2, "Delay should work with yielding (2)") -- Varargs should get passed correctly diff --git a/tests/task/spawn.luau b/tests/task/spawn.luau index 7080d05..f01344a 100644 --- a/tests/task/spawn.luau +++ b/tests/task/spawn.luau @@ -15,11 +15,11 @@ assert(flag, "Spawn should run instantly") local flag2: boolean = false task.spawn(function() - task.wait(0.1) + task.wait(0.05) flag2 = true end) assert(not flag2, "Spawn should work with yielding (1)") -task.wait(0.2) +task.wait(0.1) assert(flag2, "Spawn should work with yielding (2)") -- Spawned functions should be able to run threads created with the coroutine global diff --git a/tests/task/wait.luau b/tests/task/wait.luau index 9e6efc9..d7dde6f 100644 --- a/tests/task/wait.luau +++ b/tests/task/wait.luau @@ -5,18 +5,28 @@ local EPSILON = 1 / 100 local function test(expected: number) local start = os.clock() local returned = task.wait(expected) + if typeof(returned) ~= "number" then + error( + string.format( + "Expected task.wait to return a number, got %s %s", + typeof(returned), + stdio.format(returned) + ), + 2 + ) + end local elapsed = (os.clock() - start) - local difference = math.abs(elapsed - returned) + local difference = math.abs(elapsed - expected) if difference > EPSILON then error( string.format( "Elapsed time diverged too much from argument!" - .. "\nGot argument of %.3fs and elapsed time of %.3fs" - .. "\nGot maximum difference of %.3fs and real difference of %.3fs", - expected, - elapsed, - EPSILON, - difference + .. "\nGot argument of %.3fms and elapsed time of %.3fms" + .. "\nGot maximum difference of %.3fms and real difference of %.3fms", + expected * 1_000, + elapsed * 1_000, + EPSILON * 1_000, + difference * 1_000 ) ) end