diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 66457836..049f8c7b 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -122,7 +122,7 @@ def __tensor_flatten__(self): return ["_data", "_scale"], ctx @staticmethod - def __tensor_unflatten__(inner_tensors: Dict, metadata): + def __tensor_unflatten__(inner_tensors: Dict, metadata, outer_size, outer_stride): assert len(inner_tensors) == 2 return Float8Tensor( inner_tensors["_data"], diff --git a/test/fsdp_utils.py b/test/fsdp_utils.py new file mode 100644 index 00000000..3bc87d23 --- /dev/null +++ b/test/fsdp_utils.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import os + +import torch +import torch.distributed as dist +import torch.nn as nn + +from float8_experimental.float8_linear import Float8Linear +from float8_experimental.float8_linear_utils import ( + swap_linear_with_float8_linear, +) + +def setup(rank, world_size): + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + # initialize the process group + dist.init_process_group("nccl", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): + m = nn.Sequential( + nn.Linear(K, N, dtype=base_dtype), + nn.Linear(N, N, dtype=base_dtype), + ) + if is_fp8: + swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate) + return m + diff --git a/test/test_everything.sh b/test/test_everything.sh index 14b4813d..03e542d6 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -7,6 +7,7 @@ pytest test/test_base.py pytest test/test_sam.py pytest test/test_compile.py ./test/test_fsdp.sh +./test/test_fsdp_compile.sh ./test/test_tp.sh echo "all tests successful" diff --git a/test/test_fsdp.py b/test/test_fsdp.py index 547821d3..ab23589f 100644 --- a/test/test_fsdp.py +++ b/test/test_fsdp.py @@ -33,6 +33,8 @@ StateDictType, ) +from fsdp_utils import setup, cleanup, get_model + torch.manual_seed(0) # assumes user is running the script from /data/users/{user}/float8_experimental @@ -50,28 +52,6 @@ N_ITER = 1 -def setup(rank, world_size): - os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = "12355" - - # initialize the process group - dist.init_process_group("nccl", rank=rank, world_size=world_size) - - -def cleanup(): - dist.destroy_process_group() - - -def get_model(K, N, is_fp8, emulate, base_dtype=torch.float32): - m = nn.Sequential( - nn.Linear(K, N, dtype=base_dtype), - nn.Linear(N, N, dtype=base_dtype), - ) - if is_fp8: - swap_linear_with_float8_linear(m, Float8Linear, emulate=emulate) - return m - - # taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html # and modified def fsdp_main(rank, world_size, args): diff --git a/test/test_fsdp_compile.py b/test/test_fsdp_compile.py new file mode 100644 index 00000000..10409b9a --- /dev/null +++ b/test/test_fsdp_compile.py @@ -0,0 +1,138 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +""" +Smoke tests of FSDP + compile + Float8Linear +""" + +import os +import warnings + +import fire + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn +from float8_experimental.float8_linear_utils import ( + sync_float8_amax_and_scale_history, +) +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from fsdp_utils import setup, cleanup, get_model + +torch.manual_seed(0) + +B, M, K, N = 8, 8, 32, 32 +lr = 0.01 +N_ITER = 3 + + +def test_no_compile(world_size, emulate, base_dtype, rank, ref_input_local): + model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=base_dtype).to( + rank + ) + model = FSDP(model, use_orig_params=True) + optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) + + for _ in range(N_ITER): + optimizer.zero_grad() + y_local = model(ref_input_local) + y_local.sum().backward() + sync_float8_amax_and_scale_history(model) + optimizer.step() + + dist.barrier() + +def test_fsdp_then_compile_with_workaround(world_size, emulate, base_dtype, rank, ref_input_local): + model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=base_dtype).to( + rank + ) + model = FSDP(model, use_orig_params=True) + optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) + sync_func = torch.compile(sync_float8_amax_and_scale_history) + + for _ in range(N_ITER): + optimizer.zero_grad() + y_local = model(ref_input_local) + y_local.sum().backward() + sync_func(model) + optimizer.step() + + if _ == 0: + # right now things only work if we compile after the first iteration + # otherwise, we get https://gist.github.com/vkuzo/665e27a4d362f3999ad9a9e786acbe02 + # TODO(future): fix this + model = torch.compile(model) + + dist.barrier() + +def test_compile_then_fsdp(world_size, emulate, base_dtype, rank, ref_input_local): + model = get_model(K, N, is_fp8=True, emulate=emulate, base_dtype=base_dtype).to( + rank + ) + model = torch.compile(model) + model = FSDP(model, use_orig_params=True) + optimizer = torch.optim.SGD(model.parameters(), lr=lr * world_size) + sync_func = torch.compile(sync_float8_amax_and_scale_history) + + for _ in range(N_ITER): + optimizer.zero_grad() + y_local = model(ref_input_local) + y_local.sum().backward() + sync_func(model) + optimizer.step() + + dist.barrier() + + +# taken from https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html +# and modified +def fsdp_main(rank, world_size, args): + setup(rank, world_size) + torch.cuda.set_device(rank) + + emulate, = args + base_dtype = torch.bfloat16 + ref_input_global = torch.randn(B, M, K).cuda().to(base_dtype) + # basic distributed data sampling + bsz_global = ref_input_global.shape[0] + assert B % world_size == 0 + bsz_local_start = int(rank / world_size * B) + bsz_local_end = int((rank + 1) / world_size * B) + ref_input_local = ref_input_global[bsz_local_start:bsz_local_end].to(rank) + + test_args = world_size, emulate, base_dtype, rank, ref_input_local + + test_no_compile(*test_args) + # TODO(future): remove the workaround + test_fsdp_then_compile_with_workaround(*test_args) + # TOOD(future): unbreak this if needed + # test_compile_then_fsdp(*test_args) + # fails with https://gist.github.com/vkuzo/d7c65a073ebf47d64aa5b1a56df171c6 + + cleanup() + + +def run(): + + emulate = False + if not torch.cuda.is_available(): + warnings.warn("CUDA not available, running in emulation_mode") + emulate = True + elif torch.cuda.get_device_capability() < (9, 0): + warnings.warn( + f"CUDA capability {torch.cuda.get_device_capability()} < (9.0), running in emulation mode" + ) + emulate = True + + WORLD_SIZE = torch.cuda.device_count() + args = (emulate,) + mp.spawn(fsdp_main, args=(WORLD_SIZE, args), nprocs=WORLD_SIZE, join=True) + + +if __name__ == "__main__": + fire.Fire(run) diff --git a/test/test_fsdp_compile.sh b/test/test_fsdp_compile.sh new file mode 100755 index 00000000..20a22633 --- /dev/null +++ b/test/test_fsdp_compile.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +# the NCCL_DEBUG setting is to avoid log spew +# the CUDA_VISIBLE_DEVICES setting is for easy debugging +NCCL_DEBUG=WARN CUDA_VISIBLE_DEVICES=0,1 python test/test_fsdp_compile.py + +echo "done!" diff --git a/test/test_tp.py b/test/test_tp.py index 85ac078c..24e7ad74 100644 --- a/test/test_tp.py +++ b/test/test_tp.py @@ -37,11 +37,14 @@ def setup_model_parallel(): return local_rank, world_size -def test_column_parallel_linear(): +def test_column_parallel_linear(use_float8=True, use_compile=False): M, K, N = 128, 64, 256 m_ref = nn.Sequential(ColumnParallelLinear(K, N)).cuda() m = copy.deepcopy(m_ref) - swap_tp_linear_with_float8_linear(m) + if use_float8: + swap_tp_linear_with_float8_linear(m) + if use_compile: + m = torch.compile(m) x = torch.randn(M, K, device="cuda") y_ref = m_ref(x) @@ -51,7 +54,7 @@ def test_column_parallel_linear(): y.sum().backward() sqnr_y = compute_error(y_ref, y) - sqnr_w_grad = compute_error(m_ref[0].weight.grad, m[0].weight.grad) + sqnr_w_grad = compute_error(m_ref[0].weight.grad, getattr(m, '0').weight.grad) assert sqnr_y >= 20.0, f"sqnr_y {sqnr_y} is too low" assert sqnr_w_grad >= 20.0, f"sqnr_w_grad {sqnr_w_grad} is too low" @@ -109,7 +112,7 @@ def test_ffn(): sqnr_w2_grad = compute_error(m_ref.w2.weight.grad, m.w2.weight.grad) assert sqnr_y >= 20.0, f"sqnr_y {sqnr_y} is too low" - assert sqnr_w1_grad >= 14.0, f"sqnr_w1_grad {sqnr_w1_grad} is too low" + assert sqnr_w1_grad >= 13.0, f"sqnr_w1_grad {sqnr_w1_grad} is too low" assert sqnr_w2_grad >= 30.0, f"sqnr_w2_grad {sqnr_w2_grad} is too low" @@ -150,7 +153,16 @@ def test_ffn_sp(local_rank, world_size): if __name__ == "__main__": local_rank, world_size = setup_model_parallel() - test_column_parallel_linear() + test_column_parallel_linear(use_float8=True, use_compile=False) + + # below passes, but a lot of graph breaks: + # https://gist.github.com/vkuzo/670b2806e222bef04da5f173c758a165 + # test_column_parallel_linear(use_float8=False, use_compile=True) + + # below fails with + # https://gist.github.com/vkuzo/c9891ab38c8f341243b393e8e07d40d4https://gist.github.com/vkuzo/c9891ab38c8f341243b393e8e07d40d4 + # test_column_parallel_linear(use_float8=True, use_compile=True) + test_row_parallel_linear() test_ffn() test_ffn_sp(local_rank, world_size)