From fc349e6f21ee803675857ec683c6ceb963fe13dc Mon Sep 17 00:00:00 2001 From: daimond113 <72147841+daimond113@users.noreply.github.com> Date: Sun, 12 Jan 2025 23:12:27 +0100 Subject: [PATCH] feat: add engines Adds the initial implementation of the engines feature. Not tested yet. Requires documentation and more work for non-pesde engines to be usable. --- Cargo.toml | 7 +- registry/src/endpoints/search.rs | 6 +- src/cli/commands/self_install.rs | 10 +- src/cli/commands/self_upgrade.rs | 19 +- src/cli/version.rs | 328 +++++++++---------------- src/engine/mod.rs | 63 +++++ src/engine/source/archive.rs | 319 ++++++++++++++++++++++++ src/engine/source/github/engine_ref.rs | 19 ++ src/engine/source/github/mod.rs | 137 +++++++++++ src/engine/source/mod.rs | 144 +++++++++++ src/engine/source/traits.rs | 54 ++++ src/lib.rs | 33 ++- src/main.rs | 101 +++++--- src/manifest/mod.rs | 10 +- src/reporters.rs | 34 +++ src/source/pesde/mod.rs | 26 +- src/source/wally/mod.rs | 26 +- 17 files changed, 1014 insertions(+), 322 deletions(-) create mode 100644 src/engine/mod.rs create mode 100644 src/engine/source/archive.rs create mode 100644 src/engine/source/github/engine_ref.rs create mode 100644 src/engine/source/github/mod.rs create mode 100644 src/engine/source/mod.rs create mode 100644 src/engine/source/traits.rs diff --git a/Cargo.toml b/Cargo.toml index 5f5c616..c27420b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,6 @@ bin = [ "dep:clap", "dep:dirs", "dep:tracing-subscriber", - "reqwest/json", "dep:indicatif", "dep:inquire", "dep:toml_edit", @@ -30,7 +29,7 @@ bin = [ "tokio/rt-multi-thread", "tokio/macros", ] -wally-compat = ["dep:async_zip", "dep:serde_json"] +wally-compat = ["dep:serde_json"] patches = ["dep:git2"] version-management = ["bin"] schema = ["dep:schemars"] @@ -49,7 +48,7 @@ toml = "0.8.19" serde_with = "3.11.0" gix = { version = "0.68.0", default-features = false, features = ["blocking-http-transport-reqwest-rust-tls", "revparse-regex", "credentials", "parallel"] } semver = { version = "1.0.24", features = ["serde"] } -reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls", "stream"] } +reqwest = { version = "0.12.9", default-features = false, features = ["rustls-tls", "stream", "json"] } tokio-tar = "0.3.1" async-compression = { version = "0.4.18", features = ["tokio", "gzip"] } pathdiff = "0.2.3" @@ -68,11 +67,11 @@ tempfile = "3.14.0" wax = { version = "0.6.0", default-features = false } fs-err = { version = "3.0.0", features = ["tokio"] } urlencoding = "2.1.3" +async_zip = { version = "0.0.17", features = ["tokio", "deflate", "deflate64", "tokio-fs"] } # TODO: remove this when gitoxide adds support for: committing, pushing, adding git2 = { version = "0.19.0", optional = true } -async_zip = { version = "0.0.17", features = ["tokio", "deflate", "deflate64", "tokio-fs"], optional = true } serde_json = { version = "1.0.133", optional = true } schemars = { git = "https://github.com/daimond113/schemars", rev = "bc7c7d6", features = ["semver1", "url2"], optional = true } diff --git a/registry/src/endpoints/search.rs b/registry/src/endpoints/search.rs index 3f23049..9efcb5c 100644 --- a/registry/src/endpoints/search.rs +++ b/registry/src/endpoints/search.rs @@ -50,8 +50,10 @@ pub async fn search_packages( let source = Arc::new(app_state.source.clone().read_owned().await); - let mut results = Vec::with_capacity(top_docs.len()); - results.extend((0..top_docs.len()).map(|_| None::)); + let mut results = top_docs + .iter() + .map(|_| None::) + .collect::>(); let mut tasks = top_docs .into_iter() diff --git a/src/cli/commands/self_install.rs b/src/cli/commands/self_install.rs index 3763d71..6a2c597 100644 --- a/src/cli/commands/self_install.rs +++ b/src/cli/commands/self_install.rs @@ -1,8 +1,10 @@ -use crate::cli::{version::update_bin_exe, HOME_DIR}; +use crate::cli::{version::replace_bin_exe, HOME_DIR}; use anyhow::Context; use clap::Args; use colored::Colorize; +use pesde::engine::EngineKind; use std::env::current_exe; + #[derive(Debug, Args)] pub struct SelfInstallCommand { /// Skip adding the bin directory to the PATH @@ -70,7 +72,11 @@ and then restart your shell. ); } - update_bin_exe(¤t_exe().context("failed to get current exe path")?).await?; + replace_bin_exe( + EngineKind::Pesde, + ¤t_exe().context("failed to get current exe path")?, + ) + .await?; Ok(()) } diff --git a/src/cli/commands/self_upgrade.rs b/src/cli/commands/self_upgrade.rs index 2759fff..f355e82 100644 --- a/src/cli/commands/self_upgrade.rs +++ b/src/cli/commands/self_upgrade.rs @@ -1,13 +1,15 @@ use crate::cli::{ config::read_config, version::{ - current_version, get_or_download_version, get_remote_version, no_build_metadata, - update_bin_exe, TagInfo, VersionType, + current_version, find_latest_version, get_or_download_engine, no_build_metadata, + replace_bin_exe, }, }; use anyhow::Context; use clap::Args; use colored::Colorize; +use pesde::engine::EngineKind; +use semver::VersionReq; #[derive(Debug, Args)] pub struct SelfUpgradeCommand { @@ -25,7 +27,7 @@ impl SelfUpgradeCommand { .context("no cached version found")? .1 } else { - get_remote_version(&reqwest, VersionType::Latest).await? + find_latest_version(&reqwest).await? }; let latest_version_no_metadata = no_build_metadata(&latest_version); @@ -46,10 +48,13 @@ impl SelfUpgradeCommand { return Ok(()); } - let path = get_or_download_version(&reqwest, TagInfo::Complete(latest_version), true) - .await? - .unwrap(); - update_bin_exe(&path).await?; + let path = get_or_download_engine( + &reqwest, + EngineKind::Pesde, + VersionReq::parse(&format!("={latest_version}")).unwrap(), + ) + .await?; + replace_bin_exe(EngineKind::Pesde, &path).await?; println!("upgraded to version {display_latest_version}!"); diff --git a/src/cli/version.rs b/src/cli/version.rs index c88d9df..eda55f4 100644 --- a/src/cli/version.rs +++ b/src/cli/version.rs @@ -7,83 +7,28 @@ use crate::cli::{ use anyhow::Context; use colored::Colorize; use fs_err::tokio as fs; -use futures::StreamExt; -use reqwest::header::ACCEPT; -use semver::Version; -use serde::Deserialize; -use std::{ - env::current_exe, - path::{Path, PathBuf}, +use pesde::{ + engine::{ + source::{ + traits::{DownloadOptions, EngineSource, ResolveOptions}, + EngineSources, + }, + EngineKind, + }, + version_matches, +}; +use semver::{Version, VersionReq}; +use std::{ + collections::BTreeSet, + path::{Path, PathBuf}, + sync::Arc, }; -use tokio::io::AsyncWrite; use tracing::instrument; pub fn current_version() -> Version { Version::parse(env!("CARGO_PKG_VERSION")).unwrap() } -#[derive(Debug, Deserialize)] -struct Release { - tag_name: String, - assets: Vec, -} - -#[derive(Debug, Deserialize)] -struct Asset { - name: String, - url: url::Url, -} - -#[instrument(level = "trace")] -fn get_repo() -> (String, String) { - let mut parts = env!("CARGO_PKG_REPOSITORY").split('/').skip(3); - let (owner, repo) = ( - parts.next().unwrap().to_string(), - parts.next().unwrap().to_string(), - ); - - tracing::trace!("repository for updates: {owner}/{repo}"); - - (owner, repo) -} - -#[derive(Debug)] -pub enum VersionType { - Latest, - Specific(Version), -} - -#[instrument(skip(reqwest), level = "trace")] -pub async fn get_remote_version( - reqwest: &reqwest::Client, - ty: VersionType, -) -> anyhow::Result { - let (owner, repo) = get_repo(); - - let mut releases = reqwest - .get(format!( - "https://api.github.com/repos/{owner}/{repo}/releases", - )) - .send() - .await - .context("failed to send request to GitHub API")? - .error_for_status() - .context("failed to get GitHub API response")? - .json::>() - .await - .context("failed to parse GitHub API response")? - .into_iter() - .filter_map(|release| Version::parse(release.tag_name.trim_start_matches('v')).ok()); - - match ty { - VersionType::Latest => releases.max(), - VersionType::Specific(version) => { - releases.find(|v| no_build_metadata(v) == no_build_metadata(&version)) - } - } - .context("failed to find latest version") -} - pub fn no_build_metadata(version: &Version) -> Version { let mut version = version.clone(); version.build = semver::BuildMetadata::EMPTY; @@ -92,6 +37,23 @@ pub fn no_build_metadata(version: &Version) -> Version { const CHECK_INTERVAL: chrono::Duration = chrono::Duration::hours(6); +pub async fn find_latest_version(reqwest: &reqwest::Client) -> anyhow::Result { + let version = EngineSources::pesde() + .resolve( + &VersionReq::STAR, + &ResolveOptions { + reqwest: reqwest.clone(), + }, + ) + .await + .context("failed to resolve version")? + .pop_last() + .context("no versions found")? + .0; + + Ok(version) +} + #[instrument(skip(reqwest), level = "trace")] pub async fn check_for_updates(reqwest: &reqwest::Client) -> anyhow::Result<()> { let config = read_config().await?; @@ -104,7 +66,7 @@ pub async fn check_for_updates(reqwest: &reqwest::Client) -> anyhow::Result<()> version } else { tracing::debug!("checking for updates"); - let version = get_remote_version(reqwest, VersionType::Latest).await?; + let version = find_latest_version(reqwest).await?; write_config(&CliConfig { last_checked_updates: Some((chrono::Utc::now(), version.clone())), @@ -180,154 +142,105 @@ pub async fn check_for_updates(reqwest: &reqwest::Client) -> anyhow::Result<()> Ok(()) } -#[instrument(skip(reqwest, writer), level = "trace")] -pub async fn download_github_release( - reqwest: &reqwest::Client, - version: &Version, - mut writer: W, -) -> anyhow::Result<()> { - let (owner, repo) = get_repo(); - - let release = reqwest - .get(format!( - "https://api.github.com/repos/{owner}/{repo}/releases/tags/v{version}", - )) - .send() - .await - .context("failed to send request to GitHub API")? - .error_for_status() - .context("failed to get GitHub API response")? - .json::() - .await - .context("failed to parse GitHub API response")?; - - let asset = release - .assets - .into_iter() - .find(|asset| { - asset.name.ends_with(&format!( - "-{}-{}.tar.gz", - std::env::consts::OS, - std::env::consts::ARCH - )) - }) - .context("failed to find asset for current platform")?; - - let bytes = reqwest - .get(asset.url) - .header(ACCEPT, "application/octet-stream") - .send() - .await - .context("failed to send request to download asset")? - .error_for_status() - .context("failed to download asset")? - .bytes() - .await - .context("failed to download asset")?; - - let mut decoder = async_compression::tokio::bufread::GzipDecoder::new(bytes.as_ref()); - let mut archive = tokio_tar::Archive::new(&mut decoder); - - let mut entry = archive - .entries() - .context("failed to read archive entries")? - .next() - .await - .context("archive has no entry")? - .context("failed to get first archive entry")?; - - tokio::io::copy(&mut entry, &mut writer) - .await - .context("failed to write archive entry to file") - .map(|_| ()) -} - -#[derive(Debug)] -pub enum TagInfo { - Complete(Version), - Incomplete(Version), -} - #[instrument(skip(reqwest), level = "trace")] -pub async fn get_or_download_version( +pub async fn get_or_download_engine( reqwest: &reqwest::Client, - tag: TagInfo, - always_give_path: bool, -) -> anyhow::Result> { - let path = home_dir()?.join("versions"); + engine: EngineKind, + req: VersionReq, +) -> anyhow::Result { + let source = engine.source(); + + let path = home_dir()?.join("engines").join(source.directory()); fs::create_dir_all(&path) .await - .context("failed to create versions directory")?; + .context("failed to create engines directory")?; - let version = match &tag { - TagInfo::Complete(version) => version, - // don't fetch the version since it could be cached - TagInfo::Incomplete(version) => version, - }; + let mut read_dir = fs::read_dir(&path) + .await + .context("failed to read engines directory")?; - let path = path.join(format!( - "{}{}", - no_build_metadata(version), - std::env::consts::EXE_SUFFIX - )); + let mut matching_versions = BTreeSet::new(); - let is_requested_version = !always_give_path && *version == current_version(); + while let Some(entry) = read_dir.next_entry().await? { + let path = entry.path(); - if path.exists() { - tracing::debug!("version already exists"); + #[cfg(windows)] + let version = path.file_stem(); + #[cfg(not(windows))] + let version = path.file_name(); - return Ok(if is_requested_version { - None - } else { - Some(path) - }); - } - - if is_requested_version { - tracing::debug!("copying current executable to version directory"); - fs::copy(current_exe()?, &path) - .await - .context("failed to copy current executable to version directory")?; - } else { - let version = match tag { - TagInfo::Complete(version) => version, - TagInfo::Incomplete(version) => { - get_remote_version(reqwest, VersionType::Specific(version)) - .await - .context("failed to get remote version")? - } + let Some(version) = version.and_then(|s| s.to_str()) else { + continue; }; - tracing::debug!("downloading version"); - download_github_release( - reqwest, - &version, - fs::File::create(&path) - .await - .context("failed to create version file")?, - ) - .await?; + if let Ok(version) = Version::parse(version) { + if version_matches(&version, &req) { + matching_versions.insert(version); + } + } } + if let Some(version) = matching_versions.pop_last() { + return Ok(path + .join(version.to_string()) + .join(source.expected_file_name()) + .with_extension(std::env::consts::EXE_EXTENSION)); + } + + let mut versions = source + .resolve( + &req, + &ResolveOptions { + reqwest: reqwest.clone(), + }, + ) + .await + .context("failed to resolve versions")?; + let (version, engine_ref) = versions.pop_last().context("no matching versions found")?; + + let path = path + .join(version.to_string()) + .join(source.expected_file_name()) + .with_extension(std::env::consts::EXE_EXTENSION); + + let archive = source + .download( + &engine_ref, + &DownloadOptions { + reqwest: reqwest.clone(), + reporter: Arc::new(()), + version, + }, + ) + .await + .context("failed to download engine")?; + + let mut file = fs::File::create(&path) + .await + .context("failed to create new file")?; + tokio::io::copy( + &mut archive + .find_executable(source.expected_file_name()) + .await + .context("failed to find executable")?, + &mut file, + ) + .await + .context("failed to write to file")?; + make_executable(&path) .await .context("failed to make downloaded version executable")?; - Ok(if is_requested_version { - None - } else { - Some(path) - }) + Ok(path) } #[instrument(level = "trace")] -pub async fn update_bin_exe(downloaded_file: &Path) -> anyhow::Result<()> { - let bin_exe_path = bin_dir().await?.join(format!( - "{}{}", - env!("CARGO_BIN_NAME"), - std::env::consts::EXE_SUFFIX - )); - let mut downloaded_file = downloaded_file.to_path_buf(); +pub async fn replace_bin_exe(engine: EngineKind, with: &Path) -> anyhow::Result<()> { + let bin_exe_path = bin_dir() + .await? + .join(engine.to_string()) + .with_extension(std::env::consts::EXE_EXTENSION); let exists = bin_exe_path.exists(); @@ -339,21 +252,18 @@ pub async fn update_bin_exe(downloaded_file: &Path) -> anyhow::Result<()> { let tempfile = tempfile::Builder::new() .make(|_| Ok(())) .context("failed to create temporary file")?; - let path = tempfile.into_temp_path().to_path_buf(); + let temp_path = tempfile.into_temp_path().to_path_buf(); #[cfg(windows)] - let path = path.with_extension("exe"); + let temp_path = temp_path.with_extension("exe"); - let current_exe = current_exe().context("failed to get current exe path")?; - if current_exe == downloaded_file { - downloaded_file = path.to_path_buf(); + match fs::rename(&bin_exe_path, &temp_path).await { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::NotFound => {} + Err(e) => return Err(e).context("failed to rename existing executable"), } - - fs::rename(&bin_exe_path, &path) - .await - .context("failed to rename current executable")?; } - fs::copy(downloaded_file, &bin_exe_path) + fs::copy(with, &bin_exe_path) .await .context("failed to copy executable to bin folder")?; diff --git a/src/engine/mod.rs b/src/engine/mod.rs new file mode 100644 index 0000000..3dd11e1 --- /dev/null +++ b/src/engine/mod.rs @@ -0,0 +1,63 @@ +/// Sources of engines +pub mod source; + +use crate::engine::source::EngineSources; +use serde_with::{DeserializeFromStr, SerializeDisplay}; +use std::{fmt::Display, str::FromStr}; + +/// All supported engines +#[derive( + SerializeDisplay, DeserializeFromStr, Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord, +)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +#[cfg_attr(feature = "schema", schemars(rename_all = "snake_case"))] +pub enum EngineKind { + /// The pesde package manager + Pesde, + /// The Lune runtime + Lune, +} + +impl Display for EngineKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + EngineKind::Pesde => write!(f, "pesde"), + EngineKind::Lune => write!(f, "lune"), + } + } +} + +impl FromStr for EngineKind { + type Err = errors::EngineKindFromStrError; + + fn from_str(s: &str) -> Result { + match s.to_lowercase().as_str() { + "pesde" => Ok(EngineKind::Pesde), + "lune" => Ok(EngineKind::Lune), + _ => Err(errors::EngineKindFromStrError::Unknown(s.to_string())), + } + } +} + +impl EngineKind { + /// Returns the source to get this engine from + pub fn source(&self) -> EngineSources { + match self { + EngineKind::Pesde => EngineSources::pesde(), + EngineKind::Lune => EngineSources::lune(), + } + } +} + +/// Errors related to engine kinds +pub mod errors { + use thiserror::Error; + + /// Errors which can occur while using the FromStr implementation of EngineKind + #[derive(Debug, Error)] + pub enum EngineKindFromStrError { + /// The string isn't a recognized EngineKind + #[error("unknown engine kind {0}")] + Unknown(String), + } +} diff --git a/src/engine/source/archive.rs b/src/engine/source/archive.rs new file mode 100644 index 0000000..f73147c --- /dev/null +++ b/src/engine/source/archive.rs @@ -0,0 +1,319 @@ +use futures::StreamExt; +use std::{ + collections::BTreeSet, + mem::ManuallyDrop, + path::{Path, PathBuf}, + pin::Pin, + str::FromStr, + task::{Context, Poll}, +}; +use tokio::{ + io::{AsyncBufRead, AsyncRead, AsyncReadExt, ReadBuf}, + pin, +}; +use tokio_util::compat::{Compat, FuturesAsyncReadCompatExt}; + +/// The kind of encoding used for the archive +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum EncodingKind { + /// Gzip + Gzip, +} + +/// The kind of archive +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ArchiveKind { + /// Tar + Tar, + /// Zip + Zip, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub(crate) struct ArchiveInfo(ArchiveKind, Option); + +impl FromStr for ArchiveInfo { + type Err = errors::ArchiveInfoFromStrError; + + fn from_str(s: &str) -> Result { + let parts = s.split('.').collect::>(); + + Ok(match &*parts { + [.., "tar", "gz"] => ArchiveInfo(ArchiveKind::Tar, Some(EncodingKind::Gzip)), + [.., "tar"] => ArchiveInfo(ArchiveKind::Tar, None), + [.., "zip", "gz"] => { + return Err(errors::ArchiveInfoFromStrError::Unsupported( + ArchiveKind::Zip, + Some(EncodingKind::Gzip), + )) + } + [.., "zip"] => ArchiveInfo(ArchiveKind::Zip, None), + _ => return Err(errors::ArchiveInfoFromStrError::Invalid(s.to_string())), + }) + } +} + +/// An archive +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct Archive { + pub(crate) info: ArchiveInfo, + pub(crate) reader: R, +} + +#[derive(Debug)] +enum TarReader { + Gzip(async_compression::tokio::bufread::GzipDecoder), + Plain(R), +} + +// TODO: try to see if we can avoid the unsafe blocks + +impl AsyncRead for TarReader { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + unsafe { + match self.get_unchecked_mut() { + Self::Gzip(r) => Pin::new_unchecked(r).poll_read(cx, buf), + Self::Plain(r) => Pin::new_unchecked(r).poll_read(cx, buf), + } + } + } +} + +enum ArchiveEntryInner { + Tar(tokio_tar::Entry>>>>), + Zip { + archive: *mut async_zip::tokio::read::seek::ZipFileReader>>, + reader: ManuallyDrop< + Compat< + async_zip::tokio::read::ZipEntryReader< + 'static, + std::io::Cursor>, + async_zip::base::read::WithoutEntry, + >, + >, + >, + }, +} + +impl Drop for ArchiveEntryInner { + fn drop(&mut self) { + match self { + Self::Tar(_) => {} + Self::Zip { archive, reader } => unsafe { + ManuallyDrop::drop(reader); + drop(Box::from_raw(*archive)); + }, + } + } +} + +/// An entry in an archive. Usually the executable +pub struct ArchiveEntry(ArchiveEntryInner); + +impl AsyncRead for ArchiveEntry { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + unsafe { + match &mut self.get_unchecked_mut().0 { + ArchiveEntryInner::Tar(r) => Pin::new_unchecked(r).poll_read(cx, buf), + ArchiveEntryInner::Zip { reader, .. } => { + Pin::new_unchecked(&mut **reader).poll_read(cx, buf) + } + } + } + } +} + +impl Archive { + /// Finds the executable in the archive and returns it as an [`ArchiveEntry`] + pub async fn find_executable( + self, + expected_file_name: &str, + ) -> Result, errors::FindExecutableError> { + #[derive(Debug, PartialEq, Eq)] + struct Candidate { + path: PathBuf, + file_name_matches: bool, + extension_matches: bool, + has_permissions: bool, + } + + impl Candidate { + fn new(path: PathBuf, perms: u32, expected_file_name: &str) -> Self { + Self { + file_name_matches: path + .file_name() + .is_some_and(|name| name == expected_file_name), + extension_matches: match path.extension() { + Some(ext) if ext == std::env::consts::EXE_EXTENSION => true, + None if std::env::consts::EXE_EXTENSION.is_empty() => true, + _ => false, + }, + path, + has_permissions: perms & 0o111 != 0, + } + } + + fn should_be_considered(&self) -> bool { + // if nothing matches, we should not consider this candidate as it is most likely not + self.file_name_matches || self.extension_matches || self.has_permissions + } + } + + impl Ord for Candidate { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + self.file_name_matches + .cmp(&other.file_name_matches) + .then(self.extension_matches.cmp(&other.extension_matches)) + .then(self.has_permissions.cmp(&other.has_permissions)) + } + } + + impl PartialOrd for Candidate { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } + } + + let mut candidates = BTreeSet::new(); + + match self.info { + ArchiveInfo(ArchiveKind::Tar, encoding) => { + use async_compression::tokio::bufread as decoders; + + let reader = Box::pin(self.reader); + let reader = match encoding { + Some(EncodingKind::Gzip) => TarReader::Gzip(decoders::GzipDecoder::new(reader)), + None => TarReader::Plain(reader), + }; + + let mut archive = tokio_tar::Archive::new(reader); + let mut entries = archive.entries()?; + + while let Some(entry) = entries.next().await.transpose()? { + if entry.header().entry_type().is_dir() { + continue; + } + + let candidate = Candidate::new( + entry.path()?.to_path_buf(), + entry.header().mode()?, + expected_file_name, + ); + if candidate.should_be_considered() { + candidates.insert(candidate); + } + } + + let Some(candidate) = candidates.pop_last() else { + return Err(errors::FindExecutableError::ExecutableNotFound); + }; + + let mut entries = archive.entries()?; + + while let Some(entry) = entries.next().await.transpose()? { + if entry.header().entry_type().is_dir() { + continue; + } + + let path = entry.path()?; + if path == candidate.path { + return Ok(ArchiveEntry(ArchiveEntryInner::Tar(entry))); + } + } + } + ArchiveInfo(ArchiveKind::Zip, _) => { + let reader = self.reader; + pin!(reader); + + // TODO: would be lovely to not have to read the whole archive into memory + let mut buf = vec![]; + reader.read_to_end(&mut buf).await?; + + let archive = async_zip::base::read::seek::ZipFileReader::with_tokio( + std::io::Cursor::new(buf), + ) + .await?; + for entry in archive.file().entries() { + if entry.dir()? { + continue; + } + + let path: &Path = entry.filename().as_str()?.as_ref(); + let candidate = Candidate::new( + path.to_path_buf(), + entry.unix_permissions().unwrap_or(0) as u32, + expected_file_name, + ); + if candidate.should_be_considered() { + candidates.insert(candidate); + } + } + + let Some(candidate) = candidates.pop_last() else { + return Err(errors::FindExecutableError::ExecutableNotFound); + }; + + for (i, entry) in archive.file().entries().iter().enumerate() { + if entry.dir()? { + continue; + } + + let path: &Path = entry.filename().as_str()?.as_ref(); + if candidate.path == path { + let ptr = Box::into_raw(Box::new(archive)); + let reader = (unsafe { &mut *ptr }).reader_without_entry(i).await?; + return Ok(ArchiveEntry(ArchiveEntryInner::Zip { + archive: ptr, + reader: ManuallyDrop::new(reader.compat()), + })); + } + } + } + } + + Err(errors::FindExecutableError::ExecutableNotFound) + } +} + +/// Errors that can occur when working with archives +pub mod errors { + use thiserror::Error; + + /// Errors that can occur when parsing archive info + #[derive(Debug, Error)] + #[non_exhaustive] + pub enum ArchiveInfoFromStrError { + /// The string is not a valid archive descriptor. E.g. `{name}.tar.gz` + #[error("string `{0}` is not a valid archive descriptor")] + Invalid(String), + + /// The archive type is not supported. E.g. `{name}.zip.gz` + #[error("archive type {0:?} with encoding {1:?} is not supported")] + Unsupported(super::ArchiveKind, Option), + } + + /// Errors that can occur when finding an executable in an archive + #[derive(Debug, Error)] + #[non_exhaustive] + pub enum FindExecutableError { + /// The executable was not found in the archive + #[error("failed to find executable in archive")] + ExecutableNotFound, + + /// An IO error occurred + #[error("IO error")] + Io(#[from] std::io::Error), + + /// An error occurred reading the zip archive + #[error("failed to read zip archive")] + Zip(#[from] async_zip::error::ZipError), + } +} diff --git a/src/engine/source/github/engine_ref.rs b/src/engine/source/github/engine_ref.rs new file mode 100644 index 0000000..8792b02 --- /dev/null +++ b/src/engine/source/github/engine_ref.rs @@ -0,0 +1,19 @@ +use serde::Deserialize; + +/// A GitHub release +#[derive(Debug, Eq, PartialEq, Hash, Clone, Deserialize)] +pub struct Release { + /// The tag name of the release + pub tag_name: String, + /// The assets of the release + pub assets: Vec, +} + +/// An asset of a GitHub release +#[derive(Debug, Eq, PartialEq, Hash, Clone, Deserialize)] +pub struct Asset { + /// The name of the asset + pub name: String, + /// The download URL of the asset + pub url: url::Url, +} diff --git a/src/engine/source/github/mod.rs b/src/engine/source/github/mod.rs new file mode 100644 index 0000000..95a31a2 --- /dev/null +++ b/src/engine/source/github/mod.rs @@ -0,0 +1,137 @@ +/// The GitHub engine reference +pub mod engine_ref; + +use crate::{ + engine::source::{ + archive::Archive, + github::engine_ref::Release, + traits::{DownloadOptions, EngineSource, ResolveOptions}, + }, + reporters::{response_to_async_read, DownloadProgressReporter}, + version_matches, +}; +use reqwest::header::ACCEPT; +use semver::{Version, VersionReq}; +use std::{collections::BTreeMap, path::PathBuf}; +use tokio::io::AsyncBufRead; + +/// The GitHub engine source +#[derive(Debug, Eq, PartialEq, Hash, Clone)] +pub struct GitHubEngineSource { + /// The owner of the repository to download from + pub owner: String, + /// The repository of which to download releases from + pub repo: String, + /// The template for the asset name. `{VERSION}` will be replaced with the version + pub asset_template: String, +} + +impl EngineSource for GitHubEngineSource { + type Ref = Release; + type ResolveError = errors::ResolveError; + type DownloadError = errors::DownloadError; + + fn directory(&self) -> PathBuf { + PathBuf::from("github").join(&self.owner).join(&self.repo) + } + + fn expected_file_name(&self) -> &str { + &self.repo + } + + async fn resolve( + &self, + requirement: &VersionReq, + options: &ResolveOptions, + ) -> Result, Self::ResolveError> { + let ResolveOptions { reqwest, .. } = options; + + Ok(reqwest + .get(format!( + "https://api.github.com/repos/{}/{}/releases", + urlencoding::encode(&self.owner), + urlencoding::encode(&self.repo), + )) + .send() + .await? + .error_for_status()? + .json::>() + .await? + .into_iter() + .filter_map( + |release| match release.tag_name.trim_start_matches('v').parse() { + Ok(version) if version_matches(&version, requirement) => { + Some((version, release)) + } + _ => None, + }, + ) + .collect()) + } + + async fn download( + &self, + engine_ref: &Self::Ref, + options: &DownloadOptions, + ) -> Result, Self::DownloadError> { + let DownloadOptions { + reqwest, + reporter, + version, + .. + } = options; + + let desired_asset_name = self + .asset_template + .replace("{VERSION}", &version.to_string()); + + let asset = engine_ref + .assets + .iter() + .find(|asset| asset.name.eq_ignore_ascii_case(&desired_asset_name)) + .ok_or(errors::DownloadError::AssetNotFound)?; + + let response = reqwest + .get(asset.url.clone()) + .header(ACCEPT, "application/octet-stream") + .send() + .await? + .error_for_status()?; + + Ok(Archive { + info: asset.name.parse()?, + reader: response_to_async_read(response, reporter.clone()), + }) + } +} + +/// Errors that can occur when working with the GitHub engine source +pub mod errors { + use thiserror::Error; + + /// Errors that can occur when resolving a GitHub engine + #[derive(Debug, Error)] + #[non_exhaustive] + pub enum ResolveError { + /// Handling the request failed + #[error("failed to handle GitHub API request")] + Request(#[from] reqwest::Error), + } + + /// Errors that can occur when downloading a GitHub engine + #[derive(Debug, Error)] + #[non_exhaustive] + pub enum DownloadError { + /// An asset for the current platform could not be found + #[error("failed to find asset for current platform")] + AssetNotFound, + + /// Handling the request failed + #[error("failed to handle GitHub API request")] + Request(#[from] reqwest::Error), + + /// The asset's name could not be parsed + #[error("failed to parse asset name")] + ParseAssetName(#[from] crate::engine::source::archive::errors::ArchiveInfoFromStrError), + } +} diff --git a/src/engine/source/mod.rs b/src/engine/source/mod.rs new file mode 100644 index 0000000..5d26501 --- /dev/null +++ b/src/engine/source/mod.rs @@ -0,0 +1,144 @@ +use crate::{ + engine::source::{ + archive::Archive, + traits::{DownloadOptions, EngineSource, ResolveOptions}, + }, + reporters::DownloadProgressReporter, +}; +use semver::{Version, VersionReq}; +use std::{collections::BTreeMap, path::PathBuf}; +use tokio::io::AsyncBufRead; + +/// Archives +pub mod archive; +/// The GitHub engine source +pub mod github; +/// Traits for engine sources +pub mod traits; + +/// Engine references +#[derive(Debug, Eq, PartialEq, Hash, Clone)] +pub enum EngineRefs { + /// A GitHub engine reference + GitHub(github::engine_ref::Release), +} + +/// Engine sources +#[derive(Debug, Eq, PartialEq, Hash, Clone)] +pub enum EngineSources { + /// A GitHub engine source + GitHub(github::GitHubEngineSource), +} + +impl EngineSource for EngineSources { + type Ref = EngineRefs; + type ResolveError = errors::ResolveError; + type DownloadError = errors::DownloadError; + + fn directory(&self) -> PathBuf { + match self { + EngineSources::GitHub(source) => source.directory(), + } + } + + fn expected_file_name(&self) -> &str { + match self { + EngineSources::GitHub(source) => source.expected_file_name(), + } + } + + async fn resolve( + &self, + requirement: &VersionReq, + options: &ResolveOptions, + ) -> Result, Self::ResolveError> { + match self { + EngineSources::GitHub(source) => source + .resolve(requirement, options) + .await + .map(|map| { + map.into_iter() + .map(|(version, release)| (version, EngineRefs::GitHub(release))) + .collect() + }) + .map_err(Into::into), + } + } + + async fn download( + &self, + engine_ref: &Self::Ref, + options: &DownloadOptions, + ) -> Result, Self::DownloadError> { + match (self, engine_ref) { + (EngineSources::GitHub(source), EngineRefs::GitHub(release)) => { + source.download(release, options).await.map_err(Into::into) + } + + // for the future + #[allow(unreachable_patterns)] + _ => Err(errors::DownloadError::Mismatch), + } + } +} + +impl EngineSources { + /// Returns the source for the pesde engine + pub fn pesde() -> Self { + let mut parts = env!("CARGO_PKG_REPOSITORY").split('/').skip(3); + let (owner, repo) = ( + parts.next().unwrap().to_string(), + parts.next().unwrap().to_string(), + ); + + EngineSources::GitHub(github::GitHubEngineSource { + owner, + repo, + asset_template: format!( + "pesde-{{VERSION}}-{}-{}.zip", + std::env::consts::OS, + std::env::consts::ARCH + ), + }) + } + + /// Returns the source for the lune engine + pub fn lune() -> Self { + EngineSources::GitHub(github::GitHubEngineSource { + owner: "lune-org".into(), + repo: "lune".into(), + asset_template: format!( + "lune-{{VERSION}}-{}-{}.zip", + std::env::consts::OS, + std::env::consts::ARCH + ), + }) + } +} + +/// Errors that can occur when working with engine sources +pub mod errors { + use thiserror::Error; + + /// Errors that can occur when resolving an engine + #[derive(Debug, Error)] + #[non_exhaustive] + pub enum ResolveError { + /// Failed to resolve the GitHub engine + #[error("failed to resolve github engine")] + GitHub(#[from] super::github::errors::ResolveError), + } + + /// Errors that can occur when downloading an engine + #[derive(Debug, Error)] + #[non_exhaustive] + pub enum DownloadError { + /// Failed to download the GitHub engine + #[error("failed to download github engine")] + GitHub(#[from] super::github::errors::DownloadError), + + /// Mismatched engine reference + #[error("mismatched engine reference")] + Mismatch, + } +} diff --git a/src/engine/source/traits.rs b/src/engine/source/traits.rs new file mode 100644 index 0000000..6457486 --- /dev/null +++ b/src/engine/source/traits.rs @@ -0,0 +1,54 @@ +use crate::{engine::source::archive::Archive, reporters::DownloadProgressReporter}; +use semver::{Version, VersionReq}; +use std::{collections::BTreeMap, fmt::Debug, future::Future, path::PathBuf, sync::Arc}; +use tokio::io::AsyncBufRead; + +/// Options for resolving an engine +#[derive(Debug, Clone)] +pub struct ResolveOptions { + /// The reqwest client to use + pub reqwest: reqwest::Client, +} + +/// Options for downloading an engine +#[derive(Debug, Clone)] +pub struct DownloadOptions { + /// The reqwest client to use + pub reqwest: reqwest::Client, + /// The reporter to use + pub reporter: Arc, + /// The version of the engine to be downloaded + pub version: Version, +} + +/// A source of engines +pub trait EngineSource: Debug { + /// The reference type for this source + type Ref; + /// The error type for resolving an engine from this source + type ResolveError: std::error::Error + Send + Sync + 'static; + /// The error type for downloading an engine from this source + type DownloadError: std::error::Error + Send + Sync + 'static; + + /// Returns the folder to store the engine's versions in + fn directory(&self) -> PathBuf; + + /// Returns the expected file name of the engine in the archive + fn expected_file_name(&self) -> &str; + + /// Resolves a requirement to a reference + fn resolve( + &self, + requirement: &VersionReq, + options: &ResolveOptions, + ) -> impl Future, Self::ResolveError>> + Send + Sync; + + /// Downloads an engine + fn download( + &self, + engine_ref: &Self::Ref, + options: &DownloadOptions, + ) -> impl Future, Self::DownloadError>> + + Send + + Sync; +} diff --git a/src/lib.rs b/src/lib.rs index 9595a4d..642bbd8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,6 +15,7 @@ use async_stream::try_stream; use fs_err::tokio as fs; use futures::Stream; use gix::sec::identity::Account; +use semver::{Version, VersionReq}; use std::{ collections::{HashMap, HashSet}, fmt::Debug, @@ -29,6 +30,8 @@ use wax::Pattern; pub mod download; /// Utility for downloading and linking in the correct order pub mod download_and_link; +/// Handling of engines +pub mod engine; /// Graphs pub mod graph; /// Linking packages @@ -117,8 +120,8 @@ struct ProjectShared { package_dir: PathBuf, workspace_dir: Option, data_dir: PathBuf, - auth_config: AuthConfig, cas_dir: PathBuf, + auth_config: AuthConfig, } /// The main struct of the pesde library, representing a project @@ -130,11 +133,11 @@ pub struct Project { impl Project { /// Create a new `Project` - pub fn new, Q: AsRef, R: AsRef, S: AsRef>( - package_dir: P, - workspace_dir: Option, - data_dir: R, - cas_dir: S, + pub fn new( + package_dir: impl AsRef, + workspace_dir: Option>, + data_dir: impl AsRef, + cas_dir: impl AsRef, auth_config: AuthConfig, ) -> Self { Project { @@ -142,8 +145,8 @@ impl Project { package_dir: package_dir.as_ref().to_path_buf(), workspace_dir: workspace_dir.map(|d| d.as_ref().to_path_buf()), data_dir: data_dir.as_ref().to_path_buf(), - auth_config, cas_dir: cas_dir.as_ref().to_path_buf(), + auth_config, }), } } @@ -163,16 +166,16 @@ impl Project { &self.shared.data_dir } - /// The authentication configuration - pub fn auth_config(&self) -> &AuthConfig { - &self.shared.auth_config - } - /// The CAS (content-addressable storage) directory pub fn cas_dir(&self) -> &Path { &self.shared.cas_dir } + /// The authentication configuration + pub fn auth_config(&self) -> &AuthConfig { + &self.shared.auth_config + } + /// Read the manifest file #[instrument(skip(self), ret(level = "trace"), level = "debug")] pub async fn read_manifest(&self) -> Result { @@ -425,6 +428,12 @@ pub async fn find_roots( Ok((project_root.unwrap_or(cwd), workspace_dir)) } +/// Returns whether a version matches a version requirement +/// Differs from `VersionReq::matches` in that EVERY version matches `*` +pub fn version_matches(version: &Version, req: &VersionReq) -> bool { + *req == VersionReq::STAR || req.matches(version) +} + /// Errors that can occur when using the pesde library pub mod errors { use std::path::PathBuf; diff --git a/src/main.rs b/src/main.rs index 984147c..7aa2c61 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,16 @@ #[cfg(feature = "version-management")] -use crate::cli::version::{check_for_updates, get_or_download_version, TagInfo}; +use crate::cli::version::{check_for_updates, current_version, get_or_download_engine}; use crate::cli::{auth::get_tokens, display_err, home_dir, HOME_DIR}; use anyhow::Context; use clap::{builder::styling::AnsiColor, Parser}; use fs_err::tokio as fs; use indicatif::MultiProgress; -use pesde::{find_roots, AuthConfig, Project}; +use pesde::{engine::EngineKind, find_roots, AuthConfig, Project}; +use semver::VersionReq; use std::{ io, path::{Path, PathBuf}, + str::FromStr, sync::Mutex, }; use tempfile::NamedTempFile; @@ -135,27 +137,30 @@ impl<'a> MakeWriter<'a> for IndicatifWriter { async fn run() -> anyhow::Result<()> { let cwd = std::env::current_dir().expect("failed to get current working directory"); + let current_exe = std::env::current_exe().expect("failed to get current executable path"); + let exe_name = current_exe.file_stem().unwrap(); #[cfg(windows)] 'scripts: { - let exe = std::env::current_exe().expect("failed to get current executable path"); - if exe.parent().is_some_and(|parent| { - parent.file_name().is_some_and(|parent| parent != "bin") - || parent - .parent() - .and_then(|parent| parent.file_name()) - .is_some_and(|parent| parent != HOME_DIR) - }) { + // we're called the same as the binary, so we're not a (legal) script + if exe_name == env!("CARGO_PKG_NAME") { break 'scripts; } - let exe_name = exe.file_name().unwrap().to_string_lossy(); - let exe_name = exe_name - .strip_suffix(std::env::consts::EXE_SUFFIX) - .unwrap_or(&exe_name); + if let Some(bin_folder) = current_exe.parent() { + // we're not in {path}/bin/{exe} + if bin_folder.file_name().is_some_and(|parent| parent != "bin") { + break 'scripts; + } - if exe_name == env!("CARGO_BIN_NAME") { - break 'scripts; + // we're not in {path}/.pesde/bin/{exe} + if bin_folder + .parent() + .and_then(|home_folder| home_folder.file_name()) + .is_some_and(|home_folder| home_folder != HOME_DIR) + { + break 'scripts; + } } // the bin script will search for the project root itself, so we do that to ensure @@ -164,9 +169,11 @@ async fn run() -> anyhow::Result<()> { let status = std::process::Command::new("lune") .arg("run") .arg( - exe.parent() - .map(|p| p.join(".impl").join(exe.file_name().unwrap())) - .unwrap_or(exe) + current_exe + .parent() + .unwrap_or(¤t_exe) + .join(".impl") + .join(current_exe.file_name().unwrap()) .with_extension("luau"), ) .arg("--") @@ -265,34 +272,50 @@ async fn run() -> anyhow::Result<()> { }; #[cfg(feature = "version-management")] - { - let target_version = project + 'engines: { + let Some(engine) = exe_name + .to_str() + .and_then(|str| EngineKind::from_str(str).ok()) + else { + break 'engines; + }; + + let req = project .deser_manifest() .await .ok() - .and_then(|manifest| manifest.pesde_version); + .and_then(|mut manifest| manifest.engines.remove(&engine)); - let exe_path = if let Some(version) = target_version { - get_or_download_version(&reqwest, TagInfo::Incomplete(version), false).await? - } else { - None - }; - - if let Some(exe_path) = exe_path { - let status = std::process::Command::new(exe_path) - .args(std::env::args_os().skip(1)) - .status() - .expect("failed to run new version"); - - std::process::exit(status.code().unwrap()); + if engine == EngineKind::Pesde { + match &req { + // we're already running a compatible version + Some(req) if req.matches(¤t_version()) => break 'engines, + // the user has not requested a specific version, so we'll just use the current one + None => break 'engines, + _ => (), + } } - display_err( - check_for_updates(&reqwest).await, - " while checking for updates", - ); + let exe_path = + get_or_download_engine(&reqwest, engine, req.unwrap_or(VersionReq::STAR)).await?; + if exe_path == current_exe { + break 'engines; + } + + let status = std::process::Command::new(exe_path) + .args(std::env::args_os().skip(1)) + .status() + .expect("failed to run new version"); + + std::process::exit(status.code().unwrap()); } + #[cfg(feature = "version-management")] + display_err( + check_for_updates(&reqwest).await, + " while checking for updates", + ); + let cli = Cli::parse(); cli.subcommand.run(project, reqwest).await diff --git a/src/manifest/mod.rs b/src/manifest/mod.rs index 1e6a671..66d4c4b 100644 --- a/src/manifest/mod.rs +++ b/src/manifest/mod.rs @@ -1,4 +1,5 @@ use crate::{ + engine::EngineKind, manifest::{ overrides::{OverrideKey, OverrideSpecifier}, target::Target, @@ -7,7 +8,7 @@ use crate::{ source::specifiers::DependencySpecifiers, }; use relative_path::RelativePathBuf; -use semver::Version; +use semver::{Version, VersionReq}; use serde::{Deserialize, Serialize}; use std::collections::{BTreeMap, HashMap}; use tracing::instrument; @@ -85,15 +86,16 @@ pub struct Manifest { crate::names::PackageNames, BTreeMap, >, - #[serde(default, skip_serializing)] - /// Which version of the pesde CLI this package uses - pub pesde_version: Option, /// A list of globs pointing to workspace members' directories #[serde(default, skip_serializing_if = "Vec::is_empty")] pub workspace_members: Vec, /// The Roblox place of this project #[serde(default, skip_serializing)] pub place: BTreeMap, + /// The engines this package supports + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + #[cfg_attr(feature = "schema", schemars(with = "BTreeMap"))] + pub engines: BTreeMap, /// The standard dependencies of the package #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] diff --git a/src/reporters.rs b/src/reporters.rs index d34dc52..ef87c4e 100644 --- a/src/reporters.rs +++ b/src/reporters.rs @@ -9,6 +9,11 @@ #![allow(unused_variables)] +use async_stream::stream; +use futures::StreamExt; +use std::sync::Arc; +use tokio::io::AsyncBufRead; + /// Reports downloads. pub trait DownloadsReporter<'a>: Send + Sync { /// The [`DownloadProgressReporter`] type associated with this reporter. @@ -61,3 +66,32 @@ pub trait PatchProgressReporter: Send + Sync { } impl PatchProgressReporter for () {} + +pub(crate) fn response_to_async_read( + response: reqwest::Response, + reporter: Arc, +) -> impl AsyncBufRead { + let total_len = response.content_length().unwrap_or(0); + reporter.report_progress(total_len, 0); + + let mut bytes_downloaded = 0; + let mut stream = response.bytes_stream(); + let bytes = stream!({ + while let Some(chunk) = stream.next().await { + let chunk = match chunk { + Ok(chunk) => chunk, + Err(err) => { + yield Err(std::io::Error::new(std::io::ErrorKind::Other, err)); + continue; + } + }; + bytes_downloaded += chunk.len() as u64; + reporter.report_progress(total_len, bytes_downloaded); + yield Ok(chunk); + } + + reporter.report_done(); + }); + + tokio_util::io::StreamReader::new(bytes) +} diff --git a/src/source/pesde/mod.rs b/src/source/pesde/mod.rs index 33a6ced..fe0ebe3 100644 --- a/src/source/pesde/mod.rs +++ b/src/source/pesde/mod.rs @@ -8,7 +8,6 @@ use std::{ hash::Hash, path::PathBuf, }; -use tokio_util::io::StreamReader; use pkg_ref::PesdePackageRef; use specifier::PesdeDependencySpecifier; @@ -16,7 +15,7 @@ use specifier::PesdeDependencySpecifier; use crate::{ manifest::{target::Target, DependencyType}, names::{PackageName, PackageNames}, - reporters::DownloadProgressReporter, + reporters::{response_to_async_read, DownloadProgressReporter}, source::{ fs::{store_in_cas, FsEntry, PackageFs}, git_index::{read_file, root_tree, GitBasedSource}, @@ -28,7 +27,7 @@ use crate::{ }; use fs_err::tokio as fs; use futures::StreamExt; -use tokio::task::spawn_blocking; +use tokio::{pin, task::spawn_blocking}; use tracing::instrument; /// The pesde package reference @@ -229,23 +228,8 @@ impl PackageSource for PesdePackageSource { let response = request.send().await?.error_for_status()?; - let total_len = response.content_length().unwrap_or(0); - reporter.report_progress(total_len, 0); - - let mut bytes_downloaded = 0; - let bytes = response - .bytes_stream() - .inspect(|chunk| { - chunk.as_ref().ok().inspect(|chunk| { - bytes_downloaded += chunk.len() as u64; - reporter.report_progress(total_len, bytes_downloaded); - }); - }) - .map(|result| { - result.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)) - }); - - let bytes = StreamReader::new(bytes); + let bytes = response_to_async_read(response, reporter.clone()); + pin!(bytes); let mut decoder = async_compression::tokio::bufread::GzipDecoder::new(bytes); let mut archive = tokio_tar::Archive::new(&mut decoder); @@ -297,8 +281,6 @@ impl PackageSource for PesdePackageSource { .await .map_err(errors::DownloadError::WriteIndex)?; - reporter.report_done(); - Ok(fs) } diff --git a/src/source/wally/mod.rs b/src/source/wally/mod.rs index 49d5e4e..36b13c4 100644 --- a/src/source/wally/mod.rs +++ b/src/source/wally/mod.rs @@ -1,7 +1,7 @@ use crate::{ manifest::target::{Target, TargetKind}, names::PackageNames, - reporters::DownloadProgressReporter, + reporters::{response_to_async_read, DownloadProgressReporter}, source::{ fs::{store_in_cas, FsEntry, PackageFs}, git_index::{read_file, root_tree, GitBasedSource}, @@ -20,14 +20,13 @@ use crate::{ Project, }; use fs_err::tokio as fs; -use futures::StreamExt; use gix::Url; use relative_path::RelativePathBuf; use reqwest::header::AUTHORIZATION; use serde::Deserialize; use std::{collections::BTreeMap, path::PathBuf}; -use tokio::{io::AsyncReadExt, task::spawn_blocking}; -use tokio_util::{compat::FuturesAsyncReadCompatExt, io::StreamReader}; +use tokio::{io::AsyncReadExt, pin, task::spawn_blocking}; +use tokio_util::compat::FuturesAsyncReadCompatExt; use tracing::instrument; pub(crate) mod compat_util; @@ -268,22 +267,9 @@ impl PackageSource for WallyPackageSource { let response = request.send().await?.error_for_status()?; let total_len = response.content_length().unwrap_or(0); - reporter.report_progress(total_len, 0); + let bytes = response_to_async_read(response, reporter.clone()); + pin!(bytes); - let mut bytes_downloaded = 0; - let bytes = response - .bytes_stream() - .inspect(|chunk| { - chunk.as_ref().ok().inspect(|chunk| { - bytes_downloaded += chunk.len() as u64; - reporter.report_progress(total_len, bytes_downloaded); - }); - }) - .map(|result| { - result.map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)) - }); - - let mut bytes = StreamReader::new(bytes); let mut buf = Vec::with_capacity(total_len as usize); bytes.read_to_end(&mut buf).await?; @@ -335,8 +321,6 @@ impl PackageSource for WallyPackageSource { .await .map_err(errors::DownloadError::WriteIndex)?; - reporter.report_done(); - Ok(fs) }