diff --git a/Cargo.lock b/Cargo.lock index 2ca13cd..47f3717 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1049,6 +1049,7 @@ dependencies = [ "mlua", "once_cell", "os_str_bytes", + "path-clean", "pin-project", "rand", "rbx_binary", @@ -1267,6 +1268,12 @@ version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +[[package]] +name = "path-clean" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "17359afc20d7ab31fdb42bb844c8b3bb1dabd7dcf7e68428492da7f16966fcef" + [[package]] name = "percent-encoding" version = "2.3.0" diff --git a/Cargo.toml b/Cargo.toml index 61ad355..255bc92 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -71,6 +71,7 @@ async-trait = "0.1" dialoguer = "0.10" dunce = "1.0" lz4_flex = "0.11" +path-clean = "1.0" pin-project = "1.0" os_str_bytes = "6.4" urlencoding = "2.1" diff --git a/src/lune/globals/require/context.rs b/src/lune/globals/require/context.rs index 54fd981..c2863ae 100644 --- a/src/lune/globals/require/context.rs +++ b/src/lune/globals/require/context.rs @@ -1,31 +1,230 @@ +use std::{collections::HashMap, env, path::PathBuf, sync::Arc}; + use mlua::prelude::*; +use tokio::{fs, sync::Mutex as AsyncMutex}; + +use crate::lune::scheduler::{IntoLuaOwnedThread, Scheduler, SchedulerThreadId}; const REGISTRY_KEY: &str = "RequireContext"; -// TODO: Store current file path for each thread in -// this context somehow, as well as built-in libraries -#[derive(Clone)] +#[derive(Debug, Clone)] pub(super) struct RequireContext { - pub(super) use_absolute_paths: bool, + use_absolute_paths: bool, + working_directory: PathBuf, + cache_results: Arc>>>, + cache_pending: Arc>>, } impl RequireContext { - pub fn new() -> Self { - Self { + /** + Creates a new require context for the given [`Lua`] struct. + + Note that this require context is global and only one require + context should be created per [`Lua`] struct, creating more + than one context may lead to undefined require-behavior. + */ + pub fn create(lua: &Lua) { + let this = Self { // TODO: Set to false by default, load some kind of config // or env var to check if we should be using absolute paths use_absolute_paths: true, + working_directory: env::current_dir().expect("Failed to get current working directory"), + cache_results: Arc::new(AsyncMutex::new(HashMap::new())), + cache_pending: Arc::new(AsyncMutex::new(HashMap::new())), + }; + lua.set_named_registry_value(REGISTRY_KEY, this) + .expect("Failed to insert RequireContext into registry"); + } + + /** + If `require` should use absolute paths or not. + */ + pub fn use_absolute_paths(&self) -> bool { + self.use_absolute_paths + } + + /** + Transforms the path into an absolute path. + + If the given path is already an absolute path, this + will only resolve path segments such as `./`, `../`, ... + + If the given path is not absolute, it first gets transformed into an + absolute path by prepending the path to the current working directory. + */ + fn abs_path(&self, path: impl AsRef) -> PathBuf { + let path = path_clean::clean(path.as_ref()); + if path.is_absolute() { + path + } else { + self.working_directory.join(path) } } - pub fn from_registry(lua: &Lua) -> Self { - lua.named_registry_value(REGISTRY_KEY) - .expect("Missing require context in lua registry") + /** + Checks if the given path has a cached require result. + + The cache uses absolute paths, so any given relative + path will first be transformed into an absolute path. + */ + pub fn is_cached(&self, path: impl AsRef) -> LuaResult { + let path = self.abs_path(path); + let is_cached = self + .cache_results + .try_lock() + .expect("RequireContext may not be used from multiple threads") + .contains_key(&path); + Ok(is_cached) } - pub fn insert_into_registry(self, lua: &Lua) { - lua.set_named_registry_value(REGISTRY_KEY, self) - .expect("Failed to insert RequireContext into registry"); + /** + Checks if the given path is currently being used in `require`. + + The cache uses absolute paths, so any given relative + path will first be transformed into an absolute path. + */ + pub fn is_pending(&self, path: impl AsRef) -> LuaResult { + let path = self.abs_path(path); + let is_pending = self + .cache_pending + .try_lock() + .expect("RequireContext may not be used from multiple threads") + .contains_key(&path); + Ok(is_pending) + } + + /** + Gets the resulting value from the require cache. + + Will panic if the path has not been cached, use [`is_cached`] first. + + The cache uses absolute paths, so any given relative + path will first be transformed into an absolute path. + */ + pub fn get_from_cache<'lua>( + &'lua self, + lua: &'lua Lua, + path: impl AsRef + 'lua, + ) -> LuaResult> { + let path = self.abs_path(path); + + let results = self + .cache_results + .try_lock() + .expect("RequireContext may not be used from multiple threads"); + + let cached = results + .get(&path) + .expect("Path does not exist in results cache"); + match cached { + Err(e) => Err(e.clone()), + Ok(key) => { + let multi_vec = lua + .registry_value::>(key) + .expect("Missing require result in lua registry"); + Ok(LuaMultiValue::from_vec(multi_vec)) + } + } + } + + /** + Waits for the resulting value from the require cache. + + Will panic if the path has not been cached, use [`is_cached`] first. + + The cache uses absolute paths, so any given relative + path will first be transformed into an absolute path. + */ + pub async fn wait_for_cache<'lua>( + &'lua self, + lua: &'lua Lua, + path: impl AsRef + 'lua, + ) -> LuaResult> { + let path = self.abs_path(path); + let sched = lua + .app_data_ref::<&Scheduler>() + .expect("Lua struct is missing scheduler"); + + let thread_id = { + let pending = self + .cache_pending + .try_lock() + .expect("RequireContext may not be used from multiple threads"); + let thread_id = pending + .get(&path) + .expect("Path is not currently pending require"); + *thread_id + }; + + sched.wait_for_thread(thread_id).await + } + + /** + Loads (requires) the file at the given path. + + The cache uses absolute paths, so any given relative + path will first be transformed into an absolute path. + */ + pub async fn load<'lua>( + &'lua self, + lua: &'lua Lua, + path: impl AsRef + 'lua, + ) -> LuaResult> { + let path = self.abs_path(path); + let sched = lua + .app_data_ref::<&Scheduler>() + .expect("Lua struct is missing scheduler"); + + // TODO: Store any fs error in the cache, too + let file_contents = fs::read(&path).await?; + + // TODO: Store any lua loading/parsing error in the cache, too + // TODO: Set chunk name as file name relative to cwd + let file_thread = lua + .load(file_contents) + .into_function()? + .into_owned_lua_thread(lua)?; + + // Schedule the thread to run and store the pending thread id in the require context + let thread_id = { + let thread_id = sched.push_back(file_thread, ())?; + self.cache_pending + .try_lock() + .expect("RequireContext may not be used from multiple threads") + .insert(path.clone(), thread_id); + thread_id + }; + + // Wait for the thread to finish running + let thread_res = sched.wait_for_thread(thread_id).await; + + // Clone the result and store it in the cache, note + // that cloning a [`LuaValue`] will still refer to + // the same underlying lua data and indentity + let result = match thread_res.clone() { + Err(e) => Err(e), + Ok(multi) => { + let multi_vec = multi.into_vec(); + let multi_key = lua + .create_registry_value(multi_vec) + .expect("Failed to store require result in registry"); + Ok(multi_key) + } + }; + + // NOTE: We use the async lock and not try_lock here because + // some other thread may be wanting to insert into the require + // cache at the same time, and that's not an actual error case + self.cache_results.lock().await.insert(path.clone(), result); + + // Remove the pending thread id from the require context + self.cache_pending + .try_lock() + .expect("RequireContext may not be used from multiple threads") + .remove(&path) + .expect("Pending require thread id was unexpectedly removed"); + + thread_res } } @@ -41,3 +240,11 @@ impl<'lua> FromLua<'lua> for RequireContext { unreachable!("RequireContext should only be used from registry") } } + +impl<'lua> From<&'lua Lua> for RequireContext { + fn from(value: &'lua Lua) -> Self { + value + .named_registry_value(REGISTRY_KEY) + .expect("Missing require context in lua registry") + } +} diff --git a/src/lune/globals/require/mod.rs b/src/lune/globals/require/mod.rs index 8315687..7418a85 100644 --- a/src/lune/globals/require/mod.rs +++ b/src/lune/globals/require/mod.rs @@ -11,10 +11,10 @@ mod builtin; mod relative; pub fn create(lua: &'static Lua) -> LuaResult> { - RequireContext::new().insert_into_registry(lua); + RequireContext::create(lua); lua.create_async_function(|lua, path: LuaString| async move { - let context = RequireContext::from_registry(lua); + let context = RequireContext::from(lua); let path = path .to_str() @@ -32,7 +32,7 @@ pub fn create(lua: &'static Lua) -> LuaResult> { "Require with custom alias must contain '/' delimiter", ))?; alias::require(lua, context, alias, name).await - } else if context.use_absolute_paths { + } else if context.use_absolute_paths() { absolute::require(lua, context, &path).await } else { relative::require(lua, context, &path).await diff --git a/src/lune/scheduler/mod.rs b/src/lune/scheduler/mod.rs index 8b970a0..398b1d5 100644 --- a/src/lune/scheduler/mod.rs +++ b/src/lune/scheduler/mod.rs @@ -17,11 +17,12 @@ mod impl_async; mod impl_runner; mod impl_threads; +pub use self::thread::SchedulerThreadId; pub use self::traits::*; use self::{ state::SchedulerState, - thread::{SchedulerThread, SchedulerThreadId, SchedulerThreadSender}, + thread::{SchedulerThread, SchedulerThreadSender}, }; type SchedulerFuture<'fut> = Pin + 'fut>>;