Rewrite scheduler and make it smol (#165)

This commit is contained in:
Filip Tibell 2024-03-11 19:11:14 +01:00 committed by GitHub
parent 1f211ca0ab
commit cd34dcb0dd
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
44 changed files with 1489 additions and 2525 deletions

806
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -79,11 +79,19 @@ urlencoding = "2.1"
### RUNTIME
blocking = "1.5"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
mlua = { version = "0.9.1", features = ["luau", "luau-jit", "serialize"] }
tokio = { version = "1.24", features = ["full", "tracing"] }
os_str_bytes = { version = "6.4", features = ["conversions"] }
os_str_bytes = { version = "7.0", features = ["conversions"] }
mlua-luau-scheduler = { version = "0.0.2" }
mlua = { version = "0.9.6", features = [
"luau",
"luau-jit",
"async",
"serialize",
] }
### SERDE
@ -101,12 +109,17 @@ toml = { version = "0.8", features = ["preserve_order"] }
### NET
hyper = { version = "0.14", features = ["full"] }
hyper-tungstenite = { version = "0.11" }
hyper = { version = "1.1", features = ["full"] }
hyper-util = { version = "0.1", features = ["full"] }
http = "1.0"
http-body-util = { version = "0.1" }
hyper-tungstenite = { version = "0.13" }
reqwest = { version = "0.11", default-features = false, features = [
"rustls-tls",
] }
tokio-tungstenite = { version = "0.20", features = ["rustls-tls-webpki-roots"] }
tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] }
### DATETIME
chrono = "0.4"
@ -115,7 +128,7 @@ chrono_lc = "0.1"
### CLI
anyhow = { optional = true, version = "1.0" }
env_logger = { optional = true, version = "0.10" }
env_logger = { optional = true, version = "0.11" }
itertools = { optional = true, version = "0.12" }
clap = { optional = true, version = "4.1", features = ["derive"] }
include_dir = { optional = true, version = "0.7", features = ["glob"] }

View file

@ -14,7 +14,7 @@ use copy::copy;
use metadata::FsMetadata;
use options::FsWriteOptions;
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
TableBuilder::new(lua)?
.with_async_function("readFile", fs_read_file)?
.with_async_function("readDir", fs_read_dir)?

View file

@ -28,10 +28,7 @@ pub enum LuneBuiltin {
Roblox,
}
impl<'lua> LuneBuiltin
where
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
{
impl LuneBuiltin {
pub fn name(&self) -> &'static str {
match self {
Self::DateTime => "datetime",
@ -47,7 +44,7 @@ where
}
}
pub fn create(&self, lua: &'lua Lua) -> LuaResult<LuaMultiValue<'lua>> {
pub fn create<'lua>(&self, lua: &'lua Lua) -> LuaResult<LuaMultiValue<'lua>> {
let res = match self {
Self::DateTime => datetime::create(lua),
Self::Fs => fs::create(lua),

View file

@ -2,8 +2,14 @@ use std::str::FromStr;
use mlua::prelude::*;
use hyper::{header::HeaderName, http::HeaderValue, HeaderMap};
use reqwest::{IntoUrl, Method, RequestBuilder};
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_ENCODING};
use crate::lune::{
builtins::serde::compress_decompress::{decompress, CompressDecompressFormat},
util::TableBuilder,
};
use super::{config::RequestConfig, util::header_map_to_table};
const REGISTRY_KEY: &str = "NetClient";
@ -35,16 +41,19 @@ impl NetClientBuilder {
pub fn build(self) -> LuaResult<NetClient> {
let client = self.builder.build().into_lua_err()?;
Ok(NetClient(client))
Ok(NetClient { inner: client })
}
}
#[derive(Debug, Clone)]
pub struct NetClient(reqwest::Client);
pub struct NetClient {
inner: reqwest::Client,
}
impl NetClient {
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
self.0.request(method, url)
pub fn from_registry(lua: &Lua) -> Self {
lua.named_registry_value(REGISTRY_KEY)
.expect("Failed to get NetClient from lua registry")
}
pub fn into_registry(self, lua: &Lua) {
@ -52,16 +61,68 @@ impl NetClient {
.expect("Failed to store NetClient in lua registry");
}
pub fn from_registry(lua: &Lua) -> Self {
lua.named_registry_value(REGISTRY_KEY)
.expect("Failed to get NetClient from lua registry")
pub async fn request(&self, config: RequestConfig) -> LuaResult<NetClientResponse> {
// Create and send the request
let mut request = self.inner.request(config.method, config.url);
for (query, values) in config.query {
request = request.query(
&values
.iter()
.map(|v| (query.as_str(), v))
.collect::<Vec<_>>(),
);
}
for (header, values) in config.headers {
for value in values {
request = request.header(header.as_str(), value);
}
}
let res = request
.body(config.body.unwrap_or_default())
.send()
.await
.into_lua_err()?;
// Extract status, headers
let res_status = res.status().as_u16();
let res_status_text = res.status().canonical_reason();
let res_headers = res.headers().clone();
// Read response bytes
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
let mut res_decompressed = false;
// Check for extra options, decompression
if config.options.decompress {
let decompress_format = res_headers
.iter()
.find(|(name, _)| {
name.as_str()
.eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
})
.and_then(|(_, value)| value.to_str().ok())
.and_then(CompressDecompressFormat::detect_from_header_str);
if let Some(format) = decompress_format {
res_bytes = decompress(format, res_bytes).await?;
res_decompressed = true;
}
}
Ok(NetClientResponse {
ok: (200..300).contains(&res_status),
status_code: res_status,
status_message: res_status_text.unwrap_or_default().to_string(),
headers: res_headers,
body: res_bytes,
body_decompressed: res_decompressed,
})
}
}
impl LuaUserData for NetClient {}
impl<'lua> FromLua<'lua> for NetClient {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
impl FromLua<'_> for NetClient {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
if let LuaValue::UserData(ud) = value {
if let Ok(ctx) = ud.borrow::<NetClient>() {
return Ok(ctx.clone());
@ -71,10 +132,34 @@ impl<'lua> FromLua<'lua> for NetClient {
}
}
impl<'lua> From<&'lua Lua> for NetClient {
fn from(value: &'lua Lua) -> Self {
impl From<&Lua> for NetClient {
fn from(value: &Lua) -> Self {
value
.named_registry_value(REGISTRY_KEY)
.expect("Missing require context in lua registry")
}
}
pub struct NetClientResponse {
ok: bool,
status_code: u16,
status_message: String,
headers: HeaderMap,
body: Vec<u8>,
body_decompressed: bool,
}
impl NetClientResponse {
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
TableBuilder::new(lua)?
.with_value("ok", self.ok)?
.with_value("statusCode", self.status_code)?
.with_value("statusMessage", self.status_message)?
.with_value(
"headers",
header_map_to_table(lua, self.headers, self.body_decompressed)?,
)?
.with_value("body", lua.create_string(&self.body)?)?
.build_readonly()
}
}

View file

@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, net::Ipv4Addr};
use mlua::prelude::*;
@ -6,6 +6,18 @@ use reqwest::Method;
use super::util::table_to_hash_map;
const DEFAULT_IP_ADDRESS: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#"
return {
status = 426,
body = "Upgrade Required",
headers = {
Upgrade = "websocket",
},
}
"#;
// Net request config
#[derive(Debug, Clone)]
@ -21,28 +33,29 @@ impl Default for RequestConfigOptions {
impl<'lua> FromLua<'lua> for RequestConfigOptions {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
// Nil means default options, table means custom options
if let LuaValue::Nil = value {
return Ok(Self::default());
// Nil means default options
Ok(Self::default())
} else if let LuaValue::Table(tab) = value {
// Extract flags
let decompress = match tab.raw_get::<_, Option<bool>>("decompress") {
// Table means custom options
let decompress = match tab.get::<_, Option<bool>>("decompress") {
Ok(decomp) => Ok(decomp.unwrap_or(true)),
Err(_) => Err(LuaError::RuntimeError(
"Invalid option value for 'decompress' in request config options".to_string(),
)),
}?;
return Ok(Self { decompress });
Ok(Self { decompress })
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfigOptions",
message: Some(format!(
"Invalid request config options - expected table or nil, got {}",
value.type_name()
)),
})
}
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfigOptions",
message: Some(format!(
"Invalid request config options - expected table or nil, got {}",
value.type_name()
)),
})
}
}
@ -60,39 +73,38 @@ impl FromLua<'_> for RequestConfig {
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
// If we just got a string we assume its a GET request to a given url
if let LuaValue::String(s) = value {
return Ok(Self {
Ok(Self {
url: s.to_string_lossy().to_string(),
method: Method::GET,
query: HashMap::new(),
headers: HashMap::new(),
body: None,
options: Default::default(),
});
}
// If we got a table we are able to configure the entire request
if let LuaValue::Table(tab) = value {
})
} else if let LuaValue::Table(tab) = value {
// If we got a table we are able to configure the entire request
// Extract url
let url = match tab.raw_get::<_, LuaString>("url") {
let url = match tab.get::<_, LuaString>("url") {
Ok(config_url) => Ok(config_url.to_string_lossy().to_string()),
Err(_) => Err(LuaError::runtime("Missing 'url' in request config")),
}?;
// Extract method
let method = match tab.raw_get::<_, LuaString>("method") {
let method = match tab.get::<_, LuaString>("method") {
Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(),
Err(_) => "GET".to_string(),
};
// Extract query
let query = match tab.raw_get::<_, LuaTable>("query") {
let query = match tab.get::<_, LuaTable>("query") {
Ok(tab) => table_to_hash_map(tab, "query")?,
Err(_) => HashMap::new(),
};
// Extract headers
let headers = match tab.raw_get::<_, LuaTable>("headers") {
let headers = match tab.get::<_, LuaTable>("headers") {
Ok(tab) => table_to_hash_map(tab, "headers")?,
Err(_) => HashMap::new(),
};
// Extract body
let body = match tab.raw_get::<_, LuaString>("body") {
let body = match tab.get::<_, LuaString>("body") {
Ok(config_body) => Some(config_body.as_bytes().to_owned()),
Err(_) => None,
};
@ -112,29 +124,30 @@ impl FromLua<'_> for RequestConfig {
))),
}?;
// Parse any extra options given
let options = match tab.raw_get::<_, LuaValue>("options") {
let options = match tab.get::<_, LuaValue>("options") {
Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?,
Err(_) => RequestConfigOptions::default(),
};
// All good, validated and we got what we need
return Ok(Self {
Ok(Self {
url,
method,
query,
headers,
body,
options,
});
};
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfig",
message: Some(format!(
"Invalid request config - expected string or table, got {}",
value.type_name()
)),
})
})
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "RequestConfig",
message: Some(format!(
"Invalid request config - expected string or table, got {}",
value.type_name()
)),
})
}
}
}
@ -142,54 +155,72 @@ impl FromLua<'_> for RequestConfig {
#[derive(Debug)]
pub struct ServeConfig<'a> {
pub address: Ipv4Addr,
pub handle_request: LuaFunction<'a>,
pub handle_web_socket: Option<LuaFunction<'a>>,
pub address: Option<LuaString<'a>>,
}
impl<'lua> FromLua<'lua> for ServeConfig<'lua> {
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
let message = match &value {
LuaValue::Function(f) => {
return Ok(ServeConfig {
handle_request: f.clone(),
handle_web_socket: None,
address: None,
if let LuaValue::Function(f) = &value {
// Single function = request handler, rest is default
Ok(ServeConfig {
handle_request: f.clone(),
handle_web_socket: None,
address: DEFAULT_IP_ADDRESS,
})
} else if let LuaValue::Table(t) = &value {
// Table means custom options
let address: Option<LuaString> = t.get("address")?;
let handle_request: Option<LuaFunction> = t.get("handleRequest")?;
let handle_web_socket: Option<LuaFunction> = t.get("handleWebSocket")?;
if handle_request.is_some() || handle_web_socket.is_some() {
let address: Ipv4Addr = match &address {
Some(addr) => {
let addr_str = addr.to_str()?;
addr_str
.trim_start_matches("http://")
.trim_start_matches("https://")
.parse()
.map_err(|_e| LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig",
message: Some(format!(
"IP address format is incorrect - \
expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \
got '{addr_str}'"
)),
})?
}
None => DEFAULT_IP_ADDRESS,
};
Ok(Self {
address,
handle_request: handle_request.unwrap_or_else(|| {
lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER)
.into_function()
.expect("Failed to create default http responder function")
}),
handle_web_socket,
})
} else {
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig",
message: Some(String::from(
"Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function",
)),
})
}
LuaValue::Table(t) => {
let handle_request: Option<LuaFunction> = t.raw_get("handleRequest")?;
let handle_web_socket: Option<LuaFunction> = t.raw_get("handleWebSocket")?;
let address: Option<LuaString> = t.raw_get("address")?;
if handle_request.is_some() || handle_web_socket.is_some() {
return Ok(ServeConfig {
handle_request: handle_request.unwrap_or_else(|| {
let chunk = r#"
return {
status = 426,
body = "Upgrade Required",
headers = {
Upgrade = "websocket",
},
}
"#;
lua.load(chunk)
.into_function()
.expect("Failed to create default http responder function")
}),
handle_web_socket,
address,
});
} else {
Some("Missing handleRequest and / or handleWebSocket".to_string())
}
}
_ => None,
};
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig",
message,
})
} else {
// Anything else is invalid
Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "ServeConfig",
message: None,
})
}
}
}

View file

@ -1,36 +1,27 @@
use std::net::Ipv4Addr;
#![allow(unused_variables)]
use mlua::prelude::*;
use hyper::header::CONTENT_ENCODING;
use crate::lune::{scheduler::Scheduler, util::TableBuilder};
use self::{
server::{bind_to_addr, create_server},
util::header_map_to_table,
};
use super::serde::{
compress_decompress::{decompress, CompressDecompressFormat},
encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat},
};
use mlua_luau_scheduler::LuaSpawnExt;
mod client;
mod config;
mod processing;
mod response;
mod server;
mod util;
mod websocket;
use client::{NetClient, NetClientBuilder};
use config::{RequestConfig, ServeConfig};
use websocket::NetWebSocket;
use crate::lune::util::TableBuilder;
const DEFAULT_IP_ADDRESS: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
use self::{
client::{NetClient, NetClientBuilder},
config::{RequestConfig, ServeConfig},
server::serve,
util::create_user_agent_header,
websocket::NetWebSocket,
};
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
use super::serde::encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat};
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
NetClientBuilder::new()
.headers(&[("User-Agent", create_user_agent_header())])?
.build()?
@ -46,14 +37,6 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
.build_readonly()
}
fn create_user_agent_header() -> String {
let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY")
.trim_start_matches("https://github.com/")
.split_once('/')
.unwrap();
format!("{github_owner}-{github_repo}-cli")
}
fn net_json_encode<'lua>(
lua: &'lua Lua,
(val, pretty): (LuaValue<'lua>, Option<bool>),
@ -66,68 +49,14 @@ fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult<Lua
EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json)
}
async fn net_request<'lua>(lua: &'lua Lua, config: RequestConfig) -> LuaResult<LuaTable<'lua>>
where
'lua: 'static, // FIXME: Get rid of static lifetime bound here
{
// Create and send the request
async fn net_request(lua: &Lua, config: RequestConfig) -> LuaResult<LuaTable> {
let client = NetClient::from_registry(lua);
let mut request = client.request(config.method, &config.url);
for (query, values) in config.query {
request = request.query(
&values
.iter()
.map(|v| (query.as_str(), v))
.collect::<Vec<_>>(),
);
}
for (header, values) in config.headers {
for value in values {
request = request.header(header.as_str(), value);
}
}
let res = request
.body(config.body.unwrap_or_default())
.send()
.await
.into_lua_err()?;
// Extract status, headers
let res_status = res.status().as_u16();
let res_status_text = res.status().canonical_reason();
let res_headers = res.headers().clone();
// Read response bytes
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
let mut res_decompressed = false;
// Check for extra options, decompression
if config.options.decompress {
let decompress_format = res_headers
.iter()
.find(|(name, _)| {
name.as_str()
.eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
})
.and_then(|(_, value)| value.to_str().ok())
.and_then(CompressDecompressFormat::detect_from_header_str);
if let Some(format) = decompress_format {
res_bytes = decompress(format, res_bytes).await?;
res_decompressed = true;
}
}
// Construct and return a readonly lua table with results
let res_headers_lua = header_map_to_table(lua, res_headers, res_decompressed)?;
TableBuilder::new(lua)?
.with_value("ok", (200..300).contains(&res_status))?
.with_value("statusCode", res_status)?
.with_value("statusMessage", res_status_text)?
.with_value("headers", res_headers_lua)?
.with_value("body", lua.create_string(&res_bytes)?)?
.build_readonly()
// NOTE: We spawn the request as a background task to free up resources in lua
let res = lua.spawn(async move { client.request(config).await });
res.await?.into_lua_table(lua)
}
async fn net_socket<'lua>(lua: &'lua Lua, url: String) -> LuaResult<LuaTable>
where
'lua: 'static, // FIXME: Get rid of static lifetime bound here
{
async fn net_socket(lua: &Lua, url: String) -> LuaResult<LuaTable> {
let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?;
NetWebSocket::new(ws).into_lua_table(lua)
}
@ -135,32 +64,8 @@ where
async fn net_serve<'lua>(
lua: &'lua Lua,
(port, config): (u16, ServeConfig<'lua>),
) -> LuaResult<LuaTable<'lua>>
where
'lua: 'static, // FIXME: Get rid of static lifetime bound here
{
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
let address: Ipv4Addr = match &config.address {
Some(addr) => {
let addr_str = addr.to_str()?;
addr_str
.trim_start_matches("http://")
.trim_start_matches("https://")
.parse()
.map_err(|_e| LuaError::RuntimeError(format!(
"IP address format is incorrect (expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', got '{addr_str}')"
)))?
}
None => DEFAULT_IP_ADDRESS,
};
let builder = bind_to_addr(address, port)?;
create_server(lua, &sched, config, builder)
) -> LuaResult<LuaTable<'lua>> {
serve(lua, port, config).await
}
fn net_url_encode<'lua>(

View file

@ -1,101 +0,0 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use hyper::{body::to_bytes, Body, Request};
use mlua::prelude::*;
use crate::lune::util::TableBuilder;
static ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub(super) struct ProcessedRequestId(usize);
impl ProcessedRequestId {
pub fn new() -> Self {
// NOTE: This may overflow after a couple billion requests,
// but that's completely fine... unless a request is still
// alive after billions more arrive and need to be handled
Self(ID_COUNTER.fetch_add(1, Ordering::Relaxed))
}
}
pub(super) struct ProcessedRequest {
pub id: ProcessedRequestId,
method: String,
path: String,
query: Vec<(String, String)>,
headers: Vec<(String, Vec<u8>)>,
body: Vec<u8>,
}
impl ProcessedRequest {
pub async fn from_request(req: Request<Body>) -> LuaResult<Self> {
let (head, body) = req.into_parts();
// FUTURE: We can do extra processing like async decompression here
let body = match to_bytes(body).await {
Err(_) => return Err(LuaError::runtime("Failed to read request body bytes")),
Ok(b) => b.to_vec(),
};
let method = head.method.to_string().to_ascii_uppercase();
let mut path = head.uri.path().to_string();
if path.is_empty() {
path = "/".to_string();
}
let query = head
.uri
.query()
.unwrap_or_default()
.split('&')
.filter_map(|q| q.split_once('='))
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
let mut headers = Vec::new();
let mut header_name = String::new();
for (name_opt, value) in head.headers.into_iter() {
if let Some(name) = name_opt {
header_name = name.to_string();
}
headers.push((header_name.clone(), value.as_bytes().to_vec()))
}
let id = ProcessedRequestId::new();
Ok(Self {
id,
method,
path,
query,
headers,
body,
})
}
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
// FUTURE: Make inner tables for query keys that have multiple values?
let query = lua.create_table_with_capacity(0, self.query.len())?;
for (key, value) in self.query.into_iter() {
query.set(key, value)?;
}
let headers = lua.create_table_with_capacity(0, self.headers.len())?;
for (key, value) in self.headers.into_iter() {
headers.set(key, lua.create_string(value)?)?;
}
let body = lua.create_string(self.body)?;
TableBuilder::new(lua)?
.with_value("method", self.method)?
.with_value("path", self.path)?
.with_value("query", query)?
.with_value("headers", headers)?
.with_value("body", body)?
.build_readonly()
}
}

View file

@ -1,223 +0,0 @@
use std::{
collections::HashMap,
convert::Infallible,
net::{Ipv4Addr, SocketAddr},
sync::Arc,
};
use hyper::{
server::{conn::AddrIncoming, Builder},
service::{make_service_fn, service_fn},
Server,
};
use hyper_tungstenite::{is_upgrade_request, upgrade, HyperWebsocket};
use mlua::prelude::*;
use tokio::sync::{mpsc, oneshot, Mutex};
use crate::lune::{
scheduler::Scheduler,
util::{futures::yield_forever, traits::LuaEmitErrorExt, TableBuilder},
};
use super::{
config::ServeConfig, processing::ProcessedRequest, response::NetServeResponse,
websocket::NetWebSocket,
};
pub(super) fn bind_to_addr(address: Ipv4Addr, port: u16) -> LuaResult<Builder<AddrIncoming>> {
let addr = SocketAddr::from((address, port));
match Server::try_bind(&addr) {
Ok(b) => Ok(b),
Err(e) => Err(LuaError::external(format!(
"Failed to bind to {addr}\n{}",
e.to_string()
.replace("error creating server listener: ", "> ")
))),
}
}
pub(super) fn create_server<'lua>(
lua: &'lua Lua,
sched: &'lua Scheduler,
config: ServeConfig<'lua>,
builder: Builder<AddrIncoming>,
) -> LuaResult<LuaTable<'lua>>
where
'lua: 'static, // FIXME: Get rid of static lifetime bound here
{
// Note that we need to use a mpsc here and not
// a oneshot channel since we move the sender
// into our table with the stop function
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
// Communicate between background thread(s) and main lua thread using mpsc and oneshot
let (tx_request, mut rx_request) = mpsc::channel::<ProcessedRequest>(64);
let (tx_websocket, mut rx_websocket) = mpsc::channel::<HyperWebsocket>(64);
let tx_request_arc = Arc::new(tx_request);
let tx_websocket_arc = Arc::new(tx_websocket);
let response_senders = Arc::new(Mutex::new(HashMap::new()));
let response_senders_bg = Arc::clone(&response_senders);
let response_senders_lua = Arc::clone(&response_senders_bg);
// Create our background service which will accept
// requests, do some processing, then forward to lua
let has_websocket_handler = config.handle_web_socket.is_some();
let hyper_make_service = make_service_fn(move |_| {
let tx_request = Arc::clone(&tx_request_arc);
let tx_websocket = Arc::clone(&tx_websocket_arc);
let response_senders = Arc::clone(&response_senders_bg);
let handler = service_fn(move |mut req| {
let tx_request = Arc::clone(&tx_request);
let tx_websocket = Arc::clone(&tx_websocket);
let response_senders = Arc::clone(&response_senders);
async move {
// FUTURE: Improve error messages when lua is busy and queue is full
if has_websocket_handler && is_upgrade_request(&req) {
let (response, ws) = match upgrade(&mut req, None) {
Err(_) => return Err(LuaError::runtime("Failed to upgrade websocket")),
Ok(v) => v,
};
if (tx_websocket.send(ws).await).is_err() {
return Err(LuaError::runtime("Lua handler is busy"));
}
Ok(response)
} else {
let processed = ProcessedRequest::from_request(req).await?;
let request_id = processed.id;
if (tx_request.send(processed).await).is_err() {
return Err(LuaError::runtime("Lua handler is busy"));
}
let (response_tx, response_rx) = oneshot::channel::<NetServeResponse>();
response_senders
.lock()
.await
.insert(request_id, response_tx);
match response_rx.await {
Err(_) => Err(LuaError::runtime("Internal Server Error")),
Ok(r) => r.into_response(),
}
}
}
});
async move { Ok::<_, Infallible>(handler) }
});
// Start up our service
sched.spawn(async move {
let result = builder
.http1_only(true) // Web sockets can only use http1
.http1_keepalive(true) // Web sockets must be kept alive
.serve(hyper_make_service)
.with_graceful_shutdown(async move {
if shutdown_rx.recv().await.is_none() {
// The channel was closed, meaning the serve handle
// was garbage collected by lua without being used
yield_forever().await;
}
});
if let Err(e) = result.await {
eprintln!("Net serve error: {e}")
}
});
// Spawn a local thread with access to lua and the same lifetime
sched.spawn_local(async move {
loop {
// Wait for either a request or a websocket to handle,
// if we got neither it means both channels were dropped
// and our server has stopped, either gracefully or panic
let (req, sock) = tokio::select! {
req = rx_request.recv() => (req, None),
sock = rx_websocket.recv() => (None, sock),
};
if req.is_none() && sock.is_none() {
break;
}
// NOTE: The closure here is not really necessary, we
// make the closure so that we can use the `?` operator
// and make a catch-all for errors in spawn_local below
let handle_request = config.handle_request.clone();
let handle_web_socket = config.handle_web_socket.clone();
let response_senders = Arc::clone(&response_senders_lua);
let response_fut = async move {
match (req, sock) {
(Some(req), _) => {
let req_id = req.id;
let req_table = req.into_lua_table(lua)?;
let thread_id = sched.push_back(lua, handle_request, req_table)?;
let thread_res = sched.wait_for_thread(lua, thread_id).await?;
let response = NetServeResponse::from_lua_multi(thread_res, lua)?;
let response_sender = response_senders
.lock()
.await
.remove(&req_id)
.expect("Response channel was removed unexpectedly");
// NOTE: We ignore the error here, if the sender is no longer
// being listened to its because our client disconnected during
// handler being called, which is fine and should not emit errors
response_sender.send(response).ok();
Ok(())
}
(_, Some(sock)) => {
let sock = sock.await.into_lua_err()?;
let sock_handler = handle_web_socket
.as_ref()
.cloned()
.expect("Got web socket but web socket handler is missing");
let sock_table = NetWebSocket::new(sock).into_lua_table(lua)?;
// NOTE: Web socket handler does not need to send any
// response back, the websocket upgrade response is
// automatically sent above in the background thread(s)
let thread_id = sched.push_back(lua, sock_handler, sock_table)?;
let _thread_res = sched.wait_for_thread(lua, thread_id).await?;
Ok(())
}
_ => unreachable!(),
}
};
/*
NOTE: It is currently not possible to spawn new background tasks from within
another background task with the Lune scheduler since they are locked behind a
mutex and we also need that mutex locked to be able to run a background task...
We need to do some work to make it so our unordered futures queues do
not require locking and then we can replace the following bit of code:
sched.spawn_local(async {
if let Err(e) = response_fut.await {
lua.emit_error(e);
}
});
*/
if let Err(e) = response_fut.await {
lua.emit_error(e);
}
}
});
// Create a new read-only table that contains methods
// for manipulating server behavior and shutting it down
let handle_stop = move |_, _: ()| match shutdown_tx.try_send(()) {
Ok(_) => Ok(()),
Err(_) => Err(LuaError::RuntimeError(
"Server has already been stopped".to_string(),
)),
};
TableBuilder::new(lua)?
.with_function("stop", handle_stop)?
.build_readonly()
}

View file

@ -0,0 +1,61 @@
use std::sync::atomic::{AtomicUsize, Ordering};
use mlua::prelude::*;
#[derive(Debug, Clone, Copy)]
pub(super) struct SvcKeys {
key_request: &'static str,
key_websocket: Option<&'static str>,
}
impl SvcKeys {
pub(super) fn new<'lua>(
lua: &'lua Lua,
handle_request: LuaFunction<'lua>,
handle_websocket: Option<LuaFunction<'lua>>,
) -> LuaResult<Self> {
static SERVE_COUNTER: AtomicUsize = AtomicUsize::new(0);
let count = SERVE_COUNTER.fetch_add(1, Ordering::Relaxed);
// NOTE: We leak strings here, but this is an acceptable tradeoff since programs
// generally only start one or a couple of servers and they are usually never dropped.
// Leaking here lets us keep this struct Copy and access the request handler callbacks
// very performantly, significantly reducing the per-request overhead of the server.
let key_request: &'static str =
Box::leak(format!("__net_serve_request_{count}").into_boxed_str());
let key_websocket: Option<&'static str> = if handle_websocket.is_some() {
Some(Box::leak(
format!("__net_serve_websocket_{count}").into_boxed_str(),
))
} else {
None
};
lua.set_named_registry_value(key_request, handle_request)?;
if let Some(key) = key_websocket {
lua.set_named_registry_value(key, handle_websocket.unwrap())?;
}
Ok(Self {
key_request,
key_websocket,
})
}
pub(super) fn has_websocket_handler(&self) -> bool {
self.key_websocket.is_some()
}
pub(super) fn request_handler<'lua>(&self, lua: &'lua Lua) -> LuaResult<LuaFunction<'lua>> {
lua.named_registry_value(self.key_request)
}
pub(super) fn websocket_handler<'lua>(
&self,
lua: &'lua Lua,
) -> LuaResult<Option<LuaFunction<'lua>>> {
self.key_websocket
.map(|key| lua.named_registry_value(key))
.transpose()
}
}

View file

@ -0,0 +1,105 @@
use std::{
net::SocketAddr,
rc::{Rc, Weak},
};
use hyper::server::conn::http1;
use hyper_util::rt::TokioIo;
use tokio::{net::TcpListener, pin};
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
use crate::lune::util::TableBuilder;
use super::config::ServeConfig;
mod keys;
mod request;
mod response;
mod service;
use keys::SvcKeys;
use service::Svc;
pub async fn serve<'lua>(
lua: &'lua Lua,
port: u16,
config: ServeConfig<'lua>,
) -> LuaResult<LuaTable<'lua>> {
let addr: SocketAddr = (config.address, port).into();
let listener = TcpListener::bind(addr).await?;
let (lua_svc, lua_inner) = {
let rc = lua
.app_data_ref::<Weak<Lua>>()
.expect("Missing weak lua ref")
.upgrade()
.expect("Lua was dropped unexpectedly");
(Rc::clone(&rc), rc)
};
let keys = SvcKeys::new(lua, config.handle_request, config.handle_web_socket)?;
let svc = Svc {
lua: lua_svc,
addr,
keys,
};
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
lua.spawn_local(async move {
let mut shutdown_rx_outer = shutdown_rx.clone();
loop {
// Create futures for accepting new connections and shutting down
let fut_shutdown = shutdown_rx_outer.changed();
let fut_accept = async {
let stream = match listener.accept().await {
Err(_) => return,
Ok((s, _)) => s,
};
let io = TokioIo::new(stream);
let svc = svc.clone();
let mut shutdown_rx_inner = shutdown_rx.clone();
lua_inner.spawn_local(async move {
let conn = http1::Builder::new()
.keep_alive(true) // Web sockets need this
.serve_connection(io, svc)
.with_upgrades();
// NOTE: Because we need to use keep_alive for websockets, we need to
// also manually poll this future and handle the shutdown signal here
pin!(conn);
tokio::select! {
_ = conn.as_mut() => {}
_ = shutdown_rx_inner.changed() => {
conn.as_mut().graceful_shutdown();
}
}
});
};
// Wait for either a new connection or a shutdown signal
tokio::select! {
_ = fut_accept => {}
res = fut_shutdown => {
// NOTE: We will only get a RecvError here if the serve handle is dropped,
// this means lua has garbage collected it and the user does not want
// to manually stop the server using the serve handle. Run forever.
if res.is_ok() {
break;
}
}
}
}
});
TableBuilder::new(lua)?
.with_value("ip", addr.ip().to_string())?
.with_value("port", addr.port())?
.with_function("stop", move |lua, _: ()| match shutdown_tx.send(true) {
Ok(_) => Ok(()),
Err(_) => Err(LuaError::runtime("Server already stopped")),
})?
.build_readonly()
}

View file

@ -0,0 +1,46 @@
use std::{collections::HashMap, net::SocketAddr};
use http::request::Parts;
use mlua::prelude::*;
pub(super) struct LuaRequest {
pub(super) _remote_addr: SocketAddr,
pub(super) head: Parts,
pub(super) body: Vec<u8>,
}
impl LuaUserData for LuaRequest {
fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) {
fields.add_field_method_get("method", |_, this| {
Ok(this.head.method.as_str().to_string())
});
fields.add_field_method_get("path", |_, this| Ok(this.head.uri.path().to_string()));
fields.add_field_method_get("query", |_, this| {
let query: HashMap<String, String> = this
.head
.uri
.query()
.unwrap_or_default()
.split('&')
.filter_map(|q| q.split_once('='))
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect();
Ok(query)
});
fields.add_field_method_get("headers", |_, this| {
let headers: HashMap<String, Vec<u8>> = this
.head
.headers
.iter()
.map(|(k, v)| (k.as_str().to_string(), v.as_bytes().to_vec()))
.collect();
Ok(headers)
});
fields.add_field_method_get("body", |lua, this| lua.create_string(&this.body));
}
}

View file

@ -1,52 +1,55 @@
use std::collections::HashMap;
use std::str::FromStr;
use http_body_util::Full;
use hyper::{
body::Bytes,
header::{HeaderName, HeaderValue},
HeaderMap, Response,
};
use hyper::{Body, Response};
use mlua::prelude::*;
#[derive(Debug, Clone, Copy)]
pub enum NetServeResponseKind {
pub(super) enum LuaResponseKind {
PlainText,
Table,
}
#[derive(Debug)]
pub struct NetServeResponse {
kind: NetServeResponseKind,
status: u16,
headers: HashMap<String, Vec<u8>>,
body: Option<Vec<u8>>,
pub(super) struct LuaResponse {
pub(super) kind: LuaResponseKind,
pub(super) status: u16,
pub(super) headers: HeaderMap,
pub(super) body: Option<Vec<u8>>,
}
impl NetServeResponse {
pub fn into_response(self) -> LuaResult<Response<Body>> {
impl LuaResponse {
pub(super) fn into_response(self) -> LuaResult<Response<Full<Bytes>>> {
Ok(match self.kind {
NetServeResponseKind::PlainText => Response::builder()
LuaResponseKind::PlainText => Response::builder()
.status(200)
.header("Content-Type", "text/plain")
.body(Body::from(self.body.unwrap()))
.body(Full::new(Bytes::from(self.body.unwrap())))
.into_lua_err()?,
NetServeResponseKind::Table => {
let mut response = Response::builder();
for (key, value) in self.headers {
response = response.header(&key, value);
}
response
LuaResponseKind::Table => {
let mut response = Response::builder()
.status(self.status)
.body(Body::from(self.body.unwrap_or_default()))
.into_lua_err()?
.body(Full::new(Bytes::from(self.body.unwrap_or_default())))
.into_lua_err()?;
response.headers_mut().extend(self.headers);
response
}
})
}
}
impl<'lua> FromLua<'lua> for NetServeResponse {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
impl FromLua<'_> for LuaResponse {
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
match value {
// Plain strings from the handler are plaintext responses
LuaValue::String(s) => Ok(Self {
kind: NetServeResponseKind::PlainText,
kind: LuaResponseKind::PlainText,
status: 200,
headers: HashMap::new(),
headers: HeaderMap::new(),
body: Some(s.as_bytes().to_vec()),
}),
// Tables are more detailed responses with potential status, headers, body
@ -55,18 +58,20 @@ impl<'lua> FromLua<'lua> for NetServeResponse {
let headers: Option<LuaTable> = t.get("headers")?;
let body: Option<LuaString> = t.get("body")?;
let mut headers_map = HashMap::new();
let mut headers_map = HeaderMap::new();
if let Some(headers) = headers {
for pair in headers.pairs::<String, LuaString>() {
let (h, v) = pair?;
headers_map.insert(h, v.as_bytes().to_vec());
let name = HeaderName::from_str(&h).into_lua_err()?;
let value = HeaderValue::from_bytes(v.as_bytes()).into_lua_err()?;
headers_map.insert(name, value);
}
}
let body_bytes = body.map(|s| s.as_bytes().to_vec());
Ok(Self {
kind: NetServeResponseKind::Table,
kind: LuaResponseKind::Table,
status: status.unwrap_or(200),
headers: headers_map,
body: body_bytes,

View file

@ -0,0 +1,81 @@
use std::{future::Future, net::SocketAddr, pin::Pin, rc::Rc};
use http_body_util::{BodyExt, Full};
use hyper::{
body::{Bytes, Incoming},
service::Service,
Request, Response,
};
use hyper_tungstenite::{is_upgrade_request, upgrade};
use mlua::prelude::*;
use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt};
use super::{
super::websocket::NetWebSocket, keys::SvcKeys, request::LuaRequest, response::LuaResponse,
};
#[derive(Debug, Clone)]
pub(super) struct Svc {
pub(super) lua: Rc<Lua>,
pub(super) addr: SocketAddr,
pub(super) keys: SvcKeys,
}
impl Service<Request<Incoming>> for Svc {
type Response = Response<Full<Bytes>>;
type Error = LuaError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
let lua = self.lua.clone();
let addr = self.addr;
let keys = self.keys;
if keys.has_websocket_handler() && is_upgrade_request(&req) {
Box::pin(async move {
let (res, sock) = upgrade(req, None).into_lua_err()?;
let lua_inner = lua.clone();
lua.spawn_local(async move {
let sock = sock.await.unwrap();
let lua_sock = NetWebSocket::new(sock);
let lua_tab = lua_sock.into_lua_table(&lua_inner).unwrap();
let handler_websocket: LuaFunction =
keys.websocket_handler(&lua_inner).unwrap().unwrap();
lua_inner
.push_thread_back(handler_websocket, lua_tab)
.unwrap();
});
Ok(res)
})
} else {
let (head, body) = req.into_parts();
Box::pin(async move {
let handler_request: LuaFunction = keys.request_handler(&lua).unwrap();
let body = body.collect().await.into_lua_err()?;
let body = body.to_bytes().to_vec();
let lua_req = LuaRequest {
_remote_addr: addr,
head,
body,
};
let thread_id = lua.push_thread_back(handler_request, lua_req)?;
lua.track_thread(thread_id);
lua.wait_for_thread(thread_id).await;
let thread_res = lua
.get_thread_result(thread_id)
.expect("Missing handler thread result")?;
LuaResponse::from_lua_multi(thread_res, &lua)?.into_response()
})
}
}
}

View file

@ -1,14 +1,20 @@
use std::collections::HashMap;
use hyper::{
header::{CONTENT_ENCODING, CONTENT_LENGTH},
HeaderMap,
};
use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH};
use reqwest::header::HeaderMap;
use mlua::prelude::*;
use crate::lune::util::TableBuilder;
pub fn create_user_agent_header() -> String {
let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY")
.trim_start_matches("https://github.com/")
.split_once('/')
.unwrap();
format!("{github_owner}-{github_repo}-cli")
}
pub fn header_map_to_table(
lua: &Lua,
headers: HeaderMap,

View file

@ -1,6 +1,8 @@
use std::sync::Arc;
use std::sync::{
atomic::{AtomicBool, AtomicU16, Ordering},
Arc,
};
use hyper::upgrade::Upgraded;
use mlua::prelude::*;
use futures_util::{
@ -9,7 +11,6 @@ use futures_util::{
};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::TcpStream,
sync::Mutex as AsyncMutex,
};
@ -20,25 +21,25 @@ use hyper_tungstenite::{
},
WebSocketStream,
};
use tokio_tungstenite::MaybeTlsStream;
use crate::lune::util::TableBuilder;
// Wrapper implementation for compatibility and changing colon syntax to dot syntax
const WEB_SOCKET_IMPL_LUA: &str = r#"
return freeze(setmetatable({
close = function(...)
return close(websocket, ...)
return websocket:close(...)
end,
send = function(...)
return send(websocket, ...)
return websocket:send(...)
end,
next = function(...)
return next(websocket, ...)
return websocket:next(...)
end,
}, {
__index = function(self, key)
if key == "closeCode" then
return close_code(websocket)
return websocket.closeCode
end
end,
}))
@ -46,7 +47,8 @@ return freeze(setmetatable({
#[derive(Debug)]
pub struct NetWebSocket<T> {
close_code: Arc<AsyncMutex<Option<u16>>>,
close_code_exists: Arc<AtomicBool>,
close_code_value: Arc<AtomicU16>,
read_stream: Arc<AsyncMutex<SplitStream<WebSocketStream<T>>>>,
write_stream: Arc<AsyncMutex<SplitSink<WebSocketStream<T>, WsMessage>>>,
}
@ -54,7 +56,8 @@ pub struct NetWebSocket<T> {
impl<T> Clone for NetWebSocket<T> {
fn clone(&self) -> Self {
Self {
close_code: Arc::clone(&self.close_code),
close_code_exists: Arc::clone(&self.close_code_exists),
close_code_value: Arc::clone(&self.close_code_value),
read_stream: Arc::clone(&self.read_stream),
write_stream: Arc::clone(&self.write_stream),
}
@ -63,22 +66,78 @@ impl<T> Clone for NetWebSocket<T> {
impl<T> NetWebSocket<T>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
pub fn new(value: WebSocketStream<T>) -> Self {
let (write, read) = value.split();
Self {
close_code: Arc::new(AsyncMutex::new(None)),
close_code_exists: Arc::new(AtomicBool::new(false)),
close_code_value: Arc::new(AtomicU16::new(0)),
read_stream: Arc::new(AsyncMutex::new(read)),
write_stream: Arc::new(AsyncMutex::new(write)),
}
}
fn into_lua_table_with_env<'lua>(
lua: &'lua Lua,
env: LuaTable<'lua>,
) -> LuaResult<LuaTable<'lua>> {
fn get_close_code(&self) -> Option<u16> {
if self.close_code_exists.load(Ordering::Relaxed) {
Some(self.close_code_value.load(Ordering::Relaxed))
} else {
None
}
}
fn set_close_code(&self, code: u16) {
self.close_code_exists.store(true, Ordering::Relaxed);
self.close_code_value.store(code, Ordering::Relaxed);
}
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
let mut ws = self.write_stream.lock().await;
ws.send(msg).await.into_lua_err()
}
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
let mut ws = self.read_stream.lock().await;
ws.next().await.transpose().into_lua_err()
}
pub async fn close(&self, code: Option<u16>) -> LuaResult<()> {
if self.close_code_exists.load(Ordering::Relaxed) {
return Err(LuaError::runtime("Socket has already been closed"));
}
self.send(WsMessage::Close(Some(WsCloseFrame {
code: match code {
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
Some(code) => {
return Err(LuaError::runtime(format!(
"Close code must be between 1000 and 4999, got {code}"
)))
}
None => WsCloseCode::Normal,
},
reason: "".into(),
})))
.await?;
let mut ws = self.write_stream.lock().await;
ws.close().await.into_lua_err()
}
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
let table_freeze = lua
.globals()
.get::<_, LuaTable>("table")?
.get::<_, LuaFunction>("freeze")?;
let env = TableBuilder::new(lua)?
.with_value("websocket", self.clone())?
.with_value("setmetatable", setmetatable)?
.with_value("freeze", table_freeze)?
.build_readonly()?;
lua.load(WEB_SOCKET_IMPL_LUA)
.set_name("websocket")
.set_environment(env)
@ -86,149 +145,46 @@ where
}
}
type NetWebSocketStreamClient = MaybeTlsStream<TcpStream>;
impl NetWebSocket<NetWebSocketStreamClient> {
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
let table_freeze = lua
.globals()
.get::<_, LuaTable>("table")?
.get::<_, LuaFunction>("freeze")?;
let socket_env = TableBuilder::new(lua)?
.with_value("websocket", self)?
.with_function("close_code", close_code::<NetWebSocketStreamClient>)?
.with_async_function("close", close::<NetWebSocketStreamClient>)?
.with_async_function("send", send::<NetWebSocketStreamClient>)?
.with_async_function("next", next::<NetWebSocketStreamClient>)?
.with_value("setmetatable", setmetatable)?
.with_value("freeze", table_freeze)?
.build_readonly()?;
Self::into_lua_table_with_env(lua, socket_env)
}
}
type NetWebSocketStreamServer = Upgraded;
impl NetWebSocket<NetWebSocketStreamServer> {
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
let table_freeze = lua
.globals()
.get::<_, LuaTable>("table")?
.get::<_, LuaFunction>("freeze")?;
let socket_env = TableBuilder::new(lua)?
.with_value("websocket", self)?
.with_function("close_code", close_code::<NetWebSocketStreamServer>)?
.with_async_function("close", close::<NetWebSocketStreamServer>)?
.with_async_function("send", send::<NetWebSocketStreamServer>)?
.with_async_function("next", next::<NetWebSocketStreamServer>)?
.with_value("setmetatable", setmetatable)?
.with_value("freeze", table_freeze)?
.build_readonly()?;
Self::into_lua_table_with_env(lua, socket_env)
}
}
impl<T> LuaUserData for NetWebSocket<T> {}
fn close_code<'lua, T>(
_lua: &'lua Lua,
socket: LuaUserDataRef<'lua, NetWebSocket<T>>,
) -> LuaResult<LuaValue<'lua>>
impl<T> LuaUserData for NetWebSocket<T>
where
T: AsyncRead + AsyncWrite + Unpin,
T: AsyncRead + AsyncWrite + Unpin + 'static,
{
Ok(
match *socket
.close_code
.try_lock()
.expect("Failed to lock close code")
{
Some(code) => LuaValue::Number(code as f64),
None => LuaValue::Nil,
},
)
}
fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) {
fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code()));
}
async fn close<'lua, T>(
_lua: &'lua Lua,
(socket, code): (LuaUserDataRef<'lua, NetWebSocket<T>>, Option<u16>),
) -> LuaResult<()>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut ws = socket.write_stream.lock().await;
fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
methods.add_async_method("close", |lua, this, code: Option<u16>| async move {
this.close(code).await
});
ws.send(WsMessage::Close(Some(WsCloseFrame {
code: match code {
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
Some(code) => {
return Err(LuaError::RuntimeError(format!(
"Close code must be between 1000 and 4999, got {code}"
)))
methods.add_async_method(
"send",
|_, this, (string, as_binary): (LuaString, Option<bool>)| async move {
this.send(if as_binary.unwrap_or_default() {
WsMessage::Binary(string.as_bytes().to_vec())
} else {
let s = string.to_str().into_lua_err()?;
WsMessage::Text(s.to_string())
})
.await
},
);
methods.add_async_method("next", |lua, this, _: ()| async move {
let msg = this.next().await?;
if let Some(WsMessage::Close(Some(frame))) = msg.as_ref() {
this.set_close_code(frame.code.into());
}
None => WsCloseCode::Normal,
},
reason: "".into(),
})))
.await
.into_lua_err()?;
let res = ws.close();
res.await.into_lua_err()
}
async fn send<'lua, T>(
_lua: &'lua Lua,
(socket, string, as_binary): (
LuaUserDataRef<'lua, NetWebSocket<T>>,
LuaString<'lua>,
Option<bool>,
),
) -> LuaResult<()>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let msg = if matches!(as_binary, Some(true)) {
WsMessage::Binary(string.as_bytes().to_vec())
} else {
let s = string.to_str().into_lua_err()?;
WsMessage::Text(s.to_string())
};
let mut ws = socket.write_stream.lock().await;
ws.send(msg).await.into_lua_err()
}
async fn next<'lua, T>(
lua: &'lua Lua,
socket: LuaUserDataRef<'lua, NetWebSocket<T>>,
) -> LuaResult<LuaValue<'lua>>
where
T: AsyncRead + AsyncWrite + Unpin,
{
let mut ws = socket.read_stream.lock().await;
let item = ws.next().await.transpose().into_lua_err();
let msg = match item {
Ok(Some(WsMessage::Close(msg))) => {
if let Some(msg) = &msg {
let mut code = socket.close_code.lock().await;
*code = Some(msg.code.into());
}
Ok(Some(WsMessage::Close(msg)))
}
val => val,
}?;
while let Some(msg) = &msg {
let msg_string_opt = match msg {
WsMessage::Binary(bin) => Some(lua.create_string(bin)?),
WsMessage::Text(txt) => Some(lua.create_string(txt)?),
// Stop waiting for next message if we get a close message
WsMessage::Close(_) => return Ok(LuaValue::Nil),
// Ignore ping/pong/frame messages, they are handled by tungstenite
_ => None,
};
if let Some(msg_string) = msg_string_opt {
return Ok(LuaValue::String(msg_string));
}
Ok(match msg {
Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?),
Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?),
Some(WsMessage::Close(_)) | None => LuaValue::Nil,
// Ignore ping/pong/frame messages, they are handled by tungstenite
msg => unreachable!("Unhandled message: {:?}", msg),
})
});
}
Ok(LuaValue::Nil)
}

View file

@ -5,13 +5,11 @@ use std::{
};
use mlua::prelude::*;
use mlua_luau_scheduler::{Functions, LuaSpawnExt};
use os_str_bytes::RawOsString;
use tokio::io::AsyncWriteExt;
use crate::lune::{
scheduler::Scheduler,
util::{paths::CWD, TableBuilder},
};
use crate::lune::util::{paths::CWD, TableBuilder};
mod tee_writer;
@ -21,12 +19,7 @@ use options::ProcessSpawnOptions;
mod wait_for_child;
use wait_for_child::{wait_for_child, WaitForChildResult};
const PROCESS_EXIT_IMPL_LUA: &str = r#"
exit(...)
yield()
"#;
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
let cwd_str = {
let cwd_str = CWD.to_string_lossy().to_string();
if !cwd_str.ends_with(path::MAIN_SEPARATOR) {
@ -56,30 +49,9 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
.build_readonly()?,
)?
.build_readonly()?;
// Create our process exit function, this is a bit involved since
// we have no way to yield from c / rust, we need to load a lua
// chunk that will set the exit code and yield for us instead
let coroutine_yield = lua
.globals()
.get::<_, LuaTable>("coroutine")?
.get::<_, LuaFunction>("yield")?;
let set_scheduler_exit_code = lua.create_function(|lua, code: Option<u8>| {
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
sched.set_exit_code(code.unwrap_or_default());
Ok(())
})?;
let process_exit = lua
.load(PROCESS_EXIT_IMPL_LUA)
.set_name("=process.exit")
.set_environment(
TableBuilder::new(lua)?
.with_value("yield", coroutine_yield)?
.with_value("exit", set_scheduler_exit_code)?
.build_readonly()?,
)
.into_function()?;
// Create our process exit function, the scheduler crate provides this
let fns = Functions::new(lua)?;
let process_exit = fns.exit;
// Create the full process table
TableBuilder::new(lua)?
.with_value("os", os)?
@ -165,22 +137,10 @@ async fn process_spawn(
lua: &Lua,
(program, args, options): (String, Option<Vec<String>>, ProcessSpawnOptions),
) -> LuaResult<LuaTable> {
/*
Spawn the new process in the background, letting the tokio
runtime place it on a different thread if possible / necessary
Note that we have to use our scheduler here, we can't
be using tokio::task::spawn directly because our lua
scheduler would not drive those futures to completion
*/
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
let res = sched
let res = lua
.spawn(spawn_command(program, args, options))
.await
.expect("Failed to receive result of spawned process")?;
.expect("Failed to receive result of spawned process");
/*
NOTE: If an exit code was not given by the child process,

View file

@ -1,4 +1,5 @@
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSpawnExt;
use once_cell::sync::OnceCell;
use crate::{
@ -11,11 +12,9 @@ use crate::{
},
};
use tokio::task;
static REFLECTION_DATABASE: OnceCell<ReflectionDatabase> = OnceCell::new();
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
let mut roblox_constants = Vec::new();
let roblox_module = roblox::module(lua)?;
@ -41,12 +40,12 @@ async fn deserialize_place<'lua>(
contents: LuaString<'lua>,
) -> LuaResult<LuaValue<'lua>> {
let bytes = contents.as_bytes().to_vec();
let fut = task::spawn_blocking(move || {
let fut = lua.spawn_blocking(move || {
let doc = Document::from_bytes(bytes, DocumentKind::Place)?;
let data_model = doc.into_data_model_instance()?;
Ok::<_, DocumentError>(data_model)
});
fut.await.into_lua_err()??.into_lua(lua)
fut.await.into_lua_err()?.into_lua(lua)
}
async fn deserialize_model<'lua>(
@ -54,12 +53,12 @@ async fn deserialize_model<'lua>(
contents: LuaString<'lua>,
) -> LuaResult<LuaValue<'lua>> {
let bytes = contents.as_bytes().to_vec();
let fut = task::spawn_blocking(move || {
let fut = lua.spawn_blocking(move || {
let doc = Document::from_bytes(bytes, DocumentKind::Model)?;
let instance_array = doc.into_instance_array()?;
Ok::<_, DocumentError>(instance_array)
});
fut.await.into_lua_err()??.into_lua(lua)
fut.await.into_lua_err()?.into_lua(lua)
}
async fn serialize_place<'lua>(
@ -67,7 +66,7 @@ async fn serialize_place<'lua>(
(data_model, as_xml): (LuaUserDataRef<'lua, Instance>, Option<bool>),
) -> LuaResult<LuaString<'lua>> {
let data_model = (*data_model).clone();
let fut = task::spawn_blocking(move || {
let fut = lua.spawn_blocking(move || {
let doc = Document::from_data_model_instance(data_model)?;
let bytes = doc.to_bytes_with_format(match as_xml {
Some(true) => DocumentFormat::Xml,
@ -75,7 +74,7 @@ async fn serialize_place<'lua>(
})?;
Ok::<_, DocumentError>(bytes)
});
let bytes = fut.await.into_lua_err()??;
let bytes = fut.await.into_lua_err()?;
lua.create_string(bytes)
}
@ -84,7 +83,7 @@ async fn serialize_model<'lua>(
(instances, as_xml): (Vec<LuaUserDataRef<'lua, Instance>>, Option<bool>),
) -> LuaResult<LuaString<'lua>> {
let instances = instances.iter().map(|i| (*i).clone()).collect();
let fut = task::spawn_blocking(move || {
let fut = lua.spawn_blocking(move || {
let doc = Document::from_instance_array(instances)?;
let bytes = doc.to_bytes_with_format(match as_xml {
Some(true) => DocumentFormat::Xml,
@ -92,7 +91,7 @@ async fn serialize_model<'lua>(
})?;
Ok::<_, DocumentError>(bytes)
});
let bytes = fut.await.into_lua_err()??;
let bytes = fut.await.into_lua_err()?;
lua.create_string(bytes)
}

View file

@ -1,9 +1,7 @@
use lz4_flex::{compress_prepend_size, decompress_size_prepended};
use mlua::prelude::*;
use tokio::{
io::{copy, BufReader},
task,
};
use lz4_flex::{compress_prepend_size, decompress_size_prepended};
use tokio::io::{copy, BufReader};
use async_compression::{
tokio::bufread::{
@ -100,9 +98,7 @@ pub async fn compress<'lua>(
) -> LuaResult<Vec<u8>> {
if let CompressDecompressFormat::LZ4 = format {
let source = source.as_ref().to_vec();
return task::spawn_blocking(move || compress_prepend_size(&source))
.await
.into_lua_err();
return Ok(blocking::unblock(move || compress_prepend_size(&source)).await);
}
let mut bytes = Vec::new();
@ -133,9 +129,8 @@ pub async fn decompress<'lua>(
) -> LuaResult<Vec<u8>> {
if let CompressDecompressFormat::LZ4 = format {
let source = source.as_ref().to_vec();
return task::spawn_blocking(move || decompress_size_prepended(&source))
return blocking::unblock(move || decompress_size_prepended(&source))
.await
.into_lua_err()?
.into_lua_err();
}

View file

@ -8,7 +8,7 @@ use encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat};
use crate::lune::util::TableBuilder;
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
TableBuilder::new(lua)?
.with_function("encode", serde_encode)?
.with_function("decode", serde_decode)?

View file

@ -1,10 +1,8 @@
use mlua::prelude::*;
use dialoguer::{theme::ColorfulTheme, Confirm, Input, MultiSelect, Select};
use tokio::{
io::{self, AsyncWriteExt},
task,
};
use mlua_luau_scheduler::LuaSpawnExt;
use tokio::io::{self, AsyncWriteExt};
use crate::lune::util::{
formatting::{
@ -16,7 +14,7 @@ use crate::lune::util::{
mod prompt;
use prompt::{PromptKind, PromptOptions, PromptResult};
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'_>> {
pub fn create(lua: &Lua) -> LuaResult<LuaTable<'_>> {
TableBuilder::new(lua)?
.with_function("color", stdio_color)?
.with_function("style", stdio_style)?
@ -55,10 +53,10 @@ async fn stdio_ewrite(_: &Lua, s: LuaString<'_>) -> LuaResult<()> {
Ok(())
}
async fn stdio_prompt(_: &Lua, options: PromptOptions) -> LuaResult<PromptResult> {
task::spawn_blocking(move || prompt(options))
async fn stdio_prompt(lua: &Lua, options: PromptOptions) -> LuaResult<PromptResult> {
lua.spawn_blocking(move || prompt(options))
.await
.into_lua_err()?
.into_lua_err()
}
fn prompt(options: PromptOptions) -> LuaResult<PromptResult> {

View file

@ -2,120 +2,51 @@ use std::time::Duration;
use mlua::prelude::*;
use mlua_luau_scheduler::Functions;
use tokio::time::{self, Instant};
use crate::lune::{scheduler::Scheduler, util::TableBuilder};
use crate::lune::util::TableBuilder;
mod tof;
use tof::LuaThreadOrFunction;
/*
The spawn function needs special treatment,
we need to yield right away to allow the
spawned task to run until first yield
1. Schedule this current thread at the front
2. Schedule given thread/function at the front,
the previous schedule now comes right after
3. Give control over to the scheduler, which will
resume the above tasks in order when its ready
*/
const SPAWN_IMPL_LUA: &str = r#"
push(currentThread())
local thread = push(...)
yield()
return thread
const DELAY_IMPL_LUA: &str = r#"
return defer(function(...)
wait(select(1, ...))
spawn(select(2, ...))
end, ...)
"#;
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'_>> {
let coroutine_running = lua
.globals()
.get::<_, LuaTable>("coroutine")?
.get::<_, LuaFunction>("running")?;
let coroutine_yield = lua
.globals()
.get::<_, LuaTable>("coroutine")?
.get::<_, LuaFunction>("yield")?;
let push_front =
lua.create_function(|lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
let thread = tof.into_thread(lua)?;
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
sched.push_front(lua, thread.clone(), args)?;
Ok(thread)
})?;
let task_spawn_env = TableBuilder::new(lua)?
.with_value("currentThread", coroutine_running)?
.with_value("yield", coroutine_yield)?
.with_value("push", push_front)?
pub fn create(lua: &Lua) -> LuaResult<LuaTable<'_>> {
let fns = Functions::new(lua)?;
// Create wait & delay functions
let task_wait = lua.create_async_function(wait)?;
let task_delay_env = TableBuilder::new(lua)?
.with_value("select", lua.globals().get::<_, LuaFunction>("select")?)?
.with_value("spawn", fns.spawn.clone())?
.with_value("defer", fns.defer.clone())?
.with_value("wait", task_wait.clone())?
.build_readonly()?;
let task_spawn = lua
.load(SPAWN_IMPL_LUA)
.set_name("task.spawn")
.set_environment(task_spawn_env)
let task_delay = lua
.load(DELAY_IMPL_LUA)
.set_name("task.delay")
.set_environment(task_delay_env)
.into_function()?;
// Overwrite resume & wrap functions on the coroutine global
// with ones that are compatible with our scheduler
let co = lua.globals().get::<_, LuaTable>("coroutine")?;
co.set("resume", fns.resume.clone())?;
co.set("wrap", fns.wrap.clone())?;
TableBuilder::new(lua)?
.with_function("cancel", task_cancel)?
.with_function("defer", task_defer)?
.with_function("delay", task_delay)?
.with_value("spawn", task_spawn)?
.with_async_function("wait", task_wait)?
.with_value("cancel", fns.cancel)?
.with_value("defer", fns.defer)?
.with_value("delay", task_delay)?
.with_value("spawn", fns.spawn)?
.with_value("wait", task_wait)?
.build_readonly()
}
fn task_cancel(lua: &Lua, thread: LuaThread) -> LuaResult<()> {
let close = lua
.globals()
.get::<_, LuaTable>("coroutine")?
.get::<_, LuaFunction>("close")?;
match close.call(thread) {
Err(LuaError::CoroutineInactive) => Ok(()),
Err(e) => Err(e),
Ok(()) => Ok(()),
}
}
fn task_defer<'lua>(
lua: &'lua Lua,
(tof, args): (LuaThreadOrFunction<'lua>, LuaMultiValue<'_>),
) -> LuaResult<LuaThread<'lua>> {
let thread = tof.into_thread(lua)?;
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
sched.push_back(lua, thread.clone(), args)?;
Ok(thread)
}
// FIXME: `self` escapes outside of method because we are borrowing `tof` and
// `args` when we call `schedule_future_thread` in the lua function body below
// For now we solve this by using the 'static lifetime bound in the impl
fn task_delay<'lua>(
lua: &'lua Lua,
(secs, tof, args): (f64, LuaThreadOrFunction<'lua>, LuaMultiValue<'lua>),
) -> LuaResult<LuaThread<'lua>>
where
'lua: 'static,
{
let thread = tof.into_thread(lua)?;
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
let thread2 = thread.clone();
sched.spawn_thread(lua, thread.clone(), async move {
let duration = Duration::from_secs_f64(secs);
time::sleep(duration).await;
sched.push_back(lua, thread2, args)?;
Ok(())
})?;
Ok(thread)
}
async fn task_wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
async fn wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
let duration = Duration::from_secs_f64(secs.unwrap_or_default());
let before = Instant::now();

View file

@ -1,30 +0,0 @@
use mlua::prelude::*;
#[derive(Clone)]
pub(super) enum LuaThreadOrFunction<'lua> {
Thread(LuaThread<'lua>),
Function(LuaFunction<'lua>),
}
impl<'lua> LuaThreadOrFunction<'lua> {
pub(super) fn into_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
match self {
Self::Thread(t) => Ok(t),
Self::Function(f) => lua.create_thread(f),
}
}
}
impl<'lua> FromLua<'lua> for LuaThreadOrFunction<'lua> {
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
match value {
LuaValue::Thread(t) => Ok(Self::Thread(t)),
LuaValue::Function(f) => Ok(Self::Function(f)),
value => Err(LuaError::FromLuaConversionError {
from: value.type_name(),
to: "LuaThreadOrFunction",
message: Some("Expected thread or function".to_string()),
}),
}
}
}

View file

@ -8,7 +8,7 @@ mod require;
mod version;
mod warn;
pub fn inject_all(lua: &'static Lua) -> LuaResult<()> {
pub fn inject_all(lua: &Lua) -> LuaResult<()> {
let all = TableBuilder::new(lua)?
.with_value("_G", g_table::create(lua)?)?
.with_value("_VERSION", version::create(lua)?)?

View file

@ -9,7 +9,8 @@ use crate::lune::util::{
use super::context::*;
pub(super) async fn require<'lua, 'ctx>(
ctx: &'ctx RequireContext<'lua>,
lua: &'lua Lua,
ctx: &'ctx RequireContext,
source: &str,
alias: &str,
path: &str,
@ -71,5 +72,5 @@ where
LuaError::runtime(format!("failed to find relative path for alias '{alias}'"))
})?;
super::path::require_abs_rel(ctx, abs_path, rel_path).await
super::path::require_abs_rel(lua, ctx, abs_path, rel_path).await
}

View file

@ -3,12 +3,12 @@ use mlua::prelude::*;
use super::context::*;
pub(super) async fn require<'lua, 'ctx>(
ctx: &'ctx RequireContext<'lua>,
lua: &'lua Lua,
ctx: &'ctx RequireContext,
name: &str,
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'ctx,
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
{
ctx.load_builtin(name)
ctx.load_builtin(lua, name)
}

View file

@ -5,6 +5,7 @@ use std::{
};
use mlua::prelude::*;
use mlua_luau_scheduler::LuaSchedulerExt;
use tokio::{
fs,
sync::{
@ -13,11 +14,7 @@ use tokio::{
},
};
use crate::lune::{
builtins::LuneBuiltin,
scheduler::{IntoLuaThread, Scheduler},
util::paths::CWD,
};
use crate::lune::{builtins::LuneBuiltin, util::paths::CWD};
/**
Context containing cached results for all `require` operations.
@ -26,14 +23,13 @@ use crate::lune::{
path will first be transformed into an absolute path.
*/
#[derive(Debug, Clone)]
pub(super) struct RequireContext<'lua> {
lua: &'lua Lua,
pub(super) struct RequireContext {
cache_builtins: Arc<AsyncMutex<HashMap<LuneBuiltin, LuaResult<LuaRegistryKey>>>>,
cache_results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>,
cache_pending: Arc<AsyncMutex<HashMap<PathBuf, Sender<()>>>>,
}
impl<'lua> RequireContext<'lua> {
impl RequireContext {
/**
Creates a new require context for the given [`Lua`] struct.
@ -41,9 +37,8 @@ impl<'lua> RequireContext<'lua> {
context should be created per [`Lua`] struct, creating more
than one context may lead to undefined require-behavior.
*/
pub fn new(lua: &'lua Lua) -> Self {
pub fn new() -> Self {
Self {
lua,
cache_builtins: Arc::new(AsyncMutex::new(HashMap::new())),
cache_results: Arc::new(AsyncMutex::new(HashMap::new())),
cache_pending: Arc::new(AsyncMutex::new(HashMap::new())),
@ -107,7 +102,11 @@ impl<'lua> RequireContext<'lua> {
Will panic if the path has not been cached, use [`is_cached`] first.
*/
pub fn get_from_cache(&self, abs_path: impl AsRef<Path>) -> LuaResult<LuaMultiValue<'lua>> {
pub fn get_from_cache<'lua>(
&self,
lua: &'lua Lua,
abs_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> {
let results = self
.cache_results
.try_lock()
@ -119,8 +118,7 @@ impl<'lua> RequireContext<'lua> {
match cached {
Err(e) => Err(e.clone()),
Ok(k) => {
let multi_vec = self
.lua
let multi_vec = lua
.registry_value::<Vec<LuaValue>>(k)
.expect("Missing require result in lua registry");
Ok(LuaMultiValue::from_vec(multi_vec))
@ -133,8 +131,9 @@ impl<'lua> RequireContext<'lua> {
Will panic if the path has not been cached, use [`is_cached`] first.
*/
pub async fn wait_for_cache(
pub async fn wait_for_cache<'lua>(
&self,
lua: &'lua Lua,
abs_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> {
let mut thread_recv = {
@ -150,43 +149,37 @@ impl<'lua> RequireContext<'lua> {
thread_recv.recv().await.into_lua_err()?;
self.get_from_cache(abs_path.as_ref())
self.get_from_cache(lua, abs_path.as_ref())
}
async fn load(
async fn load<'lua>(
&self,
lua: &'lua Lua,
abs_path: impl AsRef<Path>,
rel_path: impl AsRef<Path>,
) -> LuaResult<LuaRegistryKey> {
let abs_path = abs_path.as_ref();
let rel_path = rel_path.as_ref();
let sched = self
.lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
// Read the file at the given path, try to parse and
// load it into a new lua thread that we can schedule
let file_contents = fs::read(&abs_path).await?;
let file_thread = self
.lua
let file_thread = lua
.load(file_contents)
.set_name(rel_path.to_string_lossy().to_string())
.into_function()?
.into_lua_thread(self.lua)?;
.set_name(rel_path.to_string_lossy().to_string());
// Schedule the thread to run, wait for it to finish running
let thread_id = sched.push_back(self.lua, file_thread, ())?;
let thread_res = sched.wait_for_thread(self.lua, thread_id).await;
let thread_id = lua.push_thread_back(file_thread, ())?;
lua.track_thread(thread_id);
lua.wait_for_thread(thread_id).await;
let thread_res = lua.get_thread_result(thread_id).unwrap();
// Return the result of the thread, storing any lua value(s) in the registry
match thread_res {
Err(e) => Err(e),
Ok(v) => {
let multi_vec = v.into_vec();
let multi_key = self
.lua
let multi_key = lua
.create_registry_value(multi_vec)
.expect("Failed to store require result in registry - out of memory");
Ok(multi_key)
@ -197,8 +190,9 @@ impl<'lua> RequireContext<'lua> {
/**
Loads (requires) the file at the given path.
*/
pub async fn load_with_caching(
pub async fn load_with_caching<'lua>(
&self,
lua: &'lua Lua,
abs_path: impl AsRef<Path>,
rel_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>> {
@ -213,12 +207,11 @@ impl<'lua> RequireContext<'lua> {
.insert(abs_path.to_path_buf(), broadcast_tx);
// Try to load at this abs path
let load_res = self.load(abs_path, rel_path).await;
let load_res = self.load(lua, abs_path, rel_path).await;
let load_val = match &load_res {
Err(e) => Err(e.clone()),
Ok(k) => {
let multi_vec = self
.lua
let multi_vec = lua
.registry_value::<Vec<LuaValue>>(k)
.expect("Failed to fetch require result from registry");
Ok(LuaMultiValue::from_vec(multi_vec))
@ -250,10 +243,11 @@ impl<'lua> RequireContext<'lua> {
/**
Loads (requires) the builtin with the given name.
*/
pub fn load_builtin(&self, name: impl AsRef<str>) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
{
pub fn load_builtin<'lua>(
&self,
lua: &'lua Lua,
name: impl AsRef<str>,
) -> LuaResult<LuaMultiValue<'lua>> {
let builtin: LuneBuiltin = match name.as_ref().parse() {
Err(e) => return Err(LuaError::runtime(e)),
Ok(b) => b,
@ -268,8 +262,7 @@ impl<'lua> RequireContext<'lua> {
return match res {
Err(e) => return Err(e.clone()),
Ok(key) => {
let multi_vec = self
.lua
let multi_vec = lua
.registry_value::<Vec<LuaValue>>(key)
.expect("Missing builtin result in lua registry");
Ok(LuaMultiValue::from_vec(multi_vec))
@ -277,7 +270,7 @@ impl<'lua> RequireContext<'lua> {
};
};
let result = builtin.create(self.lua);
let result = builtin.create(lua);
cache.insert(
builtin,
@ -285,8 +278,7 @@ impl<'lua> RequireContext<'lua> {
Err(e) => Err(e),
Ok(multi) => {
let multi_vec = multi.into_vec();
let multi_key = self
.lua
let multi_key = lua
.create_registry_value(multi_vec)
.expect("Failed to store require result in registry - out of memory");
Ok(multi_key)

View file

@ -1,6 +1,6 @@
use mlua::prelude::*;
use crate::lune::{scheduler::LuaSchedulerExt, util::TableBuilder};
use crate::lune::util::TableBuilder;
mod context;
use context::RequireContext;
@ -13,8 +13,8 @@ const REQUIRE_IMPL: &str = r#"
return require(source(), ...)
"#;
pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> {
lua.set_app_data(RequireContext::new(lua));
pub fn create(lua: &Lua) -> LuaResult<impl IntoLua<'_>> {
lua.set_app_data(RequireContext::new());
/*
Require implementation needs a few workarounds:
@ -62,10 +62,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> {
async fn require<'lua>(
lua: &'lua Lua,
(source, path): (LuaString<'lua>, LuaString<'lua>),
) -> LuaResult<LuaMultiValue<'lua>>
where
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
{
) -> LuaResult<LuaMultiValue<'lua>> {
let source = source
.to_str()
.into_lua_err()
@ -86,13 +83,13 @@ where
.strip_prefix("@lune/")
.map(|name| name.to_ascii_lowercase())
{
builtin::require(&context, &builtin_name).await
builtin::require(lua, &context, &builtin_name).await
} else if let Some(aliased_path) = path.strip_prefix('@') {
let (alias, path) = aliased_path.split_once('/').ok_or(LuaError::runtime(
"Require with custom alias must contain '/' delimiter",
))?;
alias::require(&context, &source, alias, path).await
alias::require(lua, &context, &source, alias, path).await
} else {
path::require(&context, &source, &path).await
path::require(lua, &context, &source, &path).await
}
}

View file

@ -5,7 +5,8 @@ use mlua::prelude::*;
use super::context::*;
pub(super) async fn require<'lua, 'ctx>(
ctx: &'ctx RequireContext<'lua>,
lua: &'lua Lua,
ctx: &'ctx RequireContext,
source: &str,
path: &str,
) -> LuaResult<LuaMultiValue<'lua>>
@ -13,11 +14,12 @@ where
'lua: 'ctx,
{
let (abs_path, rel_path) = ctx.resolve_paths(source, path)?;
require_abs_rel(ctx, abs_path, rel_path).await
require_abs_rel(lua, ctx, abs_path, rel_path).await
}
pub(super) async fn require_abs_rel<'lua, 'ctx>(
ctx: &'ctx RequireContext<'lua>,
lua: &'lua Lua,
ctx: &'ctx RequireContext,
abs_path: PathBuf, // Absolute to filesystem
rel_path: PathBuf, // Relative to CWD (for displaying)
) -> LuaResult<LuaMultiValue<'lua>>
@ -25,7 +27,7 @@ where
'lua: 'ctx,
{
// 1. Try to require the exact path
if let Ok(res) = require_inner(ctx, &abs_path, &rel_path).await {
if let Ok(res) = require_inner(lua, ctx, &abs_path, &rel_path).await {
return Ok(res);
}
@ -34,7 +36,7 @@ where
append_extension(&abs_path, "luau"),
append_extension(&rel_path, "luau"),
);
if let Ok(res) = require_inner(ctx, &luau_abs_path, &luau_rel_path).await {
if let Ok(res) = require_inner(lua, ctx, &luau_abs_path, &luau_rel_path).await {
return Ok(res);
}
@ -43,7 +45,7 @@ where
append_extension(&abs_path, "lua"),
append_extension(&rel_path, "lua"),
);
if let Ok(res) = require_inner(ctx, &lua_abs_path, &lua_rel_path).await {
if let Ok(res) = require_inner(lua, ctx, &lua_abs_path, &lua_rel_path).await {
return Ok(res);
}
@ -57,7 +59,7 @@ where
append_extension(&abs_init, "luau"),
append_extension(&rel_init, "luau"),
);
if let Ok(res) = require_inner(ctx, &luau_abs_init, &luau_rel_init).await {
if let Ok(res) = require_inner(lua, ctx, &luau_abs_init, &luau_rel_init).await {
return Ok(res);
}
@ -66,7 +68,7 @@ where
append_extension(&abs_init, "lua"),
append_extension(&rel_init, "lua"),
);
if let Ok(res) = require_inner(ctx, &lua_abs_init, &lua_rel_init).await {
if let Ok(res) = require_inner(lua, ctx, &lua_abs_init, &lua_rel_init).await {
return Ok(res);
}
@ -78,7 +80,8 @@ where
}
async fn require_inner<'lua, 'ctx>(
ctx: &'ctx RequireContext<'lua>,
lua: &'lua Lua,
ctx: &'ctx RequireContext,
abs_path: impl AsRef<Path>,
rel_path: impl AsRef<Path>,
) -> LuaResult<LuaMultiValue<'lua>>
@ -89,11 +92,11 @@ where
let rel_path = rel_path.as_ref();
if ctx.is_cached(abs_path)? {
ctx.get_from_cache(abs_path)
ctx.get_from_cache(lua, abs_path)
} else if ctx.is_pending(abs_path)? {
ctx.wait_for_cache(&abs_path).await
ctx.wait_for_cache(lua, &abs_path).await
} else {
ctx.load_with_caching(&abs_path, &rel_path).await
ctx.load_with_caching(lua, &abs_path, &rel_path).await
}
}

View file

@ -1,47 +1,42 @@
use std::process::ExitCode;
use std::{
process::ExitCode,
rc::Rc,
sync::{
atomic::{AtomicBool, Ordering},
Arc,
},
};
use mlua::Lua;
use mlua_luau_scheduler::Scheduler;
mod builtins;
mod error;
mod globals;
mod scheduler;
pub(crate) mod util;
use self::scheduler::{LuaSchedulerExt, Scheduler};
pub use error::RuntimeError;
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct Runtime {
lua: &'static Lua,
scheduler: &'static Scheduler<'static>,
lua: Rc<Lua>,
args: Vec<String>,
}
impl Runtime {
/**
Creates a new Lune runtime, with a new Luau VM and task scheduler.
Creates a new Lune runtime, with a new Luau VM.
*/
#[allow(clippy::new_without_default)]
pub fn new() -> Self {
/*
FUTURE: Stop leaking these when we have removed the lifetime
on the scheduler and can place them in lua app data using arc
let lua = Rc::new(Lua::new());
See the scheduler struct for more notes
*/
let lua = Lua::new().into_static();
let scheduler = Scheduler::new().into_static();
lua.set_scheduler(scheduler);
lua.set_app_data(Rc::downgrade(&lua));
lua.set_app_data(Vec::<String>::new());
globals::inject_all(lua).expect("Failed to inject lua globals");
Self {
lua,
scheduler,
args: Vec::new(),
}
}
@ -68,13 +63,35 @@ impl Runtime {
script_name: impl AsRef<str>,
script_contents: impl AsRef<[u8]>,
) -> Result<ExitCode, RuntimeError> {
// Create a new scheduler for this run
let sched = Scheduler::new(&self.lua);
globals::inject_all(&self.lua)?;
// Add error callback to format errors nicely + store status
let got_any_error = Arc::new(AtomicBool::new(false));
let got_any_inner = Arc::clone(&got_any_error);
sched.set_error_callback(move |e| {
got_any_inner.store(true, Ordering::SeqCst);
eprintln!("{}", RuntimeError::from(e));
});
// Load our "main" thread
let main = self
.lua
.load(script_contents.as_ref())
.set_name(script_name.as_ref());
self.scheduler.push_back(self.lua, main, ())?;
// Run it on our scheduler until it and any other spawned threads complete
sched.push_thread_back(main, ())?;
sched.run().await;
Ok(self.scheduler.run_to_completion(self.lua).await)
// Return the exit code - default to FAILURE if we got any errors
Ok(sched.get_exit_code().unwrap_or({
if got_any_error.load(Ordering::SeqCst) {
ExitCode::FAILURE
} else {
ExitCode::SUCCESS
}
}))
}
}

View file

@ -1,138 +0,0 @@
use futures_util::Future;
use mlua::prelude::*;
use tokio::{
sync::oneshot::{self, Receiver},
task,
};
use super::{IntoLuaThread, Scheduler};
impl<'fut> Scheduler<'fut> {
/**
Checks if there are any futures to run, for
lua futures and background futures respectively.
*/
pub(super) fn has_futures(&self) -> (bool, bool) {
(
self.futures_lua
.try_lock()
.expect("Failed to lock lua futures for check")
.len()
> 0,
self.futures_background
.try_lock()
.expect("Failed to lock background futures for check")
.len()
> 0,
)
}
/**
Schedules a plain future to run in the background.
This will potentially spawn the future on a different thread, using
[`task::spawn`], meaning the provided future must implement [`Send`].
Returns a [`Receiver`] which may be `await`-ed
to retrieve the result of the spawned future.
This [`Receiver`] may be safely ignored if the result of the
spawned future is not needed, the future will run either way.
*/
pub fn spawn<F>(&self, fut: F) -> Receiver<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (tx, rx) = oneshot::channel();
let handle = task::spawn(async move {
let res = fut.await;
tx.send(res).ok();
});
// NOTE: We must spawn a future on our scheduler which awaits
// the handle from tokio to start driving our future properly
let futs = self
.futures_background
.try_lock()
.expect("Failed to lock futures queue for background tasks");
futs.push(Box::pin(async move {
handle.await.ok();
}));
// NOTE: We might be resuming lua futures, need to signal that a
// new background future is ready to break out of futures resumption
self.state.message_sender().send_spawned_background_future();
rx
}
/**
Equivalent to [`spawn`], except the future is only
spawned on the Lune scheduler, and on the main thread.
*/
pub fn spawn_local<F>(&self, fut: F) -> Receiver<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
let (tx, rx) = oneshot::channel();
let futs = self
.futures_background
.try_lock()
.expect("Failed to lock futures queue for background tasks");
futs.push(Box::pin(async move {
let res = fut.await;
tx.send(res).ok();
}));
// NOTE: We might be resuming lua futures, need to signal that a
// new background future is ready to break out of futures resumption
self.state.message_sender().send_spawned_background_future();
rx
}
/**
Schedules the given `thread` to run when the given `fut` completes.
If the given future returns a [`LuaError`], that error will be passed to the given `thread`.
*/
pub fn spawn_thread<F, FR>(
&'fut self,
lua: &'fut Lua,
thread: impl IntoLuaThread<'fut>,
fut: F,
) -> LuaResult<()>
where
FR: IntoLuaMulti<'fut>,
F: Future<Output = LuaResult<FR>> + 'fut,
{
let thread = thread.into_lua_thread(lua)?;
let futs = self.futures_lua.try_lock().expect(
"Failed to lock futures queue - \
can't schedule future lua threads during futures resumption",
);
futs.push(Box::pin(async move {
match fut.await.and_then(|rets| rets.into_lua_multi(lua)) {
Err(e) => {
self.push_err(lua, thread, e)
.expect("Failed to schedule future err thread");
}
Ok(v) => {
self.push_back(lua, thread, v)
.expect("Failed to schedule future thread");
}
}
}));
// NOTE: We might be resuming background futures, need to signal that a
// new background future is ready to break out of futures resumption
self.state.message_sender().send_spawned_lua_future();
Ok(())
}
}

View file

@ -1,265 +0,0 @@
use std::{process::ExitCode, sync::Arc};
use futures_util::StreamExt;
use mlua::prelude::*;
use tokio::task::LocalSet;
use tracing::debug;
use crate::lune::util::traits::LuaEmitErrorExt;
use super::Scheduler;
impl<'fut> Scheduler<'fut> {
/**
Runs all lua threads to completion.
*/
fn run_lua_threads(&self, lua: &Lua) {
if self.state.has_exit_code() {
return;
}
let mut count = 0;
// Pop threads from the scheduler until there are none left
while let Some(thread) = self
.pop_thread()
.expect("Failed to pop thread from scheduler")
{
// Deconstruct the scheduler thread into its parts
let thread_id = thread.id();
let (thread, args) = thread.into_inner(lua);
// Make sure this thread is still resumable, it might have
// been resumed somewhere else or even have been cancelled
if thread.status() != LuaThreadStatus::Resumable {
continue;
}
// Resume the thread, ensuring that the schedulers
// current thread id is set correctly for error catching
self.state.set_current_thread_id(Some(thread_id));
let res = thread.resume::<_, LuaMultiValue>(args);
self.state.set_current_thread_id(None);
count += 1;
// If we got any resumption (lua-side) error, increment
// the error count of the scheduler so we can exit with
// a non-zero exit code, and print it out to stderr
if let Err(err) = &res {
self.state.increment_error_count();
lua.emit_error(err.clone());
}
// If the thread has finished running completely,
// send results of final resume to any listeners
if thread.status() != LuaThreadStatus::Resumable {
// NOTE: Threads that were spawned to resume
// with an error will not have a result sender
if let Some(sender) = self
.thread_senders
.try_lock()
.expect("Failed to get thread senders")
.remove(&thread_id)
{
if sender.receiver_count() > 0 {
let stored = match res {
Err(e) => Err(e),
Ok(v) => Ok(Arc::new(lua.create_registry_value(v.into_vec()).expect(
"Failed to store thread results in registry - out of memory",
))),
};
sender
.send(stored)
.expect("Failed to broadcast thread results");
}
}
}
if self.state.has_exit_code() {
break;
}
}
if count > 0 {
debug! {
%count,
"resumed lua"
}
}
}
/**
Runs the next lua future to completion.
Panics if no lua future is queued.
*/
async fn run_future_lua(&self) {
let mut futs = self
.futures_lua
.try_lock()
.expect("Failed to lock lua futures for resumption");
assert!(futs.len() > 0, "No lua futures are queued");
futs.next().await;
}
/**
Runs the next background future to completion.
Panics if no background future is queued.
*/
async fn run_future_background(&self) {
let mut futs = self
.futures_background
.try_lock()
.expect("Failed to lock background futures for resumption");
assert!(futs.len() > 0, "No background futures are queued");
futs.next().await;
}
/**
Runs as many futures as possible, until a new lua thread
is ready, or an exit code has been set for the scheduler.
### Implementation details
Running futures on our scheduler consists of a couple moving parts:
1. An unordered futures queue for lua (main thread, local) futures
2. An unordered futures queue for background (multithreaded, 'static lifetime) futures
3. A signal for breaking out of futures resumption
The two unordered futures queues need to run concurrently,
but since `FuturesUnordered` returns instantly if it does
not currently have any futures queued on it, we need to do
this branching loop, checking if each queue has futures first.
We also need to listen for our signal, to see if we should break out of resumption:
* Always break out of resumption if a new lua thread is ready
* Always break out of resumption if an exit code has been set
* Break out of lua futures resumption if we have a new background future
* Break out of background futures resumption if we have a new lua future
We need to listen for both future queues concurrently,
and break out whenever the other corresponding queue has
a new future, since the other queue may resume sooner.
*/
async fn run_futures(&self) {
let (mut has_lua, mut has_background) = self.has_futures();
if !has_lua && !has_background {
return;
}
let mut rx = self.state.message_receiver();
let mut count = 0;
while has_lua || has_background {
if has_lua && has_background {
tokio::select! {
_ = self.run_future_lua() => {},
_ = self.run_future_background() => {},
msg = rx.recv() => {
if let Some(msg) = msg {
if msg.should_break_futures() {
break;
}
}
}
}
count += 1;
} else if has_lua {
tokio::select! {
_ = self.run_future_lua() => {},
msg = rx.recv() => {
if let Some(msg) = msg {
if msg.should_break_lua_futures() {
break;
}
}
}
}
count += 1;
} else if has_background {
tokio::select! {
_ = self.run_future_background() => {},
msg = rx.recv() => {
if let Some(msg) = msg {
if msg.should_break_background_futures() {
break;
}
}
}
}
count += 1;
}
(has_lua, has_background) = self.has_futures();
}
if count > 0 {
debug! {
%count,
"resumed lua futures"
}
}
}
/**
Runs the scheduler to completion in a [`LocalSet`],
both normal lua threads and futures, prioritizing
lua threads over completion of any pending futures.
Will emit lua output and errors to stdout and stderr.
*/
pub async fn run_to_completion(&self, lua: &Lua) -> ExitCode {
if let Some(code) = self.state.exit_code() {
return ExitCode::from(code);
}
let set = LocalSet::new();
let _guard = set.enter();
loop {
// 1. Run lua threads until exit or there are none left
self.run_lua_threads(lua);
// 2. If we got a manual exit code from lua we should
// not try to wait for any pending futures to complete
if self.state.has_exit_code() {
break;
}
// 3. Keep resuming futures until there are no futures left to
// resume, or until we manually break out of resumption for any
// reason, this may be because a future spawned a new lua thread
self.run_futures().await;
// 4. Once again, check for an exit code, in case a future sets one
if self.state.has_exit_code() {
break;
}
// 5. If we have no lua threads or futures remaining,
// we have now run the scheduler until completion
let (has_future_lua, has_future_background) = self.has_futures();
if !has_future_lua && !has_future_background && !self.has_thread() {
break;
}
}
if let Some(code) = self.state.exit_code() {
debug! {
%code,
"scheduler ran to completion"
};
ExitCode::from(code)
} else if self.state.has_errored() {
debug!("scheduler ran to completion, with failure");
ExitCode::FAILURE
} else {
debug!("scheduler ran to completion, with success");
ExitCode::SUCCESS
}
}
}

View file

@ -1,185 +0,0 @@
use std::sync::Arc;
use mlua::prelude::*;
use super::{
thread::{SchedulerThread, SchedulerThreadId, SchedulerThreadSender},
IntoLuaThread, Scheduler,
};
impl<'fut> Scheduler<'fut> {
/**
Checks if there are any lua threads to run.
*/
pub(super) fn has_thread(&self) -> bool {
!self
.threads
.try_lock()
.expect("Failed to lock threads vec")
.is_empty()
}
/**
Pops the next thread to run, from the front of the scheduler.
Returns `None` if there are no threads left to run.
*/
pub(super) fn pop_thread(&self) -> LuaResult<Option<SchedulerThread>> {
match self
.threads
.try_lock()
.into_lua_err()
.context("Failed to lock threads vec")?
.pop_front()
{
Some(thread) => Ok(Some(thread)),
None => Ok(None),
}
}
/**
Schedules the `thread` to be resumed with the given [`LuaError`].
*/
pub fn push_err<'a>(
&self,
lua: &'a Lua,
thread: impl IntoLuaThread<'a>,
err: LuaError,
) -> LuaResult<()> {
let thread = thread.into_lua_thread(lua)?;
let args = LuaMultiValue::new(); // Will be resumed with error, don't need real args
let thread = SchedulerThread::new(lua, thread, args);
let thread_id = thread.id();
self.state.set_thread_error(thread_id, err);
self.threads
.try_lock()
.into_lua_err()
.context("Failed to lock threads vec")?
.push_front(thread);
// NOTE: We might be resuming futures, need to signal that a
// new lua thread is ready to break out of futures resumption
self.state.message_sender().send_pushed_lua_thread();
Ok(())
}
/**
Schedules the `thread` to be resumed with the given `args`
right away, before any other currently scheduled threads.
*/
pub fn push_front<'a>(
&self,
lua: &'a Lua,
thread: impl IntoLuaThread<'a>,
args: impl IntoLuaMulti<'a>,
) -> LuaResult<SchedulerThreadId> {
let thread = thread.into_lua_thread(lua)?;
let args = args.into_lua_multi(lua)?;
let thread = SchedulerThread::new(lua, thread, args);
let thread_id = thread.id();
self.threads
.try_lock()
.into_lua_err()
.context("Failed to lock threads vec")?
.push_front(thread);
// NOTE: We might be resuming the same thread several times and
// pushing it to the scheduler several times before it is done,
// and we should only ever create one result sender per thread
self.thread_senders
.try_lock()
.into_lua_err()
.context("Failed to lock thread senders vec")?
.entry(thread_id)
.or_insert_with(|| SchedulerThreadSender::new(1));
// NOTE: We might be resuming futures, need to signal that a
// new lua thread is ready to break out of futures resumption
self.state.message_sender().send_pushed_lua_thread();
Ok(thread_id)
}
/**
Schedules the `thread` to be resumed with the given `args`
after all other current threads have been resumed.
*/
pub fn push_back<'a>(
&self,
lua: &'a Lua,
thread: impl IntoLuaThread<'a>,
args: impl IntoLuaMulti<'a>,
) -> LuaResult<SchedulerThreadId> {
let thread = thread.into_lua_thread(lua)?;
let args = args.into_lua_multi(lua)?;
let thread = SchedulerThread::new(lua, thread, args);
let thread_id = thread.id();
self.threads
.try_lock()
.into_lua_err()
.context("Failed to lock threads vec")?
.push_back(thread);
// NOTE: We might be resuming the same thread several times and
// pushing it to the scheduler several times before it is done,
// and we should only ever create one result sender per thread
self.thread_senders
.try_lock()
.into_lua_err()
.context("Failed to lock thread senders vec")?
.entry(thread_id)
.or_insert_with(|| SchedulerThreadSender::new(1));
// NOTE: We might be resuming futures, need to signal that a
// new lua thread is ready to break out of futures resumption
self.state.message_sender().send_pushed_lua_thread();
Ok(thread_id)
}
/**
Waits for the given thread to finish running, and returns its result.
*/
pub async fn wait_for_thread<'a>(
&self,
lua: &'a Lua,
thread_id: SchedulerThreadId,
) -> LuaResult<LuaMultiValue<'a>> {
let mut recv = {
let senders = self.thread_senders.lock().await;
let sender = senders
.get(&thread_id)
.expect("Tried to wait for thread that is not queued");
sender.subscribe()
};
let res = match recv.recv().await {
Err(_) => panic!("Sender was dropped while waiting for {thread_id:?}"),
Ok(r) => r,
};
match res {
Err(e) => Err(e),
Ok(k) => {
let vals = lua
.registry_value::<Vec<LuaValue>>(&k)
.expect("Received invalid registry key for thread");
// NOTE: This is not strictly necessary, mlua can clean
// up registry values on its own, but doing this will add
// some extra safety and clean up registry values faster
if let Some(key) = Arc::into_inner(k) {
lua.remove_registry_value(key)
.expect("Failed to remove registry key for thread");
}
Ok(LuaMultiValue::from_vec(vals))
}
}
}
}

View file

@ -1,98 +0,0 @@
use std::sync::{MutexGuard, TryLockError};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use super::state::SchedulerState;
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub(crate) enum SchedulerMessage {
ExitCodeSet,
PushedLuaThread,
SpawnedLuaFuture,
SpawnedBackgroundFuture,
}
impl SchedulerMessage {
pub fn should_break_futures(self) -> bool {
matches!(self, Self::ExitCodeSet | Self::PushedLuaThread)
}
pub fn should_break_lua_futures(self) -> bool {
self.should_break_futures() || matches!(self, Self::SpawnedBackgroundFuture)
}
pub fn should_break_background_futures(self) -> bool {
self.should_break_futures() || matches!(self, Self::SpawnedLuaFuture)
}
}
/**
A message sender for the scheduler.
As long as this sender is not dropped, the scheduler
will be kept alive, waiting for more messages to arrive.
*/
pub(crate) struct SchedulerMessageSender(UnboundedSender<SchedulerMessage>);
impl SchedulerMessageSender {
/**
Creates a new message sender for the scheduler.
*/
pub fn new(state: &SchedulerState) -> Self {
Self(
state
.message_sender
.lock()
.expect("Scheduler state was poisoned")
.clone(),
)
}
pub fn send_exit_code_set(&self) {
self.0.send(SchedulerMessage::ExitCodeSet).ok();
}
pub fn send_pushed_lua_thread(&self) {
self.0.send(SchedulerMessage::PushedLuaThread).ok();
}
pub fn send_spawned_lua_future(&self) {
self.0.send(SchedulerMessage::SpawnedLuaFuture).ok();
}
pub fn send_spawned_background_future(&self) {
self.0.send(SchedulerMessage::SpawnedBackgroundFuture).ok();
}
}
/**
A message receiver for the scheduler.
Only one message receiver may exist per scheduler.
*/
pub(crate) struct SchedulerMessageReceiver<'a>(MutexGuard<'a, UnboundedReceiver<SchedulerMessage>>);
impl<'a> SchedulerMessageReceiver<'a> {
/**
Creates a new message receiver for the scheduler.
Panics if the message receiver is already being used.
*/
pub fn new(state: &'a SchedulerState) -> Self {
Self(match state.message_receiver.try_lock() {
Err(TryLockError::Poisoned(_)) => panic!("Sheduler state was poisoned"),
Err(TryLockError::WouldBlock) => {
panic!("Message receiver may only be borrowed once at a time")
}
Ok(guard) => guard,
})
}
// NOTE: Holding this lock across await points is fine, since we
// can only ever create lock exactly one SchedulerMessageReceiver
// See above constructor for details on this
#[allow(clippy::await_holding_lock)]
pub async fn recv(&mut self) -> Option<SchedulerMessage> {
self.0.recv().await
}
}

View file

@ -1,120 +0,0 @@
use std::{
collections::{HashMap, VecDeque},
pin::Pin,
sync::Arc,
};
use futures_util::{stream::FuturesUnordered, Future};
use mlua::prelude::*;
use tokio::sync::Mutex as AsyncMutex;
mod message;
mod state;
mod thread;
mod traits;
mod impl_async;
mod impl_runner;
mod impl_threads;
pub use self::thread::SchedulerThreadId;
pub use self::traits::*;
use self::{
state::SchedulerState,
thread::{SchedulerThread, SchedulerThreadSender},
};
type SchedulerFuture<'fut> = Pin<Box<dyn Future<Output = ()> + 'fut>>;
/**
Scheduler for Lua threads and futures.
This scheduler can be cheaply cloned and the underlying state
and data will remain unchanged and accessible from all clones.
*/
#[derive(Debug, Clone)]
pub(crate) struct Scheduler<'fut> {
state: Arc<SchedulerState>,
threads: Arc<AsyncMutex<VecDeque<SchedulerThread>>>,
thread_senders: Arc<AsyncMutex<HashMap<SchedulerThreadId, SchedulerThreadSender>>>,
/*
FUTURE: Get rid of these, let the tokio runtime handle running
and resumption of futures completely, just use our scheduler
state and receiver to know when we have run to completion.
If we have no senders left, we have run to completion.
We should also investigate using smol / async-executor and its
LocalExecutor struct which does not impose the 'static lifetime
restriction on all of the futures spawned on it, unlike tokio.
If we no longer store futures directly in our scheduler, we
can get rid of the lifetime on it, store it in our lua app
data as a Weak<Scheduler>, together with a Weak<Lua>.
In our lua async functions we can then get a reference to this,
upgrade it to an Arc<Scheduler> and Arc<Lua> to extend lifetimes,
and hopefully get rid of Box::leak and 'static lifetimes for good.
Relevant comment on the mlua repository:
https://github.com/khvzak/mlua/issues/169#issuecomment-1138863979
*/
futures_lua: Arc<AsyncMutex<FuturesUnordered<SchedulerFuture<'fut>>>>,
futures_background: Arc<AsyncMutex<FuturesUnordered<SchedulerFuture<'static>>>>,
}
impl<'fut> Scheduler<'fut> {
/**
Creates a new scheduler.
*/
#[allow(clippy::arc_with_non_send_sync)] // FIXME: Clippy lints our tokio mutexes that are definitely Send + Sync
pub fn new() -> Self {
Self {
state: Arc::new(SchedulerState::new()),
threads: Arc::new(AsyncMutex::new(VecDeque::new())),
thread_senders: Arc::new(AsyncMutex::new(HashMap::new())),
futures_lua: Arc::new(AsyncMutex::new(FuturesUnordered::new())),
futures_background: Arc::new(AsyncMutex::new(FuturesUnordered::new())),
}
}
/**
Sets the luau interrupt for this scheduler.
This will propagate errors from any lua-spawned
futures back to the lua threads that spawned them.
*/
pub fn set_interrupt_for(&self, lua: &Lua) {
// Propagate errors given to the scheduler back to their lua threads
// FUTURE: Do profiling and anything else we need inside of this interrupt
let state = self.state.clone();
lua.set_interrupt(move |_| {
if let Some(id) = state.get_current_thread_id() {
if let Some(err) = state.get_thread_error(id) {
return Err(err);
}
}
Ok(LuaVmState::Continue)
});
}
/**
Sets the exit code for the scheduler.
This will stop the scheduler from resuming any more lua threads or futures.
Panics if the exit code is set more than once.
*/
pub fn set_exit_code(&self, code: impl Into<u8>) {
assert!(
self.state.exit_code().is_none(),
"Exit code may only be set exactly once"
);
self.state.set_exit_code(code.into());
}
#[doc(hidden)]
pub fn into_static(self) -> &'static Self {
Box::leak(Box::new(self))
}
}

View file

@ -1,176 +0,0 @@
use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
Arc, Mutex,
},
};
use mlua::Error as LuaError;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use super::{
message::{SchedulerMessage, SchedulerMessageReceiver, SchedulerMessageSender},
SchedulerThreadId,
};
/**
Internal state for a [`Scheduler`].
This scheduler state uses atomic operations for everything
except lua error storage, and is completely thread safe.
*/
#[derive(Debug)]
pub(crate) struct SchedulerState {
exit_state: AtomicBool,
exit_code: AtomicU8,
num_resumptions: AtomicUsize,
num_errors: AtomicUsize,
thread_id: Arc<Mutex<Option<SchedulerThreadId>>>,
thread_errors: Arc<Mutex<HashMap<SchedulerThreadId, LuaError>>>,
pub(super) message_sender: Arc<Mutex<UnboundedSender<SchedulerMessage>>>,
pub(super) message_receiver: Arc<Mutex<UnboundedReceiver<SchedulerMessage>>>,
}
impl SchedulerState {
/**
Creates a new scheduler state.
*/
pub fn new() -> Self {
let (message_sender, message_receiver) = unbounded_channel();
Self {
exit_state: AtomicBool::new(false),
exit_code: AtomicU8::new(0),
num_resumptions: AtomicUsize::new(0),
num_errors: AtomicUsize::new(0),
thread_id: Arc::new(Mutex::new(None)),
thread_errors: Arc::new(Mutex::new(HashMap::new())),
message_sender: Arc::new(Mutex::new(message_sender)),
message_receiver: Arc::new(Mutex::new(message_receiver)),
}
}
/**
Increments the total lua error count for the scheduler.
This is used to determine if the scheduler should exit with
a non-zero exit code, when no exit code is explicitly set.
*/
pub fn increment_error_count(&self) {
self.num_errors.fetch_add(1, Ordering::Relaxed);
}
/**
Checks if there have been any lua errors.
This is used to determine if the scheduler should exit with
a non-zero exit code, when no exit code is explicitly set.
*/
pub fn has_errored(&self) -> bool {
self.num_errors.load(Ordering::SeqCst) > 0
}
/**
Gets the currently set exit code for the scheduler, if any.
*/
pub fn exit_code(&self) -> Option<u8> {
if self.exit_state.load(Ordering::SeqCst) {
Some(self.exit_code.load(Ordering::SeqCst))
} else {
None
}
}
/**
Checks if the scheduler has an explicit exit code set.
*/
pub fn has_exit_code(&self) -> bool {
self.exit_state.load(Ordering::SeqCst)
}
/**
Sets the explicit exit code for the scheduler.
*/
pub fn set_exit_code(&self, code: impl Into<u8>) {
self.exit_state.store(true, Ordering::SeqCst);
self.exit_code.store(code.into(), Ordering::SeqCst);
self.message_sender().send_exit_code_set();
}
/**
Gets the currently running lua scheduler thread id, if any.
*/
pub fn get_current_thread_id(&self) -> Option<SchedulerThreadId> {
*self
.thread_id
.lock()
.expect("Failed to lock current thread id")
}
/**
Sets the currently running lua scheduler thread id.
This must be set to `Some(id)` just before resuming a lua thread,
and `None` while no lua thread is being resumed. If set to `Some`
while the current thread id is also `Some`, this will panic.
Must only be set once per thread id, although this
is not checked at runtime for performance reasons.
*/
pub fn set_current_thread_id(&self, id: Option<SchedulerThreadId>) {
self.num_resumptions.fetch_add(1, Ordering::Relaxed);
let mut thread_id = self
.thread_id
.lock()
.expect("Failed to lock current thread id");
assert!(
id.is_none() || thread_id.is_none(),
"Current thread id can not be overwritten"
);
*thread_id = id;
}
/**
Gets the [`LuaError`] (if any) for the given `id`.
Note that this removes the error from the scheduler state completely.
*/
pub fn get_thread_error(&self, id: SchedulerThreadId) -> Option<LuaError> {
let mut thread_errors = self
.thread_errors
.lock()
.expect("Failed to lock thread errors");
thread_errors.remove(&id)
}
/**
Sets a [`LuaError`] for the given `id`.
Note that this will replace any already existing [`LuaError`].
*/
pub fn set_thread_error(&self, id: SchedulerThreadId, err: LuaError) {
let mut thread_errors = self
.thread_errors
.lock()
.expect("Failed to lock thread errors");
thread_errors.insert(id, err);
}
/**
Creates a new message sender for the scheduler.
*/
pub fn message_sender(&self) -> SchedulerMessageSender {
SchedulerMessageSender::new(self)
}
/**
Tries to borrow the message receiver for the scheduler.
Panics if the message receiver is already being used.
*/
pub fn message_receiver(&self) -> SchedulerMessageReceiver {
SchedulerMessageReceiver::new(self)
}
}

View file

@ -1,105 +0,0 @@
use std::sync::Arc;
use mlua::prelude::*;
use tokio::sync::broadcast::Sender;
/**
Type alias for a broadcast [`Sender`], which will
broadcast the result and return values of a lua thread.
The return values are stored in the lua registry as a
`Vec<LuaValue<'_>>`, and the registry key pointing to
those values will be sent using the broadcast sender.
*/
pub type SchedulerThreadSender = Sender<LuaResult<Arc<LuaRegistryKey>>>;
/**
Unique, randomly generated id for a scheduler thread.
*/
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
pub struct SchedulerThreadId(usize);
impl From<&LuaThread<'_>> for SchedulerThreadId {
fn from(value: &LuaThread) -> Self {
// HACK: We rely on the debug format of mlua
// thread refs here, but currently this is the
// only way to get a proper unique id using mlua
let addr_string = format!("{value:?}");
let addr = addr_string
.strip_prefix("Thread(Ref(0x")
.expect("Invalid thread address format - unknown prefix")
.split_once(')')
.map(|(s, _)| s)
.expect("Invalid thread address format - missing ')'");
let id = usize::from_str_radix(addr, 16)
.expect("Failed to parse thread address as hexadecimal into usize");
Self(id)
}
}
/**
Container for registry keys that point to a thread and thread arguments.
*/
#[derive(Debug)]
pub(super) struct SchedulerThread {
thread_id: SchedulerThreadId,
key_thread: LuaRegistryKey,
key_args: LuaRegistryKey,
}
impl SchedulerThread {
/**
Creates a new scheduler thread container from the given thread and arguments.
May fail if an allocation error occurs, is not fallible otherwise.
*/
pub(super) fn new<'lua>(
lua: &'lua Lua,
thread: LuaThread<'lua>,
args: LuaMultiValue<'lua>,
) -> Self {
let args_vec = args.into_vec();
let thread_id = SchedulerThreadId::from(&thread);
let key_thread = lua
.create_registry_value(thread)
.expect("Failed to store thread in registry - out of memory");
let key_args = lua
.create_registry_value(args_vec)
.expect("Failed to store thread args in registry - out of memory");
Self {
thread_id,
key_thread,
key_args,
}
}
/**
Extracts the inner thread and args from the container.
*/
pub(super) fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) {
let thread = lua
.registry_value(&self.key_thread)
.expect("Failed to get thread from registry");
let args_vec = lua
.registry_value(&self.key_args)
.expect("Failed to get thread args from registry");
let args = LuaMultiValue::from_vec(args_vec);
lua.remove_registry_value(self.key_thread)
.expect("Failed to remove thread from registry");
lua.remove_registry_value(self.key_args)
.expect("Failed to remove thread args from registry");
(thread, args)
}
/**
Retrieves the unique, randomly generated id for this scheduler thread.
*/
pub(super) fn id(&self) -> SchedulerThreadId {
self.thread_id
}
}

View file

@ -1,118 +0,0 @@
use futures_util::Future;
use mlua::prelude::*;
use super::Scheduler;
const ASYNC_IMPL_LUA: &str = r#"
schedule(...)
return yield()
"#;
/**
Trait for extensions to the [`Lua`] struct, allowing
for access to the scheduler without having to import
it or handle registry / app data references manually.
*/
pub(crate) trait LuaSchedulerExt<'lua> {
/**
Sets the scheduler for the [`Lua`] struct.
*/
fn set_scheduler(&'lua self, scheduler: &'lua Scheduler);
/**
Creates a function callable from Lua that runs an async
closure and returns the results of it to the call site.
*/
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
where
A: FromLuaMulti<'lua>,
R: IntoLuaMulti<'lua>,
F: Fn(&'lua Lua, A) -> FR + 'lua,
FR: Future<Output = LuaResult<R>> + 'lua;
}
// FIXME: `self` escapes outside of method because we are borrowing `func`
// when we call `schedule_future_thread` in the lua function body below
// For now we solve this by using the 'static lifetime bound in the impl
impl<'lua> LuaSchedulerExt<'lua> for Lua
where
'lua: 'static,
{
fn set_scheduler(&'lua self, scheduler: &'lua Scheduler) {
self.set_app_data(scheduler);
scheduler.set_interrupt_for(self);
}
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
where
A: FromLuaMulti<'lua>,
R: IntoLuaMulti<'lua>,
F: Fn(&'lua Lua, A) -> FR + 'lua,
FR: Future<Output = LuaResult<R>> + 'lua,
{
self.app_data_ref::<&Scheduler>()
.expect("Lua must have a scheduler to create async functions");
let async_env = self.create_table_with_capacity(0, 2)?;
async_env.set(
"yield",
self.globals()
.get::<_, LuaTable>("coroutine")?
.get::<_, LuaFunction>("yield")?,
)?;
async_env.set(
"schedule",
LuaFunction::wrap(move |lua: &Lua, args: A| {
let thread = lua.current_thread();
let future = func(lua, args);
let sched = lua
.app_data_ref::<&Scheduler>()
.expect("Lua struct is missing scheduler");
sched.spawn_thread(lua, thread, future)?;
Ok(())
}),
)?;
let async_func = self
.load(ASYNC_IMPL_LUA)
.set_name("async")
.set_environment(async_env)
.into_function()?;
Ok(async_func)
}
}
/**
Trait for any struct that can be turned into an [`LuaThread`]
and given to the scheduler, implemented for the following types:
- Lua threads ([`LuaThread`])
- Lua functions ([`LuaFunction`])
- Lua chunks ([`LuaChunk`])
*/
pub trait IntoLuaThread<'lua> {
/**
Converts the value into a lua thread.
*/
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>>;
}
impl<'lua> IntoLuaThread<'lua> for LuaThread<'lua> {
fn into_lua_thread(self, _: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
Ok(self)
}
}
impl<'lua> IntoLuaThread<'lua> for LuaFunction<'lua> {
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
lua.create_thread(self)
}
}
impl<'lua, 'a> IntoLuaThread<'lua> for LuaChunk<'lua, 'a> {
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
lua.create_thread(self.into_function()?)
}
}

View file

@ -1,18 +0,0 @@
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
#[derive(Debug, Clone, Copy)]
pub struct YieldForever;
impl Future for YieldForever {
type Output = ();
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
Poll::Pending
}
}
pub fn yield_forever() -> YieldForever {
YieldForever
}

View file

@ -1,7 +1,6 @@
mod table_builder;
pub mod formatting;
pub mod futures;
pub mod luaurc;
pub mod paths;
pub mod traits;

View file

@ -4,8 +4,6 @@ use std::future::Future;
use mlua::prelude::*;
use crate::lune::scheduler::LuaSchedulerExt;
pub struct TableBuilder<'lua> {
lua: &'lua Lua,
tab: LuaTable<'lua>,
@ -79,20 +77,13 @@ impl<'lua> TableBuilder<'lua> {
pub fn build(self) -> LuaResult<LuaTable<'lua>> {
Ok(self.tab)
}
}
// FIXME: Remove static lifetime bound here when `create_async_function`
// no longer needs it to compile, then move this into the above impl
impl<'lua> TableBuilder<'lua>
where
'lua: 'static,
{
pub fn with_async_function<K, A, R, F, FR>(self, key: K, func: F) -> LuaResult<Self>
where
K: IntoLua<'lua>,
A: FromLuaMulti<'lua>,
R: IntoLuaMulti<'lua>,
F: Fn(&'lua Lua, A) -> FR + 'lua,
F: Fn(&'lua Lua, A) -> FR + 'static,
FR: Future<Output = LuaResult<R>> + 'lua,
{
let f = self.lua.create_async_function(func)?;

View file

@ -68,6 +68,7 @@ create_tests! {
net_url_decode: "net/url/decode",
net_serve_requests: "net/serve/requests",
net_serve_websockets: "net/serve/websockets",
net_socket_basic: "net/socket/basic",
net_socket_wss: "net/socket/wss",
net_socket_wss_rw: "net/socket/wss_rw",
@ -84,7 +85,6 @@ create_tests! {
require_aliases: "require/tests/aliases",
require_async: "require/tests/async",
require_async_background: "require/tests/async_background",
require_async_concurrent: "require/tests/async_concurrent",
require_async_sequential: "require/tests/async_sequential",
require_builtins: "require/tests/builtins",

View file

@ -0,0 +1,28 @@
local net = require("@lune/net")
-- We're going to use Discord's WebSocket gateway server for testing
local socket = net.socket("wss://gateway.discord.gg/?v=10&encoding=json")
assert(type(socket.next) == "function", "next must be a function")
assert(type(socket.send) == "function", "send must be a function")
assert(type(socket.close) == "function", "close must be a function")
-- Request to close the socket
socket.close()
-- Drain remaining messages, until we got our close message
while socket.next() do
end
assert(type(socket.closeCode) == "number", "closeCode should exist after closing")
assert(socket.closeCode == 1000, "closeCode should be 1000 after closing")
local success, message = pcall(function()
socket.send("Hello, world!")
end)
assert(not success, "send should fail after closing")
assert(
string.find(tostring(message), "closed") or string.find(tostring(message), "closing"),
"send should fail with a message that the socket was closed"
)

View file

@ -1,51 +0,0 @@
local net = require("@lune/net")
local process = require("@lune/process")
local stdio = require("@lune/stdio")
local task = require("@lune/task")
-- Spawn an asynchronous background task (eg. web server)
local PORT = 8082
task.delay(3, function()
stdio.ewrite("Test did not complete in time\n")
task.wait(1)
process.exit(1)
end)
local handle = net.serve(PORT, function(request)
return ""
end)
-- Require modules same way we did in the async_concurrent and async_sequential tests
local module3
local module4
task.defer(function()
module4 = require("./modules/async")
end)
task.spawn(function()
module3 = require("./modules/async")
end)
local _module1 = require("./modules/async")
local _module2 = require("./modules/async")
task.wait(1)
assert(type(module3) == "table", "Required module3 did not return a table")
assert(module3.Foo == "Bar", "Required module3 did not contain correct values")
assert(module3.Hello == "World", "Required module3 did not contain correct values")
assert(type(module4) == "table", "Required module4 did not return a table")
assert(module4.Foo == "Bar", "Required module4 did not contain correct values")
assert(module4.Hello == "World", "Required module4 did not contain correct values")
assert(module3 == module4, "Required modules should point to the same return value")
-- Stop the server and exit successfully
handle.stop()
process.exit(0)