Rewrite the net standard library with smol ecosystem of crates (#310)

This commit is contained in:
Filip Tibell 2025-04-29 15:06:16 +02:00 committed by GitHub
parent 1f43ff89f7
commit 62910f02ab
Signed by: DevComp
GPG key ID: B5690EEEBB952194
55 changed files with 2331 additions and 1439 deletions

View file

@ -2,15 +2,16 @@ name: CI
on:
push:
pull_request:
workflow_dispatch:
defaults:
run:
shell: bash
jobs:
env:
CARGO_TERM_COLOR: always
jobs:
fmt:
name: Check formatting
runs-on: ubuntu-latest
@ -79,23 +80,26 @@ jobs:
components: clippy
targets: ${{ matrix.cargo-target }}
- name: Install binstall
uses: cargo-bins/cargo-binstall@main
- name: Install nextest
run: cargo binstall cargo-nextest
- name: Build
run: |
cargo build \
--workspace \
cargo build --workspace \
--locked --all-features \
--target ${{ matrix.cargo-target }}
- name: Lint
run: |
cargo clippy \
--workspace \
cargo clippy --workspace \
--locked --all-features \
--target ${{ matrix.cargo-target }}
- name: Test
run: |
cargo test \
--lib --workspace \
cargo nextest run --no-fail-fast \
--locked --all-features \
--target ${{ matrix.cargo-target }}

5
.gitignore vendored
View file

@ -21,7 +21,12 @@ lune.yml
luneDocs.json
luneTypes.d.luau
# Dirs generated by runtime or build scripts
/types
# Files generated by runtime or build scripts
scripts/brick_color.rs
scripts/font_enum_map.rs
scripts/physical_properties_enum_map.rs

View file

@ -65,6 +65,7 @@ fmt-check:
analyze:
#!/usr/bin/env bash
set -euo pipefail
lune run scripts/analyze_copy_typedefs
luau-lsp analyze \
--settings=".vscode/settings.json" \
--ignore="tests/roblox/rbx-test-files/**" \

View file

@ -3,7 +3,6 @@
local net = require("@lune/net")
local process = require("@lune/process")
local task = require("@lune/task")
local PORT = if process.env.PORT ~= nil and #process.env.PORT > 0
then assert(tonumber(process.env.PORT), "Failed to parse port from env")
@ -11,6 +10,10 @@ local PORT = if process.env.PORT ~= nil and #process.env.PORT > 0
-- Create our responder functions
local function root(_request: net.ServeRequest): string
return `Hello from Lune server!`
end
local function pong(request: net.ServeRequest): string
return `Pong!\n{request.path}\n{request.body}`
end
@ -29,10 +32,12 @@ local function notFound(_request: net.ServeRequest): net.ServeResponse
}
end
-- Run the server on port 8080
-- Run the server on the port forever
local handle = net.serve(PORT, function(request)
if string.sub(request.path, 1, 5) == "/ping" then
net.serve(PORT, function(request)
if request.path == "/" then
return root(request)
elseif string.sub(request.path, 1, 5) == "/ping" then
return pong(request)
elseif string.sub(request.path, 1, 7) == "/teapot" then
return teapot(request)
@ -42,12 +47,4 @@ local handle = net.serve(PORT, function(request)
end)
print(`Listening on port {PORT} 🚀`)
-- Exit our example after a small delay, if you copy this
-- example just remove this part to keep the server running
task.delay(2, function()
print("Shutting down...")
task.wait(1)
handle.stop()
end)
print("Press Ctrl+C to stop")

397
Cargo.lock generated
View file

@ -238,11 +238,22 @@ version = "3.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff6e472cdea888a4bd64f342f09b3f50e1886d32afe8df3d663c01140b811b18"
dependencies = [
"event-listener 5.4.0",
"event-listener",
"event-listener-strategy",
"pin-project-lite",
]
[[package]]
name = "async-net"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7"
dependencies = [
"async-io",
"blocking",
"futures-lite",
]
[[package]]
name = "async-process"
version = "2.3.0"
@ -256,7 +267,7 @@ dependencies = [
"async-task",
"blocking",
"cfg-if 1.0.0",
"event-listener 5.4.0",
"event-listener",
"futures-lite",
"rustix 0.38.44",
"tracing",
@ -286,6 +297,22 @@ version = "4.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b75356056920673b02621b35afd0f7dda9306d03c79a30f5c56c44cf256e3de"
[[package]]
name = "async-tungstenite"
version = "0.29.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef0f7efedeac57d9b26170f72965ecfd31473ca52ca7a64e925b0b6f5f079886"
dependencies = [
"atomic-waker",
"futures-core",
"futures-io",
"futures-task",
"futures-util",
"log",
"pin-project-lite",
"tungstenite",
]
[[package]]
name = "atomic-waker"
version = "1.1.2"
@ -298,6 +325,29 @@ version = "1.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
[[package]]
name = "aws-lc-rs"
version = "1.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b756939cb2f8dc900aa6dcd505e6e2428e9cae7ff7b028c49e3946efa70878"
dependencies = [
"aws-lc-sys",
"zeroize",
]
[[package]]
name = "aws-lc-sys"
version = "0.28.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa9b6986f250236c27e5a204062434a773a13243d2ffc2955f37bdba4c5c6a1"
dependencies = [
"bindgen",
"cc",
"cmake",
"dunce",
"fs_extra",
]
[[package]]
name = "backtrace"
version = "0.3.74"
@ -337,6 +387,29 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "bindgen"
version = "0.69.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271383c67ccabffb7381723dea0672a673f292304fcb45c01cc648c7a8d58088"
dependencies = [
"bitflags 2.9.0",
"cexpr",
"clang-sys",
"itertools",
"lazy_static",
"lazycell",
"log",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash 1.1.0",
"shlex",
"syn 2.0.100",
"which",
]
[[package]]
name = "bitflags"
version = "1.3.2"
@ -482,6 +555,15 @@ dependencies = [
"shlex",
]
[[package]]
name = "cexpr"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766"
dependencies = [
"nom",
]
[[package]]
name = "cfg-if"
version = "0.1.10"
@ -539,6 +621,17 @@ dependencies = [
"inout",
]
[[package]]
name = "clang-sys"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4"
dependencies = [
"glob",
"libc",
"libloading",
]
[[package]]
name = "clap"
version = "4.5.37"
@ -588,6 +681,15 @@ dependencies = [
"error-code",
]
[[package]]
name = "cmake"
version = "0.1.54"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7caa3f9de89ddbe2c607f4101924c5abec803763ae9534e4f4d7d8f84aa81f0"
dependencies = [
"cc",
]
[[package]]
name = "colorchoice"
version = "1.0.3"
@ -634,12 +736,6 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6"
[[package]]
name = "convert_case"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e"
[[package]]
name = "cookie"
version = "0.15.2"
@ -747,19 +843,6 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "derive_more"
version = "0.99.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6edb4b64a43d977b8e99788fe3a04d483834fba1215a7e02caa415b626497f7f"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version 0.4.1",
"syn 2.0.100",
]
[[package]]
name = "dialoguer"
version = "0.11.0"
@ -919,17 +1002,6 @@ version = "3.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5d9305ccc6942a704f4335694ecd3de2ea531b114ac2d51f5f843750787a92f"
[[package]]
name = "event-listener"
version = "4.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e"
dependencies = [
"concurrent-queue",
"parking",
"pin-project-lite",
]
[[package]]
name = "event-listener"
version = "5.4.0"
@ -947,7 +1019,7 @@ version = "0.5.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93"
dependencies = [
"event-listener 5.4.0",
"event-listener",
"pin-project-lite",
]
@ -993,6 +1065,26 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "fs_extra"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c"
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
@ -1000,6 +1092,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
@ -1038,6 +1131,17 @@ dependencies = [
"syn 2.0.100",
]
[[package]]
name = "futures-rustls"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8f2f12607f92c69b12ed746fabf9ca4f5c482cba46679c1a75b874ed7c26adb"
dependencies = [
"futures-io",
"rustls 0.23.26",
"rustls-pki-types",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
@ -1056,10 +1160,13 @@ version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
@ -1136,6 +1243,12 @@ version = "0.30.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d0e9b6647e9b41d3a5ef02964c6be01311a7f2472fea40595c635c6d046c259e"
[[package]]
name = "glob"
version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2"
[[package]]
name = "h2"
version = "0.3.26"
@ -1155,25 +1268,6 @@ dependencies = [
"tracing",
]
[[package]]
name = "h2"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75249d144030531f8dee69fe9cea04d3edf809a017ae445e2abdff6629e86633"
dependencies = [
"atomic-waker",
"bytes",
"fnv",
"futures-core",
"futures-sink",
"http 1.3.1",
"indexmap",
"slab",
"tokio",
"tokio-util",
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.15.2"
@ -1288,7 +1382,7 @@ dependencies = [
"futures-channel",
"futures-core",
"futures-util",
"h2 0.3.26",
"h2",
"http 0.2.12",
"http-body 0.4.6",
"httparse",
@ -1311,7 +1405,6 @@ dependencies = [
"bytes",
"futures-channel",
"futures-util",
"h2 0.4.9",
"http 1.3.1",
"http-body 1.0.1",
"httparse",
@ -1334,42 +1427,7 @@ dependencies = [
"hyper 0.14.32",
"rustls 0.21.12",
"tokio",
"tokio-rustls 0.24.1",
]
[[package]]
name = "hyper-tungstenite"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a343d17fe7885302ed7252767dc7bb83609a874b6ff581142241ec4b73957ad"
dependencies = [
"http-body-util",
"hyper 1.6.0",
"hyper-util",
"pin-project-lite",
"tokio",
"tokio-tungstenite",
"tungstenite",
]
[[package]]
name = "hyper-util"
version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "497bbc33a26fdd4af9ed9c70d63f61cf56a938375fbb32df34db9b1cd6d643f2"
dependencies = [
"bytes",
"futures-channel",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"hyper 1.6.0",
"libc",
"pin-project-lite",
"socket2",
"tokio",
"tower-service",
"tracing",
"tokio-rustls",
]
[[package]]
@ -1566,6 +1624,15 @@ version = "1.70.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf"
[[package]]
name = "itertools"
version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.15"
@ -1607,6 +1674,12 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "libc"
version = "0.2.172"
@ -1790,21 +1863,30 @@ dependencies = [
name = "lune-std-net"
version = "0.2.0"
dependencies = [
"async-channel",
"async-executor",
"async-io",
"async-lock",
"async-net",
"async-tungstenite",
"blocking",
"bstr",
"futures-util",
"http 1.3.1",
"futures",
"futures-lite",
"futures-rustls",
"http-body-util",
"hyper 1.6.0",
"hyper-tungstenite",
"hyper-util",
"lune-std-serde",
"lune-utils",
"mlua",
"mlua-luau-scheduler",
"reqwest",
"tokio",
"tokio-tungstenite",
"pin-project-lite",
"rustls 0.23.26",
"rustls-pki-types",
"url",
"urlencoding",
"webpki",
"webpki-roots 0.26.8",
]
[[package]]
@ -1982,6 +2064,12 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "minimal-lexical"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a"
[[package]]
name = "miniz_oxide"
version = "0.8.8"
@ -2030,11 +2118,10 @@ dependencies = [
"async-io",
"blocking",
"concurrent-queue",
"derive_more",
"event-listener 4.0.3",
"event-listener",
"futures-lite",
"mlua",
"rustc-hash 1.1.0",
"rustc-hash 2.1.1",
"tracing",
"tracing-subscriber",
"tracing-tracy",
@ -2073,6 +2160,16 @@ dependencies = [
"libc",
]
[[package]]
name = "nom"
version = "7.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a"
dependencies = [
"memchr",
"minimal-lexical",
]
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
@ -2307,6 +2404,16 @@ dependencies = [
"zerocopy 0.8.24",
]
[[package]]
name = "prettyplease"
version = "0.2.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "664ec5419c51e34154eec046ebcba56312d5a2fc3b09a06da188e1ad21afadf6"
dependencies = [
"proc-macro2",
"syn 2.0.100",
]
[[package]]
name = "proc-macro-hack"
version = "0.5.20+deprecated"
@ -2634,7 +2741,7 @@ dependencies = [
"encoding_rs",
"futures-core",
"futures-util",
"h2 0.3.26",
"h2",
"http 0.2.12",
"http-body 0.4.6",
"hyper 0.14.32",
@ -2654,7 +2761,7 @@ dependencies = [
"sync_wrapper",
"system-configuration",
"tokio",
"tokio-rustls 0.24.1",
"tokio-rustls",
"tower-service",
"url",
"wasm-bindgen",
@ -2750,15 +2857,6 @@ dependencies = [
"semver 0.9.0",
]
[[package]]
name = "rustc_version"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cfcb3a22ef46e85b45de6ee7e79d063319ebb6594faafcf1c225ea92ab6e9b92"
dependencies = [
"semver 1.0.26",
]
[[package]]
name = "rustix"
version = "0.38.44"
@ -2799,14 +2897,15 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.22.4"
version = "0.23.26"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
checksum = "df51b5869f3a441595eac5e8ff14d486ff285f7b8c0df8770e49c3b56351f0f0"
dependencies = [
"aws-lc-rs",
"log",
"ring",
"once_cell",
"rustls-pki-types",
"rustls-webpki 0.102.8",
"rustls-webpki 0.103.1",
"subtle",
"zeroize",
]
@ -2838,10 +2937,11 @@ dependencies = [
[[package]]
name = "rustls-webpki"
version = "0.102.8"
version = "0.103.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9"
checksum = "fef8b8769aaccf73098557a87cd1816b4f9c7c16811c9c77142aa695c16f2c03"
dependencies = [
"aws-lc-rs",
"ring",
"rustls-pki-types",
"untrusted",
@ -3146,7 +3246,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d022496b16281348b52d0e30ae99e01a73d737b2f45d38fed4edf79f9325a1d5"
dependencies = [
"discard",
"rustc_version 0.2.3",
"rustc_version",
"stdweb-derive",
"stdweb-internal-macros",
"stdweb-internal-runtime",
@ -3441,33 +3541,6 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f"
dependencies = [
"rustls 0.22.4",
"rustls-pki-types",
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"rustls 0.22.4",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.25.0",
"tungstenite",
"webpki-roots 0.26.8",
]
[[package]]
name = "tokio-util"
version = "0.7.15"
@ -3623,22 +3696,18 @@ checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "tungstenite"
version = "0.21.0"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
checksum = "4793cb5e56680ecbb1d843515b23b6de9a75eb04b66643e256a396d43be33c13"
dependencies = [
"byteorder 1.5.0",
"bytes",
"data-encoding",
"http 1.3.1",
"httparse",
"log",
"rand 0.8.5",
"rustls 0.22.4",
"rustls-pki-types",
"rand 0.9.1",
"sha1 0.10.6",
"thiserror 1.0.69",
"url",
"thiserror 2.0.12",
"utf-8",
]
@ -3871,6 +3940,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "webpki"
version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed63aea5ce73d0ff405984102c42de94fc55a6b75765d621c65262469b3c9b53"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "webpki-roots"
version = "0.25.4"
@ -3886,6 +3965,18 @@ dependencies = [
"rustls-pki-types",
]
[[package]]
name = "which"
version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7"
dependencies = [
"either",
"home",
"once_cell",
"rustix 0.38.44",
]
[[package]]
name = "winapi"
version = "0.3.9"

View file

@ -16,24 +16,26 @@ workspace = true
mlua = { version = "0.10.3", features = ["luau"] }
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" }
async-channel = "2.3"
async-executor = "1.13"
async-io = "2.4"
async-lock = "3.4"
async-net = "2.0"
async-tungstenite = "0.29"
blocking = "1.6"
bstr = "1.9"
futures-util = "0.3"
hyper = { version = "1.1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
http = "1.0"
http-body-util = { version = "0.1" }
hyper-tungstenite = { version = "0.13" }
reqwest = { version = "0.11", default-features = false, features = [
"rustls-tls",
] }
tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] }
futures = { version = "0.3", default-features = false, features = ["std"] }
futures-lite = "2.6"
futures-rustls = "0.26"
http-body-util = "0.1"
hyper = { version = "1.6", features = ["http1", "client", "server"] }
pin-project-lite = "0.2"
rustls = "0.23"
rustls-pki-types = "1.11"
url = "2.5"
urlencoding = "2.1"
tokio = { version = "1", default-features = false, features = [
"sync",
"net",
"macros",
] }
webpki = "0.22"
webpki-roots = "0.26"
lune-utils = { version = "0.2.0", path = "../lune-utils" }
lune-std-serde = { version = "0.2.0", path = "../lune-std-serde" }

View file

@ -1,163 +0,0 @@
use std::str::FromStr;
use mlua::prelude::*;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_ENCODING};
use lune_std_serde::{decompress, CompressDecompressFormat};
use lune_utils::TableBuilder;
use super::{config::RequestConfig, util::header_map_to_table};
const REGISTRY_KEY: &str = "NetClient";
pub struct NetClientBuilder {
builder: reqwest::ClientBuilder,
}
impl NetClientBuilder {
pub fn new() -> NetClientBuilder {
Self {
builder: reqwest::ClientBuilder::new(),
}
}
pub fn headers<K, V>(mut self, headers: &[(K, V)]) -> LuaResult<Self>
where
K: AsRef<str>,
V: AsRef<[u8]>,
{
let mut map = HeaderMap::new();
for (key, val) in headers {
let hkey = HeaderName::from_str(key.as_ref()).into_lua_err()?;
let hval = HeaderValue::from_bytes(val.as_ref()).into_lua_err()?;
map.insert(hkey, hval);
}
self.builder = self.builder.default_headers(map);
Ok(self)
}
pub fn build(self) -> LuaResult<NetClient> {
let client = self.builder.build().into_lua_err()?;
Ok(NetClient { inner: client })
}
}
#[derive(Debug, Clone)]
pub struct NetClient {
inner: reqwest::Client,
}
impl NetClient {
pub fn from_registry(lua: &Lua) -> Self {
lua.named_registry_value(REGISTRY_KEY)
.expect("Failed to get NetClient from lua registry")
}
pub fn into_registry(self, lua: &Lua) {
lua.set_named_registry_value(REGISTRY_KEY, self)
.expect("Failed to store NetClient in lua registry");
}
pub async fn request(&self, config: RequestConfig) -> LuaResult<NetClientResponse> {
// Create and send the request
let mut request = self.inner.request(config.method, config.url);
for (query, values) in config.query {
request = request.query(
&values
.iter()
.map(|v| (query.as_str(), v))
.collect::<Vec<_>>(),
);
}
for (header, values) in config.headers {
for value in values {
request = request.header(header.as_str(), value);
}
}
let res = request
.body(config.body.unwrap_or_default())
.send()
.await
.into_lua_err()?;
// Extract status, headers
let res_status = res.status().as_u16();
let res_status_text = res.status().canonical_reason();
let res_headers = res.headers().clone();
// Read response bytes
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
let mut res_decompressed = false;
// Check for extra options, decompression
if config.options.decompress {
let decompress_format = res_headers
.iter()
.find(|(name, _)| {
name.as_str()
.eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
})
.and_then(|(_, value)| value.to_str().ok())
.and_then(CompressDecompressFormat::detect_from_header_str);
if let Some(format) = decompress_format {
res_bytes = decompress(res_bytes, format).await?;
res_decompressed = true;
}
}
Ok(NetClientResponse {
ok: (200..300).contains(&res_status),
status_code: res_status,
status_message: res_status_text.unwrap_or_default().to_string(),
headers: res_headers,
body: res_bytes,
body_decompressed: res_decompressed,
})
}
}
impl LuaUserData for NetClient {}
impl FromLua for NetClient {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::UserData(ud) = value {
if let Ok(ctx) = ud.borrow::<NetClient>() {
return Ok(ctx.clone());
}
}
unreachable!("NetClient should only be used from registry")
}
}
impl From<&Lua> for NetClient {
fn from(value: &Lua) -> Self {
value
.named_registry_value(REGISTRY_KEY)
.expect("Missing require context in lua registry")
}
}
pub struct NetClientResponse {
ok: bool,
status_code: u16,
status_message: String,
headers: HeaderMap,
body: Vec<u8>,
body_decompressed: bool,
}
impl NetClientResponse {
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
TableBuilder::new(lua.clone())?
.with_value("ok", self.ok)?
.with_value("statusCode", self.status_code)?
.with_value("statusMessage", self.status_message)?
.with_value(
"headers",
header_map_to_table(lua, self.headers, self.body_decompressed)?,
)?
.with_value("body", lua.create_string(&self.body)?)?
.build_readonly()
}
}

View file

@ -0,0 +1,95 @@
use std::{
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use async_net::TcpStream;
use futures_lite::prelude::*;
use futures_rustls::{TlsConnector, TlsStream};
use rustls_pki_types::ServerName;
use url::Url;
use crate::client::rustls::CLIENT_CONFIG;
#[derive(Debug)]
pub enum HttpStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
}
impl HttpStream {
pub async fn connect(url: Url) -> Result<Self, io::Error> {
let Some(host) = url.host() else {
return Err(make_err("unknown or missing host"));
};
let Some(port) = url.port_or_known_default() else {
return Err(make_err("unknown or missing port"));
};
let use_tls = match url.scheme() {
"http" => false,
"https" => true,
s => return Err(make_err(format!("unsupported scheme: {s}"))),
};
let host = host.to_string();
let stream = TcpStream::connect((host.clone(), port)).await?;
let stream = if use_tls {
let servname = ServerName::try_from(host).map_err(make_err)?.to_owned();
let connector = TlsConnector::from(Arc::clone(&CLIENT_CONFIG));
let stream = connector.connect(servname, stream).await?;
Self::Tls(TlsStream::Client(stream))
} else {
Self::Plain(stream)
};
Ok(stream)
}
}
impl AsyncRead for HttpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
HttpStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
HttpStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for HttpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
HttpStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
HttpStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
HttpStream::Plain(stream) => Pin::new(stream).poll_close(cx),
HttpStream::Tls(stream) => Pin::new(stream).poll_close(cx),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
HttpStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
HttpStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
}
fn make_err(e: impl ToString) -> io::Error {
io::Error::new(io::ErrorKind::Other, e.to_string())
}

View file

@ -0,0 +1,125 @@
use hyper::{
body::{Bytes, Incoming},
client::conn::http1::handshake,
header::{HeaderValue, ACCEPT, CONTENT_LENGTH, HOST, LOCATION, USER_AGENT},
Method, Request as HyperRequest, Response as HyperResponse, Uri,
};
use mlua::prelude::*;
use url::Url;
use crate::{
client::{http_stream::HttpStream, ws_stream::WsStream},
shared::{
headers::create_user_agent_header,
hyper::{HyperExecutor, HyperIo},
request::Request,
response::Response,
websocket::Websocket,
},
};
pub mod http_stream;
pub mod rustls;
pub mod ws_stream;
const MAX_REDIRECTS: usize = 10;
/**
Connects to a websocket at the given URL.
*/
pub async fn connect_websocket(url: Url) -> LuaResult<Websocket<WsStream>> {
let stream = WsStream::connect(url).await?;
Ok(Websocket::from(stream))
}
/**
Sends the request and returns the final response.
This will follow any redirects returned by the server,
modifying the request method and body as necessary.
*/
pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response> {
let url = request
.inner
.uri()
.to_string()
.parse::<Url>()
.expect("uri is valid");
// Some headers are required by most if not
// all servers, make sure those are present...
if !request.headers().contains_key(HOST.as_str()) {
if let Some(host) = url.host_str() {
let host = HeaderValue::from_str(host).into_lua_err()?;
request.inner.headers_mut().insert(HOST, host);
}
}
if !request.headers().contains_key(USER_AGENT.as_str()) {
let ua = create_user_agent_header(&lua)?;
let ua = HeaderValue::from_str(&ua).into_lua_err()?;
request.inner.headers_mut().insert(USER_AGENT, ua);
}
if !request.headers().contains_key(CONTENT_LENGTH.as_str()) && request.method() != Method::GET {
let len = request.inner.body().len().to_string();
let len = HeaderValue::from_str(&len).into_lua_err()?;
request.inner.headers_mut().insert(CONTENT_LENGTH, len);
}
if !request.headers().contains_key(ACCEPT.as_str()) {
let accept = HeaderValue::from_static("*/*");
request.inner.headers_mut().insert(ACCEPT, accept);
}
// ... we can now safely continue and send the request
loop {
let stream = HttpStream::connect(url.clone()).await?;
let (mut sender, conn) = handshake(HyperIo::from(stream)).await.into_lua_err()?;
HyperExecutor::execute(lua.clone(), conn);
let incoming = sender
.send_request(request.as_full())
.await
.into_lua_err()?;
if let Some((new_method, new_uri)) = check_redirect(&request.inner, &incoming) {
if request.redirects.is_some_and(|r| r >= MAX_REDIRECTS) {
return Err(LuaError::external("Too many redirects"));
}
if new_method == Method::GET {
*request.inner.body_mut() = Bytes::new();
}
*request.inner.method_mut() = new_method;
*request.inner.uri_mut() = new_uri;
*request.redirects.get_or_insert_default() += 1;
continue;
}
break Response::from_incoming(incoming, request.decompress).await;
}
}
fn check_redirect(
request: &HyperRequest<Bytes>,
response: &HyperResponse<Incoming>,
) -> Option<(Method, Uri)> {
if !response.status().is_redirection() {
return None;
}
let location = response.headers().get(LOCATION)?;
let location = location.to_str().ok()?;
let location = location.parse().ok()?;
let method = match response.status().as_u16() {
301..=303 => Method::GET,
_ => request.method().clone(),
};
Some((method, location))
}

View file

@ -0,0 +1,12 @@
use std::sync::{Arc, LazyLock};
use rustls::ClientConfig;
pub static CLIENT_CONFIG: LazyLock<Arc<ClientConfig>> = LazyLock::new(|| {
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
})
.with_no_client_auth()
.into()
});

View file

@ -0,0 +1,114 @@
use std::{
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use async_net::TcpStream;
use async_tungstenite::{
tungstenite::{Error as TungsteniteError, Message, Result as TungsteniteResult},
WebSocketStream as TungsteniteStream,
};
use futures::Sink;
use futures_lite::prelude::*;
use futures_rustls::{TlsConnector, TlsStream};
use rustls_pki_types::ServerName;
use url::Url;
use crate::client::rustls::CLIENT_CONFIG;
#[derive(Debug)]
pub enum WsStream {
Plain(TungsteniteStream<TcpStream>),
Tls(TungsteniteStream<TlsStream<TcpStream>>),
}
impl WsStream {
pub async fn connect(url: Url) -> Result<Self, io::Error> {
let Some(host) = url.host() else {
return Err(make_err("unknown or missing host"));
};
let Some(port) = url.port_or_known_default() else {
return Err(make_err("unknown or missing port"));
};
let use_tls = match url.scheme() {
"ws" => false,
"wss" => true,
s => return Err(make_err(format!("unsupported scheme: {s}"))),
};
let host = host.to_string();
let stream = TcpStream::connect((host.clone(), port)).await?;
let stream = if use_tls {
let servname = ServerName::try_from(host).map_err(make_err)?.to_owned();
let connector = TlsConnector::from(Arc::clone(&CLIENT_CONFIG));
let stream = connector.connect(servname, stream).await?;
let stream = TlsStream::Client(stream);
let stream = async_tungstenite::client_async(url.to_string(), stream)
.await
.map_err(make_err)?
.0;
Self::Tls(stream)
} else {
let stream = async_tungstenite::client_async(url.to_string(), stream)
.await
.map_err(make_err)?
.0;
Self::Plain(stream)
};
Ok(stream)
}
}
impl Sink<Message> for WsStream {
type Error = TungsteniteError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_ready(cx),
WsStream::Tls(s) => Pin::new(s).poll_ready(cx),
}
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).start_send(item),
WsStream::Tls(s) => Pin::new(s).start_send(item),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_flush(cx),
WsStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_close(cx),
WsStream::Tls(s) => Pin::new(s).poll_close(cx),
}
}
}
impl Stream for WsStream {
type Item = TungsteniteResult<Message>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_next(cx),
WsStream::Tls(s) => Pin::new(s).poll_next(cx),
}
}
}
fn make_err(e: impl ToString) -> io::Error {
io::Error::new(io::ErrorKind::Other, e.to_string())
}

View file

@ -1,231 +0,0 @@
use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr},
};
use bstr::{BString, ByteSlice};
use mlua::prelude::*;
use reqwest::Method;
use super::util::table_to_hash_map;
const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#"
return {
status = 426,
body = "Upgrade Required",
headers = {
Upgrade = "websocket",
},
}
"#;
// Net request config
#[derive(Debug, Clone)]
pub struct RequestConfigOptions {
pub decompress: bool,
}
impl Default for RequestConfigOptions {
fn default() -> Self {
Self { decompress: true }
}
}
impl FromLua for RequestConfigOptions {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::Nil = value {
// Nil means default options
Ok(Self::default())
} else if let LuaValue::Table(tab) = value {
// Table means custom options
let decompress = match tab.get::<Option<bool>>("decompress") {
Ok(decomp) => Ok(decomp.unwrap_or(true)),
Err(_) => Err(LuaError::RuntimeError(
"Invalid option value for 'decompress' in request config options".to_string(),
)),
}?;
Ok(Self { decompress })
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfigOptions".to_string(),
message: Some(format!(
"Invalid request config options - expected table or nil, got {}",
value.type_name()
)),
})
}
}
}
#[derive(Debug, Clone)]
pub struct RequestConfig {
pub url: String,
pub method: Method,
pub query: HashMap<String, Vec<String>>,
pub headers: HashMap<String, Vec<String>>,
pub body: Option<Vec<u8>>,
pub options: RequestConfigOptions,
}
impl FromLua for RequestConfig {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
// If we just got a string we assume its a GET request to a given url
if let LuaValue::String(s) = value {
Ok(Self {
url: s.to_string_lossy().to_string(),
method: Method::GET,
query: HashMap::new(),
headers: HashMap::new(),
body: None,
options: RequestConfigOptions::default(),
})
} else if let LuaValue::Table(tab) = value {
// If we got a table we are able to configure the entire request
// Extract url
let url = match tab.get::<LuaString>("url") {
Ok(config_url) => Ok(config_url.to_string_lossy().to_string()),
Err(_) => Err(LuaError::runtime("Missing 'url' in request config")),
}?;
// Extract method
let method = match tab.get::<LuaString>("method") {
Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(),
Err(_) => "GET".to_string(),
};
// Extract query
let query = match tab.get::<LuaTable>("query") {
Ok(tab) => table_to_hash_map(tab, "query")?,
Err(_) => HashMap::new(),
};
// Extract headers
let headers = match tab.get::<LuaTable>("headers") {
Ok(tab) => table_to_hash_map(tab, "headers")?,
Err(_) => HashMap::new(),
};
// Extract body
let body = match tab.get::<BString>("body") {
Ok(config_body) => Some(config_body.as_bytes().to_owned()),
Err(_) => None,
};
// Convert method string into proper enum
let method = method.trim().to_ascii_uppercase();
let method = match method.as_ref() {
"GET" => Ok(Method::GET),
"POST" => Ok(Method::POST),
"PUT" => Ok(Method::PUT),
"DELETE" => Ok(Method::DELETE),
"HEAD" => Ok(Method::HEAD),
"OPTIONS" => Ok(Method::OPTIONS),
"PATCH" => Ok(Method::PATCH),
_ => Err(LuaError::RuntimeError(format!(
"Invalid request config method '{}'",
&method
))),
}?;
// Parse any extra options given
let options = match tab.get::<LuaValue>("options") {
Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?,
Err(_) => RequestConfigOptions::default(),
};
// All good, validated and we got what we need
Ok(Self {
url,
method,
query,
headers,
body,
options,
})
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfig".to_string(),
message: Some(format!(
"Invalid request config - expected string or table, got {}",
value.type_name()
)),
})
}
}
}
// Net serve config
#[derive(Debug)]
pub struct ServeConfig {
pub address: IpAddr,
pub handle_request: LuaFunction,
pub handle_web_socket: Option<LuaFunction>,
}
impl FromLua for ServeConfig {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
if let LuaValue::Function(f) = &value {
// Single function = request handler, rest is default
Ok(ServeConfig {
handle_request: f.clone(),
handle_web_socket: None,
address: DEFAULT_IP_ADDRESS,
})
} else if let LuaValue::Table(t) = &value {
// Table means custom options
let address: Option<LuaString> = t.get("address")?;
let handle_request: Option<LuaFunction> = t.get("handleRequest")?;
let handle_web_socket: Option<LuaFunction> = t.get("handleWebSocket")?;
if handle_request.is_some() || handle_web_socket.is_some() {
let address: IpAddr = match &address {
Some(addr) => {
let addr_str = addr.to_str()?;
addr_str
.trim_start_matches("http://")
.trim_start_matches("https://")
.parse()
.map_err(|_e| LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: Some(format!(
"IP address format is incorrect - \
expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \
got '{addr_str}'"
)),
})?
}
None => DEFAULT_IP_ADDRESS,
};
Ok(Self {
address,
handle_request: handle_request.unwrap_or_else(|| {
lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER)
.into_function()
.expect("Failed to create default http responder function")
}),
handle_web_socket,
})
} else {
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: Some(String::from(
"Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function",
)),
})
}
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: None,
})
}
}
}

View file

@ -1,22 +1,17 @@
#![allow(clippy::cargo_common_metadata)]
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
mod client;
mod config;
mod server;
mod util;
mod websocket;
use lune_utils::TableBuilder;
use mlua::prelude::*;
pub(crate) mod client;
pub(crate) mod server;
pub(crate) mod shared;
pub(crate) mod url;
use self::{
client::{NetClient, NetClientBuilder},
config::{RequestConfig, ServeConfig},
server::serve,
util::create_user_agent_header,
websocket::NetWebSocket,
client::ws_stream::WsStream,
server::config::ServeConfig,
shared::{request::Request, response::Response, websocket::Websocket},
};
const TYPEDEFS: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/types.d.luau"));
@ -37,10 +32,6 @@ pub fn typedefs() -> String {
Errors when out of memory.
*/
pub fn module(lua: Lua) -> LuaResult<LuaTable> {
NetClientBuilder::new()
.headers(&[("User-Agent", create_user_agent_header(&lua)?)])?
.build()?
.into_registry(&lua);
TableBuilder::new(lua)?
.with_async_function("request", net_request)?
.with_async_function("socket", net_socket)?
@ -50,42 +41,35 @@ pub fn module(lua: Lua) -> LuaResult<LuaTable> {
.build_readonly()
}
async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult<LuaTable> {
let client = NetClient::from_registry(&lua);
// NOTE: We spawn the request as a background task to free up resources in lua
let res = lua.spawn(async move { client.request(config).await });
res.await?.into_lua_table(&lua)
async fn net_request(lua: Lua, req: Request) -> LuaResult<Response> {
self::client::send_request(req, lua).await
}
async fn net_socket(lua: Lua, url: String) -> LuaResult<LuaValue> {
let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?;
NetWebSocket::new(ws).into_lua(&lua)
async fn net_socket(_: Lua, url: String) -> LuaResult<Websocket<WsStream>> {
let url = url.parse().into_lua_err()?;
self::client::connect_websocket(url).await
}
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<LuaTable> {
serve(lua, port, config).await
self::server::serve(lua.clone(), port, config)
.await?
.into_lua_table(lua)
}
fn net_url_encode(
lua: &Lua,
(lua_string, as_binary): (LuaString, Option<bool>),
) -> LuaResult<LuaValue> {
if matches!(as_binary, Some(true)) {
urlencoding::encode_binary(&lua_string.as_bytes()).into_lua(lua)
} else {
urlencoding::encode(&lua_string.to_str()?).into_lua(lua)
}
) -> LuaResult<LuaString> {
let as_binary = as_binary.unwrap_or_default();
let bytes = self::url::encode(lua_string, as_binary)?;
lua.create_string(bytes)
}
fn net_url_decode(
lua: &Lua,
(lua_string, as_binary): (LuaString, Option<bool>),
) -> LuaResult<LuaValue> {
if matches!(as_binary, Some(true)) {
urlencoding::decode_binary(&lua_string.as_bytes()).into_lua(lua)
} else {
urlencoding::decode(&lua_string.to_str()?)
.map_err(|e| LuaError::RuntimeError(format!("Encountered invalid encoding - {e}")))?
.into_lua(lua)
}
) -> LuaResult<LuaString> {
let as_binary = as_binary.unwrap_or_default();
let bytes = self::url::decode(lua_string, as_binary)?;
lua.create_string(bytes)
}

View file

@ -0,0 +1,87 @@
use std::net::{IpAddr, Ipv4Addr};
use mlua::prelude::*;
const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#"
return {
status = 426,
body = "Upgrade Required",
headers = {
Upgrade = "websocket",
},
}
"#;
#[derive(Debug, Clone)]
pub struct ServeConfig {
pub address: IpAddr,
pub handle_request: LuaFunction,
pub handle_web_socket: Option<LuaFunction>,
}
impl FromLua for ServeConfig {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
if let LuaValue::Function(f) = &value {
// Single function = request handler, rest is default
Ok(ServeConfig {
handle_request: f.clone(),
handle_web_socket: None,
address: DEFAULT_IP_ADDRESS,
})
} else if let LuaValue::Table(t) = &value {
// Table means custom options
let address: Option<LuaString> = t.get("address")?;
let handle_request: Option<LuaFunction> = t.get("handleRequest")?;
let handle_web_socket: Option<LuaFunction> = t.get("handleWebSocket")?;
if handle_request.is_some() || handle_web_socket.is_some() {
let address: IpAddr = match &address {
Some(addr) => {
let addr_str = addr.to_str()?;
addr_str
.trim_start_matches("http://")
.trim_start_matches("https://")
.parse()
.map_err(|_e| LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: Some(format!(
"IP address format is incorrect - \
expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \
got '{addr_str}'"
)),
})?
}
None => DEFAULT_IP_ADDRESS,
};
Ok(Self {
address,
handle_request: handle_request.unwrap_or_else(|| {
lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER)
.into_function()
.expect("Failed to create default http responder function")
}),
handle_web_socket,
})
} else {
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: Some(String::from(
"Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function",
)),
})
}
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: None,
})
}
}
}

View file

@ -0,0 +1,72 @@
use std::{
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use async_channel::{unbounded, Receiver, Sender};
use lune_utils::TableBuilder;
use mlua::prelude::*;
#[derive(Debug, Clone)]
pub struct ServeHandle {
addr: SocketAddr,
shutdown: Arc<AtomicBool>,
sender: Sender<()>,
}
impl ServeHandle {
pub fn new(addr: SocketAddr) -> (Self, Receiver<()>) {
let (sender, receiver) = unbounded();
let this = Self {
addr,
shutdown: Arc::new(AtomicBool::new(false)),
sender,
};
(this, receiver)
}
// TODO: Remove this in the next major release to use colon/self
// based call syntax and userdata implementation below instead
pub fn into_lua_table(self, lua: Lua) -> LuaResult<LuaTable> {
let shutdown = self.shutdown.clone();
let sender = self.sender.clone();
TableBuilder::new(lua)?
.with_value("ip", self.addr.ip().to_string())?
.with_value("port", self.addr.port())?
.with_function("stop", move |_, ()| {
if shutdown.load(Ordering::SeqCst) {
Err(LuaError::runtime("Server already stopped"))
} else {
shutdown.store(true, Ordering::SeqCst);
sender.try_send(()).ok();
sender.close();
Ok(())
}
})?
.build()
}
}
impl LuaUserData for ServeHandle {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ip", |_, this| Ok(this.addr.ip().to_string()));
fields.add_field_method_get("port", |_, this| Ok(this.addr.port()));
}
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_method("stop", |_, this, ()| {
if this.shutdown.load(Ordering::SeqCst) {
Err(LuaError::runtime("Server already stopped"))
} else {
this.shutdown.store(true, Ordering::SeqCst);
this.sender.try_send(()).ok();
this.sender.close();
Ok(())
}
});
}
}

View file

@ -1,58 +0,0 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use mlua::prelude::*;
#[derive(Debug, Clone, Copy)]
pub(super) struct SvcKeys {
key_request: &'static str,
key_websocket: Option<&'static str>,
}
impl SvcKeys {
pub(super) fn new(
lua: Lua,
handle_request: LuaFunction,
handle_websocket: Option<LuaFunction>,
) -> LuaResult<Self> {
static SERVE_COUNTER: AtomicUsize = AtomicUsize::new(0);
let count = SERVE_COUNTER.fetch_add(1, Ordering::Relaxed);
// NOTE: We leak strings here, but this is an acceptable tradeoff since programs
// generally only start one or a couple of servers and they are usually never dropped.
// Leaking here lets us keep this struct Copy and access the request handler callbacks
// very performantly, significantly reducing the per-request overhead of the server.
let key_request: &'static str =
Box::leak(format!("__net_serve_request_{count}").into_boxed_str());
let key_websocket: Option<&'static str> = if handle_websocket.is_some() {
Some(Box::leak(
format!("__net_serve_websocket_{count}").into_boxed_str(),
))
} else {
None
};
lua.set_named_registry_value(key_request, handle_request)?;
if let Some(key) = key_websocket {
lua.set_named_registry_value(key, handle_websocket.unwrap())?;
}
Ok(Self {
key_request,
key_websocket,
})
}
pub(super) fn has_websocket_handler(&self) -> bool {
self.key_websocket.is_some()
}
pub(super) fn request_handler(&self, lua: &Lua) -> LuaResult<LuaFunction> {
lua.named_registry_value(self.key_request)
}
pub(super) fn websocket_handler(&self, lua: &Lua) -> LuaResult<Option<LuaFunction>> {
self.key_websocket
.map(|key| lua.named_registry_value(key))
.transpose()
}
}

View file

@ -1,92 +1,121 @@
use std::net::SocketAddr;
use std::{cell::Cell, net::SocketAddr, rc::Rc};
use hyper::server::conn::http1;
use hyper_util::rt::TokioIo;
use tokio::{net::TcpListener, pin};
use async_net::TcpListener;
use futures_lite::pin;
use hyper::server::conn::http1::Builder as Http1Builder;
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
use lune_utils::TableBuilder;
use crate::{
server::{config::ServeConfig, handle::ServeHandle, service::Service},
shared::{
futures::{either, Either},
hyper::{HyperIo, HyperTimer},
},
};
use super::config::ServeConfig;
pub mod config;
pub mod handle;
pub mod service;
pub mod upgrade;
mod keys;
mod request;
mod response;
mod service;
/**
Starts an HTTP server using the given port and configuration.
use keys::SvcKeys;
use service::Svc;
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<LuaTable> {
let addr: SocketAddr = (config.address, port).into();
let listener = TcpListener::bind(addr).await?;
let lua_svc = lua.clone();
let lua_inner = lua.clone();
let keys = SvcKeys::new(lua.clone(), config.handle_request, config.handle_web_socket)?;
let svc = Svc {
lua: lua_svc,
addr,
keys,
Returns a `ServeHandle` that can be used to gracefully stop the server.
*/
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<ServeHandle> {
let address = SocketAddr::from((config.address, port));
let service = Service {
lua: lua.clone(),
address,
config,
};
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
lua.spawn_local(async move {
let mut shutdown_rx_outer = shutdown_rx.clone();
loop {
// Create futures for accepting new connections and shutting down
let fut_shutdown = shutdown_rx_outer.changed();
let fut_accept = async {
let stream = match listener.accept().await {
Err(_) => return,
Ok((s, _)) => s,
let listener = TcpListener::bind(address).await?;
let (handle, shutdown_rx) = ServeHandle::new(address);
lua.spawn_local({
let lua = lua.clone();
async move {
let handle_dropped = Rc::new(Cell::new(false));
loop {
// 1. Keep accepting new connections until we should shutdown
let (conn, addr) = if handle_dropped.get() {
// 1a. Handle has been dropped, and we don't need to listen for shutdown
match listener.accept().await {
Ok(acc) => acc,
Err(_err) => {
// TODO: Propagate error somehow
continue;
}
}
} else {
// 1b. Handle is possibly active, we must listen for shutdown
match either(shutdown_rx.recv(), listener.accept()).await {
Either::Left(Ok(())) => break,
Either::Left(Err(_)) => {
// NOTE #1: We will only get a RecvError if the serve handle is dropped,
// this means lua has garbage collected it and the user does not want
// to manually stop the server using the serve handle. Run forever.
handle_dropped.set(true);
continue;
}
Either::Right(Ok(acc)) => acc,
Either::Right(Err(_err)) => {
// TODO: Propagate error somehow
continue;
}
}
};
let io = TokioIo::new(stream);
let svc = svc.clone();
let mut shutdown_rx_inner = shutdown_rx.clone();
// 2. For each connection, spawn a new task to handle it
lua.spawn_local({
let rx = shutdown_rx.clone();
let io = HyperIo::from(conn);
lua_inner.spawn_local(async move {
let conn = http1::Builder::new()
.keep_alive(true) // Web sockets need this
.serve_connection(io, svc)
.with_upgrades();
// NOTE: Because we need to use keep_alive for websockets, we need to
// also manually poll this future and handle the shutdown signal here
pin!(conn);
tokio::select! {
_ = conn.as_mut() => {}
_ = shutdown_rx_inner.changed() => {
conn.as_mut().graceful_shutdown();
let mut svc = service.clone();
svc.address = addr;
let handle_dropped = Rc::clone(&handle_dropped);
async move {
let conn = Http1Builder::new()
.writev(false)
.timer(HyperTimer)
.keep_alive(true)
.serve_connection(io, svc)
.with_upgrades();
if handle_dropped.get() {
if let Err(_err) = conn.await {
// TODO: Propagate error somehow
}
} else {
// NOTE #2: Because we use keep_alive for websockets above, we need to
// also manually poll this future and handle the graceful shutdown,
// otherwise the already accepted connection will linger and run
// even if the stop method has been called on the serve handle
pin!(conn);
match either(rx.recv(), conn.as_mut()).await {
Either::Left(Ok(())) => conn.as_mut().graceful_shutdown(),
Either::Left(Err(_)) => {
// Same as note #1
handle_dropped.set(true);
if let Err(_err) = conn.await {
// TODO: Propagate error somehow
}
}
Either::Right(Ok(())) => {}
Either::Right(Err(_err)) => {
// TODO: Propagate error somehow
}
}
}
}
});
};
// Wait for either a new connection or a shutdown signal
tokio::select! {
() = fut_accept => {}
res = fut_shutdown => {
// NOTE: We will only get a RecvError here if the serve handle is dropped,
// this means lua has garbage collected it and the user does not want
// to manually stop the server using the serve handle. Run forever.
if res.is_ok() {
break;
}
}
}
}
});
TableBuilder::new(lua)?
.with_value("ip", addr.ip().to_string())?
.with_value("port", addr.port())?
.with_function("stop", move |_, (): ()| match shutdown_tx.send(true) {
Ok(()) => Ok(()),
Err(_) => Err(LuaError::runtime("Server already stopped")),
})?
.build_readonly()
Ok(handle)
}

View file

@ -1,56 +0,0 @@
use std::{collections::HashMap, net::SocketAddr};
use http::request::Parts;
use mlua::prelude::*;
use lune_utils::TableBuilder;
pub(super) struct LuaRequest {
pub(super) _remote_addr: SocketAddr,
pub(super) head: Parts,
pub(super) body: Vec<u8>,
}
impl LuaRequest {
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
let method = self.head.method.as_str().to_string();
let path = self.head.uri.path().to_string();
let body = lua.create_string(&self.body)?;
#[allow(clippy::mutable_key_type)]
let query: HashMap<LuaString, LuaString> = self
.head
.uri
.query()
.unwrap_or_default()
.split('&')
.filter_map(|q| q.split_once('='))
.map(|(k, v)| {
let k = lua.create_string(k)?;
let v = lua.create_string(v)?;
Ok((k, v))
})
.collect::<LuaResult<_>>()?;
#[allow(clippy::mutable_key_type)]
let headers: HashMap<LuaString, LuaString> = self
.head
.headers
.iter()
.map(|(k, v)| {
let k = lua.create_string(k.as_str())?;
let v = lua.create_string(v.as_bytes())?;
Ok((k, v))
})
.collect::<LuaResult<_>>()?;
TableBuilder::new(lua.clone())?
.with_value("method", method)?
.with_value("path", path)?
.with_value("query", query)?
.with_value("headers", headers)?
.with_value("body", body)?
.build()
}
}

View file

@ -1,89 +0,0 @@
use std::str::FromStr;
use bstr::{BString, ByteSlice};
use http_body_util::Full;
use hyper::{
body::Bytes,
header::{HeaderName, HeaderValue},
HeaderMap, Response,
};
use mlua::prelude::*;
#[derive(Debug, Clone, Copy)]
pub(super) enum LuaResponseKind {
PlainText,
Table,
}
pub(super) struct LuaResponse {
pub(super) kind: LuaResponseKind,
pub(super) status: u16,
pub(super) headers: HeaderMap,
pub(super) body: Option<Vec<u8>>,
}
impl LuaResponse {
pub(super) fn into_response(self) -> LuaResult<Response<Full<Bytes>>> {
Ok(match self.kind {
LuaResponseKind::PlainText => Response::builder()
.status(200)
.header("Content-Type", "text/plain")
.body(Full::new(Bytes::from(self.body.unwrap())))
.into_lua_err()?,
LuaResponseKind::Table => {
let mut response = Response::builder()
.status(self.status)
.body(Full::new(Bytes::from(self.body.unwrap_or_default())))
.into_lua_err()?;
response.headers_mut().extend(self.headers);
response
}
})
}
}
impl FromLua for LuaResponse {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
match value {
// Plain strings from the handler are plaintext responses
LuaValue::String(s) => Ok(Self {
kind: LuaResponseKind::PlainText,
status: 200,
headers: HeaderMap::new(),
body: Some(s.as_bytes().to_vec()),
}),
// Tables are more detailed responses with potential status, headers, body
LuaValue::Table(t) => {
let status: Option<u16> = t.get("status")?;
let headers: Option<LuaTable> = t.get("headers")?;
let body: Option<BString> = t.get("body")?;
let mut headers_map = HeaderMap::new();
if let Some(headers) = headers {
for pair in headers.pairs::<String, LuaString>() {
let (h, v) = pair?;
let name = HeaderName::from_str(&h).into_lua_err()?;
let value = HeaderValue::from_bytes(&v.as_bytes()).into_lua_err()?;
headers_map.insert(name, value);
}
}
let body_bytes = body.map(|s| s.as_bytes().to_vec());
Ok(Self {
kind: LuaResponseKind::Table,
status: status.unwrap_or(200),
headers: headers_map,
body: body_bytes,
})
}
// Anything else is an error
value => Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "NetServeResponse".to_string(),
message: None,
}),
}
}
}

View file

@ -1,82 +1,117 @@
use std::{future::Future, net::SocketAddr, pin::Pin};
use http_body_util::{BodyExt, Full};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use http_body_util::Full;
use hyper::{
body::{Bytes, Incoming},
service::Service,
Request, Response,
service::Service as HyperService,
Request as HyperRequest, Response as HyperResponse, StatusCode,
};
use hyper_tungstenite::{is_upgrade_request, upgrade};
use mlua::prelude::*;
use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt};
use super::{
super::websocket::NetWebSocket, keys::SvcKeys, request::LuaRequest, response::LuaResponse,
use crate::{
server::{
config::ServeConfig,
upgrade::{is_upgrade_request, make_upgrade_response},
},
shared::{hyper::HyperIo, request::Request, response::Response, websocket::Websocket},
};
#[derive(Debug, Clone)]
pub(super) struct Svc {
pub(super) struct Service {
pub(super) lua: Lua,
pub(super) addr: SocketAddr,
pub(super) keys: SvcKeys,
pub(super) address: SocketAddr, // NOTE: This must be the remote address of the connected client
pub(super) config: ServeConfig,
}
impl Service<Request<Incoming>> for Svc {
type Response = Response<Full<Bytes>>;
impl HyperService<HyperRequest<Incoming>> for Service {
type Response = HyperResponse<Full<Bytes>>;
type Error = LuaError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
let lua = self.lua.clone();
let addr = self.addr;
let keys = self.keys;
fn call(&self, req: HyperRequest<Incoming>) -> Self::Future {
if is_upgrade_request(&req) {
if let Some(handler) = self.config.handle_web_socket.clone() {
let lua = self.lua.clone();
return Box::pin(async move {
let response = match make_upgrade_response(&req) {
Ok(res) => res,
Err(err) => {
return Ok(HyperResponse::builder()
.status(StatusCode::BAD_REQUEST)
.body(Full::new(Bytes::from(err.to_string())))
.unwrap())
}
};
if keys.has_websocket_handler() && is_upgrade_request(&req) {
Box::pin(async move {
let (res, sock) = upgrade(req, None).into_lua_err()?;
lua.spawn_local({
let lua = lua.clone();
async move {
if let Err(_err) = handle_websocket(lua, handler, req).await {
// TODO: Propagate the error somehow?
}
}
});
let lua_inner = lua.clone();
lua.spawn_local(async move {
let sock = sock.await.unwrap();
let lua_sock = NetWebSocket::new(sock);
let lua_val = lua_sock.into_lua(&lua_inner).unwrap();
let handler_websocket: LuaFunction =
keys.websocket_handler(&lua_inner).unwrap().unwrap();
lua_inner
.push_thread_back(handler_websocket, lua_val)
.unwrap();
Ok(response)
});
Ok(res)
})
} else {
let (head, body) = req.into_parts();
Box::pin(async move {
let handler_request: LuaFunction = keys.request_handler(&lua).unwrap();
let body = body.collect().await.into_lua_err()?;
let body = body.to_bytes().to_vec();
let lua_req = LuaRequest {
_remote_addr: addr,
head,
body,
};
let lua_req_table = lua_req.into_lua_table(&lua)?;
let thread_id = lua.push_thread_back(handler_request, lua_req_table)?;
lua.track_thread(thread_id);
lua.wait_for_thread(thread_id).await;
let thread_res = lua
.get_thread_result(thread_id)
.expect("Missing handler thread result")?;
LuaResponse::from_lua_multi(thread_res, &lua)?.into_response()
})
}
}
let lua = self.lua.clone();
let address = self.address;
let handler = self.config.handle_request.clone();
Box::pin(async move {
match handle_request(lua, handler, req, address).await {
Ok(response) => Ok(response),
Err(_err) => {
// TODO: Propagate the error somehow?
Ok(HyperResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Full::new(Bytes::from("Lune: Internal server error")))
.unwrap())
}
}
})
}
}
async fn handle_request(
lua: Lua,
handler: LuaFunction,
request: HyperRequest<Incoming>,
address: SocketAddr,
) -> LuaResult<HyperResponse<Full<Bytes>>> {
let request = Request::from_incoming(request, true)
.await?
.with_address(address);
let thread_id = lua.push_thread_back(handler, request)?;
lua.track_thread(thread_id);
lua.wait_for_thread(thread_id).await;
let thread_res = lua
.get_thread_result(thread_id)
.expect("Missing handler thread result")?;
let response = Response::from_lua_multi(thread_res, &lua)?;
Ok(response.into_full())
}
async fn handle_websocket(
lua: Lua,
handler: LuaFunction,
request: HyperRequest<Incoming>,
) -> LuaResult<()> {
let upgraded = hyper::upgrade::on(request).await.into_lua_err()?;
let stream =
WebSocketStream::from_raw_socket(HyperIo::from(upgraded), Role::Server, None).await;
let websocket = Websocket::from(stream);
lua.push_thread_back(handler, websocket)?;
Ok(())
}

View file

@ -0,0 +1,55 @@
use async_tungstenite::tungstenite::{error::ProtocolError, handshake::derive_accept_key};
use http_body_util::Full;
use hyper::{
body::{Bytes, Incoming},
header::{HeaderName, CONNECTION, UPGRADE},
HeaderMap, Request as HyperRequest, Response as HyperResponse, StatusCode,
};
const SEC_WEBSOCKET_VERSION: HeaderName = HeaderName::from_static("sec-websocket-version");
const SEC_WEBSOCKET_KEY: HeaderName = HeaderName::from_static("sec-websocket-key");
const SEC_WEBSOCKET_ACCEPT: HeaderName = HeaderName::from_static("sec-websocket-accept");
pub fn is_upgrade_request(request: &HyperRequest<Incoming>) -> bool {
fn check_header_contains(headers: &HeaderMap, header_name: HeaderName, value: &str) -> bool {
headers.get(header_name).is_some_and(|header| {
header.to_str().map_or_else(
|_| false,
|header_str| {
header_str
.split(',')
.any(|part| part.trim().eq_ignore_ascii_case(value))
},
)
})
}
check_header_contains(request.headers(), CONNECTION, "Upgrade")
&& check_header_contains(request.headers(), UPGRADE, "websocket")
}
pub fn make_upgrade_response(
request: &HyperRequest<Incoming>,
) -> Result<HyperResponse<Full<Bytes>>, ProtocolError> {
let key = request
.headers()
.get(SEC_WEBSOCKET_KEY)
.ok_or(ProtocolError::MissingSecWebSocketKey)?;
if request
.headers()
.get(SEC_WEBSOCKET_VERSION)
.is_none_or(|v| v.as_bytes() != b"13")
{
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
}
Ok(HyperResponse::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(CONNECTION, "upgrade")
.header(UPGRADE, "websocket")
.header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(key.as_bytes()))
.body(Full::new(Bytes::from("switching to websocket protocol")))
.unwrap())
}

View file

@ -0,0 +1,43 @@
use http_body_util::{BodyExt, Full};
use hyper::{
body::{Bytes, Incoming},
header::CONTENT_ENCODING,
HeaderMap,
};
use mlua::prelude::*;
use lune_std_serde::{decompress, CompressDecompressFormat};
pub async fn handle_incoming_body(
headers: &HeaderMap,
body: Incoming,
should_decompress: bool,
) -> LuaResult<(Bytes, bool)> {
let mut body = body.collect().await.into_lua_err()?.to_bytes();
let was_decompressed = if should_decompress {
let decompress_format = headers
.get(CONTENT_ENCODING)
.and_then(|value| value.to_str().ok())
.and_then(CompressDecompressFormat::detect_from_header_str);
if let Some(format) = decompress_format {
body = Bytes::from(decompress(body, format).await?);
true
} else {
false
}
} else {
false
};
Ok((body, was_decompressed))
}
pub fn bytes_to_full(bytes: Bytes) -> Full<Bytes> {
if bytes.is_empty() {
Full::default()
} else {
Full::new(bytes)
}
}

View file

@ -0,0 +1,19 @@
use futures_lite::prelude::*;
pub use http_body_util::Either;
/**
Combines the left and right futures into a single future
that resolves to either the left or right output.
This combinator is biased - if both futures resolve at
the same time, the left future's output is returned.
*/
pub fn either<L: Future, R: Future>(
left: L,
right: R,
) -> impl Future<Output = Either<L::Output, R::Output>> {
let fut_left = async move { Either::Left(left.await) };
let fut_right = async move { Either::Right(right.await) };
fut_left.or(fut_right)
}

View file

@ -0,0 +1,88 @@
use std::collections::HashMap;
use hyper::{
header::{CONTENT_ENCODING, CONTENT_LENGTH},
HeaderMap,
};
use lune_utils::TableBuilder;
use mlua::prelude::*;
pub fn create_user_agent_header(lua: &Lua) -> LuaResult<String> {
let version_global = lua
.globals()
.get::<LuaString>("_VERSION")
.expect("Missing _VERSION global");
let version_global_str = version_global
.to_str()
.context("Invalid utf8 found in _VERSION global")?;
let (package_name, full_version) = version_global_str.split_once(' ').unwrap();
Ok(format!("{}/{}", package_name.to_lowercase(), full_version))
}
pub fn header_map_to_table(
lua: &Lua,
headers: HeaderMap,
remove_content_headers: bool,
) -> LuaResult<LuaTable> {
let mut string_map = HashMap::<String, Vec<String>>::new();
for (name, value) in headers {
if let Some(name) = name {
if let Ok(value) = value.to_str() {
string_map
.entry(name.to_string())
.or_default()
.push(value.to_owned());
}
}
}
hash_map_to_table(lua, string_map, remove_content_headers)
}
pub fn hash_map_to_table(
lua: &Lua,
map: impl IntoIterator<Item = (String, Vec<String>)>,
remove_content_headers: bool,
) -> LuaResult<LuaTable> {
let mut string_map = HashMap::<String, Vec<String>>::new();
for (name, values) in map {
let name = name.as_str();
if remove_content_headers {
let content_encoding_header_str = CONTENT_ENCODING.as_str();
let content_length_header_str = CONTENT_LENGTH.as_str();
if name == content_encoding_header_str || name == content_length_header_str {
continue;
}
}
for value in values {
let value = value.as_str();
string_map
.entry(name.to_owned())
.or_default()
.push(value.to_owned());
}
}
let mut builder = TableBuilder::new(lua.clone())?;
for (name, mut values) in string_map {
if values.len() == 1 {
let value = values.pop().unwrap().into_lua(lua)?;
builder = builder.with_value(name, value)?;
} else {
let values = TableBuilder::new(lua.clone())?
.with_sequential_values(values)?
.build_readonly()?
.into_lua(lua)?;
builder = builder.with_value(name, values)?;
}
}
builder.build_readonly()
}

View file

@ -0,0 +1,198 @@
use std::{
future::Future,
io,
pin::Pin,
slice,
task::{Context, Poll},
time::{Duration, Instant},
};
use async_io::Timer;
use futures_lite::{prelude::*, ready};
use hyper::rt::{self, Executor, ReadBuf, ReadBufCursor};
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
// Hyper executor that spawns futures onto our Lua scheduler
#[derive(Debug, Clone)]
pub struct HyperExecutor {
lua: Lua,
}
#[allow(dead_code)]
impl HyperExecutor {
pub fn execute<Fut>(lua: Lua, fut: Fut)
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
let exec = if let Some(exec) = lua.app_data_ref::<Self>() {
exec
} else {
lua.set_app_data(Self { lua: lua.clone() });
lua.app_data_ref::<Self>().unwrap()
};
exec.execute(fut);
}
}
impl<Fut: Future + Send + 'static> rt::Executor<Fut> for HyperExecutor
where
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
self.lua.spawn(fut).detach();
}
}
// Hyper timer & sleep future wrapper for async-io
#[derive(Debug)]
pub struct HyperTimer;
impl rt::Timer for HyperTimer {
fn sleep(&self, duration: Duration) -> Pin<Box<dyn rt::Sleep>> {
Box::pin(HyperSleep::from(Timer::after(duration)))
}
fn sleep_until(&self, at: Instant) -> Pin<Box<dyn rt::Sleep>> {
Box::pin(HyperSleep::from(Timer::at(at)))
}
fn reset(&self, sleep: &mut Pin<Box<dyn rt::Sleep>>, new_deadline: Instant) {
if let Some(mut sleep) = sleep.as_mut().downcast_mut_pin::<HyperSleep>() {
sleep.inner.set_at(new_deadline);
} else {
*sleep = Box::pin(HyperSleep::from(Timer::at(new_deadline)));
}
}
}
#[derive(Debug)]
pub struct HyperSleep {
inner: Timer,
}
impl From<Timer> for HyperSleep {
fn from(inner: Timer) -> Self {
Self { inner }
}
}
impl Future for HyperSleep {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
match Pin::new(&mut self.inner).poll(cx) {
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
}
}
}
impl rt::Sleep for HyperSleep {}
// Hyper I/O wrapper for bidirectional compatibility
// between hyper & futures-lite async read/write traits
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct HyperIo<T> {
#[pin]
inner: T
}
}
impl<T> From<T> for HyperIo<T> {
fn from(inner: T) -> Self {
Self { inner }
}
}
impl<T> HyperIo<T> {
pub fn pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().inner
}
}
// Compat for futures-lite -> hyper runtime
impl<T: AsyncRead> rt::Read for HyperIo<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
// Fill the read buffer with initialized data
let read_slice = unsafe {
let buffer = buf.as_mut();
buffer.as_mut_ptr().write_bytes(0, buffer.len());
slice::from_raw_parts_mut(buffer.as_mut_ptr().cast::<u8>(), buffer.len())
};
// Read bytes from the underlying source
let n = match self.pin_mut().poll_read(cx, read_slice) {
Poll::Ready(Ok(n)) => n,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
};
unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}
impl<T: AsyncWrite> rt::Write for HyperIo<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.pin_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.pin_mut().poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.pin_mut().poll_close(cx)
}
}
// Compat for hyper runtime -> futures-lite
impl<T: rt::Read> AsyncRead for HyperIo<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut buf = ReadBuf::new(buf);
ready!(self.pin_mut().poll_read(cx, buf.unfilled()))?;
Poll::Ready(Ok(buf.filled().len()))
}
}
impl<T: rt::Write> AsyncWrite for HyperIo<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
self.pin_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
self.pin_mut().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.pin_mut().poll_shutdown(cx)
}
}

View file

@ -0,0 +1,57 @@
use hyper::{
body::Bytes,
header::{HeaderName, HeaderValue},
HeaderMap, Method,
};
use mlua::prelude::*;
pub fn lua_value_to_bytes(value: &LuaValue) -> LuaResult<Bytes> {
match value {
LuaValue::Nil => Ok(Bytes::new()),
LuaValue::Buffer(buf) => Ok(Bytes::from(buf.to_vec())),
LuaValue::String(str) => Ok(Bytes::copy_from_slice(&str.as_bytes())),
v => Err(LuaError::FromLuaConversionError {
from: v.type_name(),
to: "Bytes".to_string(),
message: Some(format!(
"Invalid body - expected string or buffer, got {}",
v.type_name()
)),
}),
}
}
pub fn lua_value_to_method(value: &LuaValue) -> LuaResult<Method> {
match value {
LuaValue::Nil => Ok(Method::GET),
LuaValue::String(str) => {
let bytes = str.as_bytes().trim_ascii().to_ascii_uppercase();
Method::from_bytes(&bytes).into_lua_err()
}
LuaValue::Buffer(buf) => {
let bytes = buf.to_vec().trim_ascii().to_ascii_uppercase();
Method::from_bytes(&bytes).into_lua_err()
}
v => Err(LuaError::FromLuaConversionError {
from: v.type_name(),
to: "Method".to_string(),
message: Some(format!(
"Invalid method - expected string or buffer, got {}",
v.type_name()
)),
}),
}
}
pub fn lua_table_to_header_map(table: &LuaTable) -> LuaResult<HeaderMap> {
let mut headers = HeaderMap::new();
for pair in table.pairs::<LuaString, LuaString>() {
let (key, val) = pair?;
let key = HeaderName::from_bytes(&key.as_bytes()).into_lua_err()?;
let val = HeaderValue::from_bytes(&val.as_bytes()).into_lua_err()?;
headers.insert(key, val);
}
Ok(headers)
}

View file

@ -0,0 +1,8 @@
pub mod body;
pub mod futures;
pub mod headers;
pub mod hyper;
pub mod lua;
pub mod request;
pub mod response;
pub mod websocket;

View file

@ -0,0 +1,272 @@
use std::{collections::HashMap, net::SocketAddr};
use http_body_util::Full;
use url::Url;
use hyper::{
body::{Bytes, Incoming},
HeaderMap, Method, Request as HyperRequest,
};
use mlua::prelude::*;
use crate::shared::{
body::{bytes_to_full, handle_incoming_body},
headers::{hash_map_to_table, header_map_to_table},
lua::{lua_table_to_header_map, lua_value_to_bytes, lua_value_to_method},
};
#[derive(Debug, Clone)]
pub struct RequestOptions {
pub decompress: bool,
}
impl Default for RequestOptions {
fn default() -> Self {
Self { decompress: true }
}
}
impl FromLua for RequestOptions {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::Nil = value {
// Nil means default options
Ok(Self::default())
} else if let LuaValue::Table(tab) = value {
// Table means custom options
let decompress = match tab.get::<Option<bool>>("decompress") {
Ok(decomp) => Ok(decomp.unwrap_or(true)),
Err(_) => Err(LuaError::RuntimeError(
"Invalid option value for 'decompress' in request options".to_string(),
)),
}?;
Ok(Self { decompress })
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestOptions".to_string(),
message: Some(format!(
"Invalid request options - expected table or nil, got {}",
value.type_name()
)),
})
}
}
}
#[derive(Debug, Clone)]
pub struct Request {
// NOTE: We use Bytes instead of Full<Bytes> to avoid
// needing async when getting a reference to the body
pub(crate) inner: HyperRequest<Bytes>,
pub(crate) address: Option<SocketAddr>,
pub(crate) redirects: Option<usize>,
pub(crate) decompress: bool,
}
impl Request {
/**
Creates a new request from a raw incoming request.
*/
pub async fn from_incoming(
incoming: HyperRequest<Incoming>,
decompress: bool,
) -> LuaResult<Self> {
let (parts, body) = incoming.into_parts();
let (body, decompress) = handle_incoming_body(&parts.headers, body, decompress).await?;
Ok(Self {
inner: HyperRequest::from_parts(parts, body),
address: None,
redirects: None,
decompress,
})
}
/**
Attaches a socket address to the request.
This will make the `ip` and `port` fields available on the request.
*/
pub fn with_address(mut self, address: SocketAddr) -> Self {
self.address = Some(address);
self
}
/**
Returns the method of the request.
*/
pub fn method(&self) -> Method {
self.inner.method().clone()
}
/**
Returns the path of the request.
*/
pub fn path(&self) -> &str {
self.inner.uri().path()
}
/**
Returns the query parameters of the request.
*/
pub fn query(&self) -> HashMap<String, Vec<String>> {
let uri = self.inner.uri();
let url = uri.to_string().parse::<Url>().expect("uri is valid");
let mut result = HashMap::<String, Vec<String>>::new();
for (key, value) in url.query_pairs() {
result
.entry(key.into_owned())
.or_default()
.push(value.into_owned());
}
result
}
/**
Returns the headers of the request.
*/
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
/**
Returns the body of the request.
*/
pub fn body(&self) -> &[u8] {
self.inner.body()
}
/**
Clones the inner `hyper` request with its body
type modified to `Full<Bytes>` for sending.
*/
#[allow(dead_code)]
pub fn as_full(&self) -> HyperRequest<Full<Bytes>> {
let mut builder = HyperRequest::builder()
.version(self.inner.version())
.method(self.inner.method())
.uri(self.inner.uri());
builder
.headers_mut()
.expect("request was valid")
.extend(self.inner.headers().clone());
let body = bytes_to_full(self.inner.body().clone());
builder.body(body).expect("request was valid")
}
/**
Takes the inner `hyper` request with its body
type modified to `Full<Bytes>` for sending.
*/
#[allow(dead_code)]
pub fn into_full(self) -> HyperRequest<Full<Bytes>> {
let (parts, body) = self.inner.into_parts();
HyperRequest::from_parts(parts, bytes_to_full(body))
}
}
impl FromLua for Request {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
if let LuaValue::String(s) = value {
// If we just got a string we assume
// its a GET request to a given url
let uri = s.to_str()?;
let uri = uri.parse().into_lua_err()?;
let mut request = HyperRequest::new(Bytes::new());
*request.uri_mut() = uri;
Ok(Self {
inner: request,
address: None,
redirects: None,
decompress: RequestOptions::default().decompress,
})
} else if let LuaValue::Table(tab) = value {
// If we got a table we are able to configure the
// entire request, maybe with extra options too
let options = match tab.get::<LuaValue>("options") {
Ok(opts) => RequestOptions::from_lua(opts, lua)?,
Err(_) => RequestOptions::default(),
};
// Extract url (required) + optional structured query params
let url = tab.get::<LuaString>("url")?;
let mut url = url.to_str()?.parse::<Url>().into_lua_err()?;
if let Some(t) = tab.get::<Option<LuaTable>>("query")? {
let mut query = url.query_pairs_mut();
for pair in t.pairs::<LuaString, LuaString>() {
let (key, value) = pair?;
let key = key.to_str()?;
let value = value.to_str()?;
query.append_pair(&key, &value);
}
}
// Extract method
let method = tab.get::<LuaValue>("method")?;
let method = lua_value_to_method(&method)?;
// Extract headers
let headers = tab.get::<Option<LuaTable>>("headers")?;
let headers = headers
.map(|t| lua_table_to_header_map(&t))
.transpose()?
.unwrap_or_default();
// Extract body
let body = tab.get::<LuaValue>("body")?;
let body = lua_value_to_bytes(&body)?;
// Build the full request
let mut request = HyperRequest::new(body);
request.headers_mut().extend(headers);
*request.uri_mut() = url.to_string().parse().unwrap();
*request.method_mut() = method;
// All good, validated and we got what we need
Ok(Self {
inner: request,
address: None,
redirects: None,
decompress: options.decompress,
})
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "Request".to_string(),
message: Some(format!(
"Invalid request - expected string or table, got {}",
value.type_name()
)),
})
}
}
}
impl LuaUserData for Request {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ip", |_, this| {
Ok(this.address.map(|address| address.ip().to_string()))
});
fields.add_field_method_get("port", |_, this| {
Ok(this.address.map(|address| address.port()))
});
fields.add_field_method_get("method", |_, this| Ok(this.method().to_string()));
fields.add_field_method_get("path", |_, this| Ok(this.path().to_string()));
fields.add_field_method_get("query", |lua, this| {
hash_map_to_table(lua, this.query(), false)
});
fields.add_field_method_get("headers", |lua, this| {
header_map_to_table(lua, this.headers().clone(), this.decompress)
});
fields.add_field_method_get("body", |lua, this| lua.create_string(this.body()));
}
}

View file

@ -0,0 +1,172 @@
use http_body_util::Full;
use hyper::{
body::{Bytes, Incoming},
header::{HeaderValue, CONTENT_TYPE},
HeaderMap, Response as HyperResponse, StatusCode,
};
use mlua::prelude::*;
use crate::shared::{
body::{bytes_to_full, handle_incoming_body},
headers::header_map_to_table,
lua::{lua_table_to_header_map, lua_value_to_bytes},
};
#[derive(Debug, Clone)]
pub struct Response {
// NOTE: We use Bytes instead of Full<Bytes> to avoid
// needing async when getting a reference to the body
pub(crate) inner: HyperResponse<Bytes>,
pub(crate) decompressed: bool,
}
impl Response {
/**
Creates a new response from a raw incoming response.
*/
pub async fn from_incoming(
incoming: HyperResponse<Incoming>,
decompress: bool,
) -> LuaResult<Self> {
let (parts, body) = incoming.into_parts();
let (body, decompressed) = handle_incoming_body(&parts.headers, body, decompress).await?;
Ok(Self {
inner: HyperResponse::from_parts(parts, body),
decompressed,
})
}
/**
Returns whether the request was successful or not.
*/
pub fn status_ok(&self) -> bool {
self.inner.status().is_success()
}
/**
Returns the status code of the response.
*/
pub fn status_code(&self) -> u16 {
self.inner.status().as_u16()
}
/**
Returns the status message of the response.
*/
pub fn status_message(&self) -> &str {
self.inner.status().canonical_reason().unwrap_or_default()
}
/**
Returns the headers of the response.
*/
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
/**
Returns the body of the response.
*/
pub fn body(&self) -> &[u8] {
self.inner.body()
}
/**
Clones the inner `hyper` response with its body
type modified to `Full<Bytes>` for sending.
*/
#[allow(dead_code)]
pub fn as_full(&self) -> HyperResponse<Full<Bytes>> {
let mut builder = HyperResponse::builder()
.version(self.inner.version())
.status(self.inner.status());
builder
.headers_mut()
.expect("request was valid")
.extend(self.inner.headers().clone());
let body = bytes_to_full(self.inner.body().clone());
builder.body(body).expect("request was valid")
}
/**
Takes the inner `hyper` response with its body
type modified to `Full<Bytes>` for sending.
*/
#[allow(dead_code)]
pub fn into_full(self) -> HyperResponse<Full<Bytes>> {
let (parts, body) = self.inner.into_parts();
HyperResponse::from_parts(parts, bytes_to_full(body))
}
}
impl FromLua for Response {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let Ok(body) = lua_value_to_bytes(&value) {
// String or buffer is always a 200 text/plain response
let mut response = HyperResponse::new(body);
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
Ok(Self {
inner: response,
decompressed: false,
})
} else if let LuaValue::Table(tab) = value {
// Extract status (required)
let status = tab.get::<u16>("status")?;
let status = StatusCode::from_u16(status).into_lua_err()?;
// Extract headers
let headers = tab.get::<Option<LuaTable>>("headers")?;
let headers = headers
.map(|t| lua_table_to_header_map(&t))
.transpose()?
.unwrap_or_default();
// Extract body
let body = tab.get::<LuaValue>("body")?;
let body = lua_value_to_bytes(&body)?;
// Build the full response
let mut response = HyperResponse::new(body);
response.headers_mut().extend(headers);
*response.status_mut() = status;
// All good, validated and we got what we need
Ok(Self {
inner: response,
decompressed: false,
})
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "Response".to_string(),
message: Some(format!(
"Invalid response - expected table/string/buffer, got {}",
value.type_name()
)),
})
}
}
}
impl LuaUserData for Response {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ok", |_, this| Ok(this.status_ok()));
fields.add_field_method_get("statusCode", |_, this| Ok(this.status_code()));
fields.add_field_method_get("statusMessage", |lua, this| {
lua.create_string(this.status_message())
});
fields.add_field_method_get("headers", |lua, this| {
header_map_to_table(lua, this.headers().clone(), this.decompressed)
});
fields.add_field_method_get("body", |lua, this| lua.create_string(this.body()));
}
}

View file

@ -1,62 +1,38 @@
use std::sync::{
atomic::{AtomicBool, AtomicU16, Ordering},
Arc,
use std::{
error::Error,
sync::{
atomic::{AtomicBool, AtomicU16, Ordering},
Arc,
},
};
use async_lock::Mutex as AsyncMutex;
use async_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame},
Message as TungsteniteMessage, Result as TungsteniteResult, Utf8Bytes,
};
use bstr::{BString, ByteSlice};
use futures::{
stream::{SplitSink, SplitStream},
Sink, SinkExt, Stream, StreamExt,
};
use hyper::body::Bytes;
use mlua::prelude::*;
use futures_util::{
stream::{SplitSink, SplitStream},
SinkExt, StreamExt,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::Mutex as AsyncMutex,
};
use hyper_tungstenite::{
tungstenite::{
protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame},
Message as WsMessage,
},
WebSocketStream,
};
#[derive(Debug)]
pub struct NetWebSocket<T> {
#[derive(Debug, Clone)]
pub struct Websocket<T> {
close_code_exists: Arc<AtomicBool>,
close_code_value: Arc<AtomicU16>,
read_stream: Arc<AsyncMutex<SplitStream<WebSocketStream<T>>>>,
write_stream: Arc<AsyncMutex<SplitSink<WebSocketStream<T>, WsMessage>>>,
read_stream: Arc<AsyncMutex<SplitStream<T>>>,
write_stream: Arc<AsyncMutex<SplitSink<T, TungsteniteMessage>>>,
}
impl<T> Clone for NetWebSocket<T> {
fn clone(&self) -> Self {
Self {
close_code_exists: Arc::clone(&self.close_code_exists),
close_code_value: Arc::clone(&self.close_code_value),
read_stream: Arc::clone(&self.read_stream),
write_stream: Arc::clone(&self.write_stream),
}
}
}
impl<T> NetWebSocket<T>
impl<T> Websocket<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{
pub fn new(value: WebSocketStream<T>) -> Self {
let (write, read) = value.split();
Self {
close_code_exists: Arc::new(AtomicBool::new(false)),
close_code_value: Arc::new(AtomicU16::new(0)),
read_stream: Arc::new(AsyncMutex::new(read)),
write_stream: Arc::new(AsyncMutex::new(write)),
}
}
fn get_close_code(&self) -> Option<u16> {
if self.close_code_exists.load(Ordering::Relaxed) {
Some(self.close_code_value.load(Ordering::Relaxed))
@ -70,12 +46,12 @@ where
self.close_code_value.store(code, Ordering::Relaxed);
}
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
pub async fn send(&self, msg: TungsteniteMessage) -> LuaResult<()> {
let mut ws = self.write_stream.lock().await;
ws.send(msg).await.into_lua_err()
}
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
pub async fn next(&self) -> LuaResult<Option<TungsteniteMessage>> {
let mut ws = self.read_stream.lock().await;
ws.next().await.transpose().into_lua_err()
}
@ -85,15 +61,15 @@ where
return Err(LuaError::runtime("Socket has already been closed"));
}
self.send(WsMessage::Close(Some(WsCloseFrame {
self.send(TungsteniteMessage::Close(Some(CloseFrame {
code: match code {
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
Some(code) if (1000..=4999).contains(&code) => CloseCode::from(code),
Some(code) => {
return Err(LuaError::runtime(format!(
"Close code must be between 1000 and 4999, got {code}"
)))
}
None => WsCloseCode::Normal,
None => CloseCode::Normal,
},
reason: "".into(),
})))
@ -104,9 +80,27 @@ where
}
}
impl<T> LuaUserData for NetWebSocket<T>
impl<T> From<T> for Websocket<T>
where
T: AsyncRead + AsyncWrite + Unpin + 'static,
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{
fn from(value: T) -> Self {
let (write, read) = value.split();
Self {
close_code_exists: Arc::new(AtomicBool::new(false)),
close_code_value: Arc::new(AtomicU16::new(0)),
read_stream: Arc::new(AsyncMutex::new(read)),
write_stream: Arc::new(AsyncMutex::new(write)),
}
}
}
impl<T> LuaUserData for Websocket<T>
where
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code()));
@ -121,10 +115,10 @@ where
"send",
|_, this, (string, as_binary): (BString, Option<bool>)| async move {
this.send(if as_binary.unwrap_or_default() {
WsMessage::Binary(string.as_bytes().to_vec())
TungsteniteMessage::Binary(Bytes::from(string.to_vec()))
} else {
let s = string.to_str().into_lua_err()?;
WsMessage::Text(s.to_string())
TungsteniteMessage::Text(Utf8Bytes::from(s))
})
.await
},
@ -133,14 +127,14 @@ where
methods.add_async_method("next", |lua, this, (): ()| async move {
let msg = this.next().await?;
if let Some(WsMessage::Close(Some(frame))) = msg.as_ref() {
if let Some(TungsteniteMessage::Close(Some(frame))) = msg.as_ref() {
this.set_close_code(frame.code.into());
}
Ok(match msg {
Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?),
Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?),
Some(WsMessage::Close(_)) | None => LuaValue::Nil,
Some(TungsteniteMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?),
Some(TungsteniteMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?),
Some(TungsteniteMessage::Close(_)) | None => LuaValue::Nil,
// Ignore ping/pong/frame messages, they are handled by tungstenite
msg => unreachable!("Unhandled message: {:?}", msg),
})

View file

@ -0,0 +1,12 @@
use mlua::prelude::*;
pub fn decode(lua_string: LuaString, as_binary: bool) -> LuaResult<Vec<u8>> {
if as_binary {
Ok(urlencoding::decode_binary(&lua_string.as_bytes()).into_owned())
} else {
Ok(urlencoding::decode(&lua_string.to_str()?)
.map_err(|e| LuaError::RuntimeError(format!("Encountered invalid encoding - {e}")))?
.into_owned()
.into_bytes())
}
}

View file

@ -0,0 +1,13 @@
use mlua::prelude::*;
pub fn encode(lua_string: LuaString, as_binary: bool) -> LuaResult<Vec<u8>> {
if as_binary {
Ok(urlencoding::encode_binary(&lua_string.as_bytes())
.into_owned()
.into_bytes())
} else {
Ok(urlencoding::encode(&lua_string.to_str()?)
.into_owned()
.into_bytes())
}
}

View file

@ -0,0 +1,5 @@
mod decode;
mod encode;
pub use self::decode::decode;
pub use self::encode::encode;

View file

@ -1,94 +0,0 @@
use std::collections::HashMap;
use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH};
use reqwest::header::HeaderMap;
use mlua::prelude::*;
use lune_utils::TableBuilder;
pub fn create_user_agent_header(lua: &Lua) -> LuaResult<String> {
let version_global = lua
.globals()
.get::<LuaString>("_VERSION")
.expect("Missing _VERSION global");
let version_global_str = version_global
.to_str()
.context("Invalid utf8 found in _VERSION global")?;
let (package_name, full_version) = version_global_str.split_once(' ').unwrap();
Ok(format!("{}/{}", package_name.to_lowercase(), full_version))
}
pub fn header_map_to_table(
lua: &Lua,
headers: HeaderMap,
remove_content_headers: bool,
) -> LuaResult<LuaTable> {
let mut res_headers: HashMap<String, Vec<String>> = HashMap::new();
for (name, value) in &headers {
let name = name.as_str();
let value = value.to_str().unwrap().to_owned();
if let Some(existing) = res_headers.get_mut(name) {
existing.push(value);
} else {
res_headers.insert(name.to_owned(), vec![value]);
}
}
if remove_content_headers {
let content_encoding_header_str = CONTENT_ENCODING.as_str();
let content_length_header_str = CONTENT_LENGTH.as_str();
res_headers.retain(|name, _| {
name != content_encoding_header_str && name != content_length_header_str
});
}
let mut builder = TableBuilder::new(lua.clone())?;
for (name, mut values) in res_headers {
if values.len() == 1 {
let value = values.pop().unwrap().into_lua(lua)?;
builder = builder.with_value(name, value)?;
} else {
let values = TableBuilder::new(lua.clone())?
.with_sequential_values(values)?
.build_readonly()?
.into_lua(lua)?;
builder = builder.with_value(name, values)?;
}
}
builder.build_readonly()
}
pub fn table_to_hash_map(
tab: LuaTable,
tab_origin_key: &'static str,
) -> LuaResult<HashMap<String, Vec<String>>> {
let mut map = HashMap::new();
for pair in tab.pairs::<String, LuaValue>() {
let (key, value) = pair?;
match value {
LuaValue::String(s) => {
map.insert(key, vec![s.to_str()?.to_owned()]);
}
LuaValue::Table(t) => {
let mut values = Vec::new();
for value in t.sequence_values::<LuaString>() {
values.push(value?.to_str()?.to_owned());
}
map.insert(key, values);
}
_ => {
return Err(LuaError::runtime(format!(
"Value for '{tab_origin_key}' must be a string or array of strings",
)))
}
}
}
Ok(map)
}

View file

@ -127,13 +127,19 @@ create_tests! {
net_request_methods: "net/request/methods",
net_request_query: "net/request/query",
net_request_redirect: "net/request/redirect",
net_url_encode: "net/url/encode",
net_url_decode: "net/url/decode",
net_serve_addresses: "net/serve/addresses",
net_serve_handles: "net/serve/handles",
net_serve_non_blocking: "net/serve/non_blocking",
net_serve_requests: "net/serve/requests",
net_serve_websockets: "net/serve/websockets",
net_socket_basic: "net/socket/basic",
net_socket_wss: "net/socket/wss",
net_socket_wss_rw: "net/socket/wss_rw",
net_url_encode: "net/url/encode",
net_url_decode: "net/url/decode",
}
#[cfg(feature = "std-process")]

View file

@ -16,13 +16,12 @@ path = "src/lib.rs"
workspace = true
[dependencies]
async-executor = "1.8"
blocking = "1.5"
concurrent-queue = "2.4"
derive_more = "0.99"
event-listener = "4.0"
futures-lite = "2.2"
rustc-hash = "1.1"
async-executor = "1.13"
blocking = "1.6"
concurrent-queue = "2.5"
event-listener = "5.4"
futures-lite = "2.6"
rustc-hash = "2.1"
tracing = "0.1"
mlua = { version = "0.10.3", features = [
@ -34,7 +33,7 @@ mlua = { version = "0.10.3", features = [
[dev-dependencies]
async-fs = "2.1"
async-io = "2.3"
async-io = "2.4"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-tracy = "0.11"

View file

@ -8,7 +8,7 @@ use crate::{
result_map::ThreadResultMap,
thread_id::ThreadId,
traits::LuaSchedulerExt,
util::{is_poll_pending, LuaThreadOrFunction, ThreadResult},
util::{is_poll_pending, LuaThreadOrFunction},
};
const ERR_METADATA_NOT_ATTACHED: &str = "\
@ -123,8 +123,7 @@ impl Functions {
if thread.status() != LuaThreadStatus::Resumable {
let id = ThreadId::from(&thread);
if resume_map.is_tracked(id) {
let res = ThreadResult::new(Ok(v.clone()), lua);
resume_map.insert(id, res);
resume_map.insert(id, Ok(v.clone()));
}
}
(true, v).into_lua_multi(lua)
@ -134,8 +133,7 @@ impl Functions {
// Not pending, store the error
let id = ThreadId::from(&thread);
if resume_map.is_tracked(id) {
let res = ThreadResult::new(Err(e.clone()), lua);
resume_map.insert(id, res);
resume_map.insert(id, Err(e.clone()));
}
(false, e.to_string()).into_lua_multi(lua)
}
@ -177,8 +175,7 @@ impl Functions {
if thread.status() != LuaThreadStatus::Resumable {
let id = ThreadId::from(&thread);
if spawn_map.is_tracked(id) {
let res = ThreadResult::new(Ok(v), lua);
spawn_map.insert(id, res);
spawn_map.insert(id, Ok(v));
}
}
}
@ -188,8 +185,7 @@ impl Functions {
// Not pending, store the error
let id = ThreadId::from(&thread);
if spawn_map.is_tracked(id) {
let res = ThreadResult::new(Err(e), lua);
spawn_map.insert(id, res);
spawn_map.insert(id, Err(e));
}
}
}

View file

@ -1,12 +1,15 @@
use std::{pin::Pin, rc::Rc};
use std::{
ops::{Deref, DerefMut},
pin::Pin,
rc::Rc,
};
use concurrent_queue::ConcurrentQueue;
use derive_more::{Deref, DerefMut};
use event_listener::Event;
use futures_lite::{Future, FutureExt};
use mlua::prelude::*;
use crate::{traits::IntoLuaThread, util::ThreadWithArgs, ThreadId};
use crate::{traits::IntoLuaThread, ThreadId};
/**
Queue for storing [`LuaThread`]s with associated arguments.
@ -16,15 +19,13 @@ use crate::{traits::IntoLuaThread, util::ThreadWithArgs, ThreadId};
*/
#[derive(Debug, Clone)]
pub(crate) struct ThreadQueue {
queue: Rc<ConcurrentQueue<ThreadWithArgs>>,
event: Rc<Event>,
inner: Rc<ThreadQueueInner>,
}
impl ThreadQueue {
pub fn new() -> Self {
let queue = Rc::new(ConcurrentQueue::unbounded());
let event = Rc::new(Event::new());
Self { queue, event }
let inner = Rc::new(ThreadQueueInner::new());
Self { inner }
}
pub fn push_item(
@ -38,32 +39,25 @@ impl ThreadQueue {
tracing::trace!("pushing item to queue with {} args", args.len());
let id = ThreadId::from(&thread);
let stored = ThreadWithArgs::new(lua, thread, args)?;
self.queue.push(stored).into_lua_err()?;
self.event.notify(usize::MAX);
let _ = self.inner.queue.push((thread, args));
self.inner.event.notify(usize::MAX);
Ok(id)
}
#[inline]
pub fn drain_items<'outer, 'lua>(
&'outer self,
lua: &'lua Lua,
) -> impl Iterator<Item = (LuaThread, LuaMultiValue)> + 'outer
where
'lua: 'outer,
{
self.queue.try_iter().map(|stored| stored.into_inner(lua))
pub fn drain_items(&self) -> impl Iterator<Item = (LuaThread, LuaMultiValue)> + '_ {
self.inner.queue.try_iter()
}
#[inline]
pub async fn wait_for_item(&self) {
if self.queue.is_empty() {
let listener = self.event.listen();
if self.inner.queue.is_empty() {
let listener = self.inner.event.listen();
// NOTE: Need to check again, we could have gotten
// new queued items while creating our listener
if self.queue.is_empty() {
if self.inner.queue.is_empty() {
listener.await;
}
}
@ -71,14 +65,14 @@ impl ThreadQueue {
#[inline]
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
self.inner.queue.is_empty()
}
}
/**
Alias for [`ThreadQueue`], providing a newtype to store in Lua app data.
*/
#[derive(Debug, Clone, Deref, DerefMut)]
#[derive(Debug, Clone)]
pub(crate) struct SpawnedThreadQueue(ThreadQueue);
impl SpawnedThreadQueue {
@ -87,10 +81,23 @@ impl SpawnedThreadQueue {
}
}
impl Deref for SpawnedThreadQueue {
type Target = ThreadQueue;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for SpawnedThreadQueue {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
/**
Alias for [`ThreadQueue`], providing a newtype to store in Lua app data.
*/
#[derive(Debug, Clone, Deref, DerefMut)]
#[derive(Debug, Clone)]
pub(crate) struct DeferredThreadQueue(ThreadQueue);
impl DeferredThreadQueue {
@ -99,6 +106,19 @@ impl DeferredThreadQueue {
}
}
impl Deref for DeferredThreadQueue {
type Target = ThreadQueue;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DeferredThreadQueue {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
pub type LocalBoxFuture<'fut> = Pin<Box<dyn Future<Output = ()> + 'fut>>;
/**
@ -109,31 +129,60 @@ pub type LocalBoxFuture<'fut> = Pin<Box<dyn Future<Output = ()> + 'fut>>;
*/
#[derive(Debug, Clone)]
pub(crate) struct FuturesQueue<'fut> {
queue: Rc<ConcurrentQueue<LocalBoxFuture<'fut>>>,
event: Rc<Event>,
inner: Rc<FuturesQueueInner<'fut>>,
}
impl<'fut> FuturesQueue<'fut> {
pub fn new() -> Self {
let queue = Rc::new(ConcurrentQueue::unbounded());
let event = Rc::new(Event::new());
Self { queue, event }
let inner = Rc::new(FuturesQueueInner::new());
Self { inner }
}
pub fn push_item(&self, fut: impl Future<Output = ()> + 'fut) {
let _ = self.queue.push(fut.boxed_local());
self.event.notify(usize::MAX);
let _ = self.inner.queue.push(fut.boxed_local());
self.inner.event.notify(usize::MAX);
}
pub fn drain_items<'outer>(
&'outer self,
) -> impl Iterator<Item = LocalBoxFuture<'fut>> + 'outer {
self.queue.try_iter()
self.inner.queue.try_iter()
}
pub async fn wait_for_item(&self) {
if self.queue.is_empty() {
self.event.listen().await;
if self.inner.queue.is_empty() {
self.inner.event.listen().await;
}
}
}
// Inner structs without ref counting so that outer structs
// have only a single ref counter for extremely cheap clones
#[derive(Debug)]
struct ThreadQueueInner {
queue: ConcurrentQueue<(LuaThread, LuaMultiValue)>,
event: Event,
}
impl ThreadQueueInner {
fn new() -> Self {
let queue = ConcurrentQueue::unbounded();
let event = Event::new();
Self { queue, event }
}
}
#[derive(Debug)]
struct FuturesQueueInner<'fut> {
queue: ConcurrentQueue<LocalBoxFuture<'fut>>,
event: Event,
}
impl FuturesQueueInner<'_> {
pub fn new() -> Self {
let queue = ConcurrentQueue::unbounded();
let event = Event::new();
Self { queue, event }
}
}

View file

@ -5,60 +5,77 @@ use std::{cell::RefCell, rc::Rc};
use event_listener::Event;
// NOTE: This is the hash algorithm that mlua also uses, so we
// are not adding any additional dependencies / bloat by using it.
use mlua::prelude::*;
use rustc_hash::{FxHashMap, FxHashSet};
use crate::{thread_id::ThreadId, util::ThreadResult};
use crate::thread_id::ThreadId;
struct ThreadResultMapInner {
tracked: FxHashSet<ThreadId>,
results: FxHashMap<ThreadId, LuaResult<LuaMultiValue>>,
events: FxHashMap<ThreadId, Rc<Event>>,
}
impl ThreadResultMapInner {
fn new() -> Self {
Self {
tracked: FxHashSet::default(),
results: FxHashMap::default(),
events: FxHashMap::default(),
}
}
}
#[derive(Clone)]
pub(crate) struct ThreadResultMap {
tracked: Rc<RefCell<FxHashSet<ThreadId>>>,
results: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>,
events: Rc<RefCell<FxHashMap<ThreadId, Rc<Event>>>>,
inner: Rc<RefCell<ThreadResultMapInner>>,
}
impl ThreadResultMap {
pub fn new() -> Self {
Self {
tracked: Rc::new(RefCell::new(FxHashSet::default())),
results: Rc::new(RefCell::new(FxHashMap::default())),
events: Rc::new(RefCell::new(FxHashMap::default())),
}
let inner = Rc::new(RefCell::new(ThreadResultMapInner::new()));
Self { inner }
}
#[inline(always)]
pub fn track(&self, id: ThreadId) {
self.tracked.borrow_mut().insert(id);
self.inner.borrow_mut().tracked.insert(id);
}
#[inline(always)]
pub fn is_tracked(&self, id: ThreadId) -> bool {
self.tracked.borrow().contains(&id)
self.inner.borrow().tracked.contains(&id)
}
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
pub fn insert(&self, id: ThreadId, result: LuaResult<LuaMultiValue>) {
debug_assert!(self.is_tracked(id), "Thread must be tracked");
self.results.borrow_mut().insert(id, result);
if let Some(event) = self.events.borrow_mut().remove(&id) {
let mut inner = self.inner.borrow_mut();
inner.results.insert(id, result);
if let Some(event) = inner.events.remove(&id) {
event.notify(usize::MAX);
}
}
pub async fn listen(&self, id: ThreadId) {
debug_assert!(self.is_tracked(id), "Thread must be tracked");
if !self.results.borrow().contains_key(&id) {
if !self.inner.borrow().results.contains_key(&id) {
let listener = {
let mut events = self.events.borrow_mut();
let event = events.entry(id).or_insert_with(|| Rc::new(Event::new()));
let mut inner = self.inner.borrow_mut();
let event = inner
.events
.entry(id)
.or_insert_with(|| Rc::new(Event::new()));
event.listen()
};
listener.await;
}
}
pub fn remove(&self, id: ThreadId) -> Option<ThreadResult> {
let res = self.results.borrow_mut().remove(&id)?;
self.tracked.borrow_mut().remove(&id);
self.events.borrow_mut().remove(&id);
pub fn remove(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue>> {
let mut inner = self.inner.borrow_mut();
let res = inner.results.remove(&id)?;
inner.tracked.remove(&id);
inner.events.remove(&id);
Some(res)
}
}

View file

@ -2,7 +2,7 @@
use std::{
cell::Cell,
rc::{Rc, Weak as WeakRc},
rc::Rc,
sync::{Arc, Weak as WeakArc},
thread::panicking,
};
@ -21,7 +21,7 @@ use crate::{
status::Status,
thread_id::ThreadId,
traits::IntoLuaThread,
util::{run_until_yield, ThreadResult},
util::run_until_yield,
};
const ERR_METADATA_ALREADY_ATTACHED: &str = "\
@ -248,7 +248,7 @@ impl Scheduler {
*/
#[must_use]
pub fn get_thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue>> {
self.result_map.remove(id).map(|r| r.value(&self.lua))
self.result_map.remove(id)
}
/**
@ -286,7 +286,7 @@ impl Scheduler {
*/
let local_exec = LocalExecutor::new();
let main_exec = Arc::new(Executor::new());
let fut_queue = Rc::new(FuturesQueue::new());
let fut_queue = FuturesQueue::new();
/*
Store the main executor and queue in Lua, so that they may be used with LuaSchedulerExt.
@ -299,12 +299,12 @@ impl Scheduler {
"{ERR_METADATA_ALREADY_ATTACHED}"
);
assert!(
self.lua.app_data_ref::<WeakRc<FuturesQueue>>().is_none(),
self.lua.app_data_ref::<FuturesQueue>().is_none(),
"{ERR_METADATA_ALREADY_ATTACHED}"
);
self.lua.set_app_data(Arc::downgrade(&main_exec));
self.lua.set_app_data(Rc::downgrade(&fut_queue.clone()));
self.lua.set_app_data(fut_queue.clone());
/*
Manually tick the Lua executor, while running under the main executor.
@ -342,8 +342,7 @@ impl Scheduler {
self.error_callback.call(e);
}
if thread.status() != LuaThreadStatus::Resumable {
let thread_res = ThreadResult::new(res, &self.lua);
result_map_inner.unwrap().insert(id, thread_res);
result_map_inner.unwrap().insert(id, res);
}
}
} else {
@ -398,14 +397,14 @@ impl Scheduler {
let mut num_futures = 0;
{
let _span = trace_span!("Scheduler::drain_spawned").entered();
for (thread, args) in self.queue_spawn.drain_items(&self.lua) {
for (thread, args) in self.queue_spawn.drain_items() {
process_thread(thread, args);
num_spawned += 1;
}
}
{
let _span = trace_span!("Scheduler::drain_deferred").entered();
for (thread, args) in self.queue_defer.drain_items(&self.lua) {
for (thread, args) in self.queue_defer.drain_items() {
process_thread(thread, args);
num_deferred += 1;
}
@ -446,7 +445,7 @@ impl Scheduler {
.remove_app_data::<WeakArc<Executor>>()
.expect(ERR_METADATA_REMOVED);
self.lua
.remove_app_data::<WeakRc<FuturesQueue>>()
.remove_app_data::<FuturesQueue>()
.expect(ERR_METADATA_REMOVED);
}
}

View file

@ -323,7 +323,7 @@ impl LuaSchedulerExt for Lua {
let map = self
.app_data_ref::<ThreadResultMap>()
.expect("lua threads results can only be retrieved from within an active scheduler");
map.remove(id).map(|r| r.value(self))
map.remove(id)
}
fn wait_for_thread(&self, id: ThreadId) -> impl Future<Output = ()> {
@ -354,10 +354,8 @@ impl LuaSpawnExt for Lua {
F: Future<Output = ()> + 'static,
{
let queue = self
.app_data_ref::<WeakRc<FuturesQueue>>()
.expect("tasks can only be spawned within an active scheduler")
.upgrade()
.expect("executor was dropped");
.app_data_ref::<FuturesQueue>()
.expect("tasks can only be spawned within an active scheduler");
trace!("spawning local task on executor");
queue.push_item(fut);
}

View file

@ -40,74 +40,6 @@ pub(crate) fn is_poll_pending(value: &LuaValue) -> bool {
.is_some_and(|l| l == Lua::poll_pending())
}
/**
Representation of a [`LuaResult`] with an associated [`LuaMultiValue`] currently stored in the Lua registry.
*/
#[derive(Debug)]
pub(crate) struct ThreadResult {
inner: LuaResult<LuaRegistryKey>,
}
impl ThreadResult {
pub fn new(result: LuaResult<LuaMultiValue>, lua: &Lua) -> Self {
Self {
inner: match result {
Ok(v) => Ok({
let vec = v.into_vec();
lua.create_registry_value(vec).expect("out of memory")
}),
Err(e) => Err(e),
},
}
}
pub fn value(self, lua: &Lua) -> LuaResult<LuaMultiValue> {
match self.inner {
Ok(key) => {
let vec = lua.registry_value(&key).unwrap();
lua.remove_registry_value(key).unwrap();
Ok(LuaMultiValue::from_vec(vec))
}
Err(e) => Err(e.clone()),
}
}
}
/**
Representation of a [`LuaThread`] with its associated arguments currently stored in the Lua registry.
*/
#[derive(Debug)]
pub(crate) struct ThreadWithArgs {
key_thread: LuaRegistryKey,
key_args: LuaRegistryKey,
}
impl ThreadWithArgs {
pub fn new(lua: &Lua, thread: LuaThread, args: LuaMultiValue) -> LuaResult<Self> {
let argsv = args.into_vec();
let key_thread = lua.create_registry_value(thread)?;
let key_args = lua.create_registry_value(argsv)?;
Ok(Self {
key_thread,
key_args,
})
}
pub fn into_inner(self, lua: &Lua) -> (LuaThread, LuaMultiValue) {
let thread = lua.registry_value(&self.key_thread).unwrap();
let argsv = lua.registry_value(&self.key_args).unwrap();
let args = LuaMultiValue::from_vec(argsv);
lua.remove_registry_value(self.key_thread).unwrap();
lua.remove_registry_value(self.key_args).unwrap();
(thread, args)
}
}
/**
Wrapper struct to accept either a Lua thread or a Lua function as function argument.

View file

@ -1,4 +1,5 @@
[tools]
luau-lsp = "JohnnyMorganz/luau-lsp@1.33.1"
stylua = "JohnnyMorganz/StyLua@0.20.0"
just = "casey/just@1.36.0"
luau-lsp = "JohnnyMorganz/luau-lsp@1.44.1"
lune = "lune-org/lune@0.9.0"
stylua = "JohnnyMorganz/StyLua@2.1.0"
just = "casey/just@1.40.0"

View file

@ -0,0 +1,14 @@
local fs = require("@lune/fs")
fs.writeDir("./types")
for _, dir in fs.readDir("./crates") do
local std = string.match(dir, "^lune%-std%-(%w+)$")
if std ~= nil then
local from = `./crates/{dir}/types.d.luau`
if fs.isFile(from) then
local to = `./types/{std}.luau`
fs.copy(from, to, true)
end
end
end

View file

@ -0,0 +1,35 @@
local net = require("@lune/net")
local PORT = 8081
local LOCALHOST = "http://localhost"
local BROADCAST = `http://0.0.0.0`
local RESPONSE = "Hello, lune!"
-- Serve should be able to bind to broadcast IP addresse
local handle = net.serve(PORT, {
address = BROADCAST,
handleRequest = function(request)
return `Response from {BROADCAST}:{PORT}`
end,
})
-- And any requests to localhost should then succeed
local response = net.request(`{LOCALHOST}:{PORT}`).body
assert(response ~= nil, "Invalid response from server")
handle.stop()
-- Attempting to serve with a malformed IP address should throw an error
local success = pcall(function()
net.serve(8080, {
address = "a.b.c.d",
handleRequest = function()
return RESPONSE
end,
})
end)
assert(not success, "Server was created with malformed address")

View file

@ -0,0 +1,51 @@
local net = require("@lune/net")
local task = require("@lune/task")
local PORT = 8082
local URL = `http://127.0.0.1:{PORT}`
local RESPONSE = "Hello, lune!"
local handle = net.serve(PORT, function(request)
return RESPONSE
end)
-- Stopping is not guaranteed to happen instantly since it is async, but
-- it should happen on the next yield, so we wait the minimum amount here
handle.stop()
task.wait()
-- Sending a request to the stopped server should now error
local success, response2 = pcall(net.request, URL)
if not success then
local message = tostring(response2)
assert(
string.find(message, "Connection reset")
or string.find(message, "Connection closed")
or string.find(message, "Connection refused")
or string.find(message, "No connection could be made"), -- Windows Request Error
"Server did not stop responding to requests"
)
else
assert(not response2.ok, "Server did not stop responding to requests")
end
--[[
Trying to *stop* the server again should error, and
also mention that the server has already been stopped
Note that we cast pcall to any because of a
Luau limitation where it throws a type error for
`err` because handle.stop doesn't return any value
]]
local success2, err = (pcall :: any)(handle.stop)
assert(not success2, "Calling stop twice on the net serve handle should error")
local message = tostring(err)
assert(
string.find(message, "stop")
or string.find(message, "shutdown")
or string.find(message, "shut down"),
"The error message for calling stop twice on the net serve handle should be descriptive"
)

View file

@ -0,0 +1,24 @@
local net = require("@lune/net")
local process = require("@lune/process")
local stdio = require("@lune/stdio")
local task = require("@lune/task")
local PORT = 8083
local RESPONSE = "Hello, lune!"
-- Serve should not yield the entire main thread forever, only
-- for the initial binding to socket which should be very fast
local thread = task.delay(1, function()
stdio.ewrite("Serve must not yield the current thread for too long\n")
task.wait(1)
process.exit(1)
end)
local handle = net.serve(PORT, function(request)
return RESPONSE
end)
task.cancel(thread)
handle.stop()

View file

@ -3,117 +3,43 @@ local process = require("@lune/process")
local stdio = require("@lune/stdio")
local task = require("@lune/task")
local PORT = 8082
local PORT = 8084
local URL = `http://127.0.0.1:{PORT}`
local URL_EXTERNAL = `http://0.0.0.0`
local RESPONSE = "Hello, lune!"
-- A server should never be running before testing
local isRunning = pcall(net.request, URL)
assert(not isRunning, `a server is already running at {URL}`)
-- Serve should not block the thread from continuing
local thread = task.delay(1, function()
stdio.ewrite("Serve must not block the current thread\n")
task.wait(1)
process.exit(1)
end)
-- Serve should get proper path, query, and other request information
local handle = net.serve(PORT, function(request)
-- print("Request:", request)
-- print("Responding with", RESPONSE)
-- print("Got a request from", request.ip, "on port", request.port)
assert(type(request.path) == "string")
assert(type(request.query) == "table")
assert(type(request.query.key) == "table")
assert(type(request.query.key2) == "string")
assert(request.path == "/some/path")
assert(request.query.key == "param2")
assert(request.query.key[1] == "param1")
assert(request.query.key[2] == "param2")
assert(request.query.key2 == "param3")
return RESPONSE
end)
task.cancel(thread)
-- Serve should be able to handle at least 100 requests per second with a basic handler such as the above
-- Serve should respond to a request we send to it
local thread2 = task.delay(1, function()
local thread = task.delay(1, function()
stdio.ewrite("Serve should respond to requests in a reasonable amount of time\n")
task.wait(1)
process.exit(1)
end)
local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body
assert(response == RESPONSE, "Invalid response from server")
-- Serve should respond to requests we send, and keep responding until we stop it
task.cancel(thread2)
for _ = 1, 100 do
local response = net.request(URL .. "/some/path?key=param1&key=param2&key2=param3").body
assert(response == RESPONSE, "Invalid response from server")
end
task.cancel(thread)
-- Stopping is not guaranteed to happen instantly since it is async, but
-- it should happen on the next yield, so we wait the minimum amount here
handle.stop()
task.wait()
-- Sending a net request may error if there was
-- a connection issue, we should handle that here
local success, response2 = pcall(net.request, URL)
if not success then
local message = tostring(response2)
assert(
string.find(message, "Connection reset")
or string.find(message, "Connection closed")
or string.find(message, "Connection refused")
or string.find(message, "No connection could be made"), -- Windows Request Error
"Server did not stop responding to requests"
)
else
assert(not response2.ok, "Server did not stop responding to requests")
end
--[[
Trying to stop the server again should error and
mention that the server has already been stopped
Note that we cast pcall to any because of a
Luau limitation where it throws a type error for
`err` because handle.stop doesn't return any value
]]
local success2, err = (pcall :: any)(handle.stop)
assert(not success2, "Calling stop twice on the net serve handle should error")
local message = tostring(err)
assert(
string.find(message, "stop")
or string.find(message, "shutdown")
or string.find(message, "shut down"),
"The error message for calling stop twice on the net serve handle should be descriptive"
)
-- Serve should be able to bind to other IP addresses
local handle2 = net.serve(PORT, {
address = URL_EXTERNAL,
handleRequest = function(request)
return `Response from {URL_EXTERNAL}:{PORT}`
end,
})
if process.os == "windows" then
-- In Windows, client cannot directly connect to `0.0.0.0`.
-- `0.0.0.0` is a non-routable meta-address.
URL_EXTERNAL = "http://localhost"
end
-- And any requests to that IP should succeed
local response3 = net.request(`{URL_EXTERNAL}:{PORT}`).body
assert(response3 ~= nil, "Invalid response from server")
handle2.stop()
-- Attempting to serve with a malformed IP address should throw an error
local success3 = pcall(function()
net.serve(8080, {
address = "a.b.c.d",
handleRequest = function()
return RESPONSE
end,
})
end)
assert(not success3, "Server was created with malformed address")
-- We have to manually exit so Windows CI doesn't get stuck forever
process.exit(0)

View file

@ -3,7 +3,7 @@ local process = require("@lune/process")
local stdio = require("@lune/stdio")
local task = require("@lune/task")
local PORT = 8081
local PORT = 8085
local WS_URL = `ws://127.0.0.1:{PORT}`
local REQUEST = "Hello from client!"
local RESPONSE = "Hello, lune!"

View file

@ -10,7 +10,7 @@ local echoResult = process.exec("echo", {
}, {
env = { TEST_VAR = echoMessage },
shell = if IS_WINDOWS then "powershell" else "bash",
stdio = "inherit" :: process.SpawnOptionsStdioKind, -- FIXME: This should just work without a cast?
stdio = "inherit",
})
-- Windows uses \r\n (CRLF) and unix uses \n (LF)

View file

@ -1,10 +1,10 @@
local inner = require("@self/module")
local inner = require("@self/module") :: any -- FIXME: luau-lsp does not yet support self alias
local outer = require("./module")
assert(type(outer) == "table", "Outer module is not a table")
assert(type(inner) == "table", "Inner module is not a table")
assert(outer.Foo == inner.Foo, "Outer and inner modules have different Foo values")
assert(inner.Bar == outer.Bar, "Outer and inner modules have different Bar values")
assert(inner.Hello == outer.Hello, "Outer and inner modules have different Hello values")
return inner

View file

@ -16,14 +16,11 @@ end)
-- Reference: https://create.roblox.com/docs/reference/engine/classes/HttpService#GetAsync
local URL_ASTROS = "http://api.open-notify.org/astros.json"
local game = roblox.Instance.new("DataModel")
local HttpService = game:GetService("HttpService") :: any
local response = HttpService:GetAsync(URL_ASTROS)
local response = HttpService:GetAsync("https://httpbingo.org/json")
local data = HttpService:JSONDecode(response)
assert(type(data) == "table", "Returned JSON data should decode to a table")
assert(data.message == "success", "Returned JSON data should have a 'message' with value 'success'")
assert(type(data.people) == "table", "Returned JSON data should have a 'people' table")
assert(type(data.slideshow) == "table", "Returned JSON data should contain 'slideshow'")

View file

@ -7,7 +7,7 @@ local objValue1 = Instance.new("ObjectValue")
local objValue2 = Instance.new("ObjectValue")
objValue1.Name = "ObjectValue1"
objValue2.Name = "ObjectValue2";
objValue2.Name = "ObjectValue2"
(objValue1 :: any).Value = root;
(objValue2 :: any).Value = child
objValue1.Parent = child

View file

@ -1,53 +1,13 @@
local net = require("@lune/net")
local serde = require("@lune/serde")
local source = require("./source")
type Response = {
products: {
{
id: number,
title: string,
description: string,
price: number,
discountPercentage: number,
rating: number,
stock: number,
brand: string,
category: string,
thumbnail: string,
images: { string },
}
},
total: number,
skip: number,
limit: number,
}
local decoded = serde.decode("json", source.pretty)
local response = net.request("https://dummyjson.com/products")
assert(response.ok, "Dummy JSON api returned an error")
assert(#response.body > 0, "Dummy JSON api returned empty body")
local data: Response = serde.decode("json", response.body)
assert(type(data.limit) == "number", "Products limit was not a number")
assert(type(data.products) == "table", "Products was not a table")
assert(#data.products > 0, "Products table was empty")
local productCount = 0
for _, product in data.products do
productCount += 1
assert(type(product.id) == "number", "Product id was not a number")
assert(type(product.title) == "string", "Product title was not a number")
assert(type(product.description) == "string", "Product description was not a number")
assert(type(product.images) == "table", "Product images was not a table")
assert(#product.images > 0, "Product images table was empty")
end
assert(
data.limit == productCount,
string.format(
"Products limit and number of products in array mismatch (expected %d, got %d)",
data.limit,
productCount
)
)
assert(type(decoded) == "table", "Decoded payload was not a table")
assert(decoded.Hello == "World", "Decoded payload Hello was not World")
assert(type(decoded.Inner) == "table", "Decoded payload Inner was not a table")
assert(type(decoded.Inner.Array) == "table", "Decoded payload Inner.Array was not a table")
assert(type(decoded.Inner.Array[1]) == "number", "Decoded payload Inner.Array[1] was not a number")
assert(type(decoded.Inner.Array[2]) == "number", "Decoded payload Inner.Array[2] was not a number")
assert(type(decoded.Inner.Array[3]) == "number", "Decoded payload Inner.Array[3] was not a number")
assert(decoded.Foo == "Bar", "Decoded payload Foo was not Bar")

View file

@ -2,16 +2,6 @@ local serde = require("@lune/serde")
local source = require("./source")
local decoded = serde.decode("json", source.pretty)
assert(type(decoded) == "table", "Decoded payload was not a table")
assert(decoded.Hello == "World", "Decoded payload Hello was not World")
assert(type(decoded.Inner) == "table", "Decoded payload Inner was not a table")
assert(type(decoded.Inner.Array) == "table", "Decoded payload Inner.Array was not a table")
assert(type(decoded.Inner.Array[1]) == "number", "Decoded payload Inner.Array[1] was not a number")
assert(type(decoded.Inner.Array[2]) == "number", "Decoded payload Inner.Array[2] was not a number")
assert(type(decoded.Inner.Array[3]) == "number", "Decoded payload Inner.Array[3] was not a number")
assert(decoded.Foo == "Bar", "Decoded payload Foo was not Bar")
local encoded = serde.encode("json", decoded, false)
assert(encoded == source.encoded, "JSON round-trip did not produce the same result")