Implement web sockets for net serve

This commit is contained in:
Filip Tibell 2023-02-11 22:40:14 +01:00
parent 962d89fd40
commit 6c97003571
No known key found for this signature in database
13 changed files with 320 additions and 99 deletions

View file

@ -9,6 +9,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added ### Added
- `net.serve` now supports web sockets in addition to normal http requests!
Example usage:
```lua
net.serve(8080, {
handleRequest = function(request)
return "Hello, world!"
end,
handleWebSocket = function(socket)
task.delay(10, function()
socket.send("Timed out!")
socket.close()
end)
-- This will yield waiting for new messages, and will break
-- when the socket was closed by either the server or client
for message in socket do
if message == "Ping" then
socket.send("Pong")
end
end
end,
})
```
- `net.serve` now returns a `NetServeHandle` which can be used to stop serving requests safely. - `net.serve` now returns a `NetServeHandle` which can be used to stop serving requests safely.
Example usage: Example usage:
@ -19,6 +44,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
end) end)
print("Shutting down after 1 second...") print("Shutting down after 1 second...")
task.wait(1)
handle.stop() handle.stop()
print("Shut down succesfully") print("Shut down succesfully")
``` ```

View file

@ -48,7 +48,7 @@ globals:
net.serve: net.serve:
args: args:
- type: number - type: number
- type: function - type: function | table
# Processs # Processs
process.args: process.args:
property: read-only property: read-only

View file

@ -153,10 +153,24 @@ export type NetResponse = {
body: string?, body: string?,
} }
type NetServeHttpHandler = (request: NetRequest) -> (string | NetResponse)
type NetServeWebSocketHandler = (socket: NetWebSocket) -> ()
export type NetServeConfig = {
handleRequest: NetServeHttpHandler?,
handleWebSocket: NetServeWebSocketHandler?,
}
export type NetServeHandle = { export type NetServeHandle = {
stop: () -> (), stop: () -> (),
} }
declare class NetWebSocket
close: () -> ()
send: (message: string) -> ()
function __iter(self): () -> string
end
--[=[ --[=[
@class net @class net
@ -183,9 +197,9 @@ declare net: {
until the `stop` function on the returned `NetServeHandle` has been called. until the `stop` function on the returned `NetServeHandle` has been called.
@param port The port to use for the server @param port The port to use for the server
@param handler The handler function to use for the server @param handlerOrConfig The handler function or config to use for the server
]=] ]=]
serve: (port: number, handler: (request: NetRequest) -> (string | NetResponse)) -> NetServeHandle, serve: (port: number, handlerOrConfig: NetServeHttpHandler | NetServeConfig) -> NetServeHandle,
--[=[ --[=[
@within net @within net

View file

@ -8,40 +8,27 @@ use std::{
use mlua::prelude::*; use mlua::prelude::*;
use hyper::{body::to_bytes, http::HeaderValue, server::conn::AddrStream, service::Service}; use hyper::{body::to_bytes, server::conn::AddrStream, service::Service};
use hyper::{Body, HeaderMap, Request, Response, Server}; use hyper::{Body, Request, Response, Server};
use hyper_tungstenite::{ use hyper_tungstenite::{is_upgrade_request as is_ws_upgrade_request, upgrade as ws_upgrade};
is_upgrade_request as is_ws_upgrade_request, tungstenite::Message as WsMessage,
upgrade as ws_upgrade,
};
use futures_util::{SinkExt, StreamExt}; use reqwest::Method;
use reqwest::{ClientBuilder, Method};
use tokio::{ use tokio::{
sync::mpsc::{self, Sender}, sync::mpsc::{self, Sender},
task, task,
}; };
use crate::utils::{ use crate::{
message::LuneMessage, lua::net::{NetClient, NetClientBuilder, NetWebSocketServer, ServeConfig},
net::{get_request_user_agent_header, NetClient}, utils::{message::LuneMessage, net::get_request_user_agent_header, table::TableBuilder},
table::TableBuilder,
}; };
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> { pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
// Create a reusable client for performing our // Create a reusable client for performing our
// web requests and store it in the lua registry // web requests and store it in the lua registry
let mut default_headers = HeaderMap::new(); let client = NetClientBuilder::new()
default_headers.insert( .headers(&[("User-Agent", get_request_user_agent_header())])?
"User-Agent", .build()?;
HeaderValue::from_str(&get_request_user_agent_header()).map_err(LuaError::external)?,
);
let client = NetClient::new(
ClientBuilder::new()
.default_headers(default_headers)
.build()
.map_err(LuaError::external)?,
);
lua.set_named_registry_value("NetClient", client)?; lua.set_named_registry_value("NetClient", client)?;
// Create the global table for net // Create the global table for net
TableBuilder::new(lua)? TableBuilder::new(lua)?
@ -158,17 +145,24 @@ async fn net_request<'a>(lua: &'static Lua, config: LuaValue<'a>) -> LuaResult<L
async fn net_serve<'a>( async fn net_serve<'a>(
lua: &'static Lua, lua: &'static Lua,
(port, callback): (u16, LuaFunction<'a>), // TODO: Parse options as either callback or table with request callback + websocket callback (port, config): (u16, ServeConfig<'a>),
) -> LuaResult<LuaTable<'a>> { ) -> LuaResult<LuaTable<'a>> {
// Note that we need to use a mpsc here and not // Note that we need to use a mpsc here and not
// a oneshot channel since we move the sender // a oneshot channel since we move the sender
// into our table with the stop function // into our table with the stop function
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
let websocket_callback = Arc::new(None); // TODO: Store websocket callback, if given let server_request_callback = Arc::new(lua.create_registry_value(config.handle_request)?);
let server_callback = Arc::new(lua.create_registry_value(callback)?); let server_websocket_callback = Arc::new(config.handle_web_socket.map(|handler| {
lua.create_registry_value(handler)
.expect("Failed to store websocket handler")
}));
let server = Server::bind(&([127, 0, 0, 1], port).into()) let server = Server::bind(&([127, 0, 0, 1], port).into())
.executor(LocalExec) .executor(LocalExec)
.serve(MakeNetService(lua, server_callback, websocket_callback)) .serve(MakeNetService(
lua,
server_request_callback,
server_websocket_callback,
))
.with_graceful_shutdown(async move { .with_graceful_shutdown(async move {
shutdown_rx.recv().await.unwrap(); shutdown_rx.recv().await.unwrap();
shutdown_rx.close(); shutdown_rx.close();
@ -225,40 +219,25 @@ impl Service<Request<Body>> for NetService {
fn call(&mut self, mut req: Request<Body>) -> Self::Future { fn call(&mut self, mut req: Request<Body>) -> Self::Future {
let lua = self.0; let lua = self.0;
if self.2.is_some() && is_ws_upgrade_request(&req) { if self.2.is_some() && is_ws_upgrade_request(&req) {
// Websocket request + websocket handler exists, // Websocket upgrade request + websocket handler exists,
// we should upgrade this connection to a websocket // we should now upgrade this connection to a websocket
// and then pass a socket object to our lua handler // and then call our handler with a new socket object
let kopt = self.2.clone(); let kopt = self.2.clone();
let key = kopt.as_ref().as_ref().unwrap(); let key = kopt.as_ref().as_ref().unwrap();
let handler: LuaFunction = lua.registry_value(key).expect("Missing websocket handler"); let handler: LuaFunction = lua.registry_value(key).expect("Missing websocket handler");
let (response, ws) = ws_upgrade(&mut req, None).expect("Failed to upgrade websocket"); let (response, ws) = ws_upgrade(&mut req, None).expect("Failed to upgrade websocket");
task::spawn_local(async move { task::spawn_local(async move {
if let Ok(mut websocket) = ws.await { // Create our new full websocket object
// TODO: Create lua userdata websocket object let ws = ws.await.map_err(LuaError::external)?;
// with methods for interacting with the websocket let ws_lua = NetWebSocketServer::from(ws);
// TODO: Start waiting for messages when we know let ws_proper = ws_lua.into_proper(lua).await?;
// for sure that we have gotten a message handler // Call our handler with it
// and move the following logic into there instead handler.call_async::<_, ()>(ws_proper).await
while let Some(message) = websocket.next().await {
// Create lua strings from websocket messages
if let Some(handler_str) = match message.map_err(LuaError::external)? {
WsMessage::Text(msg) => Some(lua.create_string(&msg)?),
WsMessage::Binary(msg) => Some(lua.create_string(&msg)?),
// Tungstenite takes care of these messages
WsMessage::Ping(_) => None,
WsMessage::Pong(_) => None,
WsMessage::Close(_) => None,
WsMessage::Frame(_) => None,
} {
// TODO: Call whatever lua handler we have registered, with our message string
}
}
}
Ok::<_, LuaError>(())
}); });
Box::pin(async move { Ok(response) }) Box::pin(async move { Ok(response) })
} else { } else {
// Normal http request or no websocket handler exists, call the http request handler // Got a normal http request or no websocket handler
// exists, just call the http request handler
let key = self.1.clone(); let key = self.1.clone();
let (parts, body) = req.into_parts(); let (parts, body) = req.into_parts();
Box::pin(async move { Box::pin(async move {

View file

@ -4,6 +4,7 @@ use mlua::prelude::*;
use tokio::{sync::mpsc, task}; use tokio::{sync::mpsc, task};
pub(crate) mod globals; pub(crate) mod globals;
pub(crate) mod lua;
pub(crate) mod utils; pub(crate) mod utils;
#[cfg(test)] #[cfg(test)]

View file

@ -0,0 +1 @@
pub mod net;

View file

@ -0,0 +1,49 @@
use std::str::FromStr;
use mlua::prelude::*;
use hyper::{header::HeaderName, http::HeaderValue, HeaderMap};
use reqwest::{IntoUrl, Method, RequestBuilder};
pub struct NetClientBuilder {
builder: reqwest::ClientBuilder,
}
impl NetClientBuilder {
pub fn new() -> NetClientBuilder {
Self {
builder: reqwest::ClientBuilder::new(),
}
}
pub fn headers<K, V>(mut self, headers: &[(K, V)]) -> LuaResult<Self>
where
K: AsRef<str>,
V: AsRef<[u8]>,
{
let mut map = HeaderMap::new();
for (key, val) in headers {
let hkey = HeaderName::from_str(key.as_ref()).map_err(LuaError::external)?;
let hval = HeaderValue::from_bytes(val.as_ref()).map_err(LuaError::external)?;
map.insert(hkey, hval);
}
self.builder = self.builder.default_headers(map);
Ok(self)
}
pub fn build(self) -> LuaResult<NetClient> {
let client = self.builder.build().map_err(LuaError::external)?;
Ok(NetClient(client))
}
}
#[derive(Debug, Clone)]
pub struct NetClient(reqwest::Client);
impl NetClient {
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
self.0.request(method, url)
}
}
impl LuaUserData for NetClient {}

View file

@ -0,0 +1,50 @@
use mlua::prelude::*;
pub struct ServeConfig<'a> {
pub handle_request: LuaFunction<'a>,
pub handle_web_socket: Option<LuaFunction<'a>>,
}
impl<'lua> FromLua<'lua> for ServeConfig<'lua> {
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
let message = match &value {
LuaValue::Function(f) => {
return Ok(ServeConfig {
handle_request: f.clone(),
handle_web_socket: None,
})
}
LuaValue::Table(t) => {
let handle_request: Option<LuaFunction> = t.raw_get("handleRequest")?;
let handle_web_socket: Option<LuaFunction> = t.raw_get("handleWebSocket")?;
if handle_request.is_some() || handle_web_socket.is_some() {
return Ok(ServeConfig {
handle_request: handle_request.unwrap_or_else(|| {
let chunk = r#"
return {
status = 426,
body = "Upgrade Required",
headers = {
Upgrade = "websocket",
},
}
"#;
lua.load(chunk)
.into_function()
.expect("Failed to create default http responder function")
}),
handle_web_socket,
});
} else {
Some("Missing handleRequest and / or handleWebSocket".to_string())
}
}
_ => None,
};
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig",
message,
})
}
}

View file

@ -0,0 +1,7 @@
mod client;
mod config;
mod ws_server;
pub use client::{NetClient, NetClientBuilder};
pub use config::ServeConfig;
pub use ws_server::NetWebSocketServer;

View file

@ -0,0 +1,87 @@
use std::sync::Arc;
use mlua::prelude::*;
use hyper::upgrade::Upgraded;
use hyper_tungstenite::{tungstenite::Message as WsMessage, WebSocketStream};
use futures_util::{SinkExt, StreamExt};
use tokio::sync::Mutex;
#[derive(Debug, Clone)]
pub struct NetWebSocketServer(Arc<Mutex<WebSocketStream<Upgraded>>>);
impl NetWebSocketServer {
pub async fn close(&self) -> LuaResult<()> {
let mut ws = self.0.lock().await;
ws.close(None).await.map_err(LuaError::external)?;
Ok(())
}
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
let mut ws = self.0.lock().await;
ws.send(msg).await.map_err(LuaError::external)?;
Ok(())
}
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
let mut ws = self.0.lock().await;
let item = ws.next().await.transpose();
item.map_err(LuaError::external)
}
pub async fn into_proper(self, lua: &'static Lua) -> LuaResult<LuaTable> {
// HACK: This creates a new userdata that consumes and proxies this one,
// since there's no great way to implement this in pure async Rust
// and as a plain table without tons of strange lifetime issues
let chunk = r#"
return function(ws)
local proxy = newproxy(true)
local meta = getmetatable(proxy)
meta.__index = {
close = function()
return ws:close()
end,
send = function(...)
return ws:send(...)
end,
next = function()
return ws:next()
end,
}
meta.__iter = function()
return function()
return ws:next()
end
end
return proxy
end
"#;
lua.load(chunk).call_async(self).await
}
}
impl LuaUserData for NetWebSocketServer {
fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("close", |_, this, _: ()| async move { this.close().await });
methods.add_async_method("send", |_, this, msg: String| async move {
this.send(WsMessage::Text(msg)).await
});
methods.add_async_method("next", |lua, this, _: ()| async move {
match this.next().await? {
Some(msg) => Ok(match msg {
WsMessage::Binary(bin) => LuaValue::String(lua.create_string(&bin)?),
WsMessage::Text(txt) => LuaValue::String(lua.create_string(&txt)?),
_ => LuaValue::Nil,
}),
None => Ok(LuaValue::Nil),
}
})
}
}
impl From<WebSocketStream<Upgraded>> for NetWebSocketServer {
fn from(value: WebSocketStream<Upgraded>) -> Self {
Self(Arc::new(Mutex::new(value)))
}
}

View file

@ -1,21 +1,3 @@
use mlua::prelude::*;
use reqwest::{IntoUrl, Method, RequestBuilder};
#[derive(Clone)]
pub struct NetClient(reqwest::Client);
impl NetClient {
pub fn new(client: reqwest::Client) -> Self {
Self(client)
}
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
self.0.request(method, url)
}
}
impl LuaUserData for NetClient {}
pub fn get_github_owner_and_repo() -> (String, String) { pub fn get_github_owner_and_repo() -> (String, String) {
let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY") let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY")
.strip_prefix("https://github.com/") .strip_prefix("https://github.com/")

View file

@ -2,22 +2,22 @@ use std::future::Future;
use mlua::prelude::*; use mlua::prelude::*;
pub struct TableBuilder<'lua> { pub struct TableBuilder {
lua: &'lua Lua, lua: &'static Lua,
tab: LuaTable<'lua>, tab: LuaTable<'static>,
} }
#[allow(dead_code)] #[allow(dead_code)]
impl<'lua> TableBuilder<'lua> { impl TableBuilder {
pub fn new(lua: &'lua Lua) -> LuaResult<Self> { pub fn new(lua: &'static Lua) -> LuaResult<Self> {
let tab = lua.create_table()?; let tab = lua.create_table()?;
Ok(Self { lua, tab }) Ok(Self { lua, tab })
} }
pub fn with_value<K, V>(self, key: K, value: V) -> LuaResult<Self> pub fn with_value<K, V>(self, key: K, value: V) -> LuaResult<Self>
where where
K: ToLua<'lua>, K: ToLua<'static>,
V: ToLua<'lua>, V: ToLua<'static>,
{ {
self.tab.raw_set(key, value)?; self.tab.raw_set(key, value)?;
Ok(self) Ok(self)
@ -25,8 +25,8 @@ impl<'lua> TableBuilder<'lua> {
pub fn with_values<K, V>(self, values: Vec<(K, V)>) -> LuaResult<Self> pub fn with_values<K, V>(self, values: Vec<(K, V)>) -> LuaResult<Self>
where where
K: ToLua<'lua>, K: ToLua<'static>,
V: ToLua<'lua>, V: ToLua<'static>,
{ {
for (key, value) in values { for (key, value) in values {
self.tab.raw_set(key, value)?; self.tab.raw_set(key, value)?;
@ -36,7 +36,7 @@ impl<'lua> TableBuilder<'lua> {
pub fn with_sequential_value<V>(self, value: V) -> LuaResult<Self> pub fn with_sequential_value<V>(self, value: V) -> LuaResult<Self>
where where
V: ToLua<'lua>, V: ToLua<'static>,
{ {
self.tab.raw_push(value)?; self.tab.raw_push(value)?;
Ok(self) Ok(self)
@ -44,7 +44,7 @@ impl<'lua> TableBuilder<'lua> {
pub fn with_sequential_values<V>(self, values: Vec<V>) -> LuaResult<Self> pub fn with_sequential_values<V>(self, values: Vec<V>) -> LuaResult<Self>
where where
V: ToLua<'lua>, V: ToLua<'static>,
{ {
for value in values { for value in values {
self.tab.raw_push(value)?; self.tab.raw_push(value)?;
@ -59,10 +59,10 @@ impl<'lua> TableBuilder<'lua> {
pub fn with_function<K, A, R, F>(self, key: K, func: F) -> LuaResult<Self> pub fn with_function<K, A, R, F>(self, key: K, func: F) -> LuaResult<Self>
where where
K: ToLua<'lua>, K: ToLua<'static>,
A: FromLuaMulti<'lua>, A: FromLuaMulti<'static>,
R: ToLuaMulti<'lua>, R: ToLuaMulti<'static>,
F: 'static + Fn(&'lua Lua, A) -> LuaResult<R>, F: 'static + Fn(&'static Lua, A) -> LuaResult<R>,
{ {
let f = self.lua.create_function(func)?; let f = self.lua.create_function(func)?;
self.with_value(key, LuaValue::Function(f)) self.with_value(key, LuaValue::Function(f))
@ -70,22 +70,22 @@ impl<'lua> TableBuilder<'lua> {
pub fn with_async_function<K, A, R, F, FR>(self, key: K, func: F) -> LuaResult<Self> pub fn with_async_function<K, A, R, F, FR>(self, key: K, func: F) -> LuaResult<Self>
where where
K: ToLua<'lua>, K: ToLua<'static>,
A: FromLuaMulti<'lua>, A: FromLuaMulti<'static>,
R: ToLuaMulti<'lua>, R: ToLuaMulti<'static>,
F: 'static + Fn(&'lua Lua, A) -> FR, F: 'static + Fn(&'static Lua, A) -> FR,
FR: 'lua + Future<Output = LuaResult<R>>, FR: 'static + Future<Output = LuaResult<R>>,
{ {
let f = self.lua.create_async_function(func)?; let f = self.lua.create_async_function(func)?;
self.with_value(key, LuaValue::Function(f)) self.with_value(key, LuaValue::Function(f))
} }
pub fn build_readonly(self) -> LuaResult<LuaTable<'lua>> { pub fn build_readonly(self) -> LuaResult<LuaTable<'static>> {
self.tab.set_readonly(true); self.tab.set_readonly(true);
Ok(self.tab) Ok(self.tab)
} }
pub fn build(self) -> LuaResult<LuaTable<'lua>> { pub fn build(self) -> LuaResult<LuaTable<'static>> {
Ok(self.tab) Ok(self.tab)
} }
} }

View file

@ -1,6 +1,7 @@
local PORT = 8080
local RESPONSE = "Hello, lune!" local RESPONSE = "Hello, lune!"
local handle = net.serve(8080, function(request) local handle = net.serve(PORT, function(request)
-- info("Request:", request) -- info("Request:", request)
-- info("Responding with", RESPONSE) -- info("Responding with", RESPONSE)
assert(request.path == "/some/path") assert(request.path == "/some/path")
@ -10,7 +11,7 @@ local handle = net.serve(8080, function(request)
end) end)
local response = local response =
net.request("http://127.0.0.1:8080/some/path?key=param1&key=param2&key2=param3").body net.request(`http://127.0.0.1:{PORT}/some/path?key=param1&key=param2&key2=param3`).body
assert(response == RESPONSE, "Invalid response from server") assert(response == RESPONSE, "Invalid response from server")
handle.stop() handle.stop()
@ -21,7 +22,7 @@ task.wait()
-- Sending a net request may error if there was -- Sending a net request may error if there was
-- a connection issue, we should handle that here -- a connection issue, we should handle that here
local success, response2 = pcall(net.request, "http://127.0.0.1:8080/") local success, response2 = pcall(net.request, `http://127.0.0.1:{PORT}/`)
if not success then if not success then
local message = tostring(response2) local message = tostring(response2)
assert( assert(
@ -51,3 +52,27 @@ assert(
or string.find(message, "shut down"), or string.find(message, "shut down"),
"The error message for calling stop twice on the net serve handle should be descriptive" "The error message for calling stop twice on the net serve handle should be descriptive"
) )
--[[
Serve should also take a full config with handler functions
A server should also be able to start on the previously closed port
]]
local handle2 = net.serve(PORT, {
handleRequest = function()
return RESPONSE
end,
handleWebSocket = function(socket)
socket.close()
end,
})
local response3 = net.request(`http://127.0.0.1:{PORT}/`).body
assert(response3 == RESPONSE, "Invalid response from server")
-- TODO: Test web sockets properly when we have a web socket client
-- Stop the server and yield once more to end the test
handle2.stop()
task.wait()