From 65def158d235e6b63f3b5a4ebd6581b846397ca8 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Mon, 12 Feb 2024 16:09:40 +0100 Subject: [PATCH] Add back util and config files --- src/lune/builtins/net/config.rs | 195 ++++++++++++++++++++++++++++++++ src/lune/builtins/net/mod.rs | 14 ++- src/lune/builtins/net/util.rs | 81 +++++++++++++ 3 files changed, 287 insertions(+), 3 deletions(-) create mode 100644 src/lune/builtins/net/config.rs create mode 100644 src/lune/builtins/net/util.rs diff --git a/src/lune/builtins/net/config.rs b/src/lune/builtins/net/config.rs new file mode 100644 index 0000000..030288e --- /dev/null +++ b/src/lune/builtins/net/config.rs @@ -0,0 +1,195 @@ +use std::collections::HashMap; + +use mlua::prelude::*; + +use reqwest::Method; + +use super::util::table_to_hash_map; + +// Net request config + +#[derive(Debug, Clone)] +pub struct RequestConfigOptions { + pub decompress: bool, +} + +impl Default for RequestConfigOptions { + fn default() -> Self { + Self { decompress: true } + } +} + +impl<'lua> FromLua<'lua> for RequestConfigOptions { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + // Nil means default options, table means custom options + if let LuaValue::Nil = value { + return Ok(Self::default()); + } else if let LuaValue::Table(tab) = value { + // Extract flags + let decompress = match tab.raw_get::<_, Option>("decompress") { + Ok(decomp) => Ok(decomp.unwrap_or(true)), + Err(_) => Err(LuaError::RuntimeError( + "Invalid option value for 'decompress' in request config options".to_string(), + )), + }?; + return Ok(Self { decompress }); + } + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfigOptions", + message: Some(format!( + "Invalid request config options - expected table or nil, got {}", + value.type_name() + )), + }) + } +} + +#[derive(Debug, Clone)] +pub struct RequestConfig { + pub url: String, + pub method: Method, + pub query: HashMap>, + pub headers: HashMap>, + pub body: Option>, + pub options: RequestConfigOptions, +} + +impl FromLua<'_> for RequestConfig { + fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { + // If we just got a string we assume its a GET request to a given url + if let LuaValue::String(s) = value { + return Ok(Self { + url: s.to_string_lossy().to_string(), + method: Method::GET, + query: HashMap::new(), + headers: HashMap::new(), + body: None, + options: Default::default(), + }); + } + // If we got a table we are able to configure the entire request + if let LuaValue::Table(tab) = value { + // Extract url + let url = match tab.raw_get::<_, LuaString>("url") { + Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), + Err(_) => Err(LuaError::runtime("Missing 'url' in request config")), + }?; + // Extract method + let method = match tab.raw_get::<_, LuaString>("method") { + Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(), + Err(_) => "GET".to_string(), + }; + // Extract query + let query = match tab.raw_get::<_, LuaTable>("query") { + Ok(tab) => table_to_hash_map(tab, "query")?, + Err(_) => HashMap::new(), + }; + // Extract headers + let headers = match tab.raw_get::<_, LuaTable>("headers") { + Ok(tab) => table_to_hash_map(tab, "headers")?, + Err(_) => HashMap::new(), + }; + // Extract body + let body = match tab.raw_get::<_, LuaString>("body") { + Ok(config_body) => Some(config_body.as_bytes().to_owned()), + Err(_) => None, + }; + // Convert method string into proper enum + let method = method.trim().to_ascii_uppercase(); + let method = match method.as_ref() { + "GET" => Ok(Method::GET), + "POST" => Ok(Method::POST), + "PUT" => Ok(Method::PUT), + "DELETE" => Ok(Method::DELETE), + "HEAD" => Ok(Method::HEAD), + "OPTIONS" => Ok(Method::OPTIONS), + "PATCH" => Ok(Method::PATCH), + _ => Err(LuaError::RuntimeError(format!( + "Invalid request config method '{}'", + &method + ))), + }?; + // Parse any extra options given + let options = match tab.raw_get::<_, LuaValue>("options") { + Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?, + Err(_) => RequestConfigOptions::default(), + }; + // All good, validated and we got what we need + return Ok(Self { + url, + method, + query, + headers, + body, + options, + }); + }; + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfig", + message: Some(format!( + "Invalid request config - expected string or table, got {}", + value.type_name() + )), + }) + } +} + +// Net serve config + +#[derive(Debug)] +pub struct ServeConfig<'a> { + pub handle_request: LuaFunction<'a>, + pub handle_web_socket: Option>, + pub address: Option>, +} + +impl<'lua> FromLua<'lua> for ServeConfig<'lua> { + fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { + let message = match &value { + LuaValue::Function(f) => { + return Ok(ServeConfig { + handle_request: f.clone(), + handle_web_socket: None, + address: None, + }) + } + LuaValue::Table(t) => { + let handle_request: Option = t.raw_get("handleRequest")?; + let handle_web_socket: Option = t.raw_get("handleWebSocket")?; + let address: Option = t.raw_get("address")?; + if handle_request.is_some() || handle_web_socket.is_some() { + return Ok(ServeConfig { + handle_request: handle_request.unwrap_or_else(|| { + let chunk = r#" + return { + status = 426, + body = "Upgrade Required", + headers = { + Upgrade = "websocket", + }, + } + "#; + lua.load(chunk) + .into_function() + .expect("Failed to create default http responder function") + }), + handle_web_socket, + address, + }); + } else { + Some("Missing handleRequest and / or handleWebSocket".to_string()) + } + } + _ => None, + }; + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig", + message, + }) + } +} diff --git a/src/lune/builtins/net/mod.rs b/src/lune/builtins/net/mod.rs index 49b6886..c5cb035 100644 --- a/src/lune/builtins/net/mod.rs +++ b/src/lune/builtins/net/mod.rs @@ -2,8 +2,13 @@ use mlua::prelude::*; +mod config; +mod util; + use crate::lune::util::TableBuilder; +use self::config::{RequestConfig, ServeConfig}; + use super::serde::encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat}; pub fn create(lua: &Lua) -> LuaResult { @@ -38,15 +43,18 @@ fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult(lua: &'lua Lua, config: ()) -> LuaResult> { +async fn net_request<'lua>(lua: &'lua Lua, config: RequestConfig) -> LuaResult> { unimplemented!() } -async fn net_socket<'lua>(lua: &'lua Lua, url: String) -> LuaResult { +async fn net_socket<'lua>(lua: &'lua Lua, url: String) -> LuaResult> { unimplemented!() } -async fn net_serve<'lua>(lua: &'lua Lua, (port, config): (u16, ())) -> LuaResult> { +async fn net_serve<'lua>( + lua: &'lua Lua, + (port, config): (u16, ServeConfig<'lua>), +) -> LuaResult> { unimplemented!() } diff --git a/src/lune/builtins/net/util.rs b/src/lune/builtins/net/util.rs new file mode 100644 index 0000000..fa1e2ad --- /dev/null +++ b/src/lune/builtins/net/util.rs @@ -0,0 +1,81 @@ +use std::collections::HashMap; + +use hyper::{ + header::{CONTENT_ENCODING, CONTENT_LENGTH}, + HeaderMap, +}; + +use mlua::prelude::*; + +use crate::lune::util::TableBuilder; + +pub fn header_map_to_table( + lua: &Lua, + headers: HeaderMap, + remove_content_headers: bool, +) -> LuaResult { + let mut res_headers: HashMap> = HashMap::new(); + for (name, value) in headers.iter() { + let name = name.as_str(); + let value = value.to_str().unwrap().to_owned(); + if let Some(existing) = res_headers.get_mut(name) { + existing.push(value); + } else { + res_headers.insert(name.to_owned(), vec![value]); + } + } + + if remove_content_headers { + let content_encoding_header_str = CONTENT_ENCODING.as_str(); + let content_length_header_str = CONTENT_LENGTH.as_str(); + res_headers.retain(|name, _| { + name != content_encoding_header_str && name != content_length_header_str + }); + } + + let mut builder = TableBuilder::new(lua)?; + for (name, mut values) in res_headers { + if values.len() == 1 { + let value = values.pop().unwrap().into_lua(lua)?; + builder = builder.with_value(name, value)?; + } else { + let values = TableBuilder::new(lua)? + .with_sequential_values(values)? + .build_readonly()? + .into_lua(lua)?; + builder = builder.with_value(name, values)?; + } + } + + builder.build_readonly() +} + +pub fn table_to_hash_map( + tab: LuaTable, + tab_origin_key: &'static str, +) -> LuaResult>> { + let mut map = HashMap::new(); + + for pair in tab.pairs::() { + let (key, value) = pair?; + match value { + LuaValue::String(s) => { + map.insert(key, vec![s.to_str()?.to_owned()]); + } + LuaValue::Table(t) => { + let mut values = Vec::new(); + for value in t.sequence_values::() { + values.push(value?.to_str()?.to_owned()); + } + map.insert(key, values); + } + _ => { + return Err(LuaError::runtime(format!( + "Value for '{tab_origin_key}' must be a string or array of strings", + ))) + } + } + } + + Ok(map) +}