From 53da0758e5bec5d36a24e3a2ea12a9d8ab43bcd8 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Fri, 16 Feb 2024 13:11:11 +0100 Subject: [PATCH] Optimize handling of registry for new net serve --- src/lune/builtins/net/server.rs | 131 ++++++++++++++++++-------------- 1 file changed, 74 insertions(+), 57 deletions(-) diff --git a/src/lune/builtins/net/server.rs b/src/lune/builtins/net/server.rs index fe6155d..3c32e20 100644 --- a/src/lune/builtins/net/server.rs +++ b/src/lune/builtins/net/server.rs @@ -5,6 +5,7 @@ use std::{ pin::Pin, rc::{Rc, Weak}, str::FromStr, + sync::atomic::{AtomicUsize, Ordering}, }; use http::request::Parts; @@ -146,49 +147,66 @@ impl FromLua<'_> for LuaResponse { } } +#[derive(Debug, Clone, Copy)] +struct SvcKeys { + key_request: &'static str, + key_websocket: Option<&'static str>, +} + +impl SvcKeys { + fn new<'lua>( + lua: &'lua Lua, + handle_request: LuaFunction<'lua>, + handle_websocket: Option>, + ) -> LuaResult { + static SERVE_COUNTER: AtomicUsize = AtomicUsize::new(0); + let count = SERVE_COUNTER.fetch_add(1, Ordering::Relaxed); + + // NOTE: We leak strings here, but this is an acceptable tradeoff since programs + // generally only start one or a couple of servers and they are usually never dropped. + // Leaking here lets us keep this struct Copy and access the request handler callbacks + // very performantly, significantly reducing the per-request overhead of the server. + let key_request: &'static str = + Box::leak(format!("__net_serve_request_{count}").into_boxed_str()); + let key_websocket: Option<&'static str> = if handle_websocket.is_some() { + Some(Box::leak( + format!("__net_serve_websocket_{count}").into_boxed_str(), + )) + } else { + None + }; + + lua.set_named_registry_value(key_request, handle_request)?; + if let Some(key) = key_websocket { + lua.set_named_registry_value(key, handle_websocket.unwrap())?; + } + + Ok(Self { + key_request, + key_websocket, + }) + } + + fn has_websocket_handler(&self) -> bool { + self.key_websocket.is_some() + } + + fn request_handler<'lua>(&self, lua: &'lua Lua) -> LuaResult> { + lua.named_registry_value(self.key_request) + } + + fn websocket_handler<'lua>(&self, lua: &'lua Lua) -> LuaResult>> { + self.key_websocket + .map(|key| lua.named_registry_value(key)) + .transpose() + } +} + +#[derive(Debug, Clone)] struct Svc { lua: Rc, addr: SocketAddr, - handler_request: LuaRegistryKey, - handler_websocket: LuaRegistryKey, - has_websocket_handler: bool, -} - -impl Svc { - fn clone_registry_keys(&self) -> (LuaRegistryKey, LuaRegistryKey) { - let cloned_request = self - .lua - .registry_value::(&self.handler_request) - .expect("Failed to clone registry value"); - let cloned_websocket = self - .lua - .registry_value::>(&self.handler_websocket) - .expect("Failed to clone registry value"); - - let stored_request = self - .lua - .create_registry_value(cloned_request) - .expect("Failed to clone registry value"); - let stored_websocket = self - .lua - .create_registry_value(cloned_websocket) - .expect("Failed to clone registry value"); - - (stored_request, stored_websocket) - } -} - -impl Clone for Svc { - fn clone(&self) -> Self { - let (handler_request, handler_websocket) = self.clone_registry_keys(); - Self { - lua: self.lua.clone(), - addr: self.addr, - handler_request, - handler_websocket, - has_websocket_handler: self.has_websocket_handler, - } - } + keys: SvcKeys, } impl Service> for Svc { @@ -197,26 +215,25 @@ impl Service> for Svc { type Future = Pin>>>; fn call(&self, req: Request) -> Self::Future { - let addr = self.addr; let lua = self.lua.clone(); + let addr = self.addr; + let keys = self.keys; - let (handler_request, handler_websocket) = self.clone_registry_keys(); - - if self.has_websocket_handler && is_upgrade_request(&req) { + if keys.has_websocket_handler() && is_upgrade_request(&req) { Box::pin(async move { let (res, sock) = upgrade(req, None).into_lua_err()?; let lua_inner = lua.clone(); lua.spawn_local(async move { let sock = sock.await.unwrap(); - let lua_sock = NetWebSocket::new(sock).into_lua_table(&lua_inner).unwrap(); + let lua_sock = NetWebSocket::new(sock); + let lua_tab = lua_sock.into_lua_table(&lua_inner).unwrap(); - let handler_websocket = lua_inner - .registry_value::(&handler_websocket) - .unwrap(); + let handler_websocket: LuaFunction = + keys.websocket_handler(&lua_inner).unwrap().unwrap(); lua_inner - .push_thread_back(handler_websocket, lua_sock) + .push_thread_back(handler_websocket, lua_tab) .unwrap(); }); @@ -226,9 +243,10 @@ impl Service> for Svc { let (head, body) = req.into_parts(); Box::pin(async move { - let handler_request = lua.registry_value::(&handler_request)?; + let handler_request: LuaFunction = keys.request_handler(&lua).unwrap(); - let body = body.collect().await.into_lua_err()?.to_bytes().to_vec(); + let body = body.collect().await.into_lua_err()?; + let body = body.to_bytes().to_vec(); let lua_req = LuaRequest { _remote_addr: addr, @@ -257,7 +275,7 @@ pub async fn serve<'lua>( let addr: SocketAddr = (config.address, port).into(); let listener = TcpListener::bind(addr).await?; - let (lua_inner, lua_inner_2) = { + let (lua_svc, lua_inner) = { let rc = lua .app_data_ref::>() .expect("Missing weak lua ref") @@ -266,12 +284,11 @@ pub async fn serve<'lua>( (Rc::clone(&rc), rc) }; + let keys = SvcKeys::new(lua, config.handle_request, config.handle_web_socket)?; let svc = Svc { - lua: lua_inner, + lua: lua_svc, addr, - has_websocket_handler: config.handle_web_socket.is_some(), - handler_request: lua.create_registry_value(config.handle_request)?, - handler_websocket: lua.create_registry_value(config.handle_web_socket)?, + keys, }; let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); @@ -290,7 +307,7 @@ pub async fn serve<'lua>( let svc = svc.clone(); let mut shutdown_rx_inner = shutdown_rx.clone(); - lua_inner_2.spawn_local(async move { + lua_inner.spawn_local(async move { let conn = http1::Builder::new() .keep_alive(true) // Web sockets need this .serve_connection(io, svc)