Skip to content

Commit d474723

Browse files
committed
fix underflows log issue
Signed-off-by: Pawel Gadzinski <[email protected]>
1 parent 20be25a commit d474723

File tree

1 file changed

+5
-8
lines changed

1 file changed

+5
-8
lines changed

transformer_engine/debug/features/utils/stats_computation.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -211,23 +211,20 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False):
211211
stats_to_num[stat_num] = len(stats_to_num)
212212
stats_to_num[stat_pct] = len(stats_to_num)
213213

214+
zero_values = torch.tensor([0, 127], device="cuda")
215+
214216
STATS[stat_num] = (
215-
lambda x, aux_dict: (
217+
lambda x, aux_dict:
216218
aux_dict[recipe_name].get_data_tensors(
217219
rowwise_data=not columnwise, columnwise_data=columnwise
218-
)
219-
== 0
220-
).sum()
221-
- (x == 0).sum(),
220+
).isin(zero_values).sum() - (x == 0).sum(),
222221
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
223222
)
224223
STATS[stat_pct] = (
225224
lambda x, aux_dict: (
226225
aux_dict[recipe_name].get_data_tensors(
227226
rowwise_data=not columnwise, columnwise_data=columnwise
228-
)
229-
== 0
230-
).sum()
227+
).isin(zero_values).sum() - (x == 0).sum())
231228
/ aux_dict[recipe_name].numel()
232229
* 100,
233230
lambda buffers, _sn_num=stat_num: 100

0 commit comments

Comments
 (0)