From f366cc6feef37222a248530a0308a09125869f3d Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Wed, 17 Jan 2024 11:59:02 +0100 Subject: [PATCH] Split stuff into proper modules --- src/args.rs | 2 + src/lua.rs | 58 ++++++++++++++++++++++++ src/main.rs | 120 ++++++------------------------------------------- src/message.rs | 21 +++++++++ src/stats.rs | 40 +++++++++++++++++ 5 files changed, 135 insertions(+), 106 deletions(-) create mode 100644 src/lua.rs create mode 100644 src/message.rs create mode 100644 src/stats.rs diff --git a/src/args.rs b/src/args.rs index a6a86e8..860c4ce 100644 --- a/src/args.rs +++ b/src/args.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + use std::time::Duration; use mlua::prelude::*; diff --git a/src/lua.rs b/src/lua.rs new file mode 100644 index 0000000..2bda0ad --- /dev/null +++ b/src/lua.rs @@ -0,0 +1,58 @@ +use std::time::Duration; + +use mlua::prelude::*; +use tokio::time::Instant; + +use crate::{Message, MessageSender, ThreadId}; + +pub fn create_lua(tx: MessageSender) -> LuaResult { + let lua = Lua::new(); + lua.enable_jit(true); + lua.set_app_data(tx.clone()); + + // Resumption + let tx_resume = tx.clone(); + lua.globals().set( + "__scheduler__resumeAfter", + LuaFunction::wrap(move |lua, duration: f64| { + let thread_id = ThreadId::from(lua.current_thread()); + let yielded_at = Instant::now(); + let duration = Duration::from_secs_f64(duration); + tx_resume + .send(Message::Sleep(thread_id, yielded_at, duration)) + .into_lua_err() + }), + )?; + + // Cancellation + let tx_cancel = tx.clone(); + lua.globals().set( + "__scheduler__cancel", + LuaFunction::wrap(move |_, thread: LuaThread| { + let thread_id = ThreadId::from(thread); + tx_cancel.send(Message::Cancel(thread_id)).into_lua_err() + }), + )?; + + // Stdout + let tx_stdout = tx.clone(); + lua.globals().set( + "__scheduler__writeStdout", + LuaFunction::wrap(move |_, s: LuaString| { + let bytes = s.as_bytes().to_vec(); + tx_stdout.send(Message::WriteStdout(bytes)).into_lua_err() + }), + )?; + + // Stderr + let tx_stderr = tx.clone(); + lua.globals().set( + "__scheduler__writeStderr", + LuaFunction::wrap(move |_, s: LuaString| { + let bytes = s.as_bytes().to_vec(); + tx_stderr.send(Message::WriteStderr(bytes)).into_lua_err() + }), + )?; + + Ok(lua) +} diff --git a/src/main.rs b/src/main.rs index bd36f91..6704734 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,22 +1,25 @@ -use std::{sync::Arc, time::Duration}; - -use dashmap::DashMap; use gxhash::GxHashMap; use mlua::prelude::*; use tokio::{ io::{self, AsyncWriteExt}, runtime::Runtime as TokioRuntime, select, spawn, - sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender}, + sync::mpsc::{unbounded_channel, UnboundedReceiver}, task::{spawn_blocking, LocalSet}, time::{sleep, Instant}, }; mod args; +mod lua; +mod message; +mod stats; mod thread_id; -use args::Args; -use thread_id::ThreadId; +use args::*; +use lua::*; +use message::*; +use stats::*; +use thread_id::*; const NUM_TEST_BATCHES: usize = 20; const NUM_TEST_THREADS: usize = 50_000; @@ -30,54 +33,6 @@ __scheduler__resumeAfter(...) return coroutine.yield() "#; -type MessageSender = UnboundedSender; -type MessageReceiver = UnboundedReceiver; - -enum Message { - Resume(ThreadId, Args), - Cancel(ThreadId), - Sleep(ThreadId, Instant, Duration), - Error(ThreadId, Box), - WriteStdout(Vec), - WriteStderr(Vec), -} - -#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] -enum StatsCounter { - ThreadResumed, - ThreadCancelled, - ThreadSlept, - ThreadErrored, - WriteStdout, - WriteStderr, -} - -#[derive(Debug, Clone)] -struct Stats { - start: Instant, - counters: Arc>, -} - -impl Stats { - fn new() -> Self { - Self { - start: Instant::now(), - counters: Arc::new(DashMap::new()), - } - } - - fn incr(&self, counter: StatsCounter) { - self.counters - .entry(counter) - .and_modify(|c| *c += 1) - .or_insert(1); - } - - fn elapsed(&self) -> Duration { - Instant::now() - self.start - } -} - fn main() { let rt = TokioRuntime::new().unwrap(); let set = LocalSet::new(); @@ -104,58 +59,10 @@ fn main() { } fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> { - let lua = Lua::new(); - let g = lua.globals(); + let lua = create_lua(tx.clone())?; - lua.enable_jit(true); - lua.set_app_data(tx.clone()); - - let send_message = |lua: &Lua, msg: Message| { - lua.app_data_ref::() - .unwrap() - .send(msg) - .unwrap(); - }; - - g.set( - "__scheduler__resumeAfter", - LuaFunction::wrap(move |lua, duration: f64| { - let thread_id = ThreadId::from(lua.current_thread()); - let yielded_at = Instant::now(); - let duration = Duration::from_secs_f64(duration); - send_message(lua, Message::Sleep(thread_id, yielded_at, duration)); - Ok(()) - }), - )?; - - g.set( - "__scheduler__cancel", - LuaFunction::wrap(move |lua, thread: LuaThread| { - let thread_id = ThreadId::from(thread); - send_message(lua, Message::Cancel(thread_id)); - Ok(()) - }), - )?; - - g.set( - "__scheduler__writeStdout", - LuaFunction::wrap(move |lua, s: LuaString| { - let bytes = s.as_bytes().to_vec(); - send_message(lua, Message::WriteStdout(bytes)); - Ok(()) - }), - )?; - - g.set( - "__scheduler__writeStderr", - LuaFunction::wrap(move |lua, s: LuaString| { - let bytes = s.as_bytes().to_vec(); - send_message(lua, Message::WriteStderr(bytes)); - Ok(()) - }), - )?; - - g.set("wait", lua.load(WAIT_IMPL).into_function()?)?; + lua.globals() + .set("wait", lua.load(WAIT_IMPL).into_function()?)?; let mut yielded_threads: GxHashMap = GxHashMap::default(); let mut runnable_threads: GxHashMap = GxHashMap::default(); @@ -178,7 +85,8 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu for (thread_id, (thread, args)) in runnable_threads.drain() { stats.incr(StatsCounter::ThreadResumed); if let Err(e) = thread.resume::<_, ()>(args) { - send_message(&lua, Message::Error(thread_id, Box::new(e))); + tx.send(Message::Error(thread_id, Box::new(e))) + .expect("failed to send error to async task"); } if thread.status() == LuaThreadStatus::Resumable { yielded_threads.insert(thread_id, thread); diff --git a/src/message.rs b/src/message.rs new file mode 100644 index 0000000..7ca400a --- /dev/null +++ b/src/message.rs @@ -0,0 +1,21 @@ +use std::time::Duration; + +use mlua::prelude::*; +use tokio::{ + sync::mpsc::{UnboundedReceiver, UnboundedSender}, + time::Instant, +}; + +use crate::{Args, ThreadId}; + +pub type MessageSender = UnboundedSender; +pub type MessageReceiver = UnboundedReceiver; + +pub enum Message { + Resume(ThreadId, Args), + Cancel(ThreadId), + Sleep(ThreadId, Instant, Duration), + Error(ThreadId, Box), + WriteStdout(Vec), + WriteStderr(Vec), +} diff --git a/src/stats.rs b/src/stats.rs new file mode 100644 index 0000000..fb1d97f --- /dev/null +++ b/src/stats.rs @@ -0,0 +1,40 @@ +use std::{sync::Arc, time::Duration}; + +use dashmap::DashMap; +use tokio::time::Instant; + +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] +pub enum StatsCounter { + ThreadResumed, + ThreadCancelled, + ThreadSlept, + ThreadErrored, + WriteStdout, + WriteStderr, +} + +#[derive(Debug, Clone)] +pub struct Stats { + start: Instant, + pub counters: Arc>, +} + +impl Stats { + pub fn new() -> Self { + Self { + start: Instant::now(), + counters: Arc::new(DashMap::new()), + } + } + + pub fn incr(&self, counter: StatsCounter) { + self.counters + .entry(counter) + .and_modify(|c| *c += 1) + .or_insert(1); + } + + pub fn elapsed(&self) -> Duration { + Instant::now() - self.start + } +}