mirror of
https://github.com/lune-org/mlua-luau-scheduler.git
synced 2025-04-03 18:10:55 +01:00
Add back ability to async wait for threads to complete
This commit is contained in:
parent
ecbd5149f8
commit
93a56e28c5
5 changed files with 74 additions and 8 deletions
|
@ -32,7 +32,7 @@ pub fn main() -> LuaResult<()> {
|
||||||
block_on(rt.run());
|
block_on(rt.run());
|
||||||
|
|
||||||
// We should have gotten the error back from our script
|
// We should have gotten the error back from our script
|
||||||
assert!(rt.thread_result(id).unwrap().is_err());
|
assert!(rt.get_thread_result(id).unwrap().is_err());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -38,7 +38,7 @@ pub fn main() -> LuaResult<()> {
|
||||||
block_on(rt.run());
|
block_on(rt.run());
|
||||||
|
|
||||||
// We should have gotten proper values back from our script
|
// We should have gotten proper values back from our script
|
||||||
let res = rt.thread_result(id).unwrap().unwrap();
|
let res = rt.get_thread_result(id).unwrap().unwrap();
|
||||||
let nums = Vec::<usize>::from_lua_multi(res, &lua)?;
|
let nums = Vec::<usize>::from_lua_multi(res, &lua)?;
|
||||||
assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]);
|
assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]);
|
||||||
|
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
use std::{cell::RefCell, rc::Rc};
|
use std::{cell::RefCell, rc::Rc};
|
||||||
|
|
||||||
|
use event_listener::Event;
|
||||||
// NOTE: This is the hash algorithm that mlua also uses, so we
|
// NOTE: This is the hash algorithm that mlua also uses, so we
|
||||||
// are not adding any additional dependencies / bloat by using it.
|
// are not adding any additional dependencies / bloat by using it.
|
||||||
use rustc_hash::{FxHashMap, FxHashSet};
|
use rustc_hash::{FxHashMap, FxHashSet};
|
||||||
|
@ -11,14 +12,16 @@ use crate::{thread_id::ThreadId, util::ThreadResult};
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub(crate) struct ThreadResultMap {
|
pub(crate) struct ThreadResultMap {
|
||||||
tracked: Rc<RefCell<FxHashSet<ThreadId>>>,
|
tracked: Rc<RefCell<FxHashSet<ThreadId>>>,
|
||||||
inner: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>,
|
results: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>,
|
||||||
|
events: Rc<RefCell<FxHashMap<ThreadId, Rc<Event>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ThreadResultMap {
|
impl ThreadResultMap {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
tracked: Rc::new(RefCell::new(FxHashSet::default())),
|
tracked: Rc::new(RefCell::new(FxHashSet::default())),
|
||||||
inner: Rc::new(RefCell::new(FxHashMap::default())),
|
results: Rc::new(RefCell::new(FxHashMap::default())),
|
||||||
|
events: Rc::new(RefCell::new(FxHashMap::default())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -32,15 +35,32 @@ impl ThreadResultMap {
|
||||||
self.tracked.borrow().contains(&id)
|
self.tracked.borrow().contains(&id)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline(always)]
|
#[inline]
|
||||||
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
|
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
|
||||||
debug_assert!(self.is_tracked(id), "Thread must be tracked");
|
debug_assert!(self.is_tracked(id), "Thread must be tracked");
|
||||||
self.inner.borrow_mut().insert(id, result);
|
self.results.borrow_mut().insert(id, result);
|
||||||
|
if let Some(event) = self.events.borrow_mut().remove(&id) {
|
||||||
|
event.notify(usize::MAX);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub async fn listen(&self, id: ThreadId) {
|
||||||
|
debug_assert!(self.is_tracked(id), "Thread must be tracked");
|
||||||
|
if !self.results.borrow().contains_key(&id) {
|
||||||
|
let listener = {
|
||||||
|
let mut events = self.events.borrow_mut();
|
||||||
|
let event = events.entry(id).or_insert_with(|| Rc::new(Event::new()));
|
||||||
|
event.listen()
|
||||||
|
};
|
||||||
|
listener.await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn remove(&self, id: ThreadId) -> Option<ThreadResult> {
|
pub fn remove(&self, id: ThreadId) -> Option<ThreadResult> {
|
||||||
let res = self.inner.borrow_mut().remove(&id)?;
|
let res = self.results.borrow_mut().remove(&id)?;
|
||||||
self.tracked.borrow_mut().remove(&id);
|
self.tracked.borrow_mut().remove(&id);
|
||||||
|
self.events.borrow_mut().remove(&id);
|
||||||
Some(res)
|
Some(res)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -214,10 +214,19 @@ impl<'lua> Runtime<'lua> {
|
||||||
Any subsequent calls after this method returns `Some` will return `None`.
|
Any subsequent calls after this method returns `Some` will return `None`.
|
||||||
*/
|
*/
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
|
pub fn get_thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
|
||||||
self.result_map.remove(id).map(|r| r.value(self.lua))
|
self.result_map.remove(id).map(|r| r.value(self.lua))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
Waits for the [`LuaThread`] with the given [`ThreadId`] to complete.
|
||||||
|
|
||||||
|
This will return instantly if the thread has already completed.
|
||||||
|
*/
|
||||||
|
pub async fn wait_for_thread(&self, id: ThreadId) {
|
||||||
|
self.result_map.listen(id).await;
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
Runs the runtime until all Lua threads have completed.
|
Runs the runtime until all Lua threads have completed.
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ use async_executor::{Executor, Task};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue},
|
queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue},
|
||||||
|
result_map::ThreadResultMap,
|
||||||
runtime::Runtime,
|
runtime::Runtime,
|
||||||
thread_id::ThreadId,
|
thread_id::ThreadId,
|
||||||
};
|
};
|
||||||
|
@ -93,6 +94,28 @@ pub trait LuaRuntimeExt<'lua> {
|
||||||
args: impl IntoLuaMulti<'lua>,
|
args: impl IntoLuaMulti<'lua>,
|
||||||
) -> LuaResult<ThreadId>;
|
) -> LuaResult<ThreadId>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
Gets the result of the given thread.
|
||||||
|
|
||||||
|
See [`Runtime::get_thread_result`] for more information.
|
||||||
|
|
||||||
|
# Panics
|
||||||
|
|
||||||
|
Panics if called outside of a running [`Runtime`].
|
||||||
|
*/
|
||||||
|
fn get_thread_result(&'lua self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>>;
|
||||||
|
|
||||||
|
/**
|
||||||
|
Waits for the given thread to complete.
|
||||||
|
|
||||||
|
See [`Runtime::wait_for_thread`] for more information.
|
||||||
|
|
||||||
|
# Panics
|
||||||
|
|
||||||
|
Panics if called outside of a running [`Runtime`].
|
||||||
|
*/
|
||||||
|
fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future<Output = ()>;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
Spawns the given future on the current executor and returns its [`Task`].
|
Spawns the given future on the current executor and returns its [`Task`].
|
||||||
|
|
||||||
|
@ -198,6 +221,20 @@ impl<'lua> LuaRuntimeExt<'lua> for Lua {
|
||||||
queue.push_item(self, thread, args)
|
queue.push_item(self, thread, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_thread_result(&'lua self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
|
||||||
|
let map = self
|
||||||
|
.app_data_ref::<ThreadResultMap>()
|
||||||
|
.expect("lua threads results can only be retrieved within a runtime");
|
||||||
|
map.remove(id).map(|r| r.value(self))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn wait_for_thread(&'lua self, id: ThreadId) -> impl Future<Output = ()> {
|
||||||
|
let map = self
|
||||||
|
.app_data_ref::<ThreadResultMap>()
|
||||||
|
.expect("lua threads results can only be retrieved within a runtime");
|
||||||
|
async move { map.listen(id).await }
|
||||||
|
}
|
||||||
|
|
||||||
fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T> {
|
fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T> {
|
||||||
let exec = self
|
let exec = self
|
||||||
.app_data_ref::<WeakArc<Executor>>()
|
.app_data_ref::<WeakArc<Executor>>()
|
||||||
|
|
Loading…
Add table
Reference in a new issue