Implement thread callbacks, make a lib

This commit is contained in:
Filip Tibell 2024-01-19 11:14:00 +01:00
parent d3e0d5f8c2
commit c6c4c2fd40
No known key found for this signature in database
9 changed files with 140 additions and 51 deletions

16
Cargo.lock generated
View file

@ -315,14 +315,6 @@ version = "0.4.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c"
[[package]]
name = "luau-scheduler-experiments"
version = "0.0.0"
dependencies = [
"mlua",
"smol",
]
[[package]]
name = "luau0-src"
version = "0.7.11+luau606"
@ -521,6 +513,14 @@ dependencies = [
"futures-lite",
]
[[package]]
name = "smol-mlua"
version = "0.0.0"
dependencies = [
"mlua",
"smol",
]
[[package]]
name = "syn"
version = "2.0.48"

View file

@ -1,8 +1,11 @@
[package]
name = "luau-scheduler-experiments"
name = "smol-mlua"
version = "0.0.0"
edition = "2021"
[dependencies]
smol = "2.0"
mlua = { version = "0.9", features = ["luau", "luau-jit", "async"] }
[lib]
path = "lib/lib.rs"

7
lib/lib.rs Normal file
View file

@ -0,0 +1,7 @@
pub mod thread_callbacks;
pub mod thread_runtime;
pub mod thread_storage;
pub mod thread_util;
pub use mlua;
pub use smol;

71
lib/thread_callbacks.rs Normal file
View file

@ -0,0 +1,71 @@
use mlua::prelude::*;
type ErrorCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaError) + 'static>;
type ValueCallback = Box<dyn for<'lua> Fn(&'lua Lua, LuaThread<'lua>, LuaValue<'lua>) + 'static>;
#[derive(Default)]
pub struct ThreadCallbacks {
on_error: Option<ErrorCallback>,
on_value: Option<ValueCallback>,
}
impl ThreadCallbacks {
pub fn new() -> ThreadCallbacks {
Default::default()
}
pub fn on_error<F>(mut self, f: F) -> Self
where
F: Fn(&Lua, LuaThread, LuaError) + 'static,
{
self.on_error.replace(Box::new(f));
self
}
pub fn on_value<F>(mut self, f: F) -> Self
where
F: Fn(&Lua, LuaThread, LuaValue) + 'static,
{
self.on_value.replace(Box::new(f));
self
}
pub fn inject(self, lua: &Lua) {
// Create functions to forward errors & values
if let Some(f) = self.on_error {
lua.set_named_registry_value(
"__forward__error",
lua.create_function(move |lua, (thread, err): (LuaThread, LuaError)| {
f(lua, thread, err);
Ok(())
})
.expect("failed to create error callback function"),
)
.expect("failed to store error callback function");
}
if let Some(f) = self.on_value {
lua.set_named_registry_value(
"__forward__value",
lua.create_function(move |lua, (thread, val): (LuaThread, LuaValue)| {
f(lua, thread, val);
Ok(())
})
.expect("failed to create value callback function"),
)
.expect("failed to store value callback function");
}
}
pub(crate) fn forward_error(lua: &Lua, thread: LuaThread, error: LuaError) {
if let Ok(f) = lua.named_registry_value::<LuaFunction>("__forward__error") {
f.call::<_, ()>((thread, error)).unwrap();
}
}
pub(crate) fn forward_value(lua: &Lua, thread: LuaThread, value: LuaValue) {
if let Ok(f) = lua.named_registry_value::<LuaFunction>("__forward__value") {
f.call::<_, ()>((thread, value)).unwrap();
}
}
}

View file

@ -10,8 +10,9 @@ use smol::{
};
use super::{
thread_callbacks::ThreadCallbacks,
thread_storage::ThreadWithArgs,
thread_util::{IntoLuaThread, LuaThreadOrFunction},
ThreadWithArgs,
};
pub struct ThreadRuntime {
@ -56,10 +57,6 @@ impl ThreadRuntime {
// and only if we get the pending value back we can spawn to async executor
let pending: LuaValue = lua.registry_value(&pending_key)?;
match thread.resume::<_, LuaValue>(args.clone()) {
Err(e) => {
eprintln!("{:?}", e);
// TODO: Forward error
}
Ok(v) if v == pending => {
let stored = ThreadWithArgs::new(lua, thread.clone(), args);
q_spawn.lock_blocking().push(stored);
@ -68,9 +65,8 @@ impl ThreadRuntime {
LuaError::runtime("Tried to spawn thread to a dropped queue")
})?;
}
Ok(_) => {
// TODO: Forward value
}
Ok(v) => ThreadCallbacks::forward_value(lua, thread.clone(), v),
Err(e) => ThreadCallbacks::forward_error(lua, thread.clone(), e),
}
Ok(thread)
} else {
@ -122,17 +118,19 @@ impl ThreadRuntime {
lua: &'lua Lua,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) {
) -> LuaThread<'lua> {
let thread = thread
.into_lua_thread(lua)
.expect("failed to create thread");
let args = args.into_lua_multi(lua).expect("failed to create args");
let stored = ThreadWithArgs::new(lua, thread, args);
let stored = ThreadWithArgs::new(lua, thread.clone(), args);
self.queue_spawn.lock_blocking().push(stored);
self.queue_status.replace(true);
self.tx.try_send(()).unwrap(); // Unwrap is safe since this struct also holds the receiver
thread
}
/**
@ -177,21 +175,16 @@ impl ThreadRuntime {
// before we got here, so we need to check it again
let (thread, args) = queued_thread.into_inner(lua);
if thread.status() == LuaThreadStatus::Resumable {
let mut stream = thread.into_async::<_, LuaValue>(args);
let mut stream = thread.clone().into_async::<_, LuaValue>(args);
lua_exec
.spawn(async move {
// Only run stream until first coroutine.yield or completion. We will
// drop it right away to clear stack space since detached tasks dont drop
// until the executor drops https://github.com/smol-rs/smol/issues/294
match stream.next().await.unwrap() {
Err(e) => {
eprintln!("{e}");
// TODO: Forward error
}
Ok(_) => {
// TODO: Forward value
}
}
Ok(v) => ThreadCallbacks::forward_value(lua, thread, v),
Err(e) => ThreadCallbacks::forward_error(lua, thread, e),
};
})
.detach();
}

View file

@ -1,18 +1,19 @@
for i = 1, 20 do
print("iteration " .. tostring(i) .. " of 20")
local thread = coroutine.running()
local start = os.clock()
local counter = 0
for j = 1, 50_000 do
spawn(function()
wait()
counter += 1
if counter == 50_000 then
print("completed iteration " .. tostring(i) .. " of 20")
spawn(thread)
end
end)
end
local thread = coroutine.running()
coroutine.yield()
local counter = 0
for j = 1, 10_000 do
spawn(function()
wait()
counter += 1
if counter == 10_000 then
print("completed")
spawn(thread)
end
end)
end
coroutine.yield()
return os.clock() - start

View file

@ -5,12 +5,7 @@ use smol::*;
const MAIN_SCRIPT: &str = include_str!("./main.luau");
mod thread_runtime;
mod thread_storage;
mod thread_util;
use thread_runtime::*;
use thread_storage::*;
use smol_mlua::{thread_callbacks::ThreadCallbacks, thread_runtime::ThreadRuntime};
pub fn main() -> LuaResult<()> {
let start = Instant::now();
@ -27,9 +22,28 @@ pub fn main() -> LuaResult<()> {
})?,
)?;
// Set up runtime (thread queue / async executors) and run main script until end
// Set up runtime (thread queue / async executors)
let rt = ThreadRuntime::new(&lua)?;
rt.push_main(&lua, lua.load(MAIN_SCRIPT), ());
let main = rt.push_main(&lua, lua.load(MAIN_SCRIPT), ());
lua.set_named_registry_value("main", main)?;
// Add callbacks to capture resulting value/error of main thread
ThreadCallbacks::new()
.on_value(|lua, thread, val| {
let main = lua.named_registry_value::<LuaThread>("main").unwrap();
if main == thread {
println!("main thread value: {:?}", val);
}
})
.on_error(|lua, thread, err| {
let main = lua.named_registry_value::<LuaThread>("main").unwrap();
if main == thread {
eprintln!("main thread error: {:?}", err);
}
})
.inject(&lua);
// Run until end
rt.run_blocking(&lua);
println!("elapsed: {:?}", start.elapsed());