Fix websockets memory leak

This commit is contained in:
Filip Tibell 2023-03-21 15:46:14 +01:00
parent 29a3b41e15
commit 1f887cef07
No known key found for this signature in database
2 changed files with 185 additions and 105 deletions

View file

@ -64,10 +64,14 @@ end
* `"co.yield"` -> `coroutine.yield` * `"co.yield"` -> `coroutine.yield`
* `"co.close"` -> `coroutine.close` * `"co.close"` -> `coroutine.close`
--- ---
* `"tab.pack"` -> `table.pack`
* `"tab.unpack"` -> `table.unpack`
* `"tab.freeze"` -> `table.freeze`
* `"tab.getmeta"` -> `getmetatable`
* `"tab.setmeta"` -> `setmetatable`
---
* `"dbg.info"` -> `debug.info` * `"dbg.info"` -> `debug.info`
* `"dbg.trace"` -> `debug.traceback` * `"dbg.trace"` -> `debug.traceback`
* `"dbg.iserr"` -> `<custom function>`
* `"dbg.makeerr"` -> `<custom function>`
--- ---
*/ */
pub fn create() -> LuaResult<&'static Lua> { pub fn create() -> LuaResult<&'static Lua> {
@ -93,6 +97,15 @@ pub fn create() -> LuaResult<&'static Lua> {
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?; lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;
lua.set_named_registry_value("tab.pack", table.get::<_, LuaFunction>("pack")?)?; lua.set_named_registry_value("tab.pack", table.get::<_, LuaFunction>("pack")?)?;
lua.set_named_registry_value("tab.unpack", table.get::<_, LuaFunction>("unpack")?)?; lua.set_named_registry_value("tab.unpack", table.get::<_, LuaFunction>("unpack")?)?;
lua.set_named_registry_value("tab.freeze", table.get::<_, LuaFunction>("freeze")?)?;
lua.set_named_registry_value(
"tab.getmeta",
globals.get::<_, LuaFunction>("getmetatable")?,
)?;
lua.set_named_registry_value(
"tab.setmeta",
globals.get::<_, LuaFunction>("setmetatable")?,
)?;
// Create a trace function that can be called to obtain a full stack trace from // Create a trace function that can be called to obtain a full stack trace from
// lua, this is not possible to do from rust when using our manual scheduler // lua, this is not possible to do from rust when using our manual scheduler
let dbg_trace_env = lua.create_table_with_capacity(0, 1)?; let dbg_trace_env = lua.create_table_with_capacity(0, 1)?;

View file

@ -1,7 +1,15 @@
use std::{cell::Cell, sync::Arc}; use std::{cell::Cell, sync::Arc};
use hyper::upgrade::Upgraded;
use mlua::prelude::*; use mlua::prelude::*;
use futures_util::{SinkExt, StreamExt};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
sync::Mutex as AsyncMutex,
};
use hyper_tungstenite::{ use hyper_tungstenite::{
tungstenite::{ tungstenite::{
protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame}, protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame},
@ -9,19 +17,43 @@ use hyper_tungstenite::{
}, },
WebSocketStream, WebSocketStream,
}; };
use tokio_tungstenite::MaybeTlsStream;
use futures_util::{SinkExt, StreamExt};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::Mutex,
};
use crate::lua::table::TableBuilder; use crate::lua::table::TableBuilder;
#[derive(Debug, Clone)] const WEB_SOCKET_IMPL_LUA: &str = r#"
return freeze(setmetatable({
close = function(...)
return close(websocket, ...)
end,
send = function(...)
return send(websocket, ...)
end,
next = function(...)
return next(websocket, ...)
end,
}, {
__index = function(self, key)
if key == "closeCode" then
return close_code()
end
end,
}))
"#;
#[derive(Debug)]
pub struct NetWebSocket<T> { pub struct NetWebSocket<T> {
close_code: Cell<Option<u16>>, close_code: Arc<Cell<Option<u16>>>,
stream: Arc<Mutex<WebSocketStream<T>>>, stream: Arc<AsyncMutex<WebSocketStream<T>>>,
}
impl<T> Clone for NetWebSocket<T> {
fn clone(&self) -> Self {
Self {
close_code: Arc::clone(&self.close_code),
stream: Arc::clone(&self.stream),
}
}
} }
impl<T> NetWebSocket<T> impl<T> NetWebSocket<T>
@ -30,20 +62,83 @@ where
{ {
pub fn new(value: WebSocketStream<T>) -> Self { pub fn new(value: WebSocketStream<T>) -> Self {
Self { Self {
close_code: Cell::new(None), close_code: Arc::new(Cell::new(None)),
stream: Arc::new(Mutex::new(value)), stream: Arc::new(AsyncMutex::new(value)),
} }
} }
pub fn get_lua_close_code(&self) -> LuaValue { fn into_lua_table_with_env<'lua>(
match self.close_code.get() { lua: &'lua Lua,
env: LuaTable<'lua>,
) -> LuaResult<LuaTable<'lua>> {
lua.load(WEB_SOCKET_IMPL_LUA)
.set_name("websocket")?
.set_environment(env)?
.eval()
}
}
type NetWebSocketStreamClient = MaybeTlsStream<TcpStream>;
impl NetWebSocket<NetWebSocketStreamClient> {
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
let socket_env = TableBuilder::new(lua)?
.with_value("websocket", self)?
.with_function("close_code", close_code::<NetWebSocketStreamClient>)?
.with_async_function("close", close::<NetWebSocketStreamClient>)?
.with_async_function("send", send::<NetWebSocketStreamClient>)?
.with_async_function("next", next::<NetWebSocketStreamClient>)?
.with_value(
"setmetatable",
lua.named_registry_value::<_, LuaFunction>("tab.setmeta")?,
)?
.with_value(
"freeze",
lua.named_registry_value::<_, LuaFunction>("tab.freeze")?,
)?
.build_readonly()?;
Self::into_lua_table_with_env(lua, socket_env)
}
}
type NetWebSocketStreamServer = Upgraded;
impl NetWebSocket<NetWebSocketStreamServer> {
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
let socket_env = TableBuilder::new(lua)?
.with_value("websocket", self)?
.with_function("close_code", close_code::<NetWebSocketStreamServer>)?
.with_async_function("close", close::<NetWebSocketStreamServer>)?
.with_async_function("send", send::<NetWebSocketStreamServer>)?
.with_async_function("next", next::<NetWebSocketStreamServer>)?
.with_value(
"setmetatable",
lua.named_registry_value::<_, LuaFunction>("tab.setmeta")?,
)?
.with_value(
"freeze",
lua.named_registry_value::<_, LuaFunction>("tab.freeze")?,
)?
.build_readonly()?;
Self::into_lua_table_with_env(lua, socket_env)
}
}
impl<T> LuaUserData for NetWebSocket<T> {}
fn close_code<T>(_lua: &Lua, socket: NetWebSocket<T>) -> LuaResult<LuaValue>
where
T: AsyncRead + AsyncWrite + Unpin,
{
Ok(match socket.close_code.get() {
Some(code) => LuaValue::Number(code as f64), Some(code) => LuaValue::Number(code as f64),
None => LuaValue::Nil, None => LuaValue::Nil,
} })
} }
pub async fn close(&self, code: Option<u16>) -> LuaResult<()> { async fn close<T>(_lua: &Lua, (socket, code): (NetWebSocket<T>, Option<u16>)) -> LuaResult<()>
let mut ws = self.stream.lock().await; where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut ws = socket.stream.lock().await;
let res = ws.close(Some(WsCloseFrame { let res = ws.close(Some(WsCloseFrame {
code: match code { code: match code {
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
@ -59,41 +154,39 @@ where
res.await.map_err(LuaError::external) res.await.map_err(LuaError::external)
} }
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { async fn send<T>(
let mut ws = self.stream.lock().await; _lua: &Lua,
ws.send(msg).await.map_err(LuaError::external) (socket, string, as_binary): (NetWebSocket<T>, LuaString<'_>, Option<bool>),
} ) -> LuaResult<()>
where
pub async fn send_lua_string<'lua>( T: AsyncRead + AsyncWrite + Unpin,
&self, {
string: LuaString<'lua>,
as_binary: Option<bool>,
) -> LuaResult<()> {
let msg = if matches!(as_binary, Some(true)) { let msg = if matches!(as_binary, Some(true)) {
WsMessage::Binary(string.as_bytes().to_vec()) WsMessage::Binary(string.as_bytes().to_vec())
} else { } else {
let s = string.to_str().map_err(LuaError::external)?; let s = string.to_str().map_err(LuaError::external)?;
WsMessage::Text(s.to_string()) WsMessage::Text(s.to_string())
}; };
self.send(msg).await let mut ws = socket.stream.lock().await;
ws.send(msg).await.map_err(LuaError::external)
} }
pub async fn next(&self) -> LuaResult<Option<WsMessage>> { async fn next<T>(lua: &Lua, socket: NetWebSocket<T>) -> LuaResult<LuaValue>
let mut ws = self.stream.lock().await; where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut ws = socket.stream.lock().await;
let item = ws.next().await.transpose().map_err(LuaError::external); let item = ws.next().await.transpose().map_err(LuaError::external);
match item { let msg = match item {
Ok(Some(WsMessage::Close(msg))) => { Ok(Some(WsMessage::Close(msg))) => {
if let Some(msg) = &msg { if let Some(msg) = &msg {
self.close_code.replace(Some(msg.code.into())); socket.close_code.replace(Some(msg.code.into()));
} }
Ok(Some(WsMessage::Close(msg))) Ok(Some(WsMessage::Close(msg)))
} }
val => val, val => val,
} }?;
} while let Some(msg) = &msg {
pub async fn next_lua_string<'lua>(&'lua self, lua: &'lua Lua) -> LuaResult<LuaValue> {
while let Some(msg) = self.next().await? {
let msg_string_opt = match msg { let msg_string_opt = match msg {
WsMessage::Binary(bin) => Some(lua.create_string(&bin)?), WsMessage::Binary(bin) => Some(lua.create_string(&bin)?),
WsMessage::Text(txt) => Some(lua.create_string(&txt)?), WsMessage::Text(txt) => Some(lua.create_string(&txt)?),
@ -108,29 +201,3 @@ where
} }
Ok(LuaValue::Nil) Ok(LuaValue::Nil)
} }
}
impl<T> NetWebSocket<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
let ws = Box::leak(Box::new(self));
TableBuilder::new(lua)?
.with_async_function("close", |_, code| ws.close(code))?
.with_async_function("send", |_, (msg, bin)| ws.send_lua_string(msg, bin))?
.with_async_function("next", |lua, _: ()| ws.next_lua_string(lua))?
.with_metatable(
TableBuilder::new(lua)?
.with_function(LuaMetaMethod::Index.name(), |_, key: String| {
if key == "closeCode" {
Ok(ws.get_lua_close_code())
} else {
Ok(LuaValue::Nil)
}
})?
.build_readonly()?,
)?
.build_readonly()
}
}