mirror of
https://github.com/lune-org/lune.git
synced 2025-05-04 10:43:57 +01:00
Implement serve handle again and proper graceful shutdowns
This commit is contained in:
parent
d57a7b949d
commit
a82cb1da33
18 changed files with 178 additions and 36 deletions
|
@ -49,5 +49,5 @@ print(`Listening on port {PORT} 🚀`)
|
|||
task.delay(2, function()
|
||||
print("Shutting down...")
|
||||
task.wait(1)
|
||||
handle.stop()
|
||||
handle:stop()
|
||||
end)
|
||||
|
|
|
@ -32,6 +32,6 @@ print(`Listening on port {PORT} 🚀`)
|
|||
task.delay(10, function()
|
||||
print("Shutting down...")
|
||||
task.wait(1)
|
||||
handle.stop()
|
||||
handle:stop()
|
||||
task.wait(1)
|
||||
end)
|
||||
|
|
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -1852,6 +1852,7 @@ dependencies = [
|
|||
name = "lune-std-net"
|
||||
version = "0.2.0"
|
||||
dependencies = [
|
||||
"async-channel",
|
||||
"async-executor",
|
||||
"async-io",
|
||||
"async-lock",
|
||||
|
|
|
@ -16,6 +16,7 @@ workspace = true
|
|||
mlua = { version = "0.10.3", features = ["luau"] }
|
||||
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" }
|
||||
|
||||
async-channel = "2.3"
|
||||
async-executor = "1.13"
|
||||
async-io = "2.4"
|
||||
async-lock = "3.4"
|
||||
|
|
|
@ -41,7 +41,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response>
|
|||
.into_lua_err()?;
|
||||
|
||||
if let Some((new_method, new_uri)) = check_redirect(&request.inner, &incoming) {
|
||||
if request.redirects >= MAX_REDIRECTS {
|
||||
if request.redirects.is_some_and(|r| r >= MAX_REDIRECTS) {
|
||||
return Err(LuaError::external("Too many redirects"));
|
||||
}
|
||||
|
||||
|
@ -52,7 +52,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response>
|
|||
*request.inner.method_mut() = new_method;
|
||||
*request.inner.uri_mut() = new_uri;
|
||||
|
||||
request.redirects += 1;
|
||||
*request.redirects.get_or_insert_default() += 1;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
|
|
@ -10,7 +10,7 @@ pub(crate) mod url;
|
|||
|
||||
use self::{
|
||||
client::config::RequestConfig,
|
||||
server::config::ServeConfig,
|
||||
server::{config::ServeConfig, handle::ServeHandle},
|
||||
shared::{request::Request, response::Response},
|
||||
};
|
||||
|
||||
|
@ -45,7 +45,7 @@ async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult<Response> {
|
|||
self::client::send_request(Request::try_from(config)?, lua).await
|
||||
}
|
||||
|
||||
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<()> {
|
||||
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<ServeHandle> {
|
||||
self::server::serve(lua, port, config).await
|
||||
}
|
||||
|
||||
|
|
50
crates/lune-std-net/src/server/handle.rs
Normal file
50
crates/lune-std-net/src/server/handle.rs
Normal file
|
@ -0,0 +1,50 @@
|
|||
use std::{
|
||||
net::SocketAddr,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use async_channel::{unbounded, Receiver, Sender};
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ServeHandle {
|
||||
addr: SocketAddr,
|
||||
shutdown: Arc<AtomicBool>,
|
||||
sender: Sender<()>,
|
||||
}
|
||||
|
||||
impl ServeHandle {
|
||||
pub fn new(addr: SocketAddr) -> (Self, Receiver<()>) {
|
||||
let (sender, receiver) = unbounded();
|
||||
let this = Self {
|
||||
addr,
|
||||
shutdown: Arc::new(AtomicBool::new(false)),
|
||||
sender,
|
||||
};
|
||||
(this, receiver)
|
||||
}
|
||||
}
|
||||
|
||||
impl LuaUserData for ServeHandle {
|
||||
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
|
||||
fields.add_field_method_get("ip", |_, this| Ok(this.addr.ip().to_string()));
|
||||
fields.add_field_method_get("port", |_, this| Ok(this.addr.port()));
|
||||
}
|
||||
|
||||
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
|
||||
methods.add_method("stop", |_, this, ()| {
|
||||
if this.shutdown.load(Ordering::SeqCst) {
|
||||
Err(LuaError::runtime("Server already stopped"))
|
||||
} else {
|
||||
this.shutdown.store(true, Ordering::SeqCst);
|
||||
this.sender.try_send(()).ok();
|
||||
this.sender.close();
|
||||
Ok(())
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
|
@ -1,23 +1,30 @@
|
|||
use std::net::SocketAddr;
|
||||
|
||||
use async_net::TcpListener;
|
||||
use futures_lite::pin;
|
||||
use hyper::server::conn::http1::Builder as Http1Builder;
|
||||
|
||||
use mlua::prelude::*;
|
||||
use mlua_luau_scheduler::LuaSpawnExt;
|
||||
|
||||
use crate::{
|
||||
server::{config::ServeConfig, service::Service},
|
||||
shared::hyper::{HyperIo, HyperTimer},
|
||||
server::{config::ServeConfig, handle::ServeHandle, service::Service},
|
||||
shared::{
|
||||
futures::{either, Either},
|
||||
hyper::{HyperIo, HyperTimer},
|
||||
},
|
||||
};
|
||||
|
||||
pub mod config;
|
||||
pub mod handle;
|
||||
pub mod service;
|
||||
|
||||
/**
|
||||
Starts an HTTP server using the given port and configuration.
|
||||
|
||||
Returns a `ServeHandle` that can be used to gracefully stop the server.
|
||||
*/
|
||||
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
|
||||
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<ServeHandle> {
|
||||
let address = SocketAddr::from((config.address, port));
|
||||
let service = Service {
|
||||
lua: lua.clone(),
|
||||
|
@ -26,13 +33,33 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
|
|||
};
|
||||
|
||||
let listener = TcpListener::bind(address).await?;
|
||||
let (handle, shutdown_rx) = ServeHandle::new(address);
|
||||
|
||||
lua.spawn_local({
|
||||
let lua = lua.clone();
|
||||
async move {
|
||||
let mut running_forever = false;
|
||||
loop {
|
||||
let (connection, _addr) = match listener.accept().await {
|
||||
Ok((connection, addr)) => (connection, addr),
|
||||
let accepted = if running_forever {
|
||||
listener.accept().await
|
||||
} else {
|
||||
match either(shutdown_rx.recv(), listener.accept()).await {
|
||||
Either::Left(res) => {
|
||||
if res.is_ok() {
|
||||
break;
|
||||
}
|
||||
// NOTE: We will only get a RecvError if the serve handle is dropped,
|
||||
// this means lua has garbage collected it and the user does not want
|
||||
// to manually stop the server using the serve handle. Run forever.
|
||||
running_forever = true;
|
||||
continue;
|
||||
}
|
||||
Either::Right(acc) => acc,
|
||||
}
|
||||
};
|
||||
|
||||
let (conn, addr) = match accepted {
|
||||
Ok((conn, addr)) => (conn, addr),
|
||||
Err(err) => {
|
||||
eprintln!("Error while accepting connection: {err}");
|
||||
continue;
|
||||
|
@ -40,16 +67,22 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
|
|||
};
|
||||
|
||||
lua.spawn_local({
|
||||
let service = service.clone();
|
||||
let rx = shutdown_rx.clone();
|
||||
let io = HyperIo::from(conn);
|
||||
let mut svc = service.clone();
|
||||
svc.address = addr;
|
||||
async move {
|
||||
let result = Http1Builder::new()
|
||||
let conn = Http1Builder::new()
|
||||
.timer(HyperTimer)
|
||||
.keep_alive(true) // Needed for websockets
|
||||
.serve_connection(HyperIo::from(connection), service)
|
||||
.with_upgrades()
|
||||
.await;
|
||||
if let Err(err) = result {
|
||||
eprintln!("Error while responding to request: {err}");
|
||||
.keep_alive(true)
|
||||
.serve_connection(io, svc)
|
||||
.with_upgrades();
|
||||
// NOTE: Because we use keep_alive for websockets above, we need to
|
||||
// also manually poll this future and handle the shutdown signal here
|
||||
pin!(conn);
|
||||
match either(rx.recv(), conn.as_mut()).await {
|
||||
Either::Left(_) => conn.as_mut().graceful_shutdown(),
|
||||
Either::Right(_) => {}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
@ -57,5 +90,5 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> {
|
|||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
Ok(handle)
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@ use crate::{
|
|||
#[derive(Debug, Clone)]
|
||||
pub(super) struct Service {
|
||||
pub(super) lua: Lua,
|
||||
pub(super) address: SocketAddr,
|
||||
pub(super) address: SocketAddr, // NOTE: This should be the remote address of the connected client
|
||||
pub(super) config: ServeConfig,
|
||||
}
|
||||
|
||||
|
@ -29,11 +29,14 @@ impl HyperService<HyperRequest<Incoming>> for Service {
|
|||
|
||||
fn call(&self, req: HyperRequest<Incoming>) -> Self::Future {
|
||||
let lua = self.lua.clone();
|
||||
let address = self.address;
|
||||
let config = self.config.clone();
|
||||
|
||||
Box::pin(async move {
|
||||
let handler = config.handle_request.clone();
|
||||
let request = Request::from_incoming(req, true).await?;
|
||||
let request = Request::from_incoming(req, true)
|
||||
.await?
|
||||
.with_address(address);
|
||||
|
||||
let thread_id = lua.push_thread_back(handler, request)?;
|
||||
lua.track_thread(thread_id);
|
||||
|
|
19
crates/lune-std-net/src/shared/futures.rs
Normal file
19
crates/lune-std-net/src/shared/futures.rs
Normal file
|
@ -0,0 +1,19 @@
|
|||
use futures_lite::prelude::*;
|
||||
|
||||
pub use http_body_util::Either;
|
||||
|
||||
/**
|
||||
Combines the left and right futures into a single future
|
||||
that resolves to either the left or right output.
|
||||
|
||||
This combinator is biased - if both futures resolve at
|
||||
the same time, the left future's output is returned.
|
||||
*/
|
||||
pub fn either<L: Future, R: Future>(
|
||||
left: L,
|
||||
right: R,
|
||||
) -> impl Future<Output = Either<L::Output, R::Output>> {
|
||||
let fut_left = async move { Either::Left(left.await) };
|
||||
let fut_right = async move { Either::Right(right.await) };
|
||||
fut_left.or(fut_right)
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
pub mod futures;
|
||||
pub mod headers;
|
||||
pub mod hyper;
|
||||
pub mod incoming;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, net::SocketAddr};
|
||||
|
||||
use http_body_util::Full;
|
||||
use url::Url;
|
||||
|
@ -24,7 +24,8 @@ pub struct Request {
|
|||
// NOTE: We use Bytes instead of Full<Bytes> to avoid
|
||||
// needing async when getting a reference to the body
|
||||
pub(crate) inner: HyperRequest<Bytes>,
|
||||
pub(crate) redirects: usize,
|
||||
pub(crate) address: Option<SocketAddr>,
|
||||
pub(crate) redirects: Option<usize>,
|
||||
pub(crate) decompress: bool,
|
||||
}
|
||||
|
||||
|
@ -42,11 +43,22 @@ impl Request {
|
|||
|
||||
Ok(Self {
|
||||
inner: HyperRequest::from_parts(parts, body),
|
||||
redirects: 0,
|
||||
address: None,
|
||||
redirects: None,
|
||||
decompress,
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
Attaches a socket address to the request.
|
||||
|
||||
This will make the `ip` and `port` fields available on the request.
|
||||
*/
|
||||
pub fn with_address(mut self, address: SocketAddr) -> Self {
|
||||
self.address = Some(address);
|
||||
self
|
||||
}
|
||||
|
||||
/**
|
||||
Returns the method of the request.
|
||||
*/
|
||||
|
@ -154,7 +166,8 @@ impl TryFrom<RequestConfig> for Request {
|
|||
|
||||
Ok(Self {
|
||||
inner,
|
||||
redirects: 0,
|
||||
address: None,
|
||||
redirects: None,
|
||||
decompress: config.options.decompress,
|
||||
})
|
||||
}
|
||||
|
@ -162,6 +175,12 @@ impl TryFrom<RequestConfig> for Request {
|
|||
|
||||
impl LuaUserData for Request {
|
||||
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
|
||||
fields.add_field_method_get("ip", |_, this| {
|
||||
Ok(this.address.map(|address| address.ip().to_string()))
|
||||
});
|
||||
fields.add_field_method_get("port", |_, this| {
|
||||
Ok(this.address.map(|address| address.port()))
|
||||
});
|
||||
fields.add_field_method_get("method", |_, this| Ok(this.method().to_string()));
|
||||
fields.add_field_method_get("path", |_, this| Ok(this.path().to_string()));
|
||||
fields.add_field_method_get("query", |lua, this| {
|
||||
|
|
|
@ -28,7 +28,7 @@ local handle = net.serve(PORT, function()
|
|||
end)
|
||||
|
||||
task.delay(0.25, function()
|
||||
handle.stop()
|
||||
handle:stop()
|
||||
end)
|
||||
|
||||
test(net.serve, PORT, function() end)
|
||||
|
|
|
@ -19,7 +19,7 @@ local handle = net.serve(PORT, {
|
|||
local response = net.request(`{LOCALHOST}:{PORT}`).body
|
||||
assert(response ~= nil, "Invalid response from server")
|
||||
|
||||
handle.stop()
|
||||
handle:stop()
|
||||
|
||||
-- Attempting to serve with a malformed IP address should throw an error
|
||||
|
||||
|
|
|
@ -12,7 +12,7 @@ end)
|
|||
-- Stopping is not guaranteed to happen instantly since it is async, but
|
||||
-- it should happen on the next yield, so we wait the minimum amount here
|
||||
|
||||
handle.stop()
|
||||
handle:stop()
|
||||
task.wait()
|
||||
|
||||
-- Sending a request to the stopped server should now error
|
||||
|
|
|
@ -21,4 +21,4 @@ end)
|
|||
|
||||
task.cancel(thread)
|
||||
|
||||
handle.stop()
|
||||
handle:stop()
|
||||
|
|
|
@ -7,24 +7,39 @@ local PORT = 8083
|
|||
local URL = `http://127.0.0.1:{PORT}`
|
||||
local RESPONSE = "Hello, lune!"
|
||||
|
||||
-- Serve should respond to a request we send to it
|
||||
-- Serve should get proper path, query, and other request information
|
||||
|
||||
local handle = net.serve(PORT, function(request)
|
||||
-- print("Got a request from", request.ip, "on port", request.port)
|
||||
|
||||
assert(type(request.path) == "string")
|
||||
assert(type(request.query) == "table")
|
||||
assert(type(request.query.key) == "table")
|
||||
assert(type(request.query.key2) == "string")
|
||||
|
||||
assert(request.path == "/some/path")
|
||||
assert(request.query.key == "param2")
|
||||
assert(request.query.key[1] == "param1")
|
||||
assert(request.query.key[2] == "param2")
|
||||
assert(request.query.key2 == "param3")
|
||||
|
||||
return RESPONSE
|
||||
end)
|
||||
|
||||
-- Serve should be able to handle at least 1000 requests per second with a basic handler such as the above
|
||||
|
||||
local thread = task.delay(1, function()
|
||||
stdio.ewrite("Serve should respond to requests in a reasonable amount of time\n")
|
||||
task.wait(1)
|
||||
process.exit(1)
|
||||
end)
|
||||
|
||||
local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body
|
||||
assert(response == RESPONSE, "Invalid response from server")
|
||||
-- Serve should respond to requests we send, and keep responding until we stop it
|
||||
|
||||
for _ = 1, 1024 do
|
||||
local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body
|
||||
assert(response == RESPONSE, "Invalid response from server")
|
||||
end
|
||||
|
||||
task.cancel(thread)
|
||||
|
||||
handle.stop()
|
||||
handle:stop()
|
||||
|
|
|
@ -64,4 +64,4 @@ assert(
|
|||
)
|
||||
|
||||
-- Stop the server to end the test
|
||||
handle.stop()
|
||||
handle:stop()
|
||||
|
|
Loading…
Add table
Reference in a new issue