From cd30ae6abad7c2d4a9bac7fb9324303111d9400b Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Tue, 22 Oct 2024 20:48:55 +0300 Subject: [PATCH 1/2] added batch norm --- main.py | 46 ++++++++ src/liger_kernel/ops/SyncBatchNorm.py | 107 +++++++++++++++++++ src/liger_kernel/transformers/functional.py | 2 + test/transformers/test_batched_layer_norm.py | 61 +++++++++++ 4 files changed, 216 insertions(+) create mode 100644 main.py create mode 100644 src/liger_kernel/ops/SyncBatchNorm.py create mode 100644 test/transformers/test_batched_layer_norm.py diff --git a/main.py b/main.py new file mode 100644 index 000000000..4289f49b7 --- /dev/null +++ b/main.py @@ -0,0 +1,46 @@ +import torch +import torch.nn as nn +import torch.distributed as dist + +import os +import datetime +os.environ['MASTER_ADDR'] = 'localhost' +os.environ['MASTER_PORT'] = '12355' +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" +class MyCNN(nn.Module): + def __init__(self): + super(MyCNN, self).__init__() + self.conv1 = nn.Conv2d(1, 10, kernel_size=3) + # Replace BatchNorm with SyncBatchNorm for distributed training + self.bn1 = nn.SyncBatchNorm(10) + self.conv2 = nn.Conv2d(10, 20, kernel_size=3) + self.bn2 = nn.SyncBatchNorm(20) + self.pool = nn.MaxPool2d(2, 2) + self.fc = nn.Linear(1440, 10) # Assuming output size after convolutions + + def forward(self, x): + x = self.pool(torch.relu(self.bn1(self.conv1(x)))) + x = self.pool(torch.relu(self.bn2(self.conv2(x)))) + x = x.view(-1, 1440) # Flatten for fully-connected layer + x = self.fc(x) + return x + + +def main(): + # Initialize distributed training (replace with your specific initialization logic) + dist.init_process_group(backend="nccl", world_size=1, rank=0) # Example for 4 processes + print('main') + model = MyCNN() + model.to('cuda') + # Wrap the model with DistributedDataParallel (DDP) for distributed training + model = nn.parallel.DistributedDataParallel(model) + + # Define your optimizer, loss function, and training loop (omitted for brevity) + + # ... training code ... + + dist.destroy_process_group() # Clean up after training + + +if __name__ == "__main__": + main() diff --git a/src/liger_kernel/ops/SyncBatchNorm.py b/src/liger_kernel/ops/SyncBatchNorm.py new file mode 100644 index 000000000..98f2ad7fc --- /dev/null +++ b/src/liger_kernel/ops/SyncBatchNorm.py @@ -0,0 +1,107 @@ +import operator + +import torch +import triton +import triton.language as tl + +from liger_kernel.ops.utils import ( + calculate_settings, + compare_version, + ensure_contiguous, +) + +if compare_version("triton", operator.ge, "3.0.0"): + try: + # typical import path with dispatch available + from triton.language.extra.libdevice import rsqrt + except ModuleNotFoundError: + # for working with NGC containers + from triton.language.extra.cuda.libdevice import rsqrt +else: + from triton.language.math import rsqrt + + +@triton.jit +def _batch_norm_forward_kernel( + X_ptr, + X_batch_stride, + X_row_stride, + Y_ptr, + Y_batch_stride, + Y_row_stride, + # MEAN_ptr, + # MEAN_row_stride, + # VARIANCE_ptr, + # VARIANCE_row_stride, + # AXIS_ptr, + # AXIS_row_stride, + n_cols, + eps, + mean, + axis, + scale, + offset, + variance, + BLOCK_SIZE: tl.constexpr): + pass + row_idx = tl.program_id(1) + batch_idx = tl.program_id(0) + col_offsets = tl.arange(0, BLOCK_SIZE) * batch_idx + + mask = True # col_offsets < n_cols + X_ptr += (X_row_stride) * (row_idx) # +X_batch_stride + Y_ptr += (Y_row_stride) * (row_idx) # +Y_batch_stride + X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0) + + inv = tl.rsqrt(tl.load(variance) + eps) + + if scale is not None: + inv = inv * tl.load(scale) + + res = -tl.load(mean) * inv + if offset is not None: + res = res + tl.load(offset) + + tl.store(Y_ptr + col_offsets, X_row * inv + res) + + +def batch_norm_forward(X, axis, offset, scale, eps, mean, variance): + batch, n_rows, n_cols = X.shape + BLOCK_SIZE, num_warps = calculate_settings(n_cols) + + Y = torch.zeros((batch, n_rows, n_cols), dtype=X.dtype, device=X.device) + + _batch_norm_forward_kernel[(batch * n_cols * n_rows, n_rows,)]( # [(n_rows,)] + X, + X.stride(0), + X.stride(1), + Y, + Y.stride(0), + Y.stride(1), + n_cols, + eps, + mean, + axis, + scale, + offset, + variance, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + ) + + return Y + + +class LigerBatchNormFunction(torch.autograd.Function): + variance = torch.tensor([1] * 32 * 8).to('cuda') + scale = torch.tensor([1] * 32 * 8).to('cuda') + offset = torch.tensor([0] * 32 * 8).to('cuda') + mean = torch.tensor([0] * 32 * 8).to('cuda') + + @staticmethod + @ensure_contiguous + def forward(ctx, X, axis, eps): + X = batch_norm_forward(X, axis, LigerBatchNormFunction.offset, LigerBatchNormFunction.scale, + eps, LigerBatchNormFunction.mean, LigerBatchNormFunction.variance) + + return X diff --git a/src/liger_kernel/transformers/functional.py b/src/liger_kernel/transformers/functional.py index d63045efb..a22a4babc 100644 --- a/src/liger_kernel/transformers/functional.py +++ b/src/liger_kernel/transformers/functional.py @@ -5,6 +5,7 @@ from liger_kernel.ops.geglu import LigerGELUMulFunction from liger_kernel.ops.kl_div import LigerKLDivLossFunction from liger_kernel.ops.layer_norm import LigerLayerNormFunction +from liger_kernel.ops.SyncBatchNorm import LigerBatchNormFunction from liger_kernel.ops.rms_norm import LigerRMSNormFunction from liger_kernel.ops.rope import LigerRopeFunction from liger_kernel.ops.swiglu import LigerSiLUMulFunction @@ -16,4 +17,5 @@ liger_rms_norm = LigerRMSNormFunction.apply liger_rope = LigerRopeFunction.apply liger_layer_norm = LigerLayerNormFunction.apply +liger_batch_norm = LigerBatchNormFunction.apply liger_kl_div = LigerKLDivLossFunction.apply diff --git a/test/transformers/test_batched_layer_norm.py b/test/transformers/test_batched_layer_norm.py new file mode 100644 index 000000000..d853f3d35 --- /dev/null +++ b/test/transformers/test_batched_layer_norm.py @@ -0,0 +1,61 @@ +import numpy +import pytest +import torch +import keras +from liger_kernel.transformers.functional import liger_layer_norm,liger_batch_norm + +import os +os.environ['MASTER_ADDR'] = 'localhost' +os.environ['MASTER_PORT'] = '12355' +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + + +@pytest.mark.parametrize( + "hidden_size", + [ + 32, + 32, + 32, + 32, + ], +) +@pytest.mark.parametrize( + "batch_size, seq_len", + [ + (2, 32), + (8, 32), + (2, 32), + ], +) +@pytest.mark.parametrize( + "dtype, atol, rtol", + [ + (torch.float32, 1e-5, 1e-5), + ], +) + +def test_liger_bacthed_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): + torch.manual_seed(0) + + x = torch.randn( + batch_size, seq_len, hidden_size, dtype=dtype, device="cuda", requires_grad=True + ) + keras_ln = keras.layers.BatchNormalization(epsilon=1e-6) + + keras_ln.training = False + + axis = -1 + + eps = 1e-6 + + + liger_output = liger_batch_norm(x, axis, eps) + + x = x.detach().cpu().numpy() + keras_output = keras_ln(x) + + print(torch.Tensor(numpy.array(keras_output)).to('cuda')) + print(liger_output) + assert torch.allclose(liger_output,torch.Tensor(numpy.array(keras_output)).to('cuda'), atol=atol, rtol=rtol) + + From 9b237b838f352e0d065c5a950b0127e5efe6b1ba Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Tue, 22 Oct 2024 21:28:49 +0300 Subject: [PATCH 2/2] formating --- src/liger_kernel/ops/SyncBatchNorm.py | 85 ++++++++++---------- test/transformers/test_batched_layer_norm.py | 27 ++++--- 2 files changed, 55 insertions(+), 57 deletions(-) diff --git a/src/liger_kernel/ops/SyncBatchNorm.py b/src/liger_kernel/ops/SyncBatchNorm.py index 98f2ad7fc..9af5be7cd 100644 --- a/src/liger_kernel/ops/SyncBatchNorm.py +++ b/src/liger_kernel/ops/SyncBatchNorm.py @@ -1,48 +1,33 @@ -import operator - import torch import triton import triton.language as tl -from liger_kernel.ops.utils import ( - calculate_settings, - compare_version, - ensure_contiguous, -) - -if compare_version("triton", operator.ge, "3.0.0"): - try: - # typical import path with dispatch available - from triton.language.extra.libdevice import rsqrt - except ModuleNotFoundError: - # for working with NGC containers - from triton.language.extra.cuda.libdevice import rsqrt -else: - from triton.language.math import rsqrt +from liger_kernel.ops.utils import calculate_settings, ensure_contiguous @triton.jit def _batch_norm_forward_kernel( - X_ptr, - X_batch_stride, - X_row_stride, - Y_ptr, - Y_batch_stride, - Y_row_stride, - # MEAN_ptr, - # MEAN_row_stride, - # VARIANCE_ptr, - # VARIANCE_row_stride, - # AXIS_ptr, - # AXIS_row_stride, - n_cols, - eps, - mean, - axis, - scale, - offset, - variance, - BLOCK_SIZE: tl.constexpr): + X_ptr, + X_batch_stride, + X_row_stride, + Y_ptr, + Y_batch_stride, + Y_row_stride, + # MEAN_ptr, + # MEAN_row_stride, + # VARIANCE_ptr, + # VARIANCE_row_stride, + # AXIS_ptr, + # AXIS_row_stride, + n_cols, + eps, + mean, + axis, + scale, + offset, + variance, + BLOCK_SIZE: tl.constexpr, +): pass row_idx = tl.program_id(1) batch_idx = tl.program_id(0) @@ -71,7 +56,12 @@ def batch_norm_forward(X, axis, offset, scale, eps, mean, variance): Y = torch.zeros((batch, n_rows, n_cols), dtype=X.dtype, device=X.device) - _batch_norm_forward_kernel[(batch * n_cols * n_rows, n_rows,)]( # [(n_rows,)] + _batch_norm_forward_kernel[ + ( + batch * n_cols * n_rows, + n_rows, + ) + ]( # [(n_rows,)] X, X.stride(0), X.stride(1), @@ -93,15 +83,22 @@ def batch_norm_forward(X, axis, offset, scale, eps, mean, variance): class LigerBatchNormFunction(torch.autograd.Function): - variance = torch.tensor([1] * 32 * 8).to('cuda') - scale = torch.tensor([1] * 32 * 8).to('cuda') - offset = torch.tensor([0] * 32 * 8).to('cuda') - mean = torch.tensor([0] * 32 * 8).to('cuda') + variance = torch.tensor([1] * 32 * 8).to("cuda") + scale = torch.tensor([1] * 32 * 8).to("cuda") + offset = torch.tensor([0] * 32 * 8).to("cuda") + mean = torch.tensor([0] * 32 * 8).to("cuda") @staticmethod @ensure_contiguous def forward(ctx, X, axis, eps): - X = batch_norm_forward(X, axis, LigerBatchNormFunction.offset, LigerBatchNormFunction.scale, - eps, LigerBatchNormFunction.mean, LigerBatchNormFunction.variance) + X = batch_norm_forward( + X, + axis, + LigerBatchNormFunction.offset, + LigerBatchNormFunction.scale, + eps, + LigerBatchNormFunction.mean, + LigerBatchNormFunction.variance, + ) return X diff --git a/test/transformers/test_batched_layer_norm.py b/test/transformers/test_batched_layer_norm.py index d853f3d35..794dcb467 100644 --- a/test/transformers/test_batched_layer_norm.py +++ b/test/transformers/test_batched_layer_norm.py @@ -1,12 +1,14 @@ +import os + +import keras import numpy import pytest import torch -import keras -from liger_kernel.transformers.functional import liger_layer_norm,liger_batch_norm -import os -os.environ['MASTER_ADDR'] = 'localhost' -os.environ['MASTER_PORT'] = '12355' +from liger_kernel.transformers.functional import liger_batch_norm + +os.environ["MASTER_ADDR"] = "localhost" +os.environ["MASTER_PORT"] = "12355" os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" @@ -33,7 +35,6 @@ (torch.float32, 1e-5, 1e-5), ], ) - def test_liger_bacthed_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, rtol): torch.manual_seed(0) @@ -48,14 +49,14 @@ def test_liger_bacthed_layer_norm(batch_size, seq_len, hidden_size, dtype, atol, eps = 1e-6 - - liger_output = liger_batch_norm(x, axis, eps) + liger_output = liger_batch_norm(x, axis, eps) x = x.detach().cpu().numpy() keras_output = keras_ln(x) - print(torch.Tensor(numpy.array(keras_output)).to('cuda')) - print(liger_output) - assert torch.allclose(liger_output,torch.Tensor(numpy.array(keras_output)).to('cuda'), atol=atol, rtol=rtol) - - + assert torch.allclose( + liger_output, + torch.Tensor(numpy.array(keras_output)).to("cuda"), + atol=atol, + rtol=rtol, + )