Net config parsing restructure

This commit is contained in:
Filip Tibell 2024-02-13 16:57:38 +01:00
parent 3702bc98bd
commit fc60d5b031
No known key found for this signature in database

View file

@ -1,4 +1,4 @@
use std::collections::HashMap; use std::{collections::HashMap, net::Ipv4Addr};
use mlua::prelude::*; use mlua::prelude::*;
@ -6,6 +6,18 @@ use reqwest::Method;
use super::util::table_to_hash_map; use super::util::table_to_hash_map;
const DEFAULT_IP_ADDRESS: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#"
return {
status = 426,
body = "Upgrade Required",
headers = {
Upgrade = "websocket",
},
}
"#;
// Net request config // Net request config
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -21,19 +33,19 @@ impl Default for RequestConfigOptions {
impl<'lua> FromLua<'lua> for RequestConfigOptions { impl<'lua> FromLua<'lua> for RequestConfigOptions {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> { fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
// Nil means default options, table means custom options
if let LuaValue::Nil = value { if let LuaValue::Nil = value {
return Ok(Self::default()); // Nil means default options
Ok(Self::default())
} else if let LuaValue::Table(tab) = value { } else if let LuaValue::Table(tab) = value {
// Extract flags // Table means custom options
let decompress = match tab.get::<_, Option<bool>>("decompress") { let decompress = match tab.get::<_, Option<bool>>("decompress") {
Ok(decomp) => Ok(decomp.unwrap_or(true)), Ok(decomp) => Ok(decomp.unwrap_or(true)),
Err(_) => Err(LuaError::RuntimeError( Err(_) => Err(LuaError::RuntimeError(
"Invalid option value for 'decompress' in request config options".to_string(), "Invalid option value for 'decompress' in request config options".to_string(),
)), )),
}?; }?;
return Ok(Self { decompress }); Ok(Self { decompress })
} } else {
// Anything else is invalid // Anything else is invalid
Err(LuaError::FromLuaConversionError { Err(LuaError::FromLuaConversionError {
from: value.type_name(), from: value.type_name(),
@ -45,6 +57,7 @@ impl<'lua> FromLua<'lua> for RequestConfigOptions {
}) })
} }
} }
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct RequestConfig { pub struct RequestConfig {
@ -60,17 +73,16 @@ impl FromLua<'_> for RequestConfig {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> { fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
// If we just got a string we assume its a GET request to a given url // If we just got a string we assume its a GET request to a given url
if let LuaValue::String(s) = value { if let LuaValue::String(s) = value {
return Ok(Self { Ok(Self {
url: s.to_string_lossy().to_string(), url: s.to_string_lossy().to_string(),
method: Method::GET, method: Method::GET,
query: HashMap::new(), query: HashMap::new(),
headers: HashMap::new(), headers: HashMap::new(),
body: None, body: None,
options: Default::default(), options: Default::default(),
}); })
} } else if let LuaValue::Table(tab) = value {
// If we got a table we are able to configure the entire request // If we got a table we are able to configure the entire request
if let LuaValue::Table(tab) = value {
// Extract url // Extract url
let url = match tab.get::<_, LuaString>("url") { let url = match tab.get::<_, LuaString>("url") {
Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), Ok(config_url) => Ok(config_url.to_string_lossy().to_string()),
@ -117,15 +129,15 @@ impl FromLua<'_> for RequestConfig {
Err(_) => RequestConfigOptions::default(), Err(_) => RequestConfigOptions::default(),
}; };
// All good, validated and we got what we need // All good, validated and we got what we need
return Ok(Self { Ok(Self {
url, url,
method, method,
query, query,
headers, headers,
body, body,
options, options,
}); })
}; } else {
// Anything else is invalid // Anything else is invalid
Err(LuaError::FromLuaConversionError { Err(LuaError::FromLuaConversionError {
from: value.type_name(), from: value.type_name(),
@ -137,59 +149,78 @@ impl FromLua<'_> for RequestConfig {
}) })
} }
} }
}
// Net serve config // Net serve config
#[derive(Debug)] #[derive(Debug)]
pub struct ServeConfig<'a> { pub struct ServeConfig<'a> {
pub address: Ipv4Addr,
pub handle_request: LuaFunction<'a>, pub handle_request: LuaFunction<'a>,
pub handle_web_socket: Option<LuaFunction<'a>>, pub handle_web_socket: Option<LuaFunction<'a>>,
pub address: Option<LuaString<'a>>,
} }
impl<'lua> FromLua<'lua> for ServeConfig<'lua> { impl<'lua> FromLua<'lua> for ServeConfig<'lua> {
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> { fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
let message = match &value { if let LuaValue::Function(f) = &value {
LuaValue::Function(f) => { // Single function = request handler, rest is default
return Ok(ServeConfig { Ok(ServeConfig {
handle_request: f.clone(), handle_request: f.clone(),
handle_web_socket: None, handle_web_socket: None,
address: None, address: DEFAULT_IP_ADDRESS.clone(),
}) })
} } else if let LuaValue::Table(t) = &value {
LuaValue::Table(t) => { // Table means custom options
let address: Option<LuaString> = t.get("address")?;
let handle_request: Option<LuaFunction> = t.get("handleRequest")?; let handle_request: Option<LuaFunction> = t.get("handleRequest")?;
let handle_web_socket: Option<LuaFunction> = t.get("handleWebSocket")?; let handle_web_socket: Option<LuaFunction> = t.get("handleWebSocket")?;
let address: Option<LuaString> = t.get("address")?;
if handle_request.is_some() || handle_web_socket.is_some() { if handle_request.is_some() || handle_web_socket.is_some() {
return Ok(ServeConfig { let address: Ipv4Addr = match &address {
handle_request: handle_request.unwrap_or_else(|| { Some(addr) => {
let chunk = r#" let addr_str = addr.to_str()?;
return {
status = 426, addr_str
body = "Upgrade Required", .trim_start_matches("http://")
headers = { .trim_start_matches("https://")
Upgrade = "websocket", .parse()
}, .map_err(|_e| LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig",
message: Some(format!(
"IP address format is incorrect - \
expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \
got '{addr_str}'"
)),
})?
} }
"#; None => DEFAULT_IP_ADDRESS,
lua.load(chunk) };
Ok(Self {
address,
handle_request: handle_request.unwrap_or_else(|| {
lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER)
.into_function() .into_function()
.expect("Failed to create default http responder function") .expect("Failed to create default http responder function")
}), }),
handle_web_socket, handle_web_socket,
address, })
});
} else { } else {
Some("Missing handleRequest and / or handleWebSocket".to_string())
}
}
_ => None,
};
Err(LuaError::FromLuaConversionError { Err(LuaError::FromLuaConversionError {
from: value.type_name(), from: value.type_name(),
to: "ServeConfig", to: "ServeConfig",
message, message: Some(String::from(
"Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function",
)),
})
}
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig",
message: None,
}) })
} }
} }
}