mirror of
https://github.com/lune-org/mlua-luau-scheduler.git
synced 2025-04-10 21:40:55 +01:00
Add runtime-compatible versions of coroutine resume and wrap to functions struct
This commit is contained in:
parent
743d1075bf
commit
b4bc15d4ce
3 changed files with 121 additions and 14 deletions
104
lib/functions.rs
104
lib/functions.rs
|
@ -1,5 +1,5 @@
|
||||||
#![allow(unused_imports)]
|
#![allow(unused_imports)]
|
||||||
#![allow(clippy::module_name_repetitions)]
|
#![allow(clippy::too_many_lines)]
|
||||||
|
|
||||||
use std::process::ExitCode;
|
use std::process::ExitCode;
|
||||||
|
|
||||||
|
@ -26,10 +26,34 @@ exit(...)
|
||||||
yield()
|
yield()
|
||||||
";
|
";
|
||||||
|
|
||||||
|
const WRAP_IMPL_LUA: &str = r"
|
||||||
|
local t = create(...)
|
||||||
|
return function(...)
|
||||||
|
local results = { resume(f, ...) }
|
||||||
|
if results[1] then
|
||||||
|
return select(2, unpack(results))
|
||||||
|
else
|
||||||
|
error(results[2], 2)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
";
|
||||||
|
|
||||||
/**
|
/**
|
||||||
A collection of lua functions that may be called to interact with a [`Runtime`].
|
A collection of lua functions that may be called to interact with a [`Runtime`].
|
||||||
*/
|
*/
|
||||||
pub struct Functions<'lua> {
|
pub struct Functions<'lua> {
|
||||||
|
/**
|
||||||
|
Implementation of `coroutine.resume` that handles async polling properly.
|
||||||
|
|
||||||
|
Defers onto the runtime queue if the thread calls an async function.
|
||||||
|
*/
|
||||||
|
pub resume: LuaFunction<'lua>,
|
||||||
|
/**
|
||||||
|
Implementation of `coroutine.wrap` that handles async polling properly.
|
||||||
|
|
||||||
|
Defers onto the runtime queue if the thread calls an async function.
|
||||||
|
*/
|
||||||
|
pub wrap: LuaFunction<'lua>,
|
||||||
/**
|
/**
|
||||||
Resumes a function / thread once instantly, and runs until first yield.
|
Resumes a function / thread once instantly, and runs until first yield.
|
||||||
|
|
||||||
|
@ -84,6 +108,57 @@ impl<'lua> Functions<'lua> {
|
||||||
.expect(ERR_METADATA_NOT_ATTACHED)
|
.expect(ERR_METADATA_NOT_ATTACHED)
|
||||||
.clone();
|
.clone();
|
||||||
|
|
||||||
|
let resume_queue = defer_queue.clone();
|
||||||
|
let resume_map = result_map.clone();
|
||||||
|
let resume =
|
||||||
|
lua.create_function(move |lua, (thread, args): (LuaThread, LuaMultiValue)| {
|
||||||
|
match thread.resume::<_, LuaMultiValue>(args.clone()) {
|
||||||
|
Ok(v) => {
|
||||||
|
if v.get(0).map(is_poll_pending).unwrap_or_default() {
|
||||||
|
// Pending, defer to scheduler and return nil
|
||||||
|
resume_queue.push_item(lua, &thread, args)?;
|
||||||
|
(true, LuaValue::Nil).into_lua_multi(lua)
|
||||||
|
} else {
|
||||||
|
// Not pending, store the value if thread is done
|
||||||
|
if thread.status() != LuaThreadStatus::Resumable {
|
||||||
|
let id = ThreadId::from(&thread);
|
||||||
|
if resume_map.is_tracked(id) {
|
||||||
|
let res = ThreadResult::new(Ok(v.clone()), lua);
|
||||||
|
resume_map.insert(id, res);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(true, v).into_lua_multi(lua)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Not pending, store the error
|
||||||
|
let id = ThreadId::from(&thread);
|
||||||
|
if resume_map.is_tracked(id) {
|
||||||
|
let res = ThreadResult::new(Err(e.clone()), lua);
|
||||||
|
resume_map.insert(id, res);
|
||||||
|
}
|
||||||
|
(false, e.to_string()).into_lua_multi(lua)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let wrap_env = lua.create_table_from(vec![
|
||||||
|
("resume", resume.clone()),
|
||||||
|
("error", lua.globals().get::<_, LuaFunction>("error")?),
|
||||||
|
("select", lua.globals().get::<_, LuaFunction>("select")?),
|
||||||
|
(
|
||||||
|
"create",
|
||||||
|
lua.globals()
|
||||||
|
.get::<_, LuaTable>("coroutine")?
|
||||||
|
.get::<_, LuaFunction>("create")?,
|
||||||
|
),
|
||||||
|
])?;
|
||||||
|
let wrap = lua
|
||||||
|
.load(WRAP_IMPL_LUA)
|
||||||
|
.set_name("=__runtime_wrap")
|
||||||
|
.set_environment(wrap_env)
|
||||||
|
.into_function()?;
|
||||||
|
|
||||||
let spawn_map = result_map.clone();
|
let spawn_map = result_map.clone();
|
||||||
let spawn = lua.create_function(
|
let spawn = lua.create_function(
|
||||||
move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
|
move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
|
||||||
|
@ -96,7 +171,8 @@ impl<'lua> Functions<'lua> {
|
||||||
if v.get(0).map(is_poll_pending).unwrap_or_default() {
|
if v.get(0).map(is_poll_pending).unwrap_or_default() {
|
||||||
spawn_queue.push_item(lua, &thread, args)?;
|
spawn_queue.push_item(lua, &thread, args)?;
|
||||||
} else {
|
} else {
|
||||||
// Not pending, store the value
|
// Not pending, store the value if thread is done
|
||||||
|
if thread.status() != LuaThreadStatus::Resumable {
|
||||||
let id = ThreadId::from(&thread);
|
let id = ThreadId::from(&thread);
|
||||||
if spawn_map.is_tracked(id) {
|
if spawn_map.is_tracked(id) {
|
||||||
let res = ThreadResult::new(Ok(v), lua);
|
let res = ThreadResult::new(Ok(v), lua);
|
||||||
|
@ -104,6 +180,7 @@ impl<'lua> Functions<'lua> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error_callback.call(&e);
|
error_callback.call(&e);
|
||||||
// Not pending, store the error
|
// Not pending, store the error
|
||||||
|
@ -165,6 +242,8 @@ impl<'lua> Functions<'lua> {
|
||||||
.into_function()?;
|
.into_function()?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
resume,
|
||||||
|
wrap,
|
||||||
spawn,
|
spawn,
|
||||||
defer,
|
defer,
|
||||||
cancel,
|
cancel,
|
||||||
|
@ -172,3 +251,24 @@ impl<'lua> Functions<'lua> {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Functions<'_> {
|
||||||
|
/**
|
||||||
|
Injects [`Runtime`]-compatible functions into the given [`Lua`] instance.
|
||||||
|
|
||||||
|
This will overwrite the following functions:
|
||||||
|
|
||||||
|
- `coroutine.resume`
|
||||||
|
- `coroutine.wrap`
|
||||||
|
|
||||||
|
# Errors
|
||||||
|
|
||||||
|
Errors when out of memory, or if default Lua globals are missing.
|
||||||
|
*/
|
||||||
|
pub fn inject_compat(&self, lua: &Lua) -> LuaResult<()> {
|
||||||
|
let co: LuaTable = lua.globals().get("coroutine")?;
|
||||||
|
co.set("resume", self.resume.clone())?;
|
||||||
|
co.set("wrap", self.wrap.clone())?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -35,7 +35,6 @@ impl ThreadResultMap {
|
||||||
self.tracked.borrow().contains(&id)
|
self.tracked.borrow().contains(&id)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
|
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
|
||||||
debug_assert!(self.is_tracked(id), "Thread must be tracked");
|
debug_assert!(self.is_tracked(id), "Thread must be tracked");
|
||||||
self.results.borrow_mut().insert(id, result);
|
self.results.borrow_mut().insert(id, result);
|
||||||
|
@ -44,7 +43,6 @@ impl ThreadResultMap {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
|
||||||
pub async fn listen(&self, id: ThreadId) {
|
pub async fn listen(&self, id: ThreadId) {
|
||||||
debug_assert!(self.is_tracked(id), "Thread must be tracked");
|
debug_assert!(self.is_tracked(id), "Thread must be tracked");
|
||||||
if !self.results.borrow().contains_key(&id) {
|
if !self.results.borrow().contains_key(&id) {
|
||||||
|
|
|
@ -329,13 +329,22 @@ impl<'lua> Runtime<'lua> {
|
||||||
// Spawn it on the executor and store the result when done
|
// Spawn it on the executor and store the result when done
|
||||||
local_exec
|
local_exec
|
||||||
.spawn(async move {
|
.spawn(async move {
|
||||||
|
if id_tracked {
|
||||||
|
// Run until yield and check if we got a final result
|
||||||
|
let res = run_until_yield(thread.clone(), args).await;
|
||||||
|
if let Err(e) = res.as_ref() {
|
||||||
|
self.error_callback.call(e);
|
||||||
|
}
|
||||||
|
if thread.status() != LuaThreadStatus::Resumable {
|
||||||
|
let thread_res = ThreadResult::new(res, self.lua);
|
||||||
|
result_map_inner.unwrap().insert(id, thread_res);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Just run until yield
|
||||||
let res = run_until_yield(thread, args).await;
|
let res = run_until_yield(thread, args).await;
|
||||||
if let Err(e) = res.as_ref() {
|
if let Err(e) = res.as_ref() {
|
||||||
self.error_callback.call(e);
|
self.error_callback.call(e);
|
||||||
}
|
}
|
||||||
if id_tracked {
|
|
||||||
let thread_res = ThreadResult::new(res, self.lua);
|
|
||||||
result_map_inner.unwrap().insert(id, thread_res);
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.detach();
|
.detach();
|
||||||
|
|
Loading…
Add table
Reference in a new issue