Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
8ae1a45
Treat test warnings as errors; test cleanup
ravi-mosaicml May 16, 2022
42556ad
Remove webdatasets from setup.py
ravi-mosaicml May 16, 2022
cf7854d
Ignore malformed yaml warnings coming from yahp
ravi-mosaicml May 17, 2022
8888a35
Merge branch 'dev' into test_warning_cleanup
ravi-mosaicml May 17, 2022
a71078b
Fixed filter
ravi-mosaicml May 17, 2022
b50912f
Disabiling tqdm by default in tests
ravi-mosaicml May 17, 2022
0d91ba0
Fix pyright errors with uninitailized instance variables
ravi-mosaicml May 17, 2022
d69f473
Ignoring warnings for torch jit export
ravi-mosaicml May 17, 2022
be052ce
Fix ffcv tests?
ravi-mosaicml May 17, 2022
931b5da
Refactor Algorithms for AutoYAHP
ravi-mosaicml May 17, 2022
58f1383
Silence more warnings
ravi-mosaicml May 17, 2022
bbf1a70
Merge branch 'test_warning_cleanup' into remove_hparams_from_tests
ravi-mosaicml May 17, 2022
ad33e8d
Merge branch 'dev' into test_warning_cleanup
ravi-mosaicml May 17, 2022
e824353
Merge branch 'test_warning_cleanup' into remove_hparams_from_tests
ravi-mosaicml May 17, 2022
fb7ee90
Fix algorithm resumption
ravi-mosaicml May 17, 2022
f88ffec
Fix tests
ravi-mosaicml May 17, 2022
6c7b6f7
Merge branch 'test_warning_cleanup' into remove_hparams_from_tests
ravi-mosaicml May 17, 2022
6d2496e
Addressed some PR feedback
ravi-mosaicml May 17, 2022
19bf68e
Fix more deprecation warnings
ravi-mosaicml May 17, 2022
9f950f2
reportMissingImports=None; removed type ignores
ravi-mosaicml May 17, 2022
d5371a1
Fix warnings
ravi-mosaicml May 17, 2022
2de8ab6
Fix WandB Logger
ravi-mosaicml May 17, 2022
4637223
Merge branch 'dev' into test_warning_cleanup
ravi-mosaicml May 17, 2022
92e3823
Adjusted warning message
ravi-mosaicml May 17, 2022
52c7baa
Merge branch 'test_warning_cleanup' into remove_hparams_from_tests
ravi-mosaicml May 17, 2022
6b933c3
Fix more warnings
ravi-mosaicml May 18, 2022
6657bdc
Merge branch 'test_warning_cleanup' into remove_hparams_from_tests
ravi-mosaicml May 18, 2022
fb9b92f
Increase timeout
ravi-mosaicml May 18, 2022
32340a4
Merge branch 'test_warning_cleanup' into remove_hparams_from_tests
ravi-mosaicml May 18, 2022
cd25dc8
Force cpuonly in conda
ravi-mosaicml May 18, 2022
1f1a90f
Merge branch 'test_warning_cleanup' into remove_hparams_from_tests
ravi-mosaicml May 18, 2022
bee4f1e
Refactor Callbacks for AutoYAHP
ravi-mosaicml May 18, 2022
431fc43
Update pyproject.toml
ravi-mosaicml May 18, 2022
a50da20
Merge branch 'dev' into test_warning_cleanup
ravi-mosaicml May 20, 2022
915449a
Merge branch 'test_warning_cleanup' into remove_hparams_from_tests
ravi-mosaicml May 20, 2022
03b9b2d
Merge branch 'remove_hparams_from_tests' into refactor_callbacks_for_…
ravi-mosaicml May 20, 2022
1343f75
Merge branch 'dev' into remove_hparams_from_tests
ravi-mosaicml May 20, 2022
f166fe1
Merge branch 'remove_hparams_from_tests' into refactor_callbacks_for_…
ravi-mosaicml May 20, 2022
fc25f05
Merge branch 'dev' into remove_hparams_from_tests
ravi-mosaicml May 24, 2022
3d0d7f6
Fixed layer freezing in `test_algorithm_trains`
ravi-mosaicml May 24, 2022
51e5f81
Refactored alg settingse
ravi-mosaicml May 24, 2022
9571d1d
Merge branch 'dev' into remove_hparams_from_tests
ravi-mosaicml May 24, 2022
12b817a
Merge branch 'remove_hparams_from_tests' into refactor_callbacks_for_…
ravi-mosaicml May 24, 2022
43f8d14
Refactored imports
ravi-mosaicml May 24, 2022
106c852
Fixed augmix and randaugment
ravi-mosaicml May 24, 2022
2246019
Fix more tests
ravi-mosaicml May 25, 2022
61d9a07
Merge branch 'dev' into remove_hparams_from_tests
ravi-mosaicml May 25, 2022
c6a0614
Merge branch 'remove_hparams_from_tests' into refactor_callbacks_for_…
ravi-mosaicml May 25, 2022
a0aa237
Merge branch 'dev' into refactor_callbacks_for_autoyahp
ravi-mosaicml May 25, 2022
337f283
Fix the memory monitor error message
ravi-mosaicml May 25, 2022
0253a1b
Fix more tests
ravi-mosaicml May 25, 2022
d4bd7ac
Fix tests
ravi-mosaicml May 25, 2022
12edac2
Fix tests again
ravi-mosaicml May 25, 2022
45ad601
Merge branch 'dev' into refactor_callbacks_for_autoyahp
ravi-mosaicml May 25, 2022
dfb346a
Address PR Feedback
ravi-mosaicml May 25, 2022
383fb46
Run the memory monitor on cpu; use a filterwarnings
ravi-mosaicml May 25, 2022
e08e9fa
Fix memory monitor test
ravi-mosaicml May 25, 2022
92564ba
Do the `test_hf_model` on gpu
ravi-mosaicml May 25, 2022
3745458
Fix tests hopefully for good
ravi-mosaicml May 25, 2022
7c62f94
Merge branch 'dev' into refactor_callbacks_for_autoyahp
ravi-mosaicml May 25, 2022
19ea309
Increase timeout
ravi-mosaicml May 26, 2022
443cdbb
Merge branch 'refactor_callbacks_for_autoyahp' of github.com:ravi-mos…
ravi-mosaicml May 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions composer/callbacks/callback_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,11 @@
"MemoryMonitorHparams",
"LRMonitorHparams",
"SpeedMonitorHparams",
"EarlyStopperHparams",
"CheckpointSaverHparams",
"ThresholdStopperHparams",
"MLPerfCallbackHparams",
"callback_registry",
]


Expand Down Expand Up @@ -321,3 +325,15 @@ def initialize_object(self) -> CheckpointSaver:
weights_only=self.weights_only,
num_checkpoints_to_keep=self.num_checkpoints_to_keep,
)


callback_registry = {
"checkpoint_saver": CheckpointSaverHparams,
"speed_monitor": SpeedMonitorHparams,
"lr_monitor": LRMonitorHparams,
"grad_monitor": GradMonitorHparams,
"memory_monitor": MemoryMonitorHparams,
"mlperf": MLPerfCallbackHparams,
"early_stopper": EarlyStopperHparams,
"threshold_stopper": ThresholdStopperHparams,
}
2 changes: 1 addition & 1 deletion composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class CheckpointSaver(Callback):
If ``False`` (the default), then the ``folder`` must not exist or must not contain checkpoints which may conflict
with the current run. Default: ``False``.

save_interval (:class:`.Time` | str | int | (:class:`.State`, :class:`.Event`) -> bool): A :class:`.Time`, time-string, integer (in epochs),
save_interval (Time | str | int | (State, Event) -> bool): A :class:`.Time`, time-string, integer (in epochs),
or a function that takes (state, event) and returns a boolean whether a checkpoint should be saved.

If an integer, checkpoints will be saved every n epochs.
Expand Down
28 changes: 11 additions & 17 deletions composer/callbacks/memory_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

"""Log memory usage during training."""
import logging
import warnings
from typing import Dict, Union

import torch.cuda
Expand Down Expand Up @@ -76,24 +77,24 @@ class MemoryMonitor(Callback):
Memory usage monitoring is only supported for the GPU devices.
"""

def __init__(self):
super().__init__()
log.info(
"Memory monitor just profiles the current GPU assuming that the memory footprint across GPUs is balanced.")
if torch.cuda.device_count() == 0:
log.warning("Memory monitor only works on GPU devices.")
def fit_start(self, state: State, logger: Logger) -> None:
# TODO(ravi) move this check would be on Event.INIT after #1084 is merged
# Not relying on `torch.cuda.is_available()` since the model could be on CPU.
model_device = next(state.model.parameters()).device

if model_device.type != 'cuda':
warnings.warn(f"The memory monitor only works on CUDA devices, but the model is on {model_device.type}.")

def after_train_batch(self, state: State, logger: Logger):
memory_report = {}

n_devices = torch.cuda.device_count()
if n_devices == 0:
model_device = next(state.model.parameters()).device
if model_device.type != 'cuda':
return

memory_report = _get_memory_report()

for mem_stat, val in memory_report.items():
logger.data_batch({'memory/{}'.format(mem_stat): val})
logger.data_batch({f'memory/{mem_stat}': val for (mem_stat, val) in memory_report.items()})


_MEMORY_STATS = {
Expand All @@ -108,15 +109,8 @@ def after_train_batch(self, state: State, logger: Logger):


def _get_memory_report() -> Dict[str, Union[int, float]]:
if not torch.cuda.is_available():
log.debug("Cuda is not available. The memory report will be empty.")
return {}
memory_stats = torch.cuda.memory_stats()

if len(memory_stats) == 0:
log.debug("No GPU memory was used, returning empty.")
return {}

# simplify the memory_stats
memory_report = {
name: memory_stats[torch_name] for (torch_name, name) in _MEMORY_STATS.items() if torch_name in memory_stats
Expand Down
16 changes: 8 additions & 8 deletions composer/callbacks/speed_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
class SpeedMonitor(Callback):
"""Logs the training throughput.

The training throughput in terms of number of samples per second is logged on the
:attr:`~composer.core.event.Event.BATCH_END` event if we have reached the ``window_size`` threshold. Per epoch
average throughput and wall clock train, validation, and total time is also logged on the
:attr:`~composer.core.event.Event.EPOCH_END` event.
The training throughput in terms of number of samples per second is logged on
the :attr:`~composer.core.event.Event.BATCH_END` event if we have reached the ``window_size`` threshold. Per epoch
average throughput and wall clock train, validation, and total time is also logged on
the :attr:`~composer.core.event.Event.EPOCH_END` event.

Example

Expand Down Expand Up @@ -49,11 +49,11 @@ class SpeedMonitor(Callback):
| Key | Logged data |
+=======================+=============================================================+
| | Rolling average (over ``window_size`` most recent |
| ``throughput/step`` | batches) of the number of samples processed per second |
| ``samples/step`` | batches) of the number of samples processed per second |
| | |
+-----------------------+-------------------------------------------------------------+
| | Number of samples processed per second (averaged over |
| ``throughput/epoch`` | an entire epoch) |
| ``samples/epoch`` | an entire epoch) |
+-----------------------+-------------------------------------------------------------+
|``wall_clock/train`` | Total elapsed training time |
+-----------------------+-------------------------------------------------------------+
Expand All @@ -63,8 +63,8 @@ class SpeedMonitor(Callback):
+-----------------------+-------------------------------------------------------------+

Args:
window_size (int, optional):
Number of batches to use for a rolling average of throughput. Default to 100.
window_size (int, optional): Number of batches to use for a rolling average of throughput.
Default to 100.
"""

def __init__(self, window_size: int = 100):
Expand Down
7 changes: 6 additions & 1 deletion composer/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

from composer.core.serializable import Serializable

Expand Down Expand Up @@ -88,6 +88,11 @@ class Callback(Serializable, abc.ABC):
trainer.engine.close()
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
# Stub signature for pyright
del args, kwargs # unused
pass

def run_event(self, event: Event, state: State, logger: Logger) -> None:
"""This method is called by the engine on each event.

Expand Down
1 change: 1 addition & 0 deletions composer/loggers/file_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class FileLogger(LoggerDestination):
:attr:`~.LogLevel.EPOCH`, then the logfile will be flushed every n epochs. If
the ``log_level`` is :attr:`~.LogLevel.BATCH`, then the logfile will be
flushed every n batches. Default: ``100``.
overwrite (bool, optional): Whether to overwrite an existing logfile. (default: ``False``)
"""

def __init__(
Expand Down
12 changes: 6 additions & 6 deletions composer/loggers/logger_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,18 @@ def initialize_object(self) -> WandBLogger:
config_dict = self._flatten_dict(config_dict, prefix=[])

init_params = {
"project": self.project,
"name": self.name,
"group": self.group,
"entity": self.entity,
"tags": tags,
"config": config_dict,
}
init_params.update(self.extra_init_params)
return WandBLogger(
project=self.project,
group=self.group,
name=self.name,
entity=self.entity,
tags=tags,
log_artifacts=self.log_artifacts,
rank_zero_only=self.rank_zero_only,
init_params=init_params,
init_kwargs=init_params,
)

@classmethod
Expand Down
3 changes: 2 additions & 1 deletion composer/loggers/progress_bar_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def should_log(state: State, log_level: LogLevel, console_log_level: LogLevel =
elif stream.lower() == "stderr":
stream = sys.stderr
else:
raise ValueError("Invalid stream option: Should be 'stdout', 'stderr', or a TextIO-like object.")
raise ValueError(
f"Invalid stream option: Should be 'stdout', 'stderr', or a TextIO-like object; got {stream}")
self.stream = stream

def log_data(self, state: State, log_level: LogLevel, data: Dict[str, Any]) -> None:
Expand Down
60 changes: 44 additions & 16 deletions composer/loggers/wandb_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import sys
import tempfile
import warnings
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from composer.core.state import State
from composer.loggers.logger import Logger, LogLevel
Expand All @@ -28,6 +28,12 @@ class WandBLogger(LoggerDestination):
"""Log to Weights and Biases (https://wandb.ai/)

Args:
project (str, optional): WandB project name.
group (str, optional): WandB group name.
name (str, optional): WandB run name.
If not specified, the :attr:`.Logger.run_name` will be used.
entity (str, optional): WandB entity name.
tags (List[str], optional): WandB tags.
log_artifacts (bool, optional): Whether to log
`artifacts <https://docs.wandb.ai/ref/python/artifact>`_ (Default: ``False``).
rank_zero_only (bool, optional): Whether to log only on the rank-zero process.
Expand All @@ -36,15 +42,22 @@ class WandBLogger(LoggerDestination):
stored, which may discard pertinent information. For example, when using
Deepspeed ZeRO, it would be impossible to restore from checkpoints without
artifacts from all ranks (default: ``False``).
init_params (Dict[str, Any], optional): Parameters to pass into
init_kwargs (Dict[str, Any], optional): Any additional init kwargs
``wandb.init`` (see
`WandB documentation <https://docs.wandb.ai/ref/python/init>`_).
"""

def __init__(self,
log_artifacts: bool = False,
rank_zero_only: bool = True,
init_params: Optional[Dict[str, Any]] = None) -> None:
def __init__(
self,
project: Optional[str] = None,
group: Optional[str] = None,
name: Optional[str] = None,
entity: Optional[str] = None,
tags: Optional[List[str]] = None,
log_artifacts: bool = False,
rank_zero_only: bool = True,
init_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
try:
import wandb
except ImportError as e:
Expand All @@ -60,11 +73,27 @@ def __init__(self,
"restore from checkpoints."))
self._enabled = (not rank_zero_only) or dist.get_global_rank() == 0

if init_kwargs is None:
init_kwargs = {}

if project is not None:
init_kwargs['project'] = project

if group is not None:
init_kwargs['group'] = group

if name is not None:
init_kwargs['name'] = name

if entity is not None:
init_kwargs['entity'] = entity

if tags is not None:
init_kwargs['tags'] = tags

self._rank_zero_only = rank_zero_only
self._log_artifacts = log_artifacts
if init_params is None:
init_params = {}
self._init_params = init_params
self._init_kwargs = init_kwargs
self._is_in_atexit = False

def _set_is_in_atexit(self):
Expand Down Expand Up @@ -98,17 +127,16 @@ def init(self, state: State, logger: Logger) -> None:
del state # unused

# Use the logger run name if the name is not set.
if "name" not in self._init_params or self._init_params["name"] is None:
self._init_params["name"] = logger.run_name
if "name" not in self._init_kwargs or self._init_kwargs["name"] is None:
self._init_kwargs["name"] = logger.run_name

# Adjust name and group based on `rank_zero_only`.
if not self._rank_zero_only:
name = self._init_params["name"]
group = self._init_params.get("group", None)
self._init_params["name"] = f"{name} [RANK_{dist.get_global_rank()}]"
self._init_params["group"] = group if group else name
name = self._init_kwargs["name"]
self._init_kwargs["name"] += f"-rank{dist.get_global_rank()}"
self._init_kwargs["group"] = self._init_kwargs["group"] if "group" in self._init_kwargs else name
if self._enabled:
wandb.init(**self._init_params)
wandb.init(**self._init_kwargs)
atexit.register(self._set_is_in_atexit)

def log_file_artifact(self, state: State, log_level: LogLevel, artifact_name: str, file_path: pathlib.Path, *,
Expand Down
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
def _raise_missing_argument_exception(arg_name: str):
raise ValueError((f"{arg_name} is a required argument and must be specified when constructing the "
f"{Trainer.__name__} or when calling {Trainer.__name__}.{Trainer.fit.__name__}(). "
f"To fix, please specify `arg_name` via {Trainer.__name__}({arg_name}=...) or "
f"To fix, please specify `{arg_name}` via {Trainer.__name__}({arg_name}=...) or "
f"{Trainer.__name__}.{Trainer.fit.__name__}({arg_name}=...)."))


Expand Down
4 changes: 1 addition & 3 deletions examples/checkpoint_with_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
# The project name must be deterministic, so we can restore from it
wandb_logger = WandBLogger(
log_artifacts=True,
init_params={
'project': 'my-wandb-project-name',
},
project='my-wandb-project-name',
)

# Configure the trainer -- here, we train a simple MNIST classifier
Expand Down
Loading