Skip to content
69 changes: 69 additions & 0 deletions examples/train_pylate_contrastive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import torch
from datasets import load_dataset
from sentence_transformers import (
SentenceTransformerTrainer,
SentenceTransformerTrainingArguments,
)

from pylate import evaluation, losses, models, utils

# Add at the start of your train_pylate_contrastive.py
import torch
torch._inductor.config.fallback_random = True
torch._inductor.config.triton.unique_kernel_names = True
# Or completely disable torch compile
# model.forward = torch.compile(model.forward, mode="max-autotune", fullgraph=True)

# Define model parameters for contrastive training
model_name = "answerdotai/ModernBERT-large" # Choose the pre-trained model you want to use as base
batch_size = 64 # Larger batch size often improves results, but requires more memory

num_train_epochs = 5 # Adjust based on your requirements
# Set the run name for logging and output directory
run_name = "contrastive-sigrid-241226"
output_dir = f"output/{run_name}"

# 1. Here we define our ColBERT model. If not a ColBERT model, will add a linear layer to the base encoder.
model = models.ColBERT(model_name_or_path=model_name)

# Load dataset
dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
# Split the dataset (this dataset does not have a validation set, so we split the training set)
splits = dataset.train_test_split(test_size=0.01)
train_dataset = splits["train"]
eval_dataset = splits["test"]

# Define the loss function
train_loss = losses.Contrastive(model=model)

# Initialize the evaluator
dev_evaluator = evaluation.ColBERTTripletEvaluator(
anchors=eval_dataset["query"],
positives=eval_dataset["positive"],
negatives=eval_dataset["negative"],
)

# Configure the training arguments (e.g., batch size, evaluation strategy, logging steps)
args = SentenceTransformerTrainingArguments(
output_dir=output_dir,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
fp16=True, # Set to False if you get an error that your GPU can't run on FP16
bf16=False, # Set to True if you have a GPU that supports BF16
run_name=run_name, # Will be used in W&B if `wandb` is installed
learning_rate=3e-6,
)

# Initialize the trainer for the contrastive training
trainer = SentenceTransformerTrainer(
model=model,
args=args,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
loss=train_loss,
evaluator=dev_evaluator,
data_collator=utils.ColBERTCollator(model.tokenize),
)
# Start the training process
trainer.train()