diff --git a/src/lune/builtins/net/processing.rs b/src/lune/builtins/net/processing.rs index d770665..837fd3e 100644 --- a/src/lune/builtins/net/processing.rs +++ b/src/lune/builtins/net/processing.rs @@ -8,7 +8,7 @@ use crate::lune::util::TableBuilder; static ID_COUNTER: AtomicUsize = AtomicUsize::new(0); -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)] pub(super) struct ProcessedRequestId(usize); impl ProcessedRequestId { @@ -21,7 +21,7 @@ impl ProcessedRequestId { } pub(super) struct ProcessedRequest { - id: ProcessedRequestId, + pub id: ProcessedRequestId, method: String, path: String, query: Vec<(String, String)>, diff --git a/src/lune/builtins/net/server.rs b/src/lune/builtins/net/server.rs index ab46365..167f10b 100644 --- a/src/lune/builtins/net/server.rs +++ b/src/lune/builtins/net/server.rs @@ -1,20 +1,24 @@ -use std::{convert::Infallible, net::SocketAddr, sync::Arc}; +use std::{collections::HashMap, convert::Infallible, net::SocketAddr, sync::Arc}; use hyper::{ server::{conn::AddrIncoming, Builder}, service::{make_service_fn, service_fn}, - Response, Server, + Server, }; +use hyper_tungstenite::{is_upgrade_request, upgrade, HyperWebsocket}; use mlua::prelude::*; -use tokio::sync::mpsc; +use tokio::sync::{mpsc, oneshot, Mutex}; -use crate::{ - lune::{scheduler::Scheduler, util::TableBuilder}, - LuneError, +use crate::lune::{ + scheduler::Scheduler, + util::{traits::LuaEmitErrorExt, TableBuilder}, }; -use super::{config::ServeConfig, processing::ProcessedRequest, response::NetServeResponse}; +use super::{ + config::ServeConfig, processing::ProcessedRequest, response::NetServeResponse, + websocket::NetWebSocket, +}; pub(super) fn bind_to_localhost(port: u16) -> LuaResult> { let addr = match SocketAddr::try_from(([127, 0, 0, 1], port)) { @@ -49,31 +53,55 @@ where // into our table with the stop function let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - // Communicate between background thread(s) and main lua thread using mpsc + // Communicate between background thread(s) and main lua thread using mpsc and oneshot let (tx_request, mut rx_request) = mpsc::channel::(64); - let (tx_websocket, mut rx_websocket) = mpsc::channel::<()>(64); + let (tx_websocket, mut rx_websocket) = mpsc::channel::(64); let tx_request_arc = Arc::new(tx_request); let tx_websocket_arc = Arc::new(tx_websocket); + let response_senders = Arc::new(Mutex::new(HashMap::new())); + let response_senders_bg = Arc::clone(&response_senders); + let response_senders_lua = Arc::clone(&response_senders_bg); + // Create our background service which will accept // requests, do some processing, then forward to lua + let has_websocket_handler = config.handle_web_socket.is_some(); let hyper_make_service = make_service_fn(move |_| { let tx_request = Arc::clone(&tx_request_arc); let tx_websocket = Arc::clone(&tx_websocket_arc); + let response_senders = Arc::clone(&response_senders_bg); - let handler = service_fn(move |req| { - // TODO: Check if we should upgrade to a - // websocket, handle the request differently + let handler = service_fn(move |mut req| { let tx_request = Arc::clone(&tx_request); let tx_websocket = Arc::clone(&tx_websocket); + let response_senders = Arc::clone(&response_senders); async move { - let processed = ProcessedRequest::from_request(req).await?; - if (tx_request.send(processed).await).is_err() { - return Err(LuaError::runtime("Lua handler is busy")); + // FUTURE: Improve error messages when lua is busy and queue is full + if has_websocket_handler && is_upgrade_request(&req) { + let (response, ws) = match upgrade(&mut req, None) { + Err(_) => return Err(LuaError::runtime("Failed to upgrade websocket")), + Ok(v) => v, + }; + if (tx_websocket.send(ws).await).is_err() { + return Err(LuaError::runtime("Lua handler is busy")); + } + Ok(response) + } else { + let processed = ProcessedRequest::from_request(req).await?; + let request_id = processed.id; + if (tx_request.send(processed).await).is_err() { + return Err(LuaError::runtime("Lua handler is busy")); + } + let (response_tx, response_rx) = oneshot::channel::(); + response_senders + .lock() + .await + .insert(request_id, response_tx); + match response_rx.await { + Err(_) => Err(LuaError::runtime("Internal Server Error")), + Ok(r) => r.into_response(), + } } - // TODO: Wait for response from lua - let res = Response::new("TODO".to_string()); - Ok::<_, LuaError>(res) } }); @@ -111,31 +139,42 @@ where match (req, sock) { (None, None) => Ok::<_, LuaError>(true), (Some(req), _) => { - let req_table = req.into_lua_table(lua)?; + let req_id = req.id; let req_handler = config.handle_request.clone(); + let req_table = req.into_lua_table(lua)?; let thread_id = sched.push_back(lua, req_handler, req_table)?; let thread_res = sched.wait_for_thread(lua, thread_id).await?; - // TODO: Send response back to other thread somehow - let handler_res = NetServeResponse::from_lua_multi(thread_res, lua)?; + let response = NetServeResponse::from_lua_multi(thread_res, lua)?; + let response_sender = response_senders_lua + .lock() + .await + .remove(&req_id) + .expect("Response channel was removed unexpectedly"); + + // NOTE: We ignore the error here, if the sender is no longer + // being listened to its because our client disconnected during + // handler being called, which is fine and should not emit errors + response_sender.send(response).ok(); Ok(false) } - (_, Some(_sock)) => { + (_, Some(sock)) => { + let sock = sock.await.into_lua_err()?; + let sock_handler = config .handle_web_socket .as_ref() .cloned() .expect("Got web socket but web socket handler is missing"); - - // TODO: Convert websocket into lua websocket struct, give as args - let thread_id = sched.push_back(lua, sock_handler, ())?; + let sock_table = NetWebSocket::new(sock).into_lua_table(lua)?; // NOTE: Web socket handler does not need to send any // response back, the websocket upgrade response is // automatically sent above in the background thread(s) - sched.wait_for_thread(lua, thread_id).await?; + let thread_id = sched.push_back(lua, sock_handler, sock_table)?; + let _thread_res = sched.wait_for_thread(lua, thread_id).await?; Ok(false) } @@ -145,7 +184,7 @@ where match handle_req_or_sock().await { Ok(true) => break, Ok(false) => continue, - Err(e) => eprintln!("{}", LuneError::from(e)), + Err(e) => lua.emit_error(e), } } }); diff --git a/tests/net/serve/websockets.luau b/tests/net/serve/websockets.luau index 0a49e74..ad53a38 100644 --- a/tests/net/serve/websockets.luau +++ b/tests/net/serve/websockets.luau @@ -4,7 +4,6 @@ local stdio = require("@lune/stdio") local task = require("@lune/task") local PORT = 8081 -local URL = `http://127.0.0.1:{PORT}` local WS_URL = `ws://127.0.0.1:{PORT}` local REQUEST = "Hello from client!" local RESPONSE = "Hello, lune!" @@ -19,7 +18,10 @@ end) local handle = net.serve(PORT, { handleRequest = function() - return RESPONSE + stdio.ewrite("Web socket should upgrade automatically, not pass to the request handler\n") + task.wait(1) + process.exit(1) + return "unreachable" end, handleWebSocket = function(socket) local socketMessage = socket.next() @@ -31,22 +33,9 @@ local handle = net.serve(PORT, { task.cancel(thread) --- Serve should respond to a request we send to it - -local thread2 = task.delay(1, function() - stdio.ewrite("Serve should respond to requests in a reasonable amount of time\n") - task.wait(1) - process.exit(1) -end) - -local response = net.request(URL).body -assert(response == RESPONSE, "Invalid response from server") - -task.cancel(thread2) - -- Web socket responses should also be responded to -local thread3 = task.delay(1, function() +local thread2 = task.delay(1, function() stdio.ewrite("Serve should respond to websockets in a reasonable amount of time\n") task.wait(1) process.exit(1) @@ -62,7 +51,7 @@ assert(socketMessage == RESPONSE, "Invalid web socket response from server") socket.close() -task.cancel(thread3) +task.cancel(thread2) -- Wait for the socket to close and make sure we can't send messages afterwards task.wait()