Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 20 additions & 33 deletions composer/utils/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import requests
import tqdm

from composer.core.time import Time, Timestamp
from composer.core.time import Timestamp
from composer.utils import dist
from composer.utils.iter_helpers import iterate_with_callback
from composer.utils.object_store import ObjectStore
Expand Down Expand Up @@ -82,12 +82,12 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]
"""
# Prepare regex pattern by replacing f-string formatting with regex.
pattern = f'^{filename}$'
# Format time vars for capture
time_names = ['epoch', 'batch', 'sample', 'token', 'batch_in_epoch', 'sample_in_epoch', 'token_in_epoch']
captured_names = {time_name: f'{{{time_name}}}' in filename for time_name in time_names}
for time_name, is_captured in captured_names.items():
if is_captured:
pattern = pattern.replace(f'{{{time_name}}}', f'(?P<{time_name}>\\d+)')

# Format time vars for regex match
for unit in ['epoch', 'batch', 'sample', 'token', 'batch_in_epoch', 'sample_in_epoch', 'token_in_epoch']:
if unit in filename:
pattern = pattern.replace(f'{{{unit}}}', f'(?P<{unit}>\\d+)')

# Format rank information
pattern = pattern.format(rank=dist.get_global_rank(),
local_rank=dist.get_local_rank(),
Expand All @@ -99,33 +99,20 @@ def ensure_folder_has_no_conflicting_files(folder_name: Union[str, pathlib.Path]

for file in os.listdir(folder_name):
match = template.match(file)
# Encountered an invalid match

if match is not None:
valid_match = True
# Check each base unit of time and flag later checkpoints
if captured_names['token'] and Time.from_token(int(match.group('token'))) > timestamp.token:
valid_match = False
elif captured_names['sample'] and Time.from_sample(int(match.group('sample'))) > timestamp.sample:
valid_match = False
elif captured_names['batch'] and Time.from_batch(int(match.group('batch'))) > timestamp.batch:
valid_match = False
elif captured_names['epoch'] and Time.from_epoch(int(match.group('epoch'))) > timestamp.epoch:
valid_match = False
# If epoch count is same, check batch_in_epoch, sample_in_epoch, token_in_epoch
elif captured_names['epoch'] and Time.from_epoch(int(match.group('epoch'))) == timestamp.epoch:
if captured_names['token_in_epoch'] and Time.from_token(int(
match.group('token_in_epoch'))) > timestamp.token_in_epoch:
valid_match = False
elif captured_names['sample_in_epoch'] and Time.from_sample(int(
match.group('sample_in_epoch'))) > timestamp.sample_in_epoch:
valid_match = False
elif captured_names['batch_in_epoch'] and Time.from_batch(int(
match.group('batch_in_epoch'))) > timestamp.batch_in_epoch:
valid_match = False
if not valid_match:
raise FileExistsError(
f'{os.path.join(folder_name, file)} exists and conflicts in namespace with a future checkpoint of the current run.'
)
match = match.groupdict()
for unit, value in match.items():
if unit.endswith('_in_epoch'):
if 'epoch' not in match:
raise ValueError(f'{filename} has {{unit}} but not {{epoch}}. Add {{epoch}} for uniqueness.')
if int(match['epoch']) != timestamp.epoch:
continue # only check _in_epoch if both files have same epoch count

if int(value) > int(getattr(timestamp, unit)):
raise FileExistsError(
f'{os.path.join(folder_name, file)} may conflict with a future checkpoint of the current run.'
'Please delete that file, change to a new folder, or set overwrite=True.')


FORMAT_NAME_WITH_DIST_TABLE = """
Expand Down