diff --git a/CHANGELOG.md b/CHANGELOG.md index b8420c7..5f284fb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,12 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## Unreleased + +### Fixed + +- Fixed `task.delay` keeping the script running even if it was cancelled using `task.cancel` + ## `0.4.0` - February 11th, 2023 ### Added diff --git a/packages/lib/src/globals/net.rs b/packages/lib/src/globals/net.rs index e3e49c3..0dfd1ee 100644 --- a/packages/lib/src/globals/net.rs +++ b/packages/lib/src/globals/net.rs @@ -2,7 +2,7 @@ use std::{ collections::HashMap, future::Future, pin::Pin, - sync::{Arc, Weak}, + sync::Arc, task::{Context, Poll}, }; @@ -13,14 +13,14 @@ use hyper::{Body, Request, Response, Server}; use hyper_tungstenite::{is_upgrade_request as is_ws_upgrade_request, upgrade as ws_upgrade}; use reqwest::Method; -use tokio::{ - sync::mpsc::{self, Sender}, - task, -}; +use tokio::{sync::mpsc, task}; use crate::{ lua::net::{NetClient, NetClientBuilder, NetWebSocketClient, NetWebSocketServer, ServeConfig}, - utils::{message::LuneMessage, net::get_request_user_agent_header, table::TableBuilder}, + utils::{ + message::LuneMessage, net::get_request_user_agent_header, table::TableBuilder, + task::send_message, + }, }; pub fn create(lua: &'static Lua) -> LuaResult { @@ -179,18 +179,17 @@ async fn net_serve<'a>( }); // Make sure we register the thread properly by sending messages // when the server starts up and when it shuts down or errors - let server_sender = lua - .app_data_ref::>>() - .unwrap() - .upgrade() - .unwrap(); - let _ = server_sender.send(LuneMessage::Spawned).await; + send_message(lua, LuneMessage::Spawned).await?; task::spawn_local(async move { let res = server.await.map_err(LuaError::external); - let _ = match res { - Err(e) => server_sender.try_send(LuneMessage::LuaError(e)), - Ok(_) => server_sender.try_send(LuneMessage::Finished), - }; + let _ = send_message( + lua, + match res { + Err(e) => LuneMessage::LuaError(e), + Ok(_) => LuneMessage::Finished, + }, + ) + .await; }); // Create a new read-only table that contains methods // for manipulating server behavior and shutting it down @@ -255,11 +254,6 @@ impl Service> for NetService { // function & lune message sender to use later let bytes = to_bytes(body).await.map_err(LuaError::external)?; let handler: LuaFunction = lua.registry_value(&key)?; - let sender = lua - .app_data_ref::>>() - .unwrap() - .upgrade() - .unwrap(); // Create a readonly table for the request query params let query_params = TableBuilder::new(lua)? .with_values( @@ -320,10 +314,7 @@ impl Service> for NetService { } // If the handler returns an error, generate a 5xx response Err(err) => { - sender - .send(LuneMessage::LuaError(err.to_lua_err())) - .await - .map_err(LuaError::external)?; + send_message(lua, LuneMessage::LuaError(err.to_lua_err())).await?; Ok(Response::builder() .status(500) .body(Body::from("Internal Server Error")) @@ -332,13 +323,10 @@ impl Service> for NetService { // If the handler returns a value that is of an invalid type, // this should also be an error, so generate a 5xx response Ok(value) => { - sender - .send(LuneMessage::LuaError(LuaError::RuntimeError(format!( + send_message(lua, LuneMessage::LuaError(LuaError::RuntimeError(format!( "Expected net serve handler to return a value of type 'string' or 'table', got '{}'", value.type_name() - )))) - .await - .map_err(LuaError::external)?; + )))).await?; Ok(Response::builder() .status(500) .body(Body::from("Internal Server Error")) diff --git a/packages/lib/src/globals/task.rs b/packages/lib/src/globals/task.rs index 05c4c37..0bc728d 100644 --- a/packages/lib/src/globals/task.rs +++ b/packages/lib/src/globals/task.rs @@ -10,6 +10,9 @@ use crate::utils::{ const MINIMUM_WAIT_OR_DELAY_DURATION: f32 = 10.0 / 1_000.0; // 10ms +// TODO: We should probably keep track of all threads in a scheduler userdata +// that takes care of scheduling in a better way, and it should keep resuming +// threads until it encounters a delayed / waiting thread, then task:sleep pub fn create(lua: &'static Lua) -> LuaResult { // HACK: There is no way to call coroutine.close directly from the mlua // crate, so we need to fetch the function and store it in the registry @@ -81,12 +84,22 @@ async fn task_delay<'a>( let task_thread_key = lua.create_registry_value(task_thread)?; let task_args_key = lua.create_registry_value(args.into_vec())?; let lua_thread_to_return = lua.registry_value(&task_thread_key)?; - run_registered_task(lua, TaskRunMode::Deferred, async move { - task_wait(lua, duration).await?; + let start = Instant::now(); + let dur = Duration::from_secs_f32( + duration + .map(|d| d.max(MINIMUM_WAIT_OR_DELAY_DURATION)) + .unwrap_or(MINIMUM_WAIT_OR_DELAY_DURATION), + ); + run_registered_task(lua, TaskRunMode::Instant, async move { let thread: LuaThread = lua.registry_value(&task_thread_key)?; - let argsv: Vec = lua.registry_value(&task_args_key)?; - let args = LuaMultiValue::from_vec(argsv); + // NOTE: We are somewhat busy-waiting here, but we have to do this to make sure + // that delayed+cancelled threads do not prevent the tokio runtime from finishing + while thread.status() == LuaThreadStatus::Resumable && start.elapsed() < dur { + time::sleep(Duration::from_millis(1)).await; + } if thread.status() == LuaThreadStatus::Resumable { + let argsv: Vec = lua.registry_value(&task_args_key)?; + let args = LuaMultiValue::from_vec(argsv); let _: LuaMultiValue = thread.into_async(args).await?; } lua.remove_registry_value(task_thread_key)?; diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index 4ca37f9..8317f86 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -2,6 +2,7 @@ use std::{collections::HashSet, process::ExitCode, sync::Arc}; use mlua::prelude::*; use tokio::{sync::mpsc, task}; +use utils::task::send_message; pub(crate) mod globals; pub(crate) mod lua; @@ -101,11 +102,7 @@ impl Lune { // Spawn the main thread from our entrypoint script let script_name = script_name.to_string(); let script_chunk = script_contents.to_string(); - let script_sender = snd.clone(); - script_sender - .send(LuneMessage::Spawned) - .await - .map_err(LuaError::external)?; + send_message(lua, LuneMessage::Spawned).await?; task_set.spawn_local(async move { let result = lua .load(&script_chunk) @@ -113,10 +110,14 @@ impl Lune { .unwrap() .eval_async::() .await; - match result { - Err(e) => script_sender.send(LuneMessage::LuaError(e)).await, - Ok(_) => script_sender.send(LuneMessage::Finished).await, - } + send_message( + lua, + match result { + Err(e) => LuneMessage::LuaError(e), + Ok(_) => LuneMessage::Finished, + }, + ) + .await }); // Run the executor until there are no tasks left, // taking care to not exit right away for errors diff --git a/packages/lib/src/utils/process.rs b/packages/lib/src/utils/process.rs index e87991e..76b7a22 100644 --- a/packages/lib/src/utils/process.rs +++ b/packages/lib/src/utils/process.rs @@ -1,10 +1,12 @@ -use std::{process::ExitStatus, sync::Weak, time::Duration}; +use std::{process::ExitStatus, time::Duration}; use mlua::prelude::*; -use tokio::{io, process::Child, sync::mpsc::Sender, task::spawn, time::sleep}; +use tokio::{io, process::Child, task::spawn, time::sleep}; use crate::utils::{futures::AsyncTeeWriter, message::LuneMessage}; +use super::task::send_message; + pub async fn pipe_and_inherit_child_process_stdio( mut child: Child, ) -> LuaResult<(ExitStatus, Vec, Vec)> { @@ -42,17 +44,9 @@ pub async fn pipe_and_inherit_child_process_stdio( } pub async fn exit_and_yield_forever(lua: &'static Lua, exit_code: Option) -> LuaResult<()> { - let sender = lua - .app_data_ref::>>() - .unwrap() - .upgrade() - .unwrap(); // Send an exit signal to the main thread, which // will try to exit safely and as soon as possible - sender - .send(LuneMessage::Exit(exit_code.unwrap_or(0))) - .await - .map_err(LuaError::external)?; + send_message(lua, LuneMessage::Exit(exit_code.unwrap_or(0))).await?; // Make sure to block the rest of this thread indefinitely since // the main thread may not register the exit signal right away sleep(Duration::MAX).await; diff --git a/packages/lib/src/utils/task.rs b/packages/lib/src/utils/task.rs index dd4ff7c..db55a9d 100644 --- a/packages/lib/src/utils/task.rs +++ b/packages/lib/src/utils/task.rs @@ -25,24 +25,23 @@ impl fmt::Display for TaskRunMode { } } -pub async fn run_registered_task( - lua: &'static Lua, - mode: TaskRunMode, - to_run: impl Future> + 'static, -) -> LuaResult<()> { - // Fetch global reference to message sender +pub async fn send_message(lua: &'static Lua, message: LuneMessage) -> LuaResult<()> { let sender = lua .app_data_ref::>>() .unwrap() .upgrade() .unwrap(); + sender.send(message).await.map_err(LuaError::external) +} + +pub async fn run_registered_task( + lua: &'static Lua, + mode: TaskRunMode, + to_run: impl Future> + 'static, +) -> LuaResult<()> { // Send a message that we have started our task - sender - .send(LuneMessage::Spawned) - .await - .map_err(LuaError::external)?; + send_message(lua, LuneMessage::Spawned).await?; // Run the new task separately from the current one using the executor - let sender = sender.clone(); let task = task::spawn_local(async move { // HACK: For deferred tasks we yield a bunch of times to try and ensure // we run our task at the very end of the async queue, this can fail if @@ -52,13 +51,15 @@ pub async fn run_registered_task( task::yield_now().await; } } - sender - .send(match to_run.await { + send_message( + lua, + match to_run.await { Ok(_) => LuneMessage::Finished, Err(LuaError::CoroutineInactive) => LuneMessage::Finished, // Task was canceled Err(e) => LuneMessage::LuaError(e), - }) - .await + }, + ) + .await }); // Wait for the task to complete if we want this call to be blocking // Any lua errors will be sent through the message channel back diff --git a/tests/task/cancel.luau b/tests/task/cancel.luau index 9cbf06d..dfbd0ce 100644 --- a/tests/task/cancel.luau +++ b/tests/task/cancel.luau @@ -9,12 +9,13 @@ task.wait(0.1) assert(not flag, "Cancel should handle deferred threads") local flag2: boolean = false -local thread2 = task.delay(0, function() +local thread2 = task.delay(0.1, function() flag2 = true end) +task.wait(0) task.cancel(thread2) -task.wait(0.1) -assert(not flag2, "Cancel should handle deferred threads") +task.wait(0.2) +assert(not flag2, "Cancel should handle delayed threads") -- Cancellation should work with yields in spawned threads