diff --git a/packages/lib/src/globals/net.rs b/packages/lib/src/globals/net.rs index df544d9..864d959 100644 --- a/packages/lib/src/globals/net.rs +++ b/packages/lib/src/globals/net.rs @@ -1,25 +1,15 @@ -use std::{ - collections::HashMap, - future::Future, - pin::Pin, - sync::Arc, - task::{Context, Poll}, -}; +use std::collections::HashMap; -use console::style; use mlua::prelude::*; -use hyper::{body::to_bytes, server::conn::AddrStream, service::Service}; -use hyper::{Body, Request, Response, Server}; -use hyper_tungstenite::{is_upgrade_request as is_ws_upgrade_request, upgrade as ws_upgrade}; - -use reqwest::Method; +use console::style; +use hyper::Server; use tokio::{sync::mpsc, task}; use crate::{ lua::{ // net::{NetWebSocketClient, NetWebSocketServer}, - net::{NetClient, NetClientBuilder, ServeConfig}, + net::{NetClient, NetClientBuilder, NetLocalExec, NetService, RequestConfig, ServeConfig}, task::TaskScheduler, }, utils::{net::get_request_user_agent_header, table::TableBuilder}, @@ -27,7 +17,8 @@ use crate::{ pub fn create(lua: &'static Lua) -> LuaResult { // Create a reusable client for performing our - // web requests and store it in the lua registry + // web requests and store it in the lua registry, + // allowing us to reuse headers and internal structs let client = NetClientBuilder::new() .headers(&[("User-Agent", get_request_user_agent_header())])? .build()?; @@ -55,74 +46,15 @@ fn net_json_decode(lua: &'static Lua, json: String) -> LuaResult { lua.to_value(&json) } -async fn net_request<'a>(lua: &'static Lua, config: LuaValue<'a>) -> LuaResult> { - let client: NetClient = lua.named_registry_value("net.client")?; - // Extract stuff from config and make sure its all valid - let (url, method, headers, body) = match config { - LuaValue::String(s) => { - let url = s.to_string_lossy().to_string(); - let method = "GET".to_string(); - Ok((url, method, HashMap::new(), None)) - } - LuaValue::Table(tab) => { - // Extract url - let url = match tab.raw_get::<_, LuaString>("url") { - Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), - Err(_) => Err(LuaError::RuntimeError( - "Missing 'url' in request config".to_string(), - )), - }?; - // Extract method - let method = match tab.raw_get::<_, LuaString>("method") { - Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(), - Err(_) => "GET".to_string(), - }; - // Extract headers - let headers = match tab.raw_get::<_, LuaTable>("headers") { - Ok(config_headers) => { - let mut lua_headers = HashMap::new(); - for pair in config_headers.pairs::() { - let (key, value) = pair?.to_owned(); - lua_headers.insert(key, value); - } - lua_headers - } - Err(_) => HashMap::new(), - }; - // Extract body - let body = match tab.raw_get::<_, LuaString>("body") { - Ok(config_body) => Some(config_body.as_bytes().to_owned()), - Err(_) => None, - }; - Ok((url, method, headers, body)) - } - value => Err(LuaError::RuntimeError(format!( - "Invalid request config - expected string or table, got {}", - value.type_name() - ))), - }?; - // Convert method string into proper enum - let method = method.trim().to_ascii_uppercase(); - let method = match method.as_ref() { - "GET" => Ok(Method::GET), - "POST" => Ok(Method::POST), - "PUT" => Ok(Method::PUT), - "DELETE" => Ok(Method::DELETE), - "HEAD" => Ok(Method::HEAD), - "OPTIONS" => Ok(Method::OPTIONS), - "PATCH" => Ok(Method::PATCH), - _ => Err(LuaError::RuntimeError(format!( - "Invalid request config method '{}'", - &method - ))), - }?; +async fn net_request<'a>(lua: &'static Lua, config: RequestConfig<'a>) -> LuaResult> { // Create and send the request - let mut request = client.request(method, &url); - for (header, value) in headers { + let client: NetClient = lua.named_registry_value("net.client")?; + let mut request = client.request(config.method, &config.url); + for (header, value) in config.headers { request = request.header(header.to_str()?, value.to_str()?); } let res = request - .body(body.unwrap_or_default()) + .body(config.body.unwrap_or_default()) .send() .await .map_err(LuaError::external)?; @@ -146,13 +78,13 @@ async fn net_request<'a>(lua: &'static Lua, config: LuaValue<'a>) -> LuaResult(lua: &'static Lua, url: String) -> LuaResult { - let (ws, _) = tokio_tungstenite::connect_async(url) - .await - .map_err(LuaError::external)?; +async fn net_socket<'a>(_lua: &'static Lua, _url: String) -> LuaResult { Err(LuaError::RuntimeError( "Client websockets are not yet implemented".to_string(), )) + // let (ws, _) = tokio_tungstenite::connect_async(url) + // .await + // .map_err(LuaError::external)?; // let sock = NetWebSocketClient::from(ws); // let table = sock.into_lua_table(lua)?; // Ok(table) @@ -171,11 +103,11 @@ async fn net_serve<'a>( // a oneshot channel since we move the sender // into our table with the stop function let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - let server_request_callback = Arc::new(lua.create_registry_value(config.handle_request)?); - let server_websocket_callback = Arc::new(config.handle_web_socket.map(|handler| { + let server_request_callback = lua.create_registry_value(config.handle_request)?; + let server_websocket_callback = config.handle_web_socket.map(|handler| { lua.create_registry_value(handler) .expect("Failed to store websocket handler") - })); + }); let sched = lua.app_data_mut::<&TaskScheduler>().unwrap(); // Bind first to make sure that we can bind to this address let bound = match Server::try_bind(&([127, 0, 0, 1], port).into()) { @@ -190,14 +122,14 @@ async fn net_serve<'a>( } Ok(bound) => bound, }; - // Register a background task to prevent - // the task scheduler from exiting early + // Register a background task to prevent the task scheduler from + // exiting early and start up our web server on the bound address let task = sched.register_background_task(); let server = bound - .http1_only(true) - .http1_keepalive(true) - .executor(LocalExec) - .serve(MakeNetService( + .http1_only(true) // Web sockets can only use http1 + .http1_keepalive(true) // Web sockets must be kept alive + .executor(NetLocalExec) + .serve(NetService::new( lua, server_request_callback, server_websocket_callback, @@ -214,191 +146,13 @@ async fn net_serve<'a>( task::spawn_local(server); // Create a new read-only table that contains methods // for manipulating server behavior and shutting it down - let handle_stop = move |_, _: ()| { - if shutdown_tx.try_send(()).is_err() { - Err(LuaError::RuntimeError( - "Server has already been stopped".to_string(), - )) - } else { - Ok(()) - } + let handle_stop = move |_, _: ()| match shutdown_tx.try_send(()) { + Ok(_) => Ok(()), + Err(_) => Err(LuaError::RuntimeError( + "Server has already been stopped".to_string(), + )), }; TableBuilder::new(lua)? .with_function("stop", handle_stop)? .build_readonly() } - -// Hyper service implementation for net, lots of boilerplate here -// but make_svc and make_svc_function do not work for what we need - -pub struct NetService( - &'static Lua, - Arc, - Arc>, -); - -impl Service> for NetService { - type Response = Response; - type Error = LuaError; - type Future = Pin>>>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, mut req: Request) -> Self::Future { - let lua = self.0; - if self.2.is_some() && is_ws_upgrade_request(&req) { - // Websocket upgrade request + websocket handler exists, - // we should now upgrade this connection to a websocket - // and then call our handler with a new socket object - let kopt = self.2.clone(); - let key = kopt.as_ref().as_ref().unwrap(); - let handler: LuaFunction = lua.registry_value(key).expect("Missing websocket handler"); - let (response, ws) = ws_upgrade(&mut req, None).expect("Failed to upgrade websocket"); - // TODO: This should be spawned as part of the scheduler, - // the scheduler may exit early and cancel this even though what - // we want here is a long-running task that keeps the program alive - task::spawn_local(async move { - // Create our new full websocket object, then - // schedule our handler to get called asap - let ws = ws.await.map_err(LuaError::external)?; - // let sock = NetWebSocketServer::from(ws); - // let table = sock.into_lua_table(lua)?; - // let sched = lua.app_data_mut::<&TaskScheduler>().unwrap(); - // sched.schedule_current_resume( - // LuaValue::Function(handler), - // LuaMultiValue::from_vec(vec![LuaValue::Table(table)]), - // ) - Ok::<_, LuaError>(()) - }); - Box::pin(async move { Ok(response) }) - } else { - // Got a normal http request or no websocket handler - // exists, just call the http request handler - let key = self.1.clone(); - let (parts, body) = req.into_parts(); - Box::pin(async move { - // Convert request body into bytes, extract handler - // function & lune message sender to use later - let bytes = to_bytes(body).await.map_err(LuaError::external)?; - let handler: LuaFunction = lua.registry_value(&key)?; - // Create a readonly table for the request query params - let query_params = TableBuilder::new(lua)? - .with_values( - parts - .uri - .query() - .unwrap_or_default() - .split('&') - .filter_map(|q| q.split_once('=')) - .collect(), - )? - .build_readonly()?; - // Do the same for headers - let header_map = TableBuilder::new(lua)? - .with_values( - parts - .headers - .iter() - .map(|(name, value)| { - (name.to_string(), value.to_str().unwrap().to_string()) - }) - .collect(), - )? - .build_readonly()?; - // Create a readonly table with request info to pass to the handler - let request = TableBuilder::new(lua)? - .with_value("path", parts.uri.path())? - .with_value("query", query_params)? - .with_value("method", parts.method.as_str())? - .with_value("headers", header_map)? - .with_value("body", lua.create_string(&bytes)?)? - .build_readonly()?; - match handler.call(request) { - // Plain strings from the handler are plaintext responses - Ok(LuaValue::String(s)) => Ok(Response::builder() - .status(200) - .header("Content-Type", "text/plain") - .body(Body::from(s.as_bytes().to_vec())) - .unwrap()), - // Tables are more detailed responses with potential status, headers, body - Ok(LuaValue::Table(t)) => { - let status = t.get::<_, Option>("status")?.unwrap_or(200); - let mut resp = Response::builder().status(status); - - if let Some(headers) = t.get::<_, Option>("headers")? { - for pair in headers.pairs::() { - let (h, v) = pair?; - resp = resp.header(&h, v.as_bytes()); - } - } - - let body = t - .get::<_, Option>("body")? - .map(|b| Body::from(b.as_bytes().to_vec())) - .unwrap_or_else(Body::empty); - - Ok(resp.body(body).unwrap()) - } - // If the handler returns an error, generate a 5xx response - Err(_) => { - // TODO: Send above error to task scheduler so that it can emit properly - Ok(Response::builder() - .status(500) - .body(Body::from("Internal Server Error")) - .unwrap()) - } - // If the handler returns a value that is of an invalid type, - // this should also be an error, so generate a 5xx response - Ok(value) => { - // TODO: Send below error to task scheduler so that it can emit properly - let _ = LuaError::RuntimeError(format!( - "Expected net serve handler to return a value of type 'string' or 'table', got '{}'", - value.type_name() - )); - Ok(Response::builder() - .status(500) - .body(Body::from("Internal Server Error")) - .unwrap()) - } - } - }) - } - } -} - -struct MakeNetService( - &'static Lua, - Arc, - Arc>, -); - -impl Service<&AddrStream> for MakeNetService { - type Response = NetService; - type Error = hyper::Error; - type Future = Pin>>>; - - fn poll_ready(&mut self, _: &mut Context) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _: &AddrStream) -> Self::Future { - let lua = self.0; - let key1 = self.1.clone(); - let key2 = self.2.clone(); - Box::pin(async move { Ok(NetService(lua, key1, key2)) }) - } -} - -#[derive(Clone, Copy, Debug)] -struct LocalExec; - -impl hyper::rt::Executor for LocalExec -where - F: std::future::Future + 'static, // not requiring `Send` -{ - fn execute(&self, fut: F) { - task::spawn_local(fut); - } -} diff --git a/packages/lib/src/lua/net/config.rs b/packages/lib/src/lua/net/config.rs index f382a28..981af09 100644 --- a/packages/lib/src/lua/net/config.rs +++ b/packages/lib/src/lua/net/config.rs @@ -1,5 +1,93 @@ +use std::collections::HashMap; + use mlua::prelude::*; +use reqwest::Method; + +// Net request config + +pub struct RequestConfig<'a> { + pub url: String, + pub method: Method, + pub headers: HashMap, LuaString<'a>>, + pub body: Option>, +} + +impl<'lua> FromLua<'lua> for RequestConfig<'lua> { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + // If we just got a string we assume its a GET request to a given url + if let LuaValue::String(s) = value { + return Ok(Self { + url: s.to_string_lossy().to_string(), + method: Method::GET, + headers: HashMap::new(), + body: None, + }); + } + // If we got a table we are able to configure the entire request + if let LuaValue::Table(tab) = value { + // Extract url + let url = match tab.raw_get::<_, LuaString>("url") { + Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), + Err(_) => Err(LuaError::RuntimeError( + "Missing 'url' in request config".to_string(), + )), + }?; + // Extract method + let method = match tab.raw_get::<_, LuaString>("method") { + Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(), + Err(_) => "GET".to_string(), + }; + // Extract headers + let headers = match tab.raw_get::<_, LuaTable>("headers") { + Ok(config_headers) => { + let mut lua_headers = HashMap::new(); + for pair in config_headers.pairs::() { + let (key, value) = pair?.to_owned(); + lua_headers.insert(key, value); + } + lua_headers + } + Err(_) => HashMap::new(), + }; + // Extract body + let body = match tab.raw_get::<_, LuaString>("body") { + Ok(config_body) => Some(config_body.as_bytes().to_owned()), + Err(_) => None, + }; + // Convert method string into proper enum + let method = method.trim().to_ascii_uppercase(); + let method = match method.as_ref() { + "GET" => Ok(Method::GET), + "POST" => Ok(Method::POST), + "PUT" => Ok(Method::PUT), + "DELETE" => Ok(Method::DELETE), + "HEAD" => Ok(Method::HEAD), + "OPTIONS" => Ok(Method::OPTIONS), + "PATCH" => Ok(Method::PATCH), + _ => Err(LuaError::RuntimeError(format!( + "Invalid request config method '{}'", + &method + ))), + }?; + // All good, validated and we got what we need + return Ok(Self { + url, + method, + headers, + body, + }); + }; + // Anything else is invalid + Err(LuaError::RuntimeError(format!( + "Invalid request config - expected string or table, got {}", + value.type_name() + ))) + } +} + +// Net serve config + pub struct ServeConfig<'a> { pub handle_request: LuaFunction<'a>, pub handle_web_socket: Option>, diff --git a/packages/lib/src/lua/net/mod.rs b/packages/lib/src/lua/net/mod.rs index 637f49e..f75970c 100644 --- a/packages/lib/src/lua/net/mod.rs +++ b/packages/lib/src/lua/net/mod.rs @@ -1,9 +1,11 @@ mod client; mod config; +mod server; // mod ws_client; // mod ws_server; pub use client::{NetClient, NetClientBuilder}; -pub use config::ServeConfig; +pub use config::{RequestConfig, ServeConfig}; +pub use server::{NetLocalExec, NetService}; // pub use ws_client::NetWebSocketClient; // pub use ws_server::NetWebSocketServer; diff --git a/packages/lib/src/lua/net/server.rs b/packages/lib/src/lua/net/server.rs new file mode 100644 index 0000000..7ec690d --- /dev/null +++ b/packages/lib/src/lua/net/server.rs @@ -0,0 +1,200 @@ +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use mlua::prelude::*; + +use hyper::{body::to_bytes, server::conn::AddrStream, service::Service}; +use hyper::{Body, Request, Response}; +use hyper_tungstenite::{is_upgrade_request as is_ws_upgrade_request, upgrade as ws_upgrade}; +use tokio::task; + +use crate::utils::table::TableBuilder; + +// Hyper service implementation for net, lots of boilerplate here +// but make_svc and make_svc_function do not work for what we need + +pub struct NetServiceInner( + &'static Lua, + Arc, + Arc>, +); + +impl Service> for NetServiceInner { + type Response = Response; + type Error = LuaError; + type Future = Pin>>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let lua = self.0; + if self.2.is_some() && is_ws_upgrade_request(&req) { + // Websocket upgrade request + websocket handler exists, + // we should now upgrade this connection to a websocket + // and then call our handler with a new socket object + let kopt = self.2.clone(); + let key = kopt.as_ref().as_ref().unwrap(); + let _handler: LuaFunction = lua.registry_value(key).expect("Missing websocket handler"); + let (response, ws) = ws_upgrade(&mut req, None).expect("Failed to upgrade websocket"); + // TODO: This should be spawned as part of the scheduler, + // the scheduler may exit early and cancel this even though what + // we want here is a long-running task that keeps the program alive + task::spawn_local(async move { + // Create our new full websocket object, then + // schedule our handler to get called asap + let _ws = ws.await.map_err(LuaError::external)?; + // let sock = NetWebSocketServer::from(ws); + // let table = sock.into_lua_table(lua)?; + // let sched = lua.app_data_mut::<&TaskScheduler>().unwrap(); + // sched.schedule_current_resume( + // LuaValue::Function(handler), + // LuaMultiValue::from_vec(vec![LuaValue::Table(table)]), + // ) + Ok::<_, LuaError>(()) + }); + Box::pin(async move { Ok(response) }) + } else { + // Got a normal http request or no websocket handler + // exists, just call the http request handler + let key = self.1.clone(); + let (parts, body) = req.into_parts(); + Box::pin(async move { + // Convert request body into bytes, extract handler + // function & lune message sender to use later + let bytes = to_bytes(body).await.map_err(LuaError::external)?; + let handler: LuaFunction = lua.registry_value(&key)?; + // Create a readonly table for the request query params + let query_params = TableBuilder::new(lua)? + .with_values( + parts + .uri + .query() + .unwrap_or_default() + .split('&') + .filter_map(|q| q.split_once('=')) + .collect(), + )? + .build_readonly()?; + // Do the same for headers + let header_map = TableBuilder::new(lua)? + .with_values( + parts + .headers + .iter() + .map(|(name, value)| { + (name.to_string(), value.to_str().unwrap().to_string()) + }) + .collect(), + )? + .build_readonly()?; + // Create a readonly table with request info to pass to the handler + let request = TableBuilder::new(lua)? + .with_value("path", parts.uri.path())? + .with_value("query", query_params)? + .with_value("method", parts.method.as_str())? + .with_value("headers", header_map)? + .with_value("body", lua.create_string(&bytes)?)? + .build_readonly()?; + match handler.call(request) { + // Plain strings from the handler are plaintext responses + Ok(LuaValue::String(s)) => Ok(Response::builder() + .status(200) + .header("Content-Type", "text/plain") + .body(Body::from(s.as_bytes().to_vec())) + .unwrap()), + // Tables are more detailed responses with potential status, headers, body + Ok(LuaValue::Table(t)) => { + let status = t.get::<_, Option>("status")?.unwrap_or(200); + let mut resp = Response::builder().status(status); + + if let Some(headers) = t.get::<_, Option>("headers")? { + for pair in headers.pairs::() { + let (h, v) = pair?; + resp = resp.header(&h, v.as_bytes()); + } + } + + let body = t + .get::<_, Option>("body")? + .map(|b| Body::from(b.as_bytes().to_vec())) + .unwrap_or_else(Body::empty); + + Ok(resp.body(body).unwrap()) + } + // If the handler returns an error, generate a 5xx response + Err(_) => { + // TODO: Send above error to task scheduler so that it can emit properly + Ok(Response::builder() + .status(500) + .body(Body::from("Internal Server Error")) + .unwrap()) + } + // If the handler returns a value that is of an invalid type, + // this should also be an error, so generate a 5xx response + Ok(value) => { + // TODO: Send below error to task scheduler so that it can emit properly + let _ = LuaError::RuntimeError(format!( + "Expected net serve handler to return a value of type 'string' or 'table', got '{}'", + value.type_name() + )); + Ok(Response::builder() + .status(500) + .body(Body::from("Internal Server Error")) + .unwrap()) + } + } + }) + } + } +} + +pub struct NetService( + &'static Lua, + Arc, + Arc>, +); + +impl NetService { + pub fn new( + lua: &'static Lua, + callback_http: LuaRegistryKey, + callback_websocket: Option, + ) -> Self { + Self(lua, Arc::new(callback_http), Arc::new(callback_websocket)) + } +} + +impl Service<&AddrStream> for NetService { + type Response = NetServiceInner; + type Error = hyper::Error; + type Future = Pin>>>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: &AddrStream) -> Self::Future { + let lua = self.0; + let key1 = self.1.clone(); + let key2 = self.2.clone(); + Box::pin(async move { Ok(NetServiceInner(lua, key1, key2)) }) + } +} + +#[derive(Clone, Copy, Debug)] +pub struct NetLocalExec; + +impl hyper::rt::Executor for NetLocalExec +where + F: std::future::Future + 'static, // not requiring `Send` +{ + fn execute(&self, fut: F) { + task::spawn_local(fut); + } +}