diff --git a/packages/lib/src/lua/net/websocket.rs b/packages/lib/src/lua/net/websocket.rs index bbf24e0..c66b93c 100644 --- a/packages/lib/src/lua/net/websocket.rs +++ b/packages/lib/src/lua/net/websocket.rs @@ -3,7 +3,10 @@ use std::{cell::Cell, sync::Arc}; use hyper::upgrade::Upgraded; use mlua::prelude::*; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; use tokio::{ io::{AsyncRead, AsyncWrite}, net::TcpStream, @@ -44,14 +47,16 @@ return freeze(setmetatable({ #[derive(Debug)] pub struct NetWebSocket { close_code: Arc>>, - stream: Arc>>, + read_stream: Arc>>>, + write_stream: Arc, WsMessage>>>, } impl Clone for NetWebSocket { fn clone(&self) -> Self { Self { close_code: Arc::clone(&self.close_code), - stream: Arc::clone(&self.stream), + read_stream: Arc::clone(&self.read_stream), + write_stream: Arc::clone(&self.write_stream), } } } @@ -61,9 +66,12 @@ where T: AsyncRead + AsyncWrite + Unpin, { pub fn new(value: WebSocketStream) -> Self { + let (write, read) = value.split(); + Self { close_code: Arc::new(Cell::new(None)), - stream: Arc::new(AsyncMutex::new(value)), + read_stream: Arc::new(AsyncMutex::new(read)), + write_stream: Arc::new(AsyncMutex::new(write)), } } @@ -144,8 +152,9 @@ async fn close<'lua, T>( where T: AsyncRead + AsyncWrite + Unpin, { - let mut ws = socket.stream.lock().await; - let res = ws.close(Some(WsCloseFrame { + let mut ws = socket.write_stream.lock().await; + + ws.send(WsMessage::Close(Some(WsCloseFrame { code: match code { Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), Some(code) => { @@ -156,7 +165,11 @@ where None => WsCloseCode::Normal, }, reason: "".into(), - })); + }))) + .await + .map_err(LuaError::external)?; + + let res = ws.close(); res.await.map_err(LuaError::external) } @@ -177,7 +190,7 @@ where let s = string.to_str().map_err(LuaError::external)?; WsMessage::Text(s.to_string()) }; - let mut ws = socket.stream.lock().await; + let mut ws = socket.write_stream.lock().await; ws.send(msg).await.map_err(LuaError::external) } @@ -188,7 +201,7 @@ async fn next<'lua, T>( where T: AsyncRead + AsyncWrite + Unpin, { - let mut ws = socket.stream.lock().await; + let mut ws = socket.read_stream.lock().await; let item = ws.next().await.transpose().map_err(LuaError::external); let msg = match item { Ok(Some(WsMessage::Close(msg))) => { diff --git a/packages/lib/src/tests.rs b/packages/lib/src/tests.rs index 6c32104..1c41660 100644 --- a/packages/lib/src/tests.rs +++ b/packages/lib/src/tests.rs @@ -58,6 +58,7 @@ create_tests! { net_serve_requests: "net/serve/requests", net_serve_websockets: "net/serve/websockets", net_socket_wss: "net/socket/wss", + net_socket_wss_rw: "net/socket/wss_rw", process_args: "process/args", process_cwd: "process/cwd", diff --git a/tests/net/socket/wss_rw.luau b/tests/net/socket/wss_rw.luau new file mode 100644 index 0000000..5aca618 --- /dev/null +++ b/tests/net/socket/wss_rw.luau @@ -0,0 +1,29 @@ +local net = require("@lune/net") +local process = require("@lune/process") +local task = require("@lune/task") + +-- We're going to use Discord's WebSocket gateway server +-- for testing that we can both read from a stream, +-- as well as write to the same stream concurrently +local socket = net.socket("wss://gateway.discord.gg/?v=10&encoding=json") + +local spawnedThread = task.spawn(function() + while not socket.closeCode do + socket.next() + end +end) + +local delayedThread = task.delay(5, function() + task.defer(process.exit, 1) + error("`socket.send` halted, failed to write to socket") + + process.exit(1) +end) + +task.wait(1) + +socket.send('{"op":1,"d":null}') +socket.close(1000) + +task.cancel(delayedThread) +task.cancel(spawnedThread)