Add back ability to async wait for threads to complete

This commit is contained in:
Filip Tibell 2024-02-01 11:11:17 +01:00
parent ecbd5149f8
commit 93a56e28c5
No known key found for this signature in database
5 changed files with 74 additions and 8 deletions

View file

@ -32,7 +32,7 @@ pub fn main() -> LuaResult<()> {
block_on(rt.run());
// We should have gotten the error back from our script
assert!(rt.thread_result(id).unwrap().is_err());
assert!(rt.get_thread_result(id).unwrap().is_err());
Ok(())
}

View file

@ -38,7 +38,7 @@ pub fn main() -> LuaResult<()> {
block_on(rt.run());
// We should have gotten proper values back from our script
let res = rt.thread_result(id).unwrap().unwrap();
let res = rt.get_thread_result(id).unwrap().unwrap();
let nums = Vec::<usize>::from_lua_multi(res, &lua)?;
assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]);

View file

@ -2,6 +2,7 @@
use std::{cell::RefCell, rc::Rc};
use event_listener::Event;
// NOTE: This is the hash algorithm that mlua also uses, so we
// are not adding any additional dependencies / bloat by using it.
use rustc_hash::{FxHashMap, FxHashSet};
@ -11,14 +12,16 @@ use crate::{thread_id::ThreadId, util::ThreadResult};
#[derive(Clone)]
pub(crate) struct ThreadResultMap {
tracked: Rc<RefCell<FxHashSet<ThreadId>>>,
inner: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>,
results: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>,
events: Rc<RefCell<FxHashMap<ThreadId, Rc<Event>>>>,
}
impl ThreadResultMap {
pub fn new() -> Self {
Self {
tracked: Rc::new(RefCell::new(FxHashSet::default())),
inner: Rc::new(RefCell::new(FxHashMap::default())),
results: Rc::new(RefCell::new(FxHashMap::default())),
events: Rc::new(RefCell::new(FxHashMap::default())),
}
}
@ -32,15 +35,32 @@ impl ThreadResultMap {
self.tracked.borrow().contains(&id)
}
#[inline(always)]
#[inline]
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
debug_assert!(self.is_tracked(id), "Thread must be tracked");
self.inner.borrow_mut().insert(id, result);
self.results.borrow_mut().insert(id, result);
if let Some(event) = self.events.borrow_mut().remove(&id) {
event.notify(usize::MAX);
}
}
#[inline]
pub async fn listen(&self, id: ThreadId) {
debug_assert!(self.is_tracked(id), "Thread must be tracked");
if !self.results.borrow().contains_key(&id) {
let listener = {
let mut events = self.events.borrow_mut();
let event = events.entry(id).or_insert_with(|| Rc::new(Event::new()));
event.listen()
};
listener.await;
}
}
pub fn remove(&self, id: ThreadId) -> Option<ThreadResult> {
let res = self.inner.borrow_mut().remove(&id)?;
let res = self.results.borrow_mut().remove(&id)?;
self.tracked.borrow_mut().remove(&id);
self.events.borrow_mut().remove(&id);
Some(res)
}
}

View file

@ -214,10 +214,19 @@ impl<'lua> Runtime<'lua> {
Any subsequent calls after this method returns `Some` will return `None`.
*/
#[must_use]
pub fn thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
pub fn get_thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
self.result_map.remove(id).map(|r| r.value(self.lua))
}
/**
Waits for the [`LuaThread`] with the given [`ThreadId`] to complete.
This will return instantly if the thread has already completed.
*/
pub async fn wait_for_thread(&self, id: ThreadId) {
self.result_map.listen(id).await;
}
/**
Runs the runtime until all Lua threads have completed.

View file

@ -9,6 +9,7 @@ use async_executor::{Executor, Task};
use crate::{
queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue},
result_map::ThreadResultMap,
runtime::Runtime,
thread_id::ThreadId,
};
@ -93,6 +94,28 @@ pub trait LuaRuntimeExt<'lua> {
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<ThreadId>;
/**
Gets the result of the given thread.
See [`Runtime::get_thread_result`] for more information.
# Panics
Panics if called outside of a running [`Runtime`].
*/
fn get_thread_result(&'lua self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>>;
/**
Waits for the given thread to complete.
See [`Runtime::wait_for_thread`] for more information.
# Panics
Panics if called outside of a running [`Runtime`].
*/
fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future<Output = ()>;
/**
Spawns the given future on the current executor and returns its [`Task`].
@ -198,6 +221,20 @@ impl<'lua> LuaRuntimeExt<'lua> for Lua {
queue.push_item(self, thread, args)
}
fn get_thread_result(&'lua self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
let map = self
.app_data_ref::<ThreadResultMap>()
.expect("lua threads results can only be retrieved within a runtime");
map.remove(id).map(|r| r.value(self))
}
fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future<Output = ()> {
let map = self
.app_data_ref::<ThreadResultMap>()
.expect("lua threads results can only be retrieved within a runtime");
async move { map.listen(id).await }
}
fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T> {
let exec = self
.app_data_ref::<WeakArc<Executor>>()