lr.py 83.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
import warnings
17 18 19

import numpy

20
import paddle.fluid.core as core
21 22
from paddle import Tensor

G
guguguzi 已提交
23
__all__ = [  # noqa
24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
    'LRScheduler',
    'NoamDecay',
    'PiecewiseDecay',
    'NaturalExpDecay',
    'InverseTimeDecay',
    'PolynomialDecay',
    'LinearWarmup',
    'ExponentialDecay',
    'MultiStepDecay',
    'StepDecay',
    'LambdaDecay',
    'ReduceOnPlateau',
    'CosineAnnealingDecay',
    'MultiplicativeDecay',
    'OneCycleLR',
    'CyclicLR',
40 41 42
]


43
class LRScheduler:
44 45 46 47
    """

    LRScheduler Base class. Define the common interface of a learning rate scheduler.

Z
Zhou Wei 已提交
48
    User can import it by ``from paddle.optimizer.lr import LRScheduler`` ,
49 50 51 52 53 54 55 56 57 58 59 60 61 62

    then overload it for your subclass and have a custom implementation of ``get_lr()`` .

    Otherwise, an ``NotImplementedError`` exception will be thrown.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        instance to schedule learning rate.

    Examples:
63
        Here is an example of a simple ``StepDecay`` implementation.
G
guguguzi 已提交
64

65
        .. code-block:: python
G
guguguzi 已提交
66

67
            import paddle
Z
Zhou Wei 已提交
68
            from paddle.optimizer.lr import LRScheduler
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

            class StepDecay(LRScheduler):
                def __init__(self,
                            learning_rate,
                            step_size,
                            gamma=0.1,
                            last_epoch=-1,
                            verbose=False):
                    if not isinstance(step_size, int):
                        raise TypeError(
                            "The type of 'step_size' must be 'int', but received %s." %
                            type(step_size))
                    if gamma >= 1.0:
                        raise ValueError('gamma should be < 1.0.')

                    self.step_size = step_size
                    self.gamma = gamma
86
                    super().__init__(learning_rate, last_epoch, verbose)
87 88 89 90

                def get_lr(self):
                    i = self.last_epoch // self.step_size
                    return self.base_lr * (self.gamma**i)
91 92 93 94 95 96

    """

    def __init__(self, learning_rate=0.1, last_epoch=-1, verbose=False):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
97 98 99 100
                "The type of learning rate must be float, but received {}".format(
                    type(learning_rate)
                )
            )
101 102 103 104 105 106 107 108 109
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = last_epoch
        self.verbose = verbose
        self._var_name = None

        self.step()

    def __call__(self):
G
guguguzi 已提交
110
        """
111
        Return lastest computed learning rate on current epoch.
112 113 114 115 116
        """
        return self.last_lr

    def step(self, epoch=None):
        """
117

G
guguguzi 已提交
118
        ``step`` should be called after ``optimizer.step`` . It will update the learning rate in optimizer according to current ``epoch`` .
119
        The new learning rate will take effect on next ``optimizer.step`` .
120 121 122 123 124 125

        Args:
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
126

127 128 129 130 131 132 133 134 135 136 137 138
        """
        if epoch is None:
            self.last_epoch += 1
            self.last_lr = self.get_lr()
        else:
            self.last_epoch = epoch
            if hasattr(self, "_get_closed_form_lr"):
                self.last_lr = self._get_closed_form_lr()
            else:
                self.last_lr = self.get_lr()

        if self.verbose:
139 140 141 142 143
            print(
                'Epoch {}: {} set learning rate to {}.'.format(
                    self.last_epoch, self.__class__.__name__, self.last_lr
                )
            )
144 145 146

    def state_dict(self):
        """
147

148 149
        Returns the state of the scheduler as a :class:`dict`.

150
        It is a subset of ``self.__dict__`` .
151
        """
152
        self.state_keys()
153 154 155 156 157 158 159 160 161
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Tensor):
                assert value.shape == [
                    1
                ], "shape of Tensor in state_dict must be [1] {}".format(
162 163
                    value.shape
                )
164 165 166 167 168
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

169
    # For those subclass who overload LRScheduler, "last_epoch, last_lr" will be saved by default.
170
    # (Note): you can change it for your subclass.
171
    def state_keys(self):
172
        """
173 174 175 176 177 178 179

        For those subclass who overload ``LRScheduler`` (Base Class). Acquiescently, "last_epoch, last_lr" will be saved by ``self.keys = ['last_epoch', 'last_lr']`` .

        ``last_epoch`` is the current epoch num, and ``last_lr`` is the current learning rate.

        If you want to change the default behavior, you should have a custom implementation of ``_state_keys()`` to redefine ``self.keys`` .

180 181 182
        """
        self.keys = ['last_epoch', 'last_lr']

183
    def set_state_dict(self, state_dict):
184
        """
185

186 187
        Loads the schedulers state.
        """
188
        self.state_keys()
189 190 191 192 193
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
194 195 196 197
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".format(
                        key
                    )
                )
198 199 200 201 202
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

203 204
    # alias for set_state_dict
    set_dict = set_state_dict
205 206

    def get_lr(self):
207
        """
G
guguguzi 已提交
208

209 210 211 212
        For those subclass who overload ``LRScheduler`` (Base Class), User should have a custom implementation of ``get_lr()`` .

        Otherwise, an ``NotImplementedError`` exception will be thrown.
        """
213 214 215 216
        # calculate by python float
        raise NotImplementedError


217
class NoamDecay(LRScheduler):
218
    r"""
219

G
guguguzi 已提交
220
    Applies Noam Decay to the initial learning rate.
221 222 223 224 225 226 227

    The algorithm can be described as following.

    .. math::

        new\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(epoch^{-0.5}, epoch * warmup\_steps^{-1.5})

G
guguguzi 已提交
228
    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
229 230 231 232 233 234 235


    Args:
        d$_{model}$(int): The dimensionality of input and output feature vector of model. It is a python int number.
        warmup_steps(int): The number of warmup steps. A super parameter. It is a python int number
        learning_rate (float): The initial learning rate. It is a python float number. Default: 1.0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
236
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
237 238

    Returns:
239
        ``NoamDecay`` instance to schedule learning rate.
240 241 242 243 244 245 246

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

247
            # train on default dynamic graph mode
248
            linear = paddle.nn.Linear(10, 10)
249 250
            scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
251
            for epoch in range(20):
Z
Zhou Wei 已提交
252
                for batch_id in range(5):
253
                    x = paddle.uniform([10, 10])
254
                    out = linear(x)
C
chentianyu03 已提交
255
                    loss = paddle.mean(out)
256
                    loss.backward()
257 258
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
259 260
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
261

262
            # train on static graph mode
263 264 265 266
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
267 268
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
269 270
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
271
                scheduler = paddle.optimizer.lr.NoamDecay(d_model=0.01, warmup_steps=100, verbose=True)
272 273 274 275 276 277
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
278
                for batch_id in range(5):
279 280 281 282 283 284
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
285
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
286 287
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
288 289 290

    """

291 292 293 294 295 296 297 298
    def __init__(
        self,
        d_model,
        warmup_steps,
        learning_rate=1.0,
        last_epoch=-1,
        verbose=False,
    ):
299 300 301
        if d_model <= 0:
            raise ValueError("d_model should be grater than 0")

302 303
        self.d_model = d_model
        self.warmup_steps = warmup_steps
304
        super().__init__(learning_rate, last_epoch, verbose)
305 306 307 308 309 310 311 312 313 314

    def get_lr(self):
        if self.last_epoch == 0:
            a = 1
        else:
            a = self.last_epoch**-0.5
        b = self.warmup_steps**-1.5 * self.last_epoch
        return self.base_lr * (self.d_model**-0.5) * min(a, b)


315
class PiecewiseDecay(LRScheduler):
316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
    """

    Piecewise learning rate scheduler.

    The algorithm can be described as the code below:

    .. code-block:: text

        boundaries = [100, 200]
        values = [1.0, 0.5, 0.1]
        if epoch < 100:
            learning_rate = 1.0
        elif 100 <= global_step < 200:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Args:
G
guguguzi 已提交
334 335
        boundaries(list|tuple): A list/tuple of steps numbers. The type of element in the list is python int.
        values(list|tuple): A list/tuple of learning rate values that will be picked during different epoch boundaries.
336 337
            The type of element in the list is python float.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
338
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
339 340

    Returns:
341
        ``PiecewiseDecay`` instance to schedule learning rate.
342 343

    Examples:
G
guguguzi 已提交
344

345 346 347 348 349
        .. code-block:: python

            import paddle
            import numpy as np

350
            # train on default dynamic graph mode
351
            linear = paddle.nn.Linear(10, 10)
352 353
            scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
354
            for epoch in range(20):
Z
Zhou Wei 已提交
355
                for batch_id in range(5):
356
                    x = paddle.uniform([10, 10])
357
                    out = linear(x)
C
chentianyu03 已提交
358
                    loss = paddle.mean(out)
359
                    loss.backward()
360 361
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
362 363
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
364

365
            # train on static graph mode
366 367 368 369
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
370 371
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
372 373
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
374
                scheduler = paddle.optimizer.lr.PiecewiseDecay(boundaries=[3, 6, 9], values=[0.1, 0.2, 0.3, 0.4], verbose=True)
375 376 377 378 379 380
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
381
                for batch_id in range(5):
382 383 384 385 386 387
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
388
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
389 390
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
391 392 393 394 395
    """

    def __init__(self, boundaries, values, last_epoch=-1, verbose=False):
        self.boundaries = boundaries
        self.values = values
396
        super().__init__(last_epoch=last_epoch, verbose=verbose)
397 398 399 400 401 402 403 404

    def get_lr(self):
        for i in range(len(self.boundaries)):
            if self.last_epoch < self.boundaries[i]:
                return self.values[i]
        return self.values[len(self.values) - 1]


405
class NaturalExpDecay(LRScheduler):
406
    r"""
407 408

    Applies natural exponential decay to the initial learning rate.
G
guguguzi 已提交
409

410 411 412 413
    The algorithm can be described as following:

    .. math::

414
        new\_learning\_rate = learning\_rate * e^{- gamma * epoch}
415 416 417

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
418
        gamma (float, optional): A Ratio to update the learning rate, should greater than 0.0 to make learning rate decay. Default: 0.1.
419
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
420
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
421 422

    Returns:
423
        ``NaturalExpDecay`` instance to schedule learning rate.
424 425

    Examples:
G
guguguzi 已提交
426

427 428 429 430 431
        .. code-block:: python

            import paddle
            import numpy as np

432
            # train on default dynamic graph mode
433
            linear = paddle.nn.Linear(10, 10)
434 435
            scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
436
            for epoch in range(20):
Z
Zhou Wei 已提交
437
                for batch_id in range(5):
438
                    x = paddle.uniform([10, 10])
439
                    out = linear(x)
C
chentianyu03 已提交
440
                    loss = paddle.mean(out)
441
                    loss.backward()
442 443
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
444 445
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
446

447
            # train on static graph mode
448 449 450 451
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
452 453
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
454 455
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
456
                scheduler = paddle.optimizer.lr.NaturalExpDecay(learning_rate=0.5, gamma=0.1, verbose=True)
457 458 459 460 461 462
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
463
                for batch_id in range(5):
464 465 466 467 468 469
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
470
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
471 472
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
473 474 475
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
476 477 478
        assert (
            gamma > 0.0
        ), " 'gamma' must be a positive number so that the learning rate will decay."
479
        self.gamma = gamma
480
        super().__init__(learning_rate, last_epoch, verbose)
481 482 483 484 485

    def get_lr(self):
        return self.base_lr * math.exp(-1 * self.gamma * self.last_epoch)


486
class InverseTimeDecay(LRScheduler):
487
    r"""
488 489 490 491 492 493 494

    Applies inverse time decay to the initial learning rate.

    The algorithm can be described as following:

    .. math::

495
        new\_learning\_rate = \frac{learning\_rate}{1 + gamma * epoch}
496 497 498

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
499
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
500 501
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
502
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
503 504

    Returns:
505
        ``InverseTimeDecay`` instance to schedule learning rate.
506 507

    Examples:
G
guguguzi 已提交
508

509 510 511 512 513
        .. code-block:: python

            import paddle
            import numpy as np

514
            # train on default dynamic graph mode
515
            linear = paddle.nn.Linear(10, 10)
516 517
            scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
518
            for epoch in range(20):
Z
Zhou Wei 已提交
519
                for batch_id in range(5):
520
                    x = paddle.uniform([10, 10])
521
                    out = linear(x)
C
chentianyu03 已提交
522
                    loss = paddle.mean(out)
523
                    loss.backward()
524 525
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
526 527
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
528

529
            # train on static graph mode
530 531 532 533
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
534 535
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
536 537
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
538
                scheduler = paddle.optimizer.lr.InverseTimeDecay(learning_rate=0.5, gamma=0.1, verbose=True)
539 540 541 542 543 544
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
545
                for batch_id in range(5):
546 547 548 549 550 551
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
552
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
553 554
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
555 556 557 558 559

    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
        self.gamma = gamma
560
        super().__init__(learning_rate, last_epoch, verbose)
561 562 563 564 565

    def get_lr(self):
        return self.base_lr / (1 + self.gamma * self.last_epoch)


566
class PolynomialDecay(LRScheduler):
567
    r"""
568 569 570 571 572 573 574 575 576

    Applies polynomial decay to the initial learning rate.

    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

G
guguguzi 已提交
577
        decay\_steps & = decay\_steps * math.ceil(\frac{epoch}{decay\_steps})
578

579
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
580 581 582 583 584

    If cycle is set to False, then:

    .. math::

G
guguguzi 已提交
585
        epoch & = min(epoch, decay\_steps)
586

587
        new\_learning\_rate & = (learning\_rate-end\_lr)*(1-\frac{epoch}{decay\_steps})^{power}+end\_lr
588 589 590 591


    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
592
        decay_steps(int): The decay step size. It determines the decay cycle. It must be a positive integer.
593
        end_lr(float, optional): The minimum final learning rate. Default: 0.0001.
594
        power(float, optional): Power of polynomial, should greater than 0.0 to get learning rate decay. Default: 1.0.
G
guguguzi 已提交
595
        cycle(bool, optional): Whether the learning rate rises again. If True, then the learning rate will rise when it decrease
596 597
            to ``end_lr`` .  If False, the learning rate is monotone decreasing. Default: False.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
598
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
599 600

    Returns:
601
        ``PolynomialDecay`` instance to schedule learning rate.
602 603

    Examples:
G
guguguzi 已提交
604

605 606 607 608 609
        .. code-block:: python

            import paddle
            import numpy as np

610
            # train on default dynamic graph mode
611
            linear = paddle.nn.Linear(10, 10)
612 613
            scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
614
            for epoch in range(20):
Z
Zhou Wei 已提交
615
                for batch_id in range(5):
616
                    x = paddle.uniform([10, 10])
617
                    out = linear(x)
C
chentianyu03 已提交
618
                    loss = paddle.mean(out)
619
                    loss.backward()
620 621
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
622 623
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
624

625
            # train on static graph mode
626 627 628 629
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
630 631
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
632 633
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
634
                scheduler = paddle.optimizer.lr.PolynomialDecay(learning_rate=0.5, decay_steps=20, verbose=True)
635 636 637 638 639 640
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
641
                for batch_id in range(5):
642 643 644 645 646 647
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
648
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
649 650
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
651 652
    """

653 654 655 656 657 658 659 660 661 662
    def __init__(
        self,
        learning_rate,
        decay_steps,
        end_lr=0.0001,
        power=1.0,
        cycle=False,
        last_epoch=-1,
        verbose=False,
    ):
663
        assert decay_steps > 0 and isinstance(
664 665
            decay_steps, int
        ), " 'decay_steps' must be a positive integer."
666 667
        self.decay_steps = decay_steps
        self.end_lr = end_lr
668 669 670
        assert (
            power > 0.0
        ), " 'power' must be greater than 0.0 so that the learning rate will decay."
671 672
        self.power = power
        self.cycle = cycle
673
        super().__init__(learning_rate, last_epoch, verbose)
674 675 676 677 678 679

    def get_lr(self):
        tmp_epoch_num = self.last_epoch
        tmp_decay_steps = self.decay_steps
        if self.cycle:
            div_res = math.ceil(
680 681
                float(self.last_epoch) / float(self.decay_steps)
            )
682 683 684 685 686 687 688 689

            if self.last_epoch == 0:
                div_res = 1
            tmp_decay_steps = self.decay_steps * div_res
        else:
            tmp_epoch_num = min(self.last_epoch, self.decay_steps)

        return (self.base_lr - self.end_lr) * (
690 691
            (1 - float(tmp_epoch_num) / float(tmp_decay_steps)) ** self.power
        ) + self.end_lr
692 693


694
class LinearWarmup(LRScheduler):
695
    r"""
696 697 698

    Linear learning rate warm up strategy. Update the learning rate preliminarily before the normal learning rate scheduler.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
G
guguguzi 已提交
699

700
    When epoch < warmup_steps, learning rate is updated as:
G
guguguzi 已提交
701

702
    .. math::
G
guguguzi 已提交
703

704
            lr = start\_lr + (end\_lr - start\_lr) * \frac{epoch}{warmup\_steps}
G
guguguzi 已提交
705

706
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
G
guguguzi 已提交
707

708
    When epoch >= warmup_steps, learning rate is updated as:
G
guguguzi 已提交
709

710
    .. math::
G
guguguzi 已提交
711

712
            lr = learning_rate
G
guguguzi 已提交
713

714
    where ``learning_rate`` is float or any subclass of ``LRScheduler`` .
715 716

    Args:
717
        learning_rate (float|LRScheduler): The learning rate after warm-up. It is a python float number or any subclass of ``LRScheduler`` .
718
        warmup_steps (int): total steps of warm up. It must be a positive integer.
719 720 721
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
722
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
723 724

    Returns:
725
        ``LinearWarmup`` instance to schedule learning rate.
726 727

    Examples:
G
guguguzi 已提交
728

729 730 731 732 733
        .. code-block:: python

            import paddle
            import numpy as np

734
            # train on default dynamic graph mode
735
            linear = paddle.nn.Linear(10, 10)
736
            scheduler = paddle.optimizer.lr.LinearWarmup(
737
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
738
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
739
            for epoch in range(20):
Z
Zhou Wei 已提交
740
                for batch_id in range(5):
741
                    x = paddle.uniform([10, 10])
742
                    out = linear(x)
C
chentianyu03 已提交
743
                    loss = paddle.mean(out)
744
                    loss.backward()
745 746
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
747 748
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
749

750
            # train on static graph mode
751 752 753 754
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
755 756
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
757 758
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
759
                scheduler = paddle.optimizer.lr.LinearWarmup(
760 761 762 763 764 765 766
                    learning_rate=0.5, warmup_steps=20, start_lr=0, end_lr=0.5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
767
                for batch_id in range(5):
768 769 770 771 772 773
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
774
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
775 776
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
777 778
    """

779 780 781 782 783 784 785 786 787 788 789 790 791 792
    def __init__(
        self,
        learning_rate,
        warmup_steps,
        start_lr,
        end_lr,
        last_epoch=-1,
        verbose=False,
    ):
        type_check = (
            isinstance(learning_rate, float)
            or isinstance(learning_rate, int)
            or isinstance(learning_rate, LRScheduler)
        )
793 794
        if not type_check:
            raise TypeError(
795 796 797 798
                "the type of learning_rate should be [int, float or LRScheduler], the current type is {}".format(
                    learning_rate
                )
            )
799
        self.learning_rate = learning_rate
800
        assert warmup_steps > 0 and isinstance(
801 802
            warmup_steps, int
        ), " 'warmup_steps' must be a positive integer."
803 804 805
        self.warmup_steps = warmup_steps
        self.start_lr = start_lr
        self.end_lr = end_lr
806 807 808
        assert (
            end_lr > start_lr
        ), "end_lr {} must be greater than start_lr {}".format(end_lr, start_lr)
809
        super().__init__(start_lr, last_epoch, verbose)
810

811 812 813 814 815 816
    def state_dict(self):
        """
        Returns the state of the LinearWarmup scheduler as a :class:`dict`.

        It is a subset of ``self.__dict__`` .
        """
817
        state_dict = super().state_dict()
818 819 820 821 822 823 824 825
        if isinstance(self.learning_rate, LRScheduler):
            state_dict["LinearWarmup_LR"] = self.learning_rate.state_dict()
        return state_dict

    def set_state_dict(self, state_dict):
        """
        Loads state_dict for LinearWarmup scheduler.
        """
826
        super().set_state_dict(state_dict)
827 828 829
        if isinstance(self.learning_rate, LRScheduler):
            self.learning_rate.set_state_dict(state_dict["LinearWarmup_LR"])

830 831 832
    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            return (self.end_lr - self.start_lr) * float(
833 834
                self.last_epoch
            ) / float(self.warmup_steps) + self.start_lr
835
        else:
836
            if isinstance(self.learning_rate, LRScheduler):
837 838
                self.learning_rate.step(self.last_epoch - self.warmup_steps)
                return self.learning_rate()
839 840 841 842

            return self.learning_rate


843
class ExponentialDecay(LRScheduler):
844
    r"""
845

846
    Update learning rate by `gamma` each epoch.
847 848

    The algorithm can be described as following.
G
guguguzi 已提交
849

850 851 852 853 854 855
    .. math::

        new\_learning\_rate = last\_learning\_rate * gamma

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
856
        gamma (float): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
857
            It should be in interval (0.0, 1.0).
858
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
859
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
860 861

    Returns:
862
        ``ExponentialDecay`` instance to schedule learning rate.
863 864

    Examples:
G
guguguzi 已提交
865

866 867 868 869 870
        .. code-block:: python

            import paddle
            import numpy as np

871
            # train on default dynamic graph mode
872
            linear = paddle.nn.Linear(10, 10)
873 874
            scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
875
            for epoch in range(20):
Z
Zhou Wei 已提交
876
                for batch_id in range(5):
877
                    x = paddle.uniform([10, 10])
878
                    out = linear(x)
C
chentianyu03 已提交
879
                    loss = paddle.mean(out)
880
                    loss.backward()
881 882
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
883 884
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
885

886
            # train on static graph mode
887 888 889 890
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
891 892
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
893 894
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
895
                scheduler = paddle.optimizer.lr.ExponentialDecay(learning_rate=0.5, gamma=0.9, verbose=True)
896 897 898 899 900 901
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
902
                for batch_id in range(5):
903 904 905 906 907 908
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
909
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
910 911
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
912 913 914
    """

    def __init__(self, learning_rate, gamma, last_epoch=-1, verbose=False):
915 916 917
        assert (
            gamma > 0.0 and gamma < 1.0
        ), " 'gamma' must be in interval (0.0, 1.0) so that the learning rate will decay."
918
        self.gamma = gamma
919
        super().__init__(learning_rate, last_epoch, verbose)
920 921 922 923 924

    def get_lr(self):
        return self.base_lr * (self.gamma**self.last_epoch)


925
class MultiStepDecay(LRScheduler):
926
    """
927
    Update the learning rate by ``gamma`` once ``epoch`` reaches one of the milestones.
928

G
guguguzi 已提交
929
    The algorithm can be described as the code below.
930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        gamma = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
G
guguguzi 已提交
946
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
947 948
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
949
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
950

951 952

    Returns:
953
        ``MultiStepDecay`` instance to schedule learning rate.
954 955

    Examples:
G
guguguzi 已提交
956

957 958 959 960 961
        .. code-block:: python

            import paddle
            import numpy as np

962
            # train on default dynamic graph mode
963
            linear = paddle.nn.Linear(10, 10)
964 965
            scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
966
            for epoch in range(20):
Z
Zhou Wei 已提交
967
                for batch_id in range(5):
968
                    x = paddle.uniform([10, 10])
969
                    out = linear(x)
C
chentianyu03 已提交
970
                    loss = paddle.mean(out)
971
                    loss.backward()
972 973
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
974 975
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
976

977
            # train on static graph mode
978 979 980 981
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
982 983
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
984 985
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
986
                scheduler = paddle.optimizer.lr.MultiStepDecay(learning_rate=0.5, milestones=[2, 4, 6], gamma=0.8, verbose=True)
987 988 989 990 991 992
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
993
                for batch_id in range(5):
994 995 996 997 998 999
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1000
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1001 1002
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1003 1004
    """

1005 1006 1007
    def __init__(
        self, learning_rate, milestones, gamma=0.1, last_epoch=-1, verbose=False
    ):
1008 1009 1010
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
1011 1012
                % type(milestones)
            )
1013

1014 1015
        if not all(
            [
1016 1017
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
1018 1019
            ]
        ):
1020 1021 1022 1023 1024 1025
            raise ValueError('The elements of milestones must be incremented')
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

        self.milestones = milestones
        self.gamma = gamma
1026
        super().__init__(learning_rate, last_epoch, verbose)
1027 1028 1029 1030 1031

    def get_lr(self):
        for i in range(len(self.milestones)):
            if self.last_epoch < self.milestones[i]:
                return self.base_lr * (self.gamma**i)
1032
        return self.base_lr * (self.gamma ** len(self.milestones))
1033 1034


1035
class StepDecay(LRScheduler):
1036 1037 1038
    """
    Update the learning rate of ``optimizer`` by ``gamma`` every ``step_size`` number of epoch.

G
guguguzi 已提交
1039
    The algorithm can be described as the code below.
1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        gamma = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
1054
        step_size (int): the interval to update. It must be a positive integer.
G
guguguzi 已提交
1055
        gamma (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * gamma`` .
1056 1057
            It should be less than 1.0. Default: 0.1.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1058
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1059 1060

    Returns:
1061
        ``StepDecay`` instance to schedule learning rate.
1062 1063 1064


    Examples:
G
guguguzi 已提交
1065

1066 1067 1068 1069 1070
        .. code-block:: python

            import paddle
            import numpy as np

1071
            # train on default dynamic graph mode
1072
            linear = paddle.nn.Linear(10, 10)
1073 1074
            scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1075
            for epoch in range(20):
Z
Zhou Wei 已提交
1076
                for batch_id in range(5):
1077
                    x = paddle.uniform([10, 10])
1078
                    out = linear(x)
C
chentianyu03 已提交
1079
                    loss = paddle.mean(out)
1080
                    loss.backward()
1081 1082
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1083 1084
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1085

1086
            # train on static graph mode
1087 1088 1089 1090
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1091 1092
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1093 1094
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1095
                scheduler = paddle.optimizer.lr.StepDecay(learning_rate=0.5, step_size=5, gamma=0.8, verbose=True)
1096 1097 1098 1099 1100 1101
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1102
                for batch_id in range(5):
1103 1104 1105 1106 1107 1108
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1109
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1110 1111
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1112 1113
    """

1114 1115 1116
    def __init__(
        self, learning_rate, step_size, gamma=0.1, last_epoch=-1, verbose=False
    ):
1117 1118
        if not isinstance(step_size, int):
            raise TypeError(
1119 1120 1121
                "The type of 'step_size' must be 'int', but received %s."
                % type(step_size)
            )
1122 1123 1124
        if gamma >= 1.0:
            raise ValueError('gamma should be < 1.0.')

1125
        assert step_size > 0 and isinstance(
1126 1127
            step_size, int
        ), " 'step_size' must be a positive integer."
1128 1129
        self.step_size = step_size
        self.gamma = gamma
1130
        super().__init__(learning_rate, last_epoch, verbose)
1131 1132 1133 1134 1135 1136

    def get_lr(self):
        i = self.last_epoch // self.step_size
        return self.base_lr * (self.gamma**i)


1137
class LambdaDecay(LRScheduler):
1138 1139 1140
    """
    Sets the learning rate of ``optimizer`` by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

G
guguguzi 已提交
1141
    The algorithm can be described as the code below.
1142 1143 1144 1145 1146 1147

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

1148 1149 1150
        learning_rate = 0.5        # epoch 0, 0.5*0.95**0
        learning_rate = 0.475      # epoch 1, 0.5*0.95**1
        learning_rate = 0.45125    # epoch 2, 0.5*0.95**2
1151 1152 1153 1154 1155

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the initial learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1156
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
G
guguguzi 已提交
1157

1158
    Returns:
1159
        ``LambdaDecay`` instance to schedule learning rate.
1160 1161

    Examples:
G
guguguzi 已提交
1162

1163 1164 1165 1166 1167
        .. code-block:: python

            import paddle
            import numpy as np

1168
            # train on default dynamic graph mode
1169
            linear = paddle.nn.Linear(10, 10)
1170 1171
            scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1172
            for epoch in range(20):
Z
Zhou Wei 已提交
1173
                for batch_id in range(5):
1174
                    x = paddle.uniform([10, 10])
1175
                    out = linear(x)
C
chentianyu03 已提交
1176
                    loss = paddle.mean(out)
1177
                    loss.backward()
1178 1179
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1180 1181
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1182

1183
            # train on static graph mode
1184 1185 1186 1187
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1188 1189
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1190 1191
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1192
                scheduler = paddle.optimizer.lr.LambdaDecay(learning_rate=0.5, lr_lambda=lambda x:0.95**x, verbose=True)
1193 1194 1195 1196 1197 1198
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1199
                for batch_id in range(5):
1200 1201 1202 1203 1204 1205
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1206
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1207 1208
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1209 1210 1211 1212 1213 1214

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
1215
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1216 1217
                % type(lr_lambda)
            )
1218 1219

        self.lr_lambda = lr_lambda
1220
        super().__init__(learning_rate, last_epoch, verbose)
1221 1222 1223 1224 1225

    def get_lr(self):
        return self.base_lr * self.lr_lambda(self.last_epoch)


1226
class ReduceOnPlateau(LRScheduler):
1227
    """
G
guguguzi 已提交
1228
    Reduce learning rate when ``metrics`` has stopped descending. Models often benefit from reducing the learning rate
1229 1230
    by 2 to 10 times once model performance has no longer improvement.

G
guguguzi 已提交
1231 1232 1233
    The ``metrics`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``metrics``
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * factor`` .
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``metrics`` stop ascending for a ``patience``
1234 1235 1236 1237 1238 1239
    number of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming above operation.

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
G
guguguzi 已提交
1240 1241
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning
1242
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
G
guguguzi 已提交
1243
        factor (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * factor`` .
1244
            It should be less than 1.0. Default: 0.1.
G
guguguzi 已提交
1245
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
1246
            Default: 10.
G
guguguzi 已提交
1247
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
1248 1249
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
G
guguguzi 已提交
1250
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
1251 1252 1253
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
G
guguguzi 已提交
1254
        epsilon (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than epsilon,
1255
            the update is ignored. Default: 1e-8.
1256 1257
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.

G
guguguzi 已提交
1258

1259
    Returns:
1260
        ``ReduceOnPlateau`` instance to schedule learning rate.
1261 1262 1263 1264 1265 1266 1267 1268


    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

1269
            # train on default dynamic graph mode
1270
            linear = paddle.nn.Linear(10, 10)
1271 1272
            scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1273
            for epoch in range(20):
Z
Zhou Wei 已提交
1274
                for batch_id in range(5):
1275
                    x = paddle.uniform([10, 10])
1276
                    out = linear(x)
C
chentianyu03 已提交
1277
                    loss = paddle.mean(out)
1278
                    loss.backward()
1279 1280
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1281 1282
                    scheduler.step(loss)    # If you update learning rate each step
              # scheduler.step(loss)        # If you update learning rate each epoch
1283

1284
            # train on static graph mode
1285 1286 1287 1288
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1289 1290
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1291 1292
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1293
                scheduler = paddle.optimizer.lr.ReduceOnPlateau(learning_rate=1.0, factor=0.5, patience=5, verbose=True)
1294 1295 1296 1297 1298 1299
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1300
                for batch_id in range(5):
1301 1302 1303 1304 1305 1306
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1307
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1308 1309
                    scheduler.step(out[0])    # If you update learning rate each step
              # scheduler.step(out[0])        # If you update learning rate each epoch
1310 1311 1312

    """

1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325
    def __init__(
        self,
        learning_rate,
        mode='min',
        factor=0.1,
        patience=10,
        threshold=1e-4,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        epsilon=1e-8,
        verbose=False,
    ):
1326 1327 1328 1329 1330 1331 1332
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode: ' + mode + ' is unknown!')
        self.mode = mode

        if factor >= 1.0:
            raise ValueError(
1333 1334
                'new_lr = origin_lr * gamma and gamma should be < 1.0.'
            )
1335 1336 1337 1338
        self.factor = factor

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
1339 1340 1341
            raise ValueError(
                'threshold mode: ' + threshold_mode + ' is unknown!'
            )
1342 1343 1344
        self.threshold_mode = threshold_mode
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
1345
                "The type of 'learning_rate' in 'ReduceOnPlateau' must be 'float', but received %s."
1346 1347
                % type(learning_rate)
            )
1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367

        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = min_lr
        self.epsilon = epsilon

        self.cooldown_counter = 0
        self.best = None
        self.num_bad_epochs = 0

        # Can not call Parent __init__, so implement here.
        self.base_lr = float(learning_rate)
        self.last_lr = float(learning_rate)
        self.last_epoch = 0
        self.verbose = verbose
        self._var_name = None

    # "cooldown_counter / best / num_bad_epochs / last_epoch / last_lr" will be stored.
1368
    def state_keys(self):
1369
        self.keys = [
1370 1371 1372 1373 1374
            'cooldown_counter',
            'best',
            'num_bad_epochs',
            'last_epoch',
            'last_lr',
1375 1376 1377 1378
        ]

    def step(self, metrics, epoch=None):
        """
G
guguguzi 已提交
1379
        step should be called after `optimizer.step()` . It will update the learning rate in optimizer according to ``metrics`` .
1380 1381 1382
        The new learning rate will take effect on next epoch.

        Args:
G
guguguzi 已提交
1383
            metrics (Tensor|numpy.ndarray|float): Which will be monitored to determine whether the learning rate will reduce.
1384 1385 1386 1387 1388 1389
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. If it's 'Tensor' or
                'numpy.ndarray', its shape must be [1].
            epoch (int, None): specify current epoch. Default: None. Auto-increment from last_epoch=-1.

        Returns:
            None
G
guguguzi 已提交
1390

1391
        Examples:
1392
            Please refer to the example of current LRScheduler.
1393 1394 1395 1396 1397 1398
        """
        if epoch is None:
            self.last_epoch = self.last_epoch + 1
        else:
            self.last_epoch = epoch

1399
        # loss must be float, numpy.ndarray or 1-D Tensor with shape [1]
1400
        if isinstance(metrics, (core.eager.Tensor, numpy.ndarray)):
1401 1402 1403 1404 1405 1406 1407 1408 1409 1410
            assert len(metrics.shape) == 1 and metrics.shape[0] == 1, (
                "the metrics.shape "
                "should be (1L,), but the current metrics.shape is {}. Maybe that "
                "you should call paddle.mean to process it first.".format(
                    metrics.shape
                )
            )
        elif not isinstance(
            metrics, (int, float, numpy.float32, numpy.float64)
        ):
1411
            raise TypeError(
1412
                "metrics must be 'int', 'float', 'np.float64', 'numpy.ndarray' or 'paddle.Tensor', but receive {}".format(
1413 1414 1415
                    type(metrics)
                )
            )
1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432

        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best is None or self._is_better(metrics, self.best):
                self.best = metrics
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
                new_lr = max(self.last_lr * self.factor, self.min_lr)
                if self.last_lr - new_lr > self.epsilon:
                    self.last_lr = new_lr
                    if self.verbose:
1433 1434 1435 1436 1437 1438 1439
                        print(
                            'Epoch {}: {} set learning rate to {}.'.format(
                                self.last_epoch,
                                self.__class__.__name__,
                                self.last_lr,
                            )
                        )
1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold


1455
class CosineAnnealingDecay(LRScheduler):
1456
    r"""
1457

G
guguguzi 已提交
1458 1459
    Set the learning rate using a cosine annealing schedule, where :math:`\eta_{max}` is set to
    the initial learning_rate. :math:`T_{cur}` is the number of epochs since the last restart in
1460
    SGDR.
1461 1462 1463 1464

    The algorithm can be described as following.

    .. math::
1465

1466 1467
        \eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
        + \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
G
guguguzi 已提交
1468
        & T_{cur} \neq (2k+1)T_{max};
1469 1470 1471 1472

        \eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
        \left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
        & T_{cur} = (2k+1)T_{max}.
G
guguguzi 已提交
1473 1474

    It has been proposed in `SGDR: Stochastic Gradient Descent with Warm Restarts <https://arxiv.org/abs/1608.03983>`_.
1475
    Note that this only implements the cosine annealing part of SGDR, and not the restarts.
G
guguguzi 已提交
1476

1477 1478
    Args:
        learning_rate (float): The initial learning rate, that is :math:`\eta_{max}` . It can be set to python float or int number.
1479
        T_max (int): Maximum number of iterations. It is half of the decay cycle of learning rate. It must be a positive integer.
1480 1481
        eta_min (float|int, optional): Minimum learning rate, that is :math:`\eta_{min}` . Default: 0.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
1482
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .
1483 1484

    Returns:
1485
        ``CosineAnnealingDecay`` instance to schedule learning rate.
1486 1487

    Examples:
G
guguguzi 已提交
1488

1489 1490 1491 1492 1493
        .. code-block:: python

            import paddle
            import numpy as np

1494
            # train on default dynamic graph mode
1495
            linear = paddle.nn.Linear(10, 10)
1496 1497
            scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
1498
            for epoch in range(20):
Z
Zhou Wei 已提交
1499
                for batch_id in range(5):
1500
                    x = paddle.uniform([10, 10])
1501
                    out = linear(x)
C
chentianyu03 已提交
1502
                    loss = paddle.mean(out)
1503
                    loss.backward()
1504 1505
                    sgd.step()
                    sgd.clear_gradients()
Z
Zhou Wei 已提交
1506 1507
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1508

1509
            # train on static graph mode
1510 1511 1512 1513
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
1514 1515
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
1516 1517
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
1518
                scheduler = paddle.optimizer.lr.CosineAnnealingDecay(learning_rate=0.5, T_max=10, verbose=True)
1519 1520 1521 1522 1523 1524
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(20):
Z
Zhou Wei 已提交
1525
                for batch_id in range(5):
1526 1527 1528 1529 1530 1531
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
1532
                        fetch_list=loss.name)
Z
Zhou Wei 已提交
1533 1534
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch
1535 1536
    """

1537 1538 1539
    def __init__(
        self, learning_rate, T_max, eta_min=0, last_epoch=-1, verbose=False
    ):
1540 1541
        if not isinstance(T_max, int):
            raise TypeError(
1542
                "The type of 'T_max' in 'CosineAnnealingDecay' must be 'int', but received %s."
1543 1544
                % type(T_max)
            )
1545 1546
        if not isinstance(eta_min, (float, int)):
            raise TypeError(
1547
                "The type of 'eta_min' in 'CosineAnnealingDecay' must be 'float, int', but received %s."
1548 1549
                % type(eta_min)
            )
1550
        assert T_max > 0 and isinstance(
1551 1552
            T_max, int
        ), " 'T_max' must be a positive integer."
1553 1554
        self.T_max = T_max
        self.eta_min = float(eta_min)
1555
        super().__init__(learning_rate, last_epoch, verbose)
1556 1557 1558 1559 1560

    def get_lr(self):
        if self.last_epoch == 0:
            return self.base_lr
        elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
1561 1562 1563 1564 1565 1566
            return (
                self.last_lr
                + (self.base_lr - self.eta_min)
                * (1 - math.cos(math.pi / self.T_max))
                / 2
            )
1567 1568

        return (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / (
1569 1570
            1 + math.cos(math.pi * (self.last_epoch - 1) / self.T_max)
        ) * (self.last_lr - self.eta_min) + self.eta_min
1571 1572

    def _get_closed_form_lr(self):
1573 1574 1575 1576 1577 1578
        return (
            self.eta_min
            + (self.base_lr - self.eta_min)
            * (1 + math.cos(math.pi * self.last_epoch / self.T_max))
            / 2
        )
G
guguguzi 已提交
1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631


class MultiplicativeDecay(LRScheduler):
    """
    Multiply the learning rate of ``optimizer`` by the factor given in function ``lr_lambda`` .

    The algorithm can be described as the code below.

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95

        learning_rate = 0.5        # epoch 0,
        learning_rate = 0.475      # epoch 1, 0.5*0.95
        learning_rate = 0.45125    # epoch 2, 0.475*0.95

    Args:
        learning_rate (float): The initial learning rate. It is a python float number.
        lr_lambda (function): A function which computes a factor by ``epoch`` , and then multiply the last learning rate by this factor.
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``MultiplicativeDecay`` instance to schedule learning rate.

    Examples:

        .. code-block:: python

            import paddle

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.MultiplicativeDecay(learning_rate=0.5, lr_lambda=lambda x:0.95, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(20):
                for batch_id in range(5):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()    # If you update learning rate each step
              # scheduler.step()        # If you update learning rate each epoch

    """

    def __init__(self, learning_rate, lr_lambda, last_epoch=-1, verbose=False):
        if not callable(lr_lambda):
            raise TypeError(
                "The type of 'lr_lambda' in 'MultiplicativeDecay' must be 'function', but received %s."
1632 1633
                % type(lr_lambda)
            )
G
guguguzi 已提交
1634 1635

        self.lr_lambda = lr_lambda
1636
        super().__init__(learning_rate, last_epoch, verbose)
G
guguguzi 已提交
1637 1638

    def get_lr(self):
1639 1640 1641 1642
        cur_lr = self.base_lr
        for epoch in range(1, self.last_epoch + 1):
            cur_lr = cur_lr * self.lr_lambda(epoch)
        return cur_lr
1643 1644 1645 1646


class OneCycleLR(LRScheduler):
    r"""
1647

1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660
    Sets the learning rate according to the one cycle learning rate scheduler.
    The scheduler adjusts the learning rate from an initial learning rate to the maximum learning rate and then
    from that maximum learning rate to the minimum learning rate, which is much less than the initial learning rate.

    It has been proposed in `Super-Convergence: Very Fast Training of Neural Networks Using Large Learning Rates <https://arxiv.org/abs/1708.07120>`_.

    Please note that the default behaviour of this scheduler follows the fastai implementation of one cycle,
    which claims that “unpublished work has shown even better results by using only two phases”.
    If you want the behaviour of this scheduler to be consistent with the paper, please set ``three_phase=True`` .

    Also note that you should update learning rate each step.

    Args:
1661
        max_learning_rate (float): The maximum learning rate. It is a python float number. Functionally, it defines the initial learning rate by ``divide_factor`` .
1662
        total_steps (int): Number of total training steps.
1663
        divide_factor (float, optional): Initial learning rate will be determined by initial_learning_rate = max_learning_rate / divide_factor. Default: 25.
1664 1665
        end_learning_rate (float, optional): The minimum learning rate during training, it should be much less than initial learning rate.
        phase_pct (float): The percentage of total steps which used to increasing learning rate. Default: 0.3.
1666
        anneal_strategy (str, optional): Strategy of adjusting learning rate.'cos' for cosine annealing, 'linear' for linear annealing. Default: 'cos'.
1667
        three_phase (bool, optional): Whether to use three phase.
1668

1669
            If ``True``:
1670

1671 1672 1673
                1. The learning rate will first increase from initial learning rate to maximum learning rate.
                2. Then it will decrease to initial learning rate. Number of step in this phase is the same as the one in first phase.
                3. Finally, it will decrease to minimum learning rate which is much less than initial learning rate.
1674

1675
            If ``False``:
1676

1677 1678
                1. The learning rate will increase to maximum learning rate.
                2. Then it will directly decrease to minimum learning rate.
1679

1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730
        last_epoch (int, optional):  The index of last epoch. Can be set to restart training. Default: -1, means initial learning rate.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
        ``OneCycleLR`` instance to schedule learning rate.

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.OneCycleLR(max_learning_rate=1.0, total_steps=100, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
1731

1732 1733
    """

1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745
    def __init__(
        self,
        max_learning_rate,
        total_steps,
        divide_factor=25.0,
        end_learning_rate=0.0001,
        phase_pct=0.3,
        anneal_strategy='cos',
        three_phase=False,
        last_epoch=-1,
        verbose=False,
    ):
1746 1747 1748
        # Check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
1749 1750 1751 1752
                "'max_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(max_learning_rate)
                )
            )
1753 1754 1755 1756 1757 1758
        if max_learning_rate < 0:
            raise ValueError("'max_learning_rate' must be a positive integer.")

        # Check type and value of end_learning_rate
        if not isinstance(end_learning_rate, (float, int)):
            raise TypeError(
1759 1760 1761 1762
                "'end_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(end_learning_rate)
                )
            )
1763 1764 1765 1766 1767
        if end_learning_rate < 0:
            raise ValueError("'end_learning_rate' must be a positive integer.")

        # Check type and value of total_steps
        if not isinstance(total_steps, int):
1768 1769
            raise TypeError(
                "'total_step' must be 'int', but received {}".format(
1770 1771 1772
                    type(total_steps)
                )
            )
1773 1774 1775 1776 1777 1778
        if total_steps <= 0:
            raise ValueError("'total_step' must be a positive integer.")
        self.total_steps = total_steps

        # Check type and value of pac_start
        if not isinstance(phase_pct, float):
1779 1780
            raise TypeError(
                "'phase_pct' must be 'float', but received {}".format(
1781 1782 1783
                    type(phase_pct)
                )
            )
1784 1785 1786
        if phase_pct < 0 or phase_pct > 1:
            raise ValueError(
                "'phase_pct' must be between 0 and 1, but received {}".format(
1787 1788 1789
                    phase_pct
                )
            )
1790 1791 1792 1793

        # Check type and value of divide_factor
        if not isinstance(divide_factor, (float, int)):
            raise TypeError(
1794 1795 1796 1797
                "'divide_factor' must be 'float' or 'int', but received {}".format(
                    type(divide_factor)
                )
            )
1798 1799 1800 1801 1802 1803 1804 1805 1806 1807 1808 1809 1810 1811 1812 1813 1814 1815 1816 1817 1818 1819

        initial_lr = max_learning_rate / float(divide_factor)
        min_lr = float(end_learning_rate)

        if three_phase:
            if phase_pct >= 0.5:
                raise ValueError(
                    "When three_phase is True, 'phase_pct' must be less than 0.5"
                )
            # start step and end step of each phase.
            self._step_config = [
                0,
                phase_pct * self.total_steps - 1,
                2 * phase_pct * self.total_steps - 2,
                self.total_steps - 1,
                self.total_steps - 1,  # for the last step.
            ]
            # step size of each phase.
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[3] - self._step_config[2],
1820 1821
                self._step_config[3]
                - self._step_config[2],  # for the last step.
1822 1823 1824
            ]
            # start lr and end lr of each phase.
            self._lr_config = [
1825 1826 1827 1828
                initial_lr,
                max_learning_rate,
                initial_lr,
                min_lr,
1829 1830 1831
            ]
        else:
            self._step_config = [
1832 1833 1834 1835
                0,
                phase_pct * self.total_steps - 1,
                self.total_steps - 1,
                self.total_steps - 1,
1836 1837 1838 1839 1840 1841 1842 1843 1844 1845 1846 1847 1848 1849 1850
            ]
            self._steps_size = [
                self._step_config[1] - self._step_config[0],
                self._step_config[2] - self._step_config[1],
                self._step_config[2] - self._step_config[1],
            ]
            self._lr_config = [initial_lr, max_learning_rate, min_lr]

        # Check anneal_strategy
        if anneal_strategy == 'cos':
            self.anneal_func = self._cos_annealing
        elif anneal_strategy == 'linear':
            self.anneal_func = self._linear_annealing
        else:
            raise ValueError(
1851 1852 1853 1854
                "'anneal_strategy' must by one of 'cos' or 'linear', but received {}".format(
                    anneal_strategy
                )
            )
1855
        super().__init__(initial_lr, last_epoch, verbose)
1856 1857 1858 1859 1860 1861 1862 1863 1864 1865 1866 1867 1868

    def _cos_annealing(self, start_lr, end_lr, pct):
        cos_out = math.cos(math.pi * pct) + 1
        return end_lr + (start_lr - end_lr) / 2.0 * cos_out

    def _linear_annealing(self, start_lr, end_lr, pct):
        return (end_lr - start_lr) * pct + start_lr

    def get_lr(self):
        current_step = self.last_epoch

        if current_step > self.total_steps:
            raise ValueError(
1869 1870 1871 1872
                "Tried to step {} times. However the number of total steps is {}".format(
                    current_step, self.total_steps
                )
            )
1873

1874
        for (i, (end_step, step_size)) in enumerate(
1875 1876
            zip(self._step_config[1:], self._steps_size)
        ):
1877 1878 1879 1880
            # i == len(self._lr_config) - 2 catch the last step, otherwise it will return None.
            if current_step <= end_step or i == len(self._lr_config) - 2:
                # self._step_config[i] means start step of a phase.
                percentage = (current_step - self._step_config[i]) / step_size
1881 1882 1883
                return self.anneal_func(
                    self._lr_config[i], self._lr_config[i + 1], percentage
                )
1884 1885 1886 1887 1888 1889 1890 1891 1892 1893 1894 1895 1896 1897 1898 1899 1900 1901 1902 1903 1904 1905 1906 1907 1908 1909 1910 1911 1912 1913 1914 1915 1916 1917 1918 1919 1920 1921 1922 1923 1924 1925 1926


class CyclicLR(LRScheduler):
    r"""
    Set the learning rate according to the cyclic learning rate (CLR) scheduler.
    The scheduler regards the process of learning rate adjustment as one cycle after another.
    It cycles the learning rate between two boundaries with a constant frequency.
    The distance between the two boundaries can be scaled on a per-iteration or per-cycle basis.

    It has been proposed in `Cyclic Learning Rates for Training Neural Networks <https://arxiv.org/abs/1506.01186>`_.

    According to the paper, the cyclic learning rate schedule has three build-in scale methods:

    * "triangular": A basic triangular cycle without any amplitude scaling.
    * "triangular2": A basic triangular cycle that reduce initial amplitude by half each cycle.
    * "exp_range": A cycle that scales initial amplitude by scale function which is defined as :math:`gamma^{iterations}` .

    The initial amplitude is defined as max_learning_rate - base_learning_rate.
    Also note that you should update learning rate each step.

    Args:
        base_learning_rate (float): Initial learning rate, which is the lower boundary in the cycle. The paper recommends
            that set the base_learning_rate to 1/3 or 1/4 of max_learning_rate.
        max_learning_rate (float): Maximum learning rate in the cycle. It defines the cycle amplitude as above.
            Since there is some scaling operation during process of learning rate adjustment,
            max_learning_rate may not actually be reached.
        step_size_up (int): Number of training steps, which is used to increase learning rate in a cycle.
            The step size of one cycle will be defined by step_size_up + step_size_down. According to the paper, step
            size should be set as at least 3 or 4 times steps in one epoch.
        step_size_down (int, optional): Number of training steps, which is used to decrease learning rate in a cycle.
            If not specified, it's value will initialize to `` step_size_up `` . Default: None
        mode (str, optional): one of 'triangular', 'triangular2' or 'exp_range'.
            If scale_fn is specified, this argument will be ignored. Default: 'triangular'
        exp_gamma (float): Constant in 'exp_range' scaling function: exp_gamma**iterations. Used only when mode = 'exp_range'. Default: 1.0
        scale_fn (function, optional): A custom scaling function, which is used to replace three build-in methods.
            It should only have one argument. For all x >= 0, 0 <= scale_fn(x) <= 1.
            If specified, then 'mode' will be ignored. Default: None
        scale_mode (str, optional): One of 'cycle' or 'iterations'. Defines whether scale_fn is evaluated on cycle
            number or cycle iterations (total iterations since start of training). Default: 'cycle'
        last_epoch (int, optional): The index of last epoch. Can be set to restart training.Default: -1, means initial learning rate.
        verbose: (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False`` .

    Returns:
1927
        ``CyclicLR`` instance to schedule learning rate.
1928 1929 1930 1931 1932 1933 1934 1935 1936 1937 1938 1939 1940 1941 1942 1943 1944 1945 1946 1947 1948 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 1961 1962 1963 1964 1965 1966 1967 1968 1969 1970 1971 1972 1973 1974 1975 1976

    Examples:
        .. code-block:: python

            import paddle
            import numpy as np

            # train on default dynamic graph mode
            linear = paddle.nn.Linear(10, 10)
            scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5, max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
            sgd = paddle.optimizer.SGD(learning_rate=scheduler, parameters=linear.parameters())
            for epoch in range(5):
                for batch_id in range(20):
                    x = paddle.uniform([10, 10])
                    out = linear(x)
                    loss = paddle.mean(out)
                    loss.backward()
                    sgd.step()
                    sgd.clear_gradients()
                    scheduler.step()        # You should update learning rate each step

            # train on static graph mode
            paddle.enable_static()
            main_prog = paddle.static.Program()
            start_prog = paddle.static.Program()
            with paddle.static.program_guard(main_prog, start_prog):
                x = paddle.static.data(name='x', shape=[None, 4, 5])
                y = paddle.static.data(name='y', shape=[None, 4, 5])
                z = paddle.static.nn.fc(x, 100)
                loss = paddle.mean(z)
                scheduler = paddle.optimizer.lr.CyclicLR(base_learning_rate=0.5,
                    max_learning_rate=1.0, step_size_up=15, step_size_down=5, verbose=True)
                sgd = paddle.optimizer.SGD(learning_rate=scheduler)
                sgd.minimize(loss)

            exe = paddle.static.Executor()
            exe.run(start_prog)
            for epoch in range(5):
                for batch_id in range(20):
                    out = exe.run(
                        main_prog,
                        feed={
                            'x': np.random.randn(3, 4, 5).astype('float32'),
                            'y': np.random.randn(3, 4, 5).astype('float32')
                        },
                        fetch_list=loss.name)
                    scheduler.step()    # You should update learning rate each step
    """

1977 1978 1979 1980 1981 1982 1983 1984 1985 1986 1987 1988 1989
    def __init__(
        self,
        base_learning_rate,
        max_learning_rate,
        step_size_up,
        step_size_down=None,
        mode='triangular',
        exp_gamma=1.0,
        scale_fn=None,
        scale_mode='cycle',
        last_epoch=-1,
        verbose=False,
    ):
1990 1991 1992
        # check type and value of max_learning_rate
        if not isinstance(max_learning_rate, (float, int)):
            raise TypeError(
1993 1994 1995 1996
                "'max_learning_rate' must be 'float' or 'int', but received {}".format(
                    type(max_learning_rate)
                )
            )
1997 1998
        if max_learning_rate < 0:
            raise ValueError(
1999 2000 2001 2002
                "'max_learning_rate' must be a positive integer, but received {}".format(
                    max_learning_rate
                )
            )
2003 2004 2005 2006

        # check type and value of step_size_up
        if not isinstance(step_size_up, int):
            raise TypeError(
2007 2008 2009 2010
                "The type of 'step_size_up' must be int, but received {}".format(
                    type(step_size_up)
                )
            )
2011 2012
        if step_size_up <= 0:
            raise ValueError(
2013 2014 2015 2016
                "'step_size_up' must be a positive integer, but received {}".format(
                    step_size_up
                )
            )
2017 2018 2019 2020 2021

        # check type and value of step_size_down
        if step_size_down is not None:
            if not isinstance(step_size_down, int):
                raise TypeError(
2022 2023 2024 2025
                    "The type of 'step_size_down' must be int, but received {}".format(
                        type(step_size_down)
                    )
                )
2026 2027
            if step_size_down <= 0:
                raise ValueError(
2028 2029 2030 2031
                    "'step_size_down' must be a positive integer, but received {}".format(
                        step_size_down
                    )
                )
2032 2033 2034 2035 2036

        # check type of exp_gamma
        if not isinstance(exp_gamma, float):
            raise TypeError(
                "The type of 'exp_gamma' must be float, but received {}".format(
2037 2038 2039
                    type(exp_gamma)
                )
            )
2040 2041

        step_size_up = float(step_size_up)
2042 2043 2044 2045 2046
        step_size_down = (
            float(step_size_down)
            if step_size_down is not None
            else step_size_up
        )
2047 2048 2049 2050 2051 2052

        self.cycle_size = step_size_up + step_size_down
        self.step_up_pct = step_size_up / self.cycle_size
        self.max_lr = float(max_learning_rate)
        self.amplitude = self.max_lr - base_learning_rate

2053 2054 2055 2056
        if (
            mode not in ['triangular', 'triangular2', 'exp_range']
            and scale_fn is None
        ):
2057 2058 2059 2060 2061
            raise ValueError(
                "'mode' is invalid and 'scale_fn' is not specified, make sure one of 'mode' or 'scale_fn' is valid"
            )
        if scale_mode not in ['cycle', 'iterations']:
            raise ValueError(
2062 2063
                "'scale_mode' must be one of 'cycle' or 'iterations"
            )
2064 2065 2066 2067 2068 2069 2070 2071 2072 2073 2074 2075 2076 2077 2078 2079 2080 2081 2082 2083

        self.mode = mode
        self.gamma = exp_gamma  # only for exp_range mode

        if scale_fn is None:
            if self.mode == 'triangular':
                self.scale_fn = self._triangular_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'triangular2':
                self.scale_fn = self._triangular2_scale_fn
                self.scale_mode = 'cycle'
            elif self.mode == 'exp_range':
                self.scale_fn = self._exp_range_scale_fn
                self.scale_mode = 'iterations'
        else:
            self.scale_fn = scale_fn
            self.scale_mode = scale_mode
        super().__init__(base_learning_rate, last_epoch, verbose)

    def _triangular_scale_fn(self, x):
2084
        return 1.0
2085 2086

    def _triangular2_scale_fn(self, x):
2087
        return 1 / (2.0 ** (x - 1))
2088 2089 2090 2091 2092 2093 2094 2095

    def _exp_range_scale_fn(self, x):
        return self.gamma**x

    def get_lr(self):
        iterations = self.last_epoch

        cycle = 1 + iterations // self.cycle_size
2096
        pct_per_cycle = 1.0 + iterations / self.cycle_size - cycle
2097 2098 2099 2100 2101 2102 2103 2104 2105 2106 2107

        if pct_per_cycle <= self.step_up_pct:
            scale_factor = pct_per_cycle / self.step_up_pct
        else:
            scale_factor = (1 - pct_per_cycle) / (1 - self.step_up_pct)

        base_height = self.amplitude * scale_factor

        lr = self.base_lr + base_height * self.scale_fn(eval(self.scale_mode))

        return lr