Skip to content

Commit ed3e8e9

Browse files
nik-mosaicnik-mosaic
authored andcommitted
Update LPLN arguments to match LayerNorm
1 parent fbbd6c8 commit ed3e8e9

File tree

2 files changed

+38
-30
lines changed

2 files changed

+38
-30
lines changed

composer/algorithms/low_precision_layernorm/low_precision_layernorm.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def apply_low_precision_layernorm(model,
4242
if version.parse(torch.__version__) < version.parse('1.13') and precision == Precision.AMP_BF16:
4343
check_if_apex_installed()
4444
policy: Dict[Type[torch.nn.Module], module_surgery.ReplacementFunction] = {
45-
torch.nn.LayerNorm: to_FusedLayerNorm
45+
torch.nn.LayerNorm: _to_FusedLayerNorm
4646
}
4747

4848
replaced_instances = module_surgery.replace_module_classes(module=model, optimizers=optimizers, policies=policy)
@@ -88,14 +88,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
8888

8989
class LPLayerNorm(torch.nn.LayerNorm):
9090

91-
def __init__(self, layer):
92-
super().__init__(normalized_shape=layer.normalized_shape,
93-
eps=layer.eps,
94-
elementwise_affine=layer.elementwise_affine)
95-
96-
with torch.no_grad():
97-
self.weight.copy_(layer.weight)
98-
self.bias.copy_(layer.bias)
91+
def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
92+
super().__init__(normalized_shape=normalized_shape,
93+
eps=eps,
94+
elementwise_affine=elementwise_affine,
95+
device=device,
96+
dtype=dtype)
9997

10098
def forward(self, x):
10199
module_device = x.device
@@ -106,27 +104,38 @@ def forward(self, x):
106104
return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)
107105

108106

109-
def _cast_if_autocast_enabled(hidden_states):
110-
if not torch.is_autocast_enabled():
111-
return hidden_states
112-
else:
113-
return torch.cuda.amp.autocast_mode._cast(hidden_states, torch.get_autocast_gpu_dtype())
107+
def _cast_if_autocast_enabled(tensor):
108+
if torch.is_autocast_enabled():
109+
if tensor.device.type == 'cuda':
110+
dtype = torch.get_autocast_gpu_dtype()
111+
elif tensor.device.type == 'cpu':
112+
dtype = torch.get_autocast_cpu_dtype()
113+
else:
114+
raise NotImplementedError()
115+
return tensor.to(dtype=dtype)
116+
return tensor
114117

115118

116119
def check_if_apex_installed():
117120
if not APEX_INSTALLED:
118121
raise ImportError(
119-
'https://github.com/NVIDIA/apex is not installed. The Low Precision LayerNorm algorithm cannot be applied. The MosaicML Docker Images (https://hub.docker.com/r/mosaicml/pytorch) contain a copy of APEX for easy use.'
122+
'https://github.com/NVIDIA/apex is not installed. The Low Precision LayerNorm algorithm cannot be applied on PyTorch <1.13 without Apex. The MosaicML Docker Images (https://hub.docker.com/r/mosaicml/pytorch) contain a copy of APEX for easy use.'
120123
)
121124

122125

123126
def _to_LPLayerNorm(layer: torch.nn.Module, module_index: int) -> LPLayerNorm:
124-
if not isinstance(layer, torch.nn.LayerNorm):
125-
raise TypeError(f'Expected torch.nn.LayerNorm, got {type(layer)}')
126-
return LPLayerNorm(layer)
127+
"""Defines a replacement policy from a `torch.nn.LayerNorm` to a `LPLayerNorm`"""
128+
assert isinstance(layer,
129+
torch.nn.LayerNorm), 'The replacement policy will look for all instances of torch.nn.LayerNorm'
130+
lp_layernorm = LPLayerNorm(layer.normalized_shape, layer.eps, layer.elementwise_affine, layer.weight.device,
131+
layer.weight.dtype)
132+
with torch.no_grad():
133+
lp_layernorm.weight.copy_(layer.weight)
134+
lp_layernorm.bias.copy_(layer.bias)
135+
return lp_layernorm
127136

128137

129-
def to_FusedLayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerNorm:
138+
def _to_FusedLayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerNorm:
130139
"""Defines a replacement policy from a `torch.nn.LayerNorm` to a `apex.normalization.fused_layer_norm`"""
131140
if not isinstance(layer, torch.nn.LayerNorm):
132141
raise TypeError(f'Expected torch.nn.LayerNorm, got {type(layer)}')

tests/utils/test_inference.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch.nn as nn
1515
from torch.utils.data import DataLoader
1616

17-
from composer.core import State
17+
from composer.core import Precision, State
1818
from composer.devices import DeviceCPU, DeviceGPU
1919
from composer.functional import apply_gated_linear_units
2020
from composer.loggers import InMemoryLogger, Logger
@@ -24,8 +24,8 @@
2424
from composer.trainer.trainer import Trainer
2525
from composer.utils import dist, export_with_logger, inference
2626
from composer.utils.device import get_device
27-
from tests.common import device
28-
from tests.common.datasets import RandomImageDataset
27+
from tests.common import SimpleTransformerClassifier, device
28+
from tests.common.datasets import RandomImageDataset, dummy_transformer_classifier_batch
2929

3030

3131
class MockFileUploader(LoggerDestination):
@@ -35,14 +35,13 @@ def can_upload_files(self) -> bool:
3535
return True
3636

3737

38-
@pytest.mark.parametrize(
39-
'model_cls, sample_input',
40-
[
41-
(partial(composer_resnet, 'resnet18'), (torch.rand(4, 3, 224, 224), torch.randint(10, (4,)))),
42-
],
43-
)
38+
@pytest.mark.parametrize('model_cls, sample_input', [
39+
(partial(composer_resnet, 'resnet18'), (torch.rand(4, 3, 224, 224), torch.randint(10, (4,)))),
40+
(SimpleTransformerClassifier, dummy_transformer_classifier_batch(vocab_size=10)),
41+
])
4442
def test_export_for_inference_torchscript(model_cls, sample_input):
4543
model = model_cls()
44+
4645
model.eval()
4746

4847
orig_out = model(sample_input)
@@ -163,7 +162,7 @@ def test_gpu_huggingface_export_for_inference_onnx():
163162
import onnxruntime as ort
164163
import transformers
165164

166-
from composer.functional import apply_fused_layernorm
165+
from composer.functional import apply_low_precision_layernorm
167166
from composer.models import HuggingFaceModel
168167

169168
# HuggingFace Bert Model
@@ -203,7 +202,7 @@ def test_gpu_huggingface_export_for_inference_onnx():
203202

204203
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
205204
apply_gated_linear_units(model, optimizer)
206-
apply_fused_layernorm(model, optimizer)
205+
apply_low_precision_layernorm(model, optimizer, Precision('amp_fp16'))
207206

208207
model.eval()
209208
orig_out = model(sample_input)

0 commit comments

Comments
 (0)