Re-implement task library

This commit is contained in:
Filip Tibell 2023-01-22 23:00:09 -05:00
parent d50095e974
commit 54846648fd
No known key found for this signature in database
6 changed files with 143 additions and 127 deletions

View file

@ -74,22 +74,22 @@ globals:
- required: false
type: table
# Task
# task.cancel:
# task.defer:
# args:
# - type: thread | function
# - type: "..."
# task.delay:
# args:
# - required: false
# type: number
# - type: thread | function
# - type: "..."
# task.spawn:
# args:
# - type: thread | function
# - type: "..."
# task.wait:
# args:
# - required: false
# type: number
task.cancel:
task.defer:
args:
- type: thread | function
- type: "..."
task.delay:
args:
- required: false
type: number
- type: thread | function
- type: "..."
task.spawn:
args:
- type: thread | function
- type: "..."
task.wait:
args:
- required: false
type: number

View file

@ -52,12 +52,10 @@ declare process: {
},
}
--[[
declare task: {
cancel: (t: thread) -> (),
defer: <A..., R...>(f: thread | (A...) -> (R...), A...) -> (R...),
delay: <A..., R...>(duration: number?, f: thread | (A...) -> (R...), A...) -> (R...),
spawn: <A..., R...>(f: thread | (A...) -> (R...), A...) -> (R...),
defer: <T...>(f: thread | (T...) -> (...any), T...) -> thread,
delay: <T...>(duration: number?, f: thread | (T...) -> (...any), T...) -> thread,
spawn: <T...>(f: thread | (T...) -> (...any), T...) -> thread,
wait: (duration: number?) -> (number),
}
]]

View file

@ -5,80 +5,10 @@ use smol::Timer;
use crate::utils::table_builder::TableBuilder;
const TASK_LIB: &str = include_str!("../luau/task.luau");
pub async fn create(lua: &Lua) -> LuaResult<()> {
lua.globals().raw_set(
"task",
TableBuilder::new(lua)?
.with_async_function("cancel", task_cancel)?
.with_async_function("defer", task_defer)?
.with_async_function("delay", task_delay)?
.with_async_function("spawn", task_spawn)?
.with_async_function("wait", task_wait)?
.build_readonly()?,
)
}
fn get_or_create_thread_from_arg<'a>(lua: &'a Lua, arg: LuaValue<'a>) -> LuaResult<LuaThread<'a>> {
match arg {
LuaValue::Thread(thread) => Ok(thread),
LuaValue::Function(func) => Ok(lua.create_thread(func)?),
val => Err(LuaError::RuntimeError(format!(
"Expected type thread or function, got {}",
val.type_name()
))),
}
}
async fn resume_thread(lua: &Lua, thread: LuaThread<'_>, args: LuaMultiValue<'_>) -> LuaResult<()> {
let coroutine: LuaTable = lua.globals().raw_get("coroutine")?;
let resume: LuaFunction = coroutine.raw_get("resume")?;
// FIXME: This is blocking, we should spawn a local tokio task,
// but doing that moves "thread" and "args", that both have
// the lifetime of the outer function, so it doesn't work
resume.call_async((thread, args)).await?;
Ok(())
}
async fn task_cancel(lua: &Lua, thread: LuaThread<'_>) -> LuaResult<()> {
let coroutine: LuaTable = lua.globals().raw_get("coroutine")?;
let close: LuaFunction = coroutine.raw_get("close")?;
close.call_async(thread).await?;
Ok(())
}
async fn task_defer<'a>(
lua: &'a Lua,
(tof, args): (LuaValue<'a>, LuaMultiValue<'a>),
) -> LuaResult<LuaThread<'a>> {
// TODO: Defer (sleep a minimum amount of time)
let thread = get_or_create_thread_from_arg(lua, tof)?;
resume_thread(lua, thread.clone(), args).await?;
Ok(thread)
}
async fn task_delay<'a>(
lua: &'a Lua,
(_delay, tof, args): (Option<f32>, LuaValue<'a>, LuaMultiValue<'a>),
) -> LuaResult<LuaThread<'a>> {
// TODO: Delay by the amount of time wanted
let thread = get_or_create_thread_from_arg(lua, tof)?;
resume_thread(lua, thread.clone(), args).await?;
Ok(thread)
}
async fn task_spawn<'a>(
lua: &'a Lua,
(tof, args): (LuaValue<'a>, LuaMultiValue<'a>),
) -> LuaResult<LuaThread<'a>> {
let thread = get_or_create_thread_from_arg(lua, tof)?;
resume_thread(lua, thread.clone(), args).await?;
Ok(thread)
}
// FIXME: It doesn't seem possible to properly make an async wait
// function with mlua right now, something breaks when using
// the async wait function inside of a coroutine
async fn task_wait(_: &Lua, duration: Option<f32>) -> LuaResult<f32> {
let wait = lua.create_async_function(move |_, duration: Option<f32>| async move {
let start = Instant::now();
Timer::after(
duration
@ -88,4 +18,20 @@ async fn task_wait(_: &Lua, duration: Option<f32>) -> LuaResult<f32> {
.await;
let end = Instant::now();
Ok((end - start).as_secs_f32())
})?;
let task_lib: LuaTable = lua
.load(TASK_LIB)
.set_name("task")?
.call_async(wait.clone())
.await?;
lua.globals().raw_set(
"task",
TableBuilder::new(lua)?
.with_value("cancel", task_lib.raw_get::<_, LuaFunction>("cancel")?)?
.with_value("defer", task_lib.raw_get::<_, LuaFunction>("defer")?)?
.with_value("delay", task_lib.raw_get::<_, LuaFunction>("delay")?)?
.with_value("spawn", task_lib.raw_get::<_, LuaFunction>("spawn")?)?
.with_value("wait", wait)?
.build_readonly()?,
)
}

View file

@ -2,6 +2,7 @@ use std::collections::HashSet;
use anyhow::{bail, Result};
use mlua::prelude::*;
use smol::LocalExecutor;
pub mod globals;
pub mod utils;
@ -61,8 +62,9 @@ impl Lune {
}
pub async fn run(&self, name: &str, chunk: &str) -> Result<()> {
smol::block_on(async {
let lua = Lua::new();
let exec = LocalExecutor::new();
smol::block_on(exec.run(async {
for global in &self.globals {
match &global {
LuneGlobal::Console => create_console(&lua).await?,
@ -72,7 +74,11 @@ impl Lune {
LuneGlobal::Task => create_task(&lua).await?,
}
}
let result = lua.load(chunk).set_name(name)?.exec_async().await;
let result = lua
.load(chunk)
.set_name(name)?
.call_async::<_, LuaMultiValue>(LuaMultiValue::new())
.await;
match result {
Ok(_) => Ok(()),
Err(e) => {
@ -87,7 +93,7 @@ impl Lune {
}
}
}
})
}))
}
}
@ -139,11 +145,10 @@ mod tests {
net_request_redirect: "net/request/redirect",
net_json_decode: "net/json/decode",
net_json_encode: "net/json/encode",
// FIXME: Re-enable these tests for doing more work on the task library
// task_cancel: "task/cancel",
// task_defer: "task/defer",
// task_delay: "task/delay",
// task_spawn: "task/spawn",
// task_wait: "task/wait",
task_cancel: "task/cancel",
task_defer: "task/defer",
task_delay: "task/delay",
task_spawn: "task/spawn",
task_wait: "task/wait",
}
}

67
src/lib/luau/task.luau Normal file
View file

@ -0,0 +1,67 @@
type ThreadOrFunction = thread | (...any) -> ...any
-- NOTE: The async wait function gets passed in here by task.rs,
-- the same function will then be used for the global task library
local wait: (seconds: number?) -> number = ...
local task = {}
function task.cancel(thread: unknown)
if type(thread) ~= "thread" then
error(string.format("Argument #1 must be a thread, got %s", typeof(thread)))
else
coroutine.close(thread)
end
end
function task.defer(tof: ThreadOrFunction, ...: any): thread
if type(tof) == "thread" then
local thread = coroutine.create(function(...)
wait()
coroutine.resume(tof, ...)
end)
coroutine.resume(thread, ...)
return thread
elseif type(tof) == "function" then
local thread = coroutine.create(tof)
task.defer(thread, ...)
return thread
else
error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof)))
end
end
function task.delay(duration: number?, tof: ThreadOrFunction, ...: any): thread
if duration ~= nil and type(duration) ~= "number" then
error(string.format("Argument #1 must be a number or nil, got %s", typeof(duration)))
end
if type(tof) == "thread" then
local thread = coroutine.create(function(...)
wait(duration)
coroutine.resume(tof, ...)
end)
coroutine.resume(thread, ...)
return thread
elseif type(tof) == "function" then
local thread = coroutine.create(tof)
task.delay(duration, thread, ...)
return thread
else
error(string.format("Argument #2 must be a thread or function, got %s", typeof(tof)))
end
end
function task.spawn(tof: ThreadOrFunction, ...: any): thread
if type(tof) == "thread" then
coroutine.resume(tof, ...)
return tof
elseif type(tof) == "function" then
local thread = coroutine.create(tof)
coroutine.resume(thread, ...)
return thread
else
error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof)))
end
end
return task

View file

@ -37,19 +37,19 @@ measure(1 / 10)
-- Wait should work in other threads, too
local flag: boolean = false
task.spawn(function()
task.wait(0.1)
flag = true
end)
assert(not flag, "Wait failed for a task-spawned thread (1)")
task.wait(0.2)
assert(flag, "Wait failed for a task-spawned thread (2)")
local flag2: boolean = false
coroutine.wrap(function()
task.wait(0.1)
flag2 = true
flag = true
end)()
assert(not flag2, "Wait failed for a coroutine (1)")
assert(not flag, "Wait failed for a coroutine (1)")
task.wait(0.2)
assert(flag2, "Wait failed for a coroutine (2)")
assert(flag, "Wait failed for a coroutine (2)")
local flag2: boolean = false
task.spawn(function()
task.wait(0.1)
flag2 = true
end)
assert(not flag2, "Wait failed for a task-spawned thread (1)")
task.wait(0.2)
assert(flag2, "Wait failed for a task-spawned thread (2)")