split require into functions

This commit is contained in:
highflowey 2024-08-23 18:31:39 +03:30
parent 3b56b4159d
commit 42873f6383
2 changed files with 55 additions and 36 deletions

View file

@ -89,44 +89,57 @@ impl RequireContext {
Ok(multi)
}
pub(crate) async fn wait_for_pending<'lua>(
lua: &'lua Lua,
path_abs: &'_ PathBuf,
) -> Result<(), RequireError> {
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
let pending = data_ref.pending.try_lock()?;
if let Some(sender) = pending.get(path_abs) {
let mut receiver = sender.subscribe();
// unlock mutex before using async
drop(pending);
receiver.recv().await?;
}
Ok(())
}
pub(crate) async fn from_cache<'lua>(
lua: &'lua Lua,
path_abs: &'_ PathBuf,
) -> Result<Option<LuaMultiValue<'lua>>, RequireError> {
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
let cache = data_ref.cache.lock().await;
match cache.get(path_abs) {
Some(cached) => {
let multi_vec = lua.registry_value::<Vec<LuaValue>>(cached)?;
Ok(Some(LuaMultiValue::from_vec(multi_vec)))
}
None => Ok(None),
}
}
pub(crate) async fn require(
lua: &Lua,
path_rel: PathBuf,
path_abs: PathBuf,
) -> Result<LuaMultiValue, RequireError> {
// wait for module to be required
// if its pending somewhere else
{
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
Self::wait_for_pending(lua, &path_abs).await?;
let pending = data_ref.pending.try_lock()?;
if let Some(sender) = pending.get(&path_abs) {
let mut receiver = sender.subscribe();
// unlock mutex before using async
drop(pending);
receiver.recv().await?;
}
}
// get module from cache
// *if* its cached
{
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
let cache = data_ref.cache.lock().await;
if let Some(cached) = cache.get(&path_abs) {
let multi_vec = lua.registry_value::<Vec<LuaValue>>(cached)?;
return Ok(LuaMultiValue::from_vec(multi_vec));
}
if let Some(cached) = Self::from_cache(lua, &path_abs).await? {
return Ok(cached);
}
// create a broadcast channel
@ -135,7 +148,7 @@ impl RequireContext {
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
let (broadcast_tx, _) = broadcast::channel(1);
let broadcast_tx = broadcast::Sender::new(1);
{
let mut pending = data_ref.pending.try_lock()?;
@ -162,8 +175,7 @@ impl RequireContext {
.get_thread_result(thread_id)
.ok_or(RequireError::ThreadReturnedNone)??;
let mutli_clone = multi.clone();
let multi_reg = lua.create_registry_value(mutli_clone.into_vec())?;
let multi_reg = lua.create_registry_value(multi.into_vec())?;
let data_ref = lua
.app_data_ref::<RequireContextData>()
@ -184,7 +196,12 @@ impl RequireContext {
broadcast_tx.send(()).ok();
Ok(multi)
match Self::from_cache(lua, &path_abs).await? {
Some(cached) => Ok(cached),
None => Err(RequireError::CacheNotFound(
path_rel.to_string_lossy().to_string(),
)),
}
}
/**

View file

@ -25,6 +25,8 @@ pub enum RequireError {
StdMemberNotFound(String, String),
#[error("Thread result returned none")]
ThreadReturnedNone,
#[error("Could not get '{0}' from cache")]
CacheNotFound(String),
#[error("IOError: {0}")]
IOError(#[from] std::io::Error),