Add back some old structs and utils for net requests + most tests now pass

This commit is contained in:
Filip Tibell 2025-04-26 21:00:35 +02:00
parent 07744d0079
commit 14197d9398
10 changed files with 426 additions and 129 deletions

1
Cargo.lock generated
View file

@ -1869,6 +1869,7 @@ dependencies = [
"pin-project-lite",
"rustls 0.23.26",
"rustls-pki-types",
"url",
"webpki",
"webpki-roots 0.26.8",
]

View file

@ -29,6 +29,7 @@ hyper = { version = "1.6", features = ["http1", "client", "server"] }
pin-project-lite = "0.2"
rustls = "0.23"
rustls-pki-types = "1.11"
url = "2.5"
webpki = "0.22"
webpki-roots = "0.26"

View file

@ -0,0 +1,142 @@
use std::collections::HashMap;
use bstr::{BString, ByteSlice};
use hyper::Method;
use mlua::prelude::*;
use crate::shared::headers::table_to_hash_map;
#[derive(Debug, Clone)]
pub struct RequestConfigOptions {
pub decompress: bool,
}
impl Default for RequestConfigOptions {
fn default() -> Self {
Self { decompress: true }
}
}
impl FromLua for RequestConfigOptions {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::Nil = value {
// Nil means default options
Ok(Self::default())
} else if let LuaValue::Table(tab) = value {
// Table means custom options
let decompress = match tab.get::<Option<bool>>("decompress") {
Ok(decomp) => Ok(decomp.unwrap_or(true)),
Err(_) => Err(LuaError::RuntimeError(
"Invalid option value for 'decompress' in request config options".to_string(),
)),
}?;
Ok(Self { decompress })
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfigOptions".to_string(),
message: Some(format!(
"Invalid request config options - expected table or nil, got {}",
value.type_name()
)),
})
}
}
}
#[derive(Debug, Clone)]
pub struct RequestConfig {
pub url: String,
pub method: Method,
pub query: HashMap<String, Vec<String>>,
pub headers: HashMap<String, Vec<String>>,
pub body: Option<Vec<u8>>,
pub options: RequestConfigOptions,
}
impl FromLua for RequestConfig {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
// If we just got a string we assume its a GET request to a given url
if let LuaValue::String(s) = value {
Ok(Self {
url: s.to_string_lossy().to_string(),
method: Method::GET,
query: HashMap::new(),
headers: HashMap::new(),
body: None,
options: RequestConfigOptions::default(),
})
} else if let LuaValue::Table(tab) = value {
// If we got a table we are able to configure the entire request
// Extract url
let url = match tab.get::<LuaString>("url") {
Ok(config_url) => Ok(config_url.to_string_lossy().to_string()),
Err(_) => Err(LuaError::runtime("Missing 'url' in request config")),
}?;
// Extract method
let method = match tab.get::<LuaString>("method") {
Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(),
Err(_) => "GET".to_string(),
};
// Extract query
let query = match tab.get::<LuaTable>("query") {
Ok(tab) => table_to_hash_map(tab, "query")?,
Err(_) => HashMap::new(),
};
// Extract headers
let headers = match tab.get::<LuaTable>("headers") {
Ok(tab) => table_to_hash_map(tab, "headers")?,
Err(_) => HashMap::new(),
};
// Extract body
let body = match tab.get::<BString>("body") {
Ok(config_body) => Some(config_body.as_bytes().to_owned()),
Err(_) => None,
};
// Convert method string into proper enum
let method = method.trim().to_ascii_uppercase();
let method = match method.as_ref() {
"GET" => Ok(Method::GET),
"POST" => Ok(Method::POST),
"PUT" => Ok(Method::PUT),
"DELETE" => Ok(Method::DELETE),
"HEAD" => Ok(Method::HEAD),
"OPTIONS" => Ok(Method::OPTIONS),
"PATCH" => Ok(Method::PATCH),
_ => Err(LuaError::RuntimeError(format!(
"Invalid request config method '{}'",
&method
))),
}?;
// Parse any extra options given
let options = match tab.get::<LuaValue>("options") {
Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?,
Err(_) => RequestConfigOptions::default(),
};
// All good, validated and we got what we need
Ok(Self {
url,
method,
query,
headers,
body,
options,
})
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfig".to_string(),
message: Some(format!(
"Invalid request config - expected string or table, got {}",
value.type_name()
)),
})
}
}
}

View file

@ -1,4 +1,2 @@
mod request;
mod stream;
pub use self::request::{Request, Response};
pub mod config;
pub mod stream;

View file

@ -1,115 +0,0 @@
use bstr::BString;
use futures_lite::prelude::*;
use http_body_util::{BodyStream, Full};
use hyper::{
body::{Bytes, Incoming},
client::conn::http1::handshake,
Method, Request as HyperRequest, Response as HyperResponse,
};
use mlua::prelude::*;
use crate::{
client::stream::HttpRequestStream,
shared::hyper::{HyperExecutor, HyperIo},
};
#[derive(Debug, Clone)]
pub struct Request {
inner: HyperRequest<Full<Bytes>>,
}
impl Request {
pub async fn send(self, lua: Lua) -> LuaResult<Response> {
let stream = HttpRequestStream::connect(self.inner.uri()).await?;
let (mut sender, conn) = handshake(HyperIo::from(stream))
.await
.map_err(LuaError::external)?;
HyperExecutor::execute(lua, conn);
let incoming = sender
.send_request(self.inner)
.await
.map_err(LuaError::external)?;
Response::from_incoming(incoming).await
}
}
impl FromLua for Request {
fn from_lua(value: LuaValue, _lua: &Lua) -> LuaResult<Self> {
if let LuaValue::String(s) = value {
// We got a string, assume it's a URL + GET method
let uri = s.to_str()?;
Ok(Self {
inner: HyperRequest::builder()
.uri(uri.as_ref())
.body(Full::new(Bytes::new()))
.into_lua_err()?,
})
} else if let LuaValue::Table(t) = value {
// URL is always required with table options
let url = t.get::<String>("url")?;
let builder = HyperRequest::builder().uri(url);
// Add method, if provided
let builder = match t.get::<Option<String>>("method") {
Ok(Some(method)) => builder.method(method.as_str()),
Ok(None) => builder.method(Method::GET),
Err(e) => return Err(e),
};
// Add body, if provided
let builder = match t.get::<Option<BString>>("body") {
Ok(Some(body)) => builder.body(Full::new(body.to_vec().into())),
Ok(None) => builder.body(Full::new(Bytes::new())),
Err(e) => return Err(e),
};
Ok(Self {
inner: builder.into_lua_err()?,
})
} else {
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: String::from("HttpRequest"),
message: Some(String::from("HttpRequest must be a string or table")),
})
}
}
}
#[derive(Debug, Clone)]
pub struct Response {
inner: HyperResponse<Full<Bytes>>,
}
impl Response {
pub async fn from_incoming(incoming: HyperResponse<Incoming>) -> LuaResult<Self> {
let (parts, body) = incoming.into_parts();
let body = BodyStream::new(body)
.try_fold(Vec::<u8>::new(), |mut body, chunk| {
if let Some(chunk) = chunk.data_ref() {
body.extend_from_slice(chunk);
}
Ok(body)
})
.await
.into_lua_err()?;
let bytes = Full::new(Bytes::from(body));
let inner = HyperResponse::from_parts(parts, bytes);
Ok(Self { inner })
}
}
impl LuaUserData for Response {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ok", |_, this| Ok(this.inner.status().is_success()));
fields.add_field_method_get("status", |_, this| Ok(this.inner.status().as_u16()));
}
}

View file

@ -1,16 +1,15 @@
#![allow(clippy::cargo_common_metadata)]
use lune_utils::TableBuilder;
use mlua::prelude::*;
use lune_utils::TableBuilder;
mod client;
mod server;
mod url;
use self::client::{Request, Response};
pub(crate) mod client;
pub(crate) mod server;
pub(crate) mod shared;
pub(crate) mod url;
use self::client::config::RequestConfig;
use self::shared::{request::Request, response::Response};
const TYPEDEFS: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/types.d.luau"));
@ -39,6 +38,8 @@ pub fn module(lua: Lua) -> LuaResult<LuaTable> {
.build_readonly()
}
async fn net_request(lua: Lua, req: Request) -> LuaResult<Response> {
req.send(lua).await
async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult<Response> {
Request::from_config(config, lua.clone())?
.send(lua.clone())
.await
}

View file

@ -0,0 +1,95 @@
use std::collections::HashMap;
use hyper::{
header::{CONTENT_ENCODING, CONTENT_LENGTH},
HeaderMap,
};
use lune_utils::TableBuilder;
use mlua::prelude::*;
pub fn create_user_agent_header(lua: &Lua) -> LuaResult<String> {
let version_global = lua
.globals()
.get::<LuaString>("_VERSION")
.expect("Missing _VERSION global");
let version_global_str = version_global
.to_str()
.context("Invalid utf8 found in _VERSION global")?;
let (package_name, full_version) = version_global_str.split_once(' ').unwrap();
Ok(format!("{}/{}", package_name.to_lowercase(), full_version))
}
pub fn header_map_to_table(
lua: &Lua,
headers: HeaderMap,
remove_content_headers: bool,
) -> LuaResult<LuaTable> {
let mut res_headers = HashMap::<String, Vec<String>>::new();
for (name, value) in &headers {
let name = name.as_str();
let value = value.to_str().unwrap().to_owned();
if let Some(existing) = res_headers.get_mut(name) {
existing.push(value);
} else {
res_headers.insert(name.to_owned(), vec![value]);
}
}
if remove_content_headers {
let content_encoding_header_str = CONTENT_ENCODING.as_str();
let content_length_header_str = CONTENT_LENGTH.as_str();
res_headers.retain(|name, _| {
name != content_encoding_header_str && name != content_length_header_str
});
}
let mut builder = TableBuilder::new(lua.clone())?;
for (name, mut values) in res_headers {
if values.len() == 1 {
let value = values.pop().unwrap().into_lua(lua)?;
builder = builder.with_value(name, value)?;
} else {
let values = TableBuilder::new(lua.clone())?
.with_sequential_values(values)?
.build_readonly()?
.into_lua(lua)?;
builder = builder.with_value(name, values)?;
}
}
builder.build_readonly()
}
pub fn table_to_hash_map(
tab: LuaTable,
tab_origin_key: &'static str,
) -> LuaResult<HashMap<String, Vec<String>>> {
let mut map = HashMap::new();
for pair in tab.pairs::<String, LuaValue>() {
let (key, value) = pair?;
match value {
LuaValue::String(s) => {
map.insert(key, vec![s.to_str()?.to_owned()]);
}
LuaValue::Table(t) => {
let mut values = Vec::new();
for value in t.sequence_values::<LuaString>() {
values.push(value?.to_str()?.to_owned());
}
map.insert(key, values);
}
_ => {
return Err(LuaError::runtime(format!(
"Value for '{tab_origin_key}' must be a string or array of strings",
)))
}
}
}
Ok(map)
}

View file

@ -1 +1,4 @@
pub mod headers;
pub mod hyper;
pub mod request;
pub mod response;

View file

@ -0,0 +1,100 @@
use http_body_util::Full;
use hyper::{
body::Bytes,
client::conn::http1::handshake,
header::{HeaderName, HeaderValue, USER_AGENT},
HeaderMap, Request as HyperRequest,
};
use mlua::prelude::*;
use url::Url;
use crate::{
client::{config::RequestConfig, stream::HttpRequestStream},
shared::{
headers::create_user_agent_header,
hyper::{HyperExecutor, HyperIo},
response::Response,
},
};
#[derive(Debug, Clone)]
pub struct Request {
inner: HyperRequest<Full<Bytes>>,
}
impl Request {
pub fn from_config(config: RequestConfig, lua: Lua) -> LuaResult<Self> {
// 1. Parse the URL and make sure it is valid
let mut url = Url::parse(&config.url).into_lua_err()?;
// 2. Append any query pairs passed as a table
{
let mut query = url.query_pairs_mut();
for (key, values) in config.query {
for value in values {
query.append_pair(&key, &value);
}
}
}
// 3. Create the inner request builder
let mut builder = HyperRequest::builder()
.method(config.method)
.uri(url.as_str());
// 4. Append any headers passed as a table - builder
// headers may be None if builder is already invalid
if let Some(headers) = builder.headers_mut() {
for (key, values) in config.headers {
let key = HeaderName::from_bytes(key.as_bytes()).into_lua_err()?;
for value in values {
let value = HeaderValue::from_str(&value).into_lua_err()?;
headers.insert(key.clone(), value);
}
}
}
// 5. Convert request body bytes to the proper Body
// type that Hyper expects, if we got any bytes
let body = config
.body
.map(Bytes::from)
.map(Full::new)
.unwrap_or_default();
// 6. Finally, attach the body, verifying that the request
// is valid, and attach a user agent if not already set
let mut inner = builder.body(body).into_lua_err()?;
add_default_headers(&lua, inner.headers_mut())?;
Ok(Self { inner })
}
pub async fn send(self, lua: Lua) -> LuaResult<Response> {
let stream = HttpRequestStream::connect(self.inner.uri()).await?;
let (mut sender, conn) = handshake(HyperIo::from(stream))
.await
.map_err(LuaError::external)?;
HyperExecutor::execute(lua, conn);
let incoming = sender
.send_request(self.inner)
.await
.map_err(LuaError::external)?;
Response::from_incoming(incoming).await
}
}
fn add_default_headers(lua: &Lua, headers: &mut HeaderMap) -> LuaResult<()> {
if !headers.contains_key(USER_AGENT) {
let ua = create_user_agent_header(lua)?;
let ua = HeaderValue::from_str(&ua).into_lua_err()?;
headers.insert(USER_AGENT, ua);
}
Ok(())
}

View file

@ -0,0 +1,71 @@
use futures_lite::prelude::*;
use http_body_util::BodyStream;
use hyper::{
body::{Bytes, Incoming},
HeaderMap, Response as HyperResponse,
};
use mlua::prelude::*;
use crate::shared::headers::header_map_to_table;
#[derive(Debug, Clone)]
pub struct Response {
inner: HyperResponse<Bytes>,
}
impl Response {
pub async fn from_incoming(incoming: HyperResponse<Incoming>) -> LuaResult<Self> {
let (parts, body) = incoming.into_parts();
let body = BodyStream::new(body)
.try_fold(Vec::<u8>::new(), |mut body, chunk| {
if let Some(chunk) = chunk.data_ref() {
body.extend_from_slice(chunk);
}
Ok(body)
})
.await
.into_lua_err()?;
let bytes = Bytes::from(body);
let inner = HyperResponse::from_parts(parts, bytes);
Ok(Self { inner })
}
pub fn status_ok(&self) -> bool {
self.inner.status().is_success()
}
pub fn status_code(&self) -> u16 {
self.inner.status().as_u16()
}
pub fn status_message(&self) -> &str {
self.inner.status().canonical_reason().unwrap_or_default()
}
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
pub fn body(&self) -> &[u8] {
self.inner.body()
}
}
impl LuaUserData for Response {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ok", |_, this| Ok(this.status_ok()));
fields.add_field_method_get("statusCode", |_, this| Ok(this.status_code()));
fields.add_field_method_get("statusMessage", |lua, this| {
lua.create_string(this.status_message())
});
fields.add_field_method_get("headers", |lua, this| {
header_map_to_table(lua, this.headers().clone(), false)
});
fields.add_field_method_get("body", |lua, this| lua.create_string(this.body()));
}
}