Reimplement net websockets

This commit is contained in:
Filip Tibell 2024-02-12 16:58:24 +01:00
parent 65def158d2
commit dfbe8a55d0
No known key found for this signature in database
6 changed files with 432 additions and 68 deletions

258
Cargo.lock generated
View file

@ -686,16 +686,26 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c34f04666d835ff5d62e058c3995147c06f42fe86ff053337632bca83e42702d"
[[package]]
name = "env_logger"
version = "0.10.2"
name = "env_filter"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd405aab171cb85d6735e5c8d9db038c17d3ca007a4d2c25f337935c3d90580"
checksum = "a009aa4810eb158359dda09d0c87378e4bbb89b5a801f016885a4707ba24f7ea"
dependencies = [
"humantime",
"is-terminal",
"log",
"regex",
"termcolor",
]
[[package]]
name = "env_logger"
version = "0.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05e7cf40684ae96ade6232ed84582f40ce0a66efcd43a5117aef610534f8e0b8"
dependencies = [
"anstream",
"anstyle",
"env_filter",
"humantime",
"log",
]
[[package]]
@ -946,7 +956,26 @@ dependencies = [
"futures-core",
"futures-sink",
"futures-util",
"http",
"http 0.2.11",
"indexmap",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "h2"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "31d030e59af851932b72ceebadf4a2b5986dba4c3b99dd2493f8273a0f151943"
dependencies = [
"bytes",
"fnv",
"futures-core",
"futures-sink",
"futures-util",
"http 1.0.0",
"indexmap",
"slab",
"tokio",
@ -992,6 +1021,17 @@ dependencies = [
"itoa",
]
[[package]]
name = "http"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b32afd38673a8016f7c9ae69e5af41a58f81b1d31689040f2f1959594ce194ea"
dependencies = [
"bytes",
"fnv",
"itoa",
]
[[package]]
name = "http-body"
version = "0.4.6"
@ -999,7 +1039,30 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
dependencies = [
"bytes",
"http",
"http 0.2.11",
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1cac85db508abc24a2e48553ba12a996e87244a0395ce011e62b37158745d643"
dependencies = [
"bytes",
"http 1.0.0",
]
[[package]]
name = "http-body-util"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41cb79eb393015dadd30fc252023adb0b2400a0caee0fa2a077e6e21a551e840"
dependencies = [
"bytes",
"futures-util",
"http 1.0.0",
"http-body 1.0.0",
"pin-project-lite",
]
@ -1031,9 +1094,9 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"h2 0.3.24",
"http 0.2.11",
"http-body 0.4.6",
"httparse",
"httpdate",
"itoa",
@ -1045,6 +1108,26 @@ dependencies = [
"want",
]
[[package]]
name = "hyper"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fb5aa53871fc917b1a9ed87b683a5d86db645e23acb32c2e0785a353e522fb75"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.2",
"http 1.0.0",
"http-body 1.0.0",
"httparse",
"httpdate",
"itoa",
"pin-project-lite",
"tokio",
"want",
]
[[package]]
name = "hyper-rustls"
version = "0.24.2"
@ -1052,26 +1135,44 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590"
dependencies = [
"futures-util",
"http",
"hyper",
"rustls",
"http 0.2.11",
"hyper 0.14.28",
"rustls 0.21.10",
"tokio",
"tokio-rustls",
"tokio-rustls 0.24.1",
]
[[package]]
name = "hyper-tungstenite"
version = "0.11.1"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cc7dcb1ab67cd336f468a12491765672e61a3b6b148634dbfe2fe8acd3fe7d9"
checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad"
dependencies = [
"hyper",
"http-body-util",
"hyper 1.1.0",
"hyper-util",
"pin-project-lite",
"tokio",
"tokio-tungstenite",
"tungstenite",
]
[[package]]
name = "hyper-util"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ca38ef113da30126bbff9cd1705f9273e15d45498615d138b0c20279ac7a76aa"
dependencies = [
"bytes",
"futures-util",
"http 1.0.0",
"http-body 1.0.0",
"hyper 1.1.0",
"pin-project-lite",
"socket2",
"tokio",
]
[[package]]
name = "iana-time-zone"
version = "0.1.60"
@ -1141,17 +1242,6 @@ version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3"
[[package]]
name = "is-terminal"
version = "0.4.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b"
dependencies = [
"hermit-abi",
"libc",
"windows-sys 0.52.0",
]
[[package]]
name = "itertools"
version = "0.12.1"
@ -1266,7 +1356,7 @@ dependencies = [
"env_logger",
"futures-util",
"glam",
"hyper",
"hyper 1.1.0",
"hyper-tungstenite",
"include_dir",
"itertools",
@ -1510,9 +1600,9 @@ dependencies = [
[[package]]
name = "os_str_bytes"
version = "6.6.1"
version = "7.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1"
checksum = "7ac44c994af577c799b1b4bd80dc214701e349873ad894d6cdf96f4f7526e0b9"
dependencies = [
"memchr",
]
@ -1927,10 +2017,10 @@ dependencies = [
"encoding_rs",
"futures-core",
"futures-util",
"h2",
"http",
"http-body",
"hyper",
"h2 0.3.24",
"http 0.2.11",
"http-body 0.4.6",
"hyper 0.14.28",
"hyper-rustls",
"ipnet",
"js-sys",
@ -1939,7 +2029,7 @@ dependencies = [
"once_cell",
"percent-encoding",
"pin-project-lite",
"rustls",
"rustls 0.21.10",
"rustls-pemfile",
"serde",
"serde_json",
@ -1947,13 +2037,13 @@ dependencies = [
"sync_wrapper",
"system-configuration",
"tokio",
"tokio-rustls",
"tokio-rustls 0.24.1",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"webpki-roots",
"webpki-roots 0.25.4",
"winreg 0.50.0",
]
@ -2056,10 +2146,24 @@ checksum = "f9d5a6813c0759e4609cd494e8e725babae6a2ca7b62a5536a13daaec6fcb7ba"
dependencies = [
"log",
"ring",
"rustls-webpki",
"rustls-webpki 0.101.7",
"sct",
]
[[package]]
name = "rustls"
version = "0.22.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e87c9956bd9807afa1f77e0f7594af32566e830e088a5576d27c5b6f30f49d41"
dependencies = [
"log",
"ring",
"rustls-pki-types",
"rustls-webpki 0.102.2",
"subtle",
"zeroize",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.4"
@ -2069,6 +2173,12 @@ dependencies = [
"base64 0.21.7",
]
[[package]]
name = "rustls-pki-types"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a716eb65e3158e90e17cd93d855216e27bde02745ab842f2cab4a39dba1bacf"
[[package]]
name = "rustls-webpki"
version = "0.101.7"
@ -2079,6 +2189,17 @@ dependencies = [
"untrusted",
]
[[package]]
name = "rustls-webpki"
version = "0.102.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "faaa0a62740bedb9b2ef5afa303da42764c012f743917351dc9a237ea1663610"
dependencies = [
"ring",
"rustls-pki-types",
"untrusted",
]
[[package]]
name = "rustyline"
version = "13.0.0"
@ -2386,6 +2507,12 @@ version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee073c9e4cd00e28217186dbe12796d692868f432bf2e97ee73bed0c56dfa01"
[[package]]
name = "subtle"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc"
[[package]]
name = "syn"
version = "1.0.109"
@ -2447,15 +2574,6 @@ dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "termcolor"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755"
dependencies = [
"winapi-util",
]
[[package]]
name = "thiserror"
version = "1.0.57"
@ -2607,23 +2725,35 @@ version = "0.24.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081"
dependencies = [
"rustls",
"rustls 0.21.10",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f"
dependencies = [
"rustls 0.22.2",
"rustls-pki-types",
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.20.1"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"rustls",
"rustls 0.22.2",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tokio-rustls 0.25.0",
"tungstenite",
"webpki-roots",
"webpki-roots 0.26.1",
]
[[package]]
@ -2750,18 +2880,19 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.20.1"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e3dac10fd62eaf6617d3a904ae222845979aec67c615d1c842b4002c7666fb9"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder 1.5.0",
"bytes",
"data-encoding",
"http",
"http 1.0.0",
"httparse",
"log",
"rand",
"rustls",
"rustls 0.22.2",
"rustls-pki-types",
"sha1 0.10.6",
"thiserror",
"url",
@ -2983,6 +3114,15 @@ version = "0.25.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1"
[[package]]
name = "webpki-roots"
version = "0.26.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009"
dependencies = [
"rustls-pki-types",
]
[[package]]
name = "winapi"
version = "0.3.9"

View file

@ -82,7 +82,7 @@ urlencoding = "2.1"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1.24", features = ["full", "tracing"] }
os_str_bytes = { version = "6.4", features = ["conversions"] }
os_str_bytes = { version = "7.0", features = ["conversions"] }
mlua-luau-scheduler = { git = "https://github.com/lune-org/mlua-luau-scheduler", rev = "7c59d0c722215693839b2790a918fcc872e5ab93" }
mlua = { git = "https://github.com/mlua-rs/mlua.git", rev = "1754226c7440ec6c194d2d678ec083b621d46ceb", features = [
@ -108,12 +108,12 @@ toml = { version = "0.8", features = ["preserve_order"] }
### NET
hyper = { version = "0.14", features = ["full"] }
hyper-tungstenite = { version = "0.11" }
hyper = { version = "1.1", features = ["full"] }
hyper-tungstenite = { version = "0.13" }
reqwest = { version = "0.11", default-features = false, features = [
"rustls-tls",
] }
tokio-tungstenite = { version = "0.20", features = ["rustls-tls-webpki-roots"] }
tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] }
### DATETIME
chrono = "0.4"
@ -122,7 +122,7 @@ chrono_lc = "0.1"
### CLI
anyhow = { optional = true, version = "1.0" }
env_logger = { optional = true, version = "0.10" }
env_logger = { optional = true, version = "0.11" }
itertools = { optional = true, version = "0.12" }
clap = { optional = true, version = "4.1", features = ["derive"] }
include_dir = { optional = true, version = "0.7", features = ["glob"] }

View file

@ -4,10 +4,14 @@ use mlua::prelude::*;
mod config;
mod util;
mod websocket;
use crate::lune::util::TableBuilder;
use self::config::{RequestConfig, ServeConfig};
use self::{
config::{RequestConfig, ServeConfig},
websocket::NetWebSocket,
};
use super::serde::encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat};
@ -43,12 +47,13 @@ fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult<Lua
EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json)
}
async fn net_request<'lua>(lua: &'lua Lua, config: RequestConfig) -> LuaResult<LuaTable<'lua>> {
async fn net_request(lua: &Lua, config: RequestConfig) -> LuaResult<LuaTable> {
unimplemented!()
}
async fn net_socket<'lua>(lua: &'lua Lua, url: String) -> LuaResult<LuaTable<'lua>> {
unimplemented!()
async fn net_socket(lua: &Lua, url: String) -> LuaResult<LuaTable> {
let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?;
NetWebSocket::new(ws).into_lua_table(lua)
}
async fn net_serve<'lua>(

View file

@ -0,0 +1,190 @@
use std::sync::{
atomic::{AtomicBool, AtomicU16, Ordering},
Arc,
};
use mlua::prelude::*;
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::Mutex as AsyncMutex,
};
use hyper_tungstenite::{
tungstenite::{
protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame},
Message as WsMessage,
},
WebSocketStream,
};
use crate::lune::util::TableBuilder;
// Wrapper implementation for compatibility and changing colon syntax to dot syntax
const WEB_SOCKET_IMPL_LUA: &str = r#"
return freeze(setmetatable({
close = function(...)
return websocket:close(...)
end,
send = function(...)
return websocket:send(...)
end,
next = function(...)
return websocket:next(...)
end,
}, {
__index = function(self, key)
if key == "closeCode" then
return websocket.closeCode
end
end,
}))
"#;
#[derive(Debug)]
pub struct NetWebSocket<T> {
close_code_exists: Arc<AtomicBool>,
close_code_value: Arc<AtomicU16>,
read_stream: Arc<AsyncMutex<SplitStream<WebSocketStream<T>>>>,
write_stream: Arc<AsyncMutex<SplitSink<WebSocketStream<T>, WsMessage>>>,
}
impl<T> Clone for NetWebSocket<T> {
fn clone(&self) -> Self {
Self {
close_code_exists: Arc::clone(&self.close_code_exists),
close_code_value: Arc::clone(&self.close_code_value),
read_stream: Arc::clone(&self.read_stream),
write_stream: Arc::clone(&self.write_stream),
}
}
}
impl<T> NetWebSocket<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
pub fn new(value: WebSocketStream<T>) -> Self {
let (write, read) = value.split();
Self {
close_code_exists: Arc::new(AtomicBool::new(false)),
close_code_value: Arc::new(AtomicU16::new(0)),
read_stream: Arc::new(AsyncMutex::new(read)),
write_stream: Arc::new(AsyncMutex::new(write)),
}
}
fn get_close_code(&self) -> Option<u16> {
if self.close_code_exists.load(Ordering::Relaxed) {
Some(self.close_code_value.load(Ordering::Relaxed))
} else {
None
}
}
fn set_close_code(&self, code: u16) {
self.close_code_exists.store(true, Ordering::Relaxed);
self.close_code_value.store(code, Ordering::Relaxed);
}
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
let mut ws = self.write_stream.lock().await;
ws.send(msg).await.into_lua_err()
}
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
let mut ws = self.read_stream.lock().await;
ws.next().await.transpose().into_lua_err()
}
pub async fn close(&self, code: Option<u16>) -> LuaResult<()> {
if self.close_code_exists.load(Ordering::Relaxed) {
return Err(LuaError::runtime("Socket has already been closed"));
}
self.send(WsMessage::Close(Some(WsCloseFrame {
code: match code {
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
Some(code) => {
return Err(LuaError::runtime(format!(
"Close code must be between 1000 and 4999, got {code}"
)))
}
None => WsCloseCode::Normal,
},
reason: "".into(),
})))
.await?;
let mut ws = self.write_stream.lock().await;
ws.close().await.into_lua_err()
}
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
let table_freeze = lua
.globals()
.get::<_, LuaTable>("table")?
.get::<_, LuaFunction>("freeze")?;
let env = TableBuilder::new(lua)?
.with_value("websocket", self.clone())?
.with_value("setmetatable", setmetatable)?
.with_value("freeze", table_freeze)?
.build_readonly()?;
lua.load(WEB_SOCKET_IMPL_LUA)
.set_name("websocket")
.set_environment(env)
.eval()
}
}
impl<T> LuaUserData for NetWebSocket<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) {
fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code()));
}
fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("close", |lua, this, code: Option<u16>| async move {
this.close(code).await
});
methods.add_async_method(
"send",
|_, this, (string, as_binary): (LuaString, Option<bool>)| async move {
this.send(if as_binary.unwrap_or_default() {
WsMessage::Binary(string.as_bytes().to_vec())
} else {
let s = string.to_str().into_lua_err()?;
WsMessage::Text(s.to_string())
})
.await
},
);
methods.add_async_method("next", |lua, this, _: ()| async move {
let msg = this.next().await?;
if let Some(WsMessage::Close(Some(frame))) = msg.as_ref() {
this.set_close_code(frame.code.into());
}
Ok(match msg {
Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?),
Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?),
Some(WsMessage::Close(_)) | None => LuaValue::Nil,
// Ignore ping/pong/frame messages, they are handled by tungstenite
msg => unreachable!("Unhandled message: {:?}", msg),
})
});
}
}

View file

@ -68,6 +68,7 @@ create_tests! {
net_url_decode: "net/url/decode",
net_serve_requests: "net/serve/requests",
net_serve_websockets: "net/serve/websockets",
net_socket_basic: "net/socket/basic",
net_socket_wss: "net/socket/wss",
net_socket_wss_rw: "net/socket/wss_rw",

View file

@ -0,0 +1,28 @@
local net = require("@lune/net")
-- We're going to use Discord's WebSocket gateway server for testing
local socket = net.socket("wss://gateway.discord.gg/?v=10&encoding=json")
assert(type(socket.next) == "function", "next must be a function")
assert(type(socket.send) == "function", "send must be a function")
assert(type(socket.close) == "function", "close must be a function")
-- Request to close the socket
socket.close()
-- Drain remaining messages, until we got our close message
while socket.next() do
end
assert(type(socket.closeCode) == "number", "closeCode should exist after closing")
assert(socket.closeCode == 1000, "closeCode should be 1000 after closing")
local success, message = pcall(function()
socket.send("Hello, world!")
end)
assert(not success, "send should fail after closing")
assert(
string.find(tostring(message), "closed") or string.find(tostring(message), "closing"),
"send should fail with a message that the socket was closed"
)