@@ -184,7 +184,7 @@ def __init__(
184
184
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
185
185
# is enough for 131072 such tuples. The largest model I've seen only
186
186
# needs less than 10000 of registered tuples.
187
- self .rank_data = torch .zeros (
187
+ self .rank_data = torch .empty (
188
188
8 * 1024 * 1024 , dtype = torch .uint8 , device = self .device
189
189
)
190
190
self ._ptr = ops .init_custom_ar (
@@ -194,14 +194,14 @@ def __init__(
194
194
else :
195
195
# meta data buffers need to be "uncached" for signal on MI200
196
196
self .meta = ops .allocate_meta_buffer (ops .meta_size () + max_size )
197
- self .buffer = torch .zeros (max_size , dtype = torch .uint8 , device = self .device )
197
+ self .buffer = torch .empty (max_size , dtype = torch .uint8 , device = self .device )
198
198
handle = ops .get_meta_buffer_ipc_handle (self .meta )
199
199
shard_data = (
200
200
bytes (handle ), # ipc handle to base ptr
201
201
0 , # offset of base ptr
202
202
)
203
203
handles , offsets = self ._gather_ipc_meta (shard_data )
204
- self .rank_data = torch .zeros (
204
+ self .rank_data = torch .empty (
205
205
8 * 1024 * 1024 , dtype = torch .uint8 , device = self .device
206
206
)
207
207
self ._ptr = ops .init_custom_ar (
@@ -350,14 +350,14 @@ def should_custom_ar(self, inp: torch.Tensor):
350
350
# or, in the context of cuda graphs, register_graph_buffers
351
351
def all_reduce_reg (self , inp : torch .Tensor , out : torch .Tensor = None ):
352
352
if out is None :
353
- out = torch .zeros_like (inp )
353
+ out = torch .empty_like (inp )
354
354
ops .all_reduce_reg (self ._ptr , inp , out )
355
355
return out
356
356
357
357
# all reduce, assuming inp tensor is NOT IPC registered
358
358
def all_reduce_unreg (self , inp : torch .Tensor , out : torch .Tensor = None ):
359
359
if out is None :
360
- out = torch .zeros_like (inp )
360
+ out = torch .empty_like (inp )
361
361
ops .all_reduce_unreg (self ._ptr , inp , self .buffer , out )
362
362
return out
363
363
@@ -375,7 +375,7 @@ def all_reduce(
375
375
buffer.
376
376
"""
377
377
if out is None :
378
- out = torch .zeros_like (inp )
378
+ out = torch .empty_like (inp )
379
379
if registered :
380
380
ops .all_reduce (self ._ptr , inp , out , 0 , 0 )
381
381
else :
@@ -398,7 +398,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
398
398
else :
399
399
# If warm up, mimic the allocation pattern since custom
400
400
# allreduce is out-of-place.
401
- return torch .zeros_like (input )
401
+ return torch .empty_like (input )
402
402
else :
403
403
if _is_hip :
404
404
# note: outside of cuda graph context,
0 commit comments