From 981d3235566b76672bf5ef4f883e08f01e7eaa4c Mon Sep 17 00:00:00 2001 From: AshleyFlow Date: Thu, 17 Oct 2024 13:47:53 +0330 Subject: [PATCH] eliminate unnecessary checks for files in require --- .../lune-std/src/globals/require/context.rs | 86 +++++++++++-------- crates/lune-std/src/globals/require/mod.rs | 63 ++++++++++---- crates/lune-std/src/globals/require/path.rs | 72 +--------------- 3 files changed, 100 insertions(+), 121 deletions(-) diff --git a/crates/lune-std/src/globals/require/context.rs b/crates/lune-std/src/globals/require/context.rs index 0688e44..ceaeee3 100644 --- a/crates/lune-std/src/globals/require/context.rs +++ b/crates/lune-std/src/globals/require/context.rs @@ -43,7 +43,7 @@ impl RequireContext { pub(crate) fn std_exists(lua: &Lua, alias: &str) -> Result { let data_ref = lua .app_data_ref::() - .ok_or(RequireError::RequireContextNotFound)?; + .ok_or_else(|| RequireError::RequireContextNotFound)?; Ok(data_ref.std.contains_key(alias)) } @@ -54,7 +54,7 @@ impl RequireContext { ) -> Result, RequireError> { let data_ref = lua .app_data_ref::() - .ok_or(RequireError::RequireContextNotFound)?; + .ok_or_else(|| RequireError::RequireContextNotFound)?; if let Some(cached) = data_ref.std_cache.get(&require_alias) { let multi_vec = lua.registry_value::>(cached)?; @@ -62,17 +62,17 @@ impl RequireContext { return Ok(LuaMultiValue::from_vec(multi_vec)); } - let libraries = data_ref.std.get(&require_alias.alias.as_str()).ok_or( - RequireError::InvalidStdAlias(require_alias.alias.to_string()), - )?; + let libraries = data_ref + .std + .get(&require_alias.alias.as_str()) + .ok_or_else(|| RequireError::InvalidStdAlias(require_alias.alias.to_string()))?; - let std = - libraries - .get(require_alias.path.as_str()) - .ok_or(RequireError::StdMemberNotFound( - require_alias.path.to_string(), - require_alias.alias.to_string(), - ))?; + let std = libraries.get(require_alias.path.as_str()).ok_or_else(|| { + RequireError::StdMemberNotFound( + require_alias.path.to_string(), + require_alias.alias.to_string(), + ) + })?; let multi = std.module(lua)?; let mutli_clone = multi.clone(); @@ -82,7 +82,7 @@ impl RequireContext { let mut data = lua .app_data_mut::() - .ok_or(RequireError::RequireContextNotFound)?; + .ok_or_else(|| RequireError::RequireContextNotFound)?; data.std_cache.insert(require_alias, multi_reg); @@ -95,7 +95,7 @@ impl RequireContext { ) -> Result<(), RequireError> { let data_ref = lua .app_data_ref::() - .ok_or(RequireError::RequireContextNotFound)?; + .ok_or_else(|| RequireError::RequireContextNotFound)?; let pending = data_ref.pending.try_lock()?; @@ -111,13 +111,33 @@ impl RequireContext { Ok(()) } + fn is_pending(lua: &Lua, path_abs: &PathBuf) -> Result { + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; + + let pending = data_ref.pending.try_lock()?; + + Ok(pending.get(path_abs).is_some()) + } + + fn is_cached(lua: &Lua, path_abs: &PathBuf) -> Result { + let data_ref = lua + .app_data_ref::() + .ok_or_else(|| RequireError::RequireContextNotFound)?; + + let cache = data_ref.cache.try_lock()?; + + Ok(cache.get(path_abs).is_some()) + } + async fn from_cache<'lua>( lua: &'lua Lua, path_abs: &'_ PathBuf, - ) -> Result>, RequireError> { + ) -> Result, RequireError> { let data_ref = lua .app_data_ref::() - .ok_or(RequireError::RequireContextNotFound)?; + .ok_or_else(|| RequireError::RequireContextNotFound)?; let cache = data_ref.cache.lock().await; @@ -125,28 +145,32 @@ impl RequireContext { Some(cached) => { let multi_vec = lua.registry_value::>(cached)?; - Ok(Some(LuaMultiValue::from_vec(multi_vec))) + Ok(LuaMultiValue::from_vec(multi_vec)) } - None => Ok(None), + None => Err(RequireError::CacheNotFound( + path_abs.to_string_lossy().to_string(), + )), } } pub(crate) async fn require( lua: &Lua, - path_rel: PathBuf, path_abs: PathBuf, ) -> Result { - Self::wait_for_pending(lua, &path_abs).await?; - - if let Some(cached) = Self::from_cache(lua, &path_abs).await? { - return Ok(cached); + if Self::is_pending(lua, &path_abs)? { + Self::wait_for_pending(lua, &path_abs).await?; + return Self::from_cache(lua, &path_abs).await; + } else if Self::is_cached(lua, &path_abs)? { + return Self::from_cache(lua, &path_abs).await; } + let content = fs::read_to_string(&path_abs).await?; + // create a broadcast channel { let data_ref = lua .app_data_ref::() - .ok_or(RequireError::RequireContextNotFound)?; + .ok_or_else(|| RequireError::RequireContextNotFound)?; let broadcast_tx = broadcast::Sender::new(1); @@ -156,7 +180,6 @@ impl RequireContext { } } - let content = fs::read_to_string(&path_abs).await?; let thread = lua .load(&content) .set_name(path_abs.to_string_lossy()) @@ -168,13 +191,13 @@ impl RequireContext { let multi = lua .get_thread_result(thread_id) - .ok_or(RequireError::ThreadReturnedNone)??; + .ok_or_else(|| RequireError::ThreadReturnedNone)??; let multi_reg = lua.create_registry_value(multi.into_vec())?; let data_ref = lua .app_data_ref::() - .ok_or(RequireError::RequireContextNotFound)?; + .ok_or_else(|| RequireError::RequireContextNotFound)?; data_ref .cache @@ -191,12 +214,7 @@ impl RequireContext { broadcast_tx.send(()).ok(); - match Self::from_cache(lua, &path_abs).await? { - Some(cached) => Ok(cached), - None => Err(RequireError::CacheNotFound( - path_rel.to_string_lossy().to_string(), - )), - } + Self::from_cache(lua, &path_abs).await } /** @@ -226,7 +244,7 @@ impl RequireContext { ) -> Result<(), RequireError> { let mut data = lua .app_data_mut::() - .ok_or(RequireError::RequireContextNotFound)?; + .ok_or_else(|| RequireError::RequireContextNotFound)?; if let Some(map) = data.std.get_mut(alias) { map.insert(std.name(), Box::new(std)); diff --git a/crates/lune-std/src/globals/require/mod.rs b/crates/lune-std/src/globals/require/mod.rs index 7b9415b..5b67c36 100644 --- a/crates/lune-std/src/globals/require/mod.rs +++ b/crates/lune-std/src/globals/require/mod.rs @@ -3,8 +3,9 @@ use crate::{ path::get_parent_path, LuneStandardLibrary, }; +use lune_utils::path::clean_path_and_make_absolute; use mlua::prelude::*; -use path::resolve_path; +use path::append_extension; use std::path::PathBuf; use thiserror::Error; @@ -38,6 +39,44 @@ pub enum RequireError { LuaError(#[from] mlua::Error), } +/** +tries different extensions on the path and if all alternatives fail, we'll try to look for an init file + */ +async fn try_alternatives(lua: &Lua, require_path_abs: PathBuf) -> LuaResult { + for ext in ["lua", "luau"] { + // try the path with ext + let ext_path = append_extension(&require_path_abs, ext); + + match context::RequireContext::require(lua, ext_path).await { + Ok(res) => return Ok(res), + Err(err) => { + if !matches!(err, RequireError::IOError(_)) { + return Err(err).into_lua_err(); + }; + } + }; + } + + for ext in ["lua", "luau"] { + // append init to path and try it with ext + let ext_path = append_extension(require_path_abs.join("init"), ext); + + match context::RequireContext::require(lua, ext_path).await { + Ok(res) => return Ok(res), + Err(err) => { + if !matches!(err, RequireError::IOError(_)) { + return Err(err).into_lua_err(); + }; + } + }; + } + + Err(RequireError::InvalidRequire( + require_path_abs.to_string_lossy().to_string(), + )) + .into_lua_err() +} + async fn lua_require(lua: &Lua, path: String) -> LuaResult { let require_path_rel = PathBuf::from(path); let require_alias = RequireAlias::from_path(&require_path_rel).into_lua_err()?; @@ -46,29 +85,19 @@ async fn lua_require(lua: &Lua, path: String) -> LuaResult { if context::RequireContext::std_exists(lua, &require_alias.alias).into_lua_err()? { context::RequireContext::require_std(lua, require_alias).into_lua_err() } else { - let require_path_abs = resolve_path( - &Luaurc::resolve_path(lua, &require_alias) + let require_path_abs = clean_path_and_make_absolute( + Luaurc::resolve_path(lua, &require_alias) .await .into_lua_err()?, - ) - .await?; + ); - context::RequireContext::require(lua, require_path_rel, require_path_abs) - .await - .into_lua_err() + try_alternatives(lua, require_path_abs).await } } else { let parent_path = get_parent_path(lua)?; - let require_path_abs = resolve_path(&parent_path.join(&require_path_rel)) - .await - .map_err(|_| { - RequireError::InvalidRequire(require_path_rel.to_string_lossy().to_string()) - }) - .into_lua_err()?; + let require_path_abs = clean_path_and_make_absolute(parent_path.join(&require_path_rel)); - context::RequireContext::require(lua, require_path_rel, require_path_abs) - .await - .into_lua_err() + try_alternatives(lua, require_path_abs).await } } diff --git a/crates/lune-std/src/globals/require/path.rs b/crates/lune-std/src/globals/require/path.rs index c00ad55..b1e932e 100644 --- a/crates/lune-std/src/globals/require/path.rs +++ b/crates/lune-std/src/globals/require/path.rs @@ -1,72 +1,4 @@ -use mlua::prelude::*; -use std::path::{Component, Path, PathBuf}; -use tokio::fs; - -/** - -tries these alternatives on given path if path doesn't exist - -* .lua and .luau extension -* path.join("init.luau") and path.join("init.lua") - - */ -pub async fn resolve_path(path: &Path) -> LuaResult { - let init_path = &path.join("init"); - - for ext in ["lua", "luau"] { - // try extension on given path - let path = append_extension(path, ext); - - if fs::try_exists(&path).await? { - return Ok(normalize_path(&path)); - }; - - // try extension on given path's init - let init_path = append_extension(init_path, ext); - - if fs::try_exists(&init_path).await? { - return Ok(normalize_path(&init_path)); - }; - } - - Err(LuaError::runtime("Could not resolve path")) -} - -/** - -Removes useless components from the given path - -### Example - -`./path/./path` turns into `./path/path` - - */ -pub fn normalize_path(path: &Path) -> PathBuf { - let mut components = path.components().peekable(); - let mut ret = if let Some(c @ Component::Prefix(..)) = components.clone().peek() { - components.next(); - PathBuf::from(c.as_os_str()) - } else { - PathBuf::new() - }; - - for component in components { - match component { - Component::Prefix(..) => unreachable!(), - Component::RootDir => { - ret.push(component.as_os_str()); - } - Component::CurDir => {} - Component::ParentDir => { - ret.pop(); - } - Component::Normal(c) => { - ret.push(c); - } - } - } - ret -} +use std::path::PathBuf; /** @@ -77,7 +9,7 @@ adds extension to path without replacing it's current extensions appending `.luau` to `path/path.config` will return `path/path.config.luau` */ -fn append_extension(path: impl Into, ext: &'static str) -> PathBuf { +pub fn append_extension(path: impl Into, ext: &'static str) -> PathBuf { let mut new: PathBuf = path.into(); match new.extension() { // FUTURE: There's probably a better way to do this than converting to a lossy string