diff --git a/src/lune/builtins/net/mod.rs b/src/lune/builtins/net/mod.rs index 85aae45..615136d 100644 --- a/src/lune/builtins/net/mod.rs +++ b/src/lune/builtins/net/mod.rs @@ -15,13 +15,13 @@ use super::serde::{ mod client; mod config; +mod processing; mod response; mod server; mod websocket; use client::{NetClient, NetClientBuilder}; use config::{RequestConfig, ServeConfig}; -use response::NetServeResponse; use server::bind_to_localhost; use websocket::NetWebSocket; diff --git a/src/lune/builtins/net/processing.rs b/src/lune/builtins/net/processing.rs new file mode 100644 index 0000000..d770665 --- /dev/null +++ b/src/lune/builtins/net/processing.rs @@ -0,0 +1,101 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + +use hyper::{body::to_bytes, Body, Request}; + +use mlua::prelude::*; + +use crate::lune::util::TableBuilder; + +static ID_COUNTER: AtomicUsize = AtomicUsize::new(0); + +#[derive(Debug, Clone, Copy)] +pub(super) struct ProcessedRequestId(usize); + +impl ProcessedRequestId { + pub fn new() -> Self { + // NOTE: This may overflow after a couple billion requests, + // but that's completely fine... unless a request is still + // alive after billions more arrive and need to be handled + Self(ID_COUNTER.fetch_add(1, Ordering::Relaxed)) + } +} + +pub(super) struct ProcessedRequest { + id: ProcessedRequestId, + method: String, + path: String, + query: Vec<(String, String)>, + headers: Vec<(String, Vec)>, + body: Vec, +} + +impl ProcessedRequest { + pub async fn from_request(req: Request) -> LuaResult { + let (head, body) = req.into_parts(); + + // FUTURE: We can do extra processing like async decompression here + let body = match to_bytes(body).await { + Err(_) => return Err(LuaError::runtime("Failed to read request body bytes")), + Ok(b) => b.to_vec(), + }; + + let method = head.method.to_string().to_ascii_uppercase(); + + let mut path = head.uri.path().to_string(); + if path.is_empty() { + path = "/".to_string(); + } + + let query = head + .uri + .query() + .unwrap_or_default() + .split('&') + .filter_map(|q| q.split_once('=')) + .map(|(k, v)| (k.to_string(), v.to_string())) + .collect(); + + let mut headers = Vec::new(); + let mut header_name = String::new(); + for (name_opt, value) in head.headers.into_iter() { + if let Some(name) = name_opt { + header_name = name.to_string(); + } + headers.push((header_name.clone(), value.as_bytes().to_vec())) + } + + let id = ProcessedRequestId::new(); + + Ok(Self { + id, + method, + path, + query, + headers, + body, + }) + } + + pub fn into_lua_table(self, lua: &Lua) -> LuaResult { + // FUTURE: Make inner tables for query keys that have multiple values? + let query = lua.create_table_with_capacity(0, self.query.len())?; + for (key, value) in self.query.into_iter() { + query.set(key, value)?; + } + + let headers = lua.create_table_with_capacity(0, self.headers.len())?; + for (key, value) in self.headers.into_iter() { + headers.set(key, lua.create_string(value)?)?; + } + + let body = lua.create_string(self.body)?; + + TableBuilder::new(lua)? + .with_value("method", self.method)? + .with_value("path", self.path)? + .with_value("query", query)? + .with_value("headers", headers)? + .with_value("body", body)? + .build_readonly() + } +} diff --git a/src/lune/builtins/net/response.rs b/src/lune/builtins/net/response.rs index fa2e748..d14646a 100644 --- a/src/lune/builtins/net/response.rs +++ b/src/lune/builtins/net/response.rs @@ -9,7 +9,7 @@ pub enum NetServeResponseKind { Table, } -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct NetServeResponse { kind: NetServeResponseKind, status: u16, @@ -81,26 +81,3 @@ impl<'lua> FromLua<'lua> for NetServeResponse { } } } - -impl<'lua> IntoLua<'lua> for NetServeResponse { - fn into_lua(self, lua: &'lua Lua) -> LuaResult> { - if self.headers.len() > i32::MAX as usize { - return Err(LuaError::ToLuaConversionError { - from: "NetServeResponse", - to: "table", - message: Some("Too many header values".to_string()), - }); - } - let body = self.body.map(|b| lua.create_string(b)).transpose()?; - let headers = lua.create_table_with_capacity(0, self.headers.len())?; - for (key, value) in self.headers { - headers.set(key, lua.create_string(&value)?)?; - } - let table = lua.create_table_with_capacity(0, 3)?; - table.set("status", self.status)?; - table.set("headers", headers)?; - table.set("body", body)?; - table.set_readonly(true); - Ok(LuaValue::Table(table)) - } -} diff --git a/src/lune/builtins/net/server.rs b/src/lune/builtins/net/server.rs index c99a225..ab46365 100644 --- a/src/lune/builtins/net/server.rs +++ b/src/lune/builtins/net/server.rs @@ -1,16 +1,20 @@ -use std::{convert::Infallible, net::SocketAddr}; +use std::{convert::Infallible, net::SocketAddr, sync::Arc}; use hyper::{ server::{conn::AddrIncoming, Builder}, service::{make_service_fn, service_fn}, Response, Server, }; + use mlua::prelude::*; use tokio::sync::mpsc; -use crate::lune::{scheduler::Scheduler, util::TableBuilder}; +use crate::{ + lune::{scheduler::Scheduler, util::TableBuilder}, + LuneError, +}; -use super::config::ServeConfig; +use super::{config::ServeConfig, processing::ProcessedRequest, response::NetServeResponse}; pub(super) fn bind_to_localhost(port: u16) -> LuaResult> { let addr = match SocketAddr::try_from(([127, 0, 0, 1], port)) { @@ -45,19 +49,43 @@ where // into our table with the stop function let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - // Spawn a scheduler background task, communicate using mpsc - // channels, do any heavy lifting possible in background thread - let (tx_request, mut rx_request) = mpsc::channel::<()>(64); + // Communicate between background thread(s) and main lua thread using mpsc + let (tx_request, mut rx_request) = mpsc::channel::(64); let (tx_websocket, mut rx_websocket) = mpsc::channel::<()>(64); + let tx_request_arc = Arc::new(tx_request); + let tx_websocket_arc = Arc::new(tx_websocket); + + // Create our background service which will accept + // requests, do some processing, then forward to lua + let hyper_make_service = make_service_fn(move |_| { + let tx_request = Arc::clone(&tx_request_arc); + let tx_websocket = Arc::clone(&tx_websocket_arc); + + let handler = service_fn(move |req| { + // TODO: Check if we should upgrade to a + // websocket, handle the request differently + let tx_request = Arc::clone(&tx_request); + let tx_websocket = Arc::clone(&tx_websocket); + async move { + let processed = ProcessedRequest::from_request(req).await?; + if (tx_request.send(processed).await).is_err() { + return Err(LuaError::runtime("Lua handler is busy")); + } + // TODO: Wait for response from lua + let res = Response::new("TODO".to_string()); + Ok::<_, LuaError>(res) + } + }); + + async move { Ok::<_, Infallible>(handler) } + }); + + // Start up our service sched.spawn(async move { let result = builder - .serve(make_service_fn(|_| async move { - Ok::<_, Infallible>(service_fn(|_req| async move { - // TODO: Send this request back to lua - let res = Response::new("TODO".to_string()); - Ok::<_, Infallible>(res) - })) - })) + .http1_only(true) // Web sockets can only use http1 + .http1_keepalive(true) // Web sockets must be kept alive + .serve(hyper_make_service) .with_graceful_shutdown(async move { shutdown_rx.recv().await; }); @@ -66,42 +94,58 @@ where } }); - // Spawn a local thread with access to lua, this will get - // requests and sockets to handle using our lua handlers + // Spawn a local thread with access to lua and the same lifetime sched.spawn_local(async move { loop { + // Wait for either a request or a websocket to handle, + // if we got neither it means both channels were dropped + // and our server has stopped, either gracefully or panic let (req, sock) = tokio::select! { req = rx_request.recv() => (req, None), sock = rx_websocket.recv() => (None, sock), }; - if req.is_none() && sock.is_none() { - break; - } - if let Some(_req) = req { - // TODO: Convert request into lua request struct - let thread_id = sched - .push_back(lua, config.handle_request.clone(), ()) - .expect("Failed to spawn net serve handler"); - // TODO: Send response back to other thread somehow - match sched.wait_for_thread(lua, thread_id).await { - Err(e) => eprintln!("Net serve handler error: {e}"), - Ok(v) => println!("Net serve handler result: {v:?}"), - }; - } - if let Some(_sock) = sock { - let handle_web_socket = config - .handle_web_socket - .as_ref() - .expect("Got web socket but web socket handler is missing"); - // TODO: Convert request into lua request struct - let thread_id = sched - .push_back(lua, handle_web_socket.clone(), ()) - .expect("Failed to spawn net websocket handler"); - // TODO: Send response back to other thread somehow - match sched.wait_for_thread(lua, thread_id).await { - Err(e) => eprintln!("Net websocket handler error: {e}"), - Ok(v) => println!("Net websocket handler result: {v:?}"), - }; + + // NOTE: The closure here is not really necessary, we + // make the closure so that we can use the `?` operator + let handle_req_or_sock = || async { + match (req, sock) { + (None, None) => Ok::<_, LuaError>(true), + (Some(req), _) => { + let req_table = req.into_lua_table(lua)?; + let req_handler = config.handle_request.clone(); + + let thread_id = sched.push_back(lua, req_handler, req_table)?; + let thread_res = sched.wait_for_thread(lua, thread_id).await?; + + // TODO: Send response back to other thread somehow + let handler_res = NetServeResponse::from_lua_multi(thread_res, lua)?; + + Ok(false) + } + (_, Some(_sock)) => { + let sock_handler = config + .handle_web_socket + .as_ref() + .cloned() + .expect("Got web socket but web socket handler is missing"); + + // TODO: Convert websocket into lua websocket struct, give as args + let thread_id = sched.push_back(lua, sock_handler, ())?; + + // NOTE: Web socket handler does not need to send any + // response back, the websocket upgrade response is + // automatically sent above in the background thread(s) + sched.wait_for_thread(lua, thread_id).await?; + + Ok(false) + } + } + }; + + match handle_req_or_sock().await { + Ok(true) => break, + Ok(false) => continue, + Err(e) => eprintln!("{}", LuneError::from(e)), } } }); diff --git a/tests/net/serve/requests.luau b/tests/net/serve/requests.luau index 1762248..1d033a1 100644 --- a/tests/net/serve/requests.luau +++ b/tests/net/serve/requests.luau @@ -16,8 +16,8 @@ local thread = task.delay(1, function() end) local handle = net.serve(PORT, function(request) - print("Request:", request) - print("Responding with", RESPONSE) + -- print("Request:", request) + -- print("Responding with", RESPONSE) assert(request.path == "/some/path") assert(request.query.key == "param2") assert(request.query.key2 == "param3")