Skip to content

Commit 5b9f26f

Browse files
authored
Add scale_warmup argument to schedulers (#1268)
1 parent 365c714 commit 5b9f26f

File tree

3 files changed

+84
-18
lines changed

3 files changed

+84
-18
lines changed

composer/optim/scheduler.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -537,19 +537,25 @@ class MultiStepWithWarmupScheduler(ComposerScheduler):
537537
rate multiplier until the warmup has completed.
538538
539539
.. warning::
540-
Initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the milestones
541-
will still be scaled accordingly.
540+
By default, initial warmup time is **not** scaled according to any provided scale schedule ratio.
541+
To change this behavior, set ``scale_warmup=True``.
542542
543543
Args:
544544
t_warmup (str | Time): Warmup time.
545545
milestones (List[str | Time]): Times at which the learning rate should change.
546546
gamma (float): Multiplicative decay factor. Default = ``0.1``.
547+
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
547548
"""
548549

549-
def __init__(self, t_warmup: Union[str, Time], milestones: List[Union[str, Time]], gamma: float = 0.1):
550+
def __init__(self,
551+
t_warmup: Union[str, Time],
552+
milestones: List[Union[str, Time]],
553+
gamma: float = 0.1,
554+
scale_warmup: bool = False):
550555
self.t_warmup = t_warmup
551556
self.milestones = milestones
552557
self.gamma = gamma
558+
self.scale_warmup = scale_warmup
553559
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)
554560
self.step_scheduler = MultiStepScheduler(milestones=milestones, gamma=gamma)
555561

@@ -563,6 +569,8 @@ def __call__(self, state: State, ssr: float = 1.0):
563569
same unit as the trainer's max_duration parameter."""))
564570

565571
if state.timestamp < t_warmup:
572+
if self.scale_warmup:
573+
return self.warmup_scheduler(state, ssr)
566574
return self.warmup_scheduler(state)
567575

568576
return self.step_scheduler(state, ssr)
@@ -587,17 +595,31 @@ class ConstantWithWarmupScheduler(ComposerScheduler):
587595
Where :math:`\alpha` represents the learning rate multiplier to maintain while this scheduler is active, and
588596
:math:`t_{max}` represents the duration of this scheduler.
589597
598+
.. warning::
599+
By default, initial warmup time is **not** scaled according to any provided scale schedule ratio.
600+
To change this behavior, set ``scale_warmup=True``.
601+
590602
Args:
591603
t_warmup (str | Time): Warmup time.
592604
alpha (float): Learning rate multiplier to maintain while this scheduler is active. Default = ``1.0``.
593605
t_max (str | Time): Duration of this scheduler. Default = ``"1dur"``.
606+
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
594607
"""
595608

596-
def __init__(self, t_warmup: Union[str, Time], alpha: float = 1.0, t_max: Union[str, Time] = '1dur') -> None:
609+
def __init__(self,
610+
t_warmup: Union[str, Time],
611+
alpha: float = 1.0,
612+
t_max: Union[str, Time] = '1dur',
613+
scale_warmup: bool = False) -> None:
597614
self.t_warmup = t_warmup
598615
self.alpha = alpha
599616
self.t_max = t_max
600-
self.scheduler = LinearWithWarmupScheduler(t_warmup=t_warmup, alpha_i=alpha, alpha_f=alpha, t_max=t_max)
617+
self.scale_warmup = scale_warmup
618+
self.scheduler = LinearWithWarmupScheduler(t_warmup=t_warmup,
619+
alpha_i=alpha,
620+
alpha_f=alpha,
621+
t_max=t_max,
622+
scale_warmup=scale_warmup)
601623

602624
def __call__(self, state: State, ssr: float = 1.0) -> float:
603625
return self.scheduler(state, ssr)
@@ -628,27 +650,31 @@ class LinearWithWarmupScheduler(ComposerScheduler):
628650
and :math:`\alpha_f` represents the learning rate multiplier to decay to, and :math:`t_{max}` represents the duration
629651
of this scheduler.
630652
653+
631654
.. warning::
632-
Initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the duration of
633-
the scheduler is still scaled accordingly. To achieve this, after warmup, the scheduler's "pace" will be
634-
slightly distorted from what would otherwise be expected.
655+
By default, the initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the duration of
656+
the scheduler is still scaled accordingly. To achieve this, after warmup, the scheduler's "slope" will be
657+
slightly distorted from what would otherwise be expected. To scale the entire schedule, set ``scale_warmup=True``.
635658
636659
Args:
637660
t_warmup (str | Time): Warmup time.
638661
alpha_i (float): Initial learning rate multiplier. Default = ``1.0``.
639662
alpha_f (float): Final learning rate multiplier. Default = ``0.0``.
640663
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``.
664+
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
641665
"""
642666

643667
def __init__(self,
644668
t_warmup: Union[str, Time],
645669
alpha_i: float = 1.0,
646670
alpha_f: float = 0.0,
647-
t_max: Union[str, Time] = '1dur'):
671+
t_max: Union[str, Time] = '1dur',
672+
scale_warmup: bool = False):
648673
self.t_warmup = t_warmup
649674
self.alpha_i = alpha_i
650675
self.alpha_f = alpha_f
651676
self.t_max = t_max
677+
self.scale_warmup = scale_warmup
652678
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=alpha_i, t_max=t_warmup)
653679

654680
def __call__(self, state: State, ssr: float = 1.0):
@@ -661,6 +687,8 @@ def __call__(self, state: State, ssr: float = 1.0):
661687
same unit as the trainer's max_duration parameter."""))
662688

663689
if state.timestamp < t_warmup:
690+
if self.scale_warmup:
691+
return self.warmup_scheduler(state, ssr)
664692
return self.warmup_scheduler(state)
665693

666694
t_max = _convert_time(self.t_max, state, ssr=ssr)
@@ -695,20 +723,25 @@ class CosineAnnealingWithWarmupScheduler(ComposerScheduler):
695723
:math:`\alpha_f` represents the learning rate multiplier to decay to.
696724
697725
.. warning::
698-
Initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the duration of
699-
the scheduler is still scaled accordingly. To achieve this, after warmup, the scheduler's "pace" will be
700-
slightly distorted from what would otherwise be expected.
726+
By default, initial warmup time is **not** scaled according to any provided scale schedule ratio.
727+
To change this behavior, set ``scale_warmup=True``.
701728
702729
Args:
703730
t_warmup (str | Time): Warmup time.
704731
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``.
705732
alpha_f (float): Learning rate multiplier to decay to. Default = ``0.0``.
733+
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
706734
"""
707735

708-
def __init__(self, t_warmup: Union[str, Time], t_max: Union[str, Time] = '1dur', alpha_f: float = 0.0):
736+
def __init__(self,
737+
t_warmup: Union[str, Time],
738+
t_max: Union[str, Time] = '1dur',
739+
alpha_f: float = 0.0,
740+
scale_warmup: bool = False):
709741
self.t_warmup = t_warmup
710742
self.t_max = t_max
711743
self.alpha_f = alpha_f
744+
self.scale_warmup = scale_warmup
712745
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)
713746

714747
def __call__(self, state: State, ssr: float = 1.0):
@@ -721,6 +754,8 @@ def __call__(self, state: State, ssr: float = 1.0):
721754
same unit as the trainer's max_duration parameter."""))
722755

723756
if state.timestamp < t_warmup:
757+
if self.scale_warmup:
758+
return self.warmup_scheduler(state, ssr)
724759
return self.warmup_scheduler(state)
725760

726761
t_max = _convert_time(self.t_max, state, ssr=ssr)
@@ -754,26 +789,28 @@ class PolynomialWithWarmupScheduler(ComposerScheduler):
754789
:math:`\alpha_f` represents the learning rate multiplier to decay to.
755790
756791
.. warning::
757-
Initial warmup time is **not** scaled according to any provided scale schedule ratio! However, the duration of
758-
the scheduler is still scaled accordingly. To achieve this, after warmup, the scheduler's "pace" will be
759-
slightly distorted from what would otherwise be expected.
792+
By default, initial warmup time is **not** scaled according to any provided scale schedule ratio.
793+
To change this behavior, set ``scale_warmup=True``.
760794
761795
Args:
762796
t_warmup (str | Time): Warmup time.
763797
power (float): The exponent to be used for the proportionality relationship. Default = ``2.0``.
764798
t_max (str | Time): The duration of this scheduler. Default = ``"1dur"``.
765799
alpha_f (float): Learning rate multiplier to decay to. Default = ``0.0``.
800+
scale_warmup (float): SSR also scales the warmup period. Default = ``False``.
766801
"""
767802

768803
def __init__(self,
769804
t_warmup: Union[str, Time],
770805
power: float = 2.0,
771806
t_max: Union[str, Time] = '1dur',
772-
alpha_f: float = 0.0):
807+
alpha_f: float = 0.0,
808+
scale_warmup: bool = False):
773809
self.t_warmup = t_warmup
774810
self.power = power
775811
self.t_max = t_max
776812
self.alpha_f = alpha_f
813+
self.scale_warmup = scale_warmup
777814
self.warmup_scheduler = LinearScheduler(alpha_i=0.0, alpha_f=1.0, t_max=t_warmup)
778815

779816
def __call__(self, state: State, ssr: float = 1.0):
@@ -786,6 +823,8 @@ def __call__(self, state: State, ssr: float = 1.0):
786823
same unit as the trainer's max_duration parameter."""))
787824

788825
if state.timestamp < t_warmup:
826+
if self.scale_warmup:
827+
return self.warmup_scheduler(state, ssr)
789828
return self.warmup_scheduler(state)
790829

791830
t_max = _convert_time(self.t_max, state, ssr=ssr)

docs/source/trainer/schedulers.rst

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ For example, the code below will scale the training time by half
122122
)
123123

124124
Importantly, for our schedulers that have warmup, the warmup
125-
period is *never* scaled. For example, if we apply
125+
period is *not* scaled by default. For example, if we apply
126126
``scale_schedule_ratio=0.5`` to:
127127

128128
.. testcode::
@@ -136,3 +136,18 @@ period is *never* scaled. For example, if we apply
136136

137137
The resulting scheduler would warmup for 4 epochs and then
138138
have step milestones at 5 epochs and 10 epochs.
139+
140+
To scale the warmup period as well, set ``scale_warmup=True``. For example:
141+
142+
.. testcode::
143+
144+
from composer.optim.scheduler import MultiStepWithWarmupScheduler
145+
146+
scheduler = MultiStepWithWarmupScheduler(
147+
milestones=["10ep", "20ep"],
148+
t_warmup="4ep",
149+
scale_warmup=True,
150+
)
151+
152+
With ``scale_schedule_ratio=0.5``, this scheduler will warmup for 2 epochs,
153+
then step on 5 and 10 epochs.

tests/optim/test_scheduler.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,10 @@ def dummy_schedulers_state(dummy_model: torch.nn.Module, rank_zero_seed: int):
7676
['0ba', '5ba', '15ba', '25ba', '45ba'], [0.0, 0.5, 1.0, 0.1, 0.01]),
7777
pytest.param(MultiStepWithWarmupScheduler(t_warmup='10ba', milestones=cast(List, ['2ep', '4ep']), gamma=0.5), 0.5,
7878
['0ba', '5ba', '15ba', '1500ba', '2500ba'], [0.0, 0.5, 1.0, 0.5, 0.25]),
79+
pytest.param(
80+
MultiStepWithWarmupScheduler(
81+
t_warmup='10ba', milestones=cast(List, ['2ep', '4ep']), gamma=0.5, scale_warmup=True), 0.5,
82+
['0ba', '5ba', '15ba', '1500ba', '2500ba'], [0.0, 1.0, 1.0, 0.5, 0.25]),
7983
pytest.param(ConstantWithWarmupScheduler(t_warmup='500ep'), 1.0,
8084
['0ba', '250000ba', '500000ba', '750000ba', '1000000ba'], [0.0, 0.5, 1.0, 1.0, 1.0]),
8185
pytest.param(ConstantWithWarmupScheduler(t_warmup='500ep', alpha=3.0), 1.0,
@@ -92,14 +96,22 @@ def dummy_schedulers_state(dummy_model: torch.nn.Module, rank_zero_seed: int):
9296
['0ba', '250000ba', '500000ba', '500500ba', '501000ba', '502000ba'], [0.0, 1.5, 3.0, 2.5, 2.0, 2.0]),
9397
pytest.param(LinearWithWarmupScheduler(t_warmup='0.0005dur'), 1.0, ['0ba', '250ba', '500ba', '499750ba'],
9498
[0.0, 0.5, 1.0, 0.5]),
99+
pytest.param(LinearWithWarmupScheduler(t_warmup='500ba', scale_warmup=False), 0.5,
100+
['0ba', '250ba', '500ba', '249875ba'], [0.0, 0.5, 1.0, 0.5]),
101+
pytest.param(LinearWithWarmupScheduler(t_warmup='500ba', scale_warmup=True), 0.5,
102+
['0ba', '125ba', '250ba', '249875ba'], [0.0, 0.5, 1.0, 0.5]),
95103
pytest.param(CosineAnnealingWithWarmupScheduler(t_warmup='0.9dur'), 1.0,
96104
['0ba', '450000ba', '900000ba', '933333ba', '950000ba', '1000000ba'], [0.0, 0.5, 1.0, 0.75, 0.5, 0.0]),
97105
pytest.param(CosineAnnealingWithWarmupScheduler(t_warmup='0.9dur', alpha_f=0.5), 0.01,
98106
['0ba', '4500ba', '9000ba', '9333ba', '9500ba', '10000ba'], [0.0, 0.5, 1.0, 0.875, 0.75, 0.5]),
107+
pytest.param(CosineAnnealingWithWarmupScheduler(t_warmup='0.9dur', alpha_f=0.5, scale_warmup=True), 0.01,
108+
['0ba', '4500ba', '9000ba', '9333ba', '9500ba', '10000ba'], [0.0, 0.5, 1.0, 0.875, 0.75, 0.5]),
99109
pytest.param(PolynomialWithWarmupScheduler(t_warmup='0.9dur'), 1.0,
100110
['0ba', '450000ba', '900000ba', '913397ba', '929289ba', '1000000ba'], [0.0, 0.5, 1.0, 0.75, 0.5, 0.0]),
101111
pytest.param(PolynomialWithWarmupScheduler(t_warmup='0.9dur', alpha_f=0.5), 0.01,
102112
['0ba', '4500ba', '9000ba', '9134ba', '9293ba', '10000ba'], [0.0, 0.5, 1.0, 0.875, 0.75, 0.5]),
113+
pytest.param(PolynomialWithWarmupScheduler(t_warmup='0.9dur', alpha_f=0.5, scale_warmup=True), 0.01,
114+
['0ba', '4500ba', '9000ba', '9134ba', '9293ba', '10000ba'], [0.0, 0.5, 1.0, 0.875, 0.75, 0.5]),
103115
])
104116
def test_scheduler_init(scheduler: ComposerScheduler, ssr: float, test_times: List[str], expected_lrs: List[float],
105117
dummy_schedulers_state: State):

0 commit comments

Comments
 (0)