diff --git a/Cargo.lock b/Cargo.lock index 5f6fd8c..eca852d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1610,8 +1610,21 @@ dependencies = [ name = "lune-std-net" version = "0.1.0" dependencies = [ + "bstr", + "futures-util", + "http 1.1.0", + "http-body-util", + "hyper 1.3.1", + "hyper-tungstenite", + "hyper-util", + "lune-std-serde", "lune-utils", "mlua", + "mlua-luau-scheduler 0.0.1", + "reqwest", + "tokio", + "tokio-tungstenite", + "urlencoding", ] [[package]] diff --git a/crates/lune-std-net/Cargo.toml b/crates/lune-std-net/Cargo.toml index 4e28dbf..ded2a71 100644 --- a/crates/lune-std-net/Cargo.toml +++ b/crates/lune-std-net/Cargo.toml @@ -12,5 +12,22 @@ workspace = true [dependencies] mlua = { version = "0.9.7", features = ["luau"] } +mlua-luau-scheduler = "0.0.1" + +bstr = "1.9" +futures-util = "0.3" +hyper = { version = "1.1", features = ["full"] } +hyper-util = { version = "0.1", features = ["full"] } +http = "1.0" +http-body-util = { version = "0.1" } +hyper-tungstenite = { version = "0.13" } +reqwest = { version = "0.11", default-features = false, features = [ + "rustls-tls", +] } +tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } +urlencoding = "2.1" + +tokio = { version = "1", default-features = false, features = ["sync", "net"] } lune-utils = { version = "0.1.0", path = "../lune-utils" } +lune-std-serde = { version = "0.1.0", path = "../lune-std-serde" } diff --git a/crates/lune-std-net/src/client.rs b/crates/lune-std-net/src/client.rs new file mode 100644 index 0000000..cae56bf --- /dev/null +++ b/crates/lune-std-net/src/client.rs @@ -0,0 +1,163 @@ +use std::str::FromStr; + +use mlua::prelude::*; + +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_ENCODING}; + +use lune_std_serde::{decompress, CompressDecompressFormat}; +use lune_utils::TableBuilder; + +use super::{config::RequestConfig, util::header_map_to_table}; + +const REGISTRY_KEY: &str = "NetClient"; + +pub struct NetClientBuilder { + builder: reqwest::ClientBuilder, +} + +impl NetClientBuilder { + pub fn new() -> NetClientBuilder { + Self { + builder: reqwest::ClientBuilder::new(), + } + } + + pub fn headers(mut self, headers: &[(K, V)]) -> LuaResult + where + K: AsRef, + V: AsRef<[u8]>, + { + let mut map = HeaderMap::new(); + for (key, val) in headers { + let hkey = HeaderName::from_str(key.as_ref()).into_lua_err()?; + let hval = HeaderValue::from_bytes(val.as_ref()).into_lua_err()?; + map.insert(hkey, hval); + } + self.builder = self.builder.default_headers(map); + Ok(self) + } + + pub fn build(self) -> LuaResult { + let client = self.builder.build().into_lua_err()?; + Ok(NetClient { inner: client }) + } +} + +#[derive(Debug, Clone)] +pub struct NetClient { + inner: reqwest::Client, +} + +impl NetClient { + pub fn from_registry(lua: &Lua) -> Self { + lua.named_registry_value(REGISTRY_KEY) + .expect("Failed to get NetClient from lua registry") + } + + pub fn into_registry(self, lua: &Lua) { + lua.set_named_registry_value(REGISTRY_KEY, self) + .expect("Failed to store NetClient in lua registry"); + } + + pub async fn request(&self, config: RequestConfig) -> LuaResult { + // Create and send the request + let mut request = self.inner.request(config.method, config.url); + for (query, values) in config.query { + request = request.query( + &values + .iter() + .map(|v| (query.as_str(), v)) + .collect::>(), + ); + } + for (header, values) in config.headers { + for value in values { + request = request.header(header.as_str(), value); + } + } + let res = request + .body(config.body.unwrap_or_default()) + .send() + .await + .into_lua_err()?; + + // Extract status, headers + let res_status = res.status().as_u16(); + let res_status_text = res.status().canonical_reason(); + let res_headers = res.headers().clone(); + + // Read response bytes + let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec(); + let mut res_decompressed = false; + + // Check for extra options, decompression + if config.options.decompress { + let decompress_format = res_headers + .iter() + .find(|(name, _)| { + name.as_str() + .eq_ignore_ascii_case(CONTENT_ENCODING.as_str()) + }) + .and_then(|(_, value)| value.to_str().ok()) + .and_then(CompressDecompressFormat::detect_from_header_str); + if let Some(format) = decompress_format { + res_bytes = decompress(res_bytes, format).await?; + res_decompressed = true; + } + } + + Ok(NetClientResponse { + ok: (200..300).contains(&res_status), + status_code: res_status, + status_message: res_status_text.unwrap_or_default().to_string(), + headers: res_headers, + body: res_bytes, + body_decompressed: res_decompressed, + }) + } +} + +impl LuaUserData for NetClient {} + +impl FromLua<'_> for NetClient { + fn from_lua(value: LuaValue, _: &Lua) -> LuaResult { + if let LuaValue::UserData(ud) = value { + if let Ok(ctx) = ud.borrow::() { + return Ok(ctx.clone()); + } + } + unreachable!("NetClient should only be used from registry") + } +} + +impl From<&Lua> for NetClient { + fn from(value: &Lua) -> Self { + value + .named_registry_value(REGISTRY_KEY) + .expect("Missing require context in lua registry") + } +} + +pub struct NetClientResponse { + ok: bool, + status_code: u16, + status_message: String, + headers: HeaderMap, + body: Vec, + body_decompressed: bool, +} + +impl NetClientResponse { + pub fn into_lua_table(self, lua: &Lua) -> LuaResult { + TableBuilder::new(lua)? + .with_value("ok", self.ok)? + .with_value("statusCode", self.status_code)? + .with_value("statusMessage", self.status_message)? + .with_value( + "headers", + header_map_to_table(lua, self.headers, self.body_decompressed)?, + )? + .with_value("body", lua.create_string(&self.body)?)? + .build_readonly() + } +} diff --git a/crates/lune-std-net/src/config.rs b/crates/lune-std-net/src/config.rs new file mode 100644 index 0000000..6368d6d --- /dev/null +++ b/crates/lune-std-net/src/config.rs @@ -0,0 +1,231 @@ +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr}, +}; + +use bstr::{BString, ByteSlice}; +use mlua::prelude::*; + +use reqwest::Method; + +use super::util::table_to_hash_map; + +const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + +const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#" +return { + status = 426, + body = "Upgrade Required", + headers = { + Upgrade = "websocket", + }, +} +"#; + +// Net request config + +#[derive(Debug, Clone)] +pub struct RequestConfigOptions { + pub decompress: bool, +} + +impl Default for RequestConfigOptions { + fn default() -> Self { + Self { decompress: true } + } +} + +impl<'lua> FromLua<'lua> for RequestConfigOptions { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + if let LuaValue::Nil = value { + // Nil means default options + Ok(Self::default()) + } else if let LuaValue::Table(tab) = value { + // Table means custom options + let decompress = match tab.get::<_, Option>("decompress") { + Ok(decomp) => Ok(decomp.unwrap_or(true)), + Err(_) => Err(LuaError::RuntimeError( + "Invalid option value for 'decompress' in request config options".to_string(), + )), + }?; + Ok(Self { decompress }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfigOptions", + message: Some(format!( + "Invalid request config options - expected table or nil, got {}", + value.type_name() + )), + }) + } + } +} + +#[derive(Debug, Clone)] +pub struct RequestConfig { + pub url: String, + pub method: Method, + pub query: HashMap>, + pub headers: HashMap>, + pub body: Option>, + pub options: RequestConfigOptions, +} + +impl FromLua<'_> for RequestConfig { + fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { + // If we just got a string we assume its a GET request to a given url + if let LuaValue::String(s) = value { + Ok(Self { + url: s.to_string_lossy().to_string(), + method: Method::GET, + query: HashMap::new(), + headers: HashMap::new(), + body: None, + options: RequestConfigOptions::default(), + }) + } else if let LuaValue::Table(tab) = value { + // If we got a table we are able to configure the entire request + // Extract url + let url = match tab.get::<_, LuaString>("url") { + Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), + Err(_) => Err(LuaError::runtime("Missing 'url' in request config")), + }?; + // Extract method + let method = match tab.get::<_, LuaString>("method") { + Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(), + Err(_) => "GET".to_string(), + }; + // Extract query + let query = match tab.get::<_, LuaTable>("query") { + Ok(tab) => table_to_hash_map(tab, "query")?, + Err(_) => HashMap::new(), + }; + // Extract headers + let headers = match tab.get::<_, LuaTable>("headers") { + Ok(tab) => table_to_hash_map(tab, "headers")?, + Err(_) => HashMap::new(), + }; + // Extract body + let body = match tab.get::<_, BString>("body") { + Ok(config_body) => Some(config_body.as_bytes().to_owned()), + Err(_) => None, + }; + + // Convert method string into proper enum + let method = method.trim().to_ascii_uppercase(); + let method = match method.as_ref() { + "GET" => Ok(Method::GET), + "POST" => Ok(Method::POST), + "PUT" => Ok(Method::PUT), + "DELETE" => Ok(Method::DELETE), + "HEAD" => Ok(Method::HEAD), + "OPTIONS" => Ok(Method::OPTIONS), + "PATCH" => Ok(Method::PATCH), + _ => Err(LuaError::RuntimeError(format!( + "Invalid request config method '{}'", + &method + ))), + }?; + // Parse any extra options given + let options = match tab.get::<_, LuaValue>("options") { + Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?, + Err(_) => RequestConfigOptions::default(), + }; + // All good, validated and we got what we need + Ok(Self { + url, + method, + query, + headers, + body, + options, + }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfig", + message: Some(format!( + "Invalid request config - expected string or table, got {}", + value.type_name() + )), + }) + } + } +} + +// Net serve config + +#[derive(Debug)] +pub struct ServeConfig<'a> { + pub address: IpAddr, + pub handle_request: LuaFunction<'a>, + pub handle_web_socket: Option>, +} + +impl<'lua> FromLua<'lua> for ServeConfig<'lua> { + fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { + if let LuaValue::Function(f) = &value { + // Single function = request handler, rest is default + Ok(ServeConfig { + handle_request: f.clone(), + handle_web_socket: None, + address: DEFAULT_IP_ADDRESS, + }) + } else if let LuaValue::Table(t) = &value { + // Table means custom options + let address: Option = t.get("address")?; + let handle_request: Option = t.get("handleRequest")?; + let handle_web_socket: Option = t.get("handleWebSocket")?; + if handle_request.is_some() || handle_web_socket.is_some() { + let address: IpAddr = match &address { + Some(addr) => { + let addr_str = addr.to_str()?; + + addr_str + .trim_start_matches("http://") + .trim_start_matches("https://") + .parse() + .map_err(|_e| LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig", + message: Some(format!( + "IP address format is incorrect - \ + expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \ + got '{addr_str}'" + )), + })? + } + None => DEFAULT_IP_ADDRESS, + }; + + Ok(Self { + address, + handle_request: handle_request.unwrap_or_else(|| { + lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER) + .into_function() + .expect("Failed to create default http responder function") + }), + handle_web_socket, + }) + } else { + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig", + message: Some(String::from( + "Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function", + )), + }) + } + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig", + message: None, + }) + } + } +} diff --git a/crates/lune-std-net/src/lib.rs b/crates/lune-std-net/src/lib.rs index 24b8917..3f42889 100644 --- a/crates/lune-std-net/src/lib.rs +++ b/crates/lune-std-net/src/lib.rs @@ -1,9 +1,27 @@ #![allow(clippy::cargo_common_metadata)] +use bstr::BString; use mlua::prelude::*; +use mlua_luau_scheduler::LuaSpawnExt; + +mod client; +mod config; +mod server; +mod util; +mod websocket; use lune_utils::TableBuilder; +use self::{ + client::{NetClient, NetClientBuilder}, + config::{RequestConfig, ServeConfig}, + server::serve, + util::create_user_agent_header, + websocket::NetWebSocket, +}; + +use lune_std_serde::{decode, encode, EncodeDecodeConfig, EncodeDecodeFormat}; + /** Creates the `net` standard library module. @@ -12,5 +30,73 @@ use lune_utils::TableBuilder; Errors when out of memory. */ pub fn module(lua: &Lua) -> LuaResult { - TableBuilder::new(lua)?.build_readonly() + NetClientBuilder::new() + .headers(&[("User-Agent", create_user_agent_header(lua)?)])? + .build()? + .into_registry(lua); + TableBuilder::new(lua)? + .with_function("jsonEncode", net_json_encode)? + .with_function("jsonDecode", net_json_decode)? + .with_async_function("request", net_request)? + .with_async_function("socket", net_socket)? + .with_async_function("serve", net_serve)? + .with_function("urlEncode", net_url_encode)? + .with_function("urlDecode", net_url_decode)? + .build_readonly() +} + +fn net_json_encode<'lua>( + lua: &'lua Lua, + (val, pretty): (LuaValue<'lua>, Option), +) -> LuaResult> { + let config = EncodeDecodeConfig::from((EncodeDecodeFormat::Json, pretty.unwrap_or_default())); + encode(val, lua, config) +} + +fn net_json_decode(lua: &Lua, json: BString) -> LuaResult { + let config = EncodeDecodeConfig::from(EncodeDecodeFormat::Json); + decode(json, lua, config) +} + +async fn net_request(lua: &Lua, config: RequestConfig) -> LuaResult { + let client = NetClient::from_registry(lua); + // NOTE: We spawn the request as a background task to free up resources in lua + let res = lua.spawn(async move { client.request(config).await }); + res.await?.into_lua_table(lua) +} + +async fn net_socket(lua: &Lua, url: String) -> LuaResult { + let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?; + NetWebSocket::new(ws).into_lua_table(lua) +} + +async fn net_serve<'lua>( + lua: &'lua Lua, + (port, config): (u16, ServeConfig<'lua>), +) -> LuaResult> { + serve(lua, port, config).await +} + +fn net_url_encode<'lua>( + lua: &'lua Lua, + (lua_string, as_binary): (LuaString<'lua>, Option), +) -> LuaResult> { + if matches!(as_binary, Some(true)) { + urlencoding::encode_binary(lua_string.as_bytes()).into_lua(lua) + } else { + urlencoding::encode(lua_string.to_str()?).into_lua(lua) + } +} + +fn net_url_decode<'lua>( + lua: &'lua Lua, + (lua_string, as_binary): (LuaString<'lua>, Option), +) -> LuaResult> { + if matches!(as_binary, Some(true)) { + urlencoding::decode_binary(lua_string.as_bytes()).into_lua(lua) + } else { + urlencoding::decode(lua_string.to_str()?) + .map_err(|e| LuaError::RuntimeError(format!("Encountered invalid encoding - {e}")))? + .into_lua(lua) + } } diff --git a/crates/lune-std-net/src/server/keys.rs b/crates/lune-std-net/src/server/keys.rs new file mode 100644 index 0000000..9dac06a --- /dev/null +++ b/crates/lune-std-net/src/server/keys.rs @@ -0,0 +1,61 @@ +use std::sync::atomic::{AtomicUsize, Ordering}; + +use mlua::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub(super) struct SvcKeys { + key_request: &'static str, + key_websocket: Option<&'static str>, +} + +impl SvcKeys { + pub(super) fn new<'lua>( + lua: &'lua Lua, + handle_request: LuaFunction<'lua>, + handle_websocket: Option>, + ) -> LuaResult { + static SERVE_COUNTER: AtomicUsize = AtomicUsize::new(0); + let count = SERVE_COUNTER.fetch_add(1, Ordering::Relaxed); + + // NOTE: We leak strings here, but this is an acceptable tradeoff since programs + // generally only start one or a couple of servers and they are usually never dropped. + // Leaking here lets us keep this struct Copy and access the request handler callbacks + // very performantly, significantly reducing the per-request overhead of the server. + let key_request: &'static str = + Box::leak(format!("__net_serve_request_{count}").into_boxed_str()); + let key_websocket: Option<&'static str> = if handle_websocket.is_some() { + Some(Box::leak( + format!("__net_serve_websocket_{count}").into_boxed_str(), + )) + } else { + None + }; + + lua.set_named_registry_value(key_request, handle_request)?; + if let Some(key) = key_websocket { + lua.set_named_registry_value(key, handle_websocket.unwrap())?; + } + + Ok(Self { + key_request, + key_websocket, + }) + } + + pub(super) fn has_websocket_handler(&self) -> bool { + self.key_websocket.is_some() + } + + pub(super) fn request_handler<'lua>(&self, lua: &'lua Lua) -> LuaResult> { + lua.named_registry_value(self.key_request) + } + + pub(super) fn websocket_handler<'lua>( + &self, + lua: &'lua Lua, + ) -> LuaResult>> { + self.key_websocket + .map(|key| lua.named_registry_value(key)) + .transpose() + } +} diff --git a/crates/lune-std-net/src/server/mod.rs b/crates/lune-std-net/src/server/mod.rs new file mode 100644 index 0000000..7cfab9d --- /dev/null +++ b/crates/lune-std-net/src/server/mod.rs @@ -0,0 +1,105 @@ +use std::{ + net::SocketAddr, + rc::{Rc, Weak}, +}; + +use hyper::server::conn::http1; +use hyper_util::rt::TokioIo; +use tokio::{net::TcpListener, pin}; + +use mlua::prelude::*; +use mlua_luau_scheduler::LuaSpawnExt; + +use lune_utils::TableBuilder; + +use super::config::ServeConfig; + +mod keys; +mod request; +mod response; +mod service; + +use keys::SvcKeys; +use service::Svc; + +pub async fn serve<'lua>( + lua: &'lua Lua, + port: u16, + config: ServeConfig<'lua>, +) -> LuaResult> { + let addr: SocketAddr = (config.address, port).into(); + let listener = TcpListener::bind(addr).await?; + + let (lua_svc, lua_inner) = { + let rc = lua + .app_data_ref::>() + .expect("Missing weak lua ref") + .upgrade() + .expect("Lua was dropped unexpectedly"); + (Rc::clone(&rc), rc) + }; + + let keys = SvcKeys::new(lua, config.handle_request, config.handle_web_socket)?; + let svc = Svc { + lua: lua_svc, + addr, + keys, + }; + + let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); + lua.spawn_local(async move { + let mut shutdown_rx_outer = shutdown_rx.clone(); + loop { + // Create futures for accepting new connections and shutting down + let fut_shutdown = shutdown_rx_outer.changed(); + let fut_accept = async { + let stream = match listener.accept().await { + Err(_) => return, + Ok((s, _)) => s, + }; + + let io = TokioIo::new(stream); + let svc = svc.clone(); + let mut shutdown_rx_inner = shutdown_rx.clone(); + + lua_inner.spawn_local(async move { + let conn = http1::Builder::new() + .keep_alive(true) // Web sockets need this + .serve_connection(io, svc) + .with_upgrades(); + // NOTE: Because we need to use keep_alive for websockets, we need to + // also manually poll this future and handle the shutdown signal here + pin!(conn); + tokio::select! { + _ = conn.as_mut() => {} + _ = shutdown_rx_inner.changed() => { + conn.as_mut().graceful_shutdown(); + } + } + }); + }; + + // Wait for either a new connection or a shutdown signal + tokio::select! { + () = fut_accept => {} + res = fut_shutdown => { + // NOTE: We will only get a RecvError here if the serve handle is dropped, + // this means lua has garbage collected it and the user does not want + // to manually stop the server using the serve handle. Run forever. + if res.is_ok() { + break; + } + } + } + } + }); + + TableBuilder::new(lua)? + .with_value("ip", addr.ip().to_string())? + .with_value("port", addr.port())? + .with_function("stop", move |_, (): ()| match shutdown_tx.send(true) { + Ok(()) => Ok(()), + Err(_) => Err(LuaError::runtime("Server already stopped")), + })? + .build_readonly() +} diff --git a/crates/lune-std-net/src/server/request.rs b/crates/lune-std-net/src/server/request.rs new file mode 100644 index 0000000..f3de802 --- /dev/null +++ b/crates/lune-std-net/src/server/request.rs @@ -0,0 +1,54 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use http::request::Parts; + +use mlua::prelude::*; + +use lune_utils::TableBuilder; + +pub(super) struct LuaRequest { + pub(super) _remote_addr: SocketAddr, + pub(super) head: Parts, + pub(super) body: Vec, +} + +impl LuaRequest { + pub fn into_lua_table(self, lua: &Lua) -> LuaResult { + let method = self.head.method.as_str().to_string(); + let path = self.head.uri.path().to_string(); + let body = lua.create_string(&self.body)?; + + let query: HashMap = self + .head + .uri + .query() + .unwrap_or_default() + .split('&') + .filter_map(|q| q.split_once('=')) + .map(|(k, v)| { + let k = lua.create_string(k)?; + let v = lua.create_string(v)?; + Ok((k, v)) + }) + .collect::>()?; + + let headers: HashMap = self + .head + .headers + .iter() + .map(|(k, v)| { + let k = lua.create_string(k.as_str())?; + let v = lua.create_string(v.as_bytes())?; + Ok((k, v)) + }) + .collect::>()?; + + TableBuilder::new(lua)? + .with_value("method", method)? + .with_value("path", path)? + .with_value("query", query)? + .with_value("headers", headers)? + .with_value("body", body)? + .build() + } +} diff --git a/crates/lune-std-net/src/server/response.rs b/crates/lune-std-net/src/server/response.rs new file mode 100644 index 0000000..240a7cd --- /dev/null +++ b/crates/lune-std-net/src/server/response.rs @@ -0,0 +1,89 @@ +use std::str::FromStr; + +use bstr::{BString, ByteSlice}; +use http_body_util::Full; +use hyper::{ + body::Bytes, + header::{HeaderName, HeaderValue}, + HeaderMap, Response, +}; + +use mlua::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub(super) enum LuaResponseKind { + PlainText, + Table, +} + +pub(super) struct LuaResponse { + pub(super) kind: LuaResponseKind, + pub(super) status: u16, + pub(super) headers: HeaderMap, + pub(super) body: Option>, +} + +impl LuaResponse { + pub(super) fn into_response(self) -> LuaResult>> { + Ok(match self.kind { + LuaResponseKind::PlainText => Response::builder() + .status(200) + .header("Content-Type", "text/plain") + .body(Full::new(Bytes::from(self.body.unwrap()))) + .into_lua_err()?, + LuaResponseKind::Table => { + let mut response = Response::builder() + .status(self.status) + .body(Full::new(Bytes::from(self.body.unwrap_or_default()))) + .into_lua_err()?; + response.headers_mut().extend(self.headers); + response + } + }) + } +} + +impl FromLua<'_> for LuaResponse { + fn from_lua(value: LuaValue, _: &Lua) -> LuaResult { + match value { + // Plain strings from the handler are plaintext responses + LuaValue::String(s) => Ok(Self { + kind: LuaResponseKind::PlainText, + status: 200, + headers: HeaderMap::new(), + body: Some(s.as_bytes().to_vec()), + }), + // Tables are more detailed responses with potential status, headers, body + LuaValue::Table(t) => { + let status: Option = t.get("status")?; + let headers: Option = t.get("headers")?; + let body: Option = t.get("body")?; + + let mut headers_map = HeaderMap::new(); + if let Some(headers) = headers { + for pair in headers.pairs::() { + let (h, v) = pair?; + let name = HeaderName::from_str(&h).into_lua_err()?; + let value = HeaderValue::from_bytes(v.as_bytes()).into_lua_err()?; + headers_map.insert(name, value); + } + } + + let body_bytes = body.map(|s| s.as_bytes().to_vec()); + + Ok(Self { + kind: LuaResponseKind::Table, + status: status.unwrap_or(200), + headers: headers_map, + body: body_bytes, + }) + } + // Anything else is an error + value => Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "NetServeResponse", + message: None, + }), + } + } +} diff --git a/crates/lune-std-net/src/server/service.rs b/crates/lune-std-net/src/server/service.rs new file mode 100644 index 0000000..7bc7e53 --- /dev/null +++ b/crates/lune-std-net/src/server/service.rs @@ -0,0 +1,82 @@ +use std::{future::Future, net::SocketAddr, pin::Pin, rc::Rc}; + +use http_body_util::{BodyExt, Full}; +use hyper::{ + body::{Bytes, Incoming}, + service::Service, + Request, Response, +}; +use hyper_tungstenite::{is_upgrade_request, upgrade}; + +use mlua::prelude::*; +use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt}; + +use super::{ + super::websocket::NetWebSocket, keys::SvcKeys, request::LuaRequest, response::LuaResponse, +}; + +#[derive(Debug, Clone)] +pub(super) struct Svc { + pub(super) lua: Rc, + pub(super) addr: SocketAddr, + pub(super) keys: SvcKeys, +} + +impl Service> for Svc { + type Response = Response>; + type Error = LuaError; + type Future = Pin>>>; + + fn call(&self, req: Request) -> Self::Future { + let lua = self.lua.clone(); + let addr = self.addr; + let keys = self.keys; + + if keys.has_websocket_handler() && is_upgrade_request(&req) { + Box::pin(async move { + let (res, sock) = upgrade(req, None).into_lua_err()?; + + let lua_inner = lua.clone(); + lua.spawn_local(async move { + let sock = sock.await.unwrap(); + let lua_sock = NetWebSocket::new(sock); + let lua_tab = lua_sock.into_lua_table(&lua_inner).unwrap(); + + let handler_websocket: LuaFunction = + keys.websocket_handler(&lua_inner).unwrap().unwrap(); + + lua_inner + .push_thread_back(handler_websocket, lua_tab) + .unwrap(); + }); + + Ok(res) + }) + } else { + let (head, body) = req.into_parts(); + + Box::pin(async move { + let handler_request: LuaFunction = keys.request_handler(&lua).unwrap(); + + let body = body.collect().await.into_lua_err()?; + let body = body.to_bytes().to_vec(); + + let lua_req = LuaRequest { + _remote_addr: addr, + head, + body, + }; + let lua_req_table = lua_req.into_lua_table(&lua)?; + + let thread_id = lua.push_thread_back(handler_request, lua_req_table)?; + lua.track_thread(thread_id); + lua.wait_for_thread(thread_id).await; + let thread_res = lua + .get_thread_result(thread_id) + .expect("Missing handler thread result")?; + + LuaResponse::from_lua_multi(thread_res, &lua)?.into_response() + }) + } + } +} diff --git a/crates/lune-std-net/src/util.rs b/crates/lune-std-net/src/util.rs new file mode 100644 index 0000000..ca79967 --- /dev/null +++ b/crates/lune-std-net/src/util.rs @@ -0,0 +1,94 @@ +use std::collections::HashMap; + +use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH}; +use reqwest::header::HeaderMap; + +use mlua::prelude::*; + +use lune_utils::TableBuilder; + +pub fn create_user_agent_header(lua: &Lua) -> LuaResult { + let version_global = lua + .globals() + .get::<_, LuaString>("_VERSION") + .expect("Missing _VERSION global"); + + let version_global_str = version_global + .to_str() + .context("Invalid utf8 found in _VERSION global")?; + + let (package_name, full_version) = version_global_str.split_once(' ').unwrap(); + + Ok(format!("{}/{}", package_name.to_lowercase(), full_version)) +} + +pub fn header_map_to_table( + lua: &Lua, + headers: HeaderMap, + remove_content_headers: bool, +) -> LuaResult { + let mut res_headers: HashMap> = HashMap::new(); + for (name, value) in &headers { + let name = name.as_str(); + let value = value.to_str().unwrap().to_owned(); + if let Some(existing) = res_headers.get_mut(name) { + existing.push(value); + } else { + res_headers.insert(name.to_owned(), vec![value]); + } + } + + if remove_content_headers { + let content_encoding_header_str = CONTENT_ENCODING.as_str(); + let content_length_header_str = CONTENT_LENGTH.as_str(); + res_headers.retain(|name, _| { + name != content_encoding_header_str && name != content_length_header_str + }); + } + + let mut builder = TableBuilder::new(lua)?; + for (name, mut values) in res_headers { + if values.len() == 1 { + let value = values.pop().unwrap().into_lua(lua)?; + builder = builder.with_value(name, value)?; + } else { + let values = TableBuilder::new(lua)? + .with_sequential_values(values)? + .build_readonly()? + .into_lua(lua)?; + builder = builder.with_value(name, values)?; + } + } + + builder.build_readonly() +} + +pub fn table_to_hash_map( + tab: LuaTable, + tab_origin_key: &'static str, +) -> LuaResult>> { + let mut map = HashMap::new(); + + for pair in tab.pairs::() { + let (key, value) = pair?; + match value { + LuaValue::String(s) => { + map.insert(key, vec![s.to_str()?.to_owned()]); + } + LuaValue::Table(t) => { + let mut values = Vec::new(); + for value in t.sequence_values::() { + values.push(value?.to_str()?.to_owned()); + } + map.insert(key, values); + } + _ => { + return Err(LuaError::runtime(format!( + "Value for '{tab_origin_key}' must be a string or array of strings", + ))) + } + } + } + + Ok(map) +} diff --git a/crates/lune-std-net/src/websocket.rs b/crates/lune-std-net/src/websocket.rs new file mode 100644 index 0000000..ae2208a --- /dev/null +++ b/crates/lune-std-net/src/websocket.rs @@ -0,0 +1,191 @@ +use std::sync::{ + atomic::{AtomicBool, AtomicU16, Ordering}, + Arc, +}; + +use bstr::{BString, ByteSlice}; +use mlua::prelude::*; + +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::Mutex as AsyncMutex, +}; + +use hyper_tungstenite::{ + tungstenite::{ + protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame}, + Message as WsMessage, + }, + WebSocketStream, +}; + +use lune_utils::TableBuilder; + +// Wrapper implementation for compatibility and changing colon syntax to dot syntax +const WEB_SOCKET_IMPL_LUA: &str = r#" +return freeze(setmetatable({ + close = function(...) + return websocket:close(...) + end, + send = function(...) + return websocket:send(...) + end, + next = function(...) + return websocket:next(...) + end, +}, { + __index = function(self, key) + if key == "closeCode" then + return websocket.closeCode + end + end, +})) +"#; + +#[derive(Debug)] +pub struct NetWebSocket { + close_code_exists: Arc, + close_code_value: Arc, + read_stream: Arc>>>, + write_stream: Arc, WsMessage>>>, +} + +impl Clone for NetWebSocket { + fn clone(&self) -> Self { + Self { + close_code_exists: Arc::clone(&self.close_code_exists), + close_code_value: Arc::clone(&self.close_code_value), + read_stream: Arc::clone(&self.read_stream), + write_stream: Arc::clone(&self.write_stream), + } + } +} + +impl NetWebSocket +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + pub fn new(value: WebSocketStream) -> Self { + let (write, read) = value.split(); + + Self { + close_code_exists: Arc::new(AtomicBool::new(false)), + close_code_value: Arc::new(AtomicU16::new(0)), + read_stream: Arc::new(AsyncMutex::new(read)), + write_stream: Arc::new(AsyncMutex::new(write)), + } + } + + fn get_close_code(&self) -> Option { + if self.close_code_exists.load(Ordering::Relaxed) { + Some(self.close_code_value.load(Ordering::Relaxed)) + } else { + None + } + } + + fn set_close_code(&self, code: u16) { + self.close_code_exists.store(true, Ordering::Relaxed); + self.close_code_value.store(code, Ordering::Relaxed); + } + + pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { + let mut ws = self.write_stream.lock().await; + ws.send(msg).await.into_lua_err() + } + + pub async fn next(&self) -> LuaResult> { + let mut ws = self.read_stream.lock().await; + ws.next().await.transpose().into_lua_err() + } + + pub async fn close(&self, code: Option) -> LuaResult<()> { + if self.close_code_exists.load(Ordering::Relaxed) { + return Err(LuaError::runtime("Socket has already been closed")); + } + + self.send(WsMessage::Close(Some(WsCloseFrame { + code: match code { + Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), + Some(code) => { + return Err(LuaError::runtime(format!( + "Close code must be between 1000 and 4999, got {code}" + ))) + } + None => WsCloseCode::Normal, + }, + reason: "".into(), + }))) + .await?; + + let mut ws = self.write_stream.lock().await; + ws.close().await.into_lua_err() + } + + pub fn into_lua_table(self, lua: &Lua) -> LuaResult { + let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?; + let table_freeze = lua + .globals() + .get::<_, LuaTable>("table")? + .get::<_, LuaFunction>("freeze")?; + + let env = TableBuilder::new(lua)? + .with_value("websocket", self.clone())? + .with_value("setmetatable", setmetatable)? + .with_value("freeze", table_freeze)? + .build_readonly()?; + + lua.load(WEB_SOCKET_IMPL_LUA) + .set_name("websocket") + .set_environment(env) + .eval() + } +} + +impl LuaUserData for NetWebSocket +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code())); + } + + fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_async_method("close", |_, this, code: Option| async move { + this.close(code).await + }); + + methods.add_async_method( + "send", + |_, this, (string, as_binary): (BString, Option)| async move { + this.send(if as_binary.unwrap_or_default() { + WsMessage::Binary(string.as_bytes().to_vec()) + } else { + let s = string.to_str().into_lua_err()?; + WsMessage::Text(s.to_string()) + }) + .await + }, + ); + + methods.add_async_method("next", |lua, this, (): ()| async move { + let msg = this.next().await?; + + if let Some(WsMessage::Close(Some(frame))) = msg.as_ref() { + this.set_close_code(frame.code.into()); + } + + Ok(match msg { + Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?), + Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?), + Some(WsMessage::Close(_)) | None => LuaValue::Nil, + // Ignore ping/pong/frame messages, they are handled by tungstenite + msg => unreachable!("Unhandled message: {:?}", msg), + }) + }); + } +}