diff --git a/src/lib/globals/console.rs b/src/lib/globals/console.rs index 14b543f..783ad7e 100644 --- a/src/lib/globals/console.rs +++ b/src/lib/globals/console.rs @@ -5,7 +5,7 @@ use crate::utils::{ table_builder::ReadonlyTableBuilder, }; -pub async fn create(lua: Lua) -> Result { +pub async fn create(lua: &Lua) -> Result<()> { let print = |args: &MultiValue, throw: bool| -> Result<()> { let s = pretty_format_multi_value(args)?; if throw { @@ -18,7 +18,7 @@ pub async fn create(lua: Lua) -> Result { }; lua.globals().raw_set( "console", - ReadonlyTableBuilder::new(&lua)? + ReadonlyTableBuilder::new(lua)? .with_function("resetColor", |_, _: ()| print_color("reset"))? .with_function("setColor", |_, color: String| print_color(color))? .with_function("resetStyle", |_, _: ()| print_style("reset"))? @@ -40,6 +40,5 @@ pub async fn create(lua: Lua) -> Result { print(&args, true) })? .build()?, - )?; - Ok(lua) + ) } diff --git a/src/lib/globals/fs.rs b/src/lib/globals/fs.rs index ab3c1b1..4c3f3a5 100644 --- a/src/lib/globals/fs.rs +++ b/src/lib/globals/fs.rs @@ -5,10 +5,10 @@ use tokio::fs; use crate::utils::table_builder::ReadonlyTableBuilder; -pub async fn create(lua: Lua) -> Result { +pub async fn create(lua: &Lua) -> Result<()> { lua.globals().raw_set( "fs", - ReadonlyTableBuilder::new(&lua)? + ReadonlyTableBuilder::new(lua)? .with_async_function("readFile", fs_read_file)? .with_async_function("readDir", fs_read_dir)? .with_async_function("writeFile", fs_write_file)? @@ -18,8 +18,7 @@ pub async fn create(lua: Lua) -> Result { .with_async_function("isFile", fs_is_file)? .with_async_function("isDir", fs_is_dir)? .build()?, - )?; - Ok(lua) + ) } async fn fs_read_file(_: &Lua, path: String) -> Result { diff --git a/src/lib/globals/net.rs b/src/lib/globals/net.rs index a1c5724..6ee0c3a 100644 --- a/src/lib/globals/net.rs +++ b/src/lib/globals/net.rs @@ -8,16 +8,15 @@ use reqwest::{ use crate::utils::{net::get_request_user_agent_header, table_builder::ReadonlyTableBuilder}; -pub async fn create(lua: Lua) -> Result { +pub async fn create(lua: &Lua) -> Result<()> { lua.globals().raw_set( "net", - ReadonlyTableBuilder::new(&lua)? + ReadonlyTableBuilder::new(lua)? .with_function("jsonEncode", net_json_encode)? .with_function("jsonDecode", net_json_decode)? .with_async_function("request", net_request)? .build()?, - )?; - Ok(lua) + ) } fn net_json_encode(_: &Lua, (val, pretty): (Value, Option)) -> Result { diff --git a/src/lib/globals/process.rs b/src/lib/globals/process.rs index b5e472a..ae18a5b 100644 --- a/src/lib/globals/process.rs +++ b/src/lib/globals/process.rs @@ -9,7 +9,7 @@ use tokio::process::Command; use crate::utils::table_builder::ReadonlyTableBuilder; -pub async fn create(lua: Lua, args_vec: Vec) -> Result { +pub async fn create(lua: &Lua, args_vec: Vec) -> Result<()> { // Create readonly args array let inner_args = lua.create_table()?; for arg in &args_vec { @@ -38,14 +38,13 @@ pub async fn create(lua: Lua, args_vec: Vec) -> Result { // Create the full process table lua.globals().raw_set( "process", - ReadonlyTableBuilder::new(&lua)? + ReadonlyTableBuilder::new(lua)? .with_table("args", inner_args)? .with_table("env", inner_env)? .with_function("exit", process_exit)? .with_async_function("spawn", process_spawn)? .build()?, - )?; - Ok(lua) + ) } fn process_env_get<'lua>(lua: &'lua Lua, (_, key): (Value<'lua>, String)) -> Result> { diff --git a/src/lib/globals/task.rs b/src/lib/globals/task.rs index 1a74252..f58251b 100644 --- a/src/lib/globals/task.rs +++ b/src/lib/globals/task.rs @@ -1,27 +1,80 @@ use std::time::Duration; -use mlua::{Lua, Result}; -use tokio::time; +use mlua::{Error, Function, Lua, Result, Table, Thread, Value, Variadic}; +use tokio::time::{self, Instant}; use crate::utils::table_builder::ReadonlyTableBuilder; -const DEFAULT_SLEEP_DURATION: f32 = 1.0 / 60.0; - -pub async fn create(lua: Lua) -> Result { +pub async fn create(lua: &Lua) -> Result<()> { lua.globals().raw_set( "task", - ReadonlyTableBuilder::new(&lua)? + ReadonlyTableBuilder::new(lua)? + .with_async_function("cancel", task_cancel)? + .with_async_function("defer", task_defer)? + .with_async_function("delay", task_delay)? + .with_async_function("spawn", task_spawn)? .with_async_function("wait", task_wait)? .build()?, - )?; - Ok(lua) + ) } -// FIXME: It does seem possible to properly make an async wait -// function with mlua right now, something breaks when using -// async wait functions inside of coroutines -async fn task_wait(_: &Lua, duration: Option) -> Result { - let secs = duration.unwrap_or(DEFAULT_SLEEP_DURATION); - time::sleep(Duration::from_secs_f32(secs)).await; - Ok(secs) +fn get_thread_from_arg<'a>(lua: &'a Lua, thread_or_function_arg: Value<'a>) -> Result> { + Ok(match thread_or_function_arg { + Value::Thread(thread) => thread, + Value::Function(func) => lua.create_thread(func)?, + val => { + return Err(Error::RuntimeError(format!( + "Expected type thread or function, got {}", + val.type_name() + ))) + } + }) +} + +async fn task_cancel(lua: &Lua, thread: Thread<'_>) -> Result<()> { + let coroutine: Table = lua.globals().raw_get("coroutine")?; + let close: Function = coroutine.raw_get("close")?; + close.call_async(thread).await?; + Ok(()) +} + +async fn task_defer<'a>(lua: &Lua, (tof, args): (Value<'a>, Variadic>)) -> Result<()> { + task_wait(lua, None).await?; + get_thread_from_arg(lua, tof)? + .into_async::<_, Variadic>>(args) + .await?; + Ok(()) +} + +async fn task_delay<'a>( + lua: &Lua, + (delay, tof, args): (Option, Value<'a>, Variadic>), +) -> Result<()> { + task_wait(lua, delay).await?; + get_thread_from_arg(lua, tof)? + .into_async::<_, Variadic>>(args) + .await?; + Ok(()) +} + +async fn task_spawn<'a>(lua: &Lua, (tof, args): (Value<'a>, Variadic>)) -> Result<()> { + get_thread_from_arg(lua, tof)? + .into_async::<_, Variadic>>(args) + .await?; + Ok(()) +} + +// FIXME: It doesn't seem possible to properly make an async wait +// function with mlua right now, something breaks when using +// the async wait function inside of a coroutine +async fn task_wait(_: &Lua, duration: Option) -> Result { + let start = Instant::now(); + time::sleep( + duration + .map(Duration::from_secs_f32) + .unwrap_or(Duration::ZERO), + ) + .await; + let end = Instant::now(); + Ok((end - start).as_secs_f32()) } diff --git a/src/lib/lib.rs b/src/lib/lib.rs index f45f2de..751de40 100644 --- a/src/lib/lib.rs +++ b/src/lib/lib.rs @@ -2,6 +2,7 @@ use std::collections::HashSet; use anyhow::{bail, Result}; use mlua::Lua; +use tokio::task; pub mod globals; pub mod utils; @@ -61,25 +62,40 @@ impl Lune { } pub async fn run(&self, name: &str, chunk: &str) -> Result<()> { - let mut lua = Lua::new(); - for global in &self.globals { - lua = match &global { - LuneGlobal::Console => create_console(lua).await?, - LuneGlobal::Fs => create_fs(lua).await?, - LuneGlobal::Net => create_net(lua).await?, - LuneGlobal::Process => create_process(lua, self.args.clone()).await?, - LuneGlobal::Task => create_task(lua).await?, - } - } - let result = lua.load(chunk).set_name(name)?.exec_async().await; - match result { - Ok(_) => Ok(()), - Err(e) => bail!( - "\n{}\n{}", - format_label("ERROR"), - pretty_format_luau_error(&e) - ), - } + let run_name = name.to_owned(); + let run_chunk = chunk.to_owned(); + let run_globals = self.globals.to_owned(); + let run_args = self.args.to_owned(); + // Spawn a thread-local task so that we can then spawn + // more tasks in our globals without the Send requirement + let local = task::LocalSet::new(); + local + .run_until(async move { + task::spawn_local(async move { + let lua = Lua::new(); + for global in &run_globals { + match &global { + LuneGlobal::Console => create_console(&lua).await?, + LuneGlobal::Fs => create_fs(&lua).await?, + LuneGlobal::Net => create_net(&lua).await?, + LuneGlobal::Process => create_process(&lua, run_args.clone()).await?, + LuneGlobal::Task => create_task(&lua).await?, + } + } + let result = lua.load(&run_chunk).set_name(&run_name)?.exec_async().await; + match result { + Ok(_) => Ok(()), + Err(e) => bail!( + "\n{}\n{}", + format_label("ERROR"), + pretty_format_luau_error(&e) + ), + } + }) + .await + .unwrap() + }) + .await } } diff --git a/src/lib/luau/task.luau b/src/lib/luau/task.luau deleted file mode 100644 index b598086..0000000 --- a/src/lib/luau/task.luau +++ /dev/null @@ -1,112 +0,0 @@ -local MINIMUM_DELAY_TIME = 1 / 100 - -type ThreadOrFunction = thread | (A...) -> R... -type AnyThreadOrFunction = ThreadOrFunction<...any, ...any> - -type WaitingThreadKind = "Normal" | "Deferred" | "Delayed" -type WaitingThread = { - idx: number, - kind: WaitingThreadKind, - thread: thread, - args: { [number]: any, n: number }, -} - -local waitingThreadCounter = 0 -local waitingThreads: { WaitingThread } = {} - -local function scheduleWaitingThreads() - -- Grab currently waiting threads and clear the queue but keep capacity - local threadsToResume: { WaitingThread } = table.clone(waitingThreads) - table.clear(waitingThreads) - table.sort(threadsToResume, function(t0, t1) - local k0: WaitingThreadKind = t0.kind - local k1: WaitingThreadKind = t1.kind - if k0 == k1 then - return t0.idx < t1.idx - end - if k0 == "Normal" then - return true - elseif k1 == "Normal" then - return false - elseif k0 == "Deferred" then - return true - else - return false - end - end) - -- Resume threads in order, giving args & waiting if necessary - for _, waitingThread in threadsToResume do - coroutine.resume( - waitingThread.thread, - table.unpack(waitingThread.args, 1, waitingThread.args.n) - ) - end -end - -local function insertWaitingThread(kind: WaitingThreadKind, tof: AnyThreadOrFunction, ...: any) - if typeof(tof) ~= "thread" and typeof(tof) ~= "function" then - if tof == nil then - error("Expected thread or function, got nil", 3) - end - error( - string.format("Expected thread or function, got %s %s", typeof(tof), tostring(tof)), - 3 - ) - end - local thread = if type(tof) == "function" then coroutine.create(tof) else tof - waitingThreadCounter += 1 - local waitingThread: WaitingThread = { - idx = waitingThreadCounter, - kind = kind, - thread = thread, - args = table.pack(...), - } - table.insert(waitingThreads, waitingThread) - return waitingThread -end - -local function cancel(thread: unknown) - if typeof(thread) ~= "thread" then - if thread == nil then - error("Expected thread, got nil", 2) - end - error(string.format("Expected thread, got %s %s", typeof(thread), tostring(thread)), 2) - else - coroutine.close(thread) - end -end - -local function defer(tof: AnyThreadOrFunction, ...: any): thread - local waiting = insertWaitingThread("Deferred", tof, ...) - local original = waiting.thread - waiting.thread = coroutine.create(function(...) - task.wait(1 / 1_000_000) - coroutine.resume(original, ...) - end) - scheduleWaitingThreads() - return waiting.thread -end - -local function delay(delay: number?, tof: AnyThreadOrFunction, ...: any): thread - local waiting = insertWaitingThread("Delayed", tof, ...) - local original = waiting.thread - waiting.thread = coroutine.create(function(...) - task.wait(math.max(MINIMUM_DELAY_TIME, delay or 0)) - coroutine.resume(original, ...) - end) - scheduleWaitingThreads() - return waiting.thread -end - -local function spawn(tof: AnyThreadOrFunction, ...: any): thread - local waiting = insertWaitingThread("Normal", tof, ...) - scheduleWaitingThreads() - return waiting.thread -end - -return { - cancel = cancel, - defer = defer, - delay = delay, - spawn = spawn, -}