lr.py 88.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings
17 18 19 20

import numpy

from paddle import Tensor
21
from paddle.fluid import core
22

G
guguguzi 已提交
23
__all__ = [  # noqa
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
    'LRScheduler',
    'NoamDecay',
    'PiecewiseDecay',
    'NaturalExpDecay',
    'InverseTimeDecay',
    'PolynomialDecay',
    'LinearWarmup',
    'ExponentialDecay',
    'MultiStepDecay',
    'StepDecay',
    'LambdaDecay',
    'ReduceOnPlateau',
    'CosineAnnealingDecay',
    'MultiplicativeDecay',
    'OneCycleLR',
    'CyclicLR',
40 41 42
]


43
class LRScheduler:
44 45 46 47
    """

    LRScheduler Base class. Define the common interface of a learning rate scheduler.

Z
Zhou Wei 已提交
48
    User can import it by ``from paddle.optimizer.lr import LRScheduler`` ,
49 50 51 52 53 54 55 56 57 58 59 60 61 62

    then overload it for your subclass and have a custom implementation of ``get_lr()`` .

    Otherwise, an ``NotImplementedError`` exception will be thrown.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        instance to schedule learning rate.

    Examples:
63
        Here is an example of a simple ``StepDecay`` implementation.
G
guguguzi 已提交
64

65
        .. code-block:: python
G
guguguzi 已提交
66

67
            import paddle
Z
Zhou Wei 已提交
68
            from paddle.optimizer.lr import LRScheduler
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

            class StepDecay(LRScheduler):
                def __init__(self,
                            learning_rate,
                            step_size,
                            gamma=0.1,
                            last_epoch=-1,
                            verbose=False):
                    if not isinstance(step_size, int):
                        raise TypeError(
                            "The type of 'step_size' must be 'int', but received %s." %
                            type(step_size))
                    if gamma >= 1.0:
                        raise ValueError('gamma should be < 1.0.')

                    self.step_size = step_size
                    self.gamma = gamma
86
                    super().__init__(learning_rate, last_epoch, verbose)
87 88 89 90

                def get_lr(self):
                    i = self.last_epoch // self.step_size
                    return self.base_lr * (self.gamma**i)
91 92 93 94 95 96

    """

    def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
97 98 99 100
                "The type of learning rate must be float, but received {}".format(
                    type(learning_rate)
                )
            )
101 102
        if learning_rate < 0:
            raise ValueError(f"Invalid learning rate: {learning_rate}")
103 104 105 106 107 108 109 110 111
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = last_epoch
        self.verbose = verbose
        self._var_name = None

        self.step()

    def __call__(self):
G
guguguzi 已提交
112
        """
S
Shuangchi He 已提交
113
        Return latest computed learning rate on current epoch.
114 115 116 117 118
        """
        return self.last_lr

    def step(self, epoch=None):
        """
119

G
guguguzi 已提交
120
        ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
121
        The new learning rate will take effect on next ``optimizer.step`` .
122 123 124 125 126 127

        Args:
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
128 129 130 131 132 133 134 135 136 137 138 139 140
        Examples:
            .. code-block:: python

                import paddle
                value = paddle.arange(26, dtype='float32')
                a = paddle.reshape(value, [2, 13])
                linear = paddle.nn.Linear(13, 5)
                adadelta = paddle.optimizer.Adadelta(learning_rate=0.0003, epsilon=1e-06, rho=0.95,
                                            parameters = linear.parameters())
                out = linear(a)
                out.backward()
                adadelta.step()
                adadelta.clear_grad()
141

142 143 144 145 146 147 148 149 150 151 152 153
        Examples:
            .. code-block:: python
                import paddle
                value = paddle.arange(26, dtype='float32')
                a = paddle.reshape(value, [2, 13])
                linear = paddle.nn.Linear(13, 5)
                adadelta = paddle.optimizer.Adadelta(learning_rate=0.0003, epsilon=1e-06, rho=0.95,
                                            parameters = linear.parameters())
                out = linear(a)
                out.backward()
                adadelta.step()
                adadelta.clear_grad()
154 155 156 157 158 159 160 161 162 163 164 165
        """
        if epoch is None:
            self.last_epoch += 1
            self.last_lr = self.get_lr()
        else:
            self.last_epoch = epoch
            if hasattr(self, "_get_closed_form_lr"):
                self.last_lr = self._get_closed_form_lr()
            else:
                self.last_lr = self.get_lr()

        if self.verbose:
166 167 168 169 170
            print(
                'Epoch {}: {} set learning rate to {}.'.format(
                    self.last_epoch, self.__class__.__name__, self.last_lr
                )
            )
171 172 173

    def state_dict(self):
        """
174

175 176
        Returns the state of the scheduler as a :class:`dict`.

177
        It is a subset of ``self.__dict__`` .
178
        """
179
        self.state_keys()
180 181 182 183 184 185
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Tensor):
186 187 188 189
                assert (
                    value.size == 1
                ), "numel of Tensor in state_dict must be 1"
                value = float(value)
190 191 192 193
            state_dict[key] = value

        return state_dict

194
    # For those subclass who overload LRScheduler, "last_epoch, last_lr" will be saved by default.
195
    # (Note): you can change it for your subclass.
196
    def state_keys(self):
197
        """
198 199 200 201 202 203 204

        For those subclass who overload ``LRScheduler`` (Base Class). Acquiescently, "last_epoch, last_lr" will be saved by ``self.keys = ['last_epoch', 'last_lr']`` .

        ``last_epoch`` is the current epoch num, and ``last_lr`` is the current learning rate.

        If you want to change the default behavior, you should have a custom implementation of ``_state_keys()`` to redefine ``self.keys`` .

205 206 207
        """
        self.keys = ['last_epoch', 'last_lr']

208
    def set_state_dict(self, state_dict):
209
        """
210

211 212
        Loads the schedulers state.
        """
213
        self.state_keys()
214 215 216 217 218
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
219 220 221 222
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".format(
                        key
                    )
                )
223 224 225 226 227
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

228 229
    # alias for set_state_dict
    set_dict = set_state_dict
230 231

    def get_lr(self):
232
        """
G
guguguzi 已提交
233

234 235 236 237
        For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .

        Otherwise, an ``NotImplementedError`` exception will be thrown.
        """
238 239 240 241
        # calculate by python float
        raise NotImplementedError


242
class NoamDecay(LRScheduler):
243
    r"""
244

G
guguguzi 已提交
245
    Applies Noam Decay to the initial learning rate.
246 247 248 249 250 251 252

    The algorithm can be described as following.

    .. math::

        new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})

G
guguguzi 已提交
253
    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
254 255 256 257 258 259 260


    Args:
        d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
        warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
        learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
261
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
262 263

    Returns:
264
        ``NoamDecay`` instance to schedule learning rate.
265 266 267

    Examples:
        .. code-block:: python
268
            :name: code-example1
269

270
            # Example1: train on default dynamic graph mode
271 272 273
            import paddle
            import numpy as np

274
            # train on default dynamic graph mode
275
            linear = paddle.nn.Linear(10, 10)
276 277
            scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
278
            for epoch in range(20):
Z
Zhou Wei 已提交
279
                for batch_id in range(5):
280
                    x = paddle.uniform([10, 10])
281
                    out = linear(x)
C
chentianyu03 已提交
282
                    loss = paddle.mean(out)
283
                    loss.backward()
284 285
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
286 287
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
288

289 290 291 292 293 294
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
295 296 297 298
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
299 300
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
301 302
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
303
                scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
304 305 306 307 308 309
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
310
                for batch_id in range(5):
311 312 313 314 315 316
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
317
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
318 319
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
320 321 322

    """

323 324 325 326 327 328 329 330
    def __init__(
        self,
        d_model,
        warmup_steps,
        learning_rate=1.0,
        last_epoch=-1,
        verbose=False,
    ):
331 332 333
        if d_model <= 0:
            raise ValueError("d_model should be grater than 0")

334 335
        self.d_model = d_model
        self.warmup_steps = warmup_steps
336
        super().__init__(learning_rate, last_epoch, verbose)
337 338 339 340 341 342 343 344 345 346

    def get_lr(self):
        if self.last_epoch == 0:
            a = 1
        else:
            a = self.last_epoch**-0.5
        b = self.warmup_steps**-1.5 * self.last_epoch
        return self.base_lr * (self.d_model**-0.5) * min(a, b)


347
class PiecewiseDecay(LRScheduler):
348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
    """

    Piecewise learning rate scheduler.

    The algorithm can be described as the code below:

    .. code-block:: text

        boundaries = [100, 200]
        values = [1.0, 0.5, 0.1]
        if epoch < 100:
            learning_rate = 1.0
        elif 100 <= global_step < 200:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Args:
G
guguguzi 已提交
366 367
        boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int.
        values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries.
368
            The type of element in the list is python float. The ``values`` have one more element than ``boundaries``.
369
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
370
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
371 372

    Returns:
373
        ``PiecewiseDecay`` instance to schedule learning rate.
374 375

    Examples:
G
guguguzi 已提交
376

377
        .. code-block:: python
378
            :name: code-example1
379

380
            # Example1: train on default dynamic graph mode
381 382 383
            import paddle
            import numpy as np

384
            # train on default dynamic graph mode
385
            linear = paddle.nn.Linear(10, 10)
386 387
            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
388
            for epoch in range(20):
Z
Zhou Wei 已提交
389
                for batch_id in range(5):
390
                    x = paddle.uniform([10, 10])
391
                    out = linear(x)
C
chentianyu03 已提交
392
                    loss = paddle.mean(out)
393
                    loss.backward()
394 395
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
396 397
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
398

399 400 401 402 403 404
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
405 406 407 408
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
409 410
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
411 412
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
413
                scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
414 415 416 417 418 419
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
420
                for batch_id in range(5):
421 422 423 424 425 426
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
427
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
428 429
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
430 431 432
    """

    def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
433 434 435 436 437 438 439 440
        if len(boundaries) == 0:
            raise ValueError('The boundaries cannot be empty.')

        if len(values) <= len(boundaries):
            raise ValueError(
                f'The values have one more element than boundaries, but received len(values) [{len(values)}] < len(boundaries) + 1 [{len(boundaries) + 1}].'
            )

441 442
        self.boundaries = boundaries
        self.values = values
443
        super().__init__(last_epoch=last_epoch, verbose=verbose)
444 445 446 447 448 449 450 451

    def get_lr(self):
        for i in range(len(self.boundaries)):
            if self.last_epoch < self.boundaries[i]:
                return self.values[i]
        return self.values[len(self.values) - 1]


452
class NaturalExpDecay(LRScheduler):
453
    r"""
454 455

    Applies natural exponential decay to the initial learning rate.
G
guguguzi 已提交
456

457 458 459 460
    The algorithm can be described as following:

    .. math::

461
        new\_learning\_rate = learning\_rate * e^{- gamma * epoch}
462 463 464

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
465
        gamma (float, optional): A Ratio to update the learning rate, should greater than 0.0 to make learning rate decay. Default: 0.1.
466
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
467
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
468 469

    Returns:
470
        ``NaturalExpDecay`` instance to schedule learning rate.
471 472

    Examples:
G
guguguzi 已提交
473

474
        .. code-block:: python
475
            :name: code-example1
476

477
            # Example1: train on default dynamic graph mode
478 479 480
            import paddle
            import numpy as np
            linear = paddle.nn.Linear(10, 10)
481 482
            scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
483
            for epoch in range(20):
Z
Zhou Wei 已提交
484
                for batch_id in range(5):
485
                    x = paddle.uniform([10, 10])
486
                    out = linear(x)
C
chentianyu03 已提交
487
                    loss = paddle.mean(out)
488
                    loss.backward()
489 490
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
491 492
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
493

494 495 496 497 498 499
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
500 501 502 503
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
504 505
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
506 507
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
508
                scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
509 510 511 512 513 514
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
515
                for batch_id in range(5):
516 517 518 519 520 521
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
522
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
523 524
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
525 526 527
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
528 529 530
        assert (
            gamma > 0.0
        ), " 'gamma' must be a positive number so that the learning rate will decay."
531
        self.gamma = gamma
532
        super().__init__(learning_rate, last_epoch, verbose)
533 534 535 536 537

    def get_lr(self):
        return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)


538
class InverseTimeDecay(LRScheduler):
539
    r"""
540 541 542 543 544 545 546

    Applies inverse time decay to the initial learning rate.

    The algorithm can be described as following:

    .. math::

547
        new\_learning\_rate = \frac{learning\_rate}{1 + gamma * epoch}
548 549 550

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
551
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
552 553
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
554
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
555 556

    Returns:
557
        ``InverseTimeDecay`` instance to schedule learning rate.
558 559

    Examples:
G
guguguzi 已提交
560

561
        .. code-block:: python
562
            :name: code-example1
563

564
            # Example1: train on default dynamic graph mode
565 566 567
            import paddle
            import numpy as np

568
            # train on default dynamic graph mode
569
            linear = paddle.nn.Linear(10, 10)
570 571
            scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
572
            for epoch in range(20):
Z
Zhou Wei 已提交
573
                for batch_id in range(5):
574
                    x = paddle.uniform([10, 10])
575
                    out = linear(x)
C
chentianyu03 已提交
576
                    loss = paddle.mean(out)
577
                    loss.backward()
578 579
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
580 581
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
582

583 584 585 586 587 588
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
589 590 591 592
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
593 594
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
595 596
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
597
                scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
598 599 600 601 602 603
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
604
                for batch_id in range(5):
605 606 607 608 609 610
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
611
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
612 613
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
614 615 616 617 618

    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
619
        super().__init__(learning_rate, last_epoch, verbose)
620 621 622 623 624

    def get_lr(self):
        return self.base_lr / (1 + self.gamma * self.last_epoch)


625
class PolynomialDecay(LRScheduler):
626
    r"""
627 628 629 630 631 632 633 634 635

    Applies polynomial decay to the initial learning rate.

    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

G
guguguzi 已提交
636
        decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps})
637

638
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
639 640 641 642 643

    If cycle is set to False, then:

    .. math::

G
guguguzi 已提交
644
        epoch & = min(epoch, decay\_steps)
645

646
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
647 648 649 650


    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
651
        decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer.
652
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
653
        power(float, optional): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 1.0.
G
guguguzi 已提交
654
        cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
655 656
            to ``end_lr`` .  If False, the learning rate is monotone decreasing. Default: False.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
657
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
658 659

    Returns:
660
        ``PolynomialDecay`` instance to schedule learning rate.
661 662

    Examples:
G
guguguzi 已提交
663

664
        .. code-block:: python
665
            :name: code-example1
666

667
            # Example1: train on default dynamic graph mode
668 669 670
            import paddle
            import numpy as np

671
            # train on default dynamic graph mode
672
            linear = paddle.nn.Linear(10, 10)
673 674
            scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
675
            for epoch in range(20):
Z
Zhou Wei 已提交
676
                for batch_id in range(5):
677
                    x = paddle.uniform([10, 10])
678
                    out = linear(x)
C
chentianyu03 已提交
679
                    loss = paddle.mean(out)
680
                    loss.backward()
681 682
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
683 684
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
685

686 687 688 689 690 691
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
692 693 694 695
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
696 697
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
698 699
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
700
                scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
701 702 703 704 705 706
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
707
                for batch_id in range(5):
708 709 710 711 712 713
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
714
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
715 716
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
717 718
    """

719 720 721 722 723 724 725 726 727 728
    def __init__(
        self,
        learning_rate,
        decay_steps,
        end_lr=0.0001,
        power=1.0,
        cycle=False,
        last_epoch=-1,
        verbose=False,
    ):
729
        assert decay_steps > 0 and isinstance(
730 731
            decay_steps, int
        ), " 'decay_steps' must be a positive integer."
732 733
        self.decay_steps = decay_steps
        self.end_lr = end_lr
734 735 736
        assert (
            power > 0.0
        ), " 'power' must be greater than 0.0 so that the learning rate will decay."
737 738
        self.power = power
        self.cycle = cycle
739
        super().__init__(learning_rate, last_epoch, verbose)
740 741 742 743 744 745

    def get_lr(self):
        tmp_epoch_num = self.last_epoch
        tmp_decay_steps = self.decay_steps
        if self.cycle:
            div_res = math.ceil(
746 747
                float(self.last_epoch) / float(self.decay_steps)
            )
748 749 750 751 752 753 754 755

            if self.last_epoch == 0:
                div_res = 1
            tmp_decay_steps = self.decay_steps * div_res
        else:
            tmp_epoch_num = min(self.last_epoch, self.decay_steps)

        return (self.base_lr - self.end_lr) * (
756 757
            (1 - float(tmp_epoch_num) / float(tmp_decay_steps)) ** self.power
        ) + self.end_lr
758 759


760
class LinearWarmup(LRScheduler):
761
    r"""
762 763 764

    Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
G
guguguzi 已提交
765

766
    When epoch < warmup_steps, learning rate is updated as:
G
guguguzi 已提交
767

768
    .. math::
G
guguguzi 已提交
769

770
            lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps}
G
guguguzi 已提交
771

772
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
G
guguguzi 已提交
773

774
    When epoch >= warmup_steps, learning rate is updated as:
G
guguguzi 已提交
775

776
    .. math::
G
guguguzi 已提交
777

778
            lr = learning_rate
G
guguguzi 已提交
779

780
    where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
781 782

    Args:
783
        learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
784
        warmup_steps (int): total steps of warm up. It must be a positive integer.
785 786 787
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
788
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
789 790

    Returns:
791
        ``LinearWarmup`` instance to schedule learning rate.
792 793

    Examples:
G
guguguzi 已提交
794

795
        .. code-block:: python
796
            :name: code-example1
797

798
            # Example1: train on default dynamic graph mode
799 800 801
            import paddle
            import numpy as np

802
            # train on default dynamic graph mode
803
            linear = paddle.nn.Linear(10, 10)
804
            scheduler = paddle.optimizer.lr.LinearWarmup(
805
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
806
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
807
            for epoch in range(20):
Z
Zhou Wei 已提交
808
                for batch_id in range(5):
809
                    x = paddle.uniform([10, 10])
810
                    out = linear(x)
C
chentianyu03 已提交
811
                    loss = paddle.mean(out)
812
                    loss.backward()
813 814
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
815 816
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
817

818 819 820 821 822 823
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
824 825 826 827
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
828 829
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
830 831
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
832
                scheduler = paddle.optimizer.lr.LinearWarmup(
833 834 835 836 837 838 839
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
840
                for batch_id in range(5):
841 842 843 844 845 846
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
847
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
848 849
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
850 851
    """

852 853 854 855 856 857 858 859 860
    def __init__(
        self,
        learning_rate,
        warmup_steps,
        start_lr,
        end_lr,
        last_epoch=-1,
        verbose=False,
    ):
861
        type_check = isinstance(learning_rate, (float, int, LRScheduler))
862 863
        if not type_check:
            raise TypeError(
864 865 866 867
                "the type of learning_rate should be [int, float or LRScheduler], the current type is {}".format(
                    learning_rate
                )
            )
868
        self.learning_rate = learning_rate
869
        assert warmup_steps > 0 and isinstance(
870 871
            warmup_steps, int
        ), " 'warmup_steps' must be a positive integer."
872 873 874
        self.warmup_steps = warmup_steps
        self.start_lr = start_lr
        self.end_lr = end_lr
875 876
        assert (
            end_lr > start_lr
877
        ), f"end_lr {end_lr} must be greater than start_lr {start_lr}"
878
        super().__init__(start_lr, last_epoch, verbose)
879

880 881 882 883 884 885
    def state_dict(self):
        """
        Returns the state of the LinearWarmup scheduler as a :class:`dict`.

        It is a subset of ``self.__dict__`` .
        """
886
        state_dict = super().state_dict()
887 888 889 890 891 892 893 894
        if isinstance(self.learning_rate, LRScheduler):
            state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
        return state_dict

    def set_state_dict(self, state_dict):
        """
        Loads state_dict for LinearWarmup scheduler.
        """
895
        super().set_state_dict(state_dict)
896 897 898
        if isinstance(self.learning_rate, LRScheduler):
            self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])

899 900 901
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return (self.end_lr - self.start_lr) * float(
902 903
                self.last_epoch
            ) / float(self.warmup_steps) + self.start_lr
904
        else:
905
            if isinstance(self.learning_rate, LRScheduler):
906 907
                self.learning_rate.step(self.last_epoch - self.warmup_steps)
                return self.learning_rate()
908 909 910 911

            return self.learning_rate


912
class ExponentialDecay(LRScheduler):
913
    r"""
914

915
    Update learning rate by `gamma` each epoch.
916 917

    The algorithm can be described as following.
G
guguguzi 已提交
918

919 920 921 922 923 924
    .. math::

        new\_learning\_rate = last\_learning\_rate * gamma

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
925
        gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
926
            It should be in interval (0.0, 1.0).
927
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
928
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
929 930

    Returns:
931
        ``ExponentialDecay`` instance to schedule learning rate.
932 933

    Examples:
G
guguguzi 已提交
934

935
        .. code-block:: python
936
            :name: code-example1
937

938
            # Example1: train on default dynamic graph mode
939 940 941
            import paddle
            import numpy as np

942
            # train on default dynamic graph mode
943
            linear = paddle.nn.Linear(10, 10)
944 945
            scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
946
            for epoch in range(20):
Z
Zhou Wei 已提交
947
                for batch_id in range(5):
948
                    x = paddle.uniform([10, 10])
949
                    out = linear(x)
C
chentianyu03 已提交
950
                    loss = paddle.mean(out)
951
                    loss.backward()
952 953
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
954 955
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
956

957 958 959 960 961 962
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
963 964 965 966
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
967 968
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
969 970
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
971
                scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
972 973 974 975 976 977
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
978
                for batch_id in range(5):
979 980 981 982 983 984
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
985
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
986 987
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
988 989 990
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
991 992 993
        assert (
            gamma > 0.0 and gamma < 1.0
        ), " 'gamma' must be in interval (0.0, 1.0) so that the learning rate will decay."
994
        self.gamma = gamma
995
        super().__init__(learning_rate, last_epoch, verbose)
996 997 998 999 1000

    def get_lr(self):
        return self.base_lr * (self.gamma**self.last_epoch)


1001
class MultiStepDecay(LRScheduler):
1002
    """
1003
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
1004

G
guguguzi 已提交
1005
    The algorithm can be described as the code below.
1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
G
guguguzi 已提交
1022
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
1023 1024
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1025
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
1026

1027 1028

    Returns:
1029
        ``MultiStepDecay`` instance to schedule learning rate.
1030 1031

    Examples:
G
guguguzi 已提交
1032

1033
        .. code-block:: python
1034
            :name: code-example1
1035

1036
            # Example1: train on default dynamic graph mode
1037 1038 1039
            import paddle
            import numpy as np

1040
            # train on default dynamic graph mode
1041
            linear = paddle.nn.Linear(10, 10)
1042 1043
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1044
            for epoch in range(20):
Z
Zhou Wei 已提交
1045
                for batch_id in range(5):
1046
                    x = paddle.uniform([10, 10])
1047
                    out = linear(x)
C
chentianyu03 已提交
1048
                    loss = paddle.mean(out)
1049
                    loss.backward()
1050 1051
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1052 1053
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1054

1055 1056 1057 1058 1059 1060
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
1061 1062 1063 1064
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1065 1066
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1067 1068
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1069
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
1070 1071 1072 1073 1074 1075
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1076
                for batch_id in range(5):
1077 1078 1079 1080 1081 1082
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1083
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1084 1085
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1086 1087
    """

1088 1089 1090
    def __init__(
        self, learning_rate, milestones, gamma=0.1, last_epoch=-1, verbose=False
    ):
1091 1092 1093
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
1094 1095
                % type(milestones)
            )
1096

1097
        if not all(
1098 1099
            milestones[i] < milestones[i + 1]
            for i in range(len(milestones) - 1)
1100
        ):
1101 1102 1103 1104 1105 1106
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.milestones = milestones
        self.gamma = gamma
1107
        super().__init__(learning_rate, last_epoch, verbose)
1108 1109 1110 1111 1112

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
1113
        return self.base_lr * (self.gamma ** len(self.milestones))
1114 1115


1116
class StepDecay(LRScheduler):
1117 1118 1119
    """
    Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.

G
guguguzi 已提交
1120
    The algorithm can be described as the code below.
1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        gamma = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
1135
        step_size (int): the interval to update. It must be a positive integer.
G
guguguzi 已提交
1136
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
1137 1138
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1139
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1140 1141

    Returns:
1142
        ``StepDecay`` instance to schedule learning rate.
1143 1144 1145


    Examples:
G
guguguzi 已提交
1146

1147
        .. code-block:: python
1148
            :name: code-example1
1149

1150
            # Example1: train on default dynamic graph mode
1151 1152 1153
            import paddle
            import numpy as np

1154
            # train on default dynamic graph mode
1155
            linear = paddle.nn.Linear(10, 10)
1156 1157
            scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1158
            for epoch in range(20):
Z
Zhou Wei 已提交
1159
                for batch_id in range(5):
1160
                    x = paddle.uniform([10, 10])
1161
                    out = linear(x)
C
chentianyu03 已提交
1162
                    loss = paddle.mean(out)
1163
                    loss.backward()
1164 1165
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1166 1167
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1168

1169 1170 1171 1172 1173 1174
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
1175 1176 1177 1178
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1179 1180
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1181 1182
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1183
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
1184 1185 1186 1187 1188 1189
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1190
                for batch_id in range(5):
1191 1192 1193 1194 1195 1196
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1197
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1198 1199
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1200 1201
    """

1202 1203 1204
    def __init__(
        self, learning_rate, step_size, gamma=0.1, last_epoch=-1, verbose=False
    ):
1205 1206
        if not isinstance(step_size, int):
            raise TypeError(
1207 1208 1209
                "The type of 'step_size' must be 'int', but received %s."
                % type(step_size)
            )
1210 1211 1212
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

1213
        assert step_size > 0 and isinstance(
1214 1215
            step_size, int
        ), " 'step_size' must be a positive integer."
1216 1217
        self.step_size = step_size
        self.gamma = gamma
1218
        super().__init__(learning_rate, last_epoch, verbose)
1219 1220 1221 1222 1223 1224

    def get_lr(self):
        i = self.last_epoch // self.step_size
        return self.base_lr * (self.gamma**i)


1225
class LambdaDecay(LRScheduler):
1226
    """
C
co63oc 已提交
1227
    Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is function which receives ``epoch`` .
1228

G
guguguzi 已提交
1229
    The algorithm can be described as the code below.
1230 1231 1232 1233 1234 1235

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

1236 1237 1238
        learning_rate = 0.5        # epoch 0, 0.5*0.95**0
        learning_rate = 0.475      # epoch 1, 0.5*0.95**1
        learning_rate = 0.45125    # epoch 2, 0.5*0.95**2
1239 1240 1241 1242 1243

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1244
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
1245

1246
    Returns:
1247
        ``LambdaDecay`` instance to schedule learning rate.
1248 1249

    Examples:
G
guguguzi 已提交
1250

1251
        .. code-block:: python
1252
            :name: code-example1
1253

1254
            # Example1: train on default dynamic graph mode
1255 1256 1257
            import paddle
            import numpy as np

1258
            # train on default dynamic graph mode
1259
            linear = paddle.nn.Linear(10, 10)
1260 1261
            scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1262
            for epoch in range(20):
Z
Zhou Wei 已提交
1263
                for batch_id in range(5):
1264
                    x = paddle.uniform([10, 10])
1265
                    out = linear(x)
C
chentianyu03 已提交
1266
                    loss = paddle.mean(out)
1267
                    loss.backward()
1268 1269
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1270 1271
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1272

1273 1274 1275 1276 1277 1278
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
1279 1280 1281 1282
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1283 1284
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1285 1286
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1287
                scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
1288 1289 1290 1291 1292 1293
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1294
                for batch_id in range(5):
1295 1296 1297 1298 1299 1300
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1301
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1302 1303
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1304 1305 1306 1307 1308 1309

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
1310
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1311 1312
                % type(lr_lambda)
            )
1313 1314

        self.lr_lambda = lr_lambda
1315
        super().__init__(learning_rate, last_epoch, verbose)
1316 1317 1318 1319 1320

    def get_lr(self):
        return self.base_lr * self.lr_lambda(self.last_epoch)


1321
class ReduceOnPlateau(LRScheduler):
1322
    """
G
guguguzi 已提交
1323
    Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
1324 1325
    by 2 to 10 times once model performance has no longer improvement.

1326
    The ``metrics`` is the one which has been pass into ``step`` , it's shape must [] or [1]. When ``metrics``
G
guguguzi 已提交
1327 1328
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
1329 1330 1331 1332 1333 1334
    number of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
1335 1336
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning
1337
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
G
guguguzi 已提交
1338
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
1339
            It should be less than 1.0. Default: 0.1.
G
guguguzi 已提交
1340
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
1341
            Default: 10.
G
guguguzi 已提交
1342
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
1343 1344
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
G
guguguzi 已提交
1345
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
1346 1347 1348
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
G
guguguzi 已提交
1349
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
1350
            the update is ignored. Default: 1e-8.
1351 1352
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.

G
guguguzi 已提交
1353

1354
    Returns:
1355
        ``ReduceOnPlateau`` instance to schedule learning rate.
1356 1357 1358 1359


    Examples:
        .. code-block:: python
1360
            :name: code-example1
1361

1362
            # Example1: train on default dynamic graph mode
1363 1364 1365
            import paddle
            import numpy as np

1366
            # train on default dynamic graph mode
1367
            linear = paddle.nn.Linear(10, 10)
1368 1369
            scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1370
            for epoch in range(20):
Z
Zhou Wei 已提交
1371
                for batch_id in range(5):
1372
                    x = paddle.uniform([10, 10])
1373
                    out = linear(x)
C
chentianyu03 已提交
1374
                    loss = paddle.mean(out)
1375
                    loss.backward()
1376 1377
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1378 1379
                    scheduler.step(loss)    # If you update learning rate each step
              # scheduler.step(loss)        # If you update learning rate each epoch
1380

1381 1382 1383 1384 1385 1386
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
1387 1388 1389 1390
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1391 1392
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1393 1394
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1395
                scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
1396 1397 1398 1399 1400 1401
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1402
                for batch_id in range(5):
1403 1404 1405 1406 1407 1408
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1409
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1410 1411
                    scheduler.step(out[0])    # If you update learning rate each step
              # scheduler.step(out[0])        # If you update learning rate each epoch
1412 1413 1414

    """

1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427
    def __init__(
        self,
        learning_rate,
        mode='min',
        factor=0.1,
        patience=10,
        threshold=1e-4,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        epsilon=1e-8,
        verbose=False,
    ):
1428 1429 1430 1431 1432 1433 1434
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode: ' + mode + ' is unknown!')
        self.mode = mode

        if factor >= 1.0:
            raise ValueError(
1435 1436
                'new_lr = origin_lr * gamma and gamma should be < 1.0.'
            )
1437 1438 1439 1440
        self.factor = factor

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
1441 1442 1443
            raise ValueError(
                'threshold mode: ' + threshold_mode + ' is unknown!'
            )
1444 1445 1446
        self.threshold_mode = threshold_mode
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
1447
                "The type of 'learning_rate' in 'ReduceOnPlateau' must be 'float', but received %s."
1448 1449
                % type(learning_rate)
            )
1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469

        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

        self.cooldown_counter = 0
        self.best = None
        self.num_bad_epochs = 0

        # Can not call Parent __init__, so implement here.
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = 0
        self.verbose = verbose
        self._var_name = None

    # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
1470
    def state_keys(self):
1471
        self.keys = [
1472 1473 1474 1475 1476
            'cooldown_counter',
            'best',
            'num_bad_epochs',
            'last_epoch',
            'last_lr',
1477 1478 1479 1480
        ]

    def step(self, metrics, epoch=None):
        """
G
guguguzi 已提交
1481
        step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .
1482 1483 1484
        The new learning rate will take effect on next epoch.

        Args:
G
guguguzi 已提交
1485
            metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
1486
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
1487
                'numpy.ndarray', its numel must be 1.
1488 1489 1490 1491
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
G
guguguzi 已提交
1492

1493
        Examples:
1494
            Please refer to the example of current LRScheduler.
1495 1496 1497 1498 1499 1500
        """
        if epoch is None:
            self.last_epoch = self.last_epoch + 1
        else:
            self.last_epoch = epoch

1501
        # loss must be float, numpy.ndarray or 1-D Tensor with numel 1
1502
        if isinstance(metrics, (core.eager.Tensor, numpy.ndarray)):
1503 1504
            assert metrics.size == 1, (
                "the size of metrics must be 1, but the current metrics.size is {}. Maybe that "
1505
                "you should call paddle.mean to process it first.".format(
1506
                    metrics.size
1507 1508 1509 1510 1511
                )
            )
        elif not isinstance(
            metrics, (int, float, numpy.float32, numpy.float64)
        ):
1512
            raise TypeError(
1513
                "metrics must be 'int', 'float', 'np.float64', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".format(
1514 1515 1516
                    type(metrics)
                )
            )
1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533

        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best is None or self._is_better(metrics, self.best):
                self.best = metrics
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = max(self.last_lr * self.factor, self.min_lr)
                if self.last_lr - new_lr > self.epsilon:
                    self.last_lr = new_lr
                    if self.verbose:
1534 1535 1536 1537 1538 1539 1540
                        print(
                            'Epoch {}: {} set learning rate to {}.'.format(
                                self.last_epoch,
                                self.__class__.__name__,
                                self.last_lr,
                            )
                        )
1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold


1556
class CosineAnnealingDecay(LRScheduler):
1557
    r"""
1558

G
guguguzi 已提交
1559 1560
    Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
    the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
1561
    SGDR.
1562 1563 1564 1565

    The algorithm can be described as following.

    .. math::
1566

1567 1568
        \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
        + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
G
guguguzi 已提交
1569
        & T_{cur} \neq (2k+1)T_{max};
1570 1571 1572 1573

        \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
        \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
        & T_{cur} = (2k+1)T_{max}.
G
guguguzi 已提交
1574 1575

    It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
1576
    Note that this only implements the cosine annealing part of SGDR, and not the restarts.
G
guguguzi 已提交
1577

1578 1579
    Args:
        learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
1580
        T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer.
1581 1582
        eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1583
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1584 1585

    Returns:
1586
        ``CosineAnnealingDecay`` instance to schedule learning rate.
1587 1588

    Examples:
G
guguguzi 已提交
1589

1590
        .. code-block:: python
1591
            :name: code-example1
1592

1593
            # Example1: train on default dynamic graph mode
1594 1595 1596
            import paddle
            import numpy as np

1597
            # train on default dynamic graph mode
1598
            linear = paddle.nn.Linear(10, 10)
1599 1600
            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1601
            for epoch in range(20):
Z
Zhou Wei 已提交
1602
                for batch_id in range(5):
1603
                    x = paddle.uniform([10, 10])
1604
                    out = linear(x)
C
chentianyu03 已提交
1605
                    loss = paddle.mean(out)
1606
                    loss.backward()
1607 1608
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1609 1610
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1611

1612 1613 1614 1615 1616 1617
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
1618 1619 1620 1621
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1622 1623
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1624 1625
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1626
                scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
1627 1628 1629 1630 1631 1632
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1633
                for batch_id in range(5):
1634 1635 1636 1637 1638 1639
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1640
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1641 1642
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1643 1644
    """

1645 1646 1647
    def __init__(
        self, learning_rate, T_max, eta_min=0, last_epoch=-1, verbose=False
    ):
1648 1649
        if not isinstance(T_max, int):
            raise TypeError(
1650
                "The type of 'T_max' in 'CosineAnnealingDecay' must be 'int', but received %s."
1651 1652
                % type(T_max)
            )
1653 1654
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
1655
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
1656 1657
                % type(eta_min)
            )
1658
        assert T_max > 0 and isinstance(
1659 1660
            T_max, int
        ), " 'T_max' must be a positive integer."
1661 1662
        self.T_max = T_max
        self.eta_min = float(eta_min)
1663
        super().__init__(learning_rate, last_epoch, verbose)
1664 1665 1666 1667 1668

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
1669 1670 1671 1672 1673 1674
            return (
                self.last_lr
                + (self.base_lr - self.eta_min)
                * (1 - math.cos(math.pi / self.T_max))
                / 2
            )
1675 1676

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
1677 1678
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)
        ) * (self.last_lr - self.eta_min) + self.eta_min
1679 1680

    def _get_closed_form_lr(self):
1681 1682 1683 1684 1685 1686
        return (
            self.eta_min
            + (self.base_lr - self.eta_min)
            * (1 + math.cos(math.pi * self.last_epoch / self.T_max))
            / 2
        )
G
guguguzi 已提交
1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739


class MultiplicativeDecay(LRScheduler):
    """
    Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` .

    The algorithm can be described as the code below.

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95

        learning_rate = 0.5        # epoch 0,
        learning_rate = 0.475      # epoch 1, 0.5*0.95
        learning_rate = 0.45125    # epoch 2, 0.475*0.95

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``MultiplicativeDecay`` instance to schedule learning rate.

    Examples:

        .. code-block:: python

            import paddle

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(20):
                for batch_id in range(5):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
                "The type of 'lr_lambda' in 'MultiplicativeDecay' must be 'function', but received %s."
1740 1741
                % type(lr_lambda)
            )
G
guguguzi 已提交
1742 1743

        self.lr_lambda = lr_lambda
1744
        super().__init__(learning_rate, last_epoch, verbose)
G
guguguzi 已提交
1745 1746

    def get_lr(self):
1747 1748 1749 1750
        cur_lr = self.base_lr
        for epoch in range(1, self.last_epoch + 1):
            cur_lr = cur_lr * self.lr_lambda(epoch)
        return cur_lr
1751 1752 1753 1754


class OneCycleLR(LRScheduler):
    r"""
1755

1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768
    Sets the learning rate according to the one cycle learning rate scheduler.
    The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
    from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

    It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_.

    Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle,
    which claims that “unpublished work has shown even better results by using only two phases”.
    If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` .

    Also note that you should update learning rate each step.

    Args:
1769
        max_learning_rate (float): The maximum learning rate. It is a python float number. Functionally, it defines the initial learning rate by ``divide_factor`` .
1770
        total_steps (int): Number of total training steps.
1771
        divide_factor (float, optional): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25.
1772 1773
        end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate.
        phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3.
1774
        anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'.
1775
        three_phase (bool, optional): Whether to use three phase.
1776

1777
            If ``True``:
1778

1779 1780 1781
                1. The learning rate will first increase from initial learning rate to maximum learning rate.
                2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase.
                3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate.
1782

1783
            If ``False``:
1784

1785 1786
                1. The learning rate will increase to maximum learning rate.
                2. Then it will directly decrease to minimum learning rate.
1787

1788 1789 1790 1791 1792 1793 1794 1795
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``OneCycleLR`` instance to schedule learning rate.

    Examples:
        .. code-block:: python
1796
            :name: code-example1
1797

1798
            # Example1: train on default dynamic graph mode
1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815
            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

1816 1817 1818 1819 1820 1821
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
1822 1823 1824 1825 1826 1827 1828 1829 1830 1831 1832 1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
1846

1847 1848
    """

1849 1850 1851 1852 1853 1854 1855 1856 1857 1858 1859 1860
    def __init__(
        self,
        max_learning_rate,
        total_steps,
        divide_factor=25.0,
        end_learning_rate=0.0001,
        phase_pct=0.3,
        anneal_strategy='cos',
        three_phase=False,
        last_epoch=-1,
        verbose=False,
    ):
1861 1862 1863
        # Check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
1864 1865 1866 1867
                "'max_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(max_learning_rate)
                )
            )
1868 1869 1870 1871 1872 1873
        if max_learning_rate < 0:
            raise ValueError("'max_learning_rate' must be a positive integer.")

        # Check type and value of end_learning_rate
        if not isinstance(end_learning_rate, (float, int)):
            raise TypeError(
1874 1875 1876 1877
                "'end_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(end_learning_rate)
                )
            )
1878 1879 1880 1881 1882
        if end_learning_rate < 0:
            raise ValueError("'end_learning_rate' must be a positive integer.")

        # Check type and value of total_steps
        if not isinstance(total_steps, int):
1883 1884
            raise TypeError(
                "'total_step' must be 'int', but received {}".format(
1885 1886 1887
                    type(total_steps)
                )
            )
1888 1889 1890 1891 1892 1893
        if total_steps <= 0:
            raise ValueError("'total_step' must be a positive integer.")
        self.total_steps = total_steps

        # Check type and value of pac_start
        if not isinstance(phase_pct, float):
1894 1895
            raise TypeError(
                "'phase_pct' must be 'float', but received {}".format(
1896 1897 1898
                    type(phase_pct)
                )
            )
1899 1900 1901
        if phase_pct < 0 or phase_pct > 1:
            raise ValueError(
                "'phase_pct' must be between 0 and 1, but received {}".format(
1902 1903 1904
                    phase_pct
                )
            )
1905 1906 1907 1908

        # Check type and value of divide_factor
        if not isinstance(divide_factor, (float, int)):
            raise TypeError(
1909 1910 1911 1912
                "'divide_factor' must be 'float' or 'int', but received {}".format(
                    type(divide_factor)
                )
            )
1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934

        initial_lr = max_learning_rate / float(divide_factor)
        min_lr = float(end_learning_rate)

        if three_phase:
            if phase_pct >= 0.5:
                raise ValueError(
                    "When three_phase is True, 'phase_pct' must be less than 0.5"
                )
            # start step and end step of each phase.
            self._step_config = [
                0,
                phase_pct * self.total_steps - 1,
                2 * phase_pct * self.total_steps - 2,
                self.total_steps - 1,
                self.total_steps - 1,  # for the last step.
            ]
            # step size of each phase.
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[3] - self._step_config[2],
1935 1936
                self._step_config[3]
                - self._step_config[2],  # for the last step.
1937 1938 1939
            ]
            # start lr and end lr of each phase.
            self._lr_config = [
1940 1941 1942 1943
                initial_lr,
                max_learning_rate,
                initial_lr,
                min_lr,
1944 1945 1946
            ]
        else:
            self._step_config = [
1947 1948 1949 1950
                0,
                phase_pct * self.total_steps - 1,
                self.total_steps - 1,
                self.total_steps - 1,
1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965
            ]
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[2] - self._step_config[1],
            ]
            self._lr_config = [initial_lr, max_learning_rate, min_lr]

        # Check anneal_strategy
        if anneal_strategy == 'cos':
            self.anneal_func = self._cos_annealing
        elif anneal_strategy == 'linear':
            self.anneal_func = self._linear_annealing
        else:
            raise ValueError(
1966 1967 1968 1969
                "'anneal_strategy' must by one of 'cos' or 'linear', but received {}".format(
                    anneal_strategy
                )
            )
1970
        super().__init__(initial_lr, last_epoch, verbose)
1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983

    def _cos_annealing(self, start_lr, end_lr, pct):
        cos_out = math.cos(math.pi * pct) + 1
        return end_lr + (start_lr - end_lr) / 2.0 * cos_out

    def _linear_annealing(self, start_lr, end_lr, pct):
        return (end_lr - start_lr) * pct + start_lr

    def get_lr(self):
        current_step = self.last_epoch

        if current_step > self.total_steps:
            raise ValueError(
1984 1985 1986 1987
                "Tried to step {} times. However the number of total steps is {}".format(
                    current_step, self.total_steps
                )
            )
1988

1989
        for i, (end_step, step_size) in enumerate(
1990 1991
            zip(self._step_config[1:], self._steps_size)
        ):
1992 1993 1994 1995
            # i == len(self._lr_config) - 2 catch the last step, otherwise it will return None.
            if current_step <= end_step or i == len(self._lr_config) - 2:
                # self._step_config[i] means start step of a phase.
                percentage = (current_step - self._step_config[i]) / step_size
1996 1997 1998
                return self.anneal_func(
                    self._lr_config[i], self._lr_config[i + 1], percentage
                )
1999 2000 2001 2002 2003 2004 2005 2006 2007 2008 2009 2010 2011 2012 2013 2014 2015 2016 2017 2018 2019 2020 2021 2022 2023 2024 2025 2026 2027 2028 2029 2030 2031 2032 2033 2034 2035 2036 2037 2038 2039 2040 2041


class CyclicLR(LRScheduler):
    r"""
    Set the learning rate according to the cyclic learning rate (CLR) scheduler.
    The scheduler regards the process of learning rate adjustment as one cycle after another.
    It cycles the learning rate between two boundaries with a constant frequency.
    The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis.

    It has been proposed in `Cyclic Learning Rates for Training Neural Networks <https://arxiv.org/abs/1506.01186>`_.

    According to the paper, the cyclic learning rate schedule has three build-in scale methods:

    * "triangular": A basic triangular cycle without any amplitude scaling.
    * "triangular2": A basic triangular cycle that reduce initial amplitude by half each cycle.
    * "exp_range": A cycle that scales initial amplitude by scale function which is defined as :math:`gamma^{iterations}` .

    The initial amplitude is defined as max_learning_rate - base_learning_rate.
    Also note that you should update learning rate each step.

    Args:
        base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
            that set the base_learning_rate to 1/3 or 1/4 of max_learning_rate.
        max_learning_rate (float): Maximum learning rate in the cycle. It defines the cycle amplitude as above.
            Since there is some scaling operation during process of learning rate adjustment,
            max_learning_rate may not actually be reached.
        step_size_up (int): Number of training steps, which is used to increase learning rate in a cycle.
            The step size of one cycle will be defined by step_size_up + step_size_down. According to the paper, step
            size should be set as at least 3 or 4 times steps in one epoch.
        step_size_down (int, optional): Number of training steps, which is used to decrease learning rate in a cycle.
            If not specified, it's value will initialize to `` step_size_up `` . Default: None
        mode (str, optional): one of 'triangular', 'triangular2' or 'exp_range'.
            If scale_fn is specified, this argument will be ignored. Default: 'triangular'
        exp_gamma (float): Constant in 'exp_range' scaling function: exp_gamma**iterations. Used only when mode = 'exp_range'. Default: 1.0
        scale_fn (function, optional): A custom scaling function, which is used to replace three build-in methods.
            It should only have one argument. For all x >= 0, 0 <= scale_fn(x) <= 1.
            If specified, then 'mode' will be ignored. Default: None
        scale_mode (str, optional): One of 'cycle' or 'iterations'. Defines whether scale_fn is evaluated on cycle
            number or cycle iterations (total iterations since start of training). Default: 'cycle'
        last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate.
        verbose: (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
2042
        ``CyclicLR`` instance to schedule learning rate.
2043 2044 2045

    Examples:
        .. code-block:: python
2046
            :name: code-example1
2047

2048
            # Example1: train on default dynamic graph mode
2049 2050 2051 2052 2053 2054 2055 2056 2057 2058 2059 2060 2061 2062 2063 2064 2065
            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

2066 2067 2068 2069 2070 2071
        .. code-block:: python
            :name: code-example2

            # Example2: train on static graph mode
            import paddle
            import numpy as np
2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091 2092 2093 2094 2095 2096 2097 2098
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
                    max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
    """

2099 2100 2101 2102 2103 2104 2105 2106 2107 2108 2109 2110 2111
    def __init__(
        self,
        base_learning_rate,
        max_learning_rate,
        step_size_up,
        step_size_down=None,
        mode='triangular',
        exp_gamma=1.0,
        scale_fn=None,
        scale_mode='cycle',
        last_epoch=-1,
        verbose=False,
    ):
2112 2113 2114
        # check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
2115 2116 2117 2118
                "'max_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(max_learning_rate)
                )
            )
2119 2120
        if max_learning_rate < 0:
            raise ValueError(
2121 2122 2123 2124
                "'max_learning_rate' must be a positive integer, but received {}".format(
                    max_learning_rate
                )
            )
2125 2126 2127 2128

        # check type and value of step_size_up
        if not isinstance(step_size_up, int):
            raise TypeError(
2129 2130 2131 2132
                "The type of 'step_size_up' must be int, but received {}".format(
                    type(step_size_up)
                )
            )
2133 2134
        if step_size_up <= 0:
            raise ValueError(
2135 2136 2137 2138
                "'step_size_up' must be a positive integer, but received {}".format(
                    step_size_up
                )
            )
2139 2140 2141 2142 2143

        # check type and value of step_size_down
        if step_size_down is not None:
            if not isinstance(step_size_down, int):
                raise TypeError(
2144 2145 2146 2147
                    "The type of 'step_size_down' must be int, but received {}".format(
                        type(step_size_down)
                    )
                )
2148 2149
            if step_size_down <= 0:
                raise ValueError(
2150 2151 2152 2153
                    "'step_size_down' must be a positive integer, but received {}".format(
                        step_size_down
                    )
                )
2154 2155 2156 2157 2158

        # check type of exp_gamma
        if not isinstance(exp_gamma, float):
            raise TypeError(
                "The type of 'exp_gamma' must be float, but received {}".format(
2159 2160 2161
                    type(exp_gamma)
                )
            )
2162 2163

        step_size_up = float(step_size_up)
2164 2165 2166 2167 2168
        step_size_down = (
            float(step_size_down)
            if step_size_down is not None
            else step_size_up
        )
2169 2170 2171 2172 2173 2174

        self.cycle_size = step_size_up + step_size_down
        self.step_up_pct = step_size_up / self.cycle_size
        self.max_lr = float(max_learning_rate)
        self.amplitude = self.max_lr - base_learning_rate

2175 2176 2177 2178
        if (
            mode not in ['triangular', 'triangular2', 'exp_range']
            and scale_fn is None
        ):
2179 2180 2181 2182 2183
            raise ValueError(
                "'mode' is invalid and 'scale_fn' is not specified, make sure one of 'mode' or 'scale_fn' is valid"
            )
        if scale_mode not in ['cycle', 'iterations']:
            raise ValueError(
2184 2185
                "'scale_mode' must be one of 'cycle' or 'iterations"
            )
2186 2187 2188 2189 2190 2191 2192 2193 2194 2195 2196 2197 2198 2199 2200 2201 2202 2203 2204 2205

        self.mode = mode
        self.gamma = exp_gamma  # only for exp_range mode

        if scale_fn is None:
            if self.mode == 'triangular':
                self.scale_fn = self._triangular_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'triangular2':
                self.scale_fn = self._triangular2_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'exp_range':
                self.scale_fn = self._exp_range_scale_fn
                self.scale_mode = 'iterations'
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode
        super().__init__(base_learning_rate, last_epoch, verbose)

    def _triangular_scale_fn(self, x):
2206
        return 1.0
2207 2208

    def _triangular2_scale_fn(self, x):
2209
        return 1 / (2.0 ** (x - 1))
2210 2211 2212 2213 2214 2215 2216 2217

    def _exp_range_scale_fn(self, x):
        return self.gamma**x

    def get_lr(self):
        iterations = self.last_epoch

        cycle = 1 + iterations // self.cycle_size
2218
        pct_per_cycle = 1.0 + iterations / self.cycle_size - cycle
2219 2220 2221 2222 2223 2224 2225 2226 2227 2228 2229

        if pct_per_cycle <= self.step_up_pct:
            scale_factor = pct_per_cycle / self.step_up_pct
        else:
            scale_factor = (1 - pct_per_cycle) / (1 - self.step_up_pct)

        base_height = self.amplitude * scale_factor

        lr = self.base_lr + base_height * self.scale_fn(eval(self.scale_mode))

        return lr