From d4371519c2a14a229b881849e1b5f4eaace72853 Mon Sep 17 00:00:00 2001 From: daimond113 <72147841+daimond113@users.noreply.github.com> Date: Tue, 23 Jul 2024 01:20:50 +0200 Subject: [PATCH] feat: multithreaded dependency downloading --- Cargo.lock | 1 - Cargo.toml | 1 - src/cli/install.rs | 50 +++++++++++++++++++++++++++++----- src/cli/mod.rs | 5 ++-- src/cli/publish.rs | 7 +++++ src/download.rs | 59 ++++++++++++++++++++++++++++++++--------- src/lib.rs | 14 +--------- src/main.rs | 22 ++++++++++----- src/source/mod.rs | 4 ++- src/source/pesde/mod.rs | 8 +++--- 10 files changed, 124 insertions(+), 47 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 3601b9e..8916480 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2644,7 +2644,6 @@ dependencies = [ "inquire", "keyring", "log", - "once_cell", "open", "pathdiff", "pretty_env_logger", diff --git a/Cargo.toml b/Cargo.toml index ec07c29..40ae1d1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -42,7 +42,6 @@ threadpool = "1.8.1" full_moon = { version = "1.0.0-rc.5", features = ["luau"] } url = { version = "2.5.2", features = ["serde"] } cfg-if = "1.0.0" -once_cell = "1.19.0" # TODO: reevaluate whether to use this # secrecy = "0.8.0" chrono = { version = "0.4.38", features = ["serde"] } diff --git a/src/cli/install.rs b/src/cli/install.rs index f788fd6..a8a39de 100644 --- a/src/cli/install.rs +++ b/src/cli/install.rs @@ -1,14 +1,19 @@ -use crate::cli::IsUpToDate; +use crate::cli::{reqwest_client, IsUpToDate}; use anyhow::Context; use clap::Args; +use indicatif::MultiProgress; use pesde::{lockfile::Lockfile, Project}; -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc, time::Duration}; #[derive(Debug, Args)] -pub struct InstallCommand {} +pub struct InstallCommand { + /// The amount of threads to use for downloading, defaults to 6 + #[arg(short, long)] + threads: Option, +} impl InstallCommand { - pub fn run(self, project: Project) -> anyhow::Result<()> { + pub fn run(self, project: Project, multi: MultiProgress) -> anyhow::Result<()> { let mut refreshed_sources = HashSet::new(); let manifest = project @@ -51,10 +56,43 @@ impl InstallCommand { let graph = project .dependency_graph(old_graph.as_ref(), &mut refreshed_sources) .context("failed to build dependency graph")?; - let downloaded_graph = project - .download_graph(&graph, &mut refreshed_sources) + + let bar = multi.add( + indicatif::ProgressBar::new(graph.values().map(|versions| versions.len() as u64).sum()) + .with_style( + indicatif::ProgressStyle::default_bar().template( + "{msg} {bar:40.208/166} {pos}/{len} {percent}% {elapsed_precise}", + )?, + ) + .with_message("downloading dependencies"), + ); + bar.enable_steady_tick(Duration::from_millis(100)); + + let (rx, downloaded_graph) = project + .download_graph( + &graph, + &mut refreshed_sources, + &reqwest_client(project.data_dir())?, + self.threads.unwrap_or(6).max(1), + ) .context("failed to download dependencies")?; + while let Ok(result) = rx.recv() { + bar.inc(1); + + match result { + Ok(()) => {} + Err(e) => return Err(e.into()), + } + } + + bar.finish_with_message("finished downloading dependencies"); + + let downloaded_graph = Arc::into_inner(downloaded_graph) + .unwrap() + .into_inner() + .unwrap(); + project .link_dependencies(&downloaded_graph) .context("failed to link dependencies")?; diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 3304187..c21bdaf 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -1,6 +1,7 @@ use crate::util::authenticate_conn; use anyhow::Context; use gix::remote::Direction; +use indicatif::MultiProgress; use keyring::Entry; use pesde::Project; use serde::{Deserialize, Serialize}; @@ -306,13 +307,13 @@ pub enum Subcommand { } impl Subcommand { - pub fn run(self, project: Project) -> anyhow::Result<()> { + pub fn run(self, project: Project, multi: MultiProgress) -> anyhow::Result<()> { match self { Subcommand::Auth(auth) => auth.run(project), Subcommand::Config(config) => config.run(project), Subcommand::Init(init) => init.run(project), Subcommand::Run(run) => run.run(project), - Subcommand::Install(install) => install.run(project), + Subcommand::Install(install) => install.run(project, multi), Subcommand::Publish(publish) => publish.run(project), Subcommand::SelfInstall(self_install) => self_install.run(project), } diff --git a/src/cli/publish.rs b/src/cli/publish.rs index 160315f..c9097c6 100644 --- a/src/cli/publish.rs +++ b/src/cli/publish.rs @@ -59,6 +59,13 @@ impl PublishCommand { ); } + if manifest.includes.remove(".git") { + println!( + "{}: .git was in includes, removing it", + "warn".yellow().bold() + ); + } + for (name, path) in [("lib path", lib_path), ("bin path", bin_path)] { let Some(export_path) = path else { continue }; diff --git a/src/download.rs b/src/download.rs index 943e474..79648d7 100644 --- a/src/download.rs +++ b/src/download.rs @@ -1,6 +1,7 @@ use std::{ - collections::{BTreeMap, HashSet}, + collections::HashSet, fs::create_dir_all, + sync::{mpsc::Receiver, Arc, Mutex}, }; use crate::{ @@ -9,16 +10,26 @@ use crate::{ Project, PACKAGES_CONTAINER_NAME, }; +type MultithreadedGraph = Arc>; + +type MultithreadDownloadJob = ( + Receiver>, + MultithreadedGraph, +); + impl Project { - // TODO: use threadpool for concurrent downloads pub fn download_graph( &self, graph: &DependencyGraph, refreshed_sources: &mut HashSet, - ) -> Result { + reqwest: &reqwest::blocking::Client, + threads: usize, + ) -> Result { let manifest = self.deser_manifest()?; + let downloaded_graph: MultithreadedGraph = Arc::new(Mutex::new(Default::default())); - let mut downloaded_graph: DownloadedGraph = BTreeMap::new(); + let threadpool = threadpool::ThreadPool::new(threads); + let (tx, rx) = std::sync::mpsc::channel(); for (name, versions) in graph { for (version_id, node) in versions { @@ -43,19 +54,41 @@ impl Project { create_dir_all(&container_folder)?; - let target = source.download(&node.pkg_ref, &container_folder, self)?; + let tx = tx.clone(); - downloaded_graph.entry(name.clone()).or_default().insert( - version_id.clone(), - DownloadedDependencyGraphNode { - node: node.clone(), - target, - }, - ); + let name = name.clone(); + let version_id = version_id.clone(); + let node = node.clone(); + + let project = Arc::new(self.clone()); + let reqwest = reqwest.clone(); + let downloaded_graph = downloaded_graph.clone(); + + threadpool.execute(move || { + let project = project.clone(); + + let target = + match source.download(&node.pkg_ref, &container_folder, &project, &reqwest) + { + Ok(target) => target, + Err(e) => { + tx.send(Err(e.into())).unwrap(); + return; + } + }; + + let mut downloaded_graph = downloaded_graph.lock().unwrap(); + downloaded_graph + .entry(name) + .or_default() + .insert(version_id, DownloadedDependencyGraphNode { node, target }); + + tx.send(Ok(())).unwrap(); + }); } } - Ok(downloaded_graph) + Ok((rx, downloaded_graph)) } } diff --git a/src/lib.rs b/src/lib.rs index 9f54061..09fb938 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,7 +4,6 @@ compile_error!("at least one of the features `roblox`, `lune`, or `luau` must be enabled"); use crate::lockfile::Lockfile; -use once_cell::sync::Lazy; use std::path::{Path, PathBuf}; pub mod download; @@ -23,17 +22,6 @@ pub const DEFAULT_INDEX_NAME: &str = "default"; pub const PACKAGES_CONTAINER_NAME: &str = ".pesde"; pub const MAX_ARCHIVE_SIZE: usize = 4 * 1024 * 1024; -pub(crate) static REQWEST_CLIENT: Lazy = Lazy::new(|| { - reqwest::blocking::Client::builder() - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )) - .build() - .expect("failed to create reqwest client") -}); - #[derive(Debug, Default, Clone)] pub struct AuthConfig { pesde_token: Option, @@ -67,7 +55,7 @@ impl AuthConfig { } } -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Project { path: PathBuf, data_dir: PathBuf, diff --git a/src/main.rs b/src/main.rs index df11597..b3fa408 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,8 @@ use crate::cli::get_token; use clap::Parser; use colored::Colorize; +use indicatif::MultiProgress; +use indicatif_log_bridge::LogWrapper; use pesde::{AuthConfig, Project}; use std::fs::create_dir_all; @@ -20,7 +22,16 @@ struct Cli { } fn main() { - pretty_env_logger::init(); + let multi = { + let logger = pretty_env_logger::formatted_builder() + .parse_env(pretty_env_logger::env_logger::Env::default().default_filter_or("info")) + .build(); + let multi = MultiProgress::new(); + + LogWrapper::new(multi.clone(), logger).try_init().unwrap(); + + multi + }; let project_dirs = directories::ProjectDirs::from("com", env!("CARGO_PKG_NAME"), env!("CARGO_BIN_NAME")) @@ -32,11 +43,10 @@ fn main() { create_dir_all(data_dir).expect("failed to create data directory"); if let Err(err) = get_token(data_dir).and_then(|token| { - cli.subcommand.run(Project::new( - cwd, - data_dir, - AuthConfig::new().with_pesde_token(token), - )) + cli.subcommand.run( + Project::new(cwd, data_dir, AuthConfig::new().with_pesde_token(token)), + multi, + ) }) { eprintln!("{}: {err}\n", "error".red().bold()); diff --git a/src/source/mod.rs b/src/source/mod.rs index 63e8063..21ae138 100644 --- a/src/source/mod.rs +++ b/src/source/mod.rs @@ -134,6 +134,7 @@ pub trait PackageSource: Debug { pkg_ref: &Self::Ref, destination: &Path, project: &Project, + reqwest: &reqwest::blocking::Client, ) -> Result; } impl PackageSource for PackageSources { @@ -178,10 +179,11 @@ impl PackageSource for PackageSources { pkg_ref: &Self::Ref, destination: &Path, project: &Project, + reqwest: &reqwest::blocking::Client, ) -> Result { match (self, pkg_ref) { (PackageSources::Pesde(source), PackageRefs::Pesde(pkg_ref)) => source - .download(pkg_ref, destination, project) + .download(pkg_ref, destination, project, reqwest) .map_err(Into::into), _ => Err(errors::DownloadError::Mismatch), diff --git a/src/source/pesde/mod.rs b/src/source/pesde/mod.rs index df851b0..00dfd4d 100644 --- a/src/source/pesde/mod.rs +++ b/src/source/pesde/mod.rs @@ -6,13 +6,12 @@ use serde::{Deserialize, Serialize}; use pkg_ref::PesdePackageRef; use specifier::PesdeDependencySpecifier; -use crate::manifest::TargetKind; use crate::{ - manifest::{DependencyType, Target}, + manifest::{DependencyType, Target, TargetKind}, names::{PackageName, PackageNames}, source::{hash, DependencySpecifiers, PackageSource, ResolveResult, VersionId}, util::authenticate_conn, - Project, REQWEST_CLIENT, + Project, }; pub mod pkg_ref; @@ -345,6 +344,7 @@ impl PackageSource for PesdePackageSource { pkg_ref: &Self::Ref, destination: &Path, project: &Project, + reqwest: &reqwest::blocking::Client, ) -> Result { let config = self.config(project)?; @@ -355,7 +355,7 @@ impl PackageSource for PesdePackageSource { .replace("{PACKAGE_NAME}", name) .replace("{PACKAGE_VERSION}", &pkg_ref.version.to_string()); - let mut response = REQWEST_CLIENT.get(url); + let mut response = reqwest.get(url); if let Some(token) = &project.auth_config.pesde_token { response = response.header("Authorization", format!("Bearer {token}"));