mirror of
https://github.com/CompeyDev/lune-packaging.git
synced 2025-01-10 20:49:10 +00:00
290 lines
9.9 KiB
Rust
290 lines
9.9 KiB
Rust
use std::{
|
|
cell::RefCell,
|
|
collections::{HashMap, HashSet},
|
|
env::current_dir,
|
|
path::{self, PathBuf},
|
|
sync::Arc,
|
|
};
|
|
|
|
use dunce::canonicalize;
|
|
use mlua::prelude::*;
|
|
use tokio::fs;
|
|
use tokio::sync::Mutex as AsyncMutex;
|
|
|
|
use crate::lua::{
|
|
table::TableBuilder,
|
|
task::{TaskScheduler, TaskSchedulerScheduleExt},
|
|
};
|
|
|
|
use super::require_waker::{RequireWakerFuture, RequireWakerState};
|
|
|
|
const REQUIRE_IMPL_LUA: &str = r#"
|
|
local source = info(1, "s")
|
|
if source == '[string "require"]' then
|
|
source = info(2, "s")
|
|
end
|
|
load(context, source, ...)
|
|
return yield()
|
|
"#;
|
|
|
|
type RequireWakersVec<'lua> = Vec<Arc<AsyncMutex<RequireWakerState<'lua>>>>;
|
|
|
|
#[derive(Debug, Clone, Default)]
|
|
struct RequireContext<'lua> {
|
|
// NOTE: We need to use arc here so that mlua clones
|
|
// the reference and not the entire inner value(s)
|
|
builtins: Arc<HashMap<String, LuaMultiValue<'lua>>>,
|
|
cached: Arc<RefCell<HashMap<String, LuaResult<LuaMultiValue<'lua>>>>>,
|
|
wakers: Arc<RefCell<HashMap<String, RequireWakersVec<'lua>>>>,
|
|
locks: Arc<RefCell<HashSet<String>>>,
|
|
pwd: String,
|
|
}
|
|
|
|
impl<'lua> RequireContext<'lua> {
|
|
pub fn new<K, V>(lua: &'lua Lua, builtins_vec: Vec<(K, V)>) -> LuaResult<Self>
|
|
where
|
|
K: Into<String>,
|
|
V: ToLua<'lua>,
|
|
{
|
|
let mut pwd = current_dir()
|
|
.expect("Failed to access current working directory")
|
|
.to_string_lossy()
|
|
.to_string();
|
|
if !pwd.ends_with(path::MAIN_SEPARATOR) {
|
|
pwd = format!("{pwd}{}", path::MAIN_SEPARATOR)
|
|
}
|
|
let mut builtins = HashMap::new();
|
|
for (key, value) in builtins_vec {
|
|
builtins.insert(key.into(), value.to_lua_multi(lua)?);
|
|
}
|
|
Ok(Self {
|
|
pwd,
|
|
builtins: Arc::new(builtins),
|
|
..Default::default()
|
|
})
|
|
}
|
|
|
|
pub fn is_locked(&self, absolute_path: &str) -> bool {
|
|
self.locks.borrow().contains(absolute_path)
|
|
}
|
|
|
|
pub fn set_locked(&self, absolute_path: &str) -> bool {
|
|
self.locks.borrow_mut().insert(absolute_path.to_string())
|
|
}
|
|
|
|
pub fn set_unlocked(&self, absolute_path: &str) -> bool {
|
|
self.locks.borrow_mut().remove(absolute_path)
|
|
}
|
|
|
|
pub fn try_acquire_lock_sync(&self, absolute_path: &str) -> bool {
|
|
if self.is_locked(absolute_path) {
|
|
false
|
|
} else {
|
|
self.set_locked(absolute_path);
|
|
true
|
|
}
|
|
}
|
|
|
|
pub fn set_cached(&self, absolute_path: &str, result: &LuaResult<LuaMultiValue<'lua>>) {
|
|
self.cached
|
|
.borrow_mut()
|
|
.insert(absolute_path.to_string(), result.clone());
|
|
if let Some(wakers) = self.wakers.borrow_mut().remove(absolute_path) {
|
|
for waker in wakers {
|
|
waker
|
|
.try_lock()
|
|
.expect("Failed to lock waker")
|
|
.finalize(result.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn wait_for_cache(self, absolute_path: &str) -> RequireWakerFuture<'lua> {
|
|
let state = RequireWakerState::new();
|
|
let fut = RequireWakerFuture::new(&state);
|
|
self.wakers
|
|
.borrow_mut()
|
|
.entry(absolute_path.to_string())
|
|
.or_insert_with(Vec::new)
|
|
.push(Arc::clone(&state));
|
|
fut
|
|
}
|
|
|
|
pub fn get_paths(
|
|
&self,
|
|
require_source: String,
|
|
require_path: String,
|
|
) -> LuaResult<(String, String)> {
|
|
if require_path.starts_with('@') {
|
|
return Ok((require_path.clone(), require_path));
|
|
}
|
|
let path_relative_to_pwd = PathBuf::from(
|
|
&require_source
|
|
.trim_start_matches("[string \"")
|
|
.trim_end_matches("\"]"),
|
|
)
|
|
.parent()
|
|
.unwrap()
|
|
.join(&require_path);
|
|
// Try to normalize and resolve relative path segments such as './' and '../'
|
|
let file_path = match (
|
|
canonicalize(path_relative_to_pwd.with_extension("luau")),
|
|
canonicalize(path_relative_to_pwd.with_extension("lua")),
|
|
) {
|
|
(Ok(luau), _) => luau,
|
|
(_, Ok(lua)) => lua,
|
|
_ => {
|
|
return Err(LuaError::RuntimeError(format!(
|
|
"File does not exist at path '{require_path}'"
|
|
)))
|
|
}
|
|
};
|
|
let absolute = file_path.to_string_lossy().to_string();
|
|
let relative = absolute.trim_start_matches(&self.pwd).to_string();
|
|
Ok((absolute, relative))
|
|
}
|
|
}
|
|
|
|
impl<'lua> LuaUserData for RequireContext<'lua> {}
|
|
|
|
fn load_builtin<'lua>(
|
|
_lua: &'lua Lua,
|
|
context: RequireContext<'lua>,
|
|
module_name: String,
|
|
_has_acquired_lock: bool,
|
|
) -> LuaResult<LuaMultiValue<'lua>> {
|
|
match context.builtins.get(&module_name) {
|
|
Some(module) => Ok(module.clone()),
|
|
None => Err(LuaError::RuntimeError(format!(
|
|
"No builtin module exists with the name '{}'",
|
|
module_name
|
|
))),
|
|
}
|
|
}
|
|
|
|
async fn load_file<'lua>(
|
|
lua: &'lua Lua,
|
|
context: RequireContext<'lua>,
|
|
absolute_path: String,
|
|
relative_path: String,
|
|
has_acquired_lock: bool,
|
|
) -> LuaResult<LuaMultiValue<'lua>> {
|
|
let cached = { context.cached.borrow().get(&absolute_path).cloned() };
|
|
match cached {
|
|
Some(cached) => cached,
|
|
None => {
|
|
if !has_acquired_lock {
|
|
return context.wait_for_cache(&absolute_path).await;
|
|
}
|
|
// Try to read the wanted file, note that we use bytes instead of reading
|
|
// to a string since lua scripts are not necessarily valid utf-8 strings
|
|
let contents = fs::read(&absolute_path).await.map_err(LuaError::external)?;
|
|
// Use a name without extensions for loading the chunk, some
|
|
// other code assumes the require path is without extensions
|
|
let path_relative_no_extension = relative_path
|
|
.trim_end_matches(".lua")
|
|
.trim_end_matches(".luau");
|
|
// Load the file into a thread
|
|
let loaded_func = lua
|
|
.load(&contents)
|
|
.set_name(path_relative_no_extension)?
|
|
.into_function()?;
|
|
let loaded_thread = lua.create_thread(loaded_func)?;
|
|
// Run the thread and wait for completion using the native task scheduler waker
|
|
let task_fut = {
|
|
let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
|
|
let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?;
|
|
sched.wait_for_task_completion(task)
|
|
};
|
|
// Wait for the thread to finish running, cache + return our result,
|
|
// notify any other threads that are also waiting on this to finish
|
|
let rets = task_fut.await;
|
|
context.set_cached(&absolute_path, &rets);
|
|
rets
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn load<'lua>(
|
|
lua: &'lua Lua,
|
|
context: RequireContext<'lua>,
|
|
absolute_path: String,
|
|
relative_path: String,
|
|
has_acquired_lock: bool,
|
|
) -> LuaResult<LuaMultiValue<'lua>> {
|
|
let result = if absolute_path == relative_path && absolute_path.starts_with('@') {
|
|
if let Some(module_name) = absolute_path.strip_prefix("@lune/") {
|
|
load_builtin(
|
|
lua,
|
|
context.clone(),
|
|
module_name.to_string(),
|
|
has_acquired_lock,
|
|
)
|
|
} else {
|
|
// FUTURE: '@' can be used a special prefix for users to set their own
|
|
// paths relative to a project file, similar to typescript paths config
|
|
// https://www.typescriptlang.org/tsconfig#paths
|
|
Err(LuaError::RuntimeError(
|
|
"Require paths prefixed by '@' are not yet supported".to_string(),
|
|
))
|
|
}
|
|
} else {
|
|
load_file(
|
|
lua,
|
|
context.clone(),
|
|
absolute_path.to_string(),
|
|
relative_path,
|
|
has_acquired_lock,
|
|
)
|
|
.await
|
|
};
|
|
if has_acquired_lock {
|
|
context.set_unlocked(&absolute_path);
|
|
}
|
|
result
|
|
}
|
|
|
|
pub fn create<K, V>(lua: &'static Lua, builtins: Vec<(K, V)>) -> LuaResult<LuaFunction>
|
|
where
|
|
K: Clone + Into<String>,
|
|
V: Clone + ToLua<'static>,
|
|
{
|
|
let require_context = RequireContext::new(lua, builtins)?;
|
|
let require_yield: LuaFunction = lua.named_registry_value("co.yield")?;
|
|
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
|
|
let require_print: LuaFunction = lua.named_registry_value("print")?;
|
|
|
|
let require_env = TableBuilder::new(lua)?
|
|
.with_value("context", require_context)?
|
|
.with_value("yield", require_yield)?
|
|
.with_value("info", require_info)?
|
|
.with_value("print", require_print)?
|
|
.with_function(
|
|
"load",
|
|
|lua, (context, require_source, require_path): (RequireContext, String, String)| {
|
|
let (absolute_path, relative_path) =
|
|
context.get_paths(require_source, require_path)?;
|
|
// NOTE: We can not acquire the lock in the async part of the require
|
|
// load process since several requires may have happened for the
|
|
// same path before the async load task even gets a chance to run
|
|
let has_lock = context.try_acquire_lock_sync(&absolute_path);
|
|
let fut = load(lua, context, absolute_path, relative_path, has_lock);
|
|
let sched = lua
|
|
.app_data_ref::<&TaskScheduler>()
|
|
.expect("Missing task scheduler as a lua app data");
|
|
sched.queue_async_task_inherited(lua.current_thread(), None, async {
|
|
let rets = fut.await?;
|
|
let mult = rets.to_lua_multi(lua)?;
|
|
Ok(Some(mult))
|
|
})
|
|
},
|
|
)?
|
|
.build_readonly()?;
|
|
|
|
let require_fn_lua = lua
|
|
.load(REQUIRE_IMPL_LUA)
|
|
.set_name("require")?
|
|
.set_environment(require_env)?
|
|
.into_function()?;
|
|
Ok(require_fn_lua)
|
|
}
|