Add back ability to async wait for threads to complete

This commit is contained in:
Filip Tibell 2024-02-01 11:11:17 +01:00
parent ecbd5149f8
commit 93a56e28c5
No known key found for this signature in database
5 changed files with 74 additions and 8 deletions

View file

@ -32,7 +32,7 @@ pub fn main() -> LuaResult<()> {
block_on(rt.run()); block_on(rt.run());
// We should have gotten the error back from our script // We should have gotten the error back from our script
assert!(rt.thread_result(id).unwrap().is_err()); assert!(rt.get_thread_result(id).unwrap().is_err());
Ok(()) Ok(())
} }

View file

@ -38,7 +38,7 @@ pub fn main() -> LuaResult<()> {
block_on(rt.run()); block_on(rt.run());
// We should have gotten proper values back from our script // We should have gotten proper values back from our script
let res = rt.thread_result(id).unwrap().unwrap(); let res = rt.get_thread_result(id).unwrap().unwrap();
let nums = Vec::<usize>::from_lua_multi(res, &lua)?; let nums = Vec::<usize>::from_lua_multi(res, &lua)?;
assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]); assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]);

View file

@ -2,6 +2,7 @@
use std::{cell::RefCell, rc::Rc}; use std::{cell::RefCell, rc::Rc};
use event_listener::Event;
// NOTE: This is the hash algorithm that mlua also uses, so we // NOTE: This is the hash algorithm that mlua also uses, so we
// are not adding any additional dependencies / bloat by using it. // are not adding any additional dependencies / bloat by using it.
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::{FxHashMap, FxHashSet};
@ -11,14 +12,16 @@ use crate::{thread_id::ThreadId, util::ThreadResult};
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct ThreadResultMap { pub(crate) struct ThreadResultMap {
tracked: Rc<RefCell<FxHashSet<ThreadId>>>, tracked: Rc<RefCell<FxHashSet<ThreadId>>>,
inner: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>, results: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>,
events: Rc<RefCell<FxHashMap<ThreadId, Rc<Event>>>>,
} }
impl ThreadResultMap { impl ThreadResultMap {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
tracked: Rc::new(RefCell::new(FxHashSet::default())), tracked: Rc::new(RefCell::new(FxHashSet::default())),
inner: Rc::new(RefCell::new(FxHashMap::default())), results: Rc::new(RefCell::new(FxHashMap::default())),
events: Rc::new(RefCell::new(FxHashMap::default())),
} }
} }
@ -32,15 +35,32 @@ impl ThreadResultMap {
self.tracked.borrow().contains(&id) self.tracked.borrow().contains(&id)
} }
#[inline(always)] #[inline]
pub fn insert(&self, id: ThreadId, result: ThreadResult) { pub fn insert(&self, id: ThreadId, result: ThreadResult) {
debug_assert!(self.is_tracked(id), "Thread must be tracked"); debug_assert!(self.is_tracked(id), "Thread must be tracked");
self.inner.borrow_mut().insert(id, result); self.results.borrow_mut().insert(id, result);
if let Some(event) = self.events.borrow_mut().remove(&id) {
event.notify(usize::MAX);
}
}
#[inline]
pub async fn listen(&self, id: ThreadId) {
debug_assert!(self.is_tracked(id), "Thread must be tracked");
if !self.results.borrow().contains_key(&id) {
let listener = {
let mut events = self.events.borrow_mut();
let event = events.entry(id).or_insert_with(|| Rc::new(Event::new()));
event.listen()
};
listener.await;
}
} }
pub fn remove(&self, id: ThreadId) -> Option<ThreadResult> { pub fn remove(&self, id: ThreadId) -> Option<ThreadResult> {
let res = self.inner.borrow_mut().remove(&id)?; let res = self.results.borrow_mut().remove(&id)?;
self.tracked.borrow_mut().remove(&id); self.tracked.borrow_mut().remove(&id);
self.events.borrow_mut().remove(&id);
Some(res) Some(res)
} }
} }

View file

@ -214,10 +214,19 @@ impl<'lua> Runtime<'lua> {
Any subsequent calls after this method returns `Some` will return `None`. Any subsequent calls after this method returns `Some` will return `None`.
*/ */
#[must_use] #[must_use]
pub fn thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> { pub fn get_thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
self.result_map.remove(id).map(|r| r.value(self.lua)) self.result_map.remove(id).map(|r| r.value(self.lua))
} }
/**
Waits for the [`LuaThread`] with the given [`ThreadId`] to complete.
This will return instantly if the thread has already completed.
*/
pub async fn wait_for_thread(&self, id: ThreadId) {
self.result_map.listen(id).await;
}
/** /**
Runs the runtime until all Lua threads have completed. Runs the runtime until all Lua threads have completed.

View file

@ -9,6 +9,7 @@ use async_executor::{Executor, Task};
use crate::{ use crate::{
queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue}, queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue},
result_map::ThreadResultMap,
runtime::Runtime, runtime::Runtime,
thread_id::ThreadId, thread_id::ThreadId,
}; };
@ -93,6 +94,28 @@ pub trait LuaRuntimeExt<'lua> {
args: impl IntoLuaMulti<'lua>, args: impl IntoLuaMulti<'lua>,
) -> LuaResult<ThreadId>; ) -> LuaResult<ThreadId>;
/**
Gets the result of the given thread.
See [`Runtime::get_thread_result`] for more information.
# Panics
Panics if called outside of a running [`Runtime`].
*/
fn get_thread_result(&'lua self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>>;
/**
Waits for the given thread to complete.
See [`Runtime::wait_for_thread`] for more information.
# Panics
Panics if called outside of a running [`Runtime`].
*/
fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future<Output = ()>;
/** /**
Spawns the given future on the current executor and returns its [`Task`]. Spawns the given future on the current executor and returns its [`Task`].
@ -198,6 +221,20 @@ impl<'lua> LuaRuntimeExt<'lua> for Lua {
queue.push_item(self, thread, args) queue.push_item(self, thread, args)
} }
fn get_thread_result(&'lua self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
let map = self
.app_data_ref::<ThreadResultMap>()
.expect("lua threads results can only be retrieved within a runtime");
map.remove(id).map(|r| r.value(self))
}
fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future<Output = ()> {
let map = self
.app_data_ref::<ThreadResultMap>()
.expect("lua threads results can only be retrieved within a runtime");
async move { map.listen(id).await }
}
fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T> { fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T> {
let exec = self let exec = self
.app_data_ref::<WeakArc<Executor>>() .app_data_ref::<WeakArc<Executor>>()