From ec135b8a39434104c3d97e92f5a6c815d8af203c Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Wed, 17 Jan 2024 10:00:46 +0100 Subject: [PATCH] Implement stdio --- src/main.rs | 101 +++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 77 insertions(+), 24 deletions(-) diff --git a/src/main.rs b/src/main.rs index f39766e..609e031 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,7 @@ use dashmap::DashMap; use gxhash::GxHashMap; use mlua::prelude::*; use tokio::{ + io::{self, AsyncWriteExt}, runtime::Runtime as TokioRuntime, select, spawn, sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, @@ -36,14 +37,18 @@ enum Message { Cancel(ThreadId), Sleep(ThreadId, Duration), Error(ThreadId, LuaError), + WriteStdout(Vec), + WriteStderr(Vec), } #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] enum StatsCounter { - Resumed, - Cancelled, - Slept, - Errored, + ThreadResumed, + ThreadCancelled, + ThreadSlept, + ThreadErrored, + WriteStdout, + WriteStderr, } #[derive(Debug, Clone)] @@ -130,6 +135,22 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu }), )?; + g.set( + "__scheduler__writeStdout", + LuaFunction::wrap(move |lua, data: Vec| { + send_message(lua, Message::WriteStdout(data)); + Ok(()) + }), + )?; + + g.set( + "__scheduler__writeStderr", + LuaFunction::wrap(move |lua, data: Vec| { + send_message(lua, Message::WriteStderr(data)); + Ok(()) + }), + )?; + g.set("wait", lua.load(WAIT_IMPL).into_function()?)?; let mut yielded_threads = ThreadMap::default(); @@ -151,9 +172,8 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu // Resume as many threads as possible for (thread_id, thread) in runnable_threads.drain() { - stats.incr(StatsCounter::Resumed); + stats.incr(StatsCounter::ThreadResumed); if let Err(e) = thread.resume::<_, ()>(()) { - stats.incr(StatsCounter::Errored); send_message(&lua, Message::Error(thread_id, e)); } if thread.status() == LuaThreadStatus::Resumable { @@ -175,7 +195,7 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu Message::Cancel(thread_id) => { yielded_threads.remove(&thread_id); runnable_threads.remove(&thread_id); - stats.incr(StatsCounter::Cancelled); + stats.incr(StatsCounter::ThreadCancelled); } _ => unreachable!(), }; @@ -198,27 +218,60 @@ async fn main_async_task( tx: MessageSender, stats: Stats, ) -> LuaResult<()> { - // Set up message processor - let process_message = |message| match message { - Message::Sleep(thread_id, duration) => { - stats.incr(StatsCounter::Slept); - let tx = tx.clone(); - spawn(async move { - sleep(duration).await; - let _ = tx.send(Message::Resume(thread_id)); - }); - } - Message::Error(_, e) => { - eprintln!("Lua error: {e}"); - } - _ => unreachable!(), - }; + let mut handle_stdout = io::stdout(); + let mut handle_stderr = io::stderr(); // Wait for at least one message, but try to receive as many as possible + let mut messages = Vec::new(); while let Some(message) = rx.recv().await { - process_message(message); + messages.push(message); while let Ok(message) = rx.try_recv() { - process_message(message); + messages.push(message); + } + + // Handle all messages + let mut wrote_stdout = false; + let mut wrote_stderr = false; + for message in messages.drain(..) { + match message { + Message::Sleep(_, _) => stats.incr(StatsCounter::ThreadSlept), + Message::Error(_, _) => stats.incr(StatsCounter::ThreadErrored), + Message::WriteStdout(_) => stats.incr(StatsCounter::WriteStdout), + Message::WriteStderr(_) => stats.incr(StatsCounter::WriteStderr), + _ => unreachable!(), + } + + match message { + Message::Sleep(thread_id, duration) => { + let tx = tx.clone(); + spawn(async move { + sleep(duration).await; + tx.send(Message::Resume(thread_id)) + }); + } + Message::Error(_, e) => { + wrote_stderr = true; + handle_stderr.write_all(b"Lua error: ").await?; + handle_stderr.write_all(e.to_string().as_bytes()).await?; + } + Message::WriteStdout(data) => { + wrote_stdout = true; + handle_stdout.write_all(&data).await?; + } + Message::WriteStderr(data) => { + wrote_stderr = true; + handle_stderr.write_all(&data).await?; + } + _ => unreachable!(), + } + } + + // Flush streams if we wrote anything to them + if wrote_stdout { + handle_stdout.flush().await?; + } + if wrote_stderr { + handle_stderr.flush().await?; } }