Skip to content

Commit 449f5cb

Browse files
authored
[JSD] JSD fixes (#609)
## Summary Fix the JSD to use the beta made the triton kernel more stable by using the max-exp trick - Hardware Type: <BLANK> - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent a6dc70d commit 449f5cb

File tree

6 files changed

+65
-35
lines changed

6 files changed

+65
-35
lines changed

src/liger_kernel/chunked_loss/fused_linear_distillation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def _compute_loss(
117117

118118
hard_loss /= full_target.shape[0]
119119

120-
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk)
120+
soft_loss = distillation_loss_fn(student_logits_chunk, teacher_logits_chunk, **loss_kwargs)
121121
soft_loss /= full_target.shape[0]
122122

123123
loss = weight_hard_loss * hard_loss + weight_soft_loss * soft_loss
@@ -180,9 +180,9 @@ def forward(
180180
ignore_index=ignore_index,
181181
weight_hard_loss=weight_hard_loss,
182182
weight_soft_loss=weight_soft_loss,
183-
beta=beta,
184183
compute_ce_loss=compute_ce_loss,
185184
temperature=temperature,
185+
beta=beta,
186186
**loss_kwargs,
187187
)
188188

src/liger_kernel/chunked_loss/jsd_loss.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,20 @@ def distillation_loss_fn(student_logits, teacher_logits, beta=0.5):
1919
student_log_probs = F.log_softmax(student_logits, dim=-1)
2020
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
2121

22-
# Compute probabilities (only required for mean calculation)
23-
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
24-
log_mean_probs = mean_probs.log()
22+
if beta == 0:
23+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="sum", log_target=True)
24+
elif beta == 1:
25+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="sum", log_target=True)
26+
else:
27+
# Compute probabilities (only required for mean calculation)
28+
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
29+
log_mean_probs = mean_probs.log()
2530

26-
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
27-
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
31+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="sum", log_target=True)
32+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="sum", log_target=True)
2833

29-
# JSD is the weighted average of the KL divergences
30-
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
34+
# JSD is the weighted average of the KL divergences
35+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
3136
return jsd_loss
3237

3338
@classmethod

src/liger_kernel/ops/jsd.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,43 @@ def _jsd_kernel(
5151
Y = tl.load(Y_ptr + offsets, mask=mask, other=float("-inf")).to(tl.float32)
5252

5353
if beta == 0.0: # forward KL
54-
Y_prob = tl.exp(Y)
54+
Y_max = tl.max(Y, axis=0)
55+
Y_shifted = Y - Y_max
56+
Y_prob = tl.exp(Y_shifted) * tl.exp(Y_max) # Compensate for the shift
5557
loss = Y_prob * (Y - X)
5658
dX = -Y_prob
57-
elif beta == 1.0:
58-
X_prob = tl.exp(X)
59+
elif beta == 1.0: # reverse KL
60+
X_max = tl.max(X, axis=0)
61+
X_shifted = X - X_max
62+
X_prob = tl.exp(X_shifted) * tl.exp(X_max) # Compensate for the shift
5963
loss = X_prob * (X - Y)
6064
dX = loss + X_prob
6165
else:
62-
Q = tl.exp(X)
63-
P = tl.exp(Y)
64-
M = beta * P + (1 - beta) * Q
65-
log_M = tl.log(M)
66+
max_val = tl.maximum(tl.max(X, axis=0), tl.max(Y, axis=0))
67+
X_shifted = X - max_val
68+
Y_shifted = Y - max_val
6669

67-
loss = beta * P * Y + (1 - beta) * Q * X - M * log_M
68-
dX = (1 - beta) * Q * (X - log_M)
70+
# Pre-compute exp(max_val) since it's used twice
71+
exp_max = tl.exp(max_val)
72+
73+
# Compute exp terms with compensation
74+
Q = tl.exp(X_shifted) * exp_max # = exp(X)
75+
P = tl.exp(Y_shifted) * exp_max # = exp(Y)
76+
77+
# Pre-compute common terms
78+
beta_P = beta * P
79+
one_minus_beta_Q = (1 - beta) * Q
80+
M = beta_P + one_minus_beta_Q
81+
log_M = tl.log(M) # No need to compensate as M is already in original scale
82+
83+
loss = beta_P * Y + one_minus_beta_Q * X - M * log_M
84+
dX = one_minus_beta_Q * (X - log_M)
85+
86+
# Pre-compute scaling factor
87+
scale = 1.0 / n_non_ignore
88+
loss = loss * scale
89+
dX = dX * scale
6990

70-
loss = loss / n_non_ignore
71-
dX = dX / n_non_ignore
7291
tl.store(loss_ptr + offsets, loss, mask=mask)
7392
tl.store(dX_ptr + offsets, dX, mask=mask)
7493

test/chunked_loss/test_jsd_loss.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,13 @@ def __init__(
2727
ignore_index: int = -100,
2828
weight_hard_loss: float = 0.5,
2929
weight_soft_loss: float = 0.5,
30-
beta: float = 0.5,
3130
):
3231
super().__init__(
3332
ignore_index=ignore_index,
3433
weight_hard_loss=weight_hard_loss,
3534
weight_soft_loss=weight_soft_loss,
3635
temperature=temperature,
3736
)
38-
self.beta = (beta,)
3937

4038
def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
4139
"""
@@ -50,15 +48,20 @@ def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
5048
student_log_probs = F.log_softmax(student_logits, dim=-1)
5149
teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
5250

53-
# Compute probabilities (only required for mean calculation)
54-
mean_probs = beta * student_log_probs.exp() + (1 - beta) * teacher_log_probs.exp()
55-
log_mean_probs = mean_probs.log()
51+
if beta == 0:
52+
jsd_loss = F.kl_div(student_log_probs, teacher_log_probs, reduction="none", log_target=True)
53+
elif beta == 1:
54+
jsd_loss = F.kl_div(teacher_log_probs, student_log_probs, reduction="none", log_target=True)
55+
else:
56+
# Compute probabilities (only required for mean calculation)
57+
mean_probs = (1 - beta) * student_log_probs.exp() + beta * teacher_log_probs.exp()
58+
log_mean_probs = mean_probs.log()
5659

57-
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True)
58-
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True)
60+
student_kl = F.kl_div(log_mean_probs, student_log_probs, reduction="batchmean", log_target=True)
61+
teacher_kl = F.kl_div(log_mean_probs, teacher_log_probs, reduction="batchmean", log_target=True)
5962

60-
# JSD is the weighted average of the KL divergences
61-
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
63+
# JSD is the weighted average of the KL divergences
64+
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
6265
return jsd_loss
6366

6467

@@ -88,12 +91,12 @@ def __init__(
8891
# smaller student model weights
8992
self.student_lin = torch.nn.Linear(in_features=H // 2, out_features=V, bias=bias, dtype=dtype, device=device)
9093
self.teacher_lin = torch.nn.Linear(in_features=H, out_features=V, bias=bias, dtype=dtype, device=device)
94+
self.beta = beta
9195
self.jsd = HFJSDLoss(
9296
ignore_index=ignore_index,
9397
weight_hard_loss=weight_hard_loss,
9498
weight_soft_loss=weight_soft_loss,
9599
temperature=temperature,
96-
beta=beta,
97100
).get_batch_loss_metrics
98101

99102
def forward(self, student_input, teacher_input, target):
@@ -105,6 +108,7 @@ def forward(self, student_input, teacher_input, target):
105108
target,
106109
self.student_lin.bias,
107110
self.teacher_lin.bias,
111+
beta=self.beta,
108112
)
109113
return jsd_loss
110114

@@ -132,6 +136,7 @@ def __init__(
132136
weight_soft_loss=weight_soft_loss,
133137
ignore_index=ignore_index,
134138
temperature=temperature,
139+
beta=beta,
135140
)
136141

137142
def forward(self, student_input, teacher_input, target):

test/transformers/test_jsd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,13 @@ def __init__(
3333

3434
def forward(
3535
self,
36-
log_q: torch.Tensor, # input
36+
log_q: torch.Tensor, # input student logits
3737
log_p: torch.Tensor, # target
3838
label: Optional[torch.Tensor] = None,
3939
):
40-
if self.beta == 0.0:
40+
if self.beta == 0.0: # KL(p||q) -> kl(q, p)
4141
loss = self.kl(log_q, log_p).sum(dim=-1)
42-
elif self.beta == 1.0:
42+
elif self.beta == 1.0: # KL(q||p) -> kl(p, q)
4343
loss = self.kl(log_p, log_q).sum(dim=-1)
4444
else:
4545
log_p, log_q = log_p.to(torch.float), log_q.to(torch.float)

test/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -593,7 +593,7 @@ def __init__(
593593
self.temperature = temperature
594594

595595
@abstractmethod
596-
def distillation_loss(self, student_logits, teacher_logits):
596+
def distillation_loss(self, student_logits, teacher_logits, **loss_kwargs):
597597
"""Abstract method for computing distillation loss."""
598598
pass
599599

@@ -666,6 +666,7 @@ def get_batch_loss_metrics(
666666
target: torch.LongTensor,
667667
student_bias: torch.FloatTensor = None,
668668
teacher_bias: torch.FloatTensor = None,
669+
**loss_kwargs,
669670
):
670671
"""Compute the distillation loss metrics for the given batch."""
671672
forward_output = self.concatenated_forward(
@@ -686,7 +687,7 @@ def get_batch_loss_metrics(
686687
student_logits /= self.temperature
687688
teacher_logits /= self.temperature
688689

689-
soft_loss = self.distillation_loss(student_logits, teacher_logits)
690+
soft_loss = self.distillation_loss(student_logits, teacher_logits, **loss_kwargs)
690691
# full loss
691692
loss = self.weight_hard_loss * hard_loss + self.weight_soft_loss * soft_loss
692693
return loss

0 commit comments

Comments
 (0)