diff --git a/docs/luneTypes.d.luau b/docs/luneTypes.d.luau index 4b3359c..5a927fc 100644 --- a/docs/luneTypes.d.luau +++ b/docs/luneTypes.d.luau @@ -294,14 +294,17 @@ export type NetServeHandle = { * `next` will no longer return any message(s) and instead instantly return nil * `send` will throw an error stating that the socket has been closed + + Once the websocket has been closed, `closeCode` will no longer be nil, and will be populated with a close + code according to the [WebSocket specification](https://www.iana.org/assignments/websocket/websocket.xhtml). + This will be an integer between 1000 and 4999, where 1000 is the canonical code for normal, error-free closure. ]=] -declare class NetWebSocket - close: () -> () - send: (message: string) -> () - next: () -> string? - iter: () -> () -> string - function __iter(self): () -> string -end +export type NetWebSocket = { + closeCode: number?, + close: (code: number?) -> (), + send: (message: string, asBinaryMessage: boolean?) -> (), + next: () -> string?, +} --[=[ @class Net diff --git a/packages/lib/src/lua/net/websocket.rs b/packages/lib/src/lua/net/websocket.rs index c05491d..f495f23 100644 --- a/packages/lib/src/lua/net/websocket.rs +++ b/packages/lib/src/lua/net/websocket.rs @@ -1,8 +1,14 @@ -use std::sync::Arc; +use std::{cell::Cell, sync::Arc}; use mlua::prelude::*; -use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream}; +use hyper_tungstenite::{ + tungstenite::{ + protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame}, + Message as WsMessage, + }, + WebSocketStream, +}; use futures_util::{SinkExt, StreamExt}; use tokio::{ @@ -14,6 +20,7 @@ use crate::lua::table::TableBuilder; #[derive(Debug, Clone)] pub struct NetWebSocket { + close_code: Cell>, stream: Arc>>, } @@ -23,26 +30,83 @@ where { pub fn new(value: WebSocketStream) -> Self { Self { + close_code: Cell::new(None), stream: Arc::new(Mutex::new(value)), } } - pub async fn close(&self) -> LuaResult<()> { + pub fn get_lua_close_code(&self) -> LuaValue { + match self.close_code.get() { + Some(code) => LuaValue::Number(code as f64), + None => LuaValue::Nil, + } + } + + pub async fn close(&self, code: Option) -> LuaResult<()> { let mut ws = self.stream.lock().await; - ws.close(None).await.map_err(LuaError::external)?; - Ok(()) + let res = ws.close(Some(WsCloseFrame { + code: match code { + Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), + Some(code) => { + return Err(LuaError::RuntimeError(format!( + "Close code must be between 1000 and 4999, got {code}" + ))) + } + None => WsCloseCode::Normal, + }, + reason: "".into(), + })); + res.await.map_err(LuaError::external) } pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { let mut ws = self.stream.lock().await; - ws.send(msg).await.map_err(LuaError::external)?; - Ok(()) + ws.send(msg).await.map_err(LuaError::external) + } + + pub async fn send_lua_string<'lua>( + &self, + string: LuaString<'lua>, + as_binary: Option, + ) -> LuaResult<()> { + let msg = if matches!(as_binary, Some(true)) { + WsMessage::Binary(string.as_bytes().to_vec()) + } else { + let s = string.to_str().map_err(LuaError::external)?; + WsMessage::Text(s.to_string()) + }; + self.send(msg).await } pub async fn next(&self) -> LuaResult> { let mut ws = self.stream.lock().await; - let item = ws.next().await.transpose(); - item.map_err(LuaError::external) + let item = ws.next().await.transpose().map_err(LuaError::external); + match item { + Ok(Some(WsMessage::Close(msg))) => { + if let Some(msg) = &msg { + self.close_code.replace(Some(msg.code.into())); + } + Ok(Some(WsMessage::Close(msg))) + } + val => val, + } + } + + pub async fn next_lua_string<'lua>(&'lua self, lua: &'lua Lua) -> LuaResult { + while let Some(msg) = self.next().await? { + let msg_string_opt = match msg { + WsMessage::Binary(bin) => Some(lua.create_string(&bin)?), + WsMessage::Text(txt) => Some(lua.create_string(&txt)?), + // Stop waiting for next message if we get a close message + WsMessage::Close(_) => return Ok(LuaValue::Nil), + // Ignore ping/pong/frame messages, they are handled by tungstenite + _ => None, + }; + if let Some(msg_string) = msg_string_opt { + return Ok(LuaValue::String(msg_string)); + } + } + Ok(LuaValue::Nil) } } @@ -53,20 +117,20 @@ where pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { let ws = Box::leak(Box::new(self)); TableBuilder::new(lua)? - .with_async_function("close", |_, _: ()| async { ws.close().await })? - .with_async_function("send", |_, msg: String| async { - ws.send(WsMessage::Text(msg)).await - })? - .with_async_function("next", |_, _: ()| async { - match ws.next().await? { - Some(msg) => Ok(match msg { - WsMessage::Binary(bin) => LuaValue::String(lua.create_string(&bin)?), - WsMessage::Text(txt) => LuaValue::String(lua.create_string(&txt)?), - _ => LuaValue::Nil, - }), - None => Ok(LuaValue::Nil), - } - })? + .with_async_function("close", |_, code| ws.close(code))? + .with_async_function("send", |_, (msg, bin)| ws.send_lua_string(msg, bin))? + .with_async_function("next", |lua, _: ()| ws.next_lua_string(lua))? + .with_metatable( + TableBuilder::new(lua)? + .with_function(LuaMetaMethod::Index.name(), |_, key: String| { + if key == "closeCode" { + Ok(ws.get_lua_close_code()) + } else { + Ok(LuaValue::Nil) + } + })? + .build_readonly()?, + )? .build_readonly() } }