From c920d31126bf1e37b7ad89474d854b3892d69ea9 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Wed, 17 Jan 2024 13:39:06 +0100 Subject: [PATCH] Expand async ext for creating non-send functions --- src/lua_ext.rs | 56 ++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 47 insertions(+), 9 deletions(-) diff --git a/src/lua_ext.rs b/src/lua_ext.rs index b8098d7..ec955ef 100644 --- a/src/lua_ext.rs +++ b/src/lua_ext.rs @@ -1,23 +1,36 @@ use std::future::Future; use mlua::prelude::*; -use tokio::spawn; +use tokio::{spawn, task::spawn_local}; use crate::{AsyncValues, Message, MessageSender, ThreadId}; -pub trait LuaSchedulerExt<'lua> { - fn create_async_function(&'lua self, func: F) -> LuaResult> +pub trait LuaAsyncExt<'lua> { + fn current_thread_id(&'lua self) -> ThreadId; + + fn create_async_function(&'lua self, f: F) -> LuaResult> where - A: FromLuaMulti<'lua> + 'static, + A: FromLuaMulti<'lua>, R: Into + Send + 'static, F: Fn(&'lua Lua, A) -> FR + 'static, FR: Future> + Send + 'static; + + fn create_local_async_function(&'lua self, f: F) -> LuaResult> + where + A: FromLuaMulti<'lua>, + R: Into + 'static, + F: Fn(&'lua Lua, A) -> FR + 'static, + FR: Future> + 'static; } -impl<'lua> LuaSchedulerExt<'lua> for Lua { - fn create_async_function(&'lua self, func: F) -> LuaResult> +impl<'lua> LuaAsyncExt<'lua> for Lua { + fn current_thread_id(&'lua self) -> ThreadId { + ThreadId::from(self.current_thread()) + } + + fn create_async_function(&'lua self, f: F) -> LuaResult> where - A: FromLuaMulti<'lua> + 'static, + A: FromLuaMulti<'lua>, R: Into + Send + 'static, F: Fn(&'lua Lua, A) -> FR + 'static, FR: Future> + Send + 'static, @@ -25,8 +38,8 @@ impl<'lua> LuaSchedulerExt<'lua> for Lua { let tx = self.app_data_ref::().unwrap().clone(); self.create_function(move |lua, args: A| { - let thread_id = ThreadId::from(lua.current_thread()); - let fut = func(lua, args); + let thread_id = lua.current_thread_id(); + let fut = f(lua, args); let tx = tx.clone(); spawn(async move { @@ -39,4 +52,29 @@ impl<'lua> LuaSchedulerExt<'lua> for Lua { Ok(()) }) } + + fn create_local_async_function(&'lua self, f: F) -> LuaResult> + where + A: FromLuaMulti<'lua>, + R: Into + 'static, + F: Fn(&'lua Lua, A) -> FR + 'static, + FR: Future> + 'static, + { + let tx = self.app_data_ref::().unwrap().clone(); + + self.create_function(move |lua, args: A| { + let thread_id = lua.current_thread_id(); + let fut = f(lua, args); + let tx = tx.clone(); + + spawn_local(async move { + tx.send(match fut.await { + Ok(args) => Message::Resume(thread_id, Ok(args.into())), + Err(e) => Message::Resume(thread_id, Err(e)), + }) + }); + + Ok(()) + }) + } }