From 1f887cef07fdd4e4c4287c3bb88dd9cf146f0e83 Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Tue, 21 Mar 2023 15:46:14 +0100 Subject: [PATCH] Fix websockets memory leak --- packages/lib/src/lua/create.rs | 17 +- packages/lib/src/lua/net/websocket.rs | 273 ++++++++++++++++---------- 2 files changed, 185 insertions(+), 105 deletions(-) diff --git a/packages/lib/src/lua/create.rs b/packages/lib/src/lua/create.rs index 4d76ad0..5c97488 100644 --- a/packages/lib/src/lua/create.rs +++ b/packages/lib/src/lua/create.rs @@ -64,10 +64,14 @@ end * `"co.yield"` -> `coroutine.yield` * `"co.close"` -> `coroutine.close` --- + * `"tab.pack"` -> `table.pack` + * `"tab.unpack"` -> `table.unpack` + * `"tab.freeze"` -> `table.freeze` + * `"tab.getmeta"` -> `getmetatable` + * `"tab.setmeta"` -> `setmetatable` + --- * `"dbg.info"` -> `debug.info` * `"dbg.trace"` -> `debug.traceback` - * `"dbg.iserr"` -> `` - * `"dbg.makeerr"` -> `` --- */ pub fn create() -> LuaResult<&'static Lua> { @@ -93,6 +97,15 @@ pub fn create() -> LuaResult<&'static Lua> { lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; lua.set_named_registry_value("tab.pack", table.get::<_, LuaFunction>("pack")?)?; lua.set_named_registry_value("tab.unpack", table.get::<_, LuaFunction>("unpack")?)?; + lua.set_named_registry_value("tab.freeze", table.get::<_, LuaFunction>("freeze")?)?; + lua.set_named_registry_value( + "tab.getmeta", + globals.get::<_, LuaFunction>("getmetatable")?, + )?; + lua.set_named_registry_value( + "tab.setmeta", + globals.get::<_, LuaFunction>("setmetatable")?, + )?; // Create a trace function that can be called to obtain a full stack trace from // lua, this is not possible to do from rust when using our manual scheduler let dbg_trace_env = lua.create_table_with_capacity(0, 1)?; diff --git a/packages/lib/src/lua/net/websocket.rs b/packages/lib/src/lua/net/websocket.rs index f495f23..0c09827 100644 --- a/packages/lib/src/lua/net/websocket.rs +++ b/packages/lib/src/lua/net/websocket.rs @@ -1,7 +1,15 @@ use std::{cell::Cell, sync::Arc}; +use hyper::upgrade::Upgraded; use mlua::prelude::*; +use futures_util::{SinkExt, StreamExt}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, + sync::Mutex as AsyncMutex, +}; + use hyper_tungstenite::{ tungstenite::{ protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame}, @@ -9,19 +17,43 @@ use hyper_tungstenite::{ }, WebSocketStream, }; - -use futures_util::{SinkExt, StreamExt}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::Mutex, -}; +use tokio_tungstenite::MaybeTlsStream; use crate::lua::table::TableBuilder; -#[derive(Debug, Clone)] +const WEB_SOCKET_IMPL_LUA: &str = r#" +return freeze(setmetatable({ + close = function(...) + return close(websocket, ...) + end, + send = function(...) + return send(websocket, ...) + end, + next = function(...) + return next(websocket, ...) + end, +}, { + __index = function(self, key) + if key == "closeCode" then + return close_code() + end + end, +})) +"#; + +#[derive(Debug)] pub struct NetWebSocket { - close_code: Cell>, - stream: Arc>>, + close_code: Arc>>, + stream: Arc>>, +} + +impl Clone for NetWebSocket { + fn clone(&self) -> Self { + Self { + close_code: Arc::clone(&self.close_code), + stream: Arc::clone(&self.stream), + } + } } impl NetWebSocket @@ -30,107 +62,142 @@ where { pub fn new(value: WebSocketStream) -> Self { Self { - close_code: Cell::new(None), - stream: Arc::new(Mutex::new(value)), + close_code: Arc::new(Cell::new(None)), + stream: Arc::new(AsyncMutex::new(value)), } } - pub fn get_lua_close_code(&self) -> LuaValue { - match self.close_code.get() { - Some(code) => LuaValue::Number(code as f64), - None => LuaValue::Nil, - } - } - - pub async fn close(&self, code: Option) -> LuaResult<()> { - let mut ws = self.stream.lock().await; - let res = ws.close(Some(WsCloseFrame { - code: match code { - Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), - Some(code) => { - return Err(LuaError::RuntimeError(format!( - "Close code must be between 1000 and 4999, got {code}" - ))) - } - None => WsCloseCode::Normal, - }, - reason: "".into(), - })); - res.await.map_err(LuaError::external) - } - - pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { - let mut ws = self.stream.lock().await; - ws.send(msg).await.map_err(LuaError::external) - } - - pub async fn send_lua_string<'lua>( - &self, - string: LuaString<'lua>, - as_binary: Option, - ) -> LuaResult<()> { - let msg = if matches!(as_binary, Some(true)) { - WsMessage::Binary(string.as_bytes().to_vec()) - } else { - let s = string.to_str().map_err(LuaError::external)?; - WsMessage::Text(s.to_string()) - }; - self.send(msg).await - } - - pub async fn next(&self) -> LuaResult> { - let mut ws = self.stream.lock().await; - let item = ws.next().await.transpose().map_err(LuaError::external); - match item { - Ok(Some(WsMessage::Close(msg))) => { - if let Some(msg) = &msg { - self.close_code.replace(Some(msg.code.into())); - } - Ok(Some(WsMessage::Close(msg))) - } - val => val, - } - } - - pub async fn next_lua_string<'lua>(&'lua self, lua: &'lua Lua) -> LuaResult { - while let Some(msg) = self.next().await? { - let msg_string_opt = match msg { - WsMessage::Binary(bin) => Some(lua.create_string(&bin)?), - WsMessage::Text(txt) => Some(lua.create_string(&txt)?), - // Stop waiting for next message if we get a close message - WsMessage::Close(_) => return Ok(LuaValue::Nil), - // Ignore ping/pong/frame messages, they are handled by tungstenite - _ => None, - }; - if let Some(msg_string) = msg_string_opt { - return Ok(LuaValue::String(msg_string)); - } - } - Ok(LuaValue::Nil) + fn into_lua_table_with_env<'lua>( + lua: &'lua Lua, + env: LuaTable<'lua>, + ) -> LuaResult> { + lua.load(WEB_SOCKET_IMPL_LUA) + .set_name("websocket")? + .set_environment(env)? + .eval() } } -impl NetWebSocket -where - T: AsyncRead + AsyncWrite + Unpin + 'static, -{ +type NetWebSocketStreamClient = MaybeTlsStream; +impl NetWebSocket { pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { - let ws = Box::leak(Box::new(self)); - TableBuilder::new(lua)? - .with_async_function("close", |_, code| ws.close(code))? - .with_async_function("send", |_, (msg, bin)| ws.send_lua_string(msg, bin))? - .with_async_function("next", |lua, _: ()| ws.next_lua_string(lua))? - .with_metatable( - TableBuilder::new(lua)? - .with_function(LuaMetaMethod::Index.name(), |_, key: String| { - if key == "closeCode" { - Ok(ws.get_lua_close_code()) - } else { - Ok(LuaValue::Nil) - } - })? - .build_readonly()?, + let socket_env = TableBuilder::new(lua)? + .with_value("websocket", self)? + .with_function("close_code", close_code::)? + .with_async_function("close", close::)? + .with_async_function("send", send::)? + .with_async_function("next", next::)? + .with_value( + "setmetatable", + lua.named_registry_value::<_, LuaFunction>("tab.setmeta")?, )? - .build_readonly() + .with_value( + "freeze", + lua.named_registry_value::<_, LuaFunction>("tab.freeze")?, + )? + .build_readonly()?; + Self::into_lua_table_with_env(lua, socket_env) } } + +type NetWebSocketStreamServer = Upgraded; +impl NetWebSocket { + pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { + let socket_env = TableBuilder::new(lua)? + .with_value("websocket", self)? + .with_function("close_code", close_code::)? + .with_async_function("close", close::)? + .with_async_function("send", send::)? + .with_async_function("next", next::)? + .with_value( + "setmetatable", + lua.named_registry_value::<_, LuaFunction>("tab.setmeta")?, + )? + .with_value( + "freeze", + lua.named_registry_value::<_, LuaFunction>("tab.freeze")?, + )? + .build_readonly()?; + Self::into_lua_table_with_env(lua, socket_env) + } +} + +impl LuaUserData for NetWebSocket {} + +fn close_code(_lua: &Lua, socket: NetWebSocket) -> LuaResult +where + T: AsyncRead + AsyncWrite + Unpin, +{ + Ok(match socket.close_code.get() { + Some(code) => LuaValue::Number(code as f64), + None => LuaValue::Nil, + }) +} + +async fn close(_lua: &Lua, (socket, code): (NetWebSocket, Option)) -> LuaResult<()> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut ws = socket.stream.lock().await; + let res = ws.close(Some(WsCloseFrame { + code: match code { + Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), + Some(code) => { + return Err(LuaError::RuntimeError(format!( + "Close code must be between 1000 and 4999, got {code}" + ))) + } + None => WsCloseCode::Normal, + }, + reason: "".into(), + })); + res.await.map_err(LuaError::external) +} + +async fn send( + _lua: &Lua, + (socket, string, as_binary): (NetWebSocket, LuaString<'_>, Option), +) -> LuaResult<()> +where + T: AsyncRead + AsyncWrite + Unpin, +{ + let msg = if matches!(as_binary, Some(true)) { + WsMessage::Binary(string.as_bytes().to_vec()) + } else { + let s = string.to_str().map_err(LuaError::external)?; + WsMessage::Text(s.to_string()) + }; + let mut ws = socket.stream.lock().await; + ws.send(msg).await.map_err(LuaError::external) +} + +async fn next(lua: &Lua, socket: NetWebSocket) -> LuaResult +where + T: AsyncRead + AsyncWrite + Unpin, +{ + let mut ws = socket.stream.lock().await; + let item = ws.next().await.transpose().map_err(LuaError::external); + let msg = match item { + Ok(Some(WsMessage::Close(msg))) => { + if let Some(msg) = &msg { + socket.close_code.replace(Some(msg.code.into())); + } + Ok(Some(WsMessage::Close(msg))) + } + val => val, + }?; + while let Some(msg) = &msg { + let msg_string_opt = match msg { + WsMessage::Binary(bin) => Some(lua.create_string(&bin)?), + WsMessage::Text(txt) => Some(lua.create_string(&txt)?), + // Stop waiting for next message if we get a close message + WsMessage::Close(_) => return Ok(LuaValue::Nil), + // Ignore ping/pong/frame messages, they are handled by tungstenite + _ => None, + }; + if let Some(msg_string) = msg_string_opt { + return Ok(LuaValue::String(msg_string)); + } + } + Ok(LuaValue::Nil) +}