Skip to content

Commit 7ad634e

Browse files
coryMosaicMLBandish Shah
authored andcommitted
Quality of life updates to EMA (#1524)
1 parent 57655ca commit 7ad634e

File tree

6 files changed

+186
-149
lines changed

6 files changed

+186
-149
lines changed

composer/algorithms/ema/README.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,14 +73,20 @@ EMA also uses a bit of extra compute to calculate the moving average. This can l
7373

7474
## Suggested Hyperparameters
7575

76-
The Composer Trainer implementation of EMA has two hyperparameters:
76+
The Composer Trainer implementation of EMA has several hyperparameters:
7777

78-
- `half_life` - The half life for terms in the average. A longer half life means old information is remembered longer, a shorter half life means old information is discared sooner.
79-
- `update_interval` - The period at which updates to the moving average are computed. A longer update interval means that updates are computed less frequently.
78+
- `half_life` - The half life for terms in the average. A longer half life means old information is remembered longer, a shorter half life means old information is discared sooner. Defaults to `'1000ba'`
79+
- `update_interval` - The period at which updates to the moving average are computed. A longer update interval means that updates are computed less frequently. If left unspecified, this defaults to `1` in the units of `half_life`, or `1ba` if using `smoothing`.
80+
- `ema_start` - The amount of training completed before SWA is applied. The default value is `'0.0dur'` which starts EMA at the start of training.
8081

81-
A good typical starting value for `half_life` is `half_life="100ba"`, for a half life of 100 batches. At the same time, `update_interval` can be left unspecified which will default to `update_interval="1ba"`, or set to a larger value such as `update_interval="10ba"` to improve runtime. Shorter update intervals typically result in better generalization performance at the cost of somewhat increased runtime.
82+
A good typical starting value for `half_life` is `half_life="1000ba"`, for a half life of 1000 batches. At the same time, `update_interval` can be left unspecified which will default to `update_interval="1ba"`, or set to a larger value such as `update_interval="10ba"` to improve runtime. Shorter update intervals typically result in better generalization performance at the cost of somewhat increased runtime.
83+
84+
For compatibility with other implementations, there is also an option to specify the value of `smoothing` directly.
85+
86+
- `smoothing` - The coefficient representing the degree to which older observations are kept. The default (unspecified) value is `None`. Should only be used if `half_life` is not used.
87+
88+
To use this, `half_life` should be set to `half_life=None`, and the value of smoothing given instead. This value is not modified when `update_interval` is changed, and so changes to `update_interval` when using `smoothing` will result in changes to the time scale of the average.
8289

83-
Our implementation of EMA also provides the option to use the EMA weights as the training weights, which can be enabled by setting `train_with_ema_weights=True`. We reccomend leaving this off with the default value of `train_with_ema_weights=False.`
8490

8591
## Technical Details
8692

@@ -98,7 +104,7 @@ Our implementation of EMA also provides the option to use the EMA weights as the
98104
99105
> ❗ Evaluation should not be done with the training model
100106
>
101-
> Evaluation should be done with the `ema_model` in the functional impementation as this is the model containing the averaged parameters. The ema model can be accessed after training from the `EMA` object via `model = ema.get_ema_model(model)` in the composer trainer implementation.
107+
> Evaluation should be done with the `ema_model` in the functional impementation as this is the model containing the averaged parameters. The ema model can be accessed after training from the `EMA` object via `model = ema.ema_model` in the composer trainer implementation. Similarly, the model without ema applied (the training model) can be accessed via `model=ema.training_model`. By default, when saving checkpoints with the `CheckpointSaver` callback or through trainer arguments the weights saved will be the ema model weights. An exception is if saving is done by explicitly calling `trainer.save_checkpoint()` which will result in the training model weights being saved as `state.model`.
102108
103109

104110
## Attribution

0 commit comments

Comments
 (0)