diff --git a/src/lune/scheduler/state.rs b/src/lune/scheduler/state.rs index 6050810..f2df7ee 100644 --- a/src/lune/scheduler/state.rs +++ b/src/lune/scheduler/state.rs @@ -1,30 +1,34 @@ use std::{ - cell::RefCell, collections::HashMap, - sync::atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, + sync::{ + atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering}, + Arc, Mutex, + }, }; use mlua::Error as LuaError; use super::SchedulerThreadId; +/** + Internal state for a [`Scheduler`]. + + This scheduler state uses atomic operations for everything + except lua error storage, and is completely thread safe. +*/ #[derive(Debug, Default)] pub struct SchedulerState { exit_state: AtomicBool, exit_code: AtomicU8, num_resumptions: AtomicUsize, num_errors: AtomicUsize, - // TODO: Use Arc> to make these thread and borrow safe - thread_id: RefCell>, - thread_errors: RefCell>, + thread_id: Arc>>, + thread_errors: Arc>>, } impl SchedulerState { /** Creates a new scheduler state. - - This scheduler state uses atomic operations for everything - except lua resumption errors, and is completely thread safe. */ pub fn new() -> Self { Self::default() @@ -80,18 +84,33 @@ impl SchedulerState { Gets the currently running lua scheduler thread id, if any. */ pub fn get_current_thread_id(&self) -> Option { - *self.thread_id.borrow() + *self + .thread_id + .lock() + .expect("Failed to lock current thread id") } /** Sets the currently running lua scheduler thread id. - This should be set to `Some(id)` just before resuming a lua - thread, and `None` while no lua thread is being resumed. + This must be set to `Some(id)` just before resuming a lua thread, + and `None` while no lua thread is being resumed. If set to `Some` + while the current thread id is also `Some`, this will panic. + + Must only be set once per thread id, although this + is not checked at runtime for performance reasons. */ pub fn set_current_thread_id(&self, id: Option) { - self.thread_id.replace(id); self.num_resumptions.fetch_add(1, Ordering::Relaxed); + let mut thread_id = self + .thread_id + .lock() + .expect("Failed to lock current thread id"); + assert!( + id.is_none() || thread_id.is_none(), + "Current thread id can not be overwritten" + ); + *thread_id = id; } /** @@ -100,7 +119,11 @@ impl SchedulerState { Note that this removes the error from the scheduler state completely. */ pub fn get_thread_error(&self, id: SchedulerThreadId) -> Option { - self.thread_errors.borrow_mut().remove(&id) + let mut thread_errors = self + .thread_errors + .lock() + .expect("Failed to lock thread errors"); + thread_errors.remove(&id) } /** @@ -109,6 +132,10 @@ impl SchedulerState { Note that this will replace any already existing [`LuaError`]. */ pub fn set_thread_error(&self, id: SchedulerThreadId, err: LuaError) { - self.thread_errors.borrow_mut().insert(id, err); + let mut thread_errors = self + .thread_errors + .lock() + .expect("Failed to lock thread errors"); + thread_errors.insert(id, err); } }