More work on task scheduler

This commit is contained in:
Filip Tibell 2023-02-17 19:20:17 +01:00
parent 744b73db3d
commit c45c78bdc2
No known key found for this signature in database
23 changed files with 359 additions and 239 deletions

View file

@ -159,7 +159,7 @@ impl Cli {
// Display the file path relative to cwd with no extensions in stack traces
let file_display_name = file_path.with_extension("").display().to_string();
// Create a new lune object with all globals & run the script
let lune = Lune::new().with_all_globals_and_args(self.script_args);
let lune = Lune::new().with_args(self.script_args);
let result = lune.run(&file_display_name, &file_contents).await;
Ok(match result {
Err(e) => {

View file

@ -1,11 +1,10 @@
use std::{
env::{self, current_dir},
io,
fs,
path::PathBuf,
};
use mlua::prelude::*;
use tokio::fs;
use crate::utils::table::TableBuilder;
@ -15,37 +14,63 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
if env::var_os("LUAU_PWD_REQUIRE").is_some() {
return TableBuilder::new(lua)?.build_readonly();
}
// Store the current pwd, and make helper functions for path conversions
let require_pwd = current_dir()?.to_string_lossy().to_string();
// Store the current pwd, and make the functions for path conversions & loading a file
let mut require_pwd = current_dir()?.to_string_lossy().to_string();
if !require_pwd.ends_with('/') {
require_pwd = format!("{require_pwd}/")
}
let require_info: LuaFunction = lua.named_registry_value("dbg.info")?;
let require_error: LuaFunction = lua.named_registry_value("error")?;
let require_get_abs_rel_paths = lua
.create_function(
|_, (require_pwd, require_source, require_path): (String, String, String)| {
let mut path_relative_to_pwd = PathBuf::from(
let path_relative_to_pwd = PathBuf::from(
&require_source
.trim_start_matches("[string \"")
.trim_end_matches("\"]"),
)
.parent()
.unwrap()
.join(require_path);
.join(&require_path);
// Try to normalize and resolve relative path segments such as './' and '../'
if let Ok(canonicalized) =
path_relative_to_pwd.with_extension("luau").canonicalize()
{
path_relative_to_pwd = canonicalized;
}
if let Ok(canonicalized) = path_relative_to_pwd.with_extension("lua").canonicalize()
{
path_relative_to_pwd = canonicalized;
}
let absolute = path_relative_to_pwd.to_string_lossy().to_string();
let file_path = match (
path_relative_to_pwd.with_extension("luau").canonicalize(),
path_relative_to_pwd.with_extension("lua").canonicalize(),
) {
(Ok(luau), _) => luau,
(_, Ok(lua)) => lua,
_ => {
return Err(LuaError::RuntimeError(format!(
"File does not exist at path '{require_path}'"
)))
}
};
let absolute = file_path.to_string_lossy().to_string();
let relative = absolute.trim_start_matches(&require_pwd).to_string();
Ok((absolute, relative))
},
)?
.bind(require_pwd)?;
// Note that file loading must be blocking to guarantee the require cache works, if it
// were async then one lua script may require a module during the file reading process
let require_get_loaded_file = lua.create_function(
|lua: &Lua, (path_absolute, path_relative): (String, String)| {
// Use a name without extensions for loading the chunk, the
// above code assumes the require path is without extensions
let path_relative_no_extension = path_relative
.trim_end_matches(".lua")
.trim_end_matches(".luau");
// Try to read the wanted file, note that we use bytes instead of reading
// to a string since lua scripts are not necessarily valid utf-8 strings
match fs::read(path_absolute) {
Ok(contents) => lua
.load(&contents)
.set_name(path_relative_no_extension)?
.eval::<LuaValue>(),
Err(e) => Err(LuaError::external(e)),
}
},
)?;
/*
We need to get the source file where require was
called to be able to do path-relative requires,
@ -61,12 +86,15 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
.with_value("info", require_info)?
.with_value("error", require_error)?
.with_value("paths", require_get_abs_rel_paths)?
.with_async_function("load", load_file)?
.with_value("load", require_get_loaded_file)?
.build_readonly()?;
let require_fn_lua = lua
.load(
r#"
local source = info(2, "s")
local source = info(1, "s")
if source == '[string "require"]' then
source = info(2, "s")
end
local absolute, relative = paths(source, ...)
if loaded[absolute] ~= true then
local first, second = load(absolute, relative)
@ -88,20 +116,3 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
.with_value("require", require_fn_lua)?
.build_readonly()
}
async fn load_file(
lua: &Lua,
(path_absolute, path_relative): (String, String),
) -> LuaResult<LuaValue> {
// Try to read the wanted file, note that we use bytes instead of reading
// to a string since lua scripts are not necessarily valid utf-8 strings
match fs::read(&path_absolute).await {
Ok(contents) => lua.load(&contents).set_name(path_relative)?.eval(),
Err(e) => match e.kind() {
io::ErrorKind::NotFound => Err(LuaError::RuntimeError(format!(
"No lua module exists at the path '{path_relative}'"
))),
_ => Err(LuaError::external(e)),
},
}
}

View file

@ -1,10 +1,13 @@
use std::time::Duration;
use mlua::prelude::*;
use tokio::time::{sleep, Instant};
use crate::{
lua::task::{TaskKind, TaskReference, TaskScheduler, TaskSchedulerScheduleExt},
lua::{
async_ext::LuaAsyncExt,
task::{
LuaThreadOrFunction, LuaThreadOrTaskReference, TaskKind, TaskReference, TaskScheduler,
TaskSchedulerScheduleExt,
},
},
utils::table::TableBuilder,
};
@ -22,7 +25,6 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
we need to yield right away to allow the
spawned task to run until first yield
*/
let task_spawn_env_thread: LuaFunction = lua.named_registry_value("co.thread")?;
let task_spawn_env_yield: LuaFunction = lua.named_registry_value("co.yield")?;
let task_spawn = lua
.load(
@ -33,10 +35,10 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
return task
",
)
.set_name("=task.spawn")?
.set_name("task.spawn")?
.set_environment(
TableBuilder::new(lua)?
.with_value("thread", task_spawn_env_thread)?
.with_function("thread", |lua, _: ()| Ok(lua.current_thread()))?
.with_value("yield", task_spawn_env_yield)?
.with_function(
"scheduleNext",
@ -63,83 +65,14 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
coroutine.set("wrap", lua.create_function(coroutine_wrap)?)?;
// All good, return the task scheduler lib
TableBuilder::new(lua)?
.with_value("wait", lua.create_waiter_function()?)?
.with_value("spawn", task_spawn)?
.with_function("cancel", task_cancel)?
.with_function("defer", task_defer)?
.with_function("delay", task_delay)?
.with_async_function("wait", task_wait)?
.build_readonly()
}
/*
Proxy enum to deal with both threads & functions
*/
enum LuaThreadOrFunction<'lua> {
Thread(LuaThread<'lua>),
Function(LuaFunction<'lua>),
}
impl<'lua> LuaThreadOrFunction<'lua> {
fn into_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
match self {
Self::Thread(t) => Ok(t),
Self::Function(f) => lua.create_thread(f),
}
}
}
impl<'lua> FromLua<'lua> for LuaThreadOrFunction<'lua> {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Thread(t) => Ok(Self::Thread(t)),
LuaValue::Function(f) => Ok(Self::Function(f)),
value => Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "LuaThreadOrFunction",
message: Some(format!(
"Expected thread or function, got '{}'",
value.type_name()
)),
}),
}
}
}
/*
Proxy enum to deal with both threads & task scheduler task references
*/
enum LuaThreadOrTaskReference<'lua> {
Thread(LuaThread<'lua>),
TaskReference(TaskReference),
}
impl<'lua> FromLua<'lua> for LuaThreadOrTaskReference<'lua> {
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
let tname = value.type_name();
match value {
LuaValue::Thread(t) => Ok(Self::Thread(t)),
LuaValue::UserData(u) => {
if let Ok(task) = TaskReference::from_lua(LuaValue::UserData(u), lua) {
Ok(Self::TaskReference(task))
} else {
Err(LuaError::FromLuaConversionError {
from: tname,
to: "thread",
message: Some(format!("Expected thread, got '{tname}'")),
})
}
}
_ => Err(LuaError::FromLuaConversionError {
from: tname,
to: "thread",
message: Some(format!("Expected thread, got '{tname}'")),
}),
}
}
}
/*
Basic task functions
*/
@ -166,12 +99,6 @@ fn task_delay(
sched.schedule_blocking_after_seconds(secs, tof.into_thread(lua)?, args)
}
async fn task_wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
let start = Instant::now();
sleep(Duration::from_secs_f64(secs.unwrap_or_default())).await;
Ok(start.elapsed().as_secs_f64())
}
/*
Type getter overrides for compat with task scheduler
*/
@ -207,7 +134,7 @@ fn coroutine_resume<'lua>(
match value {
LuaThreadOrTaskReference::Thread(t) => {
let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
let task = sched.create_task(TaskKind::Instant, t, None, None)?;
let task = sched.create_task(TaskKind::Instant, t, None, true)?;
sched.resume_task(task, None)
}
LuaThreadOrTaskReference::TaskReference(t) => lua
@ -222,7 +149,7 @@ fn coroutine_wrap<'lua>(lua: &'lua Lua, func: LuaFunction) -> LuaResult<LuaFunct
TaskKind::Instant,
lua.create_thread(func)?,
None,
None,
false,
)?;
lua.create_function(move |lua, args: LuaMultiValue| {
lua.app_data_ref::<&TaskScheduler>()

View file

@ -1,4 +1,4 @@
use std::{collections::HashSet, process::ExitCode};
use std::process::ExitCode;
use lua::task::{TaskScheduler, TaskSchedulerResumeExt, TaskSchedulerScheduleExt};
use mlua::prelude::*;
@ -18,8 +18,7 @@ pub use lua::create_lune_lua;
#[derive(Clone, Debug, Default)]
pub struct Lune {
includes: HashSet<LuneGlobal>,
excludes: HashSet<LuneGlobal>,
args: Vec<String>,
}
impl Lune {
@ -31,42 +30,13 @@ impl Lune {
}
/**
Include a global in the lua environment created for running a Lune script.
Arguments to give in `process.args` for a Lune script.
*/
pub fn with_global(mut self, global: LuneGlobal) -> Self {
self.includes.insert(global);
self
}
/**
Include all globals in the lua environment created for running a Lune script.
*/
pub fn with_all_globals(mut self) -> Self {
for global in LuneGlobal::all::<String>(&[]) {
self.includes.insert(global);
}
self
}
/**
Include all globals in the lua environment created for running a
Lune script, as well as supplying args for [`LuneGlobal::Process`].
*/
pub fn with_all_globals_and_args(mut self, args: Vec<String>) -> Self {
for global in LuneGlobal::all(&args) {
self.includes.insert(global);
}
self
}
/**
Exclude a global from the lua environment created for running a Lune script.
This should be preferred over manually iterating and filtering
which Lune globals to add to the global environment.
*/
pub fn without_global(mut self, global: LuneGlobal) -> Self {
self.excludes.insert(global);
pub fn with_args<V>(mut self, args: V) -> Self
where
V: Into<Vec<String>>,
{
self.args = args.into();
self
}
@ -76,12 +46,11 @@ impl Lune {
This will create a new sandboxed Luau environment with the configured
globals and arguments, running inside of a [`tokio::task::LocalSet`].
Some Lune globals such as [`LuneGlobal::Process`] and [`LuneGlobal::Net`]
may spawn separate tokio tasks on other threads, but the Luau environment
itself is guaranteed to run on a single thread in the local set.
Some Lune globals may spawn separate tokio tasks on other threads, but the Luau
environment itself is guaranteed to run on a single thread in the local set.
Note that this will create a static Lua instance and task scheduler which both
will live for the remainer of the program, and that this leaks memory using
Note that this will create a static Lua instance and task scheduler that will
both live for the remainer of the program, and that this leaks memory using
[`Box::leak`] that will then get deallocated when the program exits.
*/
pub async fn run(
@ -106,10 +75,8 @@ impl Lune {
sched.schedule_blocking(main_thread, main_thread_args)?;
// Create our wanted lune globals, some of these need
// the task scheduler be available during construction
for global in self.includes.clone() {
if !self.excludes.contains(&global) {
global.inject(lua)?;
}
for global in LuneGlobal::all(&self.args) {
global.inject(lua)?;
}
// Keep running the scheduler until there are either no tasks
// left to run, or until a task requests to exit the process

View file

@ -4,6 +4,8 @@ use mlua::prelude::*;
use crate::{lua::task::TaskScheduler, utils::table::TableBuilder};
use super::task::TaskSchedulerAsyncExt;
#[async_trait(?Send)]
pub trait LuaAsyncExt {
fn create_async_function<'lua, A, R, F, FR>(self, func: F) -> LuaResult<LuaFunction<'lua>>
@ -12,6 +14,8 @@ pub trait LuaAsyncExt {
R: ToLuaMulti<'static>,
F: 'static + Fn(&'static Lua, A) -> FR,
FR: 'static + Future<Output = LuaResult<R>>;
fn create_waiter_function<'lua>(self) -> LuaResult<LuaFunction<'lua>>;
}
impl LuaAsyncExt for &'static Lua {
@ -31,7 +35,6 @@ impl LuaAsyncExt for &'static Lua {
let async_env_trace: LuaFunction = self.named_registry_value("dbg.trace")?;
let async_env_error: LuaFunction = self.named_registry_value("error")?;
let async_env_unpack: LuaFunction = self.named_registry_value("tab.unpack")?;
let async_env_thread: LuaFunction = self.named_registry_value("co.thread")?;
let async_env_yield: LuaFunction = self.named_registry_value("co.yield")?;
let async_env = TableBuilder::new(self)?
.with_value("makeError", async_env_make_err)?
@ -39,8 +42,8 @@ impl LuaAsyncExt for &'static Lua {
.with_value("trace", async_env_trace)?
.with_value("error", async_env_error)?
.with_value("unpack", async_env_unpack)?
.with_value("thread", async_env_thread)?
.with_value("yield", async_env_yield)?
.with_function("thread", |lua, _: ()| Ok(lua.current_thread()))?
.with_function(
"resumeAsync",
move |lua: &Lua, (thread, args): (LuaThread, A)| {
@ -48,7 +51,7 @@ impl LuaAsyncExt for &'static Lua {
let sched = lua
.app_data_ref::<&TaskScheduler>()
.expect("Missing task scheduler as a lua app data");
sched.queue_async_task(thread, None, None, async {
sched.queue_async_task(thread, None, async {
let rets = fut.await?;
let mult = rets.to_lua_multi(lua)?;
Ok(Some(mult))
@ -68,7 +71,36 @@ impl LuaAsyncExt for &'static Lua {
end
",
)
.set_name("asyncWrapper")?
.set_name("async")?
.set_environment(async_env)?
.into_function()?;
Ok(async_func)
}
/**
Creates a special async function that waits the
desired amount of time, inheriting the guid of the
current thread / task for proper cancellation.
*/
fn create_waiter_function<'lua>(self) -> LuaResult<LuaFunction<'lua>> {
let async_env_yield: LuaFunction = self.named_registry_value("co.yield")?;
let async_env = TableBuilder::new(self)?
.with_value("yield", async_env_yield)?
.with_function("resumeAfter", move |lua: &Lua, duration: Option<f64>| {
let sched = lua
.app_data_ref::<&TaskScheduler>()
.expect("Missing task scheduler as a lua app data");
sched.schedule_wait(lua.current_thread(), duration)
})?
.build_readonly()?;
let async_func = self
.load(
"
resumeAfter(...)
return yield()
",
)
.set_name("wait")?
.set_environment(async_env)?
.into_function()?;
Ok(async_func)

View file

@ -64,7 +64,6 @@ end
* `"tostring"` -> `tostring`
* `"tonumber"` -> `tonumber`
---
* `"co.thread"` -> `coroutine.running`
* `"co.yield"` -> `coroutine.yield`
* `"co.close"` -> `coroutine.close`
---
@ -93,7 +92,6 @@ pub fn create() -> LuaResult<&'static Lua> {
lua.set_named_registry_value("pcall", globals.get::<_, LuaFunction>("pcall")?)?;
lua.set_named_registry_value("tostring", globals.get::<_, LuaFunction>("tostring")?)?;
lua.set_named_registry_value("tonumber", globals.get::<_, LuaFunction>("tonumber")?)?;
lua.set_named_registry_value("co.thread", coroutine.get::<_, LuaFunction>("running")?)?;
lua.set_named_registry_value("co.yield", coroutine.get::<_, LuaFunction>("yield")?)?;
lua.set_named_registry_value("co.close", coroutine.get::<_, LuaFunction>("close")?)?;
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;

View file

@ -1,6 +1,6 @@
mod create;
pub mod ext;
pub mod async_ext;
pub mod net;
pub mod stdio;
pub mod task;

View file

@ -1,7 +1,12 @@
use std::time::Duration;
use async_trait::async_trait;
use futures_util::Future;
use mlua::prelude::*;
use tokio::time::{sleep, Instant};
use crate::lua::task::TaskKind;
use super::super::{
async_handle::TaskSchedulerAsyncHandle, message::TaskSchedulerMessage,
@ -30,6 +35,12 @@ pub trait TaskSchedulerAsyncExt<'fut> {
R: ToLuaMulti<'static>,
F: 'static + Fn(&'static Lua) -> FR,
FR: 'static + Future<Output = LuaResult<R>>;
fn schedule_wait(
&'fut self,
reference: LuaThread<'_>,
duration: Option<f64>,
) -> LuaResult<TaskReference>;
}
/*
@ -82,7 +93,7 @@ impl<'fut> TaskSchedulerAsyncExt<'fut> for TaskScheduler<'fut> {
F: 'static + Fn(&'static Lua) -> FR,
FR: 'static + Future<Output = LuaResult<R>>,
{
self.queue_async_task(thread, None, None, async move {
self.queue_async_task(thread, None, async move {
match func(self.lua).await {
Ok(res) => match res.to_lua_multi(self.lua) {
Ok(multi) => Ok(Some(multi)),
@ -92,4 +103,30 @@ impl<'fut> TaskSchedulerAsyncExt<'fut> for TaskScheduler<'fut> {
}
})
}
/**
Schedules a task reference to be resumed after a certain amount of time.
The given task will be resumed with the elapsed time as its one and only argument.
*/
fn schedule_wait(
&'fut self,
thread: LuaThread<'_>,
duration: Option<f64>,
) -> LuaResult<TaskReference> {
let reference = self.create_task(TaskKind::Future, thread, None, true)?;
// Insert the future
let futs = self
.futures
.try_lock()
.expect("Tried to add future to queue during futures resumption");
futs.push(Box::pin(async move {
let before = Instant::now();
sleep(Duration::from_secs_f64(duration.unwrap_or_default())).await;
let elapsed_secs = before.elapsed().as_secs_f64();
let args = elapsed_secs.to_lua_multi(self.lua).unwrap();
(Some(reference), Ok(Some(args)))
}));
Ok(reference)
}
}

View file

@ -124,12 +124,15 @@ async fn resume_next_async_task(scheduler: &TaskScheduler<'_>) -> TaskSchedulerS
.await
.expect("Tried to resume next queued future but none are queued")
};
// Promote this future task to a blocking task and resume it
// right away, also taking care to not borrow mutably twice
// by dropping this guard before trying to resume it
let mut queue_guard = scheduler.tasks_queue_blocking.borrow_mut();
queue_guard.push_front(task);
drop(queue_guard);
// The future might not return a reference that it wants to resume
if let Some(task) = task {
// Promote this future task to a blocking task and resume it
// right away, also taking care to not borrow mutably twice
// by dropping this guard before trying to resume it
let mut queue_guard = scheduler.tasks_queue_blocking.borrow_mut();
queue_guard.push_front(task);
drop(queue_guard);
}
resume_next_blocking_task(scheduler, result.transpose())
}

View file

@ -52,7 +52,7 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> {
thread: LuaThread<'_>,
thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> {
self.queue_blocking_task(TaskKind::Instant, thread, Some(thread_args), None)
self.queue_blocking_task(TaskKind::Instant, thread, Some(thread_args))
}
/**
@ -67,7 +67,7 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> {
thread: LuaThread<'_>,
thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> {
self.queue_blocking_task(TaskKind::Deferred, thread, Some(thread_args), None)
self.queue_blocking_task(TaskKind::Deferred, thread, Some(thread_args))
}
/**
@ -83,7 +83,7 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> {
thread: LuaThread<'_>,
thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> {
self.queue_async_task(thread, Some(thread_args), None, async move {
self.queue_async_task(thread, Some(thread_args), async move {
sleep(Duration::from_secs_f64(after_secs)).await;
Ok(None)
})

View file

@ -1,10 +1,12 @@
mod async_handle;
mod ext;
mod message;
mod proxy;
mod result;
mod scheduler;
mod task_kind;
mod task_reference;
pub use ext::*;
pub use proxy::*;
pub use scheduler::*;

View file

@ -0,0 +1,116 @@
use mlua::prelude::*;
use super::TaskReference;
/*
Proxy enum to deal with both threads & functions
*/
#[derive(Debug, Clone)]
pub enum LuaThreadOrFunction<'lua> {
Thread(LuaThread<'lua>),
Function(LuaFunction<'lua>),
}
impl<'lua> LuaThreadOrFunction<'lua> {
pub fn into_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
match self {
Self::Thread(t) => Ok(t),
Self::Function(f) => lua.create_thread(f),
}
}
}
impl<'lua> From<LuaThread<'lua>> for LuaThreadOrFunction<'lua> {
fn from(value: LuaThread<'lua>) -> Self {
Self::Thread(value)
}
}
impl<'lua> From<LuaFunction<'lua>> for LuaThreadOrFunction<'lua> {
fn from(value: LuaFunction<'lua>) -> Self {
Self::Function(value)
}
}
impl<'lua> FromLua<'lua> for LuaThreadOrFunction<'lua> {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Thread(t) => Ok(Self::Thread(t)),
LuaValue::Function(f) => Ok(Self::Function(f)),
value => Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "LuaThreadOrFunction",
message: Some(format!(
"Expected thread or function, got '{}'",
value.type_name()
)),
}),
}
}
}
impl<'lua> ToLua<'lua> for LuaThreadOrFunction<'lua> {
fn to_lua(self, _: &'lua Lua) -> LuaResult<LuaValue<'lua>> {
match self {
Self::Thread(t) => Ok(LuaValue::Thread(t)),
Self::Function(f) => Ok(LuaValue::Function(f)),
}
}
}
/*
Proxy enum to deal with both threads & task scheduler task references
*/
#[derive(Debug, Clone)]
pub enum LuaThreadOrTaskReference<'lua> {
Thread(LuaThread<'lua>),
TaskReference(TaskReference),
}
impl<'lua> From<LuaThread<'lua>> for LuaThreadOrTaskReference<'lua> {
fn from(value: LuaThread<'lua>) -> Self {
Self::Thread(value)
}
}
impl<'lua> From<TaskReference> for LuaThreadOrTaskReference<'lua> {
fn from(value: TaskReference) -> Self {
Self::TaskReference(value)
}
}
impl<'lua> FromLua<'lua> for LuaThreadOrTaskReference<'lua> {
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
let tname = value.type_name();
match value {
LuaValue::Thread(t) => Ok(Self::Thread(t)),
LuaValue::UserData(u) => {
if let Ok(task) = TaskReference::from_lua(LuaValue::UserData(u), lua) {
Ok(Self::TaskReference(task))
} else {
Err(LuaError::FromLuaConversionError {
from: tname,
to: "thread",
message: Some(format!("Expected thread, got '{tname}'")),
})
}
}
_ => Err(LuaError::FromLuaConversionError {
from: tname,
to: "thread",
message: Some(format!("Expected thread, got '{tname}'")),
}),
}
}
}
impl<'lua> ToLua<'lua> for LuaThreadOrTaskReference<'lua> {
fn to_lua(self, lua: &'lua Lua) -> LuaResult<LuaValue<'lua>> {
match self {
Self::TaskReference(t) => t.to_lua(lua),
Self::Thread(t) => Ok(LuaValue::Thread(t)),
}
}
}

View file

@ -14,7 +14,7 @@ use super::message::TaskSchedulerMessage;
pub use super::{task_kind::TaskKind, task_reference::TaskReference};
type TaskFutureRets<'fut> = LuaResult<Option<LuaMultiValue<'fut>>>;
type TaskFuture<'fut> = LocalBoxFuture<'fut, (TaskReference, TaskFutureRets<'fut>)>;
type TaskFuture<'fut> = LocalBoxFuture<'fut, (Option<TaskReference>, TaskFutureRets<'fut>)>;
/// A struct representing a task contained in the task scheduler
#[derive(Debug)]
@ -40,10 +40,10 @@ pub struct TaskScheduler<'fut> {
// Internal state & flags
pub(super) lua: &'static Lua,
pub(super) guid: Cell<usize>,
pub(super) guid_running: Cell<Option<usize>>,
pub(super) exit_code: Cell<Option<ExitCode>>,
// Blocking tasks
pub(super) tasks: RefCell<HashMap<TaskReference, Task>>,
pub(super) tasks_current: Cell<Option<TaskReference>>,
pub(super) tasks_queue_blocking: RefCell<VecDeque<TaskReference>>,
// Future tasks & objects for waking
pub(super) futures: AsyncMutex<FuturesUnordered<TaskFuture<'fut>>>,
@ -61,9 +61,9 @@ impl<'fut> TaskScheduler<'fut> {
Ok(Self {
lua,
guid: Cell::new(0),
guid_running: Cell::new(None),
exit_code: Cell::new(None),
tasks: RefCell::new(HashMap::new()),
tasks_current: Cell::new(None),
tasks_queue_blocking: RefCell::new(VecDeque::new()),
futures: AsyncMutex::new(FuturesUnordered::new()),
futures_tx: tx,
@ -109,6 +109,14 @@ impl<'fut> TaskScheduler<'fut> {
self.tasks.borrow().contains_key(&reference)
}
/**
Returns the currently running task, if any.
*/
#[allow(dead_code)]
pub fn current_task(&self) -> Option<TaskReference> {
self.tasks_current.get()
}
/**
Creates a new task, storing a new Lua thread
for it, as well as the arguments to give the
@ -123,7 +131,7 @@ impl<'fut> TaskScheduler<'fut> {
kind: TaskKind,
thread: LuaThread<'_>,
thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>,
inherit_current_guid: bool,
) -> LuaResult<TaskReference> {
// Store the thread and its arguments in the registry
// NOTE: We must convert to a vec since multis
@ -137,19 +145,22 @@ impl<'fut> TaskScheduler<'fut> {
args: task_args_key,
};
// Create the task ref to use
let task_ref = if let Some(reusable_guid) = guid_to_reuse {
TaskReference::new(kind, reusable_guid)
let guid = if inherit_current_guid {
self.current_task()
.expect("No current guid to inherit")
.id()
} else {
let guid = self.guid.get();
self.guid.set(guid + 1);
TaskReference::new(kind, guid)
guid
};
let reference = TaskReference::new(kind, guid);
// Add the task to the scheduler
{
let mut tasks = self.tasks.borrow_mut();
tasks.insert(task_ref, task);
tasks.insert(reference, task);
}
Ok(task_ref)
Ok(reference)
}
/**
@ -181,8 +192,13 @@ impl<'fut> TaskScheduler<'fut> {
.filter(|task_ref| task_ref.id() == reference.id())
.copied()
.collect();
for task_ref in tasks_to_remove {
if let Some(task) = tasks.remove(&task_ref) {
for task_ref in &tasks_to_remove {
if let Some(task) = tasks.remove(task_ref) {
// NOTE: We need to close the thread here to
// make 100% sure that nothing can resume it
let close: LuaFunction = self.lua.named_registry_value("co.close")?;
let thread: LuaThread = self.lua.registry_value(&task.thread)?;
close.call(thread)?;
self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
found = true;
@ -204,13 +220,16 @@ impl<'fut> TaskScheduler<'fut> {
reference: TaskReference,
override_args: Option<LuaResult<LuaMultiValue<'a>>>,
) -> LuaResult<LuaMultiValue<'a>> {
// Fetch and check if the task was removed, if it got
// removed it means it was intentionally cancelled
let task = {
let mut tasks = self.tasks.borrow_mut();
match tasks.remove(&reference) {
Some(task) => task,
None => return Ok(LuaMultiValue::new()), // Task was removed
None => return Ok(LuaMultiValue::new()),
}
};
// Fetch and remove the thread to resume + its arguments
let thread: LuaThread = self.lua.registry_value(&task.thread)?;
let args_opt_res = override_args.or_else(|| {
Ok(self
@ -222,7 +241,9 @@ impl<'fut> TaskScheduler<'fut> {
});
self.lua.remove_registry_value(task.thread)?;
self.lua.remove_registry_value(task.args)?;
self.guid_running.set(Some(reference.id()));
// We got everything we need and our references
// were cleaned up properly, resume the thread
self.tasks_current.set(Some(reference));
let rets = match args_opt_res {
Some(args_res) => match args_res {
/*
@ -235,12 +256,12 @@ impl<'fut> TaskScheduler<'fut> {
that may pass errors as arguments when resuming tasks, other
native mlua functions will handle this and dont need wrapping
*/
Err(err) => thread.resume(err),
Err(e) => thread.resume(e),
Ok(args) => thread.resume(args),
},
None => thread.resume(()),
};
self.guid_running.set(None);
self.tasks_current.set(None);
rets
}
@ -265,12 +286,11 @@ impl<'fut> TaskScheduler<'fut> {
kind: TaskKind,
thread: LuaThread<'_>,
thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>,
) -> LuaResult<TaskReference> {
if kind == TaskKind::Future {
panic!("Tried to schedule future using normal task schedule method")
}
let task_ref = self.create_task(kind, thread, thread_args, guid_to_reuse)?;
let task_ref = self.create_task(kind, thread, thread_args, false)?;
// Add the task to the front of the queue, unless it
// should be deferred, in that case add it to the back
let mut queue = self.tasks_queue_blocking.borrow_mut();
@ -304,17 +324,16 @@ impl<'fut> TaskScheduler<'fut> {
&self,
thread: LuaThread<'_>,
thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>,
fut: impl Future<Output = TaskFutureRets<'fut>> + 'fut,
) -> LuaResult<TaskReference> {
let task_ref = self.create_task(TaskKind::Future, thread, thread_args, guid_to_reuse)?;
let task_ref = self.create_task(TaskKind::Future, thread, thread_args, false)?;
let futs = self
.futures
.try_lock()
.expect("Failed to get lock on futures");
.expect("Tried to add future to queue during futures resumption");
futs.push(Box::pin(async move {
let result = fut.await;
(task_ref, result)
(Some(task_ref), result)
}));
Ok(task_ref)
}

View file

@ -24,7 +24,11 @@ impl TaskReference {
impl fmt::Display for TaskReference {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "TaskReference({} - {})", self.kind, self.guid)
if self.guid == 0 {
write!(f, "TaskReference(MAIN)")
} else {
write!(f, "TaskReference({} - {})", self.kind, self.guid)
}
}
}

View file

@ -26,12 +26,12 @@ macro_rules! create_tests {
// The rest of the test logic can continue as normal
let full_name = format!("tests/{}.luau", $value);
let script = read_to_string(&full_name).await?;
let lune = Lune::new().with_all_globals_and_args(
let lune = Lune::new().with_args(
ARGS
.clone()
.iter()
.map(ToString::to_string)
.collect()
.collect::<Vec<_>>()
);
let script_name = full_name.strip_suffix(".luau").unwrap();
let exit_code = lune.run(&script_name, &script).await?;

View file

@ -240,18 +240,21 @@ pub fn pretty_format_luau_error(e: &LuaError, colorized: bool) -> String {
// The traceback may also start with "override traceback:" which
// means it was passed from somewhere that wants a custom trace,
// so we should then respect that and get the best override instead
let mut best_trace: &str = traceback;
let mut full_trace = traceback.to_string();
let mut root_cause = cause.as_ref();
let mut trace_override = false;
while let LuaError::CallbackError { cause, traceback } = root_cause {
let is_override = traceback.starts_with("override traceback:");
if is_override {
if !trace_override || traceback.lines().count() > best_trace.len() {
best_trace = traceback.strip_prefix("override traceback:").unwrap();
if !trace_override || traceback.lines().count() > full_trace.len() {
full_trace = traceback
.strip_prefix("override traceback:")
.unwrap()
.to_string();
trace_override = true;
}
} else if !trace_override && traceback.lines().count() > best_trace.len() {
best_trace = traceback;
} else if !trace_override {
full_trace = format!("{traceback}\n{full_trace}");
}
root_cause = cause;
}
@ -266,10 +269,10 @@ pub fn pretty_format_luau_error(e: &LuaError, colorized: bool) -> String {
"{}\n{}\n{}\n{}",
pretty_format_luau_error(root_cause, colorized),
stack_begin,
if best_trace.starts_with("stack traceback:") {
best_trace.strip_prefix("stack traceback:\n").unwrap()
if full_trace.starts_with("stack traceback:") {
full_trace.strip_prefix("stack traceback:\n").unwrap()
} else {
best_trace
&full_trace
},
stack_end
)
@ -378,7 +381,8 @@ fn transform_stack_line(line: &str) -> String {
let line_num = match after_name.find(':') {
Some(lineno_start) => match after_name[lineno_start + 1..].find(':') {
Some(lineno_end) => &after_name[lineno_start + 1..lineno_end + 1],
None => match after_name.contains("in function") {
None => match after_name.contains("in function") || after_name.contains("in ?")
{
false => &after_name[lineno_start + 1..],
true => "",
},
@ -418,11 +422,18 @@ fn transform_stack_line(line: &str) -> String {
fn fix_error_nitpicks(full_message: String) -> String {
full_message
// Hacky fix for our custom require appearing as a normal script
// TODO: It's probably better to pull in the regex crate here ..
.replace("'require', Line 5", "'[C]' - function require")
.replace("'require', Line 7", "'[C]' - function require")
.replace("'require', Line 8", "'[C]' - function require")
// Fix error calls in custom script chunks coming through
.replace(
"'[C]' - function error\n Script '[C]' - function require",
"'[C]' - function require",
)
// Fix strange double require
.replace(
"'[C]' - function require - function require",
"'[C]' - function require",
)
}

View file

@ -2,7 +2,7 @@ use std::future::Future;
use mlua::prelude::*;
use crate::lua::ext::LuaAsyncExt;
use crate::lua::async_ext::LuaAsyncExt;
pub struct TableBuilder {
lua: &'static Lua,

View file

@ -1,3 +1,6 @@
local PORT = 9090 -- NOTE: This must be different from
-- net tests to let them run in parallel with this file
local function test(f, ...)
local success, message = pcall(f, ...)
assert(not success, "Function did not throw an error")
@ -14,7 +17,7 @@ test(net.request, "https://wxyz.google.com")
-- Net serve is async and will throw an OS error when trying to serve twice on the same port
local handle = net.serve(8080, function()
local handle = net.serve(PORT, function()
return ""
end)
@ -22,18 +25,4 @@ task.delay(0, function()
handle.stop()
end)
test(net.serve, 8080, function() end)
local function e()
task.spawn(function()
task.defer(function()
task.delay(0, function()
error({
Hello = "World",
})
end)
end)
end)
end
task.defer(e)
test(net.serve, PORT, function() end)

View file

@ -4,6 +4,8 @@ assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
require("modules/module")
module = require("modules/module")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
return true

View file

@ -5,7 +5,7 @@ local function test(path: string)
if success then
error(string.format("Invalid require at path '%s' succeeded", path))
else
print(message)
message = tostring(message)
if string.find(message, string.format("%s'", path)) == nil then
error(
string.format(

View file

@ -4,4 +4,6 @@ assert(type(module) == "table", "Required module did not return a table")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")
require("modules/nested")
module = require("modules/nested")
assert(module.Foo == "Bar", "Required module did not contain correct values")
assert(module.Hello == "World", "Required module did not contain correct values")

View file

@ -15,7 +15,7 @@ function util.fail(method, url, message)
method = method,
url = url,
})
if not response.ok then
if response.ok then
error(string.format("%s passed!\nResponse: %s", message, stdio.format(response)))
end
end