diff --git a/packages/lib/src/globals/task.rs b/packages/lib/src/globals/task.rs index 07127a1..f2f3bff 100644 --- a/packages/lib/src/globals/task.rs +++ b/packages/lib/src/globals/task.rs @@ -8,28 +8,31 @@ use crate::{ utils::table::TableBuilder, }; -const ERR_MISSING_SCHEDULER: &str = "Missing task scheduler - make sure it is added as a lua app data before the first scheduler resumption"; - -const TASK_SPAWN_IMPL_LUA: &str = r#" --- Schedule the current thread at the front -scheduleNext(thread()) --- Schedule the wanted task arg at the front, --- the previous schedule now comes right after -local task = scheduleNext(...) --- Give control over to the scheduler, which will --- resume the above tasks in order when its ready -yield() -return task -"#; - pub fn create(lua: &'static Lua) -> LuaResult> { - // The spawn function needs special treatment, - // we need to yield right away to allow the - // spawned task to run until first yield + lua.app_data_ref::<&TaskScheduler>() + .expect("Missing task scheduler in app data"); + /* + 1. Schedule the current thread at the front + 2. Schedule the wanted task arg at the front, + the previous schedule now comes right after + 3. Give control over to the scheduler, which will + resume the above tasks in order when its ready + + The spawn function needs special treatment, + we need to yield right away to allow the + spawned task to run until first yield + */ let task_spawn_env_thread: LuaFunction = lua.named_registry_value("co.thread")?; let task_spawn_env_yield: LuaFunction = lua.named_registry_value("co.yield")?; let task_spawn = lua - .load(TASK_SPAWN_IMPL_LUA) + .load( + " + scheduleNext(thread()) + local task = scheduleNext(...) + yield() + return task + ", + ) .set_name("=task.spawn")? .set_environment( TableBuilder::new(lua)? @@ -37,11 +40,9 @@ pub fn create(lua: &'static Lua) -> LuaResult> { .with_value("yield", task_spawn_env_yield)? .with_function( "scheduleNext", - |lua, (tof, args): (LuaValue, LuaMultiValue)| { - let sched = lua - .app_data_ref::<&TaskScheduler>() - .expect(ERR_MISSING_SCHEDULER); - sched.schedule_blocking(tof, args) + |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| { + let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); + sched.schedule_blocking(tof.into_thread(lua)?, args) }, )? .build_readonly()?, @@ -50,70 +51,16 @@ pub fn create(lua: &'static Lua) -> LuaResult> { // We want the task scheduler to be transparent, // but it does not return real lua threads, so // we need to override some globals to fake it - let type_original: LuaFunction = lua.named_registry_value("type")?; - let type_proxy = lua.create_function(move |_, value: LuaValue| { - if let LuaValue::UserData(u) = &value { - if u.is::() { - return Ok(LuaValue::String(lua.create_string("thread")?)); - } - } - type_original.call(value) - })?; - let typeof_original: LuaFunction = lua.named_registry_value("typeof")?; - let typeof_proxy = lua.create_function(move |_, value: LuaValue| { - if let LuaValue::UserData(u) = &value { - if u.is::() { - return Ok(LuaValue::String(lua.create_string("thread")?)); - } - } - typeof_original.call(value) - })?; let globals = lua.globals(); - globals.set("type", type_proxy)?; - globals.set("typeof", typeof_proxy)?; + globals.set("type", lua.create_function(proxy_type)?)?; + globals.set("typeof", lua.create_function(proxy_typeof)?)?; // Functions in the built-in coroutine library also need to be // replaced, these are a bit different than the ones above because // calling resume or the function that wrap returns must return // whatever lua value(s) that the thread or task yielded back let coroutine = globals.get::<_, LuaTable>("coroutine")?; - coroutine.set( - "resume", - lua.create_function(|lua, value: LuaValue| { - let tname = value.type_name(); - if let LuaValue::Thread(thread) = value { - let sched = lua - .app_data_ref::<&TaskScheduler>() - .expect(ERR_MISSING_SCHEDULER); - let task = - sched.create_task(TaskKind::Instant, LuaValue::Thread(thread), None, None)?; - sched.resume_task(task, None) - } else if let Ok(task) = TaskReference::from_lua(value, lua) { - lua.app_data_ref::<&TaskScheduler>() - .expect(ERR_MISSING_SCHEDULER) - .resume_task(task, None) - } else { - Err(LuaError::RuntimeError(format!( - "Argument #1 must be a thread, got {tname}", - ))) - } - })?, - )?; - coroutine.set( - "wrap", - lua.create_function(|lua, func: LuaFunction| { - let sched = lua - .app_data_ref::<&TaskScheduler>() - .expect(ERR_MISSING_SCHEDULER); - let task = - sched.create_task(TaskKind::Instant, LuaValue::Function(func), None, None)?; - lua.create_function(move |lua, args: LuaMultiValue| { - let sched = lua - .app_data_ref::<&TaskScheduler>() - .expect(ERR_MISSING_SCHEDULER); - sched.resume_task(task, Some(Ok(args))) - }) - })?, - )?; + coroutine.set("resume", lua.create_function(coroutine_resume)?)?; + coroutine.set("wrap", lua.create_function(coroutine_wrap)?)?; // All good, return the task scheduler lib TableBuilder::new(lua)? .with_value("spawn", task_spawn)? @@ -124,29 +71,99 @@ pub fn create(lua: &'static Lua) -> LuaResult> { .build_readonly() } +/* + Proxy enum to deal with both threads & functions +*/ + +enum LuaThreadOrFunction<'lua> { + Thread(LuaThread<'lua>), + Function(LuaFunction<'lua>), +} + +impl<'lua> LuaThreadOrFunction<'lua> { + fn into_thread(self, lua: &'lua Lua) -> LuaResult> { + match self { + Self::Thread(t) => Ok(t), + Self::Function(f) => lua.create_thread(f), + } + } +} + +impl<'lua> FromLua<'lua> for LuaThreadOrFunction<'lua> { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + match value { + LuaValue::Thread(t) => Ok(Self::Thread(t)), + LuaValue::Function(f) => Ok(Self::Function(f)), + value => Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "LuaThreadOrFunction", + message: Some(format!( + "Expected thread or function, got '{}'", + value.type_name() + )), + }), + } + } +} + +/* + Proxy enum to deal with both threads & task scheduler task references +*/ + +enum LuaThreadOrTaskReference<'lua> { + Thread(LuaThread<'lua>), + TaskReference(TaskReference), +} + +impl<'lua> FromLua<'lua> for LuaThreadOrTaskReference<'lua> { + fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { + let tname = value.type_name(); + match value { + LuaValue::Thread(t) => Ok(Self::Thread(t)), + LuaValue::UserData(u) => { + if let Ok(task) = TaskReference::from_lua(LuaValue::UserData(u), lua) { + Ok(Self::TaskReference(task)) + } else { + Err(LuaError::FromLuaConversionError { + from: tname, + to: "thread", + message: Some(format!("Expected thread, got '{tname}'")), + }) + } + } + _ => Err(LuaError::FromLuaConversionError { + from: tname, + to: "thread", + message: Some(format!("Expected thread, got '{tname}'")), + }), + } + } +} + +/* + Basic task functions +*/ + fn task_cancel(lua: &Lua, task: TaskReference) -> LuaResult<()> { - let sched = lua - .app_data_ref::<&TaskScheduler>() - .expect(ERR_MISSING_SCHEDULER); + let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); sched.remove_task(task)?; Ok(()) } -fn task_defer(lua: &Lua, (tof, args): (LuaValue, LuaMultiValue)) -> LuaResult { - let sched = lua - .app_data_ref::<&TaskScheduler>() - .expect(ERR_MISSING_SCHEDULER); - sched.schedule_blocking_deferred(tof, args) +fn task_defer( + lua: &Lua, + (tof, args): (LuaThreadOrFunction, LuaMultiValue), +) -> LuaResult { + let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); + sched.schedule_blocking_deferred(tof.into_thread(lua)?, args) } fn task_delay( lua: &Lua, - (secs, tof, args): (f64, LuaValue, LuaMultiValue), + (secs, tof, args): (f64, LuaThreadOrFunction, LuaMultiValue), ) -> LuaResult { - let sched = lua - .app_data_ref::<&TaskScheduler>() - .expect(ERR_MISSING_SCHEDULER); - sched.schedule_blocking_after_seconds(secs, tof, args) + let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); + sched.schedule_blocking_after_seconds(secs, tof.into_thread(lua)?, args) } async fn task_wait(_: &Lua, secs: Option) -> LuaResult { @@ -154,3 +171,62 @@ async fn task_wait(_: &Lua, secs: Option) -> LuaResult { sleep(Duration::from_secs_f64(secs.unwrap_or_default())).await; Ok(start.elapsed().as_secs_f64()) } + +/* + Type getter overrides for compat with task scheduler +*/ + +fn proxy_type<'lua>(lua: &'lua Lua, value: LuaValue<'lua>) -> LuaResult> { + if let LuaValue::UserData(u) = &value { + if u.is::() { + return lua.create_string("thread"); + } + } + lua.named_registry_value::<_, LuaFunction>("type")? + .call(value) +} + +fn proxy_typeof<'lua>(lua: &'lua Lua, value: LuaValue<'lua>) -> LuaResult> { + if let LuaValue::UserData(u) = &value { + if u.is::() { + return lua.create_string("thread"); + } + } + lua.named_registry_value::<_, LuaFunction>("typeof")? + .call(value) +} + +/* + Coroutine library overrides for compat with task scheduler +*/ + +fn coroutine_resume<'lua>( + lua: &'lua Lua, + value: LuaThreadOrTaskReference, +) -> LuaResult> { + match value { + LuaThreadOrTaskReference::Thread(t) => { + let sched = lua.app_data_ref::<&TaskScheduler>().unwrap(); + let task = sched.create_task(TaskKind::Instant, t, None, None)?; + sched.resume_task(task, None) + } + LuaThreadOrTaskReference::TaskReference(t) => lua + .app_data_ref::<&TaskScheduler>() + .unwrap() + .resume_task(t, None), + } +} + +fn coroutine_wrap<'lua>(lua: &'lua Lua, func: LuaFunction) -> LuaResult> { + let task = lua.app_data_ref::<&TaskScheduler>().unwrap().create_task( + TaskKind::Instant, + lua.create_thread(func)?, + None, + None, + )?; + lua.create_function(move |lua, args: LuaMultiValue| { + lua.app_data_ref::<&TaskScheduler>() + .unwrap() + .resume_task(task, Some(Ok(args))) + }) +} diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index 8234428..8432347 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -91,19 +91,19 @@ impl Lune { ) -> Result { // Create our special lune-flavored Lua object with extra registry values let lua = create_lune_lua().expect("Failed to create Lua object"); - // Create our task scheduler and schedule the main thread on it + // Create our task scheduler let sched = TaskScheduler::new(lua)?.into_static(); lua.set_app_data(sched); - sched.schedule_blocking( - LuaValue::Function( - lua.load(script_contents) - .set_name(script_name) - .unwrap() - .into_function() - .unwrap(), - ), - LuaValue::Nil.to_lua_multi(lua)?, - )?; + // Create the main thread and schedule it + let main_chunk = lua + .load(script_contents) + .set_name(script_name) + .unwrap() + .into_function() + .unwrap(); + let main_thread = lua.create_thread(main_chunk).unwrap(); + let main_thread_args = LuaValue::Nil.to_lua_multi(lua)?; + sched.schedule_blocking(main_thread, main_thread_args)?; // Create our wanted lune globals, some of these need // the task scheduler be available during construction for global in self.includes.clone() { diff --git a/packages/lib/src/lua/ext.rs b/packages/lib/src/lua/ext.rs index ce715ef..940e14e 100644 --- a/packages/lib/src/lua/ext.rs +++ b/packages/lib/src/lua/ext.rs @@ -38,7 +38,7 @@ impl LuaAsyncExt for &'static Lua { let sched = lua .app_data_ref::<&TaskScheduler>() .expect("Missing task scheduler as a lua app data"); - sched.queue_async_task(LuaValue::Thread(thread), None, None, async { + sched.queue_async_task(thread, None, None, async { let rets = fut.await?; let mult = rets.to_lua_multi(lua)?; Ok(Some(mult)) diff --git a/packages/lib/src/lua/task/ext/async_ext.rs b/packages/lib/src/lua/task/ext/async_ext.rs index 97223bd..deaf59b 100644 --- a/packages/lib/src/lua/task/ext/async_ext.rs +++ b/packages/lib/src/lua/task/ext/async_ext.rs @@ -22,7 +22,7 @@ pub trait TaskSchedulerAsyncExt<'fut> { fn schedule_async<'sched, R, F, FR>( &'sched self, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, func: F, ) -> LuaResult where @@ -73,7 +73,7 @@ impl<'fut> TaskSchedulerAsyncExt<'fut> for TaskScheduler<'fut> { */ fn schedule_async<'sched, R, F, FR>( &'sched self, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, func: F, ) -> LuaResult where @@ -82,7 +82,7 @@ impl<'fut> TaskSchedulerAsyncExt<'fut> for TaskScheduler<'fut> { F: 'static + Fn(&'static Lua) -> FR, FR: 'static + Future>, { - self.queue_async_task(thread_or_function, None, None, async move { + self.queue_async_task(thread, None, None, async move { match func(self.lua).await { Ok(res) => match res.to_lua_multi(self.lua) { Ok(multi) => Ok(Some(multi)), diff --git a/packages/lib/src/lua/task/ext/schedule_ext.rs b/packages/lib/src/lua/task/ext/schedule_ext.rs index 9455a53..57f7da9 100644 --- a/packages/lib/src/lua/task/ext/schedule_ext.rs +++ b/packages/lib/src/lua/task/ext/schedule_ext.rs @@ -16,20 +16,20 @@ use super::super::{scheduler::TaskKind, scheduler::TaskReference, scheduler::Tas pub trait TaskSchedulerScheduleExt { fn schedule_blocking( &self, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult; fn schedule_blocking_deferred( &self, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult; fn schedule_blocking_after_seconds( &self, after_secs: f64, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult; } @@ -49,15 +49,10 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> { */ fn schedule_blocking( &self, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - self.queue_blocking_task( - TaskKind::Instant, - thread_or_function, - Some(thread_args), - None, - ) + self.queue_blocking_task(TaskKind::Instant, thread, Some(thread_args), None) } /** @@ -69,15 +64,10 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> { */ fn schedule_blocking_deferred( &self, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - self.queue_blocking_task( - TaskKind::Deferred, - thread_or_function, - Some(thread_args), - None, - ) + self.queue_blocking_task(TaskKind::Deferred, thread, Some(thread_args), None) } /** @@ -90,10 +80,10 @@ impl TaskSchedulerScheduleExt for TaskScheduler<'_> { fn schedule_blocking_after_seconds( &self, after_secs: f64, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, thread_args: LuaMultiValue<'_>, ) -> LuaResult { - self.queue_async_task(thread_or_function, Some(thread_args), None, async move { + self.queue_async_task(thread, Some(thread_args), None, async move { sleep(Duration::from_secs_f64(after_secs)).await; Ok(None) }) diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs index 9af693a..a38a3b0 100644 --- a/packages/lib/src/lua/task/scheduler.rs +++ b/packages/lib/src/lua/task/scheduler.rs @@ -121,27 +121,16 @@ impl<'fut> TaskScheduler<'fut> { pub fn create_task( &self, kind: TaskKind, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, thread_args: Option>, guid_to_reuse: Option, ) -> LuaResult { - // Get or create a thread from the given argument - let task_thread = match thread_or_function { - LuaValue::Thread(t) => t, - LuaValue::Function(f) => self.lua.create_thread(f)?, - value => { - return Err(LuaError::RuntimeError(format!( - "Argument must be a thread or function, got {}", - value.type_name() - ))) - } - }; // Store the thread and its arguments in the registry // NOTE: We must convert to a vec since multis // can't be stored in the registry directly let task_args_vec: Option> = thread_args.map(|opt| opt.into_vec()); let task_args_key: LuaRegistryKey = self.lua.create_registry_value(task_args_vec)?; - let task_thread_key: LuaRegistryKey = self.lua.create_registry_value(task_thread)?; + let task_thread_key: LuaRegistryKey = self.lua.create_registry_value(thread)?; // Create the full task struct let task = Task { thread: task_thread_key, @@ -264,14 +253,14 @@ impl<'fut> TaskScheduler<'fut> { pub(crate) fn queue_blocking_task( &self, kind: TaskKind, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, thread_args: Option>, guid_to_reuse: Option, ) -> LuaResult { if kind == TaskKind::Future { panic!("Tried to schedule future using normal task schedule method") } - let task_ref = self.create_task(kind, thread_or_function, thread_args, guid_to_reuse)?; + let task_ref = self.create_task(kind, thread, thread_args, guid_to_reuse)?; // Add the task to the front of the queue, unless it // should be deferred, in that case add it to the back let mut queue = self.tasks_queue_blocking.borrow_mut(); @@ -303,17 +292,12 @@ impl<'fut> TaskScheduler<'fut> { */ pub(crate) fn queue_async_task( &self, - thread_or_function: LuaValue<'_>, + thread: LuaThread<'_>, thread_args: Option>, guid_to_reuse: Option, fut: impl Future> + 'fut, ) -> LuaResult { - let task_ref = self.create_task( - TaskKind::Future, - thread_or_function, - thread_args, - guid_to_reuse, - )?; + let task_ref = self.create_task(TaskKind::Future, thread, thread_args, guid_to_reuse)?; let futs = self .futures .try_lock()