Make websocket object generic over streams

This commit is contained in:
Filip Tibell 2023-02-20 16:08:45 +01:00
parent c57677bdd3
commit 5f3169c1bb
No known key found for this signature in database
6 changed files with 96 additions and 139 deletions

View file

@ -9,8 +9,8 @@ use tokio::{sync::mpsc, task};
use crate::{ use crate::{
lua::{ lua::{
net::{ net::{
NetClient, NetClientBuilder, NetLocalExec, NetService, NetWebSocketClient, NetClient, NetClientBuilder, NetLocalExec, NetService, NetWebSocket, RequestConfig,
RequestConfig, ServeConfig, ServeConfig,
}, },
task::{TaskScheduler, TaskSchedulerAsyncExt}, task::{TaskScheduler, TaskSchedulerAsyncExt},
}, },
@ -84,7 +84,7 @@ async fn net_socket<'a>(lua: &'static Lua, url: String) -> LuaResult<LuaTable> {
let (ws, _) = tokio_tungstenite::connect_async(url) let (ws, _) = tokio_tungstenite::connect_async(url)
.await .await
.map_err(LuaError::external)?; .map_err(LuaError::external)?;
NetWebSocketClient::from(ws).into_lua_table(lua) NetWebSocket::new(ws).into_lua_table(lua)
} }
async fn net_serve<'a>( async fn net_serve<'a>(

View file

@ -1,11 +1,9 @@
mod client; mod client;
mod config; mod config;
mod server; mod server;
mod ws_client; mod websocket;
mod ws_server;
pub use client::{NetClient, NetClientBuilder}; pub use client::{NetClient, NetClientBuilder};
pub use config::{RequestConfig, ServeConfig}; pub use config::{RequestConfig, ServeConfig};
pub use server::{NetLocalExec, NetService}; pub use server::{NetLocalExec, NetService};
pub use ws_client::NetWebSocketClient; pub use websocket::NetWebSocket;
pub use ws_server::NetWebSocketServer;

View file

@ -17,7 +17,7 @@ use crate::{
utils::table::TableBuilder, utils::table::TableBuilder,
}; };
use super::NetWebSocketServer; use super::NetWebSocket;
// Hyper service implementation for net, lots of boilerplate here // Hyper service implementation for net, lots of boilerplate here
// but make_svc and make_svc_function do not work for what we need // but make_svc and make_svc_function do not work for what we need
@ -52,20 +52,19 @@ impl Service<Request<Body>> for NetServiceInner {
// we want here is a long-running task that keeps the program alive // we want here is a long-running task that keeps the program alive
let sched = lua let sched = lua
.app_data_ref::<&TaskScheduler>() .app_data_ref::<&TaskScheduler>()
.expect("Missing task scheduler - make sure it is added as a lua app data before the first scheduler resumption"); .expect("Missing task scheduler");
let handle = sched.register_background_task(); let handle = sched.register_background_task();
task::spawn_local(async move { task::spawn_local(async move {
// Create our new full websocket object, then // Create our new full websocket object, then
// schedule our handler to get called asap // schedule our handler to get called asap
let ws = ws.await.map_err(LuaError::external)?; let ws = ws.await.map_err(LuaError::external)?;
let sock = NetWebSocketServer::from(ws); let sock = NetWebSocket::new(ws).into_lua_table(lua)?;
let table = sock.into_lua_table(lua)?;
let sched = lua let sched = lua
.app_data_ref::<&TaskScheduler>() .app_data_ref::<&TaskScheduler>()
.expect("Missing task scheduler - make sure it is added as a lua app data before the first scheduler resumption"); .expect("Missing task scheduler");
let result = sched.schedule_blocking( let result = sched.schedule_blocking(
lua.create_thread(handler)?, lua.create_thread(handler)?,
LuaMultiValue::from_vec(vec![LuaValue::Table(table)]), LuaMultiValue::from_vec(vec![LuaValue::Table(sock)]),
); );
handle.unregister(Ok(())); handle.unregister(Ok(()));
result result

View file

@ -0,0 +1,86 @@
use std::sync::Arc;
use hyper::upgrade::Upgraded;
use mlua::prelude::*;
use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream};
use futures_util::{SinkExt, StreamExt};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
sync::Mutex,
};
use tokio_tungstenite::MaybeTlsStream;
use crate::utils::table::TableBuilder;
#[derive(Debug, Clone)]
pub struct NetWebSocket<T> {
stream: Arc<Mutex<WebSocketStream<T>>>,
}
impl<T> NetWebSocket<T>
where
T: AsyncRead + AsyncWrite + Unpin,
{
pub fn new(value: WebSocketStream<T>) -> Self {
Self {
stream: Arc::new(Mutex::new(value)),
}
}
pub async fn close(&self) -> LuaResult<()> {
let mut inner = self.stream.lock().await;
inner.close(None).await.map_err(LuaError::external)?;
Ok(())
}
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
let mut inner = self.stream.lock().await;
inner.send(msg).await.map_err(LuaError::external)?;
Ok(())
}
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
let mut inner = self.stream.lock().await;
let item = inner.next().await.transpose();
item.map_err(LuaError::external)
}
pub async fn send_string(&self, msg: String) -> LuaResult<()> {
self.send(WsMessage::Text(msg)).await
}
pub async fn next_lua_value(&self, lua: &'static Lua) -> LuaResult<LuaValue> {
Ok(match self.next().await? {
Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(&bin)?),
Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(&txt)?),
_ => LuaValue::Nil,
})
}
}
impl NetWebSocket<MaybeTlsStream<TcpStream>> {
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
// FIXME: Deallocate when closed
let sock = Box::leak(Box::new(self));
TableBuilder::new(lua)?
.with_async_function("close", |_, ()| sock.close())?
.with_async_function("send", |_, msg: String| sock.send_string(msg))?
.with_async_function("next", |lua, ()| sock.next_lua_value(lua))?
.build_readonly()
}
}
impl NetWebSocket<Upgraded> {
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
// FIXME: Deallocate when closed
let sock = Box::leak(Box::new(self));
TableBuilder::new(lua)?
.with_async_function("close", |_, ()| sock.close())?
.with_async_function("send", |_, msg: String| sock.send_string(msg))?
.with_async_function("next", |lua, ()| sock.next_lua_value(lua))?
.build_readonly()
}
}

View file

@ -1,63 +0,0 @@
use std::sync::Arc;
use mlua::prelude::*;
use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream};
use futures_util::{SinkExt, StreamExt};
use tokio::{net::TcpStream, sync::Mutex};
use tokio_tungstenite::MaybeTlsStream;
use crate::utils::table::TableBuilder;
#[derive(Debug, Clone)]
pub struct NetWebSocketClient(Arc<Mutex<WebSocketStream<MaybeTlsStream<TcpStream>>>>);
impl NetWebSocketClient {
pub async fn close(&self) -> LuaResult<()> {
let mut ws = self.0.lock().await;
ws.close(None).await.map_err(LuaError::external)?;
Ok(())
}
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
let mut ws = self.0.lock().await;
ws.send(msg).await.map_err(LuaError::external)?;
Ok(())
}
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
let mut ws = self.0.lock().await;
let item = ws.next().await.transpose();
item.map_err(LuaError::external)
}
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
// FIXME: Deallocate when closed
let client = Box::leak(Box::new(self));
TableBuilder::new(lua)?
.with_async_function("close", |_, ()| async {
let result = client.close().await;
result
})?
.with_async_function("send", |_, message: String| async {
let result = client.send(WsMessage::Text(message)).await;
result
})?
.with_async_function("next", |lua, ()| async {
let result = client.next().await?;
Ok(match result {
Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(&bin)?),
Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(&txt)?),
_ => LuaValue::Nil,
})
})?
.build_readonly()
}
}
impl From<WebSocketStream<MaybeTlsStream<TcpStream>>> for NetWebSocketClient {
fn from(value: WebSocketStream<MaybeTlsStream<TcpStream>>) -> Self {
Self(Arc::new(Mutex::new(value)))
}
}

View file

@ -1,63 +0,0 @@
use std::sync::Arc;
use mlua::prelude::*;
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream};
use futures_util::{SinkExt, StreamExt};
use tokio::sync::Mutex;
use crate::utils::table::TableBuilder;
#[derive(Debug, Clone)]
pub struct NetWebSocketServer(Arc<Mutex<WebSocketStream<Upgraded>>>);
impl NetWebSocketServer {
pub async fn close(&self) -> LuaResult<()> {
let mut ws = self.0.lock().await;
ws.close(None).await.map_err(LuaError::external)?;
Ok(())
}
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
let mut ws = self.0.lock().await;
ws.send(msg).await.map_err(LuaError::external)?;
Ok(())
}
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
let mut ws = self.0.lock().await;
let item = ws.next().await.transpose();
item.map_err(LuaError::external)
}
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
// FIXME: Deallocate when closed
let server = Box::leak(Box::new(self));
TableBuilder::new(lua)?
.with_async_function("close", |_, ()| async {
let result = server.close().await;
result
})?
.with_async_function("send", |_, message: String| async {
let result = server.send(WsMessage::Text(message)).await;
result
})?
.with_async_function("next", |lua, ()| async {
let result = server.next().await?;
Ok(match result {
Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(&bin)?),
Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(&txt)?),
_ => LuaValue::Nil,
})
})?
.build_readonly()
}
}
impl From<WebSocketStream<Upgraded>> for NetWebSocketServer {
fn from(value: WebSocketStream<Upgraded>) -> Self {
Self(Arc::new(Mutex::new(value)))
}
}