From dfbe8a55d069c909b5b83d9fd4794c3c5d6ea995 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Mon, 12 Feb 2024 16:58:24 +0100 Subject: [PATCH] Reimplement net websockets --- Cargo.lock | 258 ++++++++++++++++++++++------- Cargo.toml | 10 +- src/lune/builtins/net/mod.rs | 13 +- src/lune/builtins/net/websocket.rs | 190 +++++++++++++++++++++ src/tests.rs | 1 + tests/net/socket/basic.luau | 28 ++++ 6 files changed, 432 insertions(+), 68 deletions(-) create mode 100644 src/lune/builtins/net/websocket.rs create mode 100644 tests/net/socket/basic.luau diff --git a/Cargo.lock b/Cargo.lock index fb2887d..189f4e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -686,16 +686,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d" [[package]] -name = "env_logger" -version = "0.10.2" +name = "env_filter" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580" +checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea" dependencies = [ - "humantime", - "is-terminal", "log", "regex", - "termcolor", +] + +[[package]] +name = "env_logger" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05e7cf40684ae96ade6232ed84582f40ce0a66efcd43a5117aef610534f8e0b8" +dependencies = [ + "anstream", + "anstyle", + "env_filter", + "humantime", + "log", ] [[package]] @@ -946,7 +956,26 @@ dependencies = [ "futures-core", "futures-sink", "futures-util", - "http", + "http 0.2.11", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + +[[package]] +name = "h2" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943" +dependencies = [ + "bytes", + "fnv", + "futures-core", + "futures-sink", + "futures-util", + "http 1.0.0", "indexmap", "slab", "tokio", @@ -992,6 +1021,17 @@ dependencies = [ "itoa", ] +[[package]] +name = "http" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea" +dependencies = [ + "bytes", + "fnv", + "itoa", +] + [[package]] name = "http-body" version = "0.4.6" @@ -999,7 +1039,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2" dependencies = [ "bytes", - "http", + "http 0.2.11", + "pin-project-lite", +] + +[[package]] +name = "http-body" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643" +dependencies = [ + "bytes", + "http 1.0.0", +] + +[[package]] +name = "http-body-util" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", "pin-project-lite", ] @@ -1031,9 +1094,9 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", - "http", - "http-body", + "h2 0.3.24", + "http 0.2.11", + "http-body 0.4.6", "httparse", "httpdate", "itoa", @@ -1045,6 +1108,26 @@ dependencies = [ "want", ] +[[package]] +name = "hyper" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb5aa53871fc917b1a9ed87b683a5d86db645e23acb32c2e0785a353e522fb75" +dependencies = [ + "bytes", + "futures-channel", + "futures-util", + "h2 0.4.2", + "http 1.0.0", + "http-body 1.0.0", + "httparse", + "httpdate", + "itoa", + "pin-project-lite", + "tokio", + "want", +] + [[package]] name = "hyper-rustls" version = "0.24.2" @@ -1052,26 +1135,44 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" dependencies = [ "futures-util", - "http", - "hyper", - "rustls", + "http 0.2.11", + "hyper 0.14.28", + "rustls 0.21.10", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", ] [[package]] name = "hyper-tungstenite" -version = "0.11.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9" +checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad" dependencies = [ - "hyper", + "http-body-util", + "hyper 1.1.0", + "hyper-util", "pin-project-lite", "tokio", "tokio-tungstenite", "tungstenite", ] +[[package]] +name = "hyper-util" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa" +dependencies = [ + "bytes", + "futures-util", + "http 1.0.0", + "http-body 1.0.0", + "hyper 1.1.0", + "pin-project-lite", + "socket2", + "tokio", +] + [[package]] name = "iana-time-zone" version = "0.1.60" @@ -1141,17 +1242,6 @@ version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" -[[package]] -name = "is-terminal" -version = "0.4.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" -dependencies = [ - "hermit-abi", - "libc", - "windows-sys 0.52.0", -] - [[package]] name = "itertools" version = "0.12.1" @@ -1266,7 +1356,7 @@ dependencies = [ "env_logger", "futures-util", "glam", - "hyper", + "hyper 1.1.0", "hyper-tungstenite", "include_dir", "itertools", @@ -1510,9 +1600,9 @@ dependencies = [ [[package]] name = "os_str_bytes" -version = "6.6.1" +version = "7.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" +checksum = "7ac44c994af577c799b1b4bd80dc214701e349873ad894d6cdf96f4f7526e0b9" dependencies = [ "memchr", ] @@ -1927,10 +2017,10 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", - "http", - "http-body", - "hyper", + "h2 0.3.24", + "http 0.2.11", + "http-body 0.4.6", + "hyper 0.14.28", "hyper-rustls", "ipnet", "js-sys", @@ -1939,7 +2029,7 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", + "rustls 0.21.10", "rustls-pemfile", "serde", "serde_json", @@ -1947,13 +2037,13 @@ dependencies = [ "sync_wrapper", "system-configuration", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", - "webpki-roots", + "webpki-roots 0.25.4", "winreg 0.50.0", ] @@ -2056,10 +2146,24 @@ checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba" dependencies = [ "log", "ring", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.22.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.2", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -2069,6 +2173,12 @@ dependencies = [ "base64 0.21.7", ] +[[package]] +name = "rustls-pki-types" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a716eb65e3158e90e17cd93d855216e27bde02745ab842f2cab4a39dba1bacf" + [[package]] name = "rustls-webpki" version = "0.101.7" @@ -2079,6 +2189,17 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustls-webpki" +version = "0.102.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustyline" version = "13.0.0" @@ -2386,6 +2507,12 @@ version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01" +[[package]] +name = "subtle" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" + [[package]] name = "syn" version = "1.0.109" @@ -2447,15 +2574,6 @@ dependencies = [ "windows-sys 0.52.0", ] -[[package]] -name = "termcolor" -version = "1.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" -dependencies = [ - "winapi-util", -] - [[package]] name = "thiserror" version = "1.0.57" @@ -2607,23 +2725,35 @@ version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls", + "rustls 0.21.10", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" +dependencies = [ + "rustls 0.22.2", + "rustls-pki-types", "tokio", ] [[package]] name = "tokio-tungstenite" -version = "0.20.1" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c" +checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" dependencies = [ "futures-util", "log", - "rustls", + "rustls 0.22.2", + "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.25.0", "tungstenite", - "webpki-roots", + "webpki-roots 0.26.1", ] [[package]] @@ -2750,18 +2880,19 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "tungstenite" -version = "0.20.1" +version = "0.21.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e3dac10fd62eaf6617d3a904ae222845979aec67c615d1c842b4002c7666fb9" +checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" dependencies = [ "byteorder 1.5.0", "bytes", "data-encoding", - "http", + "http 1.0.0", "httparse", "log", "rand", - "rustls", + "rustls 0.22.2", + "rustls-pki-types", "sha1 0.10.6", "thiserror", "url", @@ -2983,6 +3114,15 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "webpki-roots" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/Cargo.toml b/Cargo.toml index 4ce4a86..12fef19 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -82,7 +82,7 @@ urlencoding = "2.1" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tokio = { version = "1.24", features = ["full", "tracing"] } -os_str_bytes = { version = "6.4", features = ["conversions"] } +os_str_bytes = { version = "7.0", features = ["conversions"] } mlua-luau-scheduler = { git = "https://github.com/lune-org/mlua-luau-scheduler", rev = "7c59d0c722215693839b2790a918fcc872e5ab93" } mlua = { git = "https://github.com/mlua-rs/mlua.git", rev = "1754226c7440ec6c194d2d678ec083b621d46ceb", features = [ @@ -108,12 +108,12 @@ toml = { version = "0.8", features = ["preserve_order"] } ### NET -hyper = { version = "0.14", features = ["full"] } -hyper-tungstenite = { version = "0.11" } +hyper = { version = "1.1", features = ["full"] } +hyper-tungstenite = { version = "0.13" } reqwest = { version = "0.11", default-features = false, features = [ "rustls-tls", ] } -tokio-tungstenite = { version = "0.20", features = ["rustls-tls-webpki-roots"] } +tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } ### DATETIME chrono = "0.4" @@ -122,7 +122,7 @@ chrono_lc = "0.1" ### CLI anyhow = { optional = true, version = "1.0" } -env_logger = { optional = true, version = "0.10" } +env_logger = { optional = true, version = "0.11" } itertools = { optional = true, version = "0.12" } clap = { optional = true, version = "4.1", features = ["derive"] } include_dir = { optional = true, version = "0.7", features = ["glob"] } diff --git a/src/lune/builtins/net/mod.rs b/src/lune/builtins/net/mod.rs index c5cb035..2567a38 100644 --- a/src/lune/builtins/net/mod.rs +++ b/src/lune/builtins/net/mod.rs @@ -4,10 +4,14 @@ use mlua::prelude::*; mod config; mod util; +mod websocket; use crate::lune::util::TableBuilder; -use self::config::{RequestConfig, ServeConfig}; +use self::{ + config::{RequestConfig, ServeConfig}, + websocket::NetWebSocket, +}; use super::serde::encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat}; @@ -43,12 +47,13 @@ fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult(lua: &'lua Lua, config: RequestConfig) -> LuaResult> { +async fn net_request(lua: &Lua, config: RequestConfig) -> LuaResult { unimplemented!() } -async fn net_socket<'lua>(lua: &'lua Lua, url: String) -> LuaResult> { - unimplemented!() +async fn net_socket(lua: &Lua, url: String) -> LuaResult { + let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?; + NetWebSocket::new(ws).into_lua_table(lua) } async fn net_serve<'lua>( diff --git a/src/lune/builtins/net/websocket.rs b/src/lune/builtins/net/websocket.rs new file mode 100644 index 0000000..fb7233c --- /dev/null +++ b/src/lune/builtins/net/websocket.rs @@ -0,0 +1,190 @@ +use std::sync::{ + atomic::{AtomicBool, AtomicU16, Ordering}, + Arc, +}; + +use mlua::prelude::*; + +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::Mutex as AsyncMutex, +}; + +use hyper_tungstenite::{ + tungstenite::{ + protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame}, + Message as WsMessage, + }, + WebSocketStream, +}; + +use crate::lune::util::TableBuilder; + +// Wrapper implementation for compatibility and changing colon syntax to dot syntax +const WEB_SOCKET_IMPL_LUA: &str = r#" +return freeze(setmetatable({ + close = function(...) + return websocket:close(...) + end, + send = function(...) + return websocket:send(...) + end, + next = function(...) + return websocket:next(...) + end, +}, { + __index = function(self, key) + if key == "closeCode" then + return websocket.closeCode + end + end, +})) +"#; + +#[derive(Debug)] +pub struct NetWebSocket { + close_code_exists: Arc, + close_code_value: Arc, + read_stream: Arc>>>, + write_stream: Arc, WsMessage>>>, +} + +impl Clone for NetWebSocket { + fn clone(&self) -> Self { + Self { + close_code_exists: Arc::clone(&self.close_code_exists), + close_code_value: Arc::clone(&self.close_code_value), + read_stream: Arc::clone(&self.read_stream), + write_stream: Arc::clone(&self.write_stream), + } + } +} + +impl NetWebSocket +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + pub fn new(value: WebSocketStream) -> Self { + let (write, read) = value.split(); + + Self { + close_code_exists: Arc::new(AtomicBool::new(false)), + close_code_value: Arc::new(AtomicU16::new(0)), + read_stream: Arc::new(AsyncMutex::new(read)), + write_stream: Arc::new(AsyncMutex::new(write)), + } + } + + fn get_close_code(&self) -> Option { + if self.close_code_exists.load(Ordering::Relaxed) { + Some(self.close_code_value.load(Ordering::Relaxed)) + } else { + None + } + } + + fn set_close_code(&self, code: u16) { + self.close_code_exists.store(true, Ordering::Relaxed); + self.close_code_value.store(code, Ordering::Relaxed); + } + + pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { + let mut ws = self.write_stream.lock().await; + ws.send(msg).await.into_lua_err() + } + + pub async fn next(&self) -> LuaResult> { + let mut ws = self.read_stream.lock().await; + ws.next().await.transpose().into_lua_err() + } + + pub async fn close(&self, code: Option) -> LuaResult<()> { + if self.close_code_exists.load(Ordering::Relaxed) { + return Err(LuaError::runtime("Socket has already been closed")); + } + + self.send(WsMessage::Close(Some(WsCloseFrame { + code: match code { + Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), + Some(code) => { + return Err(LuaError::runtime(format!( + "Close code must be between 1000 and 4999, got {code}" + ))) + } + None => WsCloseCode::Normal, + }, + reason: "".into(), + }))) + .await?; + + let mut ws = self.write_stream.lock().await; + ws.close().await.into_lua_err() + } + + pub fn into_lua_table(self, lua: &Lua) -> LuaResult { + let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?; + let table_freeze = lua + .globals() + .get::<_, LuaTable>("table")? + .get::<_, LuaFunction>("freeze")?; + + let env = TableBuilder::new(lua)? + .with_value("websocket", self.clone())? + .with_value("setmetatable", setmetatable)? + .with_value("freeze", table_freeze)? + .build_readonly()?; + + lua.load(WEB_SOCKET_IMPL_LUA) + .set_name("websocket") + .set_environment(env) + .eval() + } +} + +impl LuaUserData for NetWebSocket +where + T: AsyncRead + AsyncWrite + Unpin + 'static, +{ + fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) { + fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code())); + } + + fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) { + methods.add_async_method("close", |lua, this, code: Option| async move { + this.close(code).await + }); + + methods.add_async_method( + "send", + |_, this, (string, as_binary): (LuaString, Option)| async move { + this.send(if as_binary.unwrap_or_default() { + WsMessage::Binary(string.as_bytes().to_vec()) + } else { + let s = string.to_str().into_lua_err()?; + WsMessage::Text(s.to_string()) + }) + .await + }, + ); + + methods.add_async_method("next", |lua, this, _: ()| async move { + let msg = this.next().await?; + + if let Some(WsMessage::Close(Some(frame))) = msg.as_ref() { + this.set_close_code(frame.code.into()); + } + + Ok(match msg { + Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?), + Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?), + Some(WsMessage::Close(_)) | None => LuaValue::Nil, + // Ignore ping/pong/frame messages, they are handled by tungstenite + msg => unreachable!("Unhandled message: {:?}", msg), + }) + }); + } +} diff --git a/src/tests.rs b/src/tests.rs index 8f08cde..048343d 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -68,6 +68,7 @@ create_tests! { net_url_decode: "net/url/decode", net_serve_requests: "net/serve/requests", net_serve_websockets: "net/serve/websockets", + net_socket_basic: "net/socket/basic", net_socket_wss: "net/socket/wss", net_socket_wss_rw: "net/socket/wss_rw", diff --git a/tests/net/socket/basic.luau b/tests/net/socket/basic.luau new file mode 100644 index 0000000..2eb4a3b --- /dev/null +++ b/tests/net/socket/basic.luau @@ -0,0 +1,28 @@ +local net = require("@lune/net") + +-- We're going to use Discord's WebSocket gateway server for testing +local socket = net.socket("wss://gateway.discord.gg/?v=10&encoding=json") + +assert(type(socket.next) == "function", "next must be a function") +assert(type(socket.send) == "function", "send must be a function") +assert(type(socket.close) == "function", "close must be a function") + +-- Request to close the socket +socket.close() + +-- Drain remaining messages, until we got our close message +while socket.next() do +end + +assert(type(socket.closeCode) == "number", "closeCode should exist after closing") +assert(socket.closeCode == 1000, "closeCode should be 1000 after closing") + +local success, message = pcall(function() + socket.send("Hello, world!") +end) + +assert(not success, "send should fail after closing") +assert( + string.find(tostring(message), "closed") or string.find(tostring(message), "closing"), + "send should fail with a message that the socket was closed" +)