Add support for buffers as arguments in builtin APIs (#148)

This commit is contained in:
Erica Marigold 2024-04-20 20:14:19 +05:30 committed by GitHub
parent 7fb48dfa1f
commit f830ce7fad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 78 additions and 71 deletions

2
Cargo.lock generated
View file

@ -346,6 +346,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706"
dependencies = [
"memchr",
"regex-automata 0.4.6",
"serde",
]
@ -1367,6 +1368,7 @@ dependencies = [
"async-trait",
"async_zip",
"blocking",
"bstr",
"chrono",
"chrono_lc",
"clap",

View file

@ -77,6 +77,7 @@ path-clean = "1.0"
pathdiff = "0.2"
pin-project = "1.0"
urlencoding = "2.1"
bstr = "1.9.1"
### RUNTIME
@ -87,7 +88,7 @@ tokio = { version = "1.24", features = ["full", "tracing"] }
os_str_bytes = { version = "7.0", features = ["conversions"] }
mlua-luau-scheduler = { version = "0.0.2" }
mlua = { version = "0.9.6", features = [
mlua = { version = "0.9.7", features = [
"luau",
"luau-jit",
"async",

View file

@ -1,6 +1,7 @@
use std::io::ErrorKind as IoErrorKind;
use std::path::{PathBuf, MAIN_SEPARATOR};
use bstr::{BString, ByteSlice};
use mlua::prelude::*;
use tokio::fs;
@ -32,6 +33,7 @@ pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
async fn fs_read_file(lua: &Lua, path: String) -> LuaResult<LuaString> {
let bytes = fs::read(&path).await.into_lua_err()?;
lua.create_string(bytes)
}
@ -64,8 +66,8 @@ async fn fs_read_dir(_: &Lua, path: String) -> LuaResult<Vec<String>> {
Ok(dir_strings_no_prefix)
}
async fn fs_write_file(_: &Lua, (path, contents): (String, LuaString<'_>)) -> LuaResult<()> {
fs::write(&path, &contents.as_bytes()).await.into_lua_err()
async fn fs_write_file(_: &Lua, (path, contents): (String, BString)) -> LuaResult<()> {
fs::write(&path, contents.as_bytes()).await.into_lua_err()
}
async fn fs_write_dir(_: &Lua, path: String) -> LuaResult<()> {

View file

@ -3,6 +3,7 @@ use std::{
net::{IpAddr, Ipv4Addr},
};
use bstr::{BString, ByteSlice};
use mlua::prelude::*;
use reqwest::Method;
@ -107,10 +108,11 @@ impl FromLua<'_> for RequestConfig {
Err(_) => HashMap::new(),
};
// Extract body
let body = match tab.get::<_, LuaString>("body") {
let body = match tab.get::<_, BString>("body") {
Ok(config_body) => Some(config_body.as_bytes().to_owned()),
Err(_) => None,
};
// Convert method string into proper enum
let method = method.trim().to_ascii_uppercase();
let method = match method.as_ref() {

View file

@ -1,5 +1,6 @@
#![allow(unused_variables)]
use bstr::BString;
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
@ -45,7 +46,7 @@ fn net_json_encode<'lua>(
.serialize_to_string(lua, val)
}
fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult<LuaValue<'lua>> {
fn net_json_decode<'lua>(lua: &'lua Lua, json: BString) -> LuaResult<LuaValue<'lua>> {
EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json)
}

View file

@ -1,5 +1,6 @@
use std::str::FromStr;
use bstr::{BString, ByteSlice};
use http_body_util::Full;
use hyper::{
body::Bytes,
@ -56,7 +57,7 @@ impl FromLua<'_> for LuaResponse {
LuaValue::Table(t) => {
let status: Option<u16> = t.get("status")?;
let headers: Option<LuaTable> = t.get("headers")?;
let body: Option<LuaString> = t.get("body")?;
let body: Option<BString> = t.get("body")?;
let mut headers_map = HeaderMap::new();
if let Some(headers) = headers {

View file

@ -3,6 +3,7 @@ use std::sync::{
Arc,
};
use bstr::{BString, ByteSlice};
use mlua::prelude::*;
use futures_util::{
@ -160,7 +161,7 @@ where
methods.add_async_method(
"send",
|_, this, (string, as_binary): (LuaString, Option<bool>)| async move {
|_, this, (string, as_binary): (BString, Option<bool>)| async move {
this.send(if as_binary.unwrap_or_default() {
WsMessage::Binary(string.as_bytes().to_vec())
} else {

View file

@ -1,3 +1,4 @@
use bstr::{BString, ByteSlice};
use mlua::prelude::*;
use serde_json::Value as JsonValue;
@ -89,7 +90,7 @@ impl EncodeDecodeConfig {
pub fn deserialize_from_string<'lua>(
self,
lua: &'lua Lua,
string: LuaString<'lua>,
string: BString,
) -> LuaResult<LuaValue<'lua>> {
let bytes = string.as_bytes();
match self.format {

View file

@ -1,3 +1,4 @@
use bstr::BString;
use mlua::prelude::*;
pub(super) mod compress_decompress;
@ -27,7 +28,7 @@ fn serde_encode<'lua>(
fn serde_decode<'lua>(
lua: &'lua Lua,
(format, str): (EncodeDecodeFormat, LuaString<'lua>),
(format, str): (EncodeDecodeFormat, BString),
) -> LuaResult<LuaValue<'lua>> {
let config = EncodeDecodeConfig::from(format);
config.deserialize_from_string(lua, str)
@ -35,7 +36,7 @@ fn serde_decode<'lua>(
async fn serde_compress<'lua>(
lua: &'lua Lua,
(format, str): (CompressDecompressFormat, LuaString<'lua>),
(format, str): (CompressDecompressFormat, BString),
) -> LuaResult<LuaString<'lua>> {
let bytes = compress(format, str).await?;
lua.create_string(bytes)
@ -43,7 +44,7 @@ async fn serde_compress<'lua>(
async fn serde_decompress<'lua>(
lua: &'lua Lua,
(format, str): (CompressDecompressFormat, LuaString<'lua>),
(format, str): (CompressDecompressFormat, BString),
) -> LuaResult<LuaString<'lua>> {
let bytes = decompress(format, str).await?;
lua.create_string(bytes)

View file

@ -50,15 +50,15 @@ assert(fs.isFile(TEMP_ROOT_PATH_2 .. "/foo/buzz"), "Missing copied file - root/f
-- Make sure the copied files are correct
assert(
fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/bar/baz") == utils.binaryBlob,
fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/bar/baz") == buffer.tostring(utils.binaryBlob),
"Invalid copied file - root/foo/bar/baz"
)
assert(
fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/fizz") == utils.binaryBlob,
fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/fizz") == buffer.tostring(utils.binaryBlob),
"Invalid copied file - root/foo/fizz"
)
assert(
fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/buzz") == utils.binaryBlob,
fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/buzz") == buffer.tostring(utils.binaryBlob),
"Invalid copied file - root/foo/buzz"
)

View file

@ -11,6 +11,8 @@ fs.writeDir(TEMP_ROOT_PATH)
-- Write both of our files
-- binaryBlob is of type buffer to make sure fs.writeFile
-- works with both strings and buffers
fs.writeFile(TEMP_ROOT_PATH .. "/test_binary", utils.binaryBlob)
fs.writeFile(TEMP_ROOT_PATH .. "/test_json.json", utils.jsonBlob)
@ -18,7 +20,7 @@ fs.writeFile(TEMP_ROOT_PATH .. "/test_json.json", utils.jsonBlob)
-- wrote gets us back the original strings
assert(
fs.readFile(TEMP_ROOT_PATH .. "/test_binary") == utils.binaryBlob,
fs.readFile(TEMP_ROOT_PATH .. "/test_binary") == buffer.tostring(utils.binaryBlob),
"Binary file round-trip resulted in different strings"
)

View file

@ -45,7 +45,7 @@ assert(metaFile.kind == "file", "File metadata kind was invalid")
local metaBefore = fs.metadata(TEMP_FILE_PATH)
task.wait(1)
fs.writeFile(TEMP_FILE_PATH, utils.binaryBlob .. "\n")
fs.writeFile(TEMP_FILE_PATH, buffer.tostring(utils.binaryBlob) .. "\n")
local metaAfter = fs.metadata(TEMP_FILE_PATH)
assert(

View file

@ -20,7 +20,7 @@ fs.move("bin/move_test_json.json", "bin/moved_test_json.json")
-- wrote gets us back the original strings
assert(
fs.readFile("bin/moved_test_binary") == utils.binaryBlob,
fs.readFile("bin/moved_test_binary") == buffer.tostring(utils.binaryBlob),
"Binary file round-trip resulted in different strings"
)

View file

@ -16,6 +16,6 @@ local jsonBlob = serde.encode("json", {
-- Return testing data and utils
return {
binaryBlob = binaryBlob,
binaryBlob = buffer.fromstring(binaryBlob),
jsonBlob = jsonBlob,
}

View file

@ -22,7 +22,9 @@ end)
task.wait(1)
socket.send('{"op":1,"d":null}')
local payload = '{"op":1,"d":null}'
socket.send(payload)
socket.send(buffer.fromstring(payload))
socket.close(1000)
task.cancel(delayedThread)

@ -1 +1 @@
Subproject commit 52f2c1a686e7b67d996005eeddf63b97b170a741
Subproject commit 655b5cc6a64024709d3662cc45ec4319c87de5a2

View file

@ -33,69 +33,60 @@ local TESTS: { Test } = {
}
local failed = false
for _, test in TESTS do
local source = fs.readFile(test.Source)
local target = fs.readFile(test.Target)
local success, compressed = pcall(serde.compress, test.Format, source)
local function testOperation(
operationName: "Compress" | "Decompress",
operation: (
format: serde.CompressDecompressFormat,
s: buffer | string
) -> string,
format: serde.CompressDecompressFormat,
source: string | buffer,
target: string
)
local success, res = pcall(operation, format, source)
if not success then
stdio.ewrite(
string.format(
"Compressing source using '%s' format threw an error!\n%s",
tostring(test.Format),
tostring(compressed)
"%sing source using '%s' format threw an error!\n%s",
operationName,
tostring(format),
tostring(res)
)
)
failed = true
continue
elseif compressed ~= target then
elseif res ~= target then
stdio.ewrite(
string.format(
"Compressing source using '%s' format did not produce target!\n",
tostring(test.Format)
"%sing source using '%s' format did not produce target!\n",
operationName,
tostring(format)
)
)
stdio.ewrite(
string.format(
"Compressed (%d chars long):\n%s\nTarget (%d chars long):\n%s\n\n",
#compressed,
tostring(compressed),
"%sed (%d chars long):\n%s\nTarget (%d chars long):\n%s\n\n",
operationName,
#res,
tostring(res),
#target,
tostring(target)
)
)
failed = true
continue
end
end
local success2, decompressed = pcall(serde.decompress, test.Format, target)
if not success2 then
stdio.ewrite(
string.format(
"Decompressing source using '%s' format threw an error!\n%s",
tostring(test.Format),
tostring(decompressed)
)
)
failed = true
continue
elseif decompressed ~= source then
stdio.ewrite(
string.format(
"Decompressing target using '%s' format did not produce source!\n",
tostring(test.Format)
)
)
stdio.ewrite(
string.format(
"Decompressed (%d chars long):\n%s\n\n",
#decompressed,
tostring(decompressed)
)
)
failed = true
continue
end
for _, test in TESTS do
local source = fs.readFile(test.Source)
local target = fs.readFile(test.Target)
-- Compression
testOperation("Compress", serde.compress, test.Format, source, target)
testOperation("Compress", serde.compress, test.Format, buffer.fromstring(source), target)
-- Decompression
testOperation("Decompress", serde.decompress, test.Format, target, source)
testOperation("Decompress", serde.decompress, test.Format, buffer.fromstring(target), source)
end
if failed then

View file

@ -144,7 +144,7 @@ end
@param path The path of the file
@param contents The contents of the file
]=]
function fs.writeFile(path: string, contents: string) end
function fs.writeFile(path: string, contents: buffer | string) end
--[=[
@within FS

View file

@ -36,7 +36,7 @@ export type FetchParamsOptions = {
export type FetchParams = {
url: string,
method: HttpMethod?,
body: string?,
body: (string | buffer)?,
query: HttpQueryMap?,
headers: HttpHeaderMap?,
options: FetchParamsOptions?,
@ -101,7 +101,7 @@ export type ServeRequest = {
export type ServeResponse = {
status: number?,
headers: { [string]: string }?,
body: string?,
body: (string | buffer)?,
}
type ServeHttpHandler = (request: ServeRequest) -> string | ServeResponse
@ -174,7 +174,7 @@ export type ServeHandle = {
export type WebSocket = {
closeCode: number?,
close: (code: number?) -> (),
send: (message: string, asBinaryMessage: boolean?) -> (),
send: (message: (string | buffer)?, asBinaryMessage: boolean?) -> (),
next: () -> string?,
}

View file

@ -70,7 +70,7 @@ end
@param encoded The string to decode
@return The decoded lua value
]=]
function serde.decode(format: EncodeDecodeFormat, encoded: string): any
function serde.decode(format: EncodeDecodeFormat, encoded: buffer | string): any
return nil :: any
end
@ -93,7 +93,7 @@ end
@param s The string to compress
@return The compressed string
]=]
function serde.compress(format: CompressDecompressFormat, s: string): string
function serde.compress(format: CompressDecompressFormat, s: buffer | string): string
return nil :: any
end
@ -116,7 +116,7 @@ end
@param s The string to decompress
@return The decompressed string
]=]
function serde.decompress(format: CompressDecompressFormat, s: string): string
function serde.decompress(format: CompressDecompressFormat, s: buffer | string): string
return nil :: any
end