eliminate unnecessary checks for files in require

This commit is contained in:
AshleyFlow 2024-10-17 13:47:53 +03:30
parent 6ce4563655
commit 981d323556
3 changed files with 100 additions and 121 deletions

View file

@ -43,7 +43,7 @@ impl RequireContext {
pub(crate) fn std_exists(lua: &Lua, alias: &str) -> Result<bool, RequireError> {
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
.ok_or_else(|| RequireError::RequireContextNotFound)?;
Ok(data_ref.std.contains_key(alias))
}
@ -54,7 +54,7 @@ impl RequireContext {
) -> Result<LuaMultiValue<'_>, RequireError> {
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
.ok_or_else(|| RequireError::RequireContextNotFound)?;
if let Some(cached) = data_ref.std_cache.get(&require_alias) {
let multi_vec = lua.registry_value::<Vec<LuaValue>>(cached)?;
@ -62,17 +62,17 @@ impl RequireContext {
return Ok(LuaMultiValue::from_vec(multi_vec));
}
let libraries = data_ref.std.get(&require_alias.alias.as_str()).ok_or(
RequireError::InvalidStdAlias(require_alias.alias.to_string()),
)?;
let libraries = data_ref
.std
.get(&require_alias.alias.as_str())
.ok_or_else(|| RequireError::InvalidStdAlias(require_alias.alias.to_string()))?;
let std =
libraries
.get(require_alias.path.as_str())
.ok_or(RequireError::StdMemberNotFound(
let std = libraries.get(require_alias.path.as_str()).ok_or_else(|| {
RequireError::StdMemberNotFound(
require_alias.path.to_string(),
require_alias.alias.to_string(),
))?;
)
})?;
let multi = std.module(lua)?;
let mutli_clone = multi.clone();
@ -82,7 +82,7 @@ impl RequireContext {
let mut data = lua
.app_data_mut::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
.ok_or_else(|| RequireError::RequireContextNotFound)?;
data.std_cache.insert(require_alias, multi_reg);
@ -95,7 +95,7 @@ impl RequireContext {
) -> Result<(), RequireError> {
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
.ok_or_else(|| RequireError::RequireContextNotFound)?;
let pending = data_ref.pending.try_lock()?;
@ -111,13 +111,33 @@ impl RequireContext {
Ok(())
}
fn is_pending(lua: &Lua, path_abs: &PathBuf) -> Result<bool, RequireError> {
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or_else(|| RequireError::RequireContextNotFound)?;
let pending = data_ref.pending.try_lock()?;
Ok(pending.get(path_abs).is_some())
}
fn is_cached(lua: &Lua, path_abs: &PathBuf) -> Result<bool, RequireError> {
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or_else(|| RequireError::RequireContextNotFound)?;
let cache = data_ref.cache.try_lock()?;
Ok(cache.get(path_abs).is_some())
}
async fn from_cache<'lua>(
lua: &'lua Lua,
path_abs: &'_ PathBuf,
) -> Result<Option<LuaMultiValue<'lua>>, RequireError> {
) -> Result<LuaMultiValue<'lua>, RequireError> {
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
.ok_or_else(|| RequireError::RequireContextNotFound)?;
let cache = data_ref.cache.lock().await;
@ -125,28 +145,32 @@ impl RequireContext {
Some(cached) => {
let multi_vec = lua.registry_value::<Vec<LuaValue>>(cached)?;
Ok(Some(LuaMultiValue::from_vec(multi_vec)))
Ok(LuaMultiValue::from_vec(multi_vec))
}
None => Ok(None),
None => Err(RequireError::CacheNotFound(
path_abs.to_string_lossy().to_string(),
)),
}
}
pub(crate) async fn require(
lua: &Lua,
path_rel: PathBuf,
path_abs: PathBuf,
) -> Result<LuaMultiValue, RequireError> {
if Self::is_pending(lua, &path_abs)? {
Self::wait_for_pending(lua, &path_abs).await?;
if let Some(cached) = Self::from_cache(lua, &path_abs).await? {
return Ok(cached);
return Self::from_cache(lua, &path_abs).await;
} else if Self::is_cached(lua, &path_abs)? {
return Self::from_cache(lua, &path_abs).await;
}
let content = fs::read_to_string(&path_abs).await?;
// create a broadcast channel
{
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
.ok_or_else(|| RequireError::RequireContextNotFound)?;
let broadcast_tx = broadcast::Sender::new(1);
@ -156,7 +180,6 @@ impl RequireContext {
}
}
let content = fs::read_to_string(&path_abs).await?;
let thread = lua
.load(&content)
.set_name(path_abs.to_string_lossy())
@ -168,13 +191,13 @@ impl RequireContext {
let multi = lua
.get_thread_result(thread_id)
.ok_or(RequireError::ThreadReturnedNone)??;
.ok_or_else(|| RequireError::ThreadReturnedNone)??;
let multi_reg = lua.create_registry_value(multi.into_vec())?;
let data_ref = lua
.app_data_ref::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
.ok_or_else(|| RequireError::RequireContextNotFound)?;
data_ref
.cache
@ -191,12 +214,7 @@ impl RequireContext {
broadcast_tx.send(()).ok();
match Self::from_cache(lua, &path_abs).await? {
Some(cached) => Ok(cached),
None => Err(RequireError::CacheNotFound(
path_rel.to_string_lossy().to_string(),
)),
}
Self::from_cache(lua, &path_abs).await
}
/**
@ -226,7 +244,7 @@ impl RequireContext {
) -> Result<(), RequireError> {
let mut data = lua
.app_data_mut::<RequireContextData>()
.ok_or(RequireError::RequireContextNotFound)?;
.ok_or_else(|| RequireError::RequireContextNotFound)?;
if let Some(map) = data.std.get_mut(alias) {
map.insert(std.name(), Box::new(std));

View file

@ -3,8 +3,9 @@ use crate::{
path::get_parent_path,
LuneStandardLibrary,
};
use lune_utils::path::clean_path_and_make_absolute;
use mlua::prelude::*;
use path::resolve_path;
use path::append_extension;
use std::path::PathBuf;
use thiserror::Error;
@ -38,6 +39,44 @@ pub enum RequireError {
LuaError(#[from] mlua::Error),
}
/**
tries different extensions on the path and if all alternatives fail, we'll try to look for an init file
*/
async fn try_alternatives(lua: &Lua, require_path_abs: PathBuf) -> LuaResult<LuaMultiValue> {
for ext in ["lua", "luau"] {
// try the path with ext
let ext_path = append_extension(&require_path_abs, ext);
match context::RequireContext::require(lua, ext_path).await {
Ok(res) => return Ok(res),
Err(err) => {
if !matches!(err, RequireError::IOError(_)) {
return Err(err).into_lua_err();
};
}
};
}
for ext in ["lua", "luau"] {
// append init to path and try it with ext
let ext_path = append_extension(require_path_abs.join("init"), ext);
match context::RequireContext::require(lua, ext_path).await {
Ok(res) => return Ok(res),
Err(err) => {
if !matches!(err, RequireError::IOError(_)) {
return Err(err).into_lua_err();
};
}
};
}
Err(RequireError::InvalidRequire(
require_path_abs.to_string_lossy().to_string(),
))
.into_lua_err()
}
async fn lua_require(lua: &Lua, path: String) -> LuaResult<LuaMultiValue> {
let require_path_rel = PathBuf::from(path);
let require_alias = RequireAlias::from_path(&require_path_rel).into_lua_err()?;
@ -46,29 +85,19 @@ async fn lua_require(lua: &Lua, path: String) -> LuaResult<LuaMultiValue> {
if context::RequireContext::std_exists(lua, &require_alias.alias).into_lua_err()? {
context::RequireContext::require_std(lua, require_alias).into_lua_err()
} else {
let require_path_abs = resolve_path(
&Luaurc::resolve_path(lua, &require_alias)
let require_path_abs = clean_path_and_make_absolute(
Luaurc::resolve_path(lua, &require_alias)
.await
.into_lua_err()?,
)
.await?;
);
context::RequireContext::require(lua, require_path_rel, require_path_abs)
.await
.into_lua_err()
try_alternatives(lua, require_path_abs).await
}
} else {
let parent_path = get_parent_path(lua)?;
let require_path_abs = resolve_path(&parent_path.join(&require_path_rel))
.await
.map_err(|_| {
RequireError::InvalidRequire(require_path_rel.to_string_lossy().to_string())
})
.into_lua_err()?;
let require_path_abs = clean_path_and_make_absolute(parent_path.join(&require_path_rel));
context::RequireContext::require(lua, require_path_rel, require_path_abs)
.await
.into_lua_err()
try_alternatives(lua, require_path_abs).await
}
}

View file

@ -1,72 +1,4 @@
use mlua::prelude::*;
use std::path::{Component, Path, PathBuf};
use tokio::fs;
/**
tries these alternatives on given path if path doesn't exist
* .lua and .luau extension
* path.join("init.luau") and path.join("init.lua")
*/
pub async fn resolve_path(path: &Path) -> LuaResult<PathBuf> {
let init_path = &path.join("init");
for ext in ["lua", "luau"] {
// try extension on given path
let path = append_extension(path, ext);
if fs::try_exists(&path).await? {
return Ok(normalize_path(&path));
};
// try extension on given path's init
let init_path = append_extension(init_path, ext);
if fs::try_exists(&init_path).await? {
return Ok(normalize_path(&init_path));
};
}
Err(LuaError::runtime("Could not resolve path"))
}
/**
Removes useless components from the given path
### Example
`./path/./path` turns into `./path/path`
*/
pub fn normalize_path(path: &Path) -> PathBuf {
let mut components = path.components().peekable();
let mut ret = if let Some(c @ Component::Prefix(..)) = components.clone().peek() {
components.next();
PathBuf::from(c.as_os_str())
} else {
PathBuf::new()
};
for component in components {
match component {
Component::Prefix(..) => unreachable!(),
Component::RootDir => {
ret.push(component.as_os_str());
}
Component::CurDir => {}
Component::ParentDir => {
ret.pop();
}
Component::Normal(c) => {
ret.push(c);
}
}
}
ret
}
use std::path::PathBuf;
/**
@ -77,7 +9,7 @@ adds extension to path without replacing it's current extensions
appending `.luau` to `path/path.config` will return `path/path.config.luau`
*/
fn append_extension(path: impl Into<PathBuf>, ext: &'static str) -> PathBuf {
pub fn append_extension(path: impl Into<PathBuf>, ext: &'static str) -> PathBuf {
let mut new: PathBuf = path.into();
match new.extension() {
// FUTURE: There's probably a better way to do this than converting to a lossy string