mirror of
https://github.com/lune-org/lune.git
synced 2024-12-12 13:00:37 +00:00
Finish implementing async require, add test cases
This commit is contained in:
parent
0975a6180b
commit
172ab16823
14 changed files with 400 additions and 165 deletions
10
CHANGELOG.md
10
CHANGELOG.md
|
@ -12,7 +12,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
|||
|
||||
### Added
|
||||
|
||||
- Builtin modules such as `fs`, `net` and others can now be imported using `require("@lune/fs")`, `require("@lune/net")` ..
|
||||
- `require` has been reimplemented and overhauled in several ways:
|
||||
|
||||
- Builtin modules such as `fs`, `net` and others can now be imported using `require("@lune/fs")`, `require("@lune/net")` ... <br />
|
||||
This is the first step towards moving away from adding each library as a global, and allowing Lune to have more built-in libraries.
|
||||
|
||||
- Requiring a script is now completely asynchronous and will not block lua threads other than the caller.
|
||||
- Requiring a script will no longer error when using async APIs in the main body of the required script.
|
||||
|
||||
Behavior otherwise stays the same, and requires are still relative to file unless the special `@` prefix is used.
|
||||
|
||||
### Removed
|
||||
|
||||
|
|
|
@ -136,7 +136,7 @@ fn process_env_iter<'lua>(
|
|||
lua: &'lua Lua,
|
||||
(_, _): (LuaValue<'lua>, ()),
|
||||
) -> LuaResult<LuaFunction<'lua>> {
|
||||
let mut vars = env::vars_os();
|
||||
let mut vars = env::vars_os().collect::<Vec<_>>().into_iter();
|
||||
lua.create_function_mut(move |lua, _: ()| match vars.next() {
|
||||
Some((key, value)) => {
|
||||
let raw_key = RawOsString::new(key);
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
use std::{
|
||||
cell::RefCell,
|
||||
collections::HashMap,
|
||||
collections::{HashMap, HashSet},
|
||||
env::current_dir,
|
||||
path::{self, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use dunce::canonicalize;
|
||||
use mlua::prelude::*;
|
||||
use tokio::{fs, sync::oneshot};
|
||||
use tokio::fs;
|
||||
|
||||
use crate::lua::{
|
||||
table::TableBuilder,
|
||||
|
@ -19,14 +20,17 @@ local source = info(1, "s")
|
|||
if source == '[string "require"]' then
|
||||
source = info(2, "s")
|
||||
end
|
||||
local absolute, relative = paths(context, source, ...)
|
||||
return load(context, absolute, relative)
|
||||
load(context, source, ...)
|
||||
return yield()
|
||||
"#;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct RequireContext<'lua> {
|
||||
builtins: HashMap<String, LuaMultiValue<'lua>>,
|
||||
cached: RefCell<HashMap<String, LuaResult<LuaMultiValue<'lua>>>>,
|
||||
// NOTE: We need to use arc here so that mlua clones
|
||||
// the reference and not the entire inner value(s)
|
||||
builtins: Arc<HashMap<String, LuaMultiValue<'lua>>>,
|
||||
cached: Arc<RefCell<HashMap<String, LuaResult<LuaMultiValue<'lua>>>>>,
|
||||
locks: Arc<RefCell<HashSet<String>>>,
|
||||
pwd: String,
|
||||
}
|
||||
|
||||
|
@ -44,12 +48,36 @@ impl<'lua> RequireContext<'lua> {
|
|||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_locked(&self, absolute_path: &str) -> bool {
|
||||
self.locks.borrow().contains(absolute_path)
|
||||
}
|
||||
|
||||
impl<'lua> LuaUserData for RequireContext<'lua> {}
|
||||
pub fn set_locked(&self, absolute_path: &str) -> bool {
|
||||
self.locks.borrow_mut().insert(absolute_path.to_string())
|
||||
}
|
||||
|
||||
fn paths(
|
||||
context: RequireContext,
|
||||
pub fn set_unlocked(&self, absolute_path: &str) -> bool {
|
||||
self.locks.borrow_mut().remove(absolute_path)
|
||||
}
|
||||
|
||||
pub fn try_acquire_lock_sync(&self, absolute_path: &str) -> bool {
|
||||
if self.is_locked(absolute_path) {
|
||||
false
|
||||
} else {
|
||||
self.set_locked(absolute_path);
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_cached(&self, absolute_path: String, result: &LuaResult<LuaMultiValue<'lua>>) {
|
||||
self.cached
|
||||
.borrow_mut()
|
||||
.insert(absolute_path, result.clone());
|
||||
}
|
||||
|
||||
pub fn get_paths(
|
||||
&self,
|
||||
require_source: String,
|
||||
require_path: String,
|
||||
) -> LuaResult<(String, String)> {
|
||||
|
@ -78,14 +106,18 @@ fn paths(
|
|||
}
|
||||
};
|
||||
let absolute = file_path.to_string_lossy().to_string();
|
||||
let relative = absolute.trim_start_matches(&context.pwd).to_string();
|
||||
let relative = absolute.trim_start_matches(&self.pwd).to_string();
|
||||
Ok((absolute, relative))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> LuaUserData for RequireContext<'lua> {}
|
||||
|
||||
fn load_builtin<'lua>(
|
||||
_lua: &'lua Lua,
|
||||
context: RequireContext<'lua>,
|
||||
context: &RequireContext<'lua>,
|
||||
module_name: String,
|
||||
_has_acquired_lock: bool,
|
||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
match context.builtins.get(&module_name) {
|
||||
Some(module) => Ok(module.clone()),
|
||||
|
@ -98,14 +130,20 @@ fn load_builtin<'lua>(
|
|||
|
||||
async fn load_file<'lua>(
|
||||
lua: &'lua Lua,
|
||||
context: RequireContext<'lua>,
|
||||
context: &RequireContext<'lua>,
|
||||
absolute_path: String,
|
||||
relative_path: String,
|
||||
has_acquired_lock: bool,
|
||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
let cached = { context.cached.borrow().get(&absolute_path).cloned() };
|
||||
match cached {
|
||||
Some(cached) => cached,
|
||||
None => {
|
||||
if !has_acquired_lock {
|
||||
return Err(LuaError::RuntimeError(
|
||||
"Failed to get require lock".to_string(),
|
||||
));
|
||||
}
|
||||
// Try to read the wanted file, note that we use bytes instead of reading
|
||||
// to a string since lua scripts are not necessarily valid utf-8 strings
|
||||
let contents = fs::read(&absolute_path).await.map_err(LuaError::external)?;
|
||||
|
@ -120,21 +158,15 @@ async fn load_file<'lua>(
|
|||
.set_name(path_relative_no_extension)?
|
||||
.into_function()?;
|
||||
let loaded_thread = lua.create_thread(loaded_func)?;
|
||||
// Run the thread and provide a channel that will
|
||||
// then get its result received when it finishes
|
||||
let (tx, rx) = oneshot::channel();
|
||||
{
|
||||
// Run the thread and wait for completion using the native task scheduler waker
|
||||
let task_fut = {
|
||||
let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
|
||||
let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?;
|
||||
sched.set_task_result_sender(task, tx);
|
||||
}
|
||||
sched.wait_for_task_completion(task)
|
||||
};
|
||||
// Wait for the thread to finish running, cache + return our result
|
||||
// FIXME: This waits indefinitely for nested requires for some reason
|
||||
let rets = rx.await.expect("Sender was dropped during require");
|
||||
context
|
||||
.cached
|
||||
.borrow_mut()
|
||||
.insert(absolute_path, rets.clone());
|
||||
let rets = task_fut.await;
|
||||
context.set_cached(absolute_path, &rets);
|
||||
rets
|
||||
}
|
||||
}
|
||||
|
@ -145,35 +177,66 @@ async fn load<'lua>(
|
|||
context: RequireContext<'lua>,
|
||||
absolute_path: String,
|
||||
relative_path: String,
|
||||
has_acquired_lock: bool,
|
||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
if absolute_path == relative_path && absolute_path.starts_with('@') {
|
||||
let result = if absolute_path == relative_path && absolute_path.starts_with('@') {
|
||||
if let Some(module_name) = absolute_path.strip_prefix("@lune/") {
|
||||
load_builtin(lua, context, module_name.to_string())
|
||||
load_builtin(lua, &context, module_name.to_string(), has_acquired_lock)
|
||||
} else {
|
||||
// FUTURE: '@' can be used a special prefix for users to set their own
|
||||
// paths relative to a project file, similar to typescript paths config
|
||||
// https://www.typescriptlang.org/tsconfig#paths
|
||||
Err(LuaError::RuntimeError(
|
||||
"Require paths prefixed by '@' are not yet supported".to_string(),
|
||||
))
|
||||
}
|
||||
} else {
|
||||
load_file(lua, context, absolute_path, relative_path).await
|
||||
load_file(
|
||||
lua,
|
||||
&context,
|
||||
absolute_path.to_string(),
|
||||
relative_path,
|
||||
has_acquired_lock,
|
||||
)
|
||||
.await
|
||||
};
|
||||
if has_acquired_lock {
|
||||
context.set_unlocked(&absolute_path);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
let require_context = RequireContext::new();
|
||||
let require_print: LuaFunction = lua.named_registry_value("print")?;
|
||||
let require_yield: LuaFunction = lua.named_registry_value("co.yield")?;
|
||||
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
|
||||
let require_print: LuaFunction = lua.named_registry_value("print")?;
|
||||
|
||||
let require_env = TableBuilder::new(lua)?
|
||||
.with_value("context", require_context)?
|
||||
.with_value("print", require_print)?
|
||||
.with_value("yield", require_yield)?
|
||||
.with_value("info", require_info)?
|
||||
.with_function("paths", |_, (context, require_source, require_path)| {
|
||||
paths(context, require_source, require_path)
|
||||
})?
|
||||
.with_async_function("load", |lua, (context, require_source, require_path)| {
|
||||
load(lua, context, require_source, require_path)
|
||||
})?
|
||||
.with_value("print", require_print)?
|
||||
.with_function(
|
||||
"load",
|
||||
|lua, (context, require_source, require_path): (RequireContext, String, String)| {
|
||||
let (absolute_path, relative_path) =
|
||||
context.get_paths(require_source, require_path)?;
|
||||
// NOTE: We can not acquire the lock in the async part of the require
|
||||
// load process since several requires may have happened for the
|
||||
// same path before the async load task even gets a chance to run
|
||||
let has_lock = context.try_acquire_lock_sync(&absolute_path);
|
||||
let fut = load(lua, context, absolute_path, relative_path, has_lock);
|
||||
let sched = lua
|
||||
.app_data_ref::<&TaskScheduler>()
|
||||
.expect("Missing task scheduler as a lua app data");
|
||||
sched.queue_async_task_inherited(lua.current_thread(), None, async {
|
||||
let rets = fut.await?;
|
||||
let mult = rets.to_lua_multi(lua)?;
|
||||
Ok(Some(mult))
|
||||
})
|
||||
},
|
||||
)?
|
||||
.build_readonly()?;
|
||||
|
||||
let require_fn_lua = lua
|
||||
|
|
|
@ -158,13 +158,13 @@ fn coroutine_resume<'lua>(
|
|||
let result = match value {
|
||||
LuaThreadOrTaskReference::Thread(t) => {
|
||||
let task = sched.create_task(TaskKind::Instant, t, None, true)?;
|
||||
sched.resume_task(task)
|
||||
sched.resume_task(task, None)
|
||||
}
|
||||
LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t),
|
||||
LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t, None),
|
||||
};
|
||||
sched.force_set_current_task(Some(current));
|
||||
match result {
|
||||
Ok(rets) => Ok((true, rets)),
|
||||
Ok(rets) => Ok((true, rets.1)),
|
||||
Err(e) => Ok((false, e.to_lua_multi(lua)?)),
|
||||
}
|
||||
}
|
||||
|
@ -187,8 +187,11 @@ fn coroutine_wrap<'lua>(lua: &'lua Lua, func: LuaFunction) -> LuaResult<LuaFunct
|
|||
let result = lua
|
||||
.app_data_ref::<&TaskScheduler>()
|
||||
.unwrap()
|
||||
.resume_task_override(task, Ok(args));
|
||||
.resume_task(task, Some(Ok(args)));
|
||||
sched.force_set_current_task(Some(current));
|
||||
result
|
||||
match result {
|
||||
Ok(rets) => Ok(rets.1),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -76,10 +76,13 @@ impl TaskSchedulerResumeExt for TaskScheduler<'_> {
|
|||
Resumes the next queued Lua task, if one exists, blocking
|
||||
the current thread until it either yields or finishes.
|
||||
*/
|
||||
fn resume_next_blocking_task(
|
||||
scheduler: &TaskScheduler<'_>,
|
||||
override_args: Option<LuaResult<LuaMultiValue>>,
|
||||
) -> TaskSchedulerState {
|
||||
fn resume_next_blocking_task<'sched, 'args>(
|
||||
scheduler: &TaskScheduler<'sched>,
|
||||
override_args: Option<LuaResult<LuaMultiValue<'args>>>,
|
||||
) -> TaskSchedulerState
|
||||
where
|
||||
'args: 'sched,
|
||||
{
|
||||
match {
|
||||
let mut queue_guard = scheduler.tasks_queue_blocking.borrow_mut();
|
||||
let task = queue_guard.pop_front();
|
||||
|
@ -87,15 +90,16 @@ fn resume_next_blocking_task(
|
|||
task
|
||||
} {
|
||||
None => TaskSchedulerState::new(scheduler),
|
||||
Some(task) => match override_args {
|
||||
Some(args) => match scheduler.resume_task_override(task, args) {
|
||||
Some(task) => match scheduler.resume_task(task, override_args) {
|
||||
Err(task_err) => {
|
||||
scheduler.wake_completed_task(task, Err(task_err.clone()));
|
||||
TaskSchedulerState::err(scheduler, task_err)
|
||||
}
|
||||
Ok(rets) if rets.0 == LuaThreadStatus::Unresumable => {
|
||||
scheduler.wake_completed_task(task, Ok(rets.1));
|
||||
TaskSchedulerState::new(scheduler)
|
||||
}
|
||||
Ok(_) => TaskSchedulerState::new(scheduler),
|
||||
Err(task_err) => TaskSchedulerState::err(scheduler, task_err),
|
||||
},
|
||||
None => match scheduler.resume_task(task) {
|
||||
Ok(_) => TaskSchedulerState::new(scheduler),
|
||||
Err(task_err) => TaskSchedulerState::err(scheduler, task_err),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ mod scheduler_message;
|
|||
mod scheduler_state;
|
||||
mod task_kind;
|
||||
mod task_reference;
|
||||
mod task_waiter;
|
||||
|
||||
pub use ext::*;
|
||||
pub use proxy::*;
|
||||
|
|
|
@ -9,12 +9,14 @@ use std::{
|
|||
use futures_util::{future::LocalBoxFuture, stream::FuturesUnordered, Future};
|
||||
use mlua::prelude::*;
|
||||
|
||||
use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex};
|
||||
use tokio::sync::{mpsc, Mutex as AsyncMutex};
|
||||
|
||||
use super::scheduler_message::TaskSchedulerMessage;
|
||||
use super::{
|
||||
scheduler_message::TaskSchedulerMessage,
|
||||
task_waiter::{TaskWaiterFuture, TaskWaiterState},
|
||||
};
|
||||
pub use super::{task_kind::TaskKind, task_reference::TaskReference};
|
||||
|
||||
type TaskResultSender = oneshot::Sender<LuaResult<LuaMultiValue<'static>>>;
|
||||
type TaskFutureRets<'fut> = LuaResult<Option<LuaMultiValue<'fut>>>;
|
||||
type TaskFuture<'fut> = LocalBoxFuture<'fut, (Option<TaskReference>, TaskFutureRets<'fut>)>;
|
||||
|
||||
|
@ -49,8 +51,9 @@ pub struct TaskScheduler<'fut> {
|
|||
pub(super) tasks_count: Cell<usize>,
|
||||
pub(super) tasks_current: Cell<Option<TaskReference>>,
|
||||
pub(super) tasks_queue_blocking: RefCell<VecDeque<TaskReference>>,
|
||||
pub(super) tasks_result_senders: RefCell<HashMap<TaskReference, TaskResultSender>>,
|
||||
pub(super) tasks_current_lua_error: Arc<RefCell<Option<LuaError>>>,
|
||||
pub(super) tasks_waiter_states:
|
||||
RefCell<HashMap<TaskReference, Arc<AsyncMutex<TaskWaiterState<'fut>>>>>,
|
||||
pub(super) tasks_current_lua_error: Arc<AsyncMutex<Option<LuaError>>>,
|
||||
// Future tasks & objects for waking
|
||||
pub(super) futures: AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>,
|
||||
pub(super) futures_count: Cell<usize>,
|
||||
|
@ -65,12 +68,14 @@ impl<'fut> TaskScheduler<'fut> {
|
|||
*/
|
||||
pub fn new(lua: &'static Lua) -> LuaResult<Self> {
|
||||
let (tx, rx) = mpsc::unbounded_channel();
|
||||
let tasks_current_lua_error = Arc::new(RefCell::new(None));
|
||||
let tasks_current_lua_error = Arc::new(AsyncMutex::new(None));
|
||||
let tasks_current_lua_error_inner = tasks_current_lua_error.clone();
|
||||
lua.set_interrupt(move || match tasks_current_lua_error_inner.take() {
|
||||
lua.set_interrupt(
|
||||
move || match tasks_current_lua_error_inner.try_lock().unwrap().take() {
|
||||
Some(err) => Err(err),
|
||||
None => Ok(LuaVmState::Continue),
|
||||
});
|
||||
},
|
||||
);
|
||||
Ok(Self {
|
||||
lua,
|
||||
guid: Cell::new(0),
|
||||
|
@ -79,7 +84,7 @@ impl<'fut> TaskScheduler<'fut> {
|
|||
tasks_count: Cell::new(0),
|
||||
tasks_current: Cell::new(None),
|
||||
tasks_queue_blocking: RefCell::new(VecDeque::new()),
|
||||
tasks_result_senders: RefCell::new(HashMap::new()),
|
||||
tasks_waiter_states: RefCell::new(HashMap::new()),
|
||||
tasks_current_lua_error,
|
||||
futures: AsyncMutex::new(FuturesUnordered::new()),
|
||||
futures_tx: tx,
|
||||
|
@ -273,8 +278,6 @@ impl<'fut> TaskScheduler<'fut> {
|
|||
TaskKind::Future => self.futures_count.set(self.futures_count.get() - 1),
|
||||
_ => self.tasks_count.set(self.tasks_count.get() - 1),
|
||||
}
|
||||
// Remove any sender
|
||||
self.tasks_result_senders.borrow_mut().remove(task_ref);
|
||||
// NOTE: We need to close the thread here to
|
||||
// make 100% sure that nothing can resume it
|
||||
let close: LuaFunction = self.lua.named_registry_value("co.close")?;
|
||||
|
@ -296,14 +299,21 @@ impl<'fut> TaskScheduler<'fut> {
|
|||
|
||||
This will be a no-op if the task no longer exists.
|
||||
*/
|
||||
pub fn resume_task(&self, reference: TaskReference) -> LuaResult<LuaMultiValue> {
|
||||
pub fn resume_task<'a, 'r>(
|
||||
&self,
|
||||
reference: TaskReference,
|
||||
override_args: Option<LuaResult<LuaMultiValue<'a>>>,
|
||||
) -> LuaResult<(LuaThreadStatus, LuaMultiValue<'r>)>
|
||||
where
|
||||
'a: 'r,
|
||||
{
|
||||
// Fetch and check if the task was removed, if it got
|
||||
// removed it means it was intentionally cancelled
|
||||
let task = {
|
||||
let mut tasks = self.tasks.borrow_mut();
|
||||
match tasks.remove(&reference) {
|
||||
Some(task) => task,
|
||||
None => return Ok(LuaMultiValue::new()),
|
||||
None => return Ok((LuaThreadStatus::Unresumable, LuaMultiValue::new())),
|
||||
}
|
||||
};
|
||||
// Decrement the corresponding task counter
|
||||
|
@ -324,66 +334,27 @@ impl<'fut> TaskScheduler<'fut> {
|
|||
// We got everything we need and our references
|
||||
// were cleaned up properly, resume the thread
|
||||
self.tasks_current.set(Some(reference));
|
||||
let rets = match thread_args {
|
||||
Some(args) => thread.resume(args),
|
||||
None => thread.resume(()),
|
||||
};
|
||||
self.tasks_current.set(None);
|
||||
// If we have a result sender for this task, we should run it if the thread finished
|
||||
if thread.status() != LuaThreadStatus::Resumable {
|
||||
if let Some(sender) = self.tasks_result_senders.borrow_mut().remove(&reference) {
|
||||
let _ = sender.send(rets.clone());
|
||||
}
|
||||
}
|
||||
rets
|
||||
}
|
||||
|
||||
/**
|
||||
Resumes a task, if the task still exists in the scheduler, using the given arguments.
|
||||
|
||||
A task may no longer exist in the scheduler if it has been manually
|
||||
cancelled and removed by calling [`TaskScheduler::cancel_task()`].
|
||||
|
||||
This will be a no-op if the task no longer exists.
|
||||
*/
|
||||
pub fn resume_task_override<'a>(
|
||||
&self,
|
||||
reference: TaskReference,
|
||||
override_args: LuaResult<LuaMultiValue<'a>>,
|
||||
) -> LuaResult<LuaMultiValue<'a>> {
|
||||
// Fetch and check if the task was removed, if it got
|
||||
// removed it means it was intentionally cancelled
|
||||
let task = {
|
||||
let mut tasks = self.tasks.borrow_mut();
|
||||
match tasks.remove(&reference) {
|
||||
Some(task) => task,
|
||||
None => return Ok(LuaMultiValue::new()),
|
||||
}
|
||||
};
|
||||
// Decrement the corresponding task counter
|
||||
match task.kind {
|
||||
TaskKind::Future => self.futures_count.set(self.futures_count.get() - 1),
|
||||
_ => self.tasks_count.set(self.tasks_count.get() - 1),
|
||||
}
|
||||
// Fetch and remove the thread to resume + its arguments
|
||||
let thread: LuaThread = self.lua.registry_value(&task.thread)?;
|
||||
self.lua.remove_registry_value(task.thread)?;
|
||||
self.lua.remove_registry_value(task.args)?;
|
||||
// We got everything we need and our references
|
||||
// were cleaned up properly, resume the thread
|
||||
self.tasks_current.set(Some(reference));
|
||||
let rets = match override_args {
|
||||
Some(override_res) => match override_res {
|
||||
Ok(args) => thread.resume(args),
|
||||
Err(e) => {
|
||||
// NOTE: Setting this error here means that when the thread
|
||||
// is resumed it will error instantly, so we don't need
|
||||
// to call it with proper args, empty args is fine
|
||||
self.tasks_current_lua_error.replace(Some(e));
|
||||
self.tasks_current_lua_error.try_lock().unwrap().replace(e);
|
||||
thread.resume(())
|
||||
}
|
||||
Ok(args) => thread.resume(args),
|
||||
},
|
||||
None => match thread_args {
|
||||
Some(args) => thread.resume(args),
|
||||
None => thread.resume(()),
|
||||
},
|
||||
};
|
||||
self.tasks_current.set(None);
|
||||
rets
|
||||
match rets {
|
||||
Ok(rets) => Ok((thread.status(), rets)),
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -446,9 +417,58 @@ impl<'fut> TaskScheduler<'fut> {
|
|||
Ok(task_ref)
|
||||
}
|
||||
|
||||
pub(crate) fn set_task_result_sender(&self, task_ref: TaskReference, sender: TaskResultSender) {
|
||||
self.tasks_result_senders
|
||||
/**
|
||||
Queues a new future to run on the task scheduler,
|
||||
inheriting the task id of the currently running task.
|
||||
*/
|
||||
pub(crate) fn queue_async_task_inherited(
|
||||
&self,
|
||||
thread: LuaThread<'_>,
|
||||
thread_args: Option<LuaMultiValue<'_>>,
|
||||
fut: impl Future<Output = TaskFutureRets<'fut>> + 'fut,
|
||||
) -> LuaResult<TaskReference> {
|
||||
let task_ref = self.create_task(TaskKind::Future, thread, thread_args, true)?;
|
||||
let futs = self
|
||||
.futures
|
||||
.try_lock()
|
||||
.expect("Tried to add future to queue during futures resumption");
|
||||
futs.push(Box::pin(async move {
|
||||
let result = fut.await;
|
||||
(Some(task_ref), result)
|
||||
}));
|
||||
Ok(task_ref)
|
||||
}
|
||||
|
||||
/**
|
||||
Waits for a task to complete.
|
||||
|
||||
Panics if the task is not currently in the scheduler.
|
||||
*/
|
||||
pub(crate) async fn wait_for_task_completion(
|
||||
&self,
|
||||
reference: TaskReference,
|
||||
) -> LuaResult<LuaMultiValue> {
|
||||
if !self.tasks.borrow().contains_key(&reference) {
|
||||
panic!("Task does not exist in scheduler")
|
||||
}
|
||||
let state = TaskWaiterState::new();
|
||||
self.tasks_waiter_states
|
||||
.borrow_mut()
|
||||
.insert(task_ref, sender);
|
||||
.insert(reference, Arc::clone(&state));
|
||||
TaskWaiterFuture::new(&state).await
|
||||
}
|
||||
|
||||
/**
|
||||
Wakes a task that has been completed and may have external code
|
||||
waiting on it using [`TaskScheduler::wait_for_task_completion`].
|
||||
*/
|
||||
pub(super) fn wake_completed_task(
|
||||
&self,
|
||||
reference: TaskReference,
|
||||
result: LuaResult<LuaMultiValue<'fut>>,
|
||||
) {
|
||||
if let Some(waiter_state) = self.tasks_waiter_states.borrow_mut().remove(&reference) {
|
||||
waiter_state.try_lock().unwrap().finalize(result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use std::fmt;
|
||||
use std::{
|
||||
fmt,
|
||||
hash::{Hash, Hasher},
|
||||
};
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
|
@ -6,7 +9,7 @@ use super::task_kind::TaskKind;
|
|||
|
||||
/// A lightweight, copyable struct that represents a
|
||||
/// task in the scheduler and is accessible from Lua
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct TaskReference {
|
||||
kind: TaskKind,
|
||||
guid: usize,
|
||||
|
@ -32,4 +35,17 @@ impl fmt::Display for TaskReference {
|
|||
}
|
||||
}
|
||||
|
||||
impl Eq for TaskReference {}
|
||||
impl PartialEq for TaskReference {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.guid == other.guid
|
||||
}
|
||||
}
|
||||
|
||||
impl Hash for TaskReference {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
self.guid.hash(state);
|
||||
}
|
||||
}
|
||||
|
||||
impl LuaUserData for TaskReference {}
|
||||
|
|
66
packages/lib/src/lua/task/task_waiter.rs
Normal file
66
packages/lib/src/lua/task/task_waiter.rs
Normal file
|
@ -0,0 +1,66 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll, Waker},
|
||||
};
|
||||
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct TaskWaiterState<'fut> {
|
||||
rets: Option<LuaResult<LuaMultiValue<'fut>>>,
|
||||
waker: Option<Waker>,
|
||||
}
|
||||
|
||||
impl<'fut> TaskWaiterState<'fut> {
|
||||
pub fn new() -> Arc<AsyncMutex<Self>> {
|
||||
Arc::new(AsyncMutex::new(TaskWaiterState {
|
||||
rets: None,
|
||||
waker: None,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn finalize(&mut self, rets: LuaResult<LuaMultiValue<'fut>>) {
|
||||
self.rets = Some(rets);
|
||||
if let Some(waker) = self.waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct TaskWaiterFuture<'fut> {
|
||||
state: Arc<AsyncMutex<TaskWaiterState<'fut>>>,
|
||||
}
|
||||
|
||||
impl<'fut> TaskWaiterFuture<'fut> {
|
||||
pub fn new(state: &Arc<AsyncMutex<TaskWaiterState<'fut>>>) -> Self {
|
||||
Self {
|
||||
state: Arc::clone(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'fut> Clone for TaskWaiterFuture<'fut> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
state: Arc::clone(&self.state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'fut> Future for TaskWaiterFuture<'fut> {
|
||||
type Output = LuaResult<LuaMultiValue<'fut>>;
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut shared_state = self.state.try_lock().unwrap();
|
||||
if let Some(rets) = shared_state.rets.clone() {
|
||||
Poll::Ready(rets)
|
||||
} else {
|
||||
shared_state.waker = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
|
@ -60,6 +60,9 @@ create_tests! {
|
|||
process_env: "process/env",
|
||||
process_exit: "process/exit",
|
||||
process_spawn: "process/spawn",
|
||||
require_async: "globals/require/tests/async",
|
||||
require_async_concurrent: "globals/require/tests/async_concurrent",
|
||||
require_async_sequential: "globals/require/tests/async_sequential",
|
||||
require_children: "globals/require/tests/children",
|
||||
require_invalid: "globals/require/tests/invalid",
|
||||
require_nested: "globals/require/tests/nested",
|
||||
|
|
9
tests/globals/require/tests/async.luau
Normal file
9
tests/globals/require/tests/async.luau
Normal file
|
@ -0,0 +1,9 @@
|
|||
local module = require("./modules/async")
|
||||
|
||||
assert(type(module) == "table", "Required module did not return a table")
|
||||
assert(module.Foo == "Bar", "Required module did not contain correct values")
|
||||
assert(module.Hello == "World", "Required module did not contain correct values")
|
||||
|
||||
module = require("modules/async")
|
||||
assert(module.Foo == "Bar", "Required module did not contain correct values")
|
||||
assert(module.Hello == "World", "Required module did not contain correct values")
|
22
tests/globals/require/tests/async_concurrent.luau
Normal file
22
tests/globals/require/tests/async_concurrent.luau
Normal file
|
@ -0,0 +1,22 @@
|
|||
local module1
|
||||
local module2
|
||||
|
||||
task.defer(function()
|
||||
module2 = require("./modules/async")
|
||||
end)
|
||||
|
||||
task.spawn(function()
|
||||
module1 = require("./modules/async")
|
||||
end)
|
||||
|
||||
task.wait(1)
|
||||
|
||||
assert(type(module1) == "table", "Required module1 did not return a table")
|
||||
assert(module1.Foo == "Bar", "Required module1 did not contain correct values")
|
||||
assert(module1.Hello == "World", "Required module1 did not contain correct values")
|
||||
|
||||
assert(type(module2) == "table", "Required module2 did not return a table")
|
||||
assert(module2.Foo == "Bar", "Required module2 did not contain correct values")
|
||||
assert(module2.Hello == "World", "Required module2 did not contain correct values")
|
||||
|
||||
assert(module1 == module2, "Required modules should point to the same return value")
|
14
tests/globals/require/tests/async_sequential.luau
Normal file
14
tests/globals/require/tests/async_sequential.luau
Normal file
|
@ -0,0 +1,14 @@
|
|||
local module1 = require("./modules/async")
|
||||
local module2 = require("./modules/async")
|
||||
|
||||
task.wait(1)
|
||||
|
||||
assert(type(module1) == "table", "Required module1 did not return a table")
|
||||
assert(module1.Foo == "Bar", "Required module1 did not contain correct values")
|
||||
assert(module1.Hello == "World", "Required module1 did not contain correct values")
|
||||
|
||||
assert(type(module2) == "table", "Required module2 did not return a table")
|
||||
assert(module2.Foo == "Bar", "Required module2 did not contain correct values")
|
||||
assert(module2.Hello == "World", "Required module2 did not contain correct values")
|
||||
|
||||
assert(module1 == module2, "Required modules should point to the same return value")
|
6
tests/globals/require/tests/modules/async.luau
Normal file
6
tests/globals/require/tests/modules/async.luau
Normal file
|
@ -0,0 +1,6 @@
|
|||
task.wait(0.25)
|
||||
|
||||
return {
|
||||
Foo = "Bar",
|
||||
Hello = "World",
|
||||
}
|
Loading…
Reference in a new issue