diff --git a/examples/callbacks.rs b/examples/callbacks.rs index a51356d..cac13ed 100644 --- a/examples/callbacks.rs +++ b/examples/callbacks.rs @@ -1,4 +1,5 @@ #![allow(clippy::missing_errors_doc)] +#![allow(clippy::missing_panics_doc)] use mlua::prelude::*; use mlua_luau_runtime::Runtime; @@ -23,13 +24,16 @@ pub fn main() -> LuaResult<()> { ); }); - // Load the main script into a runtime + // Load the main script into the runtime, and keep track of the thread we spawn let main = lua.load(MAIN_SCRIPT); - rt.spawn_thread(main, ())?; + let handle = rt.spawn_thread(main, ())?; // Run until completion block_on(rt.run()); + // We should have gotten the error back from our script + assert!(handle.result(&lua).unwrap().is_err()); + Ok(()) } diff --git a/examples/lua/scheduler_ordering.luau b/examples/lua/scheduler_ordering.luau index 264f503..b8aed74 100644 --- a/examples/lua/scheduler_ordering.luau +++ b/examples/lua/scheduler_ordering.luau @@ -1,26 +1,34 @@ --!nocheck --!nolint UnknownGlobal -print(1) +local nums = {} +local function insert(n: number) + table.insert(nums, n) + print(n) +end + +insert(1) -- Defer will run at the end of the resumption cycle, but without yielding defer(function() - print(5) + insert(5) end) -- Spawn will instantly run up until the first yield, and must then be resumed manually ... spawn(function() - print(2) + insert(2) coroutine.yield() - print("unreachable") + error("unreachable code") end) -- ... unless calling functions created using `lua.create_async_function(...)`, -- which will resume their calling thread with their result automatically spawn(function() - print(3) + insert(3) sleep(1) - print(6) + insert(6) end) -print(4) +insert(4) + +return nums diff --git a/examples/scheduler_ordering.rs b/examples/scheduler_ordering.rs index 2d5800c..e28becb 100644 --- a/examples/scheduler_ordering.rs +++ b/examples/scheduler_ordering.rs @@ -1,4 +1,5 @@ #![allow(clippy::missing_errors_doc)] +#![allow(clippy::missing_panics_doc)] use std::time::{Duration, Instant}; @@ -28,13 +29,18 @@ pub fn main() -> LuaResult<()> { })?, )?; - // Load the main script into a runtime + // Load the main script into the runtime, and keep track of the thread we spawn let main = lua.load(MAIN_SCRIPT); - rt.spawn_thread(main, ())?; + let handle = rt.spawn_thread(main, ())?; // Run until completion block_on(rt.run()); + // We should have gotten proper values back from our script + let res = handle.result(&lua).unwrap().unwrap(); + let nums = Vec::::from_lua_multi(res, &lua)?; + assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]); + Ok(()) } diff --git a/lib/handle.rs b/lib/handle.rs new file mode 100644 index 0000000..39acc6c --- /dev/null +++ b/lib/handle.rs @@ -0,0 +1,111 @@ +#![allow(unused_imports)] +#![allow(clippy::missing_panics_doc)] +#![allow(clippy::module_name_repetitions)] + +use std::{cell::RefCell, rc::Rc}; + +use mlua::prelude::*; + +use crate::{ + runtime::Runtime, + status::Status, + util::{run_until_yield, ThreadWithArgs}, + IntoLuaThread, +}; + +/** + A handle to a thread that has been spawned onto a [`Runtime`]. + + This handle contains a single public method, [`Handle::result`], which may + be used to extract the result of the thread, once it has finished running. +*/ +#[derive(Debug, Clone)] +pub struct Handle { + thread: Rc>>, + result: Rc>>, +} + +impl Handle { + pub(crate) fn new<'lua>( + lua: &'lua Lua, + thread: impl IntoLuaThread<'lua>, + args: impl IntoLuaMulti<'lua>, + ) -> LuaResult { + let thread = thread.into_lua_thread(lua)?; + let args = args.into_lua_multi(lua)?; + + let packed = ThreadWithArgs::new(lua, thread, args)?; + + Ok(Self { + thread: Rc::new(RefCell::new(Some(packed))), + result: Rc::new(RefCell::new(None)), + }) + } + + pub(crate) fn create_thread<'lua>(&self, lua: &'lua Lua) -> LuaResult> { + let env = lua.create_table()?; + env.set("handle", self.clone())?; + lua.load("return handle:resume()") + .set_name("__runtime_handle") + .set_environment(env) + .into_lua_thread(lua) + } + + fn take<'lua>(&self, lua: &'lua Lua) -> (LuaThread<'lua>, LuaMultiValue<'lua>) { + self.thread + .borrow_mut() + .take() + .expect("thread handle may only be taken once") + .into_inner(lua) + } + + fn set<'lua>(&self, lua: &'lua Lua, result: &LuaResult>) -> LuaResult<()> { + self.result.borrow_mut().replace(( + result.is_ok(), + match &result { + Ok(v) => lua.create_registry_value(v.clone().into_vec())?, + Err(e) => lua.create_registry_value(e.clone())?, + }, + )); + Ok(()) + } + + /** + Extracts the result for this thread handle. + + Depending on the current [`Runtime::status`], this method will return: + + - [`Status::NotStarted`]: returns `None`. + - [`Status::Running`]: may return `Some(Ok(v))` or `Some(Err(e))`, but it is not guaranteed. + - [`Status::Completed`]: returns `Some(Ok(v))` or `Some(Err(e))`. + */ + #[must_use] + pub fn result<'lua>(&self, lua: &'lua Lua) -> Option>> { + let res = self.result.borrow(); + let (is_ok, key) = res.as_ref()?; + Some(if *is_ok { + let v = lua.registry_value(key).unwrap(); + Ok(LuaMultiValue::from_vec(v)) + } else { + Err(lua.registry_value(key).unwrap()) + }) + } +} + +impl LuaUserData for Handle { + fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_async_method("resume", |lua, this, (): ()| async move { + /* + 1. Take the thread and args out of the handle + 2. Run the thread until it yields or completes + 3. Store the result of the thread in the lua registry + 4. Return the result of the thread back to lua as well, so that + it may be caught using the runtime and any error callback(s) + */ + let (thread, args) = this.take(lua); + let result = run_until_yield(thread, args).await; + this.set(lua, &result)?; + result + }); + } +} diff --git a/lib/lib.rs b/lib/lib.rs index 627eebd..de4cc5e 100644 --- a/lib/lib.rs +++ b/lib/lib.rs @@ -1,8 +1,12 @@ mod error_callback; +mod handle; mod queue; mod runtime; +mod status; mod traits; mod util; +pub use handle::Handle; pub use runtime::Runtime; +pub use status::Status; pub use traits::{IntoLuaThread, LuaSpawnExt}; diff --git a/lib/queue.rs b/lib/queue.rs index 2dd2303..e0a6eeb 100644 --- a/lib/queue.rs +++ b/lib/queue.rs @@ -4,7 +4,7 @@ use concurrent_queue::ConcurrentQueue; use event_listener::Event; use mlua::prelude::*; -use crate::IntoLuaThread; +use crate::{util::ThreadWithArgs, IntoLuaThread}; /** Queue for storing [`LuaThread`]s with associated arguments. @@ -59,42 +59,3 @@ impl ThreadQueue { } } } - -/** - Representation of a [`LuaThread`] with its associated arguments currently stored in the Lua registry. -*/ -#[derive(Debug)] -struct ThreadWithArgs { - key_thread: LuaRegistryKey, - key_args: LuaRegistryKey, -} - -impl ThreadWithArgs { - fn new<'lua>( - lua: &'lua Lua, - thread: LuaThread<'lua>, - args: LuaMultiValue<'lua>, - ) -> LuaResult { - let argsv = args.into_vec(); - - let key_thread = lua.create_registry_value(thread)?; - let key_args = lua.create_registry_value(argsv)?; - - Ok(Self { - key_thread, - key_args, - }) - } - - fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) { - let thread = lua.registry_value(&self.key_thread).unwrap(); - let argsv = lua.registry_value(&self.key_args).unwrap(); - - let args = LuaMultiValue::from_vec(argsv); - - lua.remove_registry_value(self.key_thread).unwrap(); - lua.remove_registry_value(self.key_args).unwrap(); - - (thread, args) - } -} diff --git a/lib/runtime.rs b/lib/runtime.rs index 8dcb8bd..59d92ef 100644 --- a/lib/runtime.rs +++ b/lib/runtime.rs @@ -1,4 +1,10 @@ -use std::sync::{Arc, Weak}; +#![allow(clippy::module_name_repetitions)] + +use std::{ + cell::Cell, + rc::Rc, + sync::{Arc, Weak}, +}; use futures_lite::prelude::*; use mlua::prelude::*; @@ -6,16 +12,22 @@ use mlua::prelude::*; use async_executor::{Executor, LocalExecutor}; use tracing::Instrument; +use crate::{status::Status, util::run_until_yield, Handle}; + use super::{ error_callback::ThreadErrorCallback, queue::ThreadQueue, traits::IntoLuaThread, util::LuaThreadOrFunction, }; +/** + A runtime for running Lua threads and async tasks. +*/ pub struct Runtime<'lua> { lua: &'lua Lua, queue_spawn: ThreadQueue, queue_defer: ThreadQueue, error_callback: ThreadErrorCallback, + status: Rc>, } impl<'lua> Runtime<'lua> { @@ -29,15 +41,24 @@ impl<'lua> Runtime<'lua> { let queue_spawn = ThreadQueue::new(); let queue_defer = ThreadQueue::new(); let error_callback = ThreadErrorCallback::default(); - + let status = Rc::new(Cell::new(Status::NotStarted)); Runtime { lua, queue_spawn, queue_defer, error_callback, + status, } } + /** + Returns the current status of this runtime. + */ + #[must_use] + pub fn status(&self) -> Status { + self.status.get() + } + /** Sets the error callback for this runtime. @@ -63,6 +84,12 @@ impl<'lua> Runtime<'lua> { Threads are guaranteed to be resumed in the order that they were pushed to the queue. + # Returns + + Returns a [`Handle`] that can be used to retrieve the result of the thread. + + Note that the result may not be available until [`Runtime::run`] completes. + # Errors Errors when out of memory. @@ -71,9 +98,15 @@ impl<'lua> Runtime<'lua> { &self, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaResult<()> { + ) -> LuaResult { tracing::debug!(deferred = false, "new runtime thread"); - self.queue_spawn.push_item(self.lua, thread, args) + + let handle = Handle::new(self.lua, thread, args)?; + let handle_thread = handle.create_thread(self.lua)?; + + self.queue_spawn.push_item(self.lua, handle_thread, ())?; + + Ok(handle) } /** @@ -83,6 +116,12 @@ impl<'lua> Runtime<'lua> { Threads are guaranteed to be resumed in the order that they were pushed to the queue. + # Returns + + Returns a [`Handle`] that can be used to retrieve the result of the thread. + + Note that the result may not be available until [`Runtime::run`] completes. + # Errors Errors when out of memory. @@ -91,9 +130,15 @@ impl<'lua> Runtime<'lua> { &self, thread: impl IntoLuaThread<'lua>, args: impl IntoLuaMulti<'lua>, - ) -> LuaResult<()> { + ) -> LuaResult { tracing::debug!(deferred = true, "new runtime thread"); - self.queue_defer.push_item(self.lua, thread, args) + + let handle = Handle::new(self.lua, thread, args)?; + let handle_thread = handle.create_thread(self.lua)?; + + self.queue_defer.push_item(self.lua, handle_thread, ())?; + + Ok(handle) } /** @@ -214,15 +259,10 @@ impl<'lua> Runtime<'lua> { // NOTE: Thread may have been cancelled from Lua // before we got here, so we need to check it again if thread.status() == LuaThreadStatus::Resumable { - let mut stream = thread.clone().into_async::<_, LuaValue>(args); lua_exec .spawn(async move { - // Only run stream until first coroutine.yield or completion. We will - // drop it right away to clear stack space since detached tasks dont drop - // until the executor drops (https://github.com/smol-rs/smol/issues/294) - let res = stream.next().await.unwrap(); - if let Err(e) = &res { - self.error_callback.call(e); + if let Err(e) = run_until_yield(thread, args).await { + self.error_callback.call(&e); } }) .detach(); @@ -280,9 +320,15 @@ impl<'lua> Runtime<'lua> { }; // Run the executor inside a span until all lua threads complete + self.status.set(Status::Running); + tracing::debug!("starting runtime"); + let span = tracing::debug_span!("run_executor"); main_exec.run(fut).instrument(span.or_current()).await; + tracing::debug!("runtime completed"); + self.status.set(Status::Completed); + // Clean up self.lua.remove_app_data::>(); } diff --git a/lib/status.rs b/lib/status.rs new file mode 100644 index 0000000..31d707e --- /dev/null +++ b/lib/status.rs @@ -0,0 +1,31 @@ +#![allow(clippy::module_name_repetitions)] + +/** + The current status of a runtime. +*/ +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum Status { + /// The runtime has not yet started running. + NotStarted, + /// The runtime is currently running. + Running, + /// The runtime has completed. + Completed, +} + +impl Status { + #[must_use] + pub const fn is_not_started(self) -> bool { + matches!(self, Self::NotStarted) + } + + #[must_use] + pub const fn is_running(self) -> bool { + matches!(self, Self::Running) + } + + #[must_use] + pub const fn is_completed(self) -> bool { + matches!(self, Self::Completed) + } +} diff --git a/lib/util.rs b/lib/util.rs index 089c223..5001901 100644 --- a/lib/util.rs +++ b/lib/util.rs @@ -1,5 +1,65 @@ +use futures_lite::StreamExt; use mlua::prelude::*; +/** + Runs a Lua thread until it manually yields (using coroutine.yield), errors, or completes. + + Returns the values yielded by the thread, or the error that caused it to stop. +*/ +pub(crate) async fn run_until_yield<'lua>( + thread: LuaThread<'lua>, + args: LuaMultiValue<'lua>, +) -> LuaResult> { + let mut stream = thread.into_async(args); + /* + NOTE: It is very important that we drop the thread/stream as + soon as we are done, it takes up valuable Lua registry space + and detached tasks will not drop until the executor does + + https://github.com/smol-rs/smol/issues/294 + */ + stream.next().await.unwrap() +} + +/** + Representation of a [`LuaThread`] with its associated arguments currently stored in the Lua registry. +*/ +#[derive(Debug)] +pub(crate) struct ThreadWithArgs { + key_thread: LuaRegistryKey, + key_args: LuaRegistryKey, +} + +impl ThreadWithArgs { + pub fn new<'lua>( + lua: &'lua Lua, + thread: LuaThread<'lua>, + args: LuaMultiValue<'lua>, + ) -> LuaResult { + let argsv = args.into_vec(); + + let key_thread = lua.create_registry_value(thread)?; + let key_args = lua.create_registry_value(argsv)?; + + Ok(Self { + key_thread, + key_args, + }) + } + + pub fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) { + let thread = lua.registry_value(&self.key_thread).unwrap(); + let argsv = lua.registry_value(&self.key_args).unwrap(); + + let args = LuaMultiValue::from_vec(argsv); + + lua.remove_registry_value(self.key_thread).unwrap(); + lua.remove_registry_value(self.key_args).unwrap(); + + (thread, args) + } +} + /** Wrapper struct to accept either a Lua thread or a Lua function as function argument.