From c484ae73d6ee2650764099f1be4aea9cdc3fb0b0 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Sat, 19 Aug 2023 20:22:11 -0500 Subject: [PATCH] Add back net builtin --- src/lune/builtins/mod.rs | 5 + src/lune/builtins/net/client.rs | 49 ++++++ src/lune/builtins/net/config.rs | 204 +++++++++++++++++++++++++ src/lune/builtins/net/mod.rs | 238 +++++++++++++++++++++++++++++ src/lune/builtins/net/response.rs | 106 +++++++++++++ src/lune/builtins/net/server.rs | 180 ++++++++++++++++++++++ src/lune/builtins/net/websocket.rs | 229 +++++++++++++++++++++++++++ src/lune/builtins/serde/mod.rs | 6 +- 8 files changed, 1014 insertions(+), 3 deletions(-) create mode 100644 src/lune/builtins/net/client.rs create mode 100644 src/lune/builtins/net/config.rs create mode 100644 src/lune/builtins/net/mod.rs create mode 100644 src/lune/builtins/net/response.rs create mode 100644 src/lune/builtins/net/server.rs create mode 100644 src/lune/builtins/net/websocket.rs diff --git a/src/lune/builtins/mod.rs b/src/lune/builtins/mod.rs index 0db6571..2fc47cf 100644 --- a/src/lune/builtins/mod.rs +++ b/src/lune/builtins/mod.rs @@ -4,6 +4,7 @@ use mlua::prelude::*; mod fs; mod luau; +mod net; mod process; mod serde; mod stdio; @@ -16,6 +17,7 @@ mod roblox; pub enum LuneBuiltin { Fs, Luau, + Net, Task, Process, Serde, @@ -32,6 +34,7 @@ where match self { Self::Fs => "fs", Self::Luau => "luau", + Self::Net => "net", Self::Task => "task", Self::Process => "process", Self::Serde => "serde", @@ -45,6 +48,7 @@ where let res = match self { Self::Fs => fs::create(lua), Self::Luau => luau::create(lua), + Self::Net => net::create(lua), Self::Task => task::create(lua), Self::Process => process::create(lua), Self::Serde => serde::create(lua), @@ -68,6 +72,7 @@ impl FromStr for LuneBuiltin { match s.trim().to_ascii_lowercase().as_str() { "fs" => Ok(Self::Fs), "luau" => Ok(Self::Luau), + "net" => Ok(Self::Net), "task" => Ok(Self::Task), "process" => Ok(Self::Process), "serde" => Ok(Self::Serde), diff --git a/src/lune/builtins/net/client.rs b/src/lune/builtins/net/client.rs new file mode 100644 index 0000000..2120c58 --- /dev/null +++ b/src/lune/builtins/net/client.rs @@ -0,0 +1,49 @@ +use std::str::FromStr; + +use mlua::prelude::*; + +use hyper::{header::HeaderName, http::HeaderValue, HeaderMap}; +use reqwest::{IntoUrl, Method, RequestBuilder}; + +pub struct NetClientBuilder { + builder: reqwest::ClientBuilder, +} + +impl NetClientBuilder { + pub fn new() -> NetClientBuilder { + Self { + builder: reqwest::ClientBuilder::new(), + } + } + + pub fn headers(mut self, headers: &[(K, V)]) -> LuaResult + where + K: AsRef, + V: AsRef<[u8]>, + { + let mut map = HeaderMap::new(); + for (key, val) in headers { + let hkey = HeaderName::from_str(key.as_ref()).into_lua_err()?; + let hval = HeaderValue::from_bytes(val.as_ref()).into_lua_err()?; + map.insert(hkey, hval); + } + self.builder = self.builder.default_headers(map); + Ok(self) + } + + pub fn build(self) -> LuaResult { + let client = self.builder.build().into_lua_err()?; + Ok(NetClient(client)) + } +} + +#[derive(Debug, Clone)] +pub struct NetClient(reqwest::Client); + +impl NetClient { + pub fn request(&self, method: Method, url: U) -> RequestBuilder { + self.0.request(method, url) + } +} + +impl LuaUserData for NetClient {} diff --git a/src/lune/builtins/net/config.rs b/src/lune/builtins/net/config.rs new file mode 100644 index 0000000..59c1e3e --- /dev/null +++ b/src/lune/builtins/net/config.rs @@ -0,0 +1,204 @@ +use std::collections::HashMap; + +use mlua::prelude::*; + +use reqwest::Method; + +// Net request config + +#[derive(Debug, Clone)] +pub struct RequestConfigOptions { + pub decompress: bool, +} + +impl Default for RequestConfigOptions { + fn default() -> Self { + Self { decompress: true } + } +} + +impl<'lua> FromLua<'lua> for RequestConfigOptions { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + // Nil means default options, table means custom options + if let LuaValue::Nil = value { + return Ok(Self::default()); + } else if let LuaValue::Table(tab) = value { + // Extract flags + let decompress = match tab.raw_get::<_, Option>("decompress") { + Ok(decomp) => Ok(decomp.unwrap_or(true)), + Err(_) => Err(LuaError::RuntimeError( + "Invalid option value for 'decompress' in request config options".to_string(), + )), + }?; + return Ok(Self { decompress }); + } + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfigOptions", + message: Some(format!( + "Invalid request config options - expected table or nil, got {}", + value.type_name() + )), + }) + } +} + +#[derive(Debug, Clone)] +pub struct RequestConfig<'a> { + pub url: String, + pub method: Method, + pub query: HashMap, LuaString<'a>>, + pub headers: HashMap, LuaString<'a>>, + pub body: Option>, + pub options: RequestConfigOptions, +} + +impl<'lua> FromLua<'lua> for RequestConfig<'lua> { + fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { + // If we just got a string we assume its a GET request to a given url + if let LuaValue::String(s) = value { + return Ok(Self { + url: s.to_string_lossy().to_string(), + method: Method::GET, + query: HashMap::new(), + headers: HashMap::new(), + body: None, + options: Default::default(), + }); + } + // If we got a table we are able to configure the entire request + if let LuaValue::Table(tab) = value { + // Extract url + let url = match tab.raw_get::<_, LuaString>("url") { + Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), + Err(_) => Err(LuaError::RuntimeError( + "Missing 'url' in request config".to_string(), + )), + }?; + // Extract method + let method = match tab.raw_get::<_, LuaString>("method") { + Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(), + Err(_) => "GET".to_string(), + }; + // Extract query + let query = match tab.raw_get::<_, LuaTable>("query") { + Ok(config_headers) => { + let mut lua_headers = HashMap::new(); + for pair in config_headers.pairs::() { + let (key, value) = pair?.to_owned(); + lua_headers.insert(key, value); + } + lua_headers + } + Err(_) => HashMap::new(), + }; + // Extract headers + let headers = match tab.raw_get::<_, LuaTable>("headers") { + Ok(config_headers) => { + let mut lua_headers = HashMap::new(); + for pair in config_headers.pairs::() { + let (key, value) = pair?.to_owned(); + lua_headers.insert(key, value); + } + lua_headers + } + Err(_) => HashMap::new(), + }; + // Extract body + let body = match tab.raw_get::<_, LuaString>("body") { + Ok(config_body) => Some(config_body.as_bytes().to_owned()), + Err(_) => None, + }; + // Convert method string into proper enum + let method = method.trim().to_ascii_uppercase(); + let method = match method.as_ref() { + "GET" => Ok(Method::GET), + "POST" => Ok(Method::POST), + "PUT" => Ok(Method::PUT), + "DELETE" => Ok(Method::DELETE), + "HEAD" => Ok(Method::HEAD), + "OPTIONS" => Ok(Method::OPTIONS), + "PATCH" => Ok(Method::PATCH), + _ => Err(LuaError::RuntimeError(format!( + "Invalid request config method '{}'", + &method + ))), + }?; + // Parse any extra options given + let options = match tab.raw_get::<_, LuaValue>("options") { + Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?, + Err(_) => RequestConfigOptions::default(), + }; + // All good, validated and we got what we need + return Ok(Self { + url, + method, + query, + headers, + body, + options, + }); + }; + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestConfig", + message: Some(format!( + "Invalid request config - expected string or table, got {}", + value.type_name() + )), + }) + } +} + +// Net serve config + +pub struct ServeConfig<'a> { + pub handle_request: LuaFunction<'a>, + pub handle_web_socket: Option>, +} + +impl<'lua> FromLua<'lua> for ServeConfig<'lua> { + fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { + let message = match &value { + LuaValue::Function(f) => { + return Ok(ServeConfig { + handle_request: f.clone(), + handle_web_socket: None, + }) + } + LuaValue::Table(t) => { + let handle_request: Option = t.raw_get("handleRequest")?; + let handle_web_socket: Option = t.raw_get("handleWebSocket")?; + if handle_request.is_some() || handle_web_socket.is_some() { + return Ok(ServeConfig { + handle_request: handle_request.unwrap_or_else(|| { + let chunk = r#" + return { + status = 426, + body = "Upgrade Required", + headers = { + Upgrade = "websocket", + }, + } + "#; + lua.load(chunk) + .into_function() + .expect("Failed to create default http responder function") + }), + handle_web_socket, + }); + } else { + Some("Missing handleRequest and / or handleWebSocket".to_string()) + } + } + _ => None, + }; + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig", + message, + }) + } +} diff --git a/src/lune/builtins/net/mod.rs b/src/lune/builtins/net/mod.rs new file mode 100644 index 0000000..95aecd9 --- /dev/null +++ b/src/lune/builtins/net/mod.rs @@ -0,0 +1,238 @@ +use std::collections::HashMap; + +use mlua::prelude::*; + +use console::style; +use hyper::{ + header::{CONTENT_ENCODING, CONTENT_LENGTH}, + Server, +}; +use tokio::{ + sync::{mpsc, oneshot}, + task, +}; + +use crate::lune::{scheduler::Scheduler, util::TableBuilder}; + +use super::serde::{ + compress_decompress::{decompress, CompressDecompressFormat}, + encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat}, +}; + +mod client; +mod config; +mod response; +mod server; +mod websocket; + +use client::{NetClient, NetClientBuilder}; +use config::{RequestConfig, ServeConfig}; +use response::NetServeResponse; +use server::{NetLocalExec, NetService}; +use websocket::NetWebSocket; + +pub fn create(lua: &'static Lua) -> LuaResult { + // Create a reusable client for performing our + // web requests and store it in the lua registry, + // allowing us to reuse headers and internal structs + let client = NetClientBuilder::new() + .headers(&[("User-Agent", create_user_agent_header())])? + .build()?; + lua.set_named_registry_value("net.client", client)?; + // Create the global table for net + TableBuilder::new(lua)? + .with_function("jsonEncode", net_json_encode)? + .with_function("jsonDecode", net_json_decode)? + .with_async_function("request", net_request)? + .with_async_function("socket", net_socket)? + .with_async_function("serve", net_serve)? + .with_function("urlEncode", net_url_encode)? + .with_function("urlDecode", net_url_decode)? + .build_readonly() +} + +fn create_user_agent_header() -> String { + let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY") + .trim_start_matches("https://github.com/") + .split_once('/') + .unwrap(); + format!("{github_owner}-{github_repo}-cli") +} + +fn net_json_encode<'lua>( + lua: &'lua Lua, + (val, pretty): (LuaValue<'lua>, Option), +) -> LuaResult> { + EncodeDecodeConfig::from((EncodeDecodeFormat::Json, pretty.unwrap_or_default())) + .serialize_to_string(lua, val) +} + +fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult> { + EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json) +} + +async fn net_request<'lua>(lua: &'lua Lua, config: RequestConfig<'lua>) -> LuaResult> +where + 'lua: 'static, // FIXME: Get rid of static lifetime bound here +{ + // Create and send the request + let client: LuaUserDataRef = lua.named_registry_value("net.client")?; + let mut request = client.request(config.method, &config.url); + for (query, value) in config.query { + request = request.query(&[(query.to_str()?, value.to_str()?)]); + } + for (header, value) in config.headers { + request = request.header(header.to_str()?, value.to_str()?); + } + let res = request + .body(config.body.unwrap_or_default()) + .send() + .await + .into_lua_err()?; + // Extract status, headers + let res_status = res.status().as_u16(); + let res_status_text = res.status().canonical_reason(); + let mut res_headers = res + .headers() + .iter() + .map(|(name, value)| { + ( + name.as_str().to_string(), + value.to_str().unwrap().to_owned(), + ) + }) + .collect::>(); + // Read response bytes + let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec(); + // Check for extra options, decompression + if config.options.decompress { + // NOTE: Header names are guaranteed to be lowercase because of the above + // transformations of them into the hashmap, so we can compare directly + let format = res_headers.iter().find_map(|(name, val)| { + if name == CONTENT_ENCODING.as_str() { + CompressDecompressFormat::detect_from_header_str(val) + } else { + None + } + }); + if let Some(format) = format { + res_bytes = decompress(format, res_bytes).await?; + let content_encoding_header_str = CONTENT_ENCODING.as_str(); + let content_length_header_str = CONTENT_LENGTH.as_str(); + res_headers.retain(|name, _| { + name != content_encoding_header_str && name != content_length_header_str + }); + } + } + // Construct and return a readonly lua table with results + TableBuilder::new(lua)? + .with_value("ok", (200..300).contains(&res_status))? + .with_value("statusCode", res_status)? + .with_value("statusMessage", res_status_text)? + .with_value("headers", res_headers)? + .with_value("body", lua.create_string(&res_bytes)?)? + .build_readonly() +} + +async fn net_socket<'lua>(lua: &'lua Lua, url: String) -> LuaResult +where + 'lua: 'static, // FIXME: Get rid of static lifetime bound here +{ + let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?; + NetWebSocket::new(ws).into_lua_table(lua) +} + +async fn net_serve<'lua>( + lua: &'lua Lua, + (port, config): (u16, ServeConfig<'lua>), +) -> LuaResult> +where + 'lua: 'static, // FIXME: Get rid of static lifetime bound here +{ + // Note that we need to use a mpsc here and not + // a oneshot channel since we move the sender + // into our table with the stop function + let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); + let server_request_callback = lua.create_registry_value(config.handle_request)?; + let server_websocket_callback = config.handle_web_socket.map(|handler| { + lua.create_registry_value(handler) + .expect("Failed to store websocket handler") + }); + let sched = lua + .app_data_ref::<&Scheduler>() + .expect("Lua struct is missing scheduler"); + // Bind first to make sure that we can bind to this address + let bound = match Server::try_bind(&([127, 0, 0, 1], port).into()) { + Err(e) => { + return Err(LuaError::external(format!( + "Failed to bind to localhost on port {port}\n{}", + format!("{e}").replace( + "error creating server listener: ", + &format!("{}", style("> ").dim()) + ) + ))); + } + Ok(bound) => bound, + }; + // Register a background task to prevent the task scheduler from + // exiting early and start up our web server on the bound address + // TODO: Implement background task registration in scheduler + let (background_tx, background_rx) = oneshot::channel(); + sched.schedule_future(async move { + let _ = background_rx.await; + }); + let server = bound + .http1_only(true) // Web sockets can only use http1 + .http1_keepalive(true) // Web sockets must be kept alive + .executor(NetLocalExec) + .serve(NetService::new( + lua, + server_request_callback, + server_websocket_callback, + )) + .with_graceful_shutdown(async move { + let _ = background_tx.send(()); + shutdown_rx + .recv() + .await + .expect("Server was stopped instantly"); + shutdown_rx.close(); + }); + // Spawn a new tokio task so we don't block + task::spawn_local(server); + // Create a new read-only table that contains methods + // for manipulating server behavior and shutting it down + let handle_stop = move |_, _: ()| match shutdown_tx.try_send(()) { + Ok(_) => Ok(()), + Err(_) => Err(LuaError::RuntimeError( + "Server has already been stopped".to_string(), + )), + }; + TableBuilder::new(lua)? + .with_function("stop", handle_stop)? + .build_readonly() +} + +fn net_url_encode<'lua>( + lua: &'lua Lua, + (lua_string, as_binary): (LuaString<'lua>, Option), +) -> LuaResult> { + if matches!(as_binary, Some(true)) { + urlencoding::encode_binary(lua_string.as_bytes()).into_lua(lua) + } else { + urlencoding::encode(lua_string.to_str()?).into_lua(lua) + } +} + +fn net_url_decode<'lua>( + lua: &'lua Lua, + (lua_string, as_binary): (LuaString<'lua>, Option), +) -> LuaResult> { + if matches!(as_binary, Some(true)) { + urlencoding::decode_binary(lua_string.as_bytes()).into_lua(lua) + } else { + urlencoding::decode(lua_string.to_str()?) + .map_err(|e| LuaError::RuntimeError(format!("Encountered invalid encoding - {e}")))? + .into_lua(lua) + } +} diff --git a/src/lune/builtins/net/response.rs b/src/lune/builtins/net/response.rs new file mode 100644 index 0000000..fa2e748 --- /dev/null +++ b/src/lune/builtins/net/response.rs @@ -0,0 +1,106 @@ +use std::collections::HashMap; + +use hyper::{Body, Response}; +use mlua::prelude::*; + +#[derive(Debug, Clone, Copy)] +pub enum NetServeResponseKind { + PlainText, + Table, +} + +#[derive(Debug, Clone)] +pub struct NetServeResponse { + kind: NetServeResponseKind, + status: u16, + headers: HashMap>, + body: Option>, +} + +impl NetServeResponse { + pub fn into_response(self) -> LuaResult> { + Ok(match self.kind { + NetServeResponseKind::PlainText => Response::builder() + .status(200) + .header("Content-Type", "text/plain") + .body(Body::from(self.body.unwrap())) + .into_lua_err()?, + NetServeResponseKind::Table => { + let mut response = Response::builder(); + for (key, value) in self.headers { + response = response.header(&key, value); + } + response + .status(self.status) + .body(Body::from(self.body.unwrap_or_default())) + .into_lua_err()? + } + }) + } +} + +impl<'lua> FromLua<'lua> for NetServeResponse { + fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult { + match value { + // Plain strings from the handler are plaintext responses + LuaValue::String(s) => Ok(Self { + kind: NetServeResponseKind::PlainText, + status: 200, + headers: HashMap::new(), + body: Some(s.as_bytes().to_vec()), + }), + // Tables are more detailed responses with potential status, headers, body + LuaValue::Table(t) => { + let status: Option = t.get("status")?; + let headers: Option = t.get("headers")?; + let body: Option = t.get("body")?; + + let mut headers_map = HashMap::new(); + if let Some(headers) = headers { + for pair in headers.pairs::() { + let (h, v) = pair?; + headers_map.insert(h, v.as_bytes().to_vec()); + } + } + + let body_bytes = body.map(|s| s.as_bytes().to_vec()); + + Ok(Self { + kind: NetServeResponseKind::Table, + status: status.unwrap_or(200), + headers: headers_map, + body: body_bytes, + }) + } + // Anything else is an error + value => Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "NetServeResponse", + message: None, + }), + } + } +} + +impl<'lua> IntoLua<'lua> for NetServeResponse { + fn into_lua(self, lua: &'lua Lua) -> LuaResult> { + if self.headers.len() > i32::MAX as usize { + return Err(LuaError::ToLuaConversionError { + from: "NetServeResponse", + to: "table", + message: Some("Too many header values".to_string()), + }); + } + let body = self.body.map(|b| lua.create_string(b)).transpose()?; + let headers = lua.create_table_with_capacity(0, self.headers.len())?; + for (key, value) in self.headers { + headers.set(key, lua.create_string(&value)?)?; + } + let table = lua.create_table_with_capacity(0, 3)?; + table.set("status", self.status)?; + table.set("headers", headers)?; + table.set("body", body)?; + table.set_readonly(true); + Ok(LuaValue::Table(table)) + } +} diff --git a/src/lune/builtins/net/server.rs b/src/lune/builtins/net/server.rs new file mode 100644 index 0000000..8fb515b --- /dev/null +++ b/src/lune/builtins/net/server.rs @@ -0,0 +1,180 @@ +use std::{ + future::Future, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use mlua::prelude::*; + +use hyper::{body::to_bytes, server::conn::AddrStream, service::Service}; +use hyper::{Body, Request, Response}; +use hyper_tungstenite::{is_upgrade_request as is_ws_upgrade_request, upgrade as ws_upgrade}; +use tokio::{sync::oneshot, task}; + +use crate::{ + lune::{scheduler::Scheduler, util::TableBuilder}, + LuneError, +}; + +use super::{NetServeResponse, NetWebSocket}; + +// Hyper service implementation for net, lots of boilerplate here +// but make_svc and make_svc_function do not work for what we need + +pub struct NetServiceInner( + &'static Lua, + Arc, + Arc>, +); + +impl Service> for NetServiceInner { + type Response = Response; + type Error = LuaError; + type Future = Pin>>>; + + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + let lua = self.0; + if self.2.is_some() && is_ws_upgrade_request(&req) { + // Websocket upgrade request + websocket handler exists, + // we should now upgrade this connection to a websocket + // and then call our handler with a new socket object + let kopt = self.2.clone(); + let key = kopt.as_ref().as_ref().unwrap(); + let handler: LuaFunction = lua.registry_value(key).expect("Missing websocket handler"); + let (response, ws) = ws_upgrade(&mut req, None).expect("Failed to upgrade websocket"); + // This should be spawned as a registered task, otherwise + // the scheduler may exit early and cancel this even though what + // we want here is a long-running task that keeps the program alive + let sched = lua + .app_data_ref::<&Scheduler>() + .expect("Lua struct is missing scheduler"); + // TODO: Implement background task registration in scheduler + let (background_tx, background_rx) = oneshot::channel(); + sched.schedule_future(async move { + let _ = background_rx.await; + }); + task::spawn_local(async move { + // Create our new full websocket object, then + // schedule our handler to get called asap + let ws = ws.await.into_lua_err()?; + let sock = NetWebSocket::new(ws).into_lua_table(lua)?; + let sched = lua + .app_data_ref::<&Scheduler>() + .expect("Lua struct is missing scheduler"); + let result = sched.push_front( + lua.create_thread(handler)?, + LuaMultiValue::from_vec(vec![LuaValue::Table(sock)]), + ); + let _ = background_tx.send(()); + result + }); + Box::pin(async move { Ok(response) }) + } else { + // Got a normal http request or no websocket handler + // exists, just call the http request handler + let key = self.1.clone(); + let (parts, body) = req.into_parts(); + Box::pin(async move { + // Convert request body into bytes, extract handler + let bytes = to_bytes(body).await.into_lua_err()?; + let handler: LuaFunction = lua.registry_value(&key)?; + // Create a readonly table for the request query params + let query_params = TableBuilder::new(lua)? + .with_values( + parts + .uri + .query() + .unwrap_or_default() + .split('&') + .filter_map(|q| q.split_once('=')) + .collect(), + )? + .build_readonly()?; + // Do the same for headers + let header_map = TableBuilder::new(lua)? + .with_values( + parts + .headers + .iter() + .map(|(name, value)| { + (name.to_string(), value.to_str().unwrap().to_string()) + }) + .collect(), + )? + .build_readonly()?; + // Create a readonly table with request info to pass to the handler + let request = TableBuilder::new(lua)? + .with_value("path", parts.uri.path())? + .with_value("query", query_params)? + .with_value("method", parts.method.as_str())? + .with_value("headers", header_map)? + .with_value("body", lua.create_string(&bytes)?)? + .build_readonly()?; + let response: LuaResult = handler.call(request); + // Send below errors to task scheduler so that they can emit properly + let lua_error = match response { + Ok(r) => match r.into_response() { + Ok(res) => return Ok(res), + Err(err) => err, + }, + Err(err) => err, + }; + eprintln!("{}", LuneError::from(lua_error)); + Ok(Response::builder() + .status(500) + .body(Body::from("Internal Server Error")) + .unwrap()) + }) + } + } +} + +pub struct NetService( + &'static Lua, + Arc, + Arc>, +); + +impl NetService { + pub fn new( + lua: &'static Lua, + callback_http: LuaRegistryKey, + callback_websocket: Option, + ) -> Self { + Self(lua, Arc::new(callback_http), Arc::new(callback_websocket)) + } +} + +impl Service<&AddrStream> for NetService { + type Response = NetServiceInner; + type Error = hyper::Error; + type Future = Pin>>>; + + fn poll_ready(&mut self, _: &mut Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn call(&mut self, _: &AddrStream) -> Self::Future { + let lua = self.0; + let key1 = self.1.clone(); + let key2 = self.2.clone(); + Box::pin(async move { Ok(NetServiceInner(lua, key1, key2)) }) + } +} + +#[derive(Clone, Copy, Debug)] +pub struct NetLocalExec; + +impl hyper::rt::Executor for NetLocalExec +where + F: std::future::Future + 'static, // not requiring `Send` +{ + fn execute(&self, fut: F) { + task::spawn_local(fut); + } +} diff --git a/src/lune/builtins/net/websocket.rs b/src/lune/builtins/net/websocket.rs new file mode 100644 index 0000000..36fc7b1 --- /dev/null +++ b/src/lune/builtins/net/websocket.rs @@ -0,0 +1,229 @@ +use std::{cell::Cell, sync::Arc}; + +use hyper::upgrade::Upgraded; +use mlua::prelude::*; + +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, + sync::Mutex as AsyncMutex, +}; + +use hyper_tungstenite::{ + tungstenite::{ + protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame}, + Message as WsMessage, + }, + WebSocketStream, +}; +use tokio_tungstenite::MaybeTlsStream; + +use crate::lune::util::TableBuilder; + +const WEB_SOCKET_IMPL_LUA: &str = r#" +return freeze(setmetatable({ + close = function(...) + return close(websocket, ...) + end, + send = function(...) + return send(websocket, ...) + end, + next = function(...) + return next(websocket, ...) + end, +}, { + __index = function(self, key) + if key == "closeCode" then + return close_code(websocket) + end + end, +})) +"#; + +#[derive(Debug)] +pub struct NetWebSocket { + close_code: Arc>>, + read_stream: Arc>>>, + write_stream: Arc, WsMessage>>>, +} + +impl Clone for NetWebSocket { + fn clone(&self) -> Self { + Self { + close_code: Arc::clone(&self.close_code), + read_stream: Arc::clone(&self.read_stream), + write_stream: Arc::clone(&self.write_stream), + } + } +} + +impl NetWebSocket +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(value: WebSocketStream) -> Self { + let (write, read) = value.split(); + + Self { + close_code: Arc::new(Cell::new(None)), + read_stream: Arc::new(AsyncMutex::new(read)), + write_stream: Arc::new(AsyncMutex::new(write)), + } + } + + fn into_lua_table_with_env<'lua>( + lua: &'lua Lua, + env: LuaTable<'lua>, + ) -> LuaResult> { + lua.load(WEB_SOCKET_IMPL_LUA) + .set_name("websocket") + .set_environment(env) + .eval() + } +} + +type NetWebSocketStreamClient = MaybeTlsStream; +impl NetWebSocket { + pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { + let socket_env = TableBuilder::new(lua)? + .with_value("websocket", self)? + .with_function("close_code", close_code::)? + .with_async_function("close", close::)? + .with_async_function("send", send::)? + .with_async_function("next", next::)? + .with_value( + "setmetatable", + lua.named_registry_value::("tab.setmeta")?, + )? + .with_value( + "freeze", + lua.named_registry_value::("tab.freeze")?, + )? + .build_readonly()?; + Self::into_lua_table_with_env(lua, socket_env) + } +} + +type NetWebSocketStreamServer = Upgraded; +impl NetWebSocket { + pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { + let socket_env = TableBuilder::new(lua)? + .with_value("websocket", self)? + .with_function("close_code", close_code::)? + .with_async_function("close", close::)? + .with_async_function("send", send::)? + .with_async_function("next", next::)? + .with_value( + "setmetatable", + lua.named_registry_value::("tab.setmeta")?, + )? + .with_value( + "freeze", + lua.named_registry_value::("tab.freeze")?, + )? + .build_readonly()?; + Self::into_lua_table_with_env(lua, socket_env) + } +} + +impl LuaUserData for NetWebSocket {} + +fn close_code<'lua, T>( + _lua: &'lua Lua, + socket: LuaUserDataRef<'lua, NetWebSocket>, +) -> LuaResult> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + Ok(match socket.close_code.get() { + Some(code) => LuaValue::Number(code as f64), + None => LuaValue::Nil, + }) +} + +async fn close<'lua, T>( + _lua: &'lua Lua, + (socket, code): (LuaUserDataRef<'lua, NetWebSocket>, Option), +) -> LuaResult<()> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut ws = socket.write_stream.lock().await; + + ws.send(WsMessage::Close(Some(WsCloseFrame { + code: match code { + Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), + Some(code) => { + return Err(LuaError::RuntimeError(format!( + "Close code must be between 1000 and 4999, got {code}" + ))) + } + None => WsCloseCode::Normal, + }, + reason: "".into(), + }))) + .await + .into_lua_err()?; + + let res = ws.close(); + res.await.into_lua_err() +} + +async fn send<'lua, T>( + _lua: &'lua Lua, + (socket, string, as_binary): ( + LuaUserDataRef<'lua, NetWebSocket>, + LuaString<'lua>, + Option, + ), +) -> LuaResult<()> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + let msg = if matches!(as_binary, Some(true)) { + WsMessage::Binary(string.as_bytes().to_vec()) + } else { + let s = string.to_str().into_lua_err()?; + WsMessage::Text(s.to_string()) + }; + let mut ws = socket.write_stream.lock().await; + ws.send(msg).await.into_lua_err() +} + +async fn next<'lua, T>( + lua: &'lua Lua, + socket: LuaUserDataRef<'lua, NetWebSocket>, +) -> LuaResult> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut ws = socket.read_stream.lock().await; + let item = ws.next().await.transpose().into_lua_err(); + let msg = match item { + Ok(Some(WsMessage::Close(msg))) => { + if let Some(msg) = &msg { + socket.close_code.replace(Some(msg.code.into())); + } + Ok(Some(WsMessage::Close(msg))) + } + val => val, + }?; + while let Some(msg) = &msg { + let msg_string_opt = match msg { + WsMessage::Binary(bin) => Some(lua.create_string(bin)?), + WsMessage::Text(txt) => Some(lua.create_string(txt)?), + // Stop waiting for next message if we get a close message + WsMessage::Close(_) => return Ok(LuaValue::Nil), + // Ignore ping/pong/frame messages, they are handled by tungstenite + _ => None, + }; + if let Some(msg_string) = msg_string_opt { + return Ok(LuaValue::String(msg_string)); + } + } + Ok(LuaValue::Nil) +} diff --git a/src/lune/builtins/serde/mod.rs b/src/lune/builtins/serde/mod.rs index 5da31fe..4e76bce 100644 --- a/src/lune/builtins/serde/mod.rs +++ b/src/lune/builtins/serde/mod.rs @@ -1,9 +1,9 @@ use mlua::prelude::*; -mod compress_decompress; -use compress_decompress::{compress, decompress, CompressDecompressFormat}; +pub(super) mod compress_decompress; +pub(super) mod encode_decode; -mod encode_decode; +use compress_decompress::{compress, decompress, CompressDecompressFormat}; use encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat}; use crate::lune::util::TableBuilder;