diff --git a/crates/mlua-luau-scheduler/src/result_map.rs b/crates/mlua-luau-scheduler/src/result_map.rs index 39a91b1..e510131 100644 --- a/crates/mlua-luau-scheduler/src/result_map.rs +++ b/crates/mlua-luau-scheduler/src/result_map.rs @@ -13,7 +13,7 @@ use crate::thread_id::ThreadId; struct ThreadResultMapInner { tracked: FxHashSet, results: FxHashMap>, - events: FxHashMap>, + events: FxHashMap, } impl ThreadResultMapInner { @@ -39,7 +39,9 @@ impl ThreadResultMap { #[inline(always)] pub fn track(&self, id: ThreadId) { - self.inner.borrow_mut().tracked.insert(id); + let mut inner = self.inner.borrow_mut(); + inner.tracked.insert(id); + inner.events.insert(id, Event::new()); } #[inline(always)] @@ -47,6 +49,7 @@ impl ThreadResultMap { self.inner.borrow().tracked.contains(&id) } + #[inline(always)] pub fn insert(&self, id: ThreadId, result: LuaResult) { debug_assert!(self.is_tracked(id), "Thread must be tracked"); let mut inner = self.inner.borrow_mut(); @@ -56,21 +59,17 @@ impl ThreadResultMap { } } + #[inline(always)] pub async fn listen(&self, id: ThreadId) { - debug_assert!(self.is_tracked(id), "Thread must be tracked"); - if !self.inner.borrow().results.contains_key(&id) { - let listener = { - let mut inner = self.inner.borrow_mut(); - let event = inner - .events - .entry(id) - .or_insert_with(|| Rc::new(Event::new())); - event.listen() - }; - listener.await; - } + let listener = { + let inner = self.inner.borrow(); + let event = inner.events.get(&id); + event.map(Event::listen) + }; + listener.expect("Thread must be tracked").await; } + #[inline(always)] pub fn remove(&self, id: ThreadId) -> Option> { let mut inner = self.inner.borrow_mut(); let res = inner.results.remove(&id)?; diff --git a/crates/mlua-luau-scheduler/src/thread_id.rs b/crates/mlua-luau-scheduler/src/thread_id.rs index 89eb126..ba31a10 100644 --- a/crates/mlua-luau-scheduler/src/thread_id.rs +++ b/crates/mlua-luau-scheduler/src/thread_id.rs @@ -1,4 +1,7 @@ -use std::hash::{Hash, Hasher}; +use std::{ + ffi::c_void, + hash::{Hash, Hasher}, +}; use mlua::prelude::*; @@ -12,13 +15,13 @@ use mlua::prelude::*; */ #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct ThreadId { - inner: usize, + inner: *const c_void, } impl From<&LuaThread> for ThreadId { fn from(thread: &LuaThread) -> Self { Self { - inner: thread.to_pointer() as usize, + inner: thread.to_pointer(), } } }