Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions composer/callbacks/runtime_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def load_state_dict(self, state: Dict[str, Any]) -> None:
self.eval_frequency_per_label = state['eval_frequency_per_label']
self.last_elapsed_fraction = state['last_elapsed_fraction']

def get_elapsed_duration(self, state: State) -> Optional[float]:
def _get_elapsed_duration(self, state: State) -> Optional[float]:
"""Get the elapsed duration.

Unlike `state.get_elapsed_duration`, this method computes fractional progress in an epoch
Expand All @@ -102,7 +102,7 @@ def get_elapsed_duration(self, state: State) -> Optional[float]:
def batch_start(self, state: State, logger: Logger) -> None:
if self._enabled and self.start_time is None and self.batches_left_to_skip == 0:
self.start_time = time.time()
self.start_dur = self.get_elapsed_duration(state)
self.start_dur = self._get_elapsed_duration(state)
if self.start_dur is None:
warnings.warn('`max_duration` is not set. Cannot estimate remaining time.')
self._enabled = False
Expand All @@ -114,7 +114,7 @@ def batch_end(self, state: State, logger: Logger) -> None:
self.batches_left_to_skip -= 1
return

elapsed_dur = self.get_elapsed_duration(state)
elapsed_dur = self._get_elapsed_duration(state)
assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start'

assert self.start_dur is not None
Expand Down Expand Up @@ -153,7 +153,7 @@ def eval_end(self, state: State, logger: Logger) -> None:
if state.dataloader_label not in self.eval_wct_per_label:
self.eval_wct_per_label[state.dataloader_label] = []
self.eval_wct_per_label[state.dataloader_label].append(state.eval_timestamp.total_wct.total_seconds())
elapsed_fraction = self.get_elapsed_duration(state)
elapsed_fraction = self._get_elapsed_duration(state)
assert elapsed_fraction is not None, 'max_duration checked as non-None on batch_start'
num_evals_finished = len(self.eval_wct_per_label[state.dataloader_label])
self.eval_frequency_per_label[state.dataloader_label] = elapsed_fraction / num_evals_finished
1 change: 1 addition & 0 deletions docs/source/trainer/callbacks.rst
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ components of training.

~checkpoint_saver.CheckpointSaver
~speed_monitor.SpeedMonitor
~runtime_estimator.RuntimeEstimator
~lr_monitor.LRMonitor
~optimizer_monitor.OptimizerMonitor
~memory_monitor.MemoryMonitor
Expand Down