Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 56 additions & 17 deletions composer/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,19 +537,25 @@ class MultiStepWithWarmupScheduler(ComposerScheduler):
rate multiplier until the warmup has completed.

.. warning::
Initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the milestones
will still be scaled accordingly.
By default, initial warmup time is **not** scaled according to any provided scale schedule ratio.
To change this behavior, set ``scale_warmup=True``.

Args:
t_warmup (str | Time): Warmup time.
milestones (List[str | Time]): Times at which the learning rate should change.
gamma (float): Multiplicative decay factor. Default = ``0.1``.
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
"""

def __init__(self, t_warmup: Union[str, Time], milestones: List[Union[str, Time]], gamma: float = 0.1):
def __init__(self,
t_warmup: Union[str, Time],
milestones: List[Union[str, Time]],
gamma: float = 0.1,
scale_warmup: bool = False):
self.t_warmup = t_warmup
self.milestones = milestones
self.gamma = gamma
self.scale_warmup = scale_warmup
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)
self.step_scheduler = MultiStepScheduler(milestones=milestones, gamma=gamma)

Expand All @@ -563,6 +569,8 @@ def __call__(self, state: State, ssr: float = 1.0):
same unit as the trainer's max_duration parameter."""))

if state.timestamp < t_warmup:
if self.scale_warmup:
return self.warmup_scheduler(state, ssr)
return self.warmup_scheduler(state)

return self.step_scheduler(state, ssr)
Expand All @@ -587,17 +595,31 @@ class ConstantWithWarmupScheduler(ComposerScheduler):
Where :math:`\alpha` represents the learning rate multiplier to maintain while this scheduler is active, and
:math:`t_{max}` represents the duration of this scheduler.

.. warning::
By default, initial warmup time is **not** scaled according to any provided scale schedule ratio.
To change this behavior, set ``scale_warmup=True``.

Args:
t_warmup (str | Time): Warmup time.
alpha (float): Learning rate multiplier to maintain while this scheduler is active. Default = ``1.0``.
t_max (str | Time): Duration of this scheduler. Default = ``"1dur"``.
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
"""

def __init__(self, t_warmup: Union[str, Time], alpha: float = 1.0, t_max: Union[str, Time] = '1dur') -> None:
def __init__(self,
t_warmup: Union[str, Time],
alpha: float = 1.0,
t_max: Union[str, Time] = '1dur',
scale_warmup: bool = False) -> None:
self.t_warmup = t_warmup
self.alpha = alpha
self.t_max = t_max
self.scheduler = LinearWithWarmupScheduler(t_warmup=t_warmup, alpha_i=alpha, alpha_f=alpha, t_max=t_max)
self.scale_warmup = scale_warmup
self.scheduler = LinearWithWarmupScheduler(t_warmup=t_warmup,
alpha_i=alpha,
alpha_f=alpha,
t_max=t_max,
scale_warmup=scale_warmup)

def __call__(self, state: State, ssr: float = 1.0) -> float:
return self.scheduler(state, ssr)
Expand Down Expand Up @@ -628,27 +650,31 @@ class LinearWithWarmupScheduler(ComposerScheduler):
and :math:`\alpha_f` represents the learning rate multiplier to decay to, and :math:`t_{max}` represents the duration
of this scheduler.


.. warning::
Initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the duration of
the scheduler is still scaled accordingly. To achieve this, after warmup, the scheduler's "pace" will be
slightly distorted from what would otherwise be expected.
By default, the initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the duration of
the scheduler is still scaled accordingly. To achieve this, after warmup, the scheduler's "slope" will be
slightly distorted from what would otherwise be expected. To scale the entire schedule, set ``scale_warmup=True``.

Args:
t_warmup (str | Time): Warmup time.
alpha_i (float): Initial learning rate multiplier. Default = ``1.0``.
alpha_f (float): Final learning rate multiplier. Default = ``0.0``.
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``.
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
"""

def __init__(self,
t_warmup: Union[str, Time],
alpha_i: float = 1.0,
alpha_f: float = 0.0,
t_max: Union[str, Time] = '1dur'):
t_max: Union[str, Time] = '1dur',
scale_warmup: bool = False):
self.t_warmup = t_warmup
self.alpha_i = alpha_i
self.alpha_f = alpha_f
self.t_max = t_max
self.scale_warmup = scale_warmup
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=alpha_i, t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0):
Expand All @@ -661,6 +687,8 @@ def __call__(self, state: State, ssr: float = 1.0):
same unit as the trainer's max_duration parameter."""))

if state.timestamp < t_warmup:
if self.scale_warmup:
return self.warmup_scheduler(state, ssr)
return self.warmup_scheduler(state)

t_max = _convert_time(self.t_max, state, ssr=ssr)
Expand Down Expand Up @@ -695,20 +723,25 @@ class CosineAnnealingWithWarmupScheduler(ComposerScheduler):
:math:`\alpha_f` represents the learning rate multiplier to decay to.

.. warning::
Initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the duration of
the scheduler is still scaled accordingly. To achieve this, after warmup, the scheduler's "pace" will be
slightly distorted from what would otherwise be expected.
By default, initial warmup time is **not** scaled according to any provided scale schedule ratio.
To change this behavior, set ``scale_warmup=True``.

Args:
t_warmup (str | Time): Warmup time.
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``.
alpha_f (float): Learning rate multiplier to decay to. Default = ``0.0``.
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
"""

def __init__(self, t_warmup: Union[str, Time], t_max: Union[str, Time] = '1dur', alpha_f: float = 0.0):
def __init__(self,
t_warmup: Union[str, Time],
t_max: Union[str, Time] = '1dur',
alpha_f: float = 0.0,
scale_warmup: bool = False):
self.t_warmup = t_warmup
self.t_max = t_max
self.alpha_f = alpha_f
self.scale_warmup = scale_warmup
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0):
Expand All @@ -721,6 +754,8 @@ def __call__(self, state: State, ssr: float = 1.0):
same unit as the trainer's max_duration parameter."""))

if state.timestamp < t_warmup:
if self.scale_warmup:
return self.warmup_scheduler(state, ssr)
return self.warmup_scheduler(state)

t_max = _convert_time(self.t_max, state, ssr=ssr)
Expand Down Expand Up @@ -754,26 +789,28 @@ class PolynomialWithWarmupScheduler(ComposerScheduler):
:math:`\alpha_f` represents the learning rate multiplier to decay to.

.. warning::
Initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the duration of
the scheduler is still scaled accordingly. To achieve this, after warmup, the scheduler's "pace" will be
slightly distorted from what would otherwise be expected.
By default, initial warmup time is **not** scaled according to any provided scale schedule ratio.
To change this behavior, set ``scale_warmup=True``.

Args:
t_warmup (str | Time): Warmup time.
power (float): The exponent to be used for the proportionality relationship. Default = ``2.0``.
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``.
alpha_f (float): Learning rate multiplier to decay to. Default = ``0.0``.
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
"""

def __init__(self,
t_warmup: Union[str, Time],
power: float = 2.0,
t_max: Union[str, Time] = '1dur',
alpha_f: float = 0.0):
alpha_f: float = 0.0,
scale_warmup: bool = False):
self.t_warmup = t_warmup
self.power = power
self.t_max = t_max
self.alpha_f = alpha_f
self.scale_warmup = scale_warmup
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)

def __call__(self, state: State, ssr: float = 1.0):
Expand All @@ -786,6 +823,8 @@ def __call__(self, state: State, ssr: float = 1.0):
same unit as the trainer's max_duration parameter."""))

if state.timestamp < t_warmup:
if self.scale_warmup:
return self.warmup_scheduler(state, ssr)
return self.warmup_scheduler(state)

t_max = _convert_time(self.t_max, state, ssr=ssr)
Expand Down
17 changes: 16 additions & 1 deletion docs/source/trainer/schedulers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ For example, the code below will scale the training time by half
)

Importantly, for our schedulers that have warmup, the warmup
period is *never* scaled. For example, if we apply
period is *not* scaled by default. For example, if we apply
``scale_schedule_ratio=0.5`` to:

.. testcode::
Expand All @@ -136,3 +136,18 @@ period is *never* scaled. For example, if we apply

The resulting scheduler would warmup for 4 epochs and then
have step milestones at 5 epochs and 10 epochs.

To scale the warmup period as well, set ``scale_warmup=True``. For example:

.. testcode::

from composer.optim.scheduler import MultiStepWithWarmupScheduler

scheduler = MultiStepWithWarmupScheduler(
milestones=["10ep", "20ep"],
t_warmup="4ep",
scale_warmup=True,
)

With ``scale_schedule_ratio=0.5``, this scheduler will warmup for 2 epochs,
then step on 5 and 10 epochs.
12 changes: 12 additions & 0 deletions tests/optim/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def dummy_schedulers_state(dummy_model: torch.nn.Module, rank_zero_seed: int):
['0ba', '5ba', '15ba', '25ba', '45ba'], [0.0, 0.5, 1.0, 0.1, 0.01]),
pytest.param(MultiStepWithWarmupScheduler(t_warmup='10ba', milestones=cast(List, ['2ep', '4ep']), gamma=0.5), 0.5,
['0ba', '5ba', '15ba', '1500ba', '2500ba'], [0.0, 0.5, 1.0, 0.5, 0.25]),
pytest.param(
MultiStepWithWarmupScheduler(
t_warmup='10ba', milestones=cast(List, ['2ep', '4ep']), gamma=0.5, scale_warmup=True), 0.5,
['0ba', '5ba', '15ba', '1500ba', '2500ba'], [0.0, 1.0, 1.0, 0.5, 0.25]),
pytest.param(ConstantWithWarmupScheduler(t_warmup='500ep'), 1.0,
['0ba', '250000ba', '500000ba', '750000ba', '1000000ba'], [0.0, 0.5, 1.0, 1.0, 1.0]),
pytest.param(ConstantWithWarmupScheduler(t_warmup='500ep', alpha=3.0), 1.0,
Expand All @@ -92,14 +96,22 @@ def dummy_schedulers_state(dummy_model: torch.nn.Module, rank_zero_seed: int):
['0ba', '250000ba', '500000ba', '500500ba', '501000ba', '502000ba'], [0.0, 1.5, 3.0, 2.5, 2.0, 2.0]),
pytest.param(LinearWithWarmupScheduler(t_warmup='0.0005dur'), 1.0, ['0ba', '250ba', '500ba', '499750ba'],
[0.0, 0.5, 1.0, 0.5]),
pytest.param(LinearWithWarmupScheduler(t_warmup='500ba', scale_warmup=False), 0.5,
['0ba', '250ba', '500ba', '249875ba'], [0.0, 0.5, 1.0, 0.5]),
pytest.param(LinearWithWarmupScheduler(t_warmup='500ba', scale_warmup=True), 0.5,
['0ba', '125ba', '250ba', '249875ba'], [0.0, 0.5, 1.0, 0.5]),
pytest.param(CosineAnnealingWithWarmupScheduler(t_warmup='0.9dur'), 1.0,
['0ba', '450000ba', '900000ba', '933333ba', '950000ba', '1000000ba'], [0.0, 0.5, 1.0, 0.75, 0.5, 0.0]),
pytest.param(CosineAnnealingWithWarmupScheduler(t_warmup='0.9dur', alpha_f=0.5), 0.01,
['0ba', '4500ba', '9000ba', '9333ba', '9500ba', '10000ba'], [0.0, 0.5, 1.0, 0.875, 0.75, 0.5]),
pytest.param(CosineAnnealingWithWarmupScheduler(t_warmup='0.9dur', alpha_f=0.5, scale_warmup=True), 0.01,
['0ba', '4500ba', '9000ba', '9333ba', '9500ba', '10000ba'], [0.0, 0.5, 1.0, 0.875, 0.75, 0.5]),
pytest.param(PolynomialWithWarmupScheduler(t_warmup='0.9dur'), 1.0,
['0ba', '450000ba', '900000ba', '913397ba', '929289ba', '1000000ba'], [0.0, 0.5, 1.0, 0.75, 0.5, 0.0]),
pytest.param(PolynomialWithWarmupScheduler(t_warmup='0.9dur', alpha_f=0.5), 0.01,
['0ba', '4500ba', '9000ba', '9134ba', '9293ba', '10000ba'], [0.0, 0.5, 1.0, 0.875, 0.75, 0.5]),
pytest.param(PolynomialWithWarmupScheduler(t_warmup='0.9dur', alpha_f=0.5, scale_warmup=True), 0.01,
['0ba', '4500ba', '9000ba', '9134ba', '9293ba', '10000ba'], [0.0, 0.5, 1.0, 0.875, 0.75, 0.5]),
])
def test_scheduler_init(scheduler: ComposerScheduler, ssr: float, test_times: List[str], expected_lrs: List[float],
dummy_schedulers_state: State):
Expand Down