mirror of
https://github.com/lune-org/lune.git
synced 2025-01-19 01:08:05 +00:00
Fix websockets memory leak
This commit is contained in:
parent
29a3b41e15
commit
1f887cef07
2 changed files with 185 additions and 105 deletions
|
@ -64,10 +64,14 @@ end
|
||||||
* `"co.yield"` -> `coroutine.yield`
|
* `"co.yield"` -> `coroutine.yield`
|
||||||
* `"co.close"` -> `coroutine.close`
|
* `"co.close"` -> `coroutine.close`
|
||||||
---
|
---
|
||||||
|
* `"tab.pack"` -> `table.pack`
|
||||||
|
* `"tab.unpack"` -> `table.unpack`
|
||||||
|
* `"tab.freeze"` -> `table.freeze`
|
||||||
|
* `"tab.getmeta"` -> `getmetatable`
|
||||||
|
* `"tab.setmeta"` -> `setmetatable`
|
||||||
|
---
|
||||||
* `"dbg.info"` -> `debug.info`
|
* `"dbg.info"` -> `debug.info`
|
||||||
* `"dbg.trace"` -> `debug.traceback`
|
* `"dbg.trace"` -> `debug.traceback`
|
||||||
* `"dbg.iserr"` -> `<custom function>`
|
|
||||||
* `"dbg.makeerr"` -> `<custom function>`
|
|
||||||
---
|
---
|
||||||
*/
|
*/
|
||||||
pub fn create() -> LuaResult<&'static Lua> {
|
pub fn create() -> LuaResult<&'static Lua> {
|
||||||
|
@ -93,6 +97,15 @@ pub fn create() -> LuaResult<&'static Lua> {
|
||||||
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;
|
lua.set_named_registry_value("dbg.info", debug.get::<_, LuaFunction>("info")?)?;
|
||||||
lua.set_named_registry_value("tab.pack", table.get::<_, LuaFunction>("pack")?)?;
|
lua.set_named_registry_value("tab.pack", table.get::<_, LuaFunction>("pack")?)?;
|
||||||
lua.set_named_registry_value("tab.unpack", table.get::<_, LuaFunction>("unpack")?)?;
|
lua.set_named_registry_value("tab.unpack", table.get::<_, LuaFunction>("unpack")?)?;
|
||||||
|
lua.set_named_registry_value("tab.freeze", table.get::<_, LuaFunction>("freeze")?)?;
|
||||||
|
lua.set_named_registry_value(
|
||||||
|
"tab.getmeta",
|
||||||
|
globals.get::<_, LuaFunction>("getmetatable")?,
|
||||||
|
)?;
|
||||||
|
lua.set_named_registry_value(
|
||||||
|
"tab.setmeta",
|
||||||
|
globals.get::<_, LuaFunction>("setmetatable")?,
|
||||||
|
)?;
|
||||||
// Create a trace function that can be called to obtain a full stack trace from
|
// Create a trace function that can be called to obtain a full stack trace from
|
||||||
// lua, this is not possible to do from rust when using our manual scheduler
|
// lua, this is not possible to do from rust when using our manual scheduler
|
||||||
let dbg_trace_env = lua.create_table_with_capacity(0, 1)?;
|
let dbg_trace_env = lua.create_table_with_capacity(0, 1)?;
|
||||||
|
|
|
@ -1,7 +1,15 @@
|
||||||
use std::{cell::Cell, sync::Arc};
|
use std::{cell::Cell, sync::Arc};
|
||||||
|
|
||||||
|
use hyper::upgrade::Upgraded;
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
|
use futures_util::{SinkExt, StreamExt};
|
||||||
|
use tokio::{
|
||||||
|
io::{AsyncRead, AsyncWrite},
|
||||||
|
net::TcpStream,
|
||||||
|
sync::Mutex as AsyncMutex,
|
||||||
|
};
|
||||||
|
|
||||||
use hyper_tungstenite::{
|
use hyper_tungstenite::{
|
||||||
tungstenite::{
|
tungstenite::{
|
||||||
protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame},
|
protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame},
|
||||||
|
@ -9,19 +17,43 @@ use hyper_tungstenite::{
|
||||||
},
|
},
|
||||||
WebSocketStream,
|
WebSocketStream,
|
||||||
};
|
};
|
||||||
|
use tokio_tungstenite::MaybeTlsStream;
|
||||||
use futures_util::{SinkExt, StreamExt};
|
|
||||||
use tokio::{
|
|
||||||
io::{AsyncRead, AsyncWrite},
|
|
||||||
sync::Mutex,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::lua::table::TableBuilder;
|
use crate::lua::table::TableBuilder;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
const WEB_SOCKET_IMPL_LUA: &str = r#"
|
||||||
|
return freeze(setmetatable({
|
||||||
|
close = function(...)
|
||||||
|
return close(websocket, ...)
|
||||||
|
end,
|
||||||
|
send = function(...)
|
||||||
|
return send(websocket, ...)
|
||||||
|
end,
|
||||||
|
next = function(...)
|
||||||
|
return next(websocket, ...)
|
||||||
|
end,
|
||||||
|
}, {
|
||||||
|
__index = function(self, key)
|
||||||
|
if key == "closeCode" then
|
||||||
|
return close_code()
|
||||||
|
end
|
||||||
|
end,
|
||||||
|
}))
|
||||||
|
"#;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct NetWebSocket<T> {
|
pub struct NetWebSocket<T> {
|
||||||
close_code: Cell<Option<u16>>,
|
close_code: Arc<Cell<Option<u16>>>,
|
||||||
stream: Arc<Mutex<WebSocketStream<T>>>,
|
stream: Arc<AsyncMutex<WebSocketStream<T>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Clone for NetWebSocket<T> {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
Self {
|
||||||
|
close_code: Arc::clone(&self.close_code),
|
||||||
|
stream: Arc::clone(&self.stream),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> NetWebSocket<T>
|
impl<T> NetWebSocket<T>
|
||||||
|
@ -30,20 +62,83 @@ where
|
||||||
{
|
{
|
||||||
pub fn new(value: WebSocketStream<T>) -> Self {
|
pub fn new(value: WebSocketStream<T>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
close_code: Cell::new(None),
|
close_code: Arc::new(Cell::new(None)),
|
||||||
stream: Arc::new(Mutex::new(value)),
|
stream: Arc::new(AsyncMutex::new(value)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_lua_close_code(&self) -> LuaValue {
|
fn into_lua_table_with_env<'lua>(
|
||||||
match self.close_code.get() {
|
lua: &'lua Lua,
|
||||||
|
env: LuaTable<'lua>,
|
||||||
|
) -> LuaResult<LuaTable<'lua>> {
|
||||||
|
lua.load(WEB_SOCKET_IMPL_LUA)
|
||||||
|
.set_name("websocket")?
|
||||||
|
.set_environment(env)?
|
||||||
|
.eval()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetWebSocketStreamClient = MaybeTlsStream<TcpStream>;
|
||||||
|
impl NetWebSocket<NetWebSocketStreamClient> {
|
||||||
|
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||||
|
let socket_env = TableBuilder::new(lua)?
|
||||||
|
.with_value("websocket", self)?
|
||||||
|
.with_function("close_code", close_code::<NetWebSocketStreamClient>)?
|
||||||
|
.with_async_function("close", close::<NetWebSocketStreamClient>)?
|
||||||
|
.with_async_function("send", send::<NetWebSocketStreamClient>)?
|
||||||
|
.with_async_function("next", next::<NetWebSocketStreamClient>)?
|
||||||
|
.with_value(
|
||||||
|
"setmetatable",
|
||||||
|
lua.named_registry_value::<_, LuaFunction>("tab.setmeta")?,
|
||||||
|
)?
|
||||||
|
.with_value(
|
||||||
|
"freeze",
|
||||||
|
lua.named_registry_value::<_, LuaFunction>("tab.freeze")?,
|
||||||
|
)?
|
||||||
|
.build_readonly()?;
|
||||||
|
Self::into_lua_table_with_env(lua, socket_env)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetWebSocketStreamServer = Upgraded;
|
||||||
|
impl NetWebSocket<NetWebSocketStreamServer> {
|
||||||
|
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||||
|
let socket_env = TableBuilder::new(lua)?
|
||||||
|
.with_value("websocket", self)?
|
||||||
|
.with_function("close_code", close_code::<NetWebSocketStreamServer>)?
|
||||||
|
.with_async_function("close", close::<NetWebSocketStreamServer>)?
|
||||||
|
.with_async_function("send", send::<NetWebSocketStreamServer>)?
|
||||||
|
.with_async_function("next", next::<NetWebSocketStreamServer>)?
|
||||||
|
.with_value(
|
||||||
|
"setmetatable",
|
||||||
|
lua.named_registry_value::<_, LuaFunction>("tab.setmeta")?,
|
||||||
|
)?
|
||||||
|
.with_value(
|
||||||
|
"freeze",
|
||||||
|
lua.named_registry_value::<_, LuaFunction>("tab.freeze")?,
|
||||||
|
)?
|
||||||
|
.build_readonly()?;
|
||||||
|
Self::into_lua_table_with_env(lua, socket_env)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> LuaUserData for NetWebSocket<T> {}
|
||||||
|
|
||||||
|
fn close_code<T>(_lua: &Lua, socket: NetWebSocket<T>) -> LuaResult<LuaValue>
|
||||||
|
where
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
Ok(match socket.close_code.get() {
|
||||||
Some(code) => LuaValue::Number(code as f64),
|
Some(code) => LuaValue::Number(code as f64),
|
||||||
None => LuaValue::Nil,
|
None => LuaValue::Nil,
|
||||||
}
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn close(&self, code: Option<u16>) -> LuaResult<()> {
|
async fn close<T>(_lua: &Lua, (socket, code): (NetWebSocket<T>, Option<u16>)) -> LuaResult<()>
|
||||||
let mut ws = self.stream.lock().await;
|
where
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let mut ws = socket.stream.lock().await;
|
||||||
let res = ws.close(Some(WsCloseFrame {
|
let res = ws.close(Some(WsCloseFrame {
|
||||||
code: match code {
|
code: match code {
|
||||||
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
|
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
|
||||||
|
@ -59,41 +154,39 @@ where
|
||||||
res.await.map_err(LuaError::external)
|
res.await.map_err(LuaError::external)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
|
async fn send<T>(
|
||||||
let mut ws = self.stream.lock().await;
|
_lua: &Lua,
|
||||||
ws.send(msg).await.map_err(LuaError::external)
|
(socket, string, as_binary): (NetWebSocket<T>, LuaString<'_>, Option<bool>),
|
||||||
}
|
) -> LuaResult<()>
|
||||||
|
where
|
||||||
pub async fn send_lua_string<'lua>(
|
T: AsyncRead + AsyncWrite + Unpin,
|
||||||
&self,
|
{
|
||||||
string: LuaString<'lua>,
|
|
||||||
as_binary: Option<bool>,
|
|
||||||
) -> LuaResult<()> {
|
|
||||||
let msg = if matches!(as_binary, Some(true)) {
|
let msg = if matches!(as_binary, Some(true)) {
|
||||||
WsMessage::Binary(string.as_bytes().to_vec())
|
WsMessage::Binary(string.as_bytes().to_vec())
|
||||||
} else {
|
} else {
|
||||||
let s = string.to_str().map_err(LuaError::external)?;
|
let s = string.to_str().map_err(LuaError::external)?;
|
||||||
WsMessage::Text(s.to_string())
|
WsMessage::Text(s.to_string())
|
||||||
};
|
};
|
||||||
self.send(msg).await
|
let mut ws = socket.stream.lock().await;
|
||||||
|
ws.send(msg).await.map_err(LuaError::external)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
|
async fn next<T>(lua: &Lua, socket: NetWebSocket<T>) -> LuaResult<LuaValue>
|
||||||
let mut ws = self.stream.lock().await;
|
where
|
||||||
|
T: AsyncRead + AsyncWrite + Unpin,
|
||||||
|
{
|
||||||
|
let mut ws = socket.stream.lock().await;
|
||||||
let item = ws.next().await.transpose().map_err(LuaError::external);
|
let item = ws.next().await.transpose().map_err(LuaError::external);
|
||||||
match item {
|
let msg = match item {
|
||||||
Ok(Some(WsMessage::Close(msg))) => {
|
Ok(Some(WsMessage::Close(msg))) => {
|
||||||
if let Some(msg) = &msg {
|
if let Some(msg) = &msg {
|
||||||
self.close_code.replace(Some(msg.code.into()));
|
socket.close_code.replace(Some(msg.code.into()));
|
||||||
}
|
}
|
||||||
Ok(Some(WsMessage::Close(msg)))
|
Ok(Some(WsMessage::Close(msg)))
|
||||||
}
|
}
|
||||||
val => val,
|
val => val,
|
||||||
}
|
}?;
|
||||||
}
|
while let Some(msg) = &msg {
|
||||||
|
|
||||||
pub async fn next_lua_string<'lua>(&'lua self, lua: &'lua Lua) -> LuaResult<LuaValue> {
|
|
||||||
while let Some(msg) = self.next().await? {
|
|
||||||
let msg_string_opt = match msg {
|
let msg_string_opt = match msg {
|
||||||
WsMessage::Binary(bin) => Some(lua.create_string(&bin)?),
|
WsMessage::Binary(bin) => Some(lua.create_string(&bin)?),
|
||||||
WsMessage::Text(txt) => Some(lua.create_string(&txt)?),
|
WsMessage::Text(txt) => Some(lua.create_string(&txt)?),
|
||||||
|
@ -108,29 +201,3 @@ where
|
||||||
}
|
}
|
||||||
Ok(LuaValue::Nil)
|
Ok(LuaValue::Nil)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> NetWebSocket<T>
|
|
||||||
where
|
|
||||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
|
||||||
{
|
|
||||||
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
|
|
||||||
let ws = Box::leak(Box::new(self));
|
|
||||||
TableBuilder::new(lua)?
|
|
||||||
.with_async_function("close", |_, code| ws.close(code))?
|
|
||||||
.with_async_function("send", |_, (msg, bin)| ws.send_lua_string(msg, bin))?
|
|
||||||
.with_async_function("next", |lua, _: ()| ws.next_lua_string(lua))?
|
|
||||||
.with_metatable(
|
|
||||||
TableBuilder::new(lua)?
|
|
||||||
.with_function(LuaMetaMethod::Index.name(), |_, key: String| {
|
|
||||||
if key == "closeCode" {
|
|
||||||
Ok(ws.get_lua_close_code())
|
|
||||||
} else {
|
|
||||||
Ok(LuaValue::Nil)
|
|
||||||
}
|
|
||||||
})?
|
|
||||||
.build_readonly()?,
|
|
||||||
)?
|
|
||||||
.build_readonly()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
Loading…
Reference in a new issue