lr.py 57.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy
import warnings
from paddle import Tensor

__all__ = [
21 22 23 24
    'LRScheduler', 'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay',
    'InverseTimeDecay', 'PolynomialDecay', 'LinearWarmup', 'ExponentialDecay',
    'MultiStepDecay', 'StepDecay', 'LambdaDecay', 'ReduceOnPlateau',
    'CosineAnnealingDecay'
25 26 27
]


28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
class LRScheduler(object):
    """

    LRScheduler Base class. Define the common interface of a learning rate scheduler.

    User can import it by ``form paddle.optimizer.lr import LRScheduler`` ,

    then overload it for your subclass and have a custom implementation of ``get_lr()`` .

    Otherwise, an ``NotImplementedError`` exception will be thrown.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        instance to schedule learning rate.

    Examples:
        Here is an example of a simple ``StepDecay`` implementation. 
        
        .. code-block:: python
            
            import paddle
            form paddle.optimizer.lr import LRScheduler

            class StepDecay(LRScheduler):
                def __init__(self,
                            learning_rate,
                            step_size,
                            gamma=0.1,
                            last_epoch=-1,
                            verbose=False):
                    if not isinstance(step_size, int):
                        raise TypeError(
                            "The type of 'step_size' must be 'int', but received %s." %
                            type(step_size))
                    if gamma >= 1.0:
                        raise ValueError('gamma should be < 1.0.')

                    self.step_size = step_size
                    self.gamma = gamma
                    super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)

                def get_lr(self):
                    i = self.last_epoch // self.step_size
                    return self.base_lr * (self.gamma**i)
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

    """

    def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
                "The type of learning rate must be float, but received {}".
                format(type(learning_rate)))
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = last_epoch
        self.verbose = verbose
        self._var_name = None

        self.step()

    def __call__(self):
        """ 
94
        Return lastest computed learning rate on current epoch.
95 96 97 98 99
        """
        return self.last_lr

    def step(self, epoch=None):
        """
100 101 102

        ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .  
        The new learning rate will take effect on next ``optimizer.step`` .
103 104 105 106 107 108

        Args:
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
109

110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        """
        if epoch is None:
            self.last_epoch += 1
            self.last_lr = self.get_lr()
        else:
            self.last_epoch = epoch
            if hasattr(self, "_get_closed_form_lr"):
                self.last_lr = self._get_closed_form_lr()
            else:
                self.last_lr = self.get_lr()

        if self.verbose:
            print('Epoch {}: {} set learning rate to {}.'.format(
                self.last_epoch, self.__class__.__name__, self.last_lr))

    def state_dict(self):
        """
127

128 129
        Returns the state of the scheduler as a :class:`dict`.

130
        It is a subset of ``self.__dict__`` .
131
        """
132
        self.state_keys()
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Tensor):
                assert value.shape == [
                    1
                ], "shape of Tensor in state_dict must be [1] {}".format(
                    value.shape)
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

148
    # For those subclass who overload LRScheduler, "last_epoch, last_lr" will be saved by default.
149
    # (Note): you can change it for your subclass.
150
    def state_keys(self):
151
        """
152 153 154 155 156 157 158

        For those subclass who overload ``LRScheduler`` (Base Class). Acquiescently, "last_epoch, last_lr" will be saved by ``self.keys = ['last_epoch', 'last_lr']`` .

        ``last_epoch`` is the current epoch num, and ``last_lr`` is the current learning rate.

        If you want to change the default behavior, you should have a custom implementation of ``_state_keys()`` to redefine ``self.keys`` .

159 160 161
        """
        self.keys = ['last_epoch', 'last_lr']

162
    def set_state_dict(self, state_dict):
163
        """
164

165 166
        Loads the schedulers state.
        """
167
        self.state_keys()
168 169 170 171 172 173 174 175 176 177 178 179
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".
                    format(key))
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

180 181
    # alias for set_state_dict
    set_dict = set_state_dict
182 183

    def get_lr(self):
184 185 186 187 188 189
        """
        
        For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .

        Otherwise, an ``NotImplementedError`` exception will be thrown.
        """
190 191 192 193
        # calculate by python float
        raise NotImplementedError


194
class NoamDecay(LRScheduler):
195 196
    """

197
    Applies Noam Decay to the initial learning rate. 
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212

    The algorithm can be described as following.

    .. math::

        new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})

    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_ 


    Args:
        d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
        warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
        learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
213
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
214 215

    Returns:
216
        ``NoamDecay`` instance to schedule learning rate.
217 218 219 220 221 222 223

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

224
            # train on default dynamic graph mode
225
            linear = paddle.nn.Linear(10, 10)
226 227
            scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
228 229
            for epoch in range(20):
                for batch_id in range(2):
230
                    x = paddle.uniform([10, 10])
231
                    out = linear(x)
232
                    loss = paddle.fluid.layers.reduce_mean(out)
233
                    loss.backward()
234 235
                    sgd.step()
                    sgd.clear_gradients()
236 237
                scheduler.step()

238
            # train on static graph mode
239 240 241 242
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
243 244
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
245 246
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
247
                scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
248 249 250 251 252 253 254 255 256 257 258 259 260
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
261
                        fetch_list=loss.name)
262 263 264 265 266 267 268 269 270 271 272 273
                scheduler.step()

    """

    def __init__(self,
                 d_model,
                 warmup_steps,
                 learning_rate=1.0,
                 last_epoch=-1,
                 verbose=False):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
274
        super(NoamDecay, self).__init__(learning_rate, last_epoch, verbose)
275 276 277 278 279 280 281 282 283 284

    def get_lr(self):
        if self.last_epoch == 0:
            a = 1
        else:
            a = self.last_epoch**-0.5
        b = self.warmup_steps**-1.5 * self.last_epoch
        return self.base_lr * (self.d_model**-0.5) * min(a, b)


285
class PiecewiseDecay(LRScheduler):
286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
    """

    Piecewise learning rate scheduler.

    The algorithm can be described as the code below:

    .. code-block:: text

        boundaries = [100, 200]
        values = [1.0, 0.5, 0.1]
        if epoch < 100:
            learning_rate = 1.0
        elif 100 <= global_step < 200:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Args:
        boundaries(list): A list of steps numbers. The type of element in the list is python int. 
        values(list): A list of learning rate values that will be picked during different epoch boundaries. 
            The type of element in the list is python float.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
308
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
309 310

    Returns:
311
        ``PiecewiseDecay`` instance to schedule learning rate.
312 313 314 315 316 317 318 319

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

320
            # train on default dynamic graph mode
321
            linear = paddle.nn.Linear(10, 10)
322 323
            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
324 325
            for epoch in range(20):
                for batch_id in range(2):
326
                    x = paddle.uniform([10, 10])
327
                    out = linear(x)
328
                    loss = paddle.fluid.layers.reduce_mean(out)
329
                    loss.backward()
330 331
                    sgd.step()
                    sgd.clear_gradients()
332 333
                scheduler.step()

334
            # train on static graph mode
335 336 337 338
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
339 340
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
341 342
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
343
                scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
344 345 346 347 348 349 350 351 352 353 354 355 356
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
357
                        fetch_list=loss.name)
358 359 360 361 362 363
                scheduler.step()
    """

    def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
        self.boundaries = boundaries
        self.values = values
364
        super(PiecewiseDecay, self).__init__(
365 366 367 368 369 370 371 372 373 374
            last_epoch=last_epoch, verbose=verbose)

    def get_lr(self):

        for i in range(len(self.boundaries)):
            if self.last_epoch < self.boundaries[i]:
                return self.values[i]
        return self.values[len(self.values) - 1]


375
class NaturalExpDecay(LRScheduler):
376 377 378 379 380 381 382 383
    """

    Applies natural exponential decay to the initial learning rate.
    
    The algorithm can be described as following:

    .. math::

384
        new\_learning\_rate = learning\_rate * e^{- gamma * epoch}
385 386 387 388 389

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        gamma (float, optional): A Ratio to update the learning rate. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
390
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
391 392

    Returns:
393
        ``NaturalExpDecay`` instance to schedule learning rate.
394 395 396 397 398 399 400 401

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

402
            # train on default dynamic graph mode
403
            linear = paddle.nn.Linear(10, 10)
404 405
            scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
406 407
            for epoch in range(20):
                for batch_id in range(2):
408
                    x = paddle.uniform([10, 10])
409
                    out = linear(x)
410
                    loss = paddle.fluid.layers.reduce_mean(out)
411
                    loss.backward()
412 413
                    sgd.step()
                    sgd.clear_gradients()
414 415
                scheduler.step()

416
            # train on static graph mode
417 418 419 420
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
421 422
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
423 424
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
425
                scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
426 427 428 429 430 431 432 433 434 435 436 437 438
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
439
                        fetch_list=loss.name)
440 441 442 443 444
                scheduler.step()
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
445 446
        super(NaturalExpDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
447 448 449 450 451

    def get_lr(self):
        return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)


452
class InverseTimeDecay(LRScheduler):
453 454 455 456 457 458 459 460 461 462 463 464 465 466 467
    """

    Applies inverse time decay to the initial learning rate.

    The algorithm can be described as following:

    .. math::

        new\_learning\_rate = \\frac{learning\_rate}{1 + gamma * epoch}

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
468
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
469 470

    Returns:
471
        ``InverseTimeDecay`` instance to schedule learning rate.
472 473 474 475 476 477 478 479

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

480
            # train on default dynamic graph mode
481
            linear = paddle.nn.Linear(10, 10)
482 483
            scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
484 485
            for epoch in range(20):
                for batch_id in range(2):
486
                    x = paddle.uniform([10, 10])
487
                    out = linear(x)
488
                    loss = paddle.fluid.layers.reduce_mean(out)
489
                    loss.backward()
490 491
                    sgd.step()
                    sgd.clear_gradients()
492 493
                scheduler.step()

494
            # train on static graph mode
495 496 497 498
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
499 500
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
501 502
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
503
                scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
504 505 506 507 508 509 510 511 512 513 514 515 516
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
517
                        fetch_list=loss.name)
518 519 520 521 522 523
                scheduler.step()

    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
524 525
        super(InverseTimeDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
526 527 528 529 530

    def get_lr(self):
        return self.base_lr / (1 + self.gamma * self.last_epoch)


531
class PolynomialDecay(LRScheduler):
532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562
    """

    Applies polynomial decay to the initial learning rate.

    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

        decay\_steps & = decay\_steps * math.ceil(\\frac{epoch}{decay\_steps}) 

        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr

    If cycle is set to False, then:

    .. math::

        epoch & = min(epoch, decay\_steps) 

        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr


    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
        power(float, optional): Power of polynomial. Default: 1.0.
        cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease 
            to ``end_lr`` .  If False, the learning rate is monotone decreasing. Default: False.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
563
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
564 565

    Returns:
566
        ``PolynomialDecay`` instance to schedule learning rate.
567 568 569 570 571 572 573 574

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

575
            # train on default dynamic graph mode
576
            linear = paddle.nn.Linear(10, 10)
577 578
            scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
579 580
            for epoch in range(20):
                for batch_id in range(2):
581
                    x = paddle.uniform([10, 10])
582
                    out = linear(x)
583
                    loss = paddle.fluid.layers.reduce_mean(out)
584
                    loss.backward()
585 586
                    sgd.step()
                    sgd.clear_gradients()
587 588
                scheduler.step()

589
            # train on static graph mode
590 591 592 593
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
594 595
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
596 597
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
598
                scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
599 600 601 602 603 604 605 606 607 608 609 610 611
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
612
                        fetch_list=loss.name)
613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
                scheduler.step()
    """

    def __init__(self,
                 learning_rate,
                 decay_steps,
                 end_lr=0.0001,
                 power=1.0,
                 cycle=False,
                 last_epoch=-1,
                 verbose=False):
        self.decay_steps = decay_steps
        self.end_lr = end_lr
        self.power = power
        self.cycle = cycle
628 629
        super(PolynomialDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648

    def get_lr(self):
        tmp_epoch_num = self.last_epoch
        tmp_decay_steps = self.decay_steps
        if self.cycle:
            div_res = math.ceil(
                float(self.last_epoch) / float(self.decay_steps))

            if self.last_epoch == 0:
                div_res = 1
            tmp_decay_steps = self.decay_steps * div_res
        else:
            tmp_epoch_num = min(self.last_epoch, self.decay_steps)

        return (self.base_lr - self.end_lr) * (
            (1 - float(tmp_epoch_num) / float(tmp_decay_steps)
             )**self.power) + self.end_lr


649
class LinearWarmup(LRScheduler):
650 651 652 653 654 655 656
    """

    Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
    
    When epoch < warmup_steps, learning rate is updated as:
    
657
    .. math::
658
    
659
            lr = start\_lr + (end\_lr - start\_lr) * \\frac{epoch}{warmup\_steps}
660 661 662 663 664
    
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
    
    When epoch >= warmup_steps, learning rate is updated as:
    
665
    .. math::
666 667 668
    
            lr = learning_rate
    
669
    where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
670 671

    Args:
672
        learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
673 674 675 676
        warmup_steps (int): total steps of warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
677
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
678 679

    Returns:
680
        ``LinearWarmup`` instance to schedule learning rate.
681 682 683 684 685 686 687 688

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

689
            # train on default dynamic graph mode
690
            linear = paddle.nn.Linear(10, 10)
691
            scheduler = paddle.optimizer.lr.LinearWarmup(
692
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
693
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
694 695
            for epoch in range(20):
                for batch_id in range(2):
696
                    x = paddle.uniform([10, 10])
697
                    out = linear(x)
698
                    loss = paddle.fluid.layers.reduce_mean(out)
699
                    loss.backward()
700 701
                    sgd.step()
                    sgd.clear_gradients()
702 703
                scheduler.step()

704
            # train on static graph mode
705 706 707 708
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
709 710
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
711 712
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
713
                scheduler = paddle.optimizer.lr.LinearWarmup(
714 715 716 717 718 719 720 721 722 723 724 725 726 727
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
728
                        fetch_list=loss.name)
729
                scheduler.step()
730 731 732 733 734 735 736 737 738 739
    """

    def __init__(self,
                 learning_rate,
                 warmup_steps,
                 start_lr,
                 end_lr,
                 last_epoch=-1,
                 verbose=False):
        type_check = isinstance(learning_rate, float) or isinstance(
740
            learning_rate, int) or isinstance(learning_rate, LRScheduler)
741 742
        if not type_check:
            raise TypeError(
743
                "the type of learning_rate should be [int, float or LRScheduler], the current type is {}".
744 745 746 747 748 749 750
                format(learning_rate))
        self.learning_rate = learning_rate
        self.warmup_steps = warmup_steps
        self.start_lr = start_lr
        self.end_lr = end_lr
        assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
            end_lr, start_lr)
751
        super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose)
752 753 754 755 756 757

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return (self.end_lr - self.start_lr) * float(
                self.last_epoch) / float(self.warmup_steps) + self.start_lr
        else:
758
            if isinstance(self.learning_rate, LRScheduler):
759 760 761 762 763 764
                self.learning_rate.step()
                return self.learning_rate()

            return self.learning_rate


765
class ExponentialDecay(LRScheduler):
766 767
    """

768
    Update learning rate by `gamma` each epoch.
769 770 771 772 773 774 775 776 777

    The algorithm can be described as following.
    
    .. math::

        new\_learning\_rate = last\_learning\_rate * gamma

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
778 779
        gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0.
780
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
781
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
782 783

    Returns:
784
        ``ExponentialDecay`` instance to schedule learning rate.
785 786 787 788 789 790 791 792

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

793
            # train on default dynamic graph mode
794
            linear = paddle.nn.Linear(10, 10)
795 796
            scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
797 798
            for epoch in range(20):
                for batch_id in range(2):
799
                    x = paddle.uniform([10, 10])
800
                    out = linear(x)
801
                    loss = paddle.fluid.layers.reduce_mean(out)
802
                    loss.backward()
803 804
                    sgd.step()
                    sgd.clear_gradients()
805 806
                scheduler.step()

807
            # train on static graph mode
808 809 810 811
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
812 813
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
814 815
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
816
                scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
817 818 819 820 821 822 823 824 825 826 827 828 829
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
830
                        fetch_list=loss.name)
831 832 833 834 835
                scheduler.step()
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
836 837
        super(ExponentialDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
838 839 840 841 842

    def get_lr(self):
        return self.base_lr * (self.gamma**self.last_epoch)


843
class MultiStepDecay(LRScheduler):
844
    """
845
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
867
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
868 869 870
        

    Returns:
871
        ``MultiStepDecay`` instance to schedule learning rate.
872 873 874 875 876 877 878 879

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

880
            # train on default dynamic graph mode
881
            linear = paddle.nn.Linear(10, 10)
882 883
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
884 885
            for epoch in range(20):
                for batch_id in range(2):
886
                    x = paddle.uniform([10, 10])
887
                    out = linear(x)
888
                    loss = paddle.fluid.layers.reduce_mean(out)
889
                    loss.backward()
890 891
                    sgd.step()
                    sgd.clear_gradients()
892 893
                scheduler.step()

894
            # train on static graph mode
895 896 897 898
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
899 900
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
901 902
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
903
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
904 905 906 907 908 909 910 911 912 913 914 915 916
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
917
                        fetch_list=loss.name)
918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941
                scheduler.step()
    """

    def __init__(self,
                 learning_rate,
                 milestones,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
                % type(milestones))

        if not all([
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
        ]):
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.milestones = milestones
        self.gamma = gamma
942
        super(MultiStepDecay, self).__init__(learning_rate, last_epoch, verbose)
943 944 945 946 947 948 949 950

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
        return self.base_lr * (self.gamma**len(self.milestones))


951
class StepDecay(LRScheduler):
952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973
    """
    Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        gamma = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        step_size (int): the interval to update.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
974
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
975 976

    Returns:
977
        ``StepDecay`` instance to schedule learning rate.
978 979 980 981 982 983 984 985 986


    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

987
            # train on default dynamic graph mode
988
            linear = paddle.nn.Linear(10, 10)
989 990
            scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
991 992
            for epoch in range(20):
                for batch_id in range(2):
993
                    x = paddle.uniform([10, 10])
994
                    out = linear(x)
995
                    loss = paddle.fluid.layers.reduce_mean(out)
996
                    loss.backward()
997 998
                    sgd.step()
                    sgd.clear_gradients()
999 1000
                scheduler.step()

1001
            # train on static graph mode
1002 1003 1004 1005
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1006 1007
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1008 1009
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1010
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1024
                        fetch_list=loss.name)
1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042
                scheduler.step()
    """

    def __init__(self,
                 learning_rate,
                 step_size,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(step_size, int):
            raise TypeError(
                "The type of 'step_size' must be 'int', but received %s." %
                type(step_size))
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.step_size = step_size
        self.gamma = gamma
1043
        super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)
1044 1045 1046 1047 1048 1049

    def get_lr(self):
        i = self.last_epoch // self.step_size
        return self.base_lr * (self.gamma**i)


1050
class LambdaDecay(LRScheduler):
1051 1052 1053 1054 1055 1056 1057 1058 1059 1060
    """
    Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

1061 1062 1063
        learning_rate = 0.5        # epoch 0, 0.5*0.95**0
        learning_rate = 0.475      # epoch 1, 0.5*0.95**1
        learning_rate = 0.45125    # epoch 2, 0.5*0.95**2
1064 1065 1066 1067 1068

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1069
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1070 1071
    
    Returns:
1072
        ``LambdaDecay`` instance to schedule learning rate.
1073 1074 1075 1076 1077 1078 1079 1080

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

1081
            # train on default dynamic graph mode
1082
            linear = paddle.nn.Linear(10, 10)
1083 1084
            scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1085 1086
            for epoch in range(20):
                for batch_id in range(2):
1087
                    x = paddle.uniform([10, 10])
1088
                    out = linear(x)
1089
                    loss = paddle.fluid.layers.reduce_mean(out)
1090
                    loss.backward()
1091 1092
                    sgd.step()
                    sgd.clear_gradients()
1093 1094
                scheduler.step()

1095
            # train on static graph mode
1096 1097 1098 1099
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1100 1101
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1102 1103
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1104
                scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1118
                        fetch_list=loss.name)
1119 1120 1121 1122 1123 1124 1125
                scheduler.step()

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
1126
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1127 1128 1129
                % type(lr_lambda))

        self.lr_lambda = lr_lambda
1130
        super(LambdaDecay, self).__init__(learning_rate, last_epoch, verbose)
1131 1132 1133 1134 1135

    def get_lr(self):
        return self.base_lr * self.lr_lambda(self.last_epoch)


1136
class ReduceOnPlateau(LRScheduler):
1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163
    """
    Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate 
    by 2 to 10 times once model performance has no longer improvement.

    The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics`` 
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` . 
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience`` 
    number of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the 
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning 
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` . 
            It should be less than 1.0. Default: 0.1.
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced. 
            Default: 10.
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` . 
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum 
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
1164 1165
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon, 
            the update is ignored. Default: 1e-8.
1166 1167 1168 1169
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.

    
    Returns:
1170
        ``ReduceOnPlateau`` instance to schedule learning rate.
1171 1172 1173 1174 1175 1176 1177 1178


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

1179
            # train on default dynamic graph mode
1180
            linear = paddle.nn.Linear(10, 10)
1181 1182
            scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1183 1184
            for epoch in range(20):
                for batch_id in range(2):
1185
                    x = paddle.uniform([10, 10])
1186
                    out = linear(x)
1187
                    loss = paddle.fluid.layers.reduce_mean(out)
1188
                    loss.backward()
1189 1190
                    sgd.step()
                    sgd.clear_gradients()
1191 1192
                scheduler.step(loss)

1193
            # train on static graph mode
1194 1195 1196 1197
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1198 1199
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1200 1201
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1202
                scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1216
                        fetch_list=loss.name)
1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248
                scheduler.step(out[0])

    """

    def __init__(self,
                 learning_rate,
                 mode='min',
                 factor=0.1,
                 patience=10,
                 threshold=1e-4,
                 threshold_mode='rel',
                 cooldown=0,
                 min_lr=0,
                 epsilon=1e-8,
                 verbose=False):
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode: ' + mode + ' is unknown!')
        self.mode = mode

        if factor >= 1.0:
            raise ValueError(
                'new_lr = origin_lr * gamma and gamma should be < 1.0.')
        self.factor = factor

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
            raise ValueError('threshold mode: ' + threshold_mode +
                             ' is unknown!')
        self.threshold_mode = threshold_mode
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
1249
                "The type of 'learning_rate' in 'ReduceOnPlateau' must be 'float', but received %s."
1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271
                % type(learning_rate))

        self.verbose = verbose
        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

        self.cooldown_counter = 0
        self.best = None
        self.num_bad_epochs = 0

        # Can not call Parent __init__, so implement here.
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = 0
        self.verbose = verbose
        self._var_name = None

    # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
1272
    def state_keys(self):
1273 1274 1275 1276 1277 1278 1279
        self.keys = [
            'cooldown_counter', 'best', 'num_bad_epochs', 'last_epoch',
            'last_lr'
        ]

    def step(self, metrics, epoch=None):
        """
1280
        step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .  
1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292
        The new learning rate will take effect on next epoch.

        Args:
            metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce. 
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
                'numpy.ndarray', its shape must be [1].
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
        
        Examples:
1293
            Please refer to the example of current LRScheduler.
1294 1295 1296 1297 1298 1299
        """
        if epoch is None:
            self.last_epoch = self.last_epoch + 1
        else:
            self.last_epoch = epoch

1300
        # loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344
        if isinstance(metrics, (Tensor, numpy.ndarray)):
            assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \
                "should be (1L,), but the current metrics.shape is {}. Maybe that "  \
                "you should call paddle.mean to process it first.".format(loss.shape)
        elif not isinstance(metrics,
                            (int, float, numpy.float32, numpy.float64)):
            raise TypeError(
                "metrics must be 'int', 'float', 'np.float', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".
                format(type(metrics)))

        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best is None or self._is_better(metrics, self.best):
                self.best = metrics
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = max(self.last_lr * self.factor, self.min_lr)
                if self.last_lr - new_lr > self.epsilon:
                    self.last_lr = new_lr
                    if self.verbose:
                        print('Epoch {}: {} set learning rate to {}.'.format(
                            self.last_epoch, self.__class__.__name__,
                            self.last_lr))

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold


1345
class CosineAnnealingDecay(LRScheduler):
1346 1347 1348 1349
    """

    Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to 
    the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in 
1350
    SGDR.
1351 1352 1353 1354

    The algorithm can be described as following.

    .. math::
1355 1356 1357 1358 1359 1360 1361

        \\begin{aligned}
            \eta_t & = \eta_{min} + \\frac{1}{2}(\eta_{max} - \eta_{min})\left(1
            + \cos\left(\\frac{T_{cur}}{T_{max}}\pi\\right)\\right),
            & T_{cur} \\neq (2k+1)T_{max}; \\
            \eta_{t+1} & = \eta_{t} + \\frac{1}{2}(\eta_{max} - \eta_{min})
            \left(1 - \cos\left(\\frac{1}{T_{max}}\pi\\right)\\right),
1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372
            & T_{cur} = (2k+1)T_{max}.
        \end{aligned}
    
    It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_. 
    Note that this only implements the cosine annealing part of SGDR, and not the restarts.
    
    Args:
        learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
        T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate.
        eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1373
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1374 1375

    Returns:
1376
        ``CosineAnnealingDecay`` instance to schedule learning rate.
1377 1378 1379 1380 1381 1382 1383 1384

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

1385
            # train on default dynamic graph mode
1386
            linear = paddle.nn.Linear(10, 10)
1387 1388
            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1389 1390
            for epoch in range(20):
                for batch_id in range(2):
1391
                    x = paddle.uniform([10, 10])
1392
                    out = linear(x)
1393
                    loss = paddle.fluid.layers.reduce_mean(out)
1394
                    loss.backward()
1395 1396
                    sgd.step()
                    sgd.clear_gradients()
1397 1398
                scheduler.step()

1399
            # train on static graph mode
1400 1401 1402 1403
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1404 1405
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1406 1407
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1408
                scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
                for batch_id in range(2):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1422
                        fetch_list=loss.name)
1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433
                scheduler.step()
    """

    def __init__(self,
                 learning_rate,
                 T_max,
                 eta_min=0,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(T_max, int):
            raise TypeError(
1434
                "The type of 'T_max' in 'CosineAnnealingDecay' must be 'int', but received %s."
1435 1436 1437
                % type(T_max))
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
1438
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
1439 1440 1441
                % type(eta_min))
        self.T_max = T_max
        self.eta_min = float(eta_min)
1442 1443
        super(CosineAnnealingDecay, self).__init__(learning_rate, last_epoch,
                                                   verbose)
1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
                math.pi / self.T_max)) / 2

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
                self.last_lr - self.eta_min) + self.eta_min

    def _get_closed_form_lr(self):
        return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
            math.pi * self.last_epoch / self.T_max)) / 2