fix(registry): avoid race condition in search
Some checks are pending
Debug / Get build version (push) Waiting to run
Debug / Build for linux-x86_64 (push) Blocked by required conditions
Debug / Build for macos-aarch64 (push) Blocked by required conditions
Debug / Build for macos-x86_64 (push) Blocked by required conditions
Debug / Build for windows-x86_64 (push) Blocked by required conditions
Test & Lint / lint (push) Waiting to run

This commit is contained in:
daimond113 2025-01-10 09:24:33 +01:00
parent 6f4c7137c0
commit dcc869c025
No known key found for this signature in database
GPG key ID: 3A8ECE51328B513C
4 changed files with 18 additions and 43 deletions

View file

@ -419,7 +419,7 @@ pub async fn publish_package(
) )
.await?; .await?;
update_search_version(&app_state, &manifest.name, &manifest.version, &new_entry); update_search_version(&app_state, &manifest.name, &new_entry);
} }
let version_id = VersionId::new(manifest.version.clone(), manifest.target.kind()); let version_id = VersionId::new(manifest.version.clone(), manifest.target.kind());

View file

@ -1,21 +1,20 @@
use std::collections::HashMap;
use crate::{ use crate::{
error::RegistryError, error::RegistryError,
package::{read_package, PackageResponse}, package::{read_package, PackageResponse},
search::find_max_searchable,
AppState, AppState,
}; };
use actix_web::{web, HttpResponse}; use actix_web::{web, HttpResponse};
use pesde::names::PackageName; use pesde::names::PackageName;
use semver::Version;
use serde::Deserialize; use serde::Deserialize;
use std::{collections::HashMap, sync::Arc};
use tantivy::{collector::Count, query::AllQuery, schema::Value, DateTime, Order}; use tantivy::{collector::Count, query::AllQuery, schema::Value, DateTime, Order};
use tokio::task::JoinSet; use tokio::task::JoinSet;
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Request { pub struct Request {
#[serde(default)] #[serde(default)]
query: Option<String>, query: String,
#[serde(default)] #[serde(default)]
offset: usize, offset: usize,
} }
@ -28,9 +27,8 @@ pub async fn search_packages(
let schema = searcher.schema(); let schema = searcher.schema();
let id = schema.get_field("id").unwrap(); let id = schema.get_field("id").unwrap();
let version = schema.get_field("version").unwrap();
let query = request_query.query.as_deref().unwrap_or_default().trim(); let query = request_query.query.trim();
let query = if query.is_empty() { let query = if query.is_empty() {
Box::new(AllQuery) Box::new(AllQuery)
@ -50,8 +48,7 @@ pub async fn search_packages(
) )
.unwrap(); .unwrap();
// prevent a write lock on the source while we're reading the documents let source = Arc::new(app_state.source.clone().read_owned().await);
let _guard = app_state.source.read().await;
let mut results = Vec::with_capacity(top_docs.len()); let mut results = Vec::with_capacity(top_docs.len());
results.extend((0..top_docs.len()).map(|_| None::<PackageResponse>)); results.extend((0..top_docs.len()).map(|_| None::<PackageResponse>));
@ -62,6 +59,7 @@ pub async fn search_packages(
.map(|(i, (_, doc_address))| { .map(|(i, (_, doc_address))| {
let app_state = app_state.clone(); let app_state = app_state.clone();
let doc = searcher.doc::<HashMap<_, _>>(doc_address).unwrap(); let doc = searcher.doc::<HashMap<_, _>>(doc_address).unwrap();
let source = source.clone();
async move { async move {
let id = doc let id = doc
@ -71,24 +69,10 @@ pub async fn search_packages(
.unwrap() .unwrap()
.parse::<PackageName>() .parse::<PackageName>()
.unwrap(); .unwrap();
let version = doc
.get(&version)
.unwrap()
.as_str()
.unwrap()
.parse::<Version>()
.unwrap();
let file = read_package(&app_state, &id, &*app_state.source.read().await) let file = read_package(&app_state, &id, &source).await?.unwrap();
.await?
.unwrap();
let version_id = file let (version_id, _) = find_max_searchable(&file).unwrap();
.entries
.keys()
.filter(|v_id| *v_id.version() == version)
.max()
.unwrap();
Ok::<_, RegistryError>((i, PackageResponse::new(&id, version_id, &file))) Ok::<_, RegistryError>((i, PackageResponse::new(&id, version_id, &file)))
} }

View file

@ -20,7 +20,7 @@ use pesde::{
}, },
AuthConfig, Project, AuthConfig, Project,
}; };
use std::{env::current_dir, path::PathBuf}; use std::{env::current_dir, path::PathBuf, sync::Arc};
use tracing::level_filters::LevelFilter; use tracing::level_filters::LevelFilter;
use tracing_subscriber::{ use tracing_subscriber::{
fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter, fmt::format::FmtSpan, layer::SubscriberExt, util::SubscriberInitExt, EnvFilter,
@ -47,7 +47,7 @@ pub fn make_reqwest() -> reqwest::Client {
} }
pub struct AppState { pub struct AppState {
pub source: tokio::sync::RwLock<PesdePackageSource>, pub source: Arc<tokio::sync::RwLock<PesdePackageSource>>,
pub project: Project, pub project: Project,
pub storage: Storage, pub storage: Storage,
pub auth: Auth, pub auth: Auth,
@ -134,7 +134,7 @@ async fn run() -> std::io::Result<()> {
tracing::info!("auth: {auth}"); tracing::info!("auth: {auth}");
auth auth
}, },
source: tokio::sync::RwLock::new(source), source: Arc::new(tokio::sync::RwLock::new(source)),
project, project,
search_reader, search_reader,

View file

@ -10,7 +10,6 @@ use pesde::{
}, },
Project, Project,
}; };
use semver::Version;
use tantivy::{ use tantivy::{
doc, doc,
query::QueryParser, query::QueryParser,
@ -69,7 +68,7 @@ async fn all_packages(
} }
} }
fn find_max(file: &IndexFile) -> Option<(&VersionId, &IndexFileEntry)> { pub fn find_max_searchable(file: &IndexFile) -> Option<(&VersionId, &IndexFileEntry)> {
file.entries file.entries
.iter() .iter()
.filter(|(_, entry)| !entry.yanked) .filter(|(_, entry)| !entry.yanked)
@ -94,7 +93,6 @@ pub async fn make_search(
); );
let id_field = schema_builder.add_text_field("id", STRING | STORED); let id_field = schema_builder.add_text_field("id", STRING | STORED);
let version = schema_builder.add_text_field("version", STRING | STORED);
let scope = schema_builder.add_text_field("scope", field_options.clone()); let scope = schema_builder.add_text_field("scope", field_options.clone());
let name = schema_builder.add_text_field("name", field_options.clone()); let name = schema_builder.add_text_field("name", field_options.clone());
@ -124,14 +122,13 @@ pub async fn make_search(
continue; continue;
} }
let Some((v_id, latest_entry)) = find_max(&file) else { let Some((_, latest_entry)) = find_max_searchable(&file) else {
continue; continue;
}; };
search_writer search_writer
.add_document(doc!( .add_document(doc!(
id_field => pkg_name.to_string(), id_field => pkg_name.to_string(),
version => v_id.version().to_string(),
scope => pkg_name.scope(), scope => pkg_name.scope(),
name => pkg_name.name(), name => pkg_name.name(),
description => latest_entry.description.clone().unwrap_or_default(), description => latest_entry.description.clone().unwrap_or_default(),
@ -150,12 +147,7 @@ pub async fn make_search(
(search_reader, search_writer, query_parser) (search_reader, search_writer, query_parser)
} }
pub fn update_search_version( pub fn update_search_version(app_state: &AppState, name: &PackageName, entry: &IndexFileEntry) {
app_state: &AppState,
name: &PackageName,
version: &Version,
entry: &IndexFileEntry,
) {
let mut search_writer = app_state.search_writer.lock().unwrap(); let mut search_writer = app_state.search_writer.lock().unwrap();
let schema = search_writer.index().schema(); let schema = search_writer.index().schema();
let id_field = schema.get_field("id").unwrap(); let id_field = schema.get_field("id").unwrap();
@ -164,7 +156,6 @@ pub fn update_search_version(
search_writer.add_document(doc!( search_writer.add_document(doc!(
id_field => name.to_string(), id_field => name.to_string(),
schema.get_field("version").unwrap() => version.to_string(),
schema.get_field("scope").unwrap() => name.scope(), schema.get_field("scope").unwrap() => name.scope(),
schema.get_field("name").unwrap() => name.name(), schema.get_field("name").unwrap() => name.name(),
schema.get_field("description").unwrap() => entry.description.clone().unwrap_or_default(), schema.get_field("description").unwrap() => entry.description.clone().unwrap_or_default(),
@ -177,12 +168,12 @@ pub fn update_search_version(
pub fn search_version_changed(app_state: &AppState, name: &PackageName, file: &IndexFile) { pub fn search_version_changed(app_state: &AppState, name: &PackageName, file: &IndexFile) {
let entry = if file.meta.deprecated.is_empty() { let entry = if file.meta.deprecated.is_empty() {
find_max(file) find_max_searchable(file)
} else { } else {
None None
}; };
let Some((v_id, entry)) = entry else { let Some((_, entry)) = entry else {
let mut search_writer = app_state.search_writer.lock().unwrap(); let mut search_writer = app_state.search_writer.lock().unwrap();
let schema = search_writer.index().schema(); let schema = search_writer.index().schema();
let id_field = schema.get_field("id").unwrap(); let id_field = schema.get_field("id").unwrap();
@ -194,5 +185,5 @@ pub fn search_version_changed(app_state: &AppState, name: &PackageName, file: &I
return; return;
}; };
update_search_version(app_state, name, v_id.version(), entry); update_search_version(app_state, name, entry);
} }