mirror of
https://github.com/lune-org/mlua-luau-scheduler.git
synced 2025-04-10 21:40:55 +01:00
Convert main to an example that also includes value capture
This commit is contained in:
parent
34529c0235
commit
0299568318
7 changed files with 153 additions and 102 deletions
|
@ -9,3 +9,6 @@ mlua = { version = "0.9", features = ["luau", "luau-jit", "async"] }
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
path = "lib/lib.rs"
|
path = "lib/lib.rs"
|
||||||
|
|
||||||
|
[examples]
|
||||||
|
main = "examples/main.rs"
|
||||||
|
|
25
examples/main.luau
Normal file
25
examples/main.luau
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
--!nocheck
|
||||||
|
--!nolint UnknownGlobal
|
||||||
|
|
||||||
|
|
||||||
|
if math.random() < 0.25 then
|
||||||
|
error("Unlucky error!")
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
|
local main = coroutine.running()
|
||||||
|
local start = os.clock()
|
||||||
|
|
||||||
|
local counter = 0
|
||||||
|
for j = 1, 10_000 do
|
||||||
|
__runtime__spawn(function()
|
||||||
|
wait()
|
||||||
|
counter += 1
|
||||||
|
if counter == 10_000 then
|
||||||
|
local elapsed = os.clock() - start
|
||||||
|
__runtime__spawn(main, elapsed)
|
||||||
|
end
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
|
||||||
|
return coroutine.yield()
|
86
examples/main.rs
Normal file
86
examples/main.rs
Normal file
|
@ -0,0 +1,86 @@
|
||||||
|
use std::{
|
||||||
|
rc::Rc,
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
|
|
||||||
|
use smol_mlua::{
|
||||||
|
mlua::prelude::{Lua, LuaResult, LuaThread, LuaValue},
|
||||||
|
smol::{lock::Mutex, Timer},
|
||||||
|
Callbacks, IntoLuaThread, Runtime,
|
||||||
|
};
|
||||||
|
|
||||||
|
const MAIN_SCRIPT: &str = include_str!("./main.luau");
|
||||||
|
|
||||||
|
pub fn main() -> LuaResult<()> {
|
||||||
|
// Set up persistent lua environment
|
||||||
|
let lua = Lua::new();
|
||||||
|
lua.globals().set(
|
||||||
|
"wait",
|
||||||
|
lua.create_async_function(|_, duration: Option<f64>| async move {
|
||||||
|
let duration = duration.unwrap_or_default().max(1.0 / 250.0);
|
||||||
|
let before = Instant::now();
|
||||||
|
let after = Timer::after(Duration::from_secs_f64(duration)).await;
|
||||||
|
Ok((after - before).as_secs_f64())
|
||||||
|
})?,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Load and run the main script a few times for the purposes of this example
|
||||||
|
for _ in 0..20 {
|
||||||
|
println!("...");
|
||||||
|
match run(&lua, lua.load(MAIN_SCRIPT)) {
|
||||||
|
Err(e) => eprintln!("Errored:\n{e}"),
|
||||||
|
Ok(v) => println!("Returned value:\n{v:?}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
Wrapper function to run the given `main` thread on a new [`Runtime`].
|
||||||
|
|
||||||
|
Waits for all threads to finish, including the main thread, and
|
||||||
|
returns the value or error of the main thread once exited.
|
||||||
|
*/
|
||||||
|
fn run<'lua>(lua: &'lua Lua, main: impl IntoLuaThread<'lua>) -> LuaResult<LuaValue> {
|
||||||
|
// Set up runtime (thread queue / async executors)
|
||||||
|
let rt = Runtime::new(lua)?;
|
||||||
|
let thread = rt.push_main(lua, main, ());
|
||||||
|
lua.set_named_registry_value("mainThread", thread)?;
|
||||||
|
|
||||||
|
// Add callbacks to capture resulting value/error of main thread,
|
||||||
|
// we need to do some tricks to get around lifetime issues with 'lua
|
||||||
|
// being different inside the callback vs outside the callback for LuaValue
|
||||||
|
let captured_error = Rc::new(Mutex::new(None));
|
||||||
|
let captured_error_inner = Rc::clone(&captured_error);
|
||||||
|
Callbacks::new()
|
||||||
|
.on_value(|lua, thread, val| {
|
||||||
|
let main: LuaThread = lua.named_registry_value("mainThread").unwrap();
|
||||||
|
if main == thread {
|
||||||
|
lua.set_named_registry_value("mainValue", val).unwrap();
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.on_error(move |lua, thread, err| {
|
||||||
|
let main: LuaThread = lua.named_registry_value("mainThread").unwrap();
|
||||||
|
if main == thread {
|
||||||
|
captured_error_inner.lock_blocking().replace(err);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.inject(lua);
|
||||||
|
|
||||||
|
// Run until end
|
||||||
|
rt.run_blocking(lua);
|
||||||
|
|
||||||
|
// Extract value and error from their containers
|
||||||
|
let err_opt = { captured_error.lock_blocking().take() };
|
||||||
|
let val_opt = lua.named_registry_value("mainValue").ok();
|
||||||
|
|
||||||
|
// Check result
|
||||||
|
if let Some(err) = err_opt {
|
||||||
|
Err(err)
|
||||||
|
} else if let Some(val) = val_opt {
|
||||||
|
Ok(val)
|
||||||
|
} else {
|
||||||
|
unreachable!("No value or error captured from main thread");
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,12 +1,15 @@
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
type ErrorCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>;
|
|
||||||
type ValueCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>;
|
type ValueCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>;
|
||||||
|
type ErrorCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>;
|
||||||
|
|
||||||
|
const FORWARD_VALUE_KEY: &str = "__runtime__forwardValue";
|
||||||
|
const FORWARD_ERROR_KEY: &str = "__runtime__forwardError";
|
||||||
|
|
||||||
#[derive(Default)]
|
#[derive(Default)]
|
||||||
pub struct Callbacks {
|
pub struct Callbacks {
|
||||||
on_error: Option<ErrorCallback>,
|
|
||||||
on_value: Option<ValueCallback>,
|
on_value: Option<ValueCallback>,
|
||||||
|
on_error: Option<ErrorCallback>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Callbacks {
|
impl Callbacks {
|
||||||
|
@ -14,14 +17,6 @@ impl Callbacks {
|
||||||
Default::default()
|
Default::default()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn on_error<F>(mut self, f: F) -> Self
|
|
||||||
where
|
|
||||||
F: Fn(&Lua, LuaThread, LuaError) + 'static,
|
|
||||||
{
|
|
||||||
self.on_error.replace(Box::new(f));
|
|
||||||
self
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn on_value<F>(mut self, f: F) -> Self
|
pub fn on_value<F>(mut self, f: F) -> Self
|
||||||
where
|
where
|
||||||
F: Fn(&Lua, LuaThread, LuaValue) + 'static,
|
F: Fn(&Lua, LuaThread, LuaValue) + 'static,
|
||||||
|
@ -30,23 +25,19 @@ impl Callbacks {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn inject(self, lua: &Lua) {
|
pub fn on_error<F>(mut self, f: F) -> Self
|
||||||
// Create functions to forward errors & values
|
where
|
||||||
if let Some(f) = self.on_error {
|
F: Fn(&Lua, LuaThread, LuaError) + 'static,
|
||||||
lua.set_named_registry_value(
|
{
|
||||||
"__forward__error",
|
self.on_error.replace(Box::new(f));
|
||||||
lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| {
|
self
|
||||||
f(lua, thread, err);
|
}
|
||||||
Ok(())
|
|
||||||
})
|
|
||||||
.expect("failed to create error callback function"),
|
|
||||||
)
|
|
||||||
.expect("failed to store error callback function");
|
|
||||||
}
|
|
||||||
|
|
||||||
|
pub fn inject(self, lua: &Lua) {
|
||||||
|
// Create functions to forward values & errors
|
||||||
if let Some(f) = self.on_value {
|
if let Some(f) = self.on_value {
|
||||||
lua.set_named_registry_value(
|
lua.set_named_registry_value(
|
||||||
"__forward__value",
|
FORWARD_VALUE_KEY,
|
||||||
lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| {
|
lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| {
|
||||||
f(lua, thread, val);
|
f(lua, thread, val);
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -55,17 +46,29 @@ impl Callbacks {
|
||||||
)
|
)
|
||||||
.expect("failed to store value callback function");
|
.expect("failed to store value callback function");
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) {
|
if let Some(f) = self.on_error {
|
||||||
if let Ok(f) = lua.named_registry_value::<LuaFunction>("__forward__error") {
|
lua.set_named_registry_value(
|
||||||
f.call::<_, ()>((thread, error)).unwrap();
|
FORWARD_ERROR_KEY,
|
||||||
|
lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| {
|
||||||
|
f(lua, thread, err);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.expect("failed to create error callback function"),
|
||||||
|
)
|
||||||
|
.expect("failed to store error callback function");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) {
|
pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) {
|
||||||
if let Ok(f) = lua.named_registry_value::<LuaFunction>("__forward__value") {
|
if let Ok(f) = lua.named_registry_value::<LuaFunction>(FORWARD_VALUE_KEY) {
|
||||||
f.call::<_, ()>((thread, value)).unwrap();
|
f.call::<_, ()>((thread, value)).unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) {
|
||||||
|
if let Ok(f) = lua.named_registry_value::<LuaFunction>(FORWARD_ERROR_KEY) {
|
||||||
|
f.call::<_, ()>((thread, error)).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,6 +13,9 @@ use super::{
|
||||||
callbacks::Callbacks, storage::ThreadWithArgs, traits::IntoLuaThread, util::LuaThreadOrFunction,
|
callbacks::Callbacks, storage::ThreadWithArgs, traits::IntoLuaThread, util::LuaThreadOrFunction,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const GLOBAL_NAME_SPAWN: &str = "__runtime__spawn";
|
||||||
|
const GLOBAL_NAME_DEFER: &str = "__runtime__defer";
|
||||||
|
|
||||||
pub struct Runtime {
|
pub struct Runtime {
|
||||||
queue_status: Rc<Cell<bool>>,
|
queue_status: Rc<Cell<bool>>,
|
||||||
queue_spawn: Rc<Mutex<Vec<ThreadWithArgs>>>,
|
queue_spawn: Rc<Mutex<Vec<ThreadWithArgs>>>,
|
||||||
|
@ -94,10 +97,9 @@ impl Runtime {
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// FUTURE: Store these as named registry values instead
|
// Store them both as globals
|
||||||
// so that they are not accessible from within user code
|
lua.globals().set(GLOBAL_NAME_SPAWN, fn_spawn)?;
|
||||||
lua.globals().set("spawn", fn_spawn)?;
|
lua.globals().set(GLOBAL_NAME_DEFER, fn_defer)?;
|
||||||
lua.globals().set("defer", fn_defer)?;
|
|
||||||
|
|
||||||
Ok(Runtime {
|
Ok(Runtime {
|
||||||
queue_status,
|
queue_status,
|
||||||
|
@ -196,7 +198,7 @@ impl Runtime {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
main_exec.run(fut).await
|
main_exec.run(fut).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
@ -1,19 +0,0 @@
|
||||||
local start = os.clock()
|
|
||||||
|
|
||||||
local thread = coroutine.running()
|
|
||||||
|
|
||||||
local counter = 0
|
|
||||||
for j = 1, 10_000 do
|
|
||||||
spawn(function()
|
|
||||||
wait()
|
|
||||||
counter += 1
|
|
||||||
if counter == 10_000 then
|
|
||||||
print("completed")
|
|
||||||
spawn(thread)
|
|
||||||
end
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
|
|
||||||
coroutine.yield()
|
|
||||||
|
|
||||||
return os.clock() - start
|
|
49
src/main.rs
49
src/main.rs
|
@ -1,49 +0,0 @@
|
||||||
use std::time::{Duration, Instant};
|
|
||||||
|
|
||||||
use smol_mlua::{mlua::prelude::*, smol::*, Callbacks, Runtime};
|
|
||||||
|
|
||||||
const MAIN_SCRIPT: &str = include_str!("./main.luau");
|
|
||||||
|
|
||||||
pub fn main() -> LuaResult<()> {
|
|
||||||
let start = Instant::now();
|
|
||||||
let lua = Lua::new();
|
|
||||||
|
|
||||||
// Set up persistent lua environment
|
|
||||||
lua.globals().set(
|
|
||||||
"wait",
|
|
||||||
lua.create_async_function(|_, duration: Option<f64>| async move {
|
|
||||||
let duration = duration.unwrap_or_default().max(1.0 / 250.0);
|
|
||||||
let before = Instant::now();
|
|
||||||
let after = Timer::after(Duration::from_secs_f64(duration)).await;
|
|
||||||
Ok((after - before).as_secs_f64())
|
|
||||||
})?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Set up runtime (thread queue / async executors)
|
|
||||||
let rt = Runtime::new(&lua)?;
|
|
||||||
let main = rt.push_main(&lua, lua.load(MAIN_SCRIPT), ());
|
|
||||||
lua.set_named_registry_value("main", main)?;
|
|
||||||
|
|
||||||
// Add callbacks to capture resulting value/error of main thread
|
|
||||||
Callbacks::new()
|
|
||||||
.on_value(|lua, thread, val| {
|
|
||||||
let main = lua.named_registry_value::<LuaThread>("main").unwrap();
|
|
||||||
if main == thread {
|
|
||||||
println!("main thread value: {:?}", val);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.on_error(|lua, thread, err| {
|
|
||||||
let main = lua.named_registry_value::<LuaThread>("main").unwrap();
|
|
||||||
if main == thread {
|
|
||||||
eprintln!("main thread error: {:?}", err);
|
|
||||||
}
|
|
||||||
})
|
|
||||||
.inject(&lua);
|
|
||||||
|
|
||||||
// Run until end
|
|
||||||
rt.run_blocking(&lua);
|
|
||||||
|
|
||||||
println!("elapsed: {:?}", start.elapsed());
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
Loading…
Add table
Reference in a new issue