Implement async functions and error storage

This commit is contained in:
Filip Tibell 2024-01-17 12:45:42 +01:00
parent f366cc6fee
commit bfb18064c8
No known key found for this signature in database
4 changed files with 105 additions and 10 deletions

36
src/error_storage.rs Normal file
View file

@ -0,0 +1,36 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, Mutex,
};
use mlua::prelude::*;
#[derive(Debug, Clone)]
pub struct ErrorStorage {
is_some: Arc<AtomicBool>,
inner: Arc<Mutex<Option<LuaError>>>,
}
impl ErrorStorage {
pub fn new() -> Self {
Self {
is_some: Arc::new(AtomicBool::new(false)),
inner: Arc::new(Mutex::new(None)),
}
}
#[inline]
pub fn take(&self) -> Option<LuaError> {
if self.is_some.load(Ordering::Relaxed) {
self.inner.lock().unwrap().take()
} else {
None
}
}
#[inline]
pub fn replace(&self, e: LuaError) {
self.is_some.store(true, Ordering::Relaxed);
self.inner.lock().unwrap().replace(e);
}
}

42
src/lua_ext.rs Normal file
View file

@ -0,0 +1,42 @@
use std::future::Future;
use mlua::prelude::*;
use tokio::spawn;
use crate::{Args, Message, MessageSender, ThreadId};
pub trait LuaSchedulerExt<'lua> {
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
where
A: FromLuaMulti<'lua> + 'static,
R: Into<Args> + Send + 'static,
F: Fn(&'lua Lua, A) -> FR + 'static,
FR: Future<Output = LuaResult<R>> + Send + 'static;
}
impl<'lua> LuaSchedulerExt<'lua> for Lua {
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
where
A: FromLuaMulti<'lua> + 'static,
R: Into<Args> + Send + 'static,
F: Fn(&'lua Lua, A) -> FR + 'static,
FR: Future<Output = LuaResult<R>> + Send + 'static,
{
let tx = self.app_data_ref::<MessageSender>().unwrap().clone();
self.create_function(move |lua, args: A| {
let thread_id = ThreadId::from(lua.current_thread());
let fut = func(lua, args);
let tx = tx.clone();
spawn(async move {
tx.send(match fut.await {
Ok(args) => Message::Resume(thread_id, Ok(args.into())),
Err(e) => Message::Resume(thread_id, Err(e)),
})
});
Ok(())
})
}
}

View file

@ -10,12 +10,15 @@ use tokio::{
};
mod args;
mod error_storage;
mod lua;
mod lua_ext;
mod message;
mod stats;
mod thread_id;
use args::*;
use error_storage::*;
use lua::*;
use message::*;
use stats::*;
@ -61,18 +64,25 @@ fn main() {
fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> {
let lua = create_lua(tx.clone())?;
let error_storage = ErrorStorage::new();
let error_storage_interrupt = error_storage.clone();
lua.set_interrupt(move |_| match error_storage_interrupt.take() {
Some(e) => Err(e),
None => Ok(LuaVmState::Continue),
});
lua.globals()
.set("wait", lua.load(WAIT_IMPL).into_function()?)?;
let mut yielded_threads: GxHashMap<ThreadId, LuaThread> = GxHashMap::default();
let mut runnable_threads: GxHashMap<ThreadId, (LuaThread, Args)> = GxHashMap::default();
let mut yielded_threads = GxHashMap::default();
let mut runnable_threads = GxHashMap::default();
println!("Running {NUM_TEST_BATCHES} batches");
for _ in 0..NUM_TEST_BATCHES {
let main_fn = lua.load(MAIN_CHUNK).into_function()?;
for _ in 0..NUM_TEST_THREADS {
let thread = lua.create_thread(main_fn.clone())?;
runnable_threads.insert(ThreadId::from(&thread), (thread, Args::new()));
runnable_threads.insert(ThreadId::from(&thread), (thread, Ok(Args::new())));
}
loop {
@ -82,13 +92,20 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
}
// Resume as many threads as possible
for (thread_id, (thread, args)) in runnable_threads.drain() {
for (thread_id, (thread, res)) in runnable_threads.drain() {
stats.incr(StatsCounter::ThreadResumed);
// NOTE: If we got an error we don't need to resume with any args
let args = match res {
Ok(a) => a,
Err(e) => {
error_storage.replace(e);
Args::from(())
}
};
if let Err(e) = thread.resume::<_, ()>(args) {
tx.send(Message::Error(thread_id, Box::new(e)))
.expect("failed to send error to async task");
}
if thread.status() == LuaThreadStatus::Resumable {
} else if thread.status() == LuaThreadStatus::Resumable {
yielded_threads.insert(thread_id, thread);
}
}
@ -100,9 +117,9 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
// Set up message processor - we mutably borrow both yielded_threads and runnable_threads
// so we can't really do this outside of the loop, but it compiles down to the same thing
let mut process_message = |message| match message {
Message::Resume(thread_id, args) => {
Message::Resume(thread_id, res) => {
if let Some(thread) = yielded_threads.remove(&thread_id) {
runnable_threads.insert(thread_id, (thread, args));
runnable_threads.insert(thread_id, (thread, res));
}
}
Message::Cancel(thread_id) => {
@ -160,7 +177,7 @@ async fn main_async_task(
spawn(async move {
sleep(duration).await;
let elapsed = Instant::now() - yielded_at;
tx.send(Message::Resume(thread_id, Args::from(elapsed)))
tx.send(Message::Resume(thread_id, Ok(Args::from(elapsed))))
});
}
Message::Error(_, e) => {

View file

@ -12,7 +12,7 @@ pub type MessageSender = UnboundedSender<Message>;
pub type MessageReceiver = UnboundedReceiver<Message>;
pub enum Message {
Resume(ThreadId, Args),
Resume(ThreadId, LuaResult<Args>),
Cancel(ThreadId),
Sleep(ThreadId, Instant, Duration),
Error(ThreadId, Box<LuaError>),