learning_rate_scheduler.py 20.7 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23
import math
Q
qingqing01 已提交
24
import numbers
25

26
import paddle
27 28
from . import nn
from . import tensor
姜永久 已提交
29 30 31 32 33 34 35
from ..framework import (
    default_main_program,
    Parameter,
    unique_name,
    name_scope,
    in_dygraph_mode,
)
Q
qingqing01 已提交
36
from ..framework import Variable
M
minqiyang 已提交
37
from ..dygraph import learning_rate_scheduler as imperate_lr
38
from ..data_feeder import check_variable_and_dtype, check_type
Q
Qiao Longfei 已提交
39

40
__all__ = [
41 42 43 44 45 46 47 48
    'exponential_decay',
    'natural_exp_decay',
    'inverse_time_decay',
    'polynomial_decay',
    'piecewise_decay',
    'noam_decay',
    'cosine_decay',
    'linear_lr_warmup',
49
]
Q
Qiao Longfei 已提交
50 51


52
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
53
    # the first global step is zero in learning rate decay
54
    global_step = nn.autoincreased_step_counter(
55 56
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1
    )
57
    global_step = paddle.cast(global_step, 'float32')
Y
Yu Yang 已提交
58 59 60
    return global_step


61
def noam_decay(d_model, warmup_steps, learning_rate=1.0):
Y
yuyang18 已提交
62
    """
S
swtkiwi 已提交
63

Y
yuyang18 已提交
64 65
    Noam decay method. The numpy implementation of noam decay as follows.

X
xiaoting 已提交
66
    .. code-block:: python
67

68
      import paddle.fluid as fluid
X
xiaoting 已提交
69 70
      import numpy as np
      # set hyper parameters
71
      base_lr = 0.01
X
xiaoting 已提交
72 73 74 75
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
76
      lr_value = base_lr * np.power(d_model, -0.5) * np.min([
X
xiaoting 已提交
77 78
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])
Y
yuyang18 已提交
79 80 81

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
82 83 84

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
85

86 87
        warmup_steps(Variable): A super parameter.

88 89 90 91
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0

92 93
    Returns:
        The decayed learning rate.
X
xiaoting 已提交
94 95 96
    Examples:
        .. code-block:: python

97
          import paddle.fluid as fluid
X
xiaoting 已提交
98 99 100 101
          warmup_steps = 100
          learning_rate = 0.01
          lr = fluid.layers.learning_rate_scheduler.noam_decay(
                         1/(warmup_steps *(learning_rate ** 2)),
102 103
                         warmup_steps,
                         learning_rate)
104
    """
105
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
106
        if in_dygraph_mode():
107
            decay = paddle.optimizer.lr.NoamDecay(
108 109
                d_model, warmup_steps, learning_rate=learning_rate
            )
M
minqiyang 已提交
110 111 112
            return decay
        else:
            global_step = _decay_step_counter(1)
F
fengjiayi 已提交
113

M
minqiyang 已提交
114 115
            a = global_step**-0.5
            b = (warmup_steps**-1.5) * global_step
116
            lr_value = learning_rate * (d_model**-0.5) * paddle.minimum(a, b)
117

M
minqiyang 已提交
118
            return lr_value
119 120


Y
Yu Yang 已提交
121
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
122
    """
S
swtkiwi 已提交
123

124
    Applies exponential decay to the learning rate.
F
fengjiayi 已提交
125

126 127
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
F
fengjiayi 已提交
128 129
    'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
130
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
131

F
fengjiayi 已提交
132 133 134 135
    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
    >>> else:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
Q
Qiao Longfei 已提交
136 137

    Args:
138
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
139 140 141
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
142
        staircase(bool): If True, decay the learning rate at discrete intervals, which
K
Kaipeng Deng 已提交
143 144 145
                         means the learning rate will be decayed by `decay_rate` every
                         `decay_steps`. If False, learning rate will be decayed continuously
                         and following the formula above. Default: False
Q
Qiao Longfei 已提交
146 147

    Returns:
K
Kaipeng Deng 已提交
148
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
149 150 151 152

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
153
          import paddle.fluid as fluid
154 155 156
          import paddle

          paddle.enable_static()
F
fengjiayi 已提交
157 158
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
159 160 161 162 163
              learning_rate=fluid.layers.exponential_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
F
fengjiayi 已提交
164

Q
Qiao Longfei 已提交
165
    """
166
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
167
        if in_dygraph_mode():
168 169
            decay = paddle.optimizer.lr.ExponentialDecay(
                learning_rate, decay_rate
170
            )
171 172 173
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
174

175 176
            div_res = global_step / decay_steps
            if staircase:
177
                div_res = paddle.floor(div_res)
178
            decayed_lr = learning_rate * (decay_rate**div_res)
179

180
            return decayed_lr
Q
Qiao Longfei 已提交
181 182


Y
Yu Yang 已提交
183
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
S
swtkiwi 已提交
184 185
    """

186
    Applies natural exponential decay to the initial learning rate.
Q
Qiao Longfei 已提交
187

188 189 190
        When training a model, it is often recommended to lower the learning rate as the
        training progresses. By using this function, the learning rate will be decayed by
        natural exponential power 'decay_rate' every 'decay_steps' steps.
K
Kaipeng Deng 已提交
191

192
        Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
193

194 195 196 197
        >>> if not staircase:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
        >>> else:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps))
Y
Yu Yang 已提交
198

199 200 201 202 203 204 205 206 207
        Args:
            learning_rate(Variable|float): The initial learning rate. It should be a Variable
                                           or a float
            decay_steps(int): The learning rate decay steps. See the decay computation above.
            decay_rate(float): The learning rate decay rate. See the decay computation above.
            staircase(bool): If True, decay the learning rate at discrete intervals, which
                             means the learning rate will be decayed by natural exponential power
                             `decay_rate` every `decay_steps`. If False, learning rate will be
                             decayed continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
208

209 210
        Returns:
            The decayed learning rate. The data type is float32.
K
Kaipeng Deng 已提交
211

212 213
        Examples:
            .. code-block:: python
K
Kaipeng Deng 已提交
214

215 216
              import paddle.fluid as fluid
              import paddle
217

218 219 220 221 222 223 224 225
              paddle.enable_static()
              base_lr = 0.1
              sgd_optimizer = fluid.optimizer.SGD(
                  learning_rate=fluid.layers.natural_exp_decay(
                        learning_rate=base_lr,
                        decay_steps=10000,
                        decay_rate=0.5,
                        staircase=True))
K
Kaipeng Deng 已提交
226

Q
Qiao Longfei 已提交
227
    """
228
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
229
        if in_dygraph_mode():
230 231
            decay = paddle.optimizer.lr.NaturalExpDecay(
                learning_rate, decay_rate
232
            )
233 234 235
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
236

237 238
            div_res = global_step / decay_steps
            if staircase:
239 240
                div_res = paddle.floor(div_res)
            decayed_lr = learning_rate * paddle.exp(-1 * decay_rate * div_res)
241

242
            return decayed_lr
Q
Qiao Longfei 已提交
243 244


Y
Yu Yang 已提交
245
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
246
    """
S
swtkiwi 已提交
247

F
fengjiayi 已提交
248
    Applies inverse time decay to the initial learning rate.
Q
Qiao Longfei 已提交
249

250 251
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
F
fengjiayi 已提交
252
    applied to the initial learning rate.
Q
Qiao Longfei 已提交
253

T
tianshuo78520a 已提交
254
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
255

F
fengjiayi 已提交
256
    >>> if staircase == True:
Y
Yu Yang 已提交
257 258 259 260
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
261
    Args:
262
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
263 264 265
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
266 267 268
        staircase(bool): If True, decay the learning rate at discrete intervals, which
                         means the learning rate will be decayed by `decay_rate` times
                         every `decay_steps`. If False, learning rate will be decayed
K
Kaipeng Deng 已提交
269
                         continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
270 271

    Returns:
K
Kaipeng Deng 已提交
272
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
273 274 275 276

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
277
          import paddle.fluid as fluid
278 279
          import paddle
          paddle.enable_static()
F
fengjiayi 已提交
280 281
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
282 283 284 285 286
              learning_rate=fluid.layers.inverse_time_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
Q
Qiao Longfei 已提交
287
    """
288
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
289
        if in_dygraph_mode():
290 291
            decay = paddle.optimizer.lr.InverseTimeDecay(
                learning_rate, decay_rate
292
            )
293 294 295
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
296

297 298
            div_res = global_step / decay_steps
            if staircase:
299
                div_res = paddle.floor(div_res)
300

301
            decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
302

303
            return decayed_lr
304 305


306 307 308
def polynomial_decay(
    learning_rate, decay_steps, end_learning_rate=0.0001, power=1.0, cycle=False
):
Q
qiaolongfei 已提交
309 310 311
    """
    Applies polynomial decay to the initial learning rate.

X
xiaoting 已提交
312
    .. code-block:: text
Q
qiaolongfei 已提交
313 314 315 316 317 318 319

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
320 321

    Args:
Q
qiaolongfei 已提交
322
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
Q
update  
qiaolongfei 已提交
323
          will be the initial learning rate during training.
Q
qiaolongfei 已提交
324
        decay_steps(int32): A Python `int32` number.
Q
update  
qiaolongfei 已提交
325 326 327
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
328 329

    Returns:
Q
update  
qiaolongfei 已提交
330
        Variable: The decayed learning rate
X
xiaoting 已提交
331 332 333 334 335 336 337 338 339 340 341

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          lr = fluid.layers.polynomial_decay(
              start_lr, total_step, end_lr, power=1)

342
    """
343
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
344
        if in_dygraph_mode():
345
            decay = paddle.optimizer.lr.PolynomialDecay(
346 347
                learning_rate, decay_steps, end_learning_rate, power, cycle
            )
348
            return decay
349
        else:
350 351 352
            global_step = _decay_step_counter()

            if cycle:
353
                div_res = paddle.ceil(global_step / decay_steps)
354
                zero_var = paddle.tensor.fill_constant(
355 356
                    shape=[1], dtype='float32', value=0.0
                )
357
                one_var = paddle.tensor.fill_constant(
358 359
                    shape=[1], dtype='float32', value=1.0
                )
360

Q
qizhaoaoe 已提交
361 362 363 364 365
                div_val = paddle.static.nn.cond(
                    global_step == zero_var, lambda: one_var, lambda: div_res
                )
                paddle.assign(div_val, output=div_res)

366 367
                decay_steps = decay_steps * div_res
            else:
368
                decay_steps_var = paddle.tensor.fill_constant(
369 370
                    shape=[1], dtype='float32', value=float(decay_steps)
                )
371
                global_step = paddle.minimum(x=global_step, y=decay_steps_var)
372 373 374 375

            decayed_lr = (learning_rate - end_learning_rate) * (
                (1 - global_step / decay_steps) ** power
            ) + end_learning_rate
376
            return decayed_lr
377 378


Y
Yu Yang 已提交
379
def piecewise_decay(boundaries, values):
S
swtkiwi 已提交
380 381
    """

382
    Applies piecewise decay to the initial learning rate.
X
Xin Pan 已提交
383

384
        The algorithm can be described as the code below.
X
Xin Pan 已提交
385

386
        .. code-block:: text
X
Xin Pan 已提交
387

X
xiaoting 已提交
388 389
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411
          if step < 10000:
              learning_rate = 1.0
          elif 10000 <= step < 20000:
              learning_rate = 0.5
          else:
              learning_rate = 0.1
        Args:
            boundaries: A list of steps numbers.
            values: A list of learning rate values that will be picked during
                different step boundaries.

        Returns:
            The decayed learning rate.

        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              import paddle
              paddle.enable_static()
              boundaries = [10000, 20000]
              values = [1.0, 0.5, 0.1]
412
              optimizer = paddle.optimizer.Momentum(
413
                  momentum=0.9,
414 415
                  learning_rate=paddle.optimizer.lr.PiecewiseDecay(boundaries, values),
                  weight_decay=paddle.regularizer.L2Decay(1e-4))
X
xiaoting 已提交
416

X
Xin Pan 已提交
417

418
    """
419 420 421 422
    with default_main_program()._lr_schedule_guard():
        if len(values) - len(boundaries) != 1:
            raise ValueError("len(values) - len(boundaries) should be 1")

姜永久 已提交
423
        if in_dygraph_mode():
424
            decay = paddle.optimizer.lr.PiecewiseDecay(boundaries, values)
425 426 427
            return decay
        else:
            global_step = _decay_step_counter()
428

429
            lr = paddle.static.create_global_var(
430 431 432 433 434 435
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate",
            )
436
            with paddle.static.nn.control_flow.Switch() as switch:
437
                for i in range(len(boundaries)):
438
                    boundary_val = paddle.tensor.fill_constant(
439 440 441 442 443
                        shape=[1],
                        dtype='float32',
                        value=float(boundaries[i]),
                        force_cpu=True,
                    )
444
                    with switch.case(global_step < boundary_val):
445
                        paddle.tensor.fill_constant(
446 447 448 449 450
                            shape=[1],
                            dtype="float32",
                            value=float(values[i]),
                            out=lr,
                        )
451
                with switch.default():
452
                    paddle.tensor.fill_constant(
453 454 455 456 457
                        shape=[1],
                        dtype="float32",
                        value=float(values[len(values) - 1]),
                        out=lr,
                    )
458
            return lr
W
Wu Yi 已提交
459 460


S
shippingwang 已提交
461
def cosine_decay(learning_rate, step_each_epoch, epochs):
462
    r"""
S
swtkiwi 已提交
463

S
shippingwang 已提交
464 465
    Applies cosine decay to the learning rate.

S
shippingwang 已提交
466
    when training a model, it is often recommended to lower the learning rate as the
S
shippingwang 已提交
467 468
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.
S
shippingwang 已提交
469

470 471
    .. math::

X
xsrobin 已提交
472 473
        decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)

S
shippingwang 已提交
474 475 476 477 478
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.

479
    Returns:
X
xsrobin 已提交
480
        Variable: The decayed learning rate.
S
shippingwang 已提交
481

482
    Examples:
X
xsrobin 已提交
483
        .. code-block:: python
S
shippingwang 已提交
484

X
xsrobin 已提交
485 486 487 488
            import paddle.fluid as fluid
            base_lr = 0.1
            lr = fluid.layers.cosine_decay(
            learning_rate = base_lr, step_each_epoch=10000, epochs=120)
S
shippingwang 已提交
489
    """
490 491 492
    check_type(
        learning_rate, 'learning_rate', (float, tensor.Variable), 'cosine_decay'
    )
493

S
shippingwang 已提交
494
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
495
        if in_dygraph_mode():
496 497
            decay = paddle.optimizer.lr.CosineAnnealingDecay(
                learning_rate, epochs
498
            )
M
minqiyang 已提交
499 500 501
            return decay
        else:
            global_step = _decay_step_counter()
S
shippingwang 已提交
502

503
            cur_epoch = paddle.floor(global_step / step_each_epoch)
504 505 506
            decayed_lr = (
                learning_rate
                * 0.5
507
                * (paddle.cos(cur_epoch * math.pi / epochs) + 1)
508
            )
M
minqiyang 已提交
509
            return decayed_lr
S
shippingwang 已提交
510 511


512 513
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
    """
S
swtkiwi 已提交
514

515 516
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
517

518
    When global_step < warmup_steps, learning rate is updated as:
519

520
    .. code-block:: text
521

522 523
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
524

525
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
526

527
    When global_step >= warmup_steps, learning rate is updated as:
528

529
    .. code-block:: text
530

531
            lr = learning_rate
532

533
    where lr is the learning_rate after warm-up.
534

535
    Args:
536 537 538 539
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
540

541
    Returns:
542
        Variable: Warm-up learning rate with the same data type as learning_rate.
543 544


545
    Examples:
546

547
    .. code-block:: python
548

549
        import paddle.fluid as fluid
550

551 552 553 554 555 556 557 558 559
        boundaries = [100, 200]
        lr_steps = [0.1, 0.01, 0.001]
        learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor
        #learning_rate = 0.1  #case2, single-value
        warmup_steps = 50
        start_lr = 1. / 3.
        end_lr = 0.1
        decayed_lr = fluid.layers.linear_lr_warmup(learning_rate,
            warmup_steps, start_lr, end_lr)
560

561 562 563 564 565 566 567
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        out, = exe.run(fetch_list=[decayed_lr.name])
        print(out)
        # case1: [0.33333334]
        # case2: [0.33333334]
568
    """
Q
qingqing01 已提交
569 570 571 572 573
    dtype = 'float32'
    if isinstance(learning_rate, Variable):
        dtype = learning_rate.dtype

    linear_step = float(end_lr) - float(start_lr)
574
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
575
        if in_dygraph_mode():
576
            lr = paddle.optimizer.lr.LinearWarmup(
577 578
                learning_rate, warmup_steps, start_lr, end_lr
            )
H
hong 已提交
579 580
            return lr
        else:
581
            lr = paddle.static.create_global_var(
582 583 584 585 586 587
                shape=[1],
                value=0.0,
                dtype=dtype,
                persistable=True,
                name="learning_rate_warmup",
            )
H
hong 已提交
588 589

            global_step = _decay_step_counter()
Q
qizhaoaoe 已提交
590 591 592 593 594 595 596 597 598 599
            if not isinstance(learning_rate, Variable):
                learning_rate = paddle.tensor.fill_constant(
                    shape=[1], dtype=dtype, value=float(learning_rate)
                )
            lr_val = paddle.static.nn.case(
                pred_fn_pairs=[
                    (
                        global_step < warmup_steps,
                        lambda: start_lr
                        + linear_step * (global_step / float(warmup_steps)),
600
                    )
Q
qizhaoaoe 已提交
601 602 603 604
                ],
                default=lambda: learning_rate,
            )
            paddle.assign(lr_val, lr)
H
hong 已提交
605
            return lr