diff --git a/src/lune/builtins/net/config.rs b/src/lune/builtins/net/config.rs index 3526579..725305a 100644 --- a/src/lune/builtins/net/config.rs +++ b/src/lune/builtins/net/config.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, net::Ipv4Addr}; use mlua::prelude::*; @@ -6,6 +6,18 @@ use reqwest::Method; use super::util::table_to_hash_map; +const DEFAULT_IP_ADDRESS: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1); + +const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#" +return { + status = 426, + body = "Upgrade Required", + headers = { + Upgrade = "websocket", + }, +} +"#; + // Net request config #[derive(Debug, Clone)] @@ -21,28 +33,29 @@ impl Default for RequestConfigOptions { impl<'lua> FromLua<'lua> for RequestConfigOptions { fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { - // Nil means default options, table means custom options if let LuaValue::Nil = value { - return Ok(Self::default()); + // Nil means default options + Ok(Self::default()) } else if let LuaValue::Table(tab) = value { - // Extract flags + // Table means custom options let decompress = match tab.get::<_, Option>("decompress") { Ok(decomp) => Ok(decomp.unwrap_or(true)), Err(_) => Err(LuaError::RuntimeError( "Invalid option value for 'decompress' in request config options".to_string(), )), }?; - return Ok(Self { decompress }); + Ok(Self { decompress }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfigOptions", + message: Some(format!( + "Invalid request config options - expected table or nil, got {}", + value.type_name() + )), + }) } - // Anything else is invalid - Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: "RequestConfigOptions", - message: Some(format!( - "Invalid request config options - expected table or nil, got {}", - value.type_name() - )), - }) } } @@ -60,17 +73,16 @@ impl FromLua<'_> for RequestConfig { fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { // If we just got a string we assume its a GET request to a given url if let LuaValue::String(s) = value { - return Ok(Self { + Ok(Self { url: s.to_string_lossy().to_string(), method: Method::GET, query: HashMap::new(), headers: HashMap::new(), body: None, options: Default::default(), - }); - } - // If we got a table we are able to configure the entire request - if let LuaValue::Table(tab) = value { + }) + } else if let LuaValue::Table(tab) = value { + // If we got a table we are able to configure the entire request // Extract url let url = match tab.get::<_, LuaString>("url") { Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), @@ -117,24 +129,25 @@ impl FromLua<'_> for RequestConfig { Err(_) => RequestConfigOptions::default(), }; // All good, validated and we got what we need - return Ok(Self { + Ok(Self { url, method, query, headers, body, options, - }); - }; - // Anything else is invalid - Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: "RequestConfig", - message: Some(format!( - "Invalid request config - expected string or table, got {}", - value.type_name() - )), - }) + }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfig", + message: Some(format!( + "Invalid request config - expected string or table, got {}", + value.type_name() + )), + }) + } } } @@ -142,54 +155,72 @@ impl FromLua<'_> for RequestConfig { #[derive(Debug)] pub struct ServeConfig<'a> { + pub address: Ipv4Addr, pub handle_request: LuaFunction<'a>, pub handle_web_socket: Option>, - pub address: Option>, } impl<'lua> FromLua<'lua> for ServeConfig<'lua> { fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { - let message = match &value { - LuaValue::Function(f) => { - return Ok(ServeConfig { - handle_request: f.clone(), - handle_web_socket: None, - address: None, + if let LuaValue::Function(f) = &value { + // Single function = request handler, rest is default + Ok(ServeConfig { + handle_request: f.clone(), + handle_web_socket: None, + address: DEFAULT_IP_ADDRESS.clone(), + }) + } else if let LuaValue::Table(t) = &value { + // Table means custom options + let address: Option = t.get("address")?; + let handle_request: Option = t.get("handleRequest")?; + let handle_web_socket: Option = t.get("handleWebSocket")?; + if handle_request.is_some() || handle_web_socket.is_some() { + let address: Ipv4Addr = match &address { + Some(addr) => { + let addr_str = addr.to_str()?; + + addr_str + .trim_start_matches("http://") + .trim_start_matches("https://") + .parse() + .map_err(|_e| LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig", + message: Some(format!( + "IP address format is incorrect - \ + expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \ + got '{addr_str}'" + )), + })? + } + None => DEFAULT_IP_ADDRESS, + }; + + Ok(Self { + address, + handle_request: handle_request.unwrap_or_else(|| { + lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER) + .into_function() + .expect("Failed to create default http responder function") + }), + handle_web_socket, + }) + } else { + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig", + message: Some(String::from( + "Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function", + )), }) } - LuaValue::Table(t) => { - let handle_request: Option = t.get("handleRequest")?; - let handle_web_socket: Option = t.get("handleWebSocket")?; - let address: Option = t.get("address")?; - if handle_request.is_some() || handle_web_socket.is_some() { - return Ok(ServeConfig { - handle_request: handle_request.unwrap_or_else(|| { - let chunk = r#" - return { - status = 426, - body = "Upgrade Required", - headers = { - Upgrade = "websocket", - }, - } - "#; - lua.load(chunk) - .into_function() - .expect("Failed to create default http responder function") - }), - handle_web_socket, - address, - }); - } else { - Some("Missing handleRequest and / or handleWebSocket".to_string()) - } - } - _ => None, - }; - Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: "ServeConfig", - message, - }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig", + message: None, + }) + } } }