mirror of
https://github.com/CompeyDev/lune-packaging.git
synced 2025-01-09 12:19:09 +00:00
Make new async require work with concurrent requires
This commit is contained in:
parent
129512b067
commit
bf574607fb
30 changed files with 160 additions and 57 deletions
10
packages/lib/src/builtins/mod.rs
Normal file
10
packages/lib/src/builtins/mod.rs
Normal file
|
@ -0,0 +1,10 @@
|
|||
pub(crate) mod fs;
|
||||
pub(crate) mod net;
|
||||
pub(crate) mod process;
|
||||
pub(crate) mod serde;
|
||||
pub(crate) mod stdio;
|
||||
pub(crate) mod task;
|
||||
pub(crate) mod top_level;
|
||||
|
||||
#[cfg(feature = "roblox")]
|
||||
pub(crate) mod roblox;
|
|
@ -7,14 +7,14 @@ use crate::lua::stdio::formatting::{format_label, pretty_format_multi_value};
|
|||
// is really tricky to do from scratch so we will just
|
||||
// proxy the default print and error functions here
|
||||
|
||||
pub fn top_level_print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
|
||||
pub fn print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
|
||||
let formatted = pretty_format_multi_value(&args)?;
|
||||
let print: LuaFunction = lua.named_registry_value("print")?;
|
||||
print.call(formatted)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
|
||||
pub fn printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
|
||||
let print: LuaFunction = lua.named_registry_value("print")?;
|
||||
print.call(format!(
|
||||
"{}\n{}",
|
||||
|
@ -24,7 +24,7 @@ pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
|
||||
pub fn warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
|
||||
let print: LuaFunction = lua.named_registry_value("print")?;
|
||||
print.call(format!(
|
||||
"{}\n{}",
|
||||
|
@ -34,7 +34,7 @@ pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
pub fn top_level_error(lua: &Lua, (arg, level): (LuaValue, Option<u32>)) -> LuaResult<()> {
|
||||
pub fn error(lua: &Lua, (arg, level): (LuaValue, Option<u32>)) -> LuaResult<()> {
|
||||
let error: LuaFunction = lua.named_registry_value("error")?;
|
||||
let trace: LuaFunction = lua.named_registry_value("dbg.trace")?;
|
||||
error.call((
|
|
@ -1,29 +1,23 @@
|
|||
use mlua::prelude::*;
|
||||
|
||||
mod fs;
|
||||
mod net;
|
||||
mod process;
|
||||
mod require;
|
||||
#[cfg(feature = "roblox")]
|
||||
mod roblox;
|
||||
mod serde;
|
||||
mod stdio;
|
||||
mod task;
|
||||
mod top_level;
|
||||
mod require_waker;
|
||||
|
||||
use crate::builtins::{self, top_level};
|
||||
|
||||
const BUILTINS_AS_GLOBALS: &[&str] = &["fs", "net", "process", "stdio", "task"];
|
||||
|
||||
pub fn create(lua: &'static Lua, args: Vec<String>) -> LuaResult<()> {
|
||||
// Create all builtins
|
||||
let builtins = vec![
|
||||
("fs", fs::create(lua)?),
|
||||
("net", net::create(lua)?),
|
||||
("process", process::create(lua, args)?),
|
||||
("fs", builtins::fs::create(lua)?),
|
||||
("net", builtins::net::create(lua)?),
|
||||
("process", builtins::process::create(lua, args)?),
|
||||
("serde", builtins::serde::create(lua)?),
|
||||
("stdio", builtins::stdio::create(lua)?),
|
||||
("task", builtins::task::create(lua)?),
|
||||
#[cfg(feature = "roblox")]
|
||||
("roblox", roblox::create(lua)?),
|
||||
("serde", self::serde::create(lua)?),
|
||||
("stdio", stdio::create(lua)?),
|
||||
("task", task::create(lua)?),
|
||||
("roblox", builtins::roblox::create(lua)?),
|
||||
];
|
||||
|
||||
// TODO: Remove this when we have proper LSP support for custom
|
||||
|
@ -40,13 +34,10 @@ pub fn create(lua: &'static Lua, args: Vec<String>) -> LuaResult<()> {
|
|||
// Create all top-level globals
|
||||
let globals = vec![
|
||||
("require", require_fn),
|
||||
("print", lua.create_function(top_level::top_level_print)?),
|
||||
("warn", lua.create_function(top_level::top_level_warn)?),
|
||||
("error", lua.create_function(top_level::top_level_error)?),
|
||||
(
|
||||
"printinfo",
|
||||
lua.create_function(top_level::top_level_printinfo)?,
|
||||
),
|
||||
("print", lua.create_function(top_level::print)?),
|
||||
("warn", lua.create_function(top_level::warn)?),
|
||||
("error", lua.create_function(top_level::error)?),
|
||||
("printinfo", lua.create_function(top_level::printinfo)?),
|
||||
];
|
||||
|
||||
// Set top-level globals and seal them
|
|
@ -9,12 +9,15 @@ use std::{
|
|||
use dunce::canonicalize;
|
||||
use mlua::prelude::*;
|
||||
use tokio::fs;
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
|
||||
use crate::lua::{
|
||||
table::TableBuilder,
|
||||
task::{TaskScheduler, TaskSchedulerScheduleExt},
|
||||
};
|
||||
|
||||
use super::require_waker::{RequireWakerFuture, RequireWakerState};
|
||||
|
||||
const REQUIRE_IMPL_LUA: &str = r#"
|
||||
local source = info(1, "s")
|
||||
if source == '[string "require"]' then
|
||||
|
@ -24,21 +27,24 @@ load(context, source, ...)
|
|||
return yield()
|
||||
"#;
|
||||
|
||||
type RequireWakersVec<'lua> = Vec<Arc<AsyncMutex<RequireWakerState<'lua>>>>;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
struct RequireContext<'lua> {
|
||||
// NOTE: We need to use arc here so that mlua clones
|
||||
// the reference and not the entire inner value(s)
|
||||
builtins: Arc<HashMap<String, LuaMultiValue<'lua>>>,
|
||||
cached: Arc<RefCell<HashMap<String, LuaResult<LuaMultiValue<'lua>>>>>,
|
||||
wakers: Arc<RefCell<HashMap<String, RequireWakersVec<'lua>>>>,
|
||||
locks: Arc<RefCell<HashSet<String>>>,
|
||||
pwd: String,
|
||||
}
|
||||
|
||||
impl<'lua> RequireContext<'lua> {
|
||||
pub fn new<K, V>(lua: &'static Lua, builtins_vec: Vec<(K, V)>) -> LuaResult<Self>
|
||||
pub fn new<K, V>(lua: &'lua Lua, builtins_vec: Vec<(K, V)>) -> LuaResult<Self>
|
||||
where
|
||||
K: Into<String>,
|
||||
V: ToLua<'static>,
|
||||
V: ToLua<'lua>,
|
||||
{
|
||||
let mut pwd = current_dir()
|
||||
.expect("Failed to access current working directory")
|
||||
|
@ -79,10 +85,29 @@ impl<'lua> RequireContext<'lua> {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn set_cached(&self, absolute_path: String, result: &LuaResult<LuaMultiValue<'lua>>) {
|
||||
pub fn set_cached(&self, absolute_path: &str, result: &LuaResult<LuaMultiValue<'lua>>) {
|
||||
self.cached
|
||||
.borrow_mut()
|
||||
.insert(absolute_path, result.clone());
|
||||
.insert(absolute_path.to_string(), result.clone());
|
||||
if let Some(wakers) = self.wakers.borrow_mut().remove(absolute_path) {
|
||||
for waker in wakers {
|
||||
waker
|
||||
.try_lock()
|
||||
.expect("Failed to lock waker")
|
||||
.finalize(result.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wait_for_cache(self, absolute_path: &str) -> RequireWakerFuture<'lua> {
|
||||
let state = RequireWakerState::new();
|
||||
let fut = RequireWakerFuture::new(&state);
|
||||
self.wakers
|
||||
.borrow_mut()
|
||||
.entry(absolute_path.to_string())
|
||||
.or_insert_with(Vec::new)
|
||||
.push(Arc::clone(&state));
|
||||
fut
|
||||
}
|
||||
|
||||
pub fn get_paths(
|
||||
|
@ -124,7 +149,7 @@ impl<'lua> LuaUserData for RequireContext<'lua> {}
|
|||
|
||||
fn load_builtin<'lua>(
|
||||
_lua: &'lua Lua,
|
||||
context: &RequireContext<'lua>,
|
||||
context: RequireContext<'lua>,
|
||||
module_name: String,
|
||||
_has_acquired_lock: bool,
|
||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
|
@ -139,7 +164,7 @@ fn load_builtin<'lua>(
|
|||
|
||||
async fn load_file<'lua>(
|
||||
lua: &'lua Lua,
|
||||
context: &RequireContext<'lua>,
|
||||
context: RequireContext<'lua>,
|
||||
absolute_path: String,
|
||||
relative_path: String,
|
||||
has_acquired_lock: bool,
|
||||
|
@ -149,9 +174,7 @@ async fn load_file<'lua>(
|
|||
Some(cached) => cached,
|
||||
None => {
|
||||
if !has_acquired_lock {
|
||||
return Err(LuaError::RuntimeError(
|
||||
"Failed to get require lock".to_string(),
|
||||
));
|
||||
return context.wait_for_cache(&absolute_path).await;
|
||||
}
|
||||
// Try to read the wanted file, note that we use bytes instead of reading
|
||||
// to a string since lua scripts are not necessarily valid utf-8 strings
|
||||
|
@ -173,9 +196,10 @@ async fn load_file<'lua>(
|
|||
let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?;
|
||||
sched.wait_for_task_completion(task)
|
||||
};
|
||||
// Wait for the thread to finish running, cache + return our result
|
||||
// Wait for the thread to finish running, cache + return our result,
|
||||
// notify any other threads that are also waiting on this to finish
|
||||
let rets = task_fut.await;
|
||||
context.set_cached(absolute_path, &rets);
|
||||
context.set_cached(&absolute_path, &rets);
|
||||
rets
|
||||
}
|
||||
}
|
||||
|
@ -190,7 +214,12 @@ async fn load<'lua>(
|
|||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
let result = if absolute_path == relative_path && absolute_path.starts_with('@') {
|
||||
if let Some(module_name) = absolute_path.strip_prefix("@lune/") {
|
||||
load_builtin(lua, &context, module_name.to_string(), has_acquired_lock)
|
||||
load_builtin(
|
||||
lua,
|
||||
context.clone(),
|
||||
module_name.to_string(),
|
||||
has_acquired_lock,
|
||||
)
|
||||
} else {
|
||||
// FUTURE: '@' can be used a special prefix for users to set their own
|
||||
// paths relative to a project file, similar to typescript paths config
|
||||
|
@ -202,7 +231,7 @@ async fn load<'lua>(
|
|||
} else {
|
||||
load_file(
|
||||
lua,
|
||||
&context,
|
||||
context.clone(),
|
||||
absolute_path.to_string(),
|
||||
relative_path,
|
||||
has_acquired_lock,
|
66
packages/lib/src/importer/require_waker.rs
Normal file
66
packages/lib/src/importer/require_waker.rs
Normal file
|
@ -0,0 +1,66 @@
|
|||
use std::{
|
||||
future::Future,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll, Waker},
|
||||
};
|
||||
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct RequireWakerState<'lua> {
|
||||
rets: Option<LuaResult<LuaMultiValue<'lua>>>,
|
||||
waker: Option<Waker>,
|
||||
}
|
||||
|
||||
impl<'lua> RequireWakerState<'lua> {
|
||||
pub fn new() -> Arc<AsyncMutex<Self>> {
|
||||
Arc::new(AsyncMutex::new(RequireWakerState {
|
||||
rets: None,
|
||||
waker: None,
|
||||
}))
|
||||
}
|
||||
|
||||
pub fn finalize(&mut self, rets: LuaResult<LuaMultiValue<'lua>>) {
|
||||
self.rets = Some(rets);
|
||||
if let Some(waker) = self.waker.take() {
|
||||
waker.wake();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(super) struct RequireWakerFuture<'lua> {
|
||||
state: Arc<AsyncMutex<RequireWakerState<'lua>>>,
|
||||
}
|
||||
|
||||
impl<'lua> RequireWakerFuture<'lua> {
|
||||
pub fn new(state: &Arc<AsyncMutex<RequireWakerState<'lua>>>) -> Self {
|
||||
Self {
|
||||
state: Arc::clone(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> Clone for RequireWakerFuture<'lua> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
state: Arc::clone(&self.state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> Future for RequireWakerFuture<'lua> {
|
||||
type Output = LuaResult<LuaMultiValue<'lua>>;
|
||||
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
let mut shared_state = self.state.try_lock().unwrap();
|
||||
if let Some(rets) = shared_state.rets.clone() {
|
||||
Poll::Ready(rets)
|
||||
} else {
|
||||
shared_state.waker = Some(cx.waker().clone());
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
}
|
|
@ -4,7 +4,8 @@ use lua::task::{TaskScheduler, TaskSchedulerResumeExt, TaskSchedulerScheduleExt}
|
|||
use mlua::prelude::*;
|
||||
use tokio::task::LocalSet;
|
||||
|
||||
pub(crate) mod globals;
|
||||
pub(crate) mod builtins;
|
||||
pub(crate) mod importer;
|
||||
pub(crate) mod lua;
|
||||
|
||||
mod error;
|
||||
|
@ -71,7 +72,7 @@ impl Lune {
|
|||
// NOTE: Some globals require the task scheduler to exist on startup
|
||||
let sched = TaskScheduler::new(lua)?.into_static();
|
||||
lua.set_app_data(sched);
|
||||
globals::create(lua, self.args.clone())?;
|
||||
importer::create(lua, self.args.clone())?;
|
||||
// Create the main thread and schedule it
|
||||
let main_chunk = lua
|
||||
.load(script_contents.as_ref())
|
||||
|
|
|
@ -52,7 +52,7 @@ pub struct TaskScheduler<'fut> {
|
|||
pub(super) tasks_current: Cell<Option<TaskReference>>,
|
||||
pub(super) tasks_queue_blocking: RefCell<VecDeque<TaskReference>>,
|
||||
pub(super) tasks_waiter_states:
|
||||
RefCell<HashMap<TaskReference, Arc<AsyncMutex<TaskWaiterState<'fut>>>>>,
|
||||
RefCell<HashMap<TaskReference, Vec<Arc<AsyncMutex<TaskWaiterState<'fut>>>>>>,
|
||||
pub(super) tasks_current_lua_error: Arc<AsyncMutex<Option<LuaError>>>,
|
||||
// Future tasks & objects for waking
|
||||
pub(super) futures: AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>,
|
||||
|
@ -452,9 +452,13 @@ impl<'fut> TaskScheduler<'fut> {
|
|||
panic!("Task does not exist in scheduler")
|
||||
}
|
||||
let state = TaskWaiterState::new();
|
||||
self.tasks_waiter_states
|
||||
.borrow_mut()
|
||||
.insert(reference, Arc::clone(&state));
|
||||
{
|
||||
let mut all_states = self.tasks_waiter_states.borrow_mut();
|
||||
all_states
|
||||
.entry(reference)
|
||||
.or_insert_with(Vec::new)
|
||||
.push(Arc::clone(&state));
|
||||
}
|
||||
TaskWaiterFuture::new(&state).await
|
||||
}
|
||||
|
||||
|
@ -467,8 +471,10 @@ impl<'fut> TaskScheduler<'fut> {
|
|||
reference: TaskReference,
|
||||
result: LuaResult<LuaMultiValue<'fut>>,
|
||||
) {
|
||||
if let Some(waiter_state) = self.tasks_waiter_states.borrow_mut().remove(&reference) {
|
||||
waiter_state.try_lock().unwrap().finalize(result);
|
||||
if let Some(waiter_states) = self.tasks_waiter_states.borrow_mut().remove(&reference) {
|
||||
for waiter_state in waiter_states {
|
||||
waiter_state.try_lock().unwrap().finalize(result.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,15 +60,15 @@ create_tests! {
|
|||
process_env: "process/env",
|
||||
process_exit: "process/exit",
|
||||
process_spawn: "process/spawn",
|
||||
require_async: "globals/require/tests/async",
|
||||
require_async_concurrent: "globals/require/tests/async_concurrent",
|
||||
require_async_sequential: "globals/require/tests/async_sequential",
|
||||
require_builtins: "globals/require/tests/builtins",
|
||||
require_children: "globals/require/tests/children",
|
||||
require_invalid: "globals/require/tests/invalid",
|
||||
require_nested: "globals/require/tests/nested",
|
||||
require_parents: "globals/require/tests/parents",
|
||||
require_siblings: "globals/require/tests/siblings",
|
||||
require_async: "require/tests/async",
|
||||
require_async_concurrent: "require/tests/async_concurrent",
|
||||
require_async_sequential: "require/tests/async_sequential",
|
||||
require_builtins: "require/tests/builtins",
|
||||
require_children: "require/tests/children",
|
||||
require_invalid: "require/tests/invalid",
|
||||
require_nested: "require/tests/nested",
|
||||
require_parents: "require/tests/parents",
|
||||
require_siblings: "require/tests/siblings",
|
||||
// TODO: Uncomment this test, it is commented out right
|
||||
// now to let CI pass so that we can make a new release
|
||||
// global_coroutine: "globals/coroutine",
|
||||
|
|
Loading…
Reference in a new issue