Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
ca9fde7
Add weight copying for fused layernorm
nik-mosaic Nov 16, 2022
fce8533
Change from weight copy to zero/one init
nik-mosaic Nov 16, 2022
d5c0ba8
Merge branch 'mosaicml:dev' into dev
nik-mosaic Nov 29, 2022
9e04286
Merge branch 'mosaicml:dev' into dev
nik-mosaic Dec 5, 2022
5517a09
Merge branch 'mosaicml:dev' into dev
nik-mosaic Dec 6, 2022
1ceead1
Revert FLN Update
nik-mosaic Dec 6, 2022
0b377ae
Merge branch 'dev' of https://github.com/nik-mosaic/composer into dev
nik-mosaic Dec 6, 2022
f14eefc
Add space
nik-mosaic Dec 6, 2022
b71375a
Merge branch 'mosaicml:dev' into dev
nik-mosaic Dec 13, 2022
a28c3eb
Merge branch 'dev' of github.com:nik-mosaic/composer into dev
Jan 30, 2023
9657dfd
Fix model surgery faailure due to functional API change
nik-mosaic Feb 7, 2023
d016390
Merge branch 'dev' of github.com:nik-mosaic/composer into dev
Feb 7, 2023
c1ac784
Update inference algorithm comment
nik-mosaic Feb 7, 2023
d1f0b90
Merge branch 'dev' into dev
nik-mosaic Feb 8, 2023
43e1462
Merge branch 'dev' into dev
bandish-shah Feb 9, 2023
923c1f6
Merge branch 'dev' into dev
bandish-shah Feb 9, 2023
b3baf8b
Merge branch 'mosaicml:dev' into dev
nik-mosaic Feb 17, 2023
a700505
Merge branch 'mosaicml:dev' into dev
nik-mosaic Feb 21, 2023
d2287f2
Merge branch 'dev' of github.com:nik-mosaic/composer into dev
Feb 22, 2023
35d8bca
Merge branch 'mosaicml:dev' into dev
nik-mosaic Feb 22, 2023
fbbd6c8
Merge branch 'mosaicml:dev' into dev
nik-mosaic Feb 28, 2023
ed3e8e9
Update LPLN arguments to match LayerNorm
Feb 23, 2023
5a96345
Update composer/algorithms/low_precision_layernorm/low_precision_laye…
nik-mosaic Feb 28, 2023
b9ae216
add CPU amp option to LPGN
nik-mosaic Feb 28, 2023
ed4c829
Change arg order to match new interface
nik-mosaic Feb 28, 2023
4b6c406
Run precommit
nik-mosaic Feb 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,16 @@ def forward(self, x):
return F.group_norm(downcast_x, self.num_groups, downcast_weight, downcast_bias, self.eps)


def _cast_if_autocast_enabled(hidden_states):
if not torch.is_autocast_enabled():
return hidden_states
else:
return torch.cuda.amp.autocast_mode._cast(hidden_states, torch.get_autocast_gpu_dtype())
def _cast_if_autocast_enabled(tensor):
if torch.is_autocast_enabled():
if tensor.device.type == 'cuda':
dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == 'cpu':
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
return tensor


def _to_LPGroupNorm(layer: torch.nn.Module, module_index: int) -> LPGroupNorm:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def apply_low_precision_layernorm(model,
if version.parse(torch.__version__) < version.parse('1.13') and precision == Precision.AMP_BF16:
check_if_apex_installed()
policy: Dict[Type[torch.nn.Module], module_surgery.ReplacementFunction] = {
torch.nn.LayerNorm: to_FusedLayerNorm
torch.nn.LayerNorm: _to_FusedLayerNorm
}

replaced_instances = module_surgery.replace_module_classes(module=model, optimizers=optimizers, policies=policy)
Expand Down Expand Up @@ -88,14 +88,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:

class LPLayerNorm(torch.nn.LayerNorm):

def __init__(self, layer):
super().__init__(normalized_shape=layer.normalized_shape,
eps=layer.eps,
elementwise_affine=layer.elementwise_affine)

with torch.no_grad():
self.weight.copy_(layer.weight)
self.bias.copy_(layer.bias)
def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None):
super().__init__(normalized_shape=normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
device=device,
dtype=dtype)

def forward(self, x):
module_device = x.device
Expand All @@ -106,27 +104,38 @@ def forward(self, x):
return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps)


def _cast_if_autocast_enabled(hidden_states):
if not torch.is_autocast_enabled():
return hidden_states
else:
return torch.cuda.amp.autocast_mode._cast(hidden_states, torch.get_autocast_gpu_dtype())
def _cast_if_autocast_enabled(tensor):
if torch.is_autocast_enabled():
if tensor.device.type == 'cuda':
dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == 'cpu':
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
return tensor


def check_if_apex_installed():
if not APEX_INSTALLED:
raise ImportError(
'https://github.com/NVIDIA/apex is not installed. The Low Precision LayerNorm algorithm cannot be applied. The MosaicML Docker Images (https://hub.docker.com/r/mosaicml/pytorch) contain a copy of APEX for easy use.'
'https://github.com/NVIDIA/apex is not installed. The Low Precision LayerNorm algorithm cannot be applied on PyTorch <1.13 without Apex. The MosaicML Docker Images (https://hub.docker.com/r/mosaicml/pytorch) contain a copy of APEX for easy use.'
)


def _to_LPLayerNorm(layer: torch.nn.Module, module_index: int) -> LPLayerNorm:
"""Defines a replacement policy from a `torch.nn.LayerNorm` to a `LPLayerNorm`"""
if not isinstance(layer, torch.nn.LayerNorm):
raise TypeError(f'Expected torch.nn.LayerNorm, got {type(layer)}')
return LPLayerNorm(layer)
lp_layernorm = LPLayerNorm(layer.normalized_shape, layer.eps, layer.elementwise_affine, layer.weight.device,
layer.weight.dtype)
with torch.no_grad():
lp_layernorm.weight.copy_(layer.weight)
lp_layernorm.bias.copy_(layer.bias)
return lp_layernorm


def to_FusedLayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerNorm:
def _to_FusedLayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerNorm:
"""Defines a replacement policy from a `torch.nn.LayerNorm` to a `apex.normalization.fused_layer_norm`"""
if not isinstance(layer, torch.nn.LayerNorm):
raise TypeError(f'Expected torch.nn.LayerNorm, got {type(layer)}')
Expand Down
21 changes: 10 additions & 11 deletions tests/utils/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch.nn as nn
from torch.utils.data import DataLoader

from composer.core import State
from composer.core import Precision, State
from composer.devices import DeviceCPU, DeviceGPU
from composer.functional import apply_gated_linear_units
from composer.loggers import InMemoryLogger, Logger
Expand All @@ -24,8 +24,8 @@
from composer.trainer.trainer import Trainer
from composer.utils import dist, export_with_logger, inference
from composer.utils.device import get_device
from tests.common import device
from tests.common.datasets import RandomImageDataset
from tests.common import SimpleTransformerClassifier, device
from tests.common.datasets import RandomImageDataset, dummy_transformer_classifier_batch


class MockFileUploader(LoggerDestination):
Expand All @@ -35,14 +35,13 @@ def can_upload_files(self) -> bool:
return True


@pytest.mark.parametrize(
'model_cls, sample_input',
[
(partial(composer_resnet, 'resnet18'), (torch.rand(4, 3, 224, 224), torch.randint(10, (4,)))),
],
)
@pytest.mark.parametrize('model_cls, sample_input', [
(partial(composer_resnet, 'resnet18'), (torch.rand(4, 3, 224, 224), torch.randint(10, (4,)))),
(SimpleTransformerClassifier, dummy_transformer_classifier_batch(vocab_size=10)),
])
def test_export_for_inference_torchscript(model_cls, sample_input):
model = model_cls()

model.eval()

orig_out = model(sample_input)
Expand Down Expand Up @@ -163,7 +162,7 @@ def test_gpu_huggingface_export_for_inference_onnx():
import onnxruntime as ort
import transformers

from composer.functional import apply_fused_layernorm
from composer.functional import apply_low_precision_layernorm
from composer.models import HuggingFaceModel

# HuggingFace Bert Model
Expand Down Expand Up @@ -203,7 +202,7 @@ def test_gpu_huggingface_export_for_inference_onnx():

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
apply_gated_linear_units(model, optimizer)
apply_fused_layernorm(model, optimizer)
apply_low_precision_layernorm(model, Precision('amp_fp16'), optimizer)

model.eval()
orig_out = model(sample_input)
Expand Down