lr.py 74.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy
import warnings
from paddle import Tensor
19
import paddle.fluid.core as core
J
Jiabin Yang 已提交
20
from ..fluid.framework import _in_legacy_dygraph
21

G
guguguzi 已提交
22
__all__ = [  # noqa
23 24 25 26 27 28 29 30 31 32 33 34
    'LRScheduler',
    'NoamDecay',
    'PiecewiseDecay',
    'NaturalExpDecay',
    'InverseTimeDecay',
    'PolynomialDecay',
    'LinearWarmup',
    'ExponentialDecay',
    'MultiStepDecay',
    'StepDecay',
    'LambdaDecay',
    'ReduceOnPlateau',
G
guguguzi 已提交
35
    'CosineAnnealingDecay',
36 37
    'MultiplicativeDecay',
    'OneCycleLR'
38 39 40
]


41 42 43 44 45
class LRScheduler(object):
    """

    LRScheduler Base class. Define the common interface of a learning rate scheduler.

Z
Zhou Wei 已提交
46
    User can import it by ``from paddle.optimizer.lr import LRScheduler`` ,
47 48 49 50 51 52 53 54 55 56 57 58 59 60 61

    then overload it for your subclass and have a custom implementation of ``get_lr()`` .

    Otherwise, an ``NotImplementedError`` exception will be thrown.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        instance to schedule learning rate.

    Examples:
        Here is an example of a simple ``StepDecay`` implementation. 
G
guguguzi 已提交
62

63
        .. code-block:: python
G
guguguzi 已提交
64

65
            import paddle
Z
Zhou Wei 已提交
66
            from paddle.optimizer.lr import LRScheduler
67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88

            class StepDecay(LRScheduler):
                def __init__(self,
                            learning_rate,
                            step_size,
                            gamma=0.1,
                            last_epoch=-1,
                            verbose=False):
                    if not isinstance(step_size, int):
                        raise TypeError(
                            "The type of 'step_size' must be 'int', but received %s." %
                            type(step_size))
                    if gamma >= 1.0:
                        raise ValueError('gamma should be < 1.0.')

                    self.step_size = step_size
                    self.gamma = gamma
                    super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)

                def get_lr(self):
                    i = self.last_epoch // self.step_size
                    return self.base_lr * (self.gamma**i)
89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105

    """

    def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
                "The type of learning rate must be float, but received {}".
                format(type(learning_rate)))
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = last_epoch
        self.verbose = verbose
        self._var_name = None

        self.step()

    def __call__(self):
G
guguguzi 已提交
106
        """
107
        Return lastest computed learning rate on current epoch.
108 109 110 111 112
        """
        return self.last_lr

    def step(self, epoch=None):
        """
113

G
guguguzi 已提交
114
        ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
115
        The new learning rate will take effect on next ``optimizer.step`` .
116 117 118 119 120 121

        Args:
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
122

123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
        """
        if epoch is None:
            self.last_epoch += 1
            self.last_lr = self.get_lr()
        else:
            self.last_epoch = epoch
            if hasattr(self, "_get_closed_form_lr"):
                self.last_lr = self._get_closed_form_lr()
            else:
                self.last_lr = self.get_lr()

        if self.verbose:
            print('Epoch {}: {} set learning rate to {}.'.format(
                self.last_epoch, self.__class__.__name__, self.last_lr))

    def state_dict(self):
        """
140

141 142
        Returns the state of the scheduler as a :class:`dict`.

143
        It is a subset of ``self.__dict__`` .
144
        """
145
        self.state_keys()
146 147 148 149 150 151 152 153 154 155 156 157 158 159 160
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Tensor):
                assert value.shape == [
                    1
                ], "shape of Tensor in state_dict must be [1] {}".format(
                    value.shape)
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

161
    # For those subclass who overload LRScheduler, "last_epoch, last_lr" will be saved by default.
162
    # (Note): you can change it for your subclass.
163
    def state_keys(self):
164
        """
165 166 167 168 169 170 171

        For those subclass who overload ``LRScheduler`` (Base Class). Acquiescently, "last_epoch, last_lr" will be saved by ``self.keys = ['last_epoch', 'last_lr']`` .

        ``last_epoch`` is the current epoch num, and ``last_lr`` is the current learning rate.

        If you want to change the default behavior, you should have a custom implementation of ``_state_keys()`` to redefine ``self.keys`` .

172 173 174
        """
        self.keys = ['last_epoch', 'last_lr']

175
    def set_state_dict(self, state_dict):
176
        """
177

178 179
        Loads the schedulers state.
        """
180
        self.state_keys()
181 182 183 184 185 186 187 188 189 190 191 192
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".
                    format(key))
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

193 194
    # alias for set_state_dict
    set_dict = set_state_dict
195 196

    def get_lr(self):
197
        """
G
guguguzi 已提交
198

199 200 201 202
        For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .

        Otherwise, an ``NotImplementedError`` exception will be thrown.
        """
203 204 205 206
        # calculate by python float
        raise NotImplementedError


207
class NoamDecay(LRScheduler):
208
    r"""
209

G
guguguzi 已提交
210
    Applies Noam Decay to the initial learning rate.
211 212 213 214 215 216 217

    The algorithm can be described as following.

    .. math::

        new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})

G
guguguzi 已提交
218
    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
219 220 221 222 223 224 225


    Args:
        d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
        warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
        learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
226
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
227 228

    Returns:
229
        ``NoamDecay`` instance to schedule learning rate.
230 231 232 233 234 235 236

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

237
            # train on default dynamic graph mode
238
            linear = paddle.nn.Linear(10, 10)
239 240
            scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
241
            for epoch in range(20):
Z
Zhou Wei 已提交
242
                for batch_id in range(5):
243
                    x = paddle.uniform([10, 10])
244
                    out = linear(x)
C
chentianyu03 已提交
245
                    loss = paddle.mean(out)
246
                    loss.backward()
247 248
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
249 250
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
251

252
            # train on static graph mode
253 254 255 256
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
257 258
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
259 260
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
261
                scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
262 263 264 265 266 267
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
268
                for batch_id in range(5):
269 270 271 272 273 274
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
275
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
276 277
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
278 279 280 281 282 283 284 285 286 287 288

    """

    def __init__(self,
                 d_model,
                 warmup_steps,
                 learning_rate=1.0,
                 last_epoch=-1,
                 verbose=False):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
289
        super(NoamDecay, self).__init__(learning_rate, last_epoch, verbose)
290 291 292 293 294 295 296 297 298 299

    def get_lr(self):
        if self.last_epoch == 0:
            a = 1
        else:
            a = self.last_epoch**-0.5
        b = self.warmup_steps**-1.5 * self.last_epoch
        return self.base_lr * (self.d_model**-0.5) * min(a, b)


300
class PiecewiseDecay(LRScheduler):
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318
    """

    Piecewise learning rate scheduler.

    The algorithm can be described as the code below:

    .. code-block:: text

        boundaries = [100, 200]
        values = [1.0, 0.5, 0.1]
        if epoch < 100:
            learning_rate = 1.0
        elif 100 <= global_step < 200:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Args:
G
guguguzi 已提交
319 320
        boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int.
        values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries.
321 322
            The type of element in the list is python float.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
323
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
324 325

    Returns:
326
        ``PiecewiseDecay`` instance to schedule learning rate.
327 328

    Examples:
G
guguguzi 已提交
329

330 331 332 333 334
        .. code-block:: python

            import paddle
            import numpy as np

335
            # train on default dynamic graph mode
336
            linear = paddle.nn.Linear(10, 10)
337 338
            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
339
            for epoch in range(20):
Z
Zhou Wei 已提交
340
                for batch_id in range(5):
341
                    x = paddle.uniform([10, 10])
342
                    out = linear(x)
C
chentianyu03 已提交
343
                    loss = paddle.mean(out)
344
                    loss.backward()
345 346
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
347 348
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
349

350
            # train on static graph mode
351 352 353 354
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
355 356
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
357 358
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
359
                scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
360 361 362 363 364 365
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
366
                for batch_id in range(5):
367 368 369 370 371 372
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
373
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
374 375
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
376 377 378 379 380
    """

    def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
        self.boundaries = boundaries
        self.values = values
381
        super(PiecewiseDecay, self).__init__(
382 383 384 385 386 387 388 389 390
            last_epoch=last_epoch, verbose=verbose)

    def get_lr(self):
        for i in range(len(self.boundaries)):
            if self.last_epoch < self.boundaries[i]:
                return self.values[i]
        return self.values[len(self.values) - 1]


391
class NaturalExpDecay(LRScheduler):
392
    r"""
393 394

    Applies natural exponential decay to the initial learning rate.
G
guguguzi 已提交
395

396 397 398 399
    The algorithm can be described as following:

    .. math::

400
        new\_learning\_rate = learning\_rate * e^{- gamma * epoch}
401 402 403

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
404
        gamma (float, optional): A Ratio to update the learning rate, should greater than 0.0 to make learning rate decay. Default: 0.1.
405
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
406
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
407 408

    Returns:
409
        ``NaturalExpDecay`` instance to schedule learning rate.
410 411

    Examples:
G
guguguzi 已提交
412

413 414 415 416 417
        .. code-block:: python

            import paddle
            import numpy as np

418
            # train on default dynamic graph mode
419
            linear = paddle.nn.Linear(10, 10)
420 421
            scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
422
            for epoch in range(20):
Z
Zhou Wei 已提交
423
                for batch_id in range(5):
424
                    x = paddle.uniform([10, 10])
425
                    out = linear(x)
C
chentianyu03 已提交
426
                    loss = paddle.mean(out)
427
                    loss.backward()
428 429
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
430 431
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
432

433
            # train on static graph mode
434 435 436 437
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
438 439
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
440 441
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
442
                scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
443 444 445 446 447 448
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
449
                for batch_id in range(5):
450 451 452 453 454 455
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
456
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
457 458
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
459 460 461
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
462
        assert gamma > 0.0, " 'gamma' must be a positive number so that the learning rate will decay."
463
        self.gamma = gamma
464 465
        super(NaturalExpDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
466 467 468 469 470

    def get_lr(self):
        return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)


471
class InverseTimeDecay(LRScheduler):
472
    r"""
473 474 475 476 477 478 479

    Applies inverse time decay to the initial learning rate.

    The algorithm can be described as following:

    .. math::

480
        new\_learning\_rate = \frac{learning\_rate}{1 + gamma * epoch}
481 482 483

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
484
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
485 486
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
487
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
488 489

    Returns:
490
        ``InverseTimeDecay`` instance to schedule learning rate.
491 492

    Examples:
G
guguguzi 已提交
493

494 495 496 497 498
        .. code-block:: python

            import paddle
            import numpy as np

499
            # train on default dynamic graph mode
500
            linear = paddle.nn.Linear(10, 10)
501 502
            scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
503
            for epoch in range(20):
Z
Zhou Wei 已提交
504
                for batch_id in range(5):
505
                    x = paddle.uniform([10, 10])
506
                    out = linear(x)
C
chentianyu03 已提交
507
                    loss = paddle.mean(out)
508
                    loss.backward()
509 510
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
511 512
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
513

514
            # train on static graph mode
515 516 517 518
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
519 520
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
521 522
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
523
                scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
524 525 526 527 528 529
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
530
                for batch_id in range(5):
531 532 533 534 535 536
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
537
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
538 539
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
540 541 542 543 544

    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
545 546
        super(InverseTimeDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
547 548 549 550 551

    def get_lr(self):
        return self.base_lr / (1 + self.gamma * self.last_epoch)


552
class PolynomialDecay(LRScheduler):
553
    r"""
554 555 556 557 558 559 560 561 562

    Applies polynomial decay to the initial learning rate.

    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

G
guguguzi 已提交
563
        decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps})
564

565
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
566 567 568 569 570

    If cycle is set to False, then:

    .. math::

G
guguguzi 已提交
571
        epoch & = min(epoch, decay\_steps)
572

573
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
574 575 576 577


    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
578
        decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer.
579
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
580
        power(float, optional): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 1.0.
G
guguguzi 已提交
581
        cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
582 583
            to ``end_lr`` .  If False, the learning rate is monotone decreasing. Default: False.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
584
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
585 586

    Returns:
587
        ``PolynomialDecay`` instance to schedule learning rate.
588 589

    Examples:
G
guguguzi 已提交
590

591 592 593 594 595
        .. code-block:: python

            import paddle
            import numpy as np

596
            # train on default dynamic graph mode
597
            linear = paddle.nn.Linear(10, 10)
598 599
            scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
600
            for epoch in range(20):
Z
Zhou Wei 已提交
601
                for batch_id in range(5):
602
                    x = paddle.uniform([10, 10])
603
                    out = linear(x)
C
chentianyu03 已提交
604
                    loss = paddle.mean(out)
605
                    loss.backward()
606 607
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
608 609
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
610

611
            # train on static graph mode
612 613 614 615
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
616 617
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
618 619
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
620
                scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
621 622 623 624 625 626
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
627
                for batch_id in range(5):
628 629 630 631 632 633
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
634
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
635 636
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
637 638 639 640 641 642 643 644 645 646
    """

    def __init__(self,
                 learning_rate,
                 decay_steps,
                 end_lr=0.0001,
                 power=1.0,
                 cycle=False,
                 last_epoch=-1,
                 verbose=False):
647 648
        assert decay_steps > 0 and isinstance(
            decay_steps, int), " 'decay_steps' must be a positive integer."
649 650
        self.decay_steps = decay_steps
        self.end_lr = end_lr
651
        assert power > 0.0, " 'power' must be greater than 0.0 so that the learning rate will decay."
652 653
        self.power = power
        self.cycle = cycle
654 655
        super(PolynomialDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674

    def get_lr(self):
        tmp_epoch_num = self.last_epoch
        tmp_decay_steps = self.decay_steps
        if self.cycle:
            div_res = math.ceil(
                float(self.last_epoch) / float(self.decay_steps))

            if self.last_epoch == 0:
                div_res = 1
            tmp_decay_steps = self.decay_steps * div_res
        else:
            tmp_epoch_num = min(self.last_epoch, self.decay_steps)

        return (self.base_lr - self.end_lr) * (
            (1 - float(tmp_epoch_num) / float(tmp_decay_steps)
             )**self.power) + self.end_lr


675
class LinearWarmup(LRScheduler):
676
    r"""
677 678 679

    Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
G
guguguzi 已提交
680

681
    When epoch < warmup_steps, learning rate is updated as:
G
guguguzi 已提交
682

683
    .. math::
G
guguguzi 已提交
684

685
            lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps}
G
guguguzi 已提交
686

687
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
G
guguguzi 已提交
688

689
    When epoch >= warmup_steps, learning rate is updated as:
G
guguguzi 已提交
690

691
    .. math::
G
guguguzi 已提交
692

693
            lr = learning_rate
G
guguguzi 已提交
694

695
    where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
696 697

    Args:
698
        learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
699
        warmup_steps (int): total steps of warm up. It must be a positive integer.
700 701 702
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
703
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
704 705

    Returns:
706
        ``LinearWarmup`` instance to schedule learning rate.
707 708

    Examples:
G
guguguzi 已提交
709

710 711 712 713 714
        .. code-block:: python

            import paddle
            import numpy as np

715
            # train on default dynamic graph mode
716
            linear = paddle.nn.Linear(10, 10)
717
            scheduler = paddle.optimizer.lr.LinearWarmup(
718
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
719
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
720
            for epoch in range(20):
Z
Zhou Wei 已提交
721
                for batch_id in range(5):
722
                    x = paddle.uniform([10, 10])
723
                    out = linear(x)
C
chentianyu03 已提交
724
                    loss = paddle.mean(out)
725
                    loss.backward()
726 727
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
728 729
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
730

731
            # train on static graph mode
732 733 734 735
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
736 737
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
738 739
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
740
                scheduler = paddle.optimizer.lr.LinearWarmup(
741 742 743 744 745 746 747
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
748
                for batch_id in range(5):
749 750 751 752 753 754
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
755
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
756 757
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
758 759 760 761 762 763 764 765 766 767
    """

    def __init__(self,
                 learning_rate,
                 warmup_steps,
                 start_lr,
                 end_lr,
                 last_epoch=-1,
                 verbose=False):
        type_check = isinstance(learning_rate, float) or isinstance(
768
            learning_rate, int) or isinstance(learning_rate, LRScheduler)
769 770
        if not type_check:
            raise TypeError(
771
                "the type of learning_rate should be [int, float or LRScheduler], the current type is {}".
772 773
                format(learning_rate))
        self.learning_rate = learning_rate
774 775
        assert warmup_steps > 0 and isinstance(
            warmup_steps, int), " 'warmup_steps' must be a positive integer."
776 777 778 779 780
        self.warmup_steps = warmup_steps
        self.start_lr = start_lr
        self.end_lr = end_lr
        assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
            end_lr, start_lr)
781
        super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose)
782

783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801
    def state_dict(self):
        """
        Returns the state of the LinearWarmup scheduler as a :class:`dict`.

        It is a subset of ``self.__dict__`` .
        """
        state_dict = super(LinearWarmup, self).state_dict()
        if isinstance(self.learning_rate, LRScheduler):
            state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
        return state_dict

    def set_state_dict(self, state_dict):
        """
        Loads state_dict for LinearWarmup scheduler.
        """
        super(LinearWarmup, self).set_state_dict(state_dict)
        if isinstance(self.learning_rate, LRScheduler):
            self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])

802 803 804 805 806
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return (self.end_lr - self.start_lr) * float(
                self.last_epoch) / float(self.warmup_steps) + self.start_lr
        else:
807
            if isinstance(self.learning_rate, LRScheduler):
808 809
                self.learning_rate.step(self.last_epoch - self.warmup_steps)
                return self.learning_rate()
810 811 812 813

            return self.learning_rate


814
class ExponentialDecay(LRScheduler):
815
    r"""
816

817
    Update learning rate by `gamma` each epoch.
818 819

    The algorithm can be described as following.
G
guguguzi 已提交
820

821 822 823 824 825 826
    .. math::

        new\_learning\_rate = last\_learning\_rate * gamma

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
827
        gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
828
            It should be in interval (0.0, 1.0).
829
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
830
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
831 832

    Returns:
833
        ``ExponentialDecay`` instance to schedule learning rate.
834 835

    Examples:
G
guguguzi 已提交
836

837 838 839 840 841
        .. code-block:: python

            import paddle
            import numpy as np

842
            # train on default dynamic graph mode
843
            linear = paddle.nn.Linear(10, 10)
844 845
            scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
846
            for epoch in range(20):
Z
Zhou Wei 已提交
847
                for batch_id in range(5):
848
                    x = paddle.uniform([10, 10])
849
                    out = linear(x)
C
chentianyu03 已提交
850
                    loss = paddle.mean(out)
851
                    loss.backward()
852 853
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
854 855
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
856

857
            # train on static graph mode
858 859 860 861
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
862 863
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
864 865
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
866
                scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
867 868 869 870 871 872
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
873
                for batch_id in range(5):
874 875 876 877 878 879
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
880
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
881 882
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
883 884 885
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
886
        assert gamma > 0.0 and gamma < 1.0, " 'gamma' must be in interval (0.0, 1.0) so that the learning rate will decay."
887
        self.gamma = gamma
888 889
        super(ExponentialDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
890 891 892 893 894

    def get_lr(self):
        return self.base_lr * (self.gamma**self.last_epoch)


895
class MultiStepDecay(LRScheduler):
896
    """
897
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
898

G
guguguzi 已提交
899
    The algorithm can be described as the code below.
900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
G
guguguzi 已提交
916
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
917 918
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
919
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
920

921 922

    Returns:
923
        ``MultiStepDecay`` instance to schedule learning rate.
924 925

    Examples:
G
guguguzi 已提交
926

927 928 929 930 931
        .. code-block:: python

            import paddle
            import numpy as np

932
            # train on default dynamic graph mode
933
            linear = paddle.nn.Linear(10, 10)
934 935
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
936
            for epoch in range(20):
Z
Zhou Wei 已提交
937
                for batch_id in range(5):
938
                    x = paddle.uniform([10, 10])
939
                    out = linear(x)
C
chentianyu03 已提交
940
                    loss = paddle.mean(out)
941
                    loss.backward()
942 943
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
944 945
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
946

947
            # train on static graph mode
948 949 950 951
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
952 953
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
954 955
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
956
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
957 958 959 960 961 962
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
963
                for batch_id in range(5):
964 965 966 967 968 969
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
970
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
971 972
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995
    """

    def __init__(self,
                 learning_rate,
                 milestones,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
                % type(milestones))

        if not all([
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
        ]):
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.milestones = milestones
        self.gamma = gamma
996
        super(MultiStepDecay, self).__init__(learning_rate, last_epoch, verbose)
997 998 999 1000 1001 1002 1003 1004

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
        return self.base_lr * (self.gamma**len(self.milestones))


1005
class StepDecay(LRScheduler):
1006 1007 1008
    """
    Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.

G
guguguzi 已提交
1009
    The algorithm can be described as the code below.
1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        gamma = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
1024
        step_size (int): the interval to update. It must be a positive integer.
G
guguguzi 已提交
1025
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
1026 1027
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1028
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1029 1030

    Returns:
1031
        ``StepDecay`` instance to schedule learning rate.
1032 1033 1034


    Examples:
G
guguguzi 已提交
1035

1036 1037 1038 1039 1040
        .. code-block:: python

            import paddle
            import numpy as np

1041
            # train on default dynamic graph mode
1042
            linear = paddle.nn.Linear(10, 10)
1043 1044
            scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1045
            for epoch in range(20):
Z
Zhou Wei 已提交
1046
                for batch_id in range(5):
1047
                    x = paddle.uniform([10, 10])
1048
                    out = linear(x)
C
chentianyu03 已提交
1049
                    loss = paddle.mean(out)
1050
                    loss.backward()
1051 1052
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1053 1054
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1055

1056
            # train on static graph mode
1057 1058 1059 1060
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1061 1062
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1063 1064
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1065
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
1066 1067 1068 1069 1070 1071
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1072
                for batch_id in range(5):
1073 1074 1075 1076 1077 1078
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1079
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1080 1081
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096
    """

    def __init__(self,
                 learning_rate,
                 step_size,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(step_size, int):
            raise TypeError(
                "The type of 'step_size' must be 'int', but received %s." %
                type(step_size))
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

1097 1098
        assert step_size > 0 and isinstance(
            step_size, int), " 'step_size' must be a positive integer."
1099 1100
        self.step_size = step_size
        self.gamma = gamma
1101
        super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)
1102 1103 1104 1105 1106 1107

    def get_lr(self):
        i = self.last_epoch // self.step_size
        return self.base_lr * (self.gamma**i)


1108
class LambdaDecay(LRScheduler):
1109 1110 1111
    """
    Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

G
guguguzi 已提交
1112
    The algorithm can be described as the code below.
1113 1114 1115 1116 1117 1118

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

1119 1120 1121
        learning_rate = 0.5        # epoch 0, 0.5*0.95**0
        learning_rate = 0.475      # epoch 1, 0.5*0.95**1
        learning_rate = 0.45125    # epoch 2, 0.5*0.95**2
1122 1123 1124 1125 1126

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1127
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
1128

1129
    Returns:
1130
        ``LambdaDecay`` instance to schedule learning rate.
1131 1132

    Examples:
G
guguguzi 已提交
1133

1134 1135 1136 1137 1138
        .. code-block:: python

            import paddle
            import numpy as np

1139
            # train on default dynamic graph mode
1140
            linear = paddle.nn.Linear(10, 10)
1141 1142
            scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1143
            for epoch in range(20):
Z
Zhou Wei 已提交
1144
                for batch_id in range(5):
1145
                    x = paddle.uniform([10, 10])
1146
                    out = linear(x)
C
chentianyu03 已提交
1147
                    loss = paddle.mean(out)
1148
                    loss.backward()
1149 1150
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1151 1152
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1153

1154
            # train on static graph mode
1155 1156 1157 1158
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1159 1160
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1161 1162
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1163
                scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
1164 1165 1166 1167 1168 1169
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1170
                for batch_id in range(5):
1171 1172 1173 1174 1175 1176
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1177
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1178 1179
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1180 1181 1182 1183 1184 1185

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
1186
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1187 1188 1189
                % type(lr_lambda))

        self.lr_lambda = lr_lambda
1190
        super(LambdaDecay, self).__init__(learning_rate, last_epoch, verbose)
1191 1192 1193 1194 1195

    def get_lr(self):
        return self.base_lr * self.lr_lambda(self.last_epoch)


1196
class ReduceOnPlateau(LRScheduler):
1197
    """
G
guguguzi 已提交
1198
    Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
1199 1200
    by 2 to 10 times once model performance has no longer improvement.

G
guguguzi 已提交
1201 1202 1203
    The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
1204 1205 1206 1207 1208 1209
    number of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
1210 1211
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning
1212
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
G
guguguzi 已提交
1213
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
1214
            It should be less than 1.0. Default: 0.1.
G
guguguzi 已提交
1215
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
1216
            Default: 10.
G
guguguzi 已提交
1217
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
1218 1219
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
G
guguguzi 已提交
1220
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
1221 1222 1223
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
G
guguguzi 已提交
1224
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
1225
            the update is ignored. Default: 1e-8.
1226 1227
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.

G
guguguzi 已提交
1228

1229
    Returns:
1230
        ``ReduceOnPlateau`` instance to schedule learning rate.
1231 1232 1233 1234 1235 1236 1237 1238


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

1239
            # train on default dynamic graph mode
1240
            linear = paddle.nn.Linear(10, 10)
1241 1242
            scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1243
            for epoch in range(20):
Z
Zhou Wei 已提交
1244
                for batch_id in range(5):
1245
                    x = paddle.uniform([10, 10])
1246
                    out = linear(x)
C
chentianyu03 已提交
1247
                    loss = paddle.mean(out)
1248
                    loss.backward()
1249 1250
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1251 1252
                    scheduler.step(loss)    # If you update learning rate each step
              # scheduler.step(loss)        # If you update learning rate each epoch
1253

1254
            # train on static graph mode
1255 1256 1257 1258
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1259 1260
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1261 1262
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1263
                scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
1264 1265 1266 1267 1268 1269
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1270
                for batch_id in range(5):
1271 1272 1273 1274 1275 1276
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1277
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1278 1279
                    scheduler.step(out[0])    # If you update learning rate each step
              # scheduler.step(out[0])        # If you update learning rate each epoch
1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310

    """

    def __init__(self,
                 learning_rate,
                 mode='min',
                 factor=0.1,
                 patience=10,
                 threshold=1e-4,
                 threshold_mode='rel',
                 cooldown=0,
                 min_lr=0,
                 epsilon=1e-8,
                 verbose=False):
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode: ' + mode + ' is unknown!')
        self.mode = mode

        if factor >= 1.0:
            raise ValueError(
                'new_lr = origin_lr * gamma and gamma should be < 1.0.')
        self.factor = factor

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
            raise ValueError('threshold mode: ' + threshold_mode +
                             ' is unknown!')
        self.threshold_mode = threshold_mode
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
1311
                "The type of 'learning_rate' in 'ReduceOnPlateau' must be 'float', but received %s."
1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332
                % type(learning_rate))

        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

        self.cooldown_counter = 0
        self.best = None
        self.num_bad_epochs = 0

        # Can not call Parent __init__, so implement here.
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = 0
        self.verbose = verbose
        self._var_name = None

    # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
1333
    def state_keys(self):
1334 1335 1336 1337 1338 1339 1340
        self.keys = [
            'cooldown_counter', 'best', 'num_bad_epochs', 'last_epoch',
            'last_lr'
        ]

    def step(self, metrics, epoch=None):
        """
G
guguguzi 已提交
1341
        step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .
1342 1343 1344
        The new learning rate will take effect on next epoch.

        Args:
G
guguguzi 已提交
1345
            metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
1346 1347 1348 1349 1350 1351
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
                'numpy.ndarray', its shape must be [1].
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
G
guguguzi 已提交
1352

1353
        Examples:
1354
            Please refer to the example of current LRScheduler.
1355 1356 1357 1358 1359 1360
        """
        if epoch is None:
            self.last_epoch = self.last_epoch + 1
        else:
            self.last_epoch = epoch

J
Jiabin Yang 已提交
1361
        if not _in_legacy_dygraph():
1362
            tmp = core.eager.Tensor
1363
        else:
1364 1365
            # need to declarate explicitly
            from paddle.framework import VarBase as Tensor
1366
            tmp = Tensor
1367
        # loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
1368
        if isinstance(metrics, (tmp, numpy.ndarray)):
1369
            assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \
G
guguguzi 已提交
1370 1371 1372
                                                                      "should be (1L,), but the current metrics.shape is {}. Maybe that " \
                                                                      "you should call paddle.mean to process it first.".format(
                metrics.shape)
1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412
        elif not isinstance(metrics,
                            (int, float, numpy.float32, numpy.float64)):
            raise TypeError(
                "metrics must be 'int', 'float', 'np.float', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".
                format(type(metrics)))

        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best is None or self._is_better(metrics, self.best):
                self.best = metrics
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = max(self.last_lr * self.factor, self.min_lr)
                if self.last_lr - new_lr > self.epsilon:
                    self.last_lr = new_lr
                    if self.verbose:
                        print('Epoch {}: {} set learning rate to {}.'.format(
                            self.last_epoch, self.__class__.__name__,
                            self.last_lr))

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold


1413
class CosineAnnealingDecay(LRScheduler):
1414
    r"""
1415

G
guguguzi 已提交
1416 1417
    Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
    the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
1418
    SGDR.
1419 1420 1421 1422

    The algorithm can be described as following.

    .. math::
1423

1424 1425
        \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
        + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
G
guguguzi 已提交
1426
        & T_{cur} \neq (2k+1)T_{max};
1427 1428 1429 1430

        \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
        \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
        & T_{cur} = (2k+1)T_{max}.
G
guguguzi 已提交
1431 1432

    It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
1433
    Note that this only implements the cosine annealing part of SGDR, and not the restarts.
G
guguguzi 已提交
1434

1435 1436
    Args:
        learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
1437
        T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer.
1438 1439
        eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1440
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1441 1442

    Returns:
1443
        ``CosineAnnealingDecay`` instance to schedule learning rate.
1444 1445

    Examples:
G
guguguzi 已提交
1446

1447 1448 1449 1450 1451
        .. code-block:: python

            import paddle
            import numpy as np

1452
            # train on default dynamic graph mode
1453
            linear = paddle.nn.Linear(10, 10)
1454 1455
            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1456
            for epoch in range(20):
Z
Zhou Wei 已提交
1457
                for batch_id in range(5):
1458
                    x = paddle.uniform([10, 10])
1459
                    out = linear(x)
C
chentianyu03 已提交
1460
                    loss = paddle.mean(out)
1461
                    loss.backward()
1462 1463
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1464 1465
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1466

1467
            # train on static graph mode
1468 1469 1470 1471
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1472 1473
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1474 1475
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1476
                scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
1477 1478 1479 1480 1481 1482
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1483
                for batch_id in range(5):
1484 1485 1486 1487 1488 1489
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1490
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1491 1492
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1493 1494 1495 1496 1497 1498 1499 1500 1501 1502
    """

    def __init__(self,
                 learning_rate,
                 T_max,
                 eta_min=0,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(T_max, int):
            raise TypeError(
1503
                "The type of 'T_max' in 'CosineAnnealingDecay' must be 'int', but received %s."
1504 1505 1506
                % type(T_max))
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
1507
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
1508
                % type(eta_min))
1509 1510
        assert T_max > 0 and isinstance(
            T_max, int), " 'T_max' must be a positive integer."
1511 1512
        self.T_max = T_max
        self.eta_min = float(eta_min)
1513 1514
        super(CosineAnnealingDecay, self).__init__(learning_rate, last_epoch,
                                                   verbose)
1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
            return self.last_lr + (self.base_lr - self.eta_min) * (1 - math.cos(
                math.pi / self.T_max)) / 2

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
                self.last_lr - self.eta_min) + self.eta_min

    def _get_closed_form_lr(self):
        return self.eta_min + (self.base_lr - self.eta_min) * (1 + math.cos(
            math.pi * self.last_epoch / self.T_max)) / 2
G
guguguzi 已提交
1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590


class MultiplicativeDecay(LRScheduler):
    """
    Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` .

    The algorithm can be described as the code below.

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95

        learning_rate = 0.5        # epoch 0,
        learning_rate = 0.475      # epoch 1, 0.5*0.95
        learning_rate = 0.45125    # epoch 2, 0.475*0.95

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``MultiplicativeDecay`` instance to schedule learning rate.

    Examples:

        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(20):
                for batch_id in range(5):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
                "The type of 'lr_lambda' in 'MultiplicativeDecay' must be 'function', but received %s."
                % type(lr_lambda))

        self.lr_lambda = lr_lambda
        super(MultiplicativeDecay, self).__init__(learning_rate, last_epoch,
                                                  verbose)

    def get_lr(self):
1591 1592 1593 1594
        cur_lr = self.base_lr
        for epoch in range(1, self.last_epoch + 1):
            cur_lr = cur_lr * self.lr_lambda(epoch)
        return cur_lr
1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767 1768 1769 1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786 1787 1788 1789 1790 1791 1792 1793 1794 1795 1796 1797 1798 1799 1800 1801 1802 1803


class OneCycleLR(LRScheduler):
    r"""
    Sets the learning rate according to the one cycle learning rate scheduler.
    The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
    from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

    It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_.

    Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle,
    which claims that “unpublished work has shown even better results by using only two phases”.
    If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` .

    Also note that you should update learning rate each step.

    Args:
        max_learning_rate (float): The maximum learning rate. It is a python float number.
             Functionally, it defines the initial learning rate by ``divide_factor`` .
        total_steps (int): Number of total training steps.
        divide_factor (float): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25.
        end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate.
        phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3.
        anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing,
            'linear' for linear annealing. Default: 'cos'.
        three_phase (bool, optional): Whether to use three phase.
            If ``True``:
                1. The learning rate will first increase from initial learning rate to maximum learning rate.
                2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase.
                3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate.
            If ``False``:
                1. The learning rate will increase to maximum learning rate.
                2. Then it will directly decrease to minimum learning rate.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``OneCycleLR`` instance to schedule learning rate.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
    """

    def __init__(self,
                 max_learning_rate,
                 total_steps,
                 divide_factor=25.,
                 end_learning_rate=0.0001,
                 phase_pct=0.3,
                 anneal_strategy='cos',
                 three_phase=False,
                 last_epoch=-1,
                 verbose=False):
        # Check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
                "'max_learning_rate' must be 'float' or 'int', but received {}".
                format(type(total_steps)))
        if max_learning_rate < 0:
            raise ValueError("'max_learning_rate' must be a positive integer.")

        # Check type and value of end_learning_rate
        if not isinstance(end_learning_rate, (float, int)):
            raise TypeError(
                "'end_learning_rate' must be 'float' or 'int', but received {}".
                format(type(total_steps)))
        if end_learning_rate < 0:
            raise ValueError("'end_learning_rate' must be a positive integer.")

        # Check type and value of total_steps
        if not isinstance(total_steps, int):
            raise TypeError("'total_step' must be 'int', but received {}".
                            format(type(total_steps)))
        if total_steps <= 0:
            raise ValueError("'total_step' must be a positive integer.")
        self.total_steps = total_steps

        # Check type and value of pac_start
        if not isinstance(phase_pct, float):
            raise TypeError("'phase_pct' must be 'float', but received {}".
                            format(type(phase_pct)))
        if phase_pct < 0 or phase_pct > 1:
            raise ValueError(
                "'phase_pct' must be between 0 and 1, but received {}".format(
                    phase_pct))

        # Check type and value of divide_factor
        if not isinstance(divide_factor, (float, int)):
            raise TypeError(
                "'divide_factor' must be 'float' or 'int', but received {}".
                format(type(divide_factor)))

        initial_lr = max_learning_rate / float(divide_factor)
        min_lr = float(end_learning_rate)

        if three_phase:
            if phase_pct >= 0.5:
                raise ValueError(
                    "When three_phase is True, 'phase_pct' must be less than 0.5"
                )
            # start step and end step of each phase.
            self._step_config = [
                0,
                phase_pct * self.total_steps - 1,
                2 * phase_pct * self.total_steps - 2,
                self.total_steps - 1,
                self.total_steps - 1,  # for the last step.
            ]
            # step size of each phase.
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[3] - self._step_config[2],
                self._step_config[3] -
                self._step_config[2],  # for the last step.
            ]
            # start lr and end lr of each phase.
            self._lr_config = [
                initial_lr, max_learning_rate, initial_lr, min_lr
            ]
        else:
            self._step_config = [
                0, phase_pct * self.total_steps - 1, self.total_steps - 1,
                self.total_steps - 1
            ]
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[2] - self._step_config[1],
            ]
            self._lr_config = [initial_lr, max_learning_rate, min_lr]

        # Check anneal_strategy
        if anneal_strategy == 'cos':
            self.anneal_func = self._cos_annealing
        elif anneal_strategy == 'linear':
            self.anneal_func = self._linear_annealing
        else:
            raise ValueError(
                "'anneal_strategy' must by one of 'cos' or 'linear', but received {}".
                format(anneal_strategy))
        super(OneCycleLR, self).__init__(initial_lr, last_epoch, verbose)

    def _cos_annealing(self, start_lr, end_lr, pct):
        cos_out = math.cos(math.pi * pct) + 1
        return end_lr + (start_lr - end_lr) / 2.0 * cos_out

    def _linear_annealing(self, start_lr, end_lr, pct):
        return (end_lr - start_lr) * pct + start_lr

    def get_lr(self):
        current_step = self.last_epoch

        if current_step > self.total_steps:
            raise ValueError(
                "Tried to step {} times. However the number of total steps is {}"
                .format(current_step, self.total_steps))

        for (i, (end_step, step_size)
             ) in enumerate(zip(self._step_config[1:], self._steps_size)):
            # i == len(self._lr_config) - 2 catch the last step, otherwise it will return None.
            if current_step <= end_step or i == len(self._lr_config) - 2:
                # self._step_config[i] means start step of a phase.
                percentage = (current_step - self._step_config[i]) / step_size
                return self.anneal_func(self._lr_config[i],
                                        self._lr_config[i + 1], percentage)