Properly handle blocking tasks spawned during async tasks in new scheduler

This commit is contained in:
Filip Tibell 2023-02-14 21:36:04 +01:00
parent 6f1ae83fbe
commit 3990b8e064
No known key found for this signature in database

View file

@ -91,11 +91,11 @@ pub struct Task {
#[derive(Debug)] #[derive(Debug)]
pub struct TaskSchedulerBackgroundTaskHandle { pub struct TaskSchedulerBackgroundTaskHandle {
unregistered: bool, unregistered: bool,
sender: mpsc::UnboundedSender<TaskSchedulerRegistrationMessage>, sender: mpsc::UnboundedSender<TaskSchedulerMessage>,
} }
impl TaskSchedulerBackgroundTaskHandle { impl TaskSchedulerBackgroundTaskHandle {
pub fn new(sender: mpsc::UnboundedSender<TaskSchedulerRegistrationMessage>) -> Self { pub fn new(sender: mpsc::UnboundedSender<TaskSchedulerMessage>) -> Self {
Self { Self {
unregistered: false, unregistered: false,
sender, sender,
@ -105,7 +105,7 @@ impl TaskSchedulerBackgroundTaskHandle {
pub fn unregister(mut self, result: LuaResult<()>) { pub fn unregister(mut self, result: LuaResult<()>) {
self.unregistered = true; self.unregistered = true;
self.sender self.sender
.send(TaskSchedulerRegistrationMessage::Terminated(result)) .send(TaskSchedulerMessage::Terminated(result))
.unwrap_or_else(|_| { .unwrap_or_else(|_| {
panic!( panic!(
"\ "\
@ -285,7 +285,8 @@ impl fmt::Display for TaskSchedulerResult {
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum TaskSchedulerRegistrationMessage { pub enum TaskSchedulerMessage {
NewBlockingTaskReady,
Spawned, Spawned,
Terminated(LuaResult<()>), Terminated(LuaResult<()>),
} }
@ -297,8 +298,8 @@ pub struct TaskScheduler<'fut> {
lua: &'static Lua, lua: &'static Lua,
tasks: Arc<Mutex<HashMap<TaskReference, Task>>>, tasks: Arc<Mutex<HashMap<TaskReference, Task>>>,
futures: Arc<AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>>, futures: Arc<AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>>,
futures_tx: mpsc::UnboundedSender<TaskSchedulerRegistrationMessage>, futures_tx: mpsc::UnboundedSender<TaskSchedulerMessage>,
futures_rx: Arc<AsyncMutex<mpsc::UnboundedReceiver<TaskSchedulerRegistrationMessage>>>, futures_rx: Arc<AsyncMutex<mpsc::UnboundedReceiver<TaskSchedulerMessage>>>,
futures_in_background: AtomicUsize, futures_in_background: AtomicUsize,
task_queue_instant: TaskSchedulerQueue, task_queue_instant: TaskSchedulerQueue,
task_queue_deferred: TaskSchedulerQueue, task_queue_deferred: TaskSchedulerQueue,
@ -424,7 +425,7 @@ impl<'fut> TaskScheduler<'fut> {
-- Here we have either yielded or finished the above task -- Here we have either yielded or finished the above task
``` ```
*/ */
fn queue_task( fn queue_blocking_task(
&self, &self,
kind: TaskKind, kind: TaskKind,
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
@ -436,10 +437,26 @@ impl<'fut> TaskScheduler<'fut> {
panic!("Tried to schedule future using normal task schedule method") panic!("Tried to schedule future using normal task schedule method")
} }
let task_ref = self.create_task(kind, thread_or_function, thread_args, guid_to_reuse)?; let task_ref = self.create_task(kind, thread_or_function, thread_args, guid_to_reuse)?;
// Note that we create two new inner new
// scopes to drop mutexes as fast as possible
let num_prev_blocking_tasks = {
let (should_defer, num_prev_tasks, mut queue) = {
let queue_instant = self.task_queue_instant.lock().unwrap();
let queue_deferred = self.task_queue_deferred.lock().unwrap();
let num_prev_tasks = queue_instant.len() + queue_deferred.len();
(
kind == TaskKind::Deferred,
num_prev_tasks,
match kind { match kind {
TaskKind::Instant => { TaskKind::Instant => queue_instant,
let mut queue = self.task_queue_instant.lock().unwrap(); TaskKind::Deferred => queue_deferred,
if after_current_resume { TaskKind::Future => unreachable!(),
},
)
};
if should_defer {
queue.push_back(task_ref);
} else if after_current_resume {
assert!( assert!(
queue.len() > 0, queue.len() > 0,
"Cannot schedule a task after the first instant when task queue is empty" "Cannot schedule a task after the first instant when task queue is empty"
@ -448,13 +465,21 @@ impl<'fut> TaskScheduler<'fut> {
} else { } else {
queue.push_front(task_ref); queue.push_front(task_ref);
} }
} num_prev_tasks
TaskKind::Deferred => { };
// Deferred tasks should always schedule at the end of the deferred queue /*
let mut queue = self.task_queue_deferred.lock().unwrap(); If we had any previous task and are currently async
queue.push_back(task_ref); waiting on tasks, we should send a signal to wake up
} and run the new blocking task that was just queued
TaskKind::Future => unreachable!(),
This can happen in cases such as an async http
server waking up from a connection and then wanting to
run a lua callback in response, to create the.. response
*/
if num_prev_blocking_tasks == 0 {
self.futures_tx
.send(TaskSchedulerMessage::NewBlockingTaskReady)
.expect("Futures waker channel was closed")
} }
Ok(task_ref) Ok(task_ref)
} }
@ -462,7 +487,7 @@ impl<'fut> TaskScheduler<'fut> {
/** /**
Queues a new future to run on the task scheduler. Queues a new future to run on the task scheduler.
*/ */
fn queue_async( fn queue_async_task(
&self, &self,
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: Option<LuaMultiValue<'_>>, thread_args: Option<LuaMultiValue<'_>>,
@ -498,7 +523,7 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.queue_task( self.queue_blocking_task(
TaskKind::Instant, TaskKind::Instant,
thread_or_function, thread_or_function,
Some(thread_args), Some(thread_args),
@ -519,7 +544,7 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.queue_task( self.queue_blocking_task(
TaskKind::Instant, TaskKind::Instant,
thread_or_function, thread_or_function,
Some(thread_args), Some(thread_args),
@ -546,7 +571,7 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.queue_task( self.queue_blocking_task(
TaskKind::Deferred, TaskKind::Deferred,
thread_or_function, thread_or_function,
Some(thread_args), Some(thread_args),
@ -568,7 +593,7 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.queue_async(thread_or_function, Some(thread_args), None, async move { self.queue_async_task(thread_or_function, Some(thread_args), None, async move {
sleep(Duration::from_secs_f64(after_secs)).await; sleep(Duration::from_secs_f64(after_secs)).await;
Ok(None) Ok(None)
}) })
@ -586,7 +611,7 @@ impl<'fut> TaskScheduler<'fut> {
after_secs: f64, after_secs: f64,
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.queue_async( self.queue_async_task(
thread_or_function, thread_or_function,
None, None,
// Wait should recycle the guid of the current task, // Wait should recycle the guid of the current task,
@ -615,7 +640,7 @@ impl<'fut> TaskScheduler<'fut> {
thread_or_function: LuaValue<'_>, thread_or_function: LuaValue<'_>,
fut: impl Future<Output = TaskFutureReturns<'fut>> + 'fut, fut: impl Future<Output = TaskFutureReturns<'fut>> + 'fut,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.queue_async(thread_or_function, None, None, fut) self.queue_async_task(thread_or_function, None, None, fut)
} }
/** /**
@ -768,7 +793,7 @@ impl<'fut> TaskScheduler<'fut> {
pub fn register_background_task(&self) -> TaskSchedulerBackgroundTaskHandle { pub fn register_background_task(&self) -> TaskSchedulerBackgroundTaskHandle {
let sender = self.futures_tx.clone(); let sender = self.futures_tx.clone();
sender sender
.send(TaskSchedulerRegistrationMessage::Spawned) .send(TaskSchedulerMessage::Spawned)
.unwrap_or_else(|e| { .unwrap_or_else(|e| {
panic!( panic!(
"\ "\
@ -876,11 +901,12 @@ impl<'fut> TaskScheduler<'fut> {
}; };
if let Some(message) = message_opt { if let Some(message) = message_opt {
match message { match message {
TaskSchedulerRegistrationMessage::Spawned => { TaskSchedulerMessage::NewBlockingTaskReady => TaskSchedulerResult::new(self),
TaskSchedulerMessage::Spawned => {
self.futures_in_background.fetch_add(1, Ordering::Relaxed); self.futures_in_background.fetch_add(1, Ordering::Relaxed);
TaskSchedulerResult::new(self) TaskSchedulerResult::new(self)
} }
TaskSchedulerRegistrationMessage::Terminated(result) => { TaskSchedulerMessage::Terminated(result) => {
let prev = self.futures_in_background.fetch_sub(1, Ordering::Relaxed); let prev = self.futures_in_background.fetch_sub(1, Ordering::Relaxed);
if prev == 0 { if prev == 0 {
panic!( panic!(