Skip to content

Commit df1b16d

Browse files
authored
[PyTorch] Fix bug in FP8 cast in LayerNormLinear/LayerNormMLP (#738)
Perform FP8 cast on gathered layernorm output in LayerNormLinear Signed-off-by: Tim Moon <[email protected]>
1 parent c1a68f6 commit df1b16d

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

transformer_engine/pytorch/cpp_extensions/gemm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def fp8_gemm(
4444
assert fp8_meta_tensor is not None and out_index is not None
4545
assert_dim_for_fp8_exec(A)
4646
assert_dim_for_fp8_exec(B)
47+
assert A.dtype == torch.uint8
48+
assert B.dtype == torch.uint8
4749

4850
if out is None:
4951
out = torch.empty(

transformer_engine/pytorch/module/layernorm_linear.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,19 @@ def forward(
169169
out=ln_out_fp8)
170170
ln_out = ln_out_fp8
171171
else:
172-
ln_out = tex.cast_to_fp8(
173-
ln_out,
172+
ln_out_total = tex.cast_to_fp8(
173+
ln_out_total,
174174
fp8_meta["scaling_fwd"],
175175
tex.FP8FwdTensors.GEMM1_INPUT,
176176
fp8_dtype_forward,
177177
)
178+
if ln_out_gathered:
179+
rank = torch.distributed.get_rank(tp_group)
180+
slice_start = rank * ln_out.size(0)
181+
slice_end = (rank + 1) * ln_out.size(0)
182+
ln_out = ln_out_total[slice_start:slice_end, ...]
183+
else:
184+
ln_out = ln_out_total
178185

179186
if fp8:
180187
bias_dtype = (

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -187,12 +187,27 @@ def forward(
187187
if return_layernorm_output:
188188
ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
189189
if fp8:
190-
ln_out = tex.cast_to_fp8(
191-
ln_out,
192-
fp8_meta["scaling_fwd"],
193-
tex.FP8FwdTensors.GEMM1_INPUT,
194-
fp8_dtype_forward,
195-
)
190+
if ub_overlap_ag:
191+
ln_out = tex.cast_to_fp8(
192+
ln_out,
193+
fp8_meta["scaling_fwd"],
194+
tex.FP8FwdTensors.GEMM1_INPUT,
195+
fp8_dtype_forward,
196+
)
197+
else:
198+
ln_out_total = tex.cast_to_fp8(
199+
ln_out_total,
200+
fp8_meta["scaling_fwd"],
201+
tex.FP8FwdTensors.GEMM1_INPUT,
202+
fp8_dtype_forward,
203+
)
204+
if ln_out_gathered:
205+
rank = torch.distributed.get_rank(tp_group)
206+
slice_start = rank * ln_out.size(0)
207+
slice_end = (rank + 1) * ln_out.size(0)
208+
ln_out = ln_out_total[slice_start:slice_end, ...]
209+
else:
210+
ln_out = ln_out_total
196211

197212
if fp8:
198213
bias_dtype = (

0 commit comments

Comments
 (0)