diff --git a/packages/lib/src/lua/net/mod.rs b/packages/lib/src/lua/net/mod.rs index 4042953..c52b780 100644 --- a/packages/lib/src/lua/net/mod.rs +++ b/packages/lib/src/lua/net/mod.rs @@ -1,9 +1,11 @@ mod client; mod config; +mod response; mod server; mod websocket; pub use client::{NetClient, NetClientBuilder}; pub use config::{RequestConfig, ServeConfig}; +pub use response::{NetServeResponse, NetServeResponseKind}; pub use server::{NetLocalExec, NetService}; pub use websocket::NetWebSocket; diff --git a/packages/lib/src/lua/net/response.rs b/packages/lib/src/lua/net/response.rs new file mode 100644 index 0000000..aeff87c --- /dev/null +++ b/packages/lib/src/lua/net/response.rs @@ -0,0 +1,106 @@ +use std::collections::HashMap; + +use hyper::{Body, Response}; +use mlua::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum NetServeResponseKind { + PlainText, + Table, +} + +#[derive(Debug, Clone)] +pub struct NetServeResponse { + kind: NetServeResponseKind, + status: u16, + headers: HashMap>, + body: Option>, +} + +impl NetServeResponse { + pub fn into_response(self) -> LuaResult> { + Ok(match self.kind { + NetServeResponseKind::PlainText => Response::builder() + .status(200) + .header("Content-Type", "text/plain") + .body(Body::from(self.body.unwrap())) + .map_err(LuaError::external)?, + NetServeResponseKind::Table => { + let mut response = Response::builder(); + for (key, value) in self.headers { + response = response.header(&key, value); + } + response + .status(self.status) + .body(Body::from(self.body.unwrap_or_default())) + .map_err(LuaError::external)? + } + }) + } +} + +impl<'lua> FromLua<'lua> for NetServeResponse { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + match value { + // Plain strings from the handler are plaintext responses + LuaValue::String(s) => Ok(Self { + kind: NetServeResponseKind::PlainText, + status: 200, + headers: HashMap::new(), + body: Some(s.as_bytes().to_vec()), + }), + // Tables are more detailed responses with potential status, headers, body + LuaValue::Table(t) => { + let status: Option = t.get("status")?; + let headers: Option = t.get("headers")?; + let body: Option = t.get("body")?; + + let mut headers_map = HashMap::new(); + if let Some(headers) = headers { + for pair in headers.pairs::() { + let (h, v) = pair?; + headers_map.insert(h, v.as_bytes().to_vec()); + } + } + + let body_bytes = body.map(|s| s.as_bytes().to_vec()); + + Ok(Self { + kind: NetServeResponseKind::Table, + status: status.unwrap_or(200), + headers: headers_map, + body: body_bytes, + }) + } + // Anything else is an error + value => Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "NetServeResponse", + message: None, + }), + } + } +} + +impl<'lua> ToLua<'lua> for NetServeResponse { + fn to_lua(self, lua: &'lua Lua) -> LuaResult> { + if self.headers.len() > i32::MAX as usize { + return Err(LuaError::ToLuaConversionError { + from: "NetServeResponse", + to: "table", + message: Some("Too many header values".to_string()), + }); + } + let body = self.body.map(|b| lua.create_string(&b)).transpose()?; + let headers = lua.create_table_with_capacity(0, self.headers.len() as i32)?; + for (key, value) in self.headers { + headers.set(key, lua.create_string(&value)?)?; + } + let table = lua.create_table_with_capacity(0, 3)?; + table.set("status", self.status)?; + table.set("headers", headers)?; + table.set("body", body)?; + table.set_readonly(true); + Ok(LuaValue::Table(table)) + } +} diff --git a/packages/lib/src/lua/net/server.rs b/packages/lib/src/lua/net/server.rs index a530acb..de37f99 100644 --- a/packages/lib/src/lua/net/server.rs +++ b/packages/lib/src/lua/net/server.rs @@ -17,7 +17,7 @@ use crate::{ utils::table::TableBuilder, }; -use super::NetWebSocket; +use super::{NetServeResponse, NetWebSocket}; // Hyper service implementation for net, lots of boilerplate here // but make_svc and make_svc_function do not work for what we need @@ -77,7 +77,6 @@ impl Service> for NetServiceInner { let (parts, body) = req.into_parts(); Box::pin(async move { // Convert request body into bytes, extract handler - // function & lune message sender to use later let bytes = to_bytes(body).await.map_err(LuaError::external)?; let handler: LuaFunction = lua.registry_value(&key)?; // Create a readonly table for the request query params @@ -112,53 +111,22 @@ impl Service> for NetServiceInner { .with_value("headers", header_map)? .with_value("body", lua.create_string(&bytes)?)? .build_readonly()?; - // TODO: Make some kind of NetServeResponse type with a - // FromLua implementation instead, this is a bit messy - // and does not send errors to the scheduler properly - match handler.call(request) { - // Plain strings from the handler are plaintext responses - Ok(LuaValue::String(s)) => Ok(Response::builder() - .status(200) - .header("Content-Type", "text/plain") - .body(Body::from(s.as_bytes().to_vec())) - .unwrap()), - // Tables are more detailed responses with potential status, headers, body - Ok(LuaValue::Table(t)) => { - let status = t.get::<_, Option>("status")?.unwrap_or(200); - let mut resp = Response::builder().status(status); - - if let Some(headers) = t.get::<_, Option>("headers")? { - for pair in headers.pairs::() { - let (h, v) = pair?; - resp = resp.header(&h, v.as_bytes()); - } - } - - let body = t - .get::<_, Option>("body")? - .map(|b| Body::from(b.as_bytes().to_vec())) - .unwrap_or_else(Body::empty); - - Ok(resp.body(body).unwrap()) - } - // If the handler returns an error, generate a 5xx response - Err(_) => { - // TODO: Send above error to task scheduler so that it can emit properly - Ok(Response::builder() - .status(500) - .body(Body::from("Internal Server Error")) - .unwrap()) - } - // If the handler returns a value that is of an invalid type, - // this should also be an error, so generate a 5xx response - Ok(_) => { - // TODO: Implement the type in the above todo - Ok(Response::builder() - .status(500) - .body(Body::from("Internal Server Error")) - .unwrap()) - } - } + let response: LuaResult = handler.call(request); + // Send below errors to task scheduler so that they can emit properly + let lua_error = match response { + Ok(r) => match r.into_response() { + Ok(res) => return Ok(res), + Err(err) => err, + }, + Err(err) => err, + }; + lua.app_data_ref::<&TaskScheduler>() + .expect("Missing task scheduler") + .forward_lua_error(lua_error); + Ok(Response::builder() + .status(500) + .body(Body::from("Internal Server Error")) + .unwrap()) }) } } diff --git a/packages/lib/src/lua/task/ext/resume_ext.rs b/packages/lib/src/lua/task/ext/resume_ext.rs index 107a9b4..f6d099c 100644 --- a/packages/lib/src/lua/task/ext/resume_ext.rs +++ b/packages/lib/src/lua/task/ext/resume_ext.rs @@ -143,6 +143,7 @@ async fn receive_next_message(scheduler: &TaskScheduler<'_>) -> TaskSchedulerSta if let Some(message) = message_opt { match message { TaskSchedulerMessage::NewBlockingTaskReady => TaskSchedulerState::new(scheduler), + TaskSchedulerMessage::NewLuaErrorReady(err) => TaskSchedulerState::err(scheduler, err), TaskSchedulerMessage::Spawned => { let prev = scheduler.futures_background_count.get(); scheduler.futures_background_count.set(prev + 1); diff --git a/packages/lib/src/lua/task/message.rs b/packages/lib/src/lua/task/message.rs index 463aff7..cddc015 100644 --- a/packages/lib/src/lua/task/message.rs +++ b/packages/lib/src/lua/task/message.rs @@ -5,6 +5,7 @@ use mlua::prelude::*; #[derive(Debug, Clone)] pub enum TaskSchedulerMessage { NewBlockingTaskReady, + NewLuaErrorReady(LuaError), Spawned, Terminated(LuaResult<()>), } diff --git a/packages/lib/src/lua/task/scheduler.rs b/packages/lib/src/lua/task/scheduler.rs index 7573f65..6c6c53e 100644 --- a/packages/lib/src/lua/task/scheduler.rs +++ b/packages/lib/src/lua/task/scheduler.rs @@ -109,6 +109,35 @@ impl<'fut> TaskScheduler<'fut> { self.exit_code.set(Some(code)); } + /** + Forwards a lua error to be emitted as soon as possible, + after any current blocking / queued tasks have been resumed. + + Useful when an async function may call into Lua and get a + result back, without erroring out of the entire async block. + */ + pub fn forward_lua_error(&self, err: LuaError) { + let sender = self.futures_tx.clone(); + sender + .send(TaskSchedulerMessage::NewLuaErrorReady(err)) + .unwrap_or_else(|e| { + panic!( + "\ + \nFailed to forward lua error - this is an internal error! \ + \nPlease report it at {} \ + \nDetails: {e} \ + ", + env!("CARGO_PKG_REPOSITORY") + ) + }); + } + + /** + Forces the current task to be set to the given reference. + + Useful if a task is to be resumed externally but full + compatibility with the task scheduler is still necessary. + */ pub(crate) fn force_set_current_task(&self, reference: Option) { self.tasks_current.set(reference); }