From 962d89fd40ecdb80a5cf03f5201bfcbcf6200943 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Sat, 11 Feb 2023 15:09:06 +0100 Subject: [PATCH] Accept & upgrade websocket connections --- Cargo.lock | 170 +++++++++++++++++++++++- packages/lib/Cargo.toml | 2 + packages/lib/src/globals/net.rs | 221 +++++++++++++++++++------------- 3 files changed, 301 insertions(+), 92 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 82742ed..2fcfb99 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,6 +23,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.0" @@ -41,6 +47,15 @@ version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" +[[package]] +name = "block-buffer" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cce20737498f97b993470a6e536b8523f0af7892a4f928cceb1ac5e52ebe7e" +dependencies = [ + "generic-array", +] + [[package]] name = "bstr" version = "0.2.17" @@ -62,6 +77,12 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be0fdd54b507df8f22012890aadd099979befdba27713c767993f8380112ca7c" +[[package]] +name = "byteorder" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" + [[package]] name = "bytes" version = "1.4.0" @@ -136,6 +157,25 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" +[[package]] +name = "cpufeatures" +version = "0.2.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d997bd5e24a5928dd43e46dc529867e207907fe0b239c3477d924f7f2ca320" +dependencies = [ + "libc", +] + +[[package]] +name = "crypto-common" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +dependencies = [ + "generic-array", + "typenum", +] + [[package]] name = "derive_more" version = "0.99.17" @@ -161,6 +201,16 @@ dependencies = [ "zeroize", ] +[[package]] +name = "digest" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8168378f4e5023e7218c89c891c0fd8ecdb5e5e4f18cb78f38cf245dd021e76f" +dependencies = [ + "block-buffer", + "crypto-common", +] + [[package]] name = "directories" version = "4.0.1" @@ -324,12 +374,23 @@ checksum = "9c1d6de3acfef38d2be4b1f543f553131788603495be83da675e180c8d6b7bd1" dependencies = [ "futures-core", "futures-macro", + "futures-sink", "futures-task", "pin-project-lite", "pin-utils", "slab", ] +[[package]] +name = "generic-array" +version = "0.14.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bff49e947297f3312447abdca79f45f4738097cc82b06e72054d2223f601f1b9" +dependencies = [ + "typenum", + "version_check", +] + [[package]] name = "getrandom" version = "0.2.8" @@ -458,6 +519,19 @@ dependencies = [ "tokio-rustls", ] +[[package]] +name = "hyper-tungstenite" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "880b8b1c98a5ec2a505c7c90db6d3f6f1f480af5655d9c5b55facc9382a5a5b5" +dependencies = [ + "hyper", + "pin-project", + "tokio", + "tokio-tungstenite", + "tungstenite", +] + [[package]] name = "idna" version = "0.3.0" @@ -607,7 +681,9 @@ dependencies = [ "console", "dialoguer", "directories", + "futures-util", "hyper", + "hyper-tungstenite", "lazy_static", "mlua", "os_str_bytes", @@ -799,6 +875,12 @@ version = "0.3.26" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ac9a59f73473f1b8d852421e59e64809f025994837ef743615c6d0c5b305160" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -847,6 +929,36 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + [[package]] name = "redox_syscall" version = "0.2.16" @@ -899,7 +1011,7 @@ version = "0.11.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "21eed90ec8570952d53b772ecf8f206aa1ec9a3d76b2521c56c42973f2d91ee9" dependencies = [ - "base64", + "base64 0.21.0", "bytes", "encoding_rs", "futures-core", @@ -994,7 +1106,7 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d194b56d58803a43635bdc398cd17e383d6f71f9182b9a192c127ca42494a59b" dependencies = [ - "base64", + "base64 0.21.0", ] [[package]] @@ -1068,6 +1180,17 @@ dependencies = [ "serde", ] +[[package]] +name = "sha1" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f04293dc80c3993519f2d7f6f511707ee7094fe0c6d3406feb330cdb3540eba3" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "shell-words" version = "1.1.0" @@ -1240,6 +1363,18 @@ dependencies = [ "webpki", ] +[[package]] +name = "tokio-tungstenite" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54319c93411147bced34cb5609a80e0a8e44c5999c93903a81cd866630ec0bfd" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.4" @@ -1286,6 +1421,31 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "tungstenite" +version = "0.18.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30ee6ab729cd4cf0fd55218530c4522ed30b7b6081752839b68fcec8d0960788" +dependencies = [ + "base64 0.13.1", + "byteorder", + "bytes", + "http", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "url", + "utf-8", +] + +[[package]] +name = "typenum" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" + [[package]] name = "unicode-bidi" version = "0.3.10" @@ -1330,6 +1490,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "version_check" version = "0.9.4" diff --git a/packages/lib/Cargo.toml b/packages/lib/Cargo.toml index bf6b16f..f4e18cb 100644 --- a/packages/lib/Cargo.toml +++ b/packages/lib/Cargo.toml @@ -24,10 +24,12 @@ reqwest.workspace = true dialoguer = "0.10.3" directories = "4.0.1" +futures-util = "0.3.26" pin-project = "1.0.12" os_str_bytes = "6.4.1" hyper = { version = "0.14.24", features = ["full"] } +hyper-tungstenite = { version = "0.9.0" } mlua = { version = "0.8.7", features = ["luau", "async", "serialize"] } [dev-dependencies] diff --git a/packages/lib/src/globals/net.rs b/packages/lib/src/globals/net.rs index 5cf454f..39132f7 100644 --- a/packages/lib/src/globals/net.rs +++ b/packages/lib/src/globals/net.rs @@ -10,6 +10,12 @@ use mlua::prelude::*; use hyper::{body::to_bytes, http::HeaderValue, server::conn::AddrStream, service::Service}; use hyper::{Body, HeaderMap, Request, Response, Server}; +use hyper_tungstenite::{ + is_upgrade_request as is_ws_upgrade_request, tungstenite::Message as WsMessage, + upgrade as ws_upgrade, +}; + +use futures_util::{SinkExt, StreamExt}; use reqwest::{ClientBuilder, Method}; use tokio::{ sync::mpsc::{self, Sender}, @@ -216,107 +222,142 @@ impl Service> for NetService { Poll::Ready(Ok(())) } - fn call(&mut self, req: Request) -> Self::Future { + fn call(&mut self, mut req: Request) -> Self::Future { let lua = self.0; - let key1 = self.1.clone(); - let _key2 = self.2.clone(); // TODO: This is the web socket callback - let (parts, body) = req.into_parts(); - Box::pin(async move { - // Convert request body into bytes, extract handler - // function & lune message sender to use later - let bytes = to_bytes(body).await.map_err(LuaError::external)?; - let handler: LuaFunction = lua.registry_value(&key1)?; - let sender = lua - .app_data_ref::>>() - .unwrap() - .upgrade() - .unwrap(); - // Create a readonly table for the request query params - let query_params = TableBuilder::new(lua)? - .with_values( - parts - .uri - .query() - .unwrap_or_default() - .split('&') - .filter_map(|q| q.split_once('=')) - .collect(), - )? - .build_readonly()?; - // Do the same for headers - let header_map = TableBuilder::new(lua)? - .with_values( - parts - .headers - .iter() - .map(|(name, value)| { - (name.to_string(), value.to_str().unwrap().to_string()) - }) - .collect(), - )? - .build_readonly()?; - // Create a readonly table with request info to pass to the handler - let request = TableBuilder::new(lua)? - .with_value("path", parts.uri.path())? - .with_value("query", query_params)? - .with_value("method", parts.method.as_str())? - .with_value("headers", header_map)? - .with_value("body", lua.create_string(&bytes)?)? - .build_readonly()?; - match handler.call_async(request).await { - // Plain strings from the handler are plaintext responses - Ok(LuaValue::String(s)) => Ok(Response::builder() - .status(200) - .header("Content-Type", "text/plain") - .body(Body::from(s.as_bytes().to_vec())) - .unwrap()), - // Tables are more detailed responses with potential status, headers, body - Ok(LuaValue::Table(t)) => { - let status = t.get::<_, Option>("status")?.unwrap_or(200); - let mut resp = Response::builder().status(status); - - if let Some(headers) = t.get::<_, Option>("headers")? { - for pair in headers.pairs::() { - let (h, v) = pair?; - resp = resp.header(&h, v.as_bytes()); + if self.2.is_some() && is_ws_upgrade_request(&req) { + // Websocket request + websocket handler exists, + // we should upgrade this connection to a websocket + // and then pass a socket object to our lua handler + let kopt = self.2.clone(); + let key = kopt.as_ref().as_ref().unwrap(); + let handler: LuaFunction = lua.registry_value(key).expect("Missing websocket handler"); + let (response, ws) = ws_upgrade(&mut req, None).expect("Failed to upgrade websocket"); + task::spawn_local(async move { + if let Ok(mut websocket) = ws.await { + // TODO: Create lua userdata websocket object + // with methods for interacting with the websocket + // TODO: Start waiting for messages when we know + // for sure that we have gotten a message handler + // and move the following logic into there instead + while let Some(message) = websocket.next().await { + // Create lua strings from websocket messages + if let Some(handler_str) = match message.map_err(LuaError::external)? { + WsMessage::Text(msg) => Some(lua.create_string(&msg)?), + WsMessage::Binary(msg) => Some(lua.create_string(&msg)?), + // Tungstenite takes care of these messages + WsMessage::Ping(_) => None, + WsMessage::Pong(_) => None, + WsMessage::Close(_) => None, + WsMessage::Frame(_) => None, + } { + // TODO: Call whatever lua handler we have registered, with our message string } } - - let body = t - .get::<_, Option>("body")? - .map(|b| Body::from(b.as_bytes().to_vec())) - .unwrap_or_else(Body::empty); - - Ok(resp.body(body).unwrap()) } - // If the handler returns an error, generate a 5xx response - Err(err) => { - sender - .send(LuneMessage::LuaError(err.to_lua_err())) - .await - .map_err(LuaError::external)?; - Ok(Response::builder() - .status(500) - .body(Body::from("Internal Server Error")) - .unwrap()) - } - // If the handler returns a value that is of an invalid type, - // this should also be an error, so generate a 5xx response - Ok(value) => { - sender + Ok::<_, LuaError>(()) + }); + Box::pin(async move { Ok(response) }) + } else { + // Normal http request or no websocket handler exists, call the http request handler + let key = self.1.clone(); + let (parts, body) = req.into_parts(); + Box::pin(async move { + // Convert request body into bytes, extract handler + // function & lune message sender to use later + let bytes = to_bytes(body).await.map_err(LuaError::external)?; + let handler: LuaFunction = lua.registry_value(&key)?; + let sender = lua + .app_data_ref::>>() + .unwrap() + .upgrade() + .unwrap(); + // Create a readonly table for the request query params + let query_params = TableBuilder::new(lua)? + .with_values( + parts + .uri + .query() + .unwrap_or_default() + .split('&') + .filter_map(|q| q.split_once('=')) + .collect(), + )? + .build_readonly()?; + // Do the same for headers + let header_map = TableBuilder::new(lua)? + .with_values( + parts + .headers + .iter() + .map(|(name, value)| { + (name.to_string(), value.to_str().unwrap().to_string()) + }) + .collect(), + )? + .build_readonly()?; + // Create a readonly table with request info to pass to the handler + let request = TableBuilder::new(lua)? + .with_value("path", parts.uri.path())? + .with_value("query", query_params)? + .with_value("method", parts.method.as_str())? + .with_value("headers", header_map)? + .with_value("body", lua.create_string(&bytes)?)? + .build_readonly()?; + match handler.call_async(request).await { + // Plain strings from the handler are plaintext responses + Ok(LuaValue::String(s)) => Ok(Response::builder() + .status(200) + .header("Content-Type", "text/plain") + .body(Body::from(s.as_bytes().to_vec())) + .unwrap()), + // Tables are more detailed responses with potential status, headers, body + Ok(LuaValue::Table(t)) => { + let status = t.get::<_, Option>("status")?.unwrap_or(200); + let mut resp = Response::builder().status(status); + + if let Some(headers) = t.get::<_, Option>("headers")? { + for pair in headers.pairs::() { + let (h, v) = pair?; + resp = resp.header(&h, v.as_bytes()); + } + } + + let body = t + .get::<_, Option>("body")? + .map(|b| Body::from(b.as_bytes().to_vec())) + .unwrap_or_else(Body::empty); + + Ok(resp.body(body).unwrap()) + } + // If the handler returns an error, generate a 5xx response + Err(err) => { + sender + .send(LuneMessage::LuaError(err.to_lua_err())) + .await + .map_err(LuaError::external)?; + Ok(Response::builder() + .status(500) + .body(Body::from("Internal Server Error")) + .unwrap()) + } + // If the handler returns a value that is of an invalid type, + // this should also be an error, so generate a 5xx response + Ok(value) => { + sender .send(LuneMessage::LuaError(LuaError::RuntimeError(format!( "Expected net serve handler to return a value of type 'string' or 'table', got '{}'", value.type_name() )))) .await .map_err(LuaError::external)?; - Ok(Response::builder() - .status(500) - .body(Body::from("Internal Server Error")) - .unwrap()) + Ok(Response::builder() + .status(500) + .body(Body::from("Internal Server Error")) + .unwrap()) + } } - } - }) + }) + } } }