feat: multithreaded dependency downloading

This commit is contained in:
daimond113 2024-07-23 01:20:50 +02:00
parent 2898b02e1c
commit d4371519c2
No known key found for this signature in database
GPG key ID: 3A8ECE51328B513C
10 changed files with 124 additions and 47 deletions

1
Cargo.lock generated
View file

@ -2644,7 +2644,6 @@ dependencies = [
"inquire", "inquire",
"keyring", "keyring",
"log", "log",
"once_cell",
"open", "open",
"pathdiff", "pathdiff",
"pretty_env_logger", "pretty_env_logger",

View file

@ -42,7 +42,6 @@ threadpool = "1.8.1"
full_moon = { version = "1.0.0-rc.5", features = ["luau"] } full_moon = { version = "1.0.0-rc.5", features = ["luau"] }
url = { version = "2.5.2", features = ["serde"] } url = { version = "2.5.2", features = ["serde"] }
cfg-if = "1.0.0" cfg-if = "1.0.0"
once_cell = "1.19.0"
# TODO: reevaluate whether to use this # TODO: reevaluate whether to use this
# secrecy = "0.8.0" # secrecy = "0.8.0"
chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", features = ["serde"] }

View file

@ -1,14 +1,19 @@
use crate::cli::IsUpToDate; use crate::cli::{reqwest_client, IsUpToDate};
use anyhow::Context; use anyhow::Context;
use clap::Args; use clap::Args;
use indicatif::MultiProgress;
use pesde::{lockfile::Lockfile, Project}; use pesde::{lockfile::Lockfile, Project};
use std::collections::HashSet; use std::{collections::HashSet, sync::Arc, time::Duration};
#[derive(Debug, Args)] #[derive(Debug, Args)]
pub struct InstallCommand {} pub struct InstallCommand {
/// The amount of threads to use for downloading, defaults to 6
#[arg(short, long)]
threads: Option<usize>,
}
impl InstallCommand { impl InstallCommand {
pub fn run(self, project: Project) -> anyhow::Result<()> { pub fn run(self, project: Project, multi: MultiProgress) -> anyhow::Result<()> {
let mut refreshed_sources = HashSet::new(); let mut refreshed_sources = HashSet::new();
let manifest = project let manifest = project
@ -51,10 +56,43 @@ impl InstallCommand {
let graph = project let graph = project
.dependency_graph(old_graph.as_ref(), &mut refreshed_sources) .dependency_graph(old_graph.as_ref(), &mut refreshed_sources)
.context("failed to build dependency graph")?; .context("failed to build dependency graph")?;
let downloaded_graph = project
.download_graph(&graph, &mut refreshed_sources) let bar = multi.add(
indicatif::ProgressBar::new(graph.values().map(|versions| versions.len() as u64).sum())
.with_style(
indicatif::ProgressStyle::default_bar().template(
"{msg} {bar:40.208/166} {pos}/{len} {percent}% {elapsed_precise}",
)?,
)
.with_message("downloading dependencies"),
);
bar.enable_steady_tick(Duration::from_millis(100));
let (rx, downloaded_graph) = project
.download_graph(
&graph,
&mut refreshed_sources,
&reqwest_client(project.data_dir())?,
self.threads.unwrap_or(6).max(1),
)
.context("failed to download dependencies")?; .context("failed to download dependencies")?;
while let Ok(result) = rx.recv() {
bar.inc(1);
match result {
Ok(()) => {}
Err(e) => return Err(e.into()),
}
}
bar.finish_with_message("finished downloading dependencies");
let downloaded_graph = Arc::into_inner(downloaded_graph)
.unwrap()
.into_inner()
.unwrap();
project project
.link_dependencies(&downloaded_graph) .link_dependencies(&downloaded_graph)
.context("failed to link dependencies")?; .context("failed to link dependencies")?;

View file

@ -1,6 +1,7 @@
use crate::util::authenticate_conn; use crate::util::authenticate_conn;
use anyhow::Context; use anyhow::Context;
use gix::remote::Direction; use gix::remote::Direction;
use indicatif::MultiProgress;
use keyring::Entry; use keyring::Entry;
use pesde::Project; use pesde::Project;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -306,13 +307,13 @@ pub enum Subcommand {
} }
impl Subcommand { impl Subcommand {
pub fn run(self, project: Project) -> anyhow::Result<()> { pub fn run(self, project: Project, multi: MultiProgress) -> anyhow::Result<()> {
match self { match self {
Subcommand::Auth(auth) => auth.run(project), Subcommand::Auth(auth) => auth.run(project),
Subcommand::Config(config) => config.run(project), Subcommand::Config(config) => config.run(project),
Subcommand::Init(init) => init.run(project), Subcommand::Init(init) => init.run(project),
Subcommand::Run(run) => run.run(project), Subcommand::Run(run) => run.run(project),
Subcommand::Install(install) => install.run(project), Subcommand::Install(install) => install.run(project, multi),
Subcommand::Publish(publish) => publish.run(project), Subcommand::Publish(publish) => publish.run(project),
Subcommand::SelfInstall(self_install) => self_install.run(project), Subcommand::SelfInstall(self_install) => self_install.run(project),
} }

View file

@ -59,6 +59,13 @@ impl PublishCommand {
); );
} }
if manifest.includes.remove(".git") {
println!(
"{}: .git was in includes, removing it",
"warn".yellow().bold()
);
}
for (name, path) in [("lib path", lib_path), ("bin path", bin_path)] { for (name, path) in [("lib path", lib_path), ("bin path", bin_path)] {
let Some(export_path) = path else { continue }; let Some(export_path) = path else { continue };

View file

@ -1,6 +1,7 @@
use std::{ use std::{
collections::{BTreeMap, HashSet}, collections::HashSet,
fs::create_dir_all, fs::create_dir_all,
sync::{mpsc::Receiver, Arc, Mutex},
}; };
use crate::{ use crate::{
@ -9,16 +10,26 @@ use crate::{
Project, PACKAGES_CONTAINER_NAME, Project, PACKAGES_CONTAINER_NAME,
}; };
type MultithreadedGraph = Arc<Mutex<DownloadedGraph>>;
type MultithreadDownloadJob = (
Receiver<Result<(), errors::DownloadGraphError>>,
MultithreadedGraph,
);
impl Project { impl Project {
// TODO: use threadpool for concurrent downloads
pub fn download_graph( pub fn download_graph(
&self, &self,
graph: &DependencyGraph, graph: &DependencyGraph,
refreshed_sources: &mut HashSet<PackageSources>, refreshed_sources: &mut HashSet<PackageSources>,
) -> Result<DownloadedGraph, errors::DownloadGraphError> { reqwest: &reqwest::blocking::Client,
threads: usize,
) -> Result<MultithreadDownloadJob, errors::DownloadGraphError> {
let manifest = self.deser_manifest()?; let manifest = self.deser_manifest()?;
let downloaded_graph: MultithreadedGraph = Arc::new(Mutex::new(Default::default()));
let mut downloaded_graph: DownloadedGraph = BTreeMap::new(); let threadpool = threadpool::ThreadPool::new(threads);
let (tx, rx) = std::sync::mpsc::channel();
for (name, versions) in graph { for (name, versions) in graph {
for (version_id, node) in versions { for (version_id, node) in versions {
@ -43,19 +54,41 @@ impl Project {
create_dir_all(&container_folder)?; create_dir_all(&container_folder)?;
let target = source.download(&node.pkg_ref, &container_folder, self)?; let tx = tx.clone();
downloaded_graph.entry(name.clone()).or_default().insert( let name = name.clone();
version_id.clone(), let version_id = version_id.clone();
DownloadedDependencyGraphNode { let node = node.clone();
node: node.clone(),
target, let project = Arc::new(self.clone());
}, let reqwest = reqwest.clone();
); let downloaded_graph = downloaded_graph.clone();
threadpool.execute(move || {
let project = project.clone();
let target =
match source.download(&node.pkg_ref, &container_folder, &project, &reqwest)
{
Ok(target) => target,
Err(e) => {
tx.send(Err(e.into())).unwrap();
return;
}
};
let mut downloaded_graph = downloaded_graph.lock().unwrap();
downloaded_graph
.entry(name)
.or_default()
.insert(version_id, DownloadedDependencyGraphNode { node, target });
tx.send(Ok(())).unwrap();
});
} }
} }
Ok(downloaded_graph) Ok((rx, downloaded_graph))
} }
} }

View file

@ -4,7 +4,6 @@
compile_error!("at least one of the features `roblox`, `lune`, or `luau` must be enabled"); compile_error!("at least one of the features `roblox`, `lune`, or `luau` must be enabled");
use crate::lockfile::Lockfile; use crate::lockfile::Lockfile;
use once_cell::sync::Lazy;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
pub mod download; pub mod download;
@ -23,17 +22,6 @@ pub const DEFAULT_INDEX_NAME: &str = "default";
pub const PACKAGES_CONTAINER_NAME: &str = ".pesde"; pub const PACKAGES_CONTAINER_NAME: &str = ".pesde";
pub const MAX_ARCHIVE_SIZE: usize = 4 * 1024 * 1024; pub const MAX_ARCHIVE_SIZE: usize = 4 * 1024 * 1024;
pub(crate) static REQWEST_CLIENT: Lazy<reqwest::blocking::Client> = Lazy::new(|| {
reqwest::blocking::Client::builder()
.user_agent(concat!(
env!("CARGO_PKG_NAME"),
"/",
env!("CARGO_PKG_VERSION")
))
.build()
.expect("failed to create reqwest client")
});
#[derive(Debug, Default, Clone)] #[derive(Debug, Default, Clone)]
pub struct AuthConfig { pub struct AuthConfig {
pesde_token: Option<String>, pesde_token: Option<String>,
@ -67,7 +55,7 @@ impl AuthConfig {
} }
} }
#[derive(Debug)] #[derive(Debug, Clone)]
pub struct Project { pub struct Project {
path: PathBuf, path: PathBuf,
data_dir: PathBuf, data_dir: PathBuf,

View file

@ -1,6 +1,8 @@
use crate::cli::get_token; use crate::cli::get_token;
use clap::Parser; use clap::Parser;
use colored::Colorize; use colored::Colorize;
use indicatif::MultiProgress;
use indicatif_log_bridge::LogWrapper;
use pesde::{AuthConfig, Project}; use pesde::{AuthConfig, Project};
use std::fs::create_dir_all; use std::fs::create_dir_all;
@ -20,7 +22,16 @@ struct Cli {
} }
fn main() { fn main() {
pretty_env_logger::init(); let multi = {
let logger = pretty_env_logger::formatted_builder()
.parse_env(pretty_env_logger::env_logger::Env::default().default_filter_or("info"))
.build();
let multi = MultiProgress::new();
LogWrapper::new(multi.clone(), logger).try_init().unwrap();
multi
};
let project_dirs = let project_dirs =
directories::ProjectDirs::from("com", env!("CARGO_PKG_NAME"), env!("CARGO_BIN_NAME")) directories::ProjectDirs::from("com", env!("CARGO_PKG_NAME"), env!("CARGO_BIN_NAME"))
@ -32,11 +43,10 @@ fn main() {
create_dir_all(data_dir).expect("failed to create data directory"); create_dir_all(data_dir).expect("failed to create data directory");
if let Err(err) = get_token(data_dir).and_then(|token| { if let Err(err) = get_token(data_dir).and_then(|token| {
cli.subcommand.run(Project::new( cli.subcommand.run(
cwd, Project::new(cwd, data_dir, AuthConfig::new().with_pesde_token(token)),
data_dir, multi,
AuthConfig::new().with_pesde_token(token), )
))
}) { }) {
eprintln!("{}: {err}\n", "error".red().bold()); eprintln!("{}: {err}\n", "error".red().bold());

View file

@ -134,6 +134,7 @@ pub trait PackageSource: Debug {
pkg_ref: &Self::Ref, pkg_ref: &Self::Ref,
destination: &Path, destination: &Path,
project: &Project, project: &Project,
reqwest: &reqwest::blocking::Client,
) -> Result<Target, Self::DownloadError>; ) -> Result<Target, Self::DownloadError>;
} }
impl PackageSource for PackageSources { impl PackageSource for PackageSources {
@ -178,10 +179,11 @@ impl PackageSource for PackageSources {
pkg_ref: &Self::Ref, pkg_ref: &Self::Ref,
destination: &Path, destination: &Path,
project: &Project, project: &Project,
reqwest: &reqwest::blocking::Client,
) -> Result<Target, Self::DownloadError> { ) -> Result<Target, Self::DownloadError> {
match (self, pkg_ref) { match (self, pkg_ref) {
(PackageSources::Pesde(source), PackageRefs::Pesde(pkg_ref)) => source (PackageSources::Pesde(source), PackageRefs::Pesde(pkg_ref)) => source
.download(pkg_ref, destination, project) .download(pkg_ref, destination, project, reqwest)
.map_err(Into::into), .map_err(Into::into),
_ => Err(errors::DownloadError::Mismatch), _ => Err(errors::DownloadError::Mismatch),

View file

@ -6,13 +6,12 @@ use serde::{Deserialize, Serialize};
use pkg_ref::PesdePackageRef; use pkg_ref::PesdePackageRef;
use specifier::PesdeDependencySpecifier; use specifier::PesdeDependencySpecifier;
use crate::manifest::TargetKind;
use crate::{ use crate::{
manifest::{DependencyType, Target}, manifest::{DependencyType, Target, TargetKind},
names::{PackageName, PackageNames}, names::{PackageName, PackageNames},
source::{hash, DependencySpecifiers, PackageSource, ResolveResult, VersionId}, source::{hash, DependencySpecifiers, PackageSource, ResolveResult, VersionId},
util::authenticate_conn, util::authenticate_conn,
Project, REQWEST_CLIENT, Project,
}; };
pub mod pkg_ref; pub mod pkg_ref;
@ -345,6 +344,7 @@ impl PackageSource for PesdePackageSource {
pkg_ref: &Self::Ref, pkg_ref: &Self::Ref,
destination: &Path, destination: &Path,
project: &Project, project: &Project,
reqwest: &reqwest::blocking::Client,
) -> Result<Target, Self::DownloadError> { ) -> Result<Target, Self::DownloadError> {
let config = self.config(project)?; let config = self.config(project)?;
@ -355,7 +355,7 @@ impl PackageSource for PesdePackageSource {
.replace("{PACKAGE_NAME}", name) .replace("{PACKAGE_NAME}", name)
.replace("{PACKAGE_VERSION}", &pkg_ref.version.to_string()); .replace("{PACKAGE_VERSION}", &pkg_ref.version.to_string());
let mut response = REQWEST_CLIENT.get(url); let mut response = reqwest.get(url);
if let Some(token) = &project.auth_config.pesde_token { if let Some(token) = &project.auth_config.pesde_token {
response = response.header("Authorization", format!("Bearer {token}")); response = response.header("Authorization", format!("Bearer {token}"));