diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ccd11e..91d0352 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added -- Builtin modules such as `fs`, `net` and others can now be imported using `require("@lune/fs")`, `require("@lune/net")` .. +- `require` has been reimplemented and overhauled in several ways: + + - Builtin modules such as `fs`, `net` and others can now be imported using `require("@lune/fs")`, `require("@lune/net")` ...
+ This is the first step towards moving away from adding each library as a global, and allowing Lune to have more built-in libraries. + + - Requiring a script is now completely asynchronous and will not block lua threads other than the caller. + - Requiring a script will no longer error when using async APIs in the main body of the required script. + + Behavior otherwise stays the same, and requires are still relative to file unless the special `@` prefix is used. ### Removed diff --git a/packages/lib/src/globals/process.rs b/packages/lib/src/globals/process.rs index 8f268d3..360d696 100644 --- a/packages/lib/src/globals/process.rs +++ b/packages/lib/src/globals/process.rs @@ -136,7 +136,7 @@ fn process_env_iter<'lua>( lua: &'lua Lua, (_, _): (LuaValue<'lua>, ()), ) -> LuaResult> { - let mut vars = env::vars_os(); + let mut vars = env::vars_os().collect::>().into_iter(); lua.create_function_mut(move |lua, _: ()| match vars.next() { Some((key, value)) => { let raw_key = RawOsString::new(key); diff --git a/packages/lib/src/globals/require.rs b/packages/lib/src/globals/require.rs index 2040600..802d52b 100644 --- a/packages/lib/src/globals/require.rs +++ b/packages/lib/src/globals/require.rs @@ -1,13 +1,14 @@ use std::{ cell::RefCell, - collections::HashMap, + collections::{HashMap, HashSet}, env::current_dir, path::{self, PathBuf}, + sync::Arc, }; use dunce::canonicalize; use mlua::prelude::*; -use tokio::{fs, sync::oneshot}; +use tokio::fs; use crate::lua::{ table::TableBuilder, @@ -19,14 +20,17 @@ local source = info(1, "s") if source == '[string "require"]' then source = info(2, "s") end -local absolute, relative = paths(context, source, ...) -return load(context, absolute, relative) +load(context, source, ...) +return yield() "#; #[derive(Debug, Clone, Default)] struct RequireContext<'lua> { - builtins: HashMap>, - cached: RefCell>>>, + // NOTE: We need to use arc here so that mlua clones + // the reference and not the entire inner value(s) + builtins: Arc>>, + cached: Arc>>>>, + locks: Arc>>, pwd: String, } @@ -44,48 +48,76 @@ impl<'lua> RequireContext<'lua> { ..Default::default() } } + + pub fn is_locked(&self, absolute_path: &str) -> bool { + self.locks.borrow().contains(absolute_path) + } + + pub fn set_locked(&self, absolute_path: &str) -> bool { + self.locks.borrow_mut().insert(absolute_path.to_string()) + } + + pub fn set_unlocked(&self, absolute_path: &str) -> bool { + self.locks.borrow_mut().remove(absolute_path) + } + + pub fn try_acquire_lock_sync(&self, absolute_path: &str) -> bool { + if self.is_locked(absolute_path) { + false + } else { + self.set_locked(absolute_path); + true + } + } + + pub fn set_cached(&self, absolute_path: String, result: &LuaResult>) { + self.cached + .borrow_mut() + .insert(absolute_path, result.clone()); + } + + pub fn get_paths( + &self, + require_source: String, + require_path: String, + ) -> LuaResult<(String, String)> { + if require_path.starts_with('@') { + return Ok((require_path.clone(), require_path)); + } + let path_relative_to_pwd = PathBuf::from( + &require_source + .trim_start_matches("[string \"") + .trim_end_matches("\"]"), + ) + .parent() + .unwrap() + .join(&require_path); + // Try to normalize and resolve relative path segments such as './' and '../' + let file_path = match ( + canonicalize(path_relative_to_pwd.with_extension("luau")), + canonicalize(path_relative_to_pwd.with_extension("lua")), + ) { + (Ok(luau), _) => luau, + (_, Ok(lua)) => lua, + _ => { + return Err(LuaError::RuntimeError(format!( + "File does not exist at path '{require_path}'" + ))) + } + }; + let absolute = file_path.to_string_lossy().to_string(); + let relative = absolute.trim_start_matches(&self.pwd).to_string(); + Ok((absolute, relative)) + } } impl<'lua> LuaUserData for RequireContext<'lua> {} -fn paths( - context: RequireContext, - require_source: String, - require_path: String, -) -> LuaResult<(String, String)> { - if require_path.starts_with('@') { - return Ok((require_path.clone(), require_path)); - } - let path_relative_to_pwd = PathBuf::from( - &require_source - .trim_start_matches("[string \"") - .trim_end_matches("\"]"), - ) - .parent() - .unwrap() - .join(&require_path); - // Try to normalize and resolve relative path segments such as './' and '../' - let file_path = match ( - canonicalize(path_relative_to_pwd.with_extension("luau")), - canonicalize(path_relative_to_pwd.with_extension("lua")), - ) { - (Ok(luau), _) => luau, - (_, Ok(lua)) => lua, - _ => { - return Err(LuaError::RuntimeError(format!( - "File does not exist at path '{require_path}'" - ))) - } - }; - let absolute = file_path.to_string_lossy().to_string(); - let relative = absolute.trim_start_matches(&context.pwd).to_string(); - Ok((absolute, relative)) -} - fn load_builtin<'lua>( _lua: &'lua Lua, - context: RequireContext<'lua>, + context: &RequireContext<'lua>, module_name: String, + _has_acquired_lock: bool, ) -> LuaResult> { match context.builtins.get(&module_name) { Some(module) => Ok(module.clone()), @@ -98,14 +130,20 @@ fn load_builtin<'lua>( async fn load_file<'lua>( lua: &'lua Lua, - context: RequireContext<'lua>, + context: &RequireContext<'lua>, absolute_path: String, relative_path: String, + has_acquired_lock: bool, ) -> LuaResult> { let cached = { context.cached.borrow().get(&absolute_path).cloned() }; match cached { Some(cached) => cached, None => { + if !has_acquired_lock { + return Err(LuaError::RuntimeError( + "Failed to get require lock".to_string(), + )); + } // Try to read the wanted file, note that we use bytes instead of reading // to a string since lua scripts are not necessarily valid utf-8 strings let contents = fs::read(&absolute_path).await.map_err(LuaError::external)?; @@ -120,21 +158,15 @@ async fn load_file<'lua>( .set_name(path_relative_no_extension)? .into_function()?; let loaded_thread = lua.create_thread(loaded_func)?; - // Run the thread and provide a channel that will - // then get its result received when it finishes - let (tx, rx) = oneshot::channel(); - { + // Run the thread and wait for completion using the native task scheduler waker + let task_fut = { let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?; - sched.set_task_result_sender(task, tx); - } + sched.wait_for_task_completion(task) + }; // Wait for the thread to finish running, cache + return our result - // FIXME: This waits indefinitely for nested requires for some reason - let rets = rx.await.expect("Sender was dropped during require"); - context - .cached - .borrow_mut() - .insert(absolute_path, rets.clone()); + let rets = task_fut.await; + context.set_cached(absolute_path, &rets); rets } } @@ -145,35 +177,66 @@ async fn load<'lua>( context: RequireContext<'lua>, absolute_path: String, relative_path: String, + has_acquired_lock: bool, ) -> LuaResult> { - if absolute_path == relative_path && absolute_path.starts_with('@') { + let result = if absolute_path == relative_path && absolute_path.starts_with('@') { if let Some(module_name) = absolute_path.strip_prefix("@lune/") { - load_builtin(lua, context, module_name.to_string()) + load_builtin(lua, &context, module_name.to_string(), has_acquired_lock) } else { + // FUTURE: '@' can be used a special prefix for users to set their own + // paths relative to a project file, similar to typescript paths config + // https://www.typescriptlang.org/tsconfig#paths Err(LuaError::RuntimeError( "Require paths prefixed by '@' are not yet supported".to_string(), )) } } else { - load_file(lua, context, absolute_path, relative_path).await + load_file( + lua, + &context, + absolute_path.to_string(), + relative_path, + has_acquired_lock, + ) + .await + }; + if has_acquired_lock { + context.set_unlocked(&absolute_path); } + result } pub fn create(lua: &'static Lua) -> LuaResult { let require_context = RequireContext::new(); - let require_print: LuaFunction = lua.named_registry_value("print")?; + let require_yield: LuaFunction = lua.named_registry_value("co.yield")?; let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; + let require_print: LuaFunction = lua.named_registry_value("print")?; let require_env = TableBuilder::new(lua)? .with_value("context", require_context)? - .with_value("print", require_print)? + .with_value("yield", require_yield)? .with_value("info", require_info)? - .with_function("paths", |_, (context, require_source, require_path)| { - paths(context, require_source, require_path) - })? - .with_async_function("load", |lua, (context, require_source, require_path)| { - load(lua, context, require_source, require_path) - })? + .with_value("print", require_print)? + .with_function( + "load", + |lua, (context, require_source, require_path): (RequireContext, String, String)| { + let (absolute_path, relative_path) = + context.get_paths(require_source, require_path)?; + // NOTE: We can not acquire the lock in the async part of the require + // load process since several requires may have happened for the + // same path before the async load task even gets a chance to run + let has_lock = context.try_acquire_lock_sync(&absolute_path); + let fut = load(lua, context, absolute_path, relative_path, has_lock); + let sched = lua + .app_data_ref::<&TaskScheduler>() + .expect("Missing task scheduler as a lua app data"); + sched.queue_async_task_inherited(lua.current_thread(), None, async { + let rets = fut.await?; + let mult = rets.to_lua_multi(lua)?; + Ok(Some(mult)) + }) + }, + )? .build_readonly()?; let require_fn_lua = lua diff --git a/packages/lib/src/globals/task.rs b/packages/lib/src/globals/task.rs index a86fe0e..fe5c9ee 100644 --- a/packages/lib/src/globals/task.rs +++ b/packages/lib/src/globals/task.rs @@ -158,13 +158,13 @@ fn coroutine_resume<'lua>( let result = match value { LuaThreadOrTaskReference::Thread(t) => { let task = sched.create_task(TaskKind::Instant, t, None, true)?; - sched.resume_task(task) + sched.resume_task(task, None) } - LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t), + LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t, None), }; sched.force_set_current_task(Some(current)); match result { - Ok(rets) => Ok((true, rets)), + Ok(rets) => Ok((true, rets.1)), Err(e) => Ok((false, e.to_lua_multi(lua)?)), } } @@ -187,8 +187,11 @@ fn coroutine_wrap<'lua>(lua: &'lua Lua, func: LuaFunction) -> LuaResult() .unwrap() - .resume_task_override(task, Ok(args)); + .resume_task(task, Some(Ok(args))); sched.force_set_current_task(Some(current)); - result + match result { + Ok(rets) => Ok(rets.1), + Err(e) => Err(e), + } }) } diff --git a/packages/lib/src/lua/task/ext/resume_ext.rs b/packages/lib/src/lua/task/ext/resume_ext.rs index 74c550f..9fecaae 100644 --- a/packages/lib/src/lua/task/ext/resume_ext.rs +++ b/packages/lib/src/lua/task/ext/resume_ext.rs @@ -76,10 +76,13 @@ impl TaskSchedulerResumeExt for TaskScheduler<'_> { Resumes the next queued Lua task, if one exists, blocking the current thread until it either yields or finishes. */ -fn resume_next_blocking_task( - scheduler: &TaskScheduler<'_>, - override_args: Option>, -) -> TaskSchedulerState { +fn resume_next_blocking_task<'sched, 'args>( + scheduler: &TaskScheduler<'sched>, + override_args: Option>>, +) -> TaskSchedulerState +where + 'args: 'sched, +{ match { let mut queue_guard = scheduler.tasks_queue_blocking.borrow_mut(); let task = queue_guard.pop_front(); @@ -87,15 +90,16 @@ fn resume_next_blocking_task( task } { None => TaskSchedulerState::new(scheduler), - Some(task) => match override_args { - Some(args) => match scheduler.resume_task_override(task, args) { - Ok(_) => TaskSchedulerState::new(scheduler), - Err(task_err) => TaskSchedulerState::err(scheduler, task_err), - }, - None => match scheduler.resume_task(task) { - Ok(_) => TaskSchedulerState::new(scheduler), - Err(task_err) => TaskSchedulerState::err(scheduler, task_err), - }, + Some(task) => match scheduler.resume_task(task, override_args) { + Err(task_err) => { + scheduler.wake_completed_task(task, Err(task_err.clone())); + TaskSchedulerState::err(scheduler, task_err) + } + Ok(rets) if rets.0 == LuaThreadStatus::Unresumable => { + scheduler.wake_completed_task(task, Ok(rets.1)); + TaskSchedulerState::new(scheduler) + } + Ok(_) => TaskSchedulerState::new(scheduler), }, } } @@ -158,9 +162,9 @@ async fn receive_next_message(scheduler: &TaskScheduler<'_>) -> TaskSchedulerSta if prev == 0 { panic!( r#" - Terminated a background task without it running - this is an internal error! - Please report it at {} - "#, + Terminated a background task without it running - this is an internal error! + Please report it at {} + "#, env!("CARGO_PKG_REPOSITORY") ) } diff --git a/packages/lib/src/lua/task/mod.rs b/packages/lib/src/lua/task/mod.rs index f9d1813..6f8df44 100644 --- a/packages/lib/src/lua/task/mod.rs +++ b/packages/lib/src/lua/task/mod.rs @@ -6,6 +6,7 @@ mod scheduler_message; mod scheduler_state; mod task_kind; mod task_reference; +mod task_waiter; pub use ext::*; pub use proxy::*; diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs index 460f62e..834a379 100644 --- a/packages/lib/src/lua/task/scheduler.rs +++ b/packages/lib/src/lua/task/scheduler.rs @@ -9,12 +9,14 @@ use std::{ use futures_util::{future::LocalBoxFuture, stream::FuturesUnordered, Future}; use mlua::prelude::*; -use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex}; +use tokio::sync::{mpsc, Mutex as AsyncMutex}; -use super::scheduler_message::TaskSchedulerMessage; +use super::{ + scheduler_message::TaskSchedulerMessage, + task_waiter::{TaskWaiterFuture, TaskWaiterState}, +}; pub use super::{task_kind::TaskKind, task_reference::TaskReference}; -type TaskResultSender = oneshot::Sender>>; type TaskFutureRets<'fut> = LuaResult>>; type TaskFuture<'fut> = LocalBoxFuture<'fut, (Option, TaskFutureRets<'fut>)>; @@ -49,8 +51,9 @@ pub struct TaskScheduler<'fut> { pub(super) tasks_count: Cell, pub(super) tasks_current: Cell>, pub(super) tasks_queue_blocking: RefCell>, - pub(super) tasks_result_senders: RefCell>, - pub(super) tasks_current_lua_error: Arc>>, + pub(super) tasks_waiter_states: + RefCell>>>>, + pub(super) tasks_current_lua_error: Arc>>, // Future tasks & objects for waking pub(super) futures: AsyncMutex>>, pub(super) futures_count: Cell, @@ -65,12 +68,14 @@ impl<'fut> TaskScheduler<'fut> { */ pub fn new(lua: &'static Lua) -> LuaResult { let (tx, rx) = mpsc::unbounded_channel(); - let tasks_current_lua_error = Arc::new(RefCell::new(None)); + let tasks_current_lua_error = Arc::new(AsyncMutex::new(None)); let tasks_current_lua_error_inner = tasks_current_lua_error.clone(); - lua.set_interrupt(move || match tasks_current_lua_error_inner.take() { - Some(err) => Err(err), - None => Ok(LuaVmState::Continue), - }); + lua.set_interrupt( + move || match tasks_current_lua_error_inner.try_lock().unwrap().take() { + Some(err) => Err(err), + None => Ok(LuaVmState::Continue), + }, + ); Ok(Self { lua, guid: Cell::new(0), @@ -79,7 +84,7 @@ impl<'fut> TaskScheduler<'fut> { tasks_count: Cell::new(0), tasks_current: Cell::new(None), tasks_queue_blocking: RefCell::new(VecDeque::new()), - tasks_result_senders: RefCell::new(HashMap::new()), + tasks_waiter_states: RefCell::new(HashMap::new()), tasks_current_lua_error, futures: AsyncMutex::new(FuturesUnordered::new()), futures_tx: tx, @@ -273,8 +278,6 @@ impl<'fut> TaskScheduler<'fut> { TaskKind::Future => self.futures_count.set(self.futures_count.get() - 1), _ => self.tasks_count.set(self.tasks_count.get() - 1), } - // Remove any sender - self.tasks_result_senders.borrow_mut().remove(task_ref); // NOTE: We need to close the thread here to // make 100% sure that nothing can resume it let close: LuaFunction = self.lua.named_registry_value("co.close")?; @@ -296,14 +299,21 @@ impl<'fut> TaskScheduler<'fut> { This will be a no-op if the task no longer exists. */ - pub fn resume_task(&self, reference: TaskReference) -> LuaResult { + pub fn resume_task<'a, 'r>( + &self, + reference: TaskReference, + override_args: Option>>, + ) -> LuaResult<(LuaThreadStatus, LuaMultiValue<'r>)> + where + 'a: 'r, + { // Fetch and check if the task was removed, if it got // removed it means it was intentionally cancelled let task = { let mut tasks = self.tasks.borrow_mut(); match tasks.remove(&reference) { Some(task) => task, - None => return Ok(LuaMultiValue::new()), + None => return Ok((LuaThreadStatus::Unresumable, LuaMultiValue::new())), } }; // Decrement the corresponding task counter @@ -324,66 +334,27 @@ impl<'fut> TaskScheduler<'fut> { // We got everything we need and our references // were cleaned up properly, resume the thread self.tasks_current.set(Some(reference)); - let rets = match thread_args { - Some(args) => thread.resume(args), - None => thread.resume(()), - }; - self.tasks_current.set(None); - // If we have a result sender for this task, we should run it if the thread finished - if thread.status() != LuaThreadStatus::Resumable { - if let Some(sender) = self.tasks_result_senders.borrow_mut().remove(&reference) { - let _ = sender.send(rets.clone()); - } - } - rets - } - - /** - Resumes a task, if the task still exists in the scheduler, using the given arguments. - - A task may no longer exist in the scheduler if it has been manually - cancelled and removed by calling [`TaskScheduler::cancel_task()`]. - - This will be a no-op if the task no longer exists. - */ - pub fn resume_task_override<'a>( - &self, - reference: TaskReference, - override_args: LuaResult>, - ) -> LuaResult> { - // Fetch and check if the task was removed, if it got - // removed it means it was intentionally cancelled - let task = { - let mut tasks = self.tasks.borrow_mut(); - match tasks.remove(&reference) { - Some(task) => task, - None => return Ok(LuaMultiValue::new()), - } - }; - // Decrement the corresponding task counter - match task.kind { - TaskKind::Future => self.futures_count.set(self.futures_count.get() - 1), - _ => self.tasks_count.set(self.tasks_count.get() - 1), - } - // Fetch and remove the thread to resume + its arguments - let thread: LuaThread = self.lua.registry_value(&task.thread)?; - self.lua.remove_registry_value(task.thread)?; - self.lua.remove_registry_value(task.args)?; - // We got everything we need and our references - // were cleaned up properly, resume the thread - self.tasks_current.set(Some(reference)); let rets = match override_args { - Err(e) => { - // NOTE: Setting this error here means that when the thread - // is resumed it will error instantly, so we don't need - // to call it with proper args, empty args is fine - self.tasks_current_lua_error.replace(Some(e)); - thread.resume(()) - } - Ok(args) => thread.resume(args), + Some(override_res) => match override_res { + Ok(args) => thread.resume(args), + Err(e) => { + // NOTE: Setting this error here means that when the thread + // is resumed it will error instantly, so we don't need + // to call it with proper args, empty args is fine + self.tasks_current_lua_error.try_lock().unwrap().replace(e); + thread.resume(()) + } + }, + None => match thread_args { + Some(args) => thread.resume(args), + None => thread.resume(()), + }, }; self.tasks_current.set(None); - rets + match rets { + Ok(rets) => Ok((thread.status(), rets)), + Err(e) => Err(e), + } } /** @@ -446,9 +417,58 @@ impl<'fut> TaskScheduler<'fut> { Ok(task_ref) } - pub(crate) fn set_task_result_sender(&self, task_ref: TaskReference, sender: TaskResultSender) { - self.tasks_result_senders + /** + Queues a new future to run on the task scheduler, + inheriting the task id of the currently running task. + */ + pub(crate) fn queue_async_task_inherited( + &self, + thread: LuaThread<'_>, + thread_args: Option>, + fut: impl Future> + 'fut, + ) -> LuaResult { + let task_ref = self.create_task(TaskKind::Future, thread, thread_args, true)?; + let futs = self + .futures + .try_lock() + .expect("Tried to add future to queue during futures resumption"); + futs.push(Box::pin(async move { + let result = fut.await; + (Some(task_ref), result) + })); + Ok(task_ref) + } + + /** + Waits for a task to complete. + + Panics if the task is not currently in the scheduler. + */ + pub(crate) async fn wait_for_task_completion( + &self, + reference: TaskReference, + ) -> LuaResult { + if !self.tasks.borrow().contains_key(&reference) { + panic!("Task does not exist in scheduler") + } + let state = TaskWaiterState::new(); + self.tasks_waiter_states .borrow_mut() - .insert(task_ref, sender); + .insert(reference, Arc::clone(&state)); + TaskWaiterFuture::new(&state).await + } + + /** + Wakes a task that has been completed and may have external code + waiting on it using [`TaskScheduler::wait_for_task_completion`]. + */ + pub(super) fn wake_completed_task( + &self, + reference: TaskReference, + result: LuaResult>, + ) { + if let Some(waiter_state) = self.tasks_waiter_states.borrow_mut().remove(&reference) { + waiter_state.try_lock().unwrap().finalize(result); + } } } diff --git a/packages/lib/src/lua/task/task_reference.rs b/packages/lib/src/lua/task/task_reference.rs index 8df11c3..de1967a 100644 --- a/packages/lib/src/lua/task/task_reference.rs +++ b/packages/lib/src/lua/task/task_reference.rs @@ -1,4 +1,7 @@ -use std::fmt; +use std::{ + fmt, + hash::{Hash, Hasher}, +}; use mlua::prelude::*; @@ -6,7 +9,7 @@ use super::task_kind::TaskKind; /// A lightweight, copyable struct that represents a /// task in the scheduler and is accessible from Lua -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[derive(Debug, Clone, Copy)] pub struct TaskReference { kind: TaskKind, guid: usize, @@ -32,4 +35,17 @@ impl fmt::Display for TaskReference { } } +impl Eq for TaskReference {} +impl PartialEq for TaskReference { + fn eq(&self, other: &Self) -> bool { + self.guid == other.guid + } +} + +impl Hash for TaskReference { + fn hash(&self, state: &mut H) { + self.guid.hash(state); + } +} + impl LuaUserData for TaskReference {} diff --git a/packages/lib/src/lua/task/task_waiter.rs b/packages/lib/src/lua/task/task_waiter.rs new file mode 100644 index 0000000..3be53b2 --- /dev/null +++ b/packages/lib/src/lua/task/task_waiter.rs @@ -0,0 +1,66 @@ +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll, Waker}, +}; + +use tokio::sync::Mutex as AsyncMutex; + +use mlua::prelude::*; + +#[derive(Debug, Clone)] +pub(super) struct TaskWaiterState<'fut> { + rets: Option>>, + waker: Option, +} + +impl<'fut> TaskWaiterState<'fut> { + pub fn new() -> Arc> { + Arc::new(AsyncMutex::new(TaskWaiterState { + rets: None, + waker: None, + })) + } + + pub fn finalize(&mut self, rets: LuaResult>) { + self.rets = Some(rets); + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } +} + +#[derive(Debug)] +pub(super) struct TaskWaiterFuture<'fut> { + state: Arc>>, +} + +impl<'fut> TaskWaiterFuture<'fut> { + pub fn new(state: &Arc>>) -> Self { + Self { + state: Arc::clone(state), + } + } +} + +impl<'fut> Clone for TaskWaiterFuture<'fut> { + fn clone(&self) -> Self { + Self { + state: Arc::clone(&self.state), + } + } +} + +impl<'fut> Future for TaskWaiterFuture<'fut> { + type Output = LuaResult>; + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut shared_state = self.state.try_lock().unwrap(); + if let Some(rets) = shared_state.rets.clone() { + Poll::Ready(rets) + } else { + shared_state.waker = Some(cx.waker().clone()); + Poll::Pending + } + } +} diff --git a/packages/lib/src/tests.rs b/packages/lib/src/tests.rs index c30117d..f4ba14f 100644 --- a/packages/lib/src/tests.rs +++ b/packages/lib/src/tests.rs @@ -60,6 +60,9 @@ create_tests! { process_env: "process/env", process_exit: "process/exit", process_spawn: "process/spawn", + require_async: "globals/require/tests/async", + require_async_concurrent: "globals/require/tests/async_concurrent", + require_async_sequential: "globals/require/tests/async_sequential", require_children: "globals/require/tests/children", require_invalid: "globals/require/tests/invalid", require_nested: "globals/require/tests/nested", diff --git a/tests/globals/require/tests/async.luau b/tests/globals/require/tests/async.luau new file mode 100644 index 0000000..6681695 --- /dev/null +++ b/tests/globals/require/tests/async.luau @@ -0,0 +1,9 @@ +local module = require("./modules/async") + +assert(type(module) == "table", "Required module did not return a table") +assert(module.Foo == "Bar", "Required module did not contain correct values") +assert(module.Hello == "World", "Required module did not contain correct values") + +module = require("modules/async") +assert(module.Foo == "Bar", "Required module did not contain correct values") +assert(module.Hello == "World", "Required module did not contain correct values") diff --git a/tests/globals/require/tests/async_concurrent.luau b/tests/globals/require/tests/async_concurrent.luau new file mode 100644 index 0000000..e0da822 --- /dev/null +++ b/tests/globals/require/tests/async_concurrent.luau @@ -0,0 +1,22 @@ +local module1 +local module2 + +task.defer(function() + module2 = require("./modules/async") +end) + +task.spawn(function() + module1 = require("./modules/async") +end) + +task.wait(1) + +assert(type(module1) == "table", "Required module1 did not return a table") +assert(module1.Foo == "Bar", "Required module1 did not contain correct values") +assert(module1.Hello == "World", "Required module1 did not contain correct values") + +assert(type(module2) == "table", "Required module2 did not return a table") +assert(module2.Foo == "Bar", "Required module2 did not contain correct values") +assert(module2.Hello == "World", "Required module2 did not contain correct values") + +assert(module1 == module2, "Required modules should point to the same return value") diff --git a/tests/globals/require/tests/async_sequential.luau b/tests/globals/require/tests/async_sequential.luau new file mode 100644 index 0000000..5bedc3b --- /dev/null +++ b/tests/globals/require/tests/async_sequential.luau @@ -0,0 +1,14 @@ +local module1 = require("./modules/async") +local module2 = require("./modules/async") + +task.wait(1) + +assert(type(module1) == "table", "Required module1 did not return a table") +assert(module1.Foo == "Bar", "Required module1 did not contain correct values") +assert(module1.Hello == "World", "Required module1 did not contain correct values") + +assert(type(module2) == "table", "Required module2 did not return a table") +assert(module2.Foo == "Bar", "Required module2 did not contain correct values") +assert(module2.Hello == "World", "Required module2 did not contain correct values") + +assert(module1 == module2, "Required modules should point to the same return value") diff --git a/tests/globals/require/tests/modules/async.luau b/tests/globals/require/tests/modules/async.luau new file mode 100644 index 0000000..86bb775 --- /dev/null +++ b/tests/globals/require/tests/modules/async.luau @@ -0,0 +1,6 @@ +task.wait(0.25) + +return { + Foo = "Bar", + Hello = "World", +}