mirror of
https://github.com/lune-org/lune.git
synced 2025-05-04 10:43:57 +01:00
Full implementation for websocket streams and client
This commit is contained in:
parent
a82cb1da33
commit
3f179ab4ec
9 changed files with 391 additions and 39 deletions
66
Cargo.lock
generated
66
Cargo.lock
generated
|
@ -297,6 +297,22 @@ version = "4.7.1"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de"
|
||||
|
||||
[[package]]
|
||||
name = "async-tungstenite"
|
||||
version = "0.29.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ef0f7efedeac57d9b26170f72965ecfd31473ca52ca7a64e925b0b6f5f079886"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
"log",
|
||||
"pin-project-lite",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "atomic-waker"
|
||||
version = "1.1.2"
|
||||
|
@ -801,6 +817,12 @@ dependencies = [
|
|||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476"
|
||||
|
||||
[[package]]
|
||||
name = "deflate64"
|
||||
version = "0.1.9"
|
||||
|
@ -1079,6 +1101,20 @@ version = "1.3.0"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.31"
|
||||
|
@ -1086,6 +1122,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -1153,9 +1190,13 @@ version = "0.3.31"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
|
@ -1857,8 +1898,10 @@ dependencies = [
|
|||
"async-io",
|
||||
"async-lock",
|
||||
"async-net",
|
||||
"async-tungstenite",
|
||||
"blocking",
|
||||
"bstr",
|
||||
"futures",
|
||||
"futures-lite",
|
||||
"futures-rustls",
|
||||
"http-body-util",
|
||||
|
@ -3691,6 +3734,23 @@ version = "0.2.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.26.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.3.1",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.9.1",
|
||||
"sha1 0.10.6",
|
||||
"thiserror 2.0.12",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typeid"
|
||||
version = "1.0.3"
|
||||
|
@ -3763,6 +3823,12 @@ dependencies = [
|
|||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "utf16_iter"
|
||||
version = "1.0.5"
|
||||
|
|
|
@ -21,8 +21,10 @@ async-executor = "1.13"
|
|||
async-io = "2.4"
|
||||
async-lock = "3.4"
|
||||
async-net = "2.0"
|
||||
async-tungstenite = "0.29"
|
||||
blocking = "1.6"
|
||||
bstr = "1.9"
|
||||
futures = { version = "0.3", default-features = false, features = ["std"] }
|
||||
futures-lite = "2.6"
|
||||
futures-rustls = "0.26"
|
||||
http-body-util = "0.1"
|
||||
|
|
|
@ -1,50 +1,41 @@
|
|||
use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
sync::{Arc, LazyLock},
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use async_net::TcpStream;
|
||||
use futures_lite::prelude::*;
|
||||
use futures_rustls::{TlsConnector, TlsStream};
|
||||
use hyper::Uri;
|
||||
use rustls::ClientConfig;
|
||||
use rustls_pki_types::ServerName;
|
||||
use url::Url;
|
||||
|
||||
static CLIENT_CONFIG: LazyLock<Arc<ClientConfig>> = LazyLock::new(|| {
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(rustls::RootCertStore {
|
||||
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
|
||||
})
|
||||
.with_no_client_auth()
|
||||
.into()
|
||||
});
|
||||
use crate::client::rustls::CLIENT_CONFIG;
|
||||
|
||||
pub enum HttpRequestStream {
|
||||
#[derive(Debug)]
|
||||
pub enum HttpStream {
|
||||
Plain(TcpStream),
|
||||
Tls(TlsStream<TcpStream>),
|
||||
}
|
||||
|
||||
impl HttpRequestStream {
|
||||
pub async fn connect(url: &Uri) -> Result<Self, io::Error> {
|
||||
impl HttpStream {
|
||||
pub async fn connect(url: Url) -> Result<Self, io::Error> {
|
||||
let Some(host) = url.host() else {
|
||||
return Err(make_err("unknown or missing host"));
|
||||
};
|
||||
let Some(scheme) = url.scheme_str() else {
|
||||
return Err(make_err("unknown scheme"));
|
||||
let Some(port) = url.port_or_known_default() else {
|
||||
return Err(make_err("unknown or missing port"));
|
||||
};
|
||||
|
||||
let (use_tls, port) = match scheme {
|
||||
"http" => (false, 80),
|
||||
"https" => (true, 443),
|
||||
let use_tls = match url.scheme() {
|
||||
"http" => false,
|
||||
"https" => true,
|
||||
s => return Err(make_err(format!("unsupported scheme: {s}"))),
|
||||
};
|
||||
|
||||
let stream = {
|
||||
let port = url.port_u16().unwrap_or(port);
|
||||
TcpStream::connect((host, port)).await?
|
||||
};
|
||||
let host = host.to_string();
|
||||
let stream = TcpStream::connect((host.clone(), port)).await?;
|
||||
|
||||
let stream = if use_tls {
|
||||
let servname = ServerName::try_from(host).map_err(make_err)?.to_owned();
|
||||
|
@ -59,42 +50,42 @@ impl HttpRequestStream {
|
|||
}
|
||||
}
|
||||
|
||||
impl AsyncRead for HttpRequestStream {
|
||||
impl AsyncRead for HttpStream {
|
||||
fn poll_read(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &mut [u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match &mut *self {
|
||||
HttpRequestStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
HttpRequestStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
HttpStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
HttpStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsyncWrite for HttpRequestStream {
|
||||
impl AsyncWrite for HttpStream {
|
||||
fn poll_write(
|
||||
mut self: Pin<&mut Self>,
|
||||
cx: &mut Context<'_>,
|
||||
buf: &[u8],
|
||||
) -> Poll<io::Result<usize>> {
|
||||
match &mut *self {
|
||||
HttpRequestStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
HttpRequestStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
HttpStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
HttpStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
HttpRequestStream::Plain(stream) => Pin::new(stream).poll_close(cx),
|
||||
HttpRequestStream::Tls(stream) => Pin::new(stream).poll_close(cx),
|
||||
HttpStream::Plain(stream) => Pin::new(stream).poll_close(cx),
|
||||
HttpStream::Tls(stream) => Pin::new(stream).poll_close(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
|
||||
match &mut *self {
|
||||
HttpRequestStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
|
||||
HttpRequestStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
|
||||
HttpStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
|
||||
HttpStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -6,21 +6,33 @@ use hyper::{
|
|||
};
|
||||
|
||||
use mlua::prelude::*;
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
client::stream::HttpRequestStream,
|
||||
client::{http_stream::HttpStream, ws_stream::WsStream},
|
||||
shared::{
|
||||
hyper::{HyperExecutor, HyperIo},
|
||||
request::Request,
|
||||
response::Response,
|
||||
websocket::Websocket,
|
||||
},
|
||||
};
|
||||
|
||||
pub mod config;
|
||||
pub mod stream;
|
||||
pub mod http_stream;
|
||||
pub mod rustls;
|
||||
pub mod ws_stream;
|
||||
|
||||
const MAX_REDIRECTS: usize = 10;
|
||||
|
||||
/**
|
||||
Connects to a websocket at the given URL.
|
||||
*/
|
||||
pub async fn connect_websocket(url: Url) -> LuaResult<Websocket<WsStream>> {
|
||||
let stream = WsStream::connect(url).await?;
|
||||
Ok(Websocket::from(stream))
|
||||
}
|
||||
|
||||
/**
|
||||
Sends the request and returns the final response.
|
||||
|
||||
|
@ -28,8 +40,14 @@ const MAX_REDIRECTS: usize = 10;
|
|||
modifying the request method and body as necessary.
|
||||
*/
|
||||
pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response> {
|
||||
let url = request
|
||||
.inner
|
||||
.uri()
|
||||
.to_string()
|
||||
.parse::<Url>()
|
||||
.expect("uri is valid");
|
||||
loop {
|
||||
let stream = HttpRequestStream::connect(request.inner.uri()).await?;
|
||||
let stream = HttpStream::connect(url.clone()).await?;
|
||||
|
||||
let (mut sender, conn) = handshake(HyperIo::from(stream)).await.into_lua_err()?;
|
||||
|
||||
|
|
12
crates/lune-std-net/src/client/rustls.rs
Normal file
12
crates/lune-std-net/src/client/rustls.rs
Normal file
|
@ -0,0 +1,12 @@
|
|||
use std::sync::{Arc, LazyLock};
|
||||
|
||||
use rustls::ClientConfig;
|
||||
|
||||
pub static CLIENT_CONFIG: LazyLock<Arc<ClientConfig>> = LazyLock::new(|| {
|
||||
rustls::ClientConfig::builder()
|
||||
.with_root_certificates(rustls::RootCertStore {
|
||||
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
|
||||
})
|
||||
.with_no_client_auth()
|
||||
.into()
|
||||
});
|
114
crates/lune-std-net/src/client/ws_stream.rs
Normal file
114
crates/lune-std-net/src/client/ws_stream.rs
Normal file
|
@ -0,0 +1,114 @@
|
|||
use std::{
|
||||
io,
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
task::{Context, Poll},
|
||||
};
|
||||
|
||||
use async_net::TcpStream;
|
||||
use async_tungstenite::{
|
||||
tungstenite::{Error as TungsteniteError, Message, Result as TungsteniteResult},
|
||||
WebSocketStream as TungsteniteStream,
|
||||
};
|
||||
use futures::Sink;
|
||||
use futures_lite::prelude::*;
|
||||
use futures_rustls::{TlsConnector, TlsStream};
|
||||
use rustls_pki_types::ServerName;
|
||||
use url::Url;
|
||||
|
||||
use crate::client::rustls::CLIENT_CONFIG;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum WsStream {
|
||||
Plain(TungsteniteStream<TcpStream>),
|
||||
Tls(TungsteniteStream<TlsStream<TcpStream>>),
|
||||
}
|
||||
|
||||
impl WsStream {
|
||||
pub async fn connect(url: Url) -> Result<Self, io::Error> {
|
||||
let Some(host) = url.host() else {
|
||||
return Err(make_err("unknown or missing host"));
|
||||
};
|
||||
let Some(port) = url.port_or_known_default() else {
|
||||
return Err(make_err("unknown or missing port"));
|
||||
};
|
||||
|
||||
let use_tls = match url.scheme() {
|
||||
"ws" => false,
|
||||
"wss" => true,
|
||||
s => return Err(make_err(format!("unsupported scheme: {s}"))),
|
||||
};
|
||||
|
||||
let host = host.to_string();
|
||||
let stream = TcpStream::connect((host.clone(), port)).await?;
|
||||
|
||||
let stream = if use_tls {
|
||||
let servname = ServerName::try_from(host).map_err(make_err)?.to_owned();
|
||||
let connector = TlsConnector::from(Arc::clone(&CLIENT_CONFIG));
|
||||
|
||||
let stream = connector.connect(servname, stream).await?;
|
||||
let stream = TlsStream::Client(stream);
|
||||
|
||||
let stream = async_tungstenite::client_async(url.to_string(), stream)
|
||||
.await
|
||||
.map_err(make_err)?
|
||||
.0;
|
||||
Self::Tls(stream)
|
||||
} else {
|
||||
let stream = async_tungstenite::client_async(url.to_string(), stream)
|
||||
.await
|
||||
.map_err(make_err)?
|
||||
.0;
|
||||
Self::Plain(stream)
|
||||
};
|
||||
|
||||
Ok(stream)
|
||||
}
|
||||
}
|
||||
|
||||
impl Sink<Message> for WsStream {
|
||||
type Error = TungsteniteError;
|
||||
|
||||
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
match &mut *self {
|
||||
WsStream::Plain(s) => Pin::new(s).poll_ready(cx),
|
||||
WsStream::Tls(s) => Pin::new(s).poll_ready(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
|
||||
match &mut *self {
|
||||
WsStream::Plain(s) => Pin::new(s).start_send(item),
|
||||
WsStream::Tls(s) => Pin::new(s).start_send(item),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
match &mut *self {
|
||||
WsStream::Plain(s) => Pin::new(s).poll_flush(cx),
|
||||
WsStream::Tls(s) => Pin::new(s).poll_flush(cx),
|
||||
}
|
||||
}
|
||||
|
||||
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
|
||||
match &mut *self {
|
||||
WsStream::Plain(s) => Pin::new(s).poll_close(cx),
|
||||
WsStream::Tls(s) => Pin::new(s).poll_close(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Stream for WsStream {
|
||||
type Item = TungsteniteResult<Message>;
|
||||
|
||||
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
|
||||
match &mut *self {
|
||||
WsStream::Plain(s) => Pin::new(s).poll_next(cx),
|
||||
WsStream::Tls(s) => Pin::new(s).poll_next(cx),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn make_err(e: impl ToString) -> io::Error {
|
||||
io::Error::new(io::ErrorKind::Other, e.to_string())
|
||||
}
|
|
@ -9,9 +9,9 @@ pub(crate) mod shared;
|
|||
pub(crate) mod url;
|
||||
|
||||
use self::{
|
||||
client::config::RequestConfig,
|
||||
client::{config::RequestConfig, ws_stream::WsStream},
|
||||
server::{config::ServeConfig, handle::ServeHandle},
|
||||
shared::{request::Request, response::Response},
|
||||
shared::{request::Request, response::Response, websocket::Websocket},
|
||||
};
|
||||
|
||||
const TYPEDEFS: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/types.d.luau"));
|
||||
|
@ -34,7 +34,7 @@ pub fn typedefs() -> String {
|
|||
pub fn module(lua: Lua) -> LuaResult<LuaTable> {
|
||||
TableBuilder::new(lua)?
|
||||
.with_async_function("request", net_request)?
|
||||
// .with_async_function("socket", net_socket)?
|
||||
.with_async_function("socket", net_socket)?
|
||||
.with_async_function("serve", net_serve)?
|
||||
.with_function("urlEncode", net_url_encode)?
|
||||
.with_function("urlDecode", net_url_decode)?
|
||||
|
@ -45,6 +45,11 @@ async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult<Response> {
|
|||
self::client::send_request(Request::try_from(config)?, lua).await
|
||||
}
|
||||
|
||||
async fn net_socket(_: Lua, url: String) -> LuaResult<Websocket<WsStream>> {
|
||||
let url = url.parse().into_lua_err()?;
|
||||
self::client::connect_websocket(url).await
|
||||
}
|
||||
|
||||
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<ServeHandle> {
|
||||
self::server::serve(lua, port, config).await
|
||||
}
|
||||
|
|
|
@ -4,3 +4,4 @@ pub mod hyper;
|
|||
pub mod incoming;
|
||||
pub mod request;
|
||||
pub mod response;
|
||||
pub mod websocket;
|
||||
|
|
143
crates/lune-std-net/src/shared/websocket.rs
Normal file
143
crates/lune-std-net/src/shared/websocket.rs
Normal file
|
@ -0,0 +1,143 @@
|
|||
use std::{
|
||||
error::Error,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU16, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use async_lock::Mutex as AsyncMutex;
|
||||
use async_tungstenite::tungstenite::{
|
||||
protocol::{frame::coding::CloseCode, CloseFrame},
|
||||
Message as TungsteniteMessage, Result as TungsteniteResult, Utf8Bytes,
|
||||
};
|
||||
use bstr::{BString, ByteSlice};
|
||||
use futures::{
|
||||
stream::{SplitSink, SplitStream},
|
||||
Sink, SinkExt, Stream, StreamExt,
|
||||
};
|
||||
use hyper::body::Bytes;
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Websocket<T> {
|
||||
close_code_exists: Arc<AtomicBool>,
|
||||
close_code_value: Arc<AtomicU16>,
|
||||
read_stream: Arc<AsyncMutex<SplitStream<T>>>,
|
||||
write_stream: Arc<AsyncMutex<SplitSink<T, TungsteniteMessage>>>,
|
||||
}
|
||||
|
||||
impl<T> Websocket<T>
|
||||
where
|
||||
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
|
||||
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
fn get_close_code(&self) -> Option<u16> {
|
||||
if self.close_code_exists.load(Ordering::Relaxed) {
|
||||
Some(self.close_code_value.load(Ordering::Relaxed))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn set_close_code(&self, code: u16) {
|
||||
self.close_code_exists.store(true, Ordering::Relaxed);
|
||||
self.close_code_value.store(code, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn send(&self, msg: TungsteniteMessage) -> LuaResult<()> {
|
||||
let mut ws = self.write_stream.lock().await;
|
||||
ws.send(msg).await.into_lua_err()
|
||||
}
|
||||
|
||||
pub async fn next(&self) -> LuaResult<Option<TungsteniteMessage>> {
|
||||
let mut ws = self.read_stream.lock().await;
|
||||
ws.next().await.transpose().into_lua_err()
|
||||
}
|
||||
|
||||
pub async fn close(&self, code: Option<u16>) -> LuaResult<()> {
|
||||
if self.close_code_exists.load(Ordering::Relaxed) {
|
||||
return Err(LuaError::runtime("Socket has already been closed"));
|
||||
}
|
||||
|
||||
self.send(TungsteniteMessage::Close(Some(CloseFrame {
|
||||
code: match code {
|
||||
Some(code) if (1000..=4999).contains(&code) => CloseCode::from(code),
|
||||
Some(code) => {
|
||||
return Err(LuaError::runtime(format!(
|
||||
"Close code must be between 1000 and 4999, got {code}"
|
||||
)))
|
||||
}
|
||||
None => CloseCode::Normal,
|
||||
},
|
||||
reason: "".into(),
|
||||
})))
|
||||
.await?;
|
||||
|
||||
let mut ws = self.write_stream.lock().await;
|
||||
ws.close().await.into_lua_err()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<T> for Websocket<T>
|
||||
where
|
||||
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
|
||||
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
fn from(value: T) -> Self {
|
||||
let (write, read) = value.split();
|
||||
|
||||
Self {
|
||||
close_code_exists: Arc::new(AtomicBool::new(false)),
|
||||
close_code_value: Arc::new(AtomicU16::new(0)),
|
||||
read_stream: Arc::new(AsyncMutex::new(read)),
|
||||
write_stream: Arc::new(AsyncMutex::new(write)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> LuaUserData for Websocket<T>
|
||||
where
|
||||
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
|
||||
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
|
||||
{
|
||||
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
|
||||
fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code()));
|
||||
}
|
||||
|
||||
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
|
||||
methods.add_async_method("close", |_, this, code: Option<u16>| async move {
|
||||
this.close(code).await
|
||||
});
|
||||
|
||||
methods.add_async_method(
|
||||
"send",
|
||||
|_, this, (string, as_binary): (BString, Option<bool>)| async move {
|
||||
this.send(if as_binary.unwrap_or_default() {
|
||||
TungsteniteMessage::Binary(Bytes::from(string.to_vec()))
|
||||
} else {
|
||||
let s = string.to_str().into_lua_err()?;
|
||||
TungsteniteMessage::Text(Utf8Bytes::from(s))
|
||||
})
|
||||
.await
|
||||
},
|
||||
);
|
||||
|
||||
methods.add_async_method("next", |lua, this, (): ()| async move {
|
||||
let msg = this.next().await?;
|
||||
|
||||
if let Some(TungsteniteMessage::Close(Some(frame))) = msg.as_ref() {
|
||||
this.set_close_code(frame.code.into());
|
||||
}
|
||||
|
||||
Ok(match msg {
|
||||
Some(TungsteniteMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?),
|
||||
Some(TungsteniteMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?),
|
||||
Some(TungsteniteMessage::Close(_)) | None => LuaValue::Nil,
|
||||
// Ignore ping/pong/frame messages, they are handled by tungstenite
|
||||
msg => unreachable!("Unhandled message: {:?}", msg),
|
||||
})
|
||||
});
|
||||
}
|
||||
}
|
Loading…
Add table
Reference in a new issue