lr.py 74.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import numpy
import warnings
from paddle import Tensor
19
import paddle.fluid.core as core
J
Jiabin Yang 已提交
20
from ..fluid.framework import _in_legacy_dygraph
21

G
guguguzi 已提交
22
__all__ = [  # noqa
23 24 25 26
    'LRScheduler', 'NoamDecay', 'PiecewiseDecay', 'NaturalExpDecay',
    'InverseTimeDecay', 'PolynomialDecay', 'LinearWarmup', 'ExponentialDecay',
    'MultiStepDecay', 'StepDecay', 'LambdaDecay', 'ReduceOnPlateau',
    'CosineAnnealingDecay', 'MultiplicativeDecay', 'OneCycleLR'
27 28 29
]


30 31 32 33 34
class LRScheduler(object):
    """

    LRScheduler Base class. Define the common interface of a learning rate scheduler.

Z
Zhou Wei 已提交
35
    User can import it by ``from paddle.optimizer.lr import LRScheduler`` ,
36 37 38 39 40 41 42 43 44 45 46 47 48 49 50

    then overload it for your subclass and have a custom implementation of ``get_lr()`` .

    Otherwise, an ``NotImplementedError`` exception will be thrown.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        instance to schedule learning rate.

    Examples:
        Here is an example of a simple ``StepDecay`` implementation. 
G
guguguzi 已提交
51

52
        .. code-block:: python
G
guguguzi 已提交
53

54
            import paddle
Z
Zhou Wei 已提交
55
            from paddle.optimizer.lr import LRScheduler
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

            class StepDecay(LRScheduler):
                def __init__(self,
                            learning_rate,
                            step_size,
                            gamma=0.1,
                            last_epoch=-1,
                            verbose=False):
                    if not isinstance(step_size, int):
                        raise TypeError(
                            "The type of 'step_size' must be 'int', but received %s." %
                            type(step_size))
                    if gamma >= 1.0:
                        raise ValueError('gamma should be < 1.0.')

                    self.step_size = step_size
                    self.gamma = gamma
                    super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)

                def get_lr(self):
                    i = self.last_epoch // self.step_size
                    return self.base_lr * (self.gamma**i)
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94

    """

    def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
                "The type of learning rate must be float, but received {}".
                format(type(learning_rate)))
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = last_epoch
        self.verbose = verbose
        self._var_name = None

        self.step()

    def __call__(self):
G
guguguzi 已提交
95
        """
96
        Return lastest computed learning rate on current epoch.
97 98 99 100 101
        """
        return self.last_lr

    def step(self, epoch=None):
        """
102

G
guguguzi 已提交
103
        ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
104
        The new learning rate will take effect on next ``optimizer.step`` .
105 106 107 108 109 110

        Args:
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
111

112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
        """
        if epoch is None:
            self.last_epoch += 1
            self.last_lr = self.get_lr()
        else:
            self.last_epoch = epoch
            if hasattr(self, "_get_closed_form_lr"):
                self.last_lr = self._get_closed_form_lr()
            else:
                self.last_lr = self.get_lr()

        if self.verbose:
            print('Epoch {}: {} set learning rate to {}.'.format(
                self.last_epoch, self.__class__.__name__, self.last_lr))

    def state_dict(self):
        """
129

130 131
        Returns the state of the scheduler as a :class:`dict`.

132
        It is a subset of ``self.__dict__`` .
133
        """
134
        self.state_keys()
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Tensor):
                assert value.shape == [
                    1
                ], "shape of Tensor in state_dict must be [1] {}".format(
                    value.shape)
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

150
    # For those subclass who overload LRScheduler, "last_epoch, last_lr" will be saved by default.
151
    # (Note): you can change it for your subclass.
152
    def state_keys(self):
153
        """
154 155 156 157 158 159 160

        For those subclass who overload ``LRScheduler`` (Base Class). Acquiescently, "last_epoch, last_lr" will be saved by ``self.keys = ['last_epoch', 'last_lr']`` .

        ``last_epoch`` is the current epoch num, and ``last_lr`` is the current learning rate.

        If you want to change the default behavior, you should have a custom implementation of ``_state_keys()`` to redefine ``self.keys`` .

161 162 163
        """
        self.keys = ['last_epoch', 'last_lr']

164
    def set_state_dict(self, state_dict):
165
        """
166

167 168
        Loads the schedulers state.
        """
169
        self.state_keys()
170 171 172 173 174
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
175 176
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict"
                    .format(key))
177 178 179 180 181
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

182 183
    # alias for set_state_dict
    set_dict = set_state_dict
184 185

    def get_lr(self):
186
        """
G
guguguzi 已提交
187

188 189 190 191
        For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .

        Otherwise, an ``NotImplementedError`` exception will be thrown.
        """
192 193 194 195
        # calculate by python float
        raise NotImplementedError


196
class NoamDecay(LRScheduler):
197
    r"""
198

G
guguguzi 已提交
199
    Applies Noam Decay to the initial learning rate.
200 201 202 203 204 205 206

    The algorithm can be described as following.

    .. math::

        new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})

G
guguguzi 已提交
207
    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
208 209 210 211 212 213 214


    Args:
        d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
        warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
        learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
215
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
216 217

    Returns:
218
        ``NoamDecay`` instance to schedule learning rate.
219 220 221 222 223 224 225

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

226
            # train on default dynamic graph mode
227
            linear = paddle.nn.Linear(10, 10)
228 229
            scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
230
            for epoch in range(20):
Z
Zhou Wei 已提交
231
                for batch_id in range(5):
232
                    x = paddle.uniform([10, 10])
233
                    out = linear(x)
C
chentianyu03 已提交
234
                    loss = paddle.mean(out)
235
                    loss.backward()
236 237
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
238 239
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
240

241
            # train on static graph mode
242 243 244 245
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
246 247
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
248 249
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
250
                scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
251 252 253 254 255 256
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
257
                for batch_id in range(5):
258 259 260 261 262 263
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
264
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
265 266
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
267 268 269 270 271 272 273 274 275 276 277

    """

    def __init__(self,
                 d_model,
                 warmup_steps,
                 learning_rate=1.0,
                 last_epoch=-1,
                 verbose=False):
        self.d_model = d_model
        self.warmup_steps = warmup_steps
278
        super(NoamDecay, self).__init__(learning_rate, last_epoch, verbose)
279 280 281 282 283 284 285 286 287 288

    def get_lr(self):
        if self.last_epoch == 0:
            a = 1
        else:
            a = self.last_epoch**-0.5
        b = self.warmup_steps**-1.5 * self.last_epoch
        return self.base_lr * (self.d_model**-0.5) * min(a, b)


289
class PiecewiseDecay(LRScheduler):
290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
    """

    Piecewise learning rate scheduler.

    The algorithm can be described as the code below:

    .. code-block:: text

        boundaries = [100, 200]
        values = [1.0, 0.5, 0.1]
        if epoch < 100:
            learning_rate = 1.0
        elif 100 <= global_step < 200:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Args:
G
guguguzi 已提交
308 309
        boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int.
        values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries.
310 311
            The type of element in the list is python float.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
312
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
313 314

    Returns:
315
        ``PiecewiseDecay`` instance to schedule learning rate.
316 317

    Examples:
G
guguguzi 已提交
318

319 320 321 322 323
        .. code-block:: python

            import paddle
            import numpy as np

324
            # train on default dynamic graph mode
325
            linear = paddle.nn.Linear(10, 10)
326 327
            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
328
            for epoch in range(20):
Z
Zhou Wei 已提交
329
                for batch_id in range(5):
330
                    x = paddle.uniform([10, 10])
331
                    out = linear(x)
C
chentianyu03 已提交
332
                    loss = paddle.mean(out)
333
                    loss.backward()
334 335
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
336 337
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
338

339
            # train on static graph mode
340 341 342 343
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
344 345
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
346 347
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
348
                scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
349 350 351 352 353 354
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
355
                for batch_id in range(5):
356 357 358 359 360 361
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
362
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
363 364
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
365 366 367 368 369
    """

    def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
        self.boundaries = boundaries
        self.values = values
370 371
        super(PiecewiseDecay, self).__init__(last_epoch=last_epoch,
                                             verbose=verbose)
372 373 374 375 376 377 378 379

    def get_lr(self):
        for i in range(len(self.boundaries)):
            if self.last_epoch < self.boundaries[i]:
                return self.values[i]
        return self.values[len(self.values) - 1]


380
class NaturalExpDecay(LRScheduler):
381
    r"""
382 383

    Applies natural exponential decay to the initial learning rate.
G
guguguzi 已提交
384

385 386 387 388
    The algorithm can be described as following:

    .. math::

389
        new\_learning\_rate = learning\_rate * e^{- gamma * epoch}
390 391 392

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
393
        gamma (float, optional): A Ratio to update the learning rate, should greater than 0.0 to make learning rate decay. Default: 0.1.
394
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
395
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
396 397

    Returns:
398
        ``NaturalExpDecay`` instance to schedule learning rate.
399 400

    Examples:
G
guguguzi 已提交
401

402 403 404 405 406
        .. code-block:: python

            import paddle
            import numpy as np

407
            # train on default dynamic graph mode
408
            linear = paddle.nn.Linear(10, 10)
409 410
            scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
411
            for epoch in range(20):
Z
Zhou Wei 已提交
412
                for batch_id in range(5):
413
                    x = paddle.uniform([10, 10])
414
                    out = linear(x)
C
chentianyu03 已提交
415
                    loss = paddle.mean(out)
416
                    loss.backward()
417 418
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
419 420
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
421

422
            # train on static graph mode
423 424 425 426
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
427 428
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
429 430
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
431
                scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
432 433 434 435 436 437
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
438
                for batch_id in range(5):
439 440 441 442 443 444
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
445
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
446 447
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
448 449 450
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
451
        assert gamma > 0.0, " 'gamma' must be a positive number so that the learning rate will decay."
452
        self.gamma = gamma
453 454
        super(NaturalExpDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
455 456 457 458 459

    def get_lr(self):
        return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)


460
class InverseTimeDecay(LRScheduler):
461
    r"""
462 463 464 465 466 467 468

    Applies inverse time decay to the initial learning rate.

    The algorithm can be described as following:

    .. math::

469
        new\_learning\_rate = \frac{learning\_rate}{1 + gamma * epoch}
470 471 472

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
473
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
474 475
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
476
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
477 478

    Returns:
479
        ``InverseTimeDecay`` instance to schedule learning rate.
480 481

    Examples:
G
guguguzi 已提交
482

483 484 485 486 487
        .. code-block:: python

            import paddle
            import numpy as np

488
            # train on default dynamic graph mode
489
            linear = paddle.nn.Linear(10, 10)
490 491
            scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
492
            for epoch in range(20):
Z
Zhou Wei 已提交
493
                for batch_id in range(5):
494
                    x = paddle.uniform([10, 10])
495
                    out = linear(x)
C
chentianyu03 已提交
496
                    loss = paddle.mean(out)
497
                    loss.backward()
498 499
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
500 501
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
502

503
            # train on static graph mode
504 505 506 507
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
508 509
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
510 511
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
512
                scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
513 514 515 516 517 518
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
519
                for batch_id in range(5):
520 521 522 523 524 525
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
526
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
527 528
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
529 530 531 532 533

    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
534 535
        super(InverseTimeDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
536 537 538 539 540

    def get_lr(self):
        return self.base_lr / (1 + self.gamma * self.last_epoch)


541
class PolynomialDecay(LRScheduler):
542
    r"""
543 544 545 546 547 548 549 550 551

    Applies polynomial decay to the initial learning rate.

    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

G
guguguzi 已提交
552
        decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps})
553

554
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
555 556 557 558 559

    If cycle is set to False, then:

    .. math::

G
guguguzi 已提交
560
        epoch & = min(epoch, decay\_steps)
561

562
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
563 564 565 566


    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
567
        decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer.
568
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
569
        power(float, optional): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 1.0.
G
guguguzi 已提交
570
        cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
571 572
            to ``end_lr`` .  If False, the learning rate is monotone decreasing. Default: False.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
573
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
574 575

    Returns:
576
        ``PolynomialDecay`` instance to schedule learning rate.
577 578

    Examples:
G
guguguzi 已提交
579

580 581 582 583 584
        .. code-block:: python

            import paddle
            import numpy as np

585
            # train on default dynamic graph mode
586
            linear = paddle.nn.Linear(10, 10)
587 588
            scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
589
            for epoch in range(20):
Z
Zhou Wei 已提交
590
                for batch_id in range(5):
591
                    x = paddle.uniform([10, 10])
592
                    out = linear(x)
C
chentianyu03 已提交
593
                    loss = paddle.mean(out)
594
                    loss.backward()
595 596
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
597 598
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
599

600
            # train on static graph mode
601 602 603 604
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
605 606
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
607 608
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
609
                scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
610 611 612 613 614 615
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
616
                for batch_id in range(5):
617 618 619 620 621 622
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
623
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
624 625
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
626 627 628 629 630 631 632 633 634 635
    """

    def __init__(self,
                 learning_rate,
                 decay_steps,
                 end_lr=0.0001,
                 power=1.0,
                 cycle=False,
                 last_epoch=-1,
                 verbose=False):
636 637
        assert decay_steps > 0 and isinstance(
            decay_steps, int), " 'decay_steps' must be a positive integer."
638 639
        self.decay_steps = decay_steps
        self.end_lr = end_lr
640
        assert power > 0.0, " 'power' must be greater than 0.0 so that the learning rate will decay."
641 642
        self.power = power
        self.cycle = cycle
643 644
        super(PolynomialDecay, self).__init__(learning_rate, last_epoch,
                                              verbose)
645 646 647 648 649 650 651 652 653 654 655 656 657 658 659

    def get_lr(self):
        tmp_epoch_num = self.last_epoch
        tmp_decay_steps = self.decay_steps
        if self.cycle:
            div_res = math.ceil(
                float(self.last_epoch) / float(self.decay_steps))

            if self.last_epoch == 0:
                div_res = 1
            tmp_decay_steps = self.decay_steps * div_res
        else:
            tmp_epoch_num = min(self.last_epoch, self.decay_steps)

        return (self.base_lr - self.end_lr) * (
660 661
            (1 - float(tmp_epoch_num) / float(tmp_decay_steps))**
            self.power) + self.end_lr
662 663


664
class LinearWarmup(LRScheduler):
665
    r"""
666 667 668

    Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
G
guguguzi 已提交
669

670
    When epoch < warmup_steps, learning rate is updated as:
G
guguguzi 已提交
671

672
    .. math::
G
guguguzi 已提交
673

674
            lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps}
G
guguguzi 已提交
675

676
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
G
guguguzi 已提交
677

678
    When epoch >= warmup_steps, learning rate is updated as:
G
guguguzi 已提交
679

680
    .. math::
G
guguguzi 已提交
681

682
            lr = learning_rate
G
guguguzi 已提交
683

684
    where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
685 686

    Args:
687
        learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
688
        warmup_steps (int): total steps of warm up. It must be a positive integer.
689 690 691
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
692
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
693 694

    Returns:
695
        ``LinearWarmup`` instance to schedule learning rate.
696 697

    Examples:
G
guguguzi 已提交
698

699 700 701 702 703
        .. code-block:: python

            import paddle
            import numpy as np

704
            # train on default dynamic graph mode
705
            linear = paddle.nn.Linear(10, 10)
706
            scheduler = paddle.optimizer.lr.LinearWarmup(
707
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
708
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
709
            for epoch in range(20):
Z
Zhou Wei 已提交
710
                for batch_id in range(5):
711
                    x = paddle.uniform([10, 10])
712
                    out = linear(x)
C
chentianyu03 已提交
713
                    loss = paddle.mean(out)
714
                    loss.backward()
715 716
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
717 718
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
719

720
            # train on static graph mode
721 722 723 724
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
725 726
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
727 728
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
729
                scheduler = paddle.optimizer.lr.LinearWarmup(
730 731 732 733 734 735 736
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
737
                for batch_id in range(5):
738 739 740 741 742 743
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
744
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
745 746
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
747 748 749 750 751 752 753 754 755 756
    """

    def __init__(self,
                 learning_rate,
                 warmup_steps,
                 start_lr,
                 end_lr,
                 last_epoch=-1,
                 verbose=False):
        type_check = isinstance(learning_rate, float) or isinstance(
757
            learning_rate, int) or isinstance(learning_rate, LRScheduler)
758 759
        if not type_check:
            raise TypeError(
760 761
                "the type of learning_rate should be [int, float or LRScheduler], the current type is {}"
                .format(learning_rate))
762
        self.learning_rate = learning_rate
763 764
        assert warmup_steps > 0 and isinstance(
            warmup_steps, int), " 'warmup_steps' must be a positive integer."
765 766 767 768 769
        self.warmup_steps = warmup_steps
        self.start_lr = start_lr
        self.end_lr = end_lr
        assert end_lr > start_lr, "end_lr {} must be greater than start_lr {}".format(
            end_lr, start_lr)
770
        super(LinearWarmup, self).__init__(start_lr, last_epoch, verbose)
771

772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790
    def state_dict(self):
        """
        Returns the state of the LinearWarmup scheduler as a :class:`dict`.

        It is a subset of ``self.__dict__`` .
        """
        state_dict = super(LinearWarmup, self).state_dict()
        if isinstance(self.learning_rate, LRScheduler):
            state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
        return state_dict

    def set_state_dict(self, state_dict):
        """
        Loads state_dict for LinearWarmup scheduler.
        """
        super(LinearWarmup, self).set_state_dict(state_dict)
        if isinstance(self.learning_rate, LRScheduler):
            self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])

791 792 793 794 795
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return (self.end_lr - self.start_lr) * float(
                self.last_epoch) / float(self.warmup_steps) + self.start_lr
        else:
796
            if isinstance(self.learning_rate, LRScheduler):
797 798
                self.learning_rate.step(self.last_epoch - self.warmup_steps)
                return self.learning_rate()
799 800 801 802

            return self.learning_rate


803
class ExponentialDecay(LRScheduler):
804
    r"""
805

806
    Update learning rate by `gamma` each epoch.
807 808

    The algorithm can be described as following.
G
guguguzi 已提交
809

810 811 812 813 814 815
    .. math::

        new\_learning\_rate = last\_learning\_rate * gamma

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
816
        gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
817
            It should be in interval (0.0, 1.0).
818
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
819
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
820 821

    Returns:
822
        ``ExponentialDecay`` instance to schedule learning rate.
823 824

    Examples:
G
guguguzi 已提交
825

826 827 828 829 830
        .. code-block:: python

            import paddle
            import numpy as np

831
            # train on default dynamic graph mode
832
            linear = paddle.nn.Linear(10, 10)
833 834
            scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
835
            for epoch in range(20):
Z
Zhou Wei 已提交
836
                for batch_id in range(5):
837
                    x = paddle.uniform([10, 10])
838
                    out = linear(x)
C
chentianyu03 已提交
839
                    loss = paddle.mean(out)
840
                    loss.backward()
841 842
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
843 844
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
845

846
            # train on static graph mode
847 848 849 850
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
851 852
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
853 854
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
855
                scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
856 857 858 859 860 861
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
862
                for batch_id in range(5):
863 864 865 866 867 868
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
869
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
870 871
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
872 873 874
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
875
        assert gamma > 0.0 and gamma < 1.0, " 'gamma' must be in interval (0.0, 1.0) so that the learning rate will decay."
876
        self.gamma = gamma
877 878
        super(ExponentialDecay, self).__init__(learning_rate, last_epoch,
                                               verbose)
879 880 881 882 883

    def get_lr(self):
        return self.base_lr * (self.gamma**self.last_epoch)


884
class MultiStepDecay(LRScheduler):
885
    """
886
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
887

G
guguguzi 已提交
888
    The algorithm can be described as the code below.
889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
G
guguguzi 已提交
905
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
906 907
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
908
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
909

910 911

    Returns:
912
        ``MultiStepDecay`` instance to schedule learning rate.
913 914

    Examples:
G
guguguzi 已提交
915

916 917 918 919 920
        .. code-block:: python

            import paddle
            import numpy as np

921
            # train on default dynamic graph mode
922
            linear = paddle.nn.Linear(10, 10)
923 924
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
925
            for epoch in range(20):
Z
Zhou Wei 已提交
926
                for batch_id in range(5):
927
                    x = paddle.uniform([10, 10])
928
                    out = linear(x)
C
chentianyu03 已提交
929
                    loss = paddle.mean(out)
930
                    loss.backward()
931 932
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
933 934
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
935

936
            # train on static graph mode
937 938 939 940
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
941 942
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
943 944
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
945
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
946 947 948 949 950 951
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
952
                for batch_id in range(5):
953 954 955 956 957 958
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
959
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
960 961
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984
    """

    def __init__(self,
                 learning_rate,
                 milestones,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
                % type(milestones))

        if not all([
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
        ]):
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.milestones = milestones
        self.gamma = gamma
985
        super(MultiStepDecay, self).__init__(learning_rate, last_epoch, verbose)
986 987 988 989 990 991 992 993

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
        return self.base_lr * (self.gamma**len(self.milestones))


994
class StepDecay(LRScheduler):
995 996 997
    """
    Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.

G
guguguzi 已提交
998
    The algorithm can be described as the code below.
999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        gamma = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
1013
        step_size (int): the interval to update. It must be a positive integer.
G
guguguzi 已提交
1014
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
1015 1016
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1017
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1018 1019

    Returns:
1020
        ``StepDecay`` instance to schedule learning rate.
1021 1022 1023


    Examples:
G
guguguzi 已提交
1024

1025 1026 1027 1028 1029
        .. code-block:: python

            import paddle
            import numpy as np

1030
            # train on default dynamic graph mode
1031
            linear = paddle.nn.Linear(10, 10)
1032 1033
            scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1034
            for epoch in range(20):
Z
Zhou Wei 已提交
1035
                for batch_id in range(5):
1036
                    x = paddle.uniform([10, 10])
1037
                    out = linear(x)
C
chentianyu03 已提交
1038
                    loss = paddle.mean(out)
1039
                    loss.backward()
1040 1041
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1042 1043
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1044

1045
            # train on static graph mode
1046 1047 1048 1049
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1050 1051
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1052 1053
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1054
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
1055 1056 1057 1058 1059 1060
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1061
                for batch_id in range(5):
1062 1063 1064 1065 1066 1067
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1068
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1069 1070
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085
    """

    def __init__(self,
                 learning_rate,
                 step_size,
                 gamma=0.1,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(step_size, int):
            raise TypeError(
                "The type of 'step_size' must be 'int', but received %s." %
                type(step_size))
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

1086 1087
        assert step_size > 0 and isinstance(
            step_size, int), " 'step_size' must be a positive integer."
1088 1089
        self.step_size = step_size
        self.gamma = gamma
1090
        super(StepDecay, self).__init__(learning_rate, last_epoch, verbose)
1091 1092 1093 1094 1095 1096

    def get_lr(self):
        i = self.last_epoch // self.step_size
        return self.base_lr * (self.gamma**i)


1097
class LambdaDecay(LRScheduler):
1098 1099 1100
    """
    Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

G
guguguzi 已提交
1101
    The algorithm can be described as the code below.
1102 1103 1104 1105 1106 1107

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

1108 1109 1110
        learning_rate = 0.5        # epoch 0, 0.5*0.95**0
        learning_rate = 0.475      # epoch 1, 0.5*0.95**1
        learning_rate = 0.45125    # epoch 2, 0.5*0.95**2
1111 1112 1113 1114 1115

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1116
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
1117

1118
    Returns:
1119
        ``LambdaDecay`` instance to schedule learning rate.
1120 1121

    Examples:
G
guguguzi 已提交
1122

1123 1124 1125 1126 1127
        .. code-block:: python

            import paddle
            import numpy as np

1128
            # train on default dynamic graph mode
1129
            linear = paddle.nn.Linear(10, 10)
1130 1131
            scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1132
            for epoch in range(20):
Z
Zhou Wei 已提交
1133
                for batch_id in range(5):
1134
                    x = paddle.uniform([10, 10])
1135
                    out = linear(x)
C
chentianyu03 已提交
1136
                    loss = paddle.mean(out)
1137
                    loss.backward()
1138 1139
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1140 1141
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1142

1143
            # train on static graph mode
1144 1145 1146 1147
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1148 1149
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1150 1151
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1152
                scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
1153 1154 1155 1156 1157 1158
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1159
                for batch_id in range(5):
1160 1161 1162 1163 1164 1165
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1166
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1167 1168
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1169 1170 1171 1172 1173 1174

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
1175
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1176 1177 1178
                % type(lr_lambda))

        self.lr_lambda = lr_lambda
1179
        super(LambdaDecay, self).__init__(learning_rate, last_epoch, verbose)
1180 1181 1182 1183 1184

    def get_lr(self):
        return self.base_lr * self.lr_lambda(self.last_epoch)


1185
class ReduceOnPlateau(LRScheduler):
1186
    """
G
guguguzi 已提交
1187
    Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
1188 1189
    by 2 to 10 times once model performance has no longer improvement.

G
guguguzi 已提交
1190 1191 1192
    The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
1193 1194 1195 1196 1197 1198
    number of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
1199 1200
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning
1201
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
G
guguguzi 已提交
1202
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
1203
            It should be less than 1.0. Default: 0.1.
G
guguguzi 已提交
1204
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
1205
            Default: 10.
G
guguguzi 已提交
1206
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
1207 1208
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
G
guguguzi 已提交
1209
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
1210 1211 1212
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
G
guguguzi 已提交
1213
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
1214
            the update is ignored. Default: 1e-8.
1215 1216
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.

G
guguguzi 已提交
1217

1218
    Returns:
1219
        ``ReduceOnPlateau`` instance to schedule learning rate.
1220 1221 1222 1223 1224 1225 1226 1227


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

1228
            # train on default dynamic graph mode
1229
            linear = paddle.nn.Linear(10, 10)
1230 1231
            scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1232
            for epoch in range(20):
Z
Zhou Wei 已提交
1233
                for batch_id in range(5):
1234
                    x = paddle.uniform([10, 10])
1235
                    out = linear(x)
C
chentianyu03 已提交
1236
                    loss = paddle.mean(out)
1237
                    loss.backward()
1238 1239
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1240 1241
                    scheduler.step(loss)    # If you update learning rate each step
              # scheduler.step(loss)        # If you update learning rate each epoch
1242

1243
            # train on static graph mode
1244 1245 1246 1247
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1248 1249
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1250 1251
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1252
                scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
1253 1254 1255 1256 1257 1258
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1259
                for batch_id in range(5):
1260 1261 1262 1263 1264 1265
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1266
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1267 1268
                    scheduler.step(out[0])    # If you update learning rate each step
              # scheduler.step(out[0])        # If you update learning rate each epoch
1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299

    """

    def __init__(self,
                 learning_rate,
                 mode='min',
                 factor=0.1,
                 patience=10,
                 threshold=1e-4,
                 threshold_mode='rel',
                 cooldown=0,
                 min_lr=0,
                 epsilon=1e-8,
                 verbose=False):
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode: ' + mode + ' is unknown!')
        self.mode = mode

        if factor >= 1.0:
            raise ValueError(
                'new_lr = origin_lr * gamma and gamma should be < 1.0.')
        self.factor = factor

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
            raise ValueError('threshold mode: ' + threshold_mode +
                             ' is unknown!')
        self.threshold_mode = threshold_mode
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
1300
                "The type of 'learning_rate' in 'ReduceOnPlateau' must be 'float', but received %s."
1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321
                % type(learning_rate))

        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

        self.cooldown_counter = 0
        self.best = None
        self.num_bad_epochs = 0

        # Can not call Parent __init__, so implement here.
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = 0
        self.verbose = verbose
        self._var_name = None

    # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
1322
    def state_keys(self):
1323 1324 1325 1326 1327 1328 1329
        self.keys = [
            'cooldown_counter', 'best', 'num_bad_epochs', 'last_epoch',
            'last_lr'
        ]

    def step(self, metrics, epoch=None):
        """
G
guguguzi 已提交
1330
        step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .
1331 1332 1333
        The new learning rate will take effect on next epoch.

        Args:
G
guguguzi 已提交
1334
            metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
1335 1336 1337 1338 1339 1340
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
                'numpy.ndarray', its shape must be [1].
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
G
guguguzi 已提交
1341

1342
        Examples:
1343
            Please refer to the example of current LRScheduler.
1344 1345 1346 1347 1348 1349
        """
        if epoch is None:
            self.last_epoch = self.last_epoch + 1
        else:
            self.last_epoch = epoch

J
Jiabin Yang 已提交
1350
        if not _in_legacy_dygraph():
1351
            tmp = core.eager.Tensor
1352
        else:
1353 1354
            # need to declarate explicitly
            from paddle.framework import VarBase as Tensor
1355
            tmp = Tensor
1356
        # loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
1357
        if isinstance(metrics, (tmp, numpy.ndarray)):
1358
            assert len(metrics.shape) == 1 and metrics.shape[0] == 1, "the metrics.shape " \
G
guguguzi 已提交
1359 1360 1361
                                                                      "should be (1L,), but the current metrics.shape is {}. Maybe that " \
                                                                      "you should call paddle.mean to process it first.".format(
                metrics.shape)
1362 1363 1364
        elif not isinstance(metrics,
                            (int, float, numpy.float32, numpy.float64)):
            raise TypeError(
1365 1366
                "metrics must be 'int', 'float', 'np.float', 'numpy.ndarray' or 'paddle.Tensor', but receive {}"
                .format(type(metrics)))
1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401

        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best is None or self._is_better(metrics, self.best):
                self.best = metrics
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = max(self.last_lr * self.factor, self.min_lr)
                if self.last_lr - new_lr > self.epsilon:
                    self.last_lr = new_lr
                    if self.verbose:
                        print('Epoch {}: {} set learning rate to {}.'.format(
                            self.last_epoch, self.__class__.__name__,
                            self.last_lr))

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold


1402
class CosineAnnealingDecay(LRScheduler):
1403
    r"""
1404

G
guguguzi 已提交
1405 1406
    Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
    the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
1407
    SGDR.
1408 1409 1410 1411

    The algorithm can be described as following.

    .. math::
1412

1413 1414
        \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
        + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
G
guguguzi 已提交
1415
        & T_{cur} \neq (2k+1)T_{max};
1416 1417 1418 1419

        \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
        \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
        & T_{cur} = (2k+1)T_{max}.
G
guguguzi 已提交
1420 1421

    It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
1422
    Note that this only implements the cosine annealing part of SGDR, and not the restarts.
G
guguguzi 已提交
1423

1424 1425
    Args:
        learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
1426
        T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer.
1427 1428
        eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1429
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1430 1431

    Returns:
1432
        ``CosineAnnealingDecay`` instance to schedule learning rate.
1433 1434

    Examples:
G
guguguzi 已提交
1435

1436 1437 1438 1439 1440
        .. code-block:: python

            import paddle
            import numpy as np

1441
            # train on default dynamic graph mode
1442
            linear = paddle.nn.Linear(10, 10)
1443 1444
            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1445
            for epoch in range(20):
Z
Zhou Wei 已提交
1446
                for batch_id in range(5):
1447
                    x = paddle.uniform([10, 10])
1448
                    out = linear(x)
C
chentianyu03 已提交
1449
                    loss = paddle.mean(out)
1450
                    loss.backward()
1451 1452
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1453 1454
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1455

1456
            # train on static graph mode
1457 1458 1459 1460
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1461 1462
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1463 1464
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1465
                scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
1466 1467 1468 1469 1470 1471
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1472
                for batch_id in range(5):
1473 1474 1475 1476 1477 1478
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1479
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1480 1481
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1482 1483 1484 1485 1486 1487 1488 1489 1490 1491
    """

    def __init__(self,
                 learning_rate,
                 T_max,
                 eta_min=0,
                 last_epoch=-1,
                 verbose=False):
        if not isinstance(T_max, int):
            raise TypeError(
1492
                "The type of 'T_max' in 'CosineAnnealingDecay' must be 'int', but received %s."
1493 1494 1495
                % type(T_max))
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
1496
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
1497
                % type(eta_min))
1498 1499
        assert T_max > 0 and isinstance(
            T_max, int), " 'T_max' must be a positive integer."
1500 1501
        self.T_max = T_max
        self.eta_min = float(eta_min)
1502 1503
        super(CosineAnnealingDecay, self).__init__(learning_rate, last_epoch,
                                                   verbose)
1504 1505 1506 1507 1508

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
1509 1510
            return self.last_lr + (self.base_lr - self.eta_min) * (
                1 - math.cos(math.pi / self.T_max)) / 2
1511 1512 1513 1514 1515 1516

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)) * (
                self.last_lr - self.eta_min) + self.eta_min

    def _get_closed_form_lr(self):
1517 1518
        return self.eta_min + (self.base_lr - self.eta_min) * (
            1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
G
guguguzi 已提交
1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579


class MultiplicativeDecay(LRScheduler):
    """
    Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` .

    The algorithm can be described as the code below.

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95

        learning_rate = 0.5        # epoch 0,
        learning_rate = 0.475      # epoch 1, 0.5*0.95
        learning_rate = 0.45125    # epoch 2, 0.475*0.95

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``MultiplicativeDecay`` instance to schedule learning rate.

    Examples:

        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(20):
                for batch_id in range(5):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
                "The type of 'lr_lambda' in 'MultiplicativeDecay' must be 'function', but received %s."
                % type(lr_lambda))

        self.lr_lambda = lr_lambda
        super(MultiplicativeDecay, self).__init__(learning_rate, last_epoch,
                                                  verbose)

    def get_lr(self):
1580 1581 1582 1583
        cur_lr = self.base_lr
        for epoch in range(1, self.last_epoch + 1):
            cur_lr = cur_lr * self.lr_lambda(epoch)
        return cur_lr
1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697


class OneCycleLR(LRScheduler):
    r"""
    Sets the learning rate according to the one cycle learning rate scheduler.
    The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
    from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

    It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_.

    Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle,
    which claims that “unpublished work has shown even better results by using only two phases”.
    If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` .

    Also note that you should update learning rate each step.

    Args:
        max_learning_rate (float): The maximum learning rate. It is a python float number.
             Functionally, it defines the initial learning rate by ``divide_factor`` .
        total_steps (int): Number of total training steps.
        divide_factor (float): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25.
        end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate.
        phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3.
        anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing,
            'linear' for linear annealing. Default: 'cos'.
        three_phase (bool, optional): Whether to use three phase.
            If ``True``:
                1. The learning rate will first increase from initial learning rate to maximum learning rate.
                2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase.
                3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate.
            If ``False``:
                1. The learning rate will increase to maximum learning rate.
                2. Then it will directly decrease to minimum learning rate.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``OneCycleLR`` instance to schedule learning rate.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
    """

    def __init__(self,
                 max_learning_rate,
                 total_steps,
                 divide_factor=25.,
                 end_learning_rate=0.0001,
                 phase_pct=0.3,
                 anneal_strategy='cos',
                 three_phase=False,
                 last_epoch=-1,
                 verbose=False):
        # Check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
                "'max_learning_rate' must be 'float' or 'int', but received {}".
                format(type(total_steps)))
        if max_learning_rate < 0:
            raise ValueError("'max_learning_rate' must be a positive integer.")

        # Check type and value of end_learning_rate
        if not isinstance(end_learning_rate, (float, int)):
            raise TypeError(
                "'end_learning_rate' must be 'float' or 'int', but received {}".
                format(type(total_steps)))
        if end_learning_rate < 0:
            raise ValueError("'end_learning_rate' must be a positive integer.")

        # Check type and value of total_steps
        if not isinstance(total_steps, int):
1698 1699 1700
            raise TypeError(
                "'total_step' must be 'int', but received {}".format(
                    type(total_steps)))
1701 1702 1703 1704 1705 1706
        if total_steps <= 0:
            raise ValueError("'total_step' must be a positive integer.")
        self.total_steps = total_steps

        # Check type and value of pac_start
        if not isinstance(phase_pct, float):
1707 1708 1709
            raise TypeError(
                "'phase_pct' must be 'float', but received {}".format(
                    type(phase_pct)))
1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753 1754 1755 1756 1757 1758 1759 1760 1761 1762 1763 1764 1765 1766 1767
        if phase_pct < 0 or phase_pct > 1:
            raise ValueError(
                "'phase_pct' must be between 0 and 1, but received {}".format(
                    phase_pct))

        # Check type and value of divide_factor
        if not isinstance(divide_factor, (float, int)):
            raise TypeError(
                "'divide_factor' must be 'float' or 'int', but received {}".
                format(type(divide_factor)))

        initial_lr = max_learning_rate / float(divide_factor)
        min_lr = float(end_learning_rate)

        if three_phase:
            if phase_pct >= 0.5:
                raise ValueError(
                    "When three_phase is True, 'phase_pct' must be less than 0.5"
                )
            # start step and end step of each phase.
            self._step_config = [
                0,
                phase_pct * self.total_steps - 1,
                2 * phase_pct * self.total_steps - 2,
                self.total_steps - 1,
                self.total_steps - 1,  # for the last step.
            ]
            # step size of each phase.
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[3] - self._step_config[2],
                self._step_config[3] -
                self._step_config[2],  # for the last step.
            ]
            # start lr and end lr of each phase.
            self._lr_config = [
                initial_lr, max_learning_rate, initial_lr, min_lr
            ]
        else:
            self._step_config = [
                0, phase_pct * self.total_steps - 1, self.total_steps - 1,
                self.total_steps - 1
            ]
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[2] - self._step_config[1],
            ]
            self._lr_config = [initial_lr, max_learning_rate, min_lr]

        # Check anneal_strategy
        if anneal_strategy == 'cos':
            self.anneal_func = self._cos_annealing
        elif anneal_strategy == 'linear':
            self.anneal_func = self._linear_annealing
        else:
            raise ValueError(
1768 1769
                "'anneal_strategy' must by one of 'cos' or 'linear', but received {}"
                .format(anneal_strategy))
1770 1771 1772 1773 1774 1775 1776 1777 1778 1779 1780 1781 1782 1783 1784 1785 1786
        super(OneCycleLR, self).__init__(initial_lr, last_epoch, verbose)

    def _cos_annealing(self, start_lr, end_lr, pct):
        cos_out = math.cos(math.pi * pct) + 1
        return end_lr + (start_lr - end_lr) / 2.0 * cos_out

    def _linear_annealing(self, start_lr, end_lr, pct):
        return (end_lr - start_lr) * pct + start_lr

    def get_lr(self):
        current_step = self.last_epoch

        if current_step > self.total_steps:
            raise ValueError(
                "Tried to step {} times. However the number of total steps is {}"
                .format(current_step, self.total_steps))

1787 1788
        for (i, (end_step, step_size)) in enumerate(
                zip(self._step_config[1:], self._steps_size)):
1789 1790 1791 1792 1793 1794
            # i == len(self._lr_config) - 2 catch the last step, otherwise it will return None.
            if current_step <= end_step or i == len(self._lr_config) - 2:
                # self._step_config[i] means start step of a phase.
                percentage = (current_step - self._step_config[i]) / step_size
                return self.anneal_func(self._lr_config[i],
                                        self._lr_config[i + 1], percentage)