From 6c97003571cab72b92af3ab82feeb87ecc317b27 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Sat, 11 Feb 2023 22:40:14 +0100 Subject: [PATCH] Implement web sockets for net serve --- CHANGELOG.md | 26 ++++++++ lune.yml | 2 +- luneTypes.d.luau | 18 +++++- packages/lib/src/globals/net.rs | 85 ++++++++++---------------- packages/lib/src/lib.rs | 1 + packages/lib/src/lua/mod.rs | 1 + packages/lib/src/lua/net/client.rs | 49 +++++++++++++++ packages/lib/src/lua/net/config.rs | 50 +++++++++++++++ packages/lib/src/lua/net/mod.rs | 7 +++ packages/lib/src/lua/net/ws_server.rs | 87 +++++++++++++++++++++++++++ packages/lib/src/utils/net.rs | 18 ------ packages/lib/src/utils/table.rs | 44 +++++++------- tests/net/serve.luau | 31 +++++++++- 13 files changed, 320 insertions(+), 99 deletions(-) create mode 100644 packages/lib/src/lua/mod.rs create mode 100644 packages/lib/src/lua/net/client.rs create mode 100644 packages/lib/src/lua/net/config.rs create mode 100644 packages/lib/src/lua/net/mod.rs create mode 100644 packages/lib/src/lua/net/ws_server.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index 74f3ee3..ad97a7f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- `net.serve` now supports web sockets in addition to normal http requests! + + Example usage: + + ```lua + net.serve(8080, { + handleRequest = function(request) + return "Hello, world!" + end, + handleWebSocket = function(socket) + task.delay(10, function() + socket.send("Timed out!") + socket.close() + end) + -- This will yield waiting for new messages, and will break + -- when the socket was closed by either the server or client + for message in socket do + if message == "Ping" then + socket.send("Pong") + end + end + end, + }) + ``` + - `net.serve` now returns a `NetServeHandle` which can be used to stop serving requests safely. Example usage: @@ -19,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 end) print("Shutting down after 1 second...") + task.wait(1) handle.stop() print("Shut down succesfully") ``` diff --git a/lune.yml b/lune.yml index cb5f07a..fefb73e 100644 --- a/lune.yml +++ b/lune.yml @@ -48,7 +48,7 @@ globals: net.serve: args: - type: number - - type: function + - type: function | table # Processs process.args: property: read-only diff --git a/luneTypes.d.luau b/luneTypes.d.luau index f1b997a..31c9244 100644 --- a/luneTypes.d.luau +++ b/luneTypes.d.luau @@ -153,10 +153,24 @@ export type NetResponse = { body: string?, } +type NetServeHttpHandler = (request: NetRequest) -> (string | NetResponse) +type NetServeWebSocketHandler = (socket: NetWebSocket) -> () + +export type NetServeConfig = { + handleRequest: NetServeHttpHandler?, + handleWebSocket: NetServeWebSocketHandler?, +} + export type NetServeHandle = { stop: () -> (), } +declare class NetWebSocket + close: () -> () + send: (message: string) -> () + function __iter(self): () -> string +end + --[=[ @class net @@ -183,9 +197,9 @@ declare net: { until the `stop` function on the returned `NetServeHandle` has been called. @param port The port to use for the server - @param handler The handler function to use for the server + @param handlerOrConfig The handler function or config to use for the server ]=] - serve: (port: number, handler: (request: NetRequest) -> (string | NetResponse)) -> NetServeHandle, + serve: (port: number, handlerOrConfig: NetServeHttpHandler | NetServeConfig) -> NetServeHandle, --[=[ @within net diff --git a/packages/lib/src/globals/net.rs b/packages/lib/src/globals/net.rs index 39132f7..020cb61 100644 --- a/packages/lib/src/globals/net.rs +++ b/packages/lib/src/globals/net.rs @@ -8,40 +8,27 @@ use std::{ use mlua::prelude::*; -use hyper::{body::to_bytes, http::HeaderValue, server::conn::AddrStream, service::Service}; -use hyper::{Body, HeaderMap, Request, Response, Server}; -use hyper_tungstenite::{ - is_upgrade_request as is_ws_upgrade_request, tungstenite::Message as WsMessage, - upgrade as ws_upgrade, -}; +use hyper::{body::to_bytes, server::conn::AddrStream, service::Service}; +use hyper::{Body, Request, Response, Server}; +use hyper_tungstenite::{is_upgrade_request as is_ws_upgrade_request, upgrade as ws_upgrade}; -use futures_util::{SinkExt, StreamExt}; -use reqwest::{ClientBuilder, Method}; +use reqwest::Method; use tokio::{ sync::mpsc::{self, Sender}, task, }; -use crate::utils::{ - message::LuneMessage, - net::{get_request_user_agent_header, NetClient}, - table::TableBuilder, +use crate::{ + lua::net::{NetClient, NetClientBuilder, NetWebSocketServer, ServeConfig}, + utils::{message::LuneMessage, net::get_request_user_agent_header, table::TableBuilder}, }; pub fn create(lua: &'static Lua) -> LuaResult { // Create a reusable client for performing our // web requests and store it in the lua registry - let mut default_headers = HeaderMap::new(); - default_headers.insert( - "User-Agent", - HeaderValue::from_str(&get_request_user_agent_header()).map_err(LuaError::external)?, - ); - let client = NetClient::new( - ClientBuilder::new() - .default_headers(default_headers) - .build() - .map_err(LuaError::external)?, - ); + let client = NetClientBuilder::new() + .headers(&[("User-Agent", get_request_user_agent_header())])? + .build()?; lua.set_named_registry_value("NetClient", client)?; // Create the global table for net TableBuilder::new(lua)? @@ -158,17 +145,24 @@ async fn net_request<'a>(lua: &'static Lua, config: LuaValue<'a>) -> LuaResult( lua: &'static Lua, - (port, callback): (u16, LuaFunction<'a>), // TODO: Parse options as either callback or table with request callback + websocket callback + (port, config): (u16, ServeConfig<'a>), ) -> LuaResult> { // Note that we need to use a mpsc here and not // a oneshot channel since we move the sender // into our table with the stop function let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); - let websocket_callback = Arc::new(None); // TODO: Store websocket callback, if given - let server_callback = Arc::new(lua.create_registry_value(callback)?); + let server_request_callback = Arc::new(lua.create_registry_value(config.handle_request)?); + let server_websocket_callback = Arc::new(config.handle_web_socket.map(|handler| { + lua.create_registry_value(handler) + .expect("Failed to store websocket handler") + })); let server = Server::bind(&([127, 0, 0, 1], port).into()) .executor(LocalExec) - .serve(MakeNetService(lua, server_callback, websocket_callback)) + .serve(MakeNetService( + lua, + server_request_callback, + server_websocket_callback, + )) .with_graceful_shutdown(async move { shutdown_rx.recv().await.unwrap(); shutdown_rx.close(); @@ -225,40 +219,25 @@ impl Service> for NetService { fn call(&mut self, mut req: Request) -> Self::Future { let lua = self.0; if self.2.is_some() && is_ws_upgrade_request(&req) { - // Websocket request + websocket handler exists, - // we should upgrade this connection to a websocket - // and then pass a socket object to our lua handler + // Websocket upgrade request + websocket handler exists, + // we should now upgrade this connection to a websocket + // and then call our handler with a new socket object let kopt = self.2.clone(); let key = kopt.as_ref().as_ref().unwrap(); let handler: LuaFunction = lua.registry_value(key).expect("Missing websocket handler"); let (response, ws) = ws_upgrade(&mut req, None).expect("Failed to upgrade websocket"); task::spawn_local(async move { - if let Ok(mut websocket) = ws.await { - // TODO: Create lua userdata websocket object - // with methods for interacting with the websocket - // TODO: Start waiting for messages when we know - // for sure that we have gotten a message handler - // and move the following logic into there instead - while let Some(message) = websocket.next().await { - // Create lua strings from websocket messages - if let Some(handler_str) = match message.map_err(LuaError::external)? { - WsMessage::Text(msg) => Some(lua.create_string(&msg)?), - WsMessage::Binary(msg) => Some(lua.create_string(&msg)?), - // Tungstenite takes care of these messages - WsMessage::Ping(_) => None, - WsMessage::Pong(_) => None, - WsMessage::Close(_) => None, - WsMessage::Frame(_) => None, - } { - // TODO: Call whatever lua handler we have registered, with our message string - } - } - } - Ok::<_, LuaError>(()) + // Create our new full websocket object + let ws = ws.await.map_err(LuaError::external)?; + let ws_lua = NetWebSocketServer::from(ws); + let ws_proper = ws_lua.into_proper(lua).await?; + // Call our handler with it + handler.call_async::<_, ()>(ws_proper).await }); Box::pin(async move { Ok(response) }) } else { - // Normal http request or no websocket handler exists, call the http request handler + // Got a normal http request or no websocket handler + // exists, just call the http request handler let key = self.1.clone(); let (parts, body) = req.into_parts(); Box::pin(async move { diff --git a/packages/lib/src/lib.rs b/packages/lib/src/lib.rs index e644d13..4ca37f9 100644 --- a/packages/lib/src/lib.rs +++ b/packages/lib/src/lib.rs @@ -4,6 +4,7 @@ use mlua::prelude::*; use tokio::{sync::mpsc, task}; pub(crate) mod globals; +pub(crate) mod lua; pub(crate) mod utils; #[cfg(test)] diff --git a/packages/lib/src/lua/mod.rs b/packages/lib/src/lua/mod.rs new file mode 100644 index 0000000..f9faf2f --- /dev/null +++ b/packages/lib/src/lua/mod.rs @@ -0,0 +1 @@ +pub mod net; diff --git a/packages/lib/src/lua/net/client.rs b/packages/lib/src/lua/net/client.rs new file mode 100644 index 0000000..6381f63 --- /dev/null +++ b/packages/lib/src/lua/net/client.rs @@ -0,0 +1,49 @@ +use std::str::FromStr; + +use mlua::prelude::*; + +use hyper::{header::HeaderName, http::HeaderValue, HeaderMap}; +use reqwest::{IntoUrl, Method, RequestBuilder}; + +pub struct NetClientBuilder { + builder: reqwest::ClientBuilder, +} + +impl NetClientBuilder { + pub fn new() -> NetClientBuilder { + Self { + builder: reqwest::ClientBuilder::new(), + } + } + + pub fn headers(mut self, headers: &[(K, V)]) -> LuaResult + where + K: AsRef, + V: AsRef<[u8]>, + { + let mut map = HeaderMap::new(); + for (key, val) in headers { + let hkey = HeaderName::from_str(key.as_ref()).map_err(LuaError::external)?; + let hval = HeaderValue::from_bytes(val.as_ref()).map_err(LuaError::external)?; + map.insert(hkey, hval); + } + self.builder = self.builder.default_headers(map); + Ok(self) + } + + pub fn build(self) -> LuaResult { + let client = self.builder.build().map_err(LuaError::external)?; + Ok(NetClient(client)) + } +} + +#[derive(Debug, Clone)] +pub struct NetClient(reqwest::Client); + +impl NetClient { + pub fn request(&self, method: Method, url: U) -> RequestBuilder { + self.0.request(method, url) + } +} + +impl LuaUserData for NetClient {} diff --git a/packages/lib/src/lua/net/config.rs b/packages/lib/src/lua/net/config.rs new file mode 100644 index 0000000..f382a28 --- /dev/null +++ b/packages/lib/src/lua/net/config.rs @@ -0,0 +1,50 @@ +use mlua::prelude::*; + +pub struct ServeConfig<'a> { + pub handle_request: LuaFunction<'a>, + pub handle_web_socket: Option>, +} + +impl<'lua> FromLua<'lua> for ServeConfig<'lua> { + fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult { + let message = match &value { + LuaValue::Function(f) => { + return Ok(ServeConfig { + handle_request: f.clone(), + handle_web_socket: None, + }) + } + LuaValue::Table(t) => { + let handle_request: Option = t.raw_get("handleRequest")?; + let handle_web_socket: Option = t.raw_get("handleWebSocket")?; + if handle_request.is_some() || handle_web_socket.is_some() { + return Ok(ServeConfig { + handle_request: handle_request.unwrap_or_else(|| { + let chunk = r#" + return { + status = 426, + body = "Upgrade Required", + headers = { + Upgrade = "websocket", + }, + } + "#; + lua.load(chunk) + .into_function() + .expect("Failed to create default http responder function") + }), + handle_web_socket, + }); + } else { + Some("Missing handleRequest and / or handleWebSocket".to_string()) + } + } + _ => None, + }; + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig", + message, + }) + } +} diff --git a/packages/lib/src/lua/net/mod.rs b/packages/lib/src/lua/net/mod.rs new file mode 100644 index 0000000..32dbf46 --- /dev/null +++ b/packages/lib/src/lua/net/mod.rs @@ -0,0 +1,7 @@ +mod client; +mod config; +mod ws_server; + +pub use client::{NetClient, NetClientBuilder}; +pub use config::ServeConfig; +pub use ws_server::NetWebSocketServer; diff --git a/packages/lib/src/lua/net/ws_server.rs b/packages/lib/src/lua/net/ws_server.rs new file mode 100644 index 0000000..ad9a6c7 --- /dev/null +++ b/packages/lib/src/lua/net/ws_server.rs @@ -0,0 +1,87 @@ +use std::sync::Arc; + +use mlua::prelude::*; + +use hyper::upgrade::Upgraded; +use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream}; + +use futures_util::{SinkExt, StreamExt}; +use tokio::sync::Mutex; + +#[derive(Debug, Clone)] +pub struct NetWebSocketServer(Arc>>); + +impl NetWebSocketServer { + pub async fn close(&self) -> LuaResult<()> { + let mut ws = self.0.lock().await; + ws.close(None).await.map_err(LuaError::external)?; + Ok(()) + } + + pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { + let mut ws = self.0.lock().await; + ws.send(msg).await.map_err(LuaError::external)?; + Ok(()) + } + + pub async fn next(&self) -> LuaResult> { + let mut ws = self.0.lock().await; + let item = ws.next().await.transpose(); + item.map_err(LuaError::external) + } + + pub async fn into_proper(self, lua: &'static Lua) -> LuaResult { + // HACK: This creates a new userdata that consumes and proxies this one, + // since there's no great way to implement this in pure async Rust + // and as a plain table without tons of strange lifetime issues + let chunk = r#" + return function(ws) + local proxy = newproxy(true) + local meta = getmetatable(proxy) + meta.__index = { + close = function() + return ws:close() + end, + send = function(...) + return ws:send(...) + end, + next = function() + return ws:next() + end, + } + meta.__iter = function() + return function() + return ws:next() + end + end + return proxy + end + "#; + lua.load(chunk).call_async(self).await + } +} + +impl LuaUserData for NetWebSocketServer { + fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_async_method("close", |_, this, _: ()| async move { this.close().await }); + methods.add_async_method("send", |_, this, msg: String| async move { + this.send(WsMessage::Text(msg)).await + }); + methods.add_async_method("next", |lua, this, _: ()| async move { + match this.next().await? { + Some(msg) => Ok(match msg { + WsMessage::Binary(bin) => LuaValue::String(lua.create_string(&bin)?), + WsMessage::Text(txt) => LuaValue::String(lua.create_string(&txt)?), + _ => LuaValue::Nil, + }), + None => Ok(LuaValue::Nil), + } + }) + } +} + +impl From> for NetWebSocketServer { + fn from(value: WebSocketStream) -> Self { + Self(Arc::new(Mutex::new(value))) + } +} diff --git a/packages/lib/src/utils/net.rs b/packages/lib/src/utils/net.rs index 2083a05..9625650 100644 --- a/packages/lib/src/utils/net.rs +++ b/packages/lib/src/utils/net.rs @@ -1,21 +1,3 @@ -use mlua::prelude::*; -use reqwest::{IntoUrl, Method, RequestBuilder}; - -#[derive(Clone)] -pub struct NetClient(reqwest::Client); - -impl NetClient { - pub fn new(client: reqwest::Client) -> Self { - Self(client) - } - - pub fn request(&self, method: Method, url: U) -> RequestBuilder { - self.0.request(method, url) - } -} - -impl LuaUserData for NetClient {} - pub fn get_github_owner_and_repo() -> (String, String) { let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY") .strip_prefix("https://github.com/") diff --git a/packages/lib/src/utils/table.rs b/packages/lib/src/utils/table.rs index 111b049..d6435df 100644 --- a/packages/lib/src/utils/table.rs +++ b/packages/lib/src/utils/table.rs @@ -2,22 +2,22 @@ use std::future::Future; use mlua::prelude::*; -pub struct TableBuilder<'lua> { - lua: &'lua Lua, - tab: LuaTable<'lua>, +pub struct TableBuilder { + lua: &'static Lua, + tab: LuaTable<'static>, } #[allow(dead_code)] -impl<'lua> TableBuilder<'lua> { - pub fn new(lua: &'lua Lua) -> LuaResult { +impl TableBuilder { + pub fn new(lua: &'static Lua) -> LuaResult { let tab = lua.create_table()?; Ok(Self { lua, tab }) } pub fn with_value(self, key: K, value: V) -> LuaResult where - K: ToLua<'lua>, - V: ToLua<'lua>, + K: ToLua<'static>, + V: ToLua<'static>, { self.tab.raw_set(key, value)?; Ok(self) @@ -25,8 +25,8 @@ impl<'lua> TableBuilder<'lua> { pub fn with_values(self, values: Vec<(K, V)>) -> LuaResult where - K: ToLua<'lua>, - V: ToLua<'lua>, + K: ToLua<'static>, + V: ToLua<'static>, { for (key, value) in values { self.tab.raw_set(key, value)?; @@ -36,7 +36,7 @@ impl<'lua> TableBuilder<'lua> { pub fn with_sequential_value(self, value: V) -> LuaResult where - V: ToLua<'lua>, + V: ToLua<'static>, { self.tab.raw_push(value)?; Ok(self) @@ -44,7 +44,7 @@ impl<'lua> TableBuilder<'lua> { pub fn with_sequential_values(self, values: Vec) -> LuaResult where - V: ToLua<'lua>, + V: ToLua<'static>, { for value in values { self.tab.raw_push(value)?; @@ -59,10 +59,10 @@ impl<'lua> TableBuilder<'lua> { pub fn with_function(self, key: K, func: F) -> LuaResult where - K: ToLua<'lua>, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + Fn(&'lua Lua, A) -> LuaResult, + K: ToLua<'static>, + A: FromLuaMulti<'static>, + R: ToLuaMulti<'static>, + F: 'static + Fn(&'static Lua, A) -> LuaResult, { let f = self.lua.create_function(func)?; self.with_value(key, LuaValue::Function(f)) @@ -70,22 +70,22 @@ impl<'lua> TableBuilder<'lua> { pub fn with_async_function(self, key: K, func: F) -> LuaResult where - K: ToLua<'lua>, - A: FromLuaMulti<'lua>, - R: ToLuaMulti<'lua>, - F: 'static + Fn(&'lua Lua, A) -> FR, - FR: 'lua + Future>, + K: ToLua<'static>, + A: FromLuaMulti<'static>, + R: ToLuaMulti<'static>, + F: 'static + Fn(&'static Lua, A) -> FR, + FR: 'static + Future>, { let f = self.lua.create_async_function(func)?; self.with_value(key, LuaValue::Function(f)) } - pub fn build_readonly(self) -> LuaResult> { + pub fn build_readonly(self) -> LuaResult> { self.tab.set_readonly(true); Ok(self.tab) } - pub fn build(self) -> LuaResult> { + pub fn build(self) -> LuaResult> { Ok(self.tab) } } diff --git a/tests/net/serve.luau b/tests/net/serve.luau index d17032a..a2f1587 100644 --- a/tests/net/serve.luau +++ b/tests/net/serve.luau @@ -1,6 +1,7 @@ +local PORT = 8080 local RESPONSE = "Hello, lune!" -local handle = net.serve(8080, function(request) +local handle = net.serve(PORT, function(request) -- info("Request:", request) -- info("Responding with", RESPONSE) assert(request.path == "/some/path") @@ -10,7 +11,7 @@ local handle = net.serve(8080, function(request) end) local response = - net.request("http://127.0.0.1:8080/some/path?key=param1&key=param2&key2=param3").body + net.request(`http://127.0.0.1:{PORT}/some/path?key=param1&key=param2&key2=param3`).body assert(response == RESPONSE, "Invalid response from server") handle.stop() @@ -21,7 +22,7 @@ task.wait() -- Sending a net request may error if there was -- a connection issue, we should handle that here -local success, response2 = pcall(net.request, "http://127.0.0.1:8080/") +local success, response2 = pcall(net.request, `http://127.0.0.1:{PORT}/`) if not success then local message = tostring(response2) assert( @@ -51,3 +52,27 @@ assert( or string.find(message, "shut down"), "The error message for calling stop twice on the net serve handle should be descriptive" ) + +--[[ + Serve should also take a full config with handler functions + + A server should also be able to start on the previously closed port +]] +local handle2 = net.serve(PORT, { + handleRequest = function() + return RESPONSE + end, + handleWebSocket = function(socket) + socket.close() + end, +}) + +local response3 = net.request(`http://127.0.0.1:{PORT}/`).body +assert(response3 == RESPONSE, "Invalid response from server") + +-- TODO: Test web sockets properly when we have a web socket client + +-- Stop the server and yield once more to end the test + +handle2.stop() +task.wait()