diff --git a/Cargo.lock b/Cargo.lock index 1031946..8693530 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -315,14 +315,6 @@ version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" -[[package]] -name = "luau-scheduler-experiments" -version = "0.0.0" -dependencies = [ - "mlua", - "smol", -] - [[package]] name = "luau0-src" version = "0.7.11+luau606" @@ -521,6 +513,14 @@ dependencies = [ "futures-lite", ] +[[package]] +name = "smol-mlua" +version = "0.0.0" +dependencies = [ + "mlua", + "smol", +] + [[package]] name = "syn" version = "2.0.48" diff --git a/Cargo.toml b/Cargo.toml index 5133fdc..e880e69 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,11 @@ [package] -name = "luau-scheduler-experiments" +name = "smol-mlua" version = "0.0.0" edition = "2021" [dependencies] smol = "2.0" mlua = { version = "0.9", features = ["luau", "luau-jit", "async"] } + +[lib] +path = "lib/lib.rs" diff --git a/lib/lib.rs b/lib/lib.rs new file mode 100644 index 0000000..a644945 --- /dev/null +++ b/lib/lib.rs @@ -0,0 +1,7 @@ +pub mod thread_callbacks; +pub mod thread_runtime; +pub mod thread_storage; +pub mod thread_util; + +pub use mlua; +pub use smol; diff --git a/lib/thread_callbacks.rs b/lib/thread_callbacks.rs new file mode 100644 index 0000000..55587d3 --- /dev/null +++ b/lib/thread_callbacks.rs @@ -0,0 +1,71 @@ +use mlua::prelude::*; + +type ErrorCallback = Box Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>; +type ValueCallback = Box Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>; + +#[derive(Default)] +pub struct ThreadCallbacks { + on_error: Option, + on_value: Option, +} + +impl ThreadCallbacks { + pub fn new() -> ThreadCallbacks { + Default::default() + } + + pub fn on_error(mut self, f: F) -> Self + where + F: Fn(&Lua, LuaThread, LuaError) + 'static, + { + self.on_error.replace(Box::new(f)); + self + } + + pub fn on_value(mut self, f: F) -> Self + where + F: Fn(&Lua, LuaThread, LuaValue) + 'static, + { + self.on_value.replace(Box::new(f)); + self + } + + pub fn inject(self, lua: &Lua) { + // Create functions to forward errors & values + if let Some(f) = self.on_error { + lua.set_named_registry_value( + "__forward__error", + lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| { + f(lua, thread, err); + Ok(()) + }) + .expect("failed to create error callback function"), + ) + .expect("failed to store error callback function"); + } + + if let Some(f) = self.on_value { + lua.set_named_registry_value( + "__forward__value", + lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| { + f(lua, thread, val); + Ok(()) + }) + .expect("failed to create value callback function"), + ) + .expect("failed to store value callback function"); + } + } + + pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) { + if let Ok(f) = lua.named_registry_value::("__forward__error") { + f.call::<_, ()>((thread, error)).unwrap(); + } + } + + pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) { + if let Ok(f) = lua.named_registry_value::("__forward__value") { + f.call::<_, ()>((thread, value)).unwrap(); + } + } +} diff --git a/src/thread_runtime.rs b/lib/thread_runtime.rs similarity index 89% rename from src/thread_runtime.rs rename to lib/thread_runtime.rs index cee3dbf..8b4de35 100644 --- a/src/thread_runtime.rs +++ b/lib/thread_runtime.rs @@ -10,8 +10,9 @@ use smol::{ }; use super::{ + thread_callbacks::ThreadCallbacks, + thread_storage::ThreadWithArgs, thread_util::{IntoLuaThread, LuaThreadOrFunction}, - ThreadWithArgs, }; pub struct ThreadRuntime { @@ -56,10 +57,6 @@ impl ThreadRuntime { // and only if we get the pending value back we can spawn to async executor let pending: LuaValue = lua.registry_value(&pending_key)?; match thread.resume::<_, LuaValue>(args.clone()) { - Err(e) => { - eprintln!("{:?}", e); - // TODO: Forward error - } Ok(v) if v == pending => { let stored = ThreadWithArgs::new(lua, thread.clone(), args); q_spawn.lock_blocking().push(stored); @@ -68,9 +65,8 @@ impl ThreadRuntime { LuaError::runtime("Tried to spawn thread to a dropped queue") })?; } - Ok(_) => { - // TODO: Forward value - } + Ok(v) => ThreadCallbacks::forward_value(lua, thread.clone(), v), + Err(e) => ThreadCallbacks::forward_error(lua, thread.clone(), e), } Ok(thread) } else { @@ -122,17 +118,19 @@ impl ThreadRuntime { lua: &'lua Lua, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) { + ) -> LuaThread<'lua> { let thread = thread .into_lua_thread(lua) .expect("failed to create thread"); let args = args.into_lua_multi(lua).expect("failed to create args"); - let stored = ThreadWithArgs::new(lua, thread, args); + let stored = ThreadWithArgs::new(lua, thread.clone(), args); self.queue_spawn.lock_blocking().push(stored); self.queue_status.replace(true); self.tx.try_send(()).unwrap(); // Unwrap is safe since this struct also holds the receiver + + thread } /** @@ -177,21 +175,16 @@ impl ThreadRuntime { // before we got here, so we need to check it again let (thread, args) = queued_thread.into_inner(lua); if thread.status() == LuaThreadStatus::Resumable { - let mut stream = thread.into_async::<_, LuaValue>(args); + let mut stream = thread.clone().into_async::<_, LuaValue>(args); lua_exec .spawn(async move { // Only run stream until first coroutine.yield or completion. We will // drop it right away to clear stack space since detached tasks dont drop // until the executor drops https://github.com/smol-rs/smol/issues/294 match stream.next().await.unwrap() { - Err(e) => { - eprintln!("{e}"); - // TODO: Forward error - } - Ok(_) => { - // TODO: Forward value - } - } + Ok(v) => ThreadCallbacks::forward_value(lua, thread, v), + Err(e) => ThreadCallbacks::forward_error(lua, thread, e), + }; }) .detach(); } diff --git a/src/thread_storage.rs b/lib/thread_storage.rs similarity index 100% rename from src/thread_storage.rs rename to lib/thread_storage.rs diff --git a/src/thread_util.rs b/lib/thread_util.rs similarity index 100% rename from src/thread_util.rs rename to lib/thread_util.rs diff --git a/src/main.luau b/src/main.luau index 224e240..190b303 100644 --- a/src/main.luau +++ b/src/main.luau @@ -1,18 +1,19 @@ -for i = 1, 20 do - print("iteration " .. tostring(i) .. " of 20") - local thread = coroutine.running() +local start = os.clock() - local counter = 0 - for j = 1, 50_000 do - spawn(function() - wait() - counter += 1 - if counter == 50_000 then - print("completed iteration " .. tostring(i) .. " of 20") - spawn(thread) - end - end) - end +local thread = coroutine.running() - coroutine.yield() +local counter = 0 +for j = 1, 10_000 do + spawn(function() + wait() + counter += 1 + if counter == 10_000 then + print("completed") + spawn(thread) + end + end) end + +coroutine.yield() + +return os.clock() - start diff --git a/src/main.rs b/src/main.rs index af60383..499b588 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,12 +5,7 @@ use smol::*; const MAIN_SCRIPT: &str = include_str!("./main.luau"); -mod thread_runtime; -mod thread_storage; -mod thread_util; - -use thread_runtime::*; -use thread_storage::*; +use smol_mlua::{thread_callbacks::ThreadCallbacks, thread_runtime::ThreadRuntime}; pub fn main() -> LuaResult<()> { let start = Instant::now(); @@ -27,9 +22,28 @@ pub fn main() -> LuaResult<()> { })?, )?; - // Set up runtime (thread queue / async executors) and run main script until end + // Set up runtime (thread queue / async executors) let rt = ThreadRuntime::new(&lua)?; - rt.push_main(&lua, lua.load(MAIN_SCRIPT), ()); + let main = rt.push_main(&lua, lua.load(MAIN_SCRIPT), ()); + lua.set_named_registry_value("main", main)?; + + // Add callbacks to capture resulting value/error of main thread + ThreadCallbacks::new() + .on_value(|lua, thread, val| { + let main = lua.named_registry_value::("main").unwrap(); + if main == thread { + println!("main thread value: {:?}", val); + } + }) + .on_error(|lua, thread, err| { + let main = lua.named_registry_value::("main").unwrap(); + if main == thread { + eprintln!("main thread error: {:?}", err); + } + }) + .inject(&lua); + + // Run until end rt.run_blocking(&lua); println!("elapsed: {:?}", start.elapsed());