mirror of
https://github.com/lune-org/lune.git
synced 2025-04-04 10:30:54 +01:00
merge: 'main' -> feature/std-buffer
This commit is contained in:
commit
652e71b60f
49 changed files with 1679 additions and 2650 deletions
32
CHANGELOG.md
32
CHANGELOG.md
|
@ -8,6 +8,38 @@ All notable changes to this project will be documented in this file.
|
|||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## `0.8.2` - March 12th, 2024
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed REPL panicking after the first evaluation / run.
|
||||
- Fixed globals reloading on each run in the REPL, causing unnecessary slowdowns.
|
||||
- Fixed `net.serve` requests no longer being plain tables in Lune `0.8.1`, breaking usage of things such as `table.clone`.
|
||||
|
||||
## `0.8.1` - March 11th, 2024
|
||||
|
||||
### Added
|
||||
|
||||
- Added the ability to specify an address in `net.serve`. ([#142])
|
||||
|
||||
### Changed
|
||||
|
||||
- Update to Luau version `0.616`.
|
||||
- Major performance improvements when using a large amount of threads / asynchronous Lune APIs. ([#165])
|
||||
- Minor performance improvements and less overhead for `net.serve` and `net.socket`. ([#165])
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed `fs.copy` not working with empty dirs. ([#155])
|
||||
- Fixed stack overflow when printing tables with cyclic references. ([#158])
|
||||
- Fixed not being able to yield in `net.serve` handlers without blocking other requests. ([#165])
|
||||
- Fixed various scheduler issues / panics. ([#165])
|
||||
|
||||
[#142]: https://github.com/lune-org/lune/pull/142
|
||||
[#155]: https://github.com/lune-org/lune/pull/155
|
||||
[#158]: https://github.com/lune-org/lune/pull/158
|
||||
[#165]: https://github.com/lune-org/lune/pull/165
|
||||
|
||||
## `0.8.0` - January 14th, 2024
|
||||
|
||||
### Breaking Changes
|
||||
|
|
988
Cargo.lock
generated
988
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
31
Cargo.toml
31
Cargo.toml
|
@ -1,6 +1,6 @@
|
|||
[package]
|
||||
name = "lune"
|
||||
version = "0.8.0"
|
||||
version = "0.8.2"
|
||||
edition = "2021"
|
||||
license = "MPL-2.0"
|
||||
repository = "https://github.com/lune-org/lune"
|
||||
|
@ -79,11 +79,19 @@ urlencoding = "2.1"
|
|||
|
||||
### RUNTIME
|
||||
|
||||
blocking = "1.5"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||
mlua = { version = "0.9.1", features = ["luau", "luau-jit", "serialize"] }
|
||||
tokio = { version = "1.24", features = ["full", "tracing"] }
|
||||
os_str_bytes = { version = "6.4", features = ["conversions"] }
|
||||
os_str_bytes = { version = "7.0", features = ["conversions"] }
|
||||
|
||||
mlua-luau-scheduler = { version = "0.0.2" }
|
||||
mlua = { version = "0.9.6", features = [
|
||||
"luau",
|
||||
"luau-jit",
|
||||
"async",
|
||||
"serialize",
|
||||
] }
|
||||
|
||||
### SERDE
|
||||
|
||||
|
@ -101,21 +109,26 @@ toml = { version = "0.8", features = ["preserve_order"] }
|
|||
|
||||
### NET
|
||||
|
||||
hyper = { version = "0.14", features = ["full"] }
|
||||
hyper-tungstenite = { version = "0.11" }
|
||||
hyper = { version = "1.1", features = ["full"] }
|
||||
hyper-util = { version = "0.1", features = ["full"] }
|
||||
http = "1.0"
|
||||
http-body-util = { version = "0.1" }
|
||||
hyper-tungstenite = { version = "0.13" }
|
||||
|
||||
reqwest = { version = "0.11", default-features = false, features = [
|
||||
"rustls-tls",
|
||||
] }
|
||||
tokio-tungstenite = { version = "0.20", features = ["rustls-tls-webpki-roots"] }
|
||||
|
||||
tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] }
|
||||
|
||||
### DATETIME
|
||||
chrono = "0.4"
|
||||
chrono = "=0.4.34" # NOTE: 0.4.35 does not compile with chrono_lc
|
||||
chrono_lc = "0.1"
|
||||
|
||||
### CLI
|
||||
|
||||
anyhow = { optional = true, version = "1.0" }
|
||||
env_logger = { optional = true, version = "0.10" }
|
||||
env_logger = { optional = true, version = "0.11" }
|
||||
itertools = { optional = true, version = "0.12" }
|
||||
clap = { optional = true, version = "4.1", features = ["derive"] }
|
||||
include_dir = { optional = true, version = "0.7", features = ["glob"] }
|
||||
|
@ -123,7 +136,7 @@ regex = { optional = true, version = "1.7", default-features = false, features =
|
|||
"std",
|
||||
"unicode-perl",
|
||||
] }
|
||||
rustyline = { optional = true, version = "13.0" }
|
||||
rustyline = { optional = true, version = "14.0" }
|
||||
|
||||
### ROBLOX
|
||||
|
||||
|
|
|
@ -15,7 +15,7 @@ use copy::copy;
|
|||
use metadata::FsMetadata;
|
||||
use options::FsWriteOptions;
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
TableBuilder::new(lua)?
|
||||
.with_async_function("readFile", fs_read_file)?
|
||||
.with_async_function("readDir", fs_read_dir)?
|
||||
|
|
|
@ -28,10 +28,7 @@ pub enum LuneBuiltin {
|
|||
Roblox,
|
||||
}
|
||||
|
||||
impl<'lua> LuneBuiltin
|
||||
where
|
||||
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
|
||||
{
|
||||
impl LuneBuiltin {
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Self::DateTime => "datetime",
|
||||
|
@ -47,7 +44,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn create(&self, lua: &'lua Lua) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
pub fn create<'lua>(&self, lua: &'lua Lua) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
let res = match self {
|
||||
Self::DateTime => datetime::create(lua),
|
||||
Self::Fs => fs::create(lua),
|
||||
|
|
|
@ -2,8 +2,14 @@ use std::str::FromStr;
|
|||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use hyper::{header::HeaderName, http::HeaderValue, HeaderMap};
|
||||
use reqwest::{IntoUrl, Method, RequestBuilder};
|
||||
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_ENCODING};
|
||||
|
||||
use crate::lune::{
|
||||
builtins::serde::compress_decompress::{decompress, CompressDecompressFormat},
|
||||
util::TableBuilder,
|
||||
};
|
||||
|
||||
use super::{config::RequestConfig, util::header_map_to_table};
|
||||
|
||||
const REGISTRY_KEY: &str = "NetClient";
|
||||
|
||||
|
@ -35,16 +41,19 @@ impl NetClientBuilder {
|
|||
|
||||
pub fn build(self) -> LuaResult<NetClient> {
|
||||
let client = self.builder.build().into_lua_err()?;
|
||||
Ok(NetClient(client))
|
||||
Ok(NetClient { inner: client })
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NetClient(reqwest::Client);
|
||||
pub struct NetClient {
|
||||
inner: reqwest::Client,
|
||||
}
|
||||
|
||||
impl NetClient {
|
||||
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
|
||||
self.0.request(method, url)
|
||||
pub fn from_registry(lua: &Lua) -> Self {
|
||||
lua.named_registry_value(REGISTRY_KEY)
|
||||
.expect("Failed to get NetClient from lua registry")
|
||||
}
|
||||
|
||||
pub fn into_registry(self, lua: &Lua) {
|
||||
|
@ -52,16 +61,68 @@ impl NetClient {
|
|||
.expect("Failed to store NetClient in lua registry");
|
||||
}
|
||||
|
||||
pub fn from_registry(lua: &Lua) -> Self {
|
||||
lua.named_registry_value(REGISTRY_KEY)
|
||||
.expect("Failed to get NetClient from lua registry")
|
||||
pub async fn request(&self, config: RequestConfig) -> LuaResult<NetClientResponse> {
|
||||
// Create and send the request
|
||||
let mut request = self.inner.request(config.method, config.url);
|
||||
for (query, values) in config.query {
|
||||
request = request.query(
|
||||
&values
|
||||
.iter()
|
||||
.map(|v| (query.as_str(), v))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
}
|
||||
for (header, values) in config.headers {
|
||||
for value in values {
|
||||
request = request.header(header.as_str(), value);
|
||||
}
|
||||
}
|
||||
let res = request
|
||||
.body(config.body.unwrap_or_default())
|
||||
.send()
|
||||
.await
|
||||
.into_lua_err()?;
|
||||
|
||||
// Extract status, headers
|
||||
let res_status = res.status().as_u16();
|
||||
let res_status_text = res.status().canonical_reason();
|
||||
let res_headers = res.headers().clone();
|
||||
|
||||
// Read response bytes
|
||||
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
|
||||
let mut res_decompressed = false;
|
||||
|
||||
// Check for extra options, decompression
|
||||
if config.options.decompress {
|
||||
let decompress_format = res_headers
|
||||
.iter()
|
||||
.find(|(name, _)| {
|
||||
name.as_str()
|
||||
.eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
|
||||
})
|
||||
.and_then(|(_, value)| value.to_str().ok())
|
||||
.and_then(CompressDecompressFormat::detect_from_header_str);
|
||||
if let Some(format) = decompress_format {
|
||||
res_bytes = decompress(format, res_bytes).await?;
|
||||
res_decompressed = true;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(NetClientResponse {
|
||||
ok: (200..300).contains(&res_status),
|
||||
status_code: res_status,
|
||||
status_message: res_status_text.unwrap_or_default().to_string(),
|
||||
headers: res_headers,
|
||||
body: res_bytes,
|
||||
body_decompressed: res_decompressed,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LuaUserData for NetClient {}
|
||||
|
||||
impl<'lua> FromLua<'lua> for NetClient {
|
||||
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
|
||||
impl FromLua<'_> for NetClient {
|
||||
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
|
||||
if let LuaValue::UserData(ud) = value {
|
||||
if let Ok(ctx) = ud.borrow::<NetClient>() {
|
||||
return Ok(ctx.clone());
|
||||
|
@ -71,10 +132,34 @@ impl<'lua> FromLua<'lua> for NetClient {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'lua> From<&'lua Lua> for NetClient {
|
||||
fn from(value: &'lua Lua) -> Self {
|
||||
impl From<&Lua> for NetClient {
|
||||
fn from(value: &Lua) -> Self {
|
||||
value
|
||||
.named_registry_value(REGISTRY_KEY)
|
||||
.expect("Missing require context in lua registry")
|
||||
}
|
||||
}
|
||||
|
||||
pub struct NetClientResponse {
|
||||
ok: bool,
|
||||
status_code: u16,
|
||||
status_message: String,
|
||||
headers: HeaderMap,
|
||||
body: Vec<u8>,
|
||||
body_decompressed: bool,
|
||||
}
|
||||
|
||||
impl NetClientResponse {
|
||||
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("ok", self.ok)?
|
||||
.with_value("statusCode", self.status_code)?
|
||||
.with_value("statusMessage", self.status_message)?
|
||||
.with_value(
|
||||
"headers",
|
||||
header_map_to_table(lua, self.headers, self.body_decompressed)?,
|
||||
)?
|
||||
.with_value("body", lua.create_string(&self.body)?)?
|
||||
.build_readonly()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, net::Ipv4Addr};
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
|
@ -8,6 +8,18 @@ use crate::lune::util::buffer::buf_to_str;
|
|||
|
||||
use super::util::table_to_hash_map;
|
||||
|
||||
const DEFAULT_IP_ADDRESS: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
|
||||
|
||||
const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#"
|
||||
return {
|
||||
status = 426,
|
||||
body = "Upgrade Required",
|
||||
headers = {
|
||||
Upgrade = "websocket",
|
||||
},
|
||||
}
|
||||
"#;
|
||||
|
||||
// Net request config
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -23,28 +35,29 @@ impl Default for RequestConfigOptions {
|
|||
|
||||
impl<'lua> FromLua<'lua> for RequestConfigOptions {
|
||||
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
|
||||
// Nil means default options, table means custom options
|
||||
if let LuaValue::Nil = value {
|
||||
return Ok(Self::default());
|
||||
// Nil means default options
|
||||
Ok(Self::default())
|
||||
} else if let LuaValue::Table(tab) = value {
|
||||
// Extract flags
|
||||
let decompress = match tab.raw_get::<_, Option<bool>>("decompress") {
|
||||
// Table means custom options
|
||||
let decompress = match tab.get::<_, Option<bool>>("decompress") {
|
||||
Ok(decomp) => Ok(decomp.unwrap_or(true)),
|
||||
Err(_) => Err(LuaError::RuntimeError(
|
||||
"Invalid option value for 'decompress' in request config options".to_string(),
|
||||
)),
|
||||
}?;
|
||||
return Ok(Self { decompress });
|
||||
Ok(Self { decompress })
|
||||
} else {
|
||||
// Anything else is invalid
|
||||
Err(LuaError::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "RequestConfigOptions",
|
||||
message: Some(format!(
|
||||
"Invalid request config options - expected table or nil, got {}",
|
||||
value.type_name()
|
||||
)),
|
||||
})
|
||||
}
|
||||
// Anything else is invalid
|
||||
Err(LuaError::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "RequestConfigOptions",
|
||||
message: Some(format!(
|
||||
"Invalid request config options - expected table or nil, got {}",
|
||||
value.type_name()
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,46 +75,40 @@ impl FromLua<'_> for RequestConfig {
|
|||
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
|
||||
// If we just got a string we assume its a GET request to a given url
|
||||
if let LuaValue::String(s) = value {
|
||||
return Ok(Self {
|
||||
Ok(Self {
|
||||
url: s.to_string_lossy().to_string(),
|
||||
method: Method::GET,
|
||||
query: HashMap::new(),
|
||||
headers: HashMap::new(),
|
||||
body: None,
|
||||
options: Default::default(),
|
||||
});
|
||||
}
|
||||
// If we got a table we are able to configure the entire request
|
||||
if let LuaValue::Table(tab) = value {
|
||||
})
|
||||
} else if let LuaValue::Table(tab) = value {
|
||||
// If we got a table we are able to configure the entire request
|
||||
// Extract url
|
||||
let url = match tab.raw_get::<_, LuaString>("url") {
|
||||
let url = match tab.get::<_, LuaString>("url") {
|
||||
Ok(config_url) => Ok(config_url.to_string_lossy().to_string()),
|
||||
Err(_) => Err(LuaError::runtime("Missing 'url' in request config")),
|
||||
}?;
|
||||
// Extract method
|
||||
let method = match tab.raw_get::<_, LuaString>("method") {
|
||||
let method = match tab.get::<_, LuaString>("method") {
|
||||
Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(),
|
||||
Err(_) => "GET".to_string(),
|
||||
};
|
||||
// Extract query
|
||||
let query = match tab.raw_get::<_, LuaTable>("query") {
|
||||
let query = match tab.get::<_, LuaTable>("query") {
|
||||
Ok(tab) => table_to_hash_map(tab, "query")?,
|
||||
Err(_) => HashMap::new(),
|
||||
};
|
||||
// Extract headers
|
||||
let headers = match tab.raw_get::<_, LuaTable>("headers") {
|
||||
let headers = match tab.get::<_, LuaTable>("headers") {
|
||||
Ok(tab) => table_to_hash_map(tab, "headers")?,
|
||||
Err(_) => HashMap::new(),
|
||||
};
|
||||
// Extract body
|
||||
let body = match tab.raw_get::<_, LuaValue>("body")? {
|
||||
LuaValue::String(str) => Some(str.as_bytes().to_vec()),
|
||||
LuaValue::UserData(inner) => Some(
|
||||
buf_to_str(lua, LuaValue::UserData(inner))?
|
||||
.as_bytes()
|
||||
.to_vec(),
|
||||
),
|
||||
_ => None,
|
||||
let body = match tab.get::<_, LuaString>("body") {
|
||||
Ok(config_body) => Some(config_body.as_bytes().to_owned()),
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
// Convert method string into proper enum
|
||||
|
@ -120,29 +127,30 @@ impl FromLua<'_> for RequestConfig {
|
|||
))),
|
||||
}?;
|
||||
// Parse any extra options given
|
||||
let options = match tab.raw_get::<_, LuaValue>("options") {
|
||||
let options = match tab.get::<_, LuaValue>("options") {
|
||||
Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?,
|
||||
Err(_) => RequestConfigOptions::default(),
|
||||
};
|
||||
// All good, validated and we got what we need
|
||||
return Ok(Self {
|
||||
Ok(Self {
|
||||
url,
|
||||
method,
|
||||
query,
|
||||
headers,
|
||||
body,
|
||||
options,
|
||||
});
|
||||
};
|
||||
// Anything else is invalid
|
||||
Err(LuaError::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "RequestConfig",
|
||||
message: Some(format!(
|
||||
"Invalid request config - expected string or table, got {}",
|
||||
value.type_name()
|
||||
)),
|
||||
})
|
||||
})
|
||||
} else {
|
||||
// Anything else is invalid
|
||||
Err(LuaError::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "RequestConfig",
|
||||
message: Some(format!(
|
||||
"Invalid request config - expected string or table, got {}",
|
||||
value.type_name()
|
||||
)),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -150,54 +158,72 @@ impl FromLua<'_> for RequestConfig {
|
|||
|
||||
#[derive(Debug)]
|
||||
pub struct ServeConfig<'a> {
|
||||
pub address: Ipv4Addr,
|
||||
pub handle_request: LuaFunction<'a>,
|
||||
pub handle_web_socket: Option<LuaFunction<'a>>,
|
||||
pub address: Option<LuaString<'a>>,
|
||||
}
|
||||
|
||||
impl<'lua> FromLua<'lua> for ServeConfig<'lua> {
|
||||
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
|
||||
let message = match &value {
|
||||
LuaValue::Function(f) => {
|
||||
return Ok(ServeConfig {
|
||||
handle_request: f.clone(),
|
||||
handle_web_socket: None,
|
||||
address: None,
|
||||
if let LuaValue::Function(f) = &value {
|
||||
// Single function = request handler, rest is default
|
||||
Ok(ServeConfig {
|
||||
handle_request: f.clone(),
|
||||
handle_web_socket: None,
|
||||
address: DEFAULT_IP_ADDRESS,
|
||||
})
|
||||
} else if let LuaValue::Table(t) = &value {
|
||||
// Table means custom options
|
||||
let address: Option<LuaString> = t.get("address")?;
|
||||
let handle_request: Option<LuaFunction> = t.get("handleRequest")?;
|
||||
let handle_web_socket: Option<LuaFunction> = t.get("handleWebSocket")?;
|
||||
if handle_request.is_some() || handle_web_socket.is_some() {
|
||||
let address: Ipv4Addr = match &address {
|
||||
Some(addr) => {
|
||||
let addr_str = addr.to_str()?;
|
||||
|
||||
addr_str
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://")
|
||||
.parse()
|
||||
.map_err(|_e| LuaError::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "ServeConfig",
|
||||
message: Some(format!(
|
||||
"IP address format is incorrect - \
|
||||
expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \
|
||||
got '{addr_str}'"
|
||||
)),
|
||||
})?
|
||||
}
|
||||
None => DEFAULT_IP_ADDRESS,
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
address,
|
||||
handle_request: handle_request.unwrap_or_else(|| {
|
||||
lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER)
|
||||
.into_function()
|
||||
.expect("Failed to create default http responder function")
|
||||
}),
|
||||
handle_web_socket,
|
||||
})
|
||||
} else {
|
||||
Err(LuaError::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "ServeConfig",
|
||||
message: Some(String::from(
|
||||
"Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function",
|
||||
)),
|
||||
})
|
||||
}
|
||||
LuaValue::Table(t) => {
|
||||
let handle_request: Option<LuaFunction> = t.raw_get("handleRequest")?;
|
||||
let handle_web_socket: Option<LuaFunction> = t.raw_get("handleWebSocket")?;
|
||||
let address: Option<LuaString> = t.raw_get("address")?;
|
||||
if handle_request.is_some() || handle_web_socket.is_some() {
|
||||
return Ok(ServeConfig {
|
||||
handle_request: handle_request.unwrap_or_else(|| {
|
||||
let chunk = r#"
|
||||
return {
|
||||
status = 426,
|
||||
body = "Upgrade Required",
|
||||
headers = {
|
||||
Upgrade = "websocket",
|
||||
},
|
||||
}
|
||||
"#;
|
||||
lua.load(chunk)
|
||||
.into_function()
|
||||
.expect("Failed to create default http responder function")
|
||||
}),
|
||||
handle_web_socket,
|
||||
address,
|
||||
});
|
||||
} else {
|
||||
Some("Missing handleRequest and / or handleWebSocket".to_string())
|
||||
}
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
Err(LuaError::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "ServeConfig",
|
||||
message,
|
||||
})
|
||||
} else {
|
||||
// Anything else is invalid
|
||||
Err(LuaError::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "ServeConfig",
|
||||
message: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,39 +1,27 @@
|
|||
use std::net::Ipv4Addr;
|
||||
#![allow(unused_variables)]
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use hyper::header::CONTENT_ENCODING;
|
||||
|
||||
use crate::lune::{
|
||||
scheduler::Scheduler,
|
||||
util::{buffer::create_lua_buffer, TableBuilder},
|
||||
};
|
||||
|
||||
use self::{
|
||||
server::{bind_to_addr, create_server},
|
||||
util::header_map_to_table,
|
||||
};
|
||||
|
||||
use super::serde::{
|
||||
compress_decompress::{decompress, CompressDecompressFormat},
|
||||
encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat},
|
||||
};
|
||||
use mlua_luau_scheduler::LuaSpawnExt;
|
||||
|
||||
mod client;
|
||||
mod config;
|
||||
mod processing;
|
||||
mod response;
|
||||
mod server;
|
||||
mod util;
|
||||
mod websocket;
|
||||
|
||||
use client::{NetClient, NetClientBuilder};
|
||||
use config::{RequestConfig, ServeConfig};
|
||||
use websocket::NetWebSocket;
|
||||
use crate::lune::util::TableBuilder;
|
||||
|
||||
const DEFAULT_IP_ADDRESS: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
|
||||
use self::{
|
||||
client::{NetClient, NetClientBuilder},
|
||||
config::{RequestConfig, ServeConfig},
|
||||
server::serve,
|
||||
util::create_user_agent_header,
|
||||
websocket::NetWebSocket,
|
||||
};
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
use super::serde::encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat};
|
||||
|
||||
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
NetClientBuilder::new()
|
||||
.headers(&[("User-Agent", create_user_agent_header())])?
|
||||
.build()?
|
||||
|
@ -49,14 +37,6 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
|||
.build_readonly()
|
||||
}
|
||||
|
||||
fn create_user_agent_header() -> String {
|
||||
let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY")
|
||||
.trim_start_matches("https://github.com/")
|
||||
.split_once('/')
|
||||
.unwrap();
|
||||
format!("{github_owner}-{github_repo}-cli")
|
||||
}
|
||||
|
||||
fn net_json_encode<'lua>(
|
||||
lua: &'lua Lua,
|
||||
(val, pretty): (LuaValue<'lua>, Option<bool>),
|
||||
|
@ -69,68 +49,14 @@ fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult<Lua
|
|||
EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json)
|
||||
}
|
||||
|
||||
async fn net_request<'lua>(lua: &'lua Lua, config: RequestConfig) -> LuaResult<LuaTable<'lua>>
|
||||
where
|
||||
'lua: 'static, // FIXME: Get rid of static lifetime bound here
|
||||
{
|
||||
// Create and send the request
|
||||
async fn net_request(lua: &Lua, config: RequestConfig) -> LuaResult<LuaTable> {
|
||||
let client = NetClient::from_registry(lua);
|
||||
let mut request = client.request(config.method, &config.url);
|
||||
for (query, values) in config.query {
|
||||
request = request.query(
|
||||
&values
|
||||
.iter()
|
||||
.map(|v| (query.as_str(), v))
|
||||
.collect::<Vec<_>>(),
|
||||
);
|
||||
}
|
||||
for (header, values) in config.headers {
|
||||
for value in values {
|
||||
request = request.header(header.as_str(), value);
|
||||
}
|
||||
}
|
||||
let res = request
|
||||
.body(config.body.unwrap_or_default())
|
||||
.send()
|
||||
.await
|
||||
.into_lua_err()?;
|
||||
// Extract status, headers
|
||||
let res_status = res.status().as_u16();
|
||||
let res_status_text = res.status().canonical_reason();
|
||||
let res_headers = res.headers().clone();
|
||||
// Read response bytes
|
||||
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
|
||||
let mut res_decompressed = false;
|
||||
// Check for extra options, decompression
|
||||
if config.options.decompress {
|
||||
let decompress_format = res_headers
|
||||
.iter()
|
||||
.find(|(name, _)| {
|
||||
name.as_str()
|
||||
.eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
|
||||
})
|
||||
.and_then(|(_, value)| value.to_str().ok())
|
||||
.and_then(CompressDecompressFormat::detect_from_header_str);
|
||||
if let Some(format) = decompress_format {
|
||||
res_bytes = decompress(format, res_bytes).await?;
|
||||
res_decompressed = true;
|
||||
}
|
||||
}
|
||||
// Construct and return a readonly lua table with results
|
||||
let res_headers_lua = header_map_to_table(lua, res_headers, res_decompressed)?;
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("ok", (200..300).contains(&res_status))?
|
||||
.with_value("statusCode", res_status)?
|
||||
.with_value("statusMessage", res_status_text)?
|
||||
.with_value("headers", res_headers_lua)?
|
||||
.with_value("body", create_lua_buffer(lua, &res_bytes)?)?
|
||||
.build_readonly()
|
||||
// NOTE: We spawn the request as a background task to free up resources in lua
|
||||
let res = lua.spawn(async move { client.request(config).await });
|
||||
res.await?.into_lua_table(lua)
|
||||
}
|
||||
|
||||
async fn net_socket<'lua>(lua: &'lua Lua, url: String) -> LuaResult<LuaTable>
|
||||
where
|
||||
'lua: 'static, // FIXME: Get rid of static lifetime bound here
|
||||
{
|
||||
async fn net_socket(lua: &Lua, url: String) -> LuaResult<LuaTable> {
|
||||
let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?;
|
||||
NetWebSocket::new(ws).into_lua_table(lua)
|
||||
}
|
||||
|
@ -138,32 +64,8 @@ where
|
|||
async fn net_serve<'lua>(
|
||||
lua: &'lua Lua,
|
||||
(port, config): (u16, ServeConfig<'lua>),
|
||||
) -> LuaResult<LuaTable<'lua>>
|
||||
where
|
||||
'lua: 'static, // FIXME: Get rid of static lifetime bound here
|
||||
{
|
||||
let sched = lua
|
||||
.app_data_ref::<&Scheduler>()
|
||||
.expect("Lua struct is missing scheduler");
|
||||
|
||||
let address: Ipv4Addr = match &config.address {
|
||||
Some(addr) => {
|
||||
let addr_str = addr.to_str()?;
|
||||
|
||||
addr_str
|
||||
.trim_start_matches("http://")
|
||||
.trim_start_matches("https://")
|
||||
.parse()
|
||||
.map_err(|_e| LuaError::RuntimeError(format!(
|
||||
"IP address format is incorrect (expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', got '{addr_str}')"
|
||||
)))?
|
||||
}
|
||||
None => DEFAULT_IP_ADDRESS,
|
||||
};
|
||||
|
||||
let builder = bind_to_addr(address, port)?;
|
||||
|
||||
create_server(lua, &sched, config, builder)
|
||||
) -> LuaResult<LuaTable<'lua>> {
|
||||
serve(lua, port, config).await
|
||||
}
|
||||
|
||||
fn net_url_encode<'lua>(
|
||||
|
|
|
@ -1,101 +0,0 @@
|
|||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use hyper::{body::to_bytes, Body, Request};
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use crate::lune::util::TableBuilder;
|
||||
|
||||
static ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub(super) struct ProcessedRequestId(usize);
|
||||
|
||||
impl ProcessedRequestId {
|
||||
pub fn new() -> Self {
|
||||
// NOTE: This may overflow after a couple billion requests,
|
||||
// but that's completely fine... unless a request is still
|
||||
// alive after billions more arrive and need to be handled
|
||||
Self(ID_COUNTER.fetch_add(1, Ordering::Relaxed))
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) struct ProcessedRequest {
|
||||
pub id: ProcessedRequestId,
|
||||
method: String,
|
||||
path: String,
|
||||
query: Vec<(String, String)>,
|
||||
headers: Vec<(String, Vec<u8>)>,
|
||||
body: Vec<u8>,
|
||||
}
|
||||
|
||||
impl ProcessedRequest {
|
||||
pub async fn from_request(req: Request<Body>) -> LuaResult<Self> {
|
||||
let (head, body) = req.into_parts();
|
||||
|
||||
// FUTURE: We can do extra processing like async decompression here
|
||||
let body = match to_bytes(body).await {
|
||||
Err(_) => return Err(LuaError::runtime("Failed to read request body bytes")),
|
||||
Ok(b) => b.to_vec(),
|
||||
};
|
||||
|
||||
let method = head.method.to_string().to_ascii_uppercase();
|
||||
|
||||
let mut path = head.uri.path().to_string();
|
||||
if path.is_empty() {
|
||||
path = "/".to_string();
|
||||
}
|
||||
|
||||
let query = head
|
||||
.uri
|
||||
.query()
|
||||
.unwrap_or_default()
|
||||
.split('&')
|
||||
.filter_map(|q| q.split_once('='))
|
||||
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||
.collect();
|
||||
|
||||
let mut headers = Vec::new();
|
||||
let mut header_name = String::new();
|
||||
for (name_opt, value) in head.headers.into_iter() {
|
||||
if let Some(name) = name_opt {
|
||||
header_name = name.to_string();
|
||||
}
|
||||
headers.push((header_name.clone(), value.as_bytes().to_vec()))
|
||||
}
|
||||
|
||||
let id = ProcessedRequestId::new();
|
||||
|
||||
Ok(Self {
|
||||
id,
|
||||
method,
|
||||
path,
|
||||
query,
|
||||
headers,
|
||||
body,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
|
||||
// FUTURE: Make inner tables for query keys that have multiple values?
|
||||
let query = lua.create_table_with_capacity(0, self.query.len())?;
|
||||
for (key, value) in self.query.into_iter() {
|
||||
query.set(key, value)?;
|
||||
}
|
||||
|
||||
let headers = lua.create_table_with_capacity(0, self.headers.len())?;
|
||||
for (key, value) in self.headers.into_iter() {
|
||||
headers.set(key, lua.create_string(value)?)?;
|
||||
}
|
||||
|
||||
let body = lua.create_string(self.body)?;
|
||||
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("method", self.method)?
|
||||
.with_value("path", self.path)?
|
||||
.with_value("query", query)?
|
||||
.with_value("headers", headers)?
|
||||
.with_value("body", body)?
|
||||
.build_readonly()
|
||||
}
|
||||
}
|
|
@ -1,223 +0,0 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
convert::Infallible,
|
||||
net::{Ipv4Addr, SocketAddr},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use hyper::{
|
||||
server::{conn::AddrIncoming, Builder},
|
||||
service::{make_service_fn, service_fn},
|
||||
Server,
|
||||
};
|
||||
|
||||
use hyper_tungstenite::{is_upgrade_request, upgrade, HyperWebsocket};
|
||||
use mlua::prelude::*;
|
||||
use tokio::sync::{mpsc, oneshot, Mutex};
|
||||
|
||||
use crate::lune::{
|
||||
scheduler::Scheduler,
|
||||
util::{futures::yield_forever, traits::LuaEmitErrorExt, TableBuilder},
|
||||
};
|
||||
|
||||
use super::{
|
||||
config::ServeConfig, processing::ProcessedRequest, response::NetServeResponse,
|
||||
websocket::NetWebSocket,
|
||||
};
|
||||
|
||||
pub(super) fn bind_to_addr(address: Ipv4Addr, port: u16) -> LuaResult<Builder<AddrIncoming>> {
|
||||
let addr = SocketAddr::from((address, port));
|
||||
|
||||
match Server::try_bind(&addr) {
|
||||
Ok(b) => Ok(b),
|
||||
Err(e) => Err(LuaError::external(format!(
|
||||
"Failed to bind to {addr}\n{}",
|
||||
e.to_string()
|
||||
.replace("error creating server listener: ", "> ")
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn create_server<'lua>(
|
||||
lua: &'lua Lua,
|
||||
sched: &'lua Scheduler,
|
||||
config: ServeConfig<'lua>,
|
||||
builder: Builder<AddrIncoming>,
|
||||
) -> LuaResult<LuaTable<'lua>>
|
||||
where
|
||||
'lua: 'static, // FIXME: Get rid of static lifetime bound here
|
||||
{
|
||||
// Note that we need to use a mpsc here and not
|
||||
// a oneshot channel since we move the sender
|
||||
// into our table with the stop function
|
||||
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
|
||||
|
||||
// Communicate between background thread(s) and main lua thread using mpsc and oneshot
|
||||
let (tx_request, mut rx_request) = mpsc::channel::<ProcessedRequest>(64);
|
||||
let (tx_websocket, mut rx_websocket) = mpsc::channel::<HyperWebsocket>(64);
|
||||
let tx_request_arc = Arc::new(tx_request);
|
||||
let tx_websocket_arc = Arc::new(tx_websocket);
|
||||
|
||||
let response_senders = Arc::new(Mutex::new(HashMap::new()));
|
||||
let response_senders_bg = Arc::clone(&response_senders);
|
||||
let response_senders_lua = Arc::clone(&response_senders_bg);
|
||||
|
||||
// Create our background service which will accept
|
||||
// requests, do some processing, then forward to lua
|
||||
let has_websocket_handler = config.handle_web_socket.is_some();
|
||||
let hyper_make_service = make_service_fn(move |_| {
|
||||
let tx_request = Arc::clone(&tx_request_arc);
|
||||
let tx_websocket = Arc::clone(&tx_websocket_arc);
|
||||
let response_senders = Arc::clone(&response_senders_bg);
|
||||
|
||||
let handler = service_fn(move |mut req| {
|
||||
let tx_request = Arc::clone(&tx_request);
|
||||
let tx_websocket = Arc::clone(&tx_websocket);
|
||||
let response_senders = Arc::clone(&response_senders);
|
||||
async move {
|
||||
// FUTURE: Improve error messages when lua is busy and queue is full
|
||||
if has_websocket_handler && is_upgrade_request(&req) {
|
||||
let (response, ws) = match upgrade(&mut req, None) {
|
||||
Err(_) => return Err(LuaError::runtime("Failed to upgrade websocket")),
|
||||
Ok(v) => v,
|
||||
};
|
||||
if (tx_websocket.send(ws).await).is_err() {
|
||||
return Err(LuaError::runtime("Lua handler is busy"));
|
||||
}
|
||||
Ok(response)
|
||||
} else {
|
||||
let processed = ProcessedRequest::from_request(req).await?;
|
||||
let request_id = processed.id;
|
||||
if (tx_request.send(processed).await).is_err() {
|
||||
return Err(LuaError::runtime("Lua handler is busy"));
|
||||
}
|
||||
let (response_tx, response_rx) = oneshot::channel::<NetServeResponse>();
|
||||
response_senders
|
||||
.lock()
|
||||
.await
|
||||
.insert(request_id, response_tx);
|
||||
match response_rx.await {
|
||||
Err(_) => Err(LuaError::runtime("Internal Server Error")),
|
||||
Ok(r) => r.into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
async move { Ok::<_, Infallible>(handler) }
|
||||
});
|
||||
|
||||
// Start up our service
|
||||
sched.spawn(async move {
|
||||
let result = builder
|
||||
.http1_only(true) // Web sockets can only use http1
|
||||
.http1_keepalive(true) // Web sockets must be kept alive
|
||||
.serve(hyper_make_service)
|
||||
.with_graceful_shutdown(async move {
|
||||
if shutdown_rx.recv().await.is_none() {
|
||||
// The channel was closed, meaning the serve handle
|
||||
// was garbage collected by lua without being used
|
||||
yield_forever().await;
|
||||
}
|
||||
});
|
||||
if let Err(e) = result.await {
|
||||
eprintln!("Net serve error: {e}")
|
||||
}
|
||||
});
|
||||
|
||||
// Spawn a local thread with access to lua and the same lifetime
|
||||
sched.spawn_local(async move {
|
||||
loop {
|
||||
// Wait for either a request or a websocket to handle,
|
||||
// if we got neither it means both channels were dropped
|
||||
// and our server has stopped, either gracefully or panic
|
||||
let (req, sock) = tokio::select! {
|
||||
req = rx_request.recv() => (req, None),
|
||||
sock = rx_websocket.recv() => (None, sock),
|
||||
};
|
||||
if req.is_none() && sock.is_none() {
|
||||
break;
|
||||
}
|
||||
|
||||
// NOTE: The closure here is not really necessary, we
|
||||
// make the closure so that we can use the `?` operator
|
||||
// and make a catch-all for errors in spawn_local below
|
||||
let handle_request = config.handle_request.clone();
|
||||
let handle_web_socket = config.handle_web_socket.clone();
|
||||
let response_senders = Arc::clone(&response_senders_lua);
|
||||
let response_fut = async move {
|
||||
match (req, sock) {
|
||||
(Some(req), _) => {
|
||||
let req_id = req.id;
|
||||
let req_table = req.into_lua_table(lua)?;
|
||||
|
||||
let thread_id = sched.push_back(lua, handle_request, req_table)?;
|
||||
let thread_res = sched.wait_for_thread(lua, thread_id).await?;
|
||||
|
||||
let response = NetServeResponse::from_lua_multi(thread_res, lua)?;
|
||||
let response_sender = response_senders
|
||||
.lock()
|
||||
.await
|
||||
.remove(&req_id)
|
||||
.expect("Response channel was removed unexpectedly");
|
||||
|
||||
// NOTE: We ignore the error here, if the sender is no longer
|
||||
// being listened to its because our client disconnected during
|
||||
// handler being called, which is fine and should not emit errors
|
||||
response_sender.send(response).ok();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
(_, Some(sock)) => {
|
||||
let sock = sock.await.into_lua_err()?;
|
||||
|
||||
let sock_handler = handle_web_socket
|
||||
.as_ref()
|
||||
.cloned()
|
||||
.expect("Got web socket but web socket handler is missing");
|
||||
let sock_table = NetWebSocket::new(sock).into_lua_table(lua)?;
|
||||
|
||||
// NOTE: Web socket handler does not need to send any
|
||||
// response back, the websocket upgrade response is
|
||||
// automatically sent above in the background thread(s)
|
||||
let thread_id = sched.push_back(lua, sock_handler, sock_table)?;
|
||||
let _thread_res = sched.wait_for_thread(lua, thread_id).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
NOTE: It is currently not possible to spawn new background tasks from within
|
||||
another background task with the Lune scheduler since they are locked behind a
|
||||
mutex and we also need that mutex locked to be able to run a background task...
|
||||
|
||||
We need to do some work to make it so our unordered futures queues do
|
||||
not require locking and then we can replace the following bit of code:
|
||||
|
||||
sched.spawn_local(async {
|
||||
if let Err(e) = response_fut.await {
|
||||
lua.emit_error(e);
|
||||
}
|
||||
});
|
||||
*/
|
||||
if let Err(e) = response_fut.await {
|
||||
lua.emit_error(e);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Create a new read-only table that contains methods
|
||||
// for manipulating server behavior and shutting it down
|
||||
let handle_stop = move |_, _: ()| match shutdown_tx.try_send(()) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(LuaError::RuntimeError(
|
||||
"Server has already been stopped".to_string(),
|
||||
)),
|
||||
};
|
||||
TableBuilder::new(lua)?
|
||||
.with_function("stop", handle_stop)?
|
||||
.build_readonly()
|
||||
}
|
61
src/lune/builtins/net/server/keys.rs
Normal file
61
src/lune/builtins/net/server/keys.rs
Normal file
|
@ -0,0 +1,61 @@
|
|||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub(super) struct SvcKeys {
|
||||
key_request: &'static str,
|
||||
key_websocket: Option<&'static str>,
|
||||
}
|
||||
|
||||
impl SvcKeys {
|
||||
pub(super) fn new<'lua>(
|
||||
lua: &'lua Lua,
|
||||
handle_request: LuaFunction<'lua>,
|
||||
handle_websocket: Option<LuaFunction<'lua>>,
|
||||
) -> LuaResult<Self> {
|
||||
static SERVE_COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||
let count = SERVE_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||
|
||||
// NOTE: We leak strings here, but this is an acceptable tradeoff since programs
|
||||
// generally only start one or a couple of servers and they are usually never dropped.
|
||||
// Leaking here lets us keep this struct Copy and access the request handler callbacks
|
||||
// very performantly, significantly reducing the per-request overhead of the server.
|
||||
let key_request: &'static str =
|
||||
Box::leak(format!("__net_serve_request_{count}").into_boxed_str());
|
||||
let key_websocket: Option<&'static str> = if handle_websocket.is_some() {
|
||||
Some(Box::leak(
|
||||
format!("__net_serve_websocket_{count}").into_boxed_str(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
lua.set_named_registry_value(key_request, handle_request)?;
|
||||
if let Some(key) = key_websocket {
|
||||
lua.set_named_registry_value(key, handle_websocket.unwrap())?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
key_request,
|
||||
key_websocket,
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn has_websocket_handler(&self) -> bool {
|
||||
self.key_websocket.is_some()
|
||||
}
|
||||
|
||||
pub(super) fn request_handler<'lua>(&self, lua: &'lua Lua) -> LuaResult<LuaFunction<'lua>> {
|
||||
lua.named_registry_value(self.key_request)
|
||||
}
|
||||
|
||||
pub(super) fn websocket_handler<'lua>(
|
||||
&self,
|
||||
lua: &'lua Lua,
|
||||
) -> LuaResult<Option<LuaFunction<'lua>>> {
|
||||
self.key_websocket
|
||||
.map(|key| lua.named_registry_value(key))
|
||||
.transpose()
|
||||
}
|
||||
}
|
105
src/lune/builtins/net/server/mod.rs
Normal file
105
src/lune/builtins/net/server/mod.rs
Normal file
|
@ -0,0 +1,105 @@
|
|||
use std::{
|
||||
net::SocketAddr,
|
||||
rc::{Rc, Weak},
|
||||
};
|
||||
|
||||
use hyper::server::conn::http1;
|
||||
use hyper_util::rt::TokioIo;
|
||||
use tokio::{net::TcpListener, pin};
|
||||
|
||||
use mlua::prelude::*;
|
||||
use mlua_luau_scheduler::LuaSpawnExt;
|
||||
|
||||
use crate::lune::util::TableBuilder;
|
||||
|
||||
use super::config::ServeConfig;
|
||||
|
||||
mod keys;
|
||||
mod request;
|
||||
mod response;
|
||||
mod service;
|
||||
|
||||
use keys::SvcKeys;
|
||||
use service::Svc;
|
||||
|
||||
pub async fn serve<'lua>(
|
||||
lua: &'lua Lua,
|
||||
port: u16,
|
||||
config: ServeConfig<'lua>,
|
||||
) -> LuaResult<LuaTable<'lua>> {
|
||||
let addr: SocketAddr = (config.address, port).into();
|
||||
let listener = TcpListener::bind(addr).await?;
|
||||
|
||||
let (lua_svc, lua_inner) = {
|
||||
let rc = lua
|
||||
.app_data_ref::<Weak<Lua>>()
|
||||
.expect("Missing weak lua ref")
|
||||
.upgrade()
|
||||
.expect("Lua was dropped unexpectedly");
|
||||
(Rc::clone(&rc), rc)
|
||||
};
|
||||
|
||||
let keys = SvcKeys::new(lua, config.handle_request, config.handle_web_socket)?;
|
||||
let svc = Svc {
|
||||
lua: lua_svc,
|
||||
addr,
|
||||
keys,
|
||||
};
|
||||
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||||
lua.spawn_local(async move {
|
||||
let mut shutdown_rx_outer = shutdown_rx.clone();
|
||||
loop {
|
||||
// Create futures for accepting new connections and shutting down
|
||||
let fut_shutdown = shutdown_rx_outer.changed();
|
||||
let fut_accept = async {
|
||||
let stream = match listener.accept().await {
|
||||
Err(_) => return,
|
||||
Ok((s, _)) => s,
|
||||
};
|
||||
|
||||
let io = TokioIo::new(stream);
|
||||
let svc = svc.clone();
|
||||
let mut shutdown_rx_inner = shutdown_rx.clone();
|
||||
|
||||
lua_inner.spawn_local(async move {
|
||||
let conn = http1::Builder::new()
|
||||
.keep_alive(true) // Web sockets need this
|
||||
.serve_connection(io, svc)
|
||||
.with_upgrades();
|
||||
// NOTE: Because we need to use keep_alive for websockets, we need to
|
||||
// also manually poll this future and handle the shutdown signal here
|
||||
pin!(conn);
|
||||
tokio::select! {
|
||||
_ = conn.as_mut() => {}
|
||||
_ = shutdown_rx_inner.changed() => {
|
||||
conn.as_mut().graceful_shutdown();
|
||||
}
|
||||
}
|
||||
});
|
||||
};
|
||||
|
||||
// Wait for either a new connection or a shutdown signal
|
||||
tokio::select! {
|
||||
_ = fut_accept => {}
|
||||
res = fut_shutdown => {
|
||||
// NOTE: We will only get a RecvError here if the serve handle is dropped,
|
||||
// this means lua has garbage collected it and the user does not want
|
||||
// to manually stop the server using the serve handle. Run forever.
|
||||
if res.is_ok() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("ip", addr.ip().to_string())?
|
||||
.with_value("port", addr.port())?
|
||||
.with_function("stop", move |lua, _: ()| match shutdown_tx.send(true) {
|
||||
Ok(_) => Ok(()),
|
||||
Err(_) => Err(LuaError::runtime("Server already stopped")),
|
||||
})?
|
||||
.build_readonly()
|
||||
}
|
45
src/lune/builtins/net/server/request.rs
Normal file
45
src/lune/builtins/net/server/request.rs
Normal file
|
@ -0,0 +1,45 @@
|
|||
use std::{collections::HashMap, net::SocketAddr};
|
||||
|
||||
use http::request::Parts;
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use crate::lune::util::TableBuilder;
|
||||
|
||||
pub(super) struct LuaRequest {
|
||||
pub(super) _remote_addr: SocketAddr,
|
||||
pub(super) head: Parts,
|
||||
pub(super) body: Vec<u8>,
|
||||
}
|
||||
|
||||
impl LuaRequest {
|
||||
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
|
||||
let method = self.head.method.as_str().to_string();
|
||||
let path = self.head.uri.path().to_string();
|
||||
let body = lua.create_string(&self.body)?;
|
||||
|
||||
let query: HashMap<String, String> = self
|
||||
.head
|
||||
.uri
|
||||
.query()
|
||||
.unwrap_or_default()
|
||||
.split('&')
|
||||
.filter_map(|q| q.split_once('='))
|
||||
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||
.collect();
|
||||
let headers: HashMap<String, Vec<u8>> = self
|
||||
.head
|
||||
.headers
|
||||
.iter()
|
||||
.map(|(k, v)| (k.as_str().to_string(), v.as_bytes().to_vec()))
|
||||
.collect();
|
||||
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("method", method)?
|
||||
.with_value("path", path)?
|
||||
.with_value("query", query)?
|
||||
.with_value("headers", headers)?
|
||||
.with_value("body", body)?
|
||||
.build()
|
||||
}
|
||||
}
|
|
@ -1,52 +1,55 @@
|
|||
use std::collections::HashMap;
|
||||
use std::str::FromStr;
|
||||
|
||||
use http_body_util::Full;
|
||||
use hyper::{
|
||||
body::Bytes,
|
||||
header::{HeaderName, HeaderValue},
|
||||
HeaderMap, Response,
|
||||
};
|
||||
|
||||
use hyper::{Body, Response};
|
||||
use mlua::prelude::*;
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum NetServeResponseKind {
|
||||
pub(super) enum LuaResponseKind {
|
||||
PlainText,
|
||||
Table,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct NetServeResponse {
|
||||
kind: NetServeResponseKind,
|
||||
status: u16,
|
||||
headers: HashMap<String, Vec<u8>>,
|
||||
body: Option<Vec<u8>>,
|
||||
pub(super) struct LuaResponse {
|
||||
pub(super) kind: LuaResponseKind,
|
||||
pub(super) status: u16,
|
||||
pub(super) headers: HeaderMap,
|
||||
pub(super) body: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl NetServeResponse {
|
||||
pub fn into_response(self) -> LuaResult<Response<Body>> {
|
||||
impl LuaResponse {
|
||||
pub(super) fn into_response(self) -> LuaResult<Response<Full<Bytes>>> {
|
||||
Ok(match self.kind {
|
||||
NetServeResponseKind::PlainText => Response::builder()
|
||||
LuaResponseKind::PlainText => Response::builder()
|
||||
.status(200)
|
||||
.header("Content-Type", "text/plain")
|
||||
.body(Body::from(self.body.unwrap()))
|
||||
.body(Full::new(Bytes::from(self.body.unwrap())))
|
||||
.into_lua_err()?,
|
||||
NetServeResponseKind::Table => {
|
||||
let mut response = Response::builder();
|
||||
for (key, value) in self.headers {
|
||||
response = response.header(&key, value);
|
||||
}
|
||||
response
|
||||
LuaResponseKind::Table => {
|
||||
let mut response = Response::builder()
|
||||
.status(self.status)
|
||||
.body(Body::from(self.body.unwrap_or_default()))
|
||||
.into_lua_err()?
|
||||
.body(Full::new(Bytes::from(self.body.unwrap_or_default())))
|
||||
.into_lua_err()?;
|
||||
response.headers_mut().extend(self.headers);
|
||||
response
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> FromLua<'lua> for NetServeResponse {
|
||||
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
|
||||
impl FromLua<'_> for LuaResponse {
|
||||
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
|
||||
match value {
|
||||
// Plain strings from the handler are plaintext responses
|
||||
LuaValue::String(s) => Ok(Self {
|
||||
kind: NetServeResponseKind::PlainText,
|
||||
kind: LuaResponseKind::PlainText,
|
||||
status: 200,
|
||||
headers: HashMap::new(),
|
||||
headers: HeaderMap::new(),
|
||||
body: Some(s.as_bytes().to_vec()),
|
||||
}),
|
||||
// Tables are more detailed responses with potential status, headers, body
|
||||
|
@ -55,18 +58,20 @@ impl<'lua> FromLua<'lua> for NetServeResponse {
|
|||
let headers: Option<LuaTable> = t.get("headers")?;
|
||||
let body: Option<LuaString> = t.get("body")?;
|
||||
|
||||
let mut headers_map = HashMap::new();
|
||||
let mut headers_map = HeaderMap::new();
|
||||
if let Some(headers) = headers {
|
||||
for pair in headers.pairs::<String, LuaString>() {
|
||||
let (h, v) = pair?;
|
||||
headers_map.insert(h, v.as_bytes().to_vec());
|
||||
let name = HeaderName::from_str(&h).into_lua_err()?;
|
||||
let value = HeaderValue::from_bytes(v.as_bytes()).into_lua_err()?;
|
||||
headers_map.insert(name, value);
|
||||
}
|
||||
}
|
||||
|
||||
let body_bytes = body.map(|s| s.as_bytes().to_vec());
|
||||
|
||||
Ok(Self {
|
||||
kind: NetServeResponseKind::Table,
|
||||
kind: LuaResponseKind::Table,
|
||||
status: status.unwrap_or(200),
|
||||
headers: headers_map,
|
||||
body: body_bytes,
|
82
src/lune/builtins/net/server/service.rs
Normal file
82
src/lune/builtins/net/server/service.rs
Normal file
|
@ -0,0 +1,82 @@
|
|||
use std::{future::Future, net::SocketAddr, pin::Pin, rc::Rc};
|
||||
|
||||
use http_body_util::{BodyExt, Full};
|
||||
use hyper::{
|
||||
body::{Bytes, Incoming},
|
||||
service::Service,
|
||||
Request, Response,
|
||||
};
|
||||
use hyper_tungstenite::{is_upgrade_request, upgrade};
|
||||
|
||||
use mlua::prelude::*;
|
||||
use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt};
|
||||
|
||||
use super::{
|
||||
super::websocket::NetWebSocket, keys::SvcKeys, request::LuaRequest, response::LuaResponse,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct Svc {
|
||||
pub(super) lua: Rc<Lua>,
|
||||
pub(super) addr: SocketAddr,
|
||||
pub(super) keys: SvcKeys,
|
||||
}
|
||||
|
||||
impl Service<Request<Incoming>> for Svc {
|
||||
type Response = Response<Full<Bytes>>;
|
||||
type Error = LuaError;
|
||||
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
|
||||
|
||||
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||
let lua = self.lua.clone();
|
||||
let addr = self.addr;
|
||||
let keys = self.keys;
|
||||
|
||||
if keys.has_websocket_handler() && is_upgrade_request(&req) {
|
||||
Box::pin(async move {
|
||||
let (res, sock) = upgrade(req, None).into_lua_err()?;
|
||||
|
||||
let lua_inner = lua.clone();
|
||||
lua.spawn_local(async move {
|
||||
let sock = sock.await.unwrap();
|
||||
let lua_sock = NetWebSocket::new(sock);
|
||||
let lua_tab = lua_sock.into_lua_table(&lua_inner).unwrap();
|
||||
|
||||
let handler_websocket: LuaFunction =
|
||||
keys.websocket_handler(&lua_inner).unwrap().unwrap();
|
||||
|
||||
lua_inner
|
||||
.push_thread_back(handler_websocket, lua_tab)
|
||||
.unwrap();
|
||||
});
|
||||
|
||||
Ok(res)
|
||||
})
|
||||
} else {
|
||||
let (head, body) = req.into_parts();
|
||||
|
||||
Box::pin(async move {
|
||||
let handler_request: LuaFunction = keys.request_handler(&lua).unwrap();
|
||||
|
||||
let body = body.collect().await.into_lua_err()?;
|
||||
let body = body.to_bytes().to_vec();
|
||||
|
||||
let lua_req = LuaRequest {
|
||||
_remote_addr: addr,
|
||||
head,
|
||||
body,
|
||||
};
|
||||
let lua_req_table = lua_req.into_lua_table(&lua)?;
|
||||
|
||||
let thread_id = lua.push_thread_back(handler_request, lua_req_table)?;
|
||||
lua.track_thread(thread_id);
|
||||
lua.wait_for_thread(thread_id).await;
|
||||
let thread_res = lua
|
||||
.get_thread_result(thread_id)
|
||||
.expect("Missing handler thread result")?;
|
||||
|
||||
LuaResponse::from_lua_multi(thread_res, &lua)?.into_response()
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,14 +1,20 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use hyper::{
|
||||
header::{CONTENT_ENCODING, CONTENT_LENGTH},
|
||||
HeaderMap,
|
||||
};
|
||||
use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH};
|
||||
use reqwest::header::HeaderMap;
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use crate::lune::util::TableBuilder;
|
||||
|
||||
pub fn create_user_agent_header() -> String {
|
||||
let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY")
|
||||
.trim_start_matches("https://github.com/")
|
||||
.split_once('/')
|
||||
.unwrap();
|
||||
format!("{github_owner}-{github_repo}-cli")
|
||||
}
|
||||
|
||||
pub fn header_map_to_table(
|
||||
lua: &Lua,
|
||||
headers: HeaderMap,
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use std::sync::Arc;
|
||||
use std::sync::{
|
||||
atomic::{AtomicBool, AtomicU16, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
use hyper::upgrade::Upgraded;
|
||||
use mlua::prelude::*;
|
||||
|
||||
use futures_util::{
|
||||
|
@ -9,7 +11,6 @@ use futures_util::{
|
|||
};
|
||||
use tokio::{
|
||||
io::{AsyncRead, AsyncWrite},
|
||||
net::TcpStream,
|
||||
sync::Mutex as AsyncMutex,
|
||||
};
|
||||
|
||||
|
@ -20,25 +21,25 @@ use hyper_tungstenite::{
|
|||
},
|
||||
WebSocketStream,
|
||||
};
|
||||
use tokio_tungstenite::MaybeTlsStream;
|
||||
|
||||
use crate::lune::util::{buffer::buf_to_str, TableBuilder};
|
||||
|
||||
// Wrapper implementation for compatibility and changing colon syntax to dot syntax
|
||||
const WEB_SOCKET_IMPL_LUA: &str = r#"
|
||||
return freeze(setmetatable({
|
||||
close = function(...)
|
||||
return close(websocket, ...)
|
||||
return websocket:close(...)
|
||||
end,
|
||||
send = function(...)
|
||||
return send(websocket, ...)
|
||||
return websocket:send(...)
|
||||
end,
|
||||
next = function(...)
|
||||
return next(websocket, ...)
|
||||
return websocket:next(...)
|
||||
end,
|
||||
}, {
|
||||
__index = function(self, key)
|
||||
if key == "closeCode" then
|
||||
return close_code(websocket)
|
||||
return websocket.closeCode
|
||||
end
|
||||
end,
|
||||
}))
|
||||
|
@ -46,7 +47,8 @@ return freeze(setmetatable({
|
|||
|
||||
#[derive(Debug)]
|
||||
pub struct NetWebSocket<T> {
|
||||
close_code: Arc<AsyncMutex<Option<u16>>>,
|
||||
close_code_exists: Arc<AtomicBool>,
|
||||
close_code_value: Arc<AtomicU16>,
|
||||
read_stream: Arc<AsyncMutex<SplitStream<WebSocketStream<T>>>>,
|
||||
write_stream: Arc<AsyncMutex<SplitSink<WebSocketStream<T>, WsMessage>>>,
|
||||
}
|
||||
|
@ -54,7 +56,8 @@ pub struct NetWebSocket<T> {
|
|||
impl<T> Clone for NetWebSocket<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
close_code: Arc::clone(&self.close_code),
|
||||
close_code_exists: Arc::clone(&self.close_code_exists),
|
||||
close_code_value: Arc::clone(&self.close_code_value),
|
||||
read_stream: Arc::clone(&self.read_stream),
|
||||
write_stream: Arc::clone(&self.write_stream),
|
||||
}
|
||||
|
@ -63,22 +66,78 @@ impl<T> Clone for NetWebSocket<T> {
|
|||
|
||||
impl<T> NetWebSocket<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
pub fn new(value: WebSocketStream<T>) -> Self {
|
||||
let (write, read) = value.split();
|
||||
|
||||
Self {
|
||||
close_code: Arc::new(AsyncMutex::new(None)),
|
||||
close_code_exists: Arc::new(AtomicBool::new(false)),
|
||||
close_code_value: Arc::new(AtomicU16::new(0)),
|
||||
read_stream: Arc::new(AsyncMutex::new(read)),
|
||||
write_stream: Arc::new(AsyncMutex::new(write)),
|
||||
}
|
||||
}
|
||||
|
||||
fn into_lua_table_with_env<'lua>(
|
||||
lua: &'lua Lua,
|
||||
env: LuaTable<'lua>,
|
||||
) -> LuaResult<LuaTable<'lua>> {
|
||||
fn get_close_code(&self) -> Option<u16> {
|
||||
if self.close_code_exists.load(Ordering::Relaxed) {
|
||||
Some(self.close_code_value.load(Ordering::Relaxed))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn set_close_code(&self, code: u16) {
|
||||
self.close_code_exists.store(true, Ordering::Relaxed);
|
||||
self.close_code_value.store(code, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
|
||||
let mut ws = self.write_stream.lock().await;
|
||||
ws.send(msg).await.into_lua_err()
|
||||
}
|
||||
|
||||
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
|
||||
let mut ws = self.read_stream.lock().await;
|
||||
ws.next().await.transpose().into_lua_err()
|
||||
}
|
||||
|
||||
pub async fn close(&self, code: Option<u16>) -> LuaResult<()> {
|
||||
if self.close_code_exists.load(Ordering::Relaxed) {
|
||||
return Err(LuaError::runtime("Socket has already been closed"));
|
||||
}
|
||||
|
||||
self.send(WsMessage::Close(Some(WsCloseFrame {
|
||||
code: match code {
|
||||
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
|
||||
Some(code) => {
|
||||
return Err(LuaError::runtime(format!(
|
||||
"Close code must be between 1000 and 4999, got {code}"
|
||||
)))
|
||||
}
|
||||
None => WsCloseCode::Normal,
|
||||
},
|
||||
reason: "".into(),
|
||||
})))
|
||||
.await?;
|
||||
|
||||
let mut ws = self.write_stream.lock().await;
|
||||
ws.close().await.into_lua_err()
|
||||
}
|
||||
|
||||
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
|
||||
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
|
||||
let table_freeze = lua
|
||||
.globals()
|
||||
.get::<_, LuaTable>("table")?
|
||||
.get::<_, LuaFunction>("freeze")?;
|
||||
|
||||
let env = TableBuilder::new(lua)?
|
||||
.with_value("websocket", self.clone())?
|
||||
.with_value("setmetatable", setmetatable)?
|
||||
.with_value("freeze", table_freeze)?
|
||||
.build_readonly()?;
|
||||
|
||||
lua.load(WEB_SOCKET_IMPL_LUA)
|
||||
.set_name("websocket")
|
||||
.set_environment(env)
|
||||
|
@ -86,158 +145,46 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
type NetWebSocketStreamClient = MaybeTlsStream<TcpStream>;
|
||||
impl NetWebSocket<NetWebSocketStreamClient> {
|
||||
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
|
||||
let table_freeze = lua
|
||||
.globals()
|
||||
.get::<_, LuaTable>("table")?
|
||||
.get::<_, LuaFunction>("freeze")?;
|
||||
let socket_env = TableBuilder::new(lua)?
|
||||
.with_value("websocket", self)?
|
||||
.with_function("close_code", close_code::<NetWebSocketStreamClient>)?
|
||||
.with_async_function("close", close::<NetWebSocketStreamClient>)?
|
||||
.with_async_function("send", send::<NetWebSocketStreamClient>)?
|
||||
.with_async_function("next", next::<NetWebSocketStreamClient>)?
|
||||
.with_value("setmetatable", setmetatable)?
|
||||
.with_value("freeze", table_freeze)?
|
||||
.build_readonly()?;
|
||||
Self::into_lua_table_with_env(lua, socket_env)
|
||||
}
|
||||
}
|
||||
|
||||
type NetWebSocketStreamServer = Upgraded;
|
||||
impl NetWebSocket<NetWebSocketStreamServer> {
|
||||
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
|
||||
let table_freeze = lua
|
||||
.globals()
|
||||
.get::<_, LuaTable>("table")?
|
||||
.get::<_, LuaFunction>("freeze")?;
|
||||
let socket_env = TableBuilder::new(lua)?
|
||||
.with_value("websocket", self)?
|
||||
.with_function("close_code", close_code::<NetWebSocketStreamServer>)?
|
||||
.with_async_function("close", close::<NetWebSocketStreamServer>)?
|
||||
.with_async_function("send", send::<NetWebSocketStreamServer>)?
|
||||
.with_async_function("next", next::<NetWebSocketStreamServer>)?
|
||||
.with_value("setmetatable", setmetatable)?
|
||||
.with_value("freeze", table_freeze)?
|
||||
.build_readonly()?;
|
||||
Self::into_lua_table_with_env(lua, socket_env)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> LuaUserData for NetWebSocket<T> {}
|
||||
|
||||
fn close_code<'lua, T>(
|
||||
_lua: &'lua Lua,
|
||||
socket: LuaUserDataRef<'lua, NetWebSocket<T>>,
|
||||
) -> LuaResult<LuaValue<'lua>>
|
||||
impl<T> LuaUserData for NetWebSocket<T>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||
{
|
||||
Ok(
|
||||
match *socket
|
||||
.close_code
|
||||
.try_lock()
|
||||
.expect("Failed to lock close code")
|
||||
{
|
||||
Some(code) => LuaValue::Number(code as f64),
|
||||
None => LuaValue::Nil,
|
||||
},
|
||||
)
|
||||
}
|
||||
fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) {
|
||||
fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code()));
|
||||
}
|
||||
|
||||
async fn close<'lua, T>(
|
||||
_lua: &'lua Lua,
|
||||
(socket, code): (LuaUserDataRef<'lua, NetWebSocket<T>>, Option<u16>),
|
||||
) -> LuaResult<()>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut ws = socket.write_stream.lock().await;
|
||||
fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
|
||||
methods.add_async_method("close", |lua, this, code: Option<u16>| async move {
|
||||
this.close(code).await
|
||||
});
|
||||
|
||||
ws.send(WsMessage::Close(Some(WsCloseFrame {
|
||||
code: match code {
|
||||
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
|
||||
Some(code) => {
|
||||
return Err(LuaError::RuntimeError(format!(
|
||||
"Close code must be between 1000 and 4999, got {code}"
|
||||
)))
|
||||
methods.add_async_method(
|
||||
"send",
|
||||
|_, this, (string, as_binary): (LuaString, Option<bool>)| async move {
|
||||
this.send(if as_binary.unwrap_or_default() {
|
||||
WsMessage::Binary(string.as_bytes().to_vec())
|
||||
} else {
|
||||
let s = string.to_str().into_lua_err()?;
|
||||
WsMessage::Text(s.to_string())
|
||||
})
|
||||
.await
|
||||
},
|
||||
);
|
||||
|
||||
methods.add_async_method("next", |lua, this, _: ()| async move {
|
||||
let msg = this.next().await?;
|
||||
|
||||
if let Some(WsMessage::Close(Some(frame))) = msg.as_ref() {
|
||||
this.set_close_code(frame.code.into());
|
||||
}
|
||||
None => WsCloseCode::Normal,
|
||||
},
|
||||
reason: "".into(),
|
||||
})))
|
||||
.await
|
||||
.into_lua_err()?;
|
||||
|
||||
let res = ws.close();
|
||||
res.await.into_lua_err()
|
||||
}
|
||||
|
||||
async fn send<'lua, T>(
|
||||
lua: &'lua Lua,
|
||||
(socket, data, as_binary): (
|
||||
LuaUserDataRef<'lua, NetWebSocket<T>>,
|
||||
LuaValue<'lua>,
|
||||
Option<bool>,
|
||||
),
|
||||
) -> LuaResult<()>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let string = match data {
|
||||
LuaValue::String(str) => Ok(str.to_str()?.to_string()),
|
||||
LuaValue::UserData(inner) => buf_to_str(lua, LuaValue::UserData(inner)),
|
||||
other => Err(LuaError::runtime(format!(
|
||||
"Expected data to be of type string or buffer, got {}",
|
||||
other.type_name()
|
||||
))),
|
||||
}?;
|
||||
|
||||
let msg = if matches!(as_binary, Some(true)) {
|
||||
WsMessage::Binary(string.as_bytes().to_vec())
|
||||
} else {
|
||||
let s = string;
|
||||
WsMessage::Text(s.to_string())
|
||||
};
|
||||
let mut ws = socket.write_stream.lock().await;
|
||||
ws.send(msg).await.into_lua_err()
|
||||
}
|
||||
|
||||
async fn next<'lua, T>(
|
||||
lua: &'lua Lua,
|
||||
socket: LuaUserDataRef<'lua, NetWebSocket<T>>,
|
||||
) -> LuaResult<LuaValue<'lua>>
|
||||
where
|
||||
T: AsyncRead + AsyncWrite + Unpin,
|
||||
{
|
||||
let mut ws = socket.read_stream.lock().await;
|
||||
let item = ws.next().await.transpose().into_lua_err();
|
||||
let msg = match item {
|
||||
Ok(Some(WsMessage::Close(msg))) => {
|
||||
if let Some(msg) = &msg {
|
||||
let mut code = socket.close_code.lock().await;
|
||||
*code = Some(msg.code.into());
|
||||
}
|
||||
Ok(Some(WsMessage::Close(msg)))
|
||||
}
|
||||
val => val,
|
||||
}?;
|
||||
while let Some(msg) = &msg {
|
||||
let msg_string_opt = match msg {
|
||||
WsMessage::Binary(bin) => Some(lua.create_string(bin)?),
|
||||
WsMessage::Text(txt) => Some(lua.create_string(txt)?),
|
||||
// Stop waiting for next message if we get a close message
|
||||
WsMessage::Close(_) => return Ok(LuaValue::Nil),
|
||||
// Ignore ping/pong/frame messages, they are handled by tungstenite
|
||||
_ => None,
|
||||
};
|
||||
if let Some(msg_string) = msg_string_opt {
|
||||
return Ok(LuaValue::String(msg_string));
|
||||
}
|
||||
Ok(match msg {
|
||||
Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?),
|
||||
Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?),
|
||||
Some(WsMessage::Close(_)) | None => LuaValue::Nil,
|
||||
// Ignore ping/pong/frame messages, they are handled by tungstenite
|
||||
msg => unreachable!("Unhandled message: {:?}", msg),
|
||||
})
|
||||
});
|
||||
}
|
||||
Ok(LuaValue::Nil)
|
||||
}
|
||||
|
|
|
@ -5,13 +5,11 @@ use std::{
|
|||
};
|
||||
|
||||
use mlua::prelude::*;
|
||||
use mlua_luau_scheduler::{Functions, LuaSpawnExt};
|
||||
use os_str_bytes::RawOsString;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use crate::lune::{
|
||||
scheduler::Scheduler,
|
||||
util::{paths::CWD, TableBuilder},
|
||||
};
|
||||
use crate::lune::util::{paths::CWD, TableBuilder};
|
||||
|
||||
mod tee_writer;
|
||||
|
||||
|
@ -21,12 +19,7 @@ use options::ProcessSpawnOptions;
|
|||
mod wait_for_child;
|
||||
use wait_for_child::{wait_for_child, WaitForChildResult};
|
||||
|
||||
const PROCESS_EXIT_IMPL_LUA: &str = r#"
|
||||
exit(...)
|
||||
yield()
|
||||
"#;
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
let cwd_str = {
|
||||
let cwd_str = CWD.to_string_lossy().to_string();
|
||||
if !cwd_str.ends_with(path::MAIN_SEPARATOR) {
|
||||
|
@ -56,30 +49,9 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
|||
.build_readonly()?,
|
||||
)?
|
||||
.build_readonly()?;
|
||||
// Create our process exit function, this is a bit involved since
|
||||
// we have no way to yield from c / rust, we need to load a lua
|
||||
// chunk that will set the exit code and yield for us instead
|
||||
let coroutine_yield = lua
|
||||
.globals()
|
||||
.get::<_, LuaTable>("coroutine")?
|
||||
.get::<_, LuaFunction>("yield")?;
|
||||
let set_scheduler_exit_code = lua.create_function(|lua, code: Option<u8>| {
|
||||
let sched = lua
|
||||
.app_data_ref::<&Scheduler>()
|
||||
.expect("Lua struct is missing scheduler");
|
||||
sched.set_exit_code(code.unwrap_or_default());
|
||||
Ok(())
|
||||
})?;
|
||||
let process_exit = lua
|
||||
.load(PROCESS_EXIT_IMPL_LUA)
|
||||
.set_name("=process.exit")
|
||||
.set_environment(
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("yield", coroutine_yield)?
|
||||
.with_value("exit", set_scheduler_exit_code)?
|
||||
.build_readonly()?,
|
||||
)
|
||||
.into_function()?;
|
||||
// Create our process exit function, the scheduler crate provides this
|
||||
let fns = Functions::new(lua)?;
|
||||
let process_exit = fns.exit;
|
||||
// Create the full process table
|
||||
TableBuilder::new(lua)?
|
||||
.with_value("os", os)?
|
||||
|
@ -165,22 +137,10 @@ async fn process_spawn(
|
|||
lua: &Lua,
|
||||
(program, args, options): (String, Option<Vec<String>>, ProcessSpawnOptions),
|
||||
) -> LuaResult<LuaTable> {
|
||||
/*
|
||||
Spawn the new process in the background, letting the tokio
|
||||
runtime place it on a different thread if possible / necessary
|
||||
|
||||
Note that we have to use our scheduler here, we can't
|
||||
be using tokio::task::spawn directly because our lua
|
||||
scheduler would not drive those futures to completion
|
||||
*/
|
||||
let sched = lua
|
||||
.app_data_ref::<&Scheduler>()
|
||||
.expect("Lua struct is missing scheduler");
|
||||
|
||||
let res = sched
|
||||
let res = lua
|
||||
.spawn(spawn_command(program, args, options))
|
||||
.await
|
||||
.expect("Failed to receive result of spawned process")?;
|
||||
.expect("Failed to receive result of spawned process");
|
||||
|
||||
/*
|
||||
NOTE: If an exit code was not given by the child process,
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use mlua::prelude::*;
|
||||
use mlua_luau_scheduler::LuaSpawnExt;
|
||||
use once_cell::sync::OnceCell;
|
||||
|
||||
use crate::{
|
||||
|
@ -11,11 +12,9 @@ use crate::{
|
|||
},
|
||||
};
|
||||
|
||||
use tokio::task;
|
||||
|
||||
static REFLECTION_DATABASE: OnceCell<ReflectionDatabase> = OnceCell::new();
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
let mut roblox_constants = Vec::new();
|
||||
|
||||
let roblox_module = roblox::module(lua)?;
|
||||
|
@ -41,12 +40,12 @@ async fn deserialize_place<'lua>(
|
|||
contents: LuaString<'lua>,
|
||||
) -> LuaResult<LuaValue<'lua>> {
|
||||
let bytes = contents.as_bytes().to_vec();
|
||||
let fut = task::spawn_blocking(move || {
|
||||
let fut = lua.spawn_blocking(move || {
|
||||
let doc = Document::from_bytes(bytes, DocumentKind::Place)?;
|
||||
let data_model = doc.into_data_model_instance()?;
|
||||
Ok::<_, DocumentError>(data_model)
|
||||
});
|
||||
fut.await.into_lua_err()??.into_lua(lua)
|
||||
fut.await.into_lua_err()?.into_lua(lua)
|
||||
}
|
||||
|
||||
async fn deserialize_model<'lua>(
|
||||
|
@ -54,12 +53,12 @@ async fn deserialize_model<'lua>(
|
|||
contents: LuaString<'lua>,
|
||||
) -> LuaResult<LuaValue<'lua>> {
|
||||
let bytes = contents.as_bytes().to_vec();
|
||||
let fut = task::spawn_blocking(move || {
|
||||
let fut = lua.spawn_blocking(move || {
|
||||
let doc = Document::from_bytes(bytes, DocumentKind::Model)?;
|
||||
let instance_array = doc.into_instance_array()?;
|
||||
Ok::<_, DocumentError>(instance_array)
|
||||
});
|
||||
fut.await.into_lua_err()??.into_lua(lua)
|
||||
fut.await.into_lua_err()?.into_lua(lua)
|
||||
}
|
||||
|
||||
async fn serialize_place<'lua>(
|
||||
|
@ -67,7 +66,7 @@ async fn serialize_place<'lua>(
|
|||
(data_model, as_xml): (LuaUserDataRef<'lua, Instance>, Option<bool>),
|
||||
) -> LuaResult<LuaString<'lua>> {
|
||||
let data_model = (*data_model).clone();
|
||||
let fut = task::spawn_blocking(move || {
|
||||
let fut = lua.spawn_blocking(move || {
|
||||
let doc = Document::from_data_model_instance(data_model)?;
|
||||
let bytes = doc.to_bytes_with_format(match as_xml {
|
||||
Some(true) => DocumentFormat::Xml,
|
||||
|
@ -75,7 +74,7 @@ async fn serialize_place<'lua>(
|
|||
})?;
|
||||
Ok::<_, DocumentError>(bytes)
|
||||
});
|
||||
let bytes = fut.await.into_lua_err()??;
|
||||
let bytes = fut.await.into_lua_err()?;
|
||||
lua.create_string(bytes)
|
||||
}
|
||||
|
||||
|
@ -84,7 +83,7 @@ async fn serialize_model<'lua>(
|
|||
(instances, as_xml): (Vec<LuaUserDataRef<'lua, Instance>>, Option<bool>),
|
||||
) -> LuaResult<LuaString<'lua>> {
|
||||
let instances = instances.iter().map(|i| (*i).clone()).collect();
|
||||
let fut = task::spawn_blocking(move || {
|
||||
let fut = lua.spawn_blocking(move || {
|
||||
let doc = Document::from_instance_array(instances)?;
|
||||
let bytes = doc.to_bytes_with_format(match as_xml {
|
||||
Some(true) => DocumentFormat::Xml,
|
||||
|
@ -92,7 +91,7 @@ async fn serialize_model<'lua>(
|
|||
})?;
|
||||
Ok::<_, DocumentError>(bytes)
|
||||
});
|
||||
let bytes = fut.await.into_lua_err()??;
|
||||
let bytes = fut.await.into_lua_err()?;
|
||||
lua.create_string(bytes)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
use lz4_flex::{compress_prepend_size, decompress_size_prepended};
|
||||
use mlua::prelude::*;
|
||||
use tokio::{
|
||||
io::{copy, BufReader},
|
||||
task,
|
||||
};
|
||||
|
||||
use lz4_flex::{compress_prepend_size, decompress_size_prepended};
|
||||
use tokio::io::{copy, BufReader};
|
||||
|
||||
use async_compression::{
|
||||
tokio::bufread::{
|
||||
|
@ -100,9 +98,7 @@ pub async fn compress<'lua>(
|
|||
) -> LuaResult<Vec<u8>> {
|
||||
if let CompressDecompressFormat::LZ4 = format {
|
||||
let source = source.as_ref().to_vec();
|
||||
return task::spawn_blocking(move || compress_prepend_size(&source))
|
||||
.await
|
||||
.into_lua_err();
|
||||
return Ok(blocking::unblock(move || compress_prepend_size(&source)).await);
|
||||
}
|
||||
|
||||
let mut bytes = Vec::new();
|
||||
|
@ -133,9 +129,8 @@ pub async fn decompress<'lua>(
|
|||
) -> LuaResult<Vec<u8>> {
|
||||
if let CompressDecompressFormat::LZ4 = format {
|
||||
let source = source.as_ref().to_vec();
|
||||
return task::spawn_blocking(move || decompress_size_prepended(&source))
|
||||
return blocking::unblock(move || decompress_size_prepended(&source))
|
||||
.await
|
||||
.into_lua_err()?
|
||||
.into_lua_err();
|
||||
}
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ use encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat};
|
|||
|
||||
use crate::lune::util::TableBuilder;
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||
TableBuilder::new(lua)?
|
||||
.with_function("encode", serde_encode)?
|
||||
.with_function("decode", serde_decode)?
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
use mlua::prelude::*;
|
||||
|
||||
use dialoguer::{theme::ColorfulTheme, Confirm, Input, MultiSelect, Select};
|
||||
use tokio::{
|
||||
io::{self, AsyncWriteExt},
|
||||
task,
|
||||
};
|
||||
use mlua_luau_scheduler::LuaSpawnExt;
|
||||
use tokio::io::{self, AsyncWriteExt};
|
||||
|
||||
use crate::lune::util::{
|
||||
formatting::{
|
||||
|
@ -16,7 +14,7 @@ use crate::lune::util::{
|
|||
mod prompt;
|
||||
use prompt::{PromptKind, PromptOptions, PromptResult};
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'_>> {
|
||||
pub fn create(lua: &Lua) -> LuaResult<LuaTable<'_>> {
|
||||
TableBuilder::new(lua)?
|
||||
.with_function("color", stdio_color)?
|
||||
.with_function("style", stdio_style)?
|
||||
|
@ -55,10 +53,10 @@ async fn stdio_ewrite(_: &Lua, s: LuaString<'_>) -> LuaResult<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn stdio_prompt(_: &Lua, options: PromptOptions) -> LuaResult<PromptResult> {
|
||||
task::spawn_blocking(move || prompt(options))
|
||||
async fn stdio_prompt(lua: &Lua, options: PromptOptions) -> LuaResult<PromptResult> {
|
||||
lua.spawn_blocking(move || prompt(options))
|
||||
.await
|
||||
.into_lua_err()?
|
||||
.into_lua_err()
|
||||
}
|
||||
|
||||
fn prompt(options: PromptOptions) -> LuaResult<PromptResult> {
|
||||
|
|
|
@ -2,120 +2,51 @@ use std::time::Duration;
|
|||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use mlua_luau_scheduler::Functions;
|
||||
use tokio::time::{self, Instant};
|
||||
|
||||
use crate::lune::{scheduler::Scheduler, util::TableBuilder};
|
||||
use crate::lune::util::TableBuilder;
|
||||
|
||||
mod tof;
|
||||
use tof::LuaThreadOrFunction;
|
||||
|
||||
/*
|
||||
The spawn function needs special treatment,
|
||||
we need to yield right away to allow the
|
||||
spawned task to run until first yield
|
||||
|
||||
1. Schedule this current thread at the front
|
||||
2. Schedule given thread/function at the front,
|
||||
the previous schedule now comes right after
|
||||
3. Give control over to the scheduler, which will
|
||||
resume the above tasks in order when its ready
|
||||
*/
|
||||
const SPAWN_IMPL_LUA: &str = r#"
|
||||
push(currentThread())
|
||||
local thread = push(...)
|
||||
yield()
|
||||
return thread
|
||||
const DELAY_IMPL_LUA: &str = r#"
|
||||
return defer(function(...)
|
||||
wait(select(1, ...))
|
||||
spawn(select(2, ...))
|
||||
end, ...)
|
||||
"#;
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'_>> {
|
||||
let coroutine_running = lua
|
||||
.globals()
|
||||
.get::<_, LuaTable>("coroutine")?
|
||||
.get::<_, LuaFunction>("running")?;
|
||||
let coroutine_yield = lua
|
||||
.globals()
|
||||
.get::<_, LuaTable>("coroutine")?
|
||||
.get::<_, LuaFunction>("yield")?;
|
||||
let push_front =
|
||||
lua.create_function(|lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
|
||||
let thread = tof.into_thread(lua)?;
|
||||
let sched = lua
|
||||
.app_data_ref::<&Scheduler>()
|
||||
.expect("Lua struct is missing scheduler");
|
||||
sched.push_front(lua, thread.clone(), args)?;
|
||||
Ok(thread)
|
||||
})?;
|
||||
let task_spawn_env = TableBuilder::new(lua)?
|
||||
.with_value("currentThread", coroutine_running)?
|
||||
.with_value("yield", coroutine_yield)?
|
||||
.with_value("push", push_front)?
|
||||
pub fn create(lua: &Lua) -> LuaResult<LuaTable<'_>> {
|
||||
let fns = Functions::new(lua)?;
|
||||
|
||||
// Create wait & delay functions
|
||||
let task_wait = lua.create_async_function(wait)?;
|
||||
let task_delay_env = TableBuilder::new(lua)?
|
||||
.with_value("select", lua.globals().get::<_, LuaFunction>("select")?)?
|
||||
.with_value("spawn", fns.spawn.clone())?
|
||||
.with_value("defer", fns.defer.clone())?
|
||||
.with_value("wait", task_wait.clone())?
|
||||
.build_readonly()?;
|
||||
let task_spawn = lua
|
||||
.load(SPAWN_IMPL_LUA)
|
||||
.set_name("task.spawn")
|
||||
.set_environment(task_spawn_env)
|
||||
let task_delay = lua
|
||||
.load(DELAY_IMPL_LUA)
|
||||
.set_name("task.delay")
|
||||
.set_environment(task_delay_env)
|
||||
.into_function()?;
|
||||
|
||||
// Overwrite resume & wrap functions on the coroutine global
|
||||
// with ones that are compatible with our scheduler
|
||||
let co = lua.globals().get::<_, LuaTable>("coroutine")?;
|
||||
co.set("resume", fns.resume.clone())?;
|
||||
co.set("wrap", fns.wrap.clone())?;
|
||||
|
||||
TableBuilder::new(lua)?
|
||||
.with_function("cancel", task_cancel)?
|
||||
.with_function("defer", task_defer)?
|
||||
.with_function("delay", task_delay)?
|
||||
.with_value("spawn", task_spawn)?
|
||||
.with_async_function("wait", task_wait)?
|
||||
.with_value("cancel", fns.cancel)?
|
||||
.with_value("defer", fns.defer)?
|
||||
.with_value("delay", task_delay)?
|
||||
.with_value("spawn", fns.spawn)?
|
||||
.with_value("wait", task_wait)?
|
||||
.build_readonly()
|
||||
}
|
||||
|
||||
fn task_cancel(lua: &Lua, thread: LuaThread) -> LuaResult<()> {
|
||||
let close = lua
|
||||
.globals()
|
||||
.get::<_, LuaTable>("coroutine")?
|
||||
.get::<_, LuaFunction>("close")?;
|
||||
match close.call(thread) {
|
||||
Err(LuaError::CoroutineInactive) => Ok(()),
|
||||
Err(e) => Err(e),
|
||||
Ok(()) => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
fn task_defer<'lua>(
|
||||
lua: &'lua Lua,
|
||||
(tof, args): (LuaThreadOrFunction<'lua>, LuaMultiValue<'_>),
|
||||
) -> LuaResult<LuaThread<'lua>> {
|
||||
let thread = tof.into_thread(lua)?;
|
||||
let sched = lua
|
||||
.app_data_ref::<&Scheduler>()
|
||||
.expect("Lua struct is missing scheduler");
|
||||
sched.push_back(lua, thread.clone(), args)?;
|
||||
Ok(thread)
|
||||
}
|
||||
|
||||
// FIXME: `self` escapes outside of method because we are borrowing `tof` and
|
||||
// `args` when we call `schedule_future_thread` in the lua function body below
|
||||
// For now we solve this by using the 'static lifetime bound in the impl
|
||||
fn task_delay<'lua>(
|
||||
lua: &'lua Lua,
|
||||
(secs, tof, args): (f64, LuaThreadOrFunction<'lua>, LuaMultiValue<'lua>),
|
||||
) -> LuaResult<LuaThread<'lua>>
|
||||
where
|
||||
'lua: 'static,
|
||||
{
|
||||
let thread = tof.into_thread(lua)?;
|
||||
let sched = lua
|
||||
.app_data_ref::<&Scheduler>()
|
||||
.expect("Lua struct is missing scheduler");
|
||||
|
||||
let thread2 = thread.clone();
|
||||
sched.spawn_thread(lua, thread.clone(), async move {
|
||||
let duration = Duration::from_secs_f64(secs);
|
||||
time::sleep(duration).await;
|
||||
sched.push_back(lua, thread2, args)?;
|
||||
Ok(())
|
||||
})?;
|
||||
|
||||
Ok(thread)
|
||||
}
|
||||
|
||||
async fn task_wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
|
||||
async fn wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
|
||||
let duration = Duration::from_secs_f64(secs.unwrap_or_default());
|
||||
|
||||
let before = Instant::now();
|
||||
|
|
|
@ -1,30 +0,0 @@
|
|||
use mlua::prelude::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(super) enum LuaThreadOrFunction<'lua> {
|
||||
Thread(LuaThread<'lua>),
|
||||
Function(LuaFunction<'lua>),
|
||||
}
|
||||
|
||||
impl<'lua> LuaThreadOrFunction<'lua> {
|
||||
pub(super) fn into_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
|
||||
match self {
|
||||
Self::Thread(t) => Ok(t),
|
||||
Self::Function(f) => lua.create_thread(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> FromLua<'lua> for LuaThreadOrFunction<'lua> {
|
||||
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
|
||||
match value {
|
||||
LuaValue::Thread(t) => Ok(Self::Thread(t)),
|
||||
LuaValue::Function(f) => Ok(Self::Function(f)),
|
||||
value => Err(LuaError::FromLuaConversionError {
|
||||
from: value.type_name(),
|
||||
to: "LuaThreadOrFunction",
|
||||
message: Some("Expected thread or function".to_string()),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -8,7 +8,7 @@ mod require;
|
|||
mod version;
|
||||
mod warn;
|
||||
|
||||
pub fn inject_all(lua: &'static Lua) -> LuaResult<()> {
|
||||
pub fn inject_all(lua: &Lua) -> LuaResult<()> {
|
||||
let all = TableBuilder::new(lua)?
|
||||
.with_value("_G", g_table::create(lua)?)?
|
||||
.with_value("_VERSION", version::create(lua)?)?
|
||||
|
|
|
@ -9,7 +9,8 @@ use crate::lune::util::{
|
|||
use super::context::*;
|
||||
|
||||
pub(super) async fn require<'lua, 'ctx>(
|
||||
ctx: &'ctx RequireContext<'lua>,
|
||||
lua: &'lua Lua,
|
||||
ctx: &'ctx RequireContext,
|
||||
source: &str,
|
||||
alias: &str,
|
||||
path: &str,
|
||||
|
@ -18,7 +19,6 @@ where
|
|||
'lua: 'ctx,
|
||||
{
|
||||
let alias = alias.to_ascii_lowercase();
|
||||
let path = path.to_ascii_lowercase();
|
||||
|
||||
let parent = make_absolute_and_clean(source)
|
||||
.parent()
|
||||
|
@ -71,5 +71,5 @@ where
|
|||
LuaError::runtime(format!("failed to find relative path for alias '{alias}'"))
|
||||
})?;
|
||||
|
||||
super::path::require_abs_rel(ctx, abs_path, rel_path).await
|
||||
super::path::require_abs_rel(lua, ctx, abs_path, rel_path).await
|
||||
}
|
||||
|
|
|
@ -3,12 +3,12 @@ use mlua::prelude::*;
|
|||
use super::context::*;
|
||||
|
||||
pub(super) async fn require<'lua, 'ctx>(
|
||||
ctx: &'ctx RequireContext<'lua>,
|
||||
lua: &'lua Lua,
|
||||
ctx: &'ctx RequireContext,
|
||||
name: &str,
|
||||
) -> LuaResult<LuaMultiValue<'lua>>
|
||||
where
|
||||
'lua: 'ctx,
|
||||
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
|
||||
{
|
||||
ctx.load_builtin(name)
|
||||
ctx.load_builtin(lua, name)
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ use std::{
|
|||
};
|
||||
|
||||
use mlua::prelude::*;
|
||||
use mlua_luau_scheduler::LuaSchedulerExt;
|
||||
use tokio::{
|
||||
fs,
|
||||
sync::{
|
||||
|
@ -13,11 +14,7 @@ use tokio::{
|
|||
},
|
||||
};
|
||||
|
||||
use crate::lune::{
|
||||
builtins::LuneBuiltin,
|
||||
scheduler::{IntoLuaThread, Scheduler},
|
||||
util::paths::CWD,
|
||||
};
|
||||
use crate::lune::{builtins::LuneBuiltin, util::paths::CWD};
|
||||
|
||||
/**
|
||||
Context containing cached results for all `require` operations.
|
||||
|
@ -26,14 +23,13 @@ use crate::lune::{
|
|||
path will first be transformed into an absolute path.
|
||||
*/
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct RequireContext<'lua> {
|
||||
lua: &'lua Lua,
|
||||
pub(super) struct RequireContext {
|
||||
cache_builtins: Arc<AsyncMutex<HashMap<LuneBuiltin, LuaResult<LuaRegistryKey>>>>,
|
||||
cache_results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>,
|
||||
cache_pending: Arc<AsyncMutex<HashMap<PathBuf, Sender<()>>>>,
|
||||
}
|
||||
|
||||
impl<'lua> RequireContext<'lua> {
|
||||
impl RequireContext {
|
||||
/**
|
||||
Creates a new require context for the given [`Lua`] struct.
|
||||
|
||||
|
@ -41,9 +37,8 @@ impl<'lua> RequireContext<'lua> {
|
|||
context should be created per [`Lua`] struct, creating more
|
||||
than one context may lead to undefined require-behavior.
|
||||
*/
|
||||
pub fn new(lua: &'lua Lua) -> Self {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
lua,
|
||||
cache_builtins: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||
cache_results: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||
cache_pending: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||
|
@ -75,7 +70,7 @@ impl<'lua> RequireContext<'lua> {
|
|||
CWD.join(&rel_path)
|
||||
};
|
||||
|
||||
Ok((rel_path, abs_path))
|
||||
Ok((abs_path, rel_path))
|
||||
}
|
||||
|
||||
/**
|
||||
|
@ -107,7 +102,11 @@ impl<'lua> RequireContext<'lua> {
|
|||
|
||||
Will panic if the path has not been cached, use [`is_cached`] first.
|
||||
*/
|
||||
pub fn get_from_cache(&self, abs_path: impl AsRef<Path>) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
pub fn get_from_cache<'lua>(
|
||||
&self,
|
||||
lua: &'lua Lua,
|
||||
abs_path: impl AsRef<Path>,
|
||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
let results = self
|
||||
.cache_results
|
||||
.try_lock()
|
||||
|
@ -119,8 +118,7 @@ impl<'lua> RequireContext<'lua> {
|
|||
match cached {
|
||||
Err(e) => Err(e.clone()),
|
||||
Ok(k) => {
|
||||
let multi_vec = self
|
||||
.lua
|
||||
let multi_vec = lua
|
||||
.registry_value::<Vec<LuaValue>>(k)
|
||||
.expect("Missing require result in lua registry");
|
||||
Ok(LuaMultiValue::from_vec(multi_vec))
|
||||
|
@ -133,8 +131,9 @@ impl<'lua> RequireContext<'lua> {
|
|||
|
||||
Will panic if the path has not been cached, use [`is_cached`] first.
|
||||
*/
|
||||
pub async fn wait_for_cache(
|
||||
pub async fn wait_for_cache<'lua>(
|
||||
&self,
|
||||
lua: &'lua Lua,
|
||||
abs_path: impl AsRef<Path>,
|
||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
let mut thread_recv = {
|
||||
|
@ -150,43 +149,37 @@ impl<'lua> RequireContext<'lua> {
|
|||
|
||||
thread_recv.recv().await.into_lua_err()?;
|
||||
|
||||
self.get_from_cache(abs_path.as_ref())
|
||||
self.get_from_cache(lua, abs_path.as_ref())
|
||||
}
|
||||
|
||||
async fn load(
|
||||
async fn load<'lua>(
|
||||
&self,
|
||||
lua: &'lua Lua,
|
||||
abs_path: impl AsRef<Path>,
|
||||
rel_path: impl AsRef<Path>,
|
||||
) -> LuaResult<LuaRegistryKey> {
|
||||
let abs_path = abs_path.as_ref();
|
||||
let rel_path = rel_path.as_ref();
|
||||
|
||||
let sched = self
|
||||
.lua
|
||||
.app_data_ref::<&Scheduler>()
|
||||
.expect("Lua struct is missing scheduler");
|
||||
|
||||
// Read the file at the given path, try to parse and
|
||||
// load it into a new lua thread that we can schedule
|
||||
let file_contents = fs::read(&abs_path).await?;
|
||||
let file_thread = self
|
||||
.lua
|
||||
let file_thread = lua
|
||||
.load(file_contents)
|
||||
.set_name(rel_path.to_string_lossy().to_string())
|
||||
.into_function()?
|
||||
.into_lua_thread(self.lua)?;
|
||||
.set_name(rel_path.to_string_lossy().to_string());
|
||||
|
||||
// Schedule the thread to run, wait for it to finish running
|
||||
let thread_id = sched.push_back(self.lua, file_thread, ())?;
|
||||
let thread_res = sched.wait_for_thread(self.lua, thread_id).await;
|
||||
let thread_id = lua.push_thread_back(file_thread, ())?;
|
||||
lua.track_thread(thread_id);
|
||||
lua.wait_for_thread(thread_id).await;
|
||||
let thread_res = lua.get_thread_result(thread_id).unwrap();
|
||||
|
||||
// Return the result of the thread, storing any lua value(s) in the registry
|
||||
match thread_res {
|
||||
Err(e) => Err(e),
|
||||
Ok(v) => {
|
||||
let multi_vec = v.into_vec();
|
||||
let multi_key = self
|
||||
.lua
|
||||
let multi_key = lua
|
||||
.create_registry_value(multi_vec)
|
||||
.expect("Failed to store require result in registry - out of memory");
|
||||
Ok(multi_key)
|
||||
|
@ -197,8 +190,9 @@ impl<'lua> RequireContext<'lua> {
|
|||
/**
|
||||
Loads (requires) the file at the given path.
|
||||
*/
|
||||
pub async fn load_with_caching(
|
||||
pub async fn load_with_caching<'lua>(
|
||||
&self,
|
||||
lua: &'lua Lua,
|
||||
abs_path: impl AsRef<Path>,
|
||||
rel_path: impl AsRef<Path>,
|
||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
|
@ -213,12 +207,11 @@ impl<'lua> RequireContext<'lua> {
|
|||
.insert(abs_path.to_path_buf(), broadcast_tx);
|
||||
|
||||
// Try to load at this abs path
|
||||
let load_res = self.load(abs_path, rel_path).await;
|
||||
let load_res = self.load(lua, abs_path, rel_path).await;
|
||||
let load_val = match &load_res {
|
||||
Err(e) => Err(e.clone()),
|
||||
Ok(k) => {
|
||||
let multi_vec = self
|
||||
.lua
|
||||
let multi_vec = lua
|
||||
.registry_value::<Vec<LuaValue>>(k)
|
||||
.expect("Failed to fetch require result from registry");
|
||||
Ok(LuaMultiValue::from_vec(multi_vec))
|
||||
|
@ -250,10 +243,11 @@ impl<'lua> RequireContext<'lua> {
|
|||
/**
|
||||
Loads (requires) the builtin with the given name.
|
||||
*/
|
||||
pub fn load_builtin(&self, name: impl AsRef<str>) -> LuaResult<LuaMultiValue<'lua>>
|
||||
where
|
||||
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
|
||||
{
|
||||
pub fn load_builtin<'lua>(
|
||||
&self,
|
||||
lua: &'lua Lua,
|
||||
name: impl AsRef<str>,
|
||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
let builtin: LuneBuiltin = match name.as_ref().parse() {
|
||||
Err(e) => return Err(LuaError::runtime(e)),
|
||||
Ok(b) => b,
|
||||
|
@ -268,8 +262,7 @@ impl<'lua> RequireContext<'lua> {
|
|||
return match res {
|
||||
Err(e) => return Err(e.clone()),
|
||||
Ok(key) => {
|
||||
let multi_vec = self
|
||||
.lua
|
||||
let multi_vec = lua
|
||||
.registry_value::<Vec<LuaValue>>(key)
|
||||
.expect("Missing builtin result in lua registry");
|
||||
Ok(LuaMultiValue::from_vec(multi_vec))
|
||||
|
@ -277,7 +270,7 @@ impl<'lua> RequireContext<'lua> {
|
|||
};
|
||||
};
|
||||
|
||||
let result = builtin.create(self.lua);
|
||||
let result = builtin.create(lua);
|
||||
|
||||
cache.insert(
|
||||
builtin,
|
||||
|
@ -285,8 +278,7 @@ impl<'lua> RequireContext<'lua> {
|
|||
Err(e) => Err(e),
|
||||
Ok(multi) => {
|
||||
let multi_vec = multi.into_vec();
|
||||
let multi_key = self
|
||||
.lua
|
||||
let multi_key = lua
|
||||
.create_registry_value(multi_vec)
|
||||
.expect("Failed to store require result in registry - out of memory");
|
||||
Ok(multi_key)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use mlua::prelude::*;
|
||||
|
||||
use crate::lune::{scheduler::LuaSchedulerExt, util::TableBuilder};
|
||||
use crate::lune::util::TableBuilder;
|
||||
|
||||
mod context;
|
||||
use context::RequireContext;
|
||||
|
@ -13,8 +13,8 @@ const REQUIRE_IMPL: &str = r#"
|
|||
return require(source(), ...)
|
||||
"#;
|
||||
|
||||
pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> {
|
||||
lua.set_app_data(RequireContext::new(lua));
|
||||
pub fn create(lua: &Lua) -> LuaResult<impl IntoLua<'_>> {
|
||||
lua.set_app_data(RequireContext::new());
|
||||
|
||||
/*
|
||||
Require implementation needs a few workarounds:
|
||||
|
@ -62,10 +62,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> {
|
|||
async fn require<'lua>(
|
||||
lua: &'lua Lua,
|
||||
(source, path): (LuaString<'lua>, LuaString<'lua>),
|
||||
) -> LuaResult<LuaMultiValue<'lua>>
|
||||
where
|
||||
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
|
||||
{
|
||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||
let source = source
|
||||
.to_str()
|
||||
.into_lua_err()
|
||||
|
@ -86,13 +83,13 @@ where
|
|||
.strip_prefix("@lune/")
|
||||
.map(|name| name.to_ascii_lowercase())
|
||||
{
|
||||
builtin::require(&context, &builtin_name).await
|
||||
builtin::require(lua, &context, &builtin_name).await
|
||||
} else if let Some(aliased_path) = path.strip_prefix('@') {
|
||||
let (alias, path) = aliased_path.split_once('/').ok_or(LuaError::runtime(
|
||||
"Require with custom alias must contain '/' delimiter",
|
||||
))?;
|
||||
alias::require(&context, &source, alias, path).await
|
||||
alias::require(lua, &context, &source, alias, path).await
|
||||
} else {
|
||||
path::require(&context, &source, &path).await
|
||||
path::require(lua, &context, &source, &path).await
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,7 +5,8 @@ use mlua::prelude::*;
|
|||
use super::context::*;
|
||||
|
||||
pub(super) async fn require<'lua, 'ctx>(
|
||||
ctx: &'ctx RequireContext<'lua>,
|
||||
lua: &'lua Lua,
|
||||
ctx: &'ctx RequireContext,
|
||||
source: &str,
|
||||
path: &str,
|
||||
) -> LuaResult<LuaMultiValue<'lua>>
|
||||
|
@ -13,11 +14,12 @@ where
|
|||
'lua: 'ctx,
|
||||
{
|
||||
let (abs_path, rel_path) = ctx.resolve_paths(source, path)?;
|
||||
require_abs_rel(ctx, abs_path, rel_path).await
|
||||
require_abs_rel(lua, ctx, abs_path, rel_path).await
|
||||
}
|
||||
|
||||
pub(super) async fn require_abs_rel<'lua, 'ctx>(
|
||||
ctx: &'ctx RequireContext<'lua>,
|
||||
lua: &'lua Lua,
|
||||
ctx: &'ctx RequireContext,
|
||||
abs_path: PathBuf, // Absolute to filesystem
|
||||
rel_path: PathBuf, // Relative to CWD (for displaying)
|
||||
) -> LuaResult<LuaMultiValue<'lua>>
|
||||
|
@ -25,7 +27,7 @@ where
|
|||
'lua: 'ctx,
|
||||
{
|
||||
// 1. Try to require the exact path
|
||||
if let Ok(res) = require_inner(ctx, &abs_path, &rel_path).await {
|
||||
if let Ok(res) = require_inner(lua, ctx, &abs_path, &rel_path).await {
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
|
@ -34,7 +36,7 @@ where
|
|||
append_extension(&abs_path, "luau"),
|
||||
append_extension(&rel_path, "luau"),
|
||||
);
|
||||
if let Ok(res) = require_inner(ctx, &luau_abs_path, &luau_rel_path).await {
|
||||
if let Ok(res) = require_inner(lua, ctx, &luau_abs_path, &luau_rel_path).await {
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
|
@ -43,7 +45,7 @@ where
|
|||
append_extension(&abs_path, "lua"),
|
||||
append_extension(&rel_path, "lua"),
|
||||
);
|
||||
if let Ok(res) = require_inner(ctx, &lua_abs_path, &lua_rel_path).await {
|
||||
if let Ok(res) = require_inner(lua, ctx, &lua_abs_path, &lua_rel_path).await {
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
|
@ -57,7 +59,7 @@ where
|
|||
append_extension(&abs_init, "luau"),
|
||||
append_extension(&rel_init, "luau"),
|
||||
);
|
||||
if let Ok(res) = require_inner(ctx, &luau_abs_init, &luau_rel_init).await {
|
||||
if let Ok(res) = require_inner(lua, ctx, &luau_abs_init, &luau_rel_init).await {
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
|
@ -66,7 +68,7 @@ where
|
|||
append_extension(&abs_init, "lua"),
|
||||
append_extension(&rel_init, "lua"),
|
||||
);
|
||||
if let Ok(res) = require_inner(ctx, &lua_abs_init, &lua_rel_init).await {
|
||||
if let Ok(res) = require_inner(lua, ctx, &lua_abs_init, &lua_rel_init).await {
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
|
@ -78,7 +80,8 @@ where
|
|||
}
|
||||
|
||||
async fn require_inner<'lua, 'ctx>(
|
||||
ctx: &'ctx RequireContext<'lua>,
|
||||
lua: &'lua Lua,
|
||||
ctx: &'ctx RequireContext,
|
||||
abs_path: impl AsRef<Path>,
|
||||
rel_path: impl AsRef<Path>,
|
||||
) -> LuaResult<LuaMultiValue<'lua>>
|
||||
|
@ -89,11 +92,11 @@ where
|
|||
let rel_path = rel_path.as_ref();
|
||||
|
||||
if ctx.is_cached(abs_path)? {
|
||||
ctx.get_from_cache(abs_path)
|
||||
ctx.get_from_cache(lua, abs_path)
|
||||
} else if ctx.is_pending(abs_path)? {
|
||||
ctx.wait_for_cache(&abs_path).await
|
||||
ctx.wait_for_cache(lua, &abs_path).await
|
||||
} else {
|
||||
ctx.load_with_caching(&abs_path, &rel_path).await
|
||||
ctx.load_with_caching(lua, &abs_path, &rel_path).await
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,24 +1,39 @@
|
|||
use mlua::prelude::*;
|
||||
|
||||
pub fn create(lua: &Lua) -> LuaResult<impl IntoLua<'_>> {
|
||||
let lune_version = format!("Lune {}", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
let luau_version_full = lua
|
||||
.globals()
|
||||
.get::<_, LuaString>("_VERSION")
|
||||
.expect("Missing _VERSION global");
|
||||
let luau_version_str = luau_version_full
|
||||
.to_str()
|
||||
.context("Invalid utf8 found in _VERSION global")?;
|
||||
|
||||
let luau_version = luau_version_full
|
||||
.to_str()?
|
||||
.strip_prefix("Luau 0.")
|
||||
.expect("_VERSION global is formatted incorrectly")
|
||||
.trim();
|
||||
|
||||
if luau_version.is_empty() {
|
||||
panic!("_VERSION global is missing version number")
|
||||
// If this function runs more than once, we
|
||||
// may get an already formatted lune version.
|
||||
if luau_version_str.starts_with(&lune_version) {
|
||||
return Ok(luau_version_full);
|
||||
}
|
||||
|
||||
lua.create_string(format!(
|
||||
"Lune {lune}+{luau}",
|
||||
lune = env!("CARGO_PKG_VERSION"),
|
||||
luau = luau_version,
|
||||
))
|
||||
// Luau version is expected to be in the format "Luau 0.x" and sometimes "Luau 0.x.y"
|
||||
if !luau_version_str.starts_with("Luau 0.") {
|
||||
panic!("_VERSION global is formatted incorrectly\nGot: '{luau_version_str}'")
|
||||
}
|
||||
let luau_version = luau_version_str.strip_prefix("Luau 0.").unwrap().trim();
|
||||
|
||||
// We make some guarantees about the format of the _VERSION global,
|
||||
// so make sure that the luau version also follows those rules.
|
||||
if luau_version.is_empty() {
|
||||
panic!("_VERSION global is missing version number\nGot: '{luau_version_str}'")
|
||||
} else if !luau_version.chars().all(is_valid_version_char) {
|
||||
panic!("_VERSION global contains invalid characters\nGot: '{luau_version_str}'")
|
||||
}
|
||||
|
||||
lua.create_string(format!("{lune_version}+{luau_version}"))
|
||||
}
|
||||
|
||||
fn is_valid_version_char(c: char) -> bool {
|
||||
matches!(c, '0'..='9' | '.')
|
||||
}
|
||||
|
|
|
@ -1,47 +1,44 @@
|
|||
use std::process::ExitCode;
|
||||
use std::{
|
||||
process::ExitCode,
|
||||
rc::Rc,
|
||||
sync::{
|
||||
atomic::{AtomicBool, Ordering},
|
||||
Arc,
|
||||
},
|
||||
};
|
||||
|
||||
use mlua::Lua;
|
||||
use mlua_luau_scheduler::Scheduler;
|
||||
|
||||
mod builtins;
|
||||
mod error;
|
||||
mod globals;
|
||||
mod scheduler;
|
||||
|
||||
pub(crate) mod util;
|
||||
|
||||
use self::scheduler::{LuaSchedulerExt, Scheduler};
|
||||
|
||||
pub use error::RuntimeError;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub struct Runtime {
|
||||
lua: &'static Lua,
|
||||
scheduler: &'static Scheduler<'static>,
|
||||
lua: Rc<Lua>,
|
||||
args: Vec<String>,
|
||||
}
|
||||
|
||||
impl Runtime {
|
||||
/**
|
||||
Creates a new Lune runtime, with a new Luau VM and task scheduler.
|
||||
Creates a new Lune runtime, with a new Luau VM.
|
||||
*/
|
||||
#[allow(clippy::new_without_default)]
|
||||
pub fn new() -> Self {
|
||||
/*
|
||||
FUTURE: Stop leaking these when we have removed the lifetime
|
||||
on the scheduler and can place them in lua app data using arc
|
||||
let lua = Rc::new(Lua::new());
|
||||
|
||||
See the scheduler struct for more notes
|
||||
*/
|
||||
let lua = Lua::new().into_static();
|
||||
let scheduler = Scheduler::new().into_static();
|
||||
|
||||
lua.set_scheduler(scheduler);
|
||||
lua.set_app_data(Rc::downgrade(&lua));
|
||||
lua.set_app_data(Vec::<String>::new());
|
||||
globals::inject_all(lua).expect("Failed to inject lua globals");
|
||||
|
||||
globals::inject_all(&lua).expect("Failed to inject globals");
|
||||
|
||||
Self {
|
||||
lua,
|
||||
scheduler,
|
||||
args: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
@ -68,13 +65,34 @@ impl Runtime {
|
|||
script_name: impl AsRef<str>,
|
||||
script_contents: impl AsRef<[u8]>,
|
||||
) -> Result<ExitCode, RuntimeError> {
|
||||
// Create a new scheduler for this run
|
||||
let sched = Scheduler::new(&self.lua);
|
||||
|
||||
// Add error callback to format errors nicely + store status
|
||||
let got_any_error = Arc::new(AtomicBool::new(false));
|
||||
let got_any_inner = Arc::clone(&got_any_error);
|
||||
sched.set_error_callback(move |e| {
|
||||
got_any_inner.store(true, Ordering::SeqCst);
|
||||
eprintln!("{}", RuntimeError::from(e));
|
||||
});
|
||||
|
||||
// Load our "main" thread
|
||||
let main = self
|
||||
.lua
|
||||
.load(script_contents.as_ref())
|
||||
.set_name(script_name.as_ref());
|
||||
|
||||
self.scheduler.push_back(self.lua, main, ())?;
|
||||
// Run it on our scheduler until it and any other spawned threads complete
|
||||
sched.push_thread_back(main, ())?;
|
||||
sched.run().await;
|
||||
|
||||
Ok(self.scheduler.run_to_completion(self.lua).await)
|
||||
// Return the exit code - default to FAILURE if we got any errors
|
||||
Ok(sched.get_exit_code().unwrap_or({
|
||||
if got_any_error.load(Ordering::SeqCst) {
|
||||
ExitCode::FAILURE
|
||||
} else {
|
||||
ExitCode::SUCCESS
|
||||
}
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,138 +0,0 @@
|
|||
use futures_util::Future;
|
||||
use mlua::prelude::*;
|
||||
use tokio::{
|
||||
sync::oneshot::{self, Receiver},
|
||||
task,
|
||||
};
|
||||
|
||||
use super::{IntoLuaThread, Scheduler};
|
||||
|
||||
impl<'fut> Scheduler<'fut> {
|
||||
/**
|
||||
Checks if there are any futures to run, for
|
||||
lua futures and background futures respectively.
|
||||
*/
|
||||
pub(super) fn has_futures(&self) -> (bool, bool) {
|
||||
(
|
||||
self.futures_lua
|
||||
.try_lock()
|
||||
.expect("Failed to lock lua futures for check")
|
||||
.len()
|
||||
> 0,
|
||||
self.futures_background
|
||||
.try_lock()
|
||||
.expect("Failed to lock background futures for check")
|
||||
.len()
|
||||
> 0,
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
Schedules a plain future to run in the background.
|
||||
|
||||
This will potentially spawn the future on a different thread, using
|
||||
[`task::spawn`], meaning the provided future must implement [`Send`].
|
||||
|
||||
Returns a [`Receiver`] which may be `await`-ed
|
||||
to retrieve the result of the spawned future.
|
||||
|
||||
This [`Receiver`] may be safely ignored if the result of the
|
||||
spawned future is not needed, the future will run either way.
|
||||
*/
|
||||
pub fn spawn<F>(&self, fut: F) -> Receiver<F::Output>
|
||||
where
|
||||
F: Future + Send + 'static,
|
||||
F::Output: Send + 'static,
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let handle = task::spawn(async move {
|
||||
let res = fut.await;
|
||||
tx.send(res).ok();
|
||||
});
|
||||
|
||||
// NOTE: We must spawn a future on our scheduler which awaits
|
||||
// the handle from tokio to start driving our future properly
|
||||
let futs = self
|
||||
.futures_background
|
||||
.try_lock()
|
||||
.expect("Failed to lock futures queue for background tasks");
|
||||
futs.push(Box::pin(async move {
|
||||
handle.await.ok();
|
||||
}));
|
||||
|
||||
// NOTE: We might be resuming lua futures, need to signal that a
|
||||
// new background future is ready to break out of futures resumption
|
||||
self.state.message_sender().send_spawned_background_future();
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
/**
|
||||
Equivalent to [`spawn`], except the future is only
|
||||
spawned on the Lune scheduler, and on the main thread.
|
||||
*/
|
||||
pub fn spawn_local<F>(&self, fut: F) -> Receiver<F::Output>
|
||||
where
|
||||
F: Future + 'static,
|
||||
F::Output: 'static,
|
||||
{
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
let futs = self
|
||||
.futures_background
|
||||
.try_lock()
|
||||
.expect("Failed to lock futures queue for background tasks");
|
||||
futs.push(Box::pin(async move {
|
||||
let res = fut.await;
|
||||
tx.send(res).ok();
|
||||
}));
|
||||
|
||||
// NOTE: We might be resuming lua futures, need to signal that a
|
||||
// new background future is ready to break out of futures resumption
|
||||
self.state.message_sender().send_spawned_background_future();
|
||||
|
||||
rx
|
||||
}
|
||||
|
||||
/**
|
||||
Schedules the given `thread` to run when the given `fut` completes.
|
||||
|
||||
If the given future returns a [`LuaError`], that error will be passed to the given `thread`.
|
||||
*/
|
||||
pub fn spawn_thread<F, FR>(
|
||||
&'fut self,
|
||||
lua: &'fut Lua,
|
||||
thread: impl IntoLuaThread<'fut>,
|
||||
fut: F,
|
||||
) -> LuaResult<()>
|
||||
where
|
||||
FR: IntoLuaMulti<'fut>,
|
||||
F: Future<Output = LuaResult<FR>> + 'fut,
|
||||
{
|
||||
let thread = thread.into_lua_thread(lua)?;
|
||||
let futs = self.futures_lua.try_lock().expect(
|
||||
"Failed to lock futures queue - \
|
||||
can't schedule future lua threads during futures resumption",
|
||||
);
|
||||
|
||||
futs.push(Box::pin(async move {
|
||||
match fut.await.and_then(|rets| rets.into_lua_multi(lua)) {
|
||||
Err(e) => {
|
||||
self.push_err(lua, thread, e)
|
||||
.expect("Failed to schedule future err thread");
|
||||
}
|
||||
Ok(v) => {
|
||||
self.push_back(lua, thread, v)
|
||||
.expect("Failed to schedule future thread");
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
// NOTE: We might be resuming background futures, need to signal that a
|
||||
// new background future is ready to break out of futures resumption
|
||||
self.state.message_sender().send_spawned_lua_future();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
|
@ -1,265 +0,0 @@
|
|||
use std::{process::ExitCode, sync::Arc};
|
||||
|
||||
use futures_util::StreamExt;
|
||||
use mlua::prelude::*;
|
||||
|
||||
use tokio::task::LocalSet;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::lune::util::traits::LuaEmitErrorExt;
|
||||
|
||||
use super::Scheduler;
|
||||
|
||||
impl<'fut> Scheduler<'fut> {
|
||||
/**
|
||||
Runs all lua threads to completion.
|
||||
*/
|
||||
fn run_lua_threads(&self, lua: &Lua) {
|
||||
if self.state.has_exit_code() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut count = 0;
|
||||
|
||||
// Pop threads from the scheduler until there are none left
|
||||
while let Some(thread) = self
|
||||
.pop_thread()
|
||||
.expect("Failed to pop thread from scheduler")
|
||||
{
|
||||
// Deconstruct the scheduler thread into its parts
|
||||
let thread_id = thread.id();
|
||||
let (thread, args) = thread.into_inner(lua);
|
||||
|
||||
// Make sure this thread is still resumable, it might have
|
||||
// been resumed somewhere else or even have been cancelled
|
||||
if thread.status() != LuaThreadStatus::Resumable {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Resume the thread, ensuring that the schedulers
|
||||
// current thread id is set correctly for error catching
|
||||
self.state.set_current_thread_id(Some(thread_id));
|
||||
let res = thread.resume::<_, LuaMultiValue>(args);
|
||||
self.state.set_current_thread_id(None);
|
||||
|
||||
count += 1;
|
||||
|
||||
// If we got any resumption (lua-side) error, increment
|
||||
// the error count of the scheduler so we can exit with
|
||||
// a non-zero exit code, and print it out to stderr
|
||||
if let Err(err) = &res {
|
||||
self.state.increment_error_count();
|
||||
lua.emit_error(err.clone());
|
||||
}
|
||||
|
||||
// If the thread has finished running completely,
|
||||
// send results of final resume to any listeners
|
||||
if thread.status() != LuaThreadStatus::Resumable {
|
||||
// NOTE: Threads that were spawned to resume
|
||||
// with an error will not have a result sender
|
||||
if let Some(sender) = self
|
||||
.thread_senders
|
||||
.try_lock()
|
||||
.expect("Failed to get thread senders")
|
||||
.remove(&thread_id)
|
||||
{
|
||||
if sender.receiver_count() > 0 {
|
||||
let stored = match res {
|
||||
Err(e) => Err(e),
|
||||
Ok(v) => Ok(Arc::new(lua.create_registry_value(v.into_vec()).expect(
|
||||
"Failed to store thread results in registry - out of memory",
|
||||
))),
|
||||
};
|
||||
sender
|
||||
.send(stored)
|
||||
.expect("Failed to broadcast thread results");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if self.state.has_exit_code() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
debug! {
|
||||
%count,
|
||||
"resumed lua"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Runs the next lua future to completion.
|
||||
|
||||
Panics if no lua future is queued.
|
||||
*/
|
||||
async fn run_future_lua(&self) {
|
||||
let mut futs = self
|
||||
.futures_lua
|
||||
.try_lock()
|
||||
.expect("Failed to lock lua futures for resumption");
|
||||
assert!(futs.len() > 0, "No lua futures are queued");
|
||||
futs.next().await;
|
||||
}
|
||||
|
||||
/**
|
||||
Runs the next background future to completion.
|
||||
|
||||
Panics if no background future is queued.
|
||||
*/
|
||||
async fn run_future_background(&self) {
|
||||
let mut futs = self
|
||||
.futures_background
|
||||
.try_lock()
|
||||
.expect("Failed to lock background futures for resumption");
|
||||
assert!(futs.len() > 0, "No background futures are queued");
|
||||
futs.next().await;
|
||||
}
|
||||
|
||||
/**
|
||||
Runs as many futures as possible, until a new lua thread
|
||||
is ready, or an exit code has been set for the scheduler.
|
||||
|
||||
### Implementation details
|
||||
|
||||
Running futures on our scheduler consists of a couple moving parts:
|
||||
|
||||
1. An unordered futures queue for lua (main thread, local) futures
|
||||
2. An unordered futures queue for background (multithreaded, 'static lifetime) futures
|
||||
3. A signal for breaking out of futures resumption
|
||||
|
||||
The two unordered futures queues need to run concurrently,
|
||||
but since `FuturesUnordered` returns instantly if it does
|
||||
not currently have any futures queued on it, we need to do
|
||||
this branching loop, checking if each queue has futures first.
|
||||
|
||||
We also need to listen for our signal, to see if we should break out of resumption:
|
||||
|
||||
* Always break out of resumption if a new lua thread is ready
|
||||
* Always break out of resumption if an exit code has been set
|
||||
* Break out of lua futures resumption if we have a new background future
|
||||
* Break out of background futures resumption if we have a new lua future
|
||||
|
||||
We need to listen for both future queues concurrently,
|
||||
and break out whenever the other corresponding queue has
|
||||
a new future, since the other queue may resume sooner.
|
||||
*/
|
||||
async fn run_futures(&self) {
|
||||
let (mut has_lua, mut has_background) = self.has_futures();
|
||||
if !has_lua && !has_background {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut rx = self.state.message_receiver();
|
||||
let mut count = 0;
|
||||
|
||||
while has_lua || has_background {
|
||||
if has_lua && has_background {
|
||||
tokio::select! {
|
||||
_ = self.run_future_lua() => {},
|
||||
_ = self.run_future_background() => {},
|
||||
msg = rx.recv() => {
|
||||
if let Some(msg) = msg {
|
||||
if msg.should_break_futures() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
count += 1;
|
||||
} else if has_lua {
|
||||
tokio::select! {
|
||||
_ = self.run_future_lua() => {},
|
||||
msg = rx.recv() => {
|
||||
if let Some(msg) = msg {
|
||||
if msg.should_break_lua_futures() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
count += 1;
|
||||
} else if has_background {
|
||||
tokio::select! {
|
||||
_ = self.run_future_background() => {},
|
||||
msg = rx.recv() => {
|
||||
if let Some(msg) = msg {
|
||||
if msg.should_break_background_futures() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
count += 1;
|
||||
}
|
||||
(has_lua, has_background) = self.has_futures();
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
debug! {
|
||||
%count,
|
||||
"resumed lua futures"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Runs the scheduler to completion in a [`LocalSet`],
|
||||
both normal lua threads and futures, prioritizing
|
||||
lua threads over completion of any pending futures.
|
||||
|
||||
Will emit lua output and errors to stdout and stderr.
|
||||
*/
|
||||
pub async fn run_to_completion(&self, lua: &Lua) -> ExitCode {
|
||||
if let Some(code) = self.state.exit_code() {
|
||||
return ExitCode::from(code);
|
||||
}
|
||||
|
||||
let set = LocalSet::new();
|
||||
let _guard = set.enter();
|
||||
|
||||
loop {
|
||||
// 1. Run lua threads until exit or there are none left
|
||||
self.run_lua_threads(lua);
|
||||
|
||||
// 2. If we got a manual exit code from lua we should
|
||||
// not try to wait for any pending futures to complete
|
||||
if self.state.has_exit_code() {
|
||||
break;
|
||||
}
|
||||
|
||||
// 3. Keep resuming futures until there are no futures left to
|
||||
// resume, or until we manually break out of resumption for any
|
||||
// reason, this may be because a future spawned a new lua thread
|
||||
self.run_futures().await;
|
||||
|
||||
// 4. Once again, check for an exit code, in case a future sets one
|
||||
if self.state.has_exit_code() {
|
||||
break;
|
||||
}
|
||||
|
||||
// 5. If we have no lua threads or futures remaining,
|
||||
// we have now run the scheduler until completion
|
||||
let (has_future_lua, has_future_background) = self.has_futures();
|
||||
if !has_future_lua && !has_future_background && !self.has_thread() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(code) = self.state.exit_code() {
|
||||
debug! {
|
||||
%code,
|
||||
"scheduler ran to completion"
|
||||
};
|
||||
ExitCode::from(code)
|
||||
} else if self.state.has_errored() {
|
||||
debug!("scheduler ran to completion, with failure");
|
||||
ExitCode::FAILURE
|
||||
} else {
|
||||
debug!("scheduler ran to completion, with success");
|
||||
ExitCode::SUCCESS
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,185 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use super::{
|
||||
thread::{SchedulerThread, SchedulerThreadId, SchedulerThreadSender},
|
||||
IntoLuaThread, Scheduler,
|
||||
};
|
||||
|
||||
impl<'fut> Scheduler<'fut> {
|
||||
/**
|
||||
Checks if there are any lua threads to run.
|
||||
*/
|
||||
pub(super) fn has_thread(&self) -> bool {
|
||||
!self
|
||||
.threads
|
||||
.try_lock()
|
||||
.expect("Failed to lock threads vec")
|
||||
.is_empty()
|
||||
}
|
||||
|
||||
/**
|
||||
Pops the next thread to run, from the front of the scheduler.
|
||||
|
||||
Returns `None` if there are no threads left to run.
|
||||
*/
|
||||
pub(super) fn pop_thread(&self) -> LuaResult<Option<SchedulerThread>> {
|
||||
match self
|
||||
.threads
|
||||
.try_lock()
|
||||
.into_lua_err()
|
||||
.context("Failed to lock threads vec")?
|
||||
.pop_front()
|
||||
{
|
||||
Some(thread) => Ok(Some(thread)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Schedules the `thread` to be resumed with the given [`LuaError`].
|
||||
*/
|
||||
pub fn push_err<'a>(
|
||||
&self,
|
||||
lua: &'a Lua,
|
||||
thread: impl IntoLuaThread<'a>,
|
||||
err: LuaError,
|
||||
) -> LuaResult<()> {
|
||||
let thread = thread.into_lua_thread(lua)?;
|
||||
let args = LuaMultiValue::new(); // Will be resumed with error, don't need real args
|
||||
|
||||
let thread = SchedulerThread::new(lua, thread, args);
|
||||
let thread_id = thread.id();
|
||||
|
||||
self.state.set_thread_error(thread_id, err);
|
||||
self.threads
|
||||
.try_lock()
|
||||
.into_lua_err()
|
||||
.context("Failed to lock threads vec")?
|
||||
.push_front(thread);
|
||||
|
||||
// NOTE: We might be resuming futures, need to signal that a
|
||||
// new lua thread is ready to break out of futures resumption
|
||||
self.state.message_sender().send_pushed_lua_thread();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/**
|
||||
Schedules the `thread` to be resumed with the given `args`
|
||||
right away, before any other currently scheduled threads.
|
||||
*/
|
||||
pub fn push_front<'a>(
|
||||
&self,
|
||||
lua: &'a Lua,
|
||||
thread: impl IntoLuaThread<'a>,
|
||||
args: impl IntoLuaMulti<'a>,
|
||||
) -> LuaResult<SchedulerThreadId> {
|
||||
let thread = thread.into_lua_thread(lua)?;
|
||||
let args = args.into_lua_multi(lua)?;
|
||||
|
||||
let thread = SchedulerThread::new(lua, thread, args);
|
||||
let thread_id = thread.id();
|
||||
|
||||
self.threads
|
||||
.try_lock()
|
||||
.into_lua_err()
|
||||
.context("Failed to lock threads vec")?
|
||||
.push_front(thread);
|
||||
|
||||
// NOTE: We might be resuming the same thread several times and
|
||||
// pushing it to the scheduler several times before it is done,
|
||||
// and we should only ever create one result sender per thread
|
||||
self.thread_senders
|
||||
.try_lock()
|
||||
.into_lua_err()
|
||||
.context("Failed to lock thread senders vec")?
|
||||
.entry(thread_id)
|
||||
.or_insert_with(|| SchedulerThreadSender::new(1));
|
||||
|
||||
// NOTE: We might be resuming futures, need to signal that a
|
||||
// new lua thread is ready to break out of futures resumption
|
||||
self.state.message_sender().send_pushed_lua_thread();
|
||||
|
||||
Ok(thread_id)
|
||||
}
|
||||
|
||||
/**
|
||||
Schedules the `thread` to be resumed with the given `args`
|
||||
after all other current threads have been resumed.
|
||||
*/
|
||||
pub fn push_back<'a>(
|
||||
&self,
|
||||
lua: &'a Lua,
|
||||
thread: impl IntoLuaThread<'a>,
|
||||
args: impl IntoLuaMulti<'a>,
|
||||
) -> LuaResult<SchedulerThreadId> {
|
||||
let thread = thread.into_lua_thread(lua)?;
|
||||
let args = args.into_lua_multi(lua)?;
|
||||
|
||||
let thread = SchedulerThread::new(lua, thread, args);
|
||||
let thread_id = thread.id();
|
||||
|
||||
self.threads
|
||||
.try_lock()
|
||||
.into_lua_err()
|
||||
.context("Failed to lock threads vec")?
|
||||
.push_back(thread);
|
||||
|
||||
// NOTE: We might be resuming the same thread several times and
|
||||
// pushing it to the scheduler several times before it is done,
|
||||
// and we should only ever create one result sender per thread
|
||||
self.thread_senders
|
||||
.try_lock()
|
||||
.into_lua_err()
|
||||
.context("Failed to lock thread senders vec")?
|
||||
.entry(thread_id)
|
||||
.or_insert_with(|| SchedulerThreadSender::new(1));
|
||||
|
||||
// NOTE: We might be resuming futures, need to signal that a
|
||||
// new lua thread is ready to break out of futures resumption
|
||||
self.state.message_sender().send_pushed_lua_thread();
|
||||
|
||||
Ok(thread_id)
|
||||
}
|
||||
|
||||
/**
|
||||
Waits for the given thread to finish running, and returns its result.
|
||||
*/
|
||||
pub async fn wait_for_thread<'a>(
|
||||
&self,
|
||||
lua: &'a Lua,
|
||||
thread_id: SchedulerThreadId,
|
||||
) -> LuaResult<LuaMultiValue<'a>> {
|
||||
let mut recv = {
|
||||
let senders = self.thread_senders.lock().await;
|
||||
let sender = senders
|
||||
.get(&thread_id)
|
||||
.expect("Tried to wait for thread that is not queued");
|
||||
sender.subscribe()
|
||||
};
|
||||
let res = match recv.recv().await {
|
||||
Err(_) => panic!("Sender was dropped while waiting for {thread_id:?}"),
|
||||
Ok(r) => r,
|
||||
};
|
||||
match res {
|
||||
Err(e) => Err(e),
|
||||
Ok(k) => {
|
||||
let vals = lua
|
||||
.registry_value::<Vec<LuaValue>>(&k)
|
||||
.expect("Received invalid registry key for thread");
|
||||
|
||||
// NOTE: This is not strictly necessary, mlua can clean
|
||||
// up registry values on its own, but doing this will add
|
||||
// some extra safety and clean up registry values faster
|
||||
if let Some(key) = Arc::into_inner(k) {
|
||||
lua.remove_registry_value(key)
|
||||
.expect("Failed to remove registry key for thread");
|
||||
}
|
||||
|
||||
Ok(LuaMultiValue::from_vec(vals))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,98 +0,0 @@
|
|||
use std::sync::{MutexGuard, TryLockError};
|
||||
|
||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
||||
|
||||
use super::state::SchedulerState;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub(crate) enum SchedulerMessage {
|
||||
ExitCodeSet,
|
||||
PushedLuaThread,
|
||||
SpawnedLuaFuture,
|
||||
SpawnedBackgroundFuture,
|
||||
}
|
||||
|
||||
impl SchedulerMessage {
|
||||
pub fn should_break_futures(self) -> bool {
|
||||
matches!(self, Self::ExitCodeSet | Self::PushedLuaThread)
|
||||
}
|
||||
|
||||
pub fn should_break_lua_futures(self) -> bool {
|
||||
self.should_break_futures() || matches!(self, Self::SpawnedBackgroundFuture)
|
||||
}
|
||||
|
||||
pub fn should_break_background_futures(self) -> bool {
|
||||
self.should_break_futures() || matches!(self, Self::SpawnedLuaFuture)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
A message sender for the scheduler.
|
||||
|
||||
As long as this sender is not dropped, the scheduler
|
||||
will be kept alive, waiting for more messages to arrive.
|
||||
*/
|
||||
pub(crate) struct SchedulerMessageSender(UnboundedSender<SchedulerMessage>);
|
||||
|
||||
impl SchedulerMessageSender {
|
||||
/**
|
||||
Creates a new message sender for the scheduler.
|
||||
*/
|
||||
pub fn new(state: &SchedulerState) -> Self {
|
||||
Self(
|
||||
state
|
||||
.message_sender
|
||||
.lock()
|
||||
.expect("Scheduler state was poisoned")
|
||||
.clone(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn send_exit_code_set(&self) {
|
||||
self.0.send(SchedulerMessage::ExitCodeSet).ok();
|
||||
}
|
||||
|
||||
pub fn send_pushed_lua_thread(&self) {
|
||||
self.0.send(SchedulerMessage::PushedLuaThread).ok();
|
||||
}
|
||||
|
||||
pub fn send_spawned_lua_future(&self) {
|
||||
self.0.send(SchedulerMessage::SpawnedLuaFuture).ok();
|
||||
}
|
||||
|
||||
pub fn send_spawned_background_future(&self) {
|
||||
self.0.send(SchedulerMessage::SpawnedBackgroundFuture).ok();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
A message receiver for the scheduler.
|
||||
|
||||
Only one message receiver may exist per scheduler.
|
||||
*/
|
||||
pub(crate) struct SchedulerMessageReceiver<'a>(MutexGuard<'a, UnboundedReceiver<SchedulerMessage>>);
|
||||
|
||||
impl<'a> SchedulerMessageReceiver<'a> {
|
||||
/**
|
||||
Creates a new message receiver for the scheduler.
|
||||
|
||||
Panics if the message receiver is already being used.
|
||||
*/
|
||||
pub fn new(state: &'a SchedulerState) -> Self {
|
||||
Self(match state.message_receiver.try_lock() {
|
||||
Err(TryLockError::Poisoned(_)) => panic!("Sheduler state was poisoned"),
|
||||
Err(TryLockError::WouldBlock) => {
|
||||
panic!("Message receiver may only be borrowed once at a time")
|
||||
}
|
||||
Ok(guard) => guard,
|
||||
})
|
||||
}
|
||||
|
||||
// NOTE: Holding this lock across await points is fine, since we
|
||||
// can only ever create lock exactly one SchedulerMessageReceiver
|
||||
// See above constructor for details on this
|
||||
#[allow(clippy::await_holding_lock)]
|
||||
pub async fn recv(&mut self) -> Option<SchedulerMessage> {
|
||||
self.0.recv().await
|
||||
}
|
||||
}
|
|
@ -1,120 +0,0 @@
|
|||
use std::{
|
||||
collections::{HashMap, VecDeque},
|
||||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use futures_util::{stream::FuturesUnordered, Future};
|
||||
use mlua::prelude::*;
|
||||
use tokio::sync::Mutex as AsyncMutex;
|
||||
|
||||
mod message;
|
||||
mod state;
|
||||
mod thread;
|
||||
mod traits;
|
||||
|
||||
mod impl_async;
|
||||
mod impl_runner;
|
||||
mod impl_threads;
|
||||
|
||||
pub use self::thread::SchedulerThreadId;
|
||||
pub use self::traits::*;
|
||||
|
||||
use self::{
|
||||
state::SchedulerState,
|
||||
thread::{SchedulerThread, SchedulerThreadSender},
|
||||
};
|
||||
|
||||
type SchedulerFuture<'fut> = Pin<Box<dyn Future<Output = ()> + 'fut>>;
|
||||
|
||||
/**
|
||||
Scheduler for Lua threads and futures.
|
||||
|
||||
This scheduler can be cheaply cloned and the underlying state
|
||||
and data will remain unchanged and accessible from all clones.
|
||||
*/
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct Scheduler<'fut> {
|
||||
state: Arc<SchedulerState>,
|
||||
threads: Arc<AsyncMutex<VecDeque<SchedulerThread>>>,
|
||||
thread_senders: Arc<AsyncMutex<HashMap<SchedulerThreadId, SchedulerThreadSender>>>,
|
||||
/*
|
||||
FUTURE: Get rid of these, let the tokio runtime handle running
|
||||
and resumption of futures completely, just use our scheduler
|
||||
state and receiver to know when we have run to completion.
|
||||
If we have no senders left, we have run to completion.
|
||||
|
||||
We should also investigate using smol / async-executor and its
|
||||
LocalExecutor struct which does not impose the 'static lifetime
|
||||
restriction on all of the futures spawned on it, unlike tokio.
|
||||
|
||||
If we no longer store futures directly in our scheduler, we
|
||||
can get rid of the lifetime on it, store it in our lua app
|
||||
data as a Weak<Scheduler>, together with a Weak<Lua>.
|
||||
|
||||
In our lua async functions we can then get a reference to this,
|
||||
upgrade it to an Arc<Scheduler> and Arc<Lua> to extend lifetimes,
|
||||
and hopefully get rid of Box::leak and 'static lifetimes for good.
|
||||
|
||||
Relevant comment on the mlua repository:
|
||||
https://github.com/khvzak/mlua/issues/169#issuecomment-1138863979
|
||||
*/
|
||||
futures_lua: Arc<AsyncMutex<FuturesUnordered<SchedulerFuture<'fut>>>>,
|
||||
futures_background: Arc<AsyncMutex<FuturesUnordered<SchedulerFuture<'static>>>>,
|
||||
}
|
||||
|
||||
impl<'fut> Scheduler<'fut> {
|
||||
/**
|
||||
Creates a new scheduler.
|
||||
*/
|
||||
#[allow(clippy::arc_with_non_send_sync)] // FIXME: Clippy lints our tokio mutexes that are definitely Send + Sync
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: Arc::new(SchedulerState::new()),
|
||||
threads: Arc::new(AsyncMutex::new(VecDeque::new())),
|
||||
thread_senders: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||
futures_lua: Arc::new(AsyncMutex::new(FuturesUnordered::new())),
|
||||
futures_background: Arc::new(AsyncMutex::new(FuturesUnordered::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Sets the luau interrupt for this scheduler.
|
||||
|
||||
This will propagate errors from any lua-spawned
|
||||
futures back to the lua threads that spawned them.
|
||||
*/
|
||||
pub fn set_interrupt_for(&self, lua: &Lua) {
|
||||
// Propagate errors given to the scheduler back to their lua threads
|
||||
// FUTURE: Do profiling and anything else we need inside of this interrupt
|
||||
let state = self.state.clone();
|
||||
lua.set_interrupt(move |_| {
|
||||
if let Some(id) = state.get_current_thread_id() {
|
||||
if let Some(err) = state.get_thread_error(id) {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
Ok(LuaVmState::Continue)
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
Sets the exit code for the scheduler.
|
||||
|
||||
This will stop the scheduler from resuming any more lua threads or futures.
|
||||
|
||||
Panics if the exit code is set more than once.
|
||||
*/
|
||||
pub fn set_exit_code(&self, code: impl Into<u8>) {
|
||||
assert!(
|
||||
self.state.exit_code().is_none(),
|
||||
"Exit code may only be set exactly once"
|
||||
);
|
||||
self.state.set_exit_code(code.into());
|
||||
}
|
||||
|
||||
#[doc(hidden)]
|
||||
pub fn into_static(self) -> &'static Self {
|
||||
Box::leak(Box::new(self))
|
||||
}
|
||||
}
|
|
@ -1,176 +0,0 @@
|
|||
use std::{
|
||||
collections::HashMap,
|
||||
sync::{
|
||||
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
|
||||
Arc, Mutex,
|
||||
},
|
||||
};
|
||||
|
||||
use mlua::Error as LuaError;
|
||||
|
||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
||||
|
||||
use super::{
|
||||
message::{SchedulerMessage, SchedulerMessageReceiver, SchedulerMessageSender},
|
||||
SchedulerThreadId,
|
||||
};
|
||||
|
||||
/**
|
||||
Internal state for a [`Scheduler`].
|
||||
|
||||
This scheduler state uses atomic operations for everything
|
||||
except lua error storage, and is completely thread safe.
|
||||
*/
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct SchedulerState {
|
||||
exit_state: AtomicBool,
|
||||
exit_code: AtomicU8,
|
||||
num_resumptions: AtomicUsize,
|
||||
num_errors: AtomicUsize,
|
||||
thread_id: Arc<Mutex<Option<SchedulerThreadId>>>,
|
||||
thread_errors: Arc<Mutex<HashMap<SchedulerThreadId, LuaError>>>,
|
||||
pub(super) message_sender: Arc<Mutex<UnboundedSender<SchedulerMessage>>>,
|
||||
pub(super) message_receiver: Arc<Mutex<UnboundedReceiver<SchedulerMessage>>>,
|
||||
}
|
||||
|
||||
impl SchedulerState {
|
||||
/**
|
||||
Creates a new scheduler state.
|
||||
*/
|
||||
pub fn new() -> Self {
|
||||
let (message_sender, message_receiver) = unbounded_channel();
|
||||
|
||||
Self {
|
||||
exit_state: AtomicBool::new(false),
|
||||
exit_code: AtomicU8::new(0),
|
||||
num_resumptions: AtomicUsize::new(0),
|
||||
num_errors: AtomicUsize::new(0),
|
||||
thread_id: Arc::new(Mutex::new(None)),
|
||||
thread_errors: Arc::new(Mutex::new(HashMap::new())),
|
||||
message_sender: Arc::new(Mutex::new(message_sender)),
|
||||
message_receiver: Arc::new(Mutex::new(message_receiver)),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Increments the total lua error count for the scheduler.
|
||||
|
||||
This is used to determine if the scheduler should exit with
|
||||
a non-zero exit code, when no exit code is explicitly set.
|
||||
*/
|
||||
pub fn increment_error_count(&self) {
|
||||
self.num_errors.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
|
||||
/**
|
||||
Checks if there have been any lua errors.
|
||||
|
||||
This is used to determine if the scheduler should exit with
|
||||
a non-zero exit code, when no exit code is explicitly set.
|
||||
*/
|
||||
pub fn has_errored(&self) -> bool {
|
||||
self.num_errors.load(Ordering::SeqCst) > 0
|
||||
}
|
||||
|
||||
/**
|
||||
Gets the currently set exit code for the scheduler, if any.
|
||||
*/
|
||||
pub fn exit_code(&self) -> Option<u8> {
|
||||
if self.exit_state.load(Ordering::SeqCst) {
|
||||
Some(self.exit_code.load(Ordering::SeqCst))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Checks if the scheduler has an explicit exit code set.
|
||||
*/
|
||||
pub fn has_exit_code(&self) -> bool {
|
||||
self.exit_state.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
/**
|
||||
Sets the explicit exit code for the scheduler.
|
||||
*/
|
||||
pub fn set_exit_code(&self, code: impl Into<u8>) {
|
||||
self.exit_state.store(true, Ordering::SeqCst);
|
||||
self.exit_code.store(code.into(), Ordering::SeqCst);
|
||||
self.message_sender().send_exit_code_set();
|
||||
}
|
||||
|
||||
/**
|
||||
Gets the currently running lua scheduler thread id, if any.
|
||||
*/
|
||||
pub fn get_current_thread_id(&self) -> Option<SchedulerThreadId> {
|
||||
*self
|
||||
.thread_id
|
||||
.lock()
|
||||
.expect("Failed to lock current thread id")
|
||||
}
|
||||
|
||||
/**
|
||||
Sets the currently running lua scheduler thread id.
|
||||
|
||||
This must be set to `Some(id)` just before resuming a lua thread,
|
||||
and `None` while no lua thread is being resumed. If set to `Some`
|
||||
while the current thread id is also `Some`, this will panic.
|
||||
|
||||
Must only be set once per thread id, although this
|
||||
is not checked at runtime for performance reasons.
|
||||
*/
|
||||
pub fn set_current_thread_id(&self, id: Option<SchedulerThreadId>) {
|
||||
self.num_resumptions.fetch_add(1, Ordering::Relaxed);
|
||||
let mut thread_id = self
|
||||
.thread_id
|
||||
.lock()
|
||||
.expect("Failed to lock current thread id");
|
||||
assert!(
|
||||
id.is_none() || thread_id.is_none(),
|
||||
"Current thread id can not be overwritten"
|
||||
);
|
||||
*thread_id = id;
|
||||
}
|
||||
|
||||
/**
|
||||
Gets the [`LuaError`] (if any) for the given `id`.
|
||||
|
||||
Note that this removes the error from the scheduler state completely.
|
||||
*/
|
||||
pub fn get_thread_error(&self, id: SchedulerThreadId) -> Option<LuaError> {
|
||||
let mut thread_errors = self
|
||||
.thread_errors
|
||||
.lock()
|
||||
.expect("Failed to lock thread errors");
|
||||
thread_errors.remove(&id)
|
||||
}
|
||||
|
||||
/**
|
||||
Sets a [`LuaError`] for the given `id`.
|
||||
|
||||
Note that this will replace any already existing [`LuaError`].
|
||||
*/
|
||||
pub fn set_thread_error(&self, id: SchedulerThreadId, err: LuaError) {
|
||||
let mut thread_errors = self
|
||||
.thread_errors
|
||||
.lock()
|
||||
.expect("Failed to lock thread errors");
|
||||
thread_errors.insert(id, err);
|
||||
}
|
||||
|
||||
/**
|
||||
Creates a new message sender for the scheduler.
|
||||
*/
|
||||
pub fn message_sender(&self) -> SchedulerMessageSender {
|
||||
SchedulerMessageSender::new(self)
|
||||
}
|
||||
|
||||
/**
|
||||
Tries to borrow the message receiver for the scheduler.
|
||||
|
||||
Panics if the message receiver is already being used.
|
||||
*/
|
||||
pub fn message_receiver(&self) -> SchedulerMessageReceiver {
|
||||
SchedulerMessageReceiver::new(self)
|
||||
}
|
||||
}
|
|
@ -1,105 +0,0 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use mlua::prelude::*;
|
||||
use tokio::sync::broadcast::Sender;
|
||||
|
||||
/**
|
||||
Type alias for a broadcast [`Sender`], which will
|
||||
broadcast the result and return values of a lua thread.
|
||||
|
||||
The return values are stored in the lua registry as a
|
||||
`Vec<LuaValue<'_>>`, and the registry key pointing to
|
||||
those values will be sent using the broadcast sender.
|
||||
*/
|
||||
pub type SchedulerThreadSender = Sender<LuaResult<Arc<LuaRegistryKey>>>;
|
||||
|
||||
/**
|
||||
Unique, randomly generated id for a scheduler thread.
|
||||
*/
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
||||
pub struct SchedulerThreadId(usize);
|
||||
|
||||
impl From<&LuaThread<'_>> for SchedulerThreadId {
|
||||
fn from(value: &LuaThread) -> Self {
|
||||
// HACK: We rely on the debug format of mlua
|
||||
// thread refs here, but currently this is the
|
||||
// only way to get a proper unique id using mlua
|
||||
let addr_string = format!("{value:?}");
|
||||
let addr = addr_string
|
||||
.strip_prefix("Thread(Ref(0x")
|
||||
.expect("Invalid thread address format - unknown prefix")
|
||||
.split_once(')')
|
||||
.map(|(s, _)| s)
|
||||
.expect("Invalid thread address format - missing ')'");
|
||||
let id = usize::from_str_radix(addr, 16)
|
||||
.expect("Failed to parse thread address as hexadecimal into usize");
|
||||
Self(id)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Container for registry keys that point to a thread and thread arguments.
|
||||
*/
|
||||
#[derive(Debug)]
|
||||
pub(super) struct SchedulerThread {
|
||||
thread_id: SchedulerThreadId,
|
||||
key_thread: LuaRegistryKey,
|
||||
key_args: LuaRegistryKey,
|
||||
}
|
||||
|
||||
impl SchedulerThread {
|
||||
/**
|
||||
Creates a new scheduler thread container from the given thread and arguments.
|
||||
|
||||
May fail if an allocation error occurs, is not fallible otherwise.
|
||||
*/
|
||||
pub(super) fn new<'lua>(
|
||||
lua: &'lua Lua,
|
||||
thread: LuaThread<'lua>,
|
||||
args: LuaMultiValue<'lua>,
|
||||
) -> Self {
|
||||
let args_vec = args.into_vec();
|
||||
let thread_id = SchedulerThreadId::from(&thread);
|
||||
|
||||
let key_thread = lua
|
||||
.create_registry_value(thread)
|
||||
.expect("Failed to store thread in registry - out of memory");
|
||||
let key_args = lua
|
||||
.create_registry_value(args_vec)
|
||||
.expect("Failed to store thread args in registry - out of memory");
|
||||
|
||||
Self {
|
||||
thread_id,
|
||||
key_thread,
|
||||
key_args,
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Extracts the inner thread and args from the container.
|
||||
*/
|
||||
pub(super) fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) {
|
||||
let thread = lua
|
||||
.registry_value(&self.key_thread)
|
||||
.expect("Failed to get thread from registry");
|
||||
let args_vec = lua
|
||||
.registry_value(&self.key_args)
|
||||
.expect("Failed to get thread args from registry");
|
||||
|
||||
let args = LuaMultiValue::from_vec(args_vec);
|
||||
|
||||
lua.remove_registry_value(self.key_thread)
|
||||
.expect("Failed to remove thread from registry");
|
||||
lua.remove_registry_value(self.key_args)
|
||||
.expect("Failed to remove thread args from registry");
|
||||
|
||||
(thread, args)
|
||||
}
|
||||
|
||||
/**
|
||||
Retrieves the unique, randomly generated id for this scheduler thread.
|
||||
*/
|
||||
pub(super) fn id(&self) -> SchedulerThreadId {
|
||||
self.thread_id
|
||||
}
|
||||
}
|
|
@ -1,118 +0,0 @@
|
|||
use futures_util::Future;
|
||||
use mlua::prelude::*;
|
||||
|
||||
use super::Scheduler;
|
||||
|
||||
const ASYNC_IMPL_LUA: &str = r#"
|
||||
schedule(...)
|
||||
return yield()
|
||||
"#;
|
||||
|
||||
/**
|
||||
Trait for extensions to the [`Lua`] struct, allowing
|
||||
for access to the scheduler without having to import
|
||||
it or handle registry / app data references manually.
|
||||
*/
|
||||
pub(crate) trait LuaSchedulerExt<'lua> {
|
||||
/**
|
||||
Sets the scheduler for the [`Lua`] struct.
|
||||
*/
|
||||
fn set_scheduler(&'lua self, scheduler: &'lua Scheduler);
|
||||
|
||||
/**
|
||||
Creates a function callable from Lua that runs an async
|
||||
closure and returns the results of it to the call site.
|
||||
*/
|
||||
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
|
||||
where
|
||||
A: FromLuaMulti<'lua>,
|
||||
R: IntoLuaMulti<'lua>,
|
||||
F: Fn(&'lua Lua, A) -> FR + 'lua,
|
||||
FR: Future<Output = LuaResult<R>> + 'lua;
|
||||
}
|
||||
|
||||
// FIXME: `self` escapes outside of method because we are borrowing `func`
|
||||
// when we call `schedule_future_thread` in the lua function body below
|
||||
// For now we solve this by using the 'static lifetime bound in the impl
|
||||
impl<'lua> LuaSchedulerExt<'lua> for Lua
|
||||
where
|
||||
'lua: 'static,
|
||||
{
|
||||
fn set_scheduler(&'lua self, scheduler: &'lua Scheduler) {
|
||||
self.set_app_data(scheduler);
|
||||
scheduler.set_interrupt_for(self);
|
||||
}
|
||||
|
||||
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
|
||||
where
|
||||
A: FromLuaMulti<'lua>,
|
||||
R: IntoLuaMulti<'lua>,
|
||||
F: Fn(&'lua Lua, A) -> FR + 'lua,
|
||||
FR: Future<Output = LuaResult<R>> + 'lua,
|
||||
{
|
||||
self.app_data_ref::<&Scheduler>()
|
||||
.expect("Lua must have a scheduler to create async functions");
|
||||
|
||||
let async_env = self.create_table_with_capacity(0, 2)?;
|
||||
|
||||
async_env.set(
|
||||
"yield",
|
||||
self.globals()
|
||||
.get::<_, LuaTable>("coroutine")?
|
||||
.get::<_, LuaFunction>("yield")?,
|
||||
)?;
|
||||
|
||||
async_env.set(
|
||||
"schedule",
|
||||
LuaFunction::wrap(move |lua: &Lua, args: A| {
|
||||
let thread = lua.current_thread();
|
||||
let future = func(lua, args);
|
||||
let sched = lua
|
||||
.app_data_ref::<&Scheduler>()
|
||||
.expect("Lua struct is missing scheduler");
|
||||
sched.spawn_thread(lua, thread, future)?;
|
||||
Ok(())
|
||||
}),
|
||||
)?;
|
||||
|
||||
let async_func = self
|
||||
.load(ASYNC_IMPL_LUA)
|
||||
.set_name("async")
|
||||
.set_environment(async_env)
|
||||
.into_function()?;
|
||||
Ok(async_func)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Trait for any struct that can be turned into an [`LuaThread`]
|
||||
and given to the scheduler, implemented for the following types:
|
||||
|
||||
- Lua threads ([`LuaThread`])
|
||||
- Lua functions ([`LuaFunction`])
|
||||
- Lua chunks ([`LuaChunk`])
|
||||
*/
|
||||
pub trait IntoLuaThread<'lua> {
|
||||
/**
|
||||
Converts the value into a lua thread.
|
||||
*/
|
||||
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>>;
|
||||
}
|
||||
|
||||
impl<'lua> IntoLuaThread<'lua> for LuaThread<'lua> {
|
||||
fn into_lua_thread(self, _: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua> IntoLuaThread<'lua> for LuaFunction<'lua> {
|
||||
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
|
||||
lua.create_thread(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'lua, 'a> IntoLuaThread<'lua> for LuaChunk<'lua, 'a> {
|
||||
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
|
||||
lua.create_thread(self.into_function()?)
|
||||
}
|
||||
}
|
|
@ -1,18 +0,0 @@
|
|||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::task::{Context, Poll};
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct YieldForever;
|
||||
|
||||
impl Future for YieldForever {
|
||||
type Output = ();
|
||||
|
||||
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
|
||||
Poll::Pending
|
||||
}
|
||||
}
|
||||
|
||||
pub fn yield_forever() -> YieldForever {
|
||||
YieldForever
|
||||
}
|
|
@ -2,7 +2,6 @@ mod table_builder;
|
|||
|
||||
pub mod buffer;
|
||||
pub mod formatting;
|
||||
pub mod futures;
|
||||
pub mod luaurc;
|
||||
pub mod paths;
|
||||
pub mod traits;
|
||||
|
|
|
@ -4,8 +4,6 @@ use std::future::Future;
|
|||
|
||||
use mlua::prelude::*;
|
||||
|
||||
use crate::lune::scheduler::LuaSchedulerExt;
|
||||
|
||||
pub struct TableBuilder<'lua> {
|
||||
lua: &'lua Lua,
|
||||
tab: LuaTable<'lua>,
|
||||
|
@ -79,20 +77,13 @@ impl<'lua> TableBuilder<'lua> {
|
|||
pub fn build(self) -> LuaResult<LuaTable<'lua>> {
|
||||
Ok(self.tab)
|
||||
}
|
||||
}
|
||||
|
||||
// FIXME: Remove static lifetime bound here when `create_async_function`
|
||||
// no longer needs it to compile, then move this into the above impl
|
||||
impl<'lua> TableBuilder<'lua>
|
||||
where
|
||||
'lua: 'static,
|
||||
{
|
||||
pub fn with_async_function<K, A, R, F, FR>(self, key: K, func: F) -> LuaResult<Self>
|
||||
where
|
||||
K: IntoLua<'lua>,
|
||||
A: FromLuaMulti<'lua>,
|
||||
R: IntoLuaMulti<'lua>,
|
||||
F: Fn(&'lua Lua, A) -> FR + 'lua,
|
||||
F: Fn(&'lua Lua, A) -> FR + 'static,
|
||||
FR: Future<Output = LuaResult<R>> + 'lua,
|
||||
{
|
||||
let f = self.lua.create_async_function(func)?;
|
||||
|
|
|
@ -68,6 +68,7 @@ create_tests! {
|
|||
net_url_decode: "net/url/decode",
|
||||
net_serve_requests: "net/serve/requests",
|
||||
net_serve_websockets: "net/serve/websockets",
|
||||
net_socket_basic: "net/socket/basic",
|
||||
net_socket_wss: "net/socket/wss",
|
||||
net_socket_wss_rw: "net/socket/wss_rw",
|
||||
|
||||
|
@ -84,7 +85,6 @@ create_tests! {
|
|||
|
||||
require_aliases: "require/tests/aliases",
|
||||
require_async: "require/tests/async",
|
||||
require_async_background: "require/tests/async_background",
|
||||
require_async_concurrent: "require/tests/async_concurrent",
|
||||
require_async_sequential: "require/tests/async_sequential",
|
||||
require_builtins: "require/tests/builtins",
|
||||
|
@ -95,6 +95,7 @@ create_tests! {
|
|||
require_nested: "require/tests/nested",
|
||||
require_parents: "require/tests/parents",
|
||||
require_siblings: "require/tests/siblings",
|
||||
require_state: "require/tests/state",
|
||||
|
||||
global_g_table: "globals/_G",
|
||||
global_version: "globals/_VERSION",
|
||||
|
|
28
tests/net/socket/basic.luau
Normal file
28
tests/net/socket/basic.luau
Normal file
|
@ -0,0 +1,28 @@
|
|||
local net = require("@lune/net")
|
||||
|
||||
-- We're going to use Discord's WebSocket gateway server for testing
|
||||
local socket = net.socket("wss://gateway.discord.gg/?v=10&encoding=json")
|
||||
|
||||
assert(type(socket.next) == "function", "next must be a function")
|
||||
assert(type(socket.send) == "function", "send must be a function")
|
||||
assert(type(socket.close) == "function", "close must be a function")
|
||||
|
||||
-- Request to close the socket
|
||||
socket.close()
|
||||
|
||||
-- Drain remaining messages, until we got our close message
|
||||
while socket.next() do
|
||||
end
|
||||
|
||||
assert(type(socket.closeCode) == "number", "closeCode should exist after closing")
|
||||
assert(socket.closeCode == 1000, "closeCode should be 1000 after closing")
|
||||
|
||||
local success, message = pcall(function()
|
||||
socket.send("Hello, world!")
|
||||
end)
|
||||
|
||||
assert(not success, "send should fail after closing")
|
||||
assert(
|
||||
string.find(tostring(message), "closed") or string.find(tostring(message), "closing"),
|
||||
"send should fail with a message that the socket was closed"
|
||||
)
|
|
@ -1,51 +0,0 @@
|
|||
local net = require("@lune/net")
|
||||
local process = require("@lune/process")
|
||||
local stdio = require("@lune/stdio")
|
||||
local task = require("@lune/task")
|
||||
|
||||
-- Spawn an asynchronous background task (eg. web server)
|
||||
|
||||
local PORT = 8082
|
||||
|
||||
task.delay(3, function()
|
||||
stdio.ewrite("Test did not complete in time\n")
|
||||
task.wait(1)
|
||||
process.exit(1)
|
||||
end)
|
||||
|
||||
local handle = net.serve(PORT, function(request)
|
||||
return ""
|
||||
end)
|
||||
|
||||
-- Require modules same way we did in the async_concurrent and async_sequential tests
|
||||
|
||||
local module3
|
||||
local module4
|
||||
|
||||
task.defer(function()
|
||||
module4 = require("./modules/async")
|
||||
end)
|
||||
|
||||
task.spawn(function()
|
||||
module3 = require("./modules/async")
|
||||
end)
|
||||
|
||||
local _module1 = require("./modules/async")
|
||||
local _module2 = require("./modules/async")
|
||||
|
||||
task.wait(1)
|
||||
|
||||
assert(type(module3) == "table", "Required module3 did not return a table")
|
||||
assert(module3.Foo == "Bar", "Required module3 did not contain correct values")
|
||||
assert(module3.Hello == "World", "Required module3 did not contain correct values")
|
||||
|
||||
assert(type(module4) == "table", "Required module4 did not return a table")
|
||||
assert(module4.Foo == "Bar", "Required module4 did not contain correct values")
|
||||
assert(module4.Hello == "World", "Required module4 did not contain correct values")
|
||||
|
||||
assert(module3 == module4, "Required modules should point to the same return value")
|
||||
|
||||
-- Stop the server and exit successfully
|
||||
|
||||
handle.stop()
|
||||
process.exit(0)
|
14
tests/require/tests/state.luau
Normal file
14
tests/require/tests/state.luau
Normal file
|
@ -0,0 +1,14 @@
|
|||
-- the idea of this test is that state_module stores some state in one of its local
|
||||
-- variable
|
||||
local state_module = require("./state_module")
|
||||
|
||||
-- we confirm that without anything happening, the initial value is what we expect
|
||||
assert(state_module.state == 10)
|
||||
|
||||
-- this second file also requires state_module and calls a function that changes the local
|
||||
-- state to 11
|
||||
require("./state_second")
|
||||
|
||||
-- with correct module caching, we should see the change done in state_secone reflected
|
||||
-- here
|
||||
assert(state_module.state == 11)
|
9
tests/require/tests/state_module.luau
Normal file
9
tests/require/tests/state_module.luau
Normal file
|
@ -0,0 +1,9 @@
|
|||
local M = {}
|
||||
|
||||
M.state = 10
|
||||
|
||||
function M.set_state(n: number)
|
||||
M.state = n
|
||||
end
|
||||
|
||||
return M
|
5
tests/require/tests/state_second.luau
Normal file
5
tests/require/tests/state_second.luau
Normal file
|
@ -0,0 +1,5 @@
|
|||
local state_module = require("./state_module")
|
||||
|
||||
state_module.set_state(11)
|
||||
|
||||
return {}
|
Loading…
Add table
Reference in a new issue