Skip to content

Commit 04b4a83

Browse files
AutoYAHP Part 2: Cleanup the Callbacks and Loggers for AutoYAHP (#1071)
Similar to #1056, this PR cleans up the callbacks and loggers for AutoYAHP. It does not depend on AutoYAHP itself (a future PR will remove the underlying hparam classes). - Refactored callback and logger tests to not depend on hparams - Reformatted the docstrings so they would be correctly parsed by auto-yahp - Added `callback_settings.py`, similar to `algorithm_settings.py`, to return a list of pytest param objects for parameterization across callback tests. These param objects include appropriate markers (e.g. conditional skipping for wandb and mlperf; requiring that the memory monitor runs on GPU, ...) - Moved the `TestTrainerAssets` into `tests/callbacks/test_callbacks.py`, since it tests the individual callbacks and loggers, not the trainer, and thus should live in `tests/callbacks`. - Cleaned up the `MemoryMonitor` warnings when cuda is not available. Now, it warns when the model is not on cuda, to catch the edge case where one does CPU training when GPUs are available.
1 parent 660e3f0 commit 04b4a83

22 files changed

+484
-410
lines changed

composer/callbacks/callback_hparams.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,11 @@
2929
"MemoryMonitorHparams",
3030
"LRMonitorHparams",
3131
"SpeedMonitorHparams",
32+
"EarlyStopperHparams",
3233
"CheckpointSaverHparams",
34+
"ThresholdStopperHparams",
35+
"MLPerfCallbackHparams",
36+
"callback_registry",
3337
]
3438

3539

@@ -321,3 +325,15 @@ def initialize_object(self) -> CheckpointSaver:
321325
weights_only=self.weights_only,
322326
num_checkpoints_to_keep=self.num_checkpoints_to_keep,
323327
)
328+
329+
330+
callback_registry = {
331+
"checkpoint_saver": CheckpointSaverHparams,
332+
"speed_monitor": SpeedMonitorHparams,
333+
"lr_monitor": LRMonitorHparams,
334+
"grad_monitor": GradMonitorHparams,
335+
"memory_monitor": MemoryMonitorHparams,
336+
"mlperf": MLPerfCallbackHparams,
337+
"early_stopper": EarlyStopperHparams,
338+
"threshold_stopper": ThresholdStopperHparams,
339+
}

composer/callbacks/checkpoint_saver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ class CheckpointSaver(Callback):
245245
If ``False`` (the default), then the ``folder`` must not exist or must not contain checkpoints which may conflict
246246
with the current run. Default: ``False``.
247247
248-
save_interval (:class:`.Time` | str | int | (:class:`.State`, :class:`.Event`) -> bool): A :class:`.Time`, time-string, integer (in epochs),
248+
save_interval (Time | str | int | (State, Event) -> bool): A :class:`.Time`, time-string, integer (in epochs),
249249
or a function that takes (state, event) and returns a boolean whether a checkpoint should be saved.
250250
251251
If an integer, checkpoints will be saved every n epochs.

composer/callbacks/memory_monitor.py

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
"""Log memory usage during training."""
55
import logging
6+
import warnings
67
from typing import Dict, Union
78

89
import torch.cuda
@@ -76,24 +77,24 @@ class MemoryMonitor(Callback):
7677
Memory usage monitoring is only supported for the GPU devices.
7778
"""
7879

79-
def __init__(self):
80-
super().__init__()
81-
log.info(
82-
"Memory monitor just profiles the current GPU assuming that the memory footprint across GPUs is balanced.")
83-
if torch.cuda.device_count() == 0:
84-
log.warning("Memory monitor only works on GPU devices.")
80+
def fit_start(self, state: State, logger: Logger) -> None:
81+
# TODO(ravi) move this check would be on Event.INIT after #1084 is merged
82+
# Not relying on `torch.cuda.is_available()` since the model could be on CPU.
83+
model_device = next(state.model.parameters()).device
84+
85+
if model_device.type != 'cuda':
86+
warnings.warn(f"The memory monitor only works on CUDA devices, but the model is on {model_device.type}.")
8587

8688
def after_train_batch(self, state: State, logger: Logger):
8789
memory_report = {}
8890

89-
n_devices = torch.cuda.device_count()
90-
if n_devices == 0:
91+
model_device = next(state.model.parameters()).device
92+
if model_device.type != 'cuda':
9193
return
9294

9395
memory_report = _get_memory_report()
9496

95-
for mem_stat, val in memory_report.items():
96-
logger.data_batch({'memory/{}'.format(mem_stat): val})
97+
logger.data_batch({f'memory/{mem_stat}': val for (mem_stat, val) in memory_report.items()})
9798

9899

99100
_MEMORY_STATS = {
@@ -108,15 +109,8 @@ def after_train_batch(self, state: State, logger: Logger):
108109

109110

110111
def _get_memory_report() -> Dict[str, Union[int, float]]:
111-
if not torch.cuda.is_available():
112-
log.debug("Cuda is not available. The memory report will be empty.")
113-
return {}
114112
memory_stats = torch.cuda.memory_stats()
115113

116-
if len(memory_stats) == 0:
117-
log.debug("No GPU memory was used, returning empty.")
118-
return {}
119-
120114
# simplify the memory_stats
121115
memory_report = {
122116
name: memory_stats[torch_name] for (torch_name, name) in _MEMORY_STATS.items() if torch_name in memory_stats

composer/callbacks/speed_monitor.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
class SpeedMonitor(Callback):
1919
"""Logs the training throughput.
2020
21-
The training throughput in terms of number of samples per second is logged on the
22-
:attr:`~composer.core.event.Event.BATCH_END` event if we have reached the ``window_size`` threshold. Per epoch
23-
average throughput and wall clock train, validation, and total time is also logged on the
24-
:attr:`~composer.core.event.Event.EPOCH_END` event.
21+
The training throughput in terms of number of samples per second is logged on
22+
the :attr:`~composer.core.event.Event.BATCH_END` event if we have reached the ``window_size`` threshold. Per epoch
23+
average throughput and wall clock train, validation, and total time is also logged on
24+
the :attr:`~composer.core.event.Event.EPOCH_END` event.
2525
2626
Example
2727
@@ -49,11 +49,11 @@ class SpeedMonitor(Callback):
4949
| Key | Logged data |
5050
+=======================+=============================================================+
5151
| | Rolling average (over ``window_size`` most recent |
52-
| ``throughput/step`` | batches) of the number of samples processed per second |
52+
| ``samples/step`` | batches) of the number of samples processed per second |
5353
| | |
5454
+-----------------------+-------------------------------------------------------------+
5555
| | Number of samples processed per second (averaged over |
56-
| ``throughput/epoch`` | an entire epoch) |
56+
| ``samples/epoch`` | an entire epoch) |
5757
+-----------------------+-------------------------------------------------------------+
5858
|``wall_clock/train`` | Total elapsed training time |
5959
+-----------------------+-------------------------------------------------------------+
@@ -63,8 +63,8 @@ class SpeedMonitor(Callback):
6363
+-----------------------+-------------------------------------------------------------+
6464
6565
Args:
66-
window_size (int, optional):
67-
Number of batches to use for a rolling average of throughput. Default to 100.
66+
window_size (int, optional): Number of batches to use for a rolling average of throughput.
67+
Default to 100.
6868
"""
6969

7070
def __init__(self, window_size: int = 100):

composer/core/callback.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from __future__ import annotations
66

77
import abc
8-
from typing import TYPE_CHECKING
8+
from typing import TYPE_CHECKING, Any
99

1010
from composer.core.serializable import Serializable
1111

@@ -88,6 +88,11 @@ class Callback(Serializable, abc.ABC):
8888
trainer.engine.close()
8989
"""
9090

91+
def __init__(self, *args: Any, **kwargs: Any) -> None:
92+
# Stub signature for pyright
93+
del args, kwargs # unused
94+
pass
95+
9196
def run_event(self, event: Event, state: State, logger: Logger) -> None:
9297
"""This method is called by the engine on each event.
9398

composer/loggers/file_logger.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class FileLogger(LoggerDestination):
113113
:attr:`~.LogLevel.EPOCH`, then the logfile will be flushed every n epochs. If
114114
the ``log_level`` is :attr:`~.LogLevel.BATCH`, then the logfile will be
115115
flushed every n batches. Default: ``100``.
116+
overwrite (bool, optional): Whether to overwrite an existing logfile. (default: ``False``)
116117
"""
117118

118119
def __init__(

composer/loggers/logger_hparams.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,18 +136,18 @@ def initialize_object(self) -> WandBLogger:
136136
config_dict = self._flatten_dict(config_dict, prefix=[])
137137

138138
init_params = {
139-
"project": self.project,
140-
"name": self.name,
141-
"group": self.group,
142-
"entity": self.entity,
143-
"tags": tags,
144139
"config": config_dict,
145140
}
146141
init_params.update(self.extra_init_params)
147142
return WandBLogger(
143+
project=self.project,
144+
group=self.group,
145+
name=self.name,
146+
entity=self.entity,
147+
tags=tags,
148148
log_artifacts=self.log_artifacts,
149149
rank_zero_only=self.rank_zero_only,
150-
init_params=init_params,
150+
init_kwargs=init_params,
151151
)
152152

153153
@classmethod

composer/loggers/progress_bar_logger.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ def should_log(state: State, log_level: LogLevel, console_log_level: LogLevel =
129129
elif stream.lower() == "stderr":
130130
stream = sys.stderr
131131
else:
132-
raise ValueError("Invalid stream option: Should be 'stdout', 'stderr', or a TextIO-like object.")
132+
raise ValueError(
133+
f"Invalid stream option: Should be 'stdout', 'stderr', or a TextIO-like object; got {stream}")
133134
self.stream = stream
134135

135136
def log_data(self, state: State, log_level: LogLevel, data: Dict[str, Any]) -> None:

composer/loggers/wandb_logger.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import sys
1414
import tempfile
1515
import warnings
16-
from typing import Any, Dict, Optional
16+
from typing import Any, Dict, List, Optional
1717

1818
from composer.core.state import State
1919
from composer.loggers.logger import Logger, LogLevel
@@ -28,6 +28,12 @@ class WandBLogger(LoggerDestination):
2828
"""Log to Weights and Biases (https://wandb.ai/)
2929
3030
Args:
31+
project (str, optional): WandB project name.
32+
group (str, optional): WandB group name.
33+
name (str, optional): WandB run name.
34+
If not specified, the :attr:`.Logger.run_name` will be used.
35+
entity (str, optional): WandB entity name.
36+
tags (List[str], optional): WandB tags.
3137
log_artifacts (bool, optional): Whether to log
3238
`artifacts <https://docs.wandb.ai/ref/python/artifact>`_ (Default: ``False``).
3339
rank_zero_only (bool, optional): Whether to log only on the rank-zero process.
@@ -36,15 +42,22 @@ class WandBLogger(LoggerDestination):
3642
stored, which may discard pertinent information. For example, when using
3743
Deepspeed ZeRO, it would be impossible to restore from checkpoints without
3844
artifacts from all ranks (default: ``False``).
39-
init_params (Dict[str, Any], optional): Parameters to pass into
45+
init_kwargs (Dict[str, Any], optional): Any additional init kwargs
4046
``wandb.init`` (see
4147
`WandB documentation <https://docs.wandb.ai/ref/python/init>`_).
4248
"""
4349

44-
def __init__(self,
45-
log_artifacts: bool = False,
46-
rank_zero_only: bool = True,
47-
init_params: Optional[Dict[str, Any]] = None) -> None:
50+
def __init__(
51+
self,
52+
project: Optional[str] = None,
53+
group: Optional[str] = None,
54+
name: Optional[str] = None,
55+
entity: Optional[str] = None,
56+
tags: Optional[List[str]] = None,
57+
log_artifacts: bool = False,
58+
rank_zero_only: bool = True,
59+
init_kwargs: Optional[Dict[str, Any]] = None,
60+
) -> None:
4861
try:
4962
import wandb
5063
except ImportError as e:
@@ -60,11 +73,27 @@ def __init__(self,
6073
"restore from checkpoints."))
6174
self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0
6275

76+
if init_kwargs is None:
77+
init_kwargs = {}
78+
79+
if project is not None:
80+
init_kwargs['project'] = project
81+
82+
if group is not None:
83+
init_kwargs['group'] = group
84+
85+
if name is not None:
86+
init_kwargs['name'] = name
87+
88+
if entity is not None:
89+
init_kwargs['entity'] = entity
90+
91+
if tags is not None:
92+
init_kwargs['tags'] = tags
93+
6394
self._rank_zero_only = rank_zero_only
6495
self._log_artifacts = log_artifacts
65-
if init_params is None:
66-
init_params = {}
67-
self._init_params = init_params
96+
self._init_kwargs = init_kwargs
6897
self._is_in_atexit = False
6998

7099
def _set_is_in_atexit(self):
@@ -98,17 +127,16 @@ def init(self, state: State, logger: Logger) -> None:
98127
del state # unused
99128

100129
# Use the logger run name if the name is not set.
101-
if "name" not in self._init_params or self._init_params["name"] is None:
102-
self._init_params["name"] = logger.run_name
130+
if "name" not in self._init_kwargs or self._init_kwargs["name"] is None:
131+
self._init_kwargs["name"] = logger.run_name
103132

104133
# Adjust name and group based on `rank_zero_only`.
105134
if not self._rank_zero_only:
106-
name = self._init_params["name"]
107-
group = self._init_params.get("group", None)
108-
self._init_params["name"] = f"{name} [RANK_{dist.get_global_rank()}]"
109-
self._init_params["group"] = group if group else name
135+
name = self._init_kwargs["name"]
136+
self._init_kwargs["name"] += f"-rank{dist.get_global_rank()}"
137+
self._init_kwargs["group"] = self._init_kwargs["group"] if "group" in self._init_kwargs else name
110138
if self._enabled:
111-
wandb.init(**self._init_params)
139+
wandb.init(**self._init_kwargs)
112140
atexit.register(self._set_is_in_atexit)
113141

114142
def log_file_artifact(self, state: State, log_level: LogLevel, artifact_name: str, file_path: pathlib.Path, *,

composer/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
def _raise_missing_argument_exception(arg_name: str):
5656
raise ValueError((f"{arg_name} is a required argument and must be specified when constructing the "
5757
f"{Trainer.__name__} or when calling {Trainer.__name__}.{Trainer.fit.__name__}(). "
58-
f"To fix, please specify `arg_name` via {Trainer.__name__}({arg_name}=...) or "
58+
f"To fix, please specify `{arg_name}` via {Trainer.__name__}({arg_name}=...) or "
5959
f"{Trainer.__name__}.{Trainer.fit.__name__}({arg_name}=...)."))
6060

6161

0 commit comments

Comments
 (0)