14
14
from transformer_engine .common .recipe import Format as FP8Format
15
15
from transformer_engine .jax import fp8_autocast , get_delayed_scaling
16
16
from transformer_engine .jax .quantize import (
17
- QuantizeConfig ,
17
+ get_quantize_config ,
18
18
is_fp8_available ,
19
19
ScalingMode ,
20
20
update_collections ,
21
+ TensorSource ,
21
22
)
22
23
from transformer_engine .jax .sharding import MeshResource , global_mesh_resource
23
24
@@ -49,7 +50,7 @@ def test_update_collections(self):
49
50
class TestFP8Functions (unittest .TestCase ):
50
51
51
52
def _check_default_state (self ):
52
- self .assertFalse (QuantizeConfig .is_fp8_enabled ())
53
+ self .assertFalse (get_quantize_config () .is_fp8_enabled ())
53
54
54
55
def _compare_delay_scaling (self , ref , test ):
55
56
self .assertTrue (ref .margin == test .margin )
@@ -58,17 +59,23 @@ def _compare_delay_scaling(self, ref, test):
58
59
self .assertTrue (ref .amax_compute_algo == test .amax_compute_algo )
59
60
60
61
def _compare_current_scaling (self , test ):
61
- self .assertEqual (QuantizeConfig .FP8_FORMAT , test .fp8_format )
62
- self .assertEqual (QuantizeConfig .SCALING_MODE , ScalingMode .CURRENT_TENSOR_SCALING )
62
+ self .assertEqual (get_quantize_config ().FP8_FORMAT , test .fp8_format )
63
+ for tensor_source in TensorSource :
64
+ self .assertEqual (
65
+ get_quantize_config ().get_scaling_mode (tensor_source ),
66
+ ScalingMode .CURRENT_TENSOR_SCALING ,
67
+ )
63
68
64
69
def _compare_mxfp8_scaling (self , test ):
65
- self .assertEqual (QuantizeConfig .MARGIN , test .margin )
66
- self .assertEqual (QuantizeConfig .FP8_FORMAT , test .fp8_format )
67
- self .assertEqual (QuantizeConfig .SCALING_MODE , ScalingMode .MXFP8_1D_SCALING )
70
+ self .assertEqual (get_quantize_config ().MARGIN , test .margin )
71
+ self .assertEqual (get_quantize_config ().FP8_FORMAT , test .fp8_format )
72
+ for tensor_source in TensorSource :
73
+ self .assertEqual (
74
+ get_quantize_config ().get_scaling_mode (tensor_source ), ScalingMode .MXFP8_1D_SCALING
75
+ )
68
76
69
77
@unittest .skipIf (not is_fp8_supported , reason = reason )
70
78
def test_fp8_autocast_delayed_scaling (self ):
71
- QuantizeConfig .finalize () # Ensure the testing not affect by previous tests.
72
79
self ._check_default_state ()
73
80
74
81
with fp8_autocast (enabled = False , fp8_recipe = DelayedScaling (), mesh_resource = MeshResource ()):
@@ -78,21 +85,20 @@ def test_fp8_autocast_delayed_scaling(self):
78
85
79
86
ds = DelayedScaling (margin = 5.0 , fp8_format = FP8Format .E4M3 , amax_history_len = 1 )
80
87
with fp8_autocast (enabled = True , fp8_recipe = ds , mesh_resource = MeshResource ()):
81
- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
88
+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
82
89
self ._compare_delay_scaling (get_delayed_scaling (), ds )
83
90
84
91
self ._check_default_state ()
85
92
86
93
ds = DelayedScaling (margin = 3.0 , fp8_format = FP8Format .HYBRID , amax_history_len = 1 )
87
94
with fp8_autocast (enabled = True , fp8_recipe = ds , mesh_resource = MeshResource ()):
88
- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
95
+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
89
96
self ._compare_delay_scaling (get_delayed_scaling (), ds )
90
97
91
98
self ._check_default_state ()
92
99
93
100
@unittest .skipIf (not is_fp8_supported , reason = reason )
94
101
def test_fp8_autocast_current_scaling (self ):
95
- QuantizeConfig .finalize () # Ensure the testing not affect by previous tests.
96
102
self ._check_default_state ()
97
103
98
104
with fp8_autocast (
@@ -104,21 +110,20 @@ def test_fp8_autocast_current_scaling(self):
104
110
105
111
cs = Float8CurrentScaling (fp8_format = FP8Format .E4M3 )
106
112
with fp8_autocast (enabled = True , fp8_recipe = cs , mesh_resource = MeshResource ()):
107
- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
113
+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
108
114
self ._compare_current_scaling (cs )
109
115
110
116
self ._check_default_state ()
111
117
112
118
cs = Float8CurrentScaling (fp8_format = FP8Format .HYBRID )
113
119
with fp8_autocast (enabled = True , fp8_recipe = cs , mesh_resource = MeshResource ()):
114
- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
120
+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
115
121
self ._compare_current_scaling (cs )
116
122
117
123
self ._check_default_state ()
118
124
119
125
@unittest .skipIf (not is_mxfp8_supported , reason = mxfp8_reason )
120
126
def test_fp8_autocast_mxfp8_block_scaling (self ):
121
- QuantizeConfig .finalize () # Ensure the testing not affect by previous tests.
122
127
self ._check_default_state ()
123
128
124
129
with fp8_autocast (
@@ -130,14 +135,14 @@ def test_fp8_autocast_mxfp8_block_scaling(self):
130
135
131
136
bs = MXFP8BlockScaling (margin = 5.0 , fp8_format = FP8Format .E4M3 )
132
137
with fp8_autocast (enabled = True , fp8_recipe = bs , mesh_resource = MeshResource ()):
133
- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
138
+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
134
139
self ._compare_mxfp8_scaling (bs )
135
140
136
141
self ._check_default_state ()
137
142
138
143
bs = MXFP8BlockScaling (margin = 3.0 , fp8_format = FP8Format .HYBRID )
139
144
with fp8_autocast (enabled = True , fp8_recipe = bs , mesh_resource = MeshResource ()):
140
- self .assertTrue (QuantizeConfig .is_fp8_enabled ())
145
+ self .assertTrue (get_quantize_config () .is_fp8_enabled ())
141
146
self ._compare_mxfp8_scaling (bs )
142
147
143
148
self ._check_default_state ()
0 commit comments