@@ -68,6 +68,101 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
68
68
)
69
69
scale .copy_ (new_scale )
70
70
71
+ # this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
72
+ # and modified to only support dynamic scaling
73
+ @torch ._dynamo .allow_in_graph
74
+ class float8_linear (torch .autograd .Function ):
75
+ """
76
+ Like F.linear, but with X and W in float8
77
+ """
78
+
79
+ @staticmethod
80
+ def forward (
81
+ ctx ,
82
+ x_fp8 ,
83
+ w_fp8 ,
84
+ emulate : bool ,
85
+ # TODO(this PR): split config into fwd/bwd
86
+ mm_config : ScaledMMConfig ,
87
+ ):
88
+ ctx .save_for_backward (x_fp8 , w_fp8 )
89
+ ctx .emulate = emulate
90
+ ctx .mm_config = mm_config
91
+ # orig_shape = x_fp8._data.shape
92
+ orig_shape = x_fp8 .shape
93
+ # x_fp8_reshaped = Float8Tensor(
94
+ # x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype, mm_config
95
+ # )
96
+ x_fp8_reshaped = x_fp8 .reshape (- 1 , orig_shape [- 1 ])
97
+
98
+ # w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, mm_config)
99
+ w_fp8_t = w_fp8 .t ()
100
+
101
+ res_bits = torch .mm (
102
+ x_fp8_reshaped , w_fp8_t
103
+ )
104
+ res_bits = res_bits .reshape (* orig_shape [:- 1 ], res_bits .shape [- 1 ])
105
+ return res_bits
106
+
107
+ @staticmethod
108
+ def backward (ctx , go_fp8 ):
109
+ x_fp8 , w_fp8 = ctx .saved_tensors
110
+ emulate = ctx .emulate
111
+ mm_config = ctx .mm_config
112
+
113
+ go_fp8_orig_shape = go_fp8 .shape
114
+ # go_fp8_reshaped = Float8Tensor(
115
+ # go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]),
116
+ # go_fp8._scale,
117
+ # go_fp8._orig_dtype,
118
+ # mm_config,
119
+ # )
120
+ go_fp8_reshaped = go_fp8 .reshape (- 1 , go_fp8_orig_shape [- 1 ])
121
+
122
+ # w_fp8_t_c_t = Float8Tensor(
123
+ # w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype, mm_config
124
+ # )
125
+ w_fp8_t_c_t = w_fp8 .t ().contiguous ().t ()
126
+
127
+ #
128
+ # calculate dL/dX
129
+ #
130
+ dL_dX = torch .mm (
131
+ go_fp8_reshaped ,
132
+ w_fp8_t_c_t ,
133
+ )
134
+ dL_dX = dL_dX .reshape (* go_fp8_orig_shape [:- 1 ], dL_dX .shape [- 1 ])
135
+
136
+ # x_fp8_orig_shape = x_fp8._data.shape
137
+ x_fp8_orig_shape = x_fp8 .shape
138
+ # x_fp8_reshaped_t_c = Float8Tensor(
139
+ # x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(),
140
+ # x_fp8._scale,
141
+ # x_fp8._orig_dtype,
142
+ # mm_config,
143
+ # )
144
+ x_fp8_reshaped_t_c = x_fp8 .reshape (- 1 , x_fp8_orig_shape [- 1 ]).t ().contiguous ()
145
+
146
+ # go_fp8_reshaped_t_c_t = Float8Tensor(
147
+ # go_fp8_reshaped._data.t().contiguous().t(),
148
+ # go_fp8_reshaped._scale,
149
+ # go_fp8_reshaped._orig_dtype,
150
+ # mm_config,
151
+ # )
152
+ go_fp8_reshaped_t_c_t = go_fp8_reshaped .t ().contiguous ().t ()
153
+
154
+ #
155
+ # calculate dL/dW
156
+ #
157
+ dL_dW = torch .mm (
158
+ x_fp8_reshaped_t_c ,
159
+ go_fp8_reshaped_t_c_t ,
160
+ )
161
+ dL_dW = dL_dW .t ()
162
+
163
+ empty_grads = None , None , None , None , None , None , None , None , None
164
+ return dL_dX , dL_dW , * empty_grads
165
+
71
166
72
167
@torch ._dynamo .allow_in_graph
73
168
class NoopFwToFloat8E5M2Bw (torch .autograd .Function ):
@@ -394,7 +489,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
394
489
x_fp8 = self .cast_x_to_float8 (input , self .is_amax_initialized )
395
490
w_fp8 = self .cast_w_to_float8 (self .weight , self .is_amax_initialized )
396
491
397
- y = torch .matmul (x_fp8 , w_fp8 .t ())
492
+ if not self .has_any_delayed_scaling :
493
+ emulate = False
494
+ mm_config = self .forward_config
495
+ y = float8_linear .apply (x_fp8 , w_fp8 , emulate , mm_config )
496
+ else :
497
+ y = torch .matmul (x_fp8 , w_fp8 .t ())
398
498
399
499
# Cast gradY to float8_e5m2 during backward
400
500
y = self .cast_y_to_float8_in_bw (y )
0 commit comments