Implement bulk of new require behavior

This commit is contained in:
Filip Tibell 2023-08-19 15:31:17 -05:00
parent 7d73601a58
commit bcef44e286
5 changed files with 232 additions and 16 deletions

7
Cargo.lock generated
View file

@ -1049,6 +1049,7 @@ dependencies = [
"mlua", "mlua",
"once_cell", "once_cell",
"os_str_bytes", "os_str_bytes",
"path-clean",
"pin-project", "pin-project",
"rand", "rand",
"rbx_binary", "rbx_binary",
@ -1267,6 +1268,12 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
[[package]]
name = "path-clean"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17359afc20d7ab31fdb42bb844c8b3bb1dabd7dcf7e68428492da7f16966fcef"
[[package]] [[package]]
name = "percent-encoding" name = "percent-encoding"
version = "2.3.0" version = "2.3.0"

View file

@ -71,6 +71,7 @@ async-trait = "0.1"
dialoguer = "0.10" dialoguer = "0.10"
dunce = "1.0" dunce = "1.0"
lz4_flex = "0.11" lz4_flex = "0.11"
path-clean = "1.0"
pin-project = "1.0" pin-project = "1.0"
os_str_bytes = "6.4" os_str_bytes = "6.4"
urlencoding = "2.1" urlencoding = "2.1"

View file

@ -1,32 +1,231 @@
use std::{collections::HashMap, env, path::PathBuf, sync::Arc};
use mlua::prelude::*; use mlua::prelude::*;
use tokio::{fs, sync::Mutex as AsyncMutex};
use crate::lune::scheduler::{IntoLuaOwnedThread, Scheduler, SchedulerThreadId};
const REGISTRY_KEY: &str = "RequireContext"; const REGISTRY_KEY: &str = "RequireContext";
// TODO: Store current file path for each thread in #[derive(Debug, Clone)]
// this context somehow, as well as built-in libraries
#[derive(Clone)]
pub(super) struct RequireContext { pub(super) struct RequireContext {
pub(super) use_absolute_paths: bool, use_absolute_paths: bool,
working_directory: PathBuf,
cache_results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>,
cache_pending: Arc<AsyncMutex<HashMap<PathBuf, SchedulerThreadId>>>,
} }
impl RequireContext { impl RequireContext {
pub fn new() -> Self { /**
Self { Creates a new require context for the given [`Lua`] struct.
Note that this require context is global and only one require
context should be created per [`Lua`] struct, creating more
than one context may lead to undefined require-behavior.
*/
pub fn create(lua: &Lua) {
let this = Self {
// TODO: Set to false by default, load some kind of config // TODO: Set to false by default, load some kind of config
// or env var to check if we should be using absolute paths // or env var to check if we should be using absolute paths
use_absolute_paths: true, use_absolute_paths: true,
} working_directory: env::current_dir().expect("Failed to get current working directory"),
} cache_results: Arc::new(AsyncMutex::new(HashMap::new())),
cache_pending: Arc::new(AsyncMutex::new(HashMap::new())),
pub fn from_registry(lua: &Lua) -> Self { };
lua.named_registry_value(REGISTRY_KEY) lua.set_named_registry_value(REGISTRY_KEY, this)
.expect("Missing require context in lua registry")
}
pub fn insert_into_registry(self, lua: &Lua) {
lua.set_named_registry_value(REGISTRY_KEY, self)
.expect("Failed to insert RequireContext into registry"); .expect("Failed to insert RequireContext into registry");
} }
/**
If `require` should use absolute paths or not.
*/
pub fn use_absolute_paths(&self) -> bool {
self.use_absolute_paths
}
/**
Transforms the path into an absolute path.
If the given path is already an absolute path, this
will only resolve path segments such as `./`, `../`, ...
If the given path is not absolute, it first gets transformed into an
absolute path by prepending the path to the current working directory.
*/
fn abs_path(&self, path: impl AsRef<str>) -> PathBuf {
let path = path_clean::clean(path.as_ref());
if path.is_absolute() {
path
} else {
self.working_directory.join(path)
}
}
/**
Checks if the given path has a cached require result.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/
pub fn is_cached(&self, path: impl AsRef<str>) -> LuaResult<bool> {
let path = self.abs_path(path);
let is_cached = self
.cache_results
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.contains_key(&path);
Ok(is_cached)
}
/**
Checks if the given path is currently being used in `require`.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/
pub fn is_pending(&self, path: impl AsRef<str>) -> LuaResult<bool> {
let path = self.abs_path(path);
let is_pending = self
.cache_pending
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.contains_key(&path);
Ok(is_pending)
}
/**
Gets the resulting value from the require cache.
Will panic if the path has not been cached, use [`is_cached`] first.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/
pub fn get_from_cache<'lua>(
&'lua self,
lua: &'lua Lua,
path: impl AsRef<str> + 'lua,
) -> LuaResult<LuaMultiValue<'lua>> {
let path = self.abs_path(path);
let results = self
.cache_results
.try_lock()
.expect("RequireContext may not be used from multiple threads");
let cached = results
.get(&path)
.expect("Path does not exist in results cache");
match cached {
Err(e) => Err(e.clone()),
Ok(key) => {
let multi_vec = lua
.registry_value::<Vec<LuaValue>>(key)
.expect("Missing require result in lua registry");
Ok(LuaMultiValue::from_vec(multi_vec))
}
}
}
/**
Waits for the resulting value from the require cache.
Will panic if the path has not been cached, use [`is_cached`] first.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/
pub async fn wait_for_cache<'lua>(
&'lua self,
lua: &'lua Lua,
path: impl AsRef<str> + 'lua,
) -> LuaResult<LuaMultiValue<'lua>> {
let path = self.abs_path(path);
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
let thread_id = {
let pending = self
.cache_pending
.try_lock()
.expect("RequireContext may not be used from multiple threads");
let thread_id = pending
.get(&path)
.expect("Path is not currently pending require");
*thread_id
};
sched.wait_for_thread(thread_id).await
}
/**
Loads (requires) the file at the given path.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/
pub async fn load<'lua>(
&'lua self,
lua: &'lua Lua,
path: impl AsRef<str> + 'lua,
) -> LuaResult<LuaMultiValue<'lua>> {
let path = self.abs_path(path);
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
// TODO: Store any fs error in the cache, too
let file_contents = fs::read(&path).await?;
// TODO: Store any lua loading/parsing error in the cache, too
// TODO: Set chunk name as file name relative to cwd
let file_thread = lua
.load(file_contents)
.into_function()?
.into_owned_lua_thread(lua)?;
// Schedule the thread to run and store the pending thread id in the require context
let thread_id = {
let thread_id = sched.push_back(file_thread, ())?;
self.cache_pending
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.insert(path.clone(), thread_id);
thread_id
};
// Wait for the thread to finish running
let thread_res = sched.wait_for_thread(thread_id).await;
// Clone the result and store it in the cache, note
// that cloning a [`LuaValue`] will still refer to
// the same underlying lua data and indentity
let result = match thread_res.clone() {
Err(e) => Err(e),
Ok(multi) => {
let multi_vec = multi.into_vec();
let multi_key = lua
.create_registry_value(multi_vec)
.expect("Failed to store require result in registry");
Ok(multi_key)
}
};
// NOTE: We use the async lock and not try_lock here because
// some other thread may be wanting to insert into the require
// cache at the same time, and that's not an actual error case
self.cache_results.lock().await.insert(path.clone(), result);
// Remove the pending thread id from the require context
self.cache_pending
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.remove(&path)
.expect("Pending require thread id was unexpectedly removed");
thread_res
}
} }
impl LuaUserData for RequireContext {} impl LuaUserData for RequireContext {}
@ -41,3 +240,11 @@ impl<'lua> FromLua<'lua> for RequireContext {
unreachable!("RequireContext should only be used from registry") unreachable!("RequireContext should only be used from registry")
} }
} }
impl<'lua> From<&'lua Lua> for RequireContext {
fn from(value: &'lua Lua) -> Self {
value
.named_registry_value(REGISTRY_KEY)
.expect("Missing require context in lua registry")
}
}

View file

@ -11,10 +11,10 @@ mod builtin;
mod relative; mod relative;
pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> { pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> {
RequireContext::new().insert_into_registry(lua); RequireContext::create(lua);
lua.create_async_function(|lua, path: LuaString| async move { lua.create_async_function(|lua, path: LuaString| async move {
let context = RequireContext::from_registry(lua); let context = RequireContext::from(lua);
let path = path let path = path
.to_str() .to_str()
@ -32,7 +32,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> {
"Require with custom alias must contain '/' delimiter", "Require with custom alias must contain '/' delimiter",
))?; ))?;
alias::require(lua, context, alias, name).await alias::require(lua, context, alias, name).await
} else if context.use_absolute_paths { } else if context.use_absolute_paths() {
absolute::require(lua, context, &path).await absolute::require(lua, context, &path).await
} else { } else {
relative::require(lua, context, &path).await relative::require(lua, context, &path).await

View file

@ -17,11 +17,12 @@ mod impl_async;
mod impl_runner; mod impl_runner;
mod impl_threads; mod impl_threads;
pub use self::thread::SchedulerThreadId;
pub use self::traits::*; pub use self::traits::*;
use self::{ use self::{
state::SchedulerState, state::SchedulerState,
thread::{SchedulerThread, SchedulerThreadId, SchedulerThreadSender}, thread::{SchedulerThread, SchedulerThreadSender},
}; };
type SchedulerFuture<'fut> = Pin<Box<dyn Future<Output = ()> + 'fut>>; type SchedulerFuture<'fut> = Pin<Box<dyn Future<Output = ()> + 'fut>>;