Fix async require cache, unify relative and cwd-relative require functions

This commit is contained in:
Filip Tibell 2023-08-20 11:46:38 -05:00
parent d6c31f67ba
commit 9182427a0a
5 changed files with 150 additions and 135 deletions

View file

@ -1,20 +0,0 @@
use mlua::prelude::*;
use super::context::*;
pub(super) async fn require<'lua, 'ctx>(
lua: &'lua Lua,
ctx: &'ctx RequireContext,
path: &str,
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'ctx,
{
if ctx.is_cached(path)? {
ctx.get_from_cache(lua, path)
} else if ctx.is_pending(path)? {
ctx.wait_for_cache(lua, path).await
} else {
ctx.load(lua, path).await
}
}

View file

@ -1,22 +1,39 @@
use std::{collections::HashMap, env, path::PathBuf, sync::Arc}; use std::{
collections::HashMap,
env,
path::{Path, PathBuf},
sync::Arc,
};
use mlua::prelude::*; use mlua::prelude::*;
use tokio::{fs, sync::Mutex as AsyncMutex}; use tokio::{
fs,
sync::{
broadcast::{self, Sender},
Mutex as AsyncMutex,
},
};
use crate::lune::{ use crate::lune::{
builtins::LuneBuiltin, builtins::LuneBuiltin,
scheduler::{IntoLuaOwnedThread, Scheduler, SchedulerThreadId}, scheduler::{IntoLuaOwnedThread, Scheduler},
}; };
const REGISTRY_KEY: &str = "RequireContext"; const REGISTRY_KEY: &str = "RequireContext";
/**
Context containing cached results for all `require` operations.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(super) struct RequireContext { pub(super) struct RequireContext {
use_cwd_relative_paths: bool, use_cwd_relative_paths: bool,
working_directory: PathBuf, working_directory: PathBuf,
cache_builtins: Arc<AsyncMutex<HashMap<LuneBuiltin, LuaResult<LuaRegistryKey>>>>, cache_builtins: Arc<AsyncMutex<HashMap<LuneBuiltin, LuaResult<LuaRegistryKey>>>>,
cache_results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>, cache_results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>,
cache_pending: Arc<AsyncMutex<HashMap<PathBuf, SchedulerThreadId>>>, cache_pending: Arc<AsyncMutex<HashMap<PathBuf, Sender<()>>>>,
} }
impl RequireContext { impl RequireContext {
@ -41,59 +58,58 @@ impl RequireContext {
} }
/** /**
If `require` should use cwd-relative paths or not. Resolves the given `source` and `path` into require paths
to use, based on the current require context settings.
This will resolve path segments such as `./`, `../`, ..., and
if the resolved path is not an absolute path, will create an
absolute path by prepending the current working directory.
*/ */
pub fn use_cwd_relative_paths(&self) -> bool { pub fn resolve_paths(
self.use_cwd_relative_paths &self,
} source: impl AsRef<str>,
path: impl AsRef<str>,
/** ) -> LuaResult<(PathBuf, PathBuf)> {
Transforms the path into an absolute path. let path = if self.use_cwd_relative_paths {
PathBuf::from(path.as_ref())
If the given path is already an absolute path, this
will only resolve path segments such as `./`, `../`, ...
If the given path is not absolute, it first gets transformed into an
absolute path by prepending the path to the current working directory.
*/
fn abs_path(&self, path: impl AsRef<str>) -> PathBuf {
let path = path_clean::clean(path.as_ref());
if path.is_absolute() {
path
} else { } else {
self.working_directory.join(path) PathBuf::from(source.as_ref())
} .parent()
.ok_or_else(|| LuaError::runtime("Failed to get parent path of source"))?
.join(path.as_ref())
};
let rel_path = path_clean::clean(path);
let abs_path = if rel_path.is_absolute() {
rel_path.to_path_buf()
} else {
self.working_directory.join(&rel_path)
};
Ok((rel_path, abs_path))
} }
/** /**
Checks if the given path has a cached require result. Checks if the given path has a cached require result.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/ */
pub fn is_cached(&self, path: impl AsRef<str>) -> LuaResult<bool> { pub fn is_cached(&self, abs_path: impl AsRef<Path>) -> LuaResult<bool> {
let path = self.abs_path(path);
let is_cached = self let is_cached = self
.cache_results .cache_results
.try_lock() .try_lock()
.expect("RequireContext may not be used from multiple threads") .expect("RequireContext may not be used from multiple threads")
.contains_key(&path); .contains_key(abs_path.as_ref());
Ok(is_cached) Ok(is_cached)
} }
/** /**
Checks if the given path is currently being used in `require`. Checks if the given path is currently being used in `require`.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/ */
pub fn is_pending(&self, path: impl AsRef<str>) -> LuaResult<bool> { pub fn is_pending(&self, abs_path: impl AsRef<Path>) -> LuaResult<bool> {
let path = self.abs_path(path);
let is_pending = self let is_pending = self
.cache_pending .cache_pending
.try_lock() .try_lock()
.expect("RequireContext may not be used from multiple threads") .expect("RequireContext may not be used from multiple threads")
.contains_key(&path); .contains_key(abs_path.as_ref());
Ok(is_pending) Ok(is_pending)
} }
@ -101,24 +117,19 @@ impl RequireContext {
Gets the resulting value from the require cache. Gets the resulting value from the require cache.
Will panic if the path has not been cached, use [`is_cached`] first. Will panic if the path has not been cached, use [`is_cached`] first.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/ */
pub fn get_from_cache<'lua>( pub fn get_from_cache<'lua>(
&self, &self,
lua: &'lua Lua, lua: &'lua Lua,
path: impl AsRef<str>, abs_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> { ) -> LuaResult<LuaMultiValue<'lua>> {
let path = self.abs_path(path);
let results = self let results = self
.cache_results .cache_results
.try_lock() .try_lock()
.expect("RequireContext may not be used from multiple threads"); .expect("RequireContext may not be used from multiple threads");
let cached = results let cached = results
.get(&path) .get(abs_path.as_ref())
.expect("Path does not exist in results cache"); .expect("Path does not exist in results cache");
match cached { match cached {
Err(e) => Err(e.clone()), Err(e) => Err(e.clone()),
@ -135,77 +146,56 @@ impl RequireContext {
Waits for the resulting value from the require cache. Waits for the resulting value from the require cache.
Will panic if the path has not been cached, use [`is_cached`] first. Will panic if the path has not been cached, use [`is_cached`] first.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/ */
pub async fn wait_for_cache<'lua>( pub async fn wait_for_cache<'lua>(
&self, &self,
lua: &'lua Lua, lua: &'lua Lua,
path: impl AsRef<str>, abs_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> { ) -> LuaResult<LuaMultiValue<'lua>> {
let path = self.abs_path(path); let mut thread_recv = {
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
let thread_id = {
let pending = self let pending = self
.cache_pending .cache_pending
.try_lock() .try_lock()
.expect("RequireContext may not be used from multiple threads"); .expect("RequireContext may not be used from multiple threads");
let thread_id = pending let thread_id = pending
.get(&path) .get(abs_path.as_ref())
.expect("Path is not currently pending require"); .expect("Path is not currently pending require");
*thread_id thread_id.subscribe()
}; };
sched.wait_for_thread(thread_id).await thread_recv.recv().await.into_lua_err()?;
self.get_from_cache(lua, abs_path.as_ref())
} }
/** async fn load(
Loads (requires) the file at the given path.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/
pub async fn load<'lua>(
&self, &self,
lua: &'lua Lua, lua: &Lua,
path: impl AsRef<str>, abs_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> { rel_path: impl AsRef<Path>,
let path = self.abs_path(path); ) -> LuaResult<LuaRegistryKey> {
let abs_path = abs_path.as_ref();
let rel_path = rel_path.as_ref();
let sched = lua let sched = lua
.app_data_ref::<&Scheduler>() .app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler"); .expect("Lua struct is missing scheduler");
// TODO: Store any fs error in the cache, too // Read the file at the given path, try to parse and
let file_contents = fs::read(&path).await?; // load it into a new lua thread that we can schedule
let file_contents = fs::read(&abs_path).await?;
// TODO: Store any lua loading/parsing error in the cache, too
// TODO: Set chunk name as file name relative to cwd
let file_thread = lua let file_thread = lua
.load(file_contents) .load(file_contents)
.set_name(rel_path.to_string_lossy().to_string())
.into_function()? .into_function()?
.into_owned_lua_thread(lua)?; .into_owned_lua_thread(lua)?;
// Schedule the thread to run and store the pending thread id in the require context // Schedule the thread to run, wait for it to finish running
let thread_id = { let thread_id = sched.push_back(file_thread, ())?;
let thread_id = sched.push_back(file_thread, ())?;
self.cache_pending
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.insert(path.clone(), thread_id);
thread_id
};
// Wait for the thread to finish running
let thread_res = sched.wait_for_thread(thread_id).await; let thread_res = sched.wait_for_thread(thread_id).await;
// Clone the result and store it in the cache, note // Return the result of the thread, storing any lua value(s) in the registry
// that cloning a [`LuaValue`] will still refer to match thread_res {
// the same underlying lua data and indentity
let result = match thread_res.clone() {
Err(e) => Err(e), Err(e) => Err(e),
Ok(multi) => { Ok(multi) => {
let multi_vec = multi.into_vec(); let multi_vec = multi.into_vec();
@ -214,21 +204,62 @@ impl RequireContext {
.expect("Failed to store require result in registry"); .expect("Failed to store require result in registry");
Ok(multi_key) Ok(multi_key)
} }
}
}
/**
Loads (requires) the file at the given path.
*/
pub async fn load_with_caching<'lua>(
&self,
lua: &'lua Lua,
abs_path: impl AsRef<Path>,
rel_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> {
let abs_path = abs_path.as_ref();
let rel_path = rel_path.as_ref();
// Set this abs path as currently pending
let (broadcast_tx, _) = broadcast::channel(1);
self.cache_pending
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.insert(abs_path.to_path_buf(), broadcast_tx);
// Try to load at this abs path
let load_res = self.load(lua, abs_path, rel_path).await;
let load_val = match &load_res {
Err(e) => Err(e.clone()),
Ok(k) => {
let multi_vec = lua
.registry_value::<Vec<LuaValue>>(k)
.expect("Failed to fetch require result from registry");
Ok(LuaMultiValue::from_vec(multi_vec))
}
}; };
// NOTE: We use the async lock and not try_lock here because // NOTE: We use the async lock and not try_lock here because
// some other thread may be wanting to insert into the require // some other thread may be wanting to insert into the require
// cache at the same time, and that's not an actual error case // cache at the same time, and that's not an actual error case
self.cache_results.lock().await.insert(path.clone(), result); self.cache_results
.lock()
.await
.insert(abs_path.to_path_buf(), load_res);
// Remove the pending thread id from the require context // Remove the pending thread id from the require context,
self.cache_pending // broadcast a message to let any listeners know that this
// path has now finished the require process and is cached
let broadcast_tx = self
.cache_pending
.try_lock() .try_lock()
.expect("RequireContext may not be used from multiple threads") .expect("RequireContext may not be used from multiple threads")
.remove(&path) .remove(abs_path)
.expect("Pending require thread id was unexpectedly removed"); .expect("Pending require broadcaster was unexpectedly removed");
broadcast_tx
.send(())
.expect("Failed to send require broadcast");
thread_res load_val
} }
/** /**

View file

@ -5,10 +5,9 @@ use crate::lune::{scheduler::LuaSchedulerExt, util::TableBuilder};
mod context; mod context;
use context::RequireContext; use context::RequireContext;
mod absolute;
mod alias; mod alias;
mod builtin; mod builtin;
mod relative; mod path;
const REQUIRE_IMPL: &str = r#" const REQUIRE_IMPL: &str = r#"
return require(source(), ...) return require(source(), ...)
@ -67,6 +66,8 @@ async fn require<'lua>(
where where
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it 'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
{ {
// TODO: Use proper lua strings, os strings, to avoid lossy conversions
let source = source let source = source
.to_str() .to_str()
.into_lua_err() .into_lua_err()
@ -93,9 +94,7 @@ where
"Require with custom alias must contain '/' delimiter", "Require with custom alias must contain '/' delimiter",
))?; ))?;
alias::require(lua, &context, alias, name).await alias::require(lua, &context, alias, name).await
} else if context.use_cwd_relative_paths() {
absolute::require(lua, &context, &path).await
} else { } else {
relative::require(lua, &context, &source, &path).await path::require(lua, &context, &source, &path).await
} }
} }

View file

@ -0,0 +1,22 @@
use mlua::prelude::*;
use super::context::*;
pub(super) async fn require<'lua, 'ctx>(
lua: &'lua Lua,
ctx: &'ctx RequireContext,
source: &str,
path: &str,
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'ctx,
{
let (abs_path, rel_path) = ctx.resolve_paths(source, path)?;
if ctx.is_cached(&abs_path)? {
ctx.get_from_cache(lua, &abs_path)
} else if ctx.is_pending(&abs_path)? {
ctx.wait_for_cache(lua, &abs_path).await
} else {
ctx.load_with_caching(lua, &abs_path, &rel_path).await
}
}

View file

@ -1,17 +0,0 @@
use mlua::prelude::*;
use super::context::*;
pub(super) async fn require<'lua, 'ctx>(
_lua: &'lua Lua,
_ctx: &'ctx RequireContext,
source: &str,
path: &str,
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'ctx,
{
Err(LuaError::runtime(format!(
"TODO: Support require for absolute paths (tried to require '{path}' from '{source}')"
)))
}