Refactor runtime and callbacks

* Improved ergonomics and flexibility for crate consumers
* Simplified callback mechanism for errors
* Factor out runtime thread queues into proper structs
* Misc performance improvements - approx 20% faster scheduler
This commit is contained in:
Filip Tibell 2024-01-24 19:50:25 +01:00
parent 913f575c74
commit 588fc46807
No known key found for this signature in database
19 changed files with 388 additions and 488 deletions

View file

@ -22,10 +22,6 @@ test = true
name = "callbacks"
test = true
[[example]]
name = "captures"
test = true
[[example]]
name = "lots_of_threads"
test = true

View file

@ -16,7 +16,10 @@
<br/>
Integration between [smol](https://crates.io/crates/smol) and [mlua](https://crates.io/crates/mlua) that provides a fully functional and asynchronous Luau runtime using smol executor(s).
Integration between [smol] and [mlua] that provides a fully functional and asynchronous Luau runtime using smol executor(s).
[smol]: https://crates.io/crates/smol
[mlua]: https://crates.io/crates/mlua
## Example Usage
@ -25,11 +28,9 @@ Integration between [smol](https://crates.io/crates/smol) and [mlua](https://cra
```rs
use std::time::{Duration, Instant};
use smol_mlua::{
mlua::prelude::*,
smol::{Timer, io, fs::read_to_string},
Runtime,
};
use mlua::prelude::*;
use smol::{Timer, io, fs::read_to_string}
use smol_mlua::Runtime;
```
### 2. Set up lua environment
@ -68,15 +69,15 @@ lua.globals().set(
```rs
let rt = Runtime::new(&lua)?;
// We can create multiple lua threads
// We can create multiple lua threads ...
let sleepThread = lua.load("sleep(0.1)");
let fileThread = lua.load("readFile(\"Cargo.toml\")");
// Put them all into the runtime
rt.push_thread(sleepThread, ());
rt.push_thread(fileThread, ());
// ... spawn them both onto the runtime ...
rt.spawn_thread(sleepThread, ());
rt.spawn_thread(fileThread, ());
// And run either async or blocking, until above threads finish
// ... and run either async or blocking, until they finish
rt.run_async().await;
rt.run_blocking();
```

View file

@ -1,10 +1,8 @@
use std::time::{Duration, Instant};
use smol_mlua::{
mlua::prelude::{Lua, LuaResult},
smol::Timer,
Runtime,
};
use mlua::prelude::*;
use smol::Timer;
use smol_mlua::Runtime;
const MAIN_SCRIPT: &str = include_str!("./lua/basic_sleep.luau");
@ -23,7 +21,7 @@ pub fn main() -> LuaResult<()> {
// Load the main script into a runtime and run it until completion
let rt = Runtime::new(&lua)?;
let main = lua.load(MAIN_SCRIPT);
rt.push_thread(main, ());
rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())

View file

@ -1,10 +1,6 @@
use mlua::ExternalResult;
use smol::io;
use smol_mlua::{
mlua::prelude::{Lua, LuaResult},
smol::fs::read_to_string,
LuaExecutorExt, Runtime,
};
use mlua::prelude::*;
use smol::{fs::read_to_string, io};
use smol_mlua::{LuaSpawnExt, Runtime};
const MAIN_SCRIPT: &str = include_str!("./lua/basic_spawn.luau");
@ -29,7 +25,7 @@ pub fn main() -> LuaResult<()> {
// Load the main script into a runtime and run it until completion
let rt = Runtime::new(&lua)?;
let main = lua.load(MAIN_SCRIPT);
rt.push_thread(main, ());
rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())

View file

@ -1,7 +1,5 @@
use smol_mlua::{
mlua::prelude::{Lua, LuaResult},
Callbacks, Runtime,
};
use mlua::prelude::*;
use smol_mlua::Runtime;
const MAIN_SCRIPT: &str = include_str!("./lua/callbacks.luau");
@ -11,17 +9,17 @@ pub fn main() -> LuaResult<()> {
// Create a new runtime with custom callbacks
let rt = Runtime::new(&lua)?;
rt.set_callbacks(Callbacks::default().on_error(|_, _, e| {
rt.set_error_callback(|e| {
println!(
"Captured error from Lua!\n{}\n{e}\n{}",
"-".repeat(15),
"-".repeat(15)
);
}));
});
// Load and run the main script until completion
let main = lua.load(MAIN_SCRIPT);
rt.push_thread(main, ());
rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())

View file

@ -1,92 +0,0 @@
use std::{
rc::Rc,
time::{Duration, Instant},
};
use smol_mlua::{
mlua::prelude::{Lua, LuaResult, LuaThread, LuaValue},
smol::{lock::Mutex, Timer},
Callbacks, IntoLuaThread, Runtime,
};
const MAIN_SCRIPT: &str = include_str!("./lua/captures.luau");
pub fn main() -> LuaResult<()> {
// Set up persistent lua environment
let lua = Lua::new();
lua.globals().set(
"sleep",
lua.create_async_function(|_, duration: Option<f64>| async move {
let duration = duration.unwrap_or_default().max(1.0 / 250.0);
let before = Instant::now();
let after = Timer::after(Duration::from_secs_f64(duration)).await;
Ok((after - before).as_secs_f64())
})?,
)?;
// Load and run the main script a few times for the purposes of this example
for _ in 0..20 {
println!("...");
match run(&lua, lua.load(MAIN_SCRIPT)) {
Err(e) => eprintln!("Errored:\n{e}"),
Ok(v) => println!("Returned value:\n{v:?}"),
}
}
Ok(())
}
/**
Wrapper function to run the given `main` thread on a new [`Runtime`].
Waits for all threads to finish, including the main thread, and
returns the value or error of the main thread once exited.
*/
fn run<'lua>(lua: &'lua Lua, main: impl IntoLuaThread<'lua>) -> LuaResult<LuaValue> {
// Set up runtime (thread queue / async executors)
let rt = Runtime::new(lua)?;
let thread = rt.push_thread(main, ());
lua.set_named_registry_value("mainThread", thread)?;
// Create callbacks to capture resulting value/error of main thread,
// we need to do some tricks to get around the lifetime issues with 'lua
// being different inside the callback vs. outside the callback, for LuaValue
let captured_error = Rc::new(Mutex::new(None));
let captured_error_inner = Rc::clone(&captured_error);
rt.set_callbacks(
Callbacks::new()
.on_value(|lua, thread, val| {
let main: LuaThread = lua.named_registry_value("mainThread").unwrap();
if main == thread {
lua.set_named_registry_value("mainValue", val).unwrap();
}
})
.on_error(move |lua, thread, err| {
let main: LuaThread = lua.named_registry_value("mainThread").unwrap();
if main == thread {
captured_error_inner.lock_blocking().replace(err);
}
}),
);
// Run until end
rt.run_blocking();
// Extract value and error from their containers
let err_opt = { captured_error.lock_blocking().take() };
let val_opt = lua.named_registry_value("mainValue").ok();
// Check result
if let Some(err) = err_opt {
Err(err)
} else if let Some(val) = val_opt {
Ok(val)
} else {
unreachable!("No value or error captured from main thread");
}
}
#[test]
fn test_captures() -> LuaResult<()> {
main()
}

View file

@ -1,18 +1,25 @@
use std::time::Duration;
use smol_mlua::{
mlua::prelude::{Lua, LuaResult},
smol::Timer,
Runtime,
};
use mlua::prelude::*;
use smol::Timer;
use smol_mlua::Runtime;
const MAIN_SCRIPT: &str = include_str!("./lua/lots_of_threads.luau");
const ONE_NANOSECOND: Duration = Duration::from_nanos(1);
pub fn main() -> LuaResult<()> {
// Set up persistent lua environment
let lua = Lua::new();
// Set up persistent lua environment, note that we enable thread reuse for
// mlua's internal async handling since we will be spawning lots of threads
let lua = Lua::new_with(
LuaStdLib::ALL,
LuaOptions::new()
.catch_rust_panics(false)
.thread_pool_size(10_000),
)?;
let rt = Runtime::new(&lua)?;
lua.globals().set("spawn", rt.create_spawn_function()?)?;
lua.globals().set(
"sleep",
lua.create_async_function(|_, ()| async move {
@ -23,10 +30,9 @@ pub fn main() -> LuaResult<()> {
})?,
)?;
// Load the main script into a runtime and run it until completion
let rt = Runtime::new(&lua)?;
// Load the main script into the runtime and run it until completion
let main = lua.load(MAIN_SCRIPT);
rt.push_thread(main, ());
rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())

View file

@ -1,23 +0,0 @@
--!nocheck
--!nolint UnknownGlobal
if math.random() < 0.25 then
error("Unlucky error!")
end
local main = coroutine.running()
local start = os.clock()
local counter = 0
for j = 1, 10_000 do
__runtime__spawn(function()
sleep()
counter += 1
if counter == 10_000 then
local elapsed = os.clock() - start
__runtime__spawn(main, elapsed)
end
end)
end
return coroutine.yield()

View file

@ -13,11 +13,11 @@ for i = 1, NUM_BATCHES do
local counter = 0
for j = 1, NUM_THREADS do
__runtime__spawn(function()
spawn(function()
sleep()
counter += 1
if counter == NUM_THREADS then
__runtime__spawn(thread)
spawn(thread)
end
end)
end

View file

@ -4,12 +4,12 @@
print(1)
-- Defer will run at the end of the resumption cycle, but without yielding
__runtime__defer(function()
defer(function()
print(5)
end)
-- Spawn will instantly run up until the first yield, and must then be resumed manually ...
__runtime__spawn(function()
spawn(function()
print(2)
coroutine.yield()
print("unreachable")
@ -17,7 +17,7 @@ end)
-- ... unless calling functions created using `lua.create_async_function(...)`,
-- which will resume their calling thread with their result automatically
__runtime__spawn(function()
spawn(function()
print(3)
sleep(1)
print(6)

View file

@ -1,16 +1,18 @@
use std::time::{Duration, Instant};
use smol_mlua::{
mlua::prelude::{Lua, LuaResult},
smol::Timer,
Runtime,
};
use mlua::prelude::*;
use smol::Timer;
use smol_mlua::Runtime;
const MAIN_SCRIPT: &str = include_str!("./lua/scheduler_ordering.luau");
pub fn main() -> LuaResult<()> {
// Set up persistent lua environment
let lua = Lua::new();
let rt = Runtime::new(&lua)?;
lua.globals().set("spawn", rt.create_spawn_function()?)?;
lua.globals().set("defer", rt.create_defer_function()?)?;
lua.globals().set(
"sleep",
lua.create_async_function(|_, duration: Option<f64>| async move {
@ -22,9 +24,8 @@ pub fn main() -> LuaResult<()> {
)?;
// Load the main script into a runtime and run it until completion
let rt = Runtime::new(&lua)?;
let main = lua.load(MAIN_SCRIPT);
rt.push_thread(main, ());
rt.spawn_thread(main, ())?;
rt.run_blocking();
Ok(())

View file

@ -1,128 +0,0 @@
use mlua::prelude::*;
type ValueCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>;
type ErrorCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>;
const FORWARD_VALUE_KEY: &str = "__runtime__forwardValue";
const FORWARD_ERROR_KEY: &str = "__runtime__forwardError";
/**
A set of callbacks for thread values and errors.
These callbacks are used to forward values and errors from
Lua threads back to Rust. By default, the runtime will print
any errors to stderr and not do any operations with values.
You can set your own callbacks using the `on_value` and `on_error` builder methods.
*/
pub struct Callbacks {
on_value: Option<ValueCallback>,
on_error: Option<ErrorCallback>,
}
impl Callbacks {
/**
Creates a new set of callbacks with no callbacks set.
*/
pub fn new() -> Self {
Self {
on_value: None,
on_error: None,
}
}
/**
Sets the callback for thread values being yielded / returned.
*/
pub fn on_value<F>(mut self, f: F) -> Self
where
F: Fn(&Lua, LuaThread, LuaValue) + 'static,
{
self.on_value.replace(Box::new(f));
self
}
/**
Sets the callback for thread errors.
*/
pub fn on_error<F>(mut self, f: F) -> Self
where
F: Fn(&Lua, LuaThread, LuaError) + 'static,
{
self.on_error.replace(Box::new(f));
self
}
/**
Removes any current thread value callback.
*/
pub fn without_value_callback(mut self) -> Self {
self.on_value.take();
self
}
/**
Removes any current thread error callback.
*/
pub fn without_error_callback(mut self) -> Self {
self.on_error.take();
self
}
pub(crate) fn inject(self, lua: &Lua) {
// Remove any previously injected callbacks
lua.unset_named_registry_value(FORWARD_VALUE_KEY).ok();
lua.unset_named_registry_value(FORWARD_ERROR_KEY).ok();
// Create functions to forward values & errors
if let Some(f) = self.on_value {
lua.set_named_registry_value(
FORWARD_VALUE_KEY,
lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| {
f(lua, thread, val);
Ok(())
})
.expect("failed to create value callback function"),
)
.expect("failed to store value callback function");
}
if let Some(f) = self.on_error {
lua.set_named_registry_value(
FORWARD_ERROR_KEY,
lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| {
f(lua, thread, err);
Ok(())
})
.expect("failed to create error callback function"),
)
.expect("failed to store error callback function");
}
}
pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) {
if let Ok(f) = lua.named_registry_value::<LuaFunction>(FORWARD_VALUE_KEY) {
f.call::<_, ()>((thread, value)).unwrap();
}
}
pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) {
if let Ok(f) = lua.named_registry_value::<LuaFunction>(FORWARD_ERROR_KEY) {
f.call::<_, ()>((thread, error)).unwrap();
}
}
}
impl Default for Callbacks {
fn default() -> Self {
Callbacks {
on_value: Some(Box::new(default_value_callback)),
on_error: Some(Box::new(default_error_callback)),
}
}
}
fn default_value_callback(_: &Lua, _: LuaThread, _: LuaValue) {}
fn default_error_callback(_: &Lua, _: LuaThread, e: LuaError) {
eprintln!("{e}");
}

52
lib/error_callback.rs Normal file
View file

@ -0,0 +1,52 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use mlua::prelude::*;
use smol::lock::Mutex;
type ErrorCallback = Box<dyn Fn(LuaError) + Send + 'static>;
#[derive(Clone)]
pub(crate) struct ThreadErrorCallback {
exists: Arc<AtomicBool>,
inner: Arc<Mutex<Option<ErrorCallback>>>,
}
impl ThreadErrorCallback {
pub fn new() -> Self {
Self {
exists: Arc::new(AtomicBool::new(false)),
inner: Arc::new(Mutex::new(None)),
}
}
pub fn new_default() -> Self {
let this = Self::new();
this.replace(default_error_callback);
this
}
pub fn replace(&self, callback: impl Fn(LuaError) + Send + 'static) {
self.exists.store(true, Ordering::Relaxed);
self.inner.lock_blocking().replace(Box::new(callback));
}
pub fn clear(&self) {
self.exists.store(false, Ordering::Relaxed);
self.inner.lock_blocking().take();
}
pub fn call(&self, error: &LuaError) {
if self.exists.load(Ordering::Relaxed) {
if let Some(cb) = &*self.inner.lock_blocking() {
cb(error.clone());
}
}
}
}
fn default_error_callback(e: LuaError) {
eprintln!("{e}");
}

View file

@ -1,12 +1,8 @@
mod callbacks;
mod error_callback;
mod queue;
mod runtime;
mod storage;
mod traits;
mod util;
pub use mlua;
pub use smol;
pub use callbacks::Callbacks;
pub use runtime::Runtime;
pub use traits::{IntoLuaThread, LuaExecutorExt};
pub use traits::{IntoLuaThread, LuaSpawnExt};

107
lib/queue.rs Normal file
View file

@ -0,0 +1,107 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc,
};
use mlua::prelude::*;
use smol::{
channel::{unbounded, Receiver, Sender},
lock::Mutex,
};
use crate::IntoLuaThread;
const ERR_OOM: &str = "out of memory";
/**
Queue for storing [`LuaThread`]s with associated arguments.
Provides methods for pushing and draining the queue, as
well as listening for new items being pushed to the queue.
*/
#[derive(Debug, Clone)]
pub struct ThreadQueue {
queue: Arc<Mutex<Vec<ThreadWithArgs>>>,
status: Arc<AtomicBool>,
signal_tx: Sender<()>,
signal_rx: Receiver<()>,
}
impl ThreadQueue {
pub fn new() -> Self {
let (signal_tx, signal_rx) = unbounded();
Self {
queue: Arc::new(Mutex::new(Vec::new())),
status: Arc::new(AtomicBool::new(false)),
signal_tx,
signal_rx,
}
}
pub fn has_threads(&self) -> bool {
self.status.load(Ordering::SeqCst)
}
pub fn push<'lua>(
&self,
lua: &'lua Lua,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<()> {
let thread = thread.into_lua_thread(lua)?;
let args = args.into_lua_multi(lua)?;
let stored = ThreadWithArgs::new(lua, thread, args);
self.queue.lock_blocking().push(stored);
self.status.store(true, Ordering::SeqCst);
self.signal_tx.try_send(()).unwrap();
Ok(())
}
pub async fn drain<'lua>(&self, lua: &'lua Lua) -> Vec<(LuaThread<'lua>, LuaMultiValue<'lua>)> {
let mut queue = self.queue.lock().await;
let drained = queue.drain(..).map(|s| s.into_inner(lua)).collect();
self.status.store(false, Ordering::SeqCst);
drained
}
pub async fn recv(&self) {
self.signal_rx.recv().await.unwrap();
}
}
/**
Representation of a [`LuaThread`] with associated arguments currently stored in the Lua registry.
*/
#[derive(Debug)]
struct ThreadWithArgs {
key_thread: LuaRegistryKey,
key_args: LuaRegistryKey,
}
impl ThreadWithArgs {
pub fn new<'lua>(lua: &'lua Lua, thread: LuaThread<'lua>, args: LuaMultiValue<'lua>) -> Self {
let argsv = args.into_vec();
let key_thread = lua.create_registry_value(thread).expect(ERR_OOM);
let key_args = lua.create_registry_value(argsv).expect(ERR_OOM);
Self {
key_thread,
key_args,
}
}
pub fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) {
let thread = lua.registry_value(&self.key_thread).unwrap();
let argsv = lua.registry_value(&self.key_args).unwrap();
let args = LuaMultiValue::from_vec(argsv);
lua.remove_registry_value(self.key_thread).unwrap();
lua.remove_registry_value(self.key_args).unwrap();
(thread, args)
}
}

View file

@ -1,160 +1,149 @@
use std::{cell::Cell, rc::Rc, sync::Arc};
use std::sync::Arc;
use mlua::prelude::*;
use smol::prelude::*;
use smol::{
block_on,
channel::{unbounded, Receiver, Sender},
lock::Mutex,
Executor, LocalExecutor,
};
use smol::{block_on, Executor, LocalExecutor};
use super::{
callbacks::Callbacks, storage::ThreadWithArgs, traits::IntoLuaThread, util::LuaThreadOrFunction,
error_callback::ThreadErrorCallback,
queue::ThreadQueue,
traits::IntoLuaThread,
util::{is_poll_pending, LuaThreadOrFunction},
};
const GLOBAL_NAME_SPAWN: &str = "__runtime__spawn";
const GLOBAL_NAME_DEFER: &str = "__runtime__defer";
pub struct Runtime<'lua> {
lua: &'lua Lua,
queue_status: Rc<Cell<bool>>,
// TODO: Use something better than Rc<Mutex<Vec<...>>>
queue_spawn: Rc<Mutex<Vec<ThreadWithArgs>>>,
queue_defer: Rc<Mutex<Vec<ThreadWithArgs>>>,
tx: Sender<()>,
rx: Receiver<()>,
queue_spawn: ThreadQueue,
queue_defer: ThreadQueue,
error_callback: ThreadErrorCallback,
}
impl<'lua> Runtime<'lua> {
/**
Creates a new runtime for the given Lua state.
This will inject some functions to interact with the scheduler / executor,
as well as the default [`Callbacks`] for thread values and errors.
This runtime will have a default error callback that prints errors to stderr.
*/
pub fn new(lua: &'lua Lua) -> LuaResult<Runtime<'lua>> {
let queue_status = Rc::new(Cell::new(false));
let queue_spawn = Rc::new(Mutex::new(Vec::new()));
let queue_defer = Rc::new(Mutex::new(Vec::new()));
let (tx, rx) = unbounded();
// HACK: Extract mlua "pending" constant value and store it
let pending = lua
.create_async_function(|_, ()| async move {
smol::future::yield_now().await;
Ok(())
})?
.into_lua_thread(lua)?
.resume::<_, LuaValue>(())?;
let pending_key = lua.create_registry_value(pending)?;
// TODO: Generalize these two functions below so we
// dont need to duplicate the same exact thing for
// spawn and defer which is prone to human error
// Create spawn function (push to start of queue)
let b_spawn = Rc::clone(&queue_status);
let q_spawn = Rc::clone(&queue_spawn);
let tx_spawn = tx.clone();
let fn_spawn = lua.create_function(
move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
let thread = tof.into_thread(lua)?;
if thread.status() == LuaThreadStatus::Resumable {
// HACK: We need to resume the thread once instantly for correct behavior,
// and only if we get the pending value back we can spawn to async executor
let pending: LuaValue = lua.registry_value(&pending_key)?;
match thread.resume::<_, LuaValue>(args.clone()) {
Ok(v) if v == pending => {
let stored = ThreadWithArgs::new(lua, thread.clone(), args);
q_spawn.lock_blocking().push(stored);
b_spawn.replace(true);
tx_spawn.try_send(()).map_err(|_| {
LuaError::runtime("Tried to spawn thread to a dropped queue")
})?;
}
Ok(v) => Callbacks::forward_value(lua, thread.clone(), v),
Err(e) => Callbacks::forward_error(lua, thread.clone(), e),
}
Ok(thread)
} else {
Err(LuaError::runtime("Tried to spawn non-resumable thread"))
}
},
)?;
// Create defer function (push to end of queue)
let b_defer = Rc::clone(&queue_status);
let q_defer = Rc::clone(&queue_defer);
let tx_defer = tx.clone();
let fn_defer = lua.create_function(
move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
let thread = tof.into_thread(lua)?;
if thread.status() == LuaThreadStatus::Resumable {
let stored = ThreadWithArgs::new(lua, thread.clone(), args);
q_defer.lock_blocking().push(stored);
b_defer.replace(true);
tx_defer.try_send(()).map_err(|_| {
LuaError::runtime("Tried to defer thread to a dropped queue")
})?;
Ok(thread)
} else {
Err(LuaError::runtime("Tried to defer non-resumable thread"))
}
},
)?;
// Store them both as globals
lua.globals().set(GLOBAL_NAME_SPAWN, fn_spawn)?;
lua.globals().set(GLOBAL_NAME_DEFER, fn_defer)?;
// Finally, inject default callbacks
Callbacks::default().inject(lua);
let queue_spawn = ThreadQueue::new();
let queue_defer = ThreadQueue::new();
let error_callback = ThreadErrorCallback::new_default();
Ok(Runtime {
lua,
queue_status,
queue_spawn,
queue_defer,
tx,
rx,
error_callback,
})
}
/**
Sets the callbacks for this runtime.
Sets the error callback for this runtime.
This will overwrite any previously set callbacks, including default ones.
This callback will be called whenever a Lua thread errors.
Overwrites any previous error callback.
*/
pub fn set_callbacks(&self, callbacks: Callbacks) {
callbacks.inject(self.lua);
pub fn set_error_callback(&self, callback: impl Fn(LuaError) + Send + 'static) {
self.error_callback.replace(callback);
}
/**
Pushes a chunk / function / thread to the runtime queue.
Clears the error callback for this runtime.
This will remove any current error callback, including default(s).
*/
pub fn remove_error_callback(&self) {
self.error_callback.clear();
}
/**
Spawns a chunk / function / thread onto the runtime queue.
Threads are guaranteed to be resumed in the order that they were pushed to the queue.
*/
pub fn push_thread(
pub fn spawn_thread(
&self,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaThread<'lua> {
let thread = thread
.into_lua_thread(self.lua)
.expect("failed to create thread");
let args = args
.into_lua_multi(self.lua)
.expect("failed to create args");
) -> LuaResult<()> {
let thread = thread.into_lua_thread(self.lua)?;
let args = args.into_lua_multi(self.lua)?;
let stored = ThreadWithArgs::new(self.lua, thread.clone(), args);
self.queue_spawn.push(self.lua, thread, args)?;
self.queue_spawn.lock_blocking().push(stored);
self.queue_status.replace(true);
self.tx.try_send(()).unwrap(); // Unwrap is safe since this struct also holds the receiver
Ok(())
}
thread
/**
Defers a chunk / function / thread onto the runtime queue.
Deferred threads are guaranteed to run after all spawned threads either yield or complete.
Threads are guaranteed to be resumed in the order that they were pushed to the queue.
*/
pub fn defer_thread(
&self,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<()> {
let thread = thread.into_lua_thread(self.lua)?;
let args = args.into_lua_multi(self.lua)?;
self.queue_defer.push(self.lua, thread, args)?;
Ok(())
}
/**
Creates a lua function that can be used to spawn threads / functions onto the runtime queue.
The function takes a thread or function as the first argument, and any variadic arguments as the rest.
*/
pub fn create_spawn_function(&self) -> LuaResult<LuaFunction<'lua>> {
let error_callback = self.error_callback.clone();
let spawn_queue = self.queue_spawn.clone();
self.lua.create_function(
move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
let thread = tof.into_thread(lua)?;
if thread.status() == LuaThreadStatus::Resumable {
// NOTE: We need to resume the thread once instantly for correct behavior,
// and only if we get the pending value back we can spawn to async executor
match thread.resume::<_, LuaValue>(args.clone()) {
Ok(v) => {
if is_poll_pending(&v) {
spawn_queue.push(lua, &thread, args)?;
}
}
Err(e) => {
error_callback.call(&e);
}
};
}
Ok(thread)
},
)
}
/**
Creates a lua function that can be used to defer threads / functions onto the runtime queue.
The function takes a thread or function as the first argument, and any variadic arguments as the rest.
Deferred threads are guaranteed to run after all spawned threads either yield or complete.
*/
pub fn create_defer_function(&self) -> LuaResult<LuaFunction<'lua>> {
let defer_queue = self.queue_defer.clone();
self.lua.create_function(
move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
let thread = tof.into_thread(lua)?;
if thread.status() == LuaThreadStatus::Resumable {
defer_queue.push(lua, &thread, args)?;
}
Ok(thread)
},
)
}
/**
@ -170,7 +159,7 @@ impl<'lua> Runtime<'lua> {
let lua_exec = LocalExecutor::new();
let main_exec = Arc::new(Executor::new());
// Store the main executor in lua for LuaExecutorExt trait
// Store the main executor in lua for spawner trait
self.lua.set_app_data(Arc::downgrade(&main_exec));
// Tick local lua executor while also driving main
@ -179,9 +168,8 @@ impl<'lua> Runtime<'lua> {
loop {
// Wait for a new thread to arrive __or__ next futures step, prioritizing
// new threads, so we don't accidentally exit when there is more work to do
let fut_recv = async {
self.rx.recv().await.ok();
};
let fut_spawn = self.queue_spawn.recv();
let fut_defer = self.queue_defer.recv();
let fut_tick = async {
lua_exec.tick().await;
// Do as much work as possible
@ -191,18 +179,18 @@ impl<'lua> Runtime<'lua> {
}
}
};
fut_recv.or(fut_tick).await;
// If a new thread was spawned onto any queue, we
// must drain them and schedule on the executor
if self.queue_status.get() {
fut_spawn.or(fut_defer).or(fut_tick).await;
// If a new thread was spawned onto any queue,
// we must drain them and schedule on the executor
if self.queue_spawn.has_threads() || self.queue_defer.has_threads() {
let mut queued_threads = Vec::new();
queued_threads.extend(self.queue_spawn.lock().await.drain(..));
queued_threads.extend(self.queue_defer.lock().await.drain(..));
for queued_thread in queued_threads {
queued_threads.extend(self.queue_spawn.drain(self.lua).await);
queued_threads.extend(self.queue_defer.drain(self.lua).await);
for (thread, args) in queued_threads {
// NOTE: Thread may have been cancelled from lua
// before we got here, so we need to check it again
let (thread, args) = queued_thread.into_inner(self.lua);
if thread.status() == LuaThreadStatus::Resumable {
let mut stream = thread.clone().into_async::<_, LuaValue>(args);
lua_exec
@ -210,10 +198,11 @@ impl<'lua> Runtime<'lua> {
// Only run stream until first coroutine.yield or completion. We will
// drop it right away to clear stack space since detached tasks dont drop
// until the executor drops https://github.com/smol-rs/smol/issues/294
match stream.next().await.unwrap() {
Ok(v) => Callbacks::forward_value(self.lua, thread, v),
Err(e) => Callbacks::forward_error(self.lua, thread, e),
};
let res = stream.next().await.unwrap();
if let Err(e) = &res {
self.error_callback.call(e);
}
// TODO: Figure out how to give this result to caller of spawn_thread/defer_thread
})
.detach();
}

View file

@ -1,43 +0,0 @@
use mlua::prelude::*;
#[derive(Debug)]
pub(crate) struct ThreadWithArgs {
key_thread: LuaRegistryKey,
key_args: LuaRegistryKey,
}
impl ThreadWithArgs {
pub fn new<'lua>(lua: &'lua Lua, thread: LuaThread<'lua>, args: LuaMultiValue<'lua>) -> Self {
let args_vec = args.into_vec();
let key_thread = lua
.create_registry_value(thread)
.expect("Failed to store thread in registry - out of memory");
let key_args = lua
.create_registry_value(args_vec)
.expect("Failed to store thread args in registry - out of memory");
Self {
key_thread,
key_args,
}
}
pub fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) {
let thread = lua
.registry_value(&self.key_thread)
.expect("Failed to get thread from registry");
let args_vec = lua
.registry_value(&self.key_args)
.expect("Failed to get thread args from registry");
let args = LuaMultiValue::from_vec(args_vec);
lua.remove_registry_value(self.key_thread)
.expect("Failed to remove thread from registry");
lua.remove_registry_value(self.key_args)
.expect("Failed to remove thread args from registry");
(thread, args)
}
}

View file

@ -36,13 +36,22 @@ impl<'lua> IntoLuaThread<'lua> for LuaChunk<'lua, '_> {
}
}
impl<'lua, T> IntoLuaThread<'lua> for &T
where
T: IntoLuaThread<'lua> + Clone,
{
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
self.clone().into_lua_thread(lua)
}
}
/**
Trait for spawning `Send` futures on the current executor.
For spawning non-`Send` futures on the same local executor as a [`Lua`]
For spawning `!Send` futures on the same local executor as a [`Lua`]
VM instance, [`Lua::create_async_function`] should be used instead.
*/
pub trait LuaExecutorExt<'lua> {
pub trait LuaSpawnExt<'lua> {
/**
Spawns the given future on the current executor and returns its [`Task`].
@ -54,7 +63,7 @@ pub trait LuaExecutorExt<'lua> {
```rust
use mlua::prelude::*;
use smol_mlua::{Runtime, LuaExecutorExt};
use smol_mlua::{Runtime, LuaSpawnExt};
fn main() -> LuaResult<()> {
let lua = Lua::new();
@ -70,7 +79,7 @@ pub trait LuaExecutorExt<'lua> {
)?;
let rt = Runtime::new(&lua)?;
rt.push_thread(lua.load("spawnBackgroundTask()"), ());
rt.spawn_thread(lua.load("spawnBackgroundTask()"), ());
rt.run_blocking();
Ok(())
@ -82,7 +91,7 @@ pub trait LuaExecutorExt<'lua> {
fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T>;
}
impl<'lua> LuaExecutorExt<'lua> for Lua {
impl<'lua> LuaSpawnExt<'lua> for Lua {
fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T> {
let exec = self
.app_data_ref::<Weak<Executor>>()

View file

@ -1,5 +1,42 @@
use std::cell::OnceCell;
use mlua::prelude::*;
use crate::IntoLuaThread;
thread_local! {
static POLL_PENDING: OnceCell<LuaLightUserData> = OnceCell::new();
}
fn get_poll_pending(lua: &Lua) -> LuaResult<LuaLightUserData> {
let yielder_fn = lua.create_async_function(|_, ()| async move {
smol::future::yield_now().await;
Ok(())
})?;
yielder_fn
.into_lua_thread(lua)?
.resume::<_, LuaLightUserData>(())
}
#[inline]
pub(crate) fn is_poll_pending(value: &LuaValue) -> bool {
// TODO: Replace with Lua::poll_pending() when it's available
let pp = POLL_PENDING.with(|cell| {
*cell.get_or_init(|| {
let lua = Lua::new().into_static();
let pending = get_poll_pending(lua).unwrap();
// SAFETY: We only use the Lua state for the lifetime of this function,
// and the "poll pending" light userdata / pointer is completely static.
drop(unsafe { Lua::from_static(lua) });
pending
})
});
matches!(value, LuaValue::LightUserData(u) if u == &pp)
}
/**
Wrapper struct to accept either a Lua thread or a Lua function as function argument.