Properly handle tasks that switch context in new scheduler

This commit is contained in:
Filip Tibell 2023-02-13 23:36:30 +01:00
parent 879d6723a3
commit b1b69c7d94
No known key found for this signature in database
4 changed files with 127 additions and 56 deletions

View file

@ -21,7 +21,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
// Create task spawning functions that add tasks to the scheduler // Create task spawning functions that add tasks to the scheduler
let task_cancel = lua.create_function(|lua, task: TaskReference| { let task_cancel = lua.create_function(|lua, task: TaskReference| {
let sched = lua.app_data_mut::<&TaskScheduler>().unwrap(); let sched = lua.app_data_mut::<&TaskScheduler>().unwrap();
sched.cancel_task(task)?; sched.remove_task(task)?;
Ok(()) Ok(())
})?; })?;
let task_defer = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| { let task_defer = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| {

View file

@ -97,6 +97,7 @@ impl Lune {
let coroutine: LuaTable = lua.globals().get("coroutine")?; let coroutine: LuaTable = lua.globals().get("coroutine")?;
lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?; lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?;
lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?; lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?;
lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?;
let debug: LuaTable = lua.globals().raw_get("debug")?; let debug: LuaTable = lua.globals().raw_get("debug")?;
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;
// Add in wanted lune globals // Add in wanted lune globals

View file

@ -10,7 +10,7 @@ use std::{
time::Duration, time::Duration,
}; };
use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt}; use futures_util::{future::LocalBoxFuture, stream::FuturesUnordered, Future, StreamExt};
use mlua::prelude::*; use mlua::prelude::*;
use tokio::{ use tokio::{
@ -21,8 +21,8 @@ use tokio::{
type TaskSchedulerQueue = Arc<Mutex<VecDeque<TaskReference>>>; type TaskSchedulerQueue = Arc<Mutex<VecDeque<TaskReference>>>;
type TaskFutureArgsOverride<'fut> = Option<Vec<LuaValue<'fut>>>; type TaskFutureArgsOverride<'fut> = Option<Vec<LuaValue<'fut>>>;
type TaskFutureResult<'fut> = (TaskReference, LuaResult<TaskFutureArgsOverride<'fut>>); type TaskFutureReturns<'fut> = LuaResult<TaskFutureArgsOverride<'fut>>;
type TaskFuture<'fut> = BoxFuture<'fut, TaskFutureResult<'fut>>; type TaskFuture<'fut> = LocalBoxFuture<'fut, (TaskReference, TaskFutureReturns<'fut>)>;
/// An enum representing different kinds of tasks /// An enum representing different kinds of tasks
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
@ -112,13 +112,14 @@ pub enum TaskSchedulerResult {
#[derive(Debug)] #[derive(Debug)]
pub struct TaskScheduler<'fut> { pub struct TaskScheduler<'fut> {
lua: &'static Lua, lua: &'static Lua,
guid: AtomicUsize,
tasks: Arc<Mutex<HashMap<TaskReference, Task>>>, tasks: Arc<Mutex<HashMap<TaskReference, Task>>>,
futures: Arc<AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>>, futures: Arc<AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>>,
task_queue_instant: TaskSchedulerQueue, task_queue_instant: TaskSchedulerQueue,
task_queue_deferred: TaskSchedulerQueue, task_queue_deferred: TaskSchedulerQueue,
exit_code_set: AtomicBool, exit_code_set: AtomicBool,
exit_code: Arc<Mutex<ExitCode>>, exit_code: Arc<Mutex<ExitCode>>,
guid: AtomicUsize,
guid_running_task: AtomicUsize,
} }
impl<'fut> TaskScheduler<'fut> { impl<'fut> TaskScheduler<'fut> {
@ -128,13 +129,16 @@ impl<'fut> TaskScheduler<'fut> {
pub fn new(lua: &'static Lua) -> LuaResult<Self> { pub fn new(lua: &'static Lua) -> LuaResult<Self> {
Ok(Self { Ok(Self {
lua, lua,
guid: AtomicUsize::new(0),
tasks: Arc::new(Mutex::new(HashMap::new())), tasks: Arc::new(Mutex::new(HashMap::new())),
futures: Arc::new(AsyncMutex::new(FuturesUnordered::new())), futures: Arc::new(AsyncMutex::new(FuturesUnordered::new())),
task_queue_instant: Arc::new(Mutex::new(VecDeque::new())), task_queue_instant: Arc::new(Mutex::new(VecDeque::new())),
task_queue_deferred: Arc::new(Mutex::new(VecDeque::new())), task_queue_deferred: Arc::new(Mutex::new(VecDeque::new())),
exit_code_set: AtomicBool::new(false), exit_code_set: AtomicBool::new(false),
exit_code: Arc::new(Mutex::new(ExitCode::SUCCESS)), exit_code: Arc::new(Mutex::new(ExitCode::SUCCESS)),
// Global ids must start at 1, since 0 is a special
// value for guid_running_task that means "no task"
guid: AtomicUsize::new(1),
guid_running_task: AtomicUsize::new(0),
}) })
} }
@ -196,6 +200,7 @@ impl<'fut> TaskScheduler<'fut> {
kind: TaskKind, kind: TaskKind,
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: Option<LuaMultiValue<'_>>, thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
// Get or create a thread from the given argument // Get or create a thread from the given argument
let task_thread = match thread_or_function { let task_thread = match thread_or_function {
@ -213,23 +218,29 @@ impl<'fut> TaskScheduler<'fut> {
let task_args_key: LuaRegistryKey = self.lua.create_registry_value(task_args_vec)?; let task_args_key: LuaRegistryKey = self.lua.create_registry_value(task_args_vec)?;
let task_thread_key: LuaRegistryKey = self.lua.create_registry_value(task_thread)?; let task_thread_key: LuaRegistryKey = self.lua.create_registry_value(task_thread)?;
// Create the full task struct // Create the full task struct
let guid = self.guid.fetch_add(1, Ordering::Relaxed) + 1;
let queued_at = Instant::now(); let queued_at = Instant::now();
let task = Task { let task = Task {
thread: task_thread_key, thread: task_thread_key,
args: task_args_key, args: task_args_key,
queued_at, queued_at,
}; };
// Add it to the scheduler // Create the task ref to use
let task_ref = if let Some(reusable_guid) = guid_to_reuse {
TaskReference::new(kind, reusable_guid)
} else {
let guid = self.guid.fetch_add(1, Ordering::Relaxed);
TaskReference::new(kind, guid)
};
// Add the task to the scheduler
{ {
let mut tasks = self.tasks.lock().unwrap(); let mut tasks = self.tasks.lock().unwrap();
tasks.insert(TaskReference::new(kind, guid), task); tasks.insert(task_ref, task);
} }
Ok(TaskReference::new(kind, guid)) Ok(task_ref)
} }
/** /**
Schedules a new task to run on the task scheduler. Queues a new task to run on the task scheduler.
When we want to schedule a task to resume instantly after the When we want to schedule a task to resume instantly after the
currently running task we should pass `after_current_resume = true`. currently running task we should pass `after_current_resume = true`.
@ -244,17 +255,18 @@ impl<'fut> TaskScheduler<'fut> {
-- Here we have either yielded or finished the above task -- Here we have either yielded or finished the above task
``` ```
*/ */
fn schedule( fn queue_task(
&self, &self,
kind: TaskKind, kind: TaskKind,
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: Option<LuaMultiValue<'_>>, thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>,
after_current_resume: bool, after_current_resume: bool,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
if kind == TaskKind::Future { if kind == TaskKind::Future {
panic!("Tried to schedule future using normal task schedule method") panic!("Tried to schedule future using normal task schedule method")
} }
let task_ref = self.create_task(kind, thread_or_function, thread_args)?; let task_ref = self.create_task(kind, thread_or_function, thread_args, guid_to_reuse)?;
match kind { match kind {
TaskKind::Instant => { TaskKind::Instant => {
let mut queue = self.task_queue_instant.lock().unwrap(); let mut queue = self.task_queue_instant.lock().unwrap();
@ -278,6 +290,33 @@ impl<'fut> TaskScheduler<'fut> {
Ok(task_ref) Ok(task_ref)
} }
/**
Queues a new future to run on the task scheduler.
*/
fn queue_async(
&self,
thread_or_function: LuaValue<'_>,
thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>,
fut: impl Future<Output = TaskFutureReturns<'fut>> + 'fut,
) -> LuaResult<TaskReference> {
let task_ref = self.create_task(
TaskKind::Future,
thread_or_function,
thread_args,
guid_to_reuse,
)?;
let futs = self
.futures
.try_lock()
.expect("Failed to get lock on futures");
futs.push(Box::pin(async move {
let result = fut.await;
(task_ref, result)
}));
Ok(task_ref)
}
/** /**
Schedules a lua thread or function to resume ***first*** during this Schedules a lua thread or function to resume ***first*** during this
resumption point, ***skipping ahead*** of any other currently queued tasks. resumption point, ***skipping ahead*** of any other currently queued tasks.
@ -290,10 +329,11 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.schedule( self.queue_task(
TaskKind::Instant, TaskKind::Instant,
thread_or_function, thread_or_function,
Some(thread_args), Some(thread_args),
None,
false, false,
) )
} }
@ -310,10 +350,17 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.schedule( self.queue_task(
TaskKind::Instant, TaskKind::Instant,
thread_or_function, thread_or_function,
Some(thread_args), Some(thread_args),
// This should recycle the guid of the current task,
// since it will only be called to schedule resuming
// current thread after it gives resumption to another
match self.guid_running_task.load(Ordering::Relaxed) {
0 => panic!("Tried to schedule with no task running"),
guid => Some(guid),
},
true, true,
) )
} }
@ -330,10 +377,11 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.schedule( self.queue_task(
TaskKind::Deferred, TaskKind::Deferred,
thread_or_function, thread_or_function,
Some(thread_args), Some(thread_args),
None,
false, false,
) )
} }
@ -351,16 +399,10 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
let task_ref = self.create_task(TaskKind::Future, thread_or_function, Some(thread_args))?; self.queue_async(thread_or_function, Some(thread_args), None, async move {
let futs = self
.futures
.try_lock()
.expect("Failed to get lock on futures");
futs.push(Box::pin(async move {
sleep(Duration::from_secs_f64(after_secs)).await; sleep(Duration::from_secs_f64(after_secs)).await;
(task_ref, Ok(None)) Ok(None)
})); })
Ok(task_ref)
} }
/** /**
@ -375,19 +417,37 @@ impl<'fut> TaskScheduler<'fut> {
after_secs: f64, after_secs: f64,
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
// TODO: Wait should inherit the guid of the current task, self.queue_async(
// this will ensure that TaskReferences are identical and thread_or_function,
// that any waits inside of spawned tasks will also cancel None,
let task_ref = self.create_task(TaskKind::Future, thread_or_function, None)?; // Wait should recycle the guid of the current task,
let futs = self // which ensures that the TaskReference is identical and
.futures // that any waits inside of spawned tasks will also cancel
.try_lock() match self.guid_running_task.load(Ordering::Relaxed) {
.expect("Failed to get lock on futures"); 0 => panic!("Tried to schedule waiting task with no task running"),
futs.push(Box::pin(async move { guid => Some(guid),
sleep(Duration::from_secs_f64(after_secs)).await; },
(task_ref, Ok(None)) async move {
})); sleep(Duration::from_secs_f64(after_secs)).await;
Ok(task_ref) Ok(None)
},
)
}
/**
Schedules a lua thread or function
to be resumed after running a future.
The given lua thread or function will be resumed
using the optional arguments returned by the future.
*/
#[allow(dead_code)]
pub fn schedule_async(
&self,
thread_or_function: LuaValue<'_>,
fut: impl Future<Output = TaskFutureReturns<'fut>> + 'fut,
) -> LuaResult<TaskReference> {
self.queue_async(thread_or_function, None, None, fut)
} }
/** /**
@ -409,26 +469,36 @@ impl<'fut> TaskScheduler<'fut> {
to a task that no longer exists in the scheduler, and calling to a task that no longer exists in the scheduler, and calling
this method with one of those references will return `false`. this method with one of those references will return `false`.
*/ */
pub fn cancel_task(&self, reference: TaskReference) -> LuaResult<bool> { pub fn remove_task(&self, reference: TaskReference) -> LuaResult<bool> {
/* /*
Remove the task from the task list and the Lua registry Remove the task from the task list and the Lua registry
This is all we need to do since resume_task will always This is all we need to do since resume_task will always
ignore resumption of any task that no longer exists there ignore resumption of any task that no longer exists there
This does lead to having some amount of "junk" tasks and futures This does lead to having some amount of "junk" futures that will
built up in the queues but these will get cleaned up and not block build up in the queue but these will get cleaned up and not block
the program from exiting since the scheduler only runs until there the program from exiting since the scheduler only runs until there
are no tasks left in the task list, the queues do not matter there are no tasks left in the task list, the futures do not matter there
*/ */
let mut found = false;
let mut tasks = self.tasks.lock().unwrap(); let mut tasks = self.tasks.lock().unwrap();
if let Some(task) = tasks.remove(&reference) { // Unfortunately we have to loop through to find which task
self.lua.remove_registry_value(task.thread)?; // references to remove instead of removing directly since
self.lua.remove_registry_value(task.args)?; // tasks can switch kinds between instant, deferred, future
Ok(true) let tasks_to_remove: Vec<_> = tasks
} else { .keys()
Ok(false) .filter(|task_ref| task_ref.guid == reference.guid)
.copied()
.collect();
for task_ref in tasks_to_remove {
if let Some(task) = tasks.remove(&task_ref) {
self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
found = true;
}
} }
Ok(found)
} }
/** /**
@ -444,6 +514,8 @@ impl<'fut> TaskScheduler<'fut> {
reference: TaskReference, reference: TaskReference,
override_args: Option<Vec<LuaValue>>, override_args: Option<Vec<LuaValue>>,
) -> LuaResult<()> { ) -> LuaResult<()> {
self.guid_running_task
.store(reference.guid, Ordering::Relaxed);
let task = { let task = {
let mut tasks = self.tasks.lock().unwrap(); let mut tasks = self.tasks.lock().unwrap();
match tasks.remove(&reference) { match tasks.remove(&reference) {
@ -452,12 +524,14 @@ impl<'fut> TaskScheduler<'fut> {
} }
}; };
let thread: LuaThread = self.lua.registry_value(&task.thread)?; let thread: LuaThread = self.lua.registry_value(&task.thread)?;
let args = override_args.or_else(|| { let args_vec_opt = override_args.or_else(|| {
self.lua self.lua
.registry_value::<Option<Vec<LuaValue>>>(&task.args) .registry_value::<Option<Vec<LuaValue>>>(&task.args)
.expect("Failed to get stored args for task") .expect("Failed to get stored args for task")
}); });
if let Some(args) = args { self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
if let Some(args) = args_vec_opt {
thread.resume::<_, LuaMultiValue>(LuaMultiValue::from_vec(args))?; thread.resume::<_, LuaMultiValue>(LuaMultiValue::from_vec(args))?;
} else { } else {
/* /*
@ -473,8 +547,7 @@ impl<'fut> TaskScheduler<'fut> {
let elapsed = task.queued_at.elapsed().as_secs_f64(); let elapsed = task.queued_at.elapsed().as_secs_f64();
thread.resume::<_, LuaMultiValue>(elapsed)?; thread.resume::<_, LuaMultiValue>(elapsed)?;
} }
self.lua.remove_registry_value(task.thread)?; self.guid_running_task.store(0, Ordering::Relaxed);
self.lua.remove_registry_value(task.args)?;
Ok(()) Ok(())
} }
@ -566,7 +639,7 @@ impl<'fut> TaskScheduler<'fut> {
(task, Err(fut_err)) => { (task, Err(fut_err)) => {
// Future errored, don't resume its associated task // Future errored, don't resume its associated task
// and make sure to cancel / remove it completely // and make sure to cancel / remove it completely
let error_prefer_cancel = match self.cancel_task(task) { let error_prefer_cancel = match self.remove_task(task) {
Err(cancel_err) => cancel_err, Err(cancel_err) => cancel_err,
Ok(_) => fut_err, Ok(_) => fut_err,
}; };

View file

@ -21,13 +21,10 @@ assert(not flag2, "Cancel should handle delayed threads")
local flag3: number = 1 local flag3: number = 1
local thread3 = task.spawn(function() local thread3 = task.spawn(function()
print("1")
task.wait(0.1) task.wait(0.1)
flag3 = 2 flag3 = 2
print("2")
task.wait(0.2) task.wait(0.2)
flag3 = 3 flag3 = 3
print("3")
end) end)
task.wait(0.2) task.wait(0.2)
task.cancel(thread3) task.cancel(thread3)