Add coroutine test suite, implement status

This commit is contained in:
Filip Tibell 2023-02-17 21:06:33 +01:00
parent dbe6c18d3a
commit 4cc983dbe6
No known key found for this signature in database
6 changed files with 130 additions and 34 deletions

View file

@ -61,6 +61,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
// calling resume or the function that wrap returns must return // calling resume or the function that wrap returns must return
// whatever lua value(s) that the thread or task yielded back // whatever lua value(s) that the thread or task yielded back
let coroutine = globals.get::<_, LuaTable>("coroutine")?; let coroutine = globals.get::<_, LuaTable>("coroutine")?;
coroutine.set("status", lua.create_function(coroutine_status)?)?;
coroutine.set("resume", lua.create_function(coroutine_resume)?)?; coroutine.set("resume", lua.create_function(coroutine_resume)?)?;
coroutine.set("wrap", lua.create_function(coroutine_wrap)?)?; coroutine.set("wrap", lua.create_function(coroutine_wrap)?)?;
// All good, return the task scheduler lib // All good, return the task scheduler lib
@ -127,26 +128,46 @@ fn proxy_typeof<'lua>(lua: &'lua Lua, value: LuaValue<'lua>) -> LuaResult<LuaStr
Coroutine library overrides for compat with task scheduler Coroutine library overrides for compat with task scheduler
*/ */
fn coroutine_status<'a>(
lua: &'a Lua,
value: LuaThreadOrTaskReference<'a>,
) -> LuaResult<LuaString<'a>> {
Ok(match value {
LuaThreadOrTaskReference::Thread(thread) => {
let get_status: LuaFunction = lua.named_registry_value("co.status")?;
get_status.call(thread)?
}
LuaThreadOrTaskReference::TaskReference(task) => {
let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
sched
.get_task_status(task)
.unwrap_or_else(|| lua.create_string("dead").unwrap())
}
})
}
fn coroutine_resume<'lua>( fn coroutine_resume<'lua>(
lua: &'lua Lua, lua: &'lua Lua,
value: LuaThreadOrTaskReference, value: LuaThreadOrTaskReference,
) -> LuaResult<LuaMultiValue<'lua>> { ) -> LuaResult<(bool, LuaMultiValue<'lua>)> {
// FIXME: Resume should return true, return vals OR false, error message
let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
match value {
LuaThreadOrTaskReference::Thread(t) => {
if sched.current_task().is_none() { if sched.current_task().is_none() {
return Err(LuaError::RuntimeError( return Err(LuaError::RuntimeError(
"No current task to inherit".to_string(), "No current task to inherit".to_string(),
)); ));
} }
let task = sched.create_task(TaskKind::Instant, t, None, true)?;
let current = sched.current_task().unwrap(); let current = sched.current_task().unwrap();
let result = sched.resume_task(task, None); let result = match value {
sched.force_set_current_task(Some(current)); LuaThreadOrTaskReference::Thread(t) => {
result let task = sched.create_task(TaskKind::Instant, t, None, true)?;
sched.resume_task(task, None)
} }
LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t, None), LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t, None),
};
sched.force_set_current_task(Some(current));
match result {
Ok(rets) => Ok((true, rets)),
Err(e) => Ok((false, e.to_lua_multi(lua)?)),
} }
} }

View file

@ -92,6 +92,7 @@ pub fn create() -> LuaResult<&'static Lua> {
lua.set_named_registry_value("pcall", globals.get::<_, LuaFunction>("pcall")?)?; lua.set_named_registry_value("pcall", globals.get::<_, LuaFunction>("pcall")?)?;
lua.set_named_registry_value("tostring", globals.get::<_, LuaFunction>("tostring")?)?; lua.set_named_registry_value("tostring", globals.get::<_, LuaFunction>("tostring")?)?;
lua.set_named_registry_value("tonumber", globals.get::<_, LuaFunction>("tonumber")?)?; lua.set_named_registry_value("tonumber", globals.get::<_, LuaFunction>("tonumber")?)?;
lua.set_named_registry_value("co.status", coroutine.get::<_, LuaFunction>("status")?)?;
lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?; lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?;
lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?; lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?;
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;

View file

@ -116,11 +116,29 @@ impl<'fut> TaskScheduler<'fut> {
/** /**
Returns the currently running task, if any. Returns the currently running task, if any.
*/ */
#[allow(dead_code)]
pub fn current_task(&self) -> Option<TaskReference> { pub fn current_task(&self) -> Option<TaskReference> {
self.tasks_current.get() self.tasks_current.get()
} }
/**
Returns the status of a specific task, if it exists in the scheduler.
*/
pub fn get_task_status(&self, reference: TaskReference) -> Option<LuaString> {
self.tasks.borrow().get(&reference).map(|task| {
let status: LuaFunction = self
.lua
.named_registry_value("co.status")
.expect("Missing coroutine status function in registry");
let thread: LuaThread = self
.lua
.registry_value(&task.thread)
.expect("Task thread missing from registry");
status
.call(thread)
.expect("Task thread failed to call status")
})
}
/** /**
Creates a new task, storing a new Lua thread Creates a new task, storing a new Lua thread
for it, as well as the arguments to give the for it, as well as the arguments to give the

View file

@ -60,6 +60,7 @@ create_tests! {
require_nested: "globals/require/tests/nested", require_nested: "globals/require/tests/nested",
require_parents: "globals/require/tests/parents", require_parents: "globals/require/tests/parents",
require_siblings: "globals/require/tests/siblings", require_siblings: "globals/require/tests/siblings",
global_coroutine: "globals/coroutine",
global_pcall: "globals/pcall", global_pcall: "globals/pcall",
global_type: "globals/type", global_type: "globals/type",
global_typeof: "globals/typeof", global_typeof: "globals/typeof",

View file

@ -0,0 +1,74 @@
-- Coroutines should return true, ret values OR false, error
local function pass()
coroutine.yield(1, 2, 3)
coroutine.yield(4, 5, 6)
end
local function fail()
error("Error message")
end
local thread1 = coroutine.create(pass)
local t10, t11, t12, t13 = coroutine.resume(thread1)
assert(t10 == true, "Coroutine resume should return true as first value unless errored")
assert(t11 == 1, "Coroutine resume should return values yielded to it (1)")
assert(t12 == 2, "Coroutine resume should return values yielded to it (2)")
assert(t13 == 3, "Coroutine resume should return values yielded to it (3)")
local thread2 = coroutine.create(fail)
local t20, t21 = coroutine.resume(thread2)
assert(t20 == false, "Coroutine resume should return false as first value when errored")
assert(#tostring(t21) > 0, "Coroutine resume should return error as second if it errors")
-- Coroutine suspended status should be correct
assert(
coroutine.status(thread1) == "suspended",
"Coroutine status should return suspended properly"
)
assert(coroutine.status(thread2) == "dead", "Coroutine status should return dead properly")
-- Coroutines should return values yielded after the first
local t30, t31, t32, t33 = coroutine.resume(thread1)
assert(t30 == true, "Coroutine resume should return true as first value unless errored")
assert(t31 == 4, "Coroutine resume should return values yielded to it (4)")
assert(t32 == 5, "Coroutine resume should return values yielded to it (5)")
assert(t33 == 6, "Coroutine resume should return values yielded to it (6)")
local t40, t41 = coroutine.resume(thread1)
assert(t40 == true, "Coroutine resume should return true as first value unless errored")
assert(t41 == nil, "Coroutine resume should return values yielded to it (7)")
-- Coroutine dead status should be correct after first yielding
assert(coroutine.status(thread1) == "dead", "Coroutine status should return dead properly")
-- Resume should error for dead coroutines
local success1 = coroutine.resume(thread1)
local success2 = coroutine.resume(thread2)
assert(success1 == false, "Coroutine resume on dead coroutines should return false")
assert(success2 == false, "Coroutine resume on dead coroutines should return false")
-- Wait should work inside native lua coroutines
local flag: boolean = false
coroutine.resume(coroutine.create(function()
task.wait(0.1)
flag = true
end))
assert(not flag, "Wait failed while inside coroutine (1)")
task.wait(0.2)
assert(flag, "Wait failed while inside coroutine (2)")
local flag2: boolean = false
coroutine.wrap(function()
task.wait(0.1)
flag2 = true
end)()
assert(not flag2, "Wait failed while inside wrap (1)")
task.wait(0.2)
assert(flag2, "Wait failed while inside wrap (2)")

View file

@ -1,6 +1,6 @@
-- Wait should be accurate down to at least 10ms -- Wait should be accurate down to at least 10ms
local EPSILON = 1 / 100 local EPSILON = 10 / 1_000
local function test(expected: number) local function test(expected: number)
local start = os.clock() local start = os.clock()
@ -44,8 +44,7 @@ measure(1 / 30)
measure(1 / 20) measure(1 / 20)
measure(1 / 10) measure(1 / 10)
-- Wait should work in other threads, including -- Wait should work in other threads
-- ones created by the built-in coroutine library
local flag: boolean = false local flag: boolean = false
task.spawn(function() task.spawn(function()
@ -55,21 +54,3 @@ end)
assert(not flag, "Wait failed while inside task-spawned thread (1)") assert(not flag, "Wait failed while inside task-spawned thread (1)")
task.wait(0.2) task.wait(0.2)
assert(flag, "Wait failed while inside task-spawned thread (2)") assert(flag, "Wait failed while inside task-spawned thread (2)")
local flag2: boolean = false
coroutine.resume(coroutine.create(function()
task.wait(0.1)
flag2 = true
end))
assert(not flag2, "Wait failed while inside coroutine (1)")
task.wait(0.2)
assert(flag2, "Wait failed while inside coroutine (2)")
local flag3: boolean = false
coroutine.wrap(function()
task.wait(0.1)
flag3 = true
end)()
assert(not flag3, "Wait failed while inside wrap (1)")
task.wait(0.2)
assert(flag3, "Wait failed while inside wrap (2)")