@@ -42,7 +42,7 @@ def apply_low_precision_layernorm(model,
42
42
if version .parse (torch .__version__ ) < version .parse ('1.13' ) and precision == Precision .AMP_BF16 :
43
43
check_if_apex_installed ()
44
44
policy : Dict [Type [torch .nn .Module ], module_surgery .ReplacementFunction ] = {
45
- torch .nn .LayerNorm : to_FusedLayerNorm
45
+ torch .nn .LayerNorm : _to_FusedLayerNorm
46
46
}
47
47
48
48
replaced_instances = module_surgery .replace_module_classes (module = model , optimizers = optimizers , policies = policy )
@@ -88,14 +88,12 @@ def apply(self, event: Event, state: State, logger: Logger) -> Optional[int]:
88
88
89
89
class LPLayerNorm (torch .nn .LayerNorm ):
90
90
91
- def __init__ (self , layer ):
92
- super ().__init__ (normalized_shape = layer .normalized_shape ,
93
- eps = layer .eps ,
94
- elementwise_affine = layer .elementwise_affine )
95
-
96
- with torch .no_grad ():
97
- self .weight .copy_ (layer .weight )
98
- self .bias .copy_ (layer .bias )
91
+ def __init__ (self , normalized_shape , eps = 1e-05 , elementwise_affine = True , device = None , dtype = None ):
92
+ super ().__init__ (normalized_shape = normalized_shape ,
93
+ eps = eps ,
94
+ elementwise_affine = elementwise_affine ,
95
+ device = device ,
96
+ dtype = dtype )
99
97
100
98
def forward (self , x ):
101
99
module_device = x .device
@@ -106,27 +104,38 @@ def forward(self, x):
106
104
return F .layer_norm (downcast_x , self .normalized_shape , downcast_weight , downcast_bias , self .eps )
107
105
108
106
109
- def _cast_if_autocast_enabled (hidden_states ):
110
- if not torch .is_autocast_enabled ():
111
- return hidden_states
112
- else :
113
- return torch .cuda .amp .autocast_mode ._cast (hidden_states , torch .get_autocast_gpu_dtype ())
107
+ def _cast_if_autocast_enabled (tensor ):
108
+ if torch .is_autocast_enabled ():
109
+ if tensor .device .type == 'cuda' :
110
+ dtype = torch .get_autocast_gpu_dtype ()
111
+ elif tensor .device .type == 'cpu' :
112
+ dtype = torch .get_autocast_cpu_dtype ()
113
+ else :
114
+ raise NotImplementedError ()
115
+ return tensor .to (dtype = dtype )
116
+ return tensor
114
117
115
118
116
119
def check_if_apex_installed ():
117
120
if not APEX_INSTALLED :
118
121
raise ImportError (
119
- 'https://github.com/NVIDIA/apex is not installed. The Low Precision LayerNorm algorithm cannot be applied. The MosaicML Docker Images (https://hub.docker.com/r/mosaicml/pytorch) contain a copy of APEX for easy use.'
122
+ 'https://github.com/NVIDIA/apex is not installed. The Low Precision LayerNorm algorithm cannot be applied on PyTorch <1.13 without Apex . The MosaicML Docker Images (https://hub.docker.com/r/mosaicml/pytorch) contain a copy of APEX for easy use.'
120
123
)
121
124
122
125
123
126
def _to_LPLayerNorm (layer : torch .nn .Module , module_index : int ) -> LPLayerNorm :
124
- if not isinstance (layer , torch .nn .LayerNorm ):
125
- raise TypeError (f'Expected torch.nn.LayerNorm, got { type (layer )} ' )
126
- return LPLayerNorm (layer )
127
+ """Defines a replacement policy from a `torch.nn.LayerNorm` to a `LPLayerNorm`"""
128
+ assert isinstance (layer ,
129
+ torch .nn .LayerNorm ), 'The replacement policy will look for all instances of torch.nn.LayerNorm'
130
+ lp_layernorm = LPLayerNorm (layer .normalized_shape , layer .eps , layer .elementwise_affine , layer .weight .device ,
131
+ layer .weight .dtype )
132
+ with torch .no_grad ():
133
+ lp_layernorm .weight .copy_ (layer .weight )
134
+ lp_layernorm .bias .copy_ (layer .bias )
135
+ return lp_layernorm
127
136
128
137
129
- def to_FusedLayerNorm (layer : torch .nn .Module , module_index : int ) -> APEXFusedLayerNorm :
138
+ def _to_FusedLayerNorm (layer : torch .nn .Module , module_index : int ) -> APEXFusedLayerNorm :
130
139
"""Defines a replacement policy from a `torch.nn.LayerNorm` to a `apex.normalization.fused_layer_norm`"""
131
140
if not isinstance (layer , torch .nn .LayerNorm ):
132
141
raise TypeError (f'Expected torch.nn.LayerNorm, got { type (layer )} ' )
0 commit comments