Make new async require work with concurrent requires

This commit is contained in:
Filip Tibell 2023-03-22 18:16:45 +01:00
parent 129512b067
commit bf574607fb
No known key found for this signature in database
30 changed files with 160 additions and 57 deletions

View file

@ -0,0 +1,10 @@
pub(crate) mod fs;
pub(crate) mod net;
pub(crate) mod process;
pub(crate) mod serde;
pub(crate) mod stdio;
pub(crate) mod task;
pub(crate) mod top_level;
#[cfg(feature = "roblox")]
pub(crate) mod roblox;

View file

@ -7,14 +7,14 @@ use crate::lua::stdio::formatting::{format_label, pretty_format_multi_value};
// is really tricky to do from scratch so we will just // is really tricky to do from scratch so we will just
// proxy the default print and error functions here // proxy the default print and error functions here
pub fn top_level_print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { pub fn print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let formatted = pretty_format_multi_value(&args)?; let formatted = pretty_format_multi_value(&args)?;
let print: LuaFunction = lua.named_registry_value("print")?; let print: LuaFunction = lua.named_registry_value("print")?;
print.call(formatted)?; print.call(formatted)?;
Ok(()) Ok(())
} }
pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { pub fn printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let print: LuaFunction = lua.named_registry_value("print")?; let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!( print.call(format!(
"{}\n{}", "{}\n{}",
@ -24,7 +24,7 @@ pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
Ok(()) Ok(())
} }
pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { pub fn warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let print: LuaFunction = lua.named_registry_value("print")?; let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!( print.call(format!(
"{}\n{}", "{}\n{}",
@ -34,7 +34,7 @@ pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
Ok(()) Ok(())
} }
pub fn top_level_error(lua: &Lua, (arg, level): (LuaValue, Option<u32>)) -> LuaResult<()> { pub fn error(lua: &Lua, (arg, level): (LuaValue, Option<u32>)) -> LuaResult<()> {
let error: LuaFunction = lua.named_registry_value("error")?; let error: LuaFunction = lua.named_registry_value("error")?;
let trace: LuaFunction = lua.named_registry_value("dbg.trace")?; let trace: LuaFunction = lua.named_registry_value("dbg.trace")?;
error.call(( error.call((

View file

@ -1,29 +1,23 @@
use mlua::prelude::*; use mlua::prelude::*;
mod fs;
mod net;
mod process;
mod require; mod require;
#[cfg(feature = "roblox")] mod require_waker;
mod roblox;
mod serde; use crate::builtins::{self, top_level};
mod stdio;
mod task;
mod top_level;
const BUILTINS_AS_GLOBALS: &[&str] = &["fs", "net", "process", "stdio", "task"]; const BUILTINS_AS_GLOBALS: &[&str] = &["fs", "net", "process", "stdio", "task"];
pub fn create(lua: &'static Lua, args: Vec<String>) -> LuaResult<()> { pub fn create(lua: &'static Lua, args: Vec<String>) -> LuaResult<()> {
// Create all builtins // Create all builtins
let builtins = vec![ let builtins = vec![
("fs", fs::create(lua)?), ("fs", builtins::fs::create(lua)?),
("net", net::create(lua)?), ("net", builtins::net::create(lua)?),
("process", process::create(lua, args)?), ("process", builtins::process::create(lua, args)?),
("serde", builtins::serde::create(lua)?),
("stdio", builtins::stdio::create(lua)?),
("task", builtins::task::create(lua)?),
#[cfg(feature = "roblox")] #[cfg(feature = "roblox")]
("roblox", roblox::create(lua)?), ("roblox", builtins::roblox::create(lua)?),
("serde", self::serde::create(lua)?),
("stdio", stdio::create(lua)?),
("task", task::create(lua)?),
]; ];
// TODO: Remove this when we have proper LSP support for custom // TODO: Remove this when we have proper LSP support for custom
@ -40,13 +34,10 @@ pub fn create(lua: &'static Lua, args: Vec<String>) -> LuaResult<()> {
// Create all top-level globals // Create all top-level globals
let globals = vec![ let globals = vec![
("require", require_fn), ("require", require_fn),
("print", lua.create_function(top_level::top_level_print)?), ("print", lua.create_function(top_level::print)?),
("warn", lua.create_function(top_level::top_level_warn)?), ("warn", lua.create_function(top_level::warn)?),
("error", lua.create_function(top_level::top_level_error)?), ("error", lua.create_function(top_level::error)?),
( ("printinfo", lua.create_function(top_level::printinfo)?),
"printinfo",
lua.create_function(top_level::top_level_printinfo)?,
),
]; ];
// Set top-level globals and seal them // Set top-level globals and seal them

View file

@ -9,12 +9,15 @@ use std::{
use dunce::canonicalize; use dunce::canonicalize;
use mlua::prelude::*; use mlua::prelude::*;
use tokio::fs; use tokio::fs;
use tokio::sync::Mutex as AsyncMutex;
use crate::lua::{ use crate::lua::{
table::TableBuilder, table::TableBuilder,
task::{TaskScheduler, TaskSchedulerScheduleExt}, task::{TaskScheduler, TaskSchedulerScheduleExt},
}; };
use super::require_waker::{RequireWakerFuture, RequireWakerState};
const REQUIRE_IMPL_LUA: &str = r#" const REQUIRE_IMPL_LUA: &str = r#"
local source = info(1, "s") local source = info(1, "s")
if source == '[string "require"]' then if source == '[string "require"]' then
@ -24,21 +27,24 @@ load(context, source, ...)
return yield() return yield()
"#; "#;
type RequireWakersVec<'lua> = Vec<Arc<AsyncMutex<RequireWakerState<'lua>>>>;
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
struct RequireContext<'lua> { struct RequireContext<'lua> {
// NOTE: We need to use arc here so that mlua clones // NOTE: We need to use arc here so that mlua clones
// the reference and not the entire inner value(s) // the reference and not the entire inner value(s)
builtins: Arc<HashMap<String, LuaMultiValue<'lua>>>, builtins: Arc<HashMap<String, LuaMultiValue<'lua>>>,
cached: Arc<RefCell<HashMap<String, LuaResult<LuaMultiValue<'lua>>>>>, cached: Arc<RefCell<HashMap<String, LuaResult<LuaMultiValue<'lua>>>>>,
wakers: Arc<RefCell<HashMap<String, RequireWakersVec<'lua>>>>,
locks: Arc<RefCell<HashSet<String>>>, locks: Arc<RefCell<HashSet<String>>>,
pwd: String, pwd: String,
} }
impl<'lua> RequireContext<'lua> { impl<'lua> RequireContext<'lua> {
pub fn new<K, V>(lua: &'static Lua, builtins_vec: Vec<(K, V)>) -> LuaResult<Self> pub fn new<K, V>(lua: &'lua Lua, builtins_vec: Vec<(K, V)>) -> LuaResult<Self>
where where
K: Into<String>, K: Into<String>,
V: ToLua<'static>, V: ToLua<'lua>,
{ {
let mut pwd = current_dir() let mut pwd = current_dir()
.expect("Failed to access current working directory") .expect("Failed to access current working directory")
@ -79,10 +85,29 @@ impl<'lua> RequireContext<'lua> {
} }
} }
pub fn set_cached(&self, absolute_path: String, result: &LuaResult<LuaMultiValue<'lua>>) { pub fn set_cached(&self, absolute_path: &str, result: &LuaResult<LuaMultiValue<'lua>>) {
self.cached self.cached
.borrow_mut() .borrow_mut()
.insert(absolute_path, result.clone()); .insert(absolute_path.to_string(), result.clone());
if let Some(wakers) = self.wakers.borrow_mut().remove(absolute_path) {
for waker in wakers {
waker
.try_lock()
.expect("Failed to lock waker")
.finalize(result.clone());
}
}
}
pub fn wait_for_cache(self, absolute_path: &str) -> RequireWakerFuture<'lua> {
let state = RequireWakerState::new();
let fut = RequireWakerFuture::new(&state);
self.wakers
.borrow_mut()
.entry(absolute_path.to_string())
.or_insert_with(Vec::new)
.push(Arc::clone(&state));
fut
} }
pub fn get_paths( pub fn get_paths(
@ -124,7 +149,7 @@ impl<'lua> LuaUserData for RequireContext<'lua> {}
fn load_builtin<'lua>( fn load_builtin<'lua>(
_lua: &'lua Lua, _lua: &'lua Lua,
context: &RequireContext<'lua>, context: RequireContext<'lua>,
module_name: String, module_name: String,
_has_acquired_lock: bool, _has_acquired_lock: bool,
) -> LuaResult<LuaMultiValue<'lua>> { ) -> LuaResult<LuaMultiValue<'lua>> {
@ -139,7 +164,7 @@ fn load_builtin<'lua>(
async fn load_file<'lua>( async fn load_file<'lua>(
lua: &'lua Lua, lua: &'lua Lua,
context: &RequireContext<'lua>, context: RequireContext<'lua>,
absolute_path: String, absolute_path: String,
relative_path: String, relative_path: String,
has_acquired_lock: bool, has_acquired_lock: bool,
@ -149,9 +174,7 @@ async fn load_file<'lua>(
Some(cached) => cached, Some(cached) => cached,
None => { None => {
if !has_acquired_lock { if !has_acquired_lock {
return Err(LuaError::RuntimeError( return context.wait_for_cache(&absolute_path).await;
"Failed to get require lock".to_string(),
));
} }
// Try to read the wanted file, note that we use bytes instead of reading // Try to read the wanted file, note that we use bytes instead of reading
// to a string since lua scripts are not necessarily valid utf-8 strings // to a string since lua scripts are not necessarily valid utf-8 strings
@ -173,9 +196,10 @@ async fn load_file<'lua>(
let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?; let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?;
sched.wait_for_task_completion(task) sched.wait_for_task_completion(task)
}; };
// Wait for the thread to finish running, cache + return our result // Wait for the thread to finish running, cache + return our result,
// notify any other threads that are also waiting on this to finish
let rets = task_fut.await; let rets = task_fut.await;
context.set_cached(absolute_path, &rets); context.set_cached(&absolute_path, &rets);
rets rets
} }
} }
@ -190,7 +214,12 @@ async fn load<'lua>(
) -> LuaResult<LuaMultiValue<'lua>> { ) -> LuaResult<LuaMultiValue<'lua>> {
let result = if absolute_path == relative_path && absolute_path.starts_with('@') { let result = if absolute_path == relative_path && absolute_path.starts_with('@') {
if let Some(module_name) = absolute_path.strip_prefix("@lune/") { if let Some(module_name) = absolute_path.strip_prefix("@lune/") {
load_builtin(lua, &context, module_name.to_string(), has_acquired_lock) load_builtin(
lua,
context.clone(),
module_name.to_string(),
has_acquired_lock,
)
} else { } else {
// FUTURE: '@' can be used a special prefix for users to set their own // FUTURE: '@' can be used a special prefix for users to set their own
// paths relative to a project file, similar to typescript paths config // paths relative to a project file, similar to typescript paths config
@ -202,7 +231,7 @@ async fn load<'lua>(
} else { } else {
load_file( load_file(
lua, lua,
&context, context.clone(),
absolute_path.to_string(), absolute_path.to_string(),
relative_path, relative_path,
has_acquired_lock, has_acquired_lock,

View file

@ -0,0 +1,66 @@
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
};
use tokio::sync::Mutex as AsyncMutex;
use mlua::prelude::*;
#[derive(Debug, Clone)]
pub(super) struct RequireWakerState<'lua> {
rets: Option<LuaResult<LuaMultiValue<'lua>>>,
waker: Option<Waker>,
}
impl<'lua> RequireWakerState<'lua> {
pub fn new() -> Arc<AsyncMutex<Self>> {
Arc::new(AsyncMutex::new(RequireWakerState {
rets: None,
waker: None,
}))
}
pub fn finalize(&mut self, rets: LuaResult<LuaMultiValue<'lua>>) {
self.rets = Some(rets);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
#[derive(Debug)]
pub(super) struct RequireWakerFuture<'lua> {
state: Arc<AsyncMutex<RequireWakerState<'lua>>>,
}
impl<'lua> RequireWakerFuture<'lua> {
pub fn new(state: &Arc<AsyncMutex<RequireWakerState<'lua>>>) -> Self {
Self {
state: Arc::clone(state),
}
}
}
impl<'lua> Clone for RequireWakerFuture<'lua> {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
}
}
}
impl<'lua> Future for RequireWakerFuture<'lua> {
type Output = LuaResult<LuaMultiValue<'lua>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared_state = self.state.try_lock().unwrap();
if let Some(rets) = shared_state.rets.clone() {
Poll::Ready(rets)
} else {
shared_state.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

View file

@ -4,7 +4,8 @@ use lua::task::{TaskScheduler, TaskSchedulerResumeExt, TaskSchedulerScheduleExt}
use mlua::prelude::*; use mlua::prelude::*;
use tokio::task::LocalSet; use tokio::task::LocalSet;
pub(crate) mod globals; pub(crate) mod builtins;
pub(crate) mod importer;
pub(crate) mod lua; pub(crate) mod lua;
mod error; mod error;
@ -71,7 +72,7 @@ impl Lune {
// NOTE: Some globals require the task scheduler to exist on startup // NOTE: Some globals require the task scheduler to exist on startup
let sched = TaskScheduler::new(lua)?.into_static(); let sched = TaskScheduler::new(lua)?.into_static();
lua.set_app_data(sched); lua.set_app_data(sched);
globals::create(lua, self.args.clone())?; importer::create(lua, self.args.clone())?;
// Create the main thread and schedule it // Create the main thread and schedule it
let main_chunk = lua let main_chunk = lua
.load(script_contents.as_ref()) .load(script_contents.as_ref())

View file

@ -52,7 +52,7 @@ pub struct TaskScheduler<'fut> {
pub(super) tasks_current: Cell<Option<TaskReference>>, pub(super) tasks_current: Cell<Option<TaskReference>>,
pub(super) tasks_queue_blocking: RefCell<VecDeque<TaskReference>>, pub(super) tasks_queue_blocking: RefCell<VecDeque<TaskReference>>,
pub(super) tasks_waiter_states: pub(super) tasks_waiter_states:
RefCell<HashMap<TaskReference, Arc<AsyncMutex<TaskWaiterState<'fut>>>>>, RefCell<HashMap<TaskReference, Vec<Arc<AsyncMutex<TaskWaiterState<'fut>>>>>>,
pub(super) tasks_current_lua_error: Arc<AsyncMutex<Option<LuaError>>>, pub(super) tasks_current_lua_error: Arc<AsyncMutex<Option<LuaError>>>,
// Future tasks & objects for waking // Future tasks & objects for waking
pub(super) futures: AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>, pub(super) futures: AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>,
@ -452,9 +452,13 @@ impl<'fut> TaskScheduler<'fut> {
panic!("Task does not exist in scheduler") panic!("Task does not exist in scheduler")
} }
let state = TaskWaiterState::new(); let state = TaskWaiterState::new();
self.tasks_waiter_states {
.borrow_mut() let mut all_states = self.tasks_waiter_states.borrow_mut();
.insert(reference, Arc::clone(&state)); all_states
.entry(reference)
.or_insert_with(Vec::new)
.push(Arc::clone(&state));
}
TaskWaiterFuture::new(&state).await TaskWaiterFuture::new(&state).await
} }
@ -467,8 +471,10 @@ impl<'fut> TaskScheduler<'fut> {
reference: TaskReference, reference: TaskReference,
result: LuaResult<LuaMultiValue<'fut>>, result: LuaResult<LuaMultiValue<'fut>>,
) { ) {
if let Some(waiter_state) = self.tasks_waiter_states.borrow_mut().remove(&reference) { if let Some(waiter_states) = self.tasks_waiter_states.borrow_mut().remove(&reference) {
waiter_state.try_lock().unwrap().finalize(result); for waiter_state in waiter_states {
waiter_state.try_lock().unwrap().finalize(result.clone());
}
} }
} }
} }

View file

@ -60,15 +60,15 @@ create_tests! {
process_env: "process/env", process_env: "process/env",
process_exit: "process/exit", process_exit: "process/exit",
process_spawn: "process/spawn", process_spawn: "process/spawn",
require_async: "globals/require/tests/async", require_async: "require/tests/async",
require_async_concurrent: "globals/require/tests/async_concurrent", require_async_concurrent: "require/tests/async_concurrent",
require_async_sequential: "globals/require/tests/async_sequential", require_async_sequential: "require/tests/async_sequential",
require_builtins: "globals/require/tests/builtins", require_builtins: "require/tests/builtins",
require_children: "globals/require/tests/children", require_children: "require/tests/children",
require_invalid: "globals/require/tests/invalid", require_invalid: "require/tests/invalid",
require_nested: "globals/require/tests/nested", require_nested: "require/tests/nested",
require_parents: "globals/require/tests/parents", require_parents: "require/tests/parents",
require_siblings: "globals/require/tests/siblings", require_siblings: "require/tests/siblings",
// TODO: Uncomment this test, it is commented out right // TODO: Uncomment this test, it is commented out right
// now to let CI pass so that we can make a new release // now to let CI pass so that we can make a new release
// global_coroutine: "globals/coroutine", // global_coroutine: "globals/coroutine",