Initial implementation of proper task scheduler, no async yet

This commit is contained in:
Filip Tibell 2023-02-13 15:28:18 +01:00
parent bb182033b9
commit fc5de3c8d5
No known key found for this signature in database
23 changed files with 685 additions and 430 deletions

View file

@ -22,7 +22,7 @@ mod tests;
use cli::Cli;
#[tokio::main]
#[tokio::main(flavor = "multi_thread")]
async fn main() -> Result<ExitCode> {
Cli::parse().run().await
}

View file

@ -36,20 +36,20 @@ async fn ensure_file_exists_and_is_not_json(file_name: &str) -> Result<()> {
}
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn list() -> Result<()> {
Cli::list().run().await?;
Ok(())
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn download_selene_types() -> Result<()> {
run_cli(Cli::download_selene_types()).await?;
ensure_file_exists_and_is_not_json(LUNE_SELENE_FILE_NAME).await?;
Ok(())
}
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn download_luau_types() -> Result<()> {
run_cli(Cli::download_luau_types()).await?;
ensure_file_exists_and_is_not_json(LUNE_LUAU_FILE_NAME).await?;

View file

@ -2,6 +2,8 @@ use std::fmt::{Display, Formatter, Result as FmtResult};
use mlua::prelude::*;
use crate::lua::task::TaskScheduler;
mod fs;
mod net;
mod process;
@ -78,14 +80,18 @@ impl LuneGlobal {
Note that proxy globals should be handled with special care and that [`LuneGlobal::inject()`]
should be preferred over manually creating and manipulating the value(s) of any Lune global.
*/
pub fn value(&self, lua: &'static Lua) -> LuaResult<LuaTable> {
pub fn value(
&self,
lua: &'static Lua,
scheduler: &'static TaskScheduler,
) -> LuaResult<LuaTable> {
match self {
LuneGlobal::Fs => fs::create(lua),
LuneGlobal::Net => net::create(lua),
LuneGlobal::Process { args } => process::create(lua, args.clone()),
LuneGlobal::Require => require::create(lua),
LuneGlobal::Stdio => stdio::create(lua),
LuneGlobal::Task => task::create(lua),
LuneGlobal::Task => task::create(lua, scheduler),
LuneGlobal::TopLevel => top_level::create(lua),
}
}
@ -98,9 +104,9 @@ impl LuneGlobal {
Refer to [`LuneGlobal::is_top_level()`] for more info on proxy globals.
*/
pub fn inject(self, lua: &'static Lua) -> LuaResult<()> {
pub fn inject(self, lua: &'static Lua, scheduler: &'static TaskScheduler) -> LuaResult<()> {
let globals = lua.globals();
let table = self.value(lua)?;
let table = self.value(lua, scheduler)?;
// NOTE: Top level globals are special, the values
// *in* the table they return should be set directly,
// instead of setting the table itself as the global

View file

@ -17,10 +17,7 @@ use tokio::{sync::mpsc, task};
use crate::{
lua::net::{NetClient, NetClientBuilder, NetWebSocketClient, NetWebSocketServer, ServeConfig},
utils::{
message::LuneMessage, net::get_request_user_agent_header, table::TableBuilder,
task::send_message,
},
utils::{net::get_request_user_agent_header, table::TableBuilder},
};
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
@ -29,7 +26,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
let client = NetClientBuilder::new()
.headers(&[("User-Agent", get_request_user_agent_header())])?
.build()?;
lua.set_named_registry_value("NetClient", client)?;
lua.set_named_registry_value("net.client", client)?;
// Create the global table for net
TableBuilder::new(lua)?
.with_function("jsonEncode", net_json_encode)?
@ -54,7 +51,7 @@ fn net_json_decode(lua: &'static Lua, json: String) -> LuaResult<LuaValue> {
}
async fn net_request<'a>(lua: &'static Lua, config: LuaValue<'a>) -> LuaResult<LuaTable<'a>> {
let client: NetClient = lua.named_registry_value("NetClient")?;
let client: NetClient = lua.named_registry_value("net.client")?;
// Extract stuff from config and make sure its all valid
let (url, method, headers, body) = match config {
LuaValue::String(s) => {
@ -177,20 +174,9 @@ async fn net_serve<'a>(
shutdown_rx.recv().await.unwrap();
shutdown_rx.close();
});
// Make sure we register the thread properly by sending messages
// when the server starts up and when it shuts down or errors
send_message(lua, LuneMessage::Spawned).await?;
task::spawn_local(async move {
let res = server.await.map_err(LuaError::external);
let _ = send_message(
lua,
match res {
Err(e) => LuneMessage::LuaError(e),
Ok(_) => LuneMessage::Finished,
},
)
.await;
});
// TODO: Spawn a new scheduler future with this so we don't block
// and make sure that we register it properly to prevent shutdown
server.await.map_err(LuaError::external)?;
// Create a new read-only table that contains methods
// for manipulating server behavior and shutting it down
let handle_stop = move |_, _: ()| {
@ -313,8 +299,8 @@ impl Service<Request<Body>> for NetService {
Ok(resp.body(body).unwrap())
}
// If the handler returns an error, generate a 5xx response
Err(err) => {
send_message(lua, LuneMessage::LuaError(err.to_lua_err())).await?;
Err(_) => {
// TODO: Send above error to task scheduler so that it can emit properly
Ok(Response::builder()
.status(500)
.body(Body::from("Internal Server Error"))
@ -323,10 +309,11 @@ impl Service<Request<Body>> for NetService {
// If the handler returns a value that is of an invalid type,
// this should also be an error, so generate a 5xx response
Ok(value) => {
send_message(lua, LuneMessage::LuaError(LuaError::RuntimeError(format!(
// TODO: Send below error to task scheduler so that it can emit properly
let _ = LuaError::RuntimeError(format!(
"Expected net serve handler to return a value of type 'string' or 'table', got '{}'",
value.type_name()
)))).await?;
));
Ok(Response::builder()
.status(500)
.body(Body::from("Internal Server Error"))

View file

@ -1,21 +1,35 @@
use std::{collections::HashMap, env, path::PathBuf, process::Stdio};
use std::{
collections::HashMap,
env,
path::PathBuf,
process::{ExitCode, Stdio},
};
use directories::UserDirs;
use mlua::prelude::*;
use os_str_bytes::RawOsString;
use tokio::process::Command;
use crate::utils::{
process::{exit_and_yield_forever, pipe_and_inherit_child_process_stdio},
table::TableBuilder,
use crate::{
lua::task::TaskScheduler,
utils::{process::pipe_and_inherit_child_process_stdio, table::TableBuilder},
};
const PROCESS_EXIT_IMPL_LUA: &str = r#"
exit(...)
yield()
"#;
pub fn create(lua: &'static Lua, args_vec: Vec<String>) -> LuaResult<LuaTable> {
let cwd_str = {
let cwd = env::current_dir()?.canonicalize()?;
let mut cwd_str = cwd.to_string_lossy().to_string();
let cwd_str = cwd.to_string_lossy().to_string();
if !cwd_str.ends_with('/') {
cwd_str = format!("{cwd_str}/");
format!("{cwd_str}/")
} else {
cwd_str
}
};
// Create readonly args array
let args_tab = TableBuilder::new(lua)?
.with_sequential_values(args_vec)?
@ -30,12 +44,31 @@ pub fn create(lua: &'static Lua, args_vec: Vec<String>) -> LuaResult<LuaTable> {
.build_readonly()?,
)?
.build_readonly()?;
// Create our process exit function, this is a bit involved since
// we have no way to yield from c / rust, we need to load a lua
// chunk that will set the exit code and yield for us instead
let process_exit_env_yield: LuaFunction = lua.named_registry_value("co.yield")?;
let process_exit_env_exit: LuaFunction = lua.create_function(|lua, code: Option<u8>| {
let exit_code = code.map_or(ExitCode::SUCCESS, ExitCode::from);
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
sched.set_exit_code(exit_code);
Ok(())
})?;
let process_exit = lua
.load(PROCESS_EXIT_IMPL_LUA)
.set_environment(
TableBuilder::new(lua)?
.with_value("yield", process_exit_env_yield)?
.with_value("exit", process_exit_env_exit)?
.build_readonly()?,
)?
.into_function()?;
// Create the full process table
TableBuilder::new(lua)?
.with_value("args", args_tab)?
.with_value("cwd", cwd_str)?
.with_value("env", env_tab)?
.with_async_function("exit", process_exit)?
.with_value("exit", process_exit)?
.with_async_function("spawn", process_spawn)?
.build_readonly()
}
@ -109,10 +142,6 @@ fn process_env_iter<'lua>(
})
}
async fn process_exit(lua: &'static Lua, exit_code: Option<u8>) -> LuaResult<()> {
exit_and_yield_forever(lua, exit_code).await
}
async fn process_spawn<'a>(
lua: &'static Lua,
(mut program, args, options): (String, Option<Vec<String>>, Option<LuaTable<'a>>),

View file

@ -10,12 +10,10 @@ use os_str_bytes::{OsStrBytes, RawOsStr};
use crate::utils::table::TableBuilder;
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
let require: LuaFunction = lua.globals().raw_get("require")?;
// Preserve original require behavior if we have a special env var set
if env::var_os("LUAU_PWD_REQUIRE").is_some() {
return TableBuilder::new(lua)?
.with_value("require", require)?
.build_readonly();
// Return an empty table since there are no globals to overwrite
return TableBuilder::new(lua)?.build_readonly();
}
/*
Store the current working directory so that we can use it later
@ -27,24 +25,17 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
just in case someone out there uses luau with non-utf8 string requires
*/
let pwd = lua.create_string(&current_dir()?.to_raw_bytes())?;
lua.set_named_registry_value("require_pwd", pwd)?;
// Fetch the debug info function and store it in the registry
// - we will use it to fetch the current scripts file name
let debug: LuaTable = lua.globals().raw_get("debug")?;
let info: LuaFunction = debug.raw_get("info")?;
lua.set_named_registry_value("require_getinfo", info)?;
// Store the original require function in the registry
lua.set_named_registry_value("require_original", require)?;
lua.set_named_registry_value("pwd", pwd)?;
/*
Create a new function that fetches the file name from the current thread,
sets the luau module lookup path to be the exact script we are looking
for, and then runs the original require function with the wanted path
*/
let new_require = lua.create_function(|lua, require_path: LuaString| {
let require_pwd: LuaString = lua.named_registry_value("require_pwd")?;
let require_original: LuaFunction = lua.named_registry_value("require_original")?;
let require_getinfo: LuaFunction = lua.named_registry_value("require_getinfo")?;
let require_source: LuaString = require_getinfo.call((2, "s"))?;
let require_pwd: LuaString = lua.named_registry_value("pwd")?;
let require_fn: LuaFunction = lua.named_registry_value("require")?;
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
let require_source: LuaString = require_info.call((2, "s"))?;
/*
Combine the require caller source with the wanted path
string to get a final path relative to pwd - it is definitely
@ -53,7 +44,12 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
let raw_pwd_str = RawOsStr::assert_from_raw_bytes(require_pwd.as_bytes());
let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes());
let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes());
let mut path_relative_to_pwd = PathBuf::from(&raw_source.to_os_str())
let mut path_relative_to_pwd = PathBuf::from(
&raw_source
.trim_start_matches("[string \"")
.trim_end_matches("\"]")
.to_os_str(),
)
.parent()
.unwrap()
.join(raw_path.to_os_str());
@ -72,7 +68,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes());
// If the require call errors then we should also replace
// the path in the error message to improve user experience
let result: LuaResult<_> = require_original.call::<_, LuaValue>(lua_path_str);
let result: LuaResult<_> = require_fn.call::<_, LuaValue>(lua_path_str);
match result {
Err(LuaError::CallbackError { traceback, cause }) => {
let before = format!(

View file

@ -1,150 +1,84 @@
use std::time::{Duration, Instant};
use mlua::prelude::*;
use tokio::time;
use crate::utils::{
table::TableBuilder,
task::{run_registered_task, TaskRunMode},
use crate::{
lua::task::{TaskReference, TaskScheduler},
utils::table::TableBuilder,
};
const MINIMUM_WAIT_OR_DELAY_DURATION: f32 = 10.0 / 1_000.0; // 10ms
const TASK_WAIT_IMPL_LUA: &str = r#"
resume_after(thread(), ...)
return yield()
"#;
// TODO: We should probably keep track of all threads in a scheduler userdata
// that takes care of scheduling in a better way, and it should keep resuming
// threads until it encounters a delayed / waiting thread, then task:sleep
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
// HACK: There is no way to call coroutine.close directly from the mlua
// crate, so we need to fetch the function and store it in the registry
let coroutine: LuaTable = lua.globals().raw_get("coroutine")?;
let close: LuaFunction = coroutine.raw_get("close")?;
lua.set_named_registry_value("coroutine.close", close)?;
// HACK: coroutine.resume has some weird scheduling issues, but our custom
// task.spawn implementation is more or less a replacement for it, so we
// overwrite the original coroutine.resume function with it to fix that
coroutine.raw_set("resume", lua.create_async_function(task_spawn)?)?;
// Rest of the task library is normal, just async functions, no metatable
pub fn create(
lua: &'static Lua,
scheduler: &'static TaskScheduler,
) -> LuaResult<LuaTable<'static>> {
lua.set_app_data(scheduler);
// Create task spawning functions that add tasks to the scheduler
let task_spawn = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| {
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
sched.schedule_instant(tof, args)
})?;
let task_defer = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| {
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
sched.schedule_deferred(tof, args)
})?;
let task_delay =
lua.create_function(|lua, (secs, tof, args): (f64, LuaValue, LuaMultiValue)| {
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
sched.schedule_delayed(secs, tof, args)
})?;
// Create our task wait function, this is a bit different since
// we have no way to yield from c / rust, we need to load a
// lua chunk that schedules and yields for us instead
let task_wait_env_thread: LuaFunction = lua.named_registry_value("co.thread")?;
let task_wait_env_yield: LuaFunction = lua.named_registry_value("co.yield")?;
let task_wait = lua
.load(TASK_WAIT_IMPL_LUA)
.set_environment(
TableBuilder::new(lua)?
.with_async_function("cancel", task_cancel)?
.with_async_function("delay", task_delay)?
.with_async_function("defer", task_defer)?
.with_async_function("spawn", task_spawn)?
.with_async_function("wait", task_wait)?
.with_value("thread", task_wait_env_thread)?
.with_value("yield", task_wait_env_yield)?
.with_function(
"resume_after",
|lua, (thread, secs): (LuaThread, Option<f64>)| {
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
sched.resume_after(secs.unwrap_or(0f64), thread)
},
)?
.build_readonly()?,
)?
.into_function()?;
// We want the task scheduler to be transparent,
// but it does not return real lua threads, so
// we need to override some globals to fake it
let globals = lua.globals();
let type_original: LuaFunction = globals.get("type")?;
let type_proxy = lua.create_function(move |_, value: LuaValue| {
if let LuaValue::UserData(u) = &value {
if u.is::<TaskReference>() {
return Ok(LuaValue::String(lua.create_string("thread")?));
}
}
type_original.call(value)
})?;
let typeof_original: LuaFunction = globals.get("typeof")?;
let typeof_proxy = lua.create_function(move |_, value: LuaValue| {
if let LuaValue::UserData(u) = &value {
if u.is::<TaskReference>() {
return Ok(LuaValue::String(lua.create_string("thread")?));
}
}
typeof_original.call(value)
})?;
globals.set("type", type_proxy)?;
globals.set("typeof", typeof_proxy)?;
// All good, return the task scheduler lib
TableBuilder::new(lua)?
.with_value("spawn", task_spawn)?
.with_value("defer", task_defer)?
.with_value("delay", task_delay)?
.with_value("wait", task_wait)?
.build_readonly()
}
fn tof_to_thread<'a>(
lua: &'static Lua,
thread_or_function: LuaValue<'a>,
) -> LuaResult<LuaThread<'a>> {
match thread_or_function {
LuaValue::Thread(t) => Ok(t),
LuaValue::Function(f) => Ok(lua.create_thread(f)?),
value => Err(LuaError::RuntimeError(format!(
"Argument must be a thread or function, got {}",
value.type_name()
))),
}
}
async fn task_cancel<'a>(lua: &'static Lua, thread: LuaThread<'a>) -> LuaResult<()> {
let close: LuaFunction = lua.named_registry_value("coroutine.close")?;
close.call_async::<_, LuaMultiValue>(thread).await?;
Ok(())
}
async fn task_defer<'a>(
lua: &'static Lua,
(tof, args): (LuaValue<'a>, LuaMultiValue<'a>),
) -> LuaResult<LuaThread<'a>> {
let task_thread = tof_to_thread(lua, tof)?;
let task_thread_key = lua.create_registry_value(task_thread)?;
let task_args_key = lua.create_registry_value(args.into_vec())?;
let lua_thread_to_return = lua.registry_value(&task_thread_key)?;
run_registered_task(lua, TaskRunMode::Deferred, async move {
let thread: LuaThread = lua.registry_value(&task_thread_key)?;
let argsv: Vec<LuaValue> = lua.registry_value(&task_args_key)?;
let args = LuaMultiValue::from_vec(argsv);
if thread.status() == LuaThreadStatus::Resumable {
let _: LuaMultiValue = thread.into_async(args).await?;
}
lua.remove_registry_value(task_thread_key)?;
lua.remove_registry_value(task_args_key)?;
Ok(())
})
.await?;
Ok(lua_thread_to_return)
}
async fn task_delay<'a>(
lua: &'static Lua,
(duration, tof, args): (Option<f32>, LuaValue<'a>, LuaMultiValue<'a>),
) -> LuaResult<LuaThread<'a>> {
let task_thread = tof_to_thread(lua, tof)?;
let task_thread_key = lua.create_registry_value(task_thread)?;
let task_args_key = lua.create_registry_value(args.into_vec())?;
let lua_thread_to_return = lua.registry_value(&task_thread_key)?;
let start = Instant::now();
let dur = Duration::from_secs_f32(
duration
.map(|d| d.max(MINIMUM_WAIT_OR_DELAY_DURATION))
.unwrap_or(MINIMUM_WAIT_OR_DELAY_DURATION),
);
run_registered_task(lua, TaskRunMode::Instant, async move {
let thread: LuaThread = lua.registry_value(&task_thread_key)?;
// NOTE: We are somewhat busy-waiting here, but we have to do this to make sure
// that delayed+cancelled threads do not prevent the tokio runtime from finishing
while thread.status() == LuaThreadStatus::Resumable && start.elapsed() < dur {
time::sleep(Duration::from_millis(1)).await;
}
if thread.status() == LuaThreadStatus::Resumable {
let argsv: Vec<LuaValue> = lua.registry_value(&task_args_key)?;
let args = LuaMultiValue::from_vec(argsv);
let _: LuaMultiValue = thread.into_async(args).await?;
}
lua.remove_registry_value(task_thread_key)?;
lua.remove_registry_value(task_args_key)?;
Ok(())
})
.await?;
Ok(lua_thread_to_return)
}
async fn task_spawn<'a>(
lua: &'static Lua,
(tof, args): (LuaValue<'a>, LuaMultiValue<'a>),
) -> LuaResult<LuaThread<'a>> {
let task_thread = tof_to_thread(lua, tof)?;
let task_thread_key = lua.create_registry_value(task_thread)?;
let task_args_key = lua.create_registry_value(args.into_vec())?;
let lua_thread_to_return = lua.registry_value(&task_thread_key)?;
run_registered_task(lua, TaskRunMode::Instant, async move {
let thread: LuaThread = lua.registry_value(&task_thread_key)?;
let argsv: Vec<LuaValue> = lua.registry_value(&task_args_key)?;
let args = LuaMultiValue::from_vec(argsv);
if thread.status() == LuaThreadStatus::Resumable {
let _: LuaMultiValue = thread.into_async(args).await?;
}
lua.remove_registry_value(task_thread_key)?;
lua.remove_registry_value(task_args_key)?;
Ok(())
})
.await?;
Ok(lua_thread_to_return)
}
async fn task_wait(lua: &'static Lua, duration: Option<f32>) -> LuaResult<f32> {
let start = Instant::now();
run_registered_task(lua, TaskRunMode::Blocking, async move {
time::sleep(Duration::from_secs_f32(
duration
.map(|d| d.max(MINIMUM_WAIT_OR_DELAY_DURATION))
.unwrap_or(MINIMUM_WAIT_OR_DELAY_DURATION),
))
.await;
Ok(())
})
.await?;
let end = Instant::now();
Ok((end - start).as_secs_f32())
}

View file

@ -6,15 +6,10 @@ use crate::utils::{
};
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
let globals = lua.globals();
// HACK: We need to preserve the default behavior of the
// print and error functions, for pcall and such, which
// is really tricky to do from scratch so we will just
// proxy the default print and error functions here
let print_fn: LuaFunction = globals.raw_get("print")?;
let error_fn: LuaFunction = globals.raw_get("error")?;
lua.set_named_registry_value("print", print_fn)?;
lua.set_named_registry_value("error", error_fn)?;
TableBuilder::new(lua)?
.with_function("print", |lua, args: LuaMultiValue| {
let formatted = pretty_format_multi_value(&args)?;

View file

@ -1,8 +1,8 @@
use std::{collections::HashSet, process::ExitCode, sync::Arc};
use std::{collections::HashSet, process::ExitCode};
use lua::task::TaskScheduler;
use mlua::prelude::*;
use tokio::{sync::mpsc, task};
use utils::task::send_message;
use tokio::task::LocalSet;
pub(crate) mod globals;
pub(crate) mod lua;
@ -11,7 +11,7 @@ pub(crate) mod utils;
#[cfg(test)]
mod tests;
use crate::utils::{formatting::pretty_format_luau_error, message::LuneMessage};
use crate::utils::formatting::pretty_format_luau_error;
pub use globals::LuneGlobal;
@ -75,12 +75,12 @@ impl Lune {
This will create a new sandboxed Luau environment with the configured
globals and arguments, running inside of a [`tokio::task::LocalSet`].
Some Lune globals such as [`LuneGlobal::Process`] may spawn
separate tokio tasks on other threads, but the Luau environment
Some Lune globals such as [`LuneGlobal::Process`] and [`LuneGlobal::Net`]
may spawn separate tokio tasks on other threads, but the Luau environment
itself is guaranteed to run on a single thread in the local set.
Note that this will create a static Lua instance that will live
for the remainer of the program, and that this leaks memory using
Note that this will create a static Lua instance and task scheduler which both
will live for the remainer of the program, and that this leaks memory using
[`Box::leak`] that will then get deallocated when the program exits.
*/
pub async fn run(
@ -88,92 +88,64 @@ impl Lune {
script_name: &str,
script_contents: &str,
) -> Result<ExitCode, LuaError> {
let task_set = task::LocalSet::new();
let (sender, mut receiver) = mpsc::channel::<LuneMessage>(64);
let set = LocalSet::new();
let lua = Lua::new().into_static();
let snd = Arc::new(sender);
lua.set_app_data(Arc::downgrade(&snd));
let sched = TaskScheduler::new(lua)?.into_static();
lua.set_app_data(sched);
// Store original lua global functions in the registry so we can use
// them later without passing them around and dealing with lifetimes
lua.set_named_registry_value("require", lua.globals().get::<_, LuaFunction>("require")?)?;
lua.set_named_registry_value("print", lua.globals().get::<_, LuaFunction>("print")?)?;
lua.set_named_registry_value("error", lua.globals().get::<_, LuaFunction>("error")?)?;
let coroutine: LuaTable = lua.globals().get("coroutine")?;
lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?;
lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?;
let debug: LuaTable = lua.globals().raw_get("debug")?;
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;
// Add in wanted lune globals
for global in self.includes.clone() {
if !self.excludes.contains(&global) {
global.inject(lua)?;
global.inject(lua, sched)?;
}
}
// Spawn the main thread from our entrypoint script
let script_name = script_name.to_string();
let script_chunk = script_contents.to_string();
send_message(lua, LuneMessage::Spawned).await?;
task_set.spawn_local(async move {
let result = lua
.load(&script_chunk)
.set_name(&format!("={script_name}"))
// Schedule the main thread on the task scheduler
sched.schedule_instant(
LuaValue::Function(
lua.load(script_contents)
.set_name(script_name)
.unwrap()
.eval_async::<LuaValue>()
.await;
send_message(
lua,
match result {
Err(e) => LuneMessage::LuaError(e),
Ok(_) => LuneMessage::Finished,
},
)
.await
});
// Run the executor until there are no tasks left,
// taking care to not exit right away for errors
let (got_code, got_error, exit_code) = task_set
.into_function()
.unwrap(),
),
LuaValue::Nil.to_lua_multi(lua)?,
)?;
// Keep running the scheduler until there are either no tasks
// left to run, or until some task requests to exit the process
let exit_code = set
.run_until(async {
let mut task_count = 0;
let mut got_error = false;
let mut got_code = false;
let mut exit_code = 0;
while let Some(message) = receiver.recv().await {
// Make sure our task-count-modifying messages are sent correctly, one
// task spawned must always correspond to one task finished / errored
match &message {
LuneMessage::Exit(_) => {}
LuneMessage::Spawned => {}
message => {
if task_count == 0 {
return Err(format!(
"Got message while task count was 0!\nMessage: {message:#?}"
));
}
}
}
// Handle whatever message we got
match message {
LuneMessage::Exit(code) => {
exit_code = code;
got_code = true;
break;
}
LuneMessage::Spawned => task_count += 1,
LuneMessage::Finished => task_count -= 1,
LuneMessage::LuaError(e) => {
while let Some(result) = sched.resume_queue().await {
match result {
Err(e) => {
eprintln!("{}", pretty_format_luau_error(&e));
got_error = true;
task_count -= 1;
}
};
// If there are no tasks left running, it is now
// safe to close the receiver and end execution
if task_count == 0 {
receiver.close();
Ok(status) => {
if let Some(exit_code) = status.exit_code {
return exit_code;
} else if status.num_total == 0 {
return ExitCode::SUCCESS;
}
}
Ok((got_code, got_error, exit_code))
})
.await
.map_err(LuaError::external)?;
// If we got an error, we will default to exiting
// with code 1, unless a code was manually given
if got_code {
Ok(ExitCode::from(exit_code))
} else if got_error {
Ok(ExitCode::FAILURE)
}
}
if got_error {
ExitCode::FAILURE
} else {
Ok(ExitCode::SUCCESS)
}
ExitCode::SUCCESS
}
})
.await;
Ok(exit_code)
}
}

View file

@ -1 +1,2 @@
pub mod net;
pub mod task;

View file

@ -0,0 +1,3 @@
mod scheduler;
pub use scheduler::*;

View file

@ -0,0 +1,406 @@
use std::{
collections::{HashMap, VecDeque},
fmt,
process::ExitCode,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc, Mutex,
},
time::Duration,
};
use mlua::prelude::*;
use tokio::time::{sleep, Instant};
type TaskSchedulerQueue = Arc<Mutex<VecDeque<TaskReference>>>;
/// An enum representing different kinds of tasks
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum TaskKind {
Instant,
Deferred,
Yielded,
}
impl fmt::Display for TaskKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let name: &'static str = match self {
TaskKind::Instant => "Instant",
TaskKind::Deferred => "Deferred",
TaskKind::Yielded => "Yielded",
};
write!(f, "{name}")
}
}
/// A lightweight, clonable struct that represents a
/// task in the scheduler and is accessible from Lua
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TaskReference {
kind: TaskKind,
guid: usize,
queued_target: Option<Instant>,
}
impl TaskReference {
pub const fn new(kind: TaskKind, guid: usize, queued_target: Option<Instant>) -> Self {
Self {
kind,
guid,
queued_target,
}
}
}
impl fmt::Display for TaskReference {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TaskReference({} - {})", self.kind, self.guid)
}
}
impl LuaUserData for TaskReference {}
impl From<&Task> for TaskReference {
fn from(value: &Task) -> Self {
Self::new(value.kind, value.guid, value.queued_target)
}
}
/// A struct representing a task contained in the task scheduler
#[derive(Debug)]
pub struct Task {
kind: TaskKind,
guid: usize,
thread: LuaRegistryKey,
args: LuaRegistryKey,
queued_at: Instant,
queued_target: Option<Instant>,
}
/// A struct representing the current status of the task scheduler
#[derive(Debug, Clone, Copy)]
pub struct TaskSchedulerStatus {
pub exit_code: Option<ExitCode>,
pub num_instant: usize,
pub num_deferred: usize,
pub num_yielded: usize,
pub num_total: usize,
}
impl fmt::Display for TaskSchedulerStatus {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"TaskSchedulerStatus(\nInstant: {}\nDeferred: {}\nYielded: {}\nTotal: {})",
self.num_instant, self.num_deferred, self.num_yielded, self.num_total
)
}
}
/// A task scheduler that implements task queues
/// with instant, deferred, and delayed tasks
#[derive(Debug)]
pub struct TaskScheduler {
lua: &'static Lua,
guid: AtomicUsize,
running: bool,
tasks: Arc<Mutex<HashMap<TaskReference, Task>>>,
task_queue_instant: TaskSchedulerQueue,
task_queue_deferred: TaskSchedulerQueue,
task_queue_yielded: TaskSchedulerQueue,
exit_code_set: AtomicBool,
exit_code: Arc<Mutex<ExitCode>>,
}
impl TaskScheduler {
pub fn new(lua: &'static Lua) -> LuaResult<Self> {
Ok(Self {
lua,
guid: AtomicUsize::new(0),
running: false,
tasks: Arc::new(Mutex::new(HashMap::new())),
task_queue_instant: Arc::new(Mutex::new(VecDeque::new())),
task_queue_deferred: Arc::new(Mutex::new(VecDeque::new())),
task_queue_yielded: Arc::new(Mutex::new(VecDeque::new())),
exit_code_set: AtomicBool::new(false),
exit_code: Arc::new(Mutex::new(ExitCode::SUCCESS)),
})
}
pub fn into_static(self) -> &'static Self {
Box::leak(Box::new(self))
}
pub fn status(&self) -> TaskSchedulerStatus {
let counts = {
(
self.task_queue_instant.lock().unwrap().len(),
self.task_queue_deferred.lock().unwrap().len(),
self.task_queue_yielded.lock().unwrap().len(),
)
};
let num_total = counts.0 + counts.1 + counts.2;
let exit_code = if self.exit_code_set.load(Ordering::Relaxed) {
Some(*self.exit_code.lock().unwrap())
} else {
None
};
TaskSchedulerStatus {
exit_code,
num_instant: counts.0,
num_deferred: counts.1,
num_yielded: counts.2,
num_total,
}
}
pub fn set_exit_code(&self, code: ExitCode) {
self.exit_code_set.store(true, Ordering::Relaxed);
*self.exit_code.lock().unwrap() = code
}
fn schedule<'a>(
&self,
kind: TaskKind,
tof: LuaValue<'a>,
args: Option<LuaMultiValue<'a>>,
delay: Option<f64>,
) -> LuaResult<TaskReference> {
// Get or create a thread from the given argument
let task_thread = match tof {
LuaValue::Thread(t) => t,
LuaValue::Function(f) => self.lua.create_thread(f)?,
value => {
return Err(LuaError::RuntimeError(format!(
"Argument must be a thread or function, got {}",
value.type_name()
)))
}
};
// Store the thread and its arguments in the registry
let task_args_vec = args.map(|opt| opt.into_vec());
let task_thread_key = self.lua.create_registry_value(task_thread)?;
let task_args_key = self.lua.create_registry_value(task_args_vec)?;
// Create the full task struct
let guid = self.guid.fetch_add(1, Ordering::Relaxed) + 1;
let queued_at = Instant::now();
let queued_target = delay.map(|secs| queued_at + Duration::from_secs_f64(secs));
let task = Task {
kind,
guid,
thread: task_thread_key,
args: task_args_key,
queued_at,
queued_target,
};
// Create the task ref (before adding the task to the scheduler)
let task_ref = TaskReference::from(&task);
// Add it to the scheduler
{
let mut tasks = self.tasks.lock().unwrap();
tasks.insert(task_ref, task);
}
match kind {
TaskKind::Instant => {
// If we have a currently running task and we spawned an
// instant task here it should run right after the currently
// running task, so put it at the front of the task queue
let mut queue = self.task_queue_instant.lock().unwrap();
if self.running {
queue.push_front(task_ref);
} else {
queue.push_back(task_ref);
}
}
TaskKind::Deferred => {
// Deferred tasks should always schedule
// at the very end of the deferred queue
let mut queue = self.task_queue_deferred.lock().unwrap();
queue.push_back(task_ref);
}
TaskKind::Yielded => {
// Find the first task that is scheduled after this one and insert before it,
// this will ensure that our list of delayed tasks is sorted and we can grab
// the very first one to figure out how long to yield until the next cycle
let mut queue = self.task_queue_yielded.lock().unwrap();
let idx = queue
.iter()
.enumerate()
.find_map(|(idx, t)| {
if t.queued_target > queued_target {
Some(idx)
} else {
None
}
})
.unwrap_or(queue.len());
queue.insert(idx, task_ref);
}
}
Ok(task_ref)
}
pub fn schedule_instant<'a>(
&self,
tof: LuaValue<'a>,
args: LuaMultiValue<'a>,
) -> LuaResult<TaskReference> {
self.schedule(TaskKind::Instant, tof, Some(args), None)
}
pub fn schedule_deferred<'a>(
&self,
tof: LuaValue<'a>,
args: LuaMultiValue<'a>,
) -> LuaResult<TaskReference> {
self.schedule(TaskKind::Deferred, tof, Some(args), None)
}
pub fn schedule_delayed<'a>(
&self,
secs: f64,
tof: LuaValue<'a>,
args: LuaMultiValue<'a>,
) -> LuaResult<TaskReference> {
self.schedule(TaskKind::Yielded, tof, Some(args), Some(secs))
}
pub fn resume_after(&self, secs: f64, thread: LuaThread<'_>) -> LuaResult<TaskReference> {
self.schedule(
TaskKind::Yielded,
LuaValue::Thread(thread),
None,
Some(secs),
)
}
pub fn cancel(&self, reference: TaskReference) -> bool {
let queue_mutex = match reference.kind {
TaskKind::Instant => &self.task_queue_instant,
TaskKind::Deferred => &self.task_queue_deferred,
TaskKind::Yielded => &self.task_queue_yielded,
};
let mut queue = queue_mutex.lock().unwrap();
let mut found = false;
queue.retain(|task| {
if task.guid == reference.guid {
found = true;
false
} else {
true
}
});
found
}
pub fn resume_task(&self, reference: TaskReference) -> LuaResult<()> {
let task = {
let mut tasks = self.tasks.lock().unwrap();
match tasks.remove(&reference) {
Some(task) => task,
None => {
return Err(LuaError::RuntimeError(format!(
"Task does not exist in scheduler: {reference}"
)))
}
}
};
let thread: LuaThread = self.lua.registry_value(&task.thread)?;
let args: Option<Vec<LuaValue>> = self.lua.registry_value(&task.args)?;
if let Some(args) = args {
thread.resume::<_, LuaMultiValue>(LuaMultiValue::from_vec(args))?;
} else {
let elapsed = task.queued_at.elapsed().as_secs_f64();
thread.resume::<_, LuaMultiValue>(elapsed)?;
}
self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
Ok(())
}
fn get_queue(&self, kind: TaskKind) -> &TaskSchedulerQueue {
match kind {
TaskKind::Instant => &self.task_queue_instant,
TaskKind::Deferred => &self.task_queue_deferred,
TaskKind::Yielded => &self.task_queue_yielded,
}
}
fn next_queue_task(&self, kind: TaskKind) -> Option<TaskReference> {
let task = {
let queue_guard = self.get_queue(kind).lock().unwrap();
queue_guard.front().copied()
};
task
}
fn resume_next_queue_task(&self, kind: TaskKind) -> Option<LuaResult<TaskSchedulerStatus>> {
match {
let mut queue_guard = self.get_queue(kind).lock().unwrap();
queue_guard.pop_front()
} {
None => {
let status = self.status();
if status.num_total > 0 {
Some(Ok(status))
} else {
None
}
}
Some(t) => match self.resume_task(t) {
Ok(_) => Some(Ok(self.status())),
Err(e) => Some(Err(e)),
},
}
}
pub async fn resume_queue(&self) -> Option<LuaResult<TaskSchedulerStatus>> {
let now = Instant::now();
let status = self.status();
/*
Resume tasks in the internal queue, in this order:
1. Tasks from task.spawn, this includes the main thread
2. Tasks from task.defer
3. Tasks from task.delay OR futures, whichever comes first
4. Tasks from futures
*/
if status.num_instant > 0 {
self.resume_next_queue_task(TaskKind::Instant)
} else if status.num_deferred > 0 {
self.resume_next_queue_task(TaskKind::Deferred)
} else if status.num_yielded > 0 {
// 3. Threads from task.delay or task.wait, futures
let next_yield_target = self
.next_queue_task(TaskKind::Yielded)
.expect("Yielded task missing but status count is > 0")
.queued_target
.expect("Yielded task is missing queued target");
// Resume this yielding task if its target time has passed
if now >= next_yield_target {
self.resume_next_queue_task(TaskKind::Yielded)
} else {
/*
Await the first future to be ready
- If it is the sleep fut then we will return and the next
call to resume_queue will then resume that yielded task
- If it is a future then we resume the corresponding task
that is has stored in the future-specific task queue
*/
sleep(next_yield_target - now).await;
// TODO: Implement this, for now we only await sleep
// since the task scheduler doesn't support futures
Some(Ok(self.status()))
}
} else {
// 4. Just futures
// TODO: Await the first future to be ready
// and resume the corresponding task for it
None
}
}
}

View file

@ -11,7 +11,7 @@ const ARGS: &[&str] = &["Foo", "Bar"];
macro_rules! create_tests {
($($name:ident: $value:expr,)*) => { $(
#[tokio::test]
#[tokio::test(flavor = "multi_thread")]
async fn $name() -> Result<ExitCode> {
// Disable styling for stdout and stderr since
// some tests rely on output not being styled

View file

@ -4,6 +4,8 @@ use console::{style, Style};
use lazy_static::lazy_static;
use mlua::prelude::*;
use crate::lua::task::TaskReference;
const MAX_FORMAT_DEPTH: usize = 4;
const INDENT: &str = " ";
@ -165,9 +167,17 @@ pub fn pretty_format_value(
)?,
LuaValue::Thread(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to("<thread>"))?,
LuaValue::Function(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to("<function>"))?,
LuaValue::UserData(_) | LuaValue::LightUserData(_) => {
LuaValue::UserData(u) => {
if u.is::<TaskReference>() {
// Task references must be transparent
// to lua and pretend to be normal lua
// threads for compatibility purposes
write!(buffer, "{}", COLOR_PURPLE.apply_to("<thread>"))?
} else {
write!(buffer, "{}", COLOR_PURPLE.apply_to("<userdata>"))?
}
}
LuaValue::LightUserData(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to("<userdata>"))?,
_ => write!(buffer, "{}", STYLE_DIM.apply_to("?"))?,
}
Ok(())
@ -220,16 +230,22 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
err_lines.join("\n")
}
LuaError::CallbackError { traceback, cause } => {
// Find the best traceback (longest) and the root error message
// Find the best traceback (most lines) and the root error message
let mut best_trace = traceback;
let mut root_cause = cause.as_ref();
while let LuaError::CallbackError { cause, traceback } = root_cause {
if traceback.len() > best_trace.len() {
if traceback.lines().count() > best_trace.len() {
best_trace = traceback;
}
root_cause = cause;
}
// Same error formatting as above
// If we got a runtime error with an embedded traceback, we should
// use that instead since it generally contains more information
if matches!(root_cause, LuaError::RuntimeError(e) if e.contains("stack traceback:")) {
pretty_format_luau_error(root_cause)
} else {
// Otherwise we format whatever root error we got using
// the same error formatting as for above runtime errors
format!(
"{}\n{}\n{}\n{}",
pretty_format_luau_error(root_cause),
@ -238,6 +254,7 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
stack_end
)
}
}
LuaError::ToLuaConversionError { from, to, message } => {
let msg = message
.clone()

View file

@ -1,9 +0,0 @@
use mlua::prelude::*;
#[derive(Debug, Clone)]
pub enum LuneMessage {
Exit(u8),
Spawned,
Finished,
LuaError(LuaError),
}

View file

@ -1,7 +1,5 @@
pub mod formatting;
pub mod futures;
pub mod message;
pub mod net;
pub mod process;
pub mod table;
pub mod task;

View file

@ -1,11 +1,9 @@
use std::{process::ExitStatus, time::Duration};
use std::process::ExitStatus;
use mlua::prelude::*;
use tokio::{io, process::Child, task::spawn, time::sleep};
use tokio::{io, process::Child, task::spawn};
use crate::utils::{futures::AsyncTeeWriter, message::LuneMessage};
use super::task::send_message;
use crate::utils::futures::AsyncTeeWriter;
pub async fn pipe_and_inherit_child_process_stdio(
mut child: Child,
@ -42,13 +40,3 @@ pub async fn pipe_and_inherit_child_process_stdio(
Ok::<_, LuaError>((status, stdout_buffer?, stderr_buffer?))
}
pub async fn exit_and_yield_forever(lua: &'static Lua, exit_code: Option<u8>) -> LuaResult<()> {
// Send an exit signal to the main thread, which
// will try to exit safely and as soon as possible
send_message(lua, LuneMessage::Exit(exit_code.unwrap_or(0))).await?;
// Make sure to block the rest of this thread indefinitely since
// the main thread may not register the exit signal right away
sleep(Duration::MAX).await;
Ok(())
}

View file

@ -1,76 +0,0 @@
use std::fmt::{self, Debug};
use std::future::Future;
use std::sync::Weak;
use mlua::prelude::*;
use tokio::sync::mpsc::Sender;
use tokio::task;
use crate::utils::message::LuneMessage;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub enum TaskRunMode {
Blocking,
Instant,
Deferred,
}
impl fmt::Display for TaskRunMode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Blocking => write!(f, "Blocking"),
Self::Instant => write!(f, "Instant"),
Self::Deferred => write!(f, "Deferred"),
}
}
}
pub async fn send_message(lua: &'static Lua, message: LuneMessage) -> LuaResult<()> {
let sender = lua
.app_data_ref::<Weak<Sender<LuneMessage>>>()
.unwrap()
.upgrade()
.unwrap();
sender.send(message).await.map_err(LuaError::external)
}
pub async fn run_registered_task<T>(
lua: &'static Lua,
mode: TaskRunMode,
to_run: impl Future<Output = LuaResult<T>> + 'static,
) -> LuaResult<()> {
// Send a message that we have started our task
send_message(lua, LuneMessage::Spawned).await?;
// Run the new task separately from the current one using the executor
let task = task::spawn_local(async move {
// HACK: For deferred tasks we yield a bunch of times to try and ensure
// we run our task at the very end of the async queue, this can fail if
// the user creates a bunch of interleaved deferred and normal tasks
if mode == TaskRunMode::Deferred {
for _ in 0..64 {
task::yield_now().await;
}
}
send_message(
lua,
match to_run.await {
Ok(_) => LuneMessage::Finished,
Err(LuaError::CoroutineInactive) => LuneMessage::Finished, // Task was canceled
Err(e) => LuneMessage::LuaError(e),
},
)
.await
});
// Wait for the task to complete if we want this call to be blocking
// Any lua errors will be sent through the message channel back
// to the main thread which will then handle them properly
if mode == TaskRunMode::Blocking {
task.await
.map_err(LuaError::external)?
.map_err(LuaError::external)?;
}
// Yield once right away to let the above spawned task start working
// instantly, forcing it to run until completion or until it yields
task::yield_now().await;
Ok(())
}

View file

@ -1,3 +1 @@
stdio.write("Hello, stdout!")
process.exit(0)

View file

@ -10,18 +10,18 @@ task.defer(function()
flag = true
end)
assert(not flag, "Defer should not run instantly or block")
task.wait(0.1)
task.wait(0.05)
assert(flag, "Defer should run")
-- Deferred functions should work with yielding
local flag2: boolean = false
task.defer(function()
task.wait(0.1)
task.wait(0.05)
flag2 = true
end)
assert(not flag2, "Defer should work with yielding (1)")
task.wait(0.2)
task.wait(0.1)
assert(flag2, "Defer should work with yielding (2)")
-- Deferred functions should run after other spawned threads

View file

@ -10,20 +10,20 @@ task.delay(0, function()
flag = true
end)
assert(not flag, "Delay should not run instantly or block")
task.wait(1 / 60)
task.wait(0.05)
assert(flag, "Delay should run after the wanted duration")
-- Delayed functions should work with yielding
local flag2: boolean = false
task.delay(0.2, function()
task.delay(0.05, function()
flag2 = true
task.wait(0.4)
task.wait(0.1)
flag2 = false
end)
task.wait(0.4)
task.wait(0.1)
assert(flag, "Delay should work with yielding (1)")
task.wait(0.4)
task.wait(0.1)
assert(not flag2, "Delay should work with yielding (2)")
-- Varargs should get passed correctly

View file

@ -15,11 +15,11 @@ assert(flag, "Spawn should run instantly")
local flag2: boolean = false
task.spawn(function()
task.wait(0.1)
task.wait(0.05)
flag2 = true
end)
assert(not flag2, "Spawn should work with yielding (1)")
task.wait(0.2)
task.wait(0.1)
assert(flag2, "Spawn should work with yielding (2)")
-- Spawned functions should be able to run threads created with the coroutine global

View file

@ -5,18 +5,28 @@ local EPSILON = 1 / 100
local function test(expected: number)
local start = os.clock()
local returned = task.wait(expected)
if typeof(returned) ~= "number" then
error(
string.format(
"Expected task.wait to return a number, got %s %s",
typeof(returned),
stdio.format(returned)
),
2
)
end
local elapsed = (os.clock() - start)
local difference = math.abs(elapsed - returned)
local difference = math.abs(elapsed - expected)
if difference > EPSILON then
error(
string.format(
"Elapsed time diverged too much from argument!"
.. "\nGot argument of %.3fs and elapsed time of %.3fs"
.. "\nGot maximum difference of %.3fs and real difference of %.3fs",
expected,
elapsed,
EPSILON,
difference
.. "\nGot argument of %.3fms and elapsed time of %.3fms"
.. "\nGot maximum difference of %.3fms and real difference of %.3fms",
expected * 1_000,
elapsed * 1_000,
EPSILON * 1_000,
difference * 1_000
)
)
end