From bf819d198078c3a9d29a1c8e230fcc673bd3c54c Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Sun, 27 Apr 2025 22:10:57 +0200 Subject: [PATCH] Implement full websocket server support and more graceful server internal error handling --- crates/lune-std-net/src/server/mod.rs | 1 + crates/lune-std-net/src/server/service.rs | 108 +++++++++++++++++----- crates/lune-std-net/src/server/upgrade.rs | 55 +++++++++++ crates/lune-std-net/src/shared/hyper.rs | 49 +++++++++- 4 files changed, 188 insertions(+), 25 deletions(-) create mode 100644 crates/lune-std-net/src/server/upgrade.rs diff --git a/crates/lune-std-net/src/server/mod.rs b/crates/lune-std-net/src/server/mod.rs index 184431c..241cbb7 100644 --- a/crates/lune-std-net/src/server/mod.rs +++ b/crates/lune-std-net/src/server/mod.rs @@ -18,6 +18,7 @@ use crate::{ pub mod config; pub mod handle; pub mod service; +pub mod upgrade; /** Starts an HTTP server using the given port and configuration. diff --git a/crates/lune-std-net/src/server/service.rs b/crates/lune-std-net/src/server/service.rs index dff6574..22f7459 100644 --- a/crates/lune-std-net/src/server/service.rs +++ b/crates/lune-std-net/src/server/service.rs @@ -1,24 +1,28 @@ use std::{future::Future, net::SocketAddr, pin::Pin}; +use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use http_body_util::Full; use hyper::{ body::{Bytes, Incoming}, service::Service as HyperService, - Request as HyperRequest, Response as HyperResponse, + Request as HyperRequest, Response as HyperResponse, StatusCode, }; use mlua::prelude::*; -use mlua_luau_scheduler::LuaSchedulerExt; +use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt}; use crate::{ - server::config::{ResponseConfig, ServeConfig}, - shared::{request::Request, response::Response}, + server::{ + config::{ResponseConfig, ServeConfig}, + upgrade::{is_upgrade_request, make_upgrade_response}, + }, + shared::{hyper::HyperIo, request::Request, response::Response, websocket::Websocket}, }; #[derive(Debug, Clone)] pub(super) struct Service { pub(super) lua: Lua, - pub(super) address: SocketAddr, // NOTE: This should be the remote address of the connected client + pub(super) address: SocketAddr, // NOTE: This must be the remote address of the connected client pub(super) config: ServeConfig, } @@ -32,24 +36,84 @@ impl HyperService> for Service { let address = self.address; let config = self.config.clone(); + if is_upgrade_request(&req) { + if let Some(handler) = config.handle_web_socket { + return Box::pin(async move { + let response = match make_upgrade_response(&req) { + Ok(res) => res, + Err(err) => { + return Ok(HyperResponse::builder() + .status(StatusCode::BAD_REQUEST) + .body(Full::new(Bytes::from(err.to_string()))) + .unwrap()) + } + }; + + lua.spawn_local({ + let lua = lua.clone(); + async move { + if let Err(_err) = handle_websocket(lua, handler, req).await { + // TODO: Propagare the error somehow? + } + } + }); + + Ok(response) + }); + } + } + Box::pin(async move { - let handler = config.handle_request.clone(); - let request = Request::from_incoming(req, true) - .await? - .with_address(address); - - let thread_id = lua.push_thread_back(handler, request)?; - lua.track_thread(thread_id); - lua.wait_for_thread(thread_id).await; - - let thread_res = lua - .get_thread_result(thread_id) - .expect("Missing handler thread result")?; - - let config = ResponseConfig::from_lua_multi(thread_res, &lua)?; - let response = Response::try_from(config)?; - - Ok(response.as_full()) + match handle_request(lua.clone(), config.handle_request, req, address).await { + Ok(response) => Ok(response), + Err(_err) => { + // TODO: Propagare the error somehow? + Ok(HyperResponse::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Full::new(Bytes::from("Lune: Internal server error"))) + .unwrap()) + } + } }) } } + +async fn handle_request( + lua: Lua, + handler: LuaFunction, + request: HyperRequest, + address: SocketAddr, +) -> LuaResult>> { + let request = Request::from_incoming(request, true) + .await? + .with_address(address); + + let thread_id = lua.push_thread_back(handler, request)?; + lua.track_thread(thread_id); + lua.wait_for_thread(thread_id).await; + + let thread_res = lua + .get_thread_result(thread_id) + .expect("Missing handler thread result")?; + + let config = ResponseConfig::from_lua_multi(thread_res, &lua)?; + let response = Response::try_from(config)?; + + Ok(response.as_full()) +} + +async fn handle_websocket( + lua: Lua, + handler: LuaFunction, + request: HyperRequest, +) -> LuaResult<()> { + let upgraded = hyper::upgrade::on(request).await.into_lua_err()?; + + let stream = + WebSocketStream::from_raw_socket(HyperIo::from(upgraded), Role::Server, None).await; + let websocket = Websocket::from(stream); + + lua.push_thread_back(handler, websocket)?; + + Ok(()) +} diff --git a/crates/lune-std-net/src/server/upgrade.rs b/crates/lune-std-net/src/server/upgrade.rs new file mode 100644 index 0000000..6840957 --- /dev/null +++ b/crates/lune-std-net/src/server/upgrade.rs @@ -0,0 +1,55 @@ +use async_tungstenite::tungstenite::{error::ProtocolError, handshake::derive_accept_key}; +use http_body_util::Full; + +use hyper::{ + body::{Bytes, Incoming}, + header::{HeaderName, CONNECTION, UPGRADE}, + HeaderMap, Request as HyperRequest, Response as HyperResponse, StatusCode, +}; + +const SEC_WEBSOCKET_VERSION: HeaderName = HeaderName::from_static("sec-websocket-version"); +const SEC_WEBSOCKET_KEY: HeaderName = HeaderName::from_static("sec-websocket-key"); +const SEC_WEBSOCKET_ACCEPT: HeaderName = HeaderName::from_static("sec-websocket-accept"); + +pub fn is_upgrade_request(request: &HyperRequest) -> bool { + fn check_header_contains(headers: &HeaderMap, header_name: HeaderName, value: &str) -> bool { + headers.get(header_name).is_some_and(|header| { + header.to_str().map_or_else( + |_| false, + |header_str| { + header_str + .split(',') + .any(|part| part.trim().eq_ignore_ascii_case(value)) + }, + ) + }) + } + + check_header_contains(request.headers(), CONNECTION, "Upgrade") + && check_header_contains(request.headers(), UPGRADE, "websocket") +} + +pub fn make_upgrade_response( + request: &HyperRequest, +) -> Result>, ProtocolError> { + let key = request + .headers() + .get(SEC_WEBSOCKET_KEY) + .ok_or(ProtocolError::MissingSecWebSocketKey)?; + + if request + .headers() + .get(SEC_WEBSOCKET_VERSION) + .is_none_or(|v| v.as_bytes() != b"13") + { + return Err(ProtocolError::MissingSecWebSocketVersionHeader); + } + + Ok(HyperResponse::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(CONNECTION, "upgrade") + .header(UPGRADE, "websocket") + .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(key.as_bytes())) + .body(Full::new(Bytes::from("switching to websocket protocol"))) + .unwrap()) +} diff --git a/crates/lune-std-net/src/shared/hyper.rs b/crates/lune-std-net/src/shared/hyper.rs index 2dc1ea8..33fdaa8 100644 --- a/crates/lune-std-net/src/shared/hyper.rs +++ b/crates/lune-std-net/src/shared/hyper.rs @@ -8,8 +8,8 @@ use std::{ }; use async_io::Timer; -use futures_lite::prelude::*; -use hyper::rt::{self, Executor, ReadBufCursor}; +use futures_lite::{prelude::*, ready}; +use hyper::rt::{self, Executor, ReadBuf, ReadBufCursor}; use mlua::prelude::*; use mlua_luau_scheduler::LuaSpawnExt; @@ -94,7 +94,8 @@ impl Future for HyperSleep { impl rt::Sleep for HyperSleep {} -// Hyper I/O wrapper for futures-lite types +// Hyper I/O wrapper for bidirectional compatibility +// between hyper & futures-lite async read/write traits pin_project_lite::pin_project! { #[derive(Debug)] @@ -116,6 +117,8 @@ impl HyperIo { } } +// Compat for futures-lite -> hyper runtime + impl rt::Read for HyperIo { fn poll_read( self: Pin<&mut Self>, @@ -169,3 +172,43 @@ impl rt::Write for HyperIo { self.pin_mut().poll_write_vectored(cx, bufs) } } + +// Compat for hyper runtime -> futures-lite + +impl AsyncRead for HyperIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut buf = ReadBuf::new(buf); + ready!(self.pin_mut().poll_read(cx, buf.unfilled()))?; + Poll::Ready(Ok(buf.filled().len())) + } +} + +impl AsyncWrite for HyperIo { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.pin_mut().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_mut().poll_shutdown(cx) + } + + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[std::io::IoSlice<'_>], + ) -> Poll> { + self.pin_mut().poll_write_vectored(cx, bufs) + } +}