diff --git a/Cargo.lock b/Cargo.lock index e7cc7bf..73702ba 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -297,6 +297,22 @@ version = "4.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" +[[package]] +name = "async-tungstenite" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef0f7efedeac57d9b26170f72965ecfd31473ca52ca7a64e925b0b6f5f079886" +dependencies = [ + "atomic-waker", + "futures-core", + "futures-io", + "futures-task", + "futures-util", + "log", + "pin-project-lite", + "tungstenite", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -801,6 +817,12 @@ dependencies = [ "typenum", ] +[[package]] +name = "data-encoding" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" + [[package]] name = "deflate64" version = "0.1.9" @@ -1079,6 +1101,20 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -1086,6 +1122,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1153,9 +1190,13 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -1857,8 +1898,10 @@ dependencies = [ "async-io", "async-lock", "async-net", + "async-tungstenite", "blocking", "bstr", + "futures", "futures-lite", "futures-rustls", "http-body-util", @@ -3691,6 +3734,23 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.26.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" +dependencies = [ + "bytes", + "data-encoding", + "http 1.3.1", + "httparse", + "log", + "rand 0.9.1", + "sha1 0.10.6", + "thiserror 2.0.12", + "utf-8", +] + [[package]] name = "typeid" version = "1.0.3" @@ -3763,6 +3823,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf16_iter" version = "1.0.5" diff --git a/crates/lune-std-net/Cargo.toml b/crates/lune-std-net/Cargo.toml index 9c6c70a..f5ec79d 100644 --- a/crates/lune-std-net/Cargo.toml +++ b/crates/lune-std-net/Cargo.toml @@ -21,8 +21,10 @@ async-executor = "1.13" async-io = "2.4" async-lock = "3.4" async-net = "2.0" +async-tungstenite = "0.29" blocking = "1.6" bstr = "1.9" +futures = { version = "0.3", default-features = false, features = ["std"] } futures-lite = "2.6" futures-rustls = "0.26" http-body-util = "0.1" diff --git a/crates/lune-std-net/src/client/stream.rs b/crates/lune-std-net/src/client/http_stream.rs similarity index 50% rename from crates/lune-std-net/src/client/stream.rs rename to crates/lune-std-net/src/client/http_stream.rs index 2d392cf..2aba704 100644 --- a/crates/lune-std-net/src/client/stream.rs +++ b/crates/lune-std-net/src/client/http_stream.rs @@ -1,50 +1,41 @@ use std::{ io, pin::Pin, - sync::{Arc, LazyLock}, + sync::Arc, task::{Context, Poll}, }; use async_net::TcpStream; use futures_lite::prelude::*; use futures_rustls::{TlsConnector, TlsStream}; -use hyper::Uri; -use rustls::ClientConfig; use rustls_pki_types::ServerName; +use url::Url; -static CLIENT_CONFIG: LazyLock> = LazyLock::new(|| { - rustls::ClientConfig::builder() - .with_root_certificates(rustls::RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), - }) - .with_no_client_auth() - .into() -}); +use crate::client::rustls::CLIENT_CONFIG; -pub enum HttpRequestStream { +#[derive(Debug)] +pub enum HttpStream { Plain(TcpStream), Tls(TlsStream), } -impl HttpRequestStream { - pub async fn connect(url: &Uri) -> Result { +impl HttpStream { + pub async fn connect(url: Url) -> Result { let Some(host) = url.host() else { return Err(make_err("unknown or missing host")); }; - let Some(scheme) = url.scheme_str() else { - return Err(make_err("unknown scheme")); + let Some(port) = url.port_or_known_default() else { + return Err(make_err("unknown or missing port")); }; - let (use_tls, port) = match scheme { - "http" => (false, 80), - "https" => (true, 443), + let use_tls = match url.scheme() { + "http" => false, + "https" => true, s => return Err(make_err(format!("unsupported scheme: {s}"))), }; - let stream = { - let port = url.port_u16().unwrap_or(port); - TcpStream::connect((host, port)).await? - }; + let host = host.to_string(); + let stream = TcpStream::connect((host.clone(), port)).await?; let stream = if use_tls { let servname = ServerName::try_from(host).map_err(make_err)?.to_owned(); @@ -59,42 +50,42 @@ impl HttpRequestStream { } } -impl AsyncRead for HttpRequestStream { +impl AsyncRead for HttpStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { match &mut *self { - HttpRequestStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf), - HttpRequestStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf), + HttpStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf), + HttpStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf), } } } -impl AsyncWrite for HttpRequestStream { +impl AsyncWrite for HttpStream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { match &mut *self { - HttpRequestStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf), - HttpRequestStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf), + HttpStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf), + HttpStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf), } } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { - HttpRequestStream::Plain(stream) => Pin::new(stream).poll_close(cx), - HttpRequestStream::Tls(stream) => Pin::new(stream).poll_close(cx), + HttpStream::Plain(stream) => Pin::new(stream).poll_close(cx), + HttpStream::Tls(stream) => Pin::new(stream).poll_close(cx), } } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { match &mut *self { - HttpRequestStream::Plain(stream) => Pin::new(stream).poll_flush(cx), - HttpRequestStream::Tls(stream) => Pin::new(stream).poll_flush(cx), + HttpStream::Plain(stream) => Pin::new(stream).poll_flush(cx), + HttpStream::Tls(stream) => Pin::new(stream).poll_flush(cx), } } } diff --git a/crates/lune-std-net/src/client/mod.rs b/crates/lune-std-net/src/client/mod.rs index fda1a76..743d9d3 100644 --- a/crates/lune-std-net/src/client/mod.rs +++ b/crates/lune-std-net/src/client/mod.rs @@ -6,21 +6,33 @@ use hyper::{ }; use mlua::prelude::*; +use url::Url; use crate::{ - client::stream::HttpRequestStream, + client::{http_stream::HttpStream, ws_stream::WsStream}, shared::{ hyper::{HyperExecutor, HyperIo}, request::Request, response::Response, + websocket::Websocket, }, }; pub mod config; -pub mod stream; +pub mod http_stream; +pub mod rustls; +pub mod ws_stream; const MAX_REDIRECTS: usize = 10; +/** + Connects to a websocket at the given URL. +*/ +pub async fn connect_websocket(url: Url) -> LuaResult> { + let stream = WsStream::connect(url).await?; + Ok(Websocket::from(stream)) +} + /** Sends the request and returns the final response. @@ -28,8 +40,14 @@ const MAX_REDIRECTS: usize = 10; modifying the request method and body as necessary. */ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult { + let url = request + .inner + .uri() + .to_string() + .parse::() + .expect("uri is valid"); loop { - let stream = HttpRequestStream::connect(request.inner.uri()).await?; + let stream = HttpStream::connect(url.clone()).await?; let (mut sender, conn) = handshake(HyperIo::from(stream)).await.into_lua_err()?; diff --git a/crates/lune-std-net/src/client/rustls.rs b/crates/lune-std-net/src/client/rustls.rs new file mode 100644 index 0000000..ea864ab --- /dev/null +++ b/crates/lune-std-net/src/client/rustls.rs @@ -0,0 +1,12 @@ +use std::sync::{Arc, LazyLock}; + +use rustls::ClientConfig; + +pub static CLIENT_CONFIG: LazyLock> = LazyLock::new(|| { + rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), + }) + .with_no_client_auth() + .into() +}); diff --git a/crates/lune-std-net/src/client/ws_stream.rs b/crates/lune-std-net/src/client/ws_stream.rs new file mode 100644 index 0000000..03537ec --- /dev/null +++ b/crates/lune-std-net/src/client/ws_stream.rs @@ -0,0 +1,114 @@ +use std::{ + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use async_net::TcpStream; +use async_tungstenite::{ + tungstenite::{Error as TungsteniteError, Message, Result as TungsteniteResult}, + WebSocketStream as TungsteniteStream, +}; +use futures::Sink; +use futures_lite::prelude::*; +use futures_rustls::{TlsConnector, TlsStream}; +use rustls_pki_types::ServerName; +use url::Url; + +use crate::client::rustls::CLIENT_CONFIG; + +#[derive(Debug)] +pub enum WsStream { + Plain(TungsteniteStream), + Tls(TungsteniteStream>), +} + +impl WsStream { + pub async fn connect(url: Url) -> Result { + let Some(host) = url.host() else { + return Err(make_err("unknown or missing host")); + }; + let Some(port) = url.port_or_known_default() else { + return Err(make_err("unknown or missing port")); + }; + + let use_tls = match url.scheme() { + "ws" => false, + "wss" => true, + s => return Err(make_err(format!("unsupported scheme: {s}"))), + }; + + let host = host.to_string(); + let stream = TcpStream::connect((host.clone(), port)).await?; + + let stream = if use_tls { + let servname = ServerName::try_from(host).map_err(make_err)?.to_owned(); + let connector = TlsConnector::from(Arc::clone(&CLIENT_CONFIG)); + + let stream = connector.connect(servname, stream).await?; + let stream = TlsStream::Client(stream); + + let stream = async_tungstenite::client_async(url.to_string(), stream) + .await + .map_err(make_err)? + .0; + Self::Tls(stream) + } else { + let stream = async_tungstenite::client_async(url.to_string(), stream) + .await + .map_err(make_err)? + .0; + Self::Plain(stream) + }; + + Ok(stream) + } +} + +impl Sink for WsStream { + type Error = TungsteniteError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).poll_ready(cx), + WsStream::Tls(s) => Pin::new(s).poll_ready(cx), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).start_send(item), + WsStream::Tls(s) => Pin::new(s).start_send(item), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).poll_flush(cx), + WsStream::Tls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).poll_close(cx), + WsStream::Tls(s) => Pin::new(s).poll_close(cx), + } + } +} + +impl Stream for WsStream { + type Item = TungsteniteResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).poll_next(cx), + WsStream::Tls(s) => Pin::new(s).poll_next(cx), + } + } +} + +fn make_err(e: impl ToString) -> io::Error { + io::Error::new(io::ErrorKind::Other, e.to_string()) +} diff --git a/crates/lune-std-net/src/lib.rs b/crates/lune-std-net/src/lib.rs index 1effacf..4ef9839 100644 --- a/crates/lune-std-net/src/lib.rs +++ b/crates/lune-std-net/src/lib.rs @@ -9,9 +9,9 @@ pub(crate) mod shared; pub(crate) mod url; use self::{ - client::config::RequestConfig, + client::{config::RequestConfig, ws_stream::WsStream}, server::{config::ServeConfig, handle::ServeHandle}, - shared::{request::Request, response::Response}, + shared::{request::Request, response::Response, websocket::Websocket}, }; const TYPEDEFS: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/types.d.luau")); @@ -34,7 +34,7 @@ pub fn typedefs() -> String { pub fn module(lua: Lua) -> LuaResult { TableBuilder::new(lua)? .with_async_function("request", net_request)? - // .with_async_function("socket", net_socket)? + .with_async_function("socket", net_socket)? .with_async_function("serve", net_serve)? .with_function("urlEncode", net_url_encode)? .with_function("urlDecode", net_url_decode)? @@ -45,6 +45,11 @@ async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult { self::client::send_request(Request::try_from(config)?, lua).await } +async fn net_socket(_: Lua, url: String) -> LuaResult> { + let url = url.parse().into_lua_err()?; + self::client::connect_websocket(url).await +} + async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult { self::server::serve(lua, port, config).await } diff --git a/crates/lune-std-net/src/shared/mod.rs b/crates/lune-std-net/src/shared/mod.rs index 43d3f74..4c02f02 100644 --- a/crates/lune-std-net/src/shared/mod.rs +++ b/crates/lune-std-net/src/shared/mod.rs @@ -4,3 +4,4 @@ pub mod hyper; pub mod incoming; pub mod request; pub mod response; +pub mod websocket; diff --git a/crates/lune-std-net/src/shared/websocket.rs b/crates/lune-std-net/src/shared/websocket.rs new file mode 100644 index 0000000..1d22947 --- /dev/null +++ b/crates/lune-std-net/src/shared/websocket.rs @@ -0,0 +1,143 @@ +use std::{ + error::Error, + sync::{ + atomic::{AtomicBool, AtomicU16, Ordering}, + Arc, + }, +}; + +use async_lock::Mutex as AsyncMutex; +use async_tungstenite::tungstenite::{ + protocol::{frame::coding::CloseCode, CloseFrame}, + Message as TungsteniteMessage, Result as TungsteniteResult, Utf8Bytes, +}; +use bstr::{BString, ByteSlice}; +use futures::{ + stream::{SplitSink, SplitStream}, + Sink, SinkExt, Stream, StreamExt, +}; +use hyper::body::Bytes; + +use mlua::prelude::*; + +#[derive(Debug, Clone)] +pub struct Websocket { + close_code_exists: Arc, + close_code_value: Arc, + read_stream: Arc>>, + write_stream: Arc>>, +} + +impl Websocket +where + T: Stream> + Sink + 'static, + >::Error: Into>, +{ + fn get_close_code(&self) -> Option { + if self.close_code_exists.load(Ordering::Relaxed) { + Some(self.close_code_value.load(Ordering::Relaxed)) + } else { + None + } + } + + fn set_close_code(&self, code: u16) { + self.close_code_exists.store(true, Ordering::Relaxed); + self.close_code_value.store(code, Ordering::Relaxed); + } + + pub async fn send(&self, msg: TungsteniteMessage) -> LuaResult<()> { + let mut ws = self.write_stream.lock().await; + ws.send(msg).await.into_lua_err() + } + + pub async fn next(&self) -> LuaResult> { + let mut ws = self.read_stream.lock().await; + ws.next().await.transpose().into_lua_err() + } + + pub async fn close(&self, code: Option) -> LuaResult<()> { + if self.close_code_exists.load(Ordering::Relaxed) { + return Err(LuaError::runtime("Socket has already been closed")); + } + + self.send(TungsteniteMessage::Close(Some(CloseFrame { + code: match code { + Some(code) if (1000..=4999).contains(&code) => CloseCode::from(code), + Some(code) => { + return Err(LuaError::runtime(format!( + "Close code must be between 1000 and 4999, got {code}" + ))) + } + None => CloseCode::Normal, + }, + reason: "".into(), + }))) + .await?; + + let mut ws = self.write_stream.lock().await; + ws.close().await.into_lua_err() + } +} + +impl From for Websocket +where + T: Stream> + Sink + 'static, + >::Error: Into>, +{ + fn from(value: T) -> Self { + let (write, read) = value.split(); + + Self { + close_code_exists: Arc::new(AtomicBool::new(false)), + close_code_value: Arc::new(AtomicU16::new(0)), + read_stream: Arc::new(AsyncMutex::new(read)), + write_stream: Arc::new(AsyncMutex::new(write)), + } + } +} + +impl LuaUserData for Websocket +where + T: Stream> + Sink + 'static, + >::Error: Into>, +{ + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code())); + } + + fn add_methods>(methods: &mut M) { + methods.add_async_method("close", |_, this, code: Option| async move { + this.close(code).await + }); + + methods.add_async_method( + "send", + |_, this, (string, as_binary): (BString, Option)| async move { + this.send(if as_binary.unwrap_or_default() { + TungsteniteMessage::Binary(Bytes::from(string.to_vec())) + } else { + let s = string.to_str().into_lua_err()?; + TungsteniteMessage::Text(Utf8Bytes::from(s)) + }) + .await + }, + ); + + methods.add_async_method("next", |lua, this, (): ()| async move { + let msg = this.next().await?; + + if let Some(TungsteniteMessage::Close(Some(frame))) = msg.as_ref() { + this.set_close_code(frame.code.into()); + } + + Ok(match msg { + Some(TungsteniteMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?), + Some(TungsteniteMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?), + Some(TungsteniteMessage::Close(_)) | None => LuaValue::Nil, + // Ignore ping/pong/frame messages, they are handled by tungstenite + msg => unreachable!("Unhandled message: {:?}", msg), + }) + }); + } +}