diff --git a/CHANGELOG.md b/CHANGELOG.md
index 18d2646..ef386b6 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -10,6 +10,38 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
## Unreleased
+### Added
+
+- Added support for multiple values for a single query, and multiple values for a single header, in `net.request`. This is a part of the HTTP specification that is not widely used but that may be useful in certain cases. To clarify:
+
+ - Single values remain unchanged and will work exactly the same as before.
+
+ ```lua
+ -- https://example.com/?foo=bar&baz=qux
+ local net = require("@lune/net")
+ net.request({
+ url = "example.com",
+ query = {
+ foo = "bar",
+ baz = "qux",
+ }
+ })
+ ```
+
+ - Multiple values _on a single query / header_ are represented as an ordered array of strings.
+
+ ```lua
+ -- https://example.com/?foo=first&foo=second&foo=third&bar=baz
+ local net = require("@lune/net")
+ net.request({
+ url = "example.com",
+ query = {
+ foo = { "first", "second", "third" },
+ bar = "baz",
+ }
+ })
+ ```
+
### Changed
- Update to Luau version `0.606`.
diff --git a/src/lune/builtins/net/config.rs b/src/lune/builtins/net/config.rs
index 59c1e3e..b754ca8 100644
--- a/src/lune/builtins/net/config.rs
+++ b/src/lune/builtins/net/config.rs
@@ -4,6 +4,8 @@ use mlua::prelude::*;
use reqwest::Method;
+use super::util::table_to_hash_map;
+
// Net request config
#[derive(Debug, Clone)]
@@ -45,17 +47,17 @@ impl<'lua> FromLua<'lua> for RequestConfigOptions {
}
#[derive(Debug, Clone)]
-pub struct RequestConfig<'a> {
+pub struct RequestConfig {
pub url: String,
pub method: Method,
- pub query: HashMap, LuaString<'a>>,
- pub headers: HashMap, LuaString<'a>>,
+ pub query: HashMap>,
+ pub headers: HashMap>,
pub body: Option>,
pub options: RequestConfigOptions,
}
-impl<'lua> FromLua<'lua> for RequestConfig<'lua> {
- fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult {
+impl FromLua<'_> for RequestConfig {
+ fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult {
// If we just got a string we assume its a GET request to a given url
if let LuaValue::String(s) = value {
return Ok(Self {
@@ -72,9 +74,7 @@ impl<'lua> FromLua<'lua> for RequestConfig<'lua> {
// Extract url
let url = match tab.raw_get::<_, LuaString>("url") {
Ok(config_url) => Ok(config_url.to_string_lossy().to_string()),
- Err(_) => Err(LuaError::RuntimeError(
- "Missing 'url' in request config".to_string(),
- )),
+ Err(_) => Err(LuaError::runtime("Missing 'url' in request config")),
}?;
// Extract method
let method = match tab.raw_get::<_, LuaString>("method") {
@@ -83,26 +83,12 @@ impl<'lua> FromLua<'lua> for RequestConfig<'lua> {
};
// Extract query
let query = match tab.raw_get::<_, LuaTable>("query") {
- Ok(config_headers) => {
- let mut lua_headers = HashMap::new();
- for pair in config_headers.pairs::() {
- let (key, value) = pair?.to_owned();
- lua_headers.insert(key, value);
- }
- lua_headers
- }
+ Ok(tab) => table_to_hash_map(tab, "query")?,
Err(_) => HashMap::new(),
};
// Extract headers
let headers = match tab.raw_get::<_, LuaTable>("headers") {
- Ok(config_headers) => {
- let mut lua_headers = HashMap::new();
- for pair in config_headers.pairs::() {
- let (key, value) = pair?.to_owned();
- lua_headers.insert(key, value);
- }
- lua_headers
- }
+ Ok(tab) => table_to_hash_map(tab, "headers")?,
Err(_) => HashMap::new(),
};
// Extract body
diff --git a/src/lune/builtins/net/mod.rs b/src/lune/builtins/net/mod.rs
index 615136d..cf9e7ab 100644
--- a/src/lune/builtins/net/mod.rs
+++ b/src/lune/builtins/net/mod.rs
@@ -1,12 +1,10 @@
-use std::collections::HashMap;
-
use mlua::prelude::*;
-use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH};
+use hyper::header::CONTENT_ENCODING;
use crate::lune::{scheduler::Scheduler, util::TableBuilder};
-use self::server::create_server;
+use self::{server::create_server, util::header_map_to_table};
use super::serde::{
compress_decompress::{decompress, CompressDecompressFormat},
@@ -18,6 +16,7 @@ mod config;
mod processing;
mod response;
mod server;
+mod util;
mod websocket;
use client::{NetClient, NetClientBuilder};
@@ -61,18 +60,25 @@ fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult(lua: &'lua Lua, config: RequestConfig<'lua>) -> LuaResult>
+async fn net_request<'lua>(lua: &'lua Lua, config: RequestConfig) -> LuaResult>
where
'lua: 'static, // FIXME: Get rid of static lifetime bound here
{
// Create and send the request
let client = NetClient::from_registry(lua);
let mut request = client.request(config.method, &config.url);
- for (query, value) in config.query {
- request = request.query(&[(query.to_str()?, value.to_str()?)]);
+ for (query, values) in config.query {
+ request = request.query(
+ &values
+ .iter()
+ .map(|v| (query.as_str(), v))
+ .collect::>(),
+ );
}
- for (header, value) in config.headers {
- request = request.header(header.to_str()?, value.to_str()?);
+ for (header, values) in config.headers {
+ for value in values {
+ request = request.header(header.as_str(), value);
+ }
}
let res = request
.body(config.body.unwrap_or_default())
@@ -82,44 +88,32 @@ where
// Extract status, headers
let res_status = res.status().as_u16();
let res_status_text = res.status().canonical_reason();
- let mut res_headers = res
- .headers()
- .iter()
- .map(|(name, value)| {
- (
- name.as_str().to_string(),
- value.to_str().unwrap().to_owned(),
- )
- })
- .collect::>();
+ let res_headers = res.headers().clone();
// Read response bytes
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
+ let mut res_decompressed = false;
// Check for extra options, decompression
if config.options.decompress {
- // NOTE: Header names are guaranteed to be lowercase because of the above
- // transformations of them into the hashmap, so we can compare directly
- let format = res_headers.iter().find_map(|(name, val)| {
- if name == CONTENT_ENCODING.as_str() {
- CompressDecompressFormat::detect_from_header_str(val)
- } else {
- None
- }
- });
- if let Some(format) = format {
+ let decompress_format = res_headers
+ .iter()
+ .find(|(name, _)| {
+ name.as_str()
+ .eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
+ })
+ .and_then(|(_, value)| value.to_str().ok())
+ .and_then(CompressDecompressFormat::detect_from_header_str);
+ if let Some(format) = decompress_format {
res_bytes = decompress(format, res_bytes).await?;
- let content_encoding_header_str = CONTENT_ENCODING.as_str();
- let content_length_header_str = CONTENT_LENGTH.as_str();
- res_headers.retain(|name, _| {
- name != content_encoding_header_str && name != content_length_header_str
- });
+ res_decompressed = true;
}
}
// Construct and return a readonly lua table with results
+ let res_headers_lua = header_map_to_table(lua, res_headers, res_decompressed)?;
TableBuilder::new(lua)?
.with_value("ok", (200..300).contains(&res_status))?
.with_value("statusCode", res_status)?
.with_value("statusMessage", res_status_text)?
- .with_value("headers", res_headers)?
+ .with_value("headers", res_headers_lua)?
.with_value("body", lua.create_string(&res_bytes)?)?
.build_readonly()
}
diff --git a/src/lune/builtins/net/util.rs b/src/lune/builtins/net/util.rs
new file mode 100644
index 0000000..fa1e2ad
--- /dev/null
+++ b/src/lune/builtins/net/util.rs
@@ -0,0 +1,81 @@
+use std::collections::HashMap;
+
+use hyper::{
+ header::{CONTENT_ENCODING, CONTENT_LENGTH},
+ HeaderMap,
+};
+
+use mlua::prelude::*;
+
+use crate::lune::util::TableBuilder;
+
+pub fn header_map_to_table(
+ lua: &Lua,
+ headers: HeaderMap,
+ remove_content_headers: bool,
+) -> LuaResult {
+ let mut res_headers: HashMap> = HashMap::new();
+ for (name, value) in headers.iter() {
+ let name = name.as_str();
+ let value = value.to_str().unwrap().to_owned();
+ if let Some(existing) = res_headers.get_mut(name) {
+ existing.push(value);
+ } else {
+ res_headers.insert(name.to_owned(), vec![value]);
+ }
+ }
+
+ if remove_content_headers {
+ let content_encoding_header_str = CONTENT_ENCODING.as_str();
+ let content_length_header_str = CONTENT_LENGTH.as_str();
+ res_headers.retain(|name, _| {
+ name != content_encoding_header_str && name != content_length_header_str
+ });
+ }
+
+ let mut builder = TableBuilder::new(lua)?;
+ for (name, mut values) in res_headers {
+ if values.len() == 1 {
+ let value = values.pop().unwrap().into_lua(lua)?;
+ builder = builder.with_value(name, value)?;
+ } else {
+ let values = TableBuilder::new(lua)?
+ .with_sequential_values(values)?
+ .build_readonly()?
+ .into_lua(lua)?;
+ builder = builder.with_value(name, values)?;
+ }
+ }
+
+ builder.build_readonly()
+}
+
+pub fn table_to_hash_map(
+ tab: LuaTable,
+ tab_origin_key: &'static str,
+) -> LuaResult>> {
+ let mut map = HashMap::new();
+
+ for pair in tab.pairs::() {
+ let (key, value) = pair?;
+ match value {
+ LuaValue::String(s) => {
+ map.insert(key, vec![s.to_str()?.to_owned()]);
+ }
+ LuaValue::Table(t) => {
+ let mut values = Vec::new();
+ for value in t.sequence_values::() {
+ values.push(value?.to_str()?.to_owned());
+ }
+ map.insert(key, values);
+ }
+ _ => {
+ return Err(LuaError::runtime(format!(
+ "Value for '{tab_origin_key}' must be a string or array of strings",
+ )))
+ }
+ }
+ }
+
+ Ok(map)
+}
diff --git a/types/net.luau b/types/net.luau
index 52eb28d..4f45a00 100644
--- a/types/net.luau
+++ b/types/net.luau
@@ -1,5 +1,9 @@
export type HttpMethod = "GET" | "POST" | "PUT" | "DELETE" | "HEAD" | "OPTIONS" | "PATCH"
+type HttpQueryOrHeaderMap = { [string]: string | { string } }
+export type HttpQueryMap = HttpQueryOrHeaderMap
+export type HttpHeaderMap = HttpQueryOrHeaderMap
+
--[=[
@interface FetchParamsOptions
@within Net
@@ -33,8 +37,8 @@ export type FetchParams = {
url: string,
method: HttpMethod?,
body: string?,
- query: { [string]: string }?,
- headers: { [string]: string }?,
+ query: HttpQueryMap?,
+ headers: HttpHeaderMap?,
options: FetchParamsOptions?,
}
@@ -56,7 +60,7 @@ export type FetchResponse = {
ok: boolean,
statusCode: number,
statusMessage: string,
- headers: { [string]: string },
+ headers: HttpHeaderMap,
body: string,
}