Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 41a6395

Browse files
committed
[TBD if for land] bring back torch.autograd.Function
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: a7abb00 Pull Request resolved: #316
1 parent c58fb5d commit 41a6395

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

float8_experimental/float8_linear.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,58 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
7171
scale.copy_(new_scale)
7272

7373

74+
# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
75+
@torch._dynamo.allow_in_graph
76+
class manual_float8_matmul(torch.autograd.Function):
77+
"""
78+
Like torch.matmul, but with the arguments in float8
79+
"""
80+
81+
@staticmethod
82+
def forward(
83+
ctx,
84+
x_fp8,
85+
w_fp8_t,
86+
):
87+
ctx.save_for_backward(x_fp8, w_fp8_t)
88+
# the reshapes are needed in order to make the shapes compatible with
89+
# torch.mm
90+
orig_shape = x_fp8.shape
91+
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
92+
res_bits = torch.mm(x_fp8_reshaped, w_fp8_t)
93+
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
94+
return res_bits
95+
96+
@staticmethod
97+
def backward(ctx, go_fp8):
98+
x_fp8, w_fp8_t = ctx.saved_tensors
99+
100+
# the reshapes are needed in order to make the shapes compatible with
101+
# torch.mm
102+
go_fp8_orig_shape = go_fp8.shape
103+
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])
104+
105+
# calculate dL/dX
106+
dL_dX = torch.mm(
107+
go_fp8_reshaped,
108+
w_fp8_t.t(),
109+
)
110+
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])
111+
112+
x_fp8_orig_shape = x_fp8.shape
113+
x_fp8_reshaped = x_fp8.reshape(-1, x_fp8_orig_shape[-1])
114+
115+
# calculate dL/dW
116+
# Note: the variant below is slightly faster on LLaMa 3 8B pretraining
117+
# compared to than calculating `dL_dW_t = x_fp8_t @ go_fp8_reshaped`
118+
dL_dW = torch.mm(
119+
go_fp8_reshaped.t(),
120+
x_fp8_reshaped,
121+
)
122+
123+
return dL_dX, dL_dW.t()
124+
125+
74126
@torch._dynamo.allow_in_graph
75127
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
76128
"""
@@ -410,7 +462,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
410462
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
411463
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
412464

413-
y = torch.matmul(x_fp8, w_fp8.t())
465+
y = manual_float8_matmul.apply(x_fp8, w_fp8.t())
414466

415467
# Cast gradY to float8_e5m2 during backward
416468
y = self.cast_y_to_float8_in_bw(y)

float8_experimental/float8_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,9 @@ def choose_scaled_mm_config(
101101
a_linear_mm_config.dL_dX == b_linear_mm_config.dL_dX
102102
), f"linear_mm_config.dL_dX mismatch: {a_linear_mm_config.dL_dX} vs {b_linear_mm_config.dL_dX}"
103103
return a_linear_mm_config.dL_dX
104-
elif a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X:
104+
elif (a_role is GemmInputRole.X and b_role is GemmInputRole.DL_DY) or (
105+
a_role is GemmInputRole.DL_DY and b_role is GemmInputRole.X
106+
):
105107
assert (
106108
a_linear_mm_config.dL_dW == b_linear_mm_config.dL_dW
107109
), f"linear_mm_config.dL_dW mismatch: {a_linear_mm_config.dL_dW} vs {b_linear_mm_config.dL_dW}"

0 commit comments

Comments
 (0)