diff --git a/src/lune/builtins/net/config.rs b/src/lune/builtins/net/config.rs index b754ca8..030288e 100644 --- a/src/lune/builtins/net/config.rs +++ b/src/lune/builtins/net/config.rs @@ -140,9 +140,11 @@ impl FromLua<'_> for RequestConfig { // Net serve config +#[derive(Debug)] pub struct ServeConfig<'a> { pub handle_request: LuaFunction<'a>, pub handle_web_socket: Option>, + pub address: Option>, } impl<'lua> FromLua<'lua> for ServeConfig<'lua> { @@ -152,11 +154,13 @@ impl<'lua> FromLua<'lua> for ServeConfig<'lua> { return Ok(ServeConfig { handle_request: f.clone(), handle_web_socket: None, + address: None, }) } LuaValue::Table(t) => { let handle_request: Option = t.raw_get("handleRequest")?; let handle_web_socket: Option = t.raw_get("handleWebSocket")?; + let address: Option = t.raw_get("address")?; if handle_request.is_some() || handle_web_socket.is_some() { return Ok(ServeConfig { handle_request: handle_request.unwrap_or_else(|| { @@ -174,6 +178,7 @@ impl<'lua> FromLua<'lua> for ServeConfig<'lua> { .expect("Failed to create default http responder function") }), handle_web_socket, + address, }); } else { Some("Missing handleRequest and / or handleWebSocket".to_string()) diff --git a/src/lune/builtins/net/mod.rs b/src/lune/builtins/net/mod.rs index cf9e7ab..78c3397 100644 --- a/src/lune/builtins/net/mod.rs +++ b/src/lune/builtins/net/mod.rs @@ -1,10 +1,15 @@ +use std::net::Ipv4Addr; + use mlua::prelude::*; use hyper::header::CONTENT_ENCODING; use crate::lune::{scheduler::Scheduler, util::TableBuilder}; -use self::{server::create_server, util::header_map_to_table}; +use self::{ + server::{bind_to_addr, create_server}, + util::header_map_to_table, +}; use super::serde::{ compress_decompress::{decompress, CompressDecompressFormat}, @@ -21,9 +26,10 @@ mod websocket; use client::{NetClient, NetClientBuilder}; use config::{RequestConfig, ServeConfig}; -use server::bind_to_localhost; use websocket::NetWebSocket; +const DEFAULT_IP_ADDRESS: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1); + pub fn create(lua: &'static Lua) -> LuaResult { NetClientBuilder::new() .headers(&[("User-Agent", create_user_agent_header())])? @@ -137,7 +143,22 @@ where .app_data_ref::<&Scheduler>() .expect("Lua struct is missing scheduler"); - let builder = bind_to_localhost(port)?; + let address: Ipv4Addr = match &config.address { + Some(addr) => { + let addr_str = addr.to_str()?; + + addr_str + .trim_start_matches("http://") + .trim_start_matches("https://") + .parse() + .map_err(|_e| LuaError::RuntimeError(format!( + "IP address format is incorrect (expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', got '{addr_str}')" + )))? + } + None => DEFAULT_IP_ADDRESS, + }; + + let builder = bind_to_addr(address, port)?; create_server(lua, &sched, config, builder) } diff --git a/src/lune/builtins/net/server.rs b/src/lune/builtins/net/server.rs index 97f22c7..8fb4612 100644 --- a/src/lune/builtins/net/server.rs +++ b/src/lune/builtins/net/server.rs @@ -1,4 +1,9 @@ -use std::{collections::HashMap, convert::Infallible, net::SocketAddr, sync::Arc}; +use std::{ + collections::HashMap, + convert::Infallible, + net::{Ipv4Addr, SocketAddr}, + sync::Arc, +}; use hyper::{ server::{conn::AddrIncoming, Builder}, @@ -20,12 +25,13 @@ use super::{ websocket::NetWebSocket, }; -pub(super) fn bind_to_localhost(port: u16) -> LuaResult> { - let addr = SocketAddr::from(([127, 0, 0, 1], port)); +pub(super) fn bind_to_addr(address: Ipv4Addr, port: u16) -> LuaResult> { + let addr = SocketAddr::from((address, port)); + match Server::try_bind(&addr) { Ok(b) => Ok(b), Err(e) => Err(LuaError::external(format!( - "Failed to bind to localhost on port {port}\n{}", + "Failed to bind to {addr}\n{}", e.to_string() .replace("error creating server listener: ", "> ") ))), diff --git a/tests/net/serve/requests.luau b/tests/net/serve/requests.luau index b17ad2d..c588992 100644 --- a/tests/net/serve/requests.luau +++ b/tests/net/serve/requests.luau @@ -5,8 +5,13 @@ local task = require("@lune/task") local PORT = 8080 local URL = `http://127.0.0.1:{PORT}` +local URL_EXTERNAL = `http://0.0.0.0` local RESPONSE = "Hello, lune!" +-- A server should never be running before testing +local isRunning = pcall(net.request, URL) +assert(not isRunning, `a server is already running at {URL}`) + -- Serve should not block the thread from continuing local thread = task.delay(1, function() @@ -77,3 +82,32 @@ assert( or string.find(message, "shut down"), "The error message for calling stop twice on the net serve handle should be descriptive" ) + +-- Serve should be able to bind to other IP addresses +local handle2 = net.serve(PORT, { + address = URL_EXTERNAL, + handleRequest = function(request) + return `Response from {URL_EXTERNAL}:{PORT}` + end, +}) + +-- And any requests to that IP should succeed +local response3 = net.request(`{URL_EXTERNAL}:{PORT}`).body +assert(response3 ~= nil, "Invalid response from server") + +handle2.stop() + +-- Attempting to serve with a malformed IP address should throw an error +local success3 = pcall(function() + net.serve(8080, { + address = "a.b.c.d", + handleRequest = function() + return RESPONSE + end, + }) +end) + +assert(not success3, "Server was created with malformed address") + +-- We have to manually exit so Windows CI doesn't get stuck forever +process.exit(0)