From a41d9950f81c48c20a1f1003acece41ee46db2f2 Mon Sep 17 00:00:00 2001 From: Luka <47296785+lukadev-0@users.noreply.github.com> Date: Fri, 27 Dec 2024 22:04:47 +0100 Subject: [PATCH] feat: better install (#17) * feat: better install * feat: support progress reporting for wally * chore: remove tracing-indicatif * chore: fix Cargo.toml * fix: indentation in bin link script * fix: spinner tick chars * feat: change progress message color * fix: remove pretty from fmt_layer Co-authored-by: dai <72147841+daimond113@users.noreply.github.com> * style: format code --------- Co-authored-by: dai <72147841+daimond113@users.noreply.github.com> --- Cargo.lock | 68 ++--- Cargo.toml | 6 +- src/cli/commands/execute.rs | 206 ++++++++------- src/cli/commands/install.rs | 320 ++---------------------- src/cli/commands/patch.rs | 4 +- src/cli/commands/update.rs | 104 +++----- src/cli/install.rs | 482 ++++++++++++++++++++++++++++++++++++ src/cli/mod.rs | 36 +-- src/cli/reporters.rs | 213 ++++++++++++++++ src/download.rs | 229 +++++++++++------ src/download_and_link.rs | 322 ++++++++++++++++-------- src/lib.rs | 5 +- src/main.rs | 56 ++++- src/manifest/mod.rs | 2 +- src/patches.rs | 140 ++++------- src/reporters.rs | 63 +++++ src/source/git/mod.rs | 2 + src/source/mod.rs | 11 +- src/source/pesde/mod.rs | 27 +- src/source/traits.rs | 3 + src/source/wally/mod.rs | 42 +++- src/source/workspace/mod.rs | 7 +- 22 files changed, 1511 insertions(+), 837 deletions(-) create mode 100644 src/cli/install.rs create mode 100644 src/cli/reporters.rs create mode 100644 src/reporters.rs diff --git a/Cargo.lock b/Cargo.lock index d6193dc..5d69f6a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -357,12 +357,6 @@ version = "1.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" -[[package]] -name = "arrayvec" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" - [[package]] name = "async-broadcast" version = "0.7.1" @@ -2880,7 +2874,6 @@ dependencies = [ "number_prefix", "portable-atomic", "unicode-width 0.2.0", - "vt100", "web-time", ] @@ -3697,7 +3690,6 @@ dependencies = [ "toml", "toml_edit", "tracing", - "tracing-indicatif", "tracing-subscriber", "url", "wax", @@ -4147,10 +4139,12 @@ dependencies = [ "tokio", "tokio-native-tls", "tokio-rustls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "webpki-roots", "windows-registry", @@ -5265,18 +5259,6 @@ dependencies = [ "valuable", ] -[[package]] -name = "tracing-indicatif" -version = "0.3.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74ba258e9de86447f75edf6455fded8e5242704c6fccffe7bf8d7fb6daef1180" -dependencies = [ - "indicatif", - "tracing", - "tracing-core", - "tracing-subscriber", -] - [[package]] name = "tracing-log" version = "0.2.0" @@ -5474,39 +5456,6 @@ version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" -[[package]] -name = "vt100" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "84cd863bf0db7e392ba3bd04994be3473491b31e66340672af5d11943c6274de" -dependencies = [ - "itoa", - "log", - "unicode-width 0.1.14", - "vte", -] - -[[package]] -name = "vte" -version = "0.11.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5022b5fbf9407086c180e9557be968742d839e68346af7792b8592489732197" -dependencies = [ - "arrayvec", - "utf8parse", - "vte_generate_state_changes", -] - -[[package]] -name = "vte_generate_state_changes" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e369bee1b05d510a7b4ed645f5faa90619e05437111783ea5848f28d97d3c2e" -dependencies = [ - "proc-macro2", - "quote", -] - [[package]] name = "walkdir" version = "2.5.0" @@ -5599,6 +5548,19 @@ version = "0.2.99" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "943aab3fdaaa029a6e0271b35ea10b72b943135afe9bffca82384098ad0e06a6" +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "wax" version = "0.6.0" diff --git a/Cargo.toml b/Cargo.toml index 96831b6..02cef76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,6 @@ bin = [ "dep:tracing-subscriber", "reqwest/json", "dep:indicatif", - "dep:tracing-indicatif", "dep:inquire", "dep:toml_edit", "dep:colored", @@ -49,7 +48,7 @@ toml = "0.8.19" serde_with = "3.11.0" gix = { version = "0.68.0", default-features = false, features = ["blocking-http-transport-reqwest-rust-tls", "revparse-regex", "credentials", "parallel"] } semver = { version = "1.0.24", features = ["serde"] } -reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls"] } +reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls", "stream"] } tokio-tar = "0.3.1" async-compression = { version = "0.4.18", features = ["tokio", "gzip"] } pathdiff = "0.2.3" @@ -83,7 +82,6 @@ clap = { version = "4.5.23", features = ["derive"], optional = true } dirs = { version = "5.0.1", optional = true } tracing-subscriber = { version = "0.3.19", features = ["env-filter"], optional = true } indicatif = { version = "0.17.9", optional = true } -tracing-indicatif = { version = "0.3.8", optional = true } inquire = { version = "0.7.5", optional = true } [target.'cfg(target_os = "windows")'.dependencies] @@ -104,4 +102,4 @@ codegen-units = 1 [profile.release.package.pesde-registry] # add debug symbols for Sentry stack traces -debug = "full" \ No newline at end of file +debug = "full" diff --git a/src/cli/commands/execute.rs b/src/cli/commands/execute.rs index a613324..6dab97d 100644 --- a/src/cli/commands/execute.rs +++ b/src/cli/commands/execute.rs @@ -1,8 +1,15 @@ -use crate::cli::{config::read_config, progress_bar, VersionedPackageName}; +use crate::cli::{ + config::read_config, + reporters::{self, CliReporter}, + VersionedPackageName, +}; use anyhow::Context; use clap::Args; +use colored::Colorize; use fs_err::tokio as fs; +use indicatif::MultiProgress; use pesde::{ + download_and_link::DownloadAndLinkOptions, linking::generator::generate_bin_linking_module, manifest::target::TargetKind, names::PackageName, @@ -14,7 +21,12 @@ use pesde::{ }; use semver::VersionReq; use std::{ - collections::HashSet, env::current_dir, ffi::OsString, io::Write, process::Command, sync::Arc, + collections::HashSet, + env::current_dir, + ffi::OsString, + io::{Stderr, Write}, + process::Command, + sync::Arc, }; use tokio::sync::Mutex; @@ -35,109 +47,123 @@ pub struct ExecuteCommand { impl ExecuteCommand { pub async fn run(self, project: Project, reqwest: reqwest::Client) -> anyhow::Result<()> { - let index = match self.index { - Some(index) => Some(index), - None => read_config().await.ok().map(|c| c.default_index), - } - .context("no index specified")?; - let source = PesdePackageSource::new(index); - source - .refresh(&project) - .await - .context("failed to refresh source")?; + let multi_progress = MultiProgress::new(); + crate::PROGRESS_BARS + .lock() + .unwrap() + .replace(multi_progress.clone()); - let version_req = self.package.1.unwrap_or(VersionReq::STAR); - let Some((version, pkg_ref)) = ('finder: { - let specifier = PesdeDependencySpecifier { - name: self.package.0.clone(), - version: version_req.clone(), - index: None, - target: None, - }; + let (tempdir, bin_path) = reporters::run_with_reporter_and_writer( + std::io::stderr(), + |multi_progress, root_progress, reporter| async { + let multi_progress = multi_progress; + let root_progress = root_progress; - if let Some(res) = source - .resolve(&specifier, &project, TargetKind::Lune, &mut HashSet::new()) - .await - .context("failed to resolve package")? - .1 - .pop_last() - { - break 'finder Some(res); - } + root_progress.set_message("resolve"); - source - .resolve(&specifier, &project, TargetKind::Luau, &mut HashSet::new()) - .await - .context("failed to resolve package")? - .1 - .pop_last() - }) else { - anyhow::bail!( - "no Lune or Luau package could be found for {}@{version_req}", - self.package.0, - ); - }; + let index = match self.index { + Some(index) => Some(index), + None => read_config().await.ok().map(|c| c.default_index), + } + .context("no index specified")?; + let source = PesdePackageSource::new(index); + source + .refresh(&project) + .await + .context("failed to refresh source")?; - println!("using {}@{version}", pkg_ref.name); + let version_req = self.package.1.unwrap_or(VersionReq::STAR); + let Some((version, pkg_ref)) = ('finder: { + let specifier = PesdeDependencySpecifier { + name: self.package.0.clone(), + version: version_req.clone(), + index: None, + target: None, + }; - let tmp_dir = project.cas_dir().join(".tmp"); - fs::create_dir_all(&tmp_dir) - .await - .context("failed to create temporary directory")?; - let tempdir = - tempfile::tempdir_in(tmp_dir).context("failed to create temporary directory")?; + if let Some(res) = source + .resolve(&specifier, &project, TargetKind::Lune, &mut HashSet::new()) + .await + .context("failed to resolve package")? + .1 + .pop_last() + { + break 'finder Some(res); + } - let project = Project::new( - tempdir.path(), - None::, - project.data_dir(), - project.cas_dir(), - project.auth_config().clone(), - ); + source + .resolve(&specifier, &project, TargetKind::Luau, &mut HashSet::new()) + .await + .context("failed to resolve package")? + .1 + .pop_last() + }) else { + anyhow::bail!( + "no Lune or Luau package could be found for {}@{version_req}", + self.package.0, + ); + }; - let (fs, target) = source - .download(&pkg_ref, &project, &reqwest) - .await - .context("failed to download package")?; - let bin_path = target.bin_path().context("package has no binary export")?; + let tmp_dir = project.cas_dir().join(".tmp"); + fs::create_dir_all(&tmp_dir) + .await + .context("failed to create temporary directory")?; + let tempdir = tempfile::tempdir_in(tmp_dir) + .context("failed to create temporary directory")?; - fs.write_to(tempdir.path(), project.cas_dir(), true) - .await - .context("failed to write package contents")?; + let project = Project::new( + tempdir.path(), + None::, + project.data_dir(), + project.cas_dir(), + project.auth_config().clone(), + ); - let mut refreshed_sources = HashSet::new(); + let (fs, target) = source + .download(&pkg_ref, &project, &reqwest, Arc::new(())) + .await + .context("failed to download package")?; + let bin_path = target.bin_path().context("package has no binary export")?; - let graph = project - .dependency_graph(None, &mut refreshed_sources, true) - .await - .context("failed to build dependency graph")?; - let graph = Arc::new(graph); + fs.write_to(tempdir.path(), project.cas_dir(), true) + .await + .context("failed to write package contents")?; - let (rx, downloaded_graph) = project - .download_and_link( - &graph, - &Arc::new(Mutex::new(refreshed_sources)), - &reqwest, - true, - true, - |_| async { Ok::<_, std::io::Error>(()) }, - ) - .await - .context("failed to download dependencies")?; + let mut refreshed_sources = HashSet::new(); - progress_bar( - graph.values().map(|versions| versions.len() as u64).sum(), - rx, - "📥 ".to_string(), - "downloading dependencies".to_string(), - "downloaded dependencies".to_string(), + let graph = project + .dependency_graph(None, &mut refreshed_sources, true) + .await + .context("failed to build dependency graph")?; + + multi_progress.suspend(|| { + eprintln!( + "{}", + format!("using {}", format!("{}@{version}", pkg_ref.name).bold()).dimmed() + ) + }); + + root_progress.reset(); + root_progress.set_message("download"); + root_progress.set_style(reporters::root_progress_style_with_progress()); + + project + .download_and_link( + &Arc::new(graph), + DownloadAndLinkOptions::, ()>::new(reqwest) + .reporter(reporter) + .refreshed_sources(Mutex::new(refreshed_sources)) + .prod(true) + .write(true), + ) + .await + .context("failed to download and link dependencies")?; + + anyhow::Ok((tempdir, bin_path.clone())) + }, ) .await?; - downloaded_graph - .await - .context("failed to download & link dependencies")?; - let mut caller = tempfile::NamedTempFile::new_in(tempdir.path()).context("failed to create tempfile")?; caller diff --git a/src/cli/commands/install.rs b/src/cli/commands/install.rs index 3ecc6ad..1ada08e 100644 --- a/src/cli/commands/install.rs +++ b/src/cli/commands/install.rs @@ -1,20 +1,10 @@ use crate::cli::{ - bin_dir, files::make_executable, progress_bar, run_on_workspace_members, up_to_date_lockfile, + install::{install, InstallOptions}, + run_on_workspace_members, }; -use anyhow::Context; use clap::Args; -use colored::{ColoredString, Colorize}; -use fs_err::tokio as fs; -use futures::future::try_join_all; -use pesde::{ - download_and_link::filter_graph, lockfile::Lockfile, manifest::target::TargetKind, Project, - MANIFEST_FILE_NAME, -}; -use std::{ - collections::{BTreeSet, HashMap, HashSet}, - sync::Arc, -}; -use tokio::sync::Mutex; +use pesde::Project; +use std::num::NonZeroUsize; #[derive(Debug, Args, Copy, Clone)] pub struct InstallCommand { @@ -25,303 +15,35 @@ pub struct InstallCommand { /// Whether to not install dev dependencies #[arg(long)] prod: bool, -} -fn bin_link_file(alias: &str) -> String { - let mut all_combinations = BTreeSet::new(); - - for a in TargetKind::VARIANTS { - for b in TargetKind::VARIANTS { - all_combinations.insert((a, b)); - } - } - - let all_folders = all_combinations - .into_iter() - .map(|(a, b)| format!("{:?}", a.packages_folder(b))) - .collect::>() - .into_iter() - .collect::>() - .join(", "); - - format!( - r#"local process = require("@lune/process") -local fs = require("@lune/fs") -local stdio = require("@lune/stdio") - -local project_root = process.cwd -local path_components = string.split(string.gsub(project_root, "\\", "/"), "/") - -for i = #path_components, 1, -1 do - local path = table.concat(path_components, "/", 1, i) - if fs.isFile(path .. "/{MANIFEST_FILE_NAME}") then - project_root = path - break - end -end - -for _, packages_folder in {{ {all_folders} }} do - local path = `{{project_root}}/{{packages_folder}}/{alias}.bin.luau` - - if fs.isFile(path) then - require(path) - return - end -end - -stdio.ewrite(stdio.color("red") .. "binary `{alias}` not found. are you in the right directory?" .. stdio.color("reset") .. "\n") - "#, - ) -} - -#[cfg(feature = "patches")] -const JOBS: u8 = 5; -#[cfg(not(feature = "patches"))] -const JOBS: u8 = 4; - -fn job(n: u8) -> ColoredString { - format!("[{n}/{JOBS}]").dimmed().bold() + /// The maximum number of concurrent network requests + #[arg(long, default_value = "16")] + network_concurrency: NonZeroUsize, } #[derive(Debug, thiserror::Error)] #[error(transparent)] struct CallbackError(#[from] anyhow::Error); - impl InstallCommand { pub async fn run(self, project: Project, reqwest: reqwest::Client) -> anyhow::Result<()> { - let mut refreshed_sources = HashSet::new(); - - let manifest = project - .deser_manifest() - .await - .context("failed to read manifest")?; - - let lockfile = if self.locked { - match up_to_date_lockfile(&project).await? { - None => { - anyhow::bail!( - "lockfile is out of sync, run `{} install` to update it", - env!("CARGO_BIN_NAME") - ); - } - file => file, - } - } else { - match project.deser_lockfile().await { - Ok(lockfile) => { - if lockfile.overrides != manifest.overrides { - tracing::debug!("overrides are different"); - None - } else if lockfile.target != manifest.target.kind() { - tracing::debug!("target kind is different"); - None - } else { - Some(lockfile) - } - } - Err(pesde::errors::LockfileReadError::Io(e)) - if e.kind() == std::io::ErrorKind::NotFound => - { - None - } - Err(e) => return Err(e.into()), - } + let options = InstallOptions { + locked: self.locked, + prod: self.prod, + write: true, + network_concurrency: self.network_concurrency, + use_lockfile: true, }; - println!( - "\n{}\n", - format!("[now installing {} {}]", manifest.name, manifest.target) - .bold() - .on_bright_black() - ); + install(&options, &project, reqwest.clone(), true).await?; - println!("{} ❌ removing current package folders", job(1)); - - { - let mut deleted_folders = HashMap::new(); - - for target_kind in TargetKind::VARIANTS { - let folder = manifest.target.kind().packages_folder(target_kind); - let package_dir = project.package_dir(); - - deleted_folders - .entry(folder.to_string()) - .or_insert_with(|| async move { - tracing::debug!("deleting the {folder} folder"); - - if let Some(e) = fs::remove_dir_all(package_dir.join(&folder)) - .await - .err() - .filter(|e| e.kind() != std::io::ErrorKind::NotFound) - { - return Err(e).context(format!("failed to remove the {folder} folder")); - }; - - Ok(()) - }); + run_on_workspace_members(&project, |project| { + let reqwest = reqwest.clone(); + async move { + install(&options, &project, reqwest, false).await?; + Ok(()) } - - try_join_all(deleted_folders.into_values()) - .await - .context("failed to remove package folders")?; - } - - let old_graph = lockfile.map(|lockfile| { - lockfile - .graph - .into_iter() - .map(|(name, versions)| { - ( - name, - versions - .into_iter() - .map(|(version, node)| (version, node.node)) - .collect(), - ) - }) - .collect() - }); - - println!("{} 📦 building dependency graph", job(2)); - - let graph = project - .dependency_graph(old_graph.as_ref(), &mut refreshed_sources, false) - .await - .context("failed to build dependency graph")?; - let graph = Arc::new(graph); - - let bin_folder = bin_dir().await?; - - let downloaded_graph = { - let (rx, downloaded_graph) = project - .download_and_link( - &graph, - &Arc::new(Mutex::new(refreshed_sources)), - &reqwest, - self.prod, - true, - |graph| { - let graph = graph.clone(); - - async move { - try_join_all( - graph - .values() - .flat_map(|versions| versions.values()) - .filter(|node| node.target.bin_path().is_some()) - .filter_map(|node| node.node.direct.as_ref()) - .map(|(alias, _, _)| alias) - .filter(|alias| { - if *alias == env!("CARGO_BIN_NAME") { - tracing::warn!( - "package {alias} has the same name as the CLI, skipping bin link" - ); - return false; - } - - true - }) - .map(|alias| { - let bin_folder = bin_folder.clone(); - async move { - let bin_exec_file = bin_folder.join(alias).with_extension(std::env::consts::EXE_EXTENSION); - - let impl_folder = bin_folder.join(".impl"); - fs::create_dir_all(&impl_folder).await.context("failed to create bin link folder")?; - - let bin_file = impl_folder.join(alias).with_extension("luau"); - fs::write(&bin_file, bin_link_file(alias)) - .await - .context("failed to write bin link file")?; - - - #[cfg(windows)] - { - fs::copy( - std::env::current_exe() - .context("failed to get current executable path")?, - &bin_exec_file, - ) - .await - .context("failed to copy bin link file")?; - } - - #[cfg(not(windows))] - { - fs::write( - &bin_exec_file, - format!(r#"#!/bin/sh -exec lune run "$(dirname "$0")/.impl/{alias}.luau" -- "$@""# - ), - ) - .await - .context("failed to link bin link file")?; - } - - make_executable(&bin_exec_file).await.context("failed to make bin link file executable")?; - - Ok::<_, CallbackError>(()) - } - }), - ) - .await - .map(|_| ()) - } - } - ) - .await - .context("failed to download dependencies")?; - - progress_bar( - graph.values().map(|versions| versions.len() as u64).sum(), - rx, - format!("{} 📥 ", job(3)), - "downloading dependencies".to_string(), - "downloaded dependencies".to_string(), - ) - .await?; - - downloaded_graph - .await - .context("failed to download & link dependencies")? - }; - - #[cfg(feature = "patches")] - { - let rx = project - .apply_patches(&filter_graph(&downloaded_graph, self.prod)) - .await - .context("failed to apply patches")?; - - progress_bar( - manifest.patches.values().map(|v| v.len() as u64).sum(), - rx, - format!("{} 🩹 ", job(JOBS - 1)), - "applying patches".to_string(), - "applied patches".to_string(), - ) - .await?; - } - - println!("{} 🧹 finishing up", job(JOBS)); - - project - .write_lockfile(Lockfile { - name: manifest.name, - version: manifest.version, - target: manifest.target.kind(), - overrides: manifest.overrides, - - graph: downloaded_graph, - - workspace: run_on_workspace_members(&project, |project| { - let reqwest = reqwest.clone(); - async move { Box::pin(self.run(project, reqwest)).await } - }) - .await?, - }) - .await - .context("failed to write lockfile")?; + }) + .await?; Ok(()) } diff --git a/src/cli/commands/patch.rs b/src/cli/commands/patch.rs index ea2518d..5ddbb19 100644 --- a/src/cli/commands/patch.rs +++ b/src/cli/commands/patch.rs @@ -1,3 +1,5 @@ +use std::sync::Arc; + use crate::cli::{up_to_date_lockfile, VersionedPackageName}; use anyhow::Context; use clap::Args; @@ -49,7 +51,7 @@ impl PatchCommand { fs::create_dir_all(&directory).await?; source - .download(&node.node.pkg_ref, &project, &reqwest) + .download(&node.node.pkg_ref, &project, &reqwest, Arc::new(())) .await? .0 .write_to(&directory, project.cas_dir(), false) diff --git a/src/cli/commands/update.rs b/src/cli/commands/update.rs index 3793093..ad48470 100644 --- a/src/cli/commands/update.rs +++ b/src/cli/commands/update.rs @@ -1,84 +1,42 @@ -use crate::cli::{progress_bar, run_on_workspace_members}; -use anyhow::Context; +use crate::cli::{ + install::{install, InstallOptions}, + run_on_workspace_members, +}; use clap::Args; -use colored::Colorize; -use pesde::{lockfile::Lockfile, Project}; -use std::{collections::HashSet, sync::Arc}; -use tokio::sync::Mutex; +use pesde::Project; +use std::num::NonZeroUsize; #[derive(Debug, Args, Copy, Clone)] -pub struct UpdateCommand {} +pub struct UpdateCommand { + /// Update the dependencies but don't install them + #[arg(long)] + no_install: bool, + + /// The maximum number of concurrent network requests + #[arg(long, default_value = "16")] + network_concurrency: NonZeroUsize, +} impl UpdateCommand { pub async fn run(self, project: Project, reqwest: reqwest::Client) -> anyhow::Result<()> { - let mut refreshed_sources = HashSet::new(); + let options = InstallOptions { + locked: false, + prod: false, + write: !self.no_install, + network_concurrency: self.network_concurrency, + use_lockfile: false, + }; - let manifest = project - .deser_manifest() - .await - .context("failed to read manifest")?; + install(&options, &project, reqwest.clone(), true).await?; - println!( - "\n{}\n", - format!("[now updating {} {}]", manifest.name, manifest.target) - .bold() - .on_bright_black() - ); - - let graph = project - .dependency_graph(None, &mut refreshed_sources, false) - .await - .context("failed to build dependency graph")?; - let graph = Arc::new(graph); - - project - .write_lockfile(Lockfile { - name: manifest.name, - version: manifest.version, - target: manifest.target.kind(), - overrides: manifest.overrides, - - graph: { - let (rx, downloaded_graph) = project - .download_and_link( - &graph, - &Arc::new(Mutex::new(refreshed_sources)), - &reqwest, - false, - false, - |_| async { Ok::<_, std::io::Error>(()) }, - ) - .await - .context("failed to download dependencies")?; - - progress_bar( - graph.values().map(|versions| versions.len() as u64).sum(), - rx, - "📥 ".to_string(), - "downloading dependencies".to_string(), - "downloaded dependencies".to_string(), - ) - .await?; - - downloaded_graph - .await - .context("failed to download dependencies")? - }, - - workspace: run_on_workspace_members(&project, |project| { - let reqwest = reqwest.clone(); - async move { Box::pin(self.run(project, reqwest)).await } - }) - .await?, - }) - .await - .context("failed to write lockfile")?; - - println!( - "\n\n{}. run `{} install` in order to install the new dependencies", - "✅ done".green(), - env!("CARGO_BIN_NAME") - ); + run_on_workspace_members(&project, |project| { + let reqwest = reqwest.clone(); + async move { + install(&options, &project, reqwest, false).await?; + Ok(()) + } + }) + .await?; Ok(()) } diff --git a/src/cli/install.rs b/src/cli/install.rs new file mode 100644 index 0000000..f7f1f86 --- /dev/null +++ b/src/cli/install.rs @@ -0,0 +1,482 @@ +use std::{ + collections::{BTreeMap, BTreeSet, HashMap, HashSet}, + num::NonZeroUsize, + sync::Arc, + time::Instant, +}; + +use anyhow::Context; +use colored::Colorize; +use fs_err::tokio as fs; +use futures::future::try_join_all; +use pesde::{ + download_and_link::{filter_graph, DownloadAndLinkHooks, DownloadAndLinkOptions}, + lockfile::{DependencyGraph, DownloadedGraph, Lockfile}, + manifest::{target::TargetKind, DependencyType}, + Project, MANIFEST_FILE_NAME, +}; +use tokio::{sync::Mutex, task::JoinSet}; + +use crate::cli::{ + bin_dir, + reporters::{self, CliReporter}, + run_on_workspace_members, up_to_date_lockfile, +}; + +use super::files::make_executable; + +fn bin_link_file(alias: &str) -> String { + let mut all_combinations = BTreeSet::new(); + + for a in TargetKind::VARIANTS { + for b in TargetKind::VARIANTS { + all_combinations.insert((a, b)); + } + } + + let all_folders = all_combinations + .into_iter() + .map(|(a, b)| format!("{:?}", a.packages_folder(b))) + .collect::>() + .into_iter() + .collect::>() + .join(", "); + + format!( + r#"local process = require("@lune/process") +local fs = require("@lune/fs") +local stdio = require("@lune/stdio") + +local project_root = process.cwd +local path_components = string.split(string.gsub(project_root, "\\", "/"), "/") + +for i = #path_components, 1, -1 do + local path = table.concat(path_components, "/", 1, i) + if fs.isFile(path .. "/{MANIFEST_FILE_NAME}") then + project_root = path + break + end +end + +for _, packages_folder in {{ {all_folders} }} do + local path = `{{project_root}}/{{packages_folder}}/{alias}.bin.luau` + + if fs.isFile(path) then + require(path) + return + end +end + +stdio.ewrite(stdio.color("red") .. "binary `{alias}` not found. are you in the right directory?" .. stdio.color("reset") .. "\n") + "#, + ) +} + +pub struct InstallHooks { + pub bin_folder: std::path::PathBuf, +} + +#[derive(Debug, thiserror::Error)] +#[error(transparent)] +pub struct InstallHooksError(#[from] anyhow::Error); + +impl DownloadAndLinkHooks for InstallHooks { + type Error = InstallHooksError; + + async fn on_bins_downloaded( + &self, + downloaded_graph: &pesde::lockfile::DownloadedGraph, + ) -> Result<(), Self::Error> { + let mut tasks = downloaded_graph + .values() + .flat_map(|versions| versions.values()) + .filter(|node| node.target.bin_path().is_some()) + .filter_map(|node| node.node.direct.as_ref()) + .map(|(alias, _, _)| alias) + .filter(|alias| { + if *alias == env!("CARGO_BIN_NAME") { + tracing::warn!( + "package {alias} has the same name as the CLI, skipping bin link" + ); + return false; + } + true + }) + .map(|alias| { + let bin_folder = self.bin_folder.clone(); + let alias = alias.clone(); + + async move { + let bin_exec_file = bin_folder + .join(&alias) + .with_extension(std::env::consts::EXE_EXTENSION); + + let impl_folder = bin_folder.join(".impl"); + fs::create_dir_all(&impl_folder) + .await + .context("failed to create bin link folder")?; + + let bin_file = impl_folder.join(&alias).with_extension("luau"); + fs::write(&bin_file, bin_link_file(&alias)) + .await + .context("failed to write bin link file")?; + + #[cfg(windows)] + match fs::symlink_file( + std::env::current_exe().context("failed to get current executable path")?, + &bin_exec_file, + ) + .await + { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {} + e => e.context("failed to copy bin link file")?, + } + + #[cfg(not(windows))] + fs::write( + &bin_exec_file, + format!( + r#"#!/bin/sh +exec lune run "$(dirname "$0")/.impl/{alias}.luau" -- "$@""# + ), + ) + .await + .context("failed to link bin link file")?; + + make_executable(&bin_exec_file) + .await + .context("failed to make bin link file executable")?; + + Ok::<_, anyhow::Error>(()) + } + }) + .collect::>(); + + while let Some(task) = tasks.join_next().await { + task.unwrap()?; + } + + Ok(()) + } +} + +#[derive(Debug, Clone, Copy)] +pub struct InstallOptions { + pub locked: bool, + pub prod: bool, + pub write: bool, + pub use_lockfile: bool, + pub network_concurrency: NonZeroUsize, +} + +pub async fn install( + options: &InstallOptions, + project: &Project, + reqwest: reqwest::Client, + is_root: bool, +) -> anyhow::Result<()> { + let start = Instant::now(); + + let mut refreshed_sources = HashSet::new(); + + let manifest = project + .deser_manifest() + .await + .context("failed to read manifest")?; + + let lockfile = if options.locked { + match up_to_date_lockfile(project).await? { + None => { + anyhow::bail!( + "lockfile is out of sync, run `{} install` to update it", + env!("CARGO_BIN_NAME") + ); + } + file => file, + } + } else { + match project.deser_lockfile().await { + Ok(lockfile) => { + if lockfile.overrides != manifest.overrides { + tracing::debug!("overrides are different"); + None + } else if lockfile.target != manifest.target.kind() { + tracing::debug!("target kind is different"); + None + } else { + Some(lockfile) + } + } + Err(pesde::errors::LockfileReadError::Io(e)) + if e.kind() == std::io::ErrorKind::NotFound => + { + None + } + Err(e) => return Err(e.into()), + } + }; + + let (new_lockfile, old_graph) = + reporters::run_with_reporter(|_, root_progress, reporter| async { + let root_progress = root_progress; + + root_progress.set_prefix(format!("{} {}: ", manifest.name, manifest.target)); + root_progress.set_message("clean"); + + if options.write { + let mut deleted_folders = HashMap::new(); + + for target_kind in TargetKind::VARIANTS { + let folder = manifest.target.kind().packages_folder(target_kind); + let package_dir = project.package_dir(); + + deleted_folders + .entry(folder.to_string()) + .or_insert_with(|| async move { + tracing::debug!("deleting the {folder} folder"); + + if let Some(e) = fs::remove_dir_all(package_dir.join(&folder)) + .await + .err() + .filter(|e| e.kind() != std::io::ErrorKind::NotFound) + { + return Err(e) + .context(format!("failed to remove the {folder} folder")); + }; + + Ok(()) + }); + } + + try_join_all(deleted_folders.into_values()) + .await + .context("failed to remove package folders")?; + } + + root_progress.reset(); + root_progress.set_message("resolve"); + + let old_graph = lockfile.map(|lockfile| { + lockfile + .graph + .into_iter() + .map(|(name, versions)| { + ( + name, + versions + .into_iter() + .map(|(version, node)| (version, node.node)) + .collect(), + ) + }) + .collect() + }); + + let graph = project + .dependency_graph( + old_graph.as_ref().filter(|_| options.use_lockfile), + &mut refreshed_sources, + false, + ) + .await + .context("failed to build dependency graph")?; + let graph = Arc::new(graph); + + root_progress.reset(); + root_progress.set_length(0); + root_progress.set_message("download"); + root_progress.set_style(reporters::root_progress_style_with_progress()); + + let hooks = InstallHooks { + bin_folder: bin_dir().await?, + }; + + let downloaded_graph = project + .download_and_link( + &graph, + DownloadAndLinkOptions::::new(reqwest.clone()) + .reporter(reporter.clone()) + .hooks(hooks) + .refreshed_sources(Mutex::new(refreshed_sources)) + .prod(options.prod) + .write(options.write) + .network_concurrency(options.network_concurrency), + ) + .await + .context("failed to download and link dependencies")?; + + #[cfg(feature = "patches")] + if options.write { + root_progress.reset(); + root_progress.set_length(0); + root_progress.set_message("patch"); + + project + .apply_patches(&filter_graph(&downloaded_graph, options.prod), reporter) + .await?; + } + + root_progress.set_message("finish"); + + let new_lockfile = Lockfile { + name: manifest.name.clone(), + version: manifest.version, + target: manifest.target.kind(), + overrides: manifest.overrides, + + graph: downloaded_graph, + + workspace: run_on_workspace_members(project, |_| async { Ok(()) }).await?, + }; + + project + .write_lockfile(&new_lockfile) + .await + .context("failed to write lockfile")?; + + anyhow::Ok((new_lockfile, old_graph.unwrap_or_default())) + }) + .await?; + + let elapsed = start.elapsed(); + + if is_root { + println!(); + } + + print_package_diff( + &format!("{} {}:", manifest.name, manifest.target), + old_graph, + new_lockfile.graph, + ); + + println!("done in {:.2}s", elapsed.as_secs_f64()); + println!(); + + Ok(()) +} + +/// Prints the difference between two graphs. +pub fn print_package_diff(prefix: &str, old_graph: DependencyGraph, new_graph: DownloadedGraph) { + let mut old_pkg_map = BTreeMap::new(); + let mut old_direct_pkg_map = BTreeMap::new(); + let mut new_pkg_map = BTreeMap::new(); + let mut new_direct_pkg_map = BTreeMap::new(); + + for (name, versions) in &old_graph { + for (version, node) in versions { + old_pkg_map.insert((name.clone(), version), node); + if node.direct.is_some() { + old_direct_pkg_map.insert((name.clone(), version), node); + } + } + } + + for (name, versions) in &new_graph { + for (version, node) in versions { + new_pkg_map.insert((name.clone(), version), &node.node); + if node.node.direct.is_some() { + new_direct_pkg_map.insert((name.clone(), version), &node.node); + } + } + } + + let added_pkgs = new_pkg_map + .iter() + .filter(|(key, _)| !old_pkg_map.contains_key(key)) + .map(|(key, &node)| (key, node)) + .collect::>(); + let removed_pkgs = old_pkg_map + .iter() + .filter(|(key, _)| !new_pkg_map.contains_key(key)) + .map(|(key, &node)| (key, node)) + .collect::>(); + let added_direct_pkgs = new_direct_pkg_map + .iter() + .filter(|(key, _)| !old_direct_pkg_map.contains_key(key)) + .map(|(key, &node)| (key, node)) + .collect::>(); + let removed_direct_pkgs = old_direct_pkg_map + .iter() + .filter(|(key, _)| !new_direct_pkg_map.contains_key(key)) + .map(|(key, &node)| (key, node)) + .collect::>(); + + let prefix = prefix.bold(); + + let no_changes = added_pkgs.is_empty() + && removed_pkgs.is_empty() + && added_direct_pkgs.is_empty() + && removed_direct_pkgs.is_empty(); + + if no_changes { + println!("{prefix} already up to date"); + } else { + let mut change_signs = [ + (!added_pkgs.is_empty()).then(|| format!("+{}", added_pkgs.len()).green().to_string()), + (!removed_pkgs.is_empty()) + .then(|| format!("-{}", removed_pkgs.len()).red().to_string()), + ] + .into_iter() + .flatten() + .collect::>() + .join(" "); + + let changes_empty = change_signs.is_empty(); + if changes_empty { + change_signs = "(no changes)".dimmed().to_string(); + } + + println!("{prefix} {change_signs}"); + + if !changes_empty { + println!( + "{}{}", + "+".repeat(added_pkgs.len()).green(), + "-".repeat(removed_pkgs.len()).red() + ); + } + + let dependency_groups = added_direct_pkgs + .iter() + .map(|(key, node)| (true, key, node)) + .chain( + removed_direct_pkgs + .iter() + .map(|(key, node)| (false, key, node)), + ) + .filter_map(|(added, key, node)| { + node.direct.as_ref().map(|(_, _, ty)| (added, key, ty)) + }) + .fold( + BTreeMap::>::new(), + |mut map, (added, key, &ty)| { + map.entry(ty).or_default().insert((key, added)); + map + }, + ); + + for (ty, set) in dependency_groups { + println!(); + + let ty_name = match ty { + DependencyType::Standard => "dependencies", + DependencyType::Peer => "peer_dependencies", + DependencyType::Dev => "dev_dependencies", + }; + println!("{}", format!("{ty_name}:").yellow().bold()); + + for ((name, version), added) in set { + println!( + "{} {} {}", + if added { "+".green() } else { "-".red() }, + name, + version.to_string().dimmed() + ); + } + } + + println!(); + } +} diff --git a/src/cli/mod.rs b/src/cli/mod.rs index a18373b..8e27e6a 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -15,7 +15,6 @@ use std::{ future::Future, path::PathBuf, str::FromStr, - time::Duration, }; use tokio::pin; use tracing::instrument; @@ -24,6 +23,8 @@ pub mod auth; pub mod commands; pub mod config; pub mod files; +pub mod install; +pub mod reporters; #[cfg(feature = "version-management")] pub mod version; @@ -193,39 +194,6 @@ pub fn parse_gix_url(s: &str) -> Result { s.try_into() } -pub async fn progress_bar>( - len: u64, - mut rx: tokio::sync::mpsc::Receiver>, - prefix: String, - progress_msg: String, - finish_msg: String, -) -> anyhow::Result<()> { - let bar = indicatif::ProgressBar::new(len) - .with_style( - indicatif::ProgressStyle::default_bar() - .template("{prefix}[{elapsed_precise}] {bar:40.208/166} {pos}/{len} {msg}")? - .progress_chars("█▓▒░ "), - ) - .with_prefix(prefix) - .with_message(progress_msg); - bar.enable_steady_tick(Duration::from_millis(100)); - - while let Some(result) = rx.recv().await { - bar.inc(1); - - match result { - Ok(text) => { - bar.set_message(text); - } - Err(e) => return Err(e.into()), - } - } - - bar.finish_with_message(finish_msg); - - Ok(()) -} - pub fn shift_project_dir(project: &Project, pkg_dir: PathBuf) -> Project { Project::new( pkg_dir, diff --git a/src/cli/reporters.rs b/src/cli/reporters.rs new file mode 100644 index 0000000..1ea5b00 --- /dev/null +++ b/src/cli/reporters.rs @@ -0,0 +1,213 @@ +//! Progress reporters for the CLI + +use std::{ + future::Future, + io::{Stdout, Write}, + sync::{Arc, Mutex, Once, OnceLock}, + time::Duration, +}; + +use colored::Colorize; +use indicatif::{MultiProgress, ProgressBar, ProgressStyle}; +use pesde::reporters::{ + DownloadProgressReporter, DownloadsReporter, PatchProgressReporter, PatchesReporter, +}; + +pub const TICK_CHARS: &str = "⣷⣯⣟⡿⢿⣻⣽⣾"; + +pub fn root_progress_style() -> ProgressStyle { + ProgressStyle::with_template("{prefix:.dim}{msg:>8.214/yellow} {spinner} [{elapsed_precise}]") + .unwrap() + .tick_chars(TICK_CHARS) +} + +pub fn root_progress_style_with_progress() -> ProgressStyle { + ProgressStyle::with_template( + "{prefix:.dim}{msg:>8.214/yellow} {spinner} [{elapsed_precise}] {bar:20} {pos}/{len}", + ) + .unwrap() + .tick_chars(TICK_CHARS) +} + +pub async fn run_with_reporter_and_writer(writer: W, f: F) -> R +where + W: Write + Send + Sync + 'static, + F: FnOnce(MultiProgress, ProgressBar, Arc>) -> Fut, + Fut: Future, +{ + let multi_progress = MultiProgress::new(); + crate::PROGRESS_BARS + .lock() + .unwrap() + .replace(multi_progress.clone()); + + let root_progress = multi_progress.add(ProgressBar::new(0)); + root_progress.set_style(root_progress_style()); + root_progress.enable_steady_tick(Duration::from_millis(100)); + + let reporter = Arc::new(CliReporter::with_writer( + writer, + multi_progress.clone(), + root_progress.clone(), + )); + let result = f(multi_progress.clone(), root_progress.clone(), reporter).await; + + root_progress.finish(); + multi_progress.clear().unwrap(); + crate::PROGRESS_BARS.lock().unwrap().take(); + + result +} + +pub async fn run_with_reporter(f: F) -> R +where + F: FnOnce(MultiProgress, ProgressBar, Arc>) -> Fut, + Fut: Future, +{ + run_with_reporter_and_writer(std::io::stdout(), f).await +} + +pub struct CliReporter { + writer: Mutex, + child_style: ProgressStyle, + child_style_with_bytes: ProgressStyle, + child_style_with_bytes_without_total: ProgressStyle, + multi_progress: MultiProgress, + root_progress: ProgressBar, +} + +impl CliReporter { + pub fn with_writer( + writer: W, + multi_progress: MultiProgress, + root_progress: ProgressBar, + ) -> Self { + Self { + writer: Mutex::new(writer), + child_style: ProgressStyle::with_template(&"{msg}".dimmed().to_string()).unwrap(), + child_style_with_bytes: ProgressStyle::with_template( + &"{msg} {bytes}/{total_bytes}".dimmed().to_string(), + ) + .unwrap(), + child_style_with_bytes_without_total: ProgressStyle::with_template( + &"{msg} {bytes}".dimmed().to_string(), + ) + .unwrap(), + multi_progress, + root_progress, + } + } +} + +pub struct CliDownloadProgressReporter<'a, W> { + root_reporter: &'a CliReporter, + name: String, + progress: OnceLock, + set_progress: Once, +} + +impl<'a, W: Write + Send + Sync + 'static> DownloadsReporter<'a> for CliReporter { + type DownloadProgressReporter = CliDownloadProgressReporter<'a, W>; + + fn report_download<'b>(&'a self, name: &'b str) -> Self::DownloadProgressReporter { + self.root_progress.inc_length(1); + + CliDownloadProgressReporter { + root_reporter: self, + name: name.to_string(), + progress: OnceLock::new(), + set_progress: Once::new(), + } + } +} + +impl DownloadProgressReporter + for CliDownloadProgressReporter<'_, W> +{ + fn report_start(&self) { + let progress = self.root_reporter.multi_progress.add(ProgressBar::new(0)); + progress.set_style(self.root_reporter.child_style.clone()); + progress.set_message(format!("- {}", self.name)); + + self.progress + .set(progress) + .expect("report_start called more than once"); + } + + fn report_progress(&self, total: u64, len: u64) { + if let Some(progress) = self.progress.get() { + progress.set_length(total); + progress.set_position(len); + + self.set_progress.call_once(|| { + if total > 0 { + progress.set_style(self.root_reporter.child_style_with_bytes.clone()); + } else { + progress.set_style( + self.root_reporter + .child_style_with_bytes_without_total + .clone(), + ); + } + }); + } + } + + fn report_done(&self) { + if let Some(progress) = self.progress.get() { + if progress.is_hidden() { + writeln!( + self.root_reporter.writer.lock().unwrap(), + "downloaded {}", + self.name + ) + .unwrap(); + } + + progress.finish(); + self.root_reporter.multi_progress.remove(progress); + self.root_reporter.root_progress.inc(1); + } + } +} + +pub struct CliPatchProgressReporter<'a, W> { + root_reporter: &'a CliReporter, + name: String, + progress: ProgressBar, +} + +impl<'a, W: Write + Send + Sync + 'static> PatchesReporter<'a> for CliReporter { + type PatchProgressReporter = CliPatchProgressReporter<'a, W>; + + fn report_patch<'b>(&'a self, name: &'b str) -> Self::PatchProgressReporter { + let progress = self.multi_progress.add(ProgressBar::new(0)); + progress.set_style(self.child_style.clone()); + progress.set_message(format!("- {name}")); + + self.root_progress.inc_length(1); + + CliPatchProgressReporter { + root_reporter: self, + name: name.to_string(), + progress, + } + } +} + +impl PatchProgressReporter for CliPatchProgressReporter<'_, W> { + fn report_done(&self) { + if self.progress.is_hidden() { + writeln!( + self.root_reporter.writer.lock().unwrap(), + "patched {}", + self.name + ) + .unwrap(); + } + + self.progress.finish(); + self.root_reporter.multi_progress.remove(&self.progress); + self.root_reporter.root_progress.inc(1); + } +} diff --git a/src/download.rs b/src/download.rs index 0b6705b..e1c19b1 100644 --- a/src/download.rs +++ b/src/download.rs @@ -1,50 +1,132 @@ use crate::{ - lockfile::{DependencyGraph, DownloadedDependencyGraphNode, DownloadedGraph}, + lockfile::{DependencyGraph, DownloadedDependencyGraphNode}, manifest::DependencyType, + names::PackageNames, refresh_sources, + reporters::{DownloadProgressReporter, DownloadsReporter}, source::{ traits::{PackageRef, PackageSource}, + version_id::VersionId, PackageSources, }, Project, PACKAGES_CONTAINER_NAME, }; +use async_stream::try_stream; use fs_err::tokio as fs; -use std::{ - collections::HashSet, - sync::{Arc, Mutex}, -}; +use futures::Stream; +use std::{collections::HashSet, num::NonZeroUsize, sync::Arc}; +use tokio::{sync::Semaphore, task::JoinSet}; use tracing::{instrument, Instrument}; -type MultithreadedGraph = Arc>; +/// Options for downloading. +#[derive(Debug)] +pub struct DownloadGraphOptions { + /// The reqwest client. + pub reqwest: reqwest::Client, + /// The downloads reporter. + pub reporter: Option>, + /// Whether to skip dev dependencies. + pub prod: bool, + /// Whether to write the downloaded packages to disk. + pub write: bool, + /// Whether to download Wally packages. + pub wally: bool, + /// The max number of concurrent network requests. + pub network_concurrency: NonZeroUsize, +} -pub(crate) type MultithreadDownloadJob = ( - tokio::sync::mpsc::Receiver>, - MultithreadedGraph, -); +impl DownloadGraphOptions +where + Reporter: for<'a> DownloadsReporter<'a> + Send + Sync + 'static, +{ + /// Creates a new download options with the given reqwest client and reporter. + pub fn new(reqwest: reqwest::Client) -> Self { + Self { + reqwest, + reporter: None, + prod: false, + write: false, + wally: false, + network_concurrency: NonZeroUsize::new(16).unwrap(), + } + } + + /// Sets the downloads reporter. + pub fn reporter(mut self, reporter: impl Into>) -> Self { + self.reporter.replace(reporter.into()); + self + } + + /// Sets whether to skip dev dependencies. + pub fn prod(mut self, prod: bool) -> Self { + self.prod = prod; + self + } + + /// Sets whether to write the downloaded packages to disk. + pub fn write(mut self, write: bool) -> Self { + self.write = write; + self + } + + /// Sets whether to download Wally packages. + pub fn wally(mut self, wally: bool) -> Self { + self.wally = wally; + self + } + + /// Sets the max number of concurrent network requests. + pub fn network_concurrency(mut self, network_concurrency: NonZeroUsize) -> Self { + self.network_concurrency = network_concurrency; + self + } +} + +impl Clone for DownloadGraphOptions { + fn clone(&self) -> Self { + Self { + reqwest: self.reqwest.clone(), + reporter: self.reporter.clone(), + prod: self.prod, + write: self.write, + wally: self.wally, + network_concurrency: self.network_concurrency, + } + } +} impl Project { - /// Downloads a graph of dependencies - #[instrument(skip(self, graph, refreshed_sources, reqwest), level = "debug")] - pub async fn download_graph( + /// Downloads a graph of dependencies. + #[instrument(skip_all, fields(prod = options.prod, wally = options.wally, write = options.write), level = "debug")] + pub async fn download_graph( &self, graph: &DependencyGraph, refreshed_sources: &mut HashSet, - reqwest: &reqwest::Client, - prod: bool, - write: bool, - wally: bool, - ) -> Result { + options: DownloadGraphOptions, + ) -> Result< + impl Stream< + Item = Result< + (DownloadedDependencyGraphNode, PackageNames, VersionId), + errors::DownloadGraphError, + >, + >, + errors::DownloadGraphError, + > + where + Reporter: for<'a> DownloadsReporter<'a> + Send + Sync + 'static, + { + let DownloadGraphOptions { + reqwest, + reporter, + prod, + write, + wally, + network_concurrency, + } = options; + let manifest = self.deser_manifest().await?; let manifest_target_kind = manifest.target.kind(); - let downloaded_graph: MultithreadedGraph = Arc::new(Mutex::new(Default::default())); - - let (tx, rx) = tokio::sync::mpsc::channel( - graph - .iter() - .map(|(_, versions)| versions.len()) - .sum::() - .max(1), - ); + let project = Arc::new(self.clone()); refresh_sources( self, @@ -56,7 +138,8 @@ impl Project { ) .await?; - let project = Arc::new(self.clone()); + let mut tasks = JoinSet::>::new(); + let semaphore = Arc::new(Semaphore::new(network_concurrency.get())); for (name, versions) in graph { for (version_id, node) in versions { @@ -65,8 +148,6 @@ impl Project { continue; } - let tx = tx.clone(); - let name = name.clone(); let version_id = version_id.clone(); let node = node.clone(); @@ -79,14 +160,24 @@ impl Project { let project = project.clone(); let reqwest = reqwest.clone(); - let downloaded_graph = downloaded_graph.clone(); + let reporter = reporter.clone(); + let package_dir = project.package_dir().to_path_buf(); + let semaphore = semaphore.clone(); - let package_dir = self.package_dir().to_path_buf(); - - tokio::spawn( + tasks.spawn( async move { - let source = node.pkg_ref.source(); + let display_name = format!("{name}@{version_id}"); + let progress_reporter = reporter + .as_deref() + .map(|reporter| reporter.report_download(&display_name)); + let _permit = semaphore.acquire().await; + + if let Some(ref progress_reporter) = progress_reporter { + progress_reporter.report_start(); + } + + let source = node.pkg_ref.source(); let container_folder = node.container_folder( &package_dir .join(manifest_target_kind.packages_folder(version_id.target())) @@ -95,42 +186,37 @@ impl Project { version_id.version(), ); - match fs::create_dir_all(&container_folder).await { - Ok(_) => {} - Err(e) => { - tx.send(Err(errors::DownloadGraphError::Io(e))) - .await - .unwrap(); - return; - } - } + fs::create_dir_all(&container_folder).await?; let project = project.clone(); tracing::debug!("downloading"); - let (fs, target) = - match source.download(&node.pkg_ref, &project, &reqwest).await { - Ok(target) => target, - Err(e) => { - tx.send(Err(Box::new(e).into())).await.unwrap(); - return; - } - }; + let (fs, target) = match progress_reporter { + Some(progress_reporter) => { + source + .download( + &node.pkg_ref, + &project, + &reqwest, + Arc::new(progress_reporter), + ) + .await + } + None => { + source + .download(&node.pkg_ref, &project, &reqwest, Arc::new(())) + .await + } + } + .map_err(Box::new)?; tracing::debug!("downloaded"); if write { if !prod || node.resolved_ty != DependencyType::Dev { - match fs.write_to(container_folder, project.cas_dir(), true).await { - Ok(_) => {} - Err(e) => { - tx.send(Err(errors::DownloadGraphError::WriteFailed(e))) - .await - .unwrap(); - return; - } - }; + fs.write_to(container_folder, project.cas_dir(), true) + .await?; } else { tracing::debug!( "skipping write to disk, dev dependency in prod mode" @@ -138,24 +224,21 @@ impl Project { } } - let display_name = format!("{name}@{version_id}"); - - { - let mut downloaded_graph = downloaded_graph.lock().unwrap(); - downloaded_graph - .entry(name) - .or_default() - .insert(version_id, DownloadedDependencyGraphNode { node, target }); - } - - tx.send(Ok(display_name)).await.unwrap(); + let downloaded_node = DownloadedDependencyGraphNode { node, target }; + Ok((downloaded_node, name, version_id)) } .instrument(span), ); } } - Ok((rx, downloaded_graph)) + let stream = try_stream! { + while let Some(res) = tasks.join_next().await { + yield res.unwrap()?; + } + }; + + Ok(stream) } } diff --git a/src/download_and_link.rs b/src/download_and_link.rs index a8d7b19..3d553ff 100644 --- a/src/download_and_link.rs +++ b/src/download_and_link.rs @@ -1,14 +1,18 @@ use crate::{ + download::DownloadGraphOptions, lockfile::{DependencyGraph, DownloadedGraph}, manifest::DependencyType, + reporters::DownloadsReporter, source::PackageSources, Project, }; -use futures::FutureExt; +use futures::TryStreamExt; use std::{ collections::HashSet, - future::Future, - sync::{Arc, Mutex as StdMutex}, + convert::Infallible, + future::{self, Future}, + num::NonZeroUsize, + sync::Arc, }; use tokio::sync::Mutex; use tracing::{instrument, Instrument}; @@ -38,118 +42,242 @@ pub fn filter_graph(graph: &DownloadedGraph, prod: bool) -> DownloadedGraph { pub type DownloadAndLinkReceiver = tokio::sync::mpsc::Receiver>; +/// Hooks to perform actions after certain events during download and linking. +#[allow(unused_variables)] +pub trait DownloadAndLinkHooks { + /// The error type for the hooks. + type Error: std::error::Error + Send + Sync + 'static; + + /// Called after scripts have been downloaded. The `downloaded_graph` + /// contains all downloaded packages. + fn on_scripts_downloaded( + &self, + downloaded_graph: &DownloadedGraph, + ) -> impl Future> + Send { + future::ready(Ok(())) + } + + /// Called after binary dependencies have been downloaded. The + /// `downloaded_graph` contains all downloaded packages. + fn on_bins_downloaded( + &self, + downloaded_graph: &DownloadedGraph, + ) -> impl Future> + Send { + future::ready(Ok(())) + } + + /// Called after all dependencies have been downloaded. The + /// `downloaded_graph` contains all downloaded packages. + fn on_all_downloaded( + &self, + downloaded_graph: &DownloadedGraph, + ) -> impl Future> + Send { + future::ready(Ok(())) + } +} + +impl DownloadAndLinkHooks for () { + type Error = Infallible; +} + +/// Options for downloading and linking. +#[derive(Debug)] +pub struct DownloadAndLinkOptions { + /// The reqwest client. + pub reqwest: reqwest::Client, + /// The downloads reporter. + pub reporter: Option>, + /// The download and link hooks. + pub hooks: Option>, + /// The refreshed sources. + pub refreshed_sources: Arc>>, + /// Whether to skip dev dependencies. + pub prod: bool, + /// Whether to write the downloaded packages to disk. + pub write: bool, + /// The max number of concurrent network requests. + pub network_concurrency: NonZeroUsize, +} + +impl DownloadAndLinkOptions +where + Reporter: for<'a> DownloadsReporter<'a> + Send + Sync + 'static, + Hooks: DownloadAndLinkHooks + Send + Sync + 'static, +{ + /// Creates a new download options with the given reqwest client and reporter. + pub fn new(reqwest: reqwest::Client) -> Self { + Self { + reqwest, + reporter: None, + hooks: None, + refreshed_sources: Default::default(), + prod: false, + write: true, + network_concurrency: NonZeroUsize::new(16).unwrap(), + } + } + + /// Sets the downloads reporter. + pub fn reporter(mut self, reporter: impl Into>) -> Self { + self.reporter.replace(reporter.into()); + self + } + + /// Sets the download and link hooks. + pub fn hooks(mut self, hooks: impl Into>) -> Self { + self.hooks.replace(hooks.into()); + self + } + + /// Sets the refreshed sources. + pub fn refreshed_sources( + mut self, + refreshed_sources: impl Into>>>, + ) -> Self { + self.refreshed_sources = refreshed_sources.into(); + self + } + + /// Sets whether to skip dev dependencies. + pub fn prod(mut self, prod: bool) -> Self { + self.prod = prod; + self + } + + /// Sets whether to write the downloaded packages to disk. + pub fn write(mut self, write: bool) -> Self { + self.write = write; + self + } + + /// Sets the max number of concurrent network requests. + pub fn network_concurrency(mut self, network_concurrency: NonZeroUsize) -> Self { + self.network_concurrency = network_concurrency; + self + } +} + +impl Clone for DownloadAndLinkOptions { + fn clone(&self) -> Self { + Self { + reqwest: self.reqwest.clone(), + reporter: self.reporter.clone(), + hooks: self.hooks.clone(), + refreshed_sources: self.refreshed_sources.clone(), + prod: self.prod, + write: self.write, + network_concurrency: self.network_concurrency, + } + } +} + impl Project { /// Downloads a graph of dependencies and links them in the correct order - #[instrument( - skip(self, graph, refreshed_sources, reqwest, pesde_cb), - level = "debug" - )] - pub async fn download_and_link< - F: FnOnce(&Arc) -> R + Send + 'static, - R: Future> + Send, - E: Send + Sync + 'static, - >( + #[instrument(skip_all, fields(prod = options.prod, write = options.write), level = "debug")] + pub async fn download_and_link( &self, graph: &Arc, - refreshed_sources: &Arc>>, - reqwest: &reqwest::Client, - prod: bool, - write: bool, - pesde_cb: F, - ) -> Result< - ( - DownloadAndLinkReceiver, - impl Future>>, - ), - errors::DownloadAndLinkError, - > { - let (tx, rx) = tokio::sync::mpsc::channel( - graph - .iter() - .map(|(_, versions)| versions.len()) - .sum::() - .max(1), - ); - let downloaded_graph = Arc::new(StdMutex::new(DownloadedGraph::default())); + options: DownloadAndLinkOptions, + ) -> Result> + where + Reporter: for<'a> DownloadsReporter<'a> + 'static, + Hooks: DownloadAndLinkHooks + 'static, + { + let DownloadAndLinkOptions { + reqwest, + reporter, + hooks, + refreshed_sources, + prod, + write, + network_concurrency, + } = options; - let this = self.clone(); let graph = graph.clone(); let reqwest = reqwest.clone(); - let refreshed_sources = refreshed_sources.clone(); - Ok(( - rx, - tokio::spawn(async move { - let mut refreshed_sources = refreshed_sources.lock().await; + let mut refreshed_sources = refreshed_sources.lock().await; + let mut downloaded_graph = DownloadedGraph::new(); - // step 1. download pesde dependencies - let (mut pesde_rx, pesde_graph) = this - .download_graph(&graph, &mut refreshed_sources, &reqwest, prod, write, false) - .instrument(tracing::debug_span!("download (pesde)")) - .await?; + let mut download_graph_options = DownloadGraphOptions::::new(reqwest.clone()) + .prod(prod) + .write(write) + .network_concurrency(network_concurrency); - while let Some(result) = pesde_rx.recv().await { - tx.send(result).await.unwrap(); - } + if let Some(reporter) = reporter { + download_graph_options = download_graph_options.reporter(reporter.clone()); + } - let pesde_graph = Arc::into_inner(pesde_graph).unwrap().into_inner().unwrap(); + // step 1. download pesde dependencies + self.download_graph( + &graph, + &mut refreshed_sources, + download_graph_options.clone(), + ) + .instrument(tracing::debug_span!("download (pesde)")) + .await? + .try_for_each(|(downloaded_node, name, version_id)| { + downloaded_graph + .entry(name) + .or_default() + .insert(version_id, downloaded_node); - // step 2. link pesde dependencies. do so without types - if write { - this.link_dependencies(&filter_graph(&pesde_graph, prod), false) - .instrument(tracing::debug_span!("link (pesde)")) - .await?; - } + future::ready(Ok(())) + }) + .await?; - let pesde_graph = Arc::new(pesde_graph); + // step 2. link pesde dependencies. do so without types + if write { + self.link_dependencies(&filter_graph(&downloaded_graph, prod), false) + .instrument(tracing::debug_span!("link (pesde)")) + .await?; + } - pesde_cb(&pesde_graph) - .await - .map_err(errors::DownloadAndLinkError::PesdeCallback)?; + if let Some(ref hooks) = hooks { + hooks + .on_scripts_downloaded(&downloaded_graph) + .await + .map_err(errors::DownloadAndLinkError::Hook)?; - let pesde_graph = Arc::into_inner(pesde_graph).unwrap(); + hooks + .on_bins_downloaded(&downloaded_graph) + .await + .map_err(errors::DownloadAndLinkError::Hook)?; + } - // step 3. download wally dependencies - let (mut wally_rx, wally_graph) = this - .download_graph(&graph, &mut refreshed_sources, &reqwest, prod, write, true) - .instrument(tracing::debug_span!("download (wally)")) - .await?; + // step 3. download wally dependencies + self.download_graph( + &graph, + &mut refreshed_sources, + download_graph_options.clone().wally(true), + ) + .instrument(tracing::debug_span!("download (wally)")) + .await? + .try_for_each(|(downloaded_node, name, version_id)| { + downloaded_graph + .entry(name) + .or_default() + .insert(version_id, downloaded_node); - while let Some(result) = wally_rx.recv().await { - tx.send(result).await.unwrap(); - } + future::ready(Ok(())) + }) + .await?; - let wally_graph = Arc::into_inner(wally_graph).unwrap().into_inner().unwrap(); + // step 4. link ALL dependencies. do so with types + if write { + self.link_dependencies(&filter_graph(&downloaded_graph, prod), true) + .instrument(tracing::debug_span!("link (all)")) + .await?; + } - { - let mut downloaded_graph = downloaded_graph.lock().unwrap(); - downloaded_graph.extend(pesde_graph); - for (name, versions) in wally_graph { - for (version_id, node) in versions { - downloaded_graph - .entry(name.clone()) - .or_default() - .insert(version_id, node); - } - } - } + if let Some(ref hooks) = hooks { + hooks + .on_all_downloaded(&downloaded_graph) + .await + .map_err(errors::DownloadAndLinkError::Hook)?; + } - let graph = Arc::into_inner(downloaded_graph) - .unwrap() - .into_inner() - .unwrap(); - - // step 4. link ALL dependencies. do so with types - if write { - this.link_dependencies(&filter_graph(&graph, prod), true) - .instrument(tracing::debug_span!("link (all)")) - .await?; - } - - Ok(graph) - }) - .map(|r| r.unwrap()), - )) + Ok(downloaded_graph) } } @@ -170,7 +298,7 @@ pub mod errors { Linking(#[from] crate::linking::errors::LinkingError), /// An error occurred while executing the pesde callback - #[error("error executing pesde callback")] - PesdeCallback(#[source] E), + #[error("error executing hook")] + Hook(#[source] E), } } diff --git a/src/lib.rs b/src/lib.rs index a8fb864..f21b1b4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -35,6 +35,7 @@ pub mod names; /// Patching packages #[cfg(feature = "patches")] pub mod patches; +pub mod reporters; /// Resolving packages pub mod resolver; /// Running scripts @@ -182,9 +183,9 @@ impl Project { #[instrument(skip(self, lockfile), level = "debug")] pub async fn write_lockfile( &self, - lockfile: Lockfile, + lockfile: &Lockfile, ) -> Result<(), errors::LockfileWriteError> { - let string = toml::to_string(&lockfile)?; + let string = toml::to_string(lockfile)?; fs::write(self.package_dir.join(LOCKFILE_FILE_NAME), string).await?; Ok(()) } diff --git a/src/main.rs b/src/main.rs index e4bdc06..b3ab90b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,16 +4,18 @@ use crate::cli::{auth::get_tokens, display_err, home_dir, HOME_DIR}; use anyhow::Context; use clap::{builder::styling::AnsiColor, Parser}; use fs_err::tokio as fs; +use indicatif::MultiProgress; use pesde::{matching_globs, AuthConfig, Project, MANIFEST_FILE_NAME}; use std::{ collections::HashSet, + io, path::{Path, PathBuf}, + sync::Mutex, }; use tempfile::NamedTempFile; use tracing::instrument; -use tracing_indicatif::{filter::IndicatifFilter, IndicatifLayer}; use tracing_subscriber::{ - filter::LevelFilter, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, Layer, + filter::LevelFilter, fmt::MakeWriter, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, }; mod cli; @@ -88,6 +90,50 @@ async fn get_linkable_dir(path: &Path) -> PathBuf { ); } +pub static PROGRESS_BARS: Mutex> = Mutex::new(None); + +#[derive(Clone, Copy)] +pub struct IndicatifWriter; + +impl IndicatifWriter { + fn suspend R, R>(f: F) -> R { + match *PROGRESS_BARS.lock().unwrap() { + Some(ref progress_bars) => progress_bars.suspend(f), + None => f(), + } + } +} + +impl io::Write for IndicatifWriter { + fn write(&mut self, buf: &[u8]) -> io::Result { + Self::suspend(|| io::stderr().write(buf)) + } + + fn flush(&mut self) -> io::Result<()> { + Self::suspend(|| io::stderr().flush()) + } + + fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result { + Self::suspend(|| io::stderr().write_vectored(bufs)) + } + + fn write_all(&mut self, buf: &[u8]) -> io::Result<()> { + Self::suspend(|| io::stderr().write_all(buf)) + } + + fn write_fmt(&mut self, fmt: std::fmt::Arguments<'_>) -> io::Result<()> { + Self::suspend(|| io::stderr().write_fmt(fmt)) + } +} + +impl<'a> MakeWriter<'a> for IndicatifWriter { + type Writer = IndicatifWriter; + + fn make_writer(&'a self) -> Self::Writer { + *self + } +} + async fn run() -> anyhow::Result<()> { let cwd = std::env::current_dir().expect("failed to get current working directory"); @@ -133,8 +179,6 @@ async fn run() -> anyhow::Result<()> { std::process::exit(status.code().unwrap()); } - let indicatif_layer = IndicatifLayer::new().with_filter(IndicatifFilter::new(false)); - let tracing_env_filter = EnvFilter::builder() .with_default_directive(LevelFilter::INFO.into()) .from_env_lossy() @@ -146,8 +190,7 @@ async fn run() -> anyhow::Result<()> { .add_directive("hyper=info".parse().unwrap()) .add_directive("h2=info".parse().unwrap()); - let fmt_layer = - tracing_subscriber::fmt::layer().with_writer(indicatif_layer.inner().get_stderr_writer()); + let fmt_layer = tracing_subscriber::fmt::layer().with_writer(IndicatifWriter); #[cfg(debug_assertions)] let fmt_layer = fmt_layer.with_timer(tracing_subscriber::fmt::time::uptime()); @@ -163,7 +206,6 @@ async fn run() -> anyhow::Result<()> { tracing_subscriber::registry() .with(tracing_env_filter) .with(fmt_layer) - .with(indicatif_layer) .init(); let (project_root_dir, project_workspace_dir) = 'finder: { diff --git a/src/manifest/mod.rs b/src/manifest/mod.rs index 336f5f6..638c008 100644 --- a/src/manifest/mod.rs +++ b/src/manifest/mod.rs @@ -94,7 +94,7 @@ pub struct Manifest { } /// A dependency type -#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] #[serde(rename_all = "snake_case")] pub enum DependencyType { /// A standard dependency diff --git a/src/patches.rs b/src/patches.rs index 13dd8f6..5953bc1 100644 --- a/src/patches.rs +++ b/src/patches.rs @@ -1,8 +1,14 @@ -use crate::{lockfile::DownloadedGraph, Project, MANIFEST_FILE_NAME, PACKAGES_CONTAINER_NAME}; +use crate::{ + lockfile::DownloadedGraph, + reporters::{PatchProgressReporter, PatchesReporter}, + Project, MANIFEST_FILE_NAME, PACKAGES_CONTAINER_NAME, +}; use fs_err::tokio as fs; +use futures::TryFutureExt; use git2::{ApplyLocation, Diff, DiffFormat, DiffLineType, Repository, Signature}; use relative_path::RelativePathBuf; -use std::path::Path; +use std::{path::Path, sync::Arc}; +use tokio::task::JoinSet; use tracing::instrument; /// Set up a git repository for patches @@ -70,28 +76,21 @@ pub fn create_patch>(dir: P) -> Result, git2::Error> { impl Project { /// Apply patches to the project's dependencies - #[instrument(skip(self, graph), level = "debug")] - pub async fn apply_patches( + #[instrument(skip(self, graph, reporter), level = "debug")] + pub async fn apply_patches( &self, graph: &DownloadedGraph, - ) -> Result< - tokio::sync::mpsc::Receiver>, - errors::ApplyPatchesError, - > { + reporter: Arc, + ) -> Result<(), errors::ApplyPatchesError> + where + Reporter: for<'a> PatchesReporter<'a> + Send + Sync + 'static, + { let manifest = self.deser_manifest().await?; - let (tx, rx) = tokio::sync::mpsc::channel( - manifest - .patches - .values() - .map(|v| v.len()) - .sum::() - .max(1), - ); + + let mut tasks = JoinSet::>::new(); for (name, versions) in manifest.patches { for (version_id, patch_path) in versions { - let tx = tx.clone(); - let name = name.clone(); let patch_path = patch_path.to_path(self.package_dir()); @@ -102,7 +101,6 @@ impl Project { tracing::warn!( "patch for {name}@{version_id} not applied because it is not in the graph" ); - tx.send(Ok(format!("{name}@{version_id}"))).await.unwrap(); continue; }; @@ -115,41 +113,23 @@ impl Project { version_id.version(), ); - tokio::spawn(async move { + let reporter = reporter.clone(); + + tasks.spawn(async move { tracing::debug!("applying patch to {name}@{version_id}"); - let patch = match fs::read(&patch_path).await { - Ok(patch) => patch, - Err(e) => { - tx.send(Err(errors::ApplyPatchesError::PatchRead(e))) - .await - .unwrap(); - return; - } - }; + let display_name = format!("{name}@{version_id}"); + let progress_reporter = reporter.report_patch(&display_name); - let patch = match Diff::from_buffer(&patch) { - Ok(patch) => patch, - Err(e) => { - tx.send(Err(errors::ApplyPatchesError::Git(e))) - .await - .unwrap(); - return; - } - }; + let patch = fs::read(&patch_path) + .await + .map_err(errors::ApplyPatchesError::PatchRead)?; + let patch = Diff::from_buffer(&patch)?; { - let repo = match setup_patches_repo(&container_folder) { - Ok(repo) => repo, - Err(e) => { - tx.send(Err(errors::ApplyPatchesError::Git(e))) - .await - .unwrap(); - return; - } - }; + let repo = setup_patches_repo(&container_folder)?; - let modified_files = patch + let mut apply_delta_tasks = patch .deltas() .filter(|delta| matches!(delta.status(), git2::Delta::Modified)) .filter_map(|delta| delta.new_file().path()) @@ -159,61 +139,45 @@ impl Project { .to_path(&container_folder) }) .filter(|path| path.is_file()) - .collect::>(); - - for path in modified_files { - // there is no way (as far as I know) to check if it's hardlinked - // so, we always unlink it - let content = match fs::read(&path).await { - Ok(content) => content, - Err(e) => { - tx.send(Err(errors::ApplyPatchesError::File(e))) - .await - .unwrap(); - return; + .map(|path| { + async { + // so, we always unlink it + let content = fs::read(&path).await?; + fs::remove_file(&path).await?; + fs::write(path, content).await?; + Ok(()) } - }; + .map_err(errors::ApplyPatchesError::File) + }) + .collect::>(); - if let Err(e) = fs::remove_file(&path).await { - tx.send(Err(errors::ApplyPatchesError::File(e))) - .await - .unwrap(); - return; - } - - if let Err(e) = fs::write(path, content).await { - tx.send(Err(errors::ApplyPatchesError::File(e))) - .await - .unwrap(); - return; - } + while let Some(res) = apply_delta_tasks.join_next().await { + res.unwrap()?; } - if let Err(e) = repo.apply(&patch, ApplyLocation::Both, None) { - tx.send(Err(errors::ApplyPatchesError::Git(e))) - .await - .unwrap(); - return; - } + repo.apply(&patch, ApplyLocation::Both, None)?; } tracing::debug!( "patch applied to {name}@{version_id}, removing .git directory" ); - if let Err(e) = fs::remove_dir_all(container_folder.join(".git")).await { - tx.send(Err(errors::ApplyPatchesError::DotGitRemove(e))) - .await - .unwrap(); - return; - } + fs::remove_dir_all(container_folder.join(".git")) + .await + .map_err(errors::ApplyPatchesError::DotGitRemove)?; - tx.send(Ok(format!("{name}@{version_id}"))).await.unwrap(); + progress_reporter.report_done(); + + Ok(()) }); } } - Ok(rx) + while let Some(res) = tasks.join_next().await { + res.unwrap()? + } + + Ok(()) } } diff --git a/src/reporters.rs b/src/reporters.rs new file mode 100644 index 0000000..9849895 --- /dev/null +++ b/src/reporters.rs @@ -0,0 +1,63 @@ +//! Progress reporting +//! +//! Certain operations will ask for a progress reporter to be passed in, this +//! allows the caller to be notified of progress during the operation. This can +//! be used to show progress to the user. +//! +//! All reporter traits are implemented for `()`. These implementations do +//! nothing, and can be used to ignore progress reporting. + +#![allow(unused_variables)] + +/// Reports downloads. +pub trait DownloadsReporter<'a>: Send + Sync { + /// The [`DownloadProgressReporter`] type associated with this reporter. + type DownloadProgressReporter: DownloadProgressReporter + 'a; + + /// Starts a new download. + fn report_download<'b>(&'a self, name: &'b str) -> Self::DownloadProgressReporter; +} + +impl DownloadsReporter<'_> for () { + type DownloadProgressReporter = (); + fn report_download(&self, name: &str) -> Self::DownloadProgressReporter {} +} + +/// Reports the progress of a single download. +pub trait DownloadProgressReporter: Send + Sync { + /// Reports that the download has started. + fn report_start(&self) {} + + /// Reports the progress of the download. + /// + /// `total` is the total number of bytes to download, and `len` is the number + /// of bytes downloaded so far. + fn report_progress(&self, total: u64, len: u64) {} + + /// Reports that the download is done. + fn report_done(&self) {} +} + +impl DownloadProgressReporter for () {} + +/// Reports the progress of applying patches. +pub trait PatchesReporter<'a>: Send + Sync { + /// The [`PatchProgressReporter`] type associated with this reporter. + type PatchProgressReporter: PatchProgressReporter + 'a; + + /// Starts a new patch. + fn report_patch<'b>(&'a self, name: &'b str) -> Self::PatchProgressReporter; +} + +impl PatchesReporter<'_> for () { + type PatchProgressReporter = (); + fn report_patch(&self, name: &str) -> Self::PatchProgressReporter {} +} + +/// Reports the progress of a single patch. +pub trait PatchProgressReporter: Send + Sync { + /// Reports that the patch has been applied. + fn report_done(&self) {} +} + +impl PatchProgressReporter for () {} diff --git a/src/source/git/mod.rs b/src/source/git/mod.rs index 149ad93..fc729d2 100644 --- a/src/source/git/mod.rs +++ b/src/source/git/mod.rs @@ -4,6 +4,7 @@ use crate::{ Manifest, }, names::PackageNames, + reporters::DownloadProgressReporter, source::{ fs::{store_in_cas, FSEntry, PackageFS}, git::{pkg_ref::GitPackageRef, specifier::GitDependencySpecifier}, @@ -338,6 +339,7 @@ impl PackageSource for GitPackageSource { pkg_ref: &Self::Ref, project: &Project, _reqwest: &reqwest::Client, + _reporter: Arc, ) -> Result<(PackageFS, Target), Self::DownloadError> { let index_file = project .cas_dir diff --git a/src/source/mod.rs b/src/source/mod.rs index e02dd10..db1aa2c 100644 --- a/src/source/mod.rs +++ b/src/source/mod.rs @@ -1,6 +1,7 @@ use crate::{ manifest::target::{Target, TargetKind}, names::PackageNames, + reporters::DownloadProgressReporter, source::{ fs::PackageFS, refs::PackageRefs, specifiers::DependencySpecifiers, traits::*, version_id::VersionId, @@ -10,6 +11,7 @@ use crate::{ use std::{ collections::{BTreeMap, HashSet}, fmt::Debug, + sync::Arc, }; /// Packages' filesystems @@ -152,26 +154,27 @@ impl PackageSource for PackageSources { pkg_ref: &Self::Ref, project: &Project, reqwest: &reqwest::Client, + reporter: Arc, ) -> Result<(PackageFS, Target), Self::DownloadError> { match (self, pkg_ref) { (PackageSources::Pesde(source), PackageRefs::Pesde(pkg_ref)) => source - .download(pkg_ref, project, reqwest) + .download(pkg_ref, project, reqwest, reporter) .await .map_err(Into::into), #[cfg(feature = "wally-compat")] (PackageSources::Wally(source), PackageRefs::Wally(pkg_ref)) => source - .download(pkg_ref, project, reqwest) + .download(pkg_ref, project, reqwest, reporter) .await .map_err(Into::into), (PackageSources::Git(source), PackageRefs::Git(pkg_ref)) => source - .download(pkg_ref, project, reqwest) + .download(pkg_ref, project, reqwest, reporter) .await .map_err(Into::into), (PackageSources::Workspace(source), PackageRefs::Workspace(pkg_ref)) => source - .download(pkg_ref, project, reqwest) + .download(pkg_ref, project, reqwest, reporter) .await .map_err(Into::into), diff --git a/src/source/pesde/mod.rs b/src/source/pesde/mod.rs index 746eef6..30a30a7 100644 --- a/src/source/pesde/mod.rs +++ b/src/source/pesde/mod.rs @@ -7,7 +7,9 @@ use std::{ fmt::Debug, hash::Hash, path::PathBuf, + sync::Arc, }; +use tokio_util::io::StreamReader; use pkg_ref::PesdePackageRef; use specifier::PesdeDependencySpecifier; @@ -18,6 +20,7 @@ use crate::{ DependencyType, }, names::{PackageName, PackageNames}, + reporters::DownloadProgressReporter, source::{ fs::{store_in_cas, FSEntry, PackageFS}, git_index::{read_file, root_tree, GitBasedSource}, @@ -165,6 +168,7 @@ impl PackageSource for PesdePackageSource { pkg_ref: &Self::Ref, project: &Project, reqwest: &reqwest::Client, + reporter: Arc, ) -> Result<(PackageFS, Target), Self::DownloadError> { let config = self.config(project).await.map_err(Box::new)?; let index_file = project @@ -202,9 +206,26 @@ impl PackageSource for PesdePackageSource { } let response = request.send().await?.error_for_status()?; - let bytes = response.bytes().await?; - let mut decoder = async_compression::tokio::bufread::GzipDecoder::new(bytes.as_ref()); + let total_len = response.content_length().unwrap_or(0); + reporter.report_progress(total_len, 0); + + let mut bytes_downloaded = 0; + let bytes = response + .bytes_stream() + .inspect(|chunk| { + chunk.as_ref().ok().inspect(|chunk| { + bytes_downloaded += chunk.len() as u64; + reporter.report_progress(total_len, bytes_downloaded); + }); + }) + .map(|result| { + result.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)) + }); + + let bytes = StreamReader::new(bytes); + + let mut decoder = async_compression::tokio::bufread::GzipDecoder::new(bytes); let mut archive = tokio_tar::Archive::new(&mut decoder); let mut entries = BTreeMap::new(); @@ -254,6 +275,8 @@ impl PackageSource for PesdePackageSource { .await .map_err(errors::DownloadError::WriteIndex)?; + reporter.report_done(); + Ok((fs, pkg_ref.target.clone())) } } diff --git a/src/source/traits.rs b/src/source/traits.rs index 53fd066..c3777d7 100644 --- a/src/source/traits.rs +++ b/src/source/traits.rs @@ -4,12 +4,14 @@ use crate::{ target::{Target, TargetKind}, DependencyType, }, + reporters::DownloadProgressReporter, source::{DependencySpecifiers, PackageFS, PackageSources, ResolveResult}, Project, }; use std::{ collections::{BTreeMap, HashSet}, fmt::{Debug, Display}, + sync::Arc, }; /// A specifier for a dependency @@ -58,5 +60,6 @@ pub trait PackageSource: Debug { pkg_ref: &Self::Ref, project: &Project, reqwest: &reqwest::Client, + reporter: Arc, ) -> Result<(PackageFS, Target), Self::DownloadError>; } diff --git a/src/source/wally/mod.rs b/src/source/wally/mod.rs index 50df173..6ffd7b2 100644 --- a/src/source/wally/mod.rs +++ b/src/source/wally/mod.rs @@ -1,6 +1,7 @@ use crate::{ manifest::target::{Target, TargetKind}, names::PackageNames, + reporters::DownloadProgressReporter, source::{ fs::{store_in_cas, FSEntry, PackageFS}, git_index::{read_file, root_tree, GitBasedSource}, @@ -17,7 +18,7 @@ use crate::{ Project, }; use fs_err::tokio as fs; -use futures::future::try_join_all; +use futures::{future::try_join_all, StreamExt}; use gix::Url; use relative_path::RelativePathBuf; use reqwest::header::AUTHORIZATION; @@ -28,8 +29,12 @@ use std::{ sync::Arc, }; use tempfile::tempdir; -use tokio::{io::AsyncWriteExt, sync::Mutex, task::spawn_blocking}; -use tokio_util::compat::FuturesAsyncReadCompatExt; +use tokio::{ + io::{AsyncReadExt, AsyncWriteExt}, + sync::Mutex, + task::spawn_blocking, +}; +use tokio_util::{compat::FuturesAsyncReadCompatExt, io::StreamReader}; use tracing::instrument; pub(crate) mod compat_util; @@ -202,6 +207,7 @@ impl PackageSource for WallyPackageSource { pkg_ref: &Self::Ref, project: &Project, reqwest: &reqwest::Client, + reporter: Arc, ) -> Result<(PackageFS, Target), Self::DownloadError> { let config = self.config(project).await.map_err(Box::new)?; let index_file = project @@ -250,12 +256,30 @@ impl PackageSource for WallyPackageSource { } let response = request.send().await?.error_for_status()?; - let mut bytes = response.bytes().await?; - let archive = async_zip::tokio::read::seek::ZipFileReader::with_tokio( - std::io::Cursor::new(&mut bytes), - ) - .await?; + let total_len = response.content_length().unwrap_or(0); + reporter.report_progress(total_len, 0); + + let mut bytes_downloaded = 0; + let bytes = response + .bytes_stream() + .inspect(|chunk| { + chunk.as_ref().ok().inspect(|chunk| { + bytes_downloaded += chunk.len() as u64; + reporter.report_progress(total_len, bytes_downloaded); + }); + }) + .map(|result| { + result.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)) + }); + + let mut bytes = StreamReader::new(bytes); + let mut buf = vec![]; + bytes.read_to_end(&mut buf).await?; + + let archive = + async_zip::tokio::read::seek::ZipFileReader::with_tokio(std::io::Cursor::new(&mut buf)) + .await?; let entries = (0..archive.file().entries().len()) .map(|index| { @@ -328,6 +352,8 @@ impl PackageSource for WallyPackageSource { .await .map_err(errors::DownloadError::WriteIndex)?; + reporter.report_done(); + Ok((fs, get_target(project, &tempdir).await?)) } } diff --git a/src/source/workspace/mod.rs b/src/source/workspace/mod.rs index 36e75f9..8953476 100644 --- a/src/source/workspace/mod.rs +++ b/src/source/workspace/mod.rs @@ -1,6 +1,7 @@ use crate::{ manifest::target::{Target, TargetKind}, names::PackageNames, + reporters::DownloadProgressReporter, source::{ fs::PackageFS, specifiers::DependencySpecifiers, traits::PackageSource, version_id::VersionId, workspace::pkg_ref::WorkspacePackageRef, PackageSources, @@ -11,7 +12,10 @@ use crate::{ use futures::StreamExt; use relative_path::RelativePathBuf; use reqwest::Client; -use std::collections::{BTreeMap, HashSet}; +use std::{ + collections::{BTreeMap, HashSet}, + sync::Arc, +}; use tokio::pin; use tracing::instrument; @@ -134,6 +138,7 @@ impl PackageSource for WorkspacePackageSource { pkg_ref: &Self::Ref, project: &Project, _reqwest: &Client, + _reporter: Arc, ) -> Result<(PackageFS, Target), Self::DownloadError> { let path = pkg_ref.path.to_path(project.workspace_dir.clone().unwrap());