diff --git a/src/lib/globals/console.rs b/src/lib/globals/console.rs index b790326..79bc86d 100644 --- a/src/lib/globals/console.rs +++ b/src/lib/globals/console.rs @@ -5,7 +5,7 @@ use crate::utils::{ table_builder::TableBuilder, }; -pub async fn create(lua: &Lua) -> LuaResult<()> { +pub fn create(lua: &Lua) -> LuaResult<()> { let print = |args: &LuaMultiValue, throw: bool| -> LuaResult<()> { let s = pretty_format_multi_value(args)?; if throw { diff --git a/src/lib/globals/fs.rs b/src/lib/globals/fs.rs index a33f1d0..7154052 100644 --- a/src/lib/globals/fs.rs +++ b/src/lib/globals/fs.rs @@ -5,7 +5,7 @@ use smol::{fs, prelude::*}; use crate::utils::table_builder::TableBuilder; -pub async fn create(lua: &Lua) -> LuaResult<()> { +pub fn create(lua: &Lua) -> LuaResult<()> { lua.globals().raw_set( "fs", TableBuilder::new(lua)? diff --git a/src/lib/globals/net.rs b/src/lib/globals/net.rs index 002d33b..85381fe 100644 --- a/src/lib/globals/net.rs +++ b/src/lib/globals/net.rs @@ -4,7 +4,7 @@ use mlua::prelude::*; use crate::utils::{net::get_request_user_agent_header, table_builder::TableBuilder}; -pub async fn create(lua: &Lua) -> LuaResult<()> { +pub fn create(lua: &Lua) -> LuaResult<()> { lua.globals().raw_set( "net", TableBuilder::new(lua)? diff --git a/src/lib/globals/process.rs b/src/lib/globals/process.rs index f6e86cd..bff727c 100644 --- a/src/lib/globals/process.rs +++ b/src/lib/globals/process.rs @@ -9,7 +9,7 @@ use smol::process::Command; use crate::utils::table_builder::TableBuilder; -pub async fn create(lua: &Lua, args_vec: Vec) -> LuaResult<()> { +pub fn create(lua: &Lua, args_vec: Vec) -> LuaResult<()> { // Create readonly args array let args_tab = TableBuilder::new(lua)? .with_sequential_values(args_vec)? diff --git a/src/lib/globals/task.rs b/src/lib/globals/task.rs index 53ca97a..2e36f02 100644 --- a/src/lib/globals/task.rs +++ b/src/lib/globals/task.rs @@ -1,37 +1,220 @@ -use std::time::{Duration, Instant}; +// TODO: Figure out a good way to remove all the shared boilerplate from these functions + +use std::{ + sync::Weak, + time::{Duration, Instant}, +}; use mlua::prelude::*; -use smol::Timer; +use smol::{channel::Sender, LocalExecutor, Timer}; -use crate::utils::table_builder::TableBuilder; +use crate::{utils::table_builder::TableBuilder, LuneMessage}; -const TASK_LIB: &str = include_str!("../luau/task.luau"); - -pub async fn create(lua: &Lua) -> LuaResult<()> { - let wait = lua.create_async_function(move |_, duration: Option| async move { - let start = Instant::now(); - Timer::after( - duration - .map(Duration::from_secs_f32) - .unwrap_or(Duration::ZERO), - ) - .await; - let end = Instant::now(); - Ok((end - start).as_secs_f32()) - })?; - let task_lib: LuaTable = lua - .load(TASK_LIB) - .set_name("task")? - .call_async(wait.clone()) - .await?; +pub fn create(lua: &Lua) -> LuaResult<()> { lua.globals().raw_set( "task", TableBuilder::new(lua)? - .with_value("cancel", task_lib.raw_get::<_, LuaFunction>("cancel")?)? - .with_value("defer", task_lib.raw_get::<_, LuaFunction>("defer")?)? - .with_value("delay", task_lib.raw_get::<_, LuaFunction>("delay")?)? - .with_value("spawn", task_lib.raw_get::<_, LuaFunction>("spawn")?)? - .with_value("wait", wait)? + .with_async_function("cancel", task_cancel)? + .with_async_function("delay", task_delay)? + .with_async_function("defer", task_defer)? + .with_async_function("spawn", task_spawn)? + .with_async_function("wait", task_wait)? .build_readonly()?, ) } + +fn tof_to_thread<'a>(lua: &'a Lua, tof: LuaValue<'a>) -> LuaResult> { + match tof { + LuaValue::Thread(t) => Ok(t), + LuaValue::Function(f) => Ok(lua.create_thread(f)?), + value => Err(LuaError::RuntimeError(format!( + "Argument must be a thread or function, got {}", + value.type_name() + ))), + } +} + +async fn task_cancel<'a>(lua: &'a Lua, thread: LuaThread<'a>) -> LuaResult<()> { + let coroutine: LuaTable = lua.globals().raw_get("coroutine")?; + let close: LuaFunction = coroutine.raw_get("close")?; + close.call_async(thread).await?; + Ok(()) +} + +async fn task_defer<'a>(task_lua: &'a Lua, tof: LuaValue<'a>) -> LuaResult> { + // Boilerplate to get arc-ed lua & async executor + let lua = task_lua + .app_data_ref::>() + .unwrap() + .upgrade() + .unwrap(); + let exec = task_lua + .app_data_ref::>() + .unwrap() + .upgrade() + .unwrap(); + let sender = task_lua + .app_data_ref::>>() + .unwrap() + .upgrade() + .unwrap(); + // Spawn a new detached thread + sender + .send(LuneMessage::Spawned) + .await + .map_err(LuaError::external)?; + let thread = tof_to_thread(&lua, tof)?; + let thread_key = lua.create_registry_value(thread)?; + let thread_to_return = task_lua.registry_value(&thread_key)?; + let thread_sender = sender.clone(); + exec.spawn(async move { + let result = async { + task_wait(&lua, None).await?; + let thread = lua.registry_value::(&thread_key)?; + if thread.status() == LuaThreadStatus::Resumable { + thread.into_async::<_, LuaMultiValue>(()).await?; + } + Ok::<_, LuaError>(()) + }; + thread_sender + .send(match result.await { + Ok(_) => LuneMessage::Finished, + Err(e) => LuneMessage::LuaError(e), + }) + .await + }) + .detach(); + sender + .send(LuneMessage::Finished) + .await + .map_err(LuaError::external)?; + Ok(thread_to_return) +} + +async fn task_delay<'a>( + task_lua: &'a Lua, + (duration, tof): (Option, LuaValue<'a>), +) -> LuaResult> { + // Boilerplate to get arc-ed lua & async executor + let lua = task_lua + .app_data_ref::>() + .unwrap() + .upgrade() + .unwrap(); + let exec = task_lua + .app_data_ref::>() + .unwrap() + .upgrade() + .unwrap(); + let sender = task_lua + .app_data_ref::>>() + .unwrap() + .upgrade() + .unwrap(); + // Spawn a new detached thread + sender + .send(LuneMessage::Spawned) + .await + .map_err(LuaError::external)?; + let thread = tof_to_thread(&lua, tof)?; + let thread_key = lua.create_registry_value(thread)?; + let thread_to_return = task_lua.registry_value(&thread_key)?; + let thread_sender = sender.clone(); + exec.spawn(async move { + let result = async { + task_wait(&lua, duration).await?; + let thread = lua.registry_value::(&thread_key)?; + if thread.status() == LuaThreadStatus::Resumable { + thread.into_async::<_, LuaMultiValue>(()).await?; + } + Ok::<_, LuaError>(()) + }; + thread_sender + .send(match result.await { + Ok(_) => LuneMessage::Finished, + Err(e) => LuneMessage::LuaError(e), + }) + .await + }) + .detach(); + sender + .send(LuneMessage::Finished) + .await + .map_err(LuaError::external)?; + Ok(thread_to_return) +} + +async fn task_spawn<'a>(task_lua: &'a Lua, tof: LuaValue<'a>) -> LuaResult> { + // Boilerplate to get arc-ed lua & async executor + let lua = task_lua + .app_data_ref::>() + .unwrap() + .upgrade() + .unwrap(); + let exec = task_lua + .app_data_ref::>() + .unwrap() + .upgrade() + .unwrap(); + let sender = task_lua + .app_data_ref::>>() + .unwrap() + .upgrade() + .unwrap(); + // Spawn a new detached thread + sender + .send(LuneMessage::Spawned) + .await + .map_err(LuaError::external)?; + let thread = tof_to_thread(&lua, tof)?; + let thread_key = lua.create_registry_value(thread)?; + let thread_to_return = task_lua.registry_value(&thread_key)?; + let thread_sender = sender.clone(); + // FIXME: This does not run the thread instantly + exec.spawn(async move { + let result = async { + let thread = lua.registry_value::(&thread_key)?; + if thread.status() == LuaThreadStatus::Resumable { + thread.into_async::<_, LuaMultiValue>(()).await?; + } + Ok::<_, LuaError>(()) + }; + thread_sender + .send(match result.await { + Ok(_) => LuneMessage::Finished, + Err(e) => LuneMessage::LuaError(e), + }) + .await + }) + .detach(); + sender + .send(LuneMessage::Finished) + .await + .map_err(LuaError::external)?; + Ok(thread_to_return) +} + +async fn task_wait(lua: &Lua, duration: Option) -> LuaResult { + let sender = lua + .app_data_ref::>>() + .unwrap() + .upgrade() + .unwrap(); + sender + .send(LuneMessage::Spawned) + .await + .map_err(LuaError::external)?; + let start = Instant::now(); + Timer::after( + duration + .map(Duration::from_secs_f32) + .unwrap_or(Duration::ZERO), + ) + .await; + let end = Instant::now(); + sender + .send(LuneMessage::Finished) + .await + .map_err(LuaError::external)?; + Ok((end - start).as_secs_f32()) +} diff --git a/src/lib/lib.rs b/src/lib/lib.rs index 6d18d6a..312d138 100644 --- a/src/lib/lib.rs +++ b/src/lib/lib.rs @@ -1,6 +1,6 @@ -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; -use anyhow::{bail, Result}; +use anyhow::{anyhow, bail, Result}; use mlua::prelude::*; use smol::LocalExecutor; @@ -33,6 +33,13 @@ impl LuneGlobal { } } +pub(crate) enum LuneMessage { + Spawned, + Finished, + Error(anyhow::Error), + LuaError(mlua::Error), +} + #[derive(Clone, Debug, Default)] pub struct Lune { globals: HashSet, @@ -62,37 +69,76 @@ impl Lune { } pub async fn run(&self, name: &str, chunk: &str) -> Result<()> { - let lua = Lua::new(); - let exec = LocalExecutor::new(); - smol::block_on(exec.run(async { - for global in &self.globals { - match &global { - LuneGlobal::Console => create_console(&lua).await?, - LuneGlobal::Fs => create_fs(&lua).await?, - LuneGlobal::Net => create_net(&lua).await?, - LuneGlobal::Process => create_process(&lua, self.args.clone()).await?, - LuneGlobal::Task => create_task(&lua).await?, - } + let (s, r) = smol::channel::unbounded::(); + let lua = Arc::new(mlua::Lua::new()); + let exec = Arc::new(LocalExecutor::new()); + let sender = Arc::new(s); + let receiver = Arc::new(r); + lua.set_app_data(Arc::downgrade(&lua)); + lua.set_app_data(Arc::downgrade(&exec)); + lua.set_app_data(Arc::downgrade(&sender)); + lua.set_app_data(Arc::downgrade(&receiver)); + // Add in wanted lune globals + for global in &self.globals { + match &global { + LuneGlobal::Console => create_console(&lua)?, + LuneGlobal::Fs => create_fs(&lua)?, + LuneGlobal::Net => create_net(&lua)?, + LuneGlobal::Process => create_process(&lua, self.args.clone())?, + LuneGlobal::Task => create_task(&lua)?, } + } + // Spawn the main thread from our entrypoint script + let script_name = name.to_string(); + let script_chunk = chunk.to_string(); + exec.spawn(async move { + sender.send(LuneMessage::Spawned).await?; let result = lua - .load(chunk) - .set_name(name)? + .load(&script_chunk) + .set_name(&script_name) + .unwrap() .call_async::<_, LuaMultiValue>(LuaMultiValue::new()) .await; - match result { - Ok(_) => Ok(()), - Err(e) => { - if cfg!(test) { - bail!(pretty_format_luau_error(&e)) - } else { - bail!( - "\n{}\n{}", - format_label("ERROR"), - pretty_format_luau_error(&e) - ) + let message = match result { + Ok(_) => LuneMessage::Finished, + Err(e) => LuneMessage::Error(if cfg!(test) { + anyhow!("{}", pretty_format_luau_error(&e)) + } else { + anyhow!( + "\n{}\n{}", + format_label("ERROR"), + pretty_format_luau_error(&e) + ) + }), + }; + sender.send(message).await + }) + .detach(); + // Run the executor until there are no tasks left + let mut task_count = 1; + smol::block_on(exec.run(async { + while let Ok(message) = receiver.recv().await { + match message { + LuneMessage::Spawned => { + task_count += 1; + } + LuneMessage::Finished => { + task_count -= 1; + if task_count <= 0 { + break; + } + } + LuneMessage::Error(e) => { + task_count -= 1; + bail!("{}", e) + } + LuneMessage::LuaError(e) => { + task_count -= 1; + bail!("{}", e) } } } + Ok(()) })) } } diff --git a/src/lib/luau/task.luau b/src/lib/luau/task.luau deleted file mode 100644 index 7539da4..0000000 --- a/src/lib/luau/task.luau +++ /dev/null @@ -1,67 +0,0 @@ -type ThreadOrFunction = thread | (...any) -> ...any - --- NOTE: The async wait function gets passed in here by task.rs, --- the same function will then be used for the global task library -local wait: (seconds: number?) -> number = ... - -local task = {} - -function task.cancel(thread: unknown) - if type(thread) ~= "thread" then - error(string.format("Argument #1 must be a thread, got %s", typeof(thread))) - else - coroutine.close(thread) - end -end - -function task.defer(tof: ThreadOrFunction, ...: any): thread - if type(tof) == "thread" then - local thread = coroutine.create(function(...) - wait() - coroutine.resume(tof, ...) - end) - coroutine.resume(thread, ...) - return thread - elseif type(tof) == "function" then - local thread = coroutine.create(tof) - task.defer(thread, ...) - return thread - else - error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof))) - end -end - -function task.delay(duration: number?, tof: ThreadOrFunction, ...: any): thread - if duration ~= nil and type(duration) ~= "number" then - error(string.format("Argument #1 must be a number or nil, got %s", typeof(duration))) - end - if type(tof) == "thread" then - local thread = coroutine.create(function(...) - wait(duration) - coroutine.resume(tof, ...) - end) - coroutine.resume(thread, ...) - return thread - elseif type(tof) == "function" then - local thread = coroutine.create(tof) - task.delay(duration, thread, ...) - return thread - else - error(string.format("Argument #2 must be a thread or function, got %s", typeof(tof))) - end -end - -function task.spawn(tof: ThreadOrFunction, ...: any): thread - if type(tof) == "thread" then - coroutine.resume(tof, ...) - return tof - elseif type(tof) == "function" then - local thread = coroutine.create(tof) - coroutine.resume(thread, ...) - return thread - else - error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof))) - end -end - -return task diff --git a/src/tests/task/wait.luau b/src/tests/task/wait.luau index fae8489..c92c733 100644 --- a/src/tests/task/wait.luau +++ b/src/tests/task/wait.luau @@ -37,19 +37,10 @@ measure(1 / 10) -- Wait should work in other threads, too local flag: boolean = false -coroutine.wrap(function() - task.wait(0.1) - flag = true -end)() -assert(not flag, "Wait failed for a coroutine (1)") -task.wait(0.2) -assert(flag, "Wait failed for a coroutine (2)") - -local flag2: boolean = false task.spawn(function() task.wait(0.1) - flag2 = true + flag = true end) -assert(not flag2, "Wait failed for a task-spawned thread (1)") +assert(not flag, "Wait failed for a task-spawned thread (1)") task.wait(0.2) -assert(flag2, "Wait failed for a task-spawned thread (2)") +assert(flag, "Wait failed for a task-spawned thread (2)")