Skip to content

[Performance]: 🐢 Sparse Embedding with SPLADE is Very Slow + OOM For larger batches - Expected Behavior? #539

@saadtahlx

Description

@saadtahlx

Context

I’m experimenting with the [FastEmbed](https://github.com/qdrant/fastembed) package to generate sparse embeddings for ~100K candidate profiles, each consisting of 2–10 lines of text. These embeddings are being stored in Qdrant.

Initially, I tested dense embeddings on 1000 profiles:

  • all-MiniLM-L6-v2 via FastEmbed: ~40 seconds
  • Switched to sentence-transformers: ~6s on GPU, ~10s on CPU

Now, I'm trying sparse embeddings with SPLADE (prithivida/Splade_PP_en_v1*). Processing 2,340 profiles took ~100 minutes (~3s/profile). At this rate, 100K profiles will take 8+ hours.


Questions

  1. Is ~3s/profile expected for SPLADE on CPU?
  2. Is there any way to optimize this (code, hardware, batch size)?
  3. Are there lighter-weight sparse models or best practices I should consider?

Hardware & Env

  • Machine: Intel Core i7 (11th Gen), 16 GB RAM
  • Platform: WSL2 (Ubuntu)
  • FastEmbed: Latest
  • Model: prithivida/Splade_PP_en_v1
  • Qdrant: Local instance with prefer_grpc=True
  • Batch size: 8 (OOMs at 16) [Note I am not specifiying batch size in fastembed model.embed]

Initialization Code

client = QdrantClient(
    url="http://localhost:6333",
    prefer_grpc=True,
)

FASTEMBED_CACHE_DIR = "/home/saad/.cache/fastembed"

model = SparseTextEmbedding(
    model_name="prithivida/Splade_PP_en_v1",
    cache_dir=FASTEMBED_CACHE_DIR
)

process_sparse_embeddings_by_batch(
    client=client,
    collection_name="candidate_profiles_mvp_v1",
    data_path="indexing/test_data.jsonl",
    model=model,
    jsonl_batch_size=8,
)

Batch Embedding Function

def process_sparse_embeddings_by_batch(
    client: QdrantClient,
    collection_name: str,
    data_path: str,
    model: SparseTextEmbedding,
    jsonl_batch_size: int = 8,
):
    total_records = sum(1 for _ in open(data_path))
    total_batches = (total_records + jsonl_batch_size - 1) // jsonl_batch_size
    logger.info(f"Total records: {total_records}, Total batches: {total_batches}")

    jsonl_batches = load_jsonl(data_path, batch_size=jsonl_batch_size)
    total_processed = 0

    for batch_idx, raw_batch in enumerate(tqdm(jsonl_batches, total=total_batches, desc="Processing JSONL batches")):
        try:
            extracted_batch = extract_profile_data(raw_batch)
            if not extracted_batch:
                logger.warning(f"Batch {batch_idx + 1} had no valid profiles, skipping")
                continue

            batch_df = create_profile_dataframe(extracted_batch)
            instance_ids = batch_df["instance_id"].tolist()
            full_profile_docs = batch_df["full_profile"].tolist()
            inferred_skills_docs = [doc if doc else "" for doc in batch_df["inferred_skills"].tolist()]

            # ❗ Sparse embeddings via SPLADE
            fp_gen = list(model.embed(full_profile_docs))          # Slow part
            is_gen = list(model.embed(inferred_skills_docs))

            batch_points = []
            for instance_id, fp_emb, is_emb in zip(instance_ids, fp_gen, is_gen):
                batch_points.append(
                    models.PointVectors(
                        id=instance_id,
                        vector={
                            "full_profile_sparse": models.SparseVector(
                                indices=fp_emb.indices,
                                values=fp_emb.values,
                            ),
                            "inferred_skills": models.SparseVector(
                                indices=is_emb.indices,
                                values=is_emb.values,
                            ),
                        },
                    )
                )

                if len(batch_points) >= 256:
                    client.update_vectors(collection_name=collection_name, points=batch_points, wait=False)
                    batch_points = []

            if batch_points:
                client.update_vectors(collection_name=collection_name, points=batch_points, wait=False)

            total_processed += len(instance_ids)

        except Exception as e:
            logger.error(f"Error processing batch {batch_idx + 1}: {e}")
            continue

    logger.info(f"Completed processing. Total profiles processed: {total_processed}")

Sample Output

Processing JSONL batches:  19%|█▉        | 2340/12285 [1:39:38<8:24:04,  3.04s/it]

OOM starts happening when I increase the jsonl_batch_size beyond 8. That means SPLADE doesn’t scale well on my CPU-based setup.


Would appreciate any help to improve this or confirm if this is the expected behavior.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions