Skip to content

Commit c950800

Browse files
[JAX] Decouple Recipe and ScalingMode (#1728)
* Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold <[email protected]> * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold <[email protected]> * Format and lint Signed-off-by: Jeremy Berchtold <[email protected]> * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold <[email protected]> * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold <[email protected]> * Update test_layer.py Signed-off-by: Jeremy Berchtold <[email protected]> --------- Signed-off-by: Jeremy Berchtold <[email protected]> Signed-off-by: jberchtold-nvidia <[email protected]>
1 parent 04add79 commit c950800

File tree

8 files changed

+260
-192
lines changed

8 files changed

+260
-192
lines changed

tests/jax/test_helper.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@
1414
from transformer_engine.common.recipe import Format as FP8Format
1515
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
1616
from transformer_engine.jax.quantize import (
17-
QuantizeConfig,
17+
get_quantize_config,
1818
is_fp8_available,
1919
ScalingMode,
2020
update_collections,
21+
TensorSource,
2122
)
2223
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource
2324

@@ -49,7 +50,7 @@ def test_update_collections(self):
4950
class TestFP8Functions(unittest.TestCase):
5051

5152
def _check_default_state(self):
52-
self.assertFalse(QuantizeConfig.is_fp8_enabled())
53+
self.assertFalse(get_quantize_config().is_fp8_enabled())
5354

5455
def _compare_delay_scaling(self, ref, test):
5556
self.assertTrue(ref.margin == test.margin)
@@ -58,17 +59,23 @@ def _compare_delay_scaling(self, ref, test):
5859
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)
5960

6061
def _compare_current_scaling(self, test):
61-
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
62-
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
62+
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
63+
for tensor_source in TensorSource:
64+
self.assertEqual(
65+
get_quantize_config().get_scaling_mode(tensor_source),
66+
ScalingMode.CURRENT_TENSOR_SCALING,
67+
)
6368

6469
def _compare_mxfp8_scaling(self, test):
65-
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
66-
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
67-
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
70+
self.assertEqual(get_quantize_config().MARGIN, test.margin)
71+
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
72+
for tensor_source in TensorSource:
73+
self.assertEqual(
74+
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
75+
)
6876

6977
@unittest.skipIf(not is_fp8_supported, reason=reason)
7078
def test_fp8_autocast_delayed_scaling(self):
71-
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
7279
self._check_default_state()
7380

7481
with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
@@ -78,21 +85,20 @@ def test_fp8_autocast_delayed_scaling(self):
7885

7986
ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
8087
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
81-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
88+
self.assertTrue(get_quantize_config().is_fp8_enabled())
8289
self._compare_delay_scaling(get_delayed_scaling(), ds)
8390

8491
self._check_default_state()
8592

8693
ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
8794
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
88-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
95+
self.assertTrue(get_quantize_config().is_fp8_enabled())
8996
self._compare_delay_scaling(get_delayed_scaling(), ds)
9097

9198
self._check_default_state()
9299

93100
@unittest.skipIf(not is_fp8_supported, reason=reason)
94101
def test_fp8_autocast_current_scaling(self):
95-
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
96102
self._check_default_state()
97103

98104
with fp8_autocast(
@@ -104,21 +110,20 @@ def test_fp8_autocast_current_scaling(self):
104110

105111
cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
106112
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
107-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
113+
self.assertTrue(get_quantize_config().is_fp8_enabled())
108114
self._compare_current_scaling(cs)
109115

110116
self._check_default_state()
111117

112118
cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
113119
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
114-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
120+
self.assertTrue(get_quantize_config().is_fp8_enabled())
115121
self._compare_current_scaling(cs)
116122

117123
self._check_default_state()
118124

119125
@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
120126
def test_fp8_autocast_mxfp8_block_scaling(self):
121-
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
122127
self._check_default_state()
123128

124129
with fp8_autocast(
@@ -130,14 +135,14 @@ def test_fp8_autocast_mxfp8_block_scaling(self):
130135

131136
bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
132137
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
133-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
138+
self.assertTrue(get_quantize_config().is_fp8_enabled())
134139
self._compare_mxfp8_scaling(bs)
135140

136141
self._check_default_state()
137142

138143
bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
139144
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
140-
self.assertTrue(QuantizeConfig.is_fp8_enabled())
145+
self.assertTrue(get_quantize_config().is_fp8_enabled())
141146
self._compare_mxfp8_scaling(bs)
142147

143148
self._check_default_state()

tests/jax/test_layer.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
from transformer_engine.common import recipe
2424
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
2525
from transformer_engine.jax.quantize import (
26-
QuantizeConfig,
26+
get_quantize_config,
2727
ScalingMode,
2828
is_fp8_available,
2929
update_collections,
30+
TensorSource,
31+
fp8_autocast,
3032
)
31-
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
33+
from transformer_engine.jax.sharding import MeshResource
3234

3335

3436
@pytest.fixture(autouse=True, scope="function")
@@ -356,7 +358,7 @@ def test_backward(
356358

357359
ref_params, test_params = self._sync_params(ref_params, test_params)
358360

359-
if QuantizeConfig.is_fp8_enabled():
361+
if get_quantize_config().is_fp8_enabled():
360362
for _ in range(4):
361363
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
362364
inputs,
@@ -365,12 +367,15 @@ def test_backward(
365367
test_others,
366368
test_layer,
367369
)
368-
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
370+
if (
371+
get_quantize_config().get_scaling_mode(TensorSource.X)
372+
== ScalingMode.DELAYED_TENSOR_SCALING
373+
):
369374
_, updated_quantize_meta = flax.core.pop(
370-
updated_state[0], QuantizeConfig.COLLECTION_NAME
375+
updated_state[0], get_quantize_config().COLLECTION_NAME
371376
)
372377
test_others = update_collections(
373-
{QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others
378+
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others
374379
)
375380
del updated_quantize_meta
376381
del updated_state
@@ -500,41 +505,33 @@ class BaseTester:
500505

501506
def test_forward(self, data_shape, dtype, attrs):
502507
"""Test normal datatype forward"""
503-
QuantizeConfig.finalize() # Ensure FP8 disabled.
504-
with global_shard_guard(
505-
MeshResource()
506-
): # Empty MeshResource is used as we are running on a single device
508+
# Ensure FP8 disabled.
509+
# Empty MeshResource is used as we are running on a single device
510+
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
507511
self.runner(attrs).test_forward(data_shape, dtype)
508512

509513
def test_backward(self, data_shape, dtype, attrs):
510514
"""Test normal datatype backward"""
511-
QuantizeConfig.finalize() # Ensure FP8 disabled.
512-
with global_shard_guard(
513-
MeshResource()
514-
): # Empty MeshResource is used as we are running on a single device
515+
# Ensure FP8 disabled.
516+
# Empty MeshResource is used as we are running on a single device
517+
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
515518
self.runner(attrs).test_backward(data_shape, dtype)
516519

517520
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
518521
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
519522
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
520523
"""Test forward with fp8 enabled"""
521-
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
522-
with global_shard_guard(
523-
MeshResource()
524-
): # Empty MeshResource is used as we are running on a single device
524+
# Empty MeshResource is used as we are running on a single device
525+
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
525526
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
526-
QuantizeConfig.finalize()
527527

528528
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
529529
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
530530
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
531531
"""Test backward with fp8 enabled"""
532-
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
533-
with global_shard_guard(
534-
MeshResource()
535-
): # Empty MeshResource is used as we are running on a single device
532+
# Empty MeshResource is used as we are running on a single device
533+
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
536534
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
537-
QuantizeConfig.finalize()
538535

539536

540537
class TestEncoderLayer(BaseTester):

transformer_engine/jax/cpp_extensions/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
219219
"""
220220
Helper function to manage primitive states by name without modifying environment variables.
221221
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
222-
This helper is used in the QuantizeConfig.initialize() methods.
222+
This helper is used in the get_quantize_config().initialize() methods.
223223
224224
Args:
225225
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.

transformer_engine/jax/cpp_extensions/gemm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
ScalingMode,
2929
Quantizer,
3030
GroupedQuantizer,
31-
QuantizeConfig,
31+
get_quantize_config,
3232
QuantizerSet,
3333
QuantizeLayout,
3434
noop_quantizer_set,
@@ -754,7 +754,7 @@ def _te_gemm(
754754
fuse_bias: bool = False,
755755
fuse_gelu: bool = False,
756756
grad: bool = False,
757-
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
757+
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
758758
) -> Tuple[jax.Array, ...]:
759759

760760
# Prepare non-quantized GEMM operands
@@ -1107,7 +1107,7 @@ def _jax_gemm_fp8_impl(lhs, rhs):
11071107
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
11081108
precision = (
11091109
jax.lax.Precision.HIGHEST
1110-
if QuantizeConfig.FP8_2X_ACC_FPROP
1110+
if get_quantize_config().FP8_2X_ACC_FPROP
11111111
else jax.lax.Precision.DEFAULT
11121112
)
11131113
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)

transformer_engine/jax/flax/module.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@
3232
jax_scaled_masked_softmax,
3333
jax_scaled_upper_triang_masked_softmax,
3434
)
35-
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
35+
from ..quantize import (
36+
QuantizerFactory,
37+
get_quantize_config,
38+
QuantizeMeta,
39+
QuantizeMetaSet,
40+
ScalingMode,
41+
TensorSource,
42+
)
3643

3744
PRNGKey = Any
3845
Shape = Tuple[int, ...]
@@ -350,7 +357,7 @@ def generate_quantize_meta(quantizer_name: str):
350357
collection_name = (
351358
variable_collection
352359
if variable_collection is not None
353-
else QuantizeConfig.COLLECTION_NAME
360+
else get_quantize_config().COLLECTION_NAME
354361
)
355362
scale = self.variable(
356363
collection_name,
@@ -363,14 +370,14 @@ def generate_quantize_meta(quantizer_name: str):
363370
collection_name,
364371
f"{quantizer_name}{postfix}_amax_history",
365372
jnp.zeros,
366-
(QuantizeConfig.AMAX_HISTORY_LEN,),
373+
(get_quantize_config().AMAX_HISTORY_LEN,),
367374
jnp.float32,
368375
).value
369376
return QuantizeMeta(scale=scale, amax_history=amax_history)
370377

371-
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(
372-
fp8_recipe, recipe.DelayedScaling
373-
):
378+
if get_quantize_config().get_scaling_mode(
379+
TensorSource.X
380+
) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
374381
x_meta = generate_quantize_meta("x")
375382
kernel_meta = generate_quantize_meta("kernel")
376383
grad_meta = generate_quantize_meta("grad")
@@ -483,7 +490,7 @@ def __call__(self, inputs: Array) -> Array:
483490
self.dtype,
484491
)
485492

486-
if not QuantizeConfig.is_fp8_enabled():
493+
if not get_quantize_config().is_fp8_enabled():
487494
kernel = kernel.astype(input_dtype)
488495

489496
if self.use_bias:
@@ -692,7 +699,7 @@ def __call__(self, inputs: Array) -> Array:
692699
quantizer_set = self.generate_quantizer_set()
693700

694701
fuse_layernorm = (
695-
QuantizeConfig.is_fp8_enabled()
702+
get_quantize_config().is_fp8_enabled()
696703
and not self.return_layernorm_output
697704
and self.enable_layernorm
698705
)
@@ -743,7 +750,7 @@ def __call__(self, inputs: Array) -> Array:
743750
kernel_shape,
744751
self.dtype,
745752
)
746-
if not QuantizeConfig.is_fp8_enabled():
753+
if not get_quantize_config().is_fp8_enabled():
747754
kernel = kernel.astype(input_dtype)
748755

749756
contract_ind = tuple(range(0, len(axis)))
@@ -1005,7 +1012,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
10051012
# TODO(Phuong): use fuse_layernorm for high-precision
10061013
# when NoOpQuantizer and Tensor are implemented
10071014
fuse_layernorm = (
1008-
QuantizeConfig.is_fp8_enabled()
1015+
get_quantize_config().is_fp8_enabled()
10091016
and not self.return_layernorm_output
10101017
and self.enable_layernorm
10111018
)
@@ -1088,7 +1095,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
10881095
self.dtype,
10891096
)
10901097

1091-
if not QuantizeConfig.is_fp8_enabled():
1098+
if not get_quantize_config().is_fp8_enabled():
10921099
kernel_1 = kernel_1.astype(input_dtype)
10931100

10941101
hidden_size = inputs.shape[-1]
@@ -1100,7 +1107,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
11001107
kernel_2_shape,
11011108
self.dtype,
11021109
)
1103-
if not QuantizeConfig.is_fp8_enabled():
1110+
if not get_quantize_config().is_fp8_enabled():
11041111
kernel_2 = kernel_2.astype(input_dtype)
11051112

11061113
contract_ind = tuple(range(0, len(axis)))

0 commit comments

Comments
 (0)