Skip to content

Commit fc3789e

Browse files
nvcastetssssnow
authored andcommitted
Add support for NCCL symmetric memory for TP allreduces (#8238)
1 parent 25ef410 commit fc3789e

File tree

13 files changed

+266
-30
lines changed

13 files changed

+266
-30
lines changed

docs/backend/server_arguments.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ Please consult the documentation below and [server_args.py](https://github.com/s
251251
| `--disable-cuda-graph-padding` | Disable cuda graph when padding is needed. Still uses cuda graph when padding is not needed. | False |
252252
| `--enable-profile-cuda-graph` | Enable profiling of cuda graph capture. | False |
253253
| `--enable-nccl-nvls` | Enable NCCL NVLS for prefill heavy requests when available. | False |
254+
| `--enable-symm-mem` | Enable NCCL symmetric memory for fast collectives. | False |
254255
| `--enable-tokenizer-batch-encode` | Enable batch tokenization for improved performance when processing multiple text inputs. Do not use with image inputs, pre-tokenized input_ids, or input_embeds. | False |
255256
| `--disable-outlines-disk-cache` | Disable disk cache of outlines to avoid possible crashes related to file system or high concurrency. | False |
256257
| `--disable-custom-all-reduce` | Disable the custom all-reduce kernel and fall back to NCCL. | False |

python/sglang/srt/distributed/device_communicators/pynccl.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
self.available = True
7676
self.disabled = False
7777

78+
self.nccl_version = self.nccl.ncclGetRawVersion()
7879
if self.rank == 0:
7980
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
8081

@@ -259,6 +260,12 @@ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
259260
cudaStream_t(stream.cuda_stream),
260261
)
261262

263+
def register_comm_window_raw(self, ptr: int, size: int):
264+
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
265+
266+
def deregister_comm_window(self, window):
267+
return self.nccl.ncclCommWindowDeregister(self.comm, window)
268+
262269
@contextmanager
263270
def change_state(
264271
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import tempfile
2+
3+
import torch
4+
from packaging import version
5+
from torch.cuda.memory import CUDAPluggableAllocator
6+
7+
from sglang.srt.distributed.parallel_state import GroupCoordinator
8+
from sglang.srt.managers.schedule_batch import global_server_args_dict
9+
10+
nccl_allocator_source = """
11+
#include <nccl.h>
12+
extern "C" {
13+
14+
void* nccl_alloc_plug(size_t size, int device, void* stream) {
15+
void* ptr;
16+
ncclResult_t err = ncclMemAlloc(&ptr, size);
17+
return ptr;
18+
19+
}
20+
21+
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
22+
ncclResult_t err = ncclMemFree(ptr);
23+
}
24+
25+
}
26+
"""
27+
28+
_allocator = None
29+
_mem_pool = None
30+
_registered_base_addrs = set()
31+
_graph_pool_id = None
32+
33+
34+
def is_symmetric_memory_enabled():
35+
return global_server_args_dict["enable_symm_mem"]
36+
37+
38+
def set_graph_pool_id(graph_pool_id):
39+
global _graph_pool_id
40+
_graph_pool_id = graph_pool_id
41+
42+
43+
def get_nccl_mem_pool():
44+
global _allocator, _mem_pool
45+
if _mem_pool is None:
46+
out_dir = tempfile.gettempdir()
47+
nccl_allocator_libname = "nccl_allocator"
48+
torch.utils.cpp_extension.load_inline(
49+
name=nccl_allocator_libname,
50+
cpp_sources=nccl_allocator_source,
51+
with_cuda=True,
52+
extra_ldflags=["-lnccl"],
53+
verbose=True,
54+
is_python_module=False,
55+
build_directory=out_dir,
56+
)
57+
_allocator = CUDAPluggableAllocator(
58+
f"{out_dir}/{nccl_allocator_libname}.so",
59+
"nccl_alloc_plug",
60+
"nccl_free_plug",
61+
).allocator()
62+
_mem_pool = torch.cuda.MemPool(_allocator)
63+
return _mem_pool
64+
65+
66+
class use_symmetric_memory:
67+
def __init__(self, group_coordinator: GroupCoordinator):
68+
if not is_symmetric_memory_enabled():
69+
self.group_coordinator = None
70+
self._mem_pool_ctx = None
71+
self.is_graph_capture = None
72+
self.device = None
73+
self.pre_2_8_0 = None
74+
else:
75+
self.group_coordinator = group_coordinator
76+
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
77+
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
78+
self.device = torch.cuda.current_device()
79+
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
80+
81+
def __enter__(self):
82+
if not is_symmetric_memory_enabled():
83+
return self
84+
assert (
85+
self.group_coordinator.pynccl_comm is not None
86+
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
87+
assert (
88+
self.group_coordinator.pynccl_comm.nccl_version >= 22703
89+
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
90+
if self.is_graph_capture:
91+
assert (
92+
_graph_pool_id is not None
93+
), "graph_pool_id is not set under graph capture"
94+
# Pause graph memory pool to use symmetric memory with cuda graph
95+
if self.pre_2_8_0:
96+
torch._C._cuda_endAllocateCurrentStreamToPool(
97+
self.device, _graph_pool_id
98+
)
99+
else:
100+
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
101+
self._mem_pool_ctx.__enter__()
102+
return self
103+
104+
def tag(self, tensor: torch.Tensor):
105+
if not is_symmetric_memory_enabled():
106+
return
107+
tensor.symmetric_memory = True
108+
109+
def __exit__(self, exc_type, exc_val, exc_tb):
110+
if not is_symmetric_memory_enabled():
111+
return
112+
global _registered_base_addrs
113+
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
114+
for segment in get_nccl_mem_pool().snapshot():
115+
if segment["address"] not in _registered_base_addrs:
116+
if segment["stream"] == 0 and self.pre_2_8_0:
117+
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
118+
# See https://github.com/pytorch/pytorch/issues/152861
119+
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
120+
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
121+
continue
122+
self.group_coordinator.pynccl_comm.register_comm_window_raw(
123+
segment["address"], segment["total_size"]
124+
)
125+
_registered_base_addrs.add(segment["address"])
126+
127+
if self.is_graph_capture:
128+
if self.pre_2_8_0:
129+
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
130+
else:
131+
torch._C._cuda_beginAllocateCurrentThreadToPool(
132+
self.device, _graph_pool_id
133+
)

python/sglang/srt/distributed/device_communicators/pynccl_wrapper.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def find_nccl_library() -> str:
6767

6868
ncclResult_t = ctypes.c_int
6969
ncclComm_t = ctypes.c_void_p
70+
ncclWindow_t = ctypes.c_void_p
7071

7172

7273
class ncclUniqueId(ctypes.Structure):
@@ -279,6 +280,23 @@ class NCCLLibrary:
279280
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
280281
]
281282

283+
exported_functions_symm_mem = [
284+
# ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);
285+
Function(
286+
"ncclCommWindowRegister",
287+
ncclResult_t,
288+
[
289+
ncclComm_t,
290+
buffer_type,
291+
ctypes.c_size_t,
292+
ctypes.POINTER(ncclWindow_t),
293+
ctypes.c_int,
294+
],
295+
),
296+
# ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);
297+
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
298+
]
299+
282300
# class attribute to store the mapping from the path to the library
283301
# to avoid loading the same library multiple times
284302
path_to_library_cache: Dict[str, Any] = {}
@@ -312,7 +330,10 @@ def __init__(self, so_file: Optional[str] = None):
312330

313331
if so_file not in NCCLLibrary.path_to_dict_mapping:
314332
_funcs: Dict[str, Any] = {}
315-
for func in NCCLLibrary.exported_functions:
333+
exported_functions = NCCLLibrary.exported_functions
334+
if hasattr(self.lib, "ncclCommWindowRegister"):
335+
exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)
336+
for func in exported_functions:
316337
f = getattr(self.lib, func.name)
317338
f.restype = func.restype
318339
f.argtypes = func.argtypes
@@ -328,10 +349,14 @@ def NCCL_CHECK(self, result: ncclResult_t) -> None:
328349
error_str = self.ncclGetErrorString(result)
329350
raise RuntimeError(f"NCCL error: {error_str}")
330351

331-
def ncclGetVersion(self) -> str:
352+
def ncclGetRawVersion(self) -> int:
332353
version = ctypes.c_int()
333354
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
334-
version_str = str(version.value)
355+
# something like 21903
356+
return version.value
357+
358+
def ncclGetVersion(self) -> str:
359+
version_str = str(self.ncclGetRawVersion())
335360
# something like 21903 --> "2.19.3"
336361
major = version_str[0].lstrip("0")
337362
minor = version_str[1:3].lstrip("0")
@@ -460,6 +485,20 @@ def ncclBroadcast(
460485
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
461486
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
462487

488+
def ncclCommWindowRegister(
489+
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
490+
) -> ncclWindow_t:
491+
window = ncclWindow_t()
492+
self.NCCL_CHECK(
493+
self._funcs["ncclCommWindowRegister"](
494+
comm, buff, size, ctypes.byref(window), win_flags
495+
)
496+
)
497+
return window
498+
499+
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
500+
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
501+
463502

464503
__all__ = [
465504
"NCCLLibrary",

python/sglang/srt/distributed/parallel_state.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,17 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
497497
if self.npu_communicator is not None and not self.npu_communicator.disabled:
498498
return self.npu_communicator.all_reduce(input_)
499499

500+
if (
501+
self.pynccl_comm is not None
502+
and hasattr(input_, "symmetric_memory")
503+
and input_.symmetric_memory
504+
):
505+
with self.pynccl_comm.change_state(
506+
enable=True, stream=torch.cuda.current_stream()
507+
):
508+
self.pynccl_comm.all_reduce(input_)
509+
return input_
510+
500511
outplace_all_reduce_method = None
501512
if (
502513
self.qr_comm is not None

python/sglang/srt/entrypoints/engine.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -623,8 +623,9 @@ async def async_score(
623623
def _set_envs_and_config(server_args: ServerArgs):
624624
# Set global environments
625625
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
626-
os.environ["NCCL_CUMEM_ENABLE"] = "0"
627-
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
626+
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
627+
if not server_args.enable_symm_mem:
628+
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
628629
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
629630
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
630631
os.environ["CUDA_MODULE_LOADING"] = "AUTO"

python/sglang/srt/layers/linear.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@
1313
divide,
1414
get_tensor_model_parallel_rank,
1515
get_tensor_model_parallel_world_size,
16+
parallel_state,
1617
split_tensor_along_last_dim,
1718
tensor_model_parallel_all_gather,
1819
tensor_model_parallel_all_reduce,
1920
)
21+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
22+
use_symmetric_memory,
23+
)
2024
from sglang.srt.layers.parameter import (
2125
BasevLLMParameter,
2226
BlockQuantScaleParameter,
@@ -1292,7 +1296,9 @@ def forward(self, input_, can_fuse_mlp_allreduce=False):
12921296
# Only fuse bias add into GEMM for rank 0 (this ensures that
12931297
# bias will not get added more than once in TP>1 case)
12941298
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1295-
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1299+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1300+
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1301+
sm.tag(output_parallel)
12961302
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
12971303
output = tensor_model_parallel_all_reduce(output_parallel)
12981304
else:

python/sglang/srt/layers/moe/fused_moe_triton/layer.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,12 @@
1414
get_moe_expert_parallel_world_size,
1515
get_moe_tensor_parallel_rank,
1616
get_moe_tensor_parallel_world_size,
17+
get_tp_group,
1718
tensor_model_parallel_all_reduce,
1819
)
20+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
21+
use_symmetric_memory,
22+
)
1923
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
2024
from sglang.srt.layers.moe.topk import StandardTopKOutput
2125
from sglang.srt.layers.quantization.base_config import (
@@ -626,24 +630,27 @@ def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
626630
)
627631

628632
# Matrix multiply.
629-
final_hidden_states = self.quant_method.apply(
630-
layer=self,
631-
x=hidden_states,
632-
topk_output=topk_output,
633-
activation=self.activation,
634-
apply_router_weight_on_input=self.apply_router_weight_on_input,
635-
routed_scaling_factor=self.routed_scaling_factor,
636-
**(
637-
dict(
638-
tp_rank=self.moe_tp_rank,
639-
tp_size=self.moe_tp_size,
640-
ep_rank=self.moe_ep_rank,
641-
ep_size=self.moe_ep_size,
642-
)
643-
if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
644-
else {}
645-
),
646-
)
633+
with use_symmetric_memory(get_tp_group()) as sm:
634+
final_hidden_states = self.quant_method.apply(
635+
layer=self,
636+
x=hidden_states,
637+
topk_output=topk_output,
638+
activation=self.activation,
639+
apply_router_weight_on_input=self.apply_router_weight_on_input,
640+
routed_scaling_factor=self.routed_scaling_factor,
641+
**(
642+
dict(
643+
tp_rank=self.moe_tp_rank,
644+
tp_size=self.moe_tp_size,
645+
ep_rank=self.moe_ep_rank,
646+
ep_size=self.moe_ep_size,
647+
)
648+
if self.quant_method.__class__.__name__
649+
== "ModelOptNvFp4FusedMoEMethod"
650+
else {}
651+
),
652+
)
653+
sm.tag(final_hidden_states)
647654

648655
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
649656
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)

python/sglang/srt/layers/vocab_parallel_embedding.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,12 @@
1111
divide,
1212
get_tensor_model_parallel_rank,
1313
get_tensor_model_parallel_world_size,
14+
parallel_state,
1415
tensor_model_parallel_all_reduce,
1516
)
17+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
18+
use_symmetric_memory,
19+
)
1620
from sglang.srt.layers.amx_utils import PackWeightMethod
1721
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
1822
from sglang.srt.layers.parameter import BasevLLMParameter
@@ -464,7 +468,9 @@ def forward(self, input_):
464468
else:
465469
masked_input = input_
466470
# Get the embeddings.
467-
output_parallel = self.quant_method.embedding(self, masked_input.long())
471+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
472+
output_parallel = self.quant_method.embedding(self, masked_input.long())
473+
sm.tag(output_parallel)
468474
# Mask the output embedding.
469475
if self.tp_size > 1:
470476
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)

python/sglang/srt/managers/schedule_batch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@
108108
"weight_loader_disable_mmap",
109109
"enable_triton_kernel_moe",
110110
"enable_multimodal",
111+
"enable_symm_mem",
111112
]
112113

113114
# Put some global args for easy access

0 commit comments

Comments
 (0)