diff --git a/Cargo.toml b/Cargo.toml index ee28b7b..660115e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,10 +22,6 @@ test = true name = "callbacks" test = true -[[example]] -name = "captures" -test = true - [[example]] name = "lots_of_threads" test = true diff --git a/README.md b/README.md index b978c12..c0d4fd0 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,10 @@
-Integration between [smol](https://crates.io/crates/smol) and [mlua](https://crates.io/crates/mlua) that provides a fully functional and asynchronous Luau runtime using smol executor(s). +Integration between [smol] and [mlua] that provides a fully functional and asynchronous Luau runtime using smol executor(s). + +[smol]: https://crates.io/crates/smol +[mlua]: https://crates.io/crates/mlua ## Example Usage @@ -25,11 +28,9 @@ Integration between [smol](https://crates.io/crates/smol) and [mlua](https://cra ```rs use std::time::{Duration, Instant}; -use smol_mlua::{ - mlua::prelude::*, - smol::{Timer, io, fs::read_to_string}, - Runtime, -}; +use mlua::prelude::*; +use smol::{Timer, io, fs::read_to_string} +use smol_mlua::Runtime; ``` ### 2. Set up lua environment @@ -68,15 +69,15 @@ lua.globals().set( ```rs let rt = Runtime::new(&lua)?; -// We can create multiple lua threads +// We can create multiple lua threads ... let sleepThread = lua.load("sleep(0.1)"); let fileThread = lua.load("readFile(\"Cargo.toml\")"); -// Put them all into the runtime -rt.push_thread(sleepThread, ()); -rt.push_thread(fileThread, ()); +// ... spawn them both onto the runtime ... +rt.spawn_thread(sleepThread, ()); +rt.spawn_thread(fileThread, ()); -// And run either async or blocking, until above threads finish +// ... and run either async or blocking, until they finish rt.run_async().await; rt.run_blocking(); ``` diff --git a/examples/basic_sleep.rs b/examples/basic_sleep.rs index d5a427a..52f376d 100644 --- a/examples/basic_sleep.rs +++ b/examples/basic_sleep.rs @@ -1,10 +1,8 @@ use std::time::{Duration, Instant}; -use smol_mlua::{ - mlua::prelude::{Lua, LuaResult}, - smol::Timer, - Runtime, -}; +use mlua::prelude::*; +use smol::Timer; +use smol_mlua::Runtime; const MAIN_SCRIPT: &str = include_str!("./lua/basic_sleep.luau"); @@ -23,7 +21,7 @@ pub fn main() -> LuaResult<()> { // Load the main script into a runtime and run it until completion let rt = Runtime::new(&lua)?; let main = lua.load(MAIN_SCRIPT); - rt.push_thread(main, ()); + rt.spawn_thread(main, ())?; rt.run_blocking(); Ok(()) diff --git a/examples/basic_spawn.rs b/examples/basic_spawn.rs index c43f8f6..a8b1ad9 100644 --- a/examples/basic_spawn.rs +++ b/examples/basic_spawn.rs @@ -1,10 +1,6 @@ -use mlua::ExternalResult; -use smol::io; -use smol_mlua::{ - mlua::prelude::{Lua, LuaResult}, - smol::fs::read_to_string, - LuaExecutorExt, Runtime, -}; +use mlua::prelude::*; +use smol::{fs::read_to_string, io}; +use smol_mlua::{LuaSpawnExt, Runtime}; const MAIN_SCRIPT: &str = include_str!("./lua/basic_spawn.luau"); @@ -29,7 +25,7 @@ pub fn main() -> LuaResult<()> { // Load the main script into a runtime and run it until completion let rt = Runtime::new(&lua)?; let main = lua.load(MAIN_SCRIPT); - rt.push_thread(main, ()); + rt.spawn_thread(main, ())?; rt.run_blocking(); Ok(()) diff --git a/examples/callbacks.rs b/examples/callbacks.rs index 447c8ce..0de42a0 100644 --- a/examples/callbacks.rs +++ b/examples/callbacks.rs @@ -1,7 +1,5 @@ -use smol_mlua::{ - mlua::prelude::{Lua, LuaResult}, - Callbacks, Runtime, -}; +use mlua::prelude::*; +use smol_mlua::Runtime; const MAIN_SCRIPT: &str = include_str!("./lua/callbacks.luau"); @@ -11,17 +9,17 @@ pub fn main() -> LuaResult<()> { // Create a new runtime with custom callbacks let rt = Runtime::new(&lua)?; - rt.set_callbacks(Callbacks::default().on_error(|_, _, e| { + rt.set_error_callback(|e| { println!( "Captured error from Lua!\n{}\n{e}\n{}", "-".repeat(15), "-".repeat(15) ); - })); + }); // Load and run the main script until completion let main = lua.load(MAIN_SCRIPT); - rt.push_thread(main, ()); + rt.spawn_thread(main, ())?; rt.run_blocking(); Ok(()) diff --git a/examples/captures.rs b/examples/captures.rs deleted file mode 100644 index 2d2f8c6..0000000 --- a/examples/captures.rs +++ /dev/null @@ -1,92 +0,0 @@ -use std::{ - rc::Rc, - time::{Duration, Instant}, -}; - -use smol_mlua::{ - mlua::prelude::{Lua, LuaResult, LuaThread, LuaValue}, - smol::{lock::Mutex, Timer}, - Callbacks, IntoLuaThread, Runtime, -}; - -const MAIN_SCRIPT: &str = include_str!("./lua/captures.luau"); - -pub fn main() -> LuaResult<()> { - // Set up persistent lua environment - let lua = Lua::new(); - lua.globals().set( - "sleep", - lua.create_async_function(|_, duration: Option| async move { - let duration = duration.unwrap_or_default().max(1.0 / 250.0); - let before = Instant::now(); - let after = Timer::after(Duration::from_secs_f64(duration)).await; - Ok((after - before).as_secs_f64()) - })?, - )?; - - // Load and run the main script a few times for the purposes of this example - for _ in 0..20 { - println!("..."); - match run(&lua, lua.load(MAIN_SCRIPT)) { - Err(e) => eprintln!("Errored:\n{e}"), - Ok(v) => println!("Returned value:\n{v:?}"), - } - } - - Ok(()) -} - -/** - Wrapper function to run the given `main` thread on a new [`Runtime`]. - - Waits for all threads to finish, including the main thread, and - returns the value or error of the main thread once exited. -*/ -fn run<'lua>(lua: &'lua Lua, main: impl IntoLuaThread<'lua>) -> LuaResult { - // Set up runtime (thread queue / async executors) - let rt = Runtime::new(lua)?; - let thread = rt.push_thread(main, ()); - lua.set_named_registry_value("mainThread", thread)?; - - // Create callbacks to capture resulting value/error of main thread, - // we need to do some tricks to get around the lifetime issues with 'lua - // being different inside the callback vs. outside the callback, for LuaValue - let captured_error = Rc::new(Mutex::new(None)); - let captured_error_inner = Rc::clone(&captured_error); - rt.set_callbacks( - Callbacks::new() - .on_value(|lua, thread, val| { - let main: LuaThread = lua.named_registry_value("mainThread").unwrap(); - if main == thread { - lua.set_named_registry_value("mainValue", val).unwrap(); - } - }) - .on_error(move |lua, thread, err| { - let main: LuaThread = lua.named_registry_value("mainThread").unwrap(); - if main == thread { - captured_error_inner.lock_blocking().replace(err); - } - }), - ); - - // Run until end - rt.run_blocking(); - - // Extract value and error from their containers - let err_opt = { captured_error.lock_blocking().take() }; - let val_opt = lua.named_registry_value("mainValue").ok(); - - // Check result - if let Some(err) = err_opt { - Err(err) - } else if let Some(val) = val_opt { - Ok(val) - } else { - unreachable!("No value or error captured from main thread"); - } -} - -#[test] -fn test_captures() -> LuaResult<()> { - main() -} diff --git a/examples/lots_of_threads.rs b/examples/lots_of_threads.rs index 10d7c6c..d4d17bd 100644 --- a/examples/lots_of_threads.rs +++ b/examples/lots_of_threads.rs @@ -1,18 +1,25 @@ use std::time::Duration; -use smol_mlua::{ - mlua::prelude::{Lua, LuaResult}, - smol::Timer, - Runtime, -}; +use mlua::prelude::*; +use smol::Timer; +use smol_mlua::Runtime; const MAIN_SCRIPT: &str = include_str!("./lua/lots_of_threads.luau"); const ONE_NANOSECOND: Duration = Duration::from_nanos(1); pub fn main() -> LuaResult<()> { - // Set up persistent lua environment - let lua = Lua::new(); + // Set up persistent lua environment, note that we enable thread reuse for + // mlua's internal async handling since we will be spawning lots of threads + let lua = Lua::new_with( + LuaStdLib::ALL, + LuaOptions::new() + .catch_rust_panics(false) + .thread_pool_size(10_000), + )?; + let rt = Runtime::new(&lua)?; + + lua.globals().set("spawn", rt.create_spawn_function()?)?; lua.globals().set( "sleep", lua.create_async_function(|_, ()| async move { @@ -23,10 +30,9 @@ pub fn main() -> LuaResult<()> { })?, )?; - // Load the main script into a runtime and run it until completion - let rt = Runtime::new(&lua)?; + // Load the main script into the runtime and run it until completion let main = lua.load(MAIN_SCRIPT); - rt.push_thread(main, ()); + rt.spawn_thread(main, ())?; rt.run_blocking(); Ok(()) diff --git a/examples/lua/captures.luau b/examples/lua/captures.luau deleted file mode 100644 index 74661af..0000000 --- a/examples/lua/captures.luau +++ /dev/null @@ -1,23 +0,0 @@ ---!nocheck ---!nolint UnknownGlobal - -if math.random() < 0.25 then - error("Unlucky error!") -end - -local main = coroutine.running() -local start = os.clock() - -local counter = 0 -for j = 1, 10_000 do - __runtime__spawn(function() - sleep() - counter += 1 - if counter == 10_000 then - local elapsed = os.clock() - start - __runtime__spawn(main, elapsed) - end - end) -end - -return coroutine.yield() diff --git a/examples/lua/lots_of_threads.luau b/examples/lua/lots_of_threads.luau index 3144f15..3958284 100644 --- a/examples/lua/lots_of_threads.luau +++ b/examples/lua/lots_of_threads.luau @@ -13,11 +13,11 @@ for i = 1, NUM_BATCHES do local counter = 0 for j = 1, NUM_THREADS do - __runtime__spawn(function() + spawn(function() sleep() counter += 1 if counter == NUM_THREADS then - __runtime__spawn(thread) + spawn(thread) end end) end diff --git a/examples/lua/scheduler_ordering.luau b/examples/lua/scheduler_ordering.luau index 945cb5c..264f503 100644 --- a/examples/lua/scheduler_ordering.luau +++ b/examples/lua/scheduler_ordering.luau @@ -4,12 +4,12 @@ print(1) -- Defer will run at the end of the resumption cycle, but without yielding -__runtime__defer(function() +defer(function() print(5) end) -- Spawn will instantly run up until the first yield, and must then be resumed manually ... -__runtime__spawn(function() +spawn(function() print(2) coroutine.yield() print("unreachable") @@ -17,7 +17,7 @@ end) -- ... unless calling functions created using `lua.create_async_function(...)`, -- which will resume their calling thread with their result automatically -__runtime__spawn(function() +spawn(function() print(3) sleep(1) print(6) diff --git a/examples/scheduler_ordering.rs b/examples/scheduler_ordering.rs index ce8e7a1..6aa5b61 100644 --- a/examples/scheduler_ordering.rs +++ b/examples/scheduler_ordering.rs @@ -1,16 +1,18 @@ use std::time::{Duration, Instant}; -use smol_mlua::{ - mlua::prelude::{Lua, LuaResult}, - smol::Timer, - Runtime, -}; +use mlua::prelude::*; +use smol::Timer; +use smol_mlua::Runtime; const MAIN_SCRIPT: &str = include_str!("./lua/scheduler_ordering.luau"); pub fn main() -> LuaResult<()> { // Set up persistent lua environment let lua = Lua::new(); + let rt = Runtime::new(&lua)?; + + lua.globals().set("spawn", rt.create_spawn_function()?)?; + lua.globals().set("defer", rt.create_defer_function()?)?; lua.globals().set( "sleep", lua.create_async_function(|_, duration: Option| async move { @@ -22,9 +24,8 @@ pub fn main() -> LuaResult<()> { )?; // Load the main script into a runtime and run it until completion - let rt = Runtime::new(&lua)?; let main = lua.load(MAIN_SCRIPT); - rt.push_thread(main, ()); + rt.spawn_thread(main, ())?; rt.run_blocking(); Ok(()) diff --git a/lib/callbacks.rs b/lib/callbacks.rs deleted file mode 100644 index 7b86de0..0000000 --- a/lib/callbacks.rs +++ /dev/null @@ -1,128 +0,0 @@ -use mlua::prelude::*; - -type ValueCallback = Box Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>; -type ErrorCallback = Box Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>; - -const FORWARD_VALUE_KEY: &str = "__runtime__forwardValue"; -const FORWARD_ERROR_KEY: &str = "__runtime__forwardError"; - -/** - A set of callbacks for thread values and errors. - - These callbacks are used to forward values and errors from - Lua threads back to Rust. By default, the runtime will print - any errors to stderr and not do any operations with values. - - You can set your own callbacks using the `on_value` and `on_error` builder methods. -*/ -pub struct Callbacks { - on_value: Option, - on_error: Option, -} - -impl Callbacks { - /** - Creates a new set of callbacks with no callbacks set. - */ - pub fn new() -> Self { - Self { - on_value: None, - on_error: None, - } - } - - /** - Sets the callback for thread values being yielded / returned. - */ - pub fn on_value(mut self, f: F) -> Self - where - F: Fn(&Lua, LuaThread, LuaValue) + 'static, - { - self.on_value.replace(Box::new(f)); - self - } - - /** - Sets the callback for thread errors. - */ - pub fn on_error(mut self, f: F) -> Self - where - F: Fn(&Lua, LuaThread, LuaError) + 'static, - { - self.on_error.replace(Box::new(f)); - self - } - - /** - Removes any current thread value callback. - */ - pub fn without_value_callback(mut self) -> Self { - self.on_value.take(); - self - } - - /** - Removes any current thread error callback. - */ - pub fn without_error_callback(mut self) -> Self { - self.on_error.take(); - self - } - - pub(crate) fn inject(self, lua: &Lua) { - // Remove any previously injected callbacks - lua.unset_named_registry_value(FORWARD_VALUE_KEY).ok(); - lua.unset_named_registry_value(FORWARD_ERROR_KEY).ok(); - - // Create functions to forward values & errors - if let Some(f) = self.on_value { - lua.set_named_registry_value( - FORWARD_VALUE_KEY, - lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| { - f(lua, thread, val); - Ok(()) - }) - .expect("failed to create value callback function"), - ) - .expect("failed to store value callback function"); - } - - if let Some(f) = self.on_error { - lua.set_named_registry_value( - FORWARD_ERROR_KEY, - lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| { - f(lua, thread, err); - Ok(()) - }) - .expect("failed to create error callback function"), - ) - .expect("failed to store error callback function"); - } - } - - pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) { - if let Ok(f) = lua.named_registry_value::(FORWARD_VALUE_KEY) { - f.call::<_, ()>((thread, value)).unwrap(); - } - } - - pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) { - if let Ok(f) = lua.named_registry_value::(FORWARD_ERROR_KEY) { - f.call::<_, ()>((thread, error)).unwrap(); - } - } -} - -impl Default for Callbacks { - fn default() -> Self { - Callbacks { - on_value: Some(Box::new(default_value_callback)), - on_error: Some(Box::new(default_error_callback)), - } - } -} - -fn default_value_callback(_: &Lua, _: LuaThread, _: LuaValue) {} -fn default_error_callback(_: &Lua, _: LuaThread, e: LuaError) { - eprintln!("{e}"); -} diff --git a/lib/error_callback.rs b/lib/error_callback.rs new file mode 100644 index 0000000..1e9f04b --- /dev/null +++ b/lib/error_callback.rs @@ -0,0 +1,52 @@ +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +use mlua::prelude::*; +use smol::lock::Mutex; + +type ErrorCallback = Box; + +#[derive(Clone)] +pub(crate) struct ThreadErrorCallback { + exists: Arc, + inner: Arc>>, +} + +impl ThreadErrorCallback { + pub fn new() -> Self { + Self { + exists: Arc::new(AtomicBool::new(false)), + inner: Arc::new(Mutex::new(None)), + } + } + + pub fn new_default() -> Self { + let this = Self::new(); + this.replace(default_error_callback); + this + } + + pub fn replace(&self, callback: impl Fn(LuaError) + Send + 'static) { + self.exists.store(true, Ordering::Relaxed); + self.inner.lock_blocking().replace(Box::new(callback)); + } + + pub fn clear(&self) { + self.exists.store(false, Ordering::Relaxed); + self.inner.lock_blocking().take(); + } + + pub fn call(&self, error: &LuaError) { + if self.exists.load(Ordering::Relaxed) { + if let Some(cb) = &*self.inner.lock_blocking() { + cb(error.clone()); + } + } + } +} + +fn default_error_callback(e: LuaError) { + eprintln!("{e}"); +} diff --git a/lib/lib.rs b/lib/lib.rs index 5aa95b6..627eebd 100644 --- a/lib/lib.rs +++ b/lib/lib.rs @@ -1,12 +1,8 @@ -mod callbacks; +mod error_callback; +mod queue; mod runtime; -mod storage; mod traits; mod util; -pub use mlua; -pub use smol; - -pub use callbacks::Callbacks; pub use runtime::Runtime; -pub use traits::{IntoLuaThread, LuaExecutorExt}; +pub use traits::{IntoLuaThread, LuaSpawnExt}; diff --git a/lib/queue.rs b/lib/queue.rs new file mode 100644 index 0000000..6ac9f27 --- /dev/null +++ b/lib/queue.rs @@ -0,0 +1,107 @@ +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, +}; + +use mlua::prelude::*; +use smol::{ + channel::{unbounded, Receiver, Sender}, + lock::Mutex, +}; + +use crate::IntoLuaThread; + +const ERR_OOM: &str = "out of memory"; + +/** + Queue for storing [`LuaThread`]s with associated arguments. + + Provides methods for pushing and draining the queue, as + well as listening for new items being pushed to the queue. +*/ +#[derive(Debug, Clone)] +pub struct ThreadQueue { + queue: Arc>>, + status: Arc, + signal_tx: Sender<()>, + signal_rx: Receiver<()>, +} + +impl ThreadQueue { + pub fn new() -> Self { + let (signal_tx, signal_rx) = unbounded(); + Self { + queue: Arc::new(Mutex::new(Vec::new())), + status: Arc::new(AtomicBool::new(false)), + signal_tx, + signal_rx, + } + } + + pub fn has_threads(&self) -> bool { + self.status.load(Ordering::SeqCst) + } + + pub fn push<'lua>( + &self, + lua: &'lua Lua, + thread: impl IntoLuaThread<'lua>, + args: impl IntoLuaMulti<'lua>, + ) -> LuaResult<()> { + let thread = thread.into_lua_thread(lua)?; + let args = args.into_lua_multi(lua)?; + let stored = ThreadWithArgs::new(lua, thread, args); + + self.queue.lock_blocking().push(stored); + self.status.store(true, Ordering::SeqCst); + self.signal_tx.try_send(()).unwrap(); + + Ok(()) + } + + pub async fn drain<'lua>(&self, lua: &'lua Lua) -> Vec<(LuaThread<'lua>, LuaMultiValue<'lua>)> { + let mut queue = self.queue.lock().await; + let drained = queue.drain(..).map(|s| s.into_inner(lua)).collect(); + self.status.store(false, Ordering::SeqCst); + drained + } + + pub async fn recv(&self) { + self.signal_rx.recv().await.unwrap(); + } +} + +/** + Representation of a [`LuaThread`] with associated arguments currently stored in the Lua registry. +*/ +#[derive(Debug)] +struct ThreadWithArgs { + key_thread: LuaRegistryKey, + key_args: LuaRegistryKey, +} + +impl ThreadWithArgs { + pub fn new<'lua>(lua: &'lua Lua, thread: LuaThread<'lua>, args: LuaMultiValue<'lua>) -> Self { + let argsv = args.into_vec(); + + let key_thread = lua.create_registry_value(thread).expect(ERR_OOM); + let key_args = lua.create_registry_value(argsv).expect(ERR_OOM); + + Self { + key_thread, + key_args, + } + } + + pub fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) { + let thread = lua.registry_value(&self.key_thread).unwrap(); + let argsv = lua.registry_value(&self.key_args).unwrap(); + + let args = LuaMultiValue::from_vec(argsv); + + lua.remove_registry_value(self.key_thread).unwrap(); + lua.remove_registry_value(self.key_args).unwrap(); + + (thread, args) + } +} diff --git a/lib/runtime.rs b/lib/runtime.rs index 661bf7d..7887bb8 100644 --- a/lib/runtime.rs +++ b/lib/runtime.rs @@ -1,160 +1,149 @@ -use std::{cell::Cell, rc::Rc, sync::Arc}; +use std::sync::Arc; use mlua::prelude::*; use smol::prelude::*; -use smol::{ - block_on, - channel::{unbounded, Receiver, Sender}, - lock::Mutex, - Executor, LocalExecutor, -}; +use smol::{block_on, Executor, LocalExecutor}; use super::{ - callbacks::Callbacks, storage::ThreadWithArgs, traits::IntoLuaThread, util::LuaThreadOrFunction, + error_callback::ThreadErrorCallback, + queue::ThreadQueue, + traits::IntoLuaThread, + util::{is_poll_pending, LuaThreadOrFunction}, }; -const GLOBAL_NAME_SPAWN: &str = "__runtime__spawn"; -const GLOBAL_NAME_DEFER: &str = "__runtime__defer"; - pub struct Runtime<'lua> { lua: &'lua Lua, - queue_status: Rc>, - // TODO: Use something better than Rc>> - queue_spawn: Rc>>, - queue_defer: Rc>>, - tx: Sender<()>, - rx: Receiver<()>, + queue_spawn: ThreadQueue, + queue_defer: ThreadQueue, + error_callback: ThreadErrorCallback, } impl<'lua> Runtime<'lua> { /** Creates a new runtime for the given Lua state. - This will inject some functions to interact with the scheduler / executor, - as well as the default [`Callbacks`] for thread values and errors. + This runtime will have a default error callback that prints errors to stderr. */ pub fn new(lua: &'lua Lua) -> LuaResult> { - let queue_status = Rc::new(Cell::new(false)); - let queue_spawn = Rc::new(Mutex::new(Vec::new())); - let queue_defer = Rc::new(Mutex::new(Vec::new())); - let (tx, rx) = unbounded(); - - // HACK: Extract mlua "pending" constant value and store it - let pending = lua - .create_async_function(|_, ()| async move { - smol::future::yield_now().await; - Ok(()) - })? - .into_lua_thread(lua)? - .resume::<_, LuaValue>(())?; - let pending_key = lua.create_registry_value(pending)?; - - // TODO: Generalize these two functions below so we - // dont need to duplicate the same exact thing for - // spawn and defer which is prone to human error - - // Create spawn function (push to start of queue) - let b_spawn = Rc::clone(&queue_status); - let q_spawn = Rc::clone(&queue_spawn); - let tx_spawn = tx.clone(); - let fn_spawn = lua.create_function( - move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| { - let thread = tof.into_thread(lua)?; - if thread.status() == LuaThreadStatus::Resumable { - // HACK: We need to resume the thread once instantly for correct behavior, - // and only if we get the pending value back we can spawn to async executor - let pending: LuaValue = lua.registry_value(&pending_key)?; - match thread.resume::<_, LuaValue>(args.clone()) { - Ok(v) if v == pending => { - let stored = ThreadWithArgs::new(lua, thread.clone(), args); - q_spawn.lock_blocking().push(stored); - b_spawn.replace(true); - tx_spawn.try_send(()).map_err(|_| { - LuaError::runtime("Tried to spawn thread to a dropped queue") - })?; - } - Ok(v) => Callbacks::forward_value(lua, thread.clone(), v), - Err(e) => Callbacks::forward_error(lua, thread.clone(), e), - } - Ok(thread) - } else { - Err(LuaError::runtime("Tried to spawn non-resumable thread")) - } - }, - )?; - - // Create defer function (push to end of queue) - let b_defer = Rc::clone(&queue_status); - let q_defer = Rc::clone(&queue_defer); - let tx_defer = tx.clone(); - let fn_defer = lua.create_function( - move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| { - let thread = tof.into_thread(lua)?; - if thread.status() == LuaThreadStatus::Resumable { - let stored = ThreadWithArgs::new(lua, thread.clone(), args); - q_defer.lock_blocking().push(stored); - b_defer.replace(true); - tx_defer.try_send(()).map_err(|_| { - LuaError::runtime("Tried to defer thread to a dropped queue") - })?; - Ok(thread) - } else { - Err(LuaError::runtime("Tried to defer non-resumable thread")) - } - }, - )?; - - // Store them both as globals - lua.globals().set(GLOBAL_NAME_SPAWN, fn_spawn)?; - lua.globals().set(GLOBAL_NAME_DEFER, fn_defer)?; - - // Finally, inject default callbacks - Callbacks::default().inject(lua); + let queue_spawn = ThreadQueue::new(); + let queue_defer = ThreadQueue::new(); + let error_callback = ThreadErrorCallback::new_default(); Ok(Runtime { lua, - queue_status, queue_spawn, queue_defer, - tx, - rx, + error_callback, }) } /** - Sets the callbacks for this runtime. + Sets the error callback for this runtime. - This will overwrite any previously set callbacks, including default ones. + This callback will be called whenever a Lua thread errors. + + Overwrites any previous error callback. */ - pub fn set_callbacks(&self, callbacks: Callbacks) { - callbacks.inject(self.lua); + pub fn set_error_callback(&self, callback: impl Fn(LuaError) + Send + 'static) { + self.error_callback.replace(callback); } /** - Pushes a chunk / function / thread to the runtime queue. + Clears the error callback for this runtime. + + This will remove any current error callback, including default(s). + */ + pub fn remove_error_callback(&self) { + self.error_callback.clear(); + } + + /** + Spawns a chunk / function / thread onto the runtime queue. Threads are guaranteed to be resumed in the order that they were pushed to the queue. */ - pub fn push_thread( + pub fn spawn_thread( &self, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaThread<'lua> { - let thread = thread - .into_lua_thread(self.lua) - .expect("failed to create thread"); - let args = args - .into_lua_multi(self.lua) - .expect("failed to create args"); + ) -> LuaResult<()> { + let thread = thread.into_lua_thread(self.lua)?; + let args = args.into_lua_multi(self.lua)?; - let stored = ThreadWithArgs::new(self.lua, thread.clone(), args); + self.queue_spawn.push(self.lua, thread, args)?; - self.queue_spawn.lock_blocking().push(stored); - self.queue_status.replace(true); - self.tx.try_send(()).unwrap(); // Unwrap is safe since this struct also holds the receiver + Ok(()) + } - thread + /** + Defers a chunk / function / thread onto the runtime queue. + + Deferred threads are guaranteed to run after all spawned threads either yield or complete. + + Threads are guaranteed to be resumed in the order that they were pushed to the queue. + */ + pub fn defer_thread( + &self, + thread: impl IntoLuaThread<'lua>, + args: impl IntoLuaMulti<'lua>, + ) -> LuaResult<()> { + let thread = thread.into_lua_thread(self.lua)?; + let args = args.into_lua_multi(self.lua)?; + + self.queue_defer.push(self.lua, thread, args)?; + + Ok(()) + } + + /** + Creates a lua function that can be used to spawn threads / functions onto the runtime queue. + + The function takes a thread or function as the first argument, and any variadic arguments as the rest. + */ + pub fn create_spawn_function(&self) -> LuaResult> { + let error_callback = self.error_callback.clone(); + let spawn_queue = self.queue_spawn.clone(); + self.lua.create_function( + move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| { + let thread = tof.into_thread(lua)?; + if thread.status() == LuaThreadStatus::Resumable { + // NOTE: We need to resume the thread once instantly for correct behavior, + // and only if we get the pending value back we can spawn to async executor + match thread.resume::<_, LuaValue>(args.clone()) { + Ok(v) => { + if is_poll_pending(&v) { + spawn_queue.push(lua, &thread, args)?; + } + } + Err(e) => { + error_callback.call(&e); + } + }; + } + Ok(thread) + }, + ) + } + + /** + Creates a lua function that can be used to defer threads / functions onto the runtime queue. + + The function takes a thread or function as the first argument, and any variadic arguments as the rest. + + Deferred threads are guaranteed to run after all spawned threads either yield or complete. + */ + pub fn create_defer_function(&self) -> LuaResult> { + let defer_queue = self.queue_defer.clone(); + self.lua.create_function( + move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| { + let thread = tof.into_thread(lua)?; + if thread.status() == LuaThreadStatus::Resumable { + defer_queue.push(lua, &thread, args)?; + } + Ok(thread) + }, + ) } /** @@ -170,7 +159,7 @@ impl<'lua> Runtime<'lua> { let lua_exec = LocalExecutor::new(); let main_exec = Arc::new(Executor::new()); - // Store the main executor in lua for LuaExecutorExt trait + // Store the main executor in lua for spawner trait self.lua.set_app_data(Arc::downgrade(&main_exec)); // Tick local lua executor while also driving main @@ -179,9 +168,8 @@ impl<'lua> Runtime<'lua> { loop { // Wait for a new thread to arrive __or__ next futures step, prioritizing // new threads, so we don't accidentally exit when there is more work to do - let fut_recv = async { - self.rx.recv().await.ok(); - }; + let fut_spawn = self.queue_spawn.recv(); + let fut_defer = self.queue_defer.recv(); let fut_tick = async { lua_exec.tick().await; // Do as much work as possible @@ -191,18 +179,18 @@ impl<'lua> Runtime<'lua> { } } }; - fut_recv.or(fut_tick).await; - // If a new thread was spawned onto any queue, we - // must drain them and schedule on the executor - if self.queue_status.get() { + fut_spawn.or(fut_defer).or(fut_tick).await; + + // If a new thread was spawned onto any queue, + // we must drain them and schedule on the executor + if self.queue_spawn.has_threads() || self.queue_defer.has_threads() { let mut queued_threads = Vec::new(); - queued_threads.extend(self.queue_spawn.lock().await.drain(..)); - queued_threads.extend(self.queue_defer.lock().await.drain(..)); - for queued_thread in queued_threads { + queued_threads.extend(self.queue_spawn.drain(self.lua).await); + queued_threads.extend(self.queue_defer.drain(self.lua).await); + for (thread, args) in queued_threads { // NOTE: Thread may have been cancelled from lua // before we got here, so we need to check it again - let (thread, args) = queued_thread.into_inner(self.lua); if thread.status() == LuaThreadStatus::Resumable { let mut stream = thread.clone().into_async::<_, LuaValue>(args); lua_exec @@ -210,10 +198,11 @@ impl<'lua> Runtime<'lua> { // Only run stream until first coroutine.yield or completion. We will // drop it right away to clear stack space since detached tasks dont drop // until the executor drops https://github.com/smol-rs/smol/issues/294 - match stream.next().await.unwrap() { - Ok(v) => Callbacks::forward_value(self.lua, thread, v), - Err(e) => Callbacks::forward_error(self.lua, thread, e), - }; + let res = stream.next().await.unwrap(); + if let Err(e) = &res { + self.error_callback.call(e); + } + // TODO: Figure out how to give this result to caller of spawn_thread/defer_thread }) .detach(); } diff --git a/lib/storage.rs b/lib/storage.rs deleted file mode 100644 index f067983..0000000 --- a/lib/storage.rs +++ /dev/null @@ -1,43 +0,0 @@ -use mlua::prelude::*; - -#[derive(Debug)] -pub(crate) struct ThreadWithArgs { - key_thread: LuaRegistryKey, - key_args: LuaRegistryKey, -} - -impl ThreadWithArgs { - pub fn new<'lua>(lua: &'lua Lua, thread: LuaThread<'lua>, args: LuaMultiValue<'lua>) -> Self { - let args_vec = args.into_vec(); - - let key_thread = lua - .create_registry_value(thread) - .expect("Failed to store thread in registry - out of memory"); - let key_args = lua - .create_registry_value(args_vec) - .expect("Failed to store thread args in registry - out of memory"); - - Self { - key_thread, - key_args, - } - } - - pub fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) { - let thread = lua - .registry_value(&self.key_thread) - .expect("Failed to get thread from registry"); - let args_vec = lua - .registry_value(&self.key_args) - .expect("Failed to get thread args from registry"); - - let args = LuaMultiValue::from_vec(args_vec); - - lua.remove_registry_value(self.key_thread) - .expect("Failed to remove thread from registry"); - lua.remove_registry_value(self.key_args) - .expect("Failed to remove thread args from registry"); - - (thread, args) - } -} diff --git a/lib/traits.rs b/lib/traits.rs index 3e2a1a5..95e7f52 100644 --- a/lib/traits.rs +++ b/lib/traits.rs @@ -36,13 +36,22 @@ impl<'lua> IntoLuaThread<'lua> for LuaChunk<'lua, '_> { } } +impl<'lua, T> IntoLuaThread<'lua> for &T +where + T: IntoLuaThread<'lua> + Clone, +{ + fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult> { + self.clone().into_lua_thread(lua) + } +} + /** Trait for spawning `Send` futures on the current executor. - For spawning non-`Send` futures on the same local executor as a [`Lua`] + For spawning `!Send` futures on the same local executor as a [`Lua`] VM instance, [`Lua::create_async_function`] should be used instead. */ -pub trait LuaExecutorExt<'lua> { +pub trait LuaSpawnExt<'lua> { /** Spawns the given future on the current executor and returns its [`Task`]. @@ -54,7 +63,7 @@ pub trait LuaExecutorExt<'lua> { ```rust use mlua::prelude::*; - use smol_mlua::{Runtime, LuaExecutorExt}; + use smol_mlua::{Runtime, LuaSpawnExt}; fn main() -> LuaResult<()> { let lua = Lua::new(); @@ -70,7 +79,7 @@ pub trait LuaExecutorExt<'lua> { )?; let rt = Runtime::new(&lua)?; - rt.push_thread(lua.load("spawnBackgroundTask()"), ()); + rt.spawn_thread(lua.load("spawnBackgroundTask()"), ()); rt.run_blocking(); Ok(()) @@ -82,7 +91,7 @@ pub trait LuaExecutorExt<'lua> { fn spawn(&self, fut: impl Future + Send + 'static) -> Task; } -impl<'lua> LuaExecutorExt<'lua> for Lua { +impl<'lua> LuaSpawnExt<'lua> for Lua { fn spawn(&self, fut: impl Future + Send + 'static) -> Task { let exec = self .app_data_ref::>() diff --git a/lib/util.rs b/lib/util.rs index 089c223..d0af4e2 100644 --- a/lib/util.rs +++ b/lib/util.rs @@ -1,5 +1,42 @@ +use std::cell::OnceCell; + use mlua::prelude::*; +use crate::IntoLuaThread; + +thread_local! { + static POLL_PENDING: OnceCell = OnceCell::new(); +} + +fn get_poll_pending(lua: &Lua) -> LuaResult { + let yielder_fn = lua.create_async_function(|_, ()| async move { + smol::future::yield_now().await; + Ok(()) + })?; + + yielder_fn + .into_lua_thread(lua)? + .resume::<_, LuaLightUserData>(()) +} + +#[inline] +pub(crate) fn is_poll_pending(value: &LuaValue) -> bool { + // TODO: Replace with Lua::poll_pending() when it's available + + let pp = POLL_PENDING.with(|cell| { + *cell.get_or_init(|| { + let lua = Lua::new().into_static(); + let pending = get_poll_pending(lua).unwrap(); + // SAFETY: We only use the Lua state for the lifetime of this function, + // and the "poll pending" light userdata / pointer is completely static. + drop(unsafe { Lua::from_static(lua) }); + pending + }) + }); + + matches!(value, LuaValue::LightUserData(u) if u == &pp) +} + /** Wrapper struct to accept either a Lua thread or a Lua function as function argument.