mirror of
https://github.com/lune-org/mlua-luau-scheduler.git
synced 2025-04-04 10:30:56 +01:00
Implement thread callbacks, make a lib
This commit is contained in:
parent
d3e0d5f8c2
commit
c6c4c2fd40
9 changed files with 140 additions and 51 deletions
16
Cargo.lock
generated
16
Cargo.lock
generated
|
@ -315,14 +315,6 @@ version = "0.4.13"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
|
||||
|
||||
[[package]]
|
||||
name = "luau-scheduler-experiments"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"mlua",
|
||||
"smol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "luau0-src"
|
||||
version = "0.7.11+luau606"
|
||||
|
@ -521,6 +513,14 @@ dependencies = [
|
|||
"futures-lite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "smol-mlua"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"mlua",
|
||||
"smol",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.48"
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
[package]
|
||||
name = "luau-scheduler-experiments"
|
||||
name = "smol-mlua"
|
||||
version = "0.0.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
smol = "2.0"
|
||||
mlua = { version = "0.9", features = ["luau", "luau-jit", "async"] }
|
||||
|
||||
[lib]
|
||||
path = "lib/lib.rs"
|
||||
|
|
7
lib/lib.rs
Normal file
7
lib/lib.rs
Normal file
|
@ -0,0 +1,7 @@
|
|||
pub mod thread_callbacks;
|
||||
pub mod thread_runtime;
|
||||
pub mod thread_storage;
|
||||
pub mod thread_util;
|
||||
|
||||
pub use mlua;
|
||||
pub use smol;
|
71
lib/thread_callbacks.rs
Normal file
71
lib/thread_callbacks.rs
Normal file
|
@ -0,0 +1,71 @@
|
|||
use mlua::prelude::*;
|
||||
|
||||
type ErrorCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>;
|
||||
type ValueCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>;
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct ThreadCallbacks {
|
||||
on_error: Option<ErrorCallback>,
|
||||
on_value: Option<ValueCallback>,
|
||||
}
|
||||
|
||||
impl ThreadCallbacks {
|
||||
pub fn new() -> ThreadCallbacks {
|
||||
Default::default()
|
||||
}
|
||||
|
||||
pub fn on_error<F>(mut self, f: F) -> Self
|
||||
where
|
||||
F: Fn(&Lua, LuaThread, LuaError) + 'static,
|
||||
{
|
||||
self.on_error.replace(Box::new(f));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn on_value<F>(mut self, f: F) -> Self
|
||||
where
|
||||
F: Fn(&Lua, LuaThread, LuaValue) + 'static,
|
||||
{
|
||||
self.on_value.replace(Box::new(f));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn inject(self, lua: &Lua) {
|
||||
// Create functions to forward errors & values
|
||||
if let Some(f) = self.on_error {
|
||||
lua.set_named_registry_value(
|
||||
"__forward__error",
|
||||
lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| {
|
||||
f(lua, thread, err);
|
||||
Ok(())
|
||||
})
|
||||
.expect("failed to create error callback function"),
|
||||
)
|
||||
.expect("failed to store error callback function");
|
||||
}
|
||||
|
||||
if let Some(f) = self.on_value {
|
||||
lua.set_named_registry_value(
|
||||
"__forward__value",
|
||||
lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| {
|
||||
f(lua, thread, val);
|
||||
Ok(())
|
||||
})
|
||||
.expect("failed to create value callback function"),
|
||||
)
|
||||
.expect("failed to store value callback function");
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) {
|
||||
if let Ok(f) = lua.named_registry_value::<LuaFunction>("__forward__error") {
|
||||
f.call::<_, ()>((thread, error)).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) {
|
||||
if let Ok(f) = lua.named_registry_value::<LuaFunction>("__forward__value") {
|
||||
f.call::<_, ()>((thread, value)).unwrap();
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,8 +10,9 @@ use smol::{
|
|||
};
|
||||
|
||||
use super::{
|
||||
thread_callbacks::ThreadCallbacks,
|
||||
thread_storage::ThreadWithArgs,
|
||||
thread_util::{IntoLuaThread, LuaThreadOrFunction},
|
||||
ThreadWithArgs,
|
||||
};
|
||||
|
||||
pub struct ThreadRuntime {
|
||||
|
@ -56,10 +57,6 @@ impl ThreadRuntime {
|
|||
// and only if we get the pending value back we can spawn to async executor
|
||||
let pending: LuaValue = lua.registry_value(&pending_key)?;
|
||||
match thread.resume::<_, LuaValue>(args.clone()) {
|
||||
Err(e) => {
|
||||
eprintln!("{:?}", e);
|
||||
// TODO: Forward error
|
||||
}
|
||||
Ok(v) if v == pending => {
|
||||
let stored = ThreadWithArgs::new(lua, thread.clone(), args);
|
||||
q_spawn.lock_blocking().push(stored);
|
||||
|
@ -68,9 +65,8 @@ impl ThreadRuntime {
|
|||
LuaError::runtime("Tried to spawn thread to a dropped queue")
|
||||
})?;
|
||||
}
|
||||
Ok(_) => {
|
||||
// TODO: Forward value
|
||||
}
|
||||
Ok(v) => ThreadCallbacks::forward_value(lua, thread.clone(), v),
|
||||
Err(e) => ThreadCallbacks::forward_error(lua, thread.clone(), e),
|
||||
}
|
||||
Ok(thread)
|
||||
} else {
|
||||
|
@ -122,17 +118,19 @@ impl ThreadRuntime {
|
|||
lua: &'lua Lua,
|
||||
thread: impl IntoLuaThread<'lua>,
|
||||
args: impl IntoLuaMulti<'lua>,
|
||||
) {
|
||||
) -> LuaThread<'lua> {
|
||||
let thread = thread
|
||||
.into_lua_thread(lua)
|
||||
.expect("failed to create thread");
|
||||
let args = args.into_lua_multi(lua).expect("failed to create args");
|
||||
|
||||
let stored = ThreadWithArgs::new(lua, thread, args);
|
||||
let stored = ThreadWithArgs::new(lua, thread.clone(), args);
|
||||
|
||||
self.queue_spawn.lock_blocking().push(stored);
|
||||
self.queue_status.replace(true);
|
||||
self.tx.try_send(()).unwrap(); // Unwrap is safe since this struct also holds the receiver
|
||||
|
||||
thread
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -177,21 +175,16 @@ impl ThreadRuntime {
|
|||
// before we got here, so we need to check it again
|
||||
let (thread, args) = queued_thread.into_inner(lua);
|
||||
if thread.status() == LuaThreadStatus::Resumable {
|
||||
let mut stream = thread.into_async::<_, LuaValue>(args);
|
||||
let mut stream = thread.clone().into_async::<_, LuaValue>(args);
|
||||
lua_exec
|
||||
.spawn(async move {
|
||||
// Only run stream until first coroutine.yield or completion. We will
|
||||
// drop it right away to clear stack space since detached tasks dont drop
|
||||
// until the executor drops https://github.com/smol-rs/smol/issues/294
|
||||
match stream.next().await.unwrap() {
|
||||
Err(e) => {
|
||||
eprintln!("{e}");
|
||||
// TODO: Forward error
|
||||
}
|
||||
Ok(_) => {
|
||||
// TODO: Forward value
|
||||
}
|
||||
}
|
||||
Ok(v) => ThreadCallbacks::forward_value(lua, thread, v),
|
||||
Err(e) => ThreadCallbacks::forward_error(lua, thread, e),
|
||||
};
|
||||
})
|
||||
.detach();
|
||||
}
|
|
@ -1,18 +1,19 @@
|
|||
for i = 1, 20 do
|
||||
print("iteration " .. tostring(i) .. " of 20")
|
||||
local thread = coroutine.running()
|
||||
local start = os.clock()
|
||||
|
||||
local counter = 0
|
||||
for j = 1, 50_000 do
|
||||
spawn(function()
|
||||
wait()
|
||||
counter += 1
|
||||
if counter == 50_000 then
|
||||
print("completed iteration " .. tostring(i) .. " of 20")
|
||||
spawn(thread)
|
||||
end
|
||||
end)
|
||||
end
|
||||
local thread = coroutine.running()
|
||||
|
||||
coroutine.yield()
|
||||
local counter = 0
|
||||
for j = 1, 10_000 do
|
||||
spawn(function()
|
||||
wait()
|
||||
counter += 1
|
||||
if counter == 10_000 then
|
||||
print("completed")
|
||||
spawn(thread)
|
||||
end
|
||||
end)
|
||||
end
|
||||
|
||||
coroutine.yield()
|
||||
|
||||
return os.clock() - start
|
||||
|
|
30
src/main.rs
30
src/main.rs
|
@ -5,12 +5,7 @@ use smol::*;
|
|||
|
||||
const MAIN_SCRIPT: &str = include_str!("./main.luau");
|
||||
|
||||
mod thread_runtime;
|
||||
mod thread_storage;
|
||||
mod thread_util;
|
||||
|
||||
use thread_runtime::*;
|
||||
use thread_storage::*;
|
||||
use smol_mlua::{thread_callbacks::ThreadCallbacks, thread_runtime::ThreadRuntime};
|
||||
|
||||
pub fn main() -> LuaResult<()> {
|
||||
let start = Instant::now();
|
||||
|
@ -27,9 +22,28 @@ pub fn main() -> LuaResult<()> {
|
|||
})?,
|
||||
)?;
|
||||
|
||||
// Set up runtime (thread queue / async executors) and run main script until end
|
||||
// Set up runtime (thread queue / async executors)
|
||||
let rt = ThreadRuntime::new(&lua)?;
|
||||
rt.push_main(&lua, lua.load(MAIN_SCRIPT), ());
|
||||
let main = rt.push_main(&lua, lua.load(MAIN_SCRIPT), ());
|
||||
lua.set_named_registry_value("main", main)?;
|
||||
|
||||
// Add callbacks to capture resulting value/error of main thread
|
||||
ThreadCallbacks::new()
|
||||
.on_value(|lua, thread, val| {
|
||||
let main = lua.named_registry_value::<LuaThread>("main").unwrap();
|
||||
if main == thread {
|
||||
println!("main thread value: {:?}", val);
|
||||
}
|
||||
})
|
||||
.on_error(|lua, thread, err| {
|
||||
let main = lua.named_registry_value::<LuaThread>("main").unwrap();
|
||||
if main == thread {
|
||||
eprintln!("main thread error: {:?}", err);
|
||||
}
|
||||
})
|
||||
.inject(&lua);
|
||||
|
||||
// Run until end
|
||||
rt.run_blocking(&lua);
|
||||
|
||||
println!("elapsed: {:?}", start.elapsed());
|
||||
|
|
Loading…
Add table
Reference in a new issue