Some final experiments

This commit is contained in:
Filip Tibell 2024-01-17 15:46:17 +01:00
parent dec5e940ed
commit aa1320daf1
No known key found for this signature in database
4 changed files with 64 additions and 94 deletions

View file

@ -1,56 +1,39 @@
use std::time::Duration;
use mlua::prelude::*; use mlua::prelude::*;
use tokio::time::Instant;
use crate::{Message, MessageSender, ThreadId}; use crate::{Message, MessageSender, ThreadId};
pub fn create_lua(tx: MessageSender) -> LuaResult<Lua> { pub fn create_lua(lua_tx: MessageSender, async_tx: MessageSender) -> LuaResult<Lua> {
let lua = Lua::new(); let lua = Lua::new();
lua.enable_jit(true); lua.enable_jit(true);
lua.set_app_data(tx.clone()); lua.set_app_data(async_tx.clone());
// Resumption
let tx_resume = tx.clone();
lua.globals().set(
"__scheduler__resumeAfter",
LuaFunction::wrap(move |lua, duration: f64| {
let thread_id = ThreadId::from(lua.current_thread());
let yielded_at = Instant::now();
let duration = Duration::from_secs_f64(duration);
tx_resume
.send(Message::Sleep(thread_id, yielded_at, duration))
.into_lua_err()
}),
)?;
// Cancellation // Cancellation
let tx_cancel = tx.clone(); let cancel_tx = lua_tx.clone();
lua.globals().set( lua.globals().set(
"__scheduler__cancel", "__scheduler__cancel",
LuaFunction::wrap(move |_, thread: LuaThread| { LuaFunction::wrap(move |_, thread: LuaThread| {
let thread_id = ThreadId::from(thread); let thread_id = ThreadId::from(thread);
tx_cancel.send(Message::Cancel(thread_id)).into_lua_err() cancel_tx.send(Message::Cancel(thread_id)).into_lua_err()
}), }),
)?; )?;
// Stdout // Stdout
let tx_stdout = tx.clone(); let stdout_tx = async_tx.clone();
lua.globals().set( lua.globals().set(
"__scheduler__writeStdout", "__scheduler__writeStdout",
LuaFunction::wrap(move |_, s: LuaString| { LuaFunction::wrap(move |_, s: LuaString| {
let bytes = s.as_bytes().to_vec(); let bytes = s.as_bytes().to_vec();
tx_stdout.send(Message::WriteStdout(bytes)).into_lua_err() stdout_tx.send(Message::WriteStdout(bytes)).into_lua_err()
}), }),
)?; )?;
// Stderr // Stderr
let tx_stderr = tx.clone(); let stderr_tx = async_tx.clone();
lua.globals().set( lua.globals().set(
"__scheduler__writeStderr", "__scheduler__writeStderr",
LuaFunction::wrap(move |_, s: LuaString| { LuaFunction::wrap(move |_, s: LuaString| {
let bytes = s.as_bytes().to_vec(); let bytes = s.as_bytes().to_vec();
tx_stderr.send(Message::WriteStderr(bytes)).into_lua_err() stderr_tx.send(Message::WriteStderr(bytes)).into_lua_err()
}), }),
)?; )?;

View file

@ -1,3 +1,5 @@
use std::time::Duration;
use gxhash::GxHashMap; use gxhash::GxHashMap;
use mlua::prelude::*; use mlua::prelude::*;
use tokio::{ use tokio::{
@ -24,6 +26,8 @@ use stats::*;
use thread_id::*; use thread_id::*;
use value::*; use value::*;
use crate::lua_ext::LuaAsyncExt;
const NUM_TEST_BATCHES: usize = 20; const NUM_TEST_BATCHES: usize = 20;
const NUM_TEST_THREADS: usize = 50_000; const NUM_TEST_THREADS: usize = 50_000;
@ -31,26 +35,21 @@ const MAIN_CHUNK: &str = r#"
wait(0.01 * math.random()) wait(0.01 * math.random())
"#; "#;
const WAIT_IMPL: &str = r#"
__scheduler__resumeAfter(...)
return coroutine.yield()
"#;
fn main() { fn main() {
let rt = TokioRuntime::new().unwrap(); let rt = TokioRuntime::new().unwrap();
let set = LocalSet::new(); let set = LocalSet::new();
let _guard = set.enter(); let _guard = set.enter();
let (msg_tx, lua_rx) = unbounded_channel::<Message>(); let (async_tx, lua_rx) = unbounded_channel::<Message>();
let (lua_tx, msg_rx) = unbounded_channel::<Message>(); let (lua_tx, async_rx) = unbounded_channel::<Message>();
let stats = Stats::new(); let stats = Stats::new();
let stats_inner = stats.clone(); let stats_inner = stats.clone();
set.block_on(&rt, async { set.block_on(&rt, async {
let res = select! { let res = select! {
r = spawn(main_async_task(msg_rx, msg_tx, stats_inner.clone())) => r, r = spawn(main_async_task(async_rx, stats_inner.clone())) => r,
r = spawn_blocking(|| main_lua_task(lua_rx, lua_tx, stats_inner)) => r, r = spawn_blocking(move || main_lua_task(lua_rx, lua_tx, async_tx, stats_inner)) => r,
}; };
if let Err(e) = res { if let Err(e) = res {
eprintln!("Runtime fatal error: {e}"); eprintln!("Runtime fatal error: {e}");
@ -61,8 +60,13 @@ fn main() {
println!("Thread counters: {:#?}", stats.counters); println!("Thread counters: {:#?}", stats.counters);
} }
fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> { fn main_lua_task(
let lua = create_lua(tx.clone())?; mut lua_rx: MessageReceiver,
lua_tx: MessageSender,
async_tx: MessageSender,
stats: Stats,
) -> LuaResult<()> {
let lua = create_lua(lua_tx.clone(), async_tx.clone())?;
let error_storage = ErrorStorage::new(); let error_storage = ErrorStorage::new();
let error_storage_interrupt = error_storage.clone(); let error_storage_interrupt = error_storage.clone();
@ -71,8 +75,14 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
None => Ok(LuaVmState::Continue), None => Ok(LuaVmState::Continue),
}); });
lua.globals() lua.globals().set(
.set("wait", lua.load(WAIT_IMPL).into_function()?)?; "wait",
lua.create_async_function(|_, duration: f64| async move {
let before = Instant::now();
sleep(Duration::from_secs_f64(duration)).await;
Ok(Instant::now() - before)
})?,
)?;
let mut yielded_threads = GxHashMap::default(); let mut yielded_threads = GxHashMap::default();
let mut runnable_threads = GxHashMap::default(); let mut runnable_threads = GxHashMap::default();
@ -103,8 +113,10 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
} }
}; };
if let Err(e) = thread.resume::<_, ()>(args) { if let Err(e) = thread.resume::<_, ()>(args) {
tx.send(Message::WriteError(e)).unwrap(); stats.incr(StatsCounter::ThreadErrored);
async_tx.send(Message::WriteError(e)).unwrap();
} else if thread.status() == LuaThreadStatus::Resumable { } else if thread.status() == LuaThreadStatus::Resumable {
stats.incr(StatsCounter::ThreadYielded);
yielded_threads.insert(thread_id, thread); yielded_threads.insert(thread_id, thread);
} }
} }
@ -126,13 +138,13 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
runnable_threads.remove(&thread_id); runnable_threads.remove(&thread_id);
stats.incr(StatsCounter::ThreadCancelled); stats.incr(StatsCounter::ThreadCancelled);
} }
_ => unreachable!(), m => unreachable!("got non-lua message: {m:?}"),
}; };
// Wait for at least one message, but try to receive as many as possible // Wait for at least one message, but try to receive as many as possible
if let Some(message) = rx.blocking_recv() { if let Some(message) = lua_rx.blocking_recv() {
process_message(message); process_message(message);
while let Ok(message) = rx.try_recv() { while let Ok(message) = lua_rx.try_recv() {
process_message(message); process_message(message);
} }
} else { } else {
@ -144,59 +156,39 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
Ok(()) Ok(())
} }
async fn main_async_task( async fn main_async_task(mut async_rx: MessageReceiver, stats: Stats) -> LuaResult<()> {
mut rx: MessageReceiver,
tx: MessageSender,
stats: Stats,
) -> LuaResult<()> {
// Give stdio its own task, we don't need it to block the scheduler // Give stdio its own task, we don't need it to block the scheduler
let (tx_stdout, rx_stdout) = unbounded_channel(); let (stdout_tx, stdout_rx) = unbounded_channel();
let (tx_stderr, rx_stderr) = unbounded_channel(); let (stderr_tx, stderr_rx) = unbounded_channel();
let forward_stdout = |data| tx_stdout.send(data).ok(); let forward_stdout = |data| stdout_tx.send(data).ok();
let forward_stderr = |data| tx_stderr.send(data).ok(); let forward_stderr = |data| stderr_tx.send(data).ok();
spawn(async move { spawn(async move {
if let Err(e) = async_stdio_task(rx_stdout, rx_stderr).await { if let Err(e) = async_stdio_task(stdout_rx, stderr_rx).await {
eprintln!("Stdio fatal error: {e}"); eprintln!("Stdio fatal error: {e}");
} }
}); });
// Set up message processor // Set up message processor
let process_message = |message| { let process_message = |message| match message {
match message { Message::WriteError(e) => {
Message::Sleep(_, _, _) => stats.incr(StatsCounter::ThreadSlept), forward_stderr(b"Lua error: ".to_vec());
Message::WriteError(_) => stats.incr(StatsCounter::ThreadErrored), forward_stderr(e.to_string().as_bytes().to_vec());
Message::WriteStdout(_) => stats.incr(StatsCounter::WriteStdout),
Message::WriteStderr(_) => stats.incr(StatsCounter::WriteStderr),
_ => unreachable!(),
} }
Message::WriteStdout(data) => {
match message { forward_stdout(data);
Message::Sleep(thread_id, yielded_at, duration) => { stats.incr(StatsCounter::WriteStdout);
let tx = tx.clone();
spawn(async move {
sleep(duration).await;
let elapsed = Instant::now() - yielded_at;
tx.send(Message::Resume(thread_id, Ok(AsyncValues::from(elapsed))))
});
}
Message::WriteError(e) => {
forward_stderr(b"Lua error: ".to_vec());
forward_stderr(e.to_string().as_bytes().to_vec());
}
Message::WriteStdout(data) => {
forward_stdout(data);
}
Message::WriteStderr(data) => {
forward_stderr(data);
}
_ => unreachable!(),
} }
Message::WriteStderr(data) => {
forward_stderr(data);
stats.incr(StatsCounter::WriteStderr);
}
_ => unreachable!(),
}; };
// Wait for at least one message, but try to receive as many as possible // Wait for at least one message, but try to receive as many as possible
while let Some(message) = rx.recv().await { while let Some(message) = async_rx.recv().await {
process_message(message); process_message(message);
while let Ok(message) = rx.try_recv() { while let Ok(message) = async_rx.try_recv() {
process_message(message); process_message(message);
} }
} }
@ -205,22 +197,22 @@ async fn main_async_task(
} }
async fn async_stdio_task( async fn async_stdio_task(
mut rx_stdout: UnboundedReceiver<Vec<u8>>, mut stdout_rx: UnboundedReceiver<Vec<u8>>,
mut rx_stderr: UnboundedReceiver<Vec<u8>>, mut stderr_rx: UnboundedReceiver<Vec<u8>>,
) -> LuaResult<()> { ) -> LuaResult<()> {
let mut stdout = io::stdout(); let mut stdout = io::stdout();
let mut stderr = io::stderr(); let mut stderr = io::stderr();
loop { loop {
select! { select! {
data = rx_stdout.recv() => match data { data = stdout_rx.recv() => match data {
None => break, // Main task exited None => break, // Main task exited
Some(data) => { Some(data) => {
stdout.write_all(&data).await?; stdout.write_all(&data).await?;
stdout.flush().await?; stdout.flush().await?;
} }
}, },
data = rx_stderr.recv() => match data { data = stderr_rx.recv() => match data {
None => break, // Main task exited None => break, // Main task exited
Some(data) => { Some(data) => {
stderr.write_all(&data).await?; stderr.write_all(&data).await?;

View file

@ -1,20 +1,15 @@
use std::time::Duration;
use mlua::prelude::*; use mlua::prelude::*;
use tokio::{ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
sync::mpsc::{UnboundedReceiver, UnboundedSender},
time::Instant,
};
use crate::{AsyncValues, ThreadId}; use crate::{AsyncValues, ThreadId};
pub type MessageSender = UnboundedSender<Message>; pub type MessageSender = UnboundedSender<Message>;
pub type MessageReceiver = UnboundedReceiver<Message>; pub type MessageReceiver = UnboundedReceiver<Message>;
#[derive(Debug)]
pub enum Message { pub enum Message {
Resume(ThreadId, LuaResult<AsyncValues>), Resume(ThreadId, LuaResult<AsyncValues>),
Cancel(ThreadId), Cancel(ThreadId),
Sleep(ThreadId, Instant, Duration),
WriteError(LuaError), WriteError(LuaError),
WriteStdout(Vec<u8>), WriteStdout(Vec<u8>),
WriteStderr(Vec<u8>), WriteStderr(Vec<u8>),

View file

@ -6,8 +6,8 @@ use tokio::time::Instant;
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum StatsCounter { pub enum StatsCounter {
ThreadResumed, ThreadResumed,
ThreadYielded,
ThreadCancelled, ThreadCancelled,
ThreadSlept,
ThreadErrored, ThreadErrored,
WriteStdout, WriteStdout,
WriteStderr, WriteStderr,