lr.py 60.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy
import warnings
from paddle import Tensor

20 21 22 23 24 25 26 27 28 29 30 31 32
__all__ = [ #noqa
    'LRScheduler',
    'NoamDecay',
    'PiecewiseDecay',
    'NaturalExpDecay',
    'InverseTimeDecay',
    'PolynomialDecay',
    'LinearWarmup',
    'ExponentialDecay',
    'MultiStepDecay',
    'StepDecay',
    'LambdaDecay',
    'ReduceOnPlateau',
33
    'CosineAnnealingDecay'
34 35 36
]


37 38 39 40 41
class LRScheduler(object):
    """

    LRScheduler Base class. Define the common interface of a learning rate scheduler.

Z
Zhou Wei 已提交
42
    User can import it by ``from paddle.optimizer.lr import LRScheduler`` ,
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

    then overload it for your subclass and have a custom implementation of ``get_lr()`` .

    Otherwise, an ``NotImplementedError`` exception will be thrown.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        instance to schedule learning rate.

    Examples:
        Here is an example of a simple ``StepDecay`` implementation. 
        
        .. code-block:: python
            
            import paddle
Z
Zhou Wei 已提交
62
            from paddle.optimizer.lr import LRScheduler
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84

            class StepDecay(LRScheduler):
                def __init__(self,
                            learning_rate,
                            step_size,
                            gamma=0.1,
                            last_epoch=-1,
                            verbose=False):
                    if not isinstance(step_size, int):
                        raise TypeError(
                            "The type of 'step_size' must be 'int', but received %s." %
                            type(step_size))
                    if gamma >= 1.0:
                        raise ValueError('gamma should be < 1.0.')

                    self.step_size = step_size
                    self.gamma = gamma
                    super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)

                def get_lr(self):
                    i = self.last_epoch // self.step_size
                    return self.base_lr * (self.gamma**i)
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102

    """

    def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
                "The type of learning rate must be float, but received {}".
                format(type(learning_rate)))
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = last_epoch
        self.verbose = verbose
        self._var_name = None

        self.step()

    def __call__(self):
        """ 
103
        Return lastest computed learning rate on current epoch.
104 105 106 107 108
        """
        return self.last_lr

    def step(self, epoch=None):
        """
109 110 111

        ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .  
        The new learning rate will take effect on next ``optimizer.step`` .
112 113 114 115 116 117

        Args:
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
118

119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
        """
        if epoch is None:
            self.last_epoch += 1
            self.last_lr = self.get_lr()
        else:
            self.last_epoch = epoch
            if hasattr(self, "_get_closed_form_lr"):
                self.last_lr = self._get_closed_form_lr()
            else:
                self.last_lr = self.get_lr()

        if self.verbose:
            print('Epoch {}: {} set learning rate to {}.'.format(
                self.last_epoch, self.__class__.__name__, self.last_lr))

    def state_dict(self):
        """
136

137 138
        Returns the state of the scheduler as a :class:`dict`.

139
        It is a subset of ``self.__dict__`` .
140
        """
141
        self.state_keys()
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Tensor):
                assert value.shape == [
                    1
                ], "shape of Tensor in state_dict must be [1] {}".format(
                    value.shape)
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

157
    # For those subclass who overload LRScheduler, "last_epoch, last_lr" will be saved by default.
158
    # (Note): you can change it for your subclass.
159
    def state_keys(self):
160
        """
161 162 163 164 165 166 167

        For those subclass who overload ``LRScheduler`` (Base Class). Acquiescently, "last_epoch, last_lr" will be saved by ``self.keys = ['last_epoch', 'last_lr']`` .

        ``last_epoch`` is the current epoch num, and ``last_lr`` is the current learning rate.

        If you want to change the default behavior, you should have a custom implementation of ``_state_keys()`` to redefine ``self.keys`` .

168 169 170
        """
        self.keys = ['last_epoch', 'last_lr']

171
    def set_state_dict(self, state_dict):
172
        """
173

174 175
        Loads the schedulers state.
        """
176
        self.state_keys()
177 178 179 180 181 182 183 184 185 186 187 188
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".
                    format(key))
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

189 190
    # alias for set_state_dict
    set_dict = set_state_dict
191 192

    def get_lr(self):
193 194 195 196 197 198
        """
        
        For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .

        Otherwise, an ``NotImplementedError`` exception will be thrown.
        """
199 200 201 202
        # calculate by python float
        raise NotImplementedError


203
class NoamDecay(LRScheduler):
204
    r"""
205

206
    Applies Noam Decay to the initial learning rate. 
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221

    The algorithm can be described as following.

    .. math::

        new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})

    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_ 


    Args:
        d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
        warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
        learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
222
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
223 224

    Returns:
225
        ``NoamDecay`` instance to schedule learning rate.
226 227 228 229 230 231 232

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

233
            # train on default dynamic graph mode
234
            linear = paddle.nn.Linear(10, 10)
235 236
            scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
237
            for epoch in range(20):
Z
Zhou Wei 已提交
238
                for batch_id in range(5):
239
                    x = paddle.uniform([10, 10])
240
                    out = linear(x)
C
chentianyu03 已提交
241
                    loss = paddle.mean(out)
242
                    loss.backward()
243 244
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
245 246
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
247

248
            # train on static graph mode
249 250 251 252
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
253 254
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
255 256
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
257
                scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
258 259 260 261 262 263
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
264
                for batch_id in range(5):
265 266 267 268 269 270
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
271
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
272 273
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
274 275 276 277 278 279 280 281 282 283 284

    """

    def __init__(self,
                 d_model,
                 warmup_steps,
                 learning_rate=1.0,
                 last_epoch=-1,
                 verbose=False):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
285
        super(NoamDecay, self).__init__(learning_rate, last_epoch, verbose)
286 287 288 289 290 291 292 293 294 295

    def get_lr(self):
        if self.last_epoch == 0:
            a = 1
        else:
            a = self.last_epoch**-0.5
        b = self.warmup_steps**-1.5 * self.last_epoch
        return self.base_lr * (self.d_model**-0.5) * min(a, b)


296
class PiecewiseDecay(LRScheduler):
297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
    """

    Piecewise learning rate scheduler.

    The algorithm can be described as the code below:

    .. code-block:: text

        boundaries = [100, 200]
        values = [1.0, 0.5, 0.1]
        if epoch < 100:
            learning_rate = 1.0
        elif 100 <= global_step < 200:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Args:
        boundaries(list): A list of steps numbers. The type of element in the list is python int. 
        values(list): A list of learning rate values that will be picked during different epoch boundaries. 
            The type of element in the list is python float.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
319
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
320 321

    Returns:
322
        ``PiecewiseDecay`` instance to schedule learning rate.
323 324 325 326 327 328 329 330

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

331
            # train on default dynamic graph mode
332
            linear = paddle.nn.Linear(10, 10)
333 334
            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
335
            for epoch in range(20):
Z
Zhou Wei 已提交
336
                for batch_id in range(5):
337
                    x = paddle.uniform([10, 10])
338
                    out = linear(x)
C
chentianyu03 已提交
339
                    loss = paddle.mean(out)
340
                    loss.backward()
341 342
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
343 344
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
345

346
            # train on static graph mode
347 348 349 350
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
351 352
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
353 354
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
355
                scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
356 357 358 359 360 361
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
362
                for batch_id in range(5):
363 364 365 366 367 368
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
369
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
370 371
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
372 373 374 375 376
    """

    def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
        self.boundaries = boundaries
        self.values = values
377
        super(PiecewiseDecay, self).__init__(
378 379 380 381 382 383 384 385 386
            last_epoch=last_epoch, verbose=verbose)

    def get_lr(self):
        for i in range(len(self.boundaries)):
            if self.last_epoch < self.boundaries[i]:
                return self.values[i]
        return self.values[len(self.values) - 1]


387
class NaturalExpDecay(LRScheduler):
388
    r"""
389 390 391 392 393 394 395

    Applies natural exponential decay to the initial learning rate.
    
    The algorithm can be described as following:

    .. math::

396
        new\_learning\_rate = learning\_rate * e^{- gamma * epoch}
397 398 399 400 401

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        gamma (float, optional): A Ratio to update the learning rate. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
402
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
403 404

    Returns:
405
        ``NaturalExpDecay`` instance to schedule learning rate.
406 407 408 409 410 411 412 413

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

414
            # train on default dynamic graph mode
415
            linear = paddle.nn.Linear(10, 10)
416 417
            scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
418
            for epoch in range(20):
Z
Zhou Wei 已提交
419
                for batch_id in range(5):
420
                    x = paddle.uniform([10, 10])
421
                    out = linear(x)
C
chentianyu03 已提交
422
                    loss = paddle.mean(out)
423
                    loss.backward()
424 425
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
426 427
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
428

429
            # train on static graph mode
430 431 432 433
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
434 435
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
436 437
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
438
                scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
439 440 441 442 443 444
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
445
                for batch_id in range(5):
446 447 448 449 450 451
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
452
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
453 454
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
455 456 457 458
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
459 460
        super(NaturalExpDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
461 462 463 464 465

    def get_lr(self):
        return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)


466
class InverseTimeDecay(LRScheduler):
467
    r"""
468 469 470 471 472 473 474 475 476 477 478 479 480 481

    Applies inverse time decay to the initial learning rate.

    The algorithm can be described as following:

    .. math::

        new\_learning\_rate = \\frac{learning\_rate}{1 + gamma * epoch}

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
482
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
483 484

    Returns:
485
        ``InverseTimeDecay`` instance to schedule learning rate.
486 487 488 489 490 491 492 493

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

494
            # train on default dynamic graph mode
495
            linear = paddle.nn.Linear(10, 10)
496 497
            scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
498
            for epoch in range(20):
Z
Zhou Wei 已提交
499
                for batch_id in range(5):
500
                    x = paddle.uniform([10, 10])
501
                    out = linear(x)
C
chentianyu03 已提交
502
                    loss = paddle.mean(out)
503
                    loss.backward()
504 505
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
506 507
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
508

509
            # train on static graph mode
510 511 512 513
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
514 515
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
516 517
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
518
                scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
519 520 521 522 523 524
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
525
                for batch_id in range(5):
526 527 528 529 530 531
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
532
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
533 534
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
535 536 537 538 539

    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
540 541
        super(InverseTimeDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
542 543 544 545 546

    def get_lr(self):
        return self.base_lr / (1 + self.gamma * self.last_epoch)


547
class PolynomialDecay(LRScheduler):
548
    r"""
549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578

    Applies polynomial decay to the initial learning rate.

    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

        decay\_steps & = decay\_steps * math.ceil(\\frac{epoch}{decay\_steps}) 

        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr

    If cycle is set to False, then:

    .. math::

        epoch & = min(epoch, decay\_steps) 

        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\\frac{epoch}{decay\_steps})^{power}+end\_lr


    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
        power(float, optional): Power of polynomial. Default: 1.0.
        cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease 
            to ``end_lr`` .  If False, the learning rate is monotone decreasing. Default: False.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
579
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
580 581

    Returns:
582
        ``PolynomialDecay`` instance to schedule learning rate.
583 584 585 586 587 588 589 590

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

591
            # train on default dynamic graph mode
592
            linear = paddle.nn.Linear(10, 10)
593 594
            scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
595
            for epoch in range(20):
Z
Zhou Wei 已提交
596
                for batch_id in range(5):
597
                    x = paddle.uniform([10, 10])
598
                    out = linear(x)
C
chentianyu03 已提交
599
                    loss = paddle.mean(out)
600
                    loss.backward()
601 602
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
603 604
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
605

606
            # train on static graph mode
607 608 609 610
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
611 612
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
613 614
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
615
                scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
616 617 618 619 620 621
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
622
                for batch_id in range(5):
623 624 625 626 627 628
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
629
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
630 631
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
632 633 634 635 636 637 638 639 640 641 642 643 644 645
    """

    def __init__(self,
                 learning_rate,
                 decay_steps,
                 end_lr=0.0001,
                 power=1.0,
                 cycle=False,
                 last_epoch=-1,
                 verbose=False):
        self.decay_steps = decay_steps
        self.end_lr = end_lr
        self.power = power
        self.cycle = cycle
646 647
        super(PolynomialDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666

    def get_lr(self):
        tmp_epoch_num = self.last_epoch
        tmp_decay_steps = self.decay_steps
        if self.cycle:
            div_res = math.ceil(
                float(self.last_epoch) / float(self.decay_steps))

            if self.last_epoch == 0:
                div_res = 1
            tmp_decay_steps = self.decay_steps * div_res
        else:
            tmp_epoch_num = min(self.last_epoch, self.decay_steps)

        return (self.base_lr - self.end_lr) * (
            (1 - float(tmp_epoch_num) / float(tmp_decay_steps)
             )**self.power) + self.end_lr


667
class LinearWarmup(LRScheduler):
668
    r"""
669 670 671 672 673 674

    Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
    
    When epoch < warmup_steps, learning rate is updated as:
    
675
    .. math::
676
    
677
            lr = start\_lr + (end\_lr - start\_lr) * \\frac{epoch}{warmup\_steps}
678 679 680 681 682
    
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
    
    When epoch >= warmup_steps, learning rate is updated as:
    
683
    .. math::
684 685 686
    
            lr = learning_rate
    
687
    where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
688 689

    Args:
690
        learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
691 692 693 694
        warmup_steps (int): total steps of warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
695
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
696 697

    Returns:
698
        ``LinearWarmup`` instance to schedule learning rate.
699 700 701 702 703 704 705 706

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

707
            # train on default dynamic graph mode
708
            linear = paddle.nn.Linear(10, 10)
709
            scheduler = paddle.optimizer.lr.LinearWarmup(
710
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
711
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
712
            for epoch in range(20):
Z
Zhou Wei 已提交
713
                for batch_id in range(5):
714
                    x = paddle.uniform([10, 10])
715
                    out = linear(x)
C
chentianyu03 已提交
716
                    loss = paddle.mean(out)
717
                    loss.backward()
718 719
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
720 721
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
722

723
            # train on static graph mode
724 725 726 727
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
728 729
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
730 731
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
732
                scheduler = paddle.optimizer.lr.LinearWarmup(
733 734 735 736 737 738 739
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
740
                for batch_id in range(5):
741 742 743 744 745 746
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
747
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
748 749
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
750 751 752 753 754 755 756 757 758 759
    """

    def __init__(self,
                 learning_rate,
                 warmup_steps,
                 start_lr,
                 end_lr,
                 last_epoch=-1,
                 verbose=False):
        type_check = isinstance(learning_rate, float) or isinstance(
760
            learning_rate, int) or isinstance(learning_rate, LRScheduler)
761 762
        if not type_check:
            raise TypeError(
763
                "the type of learning_rate should be [int, float or LRScheduler], the current type is {}".
764 765 766 767 768 769 770
                format(learning_rate))
        self.learning_rate = learning_rate
        self.warmup_steps = warmup_steps
        self.start_lr = start_lr
        self.end_lr = end_lr
        assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
            end_lr, start_lr)
771
        super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose)
772

773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791
    def state_dict(self):
        """
        Returns the state of the LinearWarmup scheduler as a :class:`dict`.

        It is a subset of ``self.__dict__`` .
        """
        state_dict = super(LinearWarmup, self).state_dict()
        if isinstance(self.learning_rate, LRScheduler):
            state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
        return state_dict

    def set_state_dict(self, state_dict):
        """
        Loads state_dict for LinearWarmup scheduler.
        """
        super(LinearWarmup, self).set_state_dict(state_dict)
        if isinstance(self.learning_rate, LRScheduler):
            self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])

792 793 794 795 796
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return (self.end_lr - self.start_lr) * float(
                self.last_epoch) / float(self.warmup_steps) + self.start_lr
        else:
797
            if isinstance(self.learning_rate, LRScheduler):
798 799
                self.learning_rate.step(self.last_epoch - self.warmup_steps)
                return self.learning_rate()
800 801 802 803

            return self.learning_rate


804
class ExponentialDecay(LRScheduler):
805
    r"""
806

807
    Update learning rate by `gamma` each epoch.
808 809 810 811 812 813 814 815 816

    The algorithm can be described as following.
    
    .. math::

        new\_learning\_rate = last\_learning\_rate * gamma

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
817 818
        gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0.
819
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
820
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
821 822

    Returns:
823
        ``ExponentialDecay`` instance to schedule learning rate.
824 825 826 827 828 829 830 831

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

832
            # train on default dynamic graph mode
833
            linear = paddle.nn.Linear(10, 10)
834 835
            scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
836
            for epoch in range(20):
Z
Zhou Wei 已提交
837
                for batch_id in range(5):
838
                    x = paddle.uniform([10, 10])
839
                    out = linear(x)
C
chentianyu03 已提交
840
                    loss = paddle.mean(out)
841
                    loss.backward()
842 843
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
844 845
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
846

847
            # train on static graph mode
848 849 850 851
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
852 853
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
854 855
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
856
                scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
857 858 859 860 861 862
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
863
                for batch_id in range(5):
864 865 866 867 868 869
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
870
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
871 872
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
873 874 875 876
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
877 878
        super(ExponentialDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
879 880 881 882 883

    def get_lr(self):
        return self.base_lr * (self.gamma**self.last_epoch)


884
class MultiStepDecay(LRScheduler):
885
    """
886
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
908
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
909 910 911
        

    Returns:
912
        ``MultiStepDecay`` instance to schedule learning rate.
913 914 915 916 917 918 919 920

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

921
            # train on default dynamic graph mode
922
            linear = paddle.nn.Linear(10, 10)
923 924
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
925
            for epoch in range(20):
Z
Zhou Wei 已提交
926
                for batch_id in range(5):
927
                    x = paddle.uniform([10, 10])
928
                    out = linear(x)
C
chentianyu03 已提交
929
                    loss = paddle.mean(out)
930
                    loss.backward()
931 932
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
933 934
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
935

936
            # train on static graph mode
937 938 939 940
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
941 942
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
943 944
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
945
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
946 947 948 949 950 951
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
952
                for batch_id in range(5):
953 954 955 956 957 958
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
959
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
960 961
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984
    """

    def __init__(self,
                 learning_rate,
                 milestones,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
                % type(milestones))

        if not all([
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
        ]):
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.milestones = milestones
        self.gamma = gamma
985
        super(MultiStepDecay, self).__init__(learning_rate, last_epoch, verbose)
986 987 988 989 990 991 992 993

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
        return self.base_lr * (self.gamma**len(self.milestones))


994
class StepDecay(LRScheduler):
995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016
    """
    Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        gamma = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        step_size (int): the interval to update.
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` . 
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1017
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1018 1019

    Returns:
1020
        ``StepDecay`` instance to schedule learning rate.
1021 1022 1023 1024 1025 1026 1027 1028 1029


    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

1030
            # train on default dynamic graph mode
1031
            linear = paddle.nn.Linear(10, 10)
1032 1033
            scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1034
            for epoch in range(20):
Z
Zhou Wei 已提交
1035
                for batch_id in range(5):
1036
                    x = paddle.uniform([10, 10])
1037
                    out = linear(x)
C
chentianyu03 已提交
1038
                    loss = paddle.mean(out)
1039
                    loss.backward()
1040 1041
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1042 1043
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1044

1045
            # train on static graph mode
1046 1047 1048 1049
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1050 1051
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1052 1053
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1054
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
1055 1056 1057 1058 1059 1060
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1061
                for batch_id in range(5):
1062 1063 1064 1065 1066 1067
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1068
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1069 1070
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087
    """

    def __init__(self,
                 learning_rate,
                 step_size,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(step_size, int):
            raise TypeError(
                "The type of 'step_size' must be 'int', but received %s." %
                type(step_size))
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.step_size = step_size
        self.gamma = gamma
1088
        super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)
1089 1090 1091 1092 1093 1094

    def get_lr(self):
        i = self.last_epoch // self.step_size
        return self.base_lr * (self.gamma**i)


1095
class LambdaDecay(LRScheduler):
1096 1097 1098 1099 1100 1101 1102 1103 1104 1105
    """
    Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

    The algorithm can be described as the code below. 

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

1106 1107 1108
        learning_rate = 0.5        # epoch 0, 0.5*0.95**0
        learning_rate = 0.475      # epoch 1, 0.5*0.95**1
        learning_rate = 0.45125    # epoch 2, 0.5*0.95**2
1109 1110 1111 1112 1113

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1114
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1115 1116
    
    Returns:
1117
        ``LambdaDecay`` instance to schedule learning rate.
1118 1119 1120 1121 1122 1123 1124 1125

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

1126
            # train on default dynamic graph mode
1127
            linear = paddle.nn.Linear(10, 10)
1128 1129
            scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1130
            for epoch in range(20):
Z
Zhou Wei 已提交
1131
                for batch_id in range(5):
1132
                    x = paddle.uniform([10, 10])
1133
                    out = linear(x)
C
chentianyu03 已提交
1134
                    loss = paddle.mean(out)
1135
                    loss.backward()
1136 1137
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1138 1139
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1140

1141
            # train on static graph mode
1142 1143 1144 1145
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1146 1147
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1148 1149
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1150
                scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
1151 1152 1153 1154 1155 1156
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1157
                for batch_id in range(5):
1158 1159 1160 1161 1162 1163
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1164
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1165 1166
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1167 1168 1169 1170 1171 1172

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
1173
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1174 1175 1176
                % type(lr_lambda))

        self.lr_lambda = lr_lambda
1177
        super(LambdaDecay, self).__init__(learning_rate, last_epoch, verbose)
1178 1179 1180 1181 1182

    def get_lr(self):
        return self.base_lr * self.lr_lambda(self.last_epoch)


1183
class ReduceOnPlateau(LRScheduler):
1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210
    """
    Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate 
    by 2 to 10 times once model performance has no longer improvement.

    The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics`` 
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` . 
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience`` 
    number of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the 
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning 
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` . 
            It should be less than 1.0. Default: 0.1.
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced. 
            Default: 10.
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` . 
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum 
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
1211 1212
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon, 
            the update is ignored. Default: 1e-8.
1213 1214 1215 1216
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.

    
    Returns:
1217
        ``ReduceOnPlateau`` instance to schedule learning rate.
1218 1219 1220 1221 1222 1223 1224 1225


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

1226
            # train on default dynamic graph mode
1227
            linear = paddle.nn.Linear(10, 10)
1228 1229
            scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1230
            for epoch in range(20):
Z
Zhou Wei 已提交
1231
                for batch_id in range(5):
1232
                    x = paddle.uniform([10, 10])
1233
                    out = linear(x)
C
chentianyu03 已提交
1234
                    loss = paddle.mean(out)
1235
                    loss.backward()
1236 1237
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1238 1239
                    scheduler.step(loss)    # If you update learning rate each step
              # scheduler.step(loss)        # If you update learning rate each epoch
1240

1241
            # train on static graph mode
1242 1243 1244 1245
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1246 1247
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1248 1249
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1250
                scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
1251 1252 1253 1254 1255 1256
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1257
                for batch_id in range(5):
1258 1259 1260 1261 1262 1263
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1264
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1265 1266
                    scheduler.step(out[0])    # If you update learning rate each step
              # scheduler.step(out[0])        # If you update learning rate each epoch
1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297

    """

    def __init__(self,
                 learning_rate,
                 mode='min',
                 factor=0.1,
                 patience=10,
                 threshold=1e-4,
                 threshold_mode='rel',
                 cooldown=0,
                 min_lr=0,
                 epsilon=1e-8,
                 verbose=False):
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode: ' + mode + ' is unknown!')
        self.mode = mode

        if factor >= 1.0:
            raise ValueError(
                'new_lr = origin_lr * gamma and gamma should be < 1.0.')
        self.factor = factor

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
            raise ValueError('threshold mode: ' + threshold_mode +
                             ' is unknown!')
        self.threshold_mode = threshold_mode
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
1298
                "The type of 'learning_rate' in 'ReduceOnPlateau' must be 'float', but received %s."
1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319
                % type(learning_rate))

        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

        self.cooldown_counter = 0
        self.best = None
        self.num_bad_epochs = 0

        # Can not call Parent __init__, so implement here.
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = 0
        self.verbose = verbose
        self._var_name = None

    # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
1320
    def state_keys(self):
1321 1322 1323 1324 1325 1326 1327
        self.keys = [
            'cooldown_counter', 'best', 'num_bad_epochs', 'last_epoch',
            'last_lr'
        ]

    def step(self, metrics, epoch=None):
        """
1328
        step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .  
1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340
        The new learning rate will take effect on next epoch.

        Args:
            metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce. 
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
                'numpy.ndarray', its shape must be [1].
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
        
        Examples:
1341
            Please refer to the example of current LRScheduler.
1342 1343 1344 1345 1346 1347
        """
        if epoch is None:
            self.last_epoch = self.last_epoch + 1
        else:
            self.last_epoch = epoch

1348
        # loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392
        if isinstance(metrics, (Tensor, numpy.ndarray)):
            assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \
                "should be (1L,), but the current metrics.shape is {}. Maybe that "  \
                "you should call paddle.mean to process it first.".format(loss.shape)
        elif not isinstance(metrics,
                            (int, float, numpy.float32, numpy.float64)):
            raise TypeError(
                "metrics must be 'int', 'float', 'np.float', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".
                format(type(metrics)))

        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best is None or self._is_better(metrics, self.best):
                self.best = metrics
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = max(self.last_lr * self.factor, self.min_lr)
                if self.last_lr - new_lr > self.epsilon:
                    self.last_lr = new_lr
                    if self.verbose:
                        print('Epoch {}: {} set learning rate to {}.'.format(
                            self.last_epoch, self.__class__.__name__,
                            self.last_lr))

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold


1393
class CosineAnnealingDecay(LRScheduler):
1394
    r"""
1395 1396 1397

    Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to 
    the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in 
1398
    SGDR.
1399 1400 1401 1402

    The algorithm can be described as following.

    .. math::
1403 1404 1405 1406 1407 1408 1409

        \\begin{aligned}
            \eta_t & = \eta_{min} + \\frac{1}{2}(\eta_{max} - \eta_{min})\left(1
            + \cos\left(\\frac{T_{cur}}{T_{max}}\pi\\right)\\right),
            & T_{cur} \\neq (2k+1)T_{max}; \\
            \eta_{t+1} & = \eta_{t} + \\frac{1}{2}(\eta_{max} - \eta_{min})
            \left(1 - \cos\left(\\frac{1}{T_{max}}\pi\\right)\\right),
1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420
            & T_{cur} = (2k+1)T_{max}.
        \end{aligned}
    
    It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_. 
    Note that this only implements the cosine annealing part of SGDR, and not the restarts.
    
    Args:
        learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
        T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate.
        eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1421
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1422 1423

    Returns:
1424
        ``CosineAnnealingDecay`` instance to schedule learning rate.
1425 1426 1427 1428 1429 1430 1431 1432

    Examples:
        
        .. code-block:: python

            import paddle
            import numpy as np

1433
            # train on default dynamic graph mode
1434
            linear = paddle.nn.Linear(10, 10)
1435 1436
            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1437
            for epoch in range(20):
Z
Zhou Wei 已提交
1438
                for batch_id in range(5):
1439
                    x = paddle.uniform([10, 10])
1440
                    out = linear(x)
C
chentianyu03 已提交
1441
                    loss = paddle.mean(out)
1442
                    loss.backward()
1443 1444
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1445 1446
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1447

1448
            # train on static graph mode
1449 1450 1451 1452
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1453 1454
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1455 1456
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1457
                scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
1458 1459 1460 1461 1462 1463
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1464
                for batch_id in range(5):
1465 1466 1467 1468 1469 1470
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1471
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1472 1473
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1474 1475 1476 1477 1478 1479 1480 1481 1482 1483
    """

    def __init__(self,
                 learning_rate,
                 T_max,
                 eta_min=0,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(T_max, int):
            raise TypeError(
1484
                "The type of 'T_max' in 'CosineAnnealingDecay' must be 'int', but received %s."
1485 1486 1487
                % type(T_max))
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
1488
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
1489 1490 1491
                % type(eta_min))
        self.T_max = T_max
        self.eta_min = float(eta_min)
1492 1493
        super(CosineAnnealingDecay, self).__init__(learning_rate, last_epoch,
                                                   verbose)
1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
                math.pi / self.T_max)) / 2

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
                self.last_lr - self.eta_min) + self.eta_min

    def _get_closed_form_lr(self):
        return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
            math.pi * self.last_epoch / self.T_max)) / 2