Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/liger_kernel/chunked_loss/fused_linear_distillation.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _compute_loss(

hard_loss /= full_target.shape[0]

soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
soft_loss /= full_target.shape[0]

loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
Expand Down Expand Up @@ -180,9 +180,9 @@ def forward(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
beta=beta,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
beta=beta,
**loss_kwargs,
)

Expand Down
19 changes: 12 additions & 7 deletions src/liger_kernel/chunked_loss/jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,20 @@ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)

# Compute probabilities (only required for mean calculation)
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
log_mean_probs = mean_probs.log()
if beta == 0:
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
elif beta == 1:
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
else:
# Compute probabilities (only required for mean calculation)
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
log_mean_probs = mean_probs.log()

student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)

# JSD is the weighted average of the KL divergences
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
# JSD is the weighted average of the KL divergences
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
return jsd_loss

@classmethod
Expand Down
41 changes: 30 additions & 11 deletions src/liger_kernel/ops/jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,24 +51,43 @@ def _jsd_kernel(
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)

if beta == 0.0: # forward KL
Y_prob = tl.exp(Y)
Y_max = tl.max(Y, axis=0)
Y_shifted = Y - Y_max
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
loss = Y_prob * (Y - X)
dX = -Y_prob
elif beta == 1.0:
X_prob = tl.exp(X)
elif beta == 1.0: # reverse KL
X_max = tl.max(X, axis=0)
X_shifted = X - X_max
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
loss = X_prob * (X - Y)
dX = loss + X_prob
else:
Q = tl.exp(X)
P = tl.exp(Y)
M = beta * P + (1 - beta) * Q
log_M = tl.log(M)
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
X_shifted = X - max_val
Y_shifted = Y - max_val

loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
dX = (1 - beta) * Q * (X - log_M)
# Pre-compute exp(max_val) since it's used twice
exp_max = tl.exp(max_val)

# Compute exp terms with compensation
Q = tl.exp(X_shifted) * exp_max # = exp(X)
P = tl.exp(Y_shifted) * exp_max # = exp(Y)

# Pre-compute common terms
beta_P = beta * P
one_minus_beta_Q = (1 - beta) * Q
M = beta_P + one_minus_beta_Q
log_M = tl.log(M) # No need to compensate as M is already in original scale

loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
dX = one_minus_beta_Q * (X - log_M)

# Pre-compute scaling factor
scale = 1.0 / n_non_ignore
loss = loss * scale
dX = dX * scale

loss = loss / n_non_ignore
dX = dX / n_non_ignore
tl.store(loss_ptr + offsets, loss, mask=mask)
tl.store(dX_ptr + offsets, dX, mask=mask)

Expand Down
25 changes: 15 additions & 10 deletions test/chunked_loss/test_jsd_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ def __init__(
ignore_index: int = -100,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
beta: float = 0.5,
):
super().__init__(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
)
self.beta = (beta,)

def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
"""
Expand All @@ -50,15 +48,20 @@ def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
student_log_probs = F.log_softmax(student_logits, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)

# Compute probabilities (only required for mean calculation)
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
log_mean_probs = mean_probs.log()
if beta == 0:
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
elif beta == 1:
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
else:
# Compute probabilities (only required for mean calculation)
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
log_mean_probs = mean_probs.log()

student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True)
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True)
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True)
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True)

# JSD is the weighted average of the KL divergences
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
# JSD is the weighted average of the KL divergences
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
return jsd_loss


Expand Down Expand Up @@ -88,12 +91,12 @@ def __init__(
# smaller student model weights
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device)
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device)
self.beta = beta
self.jsd = HFJSDLoss(
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
temperature=temperature,
beta=beta,
).get_batch_loss_metrics

def forward(self, student_input, teacher_input, target):
Expand All @@ -105,6 +108,7 @@ def forward(self, student_input, teacher_input, target):
target,
self.student_lin.bias,
self.teacher_lin.bias,
beta=self.beta,
)
return jsd_loss

Expand Down Expand Up @@ -132,6 +136,7 @@ def __init__(
weight_soft_loss=weight_soft_loss,
ignore_index=ignore_index,
temperature=temperature,
beta=beta,
)

def forward(self, student_input, teacher_input, target):
Expand Down
6 changes: 3 additions & 3 deletions test/transformers/test_jsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ def __init__(

def forward(
self,
log_q: torch.Tensor, # input
log_q: torch.Tensor, # input student logits
log_p: torch.Tensor, # target
label: Optional[torch.Tensor] = None,
):
if self.beta == 0.0:
if self.beta == 0.0: # KL(p||q) -> kl(q, p)
loss = self.kl(log_q, log_p).sum(dim=-1)
elif self.beta == 1.0:
elif self.beta == 1.0: # KL(q||p) -> kl(p, q)
loss = self.kl(log_p, log_q).sum(dim=-1)
else:
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)
Expand Down
5 changes: 3 additions & 2 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@ def __init__(
self.temperature = temperature

@abstractmethod
def distillation_loss(self, student_logits, teacher_logits):
def distillation_loss(self, student_logits, teacher_logits, **loss_kwargs):
"""Abstract method for computing distillation loss."""
pass

Expand Down Expand Up @@ -666,6 +666,7 @@ def get_batch_loss_metrics(
target: torch.LongTensor,
student_bias: torch.FloatTensor = None,
teacher_bias: torch.FloatTensor = None,
**loss_kwargs,
):
"""Compute the distillation loss metrics for the given batch."""
forward_output = self.concatenated_forward(
Expand All @@ -686,7 +687,7 @@ def get_batch_loss_metrics(
student_logits /= self.temperature
teacher_logits /= self.temperature

soft_loss = self.distillation_loss(student_logits, teacher_logits)
soft_loss = self.distillation_loss(student_logits, teacher_logits, **loss_kwargs)
# full loss
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss
return loss
Loading