Ensure thread safety of scheduler state

This commit is contained in:
Filip Tibell 2023-08-19 17:09:13 -05:00
parent b1847bf84c
commit 4acc730d38

View file

@ -1,30 +1,34 @@
use std::{
cell::RefCell,
collections::HashMap,
sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
sync::{
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
Arc, Mutex,
},
};
use mlua::Error as LuaError;
use super::SchedulerThreadId;
/**
Internal state for a [`Scheduler`].
This scheduler state uses atomic operations for everything
except lua error storage, and is completely thread safe.
*/
#[derive(Debug, Default)]
pub struct SchedulerState {
exit_state: AtomicBool,
exit_code: AtomicU8,
num_resumptions: AtomicUsize,
num_errors: AtomicUsize,
// TODO: Use Arc<Mutex<T>> to make these thread and borrow safe
thread_id: RefCell<Option<SchedulerThreadId>>,
thread_errors: RefCell<HashMap<SchedulerThreadId, LuaError>>,
thread_id: Arc<Mutex<Option<SchedulerThreadId>>>,
thread_errors: Arc<Mutex<HashMap<SchedulerThreadId, LuaError>>>,
}
impl SchedulerState {
/**
Creates a new scheduler state.
This scheduler state uses atomic operations for everything
except lua resumption errors, and is completely thread safe.
*/
pub fn new() -> Self {
Self::default()
@ -80,18 +84,33 @@ impl SchedulerState {
Gets the currently running lua scheduler thread id, if any.
*/
pub fn get_current_thread_id(&self) -> Option<SchedulerThreadId> {
*self.thread_id.borrow()
*self
.thread_id
.lock()
.expect("Failed to lock current thread id")
}
/**
Sets the currently running lua scheduler thread id.
This should be set to `Some(id)` just before resuming a lua
thread, and `None` while no lua thread is being resumed.
This must be set to `Some(id)` just before resuming a lua thread,
and `None` while no lua thread is being resumed. If set to `Some`
while the current thread id is also `Some`, this will panic.
Must only be set once per thread id, although this
is not checked at runtime for performance reasons.
*/
pub fn set_current_thread_id(&self, id: Option<SchedulerThreadId>) {
self.thread_id.replace(id);
self.num_resumptions.fetch_add(1, Ordering::Relaxed);
let mut thread_id = self
.thread_id
.lock()
.expect("Failed to lock current thread id");
assert!(
id.is_none() || thread_id.is_none(),
"Current thread id can not be overwritten"
);
*thread_id = id;
}
/**
@ -100,7 +119,11 @@ impl SchedulerState {
Note that this removes the error from the scheduler state completely.
*/
pub fn get_thread_error(&self, id: SchedulerThreadId) -> Option<LuaError> {
self.thread_errors.borrow_mut().remove(&id)
let mut thread_errors = self
.thread_errors
.lock()
.expect("Failed to lock thread errors");
thread_errors.remove(&id)
}
/**
@ -109,6 +132,10 @@ impl SchedulerState {
Note that this will replace any already existing [`LuaError`].
*/
pub fn set_thread_error(&self, id: SchedulerThreadId, err: LuaError) {
self.thread_errors.borrow_mut().insert(id, err);
let mut thread_errors = self
.thread_errors
.lock()
.expect("Failed to lock thread errors");
thread_errors.insert(id, err);
}
}