From f4e13bf8b73565bbbed8aef32c95f01a13a91e0b Mon Sep 17 00:00:00 2001
From: Filip Tibell <filip.tibell@gmail.com>
Date: Fri, 19 Jan 2024 12:40:10 +0100
Subject: [PATCH] Implement extension trait to be able to spawn Send futures
 from lua

---
 lib/lib.rs     |  2 +-
 lib/runtime.rs | 14 ++++++++++--
 lib/traits.rs  | 60 ++++++++++++++++++++++++++++++++++++++++++++++++++
 3 files changed, 73 insertions(+), 3 deletions(-)

diff --git a/lib/lib.rs b/lib/lib.rs
index c5a7205..5aa95b6 100644
--- a/lib/lib.rs
+++ b/lib/lib.rs
@@ -9,4 +9,4 @@ pub use smol;
 
 pub use callbacks::Callbacks;
 pub use runtime::Runtime;
-pub use traits::IntoLuaThread;
+pub use traits::{IntoLuaThread, LuaExecutorExt};
diff --git a/lib/runtime.rs b/lib/runtime.rs
index 7ec49ab..3d1f4aa 100644
--- a/lib/runtime.rs
+++ b/lib/runtime.rs
@@ -1,4 +1,4 @@
-use std::{cell::Cell, rc::Rc};
+use std::{cell::Cell, rc::Rc, sync::Arc};
 
 use mlua::prelude::*;
 use smol::{
@@ -18,6 +18,7 @@ const GLOBAL_NAME_DEFER: &str = "__runtime__defer";
 
 pub struct Runtime {
     queue_status: Rc<Cell<bool>>,
+    // TODO: Use something better than Rc<Mutex<Vec<...>>>
     queue_spawn: Rc<Mutex<Vec<ThreadWithArgs>>>,
     queue_defer: Rc<Mutex<Vec<ThreadWithArgs>>>,
     tx: Sender<()>,
@@ -46,6 +47,10 @@ impl Runtime {
             .resume::<_, LuaValue>(())?;
         let pending_key = lua.create_registry_value(pending)?;
 
+        // TODO: Generalize these two functions below so we
+        // dont need to duplicate the same exact thing for
+        // spawn and defer which is prone to human error
+
         // Create spawn function (push to start of queue)
         let b_spawn = Rc::clone(&queue_status);
         let q_spawn = Rc::clone(&queue_spawn);
@@ -142,7 +147,12 @@ impl Runtime {
     pub async fn run_async(&self, lua: &Lua) {
         // Create new executors to use
         let lua_exec = LocalExecutor::new();
-        let main_exec = Executor::new();
+        let main_exec = Arc::new(Executor::new());
+
+        // TODO: Create multiple executors for work stealing
+
+        // Store the main executor in lua for LuaExecutorExt trait
+        lua.set_app_data(Arc::downgrade(&main_exec));
 
         // Tick local lua executor while also driving main
         // executor forward, until all lua threads finish
diff --git a/lib/traits.rs b/lib/traits.rs
index ec9c04f..1ef509c 100644
--- a/lib/traits.rs
+++ b/lib/traits.rs
@@ -1,4 +1,7 @@
+use std::{future::Future, sync::Weak};
+
 use mlua::prelude::*;
+use smol::{Executor, Task};
 
 /**
     Trait for any struct that can be turned into an [`LuaThread`]
@@ -32,3 +35,60 @@ impl<'lua> IntoLuaThread<'lua> for LuaChunk<'lua, '_> {
         lua.create_thread(self.into_function()?)
     }
 }
+
+/**
+    Trait for spawning `Send` futures on the current executor.
+
+    For spawning non-`Send` futures on the same local executor as a [`Lua`]
+    VM instance, [`Lua::create_async_function`] should be used instead.
+*/
+pub trait LuaExecutorExt<'lua> {
+    /**
+        Spawns the given future on the current executor and returns its [`Task`].
+
+        ### Panics
+
+        Panics if called outside of a [`Runtime`].
+
+        ### Example usage
+
+        ```rust
+        use mlua::prelude::*;
+        use smol_mlua::{Runtime, LuaExecutorExt};
+
+        fn main() -> LuaResult<()> {
+            let lua = Lua::new();
+
+            lua.globals().set(
+                "spawnBackgroundTask",
+                lua.create_async_function(|lua, ()| async move {
+                    lua.spawn(async move {
+                        println!("Hello from background task!");
+                    }).await;
+                    Ok(())
+                })?
+            )?;
+
+            let rt = Runtime::new(&lua)?;
+            rt.push_main(&lua, lua.load("spawnBackgroundTask()"), ());
+            rt.run_blocking(&lua);
+
+            Ok(())
+        }
+        ```
+
+        [`Runtime`]: crate::Runtime
+    */
+    fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T>;
+}
+
+impl<'lua> LuaExecutorExt<'lua> for Lua {
+    fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T> {
+        let exec = self
+            .app_data_ref::<Weak<Executor>>()
+            .expect("futures can only be spawned within a runtime")
+            .upgrade()
+            .expect("executor was dropped");
+        exec.spawn(fut)
+    }
+}