Skip to content

Commit b441869

Browse files
hanlintBandish Shah
authored andcommitted
less strict dist formatting (#1535)
1 parent 96d832e commit b441869

File tree

3 files changed

+61
-18
lines changed

3 files changed

+61
-18
lines changed

composer/utils/dist.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,10 @@
6868
log = logging.getLogger(__name__)
6969

7070

71+
class MissingEnvironmentError(Exception):
72+
pass
73+
74+
7175
def _get_distributed_config_var(
7276
env_var: str,
7377
human_name: str,
@@ -91,8 +95,8 @@ def _get_distributed_config_var(
9195
return int(os.environ[env_var])
9296

9397
if dist.is_initialized():
94-
raise RuntimeError('Torch distributed is initialized but environment variable '
95-
f'{env_var} is not set.')
98+
raise MissingEnvironmentError('Torch distributed is initialized but environment variable '
99+
f'{env_var} is not set.')
96100

97101
return default
98102

composer/utils/file_helpers.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import re
1111
import tempfile
1212
import uuid
13-
from typing import TYPE_CHECKING, Optional, Union
13+
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
1414

1515
import requests
1616
import tqdm
@@ -34,6 +34,34 @@
3434
]
3535

3636

37+
def _get_dist_config(strict: bool = True) -> Dict[str, Any]:
38+
"""Returns a dict of distributed settings (rank, world_size, etc.).
39+
40+
If ``strict=True``, will error if a setting is not available (e.g. the
41+
environment variable is not set). Otherwise, will only return settings
42+
that are availalbe.
43+
"""
44+
settings = {
45+
'rank': dist.get_global_rank,
46+
'local_rank': dist.get_local_rank,
47+
'world_size': dist.get_world_size,
48+
'local_world_size': dist.get_local_world_size,
49+
'node_rank': dist.get_node_rank,
50+
}
51+
52+
dist_config = {}
53+
for name, func in settings.items():
54+
try:
55+
value = func()
56+
except dist.MissingEnvironmentError as e:
57+
if strict:
58+
raise e
59+
else:
60+
dist_config[name] = value
61+
62+
return dist_config
63+
64+
3765
def is_tar(name: Union[str, pathlib.Path]) -> bool:
3866
"""Returns whether ``name`` has a tar-like extension.
3967
@@ -89,11 +117,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
89117
pattern = pattern.replace(f'{{{unit}}}', f'(?P<{unit}>\\d+)')
90118

91119
# Format rank information
92-
pattern = pattern.format(rank=dist.get_global_rank(),
93-
local_rank=dist.get_local_rank(),
94-
world_size=dist.get_world_size(),
95-
local_world_size=dist.get_local_world_size(),
96-
node_rank=dist.get_node_rank())
120+
pattern = pattern.format(**_get_dist_config(strict=False))
97121

98122
template = re.compile(pattern)
99123

@@ -143,11 +167,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
143167
def format_name_with_dist(format_str: str, run_name: str, **extra_format_kwargs: object): # noqa: D103
144168
formatted_str = format_str.format(
145169
run_name=run_name,
146-
rank=dist.get_global_rank(),
147-
local_rank=dist.get_local_rank(),
148-
world_size=dist.get_world_size(),
149-
local_world_size=dist.get_local_world_size(),
150-
node_rank=dist.get_node_rank(),
170+
**_get_dist_config(strict=False),
151171
**extra_format_kwargs,
152172
)
153173
return formatted_str
@@ -240,11 +260,6 @@ def format_name_with_dist_and_time(
240260
): # noqa: D103
241261
formatted_str = format_str.format(
242262
run_name=run_name,
243-
rank=dist.get_global_rank(),
244-
local_rank=dist.get_local_rank(),
245-
world_size=dist.get_world_size(),
246-
local_world_size=dist.get_local_world_size(),
247-
node_rank=dist.get_node_rank(),
248263
epoch=int(timestamp.epoch),
249264
batch=int(timestamp.batch),
250265
batch_in_epoch=int(timestamp.batch_in_epoch),
@@ -255,6 +270,7 @@ def format_name_with_dist_and_time(
255270
total_wct=timestamp.total_wct.total_seconds(),
256271
epoch_wct=timestamp.epoch_wct.total_seconds(),
257272
batch_wct=timestamp.batch_wct.total_seconds(),
273+
**_get_dist_config(strict=False),
258274
**extra_format_kwargs,
259275
)
260276
return formatted_str

tests/utils/test_file_helpers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from composer.utils.file_helpers import (ensure_folder_has_no_conflicting_files, ensure_folder_is_empty,
1313
format_name_with_dist, format_name_with_dist_and_time, get_file, is_tar)
1414
from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore
15+
from tests.common.markers import world_size
1516

1617

1718
@pytest.mark.xfail(reason='Occassionally hits the timeout. Should refactor to use a local webserver.')
@@ -153,6 +154,28 @@ def test_format_name_with_dist():
153154
assert format_name_with_dist(format_str, 'awesome_run', extra=42) == expected_str
154155

155156

157+
@world_size(2)
158+
def test_safe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size):
159+
"""node rank deleted, but not in format string, so format should complete."""
160+
vars = ['run_name', 'world_size']
161+
format_str = ','.join(f'{x}={{{x}}}' for x in vars)
162+
expected_str = 'run_name=awesome_run,world_size=2'
163+
164+
monkeypatch.delenv('NODE_RANK')
165+
assert format_name_with_dist(format_str, 'awesome_run') == expected_str
166+
167+
168+
@world_size(2)
169+
def test_unsafe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size):
170+
"""Node rank is deleted, but also in the format string, so expect error."""
171+
vars = ['run_name', 'node_rank']
172+
format_str = ','.join(f'{x}={{{x}}}' for x in vars)
173+
174+
monkeypatch.delenv('NODE_RANK')
175+
with pytest.raises(KeyError):
176+
assert format_name_with_dist(format_str, 'awesome_run') == 'run_name=awesome_run,node_rank=3'
177+
178+
156179
def test_format_name_with_dist_and_time():
157180
vars = [
158181
'run_name',

0 commit comments

Comments
 (0)