Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions cacheflow/models/layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import torch.nn as nn

from cacheflow import layernorm_ops


class RMSNorm(nn.Module):

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = torch.empty_like(x)
layernorm_ops.rms_norm(
out,
x,
self.weight.data,
self.variance_epsilon,
)
return out
23 changes: 4 additions & 19 deletions cacheflow/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from cacheflow.models import InputMetadata
from cacheflow.models.attention import LlamaCacheFlowAttention
from cacheflow.models.layernorm import RMSNorm
from cacheflow.models.sample import Sampler
from cacheflow.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
Expand All @@ -23,22 +24,6 @@
KVCache = Tuple[torch.Tensor, torch.Tensor]


class LlamaRMSNorm(nn.Module):

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states


class LlamaMLP(nn.Module):

def __init__(
Expand Down Expand Up @@ -148,8 +133,8 @@ def __init__(self, config: LlamaConfig):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
Expand Down Expand Up @@ -190,7 +175,7 @@ def __init__(self, config: LlamaConfig):
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size,
perform_initialization=False)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
Expand Down
1 change: 1 addition & 0 deletions csrc/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "attention_utils.h"
#include "cuda_primitives.h"
#include "reduction_utils.h"

#include <algorithm>

Expand Down
39 changes: 0 additions & 39 deletions csrc/attention_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,45 +159,6 @@ struct Qk_dot<uint16_t, 4> {
}
};

////////////////////////////////////////////////////////////////////////////////////////////////////

template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float* red_smem, float sum)
{

// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;

// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}

// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}

// Make sure the data is in shared memory.
__syncthreads();

// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}

// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}

// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}

} // namespace cacheflow

#undef MMHA_USE_FP32_ACUM_FOR_FMA
Expand Down
14 changes: 14 additions & 0 deletions csrc/layernorm.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#include <torch/extension.h>

void rms_norm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weight,
float epsilon);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
}
61 changes: 61 additions & 0 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "reduction_utils.h"

namespace cacheflow {

// TODO(woosuk): Further optimize this kernel.
template<typename scalar_t>
__global__ void rms_norm_kernel(
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
const scalar_t* __restrict__ input, // [num_tokens, hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) input[blockIdx.x * hidden_size + idx];
variance += x * x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) input[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}

} // namespace cacheflow

void rms_norm(
torch::Tensor& out, // [num_tokens, hidden_size]
torch::Tensor& input, // [num_tokens, hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int num_tokens = input.size(0);
int hidden_size = input.size(1);

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.scalar_type(),
"rms_norm_kernel",
[&] {
cacheflow::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
}
76 changes: 76 additions & 0 deletions csrc/reduction_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#pragma once

namespace cacheflow {

template<int WARPS_PER_BLOCK, int WARP_SIZE = 32>
inline __device__ float block_sum(float* red_smem, float sum)
{

// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;

// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}

// Warp leaders store the data to shared memory.
if (lane == 0) {
red_smem[warp] = sum;
}

// Make sure the data is in shared memory.
__syncthreads();

// The warps compute the final sums.
if (lane < WARPS_PER_BLOCK) {
sum = red_smem[lane];
}

// Parallel reduction inside the warp.
#pragma unroll
for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
}

// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
}

#define FINAL_MASK 0xffffffff

template<typename T>
__inline__ __device__ T warpReduceSum(T val)
{
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1)
val += __shfl_xor_sync(FINAL_MASK, val, mask, 32);
return val;
}

/* Calculate the sum of all elements in a block */
template<typename T>
__inline__ __device__ T blockReduceSum(T val)
{
static __shared__ T shared[32];
int lane = threadIdx.x & 0x1f;
int wid = threadIdx.x >> 5;

val = warpReduceSum<T>(val);

if (lane == 0)
shared[wid] = val;

__syncthreads();

// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val = (threadIdx.x < (blockDim.x / 32.f)) ? shared[lane] : (T)(0.0f);
val = warpReduceSum<T>(val);

return val;
}

} // namespace cacheflow
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@
)
ext_modules.append(positional_encoding_extension)

# Layer normalization kernels.
layernorm_extension = cpp_extension.CUDAExtension(
name='cacheflow.layernorm_ops',
sources=['csrc/layernorm.cpp', 'csrc/layernorm_kernels.cu'],
extra_compile_args={'cxx': CXX_FLAGS, 'nvcc': NVCC_FLAGS},
)
ext_modules.append(layernorm_extension)

setuptools.setup(
name='cacheflow',
ext_modules=ext_modules,
Expand Down
53 changes: 53 additions & 0 deletions tests/kernels/layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.nn as nn

from cacheflow import layernorm_ops


class RefRMSNorm(nn.Module):

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
weight = torch.randn(hidden_size) / (hidden_size ** 0.5)
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps

def forward(self, hidden_states):
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states


@torch.inference_mode()
def test_rms_norm(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
) -> None:
x = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
ref = RefRMSNorm(hidden_size).to(dtype).cuda()

out = torch.empty_like(x)
layernorm_ops.rms_norm(
out,
x,
ref.weight.data,
ref.variance_epsilon,
)
ref_out = ref(x)
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)


if __name__ == '__main__':
for dtype in [torch.half, torch.float]:
for num_tokens in [7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
f'{num_tokens}, hidden_size={hidden_size}')
test_rms_norm(
num_tokens=num_tokens,
hidden_size=hidden_size,
dtype=dtype,
)