diff --git a/crates/mlua-luau-scheduler/src/lib.rs b/crates/mlua-luau-scheduler/src/lib.rs index 7b82595..8c8b201 100644 --- a/crates/mlua-luau-scheduler/src/lib.rs +++ b/crates/mlua-luau-scheduler/src/lib.rs @@ -4,6 +4,7 @@ mod error_callback; mod exit; mod functions; mod queue; +mod result_event; mod result_map; mod scheduler; mod status; diff --git a/crates/mlua-luau-scheduler/src/result_event.rs b/crates/mlua-luau-scheduler/src/result_event.rs new file mode 100644 index 0000000..77bdd83 --- /dev/null +++ b/crates/mlua-luau-scheduler/src/result_event.rs @@ -0,0 +1,106 @@ +use std::{ + cell::RefCell, + future::Future, + pin::Pin, + rc::Rc, + task::{Context, Poll, Waker}, +}; + +/** + State which is highly optimized for a single notification event. + + `Some` means not notified yet, `None` means notified. +*/ +#[derive(Debug, Default)] +struct OnceEventState { + wakers: RefCell>>, +} + +impl OnceEventState { + fn new() -> Self { + Self { + wakers: RefCell::new(Some(Vec::new())), + } + } +} + +/** + An event that may be notified exactly once. + + May be cheaply cloned. +*/ +#[derive(Debug, Clone, Default)] +pub struct OnceEvent { + state: Rc, +} + +impl OnceEvent { + /** + Creates a new event that can be notified exactly once. + */ + pub fn new() -> Self { + let initial_state = OnceEventState::new(); + Self { + state: Rc::new(initial_state), + } + } + + /** + Notifies waiting listeners. + + This is idempotent; subsequent calls do nothing. + */ + pub fn notify(&self) { + if let Some(wakers) = { self.state.wakers.borrow_mut().take() } { + for waker in wakers { + waker.wake(); + } + } + } + + /** + Creates a listener that implements `Future` and resolves when `notify` is called. + + If `notify` has already been called, the future will resolve immediately. + */ + pub fn listen(&self) -> OnceListener { + OnceListener { + state: self.state.clone(), + } + } +} + +/** + A listener that resolves when the event is notified. + + May be cheaply cloned. + + See [`OnceEvent`] for more information. +*/ +#[derive(Debug)] +pub struct OnceListener { + state: Rc, +} + +impl Future for OnceListener { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut wakers_guard = self.state.wakers.borrow_mut(); + match &mut *wakers_guard { + Some(wakers) => { + // Not yet notified + if !wakers.iter().any(|w| w.will_wake(cx.waker())) { + wakers.push(cx.waker().clone()); + } + Poll::Pending + } + None => { + // Already notified + Poll::Ready(()) + } + } + } +} + +impl Unpin for OnceListener {} diff --git a/crates/mlua-luau-scheduler/src/result_map.rs b/crates/mlua-luau-scheduler/src/result_map.rs index e510131..456c0c6 100644 --- a/crates/mlua-luau-scheduler/src/result_map.rs +++ b/crates/mlua-luau-scheduler/src/result_map.rs @@ -2,79 +2,75 @@ use std::{cell::RefCell, rc::Rc}; -use event_listener::Event; -// NOTE: This is the hash algorithm that mlua also uses, so we -// are not adding any additional dependencies / bloat by using it. use mlua::prelude::*; -use rustc_hash::{FxHashMap, FxHashSet}; +use rustc_hash::FxHashMap; -use crate::thread_id::ThreadId; +use crate::{result_event::OnceEvent, thread_id::ThreadId}; -struct ThreadResultMapInner { - tracked: FxHashSet, - results: FxHashMap>, - events: FxHashMap, +struct ThreadEvent { + result: Option>, + event: OnceEvent, } -impl ThreadResultMapInner { +impl ThreadEvent { fn new() -> Self { Self { - tracked: FxHashSet::default(), - results: FxHashMap::default(), - events: FxHashMap::default(), + result: None, + event: OnceEvent::new(), } } } #[derive(Clone)] pub(crate) struct ThreadResultMap { - inner: Rc>, + inner: Rc>>, } impl ThreadResultMap { pub fn new() -> Self { - let inner = Rc::new(RefCell::new(ThreadResultMapInner::new())); + let inner = Rc::new(RefCell::new(FxHashMap::default())); Self { inner } } #[inline(always)] pub fn track(&self, id: ThreadId) { - let mut inner = self.inner.borrow_mut(); - inner.tracked.insert(id); - inner.events.insert(id, Event::new()); + self.inner.borrow_mut().insert(id, ThreadEvent::new()); } #[inline(always)] pub fn is_tracked(&self, id: ThreadId) -> bool { - self.inner.borrow().tracked.contains(&id) + self.inner.borrow().contains_key(&id) } #[inline(always)] pub fn insert(&self, id: ThreadId, result: LuaResult) { - debug_assert!(self.is_tracked(id), "Thread must be tracked"); - let mut inner = self.inner.borrow_mut(); - inner.results.insert(id, result); - if let Some(event) = inner.events.remove(&id) { - event.notify(usize::MAX); + if let Some(tracker) = self.inner.borrow_mut().get_mut(&id) { + tracker.result.replace(result); + tracker.event.notify(); + } else { + panic!("Thread must be tracked"); } } #[inline(always)] pub async fn listen(&self, id: ThreadId) { - let listener = { + if let Some(listener) = { let inner = self.inner.borrow(); - let event = inner.events.get(&id); - event.map(Event::listen) - }; - listener.expect("Thread must be tracked").await; + let tracker = inner.get(&id); + tracker.map(|t| t.event.listen()) + } { + listener.await; + } else { + panic!("Thread must be tracked"); + } } #[inline(always)] pub fn remove(&self, id: ThreadId) -> Option> { - let mut inner = self.inner.borrow_mut(); - let res = inner.results.remove(&id)?; - inner.tracked.remove(&id); - inner.events.remove(&id); - Some(res) + if let Some(mut tracker) = self.inner.borrow_mut().remove(&id) { + tracker.result.take() + } else { + None + } } }