Make new async require work with concurrent requires

This commit is contained in:
Filip Tibell 2023-03-22 18:16:45 +01:00
parent 129512b067
commit bf574607fb
No known key found for this signature in database
30 changed files with 160 additions and 57 deletions

View file

@ -0,0 +1,10 @@
pub(crate) mod fs;
pub(crate) mod net;
pub(crate) mod process;
pub(crate) mod serde;
pub(crate) mod stdio;
pub(crate) mod task;
pub(crate) mod top_level;
#[cfg(feature = "roblox")]
pub(crate) mod roblox;

View file

@ -7,14 +7,14 @@ use crate::lua::stdio::formatting::{format_label, pretty_format_multi_value};
// is really tricky to do from scratch so we will just
// proxy the default print and error functions here
pub fn top_level_print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
pub fn print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let formatted = pretty_format_multi_value(&args)?;
let print: LuaFunction = lua.named_registry_value("print")?;
print.call(formatted)?;
Ok(())
}
pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
pub fn printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!(
"{}\n{}",
@ -24,7 +24,7 @@ pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
Ok(())
}
pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
pub fn warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!(
"{}\n{}",
@ -34,7 +34,7 @@ pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
Ok(())
}
pub fn top_level_error(lua: &Lua, (arg, level): (LuaValue, Option<u32>)) -> LuaResult<()> {
pub fn error(lua: &Lua, (arg, level): (LuaValue, Option<u32>)) -> LuaResult<()> {
let error: LuaFunction = lua.named_registry_value("error")?;
let trace: LuaFunction = lua.named_registry_value("dbg.trace")?;
error.call((

View file

@ -1,29 +1,23 @@
use mlua::prelude::*;
mod fs;
mod net;
mod process;
mod require;
#[cfg(feature = "roblox")]
mod roblox;
mod serde;
mod stdio;
mod task;
mod top_level;
mod require_waker;
use crate::builtins::{self, top_level};
const BUILTINS_AS_GLOBALS: &[&str] = &["fs", "net", "process", "stdio", "task"];
pub fn create(lua: &'static Lua, args: Vec<String>) -> LuaResult<()> {
// Create all builtins
let builtins = vec![
("fs", fs::create(lua)?),
("net", net::create(lua)?),
("process", process::create(lua, args)?),
("fs", builtins::fs::create(lua)?),
("net", builtins::net::create(lua)?),
("process", builtins::process::create(lua, args)?),
("serde", builtins::serde::create(lua)?),
("stdio", builtins::stdio::create(lua)?),
("task", builtins::task::create(lua)?),
#[cfg(feature = "roblox")]
("roblox", roblox::create(lua)?),
("serde", self::serde::create(lua)?),
("stdio", stdio::create(lua)?),
("task", task::create(lua)?),
("roblox", builtins::roblox::create(lua)?),
];
// TODO: Remove this when we have proper LSP support for custom
@ -40,13 +34,10 @@ pub fn create(lua: &'static Lua, args: Vec<String>) -> LuaResult<()> {
// Create all top-level globals
let globals = vec![
("require", require_fn),
("print", lua.create_function(top_level::top_level_print)?),
("warn", lua.create_function(top_level::top_level_warn)?),
("error", lua.create_function(top_level::top_level_error)?),
(
"printinfo",
lua.create_function(top_level::top_level_printinfo)?,
),
("print", lua.create_function(top_level::print)?),
("warn", lua.create_function(top_level::warn)?),
("error", lua.create_function(top_level::error)?),
("printinfo", lua.create_function(top_level::printinfo)?),
];
// Set top-level globals and seal them

View file

@ -9,12 +9,15 @@ use std::{
use dunce::canonicalize;
use mlua::prelude::*;
use tokio::fs;
use tokio::sync::Mutex as AsyncMutex;
use crate::lua::{
table::TableBuilder,
task::{TaskScheduler, TaskSchedulerScheduleExt},
};
use super::require_waker::{RequireWakerFuture, RequireWakerState};
const REQUIRE_IMPL_LUA: &str = r#"
local source = info(1, "s")
if source == '[string "require"]' then
@ -24,21 +27,24 @@ load(context, source, ...)
return yield()
"#;
type RequireWakersVec<'lua> = Vec<Arc<AsyncMutex<RequireWakerState<'lua>>>>;
#[derive(Debug, Clone, Default)]
struct RequireContext<'lua> {
// NOTE: We need to use arc here so that mlua clones
// the reference and not the entire inner value(s)
builtins: Arc<HashMap<String, LuaMultiValue<'lua>>>,
cached: Arc<RefCell<HashMap<String, LuaResult<LuaMultiValue<'lua>>>>>,
wakers: Arc<RefCell<HashMap<String, RequireWakersVec<'lua>>>>,
locks: Arc<RefCell<HashSet<String>>>,
pwd: String,
}
impl<'lua> RequireContext<'lua> {
pub fn new<K, V>(lua: &'static Lua, builtins_vec: Vec<(K, V)>) -> LuaResult<Self>
pub fn new<K, V>(lua: &'lua Lua, builtins_vec: Vec<(K, V)>) -> LuaResult<Self>
where
K: Into<String>,
V: ToLua<'static>,
V: ToLua<'lua>,
{
let mut pwd = current_dir()
.expect("Failed to access current working directory")
@ -79,10 +85,29 @@ impl<'lua> RequireContext<'lua> {
}
}
pub fn set_cached(&self, absolute_path: String, result: &LuaResult<LuaMultiValue<'lua>>) {
pub fn set_cached(&self, absolute_path: &str, result: &LuaResult<LuaMultiValue<'lua>>) {
self.cached
.borrow_mut()
.insert(absolute_path, result.clone());
.insert(absolute_path.to_string(), result.clone());
if let Some(wakers) = self.wakers.borrow_mut().remove(absolute_path) {
for waker in wakers {
waker
.try_lock()
.expect("Failed to lock waker")
.finalize(result.clone());
}
}
}
pub fn wait_for_cache(self, absolute_path: &str) -> RequireWakerFuture<'lua> {
let state = RequireWakerState::new();
let fut = RequireWakerFuture::new(&state);
self.wakers
.borrow_mut()
.entry(absolute_path.to_string())
.or_insert_with(Vec::new)
.push(Arc::clone(&state));
fut
}
pub fn get_paths(
@ -124,7 +149,7 @@ impl<'lua> LuaUserData for RequireContext<'lua> {}
fn load_builtin<'lua>(
_lua: &'lua Lua,
context: &RequireContext<'lua>,
context: RequireContext<'lua>,
module_name: String,
_has_acquired_lock: bool,
) -> LuaResult<LuaMultiValue<'lua>> {
@ -139,7 +164,7 @@ fn load_builtin<'lua>(
async fn load_file<'lua>(
lua: &'lua Lua,
context: &RequireContext<'lua>,
context: RequireContext<'lua>,
absolute_path: String,
relative_path: String,
has_acquired_lock: bool,
@ -149,9 +174,7 @@ async fn load_file<'lua>(
Some(cached) => cached,
None => {
if !has_acquired_lock {
return Err(LuaError::RuntimeError(
"Failed to get require lock".to_string(),
));
return context.wait_for_cache(&absolute_path).await;
}
// Try to read the wanted file, note that we use bytes instead of reading
// to a string since lua scripts are not necessarily valid utf-8 strings
@ -173,9 +196,10 @@ async fn load_file<'lua>(
let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?;
sched.wait_for_task_completion(task)
};
// Wait for the thread to finish running, cache + return our result
// Wait for the thread to finish running, cache + return our result,
// notify any other threads that are also waiting on this to finish
let rets = task_fut.await;
context.set_cached(absolute_path, &rets);
context.set_cached(&absolute_path, &rets);
rets
}
}
@ -190,7 +214,12 @@ async fn load<'lua>(
) -> LuaResult<LuaMultiValue<'lua>> {
let result = if absolute_path == relative_path && absolute_path.starts_with('@') {
if let Some(module_name) = absolute_path.strip_prefix("@lune/") {
load_builtin(lua, &context, module_name.to_string(), has_acquired_lock)
load_builtin(
lua,
context.clone(),
module_name.to_string(),
has_acquired_lock,
)
} else {
// FUTURE: '@' can be used a special prefix for users to set their own
// paths relative to a project file, similar to typescript paths config
@ -202,7 +231,7 @@ async fn load<'lua>(
} else {
load_file(
lua,
&context,
context.clone(),
absolute_path.to_string(),
relative_path,
has_acquired_lock,

View file

@ -0,0 +1,66 @@
use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll, Waker},
};
use tokio::sync::Mutex as AsyncMutex;
use mlua::prelude::*;
#[derive(Debug, Clone)]
pub(super) struct RequireWakerState<'lua> {
rets: Option<LuaResult<LuaMultiValue<'lua>>>,
waker: Option<Waker>,
}
impl<'lua> RequireWakerState<'lua> {
pub fn new() -> Arc<AsyncMutex<Self>> {
Arc::new(AsyncMutex::new(RequireWakerState {
rets: None,
waker: None,
}))
}
pub fn finalize(&mut self, rets: LuaResult<LuaMultiValue<'lua>>) {
self.rets = Some(rets);
if let Some(waker) = self.waker.take() {
waker.wake();
}
}
}
#[derive(Debug)]
pub(super) struct RequireWakerFuture<'lua> {
state: Arc<AsyncMutex<RequireWakerState<'lua>>>,
}
impl<'lua> RequireWakerFuture<'lua> {
pub fn new(state: &Arc<AsyncMutex<RequireWakerState<'lua>>>) -> Self {
Self {
state: Arc::clone(state),
}
}
}
impl<'lua> Clone for RequireWakerFuture<'lua> {
fn clone(&self) -> Self {
Self {
state: Arc::clone(&self.state),
}
}
}
impl<'lua> Future for RequireWakerFuture<'lua> {
type Output = LuaResult<LuaMultiValue<'lua>>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut shared_state = self.state.try_lock().unwrap();
if let Some(rets) = shared_state.rets.clone() {
Poll::Ready(rets)
} else {
shared_state.waker = Some(cx.waker().clone());
Poll::Pending
}
}
}

View file

@ -4,7 +4,8 @@ use lua::task::{TaskScheduler, TaskSchedulerResumeExt, TaskSchedulerScheduleExt}
use mlua::prelude::*;
use tokio::task::LocalSet;
pub(crate) mod globals;
pub(crate) mod builtins;
pub(crate) mod importer;
pub(crate) mod lua;
mod error;
@ -71,7 +72,7 @@ impl Lune {
// NOTE: Some globals require the task scheduler to exist on startup
let sched = TaskScheduler::new(lua)?.into_static();
lua.set_app_data(sched);
globals::create(lua, self.args.clone())?;
importer::create(lua, self.args.clone())?;
// Create the main thread and schedule it
let main_chunk = lua
.load(script_contents.as_ref())

View file

@ -52,7 +52,7 @@ pub struct TaskScheduler<'fut> {
pub(super) tasks_current: Cell<Option<TaskReference>>,
pub(super) tasks_queue_blocking: RefCell<VecDeque<TaskReference>>,
pub(super) tasks_waiter_states:
RefCell<HashMap<TaskReference, Arc<AsyncMutex<TaskWaiterState<'fut>>>>>,
RefCell<HashMap<TaskReference, Vec<Arc<AsyncMutex<TaskWaiterState<'fut>>>>>>,
pub(super) tasks_current_lua_error: Arc<AsyncMutex<Option<LuaError>>>,
// Future tasks & objects for waking
pub(super) futures: AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>,
@ -452,9 +452,13 @@ impl<'fut> TaskScheduler<'fut> {
panic!("Task does not exist in scheduler")
}
let state = TaskWaiterState::new();
self.tasks_waiter_states
.borrow_mut()
.insert(reference, Arc::clone(&state));
{
let mut all_states = self.tasks_waiter_states.borrow_mut();
all_states
.entry(reference)
.or_insert_with(Vec::new)
.push(Arc::clone(&state));
}
TaskWaiterFuture::new(&state).await
}
@ -467,8 +471,10 @@ impl<'fut> TaskScheduler<'fut> {
reference: TaskReference,
result: LuaResult<LuaMultiValue<'fut>>,
) {
if let Some(waiter_state) = self.tasks_waiter_states.borrow_mut().remove(&reference) {
waiter_state.try_lock().unwrap().finalize(result);
if let Some(waiter_states) = self.tasks_waiter_states.borrow_mut().remove(&reference) {
for waiter_state in waiter_states {
waiter_state.try_lock().unwrap().finalize(result.clone());
}
}
}
}

View file

@ -60,15 +60,15 @@ create_tests! {
process_env: "process/env",
process_exit: "process/exit",
process_spawn: "process/spawn",
require_async: "globals/require/tests/async",
require_async_concurrent: "globals/require/tests/async_concurrent",
require_async_sequential: "globals/require/tests/async_sequential",
require_builtins: "globals/require/tests/builtins",
require_children: "globals/require/tests/children",
require_invalid: "globals/require/tests/invalid",
require_nested: "globals/require/tests/nested",
require_parents: "globals/require/tests/parents",
require_siblings: "globals/require/tests/siblings",
require_async: "require/tests/async",
require_async_concurrent: "require/tests/async_concurrent",
require_async_sequential: "require/tests/async_sequential",
require_builtins: "require/tests/builtins",
require_children: "require/tests/children",
require_invalid: "require/tests/invalid",
require_nested: "require/tests/nested",
require_parents: "require/tests/parents",
require_siblings: "require/tests/siblings",
// TODO: Uncomment this test, it is commented out right
// now to let CI pass so that we can make a new release
// global_coroutine: "globals/coroutine",