Full implementation for websocket streams and client

This commit is contained in:
Filip Tibell 2025-04-27 19:29:48 +02:00
parent a82cb1da33
commit 3f179ab4ec
No known key found for this signature in database
9 changed files with 391 additions and 39 deletions

66
Cargo.lock generated
View file

@ -297,6 +297,22 @@ version = "4.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de"
[[package]]
name = "async-tungstenite"
version = "0.29.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef0f7efedeac57d9b26170f72965ecfd31473ca52ca7a64e925b0b6f5f079886"
dependencies = [
"atomic-waker",
"futures-core",
"futures-io",
"futures-task",
"futures-util",
"log",
"pin-project-lite",
"tungstenite",
]
[[package]]
name = "atomic-waker"
version = "1.1.2"
@ -801,6 +817,12 @@ dependencies = [
"typenum",
]
[[package]]
name = "data-encoding"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
[[package]]
name = "deflate64"
version = "0.1.9"
@ -1079,6 +1101,20 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@ -1086,6 +1122,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
@ -1153,9 +1190,13 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
@ -1857,8 +1898,10 @@ dependencies = [
"async-io",
"async-lock",
"async-net",
"async-tungstenite",
"blocking",
"bstr",
"futures",
"futures-lite",
"futures-rustls",
"http-body-util",
@ -3691,6 +3734,23 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"bytes",
"data-encoding",
"http 1.3.1",
"httparse",
"log",
"rand 0.9.1",
"sha1 0.10.6",
"thiserror 2.0.12",
"utf-8",
]
[[package]]
name = "typeid"
version = "1.0.3"
@ -3763,6 +3823,12 @@ dependencies = [
"serde",
]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "utf16_iter"
version = "1.0.5"

View file

@ -21,8 +21,10 @@ async-executor = "1.13"
async-io = "2.4"
async-lock = "3.4"
async-net = "2.0"
async-tungstenite = "0.29"
blocking = "1.6"
bstr = "1.9"
futures = { version = "0.3", default-features = false, features = ["std"] }
futures-lite = "2.6"
futures-rustls = "0.26"
http-body-util = "0.1"

View file

@ -1,50 +1,41 @@
use std::{
io,
pin::Pin,
sync::{Arc, LazyLock},
sync::Arc,
task::{Context, Poll},
};
use async_net::TcpStream;
use futures_lite::prelude::*;
use futures_rustls::{TlsConnector, TlsStream};
use hyper::Uri;
use rustls::ClientConfig;
use rustls_pki_types::ServerName;
use url::Url;
static CLIENT_CONFIG: LazyLock<Arc<ClientConfig>> = LazyLock::new(|| {
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
})
.with_no_client_auth()
.into()
});
use crate::client::rustls::CLIENT_CONFIG;
pub enum HttpRequestStream {
#[derive(Debug)]
pub enum HttpStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
}
impl HttpRequestStream {
pub async fn connect(url: &Uri) -> Result<Self, io::Error> {
impl HttpStream {
pub async fn connect(url: Url) -> Result<Self, io::Error> {
let Some(host) = url.host() else {
return Err(make_err("unknown or missing host"));
};
let Some(scheme) = url.scheme_str() else {
return Err(make_err("unknown scheme"));
let Some(port) = url.port_or_known_default() else {
return Err(make_err("unknown or missing port"));
};
let (use_tls, port) = match scheme {
"http" => (false, 80),
"https" => (true, 443),
let use_tls = match url.scheme() {
"http" => false,
"https" => true,
s => return Err(make_err(format!("unsupported scheme: {s}"))),
};
let stream = {
let port = url.port_u16().unwrap_or(port);
TcpStream::connect((host, port)).await?
};
let host = host.to_string();
let stream = TcpStream::connect((host.clone(), port)).await?;
let stream = if use_tls {
let servname = ServerName::try_from(host).map_err(make_err)?.to_owned();
@ -59,42 +50,42 @@ impl HttpRequestStream {
}
}
impl AsyncRead for HttpRequestStream {
impl AsyncRead for HttpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
HttpRequestStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
HttpRequestStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
HttpStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
HttpStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for HttpRequestStream {
impl AsyncWrite for HttpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
HttpRequestStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
HttpRequestStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
HttpStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
HttpStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
HttpRequestStream::Plain(stream) => Pin::new(stream).poll_close(cx),
HttpRequestStream::Tls(stream) => Pin::new(stream).poll_close(cx),
HttpStream::Plain(stream) => Pin::new(stream).poll_close(cx),
HttpStream::Tls(stream) => Pin::new(stream).poll_close(cx),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
HttpRequestStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
HttpRequestStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
HttpStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
HttpStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
}

View file

@ -6,21 +6,33 @@ use hyper::{
};
use mlua::prelude::*;
use url::Url;
use crate::{
client::stream::HttpRequestStream,
client::{http_stream::HttpStream, ws_stream::WsStream},
shared::{
hyper::{HyperExecutor, HyperIo},
request::Request,
response::Response,
websocket::Websocket,
},
};
pub mod config;
pub mod stream;
pub mod http_stream;
pub mod rustls;
pub mod ws_stream;
const MAX_REDIRECTS: usize = 10;
/**
Connects to a websocket at the given URL.
*/
pub async fn connect_websocket(url: Url) -> LuaResult<Websocket<WsStream>> {
let stream = WsStream::connect(url).await?;
Ok(Websocket::from(stream))
}
/**
Sends the request and returns the final response.
@ -28,8 +40,14 @@ const MAX_REDIRECTS: usize = 10;
modifying the request method and body as necessary.
*/
pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response> {
let url = request
.inner
.uri()
.to_string()
.parse::<Url>()
.expect("uri is valid");
loop {
let stream = HttpRequestStream::connect(request.inner.uri()).await?;
let stream = HttpStream::connect(url.clone()).await?;
let (mut sender, conn) = handshake(HyperIo::from(stream)).await.into_lua_err()?;

View file

@ -0,0 +1,12 @@
use std::sync::{Arc, LazyLock};
use rustls::ClientConfig;
pub static CLIENT_CONFIG: LazyLock<Arc<ClientConfig>> = LazyLock::new(|| {
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
})
.with_no_client_auth()
.into()
});

View file

@ -0,0 +1,114 @@
use std::{
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use async_net::TcpStream;
use async_tungstenite::{
tungstenite::{Error as TungsteniteError, Message, Result as TungsteniteResult},
WebSocketStream as TungsteniteStream,
};
use futures::Sink;
use futures_lite::prelude::*;
use futures_rustls::{TlsConnector, TlsStream};
use rustls_pki_types::ServerName;
use url::Url;
use crate::client::rustls::CLIENT_CONFIG;
#[derive(Debug)]
pub enum WsStream {
Plain(TungsteniteStream<TcpStream>),
Tls(TungsteniteStream<TlsStream<TcpStream>>),
}
impl WsStream {
pub async fn connect(url: Url) -> Result<Self, io::Error> {
let Some(host) = url.host() else {
return Err(make_err("unknown or missing host"));
};
let Some(port) = url.port_or_known_default() else {
return Err(make_err("unknown or missing port"));
};
let use_tls = match url.scheme() {
"ws" => false,
"wss" => true,
s => return Err(make_err(format!("unsupported scheme: {s}"))),
};
let host = host.to_string();
let stream = TcpStream::connect((host.clone(), port)).await?;
let stream = if use_tls {
let servname = ServerName::try_from(host).map_err(make_err)?.to_owned();
let connector = TlsConnector::from(Arc::clone(&CLIENT_CONFIG));
let stream = connector.connect(servname, stream).await?;
let stream = TlsStream::Client(stream);
let stream = async_tungstenite::client_async(url.to_string(), stream)
.await
.map_err(make_err)?
.0;
Self::Tls(stream)
} else {
let stream = async_tungstenite::client_async(url.to_string(), stream)
.await
.map_err(make_err)?
.0;
Self::Plain(stream)
};
Ok(stream)
}
}
impl Sink<Message> for WsStream {
type Error = TungsteniteError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_ready(cx),
WsStream::Tls(s) => Pin::new(s).poll_ready(cx),
}
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).start_send(item),
WsStream::Tls(s) => Pin::new(s).start_send(item),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_flush(cx),
WsStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_close(cx),
WsStream::Tls(s) => Pin::new(s).poll_close(cx),
}
}
}
impl Stream for WsStream {
type Item = TungsteniteResult<Message>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_next(cx),
WsStream::Tls(s) => Pin::new(s).poll_next(cx),
}
}
}
fn make_err(e: impl ToString) -> io::Error {
io::Error::new(io::ErrorKind::Other, e.to_string())
}

View file

@ -9,9 +9,9 @@ pub(crate) mod shared;
pub(crate) mod url;
use self::{
client::config::RequestConfig,
client::{config::RequestConfig, ws_stream::WsStream},
server::{config::ServeConfig, handle::ServeHandle},
shared::{request::Request, response::Response},
shared::{request::Request, response::Response, websocket::Websocket},
};
const TYPEDEFS: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/types.d.luau"));
@ -34,7 +34,7 @@ pub fn typedefs() -> String {
pub fn module(lua: Lua) -> LuaResult<LuaTable> {
TableBuilder::new(lua)?
.with_async_function("request", net_request)?
// .with_async_function("socket", net_socket)?
.with_async_function("socket", net_socket)?
.with_async_function("serve", net_serve)?
.with_function("urlEncode", net_url_encode)?
.with_function("urlDecode", net_url_decode)?
@ -45,6 +45,11 @@ async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult<Response> {
self::client::send_request(Request::try_from(config)?, lua).await
}
async fn net_socket(_: Lua, url: String) -> LuaResult<Websocket<WsStream>> {
let url = url.parse().into_lua_err()?;
self::client::connect_websocket(url).await
}
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<ServeHandle> {
self::server::serve(lua, port, config).await
}

View file

@ -4,3 +4,4 @@ pub mod hyper;
pub mod incoming;
pub mod request;
pub mod response;
pub mod websocket;

View file

@ -0,0 +1,143 @@
use std::{
error::Error,
sync::{
atomic::{AtomicBool, AtomicU16, Ordering},
Arc,
},
};
use async_lock::Mutex as AsyncMutex;
use async_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame},
Message as TungsteniteMessage, Result as TungsteniteResult, Utf8Bytes,
};
use bstr::{BString, ByteSlice};
use futures::{
stream::{SplitSink, SplitStream},
Sink, SinkExt, Stream, StreamExt,
};
use hyper::body::Bytes;
use mlua::prelude::*;
#[derive(Debug, Clone)]
pub struct Websocket<T> {
close_code_exists: Arc<AtomicBool>,
close_code_value: Arc<AtomicU16>,
read_stream: Arc<AsyncMutex<SplitStream<T>>>,
write_stream: Arc<AsyncMutex<SplitSink<T, TungsteniteMessage>>>,
}
impl<T> Websocket<T>
where
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{
fn get_close_code(&self) -> Option<u16> {
if self.close_code_exists.load(Ordering::Relaxed) {
Some(self.close_code_value.load(Ordering::Relaxed))
} else {
None
}
}
fn set_close_code(&self, code: u16) {
self.close_code_exists.store(true, Ordering::Relaxed);
self.close_code_value.store(code, Ordering::Relaxed);
}
pub async fn send(&self, msg: TungsteniteMessage) -> LuaResult<()> {
let mut ws = self.write_stream.lock().await;
ws.send(msg).await.into_lua_err()
}
pub async fn next(&self) -> LuaResult<Option<TungsteniteMessage>> {
let mut ws = self.read_stream.lock().await;
ws.next().await.transpose().into_lua_err()
}
pub async fn close(&self, code: Option<u16>) -> LuaResult<()> {
if self.close_code_exists.load(Ordering::Relaxed) {
return Err(LuaError::runtime("Socket has already been closed"));
}
self.send(TungsteniteMessage::Close(Some(CloseFrame {
code: match code {
Some(code) if (1000..=4999).contains(&code) => CloseCode::from(code),
Some(code) => {
return Err(LuaError::runtime(format!(
"Close code must be between 1000 and 4999, got {code}"
)))
}
None => CloseCode::Normal,
},
reason: "".into(),
})))
.await?;
let mut ws = self.write_stream.lock().await;
ws.close().await.into_lua_err()
}
}
impl<T> From<T> for Websocket<T>
where
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{
fn from(value: T) -> Self {
let (write, read) = value.split();
Self {
close_code_exists: Arc::new(AtomicBool::new(false)),
close_code_value: Arc::new(AtomicU16::new(0)),
read_stream: Arc::new(AsyncMutex::new(read)),
write_stream: Arc::new(AsyncMutex::new(write)),
}
}
}
impl<T> LuaUserData for Websocket<T>
where
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code()));
}
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_async_method("close", |_, this, code: Option<u16>| async move {
this.close(code).await
});
methods.add_async_method(
"send",
|_, this, (string, as_binary): (BString, Option<bool>)| async move {
this.send(if as_binary.unwrap_or_default() {
TungsteniteMessage::Binary(Bytes::from(string.to_vec()))
} else {
let s = string.to_str().into_lua_err()?;
TungsteniteMessage::Text(Utf8Bytes::from(s))
})
.await
},
);
methods.add_async_method("next", |lua, this, (): ()| async move {
let msg = this.next().await?;
if let Some(TungsteniteMessage::Close(Some(frame))) = msg.as_ref() {
this.set_close_code(frame.code.into());
}
Ok(match msg {
Some(TungsteniteMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?),
Some(TungsteniteMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?),
Some(TungsteniteMessage::Close(_)) | None => LuaValue::Nil,
// Ignore ping/pong/frame messages, they are handled by tungstenite
msg => unreachable!("Unhandled message: {:?}", msg),
})
});
}
}