From c9ce29741b0e658378ee1cd91210a1b7e9e5b068 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Sat, 30 Dec 2023 17:38:58 +0100 Subject: [PATCH] Add support for multiple query & header values in net request --- CHANGELOG.md | 32 +++++++++++++ src/lune/builtins/net/config.rs | 34 ++++---------- src/lune/builtins/net/mod.rs | 64 ++++++++++++-------------- src/lune/builtins/net/util.rs | 81 +++++++++++++++++++++++++++++++++ types/net.luau | 10 ++-- 5 files changed, 159 insertions(+), 62 deletions(-) create mode 100644 src/lune/builtins/net/util.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 18d2646..ef386b6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Added support for multiple values for a single query, and multiple values for a single header, in `net.request`. This is a part of the HTTP specification that is not widely used but that may be useful in certain cases. To clarify: + + - Single values remain unchanged and will work exactly the same as before.
+ + ```lua + -- https://example.com/?foo=bar&baz=qux + local net = require("@lune/net") + net.request({ + url = "example.com", + query = { + foo = "bar", + baz = "qux", + } + }) + ``` + + - Multiple values _on a single query / header_ are represented as an ordered array of strings.
+ + ```lua + -- https://example.com/?foo=first&foo=second&foo=third&bar=baz + local net = require("@lune/net") + net.request({ + url = "example.com", + query = { + foo = { "first", "second", "third" }, + bar = "baz", + } + }) + ``` + ### Changed - Update to Luau version `0.606`. diff --git a/src/lune/builtins/net/config.rs b/src/lune/builtins/net/config.rs index 59c1e3e..b754ca8 100644 --- a/src/lune/builtins/net/config.rs +++ b/src/lune/builtins/net/config.rs @@ -4,6 +4,8 @@ use mlua::prelude::*; use reqwest::Method; +use super::util::table_to_hash_map; + // Net request config #[derive(Debug, Clone)] @@ -45,17 +47,17 @@ impl<'lua> FromLua<'lua> for RequestConfigOptions { } #[derive(Debug, Clone)] -pub struct RequestConfig<'a> { +pub struct RequestConfig { pub url: String, pub method: Method, - pub query: HashMap, LuaString<'a>>, - pub headers: HashMap, LuaString<'a>>, + pub query: HashMap>, + pub headers: HashMap>, pub body: Option>, pub options: RequestConfigOptions, } -impl<'lua> FromLua<'lua> for RequestConfig<'lua> { - fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { +impl FromLua<'_> for RequestConfig { + fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { // If we just got a string we assume its a GET request to a given url if let LuaValue::String(s) = value { return Ok(Self { @@ -72,9 +74,7 @@ impl<'lua> FromLua<'lua> for RequestConfig<'lua> { // Extract url let url = match tab.raw_get::<_, LuaString>("url") { Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), - Err(_) => Err(LuaError::RuntimeError( - "Missing 'url' in request config".to_string(), - )), + Err(_) => Err(LuaError::runtime("Missing 'url' in request config")), }?; // Extract method let method = match tab.raw_get::<_, LuaString>("method") { @@ -83,26 +83,12 @@ impl<'lua> FromLua<'lua> for RequestConfig<'lua> { }; // Extract query let query = match tab.raw_get::<_, LuaTable>("query") { - Ok(config_headers) => { - let mut lua_headers = HashMap::new(); - for pair in config_headers.pairs::() { - let (key, value) = pair?.to_owned(); - lua_headers.insert(key, value); - } - lua_headers - } + Ok(tab) => table_to_hash_map(tab, "query")?, Err(_) => HashMap::new(), }; // Extract headers let headers = match tab.raw_get::<_, LuaTable>("headers") { - Ok(config_headers) => { - let mut lua_headers = HashMap::new(); - for pair in config_headers.pairs::() { - let (key, value) = pair?.to_owned(); - lua_headers.insert(key, value); - } - lua_headers - } + Ok(tab) => table_to_hash_map(tab, "headers")?, Err(_) => HashMap::new(), }; // Extract body diff --git a/src/lune/builtins/net/mod.rs b/src/lune/builtins/net/mod.rs index 615136d..cf9e7ab 100644 --- a/src/lune/builtins/net/mod.rs +++ b/src/lune/builtins/net/mod.rs @@ -1,12 +1,10 @@ -use std::collections::HashMap; - use mlua::prelude::*; -use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH}; +use hyper::header::CONTENT_ENCODING; use crate::lune::{scheduler::Scheduler, util::TableBuilder}; -use self::server::create_server; +use self::{server::create_server, util::header_map_to_table}; use super::serde::{ compress_decompress::{decompress, CompressDecompressFormat}, @@ -18,6 +16,7 @@ mod config; mod processing; mod response; mod server; +mod util; mod websocket; use client::{NetClient, NetClientBuilder}; @@ -61,18 +60,25 @@ fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult(lua: &'lua Lua, config: RequestConfig<'lua>) -> LuaResult> +async fn net_request<'lua>(lua: &'lua Lua, config: RequestConfig) -> LuaResult> where 'lua: 'static, // FIXME: Get rid of static lifetime bound here { // Create and send the request let client = NetClient::from_registry(lua); let mut request = client.request(config.method, &config.url); - for (query, value) in config.query { - request = request.query(&[(query.to_str()?, value.to_str()?)]); + for (query, values) in config.query { + request = request.query( + &values + .iter() + .map(|v| (query.as_str(), v)) + .collect::>(), + ); } - for (header, value) in config.headers { - request = request.header(header.to_str()?, value.to_str()?); + for (header, values) in config.headers { + for value in values { + request = request.header(header.as_str(), value); + } } let res = request .body(config.body.unwrap_or_default()) @@ -82,44 +88,32 @@ where // Extract status, headers let res_status = res.status().as_u16(); let res_status_text = res.status().canonical_reason(); - let mut res_headers = res - .headers() - .iter() - .map(|(name, value)| { - ( - name.as_str().to_string(), - value.to_str().unwrap().to_owned(), - ) - }) - .collect::>(); + let res_headers = res.headers().clone(); // Read response bytes let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec(); + let mut res_decompressed = false; // Check for extra options, decompression if config.options.decompress { - // NOTE: Header names are guaranteed to be lowercase because of the above - // transformations of them into the hashmap, so we can compare directly - let format = res_headers.iter().find_map(|(name, val)| { - if name == CONTENT_ENCODING.as_str() { - CompressDecompressFormat::detect_from_header_str(val) - } else { - None - } - }); - if let Some(format) = format { + let decompress_format = res_headers + .iter() + .find(|(name, _)| { + name.as_str() + .eq_ignore_ascii_case(CONTENT_ENCODING.as_str()) + }) + .and_then(|(_, value)| value.to_str().ok()) + .and_then(CompressDecompressFormat::detect_from_header_str); + if let Some(format) = decompress_format { res_bytes = decompress(format, res_bytes).await?; - let content_encoding_header_str = CONTENT_ENCODING.as_str(); - let content_length_header_str = CONTENT_LENGTH.as_str(); - res_headers.retain(|name, _| { - name != content_encoding_header_str && name != content_length_header_str - }); + res_decompressed = true; } } // Construct and return a readonly lua table with results + let res_headers_lua = header_map_to_table(lua, res_headers, res_decompressed)?; TableBuilder::new(lua)? .with_value("ok", (200..300).contains(&res_status))? .with_value("statusCode", res_status)? .with_value("statusMessage", res_status_text)? - .with_value("headers", res_headers)? + .with_value("headers", res_headers_lua)? .with_value("body", lua.create_string(&res_bytes)?)? .build_readonly() } diff --git a/src/lune/builtins/net/util.rs b/src/lune/builtins/net/util.rs new file mode 100644 index 0000000..fa1e2ad --- /dev/null +++ b/src/lune/builtins/net/util.rs @@ -0,0 +1,81 @@ +use std::collections::HashMap; + +use hyper::{ + header::{CONTENT_ENCODING, CONTENT_LENGTH}, + HeaderMap, +}; + +use mlua::prelude::*; + +use crate::lune::util::TableBuilder; + +pub fn header_map_to_table( + lua: &Lua, + headers: HeaderMap, + remove_content_headers: bool, +) -> LuaResult { + let mut res_headers: HashMap> = HashMap::new(); + for (name, value) in headers.iter() { + let name = name.as_str(); + let value = value.to_str().unwrap().to_owned(); + if let Some(existing) = res_headers.get_mut(name) { + existing.push(value); + } else { + res_headers.insert(name.to_owned(), vec![value]); + } + } + + if remove_content_headers { + let content_encoding_header_str = CONTENT_ENCODING.as_str(); + let content_length_header_str = CONTENT_LENGTH.as_str(); + res_headers.retain(|name, _| { + name != content_encoding_header_str && name != content_length_header_str + }); + } + + let mut builder = TableBuilder::new(lua)?; + for (name, mut values) in res_headers { + if values.len() == 1 { + let value = values.pop().unwrap().into_lua(lua)?; + builder = builder.with_value(name, value)?; + } else { + let values = TableBuilder::new(lua)? + .with_sequential_values(values)? + .build_readonly()? + .into_lua(lua)?; + builder = builder.with_value(name, values)?; + } + } + + builder.build_readonly() +} + +pub fn table_to_hash_map( + tab: LuaTable, + tab_origin_key: &'static str, +) -> LuaResult>> { + let mut map = HashMap::new(); + + for pair in tab.pairs::() { + let (key, value) = pair?; + match value { + LuaValue::String(s) => { + map.insert(key, vec![s.to_str()?.to_owned()]); + } + LuaValue::Table(t) => { + let mut values = Vec::new(); + for value in t.sequence_values::() { + values.push(value?.to_str()?.to_owned()); + } + map.insert(key, values); + } + _ => { + return Err(LuaError::runtime(format!( + "Value for '{tab_origin_key}' must be a string or array of strings", + ))) + } + } + } + + Ok(map) +} diff --git a/types/net.luau b/types/net.luau index 52eb28d..4f45a00 100644 --- a/types/net.luau +++ b/types/net.luau @@ -1,5 +1,9 @@ export type HttpMethod = "GET" | "POST" | "PUT" | "DELETE" | "HEAD" | "OPTIONS" | "PATCH" +type HttpQueryOrHeaderMap = { [string]: string | { string } } +export type HttpQueryMap = HttpQueryOrHeaderMap +export type HttpHeaderMap = HttpQueryOrHeaderMap + --[=[ @interface FetchParamsOptions @within Net @@ -33,8 +37,8 @@ export type FetchParams = { url: string, method: HttpMethod?, body: string?, - query: { [string]: string }?, - headers: { [string]: string }?, + query: HttpQueryMap?, + headers: HttpHeaderMap?, options: FetchParamsOptions?, } @@ -56,7 +60,7 @@ export type FetchResponse = { ok: boolean, statusCode: number, statusMessage: string, - headers: { [string]: string }, + headers: HttpHeaderMap, body: string, }