-
Notifications
You must be signed in to change notification settings - Fork 156
Open
Description
Context
I’m experimenting with the [FastEmbed](https://github.com/qdrant/fastembed) package to generate sparse embeddings for ~100K candidate profiles, each consisting of 2–10 lines of text. These embeddings are being stored in Qdrant.
Initially, I tested dense embeddings on 1000 profiles:
all-MiniLM-L6-v2
via FastEmbed: ~40 seconds- Switched to
sentence-transformers
: ~6s on GPU, ~10s on CPU
Now, I'm trying sparse embeddings with SPLADE (prithivida/Splade_PP_en_v1
*). Processing 2,340 profiles took ~100 minutes (~3s/profile). At this rate, 100K profiles will take 8+ hours.
Questions
- Is ~3s/profile expected for SPLADE on CPU?
- Is there any way to optimize this (code, hardware, batch size)?
- Are there lighter-weight sparse models or best practices I should consider?
Hardware & Env
- Machine: Intel Core i7 (11th Gen), 16 GB RAM
- Platform: WSL2 (Ubuntu)
- FastEmbed: Latest
- Model:
prithivida/Splade_PP_en_v1
- Qdrant: Local instance with
prefer_grpc=True
- Batch size: 8 (OOMs at 16) [Note I am not specifiying batch size in fastembed model.embed]
Initialization Code
client = QdrantClient(
url="http://localhost:6333",
prefer_grpc=True,
)
FASTEMBED_CACHE_DIR = "/home/saad/.cache/fastembed"
model = SparseTextEmbedding(
model_name="prithivida/Splade_PP_en_v1",
cache_dir=FASTEMBED_CACHE_DIR
)
process_sparse_embeddings_by_batch(
client=client,
collection_name="candidate_profiles_mvp_v1",
data_path="indexing/test_data.jsonl",
model=model,
jsonl_batch_size=8,
)
Batch Embedding Function
def process_sparse_embeddings_by_batch(
client: QdrantClient,
collection_name: str,
data_path: str,
model: SparseTextEmbedding,
jsonl_batch_size: int = 8,
):
total_records = sum(1 for _ in open(data_path))
total_batches = (total_records + jsonl_batch_size - 1) // jsonl_batch_size
logger.info(f"Total records: {total_records}, Total batches: {total_batches}")
jsonl_batches = load_jsonl(data_path, batch_size=jsonl_batch_size)
total_processed = 0
for batch_idx, raw_batch in enumerate(tqdm(jsonl_batches, total=total_batches, desc="Processing JSONL batches")):
try:
extracted_batch = extract_profile_data(raw_batch)
if not extracted_batch:
logger.warning(f"Batch {batch_idx + 1} had no valid profiles, skipping")
continue
batch_df = create_profile_dataframe(extracted_batch)
instance_ids = batch_df["instance_id"].tolist()
full_profile_docs = batch_df["full_profile"].tolist()
inferred_skills_docs = [doc if doc else "" for doc in batch_df["inferred_skills"].tolist()]
# ❗ Sparse embeddings via SPLADE
fp_gen = list(model.embed(full_profile_docs)) # Slow part
is_gen = list(model.embed(inferred_skills_docs))
batch_points = []
for instance_id, fp_emb, is_emb in zip(instance_ids, fp_gen, is_gen):
batch_points.append(
models.PointVectors(
id=instance_id,
vector={
"full_profile_sparse": models.SparseVector(
indices=fp_emb.indices,
values=fp_emb.values,
),
"inferred_skills": models.SparseVector(
indices=is_emb.indices,
values=is_emb.values,
),
},
)
)
if len(batch_points) >= 256:
client.update_vectors(collection_name=collection_name, points=batch_points, wait=False)
batch_points = []
if batch_points:
client.update_vectors(collection_name=collection_name, points=batch_points, wait=False)
total_processed += len(instance_ids)
except Exception as e:
logger.error(f"Error processing batch {batch_idx + 1}: {e}")
continue
logger.info(f"Completed processing. Total profiles processed: {total_processed}")
Sample Output
Processing JSONL batches: 19%|█▉ | 2340/12285 [1:39:38<8:24:04, 3.04s/it]
OOM starts happening when I increase the jsonl_batch_size
beyond 8. That means SPLADE doesn’t scale well on my CPU-based setup.
Would appreciate any help to improve this or confirm if this is the expected behavior.
Metadata
Metadata
Assignees
Labels
No labels