diff --git a/Cargo.lock b/Cargo.lock index 6901c73..7481ec4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -81,12 +81,51 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "dashmap" +version = "5.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "978747c1d849a7d2ee5e8adc0159961c48fb7e5db2f06af6723b80123bb53856" +dependencies = [ + "cfg-if", + "hashbrown", + "lock_api", + "once_cell", + "parking_lot_core", +] + +[[package]] +name = "getrandom" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "190092ea657667030ac6a35e305e62fc4dd69fd98ac98631e5d3a2b1575a12b5" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "gimli" version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "gxhash" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09f0c897148ec6ff3ca864b7c886df75e6ba09972d206bd9a89af0c18c992253" +dependencies = [ + "rand", +] + +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" + [[package]] name = "hermit-abi" version = "0.3.3" @@ -124,6 +163,8 @@ name = "luau-scheduler-experiments" version = "0.0.0" dependencies = [ "anyhow", + "dashmap", + "gxhash", "mlua", "tokio", ] @@ -258,6 +299,12 @@ version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69d3587f8a9e599cc7ec2c00e331f71c4e69a5f9a4b8a6efd5b07466b9736f9a" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro2" version = "1.0.76" @@ -276,6 +323,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.4.1" diff --git a/Cargo.toml b/Cargo.toml index 567d40d..ac60e03 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,5 +5,8 @@ edition = "2021" [dependencies] anyhow = "1.0" +dashmap = "5.5" +gxhash = "2.3" + tokio = { version = "1.0", features = ["full"] } mlua = { version = "0.9", features = ["luau", "luau-jit"] } diff --git a/src/main.rs b/src/main.rs index fbbb769..f39766e 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,9 +1,8 @@ -use std::{collections::HashMap, time::Duration}; +use std::{sync::Arc, time::Duration}; +use dashmap::DashMap; +use gxhash::GxHashMap; use mlua::prelude::*; - -mod thread_id; -use thread_id::ThreadId; use tokio::{ runtime::Runtime as TokioRuntime, select, spawn, @@ -12,6 +11,9 @@ use tokio::{ time::{sleep, Instant}, }; +mod thread_id; +use thread_id::ThreadId; + const NUM_TEST_BATCHES: usize = 20; const NUM_TEST_THREADS: usize = 50_000; @@ -24,14 +26,50 @@ __scheduler__resumeAfter(...) coroutine.yield() "#; -type RuntimeSender = UnboundedSender; -type RuntimeReceiver = UnboundedReceiver; +type ThreadMap<'lua> = GxHashMap>; -#[derive(Debug, Clone, Copy)] -enum RuntimeMessage { +type MessageSender = UnboundedSender; +type MessageReceiver = UnboundedReceiver; + +enum Message { Resume(ThreadId), Cancel(ThreadId), - Yield(ThreadId, Duration), + Sleep(ThreadId, Duration), + Error(ThreadId, LuaError), +} + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +enum StatsCounter { + Resumed, + Cancelled, + Slept, + Errored, +} + +#[derive(Debug, Clone)] +struct Stats { + start: Instant, + counters: Arc>, +} + +impl Stats { + fn new() -> Self { + Self { + start: Instant::now(), + counters: Arc::new(DashMap::new()), + } + } + + fn incr(&self, counter: StatsCounter) { + self.counters + .entry(counter) + .and_modify(|c| *c += 1) + .or_insert(1); + } + + fn elapsed(&self) -> Duration { + Instant::now() - self.start + } } fn main() { @@ -39,27 +77,35 @@ fn main() { let set = LocalSet::new(); let _guard = set.enter(); - let (msg_tx, lua_rx) = unbounded_channel::(); - let (lua_tx, msg_rx) = unbounded_channel::(); + let (msg_tx, lua_rx) = unbounded_channel::(); + let (lua_tx, msg_rx) = unbounded_channel::(); + + let stats = Stats::new(); + let stats_inner = stats.clone(); set.block_on(&rt, async { - // TODO: Handle result - let _ = select! { - r = spawn_blocking(|| lua_main(lua_rx, lua_tx)) => r, - r = spawn(sched_main(msg_rx, msg_tx)) => r, + let res = select! { + r = spawn(main_async_task(msg_rx, msg_tx, stats_inner.clone())) => r, + r = spawn_blocking(|| main_lua_task(lua_rx, lua_tx, stats_inner)) => r, }; + if let Err(e) = res { + eprintln!("Runtime fatal error: {e}"); + } }); + + println!("Finished running in {:?}", stats.elapsed()); + println!("Thread counters: {:#?}", stats.counters); } -fn lua_main(mut rx: RuntimeReceiver, tx: RuntimeSender) -> LuaResult<()> { +fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> { let lua = Lua::new(); let g = lua.globals(); lua.enable_jit(true); lua.set_app_data(tx.clone()); - let send_message = |lua: &Lua, msg: RuntimeMessage| { - lua.app_data_ref::() + let send_message = |lua: &Lua, msg: Message| { + lua.app_data_ref::() .unwrap() .send(msg) .unwrap(); @@ -70,7 +116,7 @@ fn lua_main(mut rx: RuntimeReceiver, tx: RuntimeSender) -> LuaResult<()> { LuaFunction::wrap(move |lua, duration: f64| { let thread_id = ThreadId::from(lua.current_thread()); let duration = Duration::from_secs_f64(duration); - send_message(lua, RuntimeMessage::Yield(thread_id, duration)); + send_message(lua, Message::Sleep(thread_id, duration)); Ok(()) }), )?; @@ -79,21 +125,18 @@ fn lua_main(mut rx: RuntimeReceiver, tx: RuntimeSender) -> LuaResult<()> { "__scheduler__cancel", LuaFunction::wrap(move |lua, thread: LuaThread| { let thread_id = ThreadId::from(thread); - send_message(lua, RuntimeMessage::Cancel(thread_id)); + send_message(lua, Message::Cancel(thread_id)); Ok(()) }), )?; g.set("wait", lua.load(WAIT_IMPL).into_function()?)?; - let mut yielded_threads: HashMap = HashMap::new(); - let mut runnable_threads: HashMap = HashMap::new(); - - let before = Instant::now(); - - for n in 1..=NUM_TEST_BATCHES { - println!("Running batch {n} of {NUM_TEST_BATCHES}"); + let mut yielded_threads = ThreadMap::default(); + let mut runnable_threads = ThreadMap::default(); + println!("Running {NUM_TEST_BATCHES} batches"); + for _ in 0..NUM_TEST_BATCHES { let main_fn = lua.load(MAIN_CHUNK).into_function()?; for _ in 0..NUM_TEST_THREADS { let thread = lua.create_thread(main_fn.clone())?; @@ -108,26 +151,31 @@ fn lua_main(mut rx: RuntimeReceiver, tx: RuntimeSender) -> LuaResult<()> { // Resume as many threads as possible for (thread_id, thread) in runnable_threads.drain() { - thread.resume(())?; + stats.incr(StatsCounter::Resumed); + if let Err(e) = thread.resume::<_, ()>(()) { + stats.incr(StatsCounter::Errored); + send_message(&lua, Message::Error(thread_id, e)); + } if thread.status() == LuaThreadStatus::Resumable { yielded_threads.insert(thread_id, thread); } } if yielded_threads.is_empty() { - break; // All threads ran and we don't have any async task that can spawn more + break; // All threads ran, and we don't have any async task that can spawn more } // Wait for at least one message, but try to receive as many as possible let mut process_message = |message| match message { - RuntimeMessage::Resume(thread_id) => { + Message::Resume(thread_id) => { if let Some(thread) = yielded_threads.remove(&thread_id) { runnable_threads.insert(thread_id, thread); } } - RuntimeMessage::Cancel(thread_id) => { + Message::Cancel(thread_id) => { yielded_threads.remove(&thread_id); runnable_threads.remove(&thread_id); + stats.incr(StatsCounter::Cancelled); } _ => unreachable!(), }; @@ -142,28 +190,35 @@ fn lua_main(mut rx: RuntimeReceiver, tx: RuntimeSender) -> LuaResult<()> { } } - let after = Instant::now(); - - println!( - "Ran {} threads in {:?}", - NUM_TEST_BATCHES * NUM_TEST_THREADS, - after - before - ); - Ok(()) } -async fn sched_main(mut rx: RuntimeReceiver, tx: RuntimeSender) -> LuaResult<()> { +async fn main_async_task( + mut rx: MessageReceiver, + tx: MessageSender, + stats: Stats, +) -> LuaResult<()> { + // Set up message processor + let process_message = |message| match message { + Message::Sleep(thread_id, duration) => { + stats.incr(StatsCounter::Slept); + let tx = tx.clone(); + spawn(async move { + sleep(duration).await; + let _ = tx.send(Message::Resume(thread_id)); + }); + } + Message::Error(_, e) => { + eprintln!("Lua error: {e}"); + } + _ => unreachable!(), + }; + + // Wait for at least one message, but try to receive as many as possible while let Some(message) = rx.recv().await { - match message { - RuntimeMessage::Yield(thread_id, duration) => { - let tx = tx.clone(); - spawn(async move { - sleep(duration).await; - let _ = tx.send(RuntimeMessage::Resume(thread_id)); - }); - } - _ => unreachable!(), + process_message(message); + while let Ok(message) = rx.try_recv() { + process_message(message); } }