diff --git a/lune.yml b/lune.yml index a8c4dbb..11b736a 100644 --- a/lune.yml +++ b/lune.yml @@ -74,22 +74,22 @@ globals: - required: false type: table # Task - # task.cancel: - # task.defer: - # args: - # - type: thread | function - # - type: "..." - # task.delay: - # args: - # - required: false - # type: number - # - type: thread | function - # - type: "..." - # task.spawn: - # args: - # - type: thread | function - # - type: "..." - # task.wait: - # args: - # - required: false - # type: number + task.cancel: + task.defer: + args: + - type: thread | function + - type: "..." + task.delay: + args: + - required: false + type: number + - type: thread | function + - type: "..." + task.spawn: + args: + - type: thread | function + - type: "..." + task.wait: + args: + - required: false + type: number diff --git a/luneTypes.d.luau b/luneTypes.d.luau index 9602285..791075e 100644 --- a/luneTypes.d.luau +++ b/luneTypes.d.luau @@ -52,12 +52,10 @@ declare process: { }, } ---[[ - declare task: { - cancel: (t: thread) -> (), - defer: (f: thread | (A...) -> (R...), A...) -> (R...), - delay: (duration: number?, f: thread | (A...) -> (R...), A...) -> (R...), - spawn: (f: thread | (A...) -> (R...), A...) -> (R...), - wait: (duration: number?) -> (number), - } -]] +declare task: { + cancel: (t: thread) -> (), + defer: (f: thread | (T...) -> (...any), T...) -> thread, + delay: (duration: number?, f: thread | (T...) -> (...any), T...) -> thread, + spawn: (f: thread | (T...) -> (...any), T...) -> thread, + wait: (duration: number?) -> (number), +} diff --git a/src/lib/globals/task.rs b/src/lib/globals/task.rs index 0b80fc1..53ca97a 100644 --- a/src/lib/globals/task.rs +++ b/src/lib/globals/task.rs @@ -5,87 +5,33 @@ use smol::Timer; use crate::utils::table_builder::TableBuilder; +const TASK_LIB: &str = include_str!("../luau/task.luau"); + pub async fn create(lua: &Lua) -> LuaResult<()> { + let wait = lua.create_async_function(move |_, duration: Option| async move { + let start = Instant::now(); + Timer::after( + duration + .map(Duration::from_secs_f32) + .unwrap_or(Duration::ZERO), + ) + .await; + let end = Instant::now(); + Ok((end - start).as_secs_f32()) + })?; + let task_lib: LuaTable = lua + .load(TASK_LIB) + .set_name("task")? + .call_async(wait.clone()) + .await?; lua.globals().raw_set( "task", TableBuilder::new(lua)? - .with_async_function("cancel", task_cancel)? - .with_async_function("defer", task_defer)? - .with_async_function("delay", task_delay)? - .with_async_function("spawn", task_spawn)? - .with_async_function("wait", task_wait)? + .with_value("cancel", task_lib.raw_get::<_, LuaFunction>("cancel")?)? + .with_value("defer", task_lib.raw_get::<_, LuaFunction>("defer")?)? + .with_value("delay", task_lib.raw_get::<_, LuaFunction>("delay")?)? + .with_value("spawn", task_lib.raw_get::<_, LuaFunction>("spawn")?)? + .with_value("wait", wait)? .build_readonly()?, ) } - -fn get_or_create_thread_from_arg<'a>(lua: &'a Lua, arg: LuaValue<'a>) -> LuaResult> { - match arg { - LuaValue::Thread(thread) => Ok(thread), - LuaValue::Function(func) => Ok(lua.create_thread(func)?), - val => Err(LuaError::RuntimeError(format!( - "Expected type thread or function, got {}", - val.type_name() - ))), - } -} - -async fn resume_thread(lua: &Lua, thread: LuaThread<'_>, args: LuaMultiValue<'_>) -> LuaResult<()> { - let coroutine: LuaTable = lua.globals().raw_get("coroutine")?; - let resume: LuaFunction = coroutine.raw_get("resume")?; - // FIXME: This is blocking, we should spawn a local tokio task, - // but doing that moves "thread" and "args", that both have - // the lifetime of the outer function, so it doesn't work - resume.call_async((thread, args)).await?; - Ok(()) -} - -async fn task_cancel(lua: &Lua, thread: LuaThread<'_>) -> LuaResult<()> { - let coroutine: LuaTable = lua.globals().raw_get("coroutine")?; - let close: LuaFunction = coroutine.raw_get("close")?; - close.call_async(thread).await?; - Ok(()) -} - -async fn task_defer<'a>( - lua: &'a Lua, - (tof, args): (LuaValue<'a>, LuaMultiValue<'a>), -) -> LuaResult> { - // TODO: Defer (sleep a minimum amount of time) - let thread = get_or_create_thread_from_arg(lua, tof)?; - resume_thread(lua, thread.clone(), args).await?; - Ok(thread) -} - -async fn task_delay<'a>( - lua: &'a Lua, - (_delay, tof, args): (Option, LuaValue<'a>, LuaMultiValue<'a>), -) -> LuaResult> { - // TODO: Delay by the amount of time wanted - let thread = get_or_create_thread_from_arg(lua, tof)?; - resume_thread(lua, thread.clone(), args).await?; - Ok(thread) -} - -async fn task_spawn<'a>( - lua: &'a Lua, - (tof, args): (LuaValue<'a>, LuaMultiValue<'a>), -) -> LuaResult> { - let thread = get_or_create_thread_from_arg(lua, tof)?; - resume_thread(lua, thread.clone(), args).await?; - Ok(thread) -} - -// FIXME: It doesn't seem possible to properly make an async wait -// function with mlua right now, something breaks when using -// the async wait function inside of a coroutine -async fn task_wait(_: &Lua, duration: Option) -> LuaResult { - let start = Instant::now(); - Timer::after( - duration - .map(Duration::from_secs_f32) - .unwrap_or(Duration::ZERO), - ) - .await; - let end = Instant::now(); - Ok((end - start).as_secs_f32()) -} diff --git a/src/lib/lib.rs b/src/lib/lib.rs index e82543e..6d18d6a 100644 --- a/src/lib/lib.rs +++ b/src/lib/lib.rs @@ -2,6 +2,7 @@ use std::collections::HashSet; use anyhow::{bail, Result}; use mlua::prelude::*; +use smol::LocalExecutor; pub mod globals; pub mod utils; @@ -61,8 +62,9 @@ impl Lune { } pub async fn run(&self, name: &str, chunk: &str) -> Result<()> { - smol::block_on(async { - let lua = Lua::new(); + let lua = Lua::new(); + let exec = LocalExecutor::new(); + smol::block_on(exec.run(async { for global in &self.globals { match &global { LuneGlobal::Console => create_console(&lua).await?, @@ -72,7 +74,11 @@ impl Lune { LuneGlobal::Task => create_task(&lua).await?, } } - let result = lua.load(chunk).set_name(name)?.exec_async().await; + let result = lua + .load(chunk) + .set_name(name)? + .call_async::<_, LuaMultiValue>(LuaMultiValue::new()) + .await; match result { Ok(_) => Ok(()), Err(e) => { @@ -87,7 +93,7 @@ impl Lune { } } } - }) + })) } } @@ -139,11 +145,10 @@ mod tests { net_request_redirect: "net/request/redirect", net_json_decode: "net/json/decode", net_json_encode: "net/json/encode", - // FIXME: Re-enable these tests for doing more work on the task library - // task_cancel: "task/cancel", - // task_defer: "task/defer", - // task_delay: "task/delay", - // task_spawn: "task/spawn", - // task_wait: "task/wait", + task_cancel: "task/cancel", + task_defer: "task/defer", + task_delay: "task/delay", + task_spawn: "task/spawn", + task_wait: "task/wait", } } diff --git a/src/lib/luau/task.luau b/src/lib/luau/task.luau new file mode 100644 index 0000000..7539da4 --- /dev/null +++ b/src/lib/luau/task.luau @@ -0,0 +1,67 @@ +type ThreadOrFunction = thread | (...any) -> ...any + +-- NOTE: The async wait function gets passed in here by task.rs, +-- the same function will then be used for the global task library +local wait: (seconds: number?) -> number = ... + +local task = {} + +function task.cancel(thread: unknown) + if type(thread) ~= "thread" then + error(string.format("Argument #1 must be a thread, got %s", typeof(thread))) + else + coroutine.close(thread) + end +end + +function task.defer(tof: ThreadOrFunction, ...: any): thread + if type(tof) == "thread" then + local thread = coroutine.create(function(...) + wait() + coroutine.resume(tof, ...) + end) + coroutine.resume(thread, ...) + return thread + elseif type(tof) == "function" then + local thread = coroutine.create(tof) + task.defer(thread, ...) + return thread + else + error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof))) + end +end + +function task.delay(duration: number?, tof: ThreadOrFunction, ...: any): thread + if duration ~= nil and type(duration) ~= "number" then + error(string.format("Argument #1 must be a number or nil, got %s", typeof(duration))) + end + if type(tof) == "thread" then + local thread = coroutine.create(function(...) + wait(duration) + coroutine.resume(tof, ...) + end) + coroutine.resume(thread, ...) + return thread + elseif type(tof) == "function" then + local thread = coroutine.create(tof) + task.delay(duration, thread, ...) + return thread + else + error(string.format("Argument #2 must be a thread or function, got %s", typeof(tof))) + end +end + +function task.spawn(tof: ThreadOrFunction, ...: any): thread + if type(tof) == "thread" then + coroutine.resume(tof, ...) + return tof + elseif type(tof) == "function" then + local thread = coroutine.create(tof) + coroutine.resume(thread, ...) + return thread + else + error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof))) + end +end + +return task diff --git a/src/tests/task/wait.luau b/src/tests/task/wait.luau index 0cb7b5c..fae8489 100644 --- a/src/tests/task/wait.luau +++ b/src/tests/task/wait.luau @@ -37,19 +37,19 @@ measure(1 / 10) -- Wait should work in other threads, too local flag: boolean = false -task.spawn(function() - task.wait(0.1) - flag = true -end) -assert(not flag, "Wait failed for a task-spawned thread (1)") -task.wait(0.2) -assert(flag, "Wait failed for a task-spawned thread (2)") - -local flag2: boolean = false coroutine.wrap(function() task.wait(0.1) - flag2 = true + flag = true end)() -assert(not flag2, "Wait failed for a coroutine (1)") +assert(not flag, "Wait failed for a coroutine (1)") task.wait(0.2) -assert(flag2, "Wait failed for a coroutine (2)") +assert(flag, "Wait failed for a coroutine (2)") + +local flag2: boolean = false +task.spawn(function() + task.wait(0.1) + flag2 = true +end) +assert(not flag2, "Wait failed for a task-spawned thread (1)") +task.wait(0.2) +assert(flag2, "Wait failed for a task-spawned thread (2)")