@@ -27,15 +27,13 @@ def __init__(
27
27
ignore_index : int = - 100 ,
28
28
weight_hard_loss : float = 0.5 ,
29
29
weight_soft_loss : float = 0.5 ,
30
- beta : float = 0.5 ,
31
30
):
32
31
super ().__init__ (
33
32
ignore_index = ignore_index ,
34
33
weight_hard_loss = weight_hard_loss ,
35
34
weight_soft_loss = weight_soft_loss ,
36
35
temperature = temperature ,
37
36
)
38
- self .beta = (beta ,)
39
37
40
38
def distillation_loss (self , student_logits , teacher_logits , beta = 0.5 ):
41
39
"""
@@ -50,15 +48,20 @@ def distillation_loss(self, student_logits, teacher_logits, beta=0.5):
50
48
student_log_probs = F .log_softmax (student_logits , dim = - 1 )
51
49
teacher_log_probs = F .log_softmax (teacher_logits , dim = - 1 )
52
50
53
- # Compute probabilities (only required for mean calculation)
54
- mean_probs = beta * student_log_probs .exp () + (1 - beta ) * teacher_log_probs .exp ()
55
- log_mean_probs = mean_probs .log ()
51
+ if beta == 0 :
52
+ jsd_loss = F .kl_div (student_log_probs , teacher_log_probs , reduction = "none" , log_target = True )
53
+ elif beta == 1 :
54
+ jsd_loss = F .kl_div (teacher_log_probs , student_log_probs , reduction = "none" , log_target = True )
55
+ else :
56
+ # Compute probabilities (only required for mean calculation)
57
+ mean_probs = (1 - beta ) * student_log_probs .exp () + beta * teacher_log_probs .exp ()
58
+ log_mean_probs = mean_probs .log ()
56
59
57
- student_kl = F .kl_div (log_mean_probs , student_log_probs , reduction = "batchmean" , log_target = True )
58
- teacher_kl = F .kl_div (log_mean_probs , teacher_log_probs , reduction = "batchmean" , log_target = True )
60
+ student_kl = F .kl_div (log_mean_probs , student_log_probs , reduction = "batchmean" , log_target = True )
61
+ teacher_kl = F .kl_div (log_mean_probs , teacher_log_probs , reduction = "batchmean" , log_target = True )
59
62
60
- # JSD is the weighted average of the KL divergences
61
- jsd_loss = beta * teacher_kl + (1 - beta ) * student_kl
63
+ # JSD is the weighted average of the KL divergences
64
+ jsd_loss = beta * teacher_kl + (1 - beta ) * student_kl
62
65
return jsd_loss
63
66
64
67
@@ -88,12 +91,12 @@ def __init__(
88
91
# smaller student model weights
89
92
self .student_lin = torch .nn .Linear (in_features = H // 2 , out_features = V , bias = bias , dtype = dtype , device = device )
90
93
self .teacher_lin = torch .nn .Linear (in_features = H , out_features = V , bias = bias , dtype = dtype , device = device )
94
+ self .beta = beta
91
95
self .jsd = HFJSDLoss (
92
96
ignore_index = ignore_index ,
93
97
weight_hard_loss = weight_hard_loss ,
94
98
weight_soft_loss = weight_soft_loss ,
95
99
temperature = temperature ,
96
- beta = beta ,
97
100
).get_batch_loss_metrics
98
101
99
102
def forward (self , student_input , teacher_input , target ):
@@ -105,6 +108,7 @@ def forward(self, student_input, teacher_input, target):
105
108
target ,
106
109
self .student_lin .bias ,
107
110
self .teacher_lin .bias ,
111
+ beta = self .beta ,
108
112
)
109
113
return jsd_loss
110
114
@@ -132,6 +136,7 @@ def __init__(
132
136
weight_soft_loss = weight_soft_loss ,
133
137
ignore_index = ignore_index ,
134
138
temperature = temperature ,
139
+ beta = beta ,
135
140
)
136
141
137
142
def forward (self , student_input , teacher_input , target ):
0 commit comments