Implement new builtins scope for require

This commit is contained in:
Filip Tibell 2023-03-21 15:06:28 +01:00
parent 172ab16823
commit 29a3b41e15
No known key found for this signature in database
7 changed files with 154 additions and 219 deletions

View file

@ -177,7 +177,7 @@ impl Cli {
// Create a new lune object with all globals & run the script
let result = Lune::new()
.with_args(self.script_args)
.try_run(&script_display_name, strip_shebang(script_contents))
.run(&script_display_name, strip_shebang(script_contents))
.await;
Ok(match result {
Err(err) => {

View file

@ -1,5 +1,3 @@
use std::fmt::{Display, Formatter, Result as FmtResult};
use mlua::prelude::*;
mod fs;
@ -10,119 +8,42 @@ mod stdio;
mod task;
mod top_level;
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum LuneGlobal {
Fs,
Net,
Process { args: Vec<String> },
Require,
Stdio,
Task,
TopLevel,
}
impl Display for LuneGlobal {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
write!(
f,
"{}",
match self {
Self::Fs => "fs",
Self::Net => "net",
Self::Process { .. } => "process",
Self::Require => "require",
Self::Stdio => "stdio",
Self::Task => "task",
Self::TopLevel => "toplevel",
}
)
}
}
impl LuneGlobal {
/**
Create a vector that contains all available Lune globals, with
the [`LuneGlobal::Process`] global containing the given args.
*/
pub fn all<S: AsRef<str>>(args: &[S]) -> Vec<Self> {
vec![
Self::Fs,
Self::Net,
Self::Process {
args: args.iter().map(|s| s.as_ref().to_string()).collect(),
},
Self::Require,
Self::Stdio,
Self::Task,
Self::TopLevel,
]
}
/**
Checks if this Lune global is a proxy global.
A proxy global is a global that re-implements or proxies functionality of one or
more existing lua globals, and may store internal references to the original global(s).
This means that proxy globals should only be injected into a lua global
environment once, since injecting twice or more will potentially break the
functionality of the proxy global and / or cause undefined behavior.
*/
pub fn is_proxy(&self) -> bool {
matches!(self, Self::Require | Self::TopLevel)
}
/**
Checks if this Lune global is an injector.
An injector is similar to a proxy global but will inject
value(s) into the global lua environment during creation,
to ensure correct usage and compatibility with base Luau.
*/
pub fn is_injector(&self) -> bool {
matches!(self, Self::Task)
}
/**
Creates the [`mlua::Table`] value for this Lune global.
Note that proxy globals should be handled with special care and that [`LuneGlobal::inject()`]
should be preferred over manually creating and manipulating the value(s) of any Lune global.
*/
pub fn value(&self, lua: &'static Lua) -> LuaResult<LuaTable> {
match self {
LuneGlobal::Fs => fs::create(lua),
LuneGlobal::Net => net::create(lua),
LuneGlobal::Process { args } => process::create(lua, args.clone()),
LuneGlobal::Require => require::create(lua),
LuneGlobal::Stdio => stdio::create(lua),
LuneGlobal::Task => task::create(lua),
LuneGlobal::TopLevel => top_level::create(lua),
}
}
/**
Injects the Lune global into a lua global environment.
This takes ownership since proxy Lune globals should
only ever be injected into a lua global environment once.
Refer to [`LuneGlobal::is_top_level()`] for more info on proxy globals.
*/
pub fn inject(self, lua: &'static Lua) -> LuaResult<()> {
let globals = lua.globals();
let table = self.value(lua)?;
// NOTE: Top level globals are special, the values
// *in* the table they return should be set directly,
// instead of setting the table itself as the global
if self.is_proxy() {
for pair in table.pairs::<LuaValue, LuaValue>() {
let (key, value) = pair?;
globals.raw_set(key, value)?;
}
Ok(())
} else {
globals.raw_set(self.to_string(), table)
}
}
pub fn create(lua: &'static Lua, args: Vec<String>) -> LuaResult<()> {
// Create all builtins
let builtins = vec![
("fs", fs::create(lua)?),
("net", net::create(lua)?),
("process", process::create(lua, args)?),
("stdio", stdio::create(lua)?),
("task", task::create(lua)?),
];
// TODO: Remove this when we have proper LSP support for custom require types
let lua_globals = lua.globals();
for (name, builtin) in &builtins {
lua_globals.set(*name, builtin.clone())?;
}
// Create our importer (require) with builtins
let require_fn = require::create(lua, builtins)?;
// Create all top-level globals
let globals = vec![
("require", require_fn),
("print", lua.create_function(top_level::top_level_print)?),
("warn", lua.create_function(top_level::top_level_warn)?),
("error", lua.create_function(top_level::top_level_error)?),
(
"printinfo",
lua.create_function(top_level::top_level_printinfo)?,
),
];
// Set top-level globals and seal them
for (name, global) in globals {
lua_globals.set(name, global)?;
}
lua_globals.set_readonly(true);
Ok(())
}

View file

@ -35,7 +35,11 @@ struct RequireContext<'lua> {
}
impl<'lua> RequireContext<'lua> {
pub fn new() -> Self {
pub fn new<K, V>(lua: &'static Lua, builtins_vec: Vec<(K, V)>) -> LuaResult<Self>
where
K: Into<String>,
V: ToLua<'static>,
{
let mut pwd = current_dir()
.expect("Failed to access current working directory")
.to_string_lossy()
@ -43,10 +47,15 @@ impl<'lua> RequireContext<'lua> {
if !pwd.ends_with(path::MAIN_SEPARATOR) {
pwd = format!("{pwd}{}", path::MAIN_SEPARATOR)
}
Self {
pwd,
..Default::default()
let mut builtins = HashMap::new();
for (key, value) in builtins_vec {
builtins.insert(key.into(), value.to_lua_multi(lua)?);
}
Ok(Self {
pwd,
builtins: Arc::new(builtins),
..Default::default()
})
}
pub fn is_locked(&self, absolute_path: &str) -> bool {
@ -206,8 +215,12 @@ async fn load<'lua>(
result
}
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
let require_context = RequireContext::new();
pub fn create<K, V>(lua: &'static Lua, builtins: Vec<(K, V)>) -> LuaResult<LuaFunction>
where
K: Clone + Into<String>,
V: Clone + ToLua<'static>,
{
let require_context = RequireContext::new(lua, builtins)?;
let require_yield: LuaFunction = lua.named_registry_value("co.yield")?;
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
let require_print: LuaFunction = lua.named_registry_value("print")?;
@ -244,8 +257,5 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
.set_name("require")?
.set_environment(require_env)?
.into_function()?;
TableBuilder::new(lua)?
.with_value("require", require_fn_lua)?
.build_readonly()
Ok(require_fn_lua)
}

View file

@ -1,57 +1,55 @@
use mlua::prelude::*;
use crate::{
lua::stdio::formatting::{format_label, pretty_format_multi_value},
lua::table::TableBuilder,
};
use crate::lua::stdio::formatting::{format_label, pretty_format_multi_value};
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
// HACK: We need to preserve the default behavior of the
// print and error functions, for pcall and such, which
// is really tricky to do from scratch so we will just
// proxy the default print and error functions here
TableBuilder::new(lua)?
.with_function("print", |lua, args: LuaMultiValue| {
let formatted = pretty_format_multi_value(&args)?;
let print: LuaFunction = lua.named_registry_value("print")?;
print.call(formatted)?;
Ok(())
})?
.with_function("info", |lua, args: LuaMultiValue| {
let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!(
"{}\n{}",
format_label("info"),
pretty_format_multi_value(&args)?
))?;
Ok(())
})?
.with_function("warn", |lua, args: LuaMultiValue| {
let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!(
"{}\n{}",
format_label("warn"),
pretty_format_multi_value(&args)?
))?;
Ok(())
})?
.with_function("error", |lua, (arg, level): (LuaValue, Option<u32>)| {
let error: LuaFunction = lua.named_registry_value("error")?;
let trace: LuaFunction = lua.named_registry_value("dbg.trace")?;
error.call((
LuaError::CallbackError {
traceback: format!("override traceback:{}", trace.call::<_, String>(())?),
cause: LuaError::external(format!(
"{}\n{}",
format_label("error"),
pretty_format_multi_value(&arg.to_lua_multi(lua)?)?
))
.into(),
},
level,
))?;
Ok(())
})?
// TODO: Add an override for tostring that formats errors in a nicer way
.build_readonly()
// HACK: We need to preserve the default behavior of the
// print and error functions, for pcall and such, which
// is really tricky to do from scratch so we will just
// proxy the default print and error functions here
pub fn top_level_print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let formatted = pretty_format_multi_value(&args)?;
let print: LuaFunction = lua.named_registry_value("print")?;
print.call(formatted)?;
Ok(())
}
pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!(
"{}\n{}",
format_label("info"),
pretty_format_multi_value(&args)?
))?;
Ok(())
}
pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> {
let print: LuaFunction = lua.named_registry_value("print")?;
print.call(format!(
"{}\n{}",
format_label("warn"),
pretty_format_multi_value(&args)?
))?;
Ok(())
}
pub fn top_level_error(lua: &Lua, (arg, level): (LuaValue, Option<u32>)) -> LuaResult<()> {
let error: LuaFunction = lua.named_registry_value("error")?;
let trace: LuaFunction = lua.named_registry_value("dbg.trace")?;
error.call((
LuaError::CallbackError {
traceback: format!("override traceback:{}", trace.call::<_, String>(())?),
cause: LuaError::external(format!(
"{}\n{}",
format_label("error"),
pretty_format_multi_value(&arg.to_lua_multi(lua)?)?
))
.into(),
},
level,
))?;
Ok(())
}
// TODO: Add an override for tostring that formats errors in a nicer way

View file

@ -12,8 +12,6 @@ mod error;
mod tests;
pub use error::LuneError;
pub use globals::LuneGlobal;
pub use lua::create_lune_lua;
#[derive(Clone, Debug, Default)]
pub struct Lune {
@ -52,39 +50,28 @@ impl Lune {
both live for the remainer of the program, and that this leaks memory using
[`Box::leak`] that will then get deallocated when the program exits.
*/
pub async fn try_run(
&self,
script_name: impl AsRef<str>,
script_contents: impl AsRef<[u8]>,
) -> Result<ExitCode, LuneError> {
self.run(script_name, script_contents, true)
.await
.map_err(LuneError::from_lua_error)
}
/**
Tries to run a Lune script directly, returning the underlying
result type if loading the script raises a lua error.
Passing `false` as the third argument `emit_prettified_errors` will
bypass any additional error formatting automatically done by Lune
for errors that are emitted while the script is running.
Behavior is otherwise exactly the same as `try_run`
and `try_run` should be preferred in all other cases.
*/
#[doc(hidden)]
pub async fn run(
&self,
script_name: impl AsRef<str>,
script_contents: impl AsRef<[u8]>,
emit_prettified_errors: bool,
) -> Result<ExitCode, LuneError> {
self.run_inner(script_name, script_contents)
.await
.map_err(LuneError::from_lua_error)
}
async fn run_inner(
&self,
script_name: impl AsRef<str>,
script_contents: impl AsRef<[u8]>,
) -> Result<ExitCode, LuaError> {
// Create our special lune-flavored Lua object with extra registry values
let lua = create_lune_lua()?;
// Create our task scheduler
let lua = lua::create_lune_lua()?;
// Create our task scheduler and all globals
// NOTE: Some globals require the task scheduler to exist on startup
let sched = TaskScheduler::new(lua)?.into_static();
lua.set_app_data(sched);
globals::create(lua, self.args.clone())?;
// Create the main thread and schedule it
let main_chunk = lua
.load(script_contents.as_ref())
@ -93,11 +80,6 @@ impl Lune {
let main_thread = lua.create_thread(main_chunk)?;
let main_thread_args = LuaValue::Nil.to_lua_multi(lua)?;
sched.schedule_blocking(main_thread, main_thread_args)?;
// Create our wanted lune globals, some of these need
// the task scheduler be available during construction
for global in LuneGlobal::all(&self.args) {
global.inject(lua)?;
}
// Keep running the scheduler until there are either no tasks
// left to run, or until a task requests to exit the process
let exit_code = LocalSet::new()
@ -106,11 +88,7 @@ impl Lune {
loop {
let result = sched.resume_queue().await;
if let Some(err) = result.get_lua_error() {
if emit_prettified_errors {
eprintln!("{}", LuneError::from_lua_error(err));
} else {
eprintln!("{err}");
}
eprintln!("{}", LuneError::from_lua_error(err));
got_error = true;
}
if result.is_done() {

View file

@ -37,7 +37,7 @@ macro_rules! create_tests {
.trim_end_matches(".luau")
.trim_end_matches(".lua")
.to_string();
let exit_code = lune.try_run(&script_name, &script).await?;
let exit_code = lune.run(&script_name, &script).await?;
Ok(exit_code)
}
)* }
@ -63,6 +63,7 @@ create_tests! {
require_async: "globals/require/tests/async",
require_async_concurrent: "globals/require/tests/async_concurrent",
require_async_sequential: "globals/require/tests/async_sequential",
require_builtins: "globals/require/tests/builtins",
require_children: "globals/require/tests/children",
require_invalid: "globals/require/tests/invalid",
require_nested: "globals/require/tests/nested",

View file

@ -0,0 +1,27 @@
local fs = require("@lune/fs") :: typeof(fs)
local net = require("@lune/net") :: typeof(net)
local process = require("@lune/process") :: typeof(process)
local stdio = require("@lune/stdio") :: typeof(stdio)
local task = require("@lune/task") :: typeof(task)
assert(type(fs.move) == "function")
assert(type(net.request) == "function")
assert(type(process.cwd) == "string")
assert(type(stdio.format("")) == "string")
assert(type(task.spawn(function() end)) == "thread")
assert(not pcall(function()
return require("@") :: any
end))
assert(not pcall(function()
return require("@lune") :: any
end))
assert(not pcall(function()
return require("@lune/") :: any
end))
assert(not pcall(function()
return require("@src") :: any
end))