learning_rate_scheduler.py 20.7 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23
import math
Q
qingqing01 已提交
24
import numbers
25

26
import paddle
27
from . import nn
姜永久 已提交
28 29 30 31 32 33 34
from ..framework import (
    default_main_program,
    Parameter,
    unique_name,
    name_scope,
    in_dygraph_mode,
)
Q
qingqing01 已提交
35
from ..framework import Variable
M
minqiyang 已提交
36
from ..dygraph import learning_rate_scheduler as imperate_lr
37
from ..data_feeder import check_variable_and_dtype, check_type
Q
Qiao Longfei 已提交
38

39
__all__ = [
40 41 42 43 44 45 46 47
    'exponential_decay',
    'natural_exp_decay',
    'inverse_time_decay',
    'polynomial_decay',
    'piecewise_decay',
    'noam_decay',
    'cosine_decay',
    'linear_lr_warmup',
48
]
Q
Qiao Longfei 已提交
49 50


51
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
52
    # the first global step is zero in learning rate decay
53
    global_step = nn.autoincreased_step_counter(
54 55
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1
    )
56
    global_step = paddle.cast(global_step, 'float32')
Y
Yu Yang 已提交
57 58 59
    return global_step


60
def noam_decay(d_model, warmup_steps, learning_rate=1.0):
Y
yuyang18 已提交
61
    """
S
swtkiwi 已提交
62

Y
yuyang18 已提交
63 64
    Noam decay method. The numpy implementation of noam decay as follows.

X
xiaoting 已提交
65
    .. code-block:: python
66

67
      import paddle.fluid as fluid
X
xiaoting 已提交
68 69
      import numpy as np
      # set hyper parameters
70
      base_lr = 0.01
X
xiaoting 已提交
71 72 73 74
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
75
      lr_value = base_lr * np.power(d_model, -0.5) * np.min([
X
xiaoting 已提交
76 77
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])
Y
yuyang18 已提交
78 79 80

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
81 82 83

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
84

85 86
        warmup_steps(Variable): A super parameter.

87 88 89 90
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0

91 92
    Returns:
        The decayed learning rate.
X
xiaoting 已提交
93 94 95
    Examples:
        .. code-block:: python

96
          import paddle.fluid as fluid
X
xiaoting 已提交
97 98 99 100
          warmup_steps = 100
          learning_rate = 0.01
          lr = fluid.layers.learning_rate_scheduler.noam_decay(
                         1/(warmup_steps *(learning_rate ** 2)),
101 102
                         warmup_steps,
                         learning_rate)
103
    """
104
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
105
        if in_dygraph_mode():
106
            decay = paddle.optimizer.lr.NoamDecay(
107 108
                d_model, warmup_steps, learning_rate=learning_rate
            )
M
minqiyang 已提交
109 110 111
            return decay
        else:
            global_step = _decay_step_counter(1)
F
fengjiayi 已提交
112

M
minqiyang 已提交
113 114
            a = global_step**-0.5
            b = (warmup_steps**-1.5) * global_step
115
            lr_value = learning_rate * (d_model**-0.5) * paddle.minimum(a, b)
116

M
minqiyang 已提交
117
            return lr_value
118 119


Y
Yu Yang 已提交
120
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
121
    """
S
swtkiwi 已提交
122

123
    Applies exponential decay to the learning rate.
F
fengjiayi 已提交
124

125 126
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
F
fengjiayi 已提交
127 128
    'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
129
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
130

F
fengjiayi 已提交
131 132 133 134
    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
    >>> else:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
Q
Qiao Longfei 已提交
135 136

    Args:
137
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
138 139 140
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
141
        staircase(bool): If True, decay the learning rate at discrete intervals, which
K
Kaipeng Deng 已提交
142 143 144
                         means the learning rate will be decayed by `decay_rate` every
                         `decay_steps`. If False, learning rate will be decayed continuously
                         and following the formula above. Default: False
Q
Qiao Longfei 已提交
145 146

    Returns:
K
Kaipeng Deng 已提交
147
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
148 149 150 151

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
152
          import paddle.fluid as fluid
153 154 155
          import paddle

          paddle.enable_static()
F
fengjiayi 已提交
156 157
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
158 159 160 161 162
              learning_rate=fluid.layers.exponential_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
F
fengjiayi 已提交
163

Q
Qiao Longfei 已提交
164
    """
165
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
166
        if in_dygraph_mode():
167 168
            decay = paddle.optimizer.lr.ExponentialDecay(
                learning_rate, decay_rate
169
            )
170 171 172
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
173

174 175
            div_res = global_step / decay_steps
            if staircase:
176
                div_res = paddle.floor(div_res)
177
            decayed_lr = learning_rate * (decay_rate**div_res)
178

179
            return decayed_lr
Q
Qiao Longfei 已提交
180 181


Y
Yu Yang 已提交
182
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
S
swtkiwi 已提交
183 184
    """

185
    Applies natural exponential decay to the initial learning rate.
Q
Qiao Longfei 已提交
186

187 188 189
        When training a model, it is often recommended to lower the learning rate as the
        training progresses. By using this function, the learning rate will be decayed by
        natural exponential power 'decay_rate' every 'decay_steps' steps.
K
Kaipeng Deng 已提交
190

191
        Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
192

193 194 195 196
        >>> if not staircase:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
        >>> else:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps))
Y
Yu Yang 已提交
197

198 199 200 201 202 203 204 205 206
        Args:
            learning_rate(Variable|float): The initial learning rate. It should be a Variable
                                           or a float
            decay_steps(int): The learning rate decay steps. See the decay computation above.
            decay_rate(float): The learning rate decay rate. See the decay computation above.
            staircase(bool): If True, decay the learning rate at discrete intervals, which
                             means the learning rate will be decayed by natural exponential power
                             `decay_rate` every `decay_steps`. If False, learning rate will be
                             decayed continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
207

208 209
        Returns:
            The decayed learning rate. The data type is float32.
K
Kaipeng Deng 已提交
210

211 212
        Examples:
            .. code-block:: python
K
Kaipeng Deng 已提交
213

214 215
              import paddle.fluid as fluid
              import paddle
216

217 218 219 220 221 222 223 224
              paddle.enable_static()
              base_lr = 0.1
              sgd_optimizer = fluid.optimizer.SGD(
                  learning_rate=fluid.layers.natural_exp_decay(
                        learning_rate=base_lr,
                        decay_steps=10000,
                        decay_rate=0.5,
                        staircase=True))
K
Kaipeng Deng 已提交
225

Q
Qiao Longfei 已提交
226
    """
227
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
228
        if in_dygraph_mode():
229 230
            decay = paddle.optimizer.lr.NaturalExpDecay(
                learning_rate, decay_rate
231
            )
232 233 234
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
235

236 237
            div_res = global_step / decay_steps
            if staircase:
238 239
                div_res = paddle.floor(div_res)
            decayed_lr = learning_rate * paddle.exp(-1 * decay_rate * div_res)
240

241
            return decayed_lr
Q
Qiao Longfei 已提交
242 243


Y
Yu Yang 已提交
244
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
245
    """
S
swtkiwi 已提交
246

F
fengjiayi 已提交
247
    Applies inverse time decay to the initial learning rate.
Q
Qiao Longfei 已提交
248

249 250
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
F
fengjiayi 已提交
251
    applied to the initial learning rate.
Q
Qiao Longfei 已提交
252

T
tianshuo78520a 已提交
253
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
254

F
fengjiayi 已提交
255
    >>> if staircase == True:
Y
Yu Yang 已提交
256 257 258 259
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
260
    Args:
261
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
262 263 264
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
265 266 267
        staircase(bool): If True, decay the learning rate at discrete intervals, which
                         means the learning rate will be decayed by `decay_rate` times
                         every `decay_steps`. If False, learning rate will be decayed
K
Kaipeng Deng 已提交
268
                         continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
269 270

    Returns:
K
Kaipeng Deng 已提交
271
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
272 273 274 275

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
276
          import paddle.fluid as fluid
277 278
          import paddle
          paddle.enable_static()
F
fengjiayi 已提交
279 280
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
281 282 283 284 285
              learning_rate=fluid.layers.inverse_time_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
Q
Qiao Longfei 已提交
286
    """
287
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
288
        if in_dygraph_mode():
289 290
            decay = paddle.optimizer.lr.InverseTimeDecay(
                learning_rate, decay_rate
291
            )
292 293 294
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
295

296 297
            div_res = global_step / decay_steps
            if staircase:
298
                div_res = paddle.floor(div_res)
299

300
            decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
301

302
            return decayed_lr
303 304


305 306 307
def polynomial_decay(
    learning_rate, decay_steps, end_learning_rate=0.0001, power=1.0, cycle=False
):
Q
qiaolongfei 已提交
308 309 310
    """
    Applies polynomial decay to the initial learning rate.

X
xiaoting 已提交
311
    .. code-block:: text
Q
qiaolongfei 已提交
312 313 314 315 316 317 318

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
319 320

    Args:
Q
qiaolongfei 已提交
321
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
Q
update  
qiaolongfei 已提交
322
          will be the initial learning rate during training.
Q
qiaolongfei 已提交
323
        decay_steps(int32): A Python `int32` number.
Q
update  
qiaolongfei 已提交
324 325 326
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
327 328

    Returns:
Q
update  
qiaolongfei 已提交
329
        Variable: The decayed learning rate
X
xiaoting 已提交
330 331 332 333 334 335 336 337 338 339 340

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          lr = fluid.layers.polynomial_decay(
              start_lr, total_step, end_lr, power=1)

341
    """
342
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
343
        if in_dygraph_mode():
344
            decay = paddle.optimizer.lr.PolynomialDecay(
345 346
                learning_rate, decay_steps, end_learning_rate, power, cycle
            )
347
            return decay
348
        else:
349 350 351
            global_step = _decay_step_counter()

            if cycle:
352
                div_res = paddle.ceil(global_step / decay_steps)
353
                zero_var = paddle.tensor.fill_constant(
354 355
                    shape=[1], dtype='float32', value=0.0
                )
356
                one_var = paddle.tensor.fill_constant(
357 358
                    shape=[1], dtype='float32', value=1.0
                )
359

Q
qizhaoaoe 已提交
360 361 362 363 364
                div_val = paddle.static.nn.cond(
                    global_step == zero_var, lambda: one_var, lambda: div_res
                )
                paddle.assign(div_val, output=div_res)

365 366
                decay_steps = decay_steps * div_res
            else:
367
                decay_steps_var = paddle.tensor.fill_constant(
368 369
                    shape=[1], dtype='float32', value=float(decay_steps)
                )
370
                global_step = paddle.minimum(x=global_step, y=decay_steps_var)
371 372 373 374

            decayed_lr = (learning_rate - end_learning_rate) * (
                (1 - global_step / decay_steps) ** power
            ) + end_learning_rate
375
            return decayed_lr
376 377


Y
Yu Yang 已提交
378
def piecewise_decay(boundaries, values):
S
swtkiwi 已提交
379 380
    """

381
    Applies piecewise decay to the initial learning rate.
X
Xin Pan 已提交
382

383
        The algorithm can be described as the code below.
X
Xin Pan 已提交
384

385
        .. code-block:: text
X
Xin Pan 已提交
386

X
xiaoting 已提交
387 388
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410
          if step < 10000:
              learning_rate = 1.0
          elif 10000 <= step < 20000:
              learning_rate = 0.5
          else:
              learning_rate = 0.1
        Args:
            boundaries: A list of steps numbers.
            values: A list of learning rate values that will be picked during
                different step boundaries.

        Returns:
            The decayed learning rate.

        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              import paddle
              paddle.enable_static()
              boundaries = [10000, 20000]
              values = [1.0, 0.5, 0.1]
411
              optimizer = paddle.optimizer.Momentum(
412
                  momentum=0.9,
413 414
                  learning_rate=paddle.optimizer.lr.PiecewiseDecay(boundaries, values),
                  weight_decay=paddle.regularizer.L2Decay(1e-4))
X
xiaoting 已提交
415

X
Xin Pan 已提交
416

417
    """
418 419 420 421
    with default_main_program()._lr_schedule_guard():
        if len(values) - len(boundaries) != 1:
            raise ValueError("len(values) - len(boundaries) should be 1")

姜永久 已提交
422
        if in_dygraph_mode():
423
            decay = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
424 425 426
            return decay
        else:
            global_step = _decay_step_counter()
427

428
            lr = paddle.static.create_global_var(
429 430 431 432 433 434
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate",
            )
435
            with paddle.static.nn.control_flow.Switch() as switch:
436
                for i in range(len(boundaries)):
437
                    boundary_val = paddle.tensor.fill_constant(
438 439 440 441 442
                        shape=[1],
                        dtype='float32',
                        value=float(boundaries[i]),
                        force_cpu=True,
                    )
443
                    with switch.case(global_step < boundary_val):
444
                        paddle.tensor.fill_constant(
445 446 447 448 449
                            shape=[1],
                            dtype="float32",
                            value=float(values[i]),
                            out=lr,
                        )
450
                with switch.default():
451
                    paddle.tensor.fill_constant(
452 453 454 455 456
                        shape=[1],
                        dtype="float32",
                        value=float(values[len(values) - 1]),
                        out=lr,
                    )
457
            return lr
W
Wu Yi 已提交
458 459


S
shippingwang 已提交
460
def cosine_decay(learning_rate, step_each_epoch, epochs):
461
    r"""
S
swtkiwi 已提交
462

S
shippingwang 已提交
463 464
    Applies cosine decay to the learning rate.

S
shippingwang 已提交
465
    when training a model, it is often recommended to lower the learning rate as the
S
shippingwang 已提交
466 467
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.
S
shippingwang 已提交
468

469 470
    .. math::

X
xsrobin 已提交
471 472
        decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)

S
shippingwang 已提交
473 474 475 476 477
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.

478
    Returns:
X
xsrobin 已提交
479
        Variable: The decayed learning rate.
S
shippingwang 已提交
480

481
    Examples:
X
xsrobin 已提交
482
        .. code-block:: python
S
shippingwang 已提交
483

X
xsrobin 已提交
484 485 486 487
            import paddle.fluid as fluid
            base_lr = 0.1
            lr = fluid.layers.cosine_decay(
            learning_rate = base_lr, step_each_epoch=10000, epochs=120)
S
shippingwang 已提交
488
    """
489
    check_type(
490
        learning_rate, 'learning_rate', (float, Variable), 'cosine_decay'
491
    )
492

S
shippingwang 已提交
493
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
494
        if in_dygraph_mode():
495 496
            decay = paddle.optimizer.lr.CosineAnnealingDecay(
                learning_rate, epochs
497
            )
M
minqiyang 已提交
498 499 500
            return decay
        else:
            global_step = _decay_step_counter()
S
shippingwang 已提交
501

502
            cur_epoch = paddle.floor(global_step / step_each_epoch)
503 504 505
            decayed_lr = (
                learning_rate
                * 0.5
506
                * (paddle.cos(cur_epoch * math.pi / epochs) + 1)
507
            )
M
minqiyang 已提交
508
            return decayed_lr
S
shippingwang 已提交
509 510


511 512
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
    """
S
swtkiwi 已提交
513

514 515
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
516

517
    When global_step < warmup_steps, learning rate is updated as:
518

519
    .. code-block:: text
520

521 522
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
523

524
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
525

526
    When global_step >= warmup_steps, learning rate is updated as:
527

528
    .. code-block:: text
529

530
            lr = learning_rate
531

532
    where lr is the learning_rate after warm-up.
533

534
    Args:
535 536 537 538
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
539

540
    Returns:
541
        Variable: Warm-up learning rate with the same data type as learning_rate.
542 543


544
    Examples:
545

546
    .. code-block:: python
547

548
        import paddle.fluid as fluid
549

550 551 552 553 554 555 556 557 558
        boundaries = [100, 200]
        lr_steps = [0.1, 0.01, 0.001]
        learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor
        #learning_rate = 0.1  #case2, single-value
        warmup_steps = 50
        start_lr = 1. / 3.
        end_lr = 0.1
        decayed_lr = fluid.layers.linear_lr_warmup(learning_rate,
            warmup_steps, start_lr, end_lr)
559

560 561 562 563 564 565 566
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        out, = exe.run(fetch_list=[decayed_lr.name])
        print(out)
        # case1: [0.33333334]
        # case2: [0.33333334]
567
    """
Q
qingqing01 已提交
568 569 570 571 572
    dtype = 'float32'
    if isinstance(learning_rate, Variable):
        dtype = learning_rate.dtype

    linear_step = float(end_lr) - float(start_lr)
573
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
574
        if in_dygraph_mode():
575
            lr = paddle.optimizer.lr.LinearWarmup(
576 577
                learning_rate, warmup_steps, start_lr, end_lr
            )
H
hong 已提交
578 579
            return lr
        else:
580
            lr = paddle.static.create_global_var(
581 582 583 584 585 586
                shape=[1],
                value=0.0,
                dtype=dtype,
                persistable=True,
                name="learning_rate_warmup",
            )
H
hong 已提交
587 588

            global_step = _decay_step_counter()
Q
qizhaoaoe 已提交
589 590 591 592 593 594 595 596 597 598
            if not isinstance(learning_rate, Variable):
                learning_rate = paddle.tensor.fill_constant(
                    shape=[1], dtype=dtype, value=float(learning_rate)
                )
            lr_val = paddle.static.nn.case(
                pred_fn_pairs=[
                    (
                        global_step < warmup_steps,
                        lambda: start_lr
                        + linear_step * (global_step / float(warmup_steps)),
599
                    )
Q
qizhaoaoe 已提交
600 601 602 603
                ],
                default=lambda: learning_rate,
            )
            paddle.assign(lr_val, lr)
H
hong 已提交
604
            return lr