Implement stdio

This commit is contained in:
Filip Tibell 2024-01-17 10:00:46 +01:00
parent 919aac3043
commit ec135b8a39
No known key found for this signature in database

View file

@ -4,6 +4,7 @@ use dashmap::DashMap;
use gxhash::GxHashMap;
use mlua::prelude::*;
use tokio::{
io::{self, AsyncWriteExt},
runtime::Runtime as TokioRuntime,
select, spawn,
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
@ -36,14 +37,18 @@ enum Message {
Cancel(ThreadId),
Sleep(ThreadId, Duration),
Error(ThreadId, LuaError),
WriteStdout(Vec<u8>),
WriteStderr(Vec<u8>),
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
enum StatsCounter {
Resumed,
Cancelled,
Slept,
Errored,
ThreadResumed,
ThreadCancelled,
ThreadSlept,
ThreadErrored,
WriteStdout,
WriteStderr,
}
#[derive(Debug, Clone)]
@ -130,6 +135,22 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
}),
)?;
g.set(
"__scheduler__writeStdout",
LuaFunction::wrap(move |lua, data: Vec<u8>| {
send_message(lua, Message::WriteStdout(data));
Ok(())
}),
)?;
g.set(
"__scheduler__writeStderr",
LuaFunction::wrap(move |lua, data: Vec<u8>| {
send_message(lua, Message::WriteStderr(data));
Ok(())
}),
)?;
g.set("wait", lua.load(WAIT_IMPL).into_function()?)?;
let mut yielded_threads = ThreadMap::default();
@ -151,9 +172,8 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
// Resume as many threads as possible
for (thread_id, thread) in runnable_threads.drain() {
stats.incr(StatsCounter::Resumed);
stats.incr(StatsCounter::ThreadResumed);
if let Err(e) = thread.resume::<_, ()>(()) {
stats.incr(StatsCounter::Errored);
send_message(&lua, Message::Error(thread_id, e));
}
if thread.status() == LuaThreadStatus::Resumable {
@ -175,7 +195,7 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
Message::Cancel(thread_id) => {
yielded_threads.remove(&thread_id);
runnable_threads.remove(&thread_id);
stats.incr(StatsCounter::Cancelled);
stats.incr(StatsCounter::ThreadCancelled);
}
_ => unreachable!(),
};
@ -198,27 +218,60 @@ async fn main_async_task(
tx: MessageSender,
stats: Stats,
) -> LuaResult<()> {
// Set up message processor
let process_message = |message| match message {
Message::Sleep(thread_id, duration) => {
stats.incr(StatsCounter::Slept);
let tx = tx.clone();
spawn(async move {
sleep(duration).await;
let _ = tx.send(Message::Resume(thread_id));
});
}
Message::Error(_, e) => {
eprintln!("Lua error: {e}");
}
_ => unreachable!(),
};
let mut handle_stdout = io::stdout();
let mut handle_stderr = io::stderr();
// Wait for at least one message, but try to receive as many as possible
let mut messages = Vec::new();
while let Some(message) = rx.recv().await {
process_message(message);
messages.push(message);
while let Ok(message) = rx.try_recv() {
process_message(message);
messages.push(message);
}
// Handle all messages
let mut wrote_stdout = false;
let mut wrote_stderr = false;
for message in messages.drain(..) {
match message {
Message::Sleep(_, _) => stats.incr(StatsCounter::ThreadSlept),
Message::Error(_, _) => stats.incr(StatsCounter::ThreadErrored),
Message::WriteStdout(_) => stats.incr(StatsCounter::WriteStdout),
Message::WriteStderr(_) => stats.incr(StatsCounter::WriteStderr),
_ => unreachable!(),
}
match message {
Message::Sleep(thread_id, duration) => {
let tx = tx.clone();
spawn(async move {
sleep(duration).await;
tx.send(Message::Resume(thread_id))
});
}
Message::Error(_, e) => {
wrote_stderr = true;
handle_stderr.write_all(b"Lua error: ").await?;
handle_stderr.write_all(e.to_string().as_bytes()).await?;
}
Message::WriteStdout(data) => {
wrote_stdout = true;
handle_stdout.write_all(&data).await?;
}
Message::WriteStderr(data) => {
wrote_stderr = true;
handle_stderr.write_all(&data).await?;
}
_ => unreachable!(),
}
}
// Flush streams if we wrote anything to them
if wrote_stdout {
handle_stdout.flush().await?;
}
if wrote_stderr {
handle_stderr.flush().await?;
}
}