Use mlua trait for mapping into lua errors

This commit is contained in:
Filip Tibell 2023-08-05 17:12:25 -05:00
parent cacaa97b6e
commit 483713e635
No known key found for this signature in database
13 changed files with 45 additions and 55 deletions

View file

@ -26,14 +26,14 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
} }
async fn fs_read_file(lua: &'static Lua, path: String) -> LuaResult<LuaString> { async fn fs_read_file(lua: &'static Lua, path: String) -> LuaResult<LuaString> {
let bytes = fs::read(&path).await.map_err(LuaError::external)?; let bytes = fs::read(&path).await.into_lua_err()?;
lua.create_string(bytes) lua.create_string(bytes)
} }
async fn fs_read_dir(_: &'static Lua, path: String) -> LuaResult<Vec<String>> { async fn fs_read_dir(_: &'static Lua, path: String) -> LuaResult<Vec<String>> {
let mut dir_strings = Vec::new(); let mut dir_strings = Vec::new();
let mut dir = fs::read_dir(&path).await.map_err(LuaError::external)?; let mut dir = fs::read_dir(&path).await.into_lua_err()?;
while let Some(dir_entry) = dir.next_entry().await.map_err(LuaError::external)? { while let Some(dir_entry) = dir.next_entry().await.into_lua_err()? {
if let Some(dir_path_str) = dir_entry.path().to_str() { if let Some(dir_path_str) = dir_entry.path().to_str() {
dir_strings.push(dir_path_str.to_owned()); dir_strings.push(dir_path_str.to_owned());
} else { } else {
@ -63,21 +63,19 @@ async fn fs_write_file(
_: &'static Lua, _: &'static Lua,
(path, contents): (String, LuaString<'_>), (path, contents): (String, LuaString<'_>),
) -> LuaResult<()> { ) -> LuaResult<()> {
fs::write(&path, &contents.as_bytes()) fs::write(&path, &contents.as_bytes()).await.into_lua_err()
.await
.map_err(LuaError::external)
} }
async fn fs_write_dir(_: &'static Lua, path: String) -> LuaResult<()> { async fn fs_write_dir(_: &'static Lua, path: String) -> LuaResult<()> {
fs::create_dir_all(&path).await.map_err(LuaError::external) fs::create_dir_all(&path).await.into_lua_err()
} }
async fn fs_remove_file(_: &'static Lua, path: String) -> LuaResult<()> { async fn fs_remove_file(_: &'static Lua, path: String) -> LuaResult<()> {
fs::remove_file(&path).await.map_err(LuaError::external) fs::remove_file(&path).await.into_lua_err()
} }
async fn fs_remove_dir(_: &'static Lua, path: String) -> LuaResult<()> { async fn fs_remove_dir(_: &'static Lua, path: String) -> LuaResult<()> {
fs::remove_dir_all(&path).await.map_err(LuaError::external) fs::remove_dir_all(&path).await.into_lua_err()
} }
async fn fs_metadata(_: &'static Lua, path: String) -> LuaResult<FsMetadata> { async fn fs_metadata(_: &'static Lua, path: String) -> LuaResult<FsMetadata> {
@ -122,9 +120,7 @@ async fn fs_move(
path_to.display() path_to.display()
))); )));
} }
fs::rename(path_from, path_to) fs::rename(path_from, path_to).await.into_lua_err()?;
.await
.map_err(LuaError::external)?;
Ok(()) Ok(())
} }

View file

@ -73,7 +73,7 @@ async fn net_request<'a>(lua: &'static Lua, config: RequestConfig<'a>) -> LuaRes
.body(config.body.unwrap_or_default()) .body(config.body.unwrap_or_default())
.send() .send()
.await .await
.map_err(LuaError::external)?; .into_lua_err()?;
// Extract status, headers // Extract status, headers
let res_status = res.status().as_u16(); let res_status = res.status().as_u16();
let res_status_text = res.status().canonical_reason(); let res_status_text = res.status().canonical_reason();
@ -88,7 +88,7 @@ async fn net_request<'a>(lua: &'static Lua, config: RequestConfig<'a>) -> LuaRes
}) })
.collect::<HashMap<String, String>>(); .collect::<HashMap<String, String>>();
// Read response bytes // Read response bytes
let mut res_bytes = res.bytes().await.map_err(LuaError::external)?.to_vec(); let mut res_bytes = res.bytes().await.into_lua_err()?.to_vec();
// Check for extra options, decompression // Check for extra options, decompression
if config.options.decompress { if config.options.decompress {
// NOTE: Header names are guaranteed to be lowercase because of the above // NOTE: Header names are guaranteed to be lowercase because of the above
@ -120,9 +120,7 @@ async fn net_request<'a>(lua: &'static Lua, config: RequestConfig<'a>) -> LuaRes
} }
async fn net_socket<'a>(lua: &'static Lua, url: String) -> LuaResult<LuaTable> { async fn net_socket<'a>(lua: &'static Lua, url: String) -> LuaResult<LuaTable> {
let (ws, _) = tokio_tungstenite::connect_async(url) let (ws, _) = tokio_tungstenite::connect_async(url).await.into_lua_err()?;
.await
.map_err(LuaError::external)?;
NetWebSocket::new(ws).into_lua_table(lua) NetWebSocket::new(ws).into_lua_table(lua)
} }

View file

@ -41,7 +41,7 @@ async fn deserialize_place<'lua>(
let data_model = doc.into_data_model_instance()?; let data_model = doc.into_data_model_instance()?;
Ok::<_, DocumentError>(data_model) Ok::<_, DocumentError>(data_model)
}); });
fut.await.map_err(LuaError::external)??.into_lua(lua) fut.await.into_lua_err()??.into_lua(lua)
} }
async fn deserialize_model<'lua>( async fn deserialize_model<'lua>(
@ -54,7 +54,7 @@ async fn deserialize_model<'lua>(
let instance_array = doc.into_instance_array()?; let instance_array = doc.into_instance_array()?;
Ok::<_, DocumentError>(instance_array) Ok::<_, DocumentError>(instance_array)
}); });
fut.await.map_err(LuaError::external)??.into_lua(lua) fut.await.into_lua_err()??.into_lua(lua)
} }
async fn serialize_place<'lua>( async fn serialize_place<'lua>(
@ -70,7 +70,7 @@ async fn serialize_place<'lua>(
})?; })?;
Ok::<_, DocumentError>(bytes) Ok::<_, DocumentError>(bytes)
}); });
let bytes = fut.await.map_err(LuaError::external)??; let bytes = fut.await.into_lua_err()??;
lua.create_string(bytes) lua.create_string(bytes)
} }
@ -87,7 +87,7 @@ async fn serialize_model<'lua>(
})?; })?;
Ok::<_, DocumentError>(bytes) Ok::<_, DocumentError>(bytes)
}); });
let bytes = fut.await.map_err(LuaError::external)??; let bytes = fut.await.into_lua_err()??;
lua.create_string(bytes) lua.create_string(bytes)
} }

View file

@ -43,7 +43,7 @@ pub fn create(lua: &'static Lua) -> LuaResult<LuaTable> {
.with_async_function("prompt", |_, options: PromptOptions| async move { .with_async_function("prompt", |_, options: PromptOptions| async move {
task::spawn_blocking(move || prompt(options)) task::spawn_blocking(move || prompt(options))
.await .await
.map_err(LuaError::external)? .into_lua_err()?
})? })?
.build_readonly() .build_readonly()
} }

View file

@ -202,7 +202,7 @@ async fn load_file<'lua>(
} }
// Try to read the wanted file, note that we use bytes instead of reading // Try to read the wanted file, note that we use bytes instead of reading
// to a string since lua scripts are not necessarily valid utf-8 strings // to a string since lua scripts are not necessarily valid utf-8 strings
let contents = fs::read(&absolute_path).await.map_err(LuaError::external)?; let contents = fs::read(&absolute_path).await.into_lua_err()?;
// Use a name without extensions for loading the chunk, some // Use a name without extensions for loading the chunk, some
// other code assumes the require path is without extensions // other code assumes the require path is without extensions
let path_relative_no_extension = relative_path let path_relative_no_extension = relative_path

View file

@ -23,8 +23,8 @@ impl NetClientBuilder {
{ {
let mut map = HeaderMap::new(); let mut map = HeaderMap::new();
for (key, val) in headers { for (key, val) in headers {
let hkey = HeaderName::from_str(key.as_ref()).map_err(LuaError::external)?; let hkey = HeaderName::from_str(key.as_ref()).into_lua_err()?;
let hval = HeaderValue::from_bytes(val.as_ref()).map_err(LuaError::external)?; let hval = HeaderValue::from_bytes(val.as_ref()).into_lua_err()?;
map.insert(hkey, hval); map.insert(hkey, hval);
} }
self.builder = self.builder.default_headers(map); self.builder = self.builder.default_headers(map);
@ -32,7 +32,7 @@ impl NetClientBuilder {
} }
pub fn build(self) -> LuaResult<NetClient> { pub fn build(self) -> LuaResult<NetClient> {
let client = self.builder.build().map_err(LuaError::external)?; let client = self.builder.build().into_lua_err()?;
Ok(NetClient(client)) Ok(NetClient(client))
} }
} }

View file

@ -24,7 +24,7 @@ impl NetServeResponse {
.status(200) .status(200)
.header("Content-Type", "text/plain") .header("Content-Type", "text/plain")
.body(Body::from(self.body.unwrap())) .body(Body::from(self.body.unwrap()))
.map_err(LuaError::external)?, .into_lua_err()?,
NetServeResponseKind::Table => { NetServeResponseKind::Table => {
let mut response = Response::builder(); let mut response = Response::builder();
for (key, value) in self.headers { for (key, value) in self.headers {
@ -33,7 +33,7 @@ impl NetServeResponse {
response response
.status(self.status) .status(self.status)
.body(Body::from(self.body.unwrap_or_default())) .body(Body::from(self.body.unwrap_or_default()))
.map_err(LuaError::external)? .into_lua_err()?
} }
}) })
} }

View file

@ -57,7 +57,7 @@ impl Service<Request<Body>> for NetServiceInner {
task::spawn_local(async move { task::spawn_local(async move {
// Create our new full websocket object, then // Create our new full websocket object, then
// schedule our handler to get called asap // schedule our handler to get called asap
let ws = ws.await.map_err(LuaError::external)?; let ws = ws.await.into_lua_err()?;
let sock = NetWebSocket::new(ws).into_lua_table(lua)?; let sock = NetWebSocket::new(ws).into_lua_table(lua)?;
let sched = lua let sched = lua
.app_data_ref::<&TaskScheduler>() .app_data_ref::<&TaskScheduler>()
@ -77,7 +77,7 @@ impl Service<Request<Body>> for NetServiceInner {
let (parts, body) = req.into_parts(); let (parts, body) = req.into_parts();
Box::pin(async move { Box::pin(async move {
// Convert request body into bytes, extract handler // Convert request body into bytes, extract handler
let bytes = to_bytes(body).await.map_err(LuaError::external)?; let bytes = to_bytes(body).await.into_lua_err()?;
let handler: LuaFunction = lua.registry_value(&key)?; let handler: LuaFunction = lua.registry_value(&key)?;
// Create a readonly table for the request query params // Create a readonly table for the request query params
let query_params = TableBuilder::new(lua)? let query_params = TableBuilder::new(lua)?

View file

@ -167,10 +167,10 @@ where
reason: "".into(), reason: "".into(),
}))) })))
.await .await
.map_err(LuaError::external)?; .into_lua_err()?;
let res = ws.close(); let res = ws.close();
res.await.map_err(LuaError::external) res.await.into_lua_err()
} }
async fn send<'lua, T>( async fn send<'lua, T>(
@ -187,11 +187,11 @@ where
let msg = if matches!(as_binary, Some(true)) { let msg = if matches!(as_binary, Some(true)) {
WsMessage::Binary(string.as_bytes().to_vec()) WsMessage::Binary(string.as_bytes().to_vec())
} else { } else {
let s = string.to_str().map_err(LuaError::external)?; let s = string.to_str().into_lua_err()?;
WsMessage::Text(s.to_string()) WsMessage::Text(s.to_string())
}; };
let mut ws = socket.write_stream.lock().await; let mut ws = socket.write_stream.lock().await;
ws.send(msg).await.map_err(LuaError::external) ws.send(msg).await.into_lua_err()
} }
async fn next<'lua, T>( async fn next<'lua, T>(
@ -202,7 +202,7 @@ where
T: AsyncRead + AsyncWrite + Unpin, T: AsyncRead + AsyncWrite + Unpin,
{ {
let mut ws = socket.read_stream.lock().await; let mut ws = socket.read_stream.lock().await;
let item = ws.next().await.transpose().map_err(LuaError::external); let item = ws.next().await.transpose().into_lua_err();
let msg = match item { let msg = match item {
Ok(Some(WsMessage::Close(msg))) => { Ok(Some(WsMessage::Close(msg))) => {
if let Some(msg) = &msg { if let Some(msg) = &msg {

View file

@ -24,9 +24,7 @@ pub async fn pipe_and_inherit_child_process_stdio(
let mut stdout = io::stdout(); let mut stdout = io::stdout();
let mut tee = AsyncTeeWriter::new(&mut stdout); let mut tee = AsyncTeeWriter::new(&mut stdout);
io::copy(&mut child_stdout, &mut tee) io::copy(&mut child_stdout, &mut tee).await.into_lua_err()?;
.await
.map_err(LuaError::external)?;
Ok::<_, LuaError>(tee.into_vec()) Ok::<_, LuaError>(tee.into_vec())
}); });
@ -35,9 +33,7 @@ pub async fn pipe_and_inherit_child_process_stdio(
let mut stderr = io::stderr(); let mut stderr = io::stderr();
let mut tee = AsyncTeeWriter::new(&mut stderr); let mut tee = AsyncTeeWriter::new(&mut stderr);
io::copy(&mut child_stderr, &mut tee) io::copy(&mut child_stderr, &mut tee).await.into_lua_err()?;
.await
.map_err(LuaError::external)?;
Ok::<_, LuaError>(tee.into_vec()) Ok::<_, LuaError>(tee.into_vec())
}); });

View file

@ -102,7 +102,7 @@ pub async fn compress<'lua>(
let source = source.as_ref().to_vec(); let source = source.as_ref().to_vec();
return task::spawn_blocking(move || compress_prepend_size(&source)) return task::spawn_blocking(move || compress_prepend_size(&source))
.await .await
.map_err(LuaError::external); .into_lua_err();
} }
let mut bytes = Vec::new(); let mut bytes = Vec::new();
@ -135,8 +135,8 @@ pub async fn decompress<'lua>(
let source = source.as_ref().to_vec(); let source = source.as_ref().to_vec();
return task::spawn_blocking(move || decompress_size_prepended(&source)) return task::spawn_blocking(move || decompress_size_prepended(&source))
.await .await
.map_err(LuaError::external)? .into_lua_err()?
.map_err(LuaError::external); .into_lua_err();
} }
let mut bytes = Vec::new(); let mut bytes = Vec::new();

View file

@ -61,23 +61,23 @@ impl EncodeDecodeConfig {
EncodeDecodeFormat::Json => { EncodeDecodeFormat::Json => {
let serialized: JsonValue = lua.from_value_with(value, LUA_DESERIALIZE_OPTIONS)?; let serialized: JsonValue = lua.from_value_with(value, LUA_DESERIALIZE_OPTIONS)?;
if self.pretty { if self.pretty {
serde_json::to_vec_pretty(&serialized).map_err(LuaError::external)? serde_json::to_vec_pretty(&serialized).into_lua_err()?
} else { } else {
serde_json::to_vec(&serialized).map_err(LuaError::external)? serde_json::to_vec(&serialized).into_lua_err()?
} }
} }
EncodeDecodeFormat::Yaml => { EncodeDecodeFormat::Yaml => {
let serialized: YamlValue = lua.from_value_with(value, LUA_DESERIALIZE_OPTIONS)?; let serialized: YamlValue = lua.from_value_with(value, LUA_DESERIALIZE_OPTIONS)?;
let mut writer = Vec::with_capacity(128); let mut writer = Vec::with_capacity(128);
serde_yaml::to_writer(&mut writer, &serialized).map_err(LuaError::external)?; serde_yaml::to_writer(&mut writer, &serialized).into_lua_err()?;
writer writer
} }
EncodeDecodeFormat::Toml => { EncodeDecodeFormat::Toml => {
let serialized: TomlValue = lua.from_value_with(value, LUA_DESERIALIZE_OPTIONS)?; let serialized: TomlValue = lua.from_value_with(value, LUA_DESERIALIZE_OPTIONS)?;
let s = if self.pretty { let s = if self.pretty {
toml::to_string_pretty(&serialized).map_err(LuaError::external)? toml::to_string_pretty(&serialized).into_lua_err()?
} else { } else {
toml::to_string(&serialized).map_err(LuaError::external)? toml::to_string(&serialized).into_lua_err()?
}; };
s.as_bytes().to_vec() s.as_bytes().to_vec()
} }
@ -93,16 +93,16 @@ impl EncodeDecodeConfig {
let bytes = string.as_bytes(); let bytes = string.as_bytes();
match self.format { match self.format {
EncodeDecodeFormat::Json => { EncodeDecodeFormat::Json => {
let value: JsonValue = serde_json::from_slice(bytes).map_err(LuaError::external)?; let value: JsonValue = serde_json::from_slice(bytes).into_lua_err()?;
lua.to_value_with(&value, LUA_SERIALIZE_OPTIONS) lua.to_value_with(&value, LUA_SERIALIZE_OPTIONS)
} }
EncodeDecodeFormat::Yaml => { EncodeDecodeFormat::Yaml => {
let value: YamlValue = serde_yaml::from_slice(bytes).map_err(LuaError::external)?; let value: YamlValue = serde_yaml::from_slice(bytes).into_lua_err()?;
lua.to_value_with(&value, LUA_SERIALIZE_OPTIONS) lua.to_value_with(&value, LUA_SERIALIZE_OPTIONS)
} }
EncodeDecodeFormat::Toml => { EncodeDecodeFormat::Toml => {
if let Ok(s) = string.to_str() { if let Ok(s) = string.to_str() {
let value: TomlValue = toml::from_str(s).map_err(LuaError::external)?; let value: TomlValue = toml::from_str(s).into_lua_err()?;
lua.to_value_with(&value, LUA_SERIALIZE_OPTIONS) lua.to_value_with(&value, LUA_SERIALIZE_OPTIONS)
} else { } else {
Err(LuaError::RuntimeError( Err(LuaError::RuntimeError(

View file

@ -197,12 +197,12 @@ pub fn pretty_format_multi_value(multi: &LuaMultiValue) -> LuaResult<String> {
for value in multi { for value in multi {
counter += 1; counter += 1;
if let LuaValue::String(s) = value { if let LuaValue::String(s) = value {
write!(buffer, "{}", s.to_string_lossy()).map_err(LuaError::external)?; write!(buffer, "{}", s.to_string_lossy()).into_lua_err()?;
} else { } else {
pretty_format_value(&mut buffer, value, 0).map_err(LuaError::external)?; pretty_format_value(&mut buffer, value, 0).into_lua_err()?;
} }
if counter < multi.len() { if counter < multi.len() {
write!(&mut buffer, " ").map_err(LuaError::external)?; write!(&mut buffer, " ").into_lua_err()?;
} }
} }
Ok(buffer) Ok(buffer)