Skip to content

Commit 5df12a4

Browse files
tjruwaseloadamsjomayeri
authored
DeepNVMe tutorial (#6449)
Co-authored-by: Logan Adams <[email protected]> Co-authored-by: jomayeri <deepspeed@H100-VM2.shlnn55tgwve1eacvp21ie45dg.jx.internal.cloudapp.net>
1 parent cfc6ed3 commit 5df12a4

File tree

7 files changed

+257
-4
lines changed

7 files changed

+257
-4
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ repos:
7676
name: check-torchcuda
7777
entry: ./scripts/check-torchcuda.py
7878
language: python
79-
exclude: ^(.github/workflows/|scripts/check-torchcuda.py|docs/_tutorials/accelerator-abstraction-interface.md|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py)
79+
exclude: ^(.github/workflows/|scripts/check-torchcuda.py|docs/_tutorials/accelerator-abstraction-interface.md|docs/_tutorials/deepnvme.md|accelerator/cuda_accelerator.py|deepspeed/inference/engine.py|deepspeed/model_implementations/transformers/clip_encoder.py|deepspeed/model_implementations/diffusers/vae.py|deepspeed/model_implementations/diffusers/unet.py|op_builder/spatial_inference.py|op_builder/transformer_inference.py|op_builder/builder.py|setup.py|tests/unit/ops/sparse_attention/test_sparse_attention.py)
8080
# Specific deepspeed/ files are excluded for now until we wrap ProcessGroup in deepspeed.comm
8181

8282
- repo: local

blogs/deepspeed-gds/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ this problem, DeepSpeed has created a suite of I/O optimizations collectively ca
1717

1818
DeepNVMe improves the performance and efficiency of I/O-bound DL applications by accelerating I/O operations
1919
and reducing hardware requirements. It achieves this by leveraging storage innovations such as Non-Volatile
20-
Memory Express (NVMe) Solid Storage Devices (SSDs) and NVIDIA Magnum IO<sup>TM</sup> GPUDirect® Storage (GDS). In this
20+
Memory Express (NVMe) Solid State Drives (SSDs) and NVIDIA Magnum IO<sup>TM</sup> GPUDirect® Storage (GDS). In this
2121
blog we show the benefits of DeepNVMe using microbenchmarks and an inference application. In experiments
2222
conducted on an Azure NC96ads\_A100\_v4 VM, we observed that DeepNVMe saturates available NVMe bandwidth for
2323
data transfers with GPU or CPU memory, achieving up to 10GB/sec reads and 5 GB/secs writes.

docs/_data/navigation.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ lnav:
5555
url: /getting-started/
5656
- title: 'Getting started on Azure'
5757
url: /tutorials/azure/
58+
- title: 'Accelerator Abstraction'
59+
url: /tutorials/accelerator-abstraction-interface/
60+
- title: 'Accelerator Setup Guides'
61+
url: /tutorials/accelerator-setup-guide/
5862
- title: 'Automatic Tensor Parallelism'
5963
url: /tutorials/automatic-tensor-parallelism/
6064
- title: 'Autotuning'
@@ -69,6 +73,8 @@ lnav:
6973
url: /tutorials/curriculum-learning/
7074
- title: 'Data Efficiency'
7175
url: /tutorials/data-efficiency/
76+
- title: 'DeepNVMe'
77+
url: /tutorials/deepnvme/
7278
- title: 'DS4Sci_EvoformerAttention'
7379
url: /tutorials/ds4sci_evoformerattention/
7480
- title: 'Flops Profiler'

docs/_tutorials/accelerator-abstraction-interface.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
title: DeepSpeed Accelerator Abstraction Interface
3-
tags: getting-started
3+
tags: getting-started training accelerator
44
---
55

66
# Contents

docs/_tutorials/accelerator-setup-guide.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
---
22
title: DeepSpeed Accelerator Setup Guides
3-
tags: getting-started
3+
tags: getting-started training accelerator
44
---
55

66
# Contents

docs/_tutorials/deepnvme.md

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
---
2+
title: "DeepNVMe"
3+
tags: training inference IO large-model
4+
---
5+
This tutorial will show how to use [DeepNVMe](https://github.com/microsoft/DeepSpeed/blob/master/blogs/deepspeed-gds/README.md) for data transfers between persistent storage and tensors residing in host or device memory. DeepNVMe improves the performance and efficiency of I/O operations in Deep Learning applications through powerful optimizations built on Non-Volatile Memory Express (NVMe) Solid State Drives (SSDs), Linux Asynchronous I/O (`libaio`), and NVIDIA Magnum IO<sup>TM</sup> GPUDirect® Storage (GDS).
6+
7+
## Requirements
8+
Ensure your environment is properly configured to use DeepNVMe. First, you need to install DeepSpeed version >= [0.15.0](https://github.com/microsoft/DeepSpeed/releases/tag/v0.15.0). Next, ensure that the DeepNVMe operators are available in the DeepSpeed installation. The `async_io` operator is required for any DeepNVMe functionality, while the `gds` operator is required only for GDS functionality. You can confirm availability of each operator by inspecting the output of `ds_report` to check that compatible status is <span style="color:green">[OKAY]</span>. Below is a snippet of `ds_report` output confirming the availability of both `async_io` and `gds` operators.
9+
10+
![deepnvme_ops_report](/assets/images/deepnvme_ops_report.png)
11+
12+
If `async_io` operator is unavailable, you will need to install the appropriate `libaio` library binaries for your Linux flavor. For example, Ubuntu users will need to run `apt install libaio-dev`. In general, you should carefully inspect `ds_report` output for helpful tips such as the following:
13+
14+
```bash
15+
[WARNING] async_io requires the dev libaio .so object and headers but these were not found.
16+
[WARNING] async_io: please install the libaio-dev package with apt
17+
[WARNING] If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
18+
```
19+
20+
To enable `gds` operator, you will need to install NVIDIA GDS by consulting the appropriate guide for [bare-metal systems](https://docs.nvidia.com/gpudirect-storage/troubleshooting-guide/index.html) or Azure VMs (coming soon).
21+
22+
23+
## Creating DeepNVMe Handles
24+
DeepNVMe functionality can be accessed through two abstractions: `aio_handle` and `gds_handle`. The `aio_handle` is usable on both host and device tensors. while `gds_handle` works only on CUDA tensors, but is more efficient. The first step to use DeepNVMe is to create a desired handle. `aio_handle` requires `async_io` operator, while `gds_handle` requires both `async_io` and `gds` operators. The following snippets illustrate `aio_handle` and `gds_handle` creation respectively.
25+
26+
```python
27+
### Create aio_handle
28+
from deepspeed.ops.op_builder import AsyncIOBuilder
29+
aio_handle = AsyncIOBuilder().load().aio_handle()
30+
```
31+
32+
```python
33+
### Create gds_handle
34+
from deepspeed.ops.op_builder import GDSBuilder
35+
gds_handle = GDSBuilder().load().gds_handle()
36+
```
37+
38+
For simplicity, the above examples illustrate handle creation using default parameters. We expect that handles created with default parameters to provide good performance in most environments. However, you can see [below](#advanced-handle-creation) for advanced handle creation.
39+
40+
## Using DeepNVMe Handles
41+
`aio_handle` and `gds_handle` provide identical APIs for storing tensors to files or loading tensors from files. A common feature of these APIs is that they take a tensor and a file path as arguments for the desired I/O operation. For best performance, pinned device or host tensors should be used for I/O operations (see [here](#pinned-tensors) for details). For brevity, this tutorial will use `aio_handle` for illustration, but keep in mind that `gds_handle` works similarly.
42+
43+
You can see the available APIs in a Python shell via tab completion on an `aio_handle` object . This is illustrated using tab completion of `h.`.
44+
45+
```bash
46+
>python
47+
Python 3.10.12 (main, Jul 29 2024, 16:56:48) [GCC 11.4.0] on linux
48+
Type "help", "copyright", "credits" or "license" for more information.
49+
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
50+
>>> h = AsyncIOBuilder().load().aio_handle()
51+
>>> h.
52+
h.async_pread( h.free_cpu_locked_tensor( h.get_overlap_events( h.get_single_submit( h.new_cpu_locked_tensor( h.pwrite( h.sync_pread( h.wait(
53+
h.async_pwrite( h.get_block_size( h.get_queue_depth( h.get_thread_count( h.pread( h.read( h.sync_pwrite( h.write(
54+
```
55+
The APIs of interest for performing I/O operations are those named with `pread` and `pwrite` substrings. For brevity, we will focus on the file write APIs, namely `sync_pwrite`, `async_pwrite`, and `pwrite`. We will discuss only `sync_pwrite` and `async_pwrite` below because they are specializations of `pwrite`.
56+
57+
### Blocking File Write
58+
`sync_pwrite` provides the standard blocking semantics of Python file write. The example below illustrates using `sync_pwrite` to store a 1GB CUDA tensor to a local NVMe file.
59+
60+
```bash
61+
>>> import os
62+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
63+
False
64+
>>> import torch
65+
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
66+
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
67+
>>> h = AsyncIOBuilder().load().aio_handle()
68+
>>> h.sync_pwrite(t,'/local_nvme/test_1GB.pt')
69+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
70+
True
71+
>>> os.path.getsize('/local_nvme/test_1GB.pt')
72+
1073741824
73+
74+
```
75+
76+
### Non-Blocking File Write
77+
An important DeepNVMe optimization is the non-blocking I/O semantics which enables Python threads to overlap computations with I/O operations. `async_pwrite` provides the non-blocking semantics for file writes. The Python thread can later use `wait()` to synchronize with the I/O operation. `async_write` can also be used to submit multiple back-to-back non-blocking I/O operations, of which can then be later blocked on using a single `wait()`. The example below illustrates using `async_pwrite` to store a 1GB CUDA tensor to a local NVMe file.
78+
79+
```bash
80+
>>> import os
81+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
82+
False
83+
>>> import torch
84+
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
85+
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
86+
>>> h = AsyncIOBuilder().load().aio_handle()
87+
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
88+
>>> h.wait()
89+
1
90+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
91+
True
92+
>>> os.path.getsize('/local_nvme/test_1GB.pt')
93+
1073741824
94+
```
95+
96+
<span style="color:red">Warning for non-blocking I/O operations:</span> To avoid data races and corruptions, `.wait()` must be carefully used to serialize the writing of source tensors, and the reading of destination tensors. For example, the following update of `t` during a non-blocking file write is unsafe and could corrupt `/local_nvme/test_1GB.pt`.
97+
98+
```bash
99+
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
100+
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
101+
>>> h = AsyncIOBuilder().load().aio_handle()
102+
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
103+
>>> t += 1 # <--- Data race; avoid by preceding with `h.wait()`
104+
```
105+
106+
Similar safety problems apply to reading the destination tensor of a non-blocking file read without `.wait()` synchronization.
107+
108+
109+
### Parallel File Write
110+
An important DeepNVMe optimization is the ability to parallelize individual I/O operations. This optimization is enabled by specifying the desired parallelism degree when constructing a DeepNVMe handle. Subsequent I/O operations with that handle are automatically parallelized over the requested number of host or device threads, as appropriate. I/O parallelism is composable with either the blocking or non-blocking I/O APIs. The example below illustrates 4-way parallelism of a file write using `async_pwrite`. Note the use of `num_threads` argument to specify the desired parallelism degree in handle creation.
111+
112+
```bash
113+
>>> import os
114+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
115+
False
116+
>>> import torch
117+
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
118+
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
119+
>>> h = AsyncIOBuilder().load().aio_handle(num_threads=4)
120+
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
121+
>>> h.wait()
122+
1
123+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
124+
True
125+
>>> os.path.getsize('/local_nvme/test_1GB.pt')
126+
1073741824
127+
```
128+
129+
### Pinned Tensors
130+
A key part of DeepNVMe optimizations is using direct memory access (DMA) for I/O operations, which requires that the host or device tensor be pinned. To pin host tensors, you can use mechanisms provided by [Pytorch](https://pytorch.org/docs/stable/generated/torch.Tensor.pin_memory.html) or [DeepSpeed Accelerators](/tutorials/accelerator-abstraction-interface/#tensor-operations). The following example illustrates writing a pinned CPU tensor to a local NVMe file.
131+
132+
```bash
133+
>>> import os
134+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
135+
False
136+
>>> import torch
137+
>>> t=torch.empty(1024**3, dtype=torch.uint8).pin_memory()
138+
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
139+
>>> h = AsyncIOBuilder().load().aio_handle()
140+
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
141+
>>> h.wait()
142+
1
143+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
144+
True
145+
>>> os.path.getsize('/local_nvme/test_1GB.pt')
146+
1073741824
147+
```
148+
149+
On the other hand,`gds_handle` provides `new_pinned_device_tensor()` and `pin_device_tensor()` functions for pinning CUDA tensors. The following example illustrates writing a pinned CUDA tensor to a local NVMe file.
150+
151+
```bash
152+
>>> import os
153+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
154+
False
155+
>>> import torch
156+
>>> t=torch.empty(1024**3, dtype=torch.uint8).cuda()
157+
>>> from deepspeed.ops.op_builder import GDSBuilder
158+
>>> h = GDSBuilder().load().gds_handle()
159+
>>> h.pin_device_tensor(t)
160+
>>> h.async_pwrite(t,'/local_nvme/test_1GB.pt')
161+
>>> h.wait()
162+
1
163+
>>> os.path.isfile('/local_nvme/test_1GB.pt')
164+
True
165+
>>> os.path.getsize('/local_nvme/test_1GB.pt')
166+
1073741824
167+
>>> h.unpin_device_tensor(t)
168+
```
169+
170+
171+
## Putting it together
172+
We hope that the above material helps you to get started with DeepNVMe. You can also use the following links to see DeepNVMe usage in real-world Deep Learning applications.
173+
174+
1. [Parameter swapper](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py#L111-L117) in [ZeRO-Inference](https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/zero_inference/README.md) and [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/).
175+
2. [Optimizer swapper](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L36-L38) in [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/).
176+
3. [Gradient swapper](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L41-L43) in [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/).
177+
4. Simple file read and write [operations](https://github.com/microsoft/DeepSpeedExamples/blob/master/deepnvme/file_access/README.md).
178+
179+
<!-- 1. ZeRO-Inference: used for [parameter offloading](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py#L111-L117).
180+
181+
2. [ZeRO-Infinity](https://www.microsoft.com/en-us/research/blog/zero-infinity-and-deepspeed-unlocking-unprecedented-model-scale-for-deep-learning-training/): used for offloading [parameters](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_param_swapper.py#L111-L117), [gradients](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L41-L43), and [optimizer](https://github.com/microsoft/DeepSpeed/blob/9b7fc5452471392b0f58844219fcfdd14a9cdc77/deepspeed/runtime/swap_tensor/partitioned_optimizer_swapper.py#L36-L38).
182+
3. Simple file read and write [operations](https://github.com/microsoft/DeepSpeedExamples/blob/master/deepnvme/file_access/README.md). -->
183+
184+
185+
## Acknowledgements
186+
This tutorial has been significantly improved by feedback from [Guanhua Wang](https://github.com/GuanhuaWang), [Masahiro Tanaka](https://github.com/tohtana), and [Stas Bekman](https://github.com/stas00).
187+
188+
## Appendix
189+
190+
### Advanced Handle Creation
191+
Achieving peak I/O performance with DeepNVMe requires careful configuration of handle creation. In particular, the parameters of `aio_handle` and `gds_handle` constructors are performance-critical because they determine how efficiently DeepNVMe interacts with the underlying storage subsystem (i.e., `libaio`, GDS, and SSD). For convenience we make it possible to create handles using default parameter values which will provide decent performance in most scenarios. However, squeezing out every available performance in your environment will likely require tuning the constructor parameters, namely `block_size`, `queue_depth`, `single_submit`, `overlap_events`, and `num_threads`. The `aio_handle` constructor parameters and default values are illustrated below:
192+
```bash
193+
>>> from deepspeed.ops.op_builder import AsyncIOBuilder
194+
>>> help(AsyncIOBuilder().load().aio_handle())
195+
Help on aio_handle in module async_io object:
196+
197+
class aio_handle(pybind11_builtins.pybind11_object)
198+
| Method resolution order:
199+
| aio_handle
200+
| pybind11_builtins.pybind11_object
201+
| builtins.object
202+
|
203+
| Methods defined here:
204+
|
205+
| __init__(...)
206+
| __init__(self: async_io.aio_handle, block_size: int = 1048576, queue_depth: int = 128, single_submit: bool = False, overlap_events: bool = False, num_threads: int = 1) -> None
207+
|
208+
| AIO handle constructor
209+
```
210+
211+
### DeepNVMe APIs
212+
For convenience, we provide listing and brief descriptions of the DeepNVMe APIs.
213+
214+
#### General I/O APIs
215+
The following functions are used for I/O operations with both `aio_handle` and `gds_handle`.
216+
217+
Function | Description |
218+
|---|---|
219+
async_pread | Non-blocking file read into tensor |
220+
sync_pread | Blocking file read into tensor |
221+
pread | File read with blocking and non-blocking options |
222+
async_pwrite | Non-blocking file write from tensor |
223+
sync_pwrite | Blocking file write from tensor |
224+
pwrite | File write with blocking and non-blocking options |
225+
wait | Wait for non-blocking I/O operations to complete |
226+
227+
#### GDS-specific APIs
228+
The following functions are available only for `gds_handle`
229+
230+
Function | Description
231+
|---|---|
232+
new_pinned_device_tensor | Allocate and pin a device tensor |
233+
free_pinned_device_tensor | Unpin and free a device tensor |
234+
pin_device_tensor | Pin a device tensor |
235+
unpin_device_tensor | unpin a device tensor |
236+
237+
238+
#### Handle Settings APIs
239+
The following APIs can be used to probe handle configuration.
240+
241+
Function | Description
242+
|---|---|
243+
get_queue_depth | Return queue depth setting |
244+
get_single_submit | Return whether single_submit is enabled |
245+
get_thread_count | Return I/O parallelism degree |
246+
get_block_size | Return I/O block size setting |
247+
get_overlap_events | Return whether overlap_event is enabled |
8.75 KB
Loading

0 commit comments

Comments
 (0)