diff --git a/src/lua.rs b/src/lua.rs index 2bda0ad..c9dff0e 100644 --- a/src/lua.rs +++ b/src/lua.rs @@ -1,56 +1,39 @@ -use std::time::Duration; - use mlua::prelude::*; -use tokio::time::Instant; use crate::{Message, MessageSender, ThreadId}; -pub fn create_lua(tx: MessageSender) -> LuaResult { +pub fn create_lua(lua_tx: MessageSender, async_tx: MessageSender) -> LuaResult { let lua = Lua::new(); lua.enable_jit(true); - lua.set_app_data(tx.clone()); - - // Resumption - let tx_resume = tx.clone(); - lua.globals().set( - "__scheduler__resumeAfter", - LuaFunction::wrap(move |lua, duration: f64| { - let thread_id = ThreadId::from(lua.current_thread()); - let yielded_at = Instant::now(); - let duration = Duration::from_secs_f64(duration); - tx_resume - .send(Message::Sleep(thread_id, yielded_at, duration)) - .into_lua_err() - }), - )?; + lua.set_app_data(async_tx.clone()); // Cancellation - let tx_cancel = tx.clone(); + let cancel_tx = lua_tx.clone(); lua.globals().set( "__scheduler__cancel", LuaFunction::wrap(move |_, thread: LuaThread| { let thread_id = ThreadId::from(thread); - tx_cancel.send(Message::Cancel(thread_id)).into_lua_err() + cancel_tx.send(Message::Cancel(thread_id)).into_lua_err() }), )?; // Stdout - let tx_stdout = tx.clone(); + let stdout_tx = async_tx.clone(); lua.globals().set( "__scheduler__writeStdout", LuaFunction::wrap(move |_, s: LuaString| { let bytes = s.as_bytes().to_vec(); - tx_stdout.send(Message::WriteStdout(bytes)).into_lua_err() + stdout_tx.send(Message::WriteStdout(bytes)).into_lua_err() }), )?; // Stderr - let tx_stderr = tx.clone(); + let stderr_tx = async_tx.clone(); lua.globals().set( "__scheduler__writeStderr", LuaFunction::wrap(move |_, s: LuaString| { let bytes = s.as_bytes().to_vec(); - tx_stderr.send(Message::WriteStderr(bytes)).into_lua_err() + stderr_tx.send(Message::WriteStderr(bytes)).into_lua_err() }), )?; diff --git a/src/main.rs b/src/main.rs index 9d4a8cb..8064f54 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,3 +1,5 @@ +use std::time::Duration; + use gxhash::GxHashMap; use mlua::prelude::*; use tokio::{ @@ -24,6 +26,8 @@ use stats::*; use thread_id::*; use value::*; +use crate::lua_ext::LuaAsyncExt; + const NUM_TEST_BATCHES: usize = 20; const NUM_TEST_THREADS: usize = 50_000; @@ -31,26 +35,21 @@ const MAIN_CHUNK: &str = r#" wait(0.01 * math.random()) "#; -const WAIT_IMPL: &str = r#" -__scheduler__resumeAfter(...) -return coroutine.yield() -"#; - fn main() { let rt = TokioRuntime::new().unwrap(); let set = LocalSet::new(); let _guard = set.enter(); - let (msg_tx, lua_rx) = unbounded_channel::(); - let (lua_tx, msg_rx) = unbounded_channel::(); + let (async_tx, lua_rx) = unbounded_channel::(); + let (lua_tx, async_rx) = unbounded_channel::(); let stats = Stats::new(); let stats_inner = stats.clone(); set.block_on(&rt, async { let res = select! { - r = spawn(main_async_task(msg_rx, msg_tx, stats_inner.clone())) => r, - r = spawn_blocking(|| main_lua_task(lua_rx, lua_tx, stats_inner)) => r, + r = spawn(main_async_task(async_rx, stats_inner.clone())) => r, + r = spawn_blocking(move || main_lua_task(lua_rx, lua_tx, async_tx, stats_inner)) => r, }; if let Err(e) = res { eprintln!("Runtime fatal error: {e}"); @@ -61,8 +60,13 @@ fn main() { println!("Thread counters: {:#?}", stats.counters); } -fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> { - let lua = create_lua(tx.clone())?; +fn main_lua_task( + mut lua_rx: MessageReceiver, + lua_tx: MessageSender, + async_tx: MessageSender, + stats: Stats, +) -> LuaResult<()> { + let lua = create_lua(lua_tx.clone(), async_tx.clone())?; let error_storage = ErrorStorage::new(); let error_storage_interrupt = error_storage.clone(); @@ -71,8 +75,14 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu None => Ok(LuaVmState::Continue), }); - lua.globals() - .set("wait", lua.load(WAIT_IMPL).into_function()?)?; + lua.globals().set( + "wait", + lua.create_async_function(|_, duration: f64| async move { + let before = Instant::now(); + sleep(Duration::from_secs_f64(duration)).await; + Ok(Instant::now() - before) + })?, + )?; let mut yielded_threads = GxHashMap::default(); let mut runnable_threads = GxHashMap::default(); @@ -103,8 +113,10 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu } }; if let Err(e) = thread.resume::<_, ()>(args) { - tx.send(Message::WriteError(e)).unwrap(); + stats.incr(StatsCounter::ThreadErrored); + async_tx.send(Message::WriteError(e)).unwrap(); } else if thread.status() == LuaThreadStatus::Resumable { + stats.incr(StatsCounter::ThreadYielded); yielded_threads.insert(thread_id, thread); } } @@ -126,13 +138,13 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu runnable_threads.remove(&thread_id); stats.incr(StatsCounter::ThreadCancelled); } - _ => unreachable!(), + m => unreachable!("got non-lua message: {m:?}"), }; // Wait for at least one message, but try to receive as many as possible - if let Some(message) = rx.blocking_recv() { + if let Some(message) = lua_rx.blocking_recv() { process_message(message); - while let Ok(message) = rx.try_recv() { + while let Ok(message) = lua_rx.try_recv() { process_message(message); } } else { @@ -144,59 +156,39 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu Ok(()) } -async fn main_async_task( - mut rx: MessageReceiver, - tx: MessageSender, - stats: Stats, -) -> LuaResult<()> { +async fn main_async_task(mut async_rx: MessageReceiver, stats: Stats) -> LuaResult<()> { // Give stdio its own task, we don't need it to block the scheduler - let (tx_stdout, rx_stdout) = unbounded_channel(); - let (tx_stderr, rx_stderr) = unbounded_channel(); - let forward_stdout = |data| tx_stdout.send(data).ok(); - let forward_stderr = |data| tx_stderr.send(data).ok(); + let (stdout_tx, stdout_rx) = unbounded_channel(); + let (stderr_tx, stderr_rx) = unbounded_channel(); + let forward_stdout = |data| stdout_tx.send(data).ok(); + let forward_stderr = |data| stderr_tx.send(data).ok(); spawn(async move { - if let Err(e) = async_stdio_task(rx_stdout, rx_stderr).await { + if let Err(e) = async_stdio_task(stdout_rx, stderr_rx).await { eprintln!("Stdio fatal error: {e}"); } }); // Set up message processor - let process_message = |message| { - match message { - Message::Sleep(_, _, _) => stats.incr(StatsCounter::ThreadSlept), - Message::WriteError(_) => stats.incr(StatsCounter::ThreadErrored), - Message::WriteStdout(_) => stats.incr(StatsCounter::WriteStdout), - Message::WriteStderr(_) => stats.incr(StatsCounter::WriteStderr), - _ => unreachable!(), + let process_message = |message| match message { + Message::WriteError(e) => { + forward_stderr(b"Lua error: ".to_vec()); + forward_stderr(e.to_string().as_bytes().to_vec()); } - - match message { - Message::Sleep(thread_id, yielded_at, duration) => { - let tx = tx.clone(); - spawn(async move { - sleep(duration).await; - let elapsed = Instant::now() - yielded_at; - tx.send(Message::Resume(thread_id, Ok(AsyncValues::from(elapsed)))) - }); - } - Message::WriteError(e) => { - forward_stderr(b"Lua error: ".to_vec()); - forward_stderr(e.to_string().as_bytes().to_vec()); - } - Message::WriteStdout(data) => { - forward_stdout(data); - } - Message::WriteStderr(data) => { - forward_stderr(data); - } - _ => unreachable!(), + Message::WriteStdout(data) => { + forward_stdout(data); + stats.incr(StatsCounter::WriteStdout); } + Message::WriteStderr(data) => { + forward_stderr(data); + stats.incr(StatsCounter::WriteStderr); + } + _ => unreachable!(), }; // Wait for at least one message, but try to receive as many as possible - while let Some(message) = rx.recv().await { + while let Some(message) = async_rx.recv().await { process_message(message); - while let Ok(message) = rx.try_recv() { + while let Ok(message) = async_rx.try_recv() { process_message(message); } } @@ -205,22 +197,22 @@ async fn main_async_task( } async fn async_stdio_task( - mut rx_stdout: UnboundedReceiver>, - mut rx_stderr: UnboundedReceiver>, + mut stdout_rx: UnboundedReceiver>, + mut stderr_rx: UnboundedReceiver>, ) -> LuaResult<()> { let mut stdout = io::stdout(); let mut stderr = io::stderr(); loop { select! { - data = rx_stdout.recv() => match data { + data = stdout_rx.recv() => match data { None => break, // Main task exited Some(data) => { stdout.write_all(&data).await?; stdout.flush().await?; } }, - data = rx_stderr.recv() => match data { + data = stderr_rx.recv() => match data { None => break, // Main task exited Some(data) => { stderr.write_all(&data).await?; diff --git a/src/message.rs b/src/message.rs index 6157331..674becb 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,20 +1,15 @@ -use std::time::Duration; - use mlua::prelude::*; -use tokio::{ - sync::mpsc::{UnboundedReceiver, UnboundedSender}, - time::Instant, -}; +use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use crate::{AsyncValues, ThreadId}; pub type MessageSender = UnboundedSender; pub type MessageReceiver = UnboundedReceiver; +#[derive(Debug)] pub enum Message { Resume(ThreadId, LuaResult), Cancel(ThreadId), - Sleep(ThreadId, Instant, Duration), WriteError(LuaError), WriteStdout(Vec), WriteStderr(Vec), diff --git a/src/stats.rs b/src/stats.rs index fb1d97f..55eaeca 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -6,8 +6,8 @@ use tokio::time::Instant; #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub enum StatsCounter { ThreadResumed, + ThreadYielded, ThreadCancelled, - ThreadSlept, ThreadErrored, WriteStdout, WriteStderr,