Implement full websocket server support and more graceful server internal error handling

This commit is contained in:
Filip Tibell 2025-04-27 22:10:57 +02:00
parent 3f179ab4ec
commit bf819d1980
No known key found for this signature in database
4 changed files with 188 additions and 25 deletions

View file

@ -18,6 +18,7 @@ use crate::{
pub mod config;
pub mod handle;
pub mod service;
pub mod upgrade;
/**
Starts an HTTP server using the given port and configuration.

View file

@ -1,24 +1,28 @@
use std::{future::Future, net::SocketAddr, pin::Pin};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use http_body_util::Full;
use hyper::{
body::{Bytes, Incoming},
service::Service as HyperService,
Request as HyperRequest, Response as HyperResponse,
Request as HyperRequest, Response as HyperResponse, StatusCode,
};
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSchedulerExt;
use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt};
use crate::{
server::config::{ResponseConfig, ServeConfig},
shared::{request::Request, response::Response},
server::{
config::{ResponseConfig, ServeConfig},
upgrade::{is_upgrade_request, make_upgrade_response},
},
shared::{hyper::HyperIo, request::Request, response::Response, websocket::Websocket},
};
#[derive(Debug, Clone)]
pub(super) struct Service {
pub(super) lua: Lua,
pub(super) address: SocketAddr, // NOTE: This should be the remote address of the connected client
pub(super) address: SocketAddr, // NOTE: This must be the remote address of the connected client
pub(super) config: ServeConfig,
}
@ -32,24 +36,84 @@ impl HyperService<HyperRequest<Incoming>> for Service {
let address = self.address;
let config = self.config.clone();
if is_upgrade_request(&req) {
if let Some(handler) = config.handle_web_socket {
return Box::pin(async move {
let response = match make_upgrade_response(&req) {
Ok(res) => res,
Err(err) => {
return Ok(HyperResponse::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from(err.to_string())))
.unwrap())
}
};
lua.spawn_local({
let lua = lua.clone();
async move {
if let Err(_err) = handle_websocket(lua, handler, req).await {
// TODO: Propagare the error somehow?
}
}
});
Ok(response)
});
}
}
Box::pin(async move {
let handler = config.handle_request.clone();
let request = Request::from_incoming(req, true)
.await?
.with_address(address);
let thread_id = lua.push_thread_back(handler, request)?;
lua.track_thread(thread_id);
lua.wait_for_thread(thread_id).await;
let thread_res = lua
.get_thread_result(thread_id)
.expect("Missing handler thread result")?;
let config = ResponseConfig::from_lua_multi(thread_res, &lua)?;
let response = Response::try_from(config)?;
Ok(response.as_full())
match handle_request(lua.clone(), config.handle_request, req, address).await {
Ok(response) => Ok(response),
Err(_err) => {
// TODO: Propagare the error somehow?
Ok(HyperResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::new(Bytes::from("Lune: Internal server error")))
.unwrap())
}
}
})
}
}
async fn handle_request(
lua: Lua,
handler: LuaFunction,
request: HyperRequest<Incoming>,
address: SocketAddr,
) -> LuaResult<HyperResponse<Full<Bytes>>> {
let request = Request::from_incoming(request, true)
.await?
.with_address(address);
let thread_id = lua.push_thread_back(handler, request)?;
lua.track_thread(thread_id);
lua.wait_for_thread(thread_id).await;
let thread_res = lua
.get_thread_result(thread_id)
.expect("Missing handler thread result")?;
let config = ResponseConfig::from_lua_multi(thread_res, &lua)?;
let response = Response::try_from(config)?;
Ok(response.as_full())
}
async fn handle_websocket(
lua: Lua,
handler: LuaFunction,
request: HyperRequest<Incoming>,
) -> LuaResult<()> {
let upgraded = hyper::upgrade::on(request).await.into_lua_err()?;
let stream =
WebSocketStream::from_raw_socket(HyperIo::from(upgraded), Role::Server, None).await;
let websocket = Websocket::from(stream);
lua.push_thread_back(handler, websocket)?;
Ok(())
}

View file

@ -0,0 +1,55 @@
use async_tungstenite::tungstenite::{error::ProtocolError, handshake::derive_accept_key};
use http_body_util::Full;
use hyper::{
body::{Bytes, Incoming},
header::{HeaderName, CONNECTION, UPGRADE},
HeaderMap, Request as HyperRequest, Response as HyperResponse, StatusCode,
};
const SEC_WEBSOCKET_VERSION: HeaderName = HeaderName::from_static("sec-websocket-version");
const SEC_WEBSOCKET_KEY: HeaderName = HeaderName::from_static("sec-websocket-key");
const SEC_WEBSOCKET_ACCEPT: HeaderName = HeaderName::from_static("sec-websocket-accept");
pub fn is_upgrade_request(request: &HyperRequest<Incoming>) -> bool {
fn check_header_contains(headers: &HeaderMap, header_name: HeaderName, value: &str) -> bool {
headers.get(header_name).is_some_and(|header| {
header.to_str().map_or_else(
|_| false,
|header_str| {
header_str
.split(',')
.any(|part| part.trim().eq_ignore_ascii_case(value))
},
)
})
}
check_header_contains(request.headers(), CONNECTION, "Upgrade")
&& check_header_contains(request.headers(), UPGRADE, "websocket")
}
pub fn make_upgrade_response(
request: &HyperRequest<Incoming>,
) -> Result<HyperResponse<Full<Bytes>>, ProtocolError> {
let key = request
.headers()
.get(SEC_WEBSOCKET_KEY)
.ok_or(ProtocolError::MissingSecWebSocketKey)?;
if request
.headers()
.get(SEC_WEBSOCKET_VERSION)
.is_none_or(|v| v.as_bytes() != b"13")
{
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
}
Ok(HyperResponse::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(CONNECTION, "upgrade")
.header(UPGRADE, "websocket")
.header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(key.as_bytes()))
.body(Full::new(Bytes::from("switching to websocket protocol")))
.unwrap())
}

View file

@ -8,8 +8,8 @@ use std::{
};
use async_io::Timer;
use futures_lite::prelude::*;
use hyper::rt::{self, Executor, ReadBufCursor};
use futures_lite::{prelude::*, ready};
use hyper::rt::{self, Executor, ReadBuf, ReadBufCursor};
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
@ -94,7 +94,8 @@ impl Future for HyperSleep {
impl rt::Sleep for HyperSleep {}
// Hyper I/O wrapper for futures-lite types
// Hyper I/O wrapper for bidirectional compatibility
// between hyper & futures-lite async read/write traits
pin_project_lite::pin_project! {
#[derive(Debug)]
@ -116,6 +117,8 @@ impl<T> HyperIo<T> {
}
}
// Compat for futures-lite -> hyper runtime
impl<T: AsyncRead> rt::Read for HyperIo<T> {
fn poll_read(
self: Pin<&mut Self>,
@ -169,3 +172,43 @@ impl<T: AsyncWrite> rt::Write for HyperIo<T> {
self.pin_mut().poll_write_vectored(cx, bufs)
}
}
// Compat for hyper runtime -> futures-lite
impl<T: rt::Read> AsyncRead for HyperIo<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut buf = ReadBuf::new(buf);
ready!(self.pin_mut().poll_read(cx, buf.unfilled()))?;
Poll::Ready(Ok(buf.filled().len()))
}
}
impl<T: rt::Write> AsyncWrite for HyperIo<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
self.pin_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
self.pin_mut().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.pin_mut().poll_shutdown(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[std::io::IoSlice<'_>],
) -> Poll<Result<usize, std::io::Error>> {
self.pin_mut().poll_write_vectored(cx, bufs)
}
}