@@ -50,7 +50,11 @@ def decorator(func):
50
50
def float8_desugar_op (aten_op , args , kwargs = None ):
51
51
new_data = aten_op (args [0 ]._data , * args [1 :], ** kwargs )
52
52
return Float8Tensor (
53
- new_data , args [0 ]._scale , args [0 ]._orig_dtype , args [0 ]._mm_config
53
+ new_data ,
54
+ args [0 ]._scale ,
55
+ args [0 ]._orig_dtype ,
56
+ args [0 ]._mm_config ,
57
+ args [0 ]._scaling_strategy ,
54
58
)
55
59
56
60
@@ -60,7 +64,11 @@ def float8_split(aten_op, args, kwargs=None):
60
64
61
65
def make_float8 (data ):
62
66
return Float8Tensor (
63
- data , args [0 ]._scale , args [0 ]._orig_dtype , args [0 ]._mm_config
67
+ data ,
68
+ args [0 ]._scale ,
69
+ args [0 ]._orig_dtype ,
70
+ args [0 ]._mm_config ,
71
+ args [0 ]._scaling_strategy ,
64
72
)
65
73
66
74
out = map (make_float8 , new_data_tensors )
@@ -75,6 +83,7 @@ def float8_cat(aten_op, args, kwargs=None):
75
83
orig_dtype = chunked_tensors [0 ]._orig_dtype
76
84
scale = chunked_tensors [0 ]._scale
77
85
mm_config = chunked_tensors [0 ]._mm_config
86
+ scaling_strategy = chunked_tensors [0 ]._scaling_strategy
78
87
fp8_dtype = chunked_tensors [0 ]._data .dtype
79
88
chunk_data = []
80
89
for chunk in chunked_tensors :
@@ -93,11 +102,14 @@ def float8_cat(aten_op, args, kwargs=None):
93
102
assert (
94
103
chunk ._data .dtype == fp8_dtype
95
104
), "Expecting all chunks to be of the same dtype as a result of a split"
105
+ assert (
106
+ chunk ._scaling_strategy is scaling_strategy
107
+ ), "Expecting all chunks to have thee same scaling strategy as a result of a split"
96
108
chunk_data .append (chunk ._data .view (torch .uint8 ))
97
109
98
110
new_data = aten_op (chunk_data , * args [1 :], ** kwargs )
99
111
new_data = new_data .view (fp8_dtype )
100
- return Float8Tensor (new_data , scale , orig_dtype , mm_config )
112
+ return Float8Tensor (new_data , scale , orig_dtype , mm_config , scaling_strategy )
101
113
102
114
103
115
@implements ([aten .sum .dim_IntList ])
@@ -162,6 +174,11 @@ def float8_mm(aten_op, args, kwargs=None):
162
174
return torch .ops .aten .mm_float8_emulated (
163
175
a ._data , a ._scale , b ._data , b ._scale , output_dtype
164
176
)
177
+ scaling_strategy = a ._scaling_strategy
178
+ # TODO We can enable this by broadcasting to the more generic form
179
+ assert (
180
+ scaling_strategy == b ._scaling_strategy
181
+ ), "Scaling strategy are currently required to be the same"
165
182
tensor_out = addmm_float8_unwrapped (
166
183
a_data ,
167
184
a_scale ,
@@ -191,6 +208,11 @@ def float8_addmm(aten_op, args, kwargs=None):
191
208
a_mm_config : ScaledMMConfig = a ._mm_config
192
209
b_mm_config : ScaledMMConfig = b ._mm_config
193
210
mm_config : ScaledMMConfig = merge_mm_configs (a_mm_config , b_mm_config )
211
+ scaling_strategy = a ._scaling_strategy
212
+ # TODO We can enable this by broadcasting to the more generic form
213
+ assert (
214
+ scaling_strategy == b ._scaling_strategy
215
+ ), "Scaling strategy are currently required to be the same"
194
216
if mm_config .emulate :
195
217
out = torch .ops .aten .mm_float8_emulated (
196
218
a ._data , a ._scale , b ._data , b ._scale , output_dtype
@@ -229,7 +251,11 @@ def autocast_to_copy(aten_op, args, kwargs=None):
229
251
torch .bfloat16 ,
230
252
}, "Only support floating point conversion for autocast w/ Float8Tensor"
231
253
return Float8Tensor (
232
- args [0 ]._data , args [0 ]._scale , kwargs ["dtype" ], args [0 ]._mm_config
254
+ args [0 ]._data ,
255
+ args [0 ]._scale ,
256
+ kwargs ["dtype" ],
257
+ args [0 ]._mm_config ,
258
+ args [0 ]._scaling_strategy ,
233
259
)
234
260
235
261
@@ -252,7 +278,11 @@ def allgather_fp8(aten_op, args, kwargs=None):
252
278
fp8_data = fp8_data .contiguous ()
253
279
fp8_out = aten_op (fp8_data , * args [1 :], ** kwargs )
254
280
return Float8Tensor (
255
- fp8_out , fp8_input ._scale , fp8_input ._orig_dtype , fp8_input ._mm_config
281
+ fp8_out ,
282
+ fp8_input ._scale ,
283
+ fp8_input ._orig_dtype ,
284
+ fp8_input ._mm_config ,
285
+ fp8_input ._scaling_strategy ,
256
286
)
257
287
258
288
@@ -264,7 +294,11 @@ def wait_tensor_fp8(aten_op, args, kwargs=None):
264
294
fp8_data = fp8_input ._data
265
295
fp8_out = aten_op (fp8_data , * args [1 :], ** kwargs )
266
296
return Float8Tensor (
267
- fp8_out , fp8_input ._scale , fp8_input ._orig_dtype , fp8_input ._mm_config
297
+ fp8_out ,
298
+ fp8_input ._scale ,
299
+ fp8_input ._orig_dtype ,
300
+ fp8_input ._mm_config ,
301
+ fp8_input ._scaling_strategy ,
268
302
)
269
303
270
304
@@ -282,7 +316,11 @@ def index_put_fp8(aten_op, args, kwargs=None):
282
316
fp8_values_data = fp8_values ._data
283
317
fp8_out = aten_op (fp8_data , args [1 ], fp8_values_data , * args [3 :], ** kwargs )
284
318
return Float8Tensor (
285
- fp8_out , fp8_self ._scale , fp8_self ._orig_dtype , fp8_self ._mm_config
319
+ fp8_out ,
320
+ fp8_self ._scale ,
321
+ fp8_self ._orig_dtype ,
322
+ fp8_self ._mm_config ,
323
+ fp8_self ._scaling_strategy ,
286
324
)
287
325
288
326
@@ -315,6 +353,12 @@ def copy_fp8(aten_op, args, kwargs=None):
315
353
self ._data .dtype == src ._data .dtype
316
354
), "Expecting both Float8Tensors to be of the same dtypet"
317
355
fp8_out = aten_op (self ._data , src ._data , * args [2 :], ** kwargs )
318
- return Float8Tensor (fp8_out , self ._scale , self ._orig_dtype , self ._mm_config )
356
+ return Float8Tensor (
357
+ fp8_out ,
358
+ self ._scale ,
359
+ self ._orig_dtype ,
360
+ self ._mm_config ,
361
+ self ._scaling_strategy ,
362
+ )
319
363
else :
320
364
raise RuntimeError ("Unsupported semantics for copy_ in Float8Tensor" )
0 commit comments