learning_rate_scheduler.py 20.9 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Y
yuyang18 已提交
14 15 16 17 18 19 20 21
"""
When training a model, it's often useful to decay the
learning rate during training process, this is called
learning_rate_decay. There are many strategies to do
this, this module will provide some classical method.
User can also implement their own learning_rate_decay
strategy according to this module.
"""
Q
Qiao Longfei 已提交
22

23
import math
Q
qingqing01 已提交
24
import numbers
25

26 27 28 29
from . import control_flow
from . import nn
from . import ops
from . import tensor
30
from ..framework import default_main_program, Parameter, unique_name, name_scope
Q
qingqing01 已提交
31
from ..framework import Variable
J
Jiabin Yang 已提交
32
from ..framework import _non_static_mode
M
minqiyang 已提交
33
from ..dygraph import learning_rate_scheduler as imperate_lr
34
from ..data_feeder import check_variable_and_dtype, check_type
Q
Qiao Longfei 已提交
35

36 37
__all__ = [
    'exponential_decay', 'natural_exp_decay', 'inverse_time_decay',
38 39
    'polynomial_decay', 'piecewise_decay', 'noam_decay', 'cosine_decay',
    'linear_lr_warmup'
40
]
Q
Qiao Longfei 已提交
41 42


43
def _decay_step_counter(begin=0):
Y
Yu Yang 已提交
44
    # the first global step is zero in learning rate decay
45
    global_step = nn.autoincreased_step_counter(
46
        counter_name='@LR_DECAY_COUNTER@', begin=begin, step=1)
47
    global_step = tensor.cast(global_step, 'float32')
Y
Yu Yang 已提交
48 49 50
    return global_step


51
def noam_decay(d_model, warmup_steps, learning_rate=1.0):
Y
yuyang18 已提交
52
    """
S
swtkiwi 已提交
53

Y
yuyang18 已提交
54 55
    Noam decay method. The numpy implementation of noam decay as follows.

X
xiaoting 已提交
56
    .. code-block:: python
57

58
      import paddle.fluid as fluid
X
xiaoting 已提交
59 60
      import numpy as np
      # set hyper parameters
61
      base_lr = 0.01
X
xiaoting 已提交
62 63 64 65
      d_model = 2
      current_steps = 20
      warmup_steps = 200
      # compute
66
      lr_value = base_lr * np.power(d_model, -0.5) * np.min([
X
xiaoting 已提交
67 68
                              np.power(current_steps, -0.5),
                              np.power(warmup_steps, -1.5) * current_steps])
Y
yuyang18 已提交
69 70 71

    Please reference `attention is all you need
    <https://arxiv.org/pdf/1706.03762.pdf>`_.
72 73 74

    Args:
        d_model(Variable): The dimensionality of input and output of model.
Y
yuyang18 已提交
75

76 77
        warmup_steps(Variable): A super parameter.

78 79 80 81
        learning_rate(Variable|float|int): The initial learning rate. If the type
            is Variable, it's a tensor with shape [1], the data type can be
            float32 or float64. It also can be set to python int number. Default 1.0

82 83
    Returns:
        The decayed learning rate.
X
xiaoting 已提交
84 85 86
    Examples:
        .. code-block:: python

87
          import paddle.fluid as fluid
X
xiaoting 已提交
88 89 90 91
          warmup_steps = 100
          learning_rate = 0.01
          lr = fluid.layers.learning_rate_scheduler.noam_decay(
                         1/(warmup_steps *(learning_rate ** 2)),
92 93
                         warmup_steps,
                         learning_rate)
94
    """
95
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
96
        if _non_static_mode():
97 98 99
            decay = imperate_lr.NoamDecay(d_model,
                                          warmup_steps,
                                          learning_rate=learning_rate)
M
minqiyang 已提交
100 101 102
            return decay
        else:
            global_step = _decay_step_counter(1)
F
fengjiayi 已提交
103

M
minqiyang 已提交
104 105
            a = global_step**-0.5
            b = (warmup_steps**-1.5) * global_step
106 107
            lr_value = learning_rate * (d_model**-0.5) * nn.elementwise_min(
                a, b)
108

M
minqiyang 已提交
109
            return lr_value
110 111


Y
Yu Yang 已提交
112
def exponential_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
113
    """
S
swtkiwi 已提交
114

115
    Applies exponential decay to the learning rate.
F
fengjiayi 已提交
116

117 118
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
F
fengjiayi 已提交
119 120
    'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
121
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
122

F
fengjiayi 已提交
123 124 125 126
    >>> if staircase == True:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ floor(global_step / decay_steps)
    >>> else:
    >>>     decayed_learning_rate = learning_rate * decay_rate ^ (global_step / decay_steps)
Q
Qiao Longfei 已提交
127 128

    Args:
129
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
130 131 132
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
133
        staircase(bool): If True, decay the learning rate at discrete intervals, which
K
Kaipeng Deng 已提交
134 135 136
                         means the learning rate will be decayed by `decay_rate` every
                         `decay_steps`. If False, learning rate will be decayed continuously
                         and following the formula above. Default: False
Q
Qiao Longfei 已提交
137 138

    Returns:
K
Kaipeng Deng 已提交
139
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
140 141 142 143

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
144
          import paddle.fluid as fluid
145 146 147
          import paddle

          paddle.enable_static()
F
fengjiayi 已提交
148 149
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
K
Kaipeng Deng 已提交
150 151 152 153 154
	      learning_rate=fluid.layers.exponential_decay(
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))
F
fengjiayi 已提交
155

Q
Qiao Longfei 已提交
156
    """
157
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
158
        if _non_static_mode():
159 160 161 162 163
            decay = imperate_lr.ExponentialDecay(learning_rate, decay_steps,
                                                 decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
164

165 166 167 168
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * (decay_rate**div_res)
169

170
            return decayed_lr
Q
Qiao Longfei 已提交
171 172


Y
Yu Yang 已提交
173
def natural_exp_decay(learning_rate, decay_steps, decay_rate, staircase=False):
S
swtkiwi 已提交
174 175 176
    """

Applies natural exponential decay to the initial learning rate.
Q
Qiao Longfei 已提交
177

K
Kaipeng Deng 已提交
178 179 180 181
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, the learning rate will be decayed by
    natural exponential power 'decay_rate' every 'decay_steps' steps.

T
tianshuo78520a 已提交
182
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
183

Y
Yu Yang 已提交
184 185 186
    >>> if not staircase:
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * (global_step / decay_steps))
    >>> else:
187
    >>>     decayed_learning_rate = learning_rate * exp(- decay_rate * floor(global_step / decay_steps))
Y
Yu Yang 已提交
188

Q
Qiao Longfei 已提交
189
    Args:
190
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
191 192 193
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
194
        staircase(bool): If True, decay the learning rate at discrete intervals, which
T
tianshuo78520a 已提交
195
                         means the learning rate will be decayed by natural exponential power
K
Kaipeng Deng 已提交
196 197
                         `decay_rate` every `decay_steps`. If False, learning rate will be
                         decayed continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
198 199

    Returns:
K
Kaipeng Deng 已提交
200
        The decayed learning rate. The data type is float32.
K
Kaipeng Deng 已提交
201 202 203 204 205

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
206 207 208
          import paddle

          paddle.enable_static()
K
Kaipeng Deng 已提交
209 210 211 212 213 214 215 216
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
	      learning_rate=fluid.layers.natural_exp_decay(
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))

Q
Qiao Longfei 已提交
217
    """
218
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
219
        if _non_static_mode():
220 221 222 223 224
            decay = imperate_lr.NaturalExpDecay(learning_rate, decay_steps,
                                                decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
225

226 227 228 229
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
            decayed_lr = learning_rate * ops.exp(-1 * decay_rate * div_res)
230

231
            return decayed_lr
Q
Qiao Longfei 已提交
232 233


Y
Yu Yang 已提交
234
def inverse_time_decay(learning_rate, decay_steps, decay_rate, staircase=False):
F
fengjiayi 已提交
235
    """
S
swtkiwi 已提交
236

F
fengjiayi 已提交
237
    Applies inverse time decay to the initial learning rate.
Q
Qiao Longfei 已提交
238

239 240
    When training a model, it is often recommended to lower the learning rate as the
    training progresses. By using this function, an inverse decay function will be
F
fengjiayi 已提交
241
    applied to the initial learning rate.
Q
Qiao Longfei 已提交
242

T
tianshuo78520a 已提交
243
    Decayed learning rate calculates as follows:
K
Kaipeng Deng 已提交
244

F
fengjiayi 已提交
245
    >>> if staircase == True:
Y
Yu Yang 已提交
246 247 248 249
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * floor(global_step / decay_step))
    >>> else:
    >>>     decayed_learning_rate = learning_rate / (1 + decay_rate * global_step / decay_step)

Q
Qiao Longfei 已提交
250
    Args:
251
        learning_rate(Variable|float): The initial learning rate. It should be a Variable
K
Kaipeng Deng 已提交
252 253 254
                                       or a float
        decay_steps(int): The learning rate decay steps. See the decay computation above.
        decay_rate(float): The learning rate decay rate. See the decay computation above.
255 256 257
        staircase(bool): If True, decay the learning rate at discrete intervals, which
                         means the learning rate will be decayed by `decay_rate` times
                         every `decay_steps`. If False, learning rate will be decayed
K
Kaipeng Deng 已提交
258
                         continuously and following the formula above. Default: False
Q
Qiao Longfei 已提交
259 260

    Returns:
K
Kaipeng Deng 已提交
261
        Variable: The decayed learning rate. The data type is float32.
F
fengjiayi 已提交
262 263 264 265

    Examples:
        .. code-block:: python

K
Kaipeng Deng 已提交
266
          import paddle.fluid as fluid
267 268
          import paddle
          paddle.enable_static()
F
fengjiayi 已提交
269 270
          base_lr = 0.1
          sgd_optimizer = fluid.optimizer.SGD(
K
Kaipeng Deng 已提交
271
	      learning_rate=fluid.layers.inverse_time_decay(
K
Kaipeng Deng 已提交
272 273 274 275
		    learning_rate=base_lr,
		    decay_steps=10000,
		    decay_rate=0.5,
		    staircase=True))
Q
Qiao Longfei 已提交
276
    """
277
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
278
        if _non_static_mode():
279 280 281 282 283
            decay = imperate_lr.InverseTimeDecay(learning_rate, decay_steps,
                                                 decay_rate, staircase)
            return decay
        else:
            global_step = _decay_step_counter()
Q
Qiao Longfei 已提交
284

285 286 287
            div_res = global_step / decay_steps
            if staircase:
                div_res = ops.floor(div_res)
288

289
            decayed_lr = learning_rate / (1 + decay_rate * div_res)
Q
Qiao Longfei 已提交
290

291
            return decayed_lr
292 293 294 295 296 297 298


def polynomial_decay(learning_rate,
                     decay_steps,
                     end_learning_rate=0.0001,
                     power=1.0,
                     cycle=False):
Q
qiaolongfei 已提交
299 300 301
    """
    Applies polynomial decay to the initial learning rate.

X
xiaoting 已提交
302
    .. code-block:: text
Q
qiaolongfei 已提交
303 304 305 306 307 308 309

     if cycle:
       decay_steps = decay_steps * ceil(global_step / decay_steps)
     else:
       global_step = min(global_step, decay_steps)
       decayed_learning_rate = (learning_rate - end_learning_rate) *
            (1 - global_step / decay_steps) ^ power + end_learning_rate
310 311

    Args:
Q
qiaolongfei 已提交
312
        learning_rate(Variable|float32): A scalar float32 value or a Variable. This
Q
update  
qiaolongfei 已提交
313
          will be the initial learning rate during training.
Q
qiaolongfei 已提交
314
        decay_steps(int32): A Python `int32` number.
Q
update  
qiaolongfei 已提交
315 316 317
        end_learning_rate(float): A Python `float` number.
        power(float): A Python `float` number.
        cycle(bool): If set true, decay the learning rate every decay_steps.
318 319

    Returns:
Q
update  
qiaolongfei 已提交
320
        Variable: The decayed learning rate
X
xiaoting 已提交
321 322 323 324 325 326 327 328 329 330 331

    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
          start_lr = 0.01
          total_step = 5000
          end_lr = 0
          lr = fluid.layers.polynomial_decay(
              start_lr, total_step, end_lr, power=1)

332
    """
333
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
334
        if _non_static_mode():
335 336 337
            decay = imperate_lr.PolynomialDecay(learning_rate, decay_steps,
                                                end_learning_rate, power, cycle)
            return decay
338
        else:
339 340 341 342
            global_step = _decay_step_counter()

            if cycle:
                div_res = ops.ceil(global_step / decay_steps)
343 344 345 346 347 348
                zero_var = tensor.fill_constant(shape=[1],
                                                dtype='float32',
                                                value=0.0)
                one_var = tensor.fill_constant(shape=[1],
                                               dtype='float32',
                                               value=1.0)
349 350 351 352 353 354

                with control_flow.Switch() as switch:
                    with switch.case(global_step == zero_var):
                        tensor.assign(input=one_var, output=div_res)
                decay_steps = decay_steps * div_res
            else:
355 356 357 358 359
                decay_steps_var = tensor.fill_constant(shape=[1],
                                                       dtype='float32',
                                                       value=float(decay_steps))
                global_step = nn.elementwise_min(x=global_step,
                                                 y=decay_steps_var)
360

361 362 363
            decayed_lr = (learning_rate - end_learning_rate) * \
                ((1 - global_step / decay_steps) ** power) + end_learning_rate
            return decayed_lr
364 365


Y
Yu Yang 已提交
366
def piecewise_decay(boundaries, values):
S
swtkiwi 已提交
367 368 369
    """

Applies piecewise decay to the initial learning rate.
370

X
xiaoting 已提交
371
    The algorithm can be described as the code below.
X
Xin Pan 已提交
372

X
xiaoting 已提交
373
    .. code-block:: text
X
Xin Pan 已提交
374

X
xiaoting 已提交
375 376 377 378 379 380 381 382
      boundaries = [10000, 20000]
      values = [1.0, 0.5, 0.1]
      if step < 10000:
          learning_rate = 1.0
      elif 10000 <= step < 20000:
          learning_rate = 0.5
      else:
          learning_rate = 0.1
X
Xin Pan 已提交
383 384 385 386 387 388 389 390
    Args:
        boundaries: A list of steps numbers.
        values: A list of learning rate values that will be picked during
            different step boundaries.

    Returns:
        The decayed learning rate.

X
xiaoting 已提交
391 392 393 394
    Examples:
        .. code-block:: python

          import paddle.fluid as fluid
395 396
          import paddle
          paddle.enable_static()
X
xiaoting 已提交
397 398 399 400 401 402 403
          boundaries = [10000, 20000]
          values = [1.0, 0.5, 0.1]
          optimizer = fluid.optimizer.Momentum(
              momentum=0.9,
              learning_rate=fluid.layers.piecewise_decay(boundaries=boundaries, values=values),
              regularization=fluid.regularizer.L2Decay(1e-4))

X
Xin Pan 已提交
404

405
    """
406 407 408 409
    with default_main_program()._lr_schedule_guard():
        if len(values) - len(boundaries) != 1:
            raise ValueError("len(values) - len(boundaries) should be 1")

J
Jiabin Yang 已提交
410
        if _non_static_mode():
M
minqiyang 已提交
411
            decay = imperate_lr.PiecewiseDecay(boundaries, values, 0)
412 413 414
            return decay
        else:
            global_step = _decay_step_counter()
415

416 417 418 419 420
            lr = tensor.create_global_var(shape=[1],
                                          value=0.0,
                                          dtype='float32',
                                          persistable=True,
                                          name="learning_rate")
421

422 423
            with control_flow.Switch() as switch:
                for i in range(len(boundaries)):
424 425 426 427 428
                    boundary_val = tensor.fill_constant(shape=[1],
                                                        dtype='float32',
                                                        value=float(
                                                            boundaries[i]),
                                                        force_cpu=True)
429
                    with switch.case(global_step < boundary_val):
430 431 432 433
                        tensor.fill_constant(shape=[1],
                                             dtype="float32",
                                             value=float(values[i]),
                                             out=lr)
434
                with switch.default():
435 436 437 438
                    tensor.fill_constant(shape=[1],
                                         dtype="float32",
                                         value=float(values[len(values) - 1]),
                                         out=lr)
439

440
            return lr
W
Wu Yi 已提交
441 442


S
shippingwang 已提交
443
def cosine_decay(learning_rate, step_each_epoch, epochs):
444
    r"""
S
swtkiwi 已提交
445

S
shippingwang 已提交
446 447
    Applies cosine decay to the learning rate.

S
shippingwang 已提交
448
    when training a model, it is often recommended to lower the learning rate as the
S
shippingwang 已提交
449 450
    training progresses. By using this function, the learning rate will be decayed by
    following cosine decay strategy.
S
shippingwang 已提交
451

452 453
    .. math::

X
xsrobin 已提交
454 455
        decayed\_lr = learning\_rate * 0.5 * (math.cos * (epoch * \\frac{math.pi}{epochs} ) + 1)

S
shippingwang 已提交
456 457 458 459 460
    Args:
        learning_rate(Variable|float): The initial learning rate.
        step_each_epoch(int): the number of steps in an epoch.
        epochs(int): the number of epochs.

461
    Returns:
X
xsrobin 已提交
462
        Variable: The decayed learning rate.
S
shippingwang 已提交
463

464
    Examples:
X
xsrobin 已提交
465
        .. code-block:: python
S
shippingwang 已提交
466

X
xsrobin 已提交
467 468 469 470
            import paddle.fluid as fluid
            base_lr = 0.1
            lr = fluid.layers.cosine_decay(
            learning_rate = base_lr, step_each_epoch=10000, epochs=120)
S
shippingwang 已提交
471
    """
472 473
    check_type(learning_rate, 'learning_rate', (float, tensor.Variable),
               'cosine_decay')
474

S
shippingwang 已提交
475
    with default_main_program()._lr_schedule_guard():
J
Jiabin Yang 已提交
476
        if _non_static_mode():
M
minqiyang 已提交
477 478 479 480 481
            decay = imperate_lr.CosineDecay(learning_rate, step_each_epoch,
                                            epochs)
            return decay
        else:
            global_step = _decay_step_counter()
S
shippingwang 已提交
482

M
minqiyang 已提交
483 484 485 486
            cur_epoch = ops.floor(global_step / step_each_epoch)
            decayed_lr = learning_rate * 0.5 * (
                ops.cos(cur_epoch * math.pi / epochs) + 1)
            return decayed_lr
S
shippingwang 已提交
487 488


489 490
def linear_lr_warmup(learning_rate, warmup_steps, start_lr, end_lr):
    """
S
swtkiwi 已提交
491

492 493
    This operator use the linear learning rate warm up strategy to adjust the learning rate preliminarily before the normal learning rate scheduling.
    For more information, please refer to `Bag of Tricks for Image Classification with Convolutional Neural Networks <https://arxiv.org/abs/1812.01187>`_
494

495
    When global_step < warmup_steps, learning rate is updated as:
496

497
    .. code-block:: text
498

499 500
            linear_step = end_lr - start_lr
            lr = start_lr + linear_step * (global_step / warmup_steps)
501

502
    where start_lr is the initial learning rate, and end_lr is the final learning rate;
503

504
    When global_step >= warmup_steps, learning rate is updated as:
505

506
    .. code-block:: text
507

508
            lr = learning_rate
509

510
    where lr is the learning_rate after warm-up.
511

512
    Args:
513 514 515 516
        learning_rate (Variable|float): Learning_rate after warm-up, it could be 1D-Tensor or single value with the data type of float32.
        warmup_steps (int): Steps for warm up.
        start_lr (float): Initial learning rate of warm up.
        end_lr (float): Final learning rate of warm up.
517

518
    Returns:
519
        Variable: Warm-up learning rate with the same data type as learning_rate.
520 521


522
    Examples:
523

524
    .. code-block:: python
525

526
        import paddle.fluid as fluid
527

528 529 530 531 532 533 534 535 536
        boundaries = [100, 200]
        lr_steps = [0.1, 0.01, 0.001]
        learning_rate = fluid.layers.piecewise_decay(boundaries, lr_steps) #case1, 1D-Tensor
        #learning_rate = 0.1  #case2, single-value
        warmup_steps = 50
        start_lr = 1. / 3.
        end_lr = 0.1
        decayed_lr = fluid.layers.linear_lr_warmup(learning_rate,
            warmup_steps, start_lr, end_lr)
537

538 539 540 541 542 543 544
        place = fluid.CPUPlace()
        exe = fluid.Executor(place)
        exe.run(fluid.default_startup_program())
        out, = exe.run(fetch_list=[decayed_lr.name])
        print(out)
        # case1: [0.33333334]
        # case2: [0.33333334]
545
    """
Q
qingqing01 已提交
546 547 548 549 550
    dtype = 'float32'
    if isinstance(learning_rate, Variable):
        dtype = learning_rate.dtype

    linear_step = float(end_lr) - float(start_lr)
551
    with default_main_program()._lr_schedule_guard():
H
hong 已提交
552

J
Jiabin Yang 已提交
553
        if _non_static_mode():
H
hong 已提交
554 555 556 557
            lr = imperate_lr.LinearLrWarmup(learning_rate, warmup_steps,
                                            start_lr, end_lr)
            return lr
        else:
558 559 560 561 562
            lr = tensor.create_global_var(shape=[1],
                                          value=0.0,
                                          dtype=dtype,
                                          persistable=True,
                                          name="learning_rate_warmup")
H
hong 已提交
563 564 565 566 567 568 569 570 571 572 573 574 575 576

            global_step = _decay_step_counter()

            with control_flow.Switch() as switch:
                with switch.case(global_step < warmup_steps):
                    decayed_lr = start_lr + linear_step * (global_step /
                                                           float(warmup_steps))
                    tensor.assign(decayed_lr, lr)
                with switch.default():
                    if not isinstance(learning_rate, Variable):
                        learning_rate = tensor.fill_constant(
                            shape=[1], dtype=dtype, value=float(learning_rate))
                    tensor.assign(learning_rate, lr)
            return lr