diff --git a/src/error_storage.rs b/src/error_storage.rs new file mode 100644 index 0000000..a3fda8e --- /dev/null +++ b/src/error_storage.rs @@ -0,0 +1,36 @@ +use std::sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, +}; + +use mlua::prelude::*; + +#[derive(Debug, Clone)] +pub struct ErrorStorage { + is_some: Arc, + inner: Arc>>, +} + +impl ErrorStorage { + pub fn new() -> Self { + Self { + is_some: Arc::new(AtomicBool::new(false)), + inner: Arc::new(Mutex::new(None)), + } + } + + #[inline] + pub fn take(&self) -> Option { + if self.is_some.load(Ordering::Relaxed) { + self.inner.lock().unwrap().take() + } else { + None + } + } + + #[inline] + pub fn replace(&self, e: LuaError) { + self.is_some.store(true, Ordering::Relaxed); + self.inner.lock().unwrap().replace(e); + } +} diff --git a/src/lua_ext.rs b/src/lua_ext.rs new file mode 100644 index 0000000..aa9654a --- /dev/null +++ b/src/lua_ext.rs @@ -0,0 +1,42 @@ +use std::future::Future; + +use mlua::prelude::*; +use tokio::spawn; + +use crate::{Args, Message, MessageSender, ThreadId}; + +pub trait LuaSchedulerExt<'lua> { + fn create_async_function(&'lua self, func: F) -> LuaResult> + where + A: FromLuaMulti<'lua> + 'static, + R: Into + Send + 'static, + F: Fn(&'lua Lua, A) -> FR + 'static, + FR: Future> + Send + 'static; +} + +impl<'lua> LuaSchedulerExt<'lua> for Lua { + fn create_async_function(&'lua self, func: F) -> LuaResult> + where + A: FromLuaMulti<'lua> + 'static, + R: Into + Send + 'static, + F: Fn(&'lua Lua, A) -> FR + 'static, + FR: Future> + Send + 'static, + { + let tx = self.app_data_ref::().unwrap().clone(); + + self.create_function(move |lua, args: A| { + let thread_id = ThreadId::from(lua.current_thread()); + let fut = func(lua, args); + let tx = tx.clone(); + + spawn(async move { + tx.send(match fut.await { + Ok(args) => Message::Resume(thread_id, Ok(args.into())), + Err(e) => Message::Resume(thread_id, Err(e)), + }) + }); + + Ok(()) + }) + } +} diff --git a/src/main.rs b/src/main.rs index 6704734..c3161b9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,12 +10,15 @@ use tokio::{ }; mod args; +mod error_storage; mod lua; +mod lua_ext; mod message; mod stats; mod thread_id; use args::*; +use error_storage::*; use lua::*; use message::*; use stats::*; @@ -61,18 +64,25 @@ fn main() { fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> { let lua = create_lua(tx.clone())?; + let error_storage = ErrorStorage::new(); + let error_storage_interrupt = error_storage.clone(); + lua.set_interrupt(move |_| match error_storage_interrupt.take() { + Some(e) => Err(e), + None => Ok(LuaVmState::Continue), + }); + lua.globals() .set("wait", lua.load(WAIT_IMPL).into_function()?)?; - let mut yielded_threads: GxHashMap = GxHashMap::default(); - let mut runnable_threads: GxHashMap = GxHashMap::default(); + let mut yielded_threads = GxHashMap::default(); + let mut runnable_threads = GxHashMap::default(); println!("Running {NUM_TEST_BATCHES} batches"); for _ in 0..NUM_TEST_BATCHES { let main_fn = lua.load(MAIN_CHUNK).into_function()?; for _ in 0..NUM_TEST_THREADS { let thread = lua.create_thread(main_fn.clone())?; - runnable_threads.insert(ThreadId::from(&thread), (thread, Args::new())); + runnable_threads.insert(ThreadId::from(&thread), (thread, Ok(Args::new()))); } loop { @@ -82,13 +92,20 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu } // Resume as many threads as possible - for (thread_id, (thread, args)) in runnable_threads.drain() { + for (thread_id, (thread, res)) in runnable_threads.drain() { stats.incr(StatsCounter::ThreadResumed); + // NOTE: If we got an error we don't need to resume with any args + let args = match res { + Ok(a) => a, + Err(e) => { + error_storage.replace(e); + Args::from(()) + } + }; if let Err(e) = thread.resume::<_, ()>(args) { tx.send(Message::Error(thread_id, Box::new(e))) .expect("failed to send error to async task"); - } - if thread.status() == LuaThreadStatus::Resumable { + } else if thread.status() == LuaThreadStatus::Resumable { yielded_threads.insert(thread_id, thread); } } @@ -100,9 +117,9 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu // Set up message processor - we mutably borrow both yielded_threads and runnable_threads // so we can't really do this outside of the loop, but it compiles down to the same thing let mut process_message = |message| match message { - Message::Resume(thread_id, args) => { + Message::Resume(thread_id, res) => { if let Some(thread) = yielded_threads.remove(&thread_id) { - runnable_threads.insert(thread_id, (thread, args)); + runnable_threads.insert(thread_id, (thread, res)); } } Message::Cancel(thread_id) => { @@ -160,7 +177,7 @@ async fn main_async_task( spawn(async move { sleep(duration).await; let elapsed = Instant::now() - yielded_at; - tx.send(Message::Resume(thread_id, Args::from(elapsed))) + tx.send(Message::Resume(thread_id, Ok(Args::from(elapsed)))) }); } Message::Error(_, e) => { diff --git a/src/message.rs b/src/message.rs index 7ca400a..282bd80 100644 --- a/src/message.rs +++ b/src/message.rs @@ -12,7 +12,7 @@ pub type MessageSender = UnboundedSender; pub type MessageReceiver = UnboundedReceiver; pub enum Message { - Resume(ThreadId, Args), + Resume(ThreadId, LuaResult), Cancel(ThreadId), Sleep(ThreadId, Instant, Duration), Error(ThreadId, Box),