mirror of
https://github.com/CompeyDev/lune-packaging.git
synced 2025-01-09 12:19:09 +00:00
Initial implementation of proper task scheduler, no async yet
This commit is contained in:
parent
bb182033b9
commit
fc5de3c8d5
23 changed files with 685 additions and 430 deletions
|
@ -22,7 +22,7 @@ mod tests;
|
|||
|
||||
use cli::Cli;
|
||||
|
||||
#[tokio::main]
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<ExitCode> {
|
||||
Cli::parse().run().await
|
||||
}
|
||||
|
|
|
@ -36,20 +36,20 @@ async fn ensure_file_exists_and_is_not_json(file_name: &str) -> Result<()> {
|
|||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn list() -> Result<()> {
|
||||
Cli::list().run().await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn download_selene_types() -> Result<()> {
|
||||
run_cli(Cli::download_selene_types()).await?;
|
||||
ensure_file_exists_and_is_not_json(LUNE_SELENE_FILE_NAME).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn download_luau_types() -> Result<()> {
|
||||
run_cli(Cli::download_luau_types()).await?;
|
||||
ensure_file_exists_and_is_not_json(LUNE_LUAU_FILE_NAME).await?;
|
||||
|
|
|
@ -2,6 +2,8 @@ use std::fmt::{Display, Formatter, Result as FmtResult};
|
|||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use crate::lua::task::TaskScheduler;
|
||||
|
||||
mod fs;
|
||||
mod net;
|
||||
mod process;
|
||||
|
@ -78,14 +80,18 @@ impl LuneGlobal {
|
|||
Note that proxy globals should be handled with special care and that [`LuneGlobal::inject()`]
|
||||
should be preferred over manually creating and manipulating the value(s) of any Lune global.
|
||||
*/
|
||||
pub fn value(&self, lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
pub fn value(
|
||||
&self,
|
||||
lua: &'static Lua,
|
||||
scheduler: &'static TaskScheduler,
|
||||
) -> LuaResult<LuaTable> {
|
||||
match self {
|
||||
LuneGlobal::Fs => fs::create(lua),
|
||||
LuneGlobal::Net => net::create(lua),
|
||||
LuneGlobal::Process { args } => process::create(lua, args.clone()),
|
||||
LuneGlobal::Require => require::create(lua),
|
||||
LuneGlobal::Stdio => stdio::create(lua),
|
||||
LuneGlobal::Task => task::create(lua),
|
||||
LuneGlobal::Task => task::create(lua, scheduler),
|
||||
LuneGlobal::TopLevel => top_level::create(lua),
|
||||
}
|
||||
}
|
||||
|
@ -98,9 +104,9 @@ impl LuneGlobal {
|
|||
|
||||
Refer to [`LuneGlobal::is_top_level()`] for more info on proxy globals.
|
||||
*/
|
||||
pub fn inject(self, lua: &'static Lua) -> LuaResult<()> {
|
||||
pub fn inject(self, lua: &'static Lua, scheduler: &'static TaskScheduler) -> LuaResult<()> {
|
||||
let globals = lua.globals();
|
||||
let table = self.value(lua)?;
|
||||
let table = self.value(lua, scheduler)?;
|
||||
// NOTE: Top level globals are special, the values
|
||||
// *in* the table they return should be set directly,
|
||||
// instead of setting the table itself as the global
|
||||
|
|
|
@ -17,10 +17,7 @@ use tokio::{sync::mpsc, task};
|
|||
|
||||
use crate::{
|
||||
lua::net::{NetClient, NetClientBuilder, NetWebSocketClient, NetWebSocketServer, ServeConfig},
|
||||
utils::{
|
||||
message::LuneMessage, net::get_request_user_agent_header, table::TableBuilder,
|
||||
task::send_message,
|
||||
},
|
||||
utils::{net::get_request_user_agent_header, table::TableBuilder},
|
||||
};
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
|
@ -29,7 +26,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
|||
let client = NetClientBuilder::new()
|
||||
.headers(&[("User-Agent", get_request_user_agent_header())])?
|
||||
.build()?;
|
||||
lua.set_named_registry_value("NetClient", client)?;
|
||||
lua.set_named_registry_value("net.client", client)?;
|
||||
// Create the global table for net
|
||||
TableBuilder::new(lua)?
|
||||
.with_function("jsonEncode", net_json_encode)?
|
||||
|
@ -54,7 +51,7 @@ fn net_json_decode(lua: &'static Lua, json: String) -> LuaResult<LuaValue> {
|
|||
}
|
||||
|
||||
async fn net_request<'a>(lua: &'static Lua, config: LuaValue<'a>) -> LuaResult<LuaTable<'a>> {
|
||||
let client: NetClient = lua.named_registry_value("NetClient")?;
|
||||
let client: NetClient = lua.named_registry_value("net.client")?;
|
||||
// Extract stuff from config and make sure its all valid
|
||||
let (url, method, headers, body) = match config {
|
||||
LuaValue::String(s) => {
|
||||
|
@ -177,20 +174,9 @@ async fn net_serve<'a>(
|
|||
shutdown_rx.recv().await.unwrap();
|
||||
shutdown_rx.close();
|
||||
});
|
||||
// Make sure we register the thread properly by sending messages
|
||||
// when the server starts up and when it shuts down or errors
|
||||
send_message(lua, LuneMessage::Spawned).await?;
|
||||
task::spawn_local(async move {
|
||||
let res = server.await.map_err(LuaError::external);
|
||||
let _ = send_message(
|
||||
lua,
|
||||
match res {
|
||||
Err(e) => LuneMessage::LuaError(e),
|
||||
Ok(_) => LuneMessage::Finished,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
});
|
||||
// TODO: Spawn a new scheduler future with this so we don't block
|
||||
// and make sure that we register it properly to prevent shutdown
|
||||
server.await.map_err(LuaError::external)?;
|
||||
// Create a new read-only table that contains methods
|
||||
// for manipulating server behavior and shutting it down
|
||||
let handle_stop = move |_, _: ()| {
|
||||
|
@ -313,8 +299,8 @@ impl Service<Request<Body>> for NetService {
|
|||
Ok(resp.body(body).unwrap())
|
||||
}
|
||||
// If the handler returns an error, generate a 5xx response
|
||||
Err(err) => {
|
||||
send_message(lua, LuneMessage::LuaError(err.to_lua_err())).await?;
|
||||
Err(_) => {
|
||||
// TODO: Send above error to task scheduler so that it can emit properly
|
||||
Ok(Response::builder()
|
||||
.status(500)
|
||||
.body(Body::from("Internal Server Error"))
|
||||
|
@ -323,10 +309,11 @@ impl Service<Request<Body>> for NetService {
|
|||
// If the handler returns a value that is of an invalid type,
|
||||
// this should also be an error, so generate a 5xx response
|
||||
Ok(value) => {
|
||||
send_message(lua, LuneMessage::LuaError(LuaError::RuntimeError(format!(
|
||||
// TODO: Send below error to task scheduler so that it can emit properly
|
||||
let _ = LuaError::RuntimeError(format!(
|
||||
"Expected net serve handler to return a value of type 'string' or 'table', got '{}'",
|
||||
value.type_name()
|
||||
)))).await?;
|
||||
));
|
||||
Ok(Response::builder()
|
||||
.status(500)
|
||||
.body(Body::from("Internal Server Error"))
|
||||
|
|
|
@ -1,21 +1,35 @@
|
|||
use std::{collections::HashMap, env, path::PathBuf, process::Stdio};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
env,
|
||||
path::PathBuf,
|
||||
process::{ExitCode, Stdio},
|
||||
};
|
||||
|
||||
use directories::UserDirs;
|
||||
use mlua::prelude::*;
|
||||
use os_str_bytes::RawOsString;
|
||||
use tokio::process::Command;
|
||||
|
||||
use crate::utils::{
|
||||
process::{exit_and_yield_forever, pipe_and_inherit_child_process_stdio},
|
||||
table::TableBuilder,
|
||||
use crate::{
|
||||
lua::task::TaskScheduler,
|
||||
utils::{process::pipe_and_inherit_child_process_stdio, table::TableBuilder},
|
||||
};
|
||||
|
||||
const PROCESS_EXIT_IMPL_LUA: &str = r#"
|
||||
exit(...)
|
||||
yield()
|
||||
"#;
|
||||
|
||||
pub fn create(lua: &'static Lua, args_vec: Vec<String>) -> LuaResult<LuaTable> {
|
||||
let cwd = env::current_dir()?.canonicalize()?;
|
||||
let mut cwd_str = cwd.to_string_lossy().to_string();
|
||||
if !cwd_str.ends_with('/') {
|
||||
cwd_str = format!("{cwd_str}/");
|
||||
}
|
||||
let cwd_str = {
|
||||
let cwd = env::current_dir()?.canonicalize()?;
|
||||
let cwd_str = cwd.to_string_lossy().to_string();
|
||||
if !cwd_str.ends_with('/') {
|
||||
format!("{cwd_str}/")
|
||||
} else {
|
||||
cwd_str
|
||||
}
|
||||
};
|
||||
// Create readonly args array
|
||||
let args_tab = TableBuilder::new(lua)?
|
||||
.with_sequential_values(args_vec)?
|
||||
|
@ -30,12 +44,31 @@ pub fn create(lua: &'static Lua, args_vec: Vec<String>) -> LuaResult<LuaTable> {
|
|||
.build_readonly()?,
|
||||
)?
|
||||
.build_readonly()?;
|
||||
// Create our process exit function, this is a bit involved since
|
||||
// we have no way to yield from c / rust, we need to load a lua
|
||||
// chunk that will set the exit code and yield for us instead
|
||||
let process_exit_env_yield: LuaFunction = lua.named_registry_value("co.yield")?;
|
||||
let process_exit_env_exit: LuaFunction = lua.create_function(|lua, code: Option<u8>| {
|
||||
let exit_code = code.map_or(ExitCode::SUCCESS, ExitCode::from);
|
||||
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
|
||||
sched.set_exit_code(exit_code);
|
||||
Ok(())
|
||||
})?;
|
||||
let process_exit = lua
|
||||
.load(PROCESS_EXIT_IMPL_LUA)
|
||||
.set_environment(
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("yield", process_exit_env_yield)?
|
||||
.with_value("exit", process_exit_env_exit)?
|
||||
.build_readonly()?,
|
||||
)?
|
||||
.into_function()?;
|
||||
// Create the full process table
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("args", args_tab)?
|
||||
.with_value("cwd", cwd_str)?
|
||||
.with_value("env", env_tab)?
|
||||
.with_async_function("exit", process_exit)?
|
||||
.with_value("exit", process_exit)?
|
||||
.with_async_function("spawn", process_spawn)?
|
||||
.build_readonly()
|
||||
}
|
||||
|
@ -109,10 +142,6 @@ fn process_env_iter<'lua>(
|
|||
})
|
||||
}
|
||||
|
||||
async fn process_exit(lua: &'static Lua, exit_code: Option<u8>) -> LuaResult<()> {
|
||||
exit_and_yield_forever(lua, exit_code).await
|
||||
}
|
||||
|
||||
async fn process_spawn<'a>(
|
||||
lua: &'static Lua,
|
||||
(mut program, args, options): (String, Option<Vec<String>>, Option<LuaTable<'a>>),
|
||||
|
|
|
@ -10,12 +10,10 @@ use os_str_bytes::{OsStrBytes, RawOsStr};
|
|||
use crate::utils::table::TableBuilder;
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
let require: LuaFunction = lua.globals().raw_get("require")?;
|
||||
// Preserve original require behavior if we have a special env var set
|
||||
if env::var_os("LUAU_PWD_REQUIRE").is_some() {
|
||||
return TableBuilder::new(lua)?
|
||||
.with_value("require", require)?
|
||||
.build_readonly();
|
||||
// Return an empty table since there are no globals to overwrite
|
||||
return TableBuilder::new(lua)?.build_readonly();
|
||||
}
|
||||
/*
|
||||
Store the current working directory so that we can use it later
|
||||
|
@ -27,24 +25,17 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
|||
just in case someone out there uses luau with non-utf8 string requires
|
||||
*/
|
||||
let pwd = lua.create_string(¤t_dir()?.to_raw_bytes())?;
|
||||
lua.set_named_registry_value("require_pwd", pwd)?;
|
||||
// Fetch the debug info function and store it in the registry
|
||||
// - we will use it to fetch the current scripts file name
|
||||
let debug: LuaTable = lua.globals().raw_get("debug")?;
|
||||
let info: LuaFunction = debug.raw_get("info")?;
|
||||
lua.set_named_registry_value("require_getinfo", info)?;
|
||||
// Store the original require function in the registry
|
||||
lua.set_named_registry_value("require_original", require)?;
|
||||
lua.set_named_registry_value("pwd", pwd)?;
|
||||
/*
|
||||
Create a new function that fetches the file name from the current thread,
|
||||
sets the luau module lookup path to be the exact script we are looking
|
||||
for, and then runs the original require function with the wanted path
|
||||
*/
|
||||
let new_require = lua.create_function(|lua, require_path: LuaString| {
|
||||
let require_pwd: LuaString = lua.named_registry_value("require_pwd")?;
|
||||
let require_original: LuaFunction = lua.named_registry_value("require_original")?;
|
||||
let require_getinfo: LuaFunction = lua.named_registry_value("require_getinfo")?;
|
||||
let require_source: LuaString = require_getinfo.call((2, "s"))?;
|
||||
let require_pwd: LuaString = lua.named_registry_value("pwd")?;
|
||||
let require_fn: LuaFunction = lua.named_registry_value("require")?;
|
||||
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
|
||||
let require_source: LuaString = require_info.call((2, "s"))?;
|
||||
/*
|
||||
Combine the require caller source with the wanted path
|
||||
string to get a final path relative to pwd - it is definitely
|
||||
|
@ -53,10 +44,15 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
|||
let raw_pwd_str = RawOsStr::assert_from_raw_bytes(require_pwd.as_bytes());
|
||||
let raw_source = RawOsStr::assert_from_raw_bytes(require_source.as_bytes());
|
||||
let raw_path = RawOsStr::assert_from_raw_bytes(require_path.as_bytes());
|
||||
let mut path_relative_to_pwd = PathBuf::from(&raw_source.to_os_str())
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join(raw_path.to_os_str());
|
||||
let mut path_relative_to_pwd = PathBuf::from(
|
||||
&raw_source
|
||||
.trim_start_matches("[string \"")
|
||||
.trim_end_matches("\"]")
|
||||
.to_os_str(),
|
||||
)
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join(raw_path.to_os_str());
|
||||
// Try to normalize and resolve relative path segments such as './' and '../'
|
||||
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("luau").canonicalize() {
|
||||
path_relative_to_pwd = canonicalized.with_extension("");
|
||||
|
@ -72,7 +68,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
|||
let lua_path_str = lua.create_string(raw_path_str.as_raw_bytes());
|
||||
// If the require call errors then we should also replace
|
||||
// the path in the error message to improve user experience
|
||||
let result: LuaResult<_> = require_original.call::<_, LuaValue>(lua_path_str);
|
||||
let result: LuaResult<_> = require_fn.call::<_, LuaValue>(lua_path_str);
|
||||
match result {
|
||||
Err(LuaError::CallbackError { traceback, cause }) => {
|
||||
let before = format!(
|
||||
|
|
|
@ -1,150 +1,84 @@
|
|||
use std::time::{Duration, Instant};
|
||||
|
||||
use mlua::prelude::*;
|
||||
use tokio::time;
|
||||
|
||||
use crate::utils::{
|
||||
table::TableBuilder,
|
||||
task::{run_registered_task, TaskRunMode},
|
||||
use crate::{
|
||||
lua::task::{TaskReference, TaskScheduler},
|
||||
utils::table::TableBuilder,
|
||||
};
|
||||
|
||||
const MINIMUM_WAIT_OR_DELAY_DURATION: f32 = 10.0 / 1_000.0; // 10ms
|
||||
const TASK_WAIT_IMPL_LUA: &str = r#"
|
||||
resume_after(thread(), ...)
|
||||
return yield()
|
||||
"#;
|
||||
|
||||
// TODO: We should probably keep track of all threads in a scheduler userdata
|
||||
// that takes care of scheduling in a better way, and it should keep resuming
|
||||
// threads until it encounters a delayed / waiting thread, then task:sleep
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
// HACK: There is no way to call coroutine.close directly from the mlua
|
||||
// crate, so we need to fetch the function and store it in the registry
|
||||
let coroutine: LuaTable = lua.globals().raw_get("coroutine")?;
|
||||
let close: LuaFunction = coroutine.raw_get("close")?;
|
||||
lua.set_named_registry_value("coroutine.close", close)?;
|
||||
// HACK: coroutine.resume has some weird scheduling issues, but our custom
|
||||
// task.spawn implementation is more or less a replacement for it, so we
|
||||
// overwrite the original coroutine.resume function with it to fix that
|
||||
coroutine.raw_set("resume", lua.create_async_function(task_spawn)?)?;
|
||||
// Rest of the task library is normal, just async functions, no metatable
|
||||
pub fn create(
|
||||
lua: &'static Lua,
|
||||
scheduler: &'static TaskScheduler,
|
||||
) -> LuaResult<LuaTable<'static>> {
|
||||
lua.set_app_data(scheduler);
|
||||
// Create task spawning functions that add tasks to the scheduler
|
||||
let task_spawn = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| {
|
||||
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
|
||||
sched.schedule_instant(tof, args)
|
||||
})?;
|
||||
let task_defer = lua.create_function(|lua, (tof, args): (LuaValue, LuaMultiValue)| {
|
||||
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
|
||||
sched.schedule_deferred(tof, args)
|
||||
})?;
|
||||
let task_delay =
|
||||
lua.create_function(|lua, (secs, tof, args): (f64, LuaValue, LuaMultiValue)| {
|
||||
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
|
||||
sched.schedule_delayed(secs, tof, args)
|
||||
})?;
|
||||
// Create our task wait function, this is a bit different since
|
||||
// we have no way to yield from c / rust, we need to load a
|
||||
// lua chunk that schedules and yields for us instead
|
||||
let task_wait_env_thread: LuaFunction = lua.named_registry_value("co.thread")?;
|
||||
let task_wait_env_yield: LuaFunction = lua.named_registry_value("co.yield")?;
|
||||
let task_wait = lua
|
||||
.load(TASK_WAIT_IMPL_LUA)
|
||||
.set_environment(
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("thread", task_wait_env_thread)?
|
||||
.with_value("yield", task_wait_env_yield)?
|
||||
.with_function(
|
||||
"resume_after",
|
||||
|lua, (thread, secs): (LuaThread, Option<f64>)| {
|
||||
let sched = &mut lua.app_data_mut::<&TaskScheduler>().unwrap();
|
||||
sched.resume_after(secs.unwrap_or(0f64), thread)
|
||||
},
|
||||
)?
|
||||
.build_readonly()?,
|
||||
)?
|
||||
.into_function()?;
|
||||
// We want the task scheduler to be transparent,
|
||||
// but it does not return real lua threads, so
|
||||
// we need to override some globals to fake it
|
||||
let globals = lua.globals();
|
||||
let type_original: LuaFunction = globals.get("type")?;
|
||||
let type_proxy = lua.create_function(move |_, value: LuaValue| {
|
||||
if let LuaValue::UserData(u) = &value {
|
||||
if u.is::<TaskReference>() {
|
||||
return Ok(LuaValue::String(lua.create_string("thread")?));
|
||||
}
|
||||
}
|
||||
type_original.call(value)
|
||||
})?;
|
||||
let typeof_original: LuaFunction = globals.get("typeof")?;
|
||||
let typeof_proxy = lua.create_function(move |_, value: LuaValue| {
|
||||
if let LuaValue::UserData(u) = &value {
|
||||
if u.is::<TaskReference>() {
|
||||
return Ok(LuaValue::String(lua.create_string("thread")?));
|
||||
}
|
||||
}
|
||||
typeof_original.call(value)
|
||||
})?;
|
||||
globals.set("type", type_proxy)?;
|
||||
globals.set("typeof", typeof_proxy)?;
|
||||
// All good, return the task scheduler lib
|
||||
TableBuilder::new(lua)?
|
||||
.with_async_function("cancel", task_cancel)?
|
||||
.with_async_function("delay", task_delay)?
|
||||
.with_async_function("defer", task_defer)?
|
||||
.with_async_function("spawn", task_spawn)?
|
||||
.with_async_function("wait", task_wait)?
|
||||
.with_value("spawn", task_spawn)?
|
||||
.with_value("defer", task_defer)?
|
||||
.with_value("delay", task_delay)?
|
||||
.with_value("wait", task_wait)?
|
||||
.build_readonly()
|
||||
}
|
||||
|
||||
fn tof_to_thread<'a>(
|
||||
lua: &'static Lua,
|
||||
thread_or_function: LuaValue<'a>,
|
||||
) -> LuaResult<LuaThread<'a>> {
|
||||
match thread_or_function {
|
||||
LuaValue::Thread(t) => Ok(t),
|
||||
LuaValue::Function(f) => Ok(lua.create_thread(f)?),
|
||||
value => Err(LuaError::RuntimeError(format!(
|
||||
"Argument must be a thread or function, got {}",
|
||||
value.type_name()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn task_cancel<'a>(lua: &'static Lua, thread: LuaThread<'a>) -> LuaResult<()> {
|
||||
let close: LuaFunction = lua.named_registry_value("coroutine.close")?;
|
||||
close.call_async::<_, LuaMultiValue>(thread).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn task_defer<'a>(
|
||||
lua: &'static Lua,
|
||||
(tof, args): (LuaValue<'a>, LuaMultiValue<'a>),
|
||||
) -> LuaResult<LuaThread<'a>> {
|
||||
let task_thread = tof_to_thread(lua, tof)?;
|
||||
let task_thread_key = lua.create_registry_value(task_thread)?;
|
||||
let task_args_key = lua.create_registry_value(args.into_vec())?;
|
||||
let lua_thread_to_return = lua.registry_value(&task_thread_key)?;
|
||||
run_registered_task(lua, TaskRunMode::Deferred, async move {
|
||||
let thread: LuaThread = lua.registry_value(&task_thread_key)?;
|
||||
let argsv: Vec<LuaValue> = lua.registry_value(&task_args_key)?;
|
||||
let args = LuaMultiValue::from_vec(argsv);
|
||||
if thread.status() == LuaThreadStatus::Resumable {
|
||||
let _: LuaMultiValue = thread.into_async(args).await?;
|
||||
}
|
||||
lua.remove_registry_value(task_thread_key)?;
|
||||
lua.remove_registry_value(task_args_key)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
Ok(lua_thread_to_return)
|
||||
}
|
||||
|
||||
async fn task_delay<'a>(
|
||||
lua: &'static Lua,
|
||||
(duration, tof, args): (Option<f32>, LuaValue<'a>, LuaMultiValue<'a>),
|
||||
) -> LuaResult<LuaThread<'a>> {
|
||||
let task_thread = tof_to_thread(lua, tof)?;
|
||||
let task_thread_key = lua.create_registry_value(task_thread)?;
|
||||
let task_args_key = lua.create_registry_value(args.into_vec())?;
|
||||
let lua_thread_to_return = lua.registry_value(&task_thread_key)?;
|
||||
let start = Instant::now();
|
||||
let dur = Duration::from_secs_f32(
|
||||
duration
|
||||
.map(|d| d.max(MINIMUM_WAIT_OR_DELAY_DURATION))
|
||||
.unwrap_or(MINIMUM_WAIT_OR_DELAY_DURATION),
|
||||
);
|
||||
run_registered_task(lua, TaskRunMode::Instant, async move {
|
||||
let thread: LuaThread = lua.registry_value(&task_thread_key)?;
|
||||
// NOTE: We are somewhat busy-waiting here, but we have to do this to make sure
|
||||
// that delayed+cancelled threads do not prevent the tokio runtime from finishing
|
||||
while thread.status() == LuaThreadStatus::Resumable && start.elapsed() < dur {
|
||||
time::sleep(Duration::from_millis(1)).await;
|
||||
}
|
||||
if thread.status() == LuaThreadStatus::Resumable {
|
||||
let argsv: Vec<LuaValue> = lua.registry_value(&task_args_key)?;
|
||||
let args = LuaMultiValue::from_vec(argsv);
|
||||
let _: LuaMultiValue = thread.into_async(args).await?;
|
||||
}
|
||||
lua.remove_registry_value(task_thread_key)?;
|
||||
lua.remove_registry_value(task_args_key)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
Ok(lua_thread_to_return)
|
||||
}
|
||||
|
||||
async fn task_spawn<'a>(
|
||||
lua: &'static Lua,
|
||||
(tof, args): (LuaValue<'a>, LuaMultiValue<'a>),
|
||||
) -> LuaResult<LuaThread<'a>> {
|
||||
let task_thread = tof_to_thread(lua, tof)?;
|
||||
let task_thread_key = lua.create_registry_value(task_thread)?;
|
||||
let task_args_key = lua.create_registry_value(args.into_vec())?;
|
||||
let lua_thread_to_return = lua.registry_value(&task_thread_key)?;
|
||||
run_registered_task(lua, TaskRunMode::Instant, async move {
|
||||
let thread: LuaThread = lua.registry_value(&task_thread_key)?;
|
||||
let argsv: Vec<LuaValue> = lua.registry_value(&task_args_key)?;
|
||||
let args = LuaMultiValue::from_vec(argsv);
|
||||
if thread.status() == LuaThreadStatus::Resumable {
|
||||
let _: LuaMultiValue = thread.into_async(args).await?;
|
||||
}
|
||||
lua.remove_registry_value(task_thread_key)?;
|
||||
lua.remove_registry_value(task_args_key)?;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
Ok(lua_thread_to_return)
|
||||
}
|
||||
|
||||
async fn task_wait(lua: &'static Lua, duration: Option<f32>) -> LuaResult<f32> {
|
||||
let start = Instant::now();
|
||||
run_registered_task(lua, TaskRunMode::Blocking, async move {
|
||||
time::sleep(Duration::from_secs_f32(
|
||||
duration
|
||||
.map(|d| d.max(MINIMUM_WAIT_OR_DELAY_DURATION))
|
||||
.unwrap_or(MINIMUM_WAIT_OR_DELAY_DURATION),
|
||||
))
|
||||
.await;
|
||||
Ok(())
|
||||
})
|
||||
.await?;
|
||||
let end = Instant::now();
|
||||
Ok((end - start).as_secs_f32())
|
||||
}
|
||||
|
|
|
@ -6,15 +6,10 @@ use crate::utils::{
|
|||
};
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
let globals = lua.globals();
|
||||
// HACK: We need to preserve the default behavior of the
|
||||
// print and error functions, for pcall and such, which
|
||||
// is really tricky to do from scratch so we will just
|
||||
// proxy the default print and error functions here
|
||||
let print_fn: LuaFunction = globals.raw_get("print")?;
|
||||
let error_fn: LuaFunction = globals.raw_get("error")?;
|
||||
lua.set_named_registry_value("print", print_fn)?;
|
||||
lua.set_named_registry_value("error", error_fn)?;
|
||||
TableBuilder::new(lua)?
|
||||
.with_function("print", |lua, args: LuaMultiValue| {
|
||||
let formatted = pretty_format_multi_value(&args)?;
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::{collections::HashSet, process::ExitCode, sync::Arc};
|
||||
use std::{collections::HashSet, process::ExitCode};
|
||||
|
||||
use lua::task::TaskScheduler;
|
||||
use mlua::prelude::*;
|
||||
use tokio::{sync::mpsc, task};
|
||||
use utils::task::send_message;
|
||||
use tokio::task::LocalSet;
|
||||
|
||||
pub(crate) mod globals;
|
||||
pub(crate) mod lua;
|
||||
|
@ -11,7 +11,7 @@ pub(crate) mod utils;
|
|||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
use crate::utils::{formatting::pretty_format_luau_error, message::LuneMessage};
|
||||
use crate::utils::formatting::pretty_format_luau_error;
|
||||
|
||||
pub use globals::LuneGlobal;
|
||||
|
||||
|
@ -75,12 +75,12 @@ impl Lune {
|
|||
This will create a new sandboxed Luau environment with the configured
|
||||
globals and arguments, running inside of a [`tokio::task::LocalSet`].
|
||||
|
||||
Some Lune globals such as [`LuneGlobal::Process`] may spawn
|
||||
separate tokio tasks on other threads, but the Luau environment
|
||||
Some Lune globals such as [`LuneGlobal::Process`] and [`LuneGlobal::Net`]
|
||||
may spawn separate tokio tasks on other threads, but the Luau environment
|
||||
itself is guaranteed to run on a single thread in the local set.
|
||||
|
||||
Note that this will create a static Lua instance that will live
|
||||
for the remainer of the program, and that this leaks memory using
|
||||
Note that this will create a static Lua instance and task scheduler which both
|
||||
will live for the remainer of the program, and that this leaks memory using
|
||||
[`Box::leak`] that will then get deallocated when the program exits.
|
||||
*/
|
||||
pub async fn run(
|
||||
|
@ -88,92 +88,64 @@ impl Lune {
|
|||
script_name: &str,
|
||||
script_contents: &str,
|
||||
) -> Result<ExitCode, LuaError> {
|
||||
let task_set = task::LocalSet::new();
|
||||
let (sender, mut receiver) = mpsc::channel::<LuneMessage>(64);
|
||||
let set = LocalSet::new();
|
||||
let lua = Lua::new().into_static();
|
||||
let snd = Arc::new(sender);
|
||||
lua.set_app_data(Arc::downgrade(&snd));
|
||||
let sched = TaskScheduler::new(lua)?.into_static();
|
||||
lua.set_app_data(sched);
|
||||
// Store original lua global functions in the registry so we can use
|
||||
// them later without passing them around and dealing with lifetimes
|
||||
lua.set_named_registry_value("require", lua.globals().get::<_, LuaFunction>("require")?)?;
|
||||
lua.set_named_registry_value("print", lua.globals().get::<_, LuaFunction>("print")?)?;
|
||||
lua.set_named_registry_value("error", lua.globals().get::<_, LuaFunction>("error")?)?;
|
||||
let coroutine: LuaTable = lua.globals().get("coroutine")?;
|
||||
lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?;
|
||||
lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?;
|
||||
let debug: LuaTable = lua.globals().raw_get("debug")?;
|
||||
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;
|
||||
// Add in wanted lune globals
|
||||
for global in self.includes.clone() {
|
||||
if !self.excludes.contains(&global) {
|
||||
global.inject(lua)?;
|
||||
global.inject(lua, sched)?;
|
||||
}
|
||||
}
|
||||
// Spawn the main thread from our entrypoint script
|
||||
let script_name = script_name.to_string();
|
||||
let script_chunk = script_contents.to_string();
|
||||
send_message(lua, LuneMessage::Spawned).await?;
|
||||
task_set.spawn_local(async move {
|
||||
let result = lua
|
||||
.load(&script_chunk)
|
||||
.set_name(&format!("={script_name}"))
|
||||
.unwrap()
|
||||
.eval_async::<LuaValue>()
|
||||
.await;
|
||||
send_message(
|
||||
lua,
|
||||
match result {
|
||||
Err(e) => LuneMessage::LuaError(e),
|
||||
Ok(_) => LuneMessage::Finished,
|
||||
},
|
||||
)
|
||||
.await
|
||||
});
|
||||
// Run the executor until there are no tasks left,
|
||||
// taking care to not exit right away for errors
|
||||
let (got_code, got_error, exit_code) = task_set
|
||||
// Schedule the main thread on the task scheduler
|
||||
sched.schedule_instant(
|
||||
LuaValue::Function(
|
||||
lua.load(script_contents)
|
||||
.set_name(script_name)
|
||||
.unwrap()
|
||||
.into_function()
|
||||
.unwrap(),
|
||||
),
|
||||
LuaValue::Nil.to_lua_multi(lua)?,
|
||||
)?;
|
||||
// Keep running the scheduler until there are either no tasks
|
||||
// left to run, or until some task requests to exit the process
|
||||
let exit_code = set
|
||||
.run_until(async {
|
||||
let mut task_count = 0;
|
||||
let mut got_error = false;
|
||||
let mut got_code = false;
|
||||
let mut exit_code = 0;
|
||||
while let Some(message) = receiver.recv().await {
|
||||
// Make sure our task-count-modifying messages are sent correctly, one
|
||||
// task spawned must always correspond to one task finished / errored
|
||||
match &message {
|
||||
LuneMessage::Exit(_) => {}
|
||||
LuneMessage::Spawned => {}
|
||||
message => {
|
||||
if task_count == 0 {
|
||||
return Err(format!(
|
||||
"Got message while task count was 0!\nMessage: {message:#?}"
|
||||
));
|
||||
while let Some(result) = sched.resume_queue().await {
|
||||
match result {
|
||||
Err(e) => {
|
||||
eprintln!("{}", pretty_format_luau_error(&e));
|
||||
got_error = true;
|
||||
}
|
||||
Ok(status) => {
|
||||
if let Some(exit_code) = status.exit_code {
|
||||
return exit_code;
|
||||
} else if status.num_total == 0 {
|
||||
return ExitCode::SUCCESS;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Handle whatever message we got
|
||||
match message {
|
||||
LuneMessage::Exit(code) => {
|
||||
exit_code = code;
|
||||
got_code = true;
|
||||
break;
|
||||
}
|
||||
LuneMessage::Spawned => task_count += 1,
|
||||
LuneMessage::Finished => task_count -= 1,
|
||||
LuneMessage::LuaError(e) => {
|
||||
eprintln!("{}", pretty_format_luau_error(&e));
|
||||
got_error = true;
|
||||
task_count -= 1;
|
||||
}
|
||||
};
|
||||
// If there are no tasks left running, it is now
|
||||
// safe to close the receiver and end execution
|
||||
if task_count == 0 {
|
||||
receiver.close();
|
||||
}
|
||||
}
|
||||
Ok((got_code, got_error, exit_code))
|
||||
if got_error {
|
||||
ExitCode::FAILURE
|
||||
} else {
|
||||
ExitCode::SUCCESS
|
||||
}
|
||||
})
|
||||
.await
|
||||
.map_err(LuaError::external)?;
|
||||
// If we got an error, we will default to exiting
|
||||
// with code 1, unless a code was manually given
|
||||
if got_code {
|
||||
Ok(ExitCode::from(exit_code))
|
||||
} else if got_error {
|
||||
Ok(ExitCode::FAILURE)
|
||||
} else {
|
||||
Ok(ExitCode::SUCCESS)
|
||||
}
|
||||
.await;
|
||||
Ok(exit_code)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
pub mod net;
|
||||
pub mod task;
|
||||
|
|
3
packages/lib/src/lua/task/mod.rs
Normal file
3
packages/lib/src/lua/task/mod.rs
Normal file
|
@ -0,0 +1,3 @@
|
|||
mod scheduler;
|
||||
|
||||
pub use scheduler::*;
|
406
packages/lib/src/lua/task/scheduler.rs
Normal file
406
packages/lib/src/lua/task/scheduler.rs
Normal file
|
@ -0,0 +1,406 @@
|
|||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
fmt,
|
||||
process::ExitCode,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicUsize, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use tokio::time::{sleep, Instant};
|
||||
|
||||
type TaskSchedulerQueue = Arc<Mutex<VecDeque<TaskReference>>>;
|
||||
|
||||
/// An enum representing different kinds of tasks
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub enum TaskKind {
|
||||
Instant,
|
||||
Deferred,
|
||||
Yielded,
|
||||
}
|
||||
|
||||
impl fmt::Display for TaskKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let name: &'static str = match self {
|
||||
TaskKind::Instant => "Instant",
|
||||
TaskKind::Deferred => "Deferred",
|
||||
TaskKind::Yielded => "Yielded",
|
||||
};
|
||||
write!(f, "{name}")
|
||||
}
|
||||
}
|
||||
|
||||
/// A lightweight, clonable struct that represents a
|
||||
/// task in the scheduler and is accessible from Lua
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
|
||||
pub struct TaskReference {
|
||||
kind: TaskKind,
|
||||
guid: usize,
|
||||
queued_target: Option<Instant>,
|
||||
}
|
||||
|
||||
impl TaskReference {
|
||||
pub const fn new(kind: TaskKind, guid: usize, queued_target: Option<Instant>) -> Self {
|
||||
Self {
|
||||
kind,
|
||||
guid,
|
||||
queued_target,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for TaskReference {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "TaskReference({} - {})", self.kind, self.guid)
|
||||
}
|
||||
}
|
||||
|
||||
impl LuaUserData for TaskReference {}
|
||||
|
||||
impl From<&Task> for TaskReference {
|
||||
fn from(value: &Task) -> Self {
|
||||
Self::new(value.kind, value.guid, value.queued_target)
|
||||
}
|
||||
}
|
||||
|
||||
/// A struct representing a task contained in the task scheduler
|
||||
#[derive(Debug)]
|
||||
pub struct Task {
|
||||
kind: TaskKind,
|
||||
guid: usize,
|
||||
thread: LuaRegistryKey,
|
||||
args: LuaRegistryKey,
|
||||
queued_at: Instant,
|
||||
queued_target: Option<Instant>,
|
||||
}
|
||||
|
||||
/// A struct representing the current status of the task scheduler
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct TaskSchedulerStatus {
|
||||
pub exit_code: Option<ExitCode>,
|
||||
pub num_instant: usize,
|
||||
pub num_deferred: usize,
|
||||
pub num_yielded: usize,
|
||||
pub num_total: usize,
|
||||
}
|
||||
|
||||
impl fmt::Display for TaskSchedulerStatus {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"TaskSchedulerStatus(\nInstant: {}\nDeferred: {}\nYielded: {}\nTotal: {})",
|
||||
self.num_instant, self.num_deferred, self.num_yielded, self.num_total
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// A task scheduler that implements task queues
|
||||
/// with instant, deferred, and delayed tasks
|
||||
#[derive(Debug)]
|
||||
pub struct TaskScheduler {
|
||||
lua: &'static Lua,
|
||||
guid: AtomicUsize,
|
||||
running: bool,
|
||||
tasks: Arc<Mutex<HashMap<TaskReference, Task>>>,
|
||||
task_queue_instant: TaskSchedulerQueue,
|
||||
task_queue_deferred: TaskSchedulerQueue,
|
||||
task_queue_yielded: TaskSchedulerQueue,
|
||||
exit_code_set: AtomicBool,
|
||||
exit_code: Arc<Mutex<ExitCode>>,
|
||||
}
|
||||
|
||||
impl TaskScheduler {
|
||||
pub fn new(lua: &'static Lua) -> LuaResult<Self> {
|
||||
Ok(Self {
|
||||
lua,
|
||||
guid: AtomicUsize::new(0),
|
||||
running: false,
|
||||
tasks: Arc::new(Mutex::new(HashMap::new())),
|
||||
task_queue_instant: Arc::new(Mutex::new(VecDeque::new())),
|
||||
task_queue_deferred: Arc::new(Mutex::new(VecDeque::new())),
|
||||
task_queue_yielded: Arc::new(Mutex::new(VecDeque::new())),
|
||||
exit_code_set: AtomicBool::new(false),
|
||||
exit_code: Arc::new(Mutex::new(ExitCode::SUCCESS)),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_static(self) -> &'static Self {
|
||||
Box::leak(Box::new(self))
|
||||
}
|
||||
|
||||
pub fn status(&self) -> TaskSchedulerStatus {
|
||||
let counts = {
|
||||
(
|
||||
self.task_queue_instant.lock().unwrap().len(),
|
||||
self.task_queue_deferred.lock().unwrap().len(),
|
||||
self.task_queue_yielded.lock().unwrap().len(),
|
||||
)
|
||||
};
|
||||
let num_total = counts.0 + counts.1 + counts.2;
|
||||
let exit_code = if self.exit_code_set.load(Ordering::Relaxed) {
|
||||
Some(*self.exit_code.lock().unwrap())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
TaskSchedulerStatus {
|
||||
exit_code,
|
||||
num_instant: counts.0,
|
||||
num_deferred: counts.1,
|
||||
num_yielded: counts.2,
|
||||
num_total,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_exit_code(&self, code: ExitCode) {
|
||||
self.exit_code_set.store(true, Ordering::Relaxed);
|
||||
*self.exit_code.lock().unwrap() = code
|
||||
}
|
||||
|
||||
fn schedule<'a>(
|
||||
&self,
|
||||
kind: TaskKind,
|
||||
tof: LuaValue<'a>,
|
||||
args: Option<LuaMultiValue<'a>>,
|
||||
delay: Option<f64>,
|
||||
) -> LuaResult<TaskReference> {
|
||||
// Get or create a thread from the given argument
|
||||
let task_thread = match tof {
|
||||
LuaValue::Thread(t) => t,
|
||||
LuaValue::Function(f) => self.lua.create_thread(f)?,
|
||||
value => {
|
||||
return Err(LuaError::RuntimeError(format!(
|
||||
"Argument must be a thread or function, got {}",
|
||||
value.type_name()
|
||||
)))
|
||||
}
|
||||
};
|
||||
// Store the thread and its arguments in the registry
|
||||
let task_args_vec = args.map(|opt| opt.into_vec());
|
||||
let task_thread_key = self.lua.create_registry_value(task_thread)?;
|
||||
let task_args_key = self.lua.create_registry_value(task_args_vec)?;
|
||||
// Create the full task struct
|
||||
let guid = self.guid.fetch_add(1, Ordering::Relaxed) + 1;
|
||||
let queued_at = Instant::now();
|
||||
let queued_target = delay.map(|secs| queued_at + Duration::from_secs_f64(secs));
|
||||
let task = Task {
|
||||
kind,
|
||||
guid,
|
||||
thread: task_thread_key,
|
||||
args: task_args_key,
|
||||
queued_at,
|
||||
queued_target,
|
||||
};
|
||||
// Create the task ref (before adding the task to the scheduler)
|
||||
let task_ref = TaskReference::from(&task);
|
||||
// Add it to the scheduler
|
||||
{
|
||||
let mut tasks = self.tasks.lock().unwrap();
|
||||
tasks.insert(task_ref, task);
|
||||
}
|
||||
match kind {
|
||||
TaskKind::Instant => {
|
||||
// If we have a currently running task and we spawned an
|
||||
// instant task here it should run right after the currently
|
||||
// running task, so put it at the front of the task queue
|
||||
let mut queue = self.task_queue_instant.lock().unwrap();
|
||||
if self.running {
|
||||
queue.push_front(task_ref);
|
||||
} else {
|
||||
queue.push_back(task_ref);
|
||||
}
|
||||
}
|
||||
TaskKind::Deferred => {
|
||||
// Deferred tasks should always schedule
|
||||
// at the very end of the deferred queue
|
||||
let mut queue = self.task_queue_deferred.lock().unwrap();
|
||||
queue.push_back(task_ref);
|
||||
}
|
||||
TaskKind::Yielded => {
|
||||
// Find the first task that is scheduled after this one and insert before it,
|
||||
// this will ensure that our list of delayed tasks is sorted and we can grab
|
||||
// the very first one to figure out how long to yield until the next cycle
|
||||
let mut queue = self.task_queue_yielded.lock().unwrap();
|
||||
let idx = queue
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find_map(|(idx, t)| {
|
||||
if t.queued_target > queued_target {
|
||||
Some(idx)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.unwrap_or(queue.len());
|
||||
queue.insert(idx, task_ref);
|
||||
}
|
||||
}
|
||||
Ok(task_ref)
|
||||
}
|
||||
|
||||
pub fn schedule_instant<'a>(
|
||||
&self,
|
||||
tof: LuaValue<'a>,
|
||||
args: LuaMultiValue<'a>,
|
||||
) -> LuaResult<TaskReference> {
|
||||
self.schedule(TaskKind::Instant, tof, Some(args), None)
|
||||
}
|
||||
|
||||
pub fn schedule_deferred<'a>(
|
||||
&self,
|
||||
tof: LuaValue<'a>,
|
||||
args: LuaMultiValue<'a>,
|
||||
) -> LuaResult<TaskReference> {
|
||||
self.schedule(TaskKind::Deferred, tof, Some(args), None)
|
||||
}
|
||||
|
||||
pub fn schedule_delayed<'a>(
|
||||
&self,
|
||||
secs: f64,
|
||||
tof: LuaValue<'a>,
|
||||
args: LuaMultiValue<'a>,
|
||||
) -> LuaResult<TaskReference> {
|
||||
self.schedule(TaskKind::Yielded, tof, Some(args), Some(secs))
|
||||
}
|
||||
|
||||
pub fn resume_after(&self, secs: f64, thread: LuaThread<'_>) -> LuaResult<TaskReference> {
|
||||
self.schedule(
|
||||
TaskKind::Yielded,
|
||||
LuaValue::Thread(thread),
|
||||
None,
|
||||
Some(secs),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn cancel(&self, reference: TaskReference) -> bool {
|
||||
let queue_mutex = match reference.kind {
|
||||
TaskKind::Instant => &self.task_queue_instant,
|
||||
TaskKind::Deferred => &self.task_queue_deferred,
|
||||
TaskKind::Yielded => &self.task_queue_yielded,
|
||||
};
|
||||
let mut queue = queue_mutex.lock().unwrap();
|
||||
let mut found = false;
|
||||
queue.retain(|task| {
|
||||
if task.guid == reference.guid {
|
||||
found = true;
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
found
|
||||
}
|
||||
|
||||
pub fn resume_task(&self, reference: TaskReference) -> LuaResult<()> {
|
||||
let task = {
|
||||
let mut tasks = self.tasks.lock().unwrap();
|
||||
match tasks.remove(&reference) {
|
||||
Some(task) => task,
|
||||
None => {
|
||||
return Err(LuaError::RuntimeError(format!(
|
||||
"Task does not exist in scheduler: {reference}"
|
||||
)))
|
||||
}
|
||||
}
|
||||
};
|
||||
let thread: LuaThread = self.lua.registry_value(&task.thread)?;
|
||||
let args: Option<Vec<LuaValue>> = self.lua.registry_value(&task.args)?;
|
||||
if let Some(args) = args {
|
||||
thread.resume::<_, LuaMultiValue>(LuaMultiValue::from_vec(args))?;
|
||||
} else {
|
||||
let elapsed = task.queued_at.elapsed().as_secs_f64();
|
||||
thread.resume::<_, LuaMultiValue>(elapsed)?;
|
||||
}
|
||||
self.lua.remove_registry_value(task.thread)?;
|
||||
self.lua.remove_registry_value(task.args)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_queue(&self, kind: TaskKind) -> &TaskSchedulerQueue {
|
||||
match kind {
|
||||
TaskKind::Instant => &self.task_queue_instant,
|
||||
TaskKind::Deferred => &self.task_queue_deferred,
|
||||
TaskKind::Yielded => &self.task_queue_yielded,
|
||||
}
|
||||
}
|
||||
|
||||
fn next_queue_task(&self, kind: TaskKind) -> Option<TaskReference> {
|
||||
let task = {
|
||||
let queue_guard = self.get_queue(kind).lock().unwrap();
|
||||
queue_guard.front().copied()
|
||||
};
|
||||
task
|
||||
}
|
||||
|
||||
fn resume_next_queue_task(&self, kind: TaskKind) -> Option<LuaResult<TaskSchedulerStatus>> {
|
||||
match {
|
||||
let mut queue_guard = self.get_queue(kind).lock().unwrap();
|
||||
queue_guard.pop_front()
|
||||
} {
|
||||
None => {
|
||||
let status = self.status();
|
||||
if status.num_total > 0 {
|
||||
Some(Ok(status))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Some(t) => match self.resume_task(t) {
|
||||
Ok(_) => Some(Ok(self.status())),
|
||||
Err(e) => Some(Err(e)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn resume_queue(&self) -> Option<LuaResult<TaskSchedulerStatus>> {
|
||||
let now = Instant::now();
|
||||
let status = self.status();
|
||||
/*
|
||||
Resume tasks in the internal queue, in this order:
|
||||
|
||||
1. Tasks from task.spawn, this includes the main thread
|
||||
2. Tasks from task.defer
|
||||
3. Tasks from task.delay OR futures, whichever comes first
|
||||
4. Tasks from futures
|
||||
*/
|
||||
if status.num_instant > 0 {
|
||||
self.resume_next_queue_task(TaskKind::Instant)
|
||||
} else if status.num_deferred > 0 {
|
||||
self.resume_next_queue_task(TaskKind::Deferred)
|
||||
} else if status.num_yielded > 0 {
|
||||
// 3. Threads from task.delay or task.wait, futures
|
||||
let next_yield_target = self
|
||||
.next_queue_task(TaskKind::Yielded)
|
||||
.expect("Yielded task missing but status count is > 0")
|
||||
.queued_target
|
||||
.expect("Yielded task is missing queued target");
|
||||
// Resume this yielding task if its target time has passed
|
||||
if now >= next_yield_target {
|
||||
self.resume_next_queue_task(TaskKind::Yielded)
|
||||
} else {
|
||||
/*
|
||||
Await the first future to be ready
|
||||
|
||||
- If it is the sleep fut then we will return and the next
|
||||
call to resume_queue will then resume that yielded task
|
||||
|
||||
- If it is a future then we resume the corresponding task
|
||||
that is has stored in the future-specific task queue
|
||||
*/
|
||||
sleep(next_yield_target - now).await;
|
||||
// TODO: Implement this, for now we only await sleep
|
||||
// since the task scheduler doesn't support futures
|
||||
Some(Ok(self.status()))
|
||||
}
|
||||
} else {
|
||||
// 4. Just futures
|
||||
|
||||
// TODO: Await the first future to be ready
|
||||
// and resume the corresponding task for it
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
|
@ -11,7 +11,7 @@ const ARGS: &[&str] = &["Foo", "Bar"];
|
|||
|
||||
macro_rules! create_tests {
|
||||
($($name:ident: $value:expr,)*) => { $(
|
||||
#[tokio::test]
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn $name() -> Result<ExitCode> {
|
||||
// Disable styling for stdout and stderr since
|
||||
// some tests rely on output not being styled
|
||||
|
|
|
@ -4,6 +4,8 @@ use console::{style, Style};
|
|||
use lazy_static::lazy_static;
|
||||
use mlua::prelude::*;
|
||||
|
||||
use crate::lua::task::TaskReference;
|
||||
|
||||
const MAX_FORMAT_DEPTH: usize = 4;
|
||||
|
||||
const INDENT: &str = " ";
|
||||
|
@ -165,9 +167,17 @@ pub fn pretty_format_value(
|
|||
)?,
|
||||
LuaValue::Thread(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to("<thread>"))?,
|
||||
LuaValue::Function(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to("<function>"))?,
|
||||
LuaValue::UserData(_) | LuaValue::LightUserData(_) => {
|
||||
write!(buffer, "{}", COLOR_PURPLE.apply_to("<userdata>"))?
|
||||
LuaValue::UserData(u) => {
|
||||
if u.is::<TaskReference>() {
|
||||
// Task references must be transparent
|
||||
// to lua and pretend to be normal lua
|
||||
// threads for compatibility purposes
|
||||
write!(buffer, "{}", COLOR_PURPLE.apply_to("<thread>"))?
|
||||
} else {
|
||||
write!(buffer, "{}", COLOR_PURPLE.apply_to("<userdata>"))?
|
||||
}
|
||||
}
|
||||
LuaValue::LightUserData(_) => write!(buffer, "{}", COLOR_PURPLE.apply_to("<userdata>"))?,
|
||||
_ => write!(buffer, "{}", STYLE_DIM.apply_to("?"))?,
|
||||
}
|
||||
Ok(())
|
||||
|
@ -220,23 +230,30 @@ pub fn pretty_format_luau_error(e: &LuaError) -> String {
|
|||
err_lines.join("\n")
|
||||
}
|
||||
LuaError::CallbackError { traceback, cause } => {
|
||||
// Find the best traceback (longest) and the root error message
|
||||
// Find the best traceback (most lines) and the root error message
|
||||
let mut best_trace = traceback;
|
||||
let mut root_cause = cause.as_ref();
|
||||
while let LuaError::CallbackError { cause, traceback } = root_cause {
|
||||
if traceback.len() > best_trace.len() {
|
||||
if traceback.lines().count() > best_trace.len() {
|
||||
best_trace = traceback;
|
||||
}
|
||||
root_cause = cause;
|
||||
}
|
||||
// Same error formatting as above
|
||||
format!(
|
||||
"{}\n{}\n{}\n{}",
|
||||
pretty_format_luau_error(root_cause),
|
||||
stack_begin,
|
||||
best_trace.strip_prefix("stack traceback:\n").unwrap(),
|
||||
stack_end
|
||||
)
|
||||
// If we got a runtime error with an embedded traceback, we should
|
||||
// use that instead since it generally contains more information
|
||||
if matches!(root_cause, LuaError::RuntimeError(e) if e.contains("stack traceback:")) {
|
||||
pretty_format_luau_error(root_cause)
|
||||
} else {
|
||||
// Otherwise we format whatever root error we got using
|
||||
// the same error formatting as for above runtime errors
|
||||
format!(
|
||||
"{}\n{}\n{}\n{}",
|
||||
pretty_format_luau_error(root_cause),
|
||||
stack_begin,
|
||||
best_trace.strip_prefix("stack traceback:\n").unwrap(),
|
||||
stack_end
|
||||
)
|
||||
}
|
||||
}
|
||||
LuaError::ToLuaConversionError { from, to, message } => {
|
||||
let msg = message
|
||||
|
|
|
@ -1,9 +0,0 @@
|
|||
use mlua::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LuneMessage {
|
||||
Exit(u8),
|
||||
Spawned,
|
||||
Finished,
|
||||
LuaError(LuaError),
|
||||
}
|
|
@ -1,7 +1,5 @@
|
|||
pub mod formatting;
|
||||
pub mod futures;
|
||||
pub mod message;
|
||||
pub mod net;
|
||||
pub mod process;
|
||||
pub mod table;
|
||||
pub mod task;
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
use std::{process::ExitStatus, time::Duration};
|
||||
use std::process::ExitStatus;
|
||||
|
||||
use mlua::prelude::*;
|
||||
use tokio::{io, process::Child, task::spawn, time::sleep};
|
||||
use tokio::{io, process::Child, task::spawn};
|
||||
|
||||
use crate::utils::{futures::AsyncTeeWriter, message::LuneMessage};
|
||||
|
||||
use super::task::send_message;
|
||||
use crate::utils::futures::AsyncTeeWriter;
|
||||
|
||||
pub async fn pipe_and_inherit_child_process_stdio(
|
||||
mut child: Child,
|
||||
|
@ -42,13 +40,3 @@ pub async fn pipe_and_inherit_child_process_stdio(
|
|||
|
||||
Ok::<_, LuaError>((status, stdout_buffer?, stderr_buffer?))
|
||||
}
|
||||
|
||||
pub async fn exit_and_yield_forever(lua: &'static Lua, exit_code: Option<u8>) -> LuaResult<()> {
|
||||
// Send an exit signal to the main thread, which
|
||||
// will try to exit safely and as soon as possible
|
||||
send_message(lua, LuneMessage::Exit(exit_code.unwrap_or(0))).await?;
|
||||
// Make sure to block the rest of this thread indefinitely since
|
||||
// the main thread may not register the exit signal right away
|
||||
sleep(Duration::MAX).await;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -1,76 +0,0 @@
|
|||
use std::fmt::{self, Debug};
|
||||
use std::future::Future;
|
||||
use std::sync::Weak;
|
||||
|
||||
use mlua::prelude::*;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tokio::task;
|
||||
|
||||
use crate::utils::message::LuneMessage;
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum TaskRunMode {
|
||||
Blocking,
|
||||
Instant,
|
||||
Deferred,
|
||||
}
|
||||
|
||||
impl fmt::Display for TaskRunMode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Blocking => write!(f, "Blocking"),
|
||||
Self::Instant => write!(f, "Instant"),
|
||||
Self::Deferred => write!(f, "Deferred"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_message(lua: &'static Lua, message: LuneMessage) -> LuaResult<()> {
|
||||
let sender = lua
|
||||
.app_data_ref::<Weak<Sender<LuneMessage>>>()
|
||||
.unwrap()
|
||||
.upgrade()
|
||||
.unwrap();
|
||||
sender.send(message).await.map_err(LuaError::external)
|
||||
}
|
||||
|
||||
pub async fn run_registered_task<T>(
|
||||
lua: &'static Lua,
|
||||
mode: TaskRunMode,
|
||||
to_run: impl Future<Output = LuaResult<T>> + 'static,
|
||||
) -> LuaResult<()> {
|
||||
// Send a message that we have started our task
|
||||
send_message(lua, LuneMessage::Spawned).await?;
|
||||
// Run the new task separately from the current one using the executor
|
||||
let task = task::spawn_local(async move {
|
||||
// HACK: For deferred tasks we yield a bunch of times to try and ensure
|
||||
// we run our task at the very end of the async queue, this can fail if
|
||||
// the user creates a bunch of interleaved deferred and normal tasks
|
||||
if mode == TaskRunMode::Deferred {
|
||||
for _ in 0..64 {
|
||||
task::yield_now().await;
|
||||
}
|
||||
}
|
||||
send_message(
|
||||
lua,
|
||||
match to_run.await {
|
||||
Ok(_) => LuneMessage::Finished,
|
||||
Err(LuaError::CoroutineInactive) => LuneMessage::Finished, // Task was canceled
|
||||
Err(e) => LuneMessage::LuaError(e),
|
||||
},
|
||||
)
|
||||
.await
|
||||
});
|
||||
// Wait for the task to complete if we want this call to be blocking
|
||||
// Any lua errors will be sent through the message channel back
|
||||
// to the main thread which will then handle them properly
|
||||
if mode == TaskRunMode::Blocking {
|
||||
task.await
|
||||
.map_err(LuaError::external)?
|
||||
.map_err(LuaError::external)?;
|
||||
}
|
||||
// Yield once right away to let the above spawned task start working
|
||||
// instantly, forcing it to run until completion or until it yields
|
||||
task::yield_now().await;
|
||||
Ok(())
|
||||
}
|
|
@ -1,3 +1 @@
|
|||
stdio.write("Hello, stdout!")
|
||||
|
||||
process.exit(0)
|
||||
|
|
|
@ -10,18 +10,18 @@ task.defer(function()
|
|||
flag = true
|
||||
end)
|
||||
assert(not flag, "Defer should not run instantly or block")
|
||||
task.wait(0.1)
|
||||
task.wait(0.05)
|
||||
assert(flag, "Defer should run")
|
||||
|
||||
-- Deferred functions should work with yielding
|
||||
|
||||
local flag2: boolean = false
|
||||
task.defer(function()
|
||||
task.wait(0.1)
|
||||
task.wait(0.05)
|
||||
flag2 = true
|
||||
end)
|
||||
assert(not flag2, "Defer should work with yielding (1)")
|
||||
task.wait(0.2)
|
||||
task.wait(0.1)
|
||||
assert(flag2, "Defer should work with yielding (2)")
|
||||
|
||||
-- Deferred functions should run after other spawned threads
|
||||
|
|
|
@ -10,20 +10,20 @@ task.delay(0, function()
|
|||
flag = true
|
||||
end)
|
||||
assert(not flag, "Delay should not run instantly or block")
|
||||
task.wait(1 / 60)
|
||||
task.wait(0.05)
|
||||
assert(flag, "Delay should run after the wanted duration")
|
||||
|
||||
-- Delayed functions should work with yielding
|
||||
|
||||
local flag2: boolean = false
|
||||
task.delay(0.2, function()
|
||||
task.delay(0.05, function()
|
||||
flag2 = true
|
||||
task.wait(0.4)
|
||||
task.wait(0.1)
|
||||
flag2 = false
|
||||
end)
|
||||
task.wait(0.4)
|
||||
task.wait(0.1)
|
||||
assert(flag, "Delay should work with yielding (1)")
|
||||
task.wait(0.4)
|
||||
task.wait(0.1)
|
||||
assert(not flag2, "Delay should work with yielding (2)")
|
||||
|
||||
-- Varargs should get passed correctly
|
||||
|
|
|
@ -15,11 +15,11 @@ assert(flag, "Spawn should run instantly")
|
|||
|
||||
local flag2: boolean = false
|
||||
task.spawn(function()
|
||||
task.wait(0.1)
|
||||
task.wait(0.05)
|
||||
flag2 = true
|
||||
end)
|
||||
assert(not flag2, "Spawn should work with yielding (1)")
|
||||
task.wait(0.2)
|
||||
task.wait(0.1)
|
||||
assert(flag2, "Spawn should work with yielding (2)")
|
||||
|
||||
-- Spawned functions should be able to run threads created with the coroutine global
|
||||
|
|
|
@ -5,18 +5,28 @@ local EPSILON = 1 / 100
|
|||
local function test(expected: number)
|
||||
local start = os.clock()
|
||||
local returned = task.wait(expected)
|
||||
if typeof(returned) ~= "number" then
|
||||
error(
|
||||
string.format(
|
||||
"Expected task.wait to return a number, got %s %s",
|
||||
typeof(returned),
|
||||
stdio.format(returned)
|
||||
),
|
||||
2
|
||||
)
|
||||
end
|
||||
local elapsed = (os.clock() - start)
|
||||
local difference = math.abs(elapsed - returned)
|
||||
local difference = math.abs(elapsed - expected)
|
||||
if difference > EPSILON then
|
||||
error(
|
||||
string.format(
|
||||
"Elapsed time diverged too much from argument!"
|
||||
.. "\nGot argument of %.3fs and elapsed time of %.3fs"
|
||||
.. "\nGot maximum difference of %.3fs and real difference of %.3fs",
|
||||
expected,
|
||||
elapsed,
|
||||
EPSILON,
|
||||
difference
|
||||
.. "\nGot argument of %.3fms and elapsed time of %.3fms"
|
||||
.. "\nGot maximum difference of %.3fms and real difference of %.3fms",
|
||||
expected * 1_000,
|
||||
elapsed * 1_000,
|
||||
EPSILON * 1_000,
|
||||
difference * 1_000
|
||||
)
|
||||
)
|
||||
end
|
||||
|
|
Loading…
Reference in a new issue