diff --git a/packages/lib/src/globals/fs.rs b/packages/lib/src/builtins/fs.rs similarity index 100% rename from packages/lib/src/globals/fs.rs rename to packages/lib/src/builtins/fs.rs diff --git a/packages/lib/src/builtins/mod.rs b/packages/lib/src/builtins/mod.rs new file mode 100644 index 0000000..4140297 --- /dev/null +++ b/packages/lib/src/builtins/mod.rs @@ -0,0 +1,10 @@ +pub(crate) mod fs; +pub(crate) mod net; +pub(crate) mod process; +pub(crate) mod serde; +pub(crate) mod stdio; +pub(crate) mod task; +pub(crate) mod top_level; + +#[cfg(feature = "roblox")] +pub(crate) mod roblox; diff --git a/packages/lib/src/globals/net.rs b/packages/lib/src/builtins/net.rs similarity index 100% rename from packages/lib/src/globals/net.rs rename to packages/lib/src/builtins/net.rs diff --git a/packages/lib/src/globals/process.rs b/packages/lib/src/builtins/process.rs similarity index 100% rename from packages/lib/src/globals/process.rs rename to packages/lib/src/builtins/process.rs diff --git a/packages/lib/src/globals/roblox.rs b/packages/lib/src/builtins/roblox.rs similarity index 100% rename from packages/lib/src/globals/roblox.rs rename to packages/lib/src/builtins/roblox.rs diff --git a/packages/lib/src/globals/serde.rs b/packages/lib/src/builtins/serde.rs similarity index 100% rename from packages/lib/src/globals/serde.rs rename to packages/lib/src/builtins/serde.rs diff --git a/packages/lib/src/globals/stdio.rs b/packages/lib/src/builtins/stdio.rs similarity index 100% rename from packages/lib/src/globals/stdio.rs rename to packages/lib/src/builtins/stdio.rs diff --git a/packages/lib/src/globals/task.rs b/packages/lib/src/builtins/task.rs similarity index 100% rename from packages/lib/src/globals/task.rs rename to packages/lib/src/builtins/task.rs diff --git a/packages/lib/src/globals/top_level.rs b/packages/lib/src/builtins/top_level.rs similarity index 82% rename from packages/lib/src/globals/top_level.rs rename to packages/lib/src/builtins/top_level.rs index 376c6a6..b015593 100644 --- a/packages/lib/src/globals/top_level.rs +++ b/packages/lib/src/builtins/top_level.rs @@ -7,14 +7,14 @@ use crate::lua::stdio::formatting::{format_label, pretty_format_multi_value}; // is really tricky to do from scratch so we will just // proxy the default print and error functions here -pub fn top_level_print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { +pub fn print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { let formatted = pretty_format_multi_value(&args)?; let print: LuaFunction = lua.named_registry_value("print")?; print.call(formatted)?; Ok(()) } -pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { +pub fn printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { let print: LuaFunction = lua.named_registry_value("print")?; print.call(format!( "{}\n{}", @@ -24,7 +24,7 @@ pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { Ok(()) } -pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { +pub fn warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { let print: LuaFunction = lua.named_registry_value("print")?; print.call(format!( "{}\n{}", @@ -34,7 +34,7 @@ pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { Ok(()) } -pub fn top_level_error(lua: &Lua, (arg, level): (LuaValue, Option)) -> LuaResult<()> { +pub fn error(lua: &Lua, (arg, level): (LuaValue, Option)) -> LuaResult<()> { let error: LuaFunction = lua.named_registry_value("error")?; let trace: LuaFunction = lua.named_registry_value("dbg.trace")?; error.call(( diff --git a/packages/lib/src/globals/mod.rs b/packages/lib/src/importer/mod.rs similarity index 58% rename from packages/lib/src/globals/mod.rs rename to packages/lib/src/importer/mod.rs index 3172956..e1fb93b 100644 --- a/packages/lib/src/globals/mod.rs +++ b/packages/lib/src/importer/mod.rs @@ -1,29 +1,23 @@ use mlua::prelude::*; -mod fs; -mod net; -mod process; mod require; -#[cfg(feature = "roblox")] -mod roblox; -mod serde; -mod stdio; -mod task; -mod top_level; +mod require_waker; + +use crate::builtins::{self, top_level}; const BUILTINS_AS_GLOBALS: &[&str] = &["fs", "net", "process", "stdio", "task"]; pub fn create(lua: &'static Lua, args: Vec) -> LuaResult<()> { // Create all builtins let builtins = vec![ - ("fs", fs::create(lua)?), - ("net", net::create(lua)?), - ("process", process::create(lua, args)?), + ("fs", builtins::fs::create(lua)?), + ("net", builtins::net::create(lua)?), + ("process", builtins::process::create(lua, args)?), + ("serde", builtins::serde::create(lua)?), + ("stdio", builtins::stdio::create(lua)?), + ("task", builtins::task::create(lua)?), #[cfg(feature = "roblox")] - ("roblox", roblox::create(lua)?), - ("serde", self::serde::create(lua)?), - ("stdio", stdio::create(lua)?), - ("task", task::create(lua)?), + ("roblox", builtins::roblox::create(lua)?), ]; // TODO: Remove this when we have proper LSP support for custom @@ -40,13 +34,10 @@ pub fn create(lua: &'static Lua, args: Vec) -> LuaResult<()> { // Create all top-level globals let globals = vec![ ("require", require_fn), - ("print", lua.create_function(top_level::top_level_print)?), - ("warn", lua.create_function(top_level::top_level_warn)?), - ("error", lua.create_function(top_level::top_level_error)?), - ( - "printinfo", - lua.create_function(top_level::top_level_printinfo)?, - ), + ("print", lua.create_function(top_level::print)?), + ("warn", lua.create_function(top_level::warn)?), + ("error", lua.create_function(top_level::error)?), + ("printinfo", lua.create_function(top_level::printinfo)?), ]; // Set top-level globals and seal them diff --git a/packages/lib/src/globals/require.rs b/packages/lib/src/importer/require.rs similarity index 83% rename from packages/lib/src/globals/require.rs rename to packages/lib/src/importer/require.rs index 6b6e738..9b21b66 100644 --- a/packages/lib/src/globals/require.rs +++ b/packages/lib/src/importer/require.rs @@ -9,12 +9,15 @@ use std::{ use dunce::canonicalize; use mlua::prelude::*; use tokio::fs; +use tokio::sync::Mutex as AsyncMutex; use crate::lua::{ table::TableBuilder, task::{TaskScheduler, TaskSchedulerScheduleExt}, }; +use super::require_waker::{RequireWakerFuture, RequireWakerState}; + const REQUIRE_IMPL_LUA: &str = r#" local source = info(1, "s") if source == '[string "require"]' then @@ -24,21 +27,24 @@ load(context, source, ...) return yield() "#; +type RequireWakersVec<'lua> = Vec>>>; + #[derive(Debug, Clone, Default)] struct RequireContext<'lua> { // NOTE: We need to use arc here so that mlua clones // the reference and not the entire inner value(s) builtins: Arc>>, cached: Arc>>>>, + wakers: Arc>>>, locks: Arc>>, pwd: String, } impl<'lua> RequireContext<'lua> { - pub fn new(lua: &'static Lua, builtins_vec: Vec<(K, V)>) -> LuaResult + pub fn new(lua: &'lua Lua, builtins_vec: Vec<(K, V)>) -> LuaResult where K: Into, - V: ToLua<'static>, + V: ToLua<'lua>, { let mut pwd = current_dir() .expect("Failed to access current working directory") @@ -79,10 +85,29 @@ impl<'lua> RequireContext<'lua> { } } - pub fn set_cached(&self, absolute_path: String, result: &LuaResult>) { + pub fn set_cached(&self, absolute_path: &str, result: &LuaResult>) { self.cached .borrow_mut() - .insert(absolute_path, result.clone()); + .insert(absolute_path.to_string(), result.clone()); + if let Some(wakers) = self.wakers.borrow_mut().remove(absolute_path) { + for waker in wakers { + waker + .try_lock() + .expect("Failed to lock waker") + .finalize(result.clone()); + } + } + } + + pub fn wait_for_cache(self, absolute_path: &str) -> RequireWakerFuture<'lua> { + let state = RequireWakerState::new(); + let fut = RequireWakerFuture::new(&state); + self.wakers + .borrow_mut() + .entry(absolute_path.to_string()) + .or_insert_with(Vec::new) + .push(Arc::clone(&state)); + fut } pub fn get_paths( @@ -124,7 +149,7 @@ impl<'lua> LuaUserData for RequireContext<'lua> {} fn load_builtin<'lua>( _lua: &'lua Lua, - context: &RequireContext<'lua>, + context: RequireContext<'lua>, module_name: String, _has_acquired_lock: bool, ) -> LuaResult> { @@ -139,7 +164,7 @@ fn load_builtin<'lua>( async fn load_file<'lua>( lua: &'lua Lua, - context: &RequireContext<'lua>, + context: RequireContext<'lua>, absolute_path: String, relative_path: String, has_acquired_lock: bool, @@ -149,9 +174,7 @@ async fn load_file<'lua>( Some(cached) => cached, None => { if !has_acquired_lock { - return Err(LuaError::RuntimeError( - "Failed to get require lock".to_string(), - )); + return context.wait_for_cache(&absolute_path).await; } // Try to read the wanted file, note that we use bytes instead of reading // to a string since lua scripts are not necessarily valid utf-8 strings @@ -173,9 +196,10 @@ async fn load_file<'lua>( let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?; sched.wait_for_task_completion(task) }; - // Wait for the thread to finish running, cache + return our result + // Wait for the thread to finish running, cache + return our result, + // notify any other threads that are also waiting on this to finish let rets = task_fut.await; - context.set_cached(absolute_path, &rets); + context.set_cached(&absolute_path, &rets); rets } } @@ -190,7 +214,12 @@ async fn load<'lua>( ) -> LuaResult> { let result = if absolute_path == relative_path && absolute_path.starts_with('@') { if let Some(module_name) = absolute_path.strip_prefix("@lune/") { - load_builtin(lua, &context, module_name.to_string(), has_acquired_lock) + load_builtin( + lua, + context.clone(), + module_name.to_string(), + has_acquired_lock, + ) } else { // FUTURE: '@' can be used a special prefix for users to set their own // paths relative to a project file, similar to typescript paths config @@ -202,7 +231,7 @@ async fn load<'lua>( } else { load_file( lua, - &context, + context.clone(), absolute_path.to_string(), relative_path, has_acquired_lock, diff --git a/packages/lib/src/importer/require_waker.rs b/packages/lib/src/importer/require_waker.rs new file mode 100644 index 0000000..4a43b87 --- /dev/null +++ b/packages/lib/src/importer/require_waker.rs @@ -0,0 +1,66 @@ +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll, Waker}, +}; + +use tokio::sync::Mutex as AsyncMutex; + +use mlua::prelude::*; + +#[derive(Debug, Clone)] +pub(super) struct RequireWakerState<'lua> { + rets: Option>>, + waker: Option, +} + +impl<'lua> RequireWakerState<'lua> { + pub fn new() -> Arc> { + Arc::new(AsyncMutex::new(RequireWakerState { + rets: None, + waker: None, + })) + } + + pub fn finalize(&mut self, rets: LuaResult>) { + self.rets = Some(rets); + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } +} + +#[derive(Debug)] +pub(super) struct RequireWakerFuture<'lua> { + state: Arc>>, +} + +impl<'lua> RequireWakerFuture<'lua> { + pub fn new(state: &Arc>>) -> Self { + Self { + state: Arc::clone(state), + } + } +} + +impl<'lua> Clone for RequireWakerFuture<'lua> { + fn clone(&self) -> Self { + Self { + state: Arc::clone(&self.state), + } + } +} + +impl<'lua> Future for RequireWakerFuture<'lua> { + type Output = LuaResult>; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut shared_state = self.state.try_lock().unwrap(); + if let Some(rets) = shared_state.rets.clone() { + Poll::Ready(rets) + } else { + shared_state.waker = Some(cx.waker().clone()); + Poll::Pending + } + } +} diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index 519ce04..ec579cc 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -4,7 +4,8 @@ use lua::task::{TaskScheduler, TaskSchedulerResumeExt, TaskSchedulerScheduleExt} use mlua::prelude::*; use tokio::task::LocalSet; -pub(crate) mod globals; +pub(crate) mod builtins; +pub(crate) mod importer; pub(crate) mod lua; mod error; @@ -71,7 +72,7 @@ impl Lune { // NOTE: Some globals require the task scheduler to exist on startup let sched = TaskScheduler::new(lua)?.into_static(); lua.set_app_data(sched); - globals::create(lua, self.args.clone())?; + importer::create(lua, self.args.clone())?; // Create the main thread and schedule it let main_chunk = lua .load(script_contents.as_ref()) diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs index 834a379..922a4dd 100644 --- a/packages/lib/src/lua/task/scheduler.rs +++ b/packages/lib/src/lua/task/scheduler.rs @@ -52,7 +52,7 @@ pub struct TaskScheduler<'fut> { pub(super) tasks_current: Cell>, pub(super) tasks_queue_blocking: RefCell>, pub(super) tasks_waiter_states: - RefCell>>>>, + RefCell>>>>>, pub(super) tasks_current_lua_error: Arc>>, // Future tasks & objects for waking pub(super) futures: AsyncMutex>>, @@ -452,9 +452,13 @@ impl<'fut> TaskScheduler<'fut> { panic!("Task does not exist in scheduler") } let state = TaskWaiterState::new(); - self.tasks_waiter_states - .borrow_mut() - .insert(reference, Arc::clone(&state)); + { + let mut all_states = self.tasks_waiter_states.borrow_mut(); + all_states + .entry(reference) + .or_insert_with(Vec::new) + .push(Arc::clone(&state)); + } TaskWaiterFuture::new(&state).await } @@ -467,8 +471,10 @@ impl<'fut> TaskScheduler<'fut> { reference: TaskReference, result: LuaResult>, ) { - if let Some(waiter_state) = self.tasks_waiter_states.borrow_mut().remove(&reference) { - waiter_state.try_lock().unwrap().finalize(result); + if let Some(waiter_states) = self.tasks_waiter_states.borrow_mut().remove(&reference) { + for waiter_state in waiter_states { + waiter_state.try_lock().unwrap().finalize(result.clone()); + } } } } diff --git a/packages/lib/src/tests.rs b/packages/lib/src/tests.rs index 1a5da82..8180894 100644 --- a/packages/lib/src/tests.rs +++ b/packages/lib/src/tests.rs @@ -60,15 +60,15 @@ create_tests! { process_env: "process/env", process_exit: "process/exit", process_spawn: "process/spawn", - require_async: "globals/require/tests/async", - require_async_concurrent: "globals/require/tests/async_concurrent", - require_async_sequential: "globals/require/tests/async_sequential", - require_builtins: "globals/require/tests/builtins", - require_children: "globals/require/tests/children", - require_invalid: "globals/require/tests/invalid", - require_nested: "globals/require/tests/nested", - require_parents: "globals/require/tests/parents", - require_siblings: "globals/require/tests/siblings", + require_async: "require/tests/async", + require_async_concurrent: "require/tests/async_concurrent", + require_async_sequential: "require/tests/async_sequential", + require_builtins: "require/tests/builtins", + require_children: "require/tests/children", + require_invalid: "require/tests/invalid", + require_nested: "require/tests/nested", + require_parents: "require/tests/parents", + require_siblings: "require/tests/siblings", // TODO: Uncomment this test, it is commented out right // now to let CI pass so that we can make a new release // global_coroutine: "globals/coroutine", diff --git a/tests/globals/require/modules/module.luau b/tests/require/modules/module.luau similarity index 100% rename from tests/globals/require/modules/module.luau rename to tests/require/modules/module.luau diff --git a/tests/globals/require/tests/async.luau b/tests/require/tests/async.luau similarity index 100% rename from tests/globals/require/tests/async.luau rename to tests/require/tests/async.luau diff --git a/tests/globals/require/tests/async_concurrent.luau b/tests/require/tests/async_concurrent.luau similarity index 100% rename from tests/globals/require/tests/async_concurrent.luau rename to tests/require/tests/async_concurrent.luau diff --git a/tests/globals/require/tests/async_sequential.luau b/tests/require/tests/async_sequential.luau similarity index 100% rename from tests/globals/require/tests/async_sequential.luau rename to tests/require/tests/async_sequential.luau diff --git a/tests/globals/require/tests/builtins.luau b/tests/require/tests/builtins.luau similarity index 100% rename from tests/globals/require/tests/builtins.luau rename to tests/require/tests/builtins.luau diff --git a/tests/globals/require/tests/children.luau b/tests/require/tests/children.luau similarity index 100% rename from tests/globals/require/tests/children.luau rename to tests/require/tests/children.luau diff --git a/tests/globals/require/tests/invalid.luau b/tests/require/tests/invalid.luau similarity index 100% rename from tests/globals/require/tests/invalid.luau rename to tests/require/tests/invalid.luau diff --git a/tests/globals/require/tests/module.luau b/tests/require/tests/module.luau similarity index 100% rename from tests/globals/require/tests/module.luau rename to tests/require/tests/module.luau diff --git a/tests/globals/require/tests/modules/async.luau b/tests/require/tests/modules/async.luau similarity index 100% rename from tests/globals/require/tests/modules/async.luau rename to tests/require/tests/modules/async.luau diff --git a/tests/globals/require/tests/modules/module.luau b/tests/require/tests/modules/module.luau similarity index 100% rename from tests/globals/require/tests/modules/module.luau rename to tests/require/tests/modules/module.luau diff --git a/tests/globals/require/tests/modules/modules/module.luau b/tests/require/tests/modules/modules/module.luau similarity index 100% rename from tests/globals/require/tests/modules/modules/module.luau rename to tests/require/tests/modules/modules/module.luau diff --git a/tests/globals/require/tests/modules/nested.luau b/tests/require/tests/modules/nested.luau similarity index 100% rename from tests/globals/require/tests/modules/nested.luau rename to tests/require/tests/modules/nested.luau diff --git a/tests/globals/require/tests/nested.luau b/tests/require/tests/nested.luau similarity index 100% rename from tests/globals/require/tests/nested.luau rename to tests/require/tests/nested.luau diff --git a/tests/globals/require/tests/parents.luau b/tests/require/tests/parents.luau similarity index 100% rename from tests/globals/require/tests/parents.luau rename to tests/require/tests/parents.luau diff --git a/tests/globals/require/tests/siblings.luau b/tests/require/tests/siblings.luau similarity index 100% rename from tests/globals/require/tests/siblings.luau rename to tests/require/tests/siblings.luau