Only use threads and nothing else to simplify task scheduler

This commit is contained in:
Filip Tibell 2023-02-17 00:08:24 +01:00
parent 546ebbd349
commit 7f17ab0063
No known key found for this signature in database
6 changed files with 200 additions and 150 deletions

View file

@ -8,28 +8,31 @@ use crate::{
utils::table::TableBuilder, utils::table::TableBuilder,
}; };
const ERR_MISSING_SCHEDULER: &str = "Missing task scheduler - make sure it is added as a lua app data before the first scheduler resumption";
const TASK_SPAWN_IMPL_LUA: &str = r#"
-- Schedule the current thread at the front
scheduleNext(thread())
-- Schedule the wanted task arg at the front,
-- the previous schedule now comes right after
local task = scheduleNext(...)
-- Give control over to the scheduler, which will
-- resume the above tasks in order when its ready
yield()
return task
"#;
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> { pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
// The spawn function needs special treatment, lua.app_data_ref::<&TaskScheduler>()
// we need to yield right away to allow the .expect("Missing task scheduler in app data");
// spawned task to run until first yield /*
1. Schedule the current thread at the front
2. Schedule the wanted task arg at the front,
the previous schedule now comes right after
3. Give control over to the scheduler, which will
resume the above tasks in order when its ready
The spawn function needs special treatment,
we need to yield right away to allow the
spawned task to run until first yield
*/
let task_spawn_env_thread: LuaFunction = lua.named_registry_value("co.thread")?; let task_spawn_env_thread: LuaFunction = lua.named_registry_value("co.thread")?;
let task_spawn_env_yield: LuaFunction = lua.named_registry_value("co.yield")?; let task_spawn_env_yield: LuaFunction = lua.named_registry_value("co.yield")?;
let task_spawn = lua let task_spawn = lua
.load(TASK_SPAWN_IMPL_LUA) .load(
"
scheduleNext(thread())
local task = scheduleNext(...)
yield()
return task
",
)
.set_name("=task.spawn")? .set_name("=task.spawn")?
.set_environment( .set_environment(
TableBuilder::new(lua)? TableBuilder::new(lua)?
@ -37,11 +40,9 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
.with_value("yield", task_spawn_env_yield)? .with_value("yield", task_spawn_env_yield)?
.with_function( .with_function(
"scheduleNext", "scheduleNext",
|lua, (tof, args): (LuaValue, LuaMultiValue)| { |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
let sched = lua let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
.app_data_ref::<&TaskScheduler>() sched.schedule_blocking(tof.into_thread(lua)?, args)
.expect(ERR_MISSING_SCHEDULER);
sched.schedule_blocking(tof, args)
}, },
)? )?
.build_readonly()?, .build_readonly()?,
@ -50,70 +51,16 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
// We want the task scheduler to be transparent, // We want the task scheduler to be transparent,
// but it does not return real lua threads, so // but it does not return real lua threads, so
// we need to override some globals to fake it // we need to override some globals to fake it
let type_original: LuaFunction = lua.named_registry_value("type")?;
let type_proxy = lua.create_function(move |_, value: LuaValue| {
if let LuaValue::UserData(u) = &value {
if u.is::<TaskReference>() {
return Ok(LuaValue::String(lua.create_string("thread")?));
}
}
type_original.call(value)
})?;
let typeof_original: LuaFunction = lua.named_registry_value("typeof")?;
let typeof_proxy = lua.create_function(move |_, value: LuaValue| {
if let LuaValue::UserData(u) = &value {
if u.is::<TaskReference>() {
return Ok(LuaValue::String(lua.create_string("thread")?));
}
}
typeof_original.call(value)
})?;
let globals = lua.globals(); let globals = lua.globals();
globals.set("type", type_proxy)?; globals.set("type", lua.create_function(proxy_type)?)?;
globals.set("typeof", typeof_proxy)?; globals.set("typeof", lua.create_function(proxy_typeof)?)?;
// Functions in the built-in coroutine library also need to be // Functions in the built-in coroutine library also need to be
// replaced, these are a bit different than the ones above because // replaced, these are a bit different than the ones above because
// calling resume or the function that wrap returns must return // calling resume or the function that wrap returns must return
// whatever lua value(s) that the thread or task yielded back // whatever lua value(s) that the thread or task yielded back
let coroutine = globals.get::<_, LuaTable>("coroutine")?; let coroutine = globals.get::<_, LuaTable>("coroutine")?;
coroutine.set( coroutine.set("resume", lua.create_function(coroutine_resume)?)?;
"resume", coroutine.set("wrap", lua.create_function(coroutine_wrap)?)?;
lua.create_function(|lua, value: LuaValue| {
let tname = value.type_name();
if let LuaValue::Thread(thread) = value {
let sched = lua
.app_data_ref::<&TaskScheduler>()
.expect(ERR_MISSING_SCHEDULER);
let task =
sched.create_task(TaskKind::Instant, LuaValue::Thread(thread), None, None)?;
sched.resume_task(task, None)
} else if let Ok(task) = TaskReference::from_lua(value, lua) {
lua.app_data_ref::<&TaskScheduler>()
.expect(ERR_MISSING_SCHEDULER)
.resume_task(task, None)
} else {
Err(LuaError::RuntimeError(format!(
"Argument #1 must be a thread, got {tname}",
)))
}
})?,
)?;
coroutine.set(
"wrap",
lua.create_function(|lua, func: LuaFunction| {
let sched = lua
.app_data_ref::<&TaskScheduler>()
.expect(ERR_MISSING_SCHEDULER);
let task =
sched.create_task(TaskKind::Instant, LuaValue::Function(func), None, None)?;
lua.create_function(move |lua, args: LuaMultiValue| {
let sched = lua
.app_data_ref::<&TaskScheduler>()
.expect(ERR_MISSING_SCHEDULER);
sched.resume_task(task, Some(Ok(args)))
})
})?,
)?;
// All good, return the task scheduler lib // All good, return the task scheduler lib
TableBuilder::new(lua)? TableBuilder::new(lua)?
.with_value("spawn", task_spawn)? .with_value("spawn", task_spawn)?
@ -124,29 +71,99 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'static>> {
.build_readonly() .build_readonly()
} }
/*
Proxy enum to deal with both threads & functions
*/
enum LuaThreadOrFunction<'lua> {
Thread(LuaThread<'lua>),
Function(LuaFunction<'lua>),
}
impl<'lua> LuaThreadOrFunction<'lua> {
fn into_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
match self {
Self::Thread(t) => Ok(t),
Self::Function(f) => lua.create_thread(f),
}
}
}
impl<'lua> FromLua<'lua> for LuaThreadOrFunction<'lua> {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Thread(t) => Ok(Self::Thread(t)),
LuaValue::Function(f) => Ok(Self::Function(f)),
value => Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "LuaThreadOrFunction",
message: Some(format!(
"Expected thread or function, got '{}'",
value.type_name()
)),
}),
}
}
}
/*
Proxy enum to deal with both threads & task scheduler task references
*/
enum LuaThreadOrTaskReference<'lua> {
Thread(LuaThread<'lua>),
TaskReference(TaskReference),
}
impl<'lua> FromLua<'lua> for LuaThreadOrTaskReference<'lua> {
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
let tname = value.type_name();
match value {
LuaValue::Thread(t) => Ok(Self::Thread(t)),
LuaValue::UserData(u) => {
if let Ok(task) = TaskReference::from_lua(LuaValue::UserData(u), lua) {
Ok(Self::TaskReference(task))
} else {
Err(LuaError::FromLuaConversionError {
from: tname,
to: "thread",
message: Some(format!("Expected thread, got '{tname}'")),
})
}
}
_ => Err(LuaError::FromLuaConversionError {
from: tname,
to: "thread",
message: Some(format!("Expected thread, got '{tname}'")),
}),
}
}
}
/*
Basic task functions
*/
fn task_cancel(lua: &Lua, task: TaskReference) -> LuaResult<()> { fn task_cancel(lua: &Lua, task: TaskReference) -> LuaResult<()> {
let sched = lua let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
.app_data_ref::<&TaskScheduler>()
.expect(ERR_MISSING_SCHEDULER);
sched.remove_task(task)?; sched.remove_task(task)?;
Ok(()) Ok(())
} }
fn task_defer(lua: &Lua, (tof, args): (LuaValue, LuaMultiValue)) -> LuaResult<TaskReference> { fn task_defer(
let sched = lua lua: &Lua,
.app_data_ref::<&TaskScheduler>() (tof, args): (LuaThreadOrFunction, LuaMultiValue),
.expect(ERR_MISSING_SCHEDULER); ) -> LuaResult<TaskReference> {
sched.schedule_blocking_deferred(tof, args) let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
sched.schedule_blocking_deferred(tof.into_thread(lua)?, args)
} }
fn task_delay( fn task_delay(
lua: &Lua, lua: &Lua,
(secs, tof, args): (f64, LuaValue, LuaMultiValue), (secs, tof, args): (f64, LuaThreadOrFunction, LuaMultiValue),
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
let sched = lua let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
.app_data_ref::<&TaskScheduler>() sched.schedule_blocking_after_seconds(secs, tof.into_thread(lua)?, args)
.expect(ERR_MISSING_SCHEDULER);
sched.schedule_blocking_after_seconds(secs, tof, args)
} }
async fn task_wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> { async fn task_wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
@ -154,3 +171,62 @@ async fn task_wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
sleep(Duration::from_secs_f64(secs.unwrap_or_default())).await; sleep(Duration::from_secs_f64(secs.unwrap_or_default())).await;
Ok(start.elapsed().as_secs_f64()) Ok(start.elapsed().as_secs_f64())
} }
/*
Type getter overrides for compat with task scheduler
*/
fn proxy_type<'lua>(lua: &'lua Lua, value: LuaValue<'lua>) -> LuaResult<LuaString<'lua>> {
if let LuaValue::UserData(u) = &value {
if u.is::<TaskReference>() {
return lua.create_string("thread");
}
}
lua.named_registry_value::<_, LuaFunction>("type")?
.call(value)
}
fn proxy_typeof<'lua>(lua: &'lua Lua, value: LuaValue<'lua>) -> LuaResult<LuaString<'lua>> {
if let LuaValue::UserData(u) = &value {
if u.is::<TaskReference>() {
return lua.create_string("thread");
}
}
lua.named_registry_value::<_, LuaFunction>("typeof")?
.call(value)
}
/*
Coroutine library overrides for compat with task scheduler
*/
fn coroutine_resume<'lua>(
lua: &'lua Lua,
value: LuaThreadOrTaskReference,
) -> LuaResult<LuaMultiValue<'lua>> {
match value {
LuaThreadOrTaskReference::Thread(t) => {
let sched = lua.app_data_ref::<&TaskScheduler>().unwrap();
let task = sched.create_task(TaskKind::Instant, t, None, None)?;
sched.resume_task(task, None)
}
LuaThreadOrTaskReference::TaskReference(t) => lua
.app_data_ref::<&TaskScheduler>()
.unwrap()
.resume_task(t, None),
}
}
fn coroutine_wrap<'lua>(lua: &'lua Lua, func: LuaFunction) -> LuaResult<LuaFunction<'lua>> {
let task = lua.app_data_ref::<&TaskScheduler>().unwrap().create_task(
TaskKind::Instant,
lua.create_thread(func)?,
None,
None,
)?;
lua.create_function(move |lua, args: LuaMultiValue| {
lua.app_data_ref::<&TaskScheduler>()
.unwrap()
.resume_task(task, Some(Ok(args)))
})
}

View file

@ -91,19 +91,19 @@ impl Lune {
) -> Result<ExitCode, LuaError> { ) -> Result<ExitCode, LuaError> {
// Create our special lune-flavored Lua object with extra registry values // Create our special lune-flavored Lua object with extra registry values
let lua = create_lune_lua().expect("Failed to create Lua object"); let lua = create_lune_lua().expect("Failed to create Lua object");
// Create our task scheduler and schedule the main thread on it // Create our task scheduler
let sched = TaskScheduler::new(lua)?.into_static(); let sched = TaskScheduler::new(lua)?.into_static();
lua.set_app_data(sched); lua.set_app_data(sched);
sched.schedule_blocking( // Create the main thread and schedule it
LuaValue::Function( let main_chunk = lua
lua.load(script_contents) .load(script_contents)
.set_name(script_name) .set_name(script_name)
.unwrap() .unwrap()
.into_function() .into_function()
.unwrap(), .unwrap();
), let main_thread = lua.create_thread(main_chunk).unwrap();
LuaValue::Nil.to_lua_multi(lua)?, let main_thread_args = LuaValue::Nil.to_lua_multi(lua)?;
)?; sched.schedule_blocking(main_thread, main_thread_args)?;
// Create our wanted lune globals, some of these need // Create our wanted lune globals, some of these need
// the task scheduler be available during construction // the task scheduler be available during construction
for global in self.includes.clone() { for global in self.includes.clone() {

View file

@ -38,7 +38,7 @@ impl LuaAsyncExt for &'static Lua {
let sched = lua let sched = lua
.app_data_ref::<&TaskScheduler>() .app_data_ref::<&TaskScheduler>()
.expect("Missing task scheduler as a lua app data"); .expect("Missing task scheduler as a lua app data");
sched.queue_async_task(LuaValue::Thread(thread), None, None, async { sched.queue_async_task(thread, None, None, async {
let rets = fut.await?; let rets = fut.await?;
let mult = rets.to_lua_multi(lua)?; let mult = rets.to_lua_multi(lua)?;
Ok(Some(mult)) Ok(Some(mult))

View file

@ -22,7 +22,7 @@ pub trait TaskSchedulerAsyncExt<'fut> {
fn schedule_async<'sched, R, F, FR>( fn schedule_async<'sched, R, F, FR>(
&'sched self, &'sched self,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
func: F, func: F,
) -> LuaResult<TaskReference> ) -> LuaResult<TaskReference>
where where
@ -73,7 +73,7 @@ impl<'fut> TaskSchedulerAsyncExt<'fut> for TaskScheduler<'fut> {
*/ */
fn schedule_async<'sched, R, F, FR>( fn schedule_async<'sched, R, F, FR>(
&'sched self, &'sched self,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
func: F, func: F,
) -> LuaResult<TaskReference> ) -> LuaResult<TaskReference>
where where
@ -82,7 +82,7 @@ impl<'fut> TaskSchedulerAsyncExt<'fut> for TaskScheduler<'fut> {
F: 'static + Fn(&'static Lua) -> FR, F: 'static + Fn(&'static Lua) -> FR,
FR: 'static + Future<Output = LuaResult<R>>, FR: 'static + Future<Output = LuaResult<R>>,
{ {
self.queue_async_task(thread_or_function, None, None, async move { self.queue_async_task(thread, None, None, async move {
match func(self.lua).await { match func(self.lua).await {
Ok(res) => match res.to_lua_multi(self.lua) { Ok(res) => match res.to_lua_multi(self.lua) {
Ok(multi) => Ok(Some(multi)), Ok(multi) => Ok(Some(multi)),

View file

@ -16,20 +16,20 @@ use super::super::{scheduler::TaskKind, scheduler::TaskReference, scheduler::Tas
pub trait TaskSchedulerScheduleExt { pub trait TaskSchedulerScheduleExt {
fn schedule_blocking( fn schedule_blocking(
&self, &self,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference>; ) -> LuaResult<TaskReference>;
fn schedule_blocking_deferred( fn schedule_blocking_deferred(
&self, &self,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference>; ) -> LuaResult<TaskReference>;
fn schedule_blocking_after_seconds( fn schedule_blocking_after_seconds(
&self, &self,
after_secs: f64, after_secs: f64,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference>; ) -> LuaResult<TaskReference>;
} }
@ -49,15 +49,10 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> {
*/ */
fn schedule_blocking( fn schedule_blocking(
&self, &self,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.queue_blocking_task( self.queue_blocking_task(TaskKind::Instant, thread, Some(thread_args), None)
TaskKind::Instant,
thread_or_function,
Some(thread_args),
None,
)
} }
/** /**
@ -69,15 +64,10 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> {
*/ */
fn schedule_blocking_deferred( fn schedule_blocking_deferred(
&self, &self,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.queue_blocking_task( self.queue_blocking_task(TaskKind::Deferred, thread, Some(thread_args), None)
TaskKind::Deferred,
thread_or_function,
Some(thread_args),
None,
)
} }
/** /**
@ -90,10 +80,10 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> {
fn schedule_blocking_after_seconds( fn schedule_blocking_after_seconds(
&self, &self,
after_secs: f64, after_secs: f64,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
thread_args: LuaMultiValue<'_>, thread_args: LuaMultiValue<'_>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
self.queue_async_task(thread_or_function, Some(thread_args), None, async move { self.queue_async_task(thread, Some(thread_args), None, async move {
sleep(Duration::from_secs_f64(after_secs)).await; sleep(Duration::from_secs_f64(after_secs)).await;
Ok(None) Ok(None)
}) })

View file

@ -121,27 +121,16 @@ impl<'fut> TaskScheduler<'fut> {
pub fn create_task( pub fn create_task(
&self, &self,
kind: TaskKind, kind: TaskKind,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
thread_args: Option<LuaMultiValue<'_>>, thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>, guid_to_reuse: Option<usize>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
// Get or create a thread from the given argument
let task_thread = match thread_or_function {
LuaValue::Thread(t) => t,
LuaValue::Function(f) => self.lua.create_thread(f)?,
value => {
return Err(LuaError::RuntimeError(format!(
"Argument must be a thread or function, got {}",
value.type_name()
)))
}
};
// Store the thread and its arguments in the registry // Store the thread and its arguments in the registry
// NOTE: We must convert to a vec since multis // NOTE: We must convert to a vec since multis
// can't be stored in the registry directly // can't be stored in the registry directly
let task_args_vec: Option<Vec<LuaValue>> = thread_args.map(|opt| opt.into_vec()); let task_args_vec: Option<Vec<LuaValue>> = thread_args.map(|opt| opt.into_vec());
let task_args_key: LuaRegistryKey = self.lua.create_registry_value(task_args_vec)?; let task_args_key: LuaRegistryKey = self.lua.create_registry_value(task_args_vec)?;
let task_thread_key: LuaRegistryKey = self.lua.create_registry_value(task_thread)?; let task_thread_key: LuaRegistryKey = self.lua.create_registry_value(thread)?;
// Create the full task struct // Create the full task struct
let task = Task { let task = Task {
thread: task_thread_key, thread: task_thread_key,
@ -264,14 +253,14 @@ impl<'fut> TaskScheduler<'fut> {
pub(crate) fn queue_blocking_task( pub(crate) fn queue_blocking_task(
&self, &self,
kind: TaskKind, kind: TaskKind,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
thread_args: Option<LuaMultiValue<'_>>, thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>, guid_to_reuse: Option<usize>,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
if kind == TaskKind::Future { if kind == TaskKind::Future {
panic!("Tried to schedule future using normal task schedule method") panic!("Tried to schedule future using normal task schedule method")
} }
let task_ref = self.create_task(kind, thread_or_function, thread_args, guid_to_reuse)?; let task_ref = self.create_task(kind, thread, thread_args, guid_to_reuse)?;
// Add the task to the front of the queue, unless it // Add the task to the front of the queue, unless it
// should be deferred, in that case add it to the back // should be deferred, in that case add it to the back
let mut queue = self.tasks_queue_blocking.borrow_mut(); let mut queue = self.tasks_queue_blocking.borrow_mut();
@ -303,17 +292,12 @@ impl<'fut> TaskScheduler<'fut> {
*/ */
pub(crate) fn queue_async_task( pub(crate) fn queue_async_task(
&self, &self,
thread_or_function: LuaValue<'_>, thread: LuaThread<'_>,
thread_args: Option<LuaMultiValue<'_>>, thread_args: Option<LuaMultiValue<'_>>,
guid_to_reuse: Option<usize>, guid_to_reuse: Option<usize>,
fut: impl Future<Output = TaskFutureRets<'fut>> + 'fut, fut: impl Future<Output = TaskFutureRets<'fut>> + 'fut,
) -> LuaResult<TaskReference> { ) -> LuaResult<TaskReference> {
let task_ref = self.create_task( let task_ref = self.create_task(TaskKind::Future, thread, thread_args, guid_to_reuse)?;
TaskKind::Future,
thread_or_function,
thread_args,
guid_to_reuse,
)?;
let futs = self let futs = self
.futures .futures
.try_lock() .try_lock()