learning_rate_scheduler.py 20.8 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23
import math
Q
qingqing01 已提交
24
import numbers
25

26 27 28 29
from . import control_flow
from . import nn
from . import ops
from . import tensor
30
from ..framework import default_main_program, Parameter, unique_name, name_scope
Q
qingqing01 已提交
31
from ..framework import Variable
J
Jiabin Yang 已提交
32
from ..framework import _non_static_mode
M
minqiyang 已提交
33
from ..dygraph import learning_rate_scheduler as imperate_lr
34
from ..data_feeder import check_variable_and_dtype, check_type
Q
Qiao Longfei 已提交
35

36
__all__ = [
37 38 39 40 41 42 43 44
    'exponential_decay',
    'natural_exp_decay',
    'inverse_time_decay',
    'polynomial_decay',
    'piecewise_decay',
    'noam_decay',
    'cosine_decay',
    'linear_lr_warmup',
45
]
Q
Qiao Longfei 已提交
46 47


48
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
49
    # the first global step is zero in learning rate decay
50
    global_step = nn.autoincreased_step_counter(
51 52
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1
    )
53
    global_step = tensor.cast(global_step, 'float32')
Y
Yu Yang 已提交
54 55 56
    return global_step


57
def noam_decay(d_model, warmup_steps, learning_rate=1.0):
Y
yuyang18 已提交
58
    """
S
swtkiwi 已提交
59

Y
yuyang18 已提交
60 61
    Noam decay method. The numpy implementation of noam decay as follows.

X
xiaoting 已提交
62
    .. code-block:: python
63

64
      import paddle.fluid as fluid
X
xiaoting 已提交
65 66
      import numpy as np
      # set hyper parameters
67
      base_lr = 0.01
X
xiaoting 已提交
68 69 70 71
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
72
      lr_value = base_lr * np.power(d_model, -0.5) * np.min([
X
xiaoting 已提交
73 74
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])
Y
yuyang18 已提交
75 76 77

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
78 79 80

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
81

82 83
        warmup_steps(Variable): A super parameter.

84 85 86 87
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0

88 89
    Returns:
        The decayed learning rate.
X
xiaoting 已提交
90 91 92
    Examples:
        .. code-block:: python

93
          import paddle.fluid as fluid
X
xiaoting 已提交
94 95 96 97
          warmup_steps = 100
          learning_rate = 0.01
          lr = fluid.layers.learning_rate_scheduler.noam_decay(
                         1/(warmup_steps *(learning_rate ** 2)),
98 99
                         warmup_steps,
                         learning_rate)
100
    """
101
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
102
        if _non_static_mode():
103 104 105
            decay = imperate_lr.NoamDecay(
                d_model, warmup_steps, learning_rate=learning_rate
            )
M
minqiyang 已提交
106 107 108
            return decay
        else:
            global_step = _decay_step_counter(1)
F
fengjiayi 已提交
109

M
minqiyang 已提交
110 111
            a = global_step**-0.5
            b = (warmup_steps**-1.5) * global_step
112 113 114
            lr_value = (
                learning_rate * (d_model**-0.5) * nn.elementwise_min(a, b)
            )
115

M
minqiyang 已提交
116
            return lr_value
117 118


Y
Yu Yang 已提交
119
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
120
    """
S
swtkiwi 已提交
121

122
    Applies exponential decay to the learning rate.
F
fengjiayi 已提交
123

124 125
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
F
fengjiayi 已提交
126 127
    'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
128
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
129

F
fengjiayi 已提交
130 131 132 133
    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
    >>> else:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
Q
Qiao Longfei 已提交
134 135

    Args:
136
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
137 138 139
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
140
        staircase(bool): If True, decay the learning rate at discrete intervals, which
K
Kaipeng Deng 已提交
141 142 143
                         means the learning rate will be decayed by `decay_rate` every
                         `decay_steps`. If False, learning rate will be decayed continuously
                         and following the formula above. Default: False
Q
Qiao Longfei 已提交
144 145

    Returns:
K
Kaipeng Deng 已提交
146
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
147 148 149 150

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
151
          import paddle.fluid as fluid
152 153 154
          import paddle

          paddle.enable_static()
F
fengjiayi 已提交
155 156
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
157 158 159 160 161
              learning_rate=fluid.layers.exponential_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
F
fengjiayi 已提交
162

Q
Qiao Longfei 已提交
163
    """
164
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
165
        if _non_static_mode():
166 167 168
            decay = imperate_lr.ExponentialDecay(
                learning_rate, decay_steps, decay_rate, staircase
            )
169 170 171
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
172

173 174 175 176
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * (decay_rate**div_res)
177

178
            return decayed_lr
Q
Qiao Longfei 已提交
179 180


Y
Yu Yang 已提交
181
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
S
swtkiwi 已提交
182 183
    """

184
    Applies natural exponential decay to the initial learning rate.
Q
Qiao Longfei 已提交
185

186 187 188
        When training a model, it is often recommended to lower the learning rate as the
        training progresses. By using this function, the learning rate will be decayed by
        natural exponential power 'decay_rate' every 'decay_steps' steps.
K
Kaipeng Deng 已提交
189

190
        Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
191

192 193 194 195
        >>> if not staircase:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
        >>> else:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps))
Y
Yu Yang 已提交
196

197 198 199 200 201 202 203 204 205
        Args:
            learning_rate(Variable|float): The initial learning rate. It should be a Variable
                                           or a float
            decay_steps(int): The learning rate decay steps. See the decay computation above.
            decay_rate(float): The learning rate decay rate. See the decay computation above.
            staircase(bool): If True, decay the learning rate at discrete intervals, which
                             means the learning rate will be decayed by natural exponential power
                             `decay_rate` every `decay_steps`. If False, learning rate will be
                             decayed continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
206

207 208
        Returns:
            The decayed learning rate. The data type is float32.
K
Kaipeng Deng 已提交
209

210 211
        Examples:
            .. code-block:: python
K
Kaipeng Deng 已提交
212

213 214
              import paddle.fluid as fluid
              import paddle
215

216 217 218 219 220 221 222 223
              paddle.enable_static()
              base_lr = 0.1
              sgd_optimizer = fluid.optimizer.SGD(
                  learning_rate=fluid.layers.natural_exp_decay(
                        learning_rate=base_lr,
                        decay_steps=10000,
                        decay_rate=0.5,
                        staircase=True))
K
Kaipeng Deng 已提交
224

Q
Qiao Longfei 已提交
225
    """
226
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
227
        if _non_static_mode():
228 229 230
            decay = imperate_lr.NaturalExpDecay(
                learning_rate, decay_steps, decay_rate, staircase
            )
231 232 233
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
234

235 236 237 238
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
239

240
            return decayed_lr
Q
Qiao Longfei 已提交
241 242


Y
Yu Yang 已提交
243
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
244
    """
S
swtkiwi 已提交
245

F
fengjiayi 已提交
246
    Applies inverse time decay to the initial learning rate.
Q
Qiao Longfei 已提交
247

248 249
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
F
fengjiayi 已提交
250
    applied to the initial learning rate.
Q
Qiao Longfei 已提交
251

T
tianshuo78520a 已提交
252
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
253

F
fengjiayi 已提交
254
    >>> if staircase == True:
Y
Yu Yang 已提交
255 256 257 258
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
259
    Args:
260
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
261 262 263
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
264 265 266
        staircase(bool): If True, decay the learning rate at discrete intervals, which
                         means the learning rate will be decayed by `decay_rate` times
                         every `decay_steps`. If False, learning rate will be decayed
K
Kaipeng Deng 已提交
267
                         continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
268 269

    Returns:
K
Kaipeng Deng 已提交
270
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
271 272 273 274

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
275
          import paddle.fluid as fluid
276 277
          import paddle
          paddle.enable_static()
F
fengjiayi 已提交
278 279
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
280 281 282 283 284
              learning_rate=fluid.layers.inverse_time_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
Q
Qiao Longfei 已提交
285
    """
286
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
287
        if _non_static_mode():
288 289 290
            decay = imperate_lr.InverseTimeDecay(
                learning_rate, decay_steps, decay_rate, staircase
            )
291 292 293
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
294

295 296 297
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
298

299
            decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
300

301
            return decayed_lr
302 303


304 305 306
def polynomial_decay(
    learning_rate, decay_steps, end_learning_rate=0.0001, power=1.0, cycle=False
):
Q
qiaolongfei 已提交
307 308 309
    """
    Applies polynomial decay to the initial learning rate.

X
xiaoting 已提交
310
    .. code-block:: text
Q
qiaolongfei 已提交
311 312 313 314 315 316 317

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
318 319

    Args:
Q
qiaolongfei 已提交
320
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
Q
update  
qiaolongfei 已提交
321
          will be the initial learning rate during training.
Q
qiaolongfei 已提交
322
        decay_steps(int32): A Python `int32` number.
Q
update  
qiaolongfei 已提交
323 324 325
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
326 327

    Returns:
Q
update  
qiaolongfei 已提交
328
        Variable: The decayed learning rate
X
xiaoting 已提交
329 330 331 332 333 334 335 336 337 338 339

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          lr = fluid.layers.polynomial_decay(
              start_lr, total_step, end_lr, power=1)

340
    """
341
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
342
        if _non_static_mode():
343 344 345
            decay = imperate_lr.PolynomialDecay(
                learning_rate, decay_steps, end_learning_rate, power, cycle
            )
346
            return decay
347
        else:
348 349 350 351
            global_step = _decay_step_counter()

            if cycle:
                div_res = ops.ceil(global_step / decay_steps)
352 353 354 355 356 357
                zero_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=0.0
                )
                one_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=1.0
                )
358 359 360 361 362 363

                with control_flow.Switch() as switch:
                    with switch.case(global_step == zero_var):
                        tensor.assign(input=one_var, output=div_res)
                decay_steps = decay_steps * div_res
            else:
364 365 366 367 368 369 370 371 372 373
                decay_steps_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=float(decay_steps)
                )
                global_step = nn.elementwise_min(
                    x=global_step, y=decay_steps_var
                )

            decayed_lr = (learning_rate - end_learning_rate) * (
                (1 - global_step / decay_steps) ** power
            ) + end_learning_rate
374
            return decayed_lr
375 376


Y
Yu Yang 已提交
377
def piecewise_decay(boundaries, values):
S
swtkiwi 已提交
378 379
    """

380
    Applies piecewise decay to the initial learning rate.
X
Xin Pan 已提交
381

382
        The algorithm can be described as the code below.
X
Xin Pan 已提交
383

384
        .. code-block:: text
X
Xin Pan 已提交
385

X
xiaoting 已提交
386 387
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
          if step < 10000:
              learning_rate = 1.0
          elif 10000 <= step < 20000:
              learning_rate = 0.5
          else:
              learning_rate = 0.1
        Args:
            boundaries: A list of steps numbers.
            values: A list of learning rate values that will be picked during
                different step boundaries.

        Returns:
            The decayed learning rate.

        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              import paddle
              paddle.enable_static()
              boundaries = [10000, 20000]
              values = [1.0, 0.5, 0.1]
              optimizer = fluid.optimizer.Momentum(
                  momentum=0.9,
                  learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values),
                  regularization=fluid.regularizer.L2Decay(1e-4))
X
xiaoting 已提交
414

X
Xin Pan 已提交
415

416
    """
417 418 419 420
    with default_main_program()._lr_schedule_guard():
        if len(values) - len(boundaries) != 1:
            raise ValueError("len(values) - len(boundaries) should be 1")

J
Jiabin Yang 已提交
421
        if _non_static_mode():
M
minqiyang 已提交
422
            decay = imperate_lr.PiecewiseDecay(boundaries, values, 0)
423 424 425
            return decay
        else:
            global_step = _decay_step_counter()
426

427 428 429 430 431 432 433
            lr = tensor.create_global_var(
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate",
            )
434

435 436
            with control_flow.Switch() as switch:
                for i in range(len(boundaries)):
437 438 439 440 441 442
                    boundary_val = tensor.fill_constant(
                        shape=[1],
                        dtype='float32',
                        value=float(boundaries[i]),
                        force_cpu=True,
                    )
443
                    with switch.case(global_step < boundary_val):
444 445 446 447 448 449
                        tensor.fill_constant(
                            shape=[1],
                            dtype="float32",
                            value=float(values[i]),
                            out=lr,
                        )
450
                with switch.default():
451 452 453 454 455 456
                    tensor.fill_constant(
                        shape=[1],
                        dtype="float32",
                        value=float(values[len(values) - 1]),
                        out=lr,
                    )
457

458
            return lr
W
Wu Yi 已提交
459 460


S
shippingwang 已提交
461
def cosine_decay(learning_rate, step_each_epoch, epochs):
462
    r"""
S
swtkiwi 已提交
463

S
shippingwang 已提交
464 465
    Applies cosine decay to the learning rate.

S
shippingwang 已提交
466
    when training a model, it is often recommended to lower the learning rate as the
S
shippingwang 已提交
467 468
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.
S
shippingwang 已提交
469

470 471
    .. math::

X
xsrobin 已提交
472 473
        decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)

S
shippingwang 已提交
474 475 476 477 478
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.

479
    Returns:
X
xsrobin 已提交
480
        Variable: The decayed learning rate.
S
shippingwang 已提交
481

482
    Examples:
X
xsrobin 已提交
483
        .. code-block:: python
S
shippingwang 已提交
484

X
xsrobin 已提交
485 486 487 488
            import paddle.fluid as fluid
            base_lr = 0.1
            lr = fluid.layers.cosine_decay(
            learning_rate = base_lr, step_each_epoch=10000, epochs=120)
S
shippingwang 已提交
489
    """
490 491 492
    check_type(
        learning_rate, 'learning_rate', (float, tensor.Variable), 'cosine_decay'
    )
493

S
shippingwang 已提交
494
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
495
        if _non_static_mode():
496 497 498
            decay = imperate_lr.CosineDecay(
                learning_rate, step_each_epoch, epochs
            )
M
minqiyang 已提交
499 500 501
            return decay
        else:
            global_step = _decay_step_counter()
S
shippingwang 已提交
502

M
minqiyang 已提交
503
            cur_epoch = ops.floor(global_step / step_each_epoch)
504 505 506 507 508
            decayed_lr = (
                learning_rate
                * 0.5
                * (ops.cos(cur_epoch * math.pi / epochs) + 1)
            )
M
minqiyang 已提交
509
            return decayed_lr
S
shippingwang 已提交
510 511


512 513
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
    """
S
swtkiwi 已提交
514

515 516
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
517

518
    When global_step < warmup_steps, learning rate is updated as:
519

520
    .. code-block:: text
521

522 523
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
524

525
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
526

527
    When global_step >= warmup_steps, learning rate is updated as:
528

529
    .. code-block:: text
530

531
            lr = learning_rate
532

533
    where lr is the learning_rate after warm-up.
534

535
    Args:
536 537 538 539
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
540

541
    Returns:
542
        Variable: Warm-up learning rate with the same data type as learning_rate.
543 544


545
    Examples:
546

547
    .. code-block:: python
548

549
        import paddle.fluid as fluid
550

551 552 553 554 555 556 557 558 559
        boundaries = [100, 200]
        lr_steps = [0.1, 0.01, 0.001]
        learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor
        #learning_rate = 0.1  #case2, single-value
        warmup_steps = 50
        start_lr = 1. / 3.
        end_lr = 0.1
        decayed_lr = fluid.layers.linear_lr_warmup(learning_rate,
            warmup_steps, start_lr, end_lr)
560

561 562 563 564 565 566 567
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        out, = exe.run(fetch_list=[decayed_lr.name])
        print(out)
        # case1: [0.33333334]
        # case2: [0.33333334]
568
    """
Q
qingqing01 已提交
569 570 571 572 573
    dtype = 'float32'
    if isinstance(learning_rate, Variable):
        dtype = learning_rate.dtype

    linear_step = float(end_lr) - float(start_lr)
574
    with default_main_program()._lr_schedule_guard():
H
hong 已提交
575

J
Jiabin Yang 已提交
576
        if _non_static_mode():
577 578 579
            lr = imperate_lr.LinearLrWarmup(
                learning_rate, warmup_steps, start_lr, end_lr
            )
H
hong 已提交
580 581
            return lr
        else:
582 583 584 585 586 587 588
            lr = tensor.create_global_var(
                shape=[1],
                value=0.0,
                dtype=dtype,
                persistable=True,
                name="learning_rate_warmup",
            )
H
hong 已提交
589 590 591 592 593

            global_step = _decay_step_counter()

            with control_flow.Switch() as switch:
                with switch.case(global_step < warmup_steps):
594 595 596
                    decayed_lr = start_lr + linear_step * (
                        global_step / float(warmup_steps)
                    )
H
hong 已提交
597 598 599 600
                    tensor.assign(decayed_lr, lr)
                with switch.default():
                    if not isinstance(learning_rate, Variable):
                        learning_rate = tensor.fill_constant(
601 602
                            shape=[1], dtype=dtype, value=float(learning_rate)
                        )
H
hong 已提交
603 604
                    tensor.assign(learning_rate, lr)
            return lr