diff --git a/Cargo.lock b/Cargo.lock index 4ca0162..bf9fb6e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1869,6 +1869,7 @@ dependencies = [ "pin-project-lite", "rustls 0.23.26", "rustls-pki-types", + "url", "webpki", "webpki-roots 0.26.8", ] diff --git a/crates/lune-std-net/Cargo.toml b/crates/lune-std-net/Cargo.toml index 65d74ae..0a70395 100644 --- a/crates/lune-std-net/Cargo.toml +++ b/crates/lune-std-net/Cargo.toml @@ -29,6 +29,7 @@ hyper = { version = "1.6", features = ["http1", "client", "server"] } pin-project-lite = "0.2" rustls = "0.23" rustls-pki-types = "1.11" +url = "2.5" webpki = "0.22" webpki-roots = "0.26" diff --git a/crates/lune-std-net/src/client/config.rs b/crates/lune-std-net/src/client/config.rs new file mode 100644 index 0000000..de78feb --- /dev/null +++ b/crates/lune-std-net/src/client/config.rs @@ -0,0 +1,142 @@ +use std::collections::HashMap; + +use bstr::{BString, ByteSlice}; +use hyper::Method; +use mlua::prelude::*; + +use crate::shared::headers::table_to_hash_map; + +#[derive(Debug, Clone)] +pub struct RequestConfigOptions { + pub decompress: bool, +} + +impl Default for RequestConfigOptions { + fn default() -> Self { + Self { decompress: true } + } +} + +impl FromLua for RequestConfigOptions { + fn from_lua(value: LuaValue, _: &Lua) -> LuaResult { + if let LuaValue::Nil = value { + // Nil means default options + Ok(Self::default()) + } else if let LuaValue::Table(tab) = value { + // Table means custom options + let decompress = match tab.get::>("decompress") { + Ok(decomp) => Ok(decomp.unwrap_or(true)), + Err(_) => Err(LuaError::RuntimeError( + "Invalid option value for 'decompress' in request config options".to_string(), + )), + }?; + Ok(Self { decompress }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfigOptions".to_string(), + message: Some(format!( + "Invalid request config options - expected table or nil, got {}", + value.type_name() + )), + }) + } + } +} + +#[derive(Debug, Clone)] +pub struct RequestConfig { + pub url: String, + pub method: Method, + pub query: HashMap>, + pub headers: HashMap>, + pub body: Option>, + pub options: RequestConfigOptions, +} + +impl FromLua for RequestConfig { + fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { + // If we just got a string we assume its a GET request to a given url + if let LuaValue::String(s) = value { + Ok(Self { + url: s.to_string_lossy().to_string(), + method: Method::GET, + query: HashMap::new(), + headers: HashMap::new(), + body: None, + options: RequestConfigOptions::default(), + }) + } else if let LuaValue::Table(tab) = value { + // If we got a table we are able to configure the entire request + + // Extract url + let url = match tab.get::("url") { + Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), + Err(_) => Err(LuaError::runtime("Missing 'url' in request config")), + }?; + // Extract method + let method = match tab.get::("method") { + Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(), + Err(_) => "GET".to_string(), + }; + // Extract query + let query = match tab.get::("query") { + Ok(tab) => table_to_hash_map(tab, "query")?, + Err(_) => HashMap::new(), + }; + // Extract headers + let headers = match tab.get::("headers") { + Ok(tab) => table_to_hash_map(tab, "headers")?, + Err(_) => HashMap::new(), + }; + // Extract body + let body = match tab.get::("body") { + Ok(config_body) => Some(config_body.as_bytes().to_owned()), + Err(_) => None, + }; + + // Convert method string into proper enum + let method = method.trim().to_ascii_uppercase(); + let method = match method.as_ref() { + "GET" => Ok(Method::GET), + "POST" => Ok(Method::POST), + "PUT" => Ok(Method::PUT), + "DELETE" => Ok(Method::DELETE), + "HEAD" => Ok(Method::HEAD), + "OPTIONS" => Ok(Method::OPTIONS), + "PATCH" => Ok(Method::PATCH), + _ => Err(LuaError::RuntimeError(format!( + "Invalid request config method '{}'", + &method + ))), + }?; + + // Parse any extra options given + let options = match tab.get::("options") { + Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?, + Err(_) => RequestConfigOptions::default(), + }; + + // All good, validated and we got what we need + Ok(Self { + url, + method, + query, + headers, + body, + options, + }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfig".to_string(), + message: Some(format!( + "Invalid request config - expected string or table, got {}", + value.type_name() + )), + }) + } + } +} diff --git a/crates/lune-std-net/src/client/mod.rs b/crates/lune-std-net/src/client/mod.rs index 5b79ed0..dec0006 100644 --- a/crates/lune-std-net/src/client/mod.rs +++ b/crates/lune-std-net/src/client/mod.rs @@ -1,4 +1,2 @@ -mod request; -mod stream; - -pub use self::request::{Request, Response}; +pub mod config; +pub mod stream; diff --git a/crates/lune-std-net/src/client/request.rs b/crates/lune-std-net/src/client/request.rs deleted file mode 100644 index 39597b8..0000000 --- a/crates/lune-std-net/src/client/request.rs +++ /dev/null @@ -1,115 +0,0 @@ -use bstr::BString; -use futures_lite::prelude::*; -use http_body_util::{BodyStream, Full}; -use hyper::{ - body::{Bytes, Incoming}, - client::conn::http1::handshake, - Method, Request as HyperRequest, Response as HyperResponse, -}; - -use mlua::prelude::*; - -use crate::{ - client::stream::HttpRequestStream, - shared::hyper::{HyperExecutor, HyperIo}, -}; - -#[derive(Debug, Clone)] -pub struct Request { - inner: HyperRequest>, -} - -impl Request { - pub async fn send(self, lua: Lua) -> LuaResult { - let stream = HttpRequestStream::connect(self.inner.uri()).await?; - - let (mut sender, conn) = handshake(HyperIo::from(stream)) - .await - .map_err(LuaError::external)?; - - HyperExecutor::execute(lua, conn); - - let incoming = sender - .send_request(self.inner) - .await - .map_err(LuaError::external)?; - - Response::from_incoming(incoming).await - } -} - -impl FromLua for Request { - fn from_lua(value: LuaValue, _lua: &Lua) -> LuaResult { - if let LuaValue::String(s) = value { - // We got a string, assume it's a URL + GET method - let uri = s.to_str()?; - Ok(Self { - inner: HyperRequest::builder() - .uri(uri.as_ref()) - .body(Full::new(Bytes::new())) - .into_lua_err()?, - }) - } else if let LuaValue::Table(t) = value { - // URL is always required with table options - let url = t.get::("url")?; - let builder = HyperRequest::builder().uri(url); - - // Add method, if provided - let builder = match t.get::>("method") { - Ok(Some(method)) => builder.method(method.as_str()), - Ok(None) => builder.method(Method::GET), - Err(e) => return Err(e), - }; - - // Add body, if provided - let builder = match t.get::>("body") { - Ok(Some(body)) => builder.body(Full::new(body.to_vec().into())), - Ok(None) => builder.body(Full::new(Bytes::new())), - Err(e) => return Err(e), - }; - - Ok(Self { - inner: builder.into_lua_err()?, - }) - } else { - Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: String::from("HttpRequest"), - message: Some(String::from("HttpRequest must be a string or table")), - }) - } - } -} - -#[derive(Debug, Clone)] -pub struct Response { - inner: HyperResponse>, -} - -impl Response { - pub async fn from_incoming(incoming: HyperResponse) -> LuaResult { - let (parts, body) = incoming.into_parts(); - - let body = BodyStream::new(body) - .try_fold(Vec::::new(), |mut body, chunk| { - if let Some(chunk) = chunk.data_ref() { - body.extend_from_slice(chunk); - } - Ok(body) - }) - .await - .into_lua_err()?; - - let bytes = Full::new(Bytes::from(body)); - let inner = HyperResponse::from_parts(parts, bytes); - - Ok(Self { inner }) - } -} - -impl LuaUserData for Response { - fn add_fields>(fields: &mut F) { - fields.add_field_method_get("ok", |_, this| Ok(this.inner.status().is_success())); - fields.add_field_method_get("status", |_, this| Ok(this.inner.status().as_u16())); - } -} diff --git a/crates/lune-std-net/src/lib.rs b/crates/lune-std-net/src/lib.rs index f39765f..dbdaefd 100644 --- a/crates/lune-std-net/src/lib.rs +++ b/crates/lune-std-net/src/lib.rs @@ -1,16 +1,15 @@ #![allow(clippy::cargo_common_metadata)] +use lune_utils::TableBuilder; use mlua::prelude::*; -use lune_utils::TableBuilder; - -mod client; -mod server; -mod url; - -use self::client::{Request, Response}; - +pub(crate) mod client; +pub(crate) mod server; pub(crate) mod shared; +pub(crate) mod url; + +use self::client::config::RequestConfig; +use self::shared::{request::Request, response::Response}; const TYPEDEFS: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/types.d.luau")); @@ -39,6 +38,8 @@ pub fn module(lua: Lua) -> LuaResult { .build_readonly() } -async fn net_request(lua: Lua, req: Request) -> LuaResult { - req.send(lua).await +async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult { + Request::from_config(config, lua.clone())? + .send(lua.clone()) + .await } diff --git a/crates/lune-std-net/src/shared/headers.rs b/crates/lune-std-net/src/shared/headers.rs new file mode 100644 index 0000000..d87a8e5 --- /dev/null +++ b/crates/lune-std-net/src/shared/headers.rs @@ -0,0 +1,95 @@ +use std::collections::HashMap; + +use hyper::{ + header::{CONTENT_ENCODING, CONTENT_LENGTH}, + HeaderMap, +}; + +use lune_utils::TableBuilder; +use mlua::prelude::*; + +pub fn create_user_agent_header(lua: &Lua) -> LuaResult { + let version_global = lua + .globals() + .get::("_VERSION") + .expect("Missing _VERSION global"); + + let version_global_str = version_global + .to_str() + .context("Invalid utf8 found in _VERSION global")?; + + let (package_name, full_version) = version_global_str.split_once(' ').unwrap(); + + Ok(format!("{}/{}", package_name.to_lowercase(), full_version)) +} + +pub fn header_map_to_table( + lua: &Lua, + headers: HeaderMap, + remove_content_headers: bool, +) -> LuaResult { + let mut res_headers = HashMap::>::new(); + for (name, value) in &headers { + let name = name.as_str(); + let value = value.to_str().unwrap().to_owned(); + if let Some(existing) = res_headers.get_mut(name) { + existing.push(value); + } else { + res_headers.insert(name.to_owned(), vec![value]); + } + } + + if remove_content_headers { + let content_encoding_header_str = CONTENT_ENCODING.as_str(); + let content_length_header_str = CONTENT_LENGTH.as_str(); + res_headers.retain(|name, _| { + name != content_encoding_header_str && name != content_length_header_str + }); + } + + let mut builder = TableBuilder::new(lua.clone())?; + for (name, mut values) in res_headers { + if values.len() == 1 { + let value = values.pop().unwrap().into_lua(lua)?; + builder = builder.with_value(name, value)?; + } else { + let values = TableBuilder::new(lua.clone())? + .with_sequential_values(values)? + .build_readonly()? + .into_lua(lua)?; + builder = builder.with_value(name, values)?; + } + } + + builder.build_readonly() +} + +pub fn table_to_hash_map( + tab: LuaTable, + tab_origin_key: &'static str, +) -> LuaResult>> { + let mut map = HashMap::new(); + + for pair in tab.pairs::() { + let (key, value) = pair?; + match value { + LuaValue::String(s) => { + map.insert(key, vec![s.to_str()?.to_owned()]); + } + LuaValue::Table(t) => { + let mut values = Vec::new(); + for value in t.sequence_values::() { + values.push(value?.to_str()?.to_owned()); + } + map.insert(key, values); + } + _ => { + return Err(LuaError::runtime(format!( + "Value for '{tab_origin_key}' must be a string or array of strings", + ))) + } + } + } + + Ok(map) +} diff --git a/crates/lune-std-net/src/shared/mod.rs b/crates/lune-std-net/src/shared/mod.rs index 4342b3c..5f8736e 100644 --- a/crates/lune-std-net/src/shared/mod.rs +++ b/crates/lune-std-net/src/shared/mod.rs @@ -1 +1,4 @@ +pub mod headers; pub mod hyper; +pub mod request; +pub mod response; diff --git a/crates/lune-std-net/src/shared/request.rs b/crates/lune-std-net/src/shared/request.rs new file mode 100644 index 0000000..b39123c --- /dev/null +++ b/crates/lune-std-net/src/shared/request.rs @@ -0,0 +1,100 @@ +use http_body_util::Full; + +use hyper::{ + body::Bytes, + client::conn::http1::handshake, + header::{HeaderName, HeaderValue, USER_AGENT}, + HeaderMap, Request as HyperRequest, +}; + +use mlua::prelude::*; +use url::Url; + +use crate::{ + client::{config::RequestConfig, stream::HttpRequestStream}, + shared::{ + headers::create_user_agent_header, + hyper::{HyperExecutor, HyperIo}, + response::Response, + }, +}; + +#[derive(Debug, Clone)] +pub struct Request { + inner: HyperRequest>, +} + +impl Request { + pub fn from_config(config: RequestConfig, lua: Lua) -> LuaResult { + // 1. Parse the URL and make sure it is valid + let mut url = Url::parse(&config.url).into_lua_err()?; + + // 2. Append any query pairs passed as a table + { + let mut query = url.query_pairs_mut(); + for (key, values) in config.query { + for value in values { + query.append_pair(&key, &value); + } + } + } + + // 3. Create the inner request builder + let mut builder = HyperRequest::builder() + .method(config.method) + .uri(url.as_str()); + + // 4. Append any headers passed as a table - builder + // headers may be None if builder is already invalid + if let Some(headers) = builder.headers_mut() { + for (key, values) in config.headers { + let key = HeaderName::from_bytes(key.as_bytes()).into_lua_err()?; + for value in values { + let value = HeaderValue::from_str(&value).into_lua_err()?; + headers.insert(key.clone(), value); + } + } + } + + // 5. Convert request body bytes to the proper Body + // type that Hyper expects, if we got any bytes + let body = config + .body + .map(Bytes::from) + .map(Full::new) + .unwrap_or_default(); + + // 6. Finally, attach the body, verifying that the request + // is valid, and attach a user agent if not already set + let mut inner = builder.body(body).into_lua_err()?; + add_default_headers(&lua, inner.headers_mut())?; + Ok(Self { inner }) + } + + pub async fn send(self, lua: Lua) -> LuaResult { + let stream = HttpRequestStream::connect(self.inner.uri()).await?; + + let (mut sender, conn) = handshake(HyperIo::from(stream)) + .await + .map_err(LuaError::external)?; + + HyperExecutor::execute(lua, conn); + + let incoming = sender + .send_request(self.inner) + .await + .map_err(LuaError::external)?; + + Response::from_incoming(incoming).await + } +} + +fn add_default_headers(lua: &Lua, headers: &mut HeaderMap) -> LuaResult<()> { + if !headers.contains_key(USER_AGENT) { + let ua = create_user_agent_header(lua)?; + let ua = HeaderValue::from_str(&ua).into_lua_err()?; + headers.insert(USER_AGENT, ua); + } + + Ok(()) +} diff --git a/crates/lune-std-net/src/shared/response.rs b/crates/lune-std-net/src/shared/response.rs new file mode 100644 index 0000000..e746c68 --- /dev/null +++ b/crates/lune-std-net/src/shared/response.rs @@ -0,0 +1,71 @@ +use futures_lite::prelude::*; +use http_body_util::BodyStream; + +use hyper::{ + body::{Bytes, Incoming}, + HeaderMap, Response as HyperResponse, +}; + +use mlua::prelude::*; + +use crate::shared::headers::header_map_to_table; + +#[derive(Debug, Clone)] +pub struct Response { + inner: HyperResponse, +} + +impl Response { + pub async fn from_incoming(incoming: HyperResponse) -> LuaResult { + let (parts, body) = incoming.into_parts(); + + let body = BodyStream::new(body) + .try_fold(Vec::::new(), |mut body, chunk| { + if let Some(chunk) = chunk.data_ref() { + body.extend_from_slice(chunk); + } + Ok(body) + }) + .await + .into_lua_err()?; + + let bytes = Bytes::from(body); + let inner = HyperResponse::from_parts(parts, bytes); + + Ok(Self { inner }) + } + + pub fn status_ok(&self) -> bool { + self.inner.status().is_success() + } + + pub fn status_code(&self) -> u16 { + self.inner.status().as_u16() + } + + pub fn status_message(&self) -> &str { + self.inner.status().canonical_reason().unwrap_or_default() + } + + pub fn headers(&self) -> &HeaderMap { + self.inner.headers() + } + + pub fn body(&self) -> &[u8] { + self.inner.body() + } +} + +impl LuaUserData for Response { + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("ok", |_, this| Ok(this.status_ok())); + fields.add_field_method_get("statusCode", |_, this| Ok(this.status_code())); + fields.add_field_method_get("statusMessage", |lua, this| { + lua.create_string(this.status_message()) + }); + fields.add_field_method_get("headers", |lua, this| { + header_map_to_table(lua, this.headers().clone(), false) + }); + fields.add_field_method_get("body", |lua, this| lua.create_string(this.body())); + } +}