Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,13 +91,20 @@ pub async fn run(
// Optionally download the pooling config.
if pooling.is_none() {
// If a pooling config exist, download it
let _ = download_pool_config(&api_repo).await;
let _ = download_pool_config(&api_repo).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});
}

// Download sentence transformers config
// Download legacy sentence transformers config
// We don't warn on failure as it is a legacy file
let _ = download_st_config(&api_repo).await;
// Download new sentence transformers config
let _ = download_new_st_config(&api_repo).await;
let _ = download_new_st_config(&api_repo).await.map_err(|err| {
tracing::warn!("Download failed: {err}");
err
});

// Download model from the Hub
download_artifacts(&api_repo)
Expand Down Expand Up @@ -387,10 +394,21 @@ fn get_backend_model_type(
None => {
// Load pooling config
let config_path = model_root.join("1_Pooling/config.json");
let config = fs::read_to_string(config_path).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.")?;
let config: PoolConfig =
serde_json::from_str(&config).context("Failed to parse `1_Pooling/config.json`")?;
Pool::try_from(config)?

match fs::read_to_string(config_path) {
Ok(config) => {
let config: PoolConfig = serde_json::from_str(&config)
.context("Failed to parse `1_Pooling/config.json`")?;
Pool::try_from(config)?
}
Err(err) => {
if !config.model_type.to_lowercase().contains("bert") {
return Err(err).context("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model.");
}
tracing::warn!("The `--pooling` arg is not set and we could not find a pooling configuration (`1_Pooling/config.json`) for this model but the model is a BERT variant. Defaulting to `CLS` pooling.");
text_embeddings_backend::Pool::Cls
}
}
}
};
Ok(text_embeddings_backend::ModelType::Embedding(pool))
Expand Down
Loading