Implement args passing, more speed

This commit is contained in:
Filip Tibell 2024-01-17 11:02:21 +01:00
parent ec135b8a39
commit 084c8a1a21
No known key found for this signature in database
2 changed files with 216 additions and 23 deletions

184
src/args.rs Normal file
View file

@ -0,0 +1,184 @@
use std::time::Duration;
use mlua::prelude::*;
#[derive(Debug, Default)]
pub enum Arg {
#[default]
Nil,
Bool(bool),
Number(f64),
String(String),
}
impl IntoLua<'_> for Arg {
#[inline]
fn into_lua(self, lua: &Lua) -> LuaResult<LuaValue> {
match self {
Arg::Nil => Ok(LuaValue::Nil),
Arg::Bool(b) => Ok(LuaValue::Boolean(b)),
Arg::Number(n) => Ok(LuaValue::Number(n)),
Arg::String(s) => Ok(LuaValue::String(lua.create_string(&s)?)),
}
}
}
// Primitives
impl From<()> for Arg {
#[inline]
fn from(_: ()) -> Self {
Arg::Nil
}
}
impl From<bool> for Arg {
#[inline]
fn from(b: bool) -> Self {
Arg::Bool(b)
}
}
impl From<f64> for Arg {
#[inline]
fn from(n: f64) -> Self {
Arg::Number(n)
}
}
impl From<String> for Arg {
#[inline]
fn from(s: String) -> Self {
Arg::String(s)
}
}
// Other types
impl From<Duration> for Arg {
#[inline]
fn from(d: Duration) -> Self {
Arg::Number(d.as_secs_f64())
}
}
// Multi args
#[derive(Debug, Default)]
pub struct Args {
inner: Vec<Arg>,
}
impl Args {
#[inline]
pub fn new() -> Self {
Self::default()
}
}
impl IntoLuaMulti<'_> for Args {
#[inline]
fn into_lua_multi(self, lua: &Lua) -> LuaResult<LuaMultiValue> {
let mut values = Vec::new();
for arg in self.inner {
values.push(arg.into_lua(lua)?);
}
Ok(LuaMultiValue::from_vec(values))
}
}
// Boilerplate
impl<T> From<T> for Args
where
T: Into<Arg>,
{
#[inline]
fn from(t: T) -> Self {
Args {
inner: vec![t.into()],
}
}
}
impl<T0, T1> From<(T0, T1)> for Args
where
T0: Into<Arg>,
T1: Into<Arg>,
{
#[inline]
fn from((t0, t1): (T0, T1)) -> Self {
Args {
inner: vec![t0.into(), t1.into()],
}
}
}
impl<T0, T1, T2> From<(T0, T1, T2)> for Args
where
T0: Into<Arg>,
T1: Into<Arg>,
T2: Into<Arg>,
{
#[inline]
fn from((t0, t1, t2): (T0, T1, T2)) -> Self {
Args {
inner: vec![t0.into(), t1.into(), t2.into()],
}
}
}
impl<T0, T1, T2, T3> From<(T0, T1, T2, T3)> for Args
where
T0: Into<Arg>,
T1: Into<Arg>,
T2: Into<Arg>,
T3: Into<Arg>,
{
#[inline]
fn from((t0, t1, t2, t3): (T0, T1, T2, T3)) -> Self {
Args {
inner: vec![t0.into(), t1.into(), t2.into(), t3.into()],
}
}
}
impl<T0, T1, T2, T3, T4> From<(T0, T1, T2, T3, T4)> for Args
where
T0: Into<Arg>,
T1: Into<Arg>,
T2: Into<Arg>,
T3: Into<Arg>,
T4: Into<Arg>,
{
#[inline]
fn from((t0, t1, t2, t3, t4): (T0, T1, T2, T3, T4)) -> Self {
Args {
inner: vec![t0.into(), t1.into(), t2.into(), t3.into(), t4.into()],
}
}
}
impl<T0, T1, T2, T3, T4, T5> From<(T0, T1, T2, T3, T4, T5)> for Args
where
T0: Into<Arg>,
T1: Into<Arg>,
T2: Into<Arg>,
T3: Into<Arg>,
T4: Into<Arg>,
T5: Into<Arg>,
{
#[inline]
fn from((t0, t1, t2, t3, t4, t5): (T0, T1, T2, T3, T4, T5)) -> Self {
Args {
inner: vec![
t0.into(),
t1.into(),
t2.into(),
t3.into(),
t4.into(),
t5.into(),
],
}
}
}

View file

@ -12,31 +12,32 @@ use tokio::{
time::{sleep, Instant},
};
mod args;
mod thread_id;
use args::Args;
use thread_id::ThreadId;
const NUM_TEST_BATCHES: usize = 20;
const NUM_TEST_THREADS: usize = 50_000;
const MAIN_CHUNK: &str = r#"
wait(0.001 * math.random())
wait(0.01 * math.random())
"#;
const WAIT_IMPL: &str = r#"
__scheduler__resumeAfter(...)
coroutine.yield()
return coroutine.yield()
"#;
type ThreadMap<'lua> = GxHashMap<ThreadId, LuaThread<'lua>>;
type MessageSender = UnboundedSender<Message>;
type MessageReceiver = UnboundedReceiver<Message>;
enum Message {
Resume(ThreadId),
Resume(ThreadId, Args),
Cancel(ThreadId),
Sleep(ThreadId, Duration),
Error(ThreadId, LuaError),
Sleep(ThreadId, Instant, Duration),
Error(ThreadId, Box<LuaError>),
WriteStdout(Vec<u8>),
WriteStderr(Vec<u8>),
}
@ -120,8 +121,9 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
"__scheduler__resumeAfter",
LuaFunction::wrap(move |lua, duration: f64| {
let thread_id = ThreadId::from(lua.current_thread());
let yielded_at = Instant::now();
let duration = Duration::from_secs_f64(duration);
send_message(lua, Message::Sleep(thread_id, duration));
send_message(lua, Message::Sleep(thread_id, yielded_at, duration));
Ok(())
}),
)?;
@ -137,31 +139,33 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
g.set(
"__scheduler__writeStdout",
LuaFunction::wrap(move |lua, data: Vec<u8>| {
send_message(lua, Message::WriteStdout(data));
LuaFunction::wrap(move |lua, s: LuaString| {
let bytes = s.as_bytes().to_vec();
send_message(lua, Message::WriteStdout(bytes));
Ok(())
}),
)?;
g.set(
"__scheduler__writeStderr",
LuaFunction::wrap(move |lua, data: Vec<u8>| {
send_message(lua, Message::WriteStderr(data));
LuaFunction::wrap(move |lua, s: LuaString| {
let bytes = s.as_bytes().to_vec();
send_message(lua, Message::WriteStderr(bytes));
Ok(())
}),
)?;
g.set("wait", lua.load(WAIT_IMPL).into_function()?)?;
let mut yielded_threads = ThreadMap::default();
let mut runnable_threads = ThreadMap::default();
let mut yielded_threads: GxHashMap<ThreadId, LuaThread> = GxHashMap::default();
let mut runnable_threads: GxHashMap<ThreadId, (LuaThread, Args)> = GxHashMap::default();
println!("Running {NUM_TEST_BATCHES} batches");
for _ in 0..NUM_TEST_BATCHES {
let main_fn = lua.load(MAIN_CHUNK).into_function()?;
for _ in 0..NUM_TEST_THREADS {
let thread = lua.create_thread(main_fn.clone())?;
runnable_threads.insert(ThreadId::from(&thread), thread);
runnable_threads.insert(ThreadId::from(&thread), (thread, Args::new()));
}
loop {
@ -171,10 +175,10 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
}
// Resume as many threads as possible
for (thread_id, thread) in runnable_threads.drain() {
for (thread_id, (thread, args)) in runnable_threads.drain() {
stats.incr(StatsCounter::ThreadResumed);
if let Err(e) = thread.resume::<_, ()>(()) {
send_message(&lua, Message::Error(thread_id, e));
if let Err(e) = thread.resume::<_, ()>(args) {
send_message(&lua, Message::Error(thread_id, Box::new(e)));
}
if thread.status() == LuaThreadStatus::Resumable {
yielded_threads.insert(thread_id, thread);
@ -187,9 +191,9 @@ fn main_lua_task(mut rx: MessageReceiver, tx: MessageSender, stats: Stats) -> Lu
// Wait for at least one message, but try to receive as many as possible
let mut process_message = |message| match message {
Message::Resume(thread_id) => {
Message::Resume(thread_id, args) => {
if let Some(thread) = yielded_threads.remove(&thread_id) {
runnable_threads.insert(thread_id, thread);
runnable_threads.insert(thread_id, (thread, args));
}
}
Message::Cancel(thread_id) => {
@ -234,7 +238,7 @@ async fn main_async_task(
let mut wrote_stderr = false;
for message in messages.drain(..) {
match message {
Message::Sleep(_, _) => stats.incr(StatsCounter::ThreadSlept),
Message::Sleep(_, _, _) => stats.incr(StatsCounter::ThreadSlept),
Message::Error(_, _) => stats.incr(StatsCounter::ThreadErrored),
Message::WriteStdout(_) => stats.incr(StatsCounter::WriteStdout),
Message::WriteStderr(_) => stats.incr(StatsCounter::WriteStderr),
@ -242,11 +246,12 @@ async fn main_async_task(
}
match message {
Message::Sleep(thread_id, duration) => {
Message::Sleep(thread_id, yielded_at, duration) => {
let tx = tx.clone();
spawn(async move {
sleep(duration).await;
tx.send(Message::Resume(thread_id))
let elapsed = Instant::now() - yielded_at;
tx.send(Message::Resume(thread_id, Args::from(elapsed)))
});
}
Message::Error(_, e) => {
@ -275,5 +280,9 @@ async fn main_async_task(
}
}
// Flush stdio one extra final time, just in case
handle_stdout.flush().await?;
handle_stderr.flush().await?;
Ok(())
}