diff --git a/packages/lib/src/globals/task.rs b/packages/lib/src/globals/task.rs index 55b5376..fb48f39 100644 --- a/packages/lib/src/globals/task.rs +++ b/packages/lib/src/globals/task.rs @@ -21,7 +21,7 @@ pub fn create(lua: &'static Lua) -> LuaResult> { // Create task spawning functions that add tasks to the scheduler let task_cancel = lua.create_function(|lua, task: TaskReference| { let sched = lua.app_data_mut::<&TaskScheduler>().unwrap(); - sched.cancel_task(task)?; + sched.remove_task(task)?; Ok(()) })?; let task_defer = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| { diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index 20afaaf..d462c28 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -97,6 +97,7 @@ impl Lune { let coroutine: LuaTable = lua.globals().get("coroutine")?; lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?; lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?; + lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?; let debug: LuaTable = lua.globals().raw_get("debug")?; lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; // Add in wanted lune globals diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs index 6a1cafe..ef2a558 100644 --- a/packages/lib/src/lua/task/scheduler.rs +++ b/packages/lib/src/lua/task/scheduler.rs @@ -10,7 +10,7 @@ use std::{ time::Duration, }; -use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; +use futures_util::{future::LocalBoxFuture, stream::FuturesUnordered, Future, StreamExt}; use mlua::prelude::*; use tokio::{ @@ -21,8 +21,8 @@ use tokio::{ type TaskSchedulerQueue = Arc>>; type TaskFutureArgsOverride<'fut> = Option>>; -type TaskFutureResult<'fut> = (TaskReference, LuaResult>); -type TaskFuture<'fut> = BoxFuture<'fut, TaskFutureResult<'fut>>; +type TaskFutureReturns<'fut> = LuaResult>; +type TaskFuture<'fut> = LocalBoxFuture<'fut, (TaskReference, TaskFutureReturns<'fut>)>; /// An enum representing different kinds of tasks #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -112,13 +112,14 @@ pub enum TaskSchedulerResult { #[derive(Debug)] pub struct TaskScheduler<'fut> { lua: &'static Lua, - guid: AtomicUsize, tasks: Arc>>, futures: Arc>>>, task_queue_instant: TaskSchedulerQueue, task_queue_deferred: TaskSchedulerQueue, exit_code_set: AtomicBool, exit_code: Arc>, + guid: AtomicUsize, + guid_running_task: AtomicUsize, } impl<'fut> TaskScheduler<'fut> { @@ -128,13 +129,16 @@ impl<'fut> TaskScheduler<'fut> { pub fn new(lua: &'static Lua) -> LuaResult { Ok(Self { lua, - guid: AtomicUsize::new(0), tasks: Arc::new(Mutex::new(HashMap::new())), futures: Arc::new(AsyncMutex::new(FuturesUnordered::new())), task_queue_instant: Arc::new(Mutex::new(VecDeque::new())), task_queue_deferred: Arc::new(Mutex::new(VecDeque::new())), exit_code_set: AtomicBool::new(false), exit_code: Arc::new(Mutex::new(ExitCode::SUCCESS)), + // Global ids must start at 1, since 0 is a special + // value for guid_running_task that means "no task" + guid: AtomicUsize::new(1), + guid_running_task: AtomicUsize::new(0), }) } @@ -196,6 +200,7 @@ impl<'fut> TaskScheduler<'fut> { kind: TaskKind, thread_or_function: LuaValue<'_>, thread_args: Option>, + guid_to_reuse: Option, ) -> LuaResult { // Get or create a thread from the given argument let task_thread = match thread_or_function { @@ -213,23 +218,29 @@ impl<'fut> TaskScheduler<'fut> { let task_args_key: LuaRegistryKey = self.lua.create_registry_value(task_args_vec)?; let task_thread_key: LuaRegistryKey = self.lua.create_registry_value(task_thread)?; // Create the full task struct - let guid = self.guid.fetch_add(1, Ordering::Relaxed) + 1; let queued_at = Instant::now(); let task = Task { thread: task_thread_key, args: task_args_key, queued_at, }; - // Add it to the scheduler + // Create the task ref to use + let task_ref = if let Some(reusable_guid) = guid_to_reuse { + TaskReference::new(kind, reusable_guid) + } else { + let guid = self.guid.fetch_add(1, Ordering::Relaxed); + TaskReference::new(kind, guid) + }; + // Add the task to the scheduler { let mut tasks = self.tasks.lock().unwrap(); - tasks.insert(TaskReference::new(kind, guid), task); + tasks.insert(task_ref, task); } - Ok(TaskReference::new(kind, guid)) + Ok(task_ref) } /** - Schedules a new task to run on the task scheduler. + Queues a new task to run on the task scheduler. When we want to schedule a task to resume instantly after the currently running task we should pass `after_current_resume = true`. @@ -244,17 +255,18 @@ impl<'fut> TaskScheduler<'fut> { -- Here we have either yielded or finished the above task ``` */ - fn schedule( + fn queue_task( &self, kind: TaskKind, thread_or_function: LuaValue<'_>, thread_args: Option>, + guid_to_reuse: Option, after_current_resume: bool, ) -> LuaResult { if kind == TaskKind::Future { panic!("Tried to schedule future using normal task schedule method") } - let task_ref = self.create_task(kind, thread_or_function, thread_args)?; + let task_ref = self.create_task(kind, thread_or_function, thread_args, guid_to_reuse)?; match kind { TaskKind::Instant => { let mut queue = self.task_queue_instant.lock().unwrap(); @@ -278,6 +290,33 @@ impl<'fut> TaskScheduler<'fut> { Ok(task_ref) } + /** + Queues a new future to run on the task scheduler. + */ + fn queue_async( + &self, + thread_or_function: LuaValue<'_>, + thread_args: Option>, + guid_to_reuse: Option, + fut: impl Future> + 'fut, + ) -> LuaResult { + let task_ref = self.create_task( + TaskKind::Future, + thread_or_function, + thread_args, + guid_to_reuse, + )?; + let futs = self + .futures + .try_lock() + .expect("Failed to get lock on futures"); + futs.push(Box::pin(async move { + let result = fut.await; + (task_ref, result) + })); + Ok(task_ref) + } + /** Schedules a lua thread or function to resume ***first*** during this resumption point, ***skipping ahead*** of any other currently queued tasks. @@ -290,10 +329,11 @@ impl<'fut> TaskScheduler<'fut> { thread_or_function: LuaValue<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - self.schedule( + self.queue_task( TaskKind::Instant, thread_or_function, Some(thread_args), + None, false, ) } @@ -310,10 +350,17 @@ impl<'fut> TaskScheduler<'fut> { thread_or_function: LuaValue<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - self.schedule( + self.queue_task( TaskKind::Instant, thread_or_function, Some(thread_args), + // This should recycle the guid of the current task, + // since it will only be called to schedule resuming + // current thread after it gives resumption to another + match self.guid_running_task.load(Ordering::Relaxed) { + 0 => panic!("Tried to schedule with no task running"), + guid => Some(guid), + }, true, ) } @@ -330,10 +377,11 @@ impl<'fut> TaskScheduler<'fut> { thread_or_function: LuaValue<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - self.schedule( + self.queue_task( TaskKind::Deferred, thread_or_function, Some(thread_args), + None, false, ) } @@ -351,16 +399,10 @@ impl<'fut> TaskScheduler<'fut> { thread_or_function: LuaValue<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - let task_ref = self.create_task(TaskKind::Future, thread_or_function, Some(thread_args))?; - let futs = self - .futures - .try_lock() - .expect("Failed to get lock on futures"); - futs.push(Box::pin(async move { + self.queue_async(thread_or_function, Some(thread_args), None, async move { sleep(Duration::from_secs_f64(after_secs)).await; - (task_ref, Ok(None)) - })); - Ok(task_ref) + Ok(None) + }) } /** @@ -375,19 +417,37 @@ impl<'fut> TaskScheduler<'fut> { after_secs: f64, thread_or_function: LuaValue<'_>, ) -> LuaResult { - // TODO: Wait should inherit the guid of the current task, - // this will ensure that TaskReferences are identical and - // that any waits inside of spawned tasks will also cancel - let task_ref = self.create_task(TaskKind::Future, thread_or_function, None)?; - let futs = self - .futures - .try_lock() - .expect("Failed to get lock on futures"); - futs.push(Box::pin(async move { - sleep(Duration::from_secs_f64(after_secs)).await; - (task_ref, Ok(None)) - })); - Ok(task_ref) + self.queue_async( + thread_or_function, + None, + // Wait should recycle the guid of the current task, + // which ensures that the TaskReference is identical and + // that any waits inside of spawned tasks will also cancel + match self.guid_running_task.load(Ordering::Relaxed) { + 0 => panic!("Tried to schedule waiting task with no task running"), + guid => Some(guid), + }, + async move { + sleep(Duration::from_secs_f64(after_secs)).await; + Ok(None) + }, + ) + } + + /** + Schedules a lua thread or function + to be resumed after running a future. + + The given lua thread or function will be resumed + using the optional arguments returned by the future. + */ + #[allow(dead_code)] + pub fn schedule_async( + &self, + thread_or_function: LuaValue<'_>, + fut: impl Future> + 'fut, + ) -> LuaResult { + self.queue_async(thread_or_function, None, None, fut) } /** @@ -409,26 +469,36 @@ impl<'fut> TaskScheduler<'fut> { to a task that no longer exists in the scheduler, and calling this method with one of those references will return `false`. */ - pub fn cancel_task(&self, reference: TaskReference) -> LuaResult { + pub fn remove_task(&self, reference: TaskReference) -> LuaResult { /* Remove the task from the task list and the Lua registry This is all we need to do since resume_task will always ignore resumption of any task that no longer exists there - This does lead to having some amount of "junk" tasks and futures - built up in the queues but these will get cleaned up and not block + This does lead to having some amount of "junk" futures that will + build up in the queue but these will get cleaned up and not block the program from exiting since the scheduler only runs until there - are no tasks left in the task list, the queues do not matter there + are no tasks left in the task list, the futures do not matter there */ + let mut found = false; let mut tasks = self.tasks.lock().unwrap(); - if let Some(task) = tasks.remove(&reference) { - self.lua.remove_registry_value(task.thread)?; - self.lua.remove_registry_value(task.args)?; - Ok(true) - } else { - Ok(false) + // Unfortunately we have to loop through to find which task + // references to remove instead of removing directly since + // tasks can switch kinds between instant, deferred, future + let tasks_to_remove: Vec<_> = tasks + .keys() + .filter(|task_ref| task_ref.guid == reference.guid) + .copied() + .collect(); + for task_ref in tasks_to_remove { + if let Some(task) = tasks.remove(&task_ref) { + self.lua.remove_registry_value(task.thread)?; + self.lua.remove_registry_value(task.args)?; + found = true; + } } + Ok(found) } /** @@ -444,6 +514,8 @@ impl<'fut> TaskScheduler<'fut> { reference: TaskReference, override_args: Option>, ) -> LuaResult<()> { + self.guid_running_task + .store(reference.guid, Ordering::Relaxed); let task = { let mut tasks = self.tasks.lock().unwrap(); match tasks.remove(&reference) { @@ -452,12 +524,14 @@ impl<'fut> TaskScheduler<'fut> { } }; let thread: LuaThread = self.lua.registry_value(&task.thread)?; - let args = override_args.or_else(|| { + let args_vec_opt = override_args.or_else(|| { self.lua .registry_value::>>(&task.args) .expect("Failed to get stored args for task") }); - if let Some(args) = args { + self.lua.remove_registry_value(task.thread)?; + self.lua.remove_registry_value(task.args)?; + if let Some(args) = args_vec_opt { thread.resume::<_, LuaMultiValue>(LuaMultiValue::from_vec(args))?; } else { /* @@ -473,8 +547,7 @@ impl<'fut> TaskScheduler<'fut> { let elapsed = task.queued_at.elapsed().as_secs_f64(); thread.resume::<_, LuaMultiValue>(elapsed)?; } - self.lua.remove_registry_value(task.thread)?; - self.lua.remove_registry_value(task.args)?; + self.guid_running_task.store(0, Ordering::Relaxed); Ok(()) } @@ -566,7 +639,7 @@ impl<'fut> TaskScheduler<'fut> { (task, Err(fut_err)) => { // Future errored, don't resume its associated task // and make sure to cancel / remove it completely - let error_prefer_cancel = match self.cancel_task(task) { + let error_prefer_cancel = match self.remove_task(task) { Err(cancel_err) => cancel_err, Ok(_) => fut_err, }; diff --git a/tests/task/cancel.luau b/tests/task/cancel.luau index f466434..dfbd0ce 100644 --- a/tests/task/cancel.luau +++ b/tests/task/cancel.luau @@ -21,13 +21,10 @@ assert(not flag2, "Cancel should handle delayed threads") local flag3: number = 1 local thread3 = task.spawn(function() - print("1") task.wait(0.1) flag3 = 2 - print("2") task.wait(0.2) flag3 = 3 - print("3") end) task.wait(0.2) task.cancel(thread3)