diff --git a/composer/callbacks/runtime_estimator.py b/composer/callbacks/runtime_estimator.py index 29afffa9d2..b0fc96e20b 100644 --- a/composer/callbacks/runtime_estimator.py +++ b/composer/callbacks/runtime_estimator.py @@ -79,7 +79,7 @@ def load_state_dict(self, state: Dict[str, Any]) -> None: self.eval_frequency_per_label = state['eval_frequency_per_label'] self.last_elapsed_fraction = state['last_elapsed_fraction'] - def get_elapsed_duration(self, state: State) -> Optional[float]: + def _get_elapsed_duration(self, state: State) -> Optional[float]: """Get the elapsed duration. Unlike `state.get_elapsed_duration`, this method computes fractional progress in an epoch @@ -102,7 +102,7 @@ def get_elapsed_duration(self, state: State) -> Optional[float]: def batch_start(self, state: State, logger: Logger) -> None: if self._enabled and self.start_time is None and self.batches_left_to_skip == 0: self.start_time = time.time() - self.start_dur = self.get_elapsed_duration(state) + self.start_dur = self._get_elapsed_duration(state) if self.start_dur is None: warnings.warn('`max_duration` is not set. Cannot estimate remaining time.') self._enabled = False @@ -114,7 +114,7 @@ def batch_end(self, state: State, logger: Logger) -> None: self.batches_left_to_skip -= 1 return - elapsed_dur = self.get_elapsed_duration(state) + elapsed_dur = self._get_elapsed_duration(state) assert elapsed_dur is not None, 'max_duration checked as non-None on batch_start' assert self.start_dur is not None @@ -153,7 +153,7 @@ def eval_end(self, state: State, logger: Logger) -> None: if state.dataloader_label not in self.eval_wct_per_label: self.eval_wct_per_label[state.dataloader_label] = [] self.eval_wct_per_label[state.dataloader_label].append(state.eval_timestamp.total_wct.total_seconds()) - elapsed_fraction = self.get_elapsed_duration(state) + elapsed_fraction = self._get_elapsed_duration(state) assert elapsed_fraction is not None, 'max_duration checked as non-None on batch_start' num_evals_finished = len(self.eval_wct_per_label[state.dataloader_label]) self.eval_frequency_per_label[state.dataloader_label] = elapsed_fraction / num_evals_finished diff --git a/docs/source/trainer/callbacks.rst b/docs/source/trainer/callbacks.rst index 1573ba0afd..a12cdf5e52 100644 --- a/docs/source/trainer/callbacks.rst +++ b/docs/source/trainer/callbacks.rst @@ -46,6 +46,7 @@ components of training. ~checkpoint_saver.CheckpointSaver ~speed_monitor.SpeedMonitor + ~runtime_estimator.RuntimeEstimator ~lr_monitor.LRMonitor ~optimizer_monitor.OptimizerMonitor ~memory_monitor.MemoryMonitor