Implement serve handle again and proper graceful shutdowns

This commit is contained in:
Filip Tibell 2025-04-27 15:33:02 +02:00
parent d57a7b949d
commit a82cb1da33
18 changed files with 178 additions and 36 deletions

View file

@ -49,5 +49,5 @@ print(`Listening on port {PORT} 🚀`)
task.delay(2, function()
print("Shutting down...")
task.wait(1)
handle.stop()
handle:stop()
end)

View file

@ -32,6 +32,6 @@ print(`Listening on port {PORT} 🚀`)
task.delay(10, function()
print("Shutting down...")
task.wait(1)
handle.stop()
handle:stop()
task.wait(1)
end)

1
Cargo.lock generated
View file

@ -1852,6 +1852,7 @@ dependencies = [
name = "lune-std-net"
version = "0.2.0"
dependencies = [
"async-channel",
"async-executor",
"async-io",
"async-lock",

View file

@ -16,6 +16,7 @@ workspace = true
mlua = { version = "0.10.3", features = ["luau"] }
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" }
async-channel = "2.3"
async-executor = "1.13"
async-io = "2.4"
async-lock = "3.4"

View file

@ -41,7 +41,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response>
.into_lua_err()?;
if let Some((new_method, new_uri)) = check_redirect(&request.inner, &incoming) {
if request.redirects >= MAX_REDIRECTS {
if request.redirects.is_some_and(|r| r >= MAX_REDIRECTS) {
return Err(LuaError::external("Too many redirects"));
}
@ -52,7 +52,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response>
*request.inner.method_mut() = new_method;
*request.inner.uri_mut() = new_uri;
request.redirects += 1;
*request.redirects.get_or_insert_default() += 1;
continue;
}

View file

@ -10,7 +10,7 @@ pub(crate) mod url;
use self::{
client::config::RequestConfig,
server::config::ServeConfig,
server::{config::ServeConfig, handle::ServeHandle},
shared::{request::Request, response::Response},
};
@ -45,7 +45,7 @@ async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult<Response> {
self::client::send_request(Request::try_from(config)?, lua).await
}
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<()> {
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<ServeHandle> {
self::server::serve(lua, port, config).await
}

View file

@ -0,0 +1,50 @@
use std::{
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use async_channel::{unbounded, Receiver, Sender};
use mlua::prelude::*;
#[derive(Debug, Clone)]
pub struct ServeHandle {
addr: SocketAddr,
shutdown: Arc<AtomicBool>,
sender: Sender<()>,
}
impl ServeHandle {
pub fn new(addr: SocketAddr) -> (Self, Receiver<()>) {
let (sender, receiver) = unbounded();
let this = Self {
addr,
shutdown: Arc::new(AtomicBool::new(false)),
sender,
};
(this, receiver)
}
}
impl LuaUserData for ServeHandle {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ip", |_, this| Ok(this.addr.ip().to_string()));
fields.add_field_method_get("port", |_, this| Ok(this.addr.port()));
}
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_method("stop", |_, this, ()| {
if this.shutdown.load(Ordering::SeqCst) {
Err(LuaError::runtime("Server already stopped"))
} else {
this.shutdown.store(true, Ordering::SeqCst);
this.sender.try_send(()).ok();
this.sender.close();
Ok(())
}
});
}
}

View file

@ -1,23 +1,30 @@
use std::net::SocketAddr;
use async_net::TcpListener;
use futures_lite::pin;
use hyper::server::conn::http1::Builder as Http1Builder;
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
use crate::{
server::{config::ServeConfig, service::Service},
shared::hyper::{HyperIo, HyperTimer},
server::{config::ServeConfig, handle::ServeHandle, service::Service},
shared::{
futures::{either, Either},
hyper::{HyperIo, HyperTimer},
},
};
pub mod config;
pub mod handle;
pub mod service;
/**
Starts an HTTP server using the given port and configuration.
Returns a `ServeHandle` that can be used to gracefully stop the server.
*/
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<ServeHandle> {
let address = SocketAddr::from((config.address, port));
let service = Service {
lua: lua.clone(),
@ -26,13 +33,33 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
};
let listener = TcpListener::bind(address).await?;
let (handle, shutdown_rx) = ServeHandle::new(address);
lua.spawn_local({
let lua = lua.clone();
async move {
let mut running_forever = false;
loop {
let (connection, _addr) = match listener.accept().await {
Ok((connection, addr)) => (connection, addr),
let accepted = if running_forever {
listener.accept().await
} else {
match either(shutdown_rx.recv(), listener.accept()).await {
Either::Left(res) => {
if res.is_ok() {
break;
}
// NOTE: We will only get a RecvError if the serve handle is dropped,
// this means lua has garbage collected it and the user does not want
// to manually stop the server using the serve handle. Run forever.
running_forever = true;
continue;
}
Either::Right(acc) => acc,
}
};
let (conn, addr) = match accepted {
Ok((conn, addr)) => (conn, addr),
Err(err) => {
eprintln!("Error while accepting connection: {err}");
continue;
@ -40,16 +67,22 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
};
lua.spawn_local({
let service = service.clone();
let rx = shutdown_rx.clone();
let io = HyperIo::from(conn);
let mut svc = service.clone();
svc.address = addr;
async move {
let result = Http1Builder::new()
let conn = Http1Builder::new()
.timer(HyperTimer)
.keep_alive(true) // Needed for websockets
.serve_connection(HyperIo::from(connection), service)
.with_upgrades()
.await;
if let Err(err) = result {
eprintln!("Error while responding to request: {err}");
.keep_alive(true)
.serve_connection(io, svc)
.with_upgrades();
// NOTE: Because we use keep_alive for websockets above, we need to
// also manually poll this future and handle the shutdown signal here
pin!(conn);
match either(rx.recv(), conn.as_mut()).await {
Either::Left(_) => conn.as_mut().graceful_shutdown(),
Either::Right(_) => {}
}
}
});
@ -57,5 +90,5 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
}
});
Ok(())
Ok(handle)
}

View file

@ -18,7 +18,7 @@ use crate::{
#[derive(Debug, Clone)]
pub(super) struct Service {
pub(super) lua: Lua,
pub(super) address: SocketAddr,
pub(super) address: SocketAddr, // NOTE: This should be the remote address of the connected client
pub(super) config: ServeConfig,
}
@ -29,11 +29,14 @@ impl HyperService<HyperRequest<Incoming>> for Service {
fn call(&self, req: HyperRequest<Incoming>) -> Self::Future {
let lua = self.lua.clone();
let address = self.address;
let config = self.config.clone();
Box::pin(async move {
let handler = config.handle_request.clone();
let request = Request::from_incoming(req, true).await?;
let request = Request::from_incoming(req, true)
.await?
.with_address(address);
let thread_id = lua.push_thread_back(handler, request)?;
lua.track_thread(thread_id);

View file

@ -0,0 +1,19 @@
use futures_lite::prelude::*;
pub use http_body_util::Either;
/**
Combines the left and right futures into a single future
that resolves to either the left or right output.
This combinator is biased - if both futures resolve at
the same time, the left future's output is returned.
*/
pub fn either<L: Future, R: Future>(
left: L,
right: R,
) -> impl Future<Output = Either<L::Output, R::Output>> {
let fut_left = async move { Either::Left(left.await) };
let fut_right = async move { Either::Right(right.await) };
fut_left.or(fut_right)
}

View file

@ -1,3 +1,4 @@
pub mod futures;
pub mod headers;
pub mod hyper;
pub mod incoming;

View file

@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, net::SocketAddr};
use http_body_util::Full;
use url::Url;
@ -24,7 +24,8 @@ pub struct Request {
// NOTE: We use Bytes instead of Full<Bytes> to avoid
// needing async when getting a reference to the body
pub(crate) inner: HyperRequest<Bytes>,
pub(crate) redirects: usize,
pub(crate) address: Option<SocketAddr>,
pub(crate) redirects: Option<usize>,
pub(crate) decompress: bool,
}
@ -42,11 +43,22 @@ impl Request {
Ok(Self {
inner: HyperRequest::from_parts(parts, body),
redirects: 0,
address: None,
redirects: None,
decompress,
})
}
/**
Attaches a socket address to the request.
This will make the `ip` and `port` fields available on the request.
*/
pub fn with_address(mut self, address: SocketAddr) -> Self {
self.address = Some(address);
self
}
/**
Returns the method of the request.
*/
@ -154,7 +166,8 @@ impl TryFrom<RequestConfig> for Request {
Ok(Self {
inner,
redirects: 0,
address: None,
redirects: None,
decompress: config.options.decompress,
})
}
@ -162,6 +175,12 @@ impl TryFrom<RequestConfig> for Request {
impl LuaUserData for Request {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ip", |_, this| {
Ok(this.address.map(|address| address.ip().to_string()))
});
fields.add_field_method_get("port", |_, this| {
Ok(this.address.map(|address| address.port()))
});
fields.add_field_method_get("method", |_, this| Ok(this.method().to_string()));
fields.add_field_method_get("path", |_, this| Ok(this.path().to_string()));
fields.add_field_method_get("query", |lua, this| {

View file

@ -28,7 +28,7 @@ local handle = net.serve(PORT, function()
end)
task.delay(0.25, function()
handle.stop()
handle:stop()
end)
test(net.serve, PORT, function() end)

View file

@ -19,7 +19,7 @@ local handle = net.serve(PORT, {
local response = net.request(`{LOCALHOST}:{PORT}`).body
assert(response ~= nil, "Invalid response from server")
handle.stop()
handle:stop()
-- Attempting to serve with a malformed IP address should throw an error

View file

@ -12,7 +12,7 @@ end)
-- Stopping is not guaranteed to happen instantly since it is async, but
-- it should happen on the next yield, so we wait the minimum amount here
handle.stop()
handle:stop()
task.wait()
-- Sending a request to the stopped server should now error

View file

@ -21,4 +21,4 @@ end)
task.cancel(thread)
handle.stop()
handle:stop()

View file

@ -7,24 +7,39 @@ local PORT = 8083
local URL = `http://127.0.0.1:{PORT}`
local RESPONSE = "Hello, lune!"
-- Serve should respond to a request we send to it
-- Serve should get proper path, query, and other request information
local handle = net.serve(PORT, function(request)
-- print("Got a request from", request.ip, "on port", request.port)
assert(type(request.path) == "string")
assert(type(request.query) == "table")
assert(type(request.query.key) == "table")
assert(type(request.query.key2) == "string")
assert(request.path == "/some/path")
assert(request.query.key == "param2")
assert(request.query.key[1] == "param1")
assert(request.query.key[2] == "param2")
assert(request.query.key2 == "param3")
return RESPONSE
end)
-- Serve should be able to handle at least 1000 requests per second with a basic handler such as the above
local thread = task.delay(1, function()
stdio.ewrite("Serve should respond to requests in a reasonable amount of time\n")
task.wait(1)
process.exit(1)
end)
local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body
assert(response == RESPONSE, "Invalid response from server")
-- Serve should respond to requests we send, and keep responding until we stop it
for _ = 1, 1024 do
local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body
assert(response == RESPONSE, "Invalid response from server")
end
task.cancel(thread)
handle.stop()
handle:stop()

View file

@ -64,4 +64,4 @@ assert(
)
-- Stop the server to end the test
handle.stop()
handle:stop()