learning_rate_scheduler.py 20.4 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23 24
from __future__ import print_function

25
import math
Q
qingqing01 已提交
26
import numbers
27

28 29 30 31
from . import control_flow
from . import nn
from . import ops
from . import tensor
32
from ..framework import default_main_program, Parameter, unique_name, name_scope
Q
qingqing01 已提交
33
from ..framework import Variable
34
from ..framework import in_dygraph_mode
M
minqiyang 已提交
35
from ..dygraph import learning_rate_scheduler as imperate_lr
36
from ..data_feeder import check_variable_and_dtype, check_type
Q
Qiao Longfei 已提交
37

38 39
__all__ = [
    'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
40 41
    'polynomial_decay', 'piecewise_decay', 'noam_decay', 'cosine_decay',
    'linear_lr_warmup'
42
]
Q
Qiao Longfei 已提交
43 44


45
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
46
    # the first global step is zero in learning rate decay
47
    global_step = nn.autoincreased_step_counter(
48
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1)
49
    global_step = tensor.cast(global_step, 'float32')
Y
Yu Yang 已提交
50 51 52
    return global_step


53
def noam_decay(d_model, warmup_steps, learning_rate=1.0):
Y
yuyang18 已提交
54
    """
S
swtkiwi 已提交
55

Y
yuyang18 已提交
56 57
    Noam decay method. The numpy implementation of noam decay as follows.

X
xiaoting 已提交
58 59
    .. code-block:: python
      
60
      import paddle.fluid as fluid
X
xiaoting 已提交
61 62
      import numpy as np
      # set hyper parameters
63
      base_lr = 0.01
X
xiaoting 已提交
64 65 66 67
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
68
      lr_value = base_lr * np.power(d_model, -0.5) * np.min([
X
xiaoting 已提交
69 70
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])
Y
yuyang18 已提交
71 72 73

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
74 75 76

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
77

78 79
        warmup_steps(Variable): A super parameter.

80 81 82 83
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0

84 85
    Returns:
        The decayed learning rate.
X
xiaoting 已提交
86 87 88
    Examples:
        .. code-block:: python

89
          import paddle.fluid as fluid
X
xiaoting 已提交
90 91 92 93
          warmup_steps = 100
          learning_rate = 0.01
          lr = fluid.layers.learning_rate_scheduler.noam_decay(
                         1/(warmup_steps *(learning_rate ** 2)),
94 95
                         warmup_steps,
                         learning_rate)
96
    """
97
    with default_main_program()._lr_schedule_guard():
98
        if in_dygraph_mode():
99 100
            decay = imperate_lr.NoamDecay(
                d_model, warmup_steps, learning_rate=learning_rate)
M
minqiyang 已提交
101 102 103
            return decay
        else:
            global_step = _decay_step_counter(1)
F
fengjiayi 已提交
104

M
minqiyang 已提交
105 106
            a = global_step**-0.5
            b = (warmup_steps**-1.5) * global_step
107 108
            lr_value = learning_rate * (d_model**-0.5) * nn.elementwise_min(a,
                                                                            b)
109

M
minqiyang 已提交
110
            return lr_value
111 112


Y
Yu Yang 已提交
113
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
114
    """
S
swtkiwi 已提交
115

116
    Applies exponential decay to the learning rate.
F
fengjiayi 已提交
117

118 119
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
F
fengjiayi 已提交
120 121
    'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
122
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
123

F
fengjiayi 已提交
124 125 126 127
    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
    >>> else:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
Q
Qiao Longfei 已提交
128 129

    Args:
K
Kaipeng Deng 已提交
130 131 132 133 134 135 136 137
        learning_rate(Variable|float): The initial learning rate. It should be a Variable 
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
        staircase(bool): If True, decay the learning rate at discrete intervals, which 
                         means the learning rate will be decayed by `decay_rate` every
                         `decay_steps`. If False, learning rate will be decayed continuously
                         and following the formula above. Default: False
Q
Qiao Longfei 已提交
138 139

    Returns:
K
Kaipeng Deng 已提交
140
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
141 142 143 144

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
145
          import paddle.fluid as fluid
146 147 148
          import paddle

          paddle.enable_static()
F
fengjiayi 已提交
149 150
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
K
Kaipeng Deng 已提交
151 152 153 154 155
	      learning_rate=fluid.layers.exponential_decay(
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))
F
fengjiayi 已提交
156

Q
Qiao Longfei 已提交
157
    """
158
    with default_main_program()._lr_schedule_guard():
159
        if in_dygraph_mode():
160 161 162 163 164
            decay = imperate_lr.ExponentialDecay(learning_rate, decay_steps,
                                                 decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
165

166 167 168 169
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * (decay_rate**div_res)
170

171
            return decayed_lr
Q
Qiao Longfei 已提交
172 173


Y
Yu Yang 已提交
174
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
S
swtkiwi 已提交
175 176 177
    """

Applies natural exponential decay to the initial learning rate.
Q
Qiao Longfei 已提交
178

K
Kaipeng Deng 已提交
179 180 181 182
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
    natural exponential power 'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
183
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
184

Y
Yu Yang 已提交
185 186 187
    >>> if not staircase:
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
    >>> else:
188
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps))
Y
Yu Yang 已提交
189

Q
Qiao Longfei 已提交
190
    Args:
K
Kaipeng Deng 已提交
191 192 193 194 195
        learning_rate(Variable|float): The initial learning rate. It should be a Variable 
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
        staircase(bool): If True, decay the learning rate at discrete intervals, which 
T
tianshuo78520a 已提交
196
                         means the learning rate will be decayed by natural exponential power
K
Kaipeng Deng 已提交
197 198
                         `decay_rate` every `decay_steps`. If False, learning rate will be
                         decayed continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
199 200

    Returns:
K
Kaipeng Deng 已提交
201
        The decayed learning rate. The data type is float32.
K
Kaipeng Deng 已提交
202 203 204 205 206

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
207 208 209
          import paddle

          paddle.enable_static()
K
Kaipeng Deng 已提交
210 211 212 213 214 215 216 217
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
	      learning_rate=fluid.layers.natural_exp_decay(
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))

Q
Qiao Longfei 已提交
218
    """
219
    with default_main_program()._lr_schedule_guard():
220
        if in_dygraph_mode():
221 222 223 224 225
            decay = imperate_lr.NaturalExpDecay(learning_rate, decay_steps,
                                                decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
226

227 228 229 230
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
231

232
            return decayed_lr
Q
Qiao Longfei 已提交
233 234


Y
Yu Yang 已提交
235
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
236
    """
S
swtkiwi 已提交
237

F
fengjiayi 已提交
238
    Applies inverse time decay to the initial learning rate.
Q
Qiao Longfei 已提交
239

240 241
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
F
fengjiayi 已提交
242
    applied to the initial learning rate.
Q
Qiao Longfei 已提交
243

T
tianshuo78520a 已提交
244
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
245

F
fengjiayi 已提交
246
    >>> if staircase == True:
Y
Yu Yang 已提交
247 248 249 250
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
251
    Args:
K
Kaipeng Deng 已提交
252 253 254 255 256 257 258 259
        learning_rate(Variable|float): The initial learning rate. It should be a Variable 
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
        staircase(bool): If True, decay the learning rate at discrete intervals, which 
                         means the learning rate will be decayed by `decay_rate` times 
                         every `decay_steps`. If False, learning rate will be decayed 
                         continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
260 261

    Returns:
K
Kaipeng Deng 已提交
262
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
263 264 265 266

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
267
          import paddle.fluid as fluid
268 269
          import paddle
          paddle.enable_static()
F
fengjiayi 已提交
270 271
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
K
Kaipeng Deng 已提交
272
	      learning_rate=fluid.layers.inverse_time_decay(
K
Kaipeng Deng 已提交
273 274 275 276
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))
Q
Qiao Longfei 已提交
277
    """
278
    with default_main_program()._lr_schedule_guard():
279
        if in_dygraph_mode():
280 281 282 283 284
            decay = imperate_lr.InverseTimeDecay(learning_rate, decay_steps,
                                                 decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
285

286 287 288
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
289

290
            decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
291

292
            return decayed_lr
293 294 295 296 297 298 299


def polynomial_decay(learning_rate,
                     decay_steps,
                     end_learning_rate=0.0001,
                     power=1.0,
                     cycle=False):
Q
qiaolongfei 已提交
300 301 302
    """
    Applies polynomial decay to the initial learning rate.

X
xiaoting 已提交
303
    .. code-block:: text
Q
qiaolongfei 已提交
304 305 306 307 308 309 310

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
311 312

    Args:
Q
qiaolongfei 已提交
313
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
Q
update  
qiaolongfei 已提交
314
          will be the initial learning rate during training.
Q
qiaolongfei 已提交
315
        decay_steps(int32): A Python `int32` number.
Q
update  
qiaolongfei 已提交
316 317 318
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
319 320

    Returns:
Q
update  
qiaolongfei 已提交
321
        Variable: The decayed learning rate
X
xiaoting 已提交
322 323 324 325 326 327 328 329 330 331 332

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          lr = fluid.layers.polynomial_decay(
              start_lr, total_step, end_lr, power=1)

333
    """
334
    with default_main_program()._lr_schedule_guard():
335
        if in_dygraph_mode():
336 337 338
            decay = imperate_lr.PolynomialDecay(learning_rate, decay_steps,
                                                end_learning_rate, power, cycle)
            return decay
339
        else:
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357
            global_step = _decay_step_counter()

            if cycle:
                div_res = ops.ceil(global_step / decay_steps)
                zero_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=0.0)
                one_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=1.0)

                with control_flow.Switch() as switch:
                    with switch.case(global_step == zero_var):
                        tensor.assign(input=one_var, output=div_res)
                decay_steps = decay_steps * div_res
            else:
                decay_steps_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=float(decay_steps))
                global_step = nn.elementwise_min(
                    x=global_step, y=decay_steps_var)
358

359 360 361
            decayed_lr = (learning_rate - end_learning_rate) * \
                ((1 - global_step / decay_steps) ** power) + end_learning_rate
            return decayed_lr
362 363


Y
Yu Yang 已提交
364
def piecewise_decay(boundaries, values):
S
swtkiwi 已提交
365 366 367
    """

Applies piecewise decay to the initial learning rate.
368

X
xiaoting 已提交
369
    The algorithm can be described as the code below.
X
Xin Pan 已提交
370

X
xiaoting 已提交
371
    .. code-block:: text
X
Xin Pan 已提交
372

X
xiaoting 已提交
373 374 375 376 377 378 379 380
      boundaries = [10000, 20000]
      values = [1.0, 0.5, 0.1]
      if step < 10000:
          learning_rate = 1.0
      elif 10000 <= step < 20000:
          learning_rate = 0.5
      else:
          learning_rate = 0.1
X
Xin Pan 已提交
381 382 383 384 385 386 387 388
    Args:
        boundaries: A list of steps numbers.
        values: A list of learning rate values that will be picked during
            different step boundaries.

    Returns:
        The decayed learning rate.

X
xiaoting 已提交
389 390 391 392
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
393 394
          import paddle
          paddle.enable_static()
X
xiaoting 已提交
395 396 397 398 399 400 401
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
          optimizer = fluid.optimizer.Momentum(
              momentum=0.9,
              learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values),
              regularization=fluid.regularizer.L2Decay(1e-4))

X
Xin Pan 已提交
402

403
    """
404 405 406 407
    with default_main_program()._lr_schedule_guard():
        if len(values) - len(boundaries) != 1:
            raise ValueError("len(values) - len(boundaries) should be 1")

408
        if in_dygraph_mode():
M
minqiyang 已提交
409
            decay = imperate_lr.PiecewiseDecay(boundaries, values, 0)
410 411 412
            return decay
        else:
            global_step = _decay_step_counter()
413

414 415 416 417 418 419
            lr = tensor.create_global_var(
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate")
420

421 422 423 424 425 426 427 428 429 430 431 432
            with control_flow.Switch() as switch:
                for i in range(len(boundaries)):
                    boundary_val = tensor.fill_constant(
                        shape=[1],
                        dtype='float32',
                        value=float(boundaries[i]),
                        force_cpu=True)
                    value_var = tensor.fill_constant(
                        shape=[1], dtype='float32', value=float(values[i]))
                    with switch.case(global_step < boundary_val):
                        tensor.assign(value_var, lr)
                last_value_var = tensor.fill_constant(
433 434
                    shape=[1],
                    dtype='float32',
435 436 437
                    value=float(values[len(values) - 1]))
                with switch.default():
                    tensor.assign(last_value_var, lr)
438

439
            return lr
W
Wu Yi 已提交
440 441


S
shippingwang 已提交
442 443
def cosine_decay(learning_rate, step_each_epoch, epochs):
    """
S
swtkiwi 已提交
444

S
shippingwang 已提交
445 446
    Applies cosine decay to the learning rate.

S
shippingwang 已提交
447
    when training a model, it is often recommended to lower the learning rate as the
S
shippingwang 已提交
448 449
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.
S
shippingwang 已提交
450

451 452
    .. math::

X
xsrobin 已提交
453 454
        decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)

S
shippingwang 已提交
455 456 457 458 459
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.

460
    Returns:
X
xsrobin 已提交
461
        Variable: The decayed learning rate.
S
shippingwang 已提交
462

463
    Examples:
X
xsrobin 已提交
464
        .. code-block:: python
S
shippingwang 已提交
465

X
xsrobin 已提交
466 467 468 469
            import paddle.fluid as fluid
            base_lr = 0.1
            lr = fluid.layers.cosine_decay(
            learning_rate = base_lr, step_each_epoch=10000, epochs=120)
S
shippingwang 已提交
470
    """
471 472
    check_type(learning_rate, 'learning_rate', (float, tensor.Variable),
               'cosine_decay')
473

S
shippingwang 已提交
474
    with default_main_program()._lr_schedule_guard():
475
        if in_dygraph_mode():
M
minqiyang 已提交
476 477 478 479 480
            decay = imperate_lr.CosineDecay(learning_rate, step_each_epoch,
                                            epochs)
            return decay
        else:
            global_step = _decay_step_counter()
S
shippingwang 已提交
481

M
minqiyang 已提交
482 483 484 485
            cur_epoch = ops.floor(global_step / step_each_epoch)
            decayed_lr = learning_rate * 0.5 * (
                ops.cos(cur_epoch * math.pi / epochs) + 1)
            return decayed_lr
S
shippingwang 已提交
486 487


488 489
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
    """
S
swtkiwi 已提交
490

491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
    
    When global_step < warmup_steps, learning rate is updated as:
    
    .. code-block:: text
    
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
    
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
    
    When global_step >= warmup_steps, learning rate is updated as:
    
    .. code-block:: text
    
            lr = learning_rate
    
    where lr is the learning_rate after warm-up.
    
511
    Args:
512 513 514 515 516
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
    
517
    Returns:
518 519 520
        Variable: Warm-up learning rate with the same data type as learning_rate.
    
    
521
    Examples:
522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543
    
    .. code-block:: python
    
        import paddle.fluid as fluid
    
        boundaries = [100, 200]
        lr_steps = [0.1, 0.01, 0.001]
        learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor
        #learning_rate = 0.1  #case2, single-value
        warmup_steps = 50
        start_lr = 1. / 3.
        end_lr = 0.1
        decayed_lr = fluid.layers.linear_lr_warmup(learning_rate,
            warmup_steps, start_lr, end_lr)
    
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        out, = exe.run(fetch_list=[decayed_lr.name])
        print(out)
        # case1: [0.33333334]
        # case2: [0.33333334]
544
    """
Q
qingqing01 已提交
545 546 547 548 549
    dtype = 'float32'
    if isinstance(learning_rate, Variable):
        dtype = learning_rate.dtype

    linear_step = float(end_lr) - float(start_lr)
550
    with default_main_program()._lr_schedule_guard():
H
hong 已提交
551

552
        if in_dygraph_mode():
H
hong 已提交
553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576
            lr = imperate_lr.LinearLrWarmup(learning_rate, warmup_steps,
                                            start_lr, end_lr)
            return lr
        else:
            lr = tensor.create_global_var(
                shape=[1],
                value=0.0,
                dtype=dtype,
                persistable=True,
                name="learning_rate_warmup")

            global_step = _decay_step_counter()

            with control_flow.Switch() as switch:
                with switch.case(global_step < warmup_steps):
                    decayed_lr = start_lr + linear_step * (global_step /
                                                           float(warmup_steps))
                    tensor.assign(decayed_lr, lr)
                with switch.default():
                    if not isinstance(learning_rate, Variable):
                        learning_rate = tensor.fill_constant(
                            shape=[1], dtype=dtype, value=float(learning_rate))
                    tensor.assign(learning_rate, lr)
            return lr