Implement zero-copy hyper body type that wraps over lua values

This commit is contained in:
Filip Tibell 2025-04-29 23:00:03 +02:00
parent 4079842a33
commit ac8c809a20
No known key found for this signature in database
14 changed files with 354 additions and 125 deletions

View file

@ -8,6 +8,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Changed
- Significantly improved performance of `net.request` and `net.serve` when handling large request bodies
## `0.9.1` - April 29th, 2025 ## `0.9.1` - April 29th, 2025
### Added ### Added

View file

@ -0,0 +1,59 @@
use hyper::body::{Buf, Bytes};
use super::inner::ReadableBodyInner;
/**
The cursor keeping track of inner data and its position for a readable body.
*/
#[derive(Debug, Clone)]
pub struct ReadableBodyCursor {
inner: ReadableBodyInner,
start: usize,
}
impl ReadableBodyCursor {
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn as_slice(&self) -> &[u8] {
&self.inner.as_slice()[self.start..]
}
pub fn advance(&mut self, cnt: usize) {
self.start += cnt;
if self.start > self.inner.len() {
self.start = self.inner.len();
}
}
pub fn into_bytes(self) -> Bytes {
self.inner.into_bytes()
}
}
impl Buf for ReadableBodyCursor {
fn remaining(&self) -> usize {
self.len().saturating_sub(self.start)
}
fn chunk(&self) -> &[u8] {
self.as_slice()
}
fn advance(&mut self, cnt: usize) {
self.advance(cnt);
}
}
impl<T> From<T> for ReadableBodyCursor
where
T: Into<ReadableBodyInner>,
{
fn from(value: T) -> Self {
Self {
inner: value.into(),
start: 0,
}
}
}

View file

@ -1,4 +1,4 @@
use http_body_util::{BodyExt, Full}; use http_body_util::BodyExt;
use hyper::{ use hyper::{
body::{Bytes, Incoming}, body::{Bytes, Incoming},
header::CONTENT_ENCODING, header::CONTENT_ENCODING,
@ -33,11 +33,3 @@ pub async fn handle_incoming_body(
Ok((body, was_decompressed)) Ok((body, was_decompressed))
} }
pub fn bytes_to_full(bytes: Bytes) -> Full<Bytes> {
if bytes.is_empty() {
Full::default()
} else {
Full::new(bytes)
}
}

View file

@ -0,0 +1,110 @@
use hyper::body::{Buf as _, Bytes};
use mlua::{prelude::*, Buffer as LuaBuffer};
/**
The inner data for a readable body.
*/
#[derive(Debug, Clone)]
pub enum ReadableBodyInner {
Bytes(Bytes),
String(String),
LuaString(LuaString),
LuaBuffer(LuaBuffer),
}
impl ReadableBodyInner {
pub fn len(&self) -> usize {
match self {
Self::Bytes(b) => b.len(),
Self::String(s) => s.len(),
Self::LuaString(s) => s.as_bytes().len(),
Self::LuaBuffer(b) => b.len(),
}
}
pub fn as_slice(&self) -> &[u8] {
/*
SAFETY: Reading lua strings and lua buffers as raw slices is safe while we can
guarantee that the inner Lua value + main lua struct has not yet been dropped
1. Buffers are fixed-size and guaranteed to never resize
2. We do not expose any method for writing to the body, only reading
3. We guarantee that net.request and net.serve futures are only driven forward
while we also know that the Lua + scheduler pair have not yet been dropped
4. Any writes from within lua to a buffer, are considered user error,
and are not unsafe, since the only possible outcome with the above
guarantees is invalid / mangled contents in request / response bodies
*/
match self {
Self::Bytes(b) => b.chunk(),
Self::String(s) => s.as_bytes(),
Self::LuaString(s) => unsafe {
// BorrowedBytes would not let us return a plain slice here,
// which is what the Buf implementation below needs - we need to
// do a little hack here to re-create the slice without a lifetime
let b = s.as_bytes();
let ptr = b.as_ptr();
let len = b.len();
std::slice::from_raw_parts(ptr, len)
},
Self::LuaBuffer(b) => unsafe {
// Similar to above, we need to get the raw slice for the buffer,
// which is a bit trickier here because Buffer has a read + write
// interface instead of using slices for some unknown reason
let v = LuaValue::Buffer(b.clone());
let ptr = v.to_pointer().cast::<u8>();
let len = b.len();
std::slice::from_raw_parts(ptr, len)
},
}
}
pub fn into_bytes(self) -> Bytes {
match self {
Self::Bytes(b) => b,
Self::String(s) => Bytes::from(s),
Self::LuaString(s) => Bytes::from(s.as_bytes().to_vec()),
Self::LuaBuffer(b) => Bytes::from(b.to_vec()),
}
}
}
impl From<&'static str> for ReadableBodyInner {
fn from(value: &'static str) -> Self {
Self::Bytes(Bytes::from(value))
}
}
impl From<Vec<u8>> for ReadableBodyInner {
fn from(value: Vec<u8>) -> Self {
Self::Bytes(Bytes::from(value))
}
}
impl From<Bytes> for ReadableBodyInner {
fn from(value: Bytes) -> Self {
Self::Bytes(value)
}
}
impl From<String> for ReadableBodyInner {
fn from(value: String) -> Self {
Self::String(value)
}
}
impl From<LuaString> for ReadableBodyInner {
fn from(value: LuaString) -> Self {
Self::LuaString(value)
}
}
impl From<LuaBuffer> for ReadableBodyInner {
fn from(value: LuaBuffer) -> Self {
Self::LuaBuffer(value)
}
}

View file

@ -0,0 +1,11 @@
#![allow(unused_imports)]
mod cursor;
mod incoming;
mod inner;
mod readable;
pub use self::cursor::ReadableBodyCursor;
pub use self::incoming::handle_incoming_body;
pub use self::inner::ReadableBodyInner;
pub use self::readable::ReadableBody;

View file

@ -0,0 +1,105 @@
use std::convert::Infallible;
use std::pin::Pin;
use std::task::{Context, Poll};
use hyper::body::{Body, Bytes, Frame, SizeHint};
use mlua::prelude::*;
use super::cursor::ReadableBodyCursor;
/**
Zero-copy wrapper for a readable body.
Provides methods to read bytes that can be safely used if, and only
if, the respective Lua struct for the body has not yet been dropped.
If the body was created from a `Vec<u8>`, `Bytes`, or a `String`, reading
bytes is always safe and does not go through any additional indirections.
*/
#[derive(Debug, Clone)]
pub struct ReadableBody {
cursor: Option<ReadableBodyCursor>,
}
impl ReadableBody {
pub const fn empty() -> Self {
Self { cursor: None }
}
pub fn as_slice(&self) -> &[u8] {
match self.cursor.as_ref() {
Some(cursor) => cursor.as_slice(),
None => &[],
}
}
pub fn into_bytes(self) -> Bytes {
match self.cursor {
Some(cursor) => cursor.into_bytes(),
None => Bytes::new(),
}
}
}
impl Body for ReadableBody {
type Data = ReadableBodyCursor;
type Error = Infallible;
fn poll_frame(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
Poll::Ready(self.cursor.take().map(|d| Ok(Frame::data(d))))
}
fn is_end_stream(&self) -> bool {
self.cursor.is_none()
}
fn size_hint(&self) -> SizeHint {
self.cursor.as_ref().map_or_else(
|| SizeHint::with_exact(0),
|c| SizeHint::with_exact(c.len() as u64),
)
}
}
impl<T> From<T> for ReadableBody
where
T: Into<ReadableBodyCursor>,
{
fn from(value: T) -> Self {
Self {
cursor: Some(value.into()),
}
}
}
impl<T> From<Option<T>> for ReadableBody
where
T: Into<ReadableBodyCursor>,
{
fn from(value: Option<T>) -> Self {
Self {
cursor: value.map(Into::into),
}
}
}
impl FromLua for ReadableBody {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
match value {
LuaValue::Nil => Ok(Self::empty()),
LuaValue::String(str) => Ok(Self::from(str)),
LuaValue::Buffer(buf) => Ok(Self::from(buf)),
v => Err(LuaError::FromLuaConversionError {
from: v.type_name(),
to: "Body".to_string(),
message: Some(format!(
"Invalid body - expected string or buffer, got {}",
v.type_name()
)),
}),
}
}
}

View file

@ -1,5 +1,6 @@
use http_body_util::Full;
use hyper::{ use hyper::{
body::{Bytes, Incoming}, body::Incoming,
client::conn::http1::handshake, client::conn::http1::handshake,
header::{HeaderValue, ACCEPT, CONTENT_LENGTH, HOST, LOCATION, USER_AGENT}, header::{HeaderValue, ACCEPT, CONTENT_LENGTH, HOST, LOCATION, USER_AGENT},
Method, Request as HyperRequest, Response as HyperResponse, Uri, Method, Request as HyperRequest, Response as HyperResponse, Uri,
@ -9,6 +10,7 @@ use mlua::prelude::*;
use url::Url; use url::Url;
use crate::{ use crate::{
body::ReadableBody,
client::{http_stream::HttpStream, ws_stream::WsStream}, client::{http_stream::HttpStream, ws_stream::WsStream},
shared::{ shared::{
headers::create_user_agent_header, headers::create_user_agent_header,
@ -61,7 +63,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response>
request.inner.headers_mut().insert(USER_AGENT, ua); request.inner.headers_mut().insert(USER_AGENT, ua);
} }
if !request.headers().contains_key(CONTENT_LENGTH.as_str()) && request.method() != Method::GET { if !request.headers().contains_key(CONTENT_LENGTH.as_str()) && request.method() != Method::GET {
let len = request.inner.body().len().to_string(); let len = request.body().len().to_string();
let len = HeaderValue::from_str(&len).into_lua_err()?; let len = HeaderValue::from_str(&len).into_lua_err()?;
request.inner.headers_mut().insert(CONTENT_LENGTH, len); request.inner.headers_mut().insert(CONTENT_LENGTH, len);
} }
@ -78,18 +80,19 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response>
HyperExecutor::execute(lua.clone(), conn); HyperExecutor::execute(lua.clone(), conn);
let incoming = sender let (parts, body) = request.clone_inner().into_parts();
.send_request(request.as_full()) let data = HyperRequest::from_parts(parts, Full::new(body.into_bytes()));
.await let incoming = sender.send_request(data).await.into_lua_err()?;
.into_lua_err()?;
if let Some((new_method, new_uri)) = check_redirect(&request.inner, &incoming) { if let Some((new_method, new_uri)) =
check_redirect(request.inner.method().clone(), &incoming)
{
if request.redirects.is_some_and(|r| r >= MAX_REDIRECTS) { if request.redirects.is_some_and(|r| r >= MAX_REDIRECTS) {
return Err(LuaError::external("Too many redirects")); return Err(LuaError::external("Too many redirects"));
} }
if new_method == Method::GET { if new_method == Method::GET {
*request.inner.body_mut() = Bytes::new(); *request.inner.body_mut() = ReadableBody::empty();
} }
*request.inner.method_mut() = new_method; *request.inner.method_mut() = new_method;
@ -104,10 +107,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response>
} }
} }
fn check_redirect( fn check_redirect(method: Method, response: &HyperResponse<Incoming>) -> Option<(Method, Uri)> {
request: &HyperRequest<Bytes>,
response: &HyperResponse<Incoming>,
) -> Option<(Method, Uri)> {
if !response.status().is_redirection() { if !response.status().is_redirection() {
return None; return None;
} }
@ -118,7 +118,7 @@ fn check_redirect(
let method = match response.status().as_u16() { let method = match response.status().as_u16() {
301..=303 => Method::GET, 301..=303 => Method::GET,
_ => request.method().clone(), _ => method,
}; };
Some((method, location)) Some((method, location))

View file

@ -3,6 +3,7 @@
use lune_utils::TableBuilder; use lune_utils::TableBuilder;
use mlua::prelude::*; use mlua::prelude::*;
pub(crate) mod body;
pub(crate) mod client; pub(crate) mod client;
pub(crate) mod server; pub(crate) mod server;
pub(crate) mod shared; pub(crate) mod shared;

View file

@ -1,17 +1,16 @@
use std::{future::Future, net::SocketAddr, pin::Pin}; use std::{future::Future, net::SocketAddr, pin::Pin};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use http_body_util::Full;
use hyper::{ use hyper::{
body::{Bytes, Incoming}, body::Incoming, service::Service as HyperService, Request as HyperRequest,
service::Service as HyperService, Response as HyperResponse, StatusCode,
Request as HyperRequest, Response as HyperResponse, StatusCode,
}; };
use mlua::prelude::*; use mlua::prelude::*;
use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt}; use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt};
use crate::{ use crate::{
body::ReadableBody,
server::{ server::{
config::ServeConfig, config::ServeConfig,
upgrade::{is_upgrade_request, make_upgrade_response}, upgrade::{is_upgrade_request, make_upgrade_response},
@ -27,7 +26,7 @@ pub(super) struct Service {
} }
impl HyperService<HyperRequest<Incoming>> for Service { impl HyperService<HyperRequest<Incoming>> for Service {
type Response = HyperResponse<Full<Bytes>>; type Response = HyperResponse<ReadableBody>;
type Error = LuaError; type Error = LuaError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
@ -41,7 +40,7 @@ impl HyperService<HyperRequest<Incoming>> for Service {
Err(err) => { Err(err) => {
return Ok(HyperResponse::builder() return Ok(HyperResponse::builder()
.status(StatusCode::BAD_REQUEST) .status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from(err.to_string()))) .body(ReadableBody::from(err.to_string()))
.unwrap()) .unwrap())
} }
}; };
@ -70,7 +69,7 @@ impl HyperService<HyperRequest<Incoming>> for Service {
// TODO: Propagate the error somehow? // TODO: Propagate the error somehow?
Ok(HyperResponse::builder() Ok(HyperResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR) .status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::new(Bytes::from("Lune: Internal server error"))) .body(ReadableBody::from("Lune: Internal server error"))
.unwrap()) .unwrap())
} }
} }
@ -83,7 +82,7 @@ async fn handle_request(
handler: LuaFunction, handler: LuaFunction,
request: HyperRequest<Incoming>, request: HyperRequest<Incoming>,
address: SocketAddr, address: SocketAddr,
) -> LuaResult<HyperResponse<Full<Bytes>>> { ) -> LuaResult<HyperResponse<ReadableBody>> {
let request = Request::from_incoming(request, true) let request = Request::from_incoming(request, true)
.await? .await?
.with_address(address); .with_address(address);
@ -97,7 +96,7 @@ async fn handle_request(
.expect("Missing handler thread result")?; .expect("Missing handler thread result")?;
let response = Response::from_lua_multi(thread_res, &lua)?; let response = Response::from_lua_multi(thread_res, &lua)?;
Ok(response.into_full()) Ok(response.into_inner())
} }
async fn handle_websocket( async fn handle_websocket(

View file

@ -1,12 +1,13 @@
use async_tungstenite::tungstenite::{error::ProtocolError, handshake::derive_accept_key}; use async_tungstenite::tungstenite::{error::ProtocolError, handshake::derive_accept_key};
use http_body_util::Full;
use hyper::{ use hyper::{
body::{Bytes, Incoming}, body::Incoming,
header::{HeaderName, CONNECTION, UPGRADE}, header::{HeaderName, CONNECTION, UPGRADE},
HeaderMap, Request as HyperRequest, Response as HyperResponse, StatusCode, HeaderMap, Request as HyperRequest, Response as HyperResponse, StatusCode,
}; };
use crate::body::ReadableBody;
const SEC_WEBSOCKET_VERSION: HeaderName = HeaderName::from_static("sec-websocket-version"); const SEC_WEBSOCKET_VERSION: HeaderName = HeaderName::from_static("sec-websocket-version");
const SEC_WEBSOCKET_KEY: HeaderName = HeaderName::from_static("sec-websocket-key"); const SEC_WEBSOCKET_KEY: HeaderName = HeaderName::from_static("sec-websocket-key");
const SEC_WEBSOCKET_ACCEPT: HeaderName = HeaderName::from_static("sec-websocket-accept"); const SEC_WEBSOCKET_ACCEPT: HeaderName = HeaderName::from_static("sec-websocket-accept");
@ -31,7 +32,7 @@ pub fn is_upgrade_request(request: &HyperRequest<Incoming>) -> bool {
pub fn make_upgrade_response( pub fn make_upgrade_response(
request: &HyperRequest<Incoming>, request: &HyperRequest<Incoming>,
) -> Result<HyperResponse<Full<Bytes>>, ProtocolError> { ) -> Result<HyperResponse<ReadableBody>, ProtocolError> {
let key = request let key = request
.headers() .headers()
.get(SEC_WEBSOCKET_KEY) .get(SEC_WEBSOCKET_KEY)
@ -50,6 +51,6 @@ pub fn make_upgrade_response(
.header(CONNECTION, "upgrade") .header(CONNECTION, "upgrade")
.header(UPGRADE, "websocket") .header(UPGRADE, "websocket")
.header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(key.as_bytes())) .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(key.as_bytes()))
.body(Full::new(Bytes::from("switching to websocket protocol"))) .body(ReadableBody::from("switching to websocket protocol"))
.unwrap()) .unwrap())
} }

View file

@ -1,25 +1,9 @@
use hyper::{ use hyper::{
body::Bytes,
header::{HeaderName, HeaderValue}, header::{HeaderName, HeaderValue},
HeaderMap, Method, HeaderMap, Method,
}; };
use mlua::prelude::*;
pub fn lua_value_to_bytes(value: &LuaValue) -> LuaResult<Bytes> { use mlua::prelude::*;
match value {
LuaValue::Nil => Ok(Bytes::new()),
LuaValue::Buffer(buf) => Ok(Bytes::from(buf.to_vec())),
LuaValue::String(str) => Ok(Bytes::copy_from_slice(&str.as_bytes())),
v => Err(LuaError::FromLuaConversionError {
from: v.type_name(),
to: "Bytes".to_string(),
message: Some(format!(
"Invalid body - expected string or buffer, got {}",
v.type_name()
)),
}),
}
}
pub fn lua_value_to_method(value: &LuaValue) -> LuaResult<Method> { pub fn lua_value_to_method(value: &LuaValue) -> LuaResult<Method> {
match value { match value {

View file

@ -1,4 +1,3 @@
pub mod body;
pub mod futures; pub mod futures;
pub mod headers; pub mod headers;
pub mod hyper; pub mod hyper;

View file

@ -1,19 +1,17 @@
use std::{collections::HashMap, net::SocketAddr}; use std::{collections::HashMap, net::SocketAddr};
use http_body_util::Full;
use url::Url; use url::Url;
use hyper::{ use hyper::{body::Incoming, HeaderMap, Method, Request as HyperRequest};
body::{Bytes, Incoming},
HeaderMap, Method, Request as HyperRequest,
};
use mlua::prelude::*; use mlua::prelude::*;
use crate::shared::{ use crate::{
body::{bytes_to_full, handle_incoming_body}, body::{handle_incoming_body, ReadableBody},
shared::{
headers::{hash_map_to_table, header_map_to_table}, headers::{hash_map_to_table, header_map_to_table},
lua::{lua_table_to_header_map, lua_value_to_bytes, lua_value_to_method}, lua::{lua_table_to_header_map, lua_value_to_method},
},
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -57,9 +55,7 @@ impl FromLua for RequestOptions {
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Request { pub struct Request {
// NOTE: We use Bytes instead of Full<Bytes> to avoid pub(crate) inner: HyperRequest<ReadableBody>,
// needing async when getting a reference to the body
pub(crate) inner: HyperRequest<Bytes>,
pub(crate) address: Option<SocketAddr>, pub(crate) address: Option<SocketAddr>,
pub(crate) redirects: Option<usize>, pub(crate) redirects: Option<usize>,
pub(crate) decompress: bool, pub(crate) decompress: bool,
@ -78,7 +74,7 @@ impl Request {
let (body, decompress) = handle_incoming_body(&parts.headers, body, decompress).await?; let (body, decompress) = handle_incoming_body(&parts.headers, body, decompress).await?;
Ok(Self { Ok(Self {
inner: HyperRequest::from_parts(parts, body), inner: HyperRequest::from_parts(parts, ReadableBody::from(body)),
address: None, address: None,
redirects: None, redirects: None,
decompress, decompress,
@ -137,37 +133,23 @@ impl Request {
Returns the body of the request. Returns the body of the request.
*/ */
pub fn body(&self) -> &[u8] { pub fn body(&self) -> &[u8] {
self.inner.body() self.inner.body().as_slice()
} }
/** /**
Clones the inner `hyper` request with its body Clones the inner `hyper` request.
type modified to `Full<Bytes>` for sending.
*/ */
#[allow(dead_code)] #[allow(dead_code)]
pub fn as_full(&self) -> HyperRequest<Full<Bytes>> { pub fn clone_inner(&self) -> HyperRequest<ReadableBody> {
let mut builder = HyperRequest::builder() self.inner.clone()
.version(self.inner.version())
.method(self.inner.method())
.uri(self.inner.uri());
builder
.headers_mut()
.expect("request was valid")
.extend(self.inner.headers().clone());
let body = bytes_to_full(self.inner.body().clone());
builder.body(body).expect("request was valid")
} }
/** /**
Takes the inner `hyper` request with its body Takes the inner `hyper` request by ownership.
type modified to `Full<Bytes>` for sending.
*/ */
#[allow(dead_code)] #[allow(dead_code)]
pub fn into_full(self) -> HyperRequest<Full<Bytes>> { pub fn into_inner(self) -> HyperRequest<ReadableBody> {
let (parts, body) = self.inner.into_parts(); self.inner
HyperRequest::from_parts(parts, bytes_to_full(body))
} }
} }
@ -179,7 +161,7 @@ impl FromLua for Request {
let uri = s.to_str()?; let uri = s.to_str()?;
let uri = uri.parse().into_lua_err()?; let uri = uri.parse().into_lua_err()?;
let mut request = HyperRequest::new(Bytes::new()); let mut request = HyperRequest::new(ReadableBody::empty());
*request.uri_mut() = uri; *request.uri_mut() = uri;
Ok(Self { Ok(Self {
@ -221,8 +203,7 @@ impl FromLua for Request {
.unwrap_or_default(); .unwrap_or_default();
// Extract body // Extract body
let body = tab.get::<LuaValue>("body")?; let body = tab.get::<ReadableBody>("body")?;
let body = lua_value_to_bytes(&body)?;
// Build the full request // Build the full request
let mut request = HyperRequest::new(body); let mut request = HyperRequest::new(body);

View file

@ -1,24 +1,19 @@
use http_body_util::Full;
use hyper::{ use hyper::{
body::{Bytes, Incoming}, body::Incoming,
header::{HeaderValue, CONTENT_TYPE}, header::{HeaderValue, CONTENT_TYPE},
HeaderMap, Response as HyperResponse, StatusCode, HeaderMap, Response as HyperResponse, StatusCode,
}; };
use mlua::prelude::*; use mlua::prelude::*;
use crate::shared::{ use crate::{
body::{bytes_to_full, handle_incoming_body}, body::{handle_incoming_body, ReadableBody},
headers::header_map_to_table, shared::{headers::header_map_to_table, lua::lua_table_to_header_map},
lua::{lua_table_to_header_map, lua_value_to_bytes},
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Response { pub struct Response {
// NOTE: We use Bytes instead of Full<Bytes> to avoid pub(crate) inner: HyperResponse<ReadableBody>,
// needing async when getting a reference to the body
pub(crate) inner: HyperResponse<Bytes>,
pub(crate) decompressed: bool, pub(crate) decompressed: bool,
} }
@ -35,7 +30,7 @@ impl Response {
let (body, decompressed) = handle_incoming_body(&parts.headers, body, decompress).await?; let (body, decompressed) = handle_incoming_body(&parts.headers, body, decompress).await?;
Ok(Self { Ok(Self {
inner: HyperResponse::from_parts(parts, body), inner: HyperResponse::from_parts(parts, ReadableBody::from(body)),
decompressed, decompressed,
}) })
} }
@ -72,42 +67,29 @@ impl Response {
Returns the body of the response. Returns the body of the response.
*/ */
pub fn body(&self) -> &[u8] { pub fn body(&self) -> &[u8] {
self.inner.body() self.inner.body().as_slice()
} }
/** /**
Clones the inner `hyper` response with its body Clones the inner `hyper` response.
type modified to `Full<Bytes>` for sending.
*/ */
#[allow(dead_code)] #[allow(dead_code)]
pub fn as_full(&self) -> HyperResponse<Full<Bytes>> { pub fn clone_inner(&self) -> HyperResponse<ReadableBody> {
let mut builder = HyperResponse::builder() self.inner.clone()
.version(self.inner.version())
.status(self.inner.status());
builder
.headers_mut()
.expect("request was valid")
.extend(self.inner.headers().clone());
let body = bytes_to_full(self.inner.body().clone());
builder.body(body).expect("request was valid")
} }
/** /**
Takes the inner `hyper` response with its body Takes the inner `hyper` response by ownership.
type modified to `Full<Bytes>` for sending.
*/ */
#[allow(dead_code)] #[allow(dead_code)]
pub fn into_full(self) -> HyperResponse<Full<Bytes>> { pub fn into_inner(self) -> HyperResponse<ReadableBody> {
let (parts, body) = self.inner.into_parts(); self.inner
HyperResponse::from_parts(parts, bytes_to_full(body))
} }
} }
impl FromLua for Response { impl FromLua for Response {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> { fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
if let Ok(body) = lua_value_to_bytes(&value) { if let Ok(body) = ReadableBody::from_lua(value.clone(), lua) {
// String or buffer is always a 200 text/plain response // String or buffer is always a 200 text/plain response
let mut response = HyperResponse::new(body); let mut response = HyperResponse::new(body);
response response
@ -130,8 +112,7 @@ impl FromLua for Response {
.unwrap_or_default(); .unwrap_or_default();
// Extract body // Extract body
let body = tab.get::<LuaValue>("body")?; let body = tab.get::<ReadableBody>("body")?;
let body = lua_value_to_bytes(&body)?;
// Build the full response // Build the full response
let mut response = HyperResponse::new(body); let mut response = HyperResponse::new(body);