mirror of
https://github.com/lune-org/lune.git
synced 2025-05-04 10:43:57 +01:00
Implement serve handle again and proper graceful shutdowns
This commit is contained in:
parent
d57a7b949d
commit
a82cb1da33
18 changed files with 178 additions and 36 deletions
|
@ -49,5 +49,5 @@ print(`Listening on port {PORT} 🚀`)
|
||||||
task.delay(2, function()
|
task.delay(2, function()
|
||||||
print("Shutting down...")
|
print("Shutting down...")
|
||||||
task.wait(1)
|
task.wait(1)
|
||||||
handle.stop()
|
handle:stop()
|
||||||
end)
|
end)
|
||||||
|
|
|
@ -32,6 +32,6 @@ print(`Listening on port {PORT} 🚀`)
|
||||||
task.delay(10, function()
|
task.delay(10, function()
|
||||||
print("Shutting down...")
|
print("Shutting down...")
|
||||||
task.wait(1)
|
task.wait(1)
|
||||||
handle.stop()
|
handle:stop()
|
||||||
task.wait(1)
|
task.wait(1)
|
||||||
end)
|
end)
|
||||||
|
|
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -1852,6 +1852,7 @@ dependencies = [
|
||||||
name = "lune-std-net"
|
name = "lune-std-net"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"async-channel",
|
||||||
"async-executor",
|
"async-executor",
|
||||||
"async-io",
|
"async-io",
|
||||||
"async-lock",
|
"async-lock",
|
||||||
|
|
|
@ -16,6 +16,7 @@ workspace = true
|
||||||
mlua = { version = "0.10.3", features = ["luau"] }
|
mlua = { version = "0.10.3", features = ["luau"] }
|
||||||
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" }
|
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" }
|
||||||
|
|
||||||
|
async-channel = "2.3"
|
||||||
async-executor = "1.13"
|
async-executor = "1.13"
|
||||||
async-io = "2.4"
|
async-io = "2.4"
|
||||||
async-lock = "3.4"
|
async-lock = "3.4"
|
||||||
|
|
|
@ -41,7 +41,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response>
|
||||||
.into_lua_err()?;
|
.into_lua_err()?;
|
||||||
|
|
||||||
if let Some((new_method, new_uri)) = check_redirect(&request.inner, &incoming) {
|
if let Some((new_method, new_uri)) = check_redirect(&request.inner, &incoming) {
|
||||||
if request.redirects >= MAX_REDIRECTS {
|
if request.redirects.is_some_and(|r| r >= MAX_REDIRECTS) {
|
||||||
return Err(LuaError::external("Too many redirects"));
|
return Err(LuaError::external("Too many redirects"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -52,7 +52,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response>
|
||||||
*request.inner.method_mut() = new_method;
|
*request.inner.method_mut() = new_method;
|
||||||
*request.inner.uri_mut() = new_uri;
|
*request.inner.uri_mut() = new_uri;
|
||||||
|
|
||||||
request.redirects += 1;
|
*request.redirects.get_or_insert_default() += 1;
|
||||||
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,7 +10,7 @@ pub(crate) mod url;
|
||||||
|
|
||||||
use self::{
|
use self::{
|
||||||
client::config::RequestConfig,
|
client::config::RequestConfig,
|
||||||
server::config::ServeConfig,
|
server::{config::ServeConfig, handle::ServeHandle},
|
||||||
shared::{request::Request, response::Response},
|
shared::{request::Request, response::Response},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult<Response> {
|
||||||
self::client::send_request(Request::try_from(config)?, lua).await
|
self::client::send_request(Request::try_from(config)?, lua).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<()> {
|
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<ServeHandle> {
|
||||||
self::server::serve(lua, port, config).await
|
self::server::serve(lua, port, config).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
50
crates/lune-std-net/src/server/handle.rs
Normal file
50
crates/lune-std-net/src/server/handle.rs
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
use std::{
|
||||||
|
net::SocketAddr,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
use async_channel::{unbounded, Receiver, Sender};
|
||||||
|
|
||||||
|
use mlua::prelude::*;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ServeHandle {
|
||||||
|
addr: SocketAddr,
|
||||||
|
shutdown: Arc<AtomicBool>,
|
||||||
|
sender: Sender<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ServeHandle {
|
||||||
|
pub fn new(addr: SocketAddr) -> (Self, Receiver<()>) {
|
||||||
|
let (sender, receiver) = unbounded();
|
||||||
|
let this = Self {
|
||||||
|
addr,
|
||||||
|
shutdown: Arc::new(AtomicBool::new(false)),
|
||||||
|
sender,
|
||||||
|
};
|
||||||
|
(this, receiver)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LuaUserData for ServeHandle {
|
||||||
|
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
|
||||||
|
fields.add_field_method_get("ip", |_, this| Ok(this.addr.ip().to_string()));
|
||||||
|
fields.add_field_method_get("port", |_, this| Ok(this.addr.port()));
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
|
||||||
|
methods.add_method("stop", |_, this, ()| {
|
||||||
|
if this.shutdown.load(Ordering::SeqCst) {
|
||||||
|
Err(LuaError::runtime("Server already stopped"))
|
||||||
|
} else {
|
||||||
|
this.shutdown.store(true, Ordering::SeqCst);
|
||||||
|
this.sender.try_send(()).ok();
|
||||||
|
this.sender.close();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,23 +1,30 @@
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
|
|
||||||
use async_net::TcpListener;
|
use async_net::TcpListener;
|
||||||
|
use futures_lite::pin;
|
||||||
use hyper::server::conn::http1::Builder as Http1Builder;
|
use hyper::server::conn::http1::Builder as Http1Builder;
|
||||||
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
use mlua_luau_scheduler::LuaSpawnExt;
|
use mlua_luau_scheduler::LuaSpawnExt;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
server::{config::ServeConfig, service::Service},
|
server::{config::ServeConfig, handle::ServeHandle, service::Service},
|
||||||
shared::hyper::{HyperIo, HyperTimer},
|
shared::{
|
||||||
|
futures::{either, Either},
|
||||||
|
hyper::{HyperIo, HyperTimer},
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
pub mod config;
|
pub mod config;
|
||||||
|
pub mod handle;
|
||||||
pub mod service;
|
pub mod service;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
Starts an HTTP server using the given port and configuration.
|
Starts an HTTP server using the given port and configuration.
|
||||||
|
|
||||||
|
Returns a `ServeHandle` that can be used to gracefully stop the server.
|
||||||
*/
|
*/
|
||||||
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
|
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<ServeHandle> {
|
||||||
let address = SocketAddr::from((config.address, port));
|
let address = SocketAddr::from((config.address, port));
|
||||||
let service = Service {
|
let service = Service {
|
||||||
lua: lua.clone(),
|
lua: lua.clone(),
|
||||||
|
@ -26,13 +33,33 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
|
||||||
};
|
};
|
||||||
|
|
||||||
let listener = TcpListener::bind(address).await?;
|
let listener = TcpListener::bind(address).await?;
|
||||||
|
let (handle, shutdown_rx) = ServeHandle::new(address);
|
||||||
|
|
||||||
lua.spawn_local({
|
lua.spawn_local({
|
||||||
let lua = lua.clone();
|
let lua = lua.clone();
|
||||||
async move {
|
async move {
|
||||||
|
let mut running_forever = false;
|
||||||
loop {
|
loop {
|
||||||
let (connection, _addr) = match listener.accept().await {
|
let accepted = if running_forever {
|
||||||
Ok((connection, addr)) => (connection, addr),
|
listener.accept().await
|
||||||
|
} else {
|
||||||
|
match either(shutdown_rx.recv(), listener.accept()).await {
|
||||||
|
Either::Left(res) => {
|
||||||
|
if res.is_ok() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// NOTE: We will only get a RecvError if the serve handle is dropped,
|
||||||
|
// this means lua has garbage collected it and the user does not want
|
||||||
|
// to manually stop the server using the serve handle. Run forever.
|
||||||
|
running_forever = true;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
Either::Right(acc) => acc,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (conn, addr) = match accepted {
|
||||||
|
Ok((conn, addr)) => (conn, addr),
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
eprintln!("Error while accepting connection: {err}");
|
eprintln!("Error while accepting connection: {err}");
|
||||||
continue;
|
continue;
|
||||||
|
@ -40,16 +67,22 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
|
||||||
};
|
};
|
||||||
|
|
||||||
lua.spawn_local({
|
lua.spawn_local({
|
||||||
let service = service.clone();
|
let rx = shutdown_rx.clone();
|
||||||
|
let io = HyperIo::from(conn);
|
||||||
|
let mut svc = service.clone();
|
||||||
|
svc.address = addr;
|
||||||
async move {
|
async move {
|
||||||
let result = Http1Builder::new()
|
let conn = Http1Builder::new()
|
||||||
.timer(HyperTimer)
|
.timer(HyperTimer)
|
||||||
.keep_alive(true) // Needed for websockets
|
.keep_alive(true)
|
||||||
.serve_connection(HyperIo::from(connection), service)
|
.serve_connection(io, svc)
|
||||||
.with_upgrades()
|
.with_upgrades();
|
||||||
.await;
|
// NOTE: Because we use keep_alive for websockets above, we need to
|
||||||
if let Err(err) = result {
|
// also manually poll this future and handle the shutdown signal here
|
||||||
eprintln!("Error while responding to request: {err}");
|
pin!(conn);
|
||||||
|
match either(rx.recv(), conn.as_mut()).await {
|
||||||
|
Either::Left(_) => conn.as_mut().graceful_shutdown(),
|
||||||
|
Either::Right(_) => {}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
@ -57,5 +90,5 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(())
|
Ok(handle)
|
||||||
}
|
}
|
||||||
|
|
|
@ -18,7 +18,7 @@ use crate::{
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(super) struct Service {
|
pub(super) struct Service {
|
||||||
pub(super) lua: Lua,
|
pub(super) lua: Lua,
|
||||||
pub(super) address: SocketAddr,
|
pub(super) address: SocketAddr, // NOTE: This should be the remote address of the connected client
|
||||||
pub(super) config: ServeConfig,
|
pub(super) config: ServeConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -29,11 +29,14 @@ impl HyperService<HyperRequest<Incoming>> for Service {
|
||||||
|
|
||||||
fn call(&self, req: HyperRequest<Incoming>) -> Self::Future {
|
fn call(&self, req: HyperRequest<Incoming>) -> Self::Future {
|
||||||
let lua = self.lua.clone();
|
let lua = self.lua.clone();
|
||||||
|
let address = self.address;
|
||||||
let config = self.config.clone();
|
let config = self.config.clone();
|
||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
let handler = config.handle_request.clone();
|
let handler = config.handle_request.clone();
|
||||||
let request = Request::from_incoming(req, true).await?;
|
let request = Request::from_incoming(req, true)
|
||||||
|
.await?
|
||||||
|
.with_address(address);
|
||||||
|
|
||||||
let thread_id = lua.push_thread_back(handler, request)?;
|
let thread_id = lua.push_thread_back(handler, request)?;
|
||||||
lua.track_thread(thread_id);
|
lua.track_thread(thread_id);
|
||||||
|
|
19
crates/lune-std-net/src/shared/futures.rs
Normal file
19
crates/lune-std-net/src/shared/futures.rs
Normal file
|
@ -0,0 +1,19 @@
|
||||||
|
use futures_lite::prelude::*;
|
||||||
|
|
||||||
|
pub use http_body_util::Either;
|
||||||
|
|
||||||
|
/**
|
||||||
|
Combines the left and right futures into a single future
|
||||||
|
that resolves to either the left or right output.
|
||||||
|
|
||||||
|
This combinator is biased - if both futures resolve at
|
||||||
|
the same time, the left future's output is returned.
|
||||||
|
*/
|
||||||
|
pub fn either<L: Future, R: Future>(
|
||||||
|
left: L,
|
||||||
|
right: R,
|
||||||
|
) -> impl Future<Output = Either<L::Output, R::Output>> {
|
||||||
|
let fut_left = async move { Either::Left(left.await) };
|
||||||
|
let fut_right = async move { Either::Right(right.await) };
|
||||||
|
fut_left.or(fut_right)
|
||||||
|
}
|
|
@ -1,3 +1,4 @@
|
||||||
|
pub mod futures;
|
||||||
pub mod headers;
|
pub mod headers;
|
||||||
pub mod hyper;
|
pub mod hyper;
|
||||||
pub mod incoming;
|
pub mod incoming;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, net::SocketAddr};
|
||||||
|
|
||||||
use http_body_util::Full;
|
use http_body_util::Full;
|
||||||
use url::Url;
|
use url::Url;
|
||||||
|
@ -24,7 +24,8 @@ pub struct Request {
|
||||||
// NOTE: We use Bytes instead of Full<Bytes> to avoid
|
// NOTE: We use Bytes instead of Full<Bytes> to avoid
|
||||||
// needing async when getting a reference to the body
|
// needing async when getting a reference to the body
|
||||||
pub(crate) inner: HyperRequest<Bytes>,
|
pub(crate) inner: HyperRequest<Bytes>,
|
||||||
pub(crate) redirects: usize,
|
pub(crate) address: Option<SocketAddr>,
|
||||||
|
pub(crate) redirects: Option<usize>,
|
||||||
pub(crate) decompress: bool,
|
pub(crate) decompress: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -42,11 +43,22 @@ impl Request {
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner: HyperRequest::from_parts(parts, body),
|
inner: HyperRequest::from_parts(parts, body),
|
||||||
redirects: 0,
|
address: None,
|
||||||
|
redirects: None,
|
||||||
decompress,
|
decompress,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
Attaches a socket address to the request.
|
||||||
|
|
||||||
|
This will make the `ip` and `port` fields available on the request.
|
||||||
|
*/
|
||||||
|
pub fn with_address(mut self, address: SocketAddr) -> Self {
|
||||||
|
self.address = Some(address);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
Returns the method of the request.
|
Returns the method of the request.
|
||||||
*/
|
*/
|
||||||
|
@ -154,7 +166,8 @@ impl TryFrom<RequestConfig> for Request {
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
inner,
|
inner,
|
||||||
redirects: 0,
|
address: None,
|
||||||
|
redirects: None,
|
||||||
decompress: config.options.decompress,
|
decompress: config.options.decompress,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -162,6 +175,12 @@ impl TryFrom<RequestConfig> for Request {
|
||||||
|
|
||||||
impl LuaUserData for Request {
|
impl LuaUserData for Request {
|
||||||
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
|
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
|
||||||
|
fields.add_field_method_get("ip", |_, this| {
|
||||||
|
Ok(this.address.map(|address| address.ip().to_string()))
|
||||||
|
});
|
||||||
|
fields.add_field_method_get("port", |_, this| {
|
||||||
|
Ok(this.address.map(|address| address.port()))
|
||||||
|
});
|
||||||
fields.add_field_method_get("method", |_, this| Ok(this.method().to_string()));
|
fields.add_field_method_get("method", |_, this| Ok(this.method().to_string()));
|
||||||
fields.add_field_method_get("path", |_, this| Ok(this.path().to_string()));
|
fields.add_field_method_get("path", |_, this| Ok(this.path().to_string()));
|
||||||
fields.add_field_method_get("query", |lua, this| {
|
fields.add_field_method_get("query", |lua, this| {
|
||||||
|
|
|
@ -28,7 +28,7 @@ local handle = net.serve(PORT, function()
|
||||||
end)
|
end)
|
||||||
|
|
||||||
task.delay(0.25, function()
|
task.delay(0.25, function()
|
||||||
handle.stop()
|
handle:stop()
|
||||||
end)
|
end)
|
||||||
|
|
||||||
test(net.serve, PORT, function() end)
|
test(net.serve, PORT, function() end)
|
||||||
|
|
|
@ -19,7 +19,7 @@ local handle = net.serve(PORT, {
|
||||||
local response = net.request(`{LOCALHOST}:{PORT}`).body
|
local response = net.request(`{LOCALHOST}:{PORT}`).body
|
||||||
assert(response ~= nil, "Invalid response from server")
|
assert(response ~= nil, "Invalid response from server")
|
||||||
|
|
||||||
handle.stop()
|
handle:stop()
|
||||||
|
|
||||||
-- Attempting to serve with a malformed IP address should throw an error
|
-- Attempting to serve with a malformed IP address should throw an error
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ end)
|
||||||
-- Stopping is not guaranteed to happen instantly since it is async, but
|
-- Stopping is not guaranteed to happen instantly since it is async, but
|
||||||
-- it should happen on the next yield, so we wait the minimum amount here
|
-- it should happen on the next yield, so we wait the minimum amount here
|
||||||
|
|
||||||
handle.stop()
|
handle:stop()
|
||||||
task.wait()
|
task.wait()
|
||||||
|
|
||||||
-- Sending a request to the stopped server should now error
|
-- Sending a request to the stopped server should now error
|
||||||
|
|
|
@ -21,4 +21,4 @@ end)
|
||||||
|
|
||||||
task.cancel(thread)
|
task.cancel(thread)
|
||||||
|
|
||||||
handle.stop()
|
handle:stop()
|
||||||
|
|
|
@ -7,24 +7,39 @@ local PORT = 8083
|
||||||
local URL = `http://127.0.0.1:{PORT}`
|
local URL = `http://127.0.0.1:{PORT}`
|
||||||
local RESPONSE = "Hello, lune!"
|
local RESPONSE = "Hello, lune!"
|
||||||
|
|
||||||
-- Serve should respond to a request we send to it
|
-- Serve should get proper path, query, and other request information
|
||||||
|
|
||||||
local handle = net.serve(PORT, function(request)
|
local handle = net.serve(PORT, function(request)
|
||||||
|
-- print("Got a request from", request.ip, "on port", request.port)
|
||||||
|
|
||||||
|
assert(type(request.path) == "string")
|
||||||
|
assert(type(request.query) == "table")
|
||||||
|
assert(type(request.query.key) == "table")
|
||||||
|
assert(type(request.query.key2) == "string")
|
||||||
|
|
||||||
assert(request.path == "/some/path")
|
assert(request.path == "/some/path")
|
||||||
assert(request.query.key == "param2")
|
assert(request.query.key[1] == "param1")
|
||||||
|
assert(request.query.key[2] == "param2")
|
||||||
assert(request.query.key2 == "param3")
|
assert(request.query.key2 == "param3")
|
||||||
|
|
||||||
return RESPONSE
|
return RESPONSE
|
||||||
end)
|
end)
|
||||||
|
|
||||||
|
-- Serve should be able to handle at least 1000 requests per second with a basic handler such as the above
|
||||||
|
|
||||||
local thread = task.delay(1, function()
|
local thread = task.delay(1, function()
|
||||||
stdio.ewrite("Serve should respond to requests in a reasonable amount of time\n")
|
stdio.ewrite("Serve should respond to requests in a reasonable amount of time\n")
|
||||||
task.wait(1)
|
task.wait(1)
|
||||||
process.exit(1)
|
process.exit(1)
|
||||||
end)
|
end)
|
||||||
|
|
||||||
local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body
|
-- Serve should respond to requests we send, and keep responding until we stop it
|
||||||
assert(response == RESPONSE, "Invalid response from server")
|
|
||||||
|
for _ = 1, 1024 do
|
||||||
|
local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body
|
||||||
|
assert(response == RESPONSE, "Invalid response from server")
|
||||||
|
end
|
||||||
|
|
||||||
task.cancel(thread)
|
task.cancel(thread)
|
||||||
|
|
||||||
handle.stop()
|
handle:stop()
|
||||||
|
|
|
@ -64,4 +64,4 @@ assert(
|
||||||
)
|
)
|
||||||
|
|
||||||
-- Stop the server to end the test
|
-- Stop the server to end the test
|
||||||
handle.stop()
|
handle:stop()
|
||||||
|
|
Loading…
Add table
Reference in a new issue