Improve websocket behavior, implement close code

This commit is contained in:
Filip Tibell 2023-03-02 19:46:17 +01:00
parent d63ac6191a
commit d01d2a27f4
No known key found for this signature in database
2 changed files with 97 additions and 30 deletions

View file

@ -294,14 +294,17 @@ export type NetServeHandle = {
* `next` will no longer return any message(s) and instead instantly return nil * `next` will no longer return any message(s) and instead instantly return nil
* `send` will throw an error stating that the socket has been closed * `send` will throw an error stating that the socket has been closed
Once the websocket has been closed, `closeCode` will no longer be nil, and will be populated with a close
code according to the [WebSocket specification](https://www.iana.org/assignments/websocket/websocket.xhtml).
This will be an integer between 1000 and 4999, where 1000 is the canonical code for normal, error-free closure.
]=] ]=]
declare class NetWebSocket export type NetWebSocket = {
close: () -> () closeCode: number?,
send: (message: string) -> () close: (code: number?) -> (),
next: () -> string? send: (message: string, asBinaryMessage: boolean?) -> (),
iter: () -> () -> string next: () -> string?,
function __iter(self): () -> string }
end
--[=[ --[=[
@class Net @class Net

View file

@ -1,8 +1,14 @@
use std::sync::Arc; use std::{cell::Cell, sync::Arc};
use mlua::prelude::*; use mlua::prelude::*;
use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream}; use hyper_tungstenite::{
tungstenite::{
protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame},
Message as WsMessage,
},
WebSocketStream,
};
use futures_util::{SinkExt, StreamExt}; use futures_util::{SinkExt, StreamExt};
use tokio::{ use tokio::{
@ -14,6 +20,7 @@ use crate::lua::table::TableBuilder;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct NetWebSocket<T> { pub struct NetWebSocket<T> {
close_code: Cell<Option<u16>>,
stream: Arc<Mutex<WebSocketStream<T>>>, stream: Arc<Mutex<WebSocketStream<T>>>,
} }
@ -23,26 +30,83 @@ where
{ {
pub fn new(value: WebSocketStream<T>) -> Self { pub fn new(value: WebSocketStream<T>) -> Self {
Self { Self {
close_code: Cell::new(None),
stream: Arc::new(Mutex::new(value)), stream: Arc::new(Mutex::new(value)),
} }
} }
pub async fn close(&self) -> LuaResult<()> { pub fn get_lua_close_code(&self) -> LuaValue {
match self.close_code.get() {
Some(code) => LuaValue::Number(code as f64),
None => LuaValue::Nil,
}
}
pub async fn close(&self, code: Option<u16>) -> LuaResult<()> {
let mut ws = self.stream.lock().await; let mut ws = self.stream.lock().await;
ws.close(None).await.map_err(LuaError::external)?; let res = ws.close(Some(WsCloseFrame {
Ok(()) code: match code {
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
Some(code) => {
return Err(LuaError::RuntimeError(format!(
"Close code must be between 1000 and 4999, got {code}"
)))
}
None => WsCloseCode::Normal,
},
reason: "".into(),
}));
res.await.map_err(LuaError::external)
} }
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
let mut ws = self.stream.lock().await; let mut ws = self.stream.lock().await;
ws.send(msg).await.map_err(LuaError::external)?; ws.send(msg).await.map_err(LuaError::external)
Ok(()) }
pub async fn send_lua_string<'lua>(
&self,
string: LuaString<'lua>,
as_binary: Option<bool>,
) -> LuaResult<()> {
let msg = if matches!(as_binary, Some(true)) {
WsMessage::Binary(string.as_bytes().to_vec())
} else {
let s = string.to_str().map_err(LuaError::external)?;
WsMessage::Text(s.to_string())
};
self.send(msg).await
} }
pub async fn next(&self) -> LuaResult<Option<WsMessage>> { pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
let mut ws = self.stream.lock().await; let mut ws = self.stream.lock().await;
let item = ws.next().await.transpose(); let item = ws.next().await.transpose().map_err(LuaError::external);
item.map_err(LuaError::external) match item {
Ok(Some(WsMessage::Close(msg))) => {
if let Some(msg) = &msg {
self.close_code.replace(Some(msg.code.into()));
}
Ok(Some(WsMessage::Close(msg)))
}
val => val,
}
}
pub async fn next_lua_string<'lua>(&'lua self, lua: &'lua Lua) -> LuaResult<LuaValue> {
while let Some(msg) = self.next().await? {
let msg_string_opt = match msg {
WsMessage::Binary(bin) => Some(lua.create_string(&bin)?),
WsMessage::Text(txt) => Some(lua.create_string(&txt)?),
// Stop waiting for next message if we get a close message
WsMessage::Close(_) => return Ok(LuaValue::Nil),
// Ignore ping/pong/frame messages, they are handled by tungstenite
_ => None,
};
if let Some(msg_string) = msg_string_opt {
return Ok(LuaValue::String(msg_string));
}
}
Ok(LuaValue::Nil)
} }
} }
@ -53,20 +117,20 @@ where
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> { pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
let ws = Box::leak(Box::new(self)); let ws = Box::leak(Box::new(self));
TableBuilder::new(lua)? TableBuilder::new(lua)?
.with_async_function("close", |_, _: ()| async { ws.close().await })? .with_async_function("close", |_, code| ws.close(code))?
.with_async_function("send", |_, msg: String| async { .with_async_function("send", |_, (msg, bin)| ws.send_lua_string(msg, bin))?
ws.send(WsMessage::Text(msg)).await .with_async_function("next", |lua, _: ()| ws.next_lua_string(lua))?
})? .with_metatable(
.with_async_function("next", |_, _: ()| async { TableBuilder::new(lua)?
match ws.next().await? { .with_function(LuaMetaMethod::Index.name(), |_, key: String| {
Some(msg) => Ok(match msg { if key == "closeCode" {
WsMessage::Binary(bin) => LuaValue::String(lua.create_string(&bin)?), Ok(ws.get_lua_close_code())
WsMessage::Text(txt) => LuaValue::String(lua.create_string(&txt)?), } else {
_ => LuaValue::Nil, Ok(LuaValue::Nil)
}), }
None => Ok(LuaValue::Nil), })?
} .build_readonly()?,
})? )?
.build_readonly() .build_readonly()
} }
} }