diff --git a/crates/lune-std-net/src/client/config.rs b/crates/lune-std-net/src/client/config.rs index de78feb..d0668b4 100644 --- a/crates/lune-std-net/src/client/config.rs +++ b/crates/lune-std-net/src/client/config.rs @@ -1,10 +1,10 @@ use std::collections::HashMap; use bstr::{BString, ByteSlice}; -use hyper::Method; +use hyper::{header::USER_AGENT, Method}; use mlua::prelude::*; -use crate::shared::headers::table_to_hash_map; +use crate::shared::headers::{create_user_agent_header, table_to_hash_map}; #[derive(Debug, Clone)] pub struct RequestConfigOptions { @@ -86,7 +86,7 @@ impl FromLua for RequestConfig { Err(_) => HashMap::new(), }; // Extract headers - let headers = match tab.get::("headers") { + let mut headers = match tab.get::("headers") { Ok(tab) => table_to_hash_map(tab, "headers")?, Err(_) => HashMap::new(), }; @@ -118,6 +118,9 @@ impl FromLua for RequestConfig { Err(_) => RequestConfigOptions::default(), }; + // Finally, add any default headers, if applicable + add_default_headers(lua, &mut headers)?; + // All good, validated and we got what we need Ok(Self { url, @@ -140,3 +143,12 @@ impl FromLua for RequestConfig { } } } + +fn add_default_headers(lua: &Lua, headers: &mut HashMap>) -> LuaResult<()> { + if !headers.contains_key(USER_AGENT.as_str()) { + let ua = create_user_agent_header(lua)?; + headers.insert(USER_AGENT.to_string(), vec![ua]); + } + + Ok(()) +} diff --git a/crates/lune-std-net/src/lib.rs b/crates/lune-std-net/src/lib.rs index 354e08e..eaeafbc 100644 --- a/crates/lune-std-net/src/lib.rs +++ b/crates/lune-std-net/src/lib.rs @@ -43,6 +43,5 @@ pub fn module(lua: Lua) -> LuaResult { } async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult { - let request = Request::from_config(config, lua.clone())?; - self::client::send_request(request, lua.clone()).await + self::client::send_request(Request::try_from(config)?, lua).await } diff --git a/crates/lune-std-net/src/shared/request.rs b/crates/lune-std-net/src/shared/request.rs index 807a30d..cc09556 100644 --- a/crates/lune-std-net/src/shared/request.rs +++ b/crates/lune-std-net/src/shared/request.rs @@ -6,7 +6,7 @@ use url::Url; use hyper::{ body::{Body as _, Bytes, Incoming}, - header::{HeaderName, HeaderValue, USER_AGENT}, + header::{HeaderName, HeaderValue}, HeaderMap, Method, Request as HyperRequest, }; @@ -14,7 +14,7 @@ use mlua::prelude::*; use crate::{ client::config::RequestConfig, - shared::headers::{create_user_agent_header, hash_map_to_table, header_map_to_table}, + shared::headers::{hash_map_to_table, header_map_to_table}, }; #[derive(Debug, Clone)] @@ -27,57 +27,6 @@ pub struct Request { } impl Request { - /** - Creates a new request that is ready to be sent from a request configuration. - */ - pub fn from_config(config: RequestConfig, lua: Lua) -> LuaResult { - // 1. Parse the URL and make sure it is valid - let mut url = Url::parse(&config.url).into_lua_err()?; - - // 2. Append any query pairs passed as a table - { - let mut query = url.query_pairs_mut(); - for (key, values) in config.query { - for value in values { - query.append_pair(&key, &value); - } - } - } - - // 3. Create the inner request builder - let mut builder = HyperRequest::builder() - .method(config.method) - .uri(url.as_str()); - - // 4. Append any headers passed as a table - builder - // headers may be None if builder is already invalid - if let Some(headers) = builder.headers_mut() { - for (key, values) in config.headers { - let key = HeaderName::from_bytes(key.as_bytes()).into_lua_err()?; - for value in values { - let value = HeaderValue::from_str(&value).into_lua_err()?; - headers.insert(key.clone(), value); - } - } - } - - // 5. Convert request body bytes to the proper Body - // type that Hyper expects, if we got any bytes - let body = config.body.map(Bytes::from).unwrap_or_default(); - - // 6. Finally, attach the body, verifying that the request - // is valid, and attach a user agent if not already set - let mut inner = builder.body(body).into_lua_err()?; - - add_default_headers(&lua, inner.headers_mut())?; - - Ok(Self { - inner, - redirects: 0, - decompress: config.options.decompress, - }) - } - /** Creates a new request from a raw incoming request. */ @@ -173,14 +122,52 @@ impl Request { } } -fn add_default_headers(lua: &Lua, headers: &mut HeaderMap) -> LuaResult<()> { - if !headers.contains_key(USER_AGENT) { - let ua = create_user_agent_header(lua)?; - let ua = HeaderValue::from_str(&ua).into_lua_err()?; - headers.insert(USER_AGENT, ua); - } +impl TryFrom for Request { + type Error = LuaError; + fn try_from(config: RequestConfig) -> Result { + // 1. Parse the URL and make sure it is valid + let mut url = Url::parse(&config.url).into_lua_err()?; - Ok(()) + // 2. Append any query pairs passed as a table + { + let mut query = url.query_pairs_mut(); + for (key, values) in config.query { + for value in values { + query.append_pair(&key, &value); + } + } + } + + // 3. Create the inner request builder + let mut builder = HyperRequest::builder() + .method(config.method) + .uri(url.as_str()); + + // 4. Append any headers passed as a table - builder + // headers may be None if builder is already invalid + if let Some(headers) = builder.headers_mut() { + for (key, values) in config.headers { + let key = HeaderName::from_bytes(key.as_bytes()).into_lua_err()?; + for value in values { + let value = HeaderValue::from_str(&value).into_lua_err()?; + headers.insert(key.clone(), value); + } + } + } + + // 5. Convert request body bytes to the proper Body + // type that Hyper expects, if we got any bytes + let body = config.body.map(Bytes::from).unwrap_or_default(); + + // 6. Finally, attach the body, verifying that the request is valid + let inner = builder.body(body).into_lua_err()?; + + Ok(Self { + inner, + redirects: 0, + decompress: config.options.decompress, + }) + } } impl LuaUserData for Request { diff --git a/crates/lune-std-net/src/shared/response.rs b/crates/lune-std-net/src/shared/response.rs index c2cb61d..d3311b2 100644 --- a/crates/lune-std-net/src/shared/response.rs +++ b/crates/lune-std-net/src/shared/response.rs @@ -20,36 +20,6 @@ pub struct Response { } impl Response { - /** - Creates a new response that is ready to be sent from a response configuration. - */ - pub fn from_config(config: ResponseConfig, _lua: Lua) -> LuaResult { - // 1. Create the inner response builder - let mut builder = HyperResponse::builder().status(config.status); - - // 2. Append any headers passed as a table - builder - // headers may be None if builder is already invalid - if let Some(headers) = builder.headers_mut() { - for (key, values) in config.headers { - let key = HeaderName::from_bytes(key.as_bytes()).into_lua_err()?; - for value in values { - let value = HeaderValue::from_str(&value).into_lua_err()?; - headers.insert(key.clone(), value); - } - } - } - - // 3. Convert response body bytes to the proper Body - // type that Hyper expects, if we got any bytes - let body = config.body.map(Bytes::from).unwrap_or_default(); - - // 4. Finally, attach the body, verifying that the response is valid - Ok(Self { - inner: builder.body(body).into_lua_err()?, - decompressed: false, - }) - } - /** Creates a new response from a raw incoming response. */ @@ -133,6 +103,36 @@ impl Response { } } +impl TryFrom for Response { + type Error = LuaError; + fn try_from(config: ResponseConfig) -> Result { + // 1. Create the inner response builder + let mut builder = HyperResponse::builder().status(config.status); + + // 2. Append any headers passed as a table - builder + // headers may be None if builder is already invalid + if let Some(headers) = builder.headers_mut() { + for (key, values) in config.headers { + let key = HeaderName::from_bytes(key.as_bytes()).into_lua_err()?; + for value in values { + let value = HeaderValue::from_str(&value).into_lua_err()?; + headers.insert(key.clone(), value); + } + } + } + + // 3. Convert response body bytes to the proper Body + // type that Hyper expects, if we got any bytes + let body = config.body.map(Bytes::from).unwrap_or_default(); + + // 4. Finally, attach the body, verifying that the response is valid + Ok(Self { + inner: builder.body(body).into_lua_err()?, + decompressed: false, + }) + } +} + impl LuaUserData for Response { fn add_fields>(fields: &mut F) { fields.add_field_method_get("ok", |_, this| Ok(this.status_ok()));