Optimize tracking of thread results in mlua-luau-scheduler

This commit is contained in:
Filip Tibell 2025-04-30 13:35:42 +02:00
parent 3e80a0a1c4
commit b57fa6fad3
No known key found for this signature in database
3 changed files with 137 additions and 34 deletions

View file

@ -4,6 +4,7 @@ mod error_callback;
mod exit; mod exit;
mod functions; mod functions;
mod queue; mod queue;
mod result_event;
mod result_map; mod result_map;
mod scheduler; mod scheduler;
mod status; mod status;

View file

@ -0,0 +1,106 @@
use std::{
cell::RefCell,
future::Future,
pin::Pin,
rc::Rc,
task::{Context, Poll, Waker},
};
/**
State which is highly optimized for a single notification event.
`Some` means not notified yet, `None` means notified.
*/
#[derive(Debug, Default)]
struct OnceEventState {
wakers: RefCell<Option<Vec<Waker>>>,
}
impl OnceEventState {
fn new() -> Self {
Self {
wakers: RefCell::new(Some(Vec::new())),
}
}
}
/**
An event that may be notified exactly once.
May be cheaply cloned.
*/
#[derive(Debug, Clone, Default)]
pub struct OnceEvent {
state: Rc<OnceEventState>,
}
impl OnceEvent {
/**
Creates a new event that can be notified exactly once.
*/
pub fn new() -> Self {
let initial_state = OnceEventState::new();
Self {
state: Rc::new(initial_state),
}
}
/**
Notifies waiting listeners.
This is idempotent; subsequent calls do nothing.
*/
pub fn notify(&self) {
if let Some(wakers) = { self.state.wakers.borrow_mut().take() } {
for waker in wakers {
waker.wake();
}
}
}
/**
Creates a listener that implements `Future` and resolves when `notify` is called.
If `notify` has already been called, the future will resolve immediately.
*/
pub fn listen(&self) -> OnceListener {
OnceListener {
state: self.state.clone(),
}
}
}
/**
A listener that resolves when the event is notified.
May be cheaply cloned.
See [`OnceEvent`] for more information.
*/
#[derive(Debug)]
pub struct OnceListener {
state: Rc<OnceEventState>,
}
impl Future for OnceListener {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut wakers_guard = self.state.wakers.borrow_mut();
match &mut *wakers_guard {
Some(wakers) => {
// Not yet notified
if !wakers.iter().any(|w| w.will_wake(cx.waker())) {
wakers.push(cx.waker().clone());
}
Poll::Pending
}
None => {
// Already notified
Poll::Ready(())
}
}
}
}
impl Unpin for OnceListener {}

View file

@ -2,79 +2,75 @@
use std::{cell::RefCell, rc::Rc}; use std::{cell::RefCell, rc::Rc};
use event_listener::Event;
// NOTE: This is the hash algorithm that mlua also uses, so we
// are not adding any additional dependencies / bloat by using it.
use mlua::prelude::*; use mlua::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet}; use rustc_hash::FxHashMap;
use crate::thread_id::ThreadId; use crate::{result_event::OnceEvent, thread_id::ThreadId};
struct ThreadResultMapInner { struct ThreadEvent {
tracked: FxHashSet<ThreadId>, result: Option<LuaResult<LuaMultiValue>>,
results: FxHashMap<ThreadId, LuaResult<LuaMultiValue>>, event: OnceEvent,
events: FxHashMap<ThreadId, Event>,
} }
impl ThreadResultMapInner { impl ThreadEvent {
fn new() -> Self { fn new() -> Self {
Self { Self {
tracked: FxHashSet::default(), result: None,
results: FxHashMap::default(), event: OnceEvent::new(),
events: FxHashMap::default(),
} }
} }
} }
#[derive(Clone)] #[derive(Clone)]
pub(crate) struct ThreadResultMap { pub(crate) struct ThreadResultMap {
inner: Rc<RefCell<ThreadResultMapInner>>, inner: Rc<RefCell<FxHashMap<ThreadId, ThreadEvent>>>,
} }
impl ThreadResultMap { impl ThreadResultMap {
pub fn new() -> Self { pub fn new() -> Self {
let inner = Rc::new(RefCell::new(ThreadResultMapInner::new())); let inner = Rc::new(RefCell::new(FxHashMap::default()));
Self { inner } Self { inner }
} }
#[inline(always)] #[inline(always)]
pub fn track(&self, id: ThreadId) { pub fn track(&self, id: ThreadId) {
let mut inner = self.inner.borrow_mut(); self.inner.borrow_mut().insert(id, ThreadEvent::new());
inner.tracked.insert(id);
inner.events.insert(id, Event::new());
} }
#[inline(always)] #[inline(always)]
pub fn is_tracked(&self, id: ThreadId) -> bool { pub fn is_tracked(&self, id: ThreadId) -> bool {
self.inner.borrow().tracked.contains(&id) self.inner.borrow().contains_key(&id)
} }
#[inline(always)] #[inline(always)]
pub fn insert(&self, id: ThreadId, result: LuaResult<LuaMultiValue>) { pub fn insert(&self, id: ThreadId, result: LuaResult<LuaMultiValue>) {
debug_assert!(self.is_tracked(id), "Thread must be tracked"); if let Some(tracker) = self.inner.borrow_mut().get_mut(&id) {
let mut inner = self.inner.borrow_mut(); tracker.result.replace(result);
inner.results.insert(id, result); tracker.event.notify();
if let Some(event) = inner.events.remove(&id) { } else {
event.notify(usize::MAX); panic!("Thread must be tracked");
} }
} }
#[inline(always)] #[inline(always)]
pub async fn listen(&self, id: ThreadId) { pub async fn listen(&self, id: ThreadId) {
let listener = { if let Some(listener) = {
let inner = self.inner.borrow(); let inner = self.inner.borrow();
let event = inner.events.get(&id); let tracker = inner.get(&id);
event.map(Event::listen) tracker.map(|t| t.event.listen())
}; } {
listener.expect("Thread must be tracked").await; listener.await;
} else {
panic!("Thread must be tracked");
}
} }
#[inline(always)] #[inline(always)]
pub fn remove(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue>> { pub fn remove(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue>> {
let mut inner = self.inner.borrow_mut(); if let Some(mut tracker) = self.inner.borrow_mut().remove(&id) {
let res = inner.results.remove(&id)?; tracker.result.take()
inner.tracked.remove(&id); } else {
inner.events.remove(&id); None
Some(res) }
} }
} }