Split stuff into proper modules

This commit is contained in:
Filip Tibell 2024-01-17 11:59:02 +01:00
parent 6143267ea5
commit f366cc6fee
No known key found for this signature in database
5 changed files with 135 additions and 106 deletions

View file

@ -1,3 +1,5 @@
#![allow(dead_code)]
use std::time::Duration;
use mlua::prelude::*;

58
src/lua.rs Normal file
View file

@ -0,0 +1,58 @@
use std::time::Duration;
use mlua::prelude::*;
use tokio::time::Instant;
use crate::{Message, MessageSender, ThreadId};
pub fn create_lua(tx: MessageSender) -> LuaResult<Lua> {
let lua = Lua::new();
lua.enable_jit(true);
lua.set_app_data(tx.clone());
// Resumption
let tx_resume = tx.clone();
lua.globals().set(
"__scheduler__resumeAfter",
LuaFunction::wrap(move |lua, duration: f64| {
let thread_id = ThreadId::from(lua.current_thread());
let yielded_at = Instant::now();
let duration = Duration::from_secs_f64(duration);
tx_resume
.send(Message::Sleep(thread_id, yielded_at, duration))
.into_lua_err()
}),
)?;
// Cancellation
let tx_cancel = tx.clone();
lua.globals().set(
"__scheduler__cancel",
LuaFunction::wrap(move |_, thread: LuaThread| {
let thread_id = ThreadId::from(thread);
tx_cancel.send(Message::Cancel(thread_id)).into_lua_err()
}),
)?;
// Stdout
let tx_stdout = tx.clone();
lua.globals().set(
"__scheduler__writeStdout",
LuaFunction::wrap(move |_, s: LuaString| {
let bytes = s.as_bytes().to_vec();
tx_stdout.send(Message::WriteStdout(bytes)).into_lua_err()
}),
)?;
// Stderr
let tx_stderr = tx.clone();
lua.globals().set(
"__scheduler__writeStderr",
LuaFunction::wrap(move |_, s: LuaString| {
let bytes = s.as_bytes().to_vec();
tx_stderr.send(Message::WriteStderr(bytes)).into_lua_err()
}),
)?;
Ok(lua)
}

View file

@ -1,22 +1,25 @@
use std::{sync::Arc, time::Duration};
use dashmap::DashMap;
use gxhash::GxHashMap;
use mlua::prelude::*;
use tokio::{
io::{self, AsyncWriteExt},
runtime::Runtime as TokioRuntime,
select, spawn,
sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender},
sync::mpsc::{unbounded_channel, UnboundedReceiver},
task::{spawn_blocking, LocalSet},
time::{sleep, Instant},
};
mod args;
mod lua;
mod message;
mod stats;
mod thread_id;
use args::Args;
use thread_id::ThreadId;
use args::*;
use lua::*;
use message::*;
use stats::*;
use thread_id::*;
const NUM_TEST_BATCHES: usize = 20;
const NUM_TEST_THREADS: usize = 50_000;
@ -30,54 +33,6 @@ __scheduler__resumeAfter(...)
return coroutine.yield()
"#;
type MessageSender = UnboundedSender<Message>;
type MessageReceiver = UnboundedReceiver<Message>;
enum Message {
Resume(ThreadId, Args),
Cancel(ThreadId),
Sleep(ThreadId, Instant, Duration),
Error(ThreadId, Box<LuaError>),
WriteStdout(Vec<u8>),
WriteStderr(Vec<u8>),
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
enum StatsCounter {
ThreadResumed,
ThreadCancelled,
ThreadSlept,
ThreadErrored,
WriteStdout,
WriteStderr,
}
#[derive(Debug, Clone)]
struct Stats {
start: Instant,
counters: Arc<DashMap<StatsCounter, usize>>,
}
impl Stats {
fn new() -> Self {
Self {
start: Instant::now(),
counters: Arc::new(DashMap::new()),
}
}
fn incr(&self, counter: StatsCounter) {
self.counters
.entry(counter)
.and_modify(|c| *c += 1)
.or_insert(1);
}
fn elapsed(&self) -> Duration {
Instant::now() - self.start
}
}
fn main() {
let rt = TokioRuntime::new().unwrap();
let set = LocalSet::new();
@ -104,58 +59,10 @@ fn main() {
}
fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> LuaResult<()> {
let lua = Lua::new();
let g = lua.globals();
let lua = create_lua(tx.clone())?;
lua.enable_jit(true);
lua.set_app_data(tx.clone());
let send_message = |lua: &Lua, msg: Message| {
lua.app_data_ref::<MessageSender>()
.unwrap()
.send(msg)
.unwrap();
};
g.set(
"__scheduler__resumeAfter",
LuaFunction::wrap(move |lua, duration: f64| {
let thread_id = ThreadId::from(lua.current_thread());
let yielded_at = Instant::now();
let duration = Duration::from_secs_f64(duration);
send_message(lua, Message::Sleep(thread_id, yielded_at, duration));
Ok(())
}),
)?;
g.set(
"__scheduler__cancel",
LuaFunction::wrap(move |lua, thread: LuaThread| {
let thread_id = ThreadId::from(thread);
send_message(lua, Message::Cancel(thread_id));
Ok(())
}),
)?;
g.set(
"__scheduler__writeStdout",
LuaFunction::wrap(move |lua, s: LuaString| {
let bytes = s.as_bytes().to_vec();
send_message(lua, Message::WriteStdout(bytes));
Ok(())
}),
)?;
g.set(
"__scheduler__writeStderr",
LuaFunction::wrap(move |lua, s: LuaString| {
let bytes = s.as_bytes().to_vec();
send_message(lua, Message::WriteStderr(bytes));
Ok(())
}),
)?;
g.set("wait", lua.load(WAIT_IMPL).into_function()?)?;
lua.globals()
.set("wait", lua.load(WAIT_IMPL).into_function()?)?;
let mut yielded_threads: GxHashMap<ThreadId, LuaThread> = GxHashMap::default();
let mut runnable_threads: GxHashMap<ThreadId, (LuaThread, Args)> = GxHashMap::default();
@ -178,7 +85,8 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
for (thread_id, (thread, args)) in runnable_threads.drain() {
stats.incr(StatsCounter::ThreadResumed);
if let Err(e) = thread.resume::<_, ()>(args) {
send_message(&lua, Message::Error(thread_id, Box::new(e)));
tx.send(Message::Error(thread_id, Box::new(e)))
.expect("failed to send error to async task");
}
if thread.status() == LuaThreadStatus::Resumable {
yielded_threads.insert(thread_id, thread);

21
src/message.rs Normal file
View file

@ -0,0 +1,21 @@
use std::time::Duration;
use mlua::prelude::*;
use tokio::{
sync::mpsc::{UnboundedReceiver, UnboundedSender},
time::Instant,
};
use crate::{Args, ThreadId};
pub type MessageSender = UnboundedSender<Message>;
pub type MessageReceiver = UnboundedReceiver<Message>;
pub enum Message {
Resume(ThreadId, Args),
Cancel(ThreadId),
Sleep(ThreadId, Instant, Duration),
Error(ThreadId, Box<LuaError>),
WriteStdout(Vec<u8>),
WriteStderr(Vec<u8>),
}

40
src/stats.rs Normal file
View file

@ -0,0 +1,40 @@
use std::{sync::Arc, time::Duration};
use dashmap::DashMap;
use tokio::time::Instant;
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub enum StatsCounter {
ThreadResumed,
ThreadCancelled,
ThreadSlept,
ThreadErrored,
WriteStdout,
WriteStderr,
}
#[derive(Debug, Clone)]
pub struct Stats {
start: Instant,
pub counters: Arc<DashMap<StatsCounter, usize>>,
}
impl Stats {
pub fn new() -> Self {
Self {
start: Instant::now(),
counters: Arc::new(DashMap::new()),
}
}
pub fn incr(&self, counter: StatsCounter) {
self.counters
.entry(counter)
.and_modify(|c| *c += 1)
.or_insert(1);
}
pub fn elapsed(&self) -> Duration {
Instant::now() - self.start
}
}