Skip to content
37 changes: 21 additions & 16 deletions tests/jax/test_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
from transformer_engine.common.recipe import Format as FP8Format
from transformer_engine.jax import fp8_autocast, get_delayed_scaling
from transformer_engine.jax.quantize import (
QuantizeConfig,
get_quantize_config,
is_fp8_available,
ScalingMode,
update_collections,
TensorSource,
)
from transformer_engine.jax.sharding import MeshResource, global_mesh_resource

Expand Down Expand Up @@ -49,7 +50,7 @@ def test_update_collections(self):
class TestFP8Functions(unittest.TestCase):

def _check_default_state(self):
self.assertFalse(QuantizeConfig.is_fp8_enabled())
self.assertFalse(get_quantize_config().is_fp8_enabled())

def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.margin == test.margin)
Expand All @@ -58,17 +59,23 @@ def _compare_delay_scaling(self, ref, test):
self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo)

def _compare_current_scaling(self, test):
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING)
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source),
ScalingMode.CURRENT_TENSOR_SCALING,
)

def _compare_mxfp8_scaling(self, test):
self.assertEqual(QuantizeConfig.MARGIN, test.margin)
self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format)
self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING)
self.assertEqual(get_quantize_config().MARGIN, test.margin)
self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format)
for tensor_source in TensorSource:
self.assertEqual(
get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING
)

@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_delayed_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()

with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()):
Expand All @@ -78,21 +85,20 @@ def test_fp8_autocast_delayed_scaling(self):

ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)

self._check_default_state()

ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1)
with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_delay_scaling(get_delayed_scaling(), ds)

self._check_default_state()

@unittest.skipIf(not is_fp8_supported, reason=reason)
def test_fp8_autocast_current_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()

with fp8_autocast(
Expand All @@ -104,21 +110,20 @@ def test_fp8_autocast_current_scaling(self):

cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)

self._check_default_state()

cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_current_scaling(cs)

self._check_default_state()

@unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason)
def test_fp8_autocast_mxfp8_block_scaling(self):
QuantizeConfig.finalize() # Ensure the testing not affect by previous tests.
self._check_default_state()

with fp8_autocast(
Expand All @@ -130,14 +135,14 @@ def test_fp8_autocast_mxfp8_block_scaling(self):

bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3)
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)

self._check_default_state()

bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID)
with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()):
self.assertTrue(QuantizeConfig.is_fp8_enabled())
self.assertTrue(get_quantize_config().is_fp8_enabled())
self._compare_mxfp8_scaling(bs)

self._check_default_state()
45 changes: 21 additions & 24 deletions tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@
from transformer_engine.common import recipe
from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType
from transformer_engine.jax.quantize import (
QuantizeConfig,
get_quantize_config,
ScalingMode,
is_fp8_available,
update_collections,
TensorSource,
fp8_autocast,
)
from transformer_engine.jax.sharding import MeshResource, global_shard_guard
from transformer_engine.jax.sharding import MeshResource


@pytest.fixture(autouse=True, scope="function")
Expand Down Expand Up @@ -356,7 +358,7 @@ def test_backward(

ref_params, test_params = self._sync_params(ref_params, test_params)

if QuantizeConfig.is_fp8_enabled():
if get_quantize_config().is_fp8_enabled():
for _ in range(4):
_, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)(
inputs,
Expand All @@ -365,12 +367,15 @@ def test_backward(
test_others,
test_layer,
)
if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING:
if (
get_quantize_config().get_scaling_mode(TensorSource.X)
== ScalingMode.DELAYED_TENSOR_SCALING
):
_, updated_quantize_meta = flax.core.pop(
updated_state[0], QuantizeConfig.COLLECTION_NAME
updated_state[0], get_quantize_config().COLLECTION_NAME
)
test_others = update_collections(
{QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others
{get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others
)
del updated_quantize_meta
del updated_state
Expand Down Expand Up @@ -500,41 +505,33 @@ class BaseTester:

def test_forward(self, data_shape, dtype, attrs):
"""Test normal datatype forward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype)

def test_backward(self, data_shape, dtype, attrs):
"""Test normal datatype backward"""
QuantizeConfig.finalize() # Ensure FP8 disabled.
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
# Ensure FP8 disabled.
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=False, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test forward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES)
def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe):
"""Test backward with fp8 enabled"""
QuantizeConfig.initialize(fp8_recipe=fp8_recipe)
with global_shard_guard(
MeshResource()
): # Empty MeshResource is used as we are running on a single device
# Empty MeshResource is used as we are running on a single device
with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()):
self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3)
QuantizeConfig.finalize()


class TestEncoderLayer(BaseTester):
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/jax/cpp_extensions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F
"""
Helper function to manage primitive states by name without modifying environment variables.
Allows enabling specific primitives, disabling specific primitives, or disabling all primitives.
This helper is used in the QuantizeConfig.initialize() methods.
This helper is used in the get_quantize_config().initialize() methods.

Args:
enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None.
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/jax/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
ScalingMode,
Quantizer,
GroupedQuantizer,
QuantizeConfig,
get_quantize_config,
QuantizerSet,
QuantizeLayout,
noop_quantizer_set,
Expand Down Expand Up @@ -754,7 +754,7 @@ def _te_gemm(
fuse_bias: bool = False,
fuse_gelu: bool = False,
grad: bool = False,
use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP,
use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP,
) -> Tuple[jax.Array, ...]:

# Prepare non-quantized GEMM operands
Expand Down Expand Up @@ -1107,7 +1107,7 @@ def _jax_gemm_fp8_impl(lhs, rhs):
), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}"
precision = (
jax.lax.Precision.HIGHEST
if QuantizeConfig.FP8_2X_ACC_FPROP
if get_quantize_config().FP8_2X_ACC_FPROP
else jax.lax.Precision.DEFAULT
)
return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision)
Expand Down
31 changes: 19 additions & 12 deletions transformer_engine/jax/flax/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@
jax_scaled_masked_softmax,
jax_scaled_upper_triang_masked_softmax,
)
from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode
from ..quantize import (
QuantizerFactory,
get_quantize_config,
QuantizeMeta,
QuantizeMetaSet,
ScalingMode,
TensorSource,
)

PRNGKey = Any
Shape = Tuple[int, ...]
Expand Down Expand Up @@ -350,7 +357,7 @@ def generate_quantize_meta(quantizer_name: str):
collection_name = (
variable_collection
if variable_collection is not None
else QuantizeConfig.COLLECTION_NAME
else get_quantize_config().COLLECTION_NAME
)
scale = self.variable(
collection_name,
Expand All @@ -363,14 +370,14 @@ def generate_quantize_meta(quantizer_name: str):
collection_name,
f"{quantizer_name}{postfix}_amax_history",
jnp.zeros,
(QuantizeConfig.AMAX_HISTORY_LEN,),
(get_quantize_config().AMAX_HISTORY_LEN,),
jnp.float32,
).value
return QuantizeMeta(scale=scale, amax_history=amax_history)

if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(
fp8_recipe, recipe.DelayedScaling
):
if get_quantize_config().get_scaling_mode(
TensorSource.X
) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling):
x_meta = generate_quantize_meta("x")
kernel_meta = generate_quantize_meta("kernel")
grad_meta = generate_quantize_meta("grad")
Expand Down Expand Up @@ -483,7 +490,7 @@ def __call__(self, inputs: Array) -> Array:
self.dtype,
)

if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel = kernel.astype(input_dtype)

if self.use_bias:
Expand Down Expand Up @@ -692,7 +699,7 @@ def __call__(self, inputs: Array) -> Array:
quantizer_set = self.generate_quantizer_set()

fuse_layernorm = (
QuantizeConfig.is_fp8_enabled()
get_quantize_config().is_fp8_enabled()
and not self.return_layernorm_output
and self.enable_layernorm
)
Expand Down Expand Up @@ -743,7 +750,7 @@ def __call__(self, inputs: Array) -> Array:
kernel_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel = kernel.astype(input_dtype)

contract_ind = tuple(range(0, len(axis)))
Expand Down Expand Up @@ -1005,7 +1012,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array:
# TODO(Phuong): use fuse_layernorm for high-precision
# when NoOpQuantizer and Tensor are implemented
fuse_layernorm = (
QuantizeConfig.is_fp8_enabled()
get_quantize_config().is_fp8_enabled()
and not self.return_layernorm_output
and self.enable_layernorm
)
Expand Down Expand Up @@ -1088,7 +1095,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
self.dtype,
)

if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel_1 = kernel_1.astype(input_dtype)

hidden_size = inputs.shape[-1]
Expand All @@ -1100,7 +1107,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args):
kernel_2_shape,
self.dtype,
)
if not QuantizeConfig.is_fp8_enabled():
if not get_quantize_config().is_fp8_enabled():
kernel_2 = kernel_2.astype(input_dtype)

contract_ind = tuple(range(0, len(axis)))
Expand Down
Loading