diff --git a/packages/lib/src/globals/net.rs b/packages/lib/src/globals/net.rs index 1724c2d..41f32f8 100644 --- a/packages/lib/src/globals/net.rs +++ b/packages/lib/src/globals/net.rs @@ -149,7 +149,9 @@ async fn net_socket<'a>(lua: &'static Lua, url: String) -> LuaResult { let (ws, _) = tokio_tungstenite::connect_async(url) .await .map_err(LuaError::external)?; - todo!() + Err(LuaError::RuntimeError( + "Client websockets are not yet implemented".to_string(), + )) // let sock = NetWebSocketClient::from(ws); // let table = sock.into_lua_table(lua)?; // Ok(table) @@ -159,6 +161,11 @@ async fn net_serve<'a>( lua: &'static Lua, (port, config): (u16, ServeConfig<'a>), ) -> LuaResult> { + if config.handle_web_socket.is_some() { + return Err(LuaError::RuntimeError( + "Server websockets are not yet implemented".to_string(), + )); + } // Note that we need to use a mpsc here and not // a oneshot channel since we move the sender // into our table with the stop function @@ -168,6 +175,10 @@ async fn net_serve<'a>( lua.create_registry_value(handler) .expect("Failed to store websocket handler") })); + // Register a background task to prevent + // the task scheduler from exiting early + let sched = lua.app_data_mut::<&TaskScheduler>().unwrap(); + let task = sched.register_background_task(); let server = Server::bind(&([127, 0, 0, 1], port).into()) .http1_only(true) .http1_keepalive(true) @@ -178,12 +189,15 @@ async fn net_serve<'a>( server_websocket_callback, )) .with_graceful_shutdown(async move { - shutdown_rx.recv().await.unwrap(); + shutdown_rx + .recv() + .await + .expect("Server was stopped instantly"); shutdown_rx.close(); + task.unregister(Ok(())); }); - // TODO: Spawn a new scheduler future with this so we don't block - // and make sure that we register it properly to prevent shutdown - server.await.map_err(LuaError::external)?; + // Spawn a new tokio task so we don't block + task::spawn_local(server); // Create a new read-only table that contains methods // for manipulating server behavior and shutting it down let handle_stop = move |_, _: ()| { diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index 0cc0b82..e05c494 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -1,6 +1,6 @@ use std::{collections::HashSet, process::ExitCode}; -use lua::task::{TaskScheduler, TaskSchedulerResult}; +use lua::task::TaskScheduler; use mlua::prelude::*; use tokio::task::LocalSet; @@ -124,26 +124,21 @@ impl Lune { // left to run, or until a task requests to exit the process let exit_code = LocalSet::new() .run_until(async move { - loop { - let mut got_error = false; - let state = match sched.resume_queue().await { - TaskSchedulerResult::TaskSuccessful { state } => state, - TaskSchedulerResult::TaskErrored { state, error } => { - eprintln!("{}", pretty_format_luau_error(&error)); - got_error = true; - state - } - TaskSchedulerResult::Finished { state } => state, - }; - if let Some(exit_code) = state.exit_code { - return exit_code; - } else if state.num_total == 0 { - if got_error { - return ExitCode::FAILURE; - } else { - return ExitCode::SUCCESS; - } + let mut got_error = false; + let mut result = sched.resume_queue().await; + while !result.is_done() { + if let Some(err) = result.get_lua_error() { + eprintln!("{}", pretty_format_luau_error(&err)); + got_error = true; } + result = sched.resume_queue().await; + } + if let Some(exit_code) = result.get_exit_code() { + exit_code + } else if got_error { + ExitCode::FAILURE + } else { + ExitCode::SUCCESS } }) .await; diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs index c116af3..3ab8f68 100644 --- a/packages/lib/src/lua/task/scheduler.rs +++ b/packages/lib/src/lua/task/scheduler.rs @@ -14,7 +14,7 @@ use futures_util::{future::LocalBoxFuture, stream::FuturesUnordered, Future, Str use mlua::prelude::*; use tokio::{ - sync::Mutex as AsyncMutex, + sync::{mpsc, Mutex as AsyncMutex}, time::{sleep, Instant}, }; @@ -50,7 +50,7 @@ impl fmt::Display for TaskKind { } } -/// A lightweight, clonable struct that represents a +/// A lightweight, copyable struct that represents a /// task in the scheduler and is accessible from Lua #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct TaskReference { @@ -80,38 +80,121 @@ pub struct Task { queued_at: Instant, } -/// A struct representing the current status of the task scheduler -#[derive(Debug, Clone, Copy)] -pub struct TaskSchedulerState { - pub exit_code: Option, - pub num_instant: usize, - pub num_deferred: usize, - pub num_future: usize, - pub num_total: usize, +/** + A handle to a registered background task. + + [`TaskSchedulerUnregistrar::unregister`] must be + called upon completion of the background task to + prevent the task scheduler from running indefinitely. +*/ +#[must_use = "Background tasks must be unregistered"] +#[derive(Debug)] +pub struct TaskSchedulerBackgroundTaskHandle { + sender: mpsc::UnboundedSender, } -impl fmt::Display for TaskSchedulerState { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "TaskSchedulerStatus(\nInstant: {}\nDeferred: {}\nYielded: {}\nTotal: {})", - self.num_instant, self.num_deferred, self.num_future, self.num_total - ) +impl TaskSchedulerBackgroundTaskHandle { + pub fn unregister(self, result: LuaResult<()>) { + self.sender + .send(TaskSchedulerRegistrationMessage::Terminated(result)) + .unwrap_or_else(|_| { + panic!( + "\ + \nFailed to unregister background task - this is an internal error! \ + \nPlease report it at {} \ + \nDetails: Manual \ + ", + env!("CARGO_PKG_REPOSITORY") + ) + }); + } +} + +/// A struct representing the current state of the task scheduler +#[derive(Debug, Clone)] +pub struct TaskSchedulerResult { + lua_error: Option, + exit_code: Option, + num_instant: usize, + num_deferred: usize, + num_futures: usize, + num_spawned: usize, + num_total: usize, +} + +impl TaskSchedulerResult { + fn new(sched: &TaskScheduler) -> Self { + const MESSAGE: &str = + "Failed to get lock - make sure not to call during task scheduler resumption"; + Self { + lua_error: None, + exit_code: if sched.exit_code_set.load(Ordering::Relaxed) { + Some(*sched.exit_code.try_lock().expect(MESSAGE)) + } else { + None + }, + num_instant: sched.task_queue_instant.try_lock().expect(MESSAGE).len(), + num_deferred: sched.task_queue_deferred.try_lock().expect(MESSAGE).len(), + num_futures: sched.futures.try_lock().expect(MESSAGE).len(), + num_spawned: sched.futures_counter.load(Ordering::Relaxed), + num_total: sched.tasks.try_lock().expect(MESSAGE).len(), + } + } + + fn err(sched: &TaskScheduler, err: LuaError) -> Self { + let mut this = Self::new(sched); + this.lua_error = Some(err); + this + } + + /** + Returns a clone of the error from + this task scheduler result, if any. + */ + pub fn get_lua_error(&self) -> Option { + self.lua_error.clone() + } + + /** + Returns a clone of the exit code from + this task scheduler result, if any. + */ + pub fn get_exit_code(&self) -> Option { + self.exit_code + } + + /** + Returns `true` if the task scheduler is still busy, + meaning it still has lua threads left to run. + */ + #[allow(dead_code)] + pub fn is_busy(&self) -> bool { + self.num_total > 0 + } + + /** + Returns `true` if the task scheduler is done, + meaning it has no lua threads left to run, and + no spawned tasks are running in the background. + */ + pub fn is_done(&self) -> bool { + self.num_total == 0 && self.num_spawned == 0 + } + + /** + Returns `true` if the task scheduler has finished all + lua threads, but still has background tasks running. + */ + #[allow(dead_code)] + pub fn is_background(&self) -> bool { + self.num_total == 0 && self.num_spawned > 0 } } #[derive(Debug, Clone)] -pub enum TaskSchedulerResult { - Finished { - state: TaskSchedulerState, - }, - TaskErrored { - error: LuaError, - state: TaskSchedulerState, - }, - TaskSuccessful { - state: TaskSchedulerState, - }, +pub enum TaskSchedulerRegistrationMessage { + Spawned, + Terminated(LuaResult<()>), } /// A task scheduler that implements task queues @@ -121,6 +204,9 @@ pub struct TaskScheduler<'fut> { lua: &'static Lua, tasks: Arc>>, futures: Arc>>>, + futures_tx: mpsc::UnboundedSender, + futures_rx: Arc>>, + futures_counter: AtomicUsize, task_queue_instant: TaskSchedulerQueue, task_queue_deferred: TaskSchedulerQueue, exit_code_set: AtomicBool, @@ -134,10 +220,14 @@ impl<'fut> TaskScheduler<'fut> { Creates a new task scheduler. */ pub fn new(lua: &'static Lua) -> LuaResult { + let (tx, rx) = mpsc::unbounded_channel(); Ok(Self { lua, tasks: Arc::new(Mutex::new(HashMap::new())), futures: Arc::new(AsyncMutex::new(FuturesUnordered::new())), + futures_tx: tx, + futures_rx: Arc::new(AsyncMutex::new(rx)), + futures_counter: AtomicUsize::new(0), task_queue_instant: Arc::new(Mutex::new(VecDeque::new())), task_queue_deferred: Arc::new(Mutex::new(VecDeque::new())), exit_code_set: AtomicBool::new(false), @@ -162,27 +252,6 @@ impl<'fut> TaskScheduler<'fut> { Box::leak(Box::new(self)) } - /** - Gets the current state of the task scheduler. - - Panics if called during any of the task scheduler resumption phases. - */ - pub fn state(&self) -> TaskSchedulerState { - const MESSAGE: &str = - "Failed to get lock - make sure not to call during task scheduler resumption"; - TaskSchedulerState { - exit_code: if self.exit_code_set.load(Ordering::Relaxed) { - Some(*self.exit_code.try_lock().expect(MESSAGE)) - } else { - None - }, - num_instant: self.task_queue_instant.try_lock().expect(MESSAGE).len(), - num_deferred: self.task_queue_deferred.try_lock().expect(MESSAGE).len(), - num_future: self.futures.try_lock().expect(MESSAGE).len(), - num_total: self.tasks.try_lock().expect(MESSAGE).len(), - } - } - /** Stores the exit code for the task scheduler. @@ -593,6 +662,33 @@ impl<'fut> TaskScheduler<'fut> { Ok(()) } + /** + Registers a new background task with the task scheduler. + + This will ensure that the task scheduler keeps running until a + call to [`TaskScheduler::deregister_background_task`] is made. + + The returned [`TaskSchedulerUnregistrar::unregister`] + must be called upon completion of the background task to + prevent the task scheduler from running indefinitely. + */ + pub fn register_background_task(&self) -> TaskSchedulerBackgroundTaskHandle { + let sender = self.futures_tx.clone(); + sender + .send(TaskSchedulerRegistrationMessage::Spawned) + .unwrap_or_else(|e| { + panic!( + "\ + \nFailed to unregister background task - this is an internal error! \ + \nPlease report it at {} \ + \nDetails: {e} \ + ", + env!("CARGO_PKG_REPOSITORY") + ) + }); + TaskSchedulerBackgroundTaskHandle { sender } + } + /** Retrieves the queue for a specific kind of task. @@ -609,18 +705,6 @@ impl<'fut> TaskScheduler<'fut> { } } - /** - Checks if a future exists in the task queue. - - Panics if called during resumption of the futures task queue. - */ - fn next_queue_future_exists(&self) -> bool { - let futs = self.futures.try_lock().expect( - "Failed to get lock on futures - make sure not to call during futures resumption", - ); - !futs.is_empty() - } - /** Resumes the next queued Lua task, if one exists, blocking the current thread until it either yields or finishes. @@ -634,26 +718,10 @@ impl<'fut> TaskScheduler<'fut> { let mut queue_guard = self.get_queue(kind).lock().unwrap(); queue_guard.pop_front() } { - None => { - let status = self.state(); - if status.num_total > 0 { - TaskSchedulerResult::TaskSuccessful { - state: self.state(), - } - } else { - TaskSchedulerResult::Finished { - state: self.state(), - } - } - } + None => TaskSchedulerResult::new(self), Some(task) => match self.resume_task(task, override_args) { - Ok(()) => TaskSchedulerResult::TaskSuccessful { - state: self.state(), - }, - Err(task_err) => TaskSchedulerResult::TaskErrored { - error: task_err, - state: self.state(), - }, + Ok(()) => TaskSchedulerResult::new(self), + Err(task_err) => TaskSchedulerResult::err(self, task_err), }, } } @@ -680,15 +748,15 @@ impl<'fut> TaskScheduler<'fut> { match result { (task, Err(fut_err)) => { // Future errored, don't resume its associated task - // and make sure to cancel / remove it completely - let error_prefer_cancel = match self.remove_task(task) { - Err(cancel_err) => cancel_err, - Ok(_) => fut_err, - }; - TaskSchedulerResult::TaskErrored { - error: error_prefer_cancel, - state: self.state(), - } + // and make sure to cancel / remove it completely, if removal + // also errors then we send that error back instead of the future's error + TaskSchedulerResult::err( + self, + match self.remove_task(task) { + Err(cancel_err) => cancel_err, + Ok(_) => fut_err, + }, + ) } (task, Ok(args)) => { // Promote this future task to an instant task @@ -702,6 +770,46 @@ impl<'fut> TaskScheduler<'fut> { } } + /** + Awaits the next background task registration + message, if any messages exist in the queue. + + This is a no-op if there are no messages. + */ + async fn receive_next_message(&self) -> TaskSchedulerResult { + let message_opt = { + let mut rx = self.futures_rx.lock().await; + rx.recv().await + }; + if let Some(message) = message_opt { + match message { + TaskSchedulerRegistrationMessage::Spawned => { + self.futures_counter.fetch_add(1, Ordering::Relaxed); + TaskSchedulerResult::new(self) + } + TaskSchedulerRegistrationMessage::Terminated(result) => { + let prev = self.futures_counter.fetch_sub(1, Ordering::Relaxed); + if prev == 0 { + panic!( + r#" + Terminated a background task without it running - this is an internal error! + Please report it at {} + "#, + env!("CARGO_PKG_REPOSITORY") + ) + } + if let Err(e) = result { + TaskSchedulerResult::err(self, e) + } else { + TaskSchedulerResult::new(self) + } + } + } + } else { + TaskSchedulerResult::new(self) + } + } + /** Resumes the task scheduler queue. @@ -712,27 +820,36 @@ impl<'fut> TaskScheduler<'fut> { futures concurrently, awaiting the first one to be ready for resumption. */ pub async fn resume_queue(&self) -> TaskSchedulerResult { - let status = self.state(); + let current = TaskSchedulerResult::new(self); /* Resume tasks in the internal queue, in this order: - 1. Tasks from task.spawn, this includes the main thread - 2. Tasks from task.defer - 3. Tasks from task.delay / task.wait / native futures, first ready first resumed + * 🛑 = blocking - lua tasks, in order + * ⏳ = async - first come, first serve + + 1. 🛑 Tasks from task.spawn and the main thread + 2. 🛑 Tasks from task.defer + 3. ⏳ Tasks from task.delay / task.wait, spawned background tasks */ - if status.num_instant > 0 { + if current.num_instant > 0 { self.resume_next_queue_task(TaskKind::Instant, None) - } else if status.num_deferred > 0 { + } else if current.num_deferred > 0 { self.resume_next_queue_task(TaskKind::Deferred, None) - } else { - // 3. Threads from task.delay or task.wait, futures - if self.next_queue_future_exists() { - self.resume_next_queue_future().await - } else { - TaskSchedulerResult::Finished { - state: self.state(), - } + } else if current.num_futures > 0 && current.num_spawned > 0 { + // Futures, spawned background tasks + tokio::select! { + result = self.resume_next_queue_future() => result, + result = self.receive_next_message() => result, } + } else if current.num_futures > 0 { + // Futures + self.resume_next_queue_future().await + } else if current.num_spawned > 0 { + // Only spawned background tasks, these may then + // spawn new lua tasks and "wake up" the scheduler + self.receive_next_message().await + } else { + TaskSchedulerResult::new(self) } } } diff --git a/tests/net/serve/websockets.luau b/tests/net/serve/websockets.luau index 0881b66..5e5985e 100644 --- a/tests/net/serve/websockets.luau +++ b/tests/net/serve/websockets.luau @@ -60,5 +60,4 @@ assert( ) -- Stop the server to end the test - handle2.stop()