diff --git a/packages/lib/src/globals/require.rs b/packages/lib/src/globals/require.rs index edacfc8..2040600 100644 --- a/packages/lib/src/globals/require.rs +++ b/packages/lib/src/globals/require.rs @@ -19,18 +19,18 @@ local source = info(1, "s") if source == '[string "require"]' then source = info(2, "s") end -local absolute, relative = importer:paths(source, ...) -return importer:load(thread(), absolute, relative) +local absolute, relative = paths(context, source, ...) +return load(context, absolute, relative) "#; #[derive(Debug, Clone, Default)] -struct Importer<'lua> { +struct RequireContext<'lua> { builtins: HashMap>, cached: RefCell>>>, pwd: String, } -impl<'lua> Importer<'lua> { +impl<'lua> RequireContext<'lua> { pub fn new() -> Self { let mut pwd = current_dir() .expect("Failed to access current working directory") @@ -44,134 +44,136 @@ impl<'lua> Importer<'lua> { ..Default::default() } } +} - fn paths(&self, require_source: String, require_path: String) -> LuaResult<(String, String)> { - if require_path.starts_with('@') { - return Ok((require_path.clone(), require_path)); - } - let path_relative_to_pwd = PathBuf::from( - &require_source - .trim_start_matches("[string \"") - .trim_end_matches("\"]"), - ) - .parent() - .unwrap() - .join(&require_path); - // Try to normalize and resolve relative path segments such as './' and '../' - let file_path = match ( - canonicalize(path_relative_to_pwd.with_extension("luau")), - canonicalize(path_relative_to_pwd.with_extension("lua")), - ) { - (Ok(luau), _) => luau, - (_, Ok(lua)) => lua, - _ => { - return Err(LuaError::RuntimeError(format!( - "File does not exist at path '{require_path}'" - ))) - } - }; - let absolute = file_path.to_string_lossy().to_string(); - let relative = absolute.trim_start_matches(&self.pwd).to_string(); - Ok((absolute, relative)) +impl<'lua> LuaUserData for RequireContext<'lua> {} + +fn paths( + context: RequireContext, + require_source: String, + require_path: String, +) -> LuaResult<(String, String)> { + if require_path.starts_with('@') { + return Ok((require_path.clone(), require_path)); } - - fn load_builtin(&self, module_name: &str) -> LuaResult { - match self.builtins.get(module_name) { - Some(module) => Ok(module.clone()), - None => Err(LuaError::RuntimeError(format!( - "No builtin module exists with the name '{}'", - module_name - ))), + let path_relative_to_pwd = PathBuf::from( + &require_source + .trim_start_matches("[string \"") + .trim_end_matches("\"]"), + ) + .parent() + .unwrap() + .join(&require_path); + // Try to normalize and resolve relative path segments such as './' and '../' + let file_path = match ( + canonicalize(path_relative_to_pwd.with_extension("luau")), + canonicalize(path_relative_to_pwd.with_extension("lua")), + ) { + (Ok(luau), _) => luau, + (_, Ok(lua)) => lua, + _ => { + return Err(LuaError::RuntimeError(format!( + "File does not exist at path '{require_path}'" + ))) } - } + }; + let absolute = file_path.to_string_lossy().to_string(); + let relative = absolute.trim_start_matches(&context.pwd).to_string(); + Ok((absolute, relative)) +} - async fn load_file( - &self, - lua: &'lua Lua, - absolute_path: String, - relative_path: String, - ) -> LuaResult { - let cached = { self.cached.borrow().get(&absolute_path).cloned() }; - match cached { - Some(cached) => cached, - None => { - // Try to read the wanted file, note that we use bytes instead of reading - // to a string since lua scripts are not necessarily valid utf-8 strings - let contents = fs::read(&absolute_path).await.map_err(LuaError::external)?; - // Use a name without extensions for loading the chunk, some - // other code assumes the require path is without extensions - let path_relative_no_extension = relative_path - .trim_end_matches(".lua") - .trim_end_matches(".luau"); - // Load the file into a thread - let loaded_func = lua - .load(&contents) - .set_name(path_relative_no_extension)? - .into_function()?; - let loaded_thread = lua.create_thread(loaded_func)?; - // Run the thread and provide a channel that will - // then get its result received when it finishes - let (tx, rx) = oneshot::channel(); - { - let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); - let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?; - sched.set_task_result_sender(task, tx); - } - // Wait for the thread to finish running, cache + return our result - let rets = rx.await.expect("Sender was dropped during require"); - self.cached.borrow_mut().insert(absolute_path, rets.clone()); - rets - } - } +fn load_builtin<'lua>( + _lua: &'lua Lua, + context: RequireContext<'lua>, + module_name: String, +) -> LuaResult> { + match context.builtins.get(&module_name) { + Some(module) => Ok(module.clone()), + None => Err(LuaError::RuntimeError(format!( + "No builtin module exists with the name '{}'", + module_name + ))), } +} - async fn load( - &self, - lua: &'lua Lua, - absolute_path: String, - relative_path: String, - ) -> LuaResult { - if absolute_path == relative_path && absolute_path.starts_with('@') { - if let Some(module_name) = absolute_path.strip_prefix("@lune/") { - self.load_builtin(module_name) - } else { - Err(LuaError::RuntimeError( - "Require paths prefixed by '@' are not yet supported".to_string(), - )) +async fn load_file<'lua>( + lua: &'lua Lua, + context: RequireContext<'lua>, + absolute_path: String, + relative_path: String, +) -> LuaResult> { + let cached = { context.cached.borrow().get(&absolute_path).cloned() }; + match cached { + Some(cached) => cached, + None => { + // Try to read the wanted file, note that we use bytes instead of reading + // to a string since lua scripts are not necessarily valid utf-8 strings + let contents = fs::read(&absolute_path).await.map_err(LuaError::external)?; + // Use a name without extensions for loading the chunk, some + // other code assumes the require path is without extensions + let path_relative_no_extension = relative_path + .trim_end_matches(".lua") + .trim_end_matches(".luau"); + // Load the file into a thread + let loaded_func = lua + .load(&contents) + .set_name(path_relative_no_extension)? + .into_function()?; + let loaded_thread = lua.create_thread(loaded_func)?; + // Run the thread and provide a channel that will + // then get its result received when it finishes + let (tx, rx) = oneshot::channel(); + { + let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); + let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?; + sched.set_task_result_sender(task, tx); } - } else { - self.load_file(lua, absolute_path, relative_path).await + // Wait for the thread to finish running, cache + return our result + // FIXME: This waits indefinitely for nested requires for some reason + let rets = rx.await.expect("Sender was dropped during require"); + context + .cached + .borrow_mut() + .insert(absolute_path, rets.clone()); + rets } } } -impl<'i> LuaUserData for Importer<'i> { - fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { - methods.add_method( - "paths", - |_, this, (require_source, require_path): (String, String)| { - this.paths(require_source, require_path) - }, - ); - methods.add_method( - "load", - |lua, this, (thread, absolute_path, relative_path): (LuaThread, String, String)| { - // TODO: Make this work - // this.load(lua, absolute_path, relative_path) - Ok(()) - }, - ); +async fn load<'lua>( + lua: &'lua Lua, + context: RequireContext<'lua>, + absolute_path: String, + relative_path: String, +) -> LuaResult> { + if absolute_path == relative_path && absolute_path.starts_with('@') { + if let Some(module_name) = absolute_path.strip_prefix("@lune/") { + load_builtin(lua, context, module_name.to_string()) + } else { + Err(LuaError::RuntimeError( + "Require paths prefixed by '@' are not yet supported".to_string(), + )) + } + } else { + load_file(lua, context, absolute_path, relative_path).await } } pub fn create(lua: &'static Lua) -> LuaResult { - let require_importer = Importer::new(); - let require_thread: LuaFunction = lua.named_registry_value("co.thread")?; + let require_context = RequireContext::new(); + let require_print: LuaFunction = lua.named_registry_value("print")?; let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; + let require_env = TableBuilder::new(lua)? - .with_value("importer", require_importer)? - .with_value("thread", require_thread)? + .with_value("context", require_context)? + .with_value("print", require_print)? .with_value("info", require_info)? + .with_function("paths", |_, (context, require_source, require_path)| { + paths(context, require_source, require_path) + })? + .with_async_function("load", |lua, (context, require_source, require_path)| { + load(lua, context, require_source, require_path) + })? .build_readonly()?; let require_fn_lua = lua