diff --git a/lib/functions.rs b/lib/functions.rs index 876d76d..db33748 100644 --- a/lib/functions.rs +++ b/lib/functions.rs @@ -1,5 +1,5 @@ #![allow(unused_imports)] -#![allow(clippy::module_name_repetitions)] +#![allow(clippy::too_many_lines)] use std::process::ExitCode; @@ -26,10 +26,34 @@ exit(...) yield() "; +const WRAP_IMPL_LUA: &str = r" +local t = create(...) +return function(...) + local results = { resume(f, ...) } + if results[1] then + return select(2, unpack(results)) + else + error(results[2], 2) + end +end +"; + /** A collection of lua functions that may be called to interact with a [`Runtime`]. */ pub struct Functions<'lua> { + /** + Implementation of `coroutine.resume` that handles async polling properly. + + Defers onto the runtime queue if the thread calls an async function. + */ + pub resume: LuaFunction<'lua>, + /** + Implementation of `coroutine.wrap` that handles async polling properly. + + Defers onto the runtime queue if the thread calls an async function. + */ + pub wrap: LuaFunction<'lua>, /** Resumes a function / thread once instantly, and runs until first yield. @@ -84,6 +108,57 @@ impl<'lua> Functions<'lua> { .expect(ERR_METADATA_NOT_ATTACHED) .clone(); + let resume_queue = defer_queue.clone(); + let resume_map = result_map.clone(); + let resume = + lua.create_function(move |lua, (thread, args): (LuaThread, LuaMultiValue)| { + match thread.resume::<_, LuaMultiValue>(args.clone()) { + Ok(v) => { + if v.get(0).map(is_poll_pending).unwrap_or_default() { + // Pending, defer to scheduler and return nil + resume_queue.push_item(lua, &thread, args)?; + (true, LuaValue::Nil).into_lua_multi(lua) + } else { + // Not pending, store the value if thread is done + if thread.status() != LuaThreadStatus::Resumable { + let id = ThreadId::from(&thread); + if resume_map.is_tracked(id) { + let res = ThreadResult::new(Ok(v.clone()), lua); + resume_map.insert(id, res); + } + } + (true, v).into_lua_multi(lua) + } + } + Err(e) => { + // Not pending, store the error + let id = ThreadId::from(&thread); + if resume_map.is_tracked(id) { + let res = ThreadResult::new(Err(e.clone()), lua); + resume_map.insert(id, res); + } + (false, e.to_string()).into_lua_multi(lua) + } + } + })?; + + let wrap_env = lua.create_table_from(vec![ + ("resume", resume.clone()), + ("error", lua.globals().get::<_, LuaFunction>("error")?), + ("select", lua.globals().get::<_, LuaFunction>("select")?), + ( + "create", + lua.globals() + .get::<_, LuaTable>("coroutine")? + .get::<_, LuaFunction>("create")?, + ), + ])?; + let wrap = lua + .load(WRAP_IMPL_LUA) + .set_name("=__runtime_wrap") + .set_environment(wrap_env) + .into_function()?; + let spawn_map = result_map.clone(); let spawn = lua.create_function( move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| { @@ -96,11 +171,13 @@ impl<'lua> Functions<'lua> { if v.get(0).map(is_poll_pending).unwrap_or_default() { spawn_queue.push_item(lua, &thread, args)?; } else { - // Not pending, store the value - let id = ThreadId::from(&thread); - if spawn_map.is_tracked(id) { - let res = ThreadResult::new(Ok(v), lua); - spawn_map.insert(id, res); + // Not pending, store the value if thread is done + if thread.status() != LuaThreadStatus::Resumable { + let id = ThreadId::from(&thread); + if spawn_map.is_tracked(id) { + let res = ThreadResult::new(Ok(v), lua); + spawn_map.insert(id, res); + } } } } @@ -165,6 +242,8 @@ impl<'lua> Functions<'lua> { .into_function()?; Ok(Self { + resume, + wrap, spawn, defer, cancel, @@ -172,3 +251,24 @@ impl<'lua> Functions<'lua> { }) } } + +impl Functions<'_> { + /** + Injects [`Runtime`]-compatible functions into the given [`Lua`] instance. + + This will overwrite the following functions: + + - `coroutine.resume` + - `coroutine.wrap` + + # Errors + + Errors when out of memory, or if default Lua globals are missing. + */ + pub fn inject_compat(&self, lua: &Lua) -> LuaResult<()> { + let co: LuaTable = lua.globals().get("coroutine")?; + co.set("resume", self.resume.clone())?; + co.set("wrap", self.wrap.clone())?; + Ok(()) + } +} diff --git a/lib/result_map.rs b/lib/result_map.rs index 5907f50..fe08a5f 100644 --- a/lib/result_map.rs +++ b/lib/result_map.rs @@ -35,7 +35,6 @@ impl ThreadResultMap { self.tracked.borrow().contains(&id) } - #[inline] pub fn insert(&self, id: ThreadId, result: ThreadResult) { debug_assert!(self.is_tracked(id), "Thread must be tracked"); self.results.borrow_mut().insert(id, result); @@ -44,7 +43,6 @@ impl ThreadResultMap { } } - #[inline] pub async fn listen(&self, id: ThreadId) { debug_assert!(self.is_tracked(id), "Thread must be tracked"); if !self.results.borrow().contains_key(&id) { diff --git a/lib/runtime.rs b/lib/runtime.rs index 213d137..1c26c97 100644 --- a/lib/runtime.rs +++ b/lib/runtime.rs @@ -329,13 +329,22 @@ impl<'lua> Runtime<'lua> { // Spawn it on the executor and store the result when done local_exec .spawn(async move { - let res = run_until_yield(thread, args).await; - if let Err(e) = res.as_ref() { - self.error_callback.call(e); - } if id_tracked { - let thread_res = ThreadResult::new(res, self.lua); - result_map_inner.unwrap().insert(id, thread_res); + // Run until yield and check if we got a final result + let res = run_until_yield(thread.clone(), args).await; + if let Err(e) = res.as_ref() { + self.error_callback.call(e); + } + if thread.status() != LuaThreadStatus::Resumable { + let thread_res = ThreadResult::new(res, self.lua); + result_map_inner.unwrap().insert(id, thread_res); + } + } else { + // Just run until yield + let res = run_until_yield(thread, args).await; + if let Err(e) = res.as_ref() { + self.error_callback.call(e); + } } }) .detach();