learning_rate_scheduler.py 20.7 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23
import math
Q
qingqing01 已提交
24
import numbers
25

26
import paddle
27 28 29
from . import control_flow
from . import nn
from . import tensor
姜永久 已提交
30 31 32 33 34 35 36
from ..framework import (
    default_main_program,
    Parameter,
    unique_name,
    name_scope,
    in_dygraph_mode,
)
Q
qingqing01 已提交
37
from ..framework import Variable
M
minqiyang 已提交
38
from ..dygraph import learning_rate_scheduler as imperate_lr
39
from ..data_feeder import check_variable_and_dtype, check_type
Q
Qiao Longfei 已提交
40

41
__all__ = [
42 43 44 45 46 47 48 49
    'exponential_decay',
    'natural_exp_decay',
    'inverse_time_decay',
    'polynomial_decay',
    'piecewise_decay',
    'noam_decay',
    'cosine_decay',
    'linear_lr_warmup',
50
]
Q
Qiao Longfei 已提交
51 52


53
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
54
    # the first global step is zero in learning rate decay
55
    global_step = nn.autoincreased_step_counter(
56 57
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1
    )
58
    global_step = paddle.cast(global_step, 'float32')
Y
Yu Yang 已提交
59 60 61
    return global_step


62
def noam_decay(d_model, warmup_steps, learning_rate=1.0):
Y
yuyang18 已提交
63
    """
S
swtkiwi 已提交
64

Y
yuyang18 已提交
65 66
    Noam decay method. The numpy implementation of noam decay as follows.

X
xiaoting 已提交
67
    .. code-block:: python
68

69
      import paddle.fluid as fluid
X
xiaoting 已提交
70 71
      import numpy as np
      # set hyper parameters
72
      base_lr = 0.01
X
xiaoting 已提交
73 74 75 76
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
77
      lr_value = base_lr * np.power(d_model, -0.5) * np.min([
X
xiaoting 已提交
78 79
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])
Y
yuyang18 已提交
80 81 82

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
83 84 85

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
86

87 88
        warmup_steps(Variable): A super parameter.

89 90 91 92
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0

93 94
    Returns:
        The decayed learning rate.
X
xiaoting 已提交
95 96 97
    Examples:
        .. code-block:: python

98
          import paddle.fluid as fluid
X
xiaoting 已提交
99 100 101 102
          warmup_steps = 100
          learning_rate = 0.01
          lr = fluid.layers.learning_rate_scheduler.noam_decay(
                         1/(warmup_steps *(learning_rate ** 2)),
103 104
                         warmup_steps,
                         learning_rate)
105
    """
106
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
107
        if in_dygraph_mode():
108 109 110
            decay = imperate_lr.NoamDecay(
                d_model, warmup_steps, learning_rate=learning_rate
            )
M
minqiyang 已提交
111 112 113
            return decay
        else:
            global_step = _decay_step_counter(1)
F
fengjiayi 已提交
114

M
minqiyang 已提交
115 116
            a = global_step**-0.5
            b = (warmup_steps**-1.5) * global_step
117
            lr_value = learning_rate * (d_model**-0.5) * paddle.minimum(a, b)
118

M
minqiyang 已提交
119
            return lr_value
120 121


Y
Yu Yang 已提交
122
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
123
    """
S
swtkiwi 已提交
124

125
    Applies exponential decay to the learning rate.
F
fengjiayi 已提交
126

127 128
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
F
fengjiayi 已提交
129 130
    'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
131
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
132

F
fengjiayi 已提交
133 134 135 136
    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
    >>> else:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
Q
Qiao Longfei 已提交
137 138

    Args:
139
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
140 141 142
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
143
        staircase(bool): If True, decay the learning rate at discrete intervals, which
K
Kaipeng Deng 已提交
144 145 146
                         means the learning rate will be decayed by `decay_rate` every
                         `decay_steps`. If False, learning rate will be decayed continuously
                         and following the formula above. Default: False
Q
Qiao Longfei 已提交
147 148

    Returns:
K
Kaipeng Deng 已提交
149
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
150 151 152 153

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
154
          import paddle.fluid as fluid
155 156 157
          import paddle

          paddle.enable_static()
F
fengjiayi 已提交
158 159
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
160 161 162 163 164
              learning_rate=fluid.layers.exponential_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
F
fengjiayi 已提交
165

Q
Qiao Longfei 已提交
166
    """
167
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
168
        if in_dygraph_mode():
169 170 171
            decay = imperate_lr.ExponentialDecay(
                learning_rate, decay_steps, decay_rate, staircase
            )
172 173 174
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
175

176 177
            div_res = global_step / decay_steps
            if staircase:
178
                div_res = paddle.floor(div_res)
179
            decayed_lr = learning_rate * (decay_rate**div_res)
180

181
            return decayed_lr
Q
Qiao Longfei 已提交
182 183


Y
Yu Yang 已提交
184
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
S
swtkiwi 已提交
185 186
    """

187
    Applies natural exponential decay to the initial learning rate.
Q
Qiao Longfei 已提交
188

189 190 191
        When training a model, it is often recommended to lower the learning rate as the
        training progresses. By using this function, the learning rate will be decayed by
        natural exponential power 'decay_rate' every 'decay_steps' steps.
K
Kaipeng Deng 已提交
192

193
        Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
194

195 196 197 198
        >>> if not staircase:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
        >>> else:
        >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps))
Y
Yu Yang 已提交
199

200 201 202 203 204 205 206 207 208
        Args:
            learning_rate(Variable|float): The initial learning rate. It should be a Variable
                                           or a float
            decay_steps(int): The learning rate decay steps. See the decay computation above.
            decay_rate(float): The learning rate decay rate. See the decay computation above.
            staircase(bool): If True, decay the learning rate at discrete intervals, which
                             means the learning rate will be decayed by natural exponential power
                             `decay_rate` every `decay_steps`. If False, learning rate will be
                             decayed continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
209

210 211
        Returns:
            The decayed learning rate. The data type is float32.
K
Kaipeng Deng 已提交
212

213 214
        Examples:
            .. code-block:: python
K
Kaipeng Deng 已提交
215

216 217
              import paddle.fluid as fluid
              import paddle
218

219 220 221 222 223 224 225 226
              paddle.enable_static()
              base_lr = 0.1
              sgd_optimizer = fluid.optimizer.SGD(
                  learning_rate=fluid.layers.natural_exp_decay(
                        learning_rate=base_lr,
                        decay_steps=10000,
                        decay_rate=0.5,
                        staircase=True))
K
Kaipeng Deng 已提交
227

Q
Qiao Longfei 已提交
228
    """
229
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
230
        if in_dygraph_mode():
231 232 233
            decay = imperate_lr.NaturalExpDecay(
                learning_rate, decay_steps, decay_rate, staircase
            )
234 235 236
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
237

238 239
            div_res = global_step / decay_steps
            if staircase:
240 241
                div_res = paddle.floor(div_res)
            decayed_lr = learning_rate * paddle.exp(-1 * decay_rate * div_res)
242

243
            return decayed_lr
Q
Qiao Longfei 已提交
244 245


Y
Yu Yang 已提交
246
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
247
    """
S
swtkiwi 已提交
248

F
fengjiayi 已提交
249
    Applies inverse time decay to the initial learning rate.
Q
Qiao Longfei 已提交
250

251 252
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
F
fengjiayi 已提交
253
    applied to the initial learning rate.
Q
Qiao Longfei 已提交
254

T
tianshuo78520a 已提交
255
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
256

F
fengjiayi 已提交
257
    >>> if staircase == True:
Y
Yu Yang 已提交
258 259 260 261
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
262
    Args:
263
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
264 265 266
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
267 268 269
        staircase(bool): If True, decay the learning rate at discrete intervals, which
                         means the learning rate will be decayed by `decay_rate` times
                         every `decay_steps`. If False, learning rate will be decayed
K
Kaipeng Deng 已提交
270
                         continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
271 272

    Returns:
K
Kaipeng Deng 已提交
273
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
274 275 276 277

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
278
          import paddle.fluid as fluid
279 280
          import paddle
          paddle.enable_static()
F
fengjiayi 已提交
281 282
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
283 284 285 286 287
              learning_rate=fluid.layers.inverse_time_decay(
                    learning_rate=base_lr,
                    decay_steps=10000,
                    decay_rate=0.5,
                    staircase=True))
Q
Qiao Longfei 已提交
288
    """
289
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
290
        if in_dygraph_mode():
291 292 293
            decay = imperate_lr.InverseTimeDecay(
                learning_rate, decay_steps, decay_rate, staircase
            )
294 295 296
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
297

298 299
            div_res = global_step / decay_steps
            if staircase:
300
                div_res = paddle.floor(div_res)
301

302
            decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
303

304
            return decayed_lr
305 306


307 308 309
def polynomial_decay(
    learning_rate, decay_steps, end_learning_rate=0.0001, power=1.0, cycle=False
):
Q
qiaolongfei 已提交
310 311 312
    """
    Applies polynomial decay to the initial learning rate.

X
xiaoting 已提交
313
    .. code-block:: text
Q
qiaolongfei 已提交
314 315 316 317 318 319 320

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
321 322

    Args:
Q
qiaolongfei 已提交
323
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
Q
update  
qiaolongfei 已提交
324
          will be the initial learning rate during training.
Q
qiaolongfei 已提交
325
        decay_steps(int32): A Python `int32` number.
Q
update  
qiaolongfei 已提交
326 327 328
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
329 330

    Returns:
Q
update  
qiaolongfei 已提交
331
        Variable: The decayed learning rate
X
xiaoting 已提交
332 333 334 335 336 337 338 339 340 341 342

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          lr = fluid.layers.polynomial_decay(
              start_lr, total_step, end_lr, power=1)

343
    """
344
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
345
        if in_dygraph_mode():
346 347 348
            decay = imperate_lr.PolynomialDecay(
                learning_rate, decay_steps, end_learning_rate, power, cycle
            )
349
            return decay
350
        else:
351 352 353
            global_step = _decay_step_counter()

            if cycle:
354
                div_res = paddle.ceil(global_step / decay_steps)
355 356 357 358 359 360
                zero_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=0.0
                )
                one_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=1.0
                )
361 362 363

                with control_flow.Switch() as switch:
                    with switch.case(global_step == zero_var):
364
                        paddle.assign(one_var, output=div_res)
365 366
                decay_steps = decay_steps * div_res
            else:
367 368 369
                decay_steps_var = tensor.fill_constant(
                    shape=[1], dtype='float32', value=float(decay_steps)
                )
370
                global_step = paddle.minimum(x=global_step, y=decay_steps_var)
371 372 373 374

            decayed_lr = (learning_rate - end_learning_rate) * (
                (1 - global_step / decay_steps) ** power
            ) + end_learning_rate
375
            return decayed_lr
376 377


Y
Yu Yang 已提交
378
def piecewise_decay(boundaries, values):
S
swtkiwi 已提交
379 380
    """

381
    Applies piecewise decay to the initial learning rate.
X
Xin Pan 已提交
382

383
        The algorithm can be described as the code below.
X
Xin Pan 已提交
384

385
        .. code-block:: text
X
Xin Pan 已提交
386

X
xiaoting 已提交
387 388
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
          if step < 10000:
              learning_rate = 1.0
          elif 10000 <= step < 20000:
              learning_rate = 0.5
          else:
              learning_rate = 0.1
        Args:
            boundaries: A list of steps numbers.
            values: A list of learning rate values that will be picked during
                different step boundaries.

        Returns:
            The decayed learning rate.

        Examples:
            .. code-block:: python

              import paddle.fluid as fluid
              import paddle
              paddle.enable_static()
              boundaries = [10000, 20000]
              values = [1.0, 0.5, 0.1]
              optimizer = fluid.optimizer.Momentum(
                  momentum=0.9,
                  learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values),
                  regularization=fluid.regularizer.L2Decay(1e-4))
X
xiaoting 已提交
415

X
Xin Pan 已提交
416

417
    """
418 419 420 421
    with default_main_program()._lr_schedule_guard():
        if len(values) - len(boundaries) != 1:
            raise ValueError("len(values) - len(boundaries) should be 1")

姜永久 已提交
422
        if in_dygraph_mode():
M
minqiyang 已提交
423
            decay = imperate_lr.PiecewiseDecay(boundaries, values, 0)
424 425 426
            return decay
        else:
            global_step = _decay_step_counter()
427

428
            lr = paddle.static.create_global_var(
429 430 431 432 433 434
                shape=[1],
                value=0.0,
                dtype='float32',
                persistable=True,
                name="learning_rate",
            )
435

436 437
            with control_flow.Switch() as switch:
                for i in range(len(boundaries)):
438 439 440 441 442 443
                    boundary_val = tensor.fill_constant(
                        shape=[1],
                        dtype='float32',
                        value=float(boundaries[i]),
                        force_cpu=True,
                    )
444
                    with switch.case(global_step < boundary_val):
445 446 447 448 449 450
                        tensor.fill_constant(
                            shape=[1],
                            dtype="float32",
                            value=float(values[i]),
                            out=lr,
                        )
451
                with switch.default():
452 453 454 455 456 457
                    tensor.fill_constant(
                        shape=[1],
                        dtype="float32",
                        value=float(values[len(values) - 1]),
                        out=lr,
                    )
458

459
            return lr
W
Wu Yi 已提交
460 461


S
shippingwang 已提交
462
def cosine_decay(learning_rate, step_each_epoch, epochs):
463
    r"""
S
swtkiwi 已提交
464

S
shippingwang 已提交
465 466
    Applies cosine decay to the learning rate.

S
shippingwang 已提交
467
    when training a model, it is often recommended to lower the learning rate as the
S
shippingwang 已提交
468 469
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.
S
shippingwang 已提交
470

471 472
    .. math::

X
xsrobin 已提交
473 474
        decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)

S
shippingwang 已提交
475 476 477 478 479
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.

480
    Returns:
X
xsrobin 已提交
481
        Variable: The decayed learning rate.
S
shippingwang 已提交
482

483
    Examples:
X
xsrobin 已提交
484
        .. code-block:: python
S
shippingwang 已提交
485

X
xsrobin 已提交
486 487 488 489
            import paddle.fluid as fluid
            base_lr = 0.1
            lr = fluid.layers.cosine_decay(
            learning_rate = base_lr, step_each_epoch=10000, epochs=120)
S
shippingwang 已提交
490
    """
491 492 493
    check_type(
        learning_rate, 'learning_rate', (float, tensor.Variable), 'cosine_decay'
    )
494

S
shippingwang 已提交
495
    with default_main_program()._lr_schedule_guard():
姜永久 已提交
496
        if in_dygraph_mode():
497 498 499
            decay = imperate_lr.CosineDecay(
                learning_rate, step_each_epoch, epochs
            )
M
minqiyang 已提交
500 501 502
            return decay
        else:
            global_step = _decay_step_counter()
S
shippingwang 已提交
503

504
            cur_epoch = paddle.floor(global_step / step_each_epoch)
505 506 507
            decayed_lr = (
                learning_rate
                * 0.5
508
                * (paddle.cos(cur_epoch * math.pi / epochs) + 1)
509
            )
M
minqiyang 已提交
510
            return decayed_lr
S
shippingwang 已提交
511 512


513 514
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
    """
S
swtkiwi 已提交
515

516 517
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
518

519
    When global_step < warmup_steps, learning rate is updated as:
520

521
    .. code-block:: text
522

523 524
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
525

526
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
527

528
    When global_step >= warmup_steps, learning rate is updated as:
529

530
    .. code-block:: text
531

532
            lr = learning_rate
533

534
    where lr is the learning_rate after warm-up.
535

536
    Args:
537 538 539 540
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
541

542
    Returns:
543
        Variable: Warm-up learning rate with the same data type as learning_rate.
544 545


546
    Examples:
547

548
    .. code-block:: python
549

550
        import paddle.fluid as fluid
551

552 553 554 555 556 557 558 559 560
        boundaries = [100, 200]
        lr_steps = [0.1, 0.01, 0.001]
        learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor
        #learning_rate = 0.1  #case2, single-value
        warmup_steps = 50
        start_lr = 1. / 3.
        end_lr = 0.1
        decayed_lr = fluid.layers.linear_lr_warmup(learning_rate,
            warmup_steps, start_lr, end_lr)
561

562 563 564 565 566 567 568
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        out, = exe.run(fetch_list=[decayed_lr.name])
        print(out)
        # case1: [0.33333334]
        # case2: [0.33333334]
569
    """
Q
qingqing01 已提交
570 571 572 573 574
    dtype = 'float32'
    if isinstance(learning_rate, Variable):
        dtype = learning_rate.dtype

    linear_step = float(end_lr) - float(start_lr)
575
    with default_main_program()._lr_schedule_guard():
H
hong 已提交
576

姜永久 已提交
577
        if in_dygraph_mode():
578 579 580
            lr = imperate_lr.LinearLrWarmup(
                learning_rate, warmup_steps, start_lr, end_lr
            )
H
hong 已提交
581 582
            return lr
        else:
583
            lr = paddle.static.create_global_var(
584 585 586 587 588 589
                shape=[1],
                value=0.0,
                dtype=dtype,
                persistable=True,
                name="learning_rate_warmup",
            )
H
hong 已提交
590 591 592 593 594

            global_step = _decay_step_counter()

            with control_flow.Switch() as switch:
                with switch.case(global_step < warmup_steps):
595 596 597
                    decayed_lr = start_lr + linear_step * (
                        global_step / float(warmup_steps)
                    )
598
                    paddle.assign(decayed_lr, lr)
H
hong 已提交
599 600 601
                with switch.default():
                    if not isinstance(learning_rate, Variable):
                        learning_rate = tensor.fill_constant(
602 603
                            shape=[1], dtype=dtype, value=float(learning_rate)
                        )
604
                    paddle.assign(learning_rate, lr)
H
hong 已提交
605
            return lr