From 32eead6ae96bcb3b90803969527205416214e322 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Thu, 2 Feb 2023 01:15:09 +0000 Subject: [PATCH 1/5] Initial switch to a more memory efficient implementation --- composer/algorithms/ema/ema.py | 145 ++++++++++++++++++++------------- tests/algorithms/test_ema.py | 38 +++++---- 2 files changed, 111 insertions(+), 72 deletions(-) diff --git a/composer/algorithms/ema/ema.py b/composer/algorithms/ema/ema.py index 8b0b2628e9..5b277164fa 100644 --- a/composer/algorithms/ema/ema.py +++ b/composer/algorithms/ema/ema.py @@ -8,7 +8,7 @@ import copy import itertools import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union import torch @@ -21,7 +21,7 @@ __all__ = ['EMA', 'compute_ema'] -def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: float = 0.99): +def compute_ema(model: torch.nn.Module, ema_model: Union[torch.nn.Module, EMAParameters], smoothing: float = 0.99): r"""Updates the weights of ``ema_model`` to be closer to the weights of ``model`` according to an exponential weighted average. Weights are updated according to @@ -42,7 +42,7 @@ def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: f Args: model (torch.nn.Module): the model containing the latest weights to use to update the moving average weights. - ema_model (torch.nn.Module): the model containing the moving average weights to be updated. + ema_model (torch.nn.Module, EMAParameters): the model containing the moving average weights to be updated. smoothing (float, optional): the coefficient representing the degree to which older observations are kept. Must be in the interval :math:`(0, 1)`. Default: ``0.99``. @@ -56,16 +56,28 @@ def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: f cf.compute_ema(model, ema_model, smoothing=0.9) """ with torch.no_grad(): - model_params = itertools.chain(model.parameters(), model.buffers()) - ema_model_params = itertools.chain(ema_model.parameters(), ema_model.buffers()) - - for ema_param, model_param in zip(ema_model_params, model_params): - model_param = model_param.detach() - ema_param.copy_(ema_param * smoothing + (1. - smoothing) * model_param) + # If the ema model is a pytorch module, can just use the state_dict + if isinstance(ema_model, torch.nn.Module): + ema_params = ema_model.state_dict() + for name, param in itertools.chain(model.named_parameters(), model.named_buffers()): + if name in ema_params: + ema_params[name].copy_(ema_params[name] * smoothing + param.data * (1. - smoothing)) + # Otherwise, the ema model needs to define the named_parameters and named_buffers dictionaries + # These should contain the parameters and buffers to average. + elif isinstance(ema_model, EMAParameters): + ema_parameters = ema_model.named_parameters_dict + ema_buffers = ema_model.named_buffers_dict + for name, param in itertools.chain(model.named_parameters(), model.named_buffers()): + if name in ema_parameters: + ema_parameters[name].copy_(ema_parameters[name] * smoothing + param.data * (1. - smoothing)) + if name in ema_buffers: + ema_buffers[name].copy_(ema_buffers[name] * smoothing + param.data * (1. - smoothing)) + else: + raise ValueError("ema_model must be a torch.nn.Module or EMAParameters") class EMA(Algorithm): - r"""Maintains a shadow model with weights that follow the exponential moving average of the trained model weights. + r"""Maintains a set of weights that follow the exponential moving average of the training model weights. Weights are updated according to @@ -78,7 +90,7 @@ class EMA(Algorithm): smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right] Model evaluation is done with the moving average weights, which can result in better generalization. Because of the - shadow models, EMA triples the model's memory consumption. Note that this does not mean that the total memory + ema weights, EMA can double the model's memory consumption. Note that this does not mean that the total memory required doubles, since stored activations and the optimizer state are not duplicated. EMA also uses a small amount of extra compute to update the moving average weights. @@ -124,10 +136,9 @@ def __init__(self, ema_start: str = '0.0dur', update_interval: Optional[str] = None): self.ema_model = None - self.training_model = None self.ema_weights_active = False self.ema_started = False - self.serialized_attributes = ['ema_model', 'training_model', 'ema_weights_active', 'ema_started'] + self.serialized_attributes = ['ema_model', 'ema_weights_active', 'ema_started'] # Verify that either half_life or smoothing has been specified if half_life is None and smoothing is None: @@ -191,6 +202,16 @@ def _should_start(self, state: State) -> bool: return should_start + def _ensure_training_weights_active(self, state: State): + if self.ema_weights_active is True: + _swap_params(model=state.model, ema_parameters=self.ema_model) + self.ema_weights_active = False + + def _ensure_ema_weights_active(self, state: State): + if self.ema_weights_active is False: + _swap_params(model=state.model, ema_parameters=self.ema_model) + self.ema_weights_active = True + def match(self, event: Event, state: State) -> bool: # Always run on init if event == Event.INIT: @@ -198,8 +219,7 @@ def match(self, event: Event, state: State) -> bool: # Check if ema should start running, and if so reinitialize models if event == self.update_event and self.ema_started is False and self._should_start(state): - self.ema_model = copy.deepcopy(state.model) - self.training_model = copy.deepcopy(state.model) + self.ema_model = EMAParameters(state.model) self.ema_started = True # Match on checkpointing events if a checkpoint is to be saved @@ -225,21 +245,17 @@ def apply(self, event: Event, state: State, logger: Logger) -> None: if event == Event.INIT: # Create the models so that the checkpoints can be loaded - self.ema_model = copy.deepcopy(state.model) - self.training_model = copy.deepcopy(state.model) + self.ema_model = EMAParameters(state.model) assert self.ema_model is not None - assert self.training_model is not None if event == Event.FIT_START: # Ensure that params are on the right device if a checkpoint has been loaded - _move_params_to_device(model=self.ema_model, destination_model=state.model) - _move_params_to_device(model=self.training_model, destination_model=state.model) + _move_params_to_device(ema_parameters=self.ema_model, destination_model=state.model) if event == Event.BATCH_START and self.ema_weights_active: # Ensure the model being trained has the correct weights - _copy_params(source_model=self.training_model, destination_model=state.model) - self.ema_weights_active = False + self._ensure_training_weights_active(state) if event in [Event.BATCH_END, Event.EPOCH_END]: # Update the ema model @@ -247,27 +263,24 @@ def apply(self, event: Event, state: State, logger: Logger) -> None: if event == Event.EVAL_START and self.ema_weights_active is False: # Swap out the training model for the ema model in state - _copy_params(source_model=state.model, destination_model=self.training_model) - _copy_params(source_model=self.ema_model, destination_model=state.model) - self.ema_weights_active = True + self._ensure_ema_weights_active(state) if event == Event.EVAL_END: # Swap out the ema model for the training model in state - _copy_params(source_model=self.training_model, destination_model=state.model) - self.ema_weights_active = False + self._ensure_training_weights_active(state) - if event in self.checkpoint_events and self.ema_weights_active is False: + if event in self.checkpoint_events: # Swap the training model out for the ema model for checkpointing - _copy_params(source_model=state.model, destination_model=self.training_model) - _copy_params(source_model=self.ema_model, destination_model=state.model) - self.ema_weights_active = True + self._ensure_ema_weights_active(state) def state_dict(self) -> Dict[str, Any]: state_dict = super().state_dict() for attribute_name in self.serialized_attributes: - if attribute_name in ['ema_model', 'training_model']: - model = getattr(self, attribute_name) - state_dict[attribute_name] = model.state_dict() + if attribute_name == 'ema_model': + ema_model = getattr(self, attribute_name) + state_dict[attribute_name] = {} + state_dict[attribute_name]['named_parameters_dict'] = ema_model.named_parameters_dict + state_dict[attribute_name]['named_buffers_dict'] = ema_model.named_buffers_dict else: state_dict[attribute_name] = getattr(self, attribute_name) return state_dict @@ -275,33 +288,53 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state: Dict[str, Any], strict: bool = False): for attribute_name, serialized_value in state.items(): if attribute_name != 'repr': # skip attribute added by parent class - if attribute_name == 'ema_model' and self.ema_model is not None: - self.ema_model.load_state_dict(serialized_value) - elif attribute_name == 'training_model' and self.training_model is not None: - self.training_model.load_state_dict(serialized_value) + if attribute_name == 'ema_model': + self.ema_model = EMAParameters(None) + self.ema_model.named_parameters_dict = serialized_value['named_parameters_dict'] + self.ema_model.named_buffers_dict = serialized_value['named_buffers_dict'] else: setattr(self, attribute_name, serialized_value) -def _copy_params(source_model: torch.nn.Module, destination_model: torch.nn.Module): - """Copies parameters and buffers from ``source_model`` to ``destination_model``.""" - with torch.no_grad(): - source_params = itertools.chain(source_model.parameters(), source_model.buffers()) - destination_params = itertools.chain(destination_model.parameters(), destination_model.buffers()) +class EMAParameters: + """A class that stores the parameters and buffers of a model needed for averaging.""" + + def __init__(self, model: Union[None, torch.nn.Module]): + if model is not None: + # Copy the trainable parameters and buffers. + self.named_parameters_dict = {name: param.data.clone() for name, param in model.named_parameters() if param.requires_grad} + self.named_buffers_dict = {name: buffer.data.clone() for name, buffer in model.named_buffers()} + else: + # Empty storage + self.named_parameters_dict = {} + self.named_buffers_dict = {} - for source_param, destination_param in zip(source_params, destination_params): - destination_param.data = source_param.data + def named_parameters(self): + return self.named_parameters_dict.items() + def named_buffers(self): + return self.named_buffers_dict.items() -def _move_params_to_device(model: torch.nn.Module, destination_model: torch.nn.Module): - """Ensures the parameters of a model are on the same device as a destination model.""" + +def _swap_params(model: torch.nn.Module, ema_parameters: EMAParameters): + """Swaps the parameters and buffers of a model with those in ema_parameters.""" with torch.no_grad(): - destination_params = destination_model.parameters() - params = model.parameters() - for s, d in zip(params, destination_params): - s.to(d.device) - - destination_buffers = destination_model.buffers() - buffers = model.buffers() - for s, d in zip(buffers, destination_buffers): - s.to(d.device) + ema_params = ema_parameters.named_parameters_dict + ema_buffers = ema_parameters.named_buffers_dict + + for name, param in model.named_parameters(): + if name in ema_params: + param.data, ema_params[name] = ema_params[name], param.data + + for name, buffer in model.named_buffers(): + buffer.data, ema_buffers[name] = ema_buffers[name], buffer.data + + +def _move_params_to_device(ema_parameters: EMAParameters, destination_model: torch.nn.Module): + """Moves the ema parameters and buffers to the device of a destination model.""" + model_state_dict = destination_model.state_dict() + for name, param in ema_parameters.named_parameters_dict.items(): + ema_parameters.named_parameters_dict[name] = param.to(model_state_dict[name].device) + + for name, buffer in ema_parameters.named_buffers_dict.items(): + ema_parameters.named_buffers_dict[name] = buffer.to(model_state_dict[name].device) diff --git a/tests/algorithms/test_ema.py b/tests/algorithms/test_ema.py index bc7f64bcce..46c62059b3 100644 --- a/tests/algorithms/test_ema.py +++ b/tests/algorithms/test_ema.py @@ -9,29 +9,35 @@ import torch from composer.algorithms import EMA -from composer.algorithms.ema.ema import compute_ema +from composer.algorithms.ema.ema import compute_ema, EMAParameters from composer.core import Event, Time, Timestamp, TimeUnit from tests.common import SimpleConvModel, SimpleTransformerClassifier from tests.common.models import configure_tiny_bert_hf_model def validate_ema(model, original_model, ema_model, smoothing): - model_params = itertools.chain(model.parameters(), model.buffers()) - original_params = itertools.chain(original_model.parameters(), original_model.buffers()) - ema_params = itertools.chain(ema_model.parameters(), ema_model.buffers()) + model_params, model_buffers = dict(model.named_parameters()), dict(model.named_buffers()) + original_params, original_buffers = dict(original_model.named_parameters()), dict(original_model.named_buffers()) + ema_params, ema_buffers = dict(ema_model.named_parameters()), dict(ema_model.named_buffers()) - for model_param, original_param, ema_param in zip(model_params, original_params, ema_params): - new_param = (original_param * smoothing + (1. - smoothing) * model_param).type(ema_param.data.dtype) - torch.testing.assert_close(ema_param.data, new_param) + for name, param in model_params.items(): + new_param = (original_params[name] * smoothing + (1. - smoothing) * param) + torch.testing.assert_close(ema_params[name].data, new_param) + + for name, buffer in model_buffers.items(): + new_buffer = (original_buffers[name] * smoothing + (1. - smoothing) * buffer).type(ema_buffers[name].data.dtype) + torch.testing.assert_close(ema_buffers[name].data, new_buffer) def validate_model(model1, model2): - model1_params = itertools.chain(model1.parameters(), model1.buffers()) - model2_params = itertools.chain(model2.parameters(), model2.buffers()) + model1_params, model1_buffers = dict(model1.named_parameters()), dict(model1.named_buffers()) + model2_params, model2_buffers = dict(model2.named_parameters()), dict(model2.named_buffers()) - for model1_param, model2_param in zip(model1_params, model2_params): - torch.testing.assert_close(model1_param.data, model2_param) + for name, param in model1_params.items(): + torch.testing.assert_close(model1_params[name].data, model2_params[name].data) + for name, buffer in model1_buffers.items(): + torch.testing.assert_close(model1_buffers[name].data, model2_buffers[name].data) @pytest.mark.parametrize('smoothing', [0, 0.5, 0.99, 1]) @pytest.mark.parametrize('model_cls', [(SimpleConvModel), (SimpleTransformerClassifier), @@ -43,7 +49,6 @@ def test_ema(smoothing, model_cls): compute_ema(model=model, ema_model=ema_model, smoothing=smoothing) validate_ema(model, original_model, ema_model, smoothing) - # params = [(half_life, update_interval)] @pytest.mark.parametrize('params', [{ 'half_life': '10ba', @@ -72,8 +77,7 @@ def test_ema_algorithm(params, model_cls, minimal_state, empty_logger): state.batch = (input, torch.Tensor()) # Start EMA - algorithm.ema_model = copy.deepcopy(state.model) - algorithm.training_model = copy.deepcopy(state.model) + algorithm.ema_model = EMAParameters(state.model) # Check if ema correctly calculated smoothing update_interval = Time.from_timestring(params['update_interval']) if 'half_life' in params: @@ -84,6 +88,7 @@ def test_ema_algorithm(params, model_cls, minimal_state, empty_logger): # Fake a training update by replacing state.model after ema grabbed it. original_model = copy.deepcopy(state.model) state.model = model_cls() + training_updated_model = copy.deepcopy(state.model) # Do the EMA update state.timestamp = Timestamp() if update_interval.unit == TimeUnit.BATCH: @@ -96,9 +101,10 @@ def test_ema_algorithm(params, model_cls, minimal_state, empty_logger): raise ValueError(f'Invalid time string for parameter half_life') # Check if EMA correctly computed the average. validate_ema(state.model, original_model, algorithm.ema_model, algorithm.smoothing) + ema_updated_model = copy.deepcopy(algorithm.ema_model) # Check if the EMA model is swapped in for testing algorithm.apply(Event.EVAL_START, state, empty_logger) - validate_model(state.model, algorithm.ema_model) + validate_model(state.model, ema_updated_model) # Check if the training model is swapped back in for training algorithm.apply(Event.EVAL_END, state, empty_logger) - validate_model(state.model, algorithm.training_model) + validate_model(state.model, training_updated_model) From d08fd93e5719344ac50e2e790ffdf35a187dec19 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 3 Feb 2023 19:09:10 +0000 Subject: [PATCH 2/5] Update readme and add methods to extract ema/training weights --- composer/algorithms/ema/README.md | 4 +- composer/algorithms/ema/ema.py | 85 ++++++++++++++++++++++--------- 2 files changed, 64 insertions(+), 25 deletions(-) diff --git a/composer/algorithms/ema/README.md b/composer/algorithms/ema/README.md index d716fbbb94..5cb414bd0a 100644 --- a/composer/algorithms/ema/README.md +++ b/composer/algorithms/ema/README.md @@ -66,7 +66,7 @@ model = ema.ema_model ### Implementation Details -Because EMA needs to maintain a copy of the model's (averaged) weights, it requires a bit more on-device memory. In the functional implementation, the amount of extra memory is 2x the size of the model. In the composer trainer implementation, it is 3x the size of the model to allow for swapping the training and evaluation models. In practice, the extra memory used is small relative to the total amount of memory used, as activations and optimizer state are not duplicated. +Because EMA needs to maintain a copy of the model's (averaged) weights, it requires a bit more on-device memory. The amount of extra memory used is equal to the size of the model's trainable parameters and buffers. In practice, the extra memory used is small relative to the total amount of memory used, as activations and optimizer state are not duplicated. EMA also uses a bit of extra compute to calculate the moving average. This can lead to a small slowdown. The extra compute can be reduced by not computing the moving average every iteration. In the composer trainer implementation this can be done by using a larger `update_interval`. In practice we find that as long as `half_life` is much larger than `update_interval`, increasing `update_interval` does not have much effect on generalization performance. @@ -113,7 +113,7 @@ To use this, `half_life` should be set to `half_life=None`, and the value of smo > ❗ Evaluation should not be done with the training model > -> Evaluation should be done with the `ema_model` in the functional impementation as this is the model containing the averaged parameters. The ema model can be accessed after training from the `EMA` object via `model = ema.ema_model` in the composer trainer implementation. Similarly, the model without ema applied (the training model) can be accessed via `model=ema.training_model`. By default, when saving checkpoints with the `CheckpointSaver` callback or through trainer arguments the weights saved will be the ema model weights. An exception is if saving is done by explicitly calling `trainer.save_checkpoint()` which will result in the training model weights being saved as `state.model`. +> Evaluation should be done with the `ema_model` in the functional impementation as this is the model containing the averaged parameters. The ema model can be accessed after training from the `EMA` object via `model = ema.get_ema_model(model)` in the composer trainer implementation. This replaces the parameters of the supplied model with the ema_weights unless composer's model already contains them. Similarly, the model without ema applied (the training model) can be accessed via `model=ema.get_training_model(model)`. By default, when saving checkpoints with the `CheckpointSaver` callback or through trainer arguments the weights saved will be the ema model weights. An exception is if saving is done by explicitly calling `trainer.save_checkpoint()` which will result in the training model weights being saved as `state.model`. ## Attribution diff --git a/composer/algorithms/ema/ema.py b/composer/algorithms/ema/ema.py index 5b277164fa..2cf63a5908 100644 --- a/composer/algorithms/ema/ema.py +++ b/composer/algorithms/ema/ema.py @@ -5,7 +5,6 @@ from __future__ import annotations -import copy import itertools import logging from typing import Any, Dict, Optional, Union @@ -203,13 +202,13 @@ def _should_start(self, state: State) -> bool: return should_start def _ensure_training_weights_active(self, state: State): - if self.ema_weights_active is True: - _swap_params(model=state.model, ema_parameters=self.ema_model) + if self.ema_weights_active is True and self.ema_model is not None: + self.ema_model.swap_params(model=state.model) self.ema_weights_active = False def _ensure_ema_weights_active(self, state: State): - if self.ema_weights_active is False: - _swap_params(model=state.model, ema_parameters=self.ema_model) + if self.ema_weights_active is False and self.ema_model is not None: + self.ema_model.swap_params(model=state.model) self.ema_weights_active = True def match(self, event: Event, state: State) -> bool: @@ -251,7 +250,7 @@ def apply(self, event: Event, state: State, logger: Logger) -> None: if event == Event.FIT_START: # Ensure that params are on the right device if a checkpoint has been loaded - _move_params_to_device(ema_parameters=self.ema_model, destination_model=state.model) + self.ema_model.move_params_to_device(destination_model=state.model) if event == Event.BATCH_START and self.ema_weights_active: # Ensure the model being trained has the correct weights @@ -295,6 +294,38 @@ def load_state_dict(self, state: Dict[str, Any], strict: bool = False): else: setattr(self, attribute_name, serialized_value) + def get_ema_model(self, model: torch.nn.Module) -> torch.nn.Module: + """Replaces the parameters of the supplied model with the ema parameters if they are not already active. + + Args: + model (torch.nn.Module): The model to replace the parameters of. + + Returns: + torch.nn.Module: The model with the ema parameters. + """ + assert self.ema_model is not None + # Ensure that self.ema_model contains the ema weights. If not raise an error. + if self.ema_weights_active == True: + raise ValueError('The ema weight are currently contained in the composer model.') + self.ema_model.transfer_ema_params(model=model) + return model + + def get_training_model(self, model: torch.nn.Module) -> torch.nn.Module: + """Replaces the parameters of the supplied model with the training parameters if they are not already active. + + Args: + model (torch.nn.Module): The model to replace the parameters of. + + Returns: + torch.nn.Module: The model with the training parameters. + """ + assert self.ema_model is not None + # Ensure that self.ema_model contains the training weights. If not raise an error. + if self.ema_weights_active == False: + raise ValueError('The training weights are currently contained in the composer model.') + self.ema_model.transfer_ema_params(model=model) + return model + class EMAParameters: """A class that stores the parameters and buffers of a model needed for averaging.""" @@ -315,26 +346,34 @@ def named_parameters(self): def named_buffers(self): return self.named_buffers_dict.items() + def swap_params(self, model: torch.nn.Module): + """Swaps the parameters and buffers of a model with the ema parameters.""" + with torch.no_grad(): + ema_params = self.named_parameters_dict + ema_buffers = self.named_buffers_dict -def _swap_params(model: torch.nn.Module, ema_parameters: EMAParameters): - """Swaps the parameters and buffers of a model with those in ema_parameters.""" - with torch.no_grad(): - ema_params = ema_parameters.named_parameters_dict - ema_buffers = ema_parameters.named_buffers_dict + for name, param in model.named_parameters(): + if name in ema_params: + param.data, ema_params[name] = ema_params[name], param.data - for name, param in model.named_parameters(): - if name in ema_params: - param.data, ema_params[name] = ema_params[name], param.data + for name, buffer in model.named_buffers(): + buffer.data, ema_buffers[name] = ema_buffers[name], buffer.data - for name, buffer in model.named_buffers(): - buffer.data, ema_buffers[name] = ema_buffers[name], buffer.data + def transfer_ema_params(self, model: torch.nn.Module): + """Transfers the parameters and buffers from the ema model to the supplied model.""" + with torch.no_grad(): + for name, param in model.named_parameters(): + if name in self.named_parameters_dict: + param.data = self.named_parameters_dict[name] + for name, buffer in model.named_buffers(): + buffer.data = self.named_buffers_dict[name] -def _move_params_to_device(ema_parameters: EMAParameters, destination_model: torch.nn.Module): - """Moves the ema parameters and buffers to the device of a destination model.""" - model_state_dict = destination_model.state_dict() - for name, param in ema_parameters.named_parameters_dict.items(): - ema_parameters.named_parameters_dict[name] = param.to(model_state_dict[name].device) + def move_params_to_device(self, destination_model: torch.nn.Module): + """Moves the ema parameters and buffers to the device of a destination model.""" + model_state_dict = destination_model.state_dict() + for name, param in self.named_parameters_dict.items(): + self.named_parameters_dict[name] = param.to(model_state_dict[name].device) - for name, buffer in ema_parameters.named_buffers_dict.items(): - ema_parameters.named_buffers_dict[name] = buffer.to(model_state_dict[name].device) + for name, buffer in self.named_buffers_dict.items(): + self.named_buffers_dict[name] = buffer.to(model_state_dict[name].device) From 2d109172ff3e8fc3a9d05068a12d69ad2637af16 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 3 Feb 2023 19:10:12 +0000 Subject: [PATCH 3/5] Some code cleanup --- composer/algorithms/ema/ema.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/composer/algorithms/ema/ema.py b/composer/algorithms/ema/ema.py index 2cf63a5908..8188468ae3 100644 --- a/composer/algorithms/ema/ema.py +++ b/composer/algorithms/ema/ema.py @@ -333,7 +333,9 @@ class EMAParameters: def __init__(self, model: Union[None, torch.nn.Module]): if model is not None: # Copy the trainable parameters and buffers. - self.named_parameters_dict = {name: param.data.clone() for name, param in model.named_parameters() if param.requires_grad} + self.named_parameters_dict = { + name: param.data.clone() for name, param in model.named_parameters() if param.requires_grad + } self.named_buffers_dict = {name: buffer.data.clone() for name, buffer in model.named_buffers()} else: # Empty storage From 673e0d7a5d8dcbe54bb64453fef03bcf618515e0 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Fri, 3 Feb 2023 19:40:43 +0000 Subject: [PATCH 4/5] Formatting --- composer/algorithms/ema/ema.py | 2 +- tests/algorithms/test_ema.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/composer/algorithms/ema/ema.py b/composer/algorithms/ema/ema.py index 8188468ae3..ba3abac7d5 100644 --- a/composer/algorithms/ema/ema.py +++ b/composer/algorithms/ema/ema.py @@ -72,7 +72,7 @@ def compute_ema(model: torch.nn.Module, ema_model: Union[torch.nn.Module, EMAPar if name in ema_buffers: ema_buffers[name].copy_(ema_buffers[name] * smoothing + param.data * (1. - smoothing)) else: - raise ValueError("ema_model must be a torch.nn.Module or EMAParameters") + raise ValueError('ema_model must be a torch.nn.Module or EMAParameters') class EMA(Algorithm): diff --git a/tests/algorithms/test_ema.py b/tests/algorithms/test_ema.py index 46c62059b3..a09bbc5a57 100644 --- a/tests/algorithms/test_ema.py +++ b/tests/algorithms/test_ema.py @@ -2,14 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import copy -import itertools import numpy as np import pytest import torch from composer.algorithms import EMA -from composer.algorithms.ema.ema import compute_ema, EMAParameters +from composer.algorithms.ema.ema import EMAParameters, compute_ema from composer.core import Event, Time, Timestamp, TimeUnit from tests.common import SimpleConvModel, SimpleTransformerClassifier from tests.common.models import configure_tiny_bert_hf_model @@ -33,12 +32,13 @@ def validate_model(model1, model2): model1_params, model1_buffers = dict(model1.named_parameters()), dict(model1.named_buffers()) model2_params, model2_buffers = dict(model2.named_parameters()), dict(model2.named_buffers()) - for name, param in model1_params.items(): + for name, _ in model1_params.items(): torch.testing.assert_close(model1_params[name].data, model2_params[name].data) - for name, buffer in model1_buffers.items(): + for name, _ in model1_buffers.items(): torch.testing.assert_close(model1_buffers[name].data, model2_buffers[name].data) + @pytest.mark.parametrize('smoothing', [0, 0.5, 0.99, 1]) @pytest.mark.parametrize('model_cls', [(SimpleConvModel), (SimpleTransformerClassifier), (configure_tiny_bert_hf_model)]) @@ -49,6 +49,7 @@ def test_ema(smoothing, model_cls): compute_ema(model=model, ema_model=ema_model, smoothing=smoothing) validate_ema(model, original_model, ema_model, smoothing) + # params = [(half_life, update_interval)] @pytest.mark.parametrize('params', [{ 'half_life': '10ba', From 08f13476500cea05ef3d84c1295f91d3595eccf8 Mon Sep 17 00:00:00 2001 From: Cory Stephenson Date: Wed, 8 Feb 2023 13:16:45 -0800 Subject: [PATCH 5/5] Fix lint --- composer/algorithms/ema/ema.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/composer/algorithms/ema/ema.py b/composer/algorithms/ema/ema.py index 6f126afc63..442a74237d 100644 --- a/composer/algorithms/ema/ema.py +++ b/composer/algorithms/ema/ema.py @@ -20,7 +20,9 @@ __all__ = ['EMA', 'compute_ema'] -def compute_ema(model: torch.nn.Module, ema_model: Union[torch.nn.Module, EMAParameters], smoothing: float = 0.99) -> None: +def compute_ema(model: torch.nn.Module, + ema_model: Union[torch.nn.Module, EMAParameters], + smoothing: float = 0.99) -> None: r"""Updates the weights of ``ema_model`` to be closer to the weights of ``model`` according to an exponential weighted average. Weights are updated according to