From ca9fde75608606e35bc37473eebd78c167c90c62 Mon Sep 17 00:00:00 2001 From: Nikhil <101217697+nik-mosaic@users.noreply.github.com> Date: Wed, 16 Nov 2022 07:35:14 -0800 Subject: [PATCH 01/11] Add weight copying for fused layernorm --- composer/algorithms/fused_layernorm/fused_layernorm.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/composer/algorithms/fused_layernorm/fused_layernorm.py b/composer/algorithms/fused_layernorm/fused_layernorm.py index 2900342155..b5df679f7d 100644 --- a/composer/algorithms/fused_layernorm/fused_layernorm.py +++ b/composer/algorithms/fused_layernorm/fused_layernorm.py @@ -36,7 +36,11 @@ def from_LayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerN """Defines a replacement policy from a `torch.nn.LayerNorm` to a `apex.normalization.fused_layer_norm`""" assert isinstance(layer, torch.nn.LayerNorm), 'The replacement policy will look for all instances of torch.nn.LayerNorm' - return APEXFusedLayerNorm(normalized_shape=layer.normalized_shape, eps=layer.eps) + fused_layernorm = APEXFusedLayerNorm(normalized_shape=layer.normalized_shape, eps=layer.eps) + with torch.no_grad(): + fused_layernorm.weight.copy_(layer.weight) + fused_layernorm.bias.copy_(layer.bias) + return fused_layernorm def apply_fused_layernorm(model: torch.nn.Module, optimizers: Union[torch.optim.Optimizer, From fce853333254e42b5b18c26cbaea71b3895e7e21 Mon Sep 17 00:00:00 2001 From: Nikhil <101217697+nik-mosaic@users.noreply.github.com> Date: Wed, 16 Nov 2022 07:45:14 -0800 Subject: [PATCH 02/11] Change from weight copy to zero/one init --- composer/algorithms/fused_layernorm/fused_layernorm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/composer/algorithms/fused_layernorm/fused_layernorm.py b/composer/algorithms/fused_layernorm/fused_layernorm.py index b5df679f7d..d5e35cd407 100644 --- a/composer/algorithms/fused_layernorm/fused_layernorm.py +++ b/composer/algorithms/fused_layernorm/fused_layernorm.py @@ -37,9 +37,8 @@ def from_LayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerN assert isinstance(layer, torch.nn.LayerNorm), 'The replacement policy will look for all instances of torch.nn.LayerNorm' fused_layernorm = APEXFusedLayerNorm(normalized_shape=layer.normalized_shape, eps=layer.eps) - with torch.no_grad(): - fused_layernorm.weight.copy_(layer.weight) - fused_layernorm.bias.copy_(layer.bias) + torch.nn.init.ones_(fused_layernorm.weight) + torch.nn.init.zeros_(fused_layernorm.bias) return fused_layernorm From 1ceead1c77f621c0ec1565cbe35f949a8a9c8e8a Mon Sep 17 00:00:00 2001 From: Nikhil <101217697+nik-mosaic@users.noreply.github.com> Date: Tue, 6 Dec 2022 10:04:01 -0800 Subject: [PATCH 03/11] Revert FLN Update --- composer/algorithms/fused_layernorm/fused_layernorm.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/composer/algorithms/fused_layernorm/fused_layernorm.py b/composer/algorithms/fused_layernorm/fused_layernorm.py index d5e35cd407..b9f9d85794 100644 --- a/composer/algorithms/fused_layernorm/fused_layernorm.py +++ b/composer/algorithms/fused_layernorm/fused_layernorm.py @@ -36,11 +36,7 @@ def from_LayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerN """Defines a replacement policy from a `torch.nn.LayerNorm` to a `apex.normalization.fused_layer_norm`""" assert isinstance(layer, torch.nn.LayerNorm), 'The replacement policy will look for all instances of torch.nn.LayerNorm' - fused_layernorm = APEXFusedLayerNorm(normalized_shape=layer.normalized_shape, eps=layer.eps) - torch.nn.init.ones_(fused_layernorm.weight) - torch.nn.init.zeros_(fused_layernorm.bias) - return fused_layernorm - + return APEXFusedLayerNorm(normalized_shape=layer.normalized_shape, eps=layer.eps) def apply_fused_layernorm(model: torch.nn.Module, optimizers: Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]) -> None: From f14eefcec3504b38190febb1094091df92eb525f Mon Sep 17 00:00:00 2001 From: Nikhil <101217697+nik-mosaic@users.noreply.github.com> Date: Tue, 6 Dec 2022 10:05:11 -0800 Subject: [PATCH 04/11] Add space --- composer/algorithms/fused_layernorm/fused_layernorm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/composer/algorithms/fused_layernorm/fused_layernorm.py b/composer/algorithms/fused_layernorm/fused_layernorm.py index b9f9d85794..2900342155 100644 --- a/composer/algorithms/fused_layernorm/fused_layernorm.py +++ b/composer/algorithms/fused_layernorm/fused_layernorm.py @@ -38,6 +38,7 @@ def from_LayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerN torch.nn.LayerNorm), 'The replacement policy will look for all instances of torch.nn.LayerNorm' return APEXFusedLayerNorm(normalized_shape=layer.normalized_shape, eps=layer.eps) + def apply_fused_layernorm(model: torch.nn.Module, optimizers: Union[torch.optim.Optimizer, Sequence[torch.optim.Optimizer]]) -> None: """Replaces all instances of `torch.nn.LayerNorm` with a `apex.normalization.fused_layer_norm.FusedLayerNorm From 9657dfd9b8e4bfee09bf46af58df7e66833356ed Mon Sep 17 00:00:00 2001 From: Nikhil <101217697+nik-mosaic@users.noreply.github.com> Date: Mon, 6 Feb 2023 18:48:26 -0800 Subject: [PATCH 05/11] Fix model surgery faailure due to functional API change --- composer/utils/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/utils/inference.py b/composer/utils/inference.py index 4815f7b965..a31d15ebb0 100644 --- a/composer/utils/inference.py +++ b/composer/utils/inference.py @@ -162,7 +162,7 @@ def export_for_inference( # Apply surgery algorithms in the given order for alg in ensure_tuple(surgery_algs): - model = alg(model) + alg(model) if load_path is not None: # download checkpoint and load weights only From c1ac784ef86141ebcca6d703c7a55498269042b1 Mon Sep 17 00:00:00 2001 From: Nikhil <101217697+nik-mosaic@users.noreply.github.com> Date: Mon, 6 Feb 2023 19:23:15 -0800 Subject: [PATCH 06/11] Update inference algorithm comment --- composer/utils/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composer/utils/inference.py b/composer/utils/inference.py index a31d15ebb0..b885348fd2 100644 --- a/composer/utils/inference.py +++ b/composer/utils/inference.py @@ -115,7 +115,7 @@ def export_for_inference( dynamic_axes (Any, optional): Dictionary specifying the axes of input/output tensors as dynamic. May be required for exporting models using older versions of PyTorch when types cannot be inferred. surgery_algs (Union[Callable, Sequence[Callable]], optional): Algorithms that should be applied to the model - before loading a checkpoint. Each should be callable that takes a model and returns modified model. + before loading a checkpoint. Each should be callable that takes a model and returns None. ``surgery_algs`` are applied before ``transforms``. (default: ``None``) transforms (Sequence[Transform], optional): transformations (usually optimizations) that should be applied to the model. Each Transform should be a callable that takes a model and returns a modified model. From ed3e8e9dbe9aab89b2c1ceeadfcd13922be07c8e Mon Sep 17 00:00:00 2001 From: nik-mosaic Date: Thu, 23 Feb 2023 00:49:56 +0000 Subject: [PATCH 07/11] Update LPLN arguments to match LayerNorm --- .../low_precision_layernorm.py | 47 +++++++++++-------- tests/utils/test_inference.py | 21 ++++----- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py index b0c72f1e92..4005bd5371 100644 --- a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py +++ b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py @@ -42,7 +42,7 @@ def apply_low_precision_layernorm(model, if version.parse(torch.__version__) < version.parse('1.13') and precision == Precision.AMP_BF16: check_if_apex_installed() policy: Dict[Type[torch.nn.Module], module_surgery.ReplacementFunction] = { - torch.nn.LayerNorm: to_FusedLayerNorm + torch.nn.LayerNorm: _to_FusedLayerNorm } replaced_instances = module_surgery.replace_module_classes(module=model, optimizers=optimizers, policies=policy) @@ -88,14 +88,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]: class LPLayerNorm(torch.nn.LayerNorm): - def __init__(self, layer): - super().__init__(normalized_shape=layer.normalized_shape, - eps=layer.eps, - elementwise_affine=layer.elementwise_affine) - - with torch.no_grad(): - self.weight.copy_(layer.weight) - self.bias.copy_(layer.bias) + def __init__(self, normalized_shape, eps=1e-05, elementwise_affine=True, device=None, dtype=None): + super().__init__(normalized_shape=normalized_shape, + eps=eps, + elementwise_affine=elementwise_affine, + device=device, + dtype=dtype) def forward(self, x): module_device = x.device @@ -106,27 +104,38 @@ def forward(self, x): return F.layer_norm(downcast_x, self.normalized_shape, downcast_weight, downcast_bias, self.eps) -def _cast_if_autocast_enabled(hidden_states): - if not torch.is_autocast_enabled(): - return hidden_states - else: - return torch.cuda.amp.autocast_mode._cast(hidden_states, torch.get_autocast_gpu_dtype()) +def _cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor def check_if_apex_installed(): if not APEX_INSTALLED: raise ImportError( - 'https://github.com/NVIDIA/apex is not installed. The Low Precision LayerNorm algorithm cannot be applied. The MosaicML Docker Images (https://hub.docker.com/r/mosaicml/pytorch) contain a copy of APEX for easy use.' + 'https://github.com/NVIDIA/apex is not installed. The Low Precision LayerNorm algorithm cannot be applied on PyTorch <1.13 without Apex. The MosaicML Docker Images (https://hub.docker.com/r/mosaicml/pytorch) contain a copy of APEX for easy use.' ) def _to_LPLayerNorm(layer: torch.nn.Module, module_index: int) -> LPLayerNorm: - if not isinstance(layer, torch.nn.LayerNorm): - raise TypeError(f'Expected torch.nn.LayerNorm, got {type(layer)}') - return LPLayerNorm(layer) + """Defines a replacement policy from a `torch.nn.LayerNorm` to a `LPLayerNorm`""" + assert isinstance(layer, + torch.nn.LayerNorm), 'The replacement policy will look for all instances of torch.nn.LayerNorm' + lp_layernorm = LPLayerNorm(layer.normalized_shape, layer.eps, layer.elementwise_affine, layer.weight.device, + layer.weight.dtype) + with torch.no_grad(): + lp_layernorm.weight.copy_(layer.weight) + lp_layernorm.bias.copy_(layer.bias) + return lp_layernorm -def to_FusedLayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerNorm: +def _to_FusedLayerNorm(layer: torch.nn.Module, module_index: int) -> APEXFusedLayerNorm: """Defines a replacement policy from a `torch.nn.LayerNorm` to a `apex.normalization.fused_layer_norm`""" if not isinstance(layer, torch.nn.LayerNorm): raise TypeError(f'Expected torch.nn.LayerNorm, got {type(layer)}') diff --git a/tests/utils/test_inference.py b/tests/utils/test_inference.py index 7ea6c08155..8f57615356 100644 --- a/tests/utils/test_inference.py +++ b/tests/utils/test_inference.py @@ -14,7 +14,7 @@ import torch.nn as nn from torch.utils.data import DataLoader -from composer.core import State +from composer.core import Precision, State from composer.devices import DeviceCPU, DeviceGPU from composer.functional import apply_gated_linear_units from composer.loggers import InMemoryLogger, Logger @@ -24,8 +24,8 @@ from composer.trainer.trainer import Trainer from composer.utils import dist, export_with_logger, inference from composer.utils.device import get_device -from tests.common import device -from tests.common.datasets import RandomImageDataset +from tests.common import SimpleTransformerClassifier, device +from tests.common.datasets import RandomImageDataset, dummy_transformer_classifier_batch class MockFileUploader(LoggerDestination): @@ -35,14 +35,13 @@ def can_upload_files(self) -> bool: return True -@pytest.mark.parametrize( - 'model_cls, sample_input', - [ - (partial(composer_resnet, 'resnet18'), (torch.rand(4, 3, 224, 224), torch.randint(10, (4,)))), - ], -) +@pytest.mark.parametrize('model_cls, sample_input', [ + (partial(composer_resnet, 'resnet18'), (torch.rand(4, 3, 224, 224), torch.randint(10, (4,)))), + (SimpleTransformerClassifier, dummy_transformer_classifier_batch(vocab_size=10)), +]) def test_export_for_inference_torchscript(model_cls, sample_input): model = model_cls() + model.eval() orig_out = model(sample_input) @@ -163,7 +162,7 @@ def test_gpu_huggingface_export_for_inference_onnx(): import onnxruntime as ort import transformers - from composer.functional import apply_fused_layernorm + from composer.functional import apply_low_precision_layernorm from composer.models import HuggingFaceModel # HuggingFace Bert Model @@ -203,7 +202,7 @@ def test_gpu_huggingface_export_for_inference_onnx(): optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) apply_gated_linear_units(model, optimizer) - apply_fused_layernorm(model, optimizer) + apply_low_precision_layernorm(model, optimizer, Precision('amp_fp16')) model.eval() orig_out = model(sample_input) From 5a96345870d545b52858be2ce718be3008a3ecdf Mon Sep 17 00:00:00 2001 From: nik-mosaic <101217697+nik-mosaic@users.noreply.github.com> Date: Mon, 27 Feb 2023 16:34:56 -0800 Subject: [PATCH 08/11] Update composer/algorithms/low_precision_layernorm/low_precision_layernorm.py Change assertion to if-Raise Co-authored-by: Mihir Patel --- .../low_precision_layernorm/low_precision_layernorm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py index 4005bd5371..22785bebf9 100644 --- a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py +++ b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py @@ -125,8 +125,8 @@ def check_if_apex_installed(): def _to_LPLayerNorm(layer: torch.nn.Module, module_index: int) -> LPLayerNorm: """Defines a replacement policy from a `torch.nn.LayerNorm` to a `LPLayerNorm`""" - assert isinstance(layer, - torch.nn.LayerNorm), 'The replacement policy will look for all instances of torch.nn.LayerNorm' + if not isinstance(layer, torch.nn.LayerNorm): + raise TypeError(f'Expected torch.nn.LayerNorm, got {type(layer)}') lp_layernorm = LPLayerNorm(layer.normalized_shape, layer.eps, layer.elementwise_affine, layer.weight.device, layer.weight.dtype) with torch.no_grad(): From b9ae2168741a2ea55b578bcfff7b5ec6bc07c0f6 Mon Sep 17 00:00:00 2001 From: Nikhil <101217697+nik-mosaic@users.noreply.github.com> Date: Mon, 27 Feb 2023 16:39:39 -0800 Subject: [PATCH 09/11] add CPU amp option to LPGN --- .../low_precision_groupnorm.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py b/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py index b1c127e535..82e028aa28 100644 --- a/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py +++ b/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py @@ -94,12 +94,16 @@ def forward(self, x): return F.group_norm(downcast_x, self.num_groups, downcast_weight, downcast_bias, self.eps) -def _cast_if_autocast_enabled(hidden_states): - if not torch.is_autocast_enabled(): - return hidden_states - else: - return torch.cuda.amp.autocast_mode._cast(hidden_states, torch.get_autocast_gpu_dtype()) - +def _cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor def _to_LPGroupNorm(layer: torch.nn.Module, module_index: int) -> LPGroupNorm: if not isinstance(layer, torch.nn.GroupNorm): From ed4c829cc99b3527818586b986db4170752dfca9 Mon Sep 17 00:00:00 2001 From: Nikhil <101217697+nik-mosaic@users.noreply.github.com> Date: Mon, 27 Feb 2023 16:49:31 -0800 Subject: [PATCH 10/11] Change arg order to match new interface --- tests/utils/test_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils/test_inference.py b/tests/utils/test_inference.py index 8f57615356..e22d4a29d1 100644 --- a/tests/utils/test_inference.py +++ b/tests/utils/test_inference.py @@ -202,7 +202,7 @@ def test_gpu_huggingface_export_for_inference_onnx(): optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) apply_gated_linear_units(model, optimizer) - apply_low_precision_layernorm(model, optimizer, Precision('amp_fp16')) + apply_low_precision_layernorm(model, Precision('amp_fp16'), optimizer) model.eval() orig_out = model(sample_input) From 4b6c406913893c514e3105f23c17794d983575a1 Mon Sep 17 00:00:00 2001 From: Nikhil <101217697+nik-mosaic@users.noreply.github.com> Date: Mon, 27 Feb 2023 16:56:10 -0800 Subject: [PATCH 11/11] Run precommit --- .../low_precision_groupnorm/low_precision_groupnorm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py b/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py index 82e028aa28..b264916caf 100644 --- a/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py +++ b/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py @@ -105,6 +105,7 @@ def _cast_if_autocast_enabled(tensor): return tensor.to(dtype=dtype) return tensor + def _to_LPGroupNorm(layer: torch.nn.Module, module_index: int) -> LPGroupNorm: if not isinstance(layer, torch.nn.GroupNorm): raise TypeError(f'Expected torch.nn.GroupNorm, got {type(layer)}')