lr.py 83.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings
17 18 19

import numpy

20
import paddle.fluid.core as core
21 22
from paddle import Tensor

G
guguguzi 已提交
23
__all__ = [  # noqa
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
    'LRScheduler',
    'NoamDecay',
    'PiecewiseDecay',
    'NaturalExpDecay',
    'InverseTimeDecay',
    'PolynomialDecay',
    'LinearWarmup',
    'ExponentialDecay',
    'MultiStepDecay',
    'StepDecay',
    'LambdaDecay',
    'ReduceOnPlateau',
    'CosineAnnealingDecay',
    'MultiplicativeDecay',
    'OneCycleLR',
    'CyclicLR',
40 41 42
]


43
class LRScheduler:
44 45 46 47
    """

    LRScheduler Base class. Define the common interface of a learning rate scheduler.

Z
Zhou Wei 已提交
48
    User can import it by ``from paddle.optimizer.lr import LRScheduler`` ,
49 50 51 52 53 54 55 56 57 58 59 60 61 62

    then overload it for your subclass and have a custom implementation of ``get_lr()`` .

    Otherwise, an ``NotImplementedError`` exception will be thrown.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        instance to schedule learning rate.

    Examples:
63
        Here is an example of a simple ``StepDecay`` implementation.
G
guguguzi 已提交
64

65
        .. code-block:: python
G
guguguzi 已提交
66

67
            import paddle
Z
Zhou Wei 已提交
68
            from paddle.optimizer.lr import LRScheduler
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

            class StepDecay(LRScheduler):
                def __init__(self,
                            learning_rate,
                            step_size,
                            gamma=0.1,
                            last_epoch=-1,
                            verbose=False):
                    if not isinstance(step_size, int):
                        raise TypeError(
                            "The type of 'step_size' must be 'int', but received %s." %
                            type(step_size))
                    if gamma >= 1.0:
                        raise ValueError('gamma should be < 1.0.')

                    self.step_size = step_size
                    self.gamma = gamma
86
                    super().__init__(learning_rate, last_epoch, verbose)
87 88 89 90

                def get_lr(self):
                    i = self.last_epoch // self.step_size
                    return self.base_lr * (self.gamma**i)
91 92 93 94 95 96

    """

    def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
97 98 99 100
                "The type of learning rate must be float, but received {}".format(
                    type(learning_rate)
                )
            )
101 102 103 104 105 106 107 108 109
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = last_epoch
        self.verbose = verbose
        self._var_name = None

        self.step()

    def __call__(self):
G
guguguzi 已提交
110
        """
111
        Return lastest computed learning rate on current epoch.
112 113 114 115 116
        """
        return self.last_lr

    def step(self, epoch=None):
        """
117

G
guguguzi 已提交
118
        ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
119
        The new learning rate will take effect on next ``optimizer.step`` .
120 121 122 123 124 125

        Args:
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
126

127 128 129 130 131 132 133 134 135 136 137 138
        """
        if epoch is None:
            self.last_epoch += 1
            self.last_lr = self.get_lr()
        else:
            self.last_epoch = epoch
            if hasattr(self, "_get_closed_form_lr"):
                self.last_lr = self._get_closed_form_lr()
            else:
                self.last_lr = self.get_lr()

        if self.verbose:
139 140 141 142 143
            print(
                'Epoch {}: {} set learning rate to {}.'.format(
                    self.last_epoch, self.__class__.__name__, self.last_lr
                )
            )
144 145 146

    def state_dict(self):
        """
147

148 149
        Returns the state of the scheduler as a :class:`dict`.

150
        It is a subset of ``self.__dict__`` .
151
        """
152
        self.state_keys()
153 154 155 156 157 158 159 160 161
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Tensor):
                assert value.shape == [
                    1
                ], "shape of Tensor in state_dict must be [1] {}".format(
162 163
                    value.shape
                )
164 165 166 167 168
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

169
    # For those subclass who overload LRScheduler, "last_epoch, last_lr" will be saved by default.
170
    # (Note): you can change it for your subclass.
171
    def state_keys(self):
172
        """
173 174 175 176 177 178 179

        For those subclass who overload ``LRScheduler`` (Base Class). Acquiescently, "last_epoch, last_lr" will be saved by ``self.keys = ['last_epoch', 'last_lr']`` .

        ``last_epoch`` is the current epoch num, and ``last_lr`` is the current learning rate.

        If you want to change the default behavior, you should have a custom implementation of ``_state_keys()`` to redefine ``self.keys`` .

180 181 182
        """
        self.keys = ['last_epoch', 'last_lr']

183
    def set_state_dict(self, state_dict):
184
        """
185

186 187
        Loads the schedulers state.
        """
188
        self.state_keys()
189 190 191 192 193
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
194 195 196 197
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".format(
                        key
                    )
                )
198 199 200 201 202
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

203 204
    # alias for set_state_dict
    set_dict = set_state_dict
205 206

    def get_lr(self):
207
        """
G
guguguzi 已提交
208

209 210 211 212
        For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .

        Otherwise, an ``NotImplementedError`` exception will be thrown.
        """
213 214 215 216
        # calculate by python float
        raise NotImplementedError


217
class NoamDecay(LRScheduler):
218
    r"""
219

G
guguguzi 已提交
220
    Applies Noam Decay to the initial learning rate.
221 222 223 224 225 226 227

    The algorithm can be described as following.

    .. math::

        new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})

G
guguguzi 已提交
228
    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
229 230 231 232 233 234 235


    Args:
        d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
        warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
        learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
236
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
237 238

    Returns:
239
        ``NoamDecay`` instance to schedule learning rate.
240 241 242 243 244 245 246

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

247
            # train on default dynamic graph mode
248
            linear = paddle.nn.Linear(10, 10)
249 250
            scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
251
            for epoch in range(20):
Z
Zhou Wei 已提交
252
                for batch_id in range(5):
253
                    x = paddle.uniform([10, 10])
254
                    out = linear(x)
C
chentianyu03 已提交
255
                    loss = paddle.mean(out)
256
                    loss.backward()
257 258
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
259 260
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
261

262
            # train on static graph mode
263 264 265 266
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
267 268
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
269 270
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
271
                scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
272 273 274 275 276 277
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
278
                for batch_id in range(5):
279 280 281 282 283 284
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
285
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
286 287
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
288 289 290

    """

291 292 293 294 295 296 297 298
    def __init__(
        self,
        d_model,
        warmup_steps,
        learning_rate=1.0,
        last_epoch=-1,
        verbose=False,
    ):
299 300
        self.d_model = d_model
        self.warmup_steps = warmup_steps
301
        super().__init__(learning_rate, last_epoch, verbose)
302 303 304 305 306 307 308 309 310 311

    def get_lr(self):
        if self.last_epoch == 0:
            a = 1
        else:
            a = self.last_epoch**-0.5
        b = self.warmup_steps**-1.5 * self.last_epoch
        return self.base_lr * (self.d_model**-0.5) * min(a, b)


312
class PiecewiseDecay(LRScheduler):
313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330
    """

    Piecewise learning rate scheduler.

    The algorithm can be described as the code below:

    .. code-block:: text

        boundaries = [100, 200]
        values = [1.0, 0.5, 0.1]
        if epoch < 100:
            learning_rate = 1.0
        elif 100 <= global_step < 200:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Args:
G
guguguzi 已提交
331 332
        boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int.
        values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries.
333 334
            The type of element in the list is python float.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
335
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
336 337

    Returns:
338
        ``PiecewiseDecay`` instance to schedule learning rate.
339 340

    Examples:
G
guguguzi 已提交
341

342 343 344 345 346
        .. code-block:: python

            import paddle
            import numpy as np

347
            # train on default dynamic graph mode
348
            linear = paddle.nn.Linear(10, 10)
349 350
            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
351
            for epoch in range(20):
Z
Zhou Wei 已提交
352
                for batch_id in range(5):
353
                    x = paddle.uniform([10, 10])
354
                    out = linear(x)
C
chentianyu03 已提交
355
                    loss = paddle.mean(out)
356
                    loss.backward()
357 358
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
359 360
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
361

362
            # train on static graph mode
363 364 365 366
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
367 368
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
369 370
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
371
                scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
372 373 374 375 376 377
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
378
                for batch_id in range(5):
379 380 381 382 383 384
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
385
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
386 387
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
388 389 390 391 392
    """

    def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
        self.boundaries = boundaries
        self.values = values
393
        super().__init__(last_epoch=last_epoch, verbose=verbose)
394 395 396 397 398 399 400 401

    def get_lr(self):
        for i in range(len(self.boundaries)):
            if self.last_epoch < self.boundaries[i]:
                return self.values[i]
        return self.values[len(self.values) - 1]


402
class NaturalExpDecay(LRScheduler):
403
    r"""
404 405

    Applies natural exponential decay to the initial learning rate.
G
guguguzi 已提交
406

407 408 409 410
    The algorithm can be described as following:

    .. math::

411
        new\_learning\_rate = learning\_rate * e^{- gamma * epoch}
412 413 414

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
415
        gamma (float, optional): A Ratio to update the learning rate, should greater than 0.0 to make learning rate decay. Default: 0.1.
416
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
417
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
418 419

    Returns:
420
        ``NaturalExpDecay`` instance to schedule learning rate.
421 422

    Examples:
G
guguguzi 已提交
423

424 425 426 427 428
        .. code-block:: python

            import paddle
            import numpy as np

429
            # train on default dynamic graph mode
430
            linear = paddle.nn.Linear(10, 10)
431 432
            scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
433
            for epoch in range(20):
Z
Zhou Wei 已提交
434
                for batch_id in range(5):
435
                    x = paddle.uniform([10, 10])
436
                    out = linear(x)
C
chentianyu03 已提交
437
                    loss = paddle.mean(out)
438
                    loss.backward()
439 440
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
441 442
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
443

444
            # train on static graph mode
445 446 447 448
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
449 450
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
451 452
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
453
                scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
454 455 456 457 458 459
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
460
                for batch_id in range(5):
461 462 463 464 465 466
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
467
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
468 469
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
470 471 472
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
473 474 475
        assert (
            gamma > 0.0
        ), " 'gamma' must be a positive number so that the learning rate will decay."
476
        self.gamma = gamma
477
        super().__init__(learning_rate, last_epoch, verbose)
478 479 480 481 482

    def get_lr(self):
        return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)


483
class InverseTimeDecay(LRScheduler):
484
    r"""
485 486 487 488 489 490 491

    Applies inverse time decay to the initial learning rate.

    The algorithm can be described as following:

    .. math::

492
        new\_learning\_rate = \frac{learning\_rate}{1 + gamma * epoch}
493 494 495

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
496
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
497 498
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
499
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
500 501

    Returns:
502
        ``InverseTimeDecay`` instance to schedule learning rate.
503 504

    Examples:
G
guguguzi 已提交
505

506 507 508 509 510
        .. code-block:: python

            import paddle
            import numpy as np

511
            # train on default dynamic graph mode
512
            linear = paddle.nn.Linear(10, 10)
513 514
            scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
515
            for epoch in range(20):
Z
Zhou Wei 已提交
516
                for batch_id in range(5):
517
                    x = paddle.uniform([10, 10])
518
                    out = linear(x)
C
chentianyu03 已提交
519
                    loss = paddle.mean(out)
520
                    loss.backward()
521 522
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
523 524
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
525

526
            # train on static graph mode
527 528 529 530
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
531 532
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
533 534
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
535
                scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
536 537 538 539 540 541
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
542
                for batch_id in range(5):
543 544 545 546 547 548
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
549
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
550 551
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
552 553 554 555 556

    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
557
        super().__init__(learning_rate, last_epoch, verbose)
558 559 560 561 562

    def get_lr(self):
        return self.base_lr / (1 + self.gamma * self.last_epoch)


563
class PolynomialDecay(LRScheduler):
564
    r"""
565 566 567 568 569 570 571 572 573

    Applies polynomial decay to the initial learning rate.

    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

G
guguguzi 已提交
574
        decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps})
575

576
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
577 578 579 580 581

    If cycle is set to False, then:

    .. math::

G
guguguzi 已提交
582
        epoch & = min(epoch, decay\_steps)
583

584
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
585 586 587 588


    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
589
        decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer.
590
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
591
        power(float, optional): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 1.0.
G
guguguzi 已提交
592
        cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
593 594
            to ``end_lr`` .  If False, the learning rate is monotone decreasing. Default: False.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
595
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
596 597

    Returns:
598
        ``PolynomialDecay`` instance to schedule learning rate.
599 600

    Examples:
G
guguguzi 已提交
601

602 603 604 605 606
        .. code-block:: python

            import paddle
            import numpy as np

607
            # train on default dynamic graph mode
608
            linear = paddle.nn.Linear(10, 10)
609 610
            scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
611
            for epoch in range(20):
Z
Zhou Wei 已提交
612
                for batch_id in range(5):
613
                    x = paddle.uniform([10, 10])
614
                    out = linear(x)
C
chentianyu03 已提交
615
                    loss = paddle.mean(out)
616
                    loss.backward()
617 618
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
619 620
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
621

622
            # train on static graph mode
623 624 625 626
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
627 628
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
629 630
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
631
                scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
632 633 634 635 636 637
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
638
                for batch_id in range(5):
639 640 641 642 643 644
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
645
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
646 647
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
648 649
    """

650 651 652 653 654 655 656 657 658 659
    def __init__(
        self,
        learning_rate,
        decay_steps,
        end_lr=0.0001,
        power=1.0,
        cycle=False,
        last_epoch=-1,
        verbose=False,
    ):
660
        assert decay_steps > 0 and isinstance(
661 662
            decay_steps, int
        ), " 'decay_steps' must be a positive integer."
663 664
        self.decay_steps = decay_steps
        self.end_lr = end_lr
665 666 667
        assert (
            power > 0.0
        ), " 'power' must be greater than 0.0 so that the learning rate will decay."
668 669
        self.power = power
        self.cycle = cycle
670
        super().__init__(learning_rate, last_epoch, verbose)
671 672 673 674 675 676

    def get_lr(self):
        tmp_epoch_num = self.last_epoch
        tmp_decay_steps = self.decay_steps
        if self.cycle:
            div_res = math.ceil(
677 678
                float(self.last_epoch) / float(self.decay_steps)
            )
679 680 681 682 683 684 685 686

            if self.last_epoch == 0:
                div_res = 1
            tmp_decay_steps = self.decay_steps * div_res
        else:
            tmp_epoch_num = min(self.last_epoch, self.decay_steps)

        return (self.base_lr - self.end_lr) * (
687 688
            (1 - float(tmp_epoch_num) / float(tmp_decay_steps)) ** self.power
        ) + self.end_lr
689 690


691
class LinearWarmup(LRScheduler):
692
    r"""
693 694 695

    Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
G
guguguzi 已提交
696

697
    When epoch < warmup_steps, learning rate is updated as:
G
guguguzi 已提交
698

699
    .. math::
G
guguguzi 已提交
700

701
            lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps}
G
guguguzi 已提交
702

703
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
G
guguguzi 已提交
704

705
    When epoch >= warmup_steps, learning rate is updated as:
G
guguguzi 已提交
706

707
    .. math::
G
guguguzi 已提交
708

709
            lr = learning_rate
G
guguguzi 已提交
710

711
    where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
712 713

    Args:
714
        learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
715
        warmup_steps (int): total steps of warm up. It must be a positive integer.
716 717 718
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
719
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
720 721

    Returns:
722
        ``LinearWarmup`` instance to schedule learning rate.
723 724

    Examples:
G
guguguzi 已提交
725

726 727 728 729 730
        .. code-block:: python

            import paddle
            import numpy as np

731
            # train on default dynamic graph mode
732
            linear = paddle.nn.Linear(10, 10)
733
            scheduler = paddle.optimizer.lr.LinearWarmup(
734
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
735
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
736
            for epoch in range(20):
Z
Zhou Wei 已提交
737
                for batch_id in range(5):
738
                    x = paddle.uniform([10, 10])
739
                    out = linear(x)
C
chentianyu03 已提交
740
                    loss = paddle.mean(out)
741
                    loss.backward()
742 743
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
744 745
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
746

747
            # train on static graph mode
748 749 750 751
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
752 753
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
754 755
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
756
                scheduler = paddle.optimizer.lr.LinearWarmup(
757 758 759 760 761 762 763
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
764
                for batch_id in range(5):
765 766 767 768 769 770
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
771
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
772 773
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
774 775
    """

776 777 778 779 780 781 782 783 784 785 786 787 788 789
    def __init__(
        self,
        learning_rate,
        warmup_steps,
        start_lr,
        end_lr,
        last_epoch=-1,
        verbose=False,
    ):
        type_check = (
            isinstance(learning_rate, float)
            or isinstance(learning_rate, int)
            or isinstance(learning_rate, LRScheduler)
        )
790 791
        if not type_check:
            raise TypeError(
792 793 794 795
                "the type of learning_rate should be [int, float or LRScheduler], the current type is {}".format(
                    learning_rate
                )
            )
796
        self.learning_rate = learning_rate
797
        assert warmup_steps > 0 and isinstance(
798 799
            warmup_steps, int
        ), " 'warmup_steps' must be a positive integer."
800 801 802
        self.warmup_steps = warmup_steps
        self.start_lr = start_lr
        self.end_lr = end_lr
803 804 805
        assert (
            end_lr > start_lr
        ), "end_lr {} must be greater than start_lr {}".format(end_lr, start_lr)
806
        super().__init__(start_lr, last_epoch, verbose)
807

808 809 810 811 812 813
    def state_dict(self):
        """
        Returns the state of the LinearWarmup scheduler as a :class:`dict`.

        It is a subset of ``self.__dict__`` .
        """
814
        state_dict = super().state_dict()
815 816 817 818 819 820 821 822
        if isinstance(self.learning_rate, LRScheduler):
            state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
        return state_dict

    def set_state_dict(self, state_dict):
        """
        Loads state_dict for LinearWarmup scheduler.
        """
823
        super().set_state_dict(state_dict)
824 825 826
        if isinstance(self.learning_rate, LRScheduler):
            self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])

827 828 829
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return (self.end_lr - self.start_lr) * float(
830 831
                self.last_epoch
            ) / float(self.warmup_steps) + self.start_lr
832
        else:
833
            if isinstance(self.learning_rate, LRScheduler):
834 835
                self.learning_rate.step(self.last_epoch - self.warmup_steps)
                return self.learning_rate()
836 837 838 839

            return self.learning_rate


840
class ExponentialDecay(LRScheduler):
841
    r"""
842

843
    Update learning rate by `gamma` each epoch.
844 845

    The algorithm can be described as following.
G
guguguzi 已提交
846

847 848 849 850 851 852
    .. math::

        new\_learning\_rate = last\_learning\_rate * gamma

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
853
        gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
854
            It should be in interval (0.0, 1.0).
855
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
856
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
857 858

    Returns:
859
        ``ExponentialDecay`` instance to schedule learning rate.
860 861

    Examples:
G
guguguzi 已提交
862

863 864 865 866 867
        .. code-block:: python

            import paddle
            import numpy as np

868
            # train on default dynamic graph mode
869
            linear = paddle.nn.Linear(10, 10)
870 871
            scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
872
            for epoch in range(20):
Z
Zhou Wei 已提交
873
                for batch_id in range(5):
874
                    x = paddle.uniform([10, 10])
875
                    out = linear(x)
C
chentianyu03 已提交
876
                    loss = paddle.mean(out)
877
                    loss.backward()
878 879
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
880 881
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
882

883
            # train on static graph mode
884 885 886 887
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
888 889
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
890 891
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
892
                scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
893 894 895 896 897 898
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
899
                for batch_id in range(5):
900 901 902 903 904 905
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
906
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
907 908
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
909 910 911
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
912 913 914
        assert (
            gamma > 0.0 and gamma < 1.0
        ), " 'gamma' must be in interval (0.0, 1.0) so that the learning rate will decay."
915
        self.gamma = gamma
916
        super().__init__(learning_rate, last_epoch, verbose)
917 918 919 920 921

    def get_lr(self):
        return self.base_lr * (self.gamma**self.last_epoch)


922
class MultiStepDecay(LRScheduler):
923
    """
924
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
925

G
guguguzi 已提交
926
    The algorithm can be described as the code below.
927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
G
guguguzi 已提交
943
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
944 945
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
946
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
947

948 949

    Returns:
950
        ``MultiStepDecay`` instance to schedule learning rate.
951 952

    Examples:
G
guguguzi 已提交
953

954 955 956 957 958
        .. code-block:: python

            import paddle
            import numpy as np

959
            # train on default dynamic graph mode
960
            linear = paddle.nn.Linear(10, 10)
961 962
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
963
            for epoch in range(20):
Z
Zhou Wei 已提交
964
                for batch_id in range(5):
965
                    x = paddle.uniform([10, 10])
966
                    out = linear(x)
C
chentianyu03 已提交
967
                    loss = paddle.mean(out)
968
                    loss.backward()
969 970
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
971 972
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
973

974
            # train on static graph mode
975 976 977 978
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
979 980
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
981 982
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
983
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
984 985 986 987 988 989
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
990
                for batch_id in range(5):
991 992 993 994 995 996
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
997
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
998 999
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1000 1001
    """

1002 1003 1004
    def __init__(
        self, learning_rate, milestones, gamma=0.1, last_epoch=-1, verbose=False
    ):
1005 1006 1007
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
1008 1009
                % type(milestones)
            )
1010

1011 1012
        if not all(
            [
1013 1014
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
1015 1016
            ]
        ):
1017 1018 1019 1020 1021 1022
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.milestones = milestones
        self.gamma = gamma
1023
        super().__init__(learning_rate, last_epoch, verbose)
1024 1025 1026 1027 1028

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
1029
        return self.base_lr * (self.gamma ** len(self.milestones))
1030 1031


1032
class StepDecay(LRScheduler):
1033 1034 1035
    """
    Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.

G
guguguzi 已提交
1036
    The algorithm can be described as the code below.
1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        gamma = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
1051
        step_size (int): the interval to update. It must be a positive integer.
G
guguguzi 已提交
1052
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
1053 1054
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1055
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1056 1057

    Returns:
1058
        ``StepDecay`` instance to schedule learning rate.
1059 1060 1061


    Examples:
G
guguguzi 已提交
1062

1063 1064 1065 1066 1067
        .. code-block:: python

            import paddle
            import numpy as np

1068
            # train on default dynamic graph mode
1069
            linear = paddle.nn.Linear(10, 10)
1070 1071
            scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1072
            for epoch in range(20):
Z
Zhou Wei 已提交
1073
                for batch_id in range(5):
1074
                    x = paddle.uniform([10, 10])
1075
                    out = linear(x)
C
chentianyu03 已提交
1076
                    loss = paddle.mean(out)
1077
                    loss.backward()
1078 1079
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1080 1081
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1082

1083
            # train on static graph mode
1084 1085 1086 1087
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1088 1089
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1090 1091
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1092
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
1093 1094 1095 1096 1097 1098
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1099
                for batch_id in range(5):
1100 1101 1102 1103 1104 1105
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1106
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1107 1108
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1109 1110
    """

1111 1112 1113
    def __init__(
        self, learning_rate, step_size, gamma=0.1, last_epoch=-1, verbose=False
    ):
1114 1115
        if not isinstance(step_size, int):
            raise TypeError(
1116 1117 1118
                "The type of 'step_size' must be 'int', but received %s."
                % type(step_size)
            )
1119 1120 1121
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

1122
        assert step_size > 0 and isinstance(
1123 1124
            step_size, int
        ), " 'step_size' must be a positive integer."
1125 1126
        self.step_size = step_size
        self.gamma = gamma
1127
        super().__init__(learning_rate, last_epoch, verbose)
1128 1129 1130 1131 1132 1133

    def get_lr(self):
        i = self.last_epoch // self.step_size
        return self.base_lr * (self.gamma**i)


1134
class LambdaDecay(LRScheduler):
1135 1136 1137
    """
    Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

G
guguguzi 已提交
1138
    The algorithm can be described as the code below.
1139 1140 1141 1142 1143 1144

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

1145 1146 1147
        learning_rate = 0.5        # epoch 0, 0.5*0.95**0
        learning_rate = 0.475      # epoch 1, 0.5*0.95**1
        learning_rate = 0.45125    # epoch 2, 0.5*0.95**2
1148 1149 1150 1151 1152

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1153
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
1154

1155
    Returns:
1156
        ``LambdaDecay`` instance to schedule learning rate.
1157 1158

    Examples:
G
guguguzi 已提交
1159

1160 1161 1162 1163 1164
        .. code-block:: python

            import paddle
            import numpy as np

1165
            # train on default dynamic graph mode
1166
            linear = paddle.nn.Linear(10, 10)
1167 1168
            scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1169
            for epoch in range(20):
Z
Zhou Wei 已提交
1170
                for batch_id in range(5):
1171
                    x = paddle.uniform([10, 10])
1172
                    out = linear(x)
C
chentianyu03 已提交
1173
                    loss = paddle.mean(out)
1174
                    loss.backward()
1175 1176
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1177 1178
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1179

1180
            # train on static graph mode
1181 1182 1183 1184
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1185 1186
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1187 1188
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1189
                scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
1190 1191 1192 1193 1194 1195
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1196
                for batch_id in range(5):
1197 1198 1199 1200 1201 1202
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1203
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1204 1205
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1206 1207 1208 1209 1210 1211

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
1212
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1213 1214
                % type(lr_lambda)
            )
1215 1216

        self.lr_lambda = lr_lambda
1217
        super().__init__(learning_rate, last_epoch, verbose)
1218 1219 1220 1221 1222

    def get_lr(self):
        return self.base_lr * self.lr_lambda(self.last_epoch)


1223
class ReduceOnPlateau(LRScheduler):
1224
    """
G
guguguzi 已提交
1225
    Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
1226 1227
    by 2 to 10 times once model performance has no longer improvement.

G
guguguzi 已提交
1228 1229 1230
    The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
1231 1232 1233 1234 1235 1236
    number of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
1237 1238
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning
1239
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
G
guguguzi 已提交
1240
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
1241
            It should be less than 1.0. Default: 0.1.
G
guguguzi 已提交
1242
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
1243
            Default: 10.
G
guguguzi 已提交
1244
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
1245 1246
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
G
guguguzi 已提交
1247
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
1248 1249 1250
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
G
guguguzi 已提交
1251
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
1252
            the update is ignored. Default: 1e-8.
1253 1254
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.

G
guguguzi 已提交
1255

1256
    Returns:
1257
        ``ReduceOnPlateau`` instance to schedule learning rate.
1258 1259 1260 1261 1262 1263 1264 1265


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

1266
            # train on default dynamic graph mode
1267
            linear = paddle.nn.Linear(10, 10)
1268 1269
            scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1270
            for epoch in range(20):
Z
Zhou Wei 已提交
1271
                for batch_id in range(5):
1272
                    x = paddle.uniform([10, 10])
1273
                    out = linear(x)
C
chentianyu03 已提交
1274
                    loss = paddle.mean(out)
1275
                    loss.backward()
1276 1277
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1278 1279
                    scheduler.step(loss)    # If you update learning rate each step
              # scheduler.step(loss)        # If you update learning rate each epoch
1280

1281
            # train on static graph mode
1282 1283 1284 1285
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1286 1287
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1288 1289
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1290
                scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
1291 1292 1293 1294 1295 1296
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1297
                for batch_id in range(5):
1298 1299 1300 1301 1302 1303
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1304
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1305 1306
                    scheduler.step(out[0])    # If you update learning rate each step
              # scheduler.step(out[0])        # If you update learning rate each epoch
1307 1308 1309

    """

1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322
    def __init__(
        self,
        learning_rate,
        mode='min',
        factor=0.1,
        patience=10,
        threshold=1e-4,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        epsilon=1e-8,
        verbose=False,
    ):
1323 1324 1325 1326 1327 1328 1329
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode: ' + mode + ' is unknown!')
        self.mode = mode

        if factor >= 1.0:
            raise ValueError(
1330 1331
                'new_lr = origin_lr * gamma and gamma should be < 1.0.'
            )
1332 1333 1334 1335
        self.factor = factor

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
1336 1337 1338
            raise ValueError(
                'threshold mode: ' + threshold_mode + ' is unknown!'
            )
1339 1340 1341
        self.threshold_mode = threshold_mode
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
1342
                "The type of 'learning_rate' in 'ReduceOnPlateau' must be 'float', but received %s."
1343 1344
                % type(learning_rate)
            )
1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364

        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

        self.cooldown_counter = 0
        self.best = None
        self.num_bad_epochs = 0

        # Can not call Parent __init__, so implement here.
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = 0
        self.verbose = verbose
        self._var_name = None

    # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
1365
    def state_keys(self):
1366
        self.keys = [
1367 1368 1369 1370 1371
            'cooldown_counter',
            'best',
            'num_bad_epochs',
            'last_epoch',
            'last_lr',
1372 1373 1374 1375
        ]

    def step(self, metrics, epoch=None):
        """
G
guguguzi 已提交
1376
        step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .
1377 1378 1379
        The new learning rate will take effect on next epoch.

        Args:
G
guguguzi 已提交
1380
            metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
1381 1382 1383 1384 1385 1386
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
                'numpy.ndarray', its shape must be [1].
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
G
guguguzi 已提交
1387

1388
        Examples:
1389
            Please refer to the example of current LRScheduler.
1390 1391 1392 1393 1394 1395
        """
        if epoch is None:
            self.last_epoch = self.last_epoch + 1
        else:
            self.last_epoch = epoch

1396
        # loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
1397
        if isinstance(metrics, (core.eager.Tensor, numpy.ndarray)):
1398 1399 1400 1401 1402 1403 1404 1405 1406 1407
            assert len(metrics.shape) == 1 and metrics.shape[0] == 1, (
                "the metrics.shape "
                "should be (1L,), but the current metrics.shape is {}. Maybe that "
                "you should call paddle.mean to process it first.".format(
                    metrics.shape
                )
            )
        elif not isinstance(
            metrics, (int, float, numpy.float32, numpy.float64)
        ):
1408
            raise TypeError(
1409 1410 1411 1412
                "metrics must be 'int', 'float', 'np.float', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".format(
                    type(metrics)
                )
            )
1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429

        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best is None or self._is_better(metrics, self.best):
                self.best = metrics
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = max(self.last_lr * self.factor, self.min_lr)
                if self.last_lr - new_lr > self.epsilon:
                    self.last_lr = new_lr
                    if self.verbose:
1430 1431 1432 1433 1434 1435 1436
                        print(
                            'Epoch {}: {} set learning rate to {}.'.format(
                                self.last_epoch,
                                self.__class__.__name__,
                                self.last_lr,
                            )
                        )
1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold


1452
class CosineAnnealingDecay(LRScheduler):
1453
    r"""
1454

G
guguguzi 已提交
1455 1456
    Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
    the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
1457
    SGDR.
1458 1459 1460 1461

    The algorithm can be described as following.

    .. math::
1462

1463 1464
        \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
        + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
G
guguguzi 已提交
1465
        & T_{cur} \neq (2k+1)T_{max};
1466 1467 1468 1469

        \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
        \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
        & T_{cur} = (2k+1)T_{max}.
G
guguguzi 已提交
1470 1471

    It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
1472
    Note that this only implements the cosine annealing part of SGDR, and not the restarts.
G
guguguzi 已提交
1473

1474 1475
    Args:
        learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
1476
        T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer.
1477 1478
        eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1479
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1480 1481

    Returns:
1482
        ``CosineAnnealingDecay`` instance to schedule learning rate.
1483 1484

    Examples:
G
guguguzi 已提交
1485

1486 1487 1488 1489 1490
        .. code-block:: python

            import paddle
            import numpy as np

1491
            # train on default dynamic graph mode
1492
            linear = paddle.nn.Linear(10, 10)
1493 1494
            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1495
            for epoch in range(20):
Z
Zhou Wei 已提交
1496
                for batch_id in range(5):
1497
                    x = paddle.uniform([10, 10])
1498
                    out = linear(x)
C
chentianyu03 已提交
1499
                    loss = paddle.mean(out)
1500
                    loss.backward()
1501 1502
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1503 1504
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1505

1506
            # train on static graph mode
1507 1508 1509 1510
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1511 1512
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1513 1514
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1515
                scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
1516 1517 1518 1519 1520 1521
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1522
                for batch_id in range(5):
1523 1524 1525 1526 1527 1528
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1529
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1530 1531
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1532 1533
    """

1534 1535 1536
    def __init__(
        self, learning_rate, T_max, eta_min=0, last_epoch=-1, verbose=False
    ):
1537 1538
        if not isinstance(T_max, int):
            raise TypeError(
1539
                "The type of 'T_max' in 'CosineAnnealingDecay' must be 'int', but received %s."
1540 1541
                % type(T_max)
            )
1542 1543
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
1544
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
1545 1546
                % type(eta_min)
            )
1547
        assert T_max > 0 and isinstance(
1548 1549
            T_max, int
        ), " 'T_max' must be a positive integer."
1550 1551
        self.T_max = T_max
        self.eta_min = float(eta_min)
1552
        super().__init__(learning_rate, last_epoch, verbose)
1553 1554 1555 1556 1557

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
1558 1559 1560 1561 1562 1563
            return (
                self.last_lr
                + (self.base_lr - self.eta_min)
                * (1 - math.cos(math.pi / self.T_max))
                / 2
            )
1564 1565

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
1566 1567
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)
        ) * (self.last_lr - self.eta_min) + self.eta_min
1568 1569

    def _get_closed_form_lr(self):
1570 1571 1572 1573 1574 1575
        return (
            self.eta_min
            + (self.base_lr - self.eta_min)
            * (1 + math.cos(math.pi * self.last_epoch / self.T_max))
            / 2
        )
G
guguguzi 已提交
1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628


class MultiplicativeDecay(LRScheduler):
    """
    Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` .

    The algorithm can be described as the code below.

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95

        learning_rate = 0.5        # epoch 0,
        learning_rate = 0.475      # epoch 1, 0.5*0.95
        learning_rate = 0.45125    # epoch 2, 0.475*0.95

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``MultiplicativeDecay`` instance to schedule learning rate.

    Examples:

        .. code-block:: python

            import paddle

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(20):
                for batch_id in range(5):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
                "The type of 'lr_lambda' in 'MultiplicativeDecay' must be 'function', but received %s."
1629 1630
                % type(lr_lambda)
            )
G
guguguzi 已提交
1631 1632

        self.lr_lambda = lr_lambda
1633
        super().__init__(learning_rate, last_epoch, verbose)
G
guguguzi 已提交
1634 1635

    def get_lr(self):
1636 1637 1638 1639
        cur_lr = self.base_lr
        for epoch in range(1, self.last_epoch + 1):
            cur_lr = cur_lr * self.lr_lambda(epoch)
        return cur_lr
1640 1641 1642 1643


class OneCycleLR(LRScheduler):
    r"""
1644

1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657
    Sets the learning rate according to the one cycle learning rate scheduler.
    The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
    from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

    It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_.

    Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle,
    which claims that “unpublished work has shown even better results by using only two phases”.
    If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` .

    Also note that you should update learning rate each step.

    Args:
1658
        max_learning_rate (float): The maximum learning rate. It is a python float number. Functionally, it defines the initial learning rate by ``divide_factor`` .
1659
        total_steps (int): Number of total training steps.
1660
        divide_factor (float, optional): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25.
1661 1662
        end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate.
        phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3.
1663
        anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'.
1664
        three_phase (bool, optional): Whether to use three phase.
1665

1666
            If ``True``:
1667

1668 1669 1670
                1. The learning rate will first increase from initial learning rate to maximum learning rate.
                2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase.
                3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate.
1671

1672
            If ``False``:
1673

1674 1675
                1. The learning rate will increase to maximum learning rate.
                2. Then it will directly decrease to minimum learning rate.
1676

1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``OneCycleLR`` instance to schedule learning rate.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
1728

1729 1730
    """

1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742
    def __init__(
        self,
        max_learning_rate,
        total_steps,
        divide_factor=25.0,
        end_learning_rate=0.0001,
        phase_pct=0.3,
        anneal_strategy='cos',
        three_phase=False,
        last_epoch=-1,
        verbose=False,
    ):
1743 1744 1745
        # Check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
1746 1747 1748 1749
                "'max_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(max_learning_rate)
                )
            )
1750 1751 1752 1753 1754 1755
        if max_learning_rate < 0:
            raise ValueError("'max_learning_rate' must be a positive integer.")

        # Check type and value of end_learning_rate
        if not isinstance(end_learning_rate, (float, int)):
            raise TypeError(
1756 1757 1758 1759
                "'end_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(end_learning_rate)
                )
            )
1760 1761 1762 1763 1764
        if end_learning_rate < 0:
            raise ValueError("'end_learning_rate' must be a positive integer.")

        # Check type and value of total_steps
        if not isinstance(total_steps, int):
1765 1766
            raise TypeError(
                "'total_step' must be 'int', but received {}".format(
1767 1768 1769
                    type(total_steps)
                )
            )
1770 1771 1772 1773 1774 1775
        if total_steps <= 0:
            raise ValueError("'total_step' must be a positive integer.")
        self.total_steps = total_steps

        # Check type and value of pac_start
        if not isinstance(phase_pct, float):
1776 1777
            raise TypeError(
                "'phase_pct' must be 'float', but received {}".format(
1778 1779 1780
                    type(phase_pct)
                )
            )
1781 1782 1783
        if phase_pct < 0 or phase_pct > 1:
            raise ValueError(
                "'phase_pct' must be between 0 and 1, but received {}".format(
1784 1785 1786
                    phase_pct
                )
            )
1787 1788 1789 1790

        # Check type and value of divide_factor
        if not isinstance(divide_factor, (float, int)):
            raise TypeError(
1791 1792 1793 1794
                "'divide_factor' must be 'float' or 'int', but received {}".format(
                    type(divide_factor)
                )
            )
1795 1796 1797 1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816

        initial_lr = max_learning_rate / float(divide_factor)
        min_lr = float(end_learning_rate)

        if three_phase:
            if phase_pct >= 0.5:
                raise ValueError(
                    "When three_phase is True, 'phase_pct' must be less than 0.5"
                )
            # start step and end step of each phase.
            self._step_config = [
                0,
                phase_pct * self.total_steps - 1,
                2 * phase_pct * self.total_steps - 2,
                self.total_steps - 1,
                self.total_steps - 1,  # for the last step.
            ]
            # step size of each phase.
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[3] - self._step_config[2],
1817 1818
                self._step_config[3]
                - self._step_config[2],  # for the last step.
1819 1820 1821
            ]
            # start lr and end lr of each phase.
            self._lr_config = [
1822 1823 1824 1825
                initial_lr,
                max_learning_rate,
                initial_lr,
                min_lr,
1826 1827 1828
            ]
        else:
            self._step_config = [
1829 1830 1831 1832
                0,
                phase_pct * self.total_steps - 1,
                self.total_steps - 1,
                self.total_steps - 1,
1833 1834 1835 1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847
            ]
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[2] - self._step_config[1],
            ]
            self._lr_config = [initial_lr, max_learning_rate, min_lr]

        # Check anneal_strategy
        if anneal_strategy == 'cos':
            self.anneal_func = self._cos_annealing
        elif anneal_strategy == 'linear':
            self.anneal_func = self._linear_annealing
        else:
            raise ValueError(
1848 1849 1850 1851
                "'anneal_strategy' must by one of 'cos' or 'linear', but received {}".format(
                    anneal_strategy
                )
            )
1852
        super().__init__(initial_lr, last_epoch, verbose)
1853 1854 1855 1856 1857 1858 1859 1860 1861 1862 1863 1864 1865

    def _cos_annealing(self, start_lr, end_lr, pct):
        cos_out = math.cos(math.pi * pct) + 1
        return end_lr + (start_lr - end_lr) / 2.0 * cos_out

    def _linear_annealing(self, start_lr, end_lr, pct):
        return (end_lr - start_lr) * pct + start_lr

    def get_lr(self):
        current_step = self.last_epoch

        if current_step > self.total_steps:
            raise ValueError(
1866 1867 1868 1869
                "Tried to step {} times. However the number of total steps is {}".format(
                    current_step, self.total_steps
                )
            )
1870

1871
        for (i, (end_step, step_size)) in enumerate(
1872 1873
            zip(self._step_config[1:], self._steps_size)
        ):
1874 1875 1876 1877
            # i == len(self._lr_config) - 2 catch the last step, otherwise it will return None.
            if current_step <= end_step or i == len(self._lr_config) - 2:
                # self._step_config[i] means start step of a phase.
                percentage = (current_step - self._step_config[i]) / step_size
1878 1879 1880
                return self.anneal_func(
                    self._lr_config[i], self._lr_config[i + 1], percentage
                )
1881 1882 1883 1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923


class CyclicLR(LRScheduler):
    r"""
    Set the learning rate according to the cyclic learning rate (CLR) scheduler.
    The scheduler regards the process of learning rate adjustment as one cycle after another.
    It cycles the learning rate between two boundaries with a constant frequency.
    The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis.

    It has been proposed in `Cyclic Learning Rates for Training Neural Networks <https://arxiv.org/abs/1506.01186>`_.

    According to the paper, the cyclic learning rate schedule has three build-in scale methods:

    * "triangular": A basic triangular cycle without any amplitude scaling.
    * "triangular2": A basic triangular cycle that reduce initial amplitude by half each cycle.
    * "exp_range": A cycle that scales initial amplitude by scale function which is defined as :math:`gamma^{iterations}` .

    The initial amplitude is defined as max_learning_rate - base_learning_rate.
    Also note that you should update learning rate each step.

    Args:
        base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
            that set the base_learning_rate to 1/3 or 1/4 of max_learning_rate.
        max_learning_rate (float): Maximum learning rate in the cycle. It defines the cycle amplitude as above.
            Since there is some scaling operation during process of learning rate adjustment,
            max_learning_rate may not actually be reached.
        step_size_up (int): Number of training steps, which is used to increase learning rate in a cycle.
            The step size of one cycle will be defined by step_size_up + step_size_down. According to the paper, step
            size should be set as at least 3 or 4 times steps in one epoch.
        step_size_down (int, optional): Number of training steps, which is used to decrease learning rate in a cycle.
            If not specified, it's value will initialize to `` step_size_up `` . Default: None
        mode (str, optional): one of 'triangular', 'triangular2' or 'exp_range'.
            If scale_fn is specified, this argument will be ignored. Default: 'triangular'
        exp_gamma (float): Constant in 'exp_range' scaling function: exp_gamma**iterations. Used only when mode = 'exp_range'. Default: 1.0
        scale_fn (function, optional): A custom scaling function, which is used to replace three build-in methods.
            It should only have one argument. For all x >= 0, 0 <= scale_fn(x) <= 1.
            If specified, then 'mode' will be ignored. Default: None
        scale_mode (str, optional): One of 'cycle' or 'iterations'. Defines whether scale_fn is evaluated on cycle
            number or cycle iterations (total iterations since start of training). Default: 'cycle'
        last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate.
        verbose: (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
1924
        ``CyclicLR`` instance to schedule learning rate.
1925 1926 1927 1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
                    max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
    """

1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984 1985 1986
    def __init__(
        self,
        base_learning_rate,
        max_learning_rate,
        step_size_up,
        step_size_down=None,
        mode='triangular',
        exp_gamma=1.0,
        scale_fn=None,
        scale_mode='cycle',
        last_epoch=-1,
        verbose=False,
    ):
1987 1988 1989
        # check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
1990 1991 1992 1993
                "'max_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(max_learning_rate)
                )
            )
1994 1995
        if max_learning_rate < 0:
            raise ValueError(
1996 1997 1998 1999
                "'max_learning_rate' must be a positive integer, but received {}".format(
                    max_learning_rate
                )
            )
2000 2001 2002 2003

        # check type and value of step_size_up
        if not isinstance(step_size_up, int):
            raise TypeError(
2004 2005 2006 2007
                "The type of 'step_size_up' must be int, but received {}".format(
                    type(step_size_up)
                )
            )
2008 2009
        if step_size_up <= 0:
            raise ValueError(
2010 2011 2012 2013
                "'step_size_up' must be a positive integer, but received {}".format(
                    step_size_up
                )
            )
2014 2015 2016 2017 2018

        # check type and value of step_size_down
        if step_size_down is not None:
            if not isinstance(step_size_down, int):
                raise TypeError(
2019 2020 2021 2022
                    "The type of 'step_size_down' must be int, but received {}".format(
                        type(step_size_down)
                    )
                )
2023 2024
            if step_size_down <= 0:
                raise ValueError(
2025 2026 2027 2028
                    "'step_size_down' must be a positive integer, but received {}".format(
                        step_size_down
                    )
                )
2029 2030 2031 2032 2033

        # check type of exp_gamma
        if not isinstance(exp_gamma, float):
            raise TypeError(
                "The type of 'exp_gamma' must be float, but received {}".format(
2034 2035 2036
                    type(exp_gamma)
                )
            )
2037 2038

        step_size_up = float(step_size_up)
2039 2040 2041 2042 2043
        step_size_down = (
            float(step_size_down)
            if step_size_down is not None
            else step_size_up
        )
2044 2045 2046 2047 2048 2049

        self.cycle_size = step_size_up + step_size_down
        self.step_up_pct = step_size_up / self.cycle_size
        self.max_lr = float(max_learning_rate)
        self.amplitude = self.max_lr - base_learning_rate

2050 2051 2052 2053
        if (
            mode not in ['triangular', 'triangular2', 'exp_range']
            and scale_fn is None
        ):
2054 2055 2056 2057 2058
            raise ValueError(
                "'mode' is invalid and 'scale_fn' is not specified, make sure one of 'mode' or 'scale_fn' is valid"
            )
        if scale_mode not in ['cycle', 'iterations']:
            raise ValueError(
2059 2060
                "'scale_mode' must be one of 'cycle' or 'iterations"
            )
2061 2062 2063 2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080

        self.mode = mode
        self.gamma = exp_gamma  # only for exp_range mode

        if scale_fn is None:
            if self.mode == 'triangular':
                self.scale_fn = self._triangular_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'triangular2':
                self.scale_fn = self._triangular2_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'exp_range':
                self.scale_fn = self._exp_range_scale_fn
                self.scale_mode = 'iterations'
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode
        super().__init__(base_learning_rate, last_epoch, verbose)

    def _triangular_scale_fn(self, x):
2081
        return 1.0
2082 2083

    def _triangular2_scale_fn(self, x):
2084
        return 1 / (2.0 ** (x - 1))
2085 2086 2087 2088 2089 2090 2091 2092

    def _exp_range_scale_fn(self, x):
        return self.gamma**x

    def get_lr(self):
        iterations = self.last_epoch

        cycle = 1 + iterations // self.cycle_size
2093
        pct_per_cycle = 1.0 + iterations / self.cycle_size - cycle
2094 2095 2096 2097 2098 2099 2100 2101 2102 2103 2104

        if pct_per_cycle <= self.step_up_pct:
            scale_factor = pct_per_cycle / self.step_up_pct
        else:
            scale_factor = (1 - pct_per_cycle) / (1 - self.step_up_pct)

        base_height = self.amplitude * scale_factor

        lr = self.base_lr + base_height * self.scale_fn(eval(self.scale_mode))

        return lr