mirror of
https://github.com/lune-org/mlua-luau-scheduler.git
synced 2025-04-03 18:10:55 +01:00
Implement async functions and error storage
This commit is contained in:
parent
f366cc6fee
commit
bfb18064c8
4 changed files with 105 additions and 10 deletions
36
src/error_storage.rs
Normal file
36
src/error_storage.rs
Normal file
|
@ -0,0 +1,36 @@
|
|||
use std::sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc, Mutex,
|
||||
};
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ErrorStorage {
|
||||
is_some: Arc<AtomicBool>,
|
||||
inner: Arc<Mutex<Option<LuaError>>>,
|
||||
}
|
||||
|
||||
impl ErrorStorage {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
is_some: Arc::new(AtomicBool::new(false)),
|
||||
inner: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn take(&self) -> Option<LuaError> {
|
||||
if self.is_some.load(Ordering::Relaxed) {
|
||||
self.inner.lock().unwrap().take()
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn replace(&self, e: LuaError) {
|
||||
self.is_some.store(true, Ordering::Relaxed);
|
||||
self.inner.lock().unwrap().replace(e);
|
||||
}
|
||||
}
|
42
src/lua_ext.rs
Normal file
42
src/lua_ext.rs
Normal file
|
@ -0,0 +1,42 @@
|
|||
use std::future::Future;
|
||||
|
||||
use mlua::prelude::*;
|
||||
use tokio::spawn;
|
||||
|
||||
use crate::{Args, Message, MessageSender, ThreadId};
|
||||
|
||||
pub trait LuaSchedulerExt<'lua> {
|
||||
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
|
||||
where
|
||||
A: FromLuaMulti<'lua> + 'static,
|
||||
R: Into<Args> + Send + 'static,
|
||||
F: Fn(&'lua Lua, A) -> FR + 'static,
|
||||
FR: Future<Output = LuaResult<R>> + Send + 'static;
|
||||
}
|
||||
|
||||
impl<'lua> LuaSchedulerExt<'lua> for Lua {
|
||||
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
|
||||
where
|
||||
A: FromLuaMulti<'lua> + 'static,
|
||||
R: Into<Args> + Send + 'static,
|
||||
F: Fn(&'lua Lua, A) -> FR + 'static,
|
||||
FR: Future<Output = LuaResult<R>> + Send + 'static,
|
||||
{
|
||||
let tx = self.app_data_ref::<MessageSender>().unwrap().clone();
|
||||
|
||||
self.create_function(move |lua, args: A| {
|
||||
let thread_id = ThreadId::from(lua.current_thread());
|
||||
let fut = func(lua, args);
|
||||
let tx = tx.clone();
|
||||
|
||||
spawn(async move {
|
||||
tx.send(match fut.await {
|
||||
Ok(args) => Message::Resume(thread_id, Ok(args.into())),
|
||||
Err(e) => Message::Resume(thread_id, Err(e)),
|
||||
})
|
||||
});
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
}
|
35
src/main.rs
35
src/main.rs
|
@ -10,12 +10,15 @@ use tokio::{
|
|||
};
|
||||
|
||||
mod args;
|
||||
mod error_storage;
|
||||
mod lua;
|
||||
mod lua_ext;
|
||||
mod message;
|
||||
mod stats;
|
||||
mod thread_id;
|
||||
|
||||
use args::*;
|
||||
use error_storage::*;
|
||||
use lua::*;
|
||||
use message::*;
|
||||
use stats::*;
|
||||
|
@ -61,18 +64,25 @@ fn main() {
|
|||
fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> {
|
||||
let lua = create_lua(tx.clone())?;
|
||||
|
||||
let error_storage = ErrorStorage::new();
|
||||
let error_storage_interrupt = error_storage.clone();
|
||||
lua.set_interrupt(move |_| match error_storage_interrupt.take() {
|
||||
Some(e) => Err(e),
|
||||
None => Ok(LuaVmState::Continue),
|
||||
});
|
||||
|
||||
lua.globals()
|
||||
.set("wait", lua.load(WAIT_IMPL).into_function()?)?;
|
||||
|
||||
let mut yielded_threads: GxHashMap<ThreadId, LuaThread> = GxHashMap::default();
|
||||
let mut runnable_threads: GxHashMap<ThreadId, (LuaThread, Args)> = GxHashMap::default();
|
||||
let mut yielded_threads = GxHashMap::default();
|
||||
let mut runnable_threads = GxHashMap::default();
|
||||
|
||||
println!("Running {NUM_TEST_BATCHES} batches");
|
||||
for _ in 0..NUM_TEST_BATCHES {
|
||||
let main_fn = lua.load(MAIN_CHUNK).into_function()?;
|
||||
for _ in 0..NUM_TEST_THREADS {
|
||||
let thread = lua.create_thread(main_fn.clone())?;
|
||||
runnable_threads.insert(ThreadId::from(&thread), (thread, Args::new()));
|
||||
runnable_threads.insert(ThreadId::from(&thread), (thread, Ok(Args::new())));
|
||||
}
|
||||
|
||||
loop {
|
||||
|
@ -82,13 +92,20 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
|
|||
}
|
||||
|
||||
// Resume as many threads as possible
|
||||
for (thread_id, (thread, args)) in runnable_threads.drain() {
|
||||
for (thread_id, (thread, res)) in runnable_threads.drain() {
|
||||
stats.incr(StatsCounter::ThreadResumed);
|
||||
// NOTE: If we got an error we don't need to resume with any args
|
||||
let args = match res {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
error_storage.replace(e);
|
||||
Args::from(())
|
||||
}
|
||||
};
|
||||
if let Err(e) = thread.resume::<_, ()>(args) {
|
||||
tx.send(Message::Error(thread_id, Box::new(e)))
|
||||
.expect("failed to send error to async task");
|
||||
}
|
||||
if thread.status() == LuaThreadStatus::Resumable {
|
||||
} else if thread.status() == LuaThreadStatus::Resumable {
|
||||
yielded_threads.insert(thread_id, thread);
|
||||
}
|
||||
}
|
||||
|
@ -100,9 +117,9 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
|
|||
// Set up message processor - we mutably borrow both yielded_threads and runnable_threads
|
||||
// so we can't really do this outside of the loop, but it compiles down to the same thing
|
||||
let mut process_message = |message| match message {
|
||||
Message::Resume(thread_id, args) => {
|
||||
Message::Resume(thread_id, res) => {
|
||||
if let Some(thread) = yielded_threads.remove(&thread_id) {
|
||||
runnable_threads.insert(thread_id, (thread, args));
|
||||
runnable_threads.insert(thread_id, (thread, res));
|
||||
}
|
||||
}
|
||||
Message::Cancel(thread_id) => {
|
||||
|
@ -160,7 +177,7 @@ async fn main_async_task(
|
|||
spawn(async move {
|
||||
sleep(duration).await;
|
||||
let elapsed = Instant::now() - yielded_at;
|
||||
tx.send(Message::Resume(thread_id, Args::from(elapsed)))
|
||||
tx.send(Message::Resume(thread_id, Ok(Args::from(elapsed))))
|
||||
});
|
||||
}
|
||||
Message::Error(_, e) => {
|
||||
|
|
|
@ -12,7 +12,7 @@ pub type MessageSender = UnboundedSender<Message>;
|
|||
pub type MessageReceiver = UnboundedReceiver<Message>;
|
||||
|
||||
pub enum Message {
|
||||
Resume(ThreadId, Args),
|
||||
Resume(ThreadId, LuaResult<Args>),
|
||||
Cancel(ThreadId),
|
||||
Sleep(ThreadId, Instant, Duration),
|
||||
Error(ThreadId, Box<LuaError>),
|
||||
|
|
Loading…
Add table
Reference in a new issue