diff --git a/float8_experimental/float8_linear.py b/float8_experimental/float8_linear.py index e67fe9ad..c7968217 100644 --- a/float8_experimental/float8_linear.py +++ b/float8_experimental/float8_linear.py @@ -18,7 +18,7 @@ ) from float8_experimental.float8_python_api import mm_float8 -from float8_experimental.float8_tensor import Float8Tensor, to_float8 +from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import ( amax_history_to_scale, @@ -44,10 +44,12 @@ def forward( fp8_scale_dL_dY, scale_fn_name, is_amax_initialized, + emulate: bool, ): ctx.save_for_backward(fp8_amax_dL_dY, fp8_amax_history_dL_dY, fp8_scale_dL_dY) ctx.scale_fn_name = scale_fn_name ctx.is_amax_initialized = is_amax_initialized + ctx.emulate = emulate return tensor @staticmethod @@ -69,99 +71,11 @@ def backward(ctx, go): fp8_amax_dL_dY.fill_(tensor_to_amax(go)) go_scaled = go * fp8_scale_dL_dY bits_fp8 = to_fp8_saturated(go_scaled, torch.float8_e5m2) - empty_grads = None, None, None, None, None - res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype) + empty_grads = None, None, None, None, None, None + res = Float8Tensor(bits_fp8, fp8_scale_dL_dY, go.dtype, emulate=ctx.emulate) return res, *empty_grads -class float8_linear(torch.autograd.Function): - """ - Like F.linear, but with X and W in float8 - """ - - @staticmethod - def forward( - ctx, - x_fp8, - w_fp8, - is_amax_initialized, - scale_fn_name, - emulate: bool, - ): - ctx.save_for_backward(x_fp8, w_fp8) - ctx.scale_fn_name = scale_fn_name - ctx.emulate = emulate - orig_shape = x_fp8._data.shape - x_fp8_reshaped = Float8Tensor( - x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype - ) - ctx.is_amax_initialized = is_amax_initialized - - w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype) - - res_bits, _output_amax = mm_float8( - x_fp8_reshaped, w_fp8_t, output_dtype=x_fp8._orig_dtype, emulate=emulate - ) - res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1]) - return res_bits - - @staticmethod - def backward(ctx, go_fp8): - x_fp8, w_fp8 = ctx.saved_tensors - scale_fn_name = ctx.scale_fn_name - emulate = ctx.emulate - is_amax_initialized = ctx.is_amax_initialized - - go_fp8_orig_shape = go_fp8._data.shape - go_fp8_reshaped = Float8Tensor( - go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]), - go_fp8._scale, - go_fp8._orig_dtype, - ) - - w_fp8_t_c_t = Float8Tensor( - w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype - ) - - # - # calculate dL/dX - # - dL_dX, _dL_dX_amax = mm_float8( - go_fp8_reshaped, - w_fp8_t_c_t, - output_dtype=x_fp8._orig_dtype, - emulate=emulate, - ) - dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1]) - - x_fp8_orig_shape = x_fp8._data.shape - x_fp8_reshaped_t_c = Float8Tensor( - x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(), - x_fp8._scale, - x_fp8._orig_dtype, - ) - - go_fp8_reshaped_t_c_t = Float8Tensor( - go_fp8_reshaped._data.t().contiguous().t(), - go_fp8_reshaped._scale, - go_fp8_reshaped._orig_dtype, - ) - - # - # calculate dL/dW - # - dL_dW, _dL_dW_amax = mm_float8( - x_fp8_reshaped_t_c, - go_fp8_reshaped_t_c_t, - output_dtype=x_fp8._orig_dtype, - emulate=emulate, - ) - dL_dW = dL_dW.t() - - empty_grads = None, None, None, None, None, None, None, None, None - return dL_dX, dL_dW, *empty_grads - - @dataclasses.dataclass class DelayedScalingRecipe: # Controls the history length of amax buffers @@ -221,13 +135,17 @@ def __init__(self, *args, **kwargs): # will access the scale when it has ensured that it is on GPU. self._float8_tensor_ctor = lambda *args, **kwargs: Float8Tensor(*args, **kwargs) - def cast_x_to_float8(self, x, is_amax_initialized): + def cast_x_to_float8( + self, x: torch.Tensor, is_amax_initialized: bool + ) -> torch.Tensor: # Duplicate the autocast logic for F.linear, so that the output # of our module has the right original precision if torch.is_autocast_enabled(): # For now, hardcode to GPU's autocast dtype # if we need CPU support in the future, we can add it - x = x.to(torch.get_autocast_gpu_dtype()) + autocast_dtype = torch.get_autocast_gpu_dtype() + x = x.to(autocast_dtype) + self.bias_dtype = autocast_dtype scale_fn_name = self.recipe.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( @@ -239,10 +157,14 @@ def cast_x_to_float8(self, x, is_amax_initialized): torch.float8_e4m3fn, is_amax_initialized, ) - x_fp8 = to_float8(x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x) + x_fp8 = Float8Tensor.to_float8( + x, self.fp8_scale_x, torch.float8_e4m3fn, self.fp8_amax_x, self.emulate + ) return x_fp8 - def cast_w_to_float8(self, w, is_amax_initialized): + def cast_w_to_float8( + self, w: torch.Tensor, is_amax_initialized: bool + ) -> torch.Tensor: scale_fn_name = self.recipe.scale_fn_name _maybe_initialize_amaxes_scales_for_float8_cast( w, @@ -253,10 +175,14 @@ def cast_w_to_float8(self, w, is_amax_initialized): torch.float8_e4m3fn, is_amax_initialized, ) - w_fp8 = to_float8(w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w) + w_fp8 = Float8Tensor.to_float8( + w, self.fp8_scale_w, torch.float8_e4m3fn, self.fp8_amax_w, self.emulate + ) return w_fp8 - def cast_y_to_float8_in_bw(self, y): + def cast_y_to_float8_in_bw( + self, y: torch.Tensor, emulate: bool = False + ) -> torch.Tensor: scale_fn_name = self.recipe.scale_fn_name y = NoopFwToFloat8E5M2Bw.apply( y, @@ -265,13 +191,7 @@ def cast_y_to_float8_in_bw(self, y): self.fp8_scale_dL_dY, scale_fn_name, self.is_amax_initialized, - ) - return y - - def float8_mm(self, x_fp8, w_fp8, is_amax_initialized): - scale_fn_name = self.recipe.scale_fn_name - y = float8_linear.apply( - x_fp8, w_fp8, is_amax_initialized, scale_fn_name, self.emulate + emulate, ) return y @@ -292,6 +212,11 @@ def float8_post_forward(self): self.is_amax_initialized = True self.amax_and_scale_synced = False + def add_weight_tag(self): + # We add a tag to the weight nn.Parameter in order to signal + # To FSDP that this param is a weight + self.weight._is_fp8_weight = True + class Float8Linear(Float8LinearMixin, torch.nn.Linear): """ @@ -311,11 +236,14 @@ def forward(self, x): w_fp8 = self._w_fp8 else: w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) - y = self.float8_mm(x_fp8, w_fp8, self.is_amax_initialized) - y = self.cast_y_to_float8_in_bw(y) + + y = torch.matmul(x_fp8, w_fp8.t()) + + # Cast gradY to float8_e5m2 during backward + y = self.cast_y_to_float8_in_bw(y, self.emulate) if self.bias is not None: - y = y + self.bias.to(x_fp8._orig_dtype) + y = y + self.bias.to(self.bias_dtype) self.float8_post_forward() return y @@ -336,16 +264,13 @@ def from_float(cls, mod, emulate: bool = False): new_mod.weight = mod.weight new_mod.bias = mod.bias new_mod.emulate = emulate + if mod.bias is not None: + new_mod.bias_dtype = mod.bias.dtype # I think its okay to send all params and buffers to device new_mod.to(mod.weight.device) new_mod.add_weight_tag() return new_mod - def add_weight_tag(self): - # We add a tag to the weight nn.Parameter in order to signal - # To FSDP that this param is a weight - self.weight._is_fp8_weight = True - def swap_linear_with_float8_linear( model, diff --git a/float8_experimental/float8_ops.py b/float8_experimental/float8_ops.py new file mode 100644 index 00000000..3050d101 --- /dev/null +++ b/float8_experimental/float8_ops.py @@ -0,0 +1,84 @@ +from typing import Any, Dict + +import torch +from float8_experimental.float8_python_api import mm_float8_unwrapped +from float8_experimental.float8_tensor import Float8Tensor +from float8_experimental.float8_utils import is_row_major + +aten = torch.ops.aten +FLOAT8_OPS_TABLE: Dict[Any, Any] = {} + + +def implements(aten_ops): + """Register aten ops to the float8 op table""" + + def decorator(func): + for op in aten_ops: + FLOAT8_OPS_TABLE[op] = func + return func + + return decorator + + +@implements( + [ + aten.view.default, + aten._unsafe_view.default, + aten.t.default, + aten.as_strided.default, + aten.clone.default, + aten.detach.default, + ] +) +def float8_desugar_op(aten_op, args, kwargs=None): + new_data = aten_op(args[0]._data, *args[1:], **kwargs) + return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype, args[0]._emulate) + + +@implements([aten.mm.default]) +def float8_mm(aten_op, args, kwargs=None): + assert isinstance(args[0], Float8Tensor) and isinstance(args[1], Float8Tensor) + a = args[0] + b = args[1] + a_data = a._data + a_scale = a._scale + b_data = b._data + + if not is_row_major(a_data.stride()): + a_data = a_data.contiguous() + if is_row_major(b_data.stride()): + b_data = b_data.t().contiguous().t() + b_scale = b._scale + output_dtype = a._orig_dtype + if a._emulate: + assert a._emulate == b._emulate + return torch.ops.aten.mm_float8_emulated( + a._data, a._scale, b._data, b._scale, output_dtype + )[0] + tensor_out, amax = mm_float8_unwrapped( + a_data, a_scale, b_data, b_scale, output_dtype, output_scale=None + ) + return tensor_out + + +@implements([aten.is_same_size.default]) +def float8_is_same_size(aten_op, args, kwargs=None): + return args[0].shape == args[1].shape + + +@implements([aten._to_copy.default]) +def autocast_to_copy(aten_op, args, kwargs=None): + """This gets called when running matmul under autocast + when the input is a Float8Tensor, presenting as a fp32 + tensor. + """ + assert isinstance(args[0], Float8Tensor) + assert ( + len(kwargs) == 1 and "dtype" in kwargs + ), "Only support dtype kwarg for autocast" + assert ( + kwargs["dtype"] == torch.float16 + ), "Only support floating point conversion for autocast w/ Float8Tensor" + return Float8Tensor( + args[0]._data, args[0]._scale, kwargs["dtype"], args[0]._emulate + ) diff --git a/float8_experimental/float8_python_api.py b/float8_experimental/float8_python_api.py index 96870d14..c7f6b3f7 100644 --- a/float8_experimental/float8_python_api.py +++ b/float8_experimental/float8_python_api.py @@ -4,7 +4,7 @@ to simplify the product code. """ -import warnings + from typing import Optional, Tuple import float8_experimental.float8_aten_api diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 82284278..6738706e 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -1,47 +1,10 @@ -from typing import Any, Dict +from typing import Dict import torch from float8_experimental.float8_utils import tensor_to_amax, to_fp8_saturated -from torch._subclasses.fake_tensor import is_fake aten = torch.ops.aten -FLOAT8_OPS_TABLE: Dict[Any, Any] = {} - - -def implements(aten_ops): - """Register aten ops to the float8 op table""" - - def decorator(func): - for op in aten_ops: - FLOAT8_OPS_TABLE[op] = func - return func - - return decorator - - -@implements( - [ - aten.view.default, - aten._unsafe_view.default, - aten.t.default, - aten.as_strided.default, - aten.clone.default, - aten.detach.default, - ] -) -def float8_desugar_op(aten_op, args, kwargs=None): - assert is_fake( - args[0] - ), "Float8Tensor.__torch_dispatch__ for user code is not supported" - new_data = aten_op(args[0]._data, *args[1:], **kwargs) - return Float8Tensor(new_data, args[0]._scale, args[0]._orig_dtype) - - -@implements([aten.is_same_size.default]) -def float8_is_same_size(aten_op, args, kwargs=None): - return args[0].shape == args[1].shape - class ToFloat8ConstrFunc(torch.autograd.Function): """ @@ -55,6 +18,7 @@ def forward( scale: float = None, float8_dtype=torch.float8_e4m3fn, amax_buffer=None, + emulate: bool = False, ): # In TransformerEngine, the casts to float8 are fused with calculating # the new amax value. In this codebase, the eager mode code for those @@ -65,34 +29,14 @@ def forward( tensor_scaled = tensor * scale bits_fp8 = to_fp8_saturated(tensor_scaled, float8_dtype) - return Float8Tensor(bits_fp8, scale, tensor.dtype) + return Float8Tensor(bits_fp8, scale, tensor.dtype, emulate=emulate) @staticmethod def backward(ctx, g): if isinstance(g, Float8Tensor): - return g.to_original_precision(), None, None, None + return g.to_original_precision(), None, None, None, None else: - return g, None, None, None - - -def to_float8( - tensor: torch.Tensor, - scale: torch.Tensor, - float8_dtype: torch.dtype, - amax_buffer: torch.Tensor = None, -) -> "Float8Tensor": - """Converts a higher precision tensor to float8 in a differentiable way. - - Args: - tensor: the tensor to convert - scale: the scale to use to convert the tensor - float8_dtype: the float8 dtype to use - amax_buffer: a buffer to store the amax value in prior to conversion - - Returns: - Float8Tensor: a float8 tensor - """ - return ToFloat8ConstrFunc.apply(tensor, scale, float8_dtype, amax_buffer) + return g, None, None, None, None class FromFloat8ConstrFunc(torch.autograd.Function): @@ -106,7 +50,7 @@ def forward(ctx, tensor): @staticmethod def backward(ctx, g): - return to_float8(g), None, None + return Float8Tensor.to_float8(g), None, None class Float8Tensor(torch.Tensor): @@ -118,6 +62,8 @@ class Float8Tensor(torch.Tensor): from fp8 range to fp32 range. * `_orig_dtype`: the original dtype of the tensor used to create this tensor. + * `_emulate`: if true using fp32 emulation for the matmuls, helpful + if you don't have access to h100 hardware. Intended usage of this abstraction: 1. to bundle raw data + fp8 metadata together for easy passing through @@ -131,9 +77,16 @@ class Float8Tensor(torch.Tensor): _data: torch.Tensor _scale: torch.Tensor _orig_dtype: torch.dtype - __slots__ = ["_data", "_scale", "_orig_dtype"] - - def __new__(cls, data: torch.Tensor, scale: torch.Tensor, orig_dtype: torch.dtype): + _emulate: bool + __slots__ = ["_data", "_scale", "_orig_dtype", "_emulate"] + + def __new__( + cls, + data: torch.Tensor, + scale: torch.Tensor, + orig_dtype: torch.dtype, + emulate=False, + ): assert scale.numel() == 1 self = torch.Tensor._make_wrapper_subclass( @@ -149,37 +102,60 @@ def __new__(cls, data: torch.Tensor, scale: torch.Tensor, orig_dtype: torch.dtyp self._data = data self._scale = scale self._orig_dtype = orig_dtype - + self._emulate = emulate return self def __repr__(self): - return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, as_orig_prec={self.to_original_precision()}" + return f"Float8Tensor(dtype={self._data.dtype}, scale={self._scale}, emulate={self._emulate}\nas_orig_prec={self.to_original_precision()}" def __tensor_flatten__(self): ctx = { - "_scale": self._scale, "_orig_dtype": self._orig_dtype, + "_emulate": self._emulate, } - # return ("_data", "_scale"), (self._orig_dtype) - return ["_data"], ctx + return ["_data", "_scale"], ctx @staticmethod def __tensor_unflatten__(inner_tensors: Dict, metadata): - assert len(inner_tensors) == 1 - # return Float8Tensor(tensors["_data"], tensors["_scale"], metadatas[0]) + assert len(inner_tensors) == 2 return Float8Tensor( - inner_tensors["_data"], metadata["_scale"], metadata["_orig_dtype"] + inner_tensors["_data"], + inner_tensors["_scale"], + metadata["_orig_dtype"], + metadata["_emulate"], ) def to_original_precision(self): return FromFloat8ConstrFunc.apply(self) + @staticmethod + @torch._dynamo.allow_in_graph + def to_float8(tensor, scale, float8_dtype, amax_buffer=None, emulate: bool = False): + """Converts a higher precision tensor to float8 in a differentiable way. + + Args: + tensor: the tensor to convert + scale: the scale to use to convert the tensor + float8_dtype: the float8 dtype to use + amax_buffer: a buffer to store the amax value in prior to conversion + + Returns: + Float8Tensor: a float8 tensor + """ + return ToFloat8ConstrFunc.apply( + tensor, scale, float8_dtype, amax_buffer, emulate + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs=None): # 1. tracing through __torch_function__ logic is not supported yet in # PT2.0, so we explicitly disallow it here for callsites from user code. # 2. We do need to handle a couple of ops in order for # TorchDynamo tracing to succeed. + + # Lazy import to avoid circular dependency + from float8_experimental.float8_ops import FLOAT8_OPS_TABLE + if func in FLOAT8_OPS_TABLE: return FLOAT8_OPS_TABLE[func](func, args, kwargs) raise NotImplementedError(f"attempting to run {func}, this is not supported") diff --git a/float8_experimental/float8_utils.py b/float8_experimental/float8_utils.py index 7283d2cd..273e1580 100644 --- a/float8_experimental/float8_utils.py +++ b/float8_experimental/float8_utils.py @@ -82,3 +82,7 @@ def compute_error(x, y): Ps = torch.norm(x) Pn = torch.norm(x - y) return 20 * torch.log10(Ps / Pn) + +def is_row_major(stride): + assert len(stride) == 2, "is_row_major only supports 2D tensors" + return stride[0] > stride[1] and stride[1] == 1 \ No newline at end of file diff --git a/float8_experimental/tp_linear.py b/float8_experimental/tp_linear.py index 4850adee..e8400c9b 100644 --- a/float8_experimental/tp_linear.py +++ b/float8_experimental/tp_linear.py @@ -14,7 +14,7 @@ _ReduceScatterFwAllGatherFloat8Bw, ) -from float8_experimental.float8_linear import float8_linear, Float8LinearMixin +from float8_experimental.float8_linear import Float8LinearMixin class Float8ColumnParallelLinear(Float8LinearMixin, ColumnParallelLinear): @@ -44,16 +44,19 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type: ignore input_parallel, self.is_amax_initialized ) - w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast + w_fp8 = self._w_fp8 + else: + w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) # Matrix multiply. - output_parallel = self.float8_mm( - input_parallel_fp8, w_fp8, self.is_amax_initialized - ) - output_parallel = self.cast_y_to_float8_in_bw(output_parallel) + output_parallel = torch.matmul(input_parallel_fp8, w_fp8.t()) + + # Cast gradY to float8_e5m2 during backward + output_parallel = self.cast_y_to_float8_in_bw(output_parallel, self.emulate) if self.bias is not None: - output_parallel = output_parallel + self.bias.to(output_parallel.dtype) + output_parallel = output_parallel + self.bias.to(self.bias_dtype) if self.gather_output: assert not self.use_sequence_parallel, "unsupported" @@ -85,12 +88,15 @@ def from_float(cls, mod, emulate=False): new_mod.out_features = mod.out_features new_mod.weight = mod.weight new_mod.bias = mod.bias + if mod.bias is not None: + new_mod.bias_dtype = mod.bias.dtype new_mod.gather_output = mod.gather_output new_mod.output_size_per_partition = mod.output_size_per_partition new_mod.master_weight = mod.master_weight device_to_use = next(mod.parameters()).device new_mod.to(device_to_use) new_mod.emulate = emulate + new_mod.add_weight_tag() # TODO: test when creation is on cuda return new_mod @@ -125,12 +131,14 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore input_parallel_fp8 = self.cast_x_to_float8( input_parallel, self.is_amax_initialized ) - w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) + + if getattr(self, "_w_fp8", None) is not None: # FSDP handled the cast + w_fp8 = self._w_fp8 + else: + w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized) # Matrix multiply. - output_parallel = self.float8_mm( - input_parallel_fp8, w_fp8, self.is_amax_initialized - ) + output_parallel = torch.matmul(input_parallel_fp8, w_fp8.t()) if self.use_sequence_parallel: # forward: reduce-scatter @@ -138,14 +146,14 @@ def forward(self, input_: torch.Tensor) -> torch.Tensor: # type:ignore # backward: all-gather # Float8 comms: yes output_ = _ReduceScatterFwAllGatherFloat8Bw.apply(output_parallel) - output_ = self.cast_y_to_float8_in_bw(output_) + output_ = self.cast_y_to_float8_in_bw(output_, self.emulate) else: # All-reduce across all the partitions. # forward: reduce # Float8 comms: none # backward: no-op # Float8 comms: none - output_parallel = self.cast_y_to_float8_in_bw(output_parallel) + output_parallel = self.cast_y_to_float8_in_bw(output_parallel, self.emulate) # adding zero below is a hack # without this hack, we see the following error: https://gist.github.com/vkuzo/0ed84e35081c8c7d20d0f46ed4322704 @@ -174,6 +182,8 @@ def from_float(cls, mod, emulate=False): new_mod.out_features = mod.out_features new_mod.weight = mod.weight new_mod.bias = mod.bias + if mod.bias is not None: + new_mod.bias_dtype = mod.bias.dtype new_mod.input_is_parallel = mod.input_is_parallel new_mod.input_size_per_partition = mod.input_size_per_partition new_mod.master_weight = mod.master_weight @@ -181,6 +191,7 @@ def from_float(cls, mod, emulate=False): device_to_use = next(mod.parameters()).device new_mod.to(device_to_use) new_mod.emulate = emulate + new_mod.add_weight_tag() return new_mod diff --git a/pyproject.toml b/pyproject.toml index ebc60d6a..6097122d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,7 +33,7 @@ dev = ["black", "bumpver", "isort", "pip-tools"] packages = ["float8_experimental"] [tool.black] -line-length = 99 +line-length = 88 include = '\.pyi?$' exclude = ''' /( diff --git a/test/test_base.py b/test/test_base.py index 19f720fa..950d3f21 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -15,7 +15,7 @@ ) from float8_experimental.float8_linear_nots import Float8LinearNoTensorSubclass from float8_experimental.float8_python_api import mm_float8 -from float8_experimental.float8_tensor import Float8Tensor, to_float8 +from float8_experimental.float8_tensor import Float8Tensor from float8_experimental.float8_utils import ( amax_to_scale, @@ -39,7 +39,7 @@ def test_preserves_dtype(self): for hp_dtype, lp_dtype in itertools.product(hp_dtypes, lp_dtypes): x1_hp = torch.randn(4, 4, dtype=hp_dtype) x1_s = tensor_to_scale(x1_hp, lp_dtype) - x2_lp = to_float8(x1_hp, x1_s, lp_dtype) + x2_lp = Float8Tensor.to_float8(x1_hp, x1_s, lp_dtype) x3_hp = x2_lp.to_original_precision() self.assertTrue(x3_hp.dtype == hp_dtype) @@ -215,9 +215,8 @@ def test_pt2_nots(self, emulate: bool, device: torch.device): @pytest.mark.parametrize("emulate", [True, False]) def test_pt2_ts(self, emulate: bool): - if emulate: - warnings.warn("PT2.0 tracing doesn't work with subclass + emulate yet") - pytest.skip() + warnings.warn("PT2.0 tracing doesn't work with subclass") + pytest.skip() self._test_pt2_impl( use_no_tensor_subclass=False, emulate=emulate, device="cuda" ) @@ -248,8 +247,8 @@ def test_scaled_mm_vs_emulated(self, base_dtype): a_scale = tensor_to_scale(a, input_dtype).float() b_scale = tensor_to_scale(b, input_dtype).float() - a_fp8 = to_float8(a, a_scale, input_dtype) - b_fp8 = to_float8(b, b_scale, input_dtype) + a_fp8 = Float8Tensor.to_float8(a, a_scale, input_dtype) + b_fp8 = Float8Tensor.to_float8(b, b_scale, input_dtype) out_scaled_mm, output_amax_scaled = mm_float8( a_fp8, b_fp8, output_dtype=output_dtype, emulate=False diff --git a/test/test_compile.py b/test/test_compile.py index b9800d42..0637bc7c 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -1,15 +1,16 @@ import copy import random +import unittest import pytest + import torch import torch.nn as nn - from float8_experimental.float8_linear import Float8Linear from float8_experimental.float8_linear_nots import Float8LinearNoTensorSubclass -from torch._dynamo.testing import CompileCounterWithBackend, EagerAndRecordGraphs # Setting to unblock for calling contiguous in backwards +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) def _test_compile_base( @@ -34,33 +35,36 @@ def _test_compile_base( y_fp8.sum().backward() y_ref = m_ref(x) y_ref.sum().backward() - torch.testing.assert_close(y_fp8, y_ref, atol=8e-2, rtol=8e-2) + torch.testing.assert_close(y_fp8, y_ref, atol=9.5e-2, rtol=9.5e-2) torch.testing.assert_close( m_fp8.weight.grad, m_ref.weight.grad, atol=2e-1, rtol=2e-1 ) torch.testing.assert_close(m_fp8.bias.grad, m_ref.bias.grad, atol=8e-2, rtol=8e-2) -@pytest.mark.parametrize("fullgraph", [False]) -@pytest.mark.parametrize("emulate", [True]) +@pytest.mark.parametrize("fullgraph", [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("use_subclass", [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_eager_only(fullgraph, emulate: bool, use_subclass: bool, dtype: torch.dtype): _test_compile_base("eager", fullgraph, emulate, use_subclass, dtype) -@pytest.mark.parametrize("fullgraph", [False]) -@pytest.mark.parametrize("emulate", [False]) +@pytest.mark.parametrize("fullgraph", [True]) +@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("use_subclass", [True]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def test_aot_eager(fullgraph, emulate: bool, use_subclass: bool, dtype: torch.dtype): _test_compile_base("aot_eager", fullgraph, emulate, use_subclass, dtype) @pytest.mark.parametrize("fullgraph", [True]) -@pytest.mark.parametrize("emulate", [True, False]) +@pytest.mark.parametrize("emulate", [False, True] if is_H100 else [True]) @pytest.mark.parametrize("use_subclass", [True]) -@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") +@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32]) def test_inductor(fullgraph, emulate: bool, use_subclass: bool, dtype: torch.dtype): _test_compile_base("inductor", fullgraph, emulate, use_subclass, dtype) diff --git a/test/test_everything.sh b/test/test_everything.sh old mode 100644 new mode 100755 index 8ba17fca..14b4813d --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -5,6 +5,7 @@ set -e pytest test/test_base.py pytest test/test_sam.py +pytest test/test_compile.py ./test/test_fsdp.sh ./test/test_tp.sh diff --git a/test/test_fsdp.sh b/test/test_fsdp.sh old mode 100644 new mode 100755 diff --git a/test/test_tp.sh b/test/test_tp.sh old mode 100644 new mode 100755