From ef2a57838b52b27a1f00b543feb6b60956d6c293 Mon Sep 17 00:00:00 2001 From: AsynchronousMatrix Date: Mon, 17 Jul 2023 19:23:05 +0100 Subject: [PATCH] split websocket impl into send/recv objects --- packages/lib/src/lua/net/websocket.rs | 51 +++++++++++++++++---------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/packages/lib/src/lua/net/websocket.rs b/packages/lib/src/lua/net/websocket.rs index bbf24e0..a53a02c 100644 --- a/packages/lib/src/lua/net/websocket.rs +++ b/packages/lib/src/lua/net/websocket.rs @@ -3,7 +3,10 @@ use std::{cell::Cell, sync::Arc}; use hyper::upgrade::Upgraded; use mlua::prelude::*; -use futures_util::{SinkExt, StreamExt}; +use futures_util::{ + stream::{SplitSink, SplitStream}, + SinkExt, StreamExt, +}; use tokio::{ io::{AsyncRead, AsyncWrite}, net::TcpStream, @@ -44,14 +47,16 @@ return freeze(setmetatable({ #[derive(Debug)] pub struct NetWebSocket { close_code: Arc>>, - stream: Arc>>, + stream_in: Arc, WsMessage>>>, + stream_out: Arc>>>, } impl Clone for NetWebSocket { fn clone(&self) -> Self { Self { close_code: Arc::clone(&self.close_code), - stream: Arc::clone(&self.stream), + stream_in: Arc::clone(&self.stream_in), + stream_out: Arc::clone(&self.stream_out), } } } @@ -61,9 +66,12 @@ where T: AsyncRead + AsyncWrite + Unpin, { pub fn new(value: WebSocketStream) -> Self { + let (write, read) = value.split(); + Self { close_code: Arc::new(Cell::new(None)), - stream: Arc::new(AsyncMutex::new(value)), + stream_in: Arc::new(AsyncMutex::new(write)), + stream_out: Arc::new(AsyncMutex::new(read)), } } @@ -144,19 +152,24 @@ async fn close<'lua, T>( where T: AsyncRead + AsyncWrite + Unpin, { - let mut ws = socket.stream.lock().await; - let res = ws.close(Some(WsCloseFrame { - code: match code { - Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), - Some(code) => { - return Err(LuaError::RuntimeError(format!( - "Close code must be between 1000 and 4999, got {code}" - ))) - } - None => WsCloseCode::Normal, - }, - reason: "".into(), - })); + let mut ws_in = socket.stream_in.lock().await; + + let _ = ws_in + .send(WsMessage::Close(Some(WsCloseFrame { + code: match code { + Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), + Some(code) => { + return Err(LuaError::RuntimeError(format!( + "Close code must be between 1000 and 4999, got {code}" + ))) + } + None => WsCloseCode::Normal, + }, + reason: "".into(), + }))) + .await; + + let res = ws_in.close(); res.await.map_err(LuaError::external) } @@ -177,7 +190,7 @@ where let s = string.to_str().map_err(LuaError::external)?; WsMessage::Text(s.to_string()) }; - let mut ws = socket.stream.lock().await; + let mut ws = socket.stream_in.lock().await; ws.send(msg).await.map_err(LuaError::external) } @@ -188,7 +201,7 @@ async fn next<'lua, T>( where T: AsyncRead + AsyncWrite + Unpin, { - let mut ws = socket.stream.lock().await; + let mut ws = socket.stream_out.lock().await; let item = ws.next().await.transpose().map_err(LuaError::external); let msg = match item { Ok(Some(WsMessage::Close(msg))) => {