From dcc869c0253b1183510878118271b5f349b3ba61 Mon Sep 17 00:00:00 2001 From: daimond113 <72147841+daimond113@users.noreply.github.com> Date: Fri, 10 Jan 2025 09:24:33 +0100 Subject: [PATCH] fix(registry): avoid race condition in search --- registry/src/endpoints/publish_version.rs | 2 +- registry/src/endpoints/search.rs | 32 ++++++----------------- registry/src/main.rs | 6 ++--- registry/src/search.rs | 21 +++++---------- 4 files changed, 18 insertions(+), 43 deletions(-) diff --git a/registry/src/endpoints/publish_version.rs b/registry/src/endpoints/publish_version.rs index 1356c1e..0d89b9e 100644 --- a/registry/src/endpoints/publish_version.rs +++ b/registry/src/endpoints/publish_version.rs @@ -419,7 +419,7 @@ pub async fn publish_package( ) .await?; - update_search_version(&app_state, &manifest.name, &manifest.version, &new_entry); + update_search_version(&app_state, &manifest.name, &new_entry); } let version_id = VersionId::new(manifest.version.clone(), manifest.target.kind()); diff --git a/registry/src/endpoints/search.rs b/registry/src/endpoints/search.rs index efbde30..3f23049 100644 --- a/registry/src/endpoints/search.rs +++ b/registry/src/endpoints/search.rs @@ -1,21 +1,20 @@ -use std::collections::HashMap; - use crate::{ error::RegistryError, package::{read_package, PackageResponse}, + search::find_max_searchable, AppState, }; use actix_web::{web, HttpResponse}; use pesde::names::PackageName; -use semver::Version; use serde::Deserialize; +use std::{collections::HashMap, sync::Arc}; use tantivy::{collector::Count, query::AllQuery, schema::Value, DateTime, Order}; use tokio::task::JoinSet; #[derive(Deserialize)] pub struct Request { #[serde(default)] - query: Option, + query: String, #[serde(default)] offset: usize, } @@ -28,9 +27,8 @@ pub async fn search_packages( let schema = searcher.schema(); let id = schema.get_field("id").unwrap(); - let version = schema.get_field("version").unwrap(); - let query = request_query.query.as_deref().unwrap_or_default().trim(); + let query = request_query.query.trim(); let query = if query.is_empty() { Box::new(AllQuery) @@ -50,8 +48,7 @@ pub async fn search_packages( ) .unwrap(); - // prevent a write lock on the source while we're reading the documents - let _guard = app_state.source.read().await; + let source = Arc::new(app_state.source.clone().read_owned().await); let mut results = Vec::with_capacity(top_docs.len()); results.extend((0..top_docs.len()).map(|_| None::)); @@ -62,6 +59,7 @@ pub async fn search_packages( .map(|(i, (_, doc_address))| { let app_state = app_state.clone(); let doc = searcher.doc::>(doc_address).unwrap(); + let source = source.clone(); async move { let id = doc @@ -71,24 +69,10 @@ pub async fn search_packages( .unwrap() .parse::() .unwrap(); - let version = doc - .get(&version) - .unwrap() - .as_str() - .unwrap() - .parse::() - .unwrap(); - let file = read_package(&app_state, &id, &*app_state.source.read().await) - .await? - .unwrap(); + let file = read_package(&app_state, &id, &source).await?.unwrap(); - let version_id = file - .entries - .keys() - .filter(|v_id| *v_id.version() == version) - .max() - .unwrap(); + let (version_id, _) = find_max_searchable(&file).unwrap(); Ok::<_, RegistryError>((i, PackageResponse::new(&id, version_id, &file))) } diff --git a/registry/src/main.rs b/registry/src/main.rs index 2045379..e2af10b 100644 --- a/registry/src/main.rs +++ b/registry/src/main.rs @@ -20,7 +20,7 @@ use pesde::{ }, AuthConfig, Project, }; -use std::{env::current_dir, path::PathBuf}; +use std::{env::current_dir, path::PathBuf, sync::Arc}; use tracing::level_filters::LevelFilter; use tracing_subscriber::{ fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, @@ -47,7 +47,7 @@ pub fn make_reqwest() -> reqwest::Client { } pub struct AppState { - pub source: tokio::sync::RwLock, + pub source: Arc>, pub project: Project, pub storage: Storage, pub auth: Auth, @@ -134,7 +134,7 @@ async fn run() -> std::io::Result<()> { tracing::info!("auth: {auth}"); auth }, - source: tokio::sync::RwLock::new(source), + source: Arc::new(tokio::sync::RwLock::new(source)), project, search_reader, diff --git a/registry/src/search.rs b/registry/src/search.rs index 036434b..4f69448 100644 --- a/registry/src/search.rs +++ b/registry/src/search.rs @@ -10,7 +10,6 @@ use pesde::{ }, Project, }; -use semver::Version; use tantivy::{ doc, query::QueryParser, @@ -69,7 +68,7 @@ async fn all_packages( } } -fn find_max(file: &IndexFile) -> Option<(&VersionId, &IndexFileEntry)> { +pub fn find_max_searchable(file: &IndexFile) -> Option<(&VersionId, &IndexFileEntry)> { file.entries .iter() .filter(|(_, entry)| !entry.yanked) @@ -94,7 +93,6 @@ pub async fn make_search( ); let id_field = schema_builder.add_text_field("id", STRING | STORED); - let version = schema_builder.add_text_field("version", STRING | STORED); let scope = schema_builder.add_text_field("scope", field_options.clone()); let name = schema_builder.add_text_field("name", field_options.clone()); @@ -124,14 +122,13 @@ pub async fn make_search( continue; } - let Some((v_id, latest_entry)) = find_max(&file) else { + let Some((_, latest_entry)) = find_max_searchable(&file) else { continue; }; search_writer .add_document(doc!( id_field => pkg_name.to_string(), - version => v_id.version().to_string(), scope => pkg_name.scope(), name => pkg_name.name(), description => latest_entry.description.clone().unwrap_or_default(), @@ -150,12 +147,7 @@ pub async fn make_search( (search_reader, search_writer, query_parser) } -pub fn update_search_version( - app_state: &AppState, - name: &PackageName, - version: &Version, - entry: &IndexFileEntry, -) { +pub fn update_search_version(app_state: &AppState, name: &PackageName, entry: &IndexFileEntry) { let mut search_writer = app_state.search_writer.lock().unwrap(); let schema = search_writer.index().schema(); let id_field = schema.get_field("id").unwrap(); @@ -164,7 +156,6 @@ pub fn update_search_version( search_writer.add_document(doc!( id_field => name.to_string(), - schema.get_field("version").unwrap() => version.to_string(), schema.get_field("scope").unwrap() => name.scope(), schema.get_field("name").unwrap() => name.name(), schema.get_field("description").unwrap() => entry.description.clone().unwrap_or_default(), @@ -177,12 +168,12 @@ pub fn update_search_version( pub fn search_version_changed(app_state: &AppState, name: &PackageName, file: &IndexFile) { let entry = if file.meta.deprecated.is_empty() { - find_max(file) + find_max_searchable(file) } else { None }; - let Some((v_id, entry)) = entry else { + let Some((_, entry)) = entry else { let mut search_writer = app_state.search_writer.lock().unwrap(); let schema = search_writer.index().schema(); let id_field = schema.get_field("id").unwrap(); @@ -194,5 +185,5 @@ pub fn search_version_changed(app_state: &AppState, name: &PackageName, file: &I return; }; - update_search_version(app_state, name, v_id.version(), entry); + update_search_version(app_state, name, entry); }