From 113f290c59cb3c31b377481ca10c7e99ea218c4b Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Tue, 24 Jan 2023 12:24:57 -0500 Subject: [PATCH] Implement argument passing for task lib --- src/lib/globals/task.rs | 30 +++++++++++++++++++++--------- src/tests/task/defer.luau | 22 ++++++++++++++++++---- src/tests/task/delay.luau | 22 ++++++++++++++++++---- src/tests/task/spawn.luau | 22 ++++++++++++++++++---- 4 files changed, 75 insertions(+), 21 deletions(-) diff --git a/src/lib/globals/task.rs b/src/lib/globals/task.rs index 43c6396..f17c927 100644 --- a/src/lib/globals/task.rs +++ b/src/lib/globals/task.rs @@ -55,19 +55,23 @@ async fn task_cancel<'a>(lua: &'a Lua, thread: LuaThread<'a>) -> LuaResult<()> { async fn task_defer<'a>( lua: &'a Lua, - (tof, _args): (LuaValue<'a>, LuaMultiValue<'a>), + (tof, args): (LuaValue<'a>, LuaMultiValue<'a>), ) -> LuaResult> { // Spawn a new detached task using a lua reference that we can use inside of our task let task_lua = lua.app_data_ref::>().unwrap().upgrade().unwrap(); let task_thread = tof_to_thread(lua, tof)?; let task_thread_key = lua.create_registry_value(task_thread)?; + let task_args_key = lua.create_registry_value(args.into_vec())?; let lua_thread_to_return = lua.registry_value(&task_thread_key)?; run_registered_task(lua, TaskRunMode::Deferred, async move { - let thread = task_lua.registry_value::(&task_thread_key)?; + let thread: LuaThread = task_lua.registry_value(&task_thread_key)?; + let argsv: Vec = task_lua.registry_value(&task_args_key)?; + let args = LuaMultiValue::from_vec(argsv); if thread.status() == LuaThreadStatus::Resumable { - thread.into_async::<_, LuaMultiValue>(()).await?; + let _: LuaMultiValue = thread.into_async(args).await?; } task_lua.remove_registry_value(task_thread_key)?; + task_lua.remove_registry_value(task_args_key)?; Ok(()) }) .await?; @@ -76,20 +80,24 @@ async fn task_defer<'a>( async fn task_delay<'a>( lua: &'a Lua, - (duration, tof, _args): (Option, LuaValue<'a>, LuaMultiValue<'a>), + (duration, tof, args): (Option, LuaValue<'a>, LuaMultiValue<'a>), ) -> LuaResult> { // Spawn a new detached task using a lua reference that we can use inside of our task let task_lua = lua.app_data_ref::>().unwrap().upgrade().unwrap(); let task_thread = tof_to_thread(lua, tof)?; let task_thread_key = lua.create_registry_value(task_thread)?; + let task_args_key = lua.create_registry_value(args.into_vec())?; let lua_thread_to_return = lua.registry_value(&task_thread_key)?; run_registered_task(lua, TaskRunMode::Deferred, async move { task_wait(&task_lua, duration).await?; - let thread = task_lua.registry_value::(&task_thread_key)?; + let thread: LuaThread = task_lua.registry_value(&task_thread_key)?; + let argsv: Vec = task_lua.registry_value(&task_args_key)?; + let args = LuaMultiValue::from_vec(argsv); if thread.status() == LuaThreadStatus::Resumable { - thread.into_async::<_, LuaMultiValue>(()).await?; + let _: LuaMultiValue = thread.into_async(args).await?; } task_lua.remove_registry_value(task_thread_key)?; + task_lua.remove_registry_value(task_args_key)?; Ok(()) }) .await?; @@ -98,19 +106,23 @@ async fn task_delay<'a>( async fn task_spawn<'a>( lua: &'a Lua, - (tof, _args): (LuaValue<'a>, LuaMultiValue<'a>), + (tof, args): (LuaValue<'a>, LuaMultiValue<'a>), ) -> LuaResult> { // Spawn a new detached task using a lua reference that we can use inside of our task let task_lua = lua.app_data_ref::>().unwrap().upgrade().unwrap(); let task_thread = tof_to_thread(lua, tof)?; let task_thread_key = lua.create_registry_value(task_thread)?; + let task_args_key = lua.create_registry_value(args.into_vec())?; let lua_thread_to_return = lua.registry_value(&task_thread_key)?; run_registered_task(lua, TaskRunMode::Instant, async move { - let thread = task_lua.registry_value::(&task_thread_key)?; + let thread: LuaThread = task_lua.registry_value(&task_thread_key)?; + let argsv: Vec = task_lua.registry_value(&task_args_key)?; + let args = LuaMultiValue::from_vec(argsv); if thread.status() == LuaThreadStatus::Resumable { - thread.into_async::<_, LuaMultiValue>(()).await?; + let _: LuaMultiValue = thread.into_async(args).await?; } task_lua.remove_registry_value(task_thread_key)?; + task_lua.remove_registry_value(task_args_key)?; Ok(()) }) .await?; diff --git a/src/tests/task/defer.luau b/src/tests/task/defer.luau index 4766f18..dceb8de 100644 --- a/src/tests/task/defer.luau +++ b/src/tests/task/defer.luau @@ -40,10 +40,24 @@ assert(not flag2, "Defer should run after spawned threads") -- Varargs should get passed correctly -local function f(arg1: string, arg2: number, f2: (...any) -> ...any) - assert(type(arg1) == "string", "Invalid arg 1 passed to function") - assert(type(arg2) == "number", "Invalid arg 2 passed to function") - assert(type(arg3) == "function", "Invalid arg 3 passed to function") +local function fcheck(index: number, type: string, value: any) + if typeof(value) ~= type then + console.error( + string.format( + "Expected argument #%d to be of type %s, got %s", + index, + type, + console.format(value) + ) + ) + process.exit(1) + end +end + +local function f(...: any) + fcheck(1, "string", select(1, ...)) + fcheck(2, "number", select(2, ...)) + fcheck(3, "function", select(3, ...)) end task.defer(f, "", 1, f) diff --git a/src/tests/task/delay.luau b/src/tests/task/delay.luau index 393cf73..abedffe 100644 --- a/src/tests/task/delay.luau +++ b/src/tests/task/delay.luau @@ -28,10 +28,24 @@ assert(not flag2, "Delay should work with yielding (2)") -- Varargs should get passed correctly -local function f(arg1: string, arg2: number, f2: (...any) -> ...any) - assert(type(arg1) == "string", "Invalid arg 1 passed to function") - assert(type(arg2) == "number", "Invalid arg 2 passed to function") - assert(type(arg3) == "function", "Invalid arg 3 passed to function") +local function fcheck(index: number, type: string, value: any) + if typeof(value) ~= type then + console.error( + string.format( + "Expected argument #%d to be of type %s, got %s", + index, + type, + console.format(value) + ) + ) + process.exit(1) + end +end + +local function f(...: any) + fcheck(1, "string", select(1, ...)) + fcheck(2, "number", select(2, ...)) + fcheck(3, "function", select(3, ...)) end task.delay(0, f, "", 1, f) diff --git a/src/tests/task/spawn.luau b/src/tests/task/spawn.luau index 121981d..08f1ae8 100644 --- a/src/tests/task/spawn.luau +++ b/src/tests/task/spawn.luau @@ -33,10 +33,24 @@ assert(flag3, "Spawn should run threads made from coroutine.create") -- Varargs should get passed correctly -local function f(arg1: string, arg2: number, f2: (...any) -> ...any) - assert(type(arg1) == "string", "Invalid arg 1 passed to function") - assert(type(arg2) == "number", "Invalid arg 2 passed to function") - assert(type(arg3) == "function", "Invalid arg 3 passed to function") +local function fcheck(index: number, type: string, value: any) + if typeof(value) ~= type then + console.error( + string.format( + "Expected argument #%d to be of type %s, got %s", + index, + type, + console.format(value) + ) + ) + process.exit(1) + end +end + +local function f(...: any) + fcheck(1, "string", select(1, ...)) + fcheck(2, "number", select(2, ...)) + fcheck(3, "function", select(3, ...)) end task.spawn(f, "", 1, f)