learning_rate_scheduler.py 44.7 KB
Newer Older
M
minqiyang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

M
minqiyang 已提交
15
import math
16
import warnings
M
minqiyang 已提交
17

H
HongyuJia 已提交
18
import paddle
M
minqiyang 已提交
19
from .. import unique_name
20 21
from ..framework import Variable
from ..data_feeder import check_type
M
minqiyang 已提交
22

23
__all__ = [
24 25 26 27 28 29 30 31 32 33 34 35
    'NoamDecay',
    'PiecewiseDecay',
    'NaturalExpDecay',
    'ExponentialDecay',
    'InverseTimeDecay',
    'PolynomialDecay',
    'CosineDecay',
    'LinearLrWarmup',
    'ReduceLROnPlateau',
    'StepDecay',
    'MultiStepDecay',
    'LambdaDecay',
36
]
M
minqiyang 已提交
37 38


39
class LearningRateDecay:
M
minqiyang 已提交
40 41
    """
    Base class of learning rate decay
42

43 44 45
    Define the common interface of an LearningRateDecay.
    User should not use this class directly,
    but need to use one of it's implementation.
M
minqiyang 已提交
46 47
    """

M
minqiyang 已提交
48 49 50
    def __init__(self, begin=0, step=1, dtype='float32'):
        self.step_num = begin
        self.step_size = step
M
minqiyang 已提交
51 52 53 54 55
        self.dtype = dtype

    def __call__(self):
        lr = self.step()
        if isinstance(lr, float):
M
minqiyang 已提交
56
            lr = self.create_lr_var(lr)
M
minqiyang 已提交
57
        self.step_num += self.step_size
M
minqiyang 已提交
58 59
        return lr

M
minqiyang 已提交
60
    def create_lr_var(self, lr):
61 62 63
        """
        convert lr from float to variable

64
        Args:
65 66 67 68
            lr: learning rate
        Returns:
            learning rate variable
        """
M
minqiyang 已提交
69
        from .. import layers
70

71
        lr = paddle.static.create_global_var(
M
minqiyang 已提交
72 73 74 75
            name=unique_name.generate("learning_rate"),
            shape=[1],
            value=float(lr),
            dtype=self.dtype,
76 77
            persistable=False,
        )
M
minqiyang 已提交
78
        return lr
M
minqiyang 已提交
79

80
    # Note: If you want to change what optimizer.state_dict stores, just overwrite this functions,
81
    # "self.step_num" will be stored by default.
82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
    def state_dict(self):
        """
        Returns the state of the scheduler as a :class:`dict`.

        It is a subset of self.__dict__ .
        """
        self._state_keys()
        state_dict = {}
        for key in self.keys:
            if key not in self.__dict__:
                continue
            value = self.__dict__[key]
            if isinstance(value, Variable):
                assert value.shape == [
                    1
                ], "shape of Variable in state_dict must be [1] {}".format(
98 99
                    value.shape
                )
100 101 102 103 104 105 106 107 108 109 110
                value = value.numpy()[0]
            state_dict[key] = value

        return state_dict

    def _state_keys(self):
        """
        set the keys in self.__dict__ that are needed to be saved.
        """
        self.keys = ['step_num']

111
    def set_state_dict(self, state_dict):
112 113 114 115 116 117 118 119 120
        """
        Loads the schedulers state.
        """
        self._state_keys()
        for key in self.keys:
            if key in state_dict:
                self.__dict__[key] = state_dict[key]
            else:
                raise RuntimeError(
121 122 123 124
                    "Please check whether state_dict is correct for optimizer. Can't find [ {} ] in state_dict".format(
                        key
                    )
                )
125 126 127 128 129
        if len(state_dict) > len(self.keys):
            warnings.warn(
                "There are some unused values in state_dict. Maybe the optimizer have different 'LearningRateDecay' when invoking state_dict and set_dict"
            )

130 131 132
    # [aliases] Compatible with old method names
    set_dict = set_state_dict

M
minqiyang 已提交
133 134 135 136
    def step(self):
        raise NotImplementedError()


M
minqiyang 已提交
137
class PiecewiseDecay(LearningRateDecay):
138
    """
139
    :api_attr: imperative
140

D
DuYao 已提交
141
    Piecewise decay scheduler.
142 143 144 145 146

    The algorithm can be described as the code below.

    .. code-block:: text

D
DuYao 已提交
147 148 149 150 151 152 153 154 155 156
        boundaries = [10000, 20000]
        values = [1.0, 0.5, 0.1]
        if global_step < 10000:
            learning_rate = 1.0
        elif 10000 <= global_step < 20000:
            learning_rate = 0.5
        else:
            learning_rate = 0.1

    Parameters:
157
        boundaries(list): A list of steps numbers. The type of element in the list is python int.
D
DuYao 已提交
158 159
        values(list): A list of learning rate values that will be picked during
            different step boundaries. The type of element in the list is python float.
T
tianshuo78520a 已提交
160
        begin(int): The begin step to initialize the global_step in the description above.
D
DuYao 已提交
161
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
162
            The default value is 1.
D
DuYao 已提交
163 164
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
165

166
    Returns:
D
DuYao 已提交
167
        None.
168

169 170 171 172
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
173
          import paddle
174 175 176
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
          with fluid.dygraph.guard():
177
              emb = paddle.nn.Embedding(10, 10)
178
              optimizer = fluid.optimizer.SGD(
179 180
                 learning_rate=fluid.dygraph.PiecewiseDecay(boundaries, values, 0),
                 parameter_list = emb.parameters() )
181 182
    """

M
minqiyang 已提交
183
    def __init__(self, boundaries, values, begin, step=1, dtype='float32'):
184
        super().__init__(begin, step, dtype)
M
minqiyang 已提交
185 186 187 188 189
        self.boundaries = boundaries
        self.values = values

        self.vars = []
        for value in values:
190
            self.vars.append(value)
M
minqiyang 已提交
191 192

    def step(self):
M
minqiyang 已提交
193 194
        for i in range(len(self.boundaries)):
            if self.step_num < self.boundaries[i]:
M
minqiyang 已提交
195
                return self.vars[i]
196
        return self.create_lr_var(self.vars[len(self.values) - 1])
197 198 199


class NaturalExpDecay(LearningRateDecay):
200
    r"""
201 202
    :api_attr: imperative

203
    Applies natural exponential decay to the initial learning rate.
204

D
DuYao 已提交
205
    The algorithm can be described as following.
206

D
DuYao 已提交
207 208
    .. math::

209
        decayed\_learning\_rate = learning\_rate * e^{y}
D
DuYao 已提交
210 211 212 213 214 215 216 217 218 219 220

    If staircase is set to False, then:

    .. math::

        y = - decay\_rate * \\frac{global\_step}{decay\_steps}

    If staircase is set to True, then:

    .. math::

221
        y = - decay\_rate * math.floor(\\frac{global\_step}{decay\_steps})
D
DuYao 已提交
222 223

    Parameters:
224 225
        learning_rate(Variable|float): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
D
DuYao 已提交
226 227 228
            float32 or float64. It also can be set to python int number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        decay_rate(int): The decay rate.
229
        staircase(bool, optional): If set to True, decay the learning rate at discrete intervals. The
D
DuYao 已提交
230 231 232
            default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
233
            The default value is 1.
D
DuYao 已提交
234 235
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
236

237
    Returns:
D
DuYao 已提交
238
        None.
239

240 241 242
    Examples:
        .. code-block:: python

243
            import paddle.fluid as fluid
244
            import paddle
245 246
            base_lr = 0.1
            with fluid.dygraph.guard():
247
                emb = paddle.nn.Embedding(10, 10)
248 249 250 251 252 253 254
                sgd_optimizer = fluid.optimizer.SGD(
                        learning_rate=fluid.dygraph.NaturalExpDecay(
                            learning_rate=base_lr,
                            decay_steps=10000,
                            decay_rate=0.5,
                            staircase=True),
                        parameter_list=emb.parameters())
255 256 257

    """

258 259 260 261 262 263 264 265 266 267
    def __init__(
        self,
        learning_rate,
        decay_steps,
        decay_rate,
        staircase=False,
        begin=0,
        step=1,
        dtype='float32',
    ):
268
        super().__init__(begin, step, dtype)
269 270 271 272 273 274 275 276
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
277 278
            div_res = paddle.floor(div_res)
        decayed_lr = self.learning_rate * paddle.exp(
279 280
            -1 * self.decay_rate * div_res
        )
281 282 283 284 285

        return decayed_lr


class ExponentialDecay(LearningRateDecay):
286
    r"""
287 288
    :api_attr: imperative

289 290
    Applies exponential decay to the learning rate.

D
DuYao 已提交
291
    The algorithm can be described as following.
292

D
DuYao 已提交
293
    .. math::
294

295
        decayed\_learning\_rate = learning\_rate * decay\_rate ^ y
D
DuYao 已提交
296 297 298 299 300

    If staircase is set to False, then:

    .. math::

301
        y = \\frac{global\_step}{decay\_steps}
D
DuYao 已提交
302 303 304 305 306 307 308 309 310

    If staircase is set to True, then:

    .. math::

        y = math.floor(\\frac{global\_step}{decay\_steps})


    Parameters:
311 312
        learning_rate(Variable|float): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
D
DuYao 已提交
313 314 315
            float32 or float64. It also can be set to python int number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        decay_rate(float): The decay rate.
316
        staircase(bool, optional): If set to True, decay the learning rate at discrete intervals. The
D
DuYao 已提交
317 318 319
            default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
320
            The default value is 1.
D
DuYao 已提交
321 322
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
323

324
    Returns:
D
DuYao 已提交
325
        None.
326

327 328 329 330 331 332 333
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          base_lr = 0.1
          with fluid.dygraph.guard():
              sgd_optimizer = fluid.optimizer.SGD(
334 335 336 337 338
                    learning_rate=fluid.dygraph.ExponentialDecay(
                        learning_rate=base_lr,
                        decay_steps=10000,
                        decay_rate=0.5,
                        staircase=True))
339 340 341

    """

342 343 344 345 346 347 348 349 350 351
    def __init__(
        self,
        learning_rate,
        decay_steps,
        decay_rate,
        staircase=False,
        begin=0,
        step=1,
        dtype='float32',
    ):
352
        super().__init__(begin, step, dtype)
353 354 355 356 357 358 359 360
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
361
            div_res = paddle.floor(div_res)
362 363 364 365 366 367 368

        decayed_lr = self.learning_rate * (self.decay_rate**div_res)

        return decayed_lr


class InverseTimeDecay(LearningRateDecay):
369
    r"""
370 371
    :api_attr: imperative

372 373
    Applies inverse time decay to the initial learning rate.

D
DuYao 已提交
374 375 376 377 378
    The algorithm can be described as following.
    If staircase is set to False, then:

    .. math::

379
        decayed\_learning\_rate = \\frac{learning\_rate}{1 + decay\_rate * \\frac{global\_step}{decay\_step}}
D
DuYao 已提交
380 381 382 383 384 385 386 387

    If staircase is set to True, then:

    .. math::

        decayed\_learning\_rate = \\frac{learning\_rate}{1 + decay\_rate * math.floor(\\frac{global\_step}{decay\_step})}

    Parameters:
388 389
        learning_rate(Variable|float): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
D
DuYao 已提交
390 391 392
            float32 or float64. It also can be set to python int number.
        decay_steps(int): The decay step size. It determines the decay cycle.
        decay_rate(float): The decay rate.
393
        staircase(bool, optional): If set to True, decay the learning rate at discrete intervals. The
D
DuYao 已提交
394 395 396
            default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
397
            The default value is 1.
398
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be
D
DuYao 已提交
399
            'float32', 'float64'. The default value is 'float32'.
400

401
    Returns:
D
DuYao 已提交
402
        None.
403

404 405 406 407
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
408
          import paddle
409 410
          base_lr = 0.1
          with fluid.dygraph.guard():
411
              emb = paddle.nn.Embedding(10, 10)
412
              sgd_optimizer = fluid.optimizer.SGD(
413 414 415 416 417
                  learning_rate=fluid.dygraph.InverseTimeDecay(
                        learning_rate=base_lr,
                        decay_steps=10000,
                        decay_rate=0.5,
                        staircase=True),
418
                  parameter_list = emb.parameters())
419 420 421

    """

422 423 424 425 426 427 428 429 430 431
    def __init__(
        self,
        learning_rate,
        decay_steps,
        decay_rate,
        staircase=False,
        begin=0,
        step=1,
        dtype='float32',
    ):
432
        super().__init__(begin, step, dtype)
433 434 435 436 437 438 439 440
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.decay_rate = decay_rate
        self.staircase = staircase

    def step(self):
        div_res = self.create_lr_var(self.step_num / self.decay_steps)
        if self.staircase:
441
            div_res = paddle.floor(div_res)
442 443 444 445 446 447 448

        decayed_lr = self.learning_rate / (1 + self.decay_rate * div_res)

        return decayed_lr


class PolynomialDecay(LearningRateDecay):
449
    r"""
450 451
    :api_attr: imperative

452 453
    Applies polynomial decay to the initial learning rate.

D
DuYao 已提交
454 455 456 457 458 459
    The algorithm can be described as following.

    If cycle is set to True, then:

    .. math::

460
        decay\_steps & = decay\_steps * math.ceil(\\frac{global\_step}{decay\_steps})
461

D
DuYao 已提交
462 463 464 465 466 467
        decayed\_learning\_rate & = (learning\_rate-end\_learning\_rate)*(1-\\frac{global\_step}{decay\_steps})^{power}+end\_learning\_rate

    If cycle is set to False, then:

    .. math::

468
        global\_step & = min(global\_step, decay\_steps)
D
DuYao 已提交
469 470 471 472

        decayed\_learning\_rate & = (learning\_rate-end\_learning\_rate)*(1-\\frac{global\_step}{decay\_steps})^{power}+end\_learning\_rate

    Parameters:
473 474
        learning_rate(Variable|float): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
D
DuYao 已提交
475
            float32 or float64. It also can be set to python int number.
476
        decay_steps(int): The decay step size. It determines the decay cycle.
D
DuYao 已提交
477 478 479 480 481
        end_learning_rate(float, optional): The minimum final learning rate. The default value is 0.0001.
        power(float, optional): Power of polynomial. The default value is 1.0.
        cycle(bool, optional): If set true, decay the learning rate every decay_steps. The default value is False.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
482
            The default value is 1.
D
DuYao 已提交
483 484
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
485

486
    Returns:
D
DuYao 已提交
487
        None.
488

489 490 491 492
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
493
          import paddle
494 495 496 497
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          with fluid.dygraph.guard():
498
              emb = paddle.nn.Embedding(10, 10)
499 500
              optimizer  = fluid.optimizer.SGD(
                  learning_rate = fluid.dygraph.PolynomialDecay(
501 502
                  start_lr, total_step, end_lr, power=1.0),
                  parameter_list = emb.parameters())
503 504 505

    """

506 507 508 509 510 511 512 513 514 515 516
    def __init__(
        self,
        learning_rate,
        decay_steps,
        end_learning_rate=0.0001,
        power=1.0,
        cycle=False,
        begin=0,
        step=1,
        dtype='float32',
    ):
517
        super().__init__(begin, step, dtype)
518 519 520 521 522 523 524
        self.learning_rate = learning_rate
        self.decay_steps = decay_steps
        self.end_learning_rate = end_learning_rate
        self.power = power
        self.cycle = cycle

    def step(self):
M
minqiyang 已提交
525 526
        tmp_step_num = self.step_num
        tmp_decay_steps = self.decay_steps
527
        if self.cycle:
528
            div_res = paddle.ceil(
529 530
                self.create_lr_var(tmp_step_num / float(self.decay_steps))
            )
531

M
minqiyang 已提交
532 533
            if tmp_step_num == 0:
                div_res = self.create_lr_var(1.0)
M
minqiyang 已提交
534
            tmp_decay_steps = self.decay_steps * div_res
535
        else:
536
            tmp_step_num = self.create_lr_var(
537 538 539 540
                tmp_step_num
                if tmp_step_num < self.decay_steps
                else self.decay_steps
            )
M
minqiyang 已提交
541

542 543 544
        decayed_lr = (self.learning_rate - self.end_learning_rate) * (
            (1 - tmp_step_num / tmp_decay_steps) ** self.power
        ) + self.end_learning_rate
M
minqiyang 已提交
545
        return decayed_lr
546

M
minqiyang 已提交
547 548

class CosineDecay(LearningRateDecay):
549
    r"""
550 551
    :api_attr: imperative

552 553
    Applies cosine decay to the learning rate.

D
DuYao 已提交
554
    The algorithm can be described as following.
555 556 557

    .. math::

D
DuYao 已提交
558
        decayed\_learning\_rate = learning\_rate * 0.5 * (math.cos(global\_step * \\frac{math.pi}{step\_each\_epoch} ) + 1)
559

D
DuYao 已提交
560
    Parameters:
561 562
        learning_rate(Variable|float): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
D
DuYao 已提交
563 564 565 566 567
            float32 or float64. It also can be set to python int number.
        step_each_epoch(int): The number of steps in an epoch.
        epochs(int): The number of epochs.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
568
            The default value is 1.
D
DuYao 已提交
569 570
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
571

572
    Returns:
D
DuYao 已提交
573
        None.
574

575
    Examples:
576
        .. code-block:: python
577

578
            base_lr = 0.1
579 580
            with fluid.dygraph.guard():
                optimizer  = fluid.optimizer.SGD(
581 582
                    learning_rate = fluid.dygraph.CosineDecay(
                            base_lr, 10000, 120) )
583 584
    """

585 586 587 588 589 590 591 592 593
    def __init__(
        self,
        learning_rate,
        step_each_epoch,
        epochs,
        begin=0,
        step=1,
        dtype='float32',
    ):
594
        super().__init__(begin, step, dtype)
M
minqiyang 已提交
595 596 597 598 599
        self.learning_rate = learning_rate
        self.step_each_epoch = step_each_epoch
        self.epochs = epochs

    def step(self):
600
        cur_epoch = paddle.floor(
601 602 603 604 605
            self.create_lr_var(self.step_num / self.step_each_epoch)
        )
        decayed_lr = (
            self.learning_rate
            * 0.5
606
            * (paddle.cos(cur_epoch * math.pi / self.epochs) + 1)
607
        )
M
minqiyang 已提交
608 609 610 611
        return decayed_lr


class NoamDecay(LearningRateDecay):
612
    r"""
613 614
    :api_attr: imperative

615
    Applies Noam decay to the initial learning rate.
D
DuYao 已提交
616 617 618 619 620

    The algorithm can be described as following.

    .. math::

621
        decayed\_learning\_rate = learning\_rate * d_{model}^{-0.5} * min(global\_step^{-0.5}, global\_step * warmup\_steps^{-1.5})
D
DuYao 已提交
622

623
    Please reference `attention is all you need <https://arxiv.org/pdf/1706.03762.pdf>`_
D
DuYao 已提交
624 625

    Parameters:
626
        d$_{model}$(Variable|int): The dimensionality of input and output feature vector of model. If type is Variable,
D
DuYao 已提交
627
            it's a tensor with shape [1] and the data type can be int32 or int64. The type can also be python int.
628
        warmup_steps(Variable|int): The number of warmup steps. A super parameter. If type is Variable,
D
DuYao 已提交
629 630 631
            it's a tensor with shape [1] and the data type can be int32 or int64. The type can also be python int.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
632
            The default value is 1.
D
DuYao 已提交
633 634
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
635 636 637
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0
638

639
    Returns:
D
DuYao 已提交
640
        None.
641

642 643 644 645
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
646
          import paddle
647 648 649
          warmup_steps = 100
          learning_rate = 0.01
          with fluid.dygraph.guard():
650
              emb = paddle.nn.Embedding(10, 10)
651 652 653
              optimizer  = fluid.optimizer.SGD(
                  learning_rate = fluid.dygraph.NoamDecay(
                         1/(warmup_steps *(learning_rate ** 2)),
654 655
                         warmup_steps),
                  parameter_list = emb.parameters())
656 657
    """

658 659 660 661 662 663 664 665 666
    def __init__(
        self,
        d_model,
        warmup_steps,
        begin=1,
        step=1,
        dtype='float32',
        learning_rate=1.0,
    ):
667
        super().__init__(begin, step, dtype)
668
        self.learning_rate = learning_rate
M
minqiyang 已提交
669 670 671 672 673
        self.d_model = d_model
        self.warmup_steps = warmup_steps

    def step(self):
        from .. import layers
674

M
minqiyang 已提交
675 676
        a = self.create_lr_var(self.step_num**-0.5)
        b = self.create_lr_var((self.warmup_steps**-1.5) * self.step_num)
677
        lr_value = (
678
            self.learning_rate * (self.d_model**-0.5) * paddle.minimum(a, b)
679
        )
M
minqiyang 已提交
680
        return lr_value
H
hong 已提交
681 682 683 684


class LinearLrWarmup(LearningRateDecay):
    """
685 686
    :api_attr: imperative

H
hong 已提交
687 688
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
689

H
hong 已提交
690
    When global_step < warmup_steps, learning rate is updated as:
691

H
hong 已提交
692
    .. code-block:: text
693

H
hong 已提交
694 695
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
696

H
hong 已提交
697
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
698

H
hong 已提交
699
    When global_step >= warmup_steps, learning rate is updated as:
700

H
hong 已提交
701
    .. code-block:: text
702

H
hong 已提交
703
            lr = learning_rate
704

H
hong 已提交
705
    where lr is the learning_rate after warm-up.
706

H
hong 已提交
707 708 709 710 711 712 713
    Args:
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
        begin(int, optional): The begin step. The initial value of global_step described above. The default value is 0.
        step(int, optional): The step size used to calculate the new global_step in the description above.
T
tianshuo78520a 已提交
714
            The default value is 1.
H
hong 已提交
715 716
        dtype(str, optional): The data type used to create the learning rate variable. The data type can be set as
            'float32', 'float64'. The default value is 'float32'.
717

H
hong 已提交
718 719
    Returns:
        Variable: Warm-up learning rate with the same data type as learning_rate.
720 721


H
hong 已提交
722
    Examples:
723

H
hong 已提交
724
    .. code-block:: python
725

H
hong 已提交
726
        import paddle.fluid as fluid
727 728

        learning_rate = 0.1
H
hong 已提交
729
        warmup_steps = 50
730
        start_lr = 0
H
hong 已提交
731 732
        end_lr = 0.1

733
        with fluid.dygraph.guard():
H
hong 已提交
734
            lr_decay = fluid.dygraph.LinearLrWarmup( learning_rate, warmup_steps, start_lr, end_lr)
735 736


H
hong 已提交
737 738
    """

739 740 741 742 743 744 745 746 747 748
    def __init__(
        self,
        learning_rate,
        warmup_steps,
        start_lr,
        end_lr,
        begin=1,
        step=1,
        dtype='float32',
    ):
749
        super().__init__(begin, step, dtype)
750 751 752 753 754
        type_check = (
            isinstance(learning_rate, float)
            or isinstance(learning_rate, int)
            or isinstance(learning_rate, LearningRateDecay)
        )
H
hong 已提交
755 756
        if not type_check:
            raise TypeError(
757 758 759 760
                "the type of learning_rate should be [int, float or LearningRateDecay], the current type is {}".format(
                    learning_rate
                )
            )
H
hong 已提交
761 762
        self.learning_rate = learning_rate
        self.warmup_steps = warmup_steps
763
        self.start_lr = start_lr
764 765 766 767 768 769
        assert (
            end_lr > start_lr
        ), "end_lr {} must be greater than start_lr {}".format(end_lr, start_lr)
        self.lr_ratio_before_warmup = (float(end_lr) - float(start_lr)) / float(
            warmup_steps
        )
H
hong 已提交
770 771 772 773 774 775 776

    def step(self):
        base_lr = self.learning_rate
        if isinstance(self.learning_rate, LearningRateDecay):
            base_lr = base_lr()

        from .. import layers
777

H
hong 已提交
778
        if self.step_num < self.warmup_steps:
779
            return self.lr_ratio_before_warmup * self.step_num + self.start_lr
H
hong 已提交
780 781
        else:
            return base_lr
782 783 784 785


class ReduceLROnPlateau(LearningRateDecay):
    """
786 787
    :api_attr: imperative

788
    Reduce learning rate when ``loss`` has stopped descending. Models often benefit from reducing the learning rate
789 790
    by 2 to 10 times once model performance has no longer improvement.

791 792 793
    The ``loss`` is the one which has been pass into ``step`` , it must be 1-D Tensor with shape [1]. When ``loss``
    stop descending for a ``patience`` number of epochs, the learning rate will be reduced to ``learning_rate * decay_rate`` .
    (Specially, ``mode`` can also be set to ``'max`` , in this case, when ``loss`` stop ascending for a ``patience`` number
794 795 796 797 798 799 800
    of epochs, the learning rate will be reduced.)

    In addition, After each reduction, it will wait a ``cooldown`` number of epochs before resuming normal operation.

    Args:
        learning_rate (Variable|float|int): The initial learning rate. It can be set to python float or int number.
            If the type is Variable, it should be 1-D Tensor with shape [1], the data type can be 'float32' or 'float64'.
801 802
        mode (str, optional): ``'min'`` or ``'max'`` can be selected. Normally, it is ``'min'`` , which means that the
            learning rate will reduce when ``loss`` stops descending. Specially, if it's set to ``'max'`` ,  the learning
803
            rate will reduce when ``loss`` stops ascending. Default: ``'min'`` .
804
        decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
805
            It should be less than 1.0. Default: 0.1.
806
        patience (int, optional): When ``loss`` doesn't improve for this number of epochs, learing rate will be reduced.
807 808
            Default: 10.
        verbose (bool, optional): If ``True``, prints a message to stdout for each update. Default: ``False``.
809
        threshold (float, optional): ``threshold`` and ``threshold_mode`` will determine the minimum change of ``loss`` .
810 811
            This make tiny changes of ``loss`` will be ignored. Default: 1e-4.
        threshold_mode (str, optional): ``'rel'`` or ``'abs'`` can be selected. In ``'rel'`` mode, the minimum change of ``loss``
812
            is ``last_loss * threshold`` , where ``last_loss`` is ``loss`` in last epoch. In ``'abs'`` mode, the minimum
813 814 815 816 817 818
            change of ``loss`` is ``threshold`` . Default: ``'rel'`` .
        cooldown (int, optional): The number of epochs to wait before resuming normal operation. Default: 0.
        min_lr (float, optional): The lower bound of the learning rate after reduction. Default: 0.
        eps (float, optional): Minimal decay applied to lr. If the difference between new and old lr is smaller than eps, the update is
            ignored. Default: 1e-8.
        dtype (str, optional): The data type used to create the learning rate variable. The data type can be set as
819 820
            'float32', 'float64'. Default: 'float32'.

821 822 823 824
    Returns:
        Reduced learning rate.

    Examples:
825

826 827 828
    .. code-block:: python

        import paddle.fluid as fluid
829
        import paddle
830 831 832 833
        import numpy as np

        with fluid.dygraph.guard():
            x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
834
            linear = paddle.nn.Linear(10, 10)
835 836 837 838 839 840
            input = fluid.dygraph.to_variable(x)

            reduce_lr = fluid.dygraph.ReduceLROnPlateau(
                                    learning_rate = 1.0,
                                    decay_rate = 0.5,
                                    patience = 5,
841
                                    verbose = True,
842 843 844 845 846 847 848 849 850
                                    cooldown = 3)
            adam = fluid.optimizer.Adam(
                learning_rate = reduce_lr,
                parameter_list = linear.parameters())

            for epoch in range(10):
                total_loss = 0
                for bath_id in range(5):
                    out = linear(input)
851
                    loss = paddle.mean(out)
852 853
                    total_loss += loss
                    adam.minimize(loss)
854

855 856 857 858 859 860 861 862 863
                avg_loss = total_loss/5

                # adjust learning rate according to avg_loss
                reduce_lr.step(avg_loss)
                lr = adam.current_step_lr()
                print("current avg_loss is %s, current lr is %s" % (avg_loss.numpy()[0], lr))

    """

864 865 866 867 868 869 870 871 872 873 874 875 876 877
    def __init__(
        self,
        learning_rate,
        mode='min',
        decay_rate=0.1,
        patience=10,
        verbose=False,
        threshold=1e-4,
        threshold_mode='rel',
        cooldown=0,
        min_lr=0,
        eps=1e-8,
        dtype='float32',
    ):
878
        super().__init__(dtype=dtype)
879 880 881 882 883 884 885 886 887
        mode = mode.lower()
        if mode not in ['min', 'max']:
            raise ValueError('mode ' + mode + ' is unknown!')
        self.mode = mode

        if decay_rate >= 1.0:
            raise ValueError(
                'new_lr = origin_lr * decay_rate and decay_rate should be < 1.0.'
            )
888
        self.decay_rate = self.create_lr_var(decay_rate)
889 890 891

        threshold_mode = threshold_mode.lower()
        if threshold_mode not in ['rel', 'abs']:
892 893 894
            raise ValueError(
                'threshold mode ' + threshold_mode + ' is unknown!'
            )
895
        self.threshold_mode = threshold_mode
896 897 898 899 900 901
        check_type(
            learning_rate,
            'learning_rate',
            (float, int, Variable),
            'ReduceLROnPlateau',
        )
902 903 904
        if not isinstance(learning_rate, (float, int, Variable)):
            raise TypeError(
                "The type of 'learning_rate' in 'ReduceLROnPlateau' must be 'float, int, Variable', but received %s."
905 906
                % type(learning_rate)
            )
907 908 909 910 911 912 913 914 915 916 917 918 919

        self.learning_rate = learning_rate
        self.verbose = verbose
        self.patience = patience
        self.threshold = threshold
        self.threshold_mode = threshold_mode
        self.cooldown = cooldown
        self.min_lr = self.create_lr_var(min_lr)
        self.eps = eps

        self.cooldown_counter = 0
        self.best_loss = None
        self.num_bad_epochs = 0
920 921
        self.epoch_num = 0

922
    # "cooldown_counter / best_loss / num_bad_epochs / epoch_num / learning_rate" will be stored.
923 924
    def _state_keys(self):
        self.keys = [
925 926 927 928 929
            'cooldown_counter',
            'best_loss',
            'num_bad_epochs',
            'epoch_num',
            'learning_rate',
930
        ]
931 932

    def __call__(self):
933 934
        if not isinstance(self.learning_rate, Variable):
            self.learning_rate = self.create_lr_var(self.learning_rate)
935 936 937 938
        return self.learning_rate

    def step(self, loss):
        """
939
        It should be invoked on each epoch. Update the learning rate in optimizer according to ``loss`` .
940 941 942
        The new learning rate will take effect on next call to ``optimizer.minimize`` .

        Args:
943 944 945
            loss (Variable): A ``Variable`` that will be monitored to determine whether the learning rate will reduce.
                If it stop descending for a ``patience`` number of epochs, the learning rate will reduce. It should
                be 1-D Tensor with shape [1].
946 947 948
                Specially, if ``mode`` has been set to ``'max'`` ,  the learning rate will reduce when it stops ascending.
        Returns:
            None
949

950 951 952 953 954 955
        Examples:
            Please refer to the example of current LearningRateDecay.
        """

        # loss must be 1-D Tensor with shape [1]
        check_type(loss, 'loss', Variable, 'ReduceLROnPlateau.step')
956 957 958 959 960 961 962
        assert len(loss.shape) == 1 and loss.shape[0] == 1, (
            "the loss.shape "
            "should be (1L,), but the current loss.shape is {}. Maybe that "
            "you should call paddle.mean to process it first.".format(
                loss.shape
            )
        )
963

964
        self.epoch_num += 1
965 966 967 968 969 970 971 972 973 974 975 976
        if self.cooldown_counter > 0:
            self.cooldown_counter -= 1
        else:
            if self.best_loss is None or self._is_better(loss, self.best_loss):
                self.best_loss = loss
                self.num_bad_epochs = 0
            else:
                self.num_bad_epochs += 1

            if self.num_bad_epochs > self.patience:
                self.cooldown_counter = self.cooldown
                self.num_bad_epochs = 0
H
HongyuJia 已提交
977
                new_lr = paddle.maximum(
978 979
                    self.learning_rate * self.decay_rate, self.min_lr
                )
980 981
                if self.learning_rate - new_lr > self.eps:
                    if self.verbose:
982 983 984 985 986 987 988 989 990 991
                        old_lr = (
                            self.learning_rate.numpy()[0]
                            if isinstance(self.learning_rate, Variable)
                            else self.learning_rate
                        )
                        print(
                            'Epoch {}: reducing learning rate from {} to {}.'.format(
                                self.epoch_num, old_lr, new_lr.numpy()[0]
                            )
                        )
992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005
                    self.learning_rate = new_lr

    def _is_better(self, current, best):
        if self.mode == 'min' and self.threshold_mode == 'rel':
            return current < best - best * self.threshold

        elif self.mode == 'min' and self.threshold_mode == 'abs':
            return current < best - self.threshold

        elif self.mode == 'max' and self.threshold_mode == 'rel':
            return current > best + best * self.threshold

        else:
            return current > best + self.threshold
1006 1007 1008 1009 1010 1011 1012


class _LearningRateEpochDecay(LearningRateDecay):
    """
    :api_attr: imperative

    Base class of learning rate decay, which is updated each epoch.
1013

1014 1015 1016 1017 1018 1019 1020 1021 1022
    Define the common interface of an _LearningRateEpochDecay.
    User should not use this class directly,
    but need to use one of it's implementation. And invoke method: `epoch()` each epoch.
    """

    def __init__(self, learning_rate, dtype=None):
        if not isinstance(learning_rate, (float, int)):
            raise TypeError(
                "The type of 'learning_rate' must be 'float, int', but received %s."
1023 1024
                % type(learning_rate)
            )
1025 1026
        if learning_rate < 0:
            raise ValueError("Invalid learning rate: {}".format(learning_rate))
1027 1028 1029 1030

        self.base_lr = float(learning_rate)

        self.epoch_num = -1
1031
        self.dtype = dtype
1032 1033 1034 1035 1036 1037
        if dtype is None:
            self.dtype = "float32"
        self.learning_rate = self.create_lr_var(self.base_lr)

        self.epoch()

1038 1039
    # For those subclass who overload _LearningRateEpochDecay, "self.epoch_num/learning_rate" will be stored by default.
    # you can change it for your subclass.
1040 1041 1042
    def _state_keys(self):
        self.keys = ['epoch_num', 'learning_rate']

1043
    def __call__(self):
1044
        """
1045 1046
        Return last computed learning rate on current epoch.
        """
1047 1048
        if not isinstance(self.learning_rate, Variable):
            self.learning_rate = self.create_lr_var(self.learning_rate)
1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071
        return self.learning_rate

    def epoch(self, epoch=None):
        """
        compueted learning_rate and update it when invoked.
        """
        if epoch is None:
            self.epoch_num += 1
        else:
            self.epoch_num = epoch

        self.learning_rate = self.get_lr()

    def get_lr(self):
        raise NotImplementedError


class StepDecay(_LearningRateEpochDecay):
    """
    :api_attr: imperative

    Decays the learning rate of ``optimizer`` by ``decay_rate`` every ``step_size`` number of epoch.

1072
    The algorithm can be described as the code below.
1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086

    .. code-block:: text

        learning_rate = 0.5
        step_size = 30
        decay_rate = 0.1

        learning_rate = 0.5     if epoch < 30
        learning_rate = 0.05    if 30 <= epoch < 60
        learning_rate = 0.005   if 60 <= epoch < 90
        ...

    Parameters:
        learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
1087
        step_size (int): Period of learning rate decay.
1088
        decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
1089 1090 1091 1092 1093 1094 1095
            It should be less than 1.0. Default: 0.1.

    Returns:
        None.

    Examples:
        .. code-block:: python
1096

1097 1098
            import paddle.fluid as fluid
            import numpy as np
1099
            import paddle
1100 1101
            with fluid.dygraph.guard():
                x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
1102
                linear = paddle.nn.Linear(10, 10)
1103 1104 1105 1106 1107 1108 1109
                input = fluid.dygraph.to_variable(x)
                scheduler = fluid.dygraph.StepDecay(0.5, step_size=3)
                adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())

                for epoch in range(9):
                    for batch_id in range(5):
                        out = linear(input)
1110
                        loss = paddle.mean(out)
1111
                        adam.minimize(loss)
1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129
                    scheduler.epoch()

                    print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr()))
                    # epoch:0, current lr is 0.5
                    # epoch:1, current lr is 0.5
                    # epoch:2, current lr is 0.5
                    # epoch:3, current lr is 0.05
                    # epoch:4, current lr is 0.05
                    # epoch:5, current lr is 0.05
                    # epoch:6, current lr is 0.005
                    # epoch:7, current lr is 0.005
                    # epoch:8, current lr is 0.005

    """

    def __init__(self, learning_rate, step_size, decay_rate=0.1):
        if not isinstance(step_size, int):
            raise TypeError(
1130 1131 1132
                "The type of 'step_size' must be 'int', but received %s."
                % type(step_size)
            )
1133 1134 1135 1136 1137
        if decay_rate >= 1.0:
            raise ValueError('decay_rate should be < 1.0.')

        self.step_size = step_size
        self.decay_rate = decay_rate
1138
        super().__init__(learning_rate)
1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151

    def get_lr(self):
        decay_rate = self.create_lr_var(self.decay_rate)
        i = self.epoch_num // self.step_size
        return self.base_lr * (decay_rate**i)


class MultiStepDecay(_LearningRateEpochDecay):
    """
    :api_attr: imperative

    Decays the learning rate of ``optimizer`` by ``decay_rate`` once ``epoch`` reaches one of the milestones.

1152
    The algorithm can be described as the code below.
1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166

    .. code-block:: text

        learning_rate = 0.5
        milestones = [30, 50]
        decay_rate = 0.1
        if epoch < 30:
            learning_rate = 0.5
        elif epoch < 50:
            learning_rate = 0.05
        else:
            learning_rate = 0.005

    Parameters:
1167
        learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
1168
        milestones (tuple|list): List or tuple of each boundaries. Must be increasing.
1169
        decay_rate (float, optional): The Ratio that the learning rate will be reduced. ``new_lr = origin_lr * decay_rate`` .
1170 1171 1172 1173 1174 1175 1176
            It should be less than 1.0. Default: 0.1.

    Returns:
        None.

    Examples:
        .. code-block:: python
1177

1178 1179
            import paddle.fluid as fluid
            import numpy as np
1180
            import paddle
1181 1182
            with fluid.dygraph.guard():
                x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
1183
                linear = paddle.nn.Linear(10, 10)
1184 1185 1186 1187 1188 1189 1190
                input = fluid.dygraph.to_variable(x)
                scheduler = fluid.dygraph.MultiStepDecay(0.5, milestones=[3, 5])
                adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())

                for epoch in range(6):
                    for batch_id in range(5):
                        out = linear(input)
1191
                        loss = paddle.mean(out)
1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208
                        adam.minimize(loss)
                    scheduler.epoch()

                    print("epoch:{}, current lr is {}" .format(epoch, adam.current_step_lr()))
                    # epoch:0, current lr is 0.5
                    # epoch:1, current lr is 0.5
                    # epoch:2, current lr is 0.5
                    # epoch:3, current lr is 0.05
                    # epoch:4, current lr is 0.05
                    # epoch:5, current lr is 0.005

    """

    def __init__(self, learning_rate, milestones, decay_rate=0.1):
        if not isinstance(milestones, (tuple, list)):
            raise TypeError(
                "The type of 'milestones' in 'MultiStepDecay' must be 'tuple, list', but received %s."
1209 1210
                % type(milestones)
            )
1211

1212 1213
        if not all(
            [
1214 1215
                milestones[i] < milestones[i + 1]
                for i in range(len(milestones) - 1)
1216 1217
            ]
        ):
1218 1219 1220 1221 1222 1223
            raise ValueError('The elements of milestones must be incremented')
        if decay_rate >= 1.0:
            raise ValueError('decay_rate should be < 1.0.')

        self.milestones = milestones
        self.decay_rate = decay_rate
1224
        super().__init__(learning_rate)
1225 1226 1227 1228 1229 1230 1231

    def get_lr(self):
        decay_rate = self.create_lr_var(self.decay_rate)
        for i in range(len(self.milestones)):
            if self.epoch_num < self.milestones[i]:
                return self.base_lr * (decay_rate**i)

1232
        return self.base_lr * (decay_rate ** len(self.milestones))
1233 1234 1235 1236 1237 1238 1239 1240 1241


class LambdaDecay(_LearningRateEpochDecay):
    """
    :api_attr: imperative

    Sets the learning rate of ``optimizer`` to the initial lr times a multiplicative factor, and this multiplicative
    factor is computed by function ``lr_lambda`` . ``lr_lambda`` is funciton which receives ``epoch`` .

1242
    The algorithm can be described as the code below.
1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254

    .. code-block:: text

        learning_rate = 0.5        # init learning_rate
        lr_lambda = lambda epoch: 0.95 ** epoch

        learning_rate = 0.5        # epoch 0
        learning_rate = 0.475      # epoch 1
        learning_rate = 0.45125    # epoch 2

    Parameters:
        learning_rate (float|int): The initial learning rate. It can be set to python float or int number.
1255
        lr_lambda (function): A function which computes a multiplicative factor given an integer parameter ``epoch`` , and
1256
            then multiply the initial learning rate by this multiplicative factor.
1257

1258 1259 1260 1261 1262
    Returns:
        None.

    Examples:
        .. code-block:: python
1263

1264 1265
            import paddle.fluid as fluid
            import numpy as np
1266
            import paddle
1267 1268
            with fluid.dygraph.guard():
                x = np.random.uniform(-1, 1, [10, 10]).astype("float32")
1269
                linear = paddle.nn.Linear(10, 10)
1270 1271 1272 1273 1274 1275 1276
                input = fluid.dygraph.to_variable(x)
                scheduler = fluid.dygraph.LambdaDecay(0.5, lr_lambda=lambda x: 0.95**x)
                adam = fluid.optimizer.Adam(learning_rate = scheduler, parameter_list = linear.parameters())

                for epoch in range(6):
                    for batch_id in range(5):
                        out = linear(input)
1277
                        loss = paddle.mean(out)
1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291
                        adam.minimize(loss)
                    scheduler.epoch()

                    print("epoch:%d, current lr is %f" .format(epoch, adam.current_step_lr()))
                    # epoch:0, current lr is 0.5
                    # epoch:1, current lr is 0.475
                    # epoch:2, current lr is 0.45125

    """

    def __init__(self, learning_rate, lr_lambda):
        if not callable(lr_lambda):
            raise TypeError(
                "The type of 'lr_lambda' in 'LambdaDecay' must be 'function', but received %s."
1292 1293
                % type(lr_lambda)
            )
1294 1295

        self.lr_lambda = lr_lambda
1296
        super().__init__(learning_rate)
1297 1298 1299 1300 1301

    def get_lr(self):
        base_lr = self.create_lr_var(self.base_lr)

        return self.base_lr * self.lr_lambda(self.epoch_num)