Properly handle tasks that switch context in new scheduler

This commit is contained in:
Filip Tibell 2023-02-13 23:36:30 +01:00
parent 879d6723a3
commit b1b69c7d94
No known key found for this signature in database
4 changed files with 127 additions and 56 deletions

View file

@ -21,7 +21,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
// Create task spawning functions that add tasks to the scheduler
let task_cancel = lua.create_function(|lua, task: TaskReference| {
let sched = lua.app_data_mut::<&TaskScheduler>().unwrap();
sched.cancel_task(task)?;
sched.remove_task(task)?;
Ok(())
})?;
let task_defer = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| {

View file

@ -97,6 +97,7 @@ impl Lune {
let coroutine: LuaTable = lua.globals().get("coroutine")?;
lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?;
lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?;
lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?;
let debug: LuaTable = lua.globals().raw_get("debug")?;
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;
// Add in wanted lune globals

View file

@ -10,7 +10,7 @@ use std::{
time::Duration,
};
use futures_util::{future::BoxFuture, stream::FuturesUnordered, StreamExt};
use futures_util::{future::LocalBoxFuture, stream::FuturesUnordered, Future, StreamExt};
use mlua::prelude::*;
use tokio::{
@ -21,8 +21,8 @@ use tokio::{
type TaskSchedulerQueue = Arc<Mutex<VecDeque<TaskReference>>>;
type TaskFutureArgsOverride<'fut> = Option<Vec<LuaValue<'fut>>>;
type TaskFutureResult<'fut> = (TaskReference, LuaResult<TaskFutureArgsOverride<'fut>>);
type TaskFuture<'fut> = BoxFuture<'fut, TaskFutureResult<'fut>>;
type TaskFutureReturns<'fut> = LuaResult<TaskFutureArgsOverride<'fut>>;
type TaskFuture<'fut> = LocalBoxFuture<'fut, (TaskReference, TaskFutureReturns<'fut>)>;
/// An enum representing different kinds of tasks
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
@ -112,13 +112,14 @@ pub enum TaskSchedulerResult {
#[derive(Debug)]
pub struct TaskScheduler<'fut> {
lua: &'static Lua,
guid: AtomicUsize,
tasks: Arc<Mutex<HashMap<TaskReference, Task>>>,
futures: Arc<AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>>,
task_queue_instant: TaskSchedulerQueue,
task_queue_deferred: TaskSchedulerQueue,
exit_code_set: AtomicBool,
exit_code: Arc<Mutex<ExitCode>>,
guid: AtomicUsize,
guid_running_task: AtomicUsize,
}
impl<'fut> TaskScheduler<'fut> {
@ -128,13 +129,16 @@ impl<'fut> TaskScheduler<'fut> {
pub fn new(lua: &'static Lua) -> LuaResult<Self> {
Ok(Self {
lua,
guid: AtomicUsize::new(0),
tasks: Arc::new(Mutex::new(HashMap::new())),
futures: Arc::new(AsyncMutex::new(FuturesUnordered::new())),
task_queue_instant: Arc::new(Mutex::new(VecDeque::new())),
task_queue_deferred: Arc::new(Mutex::new(VecDeque::new())),
exit_code_set: AtomicBool::new(false),
exit_code: Arc::new(Mutex::new(ExitCode::SUCCESS)),
// Global ids must start at 1, since 0 is a special
// value for guid_running_task that means "no task"
guid: AtomicUsize::new(1),
guid_running_task: AtomicUsize::new(0),
})
}
@ -196,6 +200,7 @@ impl<'fut> TaskScheduler<'fut> {
kind: TaskKind,
thread_or_function: LuaValue<'_>,
thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>,
) -> LuaResult<TaskReference> {
// Get or create a thread from the given argument
let task_thread = match thread_or_function {
@ -213,23 +218,29 @@ impl<'fut> TaskScheduler<'fut> {
let task_args_key: LuaRegistryKey = self.lua.create_registry_value(task_args_vec)?;
let task_thread_key: LuaRegistryKey = self.lua.create_registry_value(task_thread)?;
// Create the full task struct
let guid = self.guid.fetch_add(1, Ordering::Relaxed) + 1;
let queued_at = Instant::now();
let task = Task {
thread: task_thread_key,
args: task_args_key,
queued_at,
};
// Add it to the scheduler
// Create the task ref to use
let task_ref = if let Some(reusable_guid) = guid_to_reuse {
TaskReference::new(kind, reusable_guid)
} else {
let guid = self.guid.fetch_add(1, Ordering::Relaxed);
TaskReference::new(kind, guid)
};
// Add the task to the scheduler
{
let mut tasks = self.tasks.lock().unwrap();
tasks.insert(TaskReference::new(kind, guid), task);
tasks.insert(task_ref, task);
}
Ok(TaskReference::new(kind, guid))
Ok(task_ref)
}
/**
Schedules a new task to run on the task scheduler.
Queues a new task to run on the task scheduler.
When we want to schedule a task to resume instantly after the
currently running task we should pass `after_current_resume = true`.
@ -244,17 +255,18 @@ impl<'fut> TaskScheduler<'fut> {
-- Here we have either yielded or finished the above task
```
*/
fn schedule(
fn queue_task(
&self,
kind: TaskKind,
thread_or_function: LuaValue<'_>,
thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>,
after_current_resume: bool,
) -> LuaResult<TaskReference> {
if kind == TaskKind::Future {
panic!("Tried to schedule future using normal task schedule method")
}
let task_ref = self.create_task(kind, thread_or_function, thread_args)?;
let task_ref = self.create_task(kind, thread_or_function, thread_args, guid_to_reuse)?;
match kind {
TaskKind::Instant => {
let mut queue = self.task_queue_instant.lock().unwrap();
@ -278,6 +290,33 @@ impl<'fut> TaskScheduler<'fut> {
Ok(task_ref)
}
/**
Queues a new future to run on the task scheduler.
*/
fn queue_async(
&self,
thread_or_function: LuaValue<'_>,
thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>,
fut: impl Future<Output = TaskFutureReturns<'fut>> + 'fut,
) -> LuaResult<TaskReference> {
let task_ref = self.create_task(
TaskKind::Future,
thread_or_function,
thread_args,
guid_to_reuse,
)?;
let futs = self
.futures
.try_lock()
.expect("Failed to get lock on futures");
futs.push(Box::pin(async move {
let result = fut.await;
(task_ref, result)
}));
Ok(task_ref)
}
/**
Schedules a lua thread or function to resume ***first*** during this
resumption point, ***skipping ahead*** of any other currently queued tasks.
@ -290,10 +329,11 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> {
self.schedule(
self.queue_task(
TaskKind::Instant,
thread_or_function,
Some(thread_args),
None,
false,
)
}
@ -310,10 +350,17 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> {
self.schedule(
self.queue_task(
TaskKind::Instant,
thread_or_function,
Some(thread_args),
// This should recycle the guid of the current task,
// since it will only be called to schedule resuming
// current thread after it gives resumption to another
match self.guid_running_task.load(Ordering::Relaxed) {
0 => panic!("Tried to schedule with no task running"),
guid => Some(guid),
},
true,
)
}
@ -330,10 +377,11 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> {
self.schedule(
self.queue_task(
TaskKind::Deferred,
thread_or_function,
Some(thread_args),
None,
false,
)
}
@ -351,16 +399,10 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> {
let task_ref = self.create_task(TaskKind::Future, thread_or_function, Some(thread_args))?;
let futs = self
.futures
.try_lock()
.expect("Failed to get lock on futures");
futs.push(Box::pin(async move {
self.queue_async(thread_or_function, Some(thread_args), None, async move {
sleep(Duration::from_secs_f64(after_secs)).await;
(task_ref, Ok(None))
}));
Ok(task_ref)
Ok(None)
})
}
/**
@ -375,19 +417,37 @@ impl<'fut> TaskScheduler<'fut> {
after_secs: f64,
thread_or_function: LuaValue<'_>,
) -> LuaResult<TaskReference> {
// TODO: Wait should inherit the guid of the current task,
// this will ensure that TaskReferences are identical and
// that any waits inside of spawned tasks will also cancel
let task_ref = self.create_task(TaskKind::Future, thread_or_function, None)?;
let futs = self
.futures
.try_lock()
.expect("Failed to get lock on futures");
futs.push(Box::pin(async move {
sleep(Duration::from_secs_f64(after_secs)).await;
(task_ref, Ok(None))
}));
Ok(task_ref)
self.queue_async(
thread_or_function,
None,
// Wait should recycle the guid of the current task,
// which ensures that the TaskReference is identical and
// that any waits inside of spawned tasks will also cancel
match self.guid_running_task.load(Ordering::Relaxed) {
0 => panic!("Tried to schedule waiting task with no task running"),
guid => Some(guid),
},
async move {
sleep(Duration::from_secs_f64(after_secs)).await;
Ok(None)
},
)
}
/**
Schedules a lua thread or function
to be resumed after running a future.
The given lua thread or function will be resumed
using the optional arguments returned by the future.
*/
#[allow(dead_code)]
pub fn schedule_async(
&self,
thread_or_function: LuaValue<'_>,
fut: impl Future<Output = TaskFutureReturns<'fut>> + 'fut,
) -> LuaResult<TaskReference> {
self.queue_async(thread_or_function, None, None, fut)
}
/**
@ -409,26 +469,36 @@ impl<'fut> TaskScheduler<'fut> {
to a task that no longer exists in the scheduler, and calling
this method with one of those references will return `false`.
*/
pub fn cancel_task(&self, reference: TaskReference) -> LuaResult<bool> {
pub fn remove_task(&self, reference: TaskReference) -> LuaResult<bool> {
/*
Remove the task from the task list and the Lua registry
This is all we need to do since resume_task will always
ignore resumption of any task that no longer exists there
This does lead to having some amount of "junk" tasks and futures
built up in the queues but these will get cleaned up and not block
This does lead to having some amount of "junk" futures that will
build up in the queue but these will get cleaned up and not block
the program from exiting since the scheduler only runs until there
are no tasks left in the task list, the queues do not matter there
are no tasks left in the task list, the futures do not matter there
*/
let mut found = false;
let mut tasks = self.tasks.lock().unwrap();
if let Some(task) = tasks.remove(&reference) {
self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
Ok(true)
} else {
Ok(false)
// Unfortunately we have to loop through to find which task
// references to remove instead of removing directly since
// tasks can switch kinds between instant, deferred, future
let tasks_to_remove: Vec<_> = tasks
.keys()
.filter(|task_ref| task_ref.guid == reference.guid)
.copied()
.collect();
for task_ref in tasks_to_remove {
if let Some(task) = tasks.remove(&task_ref) {
self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
found = true;
}
}
Ok(found)
}
/**
@ -444,6 +514,8 @@ impl<'fut> TaskScheduler<'fut> {
reference: TaskReference,
override_args: Option<Vec<LuaValue>>,
) -> LuaResult<()> {
self.guid_running_task
.store(reference.guid, Ordering::Relaxed);
let task = {
let mut tasks = self.tasks.lock().unwrap();
match tasks.remove(&reference) {
@ -452,12 +524,14 @@ impl<'fut> TaskScheduler<'fut> {
}
};
let thread: LuaThread = self.lua.registry_value(&task.thread)?;
let args = override_args.or_else(|| {
let args_vec_opt = override_args.or_else(|| {
self.lua
.registry_value::<Option<Vec<LuaValue>>>(&task.args)
.expect("Failed to get stored args for task")
});
if let Some(args) = args {
self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
if let Some(args) = args_vec_opt {
thread.resume::<_, LuaMultiValue>(LuaMultiValue::from_vec(args))?;
} else {
/*
@ -473,8 +547,7 @@ impl<'fut> TaskScheduler<'fut> {
let elapsed = task.queued_at.elapsed().as_secs_f64();
thread.resume::<_, LuaMultiValue>(elapsed)?;
}
self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
self.guid_running_task.store(0, Ordering::Relaxed);
Ok(())
}
@ -566,7 +639,7 @@ impl<'fut> TaskScheduler<'fut> {
(task, Err(fut_err)) => {
// Future errored, don't resume its associated task
// and make sure to cancel / remove it completely
let error_prefer_cancel = match self.cancel_task(task) {
let error_prefer_cancel = match self.remove_task(task) {
Err(cancel_err) => cancel_err,
Ok(_) => fut_err,
};

View file

@ -21,13 +21,10 @@ assert(not flag2, "Cancel should handle delayed threads")
local flag3: number = 1
local thread3 = task.spawn(function()
print("1")
task.wait(0.1)
flag3 = 2
print("2")
task.wait(0.2)
flag3 = 3
print("3")
end)
task.wait(0.2)
task.cancel(thread3)