diff --git a/Cargo.toml b/Cargo.toml index e880e69..49144ac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,3 +9,6 @@ mlua = { version = "0.9", features = ["luau", "luau-jit", "async"] } [lib] path = "lib/lib.rs" + +[examples] +main = "examples/main.rs" diff --git a/examples/main.luau b/examples/main.luau new file mode 100644 index 0000000..f598983 --- /dev/null +++ b/examples/main.luau @@ -0,0 +1,25 @@ +--!nocheck +--!nolint UnknownGlobal + + +if math.random() < 0.25 then + error("Unlucky error!") +end + + +local main = coroutine.running() +local start = os.clock() + +local counter = 0 +for j = 1, 10_000 do + __runtime__spawn(function() + wait() + counter += 1 + if counter == 10_000 then + local elapsed = os.clock() - start + __runtime__spawn(main, elapsed) + end + end) +end + +return coroutine.yield() diff --git a/examples/main.rs b/examples/main.rs new file mode 100644 index 0000000..f1ed537 --- /dev/null +++ b/examples/main.rs @@ -0,0 +1,86 @@ +use std::{ + rc::Rc, + time::{Duration, Instant}, +}; + +use smol_mlua::{ + mlua::prelude::{Lua, LuaResult, LuaThread, LuaValue}, + smol::{lock::Mutex, Timer}, + Callbacks, IntoLuaThread, Runtime, +}; + +const MAIN_SCRIPT: &str = include_str!("./main.luau"); + +pub fn main() -> LuaResult<()> { + // Set up persistent lua environment + let lua = Lua::new(); + lua.globals().set( + "wait", + lua.create_async_function(|_, duration: Option| async move { + let duration = duration.unwrap_or_default().max(1.0 / 250.0); + let before = Instant::now(); + let after = Timer::after(Duration::from_secs_f64(duration)).await; + Ok((after - before).as_secs_f64()) + })?, + )?; + + // Load and run the main script a few times for the purposes of this example + for _ in 0..20 { + println!("..."); + match run(&lua, lua.load(MAIN_SCRIPT)) { + Err(e) => eprintln!("Errored:\n{e}"), + Ok(v) => println!("Returned value:\n{v:?}"), + } + } + + Ok(()) +} + +/** + Wrapper function to run the given `main` thread on a new [`Runtime`]. + + Waits for all threads to finish, including the main thread, and + returns the value or error of the main thread once exited. +*/ +fn run<'lua>(lua: &'lua Lua, main: impl IntoLuaThread<'lua>) -> LuaResult { + // Set up runtime (thread queue / async executors) + let rt = Runtime::new(lua)?; + let thread = rt.push_main(lua, main, ()); + lua.set_named_registry_value("mainThread", thread)?; + + // Add callbacks to capture resulting value/error of main thread, + // we need to do some tricks to get around lifetime issues with 'lua + // being different inside the callback vs outside the callback for LuaValue + let captured_error = Rc::new(Mutex::new(None)); + let captured_error_inner = Rc::clone(&captured_error); + Callbacks::new() + .on_value(|lua, thread, val| { + let main: LuaThread = lua.named_registry_value("mainThread").unwrap(); + if main == thread { + lua.set_named_registry_value("mainValue", val).unwrap(); + } + }) + .on_error(move |lua, thread, err| { + let main: LuaThread = lua.named_registry_value("mainThread").unwrap(); + if main == thread { + captured_error_inner.lock_blocking().replace(err); + } + }) + .inject(lua); + + // Run until end + rt.run_blocking(lua); + + // Extract value and error from their containers + let err_opt = { captured_error.lock_blocking().take() }; + let val_opt = lua.named_registry_value("mainValue").ok(); + + // Check result + if let Some(err) = err_opt { + Err(err) + } else if let Some(val) = val_opt { + Ok(val) + } else { + unreachable!("No value or error captured from main thread"); + } +} diff --git a/lib/callbacks.rs b/lib/callbacks.rs index 180c25c..1966b32 100644 --- a/lib/callbacks.rs +++ b/lib/callbacks.rs @@ -1,12 +1,15 @@ use mlua::prelude::*; -type ErrorCallback = Box Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>; type ValueCallback = Box Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>; +type ErrorCallback = Box Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>; + +const FORWARD_VALUE_KEY: &str = "__runtime__forwardValue"; +const FORWARD_ERROR_KEY: &str = "__runtime__forwardError"; #[derive(Default)] pub struct Callbacks { - on_error: Option, on_value: Option, + on_error: Option, } impl Callbacks { @@ -14,14 +17,6 @@ impl Callbacks { Default::default() } - pub fn on_error(mut self, f: F) -> Self - where - F: Fn(&Lua, LuaThread, LuaError) + 'static, - { - self.on_error.replace(Box::new(f)); - self - } - pub fn on_value(mut self, f: F) -> Self where F: Fn(&Lua, LuaThread, LuaValue) + 'static, @@ -30,23 +25,19 @@ impl Callbacks { self } - pub fn inject(self, lua: &Lua) { - // Create functions to forward errors & values - if let Some(f) = self.on_error { - lua.set_named_registry_value( - "__forward__error", - lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| { - f(lua, thread, err); - Ok(()) - }) - .expect("failed to create error callback function"), - ) - .expect("failed to store error callback function"); - } + pub fn on_error(mut self, f: F) -> Self + where + F: Fn(&Lua, LuaThread, LuaError) + 'static, + { + self.on_error.replace(Box::new(f)); + self + } + pub fn inject(self, lua: &Lua) { + // Create functions to forward values & errors if let Some(f) = self.on_value { lua.set_named_registry_value( - "__forward__value", + FORWARD_VALUE_KEY, lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| { f(lua, thread, val); Ok(()) @@ -55,17 +46,29 @@ impl Callbacks { ) .expect("failed to store value callback function"); } - } - pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) { - if let Ok(f) = lua.named_registry_value::("__forward__error") { - f.call::<_, ()>((thread, error)).unwrap(); + if let Some(f) = self.on_error { + lua.set_named_registry_value( + FORWARD_ERROR_KEY, + lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| { + f(lua, thread, err); + Ok(()) + }) + .expect("failed to create error callback function"), + ) + .expect("failed to store error callback function"); } } pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) { - if let Ok(f) = lua.named_registry_value::("__forward__value") { + if let Ok(f) = lua.named_registry_value::(FORWARD_VALUE_KEY) { f.call::<_, ()>((thread, value)).unwrap(); } } + + pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) { + if let Ok(f) = lua.named_registry_value::(FORWARD_ERROR_KEY) { + f.call::<_, ()>((thread, error)).unwrap(); + } + } } diff --git a/lib/runtime.rs b/lib/runtime.rs index ff46706..7ec49ab 100644 --- a/lib/runtime.rs +++ b/lib/runtime.rs @@ -13,6 +13,9 @@ use super::{ callbacks::Callbacks, storage::ThreadWithArgs, traits::IntoLuaThread, util::LuaThreadOrFunction, }; +const GLOBAL_NAME_SPAWN: &str = "__runtime__spawn"; +const GLOBAL_NAME_DEFER: &str = "__runtime__defer"; + pub struct Runtime { queue_status: Rc>, queue_spawn: Rc>>, @@ -94,10 +97,9 @@ impl Runtime { }, )?; - // FUTURE: Store these as named registry values instead - // so that they are not accessible from within user code - lua.globals().set("spawn", fn_spawn)?; - lua.globals().set("defer", fn_defer)?; + // Store them both as globals + lua.globals().set(GLOBAL_NAME_SPAWN, fn_spawn)?; + lua.globals().set(GLOBAL_NAME_DEFER, fn_defer)?; Ok(Runtime { queue_status, @@ -196,7 +198,7 @@ impl Runtime { } }; - main_exec.run(fut).await + main_exec.run(fut).await; } /** diff --git a/src/main.luau b/src/main.luau deleted file mode 100644 index 190b303..0000000 --- a/src/main.luau +++ /dev/null @@ -1,19 +0,0 @@ -local start = os.clock() - -local thread = coroutine.running() - -local counter = 0 -for j = 1, 10_000 do - spawn(function() - wait() - counter += 1 - if counter == 10_000 then - print("completed") - spawn(thread) - end - end) -end - -coroutine.yield() - -return os.clock() - start diff --git a/src/main.rs b/src/main.rs deleted file mode 100644 index 88fd917..0000000 --- a/src/main.rs +++ /dev/null @@ -1,49 +0,0 @@ -use std::time::{Duration, Instant}; - -use smol_mlua::{mlua::prelude::*, smol::*, Callbacks, Runtime}; - -const MAIN_SCRIPT: &str = include_str!("./main.luau"); - -pub fn main() -> LuaResult<()> { - let start = Instant::now(); - let lua = Lua::new(); - - // Set up persistent lua environment - lua.globals().set( - "wait", - lua.create_async_function(|_, duration: Option| async move { - let duration = duration.unwrap_or_default().max(1.0 / 250.0); - let before = Instant::now(); - let after = Timer::after(Duration::from_secs_f64(duration)).await; - Ok((after - before).as_secs_f64()) - })?, - )?; - - // Set up runtime (thread queue / async executors) - let rt = Runtime::new(&lua)?; - let main = rt.push_main(&lua, lua.load(MAIN_SCRIPT), ()); - lua.set_named_registry_value("main", main)?; - - // Add callbacks to capture resulting value/error of main thread - Callbacks::new() - .on_value(|lua, thread, val| { - let main = lua.named_registry_value::("main").unwrap(); - if main == thread { - println!("main thread value: {:?}", val); - } - }) - .on_error(|lua, thread, err| { - let main = lua.named_registry_value::("main").unwrap(); - if main == thread { - eprintln!("main thread error: {:?}", err); - } - }) - .inject(&lua); - - // Run until end - rt.run_blocking(&lua); - - println!("elapsed: {:?}", start.elapsed()); - - Ok(()) -}