Skip to content

Commit a8321fe

Browse files
committed
Default to allow_tf32=True for GPU Devices (#1275)
1 parent 419d3b5 commit a8321fe

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

composer/trainer/devices/device_gpu.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from typing import Any, Dict, Optional, TypeVar
99

1010
import torch
11+
import torch.backends.cuda
12+
import torch.backends.cudnn
1113
import torch.cuda.amp
1214
import torch.utils.data
1315

@@ -24,18 +26,29 @@ class DeviceGPU(Device):
2426
2527
Args:
2628
device_id (int, optional): Integer ID of a GPU device to train with. If not specified, the local rank
27-
of the current process is used. Default: None.
29+
of the current process is used. Default: None.
30+
allow_tf32 (bool, optional): Whether to allow TF32 matrix multiplications. Defaults to True.
31+
For more information, see :ref:`torch:tf32_on_ampere`.
2832
"""
2933
dist_backend = 'nccl'
3034

31-
def __init__(self, device_id: Optional[int] = None):
35+
def __init__(
36+
self,
37+
device_id: Optional[int] = None,
38+
*,
39+
allow_tf32: bool = True,
40+
):
3241
if not torch.cuda.is_available():
3342
raise ValueError('DeviceGPU cannot be created as torch.cuda is not available.')
3443
if not device_id:
3544
device_id = dist.get_local_rank()
3645
self._device = torch.device(f'cuda:{device_id}')
3746
torch.cuda.set_device(self._device)
3847
assert torch.cuda.current_device() == device_id
48+
torch.backends.cuda.matmul.allow_tf32 = allow_tf32
49+
# pyright error: "allow_tf32" is not a known member of module
50+
# however, this flag exists on pytorch 1.9+: https://pytorch.org/docs/1.9.0/backends.html
51+
torch.backends.cudnn.allow_tf32 = allow_tf32 # type: ignore
3952

4053
def module_to_device(self, module: T_nnModule) -> T_nnModule:
4154
return module.to(self._device)

0 commit comments

Comments
 (0)