refactor(net, fs): accept buffers and strings instead of just buffers

This commit is contained in:
Erica Marigold 2024-04-19 14:24:04 +05:30
parent a35c7fc3c4
commit 716e8aae8e
No known key found for this signature in database
GPG key ID: 2768CC0C23D245D1
8 changed files with 14 additions and 40 deletions

2
Cargo.lock generated
View file

@ -330,6 +330,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706" checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706"
dependencies = [ dependencies = [
"memchr", "memchr",
"regex-automata 0.4.6",
"serde", "serde",
] ]
@ -1350,6 +1351,7 @@ dependencies = [
"async-compression", "async-compression",
"async-trait", "async-trait",
"blocking", "blocking",
"bstr",
"chrono", "chrono",
"chrono_lc", "chrono_lc",
"clap", "clap",

View file

@ -75,6 +75,7 @@ path-clean = "1.0"
pathdiff = "0.2" pathdiff = "0.2"
pin-project = "1.0" pin-project = "1.0"
urlencoding = "2.1" urlencoding = "2.1"
bstr = "1.9.1"
### RUNTIME ### RUNTIME
@ -85,7 +86,7 @@ tokio = { version = "1.24", features = ["full", "tracing"] }
os_str_bytes = { version = "7.0", features = ["conversions"] } os_str_bytes = { version = "7.0", features = ["conversions"] }
mlua-luau-scheduler = { version = "0.0.2" } mlua-luau-scheduler = { version = "0.0.2" }
mlua = { version = "0.9.6", features = [ mlua = { version = "0.9.7", features = [
"luau", "luau",
"luau-jit", "luau-jit",
"async", "async",

View file

@ -1,10 +1,10 @@
use std::io::ErrorKind as IoErrorKind; use std::io::ErrorKind as IoErrorKind;
use std::path::{PathBuf, MAIN_SEPARATOR}; use std::path::{PathBuf, MAIN_SEPARATOR};
use bstr::{BString, ByteSlice};
use mlua::prelude::*; use mlua::prelude::*;
use tokio::fs; use tokio::fs;
use crate::lune::util::buffer::{buf_to_str, create_lua_buffer};
use crate::lune::util::TableBuilder; use crate::lune::util::TableBuilder;
mod copy; mod copy;
@ -31,10 +31,10 @@ pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
.build_readonly() .build_readonly()
} }
async fn fs_read_file(lua: &Lua, path: String) -> LuaResult<LuaValue> { async fn fs_read_file(lua: &Lua, path: String) -> LuaResult<LuaString> {
let bytes = fs::read(&path).await.into_lua_err()?; let bytes = fs::read(&path).await.into_lua_err()?;
create_lua_buffer(lua, bytes) lua.create_string(bytes)
} }
async fn fs_read_dir(_: &Lua, path: String) -> LuaResult<Vec<String>> { async fn fs_read_dir(_: &Lua, path: String) -> LuaResult<Vec<String>> {
@ -68,8 +68,8 @@ async fn fs_read_dir(_: &Lua, path: String) -> LuaResult<Vec<String>> {
async fn fs_write_file(lua: &Lua, (path, contents): (String, LuaValue<'_>)) -> LuaResult<()> { async fn fs_write_file(lua: &Lua, (path, contents): (String, LuaValue<'_>)) -> LuaResult<()> {
let contents_str = match contents { let contents_str = match contents {
LuaValue::String(str) => Ok(str.to_str()?.to_string()), LuaValue::String(str) => Ok(BString::from(str.to_str()?)),
LuaValue::UserData(inner) => Ok(buf_to_str(lua, LuaValue::UserData(inner))?), LuaValue::UserData(inner) => lua.unpack::<BString>(LuaValue::UserData(inner)),
other => Err(LuaError::runtime(format!( other => Err(LuaError::runtime(format!(
"Expected type string or buffer, got {}", "Expected type string or buffer, got {}",
other.type_name() other.type_name()

View file

@ -3,12 +3,11 @@ use std::{
net::{IpAddr, Ipv4Addr}, net::{IpAddr, Ipv4Addr},
}; };
use bstr::{BString, ByteSlice};
use mlua::prelude::*; use mlua::prelude::*;
use reqwest::Method; use reqwest::Method;
use crate::lune::util::buffer::buf_to_str;
use super::util::table_to_hash_map; use super::util::table_to_hash_map;
const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
@ -109,7 +108,7 @@ impl FromLua<'_> for RequestConfig {
Err(_) => HashMap::new(), Err(_) => HashMap::new(),
}; };
// Extract body // Extract body
let body = match tab.get::<_, LuaString>("body") { let body = match tab.get::<_, BString>("body") {
Ok(config_body) => Some(config_body.as_bytes().to_owned()), Ok(config_body) => Some(config_body.as_bytes().to_owned()),
Err(_) => None, Err(_) => None,
}; };

View file

@ -1,5 +1,6 @@
use std::str::FromStr; use std::str::FromStr;
use bstr::{BString, ByteSlice};
use http_body_util::Full; use http_body_util::Full;
use hyper::{ use hyper::{
body::Bytes, body::Bytes,
@ -56,7 +57,7 @@ impl FromLua<'_> for LuaResponse {
LuaValue::Table(t) => { LuaValue::Table(t) => {
let status: Option<u16> = t.get("status")?; let status: Option<u16> = t.get("status")?;
let headers: Option<LuaTable> = t.get("headers")?; let headers: Option<LuaTable> = t.get("headers")?;
let body: Option<LuaString> = t.get("body")?; let body: Option<BString> = t.get("body")?;
let mut headers_map = HeaderMap::new(); let mut headers_map = HeaderMap::new();
if let Some(headers) = headers { if let Some(headers) = headers {

View file

@ -22,7 +22,7 @@ use hyper_tungstenite::{
WebSocketStream, WebSocketStream,
}; };
use crate::lune::util::{buffer::buf_to_str, TableBuilder}; use crate::lune::util::TableBuilder;
// Wrapper implementation for compatibility and changing colon syntax to dot syntax // Wrapper implementation for compatibility and changing colon syntax to dot syntax
const WEB_SOCKET_IMPL_LUA: &str = r#" const WEB_SOCKET_IMPL_LUA: &str = r#"

View file

@ -1,28 +0,0 @@
use mlua::{IntoLua, Lua, Result, Value};
const BYTES_TO_BUF_IMPL: &str = r#"
local tbl = select(1, ...)
local buf = buffer.create(#tbl * 4) -- Each u32 is 4 bytes
for offset, byte in tbl do
buffer.writeu32(buf, offset, byte)
end
return buf
"#;
const BUF_TO_STR_IMPL: &str = "return buffer.tostring(select(1, ...))";
pub fn create_lua_buffer(lua: &Lua, bytes: impl AsRef<[u8]>) -> Result<Value> {
let lua_bytes = bytes.as_ref().into_lua(lua)?;
let buf_constructor = lua.load(BYTES_TO_BUF_IMPL).into_function()?;
buf_constructor.call::<_, Value>(lua_bytes)
}
pub fn buf_to_str(lua: &Lua, buf: Value<'_>) -> Result<String> {
let str_constructor = lua.load(BUF_TO_STR_IMPL).into_function()?;
str_constructor.call(buf)
}

View file

@ -1,6 +1,5 @@
mod table_builder; mod table_builder;
pub mod buffer;
pub mod formatting; pub mod formatting;
pub mod luaurc; pub mod luaurc;
pub mod paths; pub mod paths;