diff --git a/.lune/http_server.luau b/.lune/http_server.luau index 0d0b2d3..9d6fa73 100644 --- a/.lune/http_server.luau +++ b/.lune/http_server.luau @@ -49,5 +49,5 @@ print(`Listening on port {PORT} 🚀`) task.delay(2, function() print("Shutting down...") task.wait(1) - handle.stop() + handle:stop() end) diff --git a/.lune/websocket_server.luau b/.lune/websocket_server.luau index c6f620d..4efb47d 100644 --- a/.lune/websocket_server.luau +++ b/.lune/websocket_server.luau @@ -32,6 +32,6 @@ print(`Listening on port {PORT} 🚀`) task.delay(10, function() print("Shutting down...") task.wait(1) - handle.stop() + handle:stop() task.wait(1) end) diff --git a/Cargo.lock b/Cargo.lock index 06dff52..e7cc7bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1852,6 +1852,7 @@ dependencies = [ name = "lune-std-net" version = "0.2.0" dependencies = [ + "async-channel", "async-executor", "async-io", "async-lock", diff --git a/crates/lune-std-net/Cargo.toml b/crates/lune-std-net/Cargo.toml index b3ad290..9c6c70a 100644 --- a/crates/lune-std-net/Cargo.toml +++ b/crates/lune-std-net/Cargo.toml @@ -16,6 +16,7 @@ workspace = true mlua = { version = "0.10.3", features = ["luau"] } mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" } +async-channel = "2.3" async-executor = "1.13" async-io = "2.4" async-lock = "3.4" diff --git a/crates/lune-std-net/src/client/mod.rs b/crates/lune-std-net/src/client/mod.rs index 683e3e6..fda1a76 100644 --- a/crates/lune-std-net/src/client/mod.rs +++ b/crates/lune-std-net/src/client/mod.rs @@ -41,7 +41,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult .into_lua_err()?; if let Some((new_method, new_uri)) = check_redirect(&request.inner, &incoming) { - if request.redirects >= MAX_REDIRECTS { + if request.redirects.is_some_and(|r| r >= MAX_REDIRECTS) { return Err(LuaError::external("Too many redirects")); } @@ -52,7 +52,7 @@ pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult *request.inner.method_mut() = new_method; *request.inner.uri_mut() = new_uri; - request.redirects += 1; + *request.redirects.get_or_insert_default() += 1; continue; } diff --git a/crates/lune-std-net/src/lib.rs b/crates/lune-std-net/src/lib.rs index 25b4143..1effacf 100644 --- a/crates/lune-std-net/src/lib.rs +++ b/crates/lune-std-net/src/lib.rs @@ -10,7 +10,7 @@ pub(crate) mod url; use self::{ client::config::RequestConfig, - server::config::ServeConfig, + server::{config::ServeConfig, handle::ServeHandle}, shared::{request::Request, response::Response}, }; @@ -45,7 +45,7 @@ async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult { self::client::send_request(Request::try_from(config)?, lua).await } -async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<()> { +async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult { self::server::serve(lua, port, config).await } diff --git a/crates/lune-std-net/src/server/handle.rs b/crates/lune-std-net/src/server/handle.rs new file mode 100644 index 0000000..63bbb54 --- /dev/null +++ b/crates/lune-std-net/src/server/handle.rs @@ -0,0 +1,50 @@ +use std::{ + net::SocketAddr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use async_channel::{unbounded, Receiver, Sender}; + +use mlua::prelude::*; + +#[derive(Debug, Clone)] +pub struct ServeHandle { + addr: SocketAddr, + shutdown: Arc, + sender: Sender<()>, +} + +impl ServeHandle { + pub fn new(addr: SocketAddr) -> (Self, Receiver<()>) { + let (sender, receiver) = unbounded(); + let this = Self { + addr, + shutdown: Arc::new(AtomicBool::new(false)), + sender, + }; + (this, receiver) + } +} + +impl LuaUserData for ServeHandle { + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("ip", |_, this| Ok(this.addr.ip().to_string())); + fields.add_field_method_get("port", |_, this| Ok(this.addr.port())); + } + + fn add_methods>(methods: &mut M) { + methods.add_method("stop", |_, this, ()| { + if this.shutdown.load(Ordering::SeqCst) { + Err(LuaError::runtime("Server already stopped")) + } else { + this.shutdown.store(true, Ordering::SeqCst); + this.sender.try_send(()).ok(); + this.sender.close(); + Ok(()) + } + }); + } +} diff --git a/crates/lune-std-net/src/server/mod.rs b/crates/lune-std-net/src/server/mod.rs index 4c330c7..184431c 100644 --- a/crates/lune-std-net/src/server/mod.rs +++ b/crates/lune-std-net/src/server/mod.rs @@ -1,23 +1,30 @@ use std::net::SocketAddr; use async_net::TcpListener; +use futures_lite::pin; use hyper::server::conn::http1::Builder as Http1Builder; use mlua::prelude::*; use mlua_luau_scheduler::LuaSpawnExt; use crate::{ - server::{config::ServeConfig, service::Service}, - shared::hyper::{HyperIo, HyperTimer}, + server::{config::ServeConfig, handle::ServeHandle, service::Service}, + shared::{ + futures::{either, Either}, + hyper::{HyperIo, HyperTimer}, + }, }; pub mod config; +pub mod handle; pub mod service; /** Starts an HTTP server using the given port and configuration. + + Returns a `ServeHandle` that can be used to gracefully stop the server. */ -pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> { +pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult { let address = SocketAddr::from((config.address, port)); let service = Service { lua: lua.clone(), @@ -26,13 +33,33 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> { }; let listener = TcpListener::bind(address).await?; + let (handle, shutdown_rx) = ServeHandle::new(address); lua.spawn_local({ let lua = lua.clone(); async move { + let mut running_forever = false; loop { - let (connection, _addr) = match listener.accept().await { - Ok((connection, addr)) => (connection, addr), + let accepted = if running_forever { + listener.accept().await + } else { + match either(shutdown_rx.recv(), listener.accept()).await { + Either::Left(res) => { + if res.is_ok() { + break; + } + // NOTE: We will only get a RecvError if the serve handle is dropped, + // this means lua has garbage collected it and the user does not want + // to manually stop the server using the serve handle. Run forever. + running_forever = true; + continue; + } + Either::Right(acc) => acc, + } + }; + + let (conn, addr) = match accepted { + Ok((conn, addr)) => (conn, addr), Err(err) => { eprintln!("Error while accepting connection: {err}"); continue; @@ -40,16 +67,22 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> { }; lua.spawn_local({ - let service = service.clone(); + let rx = shutdown_rx.clone(); + let io = HyperIo::from(conn); + let mut svc = service.clone(); + svc.address = addr; async move { - let result = Http1Builder::new() + let conn = Http1Builder::new() .timer(HyperTimer) - .keep_alive(true) // Needed for websockets - .serve_connection(HyperIo::from(connection), service) - .with_upgrades() - .await; - if let Err(err) = result { - eprintln!("Error while responding to request: {err}"); + .keep_alive(true) + .serve_connection(io, svc) + .with_upgrades(); + // NOTE: Because we use keep_alive for websockets above, we need to + // also manually poll this future and handle the shutdown signal here + pin!(conn); + match either(rx.recv(), conn.as_mut()).await { + Either::Left(_) => conn.as_mut().graceful_shutdown(), + Either::Right(_) => {} } } }); @@ -57,5 +90,5 @@ pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<()> { } }); - Ok(()) + Ok(handle) } diff --git a/crates/lune-std-net/src/server/service.rs b/crates/lune-std-net/src/server/service.rs index 8b527bb..dff6574 100644 --- a/crates/lune-std-net/src/server/service.rs +++ b/crates/lune-std-net/src/server/service.rs @@ -18,7 +18,7 @@ use crate::{ #[derive(Debug, Clone)] pub(super) struct Service { pub(super) lua: Lua, - pub(super) address: SocketAddr, + pub(super) address: SocketAddr, // NOTE: This should be the remote address of the connected client pub(super) config: ServeConfig, } @@ -29,11 +29,14 @@ impl HyperService> for Service { fn call(&self, req: HyperRequest) -> Self::Future { let lua = self.lua.clone(); + let address = self.address; let config = self.config.clone(); Box::pin(async move { let handler = config.handle_request.clone(); - let request = Request::from_incoming(req, true).await?; + let request = Request::from_incoming(req, true) + .await? + .with_address(address); let thread_id = lua.push_thread_back(handler, request)?; lua.track_thread(thread_id); diff --git a/crates/lune-std-net/src/shared/futures.rs b/crates/lune-std-net/src/shared/futures.rs new file mode 100644 index 0000000..f00a55c --- /dev/null +++ b/crates/lune-std-net/src/shared/futures.rs @@ -0,0 +1,19 @@ +use futures_lite::prelude::*; + +pub use http_body_util::Either; + +/** + Combines the left and right futures into a single future + that resolves to either the left or right output. + + This combinator is biased - if both futures resolve at + the same time, the left future's output is returned. +*/ +pub fn either( + left: L, + right: R, +) -> impl Future> { + let fut_left = async move { Either::Left(left.await) }; + let fut_right = async move { Either::Right(right.await) }; + fut_left.or(fut_right) +} diff --git a/crates/lune-std-net/src/shared/mod.rs b/crates/lune-std-net/src/shared/mod.rs index 21a2272..43d3f74 100644 --- a/crates/lune-std-net/src/shared/mod.rs +++ b/crates/lune-std-net/src/shared/mod.rs @@ -1,3 +1,4 @@ +pub mod futures; pub mod headers; pub mod hyper; pub mod incoming; diff --git a/crates/lune-std-net/src/shared/request.rs b/crates/lune-std-net/src/shared/request.rs index 1a7d5a4..35a2048 100644 --- a/crates/lune-std-net/src/shared/request.rs +++ b/crates/lune-std-net/src/shared/request.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::{collections::HashMap, net::SocketAddr}; use http_body_util::Full; use url::Url; @@ -24,7 +24,8 @@ pub struct Request { // NOTE: We use Bytes instead of Full to avoid // needing async when getting a reference to the body pub(crate) inner: HyperRequest, - pub(crate) redirects: usize, + pub(crate) address: Option, + pub(crate) redirects: Option, pub(crate) decompress: bool, } @@ -42,11 +43,22 @@ impl Request { Ok(Self { inner: HyperRequest::from_parts(parts, body), - redirects: 0, + address: None, + redirects: None, decompress, }) } + /** + Attaches a socket address to the request. + + This will make the `ip` and `port` fields available on the request. + */ + pub fn with_address(mut self, address: SocketAddr) -> Self { + self.address = Some(address); + self + } + /** Returns the method of the request. */ @@ -154,7 +166,8 @@ impl TryFrom for Request { Ok(Self { inner, - redirects: 0, + address: None, + redirects: None, decompress: config.options.decompress, }) } @@ -162,6 +175,12 @@ impl TryFrom for Request { impl LuaUserData for Request { fn add_fields>(fields: &mut F) { + fields.add_field_method_get("ip", |_, this| { + Ok(this.address.map(|address| address.ip().to_string())) + }); + fields.add_field_method_get("port", |_, this| { + Ok(this.address.map(|address| address.port())) + }); fields.add_field_method_get("method", |_, this| Ok(this.method().to_string())); fields.add_field_method_get("path", |_, this| Ok(this.path().to_string())); fields.add_field_method_get("query", |lua, this| { diff --git a/tests/globals/pcall.luau b/tests/globals/pcall.luau index fa983cb..860b850 100644 --- a/tests/globals/pcall.luau +++ b/tests/globals/pcall.luau @@ -28,7 +28,7 @@ local handle = net.serve(PORT, function() end) task.delay(0.25, function() - handle.stop() + handle:stop() end) test(net.serve, PORT, function() end) diff --git a/tests/net/serve/addresses.luau b/tests/net/serve/addresses.luau index 57f591f..ac8ea11 100644 --- a/tests/net/serve/addresses.luau +++ b/tests/net/serve/addresses.luau @@ -19,7 +19,7 @@ local handle = net.serve(PORT, { local response = net.request(`{LOCALHOST}:{PORT}`).body assert(response ~= nil, "Invalid response from server") -handle.stop() +handle:stop() -- Attempting to serve with a malformed IP address should throw an error diff --git a/tests/net/serve/handles.luau b/tests/net/serve/handles.luau index 6b2b764..101dc25 100644 --- a/tests/net/serve/handles.luau +++ b/tests/net/serve/handles.luau @@ -12,7 +12,7 @@ end) -- Stopping is not guaranteed to happen instantly since it is async, but -- it should happen on the next yield, so we wait the minimum amount here -handle.stop() +handle:stop() task.wait() -- Sending a request to the stopped server should now error diff --git a/tests/net/serve/non_blocking.luau b/tests/net/serve/non_blocking.luau index 5463527..b05b14e 100644 --- a/tests/net/serve/non_blocking.luau +++ b/tests/net/serve/non_blocking.luau @@ -21,4 +21,4 @@ end) task.cancel(thread) -handle.stop() +handle:stop() diff --git a/tests/net/serve/requests.luau b/tests/net/serve/requests.luau index f2613fc..2d9d7e0 100644 --- a/tests/net/serve/requests.luau +++ b/tests/net/serve/requests.luau @@ -7,24 +7,39 @@ local PORT = 8083 local URL = `http://127.0.0.1:{PORT}` local RESPONSE = "Hello, lune!" --- Serve should respond to a request we send to it +-- Serve should get proper path, query, and other request information local handle = net.serve(PORT, function(request) + -- print("Got a request from", request.ip, "on port", request.port) + + assert(type(request.path) == "string") + assert(type(request.query) == "table") + assert(type(request.query.key) == "table") + assert(type(request.query.key2) == "string") + assert(request.path == "/some/path") - assert(request.query.key == "param2") + assert(request.query.key[1] == "param1") + assert(request.query.key[2] == "param2") assert(request.query.key2 == "param3") + return RESPONSE end) +-- Serve should be able to handle at least 1000 requests per second with a basic handler such as the above + local thread = task.delay(1, function() stdio.ewrite("Serve should respond to requests in a reasonable amount of time\n") task.wait(1) process.exit(1) end) -local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body -assert(response == RESPONSE, "Invalid response from server") +-- Serve should respond to requests we send, and keep responding until we stop it + +for _ = 1, 1024 do + local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body + assert(response == RESPONSE, "Invalid response from server") +end task.cancel(thread) -handle.stop() +handle:stop() diff --git a/tests/net/serve/websockets.luau b/tests/net/serve/websockets.luau index 51aea82..1b67b33 100644 --- a/tests/net/serve/websockets.luau +++ b/tests/net/serve/websockets.luau @@ -64,4 +64,4 @@ assert( ) -- Stop the server to end the test -handle.stop() +handle:stop()