mirror of
https://github.com/lune-org/lune.git
synced 2025-01-07 20:09:09 +00:00
Rewrite scheduler and make it smol (#165)
This commit is contained in:
parent
1f211ca0ab
commit
cd34dcb0dd
44 changed files with 1489 additions and 2525 deletions
806
Cargo.lock
generated
806
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
25
Cargo.toml
25
Cargo.toml
|
@ -79,11 +79,19 @@ urlencoding = "2.1"
|
||||||
|
|
||||||
### RUNTIME
|
### RUNTIME
|
||||||
|
|
||||||
|
blocking = "1.5"
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
mlua = { version = "0.9.1", features = ["luau", "luau-jit", "serialize"] }
|
|
||||||
tokio = { version = "1.24", features = ["full", "tracing"] }
|
tokio = { version = "1.24", features = ["full", "tracing"] }
|
||||||
os_str_bytes = { version = "6.4", features = ["conversions"] }
|
os_str_bytes = { version = "7.0", features = ["conversions"] }
|
||||||
|
|
||||||
|
mlua-luau-scheduler = { version = "0.0.2" }
|
||||||
|
mlua = { version = "0.9.6", features = [
|
||||||
|
"luau",
|
||||||
|
"luau-jit",
|
||||||
|
"async",
|
||||||
|
"serialize",
|
||||||
|
] }
|
||||||
|
|
||||||
### SERDE
|
### SERDE
|
||||||
|
|
||||||
|
@ -101,12 +109,17 @@ toml = { version = "0.8", features = ["preserve_order"] }
|
||||||
|
|
||||||
### NET
|
### NET
|
||||||
|
|
||||||
hyper = { version = "0.14", features = ["full"] }
|
hyper = { version = "1.1", features = ["full"] }
|
||||||
hyper-tungstenite = { version = "0.11" }
|
hyper-util = { version = "0.1", features = ["full"] }
|
||||||
|
http = "1.0"
|
||||||
|
http-body-util = { version = "0.1" }
|
||||||
|
hyper-tungstenite = { version = "0.13" }
|
||||||
|
|
||||||
reqwest = { version = "0.11", default-features = false, features = [
|
reqwest = { version = "0.11", default-features = false, features = [
|
||||||
"rustls-tls",
|
"rustls-tls",
|
||||||
] }
|
] }
|
||||||
tokio-tungstenite = { version = "0.20", features = ["rustls-tls-webpki-roots"] }
|
|
||||||
|
tokio-tungstenite = { version = "0.21", features = ["rustls-tls-webpki-roots"] }
|
||||||
|
|
||||||
### DATETIME
|
### DATETIME
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
|
@ -115,7 +128,7 @@ chrono_lc = "0.1"
|
||||||
### CLI
|
### CLI
|
||||||
|
|
||||||
anyhow = { optional = true, version = "1.0" }
|
anyhow = { optional = true, version = "1.0" }
|
||||||
env_logger = { optional = true, version = "0.10" }
|
env_logger = { optional = true, version = "0.11" }
|
||||||
itertools = { optional = true, version = "0.12" }
|
itertools = { optional = true, version = "0.12" }
|
||||||
clap = { optional = true, version = "4.1", features = ["derive"] }
|
clap = { optional = true, version = "4.1", features = ["derive"] }
|
||||||
include_dir = { optional = true, version = "0.7", features = ["glob"] }
|
include_dir = { optional = true, version = "0.7", features = ["glob"] }
|
||||||
|
|
|
@ -14,7 +14,7 @@ use copy::copy;
|
||||||
use metadata::FsMetadata;
|
use metadata::FsMetadata;
|
||||||
use options::FsWriteOptions;
|
use options::FsWriteOptions;
|
||||||
|
|
||||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||||
TableBuilder::new(lua)?
|
TableBuilder::new(lua)?
|
||||||
.with_async_function("readFile", fs_read_file)?
|
.with_async_function("readFile", fs_read_file)?
|
||||||
.with_async_function("readDir", fs_read_dir)?
|
.with_async_function("readDir", fs_read_dir)?
|
||||||
|
|
|
@ -28,10 +28,7 @@ pub enum LuneBuiltin {
|
||||||
Roblox,
|
Roblox,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'lua> LuneBuiltin
|
impl LuneBuiltin {
|
||||||
where
|
|
||||||
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
|
|
||||||
{
|
|
||||||
pub fn name(&self) -> &'static str {
|
pub fn name(&self) -> &'static str {
|
||||||
match self {
|
match self {
|
||||||
Self::DateTime => "datetime",
|
Self::DateTime => "datetime",
|
||||||
|
@ -47,7 +44,7 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create(&self, lua: &'lua Lua) -> LuaResult<LuaMultiValue<'lua>> {
|
pub fn create<'lua>(&self, lua: &'lua Lua) -> LuaResult<LuaMultiValue<'lua>> {
|
||||||
let res = match self {
|
let res = match self {
|
||||||
Self::DateTime => datetime::create(lua),
|
Self::DateTime => datetime::create(lua),
|
||||||
Self::Fs => fs::create(lua),
|
Self::Fs => fs::create(lua),
|
||||||
|
|
|
@ -2,8 +2,14 @@ use std::str::FromStr;
|
||||||
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
use hyper::{header::HeaderName, http::HeaderValue, HeaderMap};
|
use reqwest::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_ENCODING};
|
||||||
use reqwest::{IntoUrl, Method, RequestBuilder};
|
|
||||||
|
use crate::lune::{
|
||||||
|
builtins::serde::compress_decompress::{decompress, CompressDecompressFormat},
|
||||||
|
util::TableBuilder,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{config::RequestConfig, util::header_map_to_table};
|
||||||
|
|
||||||
const REGISTRY_KEY: &str = "NetClient";
|
const REGISTRY_KEY: &str = "NetClient";
|
||||||
|
|
||||||
|
@ -35,16 +41,19 @@ impl NetClientBuilder {
|
||||||
|
|
||||||
pub fn build(self) -> LuaResult<NetClient> {
|
pub fn build(self) -> LuaResult<NetClient> {
|
||||||
let client = self.builder.build().into_lua_err()?;
|
let client = self.builder.build().into_lua_err()?;
|
||||||
Ok(NetClient(client))
|
Ok(NetClient { inner: client })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct NetClient(reqwest::Client);
|
pub struct NetClient {
|
||||||
|
inner: reqwest::Client,
|
||||||
|
}
|
||||||
|
|
||||||
impl NetClient {
|
impl NetClient {
|
||||||
pub fn request<U: IntoUrl>(&self, method: Method, url: U) -> RequestBuilder {
|
pub fn from_registry(lua: &Lua) -> Self {
|
||||||
self.0.request(method, url)
|
lua.named_registry_value(REGISTRY_KEY)
|
||||||
|
.expect("Failed to get NetClient from lua registry")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn into_registry(self, lua: &Lua) {
|
pub fn into_registry(self, lua: &Lua) {
|
||||||
|
@ -52,16 +61,68 @@ impl NetClient {
|
||||||
.expect("Failed to store NetClient in lua registry");
|
.expect("Failed to store NetClient in lua registry");
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_registry(lua: &Lua) -> Self {
|
pub async fn request(&self, config: RequestConfig) -> LuaResult<NetClientResponse> {
|
||||||
lua.named_registry_value(REGISTRY_KEY)
|
// Create and send the request
|
||||||
.expect("Failed to get NetClient from lua registry")
|
let mut request = self.inner.request(config.method, config.url);
|
||||||
|
for (query, values) in config.query {
|
||||||
|
request = request.query(
|
||||||
|
&values
|
||||||
|
.iter()
|
||||||
|
.map(|v| (query.as_str(), v))
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
for (header, values) in config.headers {
|
||||||
|
for value in values {
|
||||||
|
request = request.header(header.as_str(), value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let res = request
|
||||||
|
.body(config.body.unwrap_or_default())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.into_lua_err()?;
|
||||||
|
|
||||||
|
// Extract status, headers
|
||||||
|
let res_status = res.status().as_u16();
|
||||||
|
let res_status_text = res.status().canonical_reason();
|
||||||
|
let res_headers = res.headers().clone();
|
||||||
|
|
||||||
|
// Read response bytes
|
||||||
|
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
|
||||||
|
let mut res_decompressed = false;
|
||||||
|
|
||||||
|
// Check for extra options, decompression
|
||||||
|
if config.options.decompress {
|
||||||
|
let decompress_format = res_headers
|
||||||
|
.iter()
|
||||||
|
.find(|(name, _)| {
|
||||||
|
name.as_str()
|
||||||
|
.eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
|
||||||
|
})
|
||||||
|
.and_then(|(_, value)| value.to_str().ok())
|
||||||
|
.and_then(CompressDecompressFormat::detect_from_header_str);
|
||||||
|
if let Some(format) = decompress_format {
|
||||||
|
res_bytes = decompress(format, res_bytes).await?;
|
||||||
|
res_decompressed = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(NetClientResponse {
|
||||||
|
ok: (200..300).contains(&res_status),
|
||||||
|
status_code: res_status,
|
||||||
|
status_message: res_status_text.unwrap_or_default().to_string(),
|
||||||
|
headers: res_headers,
|
||||||
|
body: res_bytes,
|
||||||
|
body_decompressed: res_decompressed,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LuaUserData for NetClient {}
|
impl LuaUserData for NetClient {}
|
||||||
|
|
||||||
impl<'lua> FromLua<'lua> for NetClient {
|
impl FromLua<'_> for NetClient {
|
||||||
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
|
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
|
||||||
if let LuaValue::UserData(ud) = value {
|
if let LuaValue::UserData(ud) = value {
|
||||||
if let Ok(ctx) = ud.borrow::<NetClient>() {
|
if let Ok(ctx) = ud.borrow::<NetClient>() {
|
||||||
return Ok(ctx.clone());
|
return Ok(ctx.clone());
|
||||||
|
@ -71,10 +132,34 @@ impl<'lua> FromLua<'lua> for NetClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'lua> From<&'lua Lua> for NetClient {
|
impl From<&Lua> for NetClient {
|
||||||
fn from(value: &'lua Lua) -> Self {
|
fn from(value: &Lua) -> Self {
|
||||||
value
|
value
|
||||||
.named_registry_value(REGISTRY_KEY)
|
.named_registry_value(REGISTRY_KEY)
|
||||||
.expect("Missing require context in lua registry")
|
.expect("Missing require context in lua registry")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub struct NetClientResponse {
|
||||||
|
ok: bool,
|
||||||
|
status_code: u16,
|
||||||
|
status_message: String,
|
||||||
|
headers: HeaderMap,
|
||||||
|
body: Vec<u8>,
|
||||||
|
body_decompressed: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NetClientResponse {
|
||||||
|
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
|
||||||
|
TableBuilder::new(lua)?
|
||||||
|
.with_value("ok", self.ok)?
|
||||||
|
.with_value("statusCode", self.status_code)?
|
||||||
|
.with_value("statusMessage", self.status_message)?
|
||||||
|
.with_value(
|
||||||
|
"headers",
|
||||||
|
header_map_to_table(lua, self.headers, self.body_decompressed)?,
|
||||||
|
)?
|
||||||
|
.with_value("body", lua.create_string(&self.body)?)?
|
||||||
|
.build_readonly()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use std::collections::HashMap;
|
use std::{collections::HashMap, net::Ipv4Addr};
|
||||||
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
|
@ -6,6 +6,18 @@ use reqwest::Method;
|
||||||
|
|
||||||
use super::util::table_to_hash_map;
|
use super::util::table_to_hash_map;
|
||||||
|
|
||||||
|
const DEFAULT_IP_ADDRESS: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
|
||||||
|
|
||||||
|
const WEB_SOCKET_UPDGRADE_REQUEST_HANDLER: &str = r#"
|
||||||
|
return {
|
||||||
|
status = 426,
|
||||||
|
body = "Upgrade Required",
|
||||||
|
headers = {
|
||||||
|
Upgrade = "websocket",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
"#;
|
||||||
|
|
||||||
// Net request config
|
// Net request config
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -21,28 +33,29 @@ impl Default for RequestConfigOptions {
|
||||||
|
|
||||||
impl<'lua> FromLua<'lua> for RequestConfigOptions {
|
impl<'lua> FromLua<'lua> for RequestConfigOptions {
|
||||||
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
|
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
|
||||||
// Nil means default options, table means custom options
|
|
||||||
if let LuaValue::Nil = value {
|
if let LuaValue::Nil = value {
|
||||||
return Ok(Self::default());
|
// Nil means default options
|
||||||
|
Ok(Self::default())
|
||||||
} else if let LuaValue::Table(tab) = value {
|
} else if let LuaValue::Table(tab) = value {
|
||||||
// Extract flags
|
// Table means custom options
|
||||||
let decompress = match tab.raw_get::<_, Option<bool>>("decompress") {
|
let decompress = match tab.get::<_, Option<bool>>("decompress") {
|
||||||
Ok(decomp) => Ok(decomp.unwrap_or(true)),
|
Ok(decomp) => Ok(decomp.unwrap_or(true)),
|
||||||
Err(_) => Err(LuaError::RuntimeError(
|
Err(_) => Err(LuaError::RuntimeError(
|
||||||
"Invalid option value for 'decompress' in request config options".to_string(),
|
"Invalid option value for 'decompress' in request config options".to_string(),
|
||||||
)),
|
)),
|
||||||
}?;
|
}?;
|
||||||
return Ok(Self { decompress });
|
Ok(Self { decompress })
|
||||||
|
} else {
|
||||||
|
// Anything else is invalid
|
||||||
|
Err(LuaError::FromLuaConversionError {
|
||||||
|
from: value.type_name(),
|
||||||
|
to: "RequestConfigOptions",
|
||||||
|
message: Some(format!(
|
||||||
|
"Invalid request config options - expected table or nil, got {}",
|
||||||
|
value.type_name()
|
||||||
|
)),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
// Anything else is invalid
|
|
||||||
Err(LuaError::FromLuaConversionError {
|
|
||||||
from: value.type_name(),
|
|
||||||
to: "RequestConfigOptions",
|
|
||||||
message: Some(format!(
|
|
||||||
"Invalid request config options - expected table or nil, got {}",
|
|
||||||
value.type_name()
|
|
||||||
)),
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,39 +73,38 @@ impl FromLua<'_> for RequestConfig {
|
||||||
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
|
fn from_lua(value: LuaValue, lua: &Lua) -> LuaResult<Self> {
|
||||||
// If we just got a string we assume its a GET request to a given url
|
// If we just got a string we assume its a GET request to a given url
|
||||||
if let LuaValue::String(s) = value {
|
if let LuaValue::String(s) = value {
|
||||||
return Ok(Self {
|
Ok(Self {
|
||||||
url: s.to_string_lossy().to_string(),
|
url: s.to_string_lossy().to_string(),
|
||||||
method: Method::GET,
|
method: Method::GET,
|
||||||
query: HashMap::new(),
|
query: HashMap::new(),
|
||||||
headers: HashMap::new(),
|
headers: HashMap::new(),
|
||||||
body: None,
|
body: None,
|
||||||
options: Default::default(),
|
options: Default::default(),
|
||||||
});
|
})
|
||||||
}
|
} else if let LuaValue::Table(tab) = value {
|
||||||
// If we got a table we are able to configure the entire request
|
// If we got a table we are able to configure the entire request
|
||||||
if let LuaValue::Table(tab) = value {
|
|
||||||
// Extract url
|
// Extract url
|
||||||
let url = match tab.raw_get::<_, LuaString>("url") {
|
let url = match tab.get::<_, LuaString>("url") {
|
||||||
Ok(config_url) => Ok(config_url.to_string_lossy().to_string()),
|
Ok(config_url) => Ok(config_url.to_string_lossy().to_string()),
|
||||||
Err(_) => Err(LuaError::runtime("Missing 'url' in request config")),
|
Err(_) => Err(LuaError::runtime("Missing 'url' in request config")),
|
||||||
}?;
|
}?;
|
||||||
// Extract method
|
// Extract method
|
||||||
let method = match tab.raw_get::<_, LuaString>("method") {
|
let method = match tab.get::<_, LuaString>("method") {
|
||||||
Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(),
|
Ok(config_method) => config_method.to_string_lossy().trim().to_ascii_uppercase(),
|
||||||
Err(_) => "GET".to_string(),
|
Err(_) => "GET".to_string(),
|
||||||
};
|
};
|
||||||
// Extract query
|
// Extract query
|
||||||
let query = match tab.raw_get::<_, LuaTable>("query") {
|
let query = match tab.get::<_, LuaTable>("query") {
|
||||||
Ok(tab) => table_to_hash_map(tab, "query")?,
|
Ok(tab) => table_to_hash_map(tab, "query")?,
|
||||||
Err(_) => HashMap::new(),
|
Err(_) => HashMap::new(),
|
||||||
};
|
};
|
||||||
// Extract headers
|
// Extract headers
|
||||||
let headers = match tab.raw_get::<_, LuaTable>("headers") {
|
let headers = match tab.get::<_, LuaTable>("headers") {
|
||||||
Ok(tab) => table_to_hash_map(tab, "headers")?,
|
Ok(tab) => table_to_hash_map(tab, "headers")?,
|
||||||
Err(_) => HashMap::new(),
|
Err(_) => HashMap::new(),
|
||||||
};
|
};
|
||||||
// Extract body
|
// Extract body
|
||||||
let body = match tab.raw_get::<_, LuaString>("body") {
|
let body = match tab.get::<_, LuaString>("body") {
|
||||||
Ok(config_body) => Some(config_body.as_bytes().to_owned()),
|
Ok(config_body) => Some(config_body.as_bytes().to_owned()),
|
||||||
Err(_) => None,
|
Err(_) => None,
|
||||||
};
|
};
|
||||||
|
@ -112,29 +124,30 @@ impl FromLua<'_> for RequestConfig {
|
||||||
))),
|
))),
|
||||||
}?;
|
}?;
|
||||||
// Parse any extra options given
|
// Parse any extra options given
|
||||||
let options = match tab.raw_get::<_, LuaValue>("options") {
|
let options = match tab.get::<_, LuaValue>("options") {
|
||||||
Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?,
|
Ok(opts) => RequestConfigOptions::from_lua(opts, lua)?,
|
||||||
Err(_) => RequestConfigOptions::default(),
|
Err(_) => RequestConfigOptions::default(),
|
||||||
};
|
};
|
||||||
// All good, validated and we got what we need
|
// All good, validated and we got what we need
|
||||||
return Ok(Self {
|
Ok(Self {
|
||||||
url,
|
url,
|
||||||
method,
|
method,
|
||||||
query,
|
query,
|
||||||
headers,
|
headers,
|
||||||
body,
|
body,
|
||||||
options,
|
options,
|
||||||
});
|
})
|
||||||
};
|
} else {
|
||||||
// Anything else is invalid
|
// Anything else is invalid
|
||||||
Err(LuaError::FromLuaConversionError {
|
Err(LuaError::FromLuaConversionError {
|
||||||
from: value.type_name(),
|
from: value.type_name(),
|
||||||
to: "RequestConfig",
|
to: "RequestConfig",
|
||||||
message: Some(format!(
|
message: Some(format!(
|
||||||
"Invalid request config - expected string or table, got {}",
|
"Invalid request config - expected string or table, got {}",
|
||||||
value.type_name()
|
value.type_name()
|
||||||
)),
|
)),
|
||||||
})
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -142,54 +155,72 @@ impl FromLua<'_> for RequestConfig {
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ServeConfig<'a> {
|
pub struct ServeConfig<'a> {
|
||||||
|
pub address: Ipv4Addr,
|
||||||
pub handle_request: LuaFunction<'a>,
|
pub handle_request: LuaFunction<'a>,
|
||||||
pub handle_web_socket: Option<LuaFunction<'a>>,
|
pub handle_web_socket: Option<LuaFunction<'a>>,
|
||||||
pub address: Option<LuaString<'a>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'lua> FromLua<'lua> for ServeConfig<'lua> {
|
impl<'lua> FromLua<'lua> for ServeConfig<'lua> {
|
||||||
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
|
fn from_lua(value: LuaValue<'lua>, lua: &'lua Lua) -> LuaResult<Self> {
|
||||||
let message = match &value {
|
if let LuaValue::Function(f) = &value {
|
||||||
LuaValue::Function(f) => {
|
// Single function = request handler, rest is default
|
||||||
return Ok(ServeConfig {
|
Ok(ServeConfig {
|
||||||
handle_request: f.clone(),
|
handle_request: f.clone(),
|
||||||
handle_web_socket: None,
|
handle_web_socket: None,
|
||||||
address: None,
|
address: DEFAULT_IP_ADDRESS,
|
||||||
|
})
|
||||||
|
} else if let LuaValue::Table(t) = &value {
|
||||||
|
// Table means custom options
|
||||||
|
let address: Option<LuaString> = t.get("address")?;
|
||||||
|
let handle_request: Option<LuaFunction> = t.get("handleRequest")?;
|
||||||
|
let handle_web_socket: Option<LuaFunction> = t.get("handleWebSocket")?;
|
||||||
|
if handle_request.is_some() || handle_web_socket.is_some() {
|
||||||
|
let address: Ipv4Addr = match &address {
|
||||||
|
Some(addr) => {
|
||||||
|
let addr_str = addr.to_str()?;
|
||||||
|
|
||||||
|
addr_str
|
||||||
|
.trim_start_matches("http://")
|
||||||
|
.trim_start_matches("https://")
|
||||||
|
.parse()
|
||||||
|
.map_err(|_e| LuaError::FromLuaConversionError {
|
||||||
|
from: value.type_name(),
|
||||||
|
to: "ServeConfig",
|
||||||
|
message: Some(format!(
|
||||||
|
"IP address format is incorrect - \
|
||||||
|
expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', \
|
||||||
|
got '{addr_str}'"
|
||||||
|
)),
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
None => DEFAULT_IP_ADDRESS,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
address,
|
||||||
|
handle_request: handle_request.unwrap_or_else(|| {
|
||||||
|
lua.load(WEB_SOCKET_UPDGRADE_REQUEST_HANDLER)
|
||||||
|
.into_function()
|
||||||
|
.expect("Failed to create default http responder function")
|
||||||
|
}),
|
||||||
|
handle_web_socket,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(LuaError::FromLuaConversionError {
|
||||||
|
from: value.type_name(),
|
||||||
|
to: "ServeConfig",
|
||||||
|
message: Some(String::from(
|
||||||
|
"Invalid serve config - expected table with 'handleRequest' or 'handleWebSocket' function",
|
||||||
|
)),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
LuaValue::Table(t) => {
|
} else {
|
||||||
let handle_request: Option<LuaFunction> = t.raw_get("handleRequest")?;
|
// Anything else is invalid
|
||||||
let handle_web_socket: Option<LuaFunction> = t.raw_get("handleWebSocket")?;
|
Err(LuaError::FromLuaConversionError {
|
||||||
let address: Option<LuaString> = t.raw_get("address")?;
|
from: value.type_name(),
|
||||||
if handle_request.is_some() || handle_web_socket.is_some() {
|
to: "ServeConfig",
|
||||||
return Ok(ServeConfig {
|
message: None,
|
||||||
handle_request: handle_request.unwrap_or_else(|| {
|
})
|
||||||
let chunk = r#"
|
}
|
||||||
return {
|
|
||||||
status = 426,
|
|
||||||
body = "Upgrade Required",
|
|
||||||
headers = {
|
|
||||||
Upgrade = "websocket",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
"#;
|
|
||||||
lua.load(chunk)
|
|
||||||
.into_function()
|
|
||||||
.expect("Failed to create default http responder function")
|
|
||||||
}),
|
|
||||||
handle_web_socket,
|
|
||||||
address,
|
|
||||||
});
|
|
||||||
} else {
|
|
||||||
Some("Missing handleRequest and / or handleWebSocket".to_string())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
Err(LuaError::FromLuaConversionError {
|
|
||||||
from: value.type_name(),
|
|
||||||
to: "ServeConfig",
|
|
||||||
message,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,36 +1,27 @@
|
||||||
use std::net::Ipv4Addr;
|
#![allow(unused_variables)]
|
||||||
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
use mlua_luau_scheduler::LuaSpawnExt;
|
||||||
use hyper::header::CONTENT_ENCODING;
|
|
||||||
|
|
||||||
use crate::lune::{scheduler::Scheduler, util::TableBuilder};
|
|
||||||
|
|
||||||
use self::{
|
|
||||||
server::{bind_to_addr, create_server},
|
|
||||||
util::header_map_to_table,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::serde::{
|
|
||||||
compress_decompress::{decompress, CompressDecompressFormat},
|
|
||||||
encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat},
|
|
||||||
};
|
|
||||||
|
|
||||||
mod client;
|
mod client;
|
||||||
mod config;
|
mod config;
|
||||||
mod processing;
|
|
||||||
mod response;
|
|
||||||
mod server;
|
mod server;
|
||||||
mod util;
|
mod util;
|
||||||
mod websocket;
|
mod websocket;
|
||||||
|
|
||||||
use client::{NetClient, NetClientBuilder};
|
use crate::lune::util::TableBuilder;
|
||||||
use config::{RequestConfig, ServeConfig};
|
|
||||||
use websocket::NetWebSocket;
|
|
||||||
|
|
||||||
const DEFAULT_IP_ADDRESS: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1);
|
use self::{
|
||||||
|
client::{NetClient, NetClientBuilder},
|
||||||
|
config::{RequestConfig, ServeConfig},
|
||||||
|
server::serve,
|
||||||
|
util::create_user_agent_header,
|
||||||
|
websocket::NetWebSocket,
|
||||||
|
};
|
||||||
|
|
||||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
use super::serde::encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat};
|
||||||
|
|
||||||
|
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||||
NetClientBuilder::new()
|
NetClientBuilder::new()
|
||||||
.headers(&[("User-Agent", create_user_agent_header())])?
|
.headers(&[("User-Agent", create_user_agent_header())])?
|
||||||
.build()?
|
.build()?
|
||||||
|
@ -46,14 +37,6 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||||
.build_readonly()
|
.build_readonly()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn create_user_agent_header() -> String {
|
|
||||||
let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY")
|
|
||||||
.trim_start_matches("https://github.com/")
|
|
||||||
.split_once('/')
|
|
||||||
.unwrap();
|
|
||||||
format!("{github_owner}-{github_repo}-cli")
|
|
||||||
}
|
|
||||||
|
|
||||||
fn net_json_encode<'lua>(
|
fn net_json_encode<'lua>(
|
||||||
lua: &'lua Lua,
|
lua: &'lua Lua,
|
||||||
(val, pretty): (LuaValue<'lua>, Option<bool>),
|
(val, pretty): (LuaValue<'lua>, Option<bool>),
|
||||||
|
@ -66,68 +49,14 @@ fn net_json_decode<'lua>(lua: &'lua Lua, json: LuaString<'lua>) -> LuaResult<Lua
|
||||||
EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json)
|
EncodeDecodeConfig::from(EncodeDecodeFormat::Json).deserialize_from_string(lua, json)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn net_request<'lua>(lua: &'lua Lua, config: RequestConfig) -> LuaResult<LuaTable<'lua>>
|
async fn net_request(lua: &Lua, config: RequestConfig) -> LuaResult<LuaTable> {
|
||||||
where
|
|
||||||
'lua: 'static, // FIXME: Get rid of static lifetime bound here
|
|
||||||
{
|
|
||||||
// Create and send the request
|
|
||||||
let client = NetClient::from_registry(lua);
|
let client = NetClient::from_registry(lua);
|
||||||
let mut request = client.request(config.method, &config.url);
|
// NOTE: We spawn the request as a background task to free up resources in lua
|
||||||
for (query, values) in config.query {
|
let res = lua.spawn(async move { client.request(config).await });
|
||||||
request = request.query(
|
res.await?.into_lua_table(lua)
|
||||||
&values
|
|
||||||
.iter()
|
|
||||||
.map(|v| (query.as_str(), v))
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
for (header, values) in config.headers {
|
|
||||||
for value in values {
|
|
||||||
request = request.header(header.as_str(), value);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
let res = request
|
|
||||||
.body(config.body.unwrap_or_default())
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.into_lua_err()?;
|
|
||||||
// Extract status, headers
|
|
||||||
let res_status = res.status().as_u16();
|
|
||||||
let res_status_text = res.status().canonical_reason();
|
|
||||||
let res_headers = res.headers().clone();
|
|
||||||
// Read response bytes
|
|
||||||
let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
|
|
||||||
let mut res_decompressed = false;
|
|
||||||
// Check for extra options, decompression
|
|
||||||
if config.options.decompress {
|
|
||||||
let decompress_format = res_headers
|
|
||||||
.iter()
|
|
||||||
.find(|(name, _)| {
|
|
||||||
name.as_str()
|
|
||||||
.eq_ignore_ascii_case(CONTENT_ENCODING.as_str())
|
|
||||||
})
|
|
||||||
.and_then(|(_, value)| value.to_str().ok())
|
|
||||||
.and_then(CompressDecompressFormat::detect_from_header_str);
|
|
||||||
if let Some(format) = decompress_format {
|
|
||||||
res_bytes = decompress(format, res_bytes).await?;
|
|
||||||
res_decompressed = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Construct and return a readonly lua table with results
|
|
||||||
let res_headers_lua = header_map_to_table(lua, res_headers, res_decompressed)?;
|
|
||||||
TableBuilder::new(lua)?
|
|
||||||
.with_value("ok", (200..300).contains(&res_status))?
|
|
||||||
.with_value("statusCode", res_status)?
|
|
||||||
.with_value("statusMessage", res_status_text)?
|
|
||||||
.with_value("headers", res_headers_lua)?
|
|
||||||
.with_value("body", lua.create_string(&res_bytes)?)?
|
|
||||||
.build_readonly()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn net_socket<'lua>(lua: &'lua Lua, url: String) -> LuaResult<LuaTable>
|
async fn net_socket(lua: &Lua, url: String) -> LuaResult<LuaTable> {
|
||||||
where
|
|
||||||
'lua: 'static, // FIXME: Get rid of static lifetime bound here
|
|
||||||
{
|
|
||||||
let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?;
|
let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?;
|
||||||
NetWebSocket::new(ws).into_lua_table(lua)
|
NetWebSocket::new(ws).into_lua_table(lua)
|
||||||
}
|
}
|
||||||
|
@ -135,32 +64,8 @@ where
|
||||||
async fn net_serve<'lua>(
|
async fn net_serve<'lua>(
|
||||||
lua: &'lua Lua,
|
lua: &'lua Lua,
|
||||||
(port, config): (u16, ServeConfig<'lua>),
|
(port, config): (u16, ServeConfig<'lua>),
|
||||||
) -> LuaResult<LuaTable<'lua>>
|
) -> LuaResult<LuaTable<'lua>> {
|
||||||
where
|
serve(lua, port, config).await
|
||||||
'lua: 'static, // FIXME: Get rid of static lifetime bound here
|
|
||||||
{
|
|
||||||
let sched = lua
|
|
||||||
.app_data_ref::<&Scheduler>()
|
|
||||||
.expect("Lua struct is missing scheduler");
|
|
||||||
|
|
||||||
let address: Ipv4Addr = match &config.address {
|
|
||||||
Some(addr) => {
|
|
||||||
let addr_str = addr.to_str()?;
|
|
||||||
|
|
||||||
addr_str
|
|
||||||
.trim_start_matches("http://")
|
|
||||||
.trim_start_matches("https://")
|
|
||||||
.parse()
|
|
||||||
.map_err(|_e| LuaError::RuntimeError(format!(
|
|
||||||
"IP address format is incorrect (expected an IP in the form 'http://0.0.0.0' or '0.0.0.0', got '{addr_str}')"
|
|
||||||
)))?
|
|
||||||
}
|
|
||||||
None => DEFAULT_IP_ADDRESS,
|
|
||||||
};
|
|
||||||
|
|
||||||
let builder = bind_to_addr(address, port)?;
|
|
||||||
|
|
||||||
create_server(lua, &sched, config, builder)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn net_url_encode<'lua>(
|
fn net_url_encode<'lua>(
|
||||||
|
|
|
@ -1,101 +0,0 @@
|
||||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
|
||||||
|
|
||||||
use hyper::{body::to_bytes, Body, Request};
|
|
||||||
|
|
||||||
use mlua::prelude::*;
|
|
||||||
|
|
||||||
use crate::lune::util::TableBuilder;
|
|
||||||
|
|
||||||
static ID_COUNTER: AtomicUsize = AtomicUsize::new(0);
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
|
||||||
pub(super) struct ProcessedRequestId(usize);
|
|
||||||
|
|
||||||
impl ProcessedRequestId {
|
|
||||||
pub fn new() -> Self {
|
|
||||||
// NOTE: This may overflow after a couple billion requests,
|
|
||||||
// but that's completely fine... unless a request is still
|
|
||||||
// alive after billions more arrive and need to be handled
|
|
||||||
Self(ID_COUNTER.fetch_add(1, Ordering::Relaxed))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) struct ProcessedRequest {
|
|
||||||
pub id: ProcessedRequestId,
|
|
||||||
method: String,
|
|
||||||
path: String,
|
|
||||||
query: Vec<(String, String)>,
|
|
||||||
headers: Vec<(String, Vec<u8>)>,
|
|
||||||
body: Vec<u8>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl ProcessedRequest {
|
|
||||||
pub async fn from_request(req: Request<Body>) -> LuaResult<Self> {
|
|
||||||
let (head, body) = req.into_parts();
|
|
||||||
|
|
||||||
// FUTURE: We can do extra processing like async decompression here
|
|
||||||
let body = match to_bytes(body).await {
|
|
||||||
Err(_) => return Err(LuaError::runtime("Failed to read request body bytes")),
|
|
||||||
Ok(b) => b.to_vec(),
|
|
||||||
};
|
|
||||||
|
|
||||||
let method = head.method.to_string().to_ascii_uppercase();
|
|
||||||
|
|
||||||
let mut path = head.uri.path().to_string();
|
|
||||||
if path.is_empty() {
|
|
||||||
path = "/".to_string();
|
|
||||||
}
|
|
||||||
|
|
||||||
let query = head
|
|
||||||
.uri
|
|
||||||
.query()
|
|
||||||
.unwrap_or_default()
|
|
||||||
.split('&')
|
|
||||||
.filter_map(|q| q.split_once('='))
|
|
||||||
.map(|(k, v)| (k.to_string(), v.to_string()))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut headers = Vec::new();
|
|
||||||
let mut header_name = String::new();
|
|
||||||
for (name_opt, value) in head.headers.into_iter() {
|
|
||||||
if let Some(name) = name_opt {
|
|
||||||
header_name = name.to_string();
|
|
||||||
}
|
|
||||||
headers.push((header_name.clone(), value.as_bytes().to_vec()))
|
|
||||||
}
|
|
||||||
|
|
||||||
let id = ProcessedRequestId::new();
|
|
||||||
|
|
||||||
Ok(Self {
|
|
||||||
id,
|
|
||||||
method,
|
|
||||||
path,
|
|
||||||
query,
|
|
||||||
headers,
|
|
||||||
body,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
|
|
||||||
// FUTURE: Make inner tables for query keys that have multiple values?
|
|
||||||
let query = lua.create_table_with_capacity(0, self.query.len())?;
|
|
||||||
for (key, value) in self.query.into_iter() {
|
|
||||||
query.set(key, value)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let headers = lua.create_table_with_capacity(0, self.headers.len())?;
|
|
||||||
for (key, value) in self.headers.into_iter() {
|
|
||||||
headers.set(key, lua.create_string(value)?)?;
|
|
||||||
}
|
|
||||||
|
|
||||||
let body = lua.create_string(self.body)?;
|
|
||||||
|
|
||||||
TableBuilder::new(lua)?
|
|
||||||
.with_value("method", self.method)?
|
|
||||||
.with_value("path", self.path)?
|
|
||||||
.with_value("query", query)?
|
|
||||||
.with_value("headers", headers)?
|
|
||||||
.with_value("body", body)?
|
|
||||||
.build_readonly()
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,223 +0,0 @@
|
||||||
use std::{
|
|
||||||
collections::HashMap,
|
|
||||||
convert::Infallible,
|
|
||||||
net::{Ipv4Addr, SocketAddr},
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
|
|
||||||
use hyper::{
|
|
||||||
server::{conn::AddrIncoming, Builder},
|
|
||||||
service::{make_service_fn, service_fn},
|
|
||||||
Server,
|
|
||||||
};
|
|
||||||
|
|
||||||
use hyper_tungstenite::{is_upgrade_request, upgrade, HyperWebsocket};
|
|
||||||
use mlua::prelude::*;
|
|
||||||
use tokio::sync::{mpsc, oneshot, Mutex};
|
|
||||||
|
|
||||||
use crate::lune::{
|
|
||||||
scheduler::Scheduler,
|
|
||||||
util::{futures::yield_forever, traits::LuaEmitErrorExt, TableBuilder},
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::{
|
|
||||||
config::ServeConfig, processing::ProcessedRequest, response::NetServeResponse,
|
|
||||||
websocket::NetWebSocket,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub(super) fn bind_to_addr(address: Ipv4Addr, port: u16) -> LuaResult<Builder<AddrIncoming>> {
|
|
||||||
let addr = SocketAddr::from((address, port));
|
|
||||||
|
|
||||||
match Server::try_bind(&addr) {
|
|
||||||
Ok(b) => Ok(b),
|
|
||||||
Err(e) => Err(LuaError::external(format!(
|
|
||||||
"Failed to bind to {addr}\n{}",
|
|
||||||
e.to_string()
|
|
||||||
.replace("error creating server listener: ", "> ")
|
|
||||||
))),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(super) fn create_server<'lua>(
|
|
||||||
lua: &'lua Lua,
|
|
||||||
sched: &'lua Scheduler,
|
|
||||||
config: ServeConfig<'lua>,
|
|
||||||
builder: Builder<AddrIncoming>,
|
|
||||||
) -> LuaResult<LuaTable<'lua>>
|
|
||||||
where
|
|
||||||
'lua: 'static, // FIXME: Get rid of static lifetime bound here
|
|
||||||
{
|
|
||||||
// Note that we need to use a mpsc here and not
|
|
||||||
// a oneshot channel since we move the sender
|
|
||||||
// into our table with the stop function
|
|
||||||
let (shutdown_tx, mut shutdown_rx) = mpsc::channel::<()>(1);
|
|
||||||
|
|
||||||
// Communicate between background thread(s) and main lua thread using mpsc and oneshot
|
|
||||||
let (tx_request, mut rx_request) = mpsc::channel::<ProcessedRequest>(64);
|
|
||||||
let (tx_websocket, mut rx_websocket) = mpsc::channel::<HyperWebsocket>(64);
|
|
||||||
let tx_request_arc = Arc::new(tx_request);
|
|
||||||
let tx_websocket_arc = Arc::new(tx_websocket);
|
|
||||||
|
|
||||||
let response_senders = Arc::new(Mutex::new(HashMap::new()));
|
|
||||||
let response_senders_bg = Arc::clone(&response_senders);
|
|
||||||
let response_senders_lua = Arc::clone(&response_senders_bg);
|
|
||||||
|
|
||||||
// Create our background service which will accept
|
|
||||||
// requests, do some processing, then forward to lua
|
|
||||||
let has_websocket_handler = config.handle_web_socket.is_some();
|
|
||||||
let hyper_make_service = make_service_fn(move |_| {
|
|
||||||
let tx_request = Arc::clone(&tx_request_arc);
|
|
||||||
let tx_websocket = Arc::clone(&tx_websocket_arc);
|
|
||||||
let response_senders = Arc::clone(&response_senders_bg);
|
|
||||||
|
|
||||||
let handler = service_fn(move |mut req| {
|
|
||||||
let tx_request = Arc::clone(&tx_request);
|
|
||||||
let tx_websocket = Arc::clone(&tx_websocket);
|
|
||||||
let response_senders = Arc::clone(&response_senders);
|
|
||||||
async move {
|
|
||||||
// FUTURE: Improve error messages when lua is busy and queue is full
|
|
||||||
if has_websocket_handler && is_upgrade_request(&req) {
|
|
||||||
let (response, ws) = match upgrade(&mut req, None) {
|
|
||||||
Err(_) => return Err(LuaError::runtime("Failed to upgrade websocket")),
|
|
||||||
Ok(v) => v,
|
|
||||||
};
|
|
||||||
if (tx_websocket.send(ws).await).is_err() {
|
|
||||||
return Err(LuaError::runtime("Lua handler is busy"));
|
|
||||||
}
|
|
||||||
Ok(response)
|
|
||||||
} else {
|
|
||||||
let processed = ProcessedRequest::from_request(req).await?;
|
|
||||||
let request_id = processed.id;
|
|
||||||
if (tx_request.send(processed).await).is_err() {
|
|
||||||
return Err(LuaError::runtime("Lua handler is busy"));
|
|
||||||
}
|
|
||||||
let (response_tx, response_rx) = oneshot::channel::<NetServeResponse>();
|
|
||||||
response_senders
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.insert(request_id, response_tx);
|
|
||||||
match response_rx.await {
|
|
||||||
Err(_) => Err(LuaError::runtime("Internal Server Error")),
|
|
||||||
Ok(r) => r.into_response(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
async move { Ok::<_, Infallible>(handler) }
|
|
||||||
});
|
|
||||||
|
|
||||||
// Start up our service
|
|
||||||
sched.spawn(async move {
|
|
||||||
let result = builder
|
|
||||||
.http1_only(true) // Web sockets can only use http1
|
|
||||||
.http1_keepalive(true) // Web sockets must be kept alive
|
|
||||||
.serve(hyper_make_service)
|
|
||||||
.with_graceful_shutdown(async move {
|
|
||||||
if shutdown_rx.recv().await.is_none() {
|
|
||||||
// The channel was closed, meaning the serve handle
|
|
||||||
// was garbage collected by lua without being used
|
|
||||||
yield_forever().await;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
if let Err(e) = result.await {
|
|
||||||
eprintln!("Net serve error: {e}")
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Spawn a local thread with access to lua and the same lifetime
|
|
||||||
sched.spawn_local(async move {
|
|
||||||
loop {
|
|
||||||
// Wait for either a request or a websocket to handle,
|
|
||||||
// if we got neither it means both channels were dropped
|
|
||||||
// and our server has stopped, either gracefully or panic
|
|
||||||
let (req, sock) = tokio::select! {
|
|
||||||
req = rx_request.recv() => (req, None),
|
|
||||||
sock = rx_websocket.recv() => (None, sock),
|
|
||||||
};
|
|
||||||
if req.is_none() && sock.is_none() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: The closure here is not really necessary, we
|
|
||||||
// make the closure so that we can use the `?` operator
|
|
||||||
// and make a catch-all for errors in spawn_local below
|
|
||||||
let handle_request = config.handle_request.clone();
|
|
||||||
let handle_web_socket = config.handle_web_socket.clone();
|
|
||||||
let response_senders = Arc::clone(&response_senders_lua);
|
|
||||||
let response_fut = async move {
|
|
||||||
match (req, sock) {
|
|
||||||
(Some(req), _) => {
|
|
||||||
let req_id = req.id;
|
|
||||||
let req_table = req.into_lua_table(lua)?;
|
|
||||||
|
|
||||||
let thread_id = sched.push_back(lua, handle_request, req_table)?;
|
|
||||||
let thread_res = sched.wait_for_thread(lua, thread_id).await?;
|
|
||||||
|
|
||||||
let response = NetServeResponse::from_lua_multi(thread_res, lua)?;
|
|
||||||
let response_sender = response_senders
|
|
||||||
.lock()
|
|
||||||
.await
|
|
||||||
.remove(&req_id)
|
|
||||||
.expect("Response channel was removed unexpectedly");
|
|
||||||
|
|
||||||
// NOTE: We ignore the error here, if the sender is no longer
|
|
||||||
// being listened to its because our client disconnected during
|
|
||||||
// handler being called, which is fine and should not emit errors
|
|
||||||
response_sender.send(response).ok();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
(_, Some(sock)) => {
|
|
||||||
let sock = sock.await.into_lua_err()?;
|
|
||||||
|
|
||||||
let sock_handler = handle_web_socket
|
|
||||||
.as_ref()
|
|
||||||
.cloned()
|
|
||||||
.expect("Got web socket but web socket handler is missing");
|
|
||||||
let sock_table = NetWebSocket::new(sock).into_lua_table(lua)?;
|
|
||||||
|
|
||||||
// NOTE: Web socket handler does not need to send any
|
|
||||||
// response back, the websocket upgrade response is
|
|
||||||
// automatically sent above in the background thread(s)
|
|
||||||
let thread_id = sched.push_back(lua, sock_handler, sock_table)?;
|
|
||||||
let _thread_res = sched.wait_for_thread(lua, thread_id).await?;
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
_ => unreachable!(),
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/*
|
|
||||||
NOTE: It is currently not possible to spawn new background tasks from within
|
|
||||||
another background task with the Lune scheduler since they are locked behind a
|
|
||||||
mutex and we also need that mutex locked to be able to run a background task...
|
|
||||||
|
|
||||||
We need to do some work to make it so our unordered futures queues do
|
|
||||||
not require locking and then we can replace the following bit of code:
|
|
||||||
|
|
||||||
sched.spawn_local(async {
|
|
||||||
if let Err(e) = response_fut.await {
|
|
||||||
lua.emit_error(e);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
*/
|
|
||||||
if let Err(e) = response_fut.await {
|
|
||||||
lua.emit_error(e);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
// Create a new read-only table that contains methods
|
|
||||||
// for manipulating server behavior and shutting it down
|
|
||||||
let handle_stop = move |_, _: ()| match shutdown_tx.try_send(()) {
|
|
||||||
Ok(_) => Ok(()),
|
|
||||||
Err(_) => Err(LuaError::RuntimeError(
|
|
||||||
"Server has already been stopped".to_string(),
|
|
||||||
)),
|
|
||||||
};
|
|
||||||
TableBuilder::new(lua)?
|
|
||||||
.with_function("stop", handle_stop)?
|
|
||||||
.build_readonly()
|
|
||||||
}
|
|
61
src/lune/builtins/net/server/keys.rs
Normal file
61
src/lune/builtins/net/server/keys.rs
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
|
use mlua::prelude::*;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy)]
|
||||||
|
pub(super) struct SvcKeys {
|
||||||
|
key_request: &'static str,
|
||||||
|
key_websocket: Option<&'static str>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SvcKeys {
|
||||||
|
pub(super) fn new<'lua>(
|
||||||
|
lua: &'lua Lua,
|
||||||
|
handle_request: LuaFunction<'lua>,
|
||||||
|
handle_websocket: Option<LuaFunction<'lua>>,
|
||||||
|
) -> LuaResult<Self> {
|
||||||
|
static SERVE_COUNTER: AtomicUsize = AtomicUsize::new(0);
|
||||||
|
let count = SERVE_COUNTER.fetch_add(1, Ordering::Relaxed);
|
||||||
|
|
||||||
|
// NOTE: We leak strings here, but this is an acceptable tradeoff since programs
|
||||||
|
// generally only start one or a couple of servers and they are usually never dropped.
|
||||||
|
// Leaking here lets us keep this struct Copy and access the request handler callbacks
|
||||||
|
// very performantly, significantly reducing the per-request overhead of the server.
|
||||||
|
let key_request: &'static str =
|
||||||
|
Box::leak(format!("__net_serve_request_{count}").into_boxed_str());
|
||||||
|
let key_websocket: Option<&'static str> = if handle_websocket.is_some() {
|
||||||
|
Some(Box::leak(
|
||||||
|
format!("__net_serve_websocket_{count}").into_boxed_str(),
|
||||||
|
))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
lua.set_named_registry_value(key_request, handle_request)?;
|
||||||
|
if let Some(key) = key_websocket {
|
||||||
|
lua.set_named_registry_value(key, handle_websocket.unwrap())?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
key_request,
|
||||||
|
key_websocket,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn has_websocket_handler(&self) -> bool {
|
||||||
|
self.key_websocket.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn request_handler<'lua>(&self, lua: &'lua Lua) -> LuaResult<LuaFunction<'lua>> {
|
||||||
|
lua.named_registry_value(self.key_request)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(super) fn websocket_handler<'lua>(
|
||||||
|
&self,
|
||||||
|
lua: &'lua Lua,
|
||||||
|
) -> LuaResult<Option<LuaFunction<'lua>>> {
|
||||||
|
self.key_websocket
|
||||||
|
.map(|key| lua.named_registry_value(key))
|
||||||
|
.transpose()
|
||||||
|
}
|
||||||
|
}
|
105
src/lune/builtins/net/server/mod.rs
Normal file
105
src/lune/builtins/net/server/mod.rs
Normal file
|
@ -0,0 +1,105 @@
|
||||||
|
use std::{
|
||||||
|
net::SocketAddr,
|
||||||
|
rc::{Rc, Weak},
|
||||||
|
};
|
||||||
|
|
||||||
|
use hyper::server::conn::http1;
|
||||||
|
use hyper_util::rt::TokioIo;
|
||||||
|
use tokio::{net::TcpListener, pin};
|
||||||
|
|
||||||
|
use mlua::prelude::*;
|
||||||
|
use mlua_luau_scheduler::LuaSpawnExt;
|
||||||
|
|
||||||
|
use crate::lune::util::TableBuilder;
|
||||||
|
|
||||||
|
use super::config::ServeConfig;
|
||||||
|
|
||||||
|
mod keys;
|
||||||
|
mod request;
|
||||||
|
mod response;
|
||||||
|
mod service;
|
||||||
|
|
||||||
|
use keys::SvcKeys;
|
||||||
|
use service::Svc;
|
||||||
|
|
||||||
|
pub async fn serve<'lua>(
|
||||||
|
lua: &'lua Lua,
|
||||||
|
port: u16,
|
||||||
|
config: ServeConfig<'lua>,
|
||||||
|
) -> LuaResult<LuaTable<'lua>> {
|
||||||
|
let addr: SocketAddr = (config.address, port).into();
|
||||||
|
let listener = TcpListener::bind(addr).await?;
|
||||||
|
|
||||||
|
let (lua_svc, lua_inner) = {
|
||||||
|
let rc = lua
|
||||||
|
.app_data_ref::<Weak<Lua>>()
|
||||||
|
.expect("Missing weak lua ref")
|
||||||
|
.upgrade()
|
||||||
|
.expect("Lua was dropped unexpectedly");
|
||||||
|
(Rc::clone(&rc), rc)
|
||||||
|
};
|
||||||
|
|
||||||
|
let keys = SvcKeys::new(lua, config.handle_request, config.handle_web_socket)?;
|
||||||
|
let svc = Svc {
|
||||||
|
lua: lua_svc,
|
||||||
|
addr,
|
||||||
|
keys,
|
||||||
|
};
|
||||||
|
|
||||||
|
let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false);
|
||||||
|
lua.spawn_local(async move {
|
||||||
|
let mut shutdown_rx_outer = shutdown_rx.clone();
|
||||||
|
loop {
|
||||||
|
// Create futures for accepting new connections and shutting down
|
||||||
|
let fut_shutdown = shutdown_rx_outer.changed();
|
||||||
|
let fut_accept = async {
|
||||||
|
let stream = match listener.accept().await {
|
||||||
|
Err(_) => return,
|
||||||
|
Ok((s, _)) => s,
|
||||||
|
};
|
||||||
|
|
||||||
|
let io = TokioIo::new(stream);
|
||||||
|
let svc = svc.clone();
|
||||||
|
let mut shutdown_rx_inner = shutdown_rx.clone();
|
||||||
|
|
||||||
|
lua_inner.spawn_local(async move {
|
||||||
|
let conn = http1::Builder::new()
|
||||||
|
.keep_alive(true) // Web sockets need this
|
||||||
|
.serve_connection(io, svc)
|
||||||
|
.with_upgrades();
|
||||||
|
// NOTE: Because we need to use keep_alive for websockets, we need to
|
||||||
|
// also manually poll this future and handle the shutdown signal here
|
||||||
|
pin!(conn);
|
||||||
|
tokio::select! {
|
||||||
|
_ = conn.as_mut() => {}
|
||||||
|
_ = shutdown_rx_inner.changed() => {
|
||||||
|
conn.as_mut().graceful_shutdown();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
// Wait for either a new connection or a shutdown signal
|
||||||
|
tokio::select! {
|
||||||
|
_ = fut_accept => {}
|
||||||
|
res = fut_shutdown => {
|
||||||
|
// NOTE: We will only get a RecvError here if the serve handle is dropped,
|
||||||
|
// this means lua has garbage collected it and the user does not want
|
||||||
|
// to manually stop the server using the serve handle. Run forever.
|
||||||
|
if res.is_ok() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
TableBuilder::new(lua)?
|
||||||
|
.with_value("ip", addr.ip().to_string())?
|
||||||
|
.with_value("port", addr.port())?
|
||||||
|
.with_function("stop", move |lua, _: ()| match shutdown_tx.send(true) {
|
||||||
|
Ok(_) => Ok(()),
|
||||||
|
Err(_) => Err(LuaError::runtime("Server already stopped")),
|
||||||
|
})?
|
||||||
|
.build_readonly()
|
||||||
|
}
|
46
src/lune/builtins/net/server/request.rs
Normal file
46
src/lune/builtins/net/server/request.rs
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
use std::{collections::HashMap, net::SocketAddr};
|
||||||
|
|
||||||
|
use http::request::Parts;
|
||||||
|
|
||||||
|
use mlua::prelude::*;
|
||||||
|
|
||||||
|
pub(super) struct LuaRequest {
|
||||||
|
pub(super) _remote_addr: SocketAddr,
|
||||||
|
pub(super) head: Parts,
|
||||||
|
pub(super) body: Vec<u8>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LuaUserData for LuaRequest {
|
||||||
|
fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) {
|
||||||
|
fields.add_field_method_get("method", |_, this| {
|
||||||
|
Ok(this.head.method.as_str().to_string())
|
||||||
|
});
|
||||||
|
|
||||||
|
fields.add_field_method_get("path", |_, this| Ok(this.head.uri.path().to_string()));
|
||||||
|
|
||||||
|
fields.add_field_method_get("query", |_, this| {
|
||||||
|
let query: HashMap<String, String> = this
|
||||||
|
.head
|
||||||
|
.uri
|
||||||
|
.query()
|
||||||
|
.unwrap_or_default()
|
||||||
|
.split('&')
|
||||||
|
.filter_map(|q| q.split_once('='))
|
||||||
|
.map(|(k, v)| (k.to_string(), v.to_string()))
|
||||||
|
.collect();
|
||||||
|
Ok(query)
|
||||||
|
});
|
||||||
|
|
||||||
|
fields.add_field_method_get("headers", |_, this| {
|
||||||
|
let headers: HashMap<String, Vec<u8>> = this
|
||||||
|
.head
|
||||||
|
.headers
|
||||||
|
.iter()
|
||||||
|
.map(|(k, v)| (k.as_str().to_string(), v.as_bytes().to_vec()))
|
||||||
|
.collect();
|
||||||
|
Ok(headers)
|
||||||
|
});
|
||||||
|
|
||||||
|
fields.add_field_method_get("body", |lua, this| lua.create_string(&this.body));
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,52 +1,55 @@
|
||||||
use std::collections::HashMap;
|
use std::str::FromStr;
|
||||||
|
|
||||||
|
use http_body_util::Full;
|
||||||
|
use hyper::{
|
||||||
|
body::Bytes,
|
||||||
|
header::{HeaderName, HeaderValue},
|
||||||
|
HeaderMap, Response,
|
||||||
|
};
|
||||||
|
|
||||||
use hyper::{Body, Response};
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
#[derive(Debug, Clone, Copy)]
|
||||||
pub enum NetServeResponseKind {
|
pub(super) enum LuaResponseKind {
|
||||||
PlainText,
|
PlainText,
|
||||||
Table,
|
Table,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug)]
|
pub(super) struct LuaResponse {
|
||||||
pub struct NetServeResponse {
|
pub(super) kind: LuaResponseKind,
|
||||||
kind: NetServeResponseKind,
|
pub(super) status: u16,
|
||||||
status: u16,
|
pub(super) headers: HeaderMap,
|
||||||
headers: HashMap<String, Vec<u8>>,
|
pub(super) body: Option<Vec<u8>>,
|
||||||
body: Option<Vec<u8>>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl NetServeResponse {
|
impl LuaResponse {
|
||||||
pub fn into_response(self) -> LuaResult<Response<Body>> {
|
pub(super) fn into_response(self) -> LuaResult<Response<Full<Bytes>>> {
|
||||||
Ok(match self.kind {
|
Ok(match self.kind {
|
||||||
NetServeResponseKind::PlainText => Response::builder()
|
LuaResponseKind::PlainText => Response::builder()
|
||||||
.status(200)
|
.status(200)
|
||||||
.header("Content-Type", "text/plain")
|
.header("Content-Type", "text/plain")
|
||||||
.body(Body::from(self.body.unwrap()))
|
.body(Full::new(Bytes::from(self.body.unwrap())))
|
||||||
.into_lua_err()?,
|
.into_lua_err()?,
|
||||||
NetServeResponseKind::Table => {
|
LuaResponseKind::Table => {
|
||||||
let mut response = Response::builder();
|
let mut response = Response::builder()
|
||||||
for (key, value) in self.headers {
|
|
||||||
response = response.header(&key, value);
|
|
||||||
}
|
|
||||||
response
|
|
||||||
.status(self.status)
|
.status(self.status)
|
||||||
.body(Body::from(self.body.unwrap_or_default()))
|
.body(Full::new(Bytes::from(self.body.unwrap_or_default())))
|
||||||
.into_lua_err()?
|
.into_lua_err()?;
|
||||||
|
response.headers_mut().extend(self.headers);
|
||||||
|
response
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'lua> FromLua<'lua> for NetServeResponse {
|
impl FromLua<'_> for LuaResponse {
|
||||||
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
|
fn from_lua(value: LuaValue, _: &Lua) -> LuaResult<Self> {
|
||||||
match value {
|
match value {
|
||||||
// Plain strings from the handler are plaintext responses
|
// Plain strings from the handler are plaintext responses
|
||||||
LuaValue::String(s) => Ok(Self {
|
LuaValue::String(s) => Ok(Self {
|
||||||
kind: NetServeResponseKind::PlainText,
|
kind: LuaResponseKind::PlainText,
|
||||||
status: 200,
|
status: 200,
|
||||||
headers: HashMap::new(),
|
headers: HeaderMap::new(),
|
||||||
body: Some(s.as_bytes().to_vec()),
|
body: Some(s.as_bytes().to_vec()),
|
||||||
}),
|
}),
|
||||||
// Tables are more detailed responses with potential status, headers, body
|
// Tables are more detailed responses with potential status, headers, body
|
||||||
|
@ -55,18 +58,20 @@ impl<'lua> FromLua<'lua> for NetServeResponse {
|
||||||
let headers: Option<LuaTable> = t.get("headers")?;
|
let headers: Option<LuaTable> = t.get("headers")?;
|
||||||
let body: Option<LuaString> = t.get("body")?;
|
let body: Option<LuaString> = t.get("body")?;
|
||||||
|
|
||||||
let mut headers_map = HashMap::new();
|
let mut headers_map = HeaderMap::new();
|
||||||
if let Some(headers) = headers {
|
if let Some(headers) = headers {
|
||||||
for pair in headers.pairs::<String, LuaString>() {
|
for pair in headers.pairs::<String, LuaString>() {
|
||||||
let (h, v) = pair?;
|
let (h, v) = pair?;
|
||||||
headers_map.insert(h, v.as_bytes().to_vec());
|
let name = HeaderName::from_str(&h).into_lua_err()?;
|
||||||
|
let value = HeaderValue::from_bytes(v.as_bytes()).into_lua_err()?;
|
||||||
|
headers_map.insert(name, value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let body_bytes = body.map(|s| s.as_bytes().to_vec());
|
let body_bytes = body.map(|s| s.as_bytes().to_vec());
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
kind: NetServeResponseKind::Table,
|
kind: LuaResponseKind::Table,
|
||||||
status: status.unwrap_or(200),
|
status: status.unwrap_or(200),
|
||||||
headers: headers_map,
|
headers: headers_map,
|
||||||
body: body_bytes,
|
body: body_bytes,
|
81
src/lune/builtins/net/server/service.rs
Normal file
81
src/lune/builtins/net/server/service.rs
Normal file
|
@ -0,0 +1,81 @@
|
||||||
|
use std::{future::Future, net::SocketAddr, pin::Pin, rc::Rc};
|
||||||
|
|
||||||
|
use http_body_util::{BodyExt, Full};
|
||||||
|
use hyper::{
|
||||||
|
body::{Bytes, Incoming},
|
||||||
|
service::Service,
|
||||||
|
Request, Response,
|
||||||
|
};
|
||||||
|
use hyper_tungstenite::{is_upgrade_request, upgrade};
|
||||||
|
|
||||||
|
use mlua::prelude::*;
|
||||||
|
use mlua_luau_scheduler::{LuaSchedulerExt, LuaSpawnExt};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
super::websocket::NetWebSocket, keys::SvcKeys, request::LuaRequest, response::LuaResponse,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(super) struct Svc {
|
||||||
|
pub(super) lua: Rc<Lua>,
|
||||||
|
pub(super) addr: SocketAddr,
|
||||||
|
pub(super) keys: SvcKeys,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Service<Request<Incoming>> for Svc {
|
||||||
|
type Response = Response<Full<Bytes>>;
|
||||||
|
type Error = LuaError;
|
||||||
|
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
|
||||||
|
|
||||||
|
fn call(&self, req: Request<Incoming>) -> Self::Future {
|
||||||
|
let lua = self.lua.clone();
|
||||||
|
let addr = self.addr;
|
||||||
|
let keys = self.keys;
|
||||||
|
|
||||||
|
if keys.has_websocket_handler() && is_upgrade_request(&req) {
|
||||||
|
Box::pin(async move {
|
||||||
|
let (res, sock) = upgrade(req, None).into_lua_err()?;
|
||||||
|
|
||||||
|
let lua_inner = lua.clone();
|
||||||
|
lua.spawn_local(async move {
|
||||||
|
let sock = sock.await.unwrap();
|
||||||
|
let lua_sock = NetWebSocket::new(sock);
|
||||||
|
let lua_tab = lua_sock.into_lua_table(&lua_inner).unwrap();
|
||||||
|
|
||||||
|
let handler_websocket: LuaFunction =
|
||||||
|
keys.websocket_handler(&lua_inner).unwrap().unwrap();
|
||||||
|
|
||||||
|
lua_inner
|
||||||
|
.push_thread_back(handler_websocket, lua_tab)
|
||||||
|
.unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(res)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
let (head, body) = req.into_parts();
|
||||||
|
|
||||||
|
Box::pin(async move {
|
||||||
|
let handler_request: LuaFunction = keys.request_handler(&lua).unwrap();
|
||||||
|
|
||||||
|
let body = body.collect().await.into_lua_err()?;
|
||||||
|
let body = body.to_bytes().to_vec();
|
||||||
|
|
||||||
|
let lua_req = LuaRequest {
|
||||||
|
_remote_addr: addr,
|
||||||
|
head,
|
||||||
|
body,
|
||||||
|
};
|
||||||
|
|
||||||
|
let thread_id = lua.push_thread_back(handler_request, lua_req)?;
|
||||||
|
lua.track_thread(thread_id);
|
||||||
|
lua.wait_for_thread(thread_id).await;
|
||||||
|
let thread_res = lua
|
||||||
|
.get_thread_result(thread_id)
|
||||||
|
.expect("Missing handler thread result")?;
|
||||||
|
|
||||||
|
LuaResponse::from_lua_multi(thread_res, &lua)?.into_response()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,14 +1,20 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use hyper::{
|
use hyper::header::{CONTENT_ENCODING, CONTENT_LENGTH};
|
||||||
header::{CONTENT_ENCODING, CONTENT_LENGTH},
|
use reqwest::header::HeaderMap;
|
||||||
HeaderMap,
|
|
||||||
};
|
|
||||||
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
use crate::lune::util::TableBuilder;
|
use crate::lune::util::TableBuilder;
|
||||||
|
|
||||||
|
pub fn create_user_agent_header() -> String {
|
||||||
|
let (github_owner, github_repo) = env!("CARGO_PKG_REPOSITORY")
|
||||||
|
.trim_start_matches("https://github.com/")
|
||||||
|
.split_once('/')
|
||||||
|
.unwrap();
|
||||||
|
format!("{github_owner}-{github_repo}-cli")
|
||||||
|
}
|
||||||
|
|
||||||
pub fn header_map_to_table(
|
pub fn header_map_to_table(
|
||||||
lua: &Lua,
|
lua: &Lua,
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
use std::sync::Arc;
|
use std::sync::{
|
||||||
|
atomic::{AtomicBool, AtomicU16, Ordering},
|
||||||
|
Arc,
|
||||||
|
};
|
||||||
|
|
||||||
use hyper::upgrade::Upgraded;
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
use futures_util::{
|
use futures_util::{
|
||||||
|
@ -9,7 +11,6 @@ use futures_util::{
|
||||||
};
|
};
|
||||||
use tokio::{
|
use tokio::{
|
||||||
io::{AsyncRead, AsyncWrite},
|
io::{AsyncRead, AsyncWrite},
|
||||||
net::TcpStream,
|
|
||||||
sync::Mutex as AsyncMutex,
|
sync::Mutex as AsyncMutex,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -20,25 +21,25 @@ use hyper_tungstenite::{
|
||||||
},
|
},
|
||||||
WebSocketStream,
|
WebSocketStream,
|
||||||
};
|
};
|
||||||
use tokio_tungstenite::MaybeTlsStream;
|
|
||||||
|
|
||||||
use crate::lune::util::TableBuilder;
|
use crate::lune::util::TableBuilder;
|
||||||
|
|
||||||
|
// Wrapper implementation for compatibility and changing colon syntax to dot syntax
|
||||||
const WEB_SOCKET_IMPL_LUA: &str = r#"
|
const WEB_SOCKET_IMPL_LUA: &str = r#"
|
||||||
return freeze(setmetatable({
|
return freeze(setmetatable({
|
||||||
close = function(...)
|
close = function(...)
|
||||||
return close(websocket, ...)
|
return websocket:close(...)
|
||||||
end,
|
end,
|
||||||
send = function(...)
|
send = function(...)
|
||||||
return send(websocket, ...)
|
return websocket:send(...)
|
||||||
end,
|
end,
|
||||||
next = function(...)
|
next = function(...)
|
||||||
return next(websocket, ...)
|
return websocket:next(...)
|
||||||
end,
|
end,
|
||||||
}, {
|
}, {
|
||||||
__index = function(self, key)
|
__index = function(self, key)
|
||||||
if key == "closeCode" then
|
if key == "closeCode" then
|
||||||
return close_code(websocket)
|
return websocket.closeCode
|
||||||
end
|
end
|
||||||
end,
|
end,
|
||||||
}))
|
}))
|
||||||
|
@ -46,7 +47,8 @@ return freeze(setmetatable({
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct NetWebSocket<T> {
|
pub struct NetWebSocket<T> {
|
||||||
close_code: Arc<AsyncMutex<Option<u16>>>,
|
close_code_exists: Arc<AtomicBool>,
|
||||||
|
close_code_value: Arc<AtomicU16>,
|
||||||
read_stream: Arc<AsyncMutex<SplitStream<WebSocketStream<T>>>>,
|
read_stream: Arc<AsyncMutex<SplitStream<WebSocketStream<T>>>>,
|
||||||
write_stream: Arc<AsyncMutex<SplitSink<WebSocketStream<T>, WsMessage>>>,
|
write_stream: Arc<AsyncMutex<SplitSink<WebSocketStream<T>, WsMessage>>>,
|
||||||
}
|
}
|
||||||
|
@ -54,7 +56,8 @@ pub struct NetWebSocket<T> {
|
||||||
impl<T> Clone for NetWebSocket<T> {
|
impl<T> Clone for NetWebSocket<T> {
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
close_code: Arc::clone(&self.close_code),
|
close_code_exists: Arc::clone(&self.close_code_exists),
|
||||||
|
close_code_value: Arc::clone(&self.close_code_value),
|
||||||
read_stream: Arc::clone(&self.read_stream),
|
read_stream: Arc::clone(&self.read_stream),
|
||||||
write_stream: Arc::clone(&self.write_stream),
|
write_stream: Arc::clone(&self.write_stream),
|
||||||
}
|
}
|
||||||
|
@ -63,22 +66,78 @@ impl<T> Clone for NetWebSocket<T> {
|
||||||
|
|
||||||
impl<T> NetWebSocket<T>
|
impl<T> NetWebSocket<T>
|
||||||
where
|
where
|
||||||
T: AsyncRead + AsyncWrite + Unpin,
|
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||||
{
|
{
|
||||||
pub fn new(value: WebSocketStream<T>) -> Self {
|
pub fn new(value: WebSocketStream<T>) -> Self {
|
||||||
let (write, read) = value.split();
|
let (write, read) = value.split();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
close_code: Arc::new(AsyncMutex::new(None)),
|
close_code_exists: Arc::new(AtomicBool::new(false)),
|
||||||
|
close_code_value: Arc::new(AtomicU16::new(0)),
|
||||||
read_stream: Arc::new(AsyncMutex::new(read)),
|
read_stream: Arc::new(AsyncMutex::new(read)),
|
||||||
write_stream: Arc::new(AsyncMutex::new(write)),
|
write_stream: Arc::new(AsyncMutex::new(write)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn into_lua_table_with_env<'lua>(
|
fn get_close_code(&self) -> Option<u16> {
|
||||||
lua: &'lua Lua,
|
if self.close_code_exists.load(Ordering::Relaxed) {
|
||||||
env: LuaTable<'lua>,
|
Some(self.close_code_value.load(Ordering::Relaxed))
|
||||||
) -> LuaResult<LuaTable<'lua>> {
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn set_close_code(&self, code: u16) {
|
||||||
|
self.close_code_exists.store(true, Ordering::Relaxed);
|
||||||
|
self.close_code_value.store(code, Ordering::Relaxed);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn send(&self, msg: WsMessage) -> LuaResult<()> {
|
||||||
|
let mut ws = self.write_stream.lock().await;
|
||||||
|
ws.send(msg).await.into_lua_err()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn next(&self) -> LuaResult<Option<WsMessage>> {
|
||||||
|
let mut ws = self.read_stream.lock().await;
|
||||||
|
ws.next().await.transpose().into_lua_err()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn close(&self, code: Option<u16>) -> LuaResult<()> {
|
||||||
|
if self.close_code_exists.load(Ordering::Relaxed) {
|
||||||
|
return Err(LuaError::runtime("Socket has already been closed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
self.send(WsMessage::Close(Some(WsCloseFrame {
|
||||||
|
code: match code {
|
||||||
|
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
|
||||||
|
Some(code) => {
|
||||||
|
return Err(LuaError::runtime(format!(
|
||||||
|
"Close code must be between 1000 and 4999, got {code}"
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
None => WsCloseCode::Normal,
|
||||||
|
},
|
||||||
|
reason: "".into(),
|
||||||
|
})))
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut ws = self.write_stream.lock().await;
|
||||||
|
ws.close().await.into_lua_err()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn into_lua_table(self, lua: &Lua) -> LuaResult<LuaTable> {
|
||||||
|
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
|
||||||
|
let table_freeze = lua
|
||||||
|
.globals()
|
||||||
|
.get::<_, LuaTable>("table")?
|
||||||
|
.get::<_, LuaFunction>("freeze")?;
|
||||||
|
|
||||||
|
let env = TableBuilder::new(lua)?
|
||||||
|
.with_value("websocket", self.clone())?
|
||||||
|
.with_value("setmetatable", setmetatable)?
|
||||||
|
.with_value("freeze", table_freeze)?
|
||||||
|
.build_readonly()?;
|
||||||
|
|
||||||
lua.load(WEB_SOCKET_IMPL_LUA)
|
lua.load(WEB_SOCKET_IMPL_LUA)
|
||||||
.set_name("websocket")
|
.set_name("websocket")
|
||||||
.set_environment(env)
|
.set_environment(env)
|
||||||
|
@ -86,149 +145,46 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type NetWebSocketStreamClient = MaybeTlsStream<TcpStream>;
|
impl<T> LuaUserData for NetWebSocket<T>
|
||||||
impl NetWebSocket<NetWebSocketStreamClient> {
|
|
||||||
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
|
|
||||||
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
|
|
||||||
let table_freeze = lua
|
|
||||||
.globals()
|
|
||||||
.get::<_, LuaTable>("table")?
|
|
||||||
.get::<_, LuaFunction>("freeze")?;
|
|
||||||
let socket_env = TableBuilder::new(lua)?
|
|
||||||
.with_value("websocket", self)?
|
|
||||||
.with_function("close_code", close_code::<NetWebSocketStreamClient>)?
|
|
||||||
.with_async_function("close", close::<NetWebSocketStreamClient>)?
|
|
||||||
.with_async_function("send", send::<NetWebSocketStreamClient>)?
|
|
||||||
.with_async_function("next", next::<NetWebSocketStreamClient>)?
|
|
||||||
.with_value("setmetatable", setmetatable)?
|
|
||||||
.with_value("freeze", table_freeze)?
|
|
||||||
.build_readonly()?;
|
|
||||||
Self::into_lua_table_with_env(lua, socket_env)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type NetWebSocketStreamServer = Upgraded;
|
|
||||||
impl NetWebSocket<NetWebSocketStreamServer> {
|
|
||||||
pub fn into_lua_table(self, lua: &'static Lua) -> LuaResult<LuaTable> {
|
|
||||||
let setmetatable = lua.globals().get::<_, LuaFunction>("setmetatable")?;
|
|
||||||
let table_freeze = lua
|
|
||||||
.globals()
|
|
||||||
.get::<_, LuaTable>("table")?
|
|
||||||
.get::<_, LuaFunction>("freeze")?;
|
|
||||||
let socket_env = TableBuilder::new(lua)?
|
|
||||||
.with_value("websocket", self)?
|
|
||||||
.with_function("close_code", close_code::<NetWebSocketStreamServer>)?
|
|
||||||
.with_async_function("close", close::<NetWebSocketStreamServer>)?
|
|
||||||
.with_async_function("send", send::<NetWebSocketStreamServer>)?
|
|
||||||
.with_async_function("next", next::<NetWebSocketStreamServer>)?
|
|
||||||
.with_value("setmetatable", setmetatable)?
|
|
||||||
.with_value("freeze", table_freeze)?
|
|
||||||
.build_readonly()?;
|
|
||||||
Self::into_lua_table_with_env(lua, socket_env)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T> LuaUserData for NetWebSocket<T> {}
|
|
||||||
|
|
||||||
fn close_code<'lua, T>(
|
|
||||||
_lua: &'lua Lua,
|
|
||||||
socket: LuaUserDataRef<'lua, NetWebSocket<T>>,
|
|
||||||
) -> LuaResult<LuaValue<'lua>>
|
|
||||||
where
|
where
|
||||||
T: AsyncRead + AsyncWrite + Unpin,
|
T: AsyncRead + AsyncWrite + Unpin + 'static,
|
||||||
{
|
{
|
||||||
Ok(
|
fn add_fields<'lua, F: LuaUserDataFields<'lua, Self>>(fields: &mut F) {
|
||||||
match *socket
|
fields.add_field_method_get("closeCode", |_, this| Ok(this.get_close_code()));
|
||||||
.close_code
|
}
|
||||||
.try_lock()
|
|
||||||
.expect("Failed to lock close code")
|
|
||||||
{
|
|
||||||
Some(code) => LuaValue::Number(code as f64),
|
|
||||||
None => LuaValue::Nil,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn close<'lua, T>(
|
fn add_methods<'lua, M: LuaUserDataMethods<'lua, Self>>(methods: &mut M) {
|
||||||
_lua: &'lua Lua,
|
methods.add_async_method("close", |lua, this, code: Option<u16>| async move {
|
||||||
(socket, code): (LuaUserDataRef<'lua, NetWebSocket<T>>, Option<u16>),
|
this.close(code).await
|
||||||
) -> LuaResult<()>
|
});
|
||||||
where
|
|
||||||
T: AsyncRead + AsyncWrite + Unpin,
|
|
||||||
{
|
|
||||||
let mut ws = socket.write_stream.lock().await;
|
|
||||||
|
|
||||||
ws.send(WsMessage::Close(Some(WsCloseFrame {
|
methods.add_async_method(
|
||||||
code: match code {
|
"send",
|
||||||
Some(code) if (1000..=4999).contains(&code) => WsCloseCode::from(code),
|
|_, this, (string, as_binary): (LuaString, Option<bool>)| async move {
|
||||||
Some(code) => {
|
this.send(if as_binary.unwrap_or_default() {
|
||||||
return Err(LuaError::RuntimeError(format!(
|
WsMessage::Binary(string.as_bytes().to_vec())
|
||||||
"Close code must be between 1000 and 4999, got {code}"
|
} else {
|
||||||
)))
|
let s = string.to_str().into_lua_err()?;
|
||||||
|
WsMessage::Text(s.to_string())
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
methods.add_async_method("next", |lua, this, _: ()| async move {
|
||||||
|
let msg = this.next().await?;
|
||||||
|
|
||||||
|
if let Some(WsMessage::Close(Some(frame))) = msg.as_ref() {
|
||||||
|
this.set_close_code(frame.code.into());
|
||||||
}
|
}
|
||||||
None => WsCloseCode::Normal,
|
|
||||||
},
|
|
||||||
reason: "".into(),
|
|
||||||
})))
|
|
||||||
.await
|
|
||||||
.into_lua_err()?;
|
|
||||||
|
|
||||||
let res = ws.close();
|
Ok(match msg {
|
||||||
res.await.into_lua_err()
|
Some(WsMessage::Binary(bin)) => LuaValue::String(lua.create_string(bin)?),
|
||||||
}
|
Some(WsMessage::Text(txt)) => LuaValue::String(lua.create_string(txt)?),
|
||||||
|
Some(WsMessage::Close(_)) | None => LuaValue::Nil,
|
||||||
async fn send<'lua, T>(
|
// Ignore ping/pong/frame messages, they are handled by tungstenite
|
||||||
_lua: &'lua Lua,
|
msg => unreachable!("Unhandled message: {:?}", msg),
|
||||||
(socket, string, as_binary): (
|
})
|
||||||
LuaUserDataRef<'lua, NetWebSocket<T>>,
|
});
|
||||||
LuaString<'lua>,
|
|
||||||
Option<bool>,
|
|
||||||
),
|
|
||||||
) -> LuaResult<()>
|
|
||||||
where
|
|
||||||
T: AsyncRead + AsyncWrite + Unpin,
|
|
||||||
{
|
|
||||||
let msg = if matches!(as_binary, Some(true)) {
|
|
||||||
WsMessage::Binary(string.as_bytes().to_vec())
|
|
||||||
} else {
|
|
||||||
let s = string.to_str().into_lua_err()?;
|
|
||||||
WsMessage::Text(s.to_string())
|
|
||||||
};
|
|
||||||
let mut ws = socket.write_stream.lock().await;
|
|
||||||
ws.send(msg).await.into_lua_err()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn next<'lua, T>(
|
|
||||||
lua: &'lua Lua,
|
|
||||||
socket: LuaUserDataRef<'lua, NetWebSocket<T>>,
|
|
||||||
) -> LuaResult<LuaValue<'lua>>
|
|
||||||
where
|
|
||||||
T: AsyncRead + AsyncWrite + Unpin,
|
|
||||||
{
|
|
||||||
let mut ws = socket.read_stream.lock().await;
|
|
||||||
let item = ws.next().await.transpose().into_lua_err();
|
|
||||||
let msg = match item {
|
|
||||||
Ok(Some(WsMessage::Close(msg))) => {
|
|
||||||
if let Some(msg) = &msg {
|
|
||||||
let mut code = socket.close_code.lock().await;
|
|
||||||
*code = Some(msg.code.into());
|
|
||||||
}
|
|
||||||
Ok(Some(WsMessage::Close(msg)))
|
|
||||||
}
|
|
||||||
val => val,
|
|
||||||
}?;
|
|
||||||
while let Some(msg) = &msg {
|
|
||||||
let msg_string_opt = match msg {
|
|
||||||
WsMessage::Binary(bin) => Some(lua.create_string(bin)?),
|
|
||||||
WsMessage::Text(txt) => Some(lua.create_string(txt)?),
|
|
||||||
// Stop waiting for next message if we get a close message
|
|
||||||
WsMessage::Close(_) => return Ok(LuaValue::Nil),
|
|
||||||
// Ignore ping/pong/frame messages, they are handled by tungstenite
|
|
||||||
_ => None,
|
|
||||||
};
|
|
||||||
if let Some(msg_string) = msg_string_opt {
|
|
||||||
return Ok(LuaValue::String(msg_string));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
Ok(LuaValue::Nil)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,13 +5,11 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
use mlua_luau_scheduler::{Functions, LuaSpawnExt};
|
||||||
use os_str_bytes::RawOsString;
|
use os_str_bytes::RawOsString;
|
||||||
use tokio::io::AsyncWriteExt;
|
use tokio::io::AsyncWriteExt;
|
||||||
|
|
||||||
use crate::lune::{
|
use crate::lune::util::{paths::CWD, TableBuilder};
|
||||||
scheduler::Scheduler,
|
|
||||||
util::{paths::CWD, TableBuilder},
|
|
||||||
};
|
|
||||||
|
|
||||||
mod tee_writer;
|
mod tee_writer;
|
||||||
|
|
||||||
|
@ -21,12 +19,7 @@ use options::ProcessSpawnOptions;
|
||||||
mod wait_for_child;
|
mod wait_for_child;
|
||||||
use wait_for_child::{wait_for_child, WaitForChildResult};
|
use wait_for_child::{wait_for_child, WaitForChildResult};
|
||||||
|
|
||||||
const PROCESS_EXIT_IMPL_LUA: &str = r#"
|
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||||
exit(...)
|
|
||||||
yield()
|
|
||||||
"#;
|
|
||||||
|
|
||||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
|
||||||
let cwd_str = {
|
let cwd_str = {
|
||||||
let cwd_str = CWD.to_string_lossy().to_string();
|
let cwd_str = CWD.to_string_lossy().to_string();
|
||||||
if !cwd_str.ends_with(path::MAIN_SEPARATOR) {
|
if !cwd_str.ends_with(path::MAIN_SEPARATOR) {
|
||||||
|
@ -56,30 +49,9 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
||||||
.build_readonly()?,
|
.build_readonly()?,
|
||||||
)?
|
)?
|
||||||
.build_readonly()?;
|
.build_readonly()?;
|
||||||
// Create our process exit function, this is a bit involved since
|
// Create our process exit function, the scheduler crate provides this
|
||||||
// we have no way to yield from c / rust, we need to load a lua
|
let fns = Functions::new(lua)?;
|
||||||
// chunk that will set the exit code and yield for us instead
|
let process_exit = fns.exit;
|
||||||
let coroutine_yield = lua
|
|
||||||
.globals()
|
|
||||||
.get::<_, LuaTable>("coroutine")?
|
|
||||||
.get::<_, LuaFunction>("yield")?;
|
|
||||||
let set_scheduler_exit_code = lua.create_function(|lua, code: Option<u8>| {
|
|
||||||
let sched = lua
|
|
||||||
.app_data_ref::<&Scheduler>()
|
|
||||||
.expect("Lua struct is missing scheduler");
|
|
||||||
sched.set_exit_code(code.unwrap_or_default());
|
|
||||||
Ok(())
|
|
||||||
})?;
|
|
||||||
let process_exit = lua
|
|
||||||
.load(PROCESS_EXIT_IMPL_LUA)
|
|
||||||
.set_name("=process.exit")
|
|
||||||
.set_environment(
|
|
||||||
TableBuilder::new(lua)?
|
|
||||||
.with_value("yield", coroutine_yield)?
|
|
||||||
.with_value("exit", set_scheduler_exit_code)?
|
|
||||||
.build_readonly()?,
|
|
||||||
)
|
|
||||||
.into_function()?;
|
|
||||||
// Create the full process table
|
// Create the full process table
|
||||||
TableBuilder::new(lua)?
|
TableBuilder::new(lua)?
|
||||||
.with_value("os", os)?
|
.with_value("os", os)?
|
||||||
|
@ -165,22 +137,10 @@ async fn process_spawn(
|
||||||
lua: &Lua,
|
lua: &Lua,
|
||||||
(program, args, options): (String, Option<Vec<String>>, ProcessSpawnOptions),
|
(program, args, options): (String, Option<Vec<String>>, ProcessSpawnOptions),
|
||||||
) -> LuaResult<LuaTable> {
|
) -> LuaResult<LuaTable> {
|
||||||
/*
|
let res = lua
|
||||||
Spawn the new process in the background, letting the tokio
|
|
||||||
runtime place it on a different thread if possible / necessary
|
|
||||||
|
|
||||||
Note that we have to use our scheduler here, we can't
|
|
||||||
be using tokio::task::spawn directly because our lua
|
|
||||||
scheduler would not drive those futures to completion
|
|
||||||
*/
|
|
||||||
let sched = lua
|
|
||||||
.app_data_ref::<&Scheduler>()
|
|
||||||
.expect("Lua struct is missing scheduler");
|
|
||||||
|
|
||||||
let res = sched
|
|
||||||
.spawn(spawn_command(program, args, options))
|
.spawn(spawn_command(program, args, options))
|
||||||
.await
|
.await
|
||||||
.expect("Failed to receive result of spawned process")?;
|
.expect("Failed to receive result of spawned process");
|
||||||
|
|
||||||
/*
|
/*
|
||||||
NOTE: If an exit code was not given by the child process,
|
NOTE: If an exit code was not given by the child process,
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
use mlua_luau_scheduler::LuaSpawnExt;
|
||||||
use once_cell::sync::OnceCell;
|
use once_cell::sync::OnceCell;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -11,11 +12,9 @@ use crate::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use tokio::task;
|
|
||||||
|
|
||||||
static REFLECTION_DATABASE: OnceCell<ReflectionDatabase> = OnceCell::new();
|
static REFLECTION_DATABASE: OnceCell<ReflectionDatabase> = OnceCell::new();
|
||||||
|
|
||||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||||
let mut roblox_constants = Vec::new();
|
let mut roblox_constants = Vec::new();
|
||||||
|
|
||||||
let roblox_module = roblox::module(lua)?;
|
let roblox_module = roblox::module(lua)?;
|
||||||
|
@ -41,12 +40,12 @@ async fn deserialize_place<'lua>(
|
||||||
contents: LuaString<'lua>,
|
contents: LuaString<'lua>,
|
||||||
) -> LuaResult<LuaValue<'lua>> {
|
) -> LuaResult<LuaValue<'lua>> {
|
||||||
let bytes = contents.as_bytes().to_vec();
|
let bytes = contents.as_bytes().to_vec();
|
||||||
let fut = task::spawn_blocking(move || {
|
let fut = lua.spawn_blocking(move || {
|
||||||
let doc = Document::from_bytes(bytes, DocumentKind::Place)?;
|
let doc = Document::from_bytes(bytes, DocumentKind::Place)?;
|
||||||
let data_model = doc.into_data_model_instance()?;
|
let data_model = doc.into_data_model_instance()?;
|
||||||
Ok::<_, DocumentError>(data_model)
|
Ok::<_, DocumentError>(data_model)
|
||||||
});
|
});
|
||||||
fut.await.into_lua_err()??.into_lua(lua)
|
fut.await.into_lua_err()?.into_lua(lua)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn deserialize_model<'lua>(
|
async fn deserialize_model<'lua>(
|
||||||
|
@ -54,12 +53,12 @@ async fn deserialize_model<'lua>(
|
||||||
contents: LuaString<'lua>,
|
contents: LuaString<'lua>,
|
||||||
) -> LuaResult<LuaValue<'lua>> {
|
) -> LuaResult<LuaValue<'lua>> {
|
||||||
let bytes = contents.as_bytes().to_vec();
|
let bytes = contents.as_bytes().to_vec();
|
||||||
let fut = task::spawn_blocking(move || {
|
let fut = lua.spawn_blocking(move || {
|
||||||
let doc = Document::from_bytes(bytes, DocumentKind::Model)?;
|
let doc = Document::from_bytes(bytes, DocumentKind::Model)?;
|
||||||
let instance_array = doc.into_instance_array()?;
|
let instance_array = doc.into_instance_array()?;
|
||||||
Ok::<_, DocumentError>(instance_array)
|
Ok::<_, DocumentError>(instance_array)
|
||||||
});
|
});
|
||||||
fut.await.into_lua_err()??.into_lua(lua)
|
fut.await.into_lua_err()?.into_lua(lua)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn serialize_place<'lua>(
|
async fn serialize_place<'lua>(
|
||||||
|
@ -67,7 +66,7 @@ async fn serialize_place<'lua>(
|
||||||
(data_model, as_xml): (LuaUserDataRef<'lua, Instance>, Option<bool>),
|
(data_model, as_xml): (LuaUserDataRef<'lua, Instance>, Option<bool>),
|
||||||
) -> LuaResult<LuaString<'lua>> {
|
) -> LuaResult<LuaString<'lua>> {
|
||||||
let data_model = (*data_model).clone();
|
let data_model = (*data_model).clone();
|
||||||
let fut = task::spawn_blocking(move || {
|
let fut = lua.spawn_blocking(move || {
|
||||||
let doc = Document::from_data_model_instance(data_model)?;
|
let doc = Document::from_data_model_instance(data_model)?;
|
||||||
let bytes = doc.to_bytes_with_format(match as_xml {
|
let bytes = doc.to_bytes_with_format(match as_xml {
|
||||||
Some(true) => DocumentFormat::Xml,
|
Some(true) => DocumentFormat::Xml,
|
||||||
|
@ -75,7 +74,7 @@ async fn serialize_place<'lua>(
|
||||||
})?;
|
})?;
|
||||||
Ok::<_, DocumentError>(bytes)
|
Ok::<_, DocumentError>(bytes)
|
||||||
});
|
});
|
||||||
let bytes = fut.await.into_lua_err()??;
|
let bytes = fut.await.into_lua_err()?;
|
||||||
lua.create_string(bytes)
|
lua.create_string(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,7 +83,7 @@ async fn serialize_model<'lua>(
|
||||||
(instances, as_xml): (Vec<LuaUserDataRef<'lua, Instance>>, Option<bool>),
|
(instances, as_xml): (Vec<LuaUserDataRef<'lua, Instance>>, Option<bool>),
|
||||||
) -> LuaResult<LuaString<'lua>> {
|
) -> LuaResult<LuaString<'lua>> {
|
||||||
let instances = instances.iter().map(|i| (*i).clone()).collect();
|
let instances = instances.iter().map(|i| (*i).clone()).collect();
|
||||||
let fut = task::spawn_blocking(move || {
|
let fut = lua.spawn_blocking(move || {
|
||||||
let doc = Document::from_instance_array(instances)?;
|
let doc = Document::from_instance_array(instances)?;
|
||||||
let bytes = doc.to_bytes_with_format(match as_xml {
|
let bytes = doc.to_bytes_with_format(match as_xml {
|
||||||
Some(true) => DocumentFormat::Xml,
|
Some(true) => DocumentFormat::Xml,
|
||||||
|
@ -92,7 +91,7 @@ async fn serialize_model<'lua>(
|
||||||
})?;
|
})?;
|
||||||
Ok::<_, DocumentError>(bytes)
|
Ok::<_, DocumentError>(bytes)
|
||||||
});
|
});
|
||||||
let bytes = fut.await.into_lua_err()??;
|
let bytes = fut.await.into_lua_err()?;
|
||||||
lua.create_string(bytes)
|
lua.create_string(bytes)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
use lz4_flex::{compress_prepend_size, decompress_size_prepended};
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
use tokio::{
|
|
||||||
io::{copy, BufReader},
|
use lz4_flex::{compress_prepend_size, decompress_size_prepended};
|
||||||
task,
|
use tokio::io::{copy, BufReader};
|
||||||
};
|
|
||||||
|
|
||||||
use async_compression::{
|
use async_compression::{
|
||||||
tokio::bufread::{
|
tokio::bufread::{
|
||||||
|
@ -100,9 +98,7 @@ pub async fn compress<'lua>(
|
||||||
) -> LuaResult<Vec<u8>> {
|
) -> LuaResult<Vec<u8>> {
|
||||||
if let CompressDecompressFormat::LZ4 = format {
|
if let CompressDecompressFormat::LZ4 = format {
|
||||||
let source = source.as_ref().to_vec();
|
let source = source.as_ref().to_vec();
|
||||||
return task::spawn_blocking(move || compress_prepend_size(&source))
|
return Ok(blocking::unblock(move || compress_prepend_size(&source)).await);
|
||||||
.await
|
|
||||||
.into_lua_err();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut bytes = Vec::new();
|
let mut bytes = Vec::new();
|
||||||
|
@ -133,9 +129,8 @@ pub async fn decompress<'lua>(
|
||||||
) -> LuaResult<Vec<u8>> {
|
) -> LuaResult<Vec<u8>> {
|
||||||
if let CompressDecompressFormat::LZ4 = format {
|
if let CompressDecompressFormat::LZ4 = format {
|
||||||
let source = source.as_ref().to_vec();
|
let source = source.as_ref().to_vec();
|
||||||
return task::spawn_blocking(move || decompress_size_prepended(&source))
|
return blocking::unblock(move || decompress_size_prepended(&source))
|
||||||
.await
|
.await
|
||||||
.into_lua_err()?
|
|
||||||
.into_lua_err();
|
.into_lua_err();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,7 @@ use encode_decode::{EncodeDecodeConfig, EncodeDecodeFormat};
|
||||||
|
|
||||||
use crate::lune::util::TableBuilder;
|
use crate::lune::util::TableBuilder;
|
||||||
|
|
||||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
|
pub fn create(lua: &Lua) -> LuaResult<LuaTable> {
|
||||||
TableBuilder::new(lua)?
|
TableBuilder::new(lua)?
|
||||||
.with_function("encode", serde_encode)?
|
.with_function("encode", serde_encode)?
|
||||||
.with_function("decode", serde_decode)?
|
.with_function("decode", serde_decode)?
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
use dialoguer::{theme::ColorfulTheme, Confirm, Input, MultiSelect, Select};
|
use dialoguer::{theme::ColorfulTheme, Confirm, Input, MultiSelect, Select};
|
||||||
use tokio::{
|
use mlua_luau_scheduler::LuaSpawnExt;
|
||||||
io::{self, AsyncWriteExt},
|
use tokio::io::{self, AsyncWriteExt};
|
||||||
task,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::lune::util::{
|
use crate::lune::util::{
|
||||||
formatting::{
|
formatting::{
|
||||||
|
@ -16,7 +14,7 @@ use crate::lune::util::{
|
||||||
mod prompt;
|
mod prompt;
|
||||||
use prompt::{PromptKind, PromptOptions, PromptResult};
|
use prompt::{PromptKind, PromptOptions, PromptResult};
|
||||||
|
|
||||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'_>> {
|
pub fn create(lua: &Lua) -> LuaResult<LuaTable<'_>> {
|
||||||
TableBuilder::new(lua)?
|
TableBuilder::new(lua)?
|
||||||
.with_function("color", stdio_color)?
|
.with_function("color", stdio_color)?
|
||||||
.with_function("style", stdio_style)?
|
.with_function("style", stdio_style)?
|
||||||
|
@ -55,10 +53,10 @@ async fn stdio_ewrite(_: &Lua, s: LuaString<'_>) -> LuaResult<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn stdio_prompt(_: &Lua, options: PromptOptions) -> LuaResult<PromptResult> {
|
async fn stdio_prompt(lua: &Lua, options: PromptOptions) -> LuaResult<PromptResult> {
|
||||||
task::spawn_blocking(move || prompt(options))
|
lua.spawn_blocking(move || prompt(options))
|
||||||
.await
|
.await
|
||||||
.into_lua_err()?
|
.into_lua_err()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prompt(options: PromptOptions) -> LuaResult<PromptResult> {
|
fn prompt(options: PromptOptions) -> LuaResult<PromptResult> {
|
||||||
|
|
|
@ -2,120 +2,51 @@ use std::time::Duration;
|
||||||
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
|
use mlua_luau_scheduler::Functions;
|
||||||
use tokio::time::{self, Instant};
|
use tokio::time::{self, Instant};
|
||||||
|
|
||||||
use crate::lune::{scheduler::Scheduler, util::TableBuilder};
|
use crate::lune::util::TableBuilder;
|
||||||
|
|
||||||
mod tof;
|
const DELAY_IMPL_LUA: &str = r#"
|
||||||
use tof::LuaThreadOrFunction;
|
return defer(function(...)
|
||||||
|
wait(select(1, ...))
|
||||||
/*
|
spawn(select(2, ...))
|
||||||
The spawn function needs special treatment,
|
end, ...)
|
||||||
we need to yield right away to allow the
|
|
||||||
spawned task to run until first yield
|
|
||||||
|
|
||||||
1. Schedule this current thread at the front
|
|
||||||
2. Schedule given thread/function at the front,
|
|
||||||
the previous schedule now comes right after
|
|
||||||
3. Give control over to the scheduler, which will
|
|
||||||
resume the above tasks in order when its ready
|
|
||||||
*/
|
|
||||||
const SPAWN_IMPL_LUA: &str = r#"
|
|
||||||
push(currentThread())
|
|
||||||
local thread = push(...)
|
|
||||||
yield()
|
|
||||||
return thread
|
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
pub fn create(lua: &'static Lua) -> LuaResult<LuaTable<'_>> {
|
pub fn create(lua: &Lua) -> LuaResult<LuaTable<'_>> {
|
||||||
let coroutine_running = lua
|
let fns = Functions::new(lua)?;
|
||||||
.globals()
|
|
||||||
.get::<_, LuaTable>("coroutine")?
|
// Create wait & delay functions
|
||||||
.get::<_, LuaFunction>("running")?;
|
let task_wait = lua.create_async_function(wait)?;
|
||||||
let coroutine_yield = lua
|
let task_delay_env = TableBuilder::new(lua)?
|
||||||
.globals()
|
.with_value("select", lua.globals().get::<_, LuaFunction>("select")?)?
|
||||||
.get::<_, LuaTable>("coroutine")?
|
.with_value("spawn", fns.spawn.clone())?
|
||||||
.get::<_, LuaFunction>("yield")?;
|
.with_value("defer", fns.defer.clone())?
|
||||||
let push_front =
|
.with_value("wait", task_wait.clone())?
|
||||||
lua.create_function(|lua, (tof, args): (LuaThreadOrFunction, LuaMultiValue)| {
|
|
||||||
let thread = tof.into_thread(lua)?;
|
|
||||||
let sched = lua
|
|
||||||
.app_data_ref::<&Scheduler>()
|
|
||||||
.expect("Lua struct is missing scheduler");
|
|
||||||
sched.push_front(lua, thread.clone(), args)?;
|
|
||||||
Ok(thread)
|
|
||||||
})?;
|
|
||||||
let task_spawn_env = TableBuilder::new(lua)?
|
|
||||||
.with_value("currentThread", coroutine_running)?
|
|
||||||
.with_value("yield", coroutine_yield)?
|
|
||||||
.with_value("push", push_front)?
|
|
||||||
.build_readonly()?;
|
.build_readonly()?;
|
||||||
let task_spawn = lua
|
let task_delay = lua
|
||||||
.load(SPAWN_IMPL_LUA)
|
.load(DELAY_IMPL_LUA)
|
||||||
.set_name("task.spawn")
|
.set_name("task.delay")
|
||||||
.set_environment(task_spawn_env)
|
.set_environment(task_delay_env)
|
||||||
.into_function()?;
|
.into_function()?;
|
||||||
|
|
||||||
|
// Overwrite resume & wrap functions on the coroutine global
|
||||||
|
// with ones that are compatible with our scheduler
|
||||||
|
let co = lua.globals().get::<_, LuaTable>("coroutine")?;
|
||||||
|
co.set("resume", fns.resume.clone())?;
|
||||||
|
co.set("wrap", fns.wrap.clone())?;
|
||||||
|
|
||||||
TableBuilder::new(lua)?
|
TableBuilder::new(lua)?
|
||||||
.with_function("cancel", task_cancel)?
|
.with_value("cancel", fns.cancel)?
|
||||||
.with_function("defer", task_defer)?
|
.with_value("defer", fns.defer)?
|
||||||
.with_function("delay", task_delay)?
|
.with_value("delay", task_delay)?
|
||||||
.with_value("spawn", task_spawn)?
|
.with_value("spawn", fns.spawn)?
|
||||||
.with_async_function("wait", task_wait)?
|
.with_value("wait", task_wait)?
|
||||||
.build_readonly()
|
.build_readonly()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn task_cancel(lua: &Lua, thread: LuaThread) -> LuaResult<()> {
|
async fn wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
|
||||||
let close = lua
|
|
||||||
.globals()
|
|
||||||
.get::<_, LuaTable>("coroutine")?
|
|
||||||
.get::<_, LuaFunction>("close")?;
|
|
||||||
match close.call(thread) {
|
|
||||||
Err(LuaError::CoroutineInactive) => Ok(()),
|
|
||||||
Err(e) => Err(e),
|
|
||||||
Ok(()) => Ok(()),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn task_defer<'lua>(
|
|
||||||
lua: &'lua Lua,
|
|
||||||
(tof, args): (LuaThreadOrFunction<'lua>, LuaMultiValue<'_>),
|
|
||||||
) -> LuaResult<LuaThread<'lua>> {
|
|
||||||
let thread = tof.into_thread(lua)?;
|
|
||||||
let sched = lua
|
|
||||||
.app_data_ref::<&Scheduler>()
|
|
||||||
.expect("Lua struct is missing scheduler");
|
|
||||||
sched.push_back(lua, thread.clone(), args)?;
|
|
||||||
Ok(thread)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME: `self` escapes outside of method because we are borrowing `tof` and
|
|
||||||
// `args` when we call `schedule_future_thread` in the lua function body below
|
|
||||||
// For now we solve this by using the 'static lifetime bound in the impl
|
|
||||||
fn task_delay<'lua>(
|
|
||||||
lua: &'lua Lua,
|
|
||||||
(secs, tof, args): (f64, LuaThreadOrFunction<'lua>, LuaMultiValue<'lua>),
|
|
||||||
) -> LuaResult<LuaThread<'lua>>
|
|
||||||
where
|
|
||||||
'lua: 'static,
|
|
||||||
{
|
|
||||||
let thread = tof.into_thread(lua)?;
|
|
||||||
let sched = lua
|
|
||||||
.app_data_ref::<&Scheduler>()
|
|
||||||
.expect("Lua struct is missing scheduler");
|
|
||||||
|
|
||||||
let thread2 = thread.clone();
|
|
||||||
sched.spawn_thread(lua, thread.clone(), async move {
|
|
||||||
let duration = Duration::from_secs_f64(secs);
|
|
||||||
time::sleep(duration).await;
|
|
||||||
sched.push_back(lua, thread2, args)?;
|
|
||||||
Ok(())
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(thread)
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn task_wait(_: &Lua, secs: Option<f64>) -> LuaResult<f64> {
|
|
||||||
let duration = Duration::from_secs_f64(secs.unwrap_or_default());
|
let duration = Duration::from_secs_f64(secs.unwrap_or_default());
|
||||||
|
|
||||||
let before = Instant::now();
|
let before = Instant::now();
|
||||||
|
|
|
@ -1,30 +0,0 @@
|
||||||
use mlua::prelude::*;
|
|
||||||
|
|
||||||
#[derive(Clone)]
|
|
||||||
pub(super) enum LuaThreadOrFunction<'lua> {
|
|
||||||
Thread(LuaThread<'lua>),
|
|
||||||
Function(LuaFunction<'lua>),
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'lua> LuaThreadOrFunction<'lua> {
|
|
||||||
pub(super) fn into_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
|
|
||||||
match self {
|
|
||||||
Self::Thread(t) => Ok(t),
|
|
||||||
Self::Function(f) => lua.create_thread(f),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'lua> FromLua<'lua> for LuaThreadOrFunction<'lua> {
|
|
||||||
fn from_lua(value: LuaValue<'lua>, _: &'lua Lua) -> LuaResult<Self> {
|
|
||||||
match value {
|
|
||||||
LuaValue::Thread(t) => Ok(Self::Thread(t)),
|
|
||||||
LuaValue::Function(f) => Ok(Self::Function(f)),
|
|
||||||
value => Err(LuaError::FromLuaConversionError {
|
|
||||||
from: value.type_name(),
|
|
||||||
to: "LuaThreadOrFunction",
|
|
||||||
message: Some("Expected thread or function".to_string()),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -8,7 +8,7 @@ mod require;
|
||||||
mod version;
|
mod version;
|
||||||
mod warn;
|
mod warn;
|
||||||
|
|
||||||
pub fn inject_all(lua: &'static Lua) -> LuaResult<()> {
|
pub fn inject_all(lua: &Lua) -> LuaResult<()> {
|
||||||
let all = TableBuilder::new(lua)?
|
let all = TableBuilder::new(lua)?
|
||||||
.with_value("_G", g_table::create(lua)?)?
|
.with_value("_G", g_table::create(lua)?)?
|
||||||
.with_value("_VERSION", version::create(lua)?)?
|
.with_value("_VERSION", version::create(lua)?)?
|
||||||
|
|
|
@ -9,7 +9,8 @@ use crate::lune::util::{
|
||||||
use super::context::*;
|
use super::context::*;
|
||||||
|
|
||||||
pub(super) async fn require<'lua, 'ctx>(
|
pub(super) async fn require<'lua, 'ctx>(
|
||||||
ctx: &'ctx RequireContext<'lua>,
|
lua: &'lua Lua,
|
||||||
|
ctx: &'ctx RequireContext,
|
||||||
source: &str,
|
source: &str,
|
||||||
alias: &str,
|
alias: &str,
|
||||||
path: &str,
|
path: &str,
|
||||||
|
@ -71,5 +72,5 @@ where
|
||||||
LuaError::runtime(format!("failed to find relative path for alias '{alias}'"))
|
LuaError::runtime(format!("failed to find relative path for alias '{alias}'"))
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
super::path::require_abs_rel(ctx, abs_path, rel_path).await
|
super::path::require_abs_rel(lua, ctx, abs_path, rel_path).await
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,12 +3,12 @@ use mlua::prelude::*;
|
||||||
use super::context::*;
|
use super::context::*;
|
||||||
|
|
||||||
pub(super) async fn require<'lua, 'ctx>(
|
pub(super) async fn require<'lua, 'ctx>(
|
||||||
ctx: &'ctx RequireContext<'lua>,
|
lua: &'lua Lua,
|
||||||
|
ctx: &'ctx RequireContext,
|
||||||
name: &str,
|
name: &str,
|
||||||
) -> LuaResult<LuaMultiValue<'lua>>
|
) -> LuaResult<LuaMultiValue<'lua>>
|
||||||
where
|
where
|
||||||
'lua: 'ctx,
|
'lua: 'ctx,
|
||||||
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
|
|
||||||
{
|
{
|
||||||
ctx.load_builtin(name)
|
ctx.load_builtin(lua, name)
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
use mlua_luau_scheduler::LuaSchedulerExt;
|
||||||
use tokio::{
|
use tokio::{
|
||||||
fs,
|
fs,
|
||||||
sync::{
|
sync::{
|
||||||
|
@ -13,11 +14,7 @@ use tokio::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::lune::{
|
use crate::lune::{builtins::LuneBuiltin, util::paths::CWD};
|
||||||
builtins::LuneBuiltin,
|
|
||||||
scheduler::{IntoLuaThread, Scheduler},
|
|
||||||
util::paths::CWD,
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
Context containing cached results for all `require` operations.
|
Context containing cached results for all `require` operations.
|
||||||
|
@ -26,14 +23,13 @@ use crate::lune::{
|
||||||
path will first be transformed into an absolute path.
|
path will first be transformed into an absolute path.
|
||||||
*/
|
*/
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub(super) struct RequireContext<'lua> {
|
pub(super) struct RequireContext {
|
||||||
lua: &'lua Lua,
|
|
||||||
cache_builtins: Arc<AsyncMutex<HashMap<LuneBuiltin, LuaResult<LuaRegistryKey>>>>,
|
cache_builtins: Arc<AsyncMutex<HashMap<LuneBuiltin, LuaResult<LuaRegistryKey>>>>,
|
||||||
cache_results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>,
|
cache_results: Arc<AsyncMutex<HashMap<PathBuf, LuaResult<LuaRegistryKey>>>>,
|
||||||
cache_pending: Arc<AsyncMutex<HashMap<PathBuf, Sender<()>>>>,
|
cache_pending: Arc<AsyncMutex<HashMap<PathBuf, Sender<()>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'lua> RequireContext<'lua> {
|
impl RequireContext {
|
||||||
/**
|
/**
|
||||||
Creates a new require context for the given [`Lua`] struct.
|
Creates a new require context for the given [`Lua`] struct.
|
||||||
|
|
||||||
|
@ -41,9 +37,8 @@ impl<'lua> RequireContext<'lua> {
|
||||||
context should be created per [`Lua`] struct, creating more
|
context should be created per [`Lua`] struct, creating more
|
||||||
than one context may lead to undefined require-behavior.
|
than one context may lead to undefined require-behavior.
|
||||||
*/
|
*/
|
||||||
pub fn new(lua: &'lua Lua) -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
lua,
|
|
||||||
cache_builtins: Arc::new(AsyncMutex::new(HashMap::new())),
|
cache_builtins: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||||
cache_results: Arc::new(AsyncMutex::new(HashMap::new())),
|
cache_results: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||||
cache_pending: Arc::new(AsyncMutex::new(HashMap::new())),
|
cache_pending: Arc::new(AsyncMutex::new(HashMap::new())),
|
||||||
|
@ -107,7 +102,11 @@ impl<'lua> RequireContext<'lua> {
|
||||||
|
|
||||||
Will panic if the path has not been cached, use [`is_cached`] first.
|
Will panic if the path has not been cached, use [`is_cached`] first.
|
||||||
*/
|
*/
|
||||||
pub fn get_from_cache(&self, abs_path: impl AsRef<Path>) -> LuaResult<LuaMultiValue<'lua>> {
|
pub fn get_from_cache<'lua>(
|
||||||
|
&self,
|
||||||
|
lua: &'lua Lua,
|
||||||
|
abs_path: impl AsRef<Path>,
|
||||||
|
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||||
let results = self
|
let results = self
|
||||||
.cache_results
|
.cache_results
|
||||||
.try_lock()
|
.try_lock()
|
||||||
|
@ -119,8 +118,7 @@ impl<'lua> RequireContext<'lua> {
|
||||||
match cached {
|
match cached {
|
||||||
Err(e) => Err(e.clone()),
|
Err(e) => Err(e.clone()),
|
||||||
Ok(k) => {
|
Ok(k) => {
|
||||||
let multi_vec = self
|
let multi_vec = lua
|
||||||
.lua
|
|
||||||
.registry_value::<Vec<LuaValue>>(k)
|
.registry_value::<Vec<LuaValue>>(k)
|
||||||
.expect("Missing require result in lua registry");
|
.expect("Missing require result in lua registry");
|
||||||
Ok(LuaMultiValue::from_vec(multi_vec))
|
Ok(LuaMultiValue::from_vec(multi_vec))
|
||||||
|
@ -133,8 +131,9 @@ impl<'lua> RequireContext<'lua> {
|
||||||
|
|
||||||
Will panic if the path has not been cached, use [`is_cached`] first.
|
Will panic if the path has not been cached, use [`is_cached`] first.
|
||||||
*/
|
*/
|
||||||
pub async fn wait_for_cache(
|
pub async fn wait_for_cache<'lua>(
|
||||||
&self,
|
&self,
|
||||||
|
lua: &'lua Lua,
|
||||||
abs_path: impl AsRef<Path>,
|
abs_path: impl AsRef<Path>,
|
||||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||||
let mut thread_recv = {
|
let mut thread_recv = {
|
||||||
|
@ -150,43 +149,37 @@ impl<'lua> RequireContext<'lua> {
|
||||||
|
|
||||||
thread_recv.recv().await.into_lua_err()?;
|
thread_recv.recv().await.into_lua_err()?;
|
||||||
|
|
||||||
self.get_from_cache(abs_path.as_ref())
|
self.get_from_cache(lua, abs_path.as_ref())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn load(
|
async fn load<'lua>(
|
||||||
&self,
|
&self,
|
||||||
|
lua: &'lua Lua,
|
||||||
abs_path: impl AsRef<Path>,
|
abs_path: impl AsRef<Path>,
|
||||||
rel_path: impl AsRef<Path>,
|
rel_path: impl AsRef<Path>,
|
||||||
) -> LuaResult<LuaRegistryKey> {
|
) -> LuaResult<LuaRegistryKey> {
|
||||||
let abs_path = abs_path.as_ref();
|
let abs_path = abs_path.as_ref();
|
||||||
let rel_path = rel_path.as_ref();
|
let rel_path = rel_path.as_ref();
|
||||||
|
|
||||||
let sched = self
|
|
||||||
.lua
|
|
||||||
.app_data_ref::<&Scheduler>()
|
|
||||||
.expect("Lua struct is missing scheduler");
|
|
||||||
|
|
||||||
// Read the file at the given path, try to parse and
|
// Read the file at the given path, try to parse and
|
||||||
// load it into a new lua thread that we can schedule
|
// load it into a new lua thread that we can schedule
|
||||||
let file_contents = fs::read(&abs_path).await?;
|
let file_contents = fs::read(&abs_path).await?;
|
||||||
let file_thread = self
|
let file_thread = lua
|
||||||
.lua
|
|
||||||
.load(file_contents)
|
.load(file_contents)
|
||||||
.set_name(rel_path.to_string_lossy().to_string())
|
.set_name(rel_path.to_string_lossy().to_string());
|
||||||
.into_function()?
|
|
||||||
.into_lua_thread(self.lua)?;
|
|
||||||
|
|
||||||
// Schedule the thread to run, wait for it to finish running
|
// Schedule the thread to run, wait for it to finish running
|
||||||
let thread_id = sched.push_back(self.lua, file_thread, ())?;
|
let thread_id = lua.push_thread_back(file_thread, ())?;
|
||||||
let thread_res = sched.wait_for_thread(self.lua, thread_id).await;
|
lua.track_thread(thread_id);
|
||||||
|
lua.wait_for_thread(thread_id).await;
|
||||||
|
let thread_res = lua.get_thread_result(thread_id).unwrap();
|
||||||
|
|
||||||
// Return the result of the thread, storing any lua value(s) in the registry
|
// Return the result of the thread, storing any lua value(s) in the registry
|
||||||
match thread_res {
|
match thread_res {
|
||||||
Err(e) => Err(e),
|
Err(e) => Err(e),
|
||||||
Ok(v) => {
|
Ok(v) => {
|
||||||
let multi_vec = v.into_vec();
|
let multi_vec = v.into_vec();
|
||||||
let multi_key = self
|
let multi_key = lua
|
||||||
.lua
|
|
||||||
.create_registry_value(multi_vec)
|
.create_registry_value(multi_vec)
|
||||||
.expect("Failed to store require result in registry - out of memory");
|
.expect("Failed to store require result in registry - out of memory");
|
||||||
Ok(multi_key)
|
Ok(multi_key)
|
||||||
|
@ -197,8 +190,9 @@ impl<'lua> RequireContext<'lua> {
|
||||||
/**
|
/**
|
||||||
Loads (requires) the file at the given path.
|
Loads (requires) the file at the given path.
|
||||||
*/
|
*/
|
||||||
pub async fn load_with_caching(
|
pub async fn load_with_caching<'lua>(
|
||||||
&self,
|
&self,
|
||||||
|
lua: &'lua Lua,
|
||||||
abs_path: impl AsRef<Path>,
|
abs_path: impl AsRef<Path>,
|
||||||
rel_path: impl AsRef<Path>,
|
rel_path: impl AsRef<Path>,
|
||||||
) -> LuaResult<LuaMultiValue<'lua>> {
|
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||||
|
@ -213,12 +207,11 @@ impl<'lua> RequireContext<'lua> {
|
||||||
.insert(abs_path.to_path_buf(), broadcast_tx);
|
.insert(abs_path.to_path_buf(), broadcast_tx);
|
||||||
|
|
||||||
// Try to load at this abs path
|
// Try to load at this abs path
|
||||||
let load_res = self.load(abs_path, rel_path).await;
|
let load_res = self.load(lua, abs_path, rel_path).await;
|
||||||
let load_val = match &load_res {
|
let load_val = match &load_res {
|
||||||
Err(e) => Err(e.clone()),
|
Err(e) => Err(e.clone()),
|
||||||
Ok(k) => {
|
Ok(k) => {
|
||||||
let multi_vec = self
|
let multi_vec = lua
|
||||||
.lua
|
|
||||||
.registry_value::<Vec<LuaValue>>(k)
|
.registry_value::<Vec<LuaValue>>(k)
|
||||||
.expect("Failed to fetch require result from registry");
|
.expect("Failed to fetch require result from registry");
|
||||||
Ok(LuaMultiValue::from_vec(multi_vec))
|
Ok(LuaMultiValue::from_vec(multi_vec))
|
||||||
|
@ -250,10 +243,11 @@ impl<'lua> RequireContext<'lua> {
|
||||||
/**
|
/**
|
||||||
Loads (requires) the builtin with the given name.
|
Loads (requires) the builtin with the given name.
|
||||||
*/
|
*/
|
||||||
pub fn load_builtin(&self, name: impl AsRef<str>) -> LuaResult<LuaMultiValue<'lua>>
|
pub fn load_builtin<'lua>(
|
||||||
where
|
&self,
|
||||||
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
|
lua: &'lua Lua,
|
||||||
{
|
name: impl AsRef<str>,
|
||||||
|
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||||
let builtin: LuneBuiltin = match name.as_ref().parse() {
|
let builtin: LuneBuiltin = match name.as_ref().parse() {
|
||||||
Err(e) => return Err(LuaError::runtime(e)),
|
Err(e) => return Err(LuaError::runtime(e)),
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
|
@ -268,8 +262,7 @@ impl<'lua> RequireContext<'lua> {
|
||||||
return match res {
|
return match res {
|
||||||
Err(e) => return Err(e.clone()),
|
Err(e) => return Err(e.clone()),
|
||||||
Ok(key) => {
|
Ok(key) => {
|
||||||
let multi_vec = self
|
let multi_vec = lua
|
||||||
.lua
|
|
||||||
.registry_value::<Vec<LuaValue>>(key)
|
.registry_value::<Vec<LuaValue>>(key)
|
||||||
.expect("Missing builtin result in lua registry");
|
.expect("Missing builtin result in lua registry");
|
||||||
Ok(LuaMultiValue::from_vec(multi_vec))
|
Ok(LuaMultiValue::from_vec(multi_vec))
|
||||||
|
@ -277,7 +270,7 @@ impl<'lua> RequireContext<'lua> {
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
|
|
||||||
let result = builtin.create(self.lua);
|
let result = builtin.create(lua);
|
||||||
|
|
||||||
cache.insert(
|
cache.insert(
|
||||||
builtin,
|
builtin,
|
||||||
|
@ -285,8 +278,7 @@ impl<'lua> RequireContext<'lua> {
|
||||||
Err(e) => Err(e),
|
Err(e) => Err(e),
|
||||||
Ok(multi) => {
|
Ok(multi) => {
|
||||||
let multi_vec = multi.into_vec();
|
let multi_vec = multi.into_vec();
|
||||||
let multi_key = self
|
let multi_key = lua
|
||||||
.lua
|
|
||||||
.create_registry_value(multi_vec)
|
.create_registry_value(multi_vec)
|
||||||
.expect("Failed to store require result in registry - out of memory");
|
.expect("Failed to store require result in registry - out of memory");
|
||||||
Ok(multi_key)
|
Ok(multi_key)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
use crate::lune::{scheduler::LuaSchedulerExt, util::TableBuilder};
|
use crate::lune::util::TableBuilder;
|
||||||
|
|
||||||
mod context;
|
mod context;
|
||||||
use context::RequireContext;
|
use context::RequireContext;
|
||||||
|
@ -13,8 +13,8 @@ const REQUIRE_IMPL: &str = r#"
|
||||||
return require(source(), ...)
|
return require(source(), ...)
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> {
|
pub fn create(lua: &Lua) -> LuaResult<impl IntoLua<'_>> {
|
||||||
lua.set_app_data(RequireContext::new(lua));
|
lua.set_app_data(RequireContext::new());
|
||||||
|
|
||||||
/*
|
/*
|
||||||
Require implementation needs a few workarounds:
|
Require implementation needs a few workarounds:
|
||||||
|
@ -62,10 +62,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<impl IntoLua<'_>> {
|
||||||
async fn require<'lua>(
|
async fn require<'lua>(
|
||||||
lua: &'lua Lua,
|
lua: &'lua Lua,
|
||||||
(source, path): (LuaString<'lua>, LuaString<'lua>),
|
(source, path): (LuaString<'lua>, LuaString<'lua>),
|
||||||
) -> LuaResult<LuaMultiValue<'lua>>
|
) -> LuaResult<LuaMultiValue<'lua>> {
|
||||||
where
|
|
||||||
'lua: 'static, // FIXME: Remove static lifetime bound here when builtin libraries no longer need it
|
|
||||||
{
|
|
||||||
let source = source
|
let source = source
|
||||||
.to_str()
|
.to_str()
|
||||||
.into_lua_err()
|
.into_lua_err()
|
||||||
|
@ -86,13 +83,13 @@ where
|
||||||
.strip_prefix("@lune/")
|
.strip_prefix("@lune/")
|
||||||
.map(|name| name.to_ascii_lowercase())
|
.map(|name| name.to_ascii_lowercase())
|
||||||
{
|
{
|
||||||
builtin::require(&context, &builtin_name).await
|
builtin::require(lua, &context, &builtin_name).await
|
||||||
} else if let Some(aliased_path) = path.strip_prefix('@') {
|
} else if let Some(aliased_path) = path.strip_prefix('@') {
|
||||||
let (alias, path) = aliased_path.split_once('/').ok_or(LuaError::runtime(
|
let (alias, path) = aliased_path.split_once('/').ok_or(LuaError::runtime(
|
||||||
"Require with custom alias must contain '/' delimiter",
|
"Require with custom alias must contain '/' delimiter",
|
||||||
))?;
|
))?;
|
||||||
alias::require(&context, &source, alias, path).await
|
alias::require(lua, &context, &source, alias, path).await
|
||||||
} else {
|
} else {
|
||||||
path::require(&context, &source, &path).await
|
path::require(lua, &context, &source, &path).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,7 +5,8 @@ use mlua::prelude::*;
|
||||||
use super::context::*;
|
use super::context::*;
|
||||||
|
|
||||||
pub(super) async fn require<'lua, 'ctx>(
|
pub(super) async fn require<'lua, 'ctx>(
|
||||||
ctx: &'ctx RequireContext<'lua>,
|
lua: &'lua Lua,
|
||||||
|
ctx: &'ctx RequireContext,
|
||||||
source: &str,
|
source: &str,
|
||||||
path: &str,
|
path: &str,
|
||||||
) -> LuaResult<LuaMultiValue<'lua>>
|
) -> LuaResult<LuaMultiValue<'lua>>
|
||||||
|
@ -13,11 +14,12 @@ where
|
||||||
'lua: 'ctx,
|
'lua: 'ctx,
|
||||||
{
|
{
|
||||||
let (abs_path, rel_path) = ctx.resolve_paths(source, path)?;
|
let (abs_path, rel_path) = ctx.resolve_paths(source, path)?;
|
||||||
require_abs_rel(ctx, abs_path, rel_path).await
|
require_abs_rel(lua, ctx, abs_path, rel_path).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) async fn require_abs_rel<'lua, 'ctx>(
|
pub(super) async fn require_abs_rel<'lua, 'ctx>(
|
||||||
ctx: &'ctx RequireContext<'lua>,
|
lua: &'lua Lua,
|
||||||
|
ctx: &'ctx RequireContext,
|
||||||
abs_path: PathBuf, // Absolute to filesystem
|
abs_path: PathBuf, // Absolute to filesystem
|
||||||
rel_path: PathBuf, // Relative to CWD (for displaying)
|
rel_path: PathBuf, // Relative to CWD (for displaying)
|
||||||
) -> LuaResult<LuaMultiValue<'lua>>
|
) -> LuaResult<LuaMultiValue<'lua>>
|
||||||
|
@ -25,7 +27,7 @@ where
|
||||||
'lua: 'ctx,
|
'lua: 'ctx,
|
||||||
{
|
{
|
||||||
// 1. Try to require the exact path
|
// 1. Try to require the exact path
|
||||||
if let Ok(res) = require_inner(ctx, &abs_path, &rel_path).await {
|
if let Ok(res) = require_inner(lua, ctx, &abs_path, &rel_path).await {
|
||||||
return Ok(res);
|
return Ok(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -34,7 +36,7 @@ where
|
||||||
append_extension(&abs_path, "luau"),
|
append_extension(&abs_path, "luau"),
|
||||||
append_extension(&rel_path, "luau"),
|
append_extension(&rel_path, "luau"),
|
||||||
);
|
);
|
||||||
if let Ok(res) = require_inner(ctx, &luau_abs_path, &luau_rel_path).await {
|
if let Ok(res) = require_inner(lua, ctx, &luau_abs_path, &luau_rel_path).await {
|
||||||
return Ok(res);
|
return Ok(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -43,7 +45,7 @@ where
|
||||||
append_extension(&abs_path, "lua"),
|
append_extension(&abs_path, "lua"),
|
||||||
append_extension(&rel_path, "lua"),
|
append_extension(&rel_path, "lua"),
|
||||||
);
|
);
|
||||||
if let Ok(res) = require_inner(ctx, &lua_abs_path, &lua_rel_path).await {
|
if let Ok(res) = require_inner(lua, ctx, &lua_abs_path, &lua_rel_path).await {
|
||||||
return Ok(res);
|
return Ok(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,7 +59,7 @@ where
|
||||||
append_extension(&abs_init, "luau"),
|
append_extension(&abs_init, "luau"),
|
||||||
append_extension(&rel_init, "luau"),
|
append_extension(&rel_init, "luau"),
|
||||||
);
|
);
|
||||||
if let Ok(res) = require_inner(ctx, &luau_abs_init, &luau_rel_init).await {
|
if let Ok(res) = require_inner(lua, ctx, &luau_abs_init, &luau_rel_init).await {
|
||||||
return Ok(res);
|
return Ok(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -66,7 +68,7 @@ where
|
||||||
append_extension(&abs_init, "lua"),
|
append_extension(&abs_init, "lua"),
|
||||||
append_extension(&rel_init, "lua"),
|
append_extension(&rel_init, "lua"),
|
||||||
);
|
);
|
||||||
if let Ok(res) = require_inner(ctx, &lua_abs_init, &lua_rel_init).await {
|
if let Ok(res) = require_inner(lua, ctx, &lua_abs_init, &lua_rel_init).await {
|
||||||
return Ok(res);
|
return Ok(res);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,7 +80,8 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn require_inner<'lua, 'ctx>(
|
async fn require_inner<'lua, 'ctx>(
|
||||||
ctx: &'ctx RequireContext<'lua>,
|
lua: &'lua Lua,
|
||||||
|
ctx: &'ctx RequireContext,
|
||||||
abs_path: impl AsRef<Path>,
|
abs_path: impl AsRef<Path>,
|
||||||
rel_path: impl AsRef<Path>,
|
rel_path: impl AsRef<Path>,
|
||||||
) -> LuaResult<LuaMultiValue<'lua>>
|
) -> LuaResult<LuaMultiValue<'lua>>
|
||||||
|
@ -89,11 +92,11 @@ where
|
||||||
let rel_path = rel_path.as_ref();
|
let rel_path = rel_path.as_ref();
|
||||||
|
|
||||||
if ctx.is_cached(abs_path)? {
|
if ctx.is_cached(abs_path)? {
|
||||||
ctx.get_from_cache(abs_path)
|
ctx.get_from_cache(lua, abs_path)
|
||||||
} else if ctx.is_pending(abs_path)? {
|
} else if ctx.is_pending(abs_path)? {
|
||||||
ctx.wait_for_cache(&abs_path).await
|
ctx.wait_for_cache(lua, &abs_path).await
|
||||||
} else {
|
} else {
|
||||||
ctx.load_with_caching(&abs_path, &rel_path).await
|
ctx.load_with_caching(lua, &abs_path, &rel_path).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,47 +1,42 @@
|
||||||
use std::process::ExitCode;
|
use std::{
|
||||||
|
process::ExitCode,
|
||||||
|
rc::Rc,
|
||||||
|
sync::{
|
||||||
|
atomic::{AtomicBool, Ordering},
|
||||||
|
Arc,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
use mlua::Lua;
|
use mlua::Lua;
|
||||||
|
use mlua_luau_scheduler::Scheduler;
|
||||||
|
|
||||||
mod builtins;
|
mod builtins;
|
||||||
mod error;
|
mod error;
|
||||||
mod globals;
|
mod globals;
|
||||||
mod scheduler;
|
|
||||||
|
|
||||||
pub(crate) mod util;
|
pub(crate) mod util;
|
||||||
|
|
||||||
use self::scheduler::{LuaSchedulerExt, Scheduler};
|
|
||||||
|
|
||||||
pub use error::RuntimeError;
|
pub use error::RuntimeError;
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug)]
|
||||||
pub struct Runtime {
|
pub struct Runtime {
|
||||||
lua: &'static Lua,
|
lua: Rc<Lua>,
|
||||||
scheduler: &'static Scheduler<'static>,
|
|
||||||
args: Vec<String>,
|
args: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Runtime {
|
impl Runtime {
|
||||||
/**
|
/**
|
||||||
Creates a new Lune runtime, with a new Luau VM and task scheduler.
|
Creates a new Lune runtime, with a new Luau VM.
|
||||||
*/
|
*/
|
||||||
#[allow(clippy::new_without_default)]
|
#[allow(clippy::new_without_default)]
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
/*
|
let lua = Rc::new(Lua::new());
|
||||||
FUTURE: Stop leaking these when we have removed the lifetime
|
|
||||||
on the scheduler and can place them in lua app data using arc
|
|
||||||
|
|
||||||
See the scheduler struct for more notes
|
lua.set_app_data(Rc::downgrade(&lua));
|
||||||
*/
|
|
||||||
let lua = Lua::new().into_static();
|
|
||||||
let scheduler = Scheduler::new().into_static();
|
|
||||||
|
|
||||||
lua.set_scheduler(scheduler);
|
|
||||||
lua.set_app_data(Vec::<String>::new());
|
lua.set_app_data(Vec::<String>::new());
|
||||||
globals::inject_all(lua).expect("Failed to inject lua globals");
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
lua,
|
lua,
|
||||||
scheduler,
|
|
||||||
args: Vec::new(),
|
args: Vec::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -68,13 +63,35 @@ impl Runtime {
|
||||||
script_name: impl AsRef<str>,
|
script_name: impl AsRef<str>,
|
||||||
script_contents: impl AsRef<[u8]>,
|
script_contents: impl AsRef<[u8]>,
|
||||||
) -> Result<ExitCode, RuntimeError> {
|
) -> Result<ExitCode, RuntimeError> {
|
||||||
|
// Create a new scheduler for this run
|
||||||
|
let sched = Scheduler::new(&self.lua);
|
||||||
|
globals::inject_all(&self.lua)?;
|
||||||
|
|
||||||
|
// Add error callback to format errors nicely + store status
|
||||||
|
let got_any_error = Arc::new(AtomicBool::new(false));
|
||||||
|
let got_any_inner = Arc::clone(&got_any_error);
|
||||||
|
sched.set_error_callback(move |e| {
|
||||||
|
got_any_inner.store(true, Ordering::SeqCst);
|
||||||
|
eprintln!("{}", RuntimeError::from(e));
|
||||||
|
});
|
||||||
|
|
||||||
|
// Load our "main" thread
|
||||||
let main = self
|
let main = self
|
||||||
.lua
|
.lua
|
||||||
.load(script_contents.as_ref())
|
.load(script_contents.as_ref())
|
||||||
.set_name(script_name.as_ref());
|
.set_name(script_name.as_ref());
|
||||||
|
|
||||||
self.scheduler.push_back(self.lua, main, ())?;
|
// Run it on our scheduler until it and any other spawned threads complete
|
||||||
|
sched.push_thread_back(main, ())?;
|
||||||
|
sched.run().await;
|
||||||
|
|
||||||
Ok(self.scheduler.run_to_completion(self.lua).await)
|
// Return the exit code - default to FAILURE if we got any errors
|
||||||
|
Ok(sched.get_exit_code().unwrap_or({
|
||||||
|
if got_any_error.load(Ordering::SeqCst) {
|
||||||
|
ExitCode::FAILURE
|
||||||
|
} else {
|
||||||
|
ExitCode::SUCCESS
|
||||||
|
}
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,138 +0,0 @@
|
||||||
use futures_util::Future;
|
|
||||||
use mlua::prelude::*;
|
|
||||||
use tokio::{
|
|
||||||
sync::oneshot::{self, Receiver},
|
|
||||||
task,
|
|
||||||
};
|
|
||||||
|
|
||||||
use super::{IntoLuaThread, Scheduler};
|
|
||||||
|
|
||||||
impl<'fut> Scheduler<'fut> {
|
|
||||||
/**
|
|
||||||
Checks if there are any futures to run, for
|
|
||||||
lua futures and background futures respectively.
|
|
||||||
*/
|
|
||||||
pub(super) fn has_futures(&self) -> (bool, bool) {
|
|
||||||
(
|
|
||||||
self.futures_lua
|
|
||||||
.try_lock()
|
|
||||||
.expect("Failed to lock lua futures for check")
|
|
||||||
.len()
|
|
||||||
> 0,
|
|
||||||
self.futures_background
|
|
||||||
.try_lock()
|
|
||||||
.expect("Failed to lock background futures for check")
|
|
||||||
.len()
|
|
||||||
> 0,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Schedules a plain future to run in the background.
|
|
||||||
|
|
||||||
This will potentially spawn the future on a different thread, using
|
|
||||||
[`task::spawn`], meaning the provided future must implement [`Send`].
|
|
||||||
|
|
||||||
Returns a [`Receiver`] which may be `await`-ed
|
|
||||||
to retrieve the result of the spawned future.
|
|
||||||
|
|
||||||
This [`Receiver`] may be safely ignored if the result of the
|
|
||||||
spawned future is not needed, the future will run either way.
|
|
||||||
*/
|
|
||||||
pub fn spawn<F>(&self, fut: F) -> Receiver<F::Output>
|
|
||||||
where
|
|
||||||
F: Future + Send + 'static,
|
|
||||||
F::Output: Send + 'static,
|
|
||||||
{
|
|
||||||
let (tx, rx) = oneshot::channel();
|
|
||||||
|
|
||||||
let handle = task::spawn(async move {
|
|
||||||
let res = fut.await;
|
|
||||||
tx.send(res).ok();
|
|
||||||
});
|
|
||||||
|
|
||||||
// NOTE: We must spawn a future on our scheduler which awaits
|
|
||||||
// the handle from tokio to start driving our future properly
|
|
||||||
let futs = self
|
|
||||||
.futures_background
|
|
||||||
.try_lock()
|
|
||||||
.expect("Failed to lock futures queue for background tasks");
|
|
||||||
futs.push(Box::pin(async move {
|
|
||||||
handle.await.ok();
|
|
||||||
}));
|
|
||||||
|
|
||||||
// NOTE: We might be resuming lua futures, need to signal that a
|
|
||||||
// new background future is ready to break out of futures resumption
|
|
||||||
self.state.message_sender().send_spawned_background_future();
|
|
||||||
|
|
||||||
rx
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Equivalent to [`spawn`], except the future is only
|
|
||||||
spawned on the Lune scheduler, and on the main thread.
|
|
||||||
*/
|
|
||||||
pub fn spawn_local<F>(&self, fut: F) -> Receiver<F::Output>
|
|
||||||
where
|
|
||||||
F: Future + 'static,
|
|
||||||
F::Output: 'static,
|
|
||||||
{
|
|
||||||
let (tx, rx) = oneshot::channel();
|
|
||||||
|
|
||||||
let futs = self
|
|
||||||
.futures_background
|
|
||||||
.try_lock()
|
|
||||||
.expect("Failed to lock futures queue for background tasks");
|
|
||||||
futs.push(Box::pin(async move {
|
|
||||||
let res = fut.await;
|
|
||||||
tx.send(res).ok();
|
|
||||||
}));
|
|
||||||
|
|
||||||
// NOTE: We might be resuming lua futures, need to signal that a
|
|
||||||
// new background future is ready to break out of futures resumption
|
|
||||||
self.state.message_sender().send_spawned_background_future();
|
|
||||||
|
|
||||||
rx
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Schedules the given `thread` to run when the given `fut` completes.
|
|
||||||
|
|
||||||
If the given future returns a [`LuaError`], that error will be passed to the given `thread`.
|
|
||||||
*/
|
|
||||||
pub fn spawn_thread<F, FR>(
|
|
||||||
&'fut self,
|
|
||||||
lua: &'fut Lua,
|
|
||||||
thread: impl IntoLuaThread<'fut>,
|
|
||||||
fut: F,
|
|
||||||
) -> LuaResult<()>
|
|
||||||
where
|
|
||||||
FR: IntoLuaMulti<'fut>,
|
|
||||||
F: Future<Output = LuaResult<FR>> + 'fut,
|
|
||||||
{
|
|
||||||
let thread = thread.into_lua_thread(lua)?;
|
|
||||||
let futs = self.futures_lua.try_lock().expect(
|
|
||||||
"Failed to lock futures queue - \
|
|
||||||
can't schedule future lua threads during futures resumption",
|
|
||||||
);
|
|
||||||
|
|
||||||
futs.push(Box::pin(async move {
|
|
||||||
match fut.await.and_then(|rets| rets.into_lua_multi(lua)) {
|
|
||||||
Err(e) => {
|
|
||||||
self.push_err(lua, thread, e)
|
|
||||||
.expect("Failed to schedule future err thread");
|
|
||||||
}
|
|
||||||
Ok(v) => {
|
|
||||||
self.push_back(lua, thread, v)
|
|
||||||
.expect("Failed to schedule future thread");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}));
|
|
||||||
|
|
||||||
// NOTE: We might be resuming background futures, need to signal that a
|
|
||||||
// new background future is ready to break out of futures resumption
|
|
||||||
self.state.message_sender().send_spawned_lua_future();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,265 +0,0 @@
|
||||||
use std::{process::ExitCode, sync::Arc};
|
|
||||||
|
|
||||||
use futures_util::StreamExt;
|
|
||||||
use mlua::prelude::*;
|
|
||||||
|
|
||||||
use tokio::task::LocalSet;
|
|
||||||
use tracing::debug;
|
|
||||||
|
|
||||||
use crate::lune::util::traits::LuaEmitErrorExt;
|
|
||||||
|
|
||||||
use super::Scheduler;
|
|
||||||
|
|
||||||
impl<'fut> Scheduler<'fut> {
|
|
||||||
/**
|
|
||||||
Runs all lua threads to completion.
|
|
||||||
*/
|
|
||||||
fn run_lua_threads(&self, lua: &Lua) {
|
|
||||||
if self.state.has_exit_code() {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut count = 0;
|
|
||||||
|
|
||||||
// Pop threads from the scheduler until there are none left
|
|
||||||
while let Some(thread) = self
|
|
||||||
.pop_thread()
|
|
||||||
.expect("Failed to pop thread from scheduler")
|
|
||||||
{
|
|
||||||
// Deconstruct the scheduler thread into its parts
|
|
||||||
let thread_id = thread.id();
|
|
||||||
let (thread, args) = thread.into_inner(lua);
|
|
||||||
|
|
||||||
// Make sure this thread is still resumable, it might have
|
|
||||||
// been resumed somewhere else or even have been cancelled
|
|
||||||
if thread.status() != LuaThreadStatus::Resumable {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resume the thread, ensuring that the schedulers
|
|
||||||
// current thread id is set correctly for error catching
|
|
||||||
self.state.set_current_thread_id(Some(thread_id));
|
|
||||||
let res = thread.resume::<_, LuaMultiValue>(args);
|
|
||||||
self.state.set_current_thread_id(None);
|
|
||||||
|
|
||||||
count += 1;
|
|
||||||
|
|
||||||
// If we got any resumption (lua-side) error, increment
|
|
||||||
// the error count of the scheduler so we can exit with
|
|
||||||
// a non-zero exit code, and print it out to stderr
|
|
||||||
if let Err(err) = &res {
|
|
||||||
self.state.increment_error_count();
|
|
||||||
lua.emit_error(err.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the thread has finished running completely,
|
|
||||||
// send results of final resume to any listeners
|
|
||||||
if thread.status() != LuaThreadStatus::Resumable {
|
|
||||||
// NOTE: Threads that were spawned to resume
|
|
||||||
// with an error will not have a result sender
|
|
||||||
if let Some(sender) = self
|
|
||||||
.thread_senders
|
|
||||||
.try_lock()
|
|
||||||
.expect("Failed to get thread senders")
|
|
||||||
.remove(&thread_id)
|
|
||||||
{
|
|
||||||
if sender.receiver_count() > 0 {
|
|
||||||
let stored = match res {
|
|
||||||
Err(e) => Err(e),
|
|
||||||
Ok(v) => Ok(Arc::new(lua.create_registry_value(v.into_vec()).expect(
|
|
||||||
"Failed to store thread results in registry - out of memory",
|
|
||||||
))),
|
|
||||||
};
|
|
||||||
sender
|
|
||||||
.send(stored)
|
|
||||||
.expect("Failed to broadcast thread results");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.state.has_exit_code() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if count > 0 {
|
|
||||||
debug! {
|
|
||||||
%count,
|
|
||||||
"resumed lua"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Runs the next lua future to completion.
|
|
||||||
|
|
||||||
Panics if no lua future is queued.
|
|
||||||
*/
|
|
||||||
async fn run_future_lua(&self) {
|
|
||||||
let mut futs = self
|
|
||||||
.futures_lua
|
|
||||||
.try_lock()
|
|
||||||
.expect("Failed to lock lua futures for resumption");
|
|
||||||
assert!(futs.len() > 0, "No lua futures are queued");
|
|
||||||
futs.next().await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Runs the next background future to completion.
|
|
||||||
|
|
||||||
Panics if no background future is queued.
|
|
||||||
*/
|
|
||||||
async fn run_future_background(&self) {
|
|
||||||
let mut futs = self
|
|
||||||
.futures_background
|
|
||||||
.try_lock()
|
|
||||||
.expect("Failed to lock background futures for resumption");
|
|
||||||
assert!(futs.len() > 0, "No background futures are queued");
|
|
||||||
futs.next().await;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Runs as many futures as possible, until a new lua thread
|
|
||||||
is ready, or an exit code has been set for the scheduler.
|
|
||||||
|
|
||||||
### Implementation details
|
|
||||||
|
|
||||||
Running futures on our scheduler consists of a couple moving parts:
|
|
||||||
|
|
||||||
1. An unordered futures queue for lua (main thread, local) futures
|
|
||||||
2. An unordered futures queue for background (multithreaded, 'static lifetime) futures
|
|
||||||
3. A signal for breaking out of futures resumption
|
|
||||||
|
|
||||||
The two unordered futures queues need to run concurrently,
|
|
||||||
but since `FuturesUnordered` returns instantly if it does
|
|
||||||
not currently have any futures queued on it, we need to do
|
|
||||||
this branching loop, checking if each queue has futures first.
|
|
||||||
|
|
||||||
We also need to listen for our signal, to see if we should break out of resumption:
|
|
||||||
|
|
||||||
* Always break out of resumption if a new lua thread is ready
|
|
||||||
* Always break out of resumption if an exit code has been set
|
|
||||||
* Break out of lua futures resumption if we have a new background future
|
|
||||||
* Break out of background futures resumption if we have a new lua future
|
|
||||||
|
|
||||||
We need to listen for both future queues concurrently,
|
|
||||||
and break out whenever the other corresponding queue has
|
|
||||||
a new future, since the other queue may resume sooner.
|
|
||||||
*/
|
|
||||||
async fn run_futures(&self) {
|
|
||||||
let (mut has_lua, mut has_background) = self.has_futures();
|
|
||||||
if !has_lua && !has_background {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut rx = self.state.message_receiver();
|
|
||||||
let mut count = 0;
|
|
||||||
|
|
||||||
while has_lua || has_background {
|
|
||||||
if has_lua && has_background {
|
|
||||||
tokio::select! {
|
|
||||||
_ = self.run_future_lua() => {},
|
|
||||||
_ = self.run_future_background() => {},
|
|
||||||
msg = rx.recv() => {
|
|
||||||
if let Some(msg) = msg {
|
|
||||||
if msg.should_break_futures() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count += 1;
|
|
||||||
} else if has_lua {
|
|
||||||
tokio::select! {
|
|
||||||
_ = self.run_future_lua() => {},
|
|
||||||
msg = rx.recv() => {
|
|
||||||
if let Some(msg) = msg {
|
|
||||||
if msg.should_break_lua_futures() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count += 1;
|
|
||||||
} else if has_background {
|
|
||||||
tokio::select! {
|
|
||||||
_ = self.run_future_background() => {},
|
|
||||||
msg = rx.recv() => {
|
|
||||||
if let Some(msg) = msg {
|
|
||||||
if msg.should_break_background_futures() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
count += 1;
|
|
||||||
}
|
|
||||||
(has_lua, has_background) = self.has_futures();
|
|
||||||
}
|
|
||||||
|
|
||||||
if count > 0 {
|
|
||||||
debug! {
|
|
||||||
%count,
|
|
||||||
"resumed lua futures"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Runs the scheduler to completion in a [`LocalSet`],
|
|
||||||
both normal lua threads and futures, prioritizing
|
|
||||||
lua threads over completion of any pending futures.
|
|
||||||
|
|
||||||
Will emit lua output and errors to stdout and stderr.
|
|
||||||
*/
|
|
||||||
pub async fn run_to_completion(&self, lua: &Lua) -> ExitCode {
|
|
||||||
if let Some(code) = self.state.exit_code() {
|
|
||||||
return ExitCode::from(code);
|
|
||||||
}
|
|
||||||
|
|
||||||
let set = LocalSet::new();
|
|
||||||
let _guard = set.enter();
|
|
||||||
|
|
||||||
loop {
|
|
||||||
// 1. Run lua threads until exit or there are none left
|
|
||||||
self.run_lua_threads(lua);
|
|
||||||
|
|
||||||
// 2. If we got a manual exit code from lua we should
|
|
||||||
// not try to wait for any pending futures to complete
|
|
||||||
if self.state.has_exit_code() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 3. Keep resuming futures until there are no futures left to
|
|
||||||
// resume, or until we manually break out of resumption for any
|
|
||||||
// reason, this may be because a future spawned a new lua thread
|
|
||||||
self.run_futures().await;
|
|
||||||
|
|
||||||
// 4. Once again, check for an exit code, in case a future sets one
|
|
||||||
if self.state.has_exit_code() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
// 5. If we have no lua threads or futures remaining,
|
|
||||||
// we have now run the scheduler until completion
|
|
||||||
let (has_future_lua, has_future_background) = self.has_futures();
|
|
||||||
if !has_future_lua && !has_future_background && !self.has_thread() {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if let Some(code) = self.state.exit_code() {
|
|
||||||
debug! {
|
|
||||||
%code,
|
|
||||||
"scheduler ran to completion"
|
|
||||||
};
|
|
||||||
ExitCode::from(code)
|
|
||||||
} else if self.state.has_errored() {
|
|
||||||
debug!("scheduler ran to completion, with failure");
|
|
||||||
ExitCode::FAILURE
|
|
||||||
} else {
|
|
||||||
debug!("scheduler ran to completion, with success");
|
|
||||||
ExitCode::SUCCESS
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,185 +0,0 @@
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use mlua::prelude::*;
|
|
||||||
|
|
||||||
use super::{
|
|
||||||
thread::{SchedulerThread, SchedulerThreadId, SchedulerThreadSender},
|
|
||||||
IntoLuaThread, Scheduler,
|
|
||||||
};
|
|
||||||
|
|
||||||
impl<'fut> Scheduler<'fut> {
|
|
||||||
/**
|
|
||||||
Checks if there are any lua threads to run.
|
|
||||||
*/
|
|
||||||
pub(super) fn has_thread(&self) -> bool {
|
|
||||||
!self
|
|
||||||
.threads
|
|
||||||
.try_lock()
|
|
||||||
.expect("Failed to lock threads vec")
|
|
||||||
.is_empty()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Pops the next thread to run, from the front of the scheduler.
|
|
||||||
|
|
||||||
Returns `None` if there are no threads left to run.
|
|
||||||
*/
|
|
||||||
pub(super) fn pop_thread(&self) -> LuaResult<Option<SchedulerThread>> {
|
|
||||||
match self
|
|
||||||
.threads
|
|
||||||
.try_lock()
|
|
||||||
.into_lua_err()
|
|
||||||
.context("Failed to lock threads vec")?
|
|
||||||
.pop_front()
|
|
||||||
{
|
|
||||||
Some(thread) => Ok(Some(thread)),
|
|
||||||
None => Ok(None),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Schedules the `thread` to be resumed with the given [`LuaError`].
|
|
||||||
*/
|
|
||||||
pub fn push_err<'a>(
|
|
||||||
&self,
|
|
||||||
lua: &'a Lua,
|
|
||||||
thread: impl IntoLuaThread<'a>,
|
|
||||||
err: LuaError,
|
|
||||||
) -> LuaResult<()> {
|
|
||||||
let thread = thread.into_lua_thread(lua)?;
|
|
||||||
let args = LuaMultiValue::new(); // Will be resumed with error, don't need real args
|
|
||||||
|
|
||||||
let thread = SchedulerThread::new(lua, thread, args);
|
|
||||||
let thread_id = thread.id();
|
|
||||||
|
|
||||||
self.state.set_thread_error(thread_id, err);
|
|
||||||
self.threads
|
|
||||||
.try_lock()
|
|
||||||
.into_lua_err()
|
|
||||||
.context("Failed to lock threads vec")?
|
|
||||||
.push_front(thread);
|
|
||||||
|
|
||||||
// NOTE: We might be resuming futures, need to signal that a
|
|
||||||
// new lua thread is ready to break out of futures resumption
|
|
||||||
self.state.message_sender().send_pushed_lua_thread();
|
|
||||||
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Schedules the `thread` to be resumed with the given `args`
|
|
||||||
right away, before any other currently scheduled threads.
|
|
||||||
*/
|
|
||||||
pub fn push_front<'a>(
|
|
||||||
&self,
|
|
||||||
lua: &'a Lua,
|
|
||||||
thread: impl IntoLuaThread<'a>,
|
|
||||||
args: impl IntoLuaMulti<'a>,
|
|
||||||
) -> LuaResult<SchedulerThreadId> {
|
|
||||||
let thread = thread.into_lua_thread(lua)?;
|
|
||||||
let args = args.into_lua_multi(lua)?;
|
|
||||||
|
|
||||||
let thread = SchedulerThread::new(lua, thread, args);
|
|
||||||
let thread_id = thread.id();
|
|
||||||
|
|
||||||
self.threads
|
|
||||||
.try_lock()
|
|
||||||
.into_lua_err()
|
|
||||||
.context("Failed to lock threads vec")?
|
|
||||||
.push_front(thread);
|
|
||||||
|
|
||||||
// NOTE: We might be resuming the same thread several times and
|
|
||||||
// pushing it to the scheduler several times before it is done,
|
|
||||||
// and we should only ever create one result sender per thread
|
|
||||||
self.thread_senders
|
|
||||||
.try_lock()
|
|
||||||
.into_lua_err()
|
|
||||||
.context("Failed to lock thread senders vec")?
|
|
||||||
.entry(thread_id)
|
|
||||||
.or_insert_with(|| SchedulerThreadSender::new(1));
|
|
||||||
|
|
||||||
// NOTE: We might be resuming futures, need to signal that a
|
|
||||||
// new lua thread is ready to break out of futures resumption
|
|
||||||
self.state.message_sender().send_pushed_lua_thread();
|
|
||||||
|
|
||||||
Ok(thread_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Schedules the `thread` to be resumed with the given `args`
|
|
||||||
after all other current threads have been resumed.
|
|
||||||
*/
|
|
||||||
pub fn push_back<'a>(
|
|
||||||
&self,
|
|
||||||
lua: &'a Lua,
|
|
||||||
thread: impl IntoLuaThread<'a>,
|
|
||||||
args: impl IntoLuaMulti<'a>,
|
|
||||||
) -> LuaResult<SchedulerThreadId> {
|
|
||||||
let thread = thread.into_lua_thread(lua)?;
|
|
||||||
let args = args.into_lua_multi(lua)?;
|
|
||||||
|
|
||||||
let thread = SchedulerThread::new(lua, thread, args);
|
|
||||||
let thread_id = thread.id();
|
|
||||||
|
|
||||||
self.threads
|
|
||||||
.try_lock()
|
|
||||||
.into_lua_err()
|
|
||||||
.context("Failed to lock threads vec")?
|
|
||||||
.push_back(thread);
|
|
||||||
|
|
||||||
// NOTE: We might be resuming the same thread several times and
|
|
||||||
// pushing it to the scheduler several times before it is done,
|
|
||||||
// and we should only ever create one result sender per thread
|
|
||||||
self.thread_senders
|
|
||||||
.try_lock()
|
|
||||||
.into_lua_err()
|
|
||||||
.context("Failed to lock thread senders vec")?
|
|
||||||
.entry(thread_id)
|
|
||||||
.or_insert_with(|| SchedulerThreadSender::new(1));
|
|
||||||
|
|
||||||
// NOTE: We might be resuming futures, need to signal that a
|
|
||||||
// new lua thread is ready to break out of futures resumption
|
|
||||||
self.state.message_sender().send_pushed_lua_thread();
|
|
||||||
|
|
||||||
Ok(thread_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Waits for the given thread to finish running, and returns its result.
|
|
||||||
*/
|
|
||||||
pub async fn wait_for_thread<'a>(
|
|
||||||
&self,
|
|
||||||
lua: &'a Lua,
|
|
||||||
thread_id: SchedulerThreadId,
|
|
||||||
) -> LuaResult<LuaMultiValue<'a>> {
|
|
||||||
let mut recv = {
|
|
||||||
let senders = self.thread_senders.lock().await;
|
|
||||||
let sender = senders
|
|
||||||
.get(&thread_id)
|
|
||||||
.expect("Tried to wait for thread that is not queued");
|
|
||||||
sender.subscribe()
|
|
||||||
};
|
|
||||||
let res = match recv.recv().await {
|
|
||||||
Err(_) => panic!("Sender was dropped while waiting for {thread_id:?}"),
|
|
||||||
Ok(r) => r,
|
|
||||||
};
|
|
||||||
match res {
|
|
||||||
Err(e) => Err(e),
|
|
||||||
Ok(k) => {
|
|
||||||
let vals = lua
|
|
||||||
.registry_value::<Vec<LuaValue>>(&k)
|
|
||||||
.expect("Received invalid registry key for thread");
|
|
||||||
|
|
||||||
// NOTE: This is not strictly necessary, mlua can clean
|
|
||||||
// up registry values on its own, but doing this will add
|
|
||||||
// some extra safety and clean up registry values faster
|
|
||||||
if let Some(key) = Arc::into_inner(k) {
|
|
||||||
lua.remove_registry_value(key)
|
|
||||||
.expect("Failed to remove registry key for thread");
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(LuaMultiValue::from_vec(vals))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,98 +0,0 @@
|
||||||
use std::sync::{MutexGuard, TryLockError};
|
|
||||||
|
|
||||||
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
|
|
||||||
|
|
||||||
use super::state::SchedulerState;
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
|
||||||
pub(crate) enum SchedulerMessage {
|
|
||||||
ExitCodeSet,
|
|
||||||
PushedLuaThread,
|
|
||||||
SpawnedLuaFuture,
|
|
||||||
SpawnedBackgroundFuture,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SchedulerMessage {
|
|
||||||
pub fn should_break_futures(self) -> bool {
|
|
||||||
matches!(self, Self::ExitCodeSet | Self::PushedLuaThread)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn should_break_lua_futures(self) -> bool {
|
|
||||||
self.should_break_futures() || matches!(self, Self::SpawnedBackgroundFuture)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn should_break_background_futures(self) -> bool {
|
|
||||||
self.should_break_futures() || matches!(self, Self::SpawnedLuaFuture)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
A message sender for the scheduler.
|
|
||||||
|
|
||||||
As long as this sender is not dropped, the scheduler
|
|
||||||
will be kept alive, waiting for more messages to arrive.
|
|
||||||
*/
|
|
||||||
pub(crate) struct SchedulerMessageSender(UnboundedSender<SchedulerMessage>);
|
|
||||||
|
|
||||||
impl SchedulerMessageSender {
|
|
||||||
/**
|
|
||||||
Creates a new message sender for the scheduler.
|
|
||||||
*/
|
|
||||||
pub fn new(state: &SchedulerState) -> Self {
|
|
||||||
Self(
|
|
||||||
state
|
|
||||||
.message_sender
|
|
||||||
.lock()
|
|
||||||
.expect("Scheduler state was poisoned")
|
|
||||||
.clone(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_exit_code_set(&self) {
|
|
||||||
self.0.send(SchedulerMessage::ExitCodeSet).ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_pushed_lua_thread(&self) {
|
|
||||||
self.0.send(SchedulerMessage::PushedLuaThread).ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_spawned_lua_future(&self) {
|
|
||||||
self.0.send(SchedulerMessage::SpawnedLuaFuture).ok();
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn send_spawned_background_future(&self) {
|
|
||||||
self.0.send(SchedulerMessage::SpawnedBackgroundFuture).ok();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
A message receiver for the scheduler.
|
|
||||||
|
|
||||||
Only one message receiver may exist per scheduler.
|
|
||||||
*/
|
|
||||||
pub(crate) struct SchedulerMessageReceiver<'a>(MutexGuard<'a, UnboundedReceiver<SchedulerMessage>>);
|
|
||||||
|
|
||||||
impl<'a> SchedulerMessageReceiver<'a> {
|
|
||||||
/**
|
|
||||||
Creates a new message receiver for the scheduler.
|
|
||||||
|
|
||||||
Panics if the message receiver is already being used.
|
|
||||||
*/
|
|
||||||
pub fn new(state: &'a SchedulerState) -> Self {
|
|
||||||
Self(match state.message_receiver.try_lock() {
|
|
||||||
Err(TryLockError::Poisoned(_)) => panic!("Sheduler state was poisoned"),
|
|
||||||
Err(TryLockError::WouldBlock) => {
|
|
||||||
panic!("Message receiver may only be borrowed once at a time")
|
|
||||||
}
|
|
||||||
Ok(guard) => guard,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// NOTE: Holding this lock across await points is fine, since we
|
|
||||||
// can only ever create lock exactly one SchedulerMessageReceiver
|
|
||||||
// See above constructor for details on this
|
|
||||||
#[allow(clippy::await_holding_lock)]
|
|
||||||
pub async fn recv(&mut self) -> Option<SchedulerMessage> {
|
|
||||||
self.0.recv().await
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,120 +0,0 @@
|
||||||
use std::{
|
|
||||||
collections::{HashMap, VecDeque},
|
|
||||||
pin::Pin,
|
|
||||||
sync::Arc,
|
|
||||||
};
|
|
||||||
|
|
||||||
use futures_util::{stream::FuturesUnordered, Future};
|
|
||||||
use mlua::prelude::*;
|
|
||||||
use tokio::sync::Mutex as AsyncMutex;
|
|
||||||
|
|
||||||
mod message;
|
|
||||||
mod state;
|
|
||||||
mod thread;
|
|
||||||
mod traits;
|
|
||||||
|
|
||||||
mod impl_async;
|
|
||||||
mod impl_runner;
|
|
||||||
mod impl_threads;
|
|
||||||
|
|
||||||
pub use self::thread::SchedulerThreadId;
|
|
||||||
pub use self::traits::*;
|
|
||||||
|
|
||||||
use self::{
|
|
||||||
state::SchedulerState,
|
|
||||||
thread::{SchedulerThread, SchedulerThreadSender},
|
|
||||||
};
|
|
||||||
|
|
||||||
type SchedulerFuture<'fut> = Pin<Box<dyn Future<Output = ()> + 'fut>>;
|
|
||||||
|
|
||||||
/**
|
|
||||||
Scheduler for Lua threads and futures.
|
|
||||||
|
|
||||||
This scheduler can be cheaply cloned and the underlying state
|
|
||||||
and data will remain unchanged and accessible from all clones.
|
|
||||||
*/
|
|
||||||
#[derive(Debug, Clone)]
|
|
||||||
pub(crate) struct Scheduler<'fut> {
|
|
||||||
state: Arc<SchedulerState>,
|
|
||||||
threads: Arc<AsyncMutex<VecDeque<SchedulerThread>>>,
|
|
||||||
thread_senders: Arc<AsyncMutex<HashMap<SchedulerThreadId, SchedulerThreadSender>>>,
|
|
||||||
/*
|
|
||||||
FUTURE: Get rid of these, let the tokio runtime handle running
|
|
||||||
and resumption of futures completely, just use our scheduler
|
|
||||||
state and receiver to know when we have run to completion.
|
|
||||||
If we have no senders left, we have run to completion.
|
|
||||||
|
|
||||||
We should also investigate using smol / async-executor and its
|
|
||||||
LocalExecutor struct which does not impose the 'static lifetime
|
|
||||||
restriction on all of the futures spawned on it, unlike tokio.
|
|
||||||
|
|
||||||
If we no longer store futures directly in our scheduler, we
|
|
||||||
can get rid of the lifetime on it, store it in our lua app
|
|
||||||
data as a Weak<Scheduler>, together with a Weak<Lua>.
|
|
||||||
|
|
||||||
In our lua async functions we can then get a reference to this,
|
|
||||||
upgrade it to an Arc<Scheduler> and Arc<Lua> to extend lifetimes,
|
|
||||||
and hopefully get rid of Box::leak and 'static lifetimes for good.
|
|
||||||
|
|
||||||
Relevant comment on the mlua repository:
|
|
||||||
https://github.com/khvzak/mlua/issues/169#issuecomment-1138863979
|
|
||||||
*/
|
|
||||||
futures_lua: Arc<AsyncMutex<FuturesUnordered<SchedulerFuture<'fut>>>>,
|
|
||||||
futures_background: Arc<AsyncMutex<FuturesUnordered<SchedulerFuture<'static>>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'fut> Scheduler<'fut> {
|
|
||||||
/**
|
|
||||||
Creates a new scheduler.
|
|
||||||
*/
|
|
||||||
#[allow(clippy::arc_with_non_send_sync)] // FIXME: Clippy lints our tokio mutexes that are definitely Send + Sync
|
|
||||||
pub fn new() -> Self {
|
|
||||||
Self {
|
|
||||||
state: Arc::new(SchedulerState::new()),
|
|
||||||
threads: Arc::new(AsyncMutex::new(VecDeque::new())),
|
|
||||||
thread_senders: Arc::new(AsyncMutex::new(HashMap::new())),
|
|
||||||
futures_lua: Arc::new(AsyncMutex::new(FuturesUnordered::new())),
|
|
||||||
futures_background: Arc::new(AsyncMutex::new(FuturesUnordered::new())),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Sets the luau interrupt for this scheduler.
|
|
||||||
|
|
||||||
This will propagate errors from any lua-spawned
|
|
||||||
futures back to the lua threads that spawned them.
|
|
||||||
*/
|
|
||||||
pub fn set_interrupt_for(&self, lua: &Lua) {
|
|
||||||
// Propagate errors given to the scheduler back to their lua threads
|
|
||||||
// FUTURE: Do profiling and anything else we need inside of this interrupt
|
|
||||||
let state = self.state.clone();
|
|
||||||
lua.set_interrupt(move |_| {
|
|
||||||
if let Some(id) = state.get_current_thread_id() {
|
|
||||||
if let Some(err) = state.get_thread_error(id) {
|
|
||||||
return Err(err);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Ok(LuaVmState::Continue)
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Sets the exit code for the scheduler.
|
|
||||||
|
|
||||||
This will stop the scheduler from resuming any more lua threads or futures.
|
|
||||||
|
|
||||||
Panics if the exit code is set more than once.
|
|
||||||
*/
|
|
||||||
pub fn set_exit_code(&self, code: impl Into<u8>) {
|
|
||||||
assert!(
|
|
||||||
self.state.exit_code().is_none(),
|
|
||||||
"Exit code may only be set exactly once"
|
|
||||||
);
|
|
||||||
self.state.set_exit_code(code.into());
|
|
||||||
}
|
|
||||||
|
|
||||||
#[doc(hidden)]
|
|
||||||
pub fn into_static(self) -> &'static Self {
|
|
||||||
Box::leak(Box::new(self))
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,176 +0,0 @@
|
||||||
use std::{
|
|
||||||
collections::HashMap,
|
|
||||||
sync::{
|
|
||||||
atomic::{AtomicBool, AtomicU8, AtomicUsize, Ordering},
|
|
||||||
Arc, Mutex,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
|
|
||||||
use mlua::Error as LuaError;
|
|
||||||
|
|
||||||
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
|
|
||||||
|
|
||||||
use super::{
|
|
||||||
message::{SchedulerMessage, SchedulerMessageReceiver, SchedulerMessageSender},
|
|
||||||
SchedulerThreadId,
|
|
||||||
};
|
|
||||||
|
|
||||||
/**
|
|
||||||
Internal state for a [`Scheduler`].
|
|
||||||
|
|
||||||
This scheduler state uses atomic operations for everything
|
|
||||||
except lua error storage, and is completely thread safe.
|
|
||||||
*/
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(crate) struct SchedulerState {
|
|
||||||
exit_state: AtomicBool,
|
|
||||||
exit_code: AtomicU8,
|
|
||||||
num_resumptions: AtomicUsize,
|
|
||||||
num_errors: AtomicUsize,
|
|
||||||
thread_id: Arc<Mutex<Option<SchedulerThreadId>>>,
|
|
||||||
thread_errors: Arc<Mutex<HashMap<SchedulerThreadId, LuaError>>>,
|
|
||||||
pub(super) message_sender: Arc<Mutex<UnboundedSender<SchedulerMessage>>>,
|
|
||||||
pub(super) message_receiver: Arc<Mutex<UnboundedReceiver<SchedulerMessage>>>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SchedulerState {
|
|
||||||
/**
|
|
||||||
Creates a new scheduler state.
|
|
||||||
*/
|
|
||||||
pub fn new() -> Self {
|
|
||||||
let (message_sender, message_receiver) = unbounded_channel();
|
|
||||||
|
|
||||||
Self {
|
|
||||||
exit_state: AtomicBool::new(false),
|
|
||||||
exit_code: AtomicU8::new(0),
|
|
||||||
num_resumptions: AtomicUsize::new(0),
|
|
||||||
num_errors: AtomicUsize::new(0),
|
|
||||||
thread_id: Arc::new(Mutex::new(None)),
|
|
||||||
thread_errors: Arc::new(Mutex::new(HashMap::new())),
|
|
||||||
message_sender: Arc::new(Mutex::new(message_sender)),
|
|
||||||
message_receiver: Arc::new(Mutex::new(message_receiver)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Increments the total lua error count for the scheduler.
|
|
||||||
|
|
||||||
This is used to determine if the scheduler should exit with
|
|
||||||
a non-zero exit code, when no exit code is explicitly set.
|
|
||||||
*/
|
|
||||||
pub fn increment_error_count(&self) {
|
|
||||||
self.num_errors.fetch_add(1, Ordering::Relaxed);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Checks if there have been any lua errors.
|
|
||||||
|
|
||||||
This is used to determine if the scheduler should exit with
|
|
||||||
a non-zero exit code, when no exit code is explicitly set.
|
|
||||||
*/
|
|
||||||
pub fn has_errored(&self) -> bool {
|
|
||||||
self.num_errors.load(Ordering::SeqCst) > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Gets the currently set exit code for the scheduler, if any.
|
|
||||||
*/
|
|
||||||
pub fn exit_code(&self) -> Option<u8> {
|
|
||||||
if self.exit_state.load(Ordering::SeqCst) {
|
|
||||||
Some(self.exit_code.load(Ordering::SeqCst))
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Checks if the scheduler has an explicit exit code set.
|
|
||||||
*/
|
|
||||||
pub fn has_exit_code(&self) -> bool {
|
|
||||||
self.exit_state.load(Ordering::SeqCst)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Sets the explicit exit code for the scheduler.
|
|
||||||
*/
|
|
||||||
pub fn set_exit_code(&self, code: impl Into<u8>) {
|
|
||||||
self.exit_state.store(true, Ordering::SeqCst);
|
|
||||||
self.exit_code.store(code.into(), Ordering::SeqCst);
|
|
||||||
self.message_sender().send_exit_code_set();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Gets the currently running lua scheduler thread id, if any.
|
|
||||||
*/
|
|
||||||
pub fn get_current_thread_id(&self) -> Option<SchedulerThreadId> {
|
|
||||||
*self
|
|
||||||
.thread_id
|
|
||||||
.lock()
|
|
||||||
.expect("Failed to lock current thread id")
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Sets the currently running lua scheduler thread id.
|
|
||||||
|
|
||||||
This must be set to `Some(id)` just before resuming a lua thread,
|
|
||||||
and `None` while no lua thread is being resumed. If set to `Some`
|
|
||||||
while the current thread id is also `Some`, this will panic.
|
|
||||||
|
|
||||||
Must only be set once per thread id, although this
|
|
||||||
is not checked at runtime for performance reasons.
|
|
||||||
*/
|
|
||||||
pub fn set_current_thread_id(&self, id: Option<SchedulerThreadId>) {
|
|
||||||
self.num_resumptions.fetch_add(1, Ordering::Relaxed);
|
|
||||||
let mut thread_id = self
|
|
||||||
.thread_id
|
|
||||||
.lock()
|
|
||||||
.expect("Failed to lock current thread id");
|
|
||||||
assert!(
|
|
||||||
id.is_none() || thread_id.is_none(),
|
|
||||||
"Current thread id can not be overwritten"
|
|
||||||
);
|
|
||||||
*thread_id = id;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Gets the [`LuaError`] (if any) for the given `id`.
|
|
||||||
|
|
||||||
Note that this removes the error from the scheduler state completely.
|
|
||||||
*/
|
|
||||||
pub fn get_thread_error(&self, id: SchedulerThreadId) -> Option<LuaError> {
|
|
||||||
let mut thread_errors = self
|
|
||||||
.thread_errors
|
|
||||||
.lock()
|
|
||||||
.expect("Failed to lock thread errors");
|
|
||||||
thread_errors.remove(&id)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Sets a [`LuaError`] for the given `id`.
|
|
||||||
|
|
||||||
Note that this will replace any already existing [`LuaError`].
|
|
||||||
*/
|
|
||||||
pub fn set_thread_error(&self, id: SchedulerThreadId, err: LuaError) {
|
|
||||||
let mut thread_errors = self
|
|
||||||
.thread_errors
|
|
||||||
.lock()
|
|
||||||
.expect("Failed to lock thread errors");
|
|
||||||
thread_errors.insert(id, err);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Creates a new message sender for the scheduler.
|
|
||||||
*/
|
|
||||||
pub fn message_sender(&self) -> SchedulerMessageSender {
|
|
||||||
SchedulerMessageSender::new(self)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Tries to borrow the message receiver for the scheduler.
|
|
||||||
|
|
||||||
Panics if the message receiver is already being used.
|
|
||||||
*/
|
|
||||||
pub fn message_receiver(&self) -> SchedulerMessageReceiver {
|
|
||||||
SchedulerMessageReceiver::new(self)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,105 +0,0 @@
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use mlua::prelude::*;
|
|
||||||
use tokio::sync::broadcast::Sender;
|
|
||||||
|
|
||||||
/**
|
|
||||||
Type alias for a broadcast [`Sender`], which will
|
|
||||||
broadcast the result and return values of a lua thread.
|
|
||||||
|
|
||||||
The return values are stored in the lua registry as a
|
|
||||||
`Vec<LuaValue<'_>>`, and the registry key pointing to
|
|
||||||
those values will be sent using the broadcast sender.
|
|
||||||
*/
|
|
||||||
pub type SchedulerThreadSender = Sender<LuaResult<Arc<LuaRegistryKey>>>;
|
|
||||||
|
|
||||||
/**
|
|
||||||
Unique, randomly generated id for a scheduler thread.
|
|
||||||
*/
|
|
||||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
|
|
||||||
pub struct SchedulerThreadId(usize);
|
|
||||||
|
|
||||||
impl From<&LuaThread<'_>> for SchedulerThreadId {
|
|
||||||
fn from(value: &LuaThread) -> Self {
|
|
||||||
// HACK: We rely on the debug format of mlua
|
|
||||||
// thread refs here, but currently this is the
|
|
||||||
// only way to get a proper unique id using mlua
|
|
||||||
let addr_string = format!("{value:?}");
|
|
||||||
let addr = addr_string
|
|
||||||
.strip_prefix("Thread(Ref(0x")
|
|
||||||
.expect("Invalid thread address format - unknown prefix")
|
|
||||||
.split_once(')')
|
|
||||||
.map(|(s, _)| s)
|
|
||||||
.expect("Invalid thread address format - missing ')'");
|
|
||||||
let id = usize::from_str_radix(addr, 16)
|
|
||||||
.expect("Failed to parse thread address as hexadecimal into usize");
|
|
||||||
Self(id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Container for registry keys that point to a thread and thread arguments.
|
|
||||||
*/
|
|
||||||
#[derive(Debug)]
|
|
||||||
pub(super) struct SchedulerThread {
|
|
||||||
thread_id: SchedulerThreadId,
|
|
||||||
key_thread: LuaRegistryKey,
|
|
||||||
key_args: LuaRegistryKey,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl SchedulerThread {
|
|
||||||
/**
|
|
||||||
Creates a new scheduler thread container from the given thread and arguments.
|
|
||||||
|
|
||||||
May fail if an allocation error occurs, is not fallible otherwise.
|
|
||||||
*/
|
|
||||||
pub(super) fn new<'lua>(
|
|
||||||
lua: &'lua Lua,
|
|
||||||
thread: LuaThread<'lua>,
|
|
||||||
args: LuaMultiValue<'lua>,
|
|
||||||
) -> Self {
|
|
||||||
let args_vec = args.into_vec();
|
|
||||||
let thread_id = SchedulerThreadId::from(&thread);
|
|
||||||
|
|
||||||
let key_thread = lua
|
|
||||||
.create_registry_value(thread)
|
|
||||||
.expect("Failed to store thread in registry - out of memory");
|
|
||||||
let key_args = lua
|
|
||||||
.create_registry_value(args_vec)
|
|
||||||
.expect("Failed to store thread args in registry - out of memory");
|
|
||||||
|
|
||||||
Self {
|
|
||||||
thread_id,
|
|
||||||
key_thread,
|
|
||||||
key_args,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Extracts the inner thread and args from the container.
|
|
||||||
*/
|
|
||||||
pub(super) fn into_inner(self, lua: &Lua) -> (LuaThread<'_>, LuaMultiValue<'_>) {
|
|
||||||
let thread = lua
|
|
||||||
.registry_value(&self.key_thread)
|
|
||||||
.expect("Failed to get thread from registry");
|
|
||||||
let args_vec = lua
|
|
||||||
.registry_value(&self.key_args)
|
|
||||||
.expect("Failed to get thread args from registry");
|
|
||||||
|
|
||||||
let args = LuaMultiValue::from_vec(args_vec);
|
|
||||||
|
|
||||||
lua.remove_registry_value(self.key_thread)
|
|
||||||
.expect("Failed to remove thread from registry");
|
|
||||||
lua.remove_registry_value(self.key_args)
|
|
||||||
.expect("Failed to remove thread args from registry");
|
|
||||||
|
|
||||||
(thread, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Retrieves the unique, randomly generated id for this scheduler thread.
|
|
||||||
*/
|
|
||||||
pub(super) fn id(&self) -> SchedulerThreadId {
|
|
||||||
self.thread_id
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,118 +0,0 @@
|
||||||
use futures_util::Future;
|
|
||||||
use mlua::prelude::*;
|
|
||||||
|
|
||||||
use super::Scheduler;
|
|
||||||
|
|
||||||
const ASYNC_IMPL_LUA: &str = r#"
|
|
||||||
schedule(...)
|
|
||||||
return yield()
|
|
||||||
"#;
|
|
||||||
|
|
||||||
/**
|
|
||||||
Trait for extensions to the [`Lua`] struct, allowing
|
|
||||||
for access to the scheduler without having to import
|
|
||||||
it or handle registry / app data references manually.
|
|
||||||
*/
|
|
||||||
pub(crate) trait LuaSchedulerExt<'lua> {
|
|
||||||
/**
|
|
||||||
Sets the scheduler for the [`Lua`] struct.
|
|
||||||
*/
|
|
||||||
fn set_scheduler(&'lua self, scheduler: &'lua Scheduler);
|
|
||||||
|
|
||||||
/**
|
|
||||||
Creates a function callable from Lua that runs an async
|
|
||||||
closure and returns the results of it to the call site.
|
|
||||||
*/
|
|
||||||
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
|
|
||||||
where
|
|
||||||
A: FromLuaMulti<'lua>,
|
|
||||||
R: IntoLuaMulti<'lua>,
|
|
||||||
F: Fn(&'lua Lua, A) -> FR + 'lua,
|
|
||||||
FR: Future<Output = LuaResult<R>> + 'lua;
|
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME: `self` escapes outside of method because we are borrowing `func`
|
|
||||||
// when we call `schedule_future_thread` in the lua function body below
|
|
||||||
// For now we solve this by using the 'static lifetime bound in the impl
|
|
||||||
impl<'lua> LuaSchedulerExt<'lua> for Lua
|
|
||||||
where
|
|
||||||
'lua: 'static,
|
|
||||||
{
|
|
||||||
fn set_scheduler(&'lua self, scheduler: &'lua Scheduler) {
|
|
||||||
self.set_app_data(scheduler);
|
|
||||||
scheduler.set_interrupt_for(self);
|
|
||||||
}
|
|
||||||
|
|
||||||
fn create_async_function<A, R, F, FR>(&'lua self, func: F) -> LuaResult<LuaFunction<'lua>>
|
|
||||||
where
|
|
||||||
A: FromLuaMulti<'lua>,
|
|
||||||
R: IntoLuaMulti<'lua>,
|
|
||||||
F: Fn(&'lua Lua, A) -> FR + 'lua,
|
|
||||||
FR: Future<Output = LuaResult<R>> + 'lua,
|
|
||||||
{
|
|
||||||
self.app_data_ref::<&Scheduler>()
|
|
||||||
.expect("Lua must have a scheduler to create async functions");
|
|
||||||
|
|
||||||
let async_env = self.create_table_with_capacity(0, 2)?;
|
|
||||||
|
|
||||||
async_env.set(
|
|
||||||
"yield",
|
|
||||||
self.globals()
|
|
||||||
.get::<_, LuaTable>("coroutine")?
|
|
||||||
.get::<_, LuaFunction>("yield")?,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
async_env.set(
|
|
||||||
"schedule",
|
|
||||||
LuaFunction::wrap(move |lua: &Lua, args: A| {
|
|
||||||
let thread = lua.current_thread();
|
|
||||||
let future = func(lua, args);
|
|
||||||
let sched = lua
|
|
||||||
.app_data_ref::<&Scheduler>()
|
|
||||||
.expect("Lua struct is missing scheduler");
|
|
||||||
sched.spawn_thread(lua, thread, future)?;
|
|
||||||
Ok(())
|
|
||||||
}),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let async_func = self
|
|
||||||
.load(ASYNC_IMPL_LUA)
|
|
||||||
.set_name("async")
|
|
||||||
.set_environment(async_env)
|
|
||||||
.into_function()?;
|
|
||||||
Ok(async_func)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
Trait for any struct that can be turned into an [`LuaThread`]
|
|
||||||
and given to the scheduler, implemented for the following types:
|
|
||||||
|
|
||||||
- Lua threads ([`LuaThread`])
|
|
||||||
- Lua functions ([`LuaFunction`])
|
|
||||||
- Lua chunks ([`LuaChunk`])
|
|
||||||
*/
|
|
||||||
pub trait IntoLuaThread<'lua> {
|
|
||||||
/**
|
|
||||||
Converts the value into a lua thread.
|
|
||||||
*/
|
|
||||||
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>>;
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'lua> IntoLuaThread<'lua> for LuaThread<'lua> {
|
|
||||||
fn into_lua_thread(self, _: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
|
|
||||||
Ok(self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'lua> IntoLuaThread<'lua> for LuaFunction<'lua> {
|
|
||||||
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
|
|
||||||
lua.create_thread(self)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<'lua, 'a> IntoLuaThread<'lua> for LuaChunk<'lua, 'a> {
|
|
||||||
fn into_lua_thread(self, lua: &'lua Lua) -> LuaResult<LuaThread<'lua>> {
|
|
||||||
lua.create_thread(self.into_function()?)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,18 +0,0 @@
|
||||||
use std::future::Future;
|
|
||||||
use std::pin::Pin;
|
|
||||||
use std::task::{Context, Poll};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone, Copy)]
|
|
||||||
pub struct YieldForever;
|
|
||||||
|
|
||||||
impl Future for YieldForever {
|
|
||||||
type Output = ();
|
|
||||||
|
|
||||||
fn poll(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
|
|
||||||
Poll::Pending
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn yield_forever() -> YieldForever {
|
|
||||||
YieldForever
|
|
||||||
}
|
|
|
@ -1,7 +1,6 @@
|
||||||
mod table_builder;
|
mod table_builder;
|
||||||
|
|
||||||
pub mod formatting;
|
pub mod formatting;
|
||||||
pub mod futures;
|
|
||||||
pub mod luaurc;
|
pub mod luaurc;
|
||||||
pub mod paths;
|
pub mod paths;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
|
@ -4,8 +4,6 @@ use std::future::Future;
|
||||||
|
|
||||||
use mlua::prelude::*;
|
use mlua::prelude::*;
|
||||||
|
|
||||||
use crate::lune::scheduler::LuaSchedulerExt;
|
|
||||||
|
|
||||||
pub struct TableBuilder<'lua> {
|
pub struct TableBuilder<'lua> {
|
||||||
lua: &'lua Lua,
|
lua: &'lua Lua,
|
||||||
tab: LuaTable<'lua>,
|
tab: LuaTable<'lua>,
|
||||||
|
@ -79,20 +77,13 @@ impl<'lua> TableBuilder<'lua> {
|
||||||
pub fn build(self) -> LuaResult<LuaTable<'lua>> {
|
pub fn build(self) -> LuaResult<LuaTable<'lua>> {
|
||||||
Ok(self.tab)
|
Ok(self.tab)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// FIXME: Remove static lifetime bound here when `create_async_function`
|
|
||||||
// no longer needs it to compile, then move this into the above impl
|
|
||||||
impl<'lua> TableBuilder<'lua>
|
|
||||||
where
|
|
||||||
'lua: 'static,
|
|
||||||
{
|
|
||||||
pub fn with_async_function<K, A, R, F, FR>(self, key: K, func: F) -> LuaResult<Self>
|
pub fn with_async_function<K, A, R, F, FR>(self, key: K, func: F) -> LuaResult<Self>
|
||||||
where
|
where
|
||||||
K: IntoLua<'lua>,
|
K: IntoLua<'lua>,
|
||||||
A: FromLuaMulti<'lua>,
|
A: FromLuaMulti<'lua>,
|
||||||
R: IntoLuaMulti<'lua>,
|
R: IntoLuaMulti<'lua>,
|
||||||
F: Fn(&'lua Lua, A) -> FR + 'lua,
|
F: Fn(&'lua Lua, A) -> FR + 'static,
|
||||||
FR: Future<Output = LuaResult<R>> + 'lua,
|
FR: Future<Output = LuaResult<R>> + 'lua,
|
||||||
{
|
{
|
||||||
let f = self.lua.create_async_function(func)?;
|
let f = self.lua.create_async_function(func)?;
|
||||||
|
|
|
@ -68,6 +68,7 @@ create_tests! {
|
||||||
net_url_decode: "net/url/decode",
|
net_url_decode: "net/url/decode",
|
||||||
net_serve_requests: "net/serve/requests",
|
net_serve_requests: "net/serve/requests",
|
||||||
net_serve_websockets: "net/serve/websockets",
|
net_serve_websockets: "net/serve/websockets",
|
||||||
|
net_socket_basic: "net/socket/basic",
|
||||||
net_socket_wss: "net/socket/wss",
|
net_socket_wss: "net/socket/wss",
|
||||||
net_socket_wss_rw: "net/socket/wss_rw",
|
net_socket_wss_rw: "net/socket/wss_rw",
|
||||||
|
|
||||||
|
@ -84,7 +85,6 @@ create_tests! {
|
||||||
|
|
||||||
require_aliases: "require/tests/aliases",
|
require_aliases: "require/tests/aliases",
|
||||||
require_async: "require/tests/async",
|
require_async: "require/tests/async",
|
||||||
require_async_background: "require/tests/async_background",
|
|
||||||
require_async_concurrent: "require/tests/async_concurrent",
|
require_async_concurrent: "require/tests/async_concurrent",
|
||||||
require_async_sequential: "require/tests/async_sequential",
|
require_async_sequential: "require/tests/async_sequential",
|
||||||
require_builtins: "require/tests/builtins",
|
require_builtins: "require/tests/builtins",
|
||||||
|
|
28
tests/net/socket/basic.luau
Normal file
28
tests/net/socket/basic.luau
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
local net = require("@lune/net")
|
||||||
|
|
||||||
|
-- We're going to use Discord's WebSocket gateway server for testing
|
||||||
|
local socket = net.socket("wss://gateway.discord.gg/?v=10&encoding=json")
|
||||||
|
|
||||||
|
assert(type(socket.next) == "function", "next must be a function")
|
||||||
|
assert(type(socket.send) == "function", "send must be a function")
|
||||||
|
assert(type(socket.close) == "function", "close must be a function")
|
||||||
|
|
||||||
|
-- Request to close the socket
|
||||||
|
socket.close()
|
||||||
|
|
||||||
|
-- Drain remaining messages, until we got our close message
|
||||||
|
while socket.next() do
|
||||||
|
end
|
||||||
|
|
||||||
|
assert(type(socket.closeCode) == "number", "closeCode should exist after closing")
|
||||||
|
assert(socket.closeCode == 1000, "closeCode should be 1000 after closing")
|
||||||
|
|
||||||
|
local success, message = pcall(function()
|
||||||
|
socket.send("Hello, world!")
|
||||||
|
end)
|
||||||
|
|
||||||
|
assert(not success, "send should fail after closing")
|
||||||
|
assert(
|
||||||
|
string.find(tostring(message), "closed") or string.find(tostring(message), "closing"),
|
||||||
|
"send should fail with a message that the socket was closed"
|
||||||
|
)
|
|
@ -1,51 +0,0 @@
|
||||||
local net = require("@lune/net")
|
|
||||||
local process = require("@lune/process")
|
|
||||||
local stdio = require("@lune/stdio")
|
|
||||||
local task = require("@lune/task")
|
|
||||||
|
|
||||||
-- Spawn an asynchronous background task (eg. web server)
|
|
||||||
|
|
||||||
local PORT = 8082
|
|
||||||
|
|
||||||
task.delay(3, function()
|
|
||||||
stdio.ewrite("Test did not complete in time\n")
|
|
||||||
task.wait(1)
|
|
||||||
process.exit(1)
|
|
||||||
end)
|
|
||||||
|
|
||||||
local handle = net.serve(PORT, function(request)
|
|
||||||
return ""
|
|
||||||
end)
|
|
||||||
|
|
||||||
-- Require modules same way we did in the async_concurrent and async_sequential tests
|
|
||||||
|
|
||||||
local module3
|
|
||||||
local module4
|
|
||||||
|
|
||||||
task.defer(function()
|
|
||||||
module4 = require("./modules/async")
|
|
||||||
end)
|
|
||||||
|
|
||||||
task.spawn(function()
|
|
||||||
module3 = require("./modules/async")
|
|
||||||
end)
|
|
||||||
|
|
||||||
local _module1 = require("./modules/async")
|
|
||||||
local _module2 = require("./modules/async")
|
|
||||||
|
|
||||||
task.wait(1)
|
|
||||||
|
|
||||||
assert(type(module3) == "table", "Required module3 did not return a table")
|
|
||||||
assert(module3.Foo == "Bar", "Required module3 did not contain correct values")
|
|
||||||
assert(module3.Hello == "World", "Required module3 did not contain correct values")
|
|
||||||
|
|
||||||
assert(type(module4) == "table", "Required module4 did not return a table")
|
|
||||||
assert(module4.Foo == "Bar", "Required module4 did not contain correct values")
|
|
||||||
assert(module4.Hello == "World", "Required module4 did not contain correct values")
|
|
||||||
|
|
||||||
assert(module3 == module4, "Required modules should point to the same return value")
|
|
||||||
|
|
||||||
-- Stop the server and exit successfully
|
|
||||||
|
|
||||||
handle.stop()
|
|
||||||
process.exit(0)
|
|
Loading…
Reference in a new issue