Add support for buffers as arguments in builtin APIs (#148)

This commit is contained in:
Erica Marigold 2024-04-20 20:14:19 +05:30 committed by GitHub
parent 7fb48dfa1f
commit f830ce7fad
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 78 additions and 71 deletions

2
Cargo.lock generated
View file

@ -346,6 +346,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706" checksum = "05efc5cfd9110c8416e471df0e96702d58690178e206e61b7173706673c93706"
dependencies = [ dependencies = [
"memchr", "memchr",
"regex-automata 0.4.6",
"serde", "serde",
] ]
@ -1367,6 +1368,7 @@ dependencies = [
"async-trait", "async-trait",
"async_zip", "async_zip",
"blocking", "blocking",
"bstr",
"chrono", "chrono",
"chrono_lc", "chrono_lc",
"clap", "clap",

View file

@ -77,6 +77,7 @@ path-clean = "1.0"
pathdiff = "0.2" pathdiff = "0.2"
pin-project = "1.0" pin-project = "1.0"
urlencoding = "2.1" urlencoding = "2.1"
bstr = "1.9.1"
### RUNTIME ### RUNTIME
@ -87,7 +88,7 @@ tokio = { version = "1.24", features = ["full", "tracing"] }
os_str_bytes = { version = "7.0", features = ["conversions"] } os_str_bytes = { version = "7.0", features = ["conversions"] }
mlua-luau-scheduler = { version = "0.0.2" } mlua-luau-scheduler = { version = "0.0.2" }
mlua = { version = "0.9.6", features = [ mlua = { version = "0.9.7", features = [
"luau", "luau",
"luau-jit", "luau-jit",
"async", "async",

View file

@ -1,6 +1,7 @@
use std::io::ErrorKind as IoErrorKind; use std::io::ErrorKind as IoErrorKind;
use std::path::{PathBuf, MAIN_SEPARATOR}; use std::path::{PathBuf, MAIN_SEPARATOR};
use bstr::{BString, ByteSlice};
use mlua::prelude::*; use mlua::prelude::*;
use tokio::fs; use tokio::fs;
@ -32,6 +33,7 @@ pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
async fn fs_read_file(lua: &Lua, path: String) -> LuaResult<LuaString> { async fn fs_read_file(lua: &Lua, path: String) -> LuaResult<LuaString> {
let bytes = fs::read(&path).await.into_lua_err()?; let bytes = fs::read(&path).await.into_lua_err()?;
lua.create_string(bytes) lua.create_string(bytes)
} }
@ -64,8 +66,8 @@ async fn fs_read_dir(_: &Lua, path: String) -> LuaResult<Vec<String>> {
Ok(dir_strings_no_prefix) Ok(dir_strings_no_prefix)
} }
async fn fs_write_file(_: &Lua, (path, contents): (String, LuaString<'_>)) -> LuaResult<()> { async fn fs_write_file(_: &Lua, (path, contents): (String, BString)) -> LuaResult<()> {
fs::write(&path, &contents.as_bytes()).await.into_lua_err() fs::write(&path, contents.as_bytes()).await.into_lua_err()
} }
async fn fs_write_dir(_: &Lua, path: String) -> LuaResult<()> { async fn fs_write_dir(_: &Lua, path: String) -> LuaResult<()> {

View file

@ -3,6 +3,7 @@ use std::{
net::{IpAddr, Ipv4Addr}, net::{IpAddr, Ipv4Addr},
}; };
use bstr::{BString, ByteSlice};
use mlua::prelude::*; use mlua::prelude::*;
use reqwest::Method; use reqwest::Method;
@ -107,10 +108,11 @@ impl FromLua<'_> for RequestConfig {
Err(_) => HashMap::new(), Err(_) => HashMap::new(),
}; };
// Extract body // Extract body
let body = match tab.get::<_, LuaString>("body") { let body = match tab.get::<_, BString>("body") {
Ok(config_body) => Some(config_body.as_bytes().to_owned()), Ok(config_body) => Some(config_body.as_bytes().to_owned()),
Err(_) => None, Err(_) => None,
}; };
// Convert method string into proper enum // Convert method string into proper enum
let method = method.trim().to_ascii_uppercase(); let method = method.trim().to_ascii_uppercase();
let method = match method.as_ref() { let method = match method.as_ref() {

View file

@ -1,5 +1,6 @@
#![allow(unused_variables)] #![allow(unused_variables)]
use bstr::BString;
use mlua::prelude::*; use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt; use mlua_luau_scheduler::LuaSpawnExt;
@ -45,7 +46,7 @@ fn net_json_encode<'lua>(
.serialize_to_string(lua, val) .serialize_to_string(lua, val)
} }
fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult<LuaValue<'lua>> { fn net_json_decode<'lua>(lua: &'lua Lua, json: BString) -> LuaResult<LuaValue<'lua>> {
EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json) EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json)
} }

View file

@ -1,5 +1,6 @@
use std::str::FromStr; use std::str::FromStr;
use bstr::{BString, ByteSlice};
use http_body_util::Full; use http_body_util::Full;
use hyper::{ use hyper::{
body::Bytes, body::Bytes,
@ -56,7 +57,7 @@ impl FromLua<'_> for LuaResponse {
LuaValue::Table(t) => { LuaValue::Table(t) => {
let status: Option<u16> = t.get("status")?; let status: Option<u16> = t.get("status")?;
let headers: Option<LuaTable> = t.get("headers")?; let headers: Option<LuaTable> = t.get("headers")?;
let body: Option<LuaString> = t.get("body")?; let body: Option<BString> = t.get("body")?;
let mut headers_map = HeaderMap::new(); let mut headers_map = HeaderMap::new();
if let Some(headers) = headers { if let Some(headers) = headers {

View file

@ -3,6 +3,7 @@ use std::sync::{
Arc, Arc,
}; };
use bstr::{BString, ByteSlice};
use mlua::prelude::*; use mlua::prelude::*;
use futures_util::{ use futures_util::{
@ -160,7 +161,7 @@ where
methods.add_async_method( methods.add_async_method(
"send", "send",
|_, this, (string, as_binary): (LuaString, Option<bool>)| async move { |_, this, (string, as_binary): (BString, Option<bool>)| async move {
this.send(if as_binary.unwrap_or_default() { this.send(if as_binary.unwrap_or_default() {
WsMessage::Binary(string.as_bytes().to_vec()) WsMessage::Binary(string.as_bytes().to_vec())
} else { } else {

View file

@ -1,3 +1,4 @@
use bstr::{BString, ByteSlice};
use mlua::prelude::*; use mlua::prelude::*;
use serde_json::Value as JsonValue; use serde_json::Value as JsonValue;
@ -89,7 +90,7 @@ impl EncodeDecodeConfig {
pub fn deserialize_from_string<'lua>( pub fn deserialize_from_string<'lua>(
self, self,
lua: &'lua Lua, lua: &'lua Lua,
string: LuaString<'lua>, string: BString,
) -> LuaResult<LuaValue<'lua>> { ) -> LuaResult<LuaValue<'lua>> {
let bytes = string.as_bytes(); let bytes = string.as_bytes();
match self.format { match self.format {

View file

@ -1,3 +1,4 @@
use bstr::BString;
use mlua::prelude::*; use mlua::prelude::*;
pub(super) mod compress_decompress; pub(super) mod compress_decompress;
@ -27,7 +28,7 @@ fn serde_encode<'lua>(
fn serde_decode<'lua>( fn serde_decode<'lua>(
lua: &'lua Lua, lua: &'lua Lua,
(format, str): (EncodeDecodeFormat, LuaString<'lua>), (format, str): (EncodeDecodeFormat, BString),
) -> LuaResult<LuaValue<'lua>> { ) -> LuaResult<LuaValue<'lua>> {
let config = EncodeDecodeConfig::from(format); let config = EncodeDecodeConfig::from(format);
config.deserialize_from_string(lua, str) config.deserialize_from_string(lua, str)
@ -35,7 +36,7 @@ fn serde_decode<'lua>(
async fn serde_compress<'lua>( async fn serde_compress<'lua>(
lua: &'lua Lua, lua: &'lua Lua,
(format, str): (CompressDecompressFormat, LuaString<'lua>), (format, str): (CompressDecompressFormat, BString),
) -> LuaResult<LuaString<'lua>> { ) -> LuaResult<LuaString<'lua>> {
let bytes = compress(format, str).await?; let bytes = compress(format, str).await?;
lua.create_string(bytes) lua.create_string(bytes)
@ -43,7 +44,7 @@ async fn serde_compress<'lua>(
async fn serde_decompress<'lua>( async fn serde_decompress<'lua>(
lua: &'lua Lua, lua: &'lua Lua,
(format, str): (CompressDecompressFormat, LuaString<'lua>), (format, str): (CompressDecompressFormat, BString),
) -> LuaResult<LuaString<'lua>> { ) -> LuaResult<LuaString<'lua>> {
let bytes = decompress(format, str).await?; let bytes = decompress(format, str).await?;
lua.create_string(bytes) lua.create_string(bytes)

View file

@ -50,15 +50,15 @@ assert(fs.isFile(TEMP_ROOT_PATH_2 .. "/foo/buzz"), "Missing copied file - root/f
-- Make sure the copied files are correct -- Make sure the copied files are correct
assert( assert(
fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/bar/baz") == utils.binaryBlob, fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/bar/baz") == buffer.tostring(utils.binaryBlob),
"Invalid copied file - root/foo/bar/baz" "Invalid copied file - root/foo/bar/baz"
) )
assert( assert(
fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/fizz") == utils.binaryBlob, fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/fizz") == buffer.tostring(utils.binaryBlob),
"Invalid copied file - root/foo/fizz" "Invalid copied file - root/foo/fizz"
) )
assert( assert(
fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/buzz") == utils.binaryBlob, fs.readFile(TEMP_ROOT_PATH_2 .. "/foo/buzz") == buffer.tostring(utils.binaryBlob),
"Invalid copied file - root/foo/buzz" "Invalid copied file - root/foo/buzz"
) )

View file

@ -11,6 +11,8 @@ fs.writeDir(TEMP_ROOT_PATH)
-- Write both of our files -- Write both of our files
-- binaryBlob is of type buffer to make sure fs.writeFile
-- works with both strings and buffers
fs.writeFile(TEMP_ROOT_PATH .. "/test_binary", utils.binaryBlob) fs.writeFile(TEMP_ROOT_PATH .. "/test_binary", utils.binaryBlob)
fs.writeFile(TEMP_ROOT_PATH .. "/test_json.json", utils.jsonBlob) fs.writeFile(TEMP_ROOT_PATH .. "/test_json.json", utils.jsonBlob)
@ -18,7 +20,7 @@ fs.writeFile(TEMP_ROOT_PATH .. "/test_json.json", utils.jsonBlob)
-- wrote gets us back the original strings -- wrote gets us back the original strings
assert( assert(
fs.readFile(TEMP_ROOT_PATH .. "/test_binary") == utils.binaryBlob, fs.readFile(TEMP_ROOT_PATH .. "/test_binary") == buffer.tostring(utils.binaryBlob),
"Binary file round-trip resulted in different strings" "Binary file round-trip resulted in different strings"
) )

View file

@ -45,7 +45,7 @@ assert(metaFile.kind == "file", "File metadata kind was invalid")
local metaBefore = fs.metadata(TEMP_FILE_PATH) local metaBefore = fs.metadata(TEMP_FILE_PATH)
task.wait(1) task.wait(1)
fs.writeFile(TEMP_FILE_PATH, utils.binaryBlob .. "\n") fs.writeFile(TEMP_FILE_PATH, buffer.tostring(utils.binaryBlob) .. "\n")
local metaAfter = fs.metadata(TEMP_FILE_PATH) local metaAfter = fs.metadata(TEMP_FILE_PATH)
assert( assert(

View file

@ -20,7 +20,7 @@ fs.move("bin/move_test_json.json", "bin/moved_test_json.json")
-- wrote gets us back the original strings -- wrote gets us back the original strings
assert( assert(
fs.readFile("bin/moved_test_binary") == utils.binaryBlob, fs.readFile("bin/moved_test_binary") == buffer.tostring(utils.binaryBlob),
"Binary file round-trip resulted in different strings" "Binary file round-trip resulted in different strings"
) )

View file

@ -16,6 +16,6 @@ local jsonBlob = serde.encode("json", {
-- Return testing data and utils -- Return testing data and utils
return { return {
binaryBlob = binaryBlob, binaryBlob = buffer.fromstring(binaryBlob),
jsonBlob = jsonBlob, jsonBlob = jsonBlob,
} }

View file

@ -22,7 +22,9 @@ end)
task.wait(1) task.wait(1)
socket.send('{"op":1,"d":null}') local payload = '{"op":1,"d":null}'
socket.send(payload)
socket.send(buffer.fromstring(payload))
socket.close(1000) socket.close(1000)
task.cancel(delayedThread) task.cancel(delayedThread)

@ -1 +1 @@
Subproject commit 52f2c1a686e7b67d996005eeddf63b97b170a741 Subproject commit 655b5cc6a64024709d3662cc45ec4319c87de5a2

View file

@ -33,69 +33,60 @@ local TESTS: { Test } = {
} }
local failed = false local failed = false
for _, test in TESTS do local function testOperation(
local source = fs.readFile(test.Source) operationName: "Compress" | "Decompress",
local target = fs.readFile(test.Target) operation: (
format: serde.CompressDecompressFormat,
local success, compressed = pcall(serde.compress, test.Format, source) s: buffer | string
) -> string,
format: serde.CompressDecompressFormat,
source: string | buffer,
target: string
)
local success, res = pcall(operation, format, source)
if not success then if not success then
stdio.ewrite( stdio.ewrite(
string.format( string.format(
"Compressing source using '%s' format threw an error!\n%s", "%sing source using '%s' format threw an error!\n%s",
tostring(test.Format), operationName,
tostring(compressed) tostring(format),
tostring(res)
) )
) )
failed = true failed = true
continue elseif res ~= target then
elseif compressed ~= target then
stdio.ewrite( stdio.ewrite(
string.format( string.format(
"Compressing source using '%s' format did not produce target!\n", "%sing source using '%s' format did not produce target!\n",
tostring(test.Format) operationName,
tostring(format)
) )
) )
stdio.ewrite( stdio.ewrite(
string.format( string.format(
"Compressed (%d chars long):\n%s\nTarget (%d chars long):\n%s\n\n", "%sed (%d chars long):\n%s\nTarget (%d chars long):\n%s\n\n",
#compressed, operationName,
tostring(compressed), #res,
tostring(res),
#target, #target,
tostring(target) tostring(target)
) )
) )
failed = true failed = true
continue
end end
end
local success2, decompressed = pcall(serde.decompress, test.Format, target) for _, test in TESTS do
if not success2 then local source = fs.readFile(test.Source)
stdio.ewrite( local target = fs.readFile(test.Target)
string.format(
"Decompressing source using '%s' format threw an error!\n%s", -- Compression
tostring(test.Format), testOperation("Compress", serde.compress, test.Format, source, target)
tostring(decompressed) testOperation("Compress", serde.compress, test.Format, buffer.fromstring(source), target)
)
) -- Decompression
failed = true testOperation("Decompress", serde.decompress, test.Format, target, source)
continue testOperation("Decompress", serde.decompress, test.Format, buffer.fromstring(target), source)
elseif decompressed ~= source then
stdio.ewrite(
string.format(
"Decompressing target using '%s' format did not produce source!\n",
tostring(test.Format)
)
)
stdio.ewrite(
string.format(
"Decompressed (%d chars long):\n%s\n\n",
#decompressed,
tostring(decompressed)
)
)
failed = true
continue
end
end end
if failed then if failed then

View file

@ -144,7 +144,7 @@ end
@param path The path of the file @param path The path of the file
@param contents The contents of the file @param contents The contents of the file
]=] ]=]
function fs.writeFile(path: string, contents: string) end function fs.writeFile(path: string, contents: buffer | string) end
--[=[ --[=[
@within FS @within FS

View file

@ -36,7 +36,7 @@ export type FetchParamsOptions = {
export type FetchParams = { export type FetchParams = {
url: string, url: string,
method: HttpMethod?, method: HttpMethod?,
body: string?, body: (string | buffer)?,
query: HttpQueryMap?, query: HttpQueryMap?,
headers: HttpHeaderMap?, headers: HttpHeaderMap?,
options: FetchParamsOptions?, options: FetchParamsOptions?,
@ -101,7 +101,7 @@ export type ServeRequest = {
export type ServeResponse = { export type ServeResponse = {
status: number?, status: number?,
headers: { [string]: string }?, headers: { [string]: string }?,
body: string?, body: (string | buffer)?,
} }
type ServeHttpHandler = (request: ServeRequest) -> string | ServeResponse type ServeHttpHandler = (request: ServeRequest) -> string | ServeResponse
@ -174,7 +174,7 @@ export type ServeHandle = {
export type WebSocket = { export type WebSocket = {
closeCode: number?, closeCode: number?,
close: (code: number?) -> (), close: (code: number?) -> (),
send: (message: string, asBinaryMessage: boolean?) -> (), send: (message: (string | buffer)?, asBinaryMessage: boolean?) -> (),
next: () -> string?, next: () -> string?,
} }

View file

@ -70,7 +70,7 @@ end
@param encoded The string to decode @param encoded The string to decode
@return The decoded lua value @return The decoded lua value
]=] ]=]
function serde.decode(format: EncodeDecodeFormat, encoded: string): any function serde.decode(format: EncodeDecodeFormat, encoded: buffer | string): any
return nil :: any return nil :: any
end end
@ -93,7 +93,7 @@ end
@param s The string to compress @param s The string to compress
@return The compressed string @return The compressed string
]=] ]=]
function serde.compress(format: CompressDecompressFormat, s: string): string function serde.compress(format: CompressDecompressFormat, s: buffer | string): string
return nil :: any return nil :: any
end end
@ -116,7 +116,7 @@ end
@param s The string to decompress @param s The string to decompress
@return The decompressed string @return The decompressed string
]=] ]=]
function serde.decompress(format: CompressDecompressFormat, s: string): string function serde.decompress(format: CompressDecompressFormat, s: buffer | string): string
return nil :: any return nil :: any
end end