Finish implementing async require, add test cases

This commit is contained in:
Filip Tibell 2023-03-21 11:07:42 +01:00
parent 0975a6180b
commit 172ab16823
No known key found for this signature in database
14 changed files with 400 additions and 165 deletions

View file

@ -12,7 +12,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- Builtin modules such as `fs`, `net` and others can now be imported using `require("@lune/fs")`, `require("@lune/net")` .. - `require` has been reimplemented and overhauled in several ways:
- Builtin modules such as `fs`, `net` and others can now be imported using `require("@lune/fs")`, `require("@lune/net")` ... <br />
This is the first step towards moving away from adding each library as a global, and allowing Lune to have more built-in libraries.
- Requiring a script is now completely asynchronous and will not block lua threads other than the caller.
- Requiring a script will no longer error when using async APIs in the main body of the required script.
Behavior otherwise stays the same, and requires are still relative to file unless the special `@` prefix is used.
### Removed ### Removed

View file

@ -136,7 +136,7 @@ fn process_env_iter<'lua>(
lua: &'lua Lua, lua: &'lua Lua,
(_, _): (LuaValue<'lua>, ()), (_, _): (LuaValue<'lua>, ()),
) -> LuaResult<LuaFunction<'lua>> { ) -> LuaResult<LuaFunction<'lua>> {
let mut vars = env::vars_os(); let mut vars = env::vars_os().collect::<Vec<_>>().into_iter();
lua.create_function_mut(move |lua, _: ()| match vars.next() { lua.create_function_mut(move |lua, _: ()| match vars.next() {
Some((key, value)) => { Some((key, value)) => {
let raw_key = RawOsString::new(key); let raw_key = RawOsString::new(key);

View file

@ -1,13 +1,14 @@
use std::{ use std::{
cell::RefCell, cell::RefCell,
collections::HashMap, collections::{HashMap, HashSet},
env::current_dir, env::current_dir,
path::{self, PathBuf}, path::{self, PathBuf},
sync::Arc,
}; };
use dunce::canonicalize; use dunce::canonicalize;
use mlua::prelude::*; use mlua::prelude::*;
use tokio::{fs, sync::oneshot}; use tokio::fs;
use crate::lua::{ use crate::lua::{
table::TableBuilder, table::TableBuilder,
@ -19,14 +20,17 @@ local source = info(1, "s")
if source == '[string "require"]' then if source == '[string "require"]' then
source = info(2, "s") source = info(2, "s")
end end
local absolute, relative = paths(context, source, ...) load(context, source, ...)
return load(context, absolute, relative) return yield()
"#; "#;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
struct RequireContext<'lua> { struct RequireContext<'lua> {
builtins: HashMap<String, LuaMultiValue<'lua>>, // NOTE: We need to use arc here so that mlua clones
cached: RefCell<HashMap<String, LuaResult<LuaMultiValue<'lua>>>>, // the reference and not the entire inner value(s)
builtins: Arc<HashMap<String, LuaMultiValue<'lua>>>,
cached: Arc<RefCell<HashMap<String, LuaResult<LuaMultiValue<'lua>>>>>,
locks: Arc<RefCell<HashSet<String>>>,
pwd: String, pwd: String,
} }
@ -44,12 +48,36 @@ impl<'lua> RequireContext<'lua> {
..Default::default() ..Default::default()
} }
} }
pub fn is_locked(&self, absolute_path: &str) -> bool {
self.locks.borrow().contains(absolute_path)
} }
impl<'lua> LuaUserData for RequireContext<'lua> {} pub fn set_locked(&self, absolute_path: &str) -> bool {
self.locks.borrow_mut().insert(absolute_path.to_string())
}
fn paths( pub fn set_unlocked(&self, absolute_path: &str) -> bool {
context: RequireContext, self.locks.borrow_mut().remove(absolute_path)
}
pub fn try_acquire_lock_sync(&self, absolute_path: &str) -> bool {
if self.is_locked(absolute_path) {
false
} else {
self.set_locked(absolute_path);
true
}
}
pub fn set_cached(&self, absolute_path: String, result: &LuaResult<LuaMultiValue<'lua>>) {
self.cached
.borrow_mut()
.insert(absolute_path, result.clone());
}
pub fn get_paths(
&self,
require_source: String, require_source: String,
require_path: String, require_path: String,
) -> LuaResult<(String, String)> { ) -> LuaResult<(String, String)> {
@ -78,14 +106,18 @@ fn paths(
} }
}; };
let absolute = file_path.to_string_lossy().to_string(); let absolute = file_path.to_string_lossy().to_string();
let relative = absolute.trim_start_matches(&context.pwd).to_string(); let relative = absolute.trim_start_matches(&self.pwd).to_string();
Ok((absolute, relative)) Ok((absolute, relative))
} }
}
impl<'lua> LuaUserData for RequireContext<'lua> {}
fn load_builtin<'lua>( fn load_builtin<'lua>(
_lua: &'lua Lua, _lua: &'lua Lua,
context: RequireContext<'lua>, context: &RequireContext<'lua>,
module_name: String, module_name: String,
_has_acquired_lock: bool,
) -> LuaResult<LuaMultiValue<'lua>> { ) -> LuaResult<LuaMultiValue<'lua>> {
match context.builtins.get(&module_name) { match context.builtins.get(&module_name) {
Some(module) => Ok(module.clone()), Some(module) => Ok(module.clone()),
@ -98,14 +130,20 @@ fn load_builtin<'lua>(
async fn load_file<'lua>( async fn load_file<'lua>(
lua: &'lua Lua, lua: &'lua Lua,
context: RequireContext<'lua>, context: &RequireContext<'lua>,
absolute_path: String, absolute_path: String,
relative_path: String, relative_path: String,
has_acquired_lock: bool,
) -> LuaResult<LuaMultiValue<'lua>> { ) -> LuaResult<LuaMultiValue<'lua>> {
let cached = { context.cached.borrow().get(&absolute_path).cloned() }; let cached = { context.cached.borrow().get(&absolute_path).cloned() };
match cached { match cached {
Some(cached) => cached, Some(cached) => cached,
None => { None => {
if !has_acquired_lock {
return Err(LuaError::RuntimeError(
"Failed to get require lock".to_string(),
));
}
// Try to read the wanted file, note that we use bytes instead of reading // Try to read the wanted file, note that we use bytes instead of reading
// to a string since lua scripts are not necessarily valid utf-8 strings // to a string since lua scripts are not necessarily valid utf-8 strings
let contents = fs::read(&absolute_path).await.map_err(LuaError::external)?; let contents = fs::read(&absolute_path).await.map_err(LuaError::external)?;
@ -120,21 +158,15 @@ async fn load_file<'lua>(
.set_name(path_relative_no_extension)? .set_name(path_relative_no_extension)?
.into_function()?; .into_function()?;
let loaded_thread = lua.create_thread(loaded_func)?; let loaded_thread = lua.create_thread(loaded_func)?;
// Run the thread and provide a channel that will // Run the thread and wait for completion using the native task scheduler waker
// then get its result received when it finishes let task_fut = {
let (tx, rx) = oneshot::channel();
{
let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?; let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?;
sched.set_task_result_sender(task, tx); sched.wait_for_task_completion(task)
} };
// Wait for the thread to finish running, cache + return our result // Wait for the thread to finish running, cache + return our result
// FIXME: This waits indefinitely for nested requires for some reason let rets = task_fut.await;
let rets = rx.await.expect("Sender was dropped during require"); context.set_cached(absolute_path, &rets);
context
.cached
.borrow_mut()
.insert(absolute_path, rets.clone());
rets rets
} }
} }
@ -145,35 +177,66 @@ async fn load<'lua>(
context: RequireContext<'lua>, context: RequireContext<'lua>,
absolute_path: String, absolute_path: String,
relative_path: String, relative_path: String,
has_acquired_lock: bool,
) -> LuaResult<LuaMultiValue<'lua>> { ) -> LuaResult<LuaMultiValue<'lua>> {
if absolute_path == relative_path && absolute_path.starts_with('@') { let result = if absolute_path == relative_path && absolute_path.starts_with('@') {
if let Some(module_name) = absolute_path.strip_prefix("@lune/") { if let Some(module_name) = absolute_path.strip_prefix("@lune/") {
load_builtin(lua, context, module_name.to_string()) load_builtin(lua, &context, module_name.to_string(), has_acquired_lock)
} else { } else {
// FUTURE: '@' can be used a special prefix for users to set their own
// paths relative to a project file, similar to typescript paths config
// https://www.typescriptlang.org/tsconfig#paths
Err(LuaError::RuntimeError( Err(LuaError::RuntimeError(
"Require paths prefixed by '@' are not yet supported".to_string(), "Require paths prefixed by '@' are not yet supported".to_string(),
)) ))
} }
} else { } else {
load_file(lua, context, absolute_path, relative_path).await load_file(
lua,
&context,
absolute_path.to_string(),
relative_path,
has_acquired_lock,
)
.await
};
if has_acquired_lock {
context.set_unlocked(&absolute_path);
} }
result
} }
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> { pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
let require_context = RequireContext::new(); let require_context = RequireContext::new();
let require_print: LuaFunction = lua.named_registry_value("print")?; let require_yield: LuaFunction = lua.named_registry_value("co.yield")?;
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
let require_print: LuaFunction = lua.named_registry_value("print")?;
let require_env = TableBuilder::new(lua)? let require_env = TableBuilder::new(lua)?
.with_value("context", require_context)? .with_value("context", require_context)?
.with_value("print", require_print)? .with_value("yield", require_yield)?
.with_value("info", require_info)? .with_value("info", require_info)?
.with_function("paths", |_, (context, require_source, require_path)| { .with_value("print", require_print)?
paths(context, require_source, require_path) .with_function(
})? "load",
.with_async_function("load", |lua, (context, require_source, require_path)| { |lua, (context, require_source, require_path): (RequireContext, String, String)| {
load(lua, context, require_source, require_path) let (absolute_path, relative_path) =
})? context.get_paths(require_source, require_path)?;
// NOTE: We can not acquire the lock in the async part of the require
// load process since several requires may have happened for the
// same path before the async load task even gets a chance to run
let has_lock = context.try_acquire_lock_sync(&absolute_path);
let fut = load(lua, context, absolute_path, relative_path, has_lock);
let sched = lua
.app_data_ref::<&TaskScheduler>()
.expect("Missing task scheduler as a lua app data");
sched.queue_async_task_inherited(lua.current_thread(), None, async {
let rets = fut.await?;
let mult = rets.to_lua_multi(lua)?;
Ok(Some(mult))
})
},
)?
.build_readonly()?; .build_readonly()?;
let require_fn_lua = lua let require_fn_lua = lua

View file

@ -158,13 +158,13 @@ fn coroutine_resume<'lua>(
let result = match value { let result = match value {
LuaThreadOrTaskReference::Thread(t) => { LuaThreadOrTaskReference::Thread(t) => {
let task = sched.create_task(TaskKind::Instant, t, None, true)?; let task = sched.create_task(TaskKind::Instant, t, None, true)?;
sched.resume_task(task) sched.resume_task(task, None)
} }
LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t), LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t, None),
}; };
sched.force_set_current_task(Some(current)); sched.force_set_current_task(Some(current));
match result { match result {
Ok(rets) => Ok((true, rets)), Ok(rets) => Ok((true, rets.1)),
Err(e) => Ok((false, e.to_lua_multi(lua)?)), Err(e) => Ok((false, e.to_lua_multi(lua)?)),
} }
} }
@ -187,8 +187,11 @@ fn coroutine_wrap<'lua>(lua: &'lua Lua, func: LuaFunction) -> LuaResult<LuaFunct
let result = lua let result = lua
.app_data_ref::<&TaskScheduler>() .app_data_ref::<&TaskScheduler>()
.unwrap() .unwrap()
.resume_task_override(task, Ok(args)); .resume_task(task, Some(Ok(args)));
sched.force_set_current_task(Some(current)); sched.force_set_current_task(Some(current));
result match result {
Ok(rets) => Ok(rets.1),
Err(e) => Err(e),
}
}) })
} }

View file

@ -76,10 +76,13 @@ impl TaskSchedulerResumeExt for TaskScheduler<'_> {
Resumes the next queued Lua task, if one exists, blocking Resumes the next queued Lua task, if one exists, blocking
the current thread until it either yields or finishes. the current thread until it either yields or finishes.
*/ */
fn resume_next_blocking_task( fn resume_next_blocking_task<'sched, 'args>(
scheduler: &TaskScheduler<'_>, scheduler: &TaskScheduler<'sched>,
override_args: Option<LuaResult<LuaMultiValue>>, override_args: Option<LuaResult<LuaMultiValue<'args>>>,
) -> TaskSchedulerState { ) -> TaskSchedulerState
where
'args: 'sched,
{
match { match {
let mut queue_guard = scheduler.tasks_queue_blocking.borrow_mut(); let mut queue_guard = scheduler.tasks_queue_blocking.borrow_mut();
let task = queue_guard.pop_front(); let task = queue_guard.pop_front();
@ -87,15 +90,16 @@ fn resume_next_blocking_task(
task task
} { } {
None => TaskSchedulerState::new(scheduler), None => TaskSchedulerState::new(scheduler),
Some(task) => match override_args { Some(task) => match scheduler.resume_task(task, override_args) {
Some(args) => match scheduler.resume_task_override(task, args) { Err(task_err) => {
scheduler.wake_completed_task(task, Err(task_err.clone()));
TaskSchedulerState::err(scheduler, task_err)
}
Ok(rets) if rets.0 == LuaThreadStatus::Unresumable => {
scheduler.wake_completed_task(task, Ok(rets.1));
TaskSchedulerState::new(scheduler)
}
Ok(_) => TaskSchedulerState::new(scheduler), Ok(_) => TaskSchedulerState::new(scheduler),
Err(task_err) => TaskSchedulerState::err(scheduler, task_err),
},
None => match scheduler.resume_task(task) {
Ok(_) => TaskSchedulerState::new(scheduler),
Err(task_err) => TaskSchedulerState::err(scheduler, task_err),
},
}, },
} }
} }

View file

@ -6,6 +6,7 @@ mod scheduler_message;
mod scheduler_state; mod scheduler_state;
mod task_kind; mod task_kind;
mod task_reference; mod task_reference;
mod task_waiter;
pub use ext::*; pub use ext::*;
pub use proxy::*; pub use proxy::*;

View file

@ -9,12 +9,14 @@ use std::{
use futures_util::{future::LocalBoxFuture, stream::FuturesUnordered, Future}; use futures_util::{future::LocalBoxFuture, stream::FuturesUnordered, Future};
use mlua::prelude::*; use mlua::prelude::*;
use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex}; use tokio::sync::{mpsc, Mutex as AsyncMutex};
use super::scheduler_message::TaskSchedulerMessage; use super::{
scheduler_message::TaskSchedulerMessage,
task_waiter::{TaskWaiterFuture, TaskWaiterState},
};
pub use super::{task_kind::TaskKind, task_reference::TaskReference}; pub use super::{task_kind::TaskKind, task_reference::TaskReference};
type TaskResultSender = oneshot::Sender<LuaResult<LuaMultiValue<'static>>>;
type TaskFutureRets<'fut> = LuaResult<Option<LuaMultiValue<'fut>>>; type TaskFutureRets<'fut> = LuaResult<Option<LuaMultiValue<'fut>>>;
type TaskFuture<'fut> = LocalBoxFuture<'fut, (Option<TaskReference>, TaskFutureRets<'fut>)>; type TaskFuture<'fut> = LocalBoxFuture<'fut, (Option<TaskReference>, TaskFutureRets<'fut>)>;
@ -49,8 +51,9 @@ pub struct TaskScheduler<'fut> {
pub(super) tasks_count: Cell<usize>, pub(super) tasks_count: Cell<usize>,
pub(super) tasks_current: Cell<Option<TaskReference>>, pub(super) tasks_current: Cell<Option<TaskReference>>,
pub(super) tasks_queue_blocking: RefCell<VecDeque<TaskReference>>, pub(super) tasks_queue_blocking: RefCell<VecDeque<TaskReference>>,
pub(super) tasks_result_senders: RefCell<HashMap<TaskReference, TaskResultSender>>, pub(super) tasks_waiter_states:
pub(super) tasks_current_lua_error: Arc<RefCell<Option<LuaError>>>, RefCell<HashMap<TaskReference, Arc<AsyncMutex<TaskWaiterState<'fut>>>>>,
pub(super) tasks_current_lua_error: Arc<AsyncMutex<Option<LuaError>>>,
// Future tasks & objects for waking // Future tasks & objects for waking
pub(super) futures: AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>, pub(super) futures: AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>,
pub(super) futures_count: Cell<usize>, pub(super) futures_count: Cell<usize>,
@ -65,12 +68,14 @@ impl<'fut> TaskScheduler<'fut> {
*/ */
pub fn new(lua: &'static Lua) -> LuaResult<Self> { pub fn new(lua: &'static Lua) -> LuaResult<Self> {
let (tx, rx) = mpsc::unbounded_channel(); let (tx, rx) = mpsc::unbounded_channel();
let tasks_current_lua_error = Arc::new(RefCell::new(None)); let tasks_current_lua_error = Arc::new(AsyncMutex::new(None));
let tasks_current_lua_error_inner = tasks_current_lua_error.clone(); let tasks_current_lua_error_inner = tasks_current_lua_error.clone();
lua.set_interrupt(move || match tasks_current_lua_error_inner.take() { lua.set_interrupt(
move || match tasks_current_lua_error_inner.try_lock().unwrap().take() {
Some(err) => Err(err), Some(err) => Err(err),
None => Ok(LuaVmState::Continue), None => Ok(LuaVmState::Continue),
}); },
);
Ok(Self { Ok(Self {
lua, lua,
guid: Cell::new(0), guid: Cell::new(0),
@ -79,7 +84,7 @@ impl<'fut> TaskScheduler<'fut> {
tasks_count: Cell::new(0), tasks_count: Cell::new(0),
tasks_current: Cell::new(None), tasks_current: Cell::new(None),
tasks_queue_blocking: RefCell::new(VecDeque::new()), tasks_queue_blocking: RefCell::new(VecDeque::new()),
tasks_result_senders: RefCell::new(HashMap::new()), tasks_waiter_states: RefCell::new(HashMap::new()),
tasks_current_lua_error, tasks_current_lua_error,
futures: AsyncMutex::new(FuturesUnordered::new()), futures: AsyncMutex::new(FuturesUnordered::new()),
futures_tx: tx, futures_tx: tx,
@ -273,8 +278,6 @@ impl<'fut> TaskScheduler<'fut> {
TaskKind::Future => self.futures_count.set(self.futures_count.get() - 1), TaskKind::Future => self.futures_count.set(self.futures_count.get() - 1),
_ => self.tasks_count.set(self.tasks_count.get() - 1), _ => self.tasks_count.set(self.tasks_count.get() - 1),
} }
// Remove any sender
self.tasks_result_senders.borrow_mut().remove(task_ref);
// NOTE: We need to close the thread here to // NOTE: We need to close the thread here to
// make 100% sure that nothing can resume it // make 100% sure that nothing can resume it
let close: LuaFunction = self.lua.named_registry_value("co.close")?; let close: LuaFunction = self.lua.named_registry_value("co.close")?;
@ -296,14 +299,21 @@ impl<'fut> TaskScheduler<'fut> {
This will be a no-op if the task no longer exists. This will be a no-op if the task no longer exists.
*/ */
pub fn resume_task(&self, reference: TaskReference) -> LuaResult<LuaMultiValue> { pub fn resume_task<'a, 'r>(
&self,
reference: TaskReference,
override_args: Option<LuaResult<LuaMultiValue<'a>>>,
) -> LuaResult<(LuaThreadStatus, LuaMultiValue<'r>)>
where
'a: 'r,
{
// Fetch and check if the task was removed, if it got // Fetch and check if the task was removed, if it got
// removed it means it was intentionally cancelled // removed it means it was intentionally cancelled
let task = { let task = {
let mut tasks = self.tasks.borrow_mut(); let mut tasks = self.tasks.borrow_mut();
match tasks.remove(&reference) { match tasks.remove(&reference) {
Some(task) => task, Some(task) => task,
None => return Ok(LuaMultiValue::new()), None => return Ok((LuaThreadStatus::Unresumable, LuaMultiValue::new())),
} }
}; };
// Decrement the corresponding task counter // Decrement the corresponding task counter
@ -324,66 +334,27 @@ impl<'fut> TaskScheduler<'fut> {
// We got everything we need and our references // We got everything we need and our references
// were cleaned up properly, resume the thread // were cleaned up properly, resume the thread
self.tasks_current.set(Some(reference)); self.tasks_current.set(Some(reference));
let rets = match thread_args {
Some(args) => thread.resume(args),
None => thread.resume(()),
};
self.tasks_current.set(None);
// If we have a result sender for this task, we should run it if the thread finished
if thread.status() != LuaThreadStatus::Resumable {
if let Some(sender) = self.tasks_result_senders.borrow_mut().remove(&reference) {
let _ = sender.send(rets.clone());
}
}
rets
}
/**
Resumes a task, if the task still exists in the scheduler, using the given arguments.
A task may no longer exist in the scheduler if it has been manually
cancelled and removed by calling [`TaskScheduler::cancel_task()`].
This will be a no-op if the task no longer exists.
*/
pub fn resume_task_override<'a>(
&self,
reference: TaskReference,
override_args: LuaResult<LuaMultiValue<'a>>,
) -> LuaResult<LuaMultiValue<'a>> {
// Fetch and check if the task was removed, if it got
// removed it means it was intentionally cancelled
let task = {
let mut tasks = self.tasks.borrow_mut();
match tasks.remove(&reference) {
Some(task) => task,
None => return Ok(LuaMultiValue::new()),
}
};
// Decrement the corresponding task counter
match task.kind {
TaskKind::Future => self.futures_count.set(self.futures_count.get() - 1),
_ => self.tasks_count.set(self.tasks_count.get() - 1),
}
// Fetch and remove the thread to resume + its arguments
let thread: LuaThread = self.lua.registry_value(&task.thread)?;
self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
// We got everything we need and our references
// were cleaned up properly, resume the thread
self.tasks_current.set(Some(reference));
let rets = match override_args { let rets = match override_args {
Some(override_res) => match override_res {
Ok(args) => thread.resume(args),
Err(e) => { Err(e) => {
// NOTE: Setting this error here means that when the thread // NOTE: Setting this error here means that when the thread
// is resumed it will error instantly, so we don't need // is resumed it will error instantly, so we don't need
// to call it with proper args, empty args is fine // to call it with proper args, empty args is fine
self.tasks_current_lua_error.replace(Some(e)); self.tasks_current_lua_error.try_lock().unwrap().replace(e);
thread.resume(()) thread.resume(())
} }
Ok(args) => thread.resume(args), },
None => match thread_args {
Some(args) => thread.resume(args),
None => thread.resume(()),
},
}; };
self.tasks_current.set(None); self.tasks_current.set(None);
rets match rets {
Ok(rets) => Ok((thread.status(), rets)),
Err(e) => Err(e),
}
} }
/** /**
@ -446,9 +417,58 @@ impl<'fut> TaskScheduler<'fut> {
Ok(task_ref) Ok(task_ref)
} }
pub(crate) fn set_task_result_sender(&self, task_ref: TaskReference, sender: TaskResultSender) { /**
self.tasks_result_senders Queues a new future to run on the task scheduler,
inheriting the task id of the currently running task.
*/
pub(crate) fn queue_async_task_inherited(
&self,
thread: LuaThread<'_>,
thread_args: Option<LuaMultiValue<'_>>,
fut: impl Future<Output = TaskFutureRets<'fut>> + 'fut,
) -> LuaResult<TaskReference> {
let task_ref = self.create_task(TaskKind::Future, thread, thread_args, true)?;
let futs = self
.futures
.try_lock()
.expect("Tried to add future to queue during futures resumption");
futs.push(Box::pin(async move {
let result = fut.await;
(Some(task_ref), result)
}));
Ok(task_ref)
}
/**
Waits for a task to complete.
Panics if the task is not currently in the scheduler.
*/
pub(crate) async fn wait_for_task_completion(
&self,
reference: TaskReference,
) -> LuaResult<LuaMultiValue> {
if !self.tasks.borrow().contains_key(&reference) {
panic!("Task does not exist in scheduler")
}
let state = TaskWaiterState::new();
self.tasks_waiter_states
.borrow_mut() .borrow_mut()
.insert(task_ref, sender); .insert(reference, Arc::clone(&state));
TaskWaiterFuture::new(&state).await
}
/**
Wakes a task that has been completed and may have external code
waiting on it using [`TaskScheduler::wait_for_task_completion`].
*/
pub(super) fn wake_completed_task(
&self,
reference: TaskReference,
result: LuaResult<LuaMultiValue<'fut>>,
) {
if let Some(waiter_state) = self.tasks_waiter_states.borrow_mut().remove(&reference) {
waiter_state.try_lock().unwrap().finalize(result);
}
} }
} }

View file

@ -1,4 +1,7 @@
use std::fmt; use std::{
fmt,
hash::{Hash, Hasher},
};
use mlua::prelude::*; use mlua::prelude::*;
@ -6,7 +9,7 @@ use super::task_kind::TaskKind;
/// A lightweight, copyable struct that represents a /// A lightweight, copyable struct that represents a
/// task in the scheduler and is accessible from Lua /// task in the scheduler and is accessible from Lua
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Debug, Clone, Copy)]
pub struct TaskReference { pub struct TaskReference {
kind: TaskKind, kind: TaskKind,
guid: usize, guid: usize,
@ -32,4 +35,17 @@ impl fmt::Display for TaskReference {
} }
} }
impl Eq for TaskReference {}
impl PartialEq for TaskReference {
fn eq(&self, other: &Self) -> bool {
self.guid == other.guid
}
}
impl Hash for TaskReference {
fn hash<H: Hasher>(&self, state: &mut H) {
self.guid.hash(state);
}
}
impl LuaUserData for TaskReference {} impl LuaUserData for TaskReference {}

View file

@ -0,0 +1,66 @@
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
};
use tokio::sync::Mutex as AsyncMutex;
use mlua::prelude::*;
#[derive(Debug, Clone)]
pub(super) struct TaskWaiterState<'fut> {
rets: Option<LuaResult<LuaMultiValue<'fut>>>,
waker: Option<Waker>,
}
impl<'fut> TaskWaiterState<'fut> {
pub fn new() -> Arc<AsyncMutex<Self>> {
Arc::new(AsyncMutex::new(TaskWaiterState {
rets: None,
waker: None,
}))
}
pub fn finalize(&mut self, rets: LuaResult<LuaMultiValue<'fut>>) {
self.rets = Some(rets);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
#[derive(Debug)]
pub(super) struct TaskWaiterFuture<'fut> {
state: Arc<AsyncMutex<TaskWaiterState<'fut>>>,
}
impl<'fut> TaskWaiterFuture<'fut> {
pub fn new(state: &Arc<AsyncMutex<TaskWaiterState<'fut>>>) -> Self {
Self {
state: Arc::clone(state),
}
}
}
impl<'fut> Clone for TaskWaiterFuture<'fut> {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
}
}
}
impl<'fut> Future for TaskWaiterFuture<'fut> {
type Output = LuaResult<LuaMultiValue<'fut>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared_state = self.state.try_lock().unwrap();
if let Some(rets) = shared_state.rets.clone() {
Poll::Ready(rets)
} else {
shared_state.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

View file

@ -60,6 +60,9 @@ create_tests! {
process_env: "process/env", process_env: "process/env",
process_exit: "process/exit", process_exit: "process/exit",
process_spawn: "process/spawn", process_spawn: "process/spawn",
require_async: "globals/require/tests/async",
require_async_concurrent: "globals/require/tests/async_concurrent",
require_async_sequential: "globals/require/tests/async_sequential",
require_children: "globals/require/tests/children", require_children: "globals/require/tests/children",
require_invalid: "globals/require/tests/invalid", require_invalid: "globals/require/tests/invalid",
require_nested: "globals/require/tests/nested", require_nested: "globals/require/tests/nested",

View file

@ -0,0 +1,9 @@
local module = require("./modules/async")
assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
module = require("modules/async")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")

View file

@ -0,0 +1,22 @@
local module1
local module2
task.defer(function()
module2 = require("./modules/async")
end)
task.spawn(function()
module1 = require("./modules/async")
end)
task.wait(1)
assert(type(module1) == "table", "Required module1 did not return a table")
assert(module1.Foo == "Bar", "Required module1 did not contain correct values")
assert(module1.Hello == "World", "Required module1 did not contain correct values")
assert(type(module2) == "table", "Required module2 did not return a table")
assert(module2.Foo == "Bar", "Required module2 did not contain correct values")
assert(module2.Hello == "World", "Required module2 did not contain correct values")
assert(module1 == module2, "Required modules should point to the same return value")

View file

@ -0,0 +1,14 @@
local module1 = require("./modules/async")
local module2 = require("./modules/async")
task.wait(1)
assert(type(module1) == "table", "Required module1 did not return a table")
assert(module1.Foo == "Bar", "Required module1 did not contain correct values")
assert(module1.Hello == "World", "Required module1 did not contain correct values")
assert(type(module2) == "table", "Required module2 did not return a table")
assert(module2.Foo == "Bar", "Required module2 did not contain correct values")
assert(module2.Hello == "World", "Required module2 did not contain correct values")
assert(module1 == module2, "Required modules should point to the same return value")

View file

@ -0,0 +1,6 @@
task.wait(0.25)
return {
Foo = "Bar",
Hello = "World",
}