From 5f3169c1bbb5accabfe106c9e32749d6e6a6980a Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Mon, 20 Feb 2023 16:08:45 +0100 Subject: [PATCH] Make websocket object generic over streams --- packages/lib/src/globals/net.rs | 6 +- packages/lib/src/lua/net/mod.rs | 6 +- packages/lib/src/lua/net/server.rs | 11 ++-- packages/lib/src/lua/net/websocket.rs | 86 +++++++++++++++++++++++++++ packages/lib/src/lua/net/ws_client.rs | 63 -------------------- packages/lib/src/lua/net/ws_server.rs | 63 -------------------- 6 files changed, 96 insertions(+), 139 deletions(-) create mode 100644 packages/lib/src/lua/net/websocket.rs delete mode 100644 packages/lib/src/lua/net/ws_client.rs delete mode 100644 packages/lib/src/lua/net/ws_server.rs diff --git a/packages/lib/src/globals/net.rs b/packages/lib/src/globals/net.rs index 9a14d9f..63f1b7b 100644 --- a/packages/lib/src/globals/net.rs +++ b/packages/lib/src/globals/net.rs @@ -9,8 +9,8 @@ use tokio::{sync::mpsc, task}; use crate::{ lua::{ net::{ - NetClient, NetClientBuilder, NetLocalExec, NetService, NetWebSocketClient, - RequestConfig, ServeConfig, + NetClient, NetClientBuilder, NetLocalExec, NetService, NetWebSocket, RequestConfig, + ServeConfig, }, task::{TaskScheduler, TaskSchedulerAsyncExt}, }, @@ -84,7 +84,7 @@ async fn net_socket<'a>(lua: &'static Lua, url: String) -> LuaResult { let (ws, _) = tokio_tungstenite::connect_async(url) .await .map_err(LuaError::external)?; - NetWebSocketClient::from(ws).into_lua_table(lua) + NetWebSocket::new(ws).into_lua_table(lua) } async fn net_serve<'a>( diff --git a/packages/lib/src/lua/net/mod.rs b/packages/lib/src/lua/net/mod.rs index 492ebf3..4042953 100644 --- a/packages/lib/src/lua/net/mod.rs +++ b/packages/lib/src/lua/net/mod.rs @@ -1,11 +1,9 @@ mod client; mod config; mod server; -mod ws_client; -mod ws_server; +mod websocket; pub use client::{NetClient, NetClientBuilder}; pub use config::{RequestConfig, ServeConfig}; pub use server::{NetLocalExec, NetService}; -pub use ws_client::NetWebSocketClient; -pub use ws_server::NetWebSocketServer; +pub use websocket::NetWebSocket; diff --git a/packages/lib/src/lua/net/server.rs b/packages/lib/src/lua/net/server.rs index 6e58822..58e3751 100644 --- a/packages/lib/src/lua/net/server.rs +++ b/packages/lib/src/lua/net/server.rs @@ -17,7 +17,7 @@ use crate::{ utils::table::TableBuilder, }; -use super::NetWebSocketServer; +use super::NetWebSocket; // Hyper service implementation for net, lots of boilerplate here // but make_svc and make_svc_function do not work for what we need @@ -52,20 +52,19 @@ impl Service> for NetServiceInner { // we want here is a long-running task that keeps the program alive let sched = lua .app_data_ref::<&TaskScheduler>() - .expect("Missing task scheduler - make sure it is added as a lua app data before the first scheduler resumption"); + .expect("Missing task scheduler"); let handle = sched.register_background_task(); task::spawn_local(async move { // Create our new full websocket object, then // schedule our handler to get called asap let ws = ws.await.map_err(LuaError::external)?; - let sock = NetWebSocketServer::from(ws); - let table = sock.into_lua_table(lua)?; + let sock = NetWebSocket::new(ws).into_lua_table(lua)?; let sched = lua .app_data_ref::<&TaskScheduler>() - .expect("Missing task scheduler - make sure it is added as a lua app data before the first scheduler resumption"); + .expect("Missing task scheduler"); let result = sched.schedule_blocking( lua.create_thread(handler)?, - LuaMultiValue::from_vec(vec![LuaValue::Table(table)]), + LuaMultiValue::from_vec(vec![LuaValue::Table(sock)]), ); handle.unregister(Ok(())); result diff --git a/packages/lib/src/lua/net/websocket.rs b/packages/lib/src/lua/net/websocket.rs new file mode 100644 index 0000000..d79ad8c --- /dev/null +++ b/packages/lib/src/lua/net/websocket.rs @@ -0,0 +1,86 @@ +use std::sync::Arc; + +use hyper::upgrade::Upgraded; +use mlua::prelude::*; + +use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream}; + +use futures_util::{SinkExt, StreamExt}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + net::TcpStream, + sync::Mutex, +}; +use tokio_tungstenite::MaybeTlsStream; + +use crate::utils::table::TableBuilder; + +#[derive(Debug, Clone)] +pub struct NetWebSocket { + stream: Arc>>, +} + +impl NetWebSocket +where + T: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(value: WebSocketStream) -> Self { + Self { + stream: Arc::new(Mutex::new(value)), + } + } + + pub async fn close(&self) -> LuaResult<()> { + let mut inner = self.stream.lock().await; + inner.close(None).await.map_err(LuaError::external)?; + Ok(()) + } + + pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { + let mut inner = self.stream.lock().await; + inner.send(msg).await.map_err(LuaError::external)?; + Ok(()) + } + + pub async fn next(&self) -> LuaResult> { + let mut inner = self.stream.lock().await; + let item = inner.next().await.transpose(); + item.map_err(LuaError::external) + } + + pub async fn send_string(&self, msg: String) -> LuaResult<()> { + self.send(WsMessage::Text(msg)).await + } + + pub async fn next_lua_value(&self, lua: &'static Lua) -> LuaResult { + Ok(match self.next().await? { + Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(&bin)?), + Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(&txt)?), + _ => LuaValue::Nil, + }) + } +} + +impl NetWebSocket> { + pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { + // FIXME: Deallocate when closed + let sock = Box::leak(Box::new(self)); + TableBuilder::new(lua)? + .with_async_function("close", |_, ()| sock.close())? + .with_async_function("send", |_, msg: String| sock.send_string(msg))? + .with_async_function("next", |lua, ()| sock.next_lua_value(lua))? + .build_readonly() + } +} + +impl NetWebSocket { + pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { + // FIXME: Deallocate when closed + let sock = Box::leak(Box::new(self)); + TableBuilder::new(lua)? + .with_async_function("close", |_, ()| sock.close())? + .with_async_function("send", |_, msg: String| sock.send_string(msg))? + .with_async_function("next", |lua, ()| sock.next_lua_value(lua))? + .build_readonly() + } +} diff --git a/packages/lib/src/lua/net/ws_client.rs b/packages/lib/src/lua/net/ws_client.rs deleted file mode 100644 index bc11111..0000000 --- a/packages/lib/src/lua/net/ws_client.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::sync::Arc; - -use mlua::prelude::*; - -use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream}; - -use futures_util::{SinkExt, StreamExt}; -use tokio::{net::TcpStream, sync::Mutex}; -use tokio_tungstenite::MaybeTlsStream; - -use crate::utils::table::TableBuilder; - -#[derive(Debug, Clone)] -pub struct NetWebSocketClient(Arc>>>); - -impl NetWebSocketClient { - pub async fn close(&self) -> LuaResult<()> { - let mut ws = self.0.lock().await; - ws.close(None).await.map_err(LuaError::external)?; - Ok(()) - } - - pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { - let mut ws = self.0.lock().await; - ws.send(msg).await.map_err(LuaError::external)?; - Ok(()) - } - - pub async fn next(&self) -> LuaResult> { - let mut ws = self.0.lock().await; - let item = ws.next().await.transpose(); - item.map_err(LuaError::external) - } - - pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { - // FIXME: Deallocate when closed - let client = Box::leak(Box::new(self)); - TableBuilder::new(lua)? - .with_async_function("close", |_, ()| async { - let result = client.close().await; - result - })? - .with_async_function("send", |_, message: String| async { - let result = client.send(WsMessage::Text(message)).await; - result - })? - .with_async_function("next", |lua, ()| async { - let result = client.next().await?; - Ok(match result { - Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(&bin)?), - Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(&txt)?), - _ => LuaValue::Nil, - }) - })? - .build_readonly() - } -} - -impl From>> for NetWebSocketClient { - fn from(value: WebSocketStream>) -> Self { - Self(Arc::new(Mutex::new(value))) - } -} diff --git a/packages/lib/src/lua/net/ws_server.rs b/packages/lib/src/lua/net/ws_server.rs deleted file mode 100644 index 5b289be..0000000 --- a/packages/lib/src/lua/net/ws_server.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::sync::Arc; - -use mlua::prelude::*; - -use hyper::upgrade::Upgraded; -use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream}; - -use futures_util::{SinkExt, StreamExt}; -use tokio::sync::Mutex; - -use crate::utils::table::TableBuilder; - -#[derive(Debug, Clone)] -pub struct NetWebSocketServer(Arc>>); - -impl NetWebSocketServer { - pub async fn close(&self) -> LuaResult<()> { - let mut ws = self.0.lock().await; - ws.close(None).await.map_err(LuaError::external)?; - Ok(()) - } - - pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { - let mut ws = self.0.lock().await; - ws.send(msg).await.map_err(LuaError::external)?; - Ok(()) - } - - pub async fn next(&self) -> LuaResult> { - let mut ws = self.0.lock().await; - let item = ws.next().await.transpose(); - item.map_err(LuaError::external) - } - - pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { - // FIXME: Deallocate when closed - let server = Box::leak(Box::new(self)); - TableBuilder::new(lua)? - .with_async_function("close", |_, ()| async { - let result = server.close().await; - result - })? - .with_async_function("send", |_, message: String| async { - let result = server.send(WsMessage::Text(message)).await; - result - })? - .with_async_function("next", |lua, ()| async { - let result = server.next().await?; - Ok(match result { - Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(&bin)?), - Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(&txt)?), - _ => LuaValue::Nil, - }) - })? - .build_readonly() - } -} - -impl From> for NetWebSocketServer { - fn from(value: WebSocketStream) -> Self { - Self(Arc::new(Mutex::new(value))) - } -}