From 62910f02ab4c2ca64a34a738116d88108785c4ec Mon Sep 17 00:00:00 2001 From: Filip Tibell Date: Tue, 29 Apr 2025 15:06:16 +0200 Subject: [PATCH] Rewrite the `net` standard library with smol ecosystem of crates (#310) --- .github/workflows/ci.yaml | 20 +- .gitignore | 5 + .justfile | 1 + .lune/http_server.luau | 23 +- Cargo.lock | 397 +++++++++++------- crates/lune-std-net/Cargo.toml | 34 +- crates/lune-std-net/src/client.rs | 163 ------- crates/lune-std-net/src/client/http_stream.rs | 95 +++++ crates/lune-std-net/src/client/mod.rs | 125 ++++++ crates/lune-std-net/src/client/rustls.rs | 12 + crates/lune-std-net/src/client/ws_stream.rs | 114 +++++ crates/lune-std-net/src/config.rs | 231 ---------- crates/lune-std-net/src/lib.rs | 66 ++- crates/lune-std-net/src/server/config.rs | 87 ++++ crates/lune-std-net/src/server/handle.rs | 72 ++++ crates/lune-std-net/src/server/keys.rs | 58 --- crates/lune-std-net/src/server/mod.rs | 171 ++++---- crates/lune-std-net/src/server/request.rs | 56 --- crates/lune-std-net/src/server/response.rs | 89 ---- crates/lune-std-net/src/server/service.rs | 151 ++++--- crates/lune-std-net/src/server/upgrade.rs | 55 +++ crates/lune-std-net/src/shared/body.rs | 43 ++ crates/lune-std-net/src/shared/futures.rs | 19 + crates/lune-std-net/src/shared/headers.rs | 88 ++++ crates/lune-std-net/src/shared/hyper.rs | 198 +++++++++ crates/lune-std-net/src/shared/lua.rs | 57 +++ crates/lune-std-net/src/shared/mod.rs | 8 + crates/lune-std-net/src/shared/request.rs | 272 ++++++++++++ crates/lune-std-net/src/shared/response.rs | 172 ++++++++ .../src/{ => shared}/websocket.rs | 116 +++-- crates/lune-std-net/src/url/decode.rs | 12 + crates/lune-std-net/src/url/encode.rs | 13 + crates/lune-std-net/src/url/mod.rs | 5 + crates/lune-std-net/src/util.rs | 94 ----- crates/lune/src/tests.rs | 10 +- crates/mlua-luau-scheduler/Cargo.toml | 15 +- crates/mlua-luau-scheduler/src/functions.rs | 14 +- crates/mlua-luau-scheduler/src/queue.rs | 119 ++++-- crates/mlua-luau-scheduler/src/result_map.rs | 59 ++- crates/mlua-luau-scheduler/src/scheduler.rs | 21 +- crates/mlua-luau-scheduler/src/traits.rs | 8 +- crates/mlua-luau-scheduler/src/util.rs | 68 --- rokit.toml | 7 +- scripts/analyze_copy_typedefs.luau | 14 + tests/net/serve/addresses.luau | 35 ++ tests/net/serve/handles.luau | 51 +++ tests/net/serve/non_blocking.luau | 24 ++ tests/net/serve/requests.luau | 116 +---- tests/net/serve/websockets.luau | 2 +- tests/process/exec/stdio.luau | 2 +- .../tests/modules/self_alias/init.luau | 4 +- tests/roblox/instance/custom/async.luau | 7 +- tests/roblox/instance/methods/Clone.luau | 2 +- tests/serde/json/decode.luau | 60 +-- tests/serde/json/encode.luau | 10 - 55 files changed, 2331 insertions(+), 1439 deletions(-) delete mode 100644 crates/lune-std-net/src/client.rs create mode 100644 crates/lune-std-net/src/client/http_stream.rs create mode 100644 crates/lune-std-net/src/client/mod.rs create mode 100644 crates/lune-std-net/src/client/rustls.rs create mode 100644 crates/lune-std-net/src/client/ws_stream.rs delete mode 100644 crates/lune-std-net/src/config.rs create mode 100644 crates/lune-std-net/src/server/config.rs create mode 100644 crates/lune-std-net/src/server/handle.rs delete mode 100644 crates/lune-std-net/src/server/keys.rs delete mode 100644 crates/lune-std-net/src/server/request.rs delete mode 100644 crates/lune-std-net/src/server/response.rs create mode 100644 crates/lune-std-net/src/server/upgrade.rs create mode 100644 crates/lune-std-net/src/shared/body.rs create mode 100644 crates/lune-std-net/src/shared/futures.rs create mode 100644 crates/lune-std-net/src/shared/headers.rs create mode 100644 crates/lune-std-net/src/shared/hyper.rs create mode 100644 crates/lune-std-net/src/shared/lua.rs create mode 100644 crates/lune-std-net/src/shared/mod.rs create mode 100644 crates/lune-std-net/src/shared/request.rs create mode 100644 crates/lune-std-net/src/shared/response.rs rename crates/lune-std-net/src/{ => shared}/websocket.rs (58%) create mode 100644 crates/lune-std-net/src/url/decode.rs create mode 100644 crates/lune-std-net/src/url/encode.rs create mode 100644 crates/lune-std-net/src/url/mod.rs delete mode 100644 crates/lune-std-net/src/util.rs create mode 100644 scripts/analyze_copy_typedefs.luau create mode 100644 tests/net/serve/addresses.luau create mode 100644 tests/net/serve/handles.luau create mode 100644 tests/net/serve/non_blocking.luau diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 3891fdf..3160008 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,15 +2,16 @@ name: CI on: push: - pull_request: workflow_dispatch: defaults: run: shell: bash -jobs: +env: + CARGO_TERM_COLOR: always +jobs: fmt: name: Check formatting runs-on: ubuntu-latest @@ -79,23 +80,26 @@ jobs: components: clippy targets: ${{ matrix.cargo-target }} + - name: Install binstall + uses: cargo-bins/cargo-binstall@main + + - name: Install nextest + run: cargo binstall cargo-nextest + - name: Build run: | - cargo build \ - --workspace \ + cargo build --workspace \ --locked --all-features \ --target ${{ matrix.cargo-target }} - name: Lint run: | - cargo clippy \ - --workspace \ + cargo clippy --workspace \ --locked --all-features \ --target ${{ matrix.cargo-target }} - name: Test run: | - cargo test \ - --lib --workspace \ + cargo nextest run --no-fail-fast \ --locked --all-features \ --target ${{ matrix.cargo-target }} diff --git a/.gitignore b/.gitignore index 6f7b83e..f88ebb4 100644 --- a/.gitignore +++ b/.gitignore @@ -21,7 +21,12 @@ lune.yml luneDocs.json luneTypes.d.luau +# Dirs generated by runtime or build scripts + +/types + # Files generated by runtime or build scripts + scripts/brick_color.rs scripts/font_enum_map.rs scripts/physical_properties_enum_map.rs diff --git a/.justfile b/.justfile index a3752ca..5e06c52 100644 --- a/.justfile +++ b/.justfile @@ -65,6 +65,7 @@ fmt-check: analyze: #!/usr/bin/env bash set -euo pipefail + lune run scripts/analyze_copy_typedefs luau-lsp analyze \ --settings=".vscode/settings.json" \ --ignore="tests/roblox/rbx-test-files/**" \ diff --git a/.lune/http_server.luau b/.lune/http_server.luau index 0d0b2d3..b3ad573 100644 --- a/.lune/http_server.luau +++ b/.lune/http_server.luau @@ -3,7 +3,6 @@ local net = require("@lune/net") local process = require("@lune/process") -local task = require("@lune/task") local PORT = if process.env.PORT ~= nil and #process.env.PORT > 0 then assert(tonumber(process.env.PORT), "Failed to parse port from env") @@ -11,6 +10,10 @@ local PORT = if process.env.PORT ~= nil and #process.env.PORT > 0 -- Create our responder functions +local function root(_request: net.ServeRequest): string + return `Hello from Lune server!` +end + local function pong(request: net.ServeRequest): string return `Pong!\n{request.path}\n{request.body}` end @@ -29,10 +32,12 @@ local function notFound(_request: net.ServeRequest): net.ServeResponse } end --- Run the server on port 8080 +-- Run the server on the port forever -local handle = net.serve(PORT, function(request) - if string.sub(request.path, 1, 5) == "/ping" then +net.serve(PORT, function(request) + if request.path == "/" then + return root(request) + elseif string.sub(request.path, 1, 5) == "/ping" then return pong(request) elseif string.sub(request.path, 1, 7) == "/teapot" then return teapot(request) @@ -42,12 +47,4 @@ local handle = net.serve(PORT, function(request) end) print(`Listening on port {PORT} 🚀`) - --- Exit our example after a small delay, if you copy this --- example just remove this part to keep the server running - -task.delay(2, function() - print("Shutting down...") - task.wait(1) - handle.stop() -end) +print("Press Ctrl+C to stop") diff --git a/Cargo.lock b/Cargo.lock index 15c4118..f7068be 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -238,11 +238,22 @@ version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18" dependencies = [ - "event-listener 5.4.0", + "event-listener", "event-listener-strategy", "pin-project-lite", ] +[[package]] +name = "async-net" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7" +dependencies = [ + "async-io", + "blocking", + "futures-lite", +] + [[package]] name = "async-process" version = "2.3.0" @@ -256,7 +267,7 @@ dependencies = [ "async-task", "blocking", "cfg-if 1.0.0", - "event-listener 5.4.0", + "event-listener", "futures-lite", "rustix 0.38.44", "tracing", @@ -286,6 +297,22 @@ version = "4.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de" +[[package]] +name = "async-tungstenite" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef0f7efedeac57d9b26170f72965ecfd31473ca52ca7a64e925b0b6f5f079886" +dependencies = [ + "atomic-waker", + "futures-core", + "futures-io", + "futures-task", + "futures-util", + "log", + "pin-project-lite", + "tungstenite", +] + [[package]] name = "atomic-waker" version = "1.1.2" @@ -298,6 +325,29 @@ version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" +[[package]] +name = "aws-lc-rs" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878" +dependencies = [ + "aws-lc-sys", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.28.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -337,6 +387,29 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "bindgen" +version = "0.69.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088" +dependencies = [ + "bitflags 2.9.0", + "cexpr", + "clang-sys", + "itertools", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash 1.1.0", + "shlex", + "syn 2.0.100", + "which", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -482,6 +555,15 @@ dependencies = [ "shlex", ] +[[package]] +name = "cexpr" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] + [[package]] name = "cfg-if" version = "0.1.10" @@ -539,6 +621,17 @@ dependencies = [ "inout", ] +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", +] + [[package]] name = "clap" version = "4.5.37" @@ -588,6 +681,15 @@ dependencies = [ "error-code", ] +[[package]] +name = "cmake" +version = "0.1.54" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0" +dependencies = [ + "cc", +] + [[package]] name = "colorchoice" version = "1.0.3" @@ -634,12 +736,6 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" -[[package]] -name = "convert_case" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" - [[package]] name = "cookie" version = "0.15.2" @@ -747,19 +843,6 @@ dependencies = [ "syn 2.0.100", ] -[[package]] -name = "derive_more" -version = "0.99.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f" -dependencies = [ - "convert_case", - "proc-macro2", - "quote", - "rustc_version 0.4.1", - "syn 2.0.100", -] - [[package]] name = "dialoguer" version = "0.11.0" @@ -919,17 +1002,6 @@ version = "3.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f" -[[package]] -name = "event-listener" -version = "4.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e" -dependencies = [ - "concurrent-queue", - "parking", - "pin-project-lite", -] - [[package]] name = "event-listener" version = "5.4.0" @@ -947,7 +1019,7 @@ version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" dependencies = [ - "event-listener 5.4.0", + "event-listener", "pin-project-lite", ] @@ -993,6 +1065,26 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -1000,6 +1092,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -1038,6 +1131,17 @@ dependencies = [ "syn 2.0.100", ] +[[package]] +name = "futures-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f2f12607f92c69b12ed746fabf9ca4f5c482cba46679c1a75b874ed7c26adb" +dependencies = [ + "futures-io", + "rustls 0.23.26", + "rustls-pki-types", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -1056,10 +1160,13 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", "futures-macro", "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", "slab", @@ -1136,6 +1243,12 @@ version = "0.30.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0e9b6647e9b41d3a5ef02964c6be01311a7f2472fea40595c635c6d046c259e" +[[package]] +name = "glob" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" + [[package]] name = "h2" version = "0.3.26" @@ -1155,25 +1268,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "h2" -version = "0.4.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "75249d144030531f8dee69fe9cea04d3edf809a017ae445e2abdff6629e86633" -dependencies = [ - "atomic-waker", - "bytes", - "fnv", - "futures-core", - "futures-sink", - "http 1.3.1", - "indexmap", - "slab", - "tokio", - "tokio-util", - "tracing", -] - [[package]] name = "hashbrown" version = "0.15.2" @@ -1288,7 +1382,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2 0.3.26", + "h2", "http 0.2.12", "http-body 0.4.6", "httparse", @@ -1311,7 +1405,6 @@ dependencies = [ "bytes", "futures-channel", "futures-util", - "h2 0.4.9", "http 1.3.1", "http-body 1.0.1", "httparse", @@ -1334,42 +1427,7 @@ dependencies = [ "hyper 0.14.32", "rustls 0.21.12", "tokio", - "tokio-rustls 0.24.1", -] - -[[package]] -name = "hyper-tungstenite" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad" -dependencies = [ - "http-body-util", - "hyper 1.6.0", - "hyper-util", - "pin-project-lite", - "tokio", - "tokio-tungstenite", - "tungstenite", -] - -[[package]] -name = "hyper-util" -version = "0.1.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2" -dependencies = [ - "bytes", - "futures-channel", - "futures-util", - "http 1.3.1", - "http-body 1.0.1", - "hyper 1.6.0", - "libc", - "pin-project-lite", - "socket2", - "tokio", - "tower-service", - "tracing", + "tokio-rustls", ] [[package]] @@ -1566,6 +1624,15 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.15" @@ -1607,6 +1674,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.172" @@ -1790,21 +1863,30 @@ dependencies = [ name = "lune-std-net" version = "0.2.0" dependencies = [ + "async-channel", + "async-executor", + "async-io", + "async-lock", + "async-net", + "async-tungstenite", + "blocking", "bstr", - "futures-util", - "http 1.3.1", + "futures", + "futures-lite", + "futures-rustls", "http-body-util", "hyper 1.6.0", - "hyper-tungstenite", - "hyper-util", "lune-std-serde", "lune-utils", "mlua", "mlua-luau-scheduler", - "reqwest", - "tokio", - "tokio-tungstenite", + "pin-project-lite", + "rustls 0.23.26", + "rustls-pki-types", + "url", "urlencoding", + "webpki", + "webpki-roots 0.26.8", ] [[package]] @@ -1982,6 +2064,12 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "miniz_oxide" version = "0.8.8" @@ -2030,11 +2118,10 @@ dependencies = [ "async-io", "blocking", "concurrent-queue", - "derive_more", - "event-listener 4.0.3", + "event-listener", "futures-lite", "mlua", - "rustc-hash 1.1.0", + "rustc-hash 2.1.1", "tracing", "tracing-subscriber", "tracing-tracy", @@ -2073,6 +2160,16 @@ dependencies = [ "libc", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.46.0" @@ -2307,6 +2404,16 @@ dependencies = [ "zerocopy 0.8.24", ] +[[package]] +name = "prettyplease" +version = "0.2.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6" +dependencies = [ + "proc-macro2", + "syn 2.0.100", +] + [[package]] name = "proc-macro-hack" version = "0.5.20+deprecated" @@ -2634,7 +2741,7 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2 0.3.26", + "h2", "http 0.2.12", "http-body 0.4.6", "hyper 0.14.32", @@ -2654,7 +2761,7 @@ dependencies = [ "sync_wrapper", "system-configuration", "tokio", - "tokio-rustls 0.24.1", + "tokio-rustls", "tower-service", "url", "wasm-bindgen", @@ -2750,15 +2857,6 @@ dependencies = [ "semver 0.9.0", ] -[[package]] -name = "rustc_version" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92" -dependencies = [ - "semver 1.0.26", -] - [[package]] name = "rustix" version = "0.38.44" @@ -2799,14 +2897,15 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.4" +version = "0.23.26" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0" dependencies = [ + "aws-lc-rs", "log", - "ring", + "once_cell", "rustls-pki-types", - "rustls-webpki 0.102.8", + "rustls-webpki 0.103.1", "subtle", "zeroize", ] @@ -2838,10 +2937,11 @@ dependencies = [ [[package]] name = "rustls-webpki" -version = "0.102.8" +version = "0.103.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03" dependencies = [ + "aws-lc-rs", "ring", "rustls-pki-types", "untrusted", @@ -3146,7 +3246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d022496b16281348b52d0e30ae99e01a73d737b2f45d38fed4edf79f9325a1d5" dependencies = [ "discard", - "rustc_version 0.2.3", + "rustc_version", "stdweb-derive", "stdweb-internal-macros", "stdweb-internal-runtime", @@ -3441,33 +3541,6 @@ dependencies = [ "tokio", ] -[[package]] -name = "tokio-rustls" -version = "0.25.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" -dependencies = [ - "rustls 0.22.4", - "rustls-pki-types", - "tokio", -] - -[[package]] -name = "tokio-tungstenite" -version = "0.21.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38" -dependencies = [ - "futures-util", - "log", - "rustls 0.22.4", - "rustls-pki-types", - "tokio", - "tokio-rustls 0.25.0", - "tungstenite", - "webpki-roots 0.26.8", -] - [[package]] name = "tokio-util" version = "0.7.15" @@ -3623,22 +3696,18 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" [[package]] name = "tungstenite" -version = "0.21.0" +version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1" +checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13" dependencies = [ - "byteorder 1.5.0", "bytes", "data-encoding", "http 1.3.1", "httparse", "log", - "rand 0.8.5", - "rustls 0.22.4", - "rustls-pki-types", + "rand 0.9.1", "sha1 0.10.6", - "thiserror 1.0.69", - "url", + "thiserror 2.0.12", "utf-8", ] @@ -3871,6 +3940,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "webpki-roots" version = "0.25.4" @@ -3886,6 +3965,18 @@ dependencies = [ "rustls-pki-types", ] +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix 0.38.44", +] + [[package]] name = "winapi" version = "0.3.9" diff --git a/crates/lune-std-net/Cargo.toml b/crates/lune-std-net/Cargo.toml index b143857..f5ec79d 100644 --- a/crates/lune-std-net/Cargo.toml +++ b/crates/lune-std-net/Cargo.toml @@ -16,24 +16,26 @@ workspace = true mlua = { version = "0.10.3", features = ["luau"] } mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" } +async-channel = "2.3" +async-executor = "1.13" +async-io = "2.4" +async-lock = "3.4" +async-net = "2.0" +async-tungstenite = "0.29" +blocking = "1.6" bstr = "1.9" -futures-util = "0.3" -hyper = { version = "1.1", features = ["full"] } -hyper-util = { version = "0.1", features = ["full"] } -http = "1.0" -http-body-util = { version = "0.1" } -hyper-tungstenite = { version = "0.13" } -reqwest = { version = "0.11", default-features = false, features = [ - "rustls-tls", -] } -tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } +futures = { version = "0.3", default-features = false, features = ["std"] } +futures-lite = "2.6" +futures-rustls = "0.26" +http-body-util = "0.1" +hyper = { version = "1.6", features = ["http1", "client", "server"] } +pin-project-lite = "0.2" +rustls = "0.23" +rustls-pki-types = "1.11" +url = "2.5" urlencoding = "2.1" - -tokio = { version = "1", default-features = false, features = [ - "sync", - "net", - "macros", -] } +webpki = "0.22" +webpki-roots = "0.26" lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-std-serde = { version = "0.2.0", path = "../lune-std-serde" } diff --git a/crates/lune-std-net/src/client.rs b/crates/lune-std-net/src/client.rs deleted file mode 100644 index c796c93..0000000 --- a/crates/lune-std-net/src/client.rs +++ /dev/null @@ -1,163 +0,0 @@ -use std::str::FromStr; - -use mlua::prelude::*; - -use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_ENCODING}; - -use lune_std_serde::{decompress, CompressDecompressFormat}; -use lune_utils::TableBuilder; - -use super::{config::RequestConfig, util::header_map_to_table}; - -const REGISTRY_KEY: &str = "NetClient"; - -pub struct NetClientBuilder { - builder: reqwest::ClientBuilder, -} - -impl NetClientBuilder { - pub fn new() -> NetClientBuilder { - Self { - builder: reqwest::ClientBuilder::new(), - } - } - - pub fn headers(mut self, headers: &[(K, V)]) -> LuaResult - where - K: AsRef, - V: AsRef<[u8]>, - { - let mut map = HeaderMap::new(); - for (key, val) in headers { - let hkey = HeaderName::from_str(key.as_ref()).into_lua_err()?; - let hval = HeaderValue::from_bytes(val.as_ref()).into_lua_err()?; - map.insert(hkey, hval); - } - self.builder = self.builder.default_headers(map); - Ok(self) - } - - pub fn build(self) -> LuaResult { - let client = self.builder.build().into_lua_err()?; - Ok(NetClient { inner: client }) - } -} - -#[derive(Debug, Clone)] -pub struct NetClient { - inner: reqwest::Client, -} - -impl NetClient { - pub fn from_registry(lua: &Lua) -> Self { - lua.named_registry_value(REGISTRY_KEY) - .expect("Failed to get NetClient from lua registry") - } - - pub fn into_registry(self, lua: &Lua) { - lua.set_named_registry_value(REGISTRY_KEY, self) - .expect("Failed to store NetClient in lua registry"); - } - - pub async fn request(&self, config: RequestConfig) -> LuaResult { - // Create and send the request - let mut request = self.inner.request(config.method, config.url); - for (query, values) in config.query { - request = request.query( - &values - .iter() - .map(|v| (query.as_str(), v)) - .collect::>(), - ); - } - for (header, values) in config.headers { - for value in values { - request = request.header(header.as_str(), value); - } - } - let res = request - .body(config.body.unwrap_or_default()) - .send() - .await - .into_lua_err()?; - - // Extract status, headers - let res_status = res.status().as_u16(); - let res_status_text = res.status().canonical_reason(); - let res_headers = res.headers().clone(); - - // Read response bytes - let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec(); - let mut res_decompressed = false; - - // Check for extra options, decompression - if config.options.decompress { - let decompress_format = res_headers - .iter() - .find(|(name, _)| { - name.as_str() - .eq_ignore_ascii_case(CONTENT_ENCODING.as_str()) - }) - .and_then(|(_, value)| value.to_str().ok()) - .and_then(CompressDecompressFormat::detect_from_header_str); - if let Some(format) = decompress_format { - res_bytes = decompress(res_bytes, format).await?; - res_decompressed = true; - } - } - - Ok(NetClientResponse { - ok: (200..300).contains(&res_status), - status_code: res_status, - status_message: res_status_text.unwrap_or_default().to_string(), - headers: res_headers, - body: res_bytes, - body_decompressed: res_decompressed, - }) - } -} - -impl LuaUserData for NetClient {} - -impl FromLua for NetClient { - fn from_lua(value: LuaValue, _: &Lua) -> LuaResult { - if let LuaValue::UserData(ud) = value { - if let Ok(ctx) = ud.borrow::() { - return Ok(ctx.clone()); - } - } - unreachable!("NetClient should only be used from registry") - } -} - -impl From<&Lua> for NetClient { - fn from(value: &Lua) -> Self { - value - .named_registry_value(REGISTRY_KEY) - .expect("Missing require context in lua registry") - } -} - -pub struct NetClientResponse { - ok: bool, - status_code: u16, - status_message: String, - headers: HeaderMap, - body: Vec, - body_decompressed: bool, -} - -impl NetClientResponse { - pub fn into_lua_table(self, lua: &Lua) -> LuaResult { - TableBuilder::new(lua.clone())? - .with_value("ok", self.ok)? - .with_value("statusCode", self.status_code)? - .with_value("statusMessage", self.status_message)? - .with_value( - "headers", - header_map_to_table(lua, self.headers, self.body_decompressed)?, - )? - .with_value("body", lua.create_string(&self.body)?)? - .build_readonly() - } -} diff --git a/crates/lune-std-net/src/client/http_stream.rs b/crates/lune-std-net/src/client/http_stream.rs new file mode 100644 index 0000000..2aba704 --- /dev/null +++ b/crates/lune-std-net/src/client/http_stream.rs @@ -0,0 +1,95 @@ +use std::{ + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use async_net::TcpStream; +use futures_lite::prelude::*; +use futures_rustls::{TlsConnector, TlsStream}; +use rustls_pki_types::ServerName; +use url::Url; + +use crate::client::rustls::CLIENT_CONFIG; + +#[derive(Debug)] +pub enum HttpStream { + Plain(TcpStream), + Tls(TlsStream), +} + +impl HttpStream { + pub async fn connect(url: Url) -> Result { + let Some(host) = url.host() else { + return Err(make_err("unknown or missing host")); + }; + let Some(port) = url.port_or_known_default() else { + return Err(make_err("unknown or missing port")); + }; + + let use_tls = match url.scheme() { + "http" => false, + "https" => true, + s => return Err(make_err(format!("unsupported scheme: {s}"))), + }; + + let host = host.to_string(); + let stream = TcpStream::connect((host.clone(), port)).await?; + + let stream = if use_tls { + let servname = ServerName::try_from(host).map_err(make_err)?.to_owned(); + let connector = TlsConnector::from(Arc::clone(&CLIENT_CONFIG)); + let stream = connector.connect(servname, stream).await?; + Self::Tls(TlsStream::Client(stream)) + } else { + Self::Plain(stream) + }; + + Ok(stream) + } +} + +impl AsyncRead for HttpStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match &mut *self { + HttpStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf), + HttpStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl AsyncWrite for HttpStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + match &mut *self { + HttpStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf), + HttpStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + HttpStream::Plain(stream) => Pin::new(stream).poll_close(cx), + HttpStream::Tls(stream) => Pin::new(stream).poll_close(cx), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + HttpStream::Plain(stream) => Pin::new(stream).poll_flush(cx), + HttpStream::Tls(stream) => Pin::new(stream).poll_flush(cx), + } + } +} + +fn make_err(e: impl ToString) -> io::Error { + io::Error::new(io::ErrorKind::Other, e.to_string()) +} diff --git a/crates/lune-std-net/src/client/mod.rs b/crates/lune-std-net/src/client/mod.rs new file mode 100644 index 0000000..456b0c0 --- /dev/null +++ b/crates/lune-std-net/src/client/mod.rs @@ -0,0 +1,125 @@ +use hyper::{ + body::{Bytes, Incoming}, + client::conn::http1::handshake, + header::{HeaderValue, ACCEPT, CONTENT_LENGTH, HOST, LOCATION, USER_AGENT}, + Method, Request as HyperRequest, Response as HyperResponse, Uri, +}; + +use mlua::prelude::*; +use url::Url; + +use crate::{ + client::{http_stream::HttpStream, ws_stream::WsStream}, + shared::{ + headers::create_user_agent_header, + hyper::{HyperExecutor, HyperIo}, + request::Request, + response::Response, + websocket::Websocket, + }, +}; + +pub mod http_stream; +pub mod rustls; +pub mod ws_stream; + +const MAX_REDIRECTS: usize = 10; + +/** + Connects to a websocket at the given URL. +*/ +pub async fn connect_websocket(url: Url) -> LuaResult> { + let stream = WsStream::connect(url).await?; + Ok(Websocket::from(stream)) +} + +/** + Sends the request and returns the final response. + + This will follow any redirects returned by the server, + modifying the request method and body as necessary. +*/ +pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult { + let url = request + .inner + .uri() + .to_string() + .parse::() + .expect("uri is valid"); + + // Some headers are required by most if not + // all servers, make sure those are present... + if !request.headers().contains_key(HOST.as_str()) { + if let Some(host) = url.host_str() { + let host = HeaderValue::from_str(host).into_lua_err()?; + request.inner.headers_mut().insert(HOST, host); + } + } + if !request.headers().contains_key(USER_AGENT.as_str()) { + let ua = create_user_agent_header(&lua)?; + let ua = HeaderValue::from_str(&ua).into_lua_err()?; + request.inner.headers_mut().insert(USER_AGENT, ua); + } + if !request.headers().contains_key(CONTENT_LENGTH.as_str()) && request.method() != Method::GET { + let len = request.inner.body().len().to_string(); + let len = HeaderValue::from_str(&len).into_lua_err()?; + request.inner.headers_mut().insert(CONTENT_LENGTH, len); + } + if !request.headers().contains_key(ACCEPT.as_str()) { + let accept = HeaderValue::from_static("*/*"); + request.inner.headers_mut().insert(ACCEPT, accept); + } + + // ... we can now safely continue and send the request + loop { + let stream = HttpStream::connect(url.clone()).await?; + + let (mut sender, conn) = handshake(HyperIo::from(stream)).await.into_lua_err()?; + + HyperExecutor::execute(lua.clone(), conn); + + let incoming = sender + .send_request(request.as_full()) + .await + .into_lua_err()?; + + if let Some((new_method, new_uri)) = check_redirect(&request.inner, &incoming) { + if request.redirects.is_some_and(|r| r >= MAX_REDIRECTS) { + return Err(LuaError::external("Too many redirects")); + } + + if new_method == Method::GET { + *request.inner.body_mut() = Bytes::new(); + } + + *request.inner.method_mut() = new_method; + *request.inner.uri_mut() = new_uri; + + *request.redirects.get_or_insert_default() += 1; + + continue; + } + + break Response::from_incoming(incoming, request.decompress).await; + } +} + +fn check_redirect( + request: &HyperRequest, + response: &HyperResponse, +) -> Option<(Method, Uri)> { + if !response.status().is_redirection() { + return None; + } + + let location = response.headers().get(LOCATION)?; + let location = location.to_str().ok()?; + let location = location.parse().ok()?; + + let method = match response.status().as_u16() { + 301..=303 => Method::GET, + _ => request.method().clone(), + }; + + Some((method, location)) +} diff --git a/crates/lune-std-net/src/client/rustls.rs b/crates/lune-std-net/src/client/rustls.rs new file mode 100644 index 0000000..ea864ab --- /dev/null +++ b/crates/lune-std-net/src/client/rustls.rs @@ -0,0 +1,12 @@ +use std::sync::{Arc, LazyLock}; + +use rustls::ClientConfig; + +pub static CLIENT_CONFIG: LazyLock> = LazyLock::new(|| { + rustls::ClientConfig::builder() + .with_root_certificates(rustls::RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(), + }) + .with_no_client_auth() + .into() +}); diff --git a/crates/lune-std-net/src/client/ws_stream.rs b/crates/lune-std-net/src/client/ws_stream.rs new file mode 100644 index 0000000..03537ec --- /dev/null +++ b/crates/lune-std-net/src/client/ws_stream.rs @@ -0,0 +1,114 @@ +use std::{ + io, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; + +use async_net::TcpStream; +use async_tungstenite::{ + tungstenite::{Error as TungsteniteError, Message, Result as TungsteniteResult}, + WebSocketStream as TungsteniteStream, +}; +use futures::Sink; +use futures_lite::prelude::*; +use futures_rustls::{TlsConnector, TlsStream}; +use rustls_pki_types::ServerName; +use url::Url; + +use crate::client::rustls::CLIENT_CONFIG; + +#[derive(Debug)] +pub enum WsStream { + Plain(TungsteniteStream), + Tls(TungsteniteStream>), +} + +impl WsStream { + pub async fn connect(url: Url) -> Result { + let Some(host) = url.host() else { + return Err(make_err("unknown or missing host")); + }; + let Some(port) = url.port_or_known_default() else { + return Err(make_err("unknown or missing port")); + }; + + let use_tls = match url.scheme() { + "ws" => false, + "wss" => true, + s => return Err(make_err(format!("unsupported scheme: {s}"))), + }; + + let host = host.to_string(); + let stream = TcpStream::connect((host.clone(), port)).await?; + + let stream = if use_tls { + let servname = ServerName::try_from(host).map_err(make_err)?.to_owned(); + let connector = TlsConnector::from(Arc::clone(&CLIENT_CONFIG)); + + let stream = connector.connect(servname, stream).await?; + let stream = TlsStream::Client(stream); + + let stream = async_tungstenite::client_async(url.to_string(), stream) + .await + .map_err(make_err)? + .0; + Self::Tls(stream) + } else { + let stream = async_tungstenite::client_async(url.to_string(), stream) + .await + .map_err(make_err)? + .0; + Self::Plain(stream) + }; + + Ok(stream) + } +} + +impl Sink for WsStream { + type Error = TungsteniteError; + + fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).poll_ready(cx), + WsStream::Tls(s) => Pin::new(s).poll_ready(cx), + } + } + + fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).start_send(item), + WsStream::Tls(s) => Pin::new(s).start_send(item), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).poll_flush(cx), + WsStream::Tls(s) => Pin::new(s).poll_flush(cx), + } + } + + fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).poll_close(cx), + WsStream::Tls(s) => Pin::new(s).poll_close(cx), + } + } +} + +impl Stream for WsStream { + type Item = TungsteniteResult; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + match &mut *self { + WsStream::Plain(s) => Pin::new(s).poll_next(cx), + WsStream::Tls(s) => Pin::new(s).poll_next(cx), + } + } +} + +fn make_err(e: impl ToString) -> io::Error { + io::Error::new(io::ErrorKind::Other, e.to_string()) +} diff --git a/crates/lune-std-net/src/config.rs b/crates/lune-std-net/src/config.rs deleted file mode 100644 index bcd9cef..0000000 --- a/crates/lune-std-net/src/config.rs +++ /dev/null @@ -1,231 +0,0 @@ -use std::{ - collections::HashMap, - net::{IpAddr, Ipv4Addr}, -}; - -use bstr::{BString, ByteSlice}; -use mlua::prelude::*; - -use reqwest::Method; - -use super::util::table_to_hash_map; - -const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); - -const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#" -return { - status = 426, - body = "Upgrade Required", - headers = { - Upgrade = "websocket", - }, -} -"#; - -// Net request config - -#[derive(Debug, Clone)] -pub struct RequestConfigOptions { - pub decompress: bool, -} - -impl Default for RequestConfigOptions { - fn default() -> Self { - Self { decompress: true } - } -} - -impl FromLua for RequestConfigOptions { - fn from_lua(value: LuaValue, _: &Lua) -> LuaResult { - if let LuaValue::Nil = value { - // Nil means default options - Ok(Self::default()) - } else if let LuaValue::Table(tab) = value { - // Table means custom options - let decompress = match tab.get::>("decompress") { - Ok(decomp) => Ok(decomp.unwrap_or(true)), - Err(_) => Err(LuaError::RuntimeError( - "Invalid option value for 'decompress' in request config options".to_string(), - )), - }?; - Ok(Self { decompress }) - } else { - // Anything else is invalid - Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: "RequestConfigOptions".to_string(), - message: Some(format!( - "Invalid request config options - expected table or nil, got {}", - value.type_name() - )), - }) - } - } -} - -#[derive(Debug, Clone)] -pub struct RequestConfig { - pub url: String, - pub method: Method, - pub query: HashMap>, - pub headers: HashMap>, - pub body: Option>, - pub options: RequestConfigOptions, -} - -impl FromLua for RequestConfig { - fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { - // If we just got a string we assume its a GET request to a given url - if let LuaValue::String(s) = value { - Ok(Self { - url: s.to_string_lossy().to_string(), - method: Method::GET, - query: HashMap::new(), - headers: HashMap::new(), - body: None, - options: RequestConfigOptions::default(), - }) - } else if let LuaValue::Table(tab) = value { - // If we got a table we are able to configure the entire request - // Extract url - let url = match tab.get::("url") { - Ok(config_url) => Ok(config_url.to_string_lossy().to_string()), - Err(_) => Err(LuaError::runtime("Missing 'url' in request config")), - }?; - // Extract method - let method = match tab.get::("method") { - Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(), - Err(_) => "GET".to_string(), - }; - // Extract query - let query = match tab.get::("query") { - Ok(tab) => table_to_hash_map(tab, "query")?, - Err(_) => HashMap::new(), - }; - // Extract headers - let headers = match tab.get::("headers") { - Ok(tab) => table_to_hash_map(tab, "headers")?, - Err(_) => HashMap::new(), - }; - // Extract body - let body = match tab.get::("body") { - Ok(config_body) => Some(config_body.as_bytes().to_owned()), - Err(_) => None, - }; - - // Convert method string into proper enum - let method = method.trim().to_ascii_uppercase(); - let method = match method.as_ref() { - "GET" => Ok(Method::GET), - "POST" => Ok(Method::POST), - "PUT" => Ok(Method::PUT), - "DELETE" => Ok(Method::DELETE), - "HEAD" => Ok(Method::HEAD), - "OPTIONS" => Ok(Method::OPTIONS), - "PATCH" => Ok(Method::PATCH), - _ => Err(LuaError::RuntimeError(format!( - "Invalid request config method '{}'", - &method - ))), - }?; - // Parse any extra options given - let options = match tab.get::("options") { - Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?, - Err(_) => RequestConfigOptions::default(), - }; - // All good, validated and we got what we need - Ok(Self { - url, - method, - query, - headers, - body, - options, - }) - } else { - // Anything else is invalid - Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: "RequestConfig".to_string(), - message: Some(format!( - "Invalid request config - expected string or table, got {}", - value.type_name() - )), - }) - } - } -} - -// Net serve config - -#[derive(Debug)] -pub struct ServeConfig { - pub address: IpAddr, - pub handle_request: LuaFunction, - pub handle_web_socket: Option, -} - -impl FromLua for ServeConfig { - fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { - if let LuaValue::Function(f) = &value { - // Single function = request handler, rest is default - Ok(ServeConfig { - handle_request: f.clone(), - handle_web_socket: None, - address: DEFAULT_IP_ADDRESS, - }) - } else if let LuaValue::Table(t) = &value { - // Table means custom options - let address: Option = t.get("address")?; - let handle_request: Option = t.get("handleRequest")?; - let handle_web_socket: Option = t.get("handleWebSocket")?; - if handle_request.is_some() || handle_web_socket.is_some() { - let address: IpAddr = match &address { - Some(addr) => { - let addr_str = addr.to_str()?; - - addr_str - .trim_start_matches("http://") - .trim_start_matches("https://") - .parse() - .map_err(|_e| LuaError::FromLuaConversionError { - from: value.type_name(), - to: "ServeConfig".to_string(), - message: Some(format!( - "IP address format is incorrect - \ - expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \ - got '{addr_str}'" - )), - })? - } - None => DEFAULT_IP_ADDRESS, - }; - - Ok(Self { - address, - handle_request: handle_request.unwrap_or_else(|| { - lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER) - .into_function() - .expect("Failed to create default http responder function") - }), - handle_web_socket, - }) - } else { - Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: "ServeConfig".to_string(), - message: Some(String::from( - "Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function", - )), - }) - } - } else { - // Anything else is invalid - Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: "ServeConfig".to_string(), - message: None, - }) - } - } -} diff --git a/crates/lune-std-net/src/lib.rs b/crates/lune-std-net/src/lib.rs index 334a302..f4c8e06 100644 --- a/crates/lune-std-net/src/lib.rs +++ b/crates/lune-std-net/src/lib.rs @@ -1,22 +1,17 @@ #![allow(clippy::cargo_common_metadata)] -use mlua::prelude::*; -use mlua_luau_scheduler::LuaSpawnExt; - -mod client; -mod config; -mod server; -mod util; -mod websocket; - use lune_utils::TableBuilder; +use mlua::prelude::*; + +pub(crate) mod client; +pub(crate) mod server; +pub(crate) mod shared; +pub(crate) mod url; use self::{ - client::{NetClient, NetClientBuilder}, - config::{RequestConfig, ServeConfig}, - server::serve, - util::create_user_agent_header, - websocket::NetWebSocket, + client::ws_stream::WsStream, + server::config::ServeConfig, + shared::{request::Request, response::Response, websocket::Websocket}, }; const TYPEDEFS: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/types.d.luau")); @@ -37,10 +32,6 @@ pub fn typedefs() -> String { Errors when out of memory. */ pub fn module(lua: Lua) -> LuaResult { - NetClientBuilder::new() - .headers(&[("User-Agent", create_user_agent_header(&lua)?)])? - .build()? - .into_registry(&lua); TableBuilder::new(lua)? .with_async_function("request", net_request)? .with_async_function("socket", net_socket)? @@ -50,42 +41,35 @@ pub fn module(lua: Lua) -> LuaResult { .build_readonly() } -async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult { - let client = NetClient::from_registry(&lua); - // NOTE: We spawn the request as a background task to free up resources in lua - let res = lua.spawn(async move { client.request(config).await }); - res.await?.into_lua_table(&lua) +async fn net_request(lua: Lua, req: Request) -> LuaResult { + self::client::send_request(req, lua).await } -async fn net_socket(lua: Lua, url: String) -> LuaResult { - let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?; - NetWebSocket::new(ws).into_lua(&lua) +async fn net_socket(_: Lua, url: String) -> LuaResult> { + let url = url.parse().into_lua_err()?; + self::client::connect_websocket(url).await } async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult { - serve(lua, port, config).await + self::server::serve(lua.clone(), port, config) + .await? + .into_lua_table(lua) } fn net_url_encode( lua: &Lua, (lua_string, as_binary): (LuaString, Option), -) -> LuaResult { - if matches!(as_binary, Some(true)) { - urlencoding::encode_binary(&lua_string.as_bytes()).into_lua(lua) - } else { - urlencoding::encode(&lua_string.to_str()?).into_lua(lua) - } +) -> LuaResult { + let as_binary = as_binary.unwrap_or_default(); + let bytes = self::url::encode(lua_string, as_binary)?; + lua.create_string(bytes) } fn net_url_decode( lua: &Lua, (lua_string, as_binary): (LuaString, Option), -) -> LuaResult { - if matches!(as_binary, Some(true)) { - urlencoding::decode_binary(&lua_string.as_bytes()).into_lua(lua) - } else { - urlencoding::decode(&lua_string.to_str()?) - .map_err(|e| LuaError::RuntimeError(format!("Encountered invalid encoding - {e}")))? - .into_lua(lua) - } +) -> LuaResult { + let as_binary = as_binary.unwrap_or_default(); + let bytes = self::url::decode(lua_string, as_binary)?; + lua.create_string(bytes) } diff --git a/crates/lune-std-net/src/server/config.rs b/crates/lune-std-net/src/server/config.rs new file mode 100644 index 0000000..06fe6a8 --- /dev/null +++ b/crates/lune-std-net/src/server/config.rs @@ -0,0 +1,87 @@ +use std::net::{IpAddr, Ipv4Addr}; + +use mlua::prelude::*; + +const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)); + +const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#" +return { + status = 426, + body = "Upgrade Required", + headers = { + Upgrade = "websocket", + }, +} +"#; + +#[derive(Debug, Clone)] +pub struct ServeConfig { + pub address: IpAddr, + pub handle_request: LuaFunction, + pub handle_web_socket: Option, +} + +impl FromLua for ServeConfig { + fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { + if let LuaValue::Function(f) = &value { + // Single function = request handler, rest is default + Ok(ServeConfig { + handle_request: f.clone(), + handle_web_socket: None, + address: DEFAULT_IP_ADDRESS, + }) + } else if let LuaValue::Table(t) = &value { + // Table means custom options + let address: Option = t.get("address")?; + let handle_request: Option = t.get("handleRequest")?; + let handle_web_socket: Option = t.get("handleWebSocket")?; + if handle_request.is_some() || handle_web_socket.is_some() { + let address: IpAddr = match &address { + Some(addr) => { + let addr_str = addr.to_str()?; + + addr_str + .trim_start_matches("http://") + .trim_start_matches("https://") + .parse() + .map_err(|_e| LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig".to_string(), + message: Some(format!( + "IP address format is incorrect - \ + expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \ + got '{addr_str}'" + )), + })? + } + None => DEFAULT_IP_ADDRESS, + }; + + Ok(Self { + address, + handle_request: handle_request.unwrap_or_else(|| { + lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER) + .into_function() + .expect("Failed to create default http responder function") + }), + handle_web_socket, + }) + } else { + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig".to_string(), + message: Some(String::from( + "Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function", + )), + }) + } + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "ServeConfig".to_string(), + message: None, + }) + } + } +} diff --git a/crates/lune-std-net/src/server/handle.rs b/crates/lune-std-net/src/server/handle.rs new file mode 100644 index 0000000..4f2f10c --- /dev/null +++ b/crates/lune-std-net/src/server/handle.rs @@ -0,0 +1,72 @@ +use std::{ + net::SocketAddr, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, + }, +}; + +use async_channel::{unbounded, Receiver, Sender}; + +use lune_utils::TableBuilder; +use mlua::prelude::*; + +#[derive(Debug, Clone)] +pub struct ServeHandle { + addr: SocketAddr, + shutdown: Arc, + sender: Sender<()>, +} + +impl ServeHandle { + pub fn new(addr: SocketAddr) -> (Self, Receiver<()>) { + let (sender, receiver) = unbounded(); + let this = Self { + addr, + shutdown: Arc::new(AtomicBool::new(false)), + sender, + }; + (this, receiver) + } + + // TODO: Remove this in the next major release to use colon/self + // based call syntax and userdata implementation below instead + pub fn into_lua_table(self, lua: Lua) -> LuaResult { + let shutdown = self.shutdown.clone(); + let sender = self.sender.clone(); + TableBuilder::new(lua)? + .with_value("ip", self.addr.ip().to_string())? + .with_value("port", self.addr.port())? + .with_function("stop", move |_, ()| { + if shutdown.load(Ordering::SeqCst) { + Err(LuaError::runtime("Server already stopped")) + } else { + shutdown.store(true, Ordering::SeqCst); + sender.try_send(()).ok(); + sender.close(); + Ok(()) + } + })? + .build() + } +} + +impl LuaUserData for ServeHandle { + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("ip", |_, this| Ok(this.addr.ip().to_string())); + fields.add_field_method_get("port", |_, this| Ok(this.addr.port())); + } + + fn add_methods>(methods: &mut M) { + methods.add_method("stop", |_, this, ()| { + if this.shutdown.load(Ordering::SeqCst) { + Err(LuaError::runtime("Server already stopped")) + } else { + this.shutdown.store(true, Ordering::SeqCst); + this.sender.try_send(()).ok(); + this.sender.close(); + Ok(()) + } + }); + } +} diff --git a/crates/lune-std-net/src/server/keys.rs b/crates/lune-std-net/src/server/keys.rs deleted file mode 100644 index aa34607..0000000 --- a/crates/lune-std-net/src/server/keys.rs +++ /dev/null @@ -1,58 +0,0 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; - -use mlua::prelude::*; - -#[derive(Debug, Clone, Copy)] -pub(super) struct SvcKeys { - key_request: &'static str, - key_websocket: Option<&'static str>, -} - -impl SvcKeys { - pub(super) fn new( - lua: Lua, - handle_request: LuaFunction, - handle_websocket: Option, - ) -> LuaResult { - static SERVE_COUNTER: AtomicUsize = AtomicUsize::new(0); - let count = SERVE_COUNTER.fetch_add(1, Ordering::Relaxed); - - // NOTE: We leak strings here, but this is an acceptable tradeoff since programs - // generally only start one or a couple of servers and they are usually never dropped. - // Leaking here lets us keep this struct Copy and access the request handler callbacks - // very performantly, significantly reducing the per-request overhead of the server. - let key_request: &'static str = - Box::leak(format!("__net_serve_request_{count}").into_boxed_str()); - let key_websocket: Option<&'static str> = if handle_websocket.is_some() { - Some(Box::leak( - format!("__net_serve_websocket_{count}").into_boxed_str(), - )) - } else { - None - }; - - lua.set_named_registry_value(key_request, handle_request)?; - if let Some(key) = key_websocket { - lua.set_named_registry_value(key, handle_websocket.unwrap())?; - } - - Ok(Self { - key_request, - key_websocket, - }) - } - - pub(super) fn has_websocket_handler(&self) -> bool { - self.key_websocket.is_some() - } - - pub(super) fn request_handler(&self, lua: &Lua) -> LuaResult { - lua.named_registry_value(self.key_request) - } - - pub(super) fn websocket_handler(&self, lua: &Lua) -> LuaResult> { - self.key_websocket - .map(|key| lua.named_registry_value(key)) - .transpose() - } -} diff --git a/crates/lune-std-net/src/server/mod.rs b/crates/lune-std-net/src/server/mod.rs index 65783ec..58ff65b 100644 --- a/crates/lune-std-net/src/server/mod.rs +++ b/crates/lune-std-net/src/server/mod.rs @@ -1,92 +1,121 @@ -use std::net::SocketAddr; +use std::{cell::Cell, net::SocketAddr, rc::Rc}; -use hyper::server::conn::http1; -use hyper_util::rt::TokioIo; -use tokio::{net::TcpListener, pin}; +use async_net::TcpListener; +use futures_lite::pin; +use hyper::server::conn::http1::Builder as Http1Builder; use mlua::prelude::*; use mlua_luau_scheduler::LuaSpawnExt; -use lune_utils::TableBuilder; +use crate::{ + server::{config::ServeConfig, handle::ServeHandle, service::Service}, + shared::{ + futures::{either, Either}, + hyper::{HyperIo, HyperTimer}, + }, +}; -use super::config::ServeConfig; +pub mod config; +pub mod handle; +pub mod service; +pub mod upgrade; -mod keys; -mod request; -mod response; -mod service; +/** + Starts an HTTP server using the given port and configuration. -use keys::SvcKeys; -use service::Svc; - -pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult { - let addr: SocketAddr = (config.address, port).into(); - let listener = TcpListener::bind(addr).await?; - - let lua_svc = lua.clone(); - let lua_inner = lua.clone(); - - let keys = SvcKeys::new(lua.clone(), config.handle_request, config.handle_web_socket)?; - let svc = Svc { - lua: lua_svc, - addr, - keys, + Returns a `ServeHandle` that can be used to gracefully stop the server. +*/ +pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult { + let address = SocketAddr::from((config.address, port)); + let service = Service { + lua: lua.clone(), + address, + config, }; - let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); - lua.spawn_local(async move { - let mut shutdown_rx_outer = shutdown_rx.clone(); - loop { - // Create futures for accepting new connections and shutting down - let fut_shutdown = shutdown_rx_outer.changed(); - let fut_accept = async { - let stream = match listener.accept().await { - Err(_) => return, - Ok((s, _)) => s, + let listener = TcpListener::bind(address).await?; + let (handle, shutdown_rx) = ServeHandle::new(address); + + lua.spawn_local({ + let lua = lua.clone(); + async move { + let handle_dropped = Rc::new(Cell::new(false)); + loop { + // 1. Keep accepting new connections until we should shutdown + let (conn, addr) = if handle_dropped.get() { + // 1a. Handle has been dropped, and we don't need to listen for shutdown + match listener.accept().await { + Ok(acc) => acc, + Err(_err) => { + // TODO: Propagate error somehow + continue; + } + } + } else { + // 1b. Handle is possibly active, we must listen for shutdown + match either(shutdown_rx.recv(), listener.accept()).await { + Either::Left(Ok(())) => break, + Either::Left(Err(_)) => { + // NOTE #1: We will only get a RecvError if the serve handle is dropped, + // this means lua has garbage collected it and the user does not want + // to manually stop the server using the serve handle. Run forever. + handle_dropped.set(true); + continue; + } + Either::Right(Ok(acc)) => acc, + Either::Right(Err(_err)) => { + // TODO: Propagate error somehow + continue; + } + } }; - let io = TokioIo::new(stream); - let svc = svc.clone(); - let mut shutdown_rx_inner = shutdown_rx.clone(); + // 2. For each connection, spawn a new task to handle it + lua.spawn_local({ + let rx = shutdown_rx.clone(); + let io = HyperIo::from(conn); - lua_inner.spawn_local(async move { - let conn = http1::Builder::new() - .keep_alive(true) // Web sockets need this - .serve_connection(io, svc) - .with_upgrades(); - // NOTE: Because we need to use keep_alive for websockets, we need to - // also manually poll this future and handle the shutdown signal here - pin!(conn); - tokio::select! { - _ = conn.as_mut() => {} - _ = shutdown_rx_inner.changed() => { - conn.as_mut().graceful_shutdown(); + let mut svc = service.clone(); + svc.address = addr; + + let handle_dropped = Rc::clone(&handle_dropped); + async move { + let conn = Http1Builder::new() + .writev(false) + .timer(HyperTimer) + .keep_alive(true) + .serve_connection(io, svc) + .with_upgrades(); + if handle_dropped.get() { + if let Err(_err) = conn.await { + // TODO: Propagate error somehow + } + } else { + // NOTE #2: Because we use keep_alive for websockets above, we need to + // also manually poll this future and handle the graceful shutdown, + // otherwise the already accepted connection will linger and run + // even if the stop method has been called on the serve handle + pin!(conn); + match either(rx.recv(), conn.as_mut()).await { + Either::Left(Ok(())) => conn.as_mut().graceful_shutdown(), + Either::Left(Err(_)) => { + // Same as note #1 + handle_dropped.set(true); + if let Err(_err) = conn.await { + // TODO: Propagate error somehow + } + } + Either::Right(Ok(())) => {} + Either::Right(Err(_err)) => { + // TODO: Propagate error somehow + } + } } } }); - }; - - // Wait for either a new connection or a shutdown signal - tokio::select! { - () = fut_accept => {} - res = fut_shutdown => { - // NOTE: We will only get a RecvError here if the serve handle is dropped, - // this means lua has garbage collected it and the user does not want - // to manually stop the server using the serve handle. Run forever. - if res.is_ok() { - break; - } - } } } }); - TableBuilder::new(lua)? - .with_value("ip", addr.ip().to_string())? - .with_value("port", addr.port())? - .with_function("stop", move |_, (): ()| match shutdown_tx.send(true) { - Ok(()) => Ok(()), - Err(_) => Err(LuaError::runtime("Server already stopped")), - })? - .build_readonly() + Ok(handle) } diff --git a/crates/lune-std-net/src/server/request.rs b/crates/lune-std-net/src/server/request.rs deleted file mode 100644 index d1df208..0000000 --- a/crates/lune-std-net/src/server/request.rs +++ /dev/null @@ -1,56 +0,0 @@ -use std::{collections::HashMap, net::SocketAddr}; - -use http::request::Parts; - -use mlua::prelude::*; - -use lune_utils::TableBuilder; - -pub(super) struct LuaRequest { - pub(super) _remote_addr: SocketAddr, - pub(super) head: Parts, - pub(super) body: Vec, -} - -impl LuaRequest { - pub fn into_lua_table(self, lua: &Lua) -> LuaResult { - let method = self.head.method.as_str().to_string(); - let path = self.head.uri.path().to_string(); - let body = lua.create_string(&self.body)?; - - #[allow(clippy::mutable_key_type)] - let query: HashMap = self - .head - .uri - .query() - .unwrap_or_default() - .split('&') - .filter_map(|q| q.split_once('=')) - .map(|(k, v)| { - let k = lua.create_string(k)?; - let v = lua.create_string(v)?; - Ok((k, v)) - }) - .collect::>()?; - - #[allow(clippy::mutable_key_type)] - let headers: HashMap = self - .head - .headers - .iter() - .map(|(k, v)| { - let k = lua.create_string(k.as_str())?; - let v = lua.create_string(v.as_bytes())?; - Ok((k, v)) - }) - .collect::>()?; - - TableBuilder::new(lua.clone())? - .with_value("method", method)? - .with_value("path", path)? - .with_value("query", query)? - .with_value("headers", headers)? - .with_value("body", body)? - .build() - } -} diff --git a/crates/lune-std-net/src/server/response.rs b/crates/lune-std-net/src/server/response.rs deleted file mode 100644 index 1240ef1..0000000 --- a/crates/lune-std-net/src/server/response.rs +++ /dev/null @@ -1,89 +0,0 @@ -use std::str::FromStr; - -use bstr::{BString, ByteSlice}; -use http_body_util::Full; -use hyper::{ - body::Bytes, - header::{HeaderName, HeaderValue}, - HeaderMap, Response, -}; - -use mlua::prelude::*; - -#[derive(Debug, Clone, Copy)] -pub(super) enum LuaResponseKind { - PlainText, - Table, -} - -pub(super) struct LuaResponse { - pub(super) kind: LuaResponseKind, - pub(super) status: u16, - pub(super) headers: HeaderMap, - pub(super) body: Option>, -} - -impl LuaResponse { - pub(super) fn into_response(self) -> LuaResult>> { - Ok(match self.kind { - LuaResponseKind::PlainText => Response::builder() - .status(200) - .header("Content-Type", "text/plain") - .body(Full::new(Bytes::from(self.body.unwrap()))) - .into_lua_err()?, - LuaResponseKind::Table => { - let mut response = Response::builder() - .status(self.status) - .body(Full::new(Bytes::from(self.body.unwrap_or_default()))) - .into_lua_err()?; - response.headers_mut().extend(self.headers); - response - } - }) - } -} - -impl FromLua for LuaResponse { - fn from_lua(value: LuaValue, _: &Lua) -> LuaResult { - match value { - // Plain strings from the handler are plaintext responses - LuaValue::String(s) => Ok(Self { - kind: LuaResponseKind::PlainText, - status: 200, - headers: HeaderMap::new(), - body: Some(s.as_bytes().to_vec()), - }), - // Tables are more detailed responses with potential status, headers, body - LuaValue::Table(t) => { - let status: Option = t.get("status")?; - let headers: Option = t.get("headers")?; - let body: Option = t.get("body")?; - - let mut headers_map = HeaderMap::new(); - if let Some(headers) = headers { - for pair in headers.pairs::() { - let (h, v) = pair?; - let name = HeaderName::from_str(&h).into_lua_err()?; - let value = HeaderValue::from_bytes(&v.as_bytes()).into_lua_err()?; - headers_map.insert(name, value); - } - } - - let body_bytes = body.map(|s| s.as_bytes().to_vec()); - - Ok(Self { - kind: LuaResponseKind::Table, - status: status.unwrap_or(200), - headers: headers_map, - body: body_bytes, - }) - } - // Anything else is an error - value => Err(LuaError::FromLuaConversionError { - from: value.type_name(), - to: "NetServeResponse".to_string(), - message: None, - }), - } - } -} diff --git a/crates/lune-std-net/src/server/service.rs b/crates/lune-std-net/src/server/service.rs index ec89572..31f7d30 100644 --- a/crates/lune-std-net/src/server/service.rs +++ b/crates/lune-std-net/src/server/service.rs @@ -1,82 +1,117 @@ use std::{future::Future, net::SocketAddr, pin::Pin}; -use http_body_util::{BodyExt, Full}; +use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; +use http_body_util::Full; use hyper::{ body::{Bytes, Incoming}, - service::Service, - Request, Response, + service::Service as HyperService, + Request as HyperRequest, Response as HyperResponse, StatusCode, }; -use hyper_tungstenite::{is_upgrade_request, upgrade}; use mlua::prelude::*; use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt}; -use super::{ - super::websocket::NetWebSocket, keys::SvcKeys, request::LuaRequest, response::LuaResponse, +use crate::{ + server::{ + config::ServeConfig, + upgrade::{is_upgrade_request, make_upgrade_response}, + }, + shared::{hyper::HyperIo, request::Request, response::Response, websocket::Websocket}, }; #[derive(Debug, Clone)] -pub(super) struct Svc { +pub(super) struct Service { pub(super) lua: Lua, - pub(super) addr: SocketAddr, - pub(super) keys: SvcKeys, + pub(super) address: SocketAddr, // NOTE: This must be the remote address of the connected client + pub(super) config: ServeConfig, } -impl Service> for Svc { - type Response = Response>; +impl HyperService> for Service { + type Response = HyperResponse>; type Error = LuaError; type Future = Pin>>>; - fn call(&self, req: Request) -> Self::Future { - let lua = self.lua.clone(); - let addr = self.addr; - let keys = self.keys; + fn call(&self, req: HyperRequest) -> Self::Future { + if is_upgrade_request(&req) { + if let Some(handler) = self.config.handle_web_socket.clone() { + let lua = self.lua.clone(); + return Box::pin(async move { + let response = match make_upgrade_response(&req) { + Ok(res) => res, + Err(err) => { + return Ok(HyperResponse::builder() + .status(StatusCode::BAD_REQUEST) + .body(Full::new(Bytes::from(err.to_string()))) + .unwrap()) + } + }; - if keys.has_websocket_handler() && is_upgrade_request(&req) { - Box::pin(async move { - let (res, sock) = upgrade(req, None).into_lua_err()?; + lua.spawn_local({ + let lua = lua.clone(); + async move { + if let Err(_err) = handle_websocket(lua, handler, req).await { + // TODO: Propagate the error somehow? + } + } + }); - let lua_inner = lua.clone(); - lua.spawn_local(async move { - let sock = sock.await.unwrap(); - let lua_sock = NetWebSocket::new(sock); - let lua_val = lua_sock.into_lua(&lua_inner).unwrap(); - - let handler_websocket: LuaFunction = - keys.websocket_handler(&lua_inner).unwrap().unwrap(); - - lua_inner - .push_thread_back(handler_websocket, lua_val) - .unwrap(); + Ok(response) }); - - Ok(res) - }) - } else { - let (head, body) = req.into_parts(); - - Box::pin(async move { - let handler_request: LuaFunction = keys.request_handler(&lua).unwrap(); - - let body = body.collect().await.into_lua_err()?; - let body = body.to_bytes().to_vec(); - - let lua_req = LuaRequest { - _remote_addr: addr, - head, - body, - }; - let lua_req_table = lua_req.into_lua_table(&lua)?; - - let thread_id = lua.push_thread_back(handler_request, lua_req_table)?; - lua.track_thread(thread_id); - lua.wait_for_thread(thread_id).await; - let thread_res = lua - .get_thread_result(thread_id) - .expect("Missing handler thread result")?; - - LuaResponse::from_lua_multi(thread_res, &lua)?.into_response() - }) + } } + + let lua = self.lua.clone(); + let address = self.address; + let handler = self.config.handle_request.clone(); + Box::pin(async move { + match handle_request(lua, handler, req, address).await { + Ok(response) => Ok(response), + Err(_err) => { + // TODO: Propagate the error somehow? + Ok(HyperResponse::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .body(Full::new(Bytes::from("Lune: Internal server error"))) + .unwrap()) + } + } + }) } } + +async fn handle_request( + lua: Lua, + handler: LuaFunction, + request: HyperRequest, + address: SocketAddr, +) -> LuaResult>> { + let request = Request::from_incoming(request, true) + .await? + .with_address(address); + + let thread_id = lua.push_thread_back(handler, request)?; + lua.track_thread(thread_id); + lua.wait_for_thread(thread_id).await; + + let thread_res = lua + .get_thread_result(thread_id) + .expect("Missing handler thread result")?; + + let response = Response::from_lua_multi(thread_res, &lua)?; + Ok(response.into_full()) +} + +async fn handle_websocket( + lua: Lua, + handler: LuaFunction, + request: HyperRequest, +) -> LuaResult<()> { + let upgraded = hyper::upgrade::on(request).await.into_lua_err()?; + + let stream = + WebSocketStream::from_raw_socket(HyperIo::from(upgraded), Role::Server, None).await; + + let websocket = Websocket::from(stream); + lua.push_thread_back(handler, websocket)?; + + Ok(()) +} diff --git a/crates/lune-std-net/src/server/upgrade.rs b/crates/lune-std-net/src/server/upgrade.rs new file mode 100644 index 0000000..6840957 --- /dev/null +++ b/crates/lune-std-net/src/server/upgrade.rs @@ -0,0 +1,55 @@ +use async_tungstenite::tungstenite::{error::ProtocolError, handshake::derive_accept_key}; +use http_body_util::Full; + +use hyper::{ + body::{Bytes, Incoming}, + header::{HeaderName, CONNECTION, UPGRADE}, + HeaderMap, Request as HyperRequest, Response as HyperResponse, StatusCode, +}; + +const SEC_WEBSOCKET_VERSION: HeaderName = HeaderName::from_static("sec-websocket-version"); +const SEC_WEBSOCKET_KEY: HeaderName = HeaderName::from_static("sec-websocket-key"); +const SEC_WEBSOCKET_ACCEPT: HeaderName = HeaderName::from_static("sec-websocket-accept"); + +pub fn is_upgrade_request(request: &HyperRequest) -> bool { + fn check_header_contains(headers: &HeaderMap, header_name: HeaderName, value: &str) -> bool { + headers.get(header_name).is_some_and(|header| { + header.to_str().map_or_else( + |_| false, + |header_str| { + header_str + .split(',') + .any(|part| part.trim().eq_ignore_ascii_case(value)) + }, + ) + }) + } + + check_header_contains(request.headers(), CONNECTION, "Upgrade") + && check_header_contains(request.headers(), UPGRADE, "websocket") +} + +pub fn make_upgrade_response( + request: &HyperRequest, +) -> Result>, ProtocolError> { + let key = request + .headers() + .get(SEC_WEBSOCKET_KEY) + .ok_or(ProtocolError::MissingSecWebSocketKey)?; + + if request + .headers() + .get(SEC_WEBSOCKET_VERSION) + .is_none_or(|v| v.as_bytes() != b"13") + { + return Err(ProtocolError::MissingSecWebSocketVersionHeader); + } + + Ok(HyperResponse::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(CONNECTION, "upgrade") + .header(UPGRADE, "websocket") + .header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(key.as_bytes())) + .body(Full::new(Bytes::from("switching to websocket protocol"))) + .unwrap()) +} diff --git a/crates/lune-std-net/src/shared/body.rs b/crates/lune-std-net/src/shared/body.rs new file mode 100644 index 0000000..3ab6541 --- /dev/null +++ b/crates/lune-std-net/src/shared/body.rs @@ -0,0 +1,43 @@ +use http_body_util::{BodyExt, Full}; +use hyper::{ + body::{Bytes, Incoming}, + header::CONTENT_ENCODING, + HeaderMap, +}; + +use mlua::prelude::*; + +use lune_std_serde::{decompress, CompressDecompressFormat}; + +pub async fn handle_incoming_body( + headers: &HeaderMap, + body: Incoming, + should_decompress: bool, +) -> LuaResult<(Bytes, bool)> { + let mut body = body.collect().await.into_lua_err()?.to_bytes(); + + let was_decompressed = if should_decompress { + let decompress_format = headers + .get(CONTENT_ENCODING) + .and_then(|value| value.to_str().ok()) + .and_then(CompressDecompressFormat::detect_from_header_str); + if let Some(format) = decompress_format { + body = Bytes::from(decompress(body, format).await?); + true + } else { + false + } + } else { + false + }; + + Ok((body, was_decompressed)) +} + +pub fn bytes_to_full(bytes: Bytes) -> Full { + if bytes.is_empty() { + Full::default() + } else { + Full::new(bytes) + } +} diff --git a/crates/lune-std-net/src/shared/futures.rs b/crates/lune-std-net/src/shared/futures.rs new file mode 100644 index 0000000..f00a55c --- /dev/null +++ b/crates/lune-std-net/src/shared/futures.rs @@ -0,0 +1,19 @@ +use futures_lite::prelude::*; + +pub use http_body_util::Either; + +/** + Combines the left and right futures into a single future + that resolves to either the left or right output. + + This combinator is biased - if both futures resolve at + the same time, the left future's output is returned. +*/ +pub fn either( + left: L, + right: R, +) -> impl Future> { + let fut_left = async move { Either::Left(left.await) }; + let fut_right = async move { Either::Right(right.await) }; + fut_left.or(fut_right) +} diff --git a/crates/lune-std-net/src/shared/headers.rs b/crates/lune-std-net/src/shared/headers.rs new file mode 100644 index 0000000..de0db54 --- /dev/null +++ b/crates/lune-std-net/src/shared/headers.rs @@ -0,0 +1,88 @@ +use std::collections::HashMap; + +use hyper::{ + header::{CONTENT_ENCODING, CONTENT_LENGTH}, + HeaderMap, +}; + +use lune_utils::TableBuilder; +use mlua::prelude::*; + +pub fn create_user_agent_header(lua: &Lua) -> LuaResult { + let version_global = lua + .globals() + .get::("_VERSION") + .expect("Missing _VERSION global"); + + let version_global_str = version_global + .to_str() + .context("Invalid utf8 found in _VERSION global")?; + + let (package_name, full_version) = version_global_str.split_once(' ').unwrap(); + + Ok(format!("{}/{}", package_name.to_lowercase(), full_version)) +} + +pub fn header_map_to_table( + lua: &Lua, + headers: HeaderMap, + remove_content_headers: bool, +) -> LuaResult { + let mut string_map = HashMap::>::new(); + + for (name, value) in headers { + if let Some(name) = name { + if let Ok(value) = value.to_str() { + string_map + .entry(name.to_string()) + .or_default() + .push(value.to_owned()); + } + } + } + + hash_map_to_table(lua, string_map, remove_content_headers) +} + +pub fn hash_map_to_table( + lua: &Lua, + map: impl IntoIterator)>, + remove_content_headers: bool, +) -> LuaResult { + let mut string_map = HashMap::>::new(); + for (name, values) in map { + let name = name.as_str(); + + if remove_content_headers { + let content_encoding_header_str = CONTENT_ENCODING.as_str(); + let content_length_header_str = CONTENT_LENGTH.as_str(); + if name == content_encoding_header_str || name == content_length_header_str { + continue; + } + } + + for value in values { + let value = value.as_str(); + string_map + .entry(name.to_owned()) + .or_default() + .push(value.to_owned()); + } + } + + let mut builder = TableBuilder::new(lua.clone())?; + for (name, mut values) in string_map { + if values.len() == 1 { + let value = values.pop().unwrap().into_lua(lua)?; + builder = builder.with_value(name, value)?; + } else { + let values = TableBuilder::new(lua.clone())? + .with_sequential_values(values)? + .build_readonly()? + .into_lua(lua)?; + builder = builder.with_value(name, values)?; + } + } + + builder.build_readonly() +} diff --git a/crates/lune-std-net/src/shared/hyper.rs b/crates/lune-std-net/src/shared/hyper.rs new file mode 100644 index 0000000..12882ee --- /dev/null +++ b/crates/lune-std-net/src/shared/hyper.rs @@ -0,0 +1,198 @@ +use std::{ + future::Future, + io, + pin::Pin, + slice, + task::{Context, Poll}, + time::{Duration, Instant}, +}; + +use async_io::Timer; +use futures_lite::{prelude::*, ready}; +use hyper::rt::{self, Executor, ReadBuf, ReadBufCursor}; +use mlua::prelude::*; +use mlua_luau_scheduler::LuaSpawnExt; + +// Hyper executor that spawns futures onto our Lua scheduler + +#[derive(Debug, Clone)] +pub struct HyperExecutor { + lua: Lua, +} + +#[allow(dead_code)] +impl HyperExecutor { + pub fn execute(lua: Lua, fut: Fut) + where + Fut: Future + Send + 'static, + Fut::Output: Send + 'static, + { + let exec = if let Some(exec) = lua.app_data_ref::() { + exec + } else { + lua.set_app_data(Self { lua: lua.clone() }); + lua.app_data_ref::().unwrap() + }; + + exec.execute(fut); + } +} + +impl rt::Executor for HyperExecutor +where + Fut::Output: Send + 'static, +{ + fn execute(&self, fut: Fut) { + self.lua.spawn(fut).detach(); + } +} + +// Hyper timer & sleep future wrapper for async-io + +#[derive(Debug)] +pub struct HyperTimer; + +impl rt::Timer for HyperTimer { + fn sleep(&self, duration: Duration) -> Pin> { + Box::pin(HyperSleep::from(Timer::after(duration))) + } + + fn sleep_until(&self, at: Instant) -> Pin> { + Box::pin(HyperSleep::from(Timer::at(at))) + } + + fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { + if let Some(mut sleep) = sleep.as_mut().downcast_mut_pin::() { + sleep.inner.set_at(new_deadline); + } else { + *sleep = Box::pin(HyperSleep::from(Timer::at(new_deadline))); + } + } +} + +#[derive(Debug)] +pub struct HyperSleep { + inner: Timer, +} + +impl From for HyperSleep { + fn from(inner: Timer) -> Self { + Self { inner } + } +} + +impl Future for HyperSleep { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + match Pin::new(&mut self.inner).poll(cx) { + Poll::Ready(_) => Poll::Ready(()), + Poll::Pending => Poll::Pending, + } + } +} + +impl rt::Sleep for HyperSleep {} + +// Hyper I/O wrapper for bidirectional compatibility +// between hyper & futures-lite async read/write traits + +pin_project_lite::pin_project! { + #[derive(Debug)] + pub struct HyperIo { + #[pin] + inner: T + } +} + +impl From for HyperIo { + fn from(inner: T) -> Self { + Self { inner } + } +} + +impl HyperIo { + pub fn pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> { + self.project().inner + } +} + +// Compat for futures-lite -> hyper runtime + +impl rt::Read for HyperIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + mut buf: ReadBufCursor<'_>, + ) -> Poll> { + // Fill the read buffer with initialized data + let read_slice = unsafe { + let buffer = buf.as_mut(); + buffer.as_mut_ptr().write_bytes(0, buffer.len()); + slice::from_raw_parts_mut(buffer.as_mut_ptr().cast::(), buffer.len()) + }; + + // Read bytes from the underlying source + let n = match self.pin_mut().poll_read(cx, read_slice) { + Poll::Ready(Ok(n)) => n, + Poll::Ready(Err(e)) => return Poll::Ready(Err(e)), + Poll::Pending => return Poll::Pending, + }; + + unsafe { + buf.advance(n); + } + + Poll::Ready(Ok(())) + } +} + +impl rt::Write for HyperIo { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.pin_mut().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_mut().poll_flush(cx) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_mut().poll_close(cx) + } +} + +// Compat for hyper runtime -> futures-lite + +impl AsyncRead for HyperIo { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let mut buf = ReadBuf::new(buf); + ready!(self.pin_mut().poll_read(cx, buf.unfilled()))?; + Poll::Ready(Ok(buf.filled().len())) + } +} + +impl AsyncWrite for HyperIo { + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + self.pin_mut().poll_write(cx, buf) + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_mut().poll_flush(cx) + } + + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.pin_mut().poll_shutdown(cx) + } +} diff --git a/crates/lune-std-net/src/shared/lua.rs b/crates/lune-std-net/src/shared/lua.rs new file mode 100644 index 0000000..1f81f96 --- /dev/null +++ b/crates/lune-std-net/src/shared/lua.rs @@ -0,0 +1,57 @@ +use hyper::{ + body::Bytes, + header::{HeaderName, HeaderValue}, + HeaderMap, Method, +}; +use mlua::prelude::*; + +pub fn lua_value_to_bytes(value: &LuaValue) -> LuaResult { + match value { + LuaValue::Nil => Ok(Bytes::new()), + LuaValue::Buffer(buf) => Ok(Bytes::from(buf.to_vec())), + LuaValue::String(str) => Ok(Bytes::copy_from_slice(&str.as_bytes())), + v => Err(LuaError::FromLuaConversionError { + from: v.type_name(), + to: "Bytes".to_string(), + message: Some(format!( + "Invalid body - expected string or buffer, got {}", + v.type_name() + )), + }), + } +} + +pub fn lua_value_to_method(value: &LuaValue) -> LuaResult { + match value { + LuaValue::Nil => Ok(Method::GET), + LuaValue::String(str) => { + let bytes = str.as_bytes().trim_ascii().to_ascii_uppercase(); + Method::from_bytes(&bytes).into_lua_err() + } + LuaValue::Buffer(buf) => { + let bytes = buf.to_vec().trim_ascii().to_ascii_uppercase(); + Method::from_bytes(&bytes).into_lua_err() + } + v => Err(LuaError::FromLuaConversionError { + from: v.type_name(), + to: "Method".to_string(), + message: Some(format!( + "Invalid method - expected string or buffer, got {}", + v.type_name() + )), + }), + } +} + +pub fn lua_table_to_header_map(table: &LuaTable) -> LuaResult { + let mut headers = HeaderMap::new(); + + for pair in table.pairs::() { + let (key, val) = pair?; + let key = HeaderName::from_bytes(&key.as_bytes()).into_lua_err()?; + let val = HeaderValue::from_bytes(&val.as_bytes()).into_lua_err()?; + headers.insert(key, val); + } + + Ok(headers) +} diff --git a/crates/lune-std-net/src/shared/mod.rs b/crates/lune-std-net/src/shared/mod.rs new file mode 100644 index 0000000..ce262d6 --- /dev/null +++ b/crates/lune-std-net/src/shared/mod.rs @@ -0,0 +1,8 @@ +pub mod body; +pub mod futures; +pub mod headers; +pub mod hyper; +pub mod lua; +pub mod request; +pub mod response; +pub mod websocket; diff --git a/crates/lune-std-net/src/shared/request.rs b/crates/lune-std-net/src/shared/request.rs new file mode 100644 index 0000000..ea9cc46 --- /dev/null +++ b/crates/lune-std-net/src/shared/request.rs @@ -0,0 +1,272 @@ +use std::{collections::HashMap, net::SocketAddr}; + +use http_body_util::Full; +use url::Url; + +use hyper::{ + body::{Bytes, Incoming}, + HeaderMap, Method, Request as HyperRequest, +}; + +use mlua::prelude::*; + +use crate::shared::{ + body::{bytes_to_full, handle_incoming_body}, + headers::{hash_map_to_table, header_map_to_table}, + lua::{lua_table_to_header_map, lua_value_to_bytes, lua_value_to_method}, +}; + +#[derive(Debug, Clone)] +pub struct RequestOptions { + pub decompress: bool, +} + +impl Default for RequestOptions { + fn default() -> Self { + Self { decompress: true } + } +} + +impl FromLua for RequestOptions { + fn from_lua(value: LuaValue, _: &Lua) -> LuaResult { + if let LuaValue::Nil = value { + // Nil means default options + Ok(Self::default()) + } else if let LuaValue::Table(tab) = value { + // Table means custom options + let decompress = match tab.get::>("decompress") { + Ok(decomp) => Ok(decomp.unwrap_or(true)), + Err(_) => Err(LuaError::RuntimeError( + "Invalid option value for 'decompress' in request options".to_string(), + )), + }?; + Ok(Self { decompress }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "RequestOptions".to_string(), + message: Some(format!( + "Invalid request options - expected table or nil, got {}", + value.type_name() + )), + }) + } + } +} + +#[derive(Debug, Clone)] +pub struct Request { + // NOTE: We use Bytes instead of Full to avoid + // needing async when getting a reference to the body + pub(crate) inner: HyperRequest, + pub(crate) address: Option, + pub(crate) redirects: Option, + pub(crate) decompress: bool, +} + +impl Request { + /** + Creates a new request from a raw incoming request. + */ + pub async fn from_incoming( + incoming: HyperRequest, + decompress: bool, + ) -> LuaResult { + let (parts, body) = incoming.into_parts(); + + let (body, decompress) = handle_incoming_body(&parts.headers, body, decompress).await?; + + Ok(Self { + inner: HyperRequest::from_parts(parts, body), + address: None, + redirects: None, + decompress, + }) + } + + /** + Attaches a socket address to the request. + + This will make the `ip` and `port` fields available on the request. + */ + pub fn with_address(mut self, address: SocketAddr) -> Self { + self.address = Some(address); + self + } + + /** + Returns the method of the request. + */ + pub fn method(&self) -> Method { + self.inner.method().clone() + } + + /** + Returns the path of the request. + */ + pub fn path(&self) -> &str { + self.inner.uri().path() + } + + /** + Returns the query parameters of the request. + */ + pub fn query(&self) -> HashMap> { + let uri = self.inner.uri(); + let url = uri.to_string().parse::().expect("uri is valid"); + + let mut result = HashMap::>::new(); + for (key, value) in url.query_pairs() { + result + .entry(key.into_owned()) + .or_default() + .push(value.into_owned()); + } + result + } + + /** + Returns the headers of the request. + */ + pub fn headers(&self) -> &HeaderMap { + self.inner.headers() + } + + /** + Returns the body of the request. + */ + pub fn body(&self) -> &[u8] { + self.inner.body() + } + + /** + Clones the inner `hyper` request with its body + type modified to `Full` for sending. + */ + #[allow(dead_code)] + pub fn as_full(&self) -> HyperRequest> { + let mut builder = HyperRequest::builder() + .version(self.inner.version()) + .method(self.inner.method()) + .uri(self.inner.uri()); + + builder + .headers_mut() + .expect("request was valid") + .extend(self.inner.headers().clone()); + + let body = bytes_to_full(self.inner.body().clone()); + builder.body(body).expect("request was valid") + } + + /** + Takes the inner `hyper` request with its body + type modified to `Full` for sending. + */ + #[allow(dead_code)] + pub fn into_full(self) -> HyperRequest> { + let (parts, body) = self.inner.into_parts(); + HyperRequest::from_parts(parts, bytes_to_full(body)) + } +} + +impl FromLua for Request { + fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult { + if let LuaValue::String(s) = value { + // If we just got a string we assume + // its a GET request to a given url + let uri = s.to_str()?; + let uri = uri.parse().into_lua_err()?; + + let mut request = HyperRequest::new(Bytes::new()); + *request.uri_mut() = uri; + + Ok(Self { + inner: request, + address: None, + redirects: None, + decompress: RequestOptions::default().decompress, + }) + } else if let LuaValue::Table(tab) = value { + // If we got a table we are able to configure the + // entire request, maybe with extra options too + let options = match tab.get::("options") { + Ok(opts) => RequestOptions::from_lua(opts, lua)?, + Err(_) => RequestOptions::default(), + }; + + // Extract url (required) + optional structured query params + let url = tab.get::("url")?; + let mut url = url.to_str()?.parse::().into_lua_err()?; + if let Some(t) = tab.get::>("query")? { + let mut query = url.query_pairs_mut(); + for pair in t.pairs::() { + let (key, value) = pair?; + let key = key.to_str()?; + let value = value.to_str()?; + query.append_pair(&key, &value); + } + } + + // Extract method + let method = tab.get::("method")?; + let method = lua_value_to_method(&method)?; + + // Extract headers + let headers = tab.get::>("headers")?; + let headers = headers + .map(|t| lua_table_to_header_map(&t)) + .transpose()? + .unwrap_or_default(); + + // Extract body + let body = tab.get::("body")?; + let body = lua_value_to_bytes(&body)?; + + // Build the full request + let mut request = HyperRequest::new(body); + request.headers_mut().extend(headers); + *request.uri_mut() = url.to_string().parse().unwrap(); + *request.method_mut() = method; + + // All good, validated and we got what we need + Ok(Self { + inner: request, + address: None, + redirects: None, + decompress: options.decompress, + }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "Request".to_string(), + message: Some(format!( + "Invalid request - expected string or table, got {}", + value.type_name() + )), + }) + } + } +} + +impl LuaUserData for Request { + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("ip", |_, this| { + Ok(this.address.map(|address| address.ip().to_string())) + }); + fields.add_field_method_get("port", |_, this| { + Ok(this.address.map(|address| address.port())) + }); + fields.add_field_method_get("method", |_, this| Ok(this.method().to_string())); + fields.add_field_method_get("path", |_, this| Ok(this.path().to_string())); + fields.add_field_method_get("query", |lua, this| { + hash_map_to_table(lua, this.query(), false) + }); + fields.add_field_method_get("headers", |lua, this| { + header_map_to_table(lua, this.headers().clone(), this.decompress) + }); + fields.add_field_method_get("body", |lua, this| lua.create_string(this.body())); + } +} diff --git a/crates/lune-std-net/src/shared/response.rs b/crates/lune-std-net/src/shared/response.rs new file mode 100644 index 0000000..75a5687 --- /dev/null +++ b/crates/lune-std-net/src/shared/response.rs @@ -0,0 +1,172 @@ +use http_body_util::Full; + +use hyper::{ + body::{Bytes, Incoming}, + header::{HeaderValue, CONTENT_TYPE}, + HeaderMap, Response as HyperResponse, StatusCode, +}; + +use mlua::prelude::*; + +use crate::shared::{ + body::{bytes_to_full, handle_incoming_body}, + headers::header_map_to_table, + lua::{lua_table_to_header_map, lua_value_to_bytes}, +}; + +#[derive(Debug, Clone)] +pub struct Response { + // NOTE: We use Bytes instead of Full to avoid + // needing async when getting a reference to the body + pub(crate) inner: HyperResponse, + pub(crate) decompressed: bool, +} + +impl Response { + /** + Creates a new response from a raw incoming response. + */ + pub async fn from_incoming( + incoming: HyperResponse, + decompress: bool, + ) -> LuaResult { + let (parts, body) = incoming.into_parts(); + + let (body, decompressed) = handle_incoming_body(&parts.headers, body, decompress).await?; + + Ok(Self { + inner: HyperResponse::from_parts(parts, body), + decompressed, + }) + } + + /** + Returns whether the request was successful or not. + */ + pub fn status_ok(&self) -> bool { + self.inner.status().is_success() + } + + /** + Returns the status code of the response. + */ + pub fn status_code(&self) -> u16 { + self.inner.status().as_u16() + } + + /** + Returns the status message of the response. + */ + pub fn status_message(&self) -> &str { + self.inner.status().canonical_reason().unwrap_or_default() + } + + /** + Returns the headers of the response. + */ + pub fn headers(&self) -> &HeaderMap { + self.inner.headers() + } + + /** + Returns the body of the response. + */ + pub fn body(&self) -> &[u8] { + self.inner.body() + } + + /** + Clones the inner `hyper` response with its body + type modified to `Full` for sending. + */ + #[allow(dead_code)] + pub fn as_full(&self) -> HyperResponse> { + let mut builder = HyperResponse::builder() + .version(self.inner.version()) + .status(self.inner.status()); + + builder + .headers_mut() + .expect("request was valid") + .extend(self.inner.headers().clone()); + + let body = bytes_to_full(self.inner.body().clone()); + builder.body(body).expect("request was valid") + } + + /** + Takes the inner `hyper` response with its body + type modified to `Full` for sending. + */ + #[allow(dead_code)] + pub fn into_full(self) -> HyperResponse> { + let (parts, body) = self.inner.into_parts(); + HyperResponse::from_parts(parts, bytes_to_full(body)) + } +} + +impl FromLua for Response { + fn from_lua(value: LuaValue, _: &Lua) -> LuaResult { + if let Ok(body) = lua_value_to_bytes(&value) { + // String or buffer is always a 200 text/plain response + let mut response = HyperResponse::new(body); + response + .headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("text/plain")); + Ok(Self { + inner: response, + decompressed: false, + }) + } else if let LuaValue::Table(tab) = value { + // Extract status (required) + let status = tab.get::("status")?; + let status = StatusCode::from_u16(status).into_lua_err()?; + + // Extract headers + let headers = tab.get::>("headers")?; + let headers = headers + .map(|t| lua_table_to_header_map(&t)) + .transpose()? + .unwrap_or_default(); + + // Extract body + let body = tab.get::("body")?; + let body = lua_value_to_bytes(&body)?; + + // Build the full response + let mut response = HyperResponse::new(body); + response.headers_mut().extend(headers); + *response.status_mut() = status; + + // All good, validated and we got what we need + Ok(Self { + inner: response, + decompressed: false, + }) + } else { + // Anything else is invalid + Err(LuaError::FromLuaConversionError { + from: value.type_name(), + to: "Response".to_string(), + message: Some(format!( + "Invalid response - expected table/string/buffer, got {}", + value.type_name() + )), + }) + } + } +} + +impl LuaUserData for Response { + fn add_fields>(fields: &mut F) { + fields.add_field_method_get("ok", |_, this| Ok(this.status_ok())); + fields.add_field_method_get("statusCode", |_, this| Ok(this.status_code())); + fields.add_field_method_get("statusMessage", |lua, this| { + lua.create_string(this.status_message()) + }); + fields.add_field_method_get("headers", |lua, this| { + header_map_to_table(lua, this.headers().clone(), this.decompressed) + }); + fields.add_field_method_get("body", |lua, this| lua.create_string(this.body())); + } +} diff --git a/crates/lune-std-net/src/websocket.rs b/crates/lune-std-net/src/shared/websocket.rs similarity index 58% rename from crates/lune-std-net/src/websocket.rs rename to crates/lune-std-net/src/shared/websocket.rs index 35ecae1..1d22947 100644 --- a/crates/lune-std-net/src/websocket.rs +++ b/crates/lune-std-net/src/shared/websocket.rs @@ -1,62 +1,38 @@ -use std::sync::{ - atomic::{AtomicBool, AtomicU16, Ordering}, - Arc, +use std::{ + error::Error, + sync::{ + atomic::{AtomicBool, AtomicU16, Ordering}, + Arc, + }, }; +use async_lock::Mutex as AsyncMutex; +use async_tungstenite::tungstenite::{ + protocol::{frame::coding::CloseCode, CloseFrame}, + Message as TungsteniteMessage, Result as TungsteniteResult, Utf8Bytes, +}; use bstr::{BString, ByteSlice}; +use futures::{ + stream::{SplitSink, SplitStream}, + Sink, SinkExt, Stream, StreamExt, +}; +use hyper::body::Bytes; + use mlua::prelude::*; -use futures_util::{ - stream::{SplitSink, SplitStream}, - SinkExt, StreamExt, -}; -use tokio::{ - io::{AsyncRead, AsyncWrite}, - sync::Mutex as AsyncMutex, -}; - -use hyper_tungstenite::{ - tungstenite::{ - protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame}, - Message as WsMessage, - }, - WebSocketStream, -}; - -#[derive(Debug)] -pub struct NetWebSocket { +#[derive(Debug, Clone)] +pub struct Websocket { close_code_exists: Arc, close_code_value: Arc, - read_stream: Arc>>>, - write_stream: Arc, WsMessage>>>, + read_stream: Arc>>, + write_stream: Arc>>, } -impl Clone for NetWebSocket { - fn clone(&self) -> Self { - Self { - close_code_exists: Arc::clone(&self.close_code_exists), - close_code_value: Arc::clone(&self.close_code_value), - read_stream: Arc::clone(&self.read_stream), - write_stream: Arc::clone(&self.write_stream), - } - } -} - -impl NetWebSocket +impl Websocket where - T: AsyncRead + AsyncWrite + Unpin + 'static, + T: Stream> + Sink + 'static, + >::Error: Into>, { - pub fn new(value: WebSocketStream) -> Self { - let (write, read) = value.split(); - - Self { - close_code_exists: Arc::new(AtomicBool::new(false)), - close_code_value: Arc::new(AtomicU16::new(0)), - read_stream: Arc::new(AsyncMutex::new(read)), - write_stream: Arc::new(AsyncMutex::new(write)), - } - } - fn get_close_code(&self) -> Option { if self.close_code_exists.load(Ordering::Relaxed) { Some(self.close_code_value.load(Ordering::Relaxed)) @@ -70,12 +46,12 @@ where self.close_code_value.store(code, Ordering::Relaxed); } - pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { + pub async fn send(&self, msg: TungsteniteMessage) -> LuaResult<()> { let mut ws = self.write_stream.lock().await; ws.send(msg).await.into_lua_err() } - pub async fn next(&self) -> LuaResult> { + pub async fn next(&self) -> LuaResult> { let mut ws = self.read_stream.lock().await; ws.next().await.transpose().into_lua_err() } @@ -85,15 +61,15 @@ where return Err(LuaError::runtime("Socket has already been closed")); } - self.send(WsMessage::Close(Some(WsCloseFrame { + self.send(TungsteniteMessage::Close(Some(CloseFrame { code: match code { - Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), + Some(code) if (1000..=4999).contains(&code) => CloseCode::from(code), Some(code) => { return Err(LuaError::runtime(format!( "Close code must be between 1000 and 4999, got {code}" ))) } - None => WsCloseCode::Normal, + None => CloseCode::Normal, }, reason: "".into(), }))) @@ -104,9 +80,27 @@ where } } -impl LuaUserData for NetWebSocket +impl From for Websocket where - T: AsyncRead + AsyncWrite + Unpin + 'static, + T: Stream> + Sink + 'static, + >::Error: Into>, +{ + fn from(value: T) -> Self { + let (write, read) = value.split(); + + Self { + close_code_exists: Arc::new(AtomicBool::new(false)), + close_code_value: Arc::new(AtomicU16::new(0)), + read_stream: Arc::new(AsyncMutex::new(read)), + write_stream: Arc::new(AsyncMutex::new(write)), + } + } +} + +impl LuaUserData for Websocket +where + T: Stream> + Sink + 'static, + >::Error: Into>, { fn add_fields>(fields: &mut F) { fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code())); @@ -121,10 +115,10 @@ where "send", |_, this, (string, as_binary): (BString, Option)| async move { this.send(if as_binary.unwrap_or_default() { - WsMessage::Binary(string.as_bytes().to_vec()) + TungsteniteMessage::Binary(Bytes::from(string.to_vec())) } else { let s = string.to_str().into_lua_err()?; - WsMessage::Text(s.to_string()) + TungsteniteMessage::Text(Utf8Bytes::from(s)) }) .await }, @@ -133,14 +127,14 @@ where methods.add_async_method("next", |lua, this, (): ()| async move { let msg = this.next().await?; - if let Some(WsMessage::Close(Some(frame))) = msg.as_ref() { + if let Some(TungsteniteMessage::Close(Some(frame))) = msg.as_ref() { this.set_close_code(frame.code.into()); } Ok(match msg { - Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?), - Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?), - Some(WsMessage::Close(_)) | None => LuaValue::Nil, + Some(TungsteniteMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?), + Some(TungsteniteMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?), + Some(TungsteniteMessage::Close(_)) | None => LuaValue::Nil, // Ignore ping/pong/frame messages, they are handled by tungstenite msg => unreachable!("Unhandled message: {:?}", msg), }) diff --git a/crates/lune-std-net/src/url/decode.rs b/crates/lune-std-net/src/url/decode.rs new file mode 100644 index 0000000..b64523b --- /dev/null +++ b/crates/lune-std-net/src/url/decode.rs @@ -0,0 +1,12 @@ +use mlua::prelude::*; + +pub fn decode(lua_string: LuaString, as_binary: bool) -> LuaResult> { + if as_binary { + Ok(urlencoding::decode_binary(&lua_string.as_bytes()).into_owned()) + } else { + Ok(urlencoding::decode(&lua_string.to_str()?) + .map_err(|e| LuaError::RuntimeError(format!("Encountered invalid encoding - {e}")))? + .into_owned() + .into_bytes()) + } +} diff --git a/crates/lune-std-net/src/url/encode.rs b/crates/lune-std-net/src/url/encode.rs new file mode 100644 index 0000000..56e8995 --- /dev/null +++ b/crates/lune-std-net/src/url/encode.rs @@ -0,0 +1,13 @@ +use mlua::prelude::*; + +pub fn encode(lua_string: LuaString, as_binary: bool) -> LuaResult> { + if as_binary { + Ok(urlencoding::encode_binary(&lua_string.as_bytes()) + .into_owned() + .into_bytes()) + } else { + Ok(urlencoding::encode(&lua_string.to_str()?) + .into_owned() + .into_bytes()) + } +} diff --git a/crates/lune-std-net/src/url/mod.rs b/crates/lune-std-net/src/url/mod.rs new file mode 100644 index 0000000..de58cce --- /dev/null +++ b/crates/lune-std-net/src/url/mod.rs @@ -0,0 +1,5 @@ +mod decode; +mod encode; + +pub use self::decode::decode; +pub use self::encode::encode; diff --git a/crates/lune-std-net/src/util.rs b/crates/lune-std-net/src/util.rs deleted file mode 100644 index 9fe01a0..0000000 --- a/crates/lune-std-net/src/util.rs +++ /dev/null @@ -1,94 +0,0 @@ -use std::collections::HashMap; - -use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH}; -use reqwest::header::HeaderMap; - -use mlua::prelude::*; - -use lune_utils::TableBuilder; - -pub fn create_user_agent_header(lua: &Lua) -> LuaResult { - let version_global = lua - .globals() - .get::("_VERSION") - .expect("Missing _VERSION global"); - - let version_global_str = version_global - .to_str() - .context("Invalid utf8 found in _VERSION global")?; - - let (package_name, full_version) = version_global_str.split_once(' ').unwrap(); - - Ok(format!("{}/{}", package_name.to_lowercase(), full_version)) -} - -pub fn header_map_to_table( - lua: &Lua, - headers: HeaderMap, - remove_content_headers: bool, -) -> LuaResult { - let mut res_headers: HashMap> = HashMap::new(); - for (name, value) in &headers { - let name = name.as_str(); - let value = value.to_str().unwrap().to_owned(); - if let Some(existing) = res_headers.get_mut(name) { - existing.push(value); - } else { - res_headers.insert(name.to_owned(), vec![value]); - } - } - - if remove_content_headers { - let content_encoding_header_str = CONTENT_ENCODING.as_str(); - let content_length_header_str = CONTENT_LENGTH.as_str(); - res_headers.retain(|name, _| { - name != content_encoding_header_str && name != content_length_header_str - }); - } - - let mut builder = TableBuilder::new(lua.clone())?; - for (name, mut values) in res_headers { - if values.len() == 1 { - let value = values.pop().unwrap().into_lua(lua)?; - builder = builder.with_value(name, value)?; - } else { - let values = TableBuilder::new(lua.clone())? - .with_sequential_values(values)? - .build_readonly()? - .into_lua(lua)?; - builder = builder.with_value(name, values)?; - } - } - - builder.build_readonly() -} - -pub fn table_to_hash_map( - tab: LuaTable, - tab_origin_key: &'static str, -) -> LuaResult>> { - let mut map = HashMap::new(); - - for pair in tab.pairs::() { - let (key, value) = pair?; - match value { - LuaValue::String(s) => { - map.insert(key, vec![s.to_str()?.to_owned()]); - } - LuaValue::Table(t) => { - let mut values = Vec::new(); - for value in t.sequence_values::() { - values.push(value?.to_str()?.to_owned()); - } - map.insert(key, values); - } - _ => { - return Err(LuaError::runtime(format!( - "Value for '{tab_origin_key}' must be a string or array of strings", - ))) - } - } - } - - Ok(map) -} diff --git a/crates/lune/src/tests.rs b/crates/lune/src/tests.rs index e4809e8..cadca69 100644 --- a/crates/lune/src/tests.rs +++ b/crates/lune/src/tests.rs @@ -127,13 +127,19 @@ create_tests! { net_request_methods: "net/request/methods", net_request_query: "net/request/query", net_request_redirect: "net/request/redirect", - net_url_encode: "net/url/encode", - net_url_decode: "net/url/decode", + + net_serve_addresses: "net/serve/addresses", + net_serve_handles: "net/serve/handles", + net_serve_non_blocking: "net/serve/non_blocking", net_serve_requests: "net/serve/requests", net_serve_websockets: "net/serve/websockets", + net_socket_basic: "net/socket/basic", net_socket_wss: "net/socket/wss", net_socket_wss_rw: "net/socket/wss_rw", + + net_url_encode: "net/url/encode", + net_url_decode: "net/url/decode", } #[cfg(feature = "std-process")] diff --git a/crates/mlua-luau-scheduler/Cargo.toml b/crates/mlua-luau-scheduler/Cargo.toml index 9952f75..6eb0fb0 100644 --- a/crates/mlua-luau-scheduler/Cargo.toml +++ b/crates/mlua-luau-scheduler/Cargo.toml @@ -16,13 +16,12 @@ path = "src/lib.rs" workspace = true [dependencies] -async-executor = "1.8" -blocking = "1.5" -concurrent-queue = "2.4" -derive_more = "0.99" -event-listener = "4.0" -futures-lite = "2.2" -rustc-hash = "1.1" +async-executor = "1.13" +blocking = "1.6" +concurrent-queue = "2.5" +event-listener = "5.4" +futures-lite = "2.6" +rustc-hash = "2.1" tracing = "0.1" mlua = { version = "0.10.3", features = [ @@ -34,7 +33,7 @@ mlua = { version = "0.10.3", features = [ [dev-dependencies] async-fs = "2.1" -async-io = "2.3" +async-io = "2.4" tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-tracy = "0.11" diff --git a/crates/mlua-luau-scheduler/src/functions.rs b/crates/mlua-luau-scheduler/src/functions.rs index 307d0f1..cb26a23 100644 --- a/crates/mlua-luau-scheduler/src/functions.rs +++ b/crates/mlua-luau-scheduler/src/functions.rs @@ -8,7 +8,7 @@ use crate::{ result_map::ThreadResultMap, thread_id::ThreadId, traits::LuaSchedulerExt, - util::{is_poll_pending, LuaThreadOrFunction, ThreadResult}, + util::{is_poll_pending, LuaThreadOrFunction}, }; const ERR_METADATA_NOT_ATTACHED: &str = "\ @@ -123,8 +123,7 @@ impl Functions { if thread.status() != LuaThreadStatus::Resumable { let id = ThreadId::from(&thread); if resume_map.is_tracked(id) { - let res = ThreadResult::new(Ok(v.clone()), lua); - resume_map.insert(id, res); + resume_map.insert(id, Ok(v.clone())); } } (true, v).into_lua_multi(lua) @@ -134,8 +133,7 @@ impl Functions { // Not pending, store the error let id = ThreadId::from(&thread); if resume_map.is_tracked(id) { - let res = ThreadResult::new(Err(e.clone()), lua); - resume_map.insert(id, res); + resume_map.insert(id, Err(e.clone())); } (false, e.to_string()).into_lua_multi(lua) } @@ -177,8 +175,7 @@ impl Functions { if thread.status() != LuaThreadStatus::Resumable { let id = ThreadId::from(&thread); if spawn_map.is_tracked(id) { - let res = ThreadResult::new(Ok(v), lua); - spawn_map.insert(id, res); + spawn_map.insert(id, Ok(v)); } } } @@ -188,8 +185,7 @@ impl Functions { // Not pending, store the error let id = ThreadId::from(&thread); if spawn_map.is_tracked(id) { - let res = ThreadResult::new(Err(e), lua); - spawn_map.insert(id, res); + spawn_map.insert(id, Err(e)); } } } diff --git a/crates/mlua-luau-scheduler/src/queue.rs b/crates/mlua-luau-scheduler/src/queue.rs index 3d17249..4f22f28 100644 --- a/crates/mlua-luau-scheduler/src/queue.rs +++ b/crates/mlua-luau-scheduler/src/queue.rs @@ -1,12 +1,15 @@ -use std::{pin::Pin, rc::Rc}; +use std::{ + ops::{Deref, DerefMut}, + pin::Pin, + rc::Rc, +}; use concurrent_queue::ConcurrentQueue; -use derive_more::{Deref, DerefMut}; use event_listener::Event; use futures_lite::{Future, FutureExt}; use mlua::prelude::*; -use crate::{traits::IntoLuaThread, util::ThreadWithArgs, ThreadId}; +use crate::{traits::IntoLuaThread, ThreadId}; /** Queue for storing [`LuaThread`]s with associated arguments. @@ -16,15 +19,13 @@ use crate::{traits::IntoLuaThread, util::ThreadWithArgs, ThreadId}; */ #[derive(Debug, Clone)] pub(crate) struct ThreadQueue { - queue: Rc>, - event: Rc, + inner: Rc, } impl ThreadQueue { pub fn new() -> Self { - let queue = Rc::new(ConcurrentQueue::unbounded()); - let event = Rc::new(Event::new()); - Self { queue, event } + let inner = Rc::new(ThreadQueueInner::new()); + Self { inner } } pub fn push_item( @@ -38,32 +39,25 @@ impl ThreadQueue { tracing::trace!("pushing item to queue with {} args", args.len()); let id = ThreadId::from(&thread); - let stored = ThreadWithArgs::new(lua, thread, args)?; - self.queue.push(stored).into_lua_err()?; - self.event.notify(usize::MAX); + let _ = self.inner.queue.push((thread, args)); + self.inner.event.notify(usize::MAX); Ok(id) } #[inline] - pub fn drain_items<'outer, 'lua>( - &'outer self, - lua: &'lua Lua, - ) -> impl Iterator + 'outer - where - 'lua: 'outer, - { - self.queue.try_iter().map(|stored| stored.into_inner(lua)) + pub fn drain_items(&self) -> impl Iterator + '_ { + self.inner.queue.try_iter() } #[inline] pub async fn wait_for_item(&self) { - if self.queue.is_empty() { - let listener = self.event.listen(); + if self.inner.queue.is_empty() { + let listener = self.inner.event.listen(); // NOTE: Need to check again, we could have gotten // new queued items while creating our listener - if self.queue.is_empty() { + if self.inner.queue.is_empty() { listener.await; } } @@ -71,14 +65,14 @@ impl ThreadQueue { #[inline] pub fn is_empty(&self) -> bool { - self.queue.is_empty() + self.inner.queue.is_empty() } } /** Alias for [`ThreadQueue`], providing a newtype to store in Lua app data. */ -#[derive(Debug, Clone, Deref, DerefMut)] +#[derive(Debug, Clone)] pub(crate) struct SpawnedThreadQueue(ThreadQueue); impl SpawnedThreadQueue { @@ -87,10 +81,23 @@ impl SpawnedThreadQueue { } } +impl Deref for SpawnedThreadQueue { + type Target = ThreadQueue; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for SpawnedThreadQueue { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + /** Alias for [`ThreadQueue`], providing a newtype to store in Lua app data. */ -#[derive(Debug, Clone, Deref, DerefMut)] +#[derive(Debug, Clone)] pub(crate) struct DeferredThreadQueue(ThreadQueue); impl DeferredThreadQueue { @@ -99,6 +106,19 @@ impl DeferredThreadQueue { } } +impl Deref for DeferredThreadQueue { + type Target = ThreadQueue; + fn deref(&self) -> &Self::Target { + &self.0 + } +} + +impl DerefMut for DeferredThreadQueue { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } +} + pub type LocalBoxFuture<'fut> = Pin + 'fut>>; /** @@ -109,31 +129,60 @@ pub type LocalBoxFuture<'fut> = Pin + 'fut>>; */ #[derive(Debug, Clone)] pub(crate) struct FuturesQueue<'fut> { - queue: Rc>>, - event: Rc, + inner: Rc>, } impl<'fut> FuturesQueue<'fut> { pub fn new() -> Self { - let queue = Rc::new(ConcurrentQueue::unbounded()); - let event = Rc::new(Event::new()); - Self { queue, event } + let inner = Rc::new(FuturesQueueInner::new()); + Self { inner } } pub fn push_item(&self, fut: impl Future + 'fut) { - let _ = self.queue.push(fut.boxed_local()); - self.event.notify(usize::MAX); + let _ = self.inner.queue.push(fut.boxed_local()); + self.inner.event.notify(usize::MAX); } pub fn drain_items<'outer>( &'outer self, ) -> impl Iterator> + 'outer { - self.queue.try_iter() + self.inner.queue.try_iter() } pub async fn wait_for_item(&self) { - if self.queue.is_empty() { - self.event.listen().await; + if self.inner.queue.is_empty() { + self.inner.event.listen().await; } } } + +// Inner structs without ref counting so that outer structs +// have only a single ref counter for extremely cheap clones + +#[derive(Debug)] +struct ThreadQueueInner { + queue: ConcurrentQueue<(LuaThread, LuaMultiValue)>, + event: Event, +} + +impl ThreadQueueInner { + fn new() -> Self { + let queue = ConcurrentQueue::unbounded(); + let event = Event::new(); + Self { queue, event } + } +} + +#[derive(Debug)] +struct FuturesQueueInner<'fut> { + queue: ConcurrentQueue>, + event: Event, +} + +impl FuturesQueueInner<'_> { + pub fn new() -> Self { + let queue = ConcurrentQueue::unbounded(); + let event = Event::new(); + Self { queue, event } + } +} diff --git a/crates/mlua-luau-scheduler/src/result_map.rs b/crates/mlua-luau-scheduler/src/result_map.rs index fe08a5f..39a91b1 100644 --- a/crates/mlua-luau-scheduler/src/result_map.rs +++ b/crates/mlua-luau-scheduler/src/result_map.rs @@ -5,60 +5,77 @@ use std::{cell::RefCell, rc::Rc}; use event_listener::Event; // NOTE: This is the hash algorithm that mlua also uses, so we // are not adding any additional dependencies / bloat by using it. +use mlua::prelude::*; use rustc_hash::{FxHashMap, FxHashSet}; -use crate::{thread_id::ThreadId, util::ThreadResult}; +use crate::thread_id::ThreadId; + +struct ThreadResultMapInner { + tracked: FxHashSet, + results: FxHashMap>, + events: FxHashMap>, +} + +impl ThreadResultMapInner { + fn new() -> Self { + Self { + tracked: FxHashSet::default(), + results: FxHashMap::default(), + events: FxHashMap::default(), + } + } +} #[derive(Clone)] pub(crate) struct ThreadResultMap { - tracked: Rc>>, - results: Rc>>, - events: Rc>>>, + inner: Rc>, } impl ThreadResultMap { pub fn new() -> Self { - Self { - tracked: Rc::new(RefCell::new(FxHashSet::default())), - results: Rc::new(RefCell::new(FxHashMap::default())), - events: Rc::new(RefCell::new(FxHashMap::default())), - } + let inner = Rc::new(RefCell::new(ThreadResultMapInner::new())); + Self { inner } } #[inline(always)] pub fn track(&self, id: ThreadId) { - self.tracked.borrow_mut().insert(id); + self.inner.borrow_mut().tracked.insert(id); } #[inline(always)] pub fn is_tracked(&self, id: ThreadId) -> bool { - self.tracked.borrow().contains(&id) + self.inner.borrow().tracked.contains(&id) } - pub fn insert(&self, id: ThreadId, result: ThreadResult) { + pub fn insert(&self, id: ThreadId, result: LuaResult) { debug_assert!(self.is_tracked(id), "Thread must be tracked"); - self.results.borrow_mut().insert(id, result); - if let Some(event) = self.events.borrow_mut().remove(&id) { + let mut inner = self.inner.borrow_mut(); + inner.results.insert(id, result); + if let Some(event) = inner.events.remove(&id) { event.notify(usize::MAX); } } pub async fn listen(&self, id: ThreadId) { debug_assert!(self.is_tracked(id), "Thread must be tracked"); - if !self.results.borrow().contains_key(&id) { + if !self.inner.borrow().results.contains_key(&id) { let listener = { - let mut events = self.events.borrow_mut(); - let event = events.entry(id).or_insert_with(|| Rc::new(Event::new())); + let mut inner = self.inner.borrow_mut(); + let event = inner + .events + .entry(id) + .or_insert_with(|| Rc::new(Event::new())); event.listen() }; listener.await; } } - pub fn remove(&self, id: ThreadId) -> Option { - let res = self.results.borrow_mut().remove(&id)?; - self.tracked.borrow_mut().remove(&id); - self.events.borrow_mut().remove(&id); + pub fn remove(&self, id: ThreadId) -> Option> { + let mut inner = self.inner.borrow_mut(); + let res = inner.results.remove(&id)?; + inner.tracked.remove(&id); + inner.events.remove(&id); Some(res) } } diff --git a/crates/mlua-luau-scheduler/src/scheduler.rs b/crates/mlua-luau-scheduler/src/scheduler.rs index 7e1eded..7eeaa7d 100644 --- a/crates/mlua-luau-scheduler/src/scheduler.rs +++ b/crates/mlua-luau-scheduler/src/scheduler.rs @@ -2,7 +2,7 @@ use std::{ cell::Cell, - rc::{Rc, Weak as WeakRc}, + rc::Rc, sync::{Arc, Weak as WeakArc}, thread::panicking, }; @@ -21,7 +21,7 @@ use crate::{ status::Status, thread_id::ThreadId, traits::IntoLuaThread, - util::{run_until_yield, ThreadResult}, + util::run_until_yield, }; const ERR_METADATA_ALREADY_ATTACHED: &str = "\ @@ -248,7 +248,7 @@ impl Scheduler { */ #[must_use] pub fn get_thread_result(&self, id: ThreadId) -> Option> { - self.result_map.remove(id).map(|r| r.value(&self.lua)) + self.result_map.remove(id) } /** @@ -286,7 +286,7 @@ impl Scheduler { */ let local_exec = LocalExecutor::new(); let main_exec = Arc::new(Executor::new()); - let fut_queue = Rc::new(FuturesQueue::new()); + let fut_queue = FuturesQueue::new(); /* Store the main executor and queue in Lua, so that they may be used with LuaSchedulerExt. @@ -299,12 +299,12 @@ impl Scheduler { "{ERR_METADATA_ALREADY_ATTACHED}" ); assert!( - self.lua.app_data_ref::>().is_none(), + self.lua.app_data_ref::().is_none(), "{ERR_METADATA_ALREADY_ATTACHED}" ); self.lua.set_app_data(Arc::downgrade(&main_exec)); - self.lua.set_app_data(Rc::downgrade(&fut_queue.clone())); + self.lua.set_app_data(fut_queue.clone()); /* Manually tick the Lua executor, while running under the main executor. @@ -342,8 +342,7 @@ impl Scheduler { self.error_callback.call(e); } if thread.status() != LuaThreadStatus::Resumable { - let thread_res = ThreadResult::new(res, &self.lua); - result_map_inner.unwrap().insert(id, thread_res); + result_map_inner.unwrap().insert(id, res); } } } else { @@ -398,14 +397,14 @@ impl Scheduler { let mut num_futures = 0; { let _span = trace_span!("Scheduler::drain_spawned").entered(); - for (thread, args) in self.queue_spawn.drain_items(&self.lua) { + for (thread, args) in self.queue_spawn.drain_items() { process_thread(thread, args); num_spawned += 1; } } { let _span = trace_span!("Scheduler::drain_deferred").entered(); - for (thread, args) in self.queue_defer.drain_items(&self.lua) { + for (thread, args) in self.queue_defer.drain_items() { process_thread(thread, args); num_deferred += 1; } @@ -446,7 +445,7 @@ impl Scheduler { .remove_app_data::>() .expect(ERR_METADATA_REMOVED); self.lua - .remove_app_data::>() + .remove_app_data::() .expect(ERR_METADATA_REMOVED); } } diff --git a/crates/mlua-luau-scheduler/src/traits.rs b/crates/mlua-luau-scheduler/src/traits.rs index 6c854ea..27a3733 100644 --- a/crates/mlua-luau-scheduler/src/traits.rs +++ b/crates/mlua-luau-scheduler/src/traits.rs @@ -323,7 +323,7 @@ impl LuaSchedulerExt for Lua { let map = self .app_data_ref::() .expect("lua threads results can only be retrieved from within an active scheduler"); - map.remove(id).map(|r| r.value(self)) + map.remove(id) } fn wait_for_thread(&self, id: ThreadId) -> impl Future { @@ -354,10 +354,8 @@ impl LuaSpawnExt for Lua { F: Future + 'static, { let queue = self - .app_data_ref::>() - .expect("tasks can only be spawned within an active scheduler") - .upgrade() - .expect("executor was dropped"); + .app_data_ref::() + .expect("tasks can only be spawned within an active scheduler"); trace!("spawning local task on executor"); queue.push_item(fut); } diff --git a/crates/mlua-luau-scheduler/src/util.rs b/crates/mlua-luau-scheduler/src/util.rs index 8f6c188..30c9912 100644 --- a/crates/mlua-luau-scheduler/src/util.rs +++ b/crates/mlua-luau-scheduler/src/util.rs @@ -40,74 +40,6 @@ pub(crate) fn is_poll_pending(value: &LuaValue) -> bool { .is_some_and(|l| l == Lua::poll_pending()) } -/** - Representation of a [`LuaResult`] with an associated [`LuaMultiValue`] currently stored in the Lua registry. -*/ -#[derive(Debug)] -pub(crate) struct ThreadResult { - inner: LuaResult, -} - -impl ThreadResult { - pub fn new(result: LuaResult, lua: &Lua) -> Self { - Self { - inner: match result { - Ok(v) => Ok({ - let vec = v.into_vec(); - lua.create_registry_value(vec).expect("out of memory") - }), - Err(e) => Err(e), - }, - } - } - - pub fn value(self, lua: &Lua) -> LuaResult { - match self.inner { - Ok(key) => { - let vec = lua.registry_value(&key).unwrap(); - lua.remove_registry_value(key).unwrap(); - Ok(LuaMultiValue::from_vec(vec)) - } - Err(e) => Err(e.clone()), - } - } -} - -/** - Representation of a [`LuaThread`] with its associated arguments currently stored in the Lua registry. -*/ -#[derive(Debug)] -pub(crate) struct ThreadWithArgs { - key_thread: LuaRegistryKey, - key_args: LuaRegistryKey, -} - -impl ThreadWithArgs { - pub fn new(lua: &Lua, thread: LuaThread, args: LuaMultiValue) -> LuaResult { - let argsv = args.into_vec(); - - let key_thread = lua.create_registry_value(thread)?; - let key_args = lua.create_registry_value(argsv)?; - - Ok(Self { - key_thread, - key_args, - }) - } - - pub fn into_inner(self, lua: &Lua) -> (LuaThread, LuaMultiValue) { - let thread = lua.registry_value(&self.key_thread).unwrap(); - let argsv = lua.registry_value(&self.key_args).unwrap(); - - let args = LuaMultiValue::from_vec(argsv); - - lua.remove_registry_value(self.key_thread).unwrap(); - lua.remove_registry_value(self.key_args).unwrap(); - - (thread, args) - } -} - /** Wrapper struct to accept either a Lua thread or a Lua function as function argument. diff --git a/rokit.toml b/rokit.toml index 3988e56..b163dd5 100644 --- a/rokit.toml +++ b/rokit.toml @@ -1,4 +1,5 @@ [tools] -luau-lsp = "JohnnyMorganz/luau-lsp@1.33.1" -stylua = "JohnnyMorganz/StyLua@0.20.0" -just = "casey/just@1.36.0" +luau-lsp = "JohnnyMorganz/luau-lsp@1.44.1" +lune = "lune-org/lune@0.9.0" +stylua = "JohnnyMorganz/StyLua@2.1.0" +just = "casey/just@1.40.0" diff --git a/scripts/analyze_copy_typedefs.luau b/scripts/analyze_copy_typedefs.luau new file mode 100644 index 0000000..ade7b6d --- /dev/null +++ b/scripts/analyze_copy_typedefs.luau @@ -0,0 +1,14 @@ +local fs = require("@lune/fs") + +fs.writeDir("./types") + +for _, dir in fs.readDir("./crates") do + local std = string.match(dir, "^lune%-std%-(%w+)$") + if std ~= nil then + local from = `./crates/{dir}/types.d.luau` + if fs.isFile(from) then + local to = `./types/{std}.luau` + fs.copy(from, to, true) + end + end +end diff --git a/tests/net/serve/addresses.luau b/tests/net/serve/addresses.luau new file mode 100644 index 0000000..57f591f --- /dev/null +++ b/tests/net/serve/addresses.luau @@ -0,0 +1,35 @@ +local net = require("@lune/net") + +local PORT = 8081 +local LOCALHOST = "http://localhost" +local BROADCAST = `http://0.0.0.0` +local RESPONSE = "Hello, lune!" + +-- Serve should be able to bind to broadcast IP addresse + +local handle = net.serve(PORT, { + address = BROADCAST, + handleRequest = function(request) + return `Response from {BROADCAST}:{PORT}` + end, +}) + +-- And any requests to localhost should then succeed + +local response = net.request(`{LOCALHOST}:{PORT}`).body +assert(response ~= nil, "Invalid response from server") + +handle.stop() + +-- Attempting to serve with a malformed IP address should throw an error + +local success = pcall(function() + net.serve(8080, { + address = "a.b.c.d", + handleRequest = function() + return RESPONSE + end, + }) +end) + +assert(not success, "Server was created with malformed address") diff --git a/tests/net/serve/handles.luau b/tests/net/serve/handles.luau new file mode 100644 index 0000000..b3ce418 --- /dev/null +++ b/tests/net/serve/handles.luau @@ -0,0 +1,51 @@ +local net = require("@lune/net") +local task = require("@lune/task") + +local PORT = 8082 +local URL = `http://127.0.0.1:{PORT}` +local RESPONSE = "Hello, lune!" + +local handle = net.serve(PORT, function(request) + return RESPONSE +end) + +-- Stopping is not guaranteed to happen instantly since it is async, but +-- it should happen on the next yield, so we wait the minimum amount here + +handle.stop() +task.wait() + +-- Sending a request to the stopped server should now error + +local success, response2 = pcall(net.request, URL) +if not success then + local message = tostring(response2) + assert( + string.find(message, "Connection reset") + or string.find(message, "Connection closed") + or string.find(message, "Connection refused") + or string.find(message, "No connection could be made"), -- Windows Request Error + "Server did not stop responding to requests" + ) +else + assert(not response2.ok, "Server did not stop responding to requests") +end + +--[[ + Trying to *stop* the server again should error, and + also mention that the server has already been stopped + + Note that we cast pcall to any because of a + Luau limitation where it throws a type error for + `err` because handle.stop doesn't return any value +]] + +local success2, err = (pcall :: any)(handle.stop) +assert(not success2, "Calling stop twice on the net serve handle should error") +local message = tostring(err) +assert( + string.find(message, "stop") + or string.find(message, "shutdown") + or string.find(message, "shut down"), + "The error message for calling stop twice on the net serve handle should be descriptive" +) diff --git a/tests/net/serve/non_blocking.luau b/tests/net/serve/non_blocking.luau new file mode 100644 index 0000000..dea1100 --- /dev/null +++ b/tests/net/serve/non_blocking.luau @@ -0,0 +1,24 @@ +local net = require("@lune/net") +local process = require("@lune/process") +local stdio = require("@lune/stdio") +local task = require("@lune/task") + +local PORT = 8083 +local RESPONSE = "Hello, lune!" + +-- Serve should not yield the entire main thread forever, only +-- for the initial binding to socket which should be very fast + +local thread = task.delay(1, function() + stdio.ewrite("Serve must not yield the current thread for too long\n") + task.wait(1) + process.exit(1) +end) + +local handle = net.serve(PORT, function(request) + return RESPONSE +end) + +task.cancel(thread) + +handle.stop() diff --git a/tests/net/serve/requests.luau b/tests/net/serve/requests.luau index 36602a0..6d4a14a 100644 --- a/tests/net/serve/requests.luau +++ b/tests/net/serve/requests.luau @@ -3,117 +3,43 @@ local process = require("@lune/process") local stdio = require("@lune/stdio") local task = require("@lune/task") -local PORT = 8082 +local PORT = 8084 local URL = `http://127.0.0.1:{PORT}` -local URL_EXTERNAL = `http://0.0.0.0` local RESPONSE = "Hello, lune!" --- A server should never be running before testing -local isRunning = pcall(net.request, URL) -assert(not isRunning, `a server is already running at {URL}`) - --- Serve should not block the thread from continuing - -local thread = task.delay(1, function() - stdio.ewrite("Serve must not block the current thread\n") - task.wait(1) - process.exit(1) -end) +-- Serve should get proper path, query, and other request information local handle = net.serve(PORT, function(request) - -- print("Request:", request) - -- print("Responding with", RESPONSE) + -- print("Got a request from", request.ip, "on port", request.port) + + assert(type(request.path) == "string") + assert(type(request.query) == "table") + assert(type(request.query.key) == "table") + assert(type(request.query.key2) == "string") + assert(request.path == "/some/path") - assert(request.query.key == "param2") + assert(request.query.key[1] == "param1") + assert(request.query.key[2] == "param2") assert(request.query.key2 == "param3") + return RESPONSE end) -task.cancel(thread) +-- Serve should be able to handle at least 100 requests per second with a basic handler such as the above --- Serve should respond to a request we send to it - -local thread2 = task.delay(1, function() +local thread = task.delay(1, function() stdio.ewrite("Serve should respond to requests in a reasonable amount of time\n") task.wait(1) process.exit(1) end) -local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body -assert(response == RESPONSE, "Invalid response from server") +-- Serve should respond to requests we send, and keep responding until we stop it -task.cancel(thread2) +for _ = 1, 100 do + local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body + assert(response == RESPONSE, "Invalid response from server") +end + +task.cancel(thread) --- Stopping is not guaranteed to happen instantly since it is async, but --- it should happen on the next yield, so we wait the minimum amount here handle.stop() -task.wait() - --- Sending a net request may error if there was --- a connection issue, we should handle that here -local success, response2 = pcall(net.request, URL) -if not success then - local message = tostring(response2) - assert( - string.find(message, "Connection reset") - or string.find(message, "Connection closed") - or string.find(message, "Connection refused") - or string.find(message, "No connection could be made"), -- Windows Request Error - "Server did not stop responding to requests" - ) -else - assert(not response2.ok, "Server did not stop responding to requests") -end - ---[[ - Trying to stop the server again should error and - mention that the server has already been stopped - - Note that we cast pcall to any because of a - Luau limitation where it throws a type error for - `err` because handle.stop doesn't return any value -]] -local success2, err = (pcall :: any)(handle.stop) -assert(not success2, "Calling stop twice on the net serve handle should error") -local message = tostring(err) -assert( - string.find(message, "stop") - or string.find(message, "shutdown") - or string.find(message, "shut down"), - "The error message for calling stop twice on the net serve handle should be descriptive" -) - --- Serve should be able to bind to other IP addresses -local handle2 = net.serve(PORT, { - address = URL_EXTERNAL, - handleRequest = function(request) - return `Response from {URL_EXTERNAL}:{PORT}` - end, -}) - -if process.os == "windows" then - -- In Windows, client cannot directly connect to `0.0.0.0`. - -- `0.0.0.0` is a non-routable meta-address. - URL_EXTERNAL = "http://localhost" -end - --- And any requests to that IP should succeed -local response3 = net.request(`{URL_EXTERNAL}:{PORT}`).body -assert(response3 ~= nil, "Invalid response from server") - -handle2.stop() - --- Attempting to serve with a malformed IP address should throw an error -local success3 = pcall(function() - net.serve(8080, { - address = "a.b.c.d", - handleRequest = function() - return RESPONSE - end, - }) -end) - -assert(not success3, "Server was created with malformed address") - --- We have to manually exit so Windows CI doesn't get stuck forever -process.exit(0) diff --git a/tests/net/serve/websockets.luau b/tests/net/serve/websockets.luau index 51aea82..1c169ce 100644 --- a/tests/net/serve/websockets.luau +++ b/tests/net/serve/websockets.luau @@ -3,7 +3,7 @@ local process = require("@lune/process") local stdio = require("@lune/stdio") local task = require("@lune/task") -local PORT = 8081 +local PORT = 8085 local WS_URL = `ws://127.0.0.1:{PORT}` local REQUEST = "Hello from client!" local RESPONSE = "Hello, lune!" diff --git a/tests/process/exec/stdio.luau b/tests/process/exec/stdio.luau index 524713b..d4ef388 100644 --- a/tests/process/exec/stdio.luau +++ b/tests/process/exec/stdio.luau @@ -10,7 +10,7 @@ local echoResult = process.exec("echo", { }, { env = { TEST_VAR = echoMessage }, shell = if IS_WINDOWS then "powershell" else "bash", - stdio = "inherit" :: process.SpawnOptionsStdioKind, -- FIXME: This should just work without a cast? + stdio = "inherit", }) -- Windows uses \r\n (CRLF) and unix uses \n (LF) diff --git a/tests/require/tests/modules/self_alias/init.luau b/tests/require/tests/modules/self_alias/init.luau index 5003e13..960ecae 100644 --- a/tests/require/tests/modules/self_alias/init.luau +++ b/tests/require/tests/modules/self_alias/init.luau @@ -1,10 +1,10 @@ -local inner = require("@self/module") +local inner = require("@self/module") :: any -- FIXME: luau-lsp does not yet support self alias local outer = require("./module") assert(type(outer) == "table", "Outer module is not a table") assert(type(inner) == "table", "Inner module is not a table") assert(outer.Foo == inner.Foo, "Outer and inner modules have different Foo values") -assert(inner.Bar == outer.Bar, "Outer and inner modules have different Bar values") +assert(inner.Hello == outer.Hello, "Outer and inner modules have different Hello values") return inner diff --git a/tests/roblox/instance/custom/async.luau b/tests/roblox/instance/custom/async.luau index 471f143..09a6c57 100644 --- a/tests/roblox/instance/custom/async.luau +++ b/tests/roblox/instance/custom/async.luau @@ -16,14 +16,11 @@ end) -- Reference: https://create.roblox.com/docs/reference/engine/classes/HttpService#GetAsync -local URL_ASTROS = "http://api.open-notify.org/astros.json" - local game = roblox.Instance.new("DataModel") local HttpService = game:GetService("HttpService") :: any -local response = HttpService:GetAsync(URL_ASTROS) +local response = HttpService:GetAsync("https://httpbingo.org/json") local data = HttpService:JSONDecode(response) assert(type(data) == "table", "Returned JSON data should decode to a table") -assert(data.message == "success", "Returned JSON data should have a 'message' with value 'success'") -assert(type(data.people) == "table", "Returned JSON data should have a 'people' table") +assert(type(data.slideshow) == "table", "Returned JSON data should contain 'slideshow'") diff --git a/tests/roblox/instance/methods/Clone.luau b/tests/roblox/instance/methods/Clone.luau index 0472974..082c78a 100644 --- a/tests/roblox/instance/methods/Clone.luau +++ b/tests/roblox/instance/methods/Clone.luau @@ -7,7 +7,7 @@ local objValue1 = Instance.new("ObjectValue") local objValue2 = Instance.new("ObjectValue") objValue1.Name = "ObjectValue1" -objValue2.Name = "ObjectValue2"; +objValue2.Name = "ObjectValue2" (objValue1 :: any).Value = root; (objValue2 :: any).Value = child objValue1.Parent = child diff --git a/tests/serde/json/decode.luau b/tests/serde/json/decode.luau index d5ec908..d3f9af6 100644 --- a/tests/serde/json/decode.luau +++ b/tests/serde/json/decode.luau @@ -1,53 +1,13 @@ -local net = require("@lune/net") local serde = require("@lune/serde") +local source = require("./source") -type Response = { - products: { - { - id: number, - title: string, - description: string, - price: number, - discountPercentage: number, - rating: number, - stock: number, - brand: string, - category: string, - thumbnail: string, - images: { string }, - } - }, - total: number, - skip: number, - limit: number, -} +local decoded = serde.decode("json", source.pretty) -local response = net.request("https://dummyjson.com/products") - -assert(response.ok, "Dummy JSON api returned an error") -assert(#response.body > 0, "Dummy JSON api returned empty body") - -local data: Response = serde.decode("json", response.body) - -assert(type(data.limit) == "number", "Products limit was not a number") -assert(type(data.products) == "table", "Products was not a table") -assert(#data.products > 0, "Products table was empty") - -local productCount = 0 -for _, product in data.products do - productCount += 1 - assert(type(product.id) == "number", "Product id was not a number") - assert(type(product.title) == "string", "Product title was not a number") - assert(type(product.description) == "string", "Product description was not a number") - assert(type(product.images) == "table", "Product images was not a table") - assert(#product.images > 0, "Product images table was empty") -end - -assert( - data.limit == productCount, - string.format( - "Products limit and number of products in array mismatch (expected %d, got %d)", - data.limit, - productCount - ) -) +assert(type(decoded) == "table", "Decoded payload was not a table") +assert(decoded.Hello == "World", "Decoded payload Hello was not World") +assert(type(decoded.Inner) == "table", "Decoded payload Inner was not a table") +assert(type(decoded.Inner.Array) == "table", "Decoded payload Inner.Array was not a table") +assert(type(decoded.Inner.Array[1]) == "number", "Decoded payload Inner.Array[1] was not a number") +assert(type(decoded.Inner.Array[2]) == "number", "Decoded payload Inner.Array[2] was not a number") +assert(type(decoded.Inner.Array[3]) == "number", "Decoded payload Inner.Array[3] was not a number") +assert(decoded.Foo == "Bar", "Decoded payload Foo was not Bar") diff --git a/tests/serde/json/encode.luau b/tests/serde/json/encode.luau index 6be3b5a..3615a14 100644 --- a/tests/serde/json/encode.luau +++ b/tests/serde/json/encode.luau @@ -2,16 +2,6 @@ local serde = require("@lune/serde") local source = require("./source") local decoded = serde.decode("json", source.pretty) - -assert(type(decoded) == "table", "Decoded payload was not a table") -assert(decoded.Hello == "World", "Decoded payload Hello was not World") -assert(type(decoded.Inner) == "table", "Decoded payload Inner was not a table") -assert(type(decoded.Inner.Array) == "table", "Decoded payload Inner.Array was not a table") -assert(type(decoded.Inner.Array[1]) == "number", "Decoded payload Inner.Array[1] was not a number") -assert(type(decoded.Inner.Array[2]) == "number", "Decoded payload Inner.Array[2] was not a number") -assert(type(decoded.Inner.Array[3]) == "number", "Decoded payload Inner.Array[3] was not a number") -assert(decoded.Foo == "Bar", "Decoded payload Foo was not Bar") - local encoded = serde.encode("json", decoded, false) assert(encoded == source.encoded, "JSON round-trip did not produce the same result")