lr.py 60.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy
import warnings
from paddle import Tensor

__all__ = [
21 22 23 24
    'LRScheduler', 'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay',
    'InverseTimeDecay', 'PolynomialDecay', 'LinearWarmup', 'ExponentialDecay',
    'MultiStepDecay', 'StepDecay', 'LambdaDecay', 'ReduceOnPlateau',
    'CosineAnnealingDecay'
25 26 27
]


28 29 30 31 32
class LRScheduler(object):
    """

    LRScheduler Base class. Define the common interface of a learning rate scheduler.

Z
Zhou Wei 已提交
33
    User can import it by ``from paddle.optimizer.lr import LRScheduler`` ,
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52

    then overload it for your subclass and have a custom implementation of ``get_lr()`` .

    Otherwise, an ``NotImplementedError`` exception will be thrown.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        instance to schedule learning rate.

    Examples:
        Here is an example of a simple ``StepDecay`` implementation. 
        
        .. code-block:: python
            
            import paddle
Z
Zhou Wei 已提交
53
            from paddle.optimizer.lr import LRScheduler
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75

            class StepDecay(LRScheduler):
                def __init__(self,
                            learning_rate,
                            step_size,
                            gamma=0.1,
                            last_epoch=-1,
                            verbose=False):
                    if not isinstance(step_size, int):
                        raise TypeError(
                            "The type of 'step_size' must be 'int', but received %s." %
                            type(step_size))
                    if gamma >= 1.0:
                        raise ValueError('gamma should be < 1.0.')

                    self.step_size = step_size
                    self.gamma = gamma
                    super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)

                def get_lr(self):
                    i = self.last_epoch // self.step_size
                    return self.base_lr * (self.gamma**i)
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93

    """

    def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
                "The type of learning rate must be float, but received {}".
                format(type(learning_rate)))
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = last_epoch
        self.verbose = verbose
        self._var_name = None

        self.step()

    def __call__(self):
        """ 
94
        Return lastest computed learning rate on current epoch.
95 96 97 98 99
        """
        return self.last_lr

    def step(self, epoch=None):
        """
100 101 102

        ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .  
        The new learning rate will take effect on next ``optimizer.step`` .
103 104 105 106 107 108

        Args:
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
109

110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
        """
        if epoch is None:
            self.last_epoch += 1
            self.last_lr = self.get_lr()
        else:
            self.last_epoch = epoch
            if hasattr(self, "_get_closed_form_lr"):
                self.last_lr = self._get_closed_form_lr()
            else:
                self.last_lr = self.get_lr()

        if self.verbose:
            print('Epoch {}: {} set learning rate to {}.'.format(
                self.last_epoch, self.__class__.__name__, self.last_lr))

    def state_dict(self):
        """
127

128 129
        Returns the state of the scheduler as a :class:`dict`.

130
        It is a subset of ``self.__dict__`` .
131
        """
132
        self.state_keys()
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Tensor):
                assert value.shape == [
                    1
                ], "shape of Tensor in state_dict must be [1] {}".format(
                    value.shape)
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

148
    # For those subclass who overload LRScheduler, "last_epoch, last_lr" will be saved by default.
149
    # (Note): you can change it for your subclass.
150
    def state_keys(self):
151
        """
152 153 154 155 156 157 158

        For those subclass who overload ``LRScheduler`` (Base Class). Acquiescently, "last_epoch, last_lr" will be saved by ``self.keys = ['last_epoch', 'last_lr']`` .

        ``last_epoch`` is the current epoch num, and ``last_lr`` is the current learning rate.

        If you want to change the default behavior, you should have a custom implementation of ``_state_keys()`` to redefine ``self.keys`` .

159 160 161
        """
        self.keys = ['last_epoch', 'last_lr']

162
    def set_state_dict(self, state_dict):
163
        """
164

165 166
        Loads the schedulers state.
        """
167
        self.state_keys()
168 169 170 171 172 173 174 175 176 177 178 179
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".
                    format(key))
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

180 181
    # alias for set_state_dict
    set_dict = set_state_dict
182 183

    def get_lr(self):
184 185 186 187 188 189
        """
        
        For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .

        Otherwise, an ``NotImplementedError`` exception will be thrown.
        """
190 191 192 193
        # calculate by python float
        raise NotImplementedError


194
class NoamDecay(LRScheduler):
195 196
    """

197
    Applies Noam Decay to the initial learning rate. 
198 199 200 201 202 203 204 205 206 207 208 209 210 211 212

    The algorithm can be described as following.

    .. math::

        new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})

    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_ 


    Args:
        d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
        warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
        learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
213
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
214 215

    Returns:
216
        ``NoamDecay`` instance to schedule learning rate.
217 218 219 220 221 222 223

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

224
            # train on default dynamic graph mode
225
            linear = paddle.nn.Linear(10, 10)
226 227
            scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
228
            for epoch in range(20):
Z
Zhou Wei 已提交
229
                for batch_id in range(5):
230
                    x = paddle.uniform([10, 10])
231
                    out = linear(x)
C
chentianyu03 已提交
232
                    loss = paddle.mean(out)
233
                    loss.backward()
234 235
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
236 237
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
238

239
            # train on static graph mode
240 241 242 243
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
244 245
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
246 247
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
248
                scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
249 250 251 252 253 254
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
255
                for batch_id in range(5):
256 257 258 259 260 261
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
262
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
263 264
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
265 266 267 268 269 270 271 272 273 274 275

    """

    def __init__(self,
                 d_model,
                 warmup_steps,
                 learning_rate=1.0,
                 last_epoch=-1,
                 verbose=False):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
276
        super(NoamDecay, self).__init__(learning_rate, last_epoch, verbose)
277 278 279 280 281 282 283 284 285 286

    def get_lr(self):
        if self.last_epoch == 0:
            a = 1
        else:
            a = self.last_epoch**-0.5
        b = self.warmup_steps**-1.5 * self.last_epoch
        return self.base_lr * (self.d_model**-0.5) * min(a, b)


287
class PiecewiseDecay(LRScheduler):
288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309
    """

    Piecewise learning rate scheduler.

    The algorithm can be described as the code below:

    .. code-block:: text

        boundaries = [100, 200]
        values = [1.0, 0.5, 0.1]
        if epoch < 100:
            learning_rate = 1.0
        elif 100 <= global_step < 200:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Args:
        boundaries(list): A list of steps numbers. The type of element in the list is python int. 
        values(list): A list of learning rate values that will be picked during different epoch boundaries. 
            The type of element in the list is python float.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
310
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
311 312

    Returns:
313
        ``PiecewiseDecay`` instance to schedule learning rate.
314 315 316 317 318 319 320 321

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

322
            # train on default dynamic graph mode
323
            linear = paddle.nn.Linear(10, 10)
324 325
            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
326
            for epoch in range(20):
Z
Zhou Wei 已提交
327
                for batch_id in range(5):
328
                    x = paddle.uniform([10, 10])
329
                    out = linear(x)
C
chentianyu03 已提交
330
                    loss = paddle.mean(out)
331
                    loss.backward()
332 333
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
334 335
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
336

337
            # train on static graph mode
338 339 340 341
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
342 343
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
344 345
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
346
                scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
347 348 349 350 351 352
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
353
                for batch_id in range(5):
354 355 356 357 358 359
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
360
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
361 362
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
363 364 365 366 367
    """

    def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
        self.boundaries = boundaries
        self.values = values
368
        super(PiecewiseDecay, self).__init__(
369 370 371 372 373 374 375 376 377
            last_epoch=last_epoch, verbose=verbose)

    def get_lr(self):
        for i in range(len(self.boundaries)):
            if self.last_epoch < self.boundaries[i]:
                return self.values[i]
        return self.values[len(self.values) - 1]


378
class NaturalExpDecay(LRScheduler):
379 380 381 382 383 384 385 386
    """

    Applies natural exponential decay to the initial learning rate.
    
    The algorithm can be described as following:

    .. math::

387
        new\_learning\_rate = learning\_rate * e^{- gamma * epoch}
388 389 390 391 392

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        gamma (float, optional): A Ratio to update the learning rate. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
393
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
394 395

    Returns:
396
        ``NaturalExpDecay`` instance to schedule learning rate.
397 398 399 400 401 402 403 404

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

405
            # train on default dynamic graph mode
406
            linear = paddle.nn.Linear(10, 10)
407 408
            scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
409
            for epoch in range(20):
Z
Zhou Wei 已提交
410
                for batch_id in range(5):
411
                    x = paddle.uniform([10, 10])
412
                    out = linear(x)
C
chentianyu03 已提交
413
                    loss = paddle.mean(out)
414
                    loss.backward()
415 416
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
417 418
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
419

420
            # train on static graph mode
421 422 423 424
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
425 426
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
427 428
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
429
                scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
430 431 432 433 434 435
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
436
                for batch_id in range(5):
437 438 439 440 441 442
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
443
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
444 445
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
446 447 448 449
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
450 451
        super(NaturalExpDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
452 453 454 455 456

    def get_lr(self):
        return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)


457
class InverseTimeDecay(LRScheduler):
458 459 460 461 462 463 464 465 466 467 468 469 470 471 472
    """

    Applies inverse time decay to the initial learning rate.

    The algorithm can be described as following:

    .. math::

        new\_learning\_rate = \\frac{learning\_rate}{1 + gamma * epoch}

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
473
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
474 475

    Returns:
476
        ``InverseTimeDecay`` instance to schedule learning rate.
477 478 479 480 481 482 483 484

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

485
            # train on default dynamic graph mode
486
            linear = paddle.nn.Linear(10, 10)
487 488
            scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
489
            for epoch in range(20):
Z
Zhou Wei 已提交
490
                for batch_id in range(5):
491
                    x = paddle.uniform([10, 10])
492
                    out = linear(x)
C
chentianyu03 已提交
493
                    loss = paddle.mean(out)
494
                    loss.backward()
495 496
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
497 498
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
499

500
            # train on static graph mode
501 502 503 504
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
505 506
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
507 508
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
509
                scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
510 511 512 513 514 515
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
516
                for batch_id in range(5):
517 518 519 520 521 522
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
523
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
524 525
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
526 527 528 529 530

    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
531 532
        super(InverseTimeDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
533 534 535 536 537

    def get_lr(self):
        return self.base_lr / (1 + self.gamma * self.last_epoch)


538
class PolynomialDecay(LRScheduler):
539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569
    """

    Applies polynomial decay to the initial learning rate.

    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

        decay\_steps & = decay\_steps * math.ceil(\\frac{epoch}{decay\_steps}) 

        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr

    If cycle is set to False, then:

    .. math::

        epoch & = min(epoch, decay\_steps) 

        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr


    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
        power(float, optional): Power of polynomial. Default: 1.0.
        cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease 
            to ``end_lr`` .  If False, the learning rate is monotone decreasing. Default: False.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
570
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
571 572

    Returns:
573
        ``PolynomialDecay`` instance to schedule learning rate.
574 575 576 577 578 579 580 581

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

582
            # train on default dynamic graph mode
583
            linear = paddle.nn.Linear(10, 10)
584 585
            scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
586
            for epoch in range(20):
Z
Zhou Wei 已提交
587
                for batch_id in range(5):
588
                    x = paddle.uniform([10, 10])
589
                    out = linear(x)
C
chentianyu03 已提交
590
                    loss = paddle.mean(out)
591
                    loss.backward()
592 593
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
594 595
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
596

597
            # train on static graph mode
598 599 600 601
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
602 603
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
604 605
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
606
                scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
607 608 609 610 611 612
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
613
                for batch_id in range(5):
614 615 616 617 618 619
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
620
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
621 622
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
623 624 625 626 627 628 629 630 631 632 633 634 635 636
    """

    def __init__(self,
                 learning_rate,
                 decay_steps,
                 end_lr=0.0001,
                 power=1.0,
                 cycle=False,
                 last_epoch=-1,
                 verbose=False):
        self.decay_steps = decay_steps
        self.end_lr = end_lr
        self.power = power
        self.cycle = cycle
637 638
        super(PolynomialDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657

    def get_lr(self):
        tmp_epoch_num = self.last_epoch
        tmp_decay_steps = self.decay_steps
        if self.cycle:
            div_res = math.ceil(
                float(self.last_epoch) / float(self.decay_steps))

            if self.last_epoch == 0:
                div_res = 1
            tmp_decay_steps = self.decay_steps * div_res
        else:
            tmp_epoch_num = min(self.last_epoch, self.decay_steps)

        return (self.base_lr - self.end_lr) * (
            (1 - float(tmp_epoch_num) / float(tmp_decay_steps)
             )**self.power) + self.end_lr


658
class LinearWarmup(LRScheduler):
659 660 661 662 663 664 665
    """

    Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
    
    When epoch < warmup_steps, learning rate is updated as:
    
666
    .. math::
667
    
668
            lr = start\_lr + (end\_lr - start\_lr) * \\frac{epoch}{warmup\_steps}
669 670 671 672 673
    
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
    
    When epoch >= warmup_steps, learning rate is updated as:
    
674
    .. math::
675 676 677
    
            lr = learning_rate
    
678
    where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
679 680

    Args:
681
        learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
682 683 684 685
        warmup_steps (int): total steps of warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
686
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
687 688

    Returns:
689
        ``LinearWarmup`` instance to schedule learning rate.
690 691 692 693 694 695 696 697

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

698
            # train on default dynamic graph mode
699
            linear = paddle.nn.Linear(10, 10)
700
            scheduler = paddle.optimizer.lr.LinearWarmup(
701
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
702
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
703
            for epoch in range(20):
Z
Zhou Wei 已提交
704
                for batch_id in range(5):
705
                    x = paddle.uniform([10, 10])
706
                    out = linear(x)
C
chentianyu03 已提交
707
                    loss = paddle.mean(out)
708
                    loss.backward()
709 710
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
711 712
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
713

714
            # train on static graph mode
715 716 717 718
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
719 720
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
721 722
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
723
                scheduler = paddle.optimizer.lr.LinearWarmup(
724 725 726 727 728 729 730
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
731
                for batch_id in range(5):
732 733 734 735 736 737
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
738
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
739 740
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
741 742 743 744 745 746 747 748 749 750
    """

    def __init__(self,
                 learning_rate,
                 warmup_steps,
                 start_lr,
                 end_lr,
                 last_epoch=-1,
                 verbose=False):
        type_check = isinstance(learning_rate, float) or isinstance(
751
            learning_rate, int) or isinstance(learning_rate, LRScheduler)
752 753
        if not type_check:
            raise TypeError(
754
                "the type of learning_rate should be [int, float or LRScheduler], the current type is {}".
755 756 757 758 759 760 761
                format(learning_rate))
        self.learning_rate = learning_rate
        self.warmup_steps = warmup_steps
        self.start_lr = start_lr
        self.end_lr = end_lr
        assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
            end_lr, start_lr)
762
        super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose)
763

764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782
    def state_dict(self):
        """
        Returns the state of the LinearWarmup scheduler as a :class:`dict`.

        It is a subset of ``self.__dict__`` .
        """
        state_dict = super(LinearWarmup, self).state_dict()
        if isinstance(self.learning_rate, LRScheduler):
            state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
        return state_dict

    def set_state_dict(self, state_dict):
        """
        Loads state_dict for LinearWarmup scheduler.
        """
        super(LinearWarmup, self).set_state_dict(state_dict)
        if isinstance(self.learning_rate, LRScheduler):
            self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])

783 784 785 786 787
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return (self.end_lr - self.start_lr) * float(
                self.last_epoch) / float(self.warmup_steps) + self.start_lr
        else:
788
            if isinstance(self.learning_rate, LRScheduler):
789
                lr_value = self.learning_rate()
790
                self.learning_rate.step()
791
                return lr_value
792 793 794 795

            return self.learning_rate


796
class ExponentialDecay(LRScheduler):
797 798
    """

799
    Update learning rate by `gamma` each epoch.
800 801 802 803 804 805 806 807 808

    The algorithm can be described as following.
    
    .. math::

        new\_learning\_rate = last\_learning\_rate * gamma

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
809 810
        gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0.
811
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
812
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
813 814

    Returns:
815
        ``ExponentialDecay`` instance to schedule learning rate.
816 817 818 819 820 821 822 823

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

824
            # train on default dynamic graph mode
825
            linear = paddle.nn.Linear(10, 10)
826 827
            scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
828
            for epoch in range(20):
Z
Zhou Wei 已提交
829
                for batch_id in range(5):
830
                    x = paddle.uniform([10, 10])
831
                    out = linear(x)
C
chentianyu03 已提交
832
                    loss = paddle.mean(out)
833
                    loss.backward()
834 835
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
836 837
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
838

839
            # train on static graph mode
840 841 842 843
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
844 845
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
846 847
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
848
                scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
849 850 851 852 853 854
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
855
                for batch_id in range(5):
856 857 858 859 860 861
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
862
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
863 864
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
865 866 867 868
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
869 870
        super(ExponentialDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
871 872 873 874 875

    def get_lr(self):
        return self.base_lr * (self.gamma**self.last_epoch)


876
class MultiStepDecay(LRScheduler):
877
    """
878
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
900
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
901 902 903
        

    Returns:
904
        ``MultiStepDecay`` instance to schedule learning rate.
905 906 907 908 909 910 911 912

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

913
            # train on default dynamic graph mode
914
            linear = paddle.nn.Linear(10, 10)
915 916
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
917
            for epoch in range(20):
Z
Zhou Wei 已提交
918
                for batch_id in range(5):
919
                    x = paddle.uniform([10, 10])
920
                    out = linear(x)
C
chentianyu03 已提交
921
                    loss = paddle.mean(out)
922
                    loss.backward()
923 924
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
925 926
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
927

928
            # train on static graph mode
929 930 931 932
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
933 934
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
935 936
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
937
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
938 939 940 941 942 943
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
944
                for batch_id in range(5):
945 946 947 948 949 950
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
951
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
952 953
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976
    """

    def __init__(self,
                 learning_rate,
                 milestones,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
                % type(milestones))

        if not all([
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
        ]):
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.milestones = milestones
        self.gamma = gamma
977
        super(MultiStepDecay, self).__init__(learning_rate, last_epoch, verbose)
978 979 980 981 982 983 984 985

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
        return self.base_lr * (self.gamma**len(self.milestones))


986
class StepDecay(LRScheduler):
987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008
    """
    Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        gamma = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        step_size (int): the interval to update.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1009
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1010 1011

    Returns:
1012
        ``StepDecay`` instance to schedule learning rate.
1013 1014 1015 1016 1017 1018 1019 1020 1021


    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

1022
            # train on default dynamic graph mode
1023
            linear = paddle.nn.Linear(10, 10)
1024 1025
            scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1026
            for epoch in range(20):
Z
Zhou Wei 已提交
1027
                for batch_id in range(5):
1028
                    x = paddle.uniform([10, 10])
1029
                    out = linear(x)
C
chentianyu03 已提交
1030
                    loss = paddle.mean(out)
1031
                    loss.backward()
1032 1033
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1034 1035
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1036

1037
            # train on static graph mode
1038 1039 1040 1041
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1042 1043
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1044 1045
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1046
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
1047 1048 1049 1050 1051 1052
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1053
                for batch_id in range(5):
1054 1055 1056 1057 1058 1059
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1060
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1061 1062
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079
    """

    def __init__(self,
                 learning_rate,
                 step_size,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(step_size, int):
            raise TypeError(
                "The type of 'step_size' must be 'int', but received %s." %
                type(step_size))
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.step_size = step_size
        self.gamma = gamma
1080
        super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)
1081 1082 1083 1084 1085 1086

    def get_lr(self):
        i = self.last_epoch // self.step_size
        return self.base_lr * (self.gamma**i)


1087
class LambdaDecay(LRScheduler):
1088 1089 1090 1091 1092 1093 1094 1095 1096 1097
    """
    Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

1098 1099 1100
        learning_rate = 0.5        # epoch 0, 0.5*0.95**0
        learning_rate = 0.475      # epoch 1, 0.5*0.95**1
        learning_rate = 0.45125    # epoch 2, 0.5*0.95**2
1101 1102 1103 1104 1105

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1106
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1107 1108
    
    Returns:
1109
        ``LambdaDecay`` instance to schedule learning rate.
1110 1111 1112 1113 1114 1115 1116 1117

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

1118
            # train on default dynamic graph mode
1119
            linear = paddle.nn.Linear(10, 10)
1120 1121
            scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1122
            for epoch in range(20):
Z
Zhou Wei 已提交
1123
                for batch_id in range(5):
1124
                    x = paddle.uniform([10, 10])
1125
                    out = linear(x)
C
chentianyu03 已提交
1126
                    loss = paddle.mean(out)
1127
                    loss.backward()
1128 1129
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1130 1131
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1132

1133
            # train on static graph mode
1134 1135 1136 1137
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1138 1139
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1140 1141
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1142
                scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
1143 1144 1145 1146 1147 1148
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1149
                for batch_id in range(5):
1150 1151 1152 1153 1154 1155
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1156
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1157 1158
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1159 1160 1161 1162 1163 1164

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
1165
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1166 1167 1168
                % type(lr_lambda))

        self.lr_lambda = lr_lambda
1169
        super(LambdaDecay, self).__init__(learning_rate, last_epoch, verbose)
1170 1171 1172 1173 1174

    def get_lr(self):
        return self.base_lr * self.lr_lambda(self.last_epoch)


1175
class ReduceOnPlateau(LRScheduler):
1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202
    """
    Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate 
    by 2 to 10 times once model performance has no longer improvement.

    The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics`` 
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` . 
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience`` 
    number of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the 
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning 
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` . 
            It should be less than 1.0. Default: 0.1.
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced. 
            Default: 10.
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` . 
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum 
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
1203 1204
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon, 
            the update is ignored. Default: 1e-8.
1205 1206 1207 1208
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.

    
    Returns:
1209
        ``ReduceOnPlateau`` instance to schedule learning rate.
1210 1211 1212 1213 1214 1215 1216 1217


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

1218
            # train on default dynamic graph mode
1219
            linear = paddle.nn.Linear(10, 10)
1220 1221
            scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1222
            for epoch in range(20):
Z
Zhou Wei 已提交
1223
                for batch_id in range(5):
1224
                    x = paddle.uniform([10, 10])
1225
                    out = linear(x)
C
chentianyu03 已提交
1226
                    loss = paddle.mean(out)
1227
                    loss.backward()
1228 1229
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1230 1231
                    scheduler.step(loss)    # If you update learning rate each step
              # scheduler.step(loss)        # If you update learning rate each epoch
1232

1233
            # train on static graph mode
1234 1235 1236 1237
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1238 1239
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1240 1241
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1242
                scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
1243 1244 1245 1246 1247 1248
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1249
                for batch_id in range(5):
1250 1251 1252 1253 1254 1255
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1256
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1257 1258
                    scheduler.step(out[0])    # If you update learning rate each step
              # scheduler.step(out[0])        # If you update learning rate each epoch
1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289

    """

    def __init__(self,
                 learning_rate,
                 mode='min',
                 factor=0.1,
                 patience=10,
                 threshold=1e-4,
                 threshold_mode='rel',
                 cooldown=0,
                 min_lr=0,
                 epsilon=1e-8,
                 verbose=False):
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode: ' + mode + ' is unknown!')
        self.mode = mode

        if factor >= 1.0:
            raise ValueError(
                'new_lr = origin_lr * gamma and gamma should be < 1.0.')
        self.factor = factor

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
            raise ValueError('threshold mode: ' + threshold_mode +
                             ' is unknown!')
        self.threshold_mode = threshold_mode
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
1290
                "The type of 'learning_rate' in 'ReduceOnPlateau' must be 'float', but received %s."
1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311
                % type(learning_rate))

        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

        self.cooldown_counter = 0
        self.best = None
        self.num_bad_epochs = 0

        # Can not call Parent __init__, so implement here.
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = 0
        self.verbose = verbose
        self._var_name = None

    # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
1312
    def state_keys(self):
1313 1314 1315 1316 1317 1318 1319
        self.keys = [
            'cooldown_counter', 'best', 'num_bad_epochs', 'last_epoch',
            'last_lr'
        ]

    def step(self, metrics, epoch=None):
        """
1320
        step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .  
1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332
        The new learning rate will take effect on next epoch.

        Args:
            metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce. 
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
                'numpy.ndarray', its shape must be [1].
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
        
        Examples:
1333
            Please refer to the example of current LRScheduler.
1334 1335 1336 1337 1338 1339
        """
        if epoch is None:
            self.last_epoch = self.last_epoch + 1
        else:
            self.last_epoch = epoch

1340
        # loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384
        if isinstance(metrics, (Tensor, numpy.ndarray)):
            assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \
                "should be (1L,), but the current metrics.shape is {}. Maybe that "  \
                "you should call paddle.mean to process it first.".format(loss.shape)
        elif not isinstance(metrics,
                            (int, float, numpy.float32, numpy.float64)):
            raise TypeError(
                "metrics must be 'int', 'float', 'np.float', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".
                format(type(metrics)))

        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best is None or self._is_better(metrics, self.best):
                self.best = metrics
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = max(self.last_lr * self.factor, self.min_lr)
                if self.last_lr - new_lr > self.epsilon:
                    self.last_lr = new_lr
                    if self.verbose:
                        print('Epoch {}: {} set learning rate to {}.'.format(
                            self.last_epoch, self.__class__.__name__,
                            self.last_lr))

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold


1385
class CosineAnnealingDecay(LRScheduler):
1386 1387 1388 1389
    """

    Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to 
    the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in 
1390
    SGDR.
1391 1392 1393 1394

    The algorithm can be described as following.

    .. math::
1395 1396 1397 1398 1399 1400 1401

        \\begin{aligned}
            \eta_t & = \eta_{min} + \\frac{1}{2}(\eta_{max} - \eta_{min})\left(1
            + \cos\left(\\frac{T_{cur}}{T_{max}}\pi\\right)\\right),
            & T_{cur} \\neq (2k+1)T_{max}; \\
            \eta_{t+1} & = \eta_{t} + \\frac{1}{2}(\eta_{max} - \eta_{min})
            \left(1 - \cos\left(\\frac{1}{T_{max}}\pi\\right)\\right),
1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412
            & T_{cur} = (2k+1)T_{max}.
        \end{aligned}
    
    It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_. 
    Note that this only implements the cosine annealing part of SGDR, and not the restarts.
    
    Args:
        learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
        T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate.
        eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1413
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1414 1415

    Returns:
1416
        ``CosineAnnealingDecay`` instance to schedule learning rate.
1417 1418 1419 1420 1421 1422 1423 1424

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

1425
            # train on default dynamic graph mode
1426
            linear = paddle.nn.Linear(10, 10)
1427 1428
            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1429
            for epoch in range(20):
Z
Zhou Wei 已提交
1430
                for batch_id in range(5):
1431
                    x = paddle.uniform([10, 10])
1432
                    out = linear(x)
C
chentianyu03 已提交
1433
                    loss = paddle.mean(out)
1434
                    loss.backward()
1435 1436
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1437 1438
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1439

1440
            # train on static graph mode
1441 1442 1443 1444
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1445 1446
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1447 1448
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1449
                scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
1450 1451 1452 1453 1454 1455
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1456
                for batch_id in range(5):
1457 1458 1459 1460 1461 1462
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1463
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1464 1465
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1466 1467 1468 1469 1470 1471 1472 1473 1474 1475
    """

    def __init__(self,
                 learning_rate,
                 T_max,
                 eta_min=0,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(T_max, int):
            raise TypeError(
1476
                "The type of 'T_max' in 'CosineAnnealingDecay' must be 'int', but received %s."
1477 1478 1479
                % type(T_max))
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
1480
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
1481 1482 1483
                % type(eta_min))
        self.T_max = T_max
        self.eta_min = float(eta_min)
1484 1485
        super(CosineAnnealingDecay, self).__init__(learning_rate, last_epoch,
                                                   verbose)
1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
                math.pi / self.T_max)) / 2

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
                self.last_lr - self.eta_min) + self.eta_min

    def _get_closed_form_lr(self):
        return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
            math.pi * self.last_epoch / self.T_max)) / 2