Compare commits

...

26 commits
v0.9.0 ... main

Author SHA1 Message Date
Filip Tibell
df56cd58e7
Github does not like requests from github to github, remove it from https test 2025-05-02 22:33:28 +02:00
Filip Tibell
66e3b58cd7
Url and uri are not the same 2025-05-02 21:57:13 +02:00
Filip Tibell
fb33d1812d
Remove old unused app data 2025-05-02 12:31:58 +02:00
Filip Tibell
0ddaaaefb5
Update changelog 2025-05-02 12:29:46 +02:00
Filip Tibell
2e5b3bb5eb
Fix panicking during require because of long lived require context borrow 2025-05-02 12:27:20 +02:00
Filip Tibell
6645631c46
Properly store process args and env as part of runtime initialization instead of in std-process 2025-05-02 12:18:53 +02:00
Filip Tibell
120048ae95
Update changelog 2025-05-01 21:12:54 +02:00
Filip Tibell
2d8e58b028
Revamp handling of process args and env with fully featured newtypes 2025-05-01 21:08:58 +02:00
Filip Tibell
b1fc60023d
Add release date to changelog 2025-04-30 15:40:39 +02:00
Filip Tibell
1429450a64
Version 0.9.2 2025-04-30 15:40:08 +02:00
Filip Tibell
d2a89f41c8
Fixed https support in net client + update changelog 2025-04-30 15:28:29 +02:00
Filip Tibell
9c9b90d70d
Final optimizations and dependency cleanup for mlua-luau-scheduler 2025-04-30 14:54:10 +02:00
Filip Tibell
d425d2568a
Final optimizations for future and thread queues 2025-04-30 14:47:16 +02:00
Filip Tibell
4c2bbcf425
Optimize thread queue storage for spawned and deferred threads 2025-04-30 14:21:48 +02:00
Filip Tibell
461ca24c33
Implement optimized event listener for thread queues 2025-04-30 14:10:35 +02:00
Filip Tibell
7fd390dead
Organize queue-related files a bit better 2025-04-30 13:53:43 +02:00
Filip Tibell
c35eaa7899
Organize thread-related files a bit better 2025-04-30 13:39:35 +02:00
Filip Tibell
b57fa6fad3
Optimize tracking of thread results in mlua-luau-scheduler 2025-04-30 13:35:42 +02:00
Filip Tibell
3e80a0a1c4
Reduce overhead of lua result tracking 2025-04-29 23:31:10 +02:00
Filip Tibell
ac8c809a20
Implement zero-copy hyper body type that wraps over lua values 2025-04-29 23:00:03 +02:00
Filip Tibell
4079842a33
Re enable crate publishing in release workflow 2025-04-29 16:00:59 +02:00
Filip Tibell
464c431697
Version 0.9.1 2025-04-29 15:59:49 +02:00
Filip Tibell
39f6319bdb
Bye bye tokio 2025-04-29 15:38:04 +02:00
Filip Tibell
62910f02ab
Rewrite the net standard library with smol ecosystem of crates (#310) 2025-04-29 15:06:16 +02:00
Sasial
1f43ff89f7
RuntimeReturnValues should derive Debug (#309) 2025-04-26 22:51:17 +02:00
Filip Tibell
e234eab813
Make pretext in latest changelog entry clearer 2025-04-25 16:43:12 +02:00
109 changed files with 4058 additions and 2419 deletions

View file

@ -2,15 +2,16 @@ name: CI
on: on:
push: push:
pull_request:
workflow_dispatch: workflow_dispatch:
defaults: defaults:
run: run:
shell: bash shell: bash
jobs: env:
CARGO_TERM_COLOR: always
jobs:
fmt: fmt:
name: Check formatting name: Check formatting
runs-on: ubuntu-latest runs-on: ubuntu-latest
@ -79,23 +80,26 @@ jobs:
components: clippy components: clippy
targets: ${{ matrix.cargo-target }} targets: ${{ matrix.cargo-target }}
- name: Install binstall
uses: cargo-bins/cargo-binstall@main
- name: Install nextest
run: cargo binstall cargo-nextest
- name: Build - name: Build
run: | run: |
cargo build \ cargo build --workspace \
--workspace \
--locked --all-features \ --locked --all-features \
--target ${{ matrix.cargo-target }} --target ${{ matrix.cargo-target }}
- name: Lint - name: Lint
run: | run: |
cargo clippy \ cargo clippy --workspace \
--workspace \
--locked --all-features \ --locked --all-features \
--target ${{ matrix.cargo-target }} --target ${{ matrix.cargo-target }}
- name: Test - name: Test
run: | run: |
cargo test \ cargo nextest run --no-fail-fast \
--lib --workspace \
--locked --all-features \ --locked --all-features \
--target ${{ matrix.cargo-target }} --target ${{ matrix.cargo-target }}

View file

@ -27,23 +27,23 @@ jobs:
file: crates/lune/Cargo.toml file: crates/lune/Cargo.toml
field: package.version field: package.version
# dry-run: dry-run:
# name: Dry-run name: Dry-run
# needs: ["init"] needs: ["init"]
# runs-on: ubuntu-latest runs-on: ubuntu-latest
# steps: steps:
# - name: Checkout repository - name: Checkout repository
# uses: actions/checkout@v4 uses: actions/checkout@v4
# - name: Install Rust - name: Install Rust
# uses: dtolnay/rust-toolchain@stable uses: dtolnay/rust-toolchain@stable
# - name: Publish (dry-run) - name: Publish (dry-run)
# uses: katyo/publish-crates@v2 uses: katyo/publish-crates@v2
# with: with:
# dry-run: true dry-run: true
# check-repo: true check-repo: true
# registry-token: ${{ secrets.CARGO_REGISTRY_TOKEN }} registry-token: ${{ secrets.CARGO_REGISTRY_TOKEN }}
build: build:
needs: ["init"] # , "dry-run"] needs: ["init"] # , "dry-run"]
@ -139,20 +139,20 @@ jobs:
files: ./releases/*.zip files: ./releases/*.zip
draft: true draft: true
# release-crates: release-crates:
# name: Release (crates.io) name: Release (crates.io)
# runs-on: ubuntu-latest runs-on: ubuntu-latest
# needs: ["init", "dry-run", "build"] needs: ["init", "dry-run", "build"]
# steps: steps:
# - name: Checkout repository - name: Checkout repository
# uses: actions/checkout@v4 uses: actions/checkout@v4
# - name: Install Rust - name: Install Rust
# uses: dtolnay/rust-toolchain@stable uses: dtolnay/rust-toolchain@stable
# - name: Publish crates - name: Publish crates
# uses: katyo/publish-crates@v2 uses: katyo/publish-crates@v2
# with: with:
# dry-run: false dry-run: false
# check-repo: true check-repo: true
# registry-token: ${{ secrets.CARGO_REGISTRY_TOKEN }} registry-token: ${{ secrets.CARGO_REGISTRY_TOKEN }}

5
.gitignore vendored
View file

@ -21,7 +21,12 @@ lune.yml
luneDocs.json luneDocs.json
luneTypes.d.luau luneTypes.d.luau
# Dirs generated by runtime or build scripts
/types
# Files generated by runtime or build scripts # Files generated by runtime or build scripts
scripts/brick_color.rs scripts/brick_color.rs
scripts/font_enum_map.rs scripts/font_enum_map.rs
scripts/physical_properties_enum_map.rs scripts/physical_properties_enum_map.rs

View file

@ -65,6 +65,7 @@ fmt-check:
analyze: analyze:
#!/usr/bin/env bash #!/usr/bin/env bash
set -euo pipefail set -euo pipefail
lune run scripts/analyze_copy_typedefs
luau-lsp analyze \ luau-lsp analyze \
--settings=".vscode/settings.json" \ --settings=".vscode/settings.json" \
--ignore="tests/roblox/rbx-test-files/**" \ --ignore="tests/roblox/rbx-test-files/**" \

View file

@ -3,7 +3,6 @@
local net = require("@lune/net") local net = require("@lune/net")
local process = require("@lune/process") local process = require("@lune/process")
local task = require("@lune/task")
local PORT = if process.env.PORT ~= nil and #process.env.PORT > 0 local PORT = if process.env.PORT ~= nil and #process.env.PORT > 0
then assert(tonumber(process.env.PORT), "Failed to parse port from env") then assert(tonumber(process.env.PORT), "Failed to parse port from env")
@ -11,6 +10,10 @@ local PORT = if process.env.PORT ~= nil and #process.env.PORT > 0
-- Create our responder functions -- Create our responder functions
local function root(_request: net.ServeRequest): string
return `Hello from Lune server!`
end
local function pong(request: net.ServeRequest): string local function pong(request: net.ServeRequest): string
return `Pong!\n{request.path}\n{request.body}` return `Pong!\n{request.path}\n{request.body}`
end end
@ -29,10 +32,12 @@ local function notFound(_request: net.ServeRequest): net.ServeResponse
} }
end end
-- Run the server on port 8080 -- Run the server on the port forever
local handle = net.serve(PORT, function(request) net.serve(PORT, function(request)
if string.sub(request.path, 1, 5) == "/ping" then if request.path == "/" then
return root(request)
elseif string.sub(request.path, 1, 5) == "/ping" then
return pong(request) return pong(request)
elseif string.sub(request.path, 1, 7) == "/teapot" then elseif string.sub(request.path, 1, 7) == "/teapot" then
return teapot(request) return teapot(request)
@ -42,12 +47,4 @@ local handle = net.serve(PORT, function(request)
end) end)
print(`Listening on port {PORT} 🚀`) print(`Listening on port {PORT} 🚀`)
print("Press Ctrl+C to stop")
-- Exit our example after a small delay, if you copy this
-- example just remove this part to keep the server running
task.delay(2, function()
print("Shutting down...")
task.wait(1)
handle.stop()
end)

View file

@ -8,10 +8,52 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## Unreleased
### Added
- Added support for non-UTF8 strings in arguments to `process.exec` and `process.spawn`
### Changed
- Improved cross-platform compatibility and correctness for values in `process.args` and `process.env`, especially on Windows
### Fixed
- Fixed various crashes during require that had the error `cannot mutably borrow app data container`
## `0.9.2` - April 30th, 2025
### Changed
- Improved performance of `net.request` and `net.serve` when handling large request bodies
- Improved performance and memory usage of `task.spawn`, `task.defer`, and `task.delay`
### Fixed
- Fixed accidental breakage of `net.request` in version `0.9.1`
## `0.9.1` - April 29th, 2025
### Added
- Added support for automatic decompression of HTTP requests in `net.serve` ([#310])
### Fixed
- Fixed `net.serve` no longer serving requests if the returned `ServeHandle` is discarded ([#310])
- Fixed `net.serve` having various performance issues ([#310])
- Fixed Lune still running after cancelling a task such as `task.delay(5, ...)` and all tasks having completed
[#310]: https://github.com/lune-org/lune/pull/310
## `0.9.0` - April 25th, 2025 ## `0.9.0` - April 25th, 2025
This release has been a long time coming, and many breaking changes have been deferred until this version. The next major version of Lune has finally been released!
If you are an existing Lune user upgrading to this version, you will **most likely** be affected - please read the full list of breaking changes below.
This release has been a long time coming, and many breaking changes have been made.
If you are an existing Lune user upgrading to this version, you will **most likely** be affected.
The full list of breaking changes can be found on below.
### Breaking changes & additions ### Breaking changes & additions

989
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-roblox" name = "lune-roblox"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -25,4 +25,4 @@ rbx_reflection = "5.0"
rbx_reflection_database = "1.0" rbx_reflection_database = "1.0"
rbx_xml = "1.0" rbx_xml = "1.0"
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-datetime" name = "lune-std-datetime"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -19,4 +19,4 @@ thiserror = "2.0"
chrono = "0.4.38" chrono = "0.4.38"
chrono_lc = "0.1.6" chrono_lc = "0.1.6"
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-fs" name = "lune-std-fs"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -19,5 +19,5 @@ async-fs = "2.1"
bstr = "1.9" bstr = "1.9"
futures-lite = "2.6" futures-lite = "2.6"
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }
lune-std-datetime = { version = "0.2.0", path = "../lune-std-datetime" } lune-std-datetime = { version = "0.2.2", path = "../lune-std-datetime" }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-luau" name = "lune-std-luau"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -15,4 +15,4 @@ workspace = true
[dependencies] [dependencies]
mlua = { version = "0.10.3", features = ["luau", "luau-jit"] } mlua = { version = "0.10.3", features = ["luau", "luau-jit"] }
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-net" name = "lune-std-net"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -14,26 +14,29 @@ workspace = true
[dependencies] [dependencies]
mlua = { version = "0.10.3", features = ["luau"] } mlua = { version = "0.10.3", features = ["luau"] }
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" } mlua-luau-scheduler = { version = "0.1.2", path = "../mlua-luau-scheduler" }
async-channel = "2.3"
async-executor = "1.13"
async-io = "2.4"
async-lock = "3.4"
async-net = "2.0"
async-tungstenite = "0.29"
blocking = "1.6"
bstr = "1.9" bstr = "1.9"
futures-util = "0.3" form_urlencoded = "1.2"
hyper = { version = "1.1", features = ["full"] } futures = { version = "0.3", default-features = false, features = ["std"] }
hyper-util = { version = "0.1", features = ["full"] } futures-lite = "2.6"
http = "1.0" futures-rustls = "0.26"
http-body-util = { version = "0.1" } http-body-util = "0.1"
hyper-tungstenite = { version = "0.13" } hyper = { version = "1.6", default-features = false, features = ["http1", "client", "server"] }
reqwest = { version = "0.11", default-features = false, features = [ pin-project-lite = "0.2"
"rustls-tls", rustls = { version = "0.23", default-features = false, features = ["std", "tls12", "ring"] }
] } rustls-pki-types = "1.11"
tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] } url = "2.5"
urlencoding = "2.1" urlencoding = "2.1"
webpki = "0.22"
webpki-roots = "0.26"
tokio = { version = "1", default-features = false, features = [ lune-utils = { version = "0.2.2", path = "../lune-utils" }
"sync", lune-std-serde = { version = "0.2.2", path = "../lune-std-serde" }
"net",
"macros",
] }
lune-utils = { version = "0.2.0", path = "../lune-utils" }
lune-std-serde = { version = "0.2.0", path = "../lune-std-serde" }

View file

@ -0,0 +1,59 @@
use hyper::body::{Buf, Bytes};
use super::inner::ReadableBodyInner;
/**
The cursor keeping track of inner data and its position for a readable body.
*/
#[derive(Debug, Clone)]
pub struct ReadableBodyCursor {
inner: ReadableBodyInner,
start: usize,
}
impl ReadableBodyCursor {
pub fn len(&self) -> usize {
self.inner.len()
}
pub fn as_slice(&self) -> &[u8] {
&self.inner.as_slice()[self.start..]
}
pub fn advance(&mut self, cnt: usize) {
self.start += cnt;
if self.start > self.inner.len() {
self.start = self.inner.len();
}
}
pub fn into_bytes(self) -> Bytes {
self.inner.into_bytes()
}
}
impl Buf for ReadableBodyCursor {
fn remaining(&self) -> usize {
self.len().saturating_sub(self.start)
}
fn chunk(&self) -> &[u8] {
self.as_slice()
}
fn advance(&mut self, cnt: usize) {
self.advance(cnt);
}
}
impl<T> From<T> for ReadableBodyCursor
where
T: Into<ReadableBodyInner>,
{
fn from(value: T) -> Self {
Self {
inner: value.into(),
start: 0,
}
}
}

View file

@ -0,0 +1,35 @@
use http_body_util::BodyExt;
use hyper::{
body::{Bytes, Incoming},
header::CONTENT_ENCODING,
HeaderMap,
};
use mlua::prelude::*;
use lune_std_serde::{decompress, CompressDecompressFormat};
pub async fn handle_incoming_body(
headers: &HeaderMap,
body: Incoming,
should_decompress: bool,
) -> LuaResult<(Bytes, bool)> {
let mut body = body.collect().await.into_lua_err()?.to_bytes();
let was_decompressed = if should_decompress {
let decompress_format = headers
.get(CONTENT_ENCODING)
.and_then(|value| value.to_str().ok())
.and_then(CompressDecompressFormat::detect_from_header_str);
if let Some(format) = decompress_format {
body = Bytes::from(decompress(body, format).await?);
true
} else {
false
}
} else {
false
};
Ok((body, was_decompressed))
}

View file

@ -0,0 +1,110 @@
use hyper::body::{Buf as _, Bytes};
use mlua::{prelude::*, Buffer as LuaBuffer};
/**
The inner data for a readable body.
*/
#[derive(Debug, Clone)]
pub enum ReadableBodyInner {
Bytes(Bytes),
String(String),
LuaString(LuaString),
LuaBuffer(LuaBuffer),
}
impl ReadableBodyInner {
pub fn len(&self) -> usize {
match self {
Self::Bytes(b) => b.len(),
Self::String(s) => s.len(),
Self::LuaString(s) => s.as_bytes().len(),
Self::LuaBuffer(b) => b.len(),
}
}
pub fn as_slice(&self) -> &[u8] {
/*
SAFETY: Reading lua strings and lua buffers as raw slices is safe while we can
guarantee that the inner Lua value + main lua struct has not yet been dropped
1. Buffers are fixed-size and guaranteed to never resize
2. We do not expose any method for writing to the body, only reading
3. We guarantee that net.request and net.serve futures are only driven forward
while we also know that the Lua + scheduler pair have not yet been dropped
4. Any writes from within lua to a buffer, are considered user error,
and are not unsafe, since the only possible outcome with the above
guarantees is invalid / mangled contents in request / response bodies
*/
match self {
Self::Bytes(b) => b.chunk(),
Self::String(s) => s.as_bytes(),
Self::LuaString(s) => unsafe {
// BorrowedBytes would not let us return a plain slice here,
// which is what the Buf implementation below needs - we need to
// do a little hack here to re-create the slice without a lifetime
let b = s.as_bytes();
let ptr = b.as_ptr();
let len = b.len();
std::slice::from_raw_parts(ptr, len)
},
Self::LuaBuffer(b) => unsafe {
// Similar to above, we need to get the raw slice for the buffer,
// which is a bit trickier here because Buffer has a read + write
// interface instead of using slices for some unknown reason
let v = LuaValue::Buffer(b.clone());
let ptr = v.to_pointer().cast::<u8>();
let len = b.len();
std::slice::from_raw_parts(ptr, len)
},
}
}
pub fn into_bytes(self) -> Bytes {
match self {
Self::Bytes(b) => b,
Self::String(s) => Bytes::from(s),
Self::LuaString(s) => Bytes::from(s.as_bytes().to_vec()),
Self::LuaBuffer(b) => Bytes::from(b.to_vec()),
}
}
}
impl From<&'static str> for ReadableBodyInner {
fn from(value: &'static str) -> Self {
Self::Bytes(Bytes::from(value))
}
}
impl From<Vec<u8>> for ReadableBodyInner {
fn from(value: Vec<u8>) -> Self {
Self::Bytes(Bytes::from(value))
}
}
impl From<Bytes> for ReadableBodyInner {
fn from(value: Bytes) -> Self {
Self::Bytes(value)
}
}
impl From<String> for ReadableBodyInner {
fn from(value: String) -> Self {
Self::String(value)
}
}
impl From<LuaString> for ReadableBodyInner {
fn from(value: LuaString) -> Self {
Self::LuaString(value)
}
}
impl From<LuaBuffer> for ReadableBodyInner {
fn from(value: LuaBuffer) -> Self {
Self::LuaBuffer(value)
}
}

View file

@ -0,0 +1,11 @@
#![allow(unused_imports)]
mod cursor;
mod incoming;
mod inner;
mod readable;
pub use self::cursor::ReadableBodyCursor;
pub use self::incoming::handle_incoming_body;
pub use self::inner::ReadableBodyInner;
pub use self::readable::ReadableBody;

View file

@ -0,0 +1,105 @@
use std::convert::Infallible;
use std::pin::Pin;
use std::task::{Context, Poll};
use hyper::body::{Body, Bytes, Frame, SizeHint};
use mlua::prelude::*;
use super::cursor::ReadableBodyCursor;
/**
Zero-copy wrapper for a readable body.
Provides methods to read bytes that can be safely used if, and only
if, the respective Lua struct for the body has not yet been dropped.
If the body was created from a `Vec<u8>`, `Bytes`, or a `String`, reading
bytes is always safe and does not go through any additional indirections.
*/
#[derive(Debug, Clone)]
pub struct ReadableBody {
cursor: Option<ReadableBodyCursor>,
}
impl ReadableBody {
pub const fn empty() -> Self {
Self { cursor: None }
}
pub fn as_slice(&self) -> &[u8] {
match self.cursor.as_ref() {
Some(cursor) => cursor.as_slice(),
None => &[],
}
}
pub fn into_bytes(self) -> Bytes {
match self.cursor {
Some(cursor) => cursor.into_bytes(),
None => Bytes::new(),
}
}
}
impl Body for ReadableBody {
type Data = ReadableBodyCursor;
type Error = Infallible;
fn poll_frame(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
Poll::Ready(self.cursor.take().map(|d| Ok(Frame::data(d))))
}
fn is_end_stream(&self) -> bool {
self.cursor.is_none()
}
fn size_hint(&self) -> SizeHint {
self.cursor.as_ref().map_or_else(
|| SizeHint::with_exact(0),
|c| SizeHint::with_exact(c.len() as u64),
)
}
}
impl<T> From<T> for ReadableBody
where
T: Into<ReadableBodyCursor>,
{
fn from(value: T) -> Self {
Self {
cursor: Some(value.into()),
}
}
}
impl<T> From<Option<T>> for ReadableBody
where
T: Into<ReadableBodyCursor>,
{
fn from(value: Option<T>) -> Self {
Self {
cursor: value.map(Into::into),
}
}
}
impl FromLua for ReadableBody {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
match value {
LuaValue::Nil => Ok(Self::empty()),
LuaValue::String(str) => Ok(Self::from(str)),
LuaValue::Buffer(buf) => Ok(Self::from(buf)),
v => Err(LuaError::FromLuaConversionError {
from: v.type_name(),
to: "Body".to_string(),
message: Some(format!(
"Invalid body - expected string or buffer, got {}",
v.type_name()
)),
}),
}
}
}

View file

@ -1,163 +0,0 @@
use std::str::FromStr;
use mlua::prelude::*;
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_ENCODING};
use lune_std_serde::{decompress, CompressDecompressFormat};
use lune_utils::TableBuilder;
use super::{config::RequestConfig, util::header_map_to_table};
const REGISTRY_KEY: &str = "NetClient";
pub struct NetClientBuilder {
builder: reqwest::ClientBuilder,
}
impl NetClientBuilder {
pub fn new() -> NetClientBuilder {
Self {
builder: reqwest::ClientBuilder::new(),
}
}
pub fn headers<K, V>(mut self, headers: &[(K, V)]) -> LuaResult<Self>
where
K: AsRef<str>,
V: AsRef<[u8]>,
{
let mut map = HeaderMap::new();
for (key, val) in headers {
let hkey = HeaderName::from_str(key.as_ref()).into_lua_err()?;
let hval = HeaderValue::from_bytes(val.as_ref()).into_lua_err()?;
map.insert(hkey, hval);
}
self.builder = self.builder.default_headers(map);
Ok(self)
}
pub fn build(self) -> LuaResult<NetClient> {
let client = self.builder.build().into_lua_err()?;
Ok(NetClient { inner: client })
}
}
#[derive(Debug, Clone)]
pub struct NetClient {
inner: reqwest::Client,
}
impl NetClient {
pub fn from_registry(lua: &Lua) -> Self {
lua.named_registry_value(REGISTRY_KEY)
.expect("Failed to get NetClient from lua registry")
}
pub fn into_registry(self, lua: &Lua) {
lua.set_named_registry_value(REGISTRY_KEY, self)
.expect("Failed to store NetClient in lua registry");
}
pub async fn request(&self, config: RequestConfig) -> LuaResult<NetClientResponse> {
// Create and send the request
let mut request = self.inner.request(config.method, config.url);
for (query, values) in config.query {
request = request.query(
&values
.iter()
.map(|v| (query.as_str(), v))
.collect::<Vec<_>>(),
);
}
for (header, values) in config.headers {
for value in values {
request = request.header(header.as_str(), value);
}
}
let res = request
.body(config.body.unwrap_or_default())
.send()
.await
.into_lua_err()?;
// Extract status, headers
let res_status = res.status().as_u16();
let res_status_text = res.status().canonical_reason();
let res_headers = res.headers().clone();
// Read response bytes
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
let mut res_decompressed = false;
// Check for extra options, decompression
if config.options.decompress {
let decompress_format = res_headers
.iter()
.find(|(name, _)| {
name.as_str()
.eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
})
.and_then(|(_, value)| value.to_str().ok())
.and_then(CompressDecompressFormat::detect_from_header_str);
if let Some(format) = decompress_format {
res_bytes = decompress(res_bytes, format).await?;
res_decompressed = true;
}
}
Ok(NetClientResponse {
ok: (200..300).contains(&res_status),
status_code: res_status,
status_message: res_status_text.unwrap_or_default().to_string(),
headers: res_headers,
body: res_bytes,
body_decompressed: res_decompressed,
})
}
}
impl LuaUserData for NetClient {}
impl FromLua for NetClient {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::UserData(ud) = value {
if let Ok(ctx) = ud.borrow::<NetClient>() {
return Ok(ctx.clone());
}
}
unreachable!("NetClient should only be used from registry")
}
}
impl From<&Lua> for NetClient {
fn from(value: &Lua) -> Self {
value
.named_registry_value(REGISTRY_KEY)
.expect("Missing require context in lua registry")
}
}
pub struct NetClientResponse {
ok: bool,
status_code: u16,
status_message: String,
headers: HeaderMap,
body: Vec<u8>,
body_decompressed: bool,
}
impl NetClientResponse {
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
TableBuilder::new(lua.clone())?
.with_value("ok", self.ok)?
.with_value("statusCode", self.status_code)?
.with_value("statusMessage", self.status_message)?
.with_value(
"headers",
header_map_to_table(lua, self.headers, self.body_decompressed)?,
)?
.with_value("body", lua.create_string(&self.body)?)?
.build_readonly()
}
}

View file

@ -0,0 +1,95 @@
use std::{
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use async_net::TcpStream;
use futures_lite::prelude::*;
use futures_rustls::{TlsConnector, TlsStream};
use rustls_pki_types::ServerName;
use url::Url;
use crate::client::rustls::CLIENT_CONFIG;
#[derive(Debug)]
pub enum HttpStream {
Plain(TcpStream),
Tls(TlsStream<TcpStream>),
}
impl HttpStream {
pub async fn connect(url: Url) -> Result<Self, io::Error> {
let Some(host) = url.host() else {
return Err(make_err("unknown or missing host"));
};
let Some(port) = url.port_or_known_default() else {
return Err(make_err("unknown or missing port"));
};
let use_tls = match url.scheme() {
"http" => false,
"https" => true,
s => return Err(make_err(format!("unsupported scheme: {s}"))),
};
let host = host.to_string();
let stream = TcpStream::connect((host.clone(), port)).await?;
let stream = if use_tls {
let servname = ServerName::try_from(host).map_err(make_err)?.to_owned();
let connector = TlsConnector::from(Arc::clone(&CLIENT_CONFIG));
let stream = connector.connect(servname, stream).await?;
Self::Tls(TlsStream::Client(stream))
} else {
Self::Plain(stream)
};
Ok(stream)
}
}
impl AsyncRead for HttpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
HttpStream::Plain(stream) => Pin::new(stream).poll_read(cx, buf),
HttpStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf),
}
}
}
impl AsyncWrite for HttpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match &mut *self {
HttpStream::Plain(stream) => Pin::new(stream).poll_write(cx, buf),
HttpStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
HttpStream::Plain(stream) => Pin::new(stream).poll_close(cx),
HttpStream::Tls(stream) => Pin::new(stream).poll_close(cx),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
match &mut *self {
HttpStream::Plain(stream) => Pin::new(stream).poll_flush(cx),
HttpStream::Tls(stream) => Pin::new(stream).poll_flush(cx),
}
}
}
fn make_err(e: impl ToString) -> io::Error {
io::Error::new(io::ErrorKind::Other, e.to_string())
}

View file

@ -0,0 +1,125 @@
use http_body_util::Full;
use hyper::{
body::Incoming,
client::conn::http1::handshake,
header::{HeaderValue, ACCEPT, CONTENT_LENGTH, HOST, LOCATION, USER_AGENT},
Method, Request as HyperRequest, Response as HyperResponse, Uri,
};
use mlua::prelude::*;
use url::Url;
use crate::{
body::ReadableBody,
client::{http_stream::HttpStream, ws_stream::WsStream},
shared::{
headers::create_user_agent_header,
hyper::{HyperExecutor, HyperIo},
request::Request,
response::Response,
websocket::Websocket,
},
};
pub mod http_stream;
pub mod rustls;
pub mod ws_stream;
const MAX_REDIRECTS: usize = 10;
/**
Connects to a websocket at the given URL.
*/
pub async fn connect_websocket(url: Url) -> LuaResult<Websocket<WsStream>> {
let stream = WsStream::connect(url).await?;
Ok(Websocket::from(stream))
}
/**
Sends the request and returns the final response.
This will follow any redirects returned by the server,
modifying the request method and body as necessary.
*/
pub async fn send_request(mut request: Request, lua: Lua) -> LuaResult<Response> {
let url = request
.inner
.uri()
.to_string()
.parse::<Url>()
.into_lua_err()?;
// Some headers are required by most if not
// all servers, make sure those are present...
if !request.headers().contains_key(HOST.as_str()) {
if let Some(host) = url.host_str() {
let host = HeaderValue::from_str(host).into_lua_err()?;
request.inner.headers_mut().insert(HOST, host);
}
}
if !request.headers().contains_key(USER_AGENT.as_str()) {
let ua = create_user_agent_header(&lua)?;
let ua = HeaderValue::from_str(&ua).into_lua_err()?;
request.inner.headers_mut().insert(USER_AGENT, ua);
}
if !request.headers().contains_key(CONTENT_LENGTH.as_str()) && request.method() != Method::GET {
let len = request.body().len().to_string();
let len = HeaderValue::from_str(&len).into_lua_err()?;
request.inner.headers_mut().insert(CONTENT_LENGTH, len);
}
if !request.headers().contains_key(ACCEPT.as_str()) {
let accept = HeaderValue::from_static("*/*");
request.inner.headers_mut().insert(ACCEPT, accept);
}
// ... we can now safely continue and send the request
loop {
let stream = HttpStream::connect(url.clone()).await?;
let (mut sender, conn) = handshake(HyperIo::from(stream)).await.into_lua_err()?;
HyperExecutor::execute(lua.clone(), conn);
let (parts, body) = request.clone_inner().into_parts();
let data = HyperRequest::from_parts(parts, Full::new(body.into_bytes()));
let incoming = sender.send_request(data).await.into_lua_err()?;
if let Some((new_method, new_uri)) =
check_redirect(request.inner.method().clone(), &incoming)
{
if request.redirects.is_some_and(|r| r >= MAX_REDIRECTS) {
return Err(LuaError::external("Too many redirects"));
}
if new_method == Method::GET {
*request.inner.body_mut() = ReadableBody::empty();
}
*request.inner.method_mut() = new_method;
*request.inner.uri_mut() = new_uri;
*request.redirects.get_or_insert_default() += 1;
continue;
}
break Response::from_incoming(incoming, request.decompress).await;
}
}
fn check_redirect(method: Method, response: &HyperResponse<Incoming>) -> Option<(Method, Uri)> {
if !response.status().is_redirection() {
return None;
}
let location = response.headers().get(LOCATION)?;
let location = location.to_str().ok()?;
let location = location.parse().ok()?;
let method = match response.status().as_u16() {
301..=303 => Method::GET,
_ => method,
};
Some((method, location))
}

View file

@ -0,0 +1,26 @@
use std::sync::{
atomic::{AtomicBool, Ordering},
Arc, LazyLock,
};
use rustls::{crypto::ring, ClientConfig};
static PROVIDER_INITIALIZED: AtomicBool = AtomicBool::new(false);
pub fn initialize_provider() {
if !PROVIDER_INITIALIZED.load(Ordering::Relaxed) {
PROVIDER_INITIALIZED.store(true, Ordering::Relaxed);
// Only errors if already installed, which is fine
ring::default_provider().install_default().ok();
}
}
pub static CLIENT_CONFIG: LazyLock<Arc<ClientConfig>> = LazyLock::new(|| {
initialize_provider();
rustls::ClientConfig::builder()
.with_root_certificates(rustls::RootCertStore {
roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
})
.with_no_client_auth()
.into()
});

View file

@ -0,0 +1,114 @@
use std::{
io,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use async_net::TcpStream;
use async_tungstenite::{
tungstenite::{Error as TungsteniteError, Message, Result as TungsteniteResult},
WebSocketStream as TungsteniteStream,
};
use futures::Sink;
use futures_lite::prelude::*;
use futures_rustls::{TlsConnector, TlsStream};
use rustls_pki_types::ServerName;
use url::Url;
use crate::client::rustls::CLIENT_CONFIG;
#[derive(Debug)]
pub enum WsStream {
Plain(TungsteniteStream<TcpStream>),
Tls(TungsteniteStream<TlsStream<TcpStream>>),
}
impl WsStream {
pub async fn connect(url: Url) -> Result<Self, io::Error> {
let Some(host) = url.host() else {
return Err(make_err("unknown or missing host"));
};
let Some(port) = url.port_or_known_default() else {
return Err(make_err("unknown or missing port"));
};
let use_tls = match url.scheme() {
"ws" => false,
"wss" => true,
s => return Err(make_err(format!("unsupported scheme: {s}"))),
};
let host = host.to_string();
let stream = TcpStream::connect((host.clone(), port)).await?;
let stream = if use_tls {
let servname = ServerName::try_from(host).map_err(make_err)?.to_owned();
let connector = TlsConnector::from(Arc::clone(&CLIENT_CONFIG));
let stream = connector.connect(servname, stream).await?;
let stream = TlsStream::Client(stream);
let stream = async_tungstenite::client_async(url.to_string(), stream)
.await
.map_err(make_err)?
.0;
Self::Tls(stream)
} else {
let stream = async_tungstenite::client_async(url.to_string(), stream)
.await
.map_err(make_err)?
.0;
Self::Plain(stream)
};
Ok(stream)
}
}
impl Sink<Message> for WsStream {
type Error = TungsteniteError;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_ready(cx),
WsStream::Tls(s) => Pin::new(s).poll_ready(cx),
}
}
fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).start_send(item),
WsStream::Tls(s) => Pin::new(s).start_send(item),
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_flush(cx),
WsStream::Tls(s) => Pin::new(s).poll_flush(cx),
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_close(cx),
WsStream::Tls(s) => Pin::new(s).poll_close(cx),
}
}
}
impl Stream for WsStream {
type Item = TungsteniteResult<Message>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match &mut *self {
WsStream::Plain(s) => Pin::new(s).poll_next(cx),
WsStream::Tls(s) => Pin::new(s).poll_next(cx),
}
}
}
fn make_err(e: impl ToString) -> io::Error {
io::Error::new(io::ErrorKind::Other, e.to_string())
}

View file

@ -1,231 +0,0 @@
use std::{
collections::HashMap,
net::{IpAddr, Ipv4Addr},
};
use bstr::{BString, ByteSlice};
use mlua::prelude::*;
use reqwest::Method;
use super::util::table_to_hash_map;
const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#"
return {
status = 426,
body = "Upgrade Required",
headers = {
Upgrade = "websocket",
},
}
"#;
// Net request config
#[derive(Debug, Clone)]
pub struct RequestConfigOptions {
pub decompress: bool,
}
impl Default for RequestConfigOptions {
fn default() -> Self {
Self { decompress: true }
}
}
impl FromLua for RequestConfigOptions {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::Nil = value {
// Nil means default options
Ok(Self::default())
} else if let LuaValue::Table(tab) = value {
// Table means custom options
let decompress = match tab.get::<Option<bool>>("decompress") {
Ok(decomp) => Ok(decomp.unwrap_or(true)),
Err(_) => Err(LuaError::RuntimeError(
"Invalid option value for 'decompress' in request config options".to_string(),
)),
}?;
Ok(Self { decompress })
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfigOptions".to_string(),
message: Some(format!(
"Invalid request config options - expected table or nil, got {}",
value.type_name()
)),
})
}
}
}
#[derive(Debug, Clone)]
pub struct RequestConfig {
pub url: String,
pub method: Method,
pub query: HashMap<String, Vec<String>>,
pub headers: HashMap<String, Vec<String>>,
pub body: Option<Vec<u8>>,
pub options: RequestConfigOptions,
}
impl FromLua for RequestConfig {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
// If we just got a string we assume its a GET request to a given url
if let LuaValue::String(s) = value {
Ok(Self {
url: s.to_string_lossy().to_string(),
method: Method::GET,
query: HashMap::new(),
headers: HashMap::new(),
body: None,
options: RequestConfigOptions::default(),
})
} else if let LuaValue::Table(tab) = value {
// If we got a table we are able to configure the entire request
// Extract url
let url = match tab.get::<LuaString>("url") {
Ok(config_url) => Ok(config_url.to_string_lossy().to_string()),
Err(_) => Err(LuaError::runtime("Missing 'url' in request config")),
}?;
// Extract method
let method = match tab.get::<LuaString>("method") {
Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(),
Err(_) => "GET".to_string(),
};
// Extract query
let query = match tab.get::<LuaTable>("query") {
Ok(tab) => table_to_hash_map(tab, "query")?,
Err(_) => HashMap::new(),
};
// Extract headers
let headers = match tab.get::<LuaTable>("headers") {
Ok(tab) => table_to_hash_map(tab, "headers")?,
Err(_) => HashMap::new(),
};
// Extract body
let body = match tab.get::<BString>("body") {
Ok(config_body) => Some(config_body.as_bytes().to_owned()),
Err(_) => None,
};
// Convert method string into proper enum
let method = method.trim().to_ascii_uppercase();
let method = match method.as_ref() {
"GET" => Ok(Method::GET),
"POST" => Ok(Method::POST),
"PUT" => Ok(Method::PUT),
"DELETE" => Ok(Method::DELETE),
"HEAD" => Ok(Method::HEAD),
"OPTIONS" => Ok(Method::OPTIONS),
"PATCH" => Ok(Method::PATCH),
_ => Err(LuaError::RuntimeError(format!(
"Invalid request config method '{}'",
&method
))),
}?;
// Parse any extra options given
let options = match tab.get::<LuaValue>("options") {
Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?,
Err(_) => RequestConfigOptions::default(),
};
// All good, validated and we got what we need
Ok(Self {
url,
method,
query,
headers,
body,
options,
})
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfig".to_string(),
message: Some(format!(
"Invalid request config - expected string or table, got {}",
value.type_name()
)),
})
}
}
}
// Net serve config
#[derive(Debug)]
pub struct ServeConfig {
pub address: IpAddr,
pub handle_request: LuaFunction,
pub handle_web_socket: Option<LuaFunction>,
}
impl FromLua for ServeConfig {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
if let LuaValue::Function(f) = &value {
// Single function = request handler, rest is default
Ok(ServeConfig {
handle_request: f.clone(),
handle_web_socket: None,
address: DEFAULT_IP_ADDRESS,
})
} else if let LuaValue::Table(t) = &value {
// Table means custom options
let address: Option<LuaString> = t.get("address")?;
let handle_request: Option<LuaFunction> = t.get("handleRequest")?;
let handle_web_socket: Option<LuaFunction> = t.get("handleWebSocket")?;
if handle_request.is_some() || handle_web_socket.is_some() {
let address: IpAddr = match &address {
Some(addr) => {
let addr_str = addr.to_str()?;
addr_str
.trim_start_matches("http://")
.trim_start_matches("https://")
.parse()
.map_err(|_e| LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: Some(format!(
"IP address format is incorrect - \
expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \
got '{addr_str}'"
)),
})?
}
None => DEFAULT_IP_ADDRESS,
};
Ok(Self {
address,
handle_request: handle_request.unwrap_or_else(|| {
lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER)
.into_function()
.expect("Failed to create default http responder function")
}),
handle_web_socket,
})
} else {
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: Some(String::from(
"Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function",
)),
})
}
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: None,
})
}
}
}

View file

@ -1,22 +1,18 @@
#![allow(clippy::cargo_common_metadata)] #![allow(clippy::cargo_common_metadata)]
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
mod client;
mod config;
mod server;
mod util;
mod websocket;
use lune_utils::TableBuilder; use lune_utils::TableBuilder;
use mlua::prelude::*;
pub(crate) mod body;
pub(crate) mod client;
pub(crate) mod server;
pub(crate) mod shared;
pub(crate) mod url;
use self::{ use self::{
client::{NetClient, NetClientBuilder}, client::ws_stream::WsStream,
config::{RequestConfig, ServeConfig}, server::config::ServeConfig,
server::serve, shared::{request::Request, response::Response, websocket::Websocket},
util::create_user_agent_header,
websocket::NetWebSocket,
}; };
const TYPEDEFS: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/types.d.luau")); const TYPEDEFS: &str = include_str!(concat!(env!("CARGO_MANIFEST_DIR"), "/types.d.luau"));
@ -37,10 +33,8 @@ pub fn typedefs() -> String {
Errors when out of memory. Errors when out of memory.
*/ */
pub fn module(lua: Lua) -> LuaResult<LuaTable> { pub fn module(lua: Lua) -> LuaResult<LuaTable> {
NetClientBuilder::new() // No initial rustls setup is necessary, the respective
.headers(&[("User-Agent", create_user_agent_header(&lua)?)])? // functions lazily initialize anything there as needed
.build()?
.into_registry(&lua);
TableBuilder::new(lua)? TableBuilder::new(lua)?
.with_async_function("request", net_request)? .with_async_function("request", net_request)?
.with_async_function("socket", net_socket)? .with_async_function("socket", net_socket)?
@ -50,42 +44,35 @@ pub fn module(lua: Lua) -> LuaResult<LuaTable> {
.build_readonly() .build_readonly()
} }
async fn net_request(lua: Lua, config: RequestConfig) -> LuaResult<LuaTable> { async fn net_request(lua: Lua, req: Request) -> LuaResult<Response> {
let client = NetClient::from_registry(&lua); self::client::send_request(req, lua).await
// NOTE: We spawn the request as a background task to free up resources in lua
let res = lua.spawn(async move { client.request(config).await });
res.await?.into_lua_table(&lua)
} }
async fn net_socket(lua: Lua, url: String) -> LuaResult<LuaValue> { async fn net_socket(_: Lua, url: String) -> LuaResult<Websocket<WsStream>> {
let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?; let url = url.parse().into_lua_err()?;
NetWebSocket::new(ws).into_lua(&lua) self::client::connect_websocket(url).await
} }
async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<LuaTable> { async fn net_serve(lua: Lua, (port, config): (u16, ServeConfig)) -> LuaResult<LuaTable> {
serve(lua, port, config).await self::server::serve(lua.clone(), port, config)
.await?
.into_lua_table(lua)
} }
fn net_url_encode( fn net_url_encode(
lua: &Lua, lua: &Lua,
(lua_string, as_binary): (LuaString, Option<bool>), (lua_string, as_binary): (LuaString, Option<bool>),
) -> LuaResult<LuaValue> { ) -> LuaResult<LuaString> {
if matches!(as_binary, Some(true)) { let as_binary = as_binary.unwrap_or_default();
urlencoding::encode_binary(&lua_string.as_bytes()).into_lua(lua) let bytes = self::url::encode(lua_string, as_binary)?;
} else { lua.create_string(bytes)
urlencoding::encode(&lua_string.to_str()?).into_lua(lua)
}
} }
fn net_url_decode( fn net_url_decode(
lua: &Lua, lua: &Lua,
(lua_string, as_binary): (LuaString, Option<bool>), (lua_string, as_binary): (LuaString, Option<bool>),
) -> LuaResult<LuaValue> { ) -> LuaResult<LuaString> {
if matches!(as_binary, Some(true)) { let as_binary = as_binary.unwrap_or_default();
urlencoding::decode_binary(&lua_string.as_bytes()).into_lua(lua) let bytes = self::url::decode(lua_string, as_binary)?;
} else { lua.create_string(bytes)
urlencoding::decode(&lua_string.to_str()?)
.map_err(|e| LuaError::RuntimeError(format!("Encountered invalid encoding - {e}")))?
.into_lua(lua)
}
} }

View file

@ -0,0 +1,87 @@
use std::net::{IpAddr, Ipv4Addr};
use mlua::prelude::*;
const DEFAULT_IP_ADDRESS: IpAddr = IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1));
const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#"
return {
status = 426,
body = "Upgrade Required",
headers = {
Upgrade = "websocket",
},
}
"#;
#[derive(Debug, Clone)]
pub struct ServeConfig {
pub address: IpAddr,
pub handle_request: LuaFunction,
pub handle_web_socket: Option<LuaFunction>,
}
impl FromLua for ServeConfig {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
if let LuaValue::Function(f) = &value {
// Single function = request handler, rest is default
Ok(ServeConfig {
handle_request: f.clone(),
handle_web_socket: None,
address: DEFAULT_IP_ADDRESS,
})
} else if let LuaValue::Table(t) = &value {
// Table means custom options
let address: Option<LuaString> = t.get("address")?;
let handle_request: Option<LuaFunction> = t.get("handleRequest")?;
let handle_web_socket: Option<LuaFunction> = t.get("handleWebSocket")?;
if handle_request.is_some() || handle_web_socket.is_some() {
let address: IpAddr = match &address {
Some(addr) => {
let addr_str = addr.to_str()?;
addr_str
.trim_start_matches("http://")
.trim_start_matches("https://")
.parse()
.map_err(|_e| LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: Some(format!(
"IP address format is incorrect - \
expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \
got '{addr_str}'"
)),
})?
}
None => DEFAULT_IP_ADDRESS,
};
Ok(Self {
address,
handle_request: handle_request.unwrap_or_else(|| {
lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER)
.into_function()
.expect("Failed to create default http responder function")
}),
handle_web_socket,
})
} else {
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: Some(String::from(
"Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function",
)),
})
}
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig".to_string(),
message: None,
})
}
}
}

View file

@ -0,0 +1,72 @@
use std::{
net::SocketAddr,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use async_channel::{unbounded, Receiver, Sender};
use lune_utils::TableBuilder;
use mlua::prelude::*;
#[derive(Debug, Clone)]
pub struct ServeHandle {
addr: SocketAddr,
shutdown: Arc<AtomicBool>,
sender: Sender<()>,
}
impl ServeHandle {
pub fn new(addr: SocketAddr) -> (Self, Receiver<()>) {
let (sender, receiver) = unbounded();
let this = Self {
addr,
shutdown: Arc::new(AtomicBool::new(false)),
sender,
};
(this, receiver)
}
// TODO: Remove this in the next major release to use colon/self
// based call syntax and userdata implementation below instead
pub fn into_lua_table(self, lua: Lua) -> LuaResult<LuaTable> {
let shutdown = self.shutdown.clone();
let sender = self.sender.clone();
TableBuilder::new(lua)?
.with_value("ip", self.addr.ip().to_string())?
.with_value("port", self.addr.port())?
.with_function("stop", move |_, ()| {
if shutdown.load(Ordering::SeqCst) {
Err(LuaError::runtime("Server already stopped"))
} else {
shutdown.store(true, Ordering::SeqCst);
sender.try_send(()).ok();
sender.close();
Ok(())
}
})?
.build()
}
}
impl LuaUserData for ServeHandle {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ip", |_, this| Ok(this.addr.ip().to_string()));
fields.add_field_method_get("port", |_, this| Ok(this.addr.port()));
}
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_method("stop", |_, this, ()| {
if this.shutdown.load(Ordering::SeqCst) {
Err(LuaError::runtime("Server already stopped"))
} else {
this.shutdown.store(true, Ordering::SeqCst);
this.sender.try_send(()).ok();
this.sender.close();
Ok(())
}
});
}
}

View file

@ -1,58 +0,0 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use mlua::prelude::*;
#[derive(Debug, Clone, Copy)]
pub(super) struct SvcKeys {
key_request: &'static str,
key_websocket: Option<&'static str>,
}
impl SvcKeys {
pub(super) fn new(
lua: Lua,
handle_request: LuaFunction,
handle_websocket: Option<LuaFunction>,
) -> LuaResult<Self> {
static SERVE_COUNTER: AtomicUsize = AtomicUsize::new(0);
let count = SERVE_COUNTER.fetch_add(1, Ordering::Relaxed);
// NOTE: We leak strings here, but this is an acceptable tradeoff since programs
// generally only start one or a couple of servers and they are usually never dropped.
// Leaking here lets us keep this struct Copy and access the request handler callbacks
// very performantly, significantly reducing the per-request overhead of the server.
let key_request: &'static str =
Box::leak(format!("__net_serve_request_{count}").into_boxed_str());
let key_websocket: Option<&'static str> = if handle_websocket.is_some() {
Some(Box::leak(
format!("__net_serve_websocket_{count}").into_boxed_str(),
))
} else {
None
};
lua.set_named_registry_value(key_request, handle_request)?;
if let Some(key) = key_websocket {
lua.set_named_registry_value(key, handle_websocket.unwrap())?;
}
Ok(Self {
key_request,
key_websocket,
})
}
pub(super) fn has_websocket_handler(&self) -> bool {
self.key_websocket.is_some()
}
pub(super) fn request_handler(&self, lua: &Lua) -> LuaResult<LuaFunction> {
lua.named_registry_value(self.key_request)
}
pub(super) fn websocket_handler(&self, lua: &Lua) -> LuaResult<Option<LuaFunction>> {
self.key_websocket
.map(|key| lua.named_registry_value(key))
.transpose()
}
}

View file

@ -1,92 +1,121 @@
use std::net::SocketAddr; use std::{cell::Cell, net::SocketAddr, rc::Rc};
use hyper::server::conn::http1; use async_net::TcpListener;
use hyper_util::rt::TokioIo; use futures_lite::pin;
use tokio::{net::TcpListener, pin}; use hyper::server::conn::http1::Builder as Http1Builder;
use mlua::prelude::*; use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt; use mlua_luau_scheduler::LuaSpawnExt;
use lune_utils::TableBuilder; use crate::{
server::{config::ServeConfig, handle::ServeHandle, service::Service},
shared::{
futures::{either, Either},
hyper::{HyperIo, HyperTimer},
},
};
use super::config::ServeConfig; pub mod config;
pub mod handle;
pub mod service;
pub mod upgrade;
mod keys; /**
mod request; Starts an HTTP server using the given port and configuration.
mod response;
mod service;
use keys::SvcKeys; Returns a `ServeHandle` that can be used to gracefully stop the server.
use service::Svc; */
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<ServeHandle> {
pub async fn serve(lua: Lua, port: u16, config: ServeConfig) -> LuaResult<LuaTable> { let address = SocketAddr::from((config.address, port));
let addr: SocketAddr = (config.address, port).into(); let service = Service {
let listener = TcpListener::bind(addr).await?; lua: lua.clone(),
address,
let lua_svc = lua.clone(); config,
let lua_inner = lua.clone();
let keys = SvcKeys::new(lua.clone(), config.handle_request, config.handle_web_socket)?;
let svc = Svc {
lua: lua_svc,
addr,
keys,
}; };
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); let listener = TcpListener::bind(address).await?;
lua.spawn_local(async move { let (handle, shutdown_rx) = ServeHandle::new(address);
let mut shutdown_rx_outer = shutdown_rx.clone();
loop { lua.spawn_local({
// Create futures for accepting new connections and shutting down let lua = lua.clone();
let fut_shutdown = shutdown_rx_outer.changed(); async move {
let fut_accept = async { let handle_dropped = Rc::new(Cell::new(false));
let stream = match listener.accept().await { loop {
Err(_) => return, // 1. Keep accepting new connections until we should shutdown
Ok((s, _)) => s, let (conn, addr) = if handle_dropped.get() {
// 1a. Handle has been dropped, and we don't need to listen for shutdown
match listener.accept().await {
Ok(acc) => acc,
Err(_err) => {
// TODO: Propagate error somehow
continue;
}
}
} else {
// 1b. Handle is possibly active, we must listen for shutdown
match either(shutdown_rx.recv(), listener.accept()).await {
Either::Left(Ok(())) => break,
Either::Left(Err(_)) => {
// NOTE #1: We will only get a RecvError if the serve handle is dropped,
// this means lua has garbage collected it and the user does not want
// to manually stop the server using the serve handle. Run forever.
handle_dropped.set(true);
continue;
}
Either::Right(Ok(acc)) => acc,
Either::Right(Err(_err)) => {
// TODO: Propagate error somehow
continue;
}
}
}; };
let io = TokioIo::new(stream); // 2. For each connection, spawn a new task to handle it
let svc = svc.clone(); lua.spawn_local({
let mut shutdown_rx_inner = shutdown_rx.clone(); let rx = shutdown_rx.clone();
let io = HyperIo::from(conn);
lua_inner.spawn_local(async move { let mut svc = service.clone();
let conn = http1::Builder::new() svc.address = addr;
.keep_alive(true) // Web sockets need this
.serve_connection(io, svc) let handle_dropped = Rc::clone(&handle_dropped);
.with_upgrades(); async move {
// NOTE: Because we need to use keep_alive for websockets, we need to let conn = Http1Builder::new()
// also manually poll this future and handle the shutdown signal here .writev(false)
pin!(conn); .timer(HyperTimer)
tokio::select! { .keep_alive(true)
_ = conn.as_mut() => {} .serve_connection(io, svc)
_ = shutdown_rx_inner.changed() => { .with_upgrades();
conn.as_mut().graceful_shutdown(); if handle_dropped.get() {
if let Err(_err) = conn.await {
// TODO: Propagate error somehow
}
} else {
// NOTE #2: Because we use keep_alive for websockets above, we need to
// also manually poll this future and handle the graceful shutdown,
// otherwise the already accepted connection will linger and run
// even if the stop method has been called on the serve handle
pin!(conn);
match either(rx.recv(), conn.as_mut()).await {
Either::Left(Ok(())) => conn.as_mut().graceful_shutdown(),
Either::Left(Err(_)) => {
// Same as note #1
handle_dropped.set(true);
if let Err(_err) = conn.await {
// TODO: Propagate error somehow
}
}
Either::Right(Ok(())) => {}
Either::Right(Err(_err)) => {
// TODO: Propagate error somehow
}
}
} }
} }
}); });
};
// Wait for either a new connection or a shutdown signal
tokio::select! {
() = fut_accept => {}
res = fut_shutdown => {
// NOTE: We will only get a RecvError here if the serve handle is dropped,
// this means lua has garbage collected it and the user does not want
// to manually stop the server using the serve handle. Run forever.
if res.is_ok() {
break;
}
}
} }
} }
}); });
TableBuilder::new(lua)? Ok(handle)
.with_value("ip", addr.ip().to_string())?
.with_value("port", addr.port())?
.with_function("stop", move |_, (): ()| match shutdown_tx.send(true) {
Ok(()) => Ok(()),
Err(_) => Err(LuaError::runtime("Server already stopped")),
})?
.build_readonly()
} }

View file

@ -1,56 +0,0 @@
use std::{collections::HashMap, net::SocketAddr};
use http::request::Parts;
use mlua::prelude::*;
use lune_utils::TableBuilder;
pub(super) struct LuaRequest {
pub(super) _remote_addr: SocketAddr,
pub(super) head: Parts,
pub(super) body: Vec<u8>,
}
impl LuaRequest {
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
let method = self.head.method.as_str().to_string();
let path = self.head.uri.path().to_string();
let body = lua.create_string(&self.body)?;
#[allow(clippy::mutable_key_type)]
let query: HashMap<LuaString, LuaString> = self
.head
.uri
.query()
.unwrap_or_default()
.split('&')
.filter_map(|q| q.split_once('='))
.map(|(k, v)| {
let k = lua.create_string(k)?;
let v = lua.create_string(v)?;
Ok((k, v))
})
.collect::<LuaResult<_>>()?;
#[allow(clippy::mutable_key_type)]
let headers: HashMap<LuaString, LuaString> = self
.head
.headers
.iter()
.map(|(k, v)| {
let k = lua.create_string(k.as_str())?;
let v = lua.create_string(v.as_bytes())?;
Ok((k, v))
})
.collect::<LuaResult<_>>()?;
TableBuilder::new(lua.clone())?
.with_value("method", method)?
.with_value("path", path)?
.with_value("query", query)?
.with_value("headers", headers)?
.with_value("body", body)?
.build()
}
}

View file

@ -1,89 +0,0 @@
use std::str::FromStr;
use bstr::{BString, ByteSlice};
use http_body_util::Full;
use hyper::{
body::Bytes,
header::{HeaderName, HeaderValue},
HeaderMap, Response,
};
use mlua::prelude::*;
#[derive(Debug, Clone, Copy)]
pub(super) enum LuaResponseKind {
PlainText,
Table,
}
pub(super) struct LuaResponse {
pub(super) kind: LuaResponseKind,
pub(super) status: u16,
pub(super) headers: HeaderMap,
pub(super) body: Option<Vec<u8>>,
}
impl LuaResponse {
pub(super) fn into_response(self) -> LuaResult<Response<Full<Bytes>>> {
Ok(match self.kind {
LuaResponseKind::PlainText => Response::builder()
.status(200)
.header("Content-Type", "text/plain")
.body(Full::new(Bytes::from(self.body.unwrap())))
.into_lua_err()?,
LuaResponseKind::Table => {
let mut response = Response::builder()
.status(self.status)
.body(Full::new(Bytes::from(self.body.unwrap_or_default())))
.into_lua_err()?;
response.headers_mut().extend(self.headers);
response
}
})
}
}
impl FromLua for LuaResponse {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
match value {
// Plain strings from the handler are plaintext responses
LuaValue::String(s) => Ok(Self {
kind: LuaResponseKind::PlainText,
status: 200,
headers: HeaderMap::new(),
body: Some(s.as_bytes().to_vec()),
}),
// Tables are more detailed responses with potential status, headers, body
LuaValue::Table(t) => {
let status: Option<u16> = t.get("status")?;
let headers: Option<LuaTable> = t.get("headers")?;
let body: Option<BString> = t.get("body")?;
let mut headers_map = HeaderMap::new();
if let Some(headers) = headers {
for pair in headers.pairs::<String, LuaString>() {
let (h, v) = pair?;
let name = HeaderName::from_str(&h).into_lua_err()?;
let value = HeaderValue::from_bytes(&v.as_bytes()).into_lua_err()?;
headers_map.insert(name, value);
}
}
let body_bytes = body.map(|s| s.as_bytes().to_vec());
Ok(Self {
kind: LuaResponseKind::Table,
status: status.unwrap_or(200),
headers: headers_map,
body: body_bytes,
})
}
// Anything else is an error
value => Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "NetServeResponse".to_string(),
message: None,
}),
}
}
}

View file

@ -1,82 +1,116 @@
use std::{future::Future, net::SocketAddr, pin::Pin}; use std::{future::Future, net::SocketAddr, pin::Pin};
use http_body_util::{BodyExt, Full}; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use hyper::{ use hyper::{
body::{Bytes, Incoming}, body::Incoming, service::Service as HyperService, Request as HyperRequest,
service::Service, Response as HyperResponse, StatusCode,
Request, Response,
}; };
use hyper_tungstenite::{is_upgrade_request, upgrade};
use mlua::prelude::*; use mlua::prelude::*;
use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt}; use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt};
use super::{ use crate::{
super::websocket::NetWebSocket, keys::SvcKeys, request::LuaRequest, response::LuaResponse, body::ReadableBody,
server::{
config::ServeConfig,
upgrade::{is_upgrade_request, make_upgrade_response},
},
shared::{hyper::HyperIo, request::Request, response::Response, websocket::Websocket},
}; };
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(super) struct Svc { pub(super) struct Service {
pub(super) lua: Lua, pub(super) lua: Lua,
pub(super) addr: SocketAddr, pub(super) address: SocketAddr, // NOTE: This must be the remote address of the connected client
pub(super) keys: SvcKeys, pub(super) config: ServeConfig,
} }
impl Service<Request<Incoming>> for Svc { impl HyperService<HyperRequest<Incoming>> for Service {
type Response = Response<Full<Bytes>>; type Response = HyperResponse<ReadableBody>;
type Error = LuaError; type Error = LuaError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>; type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn call(&self, req: Request<Incoming>) -> Self::Future { fn call(&self, req: HyperRequest<Incoming>) -> Self::Future {
let lua = self.lua.clone(); if is_upgrade_request(&req) {
let addr = self.addr; if let Some(handler) = self.config.handle_web_socket.clone() {
let keys = self.keys; let lua = self.lua.clone();
return Box::pin(async move {
let response = match make_upgrade_response(&req) {
Ok(res) => res,
Err(err) => {
return Ok(HyperResponse::builder()
.status(StatusCode::BAD_REQUEST)
.body(ReadableBody::from(err.to_string()))
.unwrap())
}
};
if keys.has_websocket_handler() && is_upgrade_request(&req) { lua.spawn_local({
Box::pin(async move { let lua = lua.clone();
let (res, sock) = upgrade(req, None).into_lua_err()?; async move {
if let Err(_err) = handle_websocket(lua, handler, req).await {
// TODO: Propagate the error somehow?
}
}
});
let lua_inner = lua.clone(); Ok(response)
lua.spawn_local(async move {
let sock = sock.await.unwrap();
let lua_sock = NetWebSocket::new(sock);
let lua_val = lua_sock.into_lua(&lua_inner).unwrap();
let handler_websocket: LuaFunction =
keys.websocket_handler(&lua_inner).unwrap().unwrap();
lua_inner
.push_thread_back(handler_websocket, lua_val)
.unwrap();
}); });
}
Ok(res)
})
} else {
let (head, body) = req.into_parts();
Box::pin(async move {
let handler_request: LuaFunction = keys.request_handler(&lua).unwrap();
let body = body.collect().await.into_lua_err()?;
let body = body.to_bytes().to_vec();
let lua_req = LuaRequest {
_remote_addr: addr,
head,
body,
};
let lua_req_table = lua_req.into_lua_table(&lua)?;
let thread_id = lua.push_thread_back(handler_request, lua_req_table)?;
lua.track_thread(thread_id);
lua.wait_for_thread(thread_id).await;
let thread_res = lua
.get_thread_result(thread_id)
.expect("Missing handler thread result")?;
LuaResponse::from_lua_multi(thread_res, &lua)?.into_response()
})
} }
let lua = self.lua.clone();
let address = self.address;
let handler = self.config.handle_request.clone();
Box::pin(async move {
match handle_request(lua, handler, req, address).await {
Ok(response) => Ok(response),
Err(_err) => {
// TODO: Propagate the error somehow?
Ok(HyperResponse::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(ReadableBody::from("Lune: Internal server error"))
.unwrap())
}
}
})
} }
} }
async fn handle_request(
lua: Lua,
handler: LuaFunction,
request: HyperRequest<Incoming>,
address: SocketAddr,
) -> LuaResult<HyperResponse<ReadableBody>> {
let request = Request::from_incoming(request, true)
.await?
.with_address(address);
let thread_id = lua.push_thread_back(handler, request)?;
lua.track_thread(thread_id);
lua.wait_for_thread(thread_id).await;
let thread_res = lua
.get_thread_result(thread_id)
.expect("Missing handler thread result")?;
let response = Response::from_lua_multi(thread_res, &lua)?;
Ok(response.into_inner())
}
async fn handle_websocket(
lua: Lua,
handler: LuaFunction,
request: HyperRequest<Incoming>,
) -> LuaResult<()> {
let upgraded = hyper::upgrade::on(request).await.into_lua_err()?;
let stream =
WebSocketStream::from_raw_socket(HyperIo::from(upgraded), Role::Server, None).await;
let websocket = Websocket::from(stream);
lua.push_thread_back(handler, websocket)?;
Ok(())
}

View file

@ -0,0 +1,56 @@
use async_tungstenite::tungstenite::{error::ProtocolError, handshake::derive_accept_key};
use hyper::{
body::Incoming,
header::{HeaderName, CONNECTION, UPGRADE},
HeaderMap, Request as HyperRequest, Response as HyperResponse, StatusCode,
};
use crate::body::ReadableBody;
const SEC_WEBSOCKET_VERSION: HeaderName = HeaderName::from_static("sec-websocket-version");
const SEC_WEBSOCKET_KEY: HeaderName = HeaderName::from_static("sec-websocket-key");
const SEC_WEBSOCKET_ACCEPT: HeaderName = HeaderName::from_static("sec-websocket-accept");
pub fn is_upgrade_request(request: &HyperRequest<Incoming>) -> bool {
fn check_header_contains(headers: &HeaderMap, header_name: HeaderName, value: &str) -> bool {
headers.get(header_name).is_some_and(|header| {
header.to_str().map_or_else(
|_| false,
|header_str| {
header_str
.split(',')
.any(|part| part.trim().eq_ignore_ascii_case(value))
},
)
})
}
check_header_contains(request.headers(), CONNECTION, "Upgrade")
&& check_header_contains(request.headers(), UPGRADE, "websocket")
}
pub fn make_upgrade_response(
request: &HyperRequest<Incoming>,
) -> Result<HyperResponse<ReadableBody>, ProtocolError> {
let key = request
.headers()
.get(SEC_WEBSOCKET_KEY)
.ok_or(ProtocolError::MissingSecWebSocketKey)?;
if request
.headers()
.get(SEC_WEBSOCKET_VERSION)
.is_none_or(|v| v.as_bytes() != b"13")
{
return Err(ProtocolError::MissingSecWebSocketVersionHeader);
}
Ok(HyperResponse::builder()
.status(StatusCode::SWITCHING_PROTOCOLS)
.header(CONNECTION, "upgrade")
.header(UPGRADE, "websocket")
.header(SEC_WEBSOCKET_ACCEPT, derive_accept_key(key.as_bytes()))
.body(ReadableBody::from("switching to websocket protocol"))
.unwrap())
}

View file

@ -0,0 +1,19 @@
use futures_lite::prelude::*;
pub use http_body_util::Either;
/**
Combines the left and right futures into a single future
that resolves to either the left or right output.
This combinator is biased - if both futures resolve at
the same time, the left future's output is returned.
*/
pub fn either<L: Future, R: Future>(
left: L,
right: R,
) -> impl Future<Output = Either<L::Output, R::Output>> {
let fut_left = async move { Either::Left(left.await) };
let fut_right = async move { Either::Right(right.await) };
fut_left.or(fut_right)
}

View file

@ -0,0 +1,88 @@
use std::collections::HashMap;
use hyper::{
header::{CONTENT_ENCODING, CONTENT_LENGTH},
HeaderMap,
};
use lune_utils::TableBuilder;
use mlua::prelude::*;
pub fn create_user_agent_header(lua: &Lua) -> LuaResult<String> {
let version_global = lua
.globals()
.get::<LuaString>("_VERSION")
.expect("Missing _VERSION global");
let version_global_str = version_global
.to_str()
.context("Invalid utf8 found in _VERSION global")?;
let (package_name, full_version) = version_global_str.split_once(' ').unwrap();
Ok(format!("{}/{}", package_name.to_lowercase(), full_version))
}
pub fn header_map_to_table(
lua: &Lua,
headers: HeaderMap,
remove_content_headers: bool,
) -> LuaResult<LuaTable> {
let mut string_map = HashMap::<String, Vec<String>>::new();
for (name, value) in headers {
if let Some(name) = name {
if let Ok(value) = value.to_str() {
string_map
.entry(name.to_string())
.or_default()
.push(value.to_owned());
}
}
}
hash_map_to_table(lua, string_map, remove_content_headers)
}
pub fn hash_map_to_table(
lua: &Lua,
map: impl IntoIterator<Item = (String, Vec<String>)>,
remove_content_headers: bool,
) -> LuaResult<LuaTable> {
let mut string_map = HashMap::<String, Vec<String>>::new();
for (name, values) in map {
let name = name.as_str();
if remove_content_headers {
let content_encoding_header_str = CONTENT_ENCODING.as_str();
let content_length_header_str = CONTENT_LENGTH.as_str();
if name == content_encoding_header_str || name == content_length_header_str {
continue;
}
}
for value in values {
let value = value.as_str();
string_map
.entry(name.to_owned())
.or_default()
.push(value.to_owned());
}
}
let mut builder = TableBuilder::new(lua.clone())?;
for (name, mut values) in string_map {
if values.len() == 1 {
let value = values.pop().unwrap().into_lua(lua)?;
builder = builder.with_value(name, value)?;
} else {
let values = TableBuilder::new(lua.clone())?
.with_sequential_values(values)?
.build_readonly()?
.into_lua(lua)?;
builder = builder.with_value(name, values)?;
}
}
builder.build_readonly()
}

View file

@ -0,0 +1,198 @@
use std::{
future::Future,
io,
pin::Pin,
slice,
task::{Context, Poll},
time::{Duration, Instant},
};
use async_io::Timer;
use futures_lite::{prelude::*, ready};
use hyper::rt::{self, Executor, ReadBuf, ReadBufCursor};
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
// Hyper executor that spawns futures onto our Lua scheduler
#[derive(Debug, Clone)]
pub struct HyperExecutor {
lua: Lua,
}
#[allow(dead_code)]
impl HyperExecutor {
pub fn execute<Fut>(lua: Lua, fut: Fut)
where
Fut: Future + Send + 'static,
Fut::Output: Send + 'static,
{
let exec = if let Some(exec) = lua.app_data_ref::<Self>() {
exec
} else {
lua.set_app_data(Self { lua: lua.clone() });
lua.app_data_ref::<Self>().unwrap()
};
exec.execute(fut);
}
}
impl<Fut: Future + Send + 'static> rt::Executor<Fut> for HyperExecutor
where
Fut::Output: Send + 'static,
{
fn execute(&self, fut: Fut) {
self.lua.spawn(fut).detach();
}
}
// Hyper timer & sleep future wrapper for async-io
#[derive(Debug)]
pub struct HyperTimer;
impl rt::Timer for HyperTimer {
fn sleep(&self, duration: Duration) -> Pin<Box<dyn rt::Sleep>> {
Box::pin(HyperSleep::from(Timer::after(duration)))
}
fn sleep_until(&self, at: Instant) -> Pin<Box<dyn rt::Sleep>> {
Box::pin(HyperSleep::from(Timer::at(at)))
}
fn reset(&self, sleep: &mut Pin<Box<dyn rt::Sleep>>, new_deadline: Instant) {
if let Some(mut sleep) = sleep.as_mut().downcast_mut_pin::<HyperSleep>() {
sleep.inner.set_at(new_deadline);
} else {
*sleep = Box::pin(HyperSleep::from(Timer::at(new_deadline)));
}
}
}
#[derive(Debug)]
pub struct HyperSleep {
inner: Timer,
}
impl From<Timer> for HyperSleep {
fn from(inner: Timer) -> Self {
Self { inner }
}
}
impl Future for HyperSleep {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
match Pin::new(&mut self.inner).poll(cx) {
Poll::Ready(_) => Poll::Ready(()),
Poll::Pending => Poll::Pending,
}
}
}
impl rt::Sleep for HyperSleep {}
// Hyper I/O wrapper for bidirectional compatibility
// between hyper & futures-lite async read/write traits
pin_project_lite::pin_project! {
#[derive(Debug)]
pub struct HyperIo<T> {
#[pin]
inner: T
}
}
impl<T> From<T> for HyperIo<T> {
fn from(inner: T) -> Self {
Self { inner }
}
}
impl<T> HyperIo<T> {
pub fn pin_mut(self: Pin<&mut Self>) -> Pin<&mut T> {
self.project().inner
}
}
// Compat for futures-lite -> hyper runtime
impl<T: AsyncRead> rt::Read for HyperIo<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
mut buf: ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
// Fill the read buffer with initialized data
let read_slice = unsafe {
let buffer = buf.as_mut();
buffer.as_mut_ptr().write_bytes(0, buffer.len());
slice::from_raw_parts_mut(buffer.as_mut_ptr().cast::<u8>(), buffer.len())
};
// Read bytes from the underlying source
let n = match self.pin_mut().poll_read(cx, read_slice) {
Poll::Ready(Ok(n)) => n,
Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
Poll::Pending => return Poll::Pending,
};
unsafe {
buf.advance(n);
}
Poll::Ready(Ok(()))
}
}
impl<T: AsyncWrite> rt::Write for HyperIo<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
self.pin_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.pin_mut().poll_flush(cx)
}
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.pin_mut().poll_close(cx)
}
}
// Compat for hyper runtime -> futures-lite
impl<T: rt::Read> AsyncRead for HyperIo<T> {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut buf = ReadBuf::new(buf);
ready!(self.pin_mut().poll_read(cx, buf.unfilled()))?;
Poll::Ready(Ok(buf.filled().len()))
}
}
impl<T: rt::Write> AsyncWrite for HyperIo<T> {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, std::io::Error>> {
self.pin_mut().poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), std::io::Error>> {
self.pin_mut().poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
self.pin_mut().poll_shutdown(cx)
}
}

View file

@ -0,0 +1,41 @@
use hyper::{
header::{HeaderName, HeaderValue},
HeaderMap, Method,
};
use mlua::prelude::*;
pub fn lua_value_to_method(value: &LuaValue) -> LuaResult<Method> {
match value {
LuaValue::Nil => Ok(Method::GET),
LuaValue::String(str) => {
let bytes = str.as_bytes().trim_ascii().to_ascii_uppercase();
Method::from_bytes(&bytes).into_lua_err()
}
LuaValue::Buffer(buf) => {
let bytes = buf.to_vec().trim_ascii().to_ascii_uppercase();
Method::from_bytes(&bytes).into_lua_err()
}
v => Err(LuaError::FromLuaConversionError {
from: v.type_name(),
to: "Method".to_string(),
message: Some(format!(
"Invalid method - expected string or buffer, got {}",
v.type_name()
)),
}),
}
}
pub fn lua_table_to_header_map(table: &LuaTable) -> LuaResult<HeaderMap> {
let mut headers = HeaderMap::new();
for pair in table.pairs::<LuaString, LuaString>() {
let (key, val) = pair?;
let key = HeaderName::from_bytes(&key.as_bytes()).into_lua_err()?;
let val = HeaderValue::from_bytes(&val.as_bytes()).into_lua_err()?;
headers.insert(key, val);
}
Ok(headers)
}

View file

@ -0,0 +1,7 @@
pub mod futures;
pub mod headers;
pub mod hyper;
pub mod lua;
pub mod request;
pub mod response;
pub mod websocket;

View file

@ -0,0 +1,256 @@
use std::{collections::HashMap, net::SocketAddr};
use url::Url;
use hyper::{body::Incoming, HeaderMap, Method, Request as HyperRequest};
use mlua::prelude::*;
use crate::{
body::{handle_incoming_body, ReadableBody},
shared::{
headers::{hash_map_to_table, header_map_to_table},
lua::{lua_table_to_header_map, lua_value_to_method},
},
};
#[derive(Debug, Clone)]
pub struct RequestOptions {
pub decompress: bool,
}
impl Default for RequestOptions {
fn default() -> Self {
Self { decompress: true }
}
}
impl FromLua for RequestOptions {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::Nil = value {
// Nil means default options
Ok(Self::default())
} else if let LuaValue::Table(tab) = value {
// Table means custom options
let decompress = match tab.get::<Option<bool>>("decompress") {
Ok(decomp) => Ok(decomp.unwrap_or(true)),
Err(_) => Err(LuaError::RuntimeError(
"Invalid option value for 'decompress' in request options".to_string(),
)),
}?;
Ok(Self { decompress })
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestOptions".to_string(),
message: Some(format!(
"Invalid request options - expected table or nil, got {}",
value.type_name()
)),
})
}
}
}
#[derive(Debug, Clone)]
pub struct Request {
pub(crate) inner: HyperRequest<ReadableBody>,
pub(crate) address: Option<SocketAddr>,
pub(crate) redirects: Option<usize>,
pub(crate) decompress: bool,
}
impl Request {
/**
Creates a new request from a raw incoming request.
*/
pub async fn from_incoming(
incoming: HyperRequest<Incoming>,
decompress: bool,
) -> LuaResult<Self> {
let (parts, body) = incoming.into_parts();
let (body, decompress) = handle_incoming_body(&parts.headers, body, decompress).await?;
Ok(Self {
inner: HyperRequest::from_parts(parts, ReadableBody::from(body)),
address: None,
redirects: None,
decompress,
})
}
/**
Attaches a socket address to the request.
This will make the `ip` and `port` fields available on the request.
*/
pub fn with_address(mut self, address: SocketAddr) -> Self {
self.address = Some(address);
self
}
/**
Returns the method of the request.
*/
pub fn method(&self) -> Method {
self.inner.method().clone()
}
/**
Returns the path of the request.
*/
pub fn path(&self) -> &str {
self.inner.uri().path()
}
/**
Returns the query parameters of the request.
*/
pub fn query(&self) -> HashMap<String, Vec<String>> {
let uri = self.inner.uri();
let mut result = HashMap::<String, Vec<String>>::new();
if let Some(query) = uri.query() {
for (key, value) in form_urlencoded::parse(query.as_bytes()) {
result
.entry(key.to_string())
.or_default()
.push(value.to_string());
}
}
result
}
/**
Returns the headers of the request.
*/
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
/**
Returns the body of the request.
*/
pub fn body(&self) -> &[u8] {
self.inner.body().as_slice()
}
/**
Clones the inner `hyper` request.
*/
#[allow(dead_code)]
pub fn clone_inner(&self) -> HyperRequest<ReadableBody> {
self.inner.clone()
}
/**
Takes the inner `hyper` request by ownership.
*/
#[allow(dead_code)]
pub fn into_inner(self) -> HyperRequest<ReadableBody> {
self.inner
}
}
impl FromLua for Request {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
if let LuaValue::String(s) = value {
// If we just got a string we assume
// its a GET request to a given url
let uri = s.to_str()?;
let uri = uri.parse().into_lua_err()?;
let mut request = HyperRequest::new(ReadableBody::empty());
*request.uri_mut() = uri;
Ok(Self {
inner: request,
address: None,
redirects: None,
decompress: RequestOptions::default().decompress,
})
} else if let LuaValue::Table(tab) = value {
// If we got a table we are able to configure the
// entire request, maybe with extra options too
let options = match tab.get::<LuaValue>("options") {
Ok(opts) => RequestOptions::from_lua(opts, lua)?,
Err(_) => RequestOptions::default(),
};
// Extract url (required) + optional structured query params
let url = tab.get::<LuaString>("url")?;
let mut url = url.to_str()?.parse::<Url>().into_lua_err()?;
if let Some(t) = tab.get::<Option<LuaTable>>("query")? {
let mut query = url.query_pairs_mut();
for pair in t.pairs::<LuaString, LuaString>() {
let (key, value) = pair?;
let key = key.to_str()?;
let value = value.to_str()?;
query.append_pair(&key, &value);
}
}
// Extract method
let method = tab.get::<LuaValue>("method")?;
let method = lua_value_to_method(&method)?;
// Extract headers
let headers = tab.get::<Option<LuaTable>>("headers")?;
let headers = headers
.map(|t| lua_table_to_header_map(&t))
.transpose()?
.unwrap_or_default();
// Extract body
let body = tab.get::<ReadableBody>("body")?;
// Build the full request
let mut request = HyperRequest::new(body);
request.headers_mut().extend(headers);
*request.uri_mut() = url.to_string().parse().unwrap();
*request.method_mut() = method;
// All good, validated and we got what we need
Ok(Self {
inner: request,
address: None,
redirects: None,
decompress: options.decompress,
})
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "Request".to_string(),
message: Some(format!(
"Invalid request - expected string or table, got {}",
value.type_name()
)),
})
}
}
}
impl LuaUserData for Request {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ip", |_, this| {
Ok(this.address.map(|address| address.ip().to_string()))
});
fields.add_field_method_get("port", |_, this| {
Ok(this.address.map(|address| address.port()))
});
fields.add_field_method_get("method", |_, this| Ok(this.method().to_string()));
fields.add_field_method_get("path", |_, this| Ok(this.path().to_string()));
fields.add_field_method_get("query", |lua, this| {
hash_map_to_table(lua, this.query(), false)
});
fields.add_field_method_get("headers", |lua, this| {
header_map_to_table(lua, this.headers().clone(), this.decompress)
});
fields.add_field_method_get("body", |lua, this| lua.create_string(this.body()));
}
}

View file

@ -0,0 +1,153 @@
use hyper::{
body::Incoming,
header::{HeaderValue, CONTENT_TYPE},
HeaderMap, Response as HyperResponse, StatusCode,
};
use mlua::prelude::*;
use crate::{
body::{handle_incoming_body, ReadableBody},
shared::{headers::header_map_to_table, lua::lua_table_to_header_map},
};
#[derive(Debug, Clone)]
pub struct Response {
pub(crate) inner: HyperResponse<ReadableBody>,
pub(crate) decompressed: bool,
}
impl Response {
/**
Creates a new response from a raw incoming response.
*/
pub async fn from_incoming(
incoming: HyperResponse<Incoming>,
decompress: bool,
) -> LuaResult<Self> {
let (parts, body) = incoming.into_parts();
let (body, decompressed) = handle_incoming_body(&parts.headers, body, decompress).await?;
Ok(Self {
inner: HyperResponse::from_parts(parts, ReadableBody::from(body)),
decompressed,
})
}
/**
Returns whether the request was successful or not.
*/
pub fn status_ok(&self) -> bool {
self.inner.status().is_success()
}
/**
Returns the status code of the response.
*/
pub fn status_code(&self) -> u16 {
self.inner.status().as_u16()
}
/**
Returns the status message of the response.
*/
pub fn status_message(&self) -> &str {
self.inner.status().canonical_reason().unwrap_or_default()
}
/**
Returns the headers of the response.
*/
pub fn headers(&self) -> &HeaderMap {
self.inner.headers()
}
/**
Returns the body of the response.
*/
pub fn body(&self) -> &[u8] {
self.inner.body().as_slice()
}
/**
Clones the inner `hyper` response.
*/
#[allow(dead_code)]
pub fn clone_inner(&self) -> HyperResponse<ReadableBody> {
self.inner.clone()
}
/**
Takes the inner `hyper` response by ownership.
*/
#[allow(dead_code)]
pub fn into_inner(self) -> HyperResponse<ReadableBody> {
self.inner
}
}
impl FromLua for Response {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
if let Ok(body) = ReadableBody::from_lua(value.clone(), lua) {
// String or buffer is always a 200 text/plain response
let mut response = HyperResponse::new(body);
response
.headers_mut()
.insert(CONTENT_TYPE, HeaderValue::from_static("text/plain"));
Ok(Self {
inner: response,
decompressed: false,
})
} else if let LuaValue::Table(tab) = value {
// Extract status (required)
let status = tab.get::<u16>("status")?;
let status = StatusCode::from_u16(status).into_lua_err()?;
// Extract headers
let headers = tab.get::<Option<LuaTable>>("headers")?;
let headers = headers
.map(|t| lua_table_to_header_map(&t))
.transpose()?
.unwrap_or_default();
// Extract body
let body = tab.get::<ReadableBody>("body")?;
// Build the full response
let mut response = HyperResponse::new(body);
response.headers_mut().extend(headers);
*response.status_mut() = status;
// All good, validated and we got what we need
Ok(Self {
inner: response,
decompressed: false,
})
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "Response".to_string(),
message: Some(format!(
"Invalid response - expected table/string/buffer, got {}",
value.type_name()
)),
})
}
}
}
impl LuaUserData for Response {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("ok", |_, this| Ok(this.status_ok()));
fields.add_field_method_get("statusCode", |_, this| Ok(this.status_code()));
fields.add_field_method_get("statusMessage", |lua, this| {
lua.create_string(this.status_message())
});
fields.add_field_method_get("headers", |lua, this| {
header_map_to_table(lua, this.headers().clone(), this.decompressed)
});
fields.add_field_method_get("body", |lua, this| lua.create_string(this.body()));
}
}

View file

@ -1,62 +1,38 @@
use std::sync::{ use std::{
atomic::{AtomicBool, AtomicU16, Ordering}, error::Error,
Arc, sync::{
atomic::{AtomicBool, AtomicU16, Ordering},
Arc,
},
}; };
use async_lock::Mutex as AsyncMutex;
use async_tungstenite::tungstenite::{
protocol::{frame::coding::CloseCode, CloseFrame},
Message as TungsteniteMessage, Result as TungsteniteResult, Utf8Bytes,
};
use bstr::{BString, ByteSlice}; use bstr::{BString, ByteSlice};
use futures::{
stream::{SplitSink, SplitStream},
Sink, SinkExt, Stream, StreamExt,
};
use hyper::body::Bytes;
use mlua::prelude::*; use mlua::prelude::*;
use futures_util::{ #[derive(Debug, Clone)]
stream::{SplitSink, SplitStream}, pub struct Websocket<T> {
SinkExt, StreamExt,
};
use tokio::{
io::{AsyncRead, AsyncWrite},
sync::Mutex as AsyncMutex,
};
use hyper_tungstenite::{
tungstenite::{
protocol::{frame::coding::CloseCode as WsCloseCode, CloseFrame as WsCloseFrame},
Message as WsMessage,
},
WebSocketStream,
};
#[derive(Debug)]
pub struct NetWebSocket<T> {
close_code_exists: Arc<AtomicBool>, close_code_exists: Arc<AtomicBool>,
close_code_value: Arc<AtomicU16>, close_code_value: Arc<AtomicU16>,
read_stream: Arc<AsyncMutex<SplitStream<WebSocketStream<T>>>>, read_stream: Arc<AsyncMutex<SplitStream<T>>>,
write_stream: Arc<AsyncMutex<SplitSink<WebSocketStream<T>, WsMessage>>>, write_stream: Arc<AsyncMutex<SplitSink<T, TungsteniteMessage>>>,
} }
impl<T> Clone for NetWebSocket<T> { impl<T> Websocket<T>
fn clone(&self) -> Self {
Self {
close_code_exists: Arc::clone(&self.close_code_exists),
close_code_value: Arc::clone(&self.close_code_value),
read_stream: Arc::clone(&self.read_stream),
write_stream: Arc::clone(&self.write_stream),
}
}
}
impl<T> NetWebSocket<T>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{ {
pub fn new(value: WebSocketStream<T>) -> Self {
let (write, read) = value.split();
Self {
close_code_exists: Arc::new(AtomicBool::new(false)),
close_code_value: Arc::new(AtomicU16::new(0)),
read_stream: Arc::new(AsyncMutex::new(read)),
write_stream: Arc::new(AsyncMutex::new(write)),
}
}
fn get_close_code(&self) -> Option<u16> { fn get_close_code(&self) -> Option<u16> {
if self.close_code_exists.load(Ordering::Relaxed) { if self.close_code_exists.load(Ordering::Relaxed) {
Some(self.close_code_value.load(Ordering::Relaxed)) Some(self.close_code_value.load(Ordering::Relaxed))
@ -70,12 +46,12 @@ where
self.close_code_value.store(code, Ordering::Relaxed); self.close_code_value.store(code, Ordering::Relaxed);
} }
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> { pub async fn send(&self, msg: TungsteniteMessage) -> LuaResult<()> {
let mut ws = self.write_stream.lock().await; let mut ws = self.write_stream.lock().await;
ws.send(msg).await.into_lua_err() ws.send(msg).await.into_lua_err()
} }
pub async fn next(&self) -> LuaResult<Option<WsMessage>> { pub async fn next(&self) -> LuaResult<Option<TungsteniteMessage>> {
let mut ws = self.read_stream.lock().await; let mut ws = self.read_stream.lock().await;
ws.next().await.transpose().into_lua_err() ws.next().await.transpose().into_lua_err()
} }
@ -85,15 +61,15 @@ where
return Err(LuaError::runtime("Socket has already been closed")); return Err(LuaError::runtime("Socket has already been closed"));
} }
self.send(WsMessage::Close(Some(WsCloseFrame { self.send(TungsteniteMessage::Close(Some(CloseFrame {
code: match code { code: match code {
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code), Some(code) if (1000..=4999).contains(&code) => CloseCode::from(code),
Some(code) => { Some(code) => {
return Err(LuaError::runtime(format!( return Err(LuaError::runtime(format!(
"Close code must be between 1000 and 4999, got {code}" "Close code must be between 1000 and 4999, got {code}"
))) )))
} }
None => WsCloseCode::Normal, None => CloseCode::Normal,
}, },
reason: "".into(), reason: "".into(),
}))) })))
@ -104,9 +80,27 @@ where
} }
} }
impl<T> LuaUserData for NetWebSocket<T> impl<T> From<T> for Websocket<T>
where where
T: AsyncRead + AsyncWrite + Unpin + 'static, T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{
fn from(value: T) -> Self {
let (write, read) = value.split();
Self {
close_code_exists: Arc::new(AtomicBool::new(false)),
close_code_value: Arc::new(AtomicU16::new(0)),
read_stream: Arc::new(AsyncMutex::new(read)),
write_stream: Arc::new(AsyncMutex::new(write)),
}
}
}
impl<T> LuaUserData for Websocket<T>
where
T: Stream<Item = TungsteniteResult<TungsteniteMessage>> + Sink<TungsteniteMessage> + 'static,
<T as Sink<TungsteniteMessage>>::Error: Into<Box<dyn Error + Send + Sync + 'static>>,
{ {
fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) { fn add_fields<F: LuaUserDataFields<Self>>(fields: &mut F) {
fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code())); fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code()));
@ -121,10 +115,10 @@ where
"send", "send",
|_, this, (string, as_binary): (BString, Option<bool>)| async move { |_, this, (string, as_binary): (BString, Option<bool>)| async move {
this.send(if as_binary.unwrap_or_default() { this.send(if as_binary.unwrap_or_default() {
WsMessage::Binary(string.as_bytes().to_vec()) TungsteniteMessage::Binary(Bytes::from(string.to_vec()))
} else { } else {
let s = string.to_str().into_lua_err()?; let s = string.to_str().into_lua_err()?;
WsMessage::Text(s.to_string()) TungsteniteMessage::Text(Utf8Bytes::from(s))
}) })
.await .await
}, },
@ -133,14 +127,14 @@ where
methods.add_async_method("next", |lua, this, (): ()| async move { methods.add_async_method("next", |lua, this, (): ()| async move {
let msg = this.next().await?; let msg = this.next().await?;
if let Some(WsMessage::Close(Some(frame))) = msg.as_ref() { if let Some(TungsteniteMessage::Close(Some(frame))) = msg.as_ref() {
this.set_close_code(frame.code.into()); this.set_close_code(frame.code.into());
} }
Ok(match msg { Ok(match msg {
Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?), Some(TungsteniteMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?),
Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?), Some(TungsteniteMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?),
Some(WsMessage::Close(_)) | None => LuaValue::Nil, Some(TungsteniteMessage::Close(_)) | None => LuaValue::Nil,
// Ignore ping/pong/frame messages, they are handled by tungstenite // Ignore ping/pong/frame messages, they are handled by tungstenite
msg => unreachable!("Unhandled message: {:?}", msg), msg => unreachable!("Unhandled message: {:?}", msg),
}) })

View file

@ -0,0 +1,12 @@
use mlua::prelude::*;
pub fn decode(lua_string: LuaString, as_binary: bool) -> LuaResult<Vec<u8>> {
if as_binary {
Ok(urlencoding::decode_binary(&lua_string.as_bytes()).into_owned())
} else {
Ok(urlencoding::decode(&lua_string.to_str()?)
.map_err(|e| LuaError::RuntimeError(format!("Encountered invalid encoding - {e}")))?
.into_owned()
.into_bytes())
}
}

View file

@ -0,0 +1,13 @@
use mlua::prelude::*;
pub fn encode(lua_string: LuaString, as_binary: bool) -> LuaResult<Vec<u8>> {
if as_binary {
Ok(urlencoding::encode_binary(&lua_string.as_bytes())
.into_owned()
.into_bytes())
} else {
Ok(urlencoding::encode(&lua_string.to_str()?)
.into_owned()
.into_bytes())
}
}

View file

@ -0,0 +1,5 @@
mod decode;
mod encode;
pub use self::decode::decode;
pub use self::encode::encode;

View file

@ -1,94 +0,0 @@
use std::collections::HashMap;
use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH};
use reqwest::header::HeaderMap;
use mlua::prelude::*;
use lune_utils::TableBuilder;
pub fn create_user_agent_header(lua: &Lua) -> LuaResult<String> {
let version_global = lua
.globals()
.get::<LuaString>("_VERSION")
.expect("Missing _VERSION global");
let version_global_str = version_global
.to_str()
.context("Invalid utf8 found in _VERSION global")?;
let (package_name, full_version) = version_global_str.split_once(' ').unwrap();
Ok(format!("{}/{}", package_name.to_lowercase(), full_version))
}
pub fn header_map_to_table(
lua: &Lua,
headers: HeaderMap,
remove_content_headers: bool,
) -> LuaResult<LuaTable> {
let mut res_headers: HashMap<String, Vec<String>> = HashMap::new();
for (name, value) in &headers {
let name = name.as_str();
let value = value.to_str().unwrap().to_owned();
if let Some(existing) = res_headers.get_mut(name) {
existing.push(value);
} else {
res_headers.insert(name.to_owned(), vec![value]);
}
}
if remove_content_headers {
let content_encoding_header_str = CONTENT_ENCODING.as_str();
let content_length_header_str = CONTENT_LENGTH.as_str();
res_headers.retain(|name, _| {
name != content_encoding_header_str && name != content_length_header_str
});
}
let mut builder = TableBuilder::new(lua.clone())?;
for (name, mut values) in res_headers {
if values.len() == 1 {
let value = values.pop().unwrap().into_lua(lua)?;
builder = builder.with_value(name, value)?;
} else {
let values = TableBuilder::new(lua.clone())?
.with_sequential_values(values)?
.build_readonly()?
.into_lua(lua)?;
builder = builder.with_value(name, values)?;
}
}
builder.build_readonly()
}
pub fn table_to_hash_map(
tab: LuaTable,
tab_origin_key: &'static str,
) -> LuaResult<HashMap<String, Vec<String>>> {
let mut map = HashMap::new();
for pair in tab.pairs::<String, LuaValue>() {
let (key, value) = pair?;
match value {
LuaValue::String(s) => {
map.insert(key, vec![s.to_str()?.to_owned()]);
}
LuaValue::Table(t) => {
let mut values = Vec::new();
for value in t.sequence_values::<LuaString>() {
values.push(value?.to_str()?.to_owned());
}
map.insert(key, values);
}
_ => {
return Err(LuaError::runtime(format!(
"Value for '{tab_origin_key}' must be a string or array of strings",
)))
}
}
}
Ok(map)
}

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-process" name = "lune-std-process"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -14,11 +14,10 @@ workspace = true
[dependencies] [dependencies]
mlua = { version = "0.10.3", features = ["luau"] } mlua = { version = "0.10.3", features = ["luau"] }
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" } mlua-luau-scheduler = { version = "0.1.2", path = "../mlua-luau-scheduler" }
directories = "6.0" directories = "6.0"
pin-project = "1.0" pin-project = "1.0"
os_str_bytes = { version = "7.0", features = ["conversions"] }
bstr = "1.9" bstr = "1.9"
bytes = "1.6.0" bytes = "1.6.0"
@ -30,4 +29,4 @@ blocking = "1.6"
futures-lite = "2.6" futures-lite = "2.6"
futures-util = "0.3" # Needed for select! macro... futures-util = "0.3" # Needed for select! macro...
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }

View file

@ -1,10 +1,7 @@
#![allow(clippy::cargo_common_metadata)] #![allow(clippy::cargo_common_metadata)]
use std::{ use std::{
env::{ env::consts::{ARCH, OS},
self,
consts::{ARCH, OS},
},
path::MAIN_SEPARATOR, path::MAIN_SEPARATOR,
process::Stdio, process::Stdio,
}; };
@ -12,9 +9,11 @@ use std::{
use mlua::prelude::*; use mlua::prelude::*;
use mlua_luau_scheduler::Functions; use mlua_luau_scheduler::Functions;
use os_str_bytes::RawOsString; use lune_utils::{
path::get_current_dir,
use lune_utils::{path::get_current_dir, TableBuilder}; process::{ProcessArgs, ProcessEnv},
TableBuilder,
};
mod create; mod create;
mod exec; mod exec;
@ -58,25 +57,15 @@ pub fn module(lua: Lua) -> LuaResult<LuaTable> {
"little" "little"
})?; })?;
// Create readonly args array // Extract stored userdatas for args + env, the runtime struct should always provide this
let args_vec = lua let process_args = lua
.app_data_ref::<Vec<String>>() .app_data_ref::<ProcessArgs>()
.ok_or_else(|| LuaError::runtime("Missing args vec in Lua app data"))? .ok_or_else(|| LuaError::runtime("Missing process args in Lua app data"))?
.clone();
let process_env = lua
.app_data_ref::<ProcessEnv>()
.ok_or_else(|| LuaError::runtime("Missing process env in Lua app data"))?
.clone(); .clone();
let args_tab = TableBuilder::new(lua.clone())?
.with_sequential_values(args_vec)?
.build_readonly()?;
// Create proxied table for env that gets & sets real env vars
let env_tab = TableBuilder::new(lua.clone())?
.with_metatable(
TableBuilder::new(lua.clone())?
.with_function(LuaMetaMethod::Index.name(), process_env_get)?
.with_function(LuaMetaMethod::NewIndex.name(), process_env_set)?
.with_function(LuaMetaMethod::Iter.name(), process_env_iter)?
.build_readonly()?,
)?
.build_readonly()?;
// Create our process exit function, the scheduler crate provides this // Create our process exit function, the scheduler crate provides this
let fns = Functions::new(lua.clone())?; let fns = Functions::new(lua.clone())?;
@ -87,73 +76,18 @@ pub fn module(lua: Lua) -> LuaResult<LuaTable> {
.with_value("os", os)? .with_value("os", os)?
.with_value("arch", arch)? .with_value("arch", arch)?
.with_value("endianness", endianness)? .with_value("endianness", endianness)?
.with_value("args", args_tab)? .with_value("args", process_args)?
.with_value("cwd", cwd_str)? .with_value("cwd", cwd_str)?
.with_value("env", env_tab)? .with_value("env", process_env)?
.with_value("exit", process_exit)? .with_value("exit", process_exit)?
.with_async_function("exec", process_exec)? .with_async_function("exec", process_exec)?
.with_function("create", process_create)? .with_function("create", process_create)?
.build_readonly() .build_readonly()
} }
fn process_env_get(lua: &Lua, (_, key): (LuaValue, String)) -> LuaResult<LuaValue> {
match env::var_os(key) {
Some(value) => {
let raw_value = RawOsString::new(value);
Ok(LuaValue::String(
lua.create_string(raw_value.to_raw_bytes())?,
))
}
None => Ok(LuaValue::Nil),
}
}
fn process_env_set(_: &Lua, (_, key, value): (LuaValue, String, Option<String>)) -> LuaResult<()> {
// Make sure key is valid, otherwise set_var will panic
if key.is_empty() {
Err(LuaError::RuntimeError("Key must not be empty".to_string()))
} else if key.contains('=') {
Err(LuaError::RuntimeError(
"Key must not contain the equals character '='".to_string(),
))
} else if key.contains('\0') {
Err(LuaError::RuntimeError(
"Key must not contain the NUL character".to_string(),
))
} else if let Some(value) = value {
// Make sure value is valid, otherwise set_var will panic
if value.contains('\0') {
Err(LuaError::RuntimeError(
"Value must not contain the NUL character".to_string(),
))
} else {
env::set_var(&key, &value);
Ok(())
}
} else {
env::remove_var(&key);
Ok(())
}
}
fn process_env_iter(lua: &Lua, (_, ()): (LuaValue, ())) -> LuaResult<LuaFunction> {
let mut vars = env::vars_os().collect::<Vec<_>>().into_iter();
lua.create_function_mut(move |lua, (): ()| match vars.next() {
Some((key, value)) => {
let raw_key = RawOsString::new(key);
let raw_value = RawOsString::new(value);
Ok((
LuaValue::String(lua.create_string(raw_key.to_raw_bytes())?),
LuaValue::String(lua.create_string(raw_value.to_raw_bytes())?),
))
}
None => Ok((LuaValue::Nil, LuaValue::Nil)),
})
}
async fn process_exec( async fn process_exec(
lua: Lua, lua: Lua,
(program, args, mut options): (String, Option<Vec<String>>, ProcessSpawnOptions), (program, args, mut options): (String, ProcessArgs, ProcessSpawnOptions),
) -> LuaResult<LuaTable> { ) -> LuaResult<LuaTable> {
let stdin = options.stdio.stdin.take(); let stdin = options.stdio.stdin.take();
let stdout = options.stdio.stdout; let stdout = options.stdio.stdout;
@ -171,7 +105,7 @@ async fn process_exec(
fn process_create( fn process_create(
lua: &Lua, lua: &Lua,
(program, args, options): (String, Option<Vec<String>>, ProcessSpawnOptions), (program, args, options): (String, ProcessArgs, ProcessSpawnOptions),
) -> LuaResult<LuaValue> { ) -> LuaResult<LuaValue> {
let child = options let child = options
.into_command(program, args) .into_command(program, args)

View file

@ -1,9 +1,11 @@
use std::{ use std::{
collections::HashMap, collections::HashMap,
env::{self}, env::{self},
ffi::OsString,
path::PathBuf, path::PathBuf,
}; };
use lune_utils::process::ProcessArgs;
use mlua::prelude::*; use mlua::prelude::*;
use async_process::Command; use async_process::Command;
@ -129,31 +131,24 @@ impl FromLua for ProcessSpawnOptions {
} }
impl ProcessSpawnOptions { impl ProcessSpawnOptions {
pub fn into_command(self, program: impl Into<String>, args: Option<Vec<String>>) -> Command { pub fn into_command(self, program: impl Into<OsString>, args: ProcessArgs) -> Command {
let mut program = program.into(); let mut program: OsString = program.into();
let mut args = args.into_iter().collect::<Vec<_>>();
// Run a shell using the command param if wanted // Run a shell using the command param if wanted
let pargs = match self.shell { if let Some(shell) = self.shell {
None => args, let mut shell_command = program.clone();
Some(shell) => { for arg in args {
let shell_args = match args { shell_command.push(" ");
Some(args) => vec!["-c".to_string(), format!("{} {}", program, args.join(" "))], shell_command.push(arg);
None => vec!["-c".to_string(), program.to_string()],
};
program = shell.to_string();
Some(shell_args)
} }
}; args = vec![OsString::from("-c"), shell_command];
program = shell.into();
}
// Create command with the wanted options // Create command with the wanted options
let mut cmd = match pargs { let mut cmd = Command::new(program);
None => Command::new(program), cmd.args(args);
Some(args) => {
let mut cmd = Command::new(program);
cmd.args(args);
cmd
}
};
// Set dir to run in and env variables // Set dir to run in and env variables
if let Some(cwd) = self.cwd { if let Some(cwd) = self.cwd {

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-regex" name = "lune-std-regex"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -18,4 +18,4 @@ mlua = { version = "0.10.3", features = ["luau"] }
regex = "1.10" regex = "1.10"
self_cell = "1.0" self_cell = "1.0"
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-roblox" name = "lune-std-roblox"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -14,10 +14,10 @@ workspace = true
[dependencies] [dependencies]
mlua = { version = "0.10.3", features = ["luau"] } mlua = { version = "0.10.3", features = ["luau"] }
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" } mlua-luau-scheduler = { version = "0.1.2", path = "../mlua-luau-scheduler" }
rbx_cookie = { version = "0.1.4", default-features = false } rbx_cookie = { version = "0.1.4", default-features = false }
roblox_install = "1.0" roblox_install = "1.0"
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }
lune-roblox = { version = "0.2.0", path = "../lune-roblox" } lune-roblox = { version = "0.2.2", path = "../lune-roblox" }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-serde" name = "lune-std-serde"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -42,4 +42,4 @@ sha3 = "0.10.8"
# Check before updating it. # Check before updating it.
blake3 = { version = "=1.5.0", features = ["traits-preview"] } blake3 = { version = "=1.5.0", features = ["traits-preview"] }
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-stdio" name = "lune-std-stdio"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -14,7 +14,7 @@ workspace = true
[dependencies] [dependencies]
mlua = { version = "0.10.3", features = ["luau", "error-send"] } mlua = { version = "0.10.3", features = ["luau", "error-send"] }
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" } mlua-luau-scheduler = { version = "0.1.2", path = "../mlua-luau-scheduler" }
async-io = "2.4" async-io = "2.4"
async-lock = "3.4" async-lock = "3.4"
@ -22,4 +22,4 @@ blocking = "1.6"
dialoguer = "0.11" dialoguer = "0.11"
futures-lite = "2.6" futures-lite = "2.6"
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std-task" name = "lune-std-task"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -14,9 +14,9 @@ workspace = true
[dependencies] [dependencies]
mlua = { version = "0.10.3", features = ["luau"] } mlua = { version = "0.10.3", features = ["luau"] }
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" } mlua-luau-scheduler = { version = "0.1.2", path = "../mlua-luau-scheduler" }
async-io = "2.4" async-io = "2.4"
futures-lite = "2.6" futures-lite = "2.6"
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-std" name = "lune-std"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -39,7 +39,7 @@ task = ["dep:lune-std-task"]
[dependencies] [dependencies]
mlua = { version = "0.10.3", features = ["luau"] } mlua = { version = "0.10.3", features = ["luau"] }
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" } mlua-luau-scheduler = { version = "0.1.2", path = "../mlua-luau-scheduler" }
async-channel = "2.3" async-channel = "2.3"
async-fs = "2.1" async-fs = "2.1"
@ -48,15 +48,15 @@ async-lock = "3.4"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
lune-utils = { version = "0.2.0", path = "../lune-utils" } lune-utils = { version = "0.2.2", path = "../lune-utils" }
lune-std-datetime = { optional = true, version = "0.2.0", path = "../lune-std-datetime" } lune-std-datetime = { optional = true, version = "0.2.2", path = "../lune-std-datetime" }
lune-std-fs = { optional = true, version = "0.2.0", path = "../lune-std-fs" } lune-std-fs = { optional = true, version = "0.2.2", path = "../lune-std-fs" }
lune-std-luau = { optional = true, version = "0.2.0", path = "../lune-std-luau" } lune-std-luau = { optional = true, version = "0.2.2", path = "../lune-std-luau" }
lune-std-net = { optional = true, version = "0.2.0", path = "../lune-std-net" } lune-std-net = { optional = true, version = "0.2.2", path = "../lune-std-net" }
lune-std-process = { optional = true, version = "0.2.0", path = "../lune-std-process" } lune-std-process = { optional = true, version = "0.2.2", path = "../lune-std-process" }
lune-std-regex = { optional = true, version = "0.2.0", path = "../lune-std-regex" } lune-std-regex = { optional = true, version = "0.2.2", path = "../lune-std-regex" }
lune-std-roblox = { optional = true, version = "0.2.0", path = "../lune-std-roblox" } lune-std-roblox = { optional = true, version = "0.2.2", path = "../lune-std-roblox" }
lune-std-serde = { optional = true, version = "0.2.0", path = "../lune-std-serde" } lune-std-serde = { optional = true, version = "0.2.2", path = "../lune-std-serde" }
lune-std-stdio = { optional = true, version = "0.2.0", path = "../lune-std-stdio" } lune-std-stdio = { optional = true, version = "0.2.2", path = "../lune-std-stdio" }
lune-std-task = { optional = true, version = "0.2.0", path = "../lune-std-task" } lune-std-task = { optional = true, version = "0.2.2", path = "../lune-std-task" }

View file

@ -35,7 +35,18 @@ pub fn create(lua: Lua) -> LuaResult<LuaValue> {
3. The lua chunk we are require-ing from 3. The lua chunk we are require-ing from
*/ */
let require_fn = lua.create_async_function(require)?; let require_fn = lua.create_async_function(|lua, (source, path)| {
// NOTE: We need to make sure that the app data reference does not
// live through the entire require call, to prevent panicking from
// being unable to borrow other app data in the main body of scripts
let context = {
let context = lua
.app_data_ref::<RequireContext>()
.expect("Failed to get RequireContext from app data");
context.clone()
};
require(lua, context, source, path)
})?;
let get_source_fn = lua.create_function(move |lua, (): ()| match lua.inspect_stack(2) { let get_source_fn = lua.create_function(move |lua, (): ()| match lua.inspect_stack(2) {
None => Err(LuaError::runtime( None => Err(LuaError::runtime(
"Failed to get stack info for require source", "Failed to get stack info for require source",
@ -60,7 +71,12 @@ pub fn create(lua: Lua) -> LuaResult<LuaValue> {
.into_lua(&lua) .into_lua(&lua)
} }
async fn require(lua: Lua, (source, path): (LuaString, LuaString)) -> LuaResult<LuaMultiValue> { async fn require(
lua: Lua,
context: RequireContext,
source: LuaString,
path: LuaString,
) -> LuaResult<LuaMultiValue> {
let source = source let source = source
.to_str() .to_str()
.into_lua_err() .into_lua_err()
@ -73,11 +89,6 @@ async fn require(lua: Lua, (source, path): (LuaString, LuaString)) -> LuaResult<
.context("Failed to parse require path as string")? .context("Failed to parse require path as string")?
.to_string(); .to_string();
let context = lua
.app_data_ref::<RequireContext>()
.expect("Failed to get RequireContext from app data")
.clone();
if let Some(builtin_name) = path.strip_prefix("@lune/").map(str::to_ascii_lowercase) { if let Some(builtin_name) = path.strip_prefix("@lune/").map(str::to_ascii_lowercase) {
library::require(lua, &context, &builtin_name) library::require(lua, &context, &builtin_name)
} else if let Some(self_path) = path.strip_prefix("@self/") { } else if let Some(self_path) = path.strip_prefix("@self/") {

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune-utils" name = "lune-utils"
version = "0.2.0" version = "0.2.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -17,6 +17,7 @@ mlua = { version = "0.10.3", features = ["luau", "async"] }
console = "0.15" console = "0.15"
dunce = "1.0" dunce = "1.0"
os_str_bytes = { version = "7.0", features = ["conversions"] }
path-clean = "1.0" path-clean = "1.0"
pathdiff = "0.2" pathdiff = "0.2"
parking_lot = "0.12.3" parking_lot = "0.12.3"

View file

@ -4,8 +4,13 @@ mod table_builder;
mod version_string; mod version_string;
pub mod fmt; pub mod fmt;
pub mod jit;
pub mod path; pub mod path;
pub mod process;
pub use self::table_builder::TableBuilder; pub use self::table_builder::TableBuilder;
pub use self::version_string::get_version_string; pub use self::version_string::get_version_string;
// TODO: Remove this in the next major semver
pub mod jit {
pub use super::process::ProcessJitEnablement as JitEnablement;
}

View file

@ -0,0 +1,252 @@
#![allow(clippy::missing_panics_doc)]
use std::{
env::args_os,
ffi::OsString,
sync::{Arc, Mutex},
};
use mlua::prelude::*;
use os_str_bytes::OsStringBytes;
// Inner (shared) struct
#[derive(Debug, Default)]
struct ProcessArgsInner {
values: Vec<OsString>,
}
impl FromIterator<OsString> for ProcessArgsInner {
fn from_iter<T: IntoIterator<Item = OsString>>(iter: T) -> Self {
Self {
values: iter.into_iter().collect(),
}
}
}
/**
A struct that can be easily shared, stored in Lua app data,
and that also guarantees the values are valid OS strings
that can be used for process arguments.
Usable directly from Lua, implementing both `FromLua` and `LuaUserData`.
Also provides convenience methods for working with the arguments
as either `OsString` or `Vec<u8>`, where using the latter implicitly
converts to an `OsString` and fails if the conversion is not possible.
*/
#[derive(Debug, Clone)]
pub struct ProcessArgs {
inner: Arc<Mutex<ProcessArgsInner>>,
}
impl ProcessArgs {
#[must_use]
pub fn empty() -> Self {
Self {
inner: Arc::new(Mutex::new(ProcessArgsInner::default())),
}
}
#[must_use]
pub fn current() -> Self {
Self {
inner: Arc::new(Mutex::new(args_os().collect())),
}
}
#[must_use]
pub fn len(&self) -> usize {
let inner = self.inner.lock().unwrap();
inner.values.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
let inner = self.inner.lock().unwrap();
inner.values.is_empty()
}
// OS strings
#[must_use]
pub fn all(&self) -> Vec<OsString> {
let inner = self.inner.lock().unwrap();
inner.values.clone()
}
#[must_use]
pub fn get(&self, index: usize) -> Option<OsString> {
let inner = self.inner.lock().unwrap();
inner.values.get(index).cloned()
}
pub fn set(&self, index: usize, val: impl Into<OsString>) {
let mut inner = self.inner.lock().unwrap();
if let Some(arg) = inner.values.get_mut(index) {
*arg = val.into();
}
}
pub fn push(&self, val: impl Into<OsString>) {
let mut inner = self.inner.lock().unwrap();
inner.values.push(val.into());
}
#[must_use]
pub fn pop(&self) -> Option<OsString> {
let mut inner = self.inner.lock().unwrap();
inner.values.pop()
}
pub fn insert(&self, index: usize, val: impl Into<OsString>) {
let mut inner = self.inner.lock().unwrap();
if index <= inner.values.len() {
inner.values.insert(index, val.into());
}
}
#[must_use]
pub fn remove(&self, index: usize) -> Option<OsString> {
let mut inner = self.inner.lock().unwrap();
if index < inner.values.len() {
Some(inner.values.remove(index))
} else {
None
}
}
// Bytes wrappers
#[must_use]
pub fn all_bytes(&self) -> Vec<Vec<u8>> {
self.all()
.into_iter()
.filter_map(OsString::into_io_vec)
.collect()
}
#[must_use]
pub fn get_bytes(&self, index: usize) -> Option<Vec<u8>> {
let val = self.get(index)?;
val.into_io_vec()
}
pub fn set_bytes(&self, index: usize, val: impl Into<Vec<u8>>) {
if let Some(val_os) = OsString::from_io_vec(val.into()) {
self.set(index, val_os);
}
}
pub fn push_bytes(&self, val: impl Into<Vec<u8>>) {
if let Some(val_os) = OsString::from_io_vec(val.into()) {
self.push(val_os);
}
}
#[must_use]
pub fn pop_bytes(&self) -> Option<Vec<u8>> {
self.pop().and_then(OsString::into_io_vec)
}
pub fn insert_bytes(&self, index: usize, val: impl Into<Vec<u8>>) {
if let Some(val_os) = OsString::from_io_vec(val.into()) {
self.insert(index, val_os);
}
}
pub fn remove_bytes(&self, index: usize) -> Option<Vec<u8>> {
self.remove(index).and_then(OsString::into_io_vec)
}
}
// Iterator implementations
impl IntoIterator for ProcessArgs {
type Item = OsString;
type IntoIter = std::vec::IntoIter<OsString>;
fn into_iter(self) -> Self::IntoIter {
let inner = self.inner.lock().unwrap();
inner.values.clone().into_iter()
}
}
impl<S: Into<OsString>> FromIterator<S> for ProcessArgs {
fn from_iter<T: IntoIterator<Item = S>>(iter: T) -> Self {
Self {
inner: Arc::new(Mutex::new(iter.into_iter().map(Into::into).collect())),
}
}
}
impl<S: Into<OsString>> Extend<S> for ProcessArgs {
fn extend<T: IntoIterator<Item = S>>(&mut self, iter: T) {
let mut inner = self.inner.lock().unwrap();
inner.values.extend(iter.into_iter().map(Into::into));
}
}
// Lua implementations
impl FromLua for ProcessArgs {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::Nil = value {
Ok(Self::from_iter([] as [OsString; 0]))
} else if let LuaValue::Boolean(true) = value {
Ok(Self::current())
} else if let Some(u) = value.as_userdata().and_then(|u| u.borrow::<Self>().ok()) {
Ok(u.clone())
} else if let LuaValue::Table(arr) = value {
let mut args = Vec::new();
for pair in arr.pairs::<LuaValue, LuaValue>() {
let val_res = pair.map(|p| p.1.clone());
let val = super::lua_value_to_os_string(val_res, "ProcessArgs")?;
super::validate_os_value(&val)?;
args.push(val);
}
Ok(Self::from_iter(args))
} else {
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: String::from("ProcessArgs"),
message: Some(format!(
"Invalid type for process args - expected table or nil, got '{}'",
value.type_name()
)),
})
}
}
}
impl LuaUserData for ProcessArgs {
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_meta_method(LuaMetaMethod::Len, |_, this, (): ()| Ok(this.len()));
methods.add_meta_method(LuaMetaMethod::Index, |_, this, index: usize| {
if index == 0 {
Ok(None)
} else {
Ok(this.get(index - 1))
}
});
methods.add_meta_method(LuaMetaMethod::NewIndex, |_, _, (): ()| {
Err::<(), _>(LuaError::runtime("ProcessArgs is read-only"))
});
methods.add_meta_method(LuaMetaMethod::Iter, |lua, this, (): ()| {
let mut vars = this
.clone()
.into_iter()
.filter_map(OsStringBytes::into_io_vec)
.enumerate();
lua.create_function_mut(move |lua, (): ()| match vars.next() {
None => Ok((LuaValue::Nil, LuaValue::Nil)),
Some((index, value)) => Ok((
LuaValue::Integer(index as i32),
LuaValue::String(lua.create_string(value)?),
)),
})
});
}
}

View file

@ -0,0 +1,254 @@
#![allow(clippy::missing_panics_doc)]
use std::{
collections::BTreeMap,
env::vars_os,
ffi::{OsStr, OsString},
sync::{Arc, Mutex},
};
use mlua::prelude::*;
use os_str_bytes::{OsStrBytes, OsStringBytes};
// Inner (shared) struct
#[derive(Debug, Default)]
struct ProcessEnvInner {
values: BTreeMap<OsString, OsString>,
}
impl FromIterator<(OsString, OsString)> for ProcessEnvInner {
fn from_iter<T: IntoIterator<Item = (OsString, OsString)>>(iter: T) -> Self {
Self {
values: iter.into_iter().collect(),
}
}
}
/**
A struct that can be easily shared, stored in Lua app data,
and that also guarantees the pairs are valid OS strings
that can be used for process environment variables.
Usable directly from Lua, implementing both `FromLua` and `LuaUserData`.
Also provides convenience methods for working with the variables
as either `OsString` or `Vec<u8>`, where using the latter implicitly
converts to an `OsString` and fails if the conversion is not possible.
*/
#[derive(Debug, Clone)]
pub struct ProcessEnv {
inner: Arc<Mutex<ProcessEnvInner>>,
}
impl ProcessEnv {
#[must_use]
pub fn empty() -> Self {
Self {
inner: Arc::new(Mutex::new(ProcessEnvInner::default())),
}
}
#[must_use]
pub fn current() -> Self {
Self {
inner: Arc::new(Mutex::new(vars_os().collect())),
}
}
#[must_use]
pub fn len(&self) -> usize {
let inner = self.inner.lock().unwrap();
inner.values.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
let inner = self.inner.lock().unwrap();
inner.values.is_empty()
}
// OS strings
#[must_use]
pub fn get_all(&self) -> Vec<(OsString, OsString)> {
let inner = self.inner.lock().unwrap();
inner.values.clone().into_iter().collect()
}
#[must_use]
pub fn get_value(&self, key: impl AsRef<OsStr>) -> Option<OsString> {
let key = key.as_ref();
super::validate_os_key(key).ok()?;
let inner = self.inner.lock().unwrap();
inner.values.get(key).cloned()
}
pub fn set_value(&self, key: impl Into<OsString>, val: impl Into<OsString>) {
let key = key.into();
let val = val.into();
if super::validate_os_pair((&key, &val)).is_err() {
return;
}
let mut inner = self.inner.lock().unwrap();
inner.values.insert(key, val);
}
pub fn remove_value(&self, key: impl AsRef<OsStr>) {
let key = key.as_ref();
if super::validate_os_key(key).is_err() {
return;
}
let mut inner = self.inner.lock().unwrap();
inner.values.remove(key);
}
// Bytes wrappers
#[must_use]
pub fn get_all_bytes(&self) -> Vec<(Vec<u8>, Vec<u8>)> {
self.get_all()
.into_iter()
.filter_map(|(k, v)| Some((k.into_io_vec()?, v.into_io_vec()?)))
.collect()
}
#[must_use]
pub fn get_value_bytes(&self, key: impl AsRef<[u8]>) -> Option<Vec<u8>> {
let key = OsStr::from_io_bytes(key.as_ref())?;
let val = self.get_value(key)?;
val.into_io_vec()
}
pub fn set_value_bytes(&self, key: impl AsRef<[u8]>, val: impl Into<Vec<u8>>) {
let key = OsStr::from_io_bytes(key.as_ref());
let val = OsString::from_io_vec(val.into());
if let (Some(key), Some(val)) = (key, val) {
self.set_value(key, val);
}
}
pub fn remove_value_bytes(&self, key: impl AsRef<[u8]>) {
let key = OsStr::from_io_bytes(key.as_ref());
if let Some(key) = key {
self.remove_value(key);
}
}
}
// Iterator implementations
impl IntoIterator for ProcessEnv {
type Item = (OsString, OsString);
type IntoIter = std::collections::btree_map::IntoIter<OsString, OsString>;
fn into_iter(self) -> Self::IntoIter {
let inner = self.inner.lock().unwrap();
inner.values.clone().into_iter()
}
}
impl<K: Into<OsString>, V: Into<OsString>> FromIterator<(K, V)> for ProcessEnv {
fn from_iter<T: IntoIterator<Item = (K, V)>>(iter: T) -> Self {
Self {
inner: Arc::new(Mutex::new(
iter.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.filter(|(k, v)| super::validate_os_pair((k, v)).is_ok())
.collect(),
)),
}
}
}
impl<K: Into<OsString>, V: Into<OsString>> Extend<(K, V)> for ProcessEnv {
fn extend<T: IntoIterator<Item = (K, V)>>(&mut self, iter: T) {
let mut inner = self.inner.lock().unwrap();
inner.values.extend(
iter.into_iter()
.map(|(k, v)| (k.into(), v.into()))
.filter(|(k, v)| super::validate_os_pair((k, v)).is_ok()),
);
}
}
// Lua implementations
impl FromLua for ProcessEnv {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::Nil = value {
Ok(Self::from_iter([] as [(OsString, OsString); 0]))
} else if let LuaValue::Boolean(true) = value {
Ok(Self::current())
} else if let Some(u) = value.as_userdata().and_then(|u| u.borrow::<Self>().ok()) {
Ok(u.clone())
} else if let LuaValue::Table(arr) = value {
let mut args = Vec::new();
for pair in arr.pairs::<LuaValue, LuaValue>() {
let (key_res, val_res) = match pair {
Ok((key, val)) => (Ok(key), Ok(val)),
Err(err) => (Err(err.clone()), Err(err)),
};
let key = super::lua_value_to_os_string(key_res, "ProcessEnv")?;
let val = super::lua_value_to_os_string(val_res, "ProcessEnv")?;
super::validate_os_pair((&key, &val))?;
args.push((key, val));
}
Ok(Self::from_iter(args))
} else {
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: String::from("ProcessEnv"),
message: Some(format!(
"Invalid type for process env - expected table or nil, got '{}'",
value.type_name()
)),
})
}
}
}
impl LuaUserData for ProcessEnv {
fn add_methods<M: LuaUserDataMethods<Self>>(methods: &mut M) {
methods.add_meta_method(LuaMetaMethod::Len, |_, this, (): ()| Ok(this.len()));
methods.add_meta_method(LuaMetaMethod::Index, |_, this, key: LuaValue| {
let key = super::lua_value_to_os_string(Ok(key), "OsString")?;
Ok(this.get_value(key))
});
methods.add_meta_method(
LuaMetaMethod::NewIndex,
|_, this, (key, val): (LuaValue, Option<LuaValue>)| {
let key = super::lua_value_to_os_string(Ok(key), "OsString")?;
if let Some(val) = val {
let val = super::lua_value_to_os_string(Ok(val), "OsString")?;
this.set_value(key, val);
} else {
this.remove_value(key);
}
Ok(())
},
);
methods.add_meta_method(LuaMetaMethod::Iter, |lua, this, (): ()| {
let mut vars = this
.clone()
.into_iter()
.filter_map(|(key, val)| Some((key.into_io_vec()?, val.into_io_vec()?)));
lua.create_function_mut(move |lua, (): ()| match vars.next() {
None => Ok((LuaValue::Nil, LuaValue::Nil)),
Some((key, val)) => Ok((
LuaValue::String(lua.create_string(key)?),
LuaValue::String(lua.create_string(val)?),
)),
})
});
}
}

View file

@ -1,29 +1,31 @@
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct JitEnablement(bool); pub struct ProcessJitEnablement {
enabled: bool,
}
impl JitEnablement { impl ProcessJitEnablement {
#[must_use] #[must_use]
pub fn new(enabled: bool) -> Self { pub fn new(enabled: bool) -> Self {
Self(enabled) Self { enabled }
} }
pub fn set_status(&mut self, enabled: bool) { pub fn set_status(&mut self, enabled: bool) {
self.0 = enabled; self.enabled = enabled;
} }
#[must_use] #[must_use]
pub fn enabled(self) -> bool { pub fn enabled(self) -> bool {
self.0 self.enabled
} }
} }
impl From<JitEnablement> for bool { impl From<ProcessJitEnablement> for bool {
fn from(val: JitEnablement) -> Self { fn from(val: ProcessJitEnablement) -> Self {
val.enabled() val.enabled()
} }
} }
impl From<bool> for JitEnablement { impl From<bool> for ProcessJitEnablement {
fn from(val: bool) -> Self { fn from(val: bool) -> Self {
Self::new(val) Self::new(val)
} }

View file

@ -0,0 +1,78 @@
use std::ffi::{OsStr, OsString};
use mlua::prelude::*;
use os_str_bytes::{OsStrBytes, OsStringBytes};
mod args;
mod env;
mod jit;
pub use self::args::ProcessArgs;
pub use self::env::ProcessEnv;
pub use self::jit::ProcessJitEnablement;
fn lua_value_to_os_string(res: LuaResult<LuaValue>, to: &'static str) -> LuaResult<OsString> {
let (btype, bs) = match res {
Ok(LuaValue::String(s)) => ("string", s.as_bytes().to_vec()),
Ok(LuaValue::Buffer(b)) => ("buffer", b.to_vec()),
res => {
let vtype = match res {
Ok(v) => v.type_name(),
Err(_) => "unknown",
};
return Err(LuaError::FromLuaConversionError {
from: vtype,
to: String::from(to),
message: Some(format!(
"Expected value to be a string or buffer, got '{vtype}'",
)),
});
}
};
let Some(s) = OsString::from_io_vec(bs) else {
return Err(LuaError::FromLuaConversionError {
from: btype,
to: String::from(to),
message: Some(String::from("Expected {btype} to contain valid OS bytes")),
});
};
Ok(s)
}
fn validate_os_key(key: &OsStr) -> LuaResult<()> {
let Some(key) = key.to_io_bytes() else {
return Err(LuaError::runtime("Key must be IO-safe"));
};
if key.is_empty() {
Err(LuaError::runtime("Key must not be empty"))
} else if key.contains(&b'=') {
Err(LuaError::runtime(
"Key must not contain the equals character '='",
))
} else if key.contains(&b'\0') {
Err(LuaError::runtime("Key must not contain the NUL character"))
} else {
Ok(())
}
}
fn validate_os_value(val: &OsStr) -> LuaResult<()> {
let Some(val) = val.to_io_bytes() else {
return Err(LuaError::runtime("Value must be IO-safe"));
};
if val.contains(&b'\0') {
Err(LuaError::runtime(
"Value must not contain the NUL character",
))
} else {
Ok(())
}
}
fn validate_os_pair((key, value): (&OsStr, &OsStr)) -> LuaResult<()> {
validate_os_key(key)?;
validate_os_value(value)?;
Ok(())
}

View file

@ -1,6 +1,6 @@
[package] [package]
name = "lune" name = "lune"
version = "0.9.0" version = "0.9.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -26,7 +26,7 @@ std-luau = ["dep:lune-std", "lune-std/luau"]
std-net = ["dep:lune-std", "lune-std/net"] std-net = ["dep:lune-std", "lune-std/net"]
std-process = ["dep:lune-std", "lune-std/process"] std-process = ["dep:lune-std", "lune-std/process"]
std-regex = ["dep:lune-std", "lune-std/regex"] std-regex = ["dep:lune-std", "lune-std/regex"]
std-roblox = ["dep:lune-std", "lune-std/roblox", "dep:lune-roblox"] std-roblox = ["dep:lune-std", "lune-std/roblox"]
std-serde = ["dep:lune-std", "lune-std/serde"] std-serde = ["dep:lune-std", "lune-std/serde"]
std-stdio = ["dep:lune-std", "lune-std/stdio"] std-stdio = ["dep:lune-std", "lune-std/stdio"]
std-task = ["dep:lune-std", "lune-std/task"] std-task = ["dep:lune-std", "lune-std/task"]
@ -51,27 +51,28 @@ workspace = true
[dependencies] [dependencies]
mlua = { version = "0.10.3", features = ["luau"] } mlua = { version = "0.10.3", features = ["luau"] }
mlua-luau-scheduler = { version = "0.1.0", path = "../mlua-luau-scheduler" } mlua-luau-scheduler = { version = "0.1.2", path = "../mlua-luau-scheduler" }
anyhow = "1.0" anyhow = "1.0"
console = "0.15" console = "0.15"
dialoguer = "0.11" dialoguer = "0.11"
directories = "6.0" directories = "6.0"
futures-util = "0.3"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
thiserror = "2.0" thiserror = "2.0"
async-io = "2.4"
async-fs = "2.1"
blocking = "1.6"
futures-lite = "2.6"
rustls = { version = "0.23", default-features = false, features = ["std", "tls12", "ring"] }
ureq = { version = "3.0", default-features = false, features = ["rustls", "gzip"] }
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tokio = { version = "1", features = ["full"] }
reqwest = { version = "0.11", default-features = false, features = [
"rustls-tls",
] }
lune-std = { optional = true, version = "0.2.0", path = "../lune-std" } lune-std = { optional = true, version = "0.2.2", path = "../lune-std" }
lune-roblox = { optional = true, version = "0.2.0", path = "../lune-roblox" } lune-utils = { version = "0.2.2", path = "../lune-utils" }
lune-utils = { version = "0.2.0", path = "../lune-utils" }
### CLI ### CLI

View file

@ -3,7 +3,9 @@ use std::{
path::PathBuf, path::PathBuf,
}; };
use tokio::{fs, task}; use async_fs as fs;
use blocking::unblock;
use rustls::crypto::ring;
use crate::standalone::metadata::CURRENT_EXE; use crate::standalone::metadata::CURRENT_EXE;
@ -44,25 +46,32 @@ pub async fn get_or_download_base_executable(target: BuildTarget) -> BuildResult
// Try to request to download the zip file from the target url, // Try to request to download the zip file from the target url,
// making sure transient errors are handled gracefully and // making sure transient errors are handled gracefully and
// with a different error message than "not found" // with a different error message than "not found"
let response = reqwest::get(release_url).await?; let (res_status, res_body) = unblock(move || {
if !response.status().is_success() { // Only errors if already installed, which is fine
if response.status().as_u16() == 404 { ring::default_provider().install_default().ok();
let mut res = ureq::get(release_url).call()?;
let body = res.body_mut().read_to_vec()?;
Ok::<_, BuildError>((res.status(), body))
})
.await?;
if !res_status.is_success() {
if res_status.as_u16() == 404 {
return Err(BuildError::ReleaseTargetNotFound(target)); return Err(BuildError::ReleaseTargetNotFound(target));
} }
return Err(BuildError::Download( return Err(BuildError::Download(ureq::Error::StatusCode(
response.error_for_status().unwrap_err(), res_status.as_u16(),
)); )));
} }
// Receive the full zip file // Start reading the zip file
let zip_bytes = response.bytes().await?.to_vec(); let zip_file = Cursor::new(res_body);
let zip_file = Cursor::new(zip_bytes);
// Look for and extract the binary file from the zip file // Look for and extract the binary file from the zip file
// NOTE: We use spawn_blocking here since reading a zip // NOTE: We use spawn_blocking here since reading a zip
// archive is a somewhat slow / blocking operation // archive is a somewhat slow / blocking operation
let binary_file_name = format!("lune{}", target.exe_suffix()); let binary_file_name = format!("lune{}", target.exe_suffix());
let binary_file_handle = task::spawn_blocking(move || { let binary_file_handle = unblock(move || {
let mut archive = zip::ZipArchive::new(zip_file)?; let mut archive = zip::ZipArchive::new(zip_file)?;
let mut binary = Vec::new(); let mut binary = Vec::new();
@ -73,7 +82,7 @@ pub async fn get_or_download_base_executable(target: BuildTarget) -> BuildResult
Ok::<_, BuildError>(binary) Ok::<_, BuildError>(binary)
}); });
let binary_file_contents = binary_file_handle.await??; let binary_file_contents = binary_file_handle.await?;
// Finally, write the extracted binary to the cache // Finally, write the extracted binary to the cache
if !CACHE_DIR.exists() { if !CACHE_DIR.exists() {

View file

@ -1,7 +1,8 @@
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use anyhow::Result; use anyhow::Result;
use tokio::{fs, io::AsyncWriteExt}; use async_fs as fs;
use futures_lite::prelude::*;
/** /**
Removes the source file extension from the given path, if it has one. Removes the source file extension from the given path, if it has one.
@ -32,6 +33,7 @@ pub async fn write_executable_file_to(
#[cfg(unix)] #[cfg(unix)]
{ {
use fs::unix::OpenOptionsExt;
options.mode(0o755); // Read & execute for all, write for owner options.mode(0o755); // Read & execute for all, write for owner
} }

View file

@ -1,9 +1,9 @@
use std::{path::PathBuf, process::ExitCode}; use std::{path::PathBuf, process::ExitCode};
use anyhow::{bail, Context, Result}; use anyhow::{bail, Context, Result};
use async_fs as fs;
use clap::Parser; use clap::Parser;
use console::style; use console::style;
use tokio::fs;
use crate::standalone::metadata::Metadata; use crate::standalone::metadata::Metadata;

View file

@ -12,11 +12,9 @@ pub enum BuildError {
#[error("failed to find lune binary '{0}' in downloaded zip file")] #[error("failed to find lune binary '{0}' in downloaded zip file")]
ZippedBinaryNotFound(String), ZippedBinaryNotFound(String),
#[error("failed to download lune binary: {0}")] #[error("failed to download lune binary: {0}")]
Download(#[from] reqwest::Error), Download(#[from] ureq::Error),
#[error("failed to unzip lune binary: {0}")] #[error("failed to unzip lune binary: {0}")]
Unzip(#[from] zip::result::ZipError), Unzip(#[from] zip::result::ZipError),
#[error("panicked while unzipping lune binary: {0}")]
UnzipJoin(#[from] tokio::task::JoinError),
#[error("io error: {0}")] #[error("io error: {0}")]
IoError(#[from] std::io::Error), IoError(#[from] std::io::Error),
} }

View file

@ -1,6 +1,7 @@
use std::{path::PathBuf, process::ExitCode}; use std::{path::PathBuf, process::ExitCode};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_fs as fs;
use clap::Parser; use clap::Parser;
use directories::UserDirs; use directories::UserDirs;
use rustyline::{error::ReadlineError, DefaultEditor}; use rustyline::{error::ReadlineError, DefaultEditor};
@ -28,7 +29,7 @@ impl ReplCommand {
.home_dir() .home_dir()
.join(".lune_history"); .join(".lune_history");
if !history_file_path.exists() { if !history_file_path.exists() {
tokio::fs::write(history_file_path, &[]).await?; fs::write(history_file_path, &[]).await?;
} }
let mut repl = DefaultEditor::new()?; let mut repl = DefaultEditor::new()?;

View file

@ -1,11 +1,10 @@
use std::{env, process::ExitCode}; use std::{env, io::stdin, process::ExitCode};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_fs::read as read_to_vec;
use blocking::Unblock;
use clap::Parser; use clap::Parser;
use tokio::{ use futures_lite::prelude::*;
fs::read as read_to_vec,
io::{stdin, AsyncReadExt as _},
};
use lune::Runtime; use lune::Runtime;
@ -27,7 +26,7 @@ impl RunCommand {
// (dash) as the script name to run to the cli // (dash) as the script name to run to the cli
let (script_display_name, script_contents) = if &self.script_path == "-" { let (script_display_name, script_contents) = if &self.script_path == "-" {
let mut stdin_contents = Vec::new(); let mut stdin_contents = Vec::new();
stdin() Unblock::new(stdin())
.read_to_end(&mut stdin_contents) .read_to_end(&mut stdin_contents)
.await .await
.context("Failed to read script contents from stdin")?; .context("Failed to read script contents from stdin")?;

View file

@ -1,11 +1,10 @@
use std::{borrow::BorrowMut, env::current_dir, io::ErrorKind, path::PathBuf, process::ExitCode}; use std::{borrow::BorrowMut, env::current_dir, io::ErrorKind, path::PathBuf, process::ExitCode};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_fs as fs;
use clap::Parser; use clap::Parser;
use directories::UserDirs; use directories::UserDirs;
use futures_util::future::try_join_all;
use thiserror::Error; use thiserror::Error;
use tokio::fs;
// TODO: Use a library that supports json with comments since VSCode settings may contain comments // TODO: Use a library that supports json with comments since VSCode settings may contain comments
use serde_json::Value as JsonValue; use serde_json::Value as JsonValue;
@ -157,16 +156,12 @@ async fn generate_typedef_files_from_definitions() -> Result<String> {
files_to_write.push((name, path, builtin.typedefs())); files_to_write.push((name, path, builtin.typedefs()));
} }
// Write all dirs and files only when we know generation was successful // Write all dirs and files
let futs_dirs = dirs_to_write for dir in dirs_to_write {
.drain(..) fs::create_dir_all(dir).await?;
.map(fs::create_dir_all) }
.collect::<Vec<_>>(); for (_name, path, contents) in files_to_write {
let futs_files = files_to_write fs::write(path, contents).await?;
.iter() }
.map(|(_, path, contents)| fs::write(path, contents))
.collect::<Vec<_>>();
try_join_all(futs_dirs).await?;
try_join_all(futs_files).await?;
Ok(version_string.to_string()) Ok(version_string.to_string())
} }

View file

@ -1,11 +1,14 @@
#![allow(clippy::match_same_arms)] #![allow(clippy::match_same_arms)]
use std::{cmp::Ordering, ffi::OsStr, fmt::Write as _, path::PathBuf, sync::LazyLock}; use std::{
cmp::Ordering, ffi::OsStr, fmt::Write as _, io::ErrorKind, path::PathBuf, sync::LazyLock,
};
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use async_fs as fs;
use console::Style; use console::Style;
use directories::UserDirs; use directories::UserDirs;
use tokio::{fs, io}; use futures_lite::prelude::*;
use super::files::{discover_script_path, parse_lune_description_from_file}; use super::files::{discover_script_path, parse_lune_description_from_file};
@ -25,7 +28,7 @@ pub async fn find_lune_scripts(in_home_dir: bool) -> Result<Vec<(String, String)
match lune_dir { match lune_dir {
Ok(mut dir) => { Ok(mut dir) => {
let mut files = Vec::new(); let mut files = Vec::new();
while let Some(entry) = dir.next_entry().await? { while let Some(entry) = dir.try_next().await? {
let meta = entry.metadata().await?; let meta = entry.metadata().await?;
if meta.is_file() { if meta.is_file() {
let contents = fs::read(entry.path()).await?; let contents = fs::read(entry.path()).await?;
@ -77,7 +80,7 @@ pub async fn find_lune_scripts(in_home_dir: bool) -> Result<Vec<(String, String)
.collect(); .collect();
Ok(parsed) Ok(parsed)
} }
Err(e) if matches!(e.kind(), io::ErrorKind::NotFound) => { Err(e) if matches!(e.kind(), ErrorKind::NotFound) => {
bail!("No lune directory was found.") bail!("No lune directory was found.")
} }
Err(e) => { Err(e) => {

View file

@ -9,8 +9,7 @@ pub(crate) mod standalone;
use lune_utils::fmt::Label; use lune_utils::fmt::Label;
#[tokio::main(flavor = "multi_thread")] fn main() -> ExitCode {
async fn main() -> ExitCode {
tracing_subscriber::fmt() tracing_subscriber::fmt()
.compact() .compact()
.with_env_filter(tracing_subscriber::filter::EnvFilter::from_default_env()) .with_env_filter(tracing_subscriber::filter::EnvFilter::from_default_env())
@ -20,24 +19,26 @@ async fn main() -> ExitCode {
.with_writer(stderr) .with_writer(stderr)
.init(); .init();
if let Some(bin) = standalone::check().await { async_io::block_on(async {
return standalone::run(bin).await.unwrap(); if let Some(bin) = standalone::check().await {
} return standalone::run(bin).await.unwrap();
}
#[cfg(feature = "cli")] #[cfg(feature = "cli")]
{ {
match cli::Cli::new().run().await { match cli::Cli::new().run().await {
Ok(code) => code, Ok(code) => code,
Err(err) => { Err(err) => {
eprintln!("{}\n{err:?}", Label::Error); eprintln!("{}\n{err:?}", Label::Error);
ExitCode::FAILURE ExitCode::FAILURE
}
} }
} }
}
#[cfg(not(feature = "cli"))] #[cfg(not(feature = "cli"))]
{ {
eprintln!("{}\nCLI feature is disabled", Label::Error); eprintln!("{}\nCLI feature is disabled", Label::Error);
ExitCode::FAILURE ExitCode::FAILURE
} }
})
} }

View file

@ -1,11 +1,14 @@
#![allow(clippy::missing_panics_doc)] #![allow(clippy::missing_panics_doc)]
use std::sync::{ use std::{
atomic::{AtomicBool, Ordering}, ffi::OsString,
Arc, sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
}; };
use lune_utils::jit::JitEnablement; use lune_utils::process::{ProcessArgs, ProcessEnv, ProcessJitEnablement};
use mlua::prelude::*; use mlua::prelude::*;
use mlua_luau_scheduler::{Functions, Scheduler}; use mlua_luau_scheduler::{Functions, Scheduler};
@ -14,6 +17,7 @@ use super::{RuntimeError, RuntimeResult};
/** /**
Values returned by running a Lune runtime until completion. Values returned by running a Lune runtime until completion.
*/ */
#[derive(Debug)]
#[non_exhaustive] #[non_exhaustive]
pub struct RuntimeReturnValues { pub struct RuntimeReturnValues {
/// The exit code manually returned from the runtime, if any. /// The exit code manually returned from the runtime, if any.
@ -57,7 +61,9 @@ impl RuntimeReturnValues {
pub struct Runtime { pub struct Runtime {
lua: Lua, lua: Lua,
sched: Scheduler, sched: Scheduler,
jit: JitEnablement, args: ProcessArgs,
env: ProcessEnv,
jit: ProcessJitEnablement,
} }
impl Runtime { impl Runtime {
@ -74,8 +80,6 @@ impl Runtime {
pub fn new() -> LuaResult<Self> { pub fn new() -> LuaResult<Self> {
let lua = Lua::new(); let lua = Lua::new();
lua.set_app_data(Vec::<String>::new());
let sched = Scheduler::new(lua.clone()); let sched = Scheduler::new(lua.clone());
let fns = Functions::new(lua.clone()).expect("has scheduler"); let fns = Functions::new(lua.clone()).expect("has scheduler");
@ -125,21 +129,47 @@ impl Runtime {
.set(g_table.name(), g_table.create(lua.clone())?)?; .set(g_table.name(), g_table.create(lua.clone())?)?;
} }
let jit = JitEnablement::default(); let args = ProcessArgs::current();
Ok(Self { lua, sched, jit }) let env = ProcessEnv::current();
let jit = ProcessJitEnablement::default();
Ok(Self {
lua,
sched,
args,
env,
jit,
})
} }
/** /**
Sets arguments to give in `process.args` for Lune scripts. Sets arguments to give in `process.args` for Lune scripts.
By default, `std::env::args_os()` is used.
*/ */
#[must_use] #[must_use]
pub fn with_args<A, S>(self, args: A) -> Self pub fn with_args<A, S>(mut self, args: A) -> Self
where where
A: IntoIterator<Item = S>, A: IntoIterator<Item = S>,
S: Into<String>, S: Into<OsString>,
{ {
let args = args.into_iter().map(Into::into).collect::<Vec<_>>(); self.args = args.into_iter().map(Into::into).collect();
self.lua.set_app_data(args); self
}
/**
Sets environment values to give in `process.env` for Lune scripts.
By default, `std::env::vars_os()` is used.
*/
#[must_use]
pub fn with_env<E, K, V>(mut self, env: E) -> Self
where
E: IntoIterator<Item = (K, V)>,
K: Into<OsString>,
V: Into<OsString>,
{
self.env = env.into_iter().map(|(k, v)| (k.into(), v.into())).collect();
self self
} }
@ -147,7 +177,10 @@ impl Runtime {
Enables or disables JIT compilation. Enables or disables JIT compilation.
*/ */
#[must_use] #[must_use]
pub fn with_jit(mut self, jit_status: impl Into<JitEnablement>) -> Self { pub fn with_jit<J>(mut self, jit_status: J) -> Self
where
J: Into<ProcessJitEnablement>,
{
self.jit = jit_status.into(); self.jit = jit_status.into();
self self
} }
@ -174,8 +207,12 @@ impl Runtime {
eprintln!("{}", RuntimeError::from(e)); eprintln!("{}", RuntimeError::from(e));
}); });
// Enable / disable the JIT as requested and store the current status as AppData // Store the provided args, environment variables, and jit enablement as AppData
self.lua.set_app_data(self.args.clone());
self.lua.set_app_data(self.env.clone());
self.lua.set_app_data(self.jit); self.lua.set_app_data(self.jit);
// Enable / disable the JIT as requested, before loading anything
self.lua.enable_jit(self.jit.enabled()); self.lua.enable_jit(self.jit.enabled());
// Load our "main" thread // Load our "main" thread

View file

@ -1,8 +1,8 @@
use std::{env, path::PathBuf, sync::LazyLock}; use std::{env, path::PathBuf, sync::LazyLock};
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use async_fs as fs;
use mlua::Compiler as LuaCompiler; use mlua::Compiler as LuaCompiler;
use tokio::fs;
pub static CURRENT_EXE: LazyLock<PathBuf> = pub static CURRENT_EXE: LazyLock<PathBuf> =
LazyLock::new(|| env::current_exe().expect("failed to get current exe")); LazyLock::new(|| env::current_exe().expect("failed to get current exe"));

View file

@ -3,9 +3,9 @@ use std::path::PathBuf;
use std::process::ExitCode; use std::process::ExitCode;
use anyhow::Result; use anyhow::Result;
use async_fs::read_to_string;
use console::set_colors_enabled; use console::set_colors_enabled;
use console::set_colors_enabled_stderr; use console::set_colors_enabled_stderr;
use tokio::fs::read_to_string;
use lune_utils::path::clean_path_and_make_absolute; use lune_utils::path::clean_path_and_make_absolute;
@ -15,37 +15,33 @@ const ARGS: &[&str] = &["Foo", "Bar"];
macro_rules! create_tests { macro_rules! create_tests {
($($name:ident: $value:expr,)*) => { $( ($($name:ident: $value:expr,)*) => { $(
#[tokio::test(flavor = "multi_thread")] #[test]
async fn $name() -> Result<ExitCode> { fn $name() -> Result<ExitCode> {
// We need to change the current directory to the workspace root since async_io::block_on(async {
// we are in a sub-crate and tests would run relative to the sub-crate // We need to change the current directory to the workspace root since
let workspace_dir_str = format!("{}/../../", env!("CARGO_MANIFEST_DIR")); // we are in a sub-crate and tests would run relative to the sub-crate
let workspace_dir = clean_path_and_make_absolute(PathBuf::from(workspace_dir_str)); let workspace_dir_str = format!("{}/../../", env!("CARGO_MANIFEST_DIR"));
set_current_dir(&workspace_dir)?; let workspace_dir = clean_path_and_make_absolute(PathBuf::from(workspace_dir_str));
set_current_dir(&workspace_dir)?;
// Disable styling for stdout and stderr since // Disable styling for stdout and stderr since
// some tests rely on output not being styled // some tests rely on output not being styled
set_colors_enabled(false); set_colors_enabled(false);
set_colors_enabled_stderr(false); set_colors_enabled_stderr(false);
// The rest of the test logic can continue as normal // The rest of the test logic can continue as normal
let full_name = format!("{}/tests/{}.luau", workspace_dir.display(), $value); let full_name = format!("{}/tests/{}.luau", workspace_dir.display(), $value);
let script = read_to_string(&full_name).await?; let script = read_to_string(&full_name).await?;
let mut lune = Runtime::new()? let mut lune = Runtime::new()?
.with_jit(true) .with_args(ARGS.iter().cloned())
.with_args( .with_jit(true);
ARGS let script_name = full_name
.clone() .trim_end_matches(".luau")
.iter() .trim_end_matches(".lua")
.map(ToString::to_string) .to_string();
.collect::<Vec<_>>() let script_values = lune.run(&script_name, &script).await?;
); Ok(ExitCode::from(script_values.status()))
let script_name = full_name })
.trim_end_matches(".luau")
.trim_end_matches(".lua")
.to_string();
let script_values = lune.run(&script_name, &script).await?;
Ok(ExitCode::from(script_values.status()))
} }
)* } )* }
} }
@ -124,16 +120,23 @@ create_tests! {
create_tests! { create_tests! {
net_request_codes: "net/request/codes", net_request_codes: "net/request/codes",
net_request_compression: "net/request/compression", net_request_compression: "net/request/compression",
net_request_https: "net/request/https",
net_request_methods: "net/request/methods", net_request_methods: "net/request/methods",
net_request_query: "net/request/query", net_request_query: "net/request/query",
net_request_redirect: "net/request/redirect", net_request_redirect: "net/request/redirect",
net_url_encode: "net/url/encode",
net_url_decode: "net/url/decode", net_serve_addresses: "net/serve/addresses",
net_serve_handles: "net/serve/handles",
net_serve_non_blocking: "net/serve/non_blocking",
net_serve_requests: "net/serve/requests", net_serve_requests: "net/serve/requests",
net_serve_websockets: "net/serve/websockets", net_serve_websockets: "net/serve/websockets",
net_socket_basic: "net/socket/basic", net_socket_basic: "net/socket/basic",
net_socket_wss: "net/socket/wss", net_socket_wss: "net/socket/wss",
net_socket_wss_rw: "net/socket/wss_rw", net_socket_wss_rw: "net/socket/wss_rw",
net_url_encode: "net/url/encode",
net_url_decode: "net/url/decode",
} }
#[cfg(feature = "std-process")] #[cfg(feature = "std-process")]

View file

@ -1,6 +1,6 @@
[package] [package]
name = "mlua-luau-scheduler" name = "mlua-luau-scheduler"
version = "0.1.0" version = "0.1.2"
edition = "2021" edition = "2021"
license = "MPL-2.0" license = "MPL-2.0"
repository = "https://github.com/lune-org/lune" repository = "https://github.com/lune-org/lune"
@ -16,13 +16,10 @@ path = "src/lib.rs"
workspace = true workspace = true
[dependencies] [dependencies]
async-executor = "1.8" async-executor = "1.13"
blocking = "1.5" blocking = "1.6"
concurrent-queue = "2.4" futures-lite = "2.6"
derive_more = "0.99" rustc-hash = "2.1"
event-listener = "4.0"
futures-lite = "2.2"
rustc-hash = "1.1"
tracing = "0.1" tracing = "0.1"
mlua = { version = "0.10.3", features = [ mlua = { version = "0.10.3", features = [
@ -34,7 +31,7 @@ mlua = { version = "0.10.3", features = [
[dev-dependencies] [dev-dependencies]
async-fs = "2.1" async-fs = "2.1"
async-io = "2.3" async-io = "2.4"
tracing-subscriber = { version = "0.3", features = ["env-filter"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] }
tracing-tracy = "0.11" tracing-tracy = "0.11"

View file

@ -0,0 +1,5 @@
mod multi;
mod once;
pub(crate) use self::multi::MultiEvent;
pub(crate) use self::once::OnceEvent;

View file

@ -0,0 +1,91 @@
use std::{
cell::{Cell, RefCell},
future::Future,
mem,
pin::Pin,
rc::Rc,
task::{Context, Poll, Waker},
};
/**
Internal state for events.
*/
#[derive(Debug, Default)]
struct MultiEventState {
generation: Cell<u64>,
wakers: RefCell<Vec<Waker>>,
}
/**
A single-threaded event signal that can be notified multiple times.
*/
#[derive(Debug, Clone, Default)]
pub(crate) struct MultiEvent {
state: Rc<MultiEventState>,
}
impl MultiEvent {
/**
Creates a new event.
*/
pub fn new() -> Self {
Self::default()
}
/**
Notifies all waiting listeners.
*/
pub fn notify(&self) {
self.state.generation.set(self.state.generation.get() + 1);
let wakers = {
let mut wakers = self.state.wakers.borrow_mut();
mem::take(&mut *wakers)
};
for waker in wakers {
waker.wake();
}
}
/**
Creates a listener that implements `Future` and resolves when `notify` is called.
*/
pub fn listen(&self) -> MultiListener {
MultiListener {
state: self.state.clone(),
generation: self.state.generation.get(),
}
}
}
/**
A listener future that resolves when the corresponding [`QueueEvent`] is notified.
*/
#[derive(Debug)]
pub(crate) struct MultiListener {
state: Rc<MultiEventState>,
generation: u64,
}
impl Future for MultiListener {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
// Check if notify was called (generation is more recent)
let current = self.state.generation.get();
if current > self.generation {
self.get_mut().generation = current;
return Poll::Ready(());
}
// No notification observed yet
let mut wakers = self.state.wakers.borrow_mut();
if !wakers.iter().any(|w| w.will_wake(cx.waker())) {
wakers.push(cx.waker().clone());
}
Poll::Pending
}
}
impl Unpin for MultiListener {}

View file

@ -0,0 +1,106 @@
use std::{
cell::RefCell,
future::Future,
pin::Pin,
rc::Rc,
task::{Context, Poll, Waker},
};
/**
State which is highly optimized for a single notification event.
`Some` means not notified yet, `None` means notified.
*/
#[derive(Debug, Default)]
struct OnceEventState {
wakers: RefCell<Option<Vec<Waker>>>,
}
impl OnceEventState {
fn new() -> Self {
Self {
wakers: RefCell::new(Some(Vec::new())),
}
}
}
/**
An event that may be notified exactly once.
May be cheaply cloned.
*/
#[derive(Debug, Clone, Default)]
pub struct OnceEvent {
state: Rc<OnceEventState>,
}
impl OnceEvent {
/**
Creates a new event that can be notified exactly once.
*/
pub fn new() -> Self {
let initial_state = OnceEventState::new();
Self {
state: Rc::new(initial_state),
}
}
/**
Notifies waiting listeners.
This is idempotent; subsequent calls do nothing.
*/
pub fn notify(&self) {
if let Some(wakers) = { self.state.wakers.borrow_mut().take() } {
for waker in wakers {
waker.wake();
}
}
}
/**
Creates a listener that implements `Future` and resolves when `notify` is called.
If `notify` has already been called, the future will resolve immediately.
*/
pub fn listen(&self) -> OnceListener {
OnceListener {
state: self.state.clone(),
}
}
}
/**
A listener that resolves when the event is notified.
May be cheaply cloned.
See [`OnceEvent`] for more information.
*/
#[derive(Debug)]
pub struct OnceListener {
state: Rc<OnceEventState>,
}
impl Future for OnceListener {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut wakers_guard = self.state.wakers.borrow_mut();
match &mut *wakers_guard {
Some(wakers) => {
// Not yet notified
if !wakers.iter().any(|w| w.will_wake(cx.waker())) {
wakers.push(cx.waker().clone());
}
Poll::Pending
}
None => {
// Already notified
Poll::Ready(())
}
}
}
}
impl Unpin for OnceListener {}

View file

@ -1,24 +1,24 @@
use std::{cell::Cell, rc::Rc}; use std::{cell::Cell, rc::Rc};
use event_listener::Event; use crate::events::OnceEvent;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub(crate) struct Exit { pub(crate) struct Exit {
code: Rc<Cell<Option<u8>>>, code: Rc<Cell<Option<u8>>>,
event: Rc<Event>, event: OnceEvent,
} }
impl Exit { impl Exit {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
code: Rc::new(Cell::new(None)), code: Rc::new(Cell::new(None)),
event: Rc::new(Event::new()), event: OnceEvent::new(),
} }
} }
pub fn set(&self, code: u8) { pub fn set(&self, code: u8) {
self.code.set(Some(code)); self.code.set(Some(code));
self.event.notify(usize::MAX); self.event.notify();
} }
pub fn get(&self) -> Option<u8> { pub fn get(&self) -> Option<u8> {

View file

@ -5,10 +5,9 @@ use mlua::prelude::*;
use crate::{ use crate::{
error_callback::ThreadErrorCallback, error_callback::ThreadErrorCallback,
queue::{DeferredThreadQueue, SpawnedThreadQueue}, queue::{DeferredThreadQueue, SpawnedThreadQueue},
result_map::ThreadResultMap, threads::{ThreadId, ThreadMap},
thread_id::ThreadId,
traits::LuaSchedulerExt, traits::LuaSchedulerExt,
util::{is_poll_pending, LuaThreadOrFunction, ThreadResult}, util::{is_poll_pending, LuaThreadOrFunction},
}; };
const ERR_METADATA_NOT_ATTACHED: &str = "\ const ERR_METADATA_NOT_ATTACHED: &str = "\
@ -102,13 +101,13 @@ impl Functions {
.app_data_ref::<ThreadErrorCallback>() .app_data_ref::<ThreadErrorCallback>()
.expect(ERR_METADATA_NOT_ATTACHED) .expect(ERR_METADATA_NOT_ATTACHED)
.clone(); .clone();
let result_map = lua let thread_map = lua
.app_data_ref::<ThreadResultMap>() .app_data_ref::<ThreadMap>()
.expect(ERR_METADATA_NOT_ATTACHED) .expect(ERR_METADATA_NOT_ATTACHED)
.clone(); .clone();
let resume_queue = defer_queue.clone(); let resume_queue = defer_queue.clone();
let resume_map = result_map.clone(); let resume_map = thread_map.clone();
let resume = let resume =
lua.create_function(move |lua, (thread, args): (LuaThread, LuaMultiValue)| { lua.create_function(move |lua, (thread, args): (LuaThread, LuaMultiValue)| {
let _span = tracing::trace_span!("Scheduler::fn_resume").entered(); let _span = tracing::trace_span!("Scheduler::fn_resume").entered();
@ -123,8 +122,7 @@ impl Functions {
if thread.status() != LuaThreadStatus::Resumable { if thread.status() != LuaThreadStatus::Resumable {
let id = ThreadId::from(&thread); let id = ThreadId::from(&thread);
if resume_map.is_tracked(id) { if resume_map.is_tracked(id) {
let res = ThreadResult::new(Ok(v.clone()), lua); resume_map.insert(id, Ok(v.clone()));
resume_map.insert(id, res);
} }
} }
(true, v).into_lua_multi(lua) (true, v).into_lua_multi(lua)
@ -134,8 +132,7 @@ impl Functions {
// Not pending, store the error // Not pending, store the error
let id = ThreadId::from(&thread); let id = ThreadId::from(&thread);
if resume_map.is_tracked(id) { if resume_map.is_tracked(id) {
let res = ThreadResult::new(Err(e.clone()), lua); resume_map.insert(id, Err(e.clone()));
resume_map.insert(id, res);
} }
(false, e.to_string()).into_lua_multi(lua) (false, e.to_string()).into_lua_multi(lua)
} }
@ -160,7 +157,7 @@ impl Functions {
.set_environment(wrap_env) .set_environment(wrap_env)
.into_function()?; .into_function()?;
let spawn_map = result_map.clone(); let spawn_map = thread_map.clone();
let spawn = lua.create_function( let spawn = lua.create_function(
move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| { move |lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
let _span = tracing::trace_span!("Scheduler::fn_spawn").entered(); let _span = tracing::trace_span!("Scheduler::fn_spawn").entered();
@ -177,8 +174,7 @@ impl Functions {
if thread.status() != LuaThreadStatus::Resumable { if thread.status() != LuaThreadStatus::Resumable {
let id = ThreadId::from(&thread); let id = ThreadId::from(&thread);
if spawn_map.is_tracked(id) { if spawn_map.is_tracked(id) {
let res = ThreadResult::new(Ok(v), lua); spawn_map.insert(id, Ok(v));
spawn_map.insert(id, res);
} }
} }
} }
@ -188,8 +184,7 @@ impl Functions {
// Not pending, store the error // Not pending, store the error
let id = ThreadId::from(&thread); let id = ThreadId::from(&thread);
if spawn_map.is_tracked(id) { if spawn_map.is_tracked(id) {
let res = ThreadResult::new(Err(e), lua); spawn_map.insert(id, Err(e));
spawn_map.insert(id, res);
} }
} }
} }

View file

@ -1,18 +1,18 @@
#![allow(clippy::cargo_common_metadata)] #![allow(clippy::cargo_common_metadata)]
mod error_callback; mod error_callback;
mod events;
mod exit; mod exit;
mod functions; mod functions;
mod queue; mod queue;
mod result_map;
mod scheduler; mod scheduler;
mod status; mod status;
mod thread_id; mod threads;
mod traits; mod traits;
mod util; mod util;
pub use functions::Functions; pub use functions::Functions;
pub use scheduler::Scheduler; pub use scheduler::Scheduler;
pub use status::Status; pub use status::Status;
pub use thread_id::ThreadId; pub use threads::ThreadId;
pub use traits::{IntoLuaThread, LuaSchedulerExt, LuaSpawnExt}; pub use traits::{IntoLuaThread, LuaSchedulerExt, LuaSpawnExt};

View file

@ -1,139 +0,0 @@
use std::{pin::Pin, rc::Rc};
use concurrent_queue::ConcurrentQueue;
use derive_more::{Deref, DerefMut};
use event_listener::Event;
use futures_lite::{Future, FutureExt};
use mlua::prelude::*;
use crate::{traits::IntoLuaThread, util::ThreadWithArgs, ThreadId};
/**
Queue for storing [`LuaThread`]s with associated arguments.
Provides methods for pushing and draining the queue, as
well as listening for new items being pushed to the queue.
*/
#[derive(Debug, Clone)]
pub(crate) struct ThreadQueue {
queue: Rc<ConcurrentQueue<ThreadWithArgs>>,
event: Rc<Event>,
}
impl ThreadQueue {
pub fn new() -> Self {
let queue = Rc::new(ConcurrentQueue::unbounded());
let event = Rc::new(Event::new());
Self { queue, event }
}
pub fn push_item(
&self,
lua: &Lua,
thread: impl IntoLuaThread,
args: impl IntoLuaMulti,
) -> LuaResult<ThreadId> {
let thread = thread.into_lua_thread(lua)?;
let args = args.into_lua_multi(lua)?;
tracing::trace!("pushing item to queue with {} args", args.len());
let id = ThreadId::from(&thread);
let stored = ThreadWithArgs::new(lua, thread, args)?;
self.queue.push(stored).into_lua_err()?;
self.event.notify(usize::MAX);
Ok(id)
}
#[inline]
pub fn drain_items<'outer, 'lua>(
&'outer self,
lua: &'lua Lua,
) -> impl Iterator<Item = (LuaThread, LuaMultiValue)> + 'outer
where
'lua: 'outer,
{
self.queue.try_iter().map(|stored| stored.into_inner(lua))
}
#[inline]
pub async fn wait_for_item(&self) {
if self.queue.is_empty() {
let listener = self.event.listen();
// NOTE: Need to check again, we could have gotten
// new queued items while creating our listener
if self.queue.is_empty() {
listener.await;
}
}
}
#[inline]
pub fn is_empty(&self) -> bool {
self.queue.is_empty()
}
}
/**
Alias for [`ThreadQueue`], providing a newtype to store in Lua app data.
*/
#[derive(Debug, Clone, Deref, DerefMut)]
pub(crate) struct SpawnedThreadQueue(ThreadQueue);
impl SpawnedThreadQueue {
pub fn new() -> Self {
Self(ThreadQueue::new())
}
}
/**
Alias for [`ThreadQueue`], providing a newtype to store in Lua app data.
*/
#[derive(Debug, Clone, Deref, DerefMut)]
pub(crate) struct DeferredThreadQueue(ThreadQueue);
impl DeferredThreadQueue {
pub fn new() -> Self {
Self(ThreadQueue::new())
}
}
pub type LocalBoxFuture<'fut> = Pin<Box<dyn Future<Output = ()> + 'fut>>;
/**
Queue for storing local futures.
Provides methods for pushing and draining the queue, as
well as listening for new items being pushed to the queue.
*/
#[derive(Debug, Clone)]
pub(crate) struct FuturesQueue<'fut> {
queue: Rc<ConcurrentQueue<LocalBoxFuture<'fut>>>,
event: Rc<Event>,
}
impl<'fut> FuturesQueue<'fut> {
pub fn new() -> Self {
let queue = Rc::new(ConcurrentQueue::unbounded());
let event = Rc::new(Event::new());
Self { queue, event }
}
pub fn push_item(&self, fut: impl Future<Output = ()> + 'fut) {
let _ = self.queue.push(fut.boxed_local());
self.event.notify(usize::MAX);
}
pub fn drain_items<'outer>(
&'outer self,
) -> impl Iterator<Item = LocalBoxFuture<'fut>> + 'outer {
self.queue.try_iter()
}
pub async fn wait_for_item(&self) {
if self.queue.is_empty() {
self.event.listen().await;
}
}
}

View file

@ -0,0 +1,28 @@
use std::ops::{Deref, DerefMut};
use super::threads::ThreadQueue;
/**
Alias for [`ThreadQueue`], providing a newtype to store in Lua app data.
*/
#[derive(Debug, Clone)]
pub(crate) struct DeferredThreadQueue(ThreadQueue);
impl DeferredThreadQueue {
pub fn new() -> Self {
Self(ThreadQueue::new())
}
}
impl Deref for DeferredThreadQueue {
type Target = ThreadQueue;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for DeferredThreadQueue {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

View file

@ -0,0 +1,55 @@
use std::{cell::RefCell, mem, pin::Pin, rc::Rc};
use futures_lite::prelude::*;
use crate::events::MultiEvent;
pub type LocalBoxFuture<'fut> = Pin<Box<dyn Future<Output = ()> + 'fut>>;
struct FuturesQueueInner<'fut> {
queue: RefCell<Vec<LocalBoxFuture<'fut>>>,
event: MultiEvent,
}
impl FuturesQueueInner<'_> {
pub fn new() -> Self {
Self {
queue: RefCell::new(Vec::new()),
event: MultiEvent::new(),
}
}
}
/**
Queue for storing local futures.
Provides methods for pushing and draining the queue, as
well as listening for new items being pushed to the queue.
*/
#[derive(Clone)]
pub(crate) struct FuturesQueue<'fut> {
inner: Rc<FuturesQueueInner<'fut>>,
}
impl<'fut> FuturesQueue<'fut> {
pub fn new() -> Self {
let inner = Rc::new(FuturesQueueInner::new());
Self { inner }
}
pub fn push_item(&self, fut: impl Future<Output = ()> + 'fut) {
self.inner.queue.borrow_mut().push(fut.boxed_local());
self.inner.event.notify();
}
pub fn take_items(&self) -> Vec<LocalBoxFuture<'fut>> {
let mut queue = self.inner.queue.borrow_mut();
mem::take(&mut *queue)
}
pub async fn wait_for_item(&self) {
if self.inner.queue.borrow().is_empty() {
self.inner.event.listen().await;
}
}
}

View file

@ -0,0 +1,8 @@
mod deferred;
mod futures;
mod spawned;
mod threads;
pub(crate) use self::deferred::DeferredThreadQueue;
pub(crate) use self::futures::FuturesQueue;
pub(crate) use self::spawned::SpawnedThreadQueue;

View file

@ -0,0 +1,28 @@
use std::ops::{Deref, DerefMut};
use super::threads::ThreadQueue;
/**
Alias for [`ThreadQueue`], providing a newtype to store in Lua app data.
*/
#[derive(Debug, Clone)]
pub(crate) struct SpawnedThreadQueue(ThreadQueue);
impl SpawnedThreadQueue {
pub fn new() -> Self {
Self(ThreadQueue::new())
}
}
impl Deref for SpawnedThreadQueue {
type Target = ThreadQueue;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl DerefMut for SpawnedThreadQueue {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}

View file

@ -0,0 +1,78 @@
#![allow(clippy::inline_always)]
use std::{cell::RefCell, mem, rc::Rc};
use mlua::prelude::*;
use crate::{threads::ThreadId, traits::IntoLuaThread};
use crate::events::MultiEvent;
#[derive(Debug)]
struct ThreadQueueInner {
queue: RefCell<Vec<(LuaThread, LuaMultiValue)>>,
event: MultiEvent,
}
impl ThreadQueueInner {
fn new() -> Self {
Self {
queue: RefCell::new(Vec::new()),
event: MultiEvent::new(),
}
}
}
/**
Queue for storing [`LuaThread`]s with associated arguments.
Provides methods for pushing and draining the queue, as
well as listening for new items being pushed to the queue.
*/
#[derive(Debug, Clone)]
pub(crate) struct ThreadQueue {
inner: Rc<ThreadQueueInner>,
}
impl ThreadQueue {
pub fn new() -> Self {
let inner = Rc::new(ThreadQueueInner::new());
Self { inner }
}
pub fn push_item(
&self,
lua: &Lua,
thread: impl IntoLuaThread,
args: impl IntoLuaMulti,
) -> LuaResult<ThreadId> {
let thread = thread.into_lua_thread(lua)?;
let args = args.into_lua_multi(lua)?;
tracing::trace!("pushing item to queue with {} args", args.len());
let id = ThreadId::from(&thread);
self.inner.queue.borrow_mut().push((thread, args));
self.inner.event.notify();
Ok(id)
}
#[inline(always)]
pub fn take_items(&self) -> Vec<(LuaThread, LuaMultiValue)> {
let mut queue = self.inner.queue.borrow_mut();
mem::take(&mut *queue)
}
#[inline(always)]
pub async fn wait_for_item(&self) {
if self.inner.queue.borrow().is_empty() {
self.inner.event.listen().await;
}
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.inner.queue.borrow().is_empty()
}
}

View file

@ -1,64 +0,0 @@
#![allow(clippy::inline_always)]
use std::{cell::RefCell, rc::Rc};
use event_listener::Event;
// NOTE: This is the hash algorithm that mlua also uses, so we
// are not adding any additional dependencies / bloat by using it.
use rustc_hash::{FxHashMap, FxHashSet};
use crate::{thread_id::ThreadId, util::ThreadResult};
#[derive(Clone)]
pub(crate) struct ThreadResultMap {
tracked: Rc<RefCell<FxHashSet<ThreadId>>>,
results: Rc<RefCell<FxHashMap<ThreadId, ThreadResult>>>,
events: Rc<RefCell<FxHashMap<ThreadId, Rc<Event>>>>,
}
impl ThreadResultMap {
pub fn new() -> Self {
Self {
tracked: Rc::new(RefCell::new(FxHashSet::default())),
results: Rc::new(RefCell::new(FxHashMap::default())),
events: Rc::new(RefCell::new(FxHashMap::default())),
}
}
#[inline(always)]
pub fn track(&self, id: ThreadId) {
self.tracked.borrow_mut().insert(id);
}
#[inline(always)]
pub fn is_tracked(&self, id: ThreadId) -> bool {
self.tracked.borrow().contains(&id)
}
pub fn insert(&self, id: ThreadId, result: ThreadResult) {
debug_assert!(self.is_tracked(id), "Thread must be tracked");
self.results.borrow_mut().insert(id, result);
if let Some(event) = self.events.borrow_mut().remove(&id) {
event.notify(usize::MAX);
}
}
pub async fn listen(&self, id: ThreadId) {
debug_assert!(self.is_tracked(id), "Thread must be tracked");
if !self.results.borrow().contains_key(&id) {
let listener = {
let mut events = self.events.borrow_mut();
let event = events.entry(id).or_insert_with(|| Rc::new(Event::new()));
event.listen()
};
listener.await;
}
}
pub fn remove(&self, id: ThreadId) -> Option<ThreadResult> {
let res = self.results.borrow_mut().remove(&id)?;
self.tracked.borrow_mut().remove(&id);
self.events.borrow_mut().remove(&id);
Some(res)
}
}

View file

@ -2,7 +2,7 @@
use std::{ use std::{
cell::Cell, cell::Cell,
rc::{Rc, Weak as WeakRc}, rc::Rc,
sync::{Arc, Weak as WeakArc}, sync::{Arc, Weak as WeakArc},
thread::panicking, thread::panicking,
}; };
@ -17,11 +17,10 @@ use crate::{
error_callback::ThreadErrorCallback, error_callback::ThreadErrorCallback,
exit::Exit, exit::Exit,
queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue}, queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue},
result_map::ThreadResultMap,
status::Status, status::Status,
thread_id::ThreadId, threads::{ThreadId, ThreadMap},
traits::IntoLuaThread, traits::IntoLuaThread,
util::{run_until_yield, ThreadResult}, util::run_until_yield,
}; };
const ERR_METADATA_ALREADY_ATTACHED: &str = "\ const ERR_METADATA_ALREADY_ATTACHED: &str = "\
@ -48,7 +47,7 @@ pub struct Scheduler {
queue_spawn: SpawnedThreadQueue, queue_spawn: SpawnedThreadQueue,
queue_defer: DeferredThreadQueue, queue_defer: DeferredThreadQueue,
error_callback: ThreadErrorCallback, error_callback: ThreadErrorCallback,
result_map: ThreadResultMap, thread_map: ThreadMap,
status: Rc<Cell<Status>>, status: Rc<Cell<Status>>,
exit: Exit, exit: Exit,
} }
@ -68,7 +67,7 @@ impl Scheduler {
let queue_spawn = SpawnedThreadQueue::new(); let queue_spawn = SpawnedThreadQueue::new();
let queue_defer = DeferredThreadQueue::new(); let queue_defer = DeferredThreadQueue::new();
let error_callback = ThreadErrorCallback::default(); let error_callback = ThreadErrorCallback::default();
let result_map = ThreadResultMap::new(); let result_map = ThreadMap::new();
let exit = Exit::new(); let exit = Exit::new();
assert!( assert!(
@ -84,7 +83,7 @@ impl Scheduler {
"{ERR_METADATA_ALREADY_ATTACHED}" "{ERR_METADATA_ALREADY_ATTACHED}"
); );
assert!( assert!(
lua.app_data_ref::<ThreadResultMap>().is_none(), lua.app_data_ref::<ThreadMap>().is_none(),
"{ERR_METADATA_ALREADY_ATTACHED}" "{ERR_METADATA_ALREADY_ATTACHED}"
); );
assert!( assert!(
@ -105,7 +104,7 @@ impl Scheduler {
queue_spawn, queue_spawn,
queue_defer, queue_defer,
error_callback, error_callback,
result_map, thread_map: result_map,
status, status,
exit, exit,
} }
@ -201,7 +200,7 @@ impl Scheduler {
args: impl IntoLuaMulti, args: impl IntoLuaMulti,
) -> LuaResult<ThreadId> { ) -> LuaResult<ThreadId> {
let id = self.queue_spawn.push_item(&self.lua, thread, args)?; let id = self.queue_spawn.push_item(&self.lua, thread, args)?;
self.result_map.track(id); self.thread_map.track(id);
Ok(id) Ok(id)
} }
@ -228,7 +227,7 @@ impl Scheduler {
args: impl IntoLuaMulti, args: impl IntoLuaMulti,
) -> LuaResult<ThreadId> { ) -> LuaResult<ThreadId> {
let id = self.queue_defer.push_item(&self.lua, thread, args)?; let id = self.queue_defer.push_item(&self.lua, thread, args)?;
self.result_map.track(id); self.thread_map.track(id);
Ok(id) Ok(id)
} }
@ -248,7 +247,7 @@ impl Scheduler {
*/ */
#[must_use] #[must_use]
pub fn get_thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue>> { pub fn get_thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue>> {
self.result_map.remove(id).map(|r| r.value(&self.lua)) self.thread_map.remove(id)
} }
/** /**
@ -257,7 +256,7 @@ impl Scheduler {
This will return instantly if the thread has already completed. This will return instantly if the thread has already completed.
*/ */
pub async fn wait_for_thread(&self, id: ThreadId) { pub async fn wait_for_thread(&self, id: ThreadId) {
self.result_map.listen(id).await; self.thread_map.listen(id).await;
} }
/** /**
@ -286,7 +285,7 @@ impl Scheduler {
*/ */
let local_exec = LocalExecutor::new(); let local_exec = LocalExecutor::new();
let main_exec = Arc::new(Executor::new()); let main_exec = Arc::new(Executor::new());
let fut_queue = Rc::new(FuturesQueue::new()); let fut_queue = FuturesQueue::new();
/* /*
Store the main executor and queue in Lua, so that they may be used with LuaSchedulerExt. Store the main executor and queue in Lua, so that they may be used with LuaSchedulerExt.
@ -299,12 +298,12 @@ impl Scheduler {
"{ERR_METADATA_ALREADY_ATTACHED}" "{ERR_METADATA_ALREADY_ATTACHED}"
); );
assert!( assert!(
self.lua.app_data_ref::<WeakRc<FuturesQueue>>().is_none(), self.lua.app_data_ref::<FuturesQueue>().is_none(),
"{ERR_METADATA_ALREADY_ATTACHED}" "{ERR_METADATA_ALREADY_ATTACHED}"
); );
self.lua.set_app_data(Arc::downgrade(&main_exec)); self.lua.set_app_data(Arc::downgrade(&main_exec));
self.lua.set_app_data(Rc::downgrade(&fut_queue.clone())); self.lua.set_app_data(fut_queue.clone());
/* /*
Manually tick the Lua executor, while running under the main executor. Manually tick the Lua executor, while running under the main executor.
@ -320,7 +319,7 @@ impl Scheduler {
when there are new Lua threads to enqueue and potentially more work to be done. when there are new Lua threads to enqueue and potentially more work to be done.
*/ */
let fut = async { let fut = async {
let result_map = self.result_map.clone(); let result_map = self.thread_map.clone();
let process_thread = |thread: LuaThread, args| { let process_thread = |thread: LuaThread, args| {
// NOTE: Thread may have been cancelled from Lua // NOTE: Thread may have been cancelled from Lua
// before we got here, so we need to check it again // before we got here, so we need to check it again
@ -342,8 +341,7 @@ impl Scheduler {
self.error_callback.call(e); self.error_callback.call(e);
} }
if thread.status() != LuaThreadStatus::Resumable { if thread.status() != LuaThreadStatus::Resumable {
let thread_res = ThreadResult::new(res, &self.lua); result_map_inner.unwrap().insert(id, res);
result_map_inner.unwrap().insert(id, thread_res);
} }
} }
} else { } else {
@ -398,21 +396,21 @@ impl Scheduler {
let mut num_futures = 0; let mut num_futures = 0;
{ {
let _span = trace_span!("Scheduler::drain_spawned").entered(); let _span = trace_span!("Scheduler::drain_spawned").entered();
for (thread, args) in self.queue_spawn.drain_items(&self.lua) { for (thread, args) in self.queue_spawn.take_items() {
process_thread(thread, args); process_thread(thread, args);
num_spawned += 1; num_spawned += 1;
} }
} }
{ {
let _span = trace_span!("Scheduler::drain_deferred").entered(); let _span = trace_span!("Scheduler::drain_deferred").entered();
for (thread, args) in self.queue_defer.drain_items(&self.lua) { for (thread, args) in self.queue_defer.take_items() {
process_thread(thread, args); process_thread(thread, args);
num_deferred += 1; num_deferred += 1;
} }
} }
{ {
let _span = trace_span!("Scheduler::drain_futures").entered(); let _span = trace_span!("Scheduler::drain_futures").entered();
for fut in fut_queue.drain_items() { for fut in fut_queue.take_items() {
local_exec.spawn(fut).detach(); local_exec.spawn(fut).detach();
num_futures += 1; num_futures += 1;
} }
@ -446,7 +444,7 @@ impl Scheduler {
.remove_app_data::<WeakArc<Executor>>() .remove_app_data::<WeakArc<Executor>>()
.expect(ERR_METADATA_REMOVED); .expect(ERR_METADATA_REMOVED);
self.lua self.lua
.remove_app_data::<WeakRc<FuturesQueue>>() .remove_app_data::<FuturesQueue>()
.expect(ERR_METADATA_REMOVED); .expect(ERR_METADATA_REMOVED);
} }
} }
@ -459,7 +457,7 @@ impl Drop for Scheduler {
self.lua.remove_app_data::<SpawnedThreadQueue>(); self.lua.remove_app_data::<SpawnedThreadQueue>();
self.lua.remove_app_data::<DeferredThreadQueue>(); self.lua.remove_app_data::<DeferredThreadQueue>();
self.lua.remove_app_data::<ThreadErrorCallback>(); self.lua.remove_app_data::<ThreadErrorCallback>();
self.lua.remove_app_data::<ThreadResultMap>(); self.lua.remove_app_data::<ThreadMap>();
self.lua.remove_app_data::<Exit>(); self.lua.remove_app_data::<Exit>();
} else { } else {
// In any other case we panic if metadata was removed incorrectly // In any other case we panic if metadata was removed incorrectly
@ -473,7 +471,7 @@ impl Drop for Scheduler {
.remove_app_data::<ThreadErrorCallback>() .remove_app_data::<ThreadErrorCallback>()
.expect(ERR_METADATA_REMOVED); .expect(ERR_METADATA_REMOVED);
self.lua self.lua
.remove_app_data::<ThreadResultMap>() .remove_app_data::<ThreadMap>()
.expect(ERR_METADATA_REMOVED); .expect(ERR_METADATA_REMOVED);
self.lua self.lua
.remove_app_data::<Exit>() .remove_app_data::<Exit>()

View file

@ -1,4 +1,7 @@
use std::hash::{Hash, Hasher}; use std::{
ffi::c_void,
hash::{Hash, Hasher},
};
use mlua::prelude::*; use mlua::prelude::*;
@ -12,13 +15,13 @@ use mlua::prelude::*;
*/ */
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ThreadId { pub struct ThreadId {
inner: usize, inner: *const c_void,
} }
impl From<&LuaThread> for ThreadId { impl From<&LuaThread> for ThreadId {
fn from(thread: &LuaThread) -> Self { fn from(thread: &LuaThread) -> Self {
Self { Self {
inner: thread.to_pointer() as usize, inner: thread.to_pointer(),
} }
} }
} }

View file

@ -0,0 +1,77 @@
#![allow(clippy::inline_always)]
use std::{cell::RefCell, rc::Rc};
use mlua::prelude::*;
use rustc_hash::FxHashMap;
use super::id::ThreadId;
use crate::events::OnceEvent;
struct ThreadEvent {
result: Option<LuaResult<LuaMultiValue>>,
event: OnceEvent,
}
impl ThreadEvent {
fn new() -> Self {
Self {
result: None,
event: OnceEvent::new(),
}
}
}
#[derive(Clone)]
pub(crate) struct ThreadMap {
inner: Rc<RefCell<FxHashMap<ThreadId, ThreadEvent>>>,
}
impl ThreadMap {
pub fn new() -> Self {
let inner = Rc::new(RefCell::new(FxHashMap::default()));
Self { inner }
}
#[inline(always)]
pub fn track(&self, id: ThreadId) {
self.inner.borrow_mut().insert(id, ThreadEvent::new());
}
#[inline(always)]
pub fn is_tracked(&self, id: ThreadId) -> bool {
self.inner.borrow().contains_key(&id)
}
#[inline(always)]
pub fn insert(&self, id: ThreadId, result: LuaResult<LuaMultiValue>) {
if let Some(tracker) = self.inner.borrow_mut().get_mut(&id) {
tracker.result.replace(result);
tracker.event.notify();
} else {
panic!("Thread must be tracked");
}
}
#[inline(always)]
pub async fn listen(&self, id: ThreadId) {
if let Some(listener) = {
let inner = self.inner.borrow();
let tracker = inner.get(&id);
tracker.map(|t| t.event.listen())
} {
listener.await;
} else {
panic!("Thread must be tracked");
}
}
#[inline(always)]
pub fn remove(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue>> {
if let Some(mut tracker) = self.inner.borrow_mut().remove(&id) {
tracker.result.take()
} else {
None
}
}
}

View file

@ -0,0 +1,5 @@
mod id;
mod map;
pub use id::ThreadId;
pub(crate) use map::ThreadMap;

View file

@ -12,9 +12,8 @@ use tracing::trace;
use crate::{ use crate::{
exit::Exit, exit::Exit,
queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue}, queue::{DeferredThreadQueue, FuturesQueue, SpawnedThreadQueue},
result_map::ThreadResultMap,
scheduler::Scheduler, scheduler::Scheduler,
thread_id::ThreadId, threads::{ThreadId, ThreadMap},
}; };
/** /**
@ -314,21 +313,21 @@ impl LuaSchedulerExt for Lua {
fn track_thread(&self, id: ThreadId) { fn track_thread(&self, id: ThreadId) {
let map = self let map = self
.app_data_ref::<ThreadResultMap>() .app_data_ref::<ThreadMap>()
.expect("lua threads can only be tracked from within an active scheduler"); .expect("lua threads can only be tracked from within an active scheduler");
map.track(id); map.track(id);
} }
fn get_thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue>> { fn get_thread_result(&self, id: ThreadId) -> Option<LuaResult<LuaMultiValue>> {
let map = self let map = self
.app_data_ref::<ThreadResultMap>() .app_data_ref::<ThreadMap>()
.expect("lua threads results can only be retrieved from within an active scheduler"); .expect("lua threads results can only be retrieved from within an active scheduler");
map.remove(id).map(|r| r.value(self)) map.remove(id)
} }
fn wait_for_thread(&self, id: ThreadId) -> impl Future<Output = ()> { fn wait_for_thread(&self, id: ThreadId) -> impl Future<Output = ()> {
let map = self let map = self
.app_data_ref::<ThreadResultMap>() .app_data_ref::<ThreadMap>()
.expect("lua threads results can only be retrieved from within an active scheduler"); .expect("lua threads results can only be retrieved from within an active scheduler");
async move { map.listen(id).await } async move { map.listen(id).await }
} }
@ -354,10 +353,8 @@ impl LuaSpawnExt for Lua {
F: Future<Output = ()> + 'static, F: Future<Output = ()> + 'static,
{ {
let queue = self let queue = self
.app_data_ref::<WeakRc<FuturesQueue>>() .app_data_ref::<FuturesQueue>()
.expect("tasks can only be spawned within an active scheduler") .expect("tasks can only be spawned within an active scheduler");
.upgrade()
.expect("executor was dropped");
trace!("spawning local task on executor"); trace!("spawning local task on executor");
queue.push_item(fut); queue.push_item(fut);
} }

View file

@ -40,74 +40,6 @@ pub(crate) fn is_poll_pending(value: &LuaValue) -> bool {
.is_some_and(|l| l == Lua::poll_pending()) .is_some_and(|l| l == Lua::poll_pending())
} }
/**
Representation of a [`LuaResult`] with an associated [`LuaMultiValue`] currently stored in the Lua registry.
*/
#[derive(Debug)]
pub(crate) struct ThreadResult {
inner: LuaResult<LuaRegistryKey>,
}
impl ThreadResult {
pub fn new(result: LuaResult<LuaMultiValue>, lua: &Lua) -> Self {
Self {
inner: match result {
Ok(v) => Ok({
let vec = v.into_vec();
lua.create_registry_value(vec).expect("out of memory")
}),
Err(e) => Err(e),
},
}
}
pub fn value(self, lua: &Lua) -> LuaResult<LuaMultiValue> {
match self.inner {
Ok(key) => {
let vec = lua.registry_value(&key).unwrap();
lua.remove_registry_value(key).unwrap();
Ok(LuaMultiValue::from_vec(vec))
}
Err(e) => Err(e.clone()),
}
}
}
/**
Representation of a [`LuaThread`] with its associated arguments currently stored in the Lua registry.
*/
#[derive(Debug)]
pub(crate) struct ThreadWithArgs {
key_thread: LuaRegistryKey,
key_args: LuaRegistryKey,
}
impl ThreadWithArgs {
pub fn new(lua: &Lua, thread: LuaThread, args: LuaMultiValue) -> LuaResult<Self> {
let argsv = args.into_vec();
let key_thread = lua.create_registry_value(thread)?;
let key_args = lua.create_registry_value(argsv)?;
Ok(Self {
key_thread,
key_args,
})
}
pub fn into_inner(self, lua: &Lua) -> (LuaThread, LuaMultiValue) {
let thread = lua.registry_value(&self.key_thread).unwrap();
let argsv = lua.registry_value(&self.key_args).unwrap();
let args = LuaMultiValue::from_vec(argsv);
lua.remove_registry_value(self.key_thread).unwrap();
lua.remove_registry_value(self.key_args).unwrap();
(thread, args)
}
}
/** /**
Wrapper struct to accept either a Lua thread or a Lua function as function argument. Wrapper struct to accept either a Lua thread or a Lua function as function argument.

View file

@ -1,4 +1,5 @@
[tools] [tools]
luau-lsp = "JohnnyMorganz/luau-lsp@1.33.1" luau-lsp = "JohnnyMorganz/luau-lsp@1.44.1"
stylua = "JohnnyMorganz/StyLua@0.20.0" lune = "lune-org/lune@0.9.0"
just = "casey/just@1.36.0" stylua = "JohnnyMorganz/StyLua@2.1.0"
just = "casey/just@1.40.0"

View file

@ -0,0 +1,14 @@
local fs = require("@lune/fs")
fs.writeDir("./types")
for _, dir in fs.readDir("./crates") do
local std = string.match(dir, "^lune%-std%-(%w+)$")
if std ~= nil then
local from = `./crates/{dir}/types.d.luau`
if fs.isFile(from) then
local to = `./types/{std}.luau`
fs.copy(from, to, true)
end
end
end

View file

@ -0,0 +1,23 @@
local task = require("@lune/task")
local util = require("./util")
local pass = util.pass
-- These are some public APIs that have, or most likely have, different
-- certificate authorities (CAs), plus are both free to use and stable.
-- This should be enough to ensure that rustls is configured correctly.
local servers = {
"https://www.googleapis.com/discovery/v1/apis",
"https://api.cloudflare.com/client/v4/ips",
"https://azure.microsoft.com/en-us/updates/feed/",
"https://acme-v02.api.letsencrypt.org/directory",
"https://ip-ranges.amazonaws.com/ip-ranges.json",
"https://en.wikipedia.org/w/api.php",
"https://status.godaddy.com/api/v2/summary.json",
}
for _, server in servers do
task.spawn(function()
pass("GET", server, server)
end)
end

View file

@ -4,22 +4,38 @@ local stdio = require("@lune/stdio")
local util = {} local util = {}
function util.pass(method, url, message) function util.pass(method, url, message)
local response = net.request({ local success, response = pcall(net.request, {
method = method, method = method,
url = url, url = url,
}) })
if not response.ok then if not success then
error(string.format("%s failed!\nResponse: %s", message, stdio.format(response))) error(`{message} errored!\nError message: {tostring(response)}`)
elseif not response.ok then
error(
`{message} failed, but should have passed!`
.. `\nStatus code: {response.statusCode}`
.. `\nStatus message: {response.statusMessage}`
.. `\nResponse headers: {stdio.format(response.headers)}`
.. `\nResponse body: {response.body}`
)
end end
end end
function util.fail(method, url, message) function util.fail(method, url, message)
local response = net.request({ local success, response = pcall(net.request, {
method = method, method = method,
url = url, url = url,
}) })
if response.ok then if not success then
error(string.format("%s passed!\nResponse: %s", message, stdio.format(response))) error(`{message} errored!\nError message: {tostring(response)}`)
elseif response.ok then
error(
`{message} passed, but should have failed!`
.. `\nStatus code: {response.statusCode}`
.. `\nStatus message: {response.statusMessage}`
.. `\nResponse headers: {stdio.format(response.headers)}`
.. `\nResponse body: {response.body}`
)
end end
end end

View file

@ -0,0 +1,35 @@
local net = require("@lune/net")
local PORT = 8081
local LOCALHOST = "http://localhost"
local BROADCAST = `http://0.0.0.0`
local RESPONSE = "Hello, lune!"
-- Serve should be able to bind to broadcast IP addresse
local handle = net.serve(PORT, {
address = BROADCAST,
handleRequest = function(request)
return `Response from {BROADCAST}:{PORT}`
end,
})
-- And any requests to localhost should then succeed
local response = net.request(`{LOCALHOST}:{PORT}`).body
assert(response ~= nil, "Invalid response from server")
handle.stop()
-- Attempting to serve with a malformed IP address should throw an error
local success = pcall(function()
net.serve(8080, {
address = "a.b.c.d",
handleRequest = function()
return RESPONSE
end,
})
end)
assert(not success, "Server was created with malformed address")

View file

@ -0,0 +1,51 @@
local net = require("@lune/net")
local task = require("@lune/task")
local PORT = 8082
local URL = `http://127.0.0.1:{PORT}`
local RESPONSE = "Hello, lune!"
local handle = net.serve(PORT, function(request)
return RESPONSE
end)
-- Stopping is not guaranteed to happen instantly since it is async, but
-- it should happen on the next yield, so we wait the minimum amount here
handle.stop()
task.wait()
-- Sending a request to the stopped server should now error
local success, response2 = pcall(net.request, URL)
if not success then
local message = tostring(response2)
assert(
string.find(message, "Connection reset")
or string.find(message, "Connection closed")
or string.find(message, "Connection refused")
or string.find(message, "No connection could be made"), -- Windows Request Error
"Server did not stop responding to requests"
)
else
assert(not response2.ok, "Server did not stop responding to requests")
end
--[[
Trying to *stop* the server again should error, and
also mention that the server has already been stopped
Note that we cast pcall to any because of a
Luau limitation where it throws a type error for
`err` because handle.stop doesn't return any value
]]
local success2, err = (pcall :: any)(handle.stop)
assert(not success2, "Calling stop twice on the net serve handle should error")
local message = tostring(err)
assert(
string.find(message, "stop")
or string.find(message, "shutdown")
or string.find(message, "shut down"),
"The error message for calling stop twice on the net serve handle should be descriptive"
)

View file

@ -0,0 +1,24 @@
local net = require("@lune/net")
local process = require("@lune/process")
local stdio = require("@lune/stdio")
local task = require("@lune/task")
local PORT = 8083
local RESPONSE = "Hello, lune!"
-- Serve should not yield the entire main thread forever, only
-- for the initial binding to socket which should be very fast
local thread = task.delay(1, function()
stdio.ewrite("Serve must not yield the current thread for too long\n")
task.wait(1)
process.exit(1)
end)
local handle = net.serve(PORT, function(request)
return RESPONSE
end)
task.cancel(thread)
handle.stop()

Some files were not shown because too many files have changed in this diff Show more