diff --git a/composer/utils/dist.py b/composer/utils/dist.py index 0909477042..e5d5c735a5 100644 --- a/composer/utils/dist.py +++ b/composer/utils/dist.py @@ -68,6 +68,10 @@ log = logging.getLogger(__name__) +class MissingEnvironmentError(Exception): + pass + + def _get_distributed_config_var( env_var: str, human_name: str, @@ -91,8 +95,8 @@ def _get_distributed_config_var( return int(os.environ[env_var]) if dist.is_initialized(): - raise RuntimeError('Torch distributed is initialized but environment variable ' - f'{env_var} is not set.') + raise MissingEnvironmentError('Torch distributed is initialized but environment variable ' + f'{env_var} is not set.') return default diff --git a/composer/utils/file_helpers.py b/composer/utils/file_helpers.py index 66003fd92c..1c550377f8 100644 --- a/composer/utils/file_helpers.py +++ b/composer/utils/file_helpers.py @@ -10,7 +10,7 @@ import re import tempfile import uuid -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Union import requests import tqdm @@ -34,6 +34,34 @@ ] +def _get_dist_config(strict: bool = True) -> Dict[str, Any]: + """Returns a dict of distributed settings (rank, world_size, etc.). + + If ``strict=True``, will error if a setting is not available (e.g. the + environment variable is not set). Otherwise, will only return settings + that are availalbe. + """ + settings = { + 'rank': dist.get_global_rank, + 'local_rank': dist.get_local_rank, + 'world_size': dist.get_world_size, + 'local_world_size': dist.get_local_world_size, + 'node_rank': dist.get_node_rank, + } + + dist_config = {} + for name, func in settings.items(): + try: + value = func() + except dist.MissingEnvironmentError as e: + if strict: + raise e + else: + dist_config[name] = value + + return dist_config + + def is_tar(name: Union[str, pathlib.Path]) -> bool: """Returns whether ``name`` has a tar-like extension. @@ -89,11 +117,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path] pattern = pattern.replace(f'{{{unit}}}', f'(?P<{unit}>\\d+)') # Format rank information - pattern = pattern.format(rank=dist.get_global_rank(), - local_rank=dist.get_local_rank(), - world_size=dist.get_world_size(), - local_world_size=dist.get_local_world_size(), - node_rank=dist.get_node_rank()) + pattern = pattern.format(**_get_dist_config(strict=False)) template = re.compile(pattern) @@ -143,11 +167,7 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path] def format_name_with_dist(format_str: str, run_name: str, **extra_format_kwargs: object): # noqa: D103 formatted_str = format_str.format( run_name=run_name, - rank=dist.get_global_rank(), - local_rank=dist.get_local_rank(), - world_size=dist.get_world_size(), - local_world_size=dist.get_local_world_size(), - node_rank=dist.get_node_rank(), + **_get_dist_config(strict=False), **extra_format_kwargs, ) return formatted_str @@ -240,11 +260,6 @@ def format_name_with_dist_and_time( ): # noqa: D103 formatted_str = format_str.format( run_name=run_name, - rank=dist.get_global_rank(), - local_rank=dist.get_local_rank(), - world_size=dist.get_world_size(), - local_world_size=dist.get_local_world_size(), - node_rank=dist.get_node_rank(), epoch=int(timestamp.epoch), batch=int(timestamp.batch), batch_in_epoch=int(timestamp.batch_in_epoch), @@ -255,6 +270,7 @@ def format_name_with_dist_and_time( total_wct=timestamp.total_wct.total_seconds(), epoch_wct=timestamp.epoch_wct.total_seconds(), batch_wct=timestamp.batch_wct.total_seconds(), + **_get_dist_config(strict=False), **extra_format_kwargs, ) return formatted_str diff --git a/tests/utils/test_file_helpers.py b/tests/utils/test_file_helpers.py index e4cee6e136..fa3d202fb0 100644 --- a/tests/utils/test_file_helpers.py +++ b/tests/utils/test_file_helpers.py @@ -12,6 +12,7 @@ from composer.utils.file_helpers import (ensure_folder_has_no_conflicting_files, ensure_folder_is_empty, format_name_with_dist, format_name_with_dist_and_time, get_file, is_tar) from composer.utils.object_store.libcloud_object_store import LibcloudObjectStore +from tests.common.markers import world_size @pytest.mark.xfail(reason='Occassionally hits the timeout. Should refactor to use a local webserver.') @@ -153,6 +154,28 @@ def test_format_name_with_dist(): assert format_name_with_dist(format_str, 'awesome_run', extra=42) == expected_str +@world_size(2) +def test_safe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size): + """node rank deleted, but not in format string, so format should complete.""" + vars = ['run_name', 'world_size'] + format_str = ','.join(f'{x}={{{x}}}' for x in vars) + expected_str = 'run_name=awesome_run,world_size=2' + + monkeypatch.delenv('NODE_RANK') + assert format_name_with_dist(format_str, 'awesome_run') == expected_str + + +@world_size(2) +def test_unsafe_format_name_with_dist(monkeypatch: pytest.MonkeyPatch, world_size): + """Node rank is deleted, but also in the format string, so expect error.""" + vars = ['run_name', 'node_rank'] + format_str = ','.join(f'{x}={{{x}}}' for x in vars) + + monkeypatch.delenv('NODE_RANK') + with pytest.raises(KeyError): + assert format_name_with_dist(format_str, 'awesome_run') == 'run_name=awesome_run,node_rank=3' + + def test_format_name_with_dist_and_time(): vars = [ 'run_name',