|
| 1 | +import tempfile |
| 2 | + |
| 3 | +import torch |
| 4 | +from packaging import version |
| 5 | +from torch.cuda.memory import CUDAPluggableAllocator |
| 6 | + |
| 7 | +from sglang.srt.distributed.parallel_state import GroupCoordinator |
| 8 | +from sglang.srt.managers.schedule_batch import global_server_args_dict |
| 9 | + |
| 10 | +nccl_allocator_source = """ |
| 11 | +#include <nccl.h> |
| 12 | +extern "C" { |
| 13 | +
|
| 14 | +void* nccl_alloc_plug(size_t size, int device, void* stream) { |
| 15 | + void* ptr; |
| 16 | + ncclResult_t err = ncclMemAlloc(&ptr, size); |
| 17 | + return ptr; |
| 18 | +
|
| 19 | +} |
| 20 | +
|
| 21 | +void nccl_free_plug(void* ptr, size_t size, int device, void* stream) { |
| 22 | + ncclResult_t err = ncclMemFree(ptr); |
| 23 | +} |
| 24 | +
|
| 25 | +} |
| 26 | +""" |
| 27 | + |
| 28 | +_allocator = None |
| 29 | +_mem_pool = None |
| 30 | +_registered_base_addrs = set() |
| 31 | +_graph_pool_id = None |
| 32 | + |
| 33 | + |
| 34 | +def is_symmetric_memory_enabled(): |
| 35 | + return global_server_args_dict["enable_symm_mem"] |
| 36 | + |
| 37 | + |
| 38 | +def set_graph_pool_id(graph_pool_id): |
| 39 | + global _graph_pool_id |
| 40 | + _graph_pool_id = graph_pool_id |
| 41 | + |
| 42 | + |
| 43 | +def get_nccl_mem_pool(): |
| 44 | + global _allocator, _mem_pool |
| 45 | + if _mem_pool is None: |
| 46 | + out_dir = tempfile.gettempdir() |
| 47 | + nccl_allocator_libname = "nccl_allocator" |
| 48 | + torch.utils.cpp_extension.load_inline( |
| 49 | + name=nccl_allocator_libname, |
| 50 | + cpp_sources=nccl_allocator_source, |
| 51 | + with_cuda=True, |
| 52 | + extra_ldflags=["-lnccl"], |
| 53 | + verbose=True, |
| 54 | + is_python_module=False, |
| 55 | + build_directory=out_dir, |
| 56 | + ) |
| 57 | + _allocator = CUDAPluggableAllocator( |
| 58 | + f"{out_dir}/{nccl_allocator_libname}.so", |
| 59 | + "nccl_alloc_plug", |
| 60 | + "nccl_free_plug", |
| 61 | + ).allocator() |
| 62 | + _mem_pool = torch.cuda.MemPool(_allocator) |
| 63 | + return _mem_pool |
| 64 | + |
| 65 | + |
| 66 | +class use_symmetric_memory: |
| 67 | + def __init__(self, group_coordinator: GroupCoordinator): |
| 68 | + if not is_symmetric_memory_enabled(): |
| 69 | + self.group_coordinator = None |
| 70 | + self._mem_pool_ctx = None |
| 71 | + self.is_graph_capture = None |
| 72 | + self.device = None |
| 73 | + self.pre_2_8_0 = None |
| 74 | + else: |
| 75 | + self.group_coordinator = group_coordinator |
| 76 | + self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool()) |
| 77 | + self.is_graph_capture = torch.cuda.is_current_stream_capturing() |
| 78 | + self.device = torch.cuda.current_device() |
| 79 | + self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0") |
| 80 | + |
| 81 | + def __enter__(self): |
| 82 | + if not is_symmetric_memory_enabled(): |
| 83 | + return self |
| 84 | + assert ( |
| 85 | + self.group_coordinator.pynccl_comm is not None |
| 86 | + ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'" |
| 87 | + assert ( |
| 88 | + self.group_coordinator.pynccl_comm.nccl_version >= 22703 |
| 89 | + ), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory" |
| 90 | + if self.is_graph_capture: |
| 91 | + assert ( |
| 92 | + _graph_pool_id is not None |
| 93 | + ), "graph_pool_id is not set under graph capture" |
| 94 | + # Pause graph memory pool to use symmetric memory with cuda graph |
| 95 | + if self.pre_2_8_0: |
| 96 | + torch._C._cuda_endAllocateCurrentStreamToPool( |
| 97 | + self.device, _graph_pool_id |
| 98 | + ) |
| 99 | + else: |
| 100 | + torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id) |
| 101 | + self._mem_pool_ctx.__enter__() |
| 102 | + return self |
| 103 | + |
| 104 | + def tag(self, tensor: torch.Tensor): |
| 105 | + if not is_symmetric_memory_enabled(): |
| 106 | + return |
| 107 | + tensor.symmetric_memory = True |
| 108 | + |
| 109 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 110 | + if not is_symmetric_memory_enabled(): |
| 111 | + return |
| 112 | + global _registered_base_addrs |
| 113 | + self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb) |
| 114 | + for segment in get_nccl_mem_pool().snapshot(): |
| 115 | + if segment["address"] not in _registered_base_addrs: |
| 116 | + if segment["stream"] == 0 and self.pre_2_8_0: |
| 117 | + # PyTorch version < 2.8.0 has a multi-thread MemPool bug |
| 118 | + # See https://github.com/pytorch/pytorch/issues/152861 |
| 119 | + # Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b |
| 120 | + # WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream |
| 121 | + continue |
| 122 | + self.group_coordinator.pynccl_comm.register_comm_window_raw( |
| 123 | + segment["address"], segment["total_size"] |
| 124 | + ) |
| 125 | + _registered_base_addrs.add(segment["address"]) |
| 126 | + |
| 127 | + if self.is_graph_capture: |
| 128 | + if self.pre_2_8_0: |
| 129 | + torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id) |
| 130 | + else: |
| 131 | + torch._C._cuda_beginAllocateCurrentThreadToPool( |
| 132 | + self.device, _graph_pool_id |
| 133 | + ) |
0 commit comments