diff --git a/src/lune/builtins/net/client.rs b/src/lune/builtins/net/client.rs new file mode 100644 index 0000000..5eb2527 --- /dev/null +++ b/src/lune/builtins/net/client.rs @@ -0,0 +1,165 @@ +use std::str::FromStr; + +use mlua::prelude::*; + +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_ENCODING}; + +use crate::lune::{ + builtins::serde::compress_decompress::{decompress, CompressDecompressFormat}, + util::TableBuilder, +}; + +use super::{config::RequestConfig, util::header_map_to_table}; + +const REGISTRY_KEY: &str = "NetClient"; + +pub struct NetClientBuilder { + builder: reqwest::ClientBuilder, +} + +impl NetClientBuilder { + pub fn new() -> NetClientBuilder { + Self { + builder: reqwest::ClientBuilder::new(), + } + } + + pub fn headers(mut self, headers: &[(K, V)]) -> LuaResult + where + K: AsRef, + V: AsRef<[u8]>, + { + let mut map = HeaderMap::new(); + for (key, val) in headers { + let hkey = HeaderName::from_str(key.as_ref()).into_lua_err()?; + let hval = HeaderValue::from_bytes(val.as_ref()).into_lua_err()?; + map.insert(hkey, hval); + } + self.builder = self.builder.default_headers(map); + Ok(self) + } + + pub fn build(self) -> LuaResult { + let client = self.builder.build().into_lua_err()?; + Ok(NetClient { inner: client }) + } +} + +#[derive(Debug, Clone)] +pub struct NetClient { + inner: reqwest::Client, +} + +impl NetClient { + pub fn from_registry(lua: &Lua) -> Self { + lua.named_registry_value(REGISTRY_KEY) + .expect("Failed to get NetClient from lua registry") + } + + pub fn into_registry(self, lua: &Lua) { + lua.set_named_registry_value(REGISTRY_KEY, self) + .expect("Failed to store NetClient in lua registry"); + } + + pub async fn request(&self, config: RequestConfig) -> LuaResult { + // Create and send the request + let mut request = self.inner.request(config.method, config.url); + for (query, values) in config.query { + request = request.query( + &values + .iter() + .map(|v| (query.as_str(), v)) + .collect::>(), + ); + } + for (header, values) in config.headers { + for value in values { + request = request.header(header.as_str(), value); + } + } + let res = request + .body(config.body.unwrap_or_default()) + .send() + .await + .into_lua_err()?; + + // Extract status, headers + let res_status = res.status().as_u16(); + let res_status_text = res.status().canonical_reason(); + let res_headers = res.headers().clone(); + + // Read response bytes + let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec(); + let mut res_decompressed = false; + + // Check for extra options, decompression + if config.options.decompress { + let decompress_format = res_headers + .iter() + .find(|(name, _)| { + name.as_str() + .eq_ignore_ascii_case(CONTENT_ENCODING.as_str()) + }) + .and_then(|(_, value)| value.to_str().ok()) + .and_then(CompressDecompressFormat::detect_from_header_str); + if let Some(format) = decompress_format { + res_bytes = decompress(format, res_bytes).await?; + res_decompressed = true; + } + } + + Ok(NetClientResponse { + ok: (200..300).contains(&res_status), + status_code: res_status, + status_message: res_status_text.unwrap_or_default().to_string(), + headers: res_headers, + body: res_bytes, + body_decompressed: res_decompressed, + }) + } +} + +impl LuaUserData for NetClient {} + +impl FromLua<'_> for NetClient { + fn from_lua(value: LuaValue, _: &Lua) -> LuaResult { + if let LuaValue::UserData(ud) = value { + if let Ok(ctx) = ud.borrow::() { + return Ok(ctx.clone()); + } + } + unreachable!("NetClient should only be used from registry") + } +} + +impl From<&Lua> for NetClient { + fn from(value: &Lua) -> Self { + value + .named_registry_value(REGISTRY_KEY) + .expect("Missing require context in lua registry") + } +} + +pub struct NetClientResponse { + ok: bool, + status_code: u16, + status_message: String, + headers: HeaderMap, + body: Vec, + body_decompressed: bool, +} + +impl NetClientResponse { + pub fn into_lua_table(self, lua: &Lua) -> LuaResult { + TableBuilder::new(lua)? + .with_value("ok", self.ok)? + .with_value("statusCode", self.status_code)? + .with_value("statusMessage", self.status_message)? + .with_value( + "headers", + header_map_to_table(lua, self.headers, self.body_decompressed)?, + )? + .with_value("body", lua.create_string(&self.body)?)? + .build_readonly() + } +} diff --git a/src/lune/builtins/net/mod.rs b/src/lune/builtins/net/mod.rs index 2567a38..8b21472 100644 --- a/src/lune/builtins/net/mod.rs +++ b/src/lune/builtins/net/mod.rs @@ -1,7 +1,9 @@ #![allow(unused_variables)] use mlua::prelude::*; +use mlua_luau_scheduler::LuaSpawnExt; +mod client; mod config; mod util; mod websocket; @@ -9,13 +11,19 @@ mod websocket; use crate::lune::util::TableBuilder; use self::{ + client::{NetClient, NetClientBuilder}, config::{RequestConfig, ServeConfig}, + util::create_user_agent_header, websocket::NetWebSocket, }; use super::serde::encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat}; pub fn create(lua: &Lua) -> LuaResult { + NetClientBuilder::new() + .headers(&[("User-Agent", create_user_agent_header())])? + .build()? + .into_registry(lua); TableBuilder::new(lua)? .with_function("jsonEncode", net_json_encode)? .with_function("jsonDecode", net_json_decode)? @@ -27,14 +35,6 @@ pub fn create(lua: &Lua) -> LuaResult { .build_readonly() } -fn _create_user_agent_header() -> String { - let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY") - .trim_start_matches("https://github.com/") - .split_once('/') - .unwrap(); - format!("{github_owner}-{github_repo}-cli") -} - fn net_json_encode<'lua>( lua: &'lua Lua, (val, pretty): (LuaValue<'lua>, Option), @@ -48,7 +48,10 @@ fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult LuaResult { - unimplemented!() + let client = NetClient::from_registry(lua); + // NOTE: We spawn the request as a background task to free up resources in lua + let res = lua.spawn(async move { client.request(config).await }); + res.await?.into_lua_table(lua) } async fn net_socket(lua: &Lua, url: String) -> LuaResult { diff --git a/src/lune/builtins/net/util.rs b/src/lune/builtins/net/util.rs index fa1e2ad..4603547 100644 --- a/src/lune/builtins/net/util.rs +++ b/src/lune/builtins/net/util.rs @@ -1,14 +1,20 @@ use std::collections::HashMap; -use hyper::{ - header::{CONTENT_ENCODING, CONTENT_LENGTH}, - HeaderMap, -}; +use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH}; +use reqwest::header::HeaderMap; use mlua::prelude::*; use crate::lune::util::TableBuilder; +pub fn create_user_agent_header() -> String { + let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY") + .trim_start_matches("https://github.com/") + .split_once('/') + .unwrap(); + format!("{github_owner}-{github_repo}-cli") +} + pub fn header_map_to_table( lua: &Lua, headers: HeaderMap,