Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 352f29e

Browse files
committed
[TBD if for land] bring back torch.autograd.Function
Summary: This approach is more readable as we add additional scaling options. For now, seeing how many things break in 2024-07 with torch.autograd.Function + subclasses + compile. Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 09c4625 Pull Request resolved: #316
1 parent de93990 commit 352f29e

File tree

1 file changed

+101
-1
lines changed

1 file changed

+101
-1
lines changed

float8_experimental/float8_linear.py

Lines changed: 101 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,101 @@ def _maybe_initialize_amaxes_scales_for_float8_cast(
6868
)
6969
scale.copy_(new_scale)
7070

71+
# this code was resurrected from https://github.com/pytorch-labs/float8_experimental/pull/128/files
72+
# and modified to only support dynamic scaling
73+
@torch._dynamo.allow_in_graph
74+
class float8_linear(torch.autograd.Function):
75+
"""
76+
Like F.linear, but with X and W in float8
77+
"""
78+
79+
@staticmethod
80+
def forward(
81+
ctx,
82+
x_fp8,
83+
w_fp8,
84+
emulate: bool,
85+
# TODO(this PR): split config into fwd/bwd
86+
mm_config: ScaledMMConfig,
87+
):
88+
ctx.save_for_backward(x_fp8, w_fp8)
89+
ctx.emulate = emulate
90+
ctx.mm_config = mm_config
91+
# orig_shape = x_fp8._data.shape
92+
orig_shape = x_fp8.shape
93+
# x_fp8_reshaped = Float8Tensor(
94+
# x_fp8._data.reshape(-1, orig_shape[-1]), x_fp8._scale, x_fp8._orig_dtype, mm_config
95+
# )
96+
x_fp8_reshaped = x_fp8.reshape(-1, orig_shape[-1])
97+
98+
# w_fp8_t = Float8Tensor(w_fp8._data.t(), w_fp8._scale, w_fp8._orig_dtype, mm_config)
99+
w_fp8_t = w_fp8.t()
100+
101+
res_bits = torch.mm(
102+
x_fp8_reshaped, w_fp8_t
103+
)
104+
res_bits = res_bits.reshape(*orig_shape[:-1], res_bits.shape[-1])
105+
return res_bits
106+
107+
@staticmethod
108+
def backward(ctx, go_fp8):
109+
x_fp8, w_fp8 = ctx.saved_tensors
110+
emulate = ctx.emulate
111+
mm_config = ctx.mm_config
112+
113+
go_fp8_orig_shape = go_fp8.shape
114+
# go_fp8_reshaped = Float8Tensor(
115+
# go_fp8._data.reshape(-1, go_fp8_orig_shape[-1]),
116+
# go_fp8._scale,
117+
# go_fp8._orig_dtype,
118+
# mm_config,
119+
# )
120+
go_fp8_reshaped = go_fp8.reshape(-1, go_fp8_orig_shape[-1])
121+
122+
# w_fp8_t_c_t = Float8Tensor(
123+
# w_fp8._data.t().contiguous().t(), w_fp8._scale, w_fp8._orig_dtype, mm_config
124+
# )
125+
w_fp8_t_c_t = w_fp8.t().contiguous().t()
126+
127+
#
128+
# calculate dL/dX
129+
#
130+
dL_dX = torch.mm(
131+
go_fp8_reshaped,
132+
w_fp8_t_c_t,
133+
)
134+
dL_dX = dL_dX.reshape(*go_fp8_orig_shape[:-1], dL_dX.shape[-1])
135+
136+
# x_fp8_orig_shape = x_fp8._data.shape
137+
x_fp8_orig_shape = x_fp8.shape
138+
# x_fp8_reshaped_t_c = Float8Tensor(
139+
# x_fp8._data.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous(),
140+
# x_fp8._scale,
141+
# x_fp8._orig_dtype,
142+
# mm_config,
143+
# )
144+
x_fp8_reshaped_t_c = x_fp8.reshape(-1, x_fp8_orig_shape[-1]).t().contiguous()
145+
146+
# go_fp8_reshaped_t_c_t = Float8Tensor(
147+
# go_fp8_reshaped._data.t().contiguous().t(),
148+
# go_fp8_reshaped._scale,
149+
# go_fp8_reshaped._orig_dtype,
150+
# mm_config,
151+
# )
152+
go_fp8_reshaped_t_c_t = go_fp8_reshaped.t().contiguous().t()
153+
154+
#
155+
# calculate dL/dW
156+
#
157+
dL_dW = torch.mm(
158+
x_fp8_reshaped_t_c,
159+
go_fp8_reshaped_t_c_t,
160+
)
161+
dL_dW = dL_dW.t()
162+
163+
empty_grads = None, None, None, None, None, None, None, None, None
164+
return dL_dX, dL_dW, *empty_grads
165+
71166

72167
@torch._dynamo.allow_in_graph
73168
class NoopFwToFloat8E5M2Bw(torch.autograd.Function):
@@ -394,7 +489,12 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
394489
x_fp8 = self.cast_x_to_float8(input, self.is_amax_initialized)
395490
w_fp8 = self.cast_w_to_float8(self.weight, self.is_amax_initialized)
396491

397-
y = torch.matmul(x_fp8, w_fp8.t())
492+
if not self.has_any_delayed_scaling:
493+
emulate = False
494+
mm_config = self.forward_config
495+
y = float8_linear.apply(x_fp8, w_fp8, emulate, mm_config)
496+
else:
497+
y = torch.matmul(x_fp8, w_fp8.t())
398498

399499
# Cast gradY to float8_e5m2 during backward
400500
y = self.cast_y_to_float8_in_bw(y)

0 commit comments

Comments
 (0)