Skip to content

Commit 35b9eb9

Browse files
authored
Add h200 to flops dict (#3889)
1 parent 1204a01 commit 35b9eb9

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

composer/callbacks/speed_monitor.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,25 @@
2020

2121
__all__ = ['SpeedMonitor']
2222

23+
_HSERIES_SXM = {
24+
'fp64': 67e12,
25+
'fp32': 67e12,
26+
'tf32': 989e12 / 2,
27+
'fp16': 1.979e15 / 2,
28+
'amp_fp16': 1.979e15 / 2,
29+
'bf16': 1.979e15 / 2,
30+
'amp_bf16': 1.979e15 / 2,
31+
'fp8': 3.958e15 / 2,
32+
'amp_fp8': 3.958e15 / 2,
33+
'int8': 3.958e15 / 2,
34+
}
35+
2336
GPU_AVAILABLE_FLOPS = {
37+
# source: https://resources.nvidia.com/en-us-data-center-overview-mc/en-us-data-center-overview/hpc-datasheet-sc23-h200
38+
'h200-sxm': _HSERIES_SXM,
2439
# source: https://resources.nvidia.com/en-us-tensor-core/nvidia-tensor-core-gpu-datasheet
2540
# nvidia publishes spec sheet with a 2x sparsity factor
26-
'h100-sxm': {
27-
'fp64': 67e12,
28-
'fp32': 67e12,
29-
'tf32': 989e12 / 2,
30-
'fp16': 1.979e15 / 2,
31-
'amp_fp16': 1.979e15 / 2,
32-
'bf16': 1.979e15 / 2,
33-
'amp_bf16': 1.979e15 / 2,
34-
'fp8': 3.958e15 / 2,
35-
'amp_fp8': 3.958e15 / 2,
36-
'int8': 3.958e15 / 2,
37-
},
41+
'h100-sxm': _HSERIES_SXM,
3842
'h100-pcie': {
3943
'fp64': 51e12,
4044
'fp32': 51e12,
@@ -116,7 +120,11 @@ def get_gpu_flops_available(state: State):
116120
if torch.cuda.is_available():
117121
# torch.cuda.get_device_name() ex output: 'NVIDIA A100-SXM4-40GB'
118122
device_name = torch.cuda.get_device_name().lower()
119-
if 'h100' in device_name and 'hbm3' in device_name:
123+
if 'h200' in device_name:
124+
# We just assume SXM because device name does not differentiate, and we would have to check
125+
# power or bandwidth or something.
126+
device_name = 'h200-sxm'
127+
elif 'h100' in device_name and 'hbm3' in device_name:
120128
device_name = 'h100-sxm'
121129
elif 'h100' in device_name and ('pcie' in device_name or 'hbm2e' in device_name):
122130
device_name = 'h100-pcie'

0 commit comments

Comments
 (0)