From 29a3b41e1580a3a2e87d4ff826d9e565a30b42d9 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Tue, 21 Mar 2023 15:06:28 +0100 Subject: [PATCH] Implement new builtins scope for require --- packages/cli/src/cli.rs | 2 +- packages/lib/src/globals/mod.rs | 155 ++++++---------------- packages/lib/src/globals/require.rs | 30 +++-- packages/lib/src/globals/top_level.rs | 104 +++++++-------- packages/lib/src/lib.rs | 52 +++----- packages/lib/src/tests.rs | 3 +- tests/globals/require/tests/builtins.luau | 27 ++++ 7 files changed, 154 insertions(+), 219 deletions(-) create mode 100644 tests/globals/require/tests/builtins.luau diff --git a/packages/cli/src/cli.rs b/packages/cli/src/cli.rs index 3c0723a..e1d7541 100644 --- a/packages/cli/src/cli.rs +++ b/packages/cli/src/cli.rs @@ -177,7 +177,7 @@ impl Cli { // Create a new lune object with all globals & run the script let result = Lune::new() .with_args(self.script_args) - .try_run(&script_display_name, strip_shebang(script_contents)) + .run(&script_display_name, strip_shebang(script_contents)) .await; Ok(match result { Err(err) => { diff --git a/packages/lib/src/globals/mod.rs b/packages/lib/src/globals/mod.rs index 2ddff96..f5015f3 100644 --- a/packages/lib/src/globals/mod.rs +++ b/packages/lib/src/globals/mod.rs @@ -1,5 +1,3 @@ -use std::fmt::{Display, Formatter, Result as FmtResult}; - use mlua::prelude::*; mod fs; @@ -10,119 +8,42 @@ mod stdio; mod task; mod top_level; -#[derive(Clone, Debug, PartialEq, Eq, Hash)] -pub enum LuneGlobal { - Fs, - Net, - Process { args: Vec }, - Require, - Stdio, - Task, - TopLevel, -} - -impl Display for LuneGlobal { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - write!( - f, - "{}", - match self { - Self::Fs => "fs", - Self::Net => "net", - Self::Process { .. } => "process", - Self::Require => "require", - Self::Stdio => "stdio", - Self::Task => "task", - Self::TopLevel => "toplevel", - } - ) - } -} - -impl LuneGlobal { - /** - Create a vector that contains all available Lune globals, with - the [`LuneGlobal::Process`] global containing the given args. - */ - pub fn all>(args: &[S]) -> Vec { - vec![ - Self::Fs, - Self::Net, - Self::Process { - args: args.iter().map(|s| s.as_ref().to_string()).collect(), - }, - Self::Require, - Self::Stdio, - Self::Task, - Self::TopLevel, - ] - } - - /** - Checks if this Lune global is a proxy global. - - A proxy global is a global that re-implements or proxies functionality of one or - more existing lua globals, and may store internal references to the original global(s). - - This means that proxy globals should only be injected into a lua global - environment once, since injecting twice or more will potentially break the - functionality of the proxy global and / or cause undefined behavior. - */ - pub fn is_proxy(&self) -> bool { - matches!(self, Self::Require | Self::TopLevel) - } - - /** - Checks if this Lune global is an injector. - - An injector is similar to a proxy global but will inject - value(s) into the global lua environment during creation, - to ensure correct usage and compatibility with base Luau. - */ - pub fn is_injector(&self) -> bool { - matches!(self, Self::Task) - } - - /** - Creates the [`mlua::Table`] value for this Lune global. - - Note that proxy globals should be handled with special care and that [`LuneGlobal::inject()`] - should be preferred over manually creating and manipulating the value(s) of any Lune global. - */ - pub fn value(&self, lua: &'static Lua) -> LuaResult { - match self { - LuneGlobal::Fs => fs::create(lua), - LuneGlobal::Net => net::create(lua), - LuneGlobal::Process { args } => process::create(lua, args.clone()), - LuneGlobal::Require => require::create(lua), - LuneGlobal::Stdio => stdio::create(lua), - LuneGlobal::Task => task::create(lua), - LuneGlobal::TopLevel => top_level::create(lua), - } - } - - /** - Injects the Lune global into a lua global environment. - - This takes ownership since proxy Lune globals should - only ever be injected into a lua global environment once. - - Refer to [`LuneGlobal::is_top_level()`] for more info on proxy globals. - */ - pub fn inject(self, lua: &'static Lua) -> LuaResult<()> { - let globals = lua.globals(); - let table = self.value(lua)?; - // NOTE: Top level globals are special, the values - // *in* the table they return should be set directly, - // instead of setting the table itself as the global - if self.is_proxy() { - for pair in table.pairs::() { - let (key, value) = pair?; - globals.raw_set(key, value)?; - } - Ok(()) - } else { - globals.raw_set(self.to_string(), table) - } - } +pub fn create(lua: &'static Lua, args: Vec) -> LuaResult<()> { + // Create all builtins + let builtins = vec![ + ("fs", fs::create(lua)?), + ("net", net::create(lua)?), + ("process", process::create(lua, args)?), + ("stdio", stdio::create(lua)?), + ("task", task::create(lua)?), + ]; + + // TODO: Remove this when we have proper LSP support for custom require types + let lua_globals = lua.globals(); + for (name, builtin) in &builtins { + lua_globals.set(*name, builtin.clone())?; + } + + // Create our importer (require) with builtins + let require_fn = require::create(lua, builtins)?; + + // Create all top-level globals + let globals = vec![ + ("require", require_fn), + ("print", lua.create_function(top_level::top_level_print)?), + ("warn", lua.create_function(top_level::top_level_warn)?), + ("error", lua.create_function(top_level::top_level_error)?), + ( + "printinfo", + lua.create_function(top_level::top_level_printinfo)?, + ), + ]; + + // Set top-level globals and seal them + for (name, global) in globals { + lua_globals.set(name, global)?; + } + lua_globals.set_readonly(true); + + Ok(()) } diff --git a/packages/lib/src/globals/require.rs b/packages/lib/src/globals/require.rs index 802d52b..6b6e738 100644 --- a/packages/lib/src/globals/require.rs +++ b/packages/lib/src/globals/require.rs @@ -35,7 +35,11 @@ struct RequireContext<'lua> { } impl<'lua> RequireContext<'lua> { - pub fn new() -> Self { + pub fn new(lua: &'static Lua, builtins_vec: Vec<(K, V)>) -> LuaResult + where + K: Into, + V: ToLua<'static>, + { let mut pwd = current_dir() .expect("Failed to access current working directory") .to_string_lossy() @@ -43,10 +47,15 @@ impl<'lua> RequireContext<'lua> { if !pwd.ends_with(path::MAIN_SEPARATOR) { pwd = format!("{pwd}{}", path::MAIN_SEPARATOR) } - Self { - pwd, - ..Default::default() + let mut builtins = HashMap::new(); + for (key, value) in builtins_vec { + builtins.insert(key.into(), value.to_lua_multi(lua)?); } + Ok(Self { + pwd, + builtins: Arc::new(builtins), + ..Default::default() + }) } pub fn is_locked(&self, absolute_path: &str) -> bool { @@ -206,8 +215,12 @@ async fn load<'lua>( result } -pub fn create(lua: &'static Lua) -> LuaResult { - let require_context = RequireContext::new(); +pub fn create(lua: &'static Lua, builtins: Vec<(K, V)>) -> LuaResult +where + K: Clone + Into, + V: Clone + ToLua<'static>, +{ + let require_context = RequireContext::new(lua, builtins)?; let require_yield: LuaFunction = lua.named_registry_value("co.yield")?; let require_info: LuaFunction = lua.named_registry_value("dbg.info")?; let require_print: LuaFunction = lua.named_registry_value("print")?; @@ -244,8 +257,5 @@ pub fn create(lua: &'static Lua) -> LuaResult { .set_name("require")? .set_environment(require_env)? .into_function()?; - - TableBuilder::new(lua)? - .with_value("require", require_fn_lua)? - .build_readonly() + Ok(require_fn_lua) } diff --git a/packages/lib/src/globals/top_level.rs b/packages/lib/src/globals/top_level.rs index b9cf5e3..376c6a6 100644 --- a/packages/lib/src/globals/top_level.rs +++ b/packages/lib/src/globals/top_level.rs @@ -1,57 +1,55 @@ use mlua::prelude::*; -use crate::{ - lua::stdio::formatting::{format_label, pretty_format_multi_value}, - lua::table::TableBuilder, -}; +use crate::lua::stdio::formatting::{format_label, pretty_format_multi_value}; -pub fn create(lua: &'static Lua) -> LuaResult { - // HACK: We need to preserve the default behavior of the - // print and error functions, for pcall and such, which - // is really tricky to do from scratch so we will just - // proxy the default print and error functions here - TableBuilder::new(lua)? - .with_function("print", |lua, args: LuaMultiValue| { - let formatted = pretty_format_multi_value(&args)?; - let print: LuaFunction = lua.named_registry_value("print")?; - print.call(formatted)?; - Ok(()) - })? - .with_function("info", |lua, args: LuaMultiValue| { - let print: LuaFunction = lua.named_registry_value("print")?; - print.call(format!( - "{}\n{}", - format_label("info"), - pretty_format_multi_value(&args)? - ))?; - Ok(()) - })? - .with_function("warn", |lua, args: LuaMultiValue| { - let print: LuaFunction = lua.named_registry_value("print")?; - print.call(format!( - "{}\n{}", - format_label("warn"), - pretty_format_multi_value(&args)? - ))?; - Ok(()) - })? - .with_function("error", |lua, (arg, level): (LuaValue, Option)| { - let error: LuaFunction = lua.named_registry_value("error")?; - let trace: LuaFunction = lua.named_registry_value("dbg.trace")?; - error.call(( - LuaError::CallbackError { - traceback: format!("override traceback:{}", trace.call::<_, String>(())?), - cause: LuaError::external(format!( - "{}\n{}", - format_label("error"), - pretty_format_multi_value(&arg.to_lua_multi(lua)?)? - )) - .into(), - }, - level, - ))?; - Ok(()) - })? - // TODO: Add an override for tostring that formats errors in a nicer way - .build_readonly() +// HACK: We need to preserve the default behavior of the +// print and error functions, for pcall and such, which +// is really tricky to do from scratch so we will just +// proxy the default print and error functions here + +pub fn top_level_print(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { + let formatted = pretty_format_multi_value(&args)?; + let print: LuaFunction = lua.named_registry_value("print")?; + print.call(formatted)?; + Ok(()) } + +pub fn top_level_printinfo(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { + let print: LuaFunction = lua.named_registry_value("print")?; + print.call(format!( + "{}\n{}", + format_label("info"), + pretty_format_multi_value(&args)? + ))?; + Ok(()) +} + +pub fn top_level_warn(lua: &Lua, args: LuaMultiValue) -> LuaResult<()> { + let print: LuaFunction = lua.named_registry_value("print")?; + print.call(format!( + "{}\n{}", + format_label("warn"), + pretty_format_multi_value(&args)? + ))?; + Ok(()) +} + +pub fn top_level_error(lua: &Lua, (arg, level): (LuaValue, Option)) -> LuaResult<()> { + let error: LuaFunction = lua.named_registry_value("error")?; + let trace: LuaFunction = lua.named_registry_value("dbg.trace")?; + error.call(( + LuaError::CallbackError { + traceback: format!("override traceback:{}", trace.call::<_, String>(())?), + cause: LuaError::external(format!( + "{}\n{}", + format_label("error"), + pretty_format_multi_value(&arg.to_lua_multi(lua)?)? + )) + .into(), + }, + level, + ))?; + Ok(()) +} + +// TODO: Add an override for tostring that formats errors in a nicer way diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index b38f9da..519ce04 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -12,8 +12,6 @@ mod error; mod tests; pub use error::LuneError; -pub use globals::LuneGlobal; -pub use lua::create_lune_lua; #[derive(Clone, Debug, Default)] pub struct Lune { @@ -52,39 +50,28 @@ impl Lune { both live for the remainer of the program, and that this leaks memory using [`Box::leak`] that will then get deallocated when the program exits. */ - pub async fn try_run( - &self, - script_name: impl AsRef, - script_contents: impl AsRef<[u8]>, - ) -> Result { - self.run(script_name, script_contents, true) - .await - .map_err(LuneError::from_lua_error) - } - - /** - Tries to run a Lune script directly, returning the underlying - result type if loading the script raises a lua error. - - Passing `false` as the third argument `emit_prettified_errors` will - bypass any additional error formatting automatically done by Lune - for errors that are emitted while the script is running. - - Behavior is otherwise exactly the same as `try_run` - and `try_run` should be preferred in all other cases. - */ - #[doc(hidden)] pub async fn run( &self, script_name: impl AsRef, script_contents: impl AsRef<[u8]>, - emit_prettified_errors: bool, + ) -> Result { + self.run_inner(script_name, script_contents) + .await + .map_err(LuneError::from_lua_error) + } + + async fn run_inner( + &self, + script_name: impl AsRef, + script_contents: impl AsRef<[u8]>, ) -> Result { // Create our special lune-flavored Lua object with extra registry values - let lua = create_lune_lua()?; - // Create our task scheduler + let lua = lua::create_lune_lua()?; + // Create our task scheduler and all globals + // NOTE: Some globals require the task scheduler to exist on startup let sched = TaskScheduler::new(lua)?.into_static(); lua.set_app_data(sched); + globals::create(lua, self.args.clone())?; // Create the main thread and schedule it let main_chunk = lua .load(script_contents.as_ref()) @@ -93,11 +80,6 @@ impl Lune { let main_thread = lua.create_thread(main_chunk)?; let main_thread_args = LuaValue::Nil.to_lua_multi(lua)?; sched.schedule_blocking(main_thread, main_thread_args)?; - // Create our wanted lune globals, some of these need - // the task scheduler be available during construction - for global in LuneGlobal::all(&self.args) { - global.inject(lua)?; - } // Keep running the scheduler until there are either no tasks // left to run, or until a task requests to exit the process let exit_code = LocalSet::new() @@ -106,11 +88,7 @@ impl Lune { loop { let result = sched.resume_queue().await; if let Some(err) = result.get_lua_error() { - if emit_prettified_errors { - eprintln!("{}", LuneError::from_lua_error(err)); - } else { - eprintln!("{err}"); - } + eprintln!("{}", LuneError::from_lua_error(err)); got_error = true; } if result.is_done() { diff --git a/packages/lib/src/tests.rs b/packages/lib/src/tests.rs index f4ba14f..48b401d 100644 --- a/packages/lib/src/tests.rs +++ b/packages/lib/src/tests.rs @@ -37,7 +37,7 @@ macro_rules! create_tests { .trim_end_matches(".luau") .trim_end_matches(".lua") .to_string(); - let exit_code = lune.try_run(&script_name, &script).await?; + let exit_code = lune.run(&script_name, &script).await?; Ok(exit_code) } )* } @@ -63,6 +63,7 @@ create_tests! { require_async: "globals/require/tests/async", require_async_concurrent: "globals/require/tests/async_concurrent", require_async_sequential: "globals/require/tests/async_sequential", + require_builtins: "globals/require/tests/builtins", require_children: "globals/require/tests/children", require_invalid: "globals/require/tests/invalid", require_nested: "globals/require/tests/nested", diff --git a/tests/globals/require/tests/builtins.luau b/tests/globals/require/tests/builtins.luau new file mode 100644 index 0000000..f1b84e7 --- /dev/null +++ b/tests/globals/require/tests/builtins.luau @@ -0,0 +1,27 @@ +local fs = require("@lune/fs") :: typeof(fs) +local net = require("@lune/net") :: typeof(net) +local process = require("@lune/process") :: typeof(process) +local stdio = require("@lune/stdio") :: typeof(stdio) +local task = require("@lune/task") :: typeof(task) + +assert(type(fs.move) == "function") +assert(type(net.request) == "function") +assert(type(process.cwd) == "string") +assert(type(stdio.format("")) == "string") +assert(type(task.spawn(function() end)) == "thread") + +assert(not pcall(function() + return require("@") :: any +end)) + +assert(not pcall(function() + return require("@lune") :: any +end)) + +assert(not pcall(function() + return require("@lune/") :: any +end)) + +assert(not pcall(function() + return require("@src") :: any +end))