use std::collections::HashMap; use mlua::prelude::*; use console::style; use hyper::Server; use tokio::{sync::mpsc, task}; use crate::lua::{ net::{ EncodeDecodeConfig, EncodeDecodeFormat, NetClient, NetClientBuilder, NetLocalExec, NetService, NetWebSocket, RequestConfig, ServeConfig, }, table::TableBuilder, task::{TaskScheduler, TaskSchedulerAsyncExt}, }; pub fn create(lua: &'static Lua) -> LuaResult { // Create a reusable client for performing our // web requests and store it in the lua registry, // allowing us to reuse headers and internal structs let client = NetClientBuilder::new() .headers(&[("User-Agent", create_user_agent_header())])? .build()?; lua.set_named_registry_value("net.client", client)?; // Create the global table for net TableBuilder::new(lua)? .with_function("encode", net_encode)? .with_function("decode", net_decode)? .with_function("jsonEncode", net_json_encode)? .with_function("jsonDecode", net_json_decode)? .with_async_function("request", net_request)? .with_async_function("socket", net_socket)? .with_async_function("serve", net_serve)? .build_readonly() } fn create_user_agent_header() -> String { let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY") .strip_prefix("https://github.com/") .unwrap() .split_once('/') .unwrap(); format!("{github_owner}-{github_repo}-cli") } fn net_encode<'a>( lua: &'static Lua, (format, val, pretty): (EncodeDecodeFormat, LuaValue<'a>, Option), ) -> LuaResult> { let config = EncodeDecodeConfig::from((format, pretty.unwrap_or_default())); config.serialize_to_string(lua, val) } fn net_decode<'a>( lua: &'static Lua, (format, str): (EncodeDecodeFormat, LuaString<'a>), ) -> LuaResult> { let config = EncodeDecodeConfig::from(format); config.deserialize_from_string(lua, str) } fn net_json_encode<'a>( lua: &'static Lua, (val, pretty): (LuaValue<'a>, Option), ) -> LuaResult> { EncodeDecodeConfig::from((EncodeDecodeFormat::Json, pretty.unwrap_or_default())) .serialize_to_string(lua, val) } fn net_json_decode<'a>(lua: &'static Lua, json: LuaString<'a>) -> LuaResult> { EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json) } async fn net_request<'a>(lua: &'static Lua, config: RequestConfig<'a>) -> LuaResult> { // Create and send the request let client: NetClient = lua.named_registry_value("net.client")?; let mut request = client.request(config.method, &config.url); for (query, value) in config.query { request = request.query(&[(query.to_str()?, value.to_str()?)]); } for (header, value) in config.headers { request = request.header(header.to_str()?, value.to_str()?); } let res = request .body(config.body.unwrap_or_default()) .send() .await .map_err(LuaError::external)?; // Extract status, headers let res_status = res.status().as_u16(); let res_status_text = res.status().canonical_reason(); let res_headers = res .headers() .iter() .map(|(name, value)| (name.to_string(), value.to_str().unwrap().to_owned())) .collect::>(); // Read response bytes let res_bytes = res.bytes().await.map_err(LuaError::external)?; // Construct and return a readonly lua table with results TableBuilder::new(lua)? .with_value("ok", (200..300).contains(&res_status))? .with_value("statusCode", res_status)? .with_value("statusMessage", res_status_text)? .with_value("headers", res_headers)? .with_value("body", lua.create_string(&res_bytes)?)? .build_readonly() } async fn net_socket<'a>(lua: &'static Lua, url: String) -> LuaResult { let (ws, _) = tokio_tungstenite::connect_async(url) .await .map_err(LuaError::external)?; NetWebSocket::new(ws).into_lua_table(lua) } async fn net_serve<'a>( lua: &'static Lua, (port, config): (u16, ServeConfig<'a>), ) -> LuaResult> { // Note that we need to use a mpsc here and not // a oneshot channel since we move the sender // into our table with the stop function let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1); let server_request_callback = lua.create_registry_value(config.handle_request)?; let server_websocket_callback = config.handle_web_socket.map(|handler| { lua.create_registry_value(handler) .expect("Failed to store websocket handler") }); let sched = lua .app_data_ref::<&TaskScheduler>() .expect("Missing task scheduler - make sure it is added as a lua app data before the first scheduler resumption"); // Bind first to make sure that we can bind to this address let bound = match Server::try_bind(&([127, 0, 0, 1], port).into()) { Err(e) => { return Err(LuaError::external(format!( "Failed to bind to localhost on port {port}\n{}", format!("{e}").replace( "error creating server listener: ", &format!("{}", style("> ").dim()) ) ))); } Ok(bound) => bound, }; // Register a background task to prevent the task scheduler from // exiting early and start up our web server on the bound address let task = sched.register_background_task(); let server = bound .http1_only(true) // Web sockets can only use http1 .http1_keepalive(true) // Web sockets must be kept alive .executor(NetLocalExec) .serve(NetService::new( lua, server_request_callback, server_websocket_callback, )) .with_graceful_shutdown(async move { task.unregister(Ok(())); shutdown_rx .recv() .await .expect("Server was stopped instantly"); shutdown_rx.close(); }); // Spawn a new tokio task so we don't block task::spawn_local(server); // Create a new read-only table that contains methods // for manipulating server behavior and shutting it down let handle_stop = move |_, _: ()| match shutdown_tx.try_send(()) { Ok(_) => Ok(()), Err(_) => Err(LuaError::RuntimeError( "Server has already been stopped".to_string(), )), }; TableBuilder::new(lua)? .with_function("stop", handle_stop)? .build_readonly() }