split websocket impl into send/recv objects

This commit is contained in:
AsynchronousMatrix 2023-07-17 19:23:05 +01:00
parent b0861ce0fb
commit ef2a57838b

View file

@ -3,7 +3,10 @@ use std::{cell::Cell, sync::Arc};
use hyper::upgrade::Upgraded; use hyper::upgrade::Upgraded;
use mlua::prelude::*; use mlua::prelude::*;
use futures_util::{SinkExt, StreamExt}; use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use tokio::{ use tokio::{
io::{AsyncRead, AsyncWrite}, io::{AsyncRead, AsyncWrite},
net::TcpStream, net::TcpStream,
@ -44,14 +47,16 @@ return freeze(setmetatable({
#[derive(Debug)] #[derive(Debug)]
pub struct NetWebSocket<T> { pub struct NetWebSocket<T> {
close_code: Arc<Cell<Option<u16>>>, close_code: Arc<Cell<Option<u16>>>,
stream: Arc<AsyncMutex<WebSocketStream<T>>>, stream_in: Arc<AsyncMutex<SplitSink<WebSocketStream<T>, WsMessage>>>,
stream_out: Arc<AsyncMutex<SplitStream<WebSocketStream<T>>>>,
} }
impl<T> Clone for NetWebSocket<T> { impl<T> Clone for NetWebSocket<T> {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
close_code: Arc::clone(&self.close_code), close_code: Arc::clone(&self.close_code),
stream: Arc::clone(&self.stream), stream_in: Arc::clone(&self.stream_in),
stream_out: Arc::clone(&self.stream_out),
} }
} }
} }
@ -61,9 +66,12 @@ where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
{ {
pub fn new(value: WebSocketStream<T>) -> Self { pub fn new(value: WebSocketStream<T>) -> Self {
let (write, read) = value.split();
Self { Self {
close_code: Arc::new(Cell::new(None)), close_code: Arc::new(Cell::new(None)),
stream: Arc::new(AsyncMutex::new(value)), stream_in: Arc::new(AsyncMutex::new(write)),
stream_out: Arc::new(AsyncMutex::new(read)),
} }
} }
@ -144,19 +152,24 @@ async fn close<'lua, T>(
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
{ {
let mut ws = socket.stream.lock().await; let mut ws_in = socket.stream_in.lock().await;
let res = ws.close(Some(WsCloseFrame {
code: match code { let _ = ws_in
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), .send(WsMessage::Close(Some(WsCloseFrame {
Some(code) => { code: match code {
return Err(LuaError::RuntimeError(format!( Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
"Close code must be between 1000 and 4999, got {code}" Some(code) => {
))) return Err(LuaError::RuntimeError(format!(
} "Close code must be between 1000 and 4999, got {code}"
None => WsCloseCode::Normal, )))
}, }
reason: "".into(), None => WsCloseCode::Normal,
})); },
reason: "".into(),
})))
.await;
let res = ws_in.close();
res.await.map_err(LuaError::external) res.await.map_err(LuaError::external)
} }
@ -177,7 +190,7 @@ where
let s = string.to_str().map_err(LuaError::external)?; let s = string.to_str().map_err(LuaError::external)?;
WsMessage::Text(s.to_string()) WsMessage::Text(s.to_string())
}; };
let mut ws = socket.stream.lock().await; let mut ws = socket.stream_in.lock().await;
ws.send(msg).await.map_err(LuaError::external) ws.send(msg).await.map_err(LuaError::external)
} }
@ -188,7 +201,7 @@ async fn next<'lua, T>(
where where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
{ {
let mut ws = socket.stream.lock().await; let mut ws = socket.stream_out.lock().await;
let item = ws.next().await.transpose().map_err(LuaError::external); let item = ws.next().await.transpose().map_err(LuaError::external);
let msg = match item { let msg = match item {
Ok(Some(WsMessage::Close(msg))) => { Ok(Some(WsMessage::Close(msg))) => {