Re-implement task library

This commit is contained in:
Filip Tibell 2023-01-22 23:00:09 -05:00
parent d50095e974
commit 54846648fd
No known key found for this signature in database
6 changed files with 143 additions and 127 deletions

View file

@ -74,22 +74,22 @@ globals:
- required: false - required: false
type: table type: table
# Task # Task
# task.cancel: task.cancel:
# task.defer: task.defer:
# args: args:
# - type: thread | function - type: thread | function
# - type: "..." - type: "..."
# task.delay: task.delay:
# args: args:
# - required: false - required: false
# type: number type: number
# - type: thread | function - type: thread | function
# - type: "..." - type: "..."
# task.spawn: task.spawn:
# args: args:
# - type: thread | function - type: thread | function
# - type: "..." - type: "..."
# task.wait: task.wait:
# args: args:
# - required: false - required: false
# type: number type: number

View file

@ -52,12 +52,10 @@ declare process: {
}, },
} }
--[[ declare task: {
declare task: { cancel: (t: thread) -> (),
cancel: (t: thread) -> (), defer: <T...>(f: thread | (T...) -> (...any), T...) -> thread,
defer: <A..., R...>(f: thread | (A...) -> (R...), A...) -> (R...), delay: <T...>(duration: number?, f: thread | (T...) -> (...any), T...) -> thread,
delay: <A..., R...>(duration: number?, f: thread | (A...) -> (R...), A...) -> (R...), spawn: <T...>(f: thread | (T...) -> (...any), T...) -> thread,
spawn: <A..., R...>(f: thread | (A...) -> (R...), A...) -> (R...), wait: (duration: number?) -> (number),
wait: (duration: number?) -> (number), }
}
]]

View file

@ -5,87 +5,33 @@ use smol::Timer;
use crate::utils::table_builder::TableBuilder; use crate::utils::table_builder::TableBuilder;
const TASK_LIB: &str = include_str!("../luau/task.luau");
pub async fn create(lua: &Lua) -> LuaResult<()> { pub async fn create(lua: &Lua) -> LuaResult<()> {
let wait = lua.create_async_function(move |_, duration: Option<f32>| async move {
let start = Instant::now();
Timer::after(
duration
.map(Duration::from_secs_f32)
.unwrap_or(Duration::ZERO),
)
.await;
let end = Instant::now();
Ok((end - start).as_secs_f32())
})?;
let task_lib: LuaTable = lua
.load(TASK_LIB)
.set_name("task")?
.call_async(wait.clone())
.await?;
lua.globals().raw_set( lua.globals().raw_set(
"task", "task",
TableBuilder::new(lua)? TableBuilder::new(lua)?
.with_async_function("cancel", task_cancel)? .with_value("cancel", task_lib.raw_get::<_, LuaFunction>("cancel")?)?
.with_async_function("defer", task_defer)? .with_value("defer", task_lib.raw_get::<_, LuaFunction>("defer")?)?
.with_async_function("delay", task_delay)? .with_value("delay", task_lib.raw_get::<_, LuaFunction>("delay")?)?
.with_async_function("spawn", task_spawn)? .with_value("spawn", task_lib.raw_get::<_, LuaFunction>("spawn")?)?
.with_async_function("wait", task_wait)? .with_value("wait", wait)?
.build_readonly()?, .build_readonly()?,
) )
} }
fn get_or_create_thread_from_arg<'a>(lua: &'a Lua, arg: LuaValue<'a>) -> LuaResult<LuaThread<'a>> {
match arg {
LuaValue::Thread(thread) => Ok(thread),
LuaValue::Function(func) => Ok(lua.create_thread(func)?),
val => Err(LuaError::RuntimeError(format!(
"Expected type thread or function, got {}",
val.type_name()
))),
}
}
async fn resume_thread(lua: &Lua, thread: LuaThread<'_>, args: LuaMultiValue<'_>) -> LuaResult<()> {
let coroutine: LuaTable = lua.globals().raw_get("coroutine")?;
let resume: LuaFunction = coroutine.raw_get("resume")?;
// FIXME: This is blocking, we should spawn a local tokio task,
// but doing that moves "thread" and "args", that both have
// the lifetime of the outer function, so it doesn't work
resume.call_async((thread, args)).await?;
Ok(())
}
async fn task_cancel(lua: &Lua, thread: LuaThread<'_>) -> LuaResult<()> {
let coroutine: LuaTable = lua.globals().raw_get("coroutine")?;
let close: LuaFunction = coroutine.raw_get("close")?;
close.call_async(thread).await?;
Ok(())
}
async fn task_defer<'a>(
lua: &'a Lua,
(tof, args): (LuaValue<'a>, LuaMultiValue<'a>),
) -> LuaResult<LuaThread<'a>> {
// TODO: Defer (sleep a minimum amount of time)
let thread = get_or_create_thread_from_arg(lua, tof)?;
resume_thread(lua, thread.clone(), args).await?;
Ok(thread)
}
async fn task_delay<'a>(
lua: &'a Lua,
(_delay, tof, args): (Option<f32>, LuaValue<'a>, LuaMultiValue<'a>),
) -> LuaResult<LuaThread<'a>> {
// TODO: Delay by the amount of time wanted
let thread = get_or_create_thread_from_arg(lua, tof)?;
resume_thread(lua, thread.clone(), args).await?;
Ok(thread)
}
async fn task_spawn<'a>(
lua: &'a Lua,
(tof, args): (LuaValue<'a>, LuaMultiValue<'a>),
) -> LuaResult<LuaThread<'a>> {
let thread = get_or_create_thread_from_arg(lua, tof)?;
resume_thread(lua, thread.clone(), args).await?;
Ok(thread)
}
// FIXME: It doesn't seem possible to properly make an async wait
// function with mlua right now, something breaks when using
// the async wait function inside of a coroutine
async fn task_wait(_: &Lua, duration: Option<f32>) -> LuaResult<f32> {
let start = Instant::now();
Timer::after(
duration
.map(Duration::from_secs_f32)
.unwrap_or(Duration::ZERO),
)
.await;
let end = Instant::now();
Ok((end - start).as_secs_f32())
}

View file

@ -2,6 +2,7 @@ use std::collections::HashSet;
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use mlua::prelude::*; use mlua::prelude::*;
use smol::LocalExecutor;
pub mod globals; pub mod globals;
pub mod utils; pub mod utils;
@ -61,8 +62,9 @@ impl Lune {
} }
pub async fn run(&self, name: &str, chunk: &str) -> Result<()> { pub async fn run(&self, name: &str, chunk: &str) -> Result<()> {
smol::block_on(async { let lua = Lua::new();
let lua = Lua::new(); let exec = LocalExecutor::new();
smol::block_on(exec.run(async {
for global in &self.globals { for global in &self.globals {
match &global { match &global {
LuneGlobal::Console => create_console(&lua).await?, LuneGlobal::Console => create_console(&lua).await?,
@ -72,7 +74,11 @@ impl Lune {
LuneGlobal::Task => create_task(&lua).await?, LuneGlobal::Task => create_task(&lua).await?,
} }
} }
let result = lua.load(chunk).set_name(name)?.exec_async().await; let result = lua
.load(chunk)
.set_name(name)?
.call_async::<_, LuaMultiValue>(LuaMultiValue::new())
.await;
match result { match result {
Ok(_) => Ok(()), Ok(_) => Ok(()),
Err(e) => { Err(e) => {
@ -87,7 +93,7 @@ impl Lune {
} }
} }
} }
}) }))
} }
} }
@ -139,11 +145,10 @@ mod tests {
net_request_redirect: "net/request/redirect", net_request_redirect: "net/request/redirect",
net_json_decode: "net/json/decode", net_json_decode: "net/json/decode",
net_json_encode: "net/json/encode", net_json_encode: "net/json/encode",
// FIXME: Re-enable these tests for doing more work on the task library task_cancel: "task/cancel",
// task_cancel: "task/cancel", task_defer: "task/defer",
// task_defer: "task/defer", task_delay: "task/delay",
// task_delay: "task/delay", task_spawn: "task/spawn",
// task_spawn: "task/spawn", task_wait: "task/wait",
// task_wait: "task/wait",
} }
} }

67
src/lib/luau/task.luau Normal file
View file

@ -0,0 +1,67 @@
type ThreadOrFunction = thread | (...any) -> ...any
-- NOTE: The async wait function gets passed in here by task.rs,
-- the same function will then be used for the global task library
local wait: (seconds: number?) -> number = ...
local task = {}
function task.cancel(thread: unknown)
if type(thread) ~= "thread" then
error(string.format("Argument #1 must be a thread, got %s", typeof(thread)))
else
coroutine.close(thread)
end
end
function task.defer(tof: ThreadOrFunction, ...: any): thread
if type(tof) == "thread" then
local thread = coroutine.create(function(...)
wait()
coroutine.resume(tof, ...)
end)
coroutine.resume(thread, ...)
return thread
elseif type(tof) == "function" then
local thread = coroutine.create(tof)
task.defer(thread, ...)
return thread
else
error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof)))
end
end
function task.delay(duration: number?, tof: ThreadOrFunction, ...: any): thread
if duration ~= nil and type(duration) ~= "number" then
error(string.format("Argument #1 must be a number or nil, got %s", typeof(duration)))
end
if type(tof) == "thread" then
local thread = coroutine.create(function(...)
wait(duration)
coroutine.resume(tof, ...)
end)
coroutine.resume(thread, ...)
return thread
elseif type(tof) == "function" then
local thread = coroutine.create(tof)
task.delay(duration, thread, ...)
return thread
else
error(string.format("Argument #2 must be a thread or function, got %s", typeof(tof)))
end
end
function task.spawn(tof: ThreadOrFunction, ...: any): thread
if type(tof) == "thread" then
coroutine.resume(tof, ...)
return tof
elseif type(tof) == "function" then
local thread = coroutine.create(tof)
coroutine.resume(thread, ...)
return thread
else
error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof)))
end
end
return task

View file

@ -37,19 +37,19 @@ measure(1 / 10)
-- Wait should work in other threads, too -- Wait should work in other threads, too
local flag: boolean = false local flag: boolean = false
task.spawn(function()
task.wait(0.1)
flag = true
end)
assert(not flag, "Wait failed for a task-spawned thread (1)")
task.wait(0.2)
assert(flag, "Wait failed for a task-spawned thread (2)")
local flag2: boolean = false
coroutine.wrap(function() coroutine.wrap(function()
task.wait(0.1) task.wait(0.1)
flag2 = true flag = true
end)() end)()
assert(not flag2, "Wait failed for a coroutine (1)") assert(not flag, "Wait failed for a coroutine (1)")
task.wait(0.2) task.wait(0.2)
assert(flag2, "Wait failed for a coroutine (2)") assert(flag, "Wait failed for a coroutine (2)")
local flag2: boolean = false
task.spawn(function()
task.wait(0.1)
flag2 = true
end)
assert(not flag2, "Wait failed for a task-spawned thread (1)")
task.wait(0.2)
assert(flag2, "Wait failed for a task-spawned thread (2)")