Implement new builtins scope for require

This commit is contained in:
Filip Tibell 2023-03-21 15:06:28 +01:00
parent 172ab16823
commit 29a3b41e15
No known key found for this signature in database
7 changed files with 154 additions and 219 deletions

View file

@ -177,7 +177,7 @@ impl Cli {
// Create a new lune object with all globals & run the script // Create a new lune object with all globals & run the script
let result = Lune::new() let result = Lune::new()
.with_args(self.script_args) .with_args(self.script_args)
.try_run(&script_display_name, strip_shebang(script_contents)) .run(&script_display_name, strip_shebang(script_contents))
.await; .await;
Ok(match result { Ok(match result {
Err(err) => { Err(err) => {

View file

@ -1,5 +1,3 @@
use std::fmt::{Display, Formatter, Result as FmtResult};
use mlua::prelude::*; use mlua::prelude::*;
mod fs; mod fs;
@ -10,119 +8,42 @@ mod stdio;
mod task; mod task;
mod top_level; mod top_level;
#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub fn create(lua: &'static Lua, args: Vec<String>) -> LuaResult<()> {
pub enum LuneGlobal { // Create all builtins
Fs, let builtins = vec![
Net, ("fs", fs::create(lua)?),
Process { args: Vec<String> }, ("net", net::create(lua)?),
Require, ("process", process::create(lua, args)?),
Stdio, ("stdio", stdio::create(lua)?),
Task, ("task", task::create(lua)?),
TopLevel, ];
}
impl Display for LuneGlobal { // TODO: Remove this when we have proper LSP support for custom require types
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { let lua_globals = lua.globals();
write!( for (name, builtin) in &builtins {
f, lua_globals.set(*name, builtin.clone())?;
"{}",
match self {
Self::Fs => "fs",
Self::Net => "net",
Self::Process { .. } => "process",
Self::Require => "require",
Self::Stdio => "stdio",
Self::Task => "task",
Self::TopLevel => "toplevel",
}
)
}
}
impl LuneGlobal {
/**
Create a vector that contains all available Lune globals, with
the [`LuneGlobal::Process`] global containing the given args.
*/
pub fn all<S: AsRef<str>>(args: &[S]) -> Vec<Self> {
vec![
Self::Fs,
Self::Net,
Self::Process {
args: args.iter().map(|s| s.as_ref().to_string()).collect(),
},
Self::Require,
Self::Stdio,
Self::Task,
Self::TopLevel,
]
} }
/** // Create our importer (require) with builtins
Checks if this Lune global is a proxy global. let require_fn = require::create(lua, builtins)?;
A proxy global is a global that re-implements or proxies functionality of one or // Create all top-level globals
more existing lua globals, and may store internal references to the original global(s). let globals = vec![
("require", require_fn),
("print", lua.create_function(top_level::top_level_print)?),
("warn", lua.create_function(top_level::top_level_warn)?),
("error", lua.create_function(top_level::top_level_error)?),
(
"printinfo",
lua.create_function(top_level::top_level_printinfo)?,
),
];
This means that proxy globals should only be injected into a lua global // Set top-level globals and seal them
environment once, since injecting twice or more will potentially break the for (name, global) in globals {
functionality of the proxy global and / or cause undefined behavior. lua_globals.set(name, global)?;
*/
pub fn is_proxy(&self) -> bool {
matches!(self, Self::Require | Self::TopLevel)
} }
lua_globals.set_readonly(true);
/**
Checks if this Lune global is an injector.
An injector is similar to a proxy global but will inject
value(s) into the global lua environment during creation,
to ensure correct usage and compatibility with base Luau.
*/
pub fn is_injector(&self) -> bool {
matches!(self, Self::Task)
}
/**
Creates the [`mlua::Table`] value for this Lune global.
Note that proxy globals should be handled with special care and that [`LuneGlobal::inject()`]
should be preferred over manually creating and manipulating the value(s) of any Lune global.
*/
pub fn value(&self, lua: &'static Lua) -> LuaResult<LuaTable> {
match self {
LuneGlobal::Fs => fs::create(lua),
LuneGlobal::Net => net::create(lua),
LuneGlobal::Process { args } => process::create(lua, args.clone()),
LuneGlobal::Require => require::create(lua),
LuneGlobal::Stdio => stdio::create(lua),
LuneGlobal::Task => task::create(lua),
LuneGlobal::TopLevel => top_level::create(lua),
}
}
/**
Injects the Lune global into a lua global environment.
This takes ownership since proxy Lune globals should
only ever be injected into a lua global environment once.
Refer to [`LuneGlobal::is_top_level()`] for more info on proxy globals.
*/
pub fn inject(self, lua: &'static Lua) -> LuaResult<()> {
let globals = lua.globals();
let table = self.value(lua)?;
// NOTE: Top level globals are special, the values
// *in* the table they return should be set directly,
// instead of setting the table itself as the global
if self.is_proxy() {
for pair in table.pairs::<LuaValue, LuaValue>() {
let (key, value) = pair?;
globals.raw_set(key, value)?;
}
Ok(()) Ok(())
} else {
globals.raw_set(self.to_string(), table)
}
}
} }

View file

@ -35,7 +35,11 @@ struct RequireContext<'lua> {
} }
impl<'lua> RequireContext<'lua> { impl<'lua> RequireContext<'lua> {
pub fn new() -> Self { pub fn new<K, V>(lua: &'static Lua, builtins_vec: Vec<(K, V)>) -> LuaResult<Self>
where
K: Into<String>,
V: ToLua<'static>,
{
let mut pwd = current_dir() let mut pwd = current_dir()
.expect("Failed to access current working directory") .expect("Failed to access current working directory")
.to_string_lossy() .to_string_lossy()
@ -43,10 +47,15 @@ impl<'lua> RequireContext<'lua> {
if !pwd.ends_with(path::MAIN_SEPARATOR) { if !pwd.ends_with(path::MAIN_SEPARATOR) {
pwd = format!("{pwd}{}", path::MAIN_SEPARATOR) pwd = format!("{pwd}{}", path::MAIN_SEPARATOR)
} }
Self { let mut builtins = HashMap::new();
pwd, for (key, value) in builtins_vec {
..Default::default() builtins.insert(key.into(), value.to_lua_multi(lua)?);
} }
Ok(Self {
pwd,
builtins: Arc::new(builtins),
..Default::default()
})
} }
pub fn is_locked(&self, absolute_path: &str) -> bool { pub fn is_locked(&self, absolute_path: &str) -> bool {
@ -206,8 +215,12 @@ async fn load<'lua>(
result result
} }
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> { pub fn create<K, V>(lua: &'static Lua, builtins: Vec<(K, V)>) -> LuaResult<LuaFunction>
let require_context = RequireContext::new(); where
K: Clone + Into<String>,
V: Clone + ToLua<'static>,
{
let require_context = RequireContext::new(lua, builtins)?;
let require_yield: LuaFunction = lua.named_registry_value("co.yield")?; let require_yield: LuaFunction = lua.named_registry_value("co.yield")?;
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
let require_print: LuaFunction = lua.named_registry_value("print")?; let require_print: LuaFunction = lua.named_registry_value("print")?;
@ -244,8 +257,5 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
.set_name("require")? .set_name("require")?
.set_environment(require_env)? .set_environment(require_env)?
.into_function()?; .into_function()?;
Ok(require_fn_lua)
TableBuilder::new(lua)?
.with_value("require", require_fn_lua)?
.build_readonly()
} }

View file

@ -1,23 +1,20 @@
use mlua::prelude::*; use mlua::prelude::*;
use crate::{ use crate::lua::stdio::formatting::{format_label, pretty_format_multi_value};
lua::stdio::formatting::{format_label, pretty_format_multi_value},
lua::table::TableBuilder,
};
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> { // HACK: We need to preserve the default behavior of the
// HACK: We need to preserve the default behavior of the // print and error functions, for pcall and such, which
// print and error functions, for pcall and such, which // is really tricky to do from scratch so we will just
// is really tricky to do from scratch so we will just // proxy the default print and error functions here
// proxy the default print and error functions here
TableBuilder::new(lua)? pub fn top_level_print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
.with_function("print", |lua, args: LuaMultiValue| {
let formatted = pretty_format_multi_value(&args)?; let formatted = pretty_format_multi_value(&args)?;
let print: LuaFunction = lua.named_registry_value("print")?; let print: LuaFunction = lua.named_registry_value("print")?;
print.call(formatted)?; print.call(formatted)?;
Ok(()) Ok(())
})? }
.with_function("info", |lua, args: LuaMultiValue| {
pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let print: LuaFunction = lua.named_registry_value("print")?; let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!( print.call(format!(
"{}\n{}", "{}\n{}",
@ -25,8 +22,9 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
pretty_format_multi_value(&args)? pretty_format_multi_value(&args)?
))?; ))?;
Ok(()) Ok(())
})? }
.with_function("warn", |lua, args: LuaMultiValue| {
pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let print: LuaFunction = lua.named_registry_value("print")?; let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!( print.call(format!(
"{}\n{}", "{}\n{}",
@ -34,8 +32,9 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
pretty_format_multi_value(&args)? pretty_format_multi_value(&args)?
))?; ))?;
Ok(()) Ok(())
})? }
.with_function("error", |lua, (arg, level): (LuaValue, Option<u32>)| {
pub fn top_level_error(lua: &Lua, (arg, level): (LuaValue, Option<u32>)) -> LuaResult<()> {
let error: LuaFunction = lua.named_registry_value("error")?; let error: LuaFunction = lua.named_registry_value("error")?;
let trace: LuaFunction = lua.named_registry_value("dbg.trace")?; let trace: LuaFunction = lua.named_registry_value("dbg.trace")?;
error.call(( error.call((
@ -51,7 +50,6 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
level, level,
))?; ))?;
Ok(()) Ok(())
})?
// TODO: Add an override for tostring that formats errors in a nicer way
.build_readonly()
} }
// TODO: Add an override for tostring that formats errors in a nicer way

View file

@ -12,8 +12,6 @@ mod error;
mod tests; mod tests;
pub use error::LuneError; pub use error::LuneError;
pub use globals::LuneGlobal;
pub use lua::create_lune_lua;
#[derive(Clone, Debug, Default)] #[derive(Clone, Debug, Default)]
pub struct Lune { pub struct Lune {
@ -52,39 +50,28 @@ impl Lune {
both live for the remainer of the program, and that this leaks memory using both live for the remainer of the program, and that this leaks memory using
[`Box::leak`] that will then get deallocated when the program exits. [`Box::leak`] that will then get deallocated when the program exits.
*/ */
pub async fn try_run(
&self,
script_name: impl AsRef<str>,
script_contents: impl AsRef<[u8]>,
) -> Result<ExitCode, LuneError> {
self.run(script_name, script_contents, true)
.await
.map_err(LuneError::from_lua_error)
}
/**
Tries to run a Lune script directly, returning the underlying
result type if loading the script raises a lua error.
Passing `false` as the third argument `emit_prettified_errors` will
bypass any additional error formatting automatically done by Lune
for errors that are emitted while the script is running.
Behavior is otherwise exactly the same as `try_run`
and `try_run` should be preferred in all other cases.
*/
#[doc(hidden)]
pub async fn run( pub async fn run(
&self, &self,
script_name: impl AsRef<str>, script_name: impl AsRef<str>,
script_contents: impl AsRef<[u8]>, script_contents: impl AsRef<[u8]>,
emit_prettified_errors: bool, ) -> Result<ExitCode, LuneError> {
self.run_inner(script_name, script_contents)
.await
.map_err(LuneError::from_lua_error)
}
async fn run_inner(
&self,
script_name: impl AsRef<str>,
script_contents: impl AsRef<[u8]>,
) -> Result<ExitCode, LuaError> { ) -> Result<ExitCode, LuaError> {
// Create our special lune-flavored Lua object with extra registry values // Create our special lune-flavored Lua object with extra registry values
let lua = create_lune_lua()?; let lua = lua::create_lune_lua()?;
// Create our task scheduler // Create our task scheduler and all globals
// NOTE: Some globals require the task scheduler to exist on startup
let sched = TaskScheduler::new(lua)?.into_static(); let sched = TaskScheduler::new(lua)?.into_static();
lua.set_app_data(sched); lua.set_app_data(sched);
globals::create(lua, self.args.clone())?;
// Create the main thread and schedule it // Create the main thread and schedule it
let main_chunk = lua let main_chunk = lua
.load(script_contents.as_ref()) .load(script_contents.as_ref())
@ -93,11 +80,6 @@ impl Lune {
let main_thread = lua.create_thread(main_chunk)?; let main_thread = lua.create_thread(main_chunk)?;
let main_thread_args = LuaValue::Nil.to_lua_multi(lua)?; let main_thread_args = LuaValue::Nil.to_lua_multi(lua)?;
sched.schedule_blocking(main_thread, main_thread_args)?; sched.schedule_blocking(main_thread, main_thread_args)?;
// Create our wanted lune globals, some of these need
// the task scheduler be available during construction
for global in LuneGlobal::all(&self.args) {
global.inject(lua)?;
}
// Keep running the scheduler until there are either no tasks // Keep running the scheduler until there are either no tasks
// left to run, or until a task requests to exit the process // left to run, or until a task requests to exit the process
let exit_code = LocalSet::new() let exit_code = LocalSet::new()
@ -106,11 +88,7 @@ impl Lune {
loop { loop {
let result = sched.resume_queue().await; let result = sched.resume_queue().await;
if let Some(err) = result.get_lua_error() { if let Some(err) = result.get_lua_error() {
if emit_prettified_errors {
eprintln!("{}", LuneError::from_lua_error(err)); eprintln!("{}", LuneError::from_lua_error(err));
} else {
eprintln!("{err}");
}
got_error = true; got_error = true;
} }
if result.is_done() { if result.is_done() {

View file

@ -37,7 +37,7 @@ macro_rules! create_tests {
.trim_end_matches(".luau") .trim_end_matches(".luau")
.trim_end_matches(".lua") .trim_end_matches(".lua")
.to_string(); .to_string();
let exit_code = lune.try_run(&script_name, &script).await?; let exit_code = lune.run(&script_name, &script).await?;
Ok(exit_code) Ok(exit_code)
} }
)* } )* }
@ -63,6 +63,7 @@ create_tests! {
require_async: "globals/require/tests/async", require_async: "globals/require/tests/async",
require_async_concurrent: "globals/require/tests/async_concurrent", require_async_concurrent: "globals/require/tests/async_concurrent",
require_async_sequential: "globals/require/tests/async_sequential", require_async_sequential: "globals/require/tests/async_sequential",
require_builtins: "globals/require/tests/builtins",
require_children: "globals/require/tests/children", require_children: "globals/require/tests/children",
require_invalid: "globals/require/tests/invalid", require_invalid: "globals/require/tests/invalid",
require_nested: "globals/require/tests/nested", require_nested: "globals/require/tests/nested",

View file

@ -0,0 +1,27 @@
local fs = require("@lune/fs") :: typeof(fs)
local net = require("@lune/net") :: typeof(net)
local process = require("@lune/process") :: typeof(process)
local stdio = require("@lune/stdio") :: typeof(stdio)
local task = require("@lune/task") :: typeof(task)
assert(type(fs.move) == "function")
assert(type(net.request) == "function")
assert(type(process.cwd) == "string")
assert(type(stdio.format("")) == "string")
assert(type(task.spawn(function() end)) == "thread")
assert(not pcall(function()
return require("@") :: any
end))
assert(not pcall(function()
return require("@lune") :: any
end))
assert(not pcall(function()
return require("@lune/") :: any
end))
assert(not pcall(function()
return require("@src") :: any
end))