10
10
import re
11
11
import tempfile
12
12
import uuid
13
- from typing import TYPE_CHECKING , Optional , Union
13
+ from typing import TYPE_CHECKING , Any , Dict , Optional , Union
14
14
15
15
import requests
16
16
import tqdm
34
34
]
35
35
36
36
37
+ def _get_dist_config (strict : bool = True ) -> Dict [str , Any ]:
38
+ """Returns a dict of distributed settings (rank, world_size, etc.).
39
+
40
+ If ``strict=True``, will error if a setting is not available (e.g. the
41
+ environment variable is not set). Otherwise, will only return settings
42
+ that are availalbe.
43
+ """
44
+ settings = {
45
+ 'rank' : dist .get_global_rank ,
46
+ 'local_rank' : dist .get_local_rank ,
47
+ 'world_size' : dist .get_world_size ,
48
+ 'local_world_size' : dist .get_local_world_size ,
49
+ 'node_rank' : dist .get_node_rank ,
50
+ }
51
+
52
+ dist_config = {}
53
+ for name , func in settings .items ():
54
+ try :
55
+ value = func ()
56
+ except dist .MissingEnvironmentError as e :
57
+ if strict :
58
+ raise e
59
+ else :
60
+ dist_config [name ] = value
61
+
62
+ return dist_config
63
+
64
+
37
65
def is_tar (name : Union [str , pathlib .Path ]) -> bool :
38
66
"""Returns whether ``name`` has a tar-like extension.
39
67
@@ -89,11 +117,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
89
117
pattern = pattern .replace (f'{{{ unit } }}' , f'(?P<{ unit } >\\ d+)' )
90
118
91
119
# Format rank information
92
- pattern = pattern .format (rank = dist .get_global_rank (),
93
- local_rank = dist .get_local_rank (),
94
- world_size = dist .get_world_size (),
95
- local_world_size = dist .get_local_world_size (),
96
- node_rank = dist .get_node_rank ())
120
+ pattern = pattern .format (** _get_dist_config (strict = False ))
97
121
98
122
template = re .compile (pattern )
99
123
@@ -143,11 +167,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
143
167
def format_name_with_dist (format_str : str , run_name : str , ** extra_format_kwargs : object ): # noqa: D103
144
168
formatted_str = format_str .format (
145
169
run_name = run_name ,
146
- rank = dist .get_global_rank (),
147
- local_rank = dist .get_local_rank (),
148
- world_size = dist .get_world_size (),
149
- local_world_size = dist .get_local_world_size (),
150
- node_rank = dist .get_node_rank (),
170
+ ** _get_dist_config (strict = False ),
151
171
** extra_format_kwargs ,
152
172
)
153
173
return formatted_str
@@ -240,11 +260,6 @@ def format_name_with_dist_and_time(
240
260
): # noqa: D103
241
261
formatted_str = format_str .format (
242
262
run_name = run_name ,
243
- rank = dist .get_global_rank (),
244
- local_rank = dist .get_local_rank (),
245
- world_size = dist .get_world_size (),
246
- local_world_size = dist .get_local_world_size (),
247
- node_rank = dist .get_node_rank (),
248
263
epoch = int (timestamp .epoch ),
249
264
batch = int (timestamp .batch ),
250
265
batch_in_epoch = int (timestamp .batch_in_epoch ),
@@ -255,6 +270,7 @@ def format_name_with_dist_and_time(
255
270
total_wct = timestamp .total_wct .total_seconds (),
256
271
epoch_wct = timestamp .epoch_wct .total_seconds (),
257
272
batch_wct = timestamp .batch_wct .total_seconds (),
273
+ ** _get_dist_config (strict = False ),
258
274
** extra_format_kwargs ,
259
275
)
260
276
return formatted_str
0 commit comments