diff --git a/src/lune/scheduler/impl_async.rs b/src/lune/scheduler/impl_async.rs index f57ec3b..6f9a83e 100644 --- a/src/lune/scheduler/impl_async.rs +++ b/src/lune/scheduler/impl_async.rs @@ -39,12 +39,8 @@ where self.schedule_future(async move { match fut.await.and_then(|rets| rets.into_lua_multi(self.lua)) { Err(e) => { - self.state.set_lua_error(e); - // NOTE: We push the thread to the front of the scheduler - // to ensure that it runs first to be able to catch the - // stored error from within the scheduler lua interrupt - self.push_front(thread, ()) - .expect("Failed to schedule future thread"); + self.push_err(thread, e) + .expect("Failed to schedule future err thread"); } Ok(v) => { self.push_back(thread, v) diff --git a/src/lune/scheduler/impl_runner.rs b/src/lune/scheduler/impl_runner.rs index 088bf23..82617cf 100644 --- a/src/lune/scheduler/impl_runner.rs +++ b/src/lune/scheduler/impl_runner.rs @@ -23,29 +23,45 @@ where let mut resumed_any = false; - while let Some((thread, args, sender)) = self + // Pop threads from the scheduler until there are none left + while let Some(thread) = self .pop_thread() .expect("Failed to pop thread from scheduler") { + // Deconstruct the scheduler thread into its parts + let thread_id = thread.id(); + let (thread, args) = thread.into_inner(self.lua); + + // Resume the thread, ensuring that the schedulers + // current thread id is set correctly for error catching + self.state.set_current_thread_id(Some(thread_id)); let res = thread.resume::<_, LuaMultiValue>(args); - self.state.add_resumption(); + self.state.set_current_thread_id(None); + resumed_any = true; + // If we got any resumption (lua-side) error, increment + // the error count of the scheduler so we can exit with + // a non-zero exit code, and print it out to stderr + // TODO: Pretty print the lua error here if let Err(err) = &res { - self.state.add_error(); - eprint!("{err}"); // TODO: Pretty print the lua error here + self.state.increment_error_count(); + eprint!("{err}"); } - if sender.receiver_count() > 0 { - sender - .send(res.map(|v| { - Arc::new( - self.lua - .create_registry_value(v.into_vec()) - .expect("Failed to store return values in registry"), - ) - })) - .expect("Failed to broadcast return values of thread"); + // Send results of resuming this thread to any listeners + if let Some(sender) = self.thread_senders.borrow_mut().remove(&thread_id) { + if sender.receiver_count() > 0 { + sender + .send(res.map(|v| { + Arc::new( + self.lua + .create_registry_value(v.into_vec()) + .expect("Failed to store return values in registry"), + ) + })) + .expect("Failed to broadcast return values of thread"); + } } if self.state.has_exit_code() { diff --git a/src/lune/scheduler/impl_threads.rs b/src/lune/scheduler/impl_threads.rs index 2f8969f..765bbb7 100644 --- a/src/lune/scheduler/impl_threads.rs +++ b/src/lune/scheduler/impl_threads.rs @@ -27,9 +27,7 @@ where Returns `None` if there are no threads left to run. */ - pub(super) fn pop_thread( - &self, - ) -> LuaResult, SchedulerThreadSender)>> { + pub(super) fn pop_thread(&self) -> LuaResult> { match self .threads .try_borrow_mut() @@ -37,20 +35,31 @@ where .context("Failed to borrow threads vec")? .pop_front() { - Some(thread) => { - let thread_id = &thread.id(); - let (thread, args) = thread.into_inner(self.lua); - let sender = self - .thread_senders - .borrow_mut() - .remove(thread_id) - .expect("Missing thread sender"); - Ok(Some((thread, args, sender))) - } + Some(thread) => Ok(Some(thread)), None => Ok(None), } } + /** + Schedules the `thread` to be resumed with the given [`LuaError`]. + */ + pub fn push_err(&self, thread: impl IntoLuaOwnedThread, err: LuaError) -> LuaResult<()> { + let thread = thread.into_owned_lua_thread(self.lua)?; + let args = LuaMultiValue::new(); // Will be resumed with error, don't need real args + + let thread = SchedulerThread::new(self.lua, thread, args)?; + let thread_id = thread.id(); + + self.state.set_thread_error(thread_id, err); + self.threads + .try_borrow_mut() + .into_lua_err() + .context("Failed to borrow threads vec")? + .push_front(thread); + + Ok(()) + } + /** Schedules the `thread` to be resumed with the given `args` right away, before any other currently scheduled threads. diff --git a/src/lune/scheduler/mod.rs b/src/lune/scheduler/mod.rs index 520a323..6f457d7 100644 --- a/src/lune/scheduler/mod.rs +++ b/src/lune/scheduler/mod.rs @@ -52,12 +52,18 @@ impl<'lua, 'fut> Scheduler<'lua, 'fut> { futures: Arc::new(AsyncMutex::new(FuturesUnordered::new())), }; - // HACK: Propagate errors given to the scheduler back to their lua threads + // Propagate errors given to the scheduler back to their lua threads // FUTURE: Do profiling and anything else we need inside of this interrupt let state = this.state.clone(); - lua.set_interrupt(move |_| match state.get_lua_error() { - Some(e) => Err(e), - None => Ok(LuaVmState::Continue), + lua.set_interrupt(move |_| { + if let Some(id) = state.get_current_thread_id() { + match state.get_thread_error(id) { + Some(e) => Err(e), + None => Ok(LuaVmState::Continue), + } + } else { + Ok(LuaVmState::Continue) + } }); this diff --git a/src/lune/scheduler/state.rs b/src/lune/scheduler/state.rs index 3a0b12a..1c38670 100644 --- a/src/lune/scheduler/state.rs +++ b/src/lune/scheduler/state.rs @@ -1,17 +1,21 @@ use std::{ cell::RefCell, + collections::HashMap, sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, }; use mlua::Error as LuaError; +use super::SchedulerThreadId; + #[derive(Debug, Default)] pub struct SchedulerState { exit_state: AtomicBool, exit_code: AtomicU8, num_resumptions: AtomicUsize, num_errors: AtomicUsize, - lua_error: RefCell>, + thread_id: RefCell>, + thread_errors: RefCell>, } impl SchedulerState { @@ -19,11 +23,7 @@ impl SchedulerState { Self::default() } - pub fn add_resumption(&self) { - self.num_resumptions.fetch_add(1, Ordering::Relaxed); - } - - pub fn add_error(&self) { + pub fn increment_error_count(&self) { self.num_errors.fetch_add(1, Ordering::Relaxed); } @@ -48,11 +48,20 @@ impl SchedulerState { self.exit_code.store(code.into(), Ordering::SeqCst); } - pub fn get_lua_error(&self) -> Option { - self.lua_error.take() + pub fn get_current_thread_id(&self) -> Option { + *self.thread_id.borrow() } - pub fn set_lua_error(&self, e: LuaError) { - self.lua_error.replace(Some(e)); + pub fn set_current_thread_id(&self, id: Option) { + self.thread_id.replace(id); + self.num_resumptions.fetch_add(1, Ordering::Relaxed); + } + + pub fn get_thread_error(&self, id: SchedulerThreadId) -> Option { + self.thread_errors.borrow_mut().remove(&id) + } + + pub fn set_thread_error(&self, id: SchedulerThreadId, err: LuaError) { + self.thread_errors.borrow_mut().insert(id, err); } }