8
8
from typing import Any , Dict , Optional , TypeVar
9
9
10
10
import torch
11
+ import torch .backends .cuda
12
+ import torch .backends .cudnn
11
13
import torch .cuda .amp
12
14
import torch .utils .data
13
15
@@ -24,18 +26,29 @@ class DeviceGPU(Device):
24
26
25
27
Args:
26
28
device_id (int, optional): Integer ID of a GPU device to train with. If not specified, the local rank
27
- of the current process is used. Default: None.
29
+ of the current process is used. Default: None.
30
+ allow_tf32 (bool, optional): Whether to allow TF32 matrix multiplications. Defaults to True.
31
+ For more information, see :ref:`torch:tf32_on_ampere`.
28
32
"""
29
33
dist_backend = 'nccl'
30
34
31
- def __init__ (self , device_id : Optional [int ] = None ):
35
+ def __init__ (
36
+ self ,
37
+ device_id : Optional [int ] = None ,
38
+ * ,
39
+ allow_tf32 : bool = True ,
40
+ ):
32
41
if not torch .cuda .is_available ():
33
42
raise ValueError ('DeviceGPU cannot be created as torch.cuda is not available.' )
34
43
if not device_id :
35
44
device_id = dist .get_local_rank ()
36
45
self ._device = torch .device (f'cuda:{ device_id } ' )
37
46
torch .cuda .set_device (self ._device )
38
47
assert torch .cuda .current_device () == device_id
48
+ torch .backends .cuda .matmul .allow_tf32 = allow_tf32
49
+ # pyright error: "allow_tf32" is not a known member of module
50
+ # however, this flag exists on pytorch 1.9+: https://pytorch.org/docs/1.9.0/backends.html
51
+ torch .backends .cudnn .allow_tf32 = allow_tf32 # type: ignore
39
52
40
53
def module_to_device (self , module : T_nnModule ) -> T_nnModule :
41
54
return module .to (self ._device )
0 commit comments