From da2846670bb1238bedf8875132f633a27a524d04 Mon Sep 17 00:00:00 2001
From: Filip Tibell <filip.tibell@gmail.com>
Date: Sat, 27 Jan 2024 23:35:10 +0100
Subject: [PATCH] Implement runtime handle struct for retrieving values back
 from spawned threads

---
 examples/callbacks.rs                |   8 +-
 examples/lua/scheduler_ordering.luau |  22 ++++--
 examples/scheduler_ordering.rs       |  10 ++-
 lib/handle.rs                        | 111 +++++++++++++++++++++++++++
 lib/lib.rs                           |   4 +
 lib/queue.rs                         |  41 +---------
 lib/runtime.rs                       |  72 +++++++++++++----
 lib/status.rs                        |  31 ++++++++
 lib/util.rs                          |  60 +++++++++++++++
 9 files changed, 295 insertions(+), 64 deletions(-)
 create mode 100644 lib/handle.rs
 create mode 100644 lib/status.rs

diff --git a/examples/callbacks.rs b/examples/callbacks.rs
index a51356d..cac13ed 100644
--- a/examples/callbacks.rs
+++ b/examples/callbacks.rs
@@ -1,4 +1,5 @@
 #![allow(clippy::missing_errors_doc)]
+#![allow(clippy::missing_panics_doc)]
 
 use mlua::prelude::*;
 use mlua_luau_runtime::Runtime;
@@ -23,13 +24,16 @@ pub fn main() -> LuaResult<()> {
         );
     });
 
-    // Load the main script into a runtime
+    // Load the main script into the runtime, and keep track of the thread we spawn
     let main = lua.load(MAIN_SCRIPT);
-    rt.spawn_thread(main, ())?;
+    let handle = rt.spawn_thread(main, ())?;
 
     // Run until completion
     block_on(rt.run());
 
+    // We should have gotten the error back from our script
+    assert!(handle.result(&lua).unwrap().is_err());
+
     Ok(())
 }
 
diff --git a/examples/lua/scheduler_ordering.luau b/examples/lua/scheduler_ordering.luau
index 264f503..b8aed74 100644
--- a/examples/lua/scheduler_ordering.luau
+++ b/examples/lua/scheduler_ordering.luau
@@ -1,26 +1,34 @@
 --!nocheck
 --!nolint UnknownGlobal
 
-print(1)
+local nums = {}
+local function insert(n: number)
+	table.insert(nums, n)
+	print(n)
+end
+
+insert(1)
 
 -- Defer will run at the end of the resumption cycle, but without yielding
 defer(function()
-	print(5)
+	insert(5)
 end)
 
 -- Spawn will instantly run up until the first yield, and must then be resumed manually ...
 spawn(function()
-	print(2)
+	insert(2)
 	coroutine.yield()
-	print("unreachable")
+	error("unreachable code")
 end)
 
 -- ... unless calling functions created using `lua.create_async_function(...)`,
 -- which will resume their calling thread with their result automatically
 spawn(function()
-	print(3)
+	insert(3)
 	sleep(1)
-	print(6)
+	insert(6)
 end)
 
-print(4)
+insert(4)
+
+return nums
diff --git a/examples/scheduler_ordering.rs b/examples/scheduler_ordering.rs
index 2d5800c..e28becb 100644
--- a/examples/scheduler_ordering.rs
+++ b/examples/scheduler_ordering.rs
@@ -1,4 +1,5 @@
 #![allow(clippy::missing_errors_doc)]
+#![allow(clippy::missing_panics_doc)]
 
 use std::time::{Duration, Instant};
 
@@ -28,13 +29,18 @@ pub fn main() -> LuaResult<()> {
         })?,
     )?;
 
-    // Load the main script into a runtime
+    // Load the main script into the runtime, and keep track of the thread we spawn
     let main = lua.load(MAIN_SCRIPT);
-    rt.spawn_thread(main, ())?;
+    let handle = rt.spawn_thread(main, ())?;
 
     // Run until completion
     block_on(rt.run());
 
+    // We should have gotten proper values back from our script
+    let res = handle.result(&lua).unwrap().unwrap();
+    let nums = Vec::<usize>::from_lua_multi(res, &lua)?;
+    assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]);
+
     Ok(())
 }
 
diff --git a/lib/handle.rs b/lib/handle.rs
new file mode 100644
index 0000000..39acc6c
--- /dev/null
+++ b/lib/handle.rs
@@ -0,0 +1,111 @@
+#![allow(unused_imports)]
+#![allow(clippy::missing_panics_doc)]
+#![allow(clippy::module_name_repetitions)]
+
+use std::{cell::RefCell, rc::Rc};
+
+use mlua::prelude::*;
+
+use crate::{
+    runtime::Runtime,
+    status::Status,
+    util::{run_until_yield, ThreadWithArgs},
+    IntoLuaThread,
+};
+
+/**
+    A handle to a thread that has been spawned onto a [`Runtime`].
+
+    This handle contains a single public method, [`Handle::result`], which may
+    be used to extract the result of the thread, once it has finished running.
+*/
+#[derive(Debug, Clone)]
+pub struct Handle {
+    thread: Rc<RefCell<Option<ThreadWithArgs>>>,
+    result: Rc<RefCell<Option<(bool, LuaRegistryKey)>>>,
+}
+
+impl Handle {
+    pub(crate) fn new<'lua>(
+        lua: &'lua Lua,
+        thread: impl IntoLuaThread<'lua>,
+        args: impl IntoLuaMulti<'lua>,
+    ) -> LuaResult<Self> {
+        let thread = thread.into_lua_thread(lua)?;
+        let args = args.into_lua_multi(lua)?;
+
+        let packed = ThreadWithArgs::new(lua, thread, args)?;
+
+        Ok(Self {
+            thread: Rc::new(RefCell::new(Some(packed))),
+            result: Rc::new(RefCell::new(None)),
+        })
+    }
+
+    pub(crate) fn create_thread<'lua>(&self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
+        let env = lua.create_table()?;
+        env.set("handle", self.clone())?;
+        lua.load("return handle:resume()")
+            .set_name("__runtime_handle")
+            .set_environment(env)
+            .into_lua_thread(lua)
+    }
+
+    fn take<'lua>(&self, lua: &'lua Lua) -> (LuaThread<'lua>, LuaMultiValue<'lua>) {
+        self.thread
+            .borrow_mut()
+            .take()
+            .expect("thread handle may only be taken once")
+            .into_inner(lua)
+    }
+
+    fn set<'lua>(&self, lua: &'lua Lua, result: &LuaResult<LuaMultiValue<'lua>>) -> LuaResult<()> {
+        self.result.borrow_mut().replace((
+            result.is_ok(),
+            match &result {
+                Ok(v) => lua.create_registry_value(v.clone().into_vec())?,
+                Err(e) => lua.create_registry_value(e.clone())?,
+            },
+        ));
+        Ok(())
+    }
+
+    /**
+        Extracts the result for this thread handle.
+
+        Depending on the current [`Runtime::status`], this method will return:
+
+        - [`Status::NotStarted`]: returns `None`.
+        - [`Status::Running`]: may return `Some(Ok(v))` or `Some(Err(e))`, but it is not guaranteed.
+        - [`Status::Completed`]: returns `Some(Ok(v))` or `Some(Err(e))`.
+    */
+    #[must_use]
+    pub fn result<'lua>(&self, lua: &'lua Lua) -> Option<LuaResult<LuaMultiValue<'lua>>> {
+        let res = self.result.borrow();
+        let (is_ok, key) = res.as_ref()?;
+        Some(if *is_ok {
+            let v = lua.registry_value(key).unwrap();
+            Ok(LuaMultiValue::from_vec(v))
+        } else {
+            Err(lua.registry_value(key).unwrap())
+        })
+    }
+}
+
+impl LuaUserData for Handle {
+    fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
+        methods.add_async_method("resume", |lua, this, (): ()| async move {
+            /*
+                1. Take the thread and args out of the handle
+                2. Run the thread until it yields or completes
+                3. Store the result of the thread in the lua registry
+                4. Return the result of the thread back to lua as well, so that
+                   it may be caught using the runtime and any error callback(s)
+            */
+            let (thread, args) = this.take(lua);
+            let result = run_until_yield(thread, args).await;
+            this.set(lua, &result)?;
+            result
+        });
+    }
+}
diff --git a/lib/lib.rs b/lib/lib.rs
index 627eebd..de4cc5e 100644
--- a/lib/lib.rs
+++ b/lib/lib.rs
@@ -1,8 +1,12 @@
 mod error_callback;
+mod handle;
 mod queue;
 mod runtime;
+mod status;
 mod traits;
 mod util;
 
+pub use handle::Handle;
 pub use runtime::Runtime;
+pub use status::Status;
 pub use traits::{IntoLuaThread, LuaSpawnExt};
diff --git a/lib/queue.rs b/lib/queue.rs
index 2dd2303..e0a6eeb 100644
--- a/lib/queue.rs
+++ b/lib/queue.rs
@@ -4,7 +4,7 @@ use concurrent_queue::ConcurrentQueue;
 use event_listener::Event;
 use mlua::prelude::*;
 
-use crate::IntoLuaThread;
+use crate::{util::ThreadWithArgs, IntoLuaThread};
 
 /**
     Queue for storing [`LuaThread`]s with associated arguments.
@@ -59,42 +59,3 @@ impl ThreadQueue {
         }
     }
 }
-
-/**
-    Representation of a [`LuaThread`] with its associated arguments currently stored in the Lua registry.
-*/
-#[derive(Debug)]
-struct ThreadWithArgs {
-    key_thread: LuaRegistryKey,
-    key_args: LuaRegistryKey,
-}
-
-impl ThreadWithArgs {
-    fn new<'lua>(
-        lua: &'lua Lua,
-        thread: LuaThread<'lua>,
-        args: LuaMultiValue<'lua>,
-    ) -> LuaResult<Self> {
-        let argsv = args.into_vec();
-
-        let key_thread = lua.create_registry_value(thread)?;
-        let key_args = lua.create_registry_value(argsv)?;
-
-        Ok(Self {
-            key_thread,
-            key_args,
-        })
-    }
-
-    fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) {
-        let thread = lua.registry_value(&self.key_thread).unwrap();
-        let argsv = lua.registry_value(&self.key_args).unwrap();
-
-        let args = LuaMultiValue::from_vec(argsv);
-
-        lua.remove_registry_value(self.key_thread).unwrap();
-        lua.remove_registry_value(self.key_args).unwrap();
-
-        (thread, args)
-    }
-}
diff --git a/lib/runtime.rs b/lib/runtime.rs
index 8dcb8bd..59d92ef 100644
--- a/lib/runtime.rs
+++ b/lib/runtime.rs
@@ -1,4 +1,10 @@
-use std::sync::{Arc, Weak};
+#![allow(clippy::module_name_repetitions)]
+
+use std::{
+    cell::Cell,
+    rc::Rc,
+    sync::{Arc, Weak},
+};
 
 use futures_lite::prelude::*;
 use mlua::prelude::*;
@@ -6,16 +12,22 @@ use mlua::prelude::*;
 use async_executor::{Executor, LocalExecutor};
 use tracing::Instrument;
 
+use crate::{status::Status, util::run_until_yield, Handle};
+
 use super::{
     error_callback::ThreadErrorCallback, queue::ThreadQueue, traits::IntoLuaThread,
     util::LuaThreadOrFunction,
 };
 
+/**
+    A runtime for running Lua threads and async tasks.
+*/
 pub struct Runtime<'lua> {
     lua: &'lua Lua,
     queue_spawn: ThreadQueue,
     queue_defer: ThreadQueue,
     error_callback: ThreadErrorCallback,
+    status: Rc<Cell<Status>>,
 }
 
 impl<'lua> Runtime<'lua> {
@@ -29,15 +41,24 @@ impl<'lua> Runtime<'lua> {
         let queue_spawn = ThreadQueue::new();
         let queue_defer = ThreadQueue::new();
         let error_callback = ThreadErrorCallback::default();
-
+        let status = Rc::new(Cell::new(Status::NotStarted));
         Runtime {
             lua,
             queue_spawn,
             queue_defer,
             error_callback,
+            status,
         }
     }
 
+    /**
+        Returns the current status of this runtime.
+    */
+    #[must_use]
+    pub fn status(&self) -> Status {
+        self.status.get()
+    }
+
     /**
         Sets the error callback for this runtime.
 
@@ -63,6 +84,12 @@ impl<'lua> Runtime<'lua> {
 
         Threads are guaranteed to be resumed in the order that they were pushed to the queue.
 
+        # Returns
+
+        Returns a [`Handle`] that can be used to retrieve the result of the thread.
+
+        Note that the result may not be available until [`Runtime::run`] completes.
+
         # Errors
 
         Errors when out of memory.
@@ -71,9 +98,15 @@ impl<'lua> Runtime<'lua> {
         &self,
         thread: impl IntoLuaThread<'lua>,
         args: impl IntoLuaMulti<'lua>,
-    ) -> LuaResult<()> {
+    ) -> LuaResult<Handle> {
         tracing::debug!(deferred = false, "new runtime thread");
-        self.queue_spawn.push_item(self.lua, thread, args)
+
+        let handle = Handle::new(self.lua, thread, args)?;
+        let handle_thread = handle.create_thread(self.lua)?;
+
+        self.queue_spawn.push_item(self.lua, handle_thread, ())?;
+
+        Ok(handle)
     }
 
     /**
@@ -83,6 +116,12 @@ impl<'lua> Runtime<'lua> {
 
         Threads are guaranteed to be resumed in the order that they were pushed to the queue.
 
+        # Returns
+
+        Returns a [`Handle`] that can be used to retrieve the result of the thread.
+
+        Note that the result may not be available until [`Runtime::run`] completes.
+
         # Errors
 
         Errors when out of memory.
@@ -91,9 +130,15 @@ impl<'lua> Runtime<'lua> {
         &self,
         thread: impl IntoLuaThread<'lua>,
         args: impl IntoLuaMulti<'lua>,
-    ) -> LuaResult<()> {
+    ) -> LuaResult<Handle> {
         tracing::debug!(deferred = true, "new runtime thread");
-        self.queue_defer.push_item(self.lua, thread, args)
+
+        let handle = Handle::new(self.lua, thread, args)?;
+        let handle_thread = handle.create_thread(self.lua)?;
+
+        self.queue_defer.push_item(self.lua, handle_thread, ())?;
+
+        Ok(handle)
     }
 
     /**
@@ -214,15 +259,10 @@ impl<'lua> Runtime<'lua> {
                 // NOTE: Thread may have been cancelled from Lua
                 // before we got here, so we need to check it again
                 if thread.status() == LuaThreadStatus::Resumable {
-                    let mut stream = thread.clone().into_async::<_, LuaValue>(args);
                     lua_exec
                         .spawn(async move {
-                            // Only run stream until first coroutine.yield or completion. We will
-                            // drop it right away to clear stack space since detached tasks dont drop
-                            // until the executor drops (https://github.com/smol-rs/smol/issues/294)
-                            let res = stream.next().await.unwrap();
-                            if let Err(e) = &res {
-                                self.error_callback.call(e);
+                            if let Err(e) = run_until_yield(thread, args).await {
+                                self.error_callback.call(&e);
                             }
                         })
                         .detach();
@@ -280,9 +320,15 @@ impl<'lua> Runtime<'lua> {
         };
 
         // Run the executor inside a span until all lua threads complete
+        self.status.set(Status::Running);
+        tracing::debug!("starting runtime");
+
         let span = tracing::debug_span!("run_executor");
         main_exec.run(fut).instrument(span.or_current()).await;
 
+        tracing::debug!("runtime completed");
+        self.status.set(Status::Completed);
+
         // Clean up
         self.lua.remove_app_data::<Weak<Executor>>();
     }
diff --git a/lib/status.rs b/lib/status.rs
new file mode 100644
index 0000000..31d707e
--- /dev/null
+++ b/lib/status.rs
@@ -0,0 +1,31 @@
+#![allow(clippy::module_name_repetitions)]
+
+/**
+    The current status of a runtime.
+*/
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
+pub enum Status {
+    /// The runtime has not yet started running.
+    NotStarted,
+    /// The runtime is currently running.
+    Running,
+    /// The runtime has completed.
+    Completed,
+}
+
+impl Status {
+    #[must_use]
+    pub const fn is_not_started(self) -> bool {
+        matches!(self, Self::NotStarted)
+    }
+
+    #[must_use]
+    pub const fn is_running(self) -> bool {
+        matches!(self, Self::Running)
+    }
+
+    #[must_use]
+    pub const fn is_completed(self) -> bool {
+        matches!(self, Self::Completed)
+    }
+}
diff --git a/lib/util.rs b/lib/util.rs
index 089c223..5001901 100644
--- a/lib/util.rs
+++ b/lib/util.rs
@@ -1,5 +1,65 @@
+use futures_lite::StreamExt;
 use mlua::prelude::*;
 
+/**
+    Runs a Lua thread until it manually yields (using coroutine.yield), errors, or completes.
+
+    Returns the values yielded by the thread, or the error that caused it to stop.
+*/
+pub(crate) async fn run_until_yield<'lua>(
+    thread: LuaThread<'lua>,
+    args: LuaMultiValue<'lua>,
+) -> LuaResult<LuaMultiValue<'lua>> {
+    let mut stream = thread.into_async(args);
+    /*
+        NOTE: It is very important that we drop the thread/stream as
+        soon as we are done, it takes up valuable Lua registry space
+        and detached tasks will not drop until the executor does
+
+        https://github.com/smol-rs/smol/issues/294
+    */
+    stream.next().await.unwrap()
+}
+
+/**
+    Representation of a [`LuaThread`] with its associated arguments currently stored in the Lua registry.
+*/
+#[derive(Debug)]
+pub(crate) struct ThreadWithArgs {
+    key_thread: LuaRegistryKey,
+    key_args: LuaRegistryKey,
+}
+
+impl ThreadWithArgs {
+    pub fn new<'lua>(
+        lua: &'lua Lua,
+        thread: LuaThread<'lua>,
+        args: LuaMultiValue<'lua>,
+    ) -> LuaResult<Self> {
+        let argsv = args.into_vec();
+
+        let key_thread = lua.create_registry_value(thread)?;
+        let key_args = lua.create_registry_value(argsv)?;
+
+        Ok(Self {
+            key_thread,
+            key_args,
+        })
+    }
+
+    pub fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) {
+        let thread = lua.registry_value(&self.key_thread).unwrap();
+        let argsv = lua.registry_value(&self.key_args).unwrap();
+
+        let args = LuaMultiValue::from_vec(argsv);
+
+        lua.remove_registry_value(self.key_thread).unwrap();
+        lua.remove_registry_value(self.key_args).unwrap();
+
+        (thread, args)
+    }
+}
+
 /**
     Wrapper struct to accept either a Lua thread or a Lua function as function argument.