Get rid of thread handles in favor of simple id-based map

This commit is contained in:
Filip Tibell 2024-01-31 20:17:15 +01:00
parent 5820858147
commit 013537b27b
No known key found for this signature in database
10 changed files with 153 additions and 179 deletions

View file

@ -26,13 +26,13 @@ pub fn main() -> LuaResult<()> {
// Load the main script into the runtime, and keep track of the thread we spawn
let main = lua.load(MAIN_SCRIPT);
let handle = rt.push_thread_front(main, ())?;
let id = rt.push_thread_front(main, ())?;
// Run until completion
block_on(rt.run());
// We should have gotten the error back from our script
assert!(handle.result(&lua).unwrap().is_err());
assert!(rt.thread_result(id).unwrap().is_err());
Ok(())
}

View file

@ -32,13 +32,13 @@ pub fn main() -> LuaResult<()> {
// Load the main script into the runtime, and keep track of the thread we spawn
let main = lua.load(MAIN_SCRIPT);
let handle = rt.push_thread_front(main, ())?;
let id = rt.push_thread_front(main, ())?;
// Run until completion
block_on(rt.run());
// We should have gotten proper values back from our script
let res = handle.result(&lua).unwrap().unwrap();
let res = rt.thread_result(id).unwrap().unwrap();
let nums = Vec::<usize>::from_lua_multi(res, &lua)?;
assert_eq!(nums, vec![1, 2, 3, 4, 5, 6]);

View file

@ -6,8 +6,10 @@ use mlua::prelude::*;
use crate::{
error_callback::ThreadErrorCallback,
queue::{DeferredThreadQueue, SpawnedThreadQueue},
result_map::ThreadResultMap,
runtime::Runtime,
util::LuaThreadOrFunction,
thread_id::ThreadId,
util::{is_poll_pending, LuaThreadOrFunction, ThreadResult},
};
const ERR_METADATA_NOT_ATTACHED: &str = "\
@ -63,24 +65,39 @@ impl<'lua> Functions<'lua> {
.app_data_ref::<ThreadErrorCallback>()
.expect(ERR_METADATA_NOT_ATTACHED)
.clone();
let result_map = lua
.app_data_ref::<ThreadResultMap>()
.expect(ERR_METADATA_NOT_ATTACHED)
.clone();
let spawn_map = result_map.clone();
let spawn = lua.create_function(
move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
let thread = tof.into_thread(lua)?;
if thread.status() == LuaThreadStatus::Resumable {
// NOTE: We need to resume the thread once instantly for correct behavior,
// and only if we get the pending value back we can spawn to async executor
match thread.resume::<_, LuaValue>(args.clone()) {
match thread.resume::<_, LuaMultiValue>(args.clone()) {
Ok(v) => {
if v.as_light_userdata()
.map(|l| l == Lua::poll_pending())
.unwrap_or_default()
{
if v.get(0).map(is_poll_pending).unwrap_or_default() {
spawn_queue.push_item(lua, &thread, args)?;
} else {
// Not pending, store the value
let id = ThreadId::from(&thread);
if spawn_map.is_tracked(id) {
let res = ThreadResult::new(Ok(v), lua);
spawn_map.insert(id, res);
}
}
}
Err(e) => {
error_callback.call(&e);
// Not pending, store the error
let id = ThreadId::from(&thread);
if spawn_map.is_tracked(id) {
let res = ThreadResult::new(Err(e), lua);
spawn_map.insert(id, res);
}
}
};
}

View file

@ -1,130 +0,0 @@
#![allow(unused_imports)]
#![allow(clippy::missing_panics_doc)]
#![allow(clippy::module_name_repetitions)]
use std::{
cell::{Cell, RefCell},
rc::Rc,
};
use event_listener::Event;
use mlua::prelude::*;
use crate::{
runtime::Runtime,
status::Status,
traits::IntoLuaThread,
util::{run_until_yield, ThreadResult, ThreadWithArgs},
};
/**
A handle to a thread that has been spawned onto a [`Runtime`].
This handle contains a public method, [`Handle::result`], which may
be used to extract the result of the thread, once it finishes running.
A result may be waited for using the [`Handle::listen`] method.
*/
#[derive(Debug, Clone)]
pub struct Handle {
thread: Rc<RefCell<Option<ThreadWithArgs>>>,
result: Rc<RefCell<Option<ThreadResult>>>,
status: Rc<Cell<bool>>,
event: Rc<Event>,
}
impl Handle {
pub(crate) fn new<'lua>(
lua: &'lua Lua,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<Self> {
let thread = thread.into_lua_thread(lua)?;
let args = args.into_lua_multi(lua)?;
let packed = ThreadWithArgs::new(lua, thread, args)?;
Ok(Self {
thread: Rc::new(RefCell::new(Some(packed))),
result: Rc::new(RefCell::new(None)),
status: Rc::new(Cell::new(false)),
event: Rc::new(Event::new()),
})
}
pub(crate) fn create_thread<'lua>(&self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
let env = lua.create_table()?;
env.set("handle", self.clone())?;
lua.load("return handle:resume()")
.set_name("__runtime_handle")
.set_environment(env)
.into_lua_thread(lua)
}
fn take<'lua>(&self, lua: &'lua Lua) -> (LuaThread<'lua>, LuaMultiValue<'lua>) {
self.thread
.borrow_mut()
.take()
.expect("thread handle may only be taken once")
.into_inner(lua)
}
fn set<'lua>(&self, lua: &'lua Lua, result: &LuaResult<LuaMultiValue<'lua>>, is_final: bool) {
self.result
.borrow_mut()
.replace(ThreadResult::new(result.clone(), lua));
self.status.replace(is_final);
if is_final {
self.event.notify(usize::MAX);
}
}
/**
Extracts the result for this thread handle.
Depending on the current [`Runtime::status`], this method will return:
- [`Status::NotStarted`]: returns `None`.
- [`Status::Running`]: may return `Some(Ok(v))` or `Some(Err(e))`, but it is not guaranteed.
- [`Status::Completed`]: returns `Some(Ok(v))` or `Some(Err(e))`.
Note that this method also takes the value out of the handle, so it may only be called once.
Any subsequent calls after this method returns `Some` will return `None`.
*/
#[must_use]
pub fn result<'lua>(&self, lua: &'lua Lua) -> Option<LuaResult<LuaMultiValue<'lua>>> {
let mut res = self.result.borrow_mut();
res.take().map(|r| r.value(lua))
}
/**
Waits for this handle to have its final result available.
Does not wait if the final result is already available.
*/
pub async fn listen(&self) {
if !self.status.get() {
self.event.listen().await;
}
}
}
impl LuaUserData for Handle {
fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("resume", |lua, this, (): ()| async move {
/*
1. Take the thread and args out of the handle
2. Run the thread until it yields or completes
3. Store the result of the thread in the lua registry
4. Return the result of the thread back to lua as well, so that
it may be caught using the runtime and any error callback(s)
*/
let (thread, args) = this.take(lua);
let result = run_until_yield(thread.clone(), args).await;
let is_final = thread.status() != LuaThreadStatus::Resumable;
this.set(lua, &result, is_final);
result
});
}
}

View file

@ -1,7 +1,7 @@
mod error_callback;
mod functions;
mod handle;
mod queue;
mod result_map;
mod runtime;
mod status;
mod thread_id;
@ -9,7 +9,6 @@ mod traits;
mod util;
pub use functions::Functions;
pub use handle::Handle;
pub use runtime::Runtime;
pub use status::Status;
pub use thread_id::ThreadId;

View file

@ -6,7 +6,7 @@ use event_listener::Event;
use futures_lite::{Future, FutureExt};
use mlua::prelude::*;
use crate::{handle::Handle, traits::IntoLuaThread, util::ThreadWithArgs};
use crate::{traits::IntoLuaThread, util::ThreadWithArgs, ThreadId};
/**
Queue for storing [`LuaThread`]s with associated arguments.
@ -32,31 +32,18 @@ impl ThreadQueue {
lua: &'lua Lua,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<()> {
) -> LuaResult<ThreadId> {
let thread = thread.into_lua_thread(lua)?;
let args = args.into_lua_multi(lua)?;
tracing::trace!("pushing item to queue with {} args", args.len());
let id = ThreadId::from(&thread);
let stored = ThreadWithArgs::new(lua, thread, args)?;
self.queue.push(stored).into_lua_err()?;
self.event.notify(usize::MAX);
Ok(())
}
pub fn push_item_with_handle<'lua>(
&self,
lua: &'lua Lua,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<Handle> {
let handle = Handle::new(lua, thread, args)?;
let handle_thread = handle.create_thread(lua)?;
self.push_item(lua, handle_thread, ())?;
Ok(handle)
Ok(id)
}
pub fn drain_items<'outer, 'lua>(

41
lib/result_map.rs Normal file
View file

@ -0,0 +1,41 @@
use std::{
cell::RefCell,
collections::{HashMap, HashSet},
rc::Rc,
};
use crate::{thread_id::ThreadId, util::ThreadResult};
#[derive(Clone)]
pub(crate) struct ThreadResultMap {
tracked: Rc<RefCell<HashSet<ThreadId>>>,
inner: Rc<RefCell<HashMap<ThreadId, ThreadResult>>>,
}
impl ThreadResultMap {
pub fn new() -> Self {
Self {
tracked: Rc::new(RefCell::new(HashSet::new())),
inner: Rc::new(RefCell::new(HashMap::new())),
}
}
pub fn track(&self, id: ThreadId) {
self.tracked.borrow_mut().insert(id);
}
pub fn is_tracked(&self, id: ThreadId) -> bool {
self.tracked.borrow().contains(&id)
}
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
assert!(self.is_tracked(id), "Thread must be tracked");
self.inner.borrow_mut().insert(id, result);
}
pub fn remove(&self, id: ThreadId) -> Option<ThreadResult> {
let res = self.inner.borrow_mut().remove(&id)?;
self.tracked.borrow_mut().remove(&id);
Some(res)
}
}

View file

@ -14,11 +14,12 @@ use tracing::Instrument;
use crate::{
error_callback::ThreadErrorCallback,
handle::Handle,
queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue},
result_map::ThreadResultMap,
status::Status,
thread_id::ThreadId,
traits::IntoLuaThread,
util::run_until_yield,
util::{run_until_yield, ThreadResult},
};
const ERR_METADATA_ALREADY_ATTACHED: &str = "\
@ -45,6 +46,7 @@ pub struct Runtime<'lua> {
queue_spawn: SpawnedThreadQueue,
queue_defer: DeferredThreadQueue,
error_callback: ThreadErrorCallback,
result_map: ThreadResultMap,
status: Rc<Cell<Status>>,
}
@ -63,7 +65,7 @@ impl<'lua> Runtime<'lua> {
let queue_spawn = SpawnedThreadQueue::new();
let queue_defer = DeferredThreadQueue::new();
let error_callback = ThreadErrorCallback::default();
let status = Rc::new(Cell::new(Status::NotStarted));
let result_map = ThreadResultMap::new();
assert!(
lua.app_data_ref::<SpawnedThreadQueue>().is_none(),
@ -77,16 +79,24 @@ impl<'lua> Runtime<'lua> {
lua.app_data_ref::<ThreadErrorCallback>().is_none(),
"{ERR_METADATA_ALREADY_ATTACHED}"
);
assert!(
lua.app_data_ref::<ThreadResultMap>().is_none(),
"{ERR_METADATA_ALREADY_ATTACHED}"
);
lua.set_app_data(queue_spawn.clone());
lua.set_app_data(queue_defer.clone());
lua.set_app_data(error_callback.clone());
lua.set_app_data(result_map.clone());
let status = Rc::new(Cell::new(Status::NotStarted));
Runtime {
lua,
queue_spawn,
queue_defer,
error_callback,
result_map,
status,
}
}
@ -142,7 +152,7 @@ impl<'lua> Runtime<'lua> {
# Returns
Returns a [`Handle`] that can be used to retrieve the result of the thread.
Returns a [`ThreadId`] that can be used to retrieve the result of the thread.
Note that the result may not be available until [`Runtime::run`] completes.
@ -154,10 +164,11 @@ impl<'lua> Runtime<'lua> {
&self,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<Handle> {
) -> LuaResult<ThreadId> {
tracing::debug!(deferred = false, "new runtime thread");
self.queue_spawn
.push_item_with_handle(self.lua, thread, args)
let id = self.queue_spawn.push_item(self.lua, thread, args)?;
self.result_map.track(id);
Ok(id)
}
/**
@ -169,7 +180,7 @@ impl<'lua> Runtime<'lua> {
# Returns
Returns a [`Handle`] that can be used to retrieve the result of the thread.
Returns a [`ThreadId`] that can be used to retrieve the result of the thread.
Note that the result may not be available until [`Runtime::run`] completes.
@ -181,10 +192,30 @@ impl<'lua> Runtime<'lua> {
&self,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<Handle> {
) -> LuaResult<ThreadId> {
tracing::debug!(deferred = true, "new runtime thread");
self.queue_defer
.push_item_with_handle(self.lua, thread, args)
let id = self.queue_defer.push_item(self.lua, thread, args)?;
self.result_map.track(id);
Ok(id)
}
/**
Gets the tracked result for the [`LuaThread`] with the given [`ThreadId`].
Depending on the current [`Runtime::status`], this method will return:
- [`Status::NotStarted`]: returns `None`.
- [`Status::Running`]: may return `Some(Ok(v))` or `Some(Err(e))`, but it is not guaranteed.
- [`Status::Completed`]: returns `Some(Ok(v))` or `Some(Err(e))`.
Note that this method also takes the value out of the runtime and
stops tracking the given thread, so it may only be called once.
Any subsequent calls after this method returns `Some` will return `None`.
*/
#[must_use]
pub fn thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue<'lua>>> {
self.result_map.remove(id).map(|r| r.value(self.lua))
}
/**
@ -245,14 +276,29 @@ impl<'lua> Runtime<'lua> {
when there are new Lua threads to enqueue and potentially more work to be done.
*/
let fut = async {
let result_map = self.result_map.clone();
let process_thread = |thread: LuaThread<'lua>, args| {
// NOTE: Thread may have been cancelled from Lua
// before we got here, so we need to check it again
if thread.status() == LuaThreadStatus::Resumable {
// Check if we should be tracking this thread
let id = ThreadId::from(&thread);
let id_tracked = result_map.is_tracked(id);
let result_map_inner = if id_tracked {
Some(result_map.clone())
} else {
None
};
// Spawn it on the executor and store the result when done
local_exec
.spawn(async move {
if let Err(e) = run_until_yield(thread, args).await {
self.error_callback.call(&e);
let res = run_until_yield(thread, args).await;
if let Err(e) = res.as_ref() {
self.error_callback.call(e);
}
if id_tracked {
let thread_res = ThreadResult::new(res, self.lua);
result_map_inner.unwrap().insert(id, thread_res);
}
})
.detach();
@ -352,5 +398,8 @@ impl Drop for Runtime<'_> {
self.lua
.remove_app_data::<ThreadErrorCallback>()
.expect(ERR_METADATA_REMOVED);
self.lua
.remove_app_data::<ThreadResultMap>()
.expect(ERR_METADATA_REMOVED);
}
}

View file

@ -8,9 +8,9 @@ use mlua::prelude::*;
use async_executor::{Executor, Task};
use crate::{
handle::Handle,
queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue},
runtime::Runtime,
thread_id::ThreadId,
};
/**
@ -76,7 +76,7 @@ pub trait LuaRuntimeExt<'lua> {
&'lua self,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<Handle>;
) -> LuaResult<ThreadId>;
/**
Pushes (defers) a lua thread to the **back** of the current runtime.
@ -91,7 +91,7 @@ pub trait LuaRuntimeExt<'lua> {
&'lua self,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<Handle>;
) -> LuaResult<ThreadId>;
/**
Spawns the given future on the current executor and returns its [`Task`].
@ -180,22 +180,22 @@ impl<'lua> LuaRuntimeExt<'lua> for Lua {
&'lua self,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<Handle> {
) -> LuaResult<ThreadId> {
let queue = self
.app_data_ref::<SpawnedThreadQueue>()
.expect("lua threads can only be pushed within a runtime");
queue.push_item_with_handle(self, thread, args)
queue.push_item(self, thread, args)
}
fn push_thread_back(
&'lua self,
thread: impl IntoLuaThread<'lua>,
args: impl IntoLuaMulti<'lua>,
) -> LuaResult<Handle> {
) -> LuaResult<ThreadId> {
let queue = self
.app_data_ref::<DeferredThreadQueue>()
.expect("lua threads can only be pushed within a runtime");
queue.push_item_with_handle(self, thread, args)
queue.push_item(self, thread, args)
}
fn spawn<T: Send + 'static>(&self, fut: impl Future<Output = T> + Send + 'static) -> Task<T> {

View file

@ -21,6 +21,17 @@ pub(crate) async fn run_until_yield<'lua>(
stream.next().await.unwrap()
}
/**
Checks if the given [`LuaValue`] is the async `POLL_PENDING` constant.
*/
#[inline]
pub(crate) fn is_poll_pending(value: &LuaValue) -> bool {
value
.as_light_userdata()
.map(|l| l == Lua::poll_pending())
.unwrap_or_default()
}
/**
Representation of a [`LuaResult`] with an associated [`LuaMultiValue`] currently stored in the Lua registry.
*/