Skip to content

Commit 0ad098b

Browse files
authored
Revert "Fix nan value generated after custom all reduce (#8532)" (#8642)
1 parent 4a6e7a6 commit 0ad098b

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

python/sglang/srt/distributed/device_communicators/custom_all_reduce.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def __init__(
184184
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
185185
# is enough for 131072 such tuples. The largest model I've seen only
186186
# needs less than 10000 of registered tuples.
187-
self.rank_data = torch.zeros(
187+
self.rank_data = torch.empty(
188188
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
189189
)
190190
self._ptr = ops.init_custom_ar(
@@ -194,14 +194,14 @@ def __init__(
194194
else:
195195
# meta data buffers need to be "uncached" for signal on MI200
196196
self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
197-
self.buffer = torch.zeros(max_size, dtype=torch.uint8, device=self.device)
197+
self.buffer = torch.empty(max_size, dtype=torch.uint8, device=self.device)
198198
handle = ops.get_meta_buffer_ipc_handle(self.meta)
199199
shard_data = (
200200
bytes(handle), # ipc handle to base ptr
201201
0, # offset of base ptr
202202
)
203203
handles, offsets = self._gather_ipc_meta(shard_data)
204-
self.rank_data = torch.zeros(
204+
self.rank_data = torch.empty(
205205
8 * 1024 * 1024, dtype=torch.uint8, device=self.device
206206
)
207207
self._ptr = ops.init_custom_ar(
@@ -350,14 +350,14 @@ def should_custom_ar(self, inp: torch.Tensor):
350350
# or, in the context of cuda graphs, register_graph_buffers
351351
def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
352352
if out is None:
353-
out = torch.zeros_like(inp)
353+
out = torch.empty_like(inp)
354354
ops.all_reduce_reg(self._ptr, inp, out)
355355
return out
356356

357357
# all reduce, assuming inp tensor is NOT IPC registered
358358
def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
359359
if out is None:
360-
out = torch.zeros_like(inp)
360+
out = torch.empty_like(inp)
361361
ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
362362
return out
363363

@@ -375,7 +375,7 @@ def all_reduce(
375375
buffer.
376376
"""
377377
if out is None:
378-
out = torch.zeros_like(inp)
378+
out = torch.empty_like(inp)
379379
if registered:
380380
ops.all_reduce(self._ptr, inp, out, 0, 0)
381381
else:
@@ -398,7 +398,7 @@ def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
398398
else:
399399
# If warm up, mimic the allocation pattern since custom
400400
# allreduce is out-of-place.
401-
return torch.zeros_like(input)
401+
return torch.empty_like(input)
402402
else:
403403
if _is_hip:
404404
# note: outside of cuda graph context,

0 commit comments

Comments
 (0)