Skip to content

Commit f9503b7

Browse files
haileyschoelkopfgithub-actionsQuentin-Anthony
authored
Add s3 checkpoint syncing (#1010)
* add s3 checkpoint syncing * Update NeoXArgs docs automatically * remove CPCargo requirement * Update NeoXArgs docs automatically * Make s3 imports try-except and separate requirements to s3 file * Update NeoXArgs docs automatically * Announce feature * Update NeoXArgs docs automatically --------- Co-authored-by: github-actions <[email protected]> Co-authored-by: Quentin Anthony <[email protected]>
1 parent 444c0ef commit f9503b7

File tree

5 files changed

+157
-3
lines changed

5 files changed

+157
-3
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ GPT-NeoX leverages many of the same features and technologies as the popular Meg
1818
* Easy connections with the open source ecosystem, including Hugging Face's [tokenizers](https://github.com/huggingface/tokenizers) and [transformers](https://github.com/huggingface/transformers/) libraries, logging via [WandB](https://wandb.ai/site), and evaluation via our [Language Model Evaluation Harness](https://github.com/EleutherAI/lm-evaluation-harness).
1919

2020
## News
21+
**[8/10/2023]** We now support checkpointing with AWS S3! Activate with the `s3_path` config option (for more detail, see [the PR](https://github.com/EleutherAI/gpt-neox/pull/1010))
22+
2123
**[9/20/2023]** As of https://github.com/EleutherAI/gpt-neox/pull/1035, we have deprecated Flash Attention 0.x and 1.x, and migrated support to Flash Attention 2.x. We don't believe this will cause problems, but if you have a specific use-case that requires old flash support using the latest GPT-NeoX, please raise an issue.
2224

2325
**[8/10/2023]** We have experimental support for LLaMA 2 and Flash Attention v2 supported in our [math-lm](https://github.com/EleutherAI/math-lm) project that will be upstreamed later this month.

configs/neox_arguments.md

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ Logging Arguments
111111

112112
- **git_hash**: str
113113

114-
Default = 1d20559
114+
Default = aa6c176
115115

116116
current git hash of repository
117117

@@ -1169,6 +1169,22 @@ Training Arguments
11691169

11701170

11711171

1172+
- **s3_path**: str
1173+
1174+
Default = None
1175+
1176+
Path to s3 bucket for saving checkpoints.
1177+
1178+
1179+
1180+
- **s3_chunk_size**: int
1181+
1182+
Default = 104857600
1183+
1184+
The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB.
1185+
1186+
1187+
11721188
- **config_files**: dict
11731189

11741190
Default = None

megatron/checkpointing.py

Lines changed: 126 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,23 @@
1818
"""Input/output checkpointing."""
1919

2020
import json
21+
import math
2122
import os
2223
import re
2324
import shutil
25+
import time
2426
import random
2527
import sys
2628
import numpy as np
2729

30+
try:
31+
import boto3
32+
except ModuleNotFoundError:
33+
print("For s3 checkpointing, please install boto3 either using requirements/requirements-s3.txt or https://github.com/boto/boto3")
34+
try:
35+
import hf_transfer
36+
except ModuleNotFoundError:
37+
print("For s3 checkpointing, please install hf_transfer either using requirements/requirements-s3.txt or https://github.com/huggingface/hf_transfer")
2838
import torch
2939
from glob import glob
3040

@@ -137,6 +147,10 @@ def get_checkpoint_name(checkpoints_path, iteration, release=False, mp_rank=None
137147
)
138148

139149

150+
def get_checkpoint_tag(iteration: int) -> str:
151+
return f"global_step{iteration}"
152+
153+
140154
def delete_old_checkpoints(save_dir, n_to_keep):
141155
if torch.distributed.get_rank() == 0:
142156
ckpt_dir_regex = r"global_step[\d]*"
@@ -188,7 +202,7 @@ def save_ds_checkpoint(iteration, model, neox_args):
188202
sd["checkpoint_validation_logits"] = logits
189203

190204
# checkpoint folder name
191-
tag = f"global_step{iteration}"
205+
tag = get_checkpoint_tag(iteration)
192206

193207
# save checkpoint
194208
model.save_checkpoint(neox_args.save, tag=tag, client_state=sd)
@@ -203,6 +217,111 @@ def save_ds_checkpoint(iteration, model, neox_args):
203217
f.write(config_data)
204218
else:
205219
json.dump(config_data, f)
220+
def multiprocessing_starmap(func, args, num_processes=None):
221+
"""Wrapper to allow for re-usable multiprocessing pools with `spawn` context handling
222+
Args:
223+
func (Callable): Function to call
224+
args (Iterable): Iterable of arguments to pass to `func`
225+
num_processes (int, optional): Number of processes to spawn. Defaults to `multiprocessing.cpu_count() - 1`
226+
"""
227+
import multiprocessing
228+
num_processes = num_processes or (multiprocessing.cpu_count() - 1)
229+
with multiprocessing.get_context("spawn").Pool(processes=num_processes) as process_pool:
230+
process_pool.starmap(func, args)
231+
process_pool.terminate()
232+
process_pool.join()
233+
del process_pool
234+
235+
236+
def _upload(
237+
file_path: str,
238+
s3_key: str,
239+
chunk_size: int = 104_857_600,
240+
max_files: int = 64,
241+
parallel_failures: int = 63,
242+
max_retries: int = 5,
243+
):
244+
"""Upload local file to S3 using `hf_transfer` library
245+
Args:
246+
file_path (str): Local filename to upload
247+
s3_key (str): S3 key to upload to. E.g. `s3://bucket-name/path/to/file`
248+
chunk_size (int, optional): Chunk size to use for multipart upload.
249+
Defaults to 100MiB = 104_857_600
250+
max_files (int, optional): Number of open file handles, which determines
251+
the maximum number of parallel downloads. Defaults to 64
252+
parallel_failures (int, optional): Number of maximum failures of different
253+
chunks in parallel (cannot exceed max_files). Defaults to 63
254+
max_retries (int, optional): Number of retries for each chunk. Defaults to 5
255+
"""
256+
s3 = boto3.client('s3')
257+
bucket = s3_key.split("s3://")[1].split("/")[0]
258+
key = s3_key.split(bucket)[1].lstrip("/")
259+
260+
# 1. Init multipart upload and obtain unique upload identifier
261+
upload = s3.create_multipart_upload(
262+
ACL="bucket-owner-full-control",
263+
Bucket=bucket,
264+
Key=key,
265+
)
266+
upload_id = upload["UploadId"]
267+
268+
# 2. Generate presigned URLs for each part
269+
file_size = os.stat(file_path).st_size
270+
urls = []
271+
nb_parts = math.ceil(file_size / chunk_size)
272+
for part_number in range(1, nb_parts + 1):
273+
params = {
274+
"Bucket": bucket,
275+
"Key": key,
276+
"PartNumber": part_number,
277+
"UploadId": upload_id,
278+
}
279+
urls.append(
280+
s3.generate_presigned_url(
281+
ClientMethod="upload_part", Params=params, ExpiresIn=86400
282+
)
283+
)
284+
285+
# 3. Upload parts in parallel
286+
responses = hf_transfer.multipart_upload(
287+
file_path=file_path,
288+
parts_urls=urls,
289+
chunk_size=chunk_size,
290+
max_files=max_files,
291+
parallel_failures=parallel_failures,
292+
max_retries=max_retries,
293+
)
294+
295+
# 4. Complete multipart upload request with ETag values
296+
etag_with_parts = []
297+
for part_number, header in enumerate(responses):
298+
etag = header.get("etag")
299+
etag_with_parts.append({"ETag": etag, "PartNumber": part_number + 1})
300+
parts = {"Parts": etag_with_parts}
301+
s3.complete_multipart_upload(
302+
Bucket=bucket, Key=key, MultipartUpload=parts, UploadId=upload_id
303+
)
304+
305+
306+
def upload_checkpoint(iteration, neox_args):
307+
local_checkpoint_path = os.path.join(os.path.abspath(neox_args.save), get_checkpoint_tag(iteration))
308+
local_checkpoint_list = sorted(filter(
309+
lambda x: os.path.isfile(x),
310+
[str(p) for p in Path(local_checkpoint_path).rglob("*")],
311+
))
312+
remote_checkpoint_path = os.path.join(
313+
neox_args.s3_path, os.path.basename(neox_args.save), get_checkpoint_tag(iteration))
314+
remote_checkpoint_list = [
315+
os.path.join(remote_checkpoint_path, os.path.relpath(local_checkpoint, local_checkpoint_path))
316+
for local_checkpoint in local_checkpoint_list
317+
]
318+
inputs = zip(local_checkpoint_list, remote_checkpoint_list, [neox_args.s3_chunk_size] * len(local_checkpoint_list))
319+
320+
print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploading checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}`...")
321+
start = time.time()
322+
multiprocessing_starmap(_upload, inputs)
323+
total_time = time.time() - start
324+
print_rank_0(f"[RANK {torch.distributed.get_rank()}] Uploaded checkpoint `{local_checkpoint_path}` to `{remote_checkpoint_path}` in {total_time:.2f}s")
206325

207326

208327
def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
@@ -213,6 +332,11 @@ def save_checkpoint(neox_args, iteration, model, optimizer, lr_scheduler):
213332
else:
214333
raise ValueError("Must be using deepspeed to use neox")
215334

335+
torch.distributed.barrier()
336+
upload_to_s3 = torch.distributed.get_rank() == 0 and neox_args.s3_path is not None
337+
if upload_to_s3:
338+
upload_checkpoint(iteration, neox_args)
339+
216340
# Wait so everyone is done (necessary)
217341
torch.distributed.barrier()
218342
if neox_args.keep_last_n_checkpoints is not None:
@@ -233,7 +357,7 @@ def load_checkpoint(
233357
if neox_args.finetune:
234358
load_optim_and_scheduler = False
235359
if iteration is not None:
236-
tag = f"global_step{iteration}"
360+
tag = get_checkpoint_tag(iteration)
237361
else:
238362
tag = None
239363
checkpoint_name, state_dict = model.load_checkpoint(

megatron/neox_arguments/neox_args.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -793,6 +793,16 @@ class NeoXArgsTraining(NeoXArgsTemplate):
793793
Output directory to save checkpoints to.
794794
"""
795795

796+
s3_path: str = None
797+
"""
798+
Path to s3 bucket for saving checkpoints.
799+
"""
800+
801+
s3_chunk_size: int = 104_857_600
802+
"""
803+
The number of bytes in each file chunk when uploading to s3. Defaults to 100MiB.
804+
"""
805+
796806
config_files: dict = None
797807
"""
798808
Store of original config files mapping config filename to file contents

requirements/requirements-s3.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
hf-transfer>=0.1.3
2+
boto3

0 commit comments

Comments
 (0)