mirror of
https://github.com/lune-org/mlua-luau-scheduler.git
synced 2025-04-19 03:13:46 +01:00
Implement thread callbacks, make a lib
This commit is contained in:
parent
d3e0d5f8c2
commit
c6c4c2fd40
9 changed files with 140 additions and 51 deletions
16
Cargo.lock
generated
16
Cargo.lock
generated
|
@ -315,14 +315,6 @@ version = "0.4.13"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "luau-scheduler-experiments"
|
|
||||||
version = "0.0.0"
|
|
||||||
dependencies = [
|
|
||||||
"mlua",
|
|
||||||
"smol",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "luau0-src"
|
name = "luau0-src"
|
||||||
version = "0.7.11+luau606"
|
version = "0.7.11+luau606"
|
||||||
|
@ -521,6 +513,14 @@ dependencies = [
|
||||||
"futures-lite",
|
"futures-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "smol-mlua"
|
||||||
|
version = "0.0.0"
|
||||||
|
dependencies = [
|
||||||
|
"mlua",
|
||||||
|
"smol",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.48"
|
version = "2.0.48"
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
[package]
|
[package]
|
||||||
name = "luau-scheduler-experiments"
|
name = "smol-mlua"
|
||||||
version = "0.0.0"
|
version = "0.0.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
smol = "2.0"
|
smol = "2.0"
|
||||||
mlua = { version = "0.9", features = ["luau", "luau-jit", "async"] }
|
mlua = { version = "0.9", features = ["luau", "luau-jit", "async"] }
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
path = "lib/lib.rs"
|
||||||
|
|
7
lib/lib.rs
Normal file
7
lib/lib.rs
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
pub mod thread_callbacks;
|
||||||
|
pub mod thread_runtime;
|
||||||
|
pub mod thread_storage;
|
||||||
|
pub mod thread_util;
|
||||||
|
|
||||||
|
pub use mlua;
|
||||||
|
pub use smol;
|
71
lib/thread_callbacks.rs
Normal file
71
lib/thread_callbacks.rs
Normal file
|
@ -0,0 +1,71 @@
|
||||||
|
use mlua::prelude::*;
|
||||||
|
|
||||||
|
type ErrorCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>;
|
||||||
|
type ValueCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>;
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ThreadCallbacks {
|
||||||
|
on_error: Option<ErrorCallback>,
|
||||||
|
on_value: Option<ValueCallback>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ThreadCallbacks {
|
||||||
|
pub fn new() -> ThreadCallbacks {
|
||||||
|
Default::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_error<F>(mut self, f: F) -> Self
|
||||||
|
where
|
||||||
|
F: Fn(&Lua, LuaThread, LuaError) + 'static,
|
||||||
|
{
|
||||||
|
self.on_error.replace(Box::new(f));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn on_value<F>(mut self, f: F) -> Self
|
||||||
|
where
|
||||||
|
F: Fn(&Lua, LuaThread, LuaValue) + 'static,
|
||||||
|
{
|
||||||
|
self.on_value.replace(Box::new(f));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn inject(self, lua: &Lua) {
|
||||||
|
// Create functions to forward errors & values
|
||||||
|
if let Some(f) = self.on_error {
|
||||||
|
lua.set_named_registry_value(
|
||||||
|
"__forward__error",
|
||||||
|
lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| {
|
||||||
|
f(lua, thread, err);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.expect("failed to create error callback function"),
|
||||||
|
)
|
||||||
|
.expect("failed to store error callback function");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(f) = self.on_value {
|
||||||
|
lua.set_named_registry_value(
|
||||||
|
"__forward__value",
|
||||||
|
lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| {
|
||||||
|
f(lua, thread, val);
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.expect("failed to create value callback function"),
|
||||||
|
)
|
||||||
|
.expect("failed to store value callback function");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) {
|
||||||
|
if let Ok(f) = lua.named_registry_value::<LuaFunction>("__forward__error") {
|
||||||
|
f.call::<_, ()>((thread, error)).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) {
|
||||||
|
if let Ok(f) = lua.named_registry_value::<LuaFunction>("__forward__value") {
|
||||||
|
f.call::<_, ()>((thread, value)).unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -10,8 +10,9 @@ use smol::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
|
thread_callbacks::ThreadCallbacks,
|
||||||
|
thread_storage::ThreadWithArgs,
|
||||||
thread_util::{IntoLuaThread, LuaThreadOrFunction},
|
thread_util::{IntoLuaThread, LuaThreadOrFunction},
|
||||||
ThreadWithArgs,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
pub struct ThreadRuntime {
|
pub struct ThreadRuntime {
|
||||||
|
@ -56,10 +57,6 @@ impl ThreadRuntime {
|
||||||
// and only if we get the pending value back we can spawn to async executor
|
// and only if we get the pending value back we can spawn to async executor
|
||||||
let pending: LuaValue = lua.registry_value(&pending_key)?;
|
let pending: LuaValue = lua.registry_value(&pending_key)?;
|
||||||
match thread.resume::<_, LuaValue>(args.clone()) {
|
match thread.resume::<_, LuaValue>(args.clone()) {
|
||||||
Err(e) => {
|
|
||||||
eprintln!("{:?}", e);
|
|
||||||
// TODO: Forward error
|
|
||||||
}
|
|
||||||
Ok(v) if v == pending => {
|
Ok(v) if v == pending => {
|
||||||
let stored = ThreadWithArgs::new(lua, thread.clone(), args);
|
let stored = ThreadWithArgs::new(lua, thread.clone(), args);
|
||||||
q_spawn.lock_blocking().push(stored);
|
q_spawn.lock_blocking().push(stored);
|
||||||
|
@ -68,9 +65,8 @@ impl ThreadRuntime {
|
||||||
LuaError::runtime("Tried to spawn thread to a dropped queue")
|
LuaError::runtime("Tried to spawn thread to a dropped queue")
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
Ok(_) => {
|
Ok(v) => ThreadCallbacks::forward_value(lua, thread.clone(), v),
|
||||||
// TODO: Forward value
|
Err(e) => ThreadCallbacks::forward_error(lua, thread.clone(), e),
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Ok(thread)
|
Ok(thread)
|
||||||
} else {
|
} else {
|
||||||
|
@ -122,17 +118,19 @@ impl ThreadRuntime {
|
||||||
lua: &'lua Lua,
|
lua: &'lua Lua,
|
||||||
thread: impl IntoLuaThread<'lua>,
|
thread: impl IntoLuaThread<'lua>,
|
||||||
args: impl IntoLuaMulti<'lua>,
|
args: impl IntoLuaMulti<'lua>,
|
||||||
) {
|
) -> LuaThread<'lua> {
|
||||||
let thread = thread
|
let thread = thread
|
||||||
.into_lua_thread(lua)
|
.into_lua_thread(lua)
|
||||||
.expect("failed to create thread");
|
.expect("failed to create thread");
|
||||||
let args = args.into_lua_multi(lua).expect("failed to create args");
|
let args = args.into_lua_multi(lua).expect("failed to create args");
|
||||||
|
|
||||||
let stored = ThreadWithArgs::new(lua, thread, args);
|
let stored = ThreadWithArgs::new(lua, thread.clone(), args);
|
||||||
|
|
||||||
self.queue_spawn.lock_blocking().push(stored);
|
self.queue_spawn.lock_blocking().push(stored);
|
||||||
self.queue_status.replace(true);
|
self.queue_status.replace(true);
|
||||||
self.tx.try_send(()).unwrap(); // Unwrap is safe since this struct also holds the receiver
|
self.tx.try_send(()).unwrap(); // Unwrap is safe since this struct also holds the receiver
|
||||||
|
|
||||||
|
thread
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
@ -177,21 +175,16 @@ impl ThreadRuntime {
|
||||||
// before we got here, so we need to check it again
|
// before we got here, so we need to check it again
|
||||||
let (thread, args) = queued_thread.into_inner(lua);
|
let (thread, args) = queued_thread.into_inner(lua);
|
||||||
if thread.status() == LuaThreadStatus::Resumable {
|
if thread.status() == LuaThreadStatus::Resumable {
|
||||||
let mut stream = thread.into_async::<_, LuaValue>(args);
|
let mut stream = thread.clone().into_async::<_, LuaValue>(args);
|
||||||
lua_exec
|
lua_exec
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
// Only run stream until first coroutine.yield or completion. We will
|
// Only run stream until first coroutine.yield or completion. We will
|
||||||
// drop it right away to clear stack space since detached tasks dont drop
|
// drop it right away to clear stack space since detached tasks dont drop
|
||||||
// until the executor drops https://github.com/smol-rs/smol/issues/294
|
// until the executor drops https://github.com/smol-rs/smol/issues/294
|
||||||
match stream.next().await.unwrap() {
|
match stream.next().await.unwrap() {
|
||||||
Err(e) => {
|
Ok(v) => ThreadCallbacks::forward_value(lua, thread, v),
|
||||||
eprintln!("{e}");
|
Err(e) => ThreadCallbacks::forward_error(lua, thread, e),
|
||||||
// TODO: Forward error
|
};
|
||||||
}
|
|
||||||
Ok(_) => {
|
|
||||||
// TODO: Forward value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
}
|
}
|
|
@ -1,18 +1,19 @@
|
||||||
for i = 1, 20 do
|
local start = os.clock()
|
||||||
print("iteration " .. tostring(i) .. " of 20")
|
|
||||||
local thread = coroutine.running()
|
|
||||||
|
|
||||||
local counter = 0
|
local thread = coroutine.running()
|
||||||
for j = 1, 50_000 do
|
|
||||||
spawn(function()
|
|
||||||
wait()
|
|
||||||
counter += 1
|
|
||||||
if counter == 50_000 then
|
|
||||||
print("completed iteration " .. tostring(i) .. " of 20")
|
|
||||||
spawn(thread)
|
|
||||||
end
|
|
||||||
end)
|
|
||||||
end
|
|
||||||
|
|
||||||
coroutine.yield()
|
local counter = 0
|
||||||
|
for j = 1, 10_000 do
|
||||||
|
spawn(function()
|
||||||
|
wait()
|
||||||
|
counter += 1
|
||||||
|
if counter == 10_000 then
|
||||||
|
print("completed")
|
||||||
|
spawn(thread)
|
||||||
|
end
|
||||||
|
end)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
coroutine.yield()
|
||||||
|
|
||||||
|
return os.clock() - start
|
||||||
|
|
30
src/main.rs
30
src/main.rs
|
@ -5,12 +5,7 @@ use smol::*;
|
||||||
|
|
||||||
const MAIN_SCRIPT: &str = include_str!("./main.luau");
|
const MAIN_SCRIPT: &str = include_str!("./main.luau");
|
||||||
|
|
||||||
mod thread_runtime;
|
use smol_mlua::{thread_callbacks::ThreadCallbacks, thread_runtime::ThreadRuntime};
|
||||||
mod thread_storage;
|
|
||||||
mod thread_util;
|
|
||||||
|
|
||||||
use thread_runtime::*;
|
|
||||||
use thread_storage::*;
|
|
||||||
|
|
||||||
pub fn main() -> LuaResult<()> {
|
pub fn main() -> LuaResult<()> {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
@ -27,9 +22,28 @@ pub fn main() -> LuaResult<()> {
|
||||||
})?,
|
})?,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// Set up runtime (thread queue / async executors) and run main script until end
|
// Set up runtime (thread queue / async executors)
|
||||||
let rt = ThreadRuntime::new(&lua)?;
|
let rt = ThreadRuntime::new(&lua)?;
|
||||||
rt.push_main(&lua, lua.load(MAIN_SCRIPT), ());
|
let main = rt.push_main(&lua, lua.load(MAIN_SCRIPT), ());
|
||||||
|
lua.set_named_registry_value("main", main)?;
|
||||||
|
|
||||||
|
// Add callbacks to capture resulting value/error of main thread
|
||||||
|
ThreadCallbacks::new()
|
||||||
|
.on_value(|lua, thread, val| {
|
||||||
|
let main = lua.named_registry_value::<LuaThread>("main").unwrap();
|
||||||
|
if main == thread {
|
||||||
|
println!("main thread value: {:?}", val);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.on_error(|lua, thread, err| {
|
||||||
|
let main = lua.named_registry_value::<LuaThread>("main").unwrap();
|
||||||
|
if main == thread {
|
||||||
|
eprintln!("main thread error: {:?}", err);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.inject(&lua);
|
||||||
|
|
||||||
|
// Run until end
|
||||||
rt.run_blocking(&lua);
|
rt.run_blocking(&lua);
|
||||||
|
|
||||||
println!("elapsed: {:?}", start.elapsed());
|
println!("elapsed: {:?}", start.elapsed());
|
||||||
|
|
Loading…
Add table
Reference in a new issue