diff --git a/Cargo.toml b/Cargo.toml
index ee28b7b..660115e 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -22,10 +22,6 @@ test = true
name = "callbacks"
test = true
-[[example]]
-name = "captures"
-test = true
-
[[example]]
name = "lots_of_threads"
test = true
diff --git a/README.md b/README.md
index b978c12..c0d4fd0 100644
--- a/README.md
+++ b/README.md
@@ -16,7 +16,10 @@
-Integration between [smol](https://crates.io/crates/smol) and [mlua](https://crates.io/crates/mlua) that provides a fully functional and asynchronous Luau runtime using smol executor(s).
+Integration between [smol] and [mlua] that provides a fully functional and asynchronous Luau runtime using smol executor(s).
+
+[smol]: https://crates.io/crates/smol
+[mlua]: https://crates.io/crates/mlua
## Example Usage
@@ -25,11 +28,9 @@ Integration between [smol](https://crates.io/crates/smol) and [mlua](https://cra
```rs
use std::time::{Duration, Instant};
-use smol_mlua::{
- mlua::prelude::*,
- smol::{Timer, io, fs::read_to_string},
- Runtime,
-};
+use mlua::prelude::*;
+use smol::{Timer, io, fs::read_to_string}
+use smol_mlua::Runtime;
```
### 2. Set up lua environment
@@ -68,15 +69,15 @@ lua.globals().set(
```rs
let rt = Runtime::new(&lua)?;
-// We can create multiple lua threads
+// We can create multiple lua threads ...
let sleepThread = lua.load("sleep(0.1)");
let fileThread = lua.load("readFile(\"Cargo.toml\")");
-// Put them all into the runtime
-rt.push_thread(sleepThread, ());
-rt.push_thread(fileThread, ());
+// ... spawn them both onto the runtime ...
+rt.spawn_thread(sleepThread, ());
+rt.spawn_thread(fileThread, ());
-// And run either async or blocking, until above threads finish
+// ... and run either async or blocking, until they finish
rt.run_async().await;
rt.run_blocking();
```
diff --git a/examples/basic_sleep.rs b/examples/basic_sleep.rs
index d5a427a..52f376d 100644
--- a/examples/basic_sleep.rs
+++ b/examples/basic_sleep.rs
@@ -1,10 +1,8 @@
use std::time::{Duration, Instant};
-use smol_mlua::{
- mlua::prelude::{Lua, LuaResult},
- smol::Timer,
- Runtime,
-};
+use mlua::prelude::*;
+use smol::Timer;
+use smol_mlua::Runtime;
const MAIN_SCRIPT: &str = include_str!("./lua/basic_sleep.luau");
@@ -23,7 +21,7 @@ pub fn main() -> LuaResult<()> {
// Load the main script into a runtime and run it until completion
let rt = Runtime::new(&lua)?;
let main = lua.load(MAIN_SCRIPT);
- rt.push_thread(main, ());
+ rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())
diff --git a/examples/basic_spawn.rs b/examples/basic_spawn.rs
index c43f8f6..a8b1ad9 100644
--- a/examples/basic_spawn.rs
+++ b/examples/basic_spawn.rs
@@ -1,10 +1,6 @@
-use mlua::ExternalResult;
-use smol::io;
-use smol_mlua::{
- mlua::prelude::{Lua, LuaResult},
- smol::fs::read_to_string,
- LuaExecutorExt, Runtime,
-};
+use mlua::prelude::*;
+use smol::{fs::read_to_string, io};
+use smol_mlua::{LuaSpawnExt, Runtime};
const MAIN_SCRIPT: &str = include_str!("./lua/basic_spawn.luau");
@@ -29,7 +25,7 @@ pub fn main() -> LuaResult<()> {
// Load the main script into a runtime and run it until completion
let rt = Runtime::new(&lua)?;
let main = lua.load(MAIN_SCRIPT);
- rt.push_thread(main, ());
+ rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())
diff --git a/examples/callbacks.rs b/examples/callbacks.rs
index 447c8ce..0de42a0 100644
--- a/examples/callbacks.rs
+++ b/examples/callbacks.rs
@@ -1,7 +1,5 @@
-use smol_mlua::{
- mlua::prelude::{Lua, LuaResult},
- Callbacks, Runtime,
-};
+use mlua::prelude::*;
+use smol_mlua::Runtime;
const MAIN_SCRIPT: &str = include_str!("./lua/callbacks.luau");
@@ -11,17 +9,17 @@ pub fn main() -> LuaResult<()> {
// Create a new runtime with custom callbacks
let rt = Runtime::new(&lua)?;
- rt.set_callbacks(Callbacks::default().on_error(|_, _, e| {
+ rt.set_error_callback(|e| {
println!(
"Captured error from Lua!\n{}\n{e}\n{}",
"-".repeat(15),
"-".repeat(15)
);
- }));
+ });
// Load and run the main script until completion
let main = lua.load(MAIN_SCRIPT);
- rt.push_thread(main, ());
+ rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())
diff --git a/examples/captures.rs b/examples/captures.rs
deleted file mode 100644
index 2d2f8c6..0000000
--- a/examples/captures.rs
+++ /dev/null
@@ -1,92 +0,0 @@
-use std::{
- rc::Rc,
- time::{Duration, Instant},
-};
-
-use smol_mlua::{
- mlua::prelude::{Lua, LuaResult, LuaThread, LuaValue},
- smol::{lock::Mutex, Timer},
- Callbacks, IntoLuaThread, Runtime,
-};
-
-const MAIN_SCRIPT: &str = include_str!("./lua/captures.luau");
-
-pub fn main() -> LuaResult<()> {
- // Set up persistent lua environment
- let lua = Lua::new();
- lua.globals().set(
- "sleep",
- lua.create_async_function(|_, duration: Option| async move {
- let duration = duration.unwrap_or_default().max(1.0 / 250.0);
- let before = Instant::now();
- let after = Timer::after(Duration::from_secs_f64(duration)).await;
- Ok((after - before).as_secs_f64())
- })?,
- )?;
-
- // Load and run the main script a few times for the purposes of this example
- for _ in 0..20 {
- println!("...");
- match run(&lua, lua.load(MAIN_SCRIPT)) {
- Err(e) => eprintln!("Errored:\n{e}"),
- Ok(v) => println!("Returned value:\n{v:?}"),
- }
- }
-
- Ok(())
-}
-
-/**
- Wrapper function to run the given `main` thread on a new [`Runtime`].
-
- Waits for all threads to finish, including the main thread, and
- returns the value or error of the main thread once exited.
-*/
-fn run<'lua>(lua: &'lua Lua, main: impl IntoLuaThread<'lua>) -> LuaResult {
- // Set up runtime (thread queue / async executors)
- let rt = Runtime::new(lua)?;
- let thread = rt.push_thread(main, ());
- lua.set_named_registry_value("mainThread", thread)?;
-
- // Create callbacks to capture resulting value/error of main thread,
- // we need to do some tricks to get around the lifetime issues with 'lua
- // being different inside the callback vs. outside the callback, for LuaValue
- let captured_error = Rc::new(Mutex::new(None));
- let captured_error_inner = Rc::clone(&captured_error);
- rt.set_callbacks(
- Callbacks::new()
- .on_value(|lua, thread, val| {
- let main: LuaThread = lua.named_registry_value("mainThread").unwrap();
- if main == thread {
- lua.set_named_registry_value("mainValue", val).unwrap();
- }
- })
- .on_error(move |lua, thread, err| {
- let main: LuaThread = lua.named_registry_value("mainThread").unwrap();
- if main == thread {
- captured_error_inner.lock_blocking().replace(err);
- }
- }),
- );
-
- // Run until end
- rt.run_blocking();
-
- // Extract value and error from their containers
- let err_opt = { captured_error.lock_blocking().take() };
- let val_opt = lua.named_registry_value("mainValue").ok();
-
- // Check result
- if let Some(err) = err_opt {
- Err(err)
- } else if let Some(val) = val_opt {
- Ok(val)
- } else {
- unreachable!("No value or error captured from main thread");
- }
-}
-
-#[test]
-fn test_captures() -> LuaResult<()> {
- main()
-}
diff --git a/examples/lots_of_threads.rs b/examples/lots_of_threads.rs
index 10d7c6c..d4d17bd 100644
--- a/examples/lots_of_threads.rs
+++ b/examples/lots_of_threads.rs
@@ -1,18 +1,25 @@
use std::time::Duration;
-use smol_mlua::{
- mlua::prelude::{Lua, LuaResult},
- smol::Timer,
- Runtime,
-};
+use mlua::prelude::*;
+use smol::Timer;
+use smol_mlua::Runtime;
const MAIN_SCRIPT: &str = include_str!("./lua/lots_of_threads.luau");
const ONE_NANOSECOND: Duration = Duration::from_nanos(1);
pub fn main() -> LuaResult<()> {
- // Set up persistent lua environment
- let lua = Lua::new();
+ // Set up persistent lua environment, note that we enable thread reuse for
+ // mlua's internal async handling since we will be spawning lots of threads
+ let lua = Lua::new_with(
+ LuaStdLib::ALL,
+ LuaOptions::new()
+ .catch_rust_panics(false)
+ .thread_pool_size(10_000),
+ )?;
+ let rt = Runtime::new(&lua)?;
+
+ lua.globals().set("spawn", rt.create_spawn_function()?)?;
lua.globals().set(
"sleep",
lua.create_async_function(|_, ()| async move {
@@ -23,10 +30,9 @@ pub fn main() -> LuaResult<()> {
})?,
)?;
- // Load the main script into a runtime and run it until completion
- let rt = Runtime::new(&lua)?;
+ // Load the main script into the runtime and run it until completion
let main = lua.load(MAIN_SCRIPT);
- rt.push_thread(main, ());
+ rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())
diff --git a/examples/lua/captures.luau b/examples/lua/captures.luau
deleted file mode 100644
index 74661af..0000000
--- a/examples/lua/captures.luau
+++ /dev/null
@@ -1,23 +0,0 @@
---!nocheck
---!nolint UnknownGlobal
-
-if math.random() < 0.25 then
- error("Unlucky error!")
-end
-
-local main = coroutine.running()
-local start = os.clock()
-
-local counter = 0
-for j = 1, 10_000 do
- __runtime__spawn(function()
- sleep()
- counter += 1
- if counter == 10_000 then
- local elapsed = os.clock() - start
- __runtime__spawn(main, elapsed)
- end
- end)
-end
-
-return coroutine.yield()
diff --git a/examples/lua/lots_of_threads.luau b/examples/lua/lots_of_threads.luau
index 3144f15..3958284 100644
--- a/examples/lua/lots_of_threads.luau
+++ b/examples/lua/lots_of_threads.luau
@@ -13,11 +13,11 @@ for i = 1, NUM_BATCHES do
local counter = 0
for j = 1, NUM_THREADS do
- __runtime__spawn(function()
+ spawn(function()
sleep()
counter += 1
if counter == NUM_THREADS then
- __runtime__spawn(thread)
+ spawn(thread)
end
end)
end
diff --git a/examples/lua/scheduler_ordering.luau b/examples/lua/scheduler_ordering.luau
index 945cb5c..264f503 100644
--- a/examples/lua/scheduler_ordering.luau
+++ b/examples/lua/scheduler_ordering.luau
@@ -4,12 +4,12 @@
print(1)
-- Defer will run at the end of the resumption cycle, but without yielding
-__runtime__defer(function()
+defer(function()
print(5)
end)
-- Spawn will instantly run up until the first yield, and must then be resumed manually ...
-__runtime__spawn(function()
+spawn(function()
print(2)
coroutine.yield()
print("unreachable")
@@ -17,7 +17,7 @@ end)
-- ... unless calling functions created using `lua.create_async_function(...)`,
-- which will resume their calling thread with their result automatically
-__runtime__spawn(function()
+spawn(function()
print(3)
sleep(1)
print(6)
diff --git a/examples/scheduler_ordering.rs b/examples/scheduler_ordering.rs
index ce8e7a1..6aa5b61 100644
--- a/examples/scheduler_ordering.rs
+++ b/examples/scheduler_ordering.rs
@@ -1,16 +1,18 @@
use std::time::{Duration, Instant};
-use smol_mlua::{
- mlua::prelude::{Lua, LuaResult},
- smol::Timer,
- Runtime,
-};
+use mlua::prelude::*;
+use smol::Timer;
+use smol_mlua::Runtime;
const MAIN_SCRIPT: &str = include_str!("./lua/scheduler_ordering.luau");
pub fn main() -> LuaResult<()> {
// Set up persistent lua environment
let lua = Lua::new();
+ let rt = Runtime::new(&lua)?;
+
+ lua.globals().set("spawn", rt.create_spawn_function()?)?;
+ lua.globals().set("defer", rt.create_defer_function()?)?;
lua.globals().set(
"sleep",
lua.create_async_function(|_, duration: Option| async move {
@@ -22,9 +24,8 @@ pub fn main() -> LuaResult<()> {
)?;
// Load the main script into a runtime and run it until completion
- let rt = Runtime::new(&lua)?;
let main = lua.load(MAIN_SCRIPT);
- rt.push_thread(main, ());
+ rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())
diff --git a/lib/callbacks.rs b/lib/callbacks.rs
deleted file mode 100644
index 7b86de0..0000000
--- a/lib/callbacks.rs
+++ /dev/null
@@ -1,128 +0,0 @@
-use mlua::prelude::*;
-
-type ValueCallback = Box Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>;
-type ErrorCallback = Box Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>;
-
-const FORWARD_VALUE_KEY: &str = "__runtime__forwardValue";
-const FORWARD_ERROR_KEY: &str = "__runtime__forwardError";
-
-/**
- A set of callbacks for thread values and errors.
-
- These callbacks are used to forward values and errors from
- Lua threads back to Rust. By default, the runtime will print
- any errors to stderr and not do any operations with values.
-
- You can set your own callbacks using the `on_value` and `on_error` builder methods.
-*/
-pub struct Callbacks {
- on_value: Option,
- on_error: Option,
-}
-
-impl Callbacks {
- /**
- Creates a new set of callbacks with no callbacks set.
- */
- pub fn new() -> Self {
- Self {
- on_value: None,
- on_error: None,
- }
- }
-
- /**
- Sets the callback for thread values being yielded / returned.
- */
- pub fn on_value(mut self, f: F) -> Self
- where
- F: Fn(&Lua, LuaThread, LuaValue) + 'static,
- {
- self.on_value.replace(Box::new(f));
- self
- }
-
- /**
- Sets the callback for thread errors.
- */
- pub fn on_error(mut self, f: F) -> Self
- where
- F: Fn(&Lua, LuaThread, LuaError) + 'static,
- {
- self.on_error.replace(Box::new(f));
- self
- }
-
- /**
- Removes any current thread value callback.
- */
- pub fn without_value_callback(mut self) -> Self {
- self.on_value.take();
- self
- }
-
- /**
- Removes any current thread error callback.
- */
- pub fn without_error_callback(mut self) -> Self {
- self.on_error.take();
- self
- }
-
- pub(crate) fn inject(self, lua: &Lua) {
- // Remove any previously injected callbacks
- lua.unset_named_registry_value(FORWARD_VALUE_KEY).ok();
- lua.unset_named_registry_value(FORWARD_ERROR_KEY).ok();
-
- // Create functions to forward values & errors
- if let Some(f) = self.on_value {
- lua.set_named_registry_value(
- FORWARD_VALUE_KEY,
- lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| {
- f(lua, thread, val);
- Ok(())
- })
- .expect("failed to create value callback function"),
- )
- .expect("failed to store value callback function");
- }
-
- if let Some(f) = self.on_error {
- lua.set_named_registry_value(
- FORWARD_ERROR_KEY,
- lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| {
- f(lua, thread, err);
- Ok(())
- })
- .expect("failed to create error callback function"),
- )
- .expect("failed to store error callback function");
- }
- }
-
- pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) {
- if let Ok(f) = lua.named_registry_value::(FORWARD_VALUE_KEY) {
- f.call::<_, ()>((thread, value)).unwrap();
- }
- }
-
- pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) {
- if let Ok(f) = lua.named_registry_value::(FORWARD_ERROR_KEY) {
- f.call::<_, ()>((thread, error)).unwrap();
- }
- }
-}
-
-impl Default for Callbacks {
- fn default() -> Self {
- Callbacks {
- on_value: Some(Box::new(default_value_callback)),
- on_error: Some(Box::new(default_error_callback)),
- }
- }
-}
-
-fn default_value_callback(_: &Lua, _: LuaThread, _: LuaValue) {}
-fn default_error_callback(_: &Lua, _: LuaThread, e: LuaError) {
- eprintln!("{e}");
-}
diff --git a/lib/error_callback.rs b/lib/error_callback.rs
new file mode 100644
index 0000000..1e9f04b
--- /dev/null
+++ b/lib/error_callback.rs
@@ -0,0 +1,52 @@
+use std::sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc,
+};
+
+use mlua::prelude::*;
+use smol::lock::Mutex;
+
+type ErrorCallback = Box;
+
+#[derive(Clone)]
+pub(crate) struct ThreadErrorCallback {
+ exists: Arc,
+ inner: Arc>>,
+}
+
+impl ThreadErrorCallback {
+ pub fn new() -> Self {
+ Self {
+ exists: Arc::new(AtomicBool::new(false)),
+ inner: Arc::new(Mutex::new(None)),
+ }
+ }
+
+ pub fn new_default() -> Self {
+ let this = Self::new();
+ this.replace(default_error_callback);
+ this
+ }
+
+ pub fn replace(&self, callback: impl Fn(LuaError) + Send + 'static) {
+ self.exists.store(true, Ordering::Relaxed);
+ self.inner.lock_blocking().replace(Box::new(callback));
+ }
+
+ pub fn clear(&self) {
+ self.exists.store(false, Ordering::Relaxed);
+ self.inner.lock_blocking().take();
+ }
+
+ pub fn call(&self, error: &LuaError) {
+ if self.exists.load(Ordering::Relaxed) {
+ if let Some(cb) = &*self.inner.lock_blocking() {
+ cb(error.clone());
+ }
+ }
+ }
+}
+
+fn default_error_callback(e: LuaError) {
+ eprintln!("{e}");
+}
diff --git a/lib/lib.rs b/lib/lib.rs
index 5aa95b6..627eebd 100644
--- a/lib/lib.rs
+++ b/lib/lib.rs
@@ -1,12 +1,8 @@
-mod callbacks;
+mod error_callback;
+mod queue;
mod runtime;
-mod storage;
mod traits;
mod util;
-pub use mlua;
-pub use smol;
-
-pub use callbacks::Callbacks;
pub use runtime::Runtime;
-pub use traits::{IntoLuaThread, LuaExecutorExt};
+pub use traits::{IntoLuaThread, LuaSpawnExt};
diff --git a/lib/queue.rs b/lib/queue.rs
new file mode 100644
index 0000000..6ac9f27
--- /dev/null
+++ b/lib/queue.rs
@@ -0,0 +1,107 @@
+use std::sync::{
+ atomic::{AtomicBool, Ordering},
+ Arc,
+};
+
+use mlua::prelude::*;
+use smol::{
+ channel::{unbounded, Receiver, Sender},
+ lock::Mutex,
+};
+
+use crate::IntoLuaThread;
+
+const ERR_OOM: &str = "out of memory";
+
+/**
+ Queue for storing [`LuaThread`]s with associated arguments.
+
+ Provides methods for pushing and draining the queue, as
+ well as listening for new items being pushed to the queue.
+*/
+#[derive(Debug, Clone)]
+pub struct ThreadQueue {
+ queue: Arc>>,
+ status: Arc,
+ signal_tx: Sender<()>,
+ signal_rx: Receiver<()>,
+}
+
+impl ThreadQueue {
+ pub fn new() -> Self {
+ let (signal_tx, signal_rx) = unbounded();
+ Self {
+ queue: Arc::new(Mutex::new(Vec::new())),
+ status: Arc::new(AtomicBool::new(false)),
+ signal_tx,
+ signal_rx,
+ }
+ }
+
+ pub fn has_threads(&self) -> bool {
+ self.status.load(Ordering::SeqCst)
+ }
+
+ pub fn push<'lua>(
+ &self,
+ lua: &'lua Lua,
+ thread: impl IntoLuaThread<'lua>,
+ args: impl IntoLuaMulti<'lua>,
+ ) -> LuaResult<()> {
+ let thread = thread.into_lua_thread(lua)?;
+ let args = args.into_lua_multi(lua)?;
+ let stored = ThreadWithArgs::new(lua, thread, args);
+
+ self.queue.lock_blocking().push(stored);
+ self.status.store(true, Ordering::SeqCst);
+ self.signal_tx.try_send(()).unwrap();
+
+ Ok(())
+ }
+
+ pub async fn drain<'lua>(&self, lua: &'lua Lua) -> Vec<(LuaThread<'lua>, LuaMultiValue<'lua>)> {
+ let mut queue = self.queue.lock().await;
+ let drained = queue.drain(..).map(|s| s.into_inner(lua)).collect();
+ self.status.store(false, Ordering::SeqCst);
+ drained
+ }
+
+ pub async fn recv(&self) {
+ self.signal_rx.recv().await.unwrap();
+ }
+}
+
+/**
+ Representation of a [`LuaThread`] with associated arguments currently stored in the Lua registry.
+*/
+#[derive(Debug)]
+struct ThreadWithArgs {
+ key_thread: LuaRegistryKey,
+ key_args: LuaRegistryKey,
+}
+
+impl ThreadWithArgs {
+ pub fn new<'lua>(lua: &'lua Lua, thread: LuaThread<'lua>, args: LuaMultiValue<'lua>) -> Self {
+ let argsv = args.into_vec();
+
+ let key_thread = lua.create_registry_value(thread).expect(ERR_OOM);
+ let key_args = lua.create_registry_value(argsv).expect(ERR_OOM);
+
+ Self {
+ key_thread,
+ key_args,
+ }
+ }
+
+ pub fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) {
+ let thread = lua.registry_value(&self.key_thread).unwrap();
+ let argsv = lua.registry_value(&self.key_args).unwrap();
+
+ let args = LuaMultiValue::from_vec(argsv);
+
+ lua.remove_registry_value(self.key_thread).unwrap();
+ lua.remove_registry_value(self.key_args).unwrap();
+
+ (thread, args)
+ }
+}
diff --git a/lib/runtime.rs b/lib/runtime.rs
index 661bf7d..7887bb8 100644
--- a/lib/runtime.rs
+++ b/lib/runtime.rs
@@ -1,160 +1,149 @@
-use std::{cell::Cell, rc::Rc, sync::Arc};
+use std::sync::Arc;
use mlua::prelude::*;
use smol::prelude::*;
-use smol::{
- block_on,
- channel::{unbounded, Receiver, Sender},
- lock::Mutex,
- Executor, LocalExecutor,
-};
+use smol::{block_on, Executor, LocalExecutor};
use super::{
- callbacks::Callbacks, storage::ThreadWithArgs, traits::IntoLuaThread, util::LuaThreadOrFunction,
+ error_callback::ThreadErrorCallback,
+ queue::ThreadQueue,
+ traits::IntoLuaThread,
+ util::{is_poll_pending, LuaThreadOrFunction},
};
-const GLOBAL_NAME_SPAWN: &str = "__runtime__spawn";
-const GLOBAL_NAME_DEFER: &str = "__runtime__defer";
-
pub struct Runtime<'lua> {
lua: &'lua Lua,
- queue_status: Rc>,
- // TODO: Use something better than Rc>>
- queue_spawn: Rc>>,
- queue_defer: Rc>>,
- tx: Sender<()>,
- rx: Receiver<()>,
+ queue_spawn: ThreadQueue,
+ queue_defer: ThreadQueue,
+ error_callback: ThreadErrorCallback,
}
impl<'lua> Runtime<'lua> {
/**
Creates a new runtime for the given Lua state.
- This will inject some functions to interact with the scheduler / executor,
- as well as the default [`Callbacks`] for thread values and errors.
+ This runtime will have a default error callback that prints errors to stderr.
*/
pub fn new(lua: &'lua Lua) -> LuaResult> {
- let queue_status = Rc::new(Cell::new(false));
- let queue_spawn = Rc::new(Mutex::new(Vec::new()));
- let queue_defer = Rc::new(Mutex::new(Vec::new()));
- let (tx, rx) = unbounded();
-
- // HACK: Extract mlua "pending" constant value and store it
- let pending = lua
- .create_async_function(|_, ()| async move {
- smol::future::yield_now().await;
- Ok(())
- })?
- .into_lua_thread(lua)?
- .resume::<_, LuaValue>(())?;
- let pending_key = lua.create_registry_value(pending)?;
-
- // TODO: Generalize these two functions below so we
- // dont need to duplicate the same exact thing for
- // spawn and defer which is prone to human error
-
- // Create spawn function (push to start of queue)
- let b_spawn = Rc::clone(&queue_status);
- let q_spawn = Rc::clone(&queue_spawn);
- let tx_spawn = tx.clone();
- let fn_spawn = lua.create_function(
- move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
- let thread = tof.into_thread(lua)?;
- if thread.status() == LuaThreadStatus::Resumable {
- // HACK: We need to resume the thread once instantly for correct behavior,
- // and only if we get the pending value back we can spawn to async executor
- let pending: LuaValue = lua.registry_value(&pending_key)?;
- match thread.resume::<_, LuaValue>(args.clone()) {
- Ok(v) if v == pending => {
- let stored = ThreadWithArgs::new(lua, thread.clone(), args);
- q_spawn.lock_blocking().push(stored);
- b_spawn.replace(true);
- tx_spawn.try_send(()).map_err(|_| {
- LuaError::runtime("Tried to spawn thread to a dropped queue")
- })?;
- }
- Ok(v) => Callbacks::forward_value(lua, thread.clone(), v),
- Err(e) => Callbacks::forward_error(lua, thread.clone(), e),
- }
- Ok(thread)
- } else {
- Err(LuaError::runtime("Tried to spawn non-resumable thread"))
- }
- },
- )?;
-
- // Create defer function (push to end of queue)
- let b_defer = Rc::clone(&queue_status);
- let q_defer = Rc::clone(&queue_defer);
- let tx_defer = tx.clone();
- let fn_defer = lua.create_function(
- move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
- let thread = tof.into_thread(lua)?;
- if thread.status() == LuaThreadStatus::Resumable {
- let stored = ThreadWithArgs::new(lua, thread.clone(), args);
- q_defer.lock_blocking().push(stored);
- b_defer.replace(true);
- tx_defer.try_send(()).map_err(|_| {
- LuaError::runtime("Tried to defer thread to a dropped queue")
- })?;
- Ok(thread)
- } else {
- Err(LuaError::runtime("Tried to defer non-resumable thread"))
- }
- },
- )?;
-
- // Store them both as globals
- lua.globals().set(GLOBAL_NAME_SPAWN, fn_spawn)?;
- lua.globals().set(GLOBAL_NAME_DEFER, fn_defer)?;
-
- // Finally, inject default callbacks
- Callbacks::default().inject(lua);
+ let queue_spawn = ThreadQueue::new();
+ let queue_defer = ThreadQueue::new();
+ let error_callback = ThreadErrorCallback::new_default();
Ok(Runtime {
lua,
- queue_status,
queue_spawn,
queue_defer,
- tx,
- rx,
+ error_callback,
})
}
/**
- Sets the callbacks for this runtime.
+ Sets the error callback for this runtime.
- This will overwrite any previously set callbacks, including default ones.
+ This callback will be called whenever a Lua thread errors.
+
+ Overwrites any previous error callback.
*/
- pub fn set_callbacks(&self, callbacks: Callbacks) {
- callbacks.inject(self.lua);
+ pub fn set_error_callback(&self, callback: impl Fn(LuaError) + Send + 'static) {
+ self.error_callback.replace(callback);
}
/**
- Pushes a chunk / function / thread to the runtime queue.
+ Clears the error callback for this runtime.
+
+ This will remove any current error callback, including default(s).
+ */
+ pub fn remove_error_callback(&self) {
+ self.error_callback.clear();
+ }
+
+ /**
+ Spawns a chunk / function / thread onto the runtime queue.
Threads are guaranteed to be resumed in the order that they were pushed to the queue.
*/
- pub fn push_thread(
+ pub fn spawn_thread(
&self,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
- ) -> LuaThread<'lua> {
- let thread = thread
- .into_lua_thread(self.lua)
- .expect("failed to create thread");
- let args = args
- .into_lua_multi(self.lua)
- .expect("failed to create args");
+ ) -> LuaResult<()> {
+ let thread = thread.into_lua_thread(self.lua)?;
+ let args = args.into_lua_multi(self.lua)?;
- let stored = ThreadWithArgs::new(self.lua, thread.clone(), args);
+ self.queue_spawn.push(self.lua, thread, args)?;
- self.queue_spawn.lock_blocking().push(stored);
- self.queue_status.replace(true);
- self.tx.try_send(()).unwrap(); // Unwrap is safe since this struct also holds the receiver
+ Ok(())
+ }
- thread
+ /**
+ Defers a chunk / function / thread onto the runtime queue.
+
+ Deferred threads are guaranteed to run after all spawned threads either yield or complete.
+
+ Threads are guaranteed to be resumed in the order that they were pushed to the queue.
+ */
+ pub fn defer_thread(
+ &self,
+ thread: impl IntoLuaThread<'lua>,
+ args: impl IntoLuaMulti<'lua>,
+ ) -> LuaResult<()> {
+ let thread = thread.into_lua_thread(self.lua)?;
+ let args = args.into_lua_multi(self.lua)?;
+
+ self.queue_defer.push(self.lua, thread, args)?;
+
+ Ok(())
+ }
+
+ /**
+ Creates a lua function that can be used to spawn threads / functions onto the runtime queue.
+
+ The function takes a thread or function as the first argument, and any variadic arguments as the rest.
+ */
+ pub fn create_spawn_function(&self) -> LuaResult> {
+ let error_callback = self.error_callback.clone();
+ let spawn_queue = self.queue_spawn.clone();
+ self.lua.create_function(
+ move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
+ let thread = tof.into_thread(lua)?;
+ if thread.status() == LuaThreadStatus::Resumable {
+ // NOTE: We need to resume the thread once instantly for correct behavior,
+ // and only if we get the pending value back we can spawn to async executor
+ match thread.resume::<_, LuaValue>(args.clone()) {
+ Ok(v) => {
+ if is_poll_pending(&v) {
+ spawn_queue.push(lua, &thread, args)?;
+ }
+ }
+ Err(e) => {
+ error_callback.call(&e);
+ }
+ };
+ }
+ Ok(thread)
+ },
+ )
+ }
+
+ /**
+ Creates a lua function that can be used to defer threads / functions onto the runtime queue.
+
+ The function takes a thread or function as the first argument, and any variadic arguments as the rest.
+
+ Deferred threads are guaranteed to run after all spawned threads either yield or complete.
+ */
+ pub fn create_defer_function(&self) -> LuaResult> {
+ let defer_queue = self.queue_defer.clone();
+ self.lua.create_function(
+ move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
+ let thread = tof.into_thread(lua)?;
+ if thread.status() == LuaThreadStatus::Resumable {
+ defer_queue.push(lua, &thread, args)?;
+ }
+ Ok(thread)
+ },
+ )
}
/**
@@ -170,7 +159,7 @@ impl<'lua> Runtime<'lua> {
let lua_exec = LocalExecutor::new();
let main_exec = Arc::new(Executor::new());
- // Store the main executor in lua for LuaExecutorExt trait
+ // Store the main executor in lua for spawner trait
self.lua.set_app_data(Arc::downgrade(&main_exec));
// Tick local lua executor while also driving main
@@ -179,9 +168,8 @@ impl<'lua> Runtime<'lua> {
loop {
// Wait for a new thread to arrive __or__ next futures step, prioritizing
// new threads, so we don't accidentally exit when there is more work to do
- let fut_recv = async {
- self.rx.recv().await.ok();
- };
+ let fut_spawn = self.queue_spawn.recv();
+ let fut_defer = self.queue_defer.recv();
let fut_tick = async {
lua_exec.tick().await;
// Do as much work as possible
@@ -191,18 +179,18 @@ impl<'lua> Runtime<'lua> {
}
}
};
- fut_recv.or(fut_tick).await;
- // If a new thread was spawned onto any queue, we
- // must drain them and schedule on the executor
- if self.queue_status.get() {
+ fut_spawn.or(fut_defer).or(fut_tick).await;
+
+ // If a new thread was spawned onto any queue,
+ // we must drain them and schedule on the executor
+ if self.queue_spawn.has_threads() || self.queue_defer.has_threads() {
let mut queued_threads = Vec::new();
- queued_threads.extend(self.queue_spawn.lock().await.drain(..));
- queued_threads.extend(self.queue_defer.lock().await.drain(..));
- for queued_thread in queued_threads {
+ queued_threads.extend(self.queue_spawn.drain(self.lua).await);
+ queued_threads.extend(self.queue_defer.drain(self.lua).await);
+ for (thread, args) in queued_threads {
// NOTE: Thread may have been cancelled from lua
// before we got here, so we need to check it again
- let (thread, args) = queued_thread.into_inner(self.lua);
if thread.status() == LuaThreadStatus::Resumable {
let mut stream = thread.clone().into_async::<_, LuaValue>(args);
lua_exec
@@ -210,10 +198,11 @@ impl<'lua> Runtime<'lua> {
// Only run stream until first coroutine.yield or completion. We will
// drop it right away to clear stack space since detached tasks dont drop
// until the executor drops https://github.com/smol-rs/smol/issues/294
- match stream.next().await.unwrap() {
- Ok(v) => Callbacks::forward_value(self.lua, thread, v),
- Err(e) => Callbacks::forward_error(self.lua, thread, e),
- };
+ let res = stream.next().await.unwrap();
+ if let Err(e) = &res {
+ self.error_callback.call(e);
+ }
+ // TODO: Figure out how to give this result to caller of spawn_thread/defer_thread
})
.detach();
}
diff --git a/lib/storage.rs b/lib/storage.rs
deleted file mode 100644
index f067983..0000000
--- a/lib/storage.rs
+++ /dev/null
@@ -1,43 +0,0 @@
-use mlua::prelude::*;
-
-#[derive(Debug)]
-pub(crate) struct ThreadWithArgs {
- key_thread: LuaRegistryKey,
- key_args: LuaRegistryKey,
-}
-
-impl ThreadWithArgs {
- pub fn new<'lua>(lua: &'lua Lua, thread: LuaThread<'lua>, args: LuaMultiValue<'lua>) -> Self {
- let args_vec = args.into_vec();
-
- let key_thread = lua
- .create_registry_value(thread)
- .expect("Failed to store thread in registry - out of memory");
- let key_args = lua
- .create_registry_value(args_vec)
- .expect("Failed to store thread args in registry - out of memory");
-
- Self {
- key_thread,
- key_args,
- }
- }
-
- pub fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) {
- let thread = lua
- .registry_value(&self.key_thread)
- .expect("Failed to get thread from registry");
- let args_vec = lua
- .registry_value(&self.key_args)
- .expect("Failed to get thread args from registry");
-
- let args = LuaMultiValue::from_vec(args_vec);
-
- lua.remove_registry_value(self.key_thread)
- .expect("Failed to remove thread from registry");
- lua.remove_registry_value(self.key_args)
- .expect("Failed to remove thread args from registry");
-
- (thread, args)
- }
-}
diff --git a/lib/traits.rs b/lib/traits.rs
index 3e2a1a5..95e7f52 100644
--- a/lib/traits.rs
+++ b/lib/traits.rs
@@ -36,13 +36,22 @@ impl<'lua> IntoLuaThread<'lua> for LuaChunk<'lua, '_> {
}
}
+impl<'lua, T> IntoLuaThread<'lua> for &T
+where
+ T: IntoLuaThread<'lua> + Clone,
+{
+ fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult> {
+ self.clone().into_lua_thread(lua)
+ }
+}
+
/**
Trait for spawning `Send` futures on the current executor.
- For spawning non-`Send` futures on the same local executor as a [`Lua`]
+ For spawning `!Send` futures on the same local executor as a [`Lua`]
VM instance, [`Lua::create_async_function`] should be used instead.
*/
-pub trait LuaExecutorExt<'lua> {
+pub trait LuaSpawnExt<'lua> {
/**
Spawns the given future on the current executor and returns its [`Task`].
@@ -54,7 +63,7 @@ pub trait LuaExecutorExt<'lua> {
```rust
use mlua::prelude::*;
- use smol_mlua::{Runtime, LuaExecutorExt};
+ use smol_mlua::{Runtime, LuaSpawnExt};
fn main() -> LuaResult<()> {
let lua = Lua::new();
@@ -70,7 +79,7 @@ pub trait LuaExecutorExt<'lua> {
)?;
let rt = Runtime::new(&lua)?;
- rt.push_thread(lua.load("spawnBackgroundTask()"), ());
+ rt.spawn_thread(lua.load("spawnBackgroundTask()"), ());
rt.run_blocking();
Ok(())
@@ -82,7 +91,7 @@ pub trait LuaExecutorExt<'lua> {
fn spawn(&self, fut: impl Future |