From 4cc983dbe6b9b02219809fce3d9ab94e15bd87cd Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Fri, 17 Feb 2023 21:06:33 +0100 Subject: [PATCH] Add coroutine test suite, implement status --- packages/lib/src/globals/task.rs | 45 +++++++++++----- packages/lib/src/lua/create.rs | 1 + packages/lib/src/lua/task/scheduler.rs | 20 ++++++- packages/lib/src/tests.rs | 1 + tests/globals/coroutine.luau | 74 ++++++++++++++++++++++++++ tests/task/wait.luau | 23 +------- 6 files changed, 130 insertions(+), 34 deletions(-) create mode 100644 tests/globals/coroutine.luau diff --git a/packages/lib/src/globals/task.rs b/packages/lib/src/globals/task.rs index a788934..9907876 100644 --- a/packages/lib/src/globals/task.rs +++ b/packages/lib/src/globals/task.rs @@ -61,6 +61,7 @@ pub fn create(lua: &'static Lua) -> LuaResult> { // calling resume or the function that wrap returns must return // whatever lua value(s) that the thread or task yielded back let coroutine = globals.get::<_, LuaTable>("coroutine")?; + coroutine.set("status", lua.create_function(coroutine_status)?)?; coroutine.set("resume", lua.create_function(coroutine_resume)?)?; coroutine.set("wrap", lua.create_function(coroutine_wrap)?)?; // All good, return the task scheduler lib @@ -127,26 +128,46 @@ fn proxy_typeof<'lua>(lua: &'lua Lua, value: LuaValue<'lua>) -> LuaResult( + lua: &'a Lua, + value: LuaThreadOrTaskReference<'a>, +) -> LuaResult> { + Ok(match value { + LuaThreadOrTaskReference::Thread(thread) => { + let get_status: LuaFunction = lua.named_registry_value("co.status")?; + get_status.call(thread)? + } + LuaThreadOrTaskReference::TaskReference(task) => { + let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); + sched + .get_task_status(task) + .unwrap_or_else(|| lua.create_string("dead").unwrap()) + } + }) +} + fn coroutine_resume<'lua>( lua: &'lua Lua, value: LuaThreadOrTaskReference, -) -> LuaResult> { - // FIXME: Resume should return true, return vals OR false, error message +) -> LuaResult<(bool, LuaMultiValue<'lua>)> { let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); - match value { + if sched.current_task().is_none() { + return Err(LuaError::RuntimeError( + "No current task to inherit".to_string(), + )); + } + let current = sched.current_task().unwrap(); + let result = match value { LuaThreadOrTaskReference::Thread(t) => { - if sched.current_task().is_none() { - return Err(LuaError::RuntimeError( - "No current task to inherit".to_string(), - )); - } let task = sched.create_task(TaskKind::Instant, t, None, true)?; - let current = sched.current_task().unwrap(); - let result = sched.resume_task(task, None); - sched.force_set_current_task(Some(current)); - result + sched.resume_task(task, None) } LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t, None), + }; + sched.force_set_current_task(Some(current)); + match result { + Ok(rets) => Ok((true, rets)), + Err(e) => Ok((false, e.to_lua_multi(lua)?)), } } diff --git a/packages/lib/src/lua/create.rs b/packages/lib/src/lua/create.rs index 8ba500a..3aff66d 100644 --- a/packages/lib/src/lua/create.rs +++ b/packages/lib/src/lua/create.rs @@ -92,6 +92,7 @@ pub fn create() -> LuaResult<&'static Lua> { lua.set_named_registry_value("pcall", globals.get::<_, LuaFunction>("pcall")?)?; lua.set_named_registry_value("tostring", globals.get::<_, LuaFunction>("tostring")?)?; lua.set_named_registry_value("tonumber", globals.get::<_, LuaFunction>("tonumber")?)?; + lua.set_named_registry_value("co.status", coroutine.get::<_, LuaFunction>("status")?)?; lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?; lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?; lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs index 954ff9f..a0499f6 100644 --- a/packages/lib/src/lua/task/scheduler.rs +++ b/packages/lib/src/lua/task/scheduler.rs @@ -116,11 +116,29 @@ impl<'fut> TaskScheduler<'fut> { /** Returns the currently running task, if any. */ - #[allow(dead_code)] pub fn current_task(&self) -> Option { self.tasks_current.get() } + /** + Returns the status of a specific task, if it exists in the scheduler. + */ + pub fn get_task_status(&self, reference: TaskReference) -> Option { + self.tasks.borrow().get(&reference).map(|task| { + let status: LuaFunction = self + .lua + .named_registry_value("co.status") + .expect("Missing coroutine status function in registry"); + let thread: LuaThread = self + .lua + .registry_value(&task.thread) + .expect("Task thread missing from registry"); + status + .call(thread) + .expect("Task thread failed to call status") + }) + } + /** Creates a new task, storing a new Lua thread for it, as well as the arguments to give the diff --git a/packages/lib/src/tests.rs b/packages/lib/src/tests.rs index 9e46c31..eccebea 100644 --- a/packages/lib/src/tests.rs +++ b/packages/lib/src/tests.rs @@ -60,6 +60,7 @@ create_tests! { require_nested: "globals/require/tests/nested", require_parents: "globals/require/tests/parents", require_siblings: "globals/require/tests/siblings", + global_coroutine: "globals/coroutine", global_pcall: "globals/pcall", global_type: "globals/type", global_typeof: "globals/typeof", diff --git a/tests/globals/coroutine.luau b/tests/globals/coroutine.luau new file mode 100644 index 0000000..f9b8f80 --- /dev/null +++ b/tests/globals/coroutine.luau @@ -0,0 +1,74 @@ +-- Coroutines should return true, ret values OR false, error + +local function pass() + coroutine.yield(1, 2, 3) + coroutine.yield(4, 5, 6) +end + +local function fail() + error("Error message") +end + +local thread1 = coroutine.create(pass) +local t10, t11, t12, t13 = coroutine.resume(thread1) +assert(t10 == true, "Coroutine resume should return true as first value unless errored") +assert(t11 == 1, "Coroutine resume should return values yielded to it (1)") +assert(t12 == 2, "Coroutine resume should return values yielded to it (2)") +assert(t13 == 3, "Coroutine resume should return values yielded to it (3)") + +local thread2 = coroutine.create(fail) +local t20, t21 = coroutine.resume(thread2) +assert(t20 == false, "Coroutine resume should return false as first value when errored") +assert(#tostring(t21) > 0, "Coroutine resume should return error as second if it errors") + +-- Coroutine suspended status should be correct + +assert( + coroutine.status(thread1) == "suspended", + "Coroutine status should return suspended properly" +) +assert(coroutine.status(thread2) == "dead", "Coroutine status should return dead properly") + +-- Coroutines should return values yielded after the first + +local t30, t31, t32, t33 = coroutine.resume(thread1) +assert(t30 == true, "Coroutine resume should return true as first value unless errored") +assert(t31 == 4, "Coroutine resume should return values yielded to it (4)") +assert(t32 == 5, "Coroutine resume should return values yielded to it (5)") +assert(t33 == 6, "Coroutine resume should return values yielded to it (6)") + +local t40, t41 = coroutine.resume(thread1) +assert(t40 == true, "Coroutine resume should return true as first value unless errored") +assert(t41 == nil, "Coroutine resume should return values yielded to it (7)") + +-- Coroutine dead status should be correct after first yielding + +assert(coroutine.status(thread1) == "dead", "Coroutine status should return dead properly") + +-- Resume should error for dead coroutines + +local success1 = coroutine.resume(thread1) +local success2 = coroutine.resume(thread2) + +assert(success1 == false, "Coroutine resume on dead coroutines should return false") +assert(success2 == false, "Coroutine resume on dead coroutines should return false") + +-- Wait should work inside native lua coroutines + +local flag: boolean = false +coroutine.resume(coroutine.create(function() + task.wait(0.1) + flag = true +end)) +assert(not flag, "Wait failed while inside coroutine (1)") +task.wait(0.2) +assert(flag, "Wait failed while inside coroutine (2)") + +local flag2: boolean = false +coroutine.wrap(function() + task.wait(0.1) + flag2 = true +end)() +assert(not flag2, "Wait failed while inside wrap (1)") +task.wait(0.2) +assert(flag2, "Wait failed while inside wrap (2)") diff --git a/tests/task/wait.luau b/tests/task/wait.luau index 437f30e..2023cd3 100644 --- a/tests/task/wait.luau +++ b/tests/task/wait.luau @@ -1,6 +1,6 @@ -- Wait should be accurate down to at least 10ms -local EPSILON = 1 / 100 +local EPSILON = 10 / 1_000 local function test(expected: number) local start = os.clock() @@ -44,8 +44,7 @@ measure(1 / 30) measure(1 / 20) measure(1 / 10) --- Wait should work in other threads, including --- ones created by the built-in coroutine library +-- Wait should work in other threads local flag: boolean = false task.spawn(function() @@ -55,21 +54,3 @@ end) assert(not flag, "Wait failed while inside task-spawned thread (1)") task.wait(0.2) assert(flag, "Wait failed while inside task-spawned thread (2)") - -local flag2: boolean = false -coroutine.resume(coroutine.create(function() - task.wait(0.1) - flag2 = true -end)) -assert(not flag2, "Wait failed while inside coroutine (1)") -task.wait(0.2) -assert(flag2, "Wait failed while inside coroutine (2)") - -local flag3: boolean = false -coroutine.wrap(function() - task.wait(0.1) - flag3 = true -end)() -assert(not flag3, "Wait failed while inside wrap (1)") -task.wait(0.2) -assert(flag3, "Wait failed while inside wrap (2)")