diff --git a/CHANGELOG.md b/CHANGELOG.md index 42eecd2..39fe2b7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 assert(decompressed == INPUT) ``` +- Added automatic decompression for compressed responses when using `net.request`. + This behavior can be disabled by passing `options = { decompress = false }` in request params. + - Added several new instance methods in the `roblox` builtin library: - [`Instance:AddTag`](https://create.roblox.com/docs/reference/engine/classes/Instance#AddTag) - [`Instance:GetTags`](https://create.roblox.com/docs/reference/engine/classes/Instance#GetTags) diff --git a/docs/typedefs/Net.luau b/docs/typedefs/Net.luau index 2869a92..139282b 100644 --- a/docs/typedefs/Net.luau +++ b/docs/typedefs/Net.luau @@ -1,5 +1,19 @@ export type HttpMethod = "GET" | "POST" | "PUT" | "DELETE" | "HEAD" | "OPTIONS" | "PATCH" +--[=[ + @type FetchParamsOptions + @within Net + + Extra options for `FetchParams`. + + This is a dictionary that may contain one or more of the following values: + + * `decompress` - If the request body should be automatically decompressed when possible. Defaults to `true` +]=] +export type FetchParamsOptions = { + decompress: boolean?, +} + --[=[ @type FetchParams @within Net @@ -10,16 +24,18 @@ export type HttpMethod = "GET" | "POST" | "PUT" | "DELETE" | "HEAD" | "OPTIONS" * `url` - The URL to send a request to. This is always required * `method` - The HTTP method verb, such as `"GET"`, `"POST"`, `"PATCH"`, `"PUT"`, or `"DELETE"`. Defaults to `"GET"` + * `body` - The request body * `query` - A table of key-value pairs representing query parameters in the request path * `headers` - A table of key-value pairs representing headers - * `body` - The request body + * `options` - Extra options for things such as automatic decompression of response bodies ]=] export type FetchParams = { url: string, method: HttpMethod?, + body: string?, query: { [string]: string }?, headers: { [string]: string }?, - body: string?, + options: FetchParamsOptions?, } --[=[ diff --git a/packages/lib/src/builtins/net.rs b/packages/lib/src/builtins/net.rs index 413f966..31b58c3 100644 --- a/packages/lib/src/builtins/net.rs +++ b/packages/lib/src/builtins/net.rs @@ -3,7 +3,10 @@ use std::collections::HashMap; use mlua::prelude::*; use console::style; -use hyper::Server; +use hyper::{ + header::{CONTENT_ENCODING, CONTENT_LENGTH}, + Server, +}; use tokio::{sync::mpsc, task}; use crate::lua::{ @@ -11,7 +14,7 @@ use crate::lua::{ NetClient, NetClientBuilder, NetLocalExec, NetService, NetWebSocket, RequestConfig, ServeConfig, }, - serde::{EncodeDecodeConfig, EncodeDecodeFormat}, + serde::{decompress, CompressDecompressFormat, EncodeDecodeConfig, EncodeDecodeFormat}, table::TableBuilder, task::{TaskScheduler, TaskSchedulerAsyncExt}, }; @@ -74,13 +77,38 @@ async fn net_request<'a>(lua: &'static Lua, config: RequestConfig<'a>) -> LuaRes // Extract status, headers let res_status = res.status().as_u16(); let res_status_text = res.status().canonical_reason(); - let res_headers = res + let mut res_headers = res .headers() .iter() - .map(|(name, value)| (name.to_string(), value.to_str().unwrap().to_owned())) + .map(|(name, value)| { + ( + name.as_str().to_string(), + value.to_str().unwrap().to_owned(), + ) + }) .collect::>(); // Read response bytes - let res_bytes = res.bytes().await.map_err(LuaError::external)?; + let mut res_bytes = res.bytes().await.map_err(LuaError::external)?.to_vec(); + // Check for extra options, decompression + if config.options.decompress { + // NOTE: Header names are guaranteed to be lowercase because of the above + // transformations of them into the hashmap, so we can compare directly + let format = res_headers.iter().find_map(|(name, val)| { + if name == CONTENT_ENCODING.as_str() { + CompressDecompressFormat::detect_from_header_str(val) + } else { + None + } + }); + if let Some(format) = format { + res_bytes = decompress(format, res_bytes).await?; + let content_encoding_header_str = CONTENT_ENCODING.as_str(); + let content_length_header_str = CONTENT_LENGTH.as_str(); + res_headers.retain(|name, _| { + name != content_encoding_header_str && name != content_length_header_str + }); + } + } // Construct and return a readonly lua table with results TableBuilder::new(lua)? .with_value("ok", (200..300).contains(&res_status))? diff --git a/packages/lib/src/lua/net/config.rs b/packages/lib/src/lua/net/config.rs index a240b36..59c1e3e 100644 --- a/packages/lib/src/lua/net/config.rs +++ b/packages/lib/src/lua/net/config.rs @@ -6,16 +6,56 @@ use reqwest::Method; // Net request config +#[derive(Debug, Clone)] +pub struct RequestConfigOptions { + pub decompress: bool, +} + +impl Default for RequestConfigOptions { + fn default() -> Self { + Self { decompress: true } + } +} + +impl<'lua> FromLua<'lua> for RequestConfigOptions { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + // Nil means default options, table means custom options + if let LuaValue::Nil = value { + return Ok(Self::default()); + } else if let LuaValue::Table(tab) = value { + // Extract flags + let decompress = match tab.raw_get::<_, Option>("decompress") { + Ok(decomp) => Ok(decomp.unwrap_or(true)), + Err(_) => Err(LuaError::RuntimeError( + "Invalid option value for 'decompress' in request config options".to_string(), + )), + }?; + return Ok(Self { decompress }); + } + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfigOptions", + message: Some(format!( + "Invalid request config options - expected table or nil, got {}", + value.type_name() + )), + }) + } +} + +#[derive(Debug, Clone)] pub struct RequestConfig<'a> { pub url: String, pub method: Method, pub query: HashMap, LuaString<'a>>, pub headers: HashMap, LuaString<'a>>, pub body: Option>, + pub options: RequestConfigOptions, } impl<'lua> FromLua<'lua> for RequestConfig<'lua> { - fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { // If we just got a string we assume its a GET request to a given url if let LuaValue::String(s) = value { return Ok(Self { @@ -24,6 +64,7 @@ impl<'lua> FromLua<'lua> for RequestConfig<'lua> { query: HashMap::new(), headers: HashMap::new(), body: None, + options: Default::default(), }); } // If we got a table we are able to configure the entire request @@ -84,6 +125,11 @@ impl<'lua> FromLua<'lua> for RequestConfig<'lua> { &method ))), }?; + // Parse any extra options given + let options = match tab.raw_get::<_, LuaValue>("options") { + Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?, + Err(_) => RequestConfigOptions::default(), + }; // All good, validated and we got what we need return Ok(Self { url, @@ -91,6 +137,7 @@ impl<'lua> FromLua<'lua> for RequestConfig<'lua> { query, headers, body, + options, }); }; // Anything else is invalid diff --git a/packages/lib/src/lua/serde/compress_decompress.rs b/packages/lib/src/lua/serde/compress_decompress.rs index f749abd..ca5d604 100644 --- a/packages/lib/src/lua/serde/compress_decompress.rs +++ b/packages/lib/src/lua/serde/compress_decompress.rs @@ -54,9 +54,9 @@ impl CompressDecompressFormat { pub fn detect_from_header_str(header: impl AsRef) -> Option { // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Encoding#directives match header.as_ref().to_ascii_lowercase().trim() { - "br" => Some(Self::Brotli), + "br" | "brotli" => Some(Self::Brotli), "deflate" => Some(Self::ZLib), - "gzip" => Some(Self::GZip), + "gz" | "gzip" => Some(Self::GZip), _ => None, } }