diff --git a/CHANGELOG.md b/CHANGELOG.md
index 4ccd11e..91d0352 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -12,7 +12,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added
-- Builtin modules such as `fs`, `net` and others can now be imported using `require("@lune/fs")`, `require("@lune/net")` ..
+- `require` has been reimplemented and overhauled in several ways:
+
+ - Builtin modules such as `fs`, `net` and others can now be imported using `require("@lune/fs")`, `require("@lune/net")` ...
+ This is the first step towards moving away from adding each library as a global, and allowing Lune to have more built-in libraries.
+
+ - Requiring a script is now completely asynchronous and will not block lua threads other than the caller.
+ - Requiring a script will no longer error when using async APIs in the main body of the required script.
+
+ Behavior otherwise stays the same, and requires are still relative to file unless the special `@` prefix is used.
### Removed
diff --git a/packages/lib/src/globals/process.rs b/packages/lib/src/globals/process.rs
index 8f268d3..360d696 100644
--- a/packages/lib/src/globals/process.rs
+++ b/packages/lib/src/globals/process.rs
@@ -136,7 +136,7 @@ fn process_env_iter<'lua>(
lua: &'lua Lua,
(_, _): (LuaValue<'lua>, ()),
) -> LuaResult> {
- let mut vars = env::vars_os();
+ let mut vars = env::vars_os().collect::>().into_iter();
lua.create_function_mut(move |lua, _: ()| match vars.next() {
Some((key, value)) => {
let raw_key = RawOsString::new(key);
diff --git a/packages/lib/src/globals/require.rs b/packages/lib/src/globals/require.rs
index 2040600..802d52b 100644
--- a/packages/lib/src/globals/require.rs
+++ b/packages/lib/src/globals/require.rs
@@ -1,13 +1,14 @@
use std::{
cell::RefCell,
- collections::HashMap,
+ collections::{HashMap, HashSet},
env::current_dir,
path::{self, PathBuf},
+ sync::Arc,
};
use dunce::canonicalize;
use mlua::prelude::*;
-use tokio::{fs, sync::oneshot};
+use tokio::fs;
use crate::lua::{
table::TableBuilder,
@@ -19,14 +20,17 @@ local source = info(1, "s")
if source == '[string "require"]' then
source = info(2, "s")
end
-local absolute, relative = paths(context, source, ...)
-return load(context, absolute, relative)
+load(context, source, ...)
+return yield()
"#;
#[derive(Debug, Clone, Default)]
struct RequireContext<'lua> {
- builtins: HashMap>,
- cached: RefCell>>>,
+ // NOTE: We need to use arc here so that mlua clones
+ // the reference and not the entire inner value(s)
+ builtins: Arc>>,
+ cached: Arc>>>>,
+ locks: Arc>>,
pwd: String,
}
@@ -44,48 +48,76 @@ impl<'lua> RequireContext<'lua> {
..Default::default()
}
}
+
+ pub fn is_locked(&self, absolute_path: &str) -> bool {
+ self.locks.borrow().contains(absolute_path)
+ }
+
+ pub fn set_locked(&self, absolute_path: &str) -> bool {
+ self.locks.borrow_mut().insert(absolute_path.to_string())
+ }
+
+ pub fn set_unlocked(&self, absolute_path: &str) -> bool {
+ self.locks.borrow_mut().remove(absolute_path)
+ }
+
+ pub fn try_acquire_lock_sync(&self, absolute_path: &str) -> bool {
+ if self.is_locked(absolute_path) {
+ false
+ } else {
+ self.set_locked(absolute_path);
+ true
+ }
+ }
+
+ pub fn set_cached(&self, absolute_path: String, result: &LuaResult>) {
+ self.cached
+ .borrow_mut()
+ .insert(absolute_path, result.clone());
+ }
+
+ pub fn get_paths(
+ &self,
+ require_source: String,
+ require_path: String,
+ ) -> LuaResult<(String, String)> {
+ if require_path.starts_with('@') {
+ return Ok((require_path.clone(), require_path));
+ }
+ let path_relative_to_pwd = PathBuf::from(
+ &require_source
+ .trim_start_matches("[string \"")
+ .trim_end_matches("\"]"),
+ )
+ .parent()
+ .unwrap()
+ .join(&require_path);
+ // Try to normalize and resolve relative path segments such as './' and '../'
+ let file_path = match (
+ canonicalize(path_relative_to_pwd.with_extension("luau")),
+ canonicalize(path_relative_to_pwd.with_extension("lua")),
+ ) {
+ (Ok(luau), _) => luau,
+ (_, Ok(lua)) => lua,
+ _ => {
+ return Err(LuaError::RuntimeError(format!(
+ "File does not exist at path '{require_path}'"
+ )))
+ }
+ };
+ let absolute = file_path.to_string_lossy().to_string();
+ let relative = absolute.trim_start_matches(&self.pwd).to_string();
+ Ok((absolute, relative))
+ }
}
impl<'lua> LuaUserData for RequireContext<'lua> {}
-fn paths(
- context: RequireContext,
- require_source: String,
- require_path: String,
-) -> LuaResult<(String, String)> {
- if require_path.starts_with('@') {
- return Ok((require_path.clone(), require_path));
- }
- let path_relative_to_pwd = PathBuf::from(
- &require_source
- .trim_start_matches("[string \"")
- .trim_end_matches("\"]"),
- )
- .parent()
- .unwrap()
- .join(&require_path);
- // Try to normalize and resolve relative path segments such as './' and '../'
- let file_path = match (
- canonicalize(path_relative_to_pwd.with_extension("luau")),
- canonicalize(path_relative_to_pwd.with_extension("lua")),
- ) {
- (Ok(luau), _) => luau,
- (_, Ok(lua)) => lua,
- _ => {
- return Err(LuaError::RuntimeError(format!(
- "File does not exist at path '{require_path}'"
- )))
- }
- };
- let absolute = file_path.to_string_lossy().to_string();
- let relative = absolute.trim_start_matches(&context.pwd).to_string();
- Ok((absolute, relative))
-}
-
fn load_builtin<'lua>(
_lua: &'lua Lua,
- context: RequireContext<'lua>,
+ context: &RequireContext<'lua>,
module_name: String,
+ _has_acquired_lock: bool,
) -> LuaResult> {
match context.builtins.get(&module_name) {
Some(module) => Ok(module.clone()),
@@ -98,14 +130,20 @@ fn load_builtin<'lua>(
async fn load_file<'lua>(
lua: &'lua Lua,
- context: RequireContext<'lua>,
+ context: &RequireContext<'lua>,
absolute_path: String,
relative_path: String,
+ has_acquired_lock: bool,
) -> LuaResult> {
let cached = { context.cached.borrow().get(&absolute_path).cloned() };
match cached {
Some(cached) => cached,
None => {
+ if !has_acquired_lock {
+ return Err(LuaError::RuntimeError(
+ "Failed to get require lock".to_string(),
+ ));
+ }
// Try to read the wanted file, note that we use bytes instead of reading
// to a string since lua scripts are not necessarily valid utf-8 strings
let contents = fs::read(&absolute_path).await.map_err(LuaError::external)?;
@@ -120,21 +158,15 @@ async fn load_file<'lua>(
.set_name(path_relative_no_extension)?
.into_function()?;
let loaded_thread = lua.create_thread(loaded_func)?;
- // Run the thread and provide a channel that will
- // then get its result received when it finishes
- let (tx, rx) = oneshot::channel();
- {
+ // Run the thread and wait for completion using the native task scheduler waker
+ let task_fut = {
let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
let task = sched.schedule_blocking(loaded_thread, LuaMultiValue::new())?;
- sched.set_task_result_sender(task, tx);
- }
+ sched.wait_for_task_completion(task)
+ };
// Wait for the thread to finish running, cache + return our result
- // FIXME: This waits indefinitely for nested requires for some reason
- let rets = rx.await.expect("Sender was dropped during require");
- context
- .cached
- .borrow_mut()
- .insert(absolute_path, rets.clone());
+ let rets = task_fut.await;
+ context.set_cached(absolute_path, &rets);
rets
}
}
@@ -145,35 +177,66 @@ async fn load<'lua>(
context: RequireContext<'lua>,
absolute_path: String,
relative_path: String,
+ has_acquired_lock: bool,
) -> LuaResult> {
- if absolute_path == relative_path && absolute_path.starts_with('@') {
+ let result = if absolute_path == relative_path && absolute_path.starts_with('@') {
if let Some(module_name) = absolute_path.strip_prefix("@lune/") {
- load_builtin(lua, context, module_name.to_string())
+ load_builtin(lua, &context, module_name.to_string(), has_acquired_lock)
} else {
+ // FUTURE: '@' can be used a special prefix for users to set their own
+ // paths relative to a project file, similar to typescript paths config
+ // https://www.typescriptlang.org/tsconfig#paths
Err(LuaError::RuntimeError(
"Require paths prefixed by '@' are not yet supported".to_string(),
))
}
} else {
- load_file(lua, context, absolute_path, relative_path).await
+ load_file(
+ lua,
+ &context,
+ absolute_path.to_string(),
+ relative_path,
+ has_acquired_lock,
+ )
+ .await
+ };
+ if has_acquired_lock {
+ context.set_unlocked(&absolute_path);
}
+ result
}
pub fn create(lua: &'static Lua) -> LuaResult {
let require_context = RequireContext::new();
- let require_print: LuaFunction = lua.named_registry_value("print")?;
+ let require_yield: LuaFunction = lua.named_registry_value("co.yield")?;
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
+ let require_print: LuaFunction = lua.named_registry_value("print")?;
let require_env = TableBuilder::new(lua)?
.with_value("context", require_context)?
- .with_value("print", require_print)?
+ .with_value("yield", require_yield)?
.with_value("info", require_info)?
- .with_function("paths", |_, (context, require_source, require_path)| {
- paths(context, require_source, require_path)
- })?
- .with_async_function("load", |lua, (context, require_source, require_path)| {
- load(lua, context, require_source, require_path)
- })?
+ .with_value("print", require_print)?
+ .with_function(
+ "load",
+ |lua, (context, require_source, require_path): (RequireContext, String, String)| {
+ let (absolute_path, relative_path) =
+ context.get_paths(require_source, require_path)?;
+ // NOTE: We can not acquire the lock in the async part of the require
+ // load process since several requires may have happened for the
+ // same path before the async load task even gets a chance to run
+ let has_lock = context.try_acquire_lock_sync(&absolute_path);
+ let fut = load(lua, context, absolute_path, relative_path, has_lock);
+ let sched = lua
+ .app_data_ref::<&TaskScheduler>()
+ .expect("Missing task scheduler as a lua app data");
+ sched.queue_async_task_inherited(lua.current_thread(), None, async {
+ let rets = fut.await?;
+ let mult = rets.to_lua_multi(lua)?;
+ Ok(Some(mult))
+ })
+ },
+ )?
.build_readonly()?;
let require_fn_lua = lua
diff --git a/packages/lib/src/globals/task.rs b/packages/lib/src/globals/task.rs
index a86fe0e..fe5c9ee 100644
--- a/packages/lib/src/globals/task.rs
+++ b/packages/lib/src/globals/task.rs
@@ -158,13 +158,13 @@ fn coroutine_resume<'lua>(
let result = match value {
LuaThreadOrTaskReference::Thread(t) => {
let task = sched.create_task(TaskKind::Instant, t, None, true)?;
- sched.resume_task(task)
+ sched.resume_task(task, None)
}
- LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t),
+ LuaThreadOrTaskReference::TaskReference(t) => sched.resume_task(t, None),
};
sched.force_set_current_task(Some(current));
match result {
- Ok(rets) => Ok((true, rets)),
+ Ok(rets) => Ok((true, rets.1)),
Err(e) => Ok((false, e.to_lua_multi(lua)?)),
}
}
@@ -187,8 +187,11 @@ fn coroutine_wrap<'lua>(lua: &'lua Lua, func: LuaFunction) -> LuaResult()
.unwrap()
- .resume_task_override(task, Ok(args));
+ .resume_task(task, Some(Ok(args)));
sched.force_set_current_task(Some(current));
- result
+ match result {
+ Ok(rets) => Ok(rets.1),
+ Err(e) => Err(e),
+ }
})
}
diff --git a/packages/lib/src/lua/task/ext/resume_ext.rs b/packages/lib/src/lua/task/ext/resume_ext.rs
index 74c550f..9fecaae 100644
--- a/packages/lib/src/lua/task/ext/resume_ext.rs
+++ b/packages/lib/src/lua/task/ext/resume_ext.rs
@@ -76,10 +76,13 @@ impl TaskSchedulerResumeExt for TaskScheduler<'_> {
Resumes the next queued Lua task, if one exists, blocking
the current thread until it either yields or finishes.
*/
-fn resume_next_blocking_task(
- scheduler: &TaskScheduler<'_>,
- override_args: Option>,
-) -> TaskSchedulerState {
+fn resume_next_blocking_task<'sched, 'args>(
+ scheduler: &TaskScheduler<'sched>,
+ override_args: Option>>,
+) -> TaskSchedulerState
+where
+ 'args: 'sched,
+{
match {
let mut queue_guard = scheduler.tasks_queue_blocking.borrow_mut();
let task = queue_guard.pop_front();
@@ -87,15 +90,16 @@ fn resume_next_blocking_task(
task
} {
None => TaskSchedulerState::new(scheduler),
- Some(task) => match override_args {
- Some(args) => match scheduler.resume_task_override(task, args) {
- Ok(_) => TaskSchedulerState::new(scheduler),
- Err(task_err) => TaskSchedulerState::err(scheduler, task_err),
- },
- None => match scheduler.resume_task(task) {
- Ok(_) => TaskSchedulerState::new(scheduler),
- Err(task_err) => TaskSchedulerState::err(scheduler, task_err),
- },
+ Some(task) => match scheduler.resume_task(task, override_args) {
+ Err(task_err) => {
+ scheduler.wake_completed_task(task, Err(task_err.clone()));
+ TaskSchedulerState::err(scheduler, task_err)
+ }
+ Ok(rets) if rets.0 == LuaThreadStatus::Unresumable => {
+ scheduler.wake_completed_task(task, Ok(rets.1));
+ TaskSchedulerState::new(scheduler)
+ }
+ Ok(_) => TaskSchedulerState::new(scheduler),
},
}
}
@@ -158,9 +162,9 @@ async fn receive_next_message(scheduler: &TaskScheduler<'_>) -> TaskSchedulerSta
if prev == 0 {
panic!(
r#"
- Terminated a background task without it running - this is an internal error!
- Please report it at {}
- "#,
+ Terminated a background task without it running - this is an internal error!
+ Please report it at {}
+ "#,
env!("CARGO_PKG_REPOSITORY")
)
}
diff --git a/packages/lib/src/lua/task/mod.rs b/packages/lib/src/lua/task/mod.rs
index f9d1813..6f8df44 100644
--- a/packages/lib/src/lua/task/mod.rs
+++ b/packages/lib/src/lua/task/mod.rs
@@ -6,6 +6,7 @@ mod scheduler_message;
mod scheduler_state;
mod task_kind;
mod task_reference;
+mod task_waiter;
pub use ext::*;
pub use proxy::*;
diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs
index 460f62e..834a379 100644
--- a/packages/lib/src/lua/task/scheduler.rs
+++ b/packages/lib/src/lua/task/scheduler.rs
@@ -9,12 +9,14 @@ use std::{
use futures_util::{future::LocalBoxFuture, stream::FuturesUnordered, Future};
use mlua::prelude::*;
-use tokio::sync::{mpsc, oneshot, Mutex as AsyncMutex};
+use tokio::sync::{mpsc, Mutex as AsyncMutex};
-use super::scheduler_message::TaskSchedulerMessage;
+use super::{
+ scheduler_message::TaskSchedulerMessage,
+ task_waiter::{TaskWaiterFuture, TaskWaiterState},
+};
pub use super::{task_kind::TaskKind, task_reference::TaskReference};
-type TaskResultSender = oneshot::Sender>>;
type TaskFutureRets<'fut> = LuaResult