From e3177eeb7569ad5790fc93c032c0093013ee0f24 Mon Sep 17 00:00:00 2001 From: daimond113 <72147841+daimond113@users.noreply.github.com> Date: Tue, 14 Jan 2025 14:33:26 +0100 Subject: [PATCH] fix(engines): store & link engines correctly Fixes issues with how engines were stored which resulted in errors. Also makes outdated linkers get updated. --- src/cli/reporters.rs | 28 +++++----- src/cli/version.rs | 90 +++++++++++++++++++++------------ src/download.rs | 8 +-- src/download_and_link.rs | 4 +- src/engine/source/archive.rs | 37 +++++++------- src/engine/source/github/mod.rs | 7 +-- src/engine/source/mod.rs | 3 +- src/engine/source/traits.rs | 5 +- src/main.rs | 2 +- src/patches.rs | 4 +- src/reporters.rs | 20 ++++---- 11 files changed, 115 insertions(+), 93 deletions(-) diff --git a/src/cli/reporters.rs b/src/cli/reporters.rs index 2cf1ca4..61f3c6e 100644 --- a/src/cli/reporters.rs +++ b/src/cli/reporters.rs @@ -99,31 +99,29 @@ impl CliReporter { } } -pub struct CliDownloadProgressReporter<'a, W> { - root_reporter: &'a CliReporter, +pub struct CliDownloadProgressReporter { + root_reporter: Arc>, name: String, progress: OnceLock, set_progress: Once, } -impl<'a, W: Write + Send + Sync + 'static> DownloadsReporter<'a> for CliReporter { - type DownloadProgressReporter = CliDownloadProgressReporter<'a, W>; +impl DownloadsReporter for CliReporter { + type DownloadProgressReporter = CliDownloadProgressReporter; - fn report_download<'b>(&'a self, name: &'b str) -> Self::DownloadProgressReporter { + fn report_download<'b>(self: Arc, name: String) -> Self::DownloadProgressReporter { self.root_progress.inc_length(1); CliDownloadProgressReporter { root_reporter: self, - name: name.to_string(), + name, progress: OnceLock::new(), set_progress: Once::new(), } } } -impl DownloadProgressReporter - for CliDownloadProgressReporter<'_, W> -{ +impl DownloadProgressReporter for CliDownloadProgressReporter { fn report_start(&self) { let progress = self.root_reporter.multi_progress.add(ProgressBar::new(0)); progress.set_style(self.root_reporter.child_style.clone()); @@ -171,16 +169,16 @@ impl DownloadProgressReporter } } -pub struct CliPatchProgressReporter<'a, W> { - root_reporter: &'a CliReporter, +pub struct CliPatchProgressReporter { + root_reporter: Arc>, name: String, progress: ProgressBar, } -impl<'a, W: Write + Send + Sync + 'static> PatchesReporter<'a> for CliReporter { - type PatchProgressReporter = CliPatchProgressReporter<'a, W>; +impl PatchesReporter for CliReporter { + type PatchProgressReporter = CliPatchProgressReporter; - fn report_patch<'b>(&'a self, name: &'b str) -> Self::PatchProgressReporter { + fn report_patch(self: Arc, name: String) -> Self::PatchProgressReporter { let progress = self.multi_progress.add(ProgressBar::new(0)); progress.set_style(self.child_style.clone()); progress.set_message(format!("- {name}")); @@ -195,7 +193,7 @@ impl<'a, W: Write + Send + Sync + 'static> PatchesReporter<'a> for CliReporter PatchProgressReporter for CliPatchProgressReporter<'_, W> { +impl PatchProgressReporter for CliPatchProgressReporter { fn report_done(&self) { if self.progress.is_hidden() { writeln!( diff --git a/src/cli/version.rs b/src/cli/version.rs index eda55f4..9a7b252 100644 --- a/src/cli/version.rs +++ b/src/cli/version.rs @@ -3,6 +3,7 @@ use crate::cli::{ config::{read_config, write_config, CliConfig}, files::make_executable, home_dir, + reporters::run_with_reporter, }; use anyhow::Context; use colored::Colorize; @@ -15,11 +16,13 @@ use pesde::{ }, EngineKind, }, + reporters::DownloadsReporter, version_matches, }; use semver::{Version, VersionReq}; use std::{ collections::BTreeSet, + env::current_exe, path::{Path, PathBuf}, sync::Arc, }; @@ -159,28 +162,25 @@ pub async fn get_or_download_engine( .await .context("failed to read engines directory")?; - let mut matching_versions = BTreeSet::new(); + let mut installed_versions = BTreeSet::new(); while let Some(entry) = read_dir.next_entry().await? { let path = entry.path(); - #[cfg(windows)] - let version = path.file_stem(); - #[cfg(not(windows))] - let version = path.file_name(); - - let Some(version) = version.and_then(|s| s.to_str()) else { + let Some(version) = path.file_name().and_then(|s| s.to_str()) else { continue; }; if let Ok(version) = Version::parse(version) { - if version_matches(&version, &req) { - matching_versions.insert(version); - } + installed_versions.insert(version); } } - if let Some(version) = matching_versions.pop_last() { + let max_matching = installed_versions + .iter() + .filter(|v| version_matches(v, &req)) + .last(); + if let Some(version) = max_matching { return Ok(path .join(version.to_string()) .join(source.expected_file_name()) @@ -198,40 +198,66 @@ pub async fn get_or_download_engine( .context("failed to resolve versions")?; let (version, engine_ref) = versions.pop_last().context("no matching versions found")?; + let path = path.join(version.to_string()); + + fs::create_dir_all(&path) + .await + .context("failed to create engine container folder")?; + let path = path - .join(version.to_string()) .join(source.expected_file_name()) .with_extension(std::env::consts::EXE_EXTENSION); - let archive = source - .download( - &engine_ref, - &DownloadOptions { - reqwest: reqwest.clone(), - reporter: Arc::new(()), - version, - }, - ) - .await - .context("failed to download engine")?; - let mut file = fs::File::create(&path) .await .context("failed to create new file")?; - tokio::io::copy( - &mut archive - .find_executable(source.expected_file_name()) + + run_with_reporter(|_, root_progress, reporter| async { + let root_progress = root_progress; + + root_progress.set_message("download"); + + let reporter = reporter.report_download(format!("{engine} v{version}")); + + let archive = source + .download( + &engine_ref, + &DownloadOptions { + reqwest: reqwest.clone(), + reporter: Arc::new(reporter), + version: version.clone(), + }, + ) .await - .context("failed to find executable")?, - &mut file, - ) - .await - .context("failed to write to file")?; + .context("failed to download engine")?; + + tokio::io::copy( + &mut archive + .find_executable(source.expected_file_name()) + .await + .context("failed to find executable")?, + &mut file, + ) + .await + .context("failed to write to file")?; + + Ok::<_, anyhow::Error>(()) + }) + .await?; make_executable(&path) .await .context("failed to make downloaded version executable")?; + // replace the executable if there isn't any installed, or the one installed is out of date + if installed_versions.pop_last().is_none_or(|v| version > v) { + replace_bin_exe( + engine, + ¤t_exe().context("failed to get current exe path")?, + ) + .await?; + } + Ok(path) } diff --git a/src/download.rs b/src/download.rs index 3e16935..3ced222 100644 --- a/src/download.rs +++ b/src/download.rs @@ -29,7 +29,7 @@ pub(crate) struct DownloadGraphOptions { impl DownloadGraphOptions where - Reporter: for<'a> DownloadsReporter<'a> + Send + Sync + 'static, + Reporter: DownloadsReporter + Send + Sync + 'static, { /// Creates a new download options with the given reqwest client and reporter. pub(crate) fn new(reqwest: reqwest::Client) -> Self { @@ -85,7 +85,7 @@ impl Project { errors::DownloadGraphError, > where - Reporter: for<'a> DownloadsReporter<'a> + Send + Sync + 'static, + Reporter: DownloadsReporter + Send + Sync + 'static, { let DownloadGraphOptions { reqwest, @@ -111,8 +111,8 @@ impl Project { async move { let progress_reporter = reporter - .as_deref() - .map(|reporter| reporter.report_download(&package_id.to_string())); + .clone() + .map(|reporter| reporter.report_download(package_id.to_string())); let _permit = semaphore.acquire().await; diff --git a/src/download_and_link.rs b/src/download_and_link.rs index b8c74e5..d799297 100644 --- a/src/download_and_link.rs +++ b/src/download_and_link.rs @@ -81,7 +81,7 @@ pub struct DownloadAndLinkOptions { impl DownloadAndLinkOptions where - Reporter: for<'a> DownloadsReporter<'a> + Send + Sync + 'static, + Reporter: DownloadsReporter + Send + Sync + 'static, Hooks: DownloadAndLinkHooks + Send + Sync + 'static, { /// Creates a new download options with the given reqwest client and reporter. @@ -149,7 +149,7 @@ impl Project { options: DownloadAndLinkOptions, ) -> Result> where - Reporter: for<'a> DownloadsReporter<'a> + 'static, + Reporter: DownloadsReporter + 'static, Hooks: DownloadAndLinkHooks + 'static, { let DownloadAndLinkOptions { diff --git a/src/engine/source/archive.rs b/src/engine/source/archive.rs index f73147c..28a2476 100644 --- a/src/engine/source/archive.rs +++ b/src/engine/source/archive.rs @@ -53,22 +53,22 @@ impl FromStr for ArchiveInfo { } } +pub(crate) type ArchiveReader = Pin>; + /// An archive -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct Archive { +pub struct Archive { pub(crate) info: ArchiveInfo, - pub(crate) reader: R, + pub(crate) reader: ArchiveReader, } -#[derive(Debug)] -enum TarReader { - Gzip(async_compression::tokio::bufread::GzipDecoder), - Plain(R), +enum TarReader { + Gzip(async_compression::tokio::bufread::GzipDecoder), + Plain(ArchiveReader), } // TODO: try to see if we can avoid the unsafe blocks -impl AsyncRead for TarReader { +impl AsyncRead for TarReader { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -83,8 +83,8 @@ impl AsyncRead for TarReader { } } -enum ArchiveEntryInner { - Tar(tokio_tar::Entry>>>>), +enum ArchiveEntryInner { + Tar(tokio_tar::Entry>), Zip { archive: *mut async_zip::tokio::read::seek::ZipFileReader>>, reader: ManuallyDrop< @@ -99,7 +99,7 @@ enum ArchiveEntryInner { }, } -impl Drop for ArchiveEntryInner { +impl Drop for ArchiveEntryInner { fn drop(&mut self) { match self { Self::Tar(_) => {} @@ -112,9 +112,9 @@ impl Drop for ArchiveEntryInner { } /// An entry in an archive. Usually the executable -pub struct ArchiveEntry(ArchiveEntryInner); +pub struct ArchiveEntry(ArchiveEntryInner); -impl AsyncRead for ArchiveEntry { +impl AsyncRead for ArchiveEntry { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -131,12 +131,12 @@ impl AsyncRead for ArchiveEntry { } } -impl Archive { +impl Archive { /// Finds the executable in the archive and returns it as an [`ArchiveEntry`] pub async fn find_executable( self, expected_file_name: &str, - ) -> Result, errors::FindExecutableError> { + ) -> Result { #[derive(Debug, PartialEq, Eq)] struct Candidate { path: PathBuf, @@ -188,10 +188,11 @@ impl Archive { ArchiveInfo(ArchiveKind::Tar, encoding) => { use async_compression::tokio::bufread as decoders; - let reader = Box::pin(self.reader); let reader = match encoding { - Some(EncodingKind::Gzip) => TarReader::Gzip(decoders::GzipDecoder::new(reader)), - None => TarReader::Plain(reader), + Some(EncodingKind::Gzip) => { + TarReader::Gzip(decoders::GzipDecoder::new(self.reader)) + } + None => TarReader::Plain(self.reader), }; let mut archive = tokio_tar::Archive::new(reader); diff --git a/src/engine/source/github/mod.rs b/src/engine/source/github/mod.rs index 95a31a2..7663d3c 100644 --- a/src/engine/source/github/mod.rs +++ b/src/engine/source/github/mod.rs @@ -13,7 +13,6 @@ use crate::{ use reqwest::header::ACCEPT; use semver::{Version, VersionReq}; use std::{collections::BTreeMap, path::PathBuf}; -use tokio::io::AsyncBufRead; /// The GitHub engine source #[derive(Debug, Eq, PartialEq, Hash, Clone)] @@ -73,7 +72,7 @@ impl EngineSource for GitHubEngineSource { &self, engine_ref: &Self::Ref, options: &DownloadOptions, - ) -> Result, Self::DownloadError> { + ) -> Result { let DownloadOptions { reqwest, reporter, @@ -91,6 +90,8 @@ impl EngineSource for GitHubEngineSource { .find(|asset| asset.name.eq_ignore_ascii_case(&desired_asset_name)) .ok_or(errors::DownloadError::AssetNotFound)?; + reporter.report_start(); + let response = reqwest .get(asset.url.clone()) .header(ACCEPT, "application/octet-stream") @@ -100,7 +101,7 @@ impl EngineSource for GitHubEngineSource { Ok(Archive { info: asset.name.parse()?, - reader: response_to_async_read(response, reporter.clone()), + reader: Box::pin(response_to_async_read(response, reporter.clone())), }) } } diff --git a/src/engine/source/mod.rs b/src/engine/source/mod.rs index 5d26501..b8a8a67 100644 --- a/src/engine/source/mod.rs +++ b/src/engine/source/mod.rs @@ -7,7 +7,6 @@ use crate::{ }; use semver::{Version, VersionReq}; use std::{collections::BTreeMap, path::PathBuf}; -use tokio::io::AsyncBufRead; /// Archives pub mod archive; @@ -69,7 +68,7 @@ impl EngineSource for EngineSources { &self, engine_ref: &Self::Ref, options: &DownloadOptions, - ) -> Result, Self::DownloadError> { + ) -> Result { match (self, engine_ref) { (EngineSources::GitHub(source), EngineRefs::GitHub(release)) => { source.download(release, options).await.map_err(Into::into) diff --git a/src/engine/source/traits.rs b/src/engine/source/traits.rs index 6457486..a06787b 100644 --- a/src/engine/source/traits.rs +++ b/src/engine/source/traits.rs @@ -1,7 +1,6 @@ use crate::{engine::source::archive::Archive, reporters::DownloadProgressReporter}; use semver::{Version, VersionReq}; use std::{collections::BTreeMap, fmt::Debug, future::Future, path::PathBuf, sync::Arc}; -use tokio::io::AsyncBufRead; /// Options for resolving an engine #[derive(Debug, Clone)] @@ -48,7 +47,5 @@ pub trait EngineSource: Debug { &self, engine_ref: &Self::Ref, options: &DownloadOptions, - ) -> impl Future, Self::DownloadError>> - + Send - + Sync; + ) -> impl Future> + Send + Sync; } diff --git a/src/main.rs b/src/main.rs index 7aa2c61..ad1d541 100644 --- a/src/main.rs +++ b/src/main.rs @@ -299,7 +299,7 @@ async fn run() -> anyhow::Result<()> { let exe_path = get_or_download_engine(&reqwest, engine, req.unwrap_or(VersionReq::STAR)).await?; if exe_path == current_exe { - break 'engines; + anyhow::bail!("engine linker executed by itself") } let status = std::process::Command::new(exe_path) diff --git a/src/patches.rs b/src/patches.rs index df4d8b2..7aab8ac 100644 --- a/src/patches.rs +++ b/src/patches.rs @@ -84,7 +84,7 @@ impl Project { reporter: Arc, ) -> Result<(), errors::ApplyPatchesError> where - Reporter: for<'a> PatchesReporter<'a> + Send + Sync + 'static, + Reporter: PatchesReporter + Send + Sync + 'static, { let manifest = self.deser_manifest().await?; @@ -112,7 +112,7 @@ impl Project { async move { tracing::debug!("applying patch"); - let progress_reporter = reporter.report_patch(&package_id.to_string()); + let progress_reporter = reporter.report_patch(package_id.to_string()); let patch = fs::read(&patch_path) .await diff --git a/src/reporters.rs b/src/reporters.rs index ef87c4e..1305f4e 100644 --- a/src/reporters.rs +++ b/src/reporters.rs @@ -15,17 +15,17 @@ use std::sync::Arc; use tokio::io::AsyncBufRead; /// Reports downloads. -pub trait DownloadsReporter<'a>: Send + Sync { +pub trait DownloadsReporter: Send + Sync { /// The [`DownloadProgressReporter`] type associated with this reporter. - type DownloadProgressReporter: DownloadProgressReporter + 'a; + type DownloadProgressReporter: DownloadProgressReporter + 'static; /// Starts a new download. - fn report_download<'b>(&'a self, name: &'b str) -> Self::DownloadProgressReporter; + fn report_download(self: Arc, name: String) -> Self::DownloadProgressReporter; } -impl DownloadsReporter<'_> for () { +impl DownloadsReporter for () { type DownloadProgressReporter = (); - fn report_download(&self, name: &str) -> Self::DownloadProgressReporter {} + fn report_download(self: Arc, name: String) -> Self::DownloadProgressReporter {} } /// Reports the progress of a single download. @@ -46,17 +46,17 @@ pub trait DownloadProgressReporter: Send + Sync { impl DownloadProgressReporter for () {} /// Reports the progress of applying patches. -pub trait PatchesReporter<'a>: Send + Sync { +pub trait PatchesReporter: Send + Sync { /// The [`PatchProgressReporter`] type associated with this reporter. - type PatchProgressReporter: PatchProgressReporter + 'a; + type PatchProgressReporter: PatchProgressReporter + 'static; /// Starts a new patch. - fn report_patch<'b>(&'a self, name: &'b str) -> Self::PatchProgressReporter; + fn report_patch(self: Arc, name: String) -> Self::PatchProgressReporter; } -impl PatchesReporter<'_> for () { +impl PatchesReporter for () { type PatchProgressReporter = (); - fn report_patch(&self, name: &str) -> Self::PatchProgressReporter {} + fn report_patch(self: Arc, name: String) -> Self::PatchProgressReporter {} } /// Reports the progress of a single patch.