luaurc rewrite

This commit is contained in:
highflowey 2024-08-23 00:26:39 +03:30
parent ff83c401b8
commit ba2d7203f8
8 changed files with 101 additions and 744 deletions

View file

@ -1,73 +0,0 @@
use mlua::prelude::*;
use lune_utils::path::{clean_path_and_make_absolute, diff_path, get_current_dir};
use crate::luaurc::LuauRc;
use super::context::*;
pub(super) async fn require<'lua, 'ctx>(
lua: &'lua Lua,
ctx: &'ctx RequireContext,
source: &str,
alias: &str,
path: &str,
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'ctx,
{
let alias = alias.to_ascii_lowercase();
let parent = clean_path_and_make_absolute(source)
.parent()
.expect("how did a root path end up here..")
.to_path_buf();
// Try to gather the first luaurc and / or error we
// encounter to display better error messages to users
let mut first_luaurc = None;
let mut first_error = None;
let predicate = |rc: &LuauRc| {
if first_luaurc.is_none() {
first_luaurc.replace(rc.clone());
}
if let Err(e) = rc.validate() {
if first_error.is_none() {
first_error.replace(e);
}
false
} else {
rc.find_alias(&alias).is_some()
}
};
// Try to find a luaurc that contains the alias we're searching for
let luaurc = LuauRc::read_recursive(parent, predicate)
.await
.ok_or_else(|| {
if let Some(error) = first_error {
LuaError::runtime(format!("error while parsing .luaurc file: {error}"))
} else if let Some(luaurc) = first_luaurc {
LuaError::runtime(format!(
"failed to find alias '{alias}' - known aliases:\n{}",
luaurc
.aliases()
.iter()
.map(|(name, path)| format!(" {name} > {path}"))
.collect::<Vec<_>>()
.join("\n")
))
} else {
LuaError::runtime(format!("failed to find alias '{alias}' (no .luaurc)"))
}
})?;
// We now have our aliased path, our path require function just needs it
// in a slightly different format with both absolute + relative to cwd
let abs_path = luaurc.find_alias(&alias).unwrap().join(path);
let rel_path = diff_path(&abs_path, get_current_dir()).ok_or_else(|| {
LuaError::runtime(format!("failed to find relative path for alias '{alias}'"))
})?;
super::path::require_abs_rel(lua, ctx, abs_path, rel_path).await
}

View file

@ -1,289 +0,0 @@
use std::{
collections::HashMap,
path::{Path, PathBuf},
sync::Arc,
};
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSchedulerExt;
use tokio::{
fs::read,
sync::{
broadcast::{self, Sender},
Mutex as AsyncMutex,
},
};
use lune_utils::path::{clean_path, clean_path_and_make_absolute};
use crate::library::LuneStandardLibrary;
/**
Context containing cached results for all `require` operations.
The cache uses absolute paths, so any given relative
path will first be transformed into an absolute path.
*/
#[derive(Debug, Clone)]
pub(super) struct RequireContext {
libraries: Arc<AsyncMutex<HashMap<LuneStandardLibrary, LuaResult<LuaRegistryKey>>>>,
results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>,
pending: Arc<AsyncMutex<HashMap<PathBuf, Sender<()>>>>,
}
impl RequireContext {
/**
Creates a new require context for the given [`Lua`] struct.
Note that this require context is global and only one require
context should be created per [`Lua`] struct, creating more
than one context may lead to undefined require-behavior.
*/
pub fn new() -> Self {
Self {
libraries: Arc::new(AsyncMutex::new(HashMap::new())),
results: Arc::new(AsyncMutex::new(HashMap::new())),
pending: Arc::new(AsyncMutex::new(HashMap::new())),
}
}
/**
Resolves the given `source` and `path` into require paths
to use, based on the current require context settings.
This will resolve path segments such as `./`, `../`, ..., and
if the resolved path is not an absolute path, will create an
absolute path by prepending the current working directory.
*/
pub fn resolve_paths(
source: impl AsRef<str>,
path: impl AsRef<str>,
) -> LuaResult<(PathBuf, PathBuf)> {
let path = PathBuf::from(source.as_ref())
.parent()
.ok_or_else(|| LuaError::runtime("Failed to get parent path of source"))?
.join(path.as_ref());
let abs_path = clean_path_and_make_absolute(&path);
let rel_path = clean_path(path);
Ok((abs_path, rel_path))
}
/**
Checks if the given path has a cached require result.
*/
pub fn is_cached(&self, abs_path: impl AsRef<Path>) -> LuaResult<bool> {
let is_cached = self
.results
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.contains_key(abs_path.as_ref());
Ok(is_cached)
}
/**
Checks if the given path is currently being used in `require`.
*/
pub fn is_pending(&self, abs_path: impl AsRef<Path>) -> LuaResult<bool> {
let is_pending = self
.pending
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.contains_key(abs_path.as_ref());
Ok(is_pending)
}
/**
Gets the resulting value from the require cache.
Will panic if the path has not been cached, use [`is_cached`] first.
*/
pub fn get_from_cache<'lua>(
&self,
lua: &'lua Lua,
abs_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> {
let results = self
.results
.try_lock()
.expect("RequireContext may not be used from multiple threads");
let cached = results
.get(abs_path.as_ref())
.expect("Path does not exist in results cache");
match cached {
Err(e) => Err(e.clone()),
Ok(k) => {
let multi_vec = lua
.registry_value::<Vec<LuaValue>>(k)
.expect("Missing require result in lua registry");
Ok(LuaMultiValue::from_vec(multi_vec))
}
}
}
/**
Waits for the resulting value from the require cache.
Will panic if the path has not been cached, use [`is_cached`] first.
*/
pub async fn wait_for_cache<'lua>(
&self,
lua: &'lua Lua,
abs_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> {
let mut thread_recv = {
let pending = self
.pending
.try_lock()
.expect("RequireContext may not be used from multiple threads");
let thread_id = pending
.get(abs_path.as_ref())
.expect("Path is not currently pending require");
thread_id.subscribe()
};
thread_recv.recv().await.into_lua_err()?;
self.get_from_cache(lua, abs_path.as_ref())
}
async fn load<'lua>(
&self,
lua: &'lua Lua,
abs_path: impl AsRef<Path>,
rel_path: impl AsRef<Path>,
) -> LuaResult<LuaRegistryKey> {
let abs_path = abs_path.as_ref();
let rel_path = rel_path.as_ref();
// Read the file at the given path, try to parse and
// load it into a new lua thread that we can schedule
let file_contents = read(&abs_path).await?;
let file_thread = lua
.load(file_contents)
.set_name(rel_path.to_string_lossy().to_string());
// Schedule the thread to run, wait for it to finish running
let thread_id = lua.push_thread_back(file_thread, ())?;
lua.track_thread(thread_id);
lua.wait_for_thread(thread_id).await;
let thread_res = lua.get_thread_result(thread_id).unwrap();
// Return the result of the thread, storing any lua value(s) in the registry
match thread_res {
Err(e) => Err(e),
Ok(v) => {
let multi_vec = v.into_vec();
let multi_key = lua
.create_registry_value(multi_vec)
.expect("Failed to store require result in registry - out of memory");
Ok(multi_key)
}
}
}
/**
Loads (requires) the file at the given path.
*/
pub async fn load_with_caching<'lua>(
&self,
lua: &'lua Lua,
abs_path: impl AsRef<Path>,
rel_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> {
let abs_path = abs_path.as_ref();
let rel_path = rel_path.as_ref();
// Set this abs path as currently pending
let (broadcast_tx, _) = broadcast::channel(1);
self.pending
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.insert(abs_path.to_path_buf(), broadcast_tx);
// Try to load at this abs path
let load_res = self.load(lua, abs_path, rel_path).await;
let load_val = match &load_res {
Err(e) => Err(e.clone()),
Ok(k) => {
let multi_vec = lua
.registry_value::<Vec<LuaValue>>(k)
.expect("Failed to fetch require result from registry");
Ok(LuaMultiValue::from_vec(multi_vec))
}
};
// NOTE: We use the async lock and not try_lock here because
// some other thread may be wanting to insert into the require
// cache at the same time, and that's not an actual error case
self.results
.lock()
.await
.insert(abs_path.to_path_buf(), load_res);
// Remove the pending thread id from the require context,
// broadcast a message to let any listeners know that this
// path has now finished the require process and is cached
let broadcast_tx = self
.pending
.try_lock()
.expect("RequireContext may not be used from multiple threads")
.remove(abs_path)
.expect("Pending require broadcaster was unexpectedly removed");
broadcast_tx.send(()).ok();
load_val
}
/**
Loads (requires) the library with the given name.
*/
pub fn load_library<'lua>(
&self,
lua: &'lua Lua,
name: impl AsRef<str>,
) -> LuaResult<LuaMultiValue<'lua>> {
let library: LuneStandardLibrary = match name.as_ref().parse() {
Err(e) => return Err(LuaError::runtime(e)),
Ok(b) => b,
};
let mut cache = self
.libraries
.try_lock()
.expect("RequireContext may not be used from multiple threads");
if let Some(res) = cache.get(&library) {
return match res {
Err(e) => return Err(e.clone()),
Ok(key) => {
let multi_vec = lua
.registry_value::<Vec<LuaValue>>(key)
.expect("Missing library result in lua registry");
Ok(LuaMultiValue::from_vec(multi_vec))
}
};
};
let result = library.module(lua);
cache.insert(
library,
match result.clone() {
Err(e) => Err(e),
Ok(multi) => {
let multi_vec = multi.into_vec();
let multi_key = lua
.create_registry_value(multi_vec)
.expect("Failed to store require result in registry - out of memory");
Ok(multi_key)
}
},
);
result
}
}

View file

@ -1,14 +0,0 @@
use mlua::prelude::*;
use super::context::*;
pub(super) fn require<'lua, 'ctx>(
lua: &'lua Lua,
ctx: &'ctx RequireContext,
name: &str,
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'ctx,
{
ctx.load_library(lua, name)
}

View file

@ -1,93 +1,5 @@
use mlua::prelude::*;
use lune_utils::TableBuilder;
mod context;
use context::RequireContext;
mod alias;
mod library;
mod path;
const REQUIRE_IMPL: &str = r"
return require(source(), ...)
";
pub fn create(lua: &Lua) -> LuaResult<LuaValue> {
lua.set_app_data(RequireContext::new());
/*
Require implementation needs a few workarounds:
- Async functions run outside of the lua resumption cycle,
so the current lua thread, as well as its stack/debug info
is not available, meaning we have to use a normal function
- Using the async require function directly in another lua function
would mean yielding across the metamethod/c-call boundary, meaning
we have to first load our two functions into a normal lua chunk
and then load that new chunk into our final require function
Also note that we inspect the stack at level 2:
1. The current c / rust function
2. The wrapper lua chunk defined above
3. The lua chunk we are require-ing from
*/
let require_fn = lua.create_async_function(require)?;
let get_source_fn = lua.create_function(move |lua, (): ()| match lua.inspect_stack(2) {
None => Err(LuaError::runtime(
"Failed to get stack info for require source",
)),
Some(info) => match info.source().source {
None => Err(LuaError::runtime(
"Stack info is missing source for require",
)),
Some(source) => lua.create_string(source.as_bytes()),
},
})?;
let require_env = TableBuilder::new(lua)?
.with_value("source", get_source_fn)?
.with_value("require", require_fn)?
.build_readonly()?;
lua.load(REQUIRE_IMPL)
.set_name("require")
.set_environment(require_env)
.into_function()?
.into_lua(lua)
}
async fn require<'lua>(
lua: &'lua Lua,
(source, path): (LuaString<'lua>, LuaString<'lua>),
) -> LuaResult<LuaMultiValue<'lua>> {
let source = source
.to_str()
.into_lua_err()
.context("Failed to parse require source as string")?
.to_string();
let path = path
.to_str()
.into_lua_err()
.context("Failed to parse require path as string")?
.to_string();
let context = lua
.app_data_ref()
.expect("Failed to get RequireContext from app data");
if let Some(builtin_name) = path.strip_prefix("@lune/").map(str::to_ascii_lowercase) {
library::require(lua, &context, &builtin_name)
} else if let Some(aliased_path) = path.strip_prefix('@') {
let (alias, path) = aliased_path.split_once('/').ok_or(LuaError::runtime(
"Require with custom alias must contain '/' delimiter",
))?;
alias::require(lua, &context, &source, alias, path).await
} else {
path::require(lua, &context, &source, &path).await
}
todo!()
}

View file

@ -1,129 +0,0 @@
use std::path::{Path, PathBuf};
use mlua::prelude::*;
use mlua::Error::ExternalError;
use super::context::*;
pub(super) async fn require<'lua, 'ctx>(
lua: &'lua Lua,
ctx: &'ctx RequireContext,
source: &str,
path: &str,
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'ctx,
{
let (abs_path, rel_path) = RequireContext::resolve_paths(source, path)?;
require_abs_rel(lua, ctx, abs_path, rel_path).await
}
pub(super) async fn require_abs_rel<'lua, 'ctx>(
lua: &'lua Lua,
ctx: &'ctx RequireContext,
abs_path: PathBuf, // Absolute to filesystem
rel_path: PathBuf, // Relative to CWD (for displaying)
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'ctx,
{
// 1. Try to require the exact path
match require_inner(lua, ctx, &abs_path, &rel_path).await {
Ok(res) => return Ok(res),
Err(err) => {
if !is_file_not_found_error(&err) {
return Err(err);
}
}
}
// 2. Try to require the path with an added "luau" extension
// 3. Try to require the path with an added "lua" extension
for extension in ["luau", "lua"] {
match require_inner(
lua,
ctx,
&append_extension(&abs_path, extension),
&append_extension(&rel_path, extension),
)
.await
{
Ok(res) => return Ok(res),
Err(err) => {
if !is_file_not_found_error(&err) {
return Err(err);
}
}
}
}
// We didn't find any direct file paths, look
// for directories with "init" files in them...
let abs_init = abs_path.join("init");
let rel_init = rel_path.join("init");
// 4. Try to require the init path with an added "luau" extension
// 5. Try to require the init path with an added "lua" extension
for extension in ["luau", "lua"] {
match require_inner(
lua,
ctx,
&append_extension(&abs_init, extension),
&append_extension(&rel_init, extension),
)
.await
{
Ok(res) => return Ok(res),
Err(err) => {
if !is_file_not_found_error(&err) {
return Err(err);
}
}
}
}
// Nothing left to try, throw an error
Err(LuaError::runtime(format!(
"No file exists at the path '{}'",
rel_path.display()
)))
}
async fn require_inner<'lua, 'ctx>(
lua: &'lua Lua,
ctx: &'ctx RequireContext,
abs_path: impl AsRef<Path>,
rel_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'ctx,
{
let abs_path = abs_path.as_ref();
let rel_path = rel_path.as_ref();
if ctx.is_cached(abs_path)? {
ctx.get_from_cache(lua, abs_path)
} else if ctx.is_pending(abs_path)? {
ctx.wait_for_cache(lua, &abs_path).await
} else {
ctx.load_with_caching(lua, &abs_path, &rel_path).await
}
}
fn append_extension(path: impl Into<PathBuf>, ext: &'static str) -> PathBuf {
let mut new = path.into();
match new.extension() {
// FUTURE: There's probably a better way to do this than converting to a lossy string
Some(e) => new.set_extension(format!("{}.{ext}", e.to_string_lossy())),
None => new.set_extension(ext),
};
new
}
fn is_file_not_found_error(err: &LuaError) -> bool {
if let ExternalError(err) = err {
err.as_ref().downcast_ref::<std::io::Error>().is_some()
} else {
false
}
}

View file

@ -6,6 +6,7 @@ mod global;
mod globals;
mod library;
mod luaurc;
mod path;
pub use self::global::LuneStandardGlobal;
pub use self::globals::version::set_global_version;

View file

@ -1,164 +1,84 @@
use crate::path::get_parent_path;
use mlua::ExternalResult;
use serde::Deserialize;
use std::{
collections::HashMap,
path::{Path, PathBuf, MAIN_SEPARATOR},
sync::Arc,
env::current_dir,
path::{Path, PathBuf},
};
use tokio::fs;
use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use tokio::fs::read;
use lune_utils::path::{clean_path, clean_path_and_make_absolute};
const LUAURC_FILE: &str = ".luaurc";
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
enum LuauLanguageMode {
NoCheck,
NonStrict,
Strict,
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct RequireAlias<'a> {
pub alias: &'a str,
pub path: &'a str,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct LuauRcConfig {
#[serde(skip_serializing_if = "Option::is_none")]
language_mode: Option<LuauLanguageMode>,
#[serde(skip_serializing_if = "Option::is_none")]
lint: Option<HashMap<String, JsonValue>>,
#[serde(skip_serializing_if = "Option::is_none")]
lint_errors: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
type_errors: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
globals: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
paths: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
aliases: Option<HashMap<String, String>>,
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct Luaurc {
aliases: Option<HashMap<String, PathBuf>>,
}
/**
A deserialized `.luaurc` file.
Contains utility methods for validating and searching for aliases.
*/
#[derive(Debug, Clone)]
pub struct LuauRc {
dir: Arc<Path>,
config: LuauRcConfig,
}
impl LuauRc {
/**
Reads a `.luaurc` file from the given directory.
If the file does not exist, or if it is invalid, this function returns `None`.
*/
pub async fn read(dir: impl AsRef<Path>) -> Option<Self> {
let dir = clean_path_and_make_absolute(dir);
let path = dir.join(LUAURC_FILE);
let bytes = read(&path).await.ok()?;
let config = serde_json::from_slice(&bytes).ok()?;
Some(Self {
dir: dir.into(),
config,
})
}
/**
Reads a `.luaurc` file from the given directory, and then recursively searches
for a `.luaurc` file in the parent directories if a predicate is not satisfied.
If no `.luaurc` file exists, or if they are invalid, this function returns `None`.
*/
pub async fn read_recursive(
dir: impl AsRef<Path>,
mut predicate: impl FnMut(&Self) -> bool,
) -> Option<Self> {
let mut current = clean_path_and_make_absolute(dir);
loop {
if let Some(rc) = Self::read(&current).await {
if predicate(&rc) {
return Some(rc);
}
}
if let Some(parent) = current.parent() {
current = parent.to_path_buf();
} else {
return None;
}
}
}
/**
Validates that the `.luaurc` file is correct.
This primarily validates aliases since they are not
validated during creation of the [`LuauRc`] struct.
# Errors
If an alias key is invalid.
*/
pub fn validate(&self) -> Result<(), String> {
if let Some(aliases) = &self.config.aliases {
for alias in aliases.keys() {
if !is_valid_alias_key(alias) {
return Err(format!("invalid alias key: {alias}"));
}
}
}
Ok(())
}
/**
Gets a copy of all aliases in the `.luaurc` file.
Will return an empty map if there are no aliases.
*/
#[must_use]
pub fn aliases(&self) -> HashMap<String, String> {
self.config.aliases.clone().unwrap_or_default()
}
/**
Finds an alias in the `.luaurc` file by name.
If the alias does not exist, this function returns `None`.
*/
#[must_use]
pub fn find_alias(&self, name: &str) -> Option<PathBuf> {
self.config.aliases.as_ref().and_then(|aliases| {
aliases.iter().find_map(|(alias, path)| {
if alias
.trim_end_matches(MAIN_SEPARATOR)
.eq_ignore_ascii_case(name)
&& is_valid_alias_key(alias)
{
Some(clean_path(self.dir.join(path)))
} else {
None
}
})
})
}
}
fn is_valid_alias_key(alias: impl AsRef<str>) -> bool {
let alias = alias.as_ref();
if alias.is_empty()
|| alias.starts_with('.')
|| alias.starts_with("..")
|| alias.chars().any(|c| c == MAIN_SEPARATOR)
/// Parses path into `RequireAlias` struct
///
/// ### Examples
///
/// `@lune/task` becomes `Some({ alias: "lune", path: "task" })`
///
/// `../path/script` becomes `None`
pub fn path_to_alias(path: &Path) -> Result<Option<RequireAlias>, mlua::Error> {
if let Some(aliased_path) = path
.to_str()
.ok_or(mlua::Error::runtime("Couldn't turn path into string"))?
.strip_prefix('@')
{
false // Paths are not valid alias keys
let (alias, path) = aliased_path.split_once('/').ok_or(mlua::Error::runtime(
"Require with alias doesn't contain '/'",
))?;
Ok(Some(RequireAlias { alias, path }))
} else {
alias.chars().all(is_valid_alias_char)
Ok(None)
}
}
fn is_valid_alias_char(c: char) -> bool {
c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.'
async fn parse_luaurc(_: &mlua::Lua, path: &PathBuf) -> Result<Option<Luaurc>, mlua::Error> {
if fs::try_exists(path).await? {
let content = fs::read(path).await?;
serde_json::from_slice(&content).map(Some).into_lua_err()
} else {
Ok(None)
}
}
/// Searches for .luaurc recursively
/// until an alias for the provided `RequireAlias` is found
pub async fn resolve_require_alias<'lua>(
lua: &'lua mlua::Lua,
alias: &'lua RequireAlias<'lua>,
) -> Result<PathBuf, mlua::Error> {
let cwd = current_dir()?;
let parent = cwd.join(get_parent_path(lua)?);
let ancestors = parent.ancestors();
for path in ancestors {
if path.starts_with(&cwd) {
if let Some(luaurc) = parse_luaurc(lua, &parent.join(".luaurc")).await? {
if let Some(aliases) = luaurc.aliases {
if let Some(alias_path) = aliases.get(alias.alias) {
let resolved = path.join(alias_path.join(alias.path));
return Ok(resolved);
}
}
}
} else {
break;
}
}
Err(mlua::Error::runtime(format!(
"Coudln't find the alias '{}' in any .luaurc file",
alias.alias
)))
}

View file

@ -0,0 +1,29 @@
use std::path::PathBuf;
pub fn get_script_path(lua: &mlua::Lua) -> Result<PathBuf, mlua::Error> {
let Some(debug) = lua.inspect_stack(2) else {
return Err(mlua::Error::runtime("Failed to inspect stack"));
};
match debug
.source()
.source
.map(|raw_source| PathBuf::from(raw_source.to_string()))
{
Some(script) => Ok(script),
None => Err(mlua::Error::runtime(
"Failed to get path of the script that called require",
)),
}
}
pub fn get_parent_path(lua: &mlua::Lua) -> Result<PathBuf, mlua::Error> {
let script = get_script_path(lua)?;
match script.parent() {
Some(parent) => Ok(parent.to_path_buf()),
None => Err(mlua::Error::runtime(
"Failed to get parent of the script that called require",
)),
}
}