mirror of
https://github.com/lune-org/mlua-luau-scheduler.git
synced 2025-04-07 12:00:58 +01:00
Implement async functions and error storage
This commit is contained in:
parent
f366cc6fee
commit
bfb18064c8
4 changed files with 105 additions and 10 deletions
36
src/error_storage.rs
Normal file
36
src/error_storage.rs
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
use std::sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc, Mutex,
|
||||||
|
};
|
||||||
|
|
||||||
|
use mlua::prelude::*;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ErrorStorage {
|
||||||
|
is_some: Arc<AtomicBool>,
|
||||||
|
inner: Arc<Mutex<Option<LuaError>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ErrorStorage {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
is_some: Arc::new(AtomicBool::new(false)),
|
||||||
|
inner: Arc::new(Mutex::new(None)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn take(&self) -> Option<LuaError> {
|
||||||
|
if self.is_some.load(Ordering::Relaxed) {
|
||||||
|
self.inner.lock().unwrap().take()
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn replace(&self, e: LuaError) {
|
||||||
|
self.is_some.store(true, Ordering::Relaxed);
|
||||||
|
self.inner.lock().unwrap().replace(e);
|
||||||
|
}
|
||||||
|
}
|
42
src/lua_ext.rs
Normal file
42
src/lua_ext.rs
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
use std::future::Future;
|
||||||
|
|
||||||
|
use mlua::prelude::*;
|
||||||
|
use tokio::spawn;
|
||||||
|
|
||||||
|
use crate::{Args, Message, MessageSender, ThreadId};
|
||||||
|
|
||||||
|
pub trait LuaSchedulerExt<'lua> {
|
||||||
|
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
|
||||||
|
where
|
||||||
|
A: FromLuaMulti<'lua> + 'static,
|
||||||
|
R: Into<Args> + Send + 'static,
|
||||||
|
F: Fn(&'lua Lua, A) -> FR + 'static,
|
||||||
|
FR: Future<Output = LuaResult<R>> + Send + 'static;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'lua> LuaSchedulerExt<'lua> for Lua {
|
||||||
|
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
|
||||||
|
where
|
||||||
|
A: FromLuaMulti<'lua> + 'static,
|
||||||
|
R: Into<Args> + Send + 'static,
|
||||||
|
F: Fn(&'lua Lua, A) -> FR + 'static,
|
||||||
|
FR: Future<Output = LuaResult<R>> + Send + 'static,
|
||||||
|
{
|
||||||
|
let tx = self.app_data_ref::<MessageSender>().unwrap().clone();
|
||||||
|
|
||||||
|
self.create_function(move |lua, args: A| {
|
||||||
|
let thread_id = ThreadId::from(lua.current_thread());
|
||||||
|
let fut = func(lua, args);
|
||||||
|
let tx = tx.clone();
|
||||||
|
|
||||||
|
spawn(async move {
|
||||||
|
tx.send(match fut.await {
|
||||||
|
Ok(args) => Message::Resume(thread_id, Ok(args.into())),
|
||||||
|
Err(e) => Message::Resume(thread_id, Err(e)),
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
35
src/main.rs
35
src/main.rs
|
@ -10,12 +10,15 @@ use tokio::{
|
||||||
};
|
};
|
||||||
|
|
||||||
mod args;
|
mod args;
|
||||||
|
mod error_storage;
|
||||||
mod lua;
|
mod lua;
|
||||||
|
mod lua_ext;
|
||||||
mod message;
|
mod message;
|
||||||
mod stats;
|
mod stats;
|
||||||
mod thread_id;
|
mod thread_id;
|
||||||
|
|
||||||
use args::*;
|
use args::*;
|
||||||
|
use error_storage::*;
|
||||||
use lua::*;
|
use lua::*;
|
||||||
use message::*;
|
use message::*;
|
||||||
use stats::*;
|
use stats::*;
|
||||||
|
@ -61,18 +64,25 @@ fn main() {
|
||||||
fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> {
|
fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> {
|
||||||
let lua = create_lua(tx.clone())?;
|
let lua = create_lua(tx.clone())?;
|
||||||
|
|
||||||
|
let error_storage = ErrorStorage::new();
|
||||||
|
let error_storage_interrupt = error_storage.clone();
|
||||||
|
lua.set_interrupt(move |_| match error_storage_interrupt.take() {
|
||||||
|
Some(e) => Err(e),
|
||||||
|
None => Ok(LuaVmState::Continue),
|
||||||
|
});
|
||||||
|
|
||||||
lua.globals()
|
lua.globals()
|
||||||
.set("wait", lua.load(WAIT_IMPL).into_function()?)?;
|
.set("wait", lua.load(WAIT_IMPL).into_function()?)?;
|
||||||
|
|
||||||
let mut yielded_threads: GxHashMap<ThreadId, LuaThread> = GxHashMap::default();
|
let mut yielded_threads = GxHashMap::default();
|
||||||
let mut runnable_threads: GxHashMap<ThreadId, (LuaThread, Args)> = GxHashMap::default();
|
let mut runnable_threads = GxHashMap::default();
|
||||||
|
|
||||||
println!("Running {NUM_TEST_BATCHES} batches");
|
println!("Running {NUM_TEST_BATCHES} batches");
|
||||||
for _ in 0..NUM_TEST_BATCHES {
|
for _ in 0..NUM_TEST_BATCHES {
|
||||||
let main_fn = lua.load(MAIN_CHUNK).into_function()?;
|
let main_fn = lua.load(MAIN_CHUNK).into_function()?;
|
||||||
for _ in 0..NUM_TEST_THREADS {
|
for _ in 0..NUM_TEST_THREADS {
|
||||||
let thread = lua.create_thread(main_fn.clone())?;
|
let thread = lua.create_thread(main_fn.clone())?;
|
||||||
runnable_threads.insert(ThreadId::from(&thread), (thread, Args::new()));
|
runnable_threads.insert(ThreadId::from(&thread), (thread, Ok(Args::new())));
|
||||||
}
|
}
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
@ -82,13 +92,20 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resume as many threads as possible
|
// Resume as many threads as possible
|
||||||
for (thread_id, (thread, args)) in runnable_threads.drain() {
|
for (thread_id, (thread, res)) in runnable_threads.drain() {
|
||||||
stats.incr(StatsCounter::ThreadResumed);
|
stats.incr(StatsCounter::ThreadResumed);
|
||||||
|
// NOTE: If we got an error we don't need to resume with any args
|
||||||
|
let args = match res {
|
||||||
|
Ok(a) => a,
|
||||||
|
Err(e) => {
|
||||||
|
error_storage.replace(e);
|
||||||
|
Args::from(())
|
||||||
|
}
|
||||||
|
};
|
||||||
if let Err(e) = thread.resume::<_, ()>(args) {
|
if let Err(e) = thread.resume::<_, ()>(args) {
|
||||||
tx.send(Message::Error(thread_id, Box::new(e)))
|
tx.send(Message::Error(thread_id, Box::new(e)))
|
||||||
.expect("failed to send error to async task");
|
.expect("failed to send error to async task");
|
||||||
}
|
} else if thread.status() == LuaThreadStatus::Resumable {
|
||||||
if thread.status() == LuaThreadStatus::Resumable {
|
|
||||||
yielded_threads.insert(thread_id, thread);
|
yielded_threads.insert(thread_id, thread);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -100,9 +117,9 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
|
||||||
// Set up message processor - we mutably borrow both yielded_threads and runnable_threads
|
// Set up message processor - we mutably borrow both yielded_threads and runnable_threads
|
||||||
// so we can't really do this outside of the loop, but it compiles down to the same thing
|
// so we can't really do this outside of the loop, but it compiles down to the same thing
|
||||||
let mut process_message = |message| match message {
|
let mut process_message = |message| match message {
|
||||||
Message::Resume(thread_id, args) => {
|
Message::Resume(thread_id, res) => {
|
||||||
if let Some(thread) = yielded_threads.remove(&thread_id) {
|
if let Some(thread) = yielded_threads.remove(&thread_id) {
|
||||||
runnable_threads.insert(thread_id, (thread, args));
|
runnable_threads.insert(thread_id, (thread, res));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Message::Cancel(thread_id) => {
|
Message::Cancel(thread_id) => {
|
||||||
|
@ -160,7 +177,7 @@ async fn main_async_task(
|
||||||
spawn(async move {
|
spawn(async move {
|
||||||
sleep(duration).await;
|
sleep(duration).await;
|
||||||
let elapsed = Instant::now() - yielded_at;
|
let elapsed = Instant::now() - yielded_at;
|
||||||
tx.send(Message::Resume(thread_id, Args::from(elapsed)))
|
tx.send(Message::Resume(thread_id, Ok(Args::from(elapsed))))
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Message::Error(_, e) => {
|
Message::Error(_, e) => {
|
||||||
|
|
|
@ -12,7 +12,7 @@ pub type MessageSender = UnboundedSender<Message>;
|
||||||
pub type MessageReceiver = UnboundedReceiver<Message>;
|
pub type MessageReceiver = UnboundedReceiver<Message>;
|
||||||
|
|
||||||
pub enum Message {
|
pub enum Message {
|
||||||
Resume(ThreadId, Args),
|
Resume(ThreadId, LuaResult<Args>),
|
||||||
Cancel(ThreadId),
|
Cancel(ThreadId),
|
||||||
Sleep(ThreadId, Instant, Duration),
|
Sleep(ThreadId, Instant, Duration),
|
||||||
Error(ThreadId, Box<LuaError>),
|
Error(ThreadId, Box<LuaError>),
|
||||||
|
|
Loading…
Add table
Reference in a new issue