mirror of
https://github.com/lune-org/mlua-luau-scheduler.git
synced 2025-04-03 18:10:55 +01:00
Add back ability to async wait for threads to complete
This commit is contained in:
parent
ecbd5149f8
commit
93a56e28c5
5 changed files with 74 additions and 8 deletions
|
@ -32,7 +32,7 @@ pub fn main() -> LuaResult<()> {
|
|||
block_on(rt.run());
|
||||
|
||||
// We should have gotten the error back from our script
|
||||
assert!(rt.thread_result(id).unwrap().is_err());
|
||||
assert!(rt.get_thread_result(id).unwrap().is_err());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ pub fn main() -> LuaResult<()> {
|
|||
block_on(rt.run());
|
||||
|
||||
// We should have gotten proper values back from our script
|
||||
let res = rt.thread_result(id).unwrap().unwrap();
|
||||
let res = rt.get_thread_result(id).unwrap().unwrap();
|
||||
let nums = Vec::<usize>::from_lua_multi(res, &lua)?;
|
||||
assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]);
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
use std::{cell::RefCell, rc::Rc};
|
||||
|
||||
use event_listener::Event;
|
||||
// NOTE: This is the hash algorithm that mlua also uses, so we
|
||||
// are not adding any additional dependencies / bloat by using it.
|
||||
use rustc_hash::{FxHashMap, FxHashSet};
|
||||
|
@ -11,14 +12,16 @@ use crate::{thread_id::ThreadId, util::ThreadResult};
|
|||
#[derive(Clone)]
|
||||
pub(crate) struct ThreadResultMap {
|
||||
tracked: Rc<RefCell<FxHashSet<ThreadId>>>,
|
||||
inner: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>,
|
||||
results: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>,
|
||||
events: Rc<RefCell<FxHashMap<ThreadId, Rc<Event>>>>,
|
||||
}
|
||||
|
||||
impl ThreadResultMap {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
tracked: Rc::new(RefCell::new(FxHashSet::default())),
|
||||
inner: Rc::new(RefCell::new(FxHashMap::default())),
|
||||
results: Rc::new(RefCell::new(FxHashMap::default())),
|
||||
events: Rc::new(RefCell::new(FxHashMap::default())),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,15 +35,32 @@ impl ThreadResultMap {
|
|||
self.tracked.borrow().contains(&id)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
#[inline]
|
||||
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
|
||||
debug_assert!(self.is_tracked(id), "Thread must be tracked");
|
||||
self.inner.borrow_mut().insert(id, result);
|
||||
self.results.borrow_mut().insert(id, result);
|
||||
if let Some(event) = self.events.borrow_mut().remove(&id) {
|
||||
event.notify(usize::MAX);
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub async fn listen(&self, id: ThreadId) {
|
||||
debug_assert!(self.is_tracked(id), "Thread must be tracked");
|
||||
if !self.results.borrow().contains_key(&id) {
|
||||
let listener = {
|
||||
let mut events = self.events.borrow_mut();
|
||||
let event = events.entry(id).or_insert_with(|| Rc::new(Event::new()));
|
||||
event.listen()
|
||||
};
|
||||
listener.await;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove(&self, id: ThreadId) -> Option<ThreadResult> {
|
||||
let res = self.inner.borrow_mut().remove(&id)?;
|
||||
let res = self.results.borrow_mut().remove(&id)?;
|
||||
self.tracked.borrow_mut().remove(&id);
|
||||
self.events.borrow_mut().remove(&id);
|
||||
Some(res)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -214,10 +214,19 @@ impl<'lua> Runtime<'lua> {
|
|||
Any subsequent calls after this method returns `Some` will return `None`.
|
||||
*/
|
||||
#[must_use]
|
||||
pub fn thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
|
||||
pub fn get_thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
|
||||
self.result_map.remove(id).map(|r| r.value(self.lua))
|
||||
}
|
||||
|
||||
/**
|
||||
Waits for the [`LuaThread`] with the given [`ThreadId`] to complete.
|
||||
|
||||
This will return instantly if the thread has already completed.
|
||||
*/
|
||||
pub async fn wait_for_thread(&self, id: ThreadId) {
|
||||
self.result_map.listen(id).await;
|
||||
}
|
||||
|
||||
/**
|
||||
Runs the runtime until all Lua threads have completed.
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ use async_executor::{Executor, Task};
|
|||
|
||||
use crate::{
|
||||
queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue},
|
||||
result_map::ThreadResultMap,
|
||||
runtime::Runtime,
|
||||
thread_id::ThreadId,
|
||||
};
|
||||
|
@ -93,6 +94,28 @@ pub trait LuaRuntimeExt<'lua> {
|
|||
args: impl IntoLuaMulti<'lua>,
|
||||
) -> LuaResult<ThreadId>;
|
||||
|
||||
/**
|
||||
Gets the result of the given thread.
|
||||
|
||||
See [`Runtime::get_thread_result`] for more information.
|
||||
|
||||
# Panics
|
||||
|
||||
Panics if called outside of a running [`Runtime`].
|
||||
*/
|
||||
fn get_thread_result(&'lua self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>>;
|
||||
|
||||
/**
|
||||
Waits for the given thread to complete.
|
||||
|
||||
See [`Runtime::wait_for_thread`] for more information.
|
||||
|
||||
# Panics
|
||||
|
||||
Panics if called outside of a running [`Runtime`].
|
||||
*/
|
||||
fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future<Output = ()>;
|
||||
|
||||
/**
|
||||
Spawns the given future on the current executor and returns its [`Task`].
|
||||
|
||||
|
@ -198,6 +221,20 @@ impl<'lua> LuaRuntimeExt<'lua> for Lua {
|
|||
queue.push_item(self, thread, args)
|
||||
}
|
||||
|
||||
fn get_thread_result(&'lua self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
|
||||
let map = self
|
||||
.app_data_ref::<ThreadResultMap>()
|
||||
.expect("lua threads results can only be retrieved within a runtime");
|
||||
map.remove(id).map(|r| r.value(self))
|
||||
}
|
||||
|
||||
fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future<Output = ()> {
|
||||
let map = self
|
||||
.app_data_ref::<ThreadResultMap>()
|
||||
.expect("lua threads results can only be retrieved within a runtime");
|
||||
async move { map.listen(id).await }
|
||||
}
|
||||
|
||||
fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T> {
|
||||
let exec = self
|
||||
.app_data_ref::<WeakArc<Executor>>()
|
||||
|
|
Loading…
Add table
Reference in a new issue