Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions composer/utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
log = logging.getLogger(__name__)


class MissingEnvironmentError(Exception):
pass


def _get_distributed_config_var(
env_var: str,
human_name: str,
Expand All @@ -91,8 +95,8 @@ def _get_distributed_config_var(
return int(os.environ[env_var])

if dist.is_initialized():
raise RuntimeError('Torch distributed is initialized but environment variable '
f'{env_var} is not set.')
raise MissingEnvironmentError('Torch distributed is initialized but environment variable '
f'{env_var} is not set.')

return default

Expand Down
48 changes: 32 additions & 16 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import re
import tempfile
import uuid
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import requests
import tqdm
Expand All @@ -34,6 +34,34 @@
]


def _get_dist_config(strict: bool = True) -> Dict[str, Any]:
"""Returns a dict of distributed settings (rank, world_size, etc.).
If ``strict=True``, will error if a setting is not available (e.g. the
environment variable is not set). Otherwise, will only return settings
that are availalbe.
"""
settings = {
'rank': dist.get_global_rank,
'local_rank': dist.get_local_rank,
'world_size': dist.get_world_size,
'local_world_size': dist.get_local_world_size,
'node_rank': dist.get_node_rank,
}

dist_config = {}
for name, func in settings.items():
try:
value = func()
except dist.MissingEnvironmentError as e:
if strict:
raise e
else:
dist_config[name] = value

return dist_config


def is_tar(name: Union[str, pathlib.Path]) -> bool:
"""Returns whether ``name`` has a tar-like extension.
Expand Down Expand Up @@ -89,11 +117,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
pattern = pattern.replace(f'{{{unit}}}', f'(?P<{unit}>\\d+)')

# Format rank information
pattern = pattern.format(rank=dist.get_global_rank(),
local_rank=dist.get_local_rank(),
world_size=dist.get_world_size(),
local_world_size=dist.get_local_world_size(),
node_rank=dist.get_node_rank())
pattern = pattern.format(**_get_dist_config(strict=False))

template = re.compile(pattern)

Expand Down Expand Up @@ -143,11 +167,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
def format_name_with_dist(format_str: str, run_name: str, **extra_format_kwargs: object): # noqa: D103
formatted_str = format_str.format(
run_name=run_name,
rank=dist.get_global_rank(),
local_rank=dist.get_local_rank(),
world_size=dist.get_world_size(),
local_world_size=dist.get_local_world_size(),
node_rank=dist.get_node_rank(),
**_get_dist_config(strict=False),
**extra_format_kwargs,
)
return formatted_str
Expand Down Expand Up @@ -240,11 +260,6 @@ def format_name_with_dist_and_time(
): # noqa: D103
formatted_str = format_str.format(
run_name=run_name,
rank=dist.get_global_rank(),
local_rank=dist.get_local_rank(),
world_size=dist.get_world_size(),
local_world_size=dist.get_local_world_size(),
node_rank=dist.get_node_rank(),
epoch=int(timestamp.epoch),
batch=int(timestamp.batch),
batch_in_epoch=int(timestamp.batch_in_epoch),
Expand All @@ -255,6 +270,7 @@ def format_name_with_dist_and_time(
total_wct=timestamp.total_wct.total_seconds(),
epoch_wct=timestamp.epoch_wct.total_seconds(),
batch_wct=timestamp.batch_wct.total_seconds(),
**_get_dist_config(strict=False),
**extra_format_kwargs,
)
return formatted_str
Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from composer.utils.file_helpers import (ensure_folder_has_no_conflicting_files, ensure_folder_is_empty,
format_name_with_dist, format_name_with_dist_and_time, get_file, is_tar)
from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore
from tests.common.markers import world_size


@pytest.mark.xfail(reason='Occassionally hits the timeout. Should refactor to use a local webserver.')
Expand Down Expand Up @@ -153,6 +154,28 @@ def test_format_name_with_dist():
assert format_name_with_dist(format_str, 'awesome_run', extra=42) == expected_str


@world_size(2)
def test_safe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size):
"""node rank deleted, but not in format string, so format should complete."""
vars = ['run_name', 'world_size']
format_str = ','.join(f'{x}={{{x}}}' for x in vars)
expected_str = 'run_name=awesome_run,world_size=2'

monkeypatch.delenv('NODE_RANK')
assert format_name_with_dist(format_str, 'awesome_run') == expected_str


@world_size(2)
def test_unsafe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size):
"""Node rank is deleted, but also in the format string, so expect error."""
vars = ['run_name', 'node_rank']
format_str = ','.join(f'{x}={{{x}}}' for x in vars)

monkeypatch.delenv('NODE_RANK')
with pytest.raises(KeyError):
assert format_name_with_dist(format_str, 'awesome_run') == 'run_name=awesome_run,node_rank=3'


def test_format_name_with_dist_and_time():
vars = [
'run_name',
Expand Down