Skip to content
Merged
12 changes: 7 additions & 5 deletions tests/pytorch/debug/test_api_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def assert_empty():
)[0]

expected_underflows = (
((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
((tensor_fp8.dequantize() == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
)

assert debug_api.transformer_engine.inspect_tensor_enabled(
Expand Down Expand Up @@ -302,7 +302,7 @@ def assert_empty():
)[0]

# Second config in same yaml
tensor = torch.rand((100, 100, 5))
tensor = torch.rand((100, 100, 5)).cuda()
debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1",
tensor_name="activation",
Expand All @@ -316,7 +316,9 @@ def assert_empty():
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"])
assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean()
torch.testing.assert_close(
stats[("decoder.6.mlp.fc1", "activation", "mean", 200)], tensor.mean()
)

debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1",
Expand All @@ -331,7 +333,7 @@ def assert_empty():
stats = log()
stats_names = [x[3] for x in stats.keys()]
all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"])
assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max()
torch.testing.assert_close(stats[("decoder.7.mlp.fc1", "weight", "max", 200)], tensor.max())

assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.7.mlp.fc1", tensor_name="weight", iteration=201
Expand Down Expand Up @@ -377,7 +379,7 @@ def fp8_tensor(t):
return quantizer(t.cuda())

shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)]
tensors = [torch.randn(shape).cuda() for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]

feed(tensors[0], tensors_fp8[0], quantizer)
Expand Down
10 changes: 4 additions & 6 deletions tests/pytorch/debug/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def test_numerics(fp8_recipe, feature_dirs):
num_quantizers=3,
)

tensor = torch.zeros(1024, 1024).cuda()
tensor[0, :] = 1000
tensor = torch.randn(1024, 1024).cuda()
tensor[0, 100:200] = -0.0
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)

Expand All @@ -191,15 +191,13 @@ def test_numerics(fp8_recipe, feature_dirs):
if "underflows%" in line:
underflows = float(line.split("value=")[1])
expected = (
((dequantized_tensor == 0).sum() - (tensor == 0).sum())
/ dequantized_tensor.numel()
* 100
((dequantized_tensor == 0).sum() - (tensor == 0).sum()) / tensor.numel() * 100
)
assert underflows == pytest.approx(expected.cpu(), abs=1e-4)
if "mse" in line:
mse = float(line.split("value=")[1])
expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean")
assert mse == pytest.approx(expected.cpu(), abs=1e-6)
assert mse == pytest.approx(expected.cpu(), abs=1e-4)
if "overflows%" in line:
overflows = float(line.split("value=")[1])
expected = (
Expand Down
26 changes: 18 additions & 8 deletions transformer_engine/debug/features/utils/stats_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,15 @@ def _get(buffers, stat_name):
),
}

FP8_NEGATIVE_ZERO = 128 # represnts -0.0 in fp8


def count_nonzero_fp8(fp8_data: torch.Tensor) -> torch.Tensor:
"""Count the number of non-zero elements in the fp8 data."""
fp8_data = fp8_data.view(dtype=torch.uint8)
zero_vals = torch.tensor([0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8)
return fp8_data.numel() - torch.isin(fp8_data, zero_vals).sum()


def add_underflows_stats(recipe_name: str, columnwise: bool = False):
"""Register *both* underflow stats (num and %) for the given recipe."""
Expand All @@ -212,22 +221,23 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False):
stats_to_num[stat_pct] = len(stats_to_num)

STATS[stat_num] = (
lambda x, aux_dict: (
lambda x, aux_dict: x.count_nonzero()
- count_nonzero_fp8(
aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise
)
== 0
).sum()
- (x == 0).sum(),
),
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
)
STATS[stat_pct] = (
lambda x, aux_dict: (
aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise
x.count_nonzero()
- count_nonzero_fp8(
aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise
)
)
== 0
).sum()
)
/ aux_dict[recipe_name].numel()
* 100,
lambda buffers, _sn_num=stat_num: 100
Expand Down