mirror of
https://github.com/lune-org/lune.git
synced 2025-05-04 10:43:57 +01:00
Implement full websocket server support and more graceful server internal error handling
This commit is contained in:
parent
3f179ab4ec
commit
bf819d1980
4 changed files with 188 additions and 25 deletions
|
@ -18,6 +18,7 @@ use crate::{
|
|||
pub mod config;
|
||||
pub mod handle;
|
||||
pub mod service;
|
||||
pub mod upgrade;
|
||||
|
||||
/**
|
||||
Starts an HTTP server using the given port and configuration.
|
||||
|
|
|
@ -1,24 +1,28 @@
|
|||
use std::{future::Future, net::SocketAddr, pin::Pin};
|
||||
|
||||
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
|
||||
use http_body_util::Full;
|
||||
use hyper::{
|
||||
body::{Bytes, Incoming},
|
||||
service::Service as HyperService,
|
||||
Request as HyperRequest, Response as HyperResponse,
|
||||
Request as HyperRequest, Response as HyperResponse, StatusCode,
|
||||
};
|
||||
|
||||
use mlua::prelude::*;
|
||||
use mlua_luau_scheduler::LuaSchedulerExt;
|
||||
use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt};
|
||||
|
||||
use crate::{
|
||||
server::config::{ResponseConfig, ServeConfig},
|
||||
shared::{request::Request, response::Response},
|
||||
server::{
|
||||
config::{ResponseConfig, ServeConfig},
|
||||
upgrade::{is_upgrade_request, make_upgrade_response},
|
||||
},
|
||||
shared::{hyper::HyperIo, request::Request, response::Response, websocket::Websocket},
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct Service {
|
||||
pub(super) lua: Lua,
|
||||
pub(super) address: SocketAddr, // NOTE: This should be the remote address of the connected client
|
||||
pub(super) address: SocketAddr, // NOTE: This must be the remote address of the connected client
|
||||
pub(super) config: ServeConfig,
|
||||
}
|
||||
|
||||
|
@ -32,24 +36,84 @@ impl HyperService<HyperRequest<Incoming>> for Service {
|
|||
let address = self.address;
|
||||
let config = self.config.clone();
|
||||
|
||||
if is_upgrade_request(&req) {
|
||||
if let Some(handler) = config.handle_web_socket {
|
||||
return Box::pin(async move {
|
||||
let response = match make_upgrade_response(&req) {
|
||||
Ok(res) => res,
|
||||
Err(err) => {
|
||||
return Ok(HyperResponse::builder()
|
||||
.status(StatusCode::BAD_REQUEST)
|
||||
.body(Full::new(Bytes::from(err.to_string())))
|
||||
.unwrap())
|
||||
}
|
||||
};
|
||||
|
||||
lua.spawn_local({
|
||||
let lua = lua.clone();
|
||||
async move {
|
||||
if let Err(_err) = handle_websocket(lua, handler, req).await {
|
||||
// TODO: Propagare the error somehow?
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(response)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Box::pin(async move {
|
||||
let handler = config.handle_request.clone();
|
||||
let request = Request::from_incoming(req, true)
|
||||
.await?
|
||||
.with_address(address);
|
||||
|
||||
let thread_id = lua.push_thread_back(handler, request)?;
|
||||
lua.track_thread(thread_id);
|
||||
lua.wait_for_thread(thread_id).await;
|
||||
|
||||
let thread_res = lua
|
||||
.get_thread_result(thread_id)
|
||||
.expect("Missing handler thread result")?;
|
||||
|
||||
let config = ResponseConfig::from_lua_multi(thread_res, &lua)?;
|
||||
let response = Response::try_from(config)?;
|
||||
|
||||
Ok(response.as_full())
|
||||
match handle_request(lua.clone(), config.handle_request, req, address).await {
|
||||
Ok(response) => Ok(response),
|
||||
Err(_err) => {
|
||||
// TODO: Propagare the error somehow?
|
||||
Ok(HyperResponse::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Full::new(Bytes::from("Lune: Internal server error")))
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
lua: Lua,
|
||||
handler: LuaFunction,
|
||||
request: HyperRequest<Incoming>,
|
||||
address: SocketAddr,
|
||||
) -> LuaResult<HyperResponse<Full<Bytes>>> {
|
||||
let request = Request::from_incoming(request, true)
|
||||
.await?
|
||||
.with_address(address);
|
||||
|
||||
let thread_id = lua.push_thread_back(handler, request)?;
|
||||
lua.track_thread(thread_id);
|
||||
lua.wait_for_thread(thread_id).await;
|
||||
|
||||
let thread_res = lua
|
||||
.get_thread_result(thread_id)
|
||||
.expect("Missing handler thread result")?;
|
||||
|
||||
let config = ResponseConfig::from_lua_multi(thread_res, &lua)?;
|
||||
let response = Response::try_from(config)?;
|
||||
|
||||
Ok(response.as_full())
|
||||
}
|
||||
|
||||
async fn handle_websocket(
|
||||
lua: Lua,
|
||||
handler: LuaFunction,
|
||||
request: HyperRequest<Incoming>,
|
||||
) -> LuaResult<()> {
|
||||
let upgraded = hyper::upgrade::on(request).await.into_lua_err()?;
|
||||
|
||||
let stream =
|
||||
WebSocketStream::from_raw_socket(HyperIo::from(upgraded), Role::Server, None).await;
|
||||
let websocket = Websocket::from(stream);
|
||||
|
||||
lua.push_thread_back(handler, websocket)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
55
crates/lune-std-net/src/server/upgrade.rs
Normal file
55
crates/lune-std-net/src/server/upgrade.rs
Normal file
|
@ -0,0 +1,55 @@
|
|||
use async_tungstenite::tungstenite::{error::ProtocolError, handshake::derive_accept_key};
|
||||
use http_body_util::Full;
|
||||
|
||||
use hyper::{
|
||||
body::{Bytes, Incoming},
|
||||
header::{HeaderName, CONNECTION, UPGRADE},
|
||||
HeaderMap, Request as HyperRequest, Response as HyperResponse, StatusCode,
|
||||
};
|
||||
|
||||
const SEC_WEBSOCKET_VERSION: HeaderName = HeaderName::from_static("sec-websocket-version");
|
||||
const SEC_WEBSOCKET_KEY: HeaderName = HeaderName::from_static("sec-websocket-key");
|
||||
const SEC_WEBSOCKET_ACCEPT: HeaderName = HeaderName::from_static("sec-websocket-accept");
|
||||
|
||||
pub fn is_upgrade_request(request: &HyperRequest<Incoming>) -> bool {
|
||||
fn check_header_contains(headers: &HeaderMap, header_name: HeaderName, value: &str) -> bool {
|
||||
headers.get(header_name).is_some_and(|header| {
|
||||
header.to_str().map_or_else(
|
||||
|_| false,
|
||||
|header_str| {
|
||||
header_str
|
||||
.split(',')
|
||||
.any(|part| part.trim().eq_ignore_ascii_case(value))
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
check_header_contains(request.headers(), CONNECTION, "Upgrade")
|
||||
&& check_header_contains(request.headers(), UPGRADE, "websocket")
|
||||
}
|
||||
|
||||
pub fn make_upgrade_response(
|
||||
request: &HyperRequest<Incoming>,
|
||||
) -> Result<HyperResponse<Full<Bytes>>, ProtocolError> {
|
||||
let key = request
|
||||
.headers()
|
||||
.get(SEC_WEBSOCKET_KEY)
|
||||
.ok_or(ProtocolError::MissingSecWebSocketKey)?;
|
||||
|
||||
if request
|
||||
.headers()
|
||||
.get(SEC_WEBSOCKET_VERSION)
|
||||
.is_none_or(|v| v.as_bytes() != b"13")
|
||||
{
|
||||
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
|
||||
}
|
||||
|
||||
Ok(HyperResponse::builder()
|
||||
.status(StatusCode::SWITCHING_PROTOCOLS)
|
||||
.header(CONNECTION, "upgrade")
|
||||
.header(UPGRADE, "websocket")
|
||||
.header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(key.as_bytes()))
|
||||
.body(Full::new(Bytes::from("switching to websocket protocol")))
|
||||
.unwrap())
|
||||
}
|
|
@ -8,8 +8,8 @@ use std::{
|
|||
};
|
||||
|
||||
use async_io::Timer;
|
||||
use futures_lite::prelude::*;
|
||||
use hyper::rt::{self, Executor, ReadBufCursor};
|
||||
use futures_lite::{prelude::*, ready};
|
||||
use hyper::rt::{self, Executor, ReadBuf, ReadBufCursor};
|
||||
use mlua::prelude::*;
|
||||
use mlua_luau_scheduler::LuaSpawnExt;
|
||||
|
||||
|
@ -94,7 +94,8 @@ impl Future for HyperSleep {
|
|||
|
||||
impl rt::Sleep for HyperSleep {}
|
||||
|
||||
// Hyper I/O wrapper for futures-lite types
|
||||
// Hyper I/O wrapper for bidirectional compatibility
|
||||
// between hyper & futures-lite async read/write traits
|
||||
|
||||
pin_project_lite::pin_project! {
|
||||
#[derive(Debug)]
|
||||
|
@ -116,6 +117,8 @@ impl<T> HyperIo<T> {
|
|||
}
|
||||
}
|
||||
|
||||
// Compat for futures-lite -> hyper runtime
|
||||
|
||||
impl<T: AsyncRead> rt::Read for HyperIo<T> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
|
@ -169,3 +172,43 @@ impl<T: AsyncWrite> rt::Write for HyperIo<T> {
|
|||
self.pin_mut().poll_write_vectored(cx, bufs)
|
||||
}
|
||||
}
|
||||
|
||||
// Compat for hyper runtime -> futures-lite
|
||||
|
||||
impl<T: rt::Read> AsyncRead for HyperIo<T> {
|
||||
fn poll_read(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
let mut buf = ReadBuf::new(buf);
|
||||
ready!(self.pin_mut().poll_read(cx, buf.unfilled()))?;
|
||||
Poll::Ready(Ok(buf.filled().len()))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: rt::Write> AsyncWrite for HyperIo<T> {
|
||||
fn poll_write(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
self.pin_mut().poll_write(cx, buf)
|
||||
}
|
||||
|
||||
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
|
||||
self.pin_mut().poll_flush(cx)
|
||||
}
|
||||
|
||||
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
self.pin_mut().poll_shutdown(cx)
|
||||
}
|
||||
|
||||
fn poll_write_vectored(
|
||||
self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
bufs: &[std::io::IoSlice<'_>],
|
||||
) -> Poll<Result<usize, std::io::Error>> {
|
||||
self.pin_mut().poll_write_vectored(cx, bufs)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue