From 9f304bddc03970940ac8b414208816fb28f4675b Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Mar 2025 13:48:35 +0100 Subject: [PATCH 1/3] initial fix --- src/liger_kernel/chunked_loss/jsd_loss.py | 19 ++++++++++++------- test/transformers/test_jsd.py | 6 +++--- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/src/liger_kernel/chunked_loss/jsd_loss.py b/src/liger_kernel/chunked_loss/jsd_loss.py index 2f1ca7d32..05440cbbb 100644 --- a/src/liger_kernel/chunked_loss/jsd_loss.py +++ b/src/liger_kernel/chunked_loss/jsd_loss.py @@ -19,15 +19,20 @@ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5): student_log_probs = F.log_softmax(student_logits, dim=-1) teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) - # Compute probabilities (only required for mean calculation) - mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp() - log_mean_probs = mean_probs.log() + if beta == 0: + jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True) + elif beta == 1: + jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True) + else: + # Compute probabilities (only required for mean calculation) + mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp() + log_mean_probs = mean_probs.log() - student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True) - teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True) + student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True) + teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True) - # JSD is the weighted average of the KL divergences - jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + # JSD is the weighted average of the KL divergences + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl return jsd_loss @classmethod diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index c0214010c..42cbc4a26 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -33,13 +33,13 @@ def __init__( def forward( self, - log_q: torch.Tensor, # input + log_q: torch.Tensor, # input student logits log_p: torch.Tensor, # target label: Optional[torch.Tensor] = None, ): - if self.beta == 0.0: + if self.beta == 0.0: # KL(p||q) -> kl(q, p) loss = self.kl(log_q, log_p).sum(dim=-1) - elif self.beta == 1.0: + elif self.beta == 1.0: # KL(q||p) -> kl(p, q) loss = self.kl(log_p, log_q).sum(dim=-1) else: log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) From a26f0ab649964aa1dc591e9c958e2d8501ab43a4 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Mar 2025 18:20:27 +0100 Subject: [PATCH 2/3] use max trick --- src/liger_kernel/ops/jsd.py | 41 +++++++++++++++++++++++++---------- test/transformers/test_jsd.py | 6 ++--- 2 files changed, 33 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/ops/jsd.py b/src/liger_kernel/ops/jsd.py index 882b4099b..da2f4bd36 100644 --- a/src/liger_kernel/ops/jsd.py +++ b/src/liger_kernel/ops/jsd.py @@ -51,24 +51,43 @@ def _jsd_kernel( Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32) if beta == 0.0: # forward KL - Y_prob = tl.exp(Y) + Y_max = tl.max(Y, axis=0) + Y_shifted = Y - Y_max + Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift loss = Y_prob * (Y - X) dX = -Y_prob - elif beta == 1.0: - X_prob = tl.exp(X) + elif beta == 1.0: # reverse KL + X_max = tl.max(X, axis=0) + X_shifted = X - X_max + X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift loss = X_prob * (X - Y) dX = loss + X_prob else: - Q = tl.exp(X) - P = tl.exp(Y) - M = beta * P + (1 - beta) * Q - log_M = tl.log(M) + max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0)) + X_shifted = X - max_val + Y_shifted = Y - max_val - loss = beta * P * Y + (1 - beta) * Q * X - M * log_M - dX = (1 - beta) * Q * (X - log_M) + # Pre-compute exp(max_val) since it's used twice + exp_max = tl.exp(max_val) + + # Compute exp terms with compensation + Q = tl.exp(X_shifted) * exp_max # = exp(X) + P = tl.exp(Y_shifted) * exp_max # = exp(Y) + + # Pre-compute common terms + beta_P = beta * P + one_minus_beta_Q = (1 - beta) * Q + M = beta_P + one_minus_beta_Q + log_M = tl.log(M) # No need to compensate as M is already in original scale + + loss = beta_P * Y + one_minus_beta_Q * X - M * log_M + dX = one_minus_beta_Q * (X - log_M) + + # Pre-compute scaling factor + scale = 1.0 / n_non_ignore + loss = loss * scale + dX = dX * scale - loss = loss / n_non_ignore - dX = dX / n_non_ignore tl.store(loss_ptr + offsets, loss, mask=mask) tl.store(dX_ptr + offsets, dX, mask=mask) diff --git a/test/transformers/test_jsd.py b/test/transformers/test_jsd.py index 42cbc4a26..e41f7a117 100644 --- a/test/transformers/test_jsd.py +++ b/test/transformers/test_jsd.py @@ -33,13 +33,13 @@ def __init__( def forward( self, - log_q: torch.Tensor, # input student logits + log_q: torch.Tensor, # input student logits log_p: torch.Tensor, # target label: Optional[torch.Tensor] = None, ): - if self.beta == 0.0: # KL(p||q) -> kl(q, p) + if self.beta == 0.0: # KL(p||q) -> kl(q, p) loss = self.kl(log_q, log_p).sum(dim=-1) - elif self.beta == 1.0: # KL(q||p) -> kl(p, q) + elif self.beta == 1.0: # KL(q||p) -> kl(p, q) loss = self.kl(log_p, log_q).sum(dim=-1) else: log_p, log_q = log_p.to(torch.float), log_q.to(torch.float) From e12b36d82eb4bab36647b106aa38bf97be0b9a38 Mon Sep 17 00:00:00 2001 From: Kashif Rasul Date: Fri, 14 Mar 2025 19:06:03 +0100 Subject: [PATCH 3/3] use the beta --- .../chunked_loss/fused_linear_distillation.py | 4 +-- test/chunked_loss/test_jsd_loss.py | 25 +++++++++++-------- test/utils.py | 5 ++-- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/src/liger_kernel/chunked_loss/fused_linear_distillation.py b/src/liger_kernel/chunked_loss/fused_linear_distillation.py index b8461edfe..5b0ef192a 100644 --- a/src/liger_kernel/chunked_loss/fused_linear_distillation.py +++ b/src/liger_kernel/chunked_loss/fused_linear_distillation.py @@ -117,7 +117,7 @@ def _compute_loss( hard_loss /= full_target.shape[0] - soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk) + soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs) soft_loss /= full_target.shape[0] loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss @@ -180,9 +180,9 @@ def forward( ignore_index=ignore_index, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, - beta=beta, compute_ce_loss=compute_ce_loss, temperature=temperature, + beta=beta, **loss_kwargs, ) diff --git a/test/chunked_loss/test_jsd_loss.py b/test/chunked_loss/test_jsd_loss.py index df588250a..e2caf1b66 100644 --- a/test/chunked_loss/test_jsd_loss.py +++ b/test/chunked_loss/test_jsd_loss.py @@ -27,7 +27,6 @@ def __init__( ignore_index: int = -100, weight_hard_loss: float = 0.5, weight_soft_loss: float = 0.5, - beta: float = 0.5, ): super().__init__( ignore_index=ignore_index, @@ -35,7 +34,6 @@ def __init__( weight_soft_loss=weight_soft_loss, temperature=temperature, ) - self.beta = (beta,) def distillation_loss(self, student_logits, teacher_logits, beta=0.5): """ @@ -50,15 +48,20 @@ def distillation_loss(self, student_logits, teacher_logits, beta=0.5): student_log_probs = F.log_softmax(student_logits, dim=-1) teacher_log_probs = F.log_softmax(teacher_logits, dim=-1) - # Compute probabilities (only required for mean calculation) - mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp() - log_mean_probs = mean_probs.log() + if beta == 0: + jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True) + elif beta == 1: + jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True) + else: + # Compute probabilities (only required for mean calculation) + mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp() + log_mean_probs = mean_probs.log() - student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True) - teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True) + student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True) + teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True) - # JSD is the weighted average of the KL divergences - jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + # JSD is the weighted average of the KL divergences + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl return jsd_loss @@ -88,12 +91,12 @@ def __init__( # smaller student model weights self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device) self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device) + self.beta = beta self.jsd = HFJSDLoss( ignore_index=ignore_index, weight_hard_loss=weight_hard_loss, weight_soft_loss=weight_soft_loss, temperature=temperature, - beta=beta, ).get_batch_loss_metrics def forward(self, student_input, teacher_input, target): @@ -105,6 +108,7 @@ def forward(self, student_input, teacher_input, target): target, self.student_lin.bias, self.teacher_lin.bias, + beta=self.beta, ) return jsd_loss @@ -132,6 +136,7 @@ def __init__( weight_soft_loss=weight_soft_loss, ignore_index=ignore_index, temperature=temperature, + beta=beta, ) def forward(self, student_input, teacher_input, target): diff --git a/test/utils.py b/test/utils.py index 64a261bef..ecb180209 100644 --- a/test/utils.py +++ b/test/utils.py @@ -593,7 +593,7 @@ def __init__( self.temperature = temperature @abstractmethod - def distillation_loss(self, student_logits, teacher_logits): + def distillation_loss(self, student_logits, teacher_logits, **loss_kwargs): """Abstract method for computing distillation loss.""" pass @@ -666,6 +666,7 @@ def get_batch_loss_metrics( target: torch.LongTensor, student_bias: torch.FloatTensor = None, teacher_bias: torch.FloatTensor = None, + **loss_kwargs, ): """Compute the distillation loss metrics for the given batch.""" forward_output = self.concatenated_forward( @@ -686,7 +687,7 @@ def get_batch_loss_metrics( student_logits /= self.temperature teacher_logits /= self.temperature - soft_loss = self.distillation_loss(student_logits, teacher_logits) + soft_loss = self.distillation_loss(student_logits, teacher_logits, **loss_kwargs) # full loss loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss return loss