Add runtime-compatible versions of coroutine resume and wrap to functions struct

This commit is contained in:
Filip Tibell 2024-02-01 14:35:32 +01:00
parent 743d1075bf
commit b4bc15d4ce
No known key found for this signature in database
3 changed files with 121 additions and 14 deletions

View file

@ -1,5 +1,5 @@
#![allow(unused_imports)]
#![allow(clippy::module_name_repetitions)]
#![allow(clippy::too_many_lines)]
use std::process::ExitCode;
@ -26,10 +26,34 @@ exit(...)
yield()
";
const WRAP_IMPL_LUA: &str = r"
local t = create(...)
return function(...)
local results = { resume(f, ...) }
if results[1] then
return select(2, unpack(results))
else
error(results[2], 2)
end
end
";
/**
A collection of lua functions that may be called to interact with a [`Runtime`].
*/
pub struct Functions<'lua> {
/**
Implementation of `coroutine.resume` that handles async polling properly.
Defers onto the runtime queue if the thread calls an async function.
*/
pub resume: LuaFunction<'lua>,
/**
Implementation of `coroutine.wrap` that handles async polling properly.
Defers onto the runtime queue if the thread calls an async function.
*/
pub wrap: LuaFunction<'lua>,
/**
Resumes a function / thread once instantly, and runs until first yield.
@ -84,6 +108,57 @@ impl<'lua> Functions<'lua> {
.expect(ERR_METADATA_NOT_ATTACHED)
.clone();
let resume_queue = defer_queue.clone();
let resume_map = result_map.clone();
let resume =
lua.create_function(move |lua, (thread, args): (LuaThread, LuaMultiValue)| {
match thread.resume::<_, LuaMultiValue>(args.clone()) {
Ok(v) => {
if v.get(0).map(is_poll_pending).unwrap_or_default() {
// Pending, defer to scheduler and return nil
resume_queue.push_item(lua, &thread, args)?;
(true, LuaValue::Nil).into_lua_multi(lua)
} else {
// Not pending, store the value if thread is done
if thread.status() != LuaThreadStatus::Resumable {
let id = ThreadId::from(&thread);
if resume_map.is_tracked(id) {
let res = ThreadResult::new(Ok(v.clone()), lua);
resume_map.insert(id, res);
}
}
(true, v).into_lua_multi(lua)
}
}
Err(e) => {
// Not pending, store the error
let id = ThreadId::from(&thread);
if resume_map.is_tracked(id) {
let res = ThreadResult::new(Err(e.clone()), lua);
resume_map.insert(id, res);
}
(false, e.to_string()).into_lua_multi(lua)
}
}
})?;
let wrap_env = lua.create_table_from(vec![
("resume", resume.clone()),
("error", lua.globals().get::<_, LuaFunction>("error")?),
("select", lua.globals().get::<_, LuaFunction>("select")?),
(
"create",
lua.globals()
.get::<_, LuaTable>("coroutine")?
.get::<_, LuaFunction>("create")?,
),
])?;
let wrap = lua
.load(WRAP_IMPL_LUA)
.set_name("=__runtime_wrap")
.set_environment(wrap_env)
.into_function()?;
let spawn_map = result_map.clone();
let spawn = lua.create_function(
move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
@ -96,11 +171,13 @@ impl<'lua> Functions<'lua> {
if v.get(0).map(is_poll_pending).unwrap_or_default() {
spawn_queue.push_item(lua, &thread, args)?;
} else {
// Not pending, store the value
let id = ThreadId::from(&thread);
if spawn_map.is_tracked(id) {
let res = ThreadResult::new(Ok(v), lua);
spawn_map.insert(id, res);
// Not pending, store the value if thread is done
if thread.status() != LuaThreadStatus::Resumable {
let id = ThreadId::from(&thread);
if spawn_map.is_tracked(id) {
let res = ThreadResult::new(Ok(v), lua);
spawn_map.insert(id, res);
}
}
}
}
@ -165,6 +242,8 @@ impl<'lua> Functions<'lua> {
.into_function()?;
Ok(Self {
resume,
wrap,
spawn,
defer,
cancel,
@ -172,3 +251,24 @@ impl<'lua> Functions<'lua> {
})
}
}
impl Functions<'_> {
/**
Injects [`Runtime`]-compatible functions into the given [`Lua`] instance.
This will overwrite the following functions:
- `coroutine.resume`
- `coroutine.wrap`
# Errors
Errors when out of memory, or if default Lua globals are missing.
*/
pub fn inject_compat(&self, lua: &Lua) -> LuaResult<()> {
let co: LuaTable = lua.globals().get("coroutine")?;
co.set("resume", self.resume.clone())?;
co.set("wrap", self.wrap.clone())?;
Ok(())
}
}

View file

@ -35,7 +35,6 @@ impl ThreadResultMap {
self.tracked.borrow().contains(&id)
}
#[inline]
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
debug_assert!(self.is_tracked(id), "Thread must be tracked");
self.results.borrow_mut().insert(id, result);
@ -44,7 +43,6 @@ impl ThreadResultMap {
}
}
#[inline]
pub async fn listen(&self, id: ThreadId) {
debug_assert!(self.is_tracked(id), "Thread must be tracked");
if !self.results.borrow().contains_key(&id) {

View file

@ -329,13 +329,22 @@ impl<'lua> Runtime<'lua> {
// Spawn it on the executor and store the result when done
local_exec
.spawn(async move {
let res = run_until_yield(thread, args).await;
if let Err(e) = res.as_ref() {
self.error_callback.call(e);
}
if id_tracked {
let thread_res = ThreadResult::new(res, self.lua);
result_map_inner.unwrap().insert(id, thread_res);
// Run until yield and check if we got a final result
let res = run_until_yield(thread.clone(), args).await;
if let Err(e) = res.as_ref() {
self.error_callback.call(e);
}
if thread.status() != LuaThreadStatus::Resumable {
let thread_res = ThreadResult::new(res, self.lua);
result_map_inner.unwrap().insert(id, thread_res);
}
} else {
// Just run until yield
let res = run_until_yield(thread, args).await;
if let Err(e) = res.as_ref() {
self.error_callback.call(e);
}
}
})
.detach();