File tree Expand file tree Collapse file tree 3 files changed +32
-8
lines changed
transformer_engine/pytorch Expand file tree Collapse file tree 3 files changed +32
-8
lines changed Original file line number Diff line number Diff line change @@ -44,6 +44,8 @@ def fp8_gemm(
44
44
assert fp8_meta_tensor is not None and out_index is not None
45
45
assert_dim_for_fp8_exec (A )
46
46
assert_dim_for_fp8_exec (B )
47
+ assert A .dtype == torch .uint8
48
+ assert B .dtype == torch .uint8
47
49
48
50
if out is None :
49
51
out = torch .empty (
Original file line number Diff line number Diff line change @@ -169,12 +169,19 @@ def forward(
169
169
out = ln_out_fp8 )
170
170
ln_out = ln_out_fp8
171
171
else :
172
- ln_out = tex .cast_to_fp8 (
173
- ln_out ,
172
+ ln_out_total = tex .cast_to_fp8 (
173
+ ln_out_total ,
174
174
fp8_meta ["scaling_fwd" ],
175
175
tex .FP8FwdTensors .GEMM1_INPUT ,
176
176
fp8_dtype_forward ,
177
177
)
178
+ if ln_out_gathered :
179
+ rank = torch .distributed .get_rank (tp_group )
180
+ slice_start = rank * ln_out .size (0 )
181
+ slice_end = (rank + 1 ) * ln_out .size (0 )
182
+ ln_out = ln_out_total [slice_start :slice_end , ...]
183
+ else :
184
+ ln_out = ln_out_total
178
185
179
186
if fp8 :
180
187
bias_dtype = (
Original file line number Diff line number Diff line change @@ -187,12 +187,27 @@ def forward(
187
187
if return_layernorm_output :
188
188
ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out
189
189
if fp8 :
190
- ln_out = tex .cast_to_fp8 (
191
- ln_out ,
192
- fp8_meta ["scaling_fwd" ],
193
- tex .FP8FwdTensors .GEMM1_INPUT ,
194
- fp8_dtype_forward ,
195
- )
190
+ if ub_overlap_ag :
191
+ ln_out = tex .cast_to_fp8 (
192
+ ln_out ,
193
+ fp8_meta ["scaling_fwd" ],
194
+ tex .FP8FwdTensors .GEMM1_INPUT ,
195
+ fp8_dtype_forward ,
196
+ )
197
+ else :
198
+ ln_out_total = tex .cast_to_fp8 (
199
+ ln_out_total ,
200
+ fp8_meta ["scaling_fwd" ],
201
+ tex .FP8FwdTensors .GEMM1_INPUT ,
202
+ fp8_dtype_forward ,
203
+ )
204
+ if ln_out_gathered :
205
+ rank = torch .distributed .get_rank (tp_group )
206
+ slice_start = rank * ln_out .size (0 )
207
+ slice_end = (rank + 1 ) * ln_out .size (0 )
208
+ ln_out = ln_out_total [slice_start :slice_end , ...]
209
+ else :
210
+ ln_out = ln_out_total
196
211
197
212
if fp8 :
198
213
bias_dtype = (
You can’t perform that action at this time.
0 commit comments