diff --git a/crates/lune-std/src/globals/require/context.rs b/crates/lune-std/src/globals/require/context.rs index 87550f2..9199b8f 100644 --- a/crates/lune-std/src/globals/require/context.rs +++ b/crates/lune-std/src/globals/require/context.rs @@ -89,44 +89,57 @@ impl RequireContext { Ok(multi) } + pub(crate) async fn wait_for_pending<'lua>( + lua: &'lua Lua, + path_abs: &'_ PathBuf, + ) -> Result<(), RequireError> { + let data_ref = lua + .app_data_ref::() + .ok_or(RequireError::RequireContextNotFound)?; + + let pending = data_ref.pending.try_lock()?; + + if let Some(sender) = pending.get(path_abs) { + let mut receiver = sender.subscribe(); + + // unlock mutex before using async + drop(pending); + + receiver.recv().await?; + } + + Ok(()) + } + + pub(crate) async fn from_cache<'lua>( + lua: &'lua Lua, + path_abs: &'_ PathBuf, + ) -> Result>, RequireError> { + let data_ref = lua + .app_data_ref::() + .ok_or(RequireError::RequireContextNotFound)?; + + let cache = data_ref.cache.lock().await; + + match cache.get(path_abs) { + Some(cached) => { + let multi_vec = lua.registry_value::>(cached)?; + + Ok(Some(LuaMultiValue::from_vec(multi_vec))) + } + None => Ok(None), + } + } + pub(crate) async fn require( lua: &Lua, path_rel: PathBuf, path_abs: PathBuf, ) -> Result { - // wait for module to be required - // if its pending somewhere else - { - let data_ref = lua - .app_data_ref::() - .ok_or(RequireError::RequireContextNotFound)?; + Self::wait_for_pending(lua, &path_abs).await?; - let pending = data_ref.pending.try_lock()?; - - if let Some(sender) = pending.get(&path_abs) { - let mut receiver = sender.subscribe(); - - // unlock mutex before using async - drop(pending); - - receiver.recv().await?; - } - } - - // get module from cache - // *if* its cached - { - let data_ref = lua - .app_data_ref::() - .ok_or(RequireError::RequireContextNotFound)?; - - let cache = data_ref.cache.lock().await; - - if let Some(cached) = cache.get(&path_abs) { - let multi_vec = lua.registry_value::>(cached)?; - - return Ok(LuaMultiValue::from_vec(multi_vec)); - } + if let Some(cached) = Self::from_cache(lua, &path_abs).await? { + return Ok(cached); } // create a broadcast channel @@ -135,7 +148,7 @@ impl RequireContext { .app_data_ref::() .ok_or(RequireError::RequireContextNotFound)?; - let (broadcast_tx, _) = broadcast::channel(1); + let broadcast_tx = broadcast::Sender::new(1); { let mut pending = data_ref.pending.try_lock()?; @@ -162,8 +175,7 @@ impl RequireContext { .get_thread_result(thread_id) .ok_or(RequireError::ThreadReturnedNone)??; - let mutli_clone = multi.clone(); - let multi_reg = lua.create_registry_value(mutli_clone.into_vec())?; + let multi_reg = lua.create_registry_value(multi.into_vec())?; let data_ref = lua .app_data_ref::() @@ -184,7 +196,12 @@ impl RequireContext { broadcast_tx.send(()).ok(); - Ok(multi) + match Self::from_cache(lua, &path_abs).await? { + Some(cached) => Ok(cached), + None => Err(RequireError::CacheNotFound( + path_rel.to_string_lossy().to_string(), + )), + } } /** diff --git a/crates/lune-std/src/globals/require/mod.rs b/crates/lune-std/src/globals/require/mod.rs index ca6d339..0d2735c 100644 --- a/crates/lune-std/src/globals/require/mod.rs +++ b/crates/lune-std/src/globals/require/mod.rs @@ -25,6 +25,8 @@ pub enum RequireError { StdMemberNotFound(String, String), #[error("Thread result returned none")] ThreadReturnedNone, + #[error("Could not get '{0}' from cache")] + CacheNotFound(String), #[error("IOError: {0}")] IOError(#[from] std::io::Error),