use std::{cell::Cell, sync::Arc}; use hyper::upgrade::Upgraded; use mlua::prelude::*; use futures_util::{ stream::{SplitSink, SplitStream}, SinkExt, StreamExt, }; use tokio::{ io::{AsyncRead, AsyncWrite}, net::TcpStream, sync::Mutex as AsyncMutex, }; use hyper_tungstenite::{ tungstenite::{ protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame}, Message as WsMessage, }, WebSocketStream, }; use tokio_tungstenite::MaybeTlsStream; use crate::lune_temp::lua::table::TableBuilder; const WEB_SOCKET_IMPL_LUA: &str = r#" return freeze(setmetatable({ close = function(...) return close(websocket, ...) end, send = function(...) return send(websocket, ...) end, next = function(...) return next(websocket, ...) end, }, { __index = function(self, key) if key == "closeCode" then return close_code(websocket) end end, })) "#; #[derive(Debug)] pub struct NetWebSocket { close_code: Arc>>, read_stream: Arc>>>, write_stream: Arc, WsMessage>>>, } impl Clone for NetWebSocket { fn clone(&self) -> Self { Self { close_code: Arc::clone(&self.close_code), read_stream: Arc::clone(&self.read_stream), write_stream: Arc::clone(&self.write_stream), } } } impl NetWebSocket where T: AsyncRead + AsyncWrite + Unpin, { pub fn new(value: WebSocketStream) -> Self { let (write, read) = value.split(); Self { close_code: Arc::new(Cell::new(None)), read_stream: Arc::new(AsyncMutex::new(read)), write_stream: Arc::new(AsyncMutex::new(write)), } } fn into_lua_table_with_env<'lua>( lua: &'lua Lua, env: LuaTable<'lua>, ) -> LuaResult> { lua.load(WEB_SOCKET_IMPL_LUA) .set_name("websocket") .set_environment(env) .eval() } } type NetWebSocketStreamClient = MaybeTlsStream; impl NetWebSocket { pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { let socket_env = TableBuilder::new(lua)? .with_value("websocket", self)? .with_function("close_code", close_code::)? .with_async_function("close", close::)? .with_async_function("send", send::)? .with_async_function("next", next::)? .with_value( "setmetatable", lua.named_registry_value::("tab.setmeta")?, )? .with_value( "freeze", lua.named_registry_value::("tab.freeze")?, )? .build_readonly()?; Self::into_lua_table_with_env(lua, socket_env) } } type NetWebSocketStreamServer = Upgraded; impl NetWebSocket { pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult { let socket_env = TableBuilder::new(lua)? .with_value("websocket", self)? .with_function("close_code", close_code::)? .with_async_function("close", close::)? .with_async_function("send", send::)? .with_async_function("next", next::)? .with_value( "setmetatable", lua.named_registry_value::("tab.setmeta")?, )? .with_value( "freeze", lua.named_registry_value::("tab.freeze")?, )? .build_readonly()?; Self::into_lua_table_with_env(lua, socket_env) } } impl LuaUserData for NetWebSocket {} fn close_code<'lua, T>( _lua: &'lua Lua, socket: LuaUserDataRef<'lua, NetWebSocket>, ) -> LuaResult> where T: AsyncRead + AsyncWrite + Unpin, { Ok(match socket.close_code.get() { Some(code) => LuaValue::Number(code as f64), None => LuaValue::Nil, }) } async fn close<'lua, T>( _lua: &'lua Lua, (socket, code): (LuaUserDataRef<'lua, NetWebSocket>, Option), ) -> LuaResult<()> where T: AsyncRead + AsyncWrite + Unpin, { let mut ws = socket.write_stream.lock().await; ws.send(WsMessage::Close(Some(WsCloseFrame { code: match code { Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), Some(code) => { return Err(LuaError::RuntimeError(format!( "Close code must be between 1000 and 4999, got {code}" ))) } None => WsCloseCode::Normal, }, reason: "".into(), }))) .await .into_lua_err()?; let res = ws.close(); res.await.into_lua_err() } async fn send<'lua, T>( _lua: &'lua Lua, (socket, string, as_binary): ( LuaUserDataRef<'lua, NetWebSocket>, LuaString<'lua>, Option, ), ) -> LuaResult<()> where T: AsyncRead + AsyncWrite + Unpin, { let msg = if matches!(as_binary, Some(true)) { WsMessage::Binary(string.as_bytes().to_vec()) } else { let s = string.to_str().into_lua_err()?; WsMessage::Text(s.to_string()) }; let mut ws = socket.write_stream.lock().await; ws.send(msg).await.into_lua_err() } async fn next<'lua, T>( lua: &'lua Lua, socket: LuaUserDataRef<'lua, NetWebSocket>, ) -> LuaResult> where T: AsyncRead + AsyncWrite + Unpin, { let mut ws = socket.read_stream.lock().await; let item = ws.next().await.transpose().into_lua_err(); let msg = match item { Ok(Some(WsMessage::Close(msg))) => { if let Some(msg) = &msg { socket.close_code.replace(Some(msg.code.into())); } Ok(Some(WsMessage::Close(msg))) } val => val, }?; while let Some(msg) = &msg { let msg_string_opt = match msg { WsMessage::Binary(bin) => Some(lua.create_string(bin)?), WsMessage::Text(txt) => Some(lua.create_string(txt)?), // Stop waiting for next message if we get a close message WsMessage::Close(_) => return Ok(LuaValue::Nil), // Ignore ping/pong/frame messages, they are handled by tungstenite _ => None, }; if let Some(msg_string) = msg_string_opt { return Ok(LuaValue::String(msg_string)); } } Ok(LuaValue::Nil) }