learning_rate_scheduler.py 20.8 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23
import math
Q
qingqing01 已提交
24
import numbers
25

26
import paddle
27 28 29
from . import control_flow
from . import nn
from . import tensor
姜永久 已提交
30 31 32 33 34 35 36
from ..framework import (
    default_main_program,
    Parameter,
    unique_name,
    name_scope,
    in_dygraph_mode,
)
Q
qingqing01 已提交
37
from ..framework import Variable
M
minqiyang 已提交
38
from ..dygraph import learning_rate_scheduler as imperate_lr
39
from ..data_feeder import check_variable_and_dtype, check_type
Q
Qiao Longfei 已提交
40

41
__all__ = [
42 43 44 45 46 47 48 49
    'exponential_decay',
    'natural_exp_decay',
    'inverse_time_decay',
    'polynomial_decay',
    'piecewise_decay',
    'noam_decay',
    'cosine_decay',
    'linear_lr_warmup',
50
]
Q
Qiao Longfei 已提交
51 52


53
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
54
    # the first global step is zero in learning rate decay
55
    global_step = nn.autoincreased_step_counter(
56 57
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1
    )
58
    global_step = paddle.cast(global_step, 'float32')
Y
Yu Yang 已提交
59 60 61
    return global_step


62
def noam_decay(d_model, warmup_steps, learning_rate=1.0):
Y
yuyang18 已提交
63
    """
S
swtkiwi 已提交
64

Y
yuyang18 已提交
65 66
    Noam decay method. The numpy implementation of noam decay as follows.

X
xiaoting 已提交
67
    .. code-block:: python
68

69
      import paddle.fluid as fluid
X
xiaoting 已提交
70 71
      import numpy as np
      # set hyper parameters
72
      base_lr = 0.01
X
xiaoting 已提交
73 74 75 76
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
77
      lr_value = base_lr * np.power(d_model, -0.5) * np.min([
X
xiaoting 已提交
78 79
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])
Y
yuyang18 已提交
80 81 82

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
83 84 85

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
86

87 88
        warmup_steps(Variable): A super parameter.

89 90 91 92
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0

93 94
    Returns:
        The decayed learning rate.
X
xiaoting 已提交
95 96 97
    Examples:
        .. code-block:: python

98
          import paddle.fluid as fluid
X
xiaoting 已提交
99 100 101 102
          warmup_steps = 100
          learning_rate = 0.01
          lr = fluid.layers.learning_rate_scheduler.noam_decay(
                         1/(warmup_steps *(learning_rate ** 2)),
103 104
                         warmup_steps,
                         learning_rate)
105
    """
106
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
107
        if in_dygraph_mode():
108 109 110
            decay = imperate_lr.NoamDecay(
                d_model, warmup_steps, learning_rate=learning_rate
            )
M
minqiyang 已提交
111 112 113
            return decay
        else:
            global_step = _decay_step_counter(1)
F
fengjiayi 已提交
114

M
minqiyang 已提交
115 116
            a = global_step**-0.5
            b = (warmup_steps**-1.5) * global_step
117
            lr_value = learning_rate * (d_model**-0.5) * paddle.minimum(a, b)
118

M
minqiyang 已提交
119
            return lr_value
120 121


Y
Yu Yang 已提交
122
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
123
    """
S
swtkiwi 已提交
124

125
    Applies exponential decay to the learning rate.
F
fengjiayi 已提交
126

127 128
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
F
fengjiayi 已提交
129 130
    'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
131
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
132

F
fengjiayi 已提交
133 134 135 136
    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
    >>> else:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
Q
Qiao Longfei 已提交
137 138

    Args:
139
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
140 141 142
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
143
        staircase(bool): If True, decay the learning rate at discrete intervals, which
K
Kaipeng Deng 已提交
144 145 146
                         means the learning rate will be decayed by `decay_rate` every
                         `decay_steps`. If False, learning rate will be decayed continuously
                         and following the formula above. Default: False
Q
Qiao Longfei 已提交
147 148

    Returns:
K
Kaipeng Deng 已提交
149
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
150 151 152 153

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
154
          import paddle.fluid as fluid
155 156 157
          import paddle

          paddle.enable_static()
F
fengjiayi 已提交
158 159
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
160 161 162 163 164
              learning_rate=fluid.layers.exponential_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
F
fengjiayi 已提交
165

Q
Qiao Longfei 已提交
166
    """
167
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
168
        if in_dygraph_mode():
169 170
            decay = paddle.optimizer.lr.ExponentialDecay(
                learning_rate, decay_rate
171
            )
172 173 174
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
175

176 177
            div_res = global_step / decay_steps
            if staircase:
178
                div_res = paddle.floor(div_res)
179
            decayed_lr = learning_rate * (decay_rate**div_res)
180

181
            return decayed_lr
Q
Qiao Longfei 已提交
182 183


Y
Yu Yang 已提交
184
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
S
swtkiwi 已提交
185 186
    """

187
    Applies natural exponential decay to the initial learning rate.
Q
Qiao Longfei 已提交
188

189 190 191
        When training a model, it is often recommended to lower the learning rate as the
        training progresses. By using this function, the learning rate will be decayed by
        natural exponential power 'decay_rate' every 'decay_steps' steps.
K
Kaipeng Deng 已提交
192

193
        Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
194

195 196 197 198
        >>> if not staircase:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
        >>> else:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps))
Y
Yu Yang 已提交
199

200 201 202 203 204 205 206 207 208
        Args:
            learning_rate(Variable|float): The initial learning rate. It should be a Variable
                                           or a float
            decay_steps(int): The learning rate decay steps. See the decay computation above.
            decay_rate(float): The learning rate decay rate. See the decay computation above.
            staircase(bool): If True, decay the learning rate at discrete intervals, which
                             means the learning rate will be decayed by natural exponential power
                             `decay_rate` every `decay_steps`. If False, learning rate will be
                             decayed continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
209

210 211
        Returns:
            The decayed learning rate. The data type is float32.
K
Kaipeng Deng 已提交
212

213 214
        Examples:
            .. code-block:: python
K
Kaipeng Deng 已提交
215

216 217
              import paddle.fluid as fluid
              import paddle
218

219 220 221 222 223 224 225 226
              paddle.enable_static()
              base_lr = 0.1
              sgd_optimizer = fluid.optimizer.SGD(
                  learning_rate=fluid.layers.natural_exp_decay(
                        learning_rate=base_lr,
                        decay_steps=10000,
                        decay_rate=0.5,
                        staircase=True))
K
Kaipeng Deng 已提交
227

Q
Qiao Longfei 已提交
228
    """
229
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
230
        if in_dygraph_mode():
231 232
            decay = paddle.optimizer.lr.NaturalExpDecay(
                learning_rate, decay_rate
233
            )
234 235 236
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
237

238 239
            div_res = global_step / decay_steps
            if staircase:
240 241
                div_res = paddle.floor(div_res)
            decayed_lr = learning_rate * paddle.exp(-1 * decay_rate * div_res)
242

243
            return decayed_lr
Q
Qiao Longfei 已提交
244 245


Y
Yu Yang 已提交
246
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
247
    """
S
swtkiwi 已提交
248

F
fengjiayi 已提交
249
    Applies inverse time decay to the initial learning rate.
Q
Qiao Longfei 已提交
250

251 252
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
F
fengjiayi 已提交
253
    applied to the initial learning rate.
Q
Qiao Longfei 已提交
254

T
tianshuo78520a 已提交
255
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
256

F
fengjiayi 已提交
257
    >>> if staircase == True:
Y
Yu Yang 已提交
258 259 260 261
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
262
    Args:
263
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
264 265 266
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
267 268 269
        staircase(bool): If True, decay the learning rate at discrete intervals, which
                         means the learning rate will be decayed by `decay_rate` times
                         every `decay_steps`. If False, learning rate will be decayed
K
Kaipeng Deng 已提交
270
                         continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
271 272

    Returns:
K
Kaipeng Deng 已提交
273
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
274 275 276 277

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
278
          import paddle.fluid as fluid
279 280
          import paddle
          paddle.enable_static()
F
fengjiayi 已提交
281 282
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
283 284 285 286 287
              learning_rate=fluid.layers.inverse_time_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
Q
Qiao Longfei 已提交
288
    """
289
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
290
        if in_dygraph_mode():
291 292
            decay = paddle.optimizer.lr.InverseTimeDecay(
                learning_rate, decay_rate
293
            )
294 295 296
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
297

298 299
            div_res = global_step / decay_steps
            if staircase:
300
                div_res = paddle.floor(div_res)
301

302
            decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
303

304
            return decayed_lr
305 306


307 308 309
def polynomial_decay(
    learning_rate, decay_steps, end_learning_rate=0.0001, power=1.0, cycle=False
):
Q
qiaolongfei 已提交
310 311 312
    """
    Applies polynomial decay to the initial learning rate.

X
xiaoting 已提交
313
    .. code-block:: text
Q
qiaolongfei 已提交
314 315 316 317 318 319 320

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
321 322

    Args:
Q
qiaolongfei 已提交
323
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
Q
update  
qiaolongfei 已提交
324
          will be the initial learning rate during training.
Q
qiaolongfei 已提交
325
        decay_steps(int32): A Python `int32` number.
Q
update  
qiaolongfei 已提交
326 327 328
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
329 330

    Returns:
Q
update  
qiaolongfei 已提交
331
        Variable: The decayed learning rate
X
xiaoting 已提交
332 333 334 335 336 337 338 339 340 341 342

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          lr = fluid.layers.polynomial_decay(
              start_lr, total_step, end_lr, power=1)

343
    """
344
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
345
        if in_dygraph_mode():
346 347 348
            decay = imperate_lr.PolynomialDecay(
                learning_rate, decay_steps, end_learning_rate, power, cycle
            )
349
            return decay
350
        else:
351 352 353
            global_step = _decay_step_counter()

            if cycle:
354
                div_res = paddle.ceil(global_step / decay_steps)
355
                zero_var = paddle.tensor.fill_constant(
356 357
                    shape=[1], dtype='float32', value=0.0
                )
358
                one_var = paddle.tensor.fill_constant(
359 360
                    shape=[1], dtype='float32', value=1.0
                )
361

Q
qizhaoaoe 已提交
362 363 364 365 366
                div_val = paddle.static.nn.cond(
                    global_step == zero_var, lambda: one_var, lambda: div_res
                )
                paddle.assign(div_val, output=div_res)

367 368
                decay_steps = decay_steps * div_res
            else:
369
                decay_steps_var = paddle.tensor.fill_constant(
370 371
                    shape=[1], dtype='float32', value=float(decay_steps)
                )
372
                global_step = paddle.minimum(x=global_step, y=decay_steps_var)
373 374 375 376

            decayed_lr = (learning_rate - end_learning_rate) * (
                (1 - global_step / decay_steps) ** power
            ) + end_learning_rate
377
            return decayed_lr
378 379


Y
Yu Yang 已提交
380
def piecewise_decay(boundaries, values):
S
swtkiwi 已提交
381 382
    """

383
    Applies piecewise decay to the initial learning rate.
X
Xin Pan 已提交
384

385
        The algorithm can be described as the code below.
X
Xin Pan 已提交
386

387
        .. code-block:: text
X
Xin Pan 已提交
388

X
xiaoting 已提交
389 390
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415
          if step < 10000:
              learning_rate = 1.0
          elif 10000 <= step < 20000:
              learning_rate = 0.5
          else:
              learning_rate = 0.1
        Args:
            boundaries: A list of steps numbers.
            values: A list of learning rate values that will be picked during
                different step boundaries.

        Returns:
            The decayed learning rate.

        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              import paddle
              paddle.enable_static()
              boundaries = [10000, 20000]
              values = [1.0, 0.5, 0.1]
              optimizer = fluid.optimizer.Momentum(
                  momentum=0.9,
                  learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values),
416
                  regularization=paddle.regularizer.L2Decay(1e-4))
X
xiaoting 已提交
417

X
Xin Pan 已提交
418

419
    """
420 421 422 423
    with default_main_program()._lr_schedule_guard():
        if len(values) - len(boundaries) != 1:
            raise ValueError("len(values) - len(boundaries) should be 1")

姜永久 已提交
424
        if in_dygraph_mode():
M
minqiyang 已提交
425
            decay = imperate_lr.PiecewiseDecay(boundaries, values, 0)
426 427 428
            return decay
        else:
            global_step = _decay_step_counter()
429

430
            lr = paddle.static.create_global_var(
431 432 433 434 435 436
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate",
            )
Q
qizhaoaoe 已提交
437
            # TODO: fluid.layers.control_flow.Switch should be replaced by paddle.static.nn.case(or cond) if possible
438 439
            with control_flow.Switch() as switch:
                for i in range(len(boundaries)):
440
                    boundary_val = paddle.tensor.fill_constant(
441 442 443 444 445
                        shape=[1],
                        dtype='float32',
                        value=float(boundaries[i]),
                        force_cpu=True,
                    )
446
                    with switch.case(global_step < boundary_val):
447
                        paddle.tensor.fill_constant(
448 449 450 451 452
                            shape=[1],
                            dtype="float32",
                            value=float(values[i]),
                            out=lr,
                        )
453
                with switch.default():
454
                    paddle.tensor.fill_constant(
455 456 457 458 459
                        shape=[1],
                        dtype="float32",
                        value=float(values[len(values) - 1]),
                        out=lr,
                    )
460
            return lr
W
Wu Yi 已提交
461 462


S
shippingwang 已提交
463
def cosine_decay(learning_rate, step_each_epoch, epochs):
464
    r"""
S
swtkiwi 已提交
465

S
shippingwang 已提交
466 467
    Applies cosine decay to the learning rate.

S
shippingwang 已提交
468
    when training a model, it is often recommended to lower the learning rate as the
S
shippingwang 已提交
469 470
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.
S
shippingwang 已提交
471

472 473
    .. math::

X
xsrobin 已提交
474 475
        decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)

S
shippingwang 已提交
476 477 478 479 480
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.

481
    Returns:
X
xsrobin 已提交
482
        Variable: The decayed learning rate.
S
shippingwang 已提交
483

484
    Examples:
X
xsrobin 已提交
485
        .. code-block:: python
S
shippingwang 已提交
486

X
xsrobin 已提交
487 488 489 490
            import paddle.fluid as fluid
            base_lr = 0.1
            lr = fluid.layers.cosine_decay(
            learning_rate = base_lr, step_each_epoch=10000, epochs=120)
S
shippingwang 已提交
491
    """
492 493 494
    check_type(
        learning_rate, 'learning_rate', (float, tensor.Variable), 'cosine_decay'
    )
495

S
shippingwang 已提交
496
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
497
        if in_dygraph_mode():
498 499 500
            decay = imperate_lr.CosineDecay(
                learning_rate, step_each_epoch, epochs
            )
M
minqiyang 已提交
501 502 503
            return decay
        else:
            global_step = _decay_step_counter()
S
shippingwang 已提交
504

505
            cur_epoch = paddle.floor(global_step / step_each_epoch)
506 507 508
            decayed_lr = (
                learning_rate
                * 0.5
509
                * (paddle.cos(cur_epoch * math.pi / epochs) + 1)
510
            )
M
minqiyang 已提交
511
            return decayed_lr
S
shippingwang 已提交
512 513


514 515
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
    """
S
swtkiwi 已提交
516

517 518
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
519

520
    When global_step < warmup_steps, learning rate is updated as:
521

522
    .. code-block:: text
523

524 525
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
526

527
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
528

529
    When global_step >= warmup_steps, learning rate is updated as:
530

531
    .. code-block:: text
532

533
            lr = learning_rate
534

535
    where lr is the learning_rate after warm-up.
536

537
    Args:
538 539 540 541
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
542

543
    Returns:
544
        Variable: Warm-up learning rate with the same data type as learning_rate.
545 546


547
    Examples:
548

549
    .. code-block:: python
550

551
        import paddle.fluid as fluid
552

553 554 555 556 557 558 559 560 561
        boundaries = [100, 200]
        lr_steps = [0.1, 0.01, 0.001]
        learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor
        #learning_rate = 0.1  #case2, single-value
        warmup_steps = 50
        start_lr = 1. / 3.
        end_lr = 0.1
        decayed_lr = fluid.layers.linear_lr_warmup(learning_rate,
            warmup_steps, start_lr, end_lr)
562

563 564 565 566 567 568 569
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        out, = exe.run(fetch_list=[decayed_lr.name])
        print(out)
        # case1: [0.33333334]
        # case2: [0.33333334]
570
    """
Q
qingqing01 已提交
571 572 573 574 575
    dtype = 'float32'
    if isinstance(learning_rate, Variable):
        dtype = learning_rate.dtype

    linear_step = float(end_lr) - float(start_lr)
576
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
577
        if in_dygraph_mode():
578 579 580
            lr = imperate_lr.LinearLrWarmup(
                learning_rate, warmup_steps, start_lr, end_lr
            )
H
hong 已提交
581 582
            return lr
        else:
583
            lr = paddle.static.create_global_var(
584 585 586 587 588 589
                shape=[1],
                value=0.0,
                dtype=dtype,
                persistable=True,
                name="learning_rate_warmup",
            )
H
hong 已提交
590 591

            global_step = _decay_step_counter()
Q
qizhaoaoe 已提交
592 593 594 595 596 597 598 599 600 601
            if not isinstance(learning_rate, Variable):
                learning_rate = paddle.tensor.fill_constant(
                    shape=[1], dtype=dtype, value=float(learning_rate)
                )
            lr_val = paddle.static.nn.case(
                pred_fn_pairs=[
                    (
                        global_step < warmup_steps,
                        lambda: start_lr
                        + linear_step * (global_step / float(warmup_steps)),
602
                    )
Q
qizhaoaoe 已提交
603 604 605 606
                ],
                default=lambda: learning_rate,
            )
            paddle.assign(lr_val, lr)
H
hong 已提交
607
            return lr