diff --git a/src/lune/builtins/fs/mod.rs b/src/lune/builtins/fs/mod.rs index 0d1fd8f..f1182e0 100644 --- a/src/lune/builtins/fs/mod.rs +++ b/src/lune/builtins/fs/mod.rs @@ -4,6 +4,7 @@ use std::path::{PathBuf, MAIN_SEPARATOR}; use mlua::prelude::*; use tokio::fs; +use crate::lune::util::buffer::{buf_to_str, create_lua_buffer}; use crate::lune::util::TableBuilder; mod copy; @@ -14,19 +15,6 @@ use copy::copy; use metadata::FsMetadata; use options::FsWriteOptions; -const BYTES_TO_BUF_IMPL: &str = r#" - local tbl = select(1, ...) - local buf = buffer.create(#tbl * 4) -- Each u32 is 4 bytes - - for offset, byte in tbl do - buffer.writeu32(buf, offset, byte) - end - - return buf -"#; - -const BUF_TO_STR_IMPL: &str = "return buffer.tostring(select(1, ...))"; - pub fn create(lua: &'static Lua) -> LuaResult { TableBuilder::new(lua)? .with_async_function("readFile", fs_read_file)? @@ -43,20 +31,6 @@ pub fn create(lua: &'static Lua) -> LuaResult { .build_readonly() } -fn create_lua_buffer(lua: &Lua, bytes: impl AsRef<[u8]>) -> LuaResult { - let lua_bytes = bytes.as_ref().into_lua(lua)?; - - let buf_constructor = lua.load(BYTES_TO_BUF_IMPL).into_function()?; - - buf_constructor.call::<_, LuaValue>(lua_bytes) -} - -fn buf_to_str(lua: &Lua, buf: LuaValue<'_>) -> LuaResult { - let str_constructor = lua.load(BUF_TO_STR_IMPL).into_function()?; - - str_constructor.call(buf) -} - async fn fs_read_file(lua: &Lua, path: String) -> LuaResult { let bytes = fs::read(&path).await.into_lua_err()?; diff --git a/src/lune/builtins/net/config.rs b/src/lune/builtins/net/config.rs index 030288e..6a3a0d9 100644 --- a/src/lune/builtins/net/config.rs +++ b/src/lune/builtins/net/config.rs @@ -4,6 +4,8 @@ use mlua::prelude::*; use reqwest::Method; +use crate::lune::util::buffer::buf_to_str; + use super::util::table_to_hash_map; // Net request config @@ -92,10 +94,16 @@ impl FromLua<'_> for RequestConfig { Err(_) => HashMap::new(), }; // Extract body - let body = match tab.raw_get::<_, LuaString>("body") { - Ok(config_body) => Some(config_body.as_bytes().to_owned()), - Err(_) => None, + let body = match tab.raw_get::<_, LuaValue>("body")? { + LuaValue::String(str) => Some(str.as_bytes().to_vec()), + LuaValue::UserData(inner) => Some( + buf_to_str(lua, LuaValue::UserData(inner))? + .as_bytes() + .to_vec(), + ), + _ => None, }; + // Convert method string into proper enum let method = method.trim().to_ascii_uppercase(); let method = match method.as_ref() { diff --git a/src/lune/builtins/net/mod.rs b/src/lune/builtins/net/mod.rs index 78c3397..f09bd1e 100644 --- a/src/lune/builtins/net/mod.rs +++ b/src/lune/builtins/net/mod.rs @@ -4,7 +4,10 @@ use mlua::prelude::*; use hyper::header::CONTENT_ENCODING; -use crate::lune::{scheduler::Scheduler, util::TableBuilder}; +use crate::lune::{ + scheduler::Scheduler, + util::{buffer::create_lua_buffer, TableBuilder}, +}; use self::{ server::{bind_to_addr, create_server}, @@ -120,7 +123,7 @@ where .with_value("statusCode", res_status)? .with_value("statusMessage", res_status_text)? .with_value("headers", res_headers_lua)? - .with_value("body", lua.create_string(&res_bytes)?)? + .with_value("body", create_lua_buffer(lua, &res_bytes)?)? .build_readonly() } diff --git a/src/lune/builtins/net/websocket.rs b/src/lune/builtins/net/websocket.rs index aa01c13..5678218 100644 --- a/src/lune/builtins/net/websocket.rs +++ b/src/lune/builtins/net/websocket.rs @@ -22,7 +22,7 @@ use hyper_tungstenite::{ }; use tokio_tungstenite::MaybeTlsStream; -use crate::lune::util::TableBuilder; +use crate::lune::util::{buffer::buf_to_str, TableBuilder}; const WEB_SOCKET_IMPL_LUA: &str = r#" return freeze(setmetatable({ @@ -178,20 +178,29 @@ where } async fn send<'lua, T>( - _lua: &'lua Lua, - (socket, string, as_binary): ( + lua: &'lua Lua, + (socket, data, as_binary): ( LuaUserDataRef<'lua, NetWebSocket>, - LuaString<'lua>, + LuaValue<'lua>, Option, ), ) -> LuaResult<()> where T: AsyncRead + AsyncWrite + Unpin, { + let string = match data { + LuaValue::String(str) => Ok(str.to_str()?.to_string()), + LuaValue::UserData(inner) => buf_to_str(lua, LuaValue::UserData(inner)), + other => Err(LuaError::runtime(format!( + "Expected data to be of type string or buffer, got {}", + other.type_name() + ))), + }?; + let msg = if matches!(as_binary, Some(true)) { WsMessage::Binary(string.as_bytes().to_vec()) } else { - let s = string.to_str().into_lua_err()?; + let s = string; WsMessage::Text(s.to_string()) }; let mut ws = socket.write_stream.lock().await; diff --git a/src/lune/util/buffer.rs b/src/lune/util/buffer.rs new file mode 100644 index 0000000..363eaaf --- /dev/null +++ b/src/lune/util/buffer.rs @@ -0,0 +1,28 @@ +use mlua::{IntoLua, Lua, Result, Value}; + +const BYTES_TO_BUF_IMPL: &str = r#" + local tbl = select(1, ...) + local buf = buffer.create(#tbl * 4) -- Each u32 is 4 bytes + + for offset, byte in tbl do + buffer.writeu32(buf, offset, byte) + end + + return buf +"#; + +const BUF_TO_STR_IMPL: &str = "return buffer.tostring(select(1, ...))"; + +pub fn create_lua_buffer(lua: &Lua, bytes: impl AsRef<[u8]>) -> Result { + let lua_bytes = bytes.as_ref().into_lua(lua)?; + + let buf_constructor = lua.load(BYTES_TO_BUF_IMPL).into_function()?; + + buf_constructor.call::<_, Value>(lua_bytes) +} + +pub fn buf_to_str(lua: &Lua, buf: Value<'_>) -> Result { + let str_constructor = lua.load(BUF_TO_STR_IMPL).into_function()?; + + str_constructor.call(buf) +} diff --git a/src/lune/util/mod.rs b/src/lune/util/mod.rs index 45e7512..1066006 100644 --- a/src/lune/util/mod.rs +++ b/src/lune/util/mod.rs @@ -1,5 +1,6 @@ mod table_builder; +pub mod buffer; pub mod formatting; pub mod futures; pub mod luaurc;