Implement functionality necessary for relative path requires

This commit is contained in:
Filip Tibell 2023-08-19 23:00:05 -05:00
parent a91e24eb01
commit d6c31f67ba
3 changed files with 93 additions and 35 deletions

View file

@ -12,7 +12,7 @@ const REGISTRY_KEY: &str = "RequireContext";
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(super) struct RequireContext { pub(super) struct RequireContext {
use_absolute_paths: bool, use_cwd_relative_paths: bool,
working_directory: PathBuf, working_directory: PathBuf,
cache_builtins: Arc<AsyncMutex<HashMap<LuneBuiltin, LuaResult<LuaRegistryKey>>>>, cache_builtins: Arc<AsyncMutex<HashMap<LuneBuiltin, LuaResult<LuaRegistryKey>>>>,
cache_results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>, cache_results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>,
@ -30,9 +30,9 @@ impl RequireContext {
pub fn new() -> Self { pub fn new() -> Self {
let cwd = env::current_dir().expect("Failed to get current working directory"); let cwd = env::current_dir().expect("Failed to get current working directory");
Self { Self {
// TODO: Set to false by default, load some kind of config // FUTURE: We could load some kind of config or env var
// or env var to check if we should be using absolute paths // to check if we should be using cwd-relative paths
use_absolute_paths: true, use_cwd_relative_paths: false,
working_directory: cwd, working_directory: cwd,
cache_builtins: Arc::new(AsyncMutex::new(HashMap::new())), cache_builtins: Arc::new(AsyncMutex::new(HashMap::new())),
cache_results: Arc::new(AsyncMutex::new(HashMap::new())), cache_results: Arc::new(AsyncMutex::new(HashMap::new())),
@ -41,10 +41,10 @@ impl RequireContext {
} }
/** /**
If `require` should use absolute paths or not. If `require` should use cwd-relative paths or not.
*/ */
pub fn use_absolute_paths(&self) -> bool { pub fn use_cwd_relative_paths(&self) -> bool {
self.use_absolute_paths self.use_cwd_relative_paths
} }
/** /**

View file

@ -1,6 +1,6 @@
use mlua::prelude::*; use mlua::prelude::*;
use crate::lune::scheduler::LuaSchedulerExt; use crate::lune::{scheduler::LuaSchedulerExt, util::TableBuilder};
mod context; mod context;
use context::RequireContext; use context::RequireContext;
@ -10,35 +10,92 @@ mod alias;
mod builtin; mod builtin;
mod relative; mod relative;
const REQUIRE_IMPL: &str = r#"
return require(source(), ...)
"#;
pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> { pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> {
lua.set_app_data(RequireContext::new()); lua.set_app_data(RequireContext::new());
lua.create_async_function(|lua, path: LuaString| async move {
let path = path
.to_str()
.into_lua_err()
.context("Failed to parse require path as string")?
.to_string();
let context = lua /*
.app_data_ref() Require implementation needs a few workarounds:
.expect("Failed to get RequireContext from app data");
let res = if let Some(builtin_name) = path - Async functions run outside of the lua resumption cycle,
.strip_prefix("@lune/") so the current lua thread, as well as its stack/debug info
.map(|name| name.to_ascii_lowercase()) is not available, meaning we have to use a normal function
{
builtin::require(lua, &context, &builtin_name).await
} else if let Some(aliased_path) = path.strip_prefix('@') {
let (alias, name) = aliased_path.split_once('/').ok_or(LuaError::runtime(
"Require with custom alias must contain '/' delimiter",
))?;
alias::require(lua, &context, alias, name).await
} else if context.use_absolute_paths() {
absolute::require(lua, &context, &path).await
} else {
relative::require(lua, &context, &path).await
};
res.clone() - Using the async require function directly in another lua function
}) would mean yielding across the metamethod/c-call boundary, meaning
we have to first load our two functions into a normal lua chunk
and then load that new chunk into our final require function
Also note that we inspect the stack at level 2:
1. The current c / rust function
2. The wrapper lua chunk defined above
3. The lua chunk we are require-ing from
*/
let require_fn = lua.create_async_function(require)?;
let get_source_fn = lua.create_function(move |lua, _: ()| match lua.inspect_stack(2) {
None => Err(LuaError::runtime(
"Failed to get stack info for require source",
)),
Some(info) => match info.source().source {
None => Err(LuaError::runtime(
"Stack info is missing source for require",
)),
Some(source) => lua.create_string(source.as_bytes()),
},
})?;
let require_env = TableBuilder::new(lua)?
.with_value("source", get_source_fn)?
.with_value("require", require_fn)?
.build_readonly()?;
lua.load(REQUIRE_IMPL)
.set_name("require")
.set_environment(require_env)
.into_function()
}
async fn require<'lua>(
lua: &'lua Lua,
(source, path): (LuaString<'lua>, LuaString<'lua>),
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
{
let source = source
.to_str()
.into_lua_err()
.context("Failed to parse require source as string")?
.to_string();
let path = path
.to_str()
.into_lua_err()
.context("Failed to parse require path as string")?
.to_string();
let context = lua
.app_data_ref()
.expect("Failed to get RequireContext from app data");
if let Some(builtin_name) = path
.strip_prefix("@lune/")
.map(|name| name.to_ascii_lowercase())
{
builtin::require(lua, &context, &builtin_name).await
} else if let Some(aliased_path) = path.strip_prefix('@') {
let (alias, name) = aliased_path.split_once('/').ok_or(LuaError::runtime(
"Require with custom alias must contain '/' delimiter",
))?;
alias::require(lua, &context, alias, name).await
} else if context.use_cwd_relative_paths() {
absolute::require(lua, &context, &path).await
} else {
relative::require(lua, &context, &source, &path).await
}
} }

View file

@ -5,12 +5,13 @@ use super::context::*;
pub(super) async fn require<'lua, 'ctx>( pub(super) async fn require<'lua, 'ctx>(
_lua: &'lua Lua, _lua: &'lua Lua,
_ctx: &'ctx RequireContext, _ctx: &'ctx RequireContext,
source: &str,
path: &str, path: &str,
) -> LuaResult<LuaMultiValue<'lua>> ) -> LuaResult<LuaMultiValue<'lua>>
where where
'lua: 'ctx, 'lua: 'ctx,
{ {
Err(LuaError::runtime(format!( Err(LuaError::runtime(format!(
"TODO: Support require for absolute paths (tried to require '{path}')" "TODO: Support require for absolute paths (tried to require '{path}' from '{source}')"
))) )))
} }