From 93a56e28c523d66ac41826c811fbefa29ef8d3d7 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Thu, 1 Feb 2024 11:11:17 +0100 Subject: [PATCH] Add back ability to async wait for threads to complete --- examples/callbacks.rs | 2 +- examples/scheduler_ordering.rs | 2 +- lib/result_map.rs | 30 ++++++++++++++++++++++----- lib/runtime.rs | 11 +++++++++- lib/traits.rs | 37 ++++++++++++++++++++++++++++++++++ 5 files changed, 74 insertions(+), 8 deletions(-) diff --git a/examples/callbacks.rs b/examples/callbacks.rs index ee4da89..0cfbb31 100644 --- a/examples/callbacks.rs +++ b/examples/callbacks.rs @@ -32,7 +32,7 @@ pub fn main() -> LuaResult<()> { block_on(rt.run()); // We should have gotten the error back from our script - assert!(rt.thread_result(id).unwrap().is_err()); + assert!(rt.get_thread_result(id).unwrap().is_err()); Ok(()) } diff --git a/examples/scheduler_ordering.rs b/examples/scheduler_ordering.rs index d774ec6..88cba18 100644 --- a/examples/scheduler_ordering.rs +++ b/examples/scheduler_ordering.rs @@ -38,7 +38,7 @@ pub fn main() -> LuaResult<()> { block_on(rt.run()); // We should have gotten proper values back from our script - let res = rt.thread_result(id).unwrap().unwrap(); + let res = rt.get_thread_result(id).unwrap().unwrap(); let nums = Vec::::from_lua_multi(res, &lua)?; assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]); diff --git a/lib/result_map.rs b/lib/result_map.rs index 4d406ff..5907f50 100644 --- a/lib/result_map.rs +++ b/lib/result_map.rs @@ -2,6 +2,7 @@ use std::{cell::RefCell, rc::Rc}; +use event_listener::Event; // NOTE: This is the hash algorithm that mlua also uses, so we // are not adding any additional dependencies / bloat by using it. use rustc_hash::{FxHashMap, FxHashSet}; @@ -11,14 +12,16 @@ use crate::{thread_id::ThreadId, util::ThreadResult}; #[derive(Clone)] pub(crate) struct ThreadResultMap { tracked: Rc>>, - inner: Rc>>, + results: Rc>>, + events: Rc>>>, } impl ThreadResultMap { pub fn new() -> Self { Self { tracked: Rc::new(RefCell::new(FxHashSet::default())), - inner: Rc::new(RefCell::new(FxHashMap::default())), + results: Rc::new(RefCell::new(FxHashMap::default())), + events: Rc::new(RefCell::new(FxHashMap::default())), } } @@ -32,15 +35,32 @@ impl ThreadResultMap { self.tracked.borrow().contains(&id) } - #[inline(always)] + #[inline] pub fn insert(&self, id: ThreadId, result: ThreadResult) { debug_assert!(self.is_tracked(id), "Thread must be tracked"); - self.inner.borrow_mut().insert(id, result); + self.results.borrow_mut().insert(id, result); + if let Some(event) = self.events.borrow_mut().remove(&id) { + event.notify(usize::MAX); + } + } + + #[inline] + pub async fn listen(&self, id: ThreadId) { + debug_assert!(self.is_tracked(id), "Thread must be tracked"); + if !self.results.borrow().contains_key(&id) { + let listener = { + let mut events = self.events.borrow_mut(); + let event = events.entry(id).or_insert_with(|| Rc::new(Event::new())); + event.listen() + }; + listener.await; + } } pub fn remove(&self, id: ThreadId) -> Option { - let res = self.inner.borrow_mut().remove(&id)?; + let res = self.results.borrow_mut().remove(&id)?; self.tracked.borrow_mut().remove(&id); + self.events.borrow_mut().remove(&id); Some(res) } } diff --git a/lib/runtime.rs b/lib/runtime.rs index 06315ff..937ef68 100644 --- a/lib/runtime.rs +++ b/lib/runtime.rs @@ -214,10 +214,19 @@ impl<'lua> Runtime<'lua> { Any subsequent calls after this method returns `Some` will return `None`. */ #[must_use] - pub fn thread_result(&self, id: ThreadId) -> Option>> { + pub fn get_thread_result(&self, id: ThreadId) -> Option>> { self.result_map.remove(id).map(|r| r.value(self.lua)) } + /** + Waits for the [`LuaThread`] with the given [`ThreadId`] to complete. + + This will return instantly if the thread has already completed. + */ + pub async fn wait_for_thread(&self, id: ThreadId) { + self.result_map.listen(id).await; + } + /** Runs the runtime until all Lua threads have completed. diff --git a/lib/traits.rs b/lib/traits.rs index 363f935..e724caf 100644 --- a/lib/traits.rs +++ b/lib/traits.rs @@ -9,6 +9,7 @@ use async_executor::{Executor, Task}; use crate::{ queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue}, + result_map::ThreadResultMap, runtime::Runtime, thread_id::ThreadId, }; @@ -93,6 +94,28 @@ pub trait LuaRuntimeExt<'lua> { args: impl IntoLuaMulti<'lua>, ) -> LuaResult; + /** + Gets the result of the given thread. + + See [`Runtime::get_thread_result`] for more information. + + # Panics + + Panics if called outside of a running [`Runtime`]. + */ + fn get_thread_result(&'lua self, id: ThreadId) -> Option>>; + + /** + Waits for the given thread to complete. + + See [`Runtime::wait_for_thread`] for more information. + + # Panics + + Panics if called outside of a running [`Runtime`]. + */ + fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future; + /** Spawns the given future on the current executor and returns its [`Task`]. @@ -198,6 +221,20 @@ impl<'lua> LuaRuntimeExt<'lua> for Lua { queue.push_item(self, thread, args) } + fn get_thread_result(&'lua self, id: ThreadId) -> Option>> { + let map = self + .app_data_ref::() + .expect("lua threads results can only be retrieved within a runtime"); + map.remove(id).map(|r| r.value(self)) + } + + fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future { + let map = self + .app_data_ref::() + .expect("lua threads results can only be retrieved within a runtime"); + async move { map.listen(id).await } + } + fn spawn(&self, fut: impl Future + Send + 'static) -> Task { let exec = self .app_data_ref::>()