lr.py 84.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings
17 18 19

import numpy

20
import paddle.fluid.core as core
21 22
from paddle import Tensor

G
guguguzi 已提交
23
__all__ = [  # noqa
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
    'LRScheduler',
    'NoamDecay',
    'PiecewiseDecay',
    'NaturalExpDecay',
    'InverseTimeDecay',
    'PolynomialDecay',
    'LinearWarmup',
    'ExponentialDecay',
    'MultiStepDecay',
    'StepDecay',
    'LambdaDecay',
    'ReduceOnPlateau',
    'CosineAnnealingDecay',
    'MultiplicativeDecay',
    'OneCycleLR',
    'CyclicLR',
40 41 42
]


43
class LRScheduler:
44 45 46 47
    """

    LRScheduler Base class. Define the common interface of a learning rate scheduler.

Z
Zhou Wei 已提交
48
    User can import it by ``from paddle.optimizer.lr import LRScheduler`` ,
49 50 51 52 53 54 55 56 57 58 59 60 61 62

    then overload it for your subclass and have a custom implementation of ``get_lr()`` .

    Otherwise, an ``NotImplementedError`` exception will be thrown.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        instance to schedule learning rate.

    Examples:
63
        Here is an example of a simple ``StepDecay`` implementation.
G
guguguzi 已提交
64

65
        .. code-block:: python
G
guguguzi 已提交
66

67
            import paddle
Z
Zhou Wei 已提交
68
            from paddle.optimizer.lr import LRScheduler
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

            class StepDecay(LRScheduler):
                def __init__(self,
                            learning_rate,
                            step_size,
                            gamma=0.1,
                            last_epoch=-1,
                            verbose=False):
                    if not isinstance(step_size, int):
                        raise TypeError(
                            "The type of 'step_size' must be 'int', but received %s." %
                            type(step_size))
                    if gamma >= 1.0:
                        raise ValueError('gamma should be < 1.0.')

                    self.step_size = step_size
                    self.gamma = gamma
86
                    super().__init__(learning_rate, last_epoch, verbose)
87 88 89 90

                def get_lr(self):
                    i = self.last_epoch // self.step_size
                    return self.base_lr * (self.gamma**i)
91 92 93 94 95 96

    """

    def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
97 98 99 100
                "The type of learning rate must be float, but received {}".format(
                    type(learning_rate)
                )
            )
101 102 103 104 105 106 107 108 109
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = last_epoch
        self.verbose = verbose
        self._var_name = None

        self.step()

    def __call__(self):
G
guguguzi 已提交
110
        """
111
        Return lastest computed learning rate on current epoch.
112 113 114 115 116
        """
        return self.last_lr

    def step(self, epoch=None):
        """
117

G
guguguzi 已提交
118
        ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
119
        The new learning rate will take effect on next ``optimizer.step`` .
120 121 122 123 124 125

        Args:
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
126

127 128 129 130 131 132 133 134 135 136 137 138
        """
        if epoch is None:
            self.last_epoch += 1
            self.last_lr = self.get_lr()
        else:
            self.last_epoch = epoch
            if hasattr(self, "_get_closed_form_lr"):
                self.last_lr = self._get_closed_form_lr()
            else:
                self.last_lr = self.get_lr()

        if self.verbose:
139 140 141 142 143
            print(
                'Epoch {}: {} set learning rate to {}.'.format(
                    self.last_epoch, self.__class__.__name__, self.last_lr
                )
            )
144 145 146

    def state_dict(self):
        """
147

148 149
        Returns the state of the scheduler as a :class:`dict`.

150
        It is a subset of ``self.__dict__`` .
151
        """
152
        self.state_keys()
153 154 155 156 157 158 159 160 161
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Tensor):
                assert value.shape == [
                    1
                ], "shape of Tensor in state_dict must be [1] {}".format(
162 163
                    value.shape
                )
164 165 166 167 168
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

169
    # For those subclass who overload LRScheduler, "last_epoch, last_lr" will be saved by default.
170
    # (Note): you can change it for your subclass.
171
    def state_keys(self):
172
        """
173 174 175 176 177 178 179

        For those subclass who overload ``LRScheduler`` (Base Class). Acquiescently, "last_epoch, last_lr" will be saved by ``self.keys = ['last_epoch', 'last_lr']`` .

        ``last_epoch`` is the current epoch num, and ``last_lr`` is the current learning rate.

        If you want to change the default behavior, you should have a custom implementation of ``_state_keys()`` to redefine ``self.keys`` .

180 181 182
        """
        self.keys = ['last_epoch', 'last_lr']

183
    def set_state_dict(self, state_dict):
184
        """
185

186 187
        Loads the schedulers state.
        """
188
        self.state_keys()
189 190 191 192 193
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
194 195 196 197
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".format(
                        key
                    )
                )
198 199 200 201 202
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

203 204
    # alias for set_state_dict
    set_dict = set_state_dict
205 206

    def get_lr(self):
207
        """
G
guguguzi 已提交
208

209 210 211 212
        For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .

        Otherwise, an ``NotImplementedError`` exception will be thrown.
        """
213 214 215 216
        # calculate by python float
        raise NotImplementedError


217
class NoamDecay(LRScheduler):
218
    r"""
219

G
guguguzi 已提交
220
    Applies Noam Decay to the initial learning rate.
221 222 223 224 225 226 227

    The algorithm can be described as following.

    .. math::

        new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})

G
guguguzi 已提交
228
    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
229 230 231 232 233 234 235


    Args:
        d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
        warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
        learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
236
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
237 238

    Returns:
239
        ``NoamDecay`` instance to schedule learning rate.
240 241 242 243 244 245 246

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

247
            # train on default dynamic graph mode
248
            linear = paddle.nn.Linear(10, 10)
249 250
            scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
251
            for epoch in range(20):
Z
Zhou Wei 已提交
252
                for batch_id in range(5):
253
                    x = paddle.uniform([10, 10])
254
                    out = linear(x)
C
chentianyu03 已提交
255
                    loss = paddle.mean(out)
256
                    loss.backward()
257 258
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
259 260
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
261

262
            # train on static graph mode
263 264 265 266
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
267 268
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
269 270
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
271
                scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
272 273 274 275 276 277
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
278
                for batch_id in range(5):
279 280 281 282 283 284
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
285
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
286 287
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
288 289 290

    """

291 292 293 294 295 296 297 298
    def __init__(
        self,
        d_model,
        warmup_steps,
        learning_rate=1.0,
        last_epoch=-1,
        verbose=False,
    ):
299 300 301
        if d_model <= 0:
            raise ValueError("d_model should be grater than 0")

302 303
        self.d_model = d_model
        self.warmup_steps = warmup_steps
304
        super().__init__(learning_rate, last_epoch, verbose)
305 306 307 308 309 310 311 312 313 314

    def get_lr(self):
        if self.last_epoch == 0:
            a = 1
        else:
            a = self.last_epoch**-0.5
        b = self.warmup_steps**-1.5 * self.last_epoch
        return self.base_lr * (self.d_model**-0.5) * min(a, b)


315
class PiecewiseDecay(LRScheduler):
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
    """

    Piecewise learning rate scheduler.

    The algorithm can be described as the code below:

    .. code-block:: text

        boundaries = [100, 200]
        values = [1.0, 0.5, 0.1]
        if epoch < 100:
            learning_rate = 1.0
        elif 100 <= global_step < 200:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Args:
G
guguguzi 已提交
334 335
        boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int.
        values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries.
336 337
            The type of element in the list is python float.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
338
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
339 340

    Returns:
341
        ``PiecewiseDecay`` instance to schedule learning rate.
342 343

    Examples:
G
guguguzi 已提交
344

345 346 347 348 349
        .. code-block:: python

            import paddle
            import numpy as np

350
            # train on default dynamic graph mode
351
            linear = paddle.nn.Linear(10, 10)
352 353
            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
354
            for epoch in range(20):
Z
Zhou Wei 已提交
355
                for batch_id in range(5):
356
                    x = paddle.uniform([10, 10])
357
                    out = linear(x)
C
chentianyu03 已提交
358
                    loss = paddle.mean(out)
359
                    loss.backward()
360 361
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
362 363
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
364

365
            # train on static graph mode
366 367 368 369
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
370 371
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
372 373
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
374
                scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
375 376 377 378 379 380
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
381
                for batch_id in range(5):
382 383 384 385 386 387
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
388
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
389 390
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
391 392 393
    """

    def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
394 395 396 397 398 399 400 401
        if len(boundaries) == 0:
            raise ValueError('The boundaries cannot be empty.')

        if len(values) <= len(boundaries):
            raise ValueError(
                f'The values have one more element than boundaries, but received len(values) [{len(values)}] < len(boundaries) + 1 [{len(boundaries) + 1}].'
            )

402 403
        self.boundaries = boundaries
        self.values = values
404
        super().__init__(last_epoch=last_epoch, verbose=verbose)
405 406 407 408 409 410 411 412

    def get_lr(self):
        for i in range(len(self.boundaries)):
            if self.last_epoch < self.boundaries[i]:
                return self.values[i]
        return self.values[len(self.values) - 1]


413
class NaturalExpDecay(LRScheduler):
414
    r"""
415 416

    Applies natural exponential decay to the initial learning rate.
G
guguguzi 已提交
417

418 419 420 421
    The algorithm can be described as following:

    .. math::

422
        new\_learning\_rate = learning\_rate * e^{- gamma * epoch}
423 424 425

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
426
        gamma (float, optional): A Ratio to update the learning rate, should greater than 0.0 to make learning rate decay. Default: 0.1.
427
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
428
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
429 430

    Returns:
431
        ``NaturalExpDecay`` instance to schedule learning rate.
432 433

    Examples:
G
guguguzi 已提交
434

435 436 437 438 439
        .. code-block:: python

            import paddle
            import numpy as np

440
            # train on default dynamic graph mode
441
            linear = paddle.nn.Linear(10, 10)
442 443
            scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
444
            for epoch in range(20):
Z
Zhou Wei 已提交
445
                for batch_id in range(5):
446
                    x = paddle.uniform([10, 10])
447
                    out = linear(x)
C
chentianyu03 已提交
448
                    loss = paddle.mean(out)
449
                    loss.backward()
450 451
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
452 453
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
454

455
            # train on static graph mode
456 457 458 459
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
460 461
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
462 463
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
464
                scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
465 466 467 468 469 470
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
471
                for batch_id in range(5):
472 473 474 475 476 477
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
478
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
479 480
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
481 482 483
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
484 485 486
        assert (
            gamma > 0.0
        ), " 'gamma' must be a positive number so that the learning rate will decay."
487
        self.gamma = gamma
488
        super().__init__(learning_rate, last_epoch, verbose)
489 490 491 492 493

    def get_lr(self):
        return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)


494
class InverseTimeDecay(LRScheduler):
495
    r"""
496 497 498 499 500 501 502

    Applies inverse time decay to the initial learning rate.

    The algorithm can be described as following:

    .. math::

503
        new\_learning\_rate = \frac{learning\_rate}{1 + gamma * epoch}
504 505 506

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
507
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
508 509
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
510
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
511 512

    Returns:
513
        ``InverseTimeDecay`` instance to schedule learning rate.
514 515

    Examples:
G
guguguzi 已提交
516

517 518 519 520 521
        .. code-block:: python

            import paddle
            import numpy as np

522
            # train on default dynamic graph mode
523
            linear = paddle.nn.Linear(10, 10)
524 525
            scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
526
            for epoch in range(20):
Z
Zhou Wei 已提交
527
                for batch_id in range(5):
528
                    x = paddle.uniform([10, 10])
529
                    out = linear(x)
C
chentianyu03 已提交
530
                    loss = paddle.mean(out)
531
                    loss.backward()
532 533
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
534 535
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
536

537
            # train on static graph mode
538 539 540 541
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
542 543
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
544 545
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
546
                scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
547 548 549 550 551 552
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
553
                for batch_id in range(5):
554 555 556 557 558 559
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
560
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
561 562
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
563 564 565 566 567

    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
568
        super().__init__(learning_rate, last_epoch, verbose)
569 570 571 572 573

    def get_lr(self):
        return self.base_lr / (1 + self.gamma * self.last_epoch)


574
class PolynomialDecay(LRScheduler):
575
    r"""
576 577 578 579 580 581 582 583 584

    Applies polynomial decay to the initial learning rate.

    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

G
guguguzi 已提交
585
        decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps})
586

587
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
588 589 590 591 592

    If cycle is set to False, then:

    .. math::

G
guguguzi 已提交
593
        epoch & = min(epoch, decay\_steps)
594

595
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
596 597 598 599


    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
600
        decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer.
601
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
602
        power(float, optional): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 1.0.
G
guguguzi 已提交
603
        cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
604 605
            to ``end_lr`` .  If False, the learning rate is monotone decreasing. Default: False.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
606
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
607 608

    Returns:
609
        ``PolynomialDecay`` instance to schedule learning rate.
610 611

    Examples:
G
guguguzi 已提交
612

613 614 615 616 617
        .. code-block:: python

            import paddle
            import numpy as np

618
            # train on default dynamic graph mode
619
            linear = paddle.nn.Linear(10, 10)
620 621
            scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
622
            for epoch in range(20):
Z
Zhou Wei 已提交
623
                for batch_id in range(5):
624
                    x = paddle.uniform([10, 10])
625
                    out = linear(x)
C
chentianyu03 已提交
626
                    loss = paddle.mean(out)
627
                    loss.backward()
628 629
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
630 631
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
632

633
            # train on static graph mode
634 635 636 637
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
638 639
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
640 641
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
642
                scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
643 644 645 646 647 648
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
649
                for batch_id in range(5):
650 651 652 653 654 655
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
656
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
657 658
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
659 660
    """

661 662 663 664 665 666 667 668 669 670
    def __init__(
        self,
        learning_rate,
        decay_steps,
        end_lr=0.0001,
        power=1.0,
        cycle=False,
        last_epoch=-1,
        verbose=False,
    ):
671
        assert decay_steps > 0 and isinstance(
672 673
            decay_steps, int
        ), " 'decay_steps' must be a positive integer."
674 675
        self.decay_steps = decay_steps
        self.end_lr = end_lr
676 677 678
        assert (
            power > 0.0
        ), " 'power' must be greater than 0.0 so that the learning rate will decay."
679 680
        self.power = power
        self.cycle = cycle
681
        super().__init__(learning_rate, last_epoch, verbose)
682 683 684 685 686 687

    def get_lr(self):
        tmp_epoch_num = self.last_epoch
        tmp_decay_steps = self.decay_steps
        if self.cycle:
            div_res = math.ceil(
688 689
                float(self.last_epoch) / float(self.decay_steps)
            )
690 691 692 693 694 695 696 697

            if self.last_epoch == 0:
                div_res = 1
            tmp_decay_steps = self.decay_steps * div_res
        else:
            tmp_epoch_num = min(self.last_epoch, self.decay_steps)

        return (self.base_lr - self.end_lr) * (
698 699
            (1 - float(tmp_epoch_num) / float(tmp_decay_steps)) ** self.power
        ) + self.end_lr
700 701


702
class LinearWarmup(LRScheduler):
703
    r"""
704 705 706

    Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
G
guguguzi 已提交
707

708
    When epoch < warmup_steps, learning rate is updated as:
G
guguguzi 已提交
709

710
    .. math::
G
guguguzi 已提交
711

712
            lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps}
G
guguguzi 已提交
713

714
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
G
guguguzi 已提交
715

716
    When epoch >= warmup_steps, learning rate is updated as:
G
guguguzi 已提交
717

718
    .. math::
G
guguguzi 已提交
719

720
            lr = learning_rate
G
guguguzi 已提交
721

722
    where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
723 724

    Args:
725
        learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
726
        warmup_steps (int): total steps of warm up. It must be a positive integer.
727 728 729
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
730
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
731 732

    Returns:
733
        ``LinearWarmup`` instance to schedule learning rate.
734 735

    Examples:
G
guguguzi 已提交
736

737 738 739 740 741
        .. code-block:: python

            import paddle
            import numpy as np

742
            # train on default dynamic graph mode
743
            linear = paddle.nn.Linear(10, 10)
744
            scheduler = paddle.optimizer.lr.LinearWarmup(
745
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
746
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
747
            for epoch in range(20):
Z
Zhou Wei 已提交
748
                for batch_id in range(5):
749
                    x = paddle.uniform([10, 10])
750
                    out = linear(x)
C
chentianyu03 已提交
751
                    loss = paddle.mean(out)
752
                    loss.backward()
753 754
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
755 756
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
757

758
            # train on static graph mode
759 760 761 762
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
763 764
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
765 766
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
767
                scheduler = paddle.optimizer.lr.LinearWarmup(
768 769 770 771 772 773 774
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
775
                for batch_id in range(5):
776 777 778 779 780 781
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
782
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
783 784
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
785 786
    """

787 788 789 790 791 792 793 794 795 796 797 798 799 800
    def __init__(
        self,
        learning_rate,
        warmup_steps,
        start_lr,
        end_lr,
        last_epoch=-1,
        verbose=False,
    ):
        type_check = (
            isinstance(learning_rate, float)
            or isinstance(learning_rate, int)
            or isinstance(learning_rate, LRScheduler)
        )
801 802
        if not type_check:
            raise TypeError(
803 804 805 806
                "the type of learning_rate should be [int, float or LRScheduler], the current type is {}".format(
                    learning_rate
                )
            )
807
        self.learning_rate = learning_rate
808
        assert warmup_steps > 0 and isinstance(
809 810
            warmup_steps, int
        ), " 'warmup_steps' must be a positive integer."
811 812 813
        self.warmup_steps = warmup_steps
        self.start_lr = start_lr
        self.end_lr = end_lr
814 815 816
        assert (
            end_lr > start_lr
        ), "end_lr {} must be greater than start_lr {}".format(end_lr, start_lr)
817
        super().__init__(start_lr, last_epoch, verbose)
818

819 820 821 822 823 824
    def state_dict(self):
        """
        Returns the state of the LinearWarmup scheduler as a :class:`dict`.

        It is a subset of ``self.__dict__`` .
        """
825
        state_dict = super().state_dict()
826 827 828 829 830 831 832 833
        if isinstance(self.learning_rate, LRScheduler):
            state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
        return state_dict

    def set_state_dict(self, state_dict):
        """
        Loads state_dict for LinearWarmup scheduler.
        """
834
        super().set_state_dict(state_dict)
835 836 837
        if isinstance(self.learning_rate, LRScheduler):
            self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])

838 839 840
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return (self.end_lr - self.start_lr) * float(
841 842
                self.last_epoch
            ) / float(self.warmup_steps) + self.start_lr
843
        else:
844
            if isinstance(self.learning_rate, LRScheduler):
845 846
                self.learning_rate.step(self.last_epoch - self.warmup_steps)
                return self.learning_rate()
847 848 849 850

            return self.learning_rate


851
class ExponentialDecay(LRScheduler):
852
    r"""
853

854
    Update learning rate by `gamma` each epoch.
855 856

    The algorithm can be described as following.
G
guguguzi 已提交
857

858 859 860 861 862 863
    .. math::

        new\_learning\_rate = last\_learning\_rate * gamma

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
864
        gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
865
            It should be in interval (0.0, 1.0).
866
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
867
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
868 869

    Returns:
870
        ``ExponentialDecay`` instance to schedule learning rate.
871 872

    Examples:
G
guguguzi 已提交
873

874 875 876 877 878
        .. code-block:: python

            import paddle
            import numpy as np

879
            # train on default dynamic graph mode
880
            linear = paddle.nn.Linear(10, 10)
881 882
            scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
883
            for epoch in range(20):
Z
Zhou Wei 已提交
884
                for batch_id in range(5):
885
                    x = paddle.uniform([10, 10])
886
                    out = linear(x)
C
chentianyu03 已提交
887
                    loss = paddle.mean(out)
888
                    loss.backward()
889 890
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
891 892
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
893

894
            # train on static graph mode
895 896 897 898
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
899 900
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
901 902
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
903
                scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
904 905 906 907 908 909
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
910
                for batch_id in range(5):
911 912 913 914 915 916
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
917
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
918 919
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
920 921 922
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
923 924 925
        assert (
            gamma > 0.0 and gamma < 1.0
        ), " 'gamma' must be in interval (0.0, 1.0) so that the learning rate will decay."
926
        self.gamma = gamma
927
        super().__init__(learning_rate, last_epoch, verbose)
928 929 930 931 932

    def get_lr(self):
        return self.base_lr * (self.gamma**self.last_epoch)


933
class MultiStepDecay(LRScheduler):
934
    """
935
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
936

G
guguguzi 已提交
937
    The algorithm can be described as the code below.
938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
G
guguguzi 已提交
954
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
955 956
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
957
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
958

959 960

    Returns:
961
        ``MultiStepDecay`` instance to schedule learning rate.
962 963

    Examples:
G
guguguzi 已提交
964

965 966 967 968 969
        .. code-block:: python

            import paddle
            import numpy as np

970
            # train on default dynamic graph mode
971
            linear = paddle.nn.Linear(10, 10)
972 973
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
974
            for epoch in range(20):
Z
Zhou Wei 已提交
975
                for batch_id in range(5):
976
                    x = paddle.uniform([10, 10])
977
                    out = linear(x)
C
chentianyu03 已提交
978
                    loss = paddle.mean(out)
979
                    loss.backward()
980 981
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
982 983
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
984

985
            # train on static graph mode
986 987 988 989
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
990 991
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
992 993
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
994
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
995 996 997 998 999 1000
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1001
                for batch_id in range(5):
1002 1003 1004 1005 1006 1007
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1008
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1009 1010
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1011 1012
    """

1013 1014 1015
    def __init__(
        self, learning_rate, milestones, gamma=0.1, last_epoch=-1, verbose=False
    ):
1016 1017 1018
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
1019 1020
                % type(milestones)
            )
1021

1022 1023
        if not all(
            [
1024 1025
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
1026 1027
            ]
        ):
1028 1029 1030 1031 1032 1033
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.milestones = milestones
        self.gamma = gamma
1034
        super().__init__(learning_rate, last_epoch, verbose)
1035 1036 1037 1038 1039

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
1040
        return self.base_lr * (self.gamma ** len(self.milestones))
1041 1042


1043
class StepDecay(LRScheduler):
1044 1045 1046
    """
    Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.

G
guguguzi 已提交
1047
    The algorithm can be described as the code below.
1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        gamma = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
1062
        step_size (int): the interval to update. It must be a positive integer.
G
guguguzi 已提交
1063
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
1064 1065
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1066
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1067 1068

    Returns:
1069
        ``StepDecay`` instance to schedule learning rate.
1070 1071 1072


    Examples:
G
guguguzi 已提交
1073

1074 1075 1076 1077 1078
        .. code-block:: python

            import paddle
            import numpy as np

1079
            # train on default dynamic graph mode
1080
            linear = paddle.nn.Linear(10, 10)
1081 1082
            scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1083
            for epoch in range(20):
Z
Zhou Wei 已提交
1084
                for batch_id in range(5):
1085
                    x = paddle.uniform([10, 10])
1086
                    out = linear(x)
C
chentianyu03 已提交
1087
                    loss = paddle.mean(out)
1088
                    loss.backward()
1089 1090
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1091 1092
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1093

1094
            # train on static graph mode
1095 1096 1097 1098
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1099 1100
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1101 1102
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1103
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
1104 1105 1106 1107 1108 1109
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1110
                for batch_id in range(5):
1111 1112 1113 1114 1115 1116
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1117
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1118 1119
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1120 1121
    """

1122 1123 1124
    def __init__(
        self, learning_rate, step_size, gamma=0.1, last_epoch=-1, verbose=False
    ):
1125 1126
        if not isinstance(step_size, int):
            raise TypeError(
1127 1128 1129
                "The type of 'step_size' must be 'int', but received %s."
                % type(step_size)
            )
1130 1131 1132
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

1133
        assert step_size > 0 and isinstance(
1134 1135
            step_size, int
        ), " 'step_size' must be a positive integer."
1136 1137
        self.step_size = step_size
        self.gamma = gamma
1138
        super().__init__(learning_rate, last_epoch, verbose)
1139 1140 1141 1142 1143 1144

    def get_lr(self):
        i = self.last_epoch // self.step_size
        return self.base_lr * (self.gamma**i)


1145
class LambdaDecay(LRScheduler):
1146 1147 1148
    """
    Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

G
guguguzi 已提交
1149
    The algorithm can be described as the code below.
1150 1151 1152 1153 1154 1155

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

1156 1157 1158
        learning_rate = 0.5        # epoch 0, 0.5*0.95**0
        learning_rate = 0.475      # epoch 1, 0.5*0.95**1
        learning_rate = 0.45125    # epoch 2, 0.5*0.95**2
1159 1160 1161 1162 1163

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1164
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
1165

1166
    Returns:
1167
        ``LambdaDecay`` instance to schedule learning rate.
1168 1169

    Examples:
G
guguguzi 已提交
1170

1171 1172 1173 1174 1175
        .. code-block:: python

            import paddle
            import numpy as np

1176
            # train on default dynamic graph mode
1177
            linear = paddle.nn.Linear(10, 10)
1178 1179
            scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1180
            for epoch in range(20):
Z
Zhou Wei 已提交
1181
                for batch_id in range(5):
1182
                    x = paddle.uniform([10, 10])
1183
                    out = linear(x)
C
chentianyu03 已提交
1184
                    loss = paddle.mean(out)
1185
                    loss.backward()
1186 1187
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1188 1189
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1190

1191
            # train on static graph mode
1192 1193 1194 1195
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1196 1197
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1198 1199
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1200
                scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
1201 1202 1203 1204 1205 1206
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1207
                for batch_id in range(5):
1208 1209 1210 1211 1212 1213
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1214
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1215 1216
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1217 1218 1219 1220 1221 1222

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
1223
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1224 1225
                % type(lr_lambda)
            )
1226 1227

        self.lr_lambda = lr_lambda
1228
        super().__init__(learning_rate, last_epoch, verbose)
1229 1230 1231 1232 1233

    def get_lr(self):
        return self.base_lr * self.lr_lambda(self.last_epoch)


1234
class ReduceOnPlateau(LRScheduler):
1235
    """
G
guguguzi 已提交
1236
    Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
1237 1238
    by 2 to 10 times once model performance has no longer improvement.

G
guguguzi 已提交
1239 1240 1241
    The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
1242 1243 1244 1245 1246 1247
    number of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
1248 1249
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning
1250
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
G
guguguzi 已提交
1251
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
1252
            It should be less than 1.0. Default: 0.1.
G
guguguzi 已提交
1253
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
1254
            Default: 10.
G
guguguzi 已提交
1255
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
1256 1257
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
G
guguguzi 已提交
1258
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
1259 1260 1261
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
G
guguguzi 已提交
1262
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
1263
            the update is ignored. Default: 1e-8.
1264 1265
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.

G
guguguzi 已提交
1266

1267
    Returns:
1268
        ``ReduceOnPlateau`` instance to schedule learning rate.
1269 1270 1271 1272 1273 1274 1275 1276


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

1277
            # train on default dynamic graph mode
1278
            linear = paddle.nn.Linear(10, 10)
1279 1280
            scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1281
            for epoch in range(20):
Z
Zhou Wei 已提交
1282
                for batch_id in range(5):
1283
                    x = paddle.uniform([10, 10])
1284
                    out = linear(x)
C
chentianyu03 已提交
1285
                    loss = paddle.mean(out)
1286
                    loss.backward()
1287 1288
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1289 1290
                    scheduler.step(loss)    # If you update learning rate each step
              # scheduler.step(loss)        # If you update learning rate each epoch
1291

1292
            # train on static graph mode
1293 1294 1295 1296
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1297 1298
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1299 1300
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1301
                scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
1302 1303 1304 1305 1306 1307
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1308
                for batch_id in range(5):
1309 1310 1311 1312 1313 1314
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1315
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1316 1317
                    scheduler.step(out[0])    # If you update learning rate each step
              # scheduler.step(out[0])        # If you update learning rate each epoch
1318 1319 1320

    """

1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333
    def __init__(
        self,
        learning_rate,
        mode='min',
        factor=0.1,
        patience=10,
        threshold=1e-4,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        epsilon=1e-8,
        verbose=False,
    ):
1334 1335 1336 1337 1338 1339 1340
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode: ' + mode + ' is unknown!')
        self.mode = mode

        if factor >= 1.0:
            raise ValueError(
1341 1342
                'new_lr = origin_lr * gamma and gamma should be < 1.0.'
            )
1343 1344 1345 1346
        self.factor = factor

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
1347 1348 1349
            raise ValueError(
                'threshold mode: ' + threshold_mode + ' is unknown!'
            )
1350 1351 1352
        self.threshold_mode = threshold_mode
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
1353
                "The type of 'learning_rate' in 'ReduceOnPlateau' must be 'float', but received %s."
1354 1355
                % type(learning_rate)
            )
1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375

        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

        self.cooldown_counter = 0
        self.best = None
        self.num_bad_epochs = 0

        # Can not call Parent __init__, so implement here.
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = 0
        self.verbose = verbose
        self._var_name = None

    # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
1376
    def state_keys(self):
1377
        self.keys = [
1378 1379 1380 1381 1382
            'cooldown_counter',
            'best',
            'num_bad_epochs',
            'last_epoch',
            'last_lr',
1383 1384 1385 1386
        ]

    def step(self, metrics, epoch=None):
        """
G
guguguzi 已提交
1387
        step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .
1388 1389 1390
        The new learning rate will take effect on next epoch.

        Args:
G
guguguzi 已提交
1391
            metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
1392 1393 1394 1395 1396 1397
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
                'numpy.ndarray', its shape must be [1].
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
G
guguguzi 已提交
1398

1399
        Examples:
1400
            Please refer to the example of current LRScheduler.
1401 1402 1403 1404 1405 1406
        """
        if epoch is None:
            self.last_epoch = self.last_epoch + 1
        else:
            self.last_epoch = epoch

1407
        # loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
1408
        if isinstance(metrics, (core.eager.Tensor, numpy.ndarray)):
1409 1410 1411 1412 1413 1414 1415 1416 1417 1418
            assert len(metrics.shape) == 1 and metrics.shape[0] == 1, (
                "the metrics.shape "
                "should be (1L,), but the current metrics.shape is {}. Maybe that "
                "you should call paddle.mean to process it first.".format(
                    metrics.shape
                )
            )
        elif not isinstance(
            metrics, (int, float, numpy.float32, numpy.float64)
        ):
1419
            raise TypeError(
1420
                "metrics must be 'int', 'float', 'np.float64', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".format(
1421 1422 1423
                    type(metrics)
                )
            )
1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440

        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best is None or self._is_better(metrics, self.best):
                self.best = metrics
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = max(self.last_lr * self.factor, self.min_lr)
                if self.last_lr - new_lr > self.epsilon:
                    self.last_lr = new_lr
                    if self.verbose:
1441 1442 1443 1444 1445 1446 1447
                        print(
                            'Epoch {}: {} set learning rate to {}.'.format(
                                self.last_epoch,
                                self.__class__.__name__,
                                self.last_lr,
                            )
                        )
1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold


1463
class CosineAnnealingDecay(LRScheduler):
1464
    r"""
1465

G
guguguzi 已提交
1466 1467
    Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
    the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
1468
    SGDR.
1469 1470 1471 1472

    The algorithm can be described as following.

    .. math::
1473

1474 1475
        \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
        + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
G
guguguzi 已提交
1476
        & T_{cur} \neq (2k+1)T_{max};
1477 1478 1479 1480

        \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
        \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
        & T_{cur} = (2k+1)T_{max}.
G
guguguzi 已提交
1481 1482

    It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
1483
    Note that this only implements the cosine annealing part of SGDR, and not the restarts.
G
guguguzi 已提交
1484

1485 1486
    Args:
        learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
1487
        T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer.
1488 1489
        eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1490
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1491 1492

    Returns:
1493
        ``CosineAnnealingDecay`` instance to schedule learning rate.
1494 1495

    Examples:
G
guguguzi 已提交
1496

1497 1498 1499 1500 1501
        .. code-block:: python

            import paddle
            import numpy as np

1502
            # train on default dynamic graph mode
1503
            linear = paddle.nn.Linear(10, 10)
1504 1505
            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1506
            for epoch in range(20):
Z
Zhou Wei 已提交
1507
                for batch_id in range(5):
1508
                    x = paddle.uniform([10, 10])
1509
                    out = linear(x)
C
chentianyu03 已提交
1510
                    loss = paddle.mean(out)
1511
                    loss.backward()
1512 1513
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1514 1515
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1516

1517
            # train on static graph mode
1518 1519 1520 1521
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1522 1523
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1524 1525
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1526
                scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
1527 1528 1529 1530 1531 1532
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1533
                for batch_id in range(5):
1534 1535 1536 1537 1538 1539
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1540
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1541 1542
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1543 1544
    """

1545 1546 1547
    def __init__(
        self, learning_rate, T_max, eta_min=0, last_epoch=-1, verbose=False
    ):
1548 1549
        if not isinstance(T_max, int):
            raise TypeError(
1550
                "The type of 'T_max' in 'CosineAnnealingDecay' must be 'int', but received %s."
1551 1552
                % type(T_max)
            )
1553 1554
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
1555
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
1556 1557
                % type(eta_min)
            )
1558
        assert T_max > 0 and isinstance(
1559 1560
            T_max, int
        ), " 'T_max' must be a positive integer."
1561 1562
        self.T_max = T_max
        self.eta_min = float(eta_min)
1563
        super().__init__(learning_rate, last_epoch, verbose)
1564 1565 1566 1567 1568

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
1569 1570 1571 1572 1573 1574
            return (
                self.last_lr
                + (self.base_lr - self.eta_min)
                * (1 - math.cos(math.pi / self.T_max))
                / 2
            )
1575 1576

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
1577 1578
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)
        ) * (self.last_lr - self.eta_min) + self.eta_min
1579 1580

    def _get_closed_form_lr(self):
1581 1582 1583 1584 1585 1586
        return (
            self.eta_min
            + (self.base_lr - self.eta_min)
            * (1 + math.cos(math.pi * self.last_epoch / self.T_max))
            / 2
        )
G
guguguzi 已提交
1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639


class MultiplicativeDecay(LRScheduler):
    """
    Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` .

    The algorithm can be described as the code below.

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95

        learning_rate = 0.5        # epoch 0,
        learning_rate = 0.475      # epoch 1, 0.5*0.95
        learning_rate = 0.45125    # epoch 2, 0.475*0.95

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``MultiplicativeDecay`` instance to schedule learning rate.

    Examples:

        .. code-block:: python

            import paddle

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(20):
                for batch_id in range(5):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
                "The type of 'lr_lambda' in 'MultiplicativeDecay' must be 'function', but received %s."
1640 1641
                % type(lr_lambda)
            )
G
guguguzi 已提交
1642 1643

        self.lr_lambda = lr_lambda
1644
        super().__init__(learning_rate, last_epoch, verbose)
G
guguguzi 已提交
1645 1646

    def get_lr(self):
1647 1648 1649 1650
        cur_lr = self.base_lr
        for epoch in range(1, self.last_epoch + 1):
            cur_lr = cur_lr * self.lr_lambda(epoch)
        return cur_lr
1651 1652 1653 1654


class OneCycleLR(LRScheduler):
    r"""
1655

1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668
    Sets the learning rate according to the one cycle learning rate scheduler.
    The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
    from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

    It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_.

    Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle,
    which claims that “unpublished work has shown even better results by using only two phases”.
    If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` .

    Also note that you should update learning rate each step.

    Args:
1669
        max_learning_rate (float): The maximum learning rate. It is a python float number. Functionally, it defines the initial learning rate by ``divide_factor`` .
1670
        total_steps (int): Number of total training steps.
1671
        divide_factor (float, optional): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25.
1672 1673
        end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate.
        phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3.
1674
        anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'.
1675
        three_phase (bool, optional): Whether to use three phase.
1676

1677
            If ``True``:
1678

1679 1680 1681
                1. The learning rate will first increase from initial learning rate to maximum learning rate.
                2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase.
                3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate.
1682

1683
            If ``False``:
1684

1685 1686
                1. The learning rate will increase to maximum learning rate.
                2. Then it will directly decrease to minimum learning rate.
1687

1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``OneCycleLR`` instance to schedule learning rate.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
1739

1740 1741
    """

1742 1743 1744 1745 1746 1747 1748 1749 1750 1751 1752 1753
    def __init__(
        self,
        max_learning_rate,
        total_steps,
        divide_factor=25.0,
        end_learning_rate=0.0001,
        phase_pct=0.3,
        anneal_strategy='cos',
        three_phase=False,
        last_epoch=-1,
        verbose=False,
    ):
1754 1755 1756
        # Check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
1757 1758 1759 1760
                "'max_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(max_learning_rate)
                )
            )
1761 1762 1763 1764 1765 1766
        if max_learning_rate < 0:
            raise ValueError("'max_learning_rate' must be a positive integer.")

        # Check type and value of end_learning_rate
        if not isinstance(end_learning_rate, (float, int)):
            raise TypeError(
1767 1768 1769 1770
                "'end_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(end_learning_rate)
                )
            )
1771 1772 1773 1774 1775
        if end_learning_rate < 0:
            raise ValueError("'end_learning_rate' must be a positive integer.")

        # Check type and value of total_steps
        if not isinstance(total_steps, int):
1776 1777
            raise TypeError(
                "'total_step' must be 'int', but received {}".format(
1778 1779 1780
                    type(total_steps)
                )
            )
1781 1782 1783 1784 1785 1786
        if total_steps <= 0:
            raise ValueError("'total_step' must be a positive integer.")
        self.total_steps = total_steps

        # Check type and value of pac_start
        if not isinstance(phase_pct, float):
1787 1788
            raise TypeError(
                "'phase_pct' must be 'float', but received {}".format(
1789 1790 1791
                    type(phase_pct)
                )
            )
1792 1793 1794
        if phase_pct < 0 or phase_pct > 1:
            raise ValueError(
                "'phase_pct' must be between 0 and 1, but received {}".format(
1795 1796 1797
                    phase_pct
                )
            )
1798 1799 1800 1801

        # Check type and value of divide_factor
        if not isinstance(divide_factor, (float, int)):
            raise TypeError(
1802 1803 1804 1805
                "'divide_factor' must be 'float' or 'int', but received {}".format(
                    type(divide_factor)
                )
            )
1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819 1820 1821 1822 1823 1824 1825 1826 1827

        initial_lr = max_learning_rate / float(divide_factor)
        min_lr = float(end_learning_rate)

        if three_phase:
            if phase_pct >= 0.5:
                raise ValueError(
                    "When three_phase is True, 'phase_pct' must be less than 0.5"
                )
            # start step and end step of each phase.
            self._step_config = [
                0,
                phase_pct * self.total_steps - 1,
                2 * phase_pct * self.total_steps - 2,
                self.total_steps - 1,
                self.total_steps - 1,  # for the last step.
            ]
            # step size of each phase.
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[3] - self._step_config[2],
1828 1829
                self._step_config[3]
                - self._step_config[2],  # for the last step.
1830 1831 1832
            ]
            # start lr and end lr of each phase.
            self._lr_config = [
1833 1834 1835 1836
                initial_lr,
                max_learning_rate,
                initial_lr,
                min_lr,
1837 1838 1839
            ]
        else:
            self._step_config = [
1840 1841 1842 1843
                0,
                phase_pct * self.total_steps - 1,
                self.total_steps - 1,
                self.total_steps - 1,
1844 1845 1846 1847 1848 1849 1850 1851 1852 1853 1854 1855 1856 1857 1858
            ]
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[2] - self._step_config[1],
            ]
            self._lr_config = [initial_lr, max_learning_rate, min_lr]

        # Check anneal_strategy
        if anneal_strategy == 'cos':
            self.anneal_func = self._cos_annealing
        elif anneal_strategy == 'linear':
            self.anneal_func = self._linear_annealing
        else:
            raise ValueError(
1859 1860 1861 1862
                "'anneal_strategy' must by one of 'cos' or 'linear', but received {}".format(
                    anneal_strategy
                )
            )
1863
        super().__init__(initial_lr, last_epoch, verbose)
1864 1865 1866 1867 1868 1869 1870 1871 1872 1873 1874 1875 1876

    def _cos_annealing(self, start_lr, end_lr, pct):
        cos_out = math.cos(math.pi * pct) + 1
        return end_lr + (start_lr - end_lr) / 2.0 * cos_out

    def _linear_annealing(self, start_lr, end_lr, pct):
        return (end_lr - start_lr) * pct + start_lr

    def get_lr(self):
        current_step = self.last_epoch

        if current_step > self.total_steps:
            raise ValueError(
1877 1878 1879 1880
                "Tried to step {} times. However the number of total steps is {}".format(
                    current_step, self.total_steps
                )
            )
1881

1882
        for (i, (end_step, step_size)) in enumerate(
1883 1884
            zip(self._step_config[1:], self._steps_size)
        ):
1885 1886 1887 1888
            # i == len(self._lr_config) - 2 catch the last step, otherwise it will return None.
            if current_step <= end_step or i == len(self._lr_config) - 2:
                # self._step_config[i] means start step of a phase.
                percentage = (current_step - self._step_config[i]) / step_size
1889 1890 1891
                return self.anneal_func(
                    self._lr_config[i], self._lr_config[i + 1], percentage
                )
1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926 1927 1928 1929 1930 1931 1932 1933 1934


class CyclicLR(LRScheduler):
    r"""
    Set the learning rate according to the cyclic learning rate (CLR) scheduler.
    The scheduler regards the process of learning rate adjustment as one cycle after another.
    It cycles the learning rate between two boundaries with a constant frequency.
    The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis.

    It has been proposed in `Cyclic Learning Rates for Training Neural Networks <https://arxiv.org/abs/1506.01186>`_.

    According to the paper, the cyclic learning rate schedule has three build-in scale methods:

    * "triangular": A basic triangular cycle without any amplitude scaling.
    * "triangular2": A basic triangular cycle that reduce initial amplitude by half each cycle.
    * "exp_range": A cycle that scales initial amplitude by scale function which is defined as :math:`gamma^{iterations}` .

    The initial amplitude is defined as max_learning_rate - base_learning_rate.
    Also note that you should update learning rate each step.

    Args:
        base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
            that set the base_learning_rate to 1/3 or 1/4 of max_learning_rate.
        max_learning_rate (float): Maximum learning rate in the cycle. It defines the cycle amplitude as above.
            Since there is some scaling operation during process of learning rate adjustment,
            max_learning_rate may not actually be reached.
        step_size_up (int): Number of training steps, which is used to increase learning rate in a cycle.
            The step size of one cycle will be defined by step_size_up + step_size_down. According to the paper, step
            size should be set as at least 3 or 4 times steps in one epoch.
        step_size_down (int, optional): Number of training steps, which is used to decrease learning rate in a cycle.
            If not specified, it's value will initialize to `` step_size_up `` . Default: None
        mode (str, optional): one of 'triangular', 'triangular2' or 'exp_range'.
            If scale_fn is specified, this argument will be ignored. Default: 'triangular'
        exp_gamma (float): Constant in 'exp_range' scaling function: exp_gamma**iterations. Used only when mode = 'exp_range'. Default: 1.0
        scale_fn (function, optional): A custom scaling function, which is used to replace three build-in methods.
            It should only have one argument. For all x >= 0, 0 <= scale_fn(x) <= 1.
            If specified, then 'mode' will be ignored. Default: None
        scale_mode (str, optional): One of 'cycle' or 'iterations'. Defines whether scale_fn is evaluated on cycle
            number or cycle iterations (total iterations since start of training). Default: 'cycle'
        last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate.
        verbose: (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
1935
        ``CyclicLR`` instance to schedule learning rate.
1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976 1977 1978 1979 1980 1981 1982 1983 1984

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
                    max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
    """

1985 1986 1987 1988 1989 1990 1991 1992 1993 1994 1995 1996 1997
    def __init__(
        self,
        base_learning_rate,
        max_learning_rate,
        step_size_up,
        step_size_down=None,
        mode='triangular',
        exp_gamma=1.0,
        scale_fn=None,
        scale_mode='cycle',
        last_epoch=-1,
        verbose=False,
    ):
1998 1999 2000
        # check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
2001 2002 2003 2004
                "'max_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(max_learning_rate)
                )
            )
2005 2006
        if max_learning_rate < 0:
            raise ValueError(
2007 2008 2009 2010
                "'max_learning_rate' must be a positive integer, but received {}".format(
                    max_learning_rate
                )
            )
2011 2012 2013 2014

        # check type and value of step_size_up
        if not isinstance(step_size_up, int):
            raise TypeError(
2015 2016 2017 2018
                "The type of 'step_size_up' must be int, but received {}".format(
                    type(step_size_up)
                )
            )
2019 2020
        if step_size_up <= 0:
            raise ValueError(
2021 2022 2023 2024
                "'step_size_up' must be a positive integer, but received {}".format(
                    step_size_up
                )
            )
2025 2026 2027 2028 2029

        # check type and value of step_size_down
        if step_size_down is not None:
            if not isinstance(step_size_down, int):
                raise TypeError(
2030 2031 2032 2033
                    "The type of 'step_size_down' must be int, but received {}".format(
                        type(step_size_down)
                    )
                )
2034 2035
            if step_size_down <= 0:
                raise ValueError(
2036 2037 2038 2039
                    "'step_size_down' must be a positive integer, but received {}".format(
                        step_size_down
                    )
                )
2040 2041 2042 2043 2044

        # check type of exp_gamma
        if not isinstance(exp_gamma, float):
            raise TypeError(
                "The type of 'exp_gamma' must be float, but received {}".format(
2045 2046 2047
                    type(exp_gamma)
                )
            )
2048 2049

        step_size_up = float(step_size_up)
2050 2051 2052 2053 2054
        step_size_down = (
            float(step_size_down)
            if step_size_down is not None
            else step_size_up
        )
2055 2056 2057 2058 2059 2060

        self.cycle_size = step_size_up + step_size_down
        self.step_up_pct = step_size_up / self.cycle_size
        self.max_lr = float(max_learning_rate)
        self.amplitude = self.max_lr - base_learning_rate

2061 2062 2063 2064
        if (
            mode not in ['triangular', 'triangular2', 'exp_range']
            and scale_fn is None
        ):
2065 2066 2067 2068 2069
            raise ValueError(
                "'mode' is invalid and 'scale_fn' is not specified, make sure one of 'mode' or 'scale_fn' is valid"
            )
        if scale_mode not in ['cycle', 'iterations']:
            raise ValueError(
2070 2071
                "'scale_mode' must be one of 'cycle' or 'iterations"
            )
2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083 2084 2085 2086 2087 2088 2089 2090 2091

        self.mode = mode
        self.gamma = exp_gamma  # only for exp_range mode

        if scale_fn is None:
            if self.mode == 'triangular':
                self.scale_fn = self._triangular_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'triangular2':
                self.scale_fn = self._triangular2_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'exp_range':
                self.scale_fn = self._exp_range_scale_fn
                self.scale_mode = 'iterations'
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode
        super().__init__(base_learning_rate, last_epoch, verbose)

    def _triangular_scale_fn(self, x):
2092
        return 1.0
2093 2094

    def _triangular2_scale_fn(self, x):
2095
        return 1 / (2.0 ** (x - 1))
2096 2097 2098 2099 2100 2101 2102 2103

    def _exp_range_scale_fn(self, x):
        return self.gamma**x

    def get_lr(self):
        iterations = self.last_epoch

        cycle = 1 + iterations // self.cycle_size
2104
        pct_per_cycle = 1.0 + iterations / self.cycle_size - cycle
2105 2106 2107 2108 2109 2110 2111 2112 2113 2114 2115

        if pct_per_cycle <= self.step_up_pct:
            scale_factor = pct_per_cycle / self.step_up_pct
        else:
            scale_factor = (1 - pct_per_cycle) / (1 - self.step_up_pct)

        base_height = self.amplitude * scale_factor

        lr = self.base_lr + base_height * self.scale_fn(eval(self.scale_mode))

        return lr