Re-implement task library, take two

This commit is contained in:
Filip Tibell 2023-01-23 02:38:32 -05:00
parent 54846648fd
commit 19829d7cf4
No known key found for this signature in database
8 changed files with 289 additions and 136 deletions

View file

@ -5,7 +5,7 @@ use crate::utils::{
table_builder::TableBuilder, table_builder::TableBuilder,
}; };
pub async fn create(lua: &Lua) -> LuaResult<()> { pub fn create(lua: &Lua) -> LuaResult<()> {
let print = |args: &LuaMultiValue, throw: bool| -> LuaResult<()> { let print = |args: &LuaMultiValue, throw: bool| -> LuaResult<()> {
let s = pretty_format_multi_value(args)?; let s = pretty_format_multi_value(args)?;
if throw { if throw {

View file

@ -5,7 +5,7 @@ use smol::{fs, prelude::*};
use crate::utils::table_builder::TableBuilder; use crate::utils::table_builder::TableBuilder;
pub async fn create(lua: &Lua) -> LuaResult<()> { pub fn create(lua: &Lua) -> LuaResult<()> {
lua.globals().raw_set( lua.globals().raw_set(
"fs", "fs",
TableBuilder::new(lua)? TableBuilder::new(lua)?

View file

@ -4,7 +4,7 @@ use mlua::prelude::*;
use crate::utils::{net::get_request_user_agent_header, table_builder::TableBuilder}; use crate::utils::{net::get_request_user_agent_header, table_builder::TableBuilder};
pub async fn create(lua: &Lua) -> LuaResult<()> { pub fn create(lua: &Lua) -> LuaResult<()> {
lua.globals().raw_set( lua.globals().raw_set(
"net", "net",
TableBuilder::new(lua)? TableBuilder::new(lua)?

View file

@ -9,7 +9,7 @@ use smol::process::Command;
use crate::utils::table_builder::TableBuilder; use crate::utils::table_builder::TableBuilder;
pub async fn create(lua: &Lua, args_vec: Vec<String>) -> LuaResult<()> { pub fn create(lua: &Lua, args_vec: Vec<String>) -> LuaResult<()> {
// Create readonly args array // Create readonly args array
let args_tab = TableBuilder::new(lua)? let args_tab = TableBuilder::new(lua)?
.with_sequential_values(args_vec)? .with_sequential_values(args_vec)?

View file

@ -1,37 +1,220 @@
use std::time::{Duration, Instant}; // TODO: Figure out a good way to remove all the shared boilerplate from these functions
use std::{
sync::Weak,
time::{Duration, Instant},
};
use mlua::prelude::*; use mlua::prelude::*;
use smol::Timer; use smol::{channel::Sender, LocalExecutor, Timer};
use crate::utils::table_builder::TableBuilder; use crate::{utils::table_builder::TableBuilder, LuneMessage};
const TASK_LIB: &str = include_str!("../luau/task.luau"); pub fn create(lua: &Lua) -> LuaResult<()> {
pub async fn create(lua: &Lua) -> LuaResult<()> {
let wait = lua.create_async_function(move |_, duration: Option<f32>| async move {
let start = Instant::now();
Timer::after(
duration
.map(Duration::from_secs_f32)
.unwrap_or(Duration::ZERO),
)
.await;
let end = Instant::now();
Ok((end - start).as_secs_f32())
})?;
let task_lib: LuaTable = lua
.load(TASK_LIB)
.set_name("task")?
.call_async(wait.clone())
.await?;
lua.globals().raw_set( lua.globals().raw_set(
"task", "task",
TableBuilder::new(lua)? TableBuilder::new(lua)?
.with_value("cancel", task_lib.raw_get::<_, LuaFunction>("cancel")?)? .with_async_function("cancel", task_cancel)?
.with_value("defer", task_lib.raw_get::<_, LuaFunction>("defer")?)? .with_async_function("delay", task_delay)?
.with_value("delay", task_lib.raw_get::<_, LuaFunction>("delay")?)? .with_async_function("defer", task_defer)?
.with_value("spawn", task_lib.raw_get::<_, LuaFunction>("spawn")?)? .with_async_function("spawn", task_spawn)?
.with_value("wait", wait)? .with_async_function("wait", task_wait)?
.build_readonly()?, .build_readonly()?,
) )
} }
fn tof_to_thread<'a>(lua: &'a Lua, tof: LuaValue<'a>) -> LuaResult<LuaThread<'a>> {
match tof {
LuaValue::Thread(t) => Ok(t),
LuaValue::Function(f) => Ok(lua.create_thread(f)?),
value => Err(LuaError::RuntimeError(format!(
"Argument must be a thread or function, got {}",
value.type_name()
))),
}
}
async fn task_cancel<'a>(lua: &'a Lua, thread: LuaThread<'a>) -> LuaResult<()> {
let coroutine: LuaTable = lua.globals().raw_get("coroutine")?;
let close: LuaFunction = coroutine.raw_get("close")?;
close.call_async(thread).await?;
Ok(())
}
async fn task_defer<'a>(task_lua: &'a Lua, tof: LuaValue<'a>) -> LuaResult<LuaThread<'a>> {
// Boilerplate to get arc-ed lua & async executor
let lua = task_lua
.app_data_ref::<Weak<Lua>>()
.unwrap()
.upgrade()
.unwrap();
let exec = task_lua
.app_data_ref::<Weak<LocalExecutor>>()
.unwrap()
.upgrade()
.unwrap();
let sender = task_lua
.app_data_ref::<Weak<Sender<LuneMessage>>>()
.unwrap()
.upgrade()
.unwrap();
// Spawn a new detached thread
sender
.send(LuneMessage::Spawned)
.await
.map_err(LuaError::external)?;
let thread = tof_to_thread(&lua, tof)?;
let thread_key = lua.create_registry_value(thread)?;
let thread_to_return = task_lua.registry_value(&thread_key)?;
let thread_sender = sender.clone();
exec.spawn(async move {
let result = async {
task_wait(&lua, None).await?;
let thread = lua.registry_value::<LuaThread>(&thread_key)?;
if thread.status() == LuaThreadStatus::Resumable {
thread.into_async::<_, LuaMultiValue>(()).await?;
}
Ok::<_, LuaError>(())
};
thread_sender
.send(match result.await {
Ok(_) => LuneMessage::Finished,
Err(e) => LuneMessage::LuaError(e),
})
.await
})
.detach();
sender
.send(LuneMessage::Finished)
.await
.map_err(LuaError::external)?;
Ok(thread_to_return)
}
async fn task_delay<'a>(
task_lua: &'a Lua,
(duration, tof): (Option<f32>, LuaValue<'a>),
) -> LuaResult<LuaThread<'a>> {
// Boilerplate to get arc-ed lua & async executor
let lua = task_lua
.app_data_ref::<Weak<Lua>>()
.unwrap()
.upgrade()
.unwrap();
let exec = task_lua
.app_data_ref::<Weak<LocalExecutor>>()
.unwrap()
.upgrade()
.unwrap();
let sender = task_lua
.app_data_ref::<Weak<Sender<LuneMessage>>>()
.unwrap()
.upgrade()
.unwrap();
// Spawn a new detached thread
sender
.send(LuneMessage::Spawned)
.await
.map_err(LuaError::external)?;
let thread = tof_to_thread(&lua, tof)?;
let thread_key = lua.create_registry_value(thread)?;
let thread_to_return = task_lua.registry_value(&thread_key)?;
let thread_sender = sender.clone();
exec.spawn(async move {
let result = async {
task_wait(&lua, duration).await?;
let thread = lua.registry_value::<LuaThread>(&thread_key)?;
if thread.status() == LuaThreadStatus::Resumable {
thread.into_async::<_, LuaMultiValue>(()).await?;
}
Ok::<_, LuaError>(())
};
thread_sender
.send(match result.await {
Ok(_) => LuneMessage::Finished,
Err(e) => LuneMessage::LuaError(e),
})
.await
})
.detach();
sender
.send(LuneMessage::Finished)
.await
.map_err(LuaError::external)?;
Ok(thread_to_return)
}
async fn task_spawn<'a>(task_lua: &'a Lua, tof: LuaValue<'a>) -> LuaResult<LuaThread<'a>> {
// Boilerplate to get arc-ed lua & async executor
let lua = task_lua
.app_data_ref::<Weak<Lua>>()
.unwrap()
.upgrade()
.unwrap();
let exec = task_lua
.app_data_ref::<Weak<LocalExecutor>>()
.unwrap()
.upgrade()
.unwrap();
let sender = task_lua
.app_data_ref::<Weak<Sender<LuneMessage>>>()
.unwrap()
.upgrade()
.unwrap();
// Spawn a new detached thread
sender
.send(LuneMessage::Spawned)
.await
.map_err(LuaError::external)?;
let thread = tof_to_thread(&lua, tof)?;
let thread_key = lua.create_registry_value(thread)?;
let thread_to_return = task_lua.registry_value(&thread_key)?;
let thread_sender = sender.clone();
// FIXME: This does not run the thread instantly
exec.spawn(async move {
let result = async {
let thread = lua.registry_value::<LuaThread>(&thread_key)?;
if thread.status() == LuaThreadStatus::Resumable {
thread.into_async::<_, LuaMultiValue>(()).await?;
}
Ok::<_, LuaError>(())
};
thread_sender
.send(match result.await {
Ok(_) => LuneMessage::Finished,
Err(e) => LuneMessage::LuaError(e),
})
.await
})
.detach();
sender
.send(LuneMessage::Finished)
.await
.map_err(LuaError::external)?;
Ok(thread_to_return)
}
async fn task_wait(lua: &Lua, duration: Option<f32>) -> LuaResult<f32> {
let sender = lua
.app_data_ref::<Weak<Sender<LuneMessage>>>()
.unwrap()
.upgrade()
.unwrap();
sender
.send(LuneMessage::Spawned)
.await
.map_err(LuaError::external)?;
let start = Instant::now();
Timer::after(
duration
.map(Duration::from_secs_f32)
.unwrap_or(Duration::ZERO),
)
.await;
let end = Instant::now();
sender
.send(LuneMessage::Finished)
.await
.map_err(LuaError::external)?;
Ok((end - start).as_secs_f32())
}

View file

@ -1,6 +1,6 @@
use std::collections::HashSet; use std::{collections::HashSet, sync::Arc};
use anyhow::{bail, Result}; use anyhow::{anyhow, bail, Result};
use mlua::prelude::*; use mlua::prelude::*;
use smol::LocalExecutor; use smol::LocalExecutor;
@ -33,6 +33,13 @@ impl LuneGlobal {
} }
} }
pub(crate) enum LuneMessage {
Spawned,
Finished,
Error(anyhow::Error),
LuaError(mlua::Error),
}
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct Lune { pub struct Lune {
globals: HashSet<LuneGlobal>, globals: HashSet<LuneGlobal>,
@ -62,37 +69,76 @@ impl Lune {
} }
pub async fn run(&self, name: &str, chunk: &str) -> Result<()> { pub async fn run(&self, name: &str, chunk: &str) -> Result<()> {
let lua = Lua::new(); let (s, r) = smol::channel::unbounded::<LuneMessage>();
let exec = LocalExecutor::new(); let lua = Arc::new(mlua::Lua::new());
smol::block_on(exec.run(async { let exec = Arc::new(LocalExecutor::new());
for global in &self.globals { let sender = Arc::new(s);
match &global { let receiver = Arc::new(r);
LuneGlobal::Console => create_console(&lua).await?, lua.set_app_data(Arc::downgrade(&lua));
LuneGlobal::Fs => create_fs(&lua).await?, lua.set_app_data(Arc::downgrade(&exec));
LuneGlobal::Net => create_net(&lua).await?, lua.set_app_data(Arc::downgrade(&sender));
LuneGlobal::Process => create_process(&lua, self.args.clone()).await?, lua.set_app_data(Arc::downgrade(&receiver));
LuneGlobal::Task => create_task(&lua).await?, // Add in wanted lune globals
} for global in &self.globals {
match &global {
LuneGlobal::Console => create_console(&lua)?,
LuneGlobal::Fs => create_fs(&lua)?,
LuneGlobal::Net => create_net(&lua)?,
LuneGlobal::Process => create_process(&lua, self.args.clone())?,
LuneGlobal::Task => create_task(&lua)?,
} }
}
// Spawn the main thread from our entrypoint script
let script_name = name.to_string();
let script_chunk = chunk.to_string();
exec.spawn(async move {
sender.send(LuneMessage::Spawned).await?;
let result = lua let result = lua
.load(chunk) .load(&script_chunk)
.set_name(name)? .set_name(&script_name)
.unwrap()
.call_async::<_, LuaMultiValue>(LuaMultiValue::new()) .call_async::<_, LuaMultiValue>(LuaMultiValue::new())
.await; .await;
match result { let message = match result {
Ok(_) => Ok(()), Ok(_) => LuneMessage::Finished,
Err(e) => { Err(e) => LuneMessage::Error(if cfg!(test) {
if cfg!(test) { anyhow!("{}", pretty_format_luau_error(&e))
bail!(pretty_format_luau_error(&e)) } else {
} else { anyhow!(
bail!( "\n{}\n{}",
"\n{}\n{}", format_label("ERROR"),
format_label("ERROR"), pretty_format_luau_error(&e)
pretty_format_luau_error(&e) )
) }),
};
sender.send(message).await
})
.detach();
// Run the executor until there are no tasks left
let mut task_count = 1;
smol::block_on(exec.run(async {
while let Ok(message) = receiver.recv().await {
match message {
LuneMessage::Spawned => {
task_count += 1;
}
LuneMessage::Finished => {
task_count -= 1;
if task_count <= 0 {
break;
}
}
LuneMessage::Error(e) => {
task_count -= 1;
bail!("{}", e)
}
LuneMessage::LuaError(e) => {
task_count -= 1;
bail!("{}", e)
} }
} }
} }
Ok(())
})) }))
} }
} }

View file

@ -1,67 +0,0 @@
type ThreadOrFunction = thread | (...any) -> ...any
-- NOTE: The async wait function gets passed in here by task.rs,
-- the same function will then be used for the global task library
local wait: (seconds: number?) -> number = ...
local task = {}
function task.cancel(thread: unknown)
if type(thread) ~= "thread" then
error(string.format("Argument #1 must be a thread, got %s", typeof(thread)))
else
coroutine.close(thread)
end
end
function task.defer(tof: ThreadOrFunction, ...: any): thread
if type(tof) == "thread" then
local thread = coroutine.create(function(...)
wait()
coroutine.resume(tof, ...)
end)
coroutine.resume(thread, ...)
return thread
elseif type(tof) == "function" then
local thread = coroutine.create(tof)
task.defer(thread, ...)
return thread
else
error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof)))
end
end
function task.delay(duration: number?, tof: ThreadOrFunction, ...: any): thread
if duration ~= nil and type(duration) ~= "number" then
error(string.format("Argument #1 must be a number or nil, got %s", typeof(duration)))
end
if type(tof) == "thread" then
local thread = coroutine.create(function(...)
wait(duration)
coroutine.resume(tof, ...)
end)
coroutine.resume(thread, ...)
return thread
elseif type(tof) == "function" then
local thread = coroutine.create(tof)
task.delay(duration, thread, ...)
return thread
else
error(string.format("Argument #2 must be a thread or function, got %s", typeof(tof)))
end
end
function task.spawn(tof: ThreadOrFunction, ...: any): thread
if type(tof) == "thread" then
coroutine.resume(tof, ...)
return tof
elseif type(tof) == "function" then
local thread = coroutine.create(tof)
coroutine.resume(thread, ...)
return thread
else
error(string.format("Argument #1 must be a thread or function, got %s", typeof(tof)))
end
end
return task

View file

@ -37,19 +37,10 @@ measure(1 / 10)
-- Wait should work in other threads, too -- Wait should work in other threads, too
local flag: boolean = false local flag: boolean = false
coroutine.wrap(function()
task.wait(0.1)
flag = true
end)()
assert(not flag, "Wait failed for a coroutine (1)")
task.wait(0.2)
assert(flag, "Wait failed for a coroutine (2)")
local flag2: boolean = false
task.spawn(function() task.spawn(function()
task.wait(0.1) task.wait(0.1)
flag2 = true flag = true
end) end)
assert(not flag2, "Wait failed for a task-spawned thread (1)") assert(not flag, "Wait failed for a task-spawned thread (1)")
task.wait(0.2) task.wait(0.2)
assert(flag2, "Wait failed for a task-spawned thread (2)") assert(flag, "Wait failed for a task-spawned thread (2)")