diff --git a/src/main.luau b/src/main.luau index 742f14e..bfa13d2 100644 --- a/src/main.luau +++ b/src/main.luau @@ -3,16 +3,18 @@ for i = 1, 5 do local thread = coroutine.running() local counter = 0 - for j = 1, 10_000 do + for j = 1, 50_000 do spawn(function() - wait(0.1 * math.random()) + wait(0.1 + 0.1 * math.random()) counter += 1 - if counter == 10_000 then + if counter == 50_000 then print("completed iteration " .. tostring(i) .. " of 5") spawn(thread) end end) end + wait(0.1 * math.random()) + coroutine.yield() end diff --git a/src/thread_runtime.rs b/src/thread_runtime.rs index 43f1804..70ebf72 100644 --- a/src/thread_runtime.rs +++ b/src/thread_runtime.rs @@ -3,8 +3,9 @@ use std::{collections::VecDeque, rc::Rc}; use mlua::prelude::*; use smol::{ channel::{Receiver, Sender}, - future::race, + future::{race, yield_now}, lock::Mutex, + stream::StreamExt, *, }; @@ -14,6 +15,7 @@ use super::{ }; pub struct ThreadRuntime { + pending_key: LuaRegistryKey, queue: Rc>>, tx: Sender<()>, rx: Receiver<()>, @@ -72,7 +74,25 @@ impl ThreadRuntime { lua.globals().set("spawn", fn_spawn)?; lua.globals().set("defer", fn_defer)?; - Ok(ThreadRuntime { queue, tx, rx }) + // HACK: Extract mlua "pending" constant value + let pending = lua + .create_async_function(|_, ()| async move { + yield_now().await; + Ok(()) + }) + .unwrap() + .into_lua_thread(lua) + .unwrap() + .resume::<_, LuaValue>(()) + .unwrap(); + let pending_key = lua.create_registry_value(pending).unwrap(); + + Ok(ThreadRuntime { + pending_key, + queue, + tx, + rx, + }) } /** @@ -110,39 +130,43 @@ impl ThreadRuntime { // executor forward, until all lua threads finish let fut = async { loop { - let did_spawn = race( + race( // Wait for next futures step... async { lua_exec.tick().await; - false }, // ...or for a new thread to arrive async { self.rx.recv().await.ok(); - true }, ) .await; // If a new thread was spawned onto queue, we // must drain it and schedule on the executor - if did_spawn { - let queued_threads = self.queue.lock().await.drain(..).collect::>(); - for queued_thread in queued_threads { - // NOTE: Thread may have been cancelled from lua - // before we got here, so we need to check it again - let (thread, args) = queued_thread.into_inner(lua); - if thread.status() == LuaThreadStatus::Resumable { - let fut = thread.into_async::<_, ()>(args); - lua_exec - .spawn(async move { - match fut.await { - Ok(()) => {} - Err(e) => eprintln!("{e}"), + for queued_thread in self.queue.lock().await.drain(..) { + // NOTE: Thread may have been cancelled from lua + // before we got here, so we need to check it again + let (thread, args) = queued_thread.into_inner(lua); + if thread.status() == LuaThreadStatus::Resumable { + let pending = lua.registry_value(&self.pending_key).unwrap(); + let mut stream = thread.into_async::<_, LuaValue>(args); + + // Keep resuming the thread until we get a + // value that is not the mlua pending value + let fut = async move { + while let Some(res) = stream.next().await { + match res { + Err(e) => eprintln!("{e}"), + Ok(v) if v != pending => { + break; } - }) - .detach(); - } + Ok(_) => {} + } + } + }; + + lua_exec.spawn(fut).detach(); } }